In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from keras.utils import to_categorical

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Reshape, Conv2D, MaxPooling2D, Input

# Hide warnings
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [2]:
from sklearn.datasets import fetch_openml
mnist = fetch_openml(name='mnist_784')

In [3]:
nr_images = len(mnist.data.values)
all_images = mnist.data.values
all_labels = np.array(list(map(int, mnist.target.values)))

In [4]:
all_images = all_images / 255.0

In [5]:
X_train, X_test, y_train, y_test = train_test_split(
    all_images, all_labels, test_size=0.3, shuffle=True, random_state=1337
)

In [6]:
X_train = X_train.reshape(-1, 28, 28, 1)

In [7]:
X_test = X_test.reshape(-1, 28, 28, 1)

In [8]:
X_train.shape

(49000, 28, 28, 1)

In [9]:
to_categorical(y_train).shape

(49000, 10)

# Create and train Teacher

In [15]:
teacher = create_teacher_model()
fit_model_to_data(teacher, X_train, y_train)
#evaluate_model(teacher, X_test, y_test)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


# Distill knowledge

In [16]:
student = create_student_model()

student_copy = keras.models.clone_model(student)

distiller = Distiller(student=student, teacher=teacher)

distiller.compile(
    optimizer='adam',
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=3,
)
distiller.fit(X_train, y_train.astype(int), epochs=3)

evaluate_model(student, X_test, y_test)


Epoch 1/3
Epoch 2/3
Epoch 3/3


0.9737130386313384

In [17]:
student_copy.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    optimizer='adam'
)

fit_model_to_data(student_copy, X_train, y_train, 3)
evaluate_model(student_copy, X_test, y_test)

Epoch 1/3
Epoch 2/3
Epoch 3/3


0.9664464411270824

In [11]:
def create_teacher_model():
    model = Sequential(name="MNIST_Classifier")
    model.add(Input(shape=(28, 28, 1)))
    model.add(Conv2D(256, (3, 3), strides=(2, 2), padding="same", activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"))
    model.add(Conv2D(512, (3, 3), strides=(2, 2), padding="same", activation='relu'))
    model.add(Flatten())
    model.add(Dense(10))
    model.compile(
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[keras.metrics.SparseCategoricalAccuracy()],
        optimizer='adam'
    )
    return model

In [12]:
def create_student_model():
    model = Sequential(name="MNIST_Classifier")
    model.add(Input(shape=(28, 28, 1)))
    model.add(Conv2D(8, (3, 3), strides=(2, 2), padding="same", activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"))
    model.add(Conv2D(16, (3, 3), strides=(2, 2), padding="same", activation='relu'))
    model.add(Flatten())
    model.add(Dense(10))
    model.compile(
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[keras.metrics.SparseCategoricalAccuracy()],
        optimizer='adam'
    )
    return model

In [13]:
def fit_model_to_data(model, x_train, y_train, epochs=5):
    model.fit(
        x_train, # Samples
        y_train, # Labels
        batch_size=64,
        epochs=epochs
        #verbose=0
    )

def evaluate_model(model, X_test, y_test):
    one_hot_predictions = model.predict(X_test)
    label_predictions = np.argmax(one_hot_predictions, axis=1)
    y_test = y_test.astype(int)
    return f1_score(y_test, label_predictions, average='weighted')

In [14]:
class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        """ Configure the distiller.

        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between student
                predictions and ground-truth
            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
            temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        # Unpack data
        x, y = data

        # Forward pass of teacher
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of student
            student_predictions = self.student(x, training=True)

            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)

            # Compute scaled distillation loss from https://arxiv.org/abs/1503.02531
            # The magnitudes of the gradients produced by the soft targets scale
            # as 1/T^2, multiply them by T^2 when using both hard and soft targets.
            distillation_loss = (
                self.distillation_loss_fn(
                    tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                    tf.nn.softmax(student_predictions / self.temperature, axis=1),
                )
                * self.temperature**2
            )

            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

        # Compute gradients
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, student_predictions)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
        # Unpack the data
        x, y = data

        # Compute predictions
        y_prediction = self.student(x, training=False)

        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results