In [2]:
from pathlib import Path
from datetime import datetime
import os
import re
import csv
import torch
from torch.utils.data import random_split, DataLoader
import monai
import gdown
import pandas as pd
import torchio as tio
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import csv
plt.rcParams['figure.figsize'] = 12, 8
monai.utils.set_determinism()

In [None]:
#Set root folder

root = os.getcwd()
Task1Folder = f'{root}\\Task1Synapse'
os.listdir(Task1Folder)
print(Task1Folder)

In [None]:
#Single channel data set creation

patient_list = [Task1Folder + '\\' + x +f'\{x}_flair.nii.gz' for x in os.listdir(Task1Folder)]

label_list = [Task1Folder + '\\' + x +f'\{x}_seg.nii.gz' for x in os.listdir(Task1Folder)]

subjects = []
for x, y in zip(patient_list, label_list):
    subject = tio.Subject(
        image = tio.ScalarImage(x),
        label = tio.LabelMap(y),
        name = x
    )
    subjects.append(subject)

preprocess = tio.Compose([
            tio.RescaleIntensity((-1, 1)),
            tio.EnsureShapeMultiple(8),
            tio.OneHot(),
        ])

new_data = tio.SubjectsDataset(subjects, transform=preprocess)

In [None]:
#Create list of subjects with less than 3 labels
new_bad_list = []
for x in new_data:
    if x.label.shape[0] != 5:
        print(x.label.shape[0])
        print(x.name)
        new_bad_list.append(x.name)


In [None]:
#Create multichannel dataset

patient_flair_list = [Task1Folder + '\\' + x + f'\{x}_flair.nii.gz' for x in os.listdir(Task1Folder)]
patient_t1_list = [Task1Folder + '\\' + x + f'\{x}_t1.nii.gz' for x in os.listdir(Task1Folder)]
patient_t1ce_list = [Task1Folder + '\\' + x + f'\{x}_t1ce.nii.gz' for x in os.listdir(Task1Folder)]
label_list = [Task1Folder + '\\' + x + f'\{x}_seg.nii.gz' for x in os.listdir(Task1Folder)]
new_patient_flair_list = [x for x in patient_flair_list if x not in new_bad_list]
list_copy = new_patient_flair_list
new_label_list = [x.replace(x[-12:-7],'seg') for x in list_copy]
new_patient_t1_list = [x.replace(x[-12:-7],'t1') for x in list_copy]
new_patient_t1ce_list = [x.replace(x[-12:-7],'t1ce') for x in list_copy]

subjects = []
for a, b, c, d in zip(
    new_patient_flair_list,
    new_patient_t1_list,
    new_patient_t1ce_list,
    new_label_list):

    subject = tio.Subject(
        channel_flair = tio.ScalarImage(a),
        channel_t1 = tio.ScalarImage(b),
        channel_t1ce = tio.ScalarImage(c),
        label = tio.LabelMap(d)
    )
    subjects.append(subject)

train_subjects = subjects[0:1200]
test_subjects = subjects[1200:]

In [None]:
#Create preprocessing/augmentation transforms and dataloader for training, validation, and test sets

class DataModule(pl.LightningDataModule):
    def __init__(self, train_subjects, test_subjects, batch_size, train_val_ratio):
        super().__init__()
        self.subjects = train_subjects
        self.batch_size = batch_size
        self.train_val_ratio = train_val_ratio
        self.test_subjects = test_subjects
        self.preprocess = None
        self.transform = None
        self.train_set = None
        self.val_set = None
        self.test_set = None

    #can comment this out and CropOrPad for faster setup
    def get_max_shape(self, train_subjects):
        import numpy as np
        dataset = tio.SubjectsDataset(train_subjects)
        shapes = np.array([s.spatial_shape for s in dataset])
        return shapes.max(axis=0)

    def get_preprocessing_transform(self):
        preprocess = tio.Compose([
            tio.RescaleIntensity((-1, 1)),
            tio.CropOrPad((240,240,160)), 
            tio.EnsureShapeMultiple(8),
            tio.OneHot(), 
        ])
        return preprocess
    
    def get_augmentation_transform(self):
        augment = tio.Compose([
            tio.RandomAffine(),
            tio.RandomGamma(p=0.5),
            tio.RandomNoise(p=0.5),
            tio.RandomMotion(p=0.1),
            tio.RandomBiasField(p=0.25),
        ])
        return augment    

    def setup(self):
        num_subjects = len(self.subjects)
        num_train_subjects = int(round(num_subjects * self.train_val_ratio))
        num_val_subjects = num_subjects - num_train_subjects
        splits = num_train_subjects, num_val_subjects
        train_subjects, val_subjects = random_split(self.subjects, splits)

        self.preprocess = self.get_preprocessing_transform()
        augment = self.get_augmentation_transform()
        self.transform = tio.Compose([self.preprocess, augment])
    
        self.train_set = tio.SubjectsDataset(train_subjects, transform=self.transform)
        self.val_set = tio.SubjectsDataset(val_subjects, transform=self.preprocess)
        self.test_set = tio.SubjectsDataset(self.test_subjects, transform=self.preprocess)

    def train_dataloader(self):
        return DataLoader(self.train_set, self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_set, self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_set, self.batch_size)  

new_data_module = DataModule(train_subjects=train_subjects, test_subjects=test_subjects, batch_size=1, train_val_ratio=0.8)
new_data_module.setup()

In [None]:
#Set up training and validation loop, concatenate channels

class Model(pl.LightningModule):
    def __init__(self, net, criterion, learning_rate, optimizer_class):
        super().__init__()
        self.lr = learning_rate
        self.net = net
        self.criterion = criterion
        self.optimizer_class = optimizer_class
    
    def configure_optimizers(self):
        optimizer = self.optimizer_class(self.parameters(), lr=self.lr)
        return optimizer
    
    def prepare_batch(self, batch):
        return batch['channel_flair'][tio.DATA], batch['channel_t1'][tio.DATA], batch['channel_t1ce'][tio.DATA], batch['label'][tio.DATA]
    
    def infer_batch(self, batch):
        e, f, g, y = self.prepare_batch(batch)
        batch_channel_tuple = (e, f, g)
        all_images = torch.cat(batch_channel_tuple, dim=1)
        y_hat = self.net(all_images)
        return y_hat, y

    def training_step(self, batch, batch_idx):
        y_hat, y = self.infer_batch(batch)
        loss = self.criterion(y_hat, y)
        self.log('train_loss', loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        y_hat, y = self.infer_batch(batch)
        loss = self.criterion(y_hat, y)
        self.log('val_loss', loss)
        return loss

In [None]:
#Configure Unet and training loop

unet = monai.networks.nets.UNet(
    dimensions=3,
    in_channels=3,
    out_channels=5,
    channels=(8, 16, 32, 64),
    strides=(2, 2, 2),
)

model = Model(
    net=unet,
    criterion=monai.losses.DiceCELoss(softmax=True),
    learning_rate=1e-2,
    optimizer_class=torch.optim.AdamW,
)
early_stopping = pl.callbacks.early_stopping.EarlyStopping(
    monitor='val_loss',
)
trainer = pl.Trainer(
    gpus=1,
    precision= 16,
    callbacks=[early_stopping],
)
trainer.logger._default_hp_metric = False

In [None]:
#Start training loop

start = datetime.now()
print('Training started at', start)
trainer.fit(model=model, datamodule=new_data_module)
print('Training duration:', datetime.now() - start)

In [None]:
# Save the Model

# model_file = 'Unet_model_Multichannel.pth'
# torch.save(model, model_file)
# model = torch.load(model_file)

In [None]:
#Create preprocess transform for validation data

preprocess_val = tio.Compose([
    tio.RescaleIntensity((-1, 1)),
    tio.CropOrPad((240,240,155)), 
    tio.EnsureShapeMultiple(8),
    tio.OneHot(),
    ])

In [None]:
#Create Brats Test Data Set

brats_root = os.getcwd()
brats_folder = brats_root + '\\' + 'BRATS_TEST_DATA'
patient_flair_list_brats = [brats_folder + '\\' + x + f'\{x}_flair.nii.gz' for x in os.listdir(brats_folder)]
patient_t1_list_brats = [brats_folder + '\\' + x + f'\{x}_t1.nii.gz' for x in os.listdir(brats_folder)]
patient_t1ce_list_brats = [brats_folder + '\\' + x + f'\{x}_t1ce.nii.gz' for x in os.listdir(brats_folder)]

brats_subjects = []
for a, b, c in zip(
    patient_flair_list_brats,
    patient_t1_list_brats,
    patient_t1ce_list_brats):

    subject = tio.Subject(
        channel_flair = tio.ScalarImage(a),
        channel_t1 = tio.ScalarImage(b),
        channel_t1ce = tio.ScalarImage(c),
        name = a
    )
    brats_subjects.append(subject)

brats_dataset = tio.SubjectsDataset(brats_subjects, transform = preprocess_val)

In [None]:
#Create and save predictions and generate Dice Score

dice_metric = monai.metrics.DiceMetric(include_background=True, reduction="mean")
dice_metric_batch = monai.metrics.DiceMetric(include_background=True, reduction="mean_channel")

Metric_Test_ET_val = []
Metric_Test_TC_val = []
Metric_Test_WT_val = []
Metric_val = []

with torch.no_grad():
    for x in brats_data_loader:
        #Concatenate Channels
        flair_con = x['channel_flair'][tio.DATA]
        t1_con = x['channel_t1'][tio.DATA]
        t1ce_con = x['channel_t1ce'][tio.DATA]
        brats_batch_channel_tuple = (flair_con, t1_con, t1ce_con)
        all_images_brats = torch.cat(brats_batch_channel_tuple, dim=1).to(model.device)

        #Create predictions and add predictions to subjects
        preds = model.net(all_images_brats).argmax(dim=1, keepdim=True).cpu()
        batch_subject_brat = tio.utils.get_subjects_from_batch(x)
        tio.utils.add_images_from_batch(batch_subject_brat, preds, tio.LabelMap)
        print(batch_subject_brat[0]['prediction'].plot())
        new_name = x['name'][0][-18:-13]
        print(new_name)
        transformed_batch_subject_brat = new_transform(batch_subject_brat[0]['prediction'])
        transformed_batch_subject_brat.save(f'TestValMultiCrop\\{new_name}.nii.gz')


        #Transform predictions to one hot
        y_pred_transform = one_hot_transform(batch_subject_brat[0]['prediction'])
        y_pred = y_pred_transform[tio.DATA]
        y1 = batch_subject_brat[0]['label'][tio.DATA]


        #Calculate Dice scores
        dice_metric(y_pred, y1)
        metric = dice_metric.aggregate()
        Metric_val.append(metric)

        dice_metric_batch(y_pred=y_pred, y= y1)
        metric_batch = dice_metric_batch.aggregate()
        print(metric_batch)
        print(f'metric_batch ET: {metric_batch[1]}')
        Metric_Test_ET_val.append(metric_batch[1])
        print(f'metric_batch TC: {metric_batch[2]}')
        Metric_Test_TC_val.append(metric_batch[2])
        print(f'metric_batch WT: {metric_batch[4]}')
        Metric_Test_WT_val.append(metric_batch[4])

        dice_metric_batch.reset()
        dice_metric.reset()


In [None]:
#Display Dice scores

ET_val_mean = torch.mean(torch.stack(Metric_Test_ET_val))
TC_val_mean = torch.mean(torch.stack(Metric_Test_TC_val))
WT_val_mean = torch.mean(torch.stack(Metric_Test_WT_val))
metric_val_mean = torch.mean(torch.stack(Metric_val))
print(f'ET: {ET_val_mean}')
print(f'TC: {TC_val_mean}')
print(f'WT: {WT_val_mean}')
print(f'Overall Dice: {metric_val_mean}')
print(len(Metric_Test_ET_val))