In [1]:
import client
import model
import dataset
import flwr as fl
from flwr.common import Metrics

import torch

from typing import List, Tuple

  return torch._C._cuda_getDeviceCount() > 0


In [2]:
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    # Multiply accuracy of each client by number of examples used
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]
    
    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}

In [3]:
strategy = fl.server.strategy.FedAvg(
        fraction_fit=1.0,  # Sample 100% of available clients for training
        fraction_evaluate=0.5,  # Sample 50% of available clients for evaluation
        min_fit_clients=10,  # Never sample less than 10 clients for training
        min_evaluate_clients=5,  # Never sample less than 5 clients for evaluation
        min_available_clients=10,  # Wait until all 10 clients are available
        evaluate_metrics_aggregation_fn=weighted_average,
)

# Start simulation
fl.simulation.start_simulation(
    client_fn=client.client_fn,
    num_clients=10,
    config=fl.server.ServerConfig(num_rounds=5),
    strategy=strategy,
)

INFO flower 2022-12-06 12:42:21,197 | app.py:145 | Starting Flower simulation, config: ServerConfig(num_rounds=5, round_timeout=None)
INFO flower 2022-12-06 12:42:24,363 | app.py:179 | Flower VCE: Ray initialized with resources: {'memory': 5886507419.0, 'object_store_memory': 2943253708.0, 'accelerator_type:G': 1.0, 'GPU': 1.0, 'CPU': 8.0, 'node:128.179.187.61': 1.0}
INFO flower 2022-12-06 12:42:24,369 | server.py:86 | Initializing global parameters
INFO flower 2022-12-06 12:42:24,372 | server.py:270 | Requesting initial parameters from one random client
[2m[36m(launch_and_get_parameters pid=223355)[0m   return torch._C._cuda_getDeviceCount() > 0
INFO flower 2022-12-06 12:42:26,110 | server.py:274 | Received initial parameters from one random client
INFO flower 2022-12-06 12:42:26,111 | server.py:88 | Evaluating initial parameters
INFO flower 2022-12-06 12:42:26,112 | server.py:101 | FL starting
DEBUG flower 2022-12-06 12:42:26,113 | server.py:220 | fit_round 1: strategy sampled 10 

History (loss, distributed):
	round 1: 0.06731279722849529
	round 2: 0.042655788878599804
	round 3: 0.019700912460684776
	round 4: 0.010846911819030842
	round 5: 0.028645703474680585
History (metrics, distributed):
{'accuracy': [(1, 0.6020000000000001), (2, 0.5656666666666667), (3, 0.902), (4, 0.9526666666666667), (5, 0.713)]}

In [4]:
def test_cnn_size_mnist() -> None:
    """Test number of parameters with MNIST-sized inputs."""
    # Prepare
    net = model.Net()
    expected = 1_663_370

    # Execute
    actual = sum([p.numel() for p in net.parameters()])

    # Assert
    assert actual == expected

test_cnn_size_mnist()