In [1]:
import argparse
import os
import time

import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from flamby.datasets.fed_tcga_brca import (
    BATCH_SIZE,
    LR,
    NUM_EPOCHS_POOLED,
    Baseline,
    BaselineLoss,
    FedTcgaBrca,
    NUM_CLIENTS,
    metric,
    get_nb_max_rounds
)
from flamby.utils import evaluate_model_on_tests
import warnings
import warnings
warnings.filterwarnings("ignore")
from flamby.datasets.fed_tcga_brca import FedTcgaBrca as FedDataset
from flamby.strategies.fed_avg_log import FedAvgWithLog as strat
from tqdm import tqdm

In [38]:
from flamby.utils import evaluate_model_on_tests

In [60]:
train_dataloaders = [
            torch.utils.data.DataLoader(
                FedDataset(center = i, train = True, pooled = False),
                batch_size = BATCH_SIZE,
                shuffle = True,
                num_workers = 0
            )
            for i in range(NUM_CLIENTS-1)
        ]

lossfunc = BaselineLoss()
m = Baseline()

In [80]:
# Federated Learning loop
# 2nd line of code to change to switch to another strategy (feed the FL strategy the right HPs)
args = {
            "training_dataloaders": train_dataloaders,
            "model": m,
            "loss": lossfunc,
            "optimizer_class": torch.optim.Adam,
            "learning_rate": LR / 20.0,
            "num_updates": 50,
# This helper function returns the number of rounds necessary to perform approximately as many
# epochs on each local dataset as with the pooled training
            "nrounds": get_nb_max_rounds(50),
        }

In [83]:
s = strat(**args)
results = []

for client_id in range(NUM_CLIENTS-1):
    test_dataset = FedTcgaBrca(center=client_id, train=False, pooled=False)
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=4,
        drop_last=True,
    )
    results.append((client_id,evaluate_model_on_tests(s.models_list[client_id].model, [test_dataloader], metric, use_tqdm=False)))
for rounds in tqdm(range(30)):
    s.perform_round()

    if (rounds+1) % 5 == 0: 
    # evaluation for clients
        for client_id in range(NUM_CLIENTS-1):
            test_dataset = FedTcgaBrca(center=client_id, train=False, pooled=False)
            test_dataloader = torch.utils.data.DataLoader(
                test_dataset,
                batch_size=BATCH_SIZE,
                shuffle=True,
                num_workers=4,
                drop_last=True,
            )
            
            results.append((client_id,evaluate_model_on_tests(s.models_list[client_id].model, [test_dataloader], metric, use_tqdm=False)))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:12<00:00,  2.44it/s]


In [84]:
results

[(0, {'client_test_0': 0.4166666666666667}),
 (1, {'client_test_0': 0.09090909090909091}),
 (2, {'client_test_0': 0.23404255319148937}),
 (3, {'client_test_0': 0.6666666666666666}),
 (4, {'client_test_0': 0.022222222222222223}),
 (0, {'client_test_0': 0.7583892617449665}),
 (1, {'client_test_0': 0.7272727272727273}),
 (2, {'client_test_0': 0.7083333333333334}),
 (3, {'client_test_0': 0.5625}),
 (4, {'client_test_0': 0.9333333333333333}),
 (0, {'client_test_0': 0.8606271777003485}),
 (1, {'client_test_0': 0.7045454545454546}),
 (2, {'client_test_0': 0.8431372549019608}),
 (3, {'client_test_0': 0.625}),
 (4, {'client_test_0': 0.9333333333333333}),
 (0, {'client_test_0': 0.865814696485623}),
 (1, {'client_test_0': 0.6590909090909091}),
 (2, {'client_test_0': 0.8602150537634409}),
 (3, {'client_test_0': 0.6666666666666666}),
 (4, {'client_test_0': 0.9111111111111111}),
 (0, {'client_test_0': 0.8571428571428571}),
 (1, {'client_test_0': 0.7045454545454546}),
 (2, {'client_test_0': 0.8679245