In [1]:
import flwr
from typing import List, Tuple, Dict, Union

if __name__ == "__main__":
    # Create FedAvg strategy
    def mean_fit_metrics(metrics: List[Tuple[int, Dict[str, Union[bool, bytes, float, int, str]]]]) -> Union[
        Dict[str, Union[bool, bytes, float, int, str, None]], None]:
        """Compute the mean of fit metrics."""
        if not metrics:
            return {}  # Return an empty dictionary if metrics is empty

        # Convert metrics to a list if it's a tuple
        if isinstance(metrics, tuple):
            metrics = list(metrics)

        # Extract last epoch metrics if available
        last_epoch_metrics = metrics[-1][1] if metrics and isinstance(metrics[-1], tuple) else {}

        # Compute mean of fit metrics
        aggregated_metrics = {}
        if metrics and isinstance(metrics[-2], tuple):  # Check the second last element for metrics
            metric_count = len(metrics[:-2])
            for metric_name in metrics[-2][1].keys():
                aggregated_metrics[metric_name] = sum(
                    client_metric[1].get(metric_name, 0) for client_metric in
                    metrics[:-2]) / metric_count if metric_count != 0 else None

        # Extract parameters and length of train_gen if available
        parameters = metrics[0][1] if metrics and isinstance(metrics[0], tuple) else None
        train_gen_length = metrics[-2][1] if metrics and isinstance(metrics[-2],
                                                                    tuple) else None  # Access the third last element

        return {
            "parameters": parameters,
            "train_gen_length": train_gen_length,
            "last_epoch_metrics": last_epoch_metrics,
            "aggregated_metrics": aggregated_metrics
        }

    strategy = flwr.server.strategy.FedAvg(
        fraction_fit=0.5,
        min_fit_clients=2,
        min_available_clients=2,
        fit_metrics_aggregation_fn=mean_fit_metrics  # Provide your aggregation function here
    )

    # Start Flower server
    flwr.server.start_server(
        server_address="127.0.0.1:8890",
        config=flwr.server.ServerConfig(
            num_rounds=5, round_timeout=60
        ),
        strategy=strategy
    )


[92mINFO [0m:      Starting Flower server, config: num_rounds=5, round_timeout=60s
[92mINFO [0m:      Flower ECE: gRPC server running (5 rounds), SSL is disabled
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[92mINFO [0m:      Received initial parameters from one random client
[92mINFO [0m:      Evaluating initial global parameters
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 2)
[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO [0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 2)
[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy samp