### Distilling the Knowledge in a Neural Network

https://arxiv.org/pdf/1503.02531.pdf

In [17]:
from importlib import reload
import models
reload(models)
from __future__ import print_function
import keras
from keras import utils
from keras.datasets import cifar100
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K
from keras import optimizers
import numpy as np
from keras.callbacks import ModelCheckpoint  

In [18]:
batch_size = 128
num_classes = 100
epochs = 60

# input image dimensions
img_rows, img_cols = 32, 32

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = cifar100.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
y_train = utils.to_categorical(y_train, num_classes)
y_test = utils.to_categorical(y_test, num_classes)
input_shape = (img_rows, img_cols, 3)
print('x_train shape:', x_train.shape)
print('y_train shape:', y_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

x_train shape: (50000, 32, 32, 3)
y_train shape: (50000, 100)
50000 train samples
10000 test samples


In [19]:
model = dict()
hist = dict()
score = dict()
preds = dict()

In [20]:
reload(models)
model['teacher'] = models.TeacherModel_CIFAR(input_shape, num_classes)

model['teacher'].compile(loss=keras.losses.categorical_crossentropy,
              optimizer='Adam',
              metrics=['accuracy'])
#model['teacher'].summary()

In [None]:
checkpointer = ModelCheckpoint(filepath='saved_models_cifar100/weights.best.teacher.hdf5', 
                               verbose=0, save_best_only=True)

hist['teacher'] = model['teacher'].fit(x_train, y_train, batch_size=batch_size,
          epochs=60, verbose=1, validation_data=(x_test, y_test), callbacks=[checkpointer])
score['teacher'] = model['teacher'].evaluate(x_test, y_test, verbose=0)
print('Test loss:', score['teacher'][0])
print('Test accuracy:', score['teacher'][1])

Train on 50000 samples, validate on 10000 samples
Epoch 1/60
Epoch 2/60
Epoch 3/60
Epoch 4/60
Epoch 5/60
Epoch 6/60
Epoch 7/60
Epoch 8/60
Epoch 9/60
Epoch 10/60
Epoch 11/60
Epoch 12/60
Epoch 13/60
Epoch 14/60
Epoch 15/60

In [None]:
model['teacher'].load_weights('saved_models_cifar100/weights.best.teacher.hdf5')
score['teacher'] = model['teacher'].evaluate(x_test, y_test, verbose=0)
n_errors = np.int((1-score['teacher'][-1])*len(y_test))
print('Test loss:', score['teacher'][0])
print('Test accuracy:', score['teacher'][-1])
print('Test errors:', n_errors)

In [None]:
reload(models)
model['soft_teacher'] = models.SoftTeacherModel_CIFAR(input_shape, num_classes, l1=0.1, l2=0.07, b=1)

model['soft_teacher'].compile(loss=keras.losses.categorical_crossentropy,
              optimizer='Adam',
              metrics=['accuracy'])

In [None]:
checkpointer = ModelCheckpoint(filepath='saved_models_cifar100/weights.best.soft_teacher.hdf5', 
                               verbose=0, save_best_only=True)

hist['soft_teacher'] = model['soft_teacher'].fit(x_train, y_train, batch_size=batch_size,
          epochs=40, verbose=1, validation_data=(x_test, y_test), callbacks=[checkpointer])
score['soft_teacher'] = model['soft_teacher'].evaluate(x_test, y_test, verbose=0)
print('Test loss:', score['soft_teacher'][0])
print('Test accuracy:', score['soft_teacher'][1])

In [None]:
model['soft_teacher'].load_weights('saved_models_cifar100/weights.best.soft_teacher.hdf5')
score['soft_teacher'] = model['soft_teacher'].evaluate(x_test, y_test, verbose=0)
n_errors = np.int((1-score['soft_teacher'][-1])*len(y_test))
print('Test loss:', score['soft_teacher'][0])
print('Test accuracy:', score['soft_teacher'][-1])
print('Test errors:', n_errors)

In [None]:
reload(models)

model['student'] = models.StudentModel(input_shape, num_classes)
model['student'].compile(loss=keras.losses.categorical_crossentropy,
              optimizer='Adam',
              metrics=['accuracy'])

In [None]:
checkpointer = ModelCheckpoint(filepath='saved_models_cifar100/weights.best.student.hdf5', 
                               verbose=0, save_best_only=True)

hist['student'] = model['student'].fit(x_train, y_train, batch_size=batch_size,
          epochs=40, verbose=1, validation_data=(x_test, y_test), callbacks=[checkpointer])
score['student'] = model['student'].evaluate(x_test, y_test, verbose=0)
print('Test loss:', score['student'][0])
print('Test accuracy:', score['student'][1])

In [None]:
model['student'].load_weights('saved_models_cifar100/weights.best.student.hdf5')
score['student'] = model['student'].evaluate(x_test, y_test, verbose=0)
n_errors = np.int((1-score['student'][-1])*len(y_test))
print('Test loss:', score['student'][0])
print('Test accuracy:', score['student'][-1])
print('Test errors:', n_errors)

# Knowledge Distilation

In [None]:
kd_gt = dict()
T = 5
kd_gt['t_train'] = model['teacher'].T_model(T).predict(x_train, verbose=1, batch_size=batch_size)
kd_gt['t_test'] = model['teacher'].T_model(T).predict(x_test, verbose=1, batch_size=batch_size)
kd_gt['st_train'] = model['soft_teacher'].predict(x_train, verbose=1, batch_size=batch_size)
kd_gt['st_test'] = model['soft_teacher'].predict(x_test, verbose=1, batch_size=batch_size)

import numpy as np
np.linalg.norm(kd_gt['t_train'], axis=-1).mean(), np.linalg.norm(kd_gt['st_train'], axis=-1).mean()

In [None]:
# import numpy as np
# from keras.activations import softmax

# def softmax_with_temp(x):
#     Temp = 1.0
#     e_x = np.exp((x - x.max(axis=1, keepdims=True))/Temp)
#     out = e_x / e_x.sum(axis=1, keepdims=True)
#     return out

# def soft_with_T(T=1):
#     def swt(x):
#         return softmax(x/T)
#     return swt

In [None]:
''' KNOWLEDGE DISTILLATION WITH REGULAR TEACHER (TEMPERATURE SOFTMAX) '''
reload(models)
model['student_'] = models.StudentModel(input_shape, num_classes, T=T, in_class=True)
model['student_'].compile(loss=['categorical_crossentropy', 'categorical_crossentropy'],
                          loss_weights=[1., 1. / (T**2)],
                          optimizer='Adam',
                          metrics=['acc'])

In [None]:
checkpointer = ModelCheckpoint(filepath='saved_models_cifar100/weights.best.student_.hdf5', 
                               verbose=0, save_best_only=True)

hist['student_'] = model['student_'].fit(x_train, [kd_gt['t_train'], y_train],
          batch_size=batch_size, epochs=200, verbose=1,
          validation_data=(x_test, [kd_gt['t_test'], y_test]), callbacks=[checkpointer])
score['student_'] = model['student_'].evaluate(x_test, [kd_gt['t_test'], y_test], verbose=0)
print('Test loss:', score['student_'][0])
print('Test accuracy:', score['student_'][-1])

In [None]:
model['student_'].load_weights('saved_models_cifar100/weights.best.student_.hdf5')
score['student_'] = model['student_'].evaluate(x_test, [kd_gt['t_test'], y_test], verbose=0)
n_errors = np.int((1-score['student_'][-1])*len(y_test))
print('Test loss:', score['student_'][0])
print('Test accuracy:', score['student_'][-1])
print('Test errors:', n_errors)

In [None]:
''' SOFT TEACHER IN CLASS '''
reload(models)
from keras import callbacks

# base_lr = 3e-3
# decay = 0.99
# optim = keras.optimizers.Adam(lr=base_lr)

model['student_st'] = models.StudentModel(input_shape, num_classes, T=1, in_class=True, l2=0, b=0)
model['student_st'].compile(loss=['categorical_crossentropy', 'categorical_crossentropy'],
                          loss_weights=[2, 1.],
                          optimizer='Adam',
                          metrics=['acc'])

In [None]:
def schedule(epoch):
    return base_lr * decay**(epoch)

#es = callbacks.EarlyStopping(monitor='val_o2_loss', mode='min', verbose=0, patience=30)
#mc = callbacks.ModelCheckpoint('best_student_st.h5', monitor='val_o2_acc', mode='max', verbose=0, save_best_only=True)
ls = callbacks.LearningRateScheduler(schedule)

checkpointer = ModelCheckpoint(filepath='saved_models_cifar100/weights.best.student_st.hdf5', 
                               verbose=0, save_best_only=True)

hist['student_st'] = model['student_st'].fit(x_train, [kd_gt['st_train'], y_train],
          batch_size=batch_size,
          epochs=50,
          verbose=1,
          validation_data=(x_test, [kd_gt['st_test'], y_test]),
          callbacks=[checkpointer],
            )
score['student_st'] = model['student_st'].evaluate(x_test, [kd_gt['st_test'], y_test], verbose=0)
print('Test loss:', score['student_st'][0])
print('Test accuracy:', score['student_st'][-1])

In [None]:
model['student_st'].load_weights('saved_models_cifar100/weights.best.student_st.hdf5')
score['student_st'] = model['student_st'].evaluate(x_test, [kd_gt['st_test'], y_test], verbose=0)
n_errors = np.int((1-score['student_st'][-1])*len(y_test))
print('Test loss:', score['student_st'][0])
print('Test accuracy:', score['student_st'][-1])
print('Test errors:', n_errors)

In [None]:
# x_small_train = x_train[np.argmax(y_train, axis=-1) != 3]
# y_small_train = y_train[np.argmax(y_train, axis=-1) != 3]
# x_small_train.shape

In [None]:
reload(models)
from keras import callbacks

base_l2 = 0.7
l2_decay = 0.99
l2_weight = K.variable(base_l2)

def changeAlpha(epoch,logs):
    #maybe use epoch+1, because it starts with 0
    K.set_value(l2_weight, base_l2 * l2_decay**epoch)

l2Changer = callbacks.LambdaCallback(on_epoch_end=changeAlpha)


base_lr = 2e-3
decay = 0.99
optim = keras.optimizers.Adam(lr=base_lr)

model['student_reg'] = models.SoftStudentModel(input_shape, num_classes, l1=0.1, l2=l2_weight, b=1)
model['student_reg'].compile(loss=keras.losses.categorical_crossentropy,
              optimizer=optim,
              metrics=['accuracy'])

In [None]:
def schedule(epoch):
    return base_lr * decay**(epoch)

ls = callbacks.LearningRateScheduler(schedule)
#model['student_reg'].load_weights('saved_models_cifar10/weights.best.student_reg.hdf5')
checkpointer = ModelCheckpoint(filepath='saved_models_cifar100/weights.best.student_reg.hdf5', 
                               verbose=0, save_best_only=True)



hist['student_reg'] = model['student_reg'].fit(x_train, y_train,
          batch_size=batch_size,
          epochs=100,
          verbose=1,
          validation_data=(x_test, y_test),
          callbacks=[ls,checkpointer, l2Changer],
            )
score['student_reg'] = model['student_reg'].evaluate(x_test, y_test, verbose=0)
print('Test loss:', score['student_reg'][0])
print('Test accuracy:', score['student_reg'][-1])

In [None]:
model['student_reg'].load_weights('saved_models_cifar100/weights.best.student_reg.hdf5')
score['student_reg'] = model['student_reg'].evaluate(x_test, y_test, verbose=0)
n_errors = np.int((1-score['student_reg'][-1])*len(y_test))
print('Test loss:', score['student_reg'][0])
print('Test accuracy:', score['student_reg'][-1])
print('Test errors:', n_errors)

# Analysis

In [None]:
T = 5
preds['teacher_no_T'] = model['teacher'].predict(x_train, verbose=1, batch_size=batch_size)
preds['teacher'] = model['teacher'].T_model(T).predict(x_train, verbose=1, batch_size=batch_size)
preds['soft_teacher'] = model['soft_teacher'].predict(x_train, verbose=1, batch_size=batch_size)

import numpy as np
np.linalg.norm(preds['teacher'], axis=-1).mean(), np.linalg.norm(preds['soft_teacher'], axis=-1).mean()

In [None]:
## Plot constrained softmax probabilities generated by the model

import matplotlib.pyplot as plt
ind = np.random.choice(len(preds['teacher_no_T']), 50)
plt.plot(np.sort(preds['teacher_no_T'])[ind].T)
plt.show()

ind = np.random.choice(len(preds['teacher']), 50)
plt.plot(np.sort(preds['teacher'])[ind].T)
plt.show()

ind = np.random.choice(len(preds['soft_teacher']), 50)
plt.plot(np.sort(preds['soft_teacher'])[ind].T)
plt.show()

In [None]:
from collections import Counter

pairs = [(x[-1], x[-2]) for x in np.argsort(preds['soft_teacher'])]
counts = Counter(pairs)
counts.most_common(20)

In [None]:
i = np.random.randint(len(x_train))
#i = 54270
fig, ax = plt.subplots(1, 3, figsize=(10,2.5), gridspec_kw={'width_ratios': [1.6, 2, 2], 'wspace': 0.3})
# plt.tight_layout()
plt.gcf().subplots_adjust(bottom=0.2)
ax[0].imshow(x_train[i])
ax[0].axis('off')
ax[0].set_title('Input')
ax[1].bar(np.linspace(0,9,10), preds['soft_teacher'][i])
ax[1].set_xticks(np.arange(0, 10, step=1))
ax[1].set_ylim(top=1)
ax[1].spines['top'].set_visible(False)
ax[1].spines['right'].set_visible(False)
ax[1].set_xlabel('Classes')
ax[1].set_ylabel('Probabilities')
ax[1].set_title('Regularized Network')
ax[2].bar(np.linspace(0,9,10), preds['teacher'][i])
ax[2].set_xticks(np.arange(0, 10, step=1))
ax[2].set_ylim(top=1)
ax[2].spines['top'].set_visible(False)
ax[2].spines['right'].set_visible(False)
ax[2].set_xlabel('Classes')
ax[2].set_ylabel('Probabilities')
ax[2].set_title('Regular Network (T=5)')
plt.savefig('figures_cifar/cifar_{}.png'.format(i))
plt.show()

#plt.savefig('foo.png')

In [None]:
from collections import Counter

pairs = [(x[0], x[1]) for x in np.argsort(preds_st)]
counts = Counter(pairs)
counts.most_common(len(counts))

In [None]:
import tensorflow as tf
a = [[1, 10, 7, 9, 3, 66], [6, 4, 3, 2, 100, 0]]
b = tf.sort(a,axis=-1,direction='ASCENDING',name=None)
c = tf.keras.backend.eval(b)
c

In [None]:
model['student'].summary()