[Link to Paper](https://arxiv.org/pdf/2010.07485.pdf)

# Related works

## Knowledge Distillation

Knowledge Distillation (KD) refers to transferring the knowledge from one Model to another, more often than not, a larger/ more complex network, referred to as the teacher network is used to "teach"/transfer/distill knowledge to a smaller/ less commplex network known as the student network.

This process of knowledge transfer is done by using the class probabilites of the teacher together with the true labels to train the student. Thus the loss function use for KD can be broken down into 2 components, 1 being the difference between student and teacher probabilities and 1 being the difference between the student and true labels. It is worth noting that during KD, the teacher network parameters are fixed.


### Why it works?

The true labels for classification tasks are often one-hot encoded and this focus model training to predict the correct class. As the number of classes increases:
1. Tasks become increasing complex, small/shallower networks might not have the capacity to perform well
2. Large/complex networks have a higher chance of over-emphasing on a dominant class in the dataset and/or be overfitted to the dataset and not perform well on unseen data

Instead of using the true labels which are binary-values (0s and 1s), we can also use the teacher's prediction which are also known as logits. These logits (vector) are continuos values representing the teachers confidence of the datapoint being each of the classes avaliable, the largest logit (values) in the vector/matrix will represent the class predicted by the teacher for that datapoint.

In addition to providing information on the correct class, the countinuos values will indicate a degree of difference between the other "incorrect" classes.

Given a dataset of images of cars, trucks and dogs, it is reasonable to state that there is a higher chance of misclassifying a car as a truck as opposed to misclassifying a car as a dog. If we were to use th one-hot encoded labels, an image of a car will be labelled as a 1 while for the other classes it would be 0s which does not reflect the true representation of this image classification task.

By using a teacher network, after much training on the true labels, it will recognize that there are more similarities between a car and truck and for an image of a car, may produce a prediction like:
| Class       | Probability |
| ----------- | ----------- |
| Car         | 0.75        |
| Truck       | 0.20        |
| Dog         | 0.05        |

Thus by using the teacher's logits to train a student network, for this data sample, it will train the student network to also predict a car at 75% confidence as opposed to 100% if we were to use the true labels. Different images of cars will present itselg with different degree of difficulties for the model, such as an image of a car which is taken further away. This will be represented in the teacher's prediction with a lower class probability than 75% to something like 55%.

The additionl information that comes from the countinous nature of teacher's prediction therefore provides more information with regards to relationship between classes as well as reducing the chance of overfitting greatly (focus class probabilties to 75% as opposed to 100%).

### Technical Details

In the case of Image Classification, the Loss Function used to train the student :

$Loss= $($\gamma$ * $LKD$) + $(1 - $ $\gamma$) $* LCE$

## Problem Statemment

### Model Confidence vs Model Generlization

Besides one-hot encoded labels, another influential factor in Knowledge Distillation is the architecture of the teacher network. The teacher network being a deeper, more complex network often allows to it learning more information than required about the data, thus becoming overfitted. This knowledge about the data learned by the teacher network can be broken down into 2 components, namely **Model Generalization** and **Model Confidence**.

**Model Generalization** refers to the knowledge **important for mapping the input to its respective labels**. The knowledge found in this category are important features (soft targets) that when used to train the student model will allow it to perform well on unseen data.

Contrarily, **Model Confidence** relates to knowledge would **negatively affect the performance of the student model**. This form of knowledge is a result of the training process of the teacher network. In the training phase, the weights in the model are adjusted to minimize the differences between the prediction of the model at each epoch with the true labels. This implies that even when the teacher model can classify all training samples correctly, it can still minimize the loss function by adjusting its prediction closer to the discrete 0s and 1s in the true labels.

To elaborate the concept of soft targets, Model Generalization and Model Confidence, we can use the Modified National Institute of Standards and Technology database (MNST) dataset. The MNST dataset consists of handwritten numbers that can be classified into 1 of 10 classes (0-9).
Taking an image of a “5”, the important features identified by the model that allows it to correctly classify the image as a “5” is referred to Knowledge pertaining to Model Generalization. Assuming that for this image of a “5” the model predicts a class probability of 0.55. Further training of the model using one-hot encoded labels will lead the model to further decrease its loss function by increasing its predicted class probability of the correct class (class “5”) to a value closer to 1 while simultaneously reducing the predicted class probabilities of other classes. This knowledge attained during this further training can be classified as Model Confidence as it represents additional information which is not needed to correctly classify images. The biasness introduce in this further training will encourage the model to predict more samples as class “5” which will affect the model’s ability to correctly predict samples of other classes.

In addition to regulating the predicted class probability of the correct class, the model should also learn the discrepancies between incorrect classes. Taking the same image of a “5” from the MNST datasets, there is a chance that the model misclassifies the image as a “6”. However slim this chance is, we would expect it to be more probable than misclassifying a “5” as a “2”. This is because a “5” shares a closer resemblance to “6” than a “2”. To represent this discrepancy, the class probabilities outputted by the model should represent a higher likelihood of the image being classified as a “6” as opposed to a “2”. This information on class discrepancy is not available when using one-hot encoded labels as all incorrect classes has a label of 0. Conversely, soft targets being logits representative of the teacher network’s predicted class probabilities will contain this additional information. Another advantage of using soft targets is the use unlabelled data for training. This can be done by processing the unlabelled data using teacher model to produce soft targets.


# Spherical Knowledge Distillation (Solution)

# Thoughts

## Dice Loss for 3 Subspace

Another interesting concept was to extent the Dice Loss function to account for 3 probability subspaces, namely the teacher logits, student logits as well as the one-hot encoded ground truth labels.

In the case of a 3 Subspace Dice Loss function, the numerator will be the intersection between the 3 subspaces, and this can be done by performing an element-wise multiplication. The denominator will be the union of these subspaces so adding them up will suffice. One thing to note is in the denominator we will also less the areas where only 2 subspaces intersect.





The problem here is that being one-hot encoded labels, the ground truth when multiplied elementwise with the logits of teacher and student will negate the information pertaining to all classes (multiplied by 0) except the correct class which will be multiplied by 1.
Modifying the concept of 3 subspace Dice Loss, I tried to break it down into a 2 component Dice Loss. The first component would be the degree of difference between the smoothed teacher and student and secondly, it would be the degree of difference between the smoothed teacher and ground truth labels. I forgo the computation for student and ground truth as the goal was to find a temperature value that can objectively smooth the teacher logits distribution.

