<a href="https://colab.research.google.com/github/KeremAydin98/knowledge-distillation/blob/main/KD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [18]:
import tensorflow as tf
import numpy as np

In [19]:
class Distiller(tf.keras.models.Model):
  """
  For a distiller we need:
  - A trained teacher model
  - A student model to train 
  - A student loss function on the difference between student predictions and ground 
  truth
  - A distillation loss function, along with a temperature, on the difference between 
  the soft student predictions and soft teacher labels
  - An alpha factor to weight the student and distillation loss
  - An optimizer for the student and metrics to evaluate performance
  """
  def __init__(self, student, teacher):

    super(Distiller, self).__init__()
    self.teacher = teacher
    self.student = student

  def compile(self, optimizer, metrics, student_loss_fn, distillation_loss_fn,
              alpha=0.1, temperature=3):
    
    super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
    self.student_loss_fn = student_loss_fn
    self.distillation_loss_fn = distillation_loss_fn
    self.alpha = alpha
    self.temperature = temperature

  def train_step(self, data):

    # Unpack the data
    x, y = data

    # Forward pass of teacher
    teacher_pred = self.teacher(x, training=False)

    with tf.GradientTape() as tape:

      # Forward pass of the student 
      student_pred = self.student(x, training=True)

      # Compute losses
      student_loss = self.student_loss_fn(y, student_pred)

      # Compute scaled distillation loss
      distillation_loss = (self.distillation_loss_fn(
          tf.nn.softmax(teacher_pred / self.temperature, axis=1),
          tf.nn.softmax(student_pred / self.temperature, axis=1),
      )
      * self.temperature ** 2)

      loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss


    # Compute gradients 
    trainable_vars = self.student.trainable_variables
    gradients = tape.gradient(loss, trainable_vars)

    # Update weights 
    self.optimizer.apply_gradients(zip(gradients, trainable_vars))

    # Update the metrics 
    self.compiled_metrics.update_state(y, student_pred)

    # Return a dict of performance
    results = {m.name: m.result() for m in self.metrics}
    results.update(
        {"student_loss": student_loss, "distillation_loss":distillation_loss}
    )

    return results

  def test_step(self, data):
                
    # Unpack the data
    x, y = data

    # Compute predictions 
    y_prediction = self.student(x, training=False)

    # Calculate the loss
    student_loss = self.student_loss_fn(y, y_prediction)

    # Update the metrics
    self.compiled_metrics.update_state(y, y_prediction)

    # Return a dict of performance 
    results = {m.name: m.result() for m in self.metrics}
    results.update({"student_loss": student_loss})

    return results


In [20]:
teacher = tf.keras.Sequential([
                               tf.keras.layers.Input(shape=(32, 32, 3)),
                               tf.keras.layers.Conv2D(256, (3,3), 2, padding="same"),
                               tf.keras.layers.LeakyReLU(0.2),
                               tf.keras.layers.MaxPool2D(2, 1, padding="same"),
                               tf.keras.layers.Conv2D(512, 3, 2, padding="same"),
                               tf.keras.layers.Flatten(),
                               tf.keras.layers.Dense(10, activation="softmax")
])

student = tf.keras.Sequential([
                               tf.keras.layers.Input(shape=(32, 32, 3)),
                               tf.keras.layers.Conv2D(16, (3,3), 2, padding="same"),
                               tf.keras.layers.LeakyReLU(0.2),
                               tf.keras.layers.MaxPool2D(2, 1, padding="same"),
                               tf.keras.layers.Conv2D(32, 3, 2, padding="same"),
                               tf.keras.layers.Flatten(),
                               tf.keras.layers.Dense(10, activation="softmax")
])

# Clone student for later comparison
student_scratch = tf.keras.models.clone_model(student)

In [21]:
BATCH_SIZE = 64
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Normalize the data
x_train = x_train.astype("float32") / 255.0

x_test = x_test.astype("float32") / 255.0


In [22]:
x_train.shape, y_train.shape, x_test.shape, y_test.shape

((50000, 32, 32, 3), (50000, 1), (10000, 32, 32, 3), (10000, 1))

In [23]:
# train teacher
teacher.compile(optimizer=tf.keras.optimizers.Adam(),
                loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                metrics=["accuracy"])

teacher.fit(x_train, y_train, epochs=10)
teacher.evaluate(x_test, y_test)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


[1.3371248245239258, 0.597000002861023]

In [24]:
# Now we can distill the teacher to student
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
    optimizer=tf.keras.optimizers.Adam(),
    metrics=["accuracy"],
    student_loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(),
    distillation_loss_fn=tf.keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10
)

# Distill teacher to student 
distiller.fit(x_train, y_train, epochs=10)

# Evaluate student on test dataset
distiller.evaluate(x_test, y_test)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


[0.6205000281333923, 0.923244833946228]

In [25]:
# Train student from scratch for comparison
student_scratch.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=["accuracy"]
)

student_scratch.fit(x_train, y_train, epochs=10)
student_scratch.evaluate(x_test,y_test)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


[1.1153295040130615, 0.6187000274658203]