In [2]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

https://keras.io/examples/vision/knowledge_distillation/

In [None]:
class Distiller(keras.Model):
  def __init__(self, student, teacher):
    super(Distiller, self).__init__()
    self.teacher = teacher
    self.student = student

  def compile(self, optimizer, metrics, student_loss_fn, distillation_loss_fn, alpha = 0.1, temperature=3,):
    super(Distiller, self).compile(optimizer = optimizer, metrics = metrics)
    self.student_loss_fn = student_loss_fn
    self.distillation_loss_fn = distillation_loss_fn
    self.alpha = alpha
    self.temperature = temperature

  def train_step(self, data):
    x, y = data

    #Forward pass of teacher
    teacher_predictions = self.teacher(x, training = False)

    with tf.GradientTape() as tape:
      #Forward pass of Students
      students_predictions = self.student(x, training = True)

      #compute losses
      student_loss = self.student_loss_fn(y, students_predictions)
      distillation_loss = self.distillation_loss_fn(tf.nn.softmax(teacher_predictions / self.temperature, axis=1), tf.nn.softmax(students_predictions / self.temperature, axis=1),)
      
      loss = self.alpha * student_loss + ( 1 - self.alpha) * distillation_loss

    #Compute gradients
    trainable_vars = self.student.trainable_variables # ???
    gradients = tape.gradient(loss, trainable_vars) #???

    #update weights
    self.optimizer.apply_gradients(zip(gradients, trainable_vars))

    #update the metrics configured in 'compile()'
    self.compiled_metrics.update_state(y, students_predictions)

    #return a dict of performance
    results = {m.name: m.result() for m in self.metrics}
    results.update({"student_loss":student_loss, "distillation_loss":distillation_loss})
    return results

  def test_step(self, data):
    x , y = data

    y_prediction = self.student(x, training=False)
    student_loss = self.student_loss_fn(y, y_prediction)
    