In [1]:
import torch
import flwr as fl
import random
import numpy as np
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, Normalize, ToTensor
from client import get_client_generator, weighted_average_accuracy
from dataset import partition_dataset
from flwr.server.strategy import FedAvg
from flwr.simulation.backend.multiprocessing import MultiProcessingBackend
from flwr.simulation.backend.deterministic import SingleThreadedBackend
transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
train_dataset = MNIST("./mnist", train=True, download=True, transform=transform)
val_dataset = MNIST("./mnist", train=False, transform=transform)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)

def seed_everything(seed):
    import random
    import numpy as np
    import torch
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def complete_run(seed=None, mode="ray"):
    if seed is not None:
        seed_everything(seed)

    num_clients = 5
    train_datasets = partition_dataset(train_dataset, num_clients)
    val_datasets = partition_dataset(val_dataset, num_clients)
    train_dataloaders = [torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True) for dataset in train_datasets]
    val_dataloaders = [torch.utils.data.DataLoader(dataset, batch_size=16) for dataset in val_datasets]
    client_fn = get_client_generator(train_dataloaders, val_dataloaders)
    client_config = {
        "lr": 0.05,
        "epochs": 1,
    }
    strategy = FedAvg(
        min_fit_clients=2,
        min_evaluate_clients=2,
        min_available_clients=2,
        fraction_fit=0.25,
        fraction_evaluate=0.25,
        on_fit_config_fn=lambda _: client_config,
        on_evaluate_config_fn=lambda _: client_config,
        evaluate_metrics_aggregation_fn=weighted_average_accuracy,
    )
    if mode == "ray":
        backend = None
    elif mode == "multiprocessing":
        backend = MultiProcessingBackend()
    elif mode == "single_threaded":
        backend = SingleThreadedBackend()
    else:
        raise ValueError("Unknown mode")
    hist = fl.simulation.start_simulation(
        client_fn=client_fn,
        seed_fn=seed_everything if seed is not None else None,
        seed=seed,
        num_clients=num_clients,
        config=fl.server.ServerConfig(num_rounds=5),
        strategy=strategy,
        backend=backend
    )

    return hist

In [2]:
def test_mode(mode):
    run1 = complete_run(0, mode)
    run2 = complete_run(0, mode)
    print(run1)
    print(run2)

In [3]:
%timeit -r 1 -n 1  test_mode("ray")

INFO flwr 2023-03-11 16:34:44,384 | app.py:199 | Starting Flower simulation, config: ServerConfig(num_rounds=5, round_timeout=None)
2023-03-11 16:34:46,295	INFO worker.py:1553 -- Started a local Ray instance.
INFO flwr 2023-03-11 16:34:47,802 | ray_backend.py:54 | Flower VCE: Ray initialized with resources: {'memory': 18016106906.0, 'object_store_memory': 2147483648.0, 'CPU': 10.0, 'node:127.0.0.1': 1.0}
INFO flwr 2023-03-11 16:34:47,928 | server.py:100 | Initializing global parameters
INFO flwr 2023-03-11 16:34:47,928 | server.py:291 | Requesting initial parameters from one random client
INFO flwr 2023-03-11 16:34:49,328 | server.py:295 | Received initial parameters from one random client
INFO flwr 2023-03-11 16:34:49,329 | server.py:102 | Evaluating initial parameters
INFO flwr 2023-03-11 16:34:49,329 | server.py:115 | FL starting
DEBUG flwr 2023-03-11 16:34:49,330 | server.py:232 | fit_round 1: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-03-11 16:35:03,037 | server.py:246 

History (loss, distributed):
	round 1: 0.23777819456905128
	round 2: 0.15710546273272485
	round 3: 0.11228240823280067
	round 4: 0.0846113607659936
	round 5: 0.08301721393363551
History (metrics, distributed):
{'accuracy': [(1, 0.933), (2, 0.955), (3, 0.96675), (4, 0.97375), (5, 0.974)]}
History (loss, distributed):
	round 1: 0.23777819456905128
	round 2: 0.15710546273272485
	round 3: 0.11228240823280067
	round 4: 0.0846113607659936
	round 5: 0.08301721393363551
History (metrics, distributed):
{'accuracy': [(1, 0.933), (2, 0.955), (3, 0.96675), (4, 0.97375), (5, 0.974)]}
2min 36s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [4]:
%timeit -r 1 -n 1  test_mode("multiprocessing")

INFO flwr 2023-03-11 16:37:20,922 | app.py:199 | Starting Flower simulation, config: ServerConfig(num_rounds=5, round_timeout=None)
INFO flwr 2023-03-11 16:37:20,981 | server.py:100 | Initializing global parameters
INFO flwr 2023-03-11 16:37:20,983 | server.py:291 | Requesting initial parameters from one random client
INFO flwr 2023-03-11 16:37:22,293 | server.py:295 | Received initial parameters from one random client
INFO flwr 2023-03-11 16:37:22,293 | server.py:102 | Evaluating initial parameters
INFO flwr 2023-03-11 16:37:22,293 | server.py:115 | FL starting
DEBUG flwr 2023-03-11 16:37:22,294 | server.py:232 | fit_round 1: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-03-11 16:37:40,242 | server.py:246 | fit_round 1 received 2 results and 0 failures
DEBUG flwr 2023-03-11 16:37:40,251 | server.py:179 | evaluate_round 1: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-03-11 16:37:42,375 | server.py:193 | evaluate_round 1 received 2 results and 0 failures
DEBUG flwr 2023

History (loss, distributed):
	round 1: 0.2330465930700302
	round 2: 0.15029081416130066
	round 3: 0.11393135458696634
	round 4: 0.09129968865960836
	round 5: 0.09463400622643531
History (metrics, distributed):
{'accuracy': [(1, 0.93175), (2, 0.95625), (3, 0.96575), (4, 0.9705), (5, 0.97)]}
History (loss, distributed):
	round 1: 0.2330465930700302
	round 2: 0.15029081416130066
	round 3: 0.11393135458696634
	round 4: 0.09129968865960836
	round 5: 0.09463400622643531
History (metrics, distributed):
{'accuracy': [(1, 0.93175), (2, 0.95625), (3, 0.96575), (4, 0.9705), (5, 0.97)]}
3min 15s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [3]:
%timeit -r 1 -n 1  test_mode("single_threaded")

INFO flwr 2023-03-11 16:51:50,042 | app.py:199 | Starting Flower simulation, config: ServerConfig(num_rounds=5, round_timeout=None)
INFO flwr 2023-03-11 16:51:50,050 | server.py:100 | Initializing global parameters
INFO flwr 2023-03-11 16:51:50,051 | server.py:291 | Requesting initial parameters from one random client
INFO flwr 2023-03-11 16:51:50,059 | server.py:295 | Received initial parameters from one random client
INFO flwr 2023-03-11 16:51:50,060 | server.py:102 | Evaluating initial parameters
INFO flwr 2023-03-11 16:51:50,061 | server.py:115 | FL starting
DEBUG flwr 2023-03-11 16:51:50,062 | server.py:232 | fit_round 1: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-03-11 16:52:10,316 | server.py:246 | fit_round 1 received 2 results and 0 failures
DEBUG flwr 2023-03-11 16:52:10,325 | server.py:179 | evaluate_round 1: strategy sampled 2 clients (out of 5)
DEBUG flwr 2023-03-11 16:52:11,887 | server.py:193 | evaluate_round 1 received 2 results and 0 failures
DEBUG flwr 2023

History (loss, distributed):
	round 1: 0.23897419881820678
	round 2: 0.14989428592845797
	round 3: 0.1124828438712284
	round 4: 0.09364412524877117
	round 5: 0.08764808896835893
History (metrics, distributed):
{'accuracy': [(1, 0.93275), (2, 0.95475), (3, 0.966), (4, 0.97025), (5, 0.97075)]}
History (loss, distributed):
	round 1: 0.2513150454834104
	round 2: 0.14717471213638783
	round 3: 0.10913577037304639
	round 4: 0.09507927848678083
	round 5: 0.08463426199601963
History (metrics, distributed):
{'accuracy': [(1, 0.9345), (2, 0.9545), (3, 0.96925), (4, 0.969), (5, 0.97325)]}
3min 34s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
