<a href="https://colab.research.google.com/github/vs-152/FL-Contributions-Incentives-Project/blob/main/ISO_CIFAR10_OR_FINAL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [41]:
import torch
import torch.nn as nn
import numpy as np
import pulp
import copy
import time
from sklearn.model_selection import StratifiedShuffleSplit
import torchvision
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.utils.data.sampler import SubsetRandomSampler
from itertools import chain, combinations
from tqdm import tqdm
from scipy.special import comb
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [43]:
from utils import *
from models import *


In [33]:
# -----------------------------------------------------------
# 0.  Paths & meta-data
# -----------------------------------------------------------
import glob, nibabel as nib, pandas as pd
from monai.data import CacheDataset, DataLoader
from monai.transforms import (
    LoadImaged, EnsureChannelFirstd, Orientationd, ScaleIntensityd,
    RandFlipd, RandSpatialCropd, Compose, SelectItemsd
)

BRATS_DIR   = "/home/locolinux2/datasets/MICCAI_FeTS2022_TrainingData"
VAL_DIR     = "/home/locolinux2/datasets/MICCAI_FeTS2022_ValidationData"
CSV_PATH    = f"{BRATS_DIR}/partitioning_1.csv"     # pick 1, 2 … or sanity
MODALITIES  = ["flair", "t1", "t1ce", "t2"]
LABEL_KEY   = "seg"  # BraTS tumour mask filename ending

# -----------------------------------------------------------
# 1.  Read partition file → mapping   {client_id: [subjIDs]}
# -----------------------------------------------------------
part_df          = pd.read_csv(CSV_PATH)
partition_map    = (
    part_df.groupby("Partition_ID")["Subject_ID"]
           .apply(list)
           .to_dict()
)
NUM_CLIENTS = len(partition_map)

# -----------------------------------------------------------
# 2.  Build a list of dicts – one per subject
# -----------------------------------------------------------
def build_records(subject_ids):
    recs = []
    for sid in subject_ids:
        subj_dir = f"{BRATS_DIR}/{sid}"
        rec = {m: f"{subj_dir}/{sid}_{m}.nii.gz"
               for m in MODALITIES}
        rec["seg"] = f"{subj_dir}/{sid}_{LABEL_KEY}.nii.gz"
        recs.append(rec)
    return recs

def build_val_records(val_dir):
    subjects = sorted(glob.glob(f"{val_dir}/FeTS2022_*_flair.nii.gz"))
    recs = []
    for flair_path in subjects:
        sid = flair_path.split("/")[-1].split("_flair")[0]
        subj_dir = f"{val_dir}/{sid}"
        rec = {m: f"{subj_dir}/{sid}_{m}.nii.gz" for m in MODALITIES}
        recs.append(rec)
    return recs

# -----------------------------------------------------------
# 3.  MONAI transform pipelines  (fixed)
# -----------------------------------------------------------
IMG_KEYS   = [m for m in MODALITIES]
ALL_KEYS   = IMG_KEYS + [LABEL_KEY]

train_tf = Compose([
    LoadImaged(keys=ALL_KEYS),
    EnsureChannelFirstd(keys=ALL_KEYS),
    Orientationd(keys=ALL_KEYS, axcodes="RAS"),
    ScaleIntensityd(keys=ALL_KEYS, minv=-1.0, maxv=1.0), # scale to [-1,1]. Diffusion Models do better if centered on a 0 mean
    SelectItemsd(keys=ALL_KEYS),
])

val_tf = Compose([
    LoadImaged(keys=MODALITIES),
    EnsureChannelFirstd(keys=MODALITIES),
    Orientationd(keys=MODALITIES, axcodes="RAS"),
    ScaleIntensityd(keys=MODALITIES, minv=-1.0, maxv=1.0),
    SelectItemsd(keys=MODALITIES),
])

In [17]:
partition_map.items()

dict_items([(1, ['FeTS2022_01341', 'FeTS2022_01333', 'FeTS2022_01077', 'FeTS2022_01054', 'FeTS2022_00285', 'FeTS2022_01308', 'FeTS2022_01363', 'FeTS2022_01091', 'FeTS2022_01273', 'FeTS2022_01108', 'FeTS2022_01255', 'FeTS2022_01301', 'FeTS2022_00219', 'FeTS2022_00380', 'FeTS2022_01349', 'FeTS2022_00251', 'FeTS2022_01276', 'FeTS2022_01407', 'FeTS2022_01344', 'FeTS2022_01405', 'FeTS2022_00218', 'FeTS2022_01327', 'FeTS2022_01252', 'FeTS2022_01132', 'FeTS2022_01036', 'FeTS2022_01039', 'FeTS2022_01366', 'FeTS2022_00262', 'FeTS2022_01279', 'FeTS2022_00839', 'FeTS2022_01322', 'FeTS2022_00389', 'FeTS2022_00390', 'FeTS2022_00431', 'FeTS2022_00222', 'FeTS2022_00373', 'FeTS2022_00288', 'FeTS2022_00284', 'FeTS2022_01088', 'FeTS2022_00311', 'FeTS2022_00387', 'FeTS2022_00258', 'FeTS2022_01389', 'FeTS2022_00321', 'FeTS2022_01249', 'FeTS2022_01230', 'FeTS2022_00836', 'FeTS2022_00348', 'FeTS2022_01205', 'FeTS2022_00246', 'FeTS2022_00314', 'FeTS2022_01404', 'FeTS2022_01102', 'FeTS2022_00379', 'FeTS2022_0

In [18]:
records

[{'flair': '/home/locolinux2/datasets/MICCAI_FeTS2022_TrainingData/FeTS2022_01147/FeTS2022_01147_flair.nii.gz',
  't1': '/home/locolinux2/datasets/MICCAI_FeTS2022_TrainingData/FeTS2022_01147/FeTS2022_01147_t1.nii.gz',
  't1ce': '/home/locolinux2/datasets/MICCAI_FeTS2022_TrainingData/FeTS2022_01147/FeTS2022_01147_t1ce.nii.gz',
  't2': '/home/locolinux2/datasets/MICCAI_FeTS2022_TrainingData/FeTS2022_01147/FeTS2022_01147_t2.nii.gz',
  'seg': '/home/locolinux2/datasets/MICCAI_FeTS2022_TrainingData/FeTS2022_01147/FeTS2022_01147_seg.nii.gz'},
 {'flair': '/home/locolinux2/datasets/MICCAI_FeTS2022_TrainingData/FeTS2022_01149/FeTS2022_01149_flair.nii.gz',
  't1': '/home/locolinux2/datasets/MICCAI_FeTS2022_TrainingData/FeTS2022_01149/FeTS2022_01149_t1.nii.gz',
  't1ce': '/home/locolinux2/datasets/MICCAI_FeTS2022_TrainingData/FeTS2022_01149/FeTS2022_01149_t1ce.nii.gz',
  't2': '/home/locolinux2/datasets/MICCAI_FeTS2022_TrainingData/FeTS2022_01149/FeTS2022_01149_t2.nii.gz',
  'seg': '/home/locolin

In [25]:
# -----------------------------------------------------------
# 4.  Build per-client datasets & dataloaders
# -----------------------------------------------------------
train_datasets = {}     # {client_id: monai CacheDataset}
for cid, subj_list in partition_map.items():
    records = build_records(subj_list)
    train_datasets[cid] = CacheDataset(data=records, transform=train_tf, cache_rate=1.0)


Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 511/511 [04:37<00:00,  1.84it/s]
Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:03<00:00,  1.81it/s]
Loading dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:07<00:00,  1.88it/s]
Loading dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 47/47 [00:24<00:00,

In [35]:
# -----------------------------------------------------------
# 7.  Build test dataset & dataloader
# -----------------------------------------------------------
val_records  = build_val_records(VAL_DIR)
test_dataset  = CacheDataset(data=val_records, transform=val_tf, cache_rate=1.0)
# test_loader   = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)

In [23]:
# class CustomTensorDataset(Dataset):
#     """TensorDataset with support of transforms.
#     """
#     def __init__(self, tensors, transform=None):
#         self.tensors = tensors
#         self.transform = transform

#     def __getitem__(self, index):
#         x = self.tensors[0][index]

#         if self.transform:
#             x = self.transform(x)

#         y = self.tensors[1][index]

#         return x, y

#     def __len__(self):
#         return self.tensors[0].shape[0]

def test_inference(model, test_dataset):

    model.eval()
    loss, total, correct = 0.0, 0.0, 0.0
    criterion = nn.CrossEntropyLoss().to(device)
    testloader = DataLoader(test_dataset, batch_size=200, shuffle=False)

    for _, (images, labels) in enumerate(testloader):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        batch_loss = criterion(outputs, labels)
        loss += batch_loss.item()
        _, pred_labels = torch.max(outputs, 1)
        pred_labels = pred_labels.view(-1)
        correct += torch.sum(torch.eq(pred_labels, labels)).item()
        total += len(labels)
    accuracy = correct / total

    return accuracy, loss

In [36]:
# N = 10 #srch
N = list(train_datasets.keys())[-1]
print(f"We got {N} clients")
local_bs = 512
lr = 0.01
local_ep = 5
EPOCHS = 5

# noise_rates = np.linspace(0, 1, N, endpoint=False)
# split_dset = mnist_iid(trainset, N)
# user_groups = {i: 0 for i in range(1, N+1)}
# noise_idx = {i: 0 for i in range(1, N+1)}
# train_datasets = {i: 0 for i in range(1, N+1)}
# for n in range(N):
#     user_groups[n+1] = np.array(list(split_dset[n]), dtype=np.int)
#     user_train_x, user_train_y = x_train[user_groups[n+1]], y_train[user_groups[n+1]]
#     user_noisy_y, noise_idx[n+1] = noisify_MNIST(noise_rates[n], 'symmetric', user_train_x, user_train_y)
    
#     train_datasets[n+1] = CustomTensorDataset((user_train_x, user_noisy_y), transform_train)

def fixfuckingbn(subset_weights, global_model_state_dict):
    for pair_1, pair_2 in zip(subset_weights.items(), global_model_state_dict.items()):
        if ('running' in pair_1[0]) or ('batches' in pair_1[0]):
            subset_weights[pair_1[0]] = global_model_state_dict[pair_1[0]]
    
    return subset_weights

global_model = ResNet9().to(device)
global_model.to(device)
global_model.train()

global_weights = global_model.state_dict()
powerset = list(powersettool(range(1, N+1)))
submodel_dict = {}  
submodel_dict[()] = copy.deepcopy(global_model)
accuracy_dict = {}
shapley_dict = {}

We got 23 clients


In [39]:
train_datasets

{1: <monai.data.dataset.CacheDataset at 0x7fc0d9586f70>,
 2: <monai.data.dataset.CacheDataset at 0x7fc2c40a9eb0>,
 3: <monai.data.dataset.CacheDataset at 0x7fc0cc62b070>,
 4: <monai.data.dataset.CacheDataset at 0x7fc0daabae20>,
 5: <monai.data.dataset.CacheDataset at 0x7fc0d9470700>,
 6: <monai.data.dataset.CacheDataset at 0x7fc0d94a7ee0>,
 7: <monai.data.dataset.CacheDataset at 0x7fc0d887da30>,
 8: <monai.data.dataset.CacheDataset at 0x7fc0cc17e0d0>,
 9: <monai.data.dataset.CacheDataset at 0x7fc0cd58fe20>,
 10: <monai.data.dataset.CacheDataset at 0x7fc0cd5aae80>,
 11: <monai.data.dataset.CacheDataset at 0x7fc0d94e1be0>,
 12: <monai.data.dataset.CacheDataset at 0x7fc0d9586160>,
 13: <monai.data.dataset.CacheDataset at 0x7fc0d94c7490>,
 14: <monai.data.dataset.CacheDataset at 0x7fc0d94dabe0>,
 15: <monai.data.dataset.CacheDataset at 0x7fc0d963ae20>,
 16: <monai.data.dataset.CacheDataset at 0x7fc0d93d0c40>,
 17: <monai.data.dataset.CacheDataset at 0x7fc0d94c77f0>,
 18: <monai.data.datase

In [45]:
import torch
import torch.nn as nn
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class ResNet9(nn.Module):
    def __init__(self):
        super(ResNet9, self).__init__()
        self.prep = self.convbnrelu(channels=3, filters=64)
        self.layer1 = self.convbnrelu(64, 128)
        self.layer_pool = nn.MaxPool2d(2, 2, 0, 1, ceil_mode=False)
        self.layer1r1 = self.convbnrelu(128, 128)
        self.layer1r2 = self.convbnrelu(128, 128)
        self.layer2 = self.convbnrelu(128, 256)
        self.layer3 = self.convbnrelu(256, 512)
        self.layer3r1 = self.convbnrelu(512, 512)
        self.layer3r2 = self.convbnrelu(512, 512)
        self.out_pool = nn.MaxPool2d(kernel_size=4, stride=4, ceil_mode=False)
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(in_features=512, out_features=10, bias=False)

    def convbnrelu(self, channels, filters):
        layers = []
        layers.append(nn.Conv2d(channels, filters, (3, 3),
                                (1, 1), (1, 1), bias=False))
        layers.append(nn.BatchNorm2d(filters, track_running_stats=False))
        layers.append(nn.ReLU(inplace=True))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.prep(x)
        x = self.layer_pool(self.layer1(x))
        r1 = self.layer1r2(self.layer1r1(x)) 
        x = x + r1
        x = self.layer_pool(self.layer2(x))
        x = self.layer_pool(self.layer3(x))
        r3 = self.layer3r2(self.layer3r1(x))
        x = x + r3
        out = self.out_pool(x)
        out = self.flatten(out)
        out = self.linear(out)
        out = out * 0.125

        return out
        
class LocalUpdate(object):

    def __init__(self, lr, local_ep, trainloader):
        self.lr = lr
        self.local_ep = local_ep
        self.trainloader = trainloader

    def update_weights(self, model):

        model.train()
        epoch_loss = []
        optimizer = torch.optim.Adam(model.parameters())
        criterion = nn.CrossEntropyLoss().to(device)
        for iter in range(self.local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.trainloader):
                images, labels = images.to(device), labels.to(device)
                model.zero_grad()   
                log_probs = model(images)
                loss = criterion(log_probs, labels)
                loss.backward()
                optimizer.step()
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))

        return model.state_dict(), sum(epoch_loss) / len(epoch_loss)

In [None]:
start_time = time.time()

for subset in range(1, N+1):
    submodel_dict[(subset,)] = copy.deepcopy(global_model)
    submodel_dict[(subset,)].to(device)
    submodel_dict[(subset,)].train() 
 
train_loss, train_accuracy = [], []
val_acc_list, net_list = [], []
print_every = 1

idxs_users = np.arange(1, N+1)
# total_data = sum(len(user_groups[i]) for i in range(1, N+1))
# fraction = [len(user_groups[i])/total_data for i in range(1, N+1)]

# ── collect dataset sizes ──────────────────────────────────────────────────
# MONAI's CacheDataset inherits __len__, so `len(ds)` is cheap:
sizes = {k: len(ds) for k, ds in train_datasets.items()}

# ── total samples across all clients ───────────────────────────────────────
total_data = sum(sizes.values())

# ── FedAvg weight (a.k.a. fraction) for each client ────────────────────────
# Keep the list in key order 1…N so it lines up with your loops later.
fraction = [sizes[i] / total_data for i in range(1, N + 1)]

# ───────────────────────────────────────────────────────────────────────────

for epoch in tqdm(range(EPOCHS)):
    local_weights, local_losses = [], []
    print(f'\n | Global Training Round : {epoch+1} |\n')
    global_model.train()
    for idx in idxs_users:
        trainloader = DataLoader(train_datasets[idx], batch_size=local_bs, shuffle=True)
        local_model = LocalUpdate(lr, local_ep, trainloader)
        w, loss = local_model.update_weights(model=copy.deepcopy(global_model))
        local_weights.append(copy.deepcopy(w))
        local_losses.append(copy.deepcopy(loss))
    global_weights = average_weights(local_weights, fraction) 
    loss_avg = sum(local_losses) / len(local_losses)
    train_loss.append(loss_avg)

    gradients = calculate_gradients(local_weights, global_model.state_dict()) 
    for i in range(1, N+1):
        subset_weights = update_weights_from_gradients(gradients[i-1], submodel_dict[(i,)].state_dict()) 
        subset_weights = fixfuckingbn(subset_weights, global_model.state_dict())
        submodel_dict[(i,)].load_state_dict(subset_weights)

    global_model.load_state_dict(global_weights)
    global_model.eval()

    if (epoch+1) % print_every == 0:
        print(f' \nAvg Training Stats after {epoch+1} global rounds:')
        print(f'Training Loss : {np.mean(np.array(train_loss))}')
        # print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))

test_acc, test_loss = test_inference(global_model, test_dataset)
print(f' \n Results after {EPOCHS} global rounds of training:')
print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))

accuracy_dict[powerset[-1]] = test_acc

# ADJUSTED-OR APPROX
for subset in powerset[:-1]: 
    if len(subset) > 1:
        # calculate the average of the subset of weights from list of all the weights
        subset_weights = average_weights([submodel_dict[(i,)].state_dict() for i in subset], [fraction[i-1] for i in subset]) 
        submodel = copy.deepcopy(submodel_dict[()])
        submodel.load_state_dict(subset_weights)
        
        test_acc, test_loss = test_inference(submodel,test_dataset)
        print(f' \n Results after {EPOCHS} global rounds of training (for OR): ')
        print("|---- Test Accuracy for {}: {:.2f}%".format(subset, 100*test_acc))
        accuracy_dict[subset] = test_acc
    else: 
        test_acc, test_loss = test_inference(submodel_dict[subset], test_dataset)
        accuracy_dict[subset] = test_acc

trainTime = time.time() - start_time
start_time = time.time()
shapley_dict = shapley(accuracy_dict, N)
shapTime = time.time() - start_time
start_time = time.time()
lc_dict = least_core(accuracy_dict, N)
LCTime = time.time() - start_time
totalShapTime = trainTime + shapTime
totalLCTime = trainTime + LCTime
print(f'\n ACCURACY: {accuracy_dict[powerset[-1]]}')
print('\n Total Time Shapley: {0:0.4f}'.format(totalShapTime))
print('\n Total Time LC: {0:0.4f}'.format(totalLCTime))

  0%|                                                                                                                                                                                                                                 | 0/5 [00:00<?, ?it/s]


 | Global Training Round : 1 |



In [None]:
def stats(vector):
    n = len(vector)
    egal = np.array([1/n for i in range(n)])
    normalised = np.array(vector / vector.sum())
    msg = f'Original vector: {vector}\n'
    msg += f'Normalised vector: {normalised}\n'
    msg += f'Max Dif: {normalised.max()-normalised.min()}\n'
    msg += f'Distance: {np.linalg.norm(normalised-egal)}\n'

    msg += f'Budget: {vector.sum()}\n'
    print(msg)

In [None]:
stats(np.array(list(shapley_dict.values())))

Original vector: [ 0.12156075  0.10917627  0.10981044  0.10419635  0.07705849  0.07902048
  0.04776313  0.03351619 -0.03449246 -0.12100964]
Normalised vector: [ 0.23084078  0.20732296  0.20852722  0.19786622  0.14633212  0.15005787
  0.09070098  0.06364639 -0.0655003  -0.22979423]
Max Dif: 0.4606350110622801
Distance: 0.4384159898612196
Budget: 0.5266



In [None]:
stats(np.array([i.value() for i in lc_dict.variables()])[1:])

Original vector: [0.100125 0.       0.091025 0.107525 0.101125 0.0763   0.0948   0.0561
 0.       0.      ]
Normalised vector: [0.159689   0.         0.14517544 0.17149123 0.16128389 0.12169059
 0.15119617 0.08947368 0.         0.        ]
Max Dif: 0.17149122807017544
Distance: 0.21834065249685256
Budget: 0.627

