In [44]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import recall_score

torch.manual_seed(0)
np.random.seed(0)

normal = np.random.randn(990,2)
# 1% is cancel data
cancer = np.random.randn(10,2)*1.5 + np.array([0.5,0.5])
cancer[:3] = np.random.randn(3,2)*0.5
X = np.vstack([normal, cancer])   # (1000, 2)
X = np.hstack([
    X,
    (X[:,0] * X[:,1]).reshape(-1,1),
    (X[:,0]**2).reshape(-1,1)
])
y = np.array([0]*990 + [1]*10)
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32).unsqueeze(1)

torch.Size([1000, 4])


In [38]:
import torch.nn as nn

class Net(nn.Module):
  def __init__(self) -> None:
    super().__init__()
    self.linear = nn.Linear(4,1)

  def forward(self,x):
    return self.linear(x)


In [39]:
import torch.nn.functional as F

def focal_loss(logits, targets, alpha=0.9, gamma=2):
  # penalty for wrong answer
  bce = F.binary_cross_entropy_with_logits(logits, targets, reduction = 'none')
  # the probability of right answer
  p = torch.sigmoid(logits)
  # the confidence of right answer(it is different from BCE,no log)
  pt = targets * p + (1-targets)*(1-p)
  loss = alpha * (1-pt)**gamma*bce
  return loss.mean()

In [40]:
def train(loss_fn):
  model = Net()
  opt = torch.optim.Adam(model.parameters(), lr= 0.001)
  epochs =100
  for _ in range(epochs):
    logits = model(X)
    loss = loss_fn(logits,y)
    opt.zero_grad()
    loss.backward()
    opt.step()

  with torch.no_grad():
    preds = (torch.sigmoid(model(X)) > 0.3).int().numpy()
    recall = recall_score(y.numpy(),preds)
  return recall

In [41]:
r_bce = train(lambda x,y : F.binary_cross_entropy_with_logits(x,y))
r_focal = train(lambda x,y : focal_loss(x,y))

In [42]:
print("BCE :",r_bce)
print("Focal:" ,r_focal)

BCE : 0.3
Focal: 0.9


In [43]:
gamma = [0.2,0.5,1.0,1.5,2.0]
result = {}
for g in gamma :
  loss = train(lambda x,y : focal_loss(x,y,gamma = g))
  result[g] = loss

for k, v in result.items():
  print("when gamma is ",k, " recall score is ", v)

when gamma is  0.2  recall score is  0.5
when gamma is  0.5  recall score is  0.6
when gamma is  1.0  recall score is  1.0
when gamma is  1.5  recall score is  0.5
when gamma is  2.0  recall score is  0.4
