In [5]:
# -*- coding: utf-8 -*-
# ---
# jupyter:
#   jupytext:
#     text_representation:
#       extension: .py
#       format_name: light
#       format_version: '1.5'
#       jupytext_version: 1.14.5 # Or your version
#   kernelspec:
#     display_name: Python 3 (ipykernel) # Or your kernel name
#     language: python
#     name: python3
# ---

# # BraTS 2023 SSL - Mini Training Verification Notebook
#
# Purpose: Verify core components (data loading, transforms, dataset, model forward/backward pass)
# using a small subset of data and limited training steps before running the full script.

# ## 1. Imports

# +
import os
import sys
import json
import time
import warnings
from functools import partial

import gc

import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

# MONAI Imports
from monai.config import print_config
from monai.data import Dataset, list_data_collate, decollate_batch
from monai.inferers import sliding_window_inference
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from monai.networks.nets import SwinUNETR
from monai import transforms
from monai.transforms import (
    AsDiscrete,
    Activations,
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    EnsureTyped,
    ConvertToMultiChannelBasedOnBratsClassesd,
    NormalizeIntensityd,
    CropForegroundd,
    RandSpatialCropd,
    RandFlipd,
    RandRotate90d,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandGaussianNoised,
    RandGaussianSmoothd,
)
from monai.utils.enums import MetricReduction

print_config()
# -

# ## 2. Configuration (Mini-Version)
# Hardcode paths and use small values for quick testing.

# +
# --- Mini-Config ---
DATA_DIR = '/home/users/vraja/dl/data/brats2021challenge' # Base dir containing CSV and data folder
STRUCTURE_CSV = 'brats_dsc.csv' # Your structure CSV
OUTPUT_DIR = './output_brats_ssl_mini_test' # Temporary output for this test

ROI_SIZE = (96, 96, 96) # Keep consistent with target model
BATCH_SIZE = 1 # Small total batch size (e.g., 1 labeled, 1 unlabeled)
LABELED_BS_RATIO = 0.5 # Ratio of labeled samples in batch
LABELED_RATIO = 0.1 # Use only 10% of data as potentially labeled for this test
DEBUG_SUBSET_SIZE = 10 # <<< Use only the first N samples overall for quick testing
MAX_TRAIN_STEPS = 5 # <<< Run only N training steps total

LR = 1e-4
WEIGHT_DECAY = 1e-5
CONSISTENCY_WEIGHT = 1.0
SEED = 42
NUM_WORKERS = 2 # Reduce workers for lighter load

# --- End Mini-Config ---

os.makedirs(OUTPUT_DIR, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seed
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
# -

# ## 3. Helper Classes & Functions (Copied from Training Script)

# +
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = np.where(self.count > 0, self.sum / self.count, self.sum)

def get_brats2023_datalists(data_dir, structure_csv_filename, labeled_ratio=0.4, debug_subset_size=None):
    """
    Reads the BraTS 2023 structure CSV file and creates labeled/unlabeled data lists.
    Added debug_subset_size parameter.
    """
    structure_csv_path = os.path.join(data_dir, structure_csv_filename)
    if not os.path.exists(structure_csv_path):
        raise FileNotFoundError(f"Structure CSV file not found: {structure_csv_path}")

    df = pd.read_csv(structure_csv_path)
    training_data_root_identifier = "ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData"
    print(f"Using structure CSV: {structure_csv_path}")
    print(f"Identifying patient folders within paths containing: {training_data_root_identifier}")

    all_files = []
    patient_dirs = df[(df['Is Directory'] == True) &
                      (df['Path'].str.contains(training_data_root_identifier)) &
                      (df['Path'].str.contains('BraTS-GLI-'))]['Path'].tolist()

    print(f"Found {len(patient_dirs)} potential patient directories.")
    if debug_subset_size is not None:
        print(f"--- Using subset of first {debug_subset_size} directories for testing ---")
        patient_dirs = patient_dirs[:debug_subset_size]

    for patient_path in patient_dirs:
        patient_folder_name = os.path.basename(patient_path)
        image_files = [
            os.path.join(patient_path, f"{patient_folder_name}-t1c.nii.gz"),
            os.path.join(patient_path, f"{patient_folder_name}-t1n.nii.gz"),
            os.path.join(patient_path, f"{patient_folder_name}-t2f.nii.gz"), # FLAIR
            os.path.join(patient_path, f"{patient_folder_name}-t2w.nii.gz"), # T2
        ]
        label_file = os.path.join(patient_path, f"{patient_folder_name}-seg.nii.gz")

        if all(os.path.exists(f) for f in image_files) and os.path.exists(label_file):
            all_files.append({"image": image_files, "label": label_file})
        else:
            missing = [f for f in image_files + [label_file] if not os.path.exists(f)]
            print(f"Warning: Missing files for patient {patient_folder_name}, skipping. Missing: {missing}")

    if not all_files:
        raise ValueError("No valid data files found for the subset.")

    np.random.shuffle(all_files)
    num_labeled = int(len(all_files) * labeled_ratio)
    labeled_files = all_files[:num_labeled]
    unlabeled_files = all_files[num_labeled:]

    print(f"Total valid cases processed in subset: {len(all_files)}")
    print(f"Using {len(labeled_files)} cases as labeled data.")
    print(f"Using {len(unlabeled_files)} cases as unlabeled data.")

    # For mini-test, maybe skip validation split or use very few
    # Let's skip validation split for simplicity here.
    train_labeled_files = labeled_files
    val_files = [] # No validation in this mini-script

    return train_labeled_files, unlabeled_files, val_files # Return empty val_files

# --- Transforms (Copied) ---
train_transforms_weak = Compose(
    [
        LoadImaged(keys=["image", "label"], image_only=False, ensure_channel_first=True),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        EnsureTyped(keys=["image", "label"], dtype=torch.float32),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        CropForegroundd(keys=["image", "label"], source_key="image", k_divisible=ROI_SIZE),
        RandSpatialCropd(keys=["image", "label"], roi_size=ROI_SIZE, random_size=False),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        RandRotate90d(keys=["image", "label"], prob=0.1, max_k=3),
        RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
    ]
)
train_transforms_strong = Compose(
    [
        LoadImaged(keys=["image"], image_only=False, ensure_channel_first=True),
        EnsureTyped(keys=["image"], dtype=torch.float32),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        CropForegroundd(keys=["image"], source_key="image", k_divisible=ROI_SIZE),
        RandSpatialCropd(keys=["image"], roi_size=ROI_SIZE, random_size=False),
        RandFlipd(keys=["image"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image"], prob=0.5, spatial_axis=2),
        RandRotate90d(keys=["image"], prob=0.1, max_k=3),
        RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
        RandGaussianNoised(keys="image", prob=0.3, mean=0.0, std=0.1),
        RandGaussianSmoothd(keys="image", prob=0.3, sigma_x=(0.5, 1.5), sigma_y=(0.5, 1.5), sigma_z=(0.5, 1.5)),
        RandScaleIntensityd(keys="image", factors=0.2, prob=0.2),
        RandShiftIntensityd(keys="image", offsets=0.2, prob=0.2),
    ]
)

# --- Custom Dataset for SSL (Copied) ---
class SSLDataset(Dataset):
    def __init__(self, data, transform_weak, transform_strong):
        super().__init__(data=data, transform=None)
        self.transform_weak = transform_weak
        self.transform_strong = transform_strong

    def __getitem__(self, index):
        data_i = self.data[index]
        data_weak = self.transform_weak(data_i.copy())
        img_weak = data_weak['image']
        label_weak = data_weak.get('label', None)
        data_strong_input = {'image': data_i['image']}
        img_strong = self.transform_strong(data_strong_input)['image']
        output = {"image_weak": img_weak, "image_strong": img_strong}
        if label_weak is not None:
            output["label"] = label_weak
        return output
# -

# ## 4. Create Datasets and DataLoaders (Mini-Version)

# +
print("Creating datasets and dataloaders (mini-version)...")
train_labeled_files, unlabeled_files, _ = get_brats2023_datalists(
    DATA_DIR, STRUCTURE_CSV, LABELED_RATIO, debug_subset_size=DEBUG_SUBSET_SIZE
)

# Create the SSL datasets
train_ds_labeled = SSLDataset(data=train_labeled_files, transform_weak=train_transforms_weak, transform_strong=train_transforms_strong)
train_ds_unlabeled = SSLDataset(data=unlabeled_files, transform_weak=train_transforms_weak, transform_strong=train_transforms_strong)

# Calculate batch sizes for the mini-batch
labeled_batch_size = int(BATCH_SIZE * LABELED_BS_RATIO)
unlabeled_batch_size = BATCH_SIZE - labeled_batch_size

# Adjust if one dataset is empty or batch size is 0
if not train_labeled_files: labeled_batch_size = 0
if not unlabeled_files: unlabeled_batch_size = 0
if labeled_batch_size == 0 and unlabeled_batch_size > 0: labeled_batch_size = 1; unlabeled_batch_size = max(0, BATCH_SIZE - 1)
if unlabeled_batch_size == 0 and labeled_batch_size > 0: unlabeled_batch_size = 1; labeled_batch_size = max(0, BATCH_SIZE - 1)

print(f"Mini Batch size: {labeled_batch_size + unlabeled_batch_size} (Labeled: {labeled_batch_size}, Unlabeled: {unlabeled_batch_size})")

# Create DataLoaders
train_loader_labeled = DataLoader(
    train_ds_labeled, batch_size=labeled_batch_size, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=torch.cuda.is_available(), collate_fn=list_data_collate, drop_last=True
) if labeled_batch_size > 0 and train_ds_labeled else None

train_loader_unlabeled = DataLoader(
    train_ds_unlabeled, batch_size=unlabeled_batch_size, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=torch.cuda.is_available(), collate_fn=list_data_collate, drop_last=True
) if unlabeled_batch_size > 0 and train_ds_unlabeled else None

if not train_loader_labeled and not train_loader_unlabeled:
     raise ValueError("No data loaders could be created. Check data subset and batch sizes.")
# -

# ## 5. Initialize Model, Loss, Optimizer

# +
print("Initializing model, loss, optimizer...")
model = SwinUNETR(
    img_size=ROI_SIZE,
    in_channels=4,
    out_channels=3,
    feature_size=48, # Keep consistent
    use_checkpoint=True, # Disable checkpointing for faster mini-test
).to(device)

# Loss Functions
supervised_loss = DiceCELoss(to_onehot_y=False, sigmoid=True, lambda_dice=0.5, lambda_ce=0.5)
consistency_loss = torch.nn.MSELoss()

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

print("Model, Loss, Optimizer Initialized.")
# -

# ## 6. Mini Training Loop

# +
print(f"--- Starting Mini Training Verification (Max Steps: {MAX_TRAIN_STEPS}) ---")
model.train()
run_loss_sup = AverageMeter()
run_loss_cons = AverageMeter()
run_loss_total = AverageMeter()
start_time = time.time()

# Create iterators (handle potential None loaders)
iter_l = iter(train_loader_labeled) if train_loader_labeled else None
iter_u = iter(train_loader_unlabeled) if train_loader_unlabeled else None

steps_done = 0
while steps_done < MAX_TRAIN_STEPS:
    optimizer.zero_grad()
    total_loss_batch = 0.0
    current_labeled_bs = 0
    current_unlabeled_bs = 0

    # --- Supervised Loss ---
    if iter_l:
        try:
            batch_labeled = next(iter_l)
            images_l_weak, labels_l = batch_labeled["image_weak"].to(device), batch_labeled["label"].to(device)
            current_labeled_bs = images_l_weak.size(0)

            print(f"Step {steps_done+1}: Labeled shapes - Image: {images_l_weak.shape}, Label: {labels_l.shape}")

            logits_l = model(images_l_weak)
            loss_s = supervised_loss(logits_l, labels_l)
            run_loss_sup.update(loss_s.item(), n=current_labeled_bs)
            total_loss_batch += loss_s
        except StopIteration:
            print("Labeled loader finished.")
            iter_l = None # Stop trying labeled data
        except Exception as e:
            print(f"Error in labeled batch {steps_done+1}: {e}")
            # Decide whether to break or continue
            break # Stop on error during mini-test

    # --- Consistency Loss ---
    if iter_u:
        try:
            batch_unlabeled = next(iter_u)
            images_u_weak, images_u_strong = batch_unlabeled["image_weak"].to(device), batch_unlabeled["image_strong"].to(device)
            current_unlabeled_bs = images_u_weak.size(0)

            print(f"Step {steps_done+1}: Unlabeled shapes - Weak: {images_u_weak.shape}, Strong: {images_u_strong.shape}")

            with torch.no_grad():
                logits_u_weak = model(images_u_weak)
                pseudo_labels = torch.sigmoid(logits_u_weak.detach())

            logits_u_strong = model(images_u_strong)
            preds_strong_sig = torch.sigmoid(logits_u_strong)
            loss_c = consistency_loss(preds_strong_sig, pseudo_labels)
            run_loss_cons.update(loss_c.item(), n=current_unlabeled_bs)
            total_loss_batch += CONSISTENCY_WEIGHT * loss_c
        except StopIteration:
            print("Unlabeled loader finished.")
            iter_u = None # Stop trying unlabeled data
        except Exception as e:
            print(f"Error in unlabeled batch {steps_done+1}: {e}")
            break # Stop on error during mini-test

    # --- Backpropagation ---
    if isinstance(total_loss_batch, torch.Tensor) and total_loss_batch != 0:
        print(f"Step {steps_done+1}: Calculated Total Loss: {total_loss_batch.item():.4f}")
        total_loss_batch.backward()
        optimizer.step()
        total_bs = current_labeled_bs + current_unlabeled_bs
        if total_bs > 0:
            run_loss_total.update(total_loss_batch.item(), n=total_bs)
        print(f"Step {steps_done+1}: Backward pass and optimizer step completed.")
    elif iter_l is None and iter_u is None:
         print("Both loaders finished. Stopping mini-train.")
         break # Stop if both loaders are exhausted
    else:
        print(f"Step {steps_done+1}: No loss computed (likely waiting for data).")


    steps_done += 1
    print("-" * 30)

    # Break if both iterators are exhausted
    if iter_l is None and iter_u is None:
        print("Both data loaders exhausted before reaching max steps.")
        break

            
    del total_loss_batch # Delete loss tensor explicitly
    if 'loss_s' in locals(): del loss_s
    if 'loss_c' in locals(): del loss_c
    if 'logits_l' in locals(): del logits_l
    if 'logits_u_weak' in locals(): del logits_u_weak
    if 'logits_u_strong' in locals(): del logits_u_strong
    if 'pseudo_labels' in locals(): del pseudo_labels
    if 'preds_strong_sig' in locals(): del preds_strong_sig
    gc.collect()
    torch.cuda.empty_cache()
        
        


print(f"--- Mini Training Verification Finished ---")
print(f"Steps Run: {steps_done}")
print(f"Avg Total Loss: {run_loss_total.avg:.4f}")
print(f"Avg Sup Loss: {run_loss_sup.avg:.4f}")
print(f"Avg Cons Loss: {run_loss_cons.avg:.4f}")
print(f"Time Taken: {(time.time() - start_time):.2f}s")
# -

# ## 7. Verification Summary
#
# * If the notebook ran through the `Mini Training Loop` section without crashing (especially on CUDA errors, shape mismatches, or file not found errors), it indicates:
#     * Data loading from `brats_dsc.csv` is likely working correctly.
#     * The `SSLDataset` class is correctly producing dictionaries with `image_weak`, `image_strong`, and `label` (when available).
#     * Transforms are compatible with the data and model input requirements.
#     * The model's forward pass works for both weakly and strongly augmented data.
#     * Loss functions can compute values based on model outputs and labels/pseudo-labels.
#     * The backward pass and optimizer step execute.
# * Check the printed shapes during the loop to ensure they are as expected (e.g., `(BatchSize, Channels, H, W, D)`).
# * Non-zero loss values indicate the model is learning *something*. Very high or NaN losses might indicate instability (check learning rate, normalization, etc.).
#
# **If this runs successfully, you have higher confidence that the full training script (`brats_ssl_train_script`) is set up correctly.**



MONAI version: 1.5.dev2514
Numpy version: 2.0.2
Pytorch version: 2.6.0+cu124
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: a3ea49fc4e600d131daadad61ea340df25fcfdaa
MONAI __file__: /home/users/<username>/dl/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.3.2
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
scipy version: 1.12.0
Pillow version: 11.1.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: NOT INSTALLED or UNKNOWN VERSION.
tqdm version: 4.67.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.8
pandas version: 2.2.3
einops version: 0.8.1
transformers version: 4.50.2
mlflow version: 2.21.2
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the 



Model, Loss, Optimizer Initialized.
--- Starting Mini Training Verification (Max Steps: 5) ---
Step 1: Unlabeled shapes - Weak: torch.Size([1, 4, 96, 96, 96]), Strong: torch.Size([1, 4, 96, 96, 96])
Step 1: Calculated Total Loss: 0.0207
Step 1: Backward pass and optimizer step completed.
------------------------------
Step 2: Unlabeled shapes - Weak: torch.Size([1, 4, 96, 96, 96]), Strong: torch.Size([1, 4, 96, 96, 96])
Step 2: Calculated Total Loss: 0.0150
Step 2: Backward pass and optimizer step completed.
------------------------------
Step 3: Unlabeled shapes - Weak: torch.Size([1, 4, 96, 96, 96]), Strong: torch.Size([1, 4, 96, 96, 96])
Step 3: Calculated Total Loss: 0.0152
Step 3: Backward pass and optimizer step completed.
------------------------------
Step 4: Unlabeled shapes - Weak: torch.Size([1, 4, 96, 96, 96]), Strong: torch.Size([1, 4, 96, 96, 96])
Step 4: Calculated Total Loss: 0.0119
Step 4: Backward pass and optimizer step completed.
------------------------------
Step 