## Load the Teacher And distill

Here, we load the teacher model we trained and we distill it into a student.
For completeness, we also create a regular non distilled student here. We do this three times - once in the main file, once here, and lastly in Student3Distillation.ipynb file.

In [1]:
!pip install keras.utils



In [2]:
!pip install pyyaml h5py  # Required to save models in HDF5 format



In [3]:
!pip install --upgrade keras



In [4]:
import os
import keras
from keras import layers
from keras import ops
import numpy as np
import matplotlib.pyplot as plt
import PIL
import tensorflow as tf
import pathlib

from keras.datasets import cifar10
from keras.models import Sequential
from keras import datasets, layers, models
from keras import regularizers
from keras.layers import Dense, Dropout, BatchNormalization
from sklearn.preprocessing import OneHotEncoder

In [5]:
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()



In [6]:
batch_size = 64
num_classes = 10
train_images = train_images.astype('float32')
test_images = test_images.astype('float32')
train_images = train_images / 255
test_images = test_images / 255
num_classes = 10
train_labels = np.array(train_labels)
test_labels = np.array(test_labels)
train_images = np.reshape(train_images, (-1, 32, 32, 3))
test_images = np.reshape(test_images, (-1, 32, 32, 3))

In [7]:
teacher = keras.Sequential(
    [
        keras.Input(shape=(32, 32, 3)),
        layers.Conv2D(32, (3, 3), padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.Conv2D(32, (3, 3), padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(pool_size=(2, 2)),

        layers.Dropout(0.3),

        layers.Conv2D(64, (3, 3), padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.Conv2D(64, (3, 3), padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(pool_size=(2, 2)),

        layers.Dropout(0.5),

        layers.Conv2D(128, (3, 3), padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.Conv2D(128, (3, 3), padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(pool_size=(2, 2)),

        layers.Dropout(0.5),

        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.BatchNormalization(),

        layers.Dropout(0.5),

        layers.Dense(num_classes),
    ],
    name="teacher",
)


In [8]:
teacher.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

In [9]:
teacher.load_weights("teacher.weights.h5")

  saveable.load_own_variables(weights_store.get(inner_path))


In [10]:
#we verify that its loaded well
teacher.evaluate(test_images, test_labels)

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 6ms/step - loss: 0.4044 - sparse_categorical_accuracy: 0.8701


[0.4072171747684479, 0.868399977684021]

In [11]:
student = keras.Sequential(
    [
        keras.Input(shape=(32, 32, 3)),
        layers.Conv2D(8, (3, 3), padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.Conv2D(8, (3, 3), padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(pool_size=(2, 2)),


        layers.Conv2D(16, (3, 3), padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.Conv2D(16, (3, 3), padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(pool_size=(2, 2)),

        layers.Conv2D(32, (3, 3), padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.Conv2D(32, (3, 3), padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(pool_size=(2, 2)),

        layers.Flatten(),
        layers.Dense(64, activation='relu'),
        layers.BatchNormalization(),

        layers.Dense(num_classes),
    ],
    name="student",
)

student_scratch = keras.models.clone_model(student)

#student.summary()

In [12]:
class Distiller_new(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.2,
        temperature=3,
    ):
        """Configure the distiller.

        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between student
                predictions and ground-truth
            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
            temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def compute_loss(
        self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False
    ):
        teacher_pred = self.teacher(x, training=False)
        student_loss = self.student_loss_fn(y, y_pred)

        distillation_loss = self.distillation_loss_fn(
            ops.softmax(teacher_pred / self.temperature, axis=1),
            ops.softmax(y_pred / self.temperature, axis=1),
        ) * (self.temperature**2)

        loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
        return loss

    def train_step(self, data):
        x, y = data
        with tf.GradientTape() as tape:
            y_pred = self.student(x, training=True)

            loss = self.compute_loss(x, y, y_pred)

        gradients = tape.gradient(loss, self.student.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.student.trainable_variables))

        self.compiled_metrics.update_state(y, y_pred)

        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        x, y = data
        y_pred = self.student(x, training=False)

        loss = self.compute_loss(x, y, y_pred)
        self.compiled_metrics.update_state(y, y_pred)

        return {m.name: m.result() for m in self.metrics}

In [13]:
# Initialize and compile distiller
d = Distiller_new(student= student, teacher=teacher)
d.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.2,
    temperature=5,
)

# Train and evaluate student trained from scratch.
checkpoint_path = "training_student_distilled1/cp.weights.h5"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)


# Distill teacher to student with validation data
dis = d.fit(train_images, train_labels , epochs=50, batch_size= 64)

Epoch 1/50


```
for metric in self.metrics:
    metric.update_state(y, y_pred)
```

  return self._compiled_metrics_update_state(


[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 10ms/step - sparse_categorical_accuracy: 0.4345 - loss: -0.0080
Epoch 2/50
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 6ms/step - sparse_categorical_accuracy: 0.6520 - loss: -0.0033
Epoch 3/50
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 5ms/step - sparse_categorical_accuracy: 0.7140 - loss: -0.0037
Epoch 4/50
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 5ms/step - sparse_categorical_accuracy: 0.7382 - loss: -0.0068
Epoch 5/50
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 6ms/step - sparse_categorical_accuracy: 0.7578 - loss: -0.0092
Epoch 6/50
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 5ms/step - sparse_categorical_accuracy: 0.7741 - loss: -0.0106
Epoch 7/50
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 5ms/step - sparse_categorical_accuracy: 0.7804 - loss: -0.0127
Epoch 8/50
[1m782/782[0m [

In [14]:
d.save_weights("distilled_student_3.weights.h5")

In [15]:
d.evaluate(test_images,test_labels)

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - sparse_categorical_accuracy: 0.7957 - loss: -0.0720


[-0.07121357321739197,
 -0.07121357321739197,
 0.7986999750137329,
 0.7986999750137329]

In [16]:
# Train student as done usually
student_scratch.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train and evaluate student trained from scratch.
checkpoint_path = "training_student_scratch1/cp.weights.h5"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)


scratch = student_scratch.fit(train_images, train_labels, batch_size=64, epochs=50)

Epoch 1/50
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 7ms/step - loss: 1.6751 - sparse_categorical_accuracy: 0.4089
Epoch 2/50
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 5ms/step - loss: 1.0882 - sparse_categorical_accuracy: 0.6161
Epoch 3/50
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - loss: 0.9098 - sparse_categorical_accuracy: 0.6822
Epoch 4/50
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 0.8090 - sparse_categorical_accuracy: 0.7146
Epoch 5/50
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 0.7376 - sparse_categorical_accuracy: 0.7390
Epoch 6/50
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - loss: 0.6869 - sparse_categorical_accuracy: 0.7581
Epoch 7/50
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 0.6332 - sparse_categorical_accuracy: 0.7783
Epoch 8/50
[1m782/782[0m

In [17]:
student_scratch.save_weights("nondistilled_student_3.weights.h5")

In [18]:
student_scratch.evaluate(test_images,test_labels)

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - loss: 1.4322 - sparse_categorical_accuracy: 0.7186


[1.4486055374145508, 0.7156000137329102]