Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable metric aggregation through user-provided functions #1144

Merged
merged 11 commits into from
Mar 25, 2022

Conversation

danieljanes
Copy link
Member

@danieljanes danieljanes commented Mar 23, 2022

This PR enables users to customize metric aggregation in built-in strategies. This prevents users from having to customize the strategy just to customize the aggregation of metrics dicts.

Currently, users have to do this:

import flwr as fl
from typing import List, Optional


class AggregateCustomMetricStrategy(fl.server.strategy.FedAvg):
    def aggregate_evaluate(
        self,
        rnd: int,
        results: List,
        failures: List[BaseException],
    ) -> Optional[float]:
        """Aggregate evaluation losses using weighted average."""
        if not results:
            return None

        # Weigh accuracy of each client by number of examples used
        accuracies = [r.metrics["accuracy"] * r.num_examples for _, r in results]
        examples = [r.num_examples for _, r in results]

        # Aggregate and print custom metric
        accuracy_aggregated = sum(accuracies) / sum(examples)
        print(
            f"Round {rnd} accuracy aggregated from client results: {accuracy_aggregated}"
        )

        # Call aggregate_evaluate from base class (FedAvg)
        params, _ = super().aggregate_evaluate(rnd, results, failures)
        return params, {"accuracy": accuracy_aggregated}


# Define strategy
strategy = AggregateCustomMetricStrategy()

# Start server
fl.server.start_server(
    server_address="[::]:8080",
    config={"num_rounds": 3},
    strategy=strategy,
)

With this PR, users can instead do this:

import flwr as fl

# Define metric aggregation function
def agg(metrics):
    # Weigh accuracy of each client by number of examples used
    accuracies = [m["accuracy"] * n for m, n in metrics]
    examples = [n for _, n in metrics]

    # Aggregate and return custom metric
    return {"accuracy": sum(accuracies) / sum(examples)}


# Define strategy
strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=agg)

# Start Flower server
fl.server.start_server(
    server_address="[::]:8080",
    config={"num_rounds": 3},
    strategy=strategy,
)

tanertopal
tanertopal previously approved these changes Mar 24, 2022
Copy link
Member

@tanertopal tanertopal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some nitpicking but overall lgtm.

src/py/flwr/server/strategy/fedavg.py Outdated Show resolved Hide resolved
src/py/flwr/server/strategy/fedavg.py Outdated Show resolved Hide resolved
@danieljanes danieljanes marked this pull request as ready for review March 24, 2022 18:57
@danieljanes danieljanes merged commit 23617cd into main Mar 25, 2022
@danieljanes danieljanes deleted the metrics-aggregation-fn branch March 25, 2022 16:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants