In [1]:
# IMPORTS

import tensorflow as tf
import numpy as np
from tensorflow import keras
from keras import layers, models
import matplotlib.pyplot as plt
from keras.datasets import cifar10
from keras.utils import to_categorical
from keras import metrics, losses
from keras.models import load_model
import os
import random
from PIL import Image


In [2]:
# Load both the teacher and student model
scratch_student = load_model('student_model.h5')
student_model = load_model('student_model.h5')
teacher_model = load_model('teacher_model.h5')




In [3]:
# Compute student model metrics without KD

img_height = 32
img_width = 32

# Paths
train_data_dir = "V2(10000)"
test_data_dir = "bing_images(10000)/raw"
valid_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.gif'}

# Get class names
classes = sorted(os.listdir(train_data_dir))
class_indices = {cls: idx for idx, cls in enumerate(classes)}

X_train, y_train, X_test, y_test = [], [], [], []

# Load training images from V2(10000)
for cls in classes:
    cls_path = os.path.join(train_data_dir, cls)
    image_files = [f for f in os.listdir(cls_path) if os.path.splitext(f)[1].lower() in valid_exts]

    if len(image_files) < 800:
        print(f"⚠️ Skipping {cls} (train) — only {len(image_files)} images found")
        continue

    selected = random.sample(image_files, 800)
    for fname in selected:
        img_path = os.path.join(cls_path, fname)
        try:
            img = Image.open(img_path).convert("RGB").resize((img_width, img_height))
            X_train.append(np.array(img))
            y_train.append(class_indices[cls])
        except:
            print(f"⚠️ Skipped train image: {img_path}")
    print(f"✅ Processed train class: {cls}")

# Load testing images from bing_images(10000)/raw
for cls in classes:
    cls_path = os.path.join(test_data_dir, cls)

    image_paths = []
    for root, _, files in os.walk(cls_path):
        for f in files:
            if os.path.splitext(f)[1].lower() in valid_exts:
                image_paths.append(os.path.join(root, f))

    if len(image_paths) < 200:
        print(f"⚠️ Skipping {cls} (test) — only {len(image_paths)} images found")
        continue

    selected = random.sample(image_paths, 200)
    for img_path in selected:
        try:
            img = Image.open(img_path).convert("RGB").resize((img_width, img_height))
            X_test.append(np.array(img))
            y_test.append(class_indices[cls])
        except:
            print(f"⚠️ Skipped test image: {img_path}")
    print(f"✅ Processed test class: {cls}")

# Convert and normalize
X_train = np.array(X_train).astype('float32') / 255.0
X_test = np.array(X_test).astype('float32') / 255.0
y_train = np.array(y_train)
y_test = np.array(y_test)

train_labels = to_categorical(y_train, num_classes=10)
test_labels = to_categorical(y_test, num_classes=10)

print(f"✅ Training set: {X_train.shape}, labels: {train_labels.shape}")
print(f"✅ Testing set: {X_test.shape}, labels: {test_labels.shape}")


✅ Processed train class: airplane
✅ Processed train class: automobile
✅ Processed train class: bird
✅ Processed train class: cat
✅ Processed train class: deer
✅ Processed train class: dog
✅ Processed train class: frog
✅ Processed train class: horse
✅ Processed train class: ship
✅ Processed train class: truck
✅ Processed test class: airplane
✅ Processed test class: automobile
✅ Processed test class: bird
✅ Processed test class: cat
✅ Processed test class: deer
✅ Processed test class: dog
✅ Processed test class: frog
✅ Processed test class: horse
✅ Processed test class: ship
✅ Processed test class: truck
✅ Training set: (8000, 32, 32, 3), labels: (8000, 10)
✅ Testing set: (2000, 32, 32, 3), labels: (2000, 10)


In [4]:
# Compute student model metrics without KD

scratch_student.compile(optimizer='sgd',
              loss='categorical_crossentropy',
              metrics=['accuracy'])


In [5]:
# First, let us try to see what if we directly train the student model without using knowledge distillation

scratch_student.fit(X_train, train_labels, epochs=7, batch_size=32)


Epoch 1/7
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 28ms/step - accuracy: 0.1512 - loss: 3.0152
Epoch 2/7
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 33ms/step - accuracy: 0.2851 - loss: 2.1061
Epoch 3/7
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 40ms/step - accuracy: 0.3468 - loss: 1.8659
Epoch 4/7
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 33ms/step - accuracy: 0.3813 - loss: 1.7458
Epoch 5/7
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 30ms/step - accuracy: 0.4135 - loss: 1.6581
Epoch 6/7
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 38ms/step - accuracy: 0.4658 - loss: 1.5480
Epoch 7/7
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 25ms/step - accuracy: 0.4563 - loss: 1.5283


<keras.src.callbacks.history.History at 0x1e971729360>

In [6]:
# We evaluate student model for its loss and accuracy, if the student model is trained without using knowledge distillation

scratch_student.evaluate(X_test, test_labels)


[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 13ms/step - accuracy: 0.2058 - loss: 2.5685


[2.3702526092529297, 0.24250000715255737]

In [7]:
# Now let us try using knowledge distillation
# KNOWLEDGE DISTILLATION CLASS, You can adjust alpha based on how much you want the student to learn from the teacher

class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.2,
        temperature=3,
    ):
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def compute_loss(
        self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False
    ):
        teacher_pred = self.teacher(x, training=False)
        student_loss = self.student_loss_fn(y, y_pred)

        distillation_loss = self.distillation_loss_fn(
            tf.nn.softmax(teacher_pred / self.temperature, axis=1),
            tf.nn.softmax(y_pred / self.temperature, axis=1),
        ) * (self.temperature**2)

        loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
        return loss

    def call(self, x):
        return self.student(x)


In [8]:
# Initialize the distiller
# Train the student model using knowledge distillation

distiller = Distiller(student=student_model, teacher=teacher_model)

# Compiling the Distiller. You can adjust alpha based on how much you want the student to learn from the teacher
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[metrics.CategoricalAccuracy()],
    student_loss_fn=losses.CategoricalCrossentropy(),
    distillation_loss_fn=losses.CategoricalCrossentropy(),
    alpha=0.2,
    temperature=1,
)

# Fitting the student model receiving KD
history = distiller.fit(
    X_train,
    train_labels,
    epochs=7,
    batch_size=32,
    validation_split=0.2,
)


Epoch 1/7
[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 48ms/step - categorical_accuracy: 0.1234 - loss: 2.2608 - val_categorical_accuracy: 0.0000e+00 - val_loss: 3.6147
Epoch 2/7
[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 44ms/step - categorical_accuracy: 0.2540 - loss: 2.2056 - val_categorical_accuracy: 0.0000e+00 - val_loss: 4.1309
Epoch 3/7
[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 52ms/step - categorical_accuracy: 0.4622 - loss: 2.1340 - val_categorical_accuracy: 0.0000e+00 - val_loss: 4.4301
Epoch 4/7
[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 52ms/step - categorical_accuracy: 0.5791 - loss: 2.0805 - val_categorical_accuracy: 0.0000e+00 - val_loss: 4.8755
Epoch 5/7
[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 44ms/step - categorical_accuracy: 0.6553 - loss: 2.0484 - val_categorical_accuracy: 0.0000e+00 - val_loss: 4.3482
Epoch 6/7
[1m200/200[0m [32m━━━━━━━━━━━━━━━━

In [9]:
# We evaluate student model again for its loss and accuracy,
# But this time the student model is trained using knowledge distillation
# You can compare this results with the results above

distiller.evaluate(X_test, test_labels)


[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 15ms/step - categorical_accuracy: 0.2858 - loss: 2.4740


[2.877549886703491, 0.21699999272823334]