# Tutorial Serise One : BCE, CrossEntropy And Focal Loss

## Meaning

## BCE: binary cross entropy
## NLL: negative log loss
## focal loss: general format of cross entropy

In [153]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Problem Scenario

### Normally the output of the model will be a unbounded number, when doing classification problem, we want to add a proper loss, such that the model could be trained

In [154]:
batch_size, n_classes = 5, 2
logits  = torch.rand(batch_size,n_classes)

In [155]:
# first applying softmax
# higher values have higher prob, which makes sense
probs = torch.softmax(logits,dim=-1)

In [156]:
probs

tensor([[0.6811, 0.3189],
        [0.4383, 0.5617],
        [0.3137, 0.6863],
        [0.4921, 0.5079],
        [0.6242, 0.3758]])

In [157]:
## softmax implementation
def cus_softmax(x):return x.exp()/x.exp().sum(-1).unsqueeze(1)

In [158]:
cus_softmax(logits)

tensor([[0.6811, 0.3189],
        [0.4383, 0.5617],
        [0.3137, 0.6863],
        [0.4921, 0.5079],
        [0.6242, 0.3758]])

### protential problem lying on the exp operation

In [159]:
y = torch.randint(high=n_classes,size=(5,))
y

tensor([1, 0, 1, 1, 0])

### softmax + nl (softmax and negative likelihood)

In [160]:
def cus_cross_entropy(inputs,targets):return -probs[range(batch_size),y].log().mean()

In [161]:
cus_cross_entropy(probs,y)

tensor(0.6986)

### log_softmax + nll

In [162]:
def cus_log_softmax(x):return x - x.exp().sum(-1).log().unsqueeze(1)

In [163]:
def nll(inputs,targets): return -inputs[range(batch_size),y].mean()

In [164]:
nll(cus_log_softmax(logits),y)

tensor(0.6986)

### The reason why doing this is:
1, Reducing computation.
2, Numerically more stable, why?

In [165]:
# above example is how F.cross_entropy fun being implemented
F.cross_entropy(logits,y.reshape(-1))

tensor(0.6986)

We intentionally choose case with 2 classes, so that we can compare directly to binary_crossentropy, which is ensentially same thing

In [166]:
logits

tensor([[0.8145, 0.0557],
        [0.2166, 0.4648],
        [0.0403, 0.8233],
        [0.3373, 0.3691],
        [0.7504, 0.2432]])

In [167]:
probs

tensor([[0.6811, 0.3189],
        [0.4383, 0.5617],
        [0.3137, 0.6863],
        [0.4921, 0.5079],
        [0.6242, 0.3758]])

In [168]:
F.binary_cross_entropy(probs[range(batch_size),[1]*batch_size],y.type(torch.float32))

tensor(0.6986)

### Notice the only difference in these two scenario is one we have two neurons and one we only have one

In [107]:
F.binary_cross_entropy??

In [108]:
F.binary_cross_entropy_with_logits??

In [None]:
def sigmoid(x): return (1 + (-x).exp()).reciprocal()

### Focal Loss

In brief: -y*log(prob) --> -y(1-prob)^gamma * log(prob)

In [20]:
# binary cases
x = torch.rand(3,1)
print(x)
F.logsigmoid(x)

tensor([[0.5827],
        [0.7705],
        [0.4685]])


tensor([[-0.4436],
        [-0.3803],
        [-0.4861]])

In [22]:
probs = torch.sigmoid(x)
probs

tensor([[0.6417],
        [0.6836],
        [0.6150]])

In [31]:
y = torch.randint(2,size=(3,1))
y

tensor([[0],
        [1],
        [1]])

In [32]:
F.binary_cross_entropy(probs,y.type(torch.float32))

tensor(0.6309)

In [None]:
def cus_binary_entropy(inputs,targets): return torch.log(probs)

In [39]:
F.logsigmoid(torch.tensor([-300.0]))

tensor([-300.])

In [44]:
torch.sigmoid(torch.tensor([-300.0]))

tensor([0.])

In [77]:
s_logits = torch.rand(batch_size,1)
s_logits

tensor([[0.2316],
        [0.4442],
        [0.0177],
        [0.3277],
        [0.0546]])

In [89]:
y = torch.randint(high=2,size=(batch_size,1))
y

tensor([[1],
        [0],
        [1],
        [0],
        [0]])

In [114]:
F.binary_cross_entropy_with_logits(s_logits,y.type(torch.float32))

tensor(0.7599)

In [116]:
-(F.logsigmoid(s_logits)*y.type(torch.float32) + F.logsigmoid(-s_logits)*(1-y.type(torch.float32))).mean()

tensor(0.7599)

In [121]:
gamma = 2
-(F.logsigmoid(s_logits)*(torch.sigmoid(-s_logits)**gamma)*y.type(torch.float32) + F.logsigmoid(-s_logits)*(torch.sigmoid(s_logits)**gamma)*(1-y.type(torch.float32))).mean()


tensor(0.2231)

In [186]:
def focal(s_logits,y,gamma=2):
    
    y = y.type(torch.float32)
    ones_loss = -F.logsigmoid(s_logits)*(torch.sigmoid(-s_logits)**gamma)*y
    zeros_loss = -F.logsigmoid(-s_logits)*(torch.sigmoid(s_logits)**gamma)*(1-y)
    total_loss = (ones_loss + zeros_loss).mean()
    
    return total_loss


In [180]:
a = torch.sigmoid(torch.rand(5,5))
a

tensor([[0.5951, 0.5312, 0.5368, 0.7122, 0.5969],
        [0.5211, 0.6784, 0.7306, 0.7261, 0.7047],
        [0.7168, 0.7294, 0.6878, 0.7264, 0.5439],
        [0.6366, 0.7089, 0.5677, 0.5763, 0.5987],
        [0.5499, 0.5554, 0.6798, 0.5378, 0.7310]])

In [212]:
s_logits

tensor([[0.2316],
        [0.4442],
        [0.0177],
        [0.3277],
        [0.0546]])

In [220]:
y = torch.randint(2,(5,))
y

tensor([0, 0, 1, 0, 0])

In [224]:
focal(s_logits,y)

tensor(0.2269)

In [225]:
forward(s_logits,y.reshape(-1,1))

tensor(0.2509)

In [179]:
%timeit torch.sigmoid(torch.rand(5,5))

4.27 µs ± 180 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [181]:
%timeit 1-a

5.25 µs ± 150 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [183]:
focal(s_logits,y)

tensor(0.1745)

In [195]:
gamma = 2
def forward( logit, target):
    target = target.float()
    max_val = (-logit).clamp(min=0)
    loss = logit - logit * target + max_val + \
           ((-max_val).exp() + (-logit - max_val).exp()).log()

    invprobs = F.logsigmoid(-logit * (target * 2.0 - 1.0))
    loss = (invprobs * gamma).exp() * loss
    if len(loss.size())==2:
        loss = loss.sum(dim=1)
    return loss.mean()


In [187]:
s_logits

tensor([[0.2316],
        [0.4442],
        [0.0177],
        [0.3277],
        [0.0546]])

In [189]:
y

tensor([1, 0, 1, 1, 0])

In [194]:
F.binary_cross_entropy_with_logits(s_logits,y.reshape(-1,1).type(torch.float32))

tensor(0.6943)

In [208]:
forward(s_logits,y)

tensor(0.8726)

In [209]:
focal(s_logits,y)

tensor(0.1745)

In [130]:
%timeit forward(s_logits,y)

90.8 µs ± 3.07 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [131]:
%timeit focal(s_logits)

63.1 µs ± 2.23 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [171]:
%timeit focal(s_logits,y)

61.8 µs ± 2.93 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [199]:
torch.sigmoid(torch.tensor([400.0]))

tensor([1.])

In [206]:
focal(torch.tensor([0.1]),torch.tensor([1]))

tensor(0.1454)

In [207]:
forward(torch.tensor([0.1]),torch.tensor([1]))

tensor(0.1454)

In [113]:
%timeit y.type(torch.float32)

2.14 µs ± 57.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [105]:
F.logsigmoid(torch.tensor([0.]))

tensor([-0.6931])

In [None]:
s_probs = torch.sigmoid(s_logi)

In [78]:
log_probs = F.logsigmoid(s_logits)
log_probs

tensor([[-0.5840],
        [-0.4955],
        [-0.6843],
        [-0.5426],
        [-0.6662]])

In [79]:
torch.sigmoid(s_logits)

tensor([[0.5576],
        [0.6093],
        [0.5044],
        [0.5812],
        [0.5136]])

In [81]:
torch.log(torch.sigmoid(-s_logits))

tensor([[-0.8156],
        [-0.9397],
        [-0.7020],
        [-0.8704],
        [-0.7208]])

In [82]:
F.logsigmoid(-s_logits)

tensor([[-0.8156],
        [-0.9397],
        [-0.7020],
        [-0.8704],
        [-0.7208]])

In [86]:
torch.log(1-torch.sigmoid(s_logits))

tensor([[-0.8156],
        [-0.9397],
        [-0.7020],
        [-0.8704],
        [-0.7208]])

In [127]:
F.logsigmoid(x)

tensor([-0.6584])

In [112]:
x - torch.log(1+x.exp())

tensor([-0.6584])

In [122]:
logits

tensor([[0.6431, 0.6980],
        [0.9518, 0.6247],
        [0.1736, 0.2046],
        [0.3819, 0.1205],
        [0.5808, 0.9027]])

In [126]:
logits.reciprocal()                            

tensor([[1.5550, 1.4326],
        [1.0506, 1.6007],
        [5.7602, 4.8872],
        [2.6182, 8.2975],
        [1.7218, 1.1078]])

In [None]:
max(-x, 0) + log(e^-(max(-x,0)) + e^(-x - max(-x,0)))

In [47]:
torch.tensor([1.0,2.]),torch.tensor([1.1,2.2])

(tensor([1., 2.]), tensor([1.1000, 2.2000]))

In [48]:
torch.max(x,torch.zeros(x.shape)) + torch.log()

tensor([1.1000, 2.2000])

In [63]:
def cus(x):
    t = torch.max(-x,torch.zeros(x.shape))
    print(t)
    tmp = t + torch.log((t).exp()+(-x-t).exp())
    return tmp

In [67]:
x = torch.tensor([0.58,-500])
x

tensor([   0.5800, -500.0000])

In [68]:
t = torch.max(x,torch.zeros(x.shape))

In [75]:
torch.sigmoid(-x) + torch.sigmoid(x)

tensor([1., 1.])

In [69]:
ta

tensor([0.5800, 0.0000])

In [None]:
torch.log()

In [70]:
s_logits=x

In [71]:
cus(s_logits)

tensor([  0., 500.])


tensor([0.4446,    inf])

In [72]:
F.logsigmoid(s_logits)

tensor([-4.4462e-01, -5.0000e+02])

In [73]:
torch.log(torch.sigmoid(s_logits))

tensor([-0.4446,    -inf])

In [113]:
def focal_loss(input, target, OHEM_percent=None):
    gamma = 2
    assert target.size() == input.size()

    max_val = (-input).clamp(min=0)
    loss = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log()
    invprobs = F.logsigmoid(-input * (target * 2 - 1))
    loss = (invprobs * gamma).exp() * loss

    if OHEM_percent is None:
        return loss.mean()
    else:
        OHEM, _ = loss.topk(k=int(10008 * OHEM_percent), dim=1, largest=True, sorted=True)
        return OHEM.mean()

In [121]:
(-torch.tensor(0.5)).clamp(min=0)

tensor(0.)

In [115]:
focal_loss(logits,y.reshape(-1,1))

AssertionError: 

In [116]:
logits.shape

torch.Size([5, 2])