In [1]:
#from resnet164 import ResNet164
#from utils import load_mnist
import numpy as np
import keras
from keras.callbacks import ModelCheckpoint
from keras.models import Sequential, Model
from keras.layers import concatenate
from keras.layers.core import Dense, Activation, Dropout, Flatten, Lambda 
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from scipy.ndimage.interpolation import rotate, shift, zoom
from keras.constraints import max_norm
from keras.datasets import mnist
import numpy as np
import tensorflow as tf
import h5py
import matplotlib.pyplot as plt


Using TensorFlow backend.


In [2]:
def evaluate(prediction, true_label):
    pred_indices = np.argmax(prediction, 1)
    true_indices = np.argmax(true_label, 1)
    return np.mean(pred_indices == true_indices)

In [3]:
def SOFTMAX(s_):
    return np.exp(s_) / np.matmul(np.ones((1, s_.shape[0])), np.exp(s_)) 

In [4]:
def knowledge_distillation_loss(y_true, y_pred, alpha):

    # Extract the one-hot encoded values and the softs separately so that we can create two objective functions
    y_true, y_true_softs = y_true[: , :nb_classes], y_true[: , nb_classes:]
    
    y_pred, y_pred_softs = y_pred[: , :nb_classes], y_pred[: , nb_classes:]
    
    loss =(alpha*tf.keras.losses.categorical_crossentropy(y_true,y_pred) +
           tf.keras.losses.categorical_crossentropy(y_true_softs, y_pred_softs))
    return loss

In [5]:
def acc(y_true, y_pred):
    y_true = y_true[:, :nb_classes]
    y_pred = y_pred[:, :nb_classes]
    return tf.keras.metrics.categorical_accuracy(y_true, y_pred)

In [6]:
def rand_jitter(temp):
    if np.random.random() > .7:
        temp = rotate(temp, angle = np.random.randint(-25, 25), reshape=False)
    return temp

## Tests:

In [7]:
(X_train,y_train), (X_test, y_test) = mnist.load_data()

X_train = X_train.reshape(60000,28,28,1).astype('float32')
X_test = X_test.reshape(10000,28,28,1).astype('float32')

X_train /= 255
X_test /= 255

n_classes = 10
y_train = keras.utils.to_categorical(y_train, n_classes)
y_test = keras.utils.to_categorical(y_test, n_classes)

* Train a large NN with two hidden layers of 1200 ReLu hidden units on all training examples. The net has to be regularized using dropout and weight-constraint. Also, the images were jittered up to two pixels in any direction. 

In [8]:
#teacher model:
teacher = Sequential()
teacher.add(Flatten(input_shape=(28,28,1)))
teacher.add(Dense(1200, kernel_constraint=max_norm(4.), activation='relu'))
teacher.add(Dense(1200, kernel_constraint=max_norm(4.), activation='relu'))
teacher.add(Dropout(.5))
teacher.add(Dense(10))
teacher.add(Activation('softmax'))

teacher.compile(loss='categorical_crossentropy', optimizer='SGD', metrics=['accuracy'])
# print(teacher.summary())

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.


In [9]:
X_train_temp = np.copy(X_train) # Copy to not effect the originals
#Add noise 'jitter':
for j in range(0, X_train_temp.shape[0]):
    X_train_temp[j, :, :,0] = rand_jitter(X_train_temp[j,:,:,0])

teacher.fit(X_train_temp, y_train, batch_size=128, epochs=5, verbose=1, validation_data=(X_test,y_test))


Instructions for updating:
Use tf.cast instead.
Train on 60000 samples, validate on 10000 samples
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x124e34fd0>

* Train a smaller network with two hidden layers of 800 ReLu hidden units without regularization.

In [10]:
#smaller teacher model:
small_teacher = Sequential()
small_teacher.add(Flatten(input_shape=(28,28,1)))
small_teacher.add(Dense(800, activation='relu'))
small_teacher.add(Dense(800, activation='relu'))
small_teacher.add(Dense(10, activation='softmax'))

small_teacher.compile(loss='categorical_crossentropy', optimizer='SGD', metrics=['accuracy'])
#print(teacher.summary())

In [11]:
X_train_temp = np.copy(X_train) # Copy to not effect the originals
#Add noise:
for j in range(0, X_train_temp.shape[0]):
    X_train_temp[j, :, :,0] = rand_jitter(X_train_temp[j,:,:,0])

small_teacher.fit(X_train_temp, y_train, batch_size=128, epochs=5, verbose=1, validation_data=(X_test,y_test))


Train on 60000 samples, validate on 10000 samples
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x13ee8cef0>

If we add regularization to this small network to make it match the soft targets produced by the large net at a temperature of 20. **In the paper this network achieves 74 test errors**.

In [12]:
#if we now use the small_teacher model as the student model:
# Raise the temperature of teacher model and gather the soft targets
# Set a tempature value
temp = 20
#Collect the logits from the previous layer output and store it in a different model
teacher_WO_Softmax = Model(teacher.input, teacher.get_layer('dense_3').output)

In [13]:
nb_classes = 10
#student model:
student_m = Sequential()
student_m.add(Flatten(input_shape=(28,28,1)))
student_m.add(Dense(800, activation='relu'))
student_m.add(Dense(800, activation='relu'))
student_m.add(Dense(10))
#student_m.add(Activation('softmax'))
student_m.compile(loss='categorical_crossentropy', optimizer='SGD', metrics=['accuracy'])
#print(student_m.summary())

logits = student_m.layers[-1].output
probs = Activation('softmax')(logits)

logits_T = Lambda(lambda x: x / temp)(logits)
probs_T = Activation('softmax')(logits_T)

output = concatenate([probs, probs_T])

student_m = Model(student_m.input, output)

student_m.compile(optimizer='SGD',
                      loss=lambda y_true, y_pred: knowledge_distillation_loss(y_true, y_pred, 1),
                      metrics=[acc])

In [14]:
teacher_train_logits = teacher_WO_Softmax.predict(X_train)
teacher_test_logits = teacher_WO_Softmax.predict(X_test)

Y_train_soft = SOFTMAX(teacher_train_logits/temp)
Y_test_soft = SOFTMAX(teacher_test_logits/temp)

Y_train_new = np.concatenate([y_train, Y_train_soft], axis=1)
Y_test_new =  np.concatenate([y_test, Y_test_soft], axis =1)

In [15]:
student_m.fit(X_train, Y_train_new,
          batch_size=128,
          epochs=5,
          verbose=1,
          validation_data=(X_test, Y_test_new))

Train on 60000 samples, validate on 10000 samples
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x10319eba8>

In [16]:
temp = 4
nb_classes = 10
#student model:
student_m2 = Sequential()
student_m2.add(Flatten(input_shape=(28,28,1)))
student_m2.add(Dense(30, activation='relu'))
student_m2.add(Dense(30, activation='relu'))
student_m2.add(Dense(10))
#student_m.add(Activation('softmax'))
student_m2.compile(loss='categorical_crossentropy', optimizer='SGD', metrics=['accuracy'])
#print(student_m2.summary())

logits = student_m2.layers[-1].output
probs = Activation('softmax')(logits)

logits_T = Lambda(lambda x: x / temp)(logits)
probs_T = Activation('softmax')(logits_T)

output = concatenate([probs, probs_T])

student_m2 = Model(student_m2.input, output)

student_m2.compile(optimizer='SGD',
                      loss=lambda y_true, y_pred: knowledge_distillation_loss(y_true, y_pred, 1),
                      metrics=[acc])

In [17]:
student_m2.fit(X_train, Y_train_new,
          batch_size=128,
          epochs=5,
          verbose=1,
          validation_data=(X_test, Y_test_new))

Train on 60000 samples, validate on 10000 samples
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x138430668>

In [18]:
nb_classes = 10
#student model:
student_m3 = Sequential()
student_m3.add(Flatten(input_shape=(28,28,1)))
student_m3.add(Dense(800, activation='relu'))
student_m3.add(Dense(800, activation='relu'))
student_m3.add(Dense(10))

student_m3.compile(loss='categorical_crossentropy', optimizer='SGD', metrics=['accuracy'])
# print(student_m3.summary())

logits = student_m3.layers[-1].output
probs = Activation('softmax')(logits)

logits_T = Lambda(lambda x: x / temp)(logits)
probs_T = Activation('softmax')(logits_T)

output = concatenate([probs, probs_T])

student_m3 = Model(student_m3.input, output)

student_m3.compile(optimizer='SGD',
                      loss=lambda y_true, y_pred: knowledge_distillation_loss(y_true, y_pred, 1),
                      metrics=[acc])

In [19]:
threes_idx = np.where(y_train[:,2] == 1)[0]
non_threes_idx = np.where(y_train[:,2] == 0)[0]
threes_n_examples = len(threes_idx)

new_training_X = X_train[non_threes_idx]
teacher_train_logits = teacher_WO_Softmax.predict(new_training_X)
Y_train_soft = SOFTMAX(teacher_train_logits/temp)
Y_train_new = np.concatenate([y_train[non_threes_idx], Y_train_soft], axis=1)

In [20]:
student_m3.fit(new_training_X, Y_train_new,
          batch_size=128,
          epochs=5,
          verbose=1)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x139cc2ef0>

In [21]:
def excluded_label_pred(dif_matrix, label_idx_list):
    indexes = []
    wrong_excluded_lable_pred = []

    for idx, i in enumerate(dif_matrix):
        if i != 0:
            indexes.append(idx)

    for i in indexes:
        if i in label_idx_list:
            wrong_excluded_lable_pred.append(i)  
    return wrong_excluded_lable_pred

In [23]:
threes_idx = np.where(y_train[:,2] == 1)[0]
non_threes_idx = np.where(y_train[:,2] == 0)[0]
threes_n_examples = len(threes_idx)

predicted = student_m3.predict(X_test)
dif = np.argmax(Y_test_new,axis=1) - np.argmax(predicted,axis=1)
test_acc = evaluate(predicted, y_test)
test_error_rate = np.count_nonzero(dif) / 10000
wrong_threes = excluded_label_pred(dif, threes_idx)
excluded_lable_miscls_among_label= len(wrong_threes)/threes_n_examples * 100
excluded_lable_miscls_among_miscls=len(wrong_threes)/(test_error_rate*10000) * 100

print(f'test prediction accuracy: {test_acc * 100}%')
print(f'test prediction error: {test_error_rate * 100}%')
print(f'misclass percentage among excluded class : {excluded_lable_miscls_among_label}%')
print(f'misclass percentage among misclassified labels : {excluded_lable_miscls_among_miscls}%')

test prediction accuracy: 83.56%
test prediction error: 16.439999999999998%
misclass percentage among excluded class : 2.4001342732460555%
misclass percentage among misclassified labels : 8.69829683698297%
