Skip to content

Commit

Permalink
resume and async resume
Browse files Browse the repository at this point in the history
  • Loading branch information
madhur-ob committed May 23, 2024
1 parent 4d6460a commit c8b9ab2
Showing 1 changed file with 101 additions and 42 deletions.
143 changes: 101 additions & 42 deletions metaflow/runner/metaflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def __init__(
if profile:
self.env_vars["METAFLOW_PROFILE"] = profile
self.spm = SubprocessManager()
self.top_level_kwargs = kwargs
self.api = MetaflowAPI.from_cli(self.flow_file, start)

def __enter__(self) -> "Runner":
Expand All @@ -216,19 +217,35 @@ def __enter__(self) -> "Runner":
async def __aenter__(self) -> "Runner":
return self

def __exit__(self, exc_type, exc_value, traceback):
self.spm.cleanup()

async def __aexit__(self, exc_type, exc_value, traceback):
self.spm.cleanup()

def run(self, **kwargs) -> ExecutingRun:
def __get_executing_run(self, tfp_pathspec, command_obj):
try:
pathspec = read_from_file_when_ready(tfp_pathspec.name, timeout=5)
run_object = Run(pathspec, _namespace_check=False)
return ExecutingRun(self, command_obj, run_object)
except TimeoutError as e:
stdout_log = open(command_obj.log_files["stdout"]).read()
stderr_log = open(command_obj.log_files["stderr"]).read()
command = " ".join(command_obj.command)
error_message = "Error executing: '%s':\n" % command
if stdout_log.strip():
error_message += "\nStdout:\n%s\n" % stdout_log
if stderr_log.strip():
error_message += "\nStderr:\n%s\n" % stderr_log
raise RuntimeError(error_message) from e

def run(self, show_output: bool = False, **kwargs) -> ExecutingRun:
"""
Synchronous execution of the run. This method will *block* until
the run has completed execution.
Parameters
----------
show_output : bool, default False
Suppress the 'stdout' and 'stderr' to the console by default.
They can be accessed later by reading the files present in the
ExecutingRun object (referenced as 'result' below) returned:
- result.stdout
- result.stderr
**kwargs : Any
Additional arguments that you would pass to `python ./myflow.py` after
the `run` command.
Expand All @@ -240,25 +257,51 @@ def run(self, **kwargs) -> ExecutingRun:
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
command = self.api.run(pathspec_file=tfp_pathspec.name, **kwargs)
command = self.api(**self.top_level_kwargs).run(
pathspec_file=tfp_pathspec.name, **kwargs
)

pid = self.spm.run_command([sys.executable, *command], env=self.env_vars)
pid = self.spm.run_command(
[sys.executable, *command], env=self.env_vars, show_output=show_output
)
command_obj = self.spm.get(pid)

try:
pathspec = read_from_file_when_ready(tfp_pathspec.name, timeout=5)
run_object = Run(pathspec, _namespace_check=False)
return ExecutingRun(self, command_obj, run_object)
except TimeoutError as e:
stdout_log = open(command_obj.log_files["stdout"]).read()
stderr_log = open(command_obj.log_files["stderr"]).read()
command = " ".join(command_obj.command)
error_message = "Error executing: '%s':\n" % command
if stdout_log.strip():
error_message += "\nStdout:\n%s\n" % stdout_log
if stderr_log.strip():
error_message += "\nStderr:\n%s\n" % stderr_log
raise RuntimeError(error_message) from e
return self.__get_executing_run(tfp_pathspec, command_obj)

def resume(self, show_output: bool = False, **kwargs):
"""
Synchronous resume execution of the run.
This method will *block* until the resumed run has completed execution.
Parameters
----------
show_output : bool, default False
Suppress the 'stdout' and 'stderr' to the console by default.
They can be accessed later by reading the files present in the
ExecutingRun object (referenced as 'result' below) returned:
- result.stdout
- result.stderr
**kwargs : Any
Additional arguments that you would pass to `python ./myflow.py` after
the `resume` command.
Returns
-------
ExecutingRun
ExecutingRun object for this resumed run.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
command = self.api(**self.top_level_kwargs).resume(
pathspec_file=tfp_pathspec.name, **kwargs
)

pid = self.spm.run_command(
[sys.executable, *command], env=self.env_vars, show_output=show_output
)
command_obj = self.spm.get(pid)

return self.__get_executing_run(tfp_pathspec, command_obj)

async def async_run(self, **kwargs) -> ExecutingRun:
"""
Expand All @@ -278,32 +321,48 @@ async def async_run(self, **kwargs) -> ExecutingRun:
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
command = self.api.run(pathspec_file=tfp_pathspec.name, **kwargs)
command = self.api(**self.top_level_kwargs).run(
pathspec_file=tfp_pathspec.name, **kwargs
)

pid = await self.spm.async_run_command(
[sys.executable, *command], env=self.env_vars
)
command_obj = self.spm.get(pid)

try:
pathspec = read_from_file_when_ready(tfp_pathspec.name, timeout=5)
run_object = Run(pathspec, _namespace_check=False)
return ExecutingRun(self, command_obj, run_object)
except TimeoutError as e:
stdout_log = open(
command_obj.log_files["stdout"], encoding="utf-8"
).read()
stderr_log = open(
command_obj.log_files["stderr"], encoding="utf-8"
).read()
command = " ".join(command_obj.command)
return self.__get_executing_run(tfp_pathspec, command_obj)

error_message = "Error executing: '%s':\n" % command
async def async_resume(self, **kwargs):
"""
Asynchronous resume execution of the run.
This method will return as soon as the resume has launched.
if stdout_log.strip():
error_message += "\nStdout:\n%s\n" % stdout_log
Parameters
----------
**kwargs : Any
Additional arguments that you would pass to `python ./myflow.py` after
the `resume` command.
if stderr_log.strip():
error_message += "\nStderr:\n%s\n" % stderr_log
Returns
-------
ExecutingRun
ExecutingRun object for this resumed run.
"""
with tempfile.TemporaryDirectory() as temp_dir:
tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
command = self.api(**self.top_level_kwargs).resume(
pathspec_file=tfp_pathspec.name, **kwargs
)

pid = await self.spm.async_run_command(
[sys.executable, *command], env=self.env_vars
)
command_obj = self.spm.get(pid)

raise RuntimeError(error_message) from e
return self.__get_executing_run(tfp_pathspec, command_obj)

def __exit__(self, exc_type, exc_value, traceback):
self.spm.cleanup()

async def __aexit__(self, exc_type, exc_value, traceback):
self.spm.cleanup()

0 comments on commit c8b9ab2

Please sign in to comment.