In [None]:
# %% [markdown]
# # LongiTumorSense Model Training
# **Training on MU-Glioma-Post Dataset**
# - Segmentation: nnUNet
# - Classification: 3D DenseNet
# - Survival: CoxPH Model

In [None]:
!pip install monai torch torchvision nnunet pyradiomics lifelines pydicom nibabel wandb -q

In [3]:
import nibabel as nib
import numpy  as np
from sklearn.model_selection import train_test_split
import torch
import os
import monai
from monai.data import Dataset ,DataLoader
from monai.transforms import ( Compose , LoadImaged , EnsureChannelFirstd, ScaleIntensityd,RandRotated,RandFlipd,RandZoomd,ToTensord)
from monai.networks.nets import DenseNet121,Unet
from monai.metrics import DiceMetric
from monai.losses import DiceLoss, FocalLoss
import wandb
import pandas as pd
from lifelines import CoxPHFitter


In [4]:
import torch

In [5]:
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print(f"Using {device} device.")


Using cpu device.


In [6]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


**Data Clearner and saved into the drive**

In [None]:
import os
import shutil


raw_root = "/content/drive/My Drive/MU-Glioma-Post"

output_root="/content/drive/My Drive/clean_data"



imagesTr = os.path.join(output_root, "imagesTr")
labelsTr = os.path.join(output_root, "labelsTr")
os.makedirs(os.path.join(output_root, "imagesTs"), exist_ok=True)
os.makedirs(imagesTr, exist_ok=True)
os.makedirs(labelsTr, exist_ok=True)

print("Raw dataset path:", raw_root)
print("nnU-Net dataset path:", output_root)


progress_file = os.path.join(output_root, "converted_cases.txt")

if os.path.exists(progress_file):
    with open(progress_file, "r") as f:
        converted_cases = set(line.strip() for line in f)
else:
    converted_cases = set()
print(f"Found {len(converted_cases)} cases already processed.")

Raw dataset path: /content/drive/My Drive/MU-Glioma-Post
nnU-Net dataset path: /content/drive/My Drive/clean_data
Found 4 cases already processed.


In [None]:
def is_nifti(fname):
  return fname.endswith(".nii") or fname.endswith(".nii.gz")

In [None]:
mod_priority=[
      't1c','t1gd','t1ce',  # contrast-enhanced T1 variants
    't1n','t1',           # native T1
    'flair','t2f','t2flair','t2w','t2' # T2 /flair variants
]

In [None]:
def file_priority(fname):
  lf=fname.lower()
  for i,k in enumerate(mod_priority):
    if k in lf:
      return i
  return len(mod_priority) + hash(lf) % 1000

In [None]:
import re
import os
import shutil
from tqdm import tqdm
canonical_modalities = None




skipped = []
new_cases_count = 0





total_timepoints = sum(
    1 for p in sorted(os.listdir(raw_root))
    if os.path.isdir(os.path.join(raw_root, p))
    for tp in sorted(os.listdir(os.path.join(raw_root, p)))
    if os.path.isdir(os.path.join(raw_root, p, tp))
)







with tqdm(total=total_timepoints, desc="Processing cases") as pbar:
    for patient_id in sorted(os.listdir(raw_root)):
        patient_path = os.path.join(raw_root, patient_id)
        if not os.path.isdir(patient_path):
            pbar.update(1)
            continue






        for tp in sorted(os.listdir(patient_path)):
            tp_path = os.path.join(patient_path, tp)
            if not os.path.isdir(tp_path):
                pbar.update(1)
                continue



            tp_clean = re.sub(r"\s+", "_", tp)
            tp_clean = re.sub(r"[^A-Za-z0-9_-]", "_", tp_clean)
            case_id = f"{patient_id}_{tp_clean}"



            if case_id in converted_cases:
                pbar.update(1)
                continue


            files = [f for f in os.listdir(tp_path) if is_nifti(f)]
            if not files:
                skipped.append((patient_id, tp, "no nifti files"))
                pbar.update(1)
                continue



            label_candidates = [f for f in files if any(x in f.lower() for x in ["mask", "tumor", "seg", "label"])]
            if len(label_candidates) == 0:
                skipped.append((patient_id, tp, "no label found"))
                pbar.update(1)
                continue


            label_file = label_candidates[0]



            image_files = [f for f in files if f != label_file]
            if len(image_files) == 0:
                skipped.append((patient_id, tp, "no image files"))
                pbar.update(1)
                continue



            image_files_sorted = sorted(image_files, key=file_priority)
            if canonical_modalities is None:
                canonical_modalities = image_files_sorted.copy()
                print("\nDetected modality order (from first sample)")
                for idx, nm in enumerate(canonical_modalities):
                    print(f"{idx}: {nm}")
                print("If this order is wrong adjust mod_priority list in the script.")


            else:
                if len(image_files_sorted) != len(canonical_modalities):
                    skipped.append(
                        (patient_id, tp, f"modality count mismatch {len(image_files_sorted)} vs {len(canonical_modalities)}")
                    )
                    pbar.update(1)
                    continue



            for i, fname in enumerate(image_files_sorted):
                src = os.path.join(tp_path, fname)
                destination = os.path.join(imagesTr, f"{case_id}_{i:04d}.nii.gz")
                shutil.copy(src, destination)

            shutil.copy2(os.path.join(tp_path, label_file), os.path.join(labelsTr, f"{case_id}.nii.gz"))



            converted_cases.add(case_id)
            with open(progress_file, "a") as f:
                f.write(case_id + "\n")

            new_cases_count += 1
            pbar.update(1)

print(f"\nConversion finished. {len(converted_cases)} total cases processed so far.")
if skipped:
    print(f"{len(skipped)} timepoints skipped (see sample):")
    for s in skipped[:10]:
        print(" ", s)

print(f"imagesTr files: {len(os.listdir(imagesTr))}, labelsTr files: {len(os.listdir(labelsTr))}")
print(f"Newly processed this run: {new_cases_count}")

In [20]:
!rm -r /content/nnUNet_raw_data_base


**Get clean data from drive into local colab  with correct file naming structure and create json.data file according to nnUnet formet**

In [22]:
import os
import re
import json
import shutil
import nibabel as nib
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from pathlib import Path


config = {
    'drive_root': "/content/drive/MyDrive/clean_data",
    'local_root': "/content/nnUNet_raw_data_base/nnUNet_raw_data/Task001_MU-Glioma-Post",
    'image_pattern': r"(PatientID_\d+)_Timepoint_(\d+)_(\d{4})\.nii\.gz",
    'label_pattern': r"(PatientID_\d+)_Timepoint_(\d+)\.nii\.gz",
    'task_name': "MU-Glioma-Post",
    'expected_modalities': ["0000", "0001", "0002", "0003"],
    'create_test_folder': False,
    'handle_missing_modalities': "create_empty",
    'strict_validation': True,
    'label_classes': {"0": "background", "1": "tumor"},
}

modality_map = {
    "0000": {"suffix": "0000", "name": "t1gd", "order": 0},
    "0001": {"suffix": "0001", "name": "t1", "order": 1},
    "0002": {"suffix": "0002", "name": "t2", "order": 2},
    "0003": {"suffix": "0003", "name": "flair", "order": 3}
}


paths = {
    'source_images': Path(config['drive_root']) / "imagesTr",
    'source_labels': Path(config['drive_root']) / "labelsTr",
    'dest_images': Path(config['local_root']) / "imagesTr",
    'dest_labels': Path(config['local_root']) / "labelsTr"
}
paths['dest_images'].mkdir(parents=True, exist_ok=True)
paths['dest_labels'].mkdir(parents=True, exist_ok=True)

if config['create_test_folder']:
    paths['dest_imagesTs'] = Path(config['local_root']) / "imagesTs"
    paths['dest_imagesTs'].mkdir(parents=True, exist_ok=True)


def create_empty_volume(reference_path: Path, output_path: Path):
    """Create empty NIfTI matching reference geometry"""
    try:
        ref_img = nib.load(str(reference_path))
        empty_data = np.zeros(ref_img.shape, dtype=np.float32)
        empty_img = nib.Nifti1Image(empty_data, ref_img.affine, ref_img.header)
        nib.save(empty_img, str(output_path))
        return True
    except Exception as e:
        print(f"[ERROR] Failed to create empty volume: {output_path} | {e}")
        return False


case_counter = 1
patient_mapping = {}
case_data = defaultdict(dict)
warnings = []
errors = []

#process images
for fname in tqdm(list(paths['source_images'].iterdir()), desc="Processing images"):
    if not fname.name.endswith(".nii.gz"):
        continue
    match = re.match(config['image_pattern'], fname.name)
    if not match:
        warnings.append(f"SKIPPED: Invalid image filename - {fname.name}")
        continue
    patient_id, _, modality_idx = match.groups()
    if patient_id not in patient_mapping:
        case_id = f"Case_{case_counter:04d}"
        patient_mapping[patient_id] = case_id
        case_counter += 1
    case_id = patient_mapping[patient_id]
    modality_info = modality_map.get(modality_idx)
    if not modality_info:
        warnings.append(f"SKIPPED: Unknown modality {modality_idx} in {fname.name}")
        continue
    dest_path = paths['dest_images'] / f"{case_id}_{modality_info['suffix']}.nii.gz"
    shutil.copyfile(str(fname), str(dest_path))
    case_data[case_id].setdefault("modalities", []).append(modality_info)
    if "reference_volume" not in case_data[case_id]:
        case_data[case_id]["reference_volume"] = dest_path

#process labels
for fname in tqdm(list(paths['source_labels'].iterdir()), desc="Processing labels"):
    if not fname.name.endswith(".nii.gz"):
        continue
    match = re.match(config['label_pattern'], fname.name)
    if not match:
        warnings.append(f"SKIPPED: Invalid label filename - {fname.name}")
        continue
    patient_id, _ = match.groups()
    if patient_id not in patient_mapping:
        errors.append(f"ORPHAN LABEL: {fname.name}")
        continue
    case_id = patient_mapping[patient_id]
    dest_path = paths['dest_labels'] / f"{case_id}.nii.gz"
    shutil.copyfile(str(fname), str(dest_path))
    case_data[case_id]["label_exists"] = True

#  handle missing modalities
cases_to_remove = []
for case_id, data in case_data.items():
    if "modalities" not in data or not data.get("label_exists", False):
        cases_to_remove.append(case_id)
        continue
    found_modalities = {m["suffix"] for m in data["modalities"]}
    missing = set(config['expected_modalities']) - found_modalities
    if missing:
        if config['handle_missing_modalities'] == "skip_case":
            cases_to_remove.append(case_id)
            warnings.append(f"REMOVING CASE: {case_id} missing {sorted(missing)}")
        elif config['handle_missing_modalities'] == "create_empty":
            for mod in missing:
                empty_path = paths['dest_images'] / f"{case_id}_{mod}.nii.gz"
                if create_empty_volume(data["reference_volume"], empty_path):
                    data["modalities"].append(modality_map[mod])
                    warnings.append(f"CREATED EMPTY: {case_id}_{mod}.nii.gz")
                else:
                    errors.append(f"EMPTY VOLUME FAILED: {case_id}_{mod}.nii.gz")
                    cases_to_remove.append(case_id)
                    break

for cid in set(cases_to_remove):
    case_data.pop(cid, None)

# dataset.json file
modality_list = sorted(modality_map.values(), key=lambda x: x['order'])
is_multimodal = len(modality_list) > 1

dataset_json = {
    "name": config['task_name'],
    "description": "Post-operative glioma segmentation",
    "reference": "Your reference here",
    "licence": "Your license here",
    "release": "1.0",
    "modality": {str(i): m["name"] for i, m in enumerate(modality_list)},
    "labels": config['label_classes'],
    "numTraining": len([c for c in case_data.values() if c.get("label_exists")]),
    "numTest": 0,
    "training": [],
    "test": []
}

for case_id, data in case_data.items():
    if not data.get("label_exists"):
        continue
    image_paths = [f"./imagesTr/{case_id}_{m['suffix']}.nii.gz" for m in modality_list]
    entry = {
        "image": image_paths if is_multimodal else image_paths[0],
        "label": f"./labelsTr/{case_id}.nii.gz"
    }
    dataset_json["training"].append(entry)

# post validation

for i, case in enumerate(dataset_json["training"][:5]):
    assert (is_multimodal and isinstance(case['image'], list)) or (not is_multimodal and isinstance(case['image'], str)), \
        f"Case {i} modality type mismatch"
    for img_path in case['image'] if is_multimodal else [case['image']]:
        if not (paths['dest_images'] / Path(img_path).name).exists():
            raise FileNotFoundError(f"Missing image: {img_path}")
    if not (paths['dest_labels'] / Path(case['label']).name).exists():
        raise FileNotFoundError(f"Missing label: {case['label']}")

# save dataset.json
final_path = Path(config['local_root']) / "dataset.json"
with open(final_path, 'w') as f:
    json.dump(dataset_json, f, indent=4, ensure_ascii=False)

print(f"\nDATASET CREATION SUCCESSFUL! Total training cases: {len(dataset_json['training'])}")
print(f"Modality order: {[m['name'] for m in modality_list]}")
if warnings:
    print("\nWarnings:")
    for w in warnings:
        print(f" - {w}")
if errors:
    print("\nErrors:")
    for e in errors:
        print(f" - {e}")


Processing images: 100%|██████████| 2376/2376 [01:55<00:00, 20.55it/s]
Processing labels: 100%|██████████| 594/594 [00:05<00:00, 112.35it/s]


DATASET CREATION SUCCESSFUL! Total training cases: 203
Modality order: ['t1gd', 't1', 't2', 'flair']





**it converts your raw dataset filenames into the exact nnU-Net format**
Case_0001_T1.nii.gz → Case_0001_0000.nii.gz
Case_0001_FLAIR.nii.gz → Case_0001_0003.nii.gz

**verify your current folder structure and file consistency**

In [None]:
import os
from pathlib import Path
import shutil


root = Path("/content/nnUNet_raw_data_base/nnUNet_raw_data/Task001_MU-Glioma-Post")
imagesTr = root / "imagesTr"
labelsTr = root / "labelsTr"



modality_map = {
    "T1": "0000",
    "T2": "0001",
    "T1ce": "0002",
    "FLAIR": "0003"
}

# rename images

for f in imagesTr.glob("*.nii.gz"):
    fname = f.name
    case_id = fname.split("_")[0] + "_" + fname.split("_")[1]
    modality = None

    for key in modality_map.keys():
        if key in fname:
            modality = modality_map[key]
            break

    if modality is None:
        print(f" Skipping (unknown modality): {fname}")
        continue

    new_name = f"{case_id}_{modality}.nii.gz"
    new_path = imagesTr / new_name

    print(f"Renaming: {fname} → {new_name}")
    shutil.move(str(f), str(new_path))

# rename labels

for f in labelsTr.glob("*.nii.gz"):
    fname = f.name
    case_id = fname.split("_")[0] + "_" + fname.split("_")[1]  # e.g. Case_0001
    new_name = f"{case_id}.nii.gz"
    new_path = labelsTr / new_name

    if fname != new_name:
        print(f"Renaming label: {fname} → {new_name}")
        shutil.move(str(f), str(new_path))

print("\n Renaming complete! Your dataset should now follow nnUNet format.")


**Validating that dataset actually matches nnU-Net requirements before training**

In [52]:
from pathlib import Path


local_root = Path("/content/nnUNet_raw_data_base/nnUNet_raw_data/Task001_MU-Glioma-Post")
expected_modalities = ["0000", "0001", "0002", "0003"]

imagesTr = local_root / "imagesTr"
labelsTr = local_root / "labelsTr"
imagesTs = local_root / "imagesTs"


print("Verifying training data...")

cases = set("_".join(f.stem.replace(".nii","").split("_")[:2])
            for f in imagesTr.glob("*.nii.gz"))

errors = []

for case in cases:

    for mod in expected_modalities:
        img_path = imagesTr / f"{case}_{mod}.nii.gz"
        if not img_path.exists():
            errors.append(f"Missing modality {mod} for case {case}")

    label_path = labelsTr / f"{case}.nii.gz"
    if not label_path.exists():
        errors.append(f"Missing label for case {case}")

# verify test data
if imagesTs.exists():
    print("Verifying test data...")
    test_cases = set("_".join(f.stem.replace(".nii","").split("_")[:2])
                     for f in imagesTs.glob("*.nii.gz"))
    for case in test_cases:
        for mod in expected_modalities:
            img_path = imagesTs / f"{case}_{mod}.nii.gz"
            if not img_path.exists():
                errors.append(f"Missing modality {mod} for test case {case}")

# report
if not errors:
    print("All files verified! ")
else:
    print("Warnings/Errors found:")
    for e in errors:
        print(f" - {e}")

print(f"Total training cases: {len(cases)}")
if imagesTs.exists():
    print(f"Total test cases: {len(test_cases)}")


Verifying training data...
All files verified! 
Total training cases: 203


**After Disconnect:**

In [None]:
!pip install monai torch torchvision nnunet pyradiomics lifelines pydicom nibabel wandb -q

**Install a Python package directly from its GitHub source code, not from the normal package store (PyPI).”**

In [None]:
!pip install git+https://github.com/MIC-DKFZ/nnUNet.git

**This function loads an MRI file, converts it to a NumPy array, and scales all values to between 0 and 1 for easier analysis.**

In [None]:
import nibabel as nib
import numpy  as np

def load_and_preprocess(patient_path):
    img = nib.load(patient_path)
    data = img.get_fdata()
    data = (data - np.min(data)) / (np.max(data) - np.min(data))
    return data

**delete 0 bytes files**

In [26]:
import os

def clean_zero_byte_files(folder_path, verbose=True):
    """Safely delete 0-byte files and verify deletions.

    Args:
        folder_path: Path to directory to clean
        verbose: Whether to print deletion messages

    Returns:
        tuple: (deleted_files, failed_deletions)
    """
    deleted_files = []
    failed_deletions = []

    for root, _, files in os.walk(folder_path):
        for file in files:
            file_path = os.path.join(root, file)

            try:

                if os.path.isfile(file_path) and os.path.getsize(file_path) == 0:
                    if verbose:
                        print(f"Deleting 0-byte file: {file_path}")

                    os.remove(file_path)


                    if not os.path.exists(file_path):
                        deleted_files.append(file_path)
                    else:
                        failed_deletions.append(file_path)
                        if verbose:
                            print(f"Failed to delete: {file_path}")

            except Exception as e:
                failed_deletions.append(file_path)
                if verbose:
                    print(f"Error processing {file_path}: {str(e)}")

    return deleted_files, failed_deletions


deleted_images, failed_images = clean_zero_byte_files("/content/nnUNet_raw_data_base/nnUNet_raw_data/Task001_MU-Glioma-Post/imagesTr")
deleted_labels, failed_labels = clean_zero_byte_files("/content/nnUNet_raw_data_base/nnUNet_raw_data/Task001_MU-Glioma-Post/labelsTr")

print(f"\nResults:")
print(f"Deleted {len(deleted_images)} image files")
print(f"Failed to delete {len(failed_images)} image files")
print(f"Deleted {len(deleted_labels)} label files")
print(f"Failed to delete {len(failed_labels)} label files\n")

if failed_images or failed_labels:
    print("Warning: Some files couldn't be deleted. Check permissions.")


Results:
Deleted 0 image files
Failed to delete 0 image files
Deleted 0 label files
Failed to delete 0 label files



**this is used for skipped bad , unreachable ,missing samples instead of crashing**

In [None]:
from monai.transforms import LoadImaged
from monai.data import Dataset as MonaiDataset,DataLoader
class SafeDataset(MonaiDataset):
    def __getitem__(self, index):
        item = self.data[index]
        try:
            if self.transform is not None:
                transformed_item = self.transform(item)
                if transformed_item is not None:
                    return transformed_item
                else:
                    image=item.get("image","unknown")
                    print(f"Skipping sample at index: {index}:{image} (Transform returned None)")
                    return None

            return item

        except (FileNotFoundError, nib.filebasedimages.ImageFileError, RuntimeError)  as e:
              image=item.get("image","unknown")
              print(f"Skipping sample at index: {index}:{image} ({e})")
              return None
        except Exception as e:
             image= item.get("image","unknown")
             print(f"Skipping sample at index:  {index} due to unexpected error: {image} ({e})")
             return None

**t checks the batch list, kicks out the None entries, and only lets valid data into the collate party. **

In [None]:
from torch.utils.data._utils.collate import default_collate
def collate_skip_none(batch):
    batch = [item for item in batch if item is not None]
    if not batch:
        return []
    return default_collate(batch)

**Load images → put channel first → convert to PyTorch tensor.**

In [None]:
from monai.transforms import Compose

transform_basic=Compose([
    LoadImaged(keys=["image"],allow_missing_keys=True),
    EnsureChannelFirstd(keys=["image"]),
    ToTensord(keys=["image"])

])

In [None]:
batch_size=4
train_dataset_basic=SafeDataset(data=train_files,transform=transform_basic)
train_loader_basic=DataLoader(train_dataset_basic,batch_size=batch_size, collate_fn=collate_skip_none, shuffle=True)
batch_shape=next(iter(train_loader_basic))["image"].shape
print("Getting batches of shape:",batch_shape)

Getting batches of shape: torch.Size([4, 4, 240, 240, 155])


**check original file in drive**

In [None]:
import torch
from tqdm import tqdm
import os

def get_mean_std(dataset_loader_basic, resume_file=None, nonzero=True):
    """
    Compute per-channel mean and std from a DataLoader that yields dicts with key 'image'.

    Args:
        dataset_loader_basic: PyTorch DataLoader returning batches with ["image"] tensors
                              of shape [B, C, H, W] or [B, C, D, H, W]
        resume_file (str or None): Optional file to store the last processed batch index for resuming.
        nonzero (bool): If True, compute statistics only over nonzero voxels (ignores background).

    Returns:
        mean (torch.Tensor): Per-channel mean values.
        std (torch.Tensor): Per-channel standard deviations.
    """
    start_index = 0
    if resume_file is not None and os.path.exists(resume_file):
        with open(resume_file, "r") as f:
            start_index = int(f.read().strip() or 0)
        print(f"Resuming from batch index {start_index}...")

    first_batch = next(iter(dataset_loader_basic))
    num_channels = first_batch["image"].shape[1]

    if nonzero:
        channels_sum = torch.zeros(num_channels)
        channels_squared_sum = torch.zeros(num_channels)
        voxel_counts = torch.zeros(num_channels)
    else:
        channels_sum = torch.zeros(num_channels)
        channels_squared_sum = torch.zeros(num_channels)
        num_batches = 0

    for idx, batch in enumerate(tqdm(dataset_loader_basic, desc="Computing mean/std", unit="batch")):
        if idx < start_index:
            continue
        try:
            data = batch["image"].float()

            if nonzero:
                for c in range(num_channels):
                    mask = data[:, c] != 0
                    vals = data[:, c][mask]
                    if vals.numel() > 0:
                        channels_sum[c] += vals.sum()
                        channels_squared_sum[c] += (vals ** 2).sum()
                        voxel_counts[c] += vals.numel()
            else:
                dims = list(range(0, data.ndim))
                dims.remove(1)
                channels_sum += data.mean(dim=dims)
                channels_squared_sum += (data ** 2).mean(dim=dims)
                num_batches += 1

            if resume_file is not None:
                with open(resume_file, "w") as f:
                    f.write(str(idx + 1))

        except Exception as e:
            print(f"Error processing batch {idx}: {e}")

    if nonzero:
        if (voxel_counts == 0).any():
            raise ValueError("Some channels have no nonzero voxels.")
        mean = channels_sum / voxel_counts
        std = torch.sqrt(channels_squared_sum / voxel_counts - mean ** 2)
    else:
        if num_batches == 0:
            raise ValueError("No valid images found in the dataset.")
        mean = channels_sum / num_batches
        std = torch.sqrt(channels_squared_sum / num_batches - mean ** 2)

    return mean, std


In [None]:

mean, std = get_mean_std(train_loader_basic,resume_file=None)
print("Mean:", mean)
print("Std:", std)

Computing mean/std:  66%|██████▋   | 79/119 [08:13<04:14,  6.36s/batch]

Skipping sample at index: 261:['/content/clean_data_local/imagesTr/PatientID_0021_Timepoint_6_0000.nii.gz', '/content/clean_data_local/imagesTr/PatientID_0021_Timepoint_6_0001.nii.gz', '/content/clean_data_local/imagesTr/PatientID_0021_Timepoint_6_0002.nii.gz', '/content/clean_data_local/imagesTr/PatientID_0021_Timepoint_6_0003.nii.gz'] (applying transform <monai.transforms.io.dictionary.LoadImaged object at 0x7a4f006b15d0>)


Computing mean/std: 100%|██████████| 119/119 [12:22<00:00,  6.24s/batch]

Mean: tensor([305.7782, 271.5573, 247.0512, 447.8155])
Std: tensor([218.4011, 160.3656, 143.3827, 233.6108])





In [None]:
from monai.transforms import Compose, LoadImaged, EnsureChannelFirstd, ScaleIntensityd, NormalizeIntensityd, RandRotated, RandFlipd, RandZoomd, ToTensord

train_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    # ScaleIntensityd(keys=["image"]),
    NormalizeIntensityd(
        keys=["image"],
        subtrahend=mean.tolist(),
        divisor=std.tolist(),
        channel_wise=True,
        nonzero=True
    ),
    RandRotated(keys=["image", "label"], range_x=0.3, prob=0.5),
    RandFlipd(keys=["image", "label"], prob=0.5),
    RandZoomd(keys=["image", "label"], min_zoom=0.9, max_zoom=1.1, prob=0.5),
    ToTensord(keys=["image", "label"]),
])

test_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    # ScaleIntensityd(keys=["image"]),
    NormalizeIntensityd(
        keys=["image"],
        subtrahend=mean.tolist(),
        divisor=std.tolist(),
        channel_wise=True,
        nonzero=True
    ),
    ToTensord(keys=["image", "label"]),
])



In [None]:
batch_size=4
train_dataset_norm =SafeDataset(data=train_files,transform=train_transforms)
dataset_loader_norm=DataLoader(train_dataset_norm,batch_size=batch_size, collate_fn=collate_skip_none)
batch = next(iter(dataset_loader_norm))
if batch:
    batch_shape=batch["image"].shape
    print("Getting batches of shape:",batch_shape)
else:
    print("No valid batches were loaded.")

Getting batches of shape: torch.Size([4, 4, 240, 240, 155])


In [None]:


norm_mean, norm_std = get_mean_std(dataset_loader_norm,resume_file=None)

print(f"Mean: {norm_mean}")
print(f"Standard deviation: {norm_std}")

Computing mean/std:  55%|█████▍    | 65/119 [22:13<17:56, 19.93s/batch]

Skipping sample at index: 261:['/content/clean_data_local/imagesTr/PatientID_0021_Timepoint_6_0000.nii.gz', '/content/clean_data_local/imagesTr/PatientID_0021_Timepoint_6_0001.nii.gz', '/content/clean_data_local/imagesTr/PatientID_0021_Timepoint_6_0002.nii.gz', '/content/clean_data_local/imagesTr/PatientID_0021_Timepoint_6_0003.nii.gz'] (applying transform <monai.transforms.io.dictionary.LoadImaged object at 0x7a4f004a7cd0>)


Computing mean/std: 100%|██████████| 119/119 [39:41<00:00, 20.01s/batch]

Mean: tensor([0.0047, 0.0046, 0.0091, 0.0047])
Standard deviation: tensor([0.9846, 0.9852, 0.9753, 0.9688])





In [None]:
batch_size=2
test_dataset_norm=SafeDataset(data=test_files,transform=test_transforms)
dataset_loader_test_norm=DataLoader(test_dataset_norm,batch_size=batch_size,collate_fn=collate_skip_none,shuffle=False)
batch_shape=next(iter(dataset_loader_test_norm))["image"].shape
print("Getting batches of shape:",batch_shape)
print(type(test_dataset_norm))

Getting batches of shape: torch.Size([2, 4, 240, 240, 155])
<class '__main__.SafeDataset'>


In [None]:
norm_mean, norm_std = get_mean_std(dataset_loader_test_norm)

print(f"Mean: {norm_mean}")
print(f"Standard deviation: {norm_std}")

Computing mean/std:  90%|█████████ | 54/60 [03:36<00:22,  3.71s/batch]

Skipping sample at index: 109:['/content/clean_data_local/imagesTr/PatientID_0006_Timepoint_6_0000.nii.gz', '/content/clean_data_local/imagesTr/PatientID_0006_Timepoint_6_0001.nii.gz', '/content/clean_data_local/imagesTr/PatientID_0006_Timepoint_6_0002.nii.gz', '/content/clean_data_local/imagesTr/PatientID_0006_Timepoint_6_0003.nii.gz'] (applying transform <monai.transforms.io.dictionary.LoadImaged object at 0x7a4f006a8690>)


Computing mean/std: 100%|██████████| 60/60 [03:58<00:00,  3.97s/batch]

Mean: tensor([0.2668, 0.2856, 0.2744, 0.2978])
Standard deviation: tensor([0.8363, 0.7954, 0.8210, 0.8185])





In [None]:

print(type(dataset_loader_norm))
print(type(dataset_loader_test_norm))

<class 'monai.data.dataloader.DataLoader'>
<class 'monai.data.dataloader.DataLoader'>


In [None]:
from torch.utils.data import random_split
length_dataset = len(train_dataset_norm)
length_train = int(length_dataset * 0.8)
length_remaining = length_dataset - length_train
train_subset, remaining_subset = random_split(train_dataset_norm, [length_train, length_remaining])

percent_train = np.round(100 * len(train_subset) / length_dataset, 2)

print(f"Train data is {percent_train}% of full data")

NameError: name 'train_dataset_norm' is not defined

In [None]:
length_dataset_test = len(test_dataset_norm)
length_test = int(length_dataset_test * 0.2)
length_remaining_test = length_dataset_test - length_test

test_subset, remaining_subset_test = random_split(test_dataset_norm, [length_test, length_remaining_test])

percent_test = np.round(100 * len(test_subset) / length_dataset_test, 2)

print(f"Test data is {percent_test}% of full test data")


Test data is 19.33% of full test data



**# Convert dataset to nnUNet format**

In [27]:
import os
os.environ['nnUNet_raw_data_base'] = '/content/nnUNet_raw_data_base'
os.environ['nnUNet_preprocessed'] = '/content/nnUNet_preprocessed'
os.environ['RESULTS_FOLDER'] = '/content/nnUNet_results'


os.makedirs('/content/nnUNet_raw_data_base', exist_ok=True)
os.makedirs('/content/nnUNet_preprocessed', exist_ok=True)
os.makedirs('/content/nnUNet_results', exist_ok=True)

print("nnUNet_raw_data_base =", os.environ['nnUNet_raw_data_base'])
print("nnUNet_preprocessed =", os.environ['nnUNet_preprocessed'])
print("RESULTS_FOLDER =", os.environ['RESULTS_FOLDER'])

nnUNet_raw_data_base = /content/nnUNet_raw_data_base
nnUNet_preprocessed = /content/nnUNet_preprocessed
RESULTS_FOLDER = /content/nnUNet_results


**inshallah ho jhaye ga**

In [53]:
import os, json

task_name = "Task001_MU-Glioma-Post"
base = f"/content/nnUNet_raw_data_base/nnUNet_raw_data/{task_name}"
imagesTr = os.path.join(base, "imagesTr")
labelsTr = os.path.join(base, "labelsTr")


cases = sorted([
    f.replace("_0000.nii.gz", "")
    for f in os.listdir(imagesTr) if f.endswith("_0000.nii.gz")
])

dataset = {
    "name": "MU-Glioma-Post",
    "description": "Post-operative glioma segmentation",
    "reference": "Your reference here",
    "licence": "Your license here",
    "release": "1.0",
    "modality": {
        "0": "t1gd",
        "1": "t1",
        "2": "t2",
        "3": "flair"
    },
    "labels": {
        "0": "background",
        "1": "tumor",
        "2": "edema",
        "3": "enhancing_tumor",
        "4": "necrosis"



    },
    "numTraining": len(cases),
    "numTest": 0,
    "training": [],
    "test": []
}


for case in cases:
    dataset["training"].append({
        "image": f"./imagesTr/{case}.nii.gz",
        "label": f"./labelsTr/{case}.nii.gz"
    })

# save
with open(os.path.join(base, "dataset.json"), "w") as f:
    json.dump(dataset, f, indent=4)

print(f"✅ dataset.json created with {len(cases)} training cases")


✅ dataset.json created with 203 training cases


**This script converts all multi-class tumor labels into a single binary tumor label, ensuring nnU-Net will treat it as a binary segmentation task.**

In [54]:
import os
import nibabel as nib
import numpy as np

label_dir = "/content/nnUNet_raw_data_base/nnUNet_raw_data/Task001_MU-Glioma-Post/labelsTr"

for fname in os.listdir(label_dir):
    if fname.endswith(".nii.gz"):
        path = os.path.join(label_dir, fname)
        img = nib.load(path)
        data = img.get_fdata().astype(np.uint8)

        # remap all labels >1 to 1
        data[data > 1] = 1

        new_img = nib.Nifti1Image(data, img.affine, img.header)
        nib.save(new_img, path)

print(" Labels cleaned: all values >1 mapped to 1")


✅ Labels cleaned: all values >1 mapped to 1


In [49]:
!cd /content/nnUNet_raw_data_base/nnUNet_raw_data/Task001_MU-Glioma-Post/labelsTr && \
for f in *.nii.gz.nii.gz; do mv "$f" "${f%.nii.gz}"; done


**verification function**

Directly validates dataset.json: Ensures your JSON file matches the actual disk structure.

Catches mismatches: Sometimes renaming and dataset.json creation get out of sync — this flags them.

Guarantees nnU-Net won’t fail later: Preprocessing requires dataset.json to be perfectly aligned with real files.

In [55]:
import os, json

base = "/content/nnUNet_raw_data_base/nnUNet_raw_data/Task001_MU-Glioma-Post"
with open(os.path.join(base, "dataset.json")) as f:
    dataset = json.load(f)

print("Num cases in dataset.json:", dataset["numTraining"])
print("First 3 cases:")
for case in dataset["training"][:3]:
    print(case)

# Check matching files actually exist
for case in dataset["training"]:
    case_id = os.path.basename(case["image"]).replace(".nii.gz", "")
    for m in range(4):
        img_path = os.path.join(base, "imagesTr", f"{case_id}_{m:04d}.nii.gz")
        if not os.path.isfile(img_path):
            print("❌ Missing modality:", img_path)
    lbl_path = os.path.join(base, "labelsTr", f"{case_id}.nii.gz")
    if not os.path.isfile(lbl_path):
        print("❌ Missing label:", lbl_path)


Num cases in dataset.json: 203
First 3 cases:
{'image': './imagesTr/Case_0001.nii.gz', 'label': './labelsTr/Case_0001.nii.gz'}
{'image': './imagesTr/Case_0002.nii.gz', 'label': './labelsTr/Case_0002.nii.gz'}
{'image': './imagesTr/Case_0003.nii.gz', 'label': './labelsTr/Case_0003.nii.gz'}


In [56]:
!nnUNet_plan_and_preprocess -t 1 --verify_dataset_integrity



Please cite the following paper when using nnUNet:

Isensee, F., Jaeger, P.F., Kohl, S.A.A. et al. "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation." Nat Methods (2020). https://doi.org/10.1038/s41592-020-01008-z


If you have questions or suggestions, feel free to open an issue at https://github.com/MIC-DKFZ/nnUNet

Verifying training set
checking case Case_0001
checking case Case_0002
checking case Case_0003
checking case Case_0004
checking case Case_0005
checking case Case_0006
checking case Case_0007
checking case Case_0008
checking case Case_0009
checking case Case_0010
checking case Case_0011
checking case Case_0012
checking case Case_0013
checking case Case_0014
checking case Case_0015
checking case Case_0016
checking case Case_0017
checking case Case_0018
checking case Case_0019
checking case Case_0020
checking case Case_0021
checking case Case_0022
checking case Case_0023
checking case Case_0024
checking case Case_0025
checking case C

In [60]:
!nnUNet_train 3d_fullres nnUNetTrainerV2 Task001_MU-Glioma-Post 0 --npz





Please cite the following paper when using nnUNet:

Isensee, F., Jaeger, P.F., Kohl, S.A.A. et al. "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation." Nat Methods (2020). https://doi.org/10.1038/s41592-020-01008-z


If you have questions or suggestions, feel free to open an issue at https://github.com/MIC-DKFZ/nnUNet

###############################################
I am running the following nnUNet: 3d_fullres
My trainer class is:  <class 'nnunet.training.network_training.nnUNetTrainerV2.nnUNetTrainerV2'>
For that I will be using the following configuration:
num_classes:  1
modalities:  {0: 't1gd', 1: 't1', 2: 't2', 3: 'flair'}
use_mask_for_norm OrderedDict([(0, True), (1, True), (2, True), (3, True)])
keep_only_largest_region None
min_region_size_per_class None
min_size_per_class None
normalization_schemes OrderedDict([(0, 'nonCT'), (1, 'nonCT'), (2, 'nonCT'), (3, 'nonCT')])
stages...

stage:  0
{'batch_size': 2, 'num_pool_per_axis': [5, 5, 5]

In [None]:
!export nnUNet_n_proc_DA=1 && nnUNet_train 3d_fullres nnUNetTrainerV2 Task001_MU-Glioma-Post 0 --npz




Please cite the following paper when using nnUNet:

Isensee, F., Jaeger, P.F., Kohl, S.A.A. et al. "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation." Nat Methods (2020). https://doi.org/10.1038/s41592-020-01008-z


If you have questions or suggestions, feel free to open an issue at https://github.com/MIC-DKFZ/nnUNet

###############################################
I am running the following nnUNet: 3d_fullres
My trainer class is:  <class 'nnunet.training.network_training.nnUNetTrainerV2.nnUNetTrainerV2'>
For that I will be using the following configuration:
num_classes:  1
modalities:  {0: 't1gd', 1: 't1', 2: 't2', 3: 'flair'}
use_mask_for_norm OrderedDict([(0, True), (1, True), (2, True), (3, True)])
keep_only_largest_region None
min_region_size_per_class None
min_size_per_class None
normalization_schemes OrderedDict([(0, 'nonCT'), (1, 'nonCT'), (2, 'nonCT'), (3, 'nonCT')])
stages...

stage:  0
{'batch_size': 2, 'num_pool_per_axis': [5, 5, 5]

**trick to check progress quickly**