In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd

from models.resnet import ResNet50
from utils.datasets import get_datasets
from utils.sampling import get_user_groups
from utils.reproducibility import make_it_reproducible
from feddyn.components import FedDynServer, FedDynClient

In [3]:
device = 'cuda' if torch.cuda.is_available else 'cpu'
print(torch.cuda.get_device_name(device))

Tesla T4


In [5]:
# setting parameters
ROUNDS = 50
alpha = 0.01
tot_clients = 100
participation = 0.1
cuda = device=="cuda"
norm = "Batch Norm"
iid = True
unbalanced = False
seed = 0

local_epochs = 5
lr = 1e-2
weight_decay = 4e-4
momentum = 0
clip_value = 10

In [6]:
make_it_reproducible(seed)

In [7]:
trainset, testset = get_datasets(augmentation=True)
user_groups, _ = get_user_groups(trainset, iid=iid, unbalanced=unbalanced, tot_users=tot_clients)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
server = FedDynServer(ResNet50(norm), alpha, tot_clients, device, testset, seed)

clients = []
for cid in range(tot_clients):
    clients.append(FedDynClient(device, lr, weight_decay, momentum, alpha, cid, local_epochs, trainset,
                                user_groups[cid], clip_value))

In [None]:
train, test = [], []
for com_round in range(1, ROUNDS+1):
    print(f"Running communication round {com_round}...")
    
    server_state_dict = server.get_server_state()
    
    active_clients_models = []
    
    m = int(max(1, tot_clients * participation))
    chosen_users = np.random.choice(range(tot_clients), m, replace=False)
    
    for idx in chosen_users:
        state, metric = clients[idx].train(ResNet50(norm), server_state_dict, com_round)
        active_clients_models.append(state)
        train.append(metric)
        
    server.update_model(active_clients_models)
    test.append(server.evaluate(com_round))
    print("\n")
    
    if com_round % 5 == 0:
        df_train = pd.DataFrame(train)
        df_train.to_csv("/content/train_rn50_iid_old.csv", index=False)
        df_test = pd.DataFrame(test)
        df_test.to_csv("/content/test_rn50_iid_old.csv", index=False)

Running communication round 1...
Training client 18 ... done!	 loss=2.6529	 accuracy=0.1544
Training client 26 ... done!	 loss=2.6428	 accuracy=0.1404
Training client 58 ... done!	 loss=2.6068	 accuracy=0.1668
Training client 94 ... done!	 loss=2.6236	 accuracy=0.1536
Training client 84 ... done!	 loss=2.5851	 accuracy=0.1732
Training client 7 ... done!	 loss=2.6414	 accuracy=0.152
Training client 89 ... done!	 loss=2.6083	 accuracy=0.17
Training client 54 ... done!	 loss=2.6061	 accuracy=0.1572
Training client 62 ... done!	 loss=2.6334	 accuracy=0.1516
Training client 3 ... done!	 loss=2.6327	 accuracy=0.1468
Updating server model... done!
Evaluating server model at round 1 ... done!	 loss=2.4238	 accuracy=0.1018


Running communication round 2...
Training client 30 ... done!	 loss=2.3227	 accuracy=0.1948
Training client 34 ... done!	 loss=2.3367	 accuracy=0.1712
Training client 39 ... done!	 loss=2.2845	 accuracy=0.2188
Training client 37 ... done!	 loss=2.2816	 accuracy=0.216
Traini