In [2]:
import argparse
import os
import time

import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from flamby.datasets.fed_tcga_brca import (
    BATCH_SIZE,
    LR,
    NUM_EPOCHS_POOLED,
    Baseline,
    BaselineLoss,
    FedTcgaBrca,
    NUM_CLIENTS,
    metric,
    get_nb_max_rounds
)
from flamby.utils import evaluate_model_on_tests
import warnings
import warnings
warnings.filterwarnings("ignore")
from flamby.datasets.fed_tcga_brca import FedTcgaBrca as FedDataset
from flamby.strategies.fed_avg_log import FedAvgWithLog as strat
from tqdm import tqdm

In [8]:
from flamby.utils import evaluate_model_on_tests

In [3]:
train_dataloaders = [
            torch.utils.data.DataLoader(
                FedDataset(center = i, train = True, pooled = False),
                batch_size = BATCH_SIZE,
                shuffle = True,
                num_workers = 0
            )
            for i in range(NUM_CLIENTS)
        ]

lossfunc = BaselineLoss()
m = Baseline()

In [20]:
# Federated Learning loop
# 2nd line of code to change to switch to another strategy (feed the FL strategy the right HPs)
args = {
            "training_dataloaders": train_dataloaders,
            "model": m,
            "loss": lossfunc,
            "optimizer_class": torch.optim.SGD,
            "learning_rate": LR / 10.0,
            "num_updates": 100,
# This helper function returns the number of rounds necessary to perform approximately as many
# epochs on each local dataset as with the pooled training
            "nrounds": get_nb_max_rounds(100),
        }

In [51]:
s = strat(**args)
results = []
for rounds in tqdm(range(10)):
    s.perform_round()

    # evaluation for clients
    for client_id in range(NUM_CLIENTS-1):
        test_dataset = FedTcgaBrca(center = client_id, train=False, pooled=False)
        test_dataloader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=4,
            # drop_last=True,
        )
        
        results.append((client_id,evaluate_model_on_tests(s.models_list[client_id].model, [test_dataloader], metric)))

  0%|                                                                 | 0/10 [00:00<?, ?it/s]
  0%|                                                                  | 0/1 [00:00<?, ?it/s][A
100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  9.56it/s][A

100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 10.15it/s][A

  0%|                                                                  | 0/1 [00:00<?, ?it/s][A
100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  9.14it/s][A

  0%|                                                                  | 0/1 [00:00<?, ?it/s][A
100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  9.79it/s][A

  0%|                                                                  | 0/1 [00:00<?, ?it/s][A
100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  8.66it/s][A
 10%|█████▋                  

In [52]:
results

[(0, {'client_test_0': np.float64(0.5921787709497207)}),
 (1, {'client_test_0': np.float64(0.7727272727272727)}),
 (2, {'client_test_0': np.float64(0.6851851851851852)}),
 (3, {'client_test_0': np.float64(0.5625)}),
 (4, {'client_test_0': np.float64(0.9148936170212766)}),
 (0, {'client_test_0': np.float64(0.6061452513966481)}),
 (1, {'client_test_0': np.float64(0.7727272727272727)}),
 (2, {'client_test_0': np.float64(0.6759259259259259)}),
 (3, {'client_test_0': np.float64(0.5625)}),
 (4, {'client_test_0': np.float64(0.9148936170212766)}),
 (0, {'client_test_0': np.float64(0.6201117318435754)}),
 (1, {'client_test_0': np.float64(0.7727272727272727)}),
 (2, {'client_test_0': np.float64(0.6759259259259259)}),
 (3, {'client_test_0': np.float64(0.625)}),
 (4, {'client_test_0': np.float64(0.851063829787234)}),
 (0, {'client_test_0': np.float64(0.6145251396648045)}),
 (1, {'client_test_0': np.float64(0.7727272727272727)}),
 (2, {'client_test_0': np.float64(0.6944444444444444)}),
 (3, {'clien

In [18]:
train_dataloaders

[<torch.utils.data.dataloader.DataLoader at 0x750eaf5d14e0>,
 <torch.utils.data.dataloader.DataLoader at 0x750eaf4d5f60>,
 <torch.utils.data.dataloader.DataLoader at 0x750eaf514880>,
 <torch.utils.data.dataloader.DataLoader at 0x750eaf4e8130>,
 <torch.utils.data.dataloader.DataLoader at 0x750eaf3b4790>,
 <torch.utils.data.dataloader.DataLoader at 0x750eaf3ca8c0>]

In [19]:
NUM_CLIENTS

6

In [49]:
test_dataset = FedTcgaBrca(center = 4, train=False, pooled=False)
test_dataloader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=4,
            # drop_last=True,
        )

In [50]:
for X, y in test_dataloader:
    print(X, y)

tensor([[86.,  1.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  1.,  1.,  0.,  1.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [66.,  0.,  1.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          1.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  1.,  0.,  1.,  1.,  0.,  1.,
          0.,  0.,  0.,  0.,  1.,  0.,  0.,  1.,  0.,  0.,  0.],
        [40.,  1.,  0.,  0.,  0.,  0.,  1.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,
          1.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  1.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  1.,  0.,  0.,  1.,  0.,  0.,  0.],
        [66.,  0.,  1.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  1.,  1.,  0.,  1.,
          0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [68.,  0.,  1.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,
          0.,  1.,  0.,  0., 