In [1]:
# Author: Naveen Lalwani
# Script to distill knowledge from LeNet-300-100 trained on CIFAR-10 to student model

import tensorflow as tf
import numpy as np
import keras
from tensorflow.keras.datasets import mnist
from tensorflow.keras import layers, models
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.constraints import max_norm
from tensorflow.keras.models import Model
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.optimizers import RMSprop, SGD, Adam
from keras.utils import np_utils
import time

Using TensorFlow backend.


In [None]:
num_classes = 10
n_input = 3072
(x_train, y_train), (x_test, y_test) =  tf.keras.datasets.cifar10.load_data()

# Enabling One Hot Encoding
y_train = np_utils.to_categorical(y_train, num_classes)
y_test = np_utils.to_categorical(y_test, num_classes)

# Changing input image datatype to float
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')

# Normalizaig data
x_train  /= 255
x_test /= 255

x_train = x_train.reshape([50000, 3072])
x_test = x_test.reshape([10000, 3072])

In [None]:
# Teacher Model: LeNet-300-100
def lenet_300_100_model():
    inputs = layers.Input(shape = (3072,))
    
    x = layers.Dense(300, activation='relu', name='FC1')(inputs)
    
    x = layers.Dense(100, activation='relu', name='FC2')(x)

    x = layers.Dense(10, name='logits')(x)
    preds = layers.Activation('softmax', name='Softmax')(x)

    model = Model(inputs=inputs, outputs=preds)
    model.summary()
    return model

#**Build Model LeNet-300-100**

In [4]:
model = lenet_300_100_model()

Instructions for updating:
Colocations handled automatically by placer.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 3072)              0         
_________________________________________________________________
FC1 (Dense)                  (None, 300)               921900    
_________________________________________________________________
FC2 (Dense)                  (None, 100)               30100     
_________________________________________________________________
logits (Dense)               (None, 10)                1010      
_________________________________________________________________
Softmax (Activation)         (None, 10)                0         
Total params: 953,010
Trainable params: 953,010
Non-trainable params: 0
_________________________________________________________________


In [5]:
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['categorical_accuracy'])
model.fit(x_train, y_train, epochs=20, batch_size = 512) 

Instructions for updating:
Use tf.cast instead.
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<tensorflow.python.keras.callbacks.History at 0x7f6011c85c18>

In [6]:
test_loss, test_acc = model.evaluate(x_test, y_test)
print("Test Loss:", test_loss)
print("Test Accuracy:", test_acc)

Test Loss: 1.4184400548934937
Test Accuracy: 0.5035


In [None]:
getSoftmaxKnowledge = Model(inputs=model.input, outputs=model.get_layer("logits").output)
model_logits = getSoftmaxKnowledge.predict(x_train)

In [None]:
# Defining function described by Geoffrey Hinton in his paper of Knowledge Distillation
def softmax_with_temperature(logits, temperature):
    logits = logits / temperature
    return (np.exp(logits) / np.sum(np.exp(logits)))

In [None]:
# Temperature is a hyperparameter
temperature = 2
softened_train_prob = softmax_with_temperature(model_logits, temperature)

In [None]:
# Model Definition for the Student Model
def build_small_model():
    inputs = layers.Input(shape = (3072,))
    
    x = layers.Dense(50, activation='relu', name='FC1')(inputs)
    
    x = layers.Dense(10, name='logits')(x)
    
    preds = layers.Activation('softmax', name='Softmax')(x)
  
    model = Model(inputs=inputs, outputs=preds)
    model.summary()
    return model

In [48]:
small_model = build_small_model()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_11 (InputLayer)        (None, 3072)              0         
_________________________________________________________________
FC1 (Dense)                  (None, 50)                153650    
_________________________________________________________________
logits (Dense)               (None, 10)                510       
_________________________________________________________________
Softmax (Activation)         (None, 10)                0         
Total params: 154,160
Trainable params: 154,160
Non-trainable params: 0
_________________________________________________________________


# **Distilling Knowledge in the student model**

In [49]:
# Optimization = Adam
# Loss = Cross Entropy loss
# Epochs = 50
# Trained with dark knowledge

small_model.compile(optimizer='adam', loss= 'categorical_crossentropy', metrics=['categorical_accuracy'])
small_model.fit(x_train, softened_train_prob, epochs=50, batch_size=128)

test_loss, test_acc = small_model.evaluate(x_test, y_test)
print("Test Loss:", test_loss)
print("Test Accuracy:", test_acc)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50
Test Loss: 1.5776135791778565
Test Accuracy: 0.4769


In [None]:
small_model.save('model_50_LeNet-300-100_Distilled_CIFAR-10.h5')
model.save('model_LeNet-300-100_CIFAR10.h5')