<a href="https://colab.research.google.com/github/DENGXUELIN/clash/blob/main/Distillation_Toy_Example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
# Imports
import tensorflow as tf

from tensorflow.keras import models
from tensorflow.keras import layers

tf.random.set_seed(666)

In [13]:
# Load the FashionMNIST dataset, scale the pixel values
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
X_train = X_train/255.
X_test = X_test/255.

X_train.shape, X_test.shape, y_train.shape, y_test.shape

((60000, 28, 28), (10000, 28, 28), (60000,), (10000,))

In [15]:
# Change the pixel values to float32 and reshape input data
X_train = X_train.astype("float32").reshape(-1, 28, 28, 1)
X_test = X_test.astype("float32").reshape(-1, 28, 28, 1)

In [17]:
# Define utility function for building a basic shallow Convnet
def get_teacher_model():
    model = models.Sequential()
    model.add(layers.Conv2D(16, (5, 5), activation="relu",
        input_shape=(28, 28, 1)))
    model.add(layers.MaxPooling2D(pool_size=(2, 2)))
    model.add(layers.Conv2D(32, (5, 5), activation="relu"))
    model.add(layers.MaxPooling2D(pool_size=(2, 2)))
    model.add(layers.Dropout(0.2))
    model.add(layers.Flatten())
    model.add(layers.Dense(128, activation="relu"))
    model.add(layers.Dense(10))

    return model

In [18]:
# Define loss function and optimizer
loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()

In [19]:
# Prepare TF dataset
train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(100).batch(64)
test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(64)

# Train the teacher model
teacher_model = get_teacher_model()
teacher_model.compile(loss=loss_func, optimizer=optimizer, metrics=["accuracy"])
teacher_model.fit(train_ds,
                  validation_data=test_ds,
                  epochs=10)

Epoch 1/10
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 6ms/step - accuracy: 0.7137 - loss: 0.7896 - val_accuracy: 0.8519 - val_loss: 0.4113
Epoch 2/10
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 4ms/step - accuracy: 0.8596 - loss: 0.3870 - val_accuracy: 0.8752 - val_loss: 0.3490
Epoch 3/10
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - accuracy: 0.8807 - loss: 0.3295 - val_accuracy: 0.8801 - val_loss: 0.3261
Epoch 4/10
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - accuracy: 0.8893 - loss: 0.2980 - val_accuracy: 0.8852 - val_loss: 0.3111
Epoch 5/10
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - accuracy: 0.8978 - loss: 0.2782 - val_accuracy: 0.8931 - val_loss: 0.2940
Epoch 6/10
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 4ms/step - accuracy: 0.9031 - loss: 0.2585 - val_accuracy: 0.9007 - val_loss: 0.2716
Epoch 7/10
[1m938/938[0m 

<keras.src.callbacks.history.History at 0x7d1b3f38df40>

In [21]:
# Evaluate and serialize
print("Test accuracy: {:.2f}".format(teacher_model.evaluate(test_ds)[1]*100))
teacher_model.save_weights("teacher_model.weights.h5")

[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - accuracy: 0.9066 - loss: 0.2721
Test accuracy: 90.74


In [22]:
# Student model utility
def get_student_model():
    model = models.Sequential()
    model.add(layers.Input(shape=(28, 28, 1)))
    model.add(layers.Flatten())
    model.add(layers.Dense(48, activation="relu"))
    model.add(layers.Dense(10))

    return model

In [23]:
# Credits: https://github.com/google-research/simclr/blob/master/colabs/distillation_self_training.ipynb
def get_kd_loss(student_logits, teacher_logits, temperature=0.5):
    teacher_probs = tf.nn.softmax(teacher_logits / temperature)
    kd_loss = tf.compat.v1.losses.softmax_cross_entropy(
        teacher_probs, student_logits / temperature, temperature**2)
    return kd_loss

In [24]:
# Model, optimizer
student_model = get_student_model()
optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)

# Average the loss across the batch size within an epoch
train_loss = tf.keras.metrics.Mean(name="train_loss")
valid_loss = tf.keras.metrics.Mean(name="test_loss")

# Specify the performance metric
train_acc = tf.keras.metrics.SparseCategoricalAccuracy(name="train_acc")
valid_acc = tf.keras.metrics.SparseCategoricalAccuracy(name="valid_acc")

In [25]:
# Train utils
@tf.function
def model_train(images, labels, teacher_model,
                student_model, optimizer, temperature):
    teacher_logits = teacher_model(images)

    with tf.GradientTape() as tape:
        student_logits = student_model(images)
        loss = get_kd_loss(student_logits, teacher_logits, temperature)

    gradients = tape.gradient(loss, student_model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, student_model.trainable_variables))

    train_loss(loss)
    train_acc(labels, tf.nn.softmax(student_logits))

In [26]:
# Validation utils
@tf.function
def model_validate(images, labels, teacher_model,
                   student_model, temperature):
    teacher_logits = teacher_model(images)

    student_logits = student_model(images)
    loss = get_kd_loss(student_logits, teacher_logits, temperature)

    valid_loss(loss)
    valid_acc(labels, tf.nn.softmax(student_logits))

In [30]:
# Tie everything together
def train_model(epochs, teacher_model, student_model, optimizer, temperature=0.5):
    for epoch in range(epochs):
        for (images, labels) in train_ds:
            model_train(images, labels, teacher_model, student_model, optimizer, temperature)

        for (images, labels) in test_ds:
            model_validate(images, labels, teacher_model, student_model, temperature)

        (loss, acc) = train_loss.result(), train_acc.result()
        (val_loss, val_acc) = valid_loss.result(), valid_acc.result()

        train_loss.reset_state(), train_acc.reset_state()
        valid_loss.reset_state(), valid_acc.reset_state()

        template = "Epoch {}, loss: {:.3f}, acc: {:.3f}, val_loss: {:.3f}, val_acc: {:.3f}"
        print (template.format(epoch+1,
                            loss,
                            acc,
                            val_loss,
                            val_acc))


    return teacher_model, student_model

In [31]:
_, student_model = train_model(10, teacher_model, student_model, optimizer)

Epoch 1, loss: 0.100, acc: 0.838, val_loss: 0.101, val_acc: 0.837
Epoch 2, loss: 0.085, acc: 0.856, val_loss: 0.121, val_acc: 0.821
Epoch 3, loss: 0.084, acc: 0.858, val_loss: 0.095, val_acc: 0.850
Epoch 4, loss: 0.083, acc: 0.859, val_loss: 0.093, val_acc: 0.847
Epoch 5, loss: 0.082, acc: 0.861, val_loss: 0.100, val_acc: 0.837
Epoch 6, loss: 0.082, acc: 0.860, val_loss: 0.094, val_acc: 0.846
Epoch 7, loss: 0.081, acc: 0.862, val_loss: 0.095, val_acc: 0.844
Epoch 8, loss: 0.079, acc: 0.865, val_loss: 0.095, val_acc: 0.846
Epoch 9, loss: 0.078, acc: 0.867, val_loss: 0.093, val_acc: 0.846
Epoch 10, loss: 0.079, acc: 0.864, val_loss: 0.094, val_acc: 0.846


This can be further improved with longer training time and more careful hyperparameter tuning.

In [33]:
# Serialize
student_model.save_weights("student_model.weights.h5")

In [34]:
# Investigate the sizes
!ls -lh *.h5

-rw-r--r-- 1 root root 166K Jan  7 06:00 student_model.weights.h5
-rw-r--r-- 1 root root 975K Jan  7 05:56 teacher_model.weights.h5


Let's check the total number of trainable params.

In [35]:
teacher_model.summary()

In [36]:
student_model.summary()

Further size decrease is possible with TFLite.

In [37]:
# Credits: https://www.tensorflow.org/lite/performance/post_training_quant

def representative_data_gen():
    for input_value in tf.data.Dataset.from_tensor_slices(X_train).batch(1).take(100):
        yield [input_value]

def convert_to_tflite(model, tflite_file):
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.representative_dataset = representative_data_gen
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
    converter.inference_input_type = tf.int8
    converter.inference_output_type = tf.int8
    tflite_quant_model = converter.convert()

    open(tflite_file, 'wb').write(tflite_quant_model)

In [38]:
convert_to_tflite(teacher_model, "teacher.tflite")
convert_to_tflite(student_model, "student.tflite")

Saved artifact at '/tmp/tmpohbgb5h_'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='keras_tensor_111')
Output Type:
  TensorSpec(shape=(None, 10), dtype=tf.float32, name=None)
Captures:
  137555976540752: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137555976541520: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137555976541328: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137555976541712: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137555976541904: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137555976543056: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137555976543440: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137555976542096: TensorSpec(shape=(), dtype=tf.resource, name=None)




Saved artifact at '/tmp/tmp6lf22ahb'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='keras_tensor_148')
Output Type:
  TensorSpec(shape=(None, 10), dtype=tf.float32, name=None)
Captures:
  137556097652240: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137556097655120: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137556097655696: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137556097652816: TensorSpec(shape=(), dtype=tf.resource, name=None)




In [39]:
!ls -lh *.tflite

-rw-r--r-- 1 root root 41K Jan  7 06:00 student.tflite
-rw-r--r-- 1 root root 87K Jan  7 06:00 teacher.tflite
