# Importing & Loading Dependencies 

In [1]:
!pip install monai

import nibabel as nib
from monai.transforms import LoadImage, Compose, NormalizeIntensityd, RandSpatialCropd, RandFlipd, \
                             RandRotate90d, Rand3DElasticd, RandAdjustContrastd, CenterSpatialCropd,\
                             Resized, RandRotated, RandZoomd, RandGaussianNoised, Spacingd, RandShiftIntensityd,  CropForegroundd, SpatialPadd, AsDiscrete, GridPatchd\

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch
import logging
import numpy as np
import os
import random
from typing import Optional, Sequence, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import interpolate
import pdb

from monai.networks.blocks.convolutions import Convolution
from monai.networks.layers.factories import Act, Norm
from monai.networks.layers.utils import get_act_layer, get_norm_layer
from monai.networks import one_hot
from monai.metrics import DiceMetric, HausdorffDistanceMetric
import torchvision
import math

from grpc import insecure_channel
import argparse
from torch import optim, amp
from monai.losses import DiceLoss
import torch.distributed as dist

from pathlib import Path
import argparse
import json
import math
import os
import random
import signal
import subprocess
import sys
import time
import pdb
import logging
from ipywidgets import interact, IntSlider

from monai.losses import DiceLoss
from torch import nn, optim
from tqdm import tqdm
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision

Collecting monai
  Downloading monai-1.4.0-py3-none-any.whl.metadata (11 kB)
Downloading monai-1.4.0-py3-none-any.whl (1.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: monai
Successfully installed monai-1.4.0


# Creating Dataset with Preprocessing

In [2]:
class CustomDataset3D(Dataset):
    def __init__(self, data_dir, patient_lists, mode):
        self.data_dir = data_dir
        self.patient_lists = patient_lists
        self.mode = mode

    @staticmethod
    def resize_with_aspect_ratio(keys, target_size):
        def transform(data):
            for key in keys:
                volume = data[key]
                original_shape = volume.shape[-3:]
    
                scaling_factor = min(
                    target_size[0] / original_shape[0],
                    target_size[1] / original_shape[1],
                    target_size[2] / original_shape[2]
                )
    
                # Computing the intermediate size while preserving aspect ratio
                new_shape = [
                    int(dim * scaling_factor) for dim in original_shape
                ]
    
                # Resizing to the intermediate shape
                resize_transform = Resized(keys=[key], spatial_size=new_shape, mode="trilinear" if key == "imgs" else "nearest-exact")
                data = resize_transform(data)
    
                # Padding to the final target size
                pad_transform = SpatialPadd(keys=[key], spatial_size=target_size, mode="constant")
                data = pad_transform(data)
            return data

        return transform
    
    def preprocess(cls, data, mode):
        if mode == 'training':
          transform = Compose([
            CropForegroundd(keys=["imgs", "masks"], source_key="imgs"),
            cls.resize_with_aspect_ratio(keys=["imgs", "masks"], target_size=[128, 128, 128]),
            NormalizeIntensityd( keys=['imgs'], nonzero=False, channel_wise=True),
              
            RandFlipd(keys=["imgs", "masks"],   
                    prob=0.5,                 
                    spatial_axis=2,  
            ),

            RandAdjustContrastd(
                keys=["imgs"],          
                prob=0.15,             
                gamma=(0.65, 1.5),   
            ),
            
        ])

        elif mode == 'validation':
          transform = Compose([
            CropForegroundd(keys=["imgs", "masks"], source_key="imgs"),
            cls.resize_with_aspect_ratio(keys=["imgs", "masks"], target_size=[128, 128, 128]),
            NormalizeIntensityd( keys=['imgs'], nonzero=False, channel_wise=True)

        ])

        else: # 'testing'
          transform = Compose([
            CropForegroundd(keys=["imgs", "masks"], source_key="imgs"),
            cls.resize_with_aspect_ratio(keys=["imgs", "masks"], target_size=[128, 128, 128]),
            NormalizeIntensityd( keys=['imgs'], nonzero=False, channel_wise=True)

        ])

        augmented_data = transform(data)
        return augmented_data

    def __len__(self):
        return len(self.patient_lists)

    def __getitem__(self, idx):
        patient_id = self.patient_lists[idx]
        loadimage = LoadImage(reader='NibabelReader', image_only=True)

        patient_folder_path = os.path.join(self.data_dir, patient_id)

        def resolve_file_path(folder, name):
            file_path = os.path.join(folder, name)
            # Check if the given path is a directory (case with 4 subdirs)
            if os.path.isdir(file_path):
                # Find the first file inside the directory that ends with .nii
                for root, _, files in os.walk(file_path):
                    for file in files:
                        if file.endswith(".nii"):
                            return os.path.join(root, file)
            return file_path


        # Resolve paths for all required image types
        t1c_path  = resolve_file_path(patient_folder_path, patient_id + '-t1c.nii')
        t1n_path  = resolve_file_path(patient_folder_path, patient_id + '-t1n.nii')
        t2f_path  = resolve_file_path(patient_folder_path, patient_id + '-t2f.nii')
        t2w_path  = resolve_file_path(patient_folder_path, patient_id + '-t2w.nii')
        seg_path  = os.path.join(patient_folder_path, patient_id + '-seg.nii')

        t1c_loader   = loadimage( t1c_path )
        t1n_loader   = loadimage( t1n_path )
        t2f_loader   = loadimage( t2f_path )
        t2w_loader   = loadimage( t2w_path )
        masks_loader = loadimage( seg_path )

        # Make the dimension of channel
        t1c_tensor   = torch.Tensor(t1c_loader).unsqueeze(0)
        t1n_tensor   = torch.Tensor(t1n_loader).unsqueeze(0)
        t2f_tensor   = torch.Tensor(t2f_loader).unsqueeze(0)
        t2w_tensor   = torch.Tensor(t2w_loader).unsqueeze(0)
        masks_tensor = torch.Tensor(masks_loader).unsqueeze(0)

        # dividing the mask to each class
        '''
        Output channel
        [1, 0, 0, 0] = BackGround (BG)
        [0, 1, 0, 0] = Necrosis (NE)
        [0, 0, 1, 0] = Edema (ED)
        [0, 0, 0, 1] = Enhancing Tumor (ET)
        
        '''

        # x = torch.ones( masks_tensor.size() )
        # y = torch.zeros( masks_tensor.size() )

        # masks_BG = torch.where( masks_tensor == 0, x, y )
        # masks_NE = torch.where( masks_tensor == 1, x, y )
        # masks_ED = torch.where( masks_tensor == 2, x, y )
        # masks_ET = torch.where( masks_tensor == 3, x, y )

        concat_tensor = torch.cat( (t1c_tensor, t1n_tensor, t2f_tensor, t2w_tensor, masks_tensor), 0 )

        data = {
            
            'imgs'  : np.array(concat_tensor[0:4,:,:,:]),
            'masks' : np.array(concat_tensor[4:,:,:,:])

        }


        augmented_imgs_masks = self.preprocess(data, self.mode)
        imgs  = np.array(augmented_imgs_masks['imgs'])
        masks = np.array(augmented_imgs_masks['masks'])

        y = {

            'imgs'  : torch.from_numpy(imgs).type(torch.FloatTensor),
            'masks' : torch.from_numpy(masks).type(torch.FloatTensor),
            'patient_id' : patient_id

        }

        return y

# Data Loaders

In [3]:
def prepare_data_loaders(args):
    random.seed(5)
    split_ratio = {'training': 0.7, 'validation': 0.1, 'testing': 0.2}
    data_dir = args['data_dir']
    
    patient_lists = os.listdir( data_dir )
    patient_lists.sort()
    total_patients = len(patient_lists)
    
    # Shuffle the patient list
    random.shuffle(patient_lists)
    
    train_split = int(split_ratio['training'] * total_patients)
    val_split = int(split_ratio['validation'] * total_patients)
    
    train_patient_lists = patient_lists[:train_split]
    val_patient_lists = patient_lists[train_split : train_split + val_split]
    test_patient_lists = patient_lists[train_split + val_split :]
    
    train_patient_lists.sort()
    val_patient_lists.sort()
    test_patient_lists.sort()

    print('Train', train_patient_lists)
    print('Val', val_patient_lists)
    print('Tsting', test_patient_lists)
    
    print(f'Number of training samples', len(train_patient_lists))
    print(f'Number of validation samples', len(val_patient_lists))
    print(f'Number of testing samples', len(test_patient_lists))

    trainDataset = CustomDataset3D( data_dir, train_patient_lists, mode='training')
    valDataset   = CustomDataset3D( data_dir, val_patient_lists, mode='validation')
    testDataset  = CustomDataset3D( data_dir, test_patient_lists, mode='testing')
    
    trainLoader = DataLoader(
        trainDataset, batch_size=args['train_batch_size'], num_workers=args['workers'], prefetch_factor=2,
        pin_memory=True, shuffle=True)
    
    valLoader = DataLoader(
        valDataset, batch_size=args['val_batch_size'], num_workers=args['workers'], prefetch_factor=2,
        pin_memory=True, shuffle=False)
    
    testLoader = DataLoader(
        testDataset, batch_size=args['test_batch_size'], num_workers=args['workers'], prefetch_factor=2,
        pin_memory=True, shuffle=False)

    return trainLoader, valLoader, testLoader

## Glioma Data Split

In [14]:
args = {
    'workers': 2,
    'epochs': 25,
    'train_batch_size': 2,
    'val_batch_size': 2,
    'test_batch_size': 3,
    'learning_rate': 1e-3,
    'weight_decay': 1e-5,
    'data_dir': '/kaggle/input/bratsglioma/Training/',
    'in_checkpoint_dir': Path('/kaggle/input/gliomateachernormalizednew/'),
    'out_checkpoint_dir': Path('/kaggle/working/')
}
trainLoader, valLoader, testLoader = prepare_data_loaders(args)
print("------------------------------------------------")

Train ['BraTS-GLI-00002-000', 'BraTS-GLI-00003-000', 'BraTS-GLI-00006-000', 'BraTS-GLI-00008-000', 'BraTS-GLI-00009-000', 'BraTS-GLI-00009-001', 'BraTS-GLI-00011-000', 'BraTS-GLI-00014-000', 'BraTS-GLI-00014-001', 'BraTS-GLI-00016-001', 'BraTS-GLI-00017-000', 'BraTS-GLI-00018-000', 'BraTS-GLI-00020-000', 'BraTS-GLI-00021-000', 'BraTS-GLI-00021-001', 'BraTS-GLI-00024-000', 'BraTS-GLI-00025-000', 'BraTS-GLI-00026-000', 'BraTS-GLI-00028-000', 'BraTS-GLI-00030-000', 'BraTS-GLI-00031-000', 'BraTS-GLI-00031-001', 'BraTS-GLI-00032-000', 'BraTS-GLI-00032-001', 'BraTS-GLI-00033-000', 'BraTS-GLI-00036-001', 'BraTS-GLI-00043-000', 'BraTS-GLI-00045-000', 'BraTS-GLI-00045-001', 'BraTS-GLI-00046-000', 'BraTS-GLI-00048-001', 'BraTS-GLI-00051-000', 'BraTS-GLI-00052-000', 'BraTS-GLI-00053-000', 'BraTS-GLI-00053-001', 'BraTS-GLI-00054-000', 'BraTS-GLI-00056-000', 'BraTS-GLI-00056-001', 'BraTS-GLI-00058-000', 'BraTS-GLI-00058-001', 'BraTS-GLI-00059-001', 'BraTS-GLI-00060-000', 'BraTS-GLI-00061-001', 'Bra

## Africa Data Split

In [15]:
args = {
    'workers': 2,
    'epochs': 25,
    'train_batch_size': 2,
    'val_batch_size': 2,
    'test_batch_size': 3,
    'learning_rate': 1e-3,
    'weight_decay': 1e-5,
    'data_dir': '/kaggle/input/bratsafrica24/',
    'in_checkpoint_dir': Path('/kaggle/input/gliomateachernormalizednew/'),
    'out_checkpoint_dir': Path('/kaggle/working/')
}
trainLoader, valLoader, testLoader = prepare_data_loaders(args)
print("------------------------------------------------")

Train ['BraTS-SSA-00007-000', 'BraTS-SSA-00008-000', 'BraTS-SSA-00010-000', 'BraTS-SSA-00011-000', 'BraTS-SSA-00012-000', 'BraTS-SSA-00015-000', 'BraTS-SSA-00025-000', 'BraTS-SSA-00026-000', 'BraTS-SSA-00028-000', 'BraTS-SSA-00040-000', 'BraTS-SSA-00041-000', 'BraTS-SSA-00046-000', 'BraTS-SSA-00047-000', 'BraTS-SSA-00051-000', 'BraTS-SSA-00056-000', 'BraTS-SSA-00074-000', 'BraTS-SSA-00076-000', 'BraTS-SSA-00078-000', 'BraTS-SSA-00080-000', 'BraTS-SSA-00081-000', 'BraTS-SSA-00089-000', 'BraTS-SSA-00092-000', 'BraTS-SSA-00095-000', 'BraTS-SSA-00096-000', 'BraTS-SSA-00105-000', 'BraTS-SSA-00106-000', 'BraTS-SSA-00108-000', 'BraTS-SSA-00110-000', 'BraTS-SSA-00111-000', 'BraTS-SSA-00112-000', 'BraTS-SSA-00113-000', 'BraTS-SSA-00114-000', 'BraTS-SSA-00116-000', 'BraTS-SSA-00117-000', 'BraTS-SSA-00119-000', 'BraTS-SSA-00120-000', 'BraTS-SSA-00121-000', 'BraTS-SSA-00124-000', 'BraTS-SSA-00125-000', 'BraTS-SSA-00126-000', 'BraTS-SSA-00127-000', 'BraTS-SSA-00128-000', 'BraTS-SSA-00129-000', 'Bra

## Pediatrics Data Split

In [16]:
args = {
    'workers': 2,
    'epochs': 25,
    'train_batch_size': 2,
    'val_batch_size': 2,
    'test_batch_size': 3,
    'learning_rate': 1e-3,
    'weight_decay': 1e-5,
    'data_dir': '/kaggle/input/bratsped/Training/',
    'in_checkpoint_dir': Path('/kaggle/input/gliomateachernormalizednew/'),
    'out_checkpoint_dir': Path('/kaggle/working/')
}
trainLoader, valLoader, testLoader = prepare_data_loaders(args)
print("------------------------------------------------")

Train ['BraTS-PED-00002-000', 'BraTS-PED-00004-000', 'BraTS-PED-00008-000', 'BraTS-PED-00009-000', 'BraTS-PED-00019-000', 'BraTS-PED-00020-000', 'BraTS-PED-00023-000', 'BraTS-PED-00024-000', 'BraTS-PED-00025-000', 'BraTS-PED-00036-000', 'BraTS-PED-00037-000', 'BraTS-PED-00039-000', 'BraTS-PED-00040-000', 'BraTS-PED-00041-000', 'BraTS-PED-00044-000', 'BraTS-PED-00046-000', 'BraTS-PED-00048-000', 'BraTS-PED-00050-000', 'BraTS-PED-00051-000', 'BraTS-PED-00055-000', 'BraTS-PED-00057-000', 'BraTS-PED-00060-000', 'BraTS-PED-00063-000', 'BraTS-PED-00064-000', 'BraTS-PED-00069-000', 'BraTS-PED-00070-000', 'BraTS-PED-00072-000', 'BraTS-PED-00078-000', 'BraTS-PED-00079-000', 'BraTS-PED-00080-000', 'BraTS-PED-00081-000', 'BraTS-PED-00082-000', 'BraTS-PED-00083-000', 'BraTS-PED-00085-000', 'BraTS-PED-00097-000', 'BraTS-PED-00098-000', 'BraTS-PED-00100-000', 'BraTS-PED-00101-000', 'BraTS-PED-00102-000', 'BraTS-PED-00103-000', 'BraTS-PED-00104-000', 'BraTS-PED-00105-000', 'BraTS-PED-00108-000', 'Bra

## Meningioma Data Split

In [17]:
args = {
    'workers': 2,
    'epochs': 25,
    'train_batch_size': 2,
    'val_batch_size': 2,
    'test_batch_size': 3,
    'learning_rate': 1e-3,
    'weight_decay': 1e-5,
    'data_dir': '/kaggle/input/bratsmen',
    'in_checkpoint_dir': Path('/kaggle/input/gliomateachernormalizednew/'),
    'out_checkpoint_dir': Path('/kaggle/working/')
}
trainLoader, valLoader, testLoader = prepare_data_loaders(args)
print("------------------------------------------------")

Train ['BraTS-MEN-00004-000', 'BraTS-MEN-00010-000', 'BraTS-MEN-00016-000', 'BraTS-MEN-00018-000', 'BraTS-MEN-00020-000', 'BraTS-MEN-00021-000', 'BraTS-MEN-00023-000', 'BraTS-MEN-00025-000', 'BraTS-MEN-00027-000', 'BraTS-MEN-00028-000', 'BraTS-MEN-00031-000', 'BraTS-MEN-00032-000', 'BraTS-MEN-00034-000', 'BraTS-MEN-00037-000', 'BraTS-MEN-00040-000', 'BraTS-MEN-00043-000', 'BraTS-MEN-00045-000', 'BraTS-MEN-00047-000', 'BraTS-MEN-00048-000', 'BraTS-MEN-00052-000', 'BraTS-MEN-00054-000', 'BraTS-MEN-00056-000', 'BraTS-MEN-00060-000', 'BraTS-MEN-00062-000', 'BraTS-MEN-00066-000', 'BraTS-MEN-00067-000', 'BraTS-MEN-00069-000', 'BraTS-MEN-00070-000', 'BraTS-MEN-00071-000', 'BraTS-MEN-00073-000', 'BraTS-MEN-00074-000', 'BraTS-MEN-00074-003', 'BraTS-MEN-00074-004', 'BraTS-MEN-00074-006', 'BraTS-MEN-00074-008', 'BraTS-MEN-00074-009', 'BraTS-MEN-00075-000', 'BraTS-MEN-00077-000', 'BraTS-MEN-00078-000', 'BraTS-MEN-00081-000', 'BraTS-MEN-00085-000', 'BraTS-MEN-00087-000', 'BraTS-MEN-00088-000', 'Bra

## Metastatic Data Split

In [18]:
args = {
    'workers': 2,
    'epochs': 25,
    'train_batch_size': 2,
    'val_batch_size': 2,
    'test_batch_size': 3,
    'learning_rate': 1e-3,
    'weight_decay': 1e-5,
    'data_dir': '/kaggle/input/bratsmet24',
    'in_checkpoint_dir': Path('/kaggle/input/gliomateachernormalizednew/'),
    'out_checkpoint_dir': Path('/kaggle/working/')
}
trainLoader, valLoader, testLoader = prepare_data_loaders(args)

Train ['BraTS-MET-00004-000', 'BraTS-MET-00005-000', 'BraTS-MET-00006-000', 'BraTS-MET-00008-000', 'BraTS-MET-00009-000', 'BraTS-MET-00011-000', 'BraTS-MET-00013-000', 'BraTS-MET-00014-000', 'BraTS-MET-00016-000', 'BraTS-MET-00017-000', 'BraTS-MET-00018-000', 'BraTS-MET-00019-000', 'BraTS-MET-00020-000', 'BraTS-MET-00021-000', 'BraTS-MET-00022-000', 'BraTS-MET-00023-000', 'BraTS-MET-00024-000', 'BraTS-MET-00025-000', 'BraTS-MET-00026-000', 'BraTS-MET-00028-000', 'BraTS-MET-00029-000', 'BraTS-MET-00032-000', 'BraTS-MET-00033-000', 'BraTS-MET-00035-000', 'BraTS-MET-00036-000', 'BraTS-MET-00086-000', 'BraTS-MET-00089-000', 'BraTS-MET-00090-000', 'BraTS-MET-00096-000', 'BraTS-MET-00097-000', 'BraTS-MET-00098-000', 'BraTS-MET-00100-000', 'BraTS-MET-00102-000', 'BraTS-MET-00104-000', 'BraTS-MET-00106-000', 'BraTS-MET-00107-000', 'BraTS-MET-00108-000', 'BraTS-MET-00110-000', 'BraTS-MET-00112-000', 'BraTS-MET-00113-000', 'BraTS-MET-00114-000', 'BraTS-MET-00115-000', 'BraTS-MET-00117-000', 'Bra