Skip to content

Commit

Permalink
Improve InProcessClientAPIExecutor (#2536)
Browse files Browse the repository at this point in the history
* 1. rename ExeTaskFnWrapper class to TaskScriptRunner
2. Replace implementation of the inprocess function exection from calling a main() function to user runpy.run_path() which reduce the user requirements to have main() function
3. redirect print() to logger.info()

* 1. rename ExeTaskFnWrapper class to TaskScriptRunner
2. Replace implementation of the inprocess function exection from calling a main() function to user runpy.run_path() which reduce the user requirements to have main() function
3. redirect print() to logger.info()

* make result check and result pull use the same configurable variable

* rename exec_task_fn_wrapper to task_script_runner.py

* fix typo
  • Loading branch information
chesterxgchen committed Apr 30, 2024
1 parent b0faf52 commit ce79609
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 187 deletions.
79 changes: 0 additions & 79 deletions nvflare/app_common/executors/exec_task_fn_wrapper.py

This file was deleted.

16 changes: 8 additions & 8 deletions nvflare/app_common/executors/in_process_client_api_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from nvflare.apis.utils.analytix_utils import create_analytic_dxo
from nvflare.app_common.abstract.params_converter import ParamsConverter
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.executors.exec_task_fn_wrapper import ExecTaskFuncWrapper
from nvflare.app_common.executors.task_script_runner import TaskScriptRunner
from nvflare.app_common.tracking.tracker_types import ANALYTIC_EVENT_TYPE
from nvflare.app_common.widgets.streaming import send_analytic_dxo
from nvflare.client.api_spec import CLIENT_API_KEY
Expand Down Expand Up @@ -61,6 +61,7 @@ def __init__(
submit_model_task_name: str = "submit_model",
):
super(InProcessClientAPIExecutor, self).__init__()
self._client_api = None
self._result_pull_interval = result_pull_interval
self._log_pull_interval = log_pull_interval
self._params_exchange_format = params_exchange_format
Expand Down Expand Up @@ -104,18 +105,17 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
self._fl_ctx = fl_ctx
self._init_converter(fl_ctx)

self._task_fn_path = self._task_script_path.replace(".py", ".main")
self._task_fn_wrapper = ExecTaskFuncWrapper(
task_fn_path=self._task_fn_path, task_main_args=self._task_script_args
self._task_fn_wrapper = TaskScriptRunner(
script_path=self._task_script_path, script_args=self._task_script_args
)

self._task_fn_thread = threading.Thread(target=self._task_fn_wrapper.run)
self._task_fn_thread.start()

meta = self._prepare_task_meta(fl_ctx, None)
self.client_api = InProcessClientAPI(task_metadata=meta, result_check_interval=0.5)
self.client_api.init()
self._data_bus.put_data(CLIENT_API_KEY, self.client_api)
self._client_api = InProcessClientAPI(task_metadata=meta, result_check_interval=self._result_pull_interval)
self._client_api.init()
self._data_bus.put_data(CLIENT_API_KEY, self._client_api)

elif event_type == EventType.END_RUN:
self._event_manager.fire_event(TOPIC_STOP, "END_RUN received")
Expand All @@ -128,7 +128,7 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort
fl_ctx.set_prop("abort_signal", abort_signal)

meta = self._prepare_task_meta(fl_ctx, task_name)
self.client_api.set_meta(meta)
self._client_api.set_meta(meta)

shareable.set_header(FLMetaKey.JOB_ID, fl_ctx.get_job_id())
shareable.set_header(FLMetaKey.SITE_NAME, fl_ctx.get_identity_name())
Expand Down
77 changes: 77 additions & 0 deletions nvflare/app_common/executors/task_script_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import builtins
import logging
import os
import sys
import traceback

print_fn = builtins.print


class TaskScriptRunner:
logger = logging.getLogger(__name__)

def __init__(self, script_path: str, script_args: str = None):
"""Wrapper for function given function path and args
Args:
script_path (str): script file name, such as train.py
script_args (str, Optional): script arguments to pass in.
"""
self.script_args = script_args
self.client_api = None
self.logger = logging.getLogger(self.__class__.__name__)
self.script_path = self.get_script_full_path(script_path)

def run(self):
"""Call the task_fn with any required arguments."""
self.logger.info(f"\n start task run() with {self.script_path}")
try:
import runpy

curr_argv = sys.argv
builtins.print = log_print
sys.argv = self.get_sys_argv()
runpy.run_path(self.script_path, run_name="__main__")
sys.argv = curr_argv

except Exception as e:
msg = traceback.format_exc()
self.logger.error(msg)
if self.client_api:
self.client_api.exec_queue.ask_abort(msg)
raise e
finally:
builtins.print = print_fn

def get_sys_argv(self):
args_list = [] if not self.script_args else self.script_args.split()
return [self.script_path] + args_list

def get_script_full_path(self, script_path) -> str:
target_files = None
for r, dirs, files in os.walk(os.getcwd()):
target_files = [os.path.join(r, f) for f in files if f == script_path]
if target_files:
break
if not target_files:
raise ValueError(f"{script_path} is not found")
return target_files[0]


def log_print(*args, logger=TaskScriptRunner.logger, **kwargs):
# Create a message from print arguments
message = " ".join(str(arg) for arg in args)
logger.info(message)
100 changes: 0 additions & 100 deletions tests/unit_test/app_common/executors/exec_task_fn_wrapper_test.py

This file was deleted.

28 changes: 28 additions & 0 deletions tests/unit_test/app_common/executors/task_script_runner_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest

from nvflare.app_common.executors.task_script_runner import TaskScriptRunner


class TestExecTaskFuncWrapper(unittest.TestCase):
def test_app_scripts_and_args(self):
curr_dir = os.getcwd()
script_path = "cli.py"
script_args = "--batch_size 4"
wrapper = TaskScriptRunner(script_path=script_path, script_args=script_args)

self.assertTrue(wrapper.script_path.endswith(script_path))
self.assertEqual(wrapper.get_sys_argv(), [os.path.join(curr_dir, "nvflare", "cli.py"), "--batch_size", "4"])

0 comments on commit ce79609

Please sign in to comment.