In [None]:
!pip install nilearn
!pip install monai

In [None]:
import os
import torch
import torch.nn as nn
import nibabel as nib
from torch.utils.data import DataLoader
from torch.nn.modules.loss import _Loss 
import glob
import numpy as np
import torchvision

In [None]:
!tar -xvf path-to-tar-file -C path-to-extract-files

# Important Params

In [None]:
inChans = 4
input_shape = (4, 160, 240, 240)
seg_outChans = 3
activation = "relu"
normalizaiton = "group_normalization"
VAE_enable = True
train_img_root = '/content/Task01_BrainTumour/imagesTr'
train_label_root = '/content/Task01_BrainTumour/labelsTr'
val_img_root = '/content/Task01_BrainTumour/imagesTr'
val_label_root = '/content/Task01_BrainTumour/labelsTr'
train_batch_size = 1
val_batch_size = 1
checkpoint_path = '/content'
epochs = 100
lr = 0.01

# NVNET

In [None]:
class DownSampling(nn.Module):
    # 3x3x3 convolution and 1 padding as default
    def __init__(self, inChans, outChans, stride=2, kernel_size=3, padding=1, dropout_rate=None):
        super(DownSampling, self).__init__()
        
        self.dropout_flag = False
        self.conv1 = nn.Conv3d(in_channels=inChans, 
                     out_channels=outChans, 
                     kernel_size=kernel_size, 
                     stride=stride,
                     padding=padding,
                     bias=False)
        if dropout_rate is not None:
            self.dropout_flag = True
            self.dropout = nn.Dropout3d(dropout_rate,inplace=True)
            
    def forward(self, x):
        out = self.conv1(x)
        if self.dropout_flag:
            out = self.dropout(out)
        return out

In [None]:
class EncoderBlock(nn.Module):
    '''
    Encoder block
    '''
    def __init__(self, inChans, outChans, stride=1, padding=1, num_groups=8, activation="relu", normalizaiton="group_normalization"):
        super(EncoderBlock, self).__init__()
        
        if normalizaiton == "group_normalization":
            self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=inChans)
            self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=inChans)
        if activation == "relu":
            self.actv1 = nn.ReLU(inplace=True)
            self.actv2 = nn.ReLU(inplace=True)
        elif activation == "elu":
            self.actv1 = nn.ELU(inplace=True)
            self.actv2 = nn.ELU(inplace=True)
        self.conv1 = nn.Conv3d(in_channels=inChans, out_channels=outChans, kernel_size=3, stride=stride, padding=padding)
        self.conv2 = nn.Conv3d(in_channels=inChans, out_channels=outChans, kernel_size=3, stride=stride, padding=padding)
        
        
    def forward(self, x):
        residual = x
        
        out = self.norm1(x)
        out = self.actv1(out)
        out = self.conv1(out)
        out = self.norm2(out)
        out = self.actv2(out)
        out = self.conv2(out)
        
        out += residual
        
        return out

In [None]:
class LinearUpSampling(nn.Module):
    '''
    Trilinear interpolate to upsampling
    '''
    def __init__(self, inChans, outChans, scale_factor=2, mode="trilinear", align_corners=True):
        super(LinearUpSampling, self).__init__()
        self.scale_factor = scale_factor
        self.mode = mode
        self.align_corners = align_corners
        self.conv1 = nn.Conv3d(in_channels=inChans, out_channels=outChans, kernel_size=1)
        self.conv2 = nn.Conv3d(in_channels=inChans, out_channels=outChans, kernel_size=1)
    
    def forward(self, x, skipx=None):
        out = self.conv1(x)
        # out = self.up1(out)
        out = nn.functional.interpolate(out, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)

        if skipx is not None:
            out = torch.cat((out, skipx), 1)
            out = self.conv2(out)
        
        return out

In [None]:
class DecoderBlock(nn.Module):
    '''
    Decoder block
    '''
    def __init__(self, inChans, outChans, stride=1, padding=1, num_groups=8, activation="relu", normalizaiton="group_normalization"):
        super(DecoderBlock, self).__init__()
        
        if normalizaiton == "group_normalization":
            self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=outChans)
            self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=outChans)
        if activation == "relu":
            self.actv1 = nn.ReLU(inplace=True)
            self.actv2 = nn.ReLU(inplace=True)
        elif activation == "elu":
            self.actv1 = nn.ELU(inplace=True)
            self.actv2 = nn.ELU(inplace=True)            
        self.conv1 = nn.Conv3d(in_channels=inChans, out_channels=outChans, kernel_size=3, stride=stride, padding=padding)
        self.conv2 = nn.Conv3d(in_channels=outChans, out_channels=outChans, kernel_size=3, stride=stride, padding=padding)
        
        
    def forward(self, x):
        residual = x
        
        out = self.norm1(x)
        out = self.actv1(out)
        out = self.conv1(out)
        out = self.norm2(out)
        out = self.actv2(out)
        out = self.conv2(out)
        
        out += residual
        
        return out

In [None]:
class OutputTransition(nn.Module):
    '''
    Decoder output layer 
    output the prediction of segmentation result
    '''
    def __init__(self, inChans, outChans):
        super(OutputTransition, self).__init__()
        
        self.conv1 = nn.Conv3d(in_channels=inChans, out_channels=outChans, kernel_size=1)
        self.actv1 = torch.sigmoid
        
    def forward(self, x):
        return self.actv1(self.conv1(x))

In [None]:
def VDraw(x):
    x = torch.abs(x)
    # Generate a Gaussian distribution with the given mean(128-d) and std(128-d)
    return torch.distributions.normal.Normal(x[:,:128], x[:,128:]).sample()

In [None]:
class VDResampling(nn.Module):
    '''
    Variational Auto-Encoder Resampling block
    '''
    def __init__(self, inChans=256, outChans=256, dense_features=(10,12,8), stride=2, kernel_size=3, padding=1, activation="relu", normalizaiton="group_normalization"):
        super(VDResampling, self).__init__()
        
        midChans = int(inChans / 2)
        self.dense_features = dense_features
        if normalizaiton == "group_normalization":
            self.gn1 = nn.GroupNorm(num_groups=8,num_channels=inChans)
        if activation == "relu":
            self.actv1 = nn.ReLU(inplace=True)
            self.actv2 = nn.ReLU(inplace=True)
        elif activation == "elu":
            self.actv1 = nn.ELU(inplace=True)
            self.actv2 = nn.ELU(inplace=True)
        self.conv1 = nn.Conv3d(in_channels=inChans, out_channels=16, kernel_size=kernel_size, stride=stride, padding=padding)
        self.dense1 = nn.Linear(in_features=16*dense_features[0]*dense_features[1]*dense_features[2], out_features=inChans)
        self.dense2 = nn.Linear(in_features=midChans, out_features=midChans*dense_features[0]*dense_features[1]*dense_features[2])
        self.up0 = LinearUpSampling(midChans,outChans)
        
    def forward(self, x):
        out = self.gn1(x)
        out = self.actv1(out)
        out = self.conv1(out)
        out = out.view(-1, self.num_flat_features(out))
        out_vd = self.dense1(out)
        distr = out_vd 
        out = VDraw(out_vd)
        out = self.dense2(out)
        out = self.actv2(out)
        out = out.view((1, 128, self.dense_features[0],self.dense_features[1],self.dense_features[2]))
        out = self.up0(out)
        
        return out, distr
            
    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
            
        return num_features

In [None]:
class VDecoderBlock(nn.Module):
    '''
    Variational Decoder block
    '''
    def __init__(self, inChans, outChans, activation="relu", normalizaiton="group_normalization", mode="trilinear"):
        super(VDecoderBlock, self).__init__()

        self.up0 = LinearUpSampling(inChans, outChans, mode=mode)
        self.block = DecoderBlock(outChans, outChans, activation=activation, normalizaiton=normalizaiton)
    
    def forward(self, x):
        out = self.up0(x)
        out = self.block(out)

        return out

In [None]:
class VAE(nn.Module):
    '''
    Variational Auto-Encoder : to group the features extracted by Encoder
    '''
    def __init__(self, inChans=256, outChans=4, dense_features=(10,12,8), activation="relu", normalizaiton="group_normalization", mode="trilinear"):
        super(VAE, self).__init__()

        self.vd_resample = VDResampling(inChans=inChans, outChans=inChans, dense_features=dense_features)
        self.vd_block2 = VDecoderBlock(inChans, inChans//2)
        self.vd_block1 = VDecoderBlock(inChans//2, inChans//4)
        self.vd_block0 = VDecoderBlock(inChans//4, inChans//8)
        self.vd_end = nn.Conv3d(inChans//8, outChans, kernel_size=1)
        
    def forward(self, x):
        out, distr = self.vd_resample(x)
        out = self.vd_block2(out)
        out = self.vd_block1(out)
        out = self.vd_block0(out)
        out = self.vd_end(out)

        return out, distr

In [None]:
class NvNet(nn.Module):
    def __init__(self, inChans, input_shape, seg_outChans, activation, normalizaiton, VAE_enable, mode):
        super(NvNet, self).__init__()
        
        # some critical parameters
        self.inChans = inChans
        self.input_shape = input_shape
        self.seg_outChans = seg_outChans
        self.activation = activation
        self.normalizaiton = normalizaiton
        self.mode = mode
        self.VAE_enable = VAE_enable
        
        # Encoder Blocks
        self.in_conv0 = DownSampling(inChans=self.inChans, outChans=32, stride=1,dropout_rate=0.2)
        self.en_block0 = EncoderBlock(32, 32, activation=self.activation, normalizaiton=self.normalizaiton)
        self.en_down1 = DownSampling(32, 64)
        self.en_block1_0 = EncoderBlock(64, 64, activation=self.activation, normalizaiton=self.normalizaiton)
        self.en_block1_1 = EncoderBlock(64, 64, activation=self.activation, normalizaiton=self.normalizaiton)
        self.en_down2 = DownSampling(64, 128)
        self.en_block2_0 = EncoderBlock(128, 128, activation=self.activation, normalizaiton=self.normalizaiton)
        self.en_block2_1 = EncoderBlock(128, 128, activation=self.activation, normalizaiton=self.normalizaiton)
        self.en_down3 = DownSampling(128, 256)
        self.en_block3_0 = EncoderBlock(256, 256, activation=self.activation, normalizaiton=self.normalizaiton)
        self.en_block3_1 = EncoderBlock(256, 256, activation=self.activation, normalizaiton=self.normalizaiton)
        self.en_block3_2 = EncoderBlock(256, 256, activation=self.activation, normalizaiton=self.normalizaiton)
        self.en_block3_3 = EncoderBlock(256, 256, activation=self.activation, normalizaiton=self.normalizaiton)
        
        # Decoder Blocks
        self.de_up2 =  LinearUpSampling(256, 128, mode=self.mode)
        self.de_block2 = DecoderBlock(128, 128, activation=self.activation, normalizaiton=self.normalizaiton)
        self.de_up1 =  LinearUpSampling(128, 64, mode=self.mode)
        self.de_block1 = DecoderBlock(64, 64, activation=self.activation, normalizaiton=self.normalizaiton)
        self.de_up0 =  LinearUpSampling(64, 32, mode=self.mode)
        self.de_block0 = DecoderBlock(32, 32, activation=self.activation, normalizaiton=self.normalizaiton)
        self.de_end = OutputTransition(32, self.seg_outChans)
        
        # Variational Auto-Encoder
        if self.VAE_enable:
            self.dense_features = (self.input_shape[1]//16, self.input_shape[2]//16, self.input_shape[3]//16)
            self.vae = VAE(256, outChans=self.inChans, dense_features=self.dense_features)

    def forward(self, x):
        out_init = self.in_conv0(x)
        out_en0 = self.en_block0(out_init)
        out_en1 = self.en_block1_1(self.en_block1_0(self.en_down1(out_en0))) 
        out_en2 = self.en_block2_1(self.en_block2_0(self.en_down2(out_en1)))
        out_en3 = self.en_block3_3(
            self.en_block3_2(
                self.en_block3_1(
                    self.en_block3_0(
                        self.en_down3(out_en2)))))
        
        out_de2 = self.de_block2(self.de_up2(out_en3, out_en2))
        out_de1 = self.de_block1(self.de_up1(out_de2, out_en1))
        out_de0 = self.de_block0(self.de_up0(out_de1, out_en0))
        out_end = self.de_end(out_de0)
        
        if self.VAE_enable:
            out_vae, out_distr = self.vae(out_en3)
            out_final = torch.cat((out_end, out_vae), 1)
            return out_final, out_distr
        
        return out_end

# Augmentations

In [None]:
from monai.transforms import (
    Compose, RandFlipd, Affined, Rand3DElastic, ToTensord, AddChanneld, DivisiblePadd
)

In [None]:
train_transforms = Compose([
                            AddChanneld(['label']),
                            RandFlipd(['image', 'label'], prob=0.5, spatial_axis=0),
                            RandFlipd(['image', 'label'], prob=0.5, spatial_axis=1),
                            DivisiblePadd(k=8, keys=['image', 'label']),
                            ToTensord(['image', 'label'])
])

In [None]:
val_transforms = Compose([
                            AddChanneld(['label']),
                            DivisiblePadd(k=8, keys=['image', 'label']),
                            ToTensord(['image', 'label'])
])

# Datasets

In [None]:
class BraTSDataSet(torch.utils.data.Dataset):
  def __init__(self, img_root, label_root, transform=None):
    self.img_root = img_root
    self.label_root = label_root
    self.transform = transform
    self.img_list = glob.glob(os.path.join(img_root, '*.nii.gz'))
    self.label_list = glob.glob(os.path.join(label_root, '*.nii.gz'))
    assert len(self.img_list)==len(self.label_list), "Some Data Samples are missing!"

  def __len__(self):
    return len(self.img_list)

  def __getitem__(self, idx):
    image = self.img_list[idx]
    label = self.label_list[idx]
    image = nib.load(image).get_fdata().astype(np.float32)
    label = nib.load(label).get_fdata()
    image = np.transpose(image)
    label = np.transpose(label)
    item_dict= {'image': image, 'label': label}
    if self.transform:
      item_dict = self.transform(item_dict)
    else:
      image = torchvision.transforms.ToTensor()(image)
      label = torchvision.transforms.ToTensor()(label)
      item_dict['image'] = image
      item_dict['label'] = label
    return item_dict['image'], item_dict['label']

# Dataloader

# Losses

In [None]:
class SoftDiceLoss(_Loss):
    '''
    Soft_Dice = 2*|dot(A, B)| / (|dot(A, A)| + |dot(B, B)| + eps)
    eps is a small constant to avoid zero division, 
    '''
    def __init__(self, *args, **kwargs):
        super(SoftDiceLoss, self).__init__()

    def forward(self, y_pred, y_true, eps=1e-8):
        intersection = torch.sum(torch.mul(y_pred, y_true)) 
        union = torch.sum(torch.mul(y_pred, y_pred)) + torch.sum(torch.mul(y_true, y_true)) + eps

        dice = 2 * intersection / union 
        dice_loss = 1 - dice

        return dice_loss

In [None]:
class CustomKLLoss(_Loss):
    '''
    KL_Loss = (|dot(mean , mean)| + |dot(std, std)| - |log(dot(std, std))| - 1) / N
    N is the total number of image voxels
    '''
    def __init__(self, *args, **kwargs):
        super(CustomKLLoss, self).__init__()

    def forward(self, mean, std):
        return torch.mean(torch.mul(mean, mean)) + torch.mean(torch.mul(std, std)) - torch.mean(torch.log(torch.mul(std, std))) - 1

In [None]:
class CombinedLoss(_Loss):
    '''
    Combined_loss = Dice_loss + k1 * L2_loss + k2 * KL_loss
    As default: k1=0.1, k2=0.1
    '''
    def __init__(self, k1=0.1, k2=0.1):
        super(CombinedLoss, self).__init__()
        self.k1 = k1
        self.k2 = k2
        self.dice_loss = SoftDiceLoss()
        self.l2_loss = nn.MSELoss()
        self.kl_loss = CustomKLLoss()

    def forward(self, seg_y_pred, seg_y_true, rec_y_pred, rec_y_true, y_mid):
        est_mean, est_std = (y_mid[:, :128], y_mid[:, 128:])
        dice_loss = self.dice_loss(seg_y_pred, seg_y_true)
        l2_loss = self.l2_loss(rec_y_pred, rec_y_true)
        kl_div = self.kl_loss(est_mean, est_std)
        combined_loss = dice_loss + self.k1 * l2_loss + self.k2 * kl_div
        #print("dice_loss:%.4f, L2_loss:%.4f, KL_div:%.4f, combined_loss:%.4f"%(dice_loss,l2_loss,kl_div,combined_loss))
        
        return combined_loss

# Varriables

In [None]:
# Data Load
train_dataset = BraTSDataSet(img_root=train_img_root, label_root=train_label_root, transform=train_transforms)
train_loader = DataLoader(train_dataset, batch_size=train_batch_size) #num_workers=os.cpu_count())

In [None]:
val_dataset = BraTSDataSet(img_root=val_img_root, label_root=val_label_root, transform=val_transforms)
val_loader = DataLoader(val_dataset, batch_size=val_batch_size)

In [None]:
net = NvNet(inChans, input_shape, seg_outChans, activation, normalizaiton, VAE_enable, mode='trilinear')
if torch.cuda.is_available(): net = net.cuda()

In [None]:
criterion = CombinedLoss(k1=0.1, k2=0.1)

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=lr)

In [None]:
import math
best_loss = -math.inf

# Training

In [None]:
torch.backends.cudnn.benchmark = True

for epoch in range(0, epochs):

    # Train Model
    print('\n\n\nEpoch: {}\n<Train>'.format(epoch))
    net.train(True)
    loss = 0
    lr = lr * (0.5 ** (epoch // 4))
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr
    torch.set_grad_enabled(True)
    for idx, (img, label) in enumerate(train_loader):
        if torch.cuda.is_available():
          img, label = img.cuda(), label.cuda()
        pred = net(img)
        seg_y_pred, rec_y_pred, y_mid = pred[0][:,:seg_outChans,:,:,:], pred[0][:,seg_outChans:,:,:,:], pred[1]
        batch_loss = criterion(seg_y_pred, label, rec_y_pred, img, y_mid)
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()
        loss += float(batch_loss)
    log_msg = '\n'.join(['Epoch: %d  Loss: %.5f' %(epoch, loss/(idx+1))])
    print(log_msg)


    # Validate Model
    print('\n\n<Validation>')
    net.eval()
    for module in net.module.modules():
        if isinstance(module, torch.nn.modules.Dropout2d):
            module.train(True)
        elif isinstance(module, torch.nn.modules.Dropout):
            module.train(True)
        else:
            pass
    loss = 0
    torch.set_grad_enabled(False)
    for idx, (img, label) in enumerate(val_loader):
      if torch.cuda.is_available():
        img, label = img.cuda(), label.cuda()
        pred = net(img)
        seg_y_pred, rec_y_pred, y_mid = pred[0][:,:seg_outChans,:,:,:], pred[0][:,seg_outChans:,:,:,:], pred[1]
        batch_loss = criterion(seg_y_pred, label, rec_y_pred, img, y_mid)
        loss += float(batch_loss)
    log_msg = '\n'.join(['Epoch: %d  Loss: %.5f' %(epoch, loss/(idx+1))])
    print(log_msg)

    # Save Model
    if loss <= best_loss:
        torch.save(os.path.join(checkpoint_path, f'epoch:{epoch}_loss{loss}.tar'))
        best_loss = loss
        print("Saving...")