# Segmentation Example

> Train a U-Net for pixelwise segmentation of the prostate


In [None]:
%load_ext autoreload
%autoreload 2


In [None]:
!conda activate p158

In [None]:
# # if you have a requirements.txt:
!pip install --upgrade pip
!pip install monai torch pytorch-ignite matplotlib matplotlib-inline pyyaml munch imageio tqdm pandas opencv-python nibabel "protobuf<=3.20.3"
!pip install tensorboard
!pip install scikit-image

# # Otherwise install core libs directly:
# !pip install monai["all"] ignite matplotlib pyyaml munch

In [None]:
import monai
import ignite
import torch

from monai.data.meta_tensor import MetaTensor

# Add MetaTensor to the safe globals list
torch.serialization.add_safe_globals([MetaTensor])

from prostate158.utils import load_config
from prostate158.train import SegmentationTrainer
from prostate158.report import ReportGenerator
from prostate158.viewer import ListViewer
import prostate158.utils as utils
from prostate158.utils import load_config
import psutil
import subprocess
import os

In [None]:
# 0) Helper to print system + GPU memory
def print_memory_stats(stage=""):
    # System RAM
    mem = psutil.virtual_memory()
    print(
        f"\n[MEMORY] {stage} ▶ System RAM: "
        f"total {mem.total/1e9:.1f} GB, used {mem.used/1e9:.1f} GB ({mem.percent}%)"
    )
    # GPU RAM (if available)
    if torch.cuda.is_available():
        # call nvidia-smi
        print("[MEMORY] GPU status via nvidia-smi:")
        try:
            gpu_info = subprocess.check_output(
                [
                    "nvidia-smi",
                    "--query-gpu=name,memory.total,memory.used",
                    "--format=csv",
                ]
            ).decode("utf-8")
            print(gpu_info.strip())
        except Exception as e:
            print("  (nvidia-smi failed:", e, ")")
        # PyTorch peak stats
        torch.cuda.reset_peak_memory_stats()
    print()

In [None]:
import os
import nibabel as nib


def check_nifti_sizes(directory):
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.endswith(".nii") or file.endswith(".nii.gz"):
                file_path = os.path.join(root, file)
                try:
                    img = nib.load(file_path)
                    print(f"File: {file_path}")
                    print(f"Shape: {img.shape}")
                    print(f"Header: {img.header}")
                    print("-" * 50)
                except Exception as e:
                    print(f"Error loading {file_path}: {str(e)}")

if __name__ == "__main__":
    train_dir = os.path.join(os.getcwd(), "prostate158", "train")
    check_nifti_sizes(train_dir)

In [None]:
# help(monai.networks.nets)

All parameters needed for training and evaluation are set in `anatomy.yaml` file.


In [None]:
# config = load_config("tumor.yaml")  # change to 'tumor.yaml' for tumor segmentation
# monai.utils.set_determinism(seed=config.seed)

---


In [None]:
# 4) Print memory before training
print_memory_stats("Before training")

Create supervised trainer for segmentation task


In [None]:
cfg = load_config("tumor.yaml")

trainer = SegmentationTrainer(
    progress_bar=True,
    early_stopping=True,
    metrics=["MeanDice", "HausdorffDistance", "SurfaceDistance"],
    save_latest_metrics=True,
    config=cfg,
)

# Load pre-trained weights from tumor.pt
trainer.load_checkpoint("models/tumor.pt")

In [None]:
cfg = load_config("tumor.yaml")

# First create the trainer
trainer = SegmentationTrainer(
    progress_bar=True,
    early_stopping=True,
    metrics=["MeanDice", "HausdorffDistance", "SurfaceDistance"],
    save_latest_metrics=True,
    config=cfg,
)

# Then load the pretrained weights into the network
# print(f"Loading pretrained weights from models/tumor.pt")
state_dict = torch.load(
    "./models/tumor.pt", map_location=trainer.config.device
)
# Handle both cases: direct state dict or wrapped in 'state_dict' key
if isinstance(state_dict, dict) and "state_dict" in state_dict:
    state_dict = state_dict["state_dict"]
trainer.network.load_state_dict(state_dict, strict=False)

Adding a learning rate scheduler for one-cylce policy.


In [None]:
trainer.fit_one_cycle()

In [None]:
# 6) Print peak GPU memory after fit_one_cycle
if torch.cuda.is_available():
    peak = torch.cuda.max_memory_allocated() / (1024**3)
    print(f"[MEMORY] Peak GPU memory used during fit_one_cycle: {peak:.2f} GB")

Let's train. This can take several hours.


In [None]:
trainer.run()

In [None]:
import os
from prostate158.inference3 import inference_pipeline

# Set paths for case ID 20
data_dir = "prostate158_train"  # Base directory from config
case_dir = os.path.join(data_dir, "train", "051")

# Input paths
t2_path = os.path.join(case_dir, "t2.nii.gz")
adc_path = os.path.join(case_dir, "adc.nii.gz")
dwi_path = os.path.join(case_dir, "dwi.nii.gz")

# Output path
os.makedirs("predictions", exist_ok=True)
output_path = os.path.join("predictions", "case_051_tumor_pred.nii.gz")

# Run inference
inference_pipeline(
    t2_path=t2_path,
    adc_path=adc_path,
    dwi_path=dwi_path,
    output_path=output_path,
    config_path="tumor.yaml",
    checkpoint_path="models/tumor.pt",
)

In [None]:
import nibabel as nib

# 2) Load each volume and print its shape
for name, path in [("T2", t2_path), ("ADC", adc_path), ("DWI", dwi_path)]:
    img = nib.load(path)
    data = img.get_fdata()
    print(f"{name}  shape: {data.shape}")

# 3) Load the prediction and print its shape
pred_img = nib.load("predictions/t2.nii.gz")
pred_data = pred_img.get_fdata()
print(f"Prediction shape: {pred_data.shape}")

In [None]:
# 9) Final memory report
if torch.cuda.is_available():
    peak_total = torch.cuda.max_memory_allocated() / (1024**3)
    print(f"[MEMORY] Peak GPU memory used across all: {peak_total:.2f} GB")
print_memory_stats("After trainer.run()")

Finish the training with final evaluation of the best model. To allow visualization of all outputs, add OutputStore handler first. Otherwise only output form the last epoch will be accessible.


In [None]:
eos_handler = ignite.handlers.EpochOutputStore()
eos_handler.attach(trainer.evaluator, "output")

In [None]:
trainer.evaluate(checkpoint=r"models\tumor.pt", map_location="cuda:0")

Generate a markdown document with segmentation results


In [None]:
report_generator = ReportGenerator(cfg.run_id, cfg.out_dir, cfg.log_dir)
report_generator.generate_report()

Have a look at some outputs


In [None]:
output = trainer.evaluator.state.output
keys = ["image", "label", "pred"]
outputs = {k: [o[0][k].detach().cpu().squeeze() for o in output] for k in keys}

In [None]:
ListViewer(
    [o.transpose(0, 2).flip(-2) for o in outputs["image"][0:3]]
    + [o.argmax(0).transpose(0, 2).flip(-2).float() for o in outputs["label"][0:3]]
    + [o.argmax(0).transpose(0, 2).flip(-2).float() for o in outputs["pred"][0:3]]
).show()