In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, callbacks, Input
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Load CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train.astype('float32') / 255.0, x_test.astype('float32') / 255.0

# Convert labels to one-hot encoding for the teacher model
y_train_one_hot = tf.keras.utils.to_categorical(y_train, 10)
y_test_one_hot = tf.keras.utils.to_categorical(y_test, 10)

# Data augmentation
datagen = ImageDataGenerator(
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.1,
    zoom_range=0.1,
    horizontal_flip=True,
    fill_mode='nearest'
)
datagen.fit(x_train)

# Teacher Model
def create_teacher_model():
    inputs = tf.keras.Input(shape=(32, 32, 3))
    base_model = EfficientNetB0(include_top=False, weights='imagenet', input_tensor=inputs)
    x = base_model.output
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(512, activation='relu')(x)
    x = layers.Dropout(0.2)(x)
    outputs = layers.Dense(10, activation='softmax')(x)
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return model

teacher_model = create_teacher_model()
teacher_model.compile(
    optimizer=optimizers.Adam(learning_rate=0.00027679),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Train Teacher Model
teacher_history = teacher_model.fit(
    datagen.flow(x_train, y_train_one_hot, batch_size=128),
    epochs=50,
    validation_data=(x_test, y_test_one_hot),
    callbacks=[callbacks.ModelCheckpoint('teacher_model_best.h5.keras', save_best_only=True)]
)

# Save the final teacher model
teacher_model.save('teacher_model_final.h5.keras')

# Student Model
def create_student_model(conv1_filters, conv2_filters, conv3_filters, dense_units, dropout_rate):
    inputs = Input(shape=(32, 32, 3))
    x = layers.Conv2D(conv1_filters, (3, 3), activation='relu')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Conv2D(conv2_filters, (3, 3), activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Conv2D(conv3_filters, (3, 3), activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(dense_units, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(dropout_rate)(x)
    outputs = layers.Dense(10, activation='softmax')(x)
    model = models.Model(inputs=inputs, outputs=outputs)
    return model

student_model = create_student_model(
    conv1_filters=64,
    conv2_filters=192,
    conv3_filters=384,
    dense_units=768,
    dropout_rate=0.5
)

student_model.compile(
    optimizer=optimizers.Adam(learning_rate=0.00035361),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Train Student Model
student_history = student_model.fit(
    datagen.flow(x_train, y_train, batch_size=64),
    epochs=40,
    validation_data=(x_test, y_test),
    callbacks=[
        callbacks.ModelCheckpoint('student_model_best.h5.keras', save_best_only=True),
        callbacks.EarlyStopping(patience=5),
        callbacks.ReduceLROnPlateau(factor=0.5, patience=2)
    ]
)

# Save the final student model
student_model.save('student_model_final.h5.keras')

# Distillation
def distillation_loss(y_true, y_pred, teacher_logits, temperature=3, alpha=0.1):
    y_true = tf.cast(y_true, tf.float32)
    teacher_logits = tf.cast(teacher_logits, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    
    soft_targets = tf.nn.softmax(teacher_logits / temperature)
    soft_prob = tf.nn.softmax(y_pred / temperature)
    
    soft_targets_loss = tf.reduce_mean(
        tf.keras.losses.categorical_crossentropy(soft_targets, soft_prob)
    )
    
    student_loss = tf.reduce_mean(
        tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)
    )
    
    return alpha * student_loss + (1 - alpha) * soft_targets_loss

# Load best teacher model for distillation
best_teacher_model = tf.keras.models.load_model('teacher_model_best.h5.keras')

# Create a new student model for distillation
distilled_student_model = create_student_model(
    conv1_filters=64,
    conv2_filters=192,
    conv3_filters=384,
    dense_units=768,
    dropout_rate=0.5
)

# Custom training step
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        teacher_logits = best_teacher_model(x, training=False)
        student_logits = distilled_student_model(x, training=True)
        loss = distillation_loss(y, student_logits, teacher_logits)
    
    gradients = tape.gradient(loss, distilled_student_model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, distilled_student_model.trainable_variables))
    
    return loss

# Distillation training loop
optimizer = optimizers.Adam(learning_rate=0.00035361)
num_epochs = 50
batch_size = 64

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    
    train_loss = 0
    num_batches = 0
    for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=batch_size):
        batch_loss = train_step(x_batch, y_batch)
        train_loss += batch_loss
        num_batches += 1
        if num_batches >= len(x_train) / batch_size:
            break
    
    train_loss /= num_batches
    print(f"Training loss: {train_loss:.4f}")
    
    val_loss = 0
    val_accuracy = 0
    num_batches = 0
    for i in range(0, len(x_test), batch_size):
        x_val_batch = x_test[i:i+batch_size]
        y_val_batch = y_test[i:i+batch_size]
        y_val_batch = tf.cast(y_val_batch, tf.int64)  # Cast y_val_batch to int64
        val_logits = distilled_student_model(x_val_batch, training=False)
        val_loss += distillation_loss(y_val_batch, val_logits, best_teacher_model(x_val_batch, training=False))
        val_accuracy += tf.reduce_mean(tf.cast(tf.equal(tf.argmax(val_logits, axis=1), tf.squeeze(y_val_batch)), tf.float32))
        num_batches += 1
    
    val_loss /= num_batches
    val_accuracy /= num_batches
    print(f"Validation loss: {val_loss:.4f}, Validation accuracy: {val_accuracy:.4f}")

# Save the distilled student model
distilled_student_model.save('distilled_student_model_final.h5.keras')

# Evaluate all models
def evaluate_model(model, name):
    eval_results = model.evaluate(x_test, y_test)
    print(f'{name} - Test loss: {eval_results[0]:.4f}, Test accuracy: {eval_results[1]:.4f}')

evaluate_model(tf.keras.models.load_model('teacher_model_best.h5.keras'), 'Teacher Model')
evaluate_model(tf.keras.models.load_model('student_model_best.h5.keras'), 'Student Model')
evaluate_model(distilled_student_model, 'Distilled Student Model')
