In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np

# Load and preprocess the MNIST dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = np.expand_dims(x_train, axis=-1)
x_test = np.expand_dims(x_test, axis=-1)

# One-hot encode the labels
y_train_onehot = tf.keras.utils.to_categorical(y_train, num_classes=10)
y_test_onehot = tf.keras.utils.to_categorical(y_test, num_classes=10)


In [2]:
# Define the Teacher Model
teacher_model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')
])
teacher_model.summary()

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
2025-09-13 10:15:38.668328: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M3 Pro
2025-09-13 10:15:38.668371: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 18.00 GB
2025-09-13 10:15:38.668376: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 6.00 GB
I0000 00:00:1757733338.668674 5229756 pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
I0000 00:00:1757733338.668996 5229756 pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [3]:
teacher_model.compile(optimizer='adam',
                      loss='categorical_crossentropy',
                      metrics=['accuracy'])


In [4]:
# Train the Teacher Model
teacher_model.fit(x_train, y_train_onehot, epochs=5, validation_split=0.1)


Epoch 1/5


2025-09-13 10:15:48.192820: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.


[1m1688/1688[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 14ms/step - accuracy: 0.9046 - loss: 0.3052 - val_accuracy: 0.9865 - val_loss: 0.0469
Epoch 2/5
[1m1688/1688[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 14ms/step - accuracy: 0.9823 - loss: 0.0581 - val_accuracy: 0.9882 - val_loss: 0.0435
Epoch 3/5
[1m1688/1688[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 14ms/step - accuracy: 0.9846 - loss: 0.0516 - val_accuracy: 0.9863 - val_loss: 0.0770
Epoch 4/5
[1m1688/1688[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 14ms/step - accuracy: 0.9846 - loss: 0.0669 - val_accuracy: 0.9875 - val_loss: 0.0767
Epoch 5/5
[1m1688/1688[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 14ms/step - accuracy: 0.9839 - loss: 0.1040 - val_accuracy: 0.9827 - val_loss: 0.2430


<keras.src.callbacks.history.History at 0x380a05540>

In [5]:
# Generate Soft Labels from the Teacher Model
soft_labels = teacher_model.predict(x_train)


[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 1ms/step


In [6]:
# Define the Student Model
student_model = models.Sequential([
    layers.Conv2D(16, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    layers.MaxPooling2D((2, 2)),
    layers.Flatten(),
    layers.Dense(64, activation='relu'),
    layers.Dense(10, activation='softmax')
])
student_model.summary()

In [8]:
# Define Knowledge Distillation Loss
def distillation_loss(y_true, y_pred, teacher_pred, temperature=5):
    # Scale predictions by temperature
    teacher_pred_scaled = tf.nn.softmax(teacher_pred / temperature)
    y_pred_scaled = tf.nn.softmax(y_pred / temperature)
    
    # Cross-entropy between teacher and student predictions
    kd_loss = tf.reduce_mean(
        tf.keras.losses.categorical_crossentropy(teacher_pred_scaled, y_pred_scaled)
    )
    # Add standard cross-entropy loss with true labels
    ce_loss = tf.reduce_mean(
        tf.keras.losses.categorical_crossentropy(y_true, y_pred)
    )
    return kd_loss * 0.5 + ce_loss * 0.5


In [9]:
epochs = 1
batch_size = 32
num_batches = len(x_train) // batch_size

print("num_batches : ",num_batches)

num_batches :  1875


In [10]:
optimizer = tf.keras.optimizers.Adam()

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    for i in range(num_batches):
        # Get a batch of data
        start = i * batch_size
        end = start + batch_size
        x_batch = x_train[start:end]
        y_batch = y_train_onehot[start:end]
        
        with tf.GradientTape() as tape:
            predictions = student_model(x_batch, training=True)
            teacher_predictions = teacher_model(x_batch, training=True)
            loss = distillation_loss(y_batch, predictions, teacher_predictions)


        gradients = tape.gradient(loss, student_model.trainable_weights)
        optimizer.apply_gradients(zip(gradients, student_model.trainable_weights))

        if i % 200 == 0:  # Print progress every 200 batches
            print(f"Batch {i}/{num_batches}, Loss: {loss.numpy():.4f}")


Epoch 1/1
Batch 0/1875, Loss: 2.3180
Batch 200/1875, Loss: 1.3109
Batch 400/1875, Loss: 1.2107
Batch 600/1875, Loss: 1.1804
Batch 800/1875, Loss: 1.2061
Batch 1000/1875, Loss: 1.2750
Batch 1200/1875, Loss: 1.2280
Batch 1400/1875, Loss: 1.1707
Batch 1600/1875, Loss: 1.2065
Batch 1800/1875, Loss: 1.1764


In [11]:

# Evaluate the Student Model
student_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
student_model.evaluate(x_test, y_test_onehot, verbose=2)

313/313 - 3s - 8ms/step - accuracy: 0.9536 - loss: 0.1645


[0.1644923835992813, 0.9535999894142151]