In [8]:
import torch
import torch.nn.functional as F
from torch import tensor

def focal_loss(inputs, targets, alpha=0.25, gamma=2.0):
    p = F.softmax(inputs, -1)
    ce = F.cross_entropy(inputs, targets, reduction="none")
    pt = p[torch.arange(len(targets), device=inputs.device), targets]

    # down-weight background (assumed to be last class index)
    alpha_factor = torch.full_like(pt, 1 - alpha)  # foreground gets higher weight
    alpha_factor[targets == inputs.shape[-1] - 1] = alpha  # background gets lower weight

    loss = alpha_factor * (1 - pt) ** gamma * ce
    return loss.mean()


In [14]:
import torch
import torch.nn.functional as F

def focal_loss(inputs, targets, alpha=0.25, gamma=2.0):
    p = F.softmax(inputs, -1)
    ce = F.cross_entropy(inputs, targets, reduction="none")
    pt = p[torch.arange(len(targets), device=inputs.device), targets]

    alpha_factor = torch.full_like(pt, 1 - alpha)
    alpha_factor[targets == inputs.shape[-1] - 1] = alpha

    loss = alpha_factor * (1 - pt) ** gamma * ce
    return loss.mean()

# Confident predictions (correct)
logits = torch.tensor([
    [2.0, 0.5, 0.1],  # confident for class 0
    [0.1, 2.0, 0.5],  # confident for class 1
    [0.2, 0.3, 2.5],  # confident for background (class 2)
], requires_grad=True)
targets = torch.tensor([0, 1, 2])

loss = focal_loss(logits, targets, alpha=0.25, gamma=2.0)
print("Focal Loss:", loss.item())

# Less confident predictions
less_confident_logits = torch.tensor([
    [1.0, 1.0, 0.1],
    [0.5, 1.0, 0.5],
    [0.5, 0.5, 1.5],
], requires_grad=True)
less_confident_loss = focal_loss(less_confident_logits, targets, alpha=0.25, gamma=2.0)
print("Less Confident Focal Loss:", less_confident_loss.item())

# Background-only test
bg_logits = torch.tensor([
    [0.2, 0.3, 2.5],
    [0.1, 0.2, 2.7],
    [0.0, 0.1, 2.8],
], requires_grad=True)
bg_targets = torch.tensor([2, 2, 2])

bg_loss_low_alpha = focal_loss(bg_logits, bg_targets, alpha=0.25, gamma=2.0)
bg_loss_high_alpha = focal_loss(bg_logits, bg_targets, alpha=0.75, gamma=2.0)

print("BG Loss (alpha=0.25):", bg_loss_low_alpha.item())
print("BG Loss (alpha=0.75):", bg_loss_high_alpha.item())

# ✅ Assertions
assert loss.item() > 0, "Loss should be positive"
assert less_confident_loss > loss, "Less confident predictions should have higher loss"
assert bg_loss_low_alpha < bg_loss_high_alpha, "Lower alpha should reduce background loss"

print("All focal loss tests passed ✅")


Focal Loss: 0.012160624377429485
Less Confident Focal Loss: 0.14292515814304352
BG Loss (alpha=0.25): 0.0008353290613740683
BG Loss (alpha=0.75): 0.0025059871841222048
All focal loss tests passed ✅


In [15]:
# Completely wrong predictions
# Each sample predicts the wrong class with high confidence
wrong_logits = torch.tensor([
    [0.1, 2.5, 0.2],  # should be class 0
    [2.5, 0.1, 0.2],  # should be class 1
    [2.5, 0.1, 0.2],  # should be class 2 (bg), but predicts class 0
], requires_grad=True)
wrong_targets = torch.tensor([0, 1, 2])

wrong_loss = focal_loss(wrong_logits, wrong_targets, alpha=0.25, gamma=2.0)
print("Wrong Prediction Focal Loss:", wrong_loss.item())

# 🔍 Compare against confident correct predictions
print("Focal Loss (correct):", loss.item())
assert wrong_loss > loss, "Loss should increase when predictions are very wrong"

Wrong Prediction Focal Loss: 1.271703839302063
Focal Loss (correct): 0.012160624377429485


In [12]:
# Check that background samples are being weighted with alpha
bg_logits = torch.tensor([
    [0.2, 0.3, 2.5],
    [0.1, 0.2, 2.7],
    [0.0, 0.1, 2.8],
], requires_grad=True)

bg_targets = torch.tensor([2, 2, 2])

bg_only_loss = focal_loss(bg_logits, bg_targets, alpha=0.25, gamma=2.0)
print("Background-Only Loss:", bg_only_loss.item())

# Now flip the alpha and see if it's higher (to test alpha impact)
fg_heavy_loss = focal_loss(bg_logits, bg_targets, alpha=0.75, gamma=2.0)
print("Background-Only Loss with alpha=0.75:", fg_heavy_loss.item())

assert bg_only_loss.item() < fg_heavy_loss.item(), "Lower alpha should reduce background loss"


Background-Only Loss: 0.0008353290613740683
Background-Only Loss with alpha=0.75: 0.0025059871841222048
