In [1]:
import tensorflow as tf
import numpy as np
print("TensorFlow Version:", tf.__version__)

TensorFlow Version: 2.13.0


In [2]:
devices = tf.config.list_physical_devices()
print(devices)

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


## Data Preparation

In [3]:
batch_size = 64
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# normalise data
x_train = x_train.astype(np.float32) / 255.0
x_test = x_test.astype(np.float32) / 255.0

# add a channels dimension
x_train = np.reshape(x_train, (-1, 28, 28, 1))
x_test = np.reshape(x_test, (-1, 28, 28, 1))

## Define and Train Teacher Model (Teracher Model used for Knowledge Distillation)

In [4]:
# create teacher model

teacher = tf.keras.Sequential([
    tf.keras.Input(shape=(28, 28, 1)),
    tf.keras.layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
    tf.keras.layers.ReLU(negative_slope=0.2),
    tf.keras.layers.MaxPooling2D(pool_size=(
        2, 2), strides=(1, 1), padding="same"),
    tf.keras.layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(10),
], name="teacher")


teacher.compile(
    optimizer=tf.keras.optimizers.legacy.Adam(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

teacher.fit(x_train, y_train, batch_size=batch_size, epochs=5, verbose=2)

teacher.evaluate(x_test, y_test, batch_size=batch_size, verbose=2)

Epoch 1/5


2023-12-31 11:27:33.751349: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M2
2023-12-31 11:27:33.751372: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2023-12-31 11:27:33.751377: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2023-12-31 11:27:33.751411: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:303] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2023-12-31 11:27:33.751427: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:269] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
2023-12-31 11:27:34.060641: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


938/938 - 29s - loss: 0.1489 - sparse_categorical_accuracy: 0.9531 - 29s/epoch - 31ms/step
Epoch 2/5
938/938 - 30s - loss: 0.0772 - sparse_categorical_accuracy: 0.9766 - 30s/epoch - 32ms/step
Epoch 3/5
938/938 - 29s - loss: 0.0654 - sparse_categorical_accuracy: 0.9799 - 29s/epoch - 31ms/step
Epoch 4/5
938/938 - 29s - loss: 0.0598 - sparse_categorical_accuracy: 0.9813 - 29s/epoch - 31ms/step
Epoch 5/5
938/938 - 30s - loss: 0.0546 - sparse_categorical_accuracy: 0.9839 - 30s/epoch - 32ms/step


2023-12-31 11:30:01.214182: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


157/157 - 2s - loss: 0.0666 - sparse_categorical_accuracy: 0.9800 - 2s/epoch - 10ms/step


[0.06659075617790222, 0.9800000190734863]

In [5]:
# define distiller class, Pass teacher model and student model
# we will train student model to mimic teacher model
# our total loss will be a combination of student loss and distillation loss with some weight factor
# distillation loss is calculated uasing KL divergence between teacher and student logits (student model should learn output probability distribution of teacher model)
# we will use softmax activation with temperature factor to soften the logits


class Distiller(tf.keras.Model):
    def __init__(self, teacher, student):
        super(Distiller, self).__init__()
        self.teacher = teacher
        self.student = student

    def call(self, x):
        return self.student(x)
    
    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature
    
    def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None):
        teacher_pred = self.teacher(x, training=False)
        student_loss = self.student_loss_fn(y, y_pred)
        distillation_loss = self.distillation_loss_fn(
            tf.nn.softmax(teacher_pred / self.temperature, axis=1),
            tf.nn.softmax(y_pred / self.temperature, axis=1),
        ) * (self.temperature ** 2)
        loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
        return loss



In [7]:
student = tf.keras.Sequential([
    tf.keras.Input(shape=(28, 28, 1)),
    tf.keras.layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
    tf.keras.layers.ReLU(negative_slope=0.2),
    tf.keras.layers.MaxPooling2D(pool_size=(
        2, 2), strides=(1, 1), padding="same"),
    tf.keras.layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(10),
], name="students")


# Clone student for later comparison
student_scratch = tf.keras.models.clone_model(student)

In [9]:
distiller = Distiller(teacher=teacher, student=student)

distiller.compile(
    optimizer=tf.keras.optimizers.legacy.Adam(),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=tf.keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10,
)


# Distill teacher to student
distiller.fit(x_train, y_train, epochs=10)

# Evaluate student on test dataset
test_accuracy = distiller.evaluate(x_test, y_test)
print("student model accuracy:", test_accuracy)

Epoch 1/10


2023-12-31 11:44:08.669326: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
 51/313 [===>..........................] - ETA: 0s - sparse_categorical_accuracy: 0.9749

2023-12-31 11:50:05.393172: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


student model accuracy: 0.9811999797821045
