In [1]:
# ─────────────────────────────────────────────────────────────
#  Imports (minimal for post-hoc Shapley eval)
# ─────────────────────────────────────────────────────────────
import os
import random
from itertools import combinations

import numpy as np
import pandas as pd
import torch

from monai.config import print_config
from monai.utils import set_determinism
from monai.data import CacheDataset
from monai.transforms import (
    Compose,
    EnsureChannelFirstd,
    EnsureTyped,
    LoadImaged,
    Orientationd,
    NormalizeIntensityd,
    Spacingd,
    MapTransform,
)
from monai.networks.nets import SegResNet

from utils import shapley, least_core  # used inside run_shapley_eval if you keep that design

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print_config()
set_determinism(seed=0)

# ─────────────────────────────────────────────────────────────
#  Label converter (same as training)
# ─────────────────────────────────────────────────────────────
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    FeTS/BraTS label mapping (ints on disk): 0=background, 1=NCR/NET, 2=edema, 4=enhancing (ET)
    Build 3-channel multi-label [TC, WT, ET]:
      TC = (label==1) OR (label==4)
      WT = (label==1) OR (label==2) OR (label==4)
      ET = (label==4)
    """
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            lab = d[key]
            tc = torch.logical_or(lab == 1, lab == 4)
            wt = torch.logical_or(torch.logical_or(lab == 1, lab == 2), lab == 4)
            et = (lab == 4)
            d[key] = torch.stack([tc, wt, et], dim=0).float()
        return d

# ─────────────────────────────────────────────────────────────
#  Paths & basic meta
# ─────────────────────────────────────────────────────────────
BRATS_DIR = "/mnt/d/Datasets/FETS_data/MICCAI_FeTS2022_TrainingData"
CSV_PATH  = f"{BRATS_DIR}/partitioning_1.csv"
MODALITIES = ["flair", "t1", "t1ce", "t2"]
LABEL_KEY  = "seg"

TOP_K_TRAIN_SITES = 6      # must match training logic
FRAC, SEED = 1.0, 42       # must match training script

# ─────────────────────────────────────────────────────────────
#  Partition map and train/val split (no train CacheDataset)
# ─────────────────────────────────────────────────────────────
part_df = pd.read_csv(CSV_PATH)

# 1) subject counts per site
site_counts = (
    part_df.groupby("Partition_ID")["Subject_ID"]
           .nunique()
)

# 2) pick top-K most populated sites for training
TRAIN_CENTRES = set(
    site_counts.sort_values(ascending=False)
               .head(TOP_K_TRAIN_SITES)
               .index.tolist()
)

# 3) remaining sites are validation
VAL_CENTRES = set(site_counts.index) - TRAIN_CENTRES

print("Train centres (top 6 by subject count):")
print(site_counts.loc[sorted(TRAIN_CENTRES)])
print("\nValidation centres (remaining):")
print(site_counts.loc[sorted(VAL_CENTRES)])

# 4) map centre → subject IDs
partition_map = (
    part_df.groupby("Partition_ID")["Subject_ID"]
           .apply(list).to_dict()
)

train_partitions = {
    cid: sids for cid, sids in partition_map.items()
    if cid in TRAIN_CENTRES
}
val_subjects = sum((partition_map[cid] for cid in VAL_CENTRES), [])

# clients actually used in training
idxs_users = sorted(train_partitions.keys())

# recompute *sizes* exactly as in training: k = max(1, int(len(subj_ids) * FRAC))
sizes = {}
for cid in idxs_users:
    subj_ids = train_partitions[cid]
    k = max(1, int(len(subj_ids) * FRAC))
    sizes[cid] = k

total_n = sum(sizes.values())
fractions = [sizes[cid] / total_n for cid in idxs_users]

print("\nclients:", idxs_users)
print("sizes:", sizes)
print("fractions:", fractions)

# ─────────────────────────────────────────────────────────────
#  Validation dataset only
# ─────────────────────────────────────────────────────────────
val_transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)

def build_records(subject_ids):
    recs = []
    for sid in subject_ids:
        sdir = f"{BRATS_DIR}/{sid}"
        images = [f"{sdir}/{sid}_{m}.nii.gz" for m in MODALITIES]  # 4 modalities
        recs.append({"image": images, "label": f"{sdir}/{sid}_{LABEL_KEY}.nii.gz"})
    return recs

val_dataset = CacheDataset(
    build_records(val_subjects), transform=val_transform, cache_rate=1
)
print("validation size:", len(val_dataset))

# ─────────────────────────────────────────────────────────────
#  Global model architecture + loading best weights
# ─────────────────────────────────────────────────────────────
global_model = SegResNet(
    blocks_down=[1, 2, 2, 4],
    blocks_up=[1, 1, 1],
    init_filters=16,
    in_channels=4,
    out_channels=3,
    dropout_prob=0.2,
).to(device)

submodel_dir = "submodels"
submodel_file_template = os.path.join(submodel_dir, "submodel_{}.pth")
best_model_path = os.path.join(submodel_dir, "best_metric_model.pth")

global_model.load_state_dict(torch.load(best_model_path, map_location=device))
global_model.eval()

print("Loaded best global model from:", best_model_path)


[0;93m2025-12-01 21:09:27.427192054 [W:onnxruntime:Default, device_discovery.cc:164 DiscoverDevicesForPlatform] GPU device discovery failed: device_discovery.cc:89 ReadFileContents Failed to open file: "/sys/class/drm/card0/device/vendor"[m


MONAI version: 1.6.dev2542
Numpy version: 2.1.2
Pytorch version: 2.8.0+cu126
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 612f3dd3cba4d73cfcea4b5329079e20aa31523d
MONAI __file__: /home/<username>/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: 5.4.4
Nibabel version: 5.3.2
scikit-image version: 0.25.2
scipy version: 1.15.3
Pillow version: 11.0.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.23.0+cu126
tqdm version: 4.67.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 7.0.0
pandas version: 2.3.2
einops version: 0.8.1
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For d

monai.transforms.spatial.dictionary Orientationd.__init__:labels: Current default value of argument `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` was changed in version None from `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` to `labels=None`. Default value changed to None meaning that the transform now uses the 'space' of a meta-tensor, if applicable, to determine appropriate axis labels.


Train centres (top 6 by subject count):
Partition_ID
1     511
4      47
6      34
13     35
18    382
21     35
Name: Subject_ID, dtype: int64

Validation centres (remaining):
Partition_ID
2      6
3     15
5     22
7     12
8      8
9      4
10     8
11    14
12    11
14     6
15    13
16    30
17     9
19     4
20    33
22     7
23     5
Name: Subject_ID, dtype: int64

clients: [1, 4, 6, 13, 18, 21]
sizes: {1: 511, 4: 47, 6: 34, 13: 35, 18: 382, 21: 35}
fractions: [0.48946360153256707, 0.045019157088122604, 0.032567049808429116, 0.033524904214559385, 0.36590038314176243, 0.033524904214559385]


Loading dataset: 100%|████████████████████████████████████████████████████████████████| 207/207 [07:54<00:00,  2.29s/it]


validation size: 207
Loaded best global model from: submodels/best_metric_model.pth


In [2]:
idxs_users

[1, 4, 6, 13, 18, 21]

In [3]:
fractions

[0.48946360153256707,
 0.045019157088122604,
 0.032567049808429116,
 0.033524904214559385,
 0.36590038314176243,
 0.033524904214559385]

In [4]:
submodel_file_template

'submodels/submodel_{}.pth'

In [1]:
from shapley_eval import *

In [7]:
# 1. Long, resumable coalition evaluation
accuracy_dict = compute_coalition_utilities(
    global_model=global_model,
    val_dataset=val_dataset,
    idxs_users=idxs_users,
    fractions=fractions,
    submodel_file_template="submodels/submodel_{}.pth",
    device=device,
    coalition_csv="./logs/coalition_utilities.csv",
)

[Resume] Loaded 62 coalitions from ./logs/coalition_utilities.csv
Skipping coalition (1,) (already evaluated).
Skipping coalition (4,) (already evaluated).
Skipping coalition (6,) (already evaluated).
Skipping coalition (13,) (already evaluated).
Skipping coalition (18,) (already evaluated).
Skipping coalition (21,) (already evaluated).
Skipping coalition (1, 4) (already evaluated).
Skipping coalition (1, 6) (already evaluated).
Skipping coalition (1, 13) (already evaluated).
Skipping coalition (1, 18) (already evaluated).
Skipping coalition (1, 21) (already evaluated).
Skipping coalition (4, 6) (already evaluated).
Skipping coalition (4, 13) (already evaluated).
Skipping coalition (4, 18) (already evaluated).
Skipping coalition (4, 21) (already evaluated).
Skipping coalition (6, 13) (already evaluated).
Skipping coalition (6, 18) (already evaluated).
Skipping coalition (6, 21) (already evaluated).
Skipping coalition (13, 18) (already evaluated).
Skipping coalition (13, 21) (already ev

Evaluating:   0%|          | 0/207 [00:00<?, ?it/s]

Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:306.)
Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:306.)


 Total coalition eval time: 95.1772s
 Grand-coalition validation utility (Dice): 0.7899

[Saved] Coalition utilities → ./logs/coalition_utilities.csv


In [2]:


# 2. Fast Shapley/LC — can be re-run anytime from the CSV
accuracy_dict, shapley_dict, lc_dict = run_shapley_from_coalitions(
    coalition_csv="./logs/coalition_utilities.csv",
    allocation_csv="./logs/allocation_summary.csv",
)


contributor: 1 (1,)
marginal: 0.7209593653678894
contributor: 2 (2,)
marginal: 0.7006985545158386
contributor: 3 (3,)
marginal: 0.7113056182861328
contributor: 4 (4,)
marginal: 0.707489013671875
contributor: 5 (5,)
marginal: 0.6794240474700928
contributor: 6 (6,)
marginal: 0.4434252977371216
contributor: 1 (1, 2)
marginal: 0.02936464548110962
contributor: 2 (1, 2)
marginal: 0.009103834629058838
contributor: 1 (1, 3)
marginal: 0.017319917678833008
contributor: 3 (1, 3)
marginal: 0.007666170597076416
contributor: 1 (1, 4)
marginal: 0.01957017183303833
contributor: 4 (1, 4)
marginal: 0.006099820137023926
contributor: 1 (1, 5)
marginal: 0.06714439392089844
contributor: 5 (1, 5)
marginal: 0.025609076023101807
contributor: 1 (1, 6)
marginal: 0.2815920114517212
contributor: 6 (1, 6)
marginal: 0.004057943820953369
contributor: 2 (2, 3)
marginal: -0.0016062259674072266
contributor: 3 (2, 3)
marginal: 0.009000837802886963
contributor: 2 (2, 4)
marginal: 0.019704699516296387
contributor: 4 (2, 4)