In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from scipy.ndimage import distance_transform_edt as distance

In [None]:
def adjust_learning_rate(optimizer, base_lr, decay_rate, step_size, epoch):
    """Set the learning rate to the initial LR decayed by decay_rate(ExpLR)"""
    lr = base_lr * decay_rate**(epoch//step_size)
    lr = max(lr, 0.0005)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

### Generalized dice loss

In [None]:
# Reference from https://github.com/AayushKrChaudhary/RITnet/blob/master/utils.py
class GeneralizedDiceLoss(nn.Module):
    def __init__(self, epsilon=1e-5, weight=None, softmax=True, reduction=True):
        super(GeneralizedDiceLoss, self).__init__()
        self.epsilon = epsilon
        self.weight = []
        self.reduction = reduction
        if softmax:
            self.norm = nn.Softmax(dim=1)
        else:
            self.norm = nn.Sigmoid()

    def forward(self, ip, target):
        Label = (np.arange(3) == target.cpu().numpy()[..., None]).astype(np.uint8)
        target = torch.from_numpy(np.rollaxis(Label, 3,start=1)).cuda()

        assert ip.shape == target.shape
        ip = self.norm(ip)
        ip = torch.flatten(ip, start_dim=2, end_dim=-1).cuda().to(torch.float32) 
        target = torch.flatten(target, start_dim=2, end_dim=-1).cuda().to(torch.float32) 
        
        numerator = ip*target 
        denominator = ip + target 

        class_weights = 1./(torch.sum(target, dim=2)**2).clamp(min=self.epsilon)

        A = class_weights*torch.sum(numerator, dim=2)
        B = class_weights*torch.sum(denominator, dim=2) 

        dice_metric = 2.*torch.sum(A, dim=1)/torch.sum(B, dim=1)
        
        if self.reduction:
            return torch.mean(1. - dice_metric.clamp(min=self.epsilon))
        else:
            return 1. - dice_metric.clamp(min=self.epsilon)

### Distance map loss

In [None]:
# Reference from https://github.com/LIVIAETS/surface-loss
def one_hot2dist(posmask):
    assert len(posmask.shape) == 2
    h, w = posmask.shape 
    res = np.zeros_like(posmask)
    posmask = posmask.astype(np.bool)
    mxDist = np.sqrt((h-1)**2 + (w-1)**2)
    if posmask.any():
        negmask = ~posmask 
        # distance(): Calculate the distance from the non-zero pixel to the nearest zero pixel in the image
        res = distance(negmask) * negmask - (distance(posmask) - 1) * posmask 
    return res/mxDist

In [None]:
class DistanceMapLoss(nn.Module):
    def __init__(self, epsilon=1e-5, softmax=True):
        super(DistanceMapLoss, self).__init__()
        self.weight_map = []
    def forward(self, x, distmap):
        x = torch.softmax(x, dim=1)
        self.weight_map = distmap 
        score = x.flatten(start_dim=2)*distmap.flatten(start_dim=2)
        score = torch.mean(score, dim=2) # Mean between pixels per channel 
        score = torch.mean(score, dim=1) # Mean between channels 
        score = torch.mean(score)
        return score

### Missing loss

In [None]:
class MissingLoss(nn.Module):
    def __init__(self, gamma=2, alpha=None, size_average=True):
        super(MissingLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha])
        if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim()>2:
            input = input.contiguous().view(input.size(0),input.size(1),-1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1,2)    # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1,input.size(2))   # N,H*W,C => N*H*W,C
        target = target.contiguous().view(-1,1) # N,H,W => N*H*W,1

        logpt = F.log_softmax(input, dim=-1) # N*H*W,C
        logpt = logpt.gather(1,target)
        logpt = logpt.view(-1) # N*H*W
        pt = Variable(logpt.data.exp()) 
        
        if self.alpha is not None:
            if self.alpha.type()!=input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0,target.data.view(-1))
            logpt = logpt * Variable(at)

        # Full loss
        loss = -1 * (1-pt)**self.gamma * logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()