diff --git a/nvflare/app_common/workflows/cyclic_ctl.py b/nvflare/app_common/workflows/cyclic_ctl.py index a3f1631d80..524c26359f 100644 --- a/nvflare/app_common/workflows/cyclic_ctl.py +++ b/nvflare/app_common/workflows/cyclic_ctl.py @@ -161,6 +161,10 @@ def _get_relay_orders(self, fl_ctx: FLContext) -> Union[List[Client], None]: self._last_client = targets[-1] return targets + def _stop_workflow(self, task: Task): + self.cancel_task(task) + self._is_done = True + def _process_result(self, client_task: ClientTask, fl_ctx: FLContext): # submitted shareable is stored in client_task.result # we need to update task.data with that shareable so the next target @@ -176,7 +180,9 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext): self._last_learnable = self.shareable_generator.shareable_to_learnable(result, fl_ctx) except Exception as ex: if rc != ReturnCode.EARLY_TERMINATION: + self._stop_workflow(task) self.log_error(fl_ctx, f"exception {secure_format_exception(ex)} from shareable_to_learnable") + return else: self.log_warning( fl_ctx, @@ -186,15 +192,21 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext): if rc == ReturnCode.EARLY_TERMINATION: if self._allow_early_termination: # the workflow is done - self.cancel_task(task) + self._stop_workflow(task) self.log_info(fl_ctx, f"Stopping workflow due to {rc} from client {client_task.client.name}") - self._is_done = True return else: self.log_warning( fl_ctx, f"Ignored {rc} from client {client_task.client.name} because early termination is not allowed", ) + else: + self._stop_workflow(task) + self.log_error( + fl_ctx, + f"Stopping workflow due to result from client {client_task.client.name} is not a Shareable", + ) + return # prepare task shareable data for next client task.data = self.shareable_generator.learnable_to_shareable(self._last_learnable, fl_ctx) diff --git a/tests/unit_test/app_common/workflow/cyclic_ctl_test.py b/tests/unit_test/app_common/workflow/cyclic_ctl_test.py index 4fb88236dd..9f17e310e2 100644 --- a/tests/unit_test/app_common/workflow/cyclic_ctl_test.py +++ b/tests/unit_test/app_common/workflow/cyclic_ctl_test.py @@ -55,6 +55,18 @@ ] +def gen_shareable(is_early_termination: bool = False, is_not_shareable: bool = False): + if is_not_shareable: + return [1, 2, 3] + return_result = Shareable() + if is_early_termination: + return_result.set_return_code(ReturnCode.EARLY_TERMINATION) + return return_result + + +PROCESS_RESULT_TEST_CASES = [gen_shareable(is_early_termination=True), gen_shareable(is_not_shareable=True)] + + class TestCyclicController: @pytest.mark.parametrize("order,active_clients,expected_result", ORDER_TEST_CASES) def test_get_relay_orders(self, order, active_clients, expected_result): @@ -90,7 +102,8 @@ def test_control_flow_call_relay_and_wait(self): mock_method.assert_called_once() - def test_process_result(self): + @pytest.mark.parametrize("return_result", PROCESS_RESULT_TEST_CASES) + def test_process_result(self, return_result): ctl = CyclicController( persist_every_n_rounds=0, snapshot_every_n_rounds=0, num_rounds=1, allow_early_termination=True ) @@ -101,7 +114,6 @@ def test_process_result(self): Client("site-2", SITE_2_ID), ] - abort_signal = Signal() fl_ctx = FLContext() with patch.object(ctl, "cancel_task") as mock_method, patch.object( ctl.shareable_generator, "learnable_to_shareable" @@ -109,9 +121,6 @@ def test_process_result(self): mock_method1.return_value = Shareable() mock_method2.return_value = Learnable() - return_result = Shareable() - return_result.set_return_code(ReturnCode.EARLY_TERMINATION) - client_task = ClientTask( client=Mock(), task=Task(