In [1]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
import cv2
import gc
import h5py
import wandb
import sys
from tqdm import tqdm
from dotenv import load_dotenv
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation, Flatten, Conv2D, Dropout, MaxPooling2D, BatchNormalization
from tensorflow.keras.applications import MobileNetV2, MobileNetV3Small, MobileNetV3Large
from tensorflow.keras.applications import EfficientNetB4
from tensorflow.keras.layers import GlobalAveragePooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import load_model
from tensorflow.keras.regularizers import l2
from tensorflow.keras.losses import KLDivergence

In [2]:
load_dotenv()  # Load environment variables from .env file
api_key = os.getenv("WANDB_API_KEY")

wandb.login(key=api_key)  # Login with the API key
wandb.init(project="LEC_Model", entity=os.getenv("WANDB_USERNAME"))

wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: nicholas66. Use `wandb login --relogin` to force relogin


In [3]:
def load_data(hdf5_file, dataset_name_images, dataset_name_labels):
    with h5py.File(hdf5_file, 'r') as hf:
        data = np.array(hf[dataset_name_images])
        labels = np.array(hf[dataset_name_labels])
    return data, labels

# Load data
data, labels = load_data('output files/combined_images_labels.h5', 'combined_images', 'combined_labels')

# Print shapes to verify
print(f"Data shape: {data.shape}")
print(f"Labels shape: {labels.shape}")

# Ensure that data and labels are not empty
assert len(data) > 0, "Data is empty."
assert len(labels) > 0, "Labels are empty."

# Split data
X_train, X_val, y_train, y_val = train_test_split(
    data, 
    labels, 
    test_size=0.2, 
    stratify=labels
)

# Print shapes of the split data
print(f"Training data shape: {X_train.shape}")
print(f"Training labels shape: {y_train.shape}")
print(f"Validation data shape: {X_val.shape}")
print(f"Validation labels shape: {y_val.shape}")

Data shape: (40980, 224, 224, 3)
Labels shape: (40980,)
Training data shape: (32784, 224, 224, 3)
Training labels shape: (32784,)
Validation data shape: (8196, 224, 224, 3)
Validation labels shape: (8196,)


In [4]:
# Define the data augmentation transformations
train_datagen = ImageDataGenerator(
    rotation_range=20,         # Randomly rotate images by up to 20 degrees
    width_shift_range=0.2,     # Randomly shift images horizontally by up to 20% of the width
    height_shift_range=0.2,    # Randomly shift images vertically by up to 20% of the height
    shear_range=0.2,           # Randomly apply shearing transformation
    zoom_range=0.2,            # Randomly zoom images by up to 20%
    horizontal_flip=True,      # Randomly flip images horizontally
    fill_mode='nearest'        # Strategy for filling in newly created pixels
)

# Create data generator for training data
train_generator = train_datagen.flow(
    X_train,   # Training images
    y_train,   # Training labels
    batch_size=84  # Batch size for training
)

validation_datagen = ImageDataGenerator()

# Create data generator for validation data
validation_generator = validation_datagen.flow(
    X_val,  # Validation images
    y_val,  # Validation labels
    batch_size=84   # Batch size for validation
)


# Teacher Model

In [5]:
def create_teacher_model():
    teacher_model = EfficientNetB4(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    teacher_model.trainable = False  # We don't want to train the teacher
    model = Sequential([
        teacher_model,
        GlobalAveragePooling2D(),
        Dense(128, activation='relu'),
        Dense(64, activation='relu'),
        Dense(7, activation='softmax')  # 7 emotion classes
    ])
    return model

# LEC Model (Student)

In [6]:
def create_student_model():
    # Load MobileNetV3 as base model
    base_model = MobileNetV3Small(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    
    # Freeze initial layers
    for layer in base_model.layers[:5]:  # Adjust the range as needed
        layer.trainable = False
    
    # Build the custom top layers
    model = Sequential([
        base_model,
        GlobalAveragePooling2D(),
        Dense(128, activation='relu'),
        Dropout(0.5),
        Dense(64, activation='relu'),
        Dropout(0.5),
        Dense(7, activation='softmax')  # Assuming 7 emotion classes
    ])

    return model



In [42]:
teacher_model = create_teacher_model()
student_model = create_student_model()

# Custom callback for WandB
class WandBCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        wandb.log({
            "Epoch Loss": logs.get("loss"),
            "Epoch Accuracy": logs.get("accuracy"),
            "Val Loss": logs.get("val_loss"),
            "Val Accuracy": logs.get("val_accuracy"),
            "Epoch": epoch + 1
        })

        

# Compile the student model
# student_model.compile(
#     optimizer=Adam(learning_rate=1e-4),
#     loss=lambda y_true, y_pred: knowledge_distillation_loss(y_true, y_pred, teacher_model.predict(train_generator[0][0])),  
#     metrics=['accuracy']
# )

# Define the knowledge distillation loss function
def knowledge_distillation_loss(y_true, y_pred, teacher_logits, T=3, alpha=0.5):
    
    # Hard loss (real labels)
    hard_loss = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)
    # Soft loss (teacher predictions)
    soft_loss = KLDivergence()(
        tf.nn.softmax(teacher_logits / T),
        tf.nn.softmax(y_pred / T)
    )
    # Combined loss
    return alpha * hard_loss + (1 - alpha) * soft_loss


In [71]:
# Custom training step for batch-wise teacher logits
class CustomDistillationModel(tf.keras.Model):
    def __init__(self, student, teacher):
        super(CustomDistillationModel, self).__init__()
        self.student = student
        self.teacher = teacher

    def call(self, inputs, training=False):
        return self.student(inputs, training=training)

    def train_step(self, data):
        x, y_true = data
        
        # Get teacher predictions for the current batch
        teacher_logits = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # Student predictions
            y_pred = self.student(x, training=True)
            # Compute the KD loss
            loss = knowledge_distillation_loss(y_true, y_pred, teacher_logits)

        # Compute gradients and update weights
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update metrics (accuracy)
        self.compiled_metrics.update_state(y_true, y_pred)
        return {m.name: m.result() for m in self.metrics}

# Create the custom KD model
distilled_model = CustomDistillationModel(student_model, teacher_model)

# Compile the student model (no need to modify the WandB callback)
distilled_model.compile(
    optimizer=Adam(learning_rate=1e-5),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Fit the model as before



In [72]:
distilled_model.fit(
    train_generator,
    validation_data=validation_generator,
    epochs=50,
    callbacks=[WandBCallback()],
    verbose=1
)

Epoch 1/50
342/342 ━━━━━━━━━━━━━━━━━━━━ 3:23:37 36s/step - accuracy: 0.1875 - loss: 0.142 ━━━━━━━━━━━━━━━━━━━━ 28:11 5s/step - accuracy: 0.2005 - loss: 0.1429  ━━━━━━━━━━━━━━━━━━━━ 28:18 5s/step - accuracy: 0.1950 - loss: 0.14 ━━━━━━━━━━━━━━━━━━━━ 28:09 5s/step - accuracy: 0.1938 - loss: 0.14 ━━━━━━━━━━━━━━━━━━━━ 28:00 5s/step - accuracy: 0.1938 - loss: 0.14 ━━━━━━━━━━━━━━━━━━━━ 27:58 5s/step - accuracy: 0.1945 - loss: 0.14 ━━━━━━━━━━━━━━━━━━━━ 27:50 5s/step - accuracy: 0.1945 - loss: 0.14 ━━━━━━━━━━━━━━━━━━━━ 27:42 5s/step - accuracy: 0.1945 - loss: 0.14 ━━━━━━━━━━━━━━━━━━━━ 27:35 5s/step - accuracy: 0.1946 - loss: 0.14 ━━━━━━━━━━━━━━━━━━━━ 27:34 5s/step - accuracy: 0.1939 - loss: 0.14 ━━━━━━━━━━━━━━━━━━━━ 27:31 5s/step - accuracy: 0.1937 - loss: 0.14 ━━━━━━━━━━━━━━━━━━━━ 27:27 5s/step - accuracy: 0.1929 - loss: 0.14 ━━━━━━━━━━━━━━━━━━━━ 27:22 5s/step - accuracy: 0.1919 - loss: 0.14 ━━━━━━━━━━━━━━━━━━━━ 27:16 5s/step - accuracy: 0.1912 - loss: 0.14 ━━━━━━━━━━━━━━━━━━━━ 27:11 5s/step -

  self._warn_if_super_not_called()


342/342 ━━━━━━━━━━━━━━━━━━━━ 1811s 5s/step - accuracy: 0.1906 - loss: 0.1429 - val_accuracy: 0.2925 - val_loss: 1.8327
Epoch 2/50
342/342 ━━━━━━━━━━━━━━━━━━━━ 1:03:58 11s/step - accuracy: 0.1979 - loss: 0.142 ━━━━━━━━━━━━━━━━━━━━ 29:25 5s/step - accuracy: 0.2135 - loss: 0.1429  ━━━━━━━━━━━━━━━━━━━━ 28:48 5s/step - accuracy: 0.2095 - loss: 0.14 ━━━━━━━━━━━━━━━━━━━━ 28:34 5s/step - accuracy: 0.2053 - loss: 0.14 ━━━━━━━━━━━━━━━━━━━━ 28:30 5s/step - accuracy: 0.2038 - loss: 0.14 ━━━━━━━━━━━━━━━━━━━━ 28:20 5s/step - accuracy: 0.2025 - loss: 0.14 ━━━━━━━━━━━━━━━━━━━━ 28:11 5s/step - accuracy: 0.2006 - loss: 0.14 ━━━━━━━━━━━━━━━━━━━━ 28:00 5s/step - accuracy: 0.1995 - loss: 0.14 ━━━━━━━━━━━━━━━━━━━━ 27:53 5s/step - accuracy: 0.1984 - loss: 0.14 ━━━━━━━━━━━━━━━━━━━━ 27:45 5s/step - accuracy: 0.1977 - loss: 0.14 ━━━━━━━━━━━━━━━━━━━━ 27:37 5s/step - accuracy: 0.1977 - loss: 0.14 ━━━━━━━━━━━━━━━━━━━━ 27:30 5s/step - accuracy: 0.1985 - loss: 0.14 ━━━━━━━━━━━━━━━━━━━━ 27:23 5s/step - accuracy: 0.19

KeyboardInterrupt: 

In [None]:
distilled_model.save('Trained Model/LEC_model_student_with_kd.keras')
