## 1. Environment Setup and Data Acquisition

In [None]:
# Install MONAI and medical imaging dependencies
!pip install -q "monai[low_resource, nibabel, tqdm, ignite]" 
!pip install -q matplotlib

import os
import glob
import torch
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from monai.utils import set_determinism
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, ScaleIntensityd, 
    Resized, ToTensord, Lambdad, ConcatItemsd
)
from monai.networks.nets import UNet
from monai.losses import DiceLoss
from monai.metrics import DiceMetric, ConfusionMatrixMetric
from monai.data import Dataset, DataLoader, decollate_batch

set_determinism(seed=42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Extract dataset
!mkdir -p /kaggle/working/brats_data
!tar -xf /kaggle/input/brats-2021-task1/BraTS2021_Training_Data.tar -C /kaggle/working/brats_data

## 2. Multi-Modal Data Mapping and Pipeline

In [None]:
data_dir = "/kaggle/working/brats_data"

flair_list = sorted(glob.glob(os.path.join(data_dir, "**/*_flair.nii.gz"), recursive=True))
t1_list = sorted(glob.glob(os.path.join(data_dir, "**/*_t1.nii.gz"), recursive=True))
t2_list = sorted(glob.glob(os.path.join(data_dir, "**/*_t2.nii.gz"), recursive=True))
label_list = sorted(glob.glob(os.path.join(data_dir, "**/*_seg.nii.gz"), recursive=True))

data_dicts = [
    {"flair": f, "t1": t1, "t2": t2, "label": l} 
    for f, t1, t2, l in zip(flair_list, t1_list, t2_list, label_list)
]

train_transforms = Compose([
    LoadImaged(keys=["flair", "t1", "t2", "label"]),
    EnsureChannelFirstd(keys=["flair", "t1", "t2", "label"]),
    ScaleIntensityd(keys=["flair", "t1", "t2"]), 
    ConcatItemsd(keys=["flair", "t1", "t2"], name="image"),
    Lambdad(keys="label", func=lambda x: torch.where(x > 0, 1, 0)),
    Resized(keys=["image", "label"], spatial_size=(128, 128, 64)),
    ToTensord(keys=["image", "label"]),
])

train_ds = Dataset(data=data_dicts[:300], transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True)

## 3. Model Training and Optimization

In [None]:
model = UNet(
    spatial_dims=3,
    in_channels=3,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
max_epochs = 50

for epoch in range(max_epochs):
    model.train()
    epoch_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{max_epochs}")
    
    for batch_data in progress_bar:
        inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    if (epoch + 1) % 10 == 0:
        torch.save(model.state_dict(), f"unet_brats_multimodal_epoch_{epoch+1}.pth")

torch.save(model.state_dict(), "suvarna_tumor_detector_v1.pth")

## 4. Evaluation and Visualization

In [None]:
def generate_clinical_report(patient_idx=305):
    model.eval()
    with torch.no_grad():
        single_ds = Dataset(data=data_dicts[patient_idx:patient_idx+1], transform=train_transforms)
        batch = next(iter(DataLoader(single_ds, batch_size=1)))
        inputs, labels = batch["image"].to(device), batch["label"].to(device)
        
        output = model(inputs)
        prediction = torch.argmax(output, dim=1).cpu().numpy()[0]
        slice_idx = 32
        input_data = inputs[0].cpu().numpy()
        
        plt.figure(figsize=(15, 8))
        plt.subplot(1, 4, 1); plt.title("FLAIR"); plt.imshow(input_data[0, :, :, slice_idx], cmap="gray")
        plt.subplot(1, 4, 2); plt.title("T2"); plt.imshow(input_data[2, :, :, slice_idx], cmap="gray")
        plt.subplot(1, 4, 3); plt.title("Ground Truth"); plt.imshow(labels[0, 0, :, :, slice_idx].cpu(), cmap="Reds")
        plt.subplot(1, 4, 4); plt.title("AI Prediction"); plt.imshow(prediction[:, :, slice_idx], cmap="Greens")
        plt.show()

# Run Evaluation Metrics
dice_metric = DiceMetric(include_background=False, reduction="mean")
model.eval()
for test_data in DataLoader(Dataset(data=data_dicts[300:350], transform=train_transforms), batch_size=1):
    with torch.no_grad():
        outputs = model(test_data["image"].to(device))
        preds = torch.argmax(outputs, dim=1, keepdim=True)
        dice_metric(y_pred=preds, y=test_data["label"].to(device))

print(f"Final Dice Score: {dice_metric.aggregate().item():.4f}")
# Sample visualization
generate_clinical_report(patient_idx=440)