<a href="https://colab.research.google.com/github/Gustave-MB/my-torch/blob/main/HW3/P2/Center_Loss_Starter_Code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# Center Loss

## Description:
Briefly speaking, Center Loss will decrease the variation of the feature cluster of each class.

In other words, the objective of Center Loss is to minimize the intra-class variance of the output feature (the output of the model before being passed to the final classification layer).

$$\mathcal{L}_C = \frac{1}{2}\sum_{i=1}^{m}||\pmb{x}_i-\pmb{c}_{y_i}||_2^2$$

Here $\mathcal{L}_C$ denotes Center Loss, $\pmb{x}_i$ denotes the feature vector of class $i$, $\pmb{c}_{y_i}$ denotes the center of feature vectors within the class of $y_i$, and $m$ is the number of ($\pmb{x}_i$,$y_i$) pairs.

What we will actually implement here will be the mean of the loss, so that the scale of loss matches with cross entropy loss.

$$\mathcal{L}_C = \frac{1}{2m}\sum_{i=1}^{m}||\pmb{x}_i-\pmb{c}_{y_i}||_2^2$$

However, it is too time-wasting to calculate the intra-class centers of ALL the data in every epoch. Therefore, Wen et.al decides to update the centers by batches. "In each iteration, the centers are computed by
averaging the features of the corresponding classes (In this case, some of the
centers may not update)."

The centers are updated by a learning rate $\alpha$ .

$$\frac{\partial\mathcal{L}_C}{\partial\pmb{x}_i} = \pmb{x}_i-\pmb{c}_{y_i}$$

$$\Delta\pmb{c}_j = \frac{\sum_{i=1}^{m}\delta(y_i=j)\cdot(\pmb{c}_i-\pmb{x}_i)}{1+\sum_{i=1}^{m}\delta(y_i=j)}$$

$$\pmb{c}_{j}^{t+1}=\pmb{c}_{j}^{t}-\alpha\cdot\Delta\pmb{c}_j$$

Inside the class of Center Loss, you do not need to implement the update part. Update is handled by the optimizer, which means that you only need to calculate the loss.

In [None]:
class CenterLoss(nn.Module):
    """Center Loss
        Center Loss Paper:
        https://ydwen.github.io/papers/WenECCV16.pdf
    Args:
        nn (_type_): _description_
    """
    def __init__(self,
                 num_classes=NotImplemented, # TODO: What is the number of classes for our model?
                 feat_dim=NotImplemented, # TODO: What is the dimension of your output feature?
                 ) -> None:
        super(CenterLoss, self).__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim

        # I have written the initialization of centers for you here
        # Consider why the shape of centers is (num_classes, feat_dim)
        # You may want to adjust here if you want to test the program on cpu
        self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())

    def forward(self, x, labels):
        """
        Args:
            x: feature matrix with shape (batch_size, feat_dim).
            labels: ground truth labels with shape (batch_size).
        """
        centers = # TODO: Boradcast your self.centers so that centers[i] will contain the center of true label of x[i]
        dist = # TODO: Calculate the squared euclidian distances between your inputs and current centers
        # Each element in dist is actually the Center Loss of each input
        dist = torch.clamp(dist, min=1e-12, max=1e+12)
        # Here you have to first wrap 'dist' inside torch.clamp() function, because log(0) will cause NaN output.
        # To avoid the 0 in 'dist', we will set the lower bound in 'dist' to a value that is close to 0

        loss = # TODO: Calculate the mean loss across the batch.

        return loss

# Example in Training Procedure

When you use FP16 in your training, there is a specific usage you have to follow if you use multiple losses in your training. Here is the example code for multiple loss training when you use Center Loss

More detailed information in this link:
[link](https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-multiple-models-losses-and-optimizers)

The hyperparameters you need to tune: loss weight $\lambda$, loss learning rate $\alpha$

In [None]:
center_loss = CenterLoss(num_classes=NotImplemented, feat_dim=NotImplemented)
optimizer_center_loss = torch.optim.SGD(center_loss.parameters(), lr = NotImplemented) # TODO: select a learning rate, I will recommend to use 0.1

In [None]:
def train(model: nn.Module,
          train_loader: Dataloader,
          optimizer: optim.Optimizer,
          optimizer_center_loss: optim.Optimizer,
          criterion: nn.Module,
          fine_tuning_loss: nn.Module, # here we are using Center Loss as our fine_tuning_loss
          loss_weight,
          scheduler: optim.lr_scheduler._LRScheduler,
          scaler: torch.cuda.amp.GradScaler,
          device):

    model.train()

    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        optimizer_center_loss.zero_grad()

        images, labels = images.to(device), labels.to(device)

        with torch.cuda.amp.autocast():
            outputs, feats = model(images, return_feats=True)
            loss0 = # TODO: calculate cross entropy loss from outputs and labels
            loss1 = # TODO: calculate weighted fine_tuning_loss (center loss) from feats and labels

        # TODO: backward loss0 to calculate gradients for model paramters
        # Hint: You have to pass retain_graph=True here, so that the scaler will remember this backward call

        # TODO: backward loss1 to calculate gradients for fine_tuning_loss paramters

        # update fine tuning loss' parameters
        # the paramerters should be adjusted according to the loss_weight you choose
        for parameter in fine_tuning_loss.parameters():
            parameter.grad.data *= (1.0 / loss_weight)

        scaler.step(optimizer_center_loss)
        scaler.step(optimizer)
        scaler.update()

        scheduler.step()
        # if you use a scheduler to schedule your learning rate for Center Loss
        # scheduler_center_loss.step()

        del images, labels, outputs, loss0, loss1
        torch.cuda.empty_cache()
