In [None]:
import torch
from torch import nn


class CenterLoss(nn.Module):

    def __init__(self, num_class=10, num_feature=2, **kwargs):
        """Initialize class centers
        Args:
            num_classes (int): number of classes.
            feat_dim (int): feature dimension.
        """

        super(CenterLoss, self).__init__()
        self.num_class = num_class
        self.num_feature = num_feature
        self.centers = nn.Parameter(
            torch.randn(self.num_class, self.num_feature))

    def forward(self, x, labels):
        """
        Args:
            x: feature matrix with shape (batch_size, feat_dim).
            labels: ground truth labels with shape (num_classes).
        """
        center = self.centers[labels]
        dist = (x - center).pow(2).sum(dim=-1)
        loss = torch.clamp(dist, min=1e-12, max=1e+12).mean(dim=-1)

        return loss

In [None]:
center_loss = CenterLoss()

In [None]:
center_loss.state_dict()

In [None]:
intercenter_dist = center_loss.centers.matmul(center_loss.centers.t())
intercenter_dist

In [None]:
intercenter_dist -= torch.diag(torch.diag(intercenter_dist))

In [None]:
a = center_loss.centers.detach()

In [None]:
torch.sum(a.ge(0))