Skip to content

Commit

Permalink
Add more error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh committed Mar 19, 2024
1 parent 098a942 commit 1282de9
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
16 changes: 14 additions & 2 deletions nvflare/app_common/workflows/cyclic_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ def _get_relay_orders(self, fl_ctx: FLContext) -> Union[List[Client], None]:
self._last_client = targets[-1]
return targets

def _stop_workflow(self, task: Task):
self.cancel_task(task)
self._is_done = True

def _process_result(self, client_task: ClientTask, fl_ctx: FLContext):
# submitted shareable is stored in client_task.result
# we need to update task.data with that shareable so the next target
Expand All @@ -176,7 +180,9 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext):
self._last_learnable = self.shareable_generator.shareable_to_learnable(result, fl_ctx)
except Exception as ex:
if rc != ReturnCode.EARLY_TERMINATION:
self._stop_workflow(task)
self.log_error(fl_ctx, f"exception {secure_format_exception(ex)} from shareable_to_learnable")
return
else:
self.log_warning(
fl_ctx,
Expand All @@ -186,15 +192,21 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext):
if rc == ReturnCode.EARLY_TERMINATION:
if self._allow_early_termination:
# the workflow is done
self.cancel_task(task)
self._stop_workflow(task)
self.log_info(fl_ctx, f"Stopping workflow due to {rc} from client {client_task.client.name}")
self._is_done = True
return
else:
self.log_warning(
fl_ctx,
f"Ignored {rc} from client {client_task.client.name} because early termination is not allowed",
)
else:
self._stop_workflow(task)
self.log_error(
fl_ctx,
f"Stopping workflow due to result from client {client_task.client.name} is not a Shareable",
)
return

# prepare task shareable data for next client
task.data = self.shareable_generator.learnable_to_shareable(self._last_learnable, fl_ctx)
Expand Down
19 changes: 14 additions & 5 deletions tests/unit_test/app_common/workflow/cyclic_ctl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,18 @@
]


def gen_shareable(is_early_termination: bool = False, is_not_shareable: bool = False):
if is_not_shareable:
return [1, 2, 3]
return_result = Shareable()
if is_early_termination:
return_result.set_return_code(ReturnCode.EARLY_TERMINATION)
return return_result


PROCESS_RESULT_TEST_CASES = [gen_shareable(is_early_termination=True), gen_shareable(is_not_shareable=True)]


class TestCyclicController:
@pytest.mark.parametrize("order,active_clients,expected_result", ORDER_TEST_CASES)
def test_get_relay_orders(self, order, active_clients, expected_result):
Expand Down Expand Up @@ -90,7 +102,8 @@ def test_control_flow_call_relay_and_wait(self):

mock_method.assert_called_once()

def test_process_result(self):
@pytest.mark.parametrize("return_result", PROCESS_RESULT_TEST_CASES)
def test_process_result(self, return_result):
ctl = CyclicController(
persist_every_n_rounds=0, snapshot_every_n_rounds=0, num_rounds=1, allow_early_termination=True
)
Expand All @@ -101,17 +114,13 @@ def test_process_result(self):
Client("site-2", SITE_2_ID),
]

abort_signal = Signal()
fl_ctx = FLContext()
with patch.object(ctl, "cancel_task") as mock_method, patch.object(
ctl.shareable_generator, "learnable_to_shareable"
) as mock_method1, patch.object(ctl.shareable_generator, "shareable_to_learnable") as mock_method2:
mock_method1.return_value = Shareable()
mock_method2.return_value = Learnable()

return_result = Shareable()
return_result.set_return_code(ReturnCode.EARLY_TERMINATION)

client_task = ClientTask(
client=Mock(),
task=Task(
Expand Down

0 comments on commit 1282de9

Please sign in to comment.