<a href="https://colab.research.google.com/github/nassma2019/PracticalSessions/blob/master/vision/distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Part II: [Distilling the knowledge](https://arxiv.org/pdf/1503.02531.pdf) from a (larger) teacher model
- Import an already trained baseline model to use as teacher
- Use a smaller baseline model as the student
- Add KL distillation loss between teacher and student
- Train the student classifier with this joint loss

#### Exercise:
- Fill in the code for the student loss


###  The total loss for the student is:
\begin{equation}
\mathcal{L} = \mathcal{L}_{\text{classif}} + \lambda \mathcal{L}_{\text{distill}}
\end{equation}

For classification loss we use the regular cross-entropy and for the distillation loss, we use Kullback-Leibler (KL) divergence. $\lambda$ is a normalisation factor explained below.


**Reminder**:

Given two distributions $t$ and $s$, we define their cross-entropy over a given set as:

$$H(t,s) = H(t) + \text{KL}(t,s),$$

where $H(t)$ is the entropy of $t$, i.e. $H(t) = \sum_{i=1}^{N}t(x_i) \cdot \log t(x_i)$

and $\text{KL}(t,s)$ is the KL divergence between $t$ and $s$, i.e. $\text{KL}(t,s) = \sum_{i=1}^{N}t(x_i) \cdot \log \frac{t(x_i)}{s(x_i)} . $

However, in most cases of interest to us, $t$ is a constant (either ground truth labels or teacher predictions also considered as constant), so the entropy term can be ignored since its gradient is 0. 

Hence we can use cross-entropy $H(t,s)$ for both losses: 
- the mismatch between ground truth and student predictions. 
- the mismatch between teacher and student distributions.

In the context of distillation, it is useful to also remember that the outputs of the network are logits, which we interpret as probabilities when passed through softmax:

$$p_i^{(T)} =\frac{\exp{(\text{logits}_i / T) }}{\sum_j \exp{(\text{logits}_j / T) }}. $$

$T$ is the softmax temperature usually set to 1. Setting it to a higher value smooths the output probability distribution, an effect desired in distillation. More precisely, we will use

\begin{equation}
\mathcal{L}_{\text{distill}} = H(\text{p}_{\text{teacher}}^{(T)}, \text{p}_{\text{student}}^{(T)}),
\end{equation}

**The normalisation factor** 

$\lambda$ is a normalisation factor that ensures the gradients of the two loss terms are comparable in scale. Note that the gradients of the distill loss term scale as $\frac{1}{T^2}$ due to the logits being divided by $T$.  Hence we use $$\lambda = T^2$$ to bring distillation term gradients to the same scale as the classification term gradients.





### Imports

In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import time

import tensorflow as tf

# Don't forget to select GPU runtime environment in Runtime -> Change runtime type
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

# we will use Sonnet on top of TF 
!pip install -q dm-sonnet
import sonnet as snt

import numpy as np

# Plotting library.
from matplotlib import pyplot as plt
import pylab as pl
from IPython import display

In [0]:
# Reset graph
tf.reset_default_graph()

### Copying the pretrained weights of baseline model on the virtual machine
- we download all three files to the Colab virtual machine:

In [0]:
!wget https://github.com/nassma2019/PracticalSessions/blob/master/vision/baseline/baseline.ckpt.data-00000-of-00001?raw=true -O baseline.ckpt.data-00000-of-00001
!wget https://github.com/nassma2019/PracticalSessions/blob/master/vision/baseline/baseline.ckpt.index?raw=true -O baseline.ckpt.index
!wget https://github.com/nassma2019/PracticalSessions/blob/master/vision/baseline/checkpoint?raw=true -O checkpoint

In [0]:
#@title (optional, if the cell above does not work)
# Uncomment `upload_to_colab` call and upload files from your computer instead.

def upload_to_colab():
  from google.colab import files

  uploaded = files.upload()

  for fn in uploaded.keys():
    print('User uploaded file "{name}" with length {length} bytes'.format(
        name=fn, length=len(uploaded[fn])))
    

# upload_to_colab()

### Download dataset to be used for training and testing
- Cifar-10 equivalent of MNIST for natural RGB images
- 60000 32x32 colour images in 10 classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck
- train: 50000; test: 10000

In [0]:
cifar10 = tf.keras.datasets.cifar10
# (down)load dataset
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()

In [0]:
#@title Prepare the data for training and testing
# define dimension of the batches to sample from the datasets
BATCH_SIZE_TRAIN = 64 #@param
BATCH_SIZE_TEST = 100 #@param

# create Dataset objects using the data previously downloaded
dataset_train = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
# we shuffle the data and sample repeatedly batches for training
batched_dataset_train = dataset_train.shuffle(100000).repeat().batch(BATCH_SIZE_TRAIN)
# create iterator to retrieve batches
iterator_train = batched_dataset_train.make_one_shot_iterator()
# get a training batch of images and labels
(batch_train_images, batch_train_labels) = iterator_train.get_next()

# we do the same for test dataset
dataset_test = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
batched_dataset_test = dataset_test.repeat().batch(BATCH_SIZE_TEST)
iterator_test = batched_dataset_test.make_one_shot_iterator() 
(batch_test_images, batch_test_labels) = iterator_test.get_next()

# Squeeze labels and convert from uint8 to int32 - required below by the loss op
batch_test_labels = tf.cast(tf.squeeze(batch_test_labels), tf.int32)
batch_train_labels = tf.cast(tf.squeeze(batch_train_labels), tf.int32)

In [0]:
#@title Preprocessing of data
# Data augmentation used for train preprocessing
# - scale image to [-1 , 1]
# - get a random crop
# - apply horizontal flip randomly

def train_image_preprocess(h, w, random_flip=True):
  """Image processing required for training the model."""
  
  def random_flip_left_right(image, flip_index, seed=None):
    shape = image.get_shape()
    if shape.ndims == 3 or shape.ndims is None:
      uniform_random = tf.random_uniform([], 0, 1.0, seed=seed)
      mirror_cond = tf.less(uniform_random, .5)
      result = tf.cond(
          mirror_cond,
          lambda: tf.reverse(image, [flip_index]),
          lambda: image
      )
      return result
    elif shape.ndims == 4:
      uniform_random = tf.random_uniform(
          [tf.shape(image)[0]], 0, 1.0, seed=seed
      )
      mirror_cond = tf.less(uniform_random, .5)
      return tf.where(
          mirror_cond,
          image,
          tf.map_fn(lambda x: tf.reverse(x, [flip_index]), image, dtype=image.dtype)
      )
    else:
      raise ValueError("\'image\' must have either 3 or 4 dimensions.")

  def fn(image):
    # Ensure the data is in range [-1, 1].
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    image = image * 2.0 - 1.0
    # Randomly choose a (h, w, 3) patch.
    image = tf.random_crop(image, size=(BATCH_SIZE_TRAIN, h, w, 3))
    # Randomly flip the image.
    image = random_flip_left_right(image, 2)
    return image

  return fn

# Test preprocessing: only scale to [-1,1].
def test_image_preprocess():
  def fn(image):
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    image = image * 2.0 - 1.0
    return image
  return fn

In [0]:
#@title Teacher model is the baseline
class Baseline(snt.AbstractModule):
  
  def __init__(self, num_classes, name="baseline"):
    super(Baseline, self).__init__(name=name)
    self._num_classes = num_classes
    self._output_channels = [
        64, 64, 128, 128, 128, 256, 256, 256, 512, 512, 512
        ]
    self._num_layers = len(self._output_channels)

    self._kernel_shapes = [[3, 3]] * self._num_layers  # All kernels are 3x3.
    self._strides = [1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1]
    self._paddings = [snt.SAME] * self._num_layers
   
  def _build(self, inputs, is_training=None, test_local_stats=False):
    net = inputs
    # instantiate all the convolutional layers 
    # and connect them to the graph, adding batch norm and non-linearity
    for i in range(self._num_layers):
      layer = snt.Conv2D(name="conv_2d_{}".format(i),
                         output_channels=self._output_channels[i],
                         kernel_shape=self._kernel_shapes[i],
                         stride=self._strides[i],
                         padding=self._paddings[i],
                         use_bias=True)
      net = layer(net)
      bn = snt.BatchNorm(name="batch_norm_{}".format(i))
      net = bn(net, is_training=is_training, test_local_stats=test_local_stats)
      net = tf.nn.relu(net)

    net = tf.reduce_mean(net, reduction_indices=[1, 2], keepdims=False,
                         name="avg_pool")

    logits = snt.Linear(self._num_classes)(net)

    return logits

In [0]:
#@title Teacher model is a smaller version of the baseline
class Student(snt.AbstractModule):
  
  def __init__(self, num_classes, name="student"):
    super(Student, self).__init__(name=name)
    self._num_classes = num_classes
    self._output_channels = [
        64, 128, 256, 512
        ]
    self._num_layers = len(self._output_channels)

    self._kernel_shapes = [[3, 3]] * self._num_layers  # All kernels are 3x3.
    self._strides = [1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1]
    self._paddings = [snt.SAME] * self._num_layers
   
  def _build(self, inputs, is_training=None, test_local_stats=False):
    net = inputs
    # instantiate all the convolutional layers 
    # and connect them to the graph, adding batch norm and non-linearity
    for i in range(self._num_layers):
      layer = snt.Conv2D(name="conv_2d_{}".format(i),
                         output_channels=self._output_channels[i],
                         kernel_shape=self._kernel_shapes[i],
                         stride=self._strides[i],
                         padding=self._paddings[i],
                         use_bias=True)
      net = layer(net)
      bn = snt.BatchNorm(name="batch_norm_{}".format(i))
      net = bn(net, is_training=is_training, test_local_stats=test_local_stats)
      net = tf.nn.relu(net)

    net = tf.reduce_mean(net, reduction_indices=[1, 2], keepdims=False,
                         name="avg_pool")

    logits = snt.Linear(self._num_classes)(net)

    return logits

### Model params

In [0]:
# First define the preprocessing ops for the train/test data
crop_height = 24 #@param
crop_width = 24 #@param
preprocess_fn_train = train_image_preprocess(crop_height, crop_width)
preprocess_fn_test = test_image_preprocess()

num_classes = 10 #@param

In [0]:
# for evaluation, we look at top_k_accuracy since it's easier to interpret; normally k=1 or k=5
def top_k_accuracy(k, labels, logits):
  in_top_k = tf.nn.in_top_k(predictions=tf.squeeze(logits), targets=tf.squeeze(tf.cast(labels, tf.int32)), k=k)
  return tf.reduce_mean(tf.cast(in_top_k, tf.float32))

### Instantiate teacher and load pre-trained weights


In [0]:
with tf.variable_scope("teacher"):
  teacher_model = Baseline(num_classes)
predictions_teacher = teacher_model(preprocess_fn_train(batch_train_images), is_training=False)


### We do not want to alter the teacher weights, so apply `tf.stop_gradients` to `predictions_teacher`

In [0]:
#@title EXERCISE.
# predictions_teacher = ############# YOUR CODE HERE #############

In [0]:
#@title SOLUTION.
predictions_teacher = tf.stop_gradient(predictions_teacher)

### Instantiate student

In [0]:
with tf.variable_scope("student"):
  student_model = Student(num_classes=num_classes)
# get predictions from the model
predictions_student = student_model(preprocess_fn_train(batch_train_images), is_training=True)
test_predictions_student = student_model(preprocess_fn_test(batch_test_images), is_training=False)

In [0]:
#@title Compare number of parameters between teacher and student
def get_num_params(scope):
  total_parameters = 0
  for variable in tf.trainable_variables(scope):
    # shape is an array of tf.Dimension
    shape = variable.get_shape()
    variable_parameters = 1
    for dim in shape:
      variable_parameters *= dim.value
    total_parameters += variable_parameters
  return total_parameters

print ('Number of paramters of teacher')
print (get_num_params('teacher'))
print ('Number of parameters of student')
print (get_num_params('student'))

### Set up the training for student, adding the distillation loss weighted by the square of temperature as explained above. 

Normally we use T = 1, but for distillation we use T>1, e.g. T=5. We will visualise later the impact of T on logits.

In [0]:
T_distill = 5.0 #@param 
T_normal = 1.0 #@param

#### First define the regular cross-entropy classification loss

In [0]:
def get_cross_entropy_loss(logits=None, labels=None):
  # We reduce over batch dimension, to ensure the loss is a scalar. 
  return tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits_v2(
          labels=tf.one_hot(tf.squeeze(labels), num_classes), logits=logits))

#### Define the distillation loss

You can do this either with

* `tf.distributions.kl_divergence` between the teacher and student distributions, respectively; or 
* `softmax_cross_entropy_with_logits`. Remember that in this case the labels are expected to sum to 1, while the output of the teacher network is logits. So we need to apply `softmax` on the `predictions_teacher`.


In [0]:
#@title EXERCISE.
# Using tf.distributions.kl_divergence
#########################
#                       #
# YOUR CODE             #
# distill_kl_loss = ... #         
#########################

# OR simpler, using tf.nn.softmax_cross_entropy_with_logits
#########################
#                       #
# YOUR CODE             #
# distill_kl_loss = ... #         
#########################

In [0]:
#@title SOLUTION.

# Using tf.distributions.kl_divergence
# pp = tf.distributions.Categorical(logits=predictions_teacher)
# qq = tf.distributions.Categorical(logits=predictions_student)

# distill_kl_loss = tf.reduce_mean(tf.distributions.kl_divergence(pp, qq))

# OR simpler, using cross entropy
scaled_predictions_teacher = tf.div(predictions_teacher, T_distill)
scaled_predictions_student = tf.div(predictions_student, T_distill)

distill_kl_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
    labels=tf.nn.softmax(scaled_predictions_teacher),
    logits=scaled_predictions_student))

**Define the joint training loss**

In [0]:
#@title EXERCISE.
# lambda_ = ############## YOUR CODE ##############
# train_loss = get_cross_entropy_loss(logits=predictions_student, labels=batch_train_labels)
# train_loss += lambda_ * distill_kl_loss

In [0]:
#@title SOLUTION.
lambda_ = T_distill * T_distill
train_loss = get_cross_entropy_loss(logits=predictions_student, labels=batch_train_labels)
train_loss += lambda_ * distill_kl_loss

In [0]:
#@title Set up the training; better to start with lower lr and longer training schedule
def get_optimizer(step):
  """Get the optimizer used for training."""
  lr_schedule = (80e3, 100e3, 110e3)
  lr_schedule = tf.to_int64(lr_schedule)
  lr_factor = 0.1
  
  lr_init = 0.01
  num_epochs = tf.reduce_sum(tf.to_float(step >= lr_schedule))
  lr = lr_init * lr_factor**num_epochs

  return tf.train.MomentumOptimizer(learning_rate=lr, momentum=0.9)

# Create a global step that is incremented during training; useful for e.g. learning rate annealing
global_step = tf.train.get_or_create_global_step()

# instantiate the optimizer
optimizer = get_optimizer(global_step)

# Get training ops, including BatchNorm update ops
training_op = optimizer.minimize(train_loss, global_step)
update_ops = tf.group(*tf.get_collection(tf.GraphKeys.UPDATE_OPS))
training_op = tf.group(training_op, update_ops)

# Display loss function
def plot_losses(loss_list, steps):
  display.clear_output(wait=True)
  display.display(pl.gcf())
  pl.plot(steps, loss_list, c='b')
  time.sleep(1.0)

### Teacher and student accuracy

In [0]:
test_acc = top_k_accuracy(1, batch_test_labels, test_predictions_student)

# We compute the accuracy of the teacher on the train set to make sure that
# the loading of the pre-trained weights was successful
acc_teacher = top_k_accuracy(1, batch_train_labels, predictions_teacher) 

### Define ops to visualise the impact of softmax temperature on output distributions

In [0]:
probs_high_temp = tf.nn.softmax(tf.div(predictions_teacher, T_distill)) 
probs_low_temp = tf.nn.softmax(tf.div(predictions_teacher, T_normal))

### Create the session 

In [0]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

### Load pre-trained weights for teacher, and check accuracy to make sure the import was successful

In [0]:
# Create saver to restore the pre-trained model
# First remove the scope name from variables name, since the name in the checkpoint doesn't include it
var_list = snt.get_variables_in_scope("teacher", collection=tf.GraphKeys.GLOBAL_VARIABLES)  
var_map = {}
for i in range(0, len(var_list)):
  name = var_list[i].name[len("teacher/"):-2]
  var_map[name] = var_list[i]
  
saver = tf.train.Saver(var_map, reshape=True)
saver.restore(sess, "baseline.ckpt")

num_batches = 100  # 100 batches * 64 samples per batch = 6400 out of 50000
avg_accuracy = 0.0
for _ in range(num_batches):
  accuracy = sess.run(acc_teacher)
  avg_accuracy += accuracy
avg_accuracy /= num_batches

# expected_accuracy > 0.90
print ("Teacher accuracy on a subset of the train set {:.3f}".format(avg_accuracy))

### Visualize the impact of temperature on the logits

In [0]:
probs_ht, probs_lt, ground_truth = sess.run([probs_high_temp, probs_low_temp, tf.one_hot(batch_train_labels, num_classes)])
# pick one sample and plot
idx = 10
# Optionally: display the ground truth.
# plt.plot(ground_truth[idx], 'b--', label='Ground truth')
plt.plot(probs_ht[idx], c='r', label='High Temp')
plt.plot(probs_lt[idx], c='g', label='Low Temp')
plt.xlim([0,9])
plt.legend()
plt.show()

### Train the model. Full training gives ~92% accuracy.

If running out of memory, reduce the BATCH_SIZE_TRAIN, e.g. 32 or 16.

Note that the execution is slower and more memory is needed now, since for each training iteration of the student we need to run the forward pass for the teacher as well.

In [0]:
#@title Training.

# Define number of training iterations and reporting intervals
TRAIN_ITERS = 120e3 #@param
REPORT_TRAIN_EVERY = 100 #@param
PLOT_EVERY = 500 #@param
REPORT_TEST_EVERY = 1000 #@param
TEST_ITERS = 50 #@param

train_iter = 0
losses = []
steps = []
for train_iter in range(int(TRAIN_ITERS)):
  _, train_loss_np = sess.run([training_op, train_loss])
  
  if (train_iter % REPORT_TRAIN_EVERY) == 0:
    losses.append(train_loss_np)
    steps.append(train_iter)
  if (train_iter % PLOT_EVERY) == 0:
    plot_losses(losses, steps)    
    
  if (train_iter % REPORT_TEST_EVERY) == 0:
    avg_acc = 0.0
    for test_iter in range(TEST_ITERS):
      acc = sess.run(test_acc)
      avg_acc += acc
      
    avg_acc /= (TEST_ITERS)
    print ('Test acc at iter {0:5d} out of {1:5d} is {2:.2f}%'.format(int(train_iter), int(TRAIN_ITERS), avg_acc*100.0))

In [0]:
print("Final accuracy: %.4f" % avg_acc)