Skip to content

Commit

Permalink
Implement FedMedian (#1461)
Browse files Browse the repository at this point in the history
  • Loading branch information
edogab33 committed Oct 31, 2022
1 parent a163c9c commit f32f96a
Show file tree
Hide file tree
Showing 4 changed files with 370 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/py/flwr/server/strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .fedavg import FedAvg as FedAvg
from .fedavg_android import FedAvgAndroid as FedAvgAndroid
from .fedavgm import FedAvgM as FedAvgM
from .fedmedian import FedMedian as FedMedian
from .fedopt import FedOpt as FedOpt
from .fedyogi import FedYogi as FedYogi
from .qfedavg import QFedAvg as QFedAvg
Expand All @@ -36,5 +37,6 @@
"FedOpt",
"FedYogi",
"QFedAvg",
"FedMedian",
"Strategy",
]
12 changes: 12 additions & 0 deletions src/py/flwr/server/strategy/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ def aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays:
return weights_prime


def aggregate_median(results: List[Tuple[NDArrays, int]]) -> NDArrays:
"""Compute median."""
# Create a list of weights and ignore the number of examples
weights = [weights for weights, _ in results]

# Compute median weight of each layer
median_w: NDArrays = [
np.median(np.asarray(layer), axis=0) for layer in zip(*weights) # type: ignore
]
return median_w


def weighted_loss_avg(results: List[Tuple[int, float]]) -> float:
"""Aggregate evaluation results obtained from multiple clients."""
num_total_evaluation_examples = sum([num_examples for num_examples, _ in results])
Expand Down
157 changes: 157 additions & 0 deletions src/py/flwr/server/strategy/fedmedian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright 2020 Adap GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Federated Median (FedMedian) [Yin et al., 2018] strategy.
Paper: https://arxiv.org/pdf/1803.01498v1.pdf
"""


from logging import WARNING
from typing import Callable, Dict, List, Optional, Tuple, Union

from flwr.common import (
FitRes,
MetricsAggregationFn,
NDArrays,
Parameters,
Scalar,
ndarrays_to_parameters,
parameters_to_ndarrays,
)
from flwr.common.logger import log
from flwr.server.client_proxy import ClientProxy

from .aggregate import aggregate_median
from .fedavg import FedAvg

WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """
Setting `min_available_clients` lower than `min_fit_clients` or
`min_evaluate_clients` can cause the server to fail when there are too few clients
connected to the server. `min_available_clients` must be set to a value larger
than or equal to the values of `min_fit_clients` and `min_evaluate_clients`.
"""

# flake8: noqa: E501
class FedMedian(FedAvg):
"""Configurable FedAvg with Momentum strategy implementation."""

# pylint: disable=too-many-arguments,too-many-instance-attributes,line-too-long
def __init__(
self,
*,
fraction_fit: float = 1.0,
fraction_evaluate: float = 1.0,
min_fit_clients: int = 2,
min_evaluate_clients: int = 2,
min_available_clients: int = 2,
evaluate_fn: Optional[
Callable[
[int, NDArrays, Dict[str, Scalar]],
Optional[Tuple[float, Dict[str, Scalar]]],
]
] = None,
on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
accept_failures: bool = True,
initial_parameters: Optional[Parameters] = None,
fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
) -> None:
"""Configurable FedMedian strategy.
Implementation based on https://arxiv.org/pdf/1803.01498v1.pdf
Parameters
----------
fraction_fit : float, optional
Fraction of clients used during training. Defaults to 0.1.
fraction_evaluate : float, optional
Fraction of clients used during validation. Defaults to 0.1.
min_fit_clients : int, optional
Minimum number of clients used during training. Defaults to 2.
min_evaluate_clients : int, optional
Minimum number of clients used during validation. Defaults to 2.
min_available_clients : int, optional
Minimum number of total clients in the system. Defaults to 2.
evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]]
Optional function used for validation. Defaults to None.
on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional
Function used to configure training. Defaults to None.
on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional
Function used to configure validation. Defaults to None.
accept_failures : bool, optional
Whether or not accept rounds containing failures. Defaults to True.
initial_parameters : Parameters, optional
Initial global model parameters.
"""

if (
min_fit_clients > min_available_clients
or min_evaluate_clients > min_available_clients
):
log(WARNING, WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW)

super().__init__(
fraction_fit=fraction_fit,
fraction_evaluate=fraction_evaluate,
min_fit_clients=min_fit_clients,
min_evaluate_clients=min_evaluate_clients,
min_available_clients=min_available_clients,
evaluate_fn=evaluate_fn,
on_fit_config_fn=on_fit_config_fn,
on_evaluate_config_fn=on_evaluate_config_fn,
accept_failures=accept_failures,
initial_parameters=initial_parameters,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
)
self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn
self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn

def __repr__(self) -> str:
rep = f"FedMedian(accept_failures={self.accept_failures})"
return rep

def aggregate_fit(
self,
server_round: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate fit results using median."""
if not results:
return None, {}
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}

# Convert results
weights_results = [
(parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
for _, fit_res in results
]
parameters_aggregated = ndarrays_to_parameters(
aggregate_median(weights_results)
)

# Aggregate custom metrics if aggregation fn was provided
metrics_aggregated = {}
if self.fit_metrics_aggregation_fn:
fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
elif server_round == 1: # Only log this warning once
log(WARNING, "No fit_metrics_aggregation_fn provided")

return parameters_aggregated, metrics_aggregated
199 changes: 199 additions & 0 deletions src/py/flwr/server/strategy/fedmedian_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# Copyright 2020 Adap GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""FedMedian tests."""

from typing import List, Tuple
from unittest.mock import MagicMock

from numpy import array, float32

from flwr.common import (
Code,
FitRes,
NDArrays,
Parameters,
Status,
ndarrays_to_parameters,
parameters_to_ndarrays,
)
from flwr.server.client_proxy import ClientProxy
from flwr.server.grpc_server.grpc_client_proxy import GrpcClientProxy

from .fedmedian import FedMedian


def test_fedmedian_num_fit_clients_20_available() -> None:
"""Test num_fit_clients function."""
# Prepare
strategy = FedMedian()
expected = 20

# Execute
actual, _ = strategy.num_fit_clients(num_available_clients=20)

# Assert
assert expected == actual


def test_fedmedian_num_fit_clients_19_available() -> None:
"""Test num_fit_clients function."""
# Prepare
strategy = FedMedian()
expected = 19

# Execute
actual, _ = strategy.num_fit_clients(num_available_clients=19)

# Assert
assert expected == actual


def test_fedmedian_num_fit_clients_10_available() -> None:
"""Test num_fit_clients function."""
# Prepare
strategy = FedMedian()
expected = 10

# Execute
actual, _ = strategy.num_fit_clients(num_available_clients=10)

# Assert
assert expected == actual


def test_fedmedian_num_fit_clients_minimum() -> None:
"""Test num_fit_clients function."""
# Prepare
strategy = FedMedian()
expected = 9

# Execute
actual, _ = strategy.num_fit_clients(num_available_clients=9)

# Assert
assert expected == actual


def test_fedmedian_num_evaluation_clients_40_available() -> None:
"""Test num_evaluation_clients function."""
# Prepare
strategy = FedMedian(fraction_evaluate=0.05)
expected = 2

# Execute
actual, _ = strategy.num_evaluation_clients(num_available_clients=40)

# Assert
assert expected == actual


def test_fedmedian_num_evaluation_clients_39_available() -> None:
"""Test num_evaluation_clients function."""
# Prepare
strategy = FedMedian(fraction_evaluate=0.05)
expected = 2

# Execute
actual, _ = strategy.num_evaluation_clients(num_available_clients=39)

# Assert
assert expected == actual


def test_fedmedian_num_evaluation_clients_20_available() -> None:
"""Test num_evaluation_clients function."""
# Prepare
strategy = FedMedian(fraction_evaluate=0.05)
expected = 2

# Execute
actual, _ = strategy.num_evaluation_clients(num_available_clients=20)

# Assert
assert expected == actual


def test_fedmedian_num_evaluation_clients_minimum() -> None:
"""Test num_evaluation_clients function."""
# Prepare
strategy = FedMedian(fraction_evaluate=0.05)
expected = 2

# Execute
actual, _ = strategy.num_evaluation_clients(num_available_clients=19)

# Assert
assert expected == actual


def test_aggregate_fit() -> None:
"""Tests if FedMedian is aggregating correctly."""
# Prepare
previous_weights: NDArrays = [array([0.1, 0.1, 0.1, 0.1], dtype=float32)]
strategy = FedMedian(
initial_parameters=ndarrays_to_parameters(previous_weights),
)
param_0: Parameters = ndarrays_to_parameters(
[array([0.2, 0.2, 0.2, 0.2], dtype=float32)]
)
param_1: Parameters = ndarrays_to_parameters(
[array([1.0, 1.0, 1.0, 1.0], dtype=float32)]
)
param_2: Parameters = ndarrays_to_parameters(
[array([0.5, 0.5, 0.5, 0.5], dtype=float32)]
)
bridge = MagicMock()
client_0 = GrpcClientProxy(cid="0", bridge=bridge)
client_1 = GrpcClientProxy(cid="1", bridge=bridge)
client_2 = GrpcClientProxy(cid="2", bridge=bridge)
results: List[Tuple[ClientProxy, FitRes]] = [
(
client_0,
FitRes(
status=Status(code=Code.OK, message="Success"),
parameters=param_0,
num_examples=5,
metrics={},
),
),
(
client_1,
FitRes(
status=Status(code=Code.OK, message="Success"),
parameters=param_1,
num_examples=5,
metrics={},
),
),
(
client_2,
FitRes(
status=Status(code=Code.OK, message="Success"),
parameters=param_2,
num_examples=5,
metrics={},
),
),
]
expected: NDArrays = [array([0.5, 0.5, 0.5, 0.5], dtype=float32)]

# Execute
actual_aggregated, _ = strategy.aggregate_fit(
server_round=1, results=results, failures=[]
)
if actual_aggregated:
actual_list = parameters_to_ndarrays(actual_aggregated)
actual = actual_list[0]
assert (actual == expected[0]).all()

0 comments on commit f32f96a

Please sign in to comment.