In [None]:
import os, copy, time, random, torch, numpy as np                                 # ← your own
import glob
from tqdm import tqdm
import pandas as pd
from torch.utils.data import DataLoader
from monai.data import CacheDataset
import glob, nibabel as nib, pandas as pd
from monai.data import CacheDataset, DataLoader
from monai.transforms import (
    LoadImaged, EnsureChannelFirstd, Orientationd, ScaleIntensityd,
    RandFlipd, RandSpatialCropd, Compose, SelectItemsd
)

from utils import *
from models import *
  
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# -----------------------------------------------------------
# 0. paths & meta-data (unchanged) ---------------------------
# -----------------------------------------------------------
BRATS_DIR = "/mnt/c/Datasets/MICCAI_FeTS2022_TrainingData"
CSV_PATH  = f"{BRATS_DIR}/partitioning_1.csv"
MODALITIES = ["flair", "t1", "t1ce", "t2"]
LABEL_KEY  = "seg"

# -----------------------------------------------------------
# 1. read partition file  ➜  { id : [subjects] } ------------
# -----------------------------------------------------------
part_df       = pd.read_csv(CSV_PATH)
partition_map = (
    part_df.groupby("Partition_ID")["Subject_ID"]
           .apply(list).to_dict()
)                               # keys are 1 … 23

VAL_CENTRES = {18, 19, 20, 21, 22, 23}          # ← our hold-out set

# split once, reuse everywhere
train_partitions = {cid: sids for cid, sids in partition_map.items()
                    if cid not in VAL_CENTRES}
val_subjects     = sum((partition_map[cid] for cid in VAL_CENTRES), [])

# -----------------------------------------------------------
# 2. helper to build MONAI-style record dicts ----------------
# -----------------------------------------------------------
def build_records(subject_ids):
    recs = []
    for sid in subject_ids:
        sdir = f"{BRATS_DIR}/{sid}"
        rec  = {m: f"{sdir}/{sid}_{m}.nii.gz" for m in MODALITIES}
        rec["seg"] = f"{sdir}/{sid}_{LABEL_KEY}.nii.gz"
        recs.append(rec)
    return recs

# -----------------------------------------------------------
# 3. transforms (unchanged) ---------------------------------
# -----------------------------------------------------------
IMG_KEYS = MODALITIES + [LABEL_KEY]
train_tf = Compose([
    LoadImaged(keys=IMG_KEYS), EnsureChannelFirstd(keys=IMG_KEYS),
    Orientationd(keys=IMG_KEYS, axcodes="RAS"),
    ScaleIntensityd(keys=IMG_KEYS, minv=-1.0, maxv=1.0),
    SelectItemsd(keys=IMG_KEYS),
])
val_tf = Compose([
    LoadImaged(keys=IMG_KEYS), EnsureChannelFirstd(keys=IMG_KEYS),
    Orientationd(keys=IMG_KEYS, axcodes="RAS"),
    ScaleIntensityd(keys=MODALITIES, minv=-1.0, maxv=1.0),   # masks untouched
    SelectItemsd(keys=IMG_KEYS),
])

# -----------------------------------------------------------
# 4. MONAI CacheDatasets ------------------------------------
# -----------------------------------------------------------
# ── client-wise training sets ───────────────────────────────
CUT_OFF, FRAC, SEED = 4, 0.3, 42
rng = random.Random(SEED)

train_datasets = {}
for cid, subj_ids in train_partitions.items():
    if cid > CUT_OFF:                                    # keep your cap
        break
    k = max(1, int(len(subj_ids) * FRAC))                # e.g. 30 %
    sample_ids = rng.sample(subj_ids, k)
    train_datasets[cid] = CacheDataset(
        build_records(sample_ids), transform=train_tf, cache_rate=1
    )

# ── single validation dataset made from *all* val subjects ─
test_dataset = CacheDataset(
    build_records(val_subjects), transform=val_tf, cache_rate=1
)

print("train per-centre sizes:", {k: len(v) for k, v in train_datasets.items()})
print("validation size:", len(val_dataset))


Loading dataset:  52%|███████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                              | 80/153 [01:11<00:56,  1.30it/s]

In [None]:
# ----------------------------------------------------------------------------------
# 2. helper: Dice on whole/TC/ET averaged to a scalar --------------------------------
# ----------------------------------------------------------------------------------
@torch.no_grad()
def dice3(model, test_dataset):
    model.eval()
    if len(test_dataset) == 0:
        raise RuntimeError(f"No validation cases found – check {VAL_DIR} and glob pattern.")
    loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    dsum = torch.zeros(3, device=device)
    for batch in loader:
        img = torch.cat([batch[k] for k in ("flair","t1","t1ce","t2")],1).to(device)
        raw = batch["seg"].squeeze(1).cpu().numpy()
        target = torch.tensor(preprocess_mask_labels(raw),
                              dtype=torch.float32, device=device)
        logits = model(img)
        pred = torch.nn.functional.one_hot(
                   torch.argmax(logits,1), num_classes=3
               ).permute(0,4,1,2,3).float()
        inter = 2*(pred*target).sum((2,3,4))
        denom = (pred+target).sum((2,3,4))+1e-6
        dsum += (inter/denom).squeeze(0)


    return (dsum/len(loader)).mean().item()

global_model = ResUNet3D(4,3).to(device)
print("Dice before any training:", dice3(global_model, test_dataset))

In [None]:


# ----------------------------------------------------------------------------------
# 3. local update (single epoch, no AMP, no fancy stuff) -----------------------------
# ----------------------------------------------------------------------------------
def local_train(model, loader, lr=1e-4, ep=1):
    crit = BCEDiceLoss().to(device)
    opt  = torch.optim.Adam(model.parameters(), lr=lr)
    model.train()
    for _ in range(ep):
        for batch in loader:
            img = torch.cat([batch[k] for k in ("flair","t1","t1ce","t2")],1).to(device)
            msk = preprocess_mask_labels(batch["seg"].squeeze(1).numpy())
            msk = torch.tensor(msk, dtype=torch.float32, device=device)
            opt.zero_grad()
            loss = crit(model(img), msk)
            loss.backward(); opt.step()
# ----------------------------------------------------------------------------------
# 4. FedAvg training loop ------------------------------------------------------------
# ----------------------------------------------------------------------------------
EPOCHS, LOCAL_EP, LR, BATCH = 2, 1, 1e-4, 2        # keep tiny for sanity run
idxs_users = list(train_datasets.keys())
sizes = {k:len(v) for k,v in train_datasets.items()}
frac  = [sizes[k]/sum(sizes.values()) for k in idxs_users]


for rnd in tqdm(range(1, EPOCHS+1)):
    local_w, losses = [], []
    for cid in idxs_users:
        loader = DataLoader(train_datasets[cid], batch_size=BATCH, shuffle=True)
        mdl = copy.deepcopy(global_model)
        local_train(mdl, loader, lr=LR, ep=LOCAL_EP)
        local_w.append(mdl.state_dict())
    # FedAvg
    global_model.load_state_dict( average_weights(local_w, frac) )
    print(f"Round {rnd:2d} – mean Dice: {dice3(global_model):.4f}")