### Connect to drive and import some library

In [None]:
!pip install SimpleITK
!pip install monai

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import sys
sys.path.append('/content/drive/MyDrive/Code_BrainTumorSeg_Conf/')

In [4]:
import SimpleITK as sitk
import numpy as np
import torch
from sklearn.model_selection import KFold
from torch.utils.data.dataset import Dataset
import argparse
import os
import random
import pathlib
import time
from datetime import datetime
import torch.backends.cudnn as cudnn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import yaml
from torch.autograd import Variable
from monai.data import decollate_batch
from monai.losses import DiceLoss
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt

from brats_ import *
from metrics import *
from preprocessing import *
from logfile import LOGGER
from train_val_epoch import *


In [5]:
BRATS_TRAIN_FOLDERS = f"/path/to/yourdata"

def get_brats_folder(on="train"):
    if on == "train":
        return BRATS_TRAIN_FOLDERS

### Load data

In [None]:
def get_datasets(seed, on="train", fold_number=0, normalisation="zscore"):
    base_folder = pathlib.Path(get_brats_folder(on)).resolve()
    assert base_folder.exists()
    patients_dir = sorted([x for x in base_folder.iterdir() if x.is_dir()])

    patients_dir = [x for x in patients_dir]

    kfold = KFold(5, shuffle=True, random_state=seed)
    print(kfold)
    splits = list(kfold.split(patients_dir))
    print(splits)
    print('________________',splits[fold_number])
    train_idx, test_idx = splits[fold_number]

#     train_idx = train_idx[: 10]
#     test_idx = test_idx[:5]

    train = [patients_dir[i] for i in train_idx]
    test = [patients_dir[i] for i in test_idx]

    train_dataset = Brats(train, training=True, normalisation=normalisation)
    test_dataset = Brats(test, training=False, benchmarking=True, normalisation=normalisation)

    return train_dataset, test_dataset

full_train_dataset, val_dataset = get_datasets(123, fold_number=2)
print(len(full_train_dataset), len(val_dataset))

In [None]:
## Split train and val set
train_loader = torch.utils.data.DataLoader(full_train_dataset, batch_size=1, shuffle=True,
                                           num_workers=1, pin_memory=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False,
                                         pin_memory=True, num_workers=1)

print("Train dataset number of batch:", len(train_loader))
print("Val dataset number of batch:", len(val_loader))

In [None]:
## Show shape of data
for batch in train_loader:
    # Assuming your input data is a 4D tensor (batch_size, channels, height, width)
    data_shape = batch['image'].shape
    label_shape = batch['label'].shape
    print("Data shape in the first batch:", data_shape)
    print("Label shape in the first batch:", label_shape)
    break  # Print only the first batch

### Visualize some data

In [None]:
# Define a custom colormap with different colors for each label
label_colors = ['red','yellow', 'green']

# RGB values for the colors
color_values = {'red': (255, 0, 0, 255),  'yellow': (255, 255, 0, 255),'green': (0, 255, 0, 255)}

background_color = (0, 0, 0, 255)
# Assuming you already have the 'train_loader' correctly set up
i=0
# Get a single batch from the DataLoader
for batch in train_loader:
    # Assuming batch_size=1, so we only take the first sample
    image_sample = batch['image']
    label_sample = batch['label']

    # Convert tensors to numpy arrays
    image_sample_np = image_sample.numpy()
    label_sample_np = label_sample.numpy()

    # Select the central slice along the z-axis (depth) for visualization
    z_slice =  image_sample_np.shape[2] // 2 +2 # between

    # Plot each channel of the image sample
    num_channels = image_sample_np.shape[1]
    fig, axes = plt.subplots(1, num_channels, figsize=(15, 5))

    for channel in range(num_channels):
        axes[channel].imshow(image_sample_np[0, channel, z_slice], cmap='gray')
        axes[channel].set_title(f"Image Channel {channel + 1}")


    plt.tight_layout()
    plt.show()

    # Plot each channel of the label sample (assuming 3 channels for one-hot encoded segmentation)
    num_channels_labels = label_sample_np.shape[1]
    fig, axes = plt.subplots(1, num_channels_labels, figsize=(15, 5))

    for channel in range(num_channels_labels):
        axes[channel].imshow(label_sample_np[0, channel, z_slice], cmap='gray')
        axes[channel].set_title(f"Label Channel {channel + 1}")

    plt.tight_layout()
    plt.show()


    image_sample_np = np.full((label_sample_np.shape[3], label_sample_np.shape[4], 4), background_color, dtype=np.uint8)
    # Combine the label channels with different colors
    num_channels_labels = label_sample_np.shape[1]
    for channel in range(num_channels_labels-1, -1, -1):
        label_channel = label_sample_np[0, channel, z_slice]

        # Overlay the label with a unique color
        label_color = label_colors[channel % len(label_colors)]
        color_value = color_values[label_color]

        # Create a mask for the current label channel
        label_mask = label_channel > 0

        # Apply the color with alpha channel to the corresponding pixels in the composite label
        image_sample_np[label_mask] = color_value

    # Plot the composite label image
    plt.figure(figsize=(10, 5))
    plt.imshow(image_sample_np)
    plt.title("Composite Label")
    plt.show()

    i+=1
    if i == 1:
        break

### Model

In [None]:
from models1.unet3d import UNet3d
from models1.unet3d_cot import UNet3d_cot
from models1.unet3d_da import UNet3d_da
from models1.unet3d_da_cot import UNet3d_da_cot

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
## Change model here
model = UNet3d_da_cot(in_channels=4, n_classes=3, n_channels=16).to(device)
print('Number of network parameters:', sum(param.numel() for param in model.parameters()))

### Function to train

In [11]:

def trainer(model, train_loader, val_loader, optimizer, loss_func, acc_func, criterian_val, metric, scheduler, batch_size, max_epochs, start_epoch=1):
    val_acc_max, best_epoch = 0.0, 0
    total_time = time.time()
    dices_tc , dices_wt, dices_et, dices_avg, loss_epochs, trains_epoch = [],[], [],[],[],[]

    for epoch in range(start_epoch, max_epochs+1):
        LOGGER.info(f"\n{'=' * 30}Training epoch {epoch}{'=' * 30}")
        epoch_time = time.time()

        train_loss = train_epoch(
            model,
            train_loader,
            optimizer,
            epoch,
            loss_func,
            batch_size,
            max_epochs,
        )
        LOGGER.info(f"Final training epochs: {epoch}/{max_epochs} ---[loss: {train_loss:.4f}] ---[time {time.time() - epoch_time:.2f}s]")

        if scheduler is not None:
                scheduler.step()

        if (epoch) % val_every == 0 or epoch == max_epochs or epoch == 1:
            loss_epochs.append(train_loss)
            trains_epoch.append(int(epoch))
            epoch_time = time.time()
            LOGGER.info(f"\t{'*' * 20}Epoch {epoch} Validation{'*' * 20}")
            val_acc = val_epoch(model,
                                val_loader,
                                epoch=epoch,
                                acc_func=acc_func,
                                criterian_val = criterian_val,
                                metric = metric,
                                max_epochs = max_epochs,
                                )
            dice_et, dice_tc, dice_wt = val_acc[0], val_acc[1], val_acc[2]
            val_avg_acc = np.mean(val_acc)
            LOGGER.info(f"\t{'*' * 20}Epoch Summary{'*' * 20}")
            LOGGER.info(f"Final validation stats {epoch}/{max_epochs}, dice_et: {dice_et:.6f}, dice_tc: {dice_tc:.6f}, dice_wt: {dice_wt:.6f} , Dice_Avg: {val_avg_acc:.6f} , time {time.time() - epoch_time:.2f}s")


            dices_tc.append(dice_tc)
            dices_wt.append(dice_wt)
            dices_et.append(dice_et)
            dices_avg.append(val_avg_acc)

            if val_avg_acc > val_acc_max:
                print("New best ({:.6f} --> {:.6f}). At epoch {}".format(val_acc_max, val_avg_acc, epoch))
                LOGGER.info(f"New best ({val_acc_max:.6f} --> {val_avg_acc:.6f}). At epoch {epoch}. Time consuming: {time.time()-total_time:.2f}")
                val_acc_max = val_avg_acc
                best_epoch = epoch
                torch.save(
                    model.state_dict(),
                    os.path.join("best_metric_model.pth") ## Change this path to your output model dir,
                )
            torch.cuda.empty_cache()

    LOGGER.info(f"Training Finished !, Best Accuracy: {val_acc_max:.6f} --At epoch: {best_epoch} --Total_time: {time.time()-total_time:.2f}")
    return (val_acc_max, dices_tc, dices_wt, dices_et, dices_avg, loss_epochs, trains_epoch)

In [None]:
start_epoch = 1
max_epochs = 100  ## Change number epochs here
batch_size = 1
val_every = 1

criterion = EDiceLoss().to(device)
criterian_val = EDiceLoss_Val().to(device)
metric = criterian_val.metric
###
loss_function = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True)


dice_acc = DiceMetric(include_background=True, reduction='mean_batch', get_not_nans=True)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)

LOGGER.info("[TRAINER] Start TRAIN process...")

(val_acc_max,dices_tc,dices_wt,dices_et,
    dices_avg,loss_epochs, trains_epoch,
) = trainer(
    model = model,
    train_loader = train_loader,
    val_loader = val_loader,
    optimizer = optimizer,
    loss_func = criterion,
    acc_func = dice_acc,
    criterian_val = criterian_val,
    metric = metric,
    scheduler = scheduler,
    batch_size = batch_size,
    max_epochs = max_epochs,
    start_epoch = start_epoch
    )

### Visualize results

In [None]:
from visualize_image_results import visualize_results
## change this file to your model
model_file_out = "/path/to/yourmodel"
munber_images = 1
visualize_results(model, val_loader, model_file_out, munber_images, device)

### HD95

In [None]:
!pip install medpy

In [None]:
from hd95_metrics import *
calc_hd95(model, val_loader, device, model_file_out, max_epochs)