In [None]:
# %% [markdown]
# # LongiTumorSense Model Training
# **Training on MU-Glioma-Post Dataset**
# - Segmentation: nnUNet
# - Classification: 3D DenseNet
# - Survival: CoxPH Model

In [None]:
!pip install monai torch torchvision nnunet pyradiomics lifelines pydicom nibabel wandb -q

KeyboardInterrupt: 

In [None]:
import nibabel as nib
import numpy  as np
from sklearn.model_selection import train_test_split
import torch
import os
import monai
from monai.data import Dataset ,DataLoader
from monai.transforms import ( Compose , LoadImaged , EnsureChannelFirstd, ScaleIntensityd,RandRotated,RandFlipd,RandZoomd,ToTensord)
from monai.networks.nets import DenseNet121,Unet
from monai.metrics import DiceMetric
from monai.losses import DiceLoss, FocalLoss
import wandb
import pandas as pd
from lifelines import CoxPHFitter


In [None]:
import torch

In [None]:
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print(f"Using {device} device.")


Using cpu device.


In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


**Data Clearner and saved into the drive**

In [None]:
import os
import shutil


raw_root = "/content/drive/My Drive/MU-Glioma-Post"

output_root="/content/drive/My Drive/clean_data"



imagesTr = os.path.join(output_root, "imagesTr")
labelsTr = os.path.join(output_root, "labelsTr")
os.makedirs(os.path.join(output_root, "imagesTs"), exist_ok=True)
os.makedirs(imagesTr, exist_ok=True)
os.makedirs(labelsTr, exist_ok=True)

print("Raw dataset path:", raw_root)
print("nnU-Net dataset path:", output_root)


progress_file = os.path.join(output_root, "converted_cases.txt")

if os.path.exists(progress_file):
    with open(progress_file, "r") as f:
        converted_cases = set(line.strip() for line in f)
else:
    converted_cases = set()
print(f"Found {len(converted_cases)} cases already processed.")

Raw dataset path: /content/drive/My Drive/MU-Glioma-Post
nnU-Net dataset path: /content/drive/My Drive/clean_data
Found 4 cases already processed.


In [None]:
def is_nifti(fname):
  return fname.endswith(".nii") or fname.endswith(".nii.gz")

In [None]:
mod_priority=[
      't1c','t1gd','t1ce',  # contrast-enhanced T1 variants
    't1n','t1',           # native T1
    'flair','t2f','t2flair','t2w','t2' # T2 /flair variants
]

In [None]:
def file_priority(fname):
  lf=fname.lower()
  for i,k in enumerate(mod_priority):
    if k in lf:
      return i
  return len(mod_priority) + hash(lf) % 1000

In [None]:
import re
import os
import shutil
from tqdm import tqdm
canonical_modalities = None




skipped = []
new_cases_count = 0





total_timepoints = sum(
    1 for p in sorted(os.listdir(raw_root))
    if os.path.isdir(os.path.join(raw_root, p))
    for tp in sorted(os.listdir(os.path.join(raw_root, p)))
    if os.path.isdir(os.path.join(raw_root, p, tp))
)







with tqdm(total=total_timepoints, desc="Processing cases") as pbar:
    for patient_id in sorted(os.listdir(raw_root)):
        patient_path = os.path.join(raw_root, patient_id)
        if not os.path.isdir(patient_path):
            pbar.update(1)
            continue






        for tp in sorted(os.listdir(patient_path)):
            tp_path = os.path.join(patient_path, tp)
            if not os.path.isdir(tp_path):
                pbar.update(1)
                continue



            tp_clean = re.sub(r"\s+", "_", tp)
            tp_clean = re.sub(r"[^A-Za-z0-9_-]", "_", tp_clean)
            case_id = f"{patient_id}_{tp_clean}"



            if case_id in converted_cases:
                pbar.update(1)
                continue


            files = [f for f in os.listdir(tp_path) if is_nifti(f)]
            if not files:
                skipped.append((patient_id, tp, "no nifti files"))
                pbar.update(1)
                continue



            label_candidates = [f for f in files if any(x in f.lower() for x in ["mask", "tumor", "seg", "label"])]
            if len(label_candidates) == 0:
                skipped.append((patient_id, tp, "no label found"))
                pbar.update(1)
                continue


            label_file = label_candidates[0]



            image_files = [f for f in files if f != label_file]
            if len(image_files) == 0:
                skipped.append((patient_id, tp, "no image files"))
                pbar.update(1)
                continue



            image_files_sorted = sorted(image_files, key=file_priority)
            if canonical_modalities is None:
                canonical_modalities = image_files_sorted.copy()
                print("\nDetected modality order (from first sample)")
                for idx, nm in enumerate(canonical_modalities):
                    print(f"{idx}: {nm}")
                print("If this order is wrong adjust mod_priority list in the script.")


            else:
                if len(image_files_sorted) != len(canonical_modalities):
                    skipped.append(
                        (patient_id, tp, f"modality count mismatch {len(image_files_sorted)} vs {len(canonical_modalities)}")
                    )
                    pbar.update(1)
                    continue



            for i, fname in enumerate(image_files_sorted):
                src = os.path.join(tp_path, fname)
                destination = os.path.join(imagesTr, f"{case_id}_{i:04d}.nii.gz")
                shutil.copy(src, destination)

            shutil.copy2(os.path.join(tp_path, label_file), os.path.join(labelsTr, f"{case_id}.nii.gz"))



            converted_cases.add(case_id)
            with open(progress_file, "a") as f:
                f.write(case_id + "\n")

            new_cases_count += 1
            pbar.update(1)

print(f"\nConversion finished. {len(converted_cases)} total cases processed so far.")
if skipped:
    print(f"{len(skipped)} timepoints skipped (see sample):")
    for s in skipped[:10]:
        print(" ", s)

print(f"imagesTr files: {len(os.listdir(imagesTr))}, labelsTr files: {len(os.listdir(labelsTr))}")
print(f"Newly processed this run: {new_cases_count}")

In [None]:
!rm -r /content/nnUNet_raw_data_base


**Get clean data from drive into local colab  with correct file naming structure and create json.data file according to nnUnet formet**

In [None]:
import os
import re
import json
import shutil
import numpy as np
import nibabel as nib
from tqdm import tqdm
from collections import defaultdict


config = {
    'drive_root': "/content/drive/MyDrive/clean_data",
    'local_root': "/content/nnUNet_raw_data_base/nnUNet_raw_data/Task001_MU-Glioma-Post",
    'image_pattern': r"(PatientID_\d+)_Timepoint_(\d+)_(\d{4})\.nii\.gz",
    'label_pattern': r"(PatientID_\d+)_Timepoint_(\d+)\.nii\.gz",
    'task_name': "MU-Glioma-Post",
    'expected_modalities': ["0000", "0001", "0002", "0003"],
    'create_test_folder': False,
    'handle_missing_modalities': "create_empty",
    'strict_validation': True
}


modality_map = {
    "0000": {"suffix": "0000", "name": "t1gd", "order": 0},
    "0001": {"suffix": "0001", "name": "t1", "order": 1},
    "0002": {"suffix": "0002", "name": "t2", "order": 2},
    "0003": {"suffix": "0003", "name": "flair", "order": 3}
}

def create_empty_volume(reference_nii_path, output_path):
    """Create float32 NIfTI volume matching reference geometry"""
    try:
        ref_img = nib.load(reference_nii_path)
        empty_data = np.zeros(ref_img.shape, dtype=np.float32)
        empty_img = nib.Nifti1Image(empty_data, ref_img.affine, ref_img.header)
        nib.save(empty_img, output_path)
        return True
    except Exception as e:
        print(f"Failed to create empty volume: {str(e)}")
        return False


paths = {
    'source_images': os.path.join(config['drive_root'], "imagesTr"),
    'source_labels': os.path.join(config['drive_root'], "labelsTr"),
    'dest_images': os.path.join(config['local_root'], "imagesTr"),
    'dest_labels': os.path.join(config['local_root'], "labelsTr"),
}


os.makedirs(paths['dest_images'], exist_ok=True)
os.makedirs(paths['dest_labels'], exist_ok=True)
if config['create_test_folder']:
    paths['dest_imagesTs'] = os.path.join(config['local_root'], "imagesTs")
    os.makedirs(paths['dest_imagesTs'], exist_ok=True)

print("\nRunning strict nnU-Net dataset preparation...")
warnings = []
errors = []


case_counter = 1
patient_mapping = {}
case_data = defaultdict(dict)

for fname in tqdm(os.listdir(paths['source_images']), desc="Processing images"):
    match = re.match(config['image_pattern'], fname)
    if not match:
        warnings.append(f"SKIPPED: Invalid image filename - {fname}")
        continue

    patient_id, _, modality_idx = match.groups()

    if patient_id not in patient_mapping:
        case_id = f"Case_{case_counter:04d}"
        patient_mapping[patient_id] = case_id
        case_counter += 1

    case_id = patient_mapping[patient_id]
    modality_info = modality_map.get(modality_idx)

    if not modality_info:
        warnings.append(f"SKIPPED: Unknown modality {modality_idx} in {fname}")
        continue

    new_name = f"{case_id}_{modality_info['suffix']}.nii.gz"
    src = os.path.join(paths['source_images'], fname)
    dest = os.path.join(paths['dest_images'], new_name)

    try:
        shutil.copyfile(src, dest)
        case_data[case_id].setdefault("modalities", []).append(modality_info)
        if len(case_data[case_id]["modalities"]) == 1:
            case_data[case_id]["reference_volume"] = dest
    except Exception as e:
        errors.append(f"COPY FAILED: {src} → {dest} | {str(e)}")


for fname in tqdm(os.listdir(paths['source_labels']), desc="Processing labels"):
    match = re.match(config['label_pattern'], fname)
    if not match:
        warnings.append(f"SKIPPED: Invalid label filename - {fname}")
        continue

    patient_id, _ = match.groups()

    if patient_id not in patient_mapping:
        errors.append(f"ORPHAN LABEL: No matching images for {fname}")
        continue

    case_id = patient_mapping[patient_id]
    new_name = f"{case_id}.nii.gz"
    src = os.path.join(paths['source_labels'], fname)
    dest = os.path.join(paths['dest_labels'], new_name)

    try:
        shutil.copyfile(src, dest)
        case_data[case_id]["label_exists"] = True
    except Exception as e:
        errors.append(f"LABEL COPY FAILED: {src} → {dest} | {str(e)}")


cases_to_remove = []
for case_id, data in case_data.items():

    if "modalities" not in data:
        errors.append(f"NO IMAGES: {case_id} has no valid images")
        cases_to_remove.append(case_id)
        continue

    if not data.get("label_exists"):
        if config['strict_validation']:
            errors.append(f"MISSING LABEL: {case_id} has images but no label")
            cases_to_remove.append(case_id)
        else:
            warnings.append(f"UNLABELED CASE: {case_id} has no label (skipping)")
            cases_to_remove.append(case_id)
        continue


    found_modalities = {m["suffix"] for m in data["modalities"]}
    missing_modalities = set(config['expected_modalities']) - found_modalities

    if missing_modalities:
        if config['handle_missing_modalities'] == "skip_case":
            cases_to_remove.append(case_id)
            warnings.append(f"REMOVING CASE: {case_id} missing modalities {sorted(missing_modalities)}")
        elif config['handle_missing_modalities'] == "create_empty":
            for mod in missing_modalities:
                empty_path = os.path.join(paths['dest_images'], f"{case_id}_{mod}.nii.gz")
                if create_empty_volume(data["reference_volume"], empty_path):
                    case_data[case_id]["modalities"].append(modality_map[mod])
                    warnings.append(f"CREATED EMPTY: {case_id}_{mod}.nii.gz")
                else:
                    errors.append(f"EMPTY VOLUME FAILED: {case_id}_{mod}.nii.gz")
                    cases_to_remove.append(case_id)
                    break


for cid in set(cases_to_remove):
    case_data.pop(cid, None)

# ==================== GENERATE DATASET.JSON ====================
modality_list = sorted(modality_map.values(), key=lambda x: x['order'])

# ====== 1. PRE-VALIDATION ======
print("Running pre-validation checks...")

# Check if dataset is single or multi-modal
is_multimodal = len(modality_list) > 1
print(f"Dataset type: {'Multi-modal' if is_multimodal else 'Single-modal'}")

# Verify all expected files exist
all_files_valid = True
for case_id, data in case_data.items():
    if not data.get("label_exists"):
        continue

    # Check modalities
    for m in modality_list:
        img_path = os.path.join(paths['dest_images'], f"{case_id}_{m['suffix']}.nii.gz")
        if not os.path.exists(img_path):
            print(f"Missing image: {img_path}")
            all_files_valid = False

    # Check label
    label_path = os.path.join(paths['dest_labels'], f"{case_id}.nii.gz")
    if not os.path.exists(label_path):
        print(f"Missing label: {label_path}")
        all_files_valid = False

if not all_files_valid:
    raise ValueError("Pre-validation failed: Missing files detected")

# ====== 2. DATASET GENERATION ======
dataset_json = {
    "name": config['task_name'],
    "description": "Post-operative glioma segmentation",
    "reference": "Your reference here",
    "licence": "Your license here",
    "release": "1.0",
    "modality": {str(i): m["name"] for i, m in enumerate(modality_list)},
    "labels": {
        "0": "background",
        "1": "tumor"
    },
    "numTraining": len([c for c in case_data.values() if c.get("label_exists")]),
    "numTest": 0,
    "training": [],
    "test": []
}

for case_id, data in case_data.items():
    if not data.get("label_exists"):
        continue

    # Get modalities in correct order
    available_modalities = []
    for m in modality_list:  # Follow predefined modality order
        img_path = os.path.join(paths['dest_images'], f"{case_id}_{m['suffix']}.nii.gz")
        if os.path.exists(img_path):
            available_modalities.append(m)

    # Create paths - guaranteed correct order
    image_paths = [f"./imagesTr/{case_id}_{m['suffix']}.nii.gz" for m in available_modalities]

    # CRITICAL: Force consistent format based on dataset type
    if is_multimodal:
        training_entry = {
            "image": image_paths,  # Always list for multi-modal
            "label": f"./labelsTr/{case_id}.nii.gz"
        }
    else:
        training_entry = {
            "image": image_paths[0],  # Always string for single-modal
            "label": f"./labelsTr/{case_id}.nii.gz"
        }

    dataset_json["training"].append(training_entry)

# ====== 3. POST-VALIDATION ======
print("\nRunning post-validation...")

# Check first 5 cases
for i, case in enumerate(dataset_json["training"][:5]):
    print(f"\nCase {i}:")
    print(f"Image type: {type(case['image'])}")
    print(f"Content: {case['image']}")

    # Verify type consistency
    if is_multimodal and not isinstance(case['image'], list):
        raise ValueError(f"Case {i} should be multi-modal but got single path")
    if not is_multimodal and not isinstance(case['image'], str):
        raise ValueError(f"Case {i} should be single-modal but got list")

# Verify all paths exist
print("\nVerifying file paths...")
for case in dataset_json["training"]:
    if is_multimodal:
        for img_path in case['image']:
            if not os.path.exists(os.path.join(config['local_root'], img_path.replace("./", ""))):
                raise ValueError(f"Missing image: {img_path}")
    else:
        if not os.path.exists(os.path.join(config['local_root'], case['image'].replace("./", ""))):
            raise ValueError(f"Missing image: {case['image']}")

    label_path = os.path.join(config['local_root'], case['label'].replace("./", ""))
    if not os.path.exists(label_path):
        raise ValueError(f"Missing label: {case['label']}")

# ====== 4. SAFE SAVE ======
print("\nSaving dataset.json with atomic write...")
temp_path = os.path.join(config['local_root'], "dataset.json.tmp")
final_path = os.path.join(config['local_root'], "dataset.json")

try:
    with open(temp_path, 'w') as f:
        json.dump(dataset_json, f, indent=4, ensure_ascii=False)
    os.replace(temp_path, final_path)
except Exception as e:
    if os.path.exists(temp_path):
        os.remove(temp_path)
    raise ValueError(f"Failed to save dataset.json: {str(e)}")

print("\nDATASET CREATION SUCCESSFUL!")
print(f"Total training cases: {len(dataset_json['training'])}")
print(f"Modality order: {[m['name'] for m in modality_list]}")


Running strict nnU-Net dataset preparation...


Processing images: 100%|██████████| 2376/2376 [02:36<00:00, 15.18it/s]
Processing labels: 100%|██████████| 594/594 [00:05<00:00, 108.94it/s]

Running pre-validation checks...
Dataset type: Multi-modal

Running post-validation...

Case 0:
Image type: <class 'list'>
Content: ['./imagesTr/Case_0001_0000.nii.gz', './imagesTr/Case_0001_0001.nii.gz', './imagesTr/Case_0001_0002.nii.gz', './imagesTr/Case_0001_0003.nii.gz']

Case 1:
Image type: <class 'list'>
Content: ['./imagesTr/Case_0002_0000.nii.gz', './imagesTr/Case_0002_0001.nii.gz', './imagesTr/Case_0002_0002.nii.gz', './imagesTr/Case_0002_0003.nii.gz']

Case 2:
Image type: <class 'list'>
Content: ['./imagesTr/Case_0003_0000.nii.gz', './imagesTr/Case_0003_0001.nii.gz', './imagesTr/Case_0003_0002.nii.gz', './imagesTr/Case_0003_0003.nii.gz']

Case 3:
Image type: <class 'list'>
Content: ['./imagesTr/Case_0004_0000.nii.gz', './imagesTr/Case_0004_0001.nii.gz', './imagesTr/Case_0004_0002.nii.gz', './imagesTr/Case_0004_0003.nii.gz']

Case 4:
Image type: <class 'list'>
Content: ['./imagesTr/Case_0005_0000.nii.gz', './imagesTr/Case_0005_0001.nii.gz', './imagesTr/Case_0005_0002.nii.gz',




**After Disconnect:**

In [None]:
!pip install monai torch torchvision nnunet pyradiomics lifelines pydicom nibabel wandb -q

**Install a Python package directly from its GitHub source code, not from the normal package store (PyPI).”**

In [None]:
!pip install git+https://github.com/MIC-DKFZ/nnUNet.git

Collecting git+https://github.com/MIC-DKFZ/nnUNet.git
  Cloning https://github.com/MIC-DKFZ/nnUNet.git to /tmp/pip-req-build-qolt2mfe
  Running command git clone --filter=blob:none --quiet https://github.com/MIC-DKFZ/nnUNet.git /tmp/pip-req-build-qolt2mfe
  Resolved https://github.com/MIC-DKFZ/nnUNet.git to commit 8c4184d46b60059ff7dc8f74cd535e13554bdeca
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting acvl-utils<0.3,>=0.2.3 (from nnunetv2==2.6.2)
  Downloading acvl_utils-0.2.5.tar.gz (29 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting dynamic-network-architectures<0.5,>=0.4.1 (from nnunetv2==2.6.2)
  Downloading dynamic_network_architectures-0.4.2.tar.gz (28 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting imagecodecs (from nnunetv2==2.6.2)
  Downloading imagecodecs-2025.8.2-cp311-cp311-manylinux_2_27_x86_64.manylin

**This function loads an MRI file, converts it to a NumPy array, and scales all values to between 0 and 1 for easier analysis.**

In [None]:
import nibabel as nib
import numpy  as np

def load_and_preprocess(patient_path):
    img = nib.load(patient_path)
    data = img.get_fdata()
    data = (data - np.min(data)) / (np.max(data) - np.min(data))
    return data

**delete 0 bytes files**

In [None]:
import os

def clean_zero_byte_files(folder_path, verbose=True):
    """Safely delete 0-byte files and verify deletions.

    Args:
        folder_path: Path to directory to clean
        verbose: Whether to print deletion messages

    Returns:
        tuple: (deleted_files, failed_deletions)
    """
    deleted_files = []
    failed_deletions = []

    for root, _, files in os.walk(folder_path):
        for file in files:
            file_path = os.path.join(root, file)

            try:

                if os.path.isfile(file_path) and os.path.getsize(file_path) == 0:
                    if verbose:
                        print(f"Deleting 0-byte file: {file_path}")

                    os.remove(file_path)


                    if not os.path.exists(file_path):
                        deleted_files.append(file_path)
                    else:
                        failed_deletions.append(file_path)
                        if verbose:
                            print(f"Failed to delete: {file_path}")

            except Exception as e:
                failed_deletions.append(file_path)
                if verbose:
                    print(f"Error processing {file_path}: {str(e)}")

    return deleted_files, failed_deletions


deleted_images, failed_images = clean_zero_byte_files("/content/nnUNet_raw_data_base/nnUNet_raw_data/Task001_MU-Glioma-Post/imagesTr")
deleted_labels, failed_labels = clean_zero_byte_files("/content/nnUNet_raw_data_base/nnUNet_raw_data/Task001_MU-Glioma-Post/labelsTr")

print(f"\nResults:")
print(f"Deleted {len(deleted_images)} image files")
print(f"Failed to delete {len(failed_images)} image files")
print(f"Deleted {len(deleted_labels)} label files")
print(f"Failed to delete {len(failed_labels)} label files\n")

if failed_images or failed_labels:
    print("Warning: Some files couldn't be deleted. Check permissions.")


Results:
Deleted 0 image files
Failed to delete 0 image files
Deleted 0 label files
Failed to delete 0 label files



In [None]:
import wandb

# Print your default W&B username (entity)
print("Your W&B username:", wandb.Api().default_entity)

# Alternative: Check after login
wandb.login()
print("Logged in as:", wandb.Api().default_entity)

In [None]:
wandb.init(project="LongiTumorSense",entity="numl-f21-35629-numl")

In [None]:
import os
from sklearn.model_selection import train_test_split

def prepare_dataset(imagesTr, labelsTr, test_size=0.2):

    image_files = [f for f in os.listdir(imagesTr) if f.endswith(".nii.gz")]
    case_ids = sorted(list(set("_".join(f.split("_")[:-1]) for f in image_files)))

    print(f"Found {len(case_ids)} unique cases.")

    # Split into train and test
    train_cases, test_cases = train_test_split(case_ids, test_size=test_size, random_state=42)

    missing_labels = []

    def build_file_list(cases):
        file_list = []
        for case_id in cases:
            # Build list of all 4 modalities for this case
            modalities = [
                os.path.join(imagesTr, f"{case_id}_0000.nii.gz"),  # FLAIR
                os.path.join(imagesTr, f"{case_id}_0001.nii.gz"),  # T1
                os.path.join(imagesTr, f"{case_id}_0002.nii.gz"),  # T1ce
                os.path.join(imagesTr, f"{case_id}_0003.nii.gz")   # T2
            ]
            label_path = os.path.join(labelsTr, f"{case_id}.nii.gz")

            if not os.path.exists(label_path):
                missing_labels.append(case_id)
                continue

            file_list.append({
                "image": modalities,
                "label": label_path,
                "name": case_id
            })
        return file_list

    train_files = build_file_list(train_cases)
    test_files = build_file_list(test_cases)

    print(f"Length of training dataset: {len(train_files)}")
    print(f"Length of validation dataset: {len(test_files)}")

    if missing_labels:
        print(f"Missing labels for {len(missing_labels)} cases: {missing_labels[:10]}{'...' if len(missing_labels) > 10 else ''}")

    return train_files, test_files


In [None]:

train_files, test_files = prepare_dataset(
    "/content/clean_data_local/imagesTr",
    "/content/clean_data_local/labelsTr"
)


Found 594 unique cases.
Length of training dataset: 474
Length of validation dataset: 119
Missing labels for 1 cases: ['PatientID_0275_Timepoint_6']


**this is used for skipped bad , unreachable ,missing samples instead of crashing**

In [None]:
from monai.transforms import LoadImaged
from monai.data import Dataset as MonaiDataset,DataLoader
class SafeDataset(MonaiDataset):
    def __getitem__(self, index):
        item = self.data[index]
        try:
            if self.transform is not None:
                transformed_item = self.transform(item)
                if transformed_item is not None:
                    return transformed_item
                else:
                    image=item.get("image","unknown")
                    print(f"Skipping sample at index: {index}:{image} (Transform returned None)")
                    return None

            return item

        except (FileNotFoundError, nib.filebasedimages.ImageFileError, RuntimeError)  as e:
              image=item.get("image","unknown")
              print(f"Skipping sample at index: {index}:{image} ({e})")
              return None
        except Exception as e:
             image= item.get("image","unknown")
             print(f"Skipping sample at index:  {index} due to unexpected error: {image} ({e})")
             return None

**t checks the batch list, kicks out the None entries, and only lets valid data into the collate party. **

In [None]:
from torch.utils.data._utils.collate import default_collate
def collate_skip_none(batch):
    batch = [item for item in batch if item is not None]
    if not batch:
        return []
    return default_collate(batch)

**Load images → put channel first → convert to PyTorch tensor.**

In [None]:
from monai.transforms import Compose

transform_basic=Compose([
    LoadImaged(keys=["image"],allow_missing_keys=True),
    EnsureChannelFirstd(keys=["image"]),
    ToTensord(keys=["image"])

])

In [None]:
batch_size=4
train_dataset_basic=SafeDataset(data=train_files,transform=transform_basic)
train_loader_basic=DataLoader(train_dataset_basic,batch_size=batch_size, collate_fn=collate_skip_none, shuffle=True)
batch_shape=next(iter(train_loader_basic))["image"].shape
print("Getting batches of shape:",batch_shape)

Getting batches of shape: torch.Size([4, 4, 240, 240, 155])


**check original file in drive**

In [None]:
import torch
from tqdm import tqdm
import os

def get_mean_std(dataset_loader_basic, resume_file=None, nonzero=True):
    """
    Compute per-channel mean and std from a DataLoader that yields dicts with key 'image'.

    Args:
        dataset_loader_basic: PyTorch DataLoader returning batches with ["image"] tensors
                              of shape [B, C, H, W] or [B, C, D, H, W]
        resume_file (str or None): Optional file to store the last processed batch index for resuming.
        nonzero (bool): If True, compute statistics only over nonzero voxels (ignores background).

    Returns:
        mean (torch.Tensor): Per-channel mean values.
        std (torch.Tensor): Per-channel standard deviations.
    """
    start_index = 0
    if resume_file is not None and os.path.exists(resume_file):
        with open(resume_file, "r") as f:
            start_index = int(f.read().strip() or 0)
        print(f"Resuming from batch index {start_index}...")

    first_batch = next(iter(dataset_loader_basic))
    num_channels = first_batch["image"].shape[1]

    if nonzero:
        channels_sum = torch.zeros(num_channels)
        channels_squared_sum = torch.zeros(num_channels)
        voxel_counts = torch.zeros(num_channels)
    else:
        channels_sum = torch.zeros(num_channels)
        channels_squared_sum = torch.zeros(num_channels)
        num_batches = 0

    for idx, batch in enumerate(tqdm(dataset_loader_basic, desc="Computing mean/std", unit="batch")):
        if idx < start_index:
            continue
        try:
            data = batch["image"].float()

            if nonzero:
                for c in range(num_channels):
                    mask = data[:, c] != 0
                    vals = data[:, c][mask]
                    if vals.numel() > 0:
                        channels_sum[c] += vals.sum()
                        channels_squared_sum[c] += (vals ** 2).sum()
                        voxel_counts[c] += vals.numel()
            else:
                dims = list(range(0, data.ndim))
                dims.remove(1)
                channels_sum += data.mean(dim=dims)
                channels_squared_sum += (data ** 2).mean(dim=dims)
                num_batches += 1

            if resume_file is not None:
                with open(resume_file, "w") as f:
                    f.write(str(idx + 1))

        except Exception as e:
            print(f"Error processing batch {idx}: {e}")

    if nonzero:
        if (voxel_counts == 0).any():
            raise ValueError("Some channels have no nonzero voxels.")
        mean = channels_sum / voxel_counts
        std = torch.sqrt(channels_squared_sum / voxel_counts - mean ** 2)
    else:
        if num_batches == 0:
            raise ValueError("No valid images found in the dataset.")
        mean = channels_sum / num_batches
        std = torch.sqrt(channels_squared_sum / num_batches - mean ** 2)

    return mean, std


In [None]:

mean, std = get_mean_std(train_loader_basic,resume_file=None)
print("Mean:", mean)
print("Std:", std)

Computing mean/std:  66%|██████▋   | 79/119 [08:13<04:14,  6.36s/batch]

Skipping sample at index: 261:['/content/clean_data_local/imagesTr/PatientID_0021_Timepoint_6_0000.nii.gz', '/content/clean_data_local/imagesTr/PatientID_0021_Timepoint_6_0001.nii.gz', '/content/clean_data_local/imagesTr/PatientID_0021_Timepoint_6_0002.nii.gz', '/content/clean_data_local/imagesTr/PatientID_0021_Timepoint_6_0003.nii.gz'] (applying transform <monai.transforms.io.dictionary.LoadImaged object at 0x7a4f006b15d0>)


Computing mean/std: 100%|██████████| 119/119 [12:22<00:00,  6.24s/batch]

Mean: tensor([305.7782, 271.5573, 247.0512, 447.8155])
Std: tensor([218.4011, 160.3656, 143.3827, 233.6108])





In [None]:
from monai.transforms import Compose, LoadImaged, EnsureChannelFirstd, ScaleIntensityd, NormalizeIntensityd, RandRotated, RandFlipd, RandZoomd, ToTensord

train_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    # ScaleIntensityd(keys=["image"]),
    NormalizeIntensityd(
        keys=["image"],
        subtrahend=mean.tolist(),
        divisor=std.tolist(),
        channel_wise=True,
        nonzero=True
    ),
    RandRotated(keys=["image", "label"], range_x=0.3, prob=0.5),
    RandFlipd(keys=["image", "label"], prob=0.5),
    RandZoomd(keys=["image", "label"], min_zoom=0.9, max_zoom=1.1, prob=0.5),
    ToTensord(keys=["image", "label"]),
])

test_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    # ScaleIntensityd(keys=["image"]),
    NormalizeIntensityd(
        keys=["image"],
        subtrahend=mean.tolist(),
        divisor=std.tolist(),
        channel_wise=True,
        nonzero=True
    ),
    ToTensord(keys=["image", "label"]),
])



In [None]:
batch_size=4
train_dataset_norm =SafeDataset(data=train_files,transform=train_transforms)
dataset_loader_norm=DataLoader(train_dataset_norm,batch_size=batch_size, collate_fn=collate_skip_none)
batch = next(iter(dataset_loader_norm))
if batch:
    batch_shape=batch["image"].shape
    print("Getting batches of shape:",batch_shape)
else:
    print("No valid batches were loaded.")

Getting batches of shape: torch.Size([4, 4, 240, 240, 155])


In [None]:


norm_mean, norm_std = get_mean_std(dataset_loader_norm,resume_file=None)

print(f"Mean: {norm_mean}")
print(f"Standard deviation: {norm_std}")

Computing mean/std:  55%|█████▍    | 65/119 [22:13<17:56, 19.93s/batch]

Skipping sample at index: 261:['/content/clean_data_local/imagesTr/PatientID_0021_Timepoint_6_0000.nii.gz', '/content/clean_data_local/imagesTr/PatientID_0021_Timepoint_6_0001.nii.gz', '/content/clean_data_local/imagesTr/PatientID_0021_Timepoint_6_0002.nii.gz', '/content/clean_data_local/imagesTr/PatientID_0021_Timepoint_6_0003.nii.gz'] (applying transform <monai.transforms.io.dictionary.LoadImaged object at 0x7a4f004a7cd0>)


Computing mean/std: 100%|██████████| 119/119 [39:41<00:00, 20.01s/batch]

Mean: tensor([0.0047, 0.0046, 0.0091, 0.0047])
Standard deviation: tensor([0.9846, 0.9852, 0.9753, 0.9688])





In [None]:
batch_size=2
test_dataset_norm=SafeDataset(data=test_files,transform=test_transforms)
dataset_loader_test_norm=DataLoader(test_dataset_norm,batch_size=batch_size,collate_fn=collate_skip_none,shuffle=False)
batch_shape=next(iter(dataset_loader_test_norm))["image"].shape
print("Getting batches of shape:",batch_shape)
print(type(test_dataset_norm))

Getting batches of shape: torch.Size([2, 4, 240, 240, 155])
<class '__main__.SafeDataset'>


In [None]:
norm_mean, norm_std = get_mean_std(dataset_loader_test_norm)

print(f"Mean: {norm_mean}")
print(f"Standard deviation: {norm_std}")

Computing mean/std:  90%|█████████ | 54/60 [03:36<00:22,  3.71s/batch]

Skipping sample at index: 109:['/content/clean_data_local/imagesTr/PatientID_0006_Timepoint_6_0000.nii.gz', '/content/clean_data_local/imagesTr/PatientID_0006_Timepoint_6_0001.nii.gz', '/content/clean_data_local/imagesTr/PatientID_0006_Timepoint_6_0002.nii.gz', '/content/clean_data_local/imagesTr/PatientID_0006_Timepoint_6_0003.nii.gz'] (applying transform <monai.transforms.io.dictionary.LoadImaged object at 0x7a4f006a8690>)


Computing mean/std: 100%|██████████| 60/60 [03:58<00:00,  3.97s/batch]

Mean: tensor([0.2668, 0.2856, 0.2744, 0.2978])
Standard deviation: tensor([0.8363, 0.7954, 0.8210, 0.8185])





In [None]:

print(type(dataset_loader_norm))
print(type(dataset_loader_test_norm))

<class 'monai.data.dataloader.DataLoader'>
<class 'monai.data.dataloader.DataLoader'>


In [None]:
from torch.utils.data import random_split
length_dataset = len(train_dataset_norm)
length_train = int(length_dataset * 0.8)
length_remaining = length_dataset - length_train
train_subset, remaining_subset = random_split(train_dataset_norm, [length_train, length_remaining])

percent_train = np.round(100 * len(train_subset) / length_dataset, 2)

print(f"Train data is {percent_train}% of full data")

NameError: name 'train_dataset_norm' is not defined

In [None]:
length_dataset_test = len(test_dataset_norm)
length_test = int(length_dataset_test * 0.2)
length_remaining_test = length_dataset_test - length_test

test_subset, remaining_subset_test = random_split(test_dataset_norm, [length_test, length_remaining_test])

percent_test = np.round(100 * len(test_subset) / length_dataset_test, 2)

print(f"Test data is {percent_test}% of full test data")


Test data is 19.33% of full test data



**# Convert dataset to nnUNet format**

In [None]:
import os
os.environ['nnUNet_raw_data_base'] = '/content/nnUNet_raw_data_base'
os.environ['nnUNet_preprocessed'] = '/content/nnUNet_preprocessed'
os.environ['RESULTS_FOLDER'] = '/content/nnUNet_results'


os.makedirs('/content/nnUNet_raw_data_base', exist_ok=True)
os.makedirs('/content/nnUNet_preprocessed', exist_ok=True)
os.makedirs('/content/nnUNet_results', exist_ok=True)

print("nnUNet_raw_data_base =", os.environ['nnUNet_raw_data_base'])
print("nnUNet_preprocessed =", os.environ['nnUNet_preprocessed'])
print("RESULTS_FOLDER =", os.environ['RESULTS_FOLDER'])

nnUNet_raw_data_base = /content/nnUNet_raw_data_base
nnUNet_preprocessed = /content/nnUNet_preprocessed
RESULTS_FOLDER = /content/nnUNet_results


**inshallah ho jhaye ga**

In [None]:
import os
import json
from collections import defaultdict

# Configuration
config = {
    'local_root': "/content/nnUNet_raw_data_base/nnUNet_raw_data/Task001_MU-Glioma-Post",
    'task_name': "MU-Glioma-Post",
    'modality_map': {
        "0000": {"name": "t1gd", "order": 0},
        "0001": {"name": "t1", "order": 1},
        "0002": {"name": "t2", "order": 2},
        "0003": {"name": "flair", "order": 3}
    }
}

# Path setup
paths = {
    'imagesTr': os.path.join(config['local_root'], "imagesTr"),
    'labelsTr': os.path.join(config['local_root'], "labelsTr"),
    'imagesTs': os.path.join(config['local_root'], "imagesTs")  # Optional test folder
}

# 1. Determine actual modality type from files
modality_suffixes = set()
for f in os.listdir(paths['imagesTr']):
    if f.endswith(".nii.gz") and "_" in f:
        suffix = f.split("_")[-1].split(".")[0]
        if suffix.isdigit():
            modality_suffixes.add(suffix)

is_multimodal = len(modality_suffixes) > 1
print(f"Dataset type: {'Multi-modal' if is_multimodal else 'Single-modal'}")
print(f"Found modalities: {sorted(modality_suffixes)}")

# 2. Generate dataset.json with guaranteed correct format
dataset_json = {
    "name": config['task_name'],
    "description": "Post-operative glioma segmentation",
    "reference": "Your reference here",
    "licence": "Your license here",
    "release": "1.0",
    "modality": {str(i): config['modality_map'][suffix]["name"]
                for i, suffix in enumerate(sorted(modality_suffixes))},
    "labels": {
        "0": "background",
        "1": "tumor"
    },
    "training": [],
    "test": []  # Will be populated below
}

# 3. Process training cases
case_ids = sorted({f.split("_")[0] for f in os.listdir(paths['imagesTr']) if f.endswith(".nii.gz")})
for case_id in case_ids:
    if is_multimodal:
        # Multi-modal - sorted list of paths
        image_paths = sorted(
            [f"./imagesTr/{f}" for f in os.listdir(paths['imagesTr'])
             if f.startswith(case_id) and f.endswith(".nii.gz")],
            key=lambda x: int(x.split("_")[-1].split(".")[0])
        )
        training_entry = {
            "image": image_paths,
            "label": f"./labelsTr/{case_id}.nii.gz"
        }
    else:
        # Single-modal - string path
        img_file = next(f for f in os.listdir(paths['imagesTr'])
                       if f.startswith(case_id) and f.endswith(".nii.gz"))
        training_entry = {
            "image": f"./imagesTr/{img_file}",
            "label": f"./labelsTr/{case_id}.nii.gz"
        }

    # Verify label exists before adding
    if os.path.exists(os.path.join(config['local_root'], training_entry['label'][2:])):
        dataset_json["training"].append(training_entry)
    else:
        print(f"Warning: Missing label for {case_id}")

# 4. Handle test set (if exists)
if os.path.exists(paths['imagesTs']):
    test_files = [f for f in os.listdir(paths['imagesTs']) if f.endswith(".nii.gz")]
    dataset_json["test"] = [f"./imagesTs/{f}" for f in test_files]
    dataset_json["numTest"] = len(test_files)
else:
    dataset_json["test"] = []
    dataset_json["numTest"] = 0

dataset_json["numTraining"] = len(dataset_json["training"])

# 5. Final validation and save
print("\nFinal validation:")
print(f"Training cases: {dataset_json['numTraining']}")
print(f"Test cases: {dataset_json['numTest']}")
print("First training case:")
print(json.dumps(dataset_json["training"][0], indent=2))

output_path = os.path.join(config['local_root'], "dataset.json")
with open(output_path, 'w') as f:
    json.dump(dataset_json, f, indent=4)

print("\nDataset successfully created at:", output_path)

Dataset type: Multi-modal
Found modalities: ['0000', '0001', '0002', '0003']

Final validation:
Training cases: 0
Test cases: 0
First training case:


IndexError: list index out of range

In [None]:
import os
import json
from collections import defaultdict

# 1. First determine REAL modality type from files
image_dir = "/content/nnUNet_raw_data_base/nnUNet_raw_data/Task001_MU-Glioma-Post/imagesTr"
modality_suffixes = set()
for f in os.listdir(image_dir):
    if f.endswith(".nii.gz"):
        parts = f.split("_")
        if len(parts) > 2:  # Has modality suffix
            modality_suffixes.add(parts[-1].split(".")[0])

is_multimodal = len(modality_suffixes) > 1
print(f"Actual modality type: {'Multi-modal' if is_multimodal else 'Single-modal'}")
print(f"Found modalities: {sorted(modality_suffixes)}")

# 2. Generate dataset.json with PROPER format
dataset_json = {
    "name": "MU-Glioma-Post",
    "description": "Post-operative glioma segmentation",
    "modality": {str(i): f"modality_{i}" for i in range(len(modality_suffixes))},
    "labels": {"0": "background", "1": "tumor"},
    "numTraining": len([f for f in os.listdir(f"{image_dir}/../labelsTr") if f.endswith(".nii.gz")]),
    "training": []
}

# Get all unique case IDs
case_ids = sorted(set(f.split("_")[0] for f in os.listdir(image_dir) if f.endswith(".nii.gz")))

for case_id in case_ids:
    if is_multimodal:
        # Multi-modal - use sorted list of paths
        image_paths = sorted(
            [f"./imagesTr/{f}" for f in os.listdir(image_dir)
             if f.startswith(case_id) and f.endswith(".nii.gz")],
            key=lambda x: int(x.split("_")[-1].split(".")[0])
        ) # Added missing parenthesis here
        dataset_json["training"].append({
            "image": image_paths,
            "label": f"./labelsTr/{case_id}.nii.gz"
        })
    else:
        # Single-modal - use string path
        img_file = next(f for f in os.listdir(image_dir)
                      if f.startswith(case_id) and f.endswith(".nii.gz"))
        dataset_json["training"].append({
            "image": f"./imagesTr/{img_file}",
            "label": f"./labelsTr/{case_id}.nii.gz"
        })

# 3. Save with verification
output_path = f"{image_dir}/../dataset.json"
with open(output_path, 'w') as f:
    json.dump(dataset_json, f, indent=4)

print("\nFinal validation:")
print("First training case:")
print(json.dumps(dataset_json["training"][0], indent=2))
print(f"Image field type: {type(dataset_json['training'][0]['image'])}")

Actual modality type: Multi-modal
Found modalities: ['0000', '0001', '0002', '0003']

Final validation:
First training case:
{
  "image": [
    "./imagesTr/Case_0128_0000.nii.gz",
    "./imagesTr/Case_0087_0000.nii.gz",
    "./imagesTr/Case_0187_0000.nii.gz",
    "./imagesTr/Case_0074_0000.nii.gz",
    "./imagesTr/Case_0043_0000.nii.gz",
    "./imagesTr/Case_0032_0000.nii.gz",
    "./imagesTr/Case_0181_0000.nii.gz",
    "./imagesTr/Case_0075_0000.nii.gz",
    "./imagesTr/Case_0064_0000.nii.gz",
    "./imagesTr/Case_0065_0000.nii.gz",
    "./imagesTr/Case_0123_0000.nii.gz",
    "./imagesTr/Case_0161_0000.nii.gz",
    "./imagesTr/Case_0104_0000.nii.gz",
    "./imagesTr/Case_0092_0000.nii.gz",
    "./imagesTr/Case_0142_0000.nii.gz",
    "./imagesTr/Case_0085_0000.nii.gz",
    "./imagesTr/Case_0155_0000.nii.gz",
    "./imagesTr/Case_0110_0000.nii.gz",
    "./imagesTr/Case_0125_0000.nii.gz",
    "./imagesTr/Case_0050_0000.nii.gz",
    "./imagesTr/Case_0101_0000.nii.gz",
    "./imagesTr/Case

In [None]:
!nnUNet_plan_and_preprocess -t 1 --verify_dataset_integrity



Please cite the following paper when using nnUNet:

Isensee, F., Jaeger, P.F., Kohl, S.A.A. et al. "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation." Nat Methods (2020). https://doi.org/10.1038/s41592-020-01008-z


If you have questions or suggestions, feel free to open an issue at https://github.com/MIC-DKFZ/nnUNet

Traceback (most recent call last):
  File "/usr/local/bin/nnUNet_plan_and_preprocess", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/nnunet/experiment_planning/nnUNet_plan_and_preprocess.py", line 105, in main
    verify_dataset_integrity(join(nnUNet_raw_data, task_name))
  File "/usr/local/lib/python3.11/dist-packages/nnunet/preprocessing/sanity_checks.py", line 107, in verify_dataset_integrity
    test_cases = dataset['test']
                 ~~~~~~~^^^^^^^^
KeyError: 'test'


In [None]:
!nnUNet_train 3d_fullres nnUNetTrainerV2 Task001_Glioma 0 --npz



Please cite the following paper when using nnUNet:

Isensee, F., Jaeger, P.F., Kohl, S.A.A. et al. "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation." Nat Methods (2020). https://doi.org/10.1038/s41592-020-01008-z


If you have questions or suggestions, feel free to open an issue at https://github.com/MIC-DKFZ/nnUNet

Traceback (most recent call last):
  File "/usr/local/bin/nnUNet_train", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/nnunet/run/run_training.py", line 140, in main
    trainer_class = get_default_configuration(network, task, network_trainer, plans_identifier)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/nnunet/run/default_configuration.py", line 47, in get_default_configuration
    plans = load_pickle(plans_file)
            ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/li