Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/gt4py/frontend/gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1468,6 +1468,16 @@ def resolve_external_symbols(
resolved_imports = {**imported}
resolved_values_list = list(nonlocals.items())

# Resolve function-like imports
func_externals = {
key: value for key, value in context.items() if isinstance(value, types.FunctionType)
}
for name, value in func_externals.items():
if isinstance(value, types.FunctionType) and not hasattr(value, "_gtscript_"):
GTScriptParser.annotate_definition(value)
for imported_name, imported_value in value._gtscript_["imported"].items():
resolved_imports[imported_name] = imported_value

# Collect all imported and inlined values recursively through all the external symbols
while resolved_imports or resolved_values_list:
new_imports = {}
Expand Down Expand Up @@ -1670,7 +1680,7 @@ class GTScriptFrontend(gt_frontend.Frontend):

@classmethod
def get_stencil_id(cls, qualified_name, definition, externals, options_id):
cls.prepare_stencil_definition(definition, externals)
cls.prepare_stencil_definition(definition, externals or {})
fingerprint = {
"__main__": definition._gtscript_["canonical_ast"],
"docstring": inspect.getdoc(definition),
Expand Down
58 changes: 58 additions & 0 deletions tests/test_unittest/test_gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,64 @@ def definition_func(inout_field: gtscript.Field[float]):
stmt = def_ir.computations[0].body.stmts[0]
assert isinstance(stmt.value, gt_ir.ScalarLiteral) and stmt.value.value == 1

def test_function_import(self, id_version):
module = f"TestInlinedExternals_test_recursive_imports_{id_version}"

def some_function():
from __externals__ import const

return const

def definition_func(inout_field: gtscript.Field[float]):
from __externals__ import some_call

with computation(PARALLEL), interval(...):
inout_field = some_call()

stencil_id, def_ir = compile_definition(
definition_func,
"test_recursive_imports",
module,
externals={"some_call": some_function, "const": GLOBAL_CONSTANT},
)
assert set(def_ir.externals.keys()) == {
"some_call",
"const",
}
Comment on lines +178 to +201
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reasons that still elude me, this test is necessary. I thought this was covered below, but this uncovered the need for the additional annotate call at gtscript_frontend.py:1477.


def test_recursive_imports(self, id_version):
module = f"TestInlinedExternals_test_recursive_imports_{id_version}"

def func_nest2():
from __externals__ import const

return const

def func_nest1():
from __externals__ import other

return other() + func_nest2()

def definition_func(inout_field: gtscript.Field[float]):
from __externals__ import some_call

with computation(PARALLEL), interval(...):
inout_field = func_nest1() + some_call()

stencil_id, def_ir = compile_definition(
definition_func,
"test_recursive_imports",
module,
externals={"some_call": func_nest1, "other": func_nest2, "const": GLOBAL_CONSTANT},
)
assert set(def_ir.externals.keys()) == {
"some_call",
"const",
"other",
"func_nest1",
"tests.test_unittest.test_gtscript_frontend.func_nest1.func_nest2",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Externals captured implicitly get the fully qualified name.

}

def test_decorated_freeze(self):
A = 0

Expand Down