From c37c7a35147fb46a8539644831ae9e222dfa1031 Mon Sep 17 00:00:00 2001 From: Johann Dahm Date: Mon, 30 Nov 2020 21:46:50 -0800 Subject: [PATCH 1/7] Support external imports in functions --- src/gt4py/frontend/gtscript_frontend.py | 29 +++++++++++++++++-- tests/test_unittest/test_gtscript_frontend.py | 26 +++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/src/gt4py/frontend/gtscript_frontend.py b/src/gt4py/frontend/gtscript_frontend.py index 3d77934c30..175d3ad30b 100644 --- a/src/gt4py/frontend/gtscript_frontend.py +++ b/src/gt4py/frontend/gtscript_frontend.py @@ -1441,8 +1441,6 @@ def collect_external_symbols(definition): def eval_external(name: str, context: dict, loc=None): try: value = eval(name, context) - if isinstance(value, types.FunctionType) and not hasattr(value, "_gtscript_"): - GTScriptParser.annotate_definition(value) assert ( value is None @@ -1468,6 +1466,33 @@ def resolve_external_symbols( resolved_imports = {**imported} resolved_values_list = list(nonlocals.items()) + # Resolve function-like imports recursively + func_externals = dict( + filter( + lambda name_value: isinstance(name_value[1], types.FunctionType), context.items() + ) + ) + while func_externals: + new_func_externals = {} + for name, value in func_externals.items(): + # Annotate if not already done + if isinstance(value, types.FunctionType) and not hasattr(value, "_gtscript_"): + GTScriptParser.annotate_definition(value) + + # Get sub-context + sub_context, sub_unbound = gt_meta.get_closure( + value, included_nonlocals=True, include_builtins=False + ) + + # Resolve imports or add to list to recursively resolve + for imported_name, imported_value in value._gtscript_["imported"].items(): + if isinstance(imported_value, types.FunctionType): + add_func_externals[imported_name] = sub_context[imported_value] + else: + resolved_imports[imported_name] = imported_value + + func_externals = new_func_externals + # Collect all imported and inlined values recursively through all the external symbols while resolved_imports or resolved_values_list: new_imports = {} diff --git a/tests/test_unittest/test_gtscript_frontend.py b/tests/test_unittest/test_gtscript_frontend.py index cdacef0100..710e57a11a 100644 --- a/tests/test_unittest/test_gtscript_frontend.py +++ b/tests/test_unittest/test_gtscript_frontend.py @@ -175,6 +175,32 @@ def definition_func(inout_field: gtscript.Field[float]): stmt = def_ir.computations[0].body.stmts[0] assert isinstance(stmt.value, gt_ir.ScalarLiteral) and stmt.value.value == 1 + def test_recursive_imports(self, id_version): + module = f"TestInlinedExternals_test_recursive_imports_{id_version}" + + def func_nest2(): + from __externals__ import const + + return const + + def func_nest1(): + from __externals__ import other + + return other() + + def definition_func(inout_field: gtscript.Field[float]): + from __externals__ import func + + with computation(PARALLEL), interval(...): + inout_field = func() + + compile_definition( + definition_func, + "test_recursive_imports", + module, + externals={"func": func_nest1, "other": func_nest2, "const": GLOBAL_CONSTANT}, + ) + def test_decorated_freeze(self): A = 0 From 37abb5c907c9a16d6361cf405ec160e57a0a0e60 Mon Sep 17 00:00:00 2001 From: Johann Dahm Date: Mon, 30 Nov 2020 22:18:45 -0800 Subject: [PATCH 2/7] Pass empty dict if no externals --- src/gt4py/frontend/gtscript_frontend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/frontend/gtscript_frontend.py b/src/gt4py/frontend/gtscript_frontend.py index 175d3ad30b..eed65e06a1 100644 --- a/src/gt4py/frontend/gtscript_frontend.py +++ b/src/gt4py/frontend/gtscript_frontend.py @@ -1695,7 +1695,7 @@ class GTScriptFrontend(gt_frontend.Frontend): @classmethod def get_stencil_id(cls, qualified_name, definition, externals, options_id): - cls.prepare_stencil_definition(definition, externals) + cls.prepare_stencil_definition(definition, externals or {}) fingerprint = { "__main__": definition._gtscript_["canonical_ast"], "docstring": inspect.getdoc(definition), From 8feffb1da07f51822978d2ab35499a927ef0cda5 Mon Sep 17 00:00:00 2001 From: Johann Dahm Date: Tue, 1 Dec 2020 08:34:02 -0800 Subject: [PATCH 3/7] Apply suggestions from code review Co-authored-by: Enrique G. Paredes <18477+egparedes@users.noreply.github.com> --- src/gt4py/frontend/gtscript_frontend.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/gt4py/frontend/gtscript_frontend.py b/src/gt4py/frontend/gtscript_frontend.py index eed65e06a1..cc6377042a 100644 --- a/src/gt4py/frontend/gtscript_frontend.py +++ b/src/gt4py/frontend/gtscript_frontend.py @@ -1467,11 +1467,7 @@ def resolve_external_symbols( resolved_values_list = list(nonlocals.items()) # Resolve function-like imports recursively - func_externals = dict( - filter( - lambda name_value: isinstance(name_value[1], types.FunctionType), context.items() - ) - ) + func_externals = {key: value for key, value in context.items() if isinstance(value, types.FunctionType)} while func_externals: new_func_externals = {} for name, value in func_externals.items(): @@ -1487,7 +1483,7 @@ def resolve_external_symbols( # Resolve imports or add to list to recursively resolve for imported_name, imported_value in value._gtscript_["imported"].items(): if isinstance(imported_value, types.FunctionType): - add_func_externals[imported_name] = sub_context[imported_value] + new_func_externals[imported_name] = sub_context[imported_value] else: resolved_imports[imported_name] = imported_value From 563d8cbde2db072c2f3a3a321474cc7172089584 Mon Sep 17 00:00:00 2001 From: Johann Dahm Date: Tue, 1 Dec 2020 09:07:14 -0800 Subject: [PATCH 4/7] Formatted --- src/gt4py/frontend/gtscript_frontend.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/frontend/gtscript_frontend.py b/src/gt4py/frontend/gtscript_frontend.py index cc6377042a..3bd84de5c0 100644 --- a/src/gt4py/frontend/gtscript_frontend.py +++ b/src/gt4py/frontend/gtscript_frontend.py @@ -1467,7 +1467,9 @@ def resolve_external_symbols( resolved_values_list = list(nonlocals.items()) # Resolve function-like imports recursively - func_externals = {key: value for key, value in context.items() if isinstance(value, types.FunctionType)} + func_externals = { + key: value for key, value in context.items() if isinstance(value, types.FunctionType) + } while func_externals: new_func_externals = {} for name, value in func_externals.items(): From 91b98d8147d886d511eeccd015fb2947327c23b8 Mon Sep 17 00:00:00 2001 From: Johann Dahm Date: Tue, 1 Dec 2020 11:04:02 -0800 Subject: [PATCH 5/7] Fix recursive import --- src/gt4py/frontend/gtscript_frontend.py | 30 ++++++------------- tests/test_unittest/test_gtscript_frontend.py | 2 +- 2 files changed, 10 insertions(+), 22 deletions(-) diff --git a/src/gt4py/frontend/gtscript_frontend.py b/src/gt4py/frontend/gtscript_frontend.py index 3bd84de5c0..dedadda47e 100644 --- a/src/gt4py/frontend/gtscript_frontend.py +++ b/src/gt4py/frontend/gtscript_frontend.py @@ -1466,30 +1466,18 @@ def resolve_external_symbols( resolved_imports = {**imported} resolved_values_list = list(nonlocals.items()) - # Resolve function-like imports recursively + # Resolve function-like imports func_externals = { key: value for key, value in context.items() if isinstance(value, types.FunctionType) } - while func_externals: - new_func_externals = {} - for name, value in func_externals.items(): - # Annotate if not already done - if isinstance(value, types.FunctionType) and not hasattr(value, "_gtscript_"): - GTScriptParser.annotate_definition(value) - - # Get sub-context - sub_context, sub_unbound = gt_meta.get_closure( - value, included_nonlocals=True, include_builtins=False - ) - - # Resolve imports or add to list to recursively resolve - for imported_name, imported_value in value._gtscript_["imported"].items(): - if isinstance(imported_value, types.FunctionType): - new_func_externals[imported_name] = sub_context[imported_value] - else: - resolved_imports[imported_name] = imported_value - - func_externals = new_func_externals + for name, value in func_externals.items(): + # Annotate if not already done + if isinstance(value, types.FunctionType) and not hasattr(value, "_gtscript_"): + GTScriptParser.annotate_definition(value) + + # Resolve import + for imported_name, imported_value in value._gtscript_["imported"].items(): + resolved_imports[imported_name] = imported_value # Collect all imported and inlined values recursively through all the external symbols while resolved_imports or resolved_values_list: diff --git a/tests/test_unittest/test_gtscript_frontend.py b/tests/test_unittest/test_gtscript_frontend.py index 710e57a11a..413f4293ba 100644 --- a/tests/test_unittest/test_gtscript_frontend.py +++ b/tests/test_unittest/test_gtscript_frontend.py @@ -186,7 +186,7 @@ def func_nest2(): def func_nest1(): from __externals__ import other - return other() + return other() + func_nest2() def definition_func(inout_field: gtscript.Field[float]): from __externals__ import func From 48556129034cdb59e5b5646f8a80bbfaf2cd2412 Mon Sep 17 00:00:00 2001 From: Johann Dahm Date: Tue, 1 Dec 2020 11:49:57 -0800 Subject: [PATCH 6/7] Move back annotation --- src/gt4py/frontend/gtscript_frontend.py | 6 ++---- tests/test_unittest/test_gtscript_frontend.py | 13 ++++++++----- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/gt4py/frontend/gtscript_frontend.py b/src/gt4py/frontend/gtscript_frontend.py index dedadda47e..6d29b73cf6 100644 --- a/src/gt4py/frontend/gtscript_frontend.py +++ b/src/gt4py/frontend/gtscript_frontend.py @@ -1441,6 +1441,8 @@ def collect_external_symbols(definition): def eval_external(name: str, context: dict, loc=None): try: value = eval(name, context) + if isinstance(value, types.FunctionType) and not hasattr(value, "_gtscript_"): + GTScriptParser.annotate_definition(value) assert ( value is None @@ -1471,10 +1473,6 @@ def resolve_external_symbols( key: value for key, value in context.items() if isinstance(value, types.FunctionType) } for name, value in func_externals.items(): - # Annotate if not already done - if isinstance(value, types.FunctionType) and not hasattr(value, "_gtscript_"): - GTScriptParser.annotate_definition(value) - # Resolve import for imported_name, imported_value in value._gtscript_["imported"].items(): resolved_imports[imported_name] = imported_value diff --git a/tests/test_unittest/test_gtscript_frontend.py b/tests/test_unittest/test_gtscript_frontend.py index 413f4293ba..f5bc5e5537 100644 --- a/tests/test_unittest/test_gtscript_frontend.py +++ b/tests/test_unittest/test_gtscript_frontend.py @@ -189,17 +189,20 @@ def func_nest1(): return other() + func_nest2() def definition_func(inout_field: gtscript.Field[float]): - from __externals__ import func - with computation(PARALLEL), interval(...): - inout_field = func() + inout_field = func_nest1() - compile_definition( + stencil_id, def_ir = compile_definition( definition_func, "test_recursive_imports", module, - externals={"func": func_nest1, "other": func_nest2, "const": GLOBAL_CONSTANT}, + externals={"other": func_nest2, "const": GLOBAL_CONSTANT}, ) + assert set(def_ir.externals.keys()) == { + "const", + "func_nest1", + "tests.test_unittest.test_gtscript_frontend.func_nest1.func_nest2", + } def test_decorated_freeze(self): A = 0 From c69e226185fc1bcff9a0a667ea0904ee5cbc7d92 Mon Sep 17 00:00:00 2001 From: Johann Dahm Date: Tue, 1 Dec 2020 12:21:42 -0800 Subject: [PATCH 7/7] Add another test for corner case --- src/gt4py/frontend/gtscript_frontend.py | 3 +- tests/test_unittest/test_gtscript_frontend.py | 33 +++++++++++++++++-- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/src/gt4py/frontend/gtscript_frontend.py b/src/gt4py/frontend/gtscript_frontend.py index 6d29b73cf6..d90a3d53a3 100644 --- a/src/gt4py/frontend/gtscript_frontend.py +++ b/src/gt4py/frontend/gtscript_frontend.py @@ -1473,7 +1473,8 @@ def resolve_external_symbols( key: value for key, value in context.items() if isinstance(value, types.FunctionType) } for name, value in func_externals.items(): - # Resolve import + if isinstance(value, types.FunctionType) and not hasattr(value, "_gtscript_"): + GTScriptParser.annotate_definition(value) for imported_name, imported_value in value._gtscript_["imported"].items(): resolved_imports[imported_name] = imported_value diff --git a/tests/test_unittest/test_gtscript_frontend.py b/tests/test_unittest/test_gtscript_frontend.py index f5bc5e5537..3d7b93348b 100644 --- a/tests/test_unittest/test_gtscript_frontend.py +++ b/tests/test_unittest/test_gtscript_frontend.py @@ -175,6 +175,31 @@ def definition_func(inout_field: gtscript.Field[float]): stmt = def_ir.computations[0].body.stmts[0] assert isinstance(stmt.value, gt_ir.ScalarLiteral) and stmt.value.value == 1 + def test_function_import(self, id_version): + module = f"TestInlinedExternals_test_recursive_imports_{id_version}" + + def some_function(): + from __externals__ import const + + return const + + def definition_func(inout_field: gtscript.Field[float]): + from __externals__ import some_call + + with computation(PARALLEL), interval(...): + inout_field = some_call() + + stencil_id, def_ir = compile_definition( + definition_func, + "test_recursive_imports", + module, + externals={"some_call": some_function, "const": GLOBAL_CONSTANT}, + ) + assert set(def_ir.externals.keys()) == { + "some_call", + "const", + } + def test_recursive_imports(self, id_version): module = f"TestInlinedExternals_test_recursive_imports_{id_version}" @@ -189,17 +214,21 @@ def func_nest1(): return other() + func_nest2() def definition_func(inout_field: gtscript.Field[float]): + from __externals__ import some_call + with computation(PARALLEL), interval(...): - inout_field = func_nest1() + inout_field = func_nest1() + some_call() stencil_id, def_ir = compile_definition( definition_func, "test_recursive_imports", module, - externals={"other": func_nest2, "const": GLOBAL_CONSTANT}, + externals={"some_call": func_nest1, "other": func_nest2, "const": GLOBAL_CONSTANT}, ) assert set(def_ir.externals.keys()) == { + "some_call", "const", + "other", "func_nest1", "tests.test_unittest.test_gtscript_frontend.func_nest1.func_nest2", }