In [2]:
# IMPORTS

import tensorflow as tf
import numpy as np
from tensorflow import keras
from keras import layers, models
import matplotlib.pyplot as plt
from keras.utils import to_categorical
from keras import metrics, losses
from keras.models import load_model
from sklearn.model_selection import train_test_split
from PIL import Image
import os
import random


In [3]:
# Load both the teacher and student model

scratch_student = load_model('student_model.h5')
student_model = load_model('student_model.h5')
teacher_model =  load_model('teacher_model.h5')




In [None]:
# Below is just an example — here we load our own constructed dataset from google+bing(1000)
# We split each class into:
# 400 original for training
# 400 augmented (180-degree rotated) for training
# 200 remaining for testing

data_dir = "google+bing(1000)"
img_height = 32
img_width = 32
classes = sorted(os.listdir(data_dir))
class_indices = {cls: idx for idx, cls in enumerate(classes)}

X_all = []
y_all = []

for cls in classes:
    cls_path = os.path.join(data_dir, cls)
    image_files = sorted(os.listdir(cls_path))
    selected = random.sample(image_files, 60)

    for fname in selected:
        img_path = os.path.join(cls_path, fname)
        img = Image.open(img_path).convert("RGB").resize((img_width, img_height))  # ← fixed here
        X_all.append(np.array(img))
        y_all.append(class_indices[cls])

X_train = []
y_train = []
X_test = []
y_test = []

X_all = np.array(X_all)
y_all = np.array(y_all)

for cls_idx in range(len(classes)):
    idxs = np.where(y_all == cls_idx)[0]
    X_train.extend(X_all[idxs][:40])
    y_train.extend([cls_idx]*40)
    X_test.extend(X_all[idxs][40:])
    y_test.extend([cls_idx]*20)

# Augment training set with 180-degree rotated copies
X_aug = [np.array(Image.fromarray(img).rotate(180)) for img in X_train]
y_aug = y_train.copy()

X_train = np.array(X_train + X_aug)
y_train = np.array(y_train + y_aug)
X_test = np.array(X_test)
y_test = np.array(y_test)

X_train = X_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0

train_labels = to_categorical(y_train, num_classes=10)
test_labels = to_categorical(y_test, num_classes=10)

print(f"Training images shape: {X_train.shape}")
print(f"Testing images shape: {X_test.shape}")


Training images shape: (800, 32, 32, 3)
Testing images shape: (200, 32, 32, 3)


In [8]:
# Compute student model metrics without KD

scratch_student.compile(optimizer = 'sgd',
              loss='categorical_crossentropy',
              metrics=['accuracy'])


In [9]:
# First, let us try to see what if we directly train the student model without using knowledge distillation

scratch_student.fit(X_train, train_labels, epochs=7, batch_size=32)


Epoch 1/7
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 33ms/step - accuracy: 0.1295 - loss: 3.3667
Epoch 2/7
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 34ms/step - accuracy: 0.2117 - loss: 2.7361
Epoch 3/7
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 31ms/step - accuracy: 0.2403 - loss: 2.4981
Epoch 4/7
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 31ms/step - accuracy: 0.2871 - loss: 2.2472
Epoch 5/7
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 32ms/step - accuracy: 0.2774 - loss: 2.2732
Epoch 6/7
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 35ms/step - accuracy: 0.2804 - loss: 2.2052
Epoch 7/7
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 35ms/step - accuracy: 0.3323 - loss: 2.0708


<keras.src.callbacks.history.History at 0x25f5ba39900>

In [10]:
# We evaluate student model for its loss and accuracy, if the student model is trained without using knowledge distillation

scratch_student.evaluate(X_test, test_labels)


[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.0940 - loss: 3.2321    


[2.8883097171783447, 0.10000000149011612]

In [11]:
# KNOWLEDGE DISTILLATION CLASS, You can adjust alpha based on how much you want the student to learn from the teacher

class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.2,
        temperature=3,
    ):
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def compute_loss(
        self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False
    ):
        teacher_pred = self.teacher(x, training=False)
        student_loss = self.student_loss_fn(y, y_pred)

        distillation_loss = self.distillation_loss_fn(
            tf.nn.softmax(teacher_pred / self.temperature, axis=1),
            tf.nn.softmax(y_pred / self.temperature, axis=1),
        ) * (self.temperature**2)

        loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
        return loss

    def call(self, x):
        return self.student(x)


In [12]:
# Initialize the distiller
# Train the student model using knowledge distillation

distiller = Distiller(student=student_model, teacher=teacher_model)

distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[metrics.CategoricalAccuracy()],
    student_loss_fn=losses.CategoricalCrossentropy(),
    distillation_loss_fn=losses.CategoricalCrossentropy(),
    alpha=0.2,
    temperature=1,
) 

# Fitting the student model receiving KD
history = distiller.fit(
    X_train,
    train_labels,
    epochs=7,
    batch_size=32,  
    validation_split=0.2,
)


Epoch 1/7
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 78ms/step - categorical_accuracy: 0.1413 - loss: 2.2977 - val_categorical_accuracy: 0.0000e+00 - val_loss: 2.4020
Epoch 2/7
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 49ms/step - categorical_accuracy: 0.1824 - loss: 2.2692 - val_categorical_accuracy: 0.0000e+00 - val_loss: 2.4825
Epoch 3/7
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 58ms/step - categorical_accuracy: 0.2910 - loss: 2.2229 - val_categorical_accuracy: 0.0188 - val_loss: 2.4471
Epoch 4/7
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 46ms/step - categorical_accuracy: 0.3160 - loss: 2.1963 - val_categorical_accuracy: 0.0188 - val_loss: 2.6268
Epoch 5/7
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 43ms/step - categorical_accuracy: 0.4159 - loss: 2.1652 - val_categorical_accuracy: 0.0312 - val_loss: 2.5650
Epoch 6/7
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0

In [13]:
# We evaluate student model again for its loss and accuracy,
# But this time the student model is trained using knowledge distillation
# You can compare this results with the results above

distiller.evaluate(X_test, test_labels)


[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 27ms/step - categorical_accuracy: 0.4068 - loss: 2.1520


[2.1824004650115967, 0.35499998927116394]