In [None]:
import torch
import numpy as np
import pandas as pd

from models.resnet import ResNet50
from utils.datasets import get_datasets
from utils.sampling import get_user_groups
from feddyn.components import FedDynServer, FedDynClient

In [None]:
device = 'cuda' if torch.cuda.is_available else 'cpu'

In [None]:
seed = 0  # 128, 479

In [None]:
# setting parameters
client_number = 100
participation_frac = 0.1
ROUNDS = 3
alpha = 0.01
lr = 1e-2
local_epochs = 5

In [None]:
metric = []

iid = True
unbalanced = False
norm = "Batch Norm"

trainset, testset = get_datasets(augmentation=True)

# server, reproducibility demanded to the server
server = FedDynServer(ResNet50(), alpha, client_number, device, testset, seed)

# clients
client_model = ResNet50()
clients = []
user_groups, _ = get_user_groups(trainset, iid=iid, unbalanced=unbalanced, tot_users=client_number)
for idx in range(client_number):
    clients.append(FedDynClient(client_model, device, lr, alpha, idx, local_epochs, trainset, user_groups[idx]))

for round in range(ROUNDS):
    print("Communication round: ", round+1)
    m = max(int(participation_frac*client_number), 1)
    active_users = np.random.choice(range(client_number), m, replace=False)
    print(f"Chosen users: {active_users}")

    act_server_state = server.get_server_state()
    active_clients_params = []
    for idx in active_users:
        client_params, metrics = clients[idx].train(act_server_state, round)
        active_clients_params.append(client_params)

    server.update_model(active_clients_params)
    server.evaluate(round)

test_metrics = server.get_test_metrics()
print(test_metrics)