In [1]:
# ─────────────────────────────────────────────────────────────
#  Imports (minimal for post-hoc Shapley eval)
# ─────────────────────────────────────────────────────────────
import os
import random
from itertools import combinations

import numpy as np
import pandas as pd
import torch

from monai.config import print_config
from monai.utils import set_determinism
from monai.data import CacheDataset
from monai.transforms import (
    Compose,
    EnsureChannelFirstd,
    EnsureTyped,
    LoadImaged,
    Orientationd,
    NormalizeIntensityd,
    Spacingd,
    MapTransform,
)
from monai.networks.nets import SegResNet

from utils import shapley, least_core  # used inside run_shapley_eval if you keep that design

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print_config()
set_determinism(seed=0)

# ─────────────────────────────────────────────────────────────
#  Label converter (same as training)
# ─────────────────────────────────────────────────────────────
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    FeTS/BraTS label mapping (ints on disk): 0=background, 1=NCR/NET, 2=edema, 4=enhancing (ET)
    Build 3-channel multi-label [TC, WT, ET]:
      TC = (label==1) OR (label==4)
      WT = (label==1) OR (label==2) OR (label==4)
      ET = (label==4)
    """
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            lab = d[key]
            tc = torch.logical_or(lab == 1, lab == 4)
            wt = torch.logical_or(torch.logical_or(lab == 1, lab == 2), lab == 4)
            et = (lab == 4)
            d[key] = torch.stack([tc, wt, et], dim=0).float()
        return d

# ─────────────────────────────────────────────────────────────
#  Paths & basic meta
# ─────────────────────────────────────────────────────────────
BRATS_DIR = "/mnt/d/Datasets/FETS_data/MICCAI_FeTS2022_TrainingData"
CSV_PATH  = f"{BRATS_DIR}/partitioning_1.csv"
MODALITIES = ["flair", "t1", "t1ce", "t2"]
LABEL_KEY  = "seg"

VAL_CENTRES = {16, 17, 18, 19, 20, 21, 22, 23}   # same as training
CUT_OFF, FRAC, SEED = 16, 1.0, 42               # must match training script

# ─────────────────────────────────────────────────────────────
#  Partition map and train/val split (no train CacheDataset)
# ─────────────────────────────────────────────────────────────
part_df       = pd.read_csv(CSV_PATH)
partition_map = (
    part_df.groupby("Partition_ID")["Subject_ID"]
           .apply(list).to_dict()
)   # keys are 1 … 23

train_partitions = {cid: sids for cid, sids in partition_map.items()
                    if cid not in VAL_CENTRES}
val_subjects     = sum((partition_map[cid] for cid in VAL_CENTRES), [])

# clients actually used in training (<= CUT_OFF)
idxs_users = sorted([cid for cid in train_partitions.keys() if cid <= CUT_OFF])

# recompute *sizes* exactly as in training: k = max(1, int(len(subj_ids)*FRAC))
sizes = {}
for cid in idxs_users:
    subj_ids = train_partitions[cid]
    k = max(1, int(len(subj_ids) * FRAC))
    sizes[cid] = k

total_n = sum(sizes.values())
fractions = [sizes[cid] / total_n for cid in idxs_users]

print("clients:", idxs_users)
print("sizes:", sizes)
print("fractions:", fractions)

# ─────────────────────────────────────────────────────────────
#  Validation dataset only
# ─────────────────────────────────────────────────────────────
val_transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)

def build_records(subject_ids):
    recs = []
    for sid in subject_ids:
        sdir = f"{BRATS_DIR}/{sid}"
        images = [f"{sdir}/{sid}_{m}.nii.gz" for m in MODALITIES]  # 4 modalities
        recs.append({"image": images, "label": f"{sdir}/{sid}_{LABEL_KEY}.nii.gz"})
    return recs

val_dataset = CacheDataset(
    build_records(val_subjects), transform=val_transform, cache_rate=1
)
print("validation size:", len(val_dataset))

# ─────────────────────────────────────────────────────────────
#  Global model architecture + loading best weights
# ─────────────────────────────────────────────────────────────
global_model = SegResNet(
    blocks_down=[1, 2, 2, 4],
    blocks_up=[1, 1, 1],
    init_filters=16,
    in_channels=4,
    out_channels=3,
    dropout_prob=0.2,
).to(device)

submodel_dir = "submodels"
submodel_file_template = os.path.join(submodel_dir, "submodel_{}.pth")
best_model_path = os.path.join(submodel_dir, "best_metric_model.pth")

global_model.load_state_dict(torch.load(best_model_path, map_location=device))
global_model.eval()

print("Loaded best global model from:", best_model_path)


[0;93m2025-11-18 15:58:59.994504266 [W:onnxruntime:Default, device_discovery.cc:164 DiscoverDevicesForPlatform] GPU device discovery failed: device_discovery.cc:89 ReadFileContents Failed to open file: "/sys/class/drm/card0/device/vendor"[m


MONAI version: 1.6.dev2542
Numpy version: 2.1.2
Pytorch version: 2.8.0+cu126
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 612f3dd3cba4d73cfcea4b5329079e20aa31523d
MONAI __file__: /home/<username>/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: 5.4.4
Nibabel version: 5.3.2
scikit-image version: 0.25.2
scipy version: 1.15.3
Pillow version: 11.0.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.23.0+cu126
tqdm version: 4.67.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 7.0.0
pandas version: 2.3.2
einops version: 0.8.1
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For d

monai.transforms.spatial.dictionary Orientationd.__init__:labels: Current default value of argument `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` was changed in version None from `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` to `labels=None`. Default value changed to None meaning that the transform now uses the 'space' of a meta-tensor, if applicable, to determine appropriate axis labels.


clients: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
sizes: {1: 511, 2: 6, 3: 15, 4: 47, 5: 22, 6: 34, 7: 12, 8: 8, 9: 4, 10: 8, 11: 14, 12: 11, 13: 35, 14: 6, 15: 13}
fractions: [0.6849865951742627, 0.00804289544235925, 0.020107238605898123, 0.06300268096514745, 0.029490616621983913, 0.045576407506702415, 0.0160857908847185, 0.010723860589812333, 0.005361930294906166, 0.010723860589812333, 0.01876675603217158, 0.014745308310991957, 0.04691689008042895, 0.00804289544235925, 0.01742627345844504]


Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 505/505 [31:29<00:00,  3.74s/it]


validation size: 505
Loaded best global model from: submodels/best_metric_model.pth


In [None]:
from shapley_eval import run_shapley_eval

accuracy_dict, shapley_dict, lc_dict = run_shapley_eval(
    global_model=global_model,
    val_dataset=val_dataset,
    idxs_users=idxs_users,                      # same list you used in training
    fractions=fractions,                        # same FedAvg fractions as training
    submodel_file_template=submodel_file_template,  # e.g. "submodels/submodel_{}.pth"
    device=device,
    coalition_csv="coalition_utilities.csv",
    allocation_csv="allocation_summary.csv",
)


Evaluating:   0%|          | 0/505 [00:00<?, ?it/s]

Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:306.)
Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:306.)



Coalition (1,): mean Val Dice=0.7566 | TC=0.7904 | WT=0.7422 | ET=0.7445


Evaluating:   0%|          | 0/505 [00:00<?, ?it/s]


Coalition (2,): mean Val Dice=0.8264 | TC=0.8627 | WT=0.8064 | ET=0.8141


Evaluating:   0%|          | 0/505 [00:00<?, ?it/s]


Coalition (3,): mean Val Dice=0.7991 | TC=0.8200 | WT=0.8093 | ET=0.7708


Evaluating:   0%|          | 0/505 [00:00<?, ?it/s]


Coalition (4,): mean Val Dice=0.7606 | TC=0.7708 | WT=0.7953 | ET=0.7179


Evaluating:   0%|          | 0/505 [00:00<?, ?it/s]


Coalition (5,): mean Val Dice=0.7591 | TC=0.7988 | WT=0.8146 | ET=0.6656


Evaluating:   0%|          | 0/505 [00:00<?, ?it/s]