# Project A: Knowledge Distillation for Building Lightweight Deep Learning Models in Visual Classification Tasks

In [3]:
import os
import shutil
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_ranking as tfr

BATCH_SIZE = 32
INITIAL_EPOCHS = 10
FINE_TUNE_EPOCHS = 25
NUM_CLASSES = 2  # 2 total classes.
# Hyperparameters for distillation (need to be tuned).
ALPHA = 0.5 # task balance between cross-entropy and distillation loss
DISTILLATION_TEMPERATURE = 4. #temperature hyperparameter


# Data loading

In [4]:
# Create directories
DATASET_PATH = 'mhist_dataset'

def initPath():
  partitions = ['train', 'test']
  majorityVoteLabels = ['HP', 'SSA']
  for p in partitions:
    for m in majorityVoteLabels:
      path = os.path.join(DATASET_PATH, p, m)
      os.makedirs(path, exist_ok=True)

initPath()

# Move images to the corresponding folders
CSV_PATH = 'mhist_dataset\\annotations.csv'
IMG_PATH = 'mhist_dataset\\images'

def copyImage():
  np.set_printoptions(precision=3, suppress=True)
  csv = pd.read_csv(CSV_PATH)
  content = dict(csv)
  imageNumber = len(content['Image Name'])
  print(f'{imageNumber} images found in {CSV_PATH}')
  HPCount = 0
  SSACount = 0
  for i in range(imageNumber):
    src = os.path.join(IMG_PATH, content['Image Name'].get(i))
    dst = os.path.join(
        DATASET_PATH,
        content['Partition'].get(i),
        content['Majority Vote Label'].get(i),
        content['Image Name'].get(i)
    )
    if content['Partition'].get(i) == 'train':
      if content['Majority Vote Label'].get(i) == 'HP':
        HPCount += 1
      elif content['Majority Vote Label'].get(i) == 'SSA':
        SSACount += 1
    # print(f'Copying image from {src} to {dst}')
    shutil.copyfile(src, dst)
    # print('Done. ')
  print(f'HP Count: {HPCount}, SSA Count: {SSACount}')

# copyImage()

IMG_SIZE = (224, 224)
IMG_SHAPE = (224, 224, 3)
TRAIN_PATH = 'mhist_dataset\\train'
TEST_PATH =  'mhist_dataset\\test'

# Load from directory
trainDataset = tf.keras.utils.image_dataset_from_directory(
  TRAIN_PATH,
  shuffle=True,
  batch_size=BATCH_SIZE, 
  image_size=IMG_SIZE
)
testDataset = tf.keras.utils.image_dataset_from_directory(
  TEST_PATH, 
  shuffle=True,
  batch_size=BATCH_SIZE, 
  image_size=IMG_SIZE
)

trainDataset = trainDataset.prefetch(buffer_size=tf.data.AUTOTUNE)
testDataset = testDataset.prefetch(buffer_size=tf.data.AUTOTUNE)

# Cast label from string to one hot
trainDataset = trainDataset.map(lambda x, y: (x, tf.one_hot(y, 2)))
testDataset = testDataset.map(lambda x, y: (x, tf.one_hot(y, 2)))

data_augmentation = tf.keras.Sequential([
    tf.keras.layers.RandomFlip('horizontal'),
    tf.keras.layers.RandomRotation(0.2),
])


Found 2175 files belonging to 2 classes.
Found 977 files belonging to 2 classes.


# Model creation

In [5]:
#@test {"output": "ignore"}

# Import ResNet50V2
baseModel = tf.keras.applications.resnet_v2.ResNet50V2(
    include_top=False, 
    input_shape=IMG_SHAPE
)

# Feature extraction
imageBatch, labelBatch = next(iter(trainDataset))
featureBatch = baseModel(imageBatch)

# Freeze the base model
baseModel.trainable = False
baseModel.summary()

# Add a classification head
preprocess_input = tf.keras.applications.resnet_v2.preprocess_input
inputs = tf.keras.Input(shape=IMG_SHAPE)
x = data_augmentation(inputs)
x = preprocess_input(x)
x = baseModel(inputs, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = tf.keras.layers.Dense(2)(x)
teacherModel = tf.keras.Model(inputs, outputs)
teacherModel.summary()

# Import MobileNetV2
basicModel = tf.keras.applications.mobilenet_v2.MobileNetV2(
    include_top=False, 
    input_shape=IMG_SHAPE
)

# Feature extraction
imageBatch, labelBatch = next(iter(trainDataset))
featureBatch = basicModel(imageBatch)

# Freeze the basic model
basicModel.trainable = False
basicModel.summary()

# Add a classification head
globalAverageLayer = tf.keras.layers.GlobalAveragePooling2D()
predictionLayer = tf.keras.layers.Dense(2)

preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input
inputs = tf.keras.Input(shape=IMG_SHAPE)
x = data_augmentation(inputs)
x = preprocess_input(x)
x = basicModel(inputs, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = tf.keras.layers.Dense(2)(x)
studentModel = tf.keras.Model(inputs, outputs)
studentModel.summary()


Model: "resnet50v2"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 224, 224, 3)]        0         []                            
                                                                                                  
 conv1_pad (ZeroPadding2D)   (None, 230, 230, 3)          0         ['input_1[0][0]']             
                                                                                                  
 conv1_conv (Conv2D)         (None, 112, 112, 64)         9472      ['conv1_pad[0][0]']           
                                                                                                  
 pool1_pad (ZeroPadding2D)   (None, 114, 114, 64)         0         ['conv1_conv[0][0]']          
                                                                                         

# Train and evaluation -- Distiller Class

In [6]:
# In reference to https://keras.io/examples/vision/knowledge_distillation/
class Distiller(tf.keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        distillation_loss_fn,
        student_loss_fn,
        metrics,
        alpha=0.5,
        temperature=4,
    ):
        """ Configure the distiller.

        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between student
                predictions and ground-truth
            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
            temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super().compile(optimizer=optimizer, metrics=metrics)
        self.distillation_loss_fn = distillation_loss_fn
        self.student_loss_fn = student_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        # Unpack data
        x, y = data

        # Forward pass of teacher
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of student
            student_predictions = self.student(x, training=True)

            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)

            # Compute scaled distillation loss from https://arxiv.org/abs/1503.02531
            # The magnitudes of the gradients produced by the soft targets scale
            # as 1/T^2, multiply them by T^2 when using both hard and soft targets.
            distillation_loss = (
                self.distillation_loss_fn(
                    tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                    tf.nn.softmax(student_predictions / self.temperature, axis=1),
                )
                * self.temperature**2
            )

            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

        # Compute gradients
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, student_predictions)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
        # Unpack the data
        x, y = data

        # Compute predictions
        y_prediction = self.student(x, training=False)

        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results
    

# Training models

In [20]:
teacherModel.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=tfr.keras.losses.SoftmaxLoss(),
    metrics=[tf.keras.metrics.AUC(from_logits=True, name='auc')]
)
teacherModel.fit(trainDataset, epochs=INITIAL_EPOCHS, validation_data=testDataset)

distiller = Distiller(student=studentModel, teacher=teacherModel)
distiller.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    distillation_loss_fn=tf.keras.losses.KLDivergence(),
    student_loss_fn=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.AUC(from_logits=True, name='auc')],
    alpha=ALPHA,
    temperature=DISTILLATION_TEMPERATURE,
)
distiller.fit(trainDataset, epochs=INITIAL_EPOCHS, validation_data=testDataset)


Epoch 1/10


Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.src.callbacks.History at 0x11c6c531010>

# Fine-tuning

In [24]:
baseModel.trainable = True
fine_tune_at = 100
for layer in baseModel.layers[:fine_tune_at]:
  layer.trainable = False

teacherModel.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
    loss=tfr.keras.losses.SoftmaxLoss(),
    metrics=[tf.keras.metrics.AUC(from_logits=True, name='auc')]
)
teacherModel.summary()
teacherModel.fit(trainDataset, epochs=(FINE_TUNE_EPOCHS + INITIAL_EPOCHS), validation_data=testDataset, initial_epoch=INITIAL_EPOCHS)

basicModel.trainable = True
fine_tune_at = 100
for layer in basicModel.layers[:fine_tune_at]:
  layer.trainable = False
studentModel.summary()

distiller = Distiller(student=studentModel, teacher=teacherModel)
distiller.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    distillation_loss_fn=tf.keras.losses.KLDivergence(),
    student_loss_fn=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.AUC(from_logits=True, name='auc')],
    alpha=ALPHA,
    temperature=DISTILLATION_TEMPERATURE,
)
distiller.fit(trainDataset, epochs=(FINE_TUNE_EPOCHS + INITIAL_EPOCHS), validation_data=testDataset, initial_epoch=INITIAL_EPOCHS)

Model: "model_6"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_14 (InputLayer)       [(None, 224, 224, 3)]     0         
                                                                 
 resnet50v2 (Functional)     (None, 7, 7, 2048)        23564800  
                                                                 
 global_average_pooling2d_9  (None, 2048)              0         
  (GlobalAveragePooling2D)                                       
                                                                 
 dropout_6 (Dropout)         (None, 2048)              0         
                                                                 
 dense_9 (Dense)             (None, 2)                 4098      
                                                                 
Total params: 23568898 (89.91 MB)
Trainable params: 20563970 (78.45 MB)
Non-trainable params: 3004928 (11.46 MB)
____________

Epoch 11/35
Epoch 12/35
Epoch 13/35
Epoch 14/35
Epoch 15/35
Epoch 16/35
Epoch 17/35
Epoch 18/35
Epoch 19/35
Epoch 20/35
Epoch 21/35
Epoch 22/35
Epoch 23/35
Epoch 24/35
Epoch 25/35
Epoch 26/35
Epoch 27/35
Epoch 28/35
Epoch 29/35
Epoch 30/35
Epoch 31/35
Epoch 32/35
Epoch 33/35
Epoch 34/35
Epoch 35/35
Model: "model_7"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_16 (InputLayer)       [(None, 224, 224, 3)]     0         
                                                                 
 mobilenetv2_1.00_224 (Func  (None, 7, 7, 1280)        2257984   
 tional)                                                         
                                                                 
 global_average_pooling2d_1  (None, 1280)              0         
 1 (GlobalAveragePooling2D)                                      
                                                                 
 dropout_7 (Dropout)   

<keras.src.callbacks.History at 0x11c73930790>

# Test AUC vs. tempreture curve

In [8]:
import matplotlib.pyplot as plt

temperatureList = [1, 2, 4, 16, 32, 64]
AUCList = []

# Clear Previous Sessions
tf.keras.backend.clear_session()
for t in temperatureList:
    DISTILLATION_TEMPERATURE = t
    
    tModel = tf.keras.applications.mobilenet_v2.MobileNetV2(
        include_top=False, 
        input_shape=IMG_SHAPE
    )

    # Feature extraction
    imageBatch, labelBatch = next(iter(trainDataset))
    featureBatch = tModel(imageBatch)

    # Freeze the basic model
    tModel.trainable = False

    # Add a classification head
    globalAverageLayer = tf.keras.layers.GlobalAveragePooling2D()
    predictionLayer = tf.keras.layers.Dense(2)

    preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input
    inputs = tf.keras.Input(shape=IMG_SHAPE)
    x = data_augmentation(inputs)
    x = preprocess_input(x)
    x = tModel(inputs, training=False)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Dropout(0.2)(x)
    outputs = tf.keras.layers.Dense(2)(x)
    
    tmp = tf.keras.Model(inputs, outputs)
    tmpDistiller = Distiller(student=tmp, teacher=teacherModel)
    tmpDistiller.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
        distillation_loss_fn=tf.keras.losses.KLDivergence(),
        student_loss_fn=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
        metrics=[tf.keras.metrics.AUC(from_logits=True, name='auc')],
        alpha=ALPHA,
        temperature=t,
    )
    tmpDistiller.fit(trainDataset, epochs=INITIAL_EPOCHS, validation_data=testDataset)
    
    tModel.trainable = True
    fine_tune_at = 100
    for layer in tModel.layers[:fine_tune_at]:
        layer.trainable = False

    distiller = Distiller(student=tmp, teacher=teacherModel)
    distiller.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
        distillation_loss_fn=tf.keras.losses.KLDivergence(),
        student_loss_fn=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
        metrics=[tf.keras.metrics.AUC(from_logits=True, name='auc')],
        alpha=ALPHA,
        temperature=t,
    )
    history = distiller.fit(trainDataset, epochs=(FINE_TUNE_EPOCHS + INITIAL_EPOCHS), validation_data=testDataset, initial_epoch=INITIAL_EPOCHS)
    
    AUCList.append(history.history['val_auc'][-1])
    # Clear Previous Sessions
    tf.keras.backend.clear_session()

def draw(temperatureList: list, AUCList: list):
    plt.figure()
    
    plt.title('Student AUC vs Temperature Hyperparameters')
    plt.xlabel('Temperature Hyperparameters')
    plt.ylabel('Student AUC')
    
    plt.plot(AUCList, marker='x')
    xi = list(range(len(temperatureList)))
    plt.xticks(xi, temperatureList)
    
    plt.show()
    
    return 0

draw(temperatureList, AUCList)

Epoch 1/10
Epoch 2/10
 9/68 [==>...........................] - ETA: 1:50 - auc: 0.7951 - student_loss: 0.5956 - distillation_loss: 0.1573

KeyboardInterrupt: 

# Train student from scratch

In [9]:
# Import MobileNetV2
bModel = tf.keras.applications.mobilenet_v2.MobileNetV2(
    include_top=False, 
    input_shape=IMG_SHAPE
)

# Feature extraction
imageBatch, labelBatch = next(iter(trainDataset))
featureBatch = bModel(imageBatch)

# Freeze the basic model
bModel.trainable = False
bModel.summary()

# Add a classification head
globalAverageLayer = tf.keras.layers.GlobalAveragePooling2D()
predictionLayer = tf.keras.layers.Dense(2)

preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input
inputs = tf.keras.Input(shape=IMG_SHAPE)
x = data_augmentation(inputs)
x = preprocess_input(x)
x = bModel(inputs, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = tf.keras.layers.Dense(2)(x)
studentModelNoKD = tf.keras.Model(inputs, outputs)
studentModelNoKD.summary()
studentModelNoKD.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss=tfr.keras.losses.SoftmaxLoss(),
    metrics=[tf.keras.metrics.AUC(from_logits=True, name='auc')]
)
studentModelNoKD.fit(trainDataset, epochs=INITIAL_EPOCHS, validation_data=testDataset)

bModel.trainable = True
fine_tune_at = 100
for layer in bModel.layers[:fine_tune_at]:
  layer.trainable = False

studentModelNoKD.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=tfr.keras.losses.SoftmaxLoss(),
    metrics=[tf.keras.metrics.AUC(from_logits=True, name='auc')]
)
studentModelNoKD.summary()
studentModelNoKD.fit(trainDataset, epochs=(FINE_TUNE_EPOCHS + INITIAL_EPOCHS), validation_data=testDataset, initial_epoch=INITIAL_EPOCHS)

Model: "mobilenetv2_1.00_224"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_3 (InputLayer)        [(None, 224, 224, 3)]        0         []                            
                                                                                                  
 Conv1 (Conv2D)              (None, 112, 112, 32)         864       ['input_3[0][0]']             
                                                                                                  
 bn_Conv1 (BatchNormalizati  (None, 112, 112, 32)         128       ['Conv1[0][0]']               
 on)                                                                                              
                                                                                                  
 Conv1_relu (ReLU)           (None, 112, 112, 32)         0         ['bn_Conv1[

KeyboardInterrupt: 

# Comparing the teacher and student model (number of of parameters and FLOPs) 

In [None]:
from keras_flops import get_flops

def statistics(model):
    total_parameters = 0
    for variable in model.trainable_variables:
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim
        total_parameters += variable_parameters
    print(f'Name: {model.name}, Parameter Count: {total_parameters}, FLOP Count: {get_flops(model, batch_size=BATCH_SIZE)}')

statistics(teacherModel)
statistics(studentModel)

Name: model_14, Parameter Count: 4098, FLOP Count: 113860051008
Name: model_15, Parameter Count: 2562, FLOP Count: 9784125504
