In [2]:
import segmentation_models_pytorch as smp
import torch

In [3]:
config = {
    "downsize_res": 512,
    "batch_size": 6,
    "epochs": 30,
    "lr": 3e-4,
    "model_architecture": "Unet",
    "model_config": {
        "encoder_name": "resnet34",
        "encoder_weights": "imagenet",
        "in_channels": 3,
        "classes": 7,
    },
}


device = "cuda" if torch.cuda.is_available() else "cpu"
model = smp.Unet(**config["model_config"]).to(device)

In [11]:
x = torch.randn(1, 3, 512, 512).to(device)
y = torch.randint(0, 6, (1, 512, 512)).to(device)
output = model(x)

In [17]:
import torch.nn.functional as F

def focal_loss(input, target, weight=None, gamma=2):
    ce = F.cross_entropy(input, target, reduction="none", weight=weight)
    if weight is not None:
        weight_matrix = weight[target]
        # probably will have stability issues with this division
        probs = torch.exp(-ce) / weight_matrix
    probs = torch.exp(-ce)
    fl_values = (1 - probs) ** gamma * ce
    return fl_values.mean()


In [21]:
loss_fn = torch.hub.load(
	"adeelh/pytorch-multi-class-focal-loss",
	model="FocalLoss",
	alpha=torch.tensor([0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.75]),
	gamma=2,
	reduction="mean"
    )

Downloading: "https://github.com/adeelh/pytorch-multi-class-focal-loss/zipball/master" to /home/davidfm43/.cache/torch/hub/master.zip


In [22]:
loss_fn(output, y)

tensor(0.0904, grad_fn=<MeanBackward0>)

In [19]:
focal_loss(output, y, weight=torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]))



tensor([[[1., 5., 3.,  ..., 6., 3., 5.],
         [5., 2., 6.,  ..., 4., 3., 5.],
         [5., 5., 6.,  ..., 3., 5., 6.],
         ...,
         [5., 2., 5.,  ..., 4., 4., 3.],
         [6., 6., 1.,  ..., 3., 2., 5.],
         [6., 5., 2.,  ..., 3., 1., 1.]]])


tensor(7.1935, grad_fn=<MeanBackward0>)