Skip to content

Commit

Permalink
synchronous run with logs
Browse files Browse the repository at this point in the history
  • Loading branch information
madhur-ob committed May 21, 2024
1 parent 1766839 commit ab53fad
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 18 deletions.
37 changes: 23 additions & 14 deletions metaflow/metaflow_runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import sys
import time
import asyncio
import tempfile
from typing import Dict, Optional
from metaflow import Run
Expand Down Expand Up @@ -110,8 +109,7 @@ def __init__(
if profile:
self.env_vars["METAFLOW_PROFILE"] = profile
self.spm = SubprocessManager()
self.api = MetaflowAPI.from_cli(self.flow_file, start)
self.runner = self.api(**kwargs).run
self.api = MetaflowAPI.from_cli(self.flow_file, start)(**kwargs)

def __enter__(self):
return self
Expand All @@ -120,23 +118,34 @@ async def __aenter__(self):
return self

def run(self, **kwargs):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
with tempfile.TemporaryDirectory() as temp_dir:
tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
command = self.api.run(pathspec_file=tfp_pathspec.name, **kwargs)

try:
result = loop.run_until_complete(self.async_run(**kwargs))
result = loop.run_until_complete(result.wait())
return result
finally:
loop.close()
pid = self.spm.run_command([sys.executable, *command], env=self.env_vars)
command_obj = self.spm.get(pid)

try:
pathspec = read_from_file_when_ready(tfp_pathspec.name, timeout=5)
run_object = Run(pathspec, _namespace_check=False)
return ExecutingRun(self, command_obj, run_object)
except TimeoutError as e:
stdout_log = open(command_obj.log_files["stdout"]).read()
stderr_log = open(command_obj.log_files["stderr"]).read()
command = " ".join(command_obj.command)
error_message = "Error executing: '%s':\n" % command
if stdout_log.strip():
error_message += "\nStdout:\n%s\n" % stdout_log
if stderr_log.strip():
error_message += "\nStderr:\n%s\n" % stderr_log
raise RuntimeError(error_message) from e

async def async_run(self, **kwargs):
with tempfile.TemporaryDirectory() as temp_dir:
tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
command = self.api.run(pathspec_file=tfp_pathspec.name, **kwargs)

command = self.runner(pathspec_file=tfp_pathspec.name, **kwargs)

pid = await self.spm.run_command(
pid = await self.spm.async_run_command(
[sys.executable, *command], env=self.env_vars
)
command_obj = self.spm.get(pid)
Expand Down
79 changes: 75 additions & 4 deletions metaflow/subprocess_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import shutil
import asyncio
import tempfile
import threading
import subprocess
from typing import List, Dict, Optional, Callable

Expand Down Expand Up @@ -41,7 +42,20 @@ async def __aenter__(self) -> "SubprocessManager":
async def __aexit__(self, exc_type, exc_value, traceback):
self.cleanup()

async def run_command(
def run_command(
self,
command: List[str],
env: Optional[Dict[str, str]] = None,
cwd: Optional[str] = None,
) -> int:
"""Run a command synchronously and return its process ID."""

command_obj = CommandManager(command, env, cwd)
pid = command_obj.run()
self.commands[pid] = command_obj
return pid

async def async_run_command(
self,
command: List[str],
env: Optional[Dict[str, str]] = None,
Expand All @@ -50,7 +64,7 @@ async def run_command(
"""Run a command asynchronously and return its process ID."""

command_obj = CommandManager(command, env, cwd)
pid = await command_obj.run()
pid = await command_obj.async_run()
self.commands[pid] = command_obj
return pid

Expand Down Expand Up @@ -121,7 +135,64 @@ async def wait(
% (self.process.pid, command_string, timeout)
)

async def run(self):
def run(self):
if not self.run_called:
self.temp_dir = tempfile.mkdtemp()
stdout_logfile = os.path.join(self.temp_dir, "stdout.log")
stderr_logfile = os.path.join(self.temp_dir, "stderr.log")

def stream_to_stdout_and_file(pipe, log_file):
with open(log_file, "w") as file:
for line in iter(pipe.readline, ""):
sys.stdout.write(line)
file.write(line)
pipe.close()

try:
self.process = subprocess.Popen(
self.command,
cwd=self.cwd,
env=self.env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
bufsize=1,
universal_newlines=True,
)

self.log_files["stdout"] = stdout_logfile
self.log_files["stderr"] = stderr_logfile

self.run_called = True

stdout_thread = threading.Thread(
target=stream_to_stdout_and_file,
args=(self.process.stdout, stdout_logfile),
)
stderr_thread = threading.Thread(
target=stream_to_stdout_and_file,
args=(self.process.stderr, stderr_logfile),
)

stdout_thread.start()
stderr_thread.start()

self.process.wait()

stdout_thread.join()
stderr_thread.join()

return self.process.pid
except Exception as e:
print("Error starting subprocess: %s" % e)
self.cleanup()
else:
command_string = " ".join(self.command)
print(
"Command '%s' has already been called. Please create another CommandManager object."
% command_string
)

async def async_run(self):
"""Run the subprocess, streaming the logs to temporary files"""

if not self.run_called:
Expand Down Expand Up @@ -243,7 +314,7 @@ async def main():

async with SubprocessManager() as spm:
# returns immediately
pid = await spm.run_command(cmd)
pid = await spm.async_run_command(cmd)
command_obj = spm.get(pid)

print(pid)
Expand Down

0 comments on commit ab53fad

Please sign in to comment.