In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, callbacks
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import keras_tuner as kt

# Load and preprocess data
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Enhanced data augmentation
datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)
datagen.fit(x_train)


# Define the teacher model
def create_teacher_model():
    base_model = EfficientNetB0(weights='imagenet', include_top=False, input_shape=(32, 32, 3))
    base_model.trainable = False

    x = layers.GlobalAveragePooling2D()(base_model.output)
    x = layers.Dense(1028, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.5)(x)
    output = layers.Dense(10, activation='softmax')(x)

    return models.Model(inputs=base_model.input, outputs=output)

teacher_model = create_teacher_model()
optimizer = optimizers.Adam(learning_rate=1e-3)
teacher_model.compile(
    optimizer=optimizer,
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'],
)

# Callbacks for teacher model
checkpoint_callback_teacher = callbacks.ModelCheckpoint(
    'teacher_model.weights.h5',
    save_weights_only=True,
    save_best_only=True,
    monitor='val_loss',
    mode='min',
    verbose=1
)
early_stopping_teacher = callbacks.EarlyStopping(monitor='val_loss', patience=10)
lr_scheduler_teacher = callbacks.ReduceLROnPlateau(factor=0.5, patience=5)

# Train the teacher model
teacher_model.fit(
    datagen.flow(x_train, y_train, batch_size=256),
    epochs=50,
    validation_data=(x_test, y_test),
    callbacks=[checkpoint_callback_teacher, early_stopping_teacher, lr_scheduler_teacher]
)

# Evaluate the teacher model
eval_results_teacher = teacher_model.evaluate(x_test, y_test)
print(f'Test loss: {eval_results_teacher[0]}')
print(f'Test accuracy: {eval_results_teacher[1]}')

# Define the student model
def create_student_model(conv1_filters, conv2_filters, conv3_filters, dense_units, dropout_rate):
    model = models.Sequential([
        layers.Conv2D(conv1_filters, (3, 3), activation='relu', input_shape=(32, 32, 3)),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(conv2_filters, (3, 3), activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(conv3_filters, (3, 3), activation='relu'),
        layers.BatchNormalization(),
        layers.GlobalAveragePooling2D(),
        layers.Dense(dense_units, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(dropout_rate),
        layers.Dense(10, activation='softmax')
    ])
    return model

def distillation_loss(y_true, y_pred, teacher_logits, temperature=3, alpha=0.1):
    y_true = tf.cast(y_true, tf.float32)
    teacher_logits = tf.cast(teacher_logits, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    
    # Compute soft targets
    soft_targets = tf.nn.softmax(teacher_logits / temperature)
    soft_prob = tf.nn.softmax(y_pred / temperature)
    
    # Soft targets loss
    soft_targets_loss = tf.reduce_mean(
        tf.keras.losses.categorical_crossentropy(soft_targets, soft_prob)
    )
    
    # Hard targets loss
    student_loss = tf.reduce_mean(
        tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)
    )
    
    # Combined loss
    return alpha * student_loss + (1 - alpha) * soft_targets_loss

# Define the hyperparameter tuning function
def build_model(hp):
    conv1_filters = hp.Int('conv1_filters', 32, 128, step=32)
    conv2_filters = hp.Int('conv2_filters', 64, 256, step=64)
    conv3_filters = hp.Int('conv3_filters', 128, 512, step=128)
    dense_units = hp.Int('dense_units', 128, 512, step=128)
    dropout_rate = hp.Float('dropout_rate', 0.3, 0.6, step=0.1)
    student_model = create_student_model(conv1_filters, conv2_filters, conv3_filters, dense_units, dropout_rate)
    optimizer = optimizers.Adam(hp.Float('learning_rate', 1e-4, 1e-2, sampling='log'))
    student_model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return student_model

# Optimization for hyperparameter tuning
tuner = kt.BayesianOptimization(
    build_model,
    objective='val_accuracy',
    num_initial_points=1,
    max_trials=1,  # Set to 1 trial
    directory='kt_dir',
    project_name='student_model_tuning',
    overwrite=True
)

# Perform the hyperparameter search
tuner.search(
    datagen.flow(x_train, y_train, batch_size=128),
    epochs=50,  # 20 epochs for each trial
    validation_data=(x_test, y_test)
)

# Get the best hyperparameters
best_hps = tuner.get_best_hyperparameters(num_trials=1)[0]

# Build the best model
best_student_model = create_student_model(
    best_hps.get('conv1_filters'),
    best_hps.get('conv2_filters'),
    best_hps.get('conv3_filters'),
    best_hps.get('dense_units'),
    best_hps.get('dropout_rate')
)
best_student_model.compile(
    optimizer=optimizers.Adam(best_hps.get('learning_rate')),
    loss=lambda y_true, y_pred: distillation_loss(y_true, y_pred, teacher_model(best_student_model.input)),
    metrics=['accuracy']
)

# Callbacks for student model training
checkpoint_callback_student = callbacks.ModelCheckpoint(
    'student_model.weights.h5',
    save_weights_only=True,
    save_best_only=True,
    monitor='val_loss',
    mode='min',
    verbose=1
)
early_stopping_student = callbacks.EarlyStopping(monitor='val_loss', patience=5)
lr_scheduler_student = callbacks.ReduceLROnPlateau(factor=0.5, patience=2)

# Train the student model with knowledge distillation
best_student_model.fit(
    datagen.flow(x_train, y_train, batch_size=128),
    epochs=50,  # Train the final model for more epochs
    validation_data=(x_test, y_test),
    callbacks=[checkpoint_callback_student, early_stopping_student, lr_scheduler_student]
)

# Evaluate the student model
eval_results = best_student_model.evaluate(x_test, y_test)
print(f'Test loss: {eval_results[0]}')
print(f'Test accuracy: {eval_results[1]}')


In [None]:
import os

def continuous_train(model, datagen, x_train, y_train, x_test, y_test, model_name, epochs=50, batch_size=128):
    # Check if there are existing weights to load
    weights_path = f'{model_name}.weights.h5'
    if os.path.exists(weights_path):
        print(f'Loading existing weights from {weights_path}')
        model.load_weights(weights_path)
    else:
        print(f'No existing weights found. Training from scratch.')

    # Callbacks for training
    checkpoint_callback = callbacks.ModelCheckpoint(
        weights_path,
        save_weights_only=True,
        save_best_only=True,
        monitor='val_loss',
        mode='min',
        verbose=1
    )
    early_stopping = callbacks.EarlyStopping(monitor='val_loss', patience=5)
    lr_scheduler = callbacks.ReduceLROnPlateau(factor=0.5, patience=2)

    # Train the model
    model.fit(
        datagen.flow(x_train, y_train, batch_size=batch_size),
        epochs=epochs,
        validation_data=(x_test, y_test),
        callbacks=[checkpoint_callback, early_stopping, lr_scheduler]
    )

    # Evaluate the model
    eval_results = model.evaluate(x_test, y_test)
    print(f'Test loss: {eval_results[0]}')
    print(f'Test accuracy: {eval_results[1]}')
    
    return model


def continuous_train_with_distillation(student_model, teacher_model, datagen, x_train, y_train, x_test, y_test, model_name, epochs=50, batch_size=128, temperature=3, alpha=0.1):
    weights_path = f'{model_name}.weights.h5'
    if os.path.exists(weights_path):
        print(f'Loading existing weights from {weights_path}')
        student_model.load_weights(weights_path)
    else:
        print(f'No existing weights found. Training from scratch.')

    # Define distillation loss
    def distillation_loss(y_true, y_pred, teacher_logits):
        y_true = tf.cast(y_true, tf.float32)
        teacher_logits = tf.cast(teacher_logits, tf.float32)
        y_pred = tf.cast(y_pred, tf.float32)
        
        # Compute soft targets
        soft_targets = tf.nn.softmax(teacher_logits / temperature)
        soft_prob = tf.nn.softmax(y_pred / temperature)
        
        # Soft targets loss
        soft_targets_loss = tf.reduce_mean(
            tf.keras.losses.categorical_crossentropy(soft_targets, soft_prob)
        )
        
        # Hard targets loss
        student_loss = tf.reduce_mean(
            tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)
        )
        
        # Combined loss
        return alpha * student_loss + (1 - alpha) * soft_targets_loss

    # Compile student model with distillation loss
    student_model.compile(
        optimizer=optimizers.Adam(),
        loss=lambda y_true, y_pred: distillation_loss(y_true, y_pred, teacher_model(student_model.input)),
        metrics=['accuracy']
    )

    # Callbacks for training
    checkpoint_callback = callbacks.ModelCheckpoint(
        weights_path,
        save_weights_only=True,
        save_best_only=True,
        monitor='val_loss',
        mode='min',
        verbose=1
    )
    early_stopping = callbacks.EarlyStopping(monitor='val_loss', patience=5)
    lr_scheduler = callbacks.ReduceLROnPlateau(factor=0.5, patience=2)

    # Train the student model
    student_model.fit(
        datagen.flow(x_train, y_train, batch_size=batch_size),
        epochs=epochs,
        validation_data=(x_test, y_test),
        callbacks=[checkpoint_callback, early_stopping, lr_scheduler]
    )

    # Evaluate the student model
    eval_results = student_model.evaluate(x_test, y_test)
    print(f'Test loss: {eval_results[0]}')
    print(f'Test accuracy: {eval_results[1]}')
    
    return student_model

# Example usage for the teacher model:
best_teacher_model = continuous_train(
    best_teacher_model,
    datagen,
    x_train,
    y_train,
    x_test,
    y_test,
    model_name='teacher_model',
    epochs=50,
    batch_size=128
)

# Example usage for the student model:
best_student_model = continuous_train(
    best_student_model,
    datagen,
    x_train,
    y_train,
    x_test,
    y_test,
    model_name='student_model',
    epochs=50,
    batch_size=128
)


# Example usage for the student model with knowledge distillation:
best_student_model = continuous_train_with_distillation(
    best_student_model,
    best_teacher_model,
    datagen,
    x_train,
    y_train,
    x_test,
    y_test,
    model_name='student_model',
    epochs=50,
    batch_size=128,
    temperature=3,
    alpha=0.1
)
