In [None]:
import os
import copy
import numpy as np
from tqdm import tqdm
import random

import torch
import torchvision
import matplotlib.pyplot as plt
from torchvision import transforms

from models.resnet import ResNet50
from utils.reproducibility import make_it_reproducible, seed_worker
from utils.fedavg_utils import get_datasets, get_user_groups, average_weights
from utils.fedavg_local import LocalUpdate

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# reproducibility
seed = 0

g = torch.Generator()
g.manual_seed(seed)

In [None]:
# setting parameters
ROUNDS = 30
tot_users = 100
selection_fraction = 0.1
local_batch_size = 10
local_epochs = 1
iid = True
unbalanced = False

In [None]:
# datasets and loaders
trainset, testset = get_datasets()
user_groups = get_user_groups(trainset, iid=iid, unbalanced=unbalanced, tot_users=tot_users)

testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=128, shuffle=False, num_workers=2,
                                         worker_init_fn=seed_worker, generator=g)

In [None]:
global_net = ResNet50()
# global_net = ResNet50("Group Norm")
global_net.to(device)
global_net.train()

global_weights = global_net.state_dict()

In [None]:
train_loss, test_accuracy = [], []
make_it_reproducible(seed)

In [None]:
for round in tqdm(range(ROUNDS)):
    local_weights, local_losses = [], []

    global_net.train()
    m = max(int(selection_fraction * tot_users), 1)
    selected_users = np.random.choice(range(tot_users), m, replace=False)

    for idx in selected_users:
        local_net = LocalUpdate(dataset=trainset, idxs=user_groups[idx], local_batch_size=local_batch_size,\
            local_epochs=local_epochs, worker_init_fn=seed_worker(seed), generator=g)
        w, loss = local_net.update_weights(model=copy.deepcopy(global_net), global_round=round)
        
        local_weights.append(copy.deepcopy(w))
        local_losses.append(copy.deepcopy(loss))
    loss_avg = sum(local_losses) / len(local_losses)
    train_loss.append(loss_avg)

    global_weights = average_weights(local_weights)
    global_net.load_state_dict(global_weights)

    global_net.eval()
    total, correct = 0, 0 
    with torch.no_grad():
        for x, y in testloader:
            x, y = x.to(device), y.to(device)
            yhat = global_net(x)
            _, predicted = torch.max(yhat.data, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
    test_accuracy.append(correct / total)

    print(f"\nAt round {round+1} we had: test_accuracy={correct/total} and average_local_loss={train_loss[-1]}")