In [2]:
# IMPORTS

import tensorflow as tf
import numpy as np
from tensorflow import keras
from keras import layers, models
import matplotlib.pyplot as plt
from keras.datasets import cifar10
from keras.utils import to_categorical
from keras import metrics, losses
from keras.models import load_model


In [3]:
# Load both the teacher and student model
scratch_student = load_model('student_model.h5')
student_model = load_model('student_model.h5')
teacher_model =  load_model('teacher_model.h5')




In [13]:
# Compute student model metrics without KD

import os
import random
import numpy as np
from PIL import Image
from tensorflow.keras.utils import to_categorical

# Dataset location
data_dir = "bing_images(10000)/raw"
img_height = 32
img_width = 32

# Sorted class names and label indices
classes = sorted(os.listdir(data_dir))
class_indices = {cls: idx for idx, cls in enumerate(classes)}

X_train, y_train, X_test, y_test = [], [], [], []

# Valid image extensions
valid_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.gif'}

for cls in classes:
    cls_path = os.path.join(data_dir, cls)

    # Recursively find all valid image files
    image_paths = []
    for root, _, files in os.walk(cls_path):
        for fname in files:
            if os.path.splitext(fname)[1].lower() in valid_exts:
                image_paths.append(os.path.join(root, fname))

    if len(image_paths) < 1000:
        print(f"⚠️ Skipping '{cls}' — only {len(image_paths)} valid images found.")
        continue

    selected = random.sample(image_paths, 1000)
    train_imgs = selected[:800]
    test_imgs = selected[800:1000]

    for img_path in train_imgs:
        try:
            img = Image.open(img_path).convert("RGB").resize((img_width, img_height))
            X_train.append(np.array(img))
            y_train.append(class_indices[cls])
        except Exception as e:
            print(f"❌ Skipping train image: {img_path} - {e}")

    for img_path in test_imgs:
        try:
            img = Image.open(img_path).convert("RGB").resize((img_width, img_height))
            X_test.append(np.array(img))
            y_test.append(class_indices[cls])
        except Exception as e:
            print(f"❌ Skipping test image: {img_path} - {e}")

# Convert to numpy arrays
X_train = np.array(X_train).astype('float32') / 255.0
X_test = np.array(X_test).astype('float32') / 255.0
y_train = np.array(y_train)
y_test = np.array(y_test)

# One-hot encode labels
train_labels = to_categorical(y_train, num_classes=10)
test_labels = to_categorical(y_test, num_classes=10)

print(f"✅ Training images: {X_train.shape}, Training labels: {train_labels.shape}")
print(f"✅ Testing images: {X_test.shape}, Testing labels: {test_labels.shape}")


✅ Training images: (8000, 32, 32, 3), Training labels: (8000, 10)
✅ Testing images: (2000, 32, 32, 3), Testing labels: (2000, 10)


In [14]:
# Compute student model metrics without KD

import os
import random
from PIL import Image

# Dataset directory
data_dir = "bing_images(10000)/raw"
img_height = 32
img_width = 32

# Set up label mapping
classes = sorted(os.listdir(data_dir))
class_indices = {cls: idx for idx, cls in enumerate(classes)}

# Initialize storage
X_train, y_train, X_test, y_test = [], [], [], []

# Valid extensions
valid_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.gif'}

# Process each class
for cls in classes:
    cls_path = os.path.join(data_dir, cls)

    # Recursively gather image files
    image_paths = []
    for root, _, files in os.walk(cls_path):
        for fname in files:
            if os.path.splitext(fname)[1].lower() in valid_exts:
                image_paths.append(os.path.join(root, fname))

    if len(image_paths) < 1000:
        print(f"⚠️ Skipping {cls}: only {len(image_paths)} images found")
        continue

    selected = random.sample(image_paths, 1000)
    train_imgs = selected[:800]
    test_imgs = selected[800:1000]

    for img_path in train_imgs:
        try:
            img = Image.open(img_path).convert("RGB").resize((img_width, img_height))
            X_train.append(np.array(img))
            y_train.append(class_indices[cls])
        except:
            print(f"⚠️ Skipped train: {img_path}")

    for img_path in test_imgs:
        try:
            img = Image.open(img_path).convert("RGB").resize((img_width, img_height))
            X_test.append(np.array(img))
            y_test.append(class_indices[cls])
        except:
            print(f"⚠️ Skipped test: {img_path}")

# Final conversions
X_train = np.array(X_train).astype('float32') / 255.0
X_test = np.array(X_test).astype('float32') / 255.0
y_train = np.array(y_train)
y_test = np.array(y_test)

train_labels = to_categorical(y_train, num_classes=10)
test_labels = to_categorical(y_test, num_classes=10)

print(f"✅ Training images: {X_train.shape}, labels: {train_labels.shape}")
print(f"✅ Testing images: {X_test.shape}, labels: {test_labels.shape}")


✅ Training images: (8000, 32, 32, 3), labels: (8000, 10)
✅ Testing images: (2000, 32, 32, 3), labels: (2000, 10)


In [15]:
# Compute student model metrics without KD

scratch_student.compile(optimizer = 'sgd',
              loss='categorical_crossentropy',
              metrics=['accuracy'])


In [16]:
# First, let us try to see what if we directly train the student model without using knowledge distillation

scratch_student.fit(X_train, train_labels, epochs=7, batch_size=32)


Epoch 1/7
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m22s[0m 74ms/step - accuracy: 0.1678 - loss: 2.9212
Epoch 2/7
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 45ms/step - accuracy: 0.2614 - loss: 2.2264
Epoch 3/7
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 24ms/step - accuracy: 0.2617 - loss: 2.0942
Epoch 4/7
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 32ms/step - accuracy: 0.2916 - loss: 1.9985
Epoch 5/7
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 32ms/step - accuracy: 0.3221 - loss: 1.9153
Epoch 6/7
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 27ms/step - accuracy: 0.3293 - loss: 1.8870
Epoch 7/7
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 28ms/step - accuracy: 0.3533 - loss: 1.8425


<keras.src.callbacks.history.History at 0x2b167ee8d30>

In [17]:
# We evaluate student model for its loss and accuracy, if the student model is trained without using knowledge distillation

scratch_student.evaluate(X_test, test_labels)


[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 14ms/step - accuracy: 0.3194 - loss: 1.8893


[1.8287723064422607, 0.3644999861717224]

In [18]:
# Now let us try using knowledge distillation
# KNOWLEDGE DISTILLATION CLASS, You can adjust alpha based on how much you want the student to learn from the teacher

class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.2,
        temperature=3,
    ):
        """Configure the distiller.

        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between student
                predictions and ground-truth
            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
            temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def compute_loss(
        self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False
    ):
        teacher_pred = self.teacher(x, training=False)
        student_loss = self.student_loss_fn(y, y_pred)

        distillation_loss = self.distillation_loss_fn(
            tf.nn.softmax(teacher_pred / self.temperature, axis=1),
            tf.nn.softmax(y_pred / self.temperature, axis=1),
        ) * (self.temperature**2)

        loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
        return loss

    def call(self, x):
        return self.student(x)


In [19]:
# Initialize the distiller
# Train the student model using knowledge distillation

distiller = Distiller(student=student_model, teacher=teacher_model)

# Compiling the Distiller. You can adjust alpha based on how much you want the student to learn from the teacher
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[metrics.CategoricalAccuracy()],
    student_loss_fn=losses.CategoricalCrossentropy(),
    distillation_loss_fn=losses.CategoricalCrossentropy(),
    alpha=0.2,
    temperature=1,
)

# Fitting the student model receiving KD
history = distiller.fit(
    X_train,
    train_labels,
    epochs=7,
    batch_size=32,
    validation_split=0.2,
)


Epoch 1/7
[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 32ms/step - categorical_accuracy: 0.1679 - loss: 2.2472 - val_categorical_accuracy: 0.0000e+00 - val_loss: 3.2024
Epoch 2/7
[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 29ms/step - categorical_accuracy: 0.2946 - loss: 2.2002 - val_categorical_accuracy: 0.0000e+00 - val_loss: 3.9436
Epoch 3/7
[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 31ms/step - categorical_accuracy: 0.3559 - loss: 2.1803 - val_categorical_accuracy: 0.0000e+00 - val_loss: 3.8906
Epoch 4/7
[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 30ms/step - categorical_accuracy: 0.4252 - loss: 2.1519 - val_categorical_accuracy: 0.0000e+00 - val_loss: 3.9502
Epoch 5/7
[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 31ms/step - categorical_accuracy: 0.4867 - loss: 2.1296 - val_categorical_accuracy: 0.0000e+00 - val_loss: 4.2124
Epoch 6/7
[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━

In [None]:
# We evaluate student model again for its loss and accuracy,
# But this time the student model is trained using knowledge distillation
# You can compare this results with the results above

distiller.evaluate(X_test, test_labels)


[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 17ms/step - categorical_accuracy: 0.4421 - loss: 2.2550


[2.7153728008270264, 0.3695000112056732]