## Setup imports

In [None]:
import os
import random
import sys
import nibabel as nib
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

import monai
from monai.data import list_data_collate, decollate_batch, partition_dataset, DatasetSummary
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.transforms import (
    AsChannelLastd,
    AsDiscrete,
    Compose,
    ConcatItemsd,
    LoadImaged,
    RandRotated,
    EnsureTyped,
    AddChanneld,
    CropForegroundd,
    EnsureType,
    Spacingd,
    CenterSpatialCropd,
    SpatialPadd,
    ScaleIntensityRanged,
    SqueezeDimd,
    ScaleIntensityd,
    NormalizeIntensityd
)
from monai.visualize import plot_2d_or_3d_image
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from monai.metrics.utils import do_metric_reduction
from monai.inferers import SimpleInferer


import sys
sys.path.append('.')
from networks import Discriminator

## Define data directory and filenames

In [None]:
root_dir ='/path/data'

file_name_flair = 'FLAIR_1.nii.gz'
file_name_t1 = 'T1_1.nii.gz'
file_name_seg = 'seg_les_1.nii.gz'

all_patients = os.listdir(root_dir)

## Split dataset for training and validation


In [None]:

partitions = partition_dataset(data=all_patients, ratios=[0.8,0.2],shuffle=True, seed=1234)
train_folders = partitions[0]
val_folders = partitions[1]


## Read image filenames for training and validation

In [None]:

train_flair_images = [os.path.join(root_dir,i,file_name_flair) for i in train_folders]
train_t1_images = [os.path.join(root_dir,i,file_name_t1) for i in train_folders]
train_segs = [os.path.join(root_dir,i,file_name_seg) for i in train_folders]


val_flair_images = [os.path.join(root_dir,i,file_name_flair) for i in val_folders]
val_t1_images = [os.path.join(root_dir,i,file_name_t1) for i in val_folders]
val_segs = [os.path.join(root_dir,i,file_name_seg) for i in val_folders]


In [None]:
# create a list of dictionaries containing the path to images and segmentation for each exam 

train_files = [{"img": gt, "t1": t1, "seg": seg} for gt, t1, seg in zip(train_flair_images, train_t1_images, train_segs)]
val_files = [{"img": gt, "t1": t1, "seg": seg} for gt, t1, seg in zip(val_flair_images, val_t1_images, val_segs)]

## Define the transformations to be applied on the image data 

In [None]:

train_transforms = Compose(
    [   LoadImaged(keys=['img','t1','seg'], reader='ITKReader'),
         AddChanneld(keys=['img','t1','seg'],),
         Spacingd(keys=['img','t1','seg'],pixdim=(0.8,0.8,1), mode=['bilinear', 'bilinear', 'nearest']),
         CropForegroundd(keys=['img','t1','seg'],source_key='seg', margin=10),
         CenterSpatialCropd(keys=['img','t1','seg'],roi_size =(192,240,160)),
         SpatialPadd(keys=['img','t1','seg'],spatial_size=(192,240,160)),
         RandRotated(keys=['img','t1','seg'],range_x=0.2, prob=0.1, mode=['bilinear', 'bilinear', 'nearest']),
         NormalizeIntensityd(keys=['img','t1']),
         ScaleIntensityRanged(keys='seg', a_min=0, a_max=5, b_min=0, b_max=1),
         ConcatItemsd(keys=['t1','seg'],name='input'),
         EnsureTyped(keys=['img','t1','seg']),
    ]
)
    
val_transforms = Compose(
    [   LoadImaged(keys=['img','t1','seg'], reader='ITKReader'),
         AddChanneld(keys=['img','t1','seg'],),
         Spacingd(keys=['img','t1','seg'],pixdim=(0.8,0.8,1), mode=['bilinear', 'bilinear', 'nearest']),
         CropForegroundd(keys=['img','t1','seg'],source_key='seg', margin=10),
         CenterSpatialCropd(keys=['img','t1','seg'],roi_size =(192,240,160)),
         SpatialPadd(keys=['img','t1','seg'],spatial_size=(192,240,160)),
         NormalizeIntensityd(keys=['img','t1']),
         ScaleIntensityRanged(keys='seg', a_min=0, a_max=5, b_min=0, b_max=1),
         ConcatItemsd(keys=['t1','seg'],name='input'),
         EnsureTyped(keys=['img','t1','seg']),
    ]
)

## Create data loaders for training and validation dataset

In [None]:
train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
train_loader = DataLoader(
    train_ds,
    batch_size=2,
    shuffle=True,
    num_workers=4,
    collate_fn=list_data_collate,
    pin_memory=torch.cuda.is_available(),
)
train_loader_iter = iter(train_loader) # create iterable object for visualization

In [None]:
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
post_trans = Compose([EnsureType()])

## Visualize images and corresponding segmentation

In [None]:
batch = next(train_loader_iter)
img, seg = batch['img'], batch['input']


plt.figure(figsize=(16,10))
for j in range(1):
        plt.subplot(2, 3, j + 1)
        plt.imshow(img[j,0,:, :, 90], cmap="gray")
        
        plt.subplot(2, 3, j + 2)
        plt.imshow(seg[j,0,:, :, 90], cmap="gray")
        
        plt.subplot(2, 3, j + 3)
        plt.imshow(seg[j,1,:, :, 90], cmap="gray")

## Initialize generator and discriminator

In [None]:
       
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
generator = monai.networks.nets.UNet(
    spatial_dims=3,
    in_channels=2,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2)
).to(device)


discriminator = Discriminator().to(device)


# Define loss functions, optimizer and modeltraining utilities

In [None]:
# loss functions
disc_loss = torch.nn.MSELoss(reduction="mean")
loss_voxelwise = torch.nn.L1Loss(reduction="mean")


In [None]:
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=1e-4)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.5e-4)

In [None]:
#metrics and lr scheduler
model_metric =  monai.metrics.MAEMetric(reduction='mean')
scheduler = ReduceLROnPlateau(optimizer_G, factor=0.9, patience=3, mode='min')

In [None]:
# Calculate output of image discriminator (PatchGAN)
patch = (1, 192// 2 ** 4, 240// 2 ** 4,  160// 2 ** 4)

## Training and validation loop, training routine adapted from https://github.com/enochkan/vox2vox


In [None]:
val_interval = 2
best_metric = 100
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
inf = SimpleInferer()
writer = SummaryWriter(log_dir='./runs/gan')
disc_acc_threshold = 0.8
lambda_voxel = 100

for epoch in range(100):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{100}")
    epoch_loss = 0
    step = 0
        
    for batch_data in train_loader:
        step += 1
        
        # model inputs 
        input_data, segs, targets = batch_data["input"].to(device), batch_data["seg"].to(device), batch_data["img"].to(device) 
        
        # discriminator ground truths
        real_gt = Variable(torch.ones((segs.size(0), *patch))).to(device)
        fake_gt = Variable(torch.zeros((segs.size(0), *patch))).to(device)

        #  Train Discriminator
        # ---------------------
        # Real loss
        pred_imgs = generator(input_data)
        pred_real = discriminator(targets, segs)
        loss_real = disc_loss(pred_real, real_gt)
        
        # Fake loss
        pred_fake = discriminator(pred_imgs.detach(), segs)
        loss_fake = disc_loss(pred_fake, fake_gt)
        # Total loss
        loss_D = 0.5 * (loss_real + loss_fake)
        
        disc_acc_real = torch.ge(pred_real.squeeze(), 0.5).float()
        disc_acc_fake = torch.le(pred_fake.squeeze(), 0.5).float()
        disc_acc_total = torch.mean(torch.cat((disc_acc_real, disc_acc_fake), 0))
        
        if disc_acc_total <= disc_acc_threshold and epoch%2 == 0 : # update discriminator only in alternate epochs
            optimizer_D.zero_grad()
            loss_D.backward()
            optimizer_D.step()
                
        optimizer_D.zero_grad()
        
        #  Train Generators
        # ------------------         
        
         # GAN loss
        pred_imgs = generator(input_data)
        pred_fake = discriminator(pred_imgs, segs)
        loss_GAN = disc_loss(pred_fake, real_gt)


        # Voxel-wise loss
        loss_voxel = loss_voxelwise(pred_imgs, targets)

        # Total loss
        loss_G = loss_GAN + lambda_voxel * loss_voxel
        
        loss_G.backward()
        optimizer_G.step()       
        optimizer_G.zero_grad()
        

        epoch_loss += loss_G.item()
        epoch_len = len(train_ds) // train_loader.batch_size
        print(f"{step}/{epoch_len}, train_loss: {loss_G.item():.4f}")
        writer.add_scalar("train_loss", loss_G.item(), epoch_len * epoch + step)
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    writer.add_scalar("Discriminator accuracy training", disc_acc_total, epoch + 1)
    
    if (epoch + 1) % val_interval == 0:
        generator.eval()
        with torch.no_grad():
            val_outputs = None
            val_epoch_loss = 0
            val_step=0
            for val_data in val_loader:
                val_step += 1
                val_input_data, val_segs, val_targets =  val_data["input"].to(device), val_data["seg"].to(device), val_data["img"].to(device)
                roi_size = (192,240,160)
                val_outputs = inf(inputs=val_input_data, network=generator)
                val_loss = loss_voxelwise(val_outputs, val_targets)
                val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]                
                # compute metric for current iteration
                model_metric(y_pred=val_outputs, y=val_targets)
                val_epoch_loss += val_loss.item()
            val_epoch_loss /= val_step    
            scheduler.step(val_epoch_loss,epoch=epoch+1)
            # aggregate the final mean result
            metric = model_metric.aggregate().item()
            # reset the status for next validation round
            model_metric.reset()

            metric_values.append(metric)
            if metric < best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(generator.state_dict(), "./models/generator.pth")
                print("saved new best metric model")
            print(
                "current epoch: {} current mean MAE: {:.4f} best mean MAE: {:.4f} at epoch {}".format(
                    epoch + 1, metric, best_metric, best_metric_epoch
                )
            )
            writer.add_scalar("val_mean_loss", val_epoch_loss, epoch + 1)
            writer.add_scalar("lr", optimizer_G.param_groups[0]['lr'], epoch + 1)
            
            
            # plot the last model output as GIF image in TensorBoard with the corresponding image and label
            plot_2d_or_3d_image(val_targets, epoch + 1, writer, index=0, tag="image")
            plot_2d_or_3d_image(val_segs, epoch + 1, writer, index=0, tag="label")
            plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")

print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
writer.close()



