In [2]:
import argparse
import os
import time

import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from flamby.datasets.fed_tcga_brca import (
    BATCH_SIZE,
    LR,
    NUM_EPOCHS_POOLED,
    Baseline,
    BaselineLoss,
    FedTcgaBrca,
    NUM_CLIENTS,
    metric,
    get_nb_max_rounds
)
from flamby.utils import evaluate_model_on_tests
import warnings
import warnings
warnings.filterwarnings("ignore")
from flamby.datasets.fed_tcga_brca import FedTcgaBrca as FedDataset
from flamby.strategies.fed_avg_log import FedAvgWithLog as strat
from tqdm import tqdm

In [8]:
from flamby.utils import evaluate_model_on_tests

In [3]:
train_dataloaders = [
            torch.utils.data.DataLoader(
                FedDataset(center = i, train = True, pooled = False),
                batch_size = BATCH_SIZE,
                shuffle = True,
                num_workers = 0
            )
            for i in range(NUM_CLIENTS)
        ]

lossfunc = BaselineLoss()
m = Baseline()

In [15]:
# Federated Learning loop
# 2nd line of code to change to switch to another strategy (feed the FL strategy the right HPs)
args = {
            "training_dataloaders": train_dataloaders,
            "model": m,
            "loss": lossfunc,
            "optimizer_class": torch.optim.SGD,
            "learning_rate": LR / 10.0,
            "num_updates": 100,
# This helper function returns the number of rounds necessary to perform approximately as many
# epochs on each local dataset as with the pooled training
            "nrounds": get_nb_max_rounds(100),
        }
s = strat(**args)

In [16]:
results = []
for rounds in tqdm(range(10)):
    s.perform_round()
    test_dataset = FedTcgaBrca(train=False, pooled=True)
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=4,
        # drop_last=True,
    )
    results.append(evaluate_model_on_tests(s.models_list[0].model, [test_dataloader], metric))

  0%|                                                                 | 0/10 [00:00<?, ?it/s]
  0%|                                                                  | 0/1 [00:00<?, ?it/s][A
100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  7.00it/s][A
 10%|█████▋                                                   | 1/10 [00:03<00:28,  3.14s/it]
  0%|                                                                  | 0/1 [00:00<?, ?it/s][A
100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  8.10it/s][A
 20%|███████████▍                                             | 2/10 [00:06<00:24,  3.08s/it]
  0%|                                                                  | 0/1 [00:00<?, ?it/s][A
100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  7.42it/s][A
 30%|█████████████████                                        | 3/10 [00:09<00:21,  3.09s/it]
  0%|                                     

In [12]:
test_dataloaders = [
            torch.utils.data.DataLoader(
                FedDataset(train = False, pooled = True),
                batch_size = BATCH_SIZE,
                shuffle = False,
                num_workers = 0,
            )
        ]
dict_cindex = evaluate_model_on_tests(m, test_dataloaders, metric)
print(dict_cindex)

100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 13.44it/s]

{'client_test_0': np.float64(0.7025641025641025)}





In [17]:
results

[{'client_test_0': np.float64(0.6915750915750916)},
 {'client_test_0': np.float64(0.7102564102564103)},
 {'client_test_0': np.float64(0.7161172161172161)},
 {'client_test_0': np.float64(0.7256410256410256)},
 {'client_test_0': np.float64(0.7333333333333333)},
 {'client_test_0': np.float64(0.7454212454212454)},
 {'client_test_0': np.float64(0.7446886446886447)},
 {'client_test_0': np.float64(0.7538461538461538)},
 {'client_test_0': np.float64(0.773992673992674)},
 {'client_test_0': np.float64(0.7732600732600733)}]

In [18]:
train_dataloaders

[<torch.utils.data.dataloader.DataLoader at 0x750eaf5d14e0>,
 <torch.utils.data.dataloader.DataLoader at 0x750eaf4d5f60>,
 <torch.utils.data.dataloader.DataLoader at 0x750eaf514880>,
 <torch.utils.data.dataloader.DataLoader at 0x750eaf4e8130>,
 <torch.utils.data.dataloader.DataLoader at 0x750eaf3b4790>,
 <torch.utils.data.dataloader.DataLoader at 0x750eaf3ca8c0>]