Skip to content

Commit

Permalink
Make Flower server serialization-agnostic (#721)
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljanes authored May 9, 2021
1 parent 4fb04e9 commit 79bcf95
Show file tree
Hide file tree
Showing 14 changed files with 252 additions and 157 deletions.
13 changes: 8 additions & 5 deletions src/py/flwr/client/numpy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@

from .client import Client

DEPRECATION_WARNING_FIT = """DEPRECATION WARNING: deprecated return format
DEPRECATION_WARNING_FIT = """
DEPRECATION WARNING: deprecated return format
parameters, num_examples
Expand All @@ -46,7 +47,8 @@
instead. Note that the deprecated return format will be removed in a future
release.
"""
DEPRECATION_WARNING_EVALUATE_0 = """DEPRECATION WARNING: deprecated return format
DEPRECATION_WARNING_EVALUATE_0 = """
DEPRECATION WARNING: deprecated return format
num_examples, loss, accuracy
Expand All @@ -57,7 +59,8 @@
instead. Note that the deprecated return format will be removed in a future
release.
"""
DEPRECATION_WARNING_EVALUATE_1 = """DEPRECATION WARNING: deprecated return format
DEPRECATION_WARNING_EVALUATE_1 = """
DEPRECATION WARNING: deprecated return format
num_examples, loss, accuracy, {"custom_key": custom_val}
Expand Down Expand Up @@ -94,9 +97,9 @@ def fit(
Parameters
----------
parameters: List[numpy.ndarray]
parameters : List[numpy.ndarray]
The current (global) model parameters.
config: Dict[str, Scalar]
config : Dict[str, Scalar]
Configuration parameters which allow the
server to influence training on the client. It can be used to
communicate arbitrary values from the server to the client, for
Expand Down
72 changes: 47 additions & 25 deletions src/py/flwr/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,19 @@
import concurrent.futures
import timeit
from logging import DEBUG, INFO, WARNING
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

from flwr.common import (
Disconnect,
EvaluateIns,
EvaluateRes,
FitIns,
FitRes,
Parameters,
Reconnect,
Scalar,
Weights,
parameters_to_weights,
weights_to_parameters,
)
from flwr.common.logger import log
from flwr.server.client_manager import ClientManager
Expand Down Expand Up @@ -95,7 +96,9 @@ def __init__(
self, client_manager: ClientManager, strategy: Optional[Strategy] = None
) -> None:
self._client_manager: ClientManager = client_manager
self.weights: Weights = []
self.parameters: Parameters = Parameters(
tensors=[], tensor_type="numpy.ndarray"
)
self.strategy: Strategy = strategy if strategy is not None else FedAvg()

def set_strategy(self, strategy: Strategy) -> None:
Expand All @@ -113,9 +116,9 @@ def fit(self, num_rounds: int) -> History:

# Initialize parameters
log(INFO, "Getting initial parameters")
self.weights = self._get_initial_parameters()
self.parameters = self._get_initial_parameters()
log(INFO, "Evaluating initial parameters")
res = self.strategy.evaluate(weights=self.weights)
res = self.strategy.evaluate(parameters=self.parameters)
if res is not None:
log(
INFO,
Expand All @@ -134,12 +137,12 @@ def fit(self, num_rounds: int) -> History:
# Train model and replace previous global model
res_fit = self.fit_round(rnd=current_round)
if res_fit:
weights_prime, _, _ = res_fit # fit_metrics_aggregated
if weights_prime:
self.weights = weights_prime
parameters_prime, _, _ = res_fit # fit_metrics_aggregated
if parameters_prime:
self.parameters = parameters_prime

# Evaluate model using strategy implementation
res_cen = self.strategy.evaluate(weights=self.weights)
res_cen = self.strategy.evaluate(parameters=self.parameters)
if res_cen is not None:
loss_cen, metrics_cen = res_cen
log(
Expand Down Expand Up @@ -190,7 +193,7 @@ def evaluate_round(

# Get clients and their respective instructions from strategy
client_instructions = self.strategy.configure_evaluate(
rnd=rnd, weights=self.weights, client_manager=self._client_manager
rnd=rnd, parameters=self.parameters, client_manager=self._client_manager
)
if not client_instructions:
log(INFO, "evaluate_round: no clients selected, cancel")
Expand All @@ -212,24 +215,35 @@ def evaluate_round(
)

# Aggregate the evaluation results
aggregated_result = self.strategy.aggregate_evaluate(rnd, results, failures)
if isinstance(aggregated_result, float) or aggregated_result is None:
aggregated_result: Union[
Tuple[Optional[float], Dict[str, Scalar]],
Optional[float], # Deprecated
] = self.strategy.aggregate_evaluate(rnd, results, failures)

metrics_aggregated: Dict[str, Scalar] = {}
if aggregated_result is None:
# Backward-compatibility, this will be removed in a future update
log(WARNING, DEPRECATION_WARNING_EVALUATE_ROUND)
loss_aggregated = None
elif isinstance(aggregated_result, float):
# Backward-compatibility, this will be removed in a future update
log(WARNING, DEPRECATION_WARNING_EVALUATE_ROUND)
loss_aggregated = aggregated_result
metrics_aggregated: Dict[str, Scalar] = {}
else:
loss_aggregated, metrics_aggregated = aggregated_result

return loss_aggregated, metrics_aggregated, (results, failures)

def fit_round(
self, rnd: int
) -> Optional[Tuple[Optional[Weights], Dict[str, Scalar], FitResultsAndFailures]]:
) -> Optional[
Tuple[Optional[Parameters], Dict[str, Scalar], FitResultsAndFailures]
]:
"""Perform a single round of federated averaging."""

# Get clients and their respective instructions from strategy
client_instructions = self.strategy.configure_fit(
rnd=rnd, weights=self.weights, client_manager=self._client_manager
rnd=rnd, parameters=self.parameters, client_manager=self._client_manager
)
if not client_instructions:
log(INFO, "fit_round: no clients selected, cancel")
Expand All @@ -251,26 +265,35 @@ def fit_round(
)

# Aggregate training results
aggregated_result = self.strategy.aggregate_fit(rnd, results, failures)
if isinstance(aggregated_result, list) or aggregated_result is None:
aggregated_result: Union[
Tuple[Optional[Parameters], Dict[str, Scalar]],
Optional[Weights], # Deprecated
] = self.strategy.aggregate_fit(rnd, results, failures)

metrics_aggregated: Dict[str, Scalar] = {}
if aggregated_result is None:
# Backward-compatibility, this will be removed in a future update
log(WARNING, DEPRECATION_WARNING_FIT_ROUND)
weights_aggregated = aggregated_result
metrics_aggregated: Dict[str, Scalar] = {}
parameters_aggregated = None
elif isinstance(aggregated_result, list):
# Backward-compatibility, this will be removed in a future update
log(WARNING, DEPRECATION_WARNING_FIT_ROUND)
parameters_aggregated = weights_to_parameters(aggregated_result)
else:
weights_aggregated, metrics_aggregated = aggregated_result
return weights_aggregated, metrics_aggregated, (results, failures)
parameters_aggregated, metrics_aggregated = aggregated_result

return parameters_aggregated, metrics_aggregated, (results, failures)

def disconnect_all_clients(self) -> None:
"""Send shutdown signal to all clients."""
all_clients = self._client_manager.all()
_ = shutdown(clients=[all_clients[k] for k in all_clients.keys()])

def _get_initial_parameters(self) -> Weights:
def _get_initial_parameters(self) -> Parameters:
"""Get initial parameters from one of the available clients."""

# Server-side parameter initialization
parameters: Optional[Weights] = self.strategy.initialize_parameters(
parameters: Optional[Parameters] = self.strategy.initialize_parameters(
client_manager=self._client_manager
)
if parameters is not None:
Expand All @@ -280,9 +303,8 @@ def _get_initial_parameters(self) -> Weights:
# Get initial parameters from one of the clients
random_client = self._client_manager.sample(1)[0]
parameters_res = random_client.get_parameters()
parameters = parameters_to_weights(parameters_res.parameters)
log(INFO, "Received initial parameters from one random client")
return parameters
return parameters_res.parameters


def shutdown(clients: List[ClientProxy]) -> ReconnectResultsAndFailures:
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/server/strategy/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from typing import Callable, Dict, Optional, Tuple

from flwr.common import Scalar, Weights
from flwr.common import Parameters, Scalar, Weights

from .fedavg import FedAvg

Expand All @@ -45,7 +45,7 @@ def __init__(
on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
accept_failures: bool = True,
initial_parameters: Optional[Weights] = None,
initial_parameters: Optional[Parameters] = None,
) -> None:
super().__init__(
fraction_fit=fraction_fit,
Expand Down
10 changes: 5 additions & 5 deletions src/py/flwr/server/strategy/fast_and_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
EvaluateRes,
FitIns,
FitRes,
Parameters,
Scalar,
Weights,
parameters_to_weights,
Expand Down Expand Up @@ -72,7 +73,7 @@ def __init__(
r_slow: int = 1,
t_fast: int = 10,
t_slow: int = 10,
initial_parameters: Optional[Weights] = None,
initial_parameters: Optional[Parameters] = None,
) -> None:
super().__init__(
fraction_fit=fraction_fit,
Expand Down Expand Up @@ -109,7 +110,7 @@ def __repr__(self) -> str:

# pylint: disable=too-many-locals
def configure_fit(
self, rnd: int, weights: Weights, client_manager: ClientManager
self, rnd: int, parameters: Parameters, client_manager: ClientManager
) -> List[Tuple[ClientProxy, FitIns]]:
"""Configure the next round of training."""

Expand Down Expand Up @@ -175,7 +176,6 @@ def configure_fit(
)

# Prepare parameters and config
parameters = weights_to_parameters(weights)
config = {}
if self.on_fit_config_fn is not None:
# Use custom fit config function if provided
Expand Down Expand Up @@ -302,7 +302,7 @@ def aggregate_fit(
rnd: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[BaseException],
) -> Tuple[Optional[Weights], Dict[str, Scalar]]:
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate fit results using weighted average."""
if not results:
return None, {}
Expand Down Expand Up @@ -347,7 +347,7 @@ def aggregate_fit(
)
self.durations.append(cid_duration)

return weights_prime, {}
return weights_to_parameters(weights_prime), {}

def aggregate_evaluate(
self,
Expand Down
16 changes: 12 additions & 4 deletions src/py/flwr/server/strategy/fault_tolerant_fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@

from typing import Callable, Dict, List, Optional, Tuple

from flwr.common import EvaluateRes, FitRes, Scalar, Weights, parameters_to_weights
from flwr.common import (
EvaluateRes,
FitRes,
Parameters,
Scalar,
Weights,
parameters_to_weights,
weights_to_parameters,
)
from flwr.server.client_proxy import ClientProxy

from .aggregate import aggregate, weighted_loss_avg
Expand All @@ -42,7 +50,7 @@ def __init__(
on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
min_completion_rate_fit: float = 0.5,
min_completion_rate_evaluate: float = 0.5,
initial_parameters: Optional[Weights] = None,
initial_parameters: Optional[Parameters] = None,
) -> None:
super().__init__(
fraction_fit=fraction_fit,
Expand All @@ -67,7 +75,7 @@ def aggregate_fit(
rnd: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[BaseException],
) -> Tuple[Optional[Weights], Dict[str, Scalar]]:
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate fit results using weighted average."""
if not results:
return None, {}
Expand All @@ -81,7 +89,7 @@ def aggregate_fit(
(parameters_to_weights(fit_res.parameters), fit_res.num_examples)
for client, fit_res in results
]
return aggregate(weights_results), {}
return weights_to_parameters(aggregate(weights_results)), {}

def aggregate_evaluate(
self,
Expand Down
14 changes: 8 additions & 6 deletions src/py/flwr/server/strategy/fault_tolerant_fedavg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import List, Optional, Tuple
from unittest.mock import MagicMock

from flwr.common import EvaluateRes, FitRes, Parameters, Weights
from flwr.common import EvaluateRes, FitRes, Parameters, Weights, parameters_to_weights
from flwr.server.client_proxy import ClientProxy

from .fault_tolerant_fedavg import FaultTolerantFedAvg
Expand All @@ -30,7 +30,7 @@ def test_aggregate_fit_no_results_no_failures() -> None:
strategy = FaultTolerantFedAvg(min_completion_rate_fit=0.1)
results: List[Tuple[ClientProxy, FitRes]] = []
failures: List[BaseException] = []
expected: Optional[Weights] = None
expected: Optional[Parameters] = None

# Execute
actual, _ = strategy.aggregate_fit(1, results, failures)
Expand All @@ -45,7 +45,7 @@ def test_aggregate_fit_no_results() -> None:
strategy = FaultTolerantFedAvg(min_completion_rate_fit=0.1)
results: List[Tuple[ClientProxy, FitRes]] = []
failures: List[BaseException] = [Exception()]
expected: Optional[Weights] = None
expected: Optional[Parameters] = None

# Execute
actual, _ = strategy.aggregate_fit(1, results, failures)
Expand All @@ -62,7 +62,7 @@ def test_aggregate_fit_not_enough_results() -> None:
(MagicMock(), FitRes(Parameters(tensors=[], tensor_type=""), 1, 1, 0.1))
]
failures: List[BaseException] = [Exception(), Exception()]
expected: Optional[Weights] = None
expected: Optional[Parameters] = None

# Execute
actual, _ = strategy.aggregate_fit(1, results, failures)
Expand All @@ -85,7 +85,8 @@ def test_aggregate_fit_just_enough_results() -> None:
actual, _ = strategy.aggregate_fit(1, results, failures)

# Assert
assert actual == expected
assert actual
assert parameters_to_weights(actual) == expected


def test_aggregate_fit_no_failures() -> None:
Expand All @@ -102,7 +103,8 @@ def test_aggregate_fit_no_failures() -> None:
actual, _ = strategy.aggregate_fit(1, results, failures)

# Assert
assert actual == expected
assert actual
assert parameters_to_weights(actual) == expected


def test_aggregate_evaluate_no_results_no_failures() -> None:
Expand Down
Loading

0 comments on commit 79bcf95

Please sign in to comment.