Import dependencies

In [1]:
import tensorflow as tf
from tensorflow.keras import models
from tensorflow.keras import layers
tf.random.set_seed(666)
# Load the FashionMNIST dataset, scale the pixel values
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
X_train = X_train/255.
X_test = X_test/255.
X_train.shape, X_test.shape, y_train.shape, y_test.shape
# Change the pixel values to float32 and reshape input data
X_train = X_train.astype("float32").reshape(-1, 28, 28, 1)
X_test = X_test.astype("float32").reshape(-1, 28, 28, 1)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


Teacher model shallow convNet

In [2]:
def get_teacher_model():
    model = models.Sequential()
    model.add(layers.Conv2D(16, (5, 5), activation="relu", input_shape=(28, 28, 1)))
    model.add(layers.MaxPooling2D(pool_size=(2, 2)))
    model.add(layers.Conv2D(32, (5, 5), activation="relu"))
    model.add(layers.MaxPooling2D(pool_size=(2, 2)))
    model.add(layers.Dropout(0.2))
    model.add(layers.Flatten())
    model.add(layers.Dense(128, activation="relu"))
    model.add(layers.Dense(10))
    
    return model

Loss function and optimizer

In [3]:
loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()

Prepare dataset

In [4]:
train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(100).batch(64)
test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(64)

Train teacher model

In [5]:
teacher_model = get_teacher_model()
teacher_model.compile(loss=loss_func, optimizer=optimizer, metrics=["accuracy"])
teacher_model.fit(train_ds,
                  validation_data=test_ds,
                  epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7f493af846d8>

Evaluation

In [6]:
print("Test accuracy: {:.2f}".format(teacher_model.evaluate(test_ds)[1]*100))
#save model
teacher_model.save_weights("teacher_model.h5")

Test accuracy: 90.44


Student model

In [7]:
def get_student_model():
    model = models.Sequential()
    model.add(layers.Input(shape=(28, 28, 1)))
    model.add(layers.Flatten())
    model.add(layers.Dense(48, activation="relu"))
    model.add(layers.Dense(10))
    
    return model

def get_kd_loss(student_logits, teacher_logits, temperature=0.5):
    teacher_probs = tf.nn.softmax(teacher_logits / temperature)
    kd_loss = tf.compat.v1.losses.softmax_cross_entropy(
        teacher_probs, student_logits / temperature, temperature**2)
    return kd_loss

Optimize student

In [8]:
student_model = get_student_model()
optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
# Average the loss across the batch size within an epoch
train_loss = tf.keras.metrics.Mean(name="train_loss")
valid_loss = tf.keras.metrics.Mean(name="test_loss")
# Specify the performance metric
train_acc = tf.keras.metrics.SparseCategoricalAccuracy(name="train_acc")
valid_acc = tf.keras.metrics.SparseCategoricalAccuracy(name="valid_acc")

Train student

In [9]:
@tf.function
def model_train(images, labels, teacher_model, 
                student_model, optimizer, temperature):
    teacher_logits = teacher_model(images)

    with tf.GradientTape() as tape:
        student_logits = student_model(images)
        loss = get_kd_loss(student_logits, teacher_logits, temperature)
    
    gradients = tape.gradient(loss, student_model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, student_model.trainable_variables))

    train_loss(loss)
    train_acc(labels, tf.nn.softmax(student_logits))

Validate student

In [10]:
@tf.function
def model_validate(images, labels, teacher_model, 
                   student_model, temperature):
    teacher_logits = teacher_model(images)

    student_logits = student_model(images)
    loss = get_kd_loss(student_logits, teacher_logits, temperature)

    valid_loss(loss)
    valid_acc(labels, tf.nn.softmax(student_logits))

Train set

In [11]:
def train_model(epochs, teacher_model, student_model, optimizer, temp=0.5):
    for epoch in range(epochs):
        for (images, labels) in train_ds:
            model_train(images, labels, teacher_model, student_model, optimizer, temp)

        for (images, labels) in test_ds:
            model_validate(images, labels, teacher_model, student_model, temp)
            
        (loss, acc) = train_loss.result(), train_acc.result()
        (val_loss, val_acc) = valid_loss.result(), valid_acc.result()
        
        train_loss.reset_states(), train_acc.reset_states()
        valid_loss.reset_states(), valid_acc.reset_states()
        
        template = "Epoch {}, loss: {:.3f}, acc: {:.3f}, val_loss: {:.3f}, val_acc: {:.3f}"
        print (template.format(epoch+1,
                            loss,
                            acc,
                            val_loss,
                            val_acc))
        
    
    return teacher_model, student_model

In [12]:
_, student_model = train_model(10, teacher_model, student_model, optimizer)
student_model.save_weights("student_model.h5")
!ls -lh *.h5
teacher_model.summary()
student_model.summary()
def representative_data_gen():
    for input_value in tf.data.Dataset.from_tensor_slices(X_train).batch(1).take(100):
        yield [input_value]
def convert_to_tflite(model, tflite_file):
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.representative_dataset = representative_data_gen
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
    converter.inference_input_type = tf.int8
    converter.inference_output_type = tf.int8
    tflite_quant_model = converter.convert()
    open(tflite_file, 'wb').write(tflite_quant_model)
convert_to_tflite(teacher_model, "teacher.tflite")
convert_to_tflite(student_model, "student.tflite")
!ls -lh *.tflite

Epoch 1, loss: 0.113, acc: 0.820, val_loss: 0.101, val_acc: 0.827
Epoch 2, loss: 0.089, acc: 0.849, val_loss: 0.088, val_acc: 0.843
Epoch 3, loss: 0.083, acc: 0.859, val_loss: 0.097, val_acc: 0.833
Epoch 4, loss: 0.082, acc: 0.859, val_loss: 0.096, val_acc: 0.837
Epoch 5, loss: 0.080, acc: 0.863, val_loss: 0.089, val_acc: 0.844
Epoch 6, loss: 0.077, acc: 0.867, val_loss: 0.094, val_acc: 0.841
Epoch 7, loss: 0.076, acc: 0.867, val_loss: 0.092, val_acc: 0.849
Epoch 8, loss: 0.076, acc: 0.869, val_loss: 0.088, val_acc: 0.848
Epoch 9, loss: 0.074, acc: 0.870, val_loss: 0.093, val_acc: 0.845
Epoch 10, loss: 0.073, acc: 0.873, val_loss: 0.083, val_acc: 0.857
-rw-r--r-- 1 root root 163K Dec 18 14:54 student_model.h5
-rw-r--r-- 1 root root 335K Dec 18 14:50 teacher_model.h5
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 24, 24, 16)        416       
_____

INFO:tensorflow:Assets written to: /tmp/tmpd9bt_0k6/assets


-rw-r--r-- 1 root root 40K Dec 18 14:54 student.tflite
-rw-r--r-- 1 root root 84K Dec 18 14:54 teacher.tflite
