In [1]:
import os, copy, time, random, torch, numpy as np                                 # ← your own
import glob
from tqdm import tqdm
import pandas as pd
from torch.utils.data import DataLoader
from monai.data import CacheDataset
import glob, nibabel as nib, pandas as pd
from monai.data import CacheDataset, DataLoader
from monai.transforms import (
    LoadImaged, EnsureChannelFirstd, Orientationd, ScaleIntensityd,
    RandFlipd, RandSpatialCropd, Compose, SelectItemsd
)

from utils import *
from models import *
  
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# -----------------------------------------------------------
# 0. paths & meta-data (unchanged) ---------------------------
# -----------------------------------------------------------
BRATS_DIR = "/mnt/d/Datasets/FETS_data/MICCAI_FeTS2022_TrainingData"
CSV_PATH  = f"{BRATS_DIR}/partitioning_1.csv"
MODALITIES = ["flair", "t1", "t1ce", "t2"]
LABEL_KEY  = "seg"

# -----------------------------------------------------------
# 1. read partition file  ➜  { id : [subjects] } ------------
# -----------------------------------------------------------
part_df       = pd.read_csv(CSV_PATH)
partition_map = (
    part_df.groupby("Partition_ID")["Subject_ID"]
           .apply(list).to_dict()
)                               # keys are 1 … 23

VAL_CENTRES = {18, 19, 20, 21, 22, 23}          # ← our hold-out set
# VAL_CENTRES = {22, 23}          # ← our sanity set

# split once, reuse everywhere
train_partitions = {cid: sids for cid, sids in partition_map.items()
                    if cid not in VAL_CENTRES}
val_subjects     = sum((partition_map[cid] for cid in VAL_CENTRES), [])

# -----------------------------------------------------------
# 2. helper to build MONAI-style record dicts ----------------
# -----------------------------------------------------------
def build_records(subject_ids):
    recs = []
    for sid in subject_ids:
        sdir = f"{BRATS_DIR}/{sid}"
        rec  = {m: f"{sdir}/{sid}_{m}.nii.gz" for m in MODALITIES}
        rec["seg"] = f"{sdir}/{sid}_{LABEL_KEY}.nii.gz"
        recs.append(rec)
    return recs

# -----------------------------------------------------------
# 3. transforms (unchanged) ---------------------------------
# -----------------------------------------------------------
IMG_KEYS = MODALITIES + [LABEL_KEY]
train_tf = Compose([
    LoadImaged(keys=IMG_KEYS), EnsureChannelFirstd(keys=IMG_KEYS),
    Orientationd(keys=IMG_KEYS, axcodes="RAS"),
    ScaleIntensityd(keys=IMG_KEYS, minv=-1.0, maxv=1.0),
    SelectItemsd(keys=IMG_KEYS),
])
val_tf = Compose([
    LoadImaged(keys=IMG_KEYS), EnsureChannelFirstd(keys=IMG_KEYS),
    Orientationd(keys=IMG_KEYS, axcodes="RAS"),
    ScaleIntensityd(keys=MODALITIES, minv=-1.0, maxv=1.0),   # masks untouched
    SelectItemsd(keys=IMG_KEYS),
])

# -----------------------------------------------------------
# 4. MONAI CacheDatasets ------------------------------------
# -----------------------------------------------------------
# ── client-wise training sets ───────────────────────────────
CUT_OFF, FRAC, SEED = 18, 1, 42
rng = random.Random(SEED)

train_datasets = {}
for cid, subj_ids in train_partitions.items():
    if cid > CUT_OFF:                                    # keep your cap
        break
    k = max(1, int(len(subj_ids) * FRAC))                # e.g. 30 %
    sample_ids = rng.sample(subj_ids, k)
    train_datasets[cid] = CacheDataset(
        build_records(sample_ids), transform=train_tf, cache_rate=1
    )

# ── single validation dataset made from *all* val subjects ─
test_dataset = CacheDataset(
    build_records(val_subjects), transform=val_tf, cache_rate=1
)

print("train per-centre sizes:", {k: len(v) for k, v in train_datasets.items()})
print("validation size:", len(test_dataset))


  from .autonotebook import tqdm as notebook_tqdm
Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 511/511 [12:28<00:00,  1.46s/it]
Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:07<00:00,  1.30s/it]
Loading dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:19<00:00,  1.32s/it]
Loading dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 47/47 [01:08<00:00,  1.46s/it]
Loading dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [00:30<00:00,  1.39s/it]
Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████

train per-centre sizes: {1: 511, 2: 6, 3: 15, 4: 47, 5: 22, 6: 34, 7: 12, 8: 8, 9: 4, 10: 8, 11: 14, 12: 11, 13: 35, 14: 6, 15: 13, 16: 30, 17: 9}
validation size: 466





In [2]:
print("validation size:", len(test_dataset))


validation size: 466


In [16]:
from seg_models import *      # adjust path / PYTHONPATH
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = ResUNet3D(4,3).to(device)

# quick dummy forward
with torch.no_grad():
    dummy = torch.randn(1, 4, 96, 96, 96, device=device)
    logits = model(dummy)
print("logits shape:", logits.shape)


logits shape: torch.Size([1, 3, 96, 96, 96])


In [17]:
# ----------------------------------------------------------------------------------
# 2. helper: Dice on whole/TC/ET averaged to a scalar --------------------------------
# ----------------------------------------------------------------------------------
@torch.no_grad()
def dice3(model, test_dataset):
    model.eval()
    if len(test_dataset) == 0:
        raise RuntimeError(f"No validation cases found – check {VAL_DIR} and glob pattern.")
    loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    dsum = torch.zeros(3, device=device)
    for batch in loader:
        img = torch.cat([batch[k] for k in ("flair","t1","t1ce","t2")],1).to(device)
        raw = batch["seg"].squeeze(1).cpu().numpy()
        target = torch.tensor(preprocess_mask_labels(raw),
                              dtype=torch.float32, device=device)
        logits = model(img)
        pred = torch.nn.functional.one_hot(
                   torch.argmax(logits,1), num_classes=3
               ).permute(0,4,1,2,3).float()
        inter = 2*(pred*target).sum((2,3,4))
        denom = (pred+target).sum((2,3,4))+1e-6
        dsum += (inter/denom).squeeze(0)


    return (dsum/len(loader)).mean().item()

global_model = ResUNet3D(4,3).to(device)
print("Dice before any training:", dice3(global_model, test_dataset))

Dice before any training: 0.011921419762074947


In [None]:
from tqdm.auto import tqdm, trange   # trange == tqdm(range())

# ────────────────────────────────────────────────────────────
# 1. one-client update (returns weights + mean loss)          │
# ────────────────────────────────────────────────────────────
def local_train(model, loader, lr=1e-4, epochs=1):
    crit = BCEDiceLoss().to(device)
    opt  = torch.optim.Adam(model.parameters(), lr=lr)
    model.train()

    losses = []
    for _ in range(epochs):
        running = 0.0
        for img_dict in loader:
            img = torch.cat([img_dict[k] for k in ("flair", "t1", "t1ce", "t2")], 1).to(device)
            msk = preprocess_mask_labels(img_dict["seg"].squeeze(1).numpy())
            msk = torch.tensor(msk, dtype=torch.float32, device=device)

            opt.zero_grad()
            loss = crit(model(img), msk)
            loss.backward()
            opt.step()

            running += loss.item()
        losses.append(running / len(loader))          # epoch-mean

    return model.state_dict(), np.mean(losses)

# ────────────────────────────────────────────────────────────
# 2. FedAvg training loop (simple tqdm + clean prints)        │
# ────────────────────────────────────────────────────────────
EPOCHS, LOCAL_EPOCHS, LR, BATCH = 50, 1, 1e-4, 1          # dial as needed
idxs_users = list(train_datasets.keys())
sizes      = {k: len(ds) for k, ds in train_datasets.items()}
fractions  = [sizes[k] / sum(sizes.values()) for k in idxs_users]

global_model = ResUNet3D(4, 3).to(device)
print(f"Dice before training: {dice3(global_model, test_dataset):.4f}")

for rnd in trange(1, EPOCHS + 1, desc="Global rounds"):
    local_weights, client_losses = [], []

    for cid in tqdm(idxs_users, desc=" clients", leave=False):
        loader = DataLoader(train_datasets[cid], batch_size=BATCH, shuffle=True)
        # deep-copy so each user starts from the same global weights
        w, loss = local_train(copy.deepcopy(global_model), loader,
                              lr=LR, epochs=LOCAL_EPOCHS)
        local_weights.append(w)
        client_losses.append(loss)

    # FedAvg
    global_model.load_state_dict(average_weights(local_weights, fractions))

    mean_loss = np.mean(client_losses)
    mean_dice = dice3(global_model, test_dataset)
    print(f"Round {rnd:02d}:  mean-loss = {mean_loss:.4f}   mean-Dice = {mean_dice:.4f}")
