# Custom Loss functions

> ِAny custom losses should be here


In [None]:
#| default_exp losses

In [None]:
#| hide
from nbdev.showdoc import *  # type: ignore # noqa: F403

In [None]:
#| export
import copy
import numpy as np
import torch, os, random
import torch.nn as nn
import torch.nn.functional as F
from fastcore.utils import * # type: ignore # noqa: F403


In [None]:
#| export
class AnchorLoss(nn.Module):
    def __init__(self, num_classes, feature_num):
        super().__init__()
        self.num_classes = num_classes
        self.feature_num = feature_num
        self.anchor = nn.Parameter(F.normalize(torch.randn(num_classes, feature_num)), requires_grad=True)

In [None]:
#| export
@patch
def forward(self: AnchorLoss, feature, _target, Lambda = 0.1):
    assert not torch.isnan(_target).any(), "Found NaN in _target!"
    # broadcast feature anchors for all inputs
    centre = self.anchor.cuda().index_select(dim=0, index=_target.long())
    # compute the number of samples in each class
    counter = torch.histc(_target.cpu().float(), bins=self.num_classes, min=0, max=self.num_classes-1)
    counter = counter.to(_target.device)  # Move back to the same device as _target
    count = counter[_target.long()]
    centre_dis = feature - centre				# compute distance between input and anchors
    pow_ = torch.pow(centre_dis, 2)				# squre
    sum_1 = torch.sum(pow_, dim=1)
    count = count.clamp(min=1)  # Avoid division by zero
    dis_ = sum_1 / count.float()				# sum all distance
    # dis_ = torch.div(sum_1, count.float())		# mean by class
    sum_2 = torch.sum(dis_)/self.num_classes						# mean loss
    res = Lambda*sum_2   							# time hyperparameter lambda 
    return res

In [None]:
#| hide
import torch
y = torch.randint(0, 3, (32,)) # B = 32
y

tensor([0, 0, 1, 0, 1, 0, 0, 0, 2, 0, 1, 2, 2, 2, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1,
        2, 0, 1, 2, 0, 0, 2, 2])

In [None]:
labels = y.type(torch.LongTensor)
labels

tensor([0, 0, 1, 0, 1, 0, 0, 0, 2, 0, 1, 2, 2, 2, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1,
        2, 0, 1, 2, 0, 0, 2, 2])

In [None]:
ys = labels.float()
ys

tensor([0., 0., 1., 0., 1., 0., 0., 0., 2., 0., 1., 2., 2., 2., 1., 1., 1., 0.,
        1., 0., 1., 1., 1., 1., 2., 0., 1., 2., 0., 0., 2., 2.])

In [None]:
anchor = torch.nn.Parameter(torch.randn(3, 512))

In [None]:
centre = anchor.index_select(dim=0, index=ys.long())
centre.shape

torch.Size([32, 512])

In [None]:
counter = torch.histc(ys, bins= 3, min=0, max=3-1)
counter

tensor([12., 12.,  8.])

In [None]:
count = counter[ys.long()]
count

tensor([12., 12., 12., 12., 12., 12., 12., 12.,  8., 12., 12.,  8.,  8.,  8.,
        12., 12., 12., 12., 12., 12., 12., 12., 12., 12.,  8., 12., 12.,  8.,
        12., 12.,  8.,  8.])

In [None]:
count.clamp(min=1).equal(count)

True

In [None]:
feature = torch.randn(32, 512)
centre_dis = feature - centre	
centre_dis.shape

torch.Size([32, 512])

In [None]:
#| hide
import nbdev
nbdev.nbdev_export()