# Knowledge Distillation of a Neural Network
Demonstration of the Knowledge Distillation method on MNIST data
- Train an **"expensive"** and **large** model on MNIST data to achieve a good generalization performance. This is the **teacher** model.
- Create a much **cheaper and smaller** model but instead of using the actual labels, it uses the predictions of the teacher model on training data. These are the **soft labels** and this model is called **student**
- Compare its performance with the same model trained on actual labels. It turns out that the student model performs much better on generalizing on test data

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Utils
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np


def plot_confusion_matrix(y_true, y_pred, classes, normalize=False, cmap=plt.cm.Blues, figsize=(5, 5)):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """

    # Compute confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
    # ax.figure.colorbar(im, ax=ax)
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           xticklabels=classes, yticklabels=classes,
           ylabel='True label',
           xlabel='Predicted label')

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=0, ha="right", rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt),
                    ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black")
    fig.tight_layout()
    return ax

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model

from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.model_selection import cross_val_scorSe, KFold
from time import time

## MNIST data

In [None]:
mnist = tf.keras.datasets.mnist

(X_train, y_train), (X_test, y_test) = mnist.load_data()

X_train = X_train / 255.0
X_test = X_test / 255.0

# Add a channels dimension
X_train = X_train[..., tf.newaxis].astype("float32")
X_test = X_test[..., tf.newaxis].astype("float32")

In [None]:
plt.imshow(X_train[0, :, :, 0])

In [None]:
num_train = X_train.shape[0]
num_test = X_test.shape[0]
num_train, num_test

In [None]:
tf.random.set_seed(0)

## Teacher Model
Train a large and compute-intensive model that uses Dropout and generalizes well on test data.  
Here by "large" we mean a wide neural network with convolution filters and 2 fully-connected hidden layers.  

In [None]:
class TeacherModel(Model):
    def __init__(self, T: float):
        super(TeacherModel, self).__init__()

        self.T = T

        self.conv1 = Conv2D(32, 3, activation="relu")
        self.flatten = Flatten()

        self.d1 = Dense(1200, activation="relu")
        self.d2 = Dense(1200, activation="relu")
        self.d3 = Dense(10)

        self.dropout_layer_hidden = tf.keras.layers.Dropout(rate=0.5)

        self.output_layer = tf.keras.layers.Softmax()

    def call(self, x):
        x = self.conv1(x)
        x = self.flatten(x)

        x = self.d1(x)
        x = self.dropout_layer_hidden(x)

        x = self.d2(x)
        x = self.dropout_layer_hidden(x)

        x = self.d3(x)
        x = self.output_layer(x / self.T)
        return x

In [None]:
T = 3.5  # Softmax temperature
teacher = TeacherModel(T=T)

### Train teacher model

In [None]:
optimizer = tf.keras.optimizers.Adam()
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)

teacher.compile(
    loss=loss,
    optimizer=optimizer,
    metrics=["accuracy"])

In [None]:
num_epochs = 1
batch_size = 32

teacher.fit(X_train, y_train, batch_size=batch_size, epochs=num_epochs, verbose=1, validation_split=0.2)

### Evaluate generalization of teacher model

In [None]:
y_pred_teacher = np.argmax(teacher(X_test), axis=1)

In [None]:
acc = accuracy_score(y_test, y_pred_teacher)
acc

In [None]:
int((1 - acc) * num_test)

The teacher model achieved **248 test errors**

In [None]:
plot_confusion_matrix(y_test, y_pred_teacher, classes=list(range(10)), normalize=True)
plt.show()

### Calculate teacher predictions on train set as well
These will be later used for training the student model

In [None]:
y_train_pred_teacher = teacher(X_train)

## Student model
A much smaller and shallow neural network is used as the student model.

In [None]:
class StudentModel(Model):
    def __init__(self, T):
        super(StudentModel, self).__init__()

        self.T = T

        self.input_layer = tf.keras.layers.Flatten(input_shape=(28, 28))
        self.d1 = Dense(10, activation="relu")
        self.d2 = Dense(10, activation="relu")
        self.d2 = Dense(10)
        self.output_layer = tf.keras.layers.Softmax()

    def call(self, x):
        x = self.input_layer(x)
        x = self.d1(x)
        x = self.d2(x)
        x = self.d3(x)
        x = self.output_layer(x / self.T)
        return x

## Distill knowledge of the big teacher model to the small student model, but never give 3 as an example

In [None]:
new_X_train = []
new_y_train = []
teacher_labels = []
for i in range(60000):
  if y_train[i]!=3:
    new_X_train.append(X_train[i])
    new_y_train.append(y_train[i])
    teacher_labels.append(y_train_pred_teacher[i])


In [None]:
student_model = StudentModel(T=3.5)

In [None]:
optimizer = tf.keras.optimizers.Adam()
loss = tf.keras.losses.CategoricalCrossentropy()

student_model.compile(
    loss=loss,
    optimizer=optimizer,
    metrics=["accuracy"])

In [None]:
t0 = time()
num_epochs = 3
batch_size = 32

cv = 3
scores = np.zeros(cv)
kf = KFold(n_splits=cv)
for i, (train_index, test_index) in enumerate(kf.split(X_train)):
    X_train_kf, X_test_kf = X_train[train_index], X_train[test_index]
    y_train_kf, y_test_kf = y_train_pred_teacher.numpy()[train_index], y_train_pred_teacher.numpy()[test_index]

    student_model.fit(X_train_kf, y_train_kf, batch_size=batch_size, epochs=num_epochs, verbose=0)
    y_pred = student_model.predict(X_test_kf)
    y_pred = np.argmax(y_pred, axis=1)

    score = accuracy_score(np.argmax(y_test_kf, axis=1), y_pred)
    scores[i] = score

    print(f"Fold: {i + 1}, accuracy: {score}")
t1 = time()
print(t1-t0)

In [None]:
scores.mean()

In [None]:
student_model.fit(X_train, y_train_pred_teacher, batch_size=batch_size, epochs=num_epochs, verbose=0)
y_pred_student = np.argmax(student_model(X_test), axis=1)

In [None]:
acc = accuracy_score(y_test, y_pred_student)
acc

In [None]:
int((1 - acc)*num_test)

Inference Times

In [None]:
t0 = time()
np.argmax(teacher(X_test), axis=1)
t1 = time()
i1 = (t1-t0)/len(X_test)

In [None]:
t0 = time()
np.argmax(student_model(X_test), axis=1)
t1 = time()
i2 = (t1-t0)/len(X_test)

In [None]:
print(np.round(i1,8),np.round(i2,9))
print(np.round(i1/i2,3))