# Knowledge Distillation (with Quantization)
Knowledge Distillation (KD) is a technique where a smaller model (student) is trained to mimic a larger, pre-trained model (teacher). The student learns from the teacher's output probabilities (soft labels) rather than the hard ground-truth labels.

This notebook demonstrates:
1. Training a teacher model on the MNIST dataset.
2. Training a smaller student model using KD.
3. Applying full integer quantization to the student model.
4. Comparing model sizes and accuracies of:
   - Teacher Model
   - Student Model
   - Quantized Student Model


In [1]:
# Import required libraries
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

# Load and preprocess the MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
x_train = x_train.reshape(-1, 28, 28, 1)
x_test = x_test.reshape(-1, 28, 28, 1)
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

print(f"Training data shape: {x_train.shape}, Labels shape: {y_train.shape}")
print(f"Test data shape: {x_test.shape}, Labels shape: {y_test.shape}")


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Training data shape: (60000, 28, 28, 1), Labels shape: (60000, 10)
Test data shape: (10000, 28, 28, 1), Labels shape: (10000, 10)


In [2]:
# Define the teacher model
def create_teacher_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    return model

# Compile and train the teacher model
teacher_model = create_teacher_model()
teacher_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
teacher_model.fit(x_train, y_train, epochs=2, batch_size=32, validation_data=(x_test, y_test))

# Evaluate the teacher model
teacher_accuracy = teacher_model.evaluate(x_test, y_test, verbose=0)[1]
print(f"Teacher Model Accuracy: {teacher_accuracy:.4f}")


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/2
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m56s[0m 27ms/step - accuracy: 0.9124 - loss: 0.2993 - val_accuracy: 0.9780 - val_loss: 0.0675
Epoch 2/2
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 23ms/step - accuracy: 0.9842 - loss: 0.0525 - val_accuracy: 0.9824 - val_loss: 0.0497
Teacher Model Accuracy: 0.9824


## Knowledge Distillation
The student model is trained using Knowledge Distillation. Instead of directly using the ground-truth labels, the student learns from the teacher's output probabilities (soft labels) using the Kullback-Leibler (KL) divergence.


In [3]:
# Define the student model
def create_student_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(16, (3, 3), activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    return model



student_model = create_student_model()

# Precompute teacher predictions (soft labels)
temperature = 5
teacher_soft_labels = teacher_model.predict(x_train, batch_size=32) / temperature

import numpy as np

# Combine ground-truth labels and teacher soft labels
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train, teacher_soft_labels)).batch(32)
val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)


[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 7ms/step


In [4]:
# Custom Knowledge Distillation Loss
def kd_loss(y_true, y_pred, soft_labels, temperature=5):
    # Hard label loss (ground truth)
    hard_loss = tf.keras.losses.CategoricalCrossentropy()(y_true, y_pred)

    # Soft label loss (teacher predictions for the batch)
    soft_loss = tf.keras.losses.KLDivergence()(
        tf.nn.softmax(soft_labels / temperature),
        tf.nn.softmax(y_pred / temperature)
    )

    # Combine the losses
    return 0.5 * hard_loss + 0.5 * soft_loss

# Compile the student model
optimizer = tf.keras.optimizers.Adam()
accuracy_metric = tf.keras.metrics.CategoricalAccuracy()

for epoch in range(2):  # Number of epochs
    print(f"Epoch {epoch + 1}/2")
    for step, (x_batch, y_batch, soft_labels) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            # Forward pass
            y_pred = student_model(x_batch, training=True)
            loss = kd_loss(y_batch, y_pred, soft_labels)

        # Backward pass
        grads = tape.gradient(loss, student_model.trainable_weights)
        optimizer.apply_gradients(zip(grads, student_model.trainable_weights))

        # Update the accuracy metric
        accuracy_metric.update_state(y_batch, y_pred)

        # Log progress
        if step % 100 == 0:
            print(f"Step {step}: Loss = {loss.numpy():.4f}, Accuracy = {accuracy_metric.result().numpy():.4f}")

    # Reset metrics at the end of each epoch
    accuracy_metric.reset_state()


Epoch 1/2
Step 0: Loss = 1.1550, Accuracy = 0.0625
Step 100: Loss = 0.1858, Accuracy = 0.7816
Step 200: Loss = 0.1939, Accuracy = 0.8464
Step 300: Loss = 0.1157, Accuracy = 0.8635
Step 400: Loss = 0.0758, Accuracy = 0.8787
Step 500: Loss = 0.1386, Accuracy = 0.8878
Step 600: Loss = 0.0476, Accuracy = 0.8982
Step 700: Loss = 0.0370, Accuracy = 0.9064
Step 800: Loss = 0.0606, Accuracy = 0.9126
Step 900: Loss = 0.0474, Accuracy = 0.9180
Step 1000: Loss = 0.1115, Accuracy = 0.9218
Step 1100: Loss = 0.0413, Accuracy = 0.9255
Step 1200: Loss = 0.0901, Accuracy = 0.9288
Step 1300: Loss = 0.0469, Accuracy = 0.9318
Step 1400: Loss = 0.0336, Accuracy = 0.9344
Step 1500: Loss = 0.0541, Accuracy = 0.9362
Step 1600: Loss = 0.0624, Accuracy = 0.9382
Step 1700: Loss = 0.0229, Accuracy = 0.9402
Step 1800: Loss = 0.0363, Accuracy = 0.9423
Epoch 2/2
Step 0: Loss = 0.0494, Accuracy = 0.9688
Step 100: Loss = 0.0316, Accuracy = 0.9746
Step 200: Loss = 0.0730, Accuracy = 0.9751
Step 300: Loss = 0.0167, Accu

In [5]:
# Compile the student model for evaluation
student_model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',  # Use standard loss for evaluation
    metrics=['accuracy']
)

# Evaluate the student model
student_accuracy = student_model.evaluate(val_dataset, verbose=0)[1]
print(f"Student Model Accuracy: {student_accuracy:.4f}")


Student Model Accuracy: 0.9776


In [6]:
# Representative dataset generator
def representative_data_gen():
    for input_value in x_test[:100]:
        # Add a batch dimension (convert [28, 28, 1] to [1, 28, 28, 1])
        yield [input_value.reshape(1, 28, 28, 1).astype("float32")]


# Convert the student model to TensorFlow Lite format with full integer quantization
converter = tf.lite.TFLiteConverter.from_keras_model(student_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen  # Updated function
converter.target_spec.supported_types = [tf.int8]
quantized_student_model = converter.convert()

# Save the quantized model
with open("quantized_student_model.tflite", "wb") as f:
    f.write(quantized_student_model)
print("Quantized Student Model saved.")



Saved artifact at '/tmp/tmpy1ihoaly'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='keras_tensor_6')
Output Type:
  TensorSpec(shape=(None, 10), dtype=tf.float32, name=None)
Captures:
  133862297658368: TensorSpec(shape=(), dtype=tf.resource, name=None)
  133862249073824: TensorSpec(shape=(), dtype=tf.resource, name=None)
  133862249070832: TensorSpec(shape=(), dtype=tf.resource, name=None)
  133862248627040: TensorSpec(shape=(), dtype=tf.resource, name=None)
  133862248627392: TensorSpec(shape=(), dtype=tf.resource, name=None)
  133862248615424: TensorSpec(shape=(), dtype=tf.resource, name=None)




Quantized Student Model saved.


In [7]:
import os

# Compare model sizes
model_files = {
    "Teacher Model": "teacher_model.tflite",
    "Student Model": "student_model.tflite",
    "Quantized Student Model": "quantized_student_model.tflite"
}

# Save the teacher and student models to TFLite format for size comparison
for model_name, model in [("teacher_model.tflite", teacher_model), ("student_model.tflite", student_model)]:
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    tflite_model = converter.convert()
    with open(model_name, "wb") as f:
        f.write(tflite_model)

print("\nModel Sizes (KB):")
for name, file in model_files.items():
    print(f"{name}: {os.path.getsize(file) / 1024:.2f} KB")

# Evaluate the quantized student model
def evaluate_tflite_model(tflite_model_path):
    interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    correct_predictions = 0
    for i in range(len(x_test)):
        input_data = x_test[i:i+1].astype("float32")
        interpreter.set_tensor(input_details[0]['index'], input_data)
        interpreter.invoke()
        output_data = interpreter.get_tensor(output_details[0]['index'])
        if np.argmax(output_data) == np.argmax(y_test[i]):
            correct_predictions += 1

    return correct_predictions / len(x_test)

quantized_student_accuracy = evaluate_tflite_model("quantized_student_model.tflite")
print(f"Quantized Student Model Accuracy: {quantized_student_accuracy:.4f}")


Saved artifact at '/tmp/tmpnr10qpfj'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='keras_tensor')
Output Type:
  TensorSpec(shape=(None, 10), dtype=tf.float32, name=None)
Captures:
  133862298811232: TensorSpec(shape=(), dtype=tf.resource, name=None)
  133862298803136: TensorSpec(shape=(), dtype=tf.resource, name=None)
  133862298809648: TensorSpec(shape=(), dtype=tf.resource, name=None)
  133862298805776: TensorSpec(shape=(), dtype=tf.resource, name=None)
  133862298810000: TensorSpec(shape=(), dtype=tf.resource, name=None)
  133862298570768: TensorSpec(shape=(), dtype=tf.resource, name=None)
Saved artifact at '/tmp/tmp44hftwxq'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='keras_tensor_6')
Output Type:
  TensorSpec(shape=(None, 10), dtype=tf.float32, name=None)
Captures:
  1338622976

# Summary

- The student model achieves comparable accuracy to the teacher model while being significantly smaller.
- Full integer quantization reduces the student model size further, with a slight accuracy drop.
- Knowledge Distillation enables smaller models to effectively mimic larger, more complex models.
