In [1]:
from metrics import *

In [10]:
shape = (1, 3, 512, 512)
positive = torch.ones(shape)
negative = torch.zeros(shape)
ones = torch.ones(shape) * 12

alpha = 0.5
gamma = 2.0
reduction = 'mean'

print(focal(ones, positive, alpha, gamma, reduction))
print(focal(-ones, negative, alpha, gamma, reduction))
print(focal(ones, negative, alpha, gamma, reduction))
print(focal(-ones, positive, alpha, gamma, reduction))


tensor([[1.1910e-16, 1.1910e-16, 1.1910e-16]])
tensor([[1.1682e-16, 1.1682e-16, 1.1682e-16]])
tensor([[5.9999, 5.9999, 5.9999]])
tensor([[5.9999, 5.9999, 5.9999]])


In [6]:
# this is symmetric:
assert alpha == 0.5 and focal(ones, zeros, alpha, gamma, reduction) == focal(-ones, ones, alpha, gamma, reduction)

tensor(1.)
tensor(1.)
tensor(-1.)
tensor(-1.)


# Let's try to replicate the focal loss from the paper, and see if it is indeed non-commutative

In [16]:
def paper_focal_single(outputs, targets, alpha, gamma, reduction):
    """Compute the focal loss for a single sample
    
Args:
        outputs (torch.Tensor): the outputs of the model, of shape HxW
        targets (torch.Tensor): the targets, of shape HxW
        alpha (float): Weighting factor in the range (0,1) to balance positive vs negative examples, or -1 for ignore. default: 0.25 to bias towards positive.
        gamma (float): the gamma parameter
        reduction (str): the reduction method
    """
    # turn targets from 0,1 to -1,1
    targets = targets * 2 - 1
    print(targets)
    # take the sigmoid of the outputs, to reflect the probability of the positive class:
    p = torch.sigmoid(outputs)
    print(p)
    # set p to be p if target = 1, 1-p if target is -1
    p = p * (targets == 1) + (1 - p) * (targets == -1)
    print(p)
    # compute the focal loss, by elementwise application of the function
    # -(1 - p)^gamma * log(p)
    loss = - (1 - p) ** gamma * torch.log(p)
    print(p)
    # apply the alpha factor, having 1-alpha for the positive class, alpha for the negative class
    alphas = (1-alpha) * (targets == 1) + alpha * (targets == -1)
    print(alphas)
    loss = alphas * loss
    # apply the reduction
    if reduction == 'mean':
        loss = loss.mean()
    elif reduction == 'sum':
        loss = loss.sum()
    return loss
    
    
    

In [17]:
# paper_focal_single(ones, ones, alpha, gamma, reduction)

print(paper_focal_single(ones, ones, alpha, gamma, reduction))
print()
print(paper_focal_single(zeros, -ones, alpha, gamma, reduction))
print()
print(paper_focal_single(ones, zeros, alpha, gamma, reduction))
print(paper_focal_single(zeros, ones, alpha, gamma, reduction))

tensor([[1.]])
tensor([[0.7311]])
tensor([[0.7311]])
tensor([[0.7311]])
tensor([[0.5000]])
tensor([[0.0113]])

tensor([[-1.]])
tensor([[0.5000]])
tensor([[0.5000]])
tensor([[0.5000]])
tensor([[0.5000]])
tensor([[0.0866]])

tensor([[-1.]])
tensor([[0.7311]])
tensor([[0.2689]])
tensor([[0.2689]])
tensor([[0.5000]])
tensor([[0.3509]])
tensor([[1.]])
tensor([[0.5000]])
tensor([[0.5000]])
tensor([[0.5000]])
tensor([[0.5000]])
tensor([[0.0866]])


In [18]:
torch.sigmoid(torch.Tensor([-1]))

tensor([0.2689])