In [None]:
#imports
import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from torchsummary import summary

In [None]:
class ConvBlock2D(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(ConvBlock2D, self).__init__()
    self.conv1 = nn.Conv3d(in_channels,out_channels, kernel_size=(1,3,3), padding='same') #unsure whether padding is used, assuming that it is
    self.relu1 = nn.ReLU()
    self.batchnorm = nn.BatchNorm3d(out_channels)
    self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=(1,3,3), padding='same') #not clearly mentioned in paper that it is out_channels to out_channels
    self.relu2 = nn.ReLU()
    self.batchnorm2 = nn.BatchNorm3d(out_channels)
  def forward(self, x):
    x = self.conv1(x)
    x = self.relu1(x)
    x = self.batchnorm(x)
    x = self.conv2(x)
    x = self.relu2(x)
    x = self.batchnorm2(x)
    return x


In [None]:
class ConvBlock3DResse(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(ConvBlock3DResse, self).__init__()
    self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding='same')
    self.relu1 = nn.ReLU()
    self.batchnorm1 = nn.BatchNorm3d(out_channels)
    self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding='same')
    self.relu2 = nn.ReLU()
    self.batchnorm2 = nn.BatchNorm3d(out_channels)
    self.globalpool = nn.AdaptiveAvgPool3d(output_size=1)
    self.flatten1 = nn.Flatten() #might have to change start_dim
    self.linear1 = nn.Linear(1, 1)
    self.relu3 = nn.ReLU()
    self.linear2 = nn.Linear(1, 1)
    self.sigmoid1 = nn.Sigmoid()
  def forward(self, x):
    x = self.conv1(x)
    x = self.relu1(x)
    x = self.batchnorm1(x)
    x = self.conv2(x)
    x = self.relu2(x)
    x = self.batchnorm2(x)
    x1 = self.globalpool(x)
    x1 = self.linear1(x1)
    x1 = self.relu3(x1)
    x1 = self.linear2(x1)
    x1 = self.sigmoid1(x1)
    print(x.shape)
    print(x1.shape)
    return x1*x + x



In [None]:
class HybridDilatedConv3DResse(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(HybridDilatedConv3DResse, self).__init__()
    #different dilation rates? but how different?
    self.hdc = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding='same', dilation = 2)
    self.relu1 = nn.ReLU()
    self.batchnorm1 = nn.BatchNorm3d(out_channels)
    self.globalpool = nn.AdaptiveAvgPool3d(output_size=1)
    self.flatten1 = nn.Flatten() #might not be necessary, let's see what the output is of the globalpooling. I suppose this is already flattened (output_size=1).
    self.linear1 = nn.Linear(1, 1)
    self.relu2 = nn.ReLU()
    self.linear2 = nn.Linear(1, 1)
    self.sigmoid1 = nn.Sigmoid()
  def forward(self, x):
    x = self.hdc(x)
    x = self.relu1(x)
    x = self.batchnorm1(x)
    x1 = self.globalpool(x)
    x1 = self.linear1(x1)
    x1 = self.relu2(x1)
    x1 = self.linear2(x1)
    x1 = self.sigmoid1(x1)
    return x1*x + x


In [None]:
class Conv3Dfine(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(Conv3Dfine, self).__init__()
    self.conv3D = nn.Conv3d(in_channels, out_channels, kernel_size=1, padding='same')
    self.relu = nn.ReLU()
    self.batchnorm = nn.BatchNorm3d(out_channels)
  def forward(self, x):
    x = self.conv3D(x)
    x = self.relu(x)
    x = self.batchnorm(x)
    return x

critique: Very much uncertain what type of pooling is done, it is only specified for one layer. We asssume this to be constant. Based on advice of TA we do maxpooling, but it's bad that we have to rely on advice instead of just the paper. Also, it would be nice if strides and padding would be mentioned. 

In [None]:
torch.manual_seed(1)

<torch._C.Generator at 0x7f416ebc1df0>

In [None]:
class TorchCNN(nn.Module):
  def __init__(self, in_channels, hidden_channels, out_features):
    super(TorchCNN, self).__init__()
    self.conv2D1 = ConvBlock2D(in_channels, hidden_channels[0])
    
    self.pool1 = nn.MaxPool3d(kernel_size=(1,2,2),stride=(1,2,2)) 
    self.conv3D_coarse1 = ConvBlock3DResse(hidden_channels[0],hidden_channels[1])
    self.pool2 = nn.MaxPool3d(kernel_size=2, stride =2)
    self.conv3D_coarse2 = ConvBlock3DResse(hidden_channels[1],hidden_channels[2])
    self.hdc1 = HybridDilatedConv3DResse(hidden_channels[2], hidden_channels[3])
    self.hdc2 = HybridDilatedConv3DResse(hidden_channels[3], hidden_channels[4])
    self.conv3D_fine1 = Conv3Dfine(hidden_channels[4], hidden_channels[3])
    self.hdc3 = HybridDilatedConv3DResse(hidden_channels[4], hidden_channels[3])
    self.conv3D_fine2 = Conv3Dfine(hidden_channels[3], hidden_channels[2])
    self.conv3D_coarse3 = ConvBlock3DResse(hidden_channels[3],hidden_channels[2])
    self.transpose1 = nn.ConvTranspose3d(hidden_channels[2], hidden_channels[1], kernel_size = 2, stride=2)
    self.conv3D_coarse4 = ConvBlock3DResse(hidden_channels[2], hidden_channels[1])
    self.transpose2 = nn.ConvTranspose3d(hidden_channels[1], hidden_channels[0], kernel_size = (1,2,2), stride=(1,2,2))
    self.conv2D2 = ConvBlock2D(hidden_channels[1], hidden_channels[1])
    self.conv3D_fine3 = Conv3Dfine(hidden_channels[1], out_features)
  def forward(self, x):
    x1 = self.conv2D1(x)
    #print(x1.shape)
    x2 = self.pool1(x1)
    #print(x2.shape)
    x2 = self.conv3D_coarse1(x2)
    #print(x2.shape, 'post-conv3D')
    x3 = self.pool2(x2)
    #print(x3)
    x3 = self.conv3D_coarse2(x3)
    x4 = self.hdc1(x3)
    x5 = self.hdc2(x4)
    x5 = self.conv3D_fine1(x5)
    print(x5.shape, x4.shape)
    x5 = torch.cat((x4, x5), dim=1)
    print(x5.shape)
    x5 = self.hdc3(x5)
    x5 = self.conv3D_fine2(x5)
    x5 = torch.cat((x3, x5), dim=1)
    x5 = self.conv3D_coarse3(x5)
    x5 = self.transpose1(x5)
    x5 = torch.cat((x2,x5), dim =1)
    x5 = self.conv3D_coarse4(x5)
    x5 = self.transpose2(x5)
    x5 = torch.cat((x1,x5), dim = 1)
    x5 = self.conv2D2(x5)
    x5 = self.conv3D_fine3(x5)
    return x5
    #Global average pooling is not implemented, apparently this is an alternative: 
#   #x = torch.randn(16, 14, 14)
    # out = F.adaptive_max_pool2d(x.unsqueeze(0), output_size=1)

In [None]:
in_channels = 1
hidden_channels = [16, 32, 64, 128, 256]
out_channels = 10 # for Miccai data set
CNN = TorchCNN(in_channels, hidden_channels, out_channels)
x = torch.randn((1, 1, 240, 240, 80))
out = CNN.forward(x)
print(out.shape)

torch.Size([1, 32, 240, 120, 40])
torch.Size([1, 32, 1, 1, 1])
torch.Size([1, 64, 120, 60, 20])
torch.Size([1, 64, 1, 1, 1])
torch.Size([1, 128, 120, 60, 20]) torch.Size([1, 128, 120, 60, 20])
torch.Size([1, 256, 120, 60, 20])
torch.Size([1, 64, 120, 60, 20])
torch.Size([1, 64, 1, 1, 1])
torch.Size([1, 32, 240, 120, 40])
torch.Size([1, 32, 1, 1, 1])
tensor([[[[[-6.9434e-01, -1.8858e-01, -7.2486e-01,  ...,  7.9303e-01,
            -7.2486e-01,  1.7384e-01],
           [-7.2486e-01, -7.2486e-01, -7.2486e-01,  ...,  9.1635e-01,
             5.9662e-02, -5.5567e-01],
           [-7.2486e-01, -7.2486e-01, -7.2486e-01,  ..., -7.2486e-01,
            -7.2486e-01, -7.2486e-01],
           ...,
           [ 1.1217e+00, -7.2486e-01,  1.7135e-01,  ...,  5.6459e-01,
             1.1226e+00,  2.1796e-01],
           [-3.9519e-01, -5.8134e-01,  3.6814e-01,  ...,  1.1973e+00,
            -7.2486e-01,  3.9043e-01],
           [ 1.0600e+00, -2.5486e-01,  2.3737e-02,  ...,  4.1766e-01,
             8.11

In [None]:
print(out.shape)

torch.Size([1, 25, 240, 240, 80])


In [None]:
def dice_coef(y_true, y_pred, epsilon=1e-6):
    """ Computes the Sørensen-dice score coefficien(DSC).
        DSC = (2*(|X&Y|)\(|X| + |Y|)
            = 2*sum(|A*B|)/(sum(A^2 + sum(B^2)
        ref: https://github.com/shalabh147/Brain-Tumor-Segmentation-and-Survival-Prediction-using-Deep-Neural-Networks/blob/master/utils.py
        ref: https://gist.github.com/jeremyjordan/9ea3032a32909f71dd2ab35fe3bacc08#file-soft_dice_loss-py

    Args:
        :param y_true: is a tensor [H, W, D, L] with the ground truth of the OAR
        :param y_pred: is a tensor [H, W, D, L] with the predicted area of the OAR
        :param epsilon: Used for numerical stability to avoid divide by zeros.
    """
    dice_scores = []

    for i in range(y_pred.shape[-1]):

        y_pred_label = y_pred([:, :, :, i])
        y_true_labe =  y_true([:, :, :, i])

        if torch.sum(y_true_i) > 0:
            dice_numerator = 2 * torch.sum(y_true_label * y_pred_label)
            dice_denominator = torch.sum(y_true_label * torch(y_pred_label)) + epsilon
            dice_score = dice_numerator/dice_denominator
        else:
            dice_score = 0

        dice_scores.append(dice_score)

    dice_avg = torch.mean(dice_scores)
    return dice_avg, dice_scores

def dice_coef_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)


def focal_loss(y_true, y_pred,  epsilon =1e-6):
    """ Computes the focal loss.
            FL(p_t) = mean(-alpha(1-p_t^gamma)* y *ln(p_t)
            Notice: y_pred is probability after clamping between 0 and 1
            ref: https://arxiv.org/pdf/2109.12634.pdf

        Args:
            :param y_true: is a tensor [H, W, D, L] with the ground truth of the OAR
            :param y_pred: is a tensor [H, W, D, L] with the predicted area of the OAR
            :param epsilon: Used for numerical stability to avoid divide by zeros
            :param gamma: Focal Tversky loss' focal parameter controls degree of down-weighting of easy examples, by default 2.0
            :param ALPHA: assigned weights according to Chen et al. (2021)
        """

    ALPHA = torch.tensor([0.5, 1.0, 4.0, 1.0, 4.0, 4.0, 1.0, 1.0, 3.0, 3.0]) 
    GAMMA = 2

    loss_labels = []

    for i in range(y_pred.shape[-1]):

        y_pred_label = y_pred([:, :, :, i])
        y_true_label =  y_true([:, :, :, i])

        y_pred_clamp = torch.clamp(y_pred_label, epsilon, 1 - epsilon)
        cross_entropy = -y_true_label * torch.log(y_pred_label)

        back_ce = torch.pow(1 - y_pred_clamp, GAMMA) * cross_entropy[]

        focal_loss_label = torch.mul(ALPHA([i]), back_ce)

        loss_labels.append(focal_loss_label)

    loss =  torch.mean(loss_labels)

    return loss


def final_loss(y_true, y_pred):
    return focal_loss(y_true, y_pred) + dice_coef_loss(y_true, y_pred)
