# Project A: Knowledge Distillation for Building Lightweight Deep Learning Models in Visual Classification Tasks

In [10]:
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
from typing import Union
from tensorflow.keras.models import load_model

tf.enable_v2_behavior()

builder = tfds.builder('mnist')
BATCH_SIZE = 256
NUM_EPOCHS = 12
NUM_CLASSES = 10  # 10 total classes.


# Data loading

In [2]:
# Load train and test splits.
def preprocess(x):
  image = tf.image.convert_image_dtype(x['image'], tf.float32)
  subclass_labels = tf.one_hot(x['label'], builder.info.features['label'].num_classes)
  return image, subclass_labels


mnist_train = tfds.load('mnist', split='train', shuffle_files=False).cache()
mnist_train = mnist_train.map(preprocess)
mnist_train = mnist_train.shuffle(builder.info.splits['train'].num_examples)
mnist_train = mnist_train.batch(BATCH_SIZE, drop_remainder=True)

mnist_test = tfds.load('mnist', split='test').cache()
mnist_test = mnist_test.map(preprocess).batch(BATCH_SIZE)

# Model creation

In [3]:
#@test {"output": "ignore"}

# Build CNN teacher.
cnn_model = tf.keras.Sequential()

# your code start from here for stpe 2
cnn_model.add(tf.keras.layers.Conv2D(filters = 32,kernel_size = (3,3),
                                     strides=(1, 1),activation ='relu'))
cnn_model.add(tf.keras.layers.MaxPool2D(pool_size = (2,2),strides =(1,1)))
cnn_model.add(tf.keras.layers.Conv2D(filters = 64,kernel_size = (3,3),
                                     strides=(1, 1),activation ='relu'))
cnn_model.add(tf.keras.layers.MaxPool2D(pool_size = (2,2),strides =(2,2)))
cnn_model.add(tf.keras.layers.Flatten())
cnn_model.add(tf.keras.layers.Dropout(0.5))
cnn_model.add(tf.keras.layers.Dense(128, activation = 'relu'))
cnn_model.add(tf.keras.layers.Dropout(0.5))
cnn_model.add(tf.keras.layers.Dense(10))

In [4]:
#### Implement intermediate teacher model (TA model)
cnn_ta_model = tf.keras.Sequential()

# your code start from here for stpe 2
cnn_ta_model.add(tf.keras.layers.Conv2D(filters = 16, kernel_size = (3,3),
            strides=(1, 1), activation ='relu', input_shape = (28, 28, 1)))
cnn_ta_model.add(tf.keras.layers.MaxPool2D(pool_size = (2,2),strides =(1,1)))
cnn_ta_model.add(tf.keras.layers.Conv2D(filters = 32, kernel_size = (3,3),
            strides=(1, 1), activation ='relu'))
cnn_ta_model.add(tf.keras.layers.MaxPool2D(pool_size = (2,2),strides =(2,2)))
cnn_ta_model.add(tf.keras.layers.Flatten())
cnn_ta_model.add(tf.keras.layers.Dropout(0.25))
cnn_ta_model.add(tf.keras.layers.Dense(128, activation = 'relu'))
cnn_ta_model.add(tf.keras.layers.Dropout(0.25))
cnn_ta_model.add(tf.keras.layers.Dense(10))

In [5]:
# Build fully connected student.
fc_model = tf.keras.Sequential()

# your code start from here for step 2
fc_model.add(tf.keras.layers.Flatten())
fc_model.add(tf.keras.layers.Dense(784,activation = 'relu'))
fc_model.add(tf.keras.layers.Dense(784,activation = 'relu'))
fc_model.add(tf.keras.layers.Dense(10))

# Teacher loss function

In [None]:
@tf.function
def compute_teacher_loss(images, labels):
  """Compute subclass knowledge distillation teacher loss for given images
     and labels.

  Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Scalar loss Tensor.
  """
  subclass_logits = cnn_model(images, training=True)

  # Compute cross-entropy loss for subclasses.

  # your code start from here for step 3
  cross_entropy_loss_value = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( \
          labels, subclass_logits))


  return cross_entropy_loss_value

# TA loss function

In [6]:
#@test {"output": "ignore"}

# Hyperparameters for distillation (need to be tuned).
ALPHA = 0.5 # task balance between cross-entropy and distillation loss
DISTILLATION_TEMPERATURE = 4. #temperature hyperparameter

def distillation_loss(teacher_logits: tf.Tensor, student_logits: tf.Tensor,
                      temperature: Union[float, tf.Tensor]):
  """Compute distillation loss.

  This function computes cross entropy between softened logits and softened
  targets. The resulting loss is scaled by the squared temperature so that
  the gradient magnitude remains approximately constant as the temperature is
  changed. For reference, see Hinton et al., 2014, "Distilling the knowledge in
  a neural network."

  Args:
    teacher_logits: A Tensor of logits provided by the teacher.
    student_logits: A Tensor of logits provided by the student, of the same
      shape as `teacher_logits`.
    temperature: Temperature to use for distillation.

  Returns:
    A scalar Tensor containing the distillation loss.
  """
 # your code start from here for step 3
  soft_targets = tf.nn.softmax(teacher_logits/temperature, axis = -1)

  return tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits(
          soft_targets, student_logits / temperature)) * temperature ** 2




### TA loss
def compute_ta_loss(images, labels):
  """Compute subclass knowledge distillation student loss for given images
     and labels.

  Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Scalar loss Tensor.
  """
  ta_subclass_logits = cnn_ta_model(images, training=True)

  # Compute subclass distillation loss between student subclass logits and
  # softened teacher subclass targets probabilities.

  # your code start from here for step 3

  teacher_subclass_logits = cnn_model(images, training=False)
  distillation_loss_value = distillation_loss(teacher_subclass_logits,
                ta_subclass_logits, DISTILLATION_TEMPERATURE)

  # Compute cross-entropy loss with hard targets.

  # your code start from here for step 3

  cross_entropy_loss_value = tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits(labels, ta_subclass_logits))

  return ALPHA*distillation_loss_value + (1 - ALPHA)*cross_entropy_loss_value



# Student loss function

In [7]:
def compute_student_loss(images, labels):
  """Compute subclass knowledge distillation student loss for given images
     and labels.

  Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Scalar loss Tensor.
  """
  student_subclass_logits = fc_model(images, training=True)

  # Compute subclass distillation loss between student subclass logits and
  # softened teacher subclass targets probabilities.

  # your code start from here for step 3

  ta_subclass_logits = cnn_ta_model(images, training=False)
  ta_distillation_loss_value = distillation_loss(ta_subclass_logits,
                student_subclass_logits, DISTILLATION_TEMPERATURE)

  # teacher_subclass_logits = cnn_model(images, training=False)
  # teacher_distillation_loss_value = distillation_loss(teacher_subclass_logits,
  #               student_subclass_logits, DISTILLATION_TEMPERATURE)

  # Compute cross-entropy loss with hard targets.

  # your code start from here for step 3

  cross_entropy_loss_value = tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits(labels, student_subclass_logits))

  return ALPHA*ta_distillation_loss_value + (1 - ALPHA)*cross_entropy_loss_value

# Train and evaluation

In [8]:
@tf.function
def compute_num_correct(model, images, labels):
  """Compute number of correctly classified images in a batch.

  Args:
    model: Instance of tf.keras.Model.
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Number of correctly classified images.
  """
  class_logits = model(images, training=False)
  return tf.reduce_sum(
      tf.cast(tf.math.equal(tf.argmax(class_logits, -1), tf.argmax(labels, -1)),
              tf.float32)), tf.argmax(class_logits, -1), tf.argmax(labels, -1)


def train_and_evaluate(model, compute_loss_fn):
  """Perform training and evaluation for a given model.

  Args:
    model: Instance of tf.keras.Model.
    compute_loss_fn: A function that computes the training loss given the
      images, and labels.
  """

  # your code start from here for step 4
  optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

  for epoch in range(1, NUM_EPOCHS + 1):
    # Run training.
    print('Epoch {}: '.format(epoch), end='')
    for images, labels in mnist_train:
      with tf.GradientTape() as tape:
         # your code start from here for step 4

        loss_value = compute_loss_fn(images, labels)

      grads = tape.gradient(loss_value, model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))

    # Run evaluation.
    num_correct = 0
    num_total = builder.info.splits['test'].num_examples
    for images, labels in mnist_test:
      # your code start from here for step 4
      num_correct_batch, pred_digit, true_digit = compute_num_correct(model, images, labels)
      num_correct += num_correct_batch
    print("Class_accuracy: " + '{:.2f}%'.format(
        num_correct / num_total * 100))


# Training models

Do the training and evaluation as follows
*   Train teacher (this was already done in previous points)
*   Train TA as a student w.r.t to the teacher
*   Train student as student w.r.t to TA (as teacher)






In [11]:
# Load Teaher Model From previous problems
print('Loading Teacher Model')
cnn_model = load_model('Teacher_Model_Task1.h5')

Loading Teacher Model




In [12]:
# TA training
print('Training and Evaluating TA WRT to teacher model at ALPHA = ' + str(ALPHA) + \
      ' T = ' + str(DISTILLATION_TEMPERATURE))
train_and_evaluate(cnn_ta_model, compute_ta_loss)

Training and Evaluating TA WRT to teacher model at ALPHA = 0.5 T = 4.0
Epoch 1: Class_accuracy: 97.72%
Epoch 2: Class_accuracy: 98.50%
Epoch 3: Class_accuracy: 98.75%
Epoch 4: Class_accuracy: 98.85%
Epoch 5: Class_accuracy: 98.97%
Epoch 6: Class_accuracy: 99.02%
Epoch 7: Class_accuracy: 99.06%
Epoch 8: Class_accuracy: 99.06%
Epoch 9: Class_accuracy: 99.16%
Epoch 10: Class_accuracy: 99.13%
Epoch 11: Class_accuracy: 99.14%
Epoch 12: Class_accuracy: 99.16%


In [13]:
# Student Training
print('Training and Evaluating Student Model WRT to TA model at ALPHA = ' + str(ALPHA) + \
      ' T = ' + str(DISTILLATION_TEMPERATURE))
train_and_evaluate(fc_model, compute_student_loss)

Training and Evaluating Student Model WRT to TA model at ALPHA = 0.5 T = 4.0
Epoch 1: Class_accuracy: 96.63%
Epoch 2: Class_accuracy: 97.87%
Epoch 3: Class_accuracy: 98.16%
Epoch 4: Class_accuracy: 98.45%
Epoch 5: Class_accuracy: 98.55%
Epoch 6: Class_accuracy: 98.41%
Epoch 7: Class_accuracy: 98.63%
Epoch 8: Class_accuracy: 98.62%
Epoch 9: Class_accuracy: 98.65%
Epoch 10: Class_accuracy: 98.60%
Epoch 11: Class_accuracy: 98.67%
Epoch 12: Class_accuracy: 98.70%


In [14]:
# your code start from here for step 8
## https://github.com/tensorflow/tensorflow/issues/32809#issuecomment-849439287
from tensorflow.python.profiler.model_analyzer import profile
from tensorflow.python.profiler.option_builder import ProfileOptionBuilder
#print('TensorFlow:', tf.__version__)
def get_flops_number(model):
  forward_pass = tf.function(model.call,
      input_signature=[tf.TensorSpec(shape=(1,) + model.input_shape[1:])])

  graph_info = profile(forward_pass.get_concrete_function().graph,
                        options=ProfileOptionBuilder.float_operation())

  # The //2 is necessary since `profile` counts multiply and accumulate
  # as two flops, here we report the total number of multiply accumulate ops
  flops = graph_info.total_float_ops // 2
  return flops

In [15]:
#### Teacher Model Architecture #######
cnn_model.summary()
teacher_flops = get_flops_number(cnn_model)
print('Teacher Flops: {:,}'.format(teacher_flops))

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 26, 26, 32)        320       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 25, 25, 32)       0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 23, 23, 64)        18496     
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 11, 11, 64)       0         
 2D)                                                             
                                                                 
 flatten (Flatten)           (None, 7744)              0         
                                                                 
 dropout (Dropout)           (None, 7744)              0

Instructions for updating:
Use `tf.compat.v1.graph_util.tensor_shape_from_node_def_name`


Teacher Flops: 11,021,029


In [16]:
#### Teacher Assistant Model Architecture #######
cnn_ta_model.summary()
ta_flops = get_flops_number(cnn_ta_model)
print('TA Flops: {:,}'.format(ta_flops))

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_2 (Conv2D)           (None, 26, 26, 16)        160       
                                                                 
 max_pooling2d_2 (MaxPooling  (None, 25, 25, 16)       0         
 2D)                                                             
                                                                 
 conv2d_3 (Conv2D)           (None, 23, 23, 32)        4640      
                                                                 
 max_pooling2d_3 (MaxPooling  (None, 11, 11, 32)       0         
 2D)                                                             
                                                                 
 flatten_1 (Flatten)         (None, 3872)              0         
                                                                 
 dropout_2 (Dropout)         (None, 3872)             

In [17]:
#### Student Model Architecture #######
fc_model.summary()
student_flops = get_flops_number(fc_model)
print('Student Flops: {:,}'.format(student_flops))

Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten_2 (Flatten)         (None, 784)               0         
                                                                 
 dense_4 (Dense)             (None, 784)               615440    
                                                                 
 dense_5 (Dense)             (None, 784)               615440    
                                                                 
 dense_6 (Dense)             (None, 10)                7850      
                                                                 
Total params: 1,238,730
Trainable params: 1,238,730
Non-trainable params: 0
_________________________________________________________________
Student Flops: 1,237,941
