In [1]:
import client
import model
import dataset
import flwr as fl
from flwr.common import Metrics

import torch

from typing import List, Tuple

In [2]:
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    # Multiply accuracy of each client by number of examples used
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]
    
    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}

In [3]:
strategy = fl.server.strategy.FedAvg(
        fraction_fit=1.0,  # Sample 100% of available clients for training
        fraction_evaluate=0.5,  # Sample 50% of available clients for evaluation
        min_fit_clients=10,  # Never sample less than 10 clients for training
        min_evaluate_clients=5,  # Never sample less than 5 clients for evaluation
        min_available_clients=10,  # Wait until all 10 clients are available
        evaluate_metrics_aggregation_fn=weighted_average,
)

# Start simulation
fl.simulation.start_simulation(
    client_fn=client.client_fn,
    num_clients=10,
    config=fl.server.ServerConfig(num_rounds=5),
    strategy=strategy,
)

INFO flower 2022-12-05 15:45:06,915 | app.py:143 | Starting Flower simulation, config: ServerConfig(num_rounds=5, round_timeout=None)
2022-12-05 15:45:08,898	INFO worker.py:1528 -- Started a local Ray instance.
INFO flower 2022-12-05 15:45:10,208 | app.py:177 | Flower VCE: Ray initialized with resources: {'object_store_memory': 2553102336.0, 'node:172.22.22.55': 1.0, 'CPU': 8.0, 'memory': 5106204672.0, 'accelerator_type:G': 1.0, 'GPU': 1.0}
INFO flower 2022-12-05 15:45:10,210 | server.py:86 | Initializing global parameters
INFO flower 2022-12-05 15:45:10,211 | server.py:270 | Requesting initial parameters from one random client
INFO flower 2022-12-05 15:45:11,316 | server.py:274 | Received initial parameters from one random client
INFO flower 2022-12-05 15:45:11,317 | server.py:88 | Evaluating initial parameters
INFO flower 2022-12-05 15:45:11,318 | server.py:101 | FL starting
DEBUG flower 2022-12-05 15:45:11,319 | server.py:220 | fit_round 1: strategy sampled 10 clients (out of 10)
[

History (loss, distributed):
	round 1: 0.06668758010864259
	round 2: 0.04564744263887405
	round 3: 0.03340614410241445
	round 4: 0.029205763339996332
	round 5: 0.010537337633470693
History (metrics, distributed):
{'accuracy': [(1, 0.5733333333333334), (2, 0.5646666666666667), (3, 0.7026666666666667), (4, 0.7116666666666667), (5, 0.954)]}

In [171]:
def test_cnn_size_mnist() -> None:
    """Test number of parameters with MNIST-sized inputs."""
    # Prepare
    net = model.Net()
    expected = 1_663_370

    # Execute
    actual = sum([p.numel() for p in net.parameters()])

    # Assert
    assert actual == expected

test_cnn_size_mnist()