In [1]:
import keras
from keras import layers
from keras import ops
import numpy as np

# Set a backend, e.g., "tensorflow", "jax", or "torch"
# Make sure to install the backend first, e.g., pip install tensorflow
import os
os.environ["KERAS_BACKEND"] = "tensorflow"

In [7]:
# Data parameters
# Define the number of samples you want for your small dataset
N_TRAIN_SAMPLES = 5000  # Example: Use 5000 training samples (out of 50,000)
N_TEST_SAMPLES = 1000   # Example: Use 1000 testing samples (out of 10,000)

# Load the full CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

# --- Create the small subsets ---
# Ensure we don't request more samples than available
n_train_actual = min(N_TRAIN_SAMPLES, len(x_train))
n_test_actual = min(N_TEST_SAMPLES, len(x_test))

x_train = x_train[:n_train_actual]
y_train = y_train[:n_train_actual]
x_test = x_test[:n_test_actual]
y_test = y_test[:n_test_actual]

x_train = x_train.astype("float32") / 255.0
y_train = keras.utils.to_categorical(y_train, 10)
x_test = x_test.astype("float32") / 255.0
y_test = keras.utils.to_categorical(y_test, 10)


In [8]:
# --- Model Definitions ---

# Create a small student model
def create_student_model():
    inputs = keras.Input(shape=(32, 32, 3))
    x = layers.Conv2D(16, (3, 3), activation="relu")(inputs)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Conv2D(32, (3, 3), activation="relu")(x)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Flatten()(x)
    # The final layer has no activation, as we'll use from_logits=True in the loss
    outputs = layers.Dense(10)(x)
    return keras.Model(inputs=inputs, outputs=outputs, name="student")

# Create a larger teacher model
def create_teacher_model():
    inputs = keras.Input(shape=(32, 32, 3))
    x = layers.Conv2D(64, (3, 3), padding="same", activation="relu")(inputs)
    x = layers.Conv2D(64, (3, 3), padding="same", activation="relu")(x)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Conv2D(128, (3, 3), padding="same", activation="relu")(x)
    x = layers.Conv2D(128, (3, 3), padding="same", activation="relu")(x)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Flatten()(x)
    x = layers.Dense(256, activation="relu")(x)
    # The final layer has no activation for consistency
    outputs = layers.Dense(10)(x)
    return keras.Model(inputs=inputs, outputs=outputs, name="teacher")


In [9]:
# --- Train the Teacher Model (as a baseline) ---
print("--- Training the Teacher Model ---")
teacher = create_teacher_model()
teacher.compile(
    optimizer="adam",
    loss=keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)
teacher.fit(x_train, y_train, epochs=10, batch_size=64, validation_split=0.1)
teacher_test_loss, teacher_test_acc = teacher.evaluate(x_test, y_test)
print(f"Teacher Test Accuracy: {teacher_test_acc:.4f}")

--- Training the Teacher Model ---
Epoch 1/10
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m71s[0m 969ms/step - accuracy: 0.1649 - loss: 2.1832 - val_accuracy: 0.3280 - val_loss: 1.8638
Epoch 2/10
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m81s[0m 953ms/step - accuracy: 0.3531 - loss: 1.7605 - val_accuracy: 0.4180 - val_loss: 1.6884
Epoch 3/10
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m84s[0m 981ms/step - accuracy: 0.4407 - loss: 1.5387 - val_accuracy: 0.4620 - val_loss: 1.5323
Epoch 4/10
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m80s[0m 955ms/step - accuracy: 0.4940 - loss: 1.3925 - val_accuracy: 0.4360 - val_loss: 1.5905
Epoch 5/10
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m67s[0m 950ms/step - accuracy: 0.5326 - loss: 1.2747 - val_accuracy: 0.5040 - val_loss: 1.4126
Epoch 6/10
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m68s[0m 954ms/step - accuracy: 0.6028 - loss: 1.1095 - val_accuracy: 0.5240 - val_

In [16]:
class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.student = student
        self.teacher = teacher
        # We will track these losses manually
        self.student_loss_tracker = keras.metrics.Mean(name="student_loss")
        self.distillation_loss_tracker = keras.metrics.Mean(name="distillation_loss")

    @property
    def metrics(self):
        # We list our `Metric` objects here so `reset_states()` can be called automatically.
        return [
            self.student_loss_tracker,
            self.distillation_loss_tracker,
        ]

    def compile(
        self, optimizer, metrics, student_loss_fn, distillation_loss_fn, alpha=0.1, temperature=3
    ):
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        x, y = data

        # Get teacher predictions
        teacher_predictions = self.teacher(x, training=False)

        with keras.backend.GradientTape() as tape:
            # Get student predictions
            student_predictions = self.student(x, training=True)

            # 1. Calculate the student loss (against hard labels)
            student_loss = self.student_loss_fn(y, student_predictions)

            # 2. Calculate the distillation loss (against soft teacher labels)
            distillation_loss = self.distillation_loss_fn(
                ops.softmax(teacher_predictions / self.temperature),
                ops.softmax(student_predictions / self.temperature),
            )

            # 3. Combine the losses
            total_loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

        # Compute gradients of the total loss w.r.t the student's trainable variables
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(total_loss, trainable_vars)

        # Update student weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update metrics
        self.student_loss_tracker.update_state(student_loss)
        self.distillation_loss_tracker.update_state(distillation_loss)
        # Note: Compiled metrics are updated automatically
        self.compute_metrics(x, y, student_predictions, None)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        return results

    def test_step(self, data):
        # The test step only evaluates the student model on the validation data
        x, y = data
        student_predictions = self.student(x, training=False)
        student_loss = self.student_loss_fn(y, student_predictions)
        self.student_loss_tracker.update_state(student_loss)
        self.compute_metrics(x, y, student_predictions, None)
        results = {m.name: m.result() for m in self.metrics}
        return results



In [17]:
# --- Distillation using Way 1 ---
print("\n--- Distilling with Way 1: Model Subclassing ---")
student_v1 = create_student_model()
distiller = Distiller(student=student_v1, teacher=teacher)

distiller.compile(
    optimizer="adam",
    metrics=["accuracy"],
    student_loss_fn=keras.losses.CategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10,
)

distiller.fit(x_train, y_train, epochs=5, batch_size=64)
student_v1_loss, student_v1_acc = student_v1.evaluate(x_test, y_test)
print(f"Way 1 - Distilled Student Accuracy: {student_v1_acc:.4f}")


--- Distilling with Way 1: Model Subclassing ---
Epoch 1/5


AttributeError: module 'keras.api.backend' has no attribute 'GradientTape'

This approach is more idiomatic for many Keras users. It avoids subclassing Model by cleverly constructing a training model and a custom loss function.
Concept:
Create a temporary "training model" using the Functional API that takes an input x and outputs both the student's and teacher's predictions.
Write a custom loss function that accepts this dual output (y_pred) and calculates the combined distillation loss.
Train this temporary model. The original student model's weights will be updated because it's part of the graph.
For evaluation, use the original student model.

In [14]:
def distillation_loss(y_true, y_pred, alpha=0.1, temperature=10):
    # y_pred will be a list of [student_logits, teacher_logits]
    student_logits, teacher_logits = y_pred

    # 1. Student loss against the hard labels
    student_loss = keras.losses.categorical_crossentropy(y_true, student_logits, from_logits=True)

    # 2. Distillation loss against the soft teacher labels
    distillation_loss = keras.losses.kl_divergence(
        ops.softmax(teacher_logits / temperature),
        ops.softmax(student_logits / temperature),
    )

    # 3. Combine the two losses
    total_loss = alpha * student_loss + (1 - self.alpha) * distillation_loss
    return total_loss

In [15]:
# --- Distillation using Way 2 ---
print("\n--- Distilling with Way 2: Functional API + Custom Loss ---")
student_v2 = create_student_model()

# Make the teacher non-trainable
teacher.trainable = False

# Create the combined training model
inputs = keras.Input(shape=(32, 32, 3))
student_output = student_v2(inputs)
teacher_output = teacher(inputs)

# The training model outputs a list of [student, teacher] predictions
training_model = keras.Model(inputs, [student_output, teacher_output])

# Compile the training model with our custom loss
training_model.compile(
    optimizer="adam",
    loss=distillation_loss, # Our custom loss function
    metrics=["accuracy"],   # This metric will be calculated on the first output (student)
)

# Train the model. The loss function gets y_true and the model's output [s_out, t_out].
training_model.fit(x_train, y_train, epochs=5, batch_size=64)

# Evaluate the final student model directly
student_v2_loss, student_v2_acc = student_v2.evaluate(x_test, y_test)
print(f"Way 2 - Distilled Student Accuracy: {student_v2_acc:.4f}")


--- Distilling with Way 2: Functional API + Custom Loss ---
Epoch 1/5


KeyError: "The path: (0,) in the `loss` argument, can't be found in either the model's output (`y_pred`) or in the labels (`y_true`)."

This is the most modern and concise Keras 3 pattern. It leverages the add_loss() method to "bake" the distillation loss directly into the model's graph.
Concept:

Define the model using the Functional API.
Within the model definition, calculate the distillation loss (KL divergence between student and teacher logits).
Add this loss directly to the model using model.add_loss().
Compile the model with only the standard student loss (e.g., Cross-Entropy). Keras will automatically combine the compile loss and the added loss during training.



In [18]:
# --- Distillation using Way 3 ---
print("\n--- Distilling with Way 3: Functional API + add_loss() ---")

# Hyperparameters
alpha = 0.1
temperature = 10

# Make the teacher non-trainable again
teacher.trainable = False

# Create the student model via the Functional API
student_v3 = create_student_model()
inputs = student_v3.input
student_logits = student_v3.output

# Get teacher logits for the same input
teacher_logits = teacher(inputs)

# Calculate the distillation loss
distillation_loss = keras.losses.kl_divergence(
    ops.softmax(teacher_logits / temperature),
    ops.softmax(student_logits / temperature),
)

# Create the final distiller model that includes the added loss
# The model takes standard inputs and produces student logits as output
distiller_model = keras.Model(inputs, student_logits)
distiller_model.add_loss((1 - alpha) * ops.mean(distillation_loss)) # Note the weighting

# Compile the model with the standard student loss
distiller_model.compile(
    optimizer="adam",
    # The loss is weighted by alpha here
    loss=keras.losses.CategoricalCrossentropy(from_logits=True, weight=alpha),
    metrics=["accuracy"],
)

# Fit the distiller model. Keras handles both losses automatically.
distiller_model.fit(x_train, y_train, epochs=5, batch_size=64)

# The distiller_model *is* our final student model.
student_v3_loss, student_v3_acc = distiller_model.evaluate(x_test, y_test)
print(f"Way 3 - Distilled Student Accuracy: {student_v3_acc:.4f}")


--- Distilling with Way 3: Functional API + add_loss() ---


NotImplementedError: 

### Step-N 🇰
1. [Knowledge Distillation using Keras](https://keras.io/examples/vision/knowledge_distillation/)