# Knowledge distillation

## Purpose
This notebook is developed to test knowledge distillation from different MobileNet versions to Inception FaceNet.

# Setup

## Libary import
We import all the required Python libraries.

In [1]:
# Core
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as T

# Data manipulation
import numpy as np
import pandas as pd

# Data visualization
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
sns.set_theme(context='notebook', style='darkgrid', palette='pastel')

# General
import math
import os
import random
from glob import glob
from tqdm import tqdm

# Autoreload extension
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload

%autoreload 2

## Repository import

For these experiments, we need to import FaceNet from [timesler/facenet-pytorch](https://github.com/timesler/facenet-pytorch): we will use it as the teacher network.

In [None]:
%%bash
git clone https://github.com/timesler/facenet-pytorch facenet_pytorch

In the preliminary tests, we will focus on using the MobileNet implmementation present in the `torchvision.models`. In the future, we will most likely import other repositories to test knowledge distilattion on different editions of MobileNet.

## Parameter definition

We set all the relevant parameters for our notebook.

In [2]:
## General
# Set the device to a GPU if available.
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'Running on device: {DEVICE}')
# Set the batch size.
BATCH_SIZE = 16
print(f'Batch size: {BATCH_SIZE}')

## InceptionResnetV1
# Defines whether to load a model pre-trained on `vggface2` or on
# `casia-webface`. The first one performs a little better (0.9965 on LFW
# compared to 0.9905) and is slightly smaller (107MB compared to 111MB).
PRETRAINED = "vggface2"

## MTCNN
# Define the output image size and margin of the MTCNN module.
IMAGE_SIZE = 160
MARGIN = 0


Running on device: cuda:0
Batch size: 16


# Data

## Data import
To train the student, we need to work on the same training set as of the teacher. Therefore, we retrieve the VGGFace2 dataset.

**Note**: to retrieve the VGGFace2 dataset, you need to have an account for [Zeus Robots](http://zeus.robots.ox.ac.uk/vgg_face2/).

Download all of the four files available in the page of the dataset. The two archives contains the train and the test dataset, while the two `.txt` files contains the list of paths to the train and test images.

## Data visualization

To visualize the data, we write some methods.

In [3]:
def show_image(tensor, denormalize=False):
    """Print a Tensor as an image."""
    
    if denormalize:
      transforms = T.Compose([T.Normalize(mean = [ 0., 0., 0. ],
                                          std = [ 1/0.229, 1/0.224, 1/0.225 ]),
                              T.Normalize(mean = [ -0.485, -0.456, -0.406 ],
                                          std = [ 1., 1., 1. ]),
                              T.ToPILImage()])
    else:
      transforms = T.Compose([T.ToPILImage()])

    img = transforms(tensor)

    return img

def show_images(tensors):
  """
  Print a list of Tensors as images.
  
    Note: if a sample from a TripletDataset is passed, it will manually
        extract the images.
  """

  if isinstance(tensors, dict) and 'anchor_image' in tensors \
      and 'positive_image' in tensors and 'negative_image' in tensors:
      tensors = [tensors['anchor_image'], tensors['positive_image'], tensors['negative_image']]

  size = math.ceil(math.sqrt(len(tensors)))
  stop = len(tensors) + 1

  # Create a new figure.
  fig = plt.figure(figsize=(size*4, size*4))

  # Set the horizontal space between subplots.
  plt.subplots_adjust(hspace=0.3)

  for i in range(1, stop):
    # Get the tensor as an image.
    image = show_image(tensors[i-1])

    # Add the image as a subplot.
    fig.add_subplot(size, size, i)
    plt.imshow(image)

    # Hide axis.
    plt.axis('off')
  
  # Show the plot.
  plt.show()


## Data process

Images must be aligned to be passed as inputs to the model. Therefore, we define a function that takes either a single tensor (representing an image) or a list of tensors and return the ones that is able to align.

In [4]:
def align_images(tensors):
    """Get a list of images and align them."""
    if not isinstance(tensors, list):
        tensors = [tensors]
    
    aligned_tensors = []
    for tensor in tensors:
        try:
            aligned_tensors.append(mtcnn(tensor))
        except:
            pass
    
    if len(aligned_tensors) == 1:
        return aligned_tensors[0]
    
    return aligned_tensors

    

## Dataset definition
Below, we define a Dataset class for the VGGFace2 dataset. The following is a complex implementation based on triplets.

**NOTE**: multiprocessing does not work in interactive mode. Therefore, to generate the triplet dataset, you need to copy everything that is needed to a file and call it.

In [5]:
import multiprocessing

from facenet_pytorch import MTCNN

class TripletDataset(Dataset):
    """Dataset composed of triplets."""

    def __init__(self, root_dir, split, list_path, num_triplets=None,
                 num_processes=0, triplets_path=None):
        """
            Arguments:
                root_dir (string): root directory of the images.
                split (string): either train or test.
                list_path (string): path to the file containing all the names of the images.
                num_triplets (int): number of triplets to generate.
                num_processes (int): number of processes to use to generate triplets.
                triplets_path (string): string to the numpy file containing the triplets.
                    Instead of generating a new list, it's possible to pass a previously
                    generated one.
        """
        if num_triplets == triplets_path == None:
            raise AttributeError('num_triplets and triplets_path cannot both be None.')

        self.df = self._extract_dataframe(list_path)
        self.root_dir = root_dir
        self.split = split
        self._mtcnn = MTCNN(image_size=IMAGE_SIZE, margin=MARGIN, device=DEVICE)

        if num_processes == 0:
            num_processes = os.cpu_count()

        if triplets_path == None:
            self.triplets = self.generate_triplets(num_triplets,  num_processes)
        else:
            self.triplets = np.load(triplets_path, allow_pickle=True)
    

    def __getitem__(self, idx):
        anchor_path, positive_path, negative_path, positive_class, negative_class = self.triplets[idx]

        # Open the images.
        anchor, positive, negative = self._get_aligned_images([anchor_path, positive_path, negative_path])

        sample = {
            'anchor_image': anchor,
            'positive_image': positive,
            'negative_image': negative,
            'positive_class': positive_class,
            'negative_class': negative_class
        }

        return sample


    def __len__(self):
        return len(self.triplets)
    

    def _get_aligned_images(self, paths):
        """
        Get a list of paths to images of the dataset and returns their aligned tensor.
        If the aligned version is not existant, it creates it.
        """
        if not isinstance(paths, list):
            paths = [paths]
              
        
        aligned_tensors = []
        for path in paths:
            aligned_path = os.path.join(self.root_dir, self.split + '_aligned', path)
            normal_path = os.path.join(self.root_dir, self.split, path)

            if os.path.exists(aligned_path):
                tensor = T.ToTensor()(Image.open(aligned_path))
            else:
                img = Image.open(normal_path)
                try:
                    tensor = self._mtcnn(img, save_path=aligned_path)
                except:
                    continue
            
            if self.split == 'train':
                tensor = T.RandomHorizontalFlip()(tensor)

            aligned_tensors.append(tensor)
        
        if len(aligned_tensors) == 1:
            return aligned_tensors[0]
        
        return aligned_tensors

    
    def _get_image_annotations(self, image_path):
        """Split the pathname to an image into its ids."""
        class_id, other = image_path.split('/')
        image_id, other = other.split('_')
        face_id, _ = other.split('.')
        
        return [class_id, image_id, face_id, image_path]


    def _extract_dataframe(self, list_path):
        """Given the file containing the paths to the images, split to get the ids."""
        images = [line.strip() for line in open(list_path).readlines()]
        annotations = [self._get_image_annotations(image) for image in images]

        return pd.DataFrame(data=annotations, columns=['class_id', 'image_id', 'face_id', 'path'])


    def _generate_triplets(self, num_triplets, process_id):
        classes = self.df['class_id'].unique()
        triplets = []

        progress_bar = tqdm(range(int(num_triplets)))

        for _ in progress_bar:
            # Select a positive and a negative class.
            pos_class = random.choice(classes)
            neg_class = random.choice(classes)
            
            # Ensure that the positive class has at least two images.
            while len(self.df.loc[self.df['class_id'] == pos_class]) < 2:
                pos_class = random.choice(classes)

            # Ensure that the classes are different.
            while neg_class == pos_class:
                neg_class = random.choice(classes)
            
            # Get the list of images of the classes.
            pos_images = self.df.loc[self.df['class_id'] == pos_class].values
            neg_images = self.df.loc[self.df['class_id'] == neg_class].values

            # Select an anchor and a positive image.
            anchor_index = random.randint(0, len(pos_images)-1)
            positive_index = random.randint(0, len(pos_images)-1)

            # Ensure that the positive image is not equal to the anchor.
            while anchor_index == positive_index:
                positive_index = random.randint(0, len(pos_images)-1)
            
            # Select a negative image.
            negative = random.choice(neg_images)

            # Retrieve the paths to the images
            anchor_path = pos_images[anchor_index][-1]
            positive_path = pos_images[positive_index][-1]
            negative_path = negative[-1]

            # Align the three images.
            aligned_images = self._get_aligned_images([anchor_path, positive_path, negative_path])

            # If the number of aligned images is 3, MTCNN was able to align all of them,
            # therefore we save the triplet.
            if len(aligned_images) == 3:
                triplet = [anchor_path,
                        positive_path,
                        negative_path,
                        pos_class,
                        neg_class]
                triplets.append(triplet)
        
        # Update the number of triplets.
        self.num_triplets = len(triplets)

        # Save the list to file.
        path = os.path.join(self.root_dir, 'temp', f'{split}_triplets_{process_id}.npy')
        np.save(path, triplets)

        return triplets


    def generate_triplets(self, num_triplets, num_processes):
        # NOTE: the code is taken from https://github.com/tamerthamoqa/facenet-pytorch-vggface2
        total_triplets = []
        
        print("\nGenerating {} triplets using {} Python processes ...".format(
                num_triplets,
                num_processes
            )
        )

        # If True, there are residual number of triplets to be generated after the processes are done
        flag_residual_triplets = False
        triplet_residual = num_triplets % num_processes

        if triplet_residual == 0:
            num_triplets_per_process = num_triplets / num_processes
        else:
            flag_residual_triplets = True
            num_triplets_per_process = num_triplets - triplet_residual
            num_triplets_per_process = num_triplets_per_process / num_processes
        
        processes = []
        for i in range(num_processes):
            processes.append(multiprocessing.Process(
                    target=self._generate_triplets,
                    args=(num_triplets_per_process, i)
                )
            )
        
        for process in processes:
            process.start()
        
        for process in processes:
            process.join()  # Block execution until all spawned processes are done
        
        if flag_residual_triplets:
            print("Processes are done. Residual number of tripelts {} detected" + \
                  "and are being generated by main process ...".format(triplet_residual))
            self._generate_triplets(
                num_triplets=triplet_residual,
                process_id=num_processes + 1
            )
        
        numpy_files = glob(os.path.join(self.root_dir, 'temp', '*.npy'))
        for numpy_file in numpy_files:
            total_triplets.append(np.load(numpy_file))
            os.remove(numpy_file)
        
        # Convert total triplets list from 3D shape to 2D shape
        total_triplets = [elem for list in total_triplets for elem in list]

        # Update the triplets and the num_triplets of the dataset.
        self.triplets = total_triplets
        self.num_triplets = len(total_triplets)

        print("Saving training triplets list...")
        np.save(os.path.join(self.root_dir, f'{self.split}_triplets_{self.num_triplets}.npy'), total_triplets)
        print("Saved!\n")


Now, we instantiate the train and the test set with their loaders. For the first execution of TripletDataset for the train and the test set, it's better to define the num_triplets in the method. By doing so, the Dataset will generate the specified number of triplets. For the next executions, this step can be omitted and we can instead pass the triplets_path argument to load the previously created list of triplets for the dataset.

The test set has 169'396 images, while the training set has 3'141'890. For the time being, we create a train set with 10000 triplets and a test set with 1000.

In [6]:
# Instantiate the datasets.
train_set = TripletDataset('VGGFace2', 'train', 'VGGFace2/train_list.txt', triplets_path='VGGFace2/train_triplets_9968.npy')
test_set = TripletDataset('VGGFace2', 'test', 'VGGFace2/test_list.txt', triplets_path='VGGFace2/test_triplets_998.npy')


# Model

## Teacher initialization
We instantiate the teacher model, as described in its repository.

In [7]:
from facenet_pytorch import InceptionResnetV1

# Create an inception resnet (in eval mode).
teacher = InceptionResnetV1(pretrained='vggface2', classify=True).eval().to(DEVICE)


## Student initialization
For the time being, we resort to use the MobileNet defined in the `torchvision` module.

In [14]:
from torchvision.models import mobilenet_v2

# Create a mobilenet model.
student = mobilenet_v2(pretrained=False).to(DEVICE)


Since the last layer of the student and the teacher differ, we must overwrite the `classifier` module of the student to make it similar to the one of the teacher.

Note that the facenet model we are using right now, has two different evaluation modes:

 1. `classify=True`: returns the logits of the classes (8631 outputs features).
 2. `classify=False`: returns the embeddings (512 outputs features).
 
Above, we had set the teacher to `classify=True`, therefore we add the third and fourth module in the overwritten classifier layer of the student to make the output of the student equals in size the number of classes of the teacher training dataset.

In [15]:
student.classifier = nn.Sequential(
    # Keep the same dropout as of the base mobilenet.
    nn.Dropout(p=0.2, inplace=False),

    # Add a linear that matches the out_features of the teacher.
    nn.Linear(in_features=1280, out_features=512, bias=False).to(DEVICE),

    # Add a batch norm as in the teacher.
    nn.BatchNorm1d(512, eps=0.001, momentum=0.1, affine=True).to(DEVICE),

    # Add a last linear that matches the number of classes in VGGFace2.
    nn.Linear(512, 8631).to(DEVICE)
)


## Student training

Below, we define a function to be used to perform a single training step.

In [10]:
def train_step(student, teacher, classes, loader, temperature, optimizer, scheduler, epoch,
              device=DEVICE, print_every=100):
    student.train()
    teacher.eval()

    running_loss = 0.0
    running_soft_loss = 0.0
    running_hard_loss = 0.0

    count = 0
    for _, sample in enumerate(loader):
        # Reset the losses.
        soft_loss = 0
        hard_loss = 0

        # Split the sample into images and labels.
        values = list(sample.values())
        sample_size = len(values[0])
        inputs = torch.cat(values[:3], 0).to(DEVICE)
        labels = [values[3], *values[3:]]

        # Convert the class names with the class id and transpose the tensor.
        labels = [[classes.index(l) for l in label] for label in labels]
        labels = torch.LongTensor(labels).T.to(DEVICE)

        # Images and labels of the same triplet can be retrieved as follow:
        # labels[ID], inputs[ID::SAMPLE_SIZE]
        # Note: to access the inputs you should use the current sample size
        # and not the general batch size.

        # Zero the parameter gradient.
        optimizer.zero_grad()

        # Forward.
        student_outputs = student(inputs)
        teacher_outputs = teacher(inputs)

        # Evaluate the soft loss.
        for i in range(sample_size):
            soft_outputs = F.log_softmax(student_outputs[i::sample_size] / temperature, dim=1)
            soft_targets = F.softmax(teacher_outputs[i::sample_size] / temperature, dim=1)
            soft_loss += F.kl_div(soft_outputs, soft_targets.detach(), reduction='batchmean')

        # It is important to multiply the soft loss by T^2 when using both hard and soft
        # targets. This ensures that the relative contributions of the hard and soft
        # targets remain roughly unchanged if the temperature used for distillation is
        # changed while experimenting with meta-parameters.
        soft_loss *= (temperature ** 2)

        # Evaluate the hard loss.
        for i in range(sample_size):
            hard_loss += F.cross_entropy(student_outputs[i::sample_size], labels[i])

        # Evaluate the weighted average loss.
        loss = soft_loss + hard_loss

        loss.backward()
        optimizer.step()

        # Print statistics.
        running_loss += loss.item() / sample_size
        running_soft_loss += soft_loss.item() / sample_size
        running_hard_loss += hard_loss.item() / sample_size
        if count % print_every == print_every - 1:
            print('[%d, %5d] loss: %.3f (soft: %.3f, hard: %.3f)' % \
                    (epoch + 1, count + 1,
                    running_loss / print_every,
                    running_soft_loss / print_every,
                    running_hard_loss / print_every))
            running_loss = 0.0
            running_soft_loss = 0.0
            running_hard_loss = 0.0
        
        count += 1


## Student testing
Here, we define a function to test the results reached by the model. A simple solution to test the quality of the student is to evaluate triplets of images and check the classification accuracy.

In [11]:
def test_step(student, teacher, classes, loader, device=DEVICE):
    student.eval()
    teacher.eval()

    teacher_class_accuracy = 0
    teacher_match_accuracy = 0
    student_class_accuracy = 0
    student_match_accuracy = 0

    for _, sample in enumerate(loader):
        # Split the sample into images and labels.
        values = list(sample.values())
        sample_size = len(values[0])
        inputs = torch.cat(values[:3], 0).to(DEVICE)
        # labels = [values[3], *values[3:]]

        # # Convert the class names with the class id and transpose the tensor.
        # labels = [[classes.index(l) for l in label] for label in labels]
        # labels = torch.LongTensor(labels).T.to(DEVICE)

        # Forward.
        student_outputs = student(inputs)
        teacher_outputs = teacher(inputs)

        # # Evaluate class predictions.
        # student_class_predictions = [classes[int(student_outputs.topk(1).indices)] for output in student_outputs]
        # teacher_class_predictions = [classes[int(teacher_outputs.topk(1).indices)] for output in teacher_outputs]

        # # Update class accuracies.
        # for i in range(3):
        #     if student_class_predictions[i] == labels[i]:
        #         student_class_accuracy += 1
        #     if teacher_class_predictions[i] == labels[i]:
        #         teacher_class_accuracy += 1
        
        # Evaluate match predictions for each triplet.
        for i in range(sample_size):
            out = student_outputs[i::sample_size]
            student_match_distance = [(out[0] - out[1]).norm().item(),
                                      (out[0] - out[2]).norm().item()]

            out = teacher_outputs[i::sample_size]
            teacher_match_distance = [(out[0] - out[1]).norm().item(),
                                      (out[0] - out[2]).norm().item()]
        
            # Update match accuracy.
            if student_match_distance[0] < student_match_distance[1]:
                student_match_accuracy += 1
            
            if teacher_match_distance[0] < teacher_match_distance[1]:
                teacher_match_accuracy += 1
    
    print(f' - Student match accuracy: {student_match_accuracy/len(loader.dataset)}')
    # print(f'  - class: {student_class_accuracy/(3*len(loader.dataset))}')
    # print(f'  - match: {student_match_accuracy/len(loader.dataset)}')
    print(f' - Teacher match accuracy: {teacher_match_accuracy/len(loader.dataset)}')
    # print(f'  - class: {teacher_class_accuracy/(3*len(loader.dataset))}')
    # print(f'  - match: {teacher_match_accuracy/len(loader.dataset)}')


## Wrap-up

Finally, we write the function that, given a student and a teacher, perform knowledge distilation.

In [12]:
def distill(student, teacher, epochs, initial_temperature):
    # Define the list of classes in the database.
    classes = [folder[-7:] for folder in glob(os.path.join('VGGFace2', 'train', '*'))]

    # Create the loader dictionary.
    loaders = {
        'train': DataLoader(train_set, batch_size=BATCH_SIZE, num_workers=0),
        'test': DataLoader(test_set, batch_size=BATCH_SIZE, num_workers=0)
    }

    # Set teacher to evaluation mode.
    teacher.eval()

    # Define the different temperatures.
    temperatures = np.linspace(initial_temperature, 1.0, epochs).tolist()

    # Instantiate optimizer and scheduler.
    optimizer = optim.Adam(student.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [5, 10])

    # Run the epochs.
    for epoch in range(epochs):
        print(f'\nEpoch {epoch+1}, distillation temperature = {temperatures[epoch]}')

        # Run a train step.
        print('Training:')
        train_step(student, teacher, classes, loaders['train'], temperatures[epoch], optimizer,
                   scheduler, epoch)

        # Run a test step.
        print('Testing:')
        test_step(student, teacher, classes, loaders['test'])
    

In [None]:
distill(student, teacher, 20, 20)