In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [2]:
class DistanceLoss(nn.Module):
    def __init__(self, num_classes, feat_d=500,lambda_center_loss=0.5):
        super(DistanceLoss, self).__init__()
        self.num_classes = num_classes
        self.lambda_center_loss = lambda_center_loss
        self.softmax_loss = nn.CrossEntropyLoss()
        self.centers = nn.Parameter(torch.randn(num_classes, feat_d))

    def center_loss(self, features, gt):
        batch_size, num_samples = gt.size(0), gt.size(1)
        
        dist = torch.pow(features - self.centers[gt].to(features.device), 2).sum(dim=-1)
        loss = dist.sum() / (batch_size * num_samples)
        return loss

    def forward(self, features,logits,gt):        
        softmax_loss = self.softmax_loss(logits, gt)
        
        ctr_loss = self.center_loss(features, gt)
        
        pred = F.softmax(logits,dim=1) #(Minibatch, Classes, 24)
        _, y_pred = torch.max(pred, dim = 1) #[Minibatch,24]
        mse_loss = torch.mean((y_pred - gt.float()) ** 2)
    
        total_loss = softmax_loss + self.lambda_center_loss * ctr_loss + mse_loss
        
        return total_loss