In [6]:
import torch
import flwr as fl
import random
import numpy as np
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, Normalize, ToTensor
from client import get_client_generator, weighted_average_accuracy
from dataset import partition_dataset
from flwr.server.strategy import FedAvg

transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
train_dataset = MNIST("./mnist", train=True, download=True, transform=transform)
val_dataset = MNIST("./mnist", train=False, transform=transform)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)

def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def complete_run(seed=None, single_threaded=False):
    if seed is not None:
        seed_everything(seed)

    num_clients = 5
    train_datasets = partition_dataset(train_dataset, num_clients)
    val_datasets = partition_dataset(val_dataset, num_clients)
    train_dataloaders = [torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True) for dataset in train_datasets]
    val_dataloaders = [torch.utils.data.DataLoader(dataset, batch_size=16) for dataset in val_datasets]
    client_resources = None
    client_fn = get_client_generator(train_dataloaders, val_dataloaders)
    client_config = {
        "lr": 0.05,
        "epochs": 1,
    }
    strategy = FedAvg(
        min_fit_clients=2,
        min_evaluate_clients=2,
        min_available_clients=2,
        fraction_fit=0.25,
        fraction_evaluate=0.25,
        on_fit_config_fn=lambda _: client_config,
        on_evaluate_config_fn=lambda _: client_config,
        evaluate_metrics_aggregation_fn=weighted_average_accuracy,
    )
    if not single_threaded:
        hist = fl.simulation.start_simulation(
            client_fn=client_fn,
            seed_fn=seed_everything if seed is not None else None,
            seed=seed,
            num_clients=num_clients,
            config=fl.server.ServerConfig(num_rounds=5),
            client_resources=client_resources,
            strategy=strategy
        )
    else:
        hist = fl.simulation.start_simulation_single_threaded(
            client_fn=client_fn,
            num_clients=num_clients,
            config=fl.server.ServerConfig(num_rounds=5),
            strategy=strategy,
        )

    return hist

In [2]:
run1 = complete_run(0)
run2 = complete_run(0)

INFO flwr 2023-03-09 17:04:03,201 | app.py:179 | Starting Flower simulation, config: ServerConfig(num_rounds=5, round_timeout=None)
2023-03-09 17:04:05,353	INFO worker.py:1553 -- Started a local Ray instance.
INFO flwr 2023-03-09 17:04:07,025 | app.py:213 | Flower VCE: Ray initialized with resources: {'CPU': 10.0, 'node:127.0.0.1': 1.0, 'object_store_memory': 2147483648.0, 'memory': 18667405312.0}
INFO flwr 2023-03-09 17:04:07,154 | server.py:100 | Initializing global parameters
INFO flwr 2023-03-09 17:04:07,154 | server.py:291 | Requesting initial parameters from one random client
INFO flwr 2023-03-09 17:04:08,718 | server.py:295 | Received initial parameters from one random client
INFO flwr 2023-03-09 17:04:08,718 | server.py:102 | Evaluating initial parameters
INFO flwr 2023-03-09 17:04:08,719 | server.py:115 | FL starting
DEBUG flwr 2023-03-09 17:04:08,719 | server.py:232 | fit_round 1: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-03-09 17:04:23,172 | server.py:246 | fit_r

In [3]:
run1

History (loss, distributed):
	round 1: 0.23777819456905128
	round 2: 0.15710546273272485
	round 3: 0.11228240823280067
	round 4: 0.0846113607659936
	round 5: 0.08301721393363551
History (metrics, distributed):
{'accuracy': [(1, 0.933), (2, 0.955), (3, 0.96675), (4, 0.97375), (5, 0.974)]}

In [4]:
run2

History (loss, distributed):
	round 1: 0.23777819456905128
	round 2: 0.15710546273272485
	round 3: 0.11228240823280067
	round 4: 0.0846113607659936
	round 5: 0.08301721393363551
History (metrics, distributed):
{'accuracy': [(1, 0.933), (2, 0.955), (3, 0.96675), (4, 0.97375), (5, 0.974)]}

In [7]:
complete_run(4, single_threaded=True)

INFO flwr 2023-03-09 17:37:59,285 | app.py:276 | Starting Flower simulation, config: ServerConfig(num_rounds=5, round_timeout=None)
INFO flwr 2023-03-09 17:37:59,286 | server.py:100 | Initializing global parameters
INFO flwr 2023-03-09 17:37:59,286 | server.py:291 | Requesting initial parameters from one random client
INFO flwr 2023-03-09 17:37:59,300 | server.py:295 | Received initial parameters from one random client
INFO flwr 2023-03-09 17:37:59,301 | server.py:102 | Evaluating initial parameters
INFO flwr 2023-03-09 17:37:59,302 | server.py:115 | FL starting
DEBUG flwr 2023-03-09 17:37:59,303 | server.py:232 | fit_round 1: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-03-09 17:38:20,345 | server.py:246 | fit_round 1 received 2 results and 0 failures
DEBUG flwr 2023-03-09 17:38:20,353 | server.py:179 | evaluate_round 1: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-03-09 17:38:22,059 | server.py:193 | evaluate_round 1 received 2 results and 0 failures
DEBUG flwr 2023

History (loss, distributed):
	round 1: 0.24160061959177256
	round 2: 0.15777746839076282
	round 3: 0.102664565872401
	round 4: 0.09126850824570283
	round 5: 0.07899820515606552
History (metrics, distributed):
{'accuracy': [(1, 0.93075), (2, 0.9525), (3, 0.96675), (4, 0.97325), (5, 0.97525)]}