Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve InProcessClientAPIExecutor #2536

Merged
merged 7 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 0 additions & 79 deletions nvflare/app_common/executors/exec_task_fn_wrapper.py

This file was deleted.

16 changes: 8 additions & 8 deletions nvflare/app_common/executors/in_process_client_api_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from nvflare.apis.utils.analytix_utils import create_analytic_dxo
from nvflare.app_common.abstract.params_converter import ParamsConverter
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.executors.exec_task_fn_wrapper import ExecTaskFuncWrapper
from nvflare.app_common.executors.task_script_runner import TaskScriptRunner
from nvflare.app_common.tracking.tracker_types import ANALYTIC_EVENT_TYPE
from nvflare.app_common.widgets.streaming import send_analytic_dxo
from nvflare.client.api_spec import CLIENT_API_KEY
Expand Down Expand Up @@ -61,6 +61,7 @@ def __init__(
submit_model_task_name: str = "submit_model",
):
super(InProcessClientAPIExecutor, self).__init__()
self._client_api = None
self._result_pull_interval = result_pull_interval
self._log_pull_interval = log_pull_interval
self._params_exchange_format = params_exchange_format
Expand Down Expand Up @@ -104,18 +105,17 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
self._fl_ctx = fl_ctx
self._init_converter(fl_ctx)

self._task_fn_path = self._task_script_path.replace(".py", ".main")
self._task_fn_wrapper = ExecTaskFuncWrapper(
task_fn_path=self._task_fn_path, task_main_args=self._task_script_args
self._task_fn_wrapper = TaskScriptRunner(
script_path=self._task_script_path, script_args=self._task_script_args
)

self._task_fn_thread = threading.Thread(target=self._task_fn_wrapper.run)
self._task_fn_thread.start()

meta = self._prepare_task_meta(fl_ctx, None)
self.client_api = InProcessClientAPI(task_metadata=meta, result_check_interval=0.5)
self.client_api.init()
self._data_bus.put_data(CLIENT_API_KEY, self.client_api)
self._client_api = InProcessClientAPI(task_metadata=meta, result_check_interval=self._result_pull_interval)
self._client_api.init()
self._data_bus.put_data(CLIENT_API_KEY, self._client_api)

elif event_type == EventType.END_RUN:
self._event_manager.fire_event(TOPIC_STOP, "END_RUN received")
Expand All @@ -128,7 +128,7 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort
fl_ctx.set_prop("abort_signal", abort_signal)

meta = self._prepare_task_meta(fl_ctx, task_name)
self.client_api.set_meta(meta)
self._client_api.set_meta(meta)

shareable.set_header(FLMetaKey.JOB_ID, fl_ctx.get_job_id())
shareable.set_header(FLMetaKey.SITE_NAME, fl_ctx.get_identity_name())
Expand Down
77 changes: 77 additions & 0 deletions nvflare/app_common/executors/task_script_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import builtins
import logging
import os
import sys
import traceback

print_fn = builtins.print


class TaskScriptRunner:
logger = logging.getLogger(__name__)

def __init__(self, script_path: str, script_args: str = None):
"""Wrapper for function given function path and args

Args:
script_path (str): script file name, such as train.py
script_args (str, Optional): script arguments to pass in.
"""
self.script_args = script_args
self.client_api = None
self.logger = logging.getLogger(self.__class__.__name__)
self.script_path = self.get_script_full_path(script_path)

def run(self):
"""Call the task_fn with any required arguments."""
self.logger.info(f"\n start task run() with {self.script_path}")
try:
import runpy

curr_argv = sys.argv
builtins.print = log_print
sys.argv = self.get_sys_argv()
runpy.run_path(self.script_path, run_name="__main__")
sys.argv = curr_argv

except Exception as e:
msg = traceback.format_exc()
self.logger.error(msg)
if self.client_api:
self.client_api.exec_queue.ask_abort(msg)
raise e
finally:
builtins.print = print_fn

def get_sys_argv(self):
args_list = [] if not self.script_args else self.script_args.split()
return [self.script_path] + args_list

def get_script_full_path(self, script_path) -> str:
target_files = None
for r, dirs, files in os.walk(os.getcwd()):
target_files = [os.path.join(r, f) for f in files if f == script_path]
if target_files:
break
if not target_files:
raise ValueError(f"{script_path} is not found")
return target_files[0]


def log_print(*args, logger=TaskScriptRunner.logger, **kwargs):
# Create a message from print arguments
message = " ".join(str(arg) for arg in args)
logger.info(message)
100 changes: 0 additions & 100 deletions tests/unit_test/app_common/executors/exec_task_fn_wrapper_test.py

This file was deleted.

28 changes: 28 additions & 0 deletions tests/unit_test/app_common/executors/task_script_runner_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest

from nvflare.app_common.executors.task_script_runner import TaskScriptRunner


class TestExecTaskFuncWrapper(unittest.TestCase):
def test_app_scripts_and_args(self):
curr_dir = os.getcwd()
script_path = "cli.py"
script_args = "--batch_size 4"
wrapper = TaskScriptRunner(script_path=script_path, script_args=script_args)

self.assertTrue(wrapper.script_path.endswith(script_path))
self.assertEqual(wrapper.get_sys_argv(), [os.path.join(curr_dir, "nvflare", "cli.py"), "--batch_size", "4"])
Loading