Here we combine the 3 distilled models to get a performance comaprable to the teacher.

In [1]:
!pip install keras.utils

Collecting keras.utils
  Downloading keras-utils-1.0.13.tar.gz (2.4 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: keras.utils
  Building wheel for keras.utils (setup.py) ... [?25l[?25hdone
  Created wheel for keras.utils: filename=keras_utils-1.0.13-py3-none-any.whl size=2631 sha256=84798b58706b021b03c8b6c027fd99a1dd85e2d4c7dc32541d550fbb6f7d5a90
  Stored in directory: /root/.cache/pip/wheels/5c/c0/b3/0c332de4fd71f3733ea6d61697464b7ae4b2b5ff0300e6ca7a
Successfully built keras.utils
Installing collected packages: keras.utils
Successfully installed keras.utils-1.0.13


In [2]:
!pip install pyyaml h5py  # Required to save models in HDF5 format



In [3]:
!pip install --upgrade keras

Collecting keras
  Downloading keras-3.3.3-py3-none-any.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
Collecting namex (from keras)
  Downloading namex-0.0.8-py3-none-any.whl (5.8 kB)
Collecting optree (from keras)
  Downloading optree-0.11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (311 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m311.2/311.2 kB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: namex, optree, keras
  Attempting uninstall: keras
    Found existing installation: keras 2.15.0
    Uninstalling keras-2.15.0:
      Successfully uninstalled keras-2.15.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.3.3 which is incompatible.[

In [4]:
import os
import keras
from keras import layers
from keras import ops
import numpy as np
import matplotlib.pyplot as plt
import PIL
import tensorflow as tf
import pathlib

from keras.datasets import cifar10
from keras.models import Sequential
from keras import datasets, layers, models
from keras import regularizers
from keras.layers import Dense, Dropout, BatchNormalization
from sklearn.preprocessing import OneHotEncoder

In [5]:
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()



Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 0us/step


In [6]:
batch_size = 64
num_classes = 10
train_images = train_images.astype('float32')
test_images = test_images.astype('float32')
train_images = train_images / 255
test_images = test_images / 255
num_classes = 10
train_labels = np.array(train_labels)
test_labels = np.array(test_labels)
train_images = np.reshape(train_images, (-1, 32, 32, 3))
test_images = np.reshape(test_images, (-1, 32, 32, 3))

In [18]:
teacher = keras.Sequential(
    [
        keras.Input(shape=(32, 32, 3)),
        layers.Conv2D(32, (3, 3), padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.Conv2D(32, (3, 3), padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(pool_size=(2, 2)),

        layers.Dropout(0.3),

        layers.Conv2D(64, (3, 3), padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.Conv2D(64, (3, 3), padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(pool_size=(2, 2)),

        layers.Dropout(0.5),

        layers.Conv2D(128, (3, 3), padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.Conv2D(128, (3, 3), padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(pool_size=(2, 2)),

        layers.Dropout(0.5),

        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.BatchNormalization(),

        layers.Dropout(0.5),

        layers.Dense(num_classes),
    ],
    name="teacher",
)


class Distiller_new(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.2,
        temperature=3,
    ):
        """Configure the distiller.

        Args:
            optimizer: Keras optimizer for the student weights

            metrics: Keras metrics for evaluation

            student_loss_fn: Loss function of difference between student
                predictions and ground-truth

            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions

            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn

            temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def compute_loss(
        self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False
    ):
        teacher_pred = self.teacher(x, training=False)
        student_loss = self.student_loss_fn(y, y_pred)

        distillation_loss = self.distillation_loss_fn(
            ops.softmax(teacher_pred / self.temperature, axis=1),
            ops.softmax(y_pred / self.temperature, axis=1),
        ) * (self.temperature**2)

        loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
        return loss

    def train_step(self, data):
        x, y = data
        with tf.GradientTape() as tape:
            y_pred = self.student(x, training=True)

            loss = self.compute_loss(x, y, y_pred)

        gradients = tape.gradient(loss, self.student.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.student.trainable_variables))

        self.compiled_metrics.update_state(y, y_pred)

        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        x, y = data
        y_pred = self.student(x, training=False)

        loss = self.compute_loss(x, y, y_pred)
        self.compiled_metrics.update_state(y, y_pred)

        return {m.name: m.result() for m in self.metrics}

student = keras.Sequential(
    [
        keras.Input(shape=(32, 32, 3)),
        layers.Conv2D(8, (3, 3), padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.Conv2D(8, (3, 3), padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(16, (3, 3), padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.Conv2D(16, (3, 3), padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(pool_size=(2, 2)),

        layers.Conv2D(32, (3, 3), padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.Conv2D(32, (3, 3), padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(pool_size=(2, 2)),

        layers.Flatten(),
        layers.Dense(64, activation='relu'),
        layers.BatchNormalization(),

        layers.Dense(num_classes),
    ],
    name="student",
)


student_scratch = keras.models.clone_model(student)
student_A = keras.models.clone_model(student)
student_B = keras.models.clone_model(student)
student_C = keras.models.clone_model(student)


teacher.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)




In [19]:
A = Distiller_new(student= student_A, teacher=teacher)

A.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.2,
    temperature=5,
)


B = Distiller_new(student= student_B, teacher=teacher)

B.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.2,
    temperature=5,
)

C = Distiller_new(student= student_C, teacher=teacher)

C.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.2,
    temperature=5,
)

teacher.load_weights("teacher.weights.h5")
A.load_weights("distilled_student_1.weights.h5")
B.load_weights("distilled_student_2.weights.h5")
C.load_weights("distilled_student_3.weights.h5")


  saveable.load_own_variables(weights_store.get(inner_path))


In [14]:
#we verify that its loaded well
teacher.evaluate(test_images, test_labels)


[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - loss: 0.4044 - sparse_categorical_accuracy: 0.8701


[0.4072171747684479, 0.868399977684021]

In [20]:
A.evaluate(test_images, test_labels)


```
for metric in self.metrics:
    metric.update_state(y, y_pred)
```

  return self._compiled_metrics_update_state(


[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - sparse_categorical_accuracy: 0.7892 - loss: -0.0818


[-0.07023735344409943,
 -0.07023735344409943,
 0.7854999899864197,
 0.7854999899864197]

In [21]:
B.evaluate(test_images, test_labels)


[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - sparse_categorical_accuracy: 0.7911 - loss: 0.0411


[0.045466408133506775,
 0.045466408133506775,
 0.7900000214576721,
 0.7900000214576721]

In [22]:
C.evaluate(test_images, test_labels)


```
for metric in self.metrics:
    metric.update_state(y, y_pred)
```

  return self._compiled_metrics_update_state(


[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - sparse_categorical_accuracy: 0.7957 - loss: -0.0720


[-0.07121355086565018,
 -0.07121355086565018,
 0.7986999750137329,
 0.7986999750137329]

In [60]:

# To get class labels from predictions
#predicted_labels = tf.argmax(predictions, axis=1)
logitsA = A.student.predict(test_images)
logitsB = B.student.predict(test_images)
logitsC = C.student.predict(test_images)



[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step


In [61]:
average_logits = (logitsA + logitsB + logitsC) / 3.0

average_predicted_labels = tf.argmax(average_logits, axis=1)
predicted_A = np.array(tf.argmax(logitsA, axis=1))
predicted_B = np.array(tf.argmax(logitsB, axis=1))
predicted_C = np.array(tf.argmax(logitsC, axis=1))



In [71]:
test_labels = np.array([test_labels[i][0] for i in range(len(test_labels))])
average_predicted_labels = np.array(average_predicted_labels)

In [72]:
test_labels[:10]

array([3, 8, 8, 0, 6, 6, 1, 6, 3, 1], dtype=uint8)

In [73]:
average_predicted_labels[:10]

array([3, 8, 8, 0, 6, 6, 1, 6, 3, 1])

In [74]:
def find_accuracy(y_true, y_pred):
    accuracy = np.sum(y_true == y_pred) / len(y_true)
    return accuracy

In [75]:
find_accuracy(test_labels, predicted_A)


0.7855

In [76]:
find_accuracy(test_labels, predicted_B)


0.79

In [77]:
find_accuracy(test_labels, predicted_C)


0.7987

In [78]:
find_accuracy(test_labels, average_predicted_labels)

0.8226

In [79]:
teacher_pred = np.array(tf.argmax(teacher.predict(test_images), axis =1))
find_accuracy(test_labels, teacher_pred)

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step


0.8684

In [81]:
scratchA = keras.models.clone_model(student)
scratchB = keras.models.clone_model(student)
scratchC = keras.models.clone_model(student)



scratchA.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

scratchB.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

scratchC.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)


scratchA.load_weights("nondistilled_student_1.weights.h5")
scratchB.load_weights("nondistilled_student_2.weights.h5")
scratchC.load_weights("nondistilled_student_3.weights.h5")

scratchA_l = scratchA.predict(test_images)
scratchB_l = scratchB.predict(test_images)
scratchC_l = scratchC.predict(test_images)


average_scratch = (scratchA_l + scratchB_l + scratchC_l) / 3.0

average_scratch_predicted_labels = np.array(tf.argmax(average_scratch, axis=1))


find_accuracy(test_labels, average_scratch_predicted_labels)

  saveable.load_own_variables(weights_store.get(inner_path))


[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 5ms/step


0.7928