In [7]:
from tensorflow import keras

In [8]:
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images = train_images / 255.0
test_images = test_images / 255.0

In [9]:
batch_size = 128
validation_split = 0.1

In [13]:
reference_model = keras.Sequential(
  [
    keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
    keras.layers.MaxPooling2D(pool_size=(2, 2)),
    keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
    keras.layers.MaxPooling2D(pool_size=(2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(10, activation='softmax'),
  ],
  name='reference'
)

reference_model.compile(
  optimizer=keras.optimizers.Adam(0.001),
  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

reference_model.summary()

Model: "simple_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_2 (Conv2D)           (None, 26, 26, 32)        320       
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 13, 13, 32)       0         
 2D)                                                             
                                                                 
 conv2d_3 (Conv2D)           (None, 11, 11, 64)        18496     
                                                                 
 max_pooling2d_2 (MaxPooling  (None, 5, 5, 64)         0         
 2D)                                                             
                                                                 
 flatten_2 (Flatten)         (None, 1600)              0         
                                                                 
 dropout_2 (Dropout)         (None, 1600)             

In [14]:
reference_model.fit(
  train_images,
  train_labels,
  batch_size=batch_size,
  epochs=6,
  validation_split=validation_split,
)

reference_model.save('reference')

Epoch 1/6


  output, from_logits = _get_logits(


Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6


2023-06-05 18:49:55.790295: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'inputs' with dtype float and shape [?,1600]
	 [[{{node inputs}}]]
2023-06-05 18:49:55.997899: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'inputs' with dtype float and shape [?,1600]
	 [[{{node inputs}}]]


INFO:tensorflow:Assets written to: simple/assets


INFO:tensorflow:Assets written to: simple/assets


In [15]:
_, reference_model_accuracy = reference_model.evaluate(test_images, test_labels, verbose=0)

print('Reference test accuracy:', reference_model_accuracy)

Reference test accuracy: 0.9883999824523926


In [6]:
model = keras.models.Sequential(
  [
    keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
    keras.layers.AveragePooling2D(pool_size=(2, 2), strides=(2, 2)),
    keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
    keras.layers.AveragePooling2D(pool_size=(2, 2), strides=(2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(256, activation='relu'),
    keras.layers.Dense(10),
  ],
  name='baseline',
)

model.compile(
  optimizer=keras.optimizers.Adam(0.001),
  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

model.summary()

Model: "baseline"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_4 (Conv2D)           (None, 26, 26, 32)        320       
                                                                 
 average_pooling2d_4 (Averag  (None, 13, 13, 32)       0         
 ePooling2D)                                                     
                                                                 
 conv2d_5 (Conv2D)           (None, 11, 11, 64)        18496     
                                                                 
 average_pooling2d_5 (Averag  (None, 5, 5, 64)         0         
 ePooling2D)                                                     
                                                                 
 flatten_2 (Flatten)         (None, 1600)              0         
                                                                 
 dense_4 (Dense)             (None, 256)               409

In [7]:
model.fit(
  train_images,
  train_labels,
  batch_size=batch_size,
  epochs=6,
  validation_split=validation_split,
)

model.save('baseline')

Epoch 1/6
Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6




INFO:tensorflow:Assets written to: baseline/assets


INFO:tensorflow:Assets written to: baseline/assets


In [8]:
_, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

Baseline test accuracy: 0.9916999936103821


In [11]:
from turtle import mode
import tensorflow_model_optimization as tfmot
import numpy as np

# Compute end step to finish pruning after 2 epochs.
epochs = 2

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

pruning_params = {
  'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50, final_sparsity=0.80, begin_step=0, end_step=end_step)
}

model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
model_for_pruning._name = 'pruned'

model_for_pruning.compile(
  optimizer=keras.optimizers.Adam(0.001),
  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  metrics=[keras.metrics.SparseCategoricalAccuracy()]
)

model_for_pruning.summary()

Model: "pruned"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_conv2d_  (None, 26, 26, 32)       610       
 4 (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_average  (None, 13, 13, 32)       1         
 _pooling2d_4 (PruneLowMagni                                     
 tude)                                                           
                                                                 
 prune_low_magnitude_conv2d_  (None, 11, 11, 64)       36930     
 5 (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_average  (None, 5, 5, 64)         1         
 _pooling2d_5 (PruneLowMagni                                     
 tude)                                                      

In [12]:
callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep()
]

model_for_pruning.fit(
    train_images, 
    train_labels, 
    batch_size=batch_size, 
    epochs=epochs, 
    validation_split=validation_split, 
    callbacks=callbacks
)

model_for_pruning = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
model_for_pruning.compile(
    optimizer=keras.optimizers.Adam(0.001),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()]
)
model_for_pruning.save('pruned')

Epoch 1/2
Epoch 2/2




INFO:tensorflow:Assets written to: pruned/assets


INFO:tensorflow:Assets written to: pruned/assets


In [13]:
_, model_for_pruning_accuracy = model_for_pruning.evaluate(test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy) 
print('Pruned test accuracy:', model_for_pruning_accuracy)

Baseline test accuracy: 0.9916999936103821
Pruned test accuracy: 0.9909999966621399


In [14]:
import tensorflow as tf

class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        """ Configure the distiller.

        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between student
                predictions and ground-truth
            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
            temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        # Unpack data
        x, y = data

        # Forward pass of teacher
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of student
            student_predictions = self.student(x, training=True)

            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)

            # Compute scaled distillation loss from https://arxiv.org/abs/1503.02531
            # The magnitudes of the gradients produced by the soft targets scale
            # as 1/T^2, multiply them by T^2 when using both hard and soft targets.
            distillation_loss = (
                self.distillation_loss_fn(
                    tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                    tf.nn.softmax(student_predictions / self.temperature, axis=1),
                )
                * self.temperature**2
            )

            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

        # Compute gradients
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, student_predictions)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
        # Unpack the data
        x, y = data

        # Compute predictions
        y_prediction = self.student(x, training=False)

        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results

In [15]:
model_for_distillation = keras.models.Sequential(
  [
    keras.layers.Conv2D(32 / 4, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
    keras.layers.AveragePooling2D(pool_size=(2, 2), strides=(2, 2)),
    keras.layers.Conv2D(64 / 4, kernel_size=(3, 3), activation='relu'),
    keras.layers.AveragePooling2D(pool_size=(2, 2), strides=(2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(256 / 4, activation='relu'),
    keras.layers.Dense(10),
  ],
  name='distilled',
)

model_for_distillation.compile(
    optimizer=keras.optimizers.Adam(0.001),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

model_for_distillation.summary()

Model: "distilled"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_6 (Conv2D)           (None, 26, 26, 8)         80        
                                                                 
 average_pooling2d_6 (Averag  (None, 13, 13, 8)        0         
 ePooling2D)                                                     
                                                                 
 conv2d_7 (Conv2D)           (None, 11, 11, 16)        1168      
                                                                 
 average_pooling2d_7 (Averag  (None, 5, 5, 16)         0         
 ePooling2D)                                                     
                                                                 
 flatten_3 (Flatten)         (None, 400)               0         
                                                                 
 dense_6 (Dense)             (None, 64)                25

In [16]:
# Initialize and compile distiller
distiller = Distiller(student=model_for_distillation, teacher=model)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10,
)

# Distill teacher to student
distiller.fit(train_images, train_labels, epochs=3)

model_for_distillation.save('distilled')

Epoch 1/3
Epoch 2/3
Epoch 3/3


[0.9847999811172485, 0.00040756131056696177]

In [17]:
_, model_for_distillation_accuracy = model_for_distillation.evaluate(test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy) 
print('Pruned test accuracy:', model_for_pruning_accuracy)
print('Distilled test accuracy:', model_for_distillation_accuracy)

Baseline test accuracy: 0.9916999936103821
Pruned test accuracy: 0.9909999966621399
Distilled test accuracy: 0.9847999811172485
