In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

import os
import random
from PIL import Image

We want to implement centerloss where we take as input a batch of features `N x d`, where `d` is the embedding dimension, and compute the loss $$L = \frac{1}{2} \sum_{j \in \text{Batch}} \lvert\lvert f_j - c_j \rvert\rvert^2$$ where $c_j$ is the learned center for the class corresponding to the training example $j$.

In [None]:
class CenterLoss(nn.Module):
    def __init__(self, num_classes, embedding_dim):
        """Assume that the classes are 0 to num_classes-1 as usual."""
        self.centers = nn.Parameter(torch.empty((num_classes, embedding_dim), dtype=torch.float))
        nn.init.xavier_normal_(self.centers)
        self.mse = nn.MSELoss()
        
    def forward(self, features, class_labels):
        """Assumes features is Nxd and class labels is N"""
        centers = self.centers[class_labels] #so centers[j] is the center associated with label j
        return self.mse(features, centers)
        
        
        
        

In [6]:
a = nn.Parameter(torch.empty((3,5)))
nn.init.xavier_normal_(a)

Parameter containing:
tensor([[ 0.3442,  0.4859, -0.0677,  0.0039,  0.0724],
        [-0.0618, -1.0802,  0.3024, -0.5622, -1.2536],
        [-0.4044, -0.3095, -0.0891, -0.7248,  0.0829]], requires_grad=True)

In [8]:
labels = torch.tensor([0,2,1,2])
a[labels]

tensor([[ 0.3442,  0.4859, -0.0677,  0.0039,  0.0724],
        [-0.4044, -0.3095, -0.0891, -0.7248,  0.0829],
        [-0.0618, -1.0802,  0.3024, -0.5622, -1.2536],
        [-0.4044, -0.3095, -0.0891, -0.7248,  0.0829]],
       grad_fn=<IndexBackward0>)