In [4]:
# ==== Stage Decathlon-style files for nnU-Net inference (fixed) ====

import os, shutil, errno
from pathlib import Path
import pandas as pd

BRATS_DIR = "/mnt/d/Datasets/FETS_data/MICCAI_FeTS2022_TrainingData"
CSV_PATH  = f"{BRATS_DIR}/partitioning_1.csv"
MODALITIES = ["flair", "t1", "t1ce", "t2"]
VAL_CENTRES = {18, 19, 20, 21, 22, 23}

# Build val_subjects if not already present
if "val_subjects" not in globals():
    part_df = pd.read_csv(CSV_PATH)
    partition_map = part_df.groupby("Partition_ID")["Subject_ID"].apply(list).to_dict()
    val_subjects = sum((partition_map[cid] for cid in VAL_CENTRES), [])

# nnU-Net suffix order: 0=t1n, 1=t1c, 2=t2w, 3=t2f
SUFFIX_MAP = {
    "t1":   "_0000",
    "t1ce": "_0001",
    "t2":   "_0002",
    "flair":"_0003",
}

# You can keep your custom path; nnU-Net just needs a folder of cases
INPUT_FOLDER_INFER = Path("/mnt/d/Datasets/FETS_data/INPUT_IMAGES_FOR_NNUNET")
INPUT_FOLDER_INFER.mkdir(parents=True, exist_ok=True)

def _safe_symlink_or_copy(src: Path, dst: Path):
    try:
        if dst.exists():
            return
        os.symlink(src, dst)
    except OSError as e:
        # cross-device or permission issues → copy instead
        if e.errno in (errno.EPERM, errno.EACCES, errno.EXDEV, errno.EOPNOTSUPP):
            shutil.copy2(src, dst)
        else:
            raise

missing = []
made = 0

for sid in val_subjects:
    sdir = Path(BRATS_DIR) / sid  # e.g., .../FeTS2022_00000
    for m in MODALITIES:
        src = sdir / f"{sid}_{m}.nii.gz"
        if not src.exists():
            missing.append(str(src))
            continue
        dst = INPUT_FOLDER_INFER / f"{sid}{SUFFIX_MAP[m]}.nii.gz"
        _safe_symlink_or_copy(src, dst)
        made += 1

print(f"[Done] Staged {made} files into: {INPUT_FOLDER_INFER}")
if missing:
    print(f"[Warn] Missing {len(missing)} files; first few:\n  - " + "\n  - ".join(missing[:8]))

# Quick sanity: should be exactly 4 files per subject
from collections import Counter
counts = Counter(p.name.split("_")[0] for p in INPUT_FOLDER_INFER.glob("*.nii.gz"))
bad = {k:v for k,v in counts.items() if v != 4}
print(f"[Check] Subjects staged: {len(counts)} (expected 4 files each)")
if bad:
    print("[Warn] Off-count subjects:", bad)
else:
    print("[OK] All subjects have 4 modalities.")


[Done] Staged 1864 files into: /mnt/d/Datasets/FETS_data/INPUT_IMAGES_FOR_NNUNET
[Check] Subjects staged: 1 (expected 4 files each)
[Warn] Off-count subjects: {'FeTS2022': 1864}


In [1]:
# -- One-time (in this kernel) --
import os
os.environ["nnUNet_raw"] = "/mnt/tmp/nnunet_raw"
os.environ["nnUNet_preprocessed"] = "/mnt/tmp/nnunet_preprocessed"
os.environ["nnUNet_results"] = "/mnt/tmp/nnunet_results"

# Put these assets as per README before you run predict:
# nnUNet_results/Dataset770_BraTSGLIPreCropRegion/nnUNetTrainer__nnUNetResEncUNetPlans__3d_fullres/fold_0/checkpoint_final.pth
# nnUNet_preprocessed/Dataset770_BraTSGLIPreCropRegion/nnUNetResEncUNetPlans.json

from pathlib import Path
from nnunet_api import NnUnetApi

INPUT_FOLDER_INFER  = "/mnt/d/Datasets/FETS_data/INPUT_IMAGES_FOR_NNUNET"
OUTPUT_FOLDER_PREDS = "/mnt/d/Datasets/FETS_data/NNUNET_PREDS_770"
Path(OUTPUT_FOLDER_PREDS).mkdir(parents=True, exist_ok=True)

api = NnUnetApi()
api.predict(
    input_folder=INPUT_FOLDER_INFER,
    output_folder=OUTPUT_FOLDER_PREDS,
    dataset_name_or_id="Dataset770_BraTSGLIPreCropRegion",
    plans_identifier="nnUNetResEncUNetPlans",
    configuration="3d_fullres",
    folds=[0]
)
print("Wrote preds to:", OUTPUT_FOLDER_PREDS)


FileNotFoundError: [Errno 2] No such file or directory: '/mnt/tmp/nnunet_results/Dataset770_BraTSGLIPreCropRegion/nnUNetTrainer__nnUNetResEncUNetPlans__3d_fullres/dataset.json'

In [None]:
@torch.no_grad()
def _label_to_onehot_3(raw_mask_np):
# raw_mask_np is the integer mask from nnU-Net (0 background, 1/2/3 tumor labels)
whole = (raw_mask_np == 1) | (raw_mask_np == 2) | (raw_mask_np == 3)
tumor_core = (raw_mask_np == 1) | (raw_mask_np == 3)
enhancing = (raw_mask_np == 3)
onehot = np.stack([whole, tumor_core, enhancing], axis=0).astype(np.float32)
return torch.from_numpy(onehot) # [3, D, H, W]


@torch.no_grad()
def dice3_from_preds_dir(preds_dir, test_dataset, device):
from torch.utils.data import DataLoader
loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
dsum = torch.zeros(3, device=device)


# Build a map from subject id -> pred path
pred_map = {p.name.replace(".nii.gz", ""): str(p)
for p in Path(preds_dir).glob("*.nii.gz")}


for batch in loader:
# Subject id recovered from one of the modality file names in your dataset
# E.g., "BraTS-GLI-00160-000_0000.nii.gz" => case id "BraTS-GLI-00160-000"
flair_path = batch["flair_meta_dict"]["filename_or_obj"][0]
sid = Path(flair_path).name.split("_")[0]
pred_path = pred_map.get(sid)
if pred_path is None:
raise FileNotFoundError(f"No prediction found for {sid}")


# target from your labels
raw_gt = batch["seg"].squeeze(1).cpu().numpy()[0]
target = _label_to_onehot_3(raw_gt).to(device).unsqueeze(0) # [1,3,D,H,W]


# predicted mask -> onehot 3‑channel
raw_pred = nib.load(pred_path).get_fdata().astype(np.int16)
pred = _label_to_onehot_3(raw_pred).to(device).unsqueeze(0)


inter = 2*(pred*target).sum((2,3,4))
denom = (pred+target).sum((2,3,4))+1e-6
dsum += (inter/denom).squeeze(0)


return (dsum/len(loader)).mean().item()


# Usage:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dice = dice3_from_preds_dir(OUTPUT_FOLDER_PRETRAINED, test_dataset, device)
print("Mean Dice (whole/core/enhancing averaged):", dice)

# No NNUnet yet  under here

In [1]:
import os, copy, time, random, torch, numpy as np                                 # ← your own
import glob
from tqdm import tqdm
import pandas as pd
from torch.utils.data import DataLoader
from monai.data import CacheDataset
import glob, nibabel as nib, pandas as pd
from monai.data import CacheDataset, DataLoader
from monai.transforms import (
    LoadImaged, EnsureChannelFirstd, Orientationd, ScaleIntensityd,
    RandFlipd, RandSpatialCropd, Compose, SelectItemsd
)

from utils import *
from models import *
  
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# -----------------------------------------------------------
# 0. paths & meta-data (unchanged) ---------------------------
# -----------------------------------------------------------
BRATS_DIR = "/mnt/d/Datasets/FETS_data/MICCAI_FeTS2022_TrainingData"
CSV_PATH  = f"{BRATS_DIR}/partitioning_1.csv"
MODALITIES = ["flair", "t1", "t1ce", "t2"]
LABEL_KEY  = "seg"

# -----------------------------------------------------------
# 1. read partition file  ➜  { id : [subjects] } ------------
# -----------------------------------------------------------
part_df       = pd.read_csv(CSV_PATH)
partition_map = (
    part_df.groupby("Partition_ID")["Subject_ID"]
           .apply(list).to_dict()
)                               # keys are 1 … 23

VAL_CENTRES = {18, 19, 20, 21, 22, 23}          # ← our hold-out set
# VAL_CENTRES = {22, 23}          # ← our sanity set

# split once, reuse everywhere
train_partitions = {cid: sids for cid, sids in partition_map.items()
                    if cid not in VAL_CENTRES}
val_subjects     = sum((partition_map[cid] for cid in VAL_CENTRES), [])

# -----------------------------------------------------------
# 2. helper to build MONAI-style record dicts ----------------
# -----------------------------------------------------------
def build_records(subject_ids):
    recs = []
    for sid in subject_ids:
        sdir = f"{BRATS_DIR}/{sid}"
        rec  = {m: f"{sdir}/{sid}_{m}.nii.gz" for m in MODALITIES}
        rec["seg"] = f"{sdir}/{sid}_{LABEL_KEY}.nii.gz"
        recs.append(rec)
    return recs

# -----------------------------------------------------------
# 3. transforms (unchanged) ---------------------------------
# -----------------------------------------------------------
IMG_KEYS = MODALITIES + [LABEL_KEY]
train_tf = Compose([
    LoadImaged(keys=IMG_KEYS), EnsureChannelFirstd(keys=IMG_KEYS),
    Orientationd(keys=IMG_KEYS, axcodes="RAS"),
    ScaleIntensityd(keys=IMG_KEYS, minv=-1.0, maxv=1.0),
    SelectItemsd(keys=IMG_KEYS),
])
val_tf = Compose([
    LoadImaged(keys=IMG_KEYS), EnsureChannelFirstd(keys=IMG_KEYS),
    Orientationd(keys=IMG_KEYS, axcodes="RAS"),
    ScaleIntensityd(keys=MODALITIES, minv=-1.0, maxv=1.0),   # masks untouched
    SelectItemsd(keys=IMG_KEYS),
])

# -----------------------------------------------------------
# 4. MONAI CacheDatasets ------------------------------------
# -----------------------------------------------------------
# ── client-wise training sets ───────────────────────────────
CUT_OFF, FRAC, SEED = 18, 1, 42
rng = random.Random(SEED)

train_datasets = {}
for cid, subj_ids in train_partitions.items():
    if cid > CUT_OFF:                                    # keep your cap
        break
    k = max(1, int(len(subj_ids) * FRAC))                # e.g. 30 %
    sample_ids = rng.sample(subj_ids, k)
    train_datasets[cid] = CacheDataset(
        build_records(sample_ids), transform=train_tf, cache_rate=1
    )

# ── single validation dataset made from *all* val subjects ─
test_dataset = CacheDataset(
    build_records(val_subjects), transform=val_tf, cache_rate=1
)

print("train per-centre sizes:", {k: len(v) for k, v in train_datasets.items()})
print("validation size:", len(test_dataset))


  from .autonotebook import tqdm as notebook_tqdm
Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 511/511 [12:28<00:00,  1.46s/it]
Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:07<00:00,  1.30s/it]
Loading dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:19<00:00,  1.32s/it]
Loading dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 47/47 [01:08<00:00,  1.46s/it]
Loading dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [00:30<00:00,  1.39s/it]
Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████

train per-centre sizes: {1: 511, 2: 6, 3: 15, 4: 47, 5: 22, 6: 34, 7: 12, 8: 8, 9: 4, 10: 8, 11: 14, 12: 11, 13: 35, 14: 6, 15: 13, 16: 30, 17: 9}
validation size: 466





In [2]:
print("validation size:", len(test_dataset))


validation size: 466


In [16]:
from seg_models import *      # adjust path / PYTHONPATH
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = ResUNet3D(4,3).to(device)

# quick dummy forward
with torch.no_grad():
    dummy = torch.randn(1, 4, 96, 96, 96, device=device)
    logits = model(dummy)
print("logits shape:", logits.shape)


logits shape: torch.Size([1, 3, 96, 96, 96])


In [17]:
# ----------------------------------------------------------------------------------
# 2. helper: Dice on whole/TC/ET averaged to a scalar --------------------------------
# ----------------------------------------------------------------------------------
@torch.no_grad()
def dice3(model, test_dataset):
    model.eval()
    if len(test_dataset) == 0:
        raise RuntimeError(f"No validation cases found – check {VAL_DIR} and glob pattern.")
    loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    dsum = torch.zeros(3, device=device)
    for batch in loader:
        img = torch.cat([batch[k] for k in ("flair","t1","t1ce","t2")],1).to(device)
        raw = batch["seg"].squeeze(1).cpu().numpy()
        target = torch.tensor(preprocess_mask_labels(raw),
                              dtype=torch.float32, device=device)
        logits = model(img)
        pred = torch.nn.functional.one_hot(
                   torch.argmax(logits,1), num_classes=3
               ).permute(0,4,1,2,3).float()
        inter = 2*(pred*target).sum((2,3,4))
        denom = (pred+target).sum((2,3,4))+1e-6
        dsum += (inter/denom).squeeze(0)


    return (dsum/len(loader)).mean().item()

global_model = ResUNet3D(4,3).to(device)
print("Dice before any training:", dice3(global_model, test_dataset))

Dice before any training: 0.011921419762074947


In [None]:
from tqdm.auto import tqdm, trange   # trange == tqdm(range())

# ────────────────────────────────────────────────────────────
# 1. one-client update (returns weights + mean loss)          │
# ────────────────────────────────────────────────────────────
def local_train(model, loader, lr=1e-4, epochs=1):
    crit = BCEDiceLoss().to(device)
    opt  = torch.optim.Adam(model.parameters(), lr=lr)
    model.train()

    losses = []
    for _ in range(epochs):
        running = 0.0
        for img_dict in loader:
            img = torch.cat([img_dict[k] for k in ("flair", "t1", "t1ce", "t2")], 1).to(device)
            msk = preprocess_mask_labels(img_dict["seg"].squeeze(1).numpy())
            msk = torch.tensor(msk, dtype=torch.float32, device=device)

            opt.zero_grad()
            loss = crit(model(img), msk)
            loss.backward()
            opt.step()

            running += loss.item()
        losses.append(running / len(loader))          # epoch-mean

    return model.state_dict(), np.mean(losses)

# ────────────────────────────────────────────────────────────
# 2. FedAvg training loop (simple tqdm + clean prints)        │
# ────────────────────────────────────────────────────────────
EPOCHS, LOCAL_EPOCHS, LR, BATCH = 50, 1, 1e-4, 1          # dial as needed
idxs_users = list(train_datasets.keys())
sizes      = {k: len(ds) for k, ds in train_datasets.items()}
fractions  = [sizes[k] / sum(sizes.values()) for k in idxs_users]

global_model = ResUNet3D(4, 3).to(device)
print(f"Dice before training: {dice3(global_model, test_dataset):.4f}")

for rnd in trange(1, EPOCHS + 1, desc="Global rounds"):
    local_weights, client_losses = [], []

    for cid in tqdm(idxs_users, desc=" clients", leave=False):
        loader = DataLoader(train_datasets[cid], batch_size=BATCH, shuffle=True)
        # deep-copy so each user starts from the same global weights
        w, loss = local_train(copy.deepcopy(global_model), loader,
                              lr=LR, epochs=LOCAL_EPOCHS)
        local_weights.append(w)
        client_losses.append(loss)

    # FedAvg
    global_model.load_state_dict(average_weights(local_weights, fractions))

    mean_loss = np.mean(client_losses)
    mean_dice = dice3(global_model, test_dataset)
    print(f"Round {rnd:02d}:  mean-loss = {mean_loss:.4f}   mean-Dice = {mean_dice:.4f}")
