In [56]:

# IMPORTS

import tensorflow as tf
import numpy as np
from tensorflow import keras
from keras import layers, models
import matplotlib.pyplot as plt
from keras.datasets import cifar10
from keras.utils import to_categorical
from keras import metrics, losses
from keras.models import load_model
import urllib.request
import os
import tarfile
from PIL import Image

In [57]:
# Load both the teacher and student model
try:
    scratch_student = load_model('student_model.h5')
    student_model = load_model('student_model.h5')
    teacher_model = load_model('teacher_model.h5')
    
    # Compile models with appropriate settings
    scratch_student.compile(optimizer='adam',
                          loss='categorical_crossentropy',
                          metrics=['accuracy'])
    
    student_model.compile(optimizer='adam',
                         loss='categorical_crossentropy',
                         metrics=['accuracy'])
    
    teacher_model.compile(optimizer='adam',
                         loss='categorical_crossentropy',
                         metrics=['accuracy'])
    
    print("Models loaded and compiled successfully!")
except Exception as e:
    print(f"Error loading models: {str(e)}")



In [58]:
# Load STL-10 dataset for knowledge distillation
# This dataset will be used to train the student model
# Note: This is a similar but different dataset from what was used to train the teacher model

# Load training images
print("Loading STL-10 training images...")
train_images = []
train_path = 'STL-10/train_images'
for img_path in glob.glob(os.path.join(train_path, '*.*')):
    img = Image.open(img_path).convert('RGB')
    img_array = np.array(img)
    train_images.append(img_array)
X_train = np.array(train_images)

# Load test images
print("Loading STL-10 test images...")
test_images = []
test_path = 'STL-10/test_images'
for img_path in glob.glob(os.path.join(test_path, '*.*')):
    img = Image.open(img_path).convert('RGB')
    img_array = np.array(img)
    test_images.append(img_array)
X_test = np.array(test_images)

# Normalize the data (same as CIFAR-10)
X_train = X_train.astype('float32')/255.0
X_test = X_test.astype('float32')/255.0

# Create labels (0-9 for 10 classes)
y_train = np.zeros(len(X_train))  # We'll update these with actual labels
y_test = np.zeros(len(X_test))    # We'll update these with actual labels

# Convert labels to one-hot encoding (same as CIFAR-10)
test_labels = to_categorical(y_test)
train_labels = to_categorical(y_train)

In [60]:
# Compute student model metrics without KD
scratch_student.compile(optimizer = 'adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])


In [61]:
# First, let us try to see what if we directly train the student model without using knowledge distillation
scratch_student.fit(X_train, train_labels, epochs=7, batch_size=32)

Epoch 1/7
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 11ms/step - accuracy: 0.2982 - loss: 2.1188
Epoch 2/7
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 11ms/step - accuracy: 0.4839 - loss: 1.4265
Epoch 3/7
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 11ms/step - accuracy: 0.5555 - loss: 1.2539
Epoch 4/7
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 11ms/step - accuracy: 0.5952 - loss: 1.1463
Epoch 5/7
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 11ms/step - accuracy: 0.6246 - loss: 1.0649
Epoch 6/7
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 11ms/step - accuracy: 0.6404 - loss: 1.0242
Epoch 7/7
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 11ms/step - accuracy: 0.6625 - loss: 0.9699


<keras.src.callbacks.history.History at 0x1e12c245130>

In [62]:
# We evaluate student model for its loss and accuracy, if the student model is trained without using knowledge distillation
scratch_student.evaluate(X_test, test_labels)

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.5412 - loss: 1.3541


[1.356394648551941, 0.5432999730110168]

In [65]:
# Now let us try using knowledge distillation
# KNOWLEDGE DISTILLATION CLASS, You can adjust alpha based on how much you want the student to learn from the teacher

class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.2,
        temperature=3,
    ):
        """Configure the distiller.

        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between student
                predictions and ground-truth
            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
            temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def compute_loss(
        self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False
    ):
        teacher_pred = self.teacher(x, training=False)
        student_loss = self.student_loss_fn(y, y_pred)

        distillation_loss = self.distillation_loss_fn(
            tf.nn.softmax(teacher_pred / self.temperature, axis=1),
            tf.nn.softmax(y_pred / self.temperature, axis=1),
        ) * (self.temperature**2)

        loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
        return loss

    def call(self, x):
        return self.student(x)

In [66]:
# Initialize the distiller
# Train the student model using knowledge distillation
distiller = Distiller(student=student_model, teacher=teacher_model)

# Compiling the Distiller. You can adjust alpha based on how much you want the student to learn from the teacher
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[metrics.CategoricalAccuracy()],
    student_loss_fn=losses.CategoricalCrossentropy(),
    distillation_loss_fn=losses.CategoricalCrossentropy(),
    alpha=0.2,
    temperature=1,
) 

# Fitting the student model receiving KD
history = distiller.fit(
    X_train,
    train_labels,
    epochs=7,
    batch_size=32,  
    validation_split=0.2,
)

Epoch 1/7
[1m1250/1250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m22s[0m 15ms/step - categorical_accuracy: 0.3646 - loss: 2.1855 - val_categorical_accuracy: 0.5998 - val_loss: 2.0767
Epoch 2/7
[1m1250/1250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 15ms/step - categorical_accuracy: 0.6438 - loss: 2.0573 - val_categorical_accuracy: 0.6967 - val_loss: 2.0333
Epoch 3/7
[1m1250/1250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 15ms/step - categorical_accuracy: 0.7336 - loss: 2.0163 - val_categorical_accuracy: 0.7351 - val_loss: 2.0208
Epoch 4/7
[1m1250/1250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 15ms/step - categorical_accuracy: 0.7750 - loss: 1.9952 - val_categorical_accuracy: 0.7445 - val_loss: 2.0144
Epoch 5/7
[1m1250/1250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 15ms/step - categorical_accuracy: 0.8180 - loss: 1.9741 - val_categorical_accuracy: 0.7457 - val_loss: 2.0213
Epoch 6/7
[1m1250/1250[0m [32m━━━━━━━━━━━━━━━━━━━━[

In [67]:
# We evaluate student model again for its loss and accuracy,
# But this time the student model is trained using knowledge distillation
# You can compare this results with the results above
distiller.evaluate(X_test, test_labels)

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 5ms/step - categorical_accuracy: 0.7383 - loss: 2.0353


[2.03804087638855, 0.7369999885559082]