In [5]:
import torch
from torch.nn import functional as F
from torch import nn,Tensor

In [109]:
class PartialBCE(nn.Module):
    def __init__(self, alpha, beta, gamma):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma

    def forward(self, outputs: Tensor, targets: Tensor, weights=None):
        masks=targets.detach().clone()
        # generate mask. only 1 and -1 are valid labels.
        masks[masks==-1]=1
        # set -1 as 0 to fit standard BCE loss
        targets[targets==-1]=0

        batch_size, num_class = outputs.size()
        criterion = torch.nn.BCEWithLogitsLoss(reduction="none")
        # BCEloss needs targets be float
        loss = criterion(outputs, targets.float())
        if weights is not None:
            loss = loss * weights
        # masks==0 are masked
        loss = loss * masks.float()
        known_ys = masks.float().sum(1)
        p_y = known_ys / num_class
        g_p_y = self.alpha * (p_y**self.gamma) + self.beta
        print(g_p_y,num_class,loss.sum(1))
        loss = ((g_p_y / num_class) * loss.sum(1)).mean()
        return loss

In [110]:
pred=torch.tensor([[100,-100,2]],dtype=torch.float,requires_grad=True)
tar=torch.tensor([[1,1,0]])
pbce=PartialBCE(-4.45,5.45,1.)
# pbce=nn.BCELoss()

In [111]:
loss=pbce(pred,tar)
loss.backward()
pred.grad

tensor([2.4833]) 3 tensor([100.], grad_fn=<SumBackward1>)


tensor([[ 0.0000, -0.8278,  0.0000]])