# Importing & Loading Dependencies

In [None]:
!pip install monai

import nibabel as nib
from monai.transforms import LoadImage, Compose, NormalizeIntensityd, RandSpatialCropd, RandFlipd, \
                             RandRotate90d, Rand3DElasticd, RandAdjustContrastd, CenterSpatialCropd,\
                             Resized, RandRotated, RandZoomd, RandGaussianNoised, Spacingd, RandShiftIntensityd,  CropForegroundd, SpatialPadd, AsDiscrete, GridPatchd\

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch
import logging
import numpy as np
import os
import random
from typing import Optional, Sequence, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import interpolate
import pdb

from monai.networks.blocks.convolutions import Convolution
from monai.networks.layers.factories import Act, Norm
from monai.networks.layers.utils import get_act_layer, get_norm_layer
from monai.networks import one_hot
from monai.metrics import DiceMetric, HausdorffDistanceMetric
import torchvision
import math

from grpc import insecure_channel
import argparse
from torch import optim, amp
from monai.losses import DiceLoss,BarlowTwinsLoss
import torch.distributed as dist

from pathlib import Path
import argparse
import json
import math
import os
import random
import signal
import subprocess
import sys
import time
import pdb
import logging
from ipywidgets import interact, IntSlider

from monai.losses import DiceLoss
from torch import nn, optim
from tqdm import tqdm
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision

# Creating Dataset with Preprocessing

In [None]:
class CustomDataset3D(Dataset):
    def __init__(
        self,
        data_dirs,
        patient_lists,
        mode
        ):

        self.data_dirs = data_dirs
        self.patient_lists = patient_lists
        self.mode = mode

    @staticmethod
    def resize_with_aspect_ratio(keys, target_size):
        def transform(data):
            for key in keys:
                volume = data[key]
                original_shape = volume.shape[-3:]
    
                scaling_factor = min(
                    target_size[0] / original_shape[0],
                    target_size[1] / original_shape[1],
                    target_size[2] / original_shape[2]
                )
    
                # Computing the intermediate size while preserving aspect ratio
                new_shape = [
                    int(dim * scaling_factor) for dim in original_shape
                ]
    
                # Resizing to the intermediate shape
                resize_transform = Resized(keys=[key], spatial_size=new_shape, mode="trilinear" if key == "imgs" else "nearest-exact")
                data = resize_transform(data)
    
                # Padding to the final target size
                pad_transform = SpatialPadd(keys=[key], spatial_size=target_size, mode="constant")
                data = pad_transform(data)
            return data

        return transform

    def preprocess(cls, data, mode):
        if mode == 'training':
          transform = Compose([
            CropForegroundd(keys=["imgs", "masks"], source_key="imgs"),
            cls.resize_with_aspect_ratio(keys=["imgs", "masks"], target_size=[128, 128, 128]),
            NormalizeIntensityd( keys=['imgs'], nonzero=False, channel_wise=True),
              
            RandFlipd(keys=["imgs", "masks"],   
                    prob=0.5,                 
                    spatial_axis=2,  
            ),

            RandAdjustContrastd(
                keys=["imgs"],          
                prob=0.15,             
                gamma=(0.65, 1.5),   
            ),
            
        ])

        elif mode == 'validation':
          transform = Compose([
            CropForegroundd(keys=["imgs", "masks"], source_key="imgs"),
            cls.resize_with_aspect_ratio(keys=["imgs", "masks"], target_size=[128, 128, 128]),
            NormalizeIntensityd( keys=['imgs'], nonzero=False, channel_wise=True)

        ])

        else: # 'testing'
          transform = Compose([
            CropForegroundd(keys=["imgs", "masks"], source_key="imgs"),
            cls.resize_with_aspect_ratio(keys=["imgs", "masks"], target_size=[128, 128, 128]),
            NormalizeIntensityd( keys=['imgs'], nonzero=False, channel_wise=True)

        ])

        augmented_data = transform(data)
        return augmented_data
        
    def __len__(self):
        return len(self.patient_lists)

    def __getitem__(self, idx):
        patient_id = self.patient_lists[idx]
        loadimage = LoadImage(reader='NibabelReader', image_only=True)

        data_type=patient_id.split('-')[1]
        if data_type == 'GLI':
            patient_folder_path = os.path.join('/kaggle/input/bratsglioma/Training', patient_id)
        elif data_type == 'SSA':
            patient_folder_path = os.path.join('/kaggle/input/bratsafrica24', patient_id)
        elif data_type == 'PED':
            patient_folder_path = os.path.join('/kaggle/input/bratsped/Training', patient_id)
        elif data_type == 'MEN':
            patient_folder_path = os.path.join('/kaggle/input/bratsmen', patient_id)
        else:
            patient_folder_path = os.path.join('/kaggle/input/bratsmet24', patient_id)

        def resolve_file_path(folder, name):
            file_path = os.path.join(folder, name)
            # Check if the given path is a directory (case with 4 subdirs)
            if os.path.isdir(file_path):
                # Find the first file inside the directory that ends with .nii
                for root, _, files in os.walk(file_path):
                    for file in files:
                        if file.endswith(".nii"):
                            return os.path.join(root, file)
            return file_path


        # Resolve paths for all required image types
        t1c_path  = resolve_file_path(patient_folder_path, patient_id + '-t1c.nii')
        t1n_path  = resolve_file_path(patient_folder_path, patient_id + '-t1n.nii')
        t2f_path  = resolve_file_path(patient_folder_path, patient_id + '-t2f.nii')
        t2w_path  = resolve_file_path(patient_folder_path, patient_id + '-t2w.nii')
        seg_path  = os.path.join(patient_folder_path, patient_id + '-seg.nii')

        t1c_loader   = loadimage( t1c_path )
        t1n_loader   = loadimage( t1n_path )
        t2f_loader   = loadimage( t2f_path )
        t2w_loader   = loadimage( t2w_path )
        masks_loader = loadimage( seg_path )

        # Make the dimension of channel
        t1c_tensor   = torch.Tensor(t1c_loader).unsqueeze(0)
        t1n_tensor   = torch.Tensor(t1n_loader).unsqueeze(0)
        t2f_tensor   = torch.Tensor(t2f_loader).unsqueeze(0)
        t2w_tensor   = torch.Tensor(t2w_loader).unsqueeze(0)
        masks_tensor = torch.Tensor(masks_loader).unsqueeze(0)

        concat_tensor = torch.cat( (t1c_tensor, t1n_tensor, t2f_tensor, t2w_tensor, masks_tensor), 0 )
        data = {            
            'imgs'  : np.array(concat_tensor[0:4,:,:,:]),
            'masks' : np.array(concat_tensor[4:,:,:,:])
        }

        augmented_imgs_masks = self.preprocess(data, self.mode)
        imgs  = np.array(augmented_imgs_masks['imgs'])
        masks = np.array(augmented_imgs_masks['masks'])

        y = {

            'imgs'  : torch.from_numpy(imgs).type(torch.FloatTensor),
            'masks' : torch.from_numpy(masks).type(torch.FloatTensor),
            'patient_id' : patient_id,
            'data_type' : data_type

        }

        return y

# Data Loaders

In [None]:
def combine_datasets(dataset_lists, batch_size=2):
    min_len = min(len(dataset) for dataset in dataset_lists)
    combined_paths = []
    
    for i in range(0, min_len, batch_size):
        for dataset in dataset_lists:
            batch = dataset[i:i + batch_size]
            if len(batch)==batch_size:
                combined_paths.extend(batch)
            else: 
                break
                
    return combined_paths

In [None]:
def prepare_data_loaders(args):
    train_datasets, val_datasets, test_datasets = [], [], []
    split_ratio = {'training': 0.71, 'validation': 0.09, 'testing': 0.2}
    
    for i, data_dir in enumerate(args['data_dirs']):
        patient_lists = os.listdir( data_dir )
        patient_lists.sort()
        total_patients = len(patient_lists)

        random.seed(5)
        random.shuffle(patient_lists)
    
        train_split = int(split_ratio['training'] * total_patients)
        val_split = int(split_ratio['validation'] * total_patients)
    
        train_patient_lists = patient_lists[:train_split]
        val_patient_lists = patient_lists[train_split : train_split + val_split]
        test_patient_lists = patient_lists[train_split + val_split :]
    
        train_patient_lists.sort()
        val_patient_lists.sort()
        test_patient_lists.sort()

        # print(f'Train IDs of {data_dir}', train_patient_lists)
        # print(f'Val IDs of {data_dir}', val_patient_lists)
        print(f'Test IDs of {data_dir}', test_patient_lists)
        
        print(f'Number of training samples in {data_dir.split("/")[3]} DataSet: {len(train_patient_lists)}')
        print(f'Number of validation samples in {data_dir.split("/")[3]} DataSet: {len(val_patient_lists)}')
        print(f'Number of testing samples in {data_dir.split("/")[3]} DataSet: {len(test_patient_lists)} ')

        train_datasets.append(train_patient_lists)
        val_datasets.append(val_patient_lists)
        test_datasets.append(test_patient_lists)
            
    combined_trainDataset = combine_datasets(train_datasets, batch_size=args['train_batch_size'])
    combined_valDataset = combine_datasets(val_datasets, batch_size=args['val_batch_size'])
    combined_testDataset = combine_datasets(test_datasets, batch_size=args['test_batch_size'])
    
    print(f'Number of combined training samples', len(combined_trainDataset))
    print(f'Number of combined validation samples', len(combined_valDataset))
    print(f'Number of combined testing samples', len(combined_testDataset))
    
    trainDataset = CustomDataset3D( args['data_dirs'], combined_trainDataset, mode='training')
    valDataset = CustomDataset3D( args['data_dirs'], combined_valDataset, mode='validation')
    testDataset = CustomDataset3D( args['data_dirs'], combined_testDataset, mode='testing')
    
    trainLoader = DataLoader(
        trainDataset, batch_size=args['train_batch_size'], num_workers=args['workers'], prefetch_factor=2,
        pin_memory=True, shuffle=False)
    
    valLoader = DataLoader(
        valDataset, batch_size=args['val_batch_size'], num_workers=args['workers'], prefetch_factor=2,
        pin_memory=True, shuffle=False)
    
    testLoader = DataLoader(
        testDataset, batch_size=args['test_batch_size'], num_workers=args['workers'], prefetch_factor=2,
        pin_memory=True, shuffle=False)

    return trainLoader, valLoader, testLoader

# DynUNet Model

In [None]:
class UnetBasicBlock(nn.Module):
    """
    A CNN module module that can be used for DynUNet, based on:
    `Automated Design of Deep Learning Methods for Biomedical Image Segmentation <https://arxiv.org/abs/1904.08128>`_.
    `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation <https://arxiv.org/abs/1809.10486>`_.

    Args:
        spatial_dims: number of spatial dimensions.
        in_channels: number of input channels.
        out_channels: number of output channels.
        kernel_size: convolution kernel size.
        stride: convolution stride.
        norm_name: feature normalization type and arguments.
        act_name: activation layer type and arguments.
        dropout: dropout probability.

    """

    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[Sequence[int], int],
        stride: Union[Sequence[int], int],
        norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
        act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}),
        dropout: Optional[Union[Tuple, str, float]] = None,
    ):
        super().__init__()
        self.conv1 = get_conv_layer(
            spatial_dims,
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            dropout=dropout,
            conv_only=True,
        )

        self.conv2 = get_conv_layer(
            spatial_dims,
            out_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=1,
            dropout=dropout,
            conv_only=True
        )
        self.lrelu = get_act_layer(name=act_name)
        self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels)
        self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels)

    def forward(self, inp):
        out = self.conv1(inp)
        out = self.norm1(out)
        out = self.lrelu(out)
        out = self.conv2(out)
        out = self.norm2(out)
        out = self.lrelu(out)
        return out



class UnetUpBlock(nn.Module):
    """
    An upsampling module that can be used for DynUNet, based on:
    `Automated Design of Deep Learning Methods for Biomedical Image Segmentation <https://arxiv.org/abs/1904.08128>`_.
    `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation <https://arxiv.org/abs/1809.10486>`_.

    Args:
        spatial_dims: number of spatial dimensions.
        in_channels: number of input channels.
        out_channels: number of output channels.
        kernel_size: convolution kernel size.
        stride: convolution stride.
        upsample_kernel_size: convolution kernel size for transposed convolution layers.
        norm_name: feature normalization type and arguments.
        act_name: activation layer type and arguments.
        dropout: dropout probability.
        trans_bias: transposed convolution bias.

    """

    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[Sequence[int], int],
        upsample_kernel_size: Union[Sequence[int], int],
        norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
        act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}),
        dropout: Optional[Union[Tuple, str, float]] = None,
        trans_bias: bool = False,
    ):
        super().__init__()
        upsample_stride = upsample_kernel_size
        
        # ( a purple arrow in the paper )
        self.transp_conv = get_conv_layer(
            spatial_dims,
            in_channels,
            out_channels,
            kernel_size=upsample_kernel_size,
            stride=upsample_stride,
            dropout=dropout,
            bias=trans_bias,
            conv_only=True,
            is_transposed=True,
        )
        
        # A light blue conv blocks in the decoder of nnUNet
        self.conv_block = UnetBasicBlock(
            spatial_dims,
            out_channels + out_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=1,
            dropout=dropout,
            norm_name=norm_name,
            act_name=act_name,
        )

    def forward(self, inp, skip):
        # number of channels for skip should equals to out_channels
        out = self.transp_conv(inp)
        out = torch.cat((out, skip), dim=1)
        out = self.conv_block(out)
        return out



class UnetOutBlock(nn.Module):
    def __init__(
        self, spatial_dims: int, in_channels: int, out_channels: int, dropout: Optional[Union[Tuple, str, float]] = None
    ):
        super().__init__()
        self.conv = get_conv_layer(
            spatial_dims, in_channels, out_channels, kernel_size=1, stride=1, dropout=dropout, bias=True, conv_only=True
        )

    def forward(self, inp):
        return self.conv(inp)
    

def get_conv_layer(
    spatial_dims: int,
    in_channels: int,
    out_channels: int,
    kernel_size: Union[Sequence[int], int] = 3,
    stride: Union[Sequence[int], int] = 1,
    act: Optional[Union[Tuple, str]] = Act.PRELU,
    norm: Union[Tuple, str] = Norm.INSTANCE,
    dropout: Optional[Union[Tuple, str, float]] = None,
    bias: bool = False,
    conv_only: bool = True,
    is_transposed: bool = False,
):
    padding = get_padding(kernel_size, stride)
    output_padding = None
    if is_transposed:
        output_padding = get_output_padding(kernel_size, stride, padding)
    
    return Convolution(
        spatial_dims,
        in_channels,
        out_channels,
        strides=stride,
        kernel_size=kernel_size,
        act=act,
        norm=norm,
        dropout=dropout,
        bias=bias,
        conv_only=conv_only,
        is_transposed=is_transposed,
        padding=padding,
        output_padding=output_padding,
    )


def get_padding(
    kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int]
) -> Union[Tuple[int, ...], int]:

    kernel_size_np = np.atleast_1d(kernel_size)
    stride_np = np.atleast_1d(stride)
    padding_np = (kernel_size_np - stride_np + 1) / 2
    if np.min(padding_np) < 0:
        raise AssertionError("padding value should not be negative, please change the kernel size and/or stride.")
    padding = tuple(int(p) for p in padding_np)

    return padding if len(padding) > 1 else padding[0]


def get_output_padding(
    kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], padding: Union[Sequence[int], int]
) -> Union[Tuple[int, ...], int]:
    kernel_size_np = np.atleast_1d(kernel_size)
    stride_np = np.atleast_1d(stride)
    padding_np = np.atleast_1d(padding)

    out_padding_np = 2 * padding_np + stride_np - kernel_size_np
    if np.min(out_padding_np) < 0:
        raise AssertionError("out_padding value should not be negative, please change the kernel size and/or stride.")
    out_padding = tuple(int(p) for p in out_padding_np)

    return out_padding if len(out_padding) > 1 else out_padding[0]

def set_requires_grad(nets, requires_grad=False):
    if not isinstance(nets, list):
        nets = [nets]
    for net in nets:
        if net is not None:
            for param in net.parameters():
                param.requires_grad = requires_grad

In [None]:
class DynUNet(nn.Module):
    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        deep_supervision: bool,
        KD: bool = False
    ):
        super().__init__()
        self.spatial_dims = spatial_dims
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.deep_supervision = deep_supervision
        self.KD_enabled = KD
        
        self.input_conv = UnetBasicBlock( spatial_dims=self.spatial_dims,
                                     in_channels=self.in_channels,
                                     out_channels=64,
                                     kernel_size=3,
                                     stride=1
                                     )
        self.down1 = UnetBasicBlock( spatial_dims=self.spatial_dims,
                                     in_channels=64,
                                     out_channels=96,
                                     kernel_size=3,
                                     stride=2 # Reduces spatial dims by 2
                                     )
        self.down2 = UnetBasicBlock( spatial_dims=self.spatial_dims,
                                     in_channels=96,
                                     out_channels=128,
                                     kernel_size=3,
                                     stride=2
                                     )
        self.down3 = UnetBasicBlock( spatial_dims=self.spatial_dims,
                                     in_channels=128,
                                     out_channels=192,
                                     kernel_size=3,
                                     stride=2
                                     )
        self.down4 = UnetBasicBlock( spatial_dims=self.spatial_dims,
                                     in_channels=192,
                                     out_channels=256,
                                     kernel_size=3,
                                     stride=2
                                     )
        self.down5 = UnetBasicBlock( spatial_dims=self.spatial_dims,
                                     in_channels=256,
                                     out_channels=384,
                                     kernel_size=3,
                                     stride=2
                                     )
        self.bottleneck = UnetBasicBlock( spatial_dims=self.spatial_dims,
                                     in_channels=384,
                                     out_channels=512,
                                     kernel_size=3,
                                     stride=2
                                     )
        self.up1 = UnetUpBlock( spatial_dims=self.spatial_dims,
                                in_channels=512,
                                out_channels=384,
                                kernel_size=3,
                                upsample_kernel_size=2
                                )
        self.up2 = UnetUpBlock( spatial_dims=self.spatial_dims,
                                in_channels=384,
                                out_channels=256,
                                kernel_size=3,
                                upsample_kernel_size=2
                                )
        self.up3 = UnetUpBlock( spatial_dims=self.spatial_dims,
                                in_channels=256,
                                out_channels=192,
                                kernel_size=3,
                                upsample_kernel_size=2
                                )
        self.up4 = UnetUpBlock( spatial_dims=self.spatial_dims,
                                in_channels=192,
                                out_channels=128,
                                kernel_size=3,
                                upsample_kernel_size=2
                                )
        
        self.up5 = UnetUpBlock( spatial_dims=self.spatial_dims,
                                in_channels=128,
                                out_channels=96,
                                kernel_size=3,
                                upsample_kernel_size=2
                                )        
        self.up6 = UnetUpBlock( spatial_dims=self.spatial_dims,
                                in_channels=96,
                                out_channels=64,
                                kernel_size=3,
                                upsample_kernel_size=2
                                )
        self.out1 = UnetOutBlock( spatial_dims=self.spatial_dims,
                                  in_channels=64,
                                  out_channels=self.out_channels,
                                  )
        self.out2 = UnetOutBlock( spatial_dims=self.spatial_dims,
                                  in_channels=96,
                                  out_channels=self.out_channels,
                                  )
        self.out3 = UnetOutBlock( spatial_dims=self.spatial_dims,
                                  in_channels=128,
                                  out_channels=self.out_channels,
                                  )
        
    def forward( self, input ):
        
        # Input
        x0 = self.input_conv( input ) # x0.shape = (B x 64 x 128 x 128 x 128)
        
        # Encoder
        x1 = self.down1( x0 ) # x1.shape = (B x 96 x 64 x 64 x 64) 
        x2 = self.down2( x1 ) # x2.shape = (B x 128 x 32 x 32 x 32)
        x3 = self.down3( x2 ) # x3.shape = (B x 192 x 16 x 16 x 16)
        x4 = self.down4( x3 ) # x4.shape = (B x 256 x 8 x 8 x 8)   
        x5 = self.down5( x4 ) # x5.shape = (B x 384 x 4 x 4 x 4)   
        
        # Bottleneck
        x6 = self.bottleneck( x5 ) # x6.shape = (B x 512 x 2 x 2 x 2)
        
        # Decoder
        x7  = self.up1( x6, x5 )  # x7.shape  = (B x 384 x 4 x 4 x 4)
        x8  = self.up2( x7, x4 )  # x8.shape  = (B x 256 x 8 x 8 x 8)
        x9  = self.up3( x8, x3 )  # x9.shape  = (B x 192 x 16 x 16 x 16)
        x10 = self.up4( x9, x2 )  # x10.shape = (B x 128 x 32 x 32 x 32)
        x11 = self.up5( x10, x1 ) # x11.shape = (B x 96 x 64 x 64 x 64)
        x12 = self.up6( x11, x0 ) # x12.shape = (B x 64 x 128 x 128 x 128)
        
        # Output
        output1 = self.out1( x12 )
        
        if (self.training and self.deep_supervision) or self.KD_enabled:
            
            # output['pred'].shape = B x 3 x 4 x 128 x 128 x 128
            output2 = interpolate( self.out2( x11 ), output1.shape[2:])
            output3 = interpolate( self.out3( x10 ), output1.shape[2:])
            output_all = [ output1, output2, output3 ]
            return { 'pred' : torch.stack(output_all, dim=1),
                     'bottleneck_feature_map' : x6 }
        
        return { 'pred' : output1 }

# Testing

In [None]:
from monai.metrics import DiceMetric, HausdorffDistanceMetric

class ComputeMetrics(nn.Module):
    def __init__(self):
        super(ComputeMetrics, self).__init__()
        self.dice_metric = DiceMetric(reduction="mean_batch")
        self.hausdorff_metric = HausdorffDistanceMetric(percentile=95.0, reduction="mean_batch")

    def compute(self, p, y, lbl):
        self.dice_metric.reset()
        self.hausdorff_metric.reset()
        
        print(f"{lbl} - Prediction unique values: {torch.unique(p)}")
        print(f"{lbl} - Ground truth unique values: {torch.unique(y)}")

        if torch.sum(y.float()) == 0 and torch.sum(p.float()) == 0:  # True Negative Case: No foreground pixels in GT
            print(f"{lbl} - No positive samples in ground truth.")
            print(f"Dice scores for {lbl} for this batch: {1.0}")
            print(f"Hausdorff distances for {lbl} for this batch: {0.0}")
            return torch.tensor(1.0), torch.tensor(0.0)
        
        if torch.sum(p.float()) == 0 and torch.sum(y.float()) > 0:  # False Negative Case: GT has 1s, Prediction is all 0s
            print(f"{lbl} - False Negative Case: GT has positive samples, but prediction is empty.")
            print(f"Dice scores for {lbl} for this batch: {0.0}")
            print(f"Hausdorff distances for {lbl} for this batch: {373.1287}")
            return torch.tensor(0.0), torch.tensor(373.1287)
        
        if torch.sum(p.float()) > 0 and torch.sum(y.float()) == 0:  # False Positive Case: Prediction has 1s, GT is all 0s
            print(f"{lbl} - False Positive Case: Prediction has positives, but ground truth is empty.")
            print(f"Dice scores for {lbl} for this batch: {0.0}")
            print(f"Hausdorff distances for {lbl} for this batch: {373.1287}")
            return torch.tensor(0.0), torch.tensor(373.1287)

        # Compute metrics normally
        dice_score = self.dice_metric(p.float(), y.float())
        hausdorff_dist = self.hausdorff_metric(p.float(), y.float())

        print(f"Dice scores for {lbl} for this batch:\n {dice_score.item()}")
        print(f"Hausdorff distances for {lbl} for this batch:\n{hausdorff_dist.item()}")
    
        return dice_score, hausdorff_dist

    def forward(self, p, y):
        p = (torch.sigmoid(p) > 0.5)
        y_wt, y_tc, y_et = y > 0, ((y == 1) + (y == 3)) > 0, y == 3
        p_wt, p_tc, p_et = p[:, 1].unsqueeze(1), p[:, 2].unsqueeze(1), p[:, 3].unsqueeze(1)
        
        dice_wt, hd_wt = self.compute(p_wt, y_wt, 'wt')
        dice_tc, hd_tc = self.compute(p_tc, y_tc, 'tc')
        dice_et, hd_et = self.compute(p_et, y_et, 'et')
        
        return [dice_wt, hd_wt], [dice_tc, hd_tc], [dice_et, hd_et]

In [None]:
def test_net(model, loader): # student_model, testLoader
    torch.manual_seed(0)
    model.eval()
    n_test_batches = len(loader)

    compute_metrics = ComputeMetrics()
    total_metrics = {"WT": {'dice_score': 0, 'hausdorff_distance': 0},
                     "TC": {'dice_score': 0, 'hausdorff_distance': 0},
                     "ET": {'dice_score': 0, 'hausdorff_distance': 0}}

    with tqdm(total=n_test_batches, desc='Testing', unit='batch', leave=False) as pbar:
        with torch.no_grad():
            for step, y in enumerate(loader):
                y['imgs'], y['masks']= y['imgs'].to('cuda'), y['masks'].to('cuda')
                
                with torch.amp.autocast('cuda'):
                    print("--------Now patient:", y['patient_id'])
                    output = model(y['imgs'])
                    wt_metrics, tc_metrics, et_metrics = compute_metrics(output['pred'], y['masks'])
                    
                    total_metrics['WT']['dice_score'] += wt_metrics[0].item()
                    total_metrics['WT']['hausdorff_distance'] += wt_metrics[1].item()

                    total_metrics['TC']['dice_score'] += tc_metrics[0].item()
                    total_metrics['TC']['hausdorff_distance'] += tc_metrics[1].item()

                    total_metrics['ET']['dice_score'] += et_metrics[0].item()
                    total_metrics['ET']['hausdorff_distance'] += et_metrics[1].item()
                                    
                pbar.update(1)

        total_metrics['WT']['dice_score'] /= n_test_batches
        total_metrics['WT']['hausdorff_distance'] /= n_test_batches

        total_metrics['TC']['dice_score'] /= n_test_batches
        total_metrics['TC']['hausdorff_distance'] /= n_test_batches

        total_metrics['ET']['dice_score'] /= n_test_batches
        total_metrics['ET']['hausdorff_distance'] /= n_test_batches


        print("************************************************************************")
        print(f"Average Dice Score for WT: {total_metrics['WT']['dice_score']:.4f}")
        print(f"Average Hausdorff Distance for WT: {total_metrics['WT']['hausdorff_distance']:.4f}")

        print("-----------------------------------------------------------------------------")
        print("-----------------------------------------------------------------------------")
                                     
        print(f"Average Dice Score for TC: {total_metrics['TC']['dice_score']:.4f}")
        print(f"Average Hausdorff Distance for TC: {total_metrics['TC']['hausdorff_distance']:.4f}")
                              
        print("-----------------------------------------------------------------------------")
        print("-----------------------------------------------------------------------------")
                                     
        print(f"Average Dice Score for ET: {total_metrics['ET']['dice_score']:.4f}")
        print(f"Average Hausdorff Distance for ET: {total_metrics['ET']['hausdorff_distance']:.4f}")
        print("************************************************************************")

        model.train()
    
    return

## Testing on GLI

In [None]:
args = {
    'workers': 2,
    'epochs': 15,
    'train_batch_size': 2,
    'val_batch_size': 2,
    'test_batch_size': 1,
    'learning_rate': 1e-3,
    # 'data_dirs': ["/kaggle/input/bratsglioma/Training/", "/kaggle/input/bratsafrica24/", "/kaggle/input/bratsped/Training/", "/kaggle/input/bratsmen/", "/kaggle/input/bratsmet24/"],
    'data_dirs': ["/kaggle/input/bratsglioma/Training/"],
}

_, _, testLoader = prepare_data_loaders(args)
student_path = Path(f'/kaggle/input/kd-5tumors-originalteachers-unbalanced-cbamkl/Student_model_after_epoch_60_trainLoss_0.8043_valLoss_0.3555.pth')
student_model = DynUNet( spatial_dims=3, in_channels=4, out_channels=4, deep_supervision=False).to('cuda')
if (student_path).is_file():
    print(f"Found model: {student_path}")
    ckpt = torch.load(student_path, map_location='cuda', weights_only=True)
    student_model.load_state_dict(ckpt['student_model'])
    print(f"Loaded model: {student_path}")
    test_net(student_model, testLoader)

## Testing on SSA

In [None]:
args = {
    'workers': 2,
    'epochs': 15,
    'train_batch_size': 2,
    'val_batch_size': 2,
    'test_batch_size': 1,
    'learning_rate': 1e-3,
    # 'data_dirs': ["/kaggle/input/bratsglioma/Training/", "/kaggle/input/bratsafrica24/", "/kaggle/input/bratsped/Training/", "/kaggle/input/bratsmen/", "/kaggle/input/bratsmet24/"],
    'data_dirs': ["/kaggle/input/bratsafrica24/"],
}

_, _, testLoader = prepare_data_loaders(args)
student_path = Path(f'/kaggle/input/kd-5tumors-originalteachers-unbalanced-cbamkl/Student_model_after_epoch_60_trainLoss_0.8043_valLoss_0.3555.pth')
student_model = DynUNet( spatial_dims=3, in_channels=4, out_channels=4, deep_supervision=False).to('cuda')
if (student_path).is_file():
    print(f"Found model: {student_path}")
    ckpt = torch.load(student_path, map_location='cuda', weights_only=True)
    student_model.load_state_dict(ckpt['student_model'])
    print(f"Loaded model: {student_path}")
    test_net(student_model, testLoader)

## Testing on PED

In [None]:
args = {
    'workers': 2,
    'epochs': 15,
    'train_batch_size': 2,
    'val_batch_size': 2,
    'test_batch_size': 1,
    'learning_rate': 1e-3,
    # 'data_dirs': ["/kaggle/input/bratsglioma/Training/", "/kaggle/input/bratsafrica24/", "/kaggle/input/bratsped/Training/", "/kaggle/input/bratsmen/", "/kaggle/input/bratsmet24/"],
    'data_dirs': ["/kaggle/input/bratsped/Training/"],
}

_, _, testLoader = prepare_data_loaders(args)
student_path = Path(f'/kaggle/input/kd-5tumors-originalteachers-unbalanced-cbamkl/Student_model_after_epoch_60_trainLoss_0.8043_valLoss_0.3555.pth')
student_model = DynUNet( spatial_dims=3, in_channels=4, out_channels=4, deep_supervision=False).to('cuda')
if (student_path).is_file():
    print(f"Found model: {student_path}")
    ckpt = torch.load(student_path, map_location='cuda', weights_only=True)
    student_model.load_state_dict(ckpt['student_model'])
    print(f"Loaded model: {student_path}")
    test_net(student_model, testLoader)

## Testing on MEN

In [None]:
args = {
    'workers': 2,
    'epochs': 15,
    'train_batch_size': 2,
    'val_batch_size': 2,
    'test_batch_size': 1,
    # 'data_dirs': ["/kaggle/input/bratsglioma/Training/", "/kaggle/input/bratsafrica24/", "/kaggle/input/bratsped/Training/", "/kaggle/input/bratsmen/", "/kaggle/input/bratsmet24/"],
    'data_dirs': ["/kaggle/input/bratsmen/"],
}

_, _, testLoader = prepare_data_loaders(args)
student_path = Path(f'/kaggle/input/kd-5tumors-originalteachers-unbalanced-cbamkl/Student_model_after_epoch_60_trainLoss_0.8043_valLoss_0.3555.pth')
student_model = DynUNet( spatial_dims=3, in_channels=4, out_channels=4, deep_supervision=False).to('cuda')
if (student_path).is_file():
    print(f"Found model: {student_path}")
    ckpt = torch.load(student_path, map_location='cuda', weights_only=True)
    student_model.load_state_dict(ckpt['student_model'])
    print(f"Loaded model: {student_path}")
    test_net(student_model, testLoader)

## Testing on MET

In [None]:
args = {
    'workers': 2,
    'epochs': 15,
    'train_batch_size': 2,
    'val_batch_size': 2,
    'test_batch_size': 1,
    'learning_rate': 1e-3,
    # 'data_dirs': ["/kaggle/input/bratsglioma/Training/", "/kaggle/input/bratsafrica24/", "/kaggle/input/bratsped/Training/", "/kaggle/input/bratsmen/", "/kaggle/input/bratsmet24/"],
    'data_dirs': ["/kaggle/input/bratsmet24/"],
}

_, _, testLoader = prepare_data_loaders(args)
student_path = Path(f'/kaggle/input/kd-5tumors-originalteachers-unbalanced-cbamkl/Student_model_after_epoch_60_trainLoss_0.8043_valLoss_0.3555.pth')
student_model = DynUNet( spatial_dims=3, in_channels=4, out_channels=4, deep_supervision=False).to('cuda')
if (student_path).is_file():
    print(f"Found model: {student_path}")
    ckpt = torch.load(student_path, map_location='cuda', weights_only=True)
    student_model.load_state_dict(ckpt['student_model'])
    print(f"Loaded model: {student_path}")
    test_net(student_model, testLoader)

# Press here