# core

> Fill in a module description here

In [None]:
#| default_exp loss

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=0.25, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.reduction = reduction

        if isinstance(alpha, (float, int)):
            self.register_buffer('alpha', torch.tensor([alpha]))
        else:
            self.register_buffer('alpha', torch.tensor(alpha))

    def forward(self, inputs, targets):
        bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        p_t = torch.exp(-bce_loss)
        focal_weight = (1 - p_t) ** self.gamma
        alpha = self.alpha.to(inputs.device)

        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        
        loss = alpha_t * focal_weight * bce_loss
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss
        

In [None]:
#| export
import importlib
def get_cls(module_name, class_name):
    module = importlib.import_module(module_name)
    return getattr(module, class_name)

In [None]:
#| export
import numpy as np

def get_cb_weights(labels_count_list, beta=0.999):
    effective_num = 1.0 - np.power(beta, labels_count_list)
    weights = (1.0 - beta) / np.array(effective_num)
    
    weights = weights / np.sum(weights) * len(labels_count_list)
    return torch.tensor(weights, dtype=torch.float32)



In [None]:
#| hide
import numpy as np
import torch
labels_count_list = [517, 163, 142]
beta = 0.999
effective_num = 1.0 - np.power(beta, labels_count_list)
print(effective_num)
weights = (1.0 - beta) / np.array(effective_num)
print(weights)
weights = weights / np.sum(weights) * len(labels_count_list)
torch.tensor(weights, dtype=torch.float32)

[0.40384744 0.1504781  0.13244038]
[0.00247618 0.00664549 0.00755057]


tensor([0.4456, 1.1958, 1.3586])

In [None]:
#| hide
torch.sum(torch.tensor(weights, dtype=torch.float32))

tensor(3.)

In [None]:
#| export
def init_loss(cfg, weights = [517, 163, 142], device='cuda'):
    
    if cfg.loss.name == "BCEWithLogitsLoss":
        return torch.nn.BCEWithLogitsLoss()
    
    alpha, gamma = cfg.loss.params.get("alpha", None), cfg.loss.params.get("gamma", None)
    if cfg.loss.name == "BalancedFocalLoss":
        alpha = get_cb_weights(weights)
        cfg.loss.name = "FocalLoss"

    loss_cls = get_cls("dl.loss", cfg.loss.name)
    loss = loss_cls(alpha=alpha, gamma=gamma).to(device)
    cfg.loss.name == "BalancedFocalLoss"
    return loss

In [None]:
#| hide
from omegaconf import OmegaConf
cfg = OmegaConf.load("../cfgs/task_2/efficientnet/Focal-fine-tuning.yaml")
cfg.loss

{'name': 'FocalLoss', 'params': {'gamma': 2.0, 'alpha': 0.25}}

In [None]:
#| hide
loss = init_loss(cfg, device='cpu')
loss

FocalLoss()

In [None]:
#| hide
cfg.loss

{'name': 'FocalLoss', 'params': {'gamma': 2.0, 'alpha': 0.25}}

In [None]:
#| hide
counts = [517, 163, 142]
cfg = OmegaConf.load("../cfgs/task_2/efficientnet/Balanced-fine-tuning.yaml")
cb_weights = get_cb_weights(counts).to("cpu")

loss = init_loss(cfg, weights=counts, device="cpu")


In [None]:
#| hide
loss

FocalLoss()

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()