In [None]:
# -*- coding: utf-8 -*-
# ---
# jupyter:
#   jupytext:
#     text_representation:
#       extension: .py
#       format_name: light
#       format_version: '1.5'
#       jupytext_version: 1.14.5 # Or your version
#   kernelspec:
#     display_name: Python 3 (ipykernel) # Or your kernel name
#     language: python
#     name: python3
# ---

# # BraTS 2023 Semi-Supervised SwinUNETR Demo & Explainability
#
# This notebook demonstrates the results of the trained SSL SwinUNETR model,
# visualizes segmentations, and shows attention maps using Attention Rollout.

# ## 1. Setup and Imports
# Import necessary libraries and configure paths.

# +
import os
import sys
import json
import time
from functools import partial

import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib
import torch
from monai.config import print_config
from monai.data import Dataset, DataLoader, decollate_batch, list_data_collate
from monai.inferers import sliding_window_inference
from monai.networks.nets import SwinUNETR
from monai import transforms
from monai.transforms import (
    AsDiscrete,
    Activations,
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    EnsureTyped,
    NormalizeIntensityd,
    # Add other necessary transforms used during validation/inference
)
from monai.visualize import plot_2d_or_3d_image # For visualization

# Add project root to path if necessary (if running notebook from a subfolder)
# module_path = os.path.abspath(os.path.join('..'))
# if module_path not in sys.path:
#     sys.path.append(module_path)

# Import helper functions if defined in separate .py files
# from brats_ssl_train import get_brats2023_datalists # Example

print_config()
# -

# ## 2. Configuration
# Define paths to data, model checkpoint, and other parameters.

# +
# --- Parameters to Set ---
DATA_DIR = './data/brats2023' # Path to dataset
MAPPING_FILE = 'BraTS2023_2017_GLI_Mapping.xlsx - Sheet1.csv' # Mapping file name
MODEL_CHECKPOINT = './output_brats_ssl/model_best.pt' # Path to the trained model weights
OUTPUT_DIR = './output_brats_ssl/notebook_viz' # Directory to save visualizations

ROI_SIZE = (128, 128, 128) # Must match training ROI
IN_CHANNELS = 4 # t1c, t1n, t2f, t2w
OUT_CHANNELS = 3 # ET, TC, WT
FEATURE_SIZE = 48 # Must match trained model

# Select a case ID for demonstration (e.g., from your validation set)
# Replace with an actual patient ID from the 'BraTS 2023' column of your mapping CSV
CASE_ID = "BraTS-GLI-00000-000" # !!! REPLACE WITH A VALID CASE !!!
# --- End Parameters ---

os.makedirs(OUTPUT_DIR, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# -

# ## 3. Load Model
# Load the trained SwinUNETR model from the checkpoint file.

# +
print(f"Loading model checkpoint from: {MODEL_CHECKPOINT}")
if not os.path.exists(MODEL_CHECKPOINT):
    raise FileNotFoundError(f"Model checkpoint not found at {MODEL_CHECKPOINT}")

model = SwinUNETR(
    img_size=ROI_SIZE,
    in_channels=IN_CHANNELS,
    out_channels=OUT_CHANNELS,
    feature_size=FEATURE_SIZE,
    use_checkpoint=False, # No need for checkpointing during inference usually
).to(device)

checkpoint = torch.load(MODEL_CHECKPOINT, map_location=device)
# Load state dict - handle potential DataParallel wrapping if model was saved that way
if 'state_dict' in checkpoint:
    model_state_dict = checkpoint['state_dict']
    # Remove 'module.' prefix if saved using DataParallel
    if list(model_state_dict.keys())[0].startswith('module.'):
        model_state_dict = {k[len("module."):]: v for k, v in model_state_dict.items()}
    model.load_state_dict(model_state_dict)
    print(f"Loaded model state_dict from epoch {checkpoint.get('epoch', 'N/A')}")
else:
    # If the checkpoint only contains the state_dict
    model.load_state_dict(checkpoint)
    print("Loaded model state_dict directly.")

model.eval() # Set model to evaluation mode
print("Model loaded successfully.")
# -

# ## 4. Prepare Data for Selected Case
# Load the MRI scans and ground truth label for the chosen `CASE_ID`.

# +
# Construct file paths for the selected case
# Adjust 'training_data_root' if needed
training_data_root = os.path.join(DATA_DIR, "ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData")
patient_path = os.path.join(training_data_root, CASE_ID)

if not os.path.isdir(patient_path):
     raise FileNotFoundError(f"Directory for case {CASE_ID} not found at {patient_path}")

image_files = [
    os.path.join(patient_path, f"{CASE_ID}-t1c.nii.gz"),
    os.path.join(patient_path, f"{CASE_ID}-t1n.nii.gz"),
    os.path.join(patient_path, f"{CASE_ID}-t2f.nii.gz"), # FLAIR
    os.path.join(patient_path, f"{CASE_ID}-t2w.nii.gz"), # T2
]
label_file = os.path.join(patient_path, f"{CASE_ID}-seg.nii.gz")

# Check if files exist
missing_files = [f for f in image_files + [label_file] if not os.path.exists(f)]
if missing_files:
    raise FileNotFoundError(f"Missing files for case {CASE_ID}: {missing_files}")

case_data = [{"image": image_files, "label": label_file}]

# Define transforms for loading and preprocessing (similar to validation)
# Ensure these match the preprocessing the model expects!
inference_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"], image_only=False),
        EnsureChannelFirstd(keys=["image", "label"]),
        transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), # Convert label for comparison
        EnsureTyped(keys=["image", "label"], dtype=torch.float32),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        # Add Spacingd or Orientationd if they were used during training/validation
        # transforms.Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")),
    ]
)

# Apply transforms
transformed_data = inference_transforms(case_data[0])
image_tensor = transformed_data["image"].unsqueeze(0).to(device) # Add batch dim and move to device
label_tensor = transformed_data["label"].unsqueeze(0).to(device) # Add batch dim

print(f"Input image shape: {image_tensor.shape}") # Should be (1, 4, H, W, D)
print(f"Label shape: {label_tensor.shape}")     # Should be (1, 3, H, W, D)
# -

# ## 5. Perform Inference
# Run the loaded model on the prepared data using sliding window inference.

# +
# Define the inferer (same as validation)
model_inferer = partial(
    sliding_window_inference,
    roi_size=ROI_SIZE,
    sw_batch_size=1, # Can increase if GPU memory allows
    predictor=model,
    overlap=0.6, # Use a higher overlap for potentially smoother results
    mode="gaussian",
    progress=True
)

# Define post-processing (sigmoid + threshold)
post_pred_viz = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

# Run inference
print("Running inference...")
start_time = time.time()
with torch.no_grad():
    logits = model_inferer(image_tensor)
    prediction = post_pred_viz(logits) # Get discrete prediction (0/1)
end_time = time.time()
print(f"Inference complete. Time taken: {end_time - start_time:.2f} seconds")
print(f"Prediction shape: {prediction.shape}") # Should be (1, 3, H, W, D)

# Move prediction to CPU for visualization
prediction_cpu = prediction.squeeze(0).cpu().numpy() # Shape (3, H, W, D)
label_cpu = label_tensor.squeeze(0).cpu().numpy()     # Shape (3, H, W, D)
image_cpu = image_tensor.squeeze(0).cpu().numpy()     # Shape (4, H, W, D)
# -

# ## 6. Visualize Segmentation Results
# Display slices of the input (e.g., FLAIR), ground truth label, and model prediction.

# +
# Choose a slice number for visualization (axial view)
slice_idx = prediction_cpu.shape[3] // 2 # Middle slice
# Or choose a slice with significant tumor presence

# Select input modality to display (e.g., FLAIR is channel 2: t1c=0, t1n=1, t2f=2, t2w=3)
img_display = image_cpu[2, :, :, slice_idx]

# Combine multi-channel labels/predictions into a single map for visualization
# Class mapping: 1=ET, 2=TC, 3=WT (adjust if needed)
# Prediction: Channel 0=ET, 1=TC, 2=WT
pred_viz = np.zeros_like(prediction_cpu[0]) # Initialize with background
pred_viz[prediction_cpu[2] == 1] = 3 # WT
pred_viz[prediction_cpu[1] == 1] = 2 # TC (overwrites WT)
pred_viz[prediction_cpu[0] == 1] = 1 # ET (overwrites TC)
pred_display = pred_viz[:, :, slice_idx]

# Ground Truth: Channel 0=ET, 1=TC, 2=WT
gt_viz = np.zeros_like(label_cpu[0])
gt_viz[label_cpu[2] == 1] = 3 # WT
gt_viz[label_cpu[1] == 1] = 2 # TC
gt_viz[label_cpu[0] == 1] = 1 # ET
gt_display = gt_viz[:, :, slice_idx]


# Plotting
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
plt.suptitle(f"Segmentation Results for Case: {CASE_ID} (Slice: {slice_idx})", fontsize=16)

axes[0].imshow(img_display, cmap='gray')
axes[0].set_title("Input FLAIR")
axes[0].axis('off')

axes[1].imshow(gt_display, cmap='jet', vmin=0, vmax=3) # Use 'jet' or other colormap for labels
axes[1].set_title("Ground Truth (1:ET, 2:TC, 3:WT)")
axes[1].axis('off')

axes[2].imshow(pred_display, cmap='jet', vmin=0, vmax=3)
axes[2].set_title("Model Prediction (1:ET, 2:TC, 3:WT)")
axes[2].axis('off')

plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap
# Save the figure
save_path = os.path.join(OUTPUT_DIR, f"{CASE_ID}_segmentation_slice_{slice_idx}.png")
plt.savefig(save_path)
print(f"Segmentation visualization saved to {save_path}")
plt.show()

# Optional: Use MONAI's plot_2d_or_3d_image for more advanced plotting
# plot_2d_or_3d_image(image_tensor, slice_idx, figure=fig, subplot=131, title="Input")
# plot_2d_or_3d_image(label_tensor, slice_idx, figure=fig, subplot=132, title="Ground Truth")
# plot_2d_or_3d_image(prediction, slice_idx, figure=fig, subplot=133, title="Prediction")
# plt.show()
# -

# ## 7. Attention Rollout Visualization
#
# **--- !!! IMPLEMENTATION REQUIRED !!! ---**
#
# This section requires implementing the Attention Rollout algorithm specifically for the SwinUNETR architecture used. This typically involves:
# 1.  **Hooking into Attention Layers:** Registering forward hooks on the attention layers (specifically the attention matrix computation) within the Swin Transformer blocks of the SwinUNETR model.
# 2.  **Extracting Attention Matrices:** During a forward pass (inference) on the input image, capture the attention matrices from each relevant layer/block.
# 3.  **Aggregating Attention (Rollout):** Implement the rollout logic. This usually involves matrix multiplying the attention matrices across layers, potentially adding residual connections and layer normalization effects. The exact formula depends on the chosen rollout variant and the Swin Transformer architecture details. Refer to the original Attention Rollout paper [cite: 101] and potentially Swin Transformer specific adaptations.
# 4.  **Generating Visualization:** Reshape and visualize the final aggregated attention map, often overlaid on the input image slice.

# +
# Placeholder function - Replace with your actual implementation
def generate_attention_rollout(model, input_tensor, device, **kwargs):
    """
    Placeholder for Attention Rollout implementation for SwinUNETR.

    Args:
        model: The trained SwinUNETR model.
        input_tensor: The input image tensor (B, C, H, W, D).
        device: The device (CPU/GPU).
        **kwargs: Additional parameters for rollout (e.g., layer selection, head fusion).

    Returns:
        np.ndarray: The aggregated attention map (e.g., shape H, W, D or H, W for a slice).
                    Returns None if not implemented.
    """
    print("--- !!! Attention Rollout function not implemented !!! ---")
    # --- Your implementation here ---
    # 1. Add hooks to model layers
    # 2. Run model(input_tensor)
    # 3. Process captured attention matrices using rollout logic
    # 4. Remove hooks
    # --- End Implementation ---
    return None # Return the calculated map

# --- Generate and Visualize Attention Map ---
print("Generating Attention Map (requires implementation)...")
attention_map = generate_attention_rollout(model, image_tensor, device)

if attention_map is not None:
    # Visualize the attention map (e.g., overlay on the input slice)
    plt.figure(figsize=(12, 6))
    plt.suptitle(f"Attention Map for Case: {CASE_ID} (Slice: {slice_idx})", fontsize=16)

    plt.subplot(1, 2, 1)
    plt.imshow(img_display, cmap='gray')
    plt.title("Input FLAIR")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    # Assuming attention_map is 2D for the slice
    attention_display = attention_map # Or attention_map[:, :, slice_idx] if 3D
    plt.imshow(img_display, cmap='gray')
    plt.imshow(attention_display, cmap='viridis', alpha=0.6) # Overlay attention map
    plt.title("Attention Rollout Map")
    plt.axis('off')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    save_path_attn = os.path.join(OUTPUT_DIR, f"{CASE_ID}_attention_slice_{slice_idx}.png")
    plt.savefig(save_path_attn)
    print(f"Attention map visualization saved to {save_path_attn}")
    plt.show()
else:
    print("Skipping attention map visualization.")

# -

# ## 8. Quantitative Results Summary (Optional)
# Load and display key metrics from the training process (e.g., best validation Dice).

# +
history_path = os.path.join(os.path.dirname(MODEL_CHECKPOINT), "training_history.json")

if os.path.exists(history_path):
    print(f"Loading training history from: {history_path}")
    with open(history_path, 'r') as f:
        history = json.load(f)

    best_val_epoch_idx = np.argmax(history['val_mean_dice'])
    best_val_dice = history['val_mean_dice'][best_val_epoch_idx]
    best_et_dice = history['val_dice_et'][best_val_epoch_idx]
    best_tc_dice = history['val_dice_tc'][best_val_epoch_idx]
    best_wt_dice = history['val_dice_wt'][best_val_epoch_idx]
    # Find the epoch number corresponding to the best validation index
    # This assumes val_every logic used during saving history is consistent
    val_epochs_count = len(history['val_mean_dice'])
    epochs_per_val = args.val_every if 'args' in locals() else 10 # Default if args not available
    best_epoch = (best_val_epoch_idx + 1) * epochs_per_val


    print("\n--- Best Validation Metrics ---")
    print(f"Achieved at Epoch: ~{best_epoch}")
    print(f"Overall Mean Dice: {best_val_dice:.4f}")
    print(f"  - Enhancing Tumor (ET) Dice: {best_et_dice:.4f}")
    print(f"  - Tumor Core (TC) Dice:      {best_tc_dice:.4f}")
    print(f"  - Whole Tumor (WT) Dice:     {best_wt_dice:.4f}")
else:
    print(f"Training history file not found at {history_path}. Cannot display metrics.")
# -

# ## 9. Conclusion
# Summarize the findings shown in the notebook.
#
# * The notebook demonstrated loading the SSL-trained SwinUNETR model.
# * Inference was performed on a sample BraTS 2023 case.
# * Segmentation results were visualized alongside the ground truth.
# * (If implemented) Attention Rollout maps provided insight into the model's focus areas.
# * Key quantitative metrics from training were summarized.