In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install -q segmentation-models-pytorch monai nnunet

import os
import time
from glob import glob
from tqdm import tqdm

import cv2
import numpy as np
import pandas as pd
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import albumentations as A
from scipy.ndimage.morphology import binary_dilation
from sklearn.model_selection import train_test_split

import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T

# MONAI imports for 3D U-Net and UNETR
from monai.networks.nets import UNet, UNETR
from monai.losses import DiceLoss

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
files_dir = '/kaggle/input/lgg-mri-segmentation/lgg-mri-segmentation/kaggle_3m/'
file_paths = glob(f'{files_dir}/*/*[0-9].tif')
print(f"Total files found: {len(file_paths)}")

In [None]:
def get_file_row(path):
    """Produces ID of a patient, image and mask filenames from a particular path"""
    path_no_ext, ext = os.path.splitext(path)
    filename = os.path.basename(path)
    
    patient_id = '_'.join(filename.split('_')[:3])
    
    return [patient_id, path, f'{path_no_ext}_mask{ext}']

In [None]:
filenames_df = pd.DataFrame((get_file_row(filename) for filename in file_paths), columns=['Patient', 'image_filename', 'mask_filename'])
print(f"Total patient records: {len(filenames_df)}")
print(filenames_df.head())

In [None]:
class MriDataset(Dataset):
    def __init__(self, df, transform=None):
        super(MriDataset, self).__init__()
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
        
    def __getitem__(self, idx, raw=False):
        row = self.df.iloc[idx]
        img = cv2.imread(row['image_filename'], cv2.IMREAD_GRAYSCALE)  # Load as grayscale
        mask = cv2.imread(row['mask_filename'], cv2.IMREAD_GRAYSCALE)
        
        if raw:
            return img, mask
        
        if self.transform:
            augmented = self.transform(image=img, mask=mask)
            image, mask = augmented['image'], augmented['mask']
        
        img = T.functional.to_tensor(img)  # Converts (H, W) to (1, H, W)
        mask = mask // 255
        mask = torch.Tensor(mask)
        return img, mask

In [None]:
class MriDataset3D(Dataset):
    def __init__(self, df, transform=None, max_depth=40):
        super(MriDataset3D, self).__init__()
        self.df = df.reset_index(drop=True)
        self.transform = transform
        self.max_depth = max_depth
        self.patients = self.df.groupby('Patient').groups
        self.patient_ids = list(self.patients.keys())
        
    def __len__(self):
        return len(self.patient_ids)
        
    def __getitem__(self, idx):
        patient_id = self.patient_ids[idx]
        indices = list(self.patients[patient_id])
        
        # Load all slices for a patient
        images = []
        masks = []
        for i in indices:
            row = self.df.loc[i]
            img = cv2.imread(row['image_filename'], cv2.IMREAD_GRAYSCALE)
            mask = cv2.imread(row['mask_filename'], cv2.IMREAD_GRAYSCALE)
            
            if img is None or mask is None:
                continue
                
            images.append(img)
            masks.append(mask // 255)
        
        if len(images) == 0:
            raise ValueError(f"No valid images found for patient {patient_id}")
        
        # Stack to create 3D volume (D, H, W)
        images = np.stack(images, axis=0)
        masks = np.stack(masks, axis=0)
        
        # Pad or crop to fixed depth (max_depth)
        current_depth = images.shape[0]
        
        if current_depth < self.max_depth:
            # Pad with zeros at the end
            pad_size = self.max_depth - current_depth
            images = np.pad(images, ((0, pad_size), (0, 0), (0, 0)), mode='constant', constant_values=0)
            masks = np.pad(masks, ((0, pad_size), (0, 0), (0, 0)), mode='constant', constant_values=0)
        elif current_depth > self.max_depth:
            # Crop to max_depth (take middle slices)
            start = (current_depth - self.max_depth) // 2
            images = images[start:start+self.max_depth]
            masks = masks[start:start+self.max_depth]
        
        # Convert to (C, D, H, W) format for MONAI
        images = torch.Tensor(images).unsqueeze(0).float()
        masks = torch.Tensor(masks).unsqueeze(0).float()
        
        return images, masks

In [None]:
train_df, test_df = train_test_split(filenames_df, test_size=0.3, random_state=42)
test_df, valid_df = train_test_split(test_df, test_size=0.5, random_state=42)

print(f"Training samples: {len(train_df)}")
print(f"Validation samples: {len(valid_df)}")
print(f"Test samples: {len(test_df)}")

In [None]:
# 2D datasets for 2D U-Net
transform = A.Compose([
    A.RandomBrightnessContrast(p=0.3),
    A.GaussNoise(p=0.2),
    A.Rotate(limit=10, p=0.3),
])

train_dataset_2d = MriDataset(train_df, transform)
valid_dataset_2d = MriDataset(valid_df)
test_dataset_2d = MriDataset(test_df)

train_loader_2d = DataLoader(train_dataset_2d, batch_size=16, shuffle=True)
valid_loader_2d = DataLoader(valid_dataset_2d, batch_size=16, shuffle=False)
test_loader_2d = DataLoader(test_dataset_2d, batch_size=1)

# 3D datasets with padding to fixed size (40 slices)
train_dataset_3d = MriDataset3D(train_df, max_depth=40)
valid_dataset_3d = MriDataset3D(valid_df, max_depth=40)
test_dataset_3d = MriDataset3D(test_df, max_depth=40)

train_loader_3d = DataLoader(train_dataset_3d, batch_size=2, shuffle=True)
valid_loader_3d = DataLoader(valid_dataset_3d, batch_size=2, shuffle=False)
test_loader_3d = DataLoader(test_dataset_3d, batch_size=1)

print("2D Dataloaders created successfully!")
print("3D Dataloaders created successfully (padded to 40 slices)!")

In [None]:
class EarlyStopping():
   
    def __init__(self, patience:int = 6, min_delta: float = 0, weights_path: str = 'weights.pt'):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
        self.weights_path = weights_path

    def __call__(self, val_loss: float, model: torch.nn.Module):
        if self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            torch.save(model.state_dict(), self.weights_path)
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

    def load_weights(self, model: torch.nn.Module):
        return model.load_state_dict(torch.load(self.weights_path))

In [None]:
def iou_pytorch(predictions: torch.Tensor, labels: torch.Tensor, e: float = 1e-7):   
    predictions = torch.where(predictions > 0.5, 1, 0)
    labels = labels.byte()
    
    intersection = (predictions & labels).float().sum((1, 2))
    union = (predictions | labels).float().sum((1, 2))
    
    iou = (intersection + e) / (union + e)
    return iou

def dice_pytorch(predictions: torch.Tensor, labels: torch.Tensor, e: float = 1e-7):
    predictions = torch.where(predictions > 0.5, 1, 0)
    labels = labels.byte()
    
    intersection = (predictions & labels).float().sum((1, 2))
    return ((2 * intersection) + e) / (predictions.float().sum((1, 2)) + labels.float().sum((1, 2)) + e)

In [None]:
def BCE_dice(output, target, alpha=0.01):
    bce = torch.nn.functional.binary_cross_entropy(output, target)
    soft_dice = 1 - dice_pytorch(output, target).mean()
    return bce + alpha * soft_dice

In [None]:
def training_loop(epochs, model, train_loader, valid_loader, optimizer, loss_fn, lr_scheduler, is_3d=False):
    history = {'train_loss': [], 'val_loss': [], 'val_IoU': [], 'val_dice': []}
    early_stopping = EarlyStopping(patience=7)
    
    for epoch in range(1, epochs + 1):
        running_loss = 0
        train_samples = 0
        model.train()
        
        for data in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}"):
            if is_3d:
                # Handle list format from custom collate function
                images_list, masks_list = data
                for images, masks in zip(images_list, masks_list):
                    # images and masks are already (C, D, H, W), add batch dimension
                    images = images.unsqueeze(0).to(device)  # (1, C, D, H, W)
                    masks = masks.unsqueeze(0).to(device)
                    
                    predictions = model(images)
                    
                    loss = loss_fn(predictions, masks)
                    running_loss += loss.item()
                    train_samples += 1
                    
                    loss.backward()
                    optimizer.step()
                    optimizer.zero_grad()
            else:
                images, masks = data
                images, masks = images.to(device), masks.to(device)
                
                predictions = model(images)
                if len(predictions.shape) == 4:
                    predictions = predictions.squeeze(1)
                
                loss = loss_fn(predictions, masks)
                running_loss += loss.item() * images.size(0)
                train_samples += images.size(0)
                
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
        
        # Validation
        model.eval()
        with torch.no_grad():
            running_IoU = 0
            running_dice = 0
            running_valid_loss = 0
            valid_samples = 0
            
            for data in valid_loader:
                if is_3d:
                    images_list, masks_list = data
                    for images, masks in zip(images_list, masks_list):
                        images = images.unsqueeze(0).to(device)
                        masks = masks.unsqueeze(0).to(device)
                        
                        predictions = model(images)
                        
                        running_dice += dice_pytorch(predictions, masks).sum().item()
                        running_IoU += iou_pytorch(predictions, masks).sum().item()
                        loss = loss_fn(predictions, masks)
                        running_valid_loss += loss.item()
                        valid_samples += 1
                else:
                    images, masks = data
                    images, masks = images.to(device), masks.to(device)
                    
                    predictions = model(images)
                    if len(predictions.shape) == 4:
                        predictions = predictions.squeeze(1)
                    
                    running_dice += dice_pytorch(predictions, masks).sum().item()
                    running_IoU += iou_pytorch(predictions, masks).sum().item()
                    loss = loss_fn(predictions, masks)
                    running_valid_loss += loss.item() * images.size(0)
                    valid_samples += images.size(0)
        
        train_loss = running_loss / max(1, train_samples)
        val_loss = running_valid_loss / max(1, valid_samples)
        val_dice = running_dice / max(1, valid_samples)
        val_IoU = running_IoU / max(1, valid_samples)
        
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_IoU'].append(val_IoU)
        history['val_dice'].append(val_dice)
        
        print(f'Epoch {epoch}/{epochs} | Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f} | Val Dice: {val_dice:.6f} | Val IoU: {val_IoU:.6f}')
        
        lr_scheduler.step(val_loss)
        if early_stopping(val_loss, model):
            print(f"Early stopping at epoch {epoch}")
            early_stopping.load_weights(model)
            break
    
    model.eval()
    return history

In [None]:
def evaluate_model(model, test_loader, device, model_name, is_3d=False):
    model.eval()
    all_dice = []
    all_iou = []
    
    with torch.no_grad():
        for data in tqdm(test_loader, desc=f"Evaluating {model_name}"):
            if is_3d:
                images_list, masks_list = data
                for images, masks in zip(images_list, masks_list):
                    images = images.unsqueeze(0).to(device)
                    masks = masks.unsqueeze(0).to(device)
                    
                    predictions = model(images)
                    
                    pred_binary = (predictions > 0.5).float()
                    
                    dice = dice_pytorch(pred_binary, masks).cpu().numpy()
                    iou = iou_pytorch(pred_binary, masks).cpu().numpy()
                    
                    all_dice.extend(dice.flatten())
                    all_iou.extend(iou.flatten())
            else:
                images, masks = data
                images = images.to(device)
                masks = masks.to(device)
                
                predictions = model(images)
                
                if predictions.shape != masks.shape:
                    predictions = predictions.squeeze(1)
                
                pred_binary = (predictions > 0.5).float()
                
                dice = dice_pytorch(pred_binary, masks).cpu().numpy()
                iou = iou_pytorch(pred_binary, masks).cpu().numpy()
                
                all_dice.extend(dice.flatten())
                all_iou.extend(iou.flatten())
    
    return {
        'model': model_name,
        'dice': np.mean(all_dice),
        'dice_std': np.std(all_dice),
        'iou': np.mean(all_iou),
        'iou_std': np.std(all_iou),
    }

In [None]:

print("TRAINING MODEL 1: 2D U-Net (MONAI)")


model_unet_2d = UNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128),
    strides=(2, 2, 2),
    num_res_units=2,
)
model_unet_2d.to(device)

optimizer_2d = Adam(model_unet_2d.parameters(), lr=0.001)
loss_fn_2d = torch.nn.BCEWithLogitsLoss()
lr_scheduler_2d = ReduceLROnPlateau(optimizer=optimizer_2d, patience=3, factor=0.5)

history_unet_2d = training_loop(20, model_unet_2d, train_loader_2d, valid_loader_2d, 
                                 optimizer_2d, loss_fn_2d, lr_scheduler_2d, is_3d=False)

print("\n 2D U-Net Training Complete!\n")

In [None]:

print("TRAINING MODEL 2: 3D U-Net (MONAI) - True 3D")


model_unet_3d = UNet(
    spatial_dims=3,  # True 3D
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128),
    strides=(2, 2, 2),
    num_res_units=2,
)
model_unet_3d.to(device)

optimizer_3d = Adam(model_unet_3d.parameters(), lr=0.001)
loss_fn_3d = torch.nn.BCEWithLogitsLoss()
lr_scheduler_3d = ReduceLROnPlateau(optimizer=optimizer_3d, patience=3, factor=0.5)

history_unet_3d = training_loop(20, model_unet_3d, train_loader_3d, valid_loader_3d, 
                                 optimizer_3d, loss_fn_3d, lr_scheduler_3d, is_3d=False)

print("\n✓ 3D U-Net Training Complete!\n")

In [None]:

print("TRAINING MODEL 3: UNETR (Transformer-based, 3D)")


model_unetr_3d = UNETR(
    in_channels=1,
    out_channels=1,
    img_size=(40, 256, 256),  # 3D volume size (D, H, W)
    feature_size=16,
    hidden_size=768,
    mlp_dim=3072,
    num_heads=12,
    norm_name="instance",
    conv_block=True,
    res_block=True,
    spatial_dims=3,  # True 3D
)
model_unetr_3d.to(device)

optimizer_unetr = Adam(model_unetr_3d.parameters(), lr=0.0005)
loss_fn_unetr = torch.nn.BCEWithLogitsLoss()
lr_scheduler_unetr = ReduceLROnPlateau(optimizer=optimizer_unetr, patience=3, factor=0.5)


history_unetr_3d = training_loop(20, model_unetr_3d, train_loader_3d, valid_loader_3d, 
                                 optimizer_unetr, loss_fn_unetr, lr_scheduler_unetr, is_3d=False)

print("\n✓ UNETR 3D Training Complete!\n")

In [None]:
print("=" * 60)
print("TRAINING MODEL 4: nnU-Net ")
print("=" * 60)


# nnU-Net-like model: standard U-Net with aggressive augmentation
model_nnunet = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    channels=(32, 64, 128, 256),  # Deeper than standard
    strides=(2, 2, 2),
    num_res_units=2,
    dropout=0.2,  # nnU-Net uses dropout
)
model_nnunet.to(device)

optimizer_nnunet = Adam(model_nnunet.parameters(), lr=0.001, weight_decay=3e-5)  # nnU-Net uses weight decay
loss_fn_nnunet = torch.nn.BCEWithLogitsLoss()
lr_scheduler_nnunet = ReduceLROnPlateau(optimizer=optimizer_nnunet, patience=3, factor=0.5)

history_nnunet = training_loop(20, model_nnunet, train_loader_3d, valid_loader_3d, 
                               optimizer_nnunet, loss_fn_nnunet, lr_scheduler_nnunet, is_3d=False)

print("\nnnU-Net Training Complete!\n")