# Knowledge Distillation

In [1]:
import numpy as np

import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation
from keras.optimizers import Adam
from keras.layers import Lambda
from keras.datasets import mnist

Using TensorFlow backend.


In [2]:
batch_size = 128
num_classes = 10
epochs = 30

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

60000 train samples
10000 test samples


In [3]:
soft_target = np.load('soft_target.npy')
print(soft_target.shape, 'train targets')

(60000, 10) train targets


In [4]:
from keras.models import Model
from keras.layers import Input, concatenate

In [5]:
input_tensor = Input(shape=(784,))
net = Dense(100, activation='relu')(input_tensor)
net = Dense(100, activation='relu')(net)
net = Dense(num_classes)(net)

hard_temperature = 1
hard_net = Lambda(lambda x: x/hard_temperature)(net)
hard_net = Activation('softmax', name='hard_pred')(hard_net)

soft_temperature = 5
soft_net = Lambda(lambda x: x/soft_temperature)(net)
soft_net = Activation('softmax', name='soft_pred')(soft_net)

model = Model(input_tensor, concatenate([hard_net, soft_net]))

In [6]:
new_train_target = np.hstack((y_train, soft_target))
new_train_target.shape

(60000, 20)

In [7]:
new_test_target = np.hstack((y_test, y_test))
new_test_target.shape

(10000, 20)

In [8]:
import keras.backend as K
from keras.losses import categorical_crossentropy as xentropy

def kd_loss(y_true, y_pred):
    hard_true, soft_true = y_true[:, :10], y_true[:, 10:]
    hard_pred, soft_pred = y_pred[:, :10], y_pred[:, 10:]
    hard_loss = xentropy(hard_true, hard_pred)
    soft_loss = xentropy(soft_true, soft_pred)*np.power(soft_temperature, 2)
    return hard_loss+soft_loss

In [9]:
model.compile(loss=kd_loss,
              optimizer=Adam(),
              metrics=['accuracy'])

In [10]:
model.fit(x_train, new_train_target,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          validation_data=(x_test, new_test_target))

Train on 60000 samples, validate on 10000 samples
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30


<keras.callbacks.History at 0x7f8a74c134a8>

In [11]:
error = 0
y_pred = model.predict(x_test, verbose=0)

for i in range(len(y_pred)):
    if not y_pred[i][:10].argmax() == new_test_target[i][:10].argmax():
        error += 1
print('Number of error: ', error, '/', len(y_pred))

Number of error:  200 / 10000
