Import all the dependencies

In [1]:
#Import dependencies
import tensorflow as tf
from tensorflow.keras import models
from tensorflow.keras import layers
from tensorflow.keras.layers import Input, Dense
from tensorflow import nn
from tensorflow.compat.v1.losses import softmax_cross_entropy
from tensorflow.keras.metrics import *
from tensorflow.keras.optimizers import *
tf.random.set_seed(666)
#Load Data
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
X_train = X_train/255.
X_test = X_test/255.
#Reshape Data
X_train = X_train.astype("float32").reshape(-1, 28, 28, 1)
X_test = X_test.astype("float32").reshape(-1, 28, 28, 1)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


Utility for teacher model

In [2]:
def build_teacher_model():
    model = models.Sequential()
    model.add(layers.Conv2D(16, (5, 5), activation="relu", input_shape=(28, 28, 1)))
    model.add(layers.MaxPooling2D(pool_size=(2, 2)))
    model.add(layers.Conv2D(32, (5, 5), activation="relu"))
    model.add(layers.MaxPooling2D(pool_size=(2, 2)))
    model.add(layers.Dropout(0.2))
    model.add(layers.Flatten())
    model.add(layers.Dense(128, activation="relu"))
    model.add(layers.Dense(10))
    loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    optimizer = tf.keras.optimizers.Adam()
    model.compile(loss=loss_func, optimizer=optimizer, metrics=["accuracy"])
    return model

Train model

In [3]:
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(100).batch(64)
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(64)
teacher_model = build_teacher_model()
teacher_model.summary()
teacher_model.fit(train_dataset, validation_data=test_dataset,epochs=20)

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 24, 24, 16)        416       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 12, 12, 16)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 8, 8, 32)          12832     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 4, 4, 32)          0         
_________________________________________________________________
dropout (Dropout)            (None, 4, 4, 32)          0         
_________________________________________________________________
flatten (Flatten)            (None, 512)               0         
_________________________________________________________________
dense (Dense)                (None, 128)               6

<tensorflow.python.keras.callbacks.History at 0x7f8279f01a90>

Evaluation

In [4]:
evaluate_1=teacher_model.evaluate(test_dataset)[1]*100
print("Test accuracy: {:.2f}".format(evaluate_1))
#save model
teacher_model.save_weights("teacher_model.h5")

Test accuracy: 90.91


Student model

In [5]:
def build_student_model():
    model = models.Sequential()
    model.add(Input(shape=(28, 28, 1)))
    model.add(layers.Flatten())
    model.add(Dense(48, activation="relu"))
    model.add(Dense(10))    
    return model

def get_kd_loss(teacher_log,student_log,temp=0.5):
    teacher_probs = nn.softmax(teacher_log / temp)
    kd_loss = softmax_cross_entropy(teacher_probs, student_log/ temp, temp**2)
    return kd_loss

student_model = build_student_model()
optimizer = Adam(learning_rate=0.01)
train_loss = Mean(name="train_loss")
valid_loss = Mean(name="test_loss")
train_acc = SparseCategoricalAccuracy(name="train_acc")
valid_acc = SparseCategoricalAccuracy(name="valid_acc")

Train model

In [6]:
def train_model(images, labels, teacher_model,student_model, optimizer, temp):
    teacher_log = teacher_model(images)
    with tf.GradientTape() as tape:
        student_log = student_model(images)
        loss = get_kd_loss(teacher_log, student_log, temp)
    
    gradients = tape.gradient(loss, student_model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, student_model.trainable_variables))

    train_loss(loss)
    train_acc(labels, nn.softmax(student_log))

Validation

In [7]:
def validate_model(images, labels, teacher_model,student_model, temp):
    teacher_log = teacher_model(images)
    student_log = student_model(images)
    loss = get_kd_loss(teacher_log, student_log, temp)
    valid_loss(loss)
    valid_acc(labels, nn.softmax(student_log))

Training

In [8]:
def train_model_all(epochs, teacher_model, student_model, optimizer, temp=0.5):
    for epoch in range(epochs):
        for (images, labels) in train_dataset:
            train_model(images, labels, teacher_model, student_model, optimizer, temp)

        for (images, labels) in test_dataset:
            validate_model(images, labels, teacher_model, student_model, temp)
            
        (loss, acc) = train_loss.result(), train_acc.result()
        (val_loss, val_acc) = valid_loss.result(), valid_acc.result()
        
        train_loss.reset_states(), train_acc.reset_states()
        valid_loss.reset_states(), valid_acc.reset_states()
        
        template = "Epoch {}, loss: {:.3f}, acc: {:.3f}, val_loss: {:.3f}, val_acc: {:.3f}"
        print (template.format(epoch+1,loss,acc,val_loss,val_acc))        
    
    return teacher_model, student_model
_, student_model = train_model_all(20, teacher_model, student_model, optimizer)
student_model.save_weights("student_model.h5")
!ls -lh *.h5
teacher_model.summary()
student_model.summary()

Epoch 1, loss: 0.125, acc: 0.812, val_loss: 0.097, val_acc: 0.837
Epoch 2, loss: 0.099, acc: 0.846, val_loss: 0.100, val_acc: 0.836
Epoch 3, loss: 0.095, acc: 0.851, val_loss: 0.107, val_acc: 0.825
Epoch 4, loss: 0.091, acc: 0.857, val_loss: 0.092, val_acc: 0.852
Epoch 5, loss: 0.088, acc: 0.861, val_loss: 0.098, val_acc: 0.844
Epoch 6, loss: 0.086, acc: 0.863, val_loss: 0.096, val_acc: 0.845
Epoch 7, loss: 0.085, acc: 0.863, val_loss: 0.104, val_acc: 0.843
Epoch 8, loss: 0.085, acc: 0.864, val_loss: 0.106, val_acc: 0.840
Epoch 9, loss: 0.084, acc: 0.867, val_loss: 0.102, val_acc: 0.850
Epoch 10, loss: 0.083, acc: 0.868, val_loss: 0.101, val_acc: 0.853
Epoch 11, loss: 0.082, acc: 0.868, val_loss: 0.110, val_acc: 0.848
Epoch 12, loss: 0.081, acc: 0.869, val_loss: 0.111, val_acc: 0.840
Epoch 13, loss: 0.080, acc: 0.873, val_loss: 0.107, val_acc: 0.849
Epoch 14, loss: 0.081, acc: 0.871, val_loss: 0.110, val_acc: 0.842
Epoch 15, loss: 0.079, acc: 0.871, val_loss: 0.105, val_acc: 0.850
Epoc