Skip to content

Commit

Permalink
[ENG-803] Bypass Python's import when loading code (#308)
Browse files Browse the repository at this point in the history
Change how we load user code to bypass Python's import. In the old way
if the user code was in `os.py`, then the user code won't load since
`os` is in `sys.modules`. Now directly load the code from the source
file.
  • Loading branch information
tebeka committed May 22, 2024
1 parent e0c4c72 commit 0d38415
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 100 deletions.
5 changes: 1 addition & 4 deletions runtimes/pythonrt/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@ ci: deps test-py

test-py:
ruff check .
python -m pytest \
--doctest-modules \
--ignore testdata \
-v .
python -m pytest

# You can set the TESTOPTS to pass options to `go test`
test-go:
Expand Down
129 changes: 34 additions & 95 deletions runtimes/pythonrt/ak_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
import sys
import tarfile
from base64 import b64decode, b64encode
from functools import wraps
from importlib.abc import Loader
from importlib.machinery import SourceFileLoader
from os import mkdir
from pathlib import Path
from socket import AF_UNIX, SOCK_STREAM, socket
Expand Down Expand Up @@ -70,103 +67,29 @@ def visit_Call(self, node):
return call


class AKLoader(Loader):
"""Custom file loaders that will rewrite function calls to actions."""
def __init__(self, src_loader, action):
self.file_name = src_loader.path
self.module_name = src_loader.name
self.action = action

def create_module(self, spec):
# Must be defined since it's an abstract method
return None # Use default module

def exec_module(self, module):
try:
with open(self.file_name) as fp:
src = fp.read()
except OSError as err:
raise ImportError(f'cannot read {self.module_name!r} - {err}')

mod = ast.parse(src, self.file_name, 'exec')
trans = Transformer(self.file_name)
out = trans.visit(mod)
ast.fix_missing_locations(out)

code = compile(out, self.file_name, 'exec')
setattr(module, ACTION_NAME, self.action)
exec(code, module.__dict__)


# There is an established way to add import hooks, but we want to change the behavior of
# the current PathFinder found in sys.import_hooks so it'll call our code when importing
# form the user directory. This is why you'll see all these monkey patches below.

def patch_finder(finder, action):
"""Patches the finder to use a custom source loader."""
_orig_find_spec = finder.find_spec
def find_spec(fullname, target=None):
spec = _orig_find_spec(fullname, target)
if spec is None or not isinstance(spec.loader, SourceFileLoader):
return spec

log.info('patching loader for %r', fullname)
spec.loader = AKLoader(spec.loader, action)
return spec

finder.find_spec = find_spec


def wrap_hook(hook, user_dir, action):
"""Wraps a hook to patch finder on user code directories."""
@wraps(hook)
def wrapper(path):
finder = hook(path)
if user_dir.is_relative_to(path):
patch_finder(finder, action)
return finder

return wrapper


def patch_import_hooks(user_dir, action_fn):
"""Patches standard import hook in sys.path_hooks."""
user_dir = Path(user_dir)
for i, hook in enumerate(sys.path_hooks):
if hook.__name__ == 'path_hook_for_FileFinder':
sys.path_hooks[i] = wrap_hook(hook, user_dir, action_fn)
return

raise RuntimeError(f'cannot find import hook to patch in {sys.path_hooks}')


ACTIVITY_ATTR = '__activity__'


def activity(fn):
setattr(fn, ACTIVITY_ATTR, True)
return fn


def ak_module():
mod = ModuleType('ak')
mod.activity = activity
return mod
def load_code(root_path, action_fn, module_name):
"""Load user code into a module, instrumenting function calls."""
log.info('importing %r', module_name)
file_name = Path(root_path) / (module_name + '.py')
with open(file_name) as fp:
src = fp.read()

tree = ast.parse(src, file_name, 'exec')
trans = Transformer(file_name)
patched_tree = trans.visit(tree)
ast.fix_missing_locations(patched_tree)

def load_code(root_path, action_fn, module_name):
# Make 'ak' module available for imports.
mod = ak_module()
sys.modules[mod.__name__] = mod
ak = ak_module()
sys.modules[ak.__name__] = ak

patch_import_hooks(root_path, action_fn)
code = compile(patched_tree, file_name, 'exec')

# Make sure user code is first in import path.
sys.path.insert(0, str(root_path))
module = ModuleType(module_name)
setattr(module, ACTION_NAME, action_fn)
exec(code, module.__dict__)

log.info('importing %r', module_name)
mod = __import__(module_name)
return mod
return module


def run_code(mod, entry_point, data):
Expand Down Expand Up @@ -261,6 +184,19 @@ def extract_response(self, message):
return pickle.loads(b64decode(data))


ACTIVITY_ATTR = '__activity__'

def activity(fn):
setattr(fn, ACTIVITY_ATTR, True)
return fn


def ak_module():
mod = ModuleType('ak')
mod.activity = activity # type: ignore
return mod


class AKCall:
"""Callable wrapping functions with activities."""
def __init__(self, module_name, comm: Comm):
Expand All @@ -269,6 +205,9 @@ def __init__(self, module_name, comm: Comm):
self.comm = comm

def should_run_as_activity(self, fn):
if self.in_activity:
return False

if getattr(fn, ACTIVITY_ATTR, False):
return True

Expand All @@ -281,7 +220,7 @@ def should_run_as_activity(self, fn):
return True

def __call__(self, func, *args, **kw):
if self.in_activity or not self.should_run_as_activity(func):
if not self.should_run_as_activity(func):
log.info(
'calling %s (args=%r, kw=%r) directly (in_activity=%s)',
func.__name__, args, kw, self.in_activity)
Expand Down
19 changes: 19 additions & 0 deletions runtimes/pythonrt/ak_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,25 @@ def __call__(self, fn, *args, **kw):
assert name == 'json.loads'


def test_load_twice(tmp_path):
mod_name = 'x'
file_name = tmp_path / (mod_name + '.py')
var, val = 'x', 1
with open(file_name, 'w') as out:
print(f'{var} = {val}', file=out)

mod = ak_runner.load_code(tmp_path, lambda x: x, mod_name)
assert getattr(mod, var) == val

# See that module is not cached.
val += 1
with open(file_name, 'w') as out:
print(f'{var} = {val}', file=out)

mod = ak_runner.load_code(tmp_path, lambda x: x, mod_name)
assert getattr(mod, var) == val


def test_cmdline_help():
py_file = str(test_dir / 'ak_runner.py')
cmd = [sys.executable, py_file, '-h']
Expand Down
3 changes: 3 additions & 0 deletions runtimes/pythonrt/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[tool.pytest.ini_options]

addopts = "-v --doctest-modules --ignore testdata"
2 changes: 1 addition & 1 deletion runtimes/pythonrt/testdata/mod.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Test file for loader
# Test file for loader, see ../ak_runner_test.py::test_load_code

import json

Expand Down

0 comments on commit 0d38415

Please sign in to comment.