In [6]:
import torch
import flwr as fl
from torch.utils.data import Dataset, Subset
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, Normalize, ToTensor
from client import get_client_generator, weighted_average_accuracy
from dataset import partition_dataset
from util import seed_everything
from flwr.server.strategy import FedAvg

transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
train_dataset = MNIST("./mnist", train=True, download=True, transform=transform)
val_dataset = MNIST("./mnist", train=False, transform=transform)

def complete_run(seed):
    seed_everything(seed)

    num_clients = 10
    train_datasets = partition_dataset(train_dataset, num_clients)
    val_datasets = partition_dataset(val_dataset, num_clients)
    train_dataloaders = [torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True) for dataset in train_datasets]
    val_dataloaders = [torch.utils.data.DataLoader(dataset, batch_size=16) for dataset in val_datasets]
    client_resources = None
    client_fn = get_client_generator(train_dataloaders, val_dataloaders)
    client_config = {
        "lr": 0.05,
        "epochs": 1,
    }
    strategy = FedAvg(
        min_fit_clients=2,
        min_evaluate_clients=2,
        min_available_clients=2,
        fraction_fit=0.5,
        fraction_evaluate=0.5,
        on_fit_config_fn=lambda _: client_config,
        on_evaluate_config_fn=lambda _: client_config,
        evaluate_metrics_aggregation_fn=weighted_average_accuracy,
    )
    hist = fl.simulation.start_simulation(
        client_fn=client_fn,
        seed_fn=seed_everything,
        seed=1234,
        num_clients=num_clients,
        config=fl.server.ServerConfig(num_rounds=5),
        client_resources=client_resources,
        strategy=strategy
    )
    return hist

In [7]:
run1 = complete_run(0)

INFO flwr 2023-03-08 14:55:05,530 | app.py:148 | Starting Flower simulation, config: ServerConfig(num_rounds=5, round_timeout=None)
2023-03-08 14:55:09,639	INFO worker.py:1553 -- Started a local Ray instance.
INFO flwr 2023-03-08 14:55:11,133 | app.py:182 | Flower VCE: Ray initialized with resources: {'object_store_memory': 2147483648.0, 'node:127.0.0.1': 1.0, 'CPU': 10.0, 'memory': 17186784871.0}
INFO flwr 2023-03-08 14:55:11,135 | server.py:96 | Initializing global parameters
INFO flwr 2023-03-08 14:55:11,136 | server.py:289 | Requesting initial parameters from one random client
INFO flwr 2023-03-08 14:55:11,917 | server.py:293 | Received initial parameters from one random client
INFO flwr 2023-03-08 14:55:11,918 | server.py:98 | Evaluating initial parameters
INFO flwr 2023-03-08 14:55:11,918 | server.py:111 | FL starting
DEBUG flwr 2023-03-08 14:55:11,919 | server.py:230 | fit_round 1: strategy sampled 9 clients (out of 18)
DEBUG flwr 2023-03-08 14:55:19,004 | server.py:244 | fit_ro

In [None]:
run2 = complete_run(0)

INFO flwr 2023-03-08 14:54:24,045 | app.py:148 | Starting Flower simulation, config: ServerConfig(num_rounds=1, round_timeout=None)
2023-03-08 14:54:28,528	INFO worker.py:1553 -- Started a local Ray instance.
INFO flwr 2023-03-08 14:54:30,082 | app.py:182 | Flower VCE: Ray initialized with resources: {'object_store_memory': 2147483648.0, 'CPU': 10.0, 'memory': 16810904781.0, 'node:127.0.0.1': 1.0}
INFO flwr 2023-03-08 14:54:30,084 | server.py:96 | Initializing global parameters
INFO flwr 2023-03-08 14:54:30,084 | server.py:289 | Requesting initial parameters from one random client
INFO flwr 2023-03-08 14:54:30,788 | server.py:293 | Received initial parameters from one random client
INFO flwr 2023-03-08 14:54:30,788 | server.py:98 | Evaluating initial parameters
INFO flwr 2023-03-08 14:54:30,789 | server.py:111 | FL starting
DEBUG flwr 2023-03-08 14:54:30,789 | server.py:230 | fit_round 1: strategy sampled 9 clients (out of 18)
DEBUG flwr 2023-03-08 14:54:38,156 | server.py:244 | fit_ro

In [None]:
run1

History (loss, distributed):
	round 1: 0.40468900497972216
History (metrics, distributed):
{'accuracy': [(1, 0.8957791558311662)]}

In [None]:
run2

History (loss, distributed):
	round 1: 0.40468900497972216
History (metrics, distributed):
{'accuracy': [(1, 0.8957791558311662)]}