In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
import collections

In [2]:
sample_pred = torch.Tensor([0.8, 0.2])
sample_target = torch.Tensor([1])
sample_target2 = torch.Tensor([1,0])
penalty_matrix = torch.Tensor([[0,1], [1,0]])

In [3]:
sample_pred.shape, sample_target.shape, penalty_matrix.shape

(torch.Size([2]), torch.Size([1]), torch.Size([2, 2]))

In [4]:
penalty_matrix @ sample_pred

tensor([0.2000, 0.8000])

In [5]:
sample_target2 @ (penalty_matrix @ sample_pred)

tensor(0.2000)

In [6]:
sample_pred2 = torch.Tensor([0,1])
sample_target2 @ (penalty_matrix @ sample_pred2)

tensor(1.)

In [7]:
torch.Tensor([0,1]).T @ (penalty_matrix @ sample_target2)

tensor(1.)

In [8]:
def bilinear_loss_calc(preds, targets, FP_pen = 1, FN_pen = 0.5):
    penalty_matrix =  torch.Tensor([[0, FP_pen], [FN_pen, 0]])
    inner_part = penalty_matrix @ F.softmax(preds, dim = 0)
    return targets.T @ inner_part


In [9]:
preds = torch.Tensor([1,0])
targets = torch.Tensor([1,0])
bilinear_loss_calc(preds, targets)

tensor(0.2689)

In [10]:
preds = torch.Tensor([.9,0.1])
targets = torch.Tensor([1,0])
bilinear_loss_calc(preds, targets)

tensor(0.3100)

In [11]:
preds = torch.Tensor([0,1])
targets = torch.Tensor([1,0])
bilinear_loss_calc(preds, targets)

tensor(0.7311)

In [12]:
preds = torch.Tensor([1,0])
targets = torch.Tensor([0,1])
bilinear_loss_calc(preds, targets)

tensor(0.3655)

In [16]:
preds = torch.Tensor([500,-10])
targets = torch.Tensor([0,1])
bilinear_loss_calc(preds, targets)

tensor(0.5000)

In [18]:
preds = torch.Tensor([-10,500])
targets = torch.Tensor([0,1])
bilinear_loss_calc(preds, targets)

tensor(0.)

In [19]:
sample_preds = torch.Tensor([[-0.2944,  0.1242],
        [ 0.5356,  0.5005],
        [-1.4531, -1.0283],
        [ 0.4014,  0.7788]] )
sample_preds.shape

torch.Size([4, 2])

In [83]:
F.softmax(sample_preds, dim = 1)

tensor([[0.3969, 0.6031],
        [0.5088, 0.4912],
        [0.3954, 0.6046],
        [0.4068, 0.5932]])

In [22]:
sample_targets = torch.Tensor([[1], [0], [1], [1]])

In [23]:
sample_targets.shape

torch.Size([4, 1])

In [24]:
one_hot_targets = F.one_hot(sample_targets.to(torch.int64)).to(torch.float32)
one_hot_targets

tensor([[[0., 1.]],

        [[1., 0.]],

        [[0., 1.]],

        [[0., 1.]]])

In [25]:
one_hot_targets.shape

torch.Size([4, 1, 2])

In [27]:
inner_part = penalty_matrix @ sample_preds.T

In [43]:
inner_part.T

tensor([[ 0.1242, -0.2944],
        [ 0.5005,  0.5356],
        [-1.0283, -1.4531],
        [ 0.7788,  0.4014]])

In [45]:
inner_part.T.shape, sample_targets.shape

(torch.Size([4, 2]), torch.Size([4, 1]))

In [39]:
inner_part.shape, one_hot_targets.shape

(torch.Size([2, 4]), torch.Size([4, 1, 2]))

In [34]:
one_hot_targets @ inner_part

tensor([[[-0.2944,  0.5356, -1.4531,  0.4014]],

        [[ 0.1242,  0.5005, -1.0283,  0.7788]],

        [[-0.2944,  0.5356, -1.4531,  0.4014]],

        [[-0.2944,  0.5356, -1.4531,  0.4014]]])

In [42]:
one_hot_targets.permute(1,2,0) @ inner_part.T

tensor([[[ 0.5005,  0.5356],
         [-0.1253, -1.3461]]])

In [47]:
expanded_penalty_matrix = penalty_matrix.repeat(4,1,1)

In [48]:
expanded_penalty_matrix.shape

torch.Size([4, 2, 2])

In [49]:
expanded_penalty_matrix

tensor([[[0., 1.],
         [1., 0.]],

        [[0., 1.],
         [1., 0.]],

        [[0., 1.],
         [1., 0.]],

        [[0., 1.],
         [1., 0.]]])

In [52]:
expanded_inner_part = expanded_penalty_matrix @ sample_preds.T
expanded_inner_part

tensor([[[ 0.1242,  0.5005, -1.0283,  0.7788],
         [-0.2944,  0.5356, -1.4531,  0.4014]],

        [[ 0.1242,  0.5005, -1.0283,  0.7788],
         [-0.2944,  0.5356, -1.4531,  0.4014]],

        [[ 0.1242,  0.5005, -1.0283,  0.7788],
         [-0.2944,  0.5356, -1.4531,  0.4014]],

        [[ 0.1242,  0.5005, -1.0283,  0.7788],
         [-0.2944,  0.5356, -1.4531,  0.4014]]])

In [54]:
expanded_inner_part.shape, one_hot_targets.shape

(torch.Size([4, 2, 4]), torch.Size([4, 1, 2]))

In [60]:
bil_loss = one_hot_targets @ expanded_inner_part

In [61]:
bil_loss

tensor([[[-0.2944,  0.5356, -1.4531,  0.4014]],

        [[ 0.1242,  0.5005, -1.0283,  0.7788]],

        [[-0.2944,  0.5356, -1.4531,  0.4014]],

        [[-0.2944,  0.5356, -1.4531,  0.4014]]])

In [79]:
bil_loss.squeeze().mean(axis = 0)

tensor([-0.1898,  0.5268, -1.3469,  0.4957])

In [80]:
np.array([-0.2944,  0.5356, -1.4531,  0.4014]).mean()

-0.202625

In [78]:
preds = sample_preds[0]
targets = one_hot_targets[0].squeeze()
bilinear_loss_calc(preds, targets)

tensor(0.1984)