-
Notifications
You must be signed in to change notification settings - Fork 159
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improve InProcessClientAPIExecutor (#2536)
* 1. rename ExeTaskFnWrapper class to TaskScriptRunner 2. Replace implementation of the inprocess function exection from calling a main() function to user runpy.run_path() which reduce the user requirements to have main() function 3. redirect print() to logger.info() * 1. rename ExeTaskFnWrapper class to TaskScriptRunner 2. Replace implementation of the inprocess function exection from calling a main() function to user runpy.run_path() which reduce the user requirements to have main() function 3. redirect print() to logger.info() * make result check and result pull use the same configurable variable * rename exec_task_fn_wrapper to task_script_runner.py * fix typo
- Loading branch information
1 parent
b0faf52
commit ce79609
Showing
5 changed files
with
113 additions
and
187 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import builtins | ||
import logging | ||
import os | ||
import sys | ||
import traceback | ||
|
||
print_fn = builtins.print | ||
|
||
|
||
class TaskScriptRunner: | ||
logger = logging.getLogger(__name__) | ||
|
||
def __init__(self, script_path: str, script_args: str = None): | ||
"""Wrapper for function given function path and args | ||
Args: | ||
script_path (str): script file name, such as train.py | ||
script_args (str, Optional): script arguments to pass in. | ||
""" | ||
self.script_args = script_args | ||
self.client_api = None | ||
self.logger = logging.getLogger(self.__class__.__name__) | ||
self.script_path = self.get_script_full_path(script_path) | ||
|
||
def run(self): | ||
"""Call the task_fn with any required arguments.""" | ||
self.logger.info(f"\n start task run() with {self.script_path}") | ||
try: | ||
import runpy | ||
|
||
curr_argv = sys.argv | ||
builtins.print = log_print | ||
sys.argv = self.get_sys_argv() | ||
runpy.run_path(self.script_path, run_name="__main__") | ||
sys.argv = curr_argv | ||
|
||
except Exception as e: | ||
msg = traceback.format_exc() | ||
self.logger.error(msg) | ||
if self.client_api: | ||
self.client_api.exec_queue.ask_abort(msg) | ||
raise e | ||
finally: | ||
builtins.print = print_fn | ||
|
||
def get_sys_argv(self): | ||
args_list = [] if not self.script_args else self.script_args.split() | ||
return [self.script_path] + args_list | ||
|
||
def get_script_full_path(self, script_path) -> str: | ||
target_files = None | ||
for r, dirs, files in os.walk(os.getcwd()): | ||
target_files = [os.path.join(r, f) for f in files if f == script_path] | ||
if target_files: | ||
break | ||
if not target_files: | ||
raise ValueError(f"{script_path} is not found") | ||
return target_files[0] | ||
|
||
|
||
def log_print(*args, logger=TaskScriptRunner.logger, **kwargs): | ||
# Create a message from print arguments | ||
message = " ".join(str(arg) for arg in args) | ||
logger.info(message) |
100 changes: 0 additions & 100 deletions
100
tests/unit_test/app_common/executors/exec_task_fn_wrapper_test.py
This file was deleted.
Oops, something went wrong.
28 changes: 28 additions & 0 deletions
28
tests/unit_test/app_common/executors/task_script_runner_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import os | ||
import unittest | ||
|
||
from nvflare.app_common.executors.task_script_runner import TaskScriptRunner | ||
|
||
|
||
class TestExecTaskFuncWrapper(unittest.TestCase): | ||
def test_app_scripts_and_args(self): | ||
curr_dir = os.getcwd() | ||
script_path = "cli.py" | ||
script_args = "--batch_size 4" | ||
wrapper = TaskScriptRunner(script_path=script_path, script_args=script_args) | ||
|
||
self.assertTrue(wrapper.script_path.endswith(script_path)) | ||
self.assertEqual(wrapper.get_sys_argv(), [os.path.join(curr_dir, "nvflare", "cli.py"), "--batch_size", "4"]) |