Skip to content

Commit

Permalink
enhance abort, set peer props
Browse files Browse the repository at this point in the history
  • Loading branch information
SYangster committed May 2, 2024
1 parent 88d26a2 commit 95a9be5
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 7 deletions.
2 changes: 2 additions & 0 deletions job_templates/sag_cse_ccwf_pt/config_fed_server.conf
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
id = "svr_ctl"
path = "nvflare.app_common.ccwf.server_ctl.ServerSideController"
args {
# the prefix for task names of this workflow
task_name_prefix = "wf1"
# the maximum amount of time allowed for a client to miss a status report
max_status_report_interval = 300
# policy to choose which client to run the controller logic from
starting_client_policy = "random"
Expand Down
3 changes: 3 additions & 0 deletions nvflare/apis/impl/wf_comm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ def broadcast_and_wait(
client_task.result = error_reply
break

if not reply.get_peer_props() and fl_ctx.get_peer_context():
reply.set_peer_props(fl_ctx.get_peer_context().get_all_public_props())

# assign replies to client task, prepare for the result_received_cb
client_task.result = reply

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,6 @@ def accept(self, shareable: Shareable, fl_ctx: FLContext) -> bool:
return False

contributor_name = shareable.get_peer_prop(key=ReservedKey.IDENTITY_NAME, default="?")

peer_ctx = fl_ctx.get_peer_context()
if contributor_name == "?" and peer_ctx:
contributor_name = peer_ctx.get_identity_name("?")

contribution_round = shareable.get_cookie(AppConstants.CONTRIBUTION_ROUND)

rc = shareable.get_return_code()
Expand Down
26 changes: 24 additions & 2 deletions nvflare/app_common/ccwf/client_controller_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.ccwf.common import Constant, StatusReport, make_task_name, topic_for_end_workflow
from nvflare.fuel.utils.validation_utils import check_number_range
from nvflare.security.logging import secure_format_exception


class ClientControllerExecutor(Executor):
Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(
self.me = None
self.is_starting_client = False
self.workflow_done = False
self.fatal_system_error = False

def get_config_prop(self, name: str, default=None):
"""
Expand Down Expand Up @@ -149,6 +151,11 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
self.asked_to_stop = True
self.finalize(fl_ctx)

elif event_type == EventType.FATAL_SYSTEM_ERROR:
if self.is_starting_client and not self.fatal_system_error:
self.fatal_system_error = True
self.fire_fed_event(EventType.FATAL_SYSTEM_ERROR, Shareable(), fl_ctx)

def _add_status_report(self, report: StatusReport, fl_ctx: FLContext):
reports = fl_ctx.get_prop(Constant.STATUS_REPORTS)
if not reports:
Expand Down Expand Up @@ -182,6 +189,10 @@ def finalize(self, fl_ctx: FLContext):
self.workflow_done = True

def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
if self.workflow_done:
self.log_error(fl_ctx, f"ClientControllerExecutor is finalized, not executing task {task_name}.")
return make_reply(ReturnCode.ERROR)

if task_name == self.configure_task_name:
self.config = shareable[Constant.CONFIG]
my_wf_id = self.get_config_prop(FLContextKey.WORKFLOW)
Expand Down Expand Up @@ -210,7 +221,15 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort
self.controller = self.initialize_controller(controller_id, fl_ctx)
self.log_info(fl_ctx, f"Starting control flow {self.controller.name}")

res = self.controller.control_flow(abort_signal, fl_ctx)
try:
res = self.controller.control_flow(abort_signal, fl_ctx)
except Exception as e:
error_msg = f"{controller_id} control_flow exception: {secure_format_exception(e)}"
self.log_error(fl_ctx, error_msg)
self.system_panic(error_msg, fl_ctx)

if abort_signal.triggered:
return make_reply(ReturnCode.TASK_ABORTED)

if hasattr(self.controller, "persistor"):
self.broadcast_final_result(self.controller.persistor.load(fl_ctx), fl_ctx)
Expand All @@ -219,7 +238,10 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort

self.log_info(fl_ctx, f"Finished control flow {self.controller.name}")

self.update_status(action="finished_start_task", error=None, all_done=True)
self.update_status(action=f"finished_{controller_id}", error=None, all_done=True)

self.update_status(action=f"finished_start_task", error=None, all_done=True)

return make_reply(ReturnCode.OK)

elif task_name == self.report_final_result_task_name:
Expand Down

0 comments on commit 95a9be5

Please sign in to comment.