<a href="https://colab.research.google.com/github/vs-152/FL-Contributions-Incentives-Project/blob/main/ISO_CIFAR10_OR_FINAL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ─────────────────────────────────────────────────────────────
#  Imports
# ─────────────────────────────────────────────────────────────
import os
import copy
import time
import glob
import shutil
import tempfile
from itertools import chain, combinations

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.utils.data.sampler import SubsetRandomSampler
from sklearn.model_selection import StratifiedShuffleSplit
from scipy.special import comb
import matplotlib.pyplot as plt
from tqdm import tqdm
import nibabel as nib
import pulp
import onnxruntime
import random

# ─────────────────────────────────────────────────────────────
#  MONAI
# ─────────────────────────────────────────────────────────────
from monai.config import print_config
from monai.utils import set_determinism
from monai.data import CacheDataset, DataLoader, decollate_batch
from monai.handlers.utils import from_engine
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
from monai.networks.nets import SegResNet
from monai.apps import DecathlonDataset
from monai.transforms import (
    Activations,
    Activationsd,
    AsDiscrete,
    AsDiscreted,
    Compose,
    EnsureChannelFirstd,
    EnsureTyped,
    Invertd,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    ScaleIntensityd,
    Spacingd,
    SelectItemsd
)

# ─────────────────────────────────────────────────────────────
#  Custom Modules
# ─────────────────────────────────────────────────────────────
from utils import *

# ─────────────────────────────────────────────────────────────
#  Device & Setup
# ─────────────────────────────────────────────────────────────
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print_config()
set_determinism(seed=0)


[0;93m2025-11-17 16:17:35.717494114 [W:onnxruntime:Default, device_discovery.cc:164 DiscoverDevicesForPlatform] GPU device discovery failed: device_discovery.cc:89 ReadFileContents Failed to open file: "/sys/class/drm/card0/device/vendor"[m


MONAI version: 1.6.dev2542
Numpy version: 2.1.2
Pytorch version: 2.8.0+cu126
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 612f3dd3cba4d73cfcea4b5329079e20aa31523d
MONAI __file__: /home/<username>/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: 5.4.4
Nibabel version: 5.3.2
scikit-image version: 0.25.2
scipy version: 1.15.3
Pillow version: 11.0.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.23.0+cu126
tqdm version: 4.67.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 7.0.0
pandas version: 2.3.2
einops version: 0.8.1
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For d

In [2]:
# Corrected conversion for FeTS labels
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    FeTS/BraTS label mapping (ints on disk): 0=background, 1=NCR/NET, 2=edema, 4=enhancing (ET)
    Build 3-channel multi-label [TC, WT, ET]:
      TC = (label==1) OR (label==4)
      WT = (label==1) OR (label==2) OR (label==4)
      ET = (label==4)
    """
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            lab = d[key]
            tc = torch.logical_or(lab == 1, lab == 4)
            wt = torch.logical_or(torch.logical_or(lab == 1, lab == 2), lab == 4)
            et = (lab == 4)
            d[key] = torch.stack([tc, wt, et], dim=0).float()
        return d


train_transform = Compose(
    [
        # load 4 Nifti images and stack them together
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        RandSpatialCropd(keys=["image", "label"], roi_size=[224, 224, 144], random_size=False),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
    ]
)
val_transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)

monai.transforms.spatial.dictionary Orientationd.__init__:labels: Current default value of argument `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` was changed in version None from `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` to `labels=None`. Default value changed to None meaning that the transform now uses the 'space' of a meta-tensor, if applicable, to determine appropriate axis labels.


In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# -----------------------------------------------------------
# 0. paths & meta-data (unchanged) ---------------------------
# -----------------------------------------------------------
BRATS_DIR = "/mnt/d/Datasets/FETS_data/MICCAI_FeTS2022_TrainingData"
CSV_PATH  = f"{BRATS_DIR}/partitioning_1.csv"
MODALITIES = ["flair", "t1", "t1ce", "t2"]
LABEL_KEY  = "seg"

# -----------------------------------------------------------
# 1. read partition file  ➜  { id : [subjects] } ------------
# -----------------------------------------------------------
part_df       = pd.read_csv(CSV_PATH)
partition_map = (
    part_df.groupby("Partition_ID")["Subject_ID"]
           .apply(list).to_dict()
)                               # keys are 1 … 23

VAL_CENTRES = {16, 17, 18, 19, 20, 21, 22, 23}          # ← our hold-out set
# VAL_CENTRES = {22, 23}          # ← our sanity set

# split once, reuse everywhere
train_partitions = {cid: sids for cid, sids in partition_map.items()
                    if cid not in VAL_CENTRES}
val_subjects     = sum((partition_map[cid] for cid in VAL_CENTRES), [])

# -----------------------------------------------------------
# 2. helper to build MONAI-style record dicts ----------------
# -----------------------------------------------------------

def build_records(subject_ids):
    recs = []
    for sid in subject_ids:
        sdir = f"{BRATS_DIR}/{sid}"
        images = [f"{sdir}/{sid}_{m}.nii.gz" for m in MODALITIES]  # 4 modalities
        recs.append({"image": images, "label": f"{sdir}/{sid}_{LABEL_KEY}.nii.gz"})
    return recs


# -----------------------------------------------------------
# 3. MONAI CacheDatasets ------------------------------------
# -----------------------------------------------------------
# ── client-wise training sets ───────────────────────────────
CUT_OFF, FRAC, SEED = 16, 1.0, 42
rng = random.Random(SEED)

train_datasets = {}
for cid, subj_ids in train_partitions.items():
    if cid > CUT_OFF:                                    # keep your cap
        break
    k = max(1, int(len(subj_ids) * FRAC))                # e.g. 30 %
    sample_ids = rng.sample(subj_ids, k)
    train_datasets[cid] = CacheDataset(
        build_records(sample_ids), transform=train_transform, cache_rate=1
    )

# ── single validation dataset made from *all* val subjects ─
val_dataset = CacheDataset(
    build_records(val_subjects), transform=val_transform, cache_rate=1
)

print("train per-centre sizes:", {k: len(v) for k, v in train_datasets.items()})
print("validation size:", len(val_dataset))

Loading dataset: 100%|████████████████████████████████████████████████████████████████| 511/511 [19:48<00:00,  2.33s/it]
Loading dataset: 100%|████████████████████████████████████████████████████████████████████| 6/6 [00:16<00:00,  2.77s/it]
Loading dataset: 100%|██████████████████████████████████████████████████████████████████| 15/15 [00:33<00:00,  2.21s/it]
Loading dataset: 100%|██████████████████████████████████████████████████████████████████| 47/47 [02:06<00:00,  2.69s/it]
Loading dataset: 100%|██████████████████████████████████████████████████████████████████| 22/22 [00:53<00:00,  2.44s/it]
Loading dataset: 100%|██████████████████████████████████████████████████████████████████| 34/34 [01:18<00:00,  2.30s/it]
Loading dataset: 100%|██████████████████████████████████████████████████████████████████| 12/12 [00:29<00:00,  2.43s/it]
Loading dataset: 100%|████████████████████████████████████████████████████████████████████| 8/8 [00:20<00:00,  2.53s/it]
Loading dataset: 100%|██████████

train per-centre sizes: {1: 511, 2: 6, 3: 15, 4: 47, 5: 22, 6: 34, 7: 12, 8: 8, 9: 4, 10: 8, 11: 14, 12: 11, 13: 35, 14: 6, 15: 13}
validation size: 505





In [4]:
max_epochs = 300
val_interval = 1
VAL_AMP = True

# create SegResNet, DiceLoss and Adam optimizer
device = torch.device("cuda:0")
global_model = SegResNet(
    blocks_down=[1, 2, 2, 4],
    blocks_up=[1, 1, 1],
    init_filters=16,
    in_channels=4,
    out_channels=3,
    dropout_prob=0.2,
).to(device)
loss_function = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True)
optimizer = torch.optim.Adam(global_model.parameters(), 1e-4, weight_decay=1e-5)

dice_metric = DiceMetric(include_background=True, reduction="mean")
dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")

post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

# use amp to accelerate training
scaler = torch.GradScaler("cuda")
# enable cuDNN benchmark
torch.backends.cudnn.benchmark = True

In [5]:
def evaluate_model(model, dataset, device, batch_size=1,
                   roi_size=(128, 128, 64), sw_batch_size=4):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    dice_metric.reset()
    dice_metric_batch.reset()
    model.eval()

    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            inputs = batch["image"].to(device)
            labels = batch["label"].to(device)          # [B, 3, D, H, W]

            logits = sliding_window_inference(
                inputs=inputs,
                roi_size=roi_size,
                sw_batch_size=sw_batch_size,
                predictor=model,                        # ← use THIS model
            )

            preds = torch.sigmoid(logits)
            preds = (preds > 0.5).float()

            dice_metric(y_pred=preds, y=labels)
            dice_metric_batch(y_pred=preds, y=labels)

    mean_dice = dice_metric.aggregate().item()
    metric_batch = dice_metric_batch.aggregate()
    metric_tc = metric_batch[0].item()
    metric_wt = metric_batch[1].item()
    metric_et = metric_batch[2].item()
    dice_metric.reset()
    dice_metric_batch.reset()

    return mean_dice, metric_tc, metric_wt, metric_et
    
print("Dice before any training:", evaluate_model(global_model, val_dataset, device)) # quick sanity check


Evaluating:   0%|                                                                               | 0/505 [00:00<?, ?it/s]Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:306.)
Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:306.)
Evaluating: 100%|█████████████████████████████████████████████████████████████████████| 505/505 [03:41<00:00,  2.28it/s]


Dice before any training: (0.022014198824763298, 0.02844342216849327, 0.021195057779550552, 0.016651295125484467)


In [6]:
# ─────────────────────────────────────────────────────────────
#  Federation setup
# ─────────────────────────────────────────────────────────────
# train_datasets: dict[int -> MONAI CacheDataset]  (already built)
idxs_users = list(sorted(train_datasets.keys()))
N = len(idxs_users)
print(f"We got {N} clients")

# Fed hyperparams (align with your working pipeline)
ROUNDS       = 50            # you can raise later (e.g., 100)
LOCAL_EPOCHS = 1
LR           = 1e-4
BATCH        = 1

# Client sizes & FedAvg fractions
sizes     = {k: len(ds) for k, ds in train_datasets.items()}
total_n   = sum(sizes.values())
fractions = [sizes[k] / total_n for k in idxs_users]

# Where to persist submodels / global snapshots
submodel_dir = "submodels"
os.makedirs(submodel_dir, exist_ok=True)
submodel_file_template = os.path.join(submodel_dir, "submodel_{}.pth")
global_model_path      = os.path.join(submodel_dir, "global_model.pth")
best_model_path        = os.path.join(submodel_dir, "best_metric_model.pth")

# Save initial global (round 0) – useful for baselines
torch.save(global_model.state_dict(), global_model_path)

# For later Shapley steps
accuracy_dict = {}     # coalition -> utility (e.g., Dice on test set)
shapley_dict  = {}     # client -> shapley value (to be filled later)

# fast sanity check before any training
print("Dice before any training:", evaluate_model(global_model, val_dataset, device))


We got 15 clients


Evaluating: 100%|█████████████████████████████████████████████████████████████████████| 505/505 [03:37<00:00,  2.32it/s]

Dice before any training: (0.022014198824763298, 0.02844342216849327, 0.021195057779550552, 0.016651295125484467)





In [7]:
from tqdm.auto import tqdm, trange   # trange == tqdm(range())
from collections import OrderedDict

def average_weights(state_dicts, fractions):
    """
    Federated averaging with client fractions (must sum to 1).
    state_dicts: list of state_dicts (same keys)
    fractions:   list of floats, same length, sum≈1
    """
    avg_sd = OrderedDict()
    for k in state_dicts[0].keys():
        avg = 0.0
        for sd, w in zip(state_dicts, fractions):
            avg += sd[k] * w
        avg_sd[k] = avg
    return avg_sd

# ────────────────────────────────────────────────────────────
# 1. one-client update (returns weights + mean loss)          │
# ────────────────────────────────────────────────────────────
def local_train(model, loader, device, lr=1e-4, epochs=1):
    """
    Train a local copy of the global model on one client's DataLoader.
    Uses your DiceLoss (multi-label, sigmoid) and full crops from transforms.
    """
    model = copy.deepcopy(model).to(device)
    model.train()

    # reuse your loss choice; or inline DiceLoss the same way
    crit = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True,
                    to_onehot_y=False, sigmoid=True).to(device)

    opt = torch.optim.Adam(model.parameters(), lr=lr)
    epoch_losses = []

    for _ in range(epochs):
        running = 0.0
        for batch in loader:
            img = batch["image"].to(device)   # [B, 4, D, H, W]
            msk = batch["label"].to(device)   # [B, 3, D, H, W]

            opt.zero_grad(set_to_none=True)
            logits = model(img)               # [B, 3, D, H, W]
            loss = crit(logits, msk)
            loss.backward()
            opt.step()

            running += float(loss.item())
        epoch_losses.append(running / max(1, len(loader)))

    return model.state_dict(), float(np.mean(epoch_losses))


# ─────────────────────────────────────────────────────────────
#  FedAvg training loop (with per-client snapshots each round)
# ─────────────────────────────────────────────────────────────
from tqdm.auto import tqdm, trange
from collections import OrderedDict

best_metric = -1
best_metric_round = -1
best_metrics_rounds_and_time = [[], [], []]   # best, round, seconds
round_loss_values = []
metric_values     = []
metric_values_tc  = []
metric_values_wt  = []
metric_values_et  = []

patience      = 5      # stop after 5 rounds with no improvement
no_improve    = 0
start_time    = time.time()
last_round_run = 0     # track actual last round (for logging)

for rnd in trange(1, ROUNDS + 1, desc="Global rounds", position=0, leave=True, dynamic_ncols=True):
    local_weights, client_losses = [], []

    # —— local updates per client ——
    for cid in tqdm(idxs_users, desc=" clients", position=1, leave=False, total=len(idxs_users), dynamic_ncols=True):
        loader = DataLoader(
            train_datasets[cid], batch_size=BATCH, shuffle=True,
            num_workers=4, pin_memory=True
        )
        w, loss = local_train(global_model, loader, device, lr=LR, epochs=LOCAL_EPOCHS)
        local_weights.append(w); client_losses.append(loss)

        # Persist this client's *latest* local model for Shapley / ablations
        torch.save(w, submodel_file_template.format(cid))

    # —— FedAvg (fraction-weighted) ——
    global_model.load_state_dict(average_weights(local_weights, fractions))

    # —— validation metrics on your current pipeline ——
    mean_dice, metric_tc, metric_wt, metric_et = evaluate_model(global_model, val_dataset, device)
    metric_values.append(mean_dice)
    metric_values_tc.append(metric_tc)
    metric_values_wt.append(metric_wt)
    metric_values_et.append(metric_et)

    mean_loss = float(np.mean(client_losses))
    round_loss_values.append(mean_loss)

    # —— track best & save ——
    if mean_dice > best_metric:
        best_metric = mean_dice
        best_metric_round = rnd
        best_metrics_rounds_and_time[0].append(best_metric)
        best_metrics_rounds_and_time[1].append(best_metric_round)
        best_metrics_rounds_and_time[2].append(time.time() - start_time)
        torch.save(global_model.state_dict(), best_model_path)
        print("saved new best metric model")
        no_improve = 0
    else:
        no_improve += 1
        if no_improve >= patience:
            print(f"Early stopping triggered at round {rnd} (no improvement for {patience} rounds).")
            break

    tqdm.write(
        f"Round {rnd:02d}: mean-loss={mean_loss:.4f} "
        f"mean-Dice={mean_dice:.4f}  "
        f"TC-Dice={metric_tc:.4f}  WT-Dice={metric_wt:.4f}  ET-Dice={metric_et:.4f}"
    )

# ── final val utility for the “grand coalition” (all clients) ─────────────
val_mean_dice, val_tc, val_wt, val_et = evaluate_model(global_model, val_dataset, device)
print(f"\nResults after {ROUNDS} global rounds:")
print(f"|---- Val Dice(mean): {val_mean_dice:.4f} | TC {val_tc:.4f} | WT {val_wt:.4f} | ET {val_et:.4f}")

# Store utility for coalition = all clients (tuple keeps order deterministic)
accuracy_dict[tuple(idxs_users)] = val_mean_dice



Global rounds:   0%|                                                                             | 0/50 [00:00…

 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:00<?, ?it/s]

saved new best metric model
Round 01: mean-loss=0.9413 mean-Dice=0.2763  TC-Dice=0.2632  WT-Dice=0.3711  ET-Dice=0.1961


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:00<?, ?it/s]

saved new best metric model
Round 02: mean-loss=0.8858 mean-Dice=0.4072  TC-Dice=0.3937  WT-Dice=0.4966  ET-Dice=0.3338


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:00<?, ?it/s]

saved new best metric model
Round 03: mean-loss=0.8346 mean-Dice=0.4843  TC-Dice=0.4470  WT-Dice=0.5792  ET-Dice=0.4291


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

saved new best metric model
Round 04: mean-loss=0.7545 mean-Dice=0.6090  TC-Dice=0.5810  WT-Dice=0.6392  ET-Dice=0.6099


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

saved new best metric model
Round 05: mean-loss=0.6483 mean-Dice=0.6870  TC-Dice=0.6855  WT-Dice=0.6921  ET-Dice=0.6867


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

saved new best metric model
Round 06: mean-loss=0.5276 mean-Dice=0.7013  TC-Dice=0.7234  WT-Dice=0.7099  ET-Dice=0.6739


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

Round 07: mean-loss=0.4187 mean-Dice=0.6419  TC-Dice=0.6541  WT-Dice=0.6299  ET-Dice=0.6455


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:00<?, ?it/s]

saved new best metric model
Round 08: mean-loss=0.3505 mean-Dice=0.7033  TC-Dice=0.7248  WT-Dice=0.6993  ET-Dice=0.6894


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

saved new best metric model
Round 09: mean-loss=0.2932 mean-Dice=0.7239  TC-Dice=0.7421  WT-Dice=0.7482  ET-Dice=0.6837


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

Round 10: mean-loss=0.2582 mean-Dice=0.6752  TC-Dice=0.6542  WT-Dice=0.7504  ET-Dice=0.6217


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

Round 11: mean-loss=0.2418 mean-Dice=0.7219  TC-Dice=0.7317  WT-Dice=0.7570  ET-Dice=0.6786


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

Round 12: mean-loss=0.2350 mean-Dice=0.6640  TC-Dice=0.6606  WT-Dice=0.7239  ET-Dice=0.6076


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

saved new best metric model
Round 13: mean-loss=0.2226 mean-Dice=0.7469  TC-Dice=0.7807  WT-Dice=0.7218  ET-Dice=0.7416


 clients:   0%|                                                                                  | 0/15 [00:00…

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f61543bd3f0>
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
Traceback (most recent call last):
    self._shutdown_workers()
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f61543bd3f0>
Traceback (most recent call last):
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
 

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

Round 14: mean-loss=0.2349 mean-Dice=0.7153  TC-Dice=0.7246  WT-Dice=0.7593  ET-Dice=0.6634


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

Round 15: mean-loss=0.2247 mean-Dice=0.7374  TC-Dice=0.7385  WT-Dice=0.7847  ET-Dice=0.6904


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

saved new best metric model
Round 16: mean-loss=0.2151 mean-Dice=0.7499  TC-Dice=0.7570  WT-Dice=0.7870  ET-Dice=0.7068


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

Round 17: mean-loss=0.2217 mean-Dice=0.7096  TC-Dice=0.7148  WT-Dice=0.7456  ET-Dice=0.6695


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

Round 18: mean-loss=0.2201 mean-Dice=0.7492  TC-Dice=0.7572  WT-Dice=0.7885  ET-Dice=0.7036


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

saved new best metric model
Round 19: mean-loss=0.2146 mean-Dice=0.7641  TC-Dice=0.7914  WT-Dice=0.7666  ET-Dice=0.7364


 clients:   0%|                                                                                  | 0/15 [00:00…

IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out


Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

saved new best metric model
Round 20: mean-loss=0.2095 mean-Dice=0.7991  TC-Dice=0.8295  WT-Dice=0.7947  ET-Dice=0.7762


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

Round 21: mean-loss=0.2231 mean-Dice=0.7598  TC-Dice=0.7768  WT-Dice=0.7751  ET-Dice=0.7300


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

saved new best metric model
Round 22: mean-loss=0.2151 mean-Dice=0.8007  TC-Dice=0.8273  WT-Dice=0.8022  ET-Dice=0.7766


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

Round 23: mean-loss=0.2050 mean-Dice=0.7494  TC-Dice=0.7436  WT-Dice=0.7919  ET-Dice=0.7150


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

Round 24: mean-loss=0.2190 mean-Dice=0.7649  TC-Dice=0.7728  WT-Dice=0.8003  ET-Dice=0.7234


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

Round 25: mean-loss=0.2147 mean-Dice=0.7372  TC-Dice=0.7515  WT-Dice=0.7766  ET-Dice=0.6856


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

Round 26: mean-loss=0.2133 mean-Dice=0.7336  TC-Dice=0.7415  WT-Dice=0.7842  ET-Dice=0.6770


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

saved new best metric model
Round 27: mean-loss=0.2015 mean-Dice=0.8039  TC-Dice=0.8244  WT-Dice=0.8231  ET-Dice=0.7669


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

Round 28: mean-loss=0.2069 mean-Dice=0.6409  TC-Dice=0.6756  WT-Dice=0.6317  ET-Dice=0.6223


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

Round 29: mean-loss=0.2188 mean-Dice=0.7943  TC-Dice=0.8214  WT-Dice=0.8110  ET-Dice=0.7531


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

Round 30: mean-loss=0.2049 mean-Dice=0.7824  TC-Dice=0.8001  WT-Dice=0.8039  ET-Dice=0.7462


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

Round 31: mean-loss=0.2053 mean-Dice=0.7799  TC-Dice=0.8005  WT-Dice=0.8103  ET-Dice=0.7313


 clients:   0%|                                                                                  | 0/15 [00:00…

Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

Early stopping triggered at round 32 (no improvement for 5 rounds).


Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]


Results after 50 global rounds:
|---- Val Dice(mean): 0.8011 | TC 0.8273 | WT 0.8034 | ET 0.7767


In [None]:
# ── deterministic powerset over your sorted client IDs (exclude empty, include all) ──
clients = list(sorted(idxs_users))                       # e.g., [1,2,3,...] or arbitrary ints
powerset = [tuple(s) for r in range(1, len(clients)+1)   # nonempty subsets only
            for s in combinations(clients, r)]

# Build a client -> fraction map (global FedAvg fractions you already computed)
fraction_map = {cid: frac for cid, frac in zip(clients, fractions)}

# Helper: renormalize fractions within a subset so they sum to 1
def subset_weights_and_fracs(subset):
    w_list = [torch.load(submodel_file_template.format(cid)) for cid in subset]
    raw_fracs = np.array([fraction_map[cid] for cid in subset], dtype=float)
    raw_sum = float(raw_fracs.sum())
    if raw_sum <= 0:
        # fallback to uniform if something odd happens
        norm_fracs = [1.0 / len(subset)] * len(subset)
    else:
        norm_fracs = (raw_fracs / raw_sum).tolist()
    return w_list, norm_fracs

# Evaluate every proper coalition (exclude the grand coalition at first)
# If you want all, use `powerset`; if you want proper only, do `powerset[:-1]` as you had.
for subset in powerset[:-1]:
    # 1) aggregate weights
    if len(subset) == 1:
        subset_sd = torch.load(submodel_file_template.format(subset[0]))
    else:
        w_list, norm_fracs = subset_weights_and_fracs(subset)
        subset_sd = average_weights(w_list, norm_fracs)

    # 2) build a model with identical arch/buffers and load weights
    submodel = copy.deepcopy(global_model).to(device)
    submodel.load_state_dict(subset_sd)
    submodel.eval()

    # 3) evaluate with your current pipeline’s evaluator
    mean_dice, metric_tc, metric_wt, metric_et = evaluate_model(submodel, val_dataset, device)

    # 4) record utility
    accuracy_dict[subset] = float(mean_dice)

    print(f"\nCoalition {subset}: mean Val Dice={mean_dice:.4f} | TC={metric_tc:.4f} | WT={metric_wt:.4f} | ET={metric_et:.4f}")

    # free promptly
    del submodel
    torch.cuda.empty_cache()

# Optionally ensure the grand coalition utility is present (you stored it earlier, but just in case)
grand = tuple(clients)
if grand not in accuracy_dict:
    m, tc, wt, et = evaluate_model(global_model, val_dataset, device)
    accuracy_dict[grand] = float(m)



Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]


Coalition (1,): mean Val Dice=0.7566 | TC=0.7904 | WT=0.7422 | ET=0.7445


Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]


Coalition (2,): mean Val Dice=0.8264 | TC=0.8627 | WT=0.8064 | ET=0.8141


Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]


Coalition (3,): mean Val Dice=0.7991 | TC=0.8200 | WT=0.8093 | ET=0.7708


Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]


Coalition (4,): mean Val Dice=0.7606 | TC=0.7708 | WT=0.7953 | ET=0.7180


Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]


Coalition (5,): mean Val Dice=0.7591 | TC=0.7987 | WT=0.8146 | ET=0.6656


Evaluating:   0%|          | 0/505 [00:02<?, ?it/s]


Coalition (6,): mean Val Dice=0.8092 | TC=0.8338 | WT=0.8116 | ET=0.7850


Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]


Coalition (7,): mean Val Dice=0.7158 | TC=0.7138 | WT=0.7619 | ET=0.6734


Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]


Coalition (8,): mean Val Dice=0.6095 | TC=0.6096 | WT=0.6627 | ET=0.5594


Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]


Coalition (9,): mean Val Dice=0.8296 | TC=0.8582 | WT=0.8292 | ET=0.8048


Evaluating:   0%|          | 0/505 [00:01<?, ?it/s]

In [None]:
# Sanity check
# Confirm submodels differ (they should):
import torch

def l2_diff(sd1, sd2):
    s = 0.0 
    for k in sd1:
        s += torch.sum((sd1[k] - sd2[k])**2).item()
    return s

w1 = torch.load(submodel_file_template.format(clients[0]))
w2 = torch.load(submodel_file_template.format(clients[1]))
print("L2 diff between client 1 and 2:", l2_diff(w1, w2))


In [None]:
# utility of the empty coalition is 0 by definition
accuracy_dict[()] = 0.0

# ── Shapley & Least Core ─────────────────────────────────────────────────
trainTime = time.time() - start_time
start_time = time.time()
shapley_dict = shapley(accuracy_dict, len(clients))
shapTime = time.time() - start_time

start_time = time.time()
lc_dict = least_core(accuracy_dict, len(clients))
LCTime = time.time() - start_time

totalShapTime = trainTime + shapTime
totalLCTime   = trainTime + LCTime
print(f"\n Grand-coalition validation utility (Dice): {accuracy_dict[grand]:.4f}")
print(f" Total Time Shapley: {totalShapTime:0.4f}s")
print(f" Total Time LC:      {totalLCTime:0.4f}s")


In [None]:
accuracy_dict

In [None]:
print("Shapley allocation:")
for cid, phi in shapley_dict.items():
    print(f" client {cid}: {phi:.4f}")

print("\nLeast-Core allocation:")
for var in lc_dict.variables():
    if var.name.startswith("x("):
        print(f" client {var.name[2:-1]}: {var.value():.4f}")
print(f" e (slack): {lc_dict.variablesDict()['e'].value():.4f}")


In [None]:
def stats(vector):
    n = len(vector)
    egal = np.array([1/n for i in range(n)])
    normalised = np.array(vector / vector.sum())
    msg = f'Original vector: {vector}\n'
    msg += f'Normalised vector: {normalised}\n'
    msg += f'Max Dif: {normalised.max()-normalised.min()}\n'
    msg += f'Distance: {np.linalg.norm(normalised-egal)}\n'

    msg += f'Budget: {vector.sum()}\n'
    print(msg)

In [None]:
stats(np.array(list(shapley_dict.values())))

In [None]:
stats(np.array([i.value() for i in lc_dict.variables()])[1:])