## Imports

In [5]:
!python3 -c "import monai" || pip3 install -q "monai-weekly[nibabel, tqdm, einops]" 
!python3 -c "import matplotlib" || pip3 install -q matplotlib 
%matplotlib inline
!pip3 install -q einops 
!pip install -q wandb


2025-06-14 16:08:53.024311: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-06-14 16:08:53.045812: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-06-14 16:08:53.052460: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [6]:
import os
import shutil
import tempfile
import time
import matplotlib.pyplot as plt
from monai.apps import DecathlonDataset
from monai.config import print_config
from monai.data import DataLoader, decollate_batch
from monai.handlers.utils import from_engine
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.transforms import (
    Activations,
    Activationsd,
    AsDiscrete,
    AsDiscreted,
    Compose,
    Invertd,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    Spacingd,
    EnsureTyped,
    EnsureChannelFirstd,
)
from monai.utils import set_determinism
import torch


In [7]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
api_key = user_secrets.get_secret("wandb")

import wandb
wandb.login(key=api_key)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33myokai-re77[0m ([33myokai-re77-north-south-university[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

## Creating JSON for Dataset

In [8]:
import os
import glob
import json

# Get sorted file paths and file names
file_paths1 = glob.glob('/kaggle/input/brats2023-part-1/*')  # Fixed the glob pattern
file_paths1.sort()

file_names1 = [os.path.basename(path) for path in file_paths1]  # Extract file names from paths
file_names1.sort()

# Initialize lists for different MRI modalities and segmentation labels
t1c, t1n, t2f, t2w, label = [], [], [], [], []

# Use the total number of files instead of a fixed 330
num_files = len(file_paths1)

# Populate the lists with file paths
for i in range(num_files):
    t1c.append(os.path.join(file_paths1[i], file_names1[i] + '-t1c.nii'))
    t1n.append(os.path.join(file_paths1[i], file_names1[i] + '-t1n.nii'))
    t2f.append(os.path.join(file_paths1[i], file_names1[i] + '-t2f.nii'))
    t2w.append(os.path.join(file_paths1[i], file_names1[i] + '-t2w.nii'))
    label.append(os.path.join(file_paths1[i], file_names1[i] + '-seg.nii'))

# Store in a dictionary with combined image modalities and separate label
file_list = []
for i in range(num_files):
    file_list.append({
        "image": [t1c[i], t1n[i], t2f[i], t2w[i]],  # Combine modalities into one "image" field
        "label": label[i]
    })

file_json = {
    "training": file_list
}

# Save to JSON file
file_path = '/kaggle/working/dataset.json'
with open(file_path, 'w') as json_file:
    json.dump(file_json, json_file, indent=4)

# Print the first 100 entries
# print(json.dumps({"training": file_list[:100]}, indent=4))


## MONAI Directory

In [9]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = '/kaggle/working/'
print(root_dir)

/kaggle/working/


In [10]:
set_determinism(seed=0)

## Convert Labels to BRATS 2023

In [11]:
class ConvertLabels(MapTransform):
    """
#     Convert labels to multi channels based on BRATS 2023 classes:
#     label 1 is Necrotic Tumor Core (NCR)
#     label 2 is Edema (ED)
#     label 3 is Enhancing Tumor (ET)
#     label 0 is everything else (non-tumor)
#     """

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            # Tumor Core (TC) = NCR + Enhancing Tumor (ET)
            result.append(torch.logical_or(d[key] == 1, d[key] == 3))
            # Whole Tumor (WT) = NCR + Edema + Enhancing Tumor
            result.append(torch.logical_or(torch.logical_or(d[key] == 1, d[key] == 2), d[key] == 3))
            # Enhancing Tumor (ET) = Enhancing Tumor (label 3)
            result.append(d[key] == 3)
            d[key] = torch.stack(result, axis=0).float()
        return d

## Preparing Data and Augmentation

In [12]:
# Optimized Transforms for Brain Tumor Segmentation - Dice Score Focused

from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, EnsureTyped,
    Orientationd, Spacingd, RandSpatialCropd, RandFlipd, NormalizeIntensityd,
    RandScaleIntensityd, RandShiftIntensityd, RandRotate90d, RandGaussianNoised,
    RandAdjustContrastd, RandAffined
)

# Optimized Training Transforms - Only Dice-Boosting Augmentations
train_transform = Compose(
    [
        # Essential loading and preprocessing
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        ConvertLabels(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        
        # Spatial cropping
        RandSpatialCropd(keys=["image", "label"], roi_size=[96, 96, 96], random_size=False),
        
        # Geometric augmentations - proven to boost dice
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        RandRotate90d(keys=["image", "label"], prob=0.3, spatial_axes=(0, 1)),
        
        # Conservative affine transformations
        RandAffined(
            keys=["image", "label"],
            mode=("bilinear", "nearest"),
            prob=0.4,
            rotate_range=(0.1, 0.1, 0.1),  # Reduced rotation
            scale_range=(0.05, 0.05, 0.05),  # Reduced scaling
            translate_range=(5, 5, 5),  # Small translations
        ),
        
        # Essential intensity normalization
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        
        # Conservative intensity augmentations
        RandScaleIntensityd(keys="image", factors=0.1, prob=0.8),  # Reduced prob
        RandShiftIntensityd(keys="image", offsets=0.1, prob=0.8),  # Reduced prob
        
        # Light Gaussian noise for robustness
        RandGaussianNoised(keys="image", prob=0.2, std=0.01),
        
        # Contrast adjustment - helps with tumor boundary detection
        RandAdjustContrastd(keys="image", prob=0.3, gamma=(0.8, 1.3)),  # Conservative range
    ]
)

# Validation Transform (no augmentation)
val_transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        ConvertLabels(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)

## Dataset and Dataloader

In [13]:
import monai
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, EnsureTyped, Orientationd,
    Spacingd, RandSpatialCropd, RandFlipd, NormalizeIntensityd,
    RandScaleIntensityd, RandShiftIntensityd
)
from monai.data import Dataset, DataLoader, CacheDataset, PersistentDataset
from monai.utils import set_determinism
from sklearn.model_selection import train_test_split
import json

# Load dataset
dataset_path = "/kaggle/working/dataset.json"
with open(dataset_path) as f:
    datalist = json.load(f)["training"]

# Split dataset into training (80%) and validation (20%)
train_files, val_files = train_test_split(datalist, test_size=0.2, random_state=42)

### For quick iterations use 10 percent dataset

# train_files = train_files[:int(len(train_files) * 0.3)]  # 10% of training data
# val_files = val_files[:int(len(val_files) * 0.3)]  # 10% of validation data

# Set deterministic behavior
set_determinism(seed=0)

cache_dir = "/kaggle/working/cache"


train_transform = train_transform
val_transform = val_transform

# Create MONAI datasets

train_ds = Dataset(data=train_files, transform=train_transform)
val_ds = Dataset(data=val_files, transform=val_transform)

# train_ds = Dataset(data=train_files[:int(0.3 * len(train_files))], transform=train_transform)
# val_ds = Dataset(data=val_files[:int(0.3 * len(val_files))], transform=val_transform)

# Dataloaders
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=3, pin_memory=True, persistent_workers=False)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=3, pin_memory=True, persistent_workers=False)


## Train Using Torch Lightning Pipeline

In [14]:
import multiprocessing

torch.multiprocessing.set_sharing_strategy("file_system")
multiprocessing.set_start_method('fork', force=True)


In [15]:
# !rm -rf /kaggle/working

In [16]:
# # Restart the kernel (THIS WILL INTERRUPT THE KERNEL)
# import os
# os._exit(00)

## MedNext Block

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Sequence

# MedNeXt Components (Fixed Version)
class LayerNorm(nn.Module):
    """ LayerNorm that supports two data formats: channels_last (default) or channels_first. 
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs 
    with shape (batch_size, channels, height, width).
    """
    def __init__(self, normalized_shape, eps=1e-5, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))        # beta
        self.bias = nn.Parameter(torch.zeros(normalized_shape))         # gamma
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError 
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x, dummy_tensor=False):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
            return x

class MedNeXtBlock(nn.Module):
    def __init__(self, 
                in_channels: int, 
                out_channels: int, 
                exp_r: int = 4, 
                kernel_size: int = 7, 
                do_res: bool = True,
                norm_type: str = 'group',
                n_groups: int = None,
                dim: str = '3d',
                grn: bool = False
                ):

        super().__init__()

        self.do_res = do_res and (in_channels == out_channels)  # Only do residual if channels match
        self.in_channels = in_channels
        self.out_channels = out_channels

        assert dim in ['2d', '3d']
        self.dim = dim
        if self.dim == '2d':
            conv = nn.Conv2d
        elif self.dim == '3d':
            conv = nn.Conv3d
            
        # First convolution layer with DepthWise Convolutions
        self.conv1 = conv(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=kernel_size,
            stride=1,
            padding=kernel_size//2,
            groups=in_channels if n_groups is None else n_groups,
        )

        # Normalization Layer. GroupNorm is used by default.
        if norm_type == 'group':
            self.norm = nn.GroupNorm(
                num_groups=in_channels, 
                num_channels=in_channels
                )
        elif norm_type == 'layer':
            self.norm = LayerNorm(
                normalized_shape=in_channels, 
                data_format='channels_first'
                )

        # Second convolution (Expansion) layer
        self.conv2 = conv(
            in_channels=in_channels,
            out_channels=exp_r*in_channels,
            kernel_size=1,
            stride=1,
            padding=0
        )
        
        # GeLU activations
        self.act = nn.GELU()
        
        # Third convolution (Compression) layer
        self.conv3 = conv(
            in_channels=exp_r*in_channels,
            out_channels=out_channels,
            kernel_size=1,
            stride=1,
            padding=0
        )
        
        # Add projection layer for residual connection when channels don't match
        if in_channels != out_channels and do_res:
            self.projection = conv(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1,
                stride=1,
                padding=0
            )
        else:
            self.projection = None

        self.grn = grn
        if grn:
            if dim == '3d':
                self.grn_beta = nn.Parameter(torch.zeros(1, exp_r*in_channels, 1, 1, 1), requires_grad=True)
                self.grn_gamma = nn.Parameter(torch.zeros(1, exp_r*in_channels, 1, 1, 1), requires_grad=True)
            elif dim == '2d':
                self.grn_beta = nn.Parameter(torch.zeros(1, exp_r*in_channels, 1, 1), requires_grad=True)
                self.grn_gamma = nn.Parameter(torch.zeros(1, exp_r*in_channels, 1, 1), requires_grad=True)

    def forward(self, x, dummy_tensor=None):
        residual = x
        
        x1 = self.conv1(x)
        x1 = self.act(self.conv2(self.norm(x1)))
        
        if self.grn:
            if self.dim == '3d':
                gx = torch.norm(x1, p=2, dim=(-3, -2, -1), keepdim=True)
            elif self.dim == '2d':
                gx = torch.norm(x1, p=2, dim=(-2, -1), keepdim=True)
            nx = gx / (gx.mean(dim=1, keepdim=True) + 1e-6)
            x1 = self.grn_gamma * (x1 * nx) + self.grn_beta + x1
            
        x1 = self.conv3(x1)
        
        # Handle residual connection properly
        if self.do_res or self.projection is not None:
            if self.projection is not None:
                residual = self.projection(residual)
            x1 = residual + x1  
            
        return x1

class MedNeXtUpBlock(nn.Module):
    """MedNeXt-based upsampling block for decoder"""
    def __init__(self, 
                 in_channels: int, 
                 out_channels: int, 
                 exp_r: int = 4,
                 kernel_size: int = 7,
                 norm_type: str = 'group',
                 spatial_dims: int = 3,
                 grn: bool = False):
        super().__init__()
        
        self.spatial_dims = spatial_dims
        dim = '3d' if spatial_dims == 3 else '2d'
        
        if spatial_dims == 3:
            self.upsample = nn.ConvTranspose3d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=2,
                stride=2
            )
        else:
            self.upsample = nn.ConvTranspose2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=2,
                stride=2
            )
        
        # MedNeXt block for feature processing
        self.mednext_block = MedNeXtBlock(
            in_channels=out_channels * 2,  # Concatenated features
            out_channels=out_channels,
            exp_r=exp_r,
            kernel_size=kernel_size,
            do_res=False,  # Set to False since input/output channels are different
            norm_type=norm_type,
            dim=dim,
            grn=grn
        )
        
    def forward(self, x, skip):
        # Upsample the input
        x = self.upsample(x)
        
        # Concatenate with skip connection
        x = torch.cat([x, skip], dim=1)
        
        # Process through MedNeXt block
        x = self.mednext_block(x)
        
        return x

class MedNeXtEncoder(nn.Module):
    """MedNeXt-based encoder block"""
    def __init__(self, 
                 in_channels: int, 
                 out_channels: int, 
                 exp_r: int = 4,
                 kernel_size: int = 7,
                 norm_type: str = 'group',
                 spatial_dims: int = 3,
                 grn: bool = False):
        super().__init__()
        
        dim = '3d' if spatial_dims == 3 else '2d'
        
        # MedNeXt blocks for feature processing
        # First block handles channel dimension change
        self.block1 = MedNeXtBlock(
            in_channels=in_channels,
            out_channels=out_channels,
            exp_r=exp_r,
            kernel_size=kernel_size,
            do_res=True if in_channels == out_channels else False,  # Only residual if channels match
            norm_type=norm_type,
            dim=dim,
            grn=grn
        )
        
        # Second block maintains channel dimensions
        self.block2 = MedNeXtBlock(
            in_channels=out_channels,
            out_channels=out_channels,
            exp_r=exp_r,
            kernel_size=kernel_size,
            do_res=True,  # Can always do residual here since channels match
            norm_type=norm_type,
            dim=dim,
            grn=grn
        )
        
    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        return x

## SwinUNETR Architecture

In [22]:
import itertools
from collections.abc import Sequence

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from torch.nn import LayerNorm

from monai.networks.blocks import MLPBlock as Mlp
from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock
from monai.networks.layers import DropPath, trunc_normal_
from monai.utils import ensure_tuple_rep, look_up_option, optional_import

rearrange, _ = optional_import("einops", name="rearrange")

class SwinUNETR(nn.Module):

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        patch_size: int = 2,
        depths: Sequence[int] = (2, 2, 2, 2),
        num_heads: Sequence[int] = (3, 6, 12, 24),
        window_size: Sequence[int] | int = 7,
        qkv_bias: bool = True,
        mlp_ratio: float = 4.0,
        feature_size: int = 24,
        norm_name: tuple | str = "instance",
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        dropout_path_rate: float = 0.0,
        normalize: bool = True,
        norm_layer: type[LayerNorm] = nn.LayerNorm,
        patch_norm: bool = False,
        use_checkpoint: bool = False,
        spatial_dims: int = 3,
        downsample: str | nn.Module = "merging",
        use_v2: bool = False,
        # MedNeXt specific parameters
        mednext_exp_r: int = 4,
        mednext_kernel_size: int = 7,
        mednext_norm_type: str = 'group',
        mednext_grn: bool = False
    ) -> None:
        """
        Args:
            in_channels: dimension of input channels.
            out_channels: dimension of output channels.
            patch_size: size of the patch token.
            feature_size: dimension of network feature size.
            depths: number of layers in each stage.
            num_heads: number of attention heads.
            window_size: local window size.
            qkv_bias: add a learnable bias to query, key, value.
            mlp_ratio: ratio of mlp hidden dim to embedding dim.
            norm_name: feature normalization type and arguments.
            drop_rate: dropout rate.
            attn_drop_rate: attention dropout rate.
            dropout_path_rate: drop path rate.
            normalize: normalize output intermediate features in each stage.
            norm_layer: normalization layer.
            patch_norm: whether to apply normalization to the patch embedding. Default is False.
            use_checkpoint: use gradient checkpointing for reduced memory usage.
            spatial_dims: number of spatial dims.
            downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
                user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
                The default is currently `"merging"` (the original version defined in v0.9.0).
            use_v2: using swinunetr_v2, which adds a residual convolution block at the beggining of each swin stage.
            
        """

        super().__init__()

        if spatial_dims not in (2, 3):
            raise ValueError("spatial dimension should be 2 or 3.")

        self.patch_size = patch_size

        patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims)
        window_size = ensure_tuple_rep(window_size, spatial_dims)

        if not (0 <= drop_rate <= 1):
            raise ValueError("dropout rate should be between 0 and 1.")

        if not (0 <= attn_drop_rate <= 1):
            raise ValueError("attention dropout rate should be between 0 and 1.")

        if not (0 <= dropout_path_rate <= 1):
            raise ValueError("drop path rate should be between 0 and 1.")

        if feature_size % 12 != 0:
            raise ValueError("feature_size should be divisible by 12.")

        self.normalize = normalize

        self.swinViT = SwinTransformer(
            in_chans=in_channels,
            embed_dim=feature_size,
            window_size=window_size,
            patch_size=patch_sizes,
            depths=depths,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            drop_rate=drop_rate,
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=dropout_path_rate,
            norm_layer=norm_layer,
            patch_norm=patch_norm,
            use_checkpoint=use_checkpoint,
            spatial_dims=spatial_dims,
            downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample,
            use_v2=use_v2,
        )
        
        # MedNeXt-based Encoders
        self.encoder1 = MedNeXtEncoder(
            in_channels=in_channels,  # Original input channels (e.g., 4 for multi-modal MRI)
            out_channels=feature_size,  # First level features (e.g., 48)
            exp_r=mednext_exp_r,
            kernel_size=mednext_kernel_size,
            norm_type=mednext_norm_type,
            spatial_dims=spatial_dims,
            grn=mednext_grn
        )
    
        self.encoder2 = MedNeXtEncoder(
            in_channels=feature_size,  # Match SwinViT output channels
            out_channels=feature_size,  # Keep same for skip connection compatibility
            exp_r=mednext_exp_r,
            kernel_size=mednext_kernel_size,
            norm_type=mednext_norm_type,
            spatial_dims=spatial_dims,
            grn=mednext_grn
        )
    
        self.encoder3 = MedNeXtEncoder(
            in_channels=2 * feature_size,  # Match SwinViT output channels
            out_channels=2 * feature_size,  # Keep same for skip connection compatibility
            exp_r=mednext_exp_r,
            kernel_size=mednext_kernel_size,
            norm_type=mednext_norm_type,
            spatial_dims=spatial_dims,
            grn=mednext_grn
        )
    
        self.encoder4 = MedNeXtEncoder(
            in_channels=4 * feature_size,  # Match SwinViT output channels
            out_channels=4 * feature_size,  # Keep same for skip connection compatibility
            exp_r=mednext_exp_r,
            kernel_size=mednext_kernel_size,
            norm_type=mednext_norm_type,
            spatial_dims=spatial_dims,
            grn=mednext_grn
        )
    
        self.encoder10 = MedNeXtEncoder(
            in_channels=16 * feature_size,  # Match SwinViT bottleneck output
            out_channels=16 * feature_size,
            exp_r=mednext_exp_r,
            kernel_size=mednext_kernel_size,
            norm_type=mednext_norm_type,
            spatial_dims=spatial_dims,
            grn=mednext_grn
        )
    
        # MedNeXt-based Decoders - these look correct
        self.decoder5 = MedNeXtUpBlock(
            in_channels=16 * feature_size,
            out_channels=8 * feature_size,
            exp_r=mednext_exp_r,
            kernel_size=mednext_kernel_size,
            norm_type=mednext_norm_type,
            spatial_dims=spatial_dims,
            grn=mednext_grn
        )
    
        self.decoder4 = MedNeXtUpBlock(
            in_channels=8 * feature_size,
            out_channels=4 * feature_size,
            exp_r=mednext_exp_r,
            kernel_size=mednext_kernel_size,
            norm_type=mednext_norm_type,
            spatial_dims=spatial_dims,
            grn=mednext_grn
        )
    
        self.decoder3 = MedNeXtUpBlock(
            in_channels=4 * feature_size,
            out_channels=2 * feature_size,
            exp_r=mednext_exp_r,
            kernel_size=mednext_kernel_size,
            norm_type=mednext_norm_type,
            spatial_dims=spatial_dims,
            grn=mednext_grn
        )
    
        self.decoder2 = MedNeXtUpBlock(
            in_channels=2 * feature_size,
            out_channels=feature_size,
            exp_r=mednext_exp_r,
            kernel_size=mednext_kernel_size,
            norm_type=mednext_norm_type,
            spatial_dims=spatial_dims,
            grn=mednext_grn
        )
    
        self.decoder1 = MedNeXtUpBlock(
            in_channels=feature_size,
            out_channels=feature_size,
            exp_r=mednext_exp_r,
            kernel_size=mednext_kernel_size,
            norm_type=mednext_norm_type,
            spatial_dims=spatial_dims,
            grn=mednext_grn
        )
    
        # Output layer
        self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels)


    def load_from(self, weights):
        layers1_0: BasicLayer = self.swinViT.layers1[0]  # type: ignore[assignment]
        layers2_0: BasicLayer = self.swinViT.layers2[0]  # type: ignore[assignment]
        layers3_0: BasicLayer = self.swinViT.layers3[0]  # type: ignore[assignment]
        layers4_0: BasicLayer = self.swinViT.layers4[0]  # type: ignore[assignment]
        wstate = weights["state_dict"]

        with torch.no_grad():
            self.swinViT.patch_embed.proj.weight.copy_(wstate["module.patch_embed.proj.weight"])
            self.swinViT.patch_embed.proj.bias.copy_(wstate["module.patch_embed.proj.bias"])
            for bname, block in layers1_0.blocks.named_children():
                block.load_from(weights, n_block=bname, layer="layers1")  # type: ignore[operator]

            if layers1_0.downsample is not None:
                d = layers1_0.downsample
                d.reduction.weight.copy_(wstate["module.layers1.0.downsample.reduction.weight"])  # type: ignore
                d.norm.weight.copy_(wstate["module.layers1.0.downsample.norm.weight"])  # type: ignore
                d.norm.bias.copy_(wstate["module.layers1.0.downsample.norm.bias"])  # type: ignore

            for bname, block in layers2_0.blocks.named_children():
                block.load_from(weights, n_block=bname, layer="layers2")  # type: ignore[operator]

            if layers2_0.downsample is not None:
                d = layers2_0.downsample
                d.reduction.weight.copy_(wstate["module.layers2.0.downsample.reduction.weight"])  # type: ignore
                d.norm.weight.copy_(wstate["module.layers2.0.downsample.norm.weight"])  # type: ignore
                d.norm.bias.copy_(wstate["module.layers2.0.downsample.norm.bias"])  # type: ignore

            for bname, block in layers3_0.blocks.named_children():
                block.load_from(weights, n_block=bname, layer="layers3")  # type: ignore[operator]

            if layers3_0.downsample is not None:
                d = layers3_0.downsample
                d.reduction.weight.copy_(wstate["module.layers3.0.downsample.reduction.weight"])  # type: ignore
                d.norm.weight.copy_(wstate["module.layers3.0.downsample.norm.weight"])  # type: ignore
                d.norm.bias.copy_(wstate["module.layers3.0.downsample.norm.bias"])  # type: ignore

            for bname, block in layers4_0.blocks.named_children():
                block.load_from(weights, n_block=bname, layer="layers4")  # type: ignore[operator]

            if layers4_0.downsample is not None:
                d = layers4_0.downsample
                d.reduction.weight.copy_(wstate["module.layers4.0.downsample.reduction.weight"])  # type: ignore
                d.norm.weight.copy_(wstate["module.layers4.0.downsample.norm.weight"])  # type: ignore
                d.norm.bias.copy_(wstate["module.layers4.0.downsample.norm.bias"])  # type: ignore

    @torch.jit.unused
    def _check_input_size(self, spatial_shape):
        img_size = np.array(spatial_shape)
        remainder = (img_size % np.power(self.patch_size, 5)) > 0
        if remainder.any():
            wrong_dims = (np.where(remainder)[0] + 2).tolist()
            raise ValueError(
                f"spatial dimensions {wrong_dims} of input image (spatial shape: {spatial_shape})"
                f" must be divisible by {self.patch_size}**5."
            )

    def forward(self, x_in):
        if not torch.jit.is_scripting() and not torch.jit.is_tracing():
            self._check_input_size(x_in.shape[2:])
        hidden_states_out = self.swinViT(x_in, self.normalize)
        enc0 = self.encoder1(x_in)
        enc1 = self.encoder2(hidden_states_out[0])
        enc2 = self.encoder3(hidden_states_out[1])
        enc3 = self.encoder4(hidden_states_out[2])
        dec4 = self.encoder10(hidden_states_out[4])
        dec3 = self.decoder5(dec4, hidden_states_out[3])
        dec2 = self.decoder4(dec3, enc3)
        dec1 = self.decoder3(dec2, enc2)
        dec0 = self.decoder2(dec1, enc1)
        out = self.decoder1(dec0, enc0)
        logits = self.out(out)
        return logits


def window_partition(x, window_size):
    """window partition operation based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer

     Args:
        x: input tensor.
        window_size: local window size.
    """
    x_shape = x.size()  # length 4 or 5 only
    if len(x_shape) == 5:
        b, d, h, w, c = x_shape
        x = x.view(
            b,
            d // window_size[0],
            window_size[0],
            h // window_size[1],
            window_size[1],
            w // window_size[2],
            window_size[2],
            c,
        )
        windows = (
            x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0] * window_size[1] * window_size[2], c)
        )
    else:  # if len(x_shape) == 4:
        b, h, w, c = x.shape
        x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c)
        windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], c)

    return windows


def window_reverse(windows, window_size, dims):
    """window reverse operation based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer

     Args:
        windows: windows tensor.
        window_size: local window size.
        dims: dimension values.
    """
    if len(dims) == 4:
        b, d, h, w = dims
        x = windows.view(
            b,
            d // window_size[0],
            h // window_size[1],
            w // window_size[2],
            window_size[0],
            window_size[1],
            window_size[2],
            -1,
        )
        x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(b, d, h, w, -1)

    elif len(dims) == 3:
        b, h, w = dims
        x = windows.view(b, h // window_size[0], w // window_size[1], window_size[0], window_size[1], -1)
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
    return x


def get_window_size(x_size, window_size, shift_size=None):
    """Computing window size based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer

     Args:
        x_size: input size.
        window_size: local window size.
        shift_size: window shifting size.
    """

    use_window_size = list(window_size)
    if shift_size is not None:
        use_shift_size = list(shift_size)
    for i in range(len(x_size)):
        if x_size[i] <= window_size[i]:
            use_window_size[i] = x_size[i]
            if shift_size is not None:
                use_shift_size[i] = 0

    if shift_size is None:
        return tuple(use_window_size)
    else:
        return tuple(use_window_size), tuple(use_shift_size)


class WindowAttention(nn.Module):
    """
    Window based multi-head self attention module with relative position bias based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer
    """

    def __init__(
        self,
        dim: int,
        num_heads: int,
        window_size: Sequence[int],
        qkv_bias: bool = False,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
    ) -> None:
        """
        Args:
            dim: number of feature channels.
            num_heads: number of attention heads.
            window_size: local window size.
            qkv_bias: add a learnable bias to query, key, value.
            attn_drop: attention dropout rate.
            proj_drop: dropout rate of output.
        """

        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5
        mesh_args = torch.meshgrid.__kwdefaults__

        if len(self.window_size) == 3:
            self.relative_position_bias_table = nn.Parameter(
                torch.zeros(
                    (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1),
                    num_heads,
                )
            )
            coords_d = torch.arange(self.window_size[0])
            coords_h = torch.arange(self.window_size[1])
            coords_w = torch.arange(self.window_size[2])
            if mesh_args is not None:
                coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing="ij"))
            else:
                coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w))
            coords_flatten = torch.flatten(coords, 1)
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()
            relative_coords[:, :, 0] += self.window_size[0] - 1
            relative_coords[:, :, 1] += self.window_size[1] - 1
            relative_coords[:, :, 2] += self.window_size[2] - 1
            relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
            relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1
        elif len(self.window_size) == 2:
            self.relative_position_bias_table = nn.Parameter(
                torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
            )
            coords_h = torch.arange(self.window_size[0])
            coords_w = torch.arange(self.window_size[1])
            if mesh_args is not None:
                coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"))
            else:
                coords = torch.stack(torch.meshgrid(coords_h, coords_w))
            coords_flatten = torch.flatten(coords, 1)
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()
            relative_coords[:, :, 0] += self.window_size[0] - 1
            relative_coords[:, :, 1] += self.window_size[1] - 1
            relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1

        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index)
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        trunc_normal_(self.relative_position_bias_table, std=0.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask):
        b, n, c = x.shape
        qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        q = q * self.scale
        attn = q @ k.transpose(-2, -1)
        relative_position_bias = self.relative_position_bias_table[
            self.relative_position_index.clone()[:n, :n].reshape(-1)  # type: ignore[operator]
        ].reshape(n, n, -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
        attn = attn + relative_position_bias.unsqueeze(0)
        if mask is not None:
            nw = mask.shape[0]
            attn = attn.view(b // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, n, n)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn).to(v.dtype)
        x = (attn @ v).transpose(1, 2).reshape(b, n, c)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class SwinTransformerBlock(nn.Module):
    """
    Swin Transformer block based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer
    """

    def __init__(
        self,
        dim: int,
        num_heads: int,
        window_size: Sequence[int],
        shift_size: Sequence[int],
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        drop_path: float = 0.0,
        act_layer: str = "GELU",
        norm_layer: type[LayerNorm] = nn.LayerNorm,
        use_checkpoint: bool = False,
    ) -> None:
        """
        Args:
            dim: number of feature channels.
            num_heads: number of attention heads.
            window_size: local window size.
            shift_size: window shift size.
            mlp_ratio: ratio of mlp hidden dim to embedding dim.
            qkv_bias: add a learnable bias to query, key, value.
            drop: dropout rate.
            attn_drop: attention dropout rate.
            drop_path: stochastic depth rate.
            act_layer: activation layer.
            norm_layer: normalization layer.
            use_checkpoint: use gradient checkpointing for reduced memory usage.
        """

        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        self.use_checkpoint = use_checkpoint
        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim,
            window_size=self.window_size,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
        )

        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(hidden_size=dim, mlp_dim=mlp_hidden_dim, act=act_layer, dropout_rate=drop, dropout_mode="swin")

    def forward_part1(self, x, mask_matrix):
        x_shape = x.size()
        x = self.norm1(x)
        if len(x_shape) == 5:
            b, d, h, w, c = x.shape
            window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
            pad_l = pad_t = pad_d0 = 0
            pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0]
            pad_b = (window_size[1] - h % window_size[1]) % window_size[1]
            pad_r = (window_size[2] - w % window_size[2]) % window_size[2]
            x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))
            _, dp, hp, wp, _ = x.shape
            dims = [b, dp, hp, wp]

        else:  # elif len(x_shape) == 4
            b, h, w, c = x.shape
            window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
            pad_l = pad_t = 0
            pad_b = (window_size[0] - h % window_size[0]) % window_size[0]
            pad_r = (window_size[1] - w % window_size[1]) % window_size[1]
            x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
            _, hp, wp, _ = x.shape
            dims = [b, hp, wp]

        if any(i > 0 for i in shift_size):
            if len(x_shape) == 5:
                shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
            elif len(x_shape) == 4:
                shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
            attn_mask = mask_matrix
        else:
            shifted_x = x
            attn_mask = None
        x_windows = window_partition(shifted_x, window_size)
        attn_windows = self.attn(x_windows, mask=attn_mask)
        attn_windows = attn_windows.view(-1, *(window_size + (c,)))
        shifted_x = window_reverse(attn_windows, window_size, dims)
        if any(i > 0 for i in shift_size):
            if len(x_shape) == 5:
                x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))
            elif len(x_shape) == 4:
                x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
        else:
            x = shifted_x

        if len(x_shape) == 5:
            if pad_d1 > 0 or pad_r > 0 or pad_b > 0:
                x = x[:, :d, :h, :w, :].contiguous()
        elif len(x_shape) == 4:
            if pad_r > 0 or pad_b > 0:
                x = x[:, :h, :w, :].contiguous()

        return x

    def forward_part2(self, x):
        return self.drop_path(self.mlp(self.norm2(x)))

    def load_from(self, weights, n_block, layer):
        root = f"module.{layer}.0.blocks.{n_block}."
        block_names = [
            "norm1.weight",
            "norm1.bias",
            "attn.relative_position_bias_table",
            "attn.relative_position_index",
            "attn.qkv.weight",
            "attn.qkv.bias",
            "attn.proj.weight",
            "attn.proj.bias",
            "norm2.weight",
            "norm2.bias",
            "mlp.fc1.weight",
            "mlp.fc1.bias",
            "mlp.fc2.weight",
            "mlp.fc2.bias",
        ]
        with torch.no_grad():
            self.norm1.weight.copy_(weights["state_dict"][root + block_names[0]])
            self.norm1.bias.copy_(weights["state_dict"][root + block_names[1]])
            self.attn.relative_position_bias_table.copy_(weights["state_dict"][root + block_names[2]])
            self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]])  # type: ignore[operator]
            self.attn.qkv.weight.copy_(weights["state_dict"][root + block_names[4]])
            self.attn.qkv.bias.copy_(weights["state_dict"][root + block_names[5]])
            self.attn.proj.weight.copy_(weights["state_dict"][root + block_names[6]])
            self.attn.proj.bias.copy_(weights["state_dict"][root + block_names[7]])
            self.norm2.weight.copy_(weights["state_dict"][root + block_names[8]])
            self.norm2.bias.copy_(weights["state_dict"][root + block_names[9]])
            self.mlp.linear1.weight.copy_(weights["state_dict"][root + block_names[10]])
            self.mlp.linear1.bias.copy_(weights["state_dict"][root + block_names[11]])
            self.mlp.linear2.weight.copy_(weights["state_dict"][root + block_names[12]])
            self.mlp.linear2.bias.copy_(weights["state_dict"][root + block_names[13]])

    def forward(self, x, mask_matrix):
        shortcut = x
        if self.use_checkpoint:
            x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix, use_reentrant=False)
        else:
            x = self.forward_part1(x, mask_matrix)
        x = shortcut + self.drop_path(x)
        if self.use_checkpoint:
            x = x + checkpoint.checkpoint(self.forward_part2, x, use_reentrant=False)
        else:
            x = x + self.forward_part2(x)
        return x


class PatchMergingV2(nn.Module):
    """
    Patch merging layer based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer
    """

    def __init__(self, dim: int, norm_layer: type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3) -> None:
        """
        Args:
            dim: number of feature channels.
            norm_layer: normalization layer.
            spatial_dims: number of spatial dims.
        """

        super().__init__()
        self.dim = dim
        if spatial_dims == 3:
            self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)
            self.norm = norm_layer(8 * dim)
        elif spatial_dims == 2:
            self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
            self.norm = norm_layer(4 * dim)

    def forward(self, x):
        x_shape = x.size()
        if len(x_shape) == 5:
            b, d, h, w, c = x_shape
            pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1)
            if pad_input:
                x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2, 0, d % 2))
            x = torch.cat(
                [x[:, i::2, j::2, k::2, :] for i, j, k in itertools.product(range(2), range(2), range(2))], -1
            )

        elif len(x_shape) == 4:
            b, h, w, c = x_shape
            pad_input = (h % 2 == 1) or (w % 2 == 1)
            if pad_input:
                x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2))
            x = torch.cat([x[:, j::2, i::2, :] for i, j in itertools.product(range(2), range(2))], -1)

        x = self.norm(x)
        x = self.reduction(x)
        return x


class PatchMerging(PatchMergingV2):
    """The `PatchMerging` module previously defined in v0.9.0."""

    def forward(self, x):
        x_shape = x.size()
        if len(x_shape) == 4:
            return super().forward(x)
        if len(x_shape) != 5:
            raise ValueError(f"expecting 5D x, got {x.shape}.")
        b, d, h, w, c = x_shape
        pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1)
        if pad_input:
            x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2, 0, d % 2))
        x0 = x[:, 0::2, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, 0::2, :]
        x3 = x[:, 0::2, 0::2, 1::2, :]
        x4 = x[:, 1::2, 1::2, 0::2, :]
        x5 = x[:, 1::2, 0::2, 1::2, :]
        x6 = x[:, 0::2, 1::2, 1::2, :]
        x7 = x[:, 1::2, 1::2, 1::2, :]
        x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)
        x = self.norm(x)
        x = self.reduction(x)
        return x


MERGING_MODE = {"merging": PatchMerging, "mergingv2": PatchMergingV2}


def compute_mask(dims, window_size, shift_size, device):
    """Computing region masks based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer

     Args:
        dims: dimension values.
        window_size: local window size.
        shift_size: shift size.
        device: device.
    """

    cnt = 0

    if len(dims) == 3:
        d, h, w = dims
        img_mask = torch.zeros((1, d, h, w, 1), device=device)
        for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
            for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
                for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None):
                    img_mask[:, d, h, w, :] = cnt
                    cnt += 1

    elif len(dims) == 2:
        h, w = dims
        img_mask = torch.zeros((1, h, w, 1), device=device)
        for h in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
            for w in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
                img_mask[:, h, w, :] = cnt
                cnt += 1

    mask_windows = window_partition(img_mask, window_size)
    mask_windows = mask_windows.squeeze(-1)
    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

    return attn_mask


class BasicLayer(nn.Module):
    """
    Basic Swin Transformer layer in one stage based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer
    """

    def __init__(
        self,
        dim: int,
        depth: int,
        num_heads: int,
        window_size: Sequence[int],
        drop_path: list,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = False,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        norm_layer: type[LayerNorm] = nn.LayerNorm,
        downsample: nn.Module | None = None,
        use_checkpoint: bool = False,
    ) -> None:
        """
        Args:
            dim: number of feature channels.
            depth: number of layers in each stage.
            num_heads: number of attention heads.
            window_size: local window size.
            drop_path: stochastic depth rate.
            mlp_ratio: ratio of mlp hidden dim to embedding dim.
            qkv_bias: add a learnable bias to query, key, value.
            drop: dropout rate.
            attn_drop: attention dropout rate.
            norm_layer: normalization layer.
            downsample: an optional downsampling layer at the end of the layer.
            use_checkpoint: use gradient checkpointing for reduced memory usage.
        """

        super().__init__()
        self.window_size = window_size
        self.shift_size = tuple(i // 2 for i in window_size)
        self.no_shift = tuple(0 for i in window_size)
        self.depth = depth
        self.use_checkpoint = use_checkpoint
        self.blocks = nn.ModuleList(
            [
                SwinTransformerBlock(
                    dim=dim,
                    num_heads=num_heads,
                    window_size=self.window_size,
                    shift_size=self.no_shift if (i % 2 == 0) else self.shift_size,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    drop=drop,
                    attn_drop=attn_drop,
                    drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                    norm_layer=norm_layer,
                    use_checkpoint=use_checkpoint,
                )
                for i in range(depth)
            ]
        )
        self.downsample = downsample
        if callable(self.downsample):
            self.downsample = downsample(dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size))

    def forward(self, x):
        x_shape = x.size()
        if len(x_shape) == 5:
            b, c, d, h, w = x_shape
            window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
            x = rearrange(x, "b c d h w -> b d h w c")
            dp = int(np.ceil(d / window_size[0])) * window_size[0]
            hp = int(np.ceil(h / window_size[1])) * window_size[1]
            wp = int(np.ceil(w / window_size[2])) * window_size[2]
            attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device)
            for blk in self.blocks:
                x = blk(x, attn_mask)
            x = x.view(b, d, h, w, -1)
            if self.downsample is not None:
                x = self.downsample(x)
            x = rearrange(x, "b d h w c -> b c d h w")

        elif len(x_shape) == 4:
            b, c, h, w = x_shape
            window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
            x = rearrange(x, "b c h w -> b h w c")
            hp = int(np.ceil(h / window_size[0])) * window_size[0]
            wp = int(np.ceil(w / window_size[1])) * window_size[1]
            attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device)
            for blk in self.blocks:
                x = blk(x, attn_mask)
            x = x.view(b, h, w, -1)
            if self.downsample is not None:
                x = self.downsample(x)
            x = rearrange(x, "b h w c -> b c h w")
        return x


class SwinTransformer(nn.Module):
    """
    Swin Transformer based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer
    """

    def __init__(
        self,
        in_chans: int,
        embed_dim: int,
        window_size: Sequence[int],
        patch_size: Sequence[int],
        depths: Sequence[int],
        num_heads: Sequence[int],
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        norm_layer: type[LayerNorm] = nn.LayerNorm,
        patch_norm: bool = False,
        use_checkpoint: bool = False,
        spatial_dims: int = 3,
        downsample="merging",
        use_v2=False,
    ) -> None:
        """
        Args:
            in_chans: dimension of input channels.
            embed_dim: number of linear projection output channels.
            window_size: local window size.
            patch_size: patch size.
            depths: number of layers in each stage.
            num_heads: number of attention heads.
            mlp_ratio: ratio of mlp hidden dim to embedding dim.
            qkv_bias: add a learnable bias to query, key, value.
            drop_rate: dropout rate.
            attn_drop_rate: attention dropout rate.
            drop_path_rate: stochastic depth rate.
            norm_layer: normalization layer.
            patch_norm: add normalization after patch embedding.
            use_checkpoint: use gradient checkpointing for reduced memory usage.
            spatial_dims: spatial dimension.
            downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
                user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
                The default is currently `"merging"` (the original version defined in v0.9.0).
            use_v2: using swinunetr_v2, which adds a residual convolution block at the beginning of each swin stage.
        """

        super().__init__()
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.patch_norm = patch_norm
        self.window_size = window_size
        self.patch_size = patch_size
        self.patch_embed = PatchEmbed(
            patch_size=self.patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None,  # type: ignore
            spatial_dims=spatial_dims,
        )
        self.pos_drop = nn.Dropout(p=drop_rate)
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        self.use_v2 = use_v2
        self.layers1 = nn.ModuleList()
        self.layers2 = nn.ModuleList()
        self.layers3 = nn.ModuleList()
        self.layers4 = nn.ModuleList()
        if self.use_v2:
            self.layers1c = nn.ModuleList()
            self.layers2c = nn.ModuleList()
            self.layers3c = nn.ModuleList()
            self.layers4c = nn.ModuleList()
        down_sample_mod = look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample
        for i_layer in range(self.num_layers):
            layer = BasicLayer(
                dim=int(embed_dim * 2**i_layer),
                depth=depths[i_layer],
                num_heads=num_heads[i_layer],
                window_size=self.window_size,
                drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                norm_layer=norm_layer,
                downsample=down_sample_mod,
                use_checkpoint=use_checkpoint,
            )
            if i_layer == 0:
                self.layers1.append(layer)
            elif i_layer == 1:
                self.layers2.append(layer)
            elif i_layer == 2:
                self.layers3.append(layer)
            elif i_layer == 3:
                self.layers4.append(layer)
            if self.use_v2:
                layerc = UnetrBasicBlock(
                    spatial_dims=spatial_dims,
                    in_channels=embed_dim * 2**i_layer,
                    out_channels=embed_dim * 2**i_layer,
                    kernel_size=3,
                    stride=1,
                    norm_name="instance",
                    res_block=True,
                )
                if i_layer == 0:
                    self.layers1c.append(layerc)
                elif i_layer == 1:
                    self.layers2c.append(layerc)
                elif i_layer == 2:
                    self.layers3c.append(layerc)
                elif i_layer == 3:
                    self.layers4c.append(layerc)

        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))

    def proj_out(self, x, normalize=False):
        if normalize:
            x_shape = x.shape
            # Force trace() to generate a constant by casting to int
            ch = int(x_shape[1])
            if len(x_shape) == 5:
                x = rearrange(x, "n c d h w -> n d h w c")
                x = F.layer_norm(x, [ch])
                x = rearrange(x, "n d h w c -> n c d h w")
            elif len(x_shape) == 4:
                x = rearrange(x, "n c h w -> n h w c")
                x = F.layer_norm(x, [ch])
                x = rearrange(x, "n h w c -> n c h w")
        return x

    def forward(self, x, normalize=True):
        x0 = self.patch_embed(x)
        x0 = self.pos_drop(x0)
        x0_out = self.proj_out(x0, normalize)
        if self.use_v2:
            x0 = self.layers1c[0](x0.contiguous())
        x1 = self.layers1[0](x0.contiguous())
        x1_out = self.proj_out(x1, normalize)
        if self.use_v2:
            x1 = self.layers2c[0](x1.contiguous())
        x2 = self.layers2[0](x1.contiguous())
        x2_out = self.proj_out(x2, normalize)
        if self.use_v2:
            x2 = self.layers3c[0](x2.contiguous())
        x3 = self.layers3[0](x2.contiguous())
        x3_out = self.proj_out(x3, normalize)
        if self.use_v2:
            x3 = self.layers4c[0](x3.contiguous())
        x4 = self.layers4[0](x3.contiguous())
        x4_out = self.proj_out(x4, normalize)
        return [x0_out, x1_out, x2_out, x3_out, x4_out]


def filter_swinunetr(key, value):
    """
    A filter function used to filter the pretrained weights from [1], then the weights can be loaded into MONAI SwinUNETR Model.
    This function is typically used with `monai.networks.copy_model_state`
    [1] "Valanarasu JM et al., Disruptive Autoencoders: Leveraging Low-level features for 3D Medical Image Pre-training
    <https://arxiv.org/abs/2307.16896>"

    Args:
        key: the key in the source state dict used for the update.
        value: the value in the source state dict used for the update.

    Examples::

        import torch
        from monai.apps import download_url
        from monai.networks.utils import copy_model_state
        from monai.networks.nets.swin_unetr import SwinUNETR, filter_swinunetr

        model = SwinUNETR(in_channels=1, out_channels=3, feature_size=48)
        resource = (
            "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth"
        )
        ssl_weights_path = "./ssl_pretrained_weights.pth"
        download_url(resource, ssl_weights_path)
        ssl_weights = torch.load(ssl_weights_path, weights_only=True)["model"]

        dst_dict, loaded, not_loaded = copy_model_state(model, ssl_weights, filter_func=filter_swinunetr)

    """
    if key in [
        "encoder.mask_token",
        "encoder.norm.weight",
        "encoder.norm.bias",
        "out.conv.conv.weight",
        "out.conv.conv.bias",
    ]:
        return None

    if key[:8] == "encoder.":
        if key[8:19] == "patch_embed":
            new_key = "swinViT." + key[8:]
        else:
            new_key = "swinViT." + key[8:18] + key[20:]

        return new_key, value
    else:
        return None

In [27]:
import os
import time
import torch
import torch.nn as nn
import pytorch_lightning as pl
from monai.losses import DiceLoss, DiceCELoss, FocalLoss
from monai.metrics import DiceMetric
from monai.transforms import Compose, Activations, AsDiscrete
from monai.data import PersistentDataset, list_data_collate, decollate_batch, DataLoader, load_decathlon_datalist, CacheDataset
from monai.inferers import sliding_window_inference
from torch.optim import Adam, AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
from monai.data import DataLoader, Dataset
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.timer import Timer
from torch.cuda.amp import GradScaler
import wandb
from pytorch_lightning.loggers import WandbLogger
import math

class BrainTumorSegmentation(pl.LightningModule):
    def __init__(self, train_loader, val_loader, max_epochs=100,
                 val_interval=1, learning_rate=1e-4, feature_size=48,
                 weight_decay=1e-5, warmup_epochs=10, roi_size=(96, 96, 96),
                 sw_batch_size=2, use_v2=True, depths=(2, 2, 2, 2),
                 num_heads=(3, 6, 12, 24), downsample="mergingv2",
                 use_class_weights=True,
                 ## MedNext Params ##
                 mednext_exp_r=4, mednext_kernel_size=7,
                 mednext_norm_type='group', mednext_grn=True,
                 ):
        
        super().__init__()
        self.save_hyperparameters()
        
        # Base SwinUNETR model
        self.model = SwinUNETR(
            in_channels=4,
            out_channels=3,
            feature_size=self.hparams.feature_size,
            use_checkpoint=True,
            use_v2=self.hparams.use_v2,
            spatial_dims=3,
            depths=self.hparams.depths,
            num_heads=self.hparams.num_heads,
            norm_name="instance",
            drop_rate=0.0,
            attn_drop_rate=0.0,
            dropout_path_rate=0.0,
            downsample=self.hparams.downsample,
            # MedNext Params
            mednext_exp_r = self.hparams.mednext_exp_r,
            mednext_kernel_size = self.hparams.mednext_kernel_size,
            mednext_norm_type = self.hparams.mednext_norm_type,
            mednext_grn = self.hparams.mednext_grn,
            
        )
        
        # Class weights based on BraTS imbalance: ET (most rare) > TC > WT
        if self.hparams.use_class_weights:
            # Higher weights for more imbalanced classes
            class_weights = torch.tensor([1.0, 3.0, 5.0])  # Background, WT, TC, ET
        else:
            class_weights = None
            
        # Loss functions with class weighting
        self.dice_loss = DiceLoss(
            smooth_nr=0, smooth_dr=1e-5, squared_pred=True, 
            to_onehot_y=False, sigmoid=True
        )
        self.ce_loss = DiceCELoss(
            smooth_nr=0, smooth_dr=1e-5, squared_pred=True, 
            to_onehot_y=False, sigmoid=True
        )
        self.focal_loss = FocalLoss(
            gamma=2.0, weight=class_weights, reduction='mean'
        )
        
        # Standard Dice Loss Metrics
        self.dice_metric = DiceMetric(include_background=True, reduction="mean")
        self.dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")

        self.post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
        
        self.best_metric = -1
        self.train_loader = train_loader
        self.val_loader = val_loader

        # Training metrics
        self.avg_train_loss_values = []
        self.train_loss_values = []
        self.train_metric_values = []
        self.train_metric_values_tc = []
        self.train_metric_values_wt = []
        self.train_metric_values_et = []

        # Validation metrics
        self.avg_val_loss_values = []
        self.epoch_loss_values = []
        self.metric_values = []
        self.metric_values_tc = []
        self.metric_values_wt = []
        self.metric_values_et = []

    def forward(self, x):
        return self.model(x)

    def compute_loss(self, outputs, labels):
        """Combine DiceCE with class-weighted Focal for imbalanced classes"""
        dice_ce_loss = self.ce_loss(outputs, labels)
        focal_loss = self.focal_loss(outputs, labels)
        
        # Balanced weighting with more emphasis on focal for imbalance
        total_loss = 0.6 * dice_ce_loss + 0.4 * focal_loss
        return total_loss

    def training_step(self, batch, batch_idx):
        inputs, labels = batch["image"], batch["label"]

        # Calculate Train Loss with hybrid approach
        outputs = self(inputs)
        loss = self.compute_loss(outputs, labels)
        
        # Log the Train Loss
        self.log("train_loss", loss, prog_bar=True)

        # Apply sigmoid and threshold
        outputs = [self.post_trans(i) for i in decollate_batch(outputs)]
        
        # Compute Dice
        self.dice_metric(y_pred=outputs, y=labels)
        self.dice_metric_batch(y_pred=outputs, y=labels)

        # Log Train Dice 
        train_dice = self.dice_metric.aggregate().item()
        self.log("train_mean_dice", train_dice, prog_bar=True)

        # Store metrics
        self.train_metric_values.append(train_dice)
        metric_batch = self.dice_metric_batch.aggregate()
        self.train_metric_values_tc.append(metric_batch[0].item())
        self.train_metric_values_wt.append(metric_batch[1].item())
        self.train_metric_values_et.append(metric_batch[2].item())

        # Log individual dice metrics
        self.log("train_tc", metric_batch[0].item(), prog_bar=True)
        self.log("train_wt", metric_batch[1].item(), prog_bar=True)
        self.log("train_et", metric_batch[2].item(), prog_bar=True)

        # Reset metrics
        self.dice_metric.reset()
        self.dice_metric_batch.reset()

        return loss

    def on_train_epoch_end(self):
        train_loss = self.trainer.logged_metrics["train_loss"].item()
        self.train_loss_values.append(train_loss)
        
        avg_train_loss = sum(self.train_loss_values) / len(self.train_loss_values)
        self.log("avg_train_loss", avg_train_loss, prog_bar=True, sync_dist=True)
        self.avg_train_loss_values.append(avg_train_loss)

    def validation_step(self, batch, batch_idx):
        val_inputs, val_labels = batch["image"], batch["label"]
        
        # Multiple overlapping predictions for better accuracy
        roi_size = (96, 96, 96)
        
        # Original prediction
        val_outputs = sliding_window_inference(
            val_inputs, roi_size=roi_size, sw_batch_size=1, 
            predictor=self.model, overlap=0.6  # Higher overlap
        )
        
        # Compute loss with hybrid approach
        val_loss = self.compute_loss(val_outputs, val_labels)
        self.log("val_loss", val_loss, prog_bar=True, sync_dist=True, on_epoch=True)
        
        val_outputs = [self.post_trans(i) for i in decollate_batch(val_outputs)]    
        
        # Compute Dice
        self.dice_metric(y_pred=val_outputs, y=val_labels)
        self.dice_metric_batch(y_pred=val_outputs, y=val_labels)
        return {"val_loss": val_loss}

    def on_validation_epoch_end(self):
        val_dice = self.dice_metric.aggregate().item()
        self.metric_values.append(val_dice)

        val_loss = self.trainer.logged_metrics["val_loss"].item()
        self.epoch_loss_values.append(val_loss)

        metric_batch = self.dice_metric_batch.aggregate()
        self.metric_values_tc.append(metric_batch[0].item())
        self.metric_values_wt.append(metric_batch[1].item())
        self.metric_values_et.append(metric_batch[2].item())

        # Log validation metrics
        self.log("val_loss", val_loss, prog_bar=True, on_epoch=True, sync_dist=True)
        self.log("val_mean_dice", val_dice, prog_bar=True, on_epoch=True, sync_dist=True)
        self.log("val_tc", metric_batch[0].item(), prog_bar=True, on_epoch=True, sync_dist=True)
        self.log("val_wt", metric_batch[1].item(), prog_bar=True, on_epoch=True, sync_dist=True)
        self.log("val_et", metric_batch[2].item(), prog_bar=True, on_epoch=True, sync_dist=True)

    
        if val_dice > self.best_metric:
            self.best_metric = val_dice
            self.best_metric_epoch = self.current_epoch
            torch.save(self.model.state_dict(), "best_metric_model_swinunetr_v2.pth")
            self.log("best_metric", self.best_metric, sync_dist=True, on_epoch=True)
    
        # Reset metrics
        self.dice_metric.reset()
        self.dice_metric_batch.reset()

    def on_train_end(self):
        print(f"Train completed, best_metric: {self.best_metric:.4f} at epoch: {self.best_metric_epoch}, "
              f"tc: {self.metric_values_tc[-1]:.4f}, "
              f"wt: {self.metric_values_wt[-1]:.4f}, "
              f"et: {self.metric_values_et[-1]:.4f}.")

    def configure_optimizers(self):
        optimizer = AdamW(
            self.model.parameters(), 
            lr=self.hparams.learning_rate, 
            weight_decay=self.hparams.weight_decay,
            betas=(0.9, 0.999),
            eps=1e-8
        )
        
        # Warmup + Cosine Annealing
        def lr_lambda(epoch):
            if epoch < self.hparams.warmup_epochs:
                return epoch / self.hparams.warmup_epochs
            else:
                progress = (epoch - self.hparams.warmup_epochs) / (self.hparams.max_epochs - self.hparams.warmup_epochs)
                return 0.5 * (1 + math.cos(math.pi * progress))
        
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
                "frequency": 1,
                "name": "learning_rate"
            }
        }

## Pipeline

## Run 

In [None]:
import torch
import wandb
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
from pytorch_lightning.strategies import DDPStrategy

# Assume your data loaders are already defined
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, pin_memory=True, persistent_workers=False)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=False)

# Enhanced Logger with more tracking
wandb_logger = WandbLogger(
    project="brain-tumor-segmentation",
    name="swinv2-mednext",
    log_model=True,
    save_dir="./wandb_logs"
)

# Enhanced Callbacks
callbacks = [
    ModelCheckpoint(
        monitor='val_mean_dice',
        mode='max',
        save_top_k=3,  # Save top 3 models
        save_last=True,
        filename='brats-{epoch:02d}-{val_mean_dice:.4f}',
        verbose=True,
        auto_insert_metric_name=False
    ),
    LearningRateMonitor(logging_interval='epoch'),
    EarlyStopping(
        monitor='val_mean_dice',
        mode='max',
        patience=15,  # Slightly more patience
        verbose=True,
        min_delta=0.01  # Minimum improvement threshold
    ),
    Timer(duration="00:11:00:00")
]
    
# Create enhanced model
model = BrainTumorSegmentation(
    train_loader, val_loader,
    feature_size=48,
    learning_rate=1e-3,
    weight_decay=2e-5,
    warmup_epochs=5,
    roi_size=(96, 96, 96),
    sw_batch_size=2,
    use_v2=True,
    depths=(2, 2, 2, 2),
    num_heads=(3, 6, 12, 24), 
    downsample="mergingv2",
    ## MedNext Params ##
    mednext_exp_r=4,
    mednext_kernel_size=7,
    mednext_norm_type='group',
    mednext_grn=False,  # Enable Global Response Normalization
    ## Weighted Params
    use_class_weights=True,
)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")


trainer = pl.Trainer(
    max_epochs=50,  # Slightly more epochs with better early stopping
    precision="16-mixed",
    devices=torch.cuda.device_count(),
    strategy="ddp_notebook",
    # strategy="auto",
    accelerator="gpu",
    gradient_clip_val=0.5,  # Slightly lower for stability
    accumulate_grad_batches=4,
    callbacks=callbacks,
    logger=wandb_logger,
    enable_checkpointing=True,
    deterministic=False,
    benchmark=True,
    log_every_n_steps=5,  # More frequent logging
    check_val_every_n_epoch=1,  # Validate once every 10 epochs
    limit_val_batches=10,
    sync_batchnorm=True,  # Better for multi-GPU
    enable_model_summary=True,
    profiler="simple"  # Basic profiling for optimization insights
)

# Train the enhanced model
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

Model parameters: 38,109,021




Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('best_metric', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


Training: |          | 0/? [00:00<?, ?it/s]

grad.sizes() = [768, 1, 7, 7, 7], strides() = [343, 1, 49, 7, 1]
bucket_view.sizes() = [768, 1, 7, 7, 7], strides() = [343, 343, 49, 7, 1] (Triggered internally at ../torch/csrc/distributed/c10d/reducer.cpp:327.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
grad.sizes() = [768, 1, 7, 7, 7], strides() = [343, 1, 49, 7, 1]
bucket_view.sizes() = [768, 1, 7, 7, 7], strides() = [343, 343, 49, 7, 1] (Triggered internally at ../torch/csrc/distributed/c10d/reducer.cpp:327.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

## End Of Training and Results

In [None]:
import matplotlib.pyplot as plt

# Plotting train vs validation metrics
plt.figure("train", (12, 6))

# Plot 1: Epoch Average Loss
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(model.avg_train_loss_values))]
y = model.avg_train_loss_values
plt.xlabel("Epoch")
plt.plot(x, y, color="red", label="Train Loss")
plt.legend()

# Plot 2: Train Mean Dice
plt.subplot(1, 2, 2)
plt.title("Train Mean Dice")
x = [i + 1 for i in range(len(model.train_metric_values))]
y = model.train_metric_values
plt.xlabel("Epoch")
plt.plot(x, y, color="green", label="Train Dice")
plt.legend()

plt.show()

# Plotting dice metrics for different categories (TC, WT, ET)
plt.figure("train", (18, 6))

# Plot 1: Train Mean Dice TC
plt.subplot(1, 3, 1)
plt.title("Train Mean Dice TC")
x = [i + 1 for i in range(len(model.train_metric_values_tc))]
y = model.train_metric_values_tc
plt.xlabel("Epoch")
plt.plot(x, y, color="blue", label="Train TC Dice")
plt.legend()

# Plot 2: Train Mean Dice WT
plt.subplot(1, 3, 2)
plt.title("Train Mean Dice WT")
x = [i + 1 for i in range(len(model.train_metric_values_wt))]
y = model.train_metric_values_wt
plt.xlabel("Epoch")
plt.plot(x, y, color="brown", label="Train WT Dice")
plt.legend()

# Plot 3: Train Mean Dice ET
plt.subplot(1, 3, 3)
plt.title("Train Mean Dice ET")
x = [i + 1 for i in range(len(model.train_metric_values_et))]
y = model.train_metric_values_et
plt.xlabel("Epoch")
plt.plot(x, y, color="purple", label="Train ET Dice")
plt.legend()

plt.show()


In [None]:
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(model.epoch_loss_values))]
y = model.epoch_loss_values
plt.xlabel("epoch")
plt.plot(x, y, color="red")
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
x = [model.hparams.val_interval * (i + 1) for i in range(len(model.metric_values))]
y = model.metric_values
plt.xlabel("epoch")
plt.plot(x, y, color="green")
plt.show()

plt.figure("train", (18, 6))
plt.subplot(1, 3, 1)
plt.title("Val Mean Dice TC")
x = [model.hparams.val_interval * (i + 1) for i in range(len(model.metric_values_tc))]
y = model.metric_values_tc
plt.xlabel("epoch")
plt.plot(x, y, color="blue")
plt.subplot(1, 3, 2)
plt.title("Val Mean Dice WT")
x = [model.hparams.val_interval * (i + 1) for i in range(len(model.metric_values_wt))]
y = model.metric_values_wt
plt.xlabel("epoch")
plt.plot(x, y, color="brown")
plt.subplot(1, 3, 3)
plt.title("Val Mean Dice ET")
x = [model.hparams.val_interval * (i + 1) for i in range(len(model.metric_values_et))]
y = model.metric_values_et
plt.xlabel("epoch")
plt.plot(x, y, color="purple")
plt.show()


In [None]:
# from monai.inferers import sliding_window_inference

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])


# model.model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))
# model = model.to(device)  # Add this line to move model to the same device as inputs
# model.eval()
# with torch.no_grad():
#     # select one image to evaluate and visualize the model output
#     val_input = val_ds[5]["image"].unsqueeze(0).to(device)
#     roi_size = (96, 96, 96)
#     sw_batch_size = 4
#     val_output = sliding_window_inference(val_input, roi_size, sw_batch_size, model)
#     val_output = post_trans(val_output[0])
#     plt.figure("image", (24, 6))
#     for i in range(4):
#         plt.subplot(1, 4, i + 1)
#         plt.title(f"image channel {i}")
#         plt.imshow(val_ds[5]["image"][i, :, :, 72].detach().cpu(), cmap="gray")
#     plt.show()
#     # visualize the 3 channels label corresponding to this image
#     plt.figure("label", (18, 6))
#     for i in range(3):
#         plt.subplot(1, 3, i + 1)
#         plt.title(f"label channel {i}")
#         plt.imshow(val_ds[5]["label"][i, :, :, 72].detach().cpu())
#     plt.show()
#     # visualize the 3 channels model output corresponding to this image
#     plt.figure("output", (18, 6))
#     for i in range(3):
#         plt.subplot(1, 3, i + 1)
#         plt.title(f"output channel {i}")
#         plt.imshow(val_output[i, :, :, 72].detach().cpu())
#     plt.show()

In [None]:
# from monai.inferers import sliding_window_inference
# import matplotlib.pyplot as plt
# import numpy as np
# import torch
# from matplotlib.colors import ListedColormap
# import os
# from monai.transforms import Compose, Activations, AsDiscrete

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

# # Load model and move to device
# model.model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))
# model = model.to(device)
# model.eval()

# # Define aesthetically pleasing color scheme
# # Using a professionally-designed color palette with better contrast
# class_colors = [
#     [0.7, 0.7, 0.7, 1],         # Background (neutral gray)
#     [0.85, 0.37, 0.35, 0.7],  # Class 1 (rust red - softer and more professional)
#     [0.46, 0.78, 0.56, 0.7],  # Class 2 (sage green - easier on the eyes)
#     [0.31, 0.51, 0.9, 0.7]    # Class 3 (medium blue - more saturated but not overwhelming)
# ]
# custom_cmap = ListedColormap(class_colors)

# with torch.no_grad():
#     # Select one image to evaluate
#     val_input = val_ds[8]["image"].unsqueeze(0).to(device)
#     val_label = val_ds[8]["label"]
    
#     # Inference
#     roi_size = (96, 96, 96)
#     sw_batch_size = 4
#     val_output = sliding_window_inference(val_input, roi_size, sw_batch_size, model)
#     val_output = post_trans(val_output[0])
    
#     # Move tensors to CPU and convert to numpy
#     val_input_np = val_input[0, 0].cpu().numpy()  # Shape: (H, W, D)
#     val_label_np = val_label.cpu().numpy()  # Shape: (C, H, W, D) where C is number of classes
#     val_output_np = val_output.cpu().numpy()  # Shape: (C, H, W, D)
    
#     # Normalize image for visualization with better contrast
#     val_input_np = (val_input_np - val_input_np.min()) / (val_input_np.max() - val_input_np.min())
#     val_input_np = (val_input_np * 255).astype(np.uint8)
    
#     # Determine slice to use (middle slice or 77 if available)
#     total_slices = val_input_np.shape[-1]
#     middle_slice = total_slices // 2
#     slice_idx = 77 if total_slices > 77 else middle_slice
#     print(f"Using slice {slice_idx} out of {total_slices} total slices")
    
#     # Create a combined segmentation map for ground truth and prediction
#     # Initialize with zeros (background)
#     num_classes = val_label_np.shape[0]
#     gt_combined = np.zeros((val_label_np.shape[1], val_label_np.shape[2], 4))  # RGBA
#     pred_combined = np.zeros((val_output_np.shape[1], val_output_np.shape[2], 4))  # RGBA
    
#     # Fill in each class with its color
#     for c in range(num_classes):
#         # For ground truth
#         mask = val_label_np[c, :, :, slice_idx]
#         for i in range(4):  # RGBA channels
#             gt_combined[:, :, i] = np.where(mask > 0, class_colors[c+1][i], gt_combined[:, :, i])
        
#         # For prediction
#         mask = val_output_np[c, :, :, slice_idx]
#         for i in range(4):  # RGBA channels
#             pred_combined[:, :, i] = np.where(mask > 0, class_colors[c+1][i], pred_combined[:, :, i])
    
#     # Plot the images with improved styling
#     plt.figure(figsize=(18, 6), facecolor='white')
    
#     plt.subplot(1, 3, 1)
#     plt.title("Image", fontsize=14, fontweight='bold')
#     plt.imshow(val_input_np[:, :, slice_idx], cmap="gray")
#     plt.axis('off')
    
#     plt.subplot(1, 3, 2)
#     plt.title("Ground Truth", fontsize=14, fontweight='bold')
#     plt.imshow(val_input_np[:, :, slice_idx], cmap="gray")
#     plt.imshow(gt_combined)  # Alpha is already in the array
#     plt.axis('off')
    
#     plt.subplot(1, 3, 3)
#     plt.title("Predicted Segmentation", fontsize=14, fontweight='bold')
#     plt.imshow(val_input_np[:, :, slice_idx], cmap="gray")
#     plt.imshow(pred_combined)  # Alpha is already in the array
#     plt.axis('off')
    
#     # Add a clearer color legend
#     class_names = ["Tumor Core", "Whole Tumor", "Enhancing"]
#     legend_patches = [plt.Rectangle((0, 0), 1, 1, fc=class_colors[i+1][:3], alpha=0.7) for i in range(num_classes)]
#     plt.figlegend(legend_patches, class_names, loc='lower center', ncol=num_classes, 
#                  bbox_to_anchor=(0.5, -0.05), fontsize=12, frameon=True, edgecolor='black')
    
#     plt.tight_layout(pad=1.5)
#     plt.subplots_adjust(bottom=0.15)  # Add space for the legend
#     plt.show()
    
#     # Display each class separately with enhanced visualization
#     plt.figure(figsize=(15, 5 * num_classes), facecolor='white')
    
#     # Custom colormaps for each class - more aesthetically pleasing
#     class_cmaps = ['RdPu', 'BuGn', 'PuBu']
    
#     for c in range(num_classes):
#         # Ground Truth for this class
#         plt.subplot(num_classes, 2, 2*c+1)
#         plt.title(f"Ground Truth - {class_names[c]}", fontsize=12, fontweight='bold')
#         plt.imshow(val_input_np[:, :, slice_idx], cmap="gray")
#         plt.imshow(val_label_np[c, :, :, slice_idx], cmap=class_cmaps[c], alpha=0.7, vmin=0, vmax=1)
#         plt.axis('off')
        
#         # Prediction for this class
#         plt.subplot(num_classes, 2, 2*c+2)
#         plt.title(f"Prediction - {class_names[c]}", fontsize=12, fontweight='bold')
#         plt.imshow(val_input_np[:, :, slice_idx], cmap="gray")
#         plt.imshow(val_output_np[c, :, :, slice_idx], cmap=class_cmaps[c], alpha=0.7, vmin=0, vmax=1)
#         plt.axis('off')
        
#         # Add a small colorbar to show intensity
#         plt.colorbar(shrink=0.8, ax=plt.gca())
    
#     plt.tight_layout(pad=2.0)
#     plt.show()

In [None]:
# import random
# import torch
# import numpy as np
# import matplotlib.pyplot as plt
# from matplotlib.colors import ListedColormap
# from monai.inferers import sliding_window_inference
# from monai.transforms import Compose, Activations, AsDiscrete
# import os

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

# # Load model and move to device
# model.model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))
# model = model.to(device)
# model.eval()

# # Define aesthetically pleasing color scheme
# class_colors = [
#     [0.7, 0.7, 0.7, 1],         # Background (neutral gray)
#     [0.85, 0.37, 0.35, 0.7],  # Class 1 (rust red - softer and more professional)
#     [0.46, 0.78, 0.56, 0.7],  # Class 2 (sage green - easier on the eyes)
#     [0.31, 0.51, 0.9, 0.7]    # Class 3 (medium blue - more saturated but not overwhelming)
# ]
# custom_cmap = ListedColormap(class_colors)

# # Randomly select 5 samples
# random_indices = random.sample(range(len(val_ds)), 5)

# for idx in random_indices:
#     with torch.no_grad():
#         # Select image to evaluate
#         val_input = val_ds[idx]["image"].unsqueeze(0).to(device)
#         val_label = val_ds[idx]["label"]
        
#         # Inference
#         roi_size = (96, 96, 96)
#         sw_batch_size = 4
#         val_output = sliding_window_inference(val_input, roi_size, sw_batch_size, model)
#         val_output = post_trans(val_output[0])
        
#         # Move tensors to CPU and convert to numpy
#         val_input_np = val_input[0, 0].cpu().numpy()  # Shape: (H, W, D)
#         val_label_np = val_label.cpu().numpy()  # Shape: (C, H, W, D) where C is number of classes
#         val_output_np = val_output.cpu().numpy()  # Shape: (C, H, W, D)
        
#         # Normalize image for visualization with better contrast
#         val_input_np = (val_input_np - val_input_np.min()) / (val_input_np.max() - val_input_np.min())
#         val_input_np = (val_input_np * 255).astype(np.uint8)
        
#         # Determine slice to use (middle slice or 77 if available)
#         total_slices = val_input_np.shape[-1]
#         middle_slice = total_slices // 2
#         slice_idx = 77 if total_slices > 77 else middle_slice
#         print(f"Using slice {slice_idx} out of {total_slices} total slices")
        
#         # Create a combined segmentation map for ground truth and prediction
#         num_classes = val_label_np.shape[0]
#         gt_combined = np.zeros((val_label_np.shape[1], val_label_np.shape[2], 4))  # RGBA
#         pred_combined = np.zeros((val_output_np.shape[1], val_output_np.shape[2], 4))  # RGBA
        
#         # Fill in each class with its color
#         for c in range(num_classes):
#             # For ground truth
#             mask = val_label_np[c, :, :, slice_idx]
#             for i in range(4):  # RGBA channels
#                 gt_combined[:, :, i] = np.where(mask > 0, class_colors[c+1][i], gt_combined[:, :, i])
            
#             # For prediction
#             mask = val_output_np[c, :, :, slice_idx]
#             for i in range(4):  # RGBA channels
#                 pred_combined[:, :, i] = np.where(mask > 0, class_colors[c+1][i], pred_combined[:, :, i])
        
#         # Plot the images with improved styling
#         plt.figure(figsize=(18, 6), facecolor='white')
        
#         plt.subplot(1, 3, 1)
#         plt.title("Image", fontsize=14, fontweight='bold')
#         plt.imshow(val_input_np[:, :, slice_idx], cmap="gray")
#         plt.axis('off')
        
#         plt.subplot(1, 3, 2)
#         plt.title("Ground Truth", fontsize=14, fontweight='bold')
#         plt.imshow(val_input_np[:, :, slice_idx], cmap="gray")
#         plt.imshow(gt_combined)  # Alpha is already in the array
#         plt.axis('off')
        
#         plt.subplot(1, 3, 3)
#         plt.title("Predicted Segmentation", fontsize=14, fontweight='bold')
#         plt.imshow(val_input_np[:, :, slice_idx], cmap="gray")
#         plt.imshow(pred_combined)  # Alpha is already in the array
#         plt.axis('off')
        
#         # Add a clearer color legend
#         class_names = ["Tumor Core", "Whole Tumor", "Enhancing"]
#         legend_patches = [plt.Rectangle((0, 0), 1, 1, fc=class_colors[i+1][:3], alpha=0.7) for i in range(num_classes)]
#         plt.figlegend(legend_patches, class_names, loc='lower center', ncol=num_classes, 
#                      bbox_to_anchor=(0.5, -0.05), fontsize=12, frameon=True, edgecolor='black')
        
#         plt.tight_layout(pad=1.5)
#         plt.subplots_adjust(bottom=0.15)  # Add space for the legend
#         plt.show()
        
#         # Display each class separately with enhanced visualization
#         plt.figure(figsize=(15, 5 * num_classes), facecolor='white')
        
#         # Custom colormaps for each class - more aesthetically pleasing
#         class_cmaps = ['RdPu', 'BuGn', 'PuBu']
        
#         for c in range(num_classes):
#             # Ground Truth for this class
#             plt.subplot(num_classes, 2, 2*c+1)
#             plt.title(f"Ground Truth - {class_names[c]}", fontsize=12, fontweight='bold')
#             plt.imshow(val_input_np[:, :, slice_idx], cmap="gray")
#             plt.imshow(val_label_np[c, :, :, slice_idx], cmap=class_cmaps[c], alpha=0.7, vmin=0, vmax=1)
#             plt.axis('off')
            
#             # Prediction for this class
#             plt.subplot(num_classes, 2, 2*c+2)
#             plt.title(f"Prediction - {class_names[c]}", fontsize=12, fontweight='bold')
#             plt.imshow(val_input_np[:, :, slice_idx], cmap="gray")
#             plt.imshow(val_output_np[c, :, :, slice_idx], cmap=class_cmaps[c], alpha=0.7, vmin=0, vmax=1)
#             plt.axis('off')
            
#             # Add a small colorbar to show intensity
#             plt.colorbar(shrink=0.8, ax=plt.gca())
        
#         plt.tight_layout(pad=2.0)
#         plt.show()


In [None]:
# import numpy as np

# def overlay_label_on_image(image, label):
#     overlaid_images = []
#     for i in range(image.shape[0]):
#         overlaid_image = np.zeros_like(image[i])
#         # Overlay each label channel onto the corresponding image channel
#         for j in range(min(image.shape[0], label.shape[0])):
#             overlaid_image[label[j] > 0] = label[j][label[j] > 0]
#         overlaid_images.append(overlaid_image)
#     return np.stack(overlaid_images)

# # Usage:
# overlay = overlay_label_on_image(val_ds[1]["image"].detach().cpu().numpy(), val_ds[1]["label"].detach().cpu().numpy())


In [None]:
# overlay.shape

In [None]:
# import matplotlib.pyplot as plt
# from mpl_toolkits.mplot3d import Axes3D

# # Iterate over each channel
# for i in range(overlay.shape[0]):
#     # Create a new figure and a set of subplots with a 3D projection for each channel
#     fig = plt.figure()
#     ax = fig.add_subplot(111, projection='3d')

#     # Display the 3D volume for the current channel
#     ax.voxels(overlay[i])

#     # Set labels and title
#     ax.set_xlabel('X')
#     ax.set_ylabel('Y')
#     ax.set_zlabel('Z')
#     ax.set_title(f'Overlay Channel {i+1}')

#     # Show plot
#     plt.show()


## Create Test Dataset JSON

In [None]:
# import os
# import glob
# import json

# # Get sorted file paths and file names
# file_paths2 = glob.glob('/kaggle/input/brats2023-part-2zip/*')  # Unseen Data 
# file_paths2.sort()

# file_names2 = [os.path.basename(path) for path in file_paths2]  # Extract file names from paths
# file_names2.sort()

# # Initialize lists for different MRI modalities and segmentation labels
# t1c, t1n, t2f, t2w, label = [], [], [], [], []

# # Use the total number of files
# num_files = len(file_paths2)

# # Populate the lists with file paths
# for i in range(num_files):
#     t1c.append(os.path.join(file_paths2[i], file_names2[i] + '-t1c.nii'))
#     t1n.append(os.path.join(file_paths2[i], file_names2[i] + '-t1n.nii'))
#     t2f.append(os.path.join(file_paths2[i], file_names2[i] + '-t2f.nii'))
#     t2w.append(os.path.join(file_paths2[i], file_names2[i] + '-t2w.nii'))
#     label.append(os.path.join(file_paths2[i], file_names2[i] + '-seg.nii'))

# # Store in a dictionary with combined image modalities and separate label
# file_list = []
# for i in range(num_files):
#     file_list.append({
#         "image": [t1c[i], t1n[i], t2f[i], t2w[i]],  # Combine modalities into one "image" field
#         "label": label[i]
#     })

# file_json = {
#     "testing": file_list  # Changed key to "testing" for clarity
# }

# # Save to JSON file
# file_path = '/kaggle/working/dataset_test.json'
# with open(file_path, 'w') as json_file:
#     json.dump(file_json, json_file, indent=4)


## Test DataLoader

In [None]:
# # Load test dataset
# dataset_path = "/kaggle/working/dataset_test.json"
# with open(dataset_path) as f:
#     datalist = json.load(f)["testing"]  # Updated key to match test dataset

# ### Run it on 100 samples
# datalist = datalist[:40]

# test_transform = Compose(
#     [
#         LoadImaged(keys=["image", "label"]),
#         EnsureChannelFirstd(keys="image"),
#         EnsureTyped(keys=["image", "label"]),
#         ConvertLabels(keys="label"),
#         Orientationd(keys=["image", "label"], axcodes="RAS"),
#         Spacingd(
#             keys=["image", "label"],
#             pixdim=(1.0, 1.0, 1.0),
#             mode=("bilinear", "nearest"),
#         ),
#         NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
#     ]
# )

# # Create MONAI test dataset
# test_ds = Dataset(data=datalist, transform=test_transform)

# # Dataloader
# test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=3, pin_memory=True, persistent_workers=False)


## Testing Pipeline

In [None]:
# import torch
# import pytorch_lightning as pl
# from monai.networks.nets import SwinUNETR
# from monai.transforms import Compose, Activations, AsDiscrete
# from monai.metrics import DiceMetric
# from monai.losses import DiceLoss
# from monai.data import DataLoader, Dataset, decollate_batch
# from monai.inferers import sliding_window_inference
# import matplotlib.pyplot as plt
# import pandas as pd
# import seaborn as sns
# import numpy as np

# class BrainTumorSegmentationModel(pl.LightningModule):
#     def __init__(self):
#         super(BrainTumorSegmentationModel, self).__init__()
#         self.model = SwinUNETR(
#             img_size=(96, 96, 96),
#             in_channels=4,
#             out_channels=3,
#             feature_size=48,
#             use_checkpoint=True,
#         )
        
#         # Load model weights
#         self.model.load_state_dict(torch.load("/kaggle/input/trained-model-29/pytorch/default/1/swinunetr-29epochs.pth"))
        
#         # Post-processing transformations
#         self.post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
        
#         # Dice metrics for evaluation
#         self.dice_metric = DiceMetric(include_background=True, reduction="mean")
#         self.dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")
        
#         # Dice loss
#         self.dice_loss_function = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True)

#         self.total_loss = 0.0  # To accumulate loss
#         self.steps = 0  # Count number of batches
        
#         # List to store all batch results
#         self.all_batch_results = []
        
#     def forward(self, x):
#         return self.model(x)
    
#     def test_step(self, batch, batch_idx):
#         test_inputs, test_labels = batch["image"].to(self.device), batch["label"].to(self.device)
    
#         # Inference
#         with torch.no_grad():
#             test_outputs = sliding_window_inference(test_inputs, roi_size=(96, 96, 96), sw_batch_size=1, predictor=self.model, overlap=0.5)
    
#         # Compute Dice Loss
#         test_loss = self.dice_loss_function(test_outputs, test_labels)
    
#         # Aggregate Loss
#         self.total_loss += test_loss
#         self.steps += 1
    
#         # Post-processing
#         test_outputs = [self.post_trans(i) for i in decollate_batch(test_outputs)]
    
#         # Compute Dice scores
#         self.dice_metric(y_pred=test_outputs, y=test_labels)
#         self.dice_metric_batch(y_pred=test_outputs, y=test_labels)
    
#         mean_dice = self.dice_metric.aggregate().item()
#         metric_batch = self.dice_metric_batch.aggregate()
    
#         dice_tc, dice_wt, dice_et = metric_batch[0].item(), metric_batch[1].item(), metric_batch[2].item()
    
#         self.log("Test Loss (Dice)", test_loss, prog_bar=True)
#         self.log("Mean Dice", mean_dice, prog_bar=True)
#         self.log("Dice TC", dice_tc, prog_bar=True)
#         self.log("Dice WT", dice_wt, prog_bar=True)
#         self.log("Dice ET", dice_et, prog_bar=True)
    
#         # Create batch result dictionary
#         batch_result = {
#             "Test Loss (Dice)": test_loss.item(),
#             "Mean Dice": mean_dice,
#             "Dice TC": dice_tc,
#             "Dice WT": dice_wt,
#             "Dice ET": dice_et,
#         }
        
#         # Store the batch result
#         self.all_batch_results.append(batch_result)
    
#         # Reset metrics
#         self.dice_metric.reset()
#         self.dice_metric_batch.reset()
    
#         # Return metrics as a dictionary
#         return batch_result

#     def on_test_epoch_end(self):
#         """Compute and log mean loss over all batches."""
#         mean_loss = self.total_loss / self.steps if self.steps > 0 else 0
#         self.log("Mean Test Loss", mean_loss, prog_bar=True)
#         print(f"Mean Test Loss: {mean_loss:.4f}")
        
#         # Calculate and log average metrics across all batches
#         avg_metrics = {
#             "Mean Test Loss": mean_loss.item() if torch.is_tensor(mean_loss) else mean_loss,
#             "Mean Dice": np.mean([res["Mean Dice"] for res in self.all_batch_results]),
#             "Dice TC": np.mean([res["Dice TC"] for res in self.all_batch_results]),
#             "Dice WT": np.mean([res["Dice WT"] for res in self.all_batch_results]),
#             "Dice ET": np.mean([res["Dice ET"] for res in self.all_batch_results]),
#         }
        
#         print("Average Metrics:", avg_metrics)
        
#         # Return all results and average metrics
#         return {"batch_results": self.all_batch_results, "avg_metrics": avg_metrics}

#     def test_dataloader(self):
#         return test_loader  # Ensure you have a DataLoader defined for your test dataset


# # Set device (GPU if available, otherwise CPU)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # Initialize model
# model = BrainTumorSegmentationModel()

# # Initialize PyTorch Lightning Trainer
# trainer = pl.Trainer(
#     devices=1,
#     accelerator="gpu" if torch.cuda.is_available() else "cpu",
#     max_epochs=1,  # Since it's inference, set this to 1
#     log_every_n_steps=1,
# )

# # Run inference on test dataset and collect results
# test_results = trainer.test(model)

# # The test results from trainer.test() will be a list with one dictionary
# # If we want to save our detailed results, use:
# torch.save(model.all_batch_results, "test_batch_results.pth")

In [None]:
# import matplotlib.pyplot as plt
# import numpy as np
# import pandas as pd
# import seaborn as sns
# import torch

# # Load the saved batch results
# try:
#     batch_results = torch.load("test_batch_results.pth")
# except FileNotFoundError:
#     print("Could not find test_batch_results.pth. Make sure the file exists.")
#     # Create dummy data for demonstration if file doesn't exist
#     batch_results = []
#     print("Using dummy data for demonstration purposes.")

# if batch_results:
#     # Calculate average metrics across all batches
#     avg_metrics = {
#         "Mean Dice": np.mean([res["Mean Dice"] for res in batch_results]),
#         "Dice TC": np.mean([res["Dice TC"] for res in batch_results]),
#         "Dice WT": np.mean([res["Dice WT"] for res in batch_results]),
#         "Dice ET": np.mean([res["Dice ET"] for res in batch_results]),
#         "Test Loss (Dice)": np.mean([res["Test Loss (Dice)"] for res in batch_results])
#     }
    
#     # Create a DataFrame for the Dice metrics (excluding loss for better visualization)
#     dice_metrics = {k: v for k, v in avg_metrics.items() if k != "Test Loss (Dice)"}
    
#     # Create a DataFrame for plotting
#     metrics_df = pd.DataFrame({
#         'Metric': list(dice_metrics.keys()),
#         'Value': list(dice_metrics.values())
#     })
    
#     # Set the Seaborn style
#     sns.set_style("whitegrid")
    
#     # # Create figure with two subplots: one for Dice metrics, one for Loss
#     # fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6), 
#     #                                gridspec_kw={'width_ratios': [3, 1]})

#     fig, ax1 = plt.subplots(figsize=(10, 8))

    
#     # 1. Plot Dice metrics
#     # Create bar plot with custom color palette
#     sns.barplot(x='Metric', y='Value', data=metrics_df, 
#                 palette='viridis', ax=ax1)
    
#     # Customize the first subplot
#     ax1.set_title('SwinUNETR Test Set Performance', fontsize=16)
#     ax1.set_xlabel('Metrics', fontsize=14)
#     ax1.set_ylabel('Score', fontsize=14)
#     ax1.set_ylim(0, 1.0)  # Dice scores are between 0 and 1
#     ax1.set_xticklabels(ax1.get_xticklabels(), rotation=45, ha='right')
    
#     # Add value labels on top of bars
#     for i, value in enumerate(metrics_df['Value']):
#         ax1.text(i, value + 0.02, f'{value:.4f}', ha='center', fontsize=12)
    
#     # # 2. Plot Test Loss
#     # loss_value = avg_metrics["Test Loss (Dice)"]
#     # ax2.bar(['Test Loss (Dice)'], [loss_value], color='crimson')
#     # ax2.set_title('Dice Loss', fontsize=16)
#     # ax2.set_ylabel('Loss Value', fontsize=14)
#     # ax2.set_ylim(0, min(1.0, loss_value * 1.5))  # Adjust based on loss value
#     # ax2.text(0, loss_value + 0.02, f'{loss_value:.4f}', ha='center', fontsize=12)
    
    
#     # Adjust layout
#     plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        
#     # Save the figure
#     plt.savefig('brain_tumor_segmentation_metrics.png', dpi=300, bbox_inches='tight')
    
#     # Show the plot
#     plt.show()

## Knowledge Distillation Pipeline

In [None]:
# import os
# import time
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import pytorch_lightning as pl
# from monai.networks.nets import SwinUNETR
# from monai.losses import DiceLoss, DiceCELoss
# from monai.metrics import DiceMetric
# from monai.transforms import Compose, Activations, AsDiscrete
# from monai.data import PersistentDataset, list_data_collate, decollate_batch, DataLoader
# from monai.inferers import sliding_window_inference
# from torch.optim import AdamW
# from torch.optim.lr_scheduler import CosineAnnealingLR
# from pytorch_lightning.callbacks.early_stopping import EarlyStopping
# from pytorch_lightning.callbacks.timer import Timer
# from pytorch_lightning.loggers import WandbLogger
# import wandb

# class DistillationLoss(nn.Module):
#     def __init__(self, alpha=0.5, temperature=2.0, dice_loss=None):
#         super().__init__()
#         self.alpha = alpha  # Weight for hard/ground-truth loss
#         self.temperature = temperature  # Temperature for softening
#         self.dice_loss = dice_loss  # The original dice loss
#         self.mse_loss = nn.MSELoss()  # Use MSE loss instead of KL for segmentation
    
#     def forward(self, student_outputs, teacher_outputs, labels):
#         # Hard loss (original Dice loss with ground truth)
#         hard_loss = self.dice_loss(student_outputs, labels)
        
#         # Soft loss (MSE between teacher and student logits)
#         # Scale logits by temperature
#         soft_student = student_outputs / self.temperature
#         soft_teacher = teacher_outputs / self.temperature
        
#         # MSE loss between student and teacher outputs
#         soft_loss = self.mse_loss(soft_student, soft_teacher) * (self.temperature ** 2)
        
#         # Normalize soft loss to be on a similar scale as hard loss
#         # This is important to prevent the soft loss from dominating
#         # You may need to adjust this scaling factor based on your data
#         soft_loss_scaling = 0.01  # Adjust this based on initial loss values
#         soft_loss = soft_loss * soft_loss_scaling
        
#         # Combined loss
#         total_loss = self.alpha * hard_loss + (1 - self.alpha) * soft_loss
        
#         return total_loss, hard_loss, soft_loss

# class BrainTumorDistillation(pl.LightningModule):
#     def __init__(self, train_loader, val_loader, teacher_model_path, 
#                  max_epochs=100, val_interval=1, learning_rate=1e-4,
#                  alpha=0.5, temperature=4.0, feature_size=24):
#         super().__init__()
#         self.save_hyperparameters()
        
#         # Teacher model (frozen)
#         self.teacher_model = SwinUNETR(
#             img_size=(96, 96, 96),
#             in_channels=4,
#             out_channels=3,
#             feature_size=48,
#             use_checkpoint=True,
#         )
#         # Load teacher weights
#         self.teacher_model.load_state_dict(torch.load(teacher_model_path))
#         # Freeze teacher model
#         for param in self.teacher_model.parameters():
#             param.requires_grad = False
#         self.teacher_model.eval()
        
#         # Student model (smaller)
#         self.model = SwinUNETR(
#             img_size=(96, 96, 96),
#             in_channels=4,
#             out_channels=3,
#             feature_size=feature_size,  # Smaller feature size
#             use_checkpoint=True,
#         )
        
#         # Original loss function
#         self.dice_loss = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True)
        
#         # Distillation loss
#         self.loss_function = DistillationLoss(
#             alpha=alpha,
#             temperature=temperature,
#             dice_loss=self.dice_loss
#         )
        
#         # Metrics
#         self.dice_metric = DiceMetric(include_background=True, reduction="mean")
#         self.dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")

#         self.post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
        
#         self.best_metric = -1
#         self.train_loader = train_loader
#         self.val_loader = val_loader

#         # Training metrics
#         self.avg_train_loss_values = []
#         self.train_loss_values = []
#         self.train_metric_values = []
#         self.train_metric_values_tc = []
#         self.train_metric_values_wt = []
#         self.train_metric_values_et = []

#         # Validation metrics
#         self.avg_val_loss_values = []
#         self.epoch_loss_values = []
#         self.metric_values = []
#         self.metric_values_tc = []
#         self.metric_values_wt = []
#         self.metric_values_et = []
        
#         # Track distillation-specific metrics
#         self.hard_loss_values = []
#         self.soft_loss_values = []

#     def forward(self, x):
#         return self.model(x)

#     def training_step(self, batch, batch_idx):
#         inputs, labels = batch["image"], batch["label"]

#         # Get teacher predictions (no gradients)
#         with torch.no_grad():
#             teacher_outputs = self.teacher_model(inputs)
        
#         # Get student predictions
#         student_outputs = self(inputs)
        
#         # Calculate losses
#         total_loss, hard_loss, soft_loss = self.loss_function(student_outputs, teacher_outputs, labels)
        
#         # Log the losses
#         self.log("train_total_loss", total_loss, prog_bar=True)
#         self.log("train_hard_loss", hard_loss, prog_bar=True)
#         self.log("train_soft_loss", soft_loss, prog_bar=True)
        
#         # Store distillation losses
#         self.hard_loss_values.append(hard_loss.item())
#         self.soft_loss_values.append(soft_loss.item())

#         # Apply sigmoid and threshold
#         outputs = [self.post_trans(i) for i in decollate_batch(student_outputs)]
        
#         # Compute Dice
#         self.dice_metric(y_pred=outputs, y=labels)
#         self.dice_metric_batch(y_pred=outputs, y=labels)

#         # Log Train Dice 
#         train_dice = self.dice_metric.aggregate().item()
#         self.log("train_mean_dice", train_dice, prog_bar=True)

#         # Store Mean Dice
#         self.train_metric_values.append(train_dice)

#         # Store the individual dice
#         metric_batch = self.dice_metric_batch.aggregate()
#         self.train_metric_values_tc.append(metric_batch[0].item())
#         self.train_metric_values_wt.append(metric_batch[1].item())
#         self.train_metric_values_et.append(metric_batch[2].item())

#         # Log the individual dice metrics
#         self.log("train_tc", metric_batch[0].item(), prog_bar=True)
#         self.log("train_wt", metric_batch[1].item(), prog_bar=True)
#         self.log("train_et", metric_batch[2].item(), prog_bar=True)

#         if batch_idx == 0 and self.current_epoch % 5 == 0:
#             print(f"Epoch {self.current_epoch}")
#             print(f"Student output range: {student_outputs.min().item():.4f} to {student_outputs.max().item():.4f}")
#             print(f"Teacher output range: {teacher_outputs.min().item():.4f} to {teacher_outputs.max().item():.4f}")
#             print(f"Hard loss: {hard_loss.item():.4f}, Soft loss: {soft_loss.item():.4f}")

#         # Reset metrics for the next batch
#         self.dice_metric.reset()
#         self.dice_metric_batch.reset()

#         return total_loss

#     def on_train_epoch_end(self):
#         train_loss = self.trainer.logged_metrics["train_total_loss"].item()
#         self.train_loss_values.append(train_loss)
        
#         # Calculate and store average loss per epoch
#         avg_train_loss = sum(self.train_loss_values) / len(self.train_loss_values)
#         self.log("avg_train_loss", avg_train_loss, prog_bar=True)
#         self.avg_train_loss_values.append(avg_train_loss)
        
#         # Log average distillation losses
#         avg_hard_loss = sum(self.hard_loss_values) / len(self.hard_loss_values)
#         avg_soft_loss = sum(self.soft_loss_values) / len(self.soft_loss_values)
#         self.log("avg_hard_loss", avg_hard_loss)
#         self.log("avg_soft_loss", avg_soft_loss)

#     def validation_step(self, batch, batch_idx):
#         val_inputs, val_labels = batch["image"], batch["label"]
        
#         # Get teacher predictions for reference
#         with torch.no_grad():
#             teacher_val_outputs = sliding_window_inference(
#                 val_inputs, roi_size=(96, 96, 96), sw_batch_size=1, predictor=self.teacher_model, overlap=0.5
#             )
        
#         # Get student predictions
#         student_val_outputs = sliding_window_inference(
#             val_inputs, roi_size=(96, 96, 96), sw_batch_size=1, predictor=self.model, overlap=0.5
#         )
        
#         # Compute distillation loss
#         val_loss, val_hard_loss, val_soft_loss = self.loss_function(
#             student_val_outputs, teacher_val_outputs, val_labels
#         )
        
#         # Log validation losses
#         self.log("val_total_loss", val_loss, prog_bar=True, sync_dist=True)
#         self.log("val_hard_loss", val_hard_loss, prog_bar=True)
#         self.log("val_soft_loss", val_soft_loss, prog_bar=True)
        
#         # Process student outputs for metrics
#         student_post_outputs = [self.post_trans(i) for i in decollate_batch(student_val_outputs)]
        
#         # Compute Dice metrics
#         self.dice_metric(y_pred=student_post_outputs, y=val_labels)
#         self.dice_metric_batch(y_pred=student_post_outputs, y=val_labels)
        
#         # Log validation Dice
#         val_dice = self.dice_metric.aggregate().item()
#         self.log("val_mean_dice", val_dice, prog_bar=True)
        
#         # Compare with teacher (for monitoring purposes)
#         with torch.no_grad():
#             teacher_post_outputs = [self.post_trans(i) for i in decollate_batch(teacher_val_outputs)]
#             teacher_dice = DiceMetric(include_background=True, reduction="mean")
#             teacher_dice(y_pred=teacher_post_outputs, y=val_labels)
#             teacher_dice_score = teacher_dice.aggregate().item()
#             self.log("teacher_mean_dice", teacher_dice_score, prog_bar=True)
#             self.log("teacher_student_gap", teacher_dice_score - val_dice, prog_bar=True)
        
#         return {"val_loss": val_loss}

#     def on_validation_epoch_end(self):
#         # Store Dice Mean
#         val_dice = self.dice_metric.aggregate().item()
#         self.metric_values.append(val_dice)

#         # Store Validation Loss 
#         val_loss = self.trainer.logged_metrics["val_total_loss"].item()
#         self.epoch_loss_values.append(val_loss)

#         # Calculate and Store avg val loss values
#         avg_val_loss = sum(self.epoch_loss_values) / len(self.epoch_loss_values)
#         self.log("avg_val_loss", avg_val_loss, prog_bar=True)
#         self.avg_val_loss_values.append(avg_val_loss)

#         # Store Individual Dice
#         metric_batch = self.dice_metric_batch.aggregate()
#         self.metric_values_tc.append(metric_batch[0].item())
#         self.metric_values_wt.append(metric_batch[1].item())
#         self.metric_values_et.append(metric_batch[2].item())

#         # Log validation metrics
#         self.log("val_loss", val_loss, prog_bar=True)
#         self.log("val_mean_dice", val_dice, prog_bar=True)
#         self.log("val_tc", metric_batch[0].item(), prog_bar=True)
#         self.log("val_wt", metric_batch[1].item(), prog_bar=True)
#         self.log("val_et", metric_batch[2].item(), prog_bar=True)
    
#         if val_dice > self.best_metric:
#             self.best_metric = val_dice
#             self.best_metric_epoch = self.current_epoch
#             torch.save(self.model.state_dict(), "best_distilled_model.pth")
#             self.log("best_metric", self.best_metric)
    
#         # Reset metrics for the next epoch
#         self.dice_metric.reset()
#         self.dice_metric_batch.reset()

#     def on_train_end(self):
#         # Print the best metric and epoch along with individual Dice scores
#         print(f"Train completed, best_metric: {self.best_metric:.4f} at epoch: {self.best_metric_epoch}, "
#               f"tc: {self.metric_values_tc[-1]:.4f}, "
#               f"wt: {self.metric_values_wt[-1]:.4f}, "
#               f"et: {self.metric_values_et[-1]:.4f}.")
        
#         # Compare model sizes
#         teacher_size = sum(p.numel() for p in self.teacher_model.parameters())
#         student_size = sum(p.numel() for p in self.model.parameters())
#         print(f"Teacher model parameters: {teacher_size:,}")
#         print(f"Student model parameters: {student_size:,}")
#         print(f"Size reduction: {(1 - student_size/teacher_size)*100:.2f}%")

#     def configure_optimizers(self):
#         optimizer = AdamW(self.model.parameters(), lr=self.hparams.learning_rate, weight_decay=1e-5)
#         scheduler = CosineAnnealingLR(optimizer, T_max=self.hparams.max_epochs)
#         return [optimizer], [scheduler]

# # Usage example
# def train_distilled_model(teacher_model_path, train_ds, val_ds, feature_size=24):
#     # Training setup
#     max_epochs = 30
#     train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=3, pin_memory=True)
#     val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=3, pin_memory=True)

#     # Set up early stopping
#     early_stop_callback = EarlyStopping(
#        monitor="val_total_loss",
#        min_delta=0.00,
#        patience=5,
#        verbose=True,
#        mode='min'
#     )
#     # Stop training after 10 hours
#     timer_callback = Timer(duration="00:11:00:00")

#     # Initialize wandb logger
#     wandb.init(project="brain-tumor-segmentation", name="swinunetr-distillation")
#     wandb_logger = WandbLogger()

#     # Initialize and train the model
#     model = BrainTumorDistillation(
#         train_loader=train_loader, 
#         val_loader=val_loader, 
#         teacher_model_path=teacher_model_path,
#         max_epochs=max_epochs,
#         feature_size=feature_size  # Smaller feature size for student model
#     )
    
#     trainer = pl.Trainer(
#         max_epochs=max_epochs,
#         devices=1,
#         accelerator="gpu",
#         precision="16-mixed",
#         gradient_clip_val=1.0,
#         log_every_n_steps=1,
#         callbacks=[early_stop_callback, timer_callback],
#         limit_val_batches=5,
#         check_val_every_n_epoch=1,
#         logger=wandb_logger, 
#     )

#     trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
    
#     return model


# model = train_distilled_model(
#     teacher_model_path="/kaggle/input/trained-model-29/pytorch/default/1/swinunetr-29epochs.pth", 
#     train_ds=train_ds, 
#     val_ds=val_ds,
#     feature_size=24  # Half the original feature size
# )

## Student Model Inference

In [None]:
# import random
# import torch
# import numpy as np
# import matplotlib.pyplot as plt
# from matplotlib.colors import ListedColormap
# from monai.inferers import sliding_window_inference
# from monai.transforms import Compose, Activations, AsDiscrete
# import os

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

# # Load model and move to device
# model.model.load_state_dict(torch.load(os.path.join(root_dir, "best_distilled_model.pth")))
# model = model.to(device)
# model.eval()

# # Define aesthetically pleasing color scheme
# class_colors = [
#     [0.7, 0.7, 0.7, 1],         # Background (neutral gray)
#     [0.85, 0.37, 0.35, 0.7],  # Class 1 (rust red - softer and more professional)
#     [0.46, 0.78, 0.56, 0.7],  # Class 2 (sage green - easier on the eyes)
#     [0.31, 0.51, 0.9, 0.7]    # Class 3 (medium blue - more saturated but not overwhelming)
# ]
# custom_cmap = ListedColormap(class_colors)

# # Randomly select 5 samples
# random_indices = random.sample(range(len(val_ds)), 5)

# for idx in random_indices:
#     with torch.no_grad():
#         # Select image to evaluate
#         val_input = val_ds[idx]["image"].unsqueeze(0).to(device)
#         val_label = val_ds[idx]["label"]
        
#         # Inference
#         roi_size = (96, 96, 96)
#         sw_batch_size = 4
#         val_output = sliding_window_inference(val_input, roi_size, sw_batch_size, model)
#         val_output = post_trans(val_output[0])
        
#         # Move tensors to CPU and convert to numpy
#         val_input_np = val_input[0, 0].cpu().numpy()  # Shape: (H, W, D)
#         val_label_np = val_label.cpu().numpy()  # Shape: (C, H, W, D) where C is number of classes
#         val_output_np = val_output.cpu().numpy()  # Shape: (C, H, W, D)
        
#         # Normalize image for visualization with better contrast
#         val_input_np = (val_input_np - val_input_np.min()) / (val_input_np.max() - val_input_np.min())
#         val_input_np = (val_input_np * 255).astype(np.uint8)
        
#         # Determine slice to use (middle slice or 77 if available)
#         total_slices = val_input_np.shape[-1]
#         middle_slice = total_slices // 2
#         slice_idx = 77 if total_slices > 77 else middle_slice
#         print(f"Using slice {slice_idx} out of {total_slices} total slices")
        
#         # Create a combined segmentation map for ground truth and prediction
#         num_classes = val_label_np.shape[0]
#         gt_combined = np.zeros((val_label_np.shape[1], val_label_np.shape[2], 4))  # RGBA
#         pred_combined = np.zeros((val_output_np.shape[1], val_output_np.shape[2], 4))  # RGBA
        
#         # Fill in each class with its color
#         for c in range(num_classes):
#             # For ground truth
#             mask = val_label_np[c, :, :, slice_idx]
#             for i in range(4):  # RGBA channels
#                 gt_combined[:, :, i] = np.where(mask > 0, class_colors[c+1][i], gt_combined[:, :, i])
            
#             # For prediction
#             mask = val_output_np[c, :, :, slice_idx]
#             for i in range(4):  # RGBA channels
#                 pred_combined[:, :, i] = np.where(mask > 0, class_colors[c+1][i], pred_combined[:, :, i])
        
#         # Plot the images with improved styling
#         plt.figure(figsize=(18, 6), facecolor='white')
        
#         plt.subplot(1, 3, 1)
#         plt.title("Image", fontsize=14, fontweight='bold')
#         plt.imshow(val_input_np[:, :, slice_idx], cmap="gray")
#         plt.axis('off')
        
#         plt.subplot(1, 3, 2)
#         plt.title("Ground Truth", fontsize=14, fontweight='bold')
#         plt.imshow(val_input_np[:, :, slice_idx], cmap="gray")
#         plt.imshow(gt_combined)  # Alpha is already in the array
#         plt.axis('off')
        
#         plt.subplot(1, 3, 3)
#         plt.title("Predicted Segmentation", fontsize=14, fontweight='bold')
#         plt.imshow(val_input_np[:, :, slice_idx], cmap="gray")
#         plt.imshow(pred_combined)  # Alpha is already in the array
#         plt.axis('off')
        
#         # Add a clearer color legend
#         class_names = ["Tumor Core", "Whole Tumor", "Enhancing"]
#         legend_patches = [plt.Rectangle((0, 0), 1, 1, fc=class_colors[i+1][:3], alpha=0.7) for i in range(num_classes)]
#         plt.figlegend(legend_patches, class_names, loc='lower center', ncol=num_classes, 
#                      bbox_to_anchor=(0.5, -0.05), fontsize=12, frameon=True, edgecolor='black')
        
#         plt.tight_layout(pad=1.5)
#         plt.subplots_adjust(bottom=0.15)  # Add space for the legend
#         plt.show()
        
#         # Display each class separately with enhanced visualization
#         plt.figure(figsize=(15, 5 * num_classes), facecolor='white')
        
#         # Custom colormaps for each class - more aesthetically pleasing
#         class_cmaps = ['RdPu', 'BuGn', 'PuBu']
        
#         for c in range(num_classes):
#             # Ground Truth for this class
#             plt.subplot(num_classes, 2, 2*c+1)
#             plt.title(f"Ground Truth - {class_names[c]}", fontsize=12, fontweight='bold')
#             plt.imshow(val_input_np[:, :, slice_idx], cmap="gray")
#             plt.imshow(val_label_np[c, :, :, slice_idx], cmap=class_cmaps[c], alpha=0.7, vmin=0, vmax=1)
#             plt.axis('off')
            
#             # Prediction for this class
#             plt.subplot(num_classes, 2, 2*c+2)
#             plt.title(f"Prediction - {class_names[c]}", fontsize=12, fontweight='bold')
#             plt.imshow(val_input_np[:, :, slice_idx], cmap="gray")
#             plt.imshow(val_output_np[c, :, :, slice_idx], cmap=class_cmaps[c], alpha=0.7, vmin=0, vmax=1)
#             plt.axis('off')
            
#             # Add a small colorbar to show intensity
#             plt.colorbar(shrink=0.8, ax=plt.gca())
        
#         plt.tight_layout(pad=2.0)
#         plt.show()

In [None]:
# from monai.inferers import sliding_window_inference

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])


# model.model.load_state_dict(torch.load(os.path.join(root_dir, "best_distilled_model.pth")))
# model = model.to(device)  # Add this line to move model to the same device as inputs
# model.eval()
# with torch.no_grad():
#     # select one image to evaluate and visualize the model output
#     val_input = val_ds[5]["image"].unsqueeze(0).to(device)
#     roi_size = (96, 96, 96)
#     sw_batch_size = 4
#     val_output = sliding_window_inference(val_input, roi_size, sw_batch_size, model)
#     val_output = post_trans(val_output[0])
#     plt.figure("image", (24, 6))
#     for i in range(4):
#         plt.subplot(1, 4, i + 1)
#         plt.title(f"image channel {i}")
#         plt.imshow(val_ds[5]["image"][i, :, :, 72].detach().cpu(), cmap="gray")
#     plt.show()
#     # visualize the 3 channels label corresponding to this image
#     plt.figure("label", (18, 6))
#     for i in range(3):
#         plt.subplot(1, 3, i + 1)
#         plt.title(f"label channel {i}")
#         plt.imshow(val_ds[5]["label"][i, :, :, 72].detach().cpu())
#     plt.show()
#     # visualize the 3 channels model output corresponding to this image
#     plt.figure("output", (18, 6))
#     for i in range(3):
#         plt.subplot(1, 3, i + 1)
#         plt.title(f"output channel {i}")
#         plt.imshow(val_output[i, :, :, 72].detach().cpu())
#     plt.show()

## Student Model Test Performance

### Prepare Test JSON

In [None]:
# import os
# import glob
# import json

# # Get sorted file paths and file names
# file_paths2 = glob.glob('/kaggle/input/brats2023-part-2zip/*')  # Unseen Data 
# file_paths2.sort()

# file_names2 = [os.path.basename(path) for path in file_paths2]  # Extract file names from paths
# file_names2.sort()

# # Initialize lists for different MRI modalities and segmentation labels
# t1c, t1n, t2f, t2w, label = [], [], [], [], []

# # Use the total number of files
# num_files = len(file_paths2)

# # Populate the lists with file paths
# for i in range(num_files):
#     t1c.append(os.path.join(file_paths2[i], file_names2[i] + '-t1c.nii'))
#     t1n.append(os.path.join(file_paths2[i], file_names2[i] + '-t1n.nii'))
#     t2f.append(os.path.join(file_paths2[i], file_names2[i] + '-t2f.nii'))
#     t2w.append(os.path.join(file_paths2[i], file_names2[i] + '-t2w.nii'))
#     label.append(os.path.join(file_paths2[i], file_names2[i] + '-seg.nii'))

# # Store in a dictionary with combined image modalities and separate label
# file_list = []
# for i in range(num_files):
#     file_list.append({
#         "image": [t1c[i], t1n[i], t2f[i], t2w[i]],  # Combine modalities into one "image" field
#         "label": label[i]
#     })

# file_json = {
#     "testing": file_list  # Changed key to "testing" for clarity
# }

# # Save to JSON file
# file_path = '/kaggle/working/dataset_test.json'
# with open(file_path, 'w') as json_file:
#     json.dump(file_json, json_file, indent=4)


### Test Dataloader 

In [None]:
# # Load test dataset
# dataset_path = "/kaggle/working/dataset_test.json"
# with open(dataset_path) as f:
#     datalist = json.load(f)["testing"]  # Updated key to match test dataset

# ### Run it on 100 samples
# datalist = datalist[:40]

# test_transform = Compose(
#     [
#         LoadImaged(keys=["image", "label"]),
#         EnsureChannelFirstd(keys="image"),
#         EnsureTyped(keys=["image", "label"]),
#         ConvertLabels(keys="label"),
#         Orientationd(keys=["image", "label"], axcodes="RAS"),
#         Spacingd(
#             keys=["image", "label"],
#             pixdim=(1.0, 1.0, 1.0),
#             mode=("bilinear", "nearest"),
#         ),
#         NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
#     ]
# )

# # Create MONAI test dataset
# test_ds = Dataset(data=datalist, transform=test_transform)

# # Dataloader
# test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=3, pin_memory=True, persistent_workers=False)


### Test Pipeline (Student) 

In [None]:
# import torch
# import pytorch_lightning as pl
# from monai.networks.nets import SwinUNETR
# from monai.transforms import Compose, Activations, AsDiscrete
# from monai.metrics import DiceMetric
# from monai.losses import DiceLoss
# from monai.data import DataLoader, Dataset, decollate_batch
# from monai.inferers import sliding_window_inference
# import matplotlib.pyplot as plt
# import pandas as pd
# import seaborn as sns
# import numpy as np

# class BrainTumorSegmentationModel(pl.LightningModule):
#     def __init__(self):
#         super(BrainTumorSegmentationModel, self).__init__()
#         self.model = SwinUNETR(
#             img_size=(96, 96, 96),
#             in_channels=4,
#             out_channels=3,
#             feature_size=24,
#             use_checkpoint=True,
#         )
        
#         # Load model weights
#         self.model.load_state_dict(torch.load("/kaggle/input/swin_distilled_model/pytorch/default/1/best_distilled_model.pth"))
        
#         # Post-processing transformations
#         self.post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
        
#         # Dice metrics for evaluation
#         self.dice_metric = DiceMetric(include_background=True, reduction="mean")
#         self.dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")
        
#         # Dice loss
#         self.dice_loss_function = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True)

#         self.total_loss = 0.0  # To accumulate loss
#         self.steps = 0  # Count number of batches
        
#         # List to store all batch results
#         self.all_batch_results = []
        
#     def forward(self, x):
#         return self.model(x)


# model = BrainTumorSegmentationModel()

In [None]:
# for name, module in model.model.named_modules():
#     print(name)


In [None]:
# import json
# import torch
# import numpy as np
# import matplotlib.pyplot as plt
# from monai.networks.nets import SwinUNETR
# from monai.data import Dataset, DataLoader
# from monai.transforms import (
#     Compose, LoadImaged, EnsureChannelFirstd, EnsureTyped,
#     Orientationd, Spacingd, NormalizeIntensityd, Resized
# )
# from monai.visualize import GradCAM
# from monai.transforms import Resize
# from monai.utils.misc import set_determinism

# set_determinism(42)

# # --- Preprocessing ---
# def load_datalist(path, max_samples=40):
#     with open(path) as f:
#         return json.load(f)["testing"][:max_samples]

# def get_transforms():
#     return Compose([
#         LoadImaged(keys=["image", "label"]),
#         EnsureChannelFirstd(keys=["image", "label"]),
#         EnsureTyped(keys=["image", "label"]),
#         Orientationd(keys=["image", "label"], axcodes="RAS"),
#         Spacingd(keys=["image", "label"], pixdim=(1, 1, 1), mode=("bilinear", "nearest")),
#         NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
#         Resized(keys=["image", "label"], spatial_size=(96, 96, 96), mode=["trilinear", "nearest"]),
#     ])

# # --- Model ---
# def build_model():
#     return SwinUNETR(
#         img_size=(96, 96, 96),
#         in_channels=4,
#         out_channels=3,
#         feature_size=24,
#         use_checkpoint=True
#     )

# def load_weights(model, path):
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     model.load_state_dict(torch.load(path, map_location=device))
#     return model.eval().to(device), device

# # --- GradCAM Utils ---
# def normalize(cam):
#     return (cam - cam.min()) / (cam.max() - cam.min() + 1e-5)

# import torch.nn.functional as F
# def resize_cam(cam, target_shape):
#     if cam.ndim == 4:  # [B, H, W, D]
#         cam = cam.unsqueeze(1)  # -> [B, 1, H, W, D]
#     elif cam.ndim == 3:  # [H, W, D]
#         cam = cam[None, None]   # -> [1, 1, H, W, D]

#     cam = F.interpolate(cam, size=target_shape, mode="trilinear", align_corners=False)
#     return normalize(cam)


# # --- Visualization ---
# def show_cam_overlay(image, cam, title, channel_idx=3, channel_name="T2"):
#     mid = image.shape[2] // 2
#     plt.figure(figsize=(12, 5))
#     plt.subplot(1, 2, 1)
#     plt.imshow(image[channel_idx, :, mid, :], cmap="gray")
#     plt.title(f"Original - {channel_name}")
#     plt.axis("off")
#     plt.subplot(1, 2, 2)
#     plt.imshow(image[channel_idx, :, mid, :], cmap="gray")
#     plt.imshow(cam[0, :, mid, :], cmap="jet_r", alpha=0.5)
#     plt.title(f"Grad-CAM Overlay - {channel_name}")
#     plt.axis("off")
#     plt.suptitle(title)
#     plt.tight_layout()
#     plt.show()

# # --- Main Execution ---
# def run_gradcam(
#     dataset_path,
#     checkpoint_path,
#     sample_idx=0,
#     target_class=1
# ):
#     datalist = load_datalist(dataset_path)
#     dataset = Dataset(data=datalist, transform=get_transforms())
#     loader = DataLoader(dataset, batch_size=1)

#     model = build_model()
#     model, device = load_weights(model, checkpoint_path)

#     # Get sample
#     sample = dataset[sample_idx]
#     image = sample["image"].unsqueeze(0).to(device)

#     # GradCAM
#     gradcam = GradCAM(nn_module=model, target_layers="encoder1.layer.norm3")
#     cam_raw = gradcam(x=image, class_idx=target_class)

#     # Resize and Normalize CAM
#     cam = resize_cam(cam_raw, image.shape[2:])

#     # Visualization
#     input_np = image[0].cpu().numpy()
#     cam_np = cam[0].cpu().numpy()
#     class_names = ["Tumor Core", "Whole Tumor", "Enhancing Tumor"]
#     show_cam_overlay(input_np, cam_np, f"Grad-CAM: {class_names[target_class]}")

#     return input_np, cam_np

# # Example usage
# if __name__ == "__main__":
#     run_gradcam(
#         dataset_path="/kaggle/working/dataset_test.json",
#         checkpoint_path="/kaggle/input/swin_distilled_model/pytorch/default/1/best_distilled_model.pth",
#         sample_idx=17,
#         target_class=0
#     )


In [None]:
# import os
# import json
# import torch
# import numpy as np
# import matplotlib.pyplot as plt
# import pytorch_lightning as pl
# from monai.networks.nets import SwinUNETR
# from monai.data import Dataset, DataLoader
# from monai.transforms import (
#     Compose, LoadImaged, EnsureChannelFirstd, EnsureTyped,
#     Orientationd, Spacingd, NormalizeIntensityd, Resized,
#     MapTransform
# )
# from monai.visualize import GradCAM
# from monai.utils.misc import set_determinism

# # Set determinism for reproducibility
# set_determinism(42)

# # Create model
# def create_model():
#     return SwinUNETR(
#         img_size=(96, 96, 96),
#         in_channels=4,
#         out_channels=3,
#         feature_size=24,
#         use_checkpoint=True,
#     )

# # Load test dataset
# dataset_path = "/kaggle/working/dataset_test.json"
# with open(dataset_path) as f:
#     datalist = json.load(f)["testing"]

# # Limit to 40 samples
# datalist = datalist[:40]

# # Test transforms with Resized
# test_transform = Compose(
#     [
#         LoadImaged(keys=["image", "label"]),
#         EnsureChannelFirstd(keys="image"),
#         EnsureTyped(keys=["image", "label"]),
#         Orientationd(keys=["image", "label"], axcodes="RAS"),
#         Res enized(
#             keys=["image", "label"],
#             spatial_size=(96, 96, 96),
#             mode=("trilinear", "nearest"),
#         ),
#         NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
#     ]
# )

# # Create test dataset and dataloader
# test_ds = Dataset(data=datalist, transform=test_transform)
# test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=3, pin_memory=True, persistent_workers=False)

# # Initialize model and load checkpoint
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = create_model().to(device)
# checkpoint_path = "/kaggle/input/swin_distilled_model/pytorch/default/1/best_distilled_model.pth"
# checkpoint = torch.load(checkpoint_path, map_location=device)
# model.load_state_dict(checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint)
# model.eval()

# # Inspect model architecture (optional, for debugging)
# # for name, module in model.named_modules():
# #     print(name)

# # Initialize Grad-CAM with a new target layer
# target_layer = "encoder1"  # Use the entire encoder1 block
# grad_cam = GradCAM(nn_module=model, target_layers=target_layer)

# # Function to visualize Grad-CAM
# def visualize_gradcam(image, cam_result, slice_idx=48, alpha=0.5):
#     # Convert to numpy for visualization
#     image_np = image.cpu().detach().numpy()[0, 0]  # Select first channel
#     cam_np = cam_result.cpu().detach().numpy()[0, 0]  # Select first channel of CAM

#     # Normalize CAM for visualization
#     cam_np = (cam_np - cam_np.min()) / (cam_np.max() - cam_np.min() + 1e-8)

#     # Plot
#     plt.figure(figsize=(12, 4))

#     # Original image slice
#     plt.subplot(1, 3, 1)
#     plt.imshow(image_np[:, :, slice_idx], cmap="gray")
#     plt.title("Original Image")
#     plt.axis("off")

#     # Grad-CAM heatmap
#     plt.subplot(1, 3, 2)
#     plt.imshow(cam_np[:, :, slice_idx], cmap="jet")
#     plt.title("Grad-CAM Heatmap")
#     plt.axis("off")

#     # Overlay
#     plt.subplot(1, 3, 3)
#     plt.imshow(image_np[:, :, slice_idx], cmap="gray")
#     plt.imshow(cam_np[:, :, slice_idx], cmap="jet_r", alpha=alpha)
#     plt.title("Overlay")
#     plt.axis("off")

#     plt.tight_layout()
#     plt.show()

# # Run Grad-CAM on one sample
# for i, batch in enumerate(test_loader):
#     # Get image and label
#     image = batch["image"].to(device)  # Shape: [1, 4, 96, 96, 96]
#     label = batch["label"].to(device)  # Shape: [1, 3, 96, 96, 96]

#     # Enable gradient tracking for the input
#     image.requires_grad_(True)

#     # Compute Grad-CAM
#     class_index = 1
#     cam_result = grad_cam(x=image, class_idx=class_index)  # Shape: [1, 1, 96, 96, 96]

#     # Visualize
#     print(f"Visualizing Grad-CAM for sample {i+1}")
#     visualize_gradcam(image, cam_result)

#     # Process only one sample for demonstration
#     break

# print("Grad-CAM visualization complete.")

In [None]:
# import os
# os._exit(00)
