In [1]:
# IMPORTS

import tensorflow as tf
import numpy as np
from tensorflow import keras
from keras import layers, models
import matplotlib.pyplot as plt
from keras.datasets import cifar10
from keras.utils import to_categorical
from keras import metrics, losses
from keras.models import load_model
import os
import random
from PIL import Image


In [2]:
# Load both the teacher and student model
scratch_student = load_model('student_model.h5')
student_model = load_model('student_model.h5')
teacher_model = load_model('teacher_model.h5')




In [7]:
# Compute student model metrics without KD

img_height = 32
img_width = 32

google_dir = "google_images(3500)/raw"
bing_dir = "bing_images(10000)/raw"
valid_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.gif'}

classes = sorted(os.listdir(google_dir))
class_indices = {cls: idx for idx, cls in enumerate(classes)}

X_train, y_train, X_test, y_test = [], [], [], []

# ✅ Load from GOOGLE (recursively)
for cls in classes:
    cls_path = os.path.join(google_dir, cls)
    image_paths = []
    for root, _, files in os.walk(cls_path):
        for f in files:
            if os.path.splitext(f)[1].lower() in valid_exts:
                image_paths.append(os.path.join(root, f))

    print(f"Found {len(image_paths)} images for class '{cls}' in Google")

    if len(image_paths) < 250:
        print(f"⚠️ Skipping {cls} from Google — only {len(image_paths)} found")
        continue

    selected = random.sample(image_paths, 250)
    train_imgs = selected[:200]
    test_imgs = selected[200:]

    for path in train_imgs:
        try:
            img = Image.open(path).convert("RGB").resize((img_width, img_height))
            X_train.append(np.array(img))
            y_train.append(class_indices[cls])
        except:
            print(f"⚠️ Skipped train image (Google): {path}")

    for path in test_imgs:
        try:
            img = Image.open(path).convert("RGB").resize((img_width, img_height))
            X_test.append(np.array(img))
            y_test.append(class_indices[cls])
        except:
            print(f"⚠️ Skipped test image (Google): {path}")

    print(f"✅ Processed class from Google: {cls}")

# ✅ Load from BING (recursive as before)
for cls in classes:
    cls_path = os.path.join(bing_dir, cls)
    image_paths = []
    for root, _, files in os.walk(cls_path):
        for f in files:
            if os.path.splitext(f)[1].lower() in valid_exts:
                image_paths.append(os.path.join(root, f))

    if len(image_paths) < 250:
        print(f"⚠️ Skipping {cls} from Bing — only {len(image_paths)} found")
        continue

    selected = random.sample(image_paths, 250)
    train_imgs = selected[:200]
    test_imgs = selected[200:]

    for path in train_imgs:
        try:
            img = Image.open(path).convert("RGB").resize((img_width, img_height))
            X_train.append(np.array(img))
            y_train.append(class_indices[cls])
        except:
            print(f"⚠️ Skipped train image (Bing): {path}")

    for path in test_imgs:
        try:
            img = Image.open(path).convert("RGB").resize((img_width, img_height))
            X_test.append(np.array(img))
            y_test.append(class_indices[cls])
        except:
            print(f"⚠️ Skipped test image (Bing): {path}")

    print(f"✅ Processed class from Bing: {cls}")

# ✅ Final conversion
X_train = np.array(X_train).astype('float32') / 255.0
X_test = np.array(X_test).astype('float32') / 255.0
y_train = np.array(y_train)
y_test = np.array(y_test)

train_labels = to_categorical(y_train, num_classes=10)
test_labels = to_categorical(y_test, num_classes=10)

print(f"✅ Training set: {X_train.shape}, labels: {train_labels.shape}")
print(f"✅ Testing set: {X_test.shape}, labels: {test_labels.shape}")


Found 700 images for class 'airplane' in Google
✅ Processed class from Google: airplane
Found 666 images for class 'automobile' in Google
✅ Processed class from Google: automobile
Found 632 images for class 'bird' in Google
✅ Processed class from Google: bird
Found 665 images for class 'cat' in Google
✅ Processed class from Google: cat
Found 700 images for class 'deer' in Google
✅ Processed class from Google: deer
Found 700 images for class 'dog' in Google
✅ Processed class from Google: dog
Found 666 images for class 'frog' in Google




✅ Processed class from Google: frog
Found 687 images for class 'horse' in Google
✅ Processed class from Google: horse
Found 700 images for class 'ship' in Google
✅ Processed class from Google: ship
Found 666 images for class 'truck' in Google
✅ Processed class from Google: truck
✅ Processed class from Bing: airplane
✅ Processed class from Bing: automobile
✅ Processed class from Bing: bird
✅ Processed class from Bing: cat
✅ Processed class from Bing: deer
✅ Processed class from Bing: dog
✅ Processed class from Bing: frog
✅ Processed class from Bing: horse
✅ Processed class from Bing: ship
✅ Processed class from Bing: truck
✅ Training set: (4000, 32, 32, 3), labels: (4000, 10)
✅ Testing set: (1000, 32, 32, 3), labels: (1000, 10)


In [8]:
# Compute student model metrics without KD

scratch_student.compile(optimizer='sgd',
              loss='categorical_crossentropy',
              metrics=['accuracy'])


In [9]:
# First, let us try to see what if we directly train the student model without using knowledge distillation

scratch_student.fit(X_train, train_labels, epochs=7, batch_size=32)


Epoch 1/7
[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 31ms/step - accuracy: 0.1541 - loss: 3.0347
Epoch 2/7
[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 50ms/step - accuracy: 0.2417 - loss: 2.2901
Epoch 3/7
[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 63ms/step - accuracy: 0.2710 - loss: 2.1180
Epoch 4/7
[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 61ms/step - accuracy: 0.3178 - loss: 1.9501
Epoch 5/7
[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 32ms/step - accuracy: 0.3043 - loss: 1.9643
Epoch 6/7
[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 36ms/step - accuracy: 0.3698 - loss: 1.8130
Epoch 7/7
[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 25ms/step - accuracy: 0.3476 - loss: 1.8572


<keras.src.callbacks.history.History at 0x267dcef3a00>

In [10]:
# We evaluate student model for its loss and accuracy, if the student model is trained without using knowledge distillation

scratch_student.evaluate(X_test, test_labels)


[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 17ms/step - accuracy: 0.3835 - loss: 1.6962


[1.7730858325958252, 0.3919999897480011]

In [11]:
# Now let us try using knowledge distillation
# KNOWLEDGE DISTILLATION CLASS, You can adjust alpha based on how much you want the student to learn from the teacher

class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.2,
        temperature=3,
    ):
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def compute_loss(
        self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False
    ):
        teacher_pred = self.teacher(x, training=False)
        student_loss = self.student_loss_fn(y, y_pred)

        distillation_loss = self.distillation_loss_fn(
            tf.nn.softmax(teacher_pred / self.temperature, axis=1),
            tf.nn.softmax(y_pred / self.temperature, axis=1),
        ) * (self.temperature**2)

        loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
        return loss

    def call(self, x):
        return self.student(x)


In [12]:
# Initialize the distiller
# Train the student model using knowledge distillation

distiller = Distiller(student=student_model, teacher=teacher_model)

# Compiling the Distiller. You can adjust alpha based on how much you want the student to learn from the teacher
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[metrics.CategoricalAccuracy()],
    student_loss_fn=losses.CategoricalCrossentropy(),
    distillation_loss_fn=losses.CategoricalCrossentropy(),
    alpha=0.2,
    temperature=1,
)

# Fitting the student model receiving KD
history = distiller.fit(
    X_train,
    train_labels,
    epochs=7,
    batch_size=32,
    validation_split=0.2,
)


Epoch 1/7
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 46ms/step - categorical_accuracy: 0.1542 - loss: 2.2820 - val_categorical_accuracy: 0.0587 - val_loss: 2.2951
Epoch 2/7
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 52ms/step - categorical_accuracy: 0.2806 - loss: 2.2242 - val_categorical_accuracy: 0.1125 - val_loss: 2.3188
Epoch 3/7
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 45ms/step - categorical_accuracy: 0.3083 - loss: 2.2058 - val_categorical_accuracy: 0.2175 - val_loss: 2.3005
Epoch 4/7
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 48ms/step - categorical_accuracy: 0.3713 - loss: 2.1839 - val_categorical_accuracy: 0.2138 - val_loss: 2.3191
Epoch 5/7
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 43ms/step - categorical_accuracy: 0.4050 - loss: 2.1592 - val_categorical_accuracy: 0.1925 - val_loss: 2.3645
Epoch 6/7
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m

In [13]:
# We evaluate student model again for its loss and accuracy,
# But this time the student model is trained using knowledge distillation
# You can compare this results with the results above

distiller.evaluate(X_test, test_labels)


[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - categorical_accuracy: 0.4595 - loss: 2.1612


[2.21030592918396, 0.38199999928474426]