
# From-Scratch 3D Segmentation Pipeline — Starter Notebook

This notebook gives you a **clean, reproducible** workflow to train, predict and evaluate with the `medssl_from_scratch` repo you downloaded.

- It reads **nnU-Net v2** plans to mirror *spacing* and *patch size* (for fair comparison).
- It reuses **fold-0 splits** from nnU-Net so your results align 1:1.
- It runs the provided training & inference scripts directly from the notebook.

> **Tip:** Place this notebook **in the repo root** (`~/projects/medssl_from_scratch`) or set `REPO_ROOT` below accordingly.


In [None]:

# ⬅️ set your repo root (adjust if needed)
REPO_ROOT = "/home/htetaung/projects/medssl_from_scratch"   # change if different
DATA_ROOT = "/home/htetaung/data"                            # where MSD + nnU-Net folders live
TASK = "Task02_Heart"                                        # change to "Task09_Spleen" when needed
FOLD = 0                                                     # nnU-Net fold to mirror
CLASSES = 2                                                  # Heart & Spleen are binary (bg + organ)
PATCH = (96, 96, 96)                                         # you can replace with nnU-Net patch (see below)


In [None]:

# Make the repo importable and enable autoreload for quick edits
import sys, os
sys.path.append(REPO_ROOT)

%load_ext autoreload
%autoreload 2

print("Repo in sys.path?", REPO_ROOT in sys.path)


In [None]:

# Verify PyTorch + GPU
import torch
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("Device:", torch.cuda.get_device_name(0))



## Mirror nnU-Net v2 plans (spacing & patch size)

We read `nnUNetPlans.json` so you can **match** nnU-Net's target spacing and patch size. If the keys differ across versions, we try a few common locations.


In [None]:

import json

def read_nnunet_cfg(ds_id, name, data_root=DATA_ROOT):
    base = f"{data_root}/nnunet_preprocessed/Dataset0{ds_id:02d}_{name}"
    with open(f"{base}/nnUNetPlans.json") as f:
        plans = json.load(f)
    cfg = plans.get("configurations", {}).get("3d_fullres", {})
    spacing = (cfg.get("spacing") or cfg.get("resampling_target_spacing") or
               plans.get("target_spacing") or plans.get("spacing"))
    patch_size = tuple(cfg.get("patch_size") or ())
    batch_size = cfg.get("batch_size")
    return spacing, patch_size, batch_size

# Heart = dataset 2, Spleen = dataset 9
ds_id, name = (2, "Heart") if TASK=="Task02_Heart" else (9, "Spleen")
spacing, nn_patch, nn_bs = read_nnunet_cfg(ds_id, name)
print(f"Dataset {ds_id} {name} -> spacing={spacing}, patch_size={nn_patch}, batch_size={nn_bs}")



## Use nnU-Net fold-0 split for fair comparison
This cell reads `splits_final.json` and writes list files our training script expects.


In [None]:

import os, json

os.makedirs(f"{REPO_ROOT}/lists", exist_ok=True)

sp = f"{DATA_ROOT}/nnunet_preprocessed/Dataset0{ds_id:02d}_{name}/splits_final.json"
with open(sp) as f:
    split = json.load(f)[FOLD]

def to_pair(case):
    img = f"{DATA_ROOT}/MSD/{TASK}/imagesTr/{case}.nii.gz"
    lab = f"{DATA_ROOT}/MSD/{TASK}/labelsTr/{case}.nii.gz"
    if not (os.path.exists(img) and os.path.exists(lab)):
        raise FileNotFoundError(f"Missing {img} or {lab}")
    return f"{img},{lab}"

train_pairs = [to_pair(c) for c in split["train"]]
val_pairs   = [to_pair(c) for c in split["val"]]

train_list = f"{REPO_ROOT}/lists/{TASK}_train_fold{FOLD}.txt"
val_list   = f"{REPO_ROOT}/lists/{TASK}_val_fold{FOLD}.txt"
open(train_list, "w").write("\n".join(train_pairs) + "\n")
open(val_list, "w").write("\n".join(val_pairs) + "\n")

print(f"Wrote\n  {train_list} ({len(train_pairs)})\n  {val_list} ({len(val_pairs)})")



## (Optional) Set PATCH from nnU-Net plans
If you want to match exactly (recommended when VRAM allows), set `PATCH` to the `nn_patch` printed earlier.


In [None]:

# Uncomment to mirror nnU-Net patch size (if it fits 12GB VRAM)
# if nn_patch: PATCH = tuple(nn_patch)
print("Using PATCH =", PATCH)



## Train from-scratch U-Net (fold-0)

This runs the provided training script with AMP. Reduce `--batch` or `PATCH` if you hit OOM.


In [None]:

%%bash
cd "$REPO_ROOT"
python scripts/train_unet3d.py   --train_list "lists/${TASK}_train_fold0.txt"   --val_list   "lists/${TASK}_val_fold0.txt"   --out_dir "runs/${TASK}_unet_fs_fold0"   --epochs 50 --batch 2 --patch ${PATCH[0]} ${PATCH[1]} ${PATCH[2]} --classes ${CLASSES}



## Predict (sliding-window) on one test case
Replace the image path with a real test case from `imagesTs`.


In [None]:

%%bash
cd "$REPO_ROOT"
python scripts/predict_sliding.py   --image "/home/htetaung/data/MSD/${TASK}/imagesTs/la_001.nii.gz"   --checkpoint "runs/${TASK}_unet_fs_fold0/unet3d_best.pth"   --out "preds/${TASK}_la_001_pred.nii.gz"   --classes ${CLASSES}



## (Optional) Evaluate Dice/HD95 if GT available
For a quick spot-check on a training/validation case.


In [None]:

%%bash
cd "$REPO_ROOT"
python scripts/eval_masks.py   --pred "preds/${TASK}_la_001_pred.nii.gz"   --gt   "/home/htetaung/data/MSD/${TASK}/labelsTr/la_001.nii.gz"   --label 1



---

### Tips
- **Editable install (optional):** `pip install -e .` in the repo root so code edits are picked up by imports.
- **Autoreload** is enabled; after editing modules under `medssl_fs/`, cells will use the newest code.
- To switch to **Spleen**, set `TASK = "Task09_Spleen"` at the top and re-run the notebook.
- If nnU-Net’s patch is larger than what fits on your GPU, use a slightly smaller `PATCH` and note the difference in your dissertation.
