In [None]:
import os
import copy
import numpy as np
from tqdm import tqdm
import random

import torch
import torchvision
import matplotlib.pyplot as plt
from torchvision import transforms

from local import LocalUpdate
from resnet import ResNet50
# from resnetgn import ResNet50GN
from utils import get_datasets, get_user_groups, average_weights

Code to make this notebook reproducible

In [None]:
seed = 0

def make_it_reproducible(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    torch.use_deterministic_algorithms(True)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(seed)

Get datasets

In [None]:
trainset, testset = get_datasets()
testloader = torch.utils.data.DataLoader(
    testset, batch_size=128, shuffle=False, num_workers=2, 
    worker_init_fn=seed_worker(seed), generator=g)

Define federated parameters

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
rounds = 30
tot_users = 100
selection_fraction = 0.1
local_batch_size = 10
local_epochs = 1

In [None]:
global_net = ResNet50()
# global_net = ResNet50GN()
global_net.to(device)
global_net.train()

global_weights = global_net.state_dict()

In [None]:
train_loss, test_accuracy = [], []
make_it_reproducible(seed)

In [None]:
for round in tqdm(range(rounds)):
    local_weights, local_losses = [], []

    global_net.train()
    m = max(int(selection_fraction * tot_users), 1)
    selected_users = np.random.choice(range(tot_users), m, replace=False)

    for idx in selected_users:
        local_net = LocalUpdate(dataset=trainset, idxs=user_groups[idx], local_batch_size=local_batch_size,\
            local_epochs=local_epochs, worker_init_fn=seed_worker(seed), generator=g)
        w, loss = local_net.update_weights(model=copy.deepcopy(global_net), global_round=round)
        
        local_weights.append(copy.deepcopy(w))
        local_losses.append(copy.deepcopy(loss))
    loss_avg = sum(local_losses) / len(local_losses)
    train_loss.append(loss_avg)

    global_weights = average_weights(local_weights)
    global_net.load_state_dict(global_weights)

    global_net.eval()
    total, correct = 0, 0 
    with torch.no_grad():
        for x, y in testloader:
            x, y = x.to(device), y.to(device)
            yhat = global_net(x)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    test_accuracy.append(correct / total)

    print(f"\nAt round {round+1} we had: test_accuracy={correct/total} and average_local_loss={train_loss[-1]}")