In [1]:
# Setup
import os
import numpy as np # For general mathematical operations
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchvision
import matplotlib.pyplot as plt # For plotting the results
from torchinfo import summary # For model summaries
from torch.utils.tensorboard import SummaryWriter # For writing into tensorboard
import nrrd # For reading and manipulating nrrd files
from glob import glob
import torch.nn.functional as F
%load_ext autoreload
%autoreload 3

# Develop focal loss

Specifically solve the alpha issue

In [4]:
alpha = torch.Tensor(
[0.5, 1.0, 4.0, 1.0, 4.0, 4.0, 1.0, 1.0, 3.0, 3.0]
)  # TODO: focal loss weights per channels from the paper


gamma = 2.0
dims = [2, 10, 20, 30, 30]
weights = torch.ones(dims)
alpha_transformed = (weights.transpose(1,-1)*alpha).transpose(1,-1).view(-1)

targets = torch.rand(dims)
inputs = (torch.rand(dims) + targets) / 2

# Add error specifically to channel 0
inputs[0,...] = torch.rand(dims[1:])

orig_input_shape = inputs.shape
# flatten label and prediction tensors
inputs = inputs.view(-1)
targets = targets.view(-1)
# first compute binary cross-entropy

BCE = F.binary_cross_entropy(inputs, targets, weight=alpha_transformed, reduction="mean")

BCE_EXP = torch.exp(-BCE)
focal_loss = (1 - BCE_EXP) ** gamma * BCE


focal_loss


tensor(1.2587)

# Then the dice loss


In [3]:
dims = [2,10, 50, 100, 50]
def old(inputs, targets):
    channels = inputs.size()[1]
    inputs = inputs[:].contiguous().view(-1)
    targets = targets[:].contiguous().view(-1)
    intersection = (inputs * targets).sum()
    dice = ((2.0 * intersection) / (inputs.sum() + targets.sum())) / channels
    return dice

def new(inputs, targets, return_per_channel=False):
    # Compute the elementwise operations p * y and p + y
    dice_top = 2 * inputs * targets + 1e-4
    dice_bottom = (inputs + targets + 1e-4)
    dice = dice_top / dice_bottom
    dsc_per_channel = dice.mean(dim=(0,3,2,4))
    dsc_avg = dsc_per_channel.mean()

    if return_per_channel:
        return dsc_avg, dsc_per_channel

    return dsc_avg

targets = (torch.rand(dims))
targets[:,1] *= 10 # Big organ -> 3 times the size
targets[:,0] *= 1 # Small organ -> Half the size
targets = targets.round(decimals=0)
targets[targets > 1] = 1

inputs = torch.nn.Softmax(dim=1)(torch.rand_like(targets))

good_ch = 1
bad_ch = 2

basic_error = 1.0
good_error =  0.
bad_error = 0.3

diff = targets - inputs
inputs += (1.0-basic_error)*diff # Normal error

diff = targets - inputs
inputs[:,good_ch] +=(1-good_error)*diff[:, good_ch] # Good channel

diff = targets - inputs
inputs[:, bad_ch] += (1.0-bad_error)*diff[:, bad_ch] # BAD Channel

old_dsc = old(inputs, targets)
new_dsc = new(inputs,targets, return_per_channel=True)

print("Organ sizes")
print((targets == 1).sum(dim=(0,2,3,4)) / (targets.numel()/targets.shape[1]))

old_dsc.round(decimals=3), new_dsc

Organ sizes
tensor([0.4998, 0.9498, 0.5006, 0.5009, 0.5001, 0.5003, 0.4999, 0.5005, 0.4996,
        0.4988])


(tensor(0.0450),
 (tensor(0.2152),
  tensor([0.0909, 1.0000, 0.4243, 0.0910, 0.0909, 0.0910, 0.0910, 0.0910, 0.0909,
          0.0908])))

In [363]:
dice_top = 2 * inputs * targets + 1e-4
dice_bottom = (inputs + targets + 1e-4)
dice = dice_top / dice_bottom
dsc_per_channel = dice.mean(dim=(0,3,2,4))
dsc_per_channel = dice.mean()
#dsc_avg = dsc_per_channel.mean()

print(dsc_per_channel)

tensor(0.0969)


In [145]:
diceloss = 0
dims = (18,27,27)
targets = torch.rand(dims)
inputs = (torch.rand(dims) + targets) / 2

p = inputs
y = targets
diceloss += (2*(p*y)/(p+y)).sum()

diceloss /= dims[1]

diceloss, forward(inputs, targets=targets)

(tensor(230.9566), tensor(0.5849))

In [10]:
from src.losses import DiceCoefficient, DiceLoss, CombinedLoss, FocalLoss
from copy import deepcopy

eps = 1e-4
dsc = DiceCoefficient(eps=eps)
dsloss = DiceLoss(eps=eps)
combined = CombinedLoss(alpha=[1.0], eps=eps)
focal = FocalLoss(eps=eps)

dim = (1,2,10,10,10)

targets = torch.rand(dim).round(decimals=0)
inputs = deepcopy(targets)

print("Targets")
print("Ch 0")
print(targets[:,0].numpy())
print("Ch 1")
print(targets[:,1].numpy())


print("Predictions")
print("Ch 0")
print(inputs[:,0].numpy())
print("Ch 1")
print(inputs[:,1].numpy())


Targets
Ch 0
[[[[1. 1. 0. 0. 1. 1. 0. 0. 1. 1.]
   [1. 0. 0. 1. 1. 1. 0. 0. 1. 0.]
   [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
   [0. 1. 1. 0. 0. 1. 0. 0. 0. 1.]
   [1. 1. 0. 0. 1. 0. 1. 1. 1. 1.]
   [0. 0. 1. 0. 1. 1. 1. 0. 1. 0.]
   [1. 1. 1. 0. 1. 1. 1. 0. 1. 0.]
   [1. 0. 0. 1. 1. 1. 0. 1. 1. 0.]
   [0. 1. 0. 1. 1. 0. 0. 1. 0. 0.]
   [1. 0. 1. 0. 1. 0. 1. 0. 1. 1.]]

  [[0. 0. 0. 1. 1. 1. 1. 1. 1. 1.]
   [0. 0. 1. 0. 1. 0. 0. 0. 1. 1.]
   [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
   [0. 1. 0. 1. 1. 1. 0. 0. 1. 0.]
   [1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0. 0. 0. 1. 1. 1.]
   [0. 1. 0. 1. 0. 1. 0. 0. 0. 0.]
   [1. 1. 0. 1. 1. 0. 1. 0. 0. 0.]
   [1. 1. 1. 0. 0. 1. 1. 0. 0. 0.]
   [1. 0. 0. 0. 0. 1. 1. 1. 1. 1.]]

  [[0. 1. 0. 1. 1. 1. 0. 1. 0. 0.]
   [0. 1. 1. 0. 1. 1. 1. 1. 0. 0.]
   [1. 1. 0. 1. 0. 1. 0. 0. 0. 0.]
   [0. 0. 1. 0. 1. 1. 1. 1. 0. 0.]
   [1. 1. 0. 1. 0. 1. 0. 0. 1. 0.]
   [1. 1. 1. 1. 1. 1. 1. 1. 0. 0.]
   [0. 1. 0. 1. 0. 1. 0. 0. 0. 1.]
   [0. 1. 1. 0. 1. 0. 1. 0. 1. 0.]
   

In [11]:

inputs[0,0,0,:,:] = 1 - targets[0,0,0,:,:]

inputs = torch.rand_like(targets)

#inputs[0,0,0,:,:] = targets[0,0,0,:,:]


print()
print("Dice coefficient + mean")
coeff, per_ch = dsc(targets, inputs, return_per_channel_dsc=True)
print(coeff.numpy())
print(per_ch.numpy())

print("Dice loss")
loss, per_ch = dsloss(targets, inputs, return_per_channel_dsc=True)
print(loss.numpy())
print(per_ch.numpy())

print("Focal")
focal_loss = focal(targets, inputs, alpha=torch.Tensor([1.0]))
print(focal_loss.numpy())


print("Combined")
combined_loss, per_ch = combined(targets, inputs, return_per_channel_dsc=True)
print(combined_loss.numpy())
print(per_ch.numpy())

inputs = deepcopy(targets)



Dice coefficient + mean
0.3032807
[0.29969332 0.30686808]
Dice loss
0.6967193
[0.29969332 0.30686808]
Focal
2.265977
Combined
2.962696
[0.29969332 0.30686808]


In [12]:
n_correct_ch_zero = ((targets[0,0].view(-1) == 1) & (targets[0,0].view(-1) == inputs[0,0].view(-1))).sum()
n_correct_ch_one = ((targets[0,1].view(-1) == 1) & (targets[0,1].view(-1) == inputs[0,1].view(-1))).sum()

print(n_correct_ch_zero, (targets[0,0].view(-1) == 1).sum(),n_correct_ch_zero / (targets[0,0].view(-1) == 1).sum())
print(n_correct_ch_one,(targets[0,1].view(-1) == 1).sum(), n_correct_ch_one / (targets[0,1].view(-1) == 1).sum())

tensor(488) tensor(488) tensor(1.)
tensor(518) tensor(518) tensor(1.)


In [40]:

inputs = deepcopy(targets)

inputs[0,0,0,:,:] = 1 - targets[0,0,0,:,:]
inputs[0,1,0,:,:] = 1 - targets[0,1,0,:,:]

inputs = torch.rand_like(targets)
inputs[0,0,:,:,:] = targets[0,0,:,:,:]
inputs[0,0,:,:,:] = targets[0,0,:,:,:]


print()
print("Dice coefficient + mean")
coeff, per_ch = dsc(targets, inputs, return_per_channel_dsc=True)
print(coeff.numpy())
print(per_ch.numpy())

print("Dice loss")
loss, per_ch = dsloss(targets, inputs, return_per_channel_dsc=True)
print(loss.numpy())
print(per_ch.numpy())
alpha_ch0 = -1.0
alpha_ch1 = 1.0
print("Combined")
combined = CombinedLoss(alpha=[alpha_ch0, alpha_ch1])
combined_loss, per_ch = combined(targets, inputs, return_per_channel_dsc=True)
print(combined_loss.numpy())
print(per_ch.numpy())


print("Focal")
focal_loss = focal(targets, inputs, alpha=torch.Tensor([alpha_ch0, alpha_ch1]*(1000)))
print(focal_loss.numpy())



#inputs = deepcopy(targets)



Dice coefficient + mean
-74.65084
[   1.      -150.30168]
Dice loss
75.65084
[   1.      -150.30168]
Focal
nan
Combined
nan
[   1.      -150.30168]


In [31]:
torch.Tensor([alpha_ch0, 2.0]*(1000))

tensor([1., 2., 1.,  ..., 2., 1., 2.])

In [35]:
targets[:,0] = 0
targets[:,1] = -1

targets.view(-1)

tensor([ 0.,  0.,  0.,  ..., -1., -1., -1.])

In [41]:
combined.get_alpha(targets)

tensor([-1., -1., -1.,  ...,  1.,  1.,  1.])