# Project A: Knowledge Distillation for Building Lightweight Deep Learning Models in Visual Classification Tasks

In [22]:
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
from typing import Union
import csv
from PIL import Image
import numpy as np
import os
from tensorflow.keras.models import load_model

tf.enable_v2_behavior()

split_data = False   ## Use it only once to split data into train/test using annotation file
BATCH_SIZE = 32
INITIAL_EPOCHS = 10
FINE_TUNE_EPOCHS = 25
NUM_CLASSES = 2
run_type = 'local'  ## 'local' or 'colab'
train_ta = True
model = 'B3'    ## TA model selection 'B3' or 'B2' or 'D121'

In [23]:
if run_type == 'local':
    data_path = 'G:/My Drive/Colab Notebooks/MHIST/'
else:
    # Giving Access to Google Drive for loading data
    from google.colab import drive
    drive.mount('/content/drive', force_remount = True)
    data_path = '/content/drive/My Drive/Colab Notebooks/MHIST/'

# Data loading

In [6]:
def train_test_split(output_dir):
  import shutil

  file_name = 'annotations.csv'
  ## Create directories for output train/test
  os.makedirs(output_dir, exist_ok=True)
  os.makedirs(output_dir + 'train/HP/', exist_ok=True)
  os.makedirs(output_dir + 'train/SSA/', exist_ok=True)
  os.makedirs(output_dir + 'test/HP/', exist_ok=True)
  os.makedirs(output_dir + 'test/SSA/', exist_ok=True)

  with open(data_path + file_name) as csv_file:
    csv_reader = csv.reader(csv_file, delimiter=',')

    for line, row in enumerate(csv_reader):
        if line > 0:
          print('[INFO] Image number:' + str(line) + ' - Image Name: ' + row[0])
          src = data_path + 'images_unzip/' + row[0]
          dst = output_dir + row[3] + '/' + row[1] + '/' + row[0]
          shutil.copy(src, dst)

if split_data:
  # Just do it once for data building
  train_test_split(data_path + 'images_split/')

builder = tfds.ImageFolder(data_path + 'images_split/', shape = (224, 224, 3))
print(builder.info)

tfds.core.DatasetInfo(
    name='image_folder',
    full_name='image_folder/1.0.0',
    description="""
    Generic image classification dataset.
    """,
    homepage='https://www.tensorflow.org/datasets/catalog/image_folder',
    data_dir='/root/tensorflow_datasets/image_folder/1.0.0',
    file_format=tfrecord,
    download_size=Unknown size,
    dataset_size=Unknown size,
    features=FeaturesDict({
        'image': Image(shape=(224, 224, 3), dtype=uint8),
        'image/filename': Text(shape=(), dtype=string),
        'label': ClassLabel(shape=(), dtype=int64, num_classes=2),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=977, num_shards=1>,
        'train': <SplitInfo num_examples=2175, num_shards=1>,
    },
    citation="""""",
)


In [7]:
# Load train and test splits.
def preprocess(x):
  image, label = x['image'], x['label']
  image = tf.image.convert_image_dtype(image, tf.float32)
  subclass_labels = tf.one_hot(label, builder.info.features['label'].num_classes)


  return image, subclass_labels


mhist_train = builder.as_dataset(split='train', shuffle_files=False).cache()
mhist_train = mhist_train.map(preprocess)
mhist_train = mhist_train.shuffle(builder.info.splits['train'].num_examples)
mhist_train = mhist_train.batch(BATCH_SIZE, drop_remainder=True)

mhist_test = builder.as_dataset(split='test').cache()
mhist_test = mhist_test.map(preprocess).batch(BATCH_SIZE)

# Data Augmentation

In [8]:
IMG_SHAPE = (224, 224, 3)
data_augmentation = tf.keras.Sequential([
  tf.keras.layers.RandomFlip('horizontal'),
  tf.keras.layers.RandomRotation(0.2),
])

In [9]:
# your code start from here for step 8
## https://github.com/tensorflow/tensorflow/issues/32809#issuecomment-849439287
from tensorflow.python.profiler.model_analyzer import profile
from tensorflow.python.profiler.option_builder import ProfileOptionBuilder
#print('TensorFlow:', tf.__version__)
def get_flops_number(model):
  forward_pass = tf.function(model.call,
      input_signature=[tf.TensorSpec(shape=(1,) + model.input_shape[1:])])

  graph_info = profile(forward_pass.get_concrete_function().graph,
                        options=ProfileOptionBuilder.float_operation())

  # The //2 is necessary since `profile` counts multiply and accumulate
  # as two flops, here we report the total number of multiply accumulate ops
  flops = graph_info.total_float_ops // 2
  return flops

# Model creation

In [10]:
#@test {"output": "ignore"}
################## Teacher model  ###############################
## We load Teacher model trained from previous problems
print('Load pre-trained teacher model')
if run_type == 'local':
  teacher_model = load_model('Teacher_Model_Task2.h5')
  hist = np.load('Teacher_Model_Task2_hist.npz')
else:
  teacher_model = load_model('/content/drive/My Drive/Colab Notebooks/Teacher_Model_Task2.h5')
  hist = np.load('/content/drive/My Drive/Colab Notebooks/Teacher_Model_Task2_hist.npz')

print('Average F1 score and class accuracy for the last five epochs')
print('Teacher test F1 score = ' + str(np.mean(hist['f1_hist'][-5:])))
print('Teacher test class accuracy = ' + str(np.mean(hist['acc_hist'][-5:])))

print('######### Teacher Model, Number of Parameters, Flops ################')
teacher_model.summary()
teacher_flops = get_flops_number(teacher_model)
print('Teacher Flops: {:,}'.format(teacher_flops))

Load pre-trained teacher model




Average F1 score and class accuracy for the last five epochs
Teacher test F1 score = 0.7882733316229258
Teacher test class accuracy = 85.322426
######### Teacher Model, Number of Parameters, Flops ################
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 data_augmentation (Sequent  (None, 224, 224, 3)       0         
 ial)                                                            
                                                                 
 resnet50v2 (Functional)     (None, 7, 7, 2048)        23564800  
                                                                 
 global_average_pooling2d (  (None, 2048)              0         
 GlobalAveragePooling2D)                                         
                                             

Instructions for updating:
This API was designed for TensorFlow v1. See https://www.tensorflow.org/guide/migrate for instructions on how to migrate your code to TensorFlow v2.


Teacher Flops: 3,489,281,549


In [24]:
## TA Model (we choose an intermediate model EfficientNetB3)
if model == 'B3':
  ta_base = tf.keras.applications.EfficientNetB3(include_top=False,
                                  weights="imagenet", input_shape=IMG_SHAPE)
  ## Remove the first processing units (Normalization/rescaling)
  new_input = ta_base.layers[4].input
  new_output = ta_base.layers[-1].output
  # Create a new model starting from the fourth layer
  ta_base = tf.keras.Model(inputs=new_input, outputs=new_output, name = 'effecientnetB3')
elif model == 'B2':
  ta_base = tf.keras.applications.EfficientNetB2(include_top=False,
            weights="imagenet", input_shape=IMG_SHAPE)
  ## Remove the first processing units (Normalization/rescaling)
  new_input = ta_base.layers[4].input
  new_output = ta_base.layers[-1].output
  # Create a new model starting from the fourth layer
  ta_base = tf.keras.Model(inputs=new_input, outputs=new_output, name = 'effecientnetB2')
elif model == 'D121':
  ta_base = tf.keras.applications.DenseNet121(include_top=False,
            weights="imagenet", input_shape=IMG_SHAPE)



ta_base.trainable = False
inputs = tf.keras.Input(shape=IMG_SHAPE)
x = data_augmentation(inputs)
x = ta_base(x, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(512, activation='relu')(x)
x = tf.keras.layers.Dropout(0.15)(x)
logits = tf.keras.layers.Dense(NUM_CLASSES)(x)
ta_model = tf.keras.Model(inputs, logits)

print('######### TA Model, Number of Parameters, Flops ################')
ta_model.summary()
ta_flops = get_flops_number(ta_model)
print('TA Flops: {:,}'.format(ta_flops))
print('Number of layers in Base Model = ' + str(len(ta_base.layers)))

######### TA Model, Number of Parameters, Flops ################
Model: "model_10"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_27 (InputLayer)       [(None, 224, 224, 3)]     0         
                                                                 
 sequential (Sequential)     (None, 224, 224, 3)       0         
                                                                 
 effecientnetB3 (Functional)  (None, 7, 7, 1536)       10783528  
                                                                 
 global_average_pooling2d_4   (None, 1536)             0         
 (GlobalAveragePooling2D)                                        
                                                                 
 dense_8 (Dense)             (None, 512)               786944    
                                                                 
 dropout_4 (Dropout)         (None, 512)               0   

In [12]:
## Student Model
# load trained base model
mobilenet_base = tf.keras.applications.MobileNetV2(include_top=False,
    input_shape=IMG_SHAPE, weights="imagenet")
mobilenet_layers_num = len(mobilenet_base.layers)

mobilenet_base.trainable = False

inputs = tf.keras.Input(shape=IMG_SHAPE)
x = data_augmentation(inputs)
x = mobilenet_base(x, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(512, activation='relu')(x)
x = tf.keras.layers.Dropout(0.1)(x)
logits = tf.keras.layers.Dense(NUM_CLASSES)(x)

# Build teacher model with inputs and outputs
student_model = tf.keras.Model(inputs, logits)

print('######### Student Model, Number of Parameters, Flops ################')
student_model.summary()
student_flops = get_flops_number(student_model)
print('Student Flops: {:,}'.format(student_flops))

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5
######### Student Model, Number of Parameters, Flops ################
Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_4 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 sequential (Sequential)     (None, 224, 224, 3)       0         
                                                                 
 mobilenetv2_1.00_224 (Func  (None, 7, 7, 1280)        2257984   
 tional)                                                         
                                                                 
 global_average_pooling2d_1  (None, 1280)              0         
  (GlobalAveragePooling2D)                                       
                         

# Distillation Loss

In [13]:
#@test {"output": "ignore"}

# Hyperparameters for distillation (need to be tuned).
ALPHA = 0.5 # task balance between cross-entropy and distillation loss
DISTILLATION_TEMPERATURE = 4. #temperature hyperparameter

def distillation_loss(teacher_logits: tf.Tensor, student_logits: tf.Tensor,
                      temperature: Union[float, tf.Tensor]):
  """Compute distillation loss.

  This function computes cross entropy between softened logits and softened
  targets. The resulting loss is scaled by the squared temperature so that
  the gradient magnitude remains approximately constant as the temperature is
  changed. For reference, see Hinton et al., 2014, "Distilling the knowledge in
  a neural network."

  Args:
    teacher_logits: A Tensor of logits provided by the teacher.
    student_logits: A Tensor of logits provided by the student, of the same
      shape as `teacher_logits`.
    temperature: Temperature to use for distillation.

  Returns:
    A scalar Tensor containing the distillation loss.
  """
 # your code start from here for step 3
  soft_targets = tf.nn.softmax(teacher_logits/temperature, axis = -1)

  return tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits(
          soft_targets, student_logits / temperature)) * temperature ** 2

# TA loss function

In [14]:
def compute_ta_loss(images, labels):
  """Compute subclass knowledge distillation student loss for given images
     and labels.

  Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Scalar loss Tensor.
  """
  ta_subclass_logits = ta_model(images, training=True)

  # Compute subclass distillation loss between student subclass logits and
  # softened teacher subclass targets probabilities.

  # your code start from here for step 3

  teacher_subclass_logits = teacher_model(images, training=False)
  distillation_loss_value = distillation_loss(teacher_subclass_logits,
                ta_subclass_logits, DISTILLATION_TEMPERATURE)

  # Compute cross-entropy loss with hard targets.

  # your code start from here for step 3

  cross_entropy_loss_value = tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits(labels, ta_subclass_logits))

  return ALPHA*distillation_loss_value + (1 - ALPHA)*cross_entropy_loss_value

# Student loss function

In [15]:
def compute_student_loss(images, labels):
  """Compute subclass knowledge distillation student loss for given images
     and labels.

  Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Scalar loss Tensor.
  """
  student_subclass_logits = student_model(images, training=True)

  # Compute subclass distillation loss between student subclass logits and
  # softened teacher subclass targets probabilities.

  # your code start from here for step 3

  ta_subclass_logits = ta_model(images, training = False)
  distillation_loss_value = distillation_loss(ta_subclass_logits,
                student_subclass_logits, DISTILLATION_TEMPERATURE)

  # Compute cross-entropy loss with hard targets.

  # your code start from here for step 3

  cross_entropy_loss_value = tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits(labels, student_subclass_logits))

  return ALPHA*distillation_loss_value + (1 - ALPHA)*cross_entropy_loss_value

# Train and Evaluation

In [16]:
from sklearn.metrics import f1_score

def compute_F1_score(true_binary, pred_binary):
  """Compute F1 score between true labels and prediction.

  Args:
    model: Instance of tf.keras.Model.
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Number of correctly classified images.
  """
  return f1_score(true_binary, pred_binary)

def get_binary_labels_batch(model, images, labels):
  class_logits = model(images, training=False)
  pred_binary = tf.argmax(class_logits, -1).numpy()
  true_binary = tf.argmax(labels, -1).numpy()
  return true_binary, pred_binary

@tf.function
def compute_num_correct(model, images, labels):
  """Compute number of correctly classified images in a batch.

  Args:
    model: Instance of tf.keras.Model.
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Number of correctly classified images.
  """
  class_logits = model(images, training=False)
  return tf.reduce_sum(
      tf.cast(tf.math.equal(tf.argmax(class_logits, -1), tf.argmax(labels, -1)),
              tf.float32)), tf.argmax(class_logits, -1), tf.argmax(labels, -1)


def train_and_evaluate(model, compute_loss_fn, lr, num_epochs):
  """Perform training and evaluation for a given model.

  Args:
    model: Instance of tf.keras.Model.
    compute_loss_fn: A function that computes the training loss given the
      images, and labels.
  """

  # your code start from here for step 4
  optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

  f1_score_epochs, class_acc_epochs = [], []
  for epoch in range(1, num_epochs + 1):
    # Run training.
    print('Epoch {}: '.format(epoch), end='')
    for images, labels in mhist_train:
      with tf.GradientTape() as tape:
         # your code start from here for step 4

        loss_value = compute_loss_fn(images, labels)

      grads = tape.gradient(loss_value, model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))

    # Run evaluation.
    true_binary, pred_binary = [], []
    num_correct = 0
    num_total = builder.info.splits['test'].num_examples
    for images, labels in mhist_test:
      # your code start from here for step 4
      true_bin_batch, pred_bin_batch = get_binary_labels_batch(model, images, labels)
      true_binary.append(true_bin_batch)
      pred_binary.append(pred_bin_batch)

      num_correct_batch, pred_digit, true_digit = compute_num_correct(model, images, labels)
      num_correct += num_correct_batch

    ## Estimate F1 score
    true_binary = np.concatenate(true_binary)
    pred_binary = np.concatenate(pred_binary)
    F1_score = compute_F1_score(true_binary, pred_binary)
    f1_score_epochs.append(F1_score)
    class_acc_epochs.append(num_correct / num_total * 100)
    print("F1 score: {:.2f} - Class accuracy: {:.2f}%".format(F1_score,
                                            num_correct / num_total * 100))
  return f1_score_epochs, class_acc_epochs

# Training TA model

In [17]:
if train_ta == True:
  # Initial training for the TA  model
  print('INITIAL Training for TA Model')
  f1_hist, acc_hist = train_and_evaluate(ta_model, compute_ta_loss, 1e-3, INITIAL_EPOCHS)

  ## Fine Tuning Step (make all layers trainable)
  ta_base.trainable = True
  # Fine tune only top layers (e.g. above 100)
  fine_tune_at = 280
  for layer in ta_base.layers[:fine_tune_at]:
    layer.trainable = False
  print('FINE TUNING Training for TA Model')
  f1_hist, acc_hist = train_and_evaluate(ta_model, compute_ta_loss, 0.1*1e-3, FINE_TUNE_EPOCHS)

  print('Average F1 score and class accuracy for the last five epochs')
  print('F1 score = ' + str(np.mean(f1_hist[-5:])))
  print('class accuracy = ' + str(np.mean(acc_hist[-5:])))

  ## Save TA Model
  if run_type == 'local':
    ta_model.save('TA_Model_Task2_' + model + '.h5')
    np.savez('TA_Model_Task2_hist_' + model+ '.npz', f1_hist = f1_hist, acc_hist = acc_hist)
  elif run_type == 'colab':
    ta_model.save('/content/drive/My Drive/Colab Notebooks/TA_Model_Task2_' + model +'.h5')
    np.savez('/content/drive/My Drive/Colab Notebooks/TA_Model_Task2_hist_' + model +'.npz',
             f1_hist = f1_hist, acc_hist = acc_hist)

INITIAL Training for TA Model
Epoch 1: F1 score: 0.74 - Class accuracy: 78.20%
Epoch 2: F1 score: 0.67 - Class accuracy: 79.22%
Epoch 3: F1 score: 0.62 - Class accuracy: 78.20%
Epoch 4: F1 score: 0.74 - Class accuracy: 80.55%
Epoch 5: F1 score: 0.59 - Class accuracy: 76.56%
Epoch 6: F1 score: 0.73 - Class accuracy: 80.04%
Epoch 7: F1 score: 0.73 - Class accuracy: 80.66%
Epoch 8: F1 score: 0.65 - Class accuracy: 78.61%
Epoch 9: F1 score: 0.57 - Class accuracy: 76.05%
Epoch 10: F1 score: 0.74 - Class accuracy: 81.06%
FINE TUNING Training for TA Model
Epoch 1: 



F1 score: 0.75 - Class accuracy: 82.29%
Epoch 2: F1 score: 0.76 - Class accuracy: 82.19%
Epoch 3: F1 score: 0.78 - Class accuracy: 82.60%
Epoch 4: F1 score: 0.77 - Class accuracy: 83.93%
Epoch 5: F1 score: 0.78 - Class accuracy: 84.34%
Epoch 6: F1 score: 0.79 - Class accuracy: 84.85%
Epoch 7: F1 score: 0.74 - Class accuracy: 82.29%
Epoch 8: F1 score: 0.75 - Class accuracy: 83.62%
Epoch 9: F1 score: 0.71 - Class accuracy: 81.06%
Epoch 10: F1 score: 0.78 - Class accuracy: 83.62%
Epoch 11: F1 score: 0.75 - Class accuracy: 83.73%
Epoch 12: F1 score: 0.77 - Class accuracy: 84.34%
Epoch 13: F1 score: 0.76 - Class accuracy: 83.42%
Epoch 14: F1 score: 0.68 - Class accuracy: 80.25%
Epoch 15: F1 score: 0.76 - Class accuracy: 84.34%
Epoch 16: F1 score: 0.77 - Class accuracy: 83.83%
Epoch 17: F1 score: 0.76 - Class accuracy: 84.34%
Epoch 18: F1 score: 0.67 - Class accuracy: 80.04%
Epoch 19: F1 score: 0.70 - Class accuracy: 81.06%
Epoch 20: F1 score: 0.79 - Class accuracy: 83.21%
Epoch 21: F1 score

  saving_api.save_model(


F1 score: 0.76 - Class accuracy: 83.32%
Average F1 score and class accuracy for the last five epochs
F1 score = 0.7728414454765564
class accuracy = 83.97134


In [18]:
if train_ta == False:
  print('Load pre-trained teacher assistant model ' + model)
  if run_type == 'local':
    teacher_model = load_model('TA_Model_Task2_' + model + '.h5')
    hist = np.load('TA_Model_Task2_hist_' + model + '.npz')
  else:
    teacher_model = load_model('/content/drive/My Drive/Colab Notebooks/TA_Model_Task2_' + model +'.h5')
    hist = np.load('/content/drive/My Drive/Colab Notebooks/TA_Model_Task2_hist_' + model + '.npz')

  print('Average F1 score and class accuracy for the last five epochs')
  print('Teacher test F1 score = ' + str(np.mean(hist['f1_hist'][-5:])))
  print('Teacher test class accuracy = ' + str(np.mean(hist['acc_hist'][-5:])))

In [19]:
# Initial training for the student model
print('INITIAL Training for Student Model')
f1_hist, acc_hist = train_and_evaluate(student_model, compute_student_loss, 1e-3, INITIAL_EPOCHS)

## Fine Tuning Step (make all layers trainable)
mobilenet_base.trainable = True
# Fine tune only top layers (e.g. above 80)
fine_tune_at = 80
for layer in mobilenet_base.layers[:fine_tune_at]:
  layer.trainable = False
print('FINE TUNING Training for Student Model')
f1_hist, acc_hist = train_and_evaluate(student_model, compute_student_loss, 0.1*1e-3, FINE_TUNE_EPOCHS)


print('Average F1 score and class accuracy for the last five epochs')
print('F1 score = ' + str(np.mean(f1_hist[-5:])))
print('class accuracy = ' + str(np.mean(acc_hist[-5:])))

INITIAL Training for Student Model
Epoch 1: F1 score: 0.70 - Class accuracy: 78.20%
Epoch 2: F1 score: 0.40 - Class accuracy: 71.24%
Epoch 3: F1 score: 0.63 - Class accuracy: 77.69%
Epoch 4: F1 score: 0.69 - Class accuracy: 78.71%
Epoch 5: F1 score: 0.61 - Class accuracy: 76.77%
Epoch 6: F1 score: 0.57 - Class accuracy: 75.74%
Epoch 7: F1 score: 0.67 - Class accuracy: 77.38%
Epoch 8: F1 score: 0.60 - Class accuracy: 76.56%
Epoch 9: F1 score: 0.47 - Class accuracy: 72.06%
Epoch 10: F1 score: 0.40 - Class accuracy: 70.52%
FINE TUNING Training for Student Model
Epoch 1: F1 score: 0.22 - Class accuracy: 66.94%
Epoch 2: F1 score: 0.57 - Class accuracy: 76.36%
Epoch 3: F1 score: 0.51 - Class accuracy: 74.51%
Epoch 4: F1 score: 0.69 - Class accuracy: 80.55%
Epoch 5: F1 score: 0.75 - Class accuracy: 82.19%
Epoch 6: F1 score: 0.65 - Class accuracy: 79.32%
Epoch 7: F1 score: 0.70 - Class accuracy: 81.47%
Epoch 8: F1 score: 0.76 - Class accuracy: 82.91%
Epoch 9: F1 score: 0.55 - Class accuracy: 7