diff --git a/metaflow/metaflow_runner.py b/metaflow/metaflow_runner.py index 0717d9c669d..b522c812e6c 100644 --- a/metaflow/metaflow_runner.py +++ b/metaflow/metaflow_runner.py @@ -109,7 +109,8 @@ def __init__( if profile: self.env_vars["METAFLOW_PROFILE"] = profile self.spm = SubprocessManager() - self.api = MetaflowAPI.from_cli(self.flow_file, start)(**kwargs) + self.top_level_kwargs = kwargs + self.api = MetaflowAPI.from_cli(self.flow_file, start) def __enter__(self): return self @@ -117,53 +118,73 @@ def __enter__(self): async def __aenter__(self): return self + def __get_executing_run(self, tfp_pathspec, command_obj): + try: + pathspec = read_from_file_when_ready(tfp_pathspec.name, timeout=5) + run_object = Run(pathspec, _namespace_check=False) + return ExecutingRun(self, command_obj, run_object) + except TimeoutError as e: + stdout_log = open(command_obj.log_files["stdout"]).read() + stderr_log = open(command_obj.log_files["stderr"]).read() + command = " ".join(command_obj.command) + error_message = "Error executing: '%s':\n" % command + if stdout_log.strip(): + error_message += "\nStdout:\n%s\n" % stdout_log + if stderr_log.strip(): + error_message += "\nStderr:\n%s\n" % stderr_log + raise RuntimeError(error_message) from e + def run(self, **kwargs): with tempfile.TemporaryDirectory() as temp_dir: tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False) - command = self.api.run(pathspec_file=tfp_pathspec.name, **kwargs) + command = self.api(**self.top_level_kwargs).run( + pathspec_file=tfp_pathspec.name, **kwargs + ) pid = self.spm.run_command([sys.executable, *command], env=self.env_vars) command_obj = self.spm.get(pid) - try: - pathspec = read_from_file_when_ready(tfp_pathspec.name, timeout=5) - run_object = Run(pathspec, _namespace_check=False) - return ExecutingRun(self, command_obj, run_object) - except TimeoutError as e: - stdout_log = open(command_obj.log_files["stdout"]).read() - stderr_log = open(command_obj.log_files["stderr"]).read() - command = " ".join(command_obj.command) - error_message = "Error executing: '%s':\n" % command - if stdout_log.strip(): - error_message += "\nStdout:\n%s\n" % stdout_log - if stderr_log.strip(): - error_message += "\nStderr:\n%s\n" % stderr_log - raise RuntimeError(error_message) from e + return self.__get_executing_run(tfp_pathspec, command_obj) + + def resume(self, **kwargs): + with tempfile.TemporaryDirectory() as temp_dir: + tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False) + command = self.api(**self.top_level_kwargs).resume( + pathspec_file=tfp_pathspec.name, **kwargs + ) + + pid = self.spm.run_command([sys.executable, *command], env=self.env_vars) + command_obj = self.spm.get(pid) + + return self.__get_executing_run(tfp_pathspec, command_obj) async def async_run(self, **kwargs): with tempfile.TemporaryDirectory() as temp_dir: tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False) - command = self.api.run(pathspec_file=tfp_pathspec.name, **kwargs) + command = self.api(**self.top_level_kwargs).run( + pathspec_file=tfp_pathspec.name, **kwargs + ) + + pid = await self.spm.async_run_command( + [sys.executable, *command], env=self.env_vars + ) + command_obj = self.spm.get(pid) + + return self.__get_executing_run(tfp_pathspec, command_obj) + + async def async_resume(self, **kwargs): + with tempfile.TemporaryDirectory() as temp_dir: + tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False) + command = self.api(**self.top_level_kwargs).resume( + pathspec_file=tfp_pathspec.name, **kwargs + ) pid = await self.spm.async_run_command( [sys.executable, *command], env=self.env_vars ) command_obj = self.spm.get(pid) - try: - pathspec = read_from_file_when_ready(tfp_pathspec.name, timeout=5) - run_object = Run(pathspec, _namespace_check=False) - return ExecutingRun(self, command_obj, run_object) - except TimeoutError as e: - stdout_log = open(command_obj.log_files["stdout"]).read() - stderr_log = open(command_obj.log_files["stderr"]).read() - command = " ".join(command_obj.command) - error_message = "Error executing: '%s':\n" % command - if stdout_log.strip(): - error_message += "\nStdout:\n%s\n" % stdout_log - if stderr_log.strip(): - error_message += "\nStderr:\n%s\n" % stderr_log - raise RuntimeError(error_message) from e + return self.__get_executing_run(tfp_pathspec, command_obj) def __exit__(self, exc_type, exc_value, traceback): self.spm.cleanup()