Skip to content

Commit

Permalink
Load assignments when safeloading a namespace (#13781)
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle committed Jun 4, 2024
1 parent a23a2a6 commit 9c584e8
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 9 deletions.
10 changes: 5 additions & 5 deletions src/prefect/utilities/importtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def safe_load_namespace(source_code: str):
"""
parsed_code = ast.parse(source_code)

namespace = {}
namespace = {"__name__": "prefect_safe_namespace_loader"}

# Walk through the AST and find all import statements
for node in ast.walk(parsed_code):
Expand Down Expand Up @@ -413,17 +413,17 @@ def safe_load_namespace(source_code: str):
except ImportError as e:
logger.debug("Failed to import from %s: %s", node.module, e)

# Handle local class definitions
# Handle local definitions
for node in ast.walk(parsed_code):
if isinstance(node, (ast.ClassDef, ast.FunctionDef)):
if isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.Assign)):
try:
# Compile and execute each class and function definition locally
# Compile and execute each class and function definition and assignment
code = compile(
ast.Module(body=[node], type_ignores=[]),
filename="<ast>",
mode="exec",
)
exec(code, namespace)
except Exception as e:
logger.debug("Failed to compile class definition: %s", e)
logger.debug("Failed to compile: %s", e)
return namespace
56 changes: 56 additions & 0 deletions tests/utilities/test_callables.py
Original file line number Diff line number Diff line change
Expand Up @@ -1687,6 +1687,62 @@ def f(x):
"required": ["x"],
}

def test_handles_dynamically_created_models(self, tmp_path: Path):
source_code = dedent(
"""
from pydantic import BaseModel, create_model, Field
def get_model() -> BaseModel:
return create_model(
"MyModel",
param=(
int,
Field(
title="param",
default=1,
),
),
)
MyModel = get_model()
def f(
param: MyModel,
) -> None:
pass
"""
)
tmp_path.joinpath("test.py").write_text(source_code)
schema = callables.parameter_schema_from_entrypoint(f"{tmp_path}/test.py:f")
assert schema.dict() == {
"title": "Parameters",
"type": "object",
"properties": {
"param": {
"allOf": [{"$ref": "#/definitions/MyModel"}],
"position": 0,
"title": "param",
}
},
"required": ["param"],
"definitions": {
"MyModel": {
"properties": {
"param": {
"default": 1,
"title": "param",
"type": "integer",
}
},
"title": "MyModel",
"type": "object",
}
},
}

def test_function_with_kwargs_only(self, tmp_path: Path):
source_code = dedent(
"""
Expand Down
8 changes: 4 additions & 4 deletions tests/utilities/test_importtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,10 @@ def my_fn():
assert "math" in namespace
assert "datetime" in namespace
assert "BaseModel" in namespace
# module-level variables should not be present
assert "x" not in namespace
assert "y" not in namespace
assert "now" not in namespace
# module-level variables should be present
assert "x" in namespace
assert "y" in namespace
assert "now" in namespace
# module-level classes should be present
assert "MyModel" in namespace
# module-level functions should be present
Expand Down

0 comments on commit 9c584e8

Please sign in to comment.