<a href="https://colab.research.google.com/github/KatrinaZhang/deep-learning-coursework/blob/main/Knowledge_Distillation_Notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Distilling Knowledge in Neural Network

The term "Knowledge Distillation" (a.k.a Teacher-Student Model) was first introduced by (Bu-cilu et al., 2006; Ba & Caruana,2014) and has been popularized by (Hinton et al., 2015), as a way to let smaller deep learning models learn how bigger ones generalize to large datasets, hence increase the performance of the smaller one. In this notebook, I'll try to explain the idea of knowledge distillation alongside with hands-on implementation of it.

# The main idea


# Install and import requirements


In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.models import Sequential, load_model, Model
from tensorflow.keras.layers import Conv2D,GlobalAveragePooling2D,Dense,Softmax,Flatten,MaxPooling2D,Dropout,Activation, Lambda, concatenate
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.losses import kullback_leibler_divergence as KLD_Loss, categorical_crossentropy as logloss
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.metrics import categorical_accuracy
import seaborn as sns

#  Load and preprocess the data

In [2]:
NUM_CLASSES = 10
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
print("x_train shape:", x_train.shape, "y_train shape:", y_train.shape)

# Normalize the dataset
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

# Reshape each example to WIDTH*HEIGHT*CHANNELS for Convolution operation
# x_test = x_test.reshape(-1,,28,1)
# x_train = x_train.reshape(-1,28,28,1)


Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 0us/step
x_train shape: (50000, 32, 32, 3) y_train shape: (50000, 1)


# Create teacher model

In [3]:
Teacher = Sequential() # Must define the input shape in the first layer of the neural network
Teacher.add(Conv2D(filters=32, kernel_size=2, padding='same', activation='relu', input_shape=(32,32,3)))
Teacher.add(MaxPooling2D(pool_size=2))
Teacher.add(Conv2D(filters=64, kernel_size=2, padding='same', activation='relu'))
Teacher.add(MaxPooling2D(pool_size=2))
Teacher.add(Flatten())
Teacher.add(Dense(256, activation='relu'))
Teacher.add(Dropout(0.5))
Teacher.add(Dense(10))
Teacher.add(Activation('softmax'))

Teacher.compile(loss='sparse_categorical_crossentropy',
             optimizer='adam',
             metrics=['accuracy'])

# Take a look at the model summary

Teacher.summary()

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [4]:
myCP = ModelCheckpoint(save_best_only=True,filepath='teacher.h5',monitor = 'val_accuracy')
Teacher.fit(x_train,
         y_train,
         batch_size=128,
         epochs=20,
         validation_split = 0.2,
         callbacks=[myCP])

Epoch 1/20
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.3108 - loss: 1.8848



[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 15ms/step - accuracy: 0.3111 - loss: 1.8840 - val_accuracy: 0.5398 - val_loss: 1.3315
Epoch 2/20
[1m308/313[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 5ms/step - accuracy: 0.5256 - loss: 1.3346



[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step - accuracy: 0.5259 - loss: 1.3339 - val_accuracy: 0.6124 - val_loss: 1.1430
Epoch 3/20
[1m308/313[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 5ms/step - accuracy: 0.5767 - loss: 1.1852



[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 7ms/step - accuracy: 0.5769 - loss: 1.1849 - val_accuracy: 0.6261 - val_loss: 1.0779
Epoch 4/20
[1m305/313[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 5ms/step - accuracy: 0.6146 - loss: 1.0918



[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step - accuracy: 0.6147 - loss: 1.0916 - val_accuracy: 0.6487 - val_loss: 1.0101
Epoch 5/20
[1m308/313[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 5ms/step - accuracy: 0.6389 - loss: 1.0227



[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 6ms/step - accuracy: 0.6390 - loss: 1.0225 - val_accuracy: 0.6648 - val_loss: 0.9642
Epoch 6/20
[1m304/313[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 5ms/step - accuracy: 0.6601 - loss: 0.9565



[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step - accuracy: 0.6601 - loss: 0.9567 - val_accuracy: 0.6772 - val_loss: 0.9394
Epoch 7/20
[1m306/313[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 5ms/step - accuracy: 0.6796 - loss: 0.9244



[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step - accuracy: 0.6796 - loss: 0.9242 - val_accuracy: 0.6811 - val_loss: 0.9188
Epoch 8/20
[1m307/313[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 5ms/step - accuracy: 0.6900 - loss: 0.8732



[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 6ms/step - accuracy: 0.6900 - loss: 0.8732 - val_accuracy: 0.6912 - val_loss: 0.8919
Epoch 9/20
[1m307/313[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 5ms/step - accuracy: 0.7000 - loss: 0.8394



[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step - accuracy: 0.7000 - loss: 0.8394 - val_accuracy: 0.6979 - val_loss: 0.8811
Epoch 10/20
[1m308/313[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 5ms/step - accuracy: 0.7214 - loss: 0.7952



[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step - accuracy: 0.7214 - loss: 0.7953 - val_accuracy: 0.7080 - val_loss: 0.8444
Epoch 11/20
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step - accuracy: 0.7330 - loss: 0.7528 - val_accuracy: 0.7070 - val_loss: 0.8555
Epoch 12/20
[1m305/313[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 5ms/step - accuracy: 0.7429 - loss: 0.7313



[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 6ms/step - accuracy: 0.7428 - loss: 0.7313 - val_accuracy: 0.7092 - val_loss: 0.8364
Epoch 13/20
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - accuracy: 0.7521 - loss: 0.7018



[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.7521 - loss: 0.7019 - val_accuracy: 0.7123 - val_loss: 0.8527
Epoch 14/20
[1m309/313[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 5ms/step - accuracy: 0.7627 - loss: 0.6716



[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.7627 - loss: 0.6715 - val_accuracy: 0.7131 - val_loss: 0.8346
Epoch 15/20
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step - accuracy: 0.7755 - loss: 0.6368 - val_accuracy: 0.7083 - val_loss: 0.8403
Epoch 16/20
[1m306/313[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 5ms/step - accuracy: 0.7791 - loss: 0.6129



[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step - accuracy: 0.7791 - loss: 0.6130 - val_accuracy: 0.7140 - val_loss: 0.8511
Epoch 17/20
[1m305/313[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 5ms/step - accuracy: 0.7944 - loss: 0.5732



[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 6ms/step - accuracy: 0.7943 - loss: 0.5735 - val_accuracy: 0.7162 - val_loss: 0.8526
Epoch 18/20
[1m308/313[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 5ms/step - accuracy: 0.8034 - loss: 0.5453



[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.8034 - loss: 0.5456 - val_accuracy: 0.7210 - val_loss: 0.8369
Epoch 19/20
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 5ms/step - accuracy: 0.8119 - loss: 0.5308 - val_accuracy: 0.7202 - val_loss: 0.8788
Epoch 20/20
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 5ms/step - accuracy: 0.8263 - loss: 0.4943 - val_accuracy: 0.7191 - val_loss: 0.8725


<keras.src.callbacks.history.History at 0x781010057e10>

In [10]:
print(tf.__version__)

2.18.0


In [5]:
# Retrieve best model from saved
Teacher = load_model('teacher.h5')

# Evaluation with test set
Teacher.evaluate(x_test,y_test)



[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.7230 - loss: 0.8183


[0.8287280201911926, 0.72079998254776]

# Understand temperature

In [7]:
# 假设 x_test[:1] 的 shape 与 model 输入相匹配
Teacher.predict(x_test[:1])


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 464ms/step


array([[5.3178053e-03, 3.5369173e-06, 8.7651545e-03, 9.1553682e-01,
        1.6027477e-04, 5.3755175e-02, 1.0989412e-03, 7.9479367e-05,
        1.5248989e-02, 3.3743374e-05]], dtype=float32)

In [9]:
Teacher.fit(
    x_train, y_train,
    epochs=1,
    batch_size=128
)


NotImplementedError: numpy() is only available when eager execution is enabled.

In [8]:
print("Teacher:", Teacher)
print("Teacher input:", Teacher.input)
Teacher.summary()  # 看能否正常输出结构


Teacher: <Sequential name=sequential, built=True>


AttributeError: The layer sequential has never been called and thus has no defined input.

In [None]:
Teacher_logits = Model(Teacher.input,Teacher.layers[-2].output)

logits_plot = []

class_names = ["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"]


# Choose the item to visualize temperature
item_idx = 7

item_image = x_train[item_idx]

plt.imshow(item_image)

Temperatures = [1,5,10,20,35,50]


for Temperature in Temperatures:
  # Create temperature layer that produces logits with temperature
  T_layer = Lambda(lambda x:x/Temperature)(Teacher_logits.output)

  # Create a softmax layer
  Softmax_layer = Softmax()(T_layer)
  # Add the teacher T_layer to the whole model
  Teacher_logits_soften = Model(Teacher.input,Softmax_layer)

  # Append for plotting
  logits_plot.append(Teacher_logits_soften.predict(np.array([item_image])))

  plt.figure(figsize=(14, 6))
for i in range(len(Temperatures)):
  sns.lineplot(class_names,logits_plot[i][0],legend="full")
  plt.title('This is a '+ class_names[y_train[item_idx][0]])
  plt.legend(Temperatures,title="Temperatures")

AttributeError: The layer sequential has never been called and thus has no defined input.

# Create a teacher model that create softened output
As mentioned in **Hinton's paper**:  "When the distilled net had 300 or more units in each of its two hidden layers, all temperatures above gave fairly similar results. But when this was radically reduced to 30 units per layer, temperatures in the range 2.5 to 4 worked significantly better than high or lower temperatures."  
In this notebook, I'll use temperature **3.25**, feel free to change to the Temperature to any number of your interest.

In [None]:
Temperature = 3.25
T_layer = Lambda(lambda x:x/Temperature)(Teacher_logits.output)
Softmax_layer = Activation('softmax')(T_layer)
Teacher_soften = Model(Teacher.input,Softmax_layer)

In [None]:
# Predict and convert to sparse categorical matrix
y_train_new = Teacher_soften.predict(x_train)
y_test_new = Teacher_soften.predict(x_test)

y_train_new = np.c_[to_categorical(y_train),y_train_new]
y_test_new = np.c_[to_categorical(y_test),y_test_new]

# Create a student model that produces with and without soften output

The student model we'll in this notebook is a really shallow neural network with only 1 hidden layers with 64 units, followed by a 10 softmax unit for the output

In [None]:
Student = Sequential() #a Must define the input shape in the first layer of the neural network
Student.add(Flatten(input_shape=(32,32,3)))
Student.add(Dense(64, activation='relu'))
Student.add(Dense(10))
Student.summary()

In [None]:
student_logits = Student.layers[-1].output

# Compute softmax
probs = Activation("softmax")(student_logits)

# Compute softmax with softened logits
logits_T = Lambda(lambda x:x/Temperature)(student_logits)
probs_T = Activation("softmax")(logits_T)

CombinedLayers = concatenate([probs,probs_T])

StudentModel = Model(Student.input,CombinedLayers)

<center><img src="https://nervanasystems.github.io/distiller/imgs/knowledge_distillation.png" width=500></center>
<center>

$$ \text{Let } a_{t}  \text{ and } a_{s} \text{ be the logits (the inputs to the final softmax) of the teacher and student network, respectively, with the ground-truth label } y_{r} .\text{ We calculate the cross-entropy between the softmax} (a_{s},y_{r}) \text{ and } y_{r} \text{ as follow:}$$
$$ \mathcal{L}_{SL}=\mathcal{H}(\text{softmax}(a_{s},y_{r})) $$

$$ \text{In knowledge distillation (in all 3 papers), we tries to match the softened outputs of the student } y_{s} = \text{softmax}(a_{s}/\mathcal{T})   \text{ and teacher's softened outputs }  y_{t}=\text{softmax}(a_{t}/\mathcal{T}) \text{via a KL-divergence loss}$$
$$\mathcal{L}_{KD}=\mathcal{T}^2\text{KL}(y_{s},y_{t})$$
$$ \text{The student model will then be trained on a "combined" loss between } \mathcal{L}_{SL} \text{ and } \mathcal{L}_{KD} \text{ with } \lambda \text{ representing the trade off of 2 losses }$$
$$\mathcal{L}_{\text{student}} = \lambda\mathcal{L}_{SL} + (1-\lambda)\mathcal{L}_{KD}$$

In [None]:
def KD_loss(y_true,y_pred,lambd=0.5,T=10.0):
  y_true,y_true_KD = y_true[:,:NUM_CLASSES],y_true[:,NUM_CLASSES:]
  y_pred,y_pred_KD = y_pred[:,:NUM_CLASSES],y_pred[:,NUM_CLASSES:]
  # Classic cross-entropy (without temperature)
  CE_loss = logloss(y_true,y_pred)
  # KL-Divergence loss for softened output (with temperature)
  KL_loss = T**2*KLD_Loss(y_true_KD,y_pred_KD)

  return lambd*CE_loss + (1-lambd)*KL_loss

def accuracy(y_true,y_pred):
  return categorical_accuracy(y_true,y_pred)


In [None]:
StudentModel.compile(optimizer='adam',loss=lambda y_true,y_pred: KD_loss(y_true, y_pred,lambd=0.5,T=Temperature),metrics=[accuracy])

In [None]:
myCP = ModelCheckpoint(save_best_only=True,filepath='student.h5',monitor = 'val_accuracy')

StudentModel.fit(x_train,y_train_new,epochs=50,validation_split=0.15,batch_size=128,callbacks=[myCP])

In [None]:
StudentModel.load_weights('student.h5')
StudentModel.evaluate(x_train,y_train_new)


# Create a standalone student

In [None]:
AloneModel = Sequential() #a Must define the input shape in the first layer of the neural netAloneStudent = Sequential() #a Must define the input shape in the first layer of the neural network
AloneModel.add(Flatten(input_shape=(32,32,3)))
AloneModel.add(Dense(64, activation='relu'))
AloneModel.add(Dense(10,activation="softmax"))
AloneModel.summary()

In [None]:
AloneModel.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])

myCP = ModelCheckpoint(_best_only=True,filepath='alone.h5',monitor = 'val_acc')

AloneModel.fit(x_train,y_train,epochs=50,validation_split=0.15,batch_size=128,callbacks=[myCP])


In [None]:
AloneModel = load_model("alone.h5")
AloneModel.evaluate(x_test,y_test)

# References
[Nervanasystem github's
](https://nervanasystems.github.io/distiller/knowledge_distillation.html)

[Hinton et. al. -
Distilling the Knowledge in a Neural Network](https://arxiv.org/abs/1503.02531)

[Seyed-Iman Mirzadeh et. al. - Improved Knowledge Distillation via Teacher Assistant:Bridging the Gap Between Student and Teacher](https://arxiv.org/abs/1902.03393)