In [None]:
#default_exp losses
from nbdev.showdoc import show_doc

# Losses

> Implements popular segmentation loss functions.

In [None]:
#hide
from fastcore.test import *
from fastai.torch_core import TensorImage, TensorMask
from fastai.losses import CrossEntropyLossFlat

In [None]:
#export
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from fastai.torch_core import TensorBase
import segmentation_models_pytorch as smp
from deepflash2.utils import import_package

Losses implemented here:

In [None]:
#export
LOSSES = ['WeightedCrossEntropyLoss', 'CrossEntropyLoss', 'CrossEntropyDiceLoss', 
           'DiceLoss', 'JaccardLoss', 'FocalLoss', 'LovaszLoss', 'TverskyLoss']

## Loss Wrapper functions

Wrapper for handling different tensor types from [fastai](https://docs.fast.ai/torch_core.html#TensorBase).

In [None]:
#export 
class TorchLoss(_Loss):
    'Wrapper class around loss function for handling different tensor types.'
    def __init__(self, loss):
        super().__init__()
        self.loss = loss
        
    def _contiguous(self, x): return TensorBase(x.contiguous())
    
    def forward(self, *input):
        input = map(self._contiguous, input)        
        return self.loss(*input) #

Wrapper for combining different losses from [pytorch-toolbelt](https://github.com/BloodAxe/pytorch-toolbelt)

In [None]:
#export 
# from https://github.com/BloodAxe/pytorch-toolbelt/blob/develop/pytorch_toolbelt/losses/joint_loss.py
class WeightedLoss(_Loss):
    '''
    Wrapper class around loss function that applies weighted with fixed factor.
    This class helps to balance multiple losses if they have different scales
    '''
    def __init__(self, loss, weight=1.0):
        super().__init__()
        self.loss = loss
        self.weight = weight

    def forward(self, *input):
        return self.loss(*input) * self.weight

class JointLoss(_Loss):
    'Wrap two loss functions into one. This class computes a weighted sum of two losses.'

    def __init__(self, first: nn.Module, second: nn.Module, first_weight=1.0, second_weight=1.0):
        super().__init__()
        self.first = WeightedLoss(first, first_weight)
        self.second = WeightedLoss(second, second_weight)

    def forward(self, *input):
        return self.first(*input) + self.second(*input)

## Custom Loss Functions

### Weighted Softmax Cross Entropy Loss

as described by Falk, Thorsten, et al. "U-Net: deep learning for cell counting, detection, and morphometry." Nature methods 16.1 (2019): 67-70.


- `axis` for softmax calculations. Defaulted at 1 (channel dimension).
- `reduction` will be used when we call `Learner.get_preds`
- `activation` function will be applied on the raw output logits of the model when calling `Learner.get_preds` or `Learner.predict`
- `decodes` function converts the output of the model to a format similar to the target (here binary masks). This is used in `Learner.predict`

In [None]:
#export
class WeightedSoftmaxCrossEntropy(torch.nn.Module):
    "Weighted Softmax Cross Entropy loss functions"
    def __init__(self, *args, axis=-1, reduction = 'mean'):
        super().__init__()
        self.reduction = reduction
        self.axis = axis
    
    def forward(self, inp, targ, weights):
        # Weighted soft-max cross-entropy loss
        loss = F.cross_entropy(inp, targ, reduction='none')
        loss = loss * weights
        if  self.reduction == 'mean':
            return loss.mean()

        elif self.reduction == 'sum':
            return loss.sum()

        else:
            return loss

    def decodes(self, x): return x.argmax(dim=self.axis)
    def activation(self, x): return F.softmax(x, dim=self.axis)

In a segmentation task, we want to take the softmax over the channel dimension

In [None]:
#Compare WeightedSoftmaxCrossEntropy loss with weights==1 to cross entropy loss
n_classes = 5
wce = TorchLoss(WeightedSoftmaxCrossEntropy(axis=1))
ce = CrossEntropyLossFlat(axis=1)
output = TensorImage(torch.randn(4, n_classes, 356, 356, requires_grad=True))
target = TensorMask(torch.randint(0, n_classes, (4, 356, 356)))
weights = torch.ones(4, 356, 356)
test_close(wce(output, target, weights), ce(output, target), eps=1e-04)
#Test WeightedSoftmaxCrossEntropy loss with weights!=1 is different than weights==1
test_ne(wce(output, target, weights), wce(output, target, weights*0.9))

### Segmenation Models Pytorch Integration

The `get_loss()` function loads popular segmentation losses from [Segmenation Models Pytorch](https://github.com/qubvel/segmentation_models.pytorch): 
- (Soft) CrossEntropy Loss (*insert citation*)
- Dice Loss (*insert citation*)
- Jaccard Loss (*insert citation*)
- Focal Loss (*insert citation*)
- Lovasz Loss (*insert citation*)

### Kornia Segmentation Losses Integration

The `get_loss()` function also loads segmentation losses from [kornia](https://github.com/kornia/kornia). 
Read the [docs](https://kornia.readthedocs.io/en/latest/losses.html#module) for a detailed explanation.
- TverskyLoss (*insert citation*)

In [None]:
#export 
def get_loss(loss_name, mode='multiclass', classes=[1], smooth_factor=0., alpha=0.5, beta=0.5, gamma=2.0, reduction='mean', **kwargs):
    'Load losses from based on loss_name'
    
    assert loss_name in LOSSES, f'Select one of {LOSSES}'
    
    if loss_name=="WeightedCrossEntropyLoss": 
        loss = WeightedSoftmaxCrossEntropy(axis=-1, reduction=reduction)
     
    else:
        if loss_name=="CrossEntropyLoss": 
            loss = smp.losses.SoftCrossEntropyLoss(smooth_factor=smooth_factor, **kwargs)

        elif loss_name=="DiceLoss": 
            loss = smp.losses.DiceLoss(mode=mode, classes=classes, **kwargs)

        elif loss_name=="JaccardLoss": 
            loss = smp.losses.JaccardLoss(mode=mode, classes=classes, **kwargs)

        elif loss_name=="FocalLoss": 
            loss = smp.losses.FocalLoss(mode=mode, alpha=alpha, gamma=gamma, reduction=reduction, **kwargs)

        elif loss_name=="LovaszLoss": 
            loss = smp.losses.LovaszLoss(mode=mode, **kwargs)
        
        elif loss_name=="TverskyLoss": 
            kornia = import_package('kornia')
            loss = kornia.losses.TverskyLoss(alpha=alpha, beta=beta, **kwargs)
        
        elif loss_name=="CrossEntropyDiceLoss":
            dc = smp.losses.DiceLoss(mode=mode, classes=classes, **kwargs)
            ce = smp.losses.SoftCrossEntropyLoss(smooth_factor=smooth_factor, **kwargs)
            loss = JointLoss(ce, dc, 1, 1)
        
    return TorchLoss(loss)

In [None]:
#Test if all losses are running
for loss_name in LOSSES[1:]:
    tst = get_loss(loss_name) 
    loss = tst(output, target)

In [None]:
#Compare soft cross entropy loss with smooth_factor=0 to (fastai) cross entropy 
ce1 = get_loss('CrossEntropyLoss', smooth_factor=0)
ce2 = CrossEntropyLossFlat(axis=1)
test_close(ce1(output, target), ce2(output, target), eps=1e-04)

In [None]:
#Compare soft cross entropy loss with smooth_factor=0 to cross entropy 
jc = get_loss('JaccardLoss')
dc = get_loss('DiceLoss')
dc_loss = dc(output, target)
dc_to_jc = 2*loss/(loss+1) #it seems to be the other way around?
test_close(jc(output, target), dc_to_jc, eps=1e-02)

In [None]:
#Compare TverskyLoss with alpha=0.5 and beta=0.5 to dice loss, should be equal
tw = get_loss("TverskyLoss", alpha=0.5, beta=0.5)
test_close(dc(output, target), tw(output, target), eps=1e-02)

## Export -

In [None]:
#hide
from nbdev.export import *
notebook2script()

Converted 00_learner.ipynb.
Converted 01_models.ipynb.
Converted 02_data.ipynb.
Converted 02a_transforms.ipynb.
Converted 03_metrics.ipynb.
Converted 04_callbacks.ipynb.
Converted 05_losses.ipynb.
Converted 06_utils.ipynb.
