diff --git a/src/gt4py/frontend/gtscript_frontend.py b/src/gt4py/frontend/gtscript_frontend.py index 3d77934c30..d90a3d53a3 100644 --- a/src/gt4py/frontend/gtscript_frontend.py +++ b/src/gt4py/frontend/gtscript_frontend.py @@ -1468,6 +1468,16 @@ def resolve_external_symbols( resolved_imports = {**imported} resolved_values_list = list(nonlocals.items()) + # Resolve function-like imports + func_externals = { + key: value for key, value in context.items() if isinstance(value, types.FunctionType) + } + for name, value in func_externals.items(): + if isinstance(value, types.FunctionType) and not hasattr(value, "_gtscript_"): + GTScriptParser.annotate_definition(value) + for imported_name, imported_value in value._gtscript_["imported"].items(): + resolved_imports[imported_name] = imported_value + # Collect all imported and inlined values recursively through all the external symbols while resolved_imports or resolved_values_list: new_imports = {} @@ -1670,7 +1680,7 @@ class GTScriptFrontend(gt_frontend.Frontend): @classmethod def get_stencil_id(cls, qualified_name, definition, externals, options_id): - cls.prepare_stencil_definition(definition, externals) + cls.prepare_stencil_definition(definition, externals or {}) fingerprint = { "__main__": definition._gtscript_["canonical_ast"], "docstring": inspect.getdoc(definition), diff --git a/tests/test_unittest/test_gtscript_frontend.py b/tests/test_unittest/test_gtscript_frontend.py index cdacef0100..3d7b93348b 100644 --- a/tests/test_unittest/test_gtscript_frontend.py +++ b/tests/test_unittest/test_gtscript_frontend.py @@ -175,6 +175,64 @@ def definition_func(inout_field: gtscript.Field[float]): stmt = def_ir.computations[0].body.stmts[0] assert isinstance(stmt.value, gt_ir.ScalarLiteral) and stmt.value.value == 1 + def test_function_import(self, id_version): + module = f"TestInlinedExternals_test_recursive_imports_{id_version}" + + def some_function(): + from __externals__ import const + + return const + + def definition_func(inout_field: gtscript.Field[float]): + from __externals__ import some_call + + with computation(PARALLEL), interval(...): + inout_field = some_call() + + stencil_id, def_ir = compile_definition( + definition_func, + "test_recursive_imports", + module, + externals={"some_call": some_function, "const": GLOBAL_CONSTANT}, + ) + assert set(def_ir.externals.keys()) == { + "some_call", + "const", + } + + def test_recursive_imports(self, id_version): + module = f"TestInlinedExternals_test_recursive_imports_{id_version}" + + def func_nest2(): + from __externals__ import const + + return const + + def func_nest1(): + from __externals__ import other + + return other() + func_nest2() + + def definition_func(inout_field: gtscript.Field[float]): + from __externals__ import some_call + + with computation(PARALLEL), interval(...): + inout_field = func_nest1() + some_call() + + stencil_id, def_ir = compile_definition( + definition_func, + "test_recursive_imports", + module, + externals={"some_call": func_nest1, "other": func_nest2, "const": GLOBAL_CONSTANT}, + ) + assert set(def_ir.externals.keys()) == { + "some_call", + "const", + "other", + "func_nest1", + "tests.test_unittest.test_gtscript_frontend.func_nest1.func_nest2", + } + def test_decorated_freeze(self): A = 0