diff --git a/nvflare/app_common/workflows/base_fedavg.py b/nvflare/app_common/workflows/base_fedavg.py index 7587fa40d5..f691034cbe 100644 --- a/nvflare/app_common/workflows/base_fedavg.py +++ b/nvflare/app_common/workflows/base_fedavg.py @@ -107,7 +107,7 @@ def _check_results(results: List[FLModel]): raise ValueError(f"Result from client(s) {empty_clients} is empty!") @staticmethod - def _aggregate_fn(results: List[FLModel]) -> FLModel: + def aggregate_fn(results: List[FLModel]) -> FLModel: aggregation_helper = WeightedAggregationHelper() for _result in results: aggregation_helper.add( @@ -141,7 +141,7 @@ def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel: self._check_results(results) if not aggregate_fn: - aggregate_fn = self._aggregate_fn + aggregate_fn = self.aggregate_fn self.info(f"aggregating {len(results)} update(s) at round {self.current_round}") try: diff --git a/nvflare/app_common/workflows/fedavg.py b/nvflare/app_common/workflows/fedavg.py index d03ae8999b..404871582d 100644 --- a/nvflare/app_common/workflows/fedavg.py +++ b/nvflare/app_common/workflows/fedavg.py @@ -57,8 +57,8 @@ def run(self) -> None: results = self.send_model_and_wait(targets=clients, data=model) aggregate_results = self.aggregate( - results, aggregate_fn=None - ) # if no `aggregate_fn` provided, default `WeightedAggregationHelper` is used + results, aggregate_fn=self.aggregate_fn + ) # using default aggregate_fn with `WeightedAggregationHelper`. Can overwrite self.aggregrate_fn with signature Callable[List[FLModel], FLModel] model = self.update_model(model, aggregate_results)