### 1: Setup and Hyperparameter Preview

We load the configs to ensure the notebook matches the eventual script-based pipeline.

In [1]:
import torch
import matplotlib.pyplot as plt
from monai.data import DataLoader, Dataset, decollate_batch
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.visualize import plot_2d_or_3d_image
from tqdm import tqdm

import sys
sys.path.append("..")
from config.global_config import GlobalConfig
from config.model_config import ModelConfig
from config.transform_config import TransformConfig
from src.model_factory import ModelFactory

# Set device
device = torch.device(GlobalConfig.DEVICE)
print(f"üöÄ Using device: {device}")

monai.transforms.spatial.dictionary Orientationd.__init__:labels: Current default value of argument `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` was changed in version None from `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` to `labels=None`. Default value changed to None meaning that the transform now uses the 'space' of a meta-tensor, if applicable, to determine appropriate axis labels.


NameError: name 'ModelConfig' is not defined

### 2: Mock Dataset Creation

For prototyping, we don't want to load 100GB of data. We select 4-8 samples to verify the gradient flow and transform logic.

In [None]:
import glob

# Load paths from processed directory
images = sorted(glob.glob(str(GlobalConfig.PROCESSED_DATA_DIR / "*_img.nii.gz")))
labels = sorted(glob.glob(str(GlobalConfig.PROCESSED_DATA_DIR / "*_seg.nii.gz")))

data_dicts = [{"image": i, "label": l} for i, l in zip(images, labels)]

# Split for prototype (e.g., 4 train, 2 val)
train_files, val_files = data_dicts[:4], data_dicts[4:6]

train_ds = Dataset(data=train_files, transform=TransformConfig.get_train_transforms())
train_loader = DataLoader(train_ds, batch_size=ModelConfig.BATCH_SIZE, shuffle=True)

val_ds = Dataset(data=val_files, transform=TransformConfig.get_val_transforms())
val_loader = DataLoader(val_ds, batch_size=1)

print(f"üì¶ Prototype ready: {len(train_ds)} train, {len(val_ds)} val samples.")

### 3: Sanity Check - Visualizing Transforms
Before training, we must visualize the 3D patches. If the "organ" is missing from the random crop, the model won't learn.

check_ds = Dataset(data=train_files, transform=TransformConfig.get_train_transforms())
check_loader = DataLoader(check_ds, batch_size=1)
check_data = next(iter(check_loader))

image, label = (check_data["image"][0][0], check_data["label"][0][0])
print(f"Image shape: {image.shape}, Label shape: {label.shape}")

# Plot center slice of the 3D patch
plt.figure("check", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Image Patch (Center Slice)")
plt.imshow(image[:, :, image.shape[2]//2], cmap="gray")
plt.subplot(1, 2, 2)
plt.title("Label Patch (Center Slice)")
plt.imshow(label[:, :, label.shape[2]//2])
plt.show()

### 4: Model Initialization
We instantiate the Swin-UNETR. This is a transformer-based architecture that uses shifted windows for 3D context.

model = ModelFactory.get_model(ModelConfig.MODEL_NAME).to(device)

# Loss and Optimizer
from monai.losses import DiceFocalLoss
loss_function = DiceFocalLoss(to_onehot_y=ModelConfig.OUT_CHANNELS, softmax=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=ModelConfig.LEARNING_RATE)

print(f"üèóÔ∏è {ModelConfig.MODEL_NAME} initialized with {sum(p.numel() for p in model.parameters()):,} parameters.")

### 5: The "Overfit" Test
In prototyping, the goal is to see if the model can overfit on 2 images. If it can't reduce loss to near zero on a tiny set, there is a bug in the code or the transforms.

max_epochs = 50
epoch_loss_values = []

model.train()
for epoch in range(max_epochs):
    epoch_loss = 0
    for batch_data in train_loader:
        inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    
    epoch_loss /= len(train_loader)
    epoch_loss_values.append(epoch_loss)
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{max_epochs} | Loss: {epoch_loss:.4f}")

plt.plot(epoch_loss_values)
plt.title("Prototype Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

### 6: Prototype Inference (Sliding Window)
We use sliding_window_inference because the full volume won't fit in VRAM. This tests the logic that will eventually live in src/inference.py.

model.eval()
with torch.no_grad():
    for val_data in val_loader:
        val_inputs = val_data["image"].to(device)
        roi_size = ModelConfig.ROI_SIZE
        sw_batch_size = 4
        
        val_outputs = sliding_window_inference(
            val_inputs, roi_size, sw_batch_size, model, overlap=ModelConfig.OVERLAP
        )
        
        # Visualize the result
        plt.figure("Inference", (18, 6))
        plt.subplot(1, 3, 1); plt.title("Input"); plt.imshow(val_inputs.cpu()[0, 0, :, :, 64], cmap="gray")
        plt.subplot(1, 3, 2); plt.title("Ground Truth"); plt.imshow(val_data["label"][0, 0, :, :, 64])
        plt.subplot(1, 3, 3); plt.title("Prediction"); plt.imshow(torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, 64])
        plt.show()
        break # Just show one for the prototype