diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 15b47dbc64..7eed05a841 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,6 +12,7 @@ repos: - id: check-case-conflict # - id: check-json - id: check-merge-conflict + exclude_types: [rst] - id: check-toml - id: check-yaml - id: debug-statements diff --git a/src/gt4py/backend/gtc_backend/defir_to_gtir.py b/src/gt4py/backend/gtc_backend/defir_to_gtir.py index 74379b9e57..75c1fbcf29 100644 --- a/src/gt4py/backend/gtc_backend/defir_to_gtir.py +++ b/src/gt4py/backend/gtc_backend/defir_to_gtir.py @@ -15,7 +15,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import numbers -from typing import Any, Dict, List, Union, cast +from typing import Any, Dict, List, Tuple, Union, cast from gt4py.ir import IRNodeVisitor from gt4py.ir.nodes import ( @@ -45,6 +45,7 @@ UnaryOpExpr, VarDecl, VarRef, + While, ) from gtc import common, gtir from gtc.common import ExprKind @@ -245,7 +246,7 @@ def visit_NativeFuncCall(self, node: NativeFuncCall) -> gtir.NativeFuncCall: loc=common.location_to_source_location(node.loc), ) - def visit_FieldRef(self, node: FieldRef): + def visit_FieldRef(self, node: FieldRef) -> gtir.FieldAccess: return gtir.FieldAccess( name=node.name, offset=self.transform_offset(node.offset), @@ -253,7 +254,7 @@ def visit_FieldRef(self, node: FieldRef): loc=common.location_to_source_location(node.loc), ) - def visit_If(self, node: If): + def visit_If(self, node: If) -> Union[gtir.FieldIfStmt, gtir.ScalarIfStmt]: cond = self.visit(node.condition) if cond.kind == ExprKind.FIELD: return gtir.FieldIfStmt( @@ -274,19 +275,26 @@ def visit_If(self, node: If): loc=common.location_to_source_location(node.loc), ) + def visit_While(self, node: While) -> gtir.While: + return gtir.While( + cond=self.visit(node.condition), + body=self.visit(node.body), + loc=common.location_to_source_location(node.loc), + ) + def visit_VarRef(self, node: VarRef, **kwargs): return gtir.ScalarAccess(name=node.name, loc=common.location_to_source_location(node.loc)) - def visit_AxisInterval(self, node: AxisInterval): + def visit_AxisInterval(self, node: AxisInterval) -> Tuple[gtir.AxisBound, gtir.AxisBound]: return self.visit(node.start), self.visit(node.end) - def visit_AxisBound(self, node: AxisBound): + def visit_AxisBound(self, node: AxisBound) -> gtir.AxisBound: # TODO(havogt) add support VarRef return gtir.AxisBound( level=self.GT4PY_LEVELMARKER_TO_GTIR_LEVELMARKER[node.level], offset=node.offset ) - def visit_FieldDecl(self, node: FieldDecl): + def visit_FieldDecl(self, node: FieldDecl) -> gtir.FieldDecl: dimension_names = ["I", "J", "K"] dimensions = [dim in node.axes for dim in dimension_names] # datatype conversion works via same ID @@ -298,7 +306,7 @@ def visit_FieldDecl(self, node: FieldDecl): loc=common.location_to_source_location(node.loc), ) - def visit_VarDecl(self, node: VarDecl): + def visit_VarDecl(self, node: VarDecl) -> gtir.ScalarDecl: # datatype conversion works via same ID return gtir.ScalarDecl( name=node.name, @@ -308,10 +316,10 @@ def visit_VarDecl(self, node: VarDecl): def transform_offset( self, offset: Dict[str, Union[int, Expr]], **kwargs: Any - ) -> Union[gtir.CartesianOffset, gtir.VariableKOffset]: + ) -> Union[common.CartesianOffset, gtir.VariableKOffset]: k_val = offset.get("K", 0) if isinstance(k_val, numbers.Integral): - return gtir.CartesianOffset(i=offset.get("I", 0), j=offset.get("J", 0), k=k_val) + return common.CartesianOffset(i=offset.get("I", 0), j=offset.get("J", 0), k=k_val) elif isinstance(k_val, Expr): return gtir.VariableKOffset(k=self.visit(k_val, **kwargs)) else: diff --git a/src/gt4py/frontend/gtscript_frontend.py b/src/gt4py/frontend/gtscript_frontend.py index 1b3d88c303..6a4b95be5f 100644 --- a/src/gt4py/frontend/gtscript_frontend.py +++ b/src/gt4py/frontend/gtscript_frontend.py @@ -737,7 +737,7 @@ def __init__( self.extra_temp_decls = extra_temp_decls or {} self.parsing_context = None self.iteration_order = None - self.if_decls_stack = [] + self.decls_stack = [] gt_ir.NativeFunction.PYTHON_SYMBOL_TO_IR_OP = { "abs": gt_ir.NativeFunction.ABS, "min": gt_ir.NativeFunction.MIN, @@ -1237,7 +1237,7 @@ def visit_IfExp(self, node: ast.IfExp) -> gt_ir.TernaryOpExpr: return result def visit_If(self, node: ast.If) -> list: - self.if_decls_stack.append([]) + self.decls_stack.append([]) main_stmts = [] for stmt in node.body: @@ -1251,11 +1251,11 @@ def visit_If(self, node: ast.If) -> list: assert all(isinstance(item, gt_ir.Statement) for item in else_stmts) result = [] - if len(self.if_decls_stack) == 1: - result.extend(self.if_decls_stack.pop()) - elif len(self.if_decls_stack) > 1: - self.if_decls_stack[-2].extend(self.if_decls_stack[-1]) - self.if_decls_stack.pop() + if len(self.decls_stack) == 1: + result.extend(self.decls_stack.pop()) + elif len(self.decls_stack) > 1: + self.decls_stack[-2].extend(self.decls_stack[-1]) + self.decls_stack.pop() result.append( gt_ir.If( @@ -1270,17 +1270,28 @@ def visit_If(self, node: ast.If) -> list: return result - def visit_While(self, node: ast.While) -> gt_ir.While: - if node.orelse: - raise GTScriptSyntaxError("orelse is not supported on while loops") - stmts = [] - for stmt in node.body: - stmts.extend(self.visit(stmt)) - return gt_ir.While( - condition=self.visit(node.test), - loc=gt_ir.Location.from_ast_node(node), - body=gt_ir.BlockStmt(stmts=stmts, loc=gt_ir.Location.from_ast_node(node)), - ) + def visit_While(self, node: ast.While) -> list: + loc = gt_ir.Location.from_ast_node(node) + + self.decls_stack.append([]) + stmts = gt_utils.flatten([self.visit(stmt) for stmt in node.body]) + assert all(isinstance(item, gt_ir.Statement) for item in stmts) + + result = [ + gt_ir.While( + condition=self.visit(node.test), + loc=gt_ir.Location.from_ast_node(node), + body=gt_ir.BlockStmt(stmts=stmts, loc=loc), + ) + ] + + if len(self.decls_stack) == 1: + result.extend(self.decls_stack.pop()) + elif len(self.decls_stack) > 1: + self.decls_stack[-2].extend(self.decls_stack[-1]) + self.decls_stack.pop() + + return result def visit_Call(self, node: ast.Call): native_fcn = gt_ir.NativeFunction.PYTHON_SYMBOL_TO_IR_OP[node.func.id] @@ -1371,8 +1382,8 @@ def visit_Assign(self, node: ast.Assign) -> list: # layout_id=t.id, is_api=False, ) - if len(self.if_decls_stack): - self.if_decls_stack[-1].append(field_decl) + if len(self.decls_stack): + self.decls_stack[-1].append(field_decl) else: result.append(field_decl) self.fields[field_decl.name] = field_decl diff --git a/src/gtc/common.py b/src/gtc/common.py index 6f92c74ebd..ed58c28d27 100644 --- a/src/gtc/common.py +++ b/src/gtc/common.py @@ -394,6 +394,21 @@ def condition_is_boolean(cls, cond: Expr) -> Expr: return verify_condition_is_boolean(cls, cond) +class While(GenericNode, Generic[StmtT, ExprT]): + """ + Generic while loop. + + Verifies that `cond` is a boolean expr (if `dtype` is set). + """ + + cond: ExprT + body: List[StmtT] + + @validator("cond") + def condition_is_boolean(cls, cond: Expr) -> Expr: + return verify_condition_is_boolean(cls, cond) + + class AssignStmt(GenericNode, Generic[TargetT, ExprT]): left: TargetT right: ExprT diff --git a/src/gtc/cuir/cuir.py b/src/gtc/cuir/cuir.py index 26fc741b7e..6082660c78 100644 --- a/src/gtc/cuir/cuir.py +++ b/src/gtc/cuir/cuir.py @@ -69,11 +69,17 @@ class KCacheAccess(common.FieldAccess[Expr, VariableKOffset], Expr): k_cache_is_different_from_field_access = True @validator("offset") - def zero_ij_offset(cls, v: CartesianOffset) -> CartesianOffset: + def has_no_ij_offset(cls, v: Union[CartesianOffset, VariableKOffset]) -> CartesianOffset: if not v.i == v.j == 0: raise ValueError("No ij-offset allowed") return v + @validator("offset") + def not_variable_offset(cls, v: Union[CartesianOffset, VariableKOffset]) -> CartesianOffset: + if isinstance(v, VariableKOffset): + raise ValueError("Cannot k-cache a variable k offset") + return v + @validator("data_index") def no_additional_dimensions(cls, v: List[int]) -> List[int]: if v: @@ -92,6 +98,10 @@ class MaskStmt(Stmt): body: List[Stmt] +class While(common.While[Stmt, Expr], Stmt): + pass + + class UnaryOp(common.UnaryOp[Expr], Expr): pass diff --git a/src/gtc/cuir/cuir_codegen.py b/src/gtc/cuir/cuir_codegen.py index 50ef3bf9ae..0726de2865 100644 --- a/src/gtc/cuir/cuir_codegen.py +++ b/src/gtc/cuir/cuir_codegen.py @@ -46,6 +46,14 @@ class CUIRCodegen(codegen.TemplatedGenerator): """ ) + While = as_mako( + """ + while (${cond}) { + ${'\\n'.join(body)} + } + """ + ) + def visit_FieldAccess(self, node: cuir.FieldAccess, **kwargs: Any): def maybe_const(s): try: diff --git a/src/gtc/cuir/oir_to_cuir.py b/src/gtc/cuir/oir_to_cuir.py index 8e5d5e7d06..4b50ea98a0 100644 --- a/src/gtc/cuir/oir_to_cuir.py +++ b/src/gtc/cuir/oir_to_cuir.py @@ -119,6 +119,11 @@ def visit_MaskStmt(self, node: oir.MaskStmt, **kwargs: Any) -> cuir.MaskStmt: mask=self.visit(node.mask, **kwargs), body=self.visit(node.body, **kwargs) ) + def visit_While(self, node: oir.While, **kwargs: Any) -> cuir.While: + return cuir.While( + cond=self.visit(node.cond, **kwargs), body=self.visit(node.body, **kwargs) + ) + def visit_Cast(self, node: oir.Cast, **kwargs: Any) -> cuir.Cast: return cuir.Cast(dtype=node.dtype, expr=self.visit(node.expr, **kwargs)) diff --git a/src/gtc/dace/expansion.py b/src/gtc/dace/expansion.py index a40a24924e..5cee8b0173 100644 --- a/src/gtc/dace/expansion.py +++ b/src/gtc/dace/expansion.py @@ -171,6 +171,13 @@ def visit_MaskStmt(self, node: oir.MaskStmt, **kwargs): body_code = [indent + b for b in body_code] return "\n".join([mask_str] + body_code) + def visit_While(self, node: oir.While, **kwargs): + body = self.visit(node.body, **kwargs) + cond = self.visit(node.cond, is_target=False, **kwargs) + indent = " " * 4 + delim = f"\n{indent}" + return f"while {cond}:\n{indent}{delim.join(body)}" + class RemoveCastInIndexVisitor(eve.NodeTranslator): def visit_FieldAccess(self, node: oir.FieldAccess): if node.data_index: diff --git a/src/gtc/dace/utils.py b/src/gtc/dace/utils.py index 26ebd3cefa..a19f294d09 100644 --- a/src/gtc/dace/utils.py +++ b/src/gtc/dace/utils.py @@ -232,16 +232,19 @@ def domain() -> "CartesianIterationSpace": ) @staticmethod - def from_offset(offset: CartesianOffset) -> "CartesianIterationSpace": + def from_offset( + offset: Union[CartesianOffset, oir.VariableKOffset] + ) -> "CartesianIterationSpace": + dict_offsets = offset.to_dict() return CartesianIterationSpace( i_interval=oir.Interval( - start=oir.AxisBound.from_start(min(0, offset.i)), - end=oir.AxisBound.from_end(max(0, offset.i)), + start=oir.AxisBound.from_start(min(0, dict_offsets["i"])), + end=oir.AxisBound.from_end(max(0, dict_offsets["i"])), ), j_interval=oir.Interval( - start=oir.AxisBound.from_start(min(0, offset.j)), - end=oir.AxisBound.from_end(max(0, offset.j)), + start=oir.AxisBound.from_start(min(0, dict_offsets["j"])), + end=oir.AxisBound.from_end(max(0, dict_offsets["j"])), ), ) diff --git a/src/gtc/gtcpp/gtcpp.py b/src/gtc/gtcpp/gtcpp.py index d8a74a52fa..819d0dd886 100644 --- a/src/gtc/gtcpp/gtcpp.py +++ b/src/gtc/gtcpp/gtcpp.py @@ -77,6 +77,10 @@ class IfStmt(common.IfStmt[Stmt, Expr], Stmt): pass +class While(common.While[Stmt, Expr], Stmt): + pass + + class UnaryOp(common.UnaryOp[Expr], Expr): pass diff --git a/src/gtc/gtcpp/gtcpp_codegen.py b/src/gtc/gtcpp/gtcpp_codegen.py index 28fc8e0ba5..18c6570ae8 100644 --- a/src/gtc/gtcpp/gtcpp_codegen.py +++ b/src/gtc/gtcpp/gtcpp_codegen.py @@ -203,6 +203,8 @@ def visit_LoopOrder(self, looporder: LoopOrder, **kwargs: Any) -> str: """ ) + While = as_mako("while(${cond}) {${''.join(body)}}") + BlockStmt = as_mako("{${''.join(body)}}") def visit_GTComputationCall( diff --git a/src/gtc/gtcpp/oir_to_gtcpp.py b/src/gtc/gtcpp/oir_to_gtcpp.py index dfc95159b7..de2d29a353 100644 --- a/src/gtc/gtcpp/oir_to_gtcpp.py +++ b/src/gtc/gtcpp/oir_to_gtcpp.py @@ -180,6 +180,11 @@ def visit_MaskStmt(self, node: oir.MaskStmt, **kwargs: Any) -> gtcpp.IfStmt: true_branch=gtcpp.BlockStmt(body=self.visit(node.body, **kwargs)), ) + def visit_While(self, node: oir.While, **kwargs: Any) -> gtcpp.While: + return gtcpp.While( + cond=self.visit(node.cond, **kwargs), body=self.visit(node.body, **kwargs) + ) + def visit_HorizontalExecution( self, node: oir.HorizontalExecution, diff --git a/src/gtc/gtir.py b/src/gtc/gtir.py index 520228b4a9..a0cc4b1f7b 100644 --- a/src/gtc/gtir.py +++ b/src/gtc/gtir.py @@ -57,10 +57,6 @@ class Literal(common.Literal, Expr): # type: ignore pass -class CartesianOffset(common.CartesianOffset): - pass - - class VariableKOffset(common.VariableKOffset[Expr]): pass @@ -149,6 +145,19 @@ def verify_scalar_condition(cls, cond: Expr) -> Expr: return cond +class While(common.While[Stmt, Expr], Stmt): + """While loop with a field or scalar expression as condition.""" + + @validator("body") + def _no_write_and_read_with_horizontal_offset_all( + cls, body: List[Stmt] + ) -> RootValidatorValuesType: + """In a while loop all variables must not be written and read with a horizontal offset.""" + if names := _written_and_read_with_offset(body): + raise ValueError(f"Illegal write and read with horizontal offset detected for {names}.") + return body + + class UnaryOp(common.UnaryOp[Expr], Expr): pass @@ -201,43 +210,16 @@ class VerticalLoop(LocNode): body: List[Stmt] @root_validator(skip_on_failure=True) - def no_write_and_read_with_horizontal_offset( + def _no_write_and_read_with_horizontal_offset( cls, values: RootValidatorValuesType ) -> RootValidatorValuesType: """ In the same VerticalLoop a field must not be written and read with a horizontal offset. - Temporaries don't have this contraint. Backends are required to implement temporaries with block-private halos. + Temporaries don't have this contraint. Backends are required to implement + them using block-private halos. """ - # TODO(havogt): either move to eve or will be removed in the attr-based eve if a List[Node] is represented as a CollectionNode - @utils.as_xiter - def _collection_iter_tree( - collection: List[Node], - ) -> Generator[TreeIterationItem, None, None]: - for elem in collection: - yield from elem.iter_tree() - - def _writes(stmts: List[Stmt]) -> Set[str]: - result = set() - for left in _collection_iter_tree(stmts).if_isinstance(ParAssignStmt).getattr("left"): - result |= left.iter_tree().if_isinstance(FieldAccess).getattr("name").to_set() - return result - - def _reads_with_offset(stmts: List[Stmt]) -> Set[str]: - return ( - _collection_iter_tree(stmts) - .filter(_cartesian_fieldaccess) - .filter( - lambda acc: acc.offset.i != 0 or acc.offset.j != 0 - ) # writes always have zero offset - .getattr("name") - .to_set() - ) - - writes = _writes(values["body"]) - reads_with_offset = _reads_with_offset(values["body"]) - - intersec = writes.intersection(reads_with_offset) + intersec = _written_and_read_with_offset(values["body"]) non_tmp_fields = { acc for acc in intersec if acc not in {tmp.name for tmp in values["temporaries"]} } @@ -268,3 +250,37 @@ def _cartesian_fieldaccess(node) -> bool: def _variablek_fieldaccess(node) -> bool: return isinstance(node, FieldAccess) and isinstance(node.offset, VariableKOffset) + + +def _written_and_read_with_offset( + stmts: List[Stmt], +) -> RootValidatorValuesType: + """Return a list of names that are written to and read with offset.""" + # TODO(havogt): either move to eve or will be removed in the attr-based eve if a List[Node] is represented as a CollectionNode + @utils.as_xiter + def _collection_iter_tree( + collection: List[Node], + ) -> Generator[TreeIterationItem, None, None]: + for elem in collection: + yield from elem.iter_tree() + + def _writes(stmts: List[Stmt]) -> Set[str]: + result = set() + for left in _collection_iter_tree(stmts).if_isinstance(ParAssignStmt).getattr("left"): + result |= left.iter_tree().if_isinstance(FieldAccess).getattr("name").to_set() + return result + + def _reads_with_offset(stmts: List[Stmt]) -> Set[str]: + return ( + _collection_iter_tree(stmts) + .filter(_cartesian_fieldaccess) + .filter( + lambda acc: acc.offset.i != 0 or acc.offset.j != 0 + ) # writes always have zero offset + .getattr("name") + .to_set() + ) + + writes = _writes(stmts) + reads_with_offset = _reads_with_offset(stmts) + return writes & reads_with_offset diff --git a/src/gtc/gtir_to_oir.py b/src/gtc/gtir_to_oir.py index a5d6841b87..84cebd9b9d 100644 --- a/src/gtc/gtir_to_oir.py +++ b/src/gtc/gtir_to_oir.py @@ -18,72 +18,21 @@ from typing import Any, List from eve import NodeTranslator -from gtc import gtir, oir +from gtc import common, gtir, oir, utils from gtc.common import CartesianOffset, DataType, LogicalOperator, UnaryOperator -def _create_mask(ctx: "GTIRToOIR.Context", name: str, cond: oir.Expr) -> oir.Temporary: - mask_field_decl = oir.Temporary(name=name, dtype=DataType.BOOL, dimensions=(True, True, True)) - ctx.add_decl(mask_field_decl) - - fill_mask_field = oir.HorizontalExecution( - body=[ - oir.AssignStmt( - left=oir.FieldAccess( - name=mask_field_decl.name, - offset=CartesianOffset.zero(), - dtype=mask_field_decl.dtype, - ), - right=cond, - ) - ], - declarations=[], - ) - ctx.add_horizontal_execution(fill_mask_field) - return mask_field_decl - - class GTIRToOIR(NodeTranslator): @dataclass class Context: - """ - Context for Stmts. - - `Stmt` nodes create `Temporary` nodes and `HorizontalExecution` nodes. - All visit()-methods for `Stmt` have no return value, - they attach their result to the Context object. - """ - - decls: List = field(default_factory=list) - horizontal_executions: List = field(default_factory=list) - - def add_decl(self, decl: oir.Decl) -> "GTIRToOIR.Context": - self.decls.append(decl) - return self - - def add_horizontal_execution( - self, horizontal_execution: oir.HorizontalExecution - ) -> "GTIRToOIR.Context": - self.horizontal_executions.append(horizontal_execution) - return self + local_scalars: List[oir.ScalarDecl] = field(default_factory=list) + temp_fields: List[oir.FieldDecl] = field(default_factory=list) - def visit_ParAssignStmt( - self, node: gtir.ParAssignStmt, *, mask: oir.Expr = None, ctx: Context, **kwargs: Any - ) -> None: - body = [ - oir.AssignStmt(left=self.visit(node.left), right=self.visit(node.right), loc=node.loc) - ] - if mask is not None: - body = [oir.MaskStmt(body=body, mask=mask)] - ctx.add_horizontal_execution( - oir.HorizontalExecution( - body=body, - declarations=[], - loc=node.loc, - ), - ) + def reset_local_scalars(self): + self.local_scalars = [] - def visit_FieldAccess(self, node: gtir.FieldAccess, **kwargs: Any) -> oir.FieldAccess: + # --- Exprs --- + def visit_FieldAccess(self, node: gtir.FieldAccess) -> oir.FieldAccess: return oir.FieldAccess( name=node.name, offset=self.visit(node.offset), @@ -92,26 +41,26 @@ def visit_FieldAccess(self, node: gtir.FieldAccess, **kwargs: Any) -> oir.FieldA loc=node.loc, ) - def visit_VariableKOffset( - self, node: gtir.VariableKOffset, **kwargs: Any - ) -> oir.VariableKOffset: + def visit_VariableKOffset(self, node: gtir.VariableKOffset) -> oir.VariableKOffset: return oir.VariableKOffset(k=self.visit(node.k)) - def visit_ScalarAccess(self, node: gtir.ScalarAccess, **kwargs: Any) -> oir.ScalarAccess: + def visit_ScalarAccess(self, node: gtir.ScalarAccess) -> oir.ScalarAccess: return oir.ScalarAccess(name=node.name, dtype=node.dtype, loc=node.loc) - def visit_Literal(self, node: gtir.Literal, **kwargs: Any) -> oir.Literal: - return oir.Literal(value=self.visit(node.value), dtype=node.dtype, kind=node.kind) + def visit_Literal(self, node: gtir.Literal) -> oir.Literal: + return oir.Literal( + value=self.visit(node.value), dtype=node.dtype, kind=node.kind, loc=node.loc + ) - def visit_UnaryOp(self, node: gtir.UnaryOp, **kwargs: Any) -> oir.UnaryOp: + def visit_UnaryOp(self, node: gtir.UnaryOp) -> oir.UnaryOp: return oir.UnaryOp(op=node.op, expr=self.visit(node.expr), loc=node.loc) - def visit_BinaryOp(self, node: gtir.BinaryOp, **kwargs: Any) -> oir.BinaryOp: + def visit_BinaryOp(self, node: gtir.BinaryOp) -> oir.BinaryOp: return oir.BinaryOp( op=node.op, left=self.visit(node.left), right=self.visit(node.right), loc=node.loc ) - def visit_TernaryOp(self, node: gtir.TernaryOp, **kwargs: Any) -> oir.TernaryOp: + def visit_TernaryOp(self, node: gtir.TernaryOp) -> oir.TernaryOp: return oir.TernaryOp( cond=self.visit(node.cond), true_expr=self.visit(node.true_expr), @@ -119,10 +68,10 @@ def visit_TernaryOp(self, node: gtir.TernaryOp, **kwargs: Any) -> oir.TernaryOp: loc=node.loc, ) - def visit_Cast(self, node: gtir.Cast, **kwargs: Any) -> oir.Cast: + def visit_Cast(self, node: gtir.Cast) -> oir.Cast: return oir.Cast(dtype=node.dtype, expr=self.visit(node.expr), loc=node.loc) - def visit_FieldDecl(self, node: gtir.FieldDecl, **kwargs: Any) -> oir.FieldDecl: + def visit_FieldDecl(self, node: gtir.FieldDecl) -> oir.FieldDecl: return oir.FieldDecl( name=node.name, dtype=node.dtype, @@ -131,10 +80,10 @@ def visit_FieldDecl(self, node: gtir.FieldDecl, **kwargs: Any) -> oir.FieldDecl: loc=node.loc, ) - def visit_ScalarDecl(self, node: gtir.ScalarDecl, **kwargs: Any) -> oir.ScalarDecl: + def visit_ScalarDecl(self, node: gtir.ScalarDecl) -> oir.ScalarDecl: return oir.ScalarDecl(name=node.name, dtype=node.dtype, loc=node.loc) - def visit_NativeFuncCall(self, node: gtir.NativeFuncCall, **kwargs: Any) -> oir.NativeFuncCall: + def visit_NativeFuncCall(self, node: gtir.NativeFuncCall) -> oir.NativeFuncCall: return oir.NativeFuncCall( func=node.func, args=self.visit(node.args), @@ -143,10 +92,52 @@ def visit_NativeFuncCall(self, node: gtir.NativeFuncCall, **kwargs: Any) -> oir. loc=node.loc, ) + # --- Stmts --- + def visit_ParAssignStmt( + self, node: gtir.ParAssignStmt, *, mask: oir.Expr = None, **kwargs: Any + ) -> oir.AssignStmt: + stmt = oir.AssignStmt(left=self.visit(node.left), right=self.visit(node.right)) + if mask is not None: + # Wrap inside MaskStmt + stmt = oir.MaskStmt(body=[stmt], mask=mask, loc=node.loc) + return stmt + + def visit_While(self, node: gtir.While, *, mask: oir.Expr = None, **kwargs: Any): + body_stmts = [] + for stmt in node.body: + stmt_or_stmts = self.visit(stmt, **kwargs) + if isinstance(stmt_or_stmts, oir.Stmt): + body_stmts.append(stmt_or_stmts) + else: + body_stmts.extend(stmt_or_stmts) + + cond = self.visit(node.cond) + if mask: + cond = oir.BinaryOp(op=common.LogicalOperator.AND, left=mask, right=cond) + stmt = oir.While(cond=cond, body=body_stmts, loc=node.loc) + if mask is not None: + stmt = oir.MaskStmt(body=[stmt], mask=mask, loc=node.loc) + return stmt + def visit_FieldIfStmt( self, node: gtir.FieldIfStmt, *, mask: oir.Expr = None, ctx: Context, **kwargs: Any - ) -> None: - mask_field_decl = _create_mask(ctx, f"mask_{id(node)}", self.visit(node.cond)) + ) -> List[oir.Stmt]: + mask_field_decl = oir.Temporary( + name=f"mask_{id(node)}", dtype=DataType.BOOL, dimensions=(True, True, True) + ) + ctx.temp_fields.append(mask_field_decl) + stmts = [ + oir.AssignStmt( + left=oir.FieldAccess( + name=mask_field_decl.name, + offset=CartesianOffset.zero(), + dtype=DataType.BOOL, + loc=node.loc, + ), + right=self.visit(node.cond), + ) + ] + current_mask = oir.FieldAccess( name=mask_field_decl.name, offset=CartesianOffset.zero(), @@ -156,101 +147,77 @@ def visit_FieldIfStmt( combined_mask = current_mask if mask: combined_mask = oir.BinaryOp( - op=LogicalOperator.AND, - left=mask, - right=combined_mask, - loc=node.loc, + op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc ) - self.visit(node.true_branch.body, mask=combined_mask, ctx=ctx) + stmts.extend(self.visit(node.true_branch.body, mask=combined_mask, ctx=ctx, **kwargs)) if node.false_branch: - combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=current_mask, loc=node.loc) + combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=current_mask) if mask: combined_mask = oir.BinaryOp( - op=LogicalOperator.AND, - left=mask, - right=combined_mask, - loc=node.loc, + op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc ) - self.visit( - node.false_branch.body, - mask=combined_mask, - ctx=ctx, - ) + stmts.extend(self.visit(node.false_branch.body, mask=combined_mask, ctx=ctx, **kwargs)) + + return stmts # For now we represent ScalarIf (and FieldIf) both as masks on the HorizontalExecution. # This is not meant to be set in stone... def visit_ScalarIfStmt( self, node: gtir.ScalarIfStmt, *, mask: oir.Expr = None, ctx: Context, **kwargs: Any - ) -> None: + ) -> List[oir.Stmt]: current_mask = self.visit(node.cond) combined_mask = current_mask if mask: combined_mask = oir.BinaryOp( - op=LogicalOperator.AND, - left=mask, - right=current_mask, - loc=node.loc, + op=LogicalOperator.AND, left=mask, right=current_mask, loc=node.loc ) - self.visit(node.true_branch.body, mask=combined_mask, ctx=ctx) + stmts = self.visit(node.true_branch.body, mask=combined_mask, ctx=ctx, **kwargs) if node.false_branch: combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=current_mask, loc=node.loc) if mask: - combined_mask = oir.BinaryOp( - op=LogicalOperator.AND, - left=mask, - right=combined_mask, - loc=node.loc, - ) - self.visit( - node.false_branch.body, - mask=combined_mask, - ctx=ctx, - ) - - def visit_Interval(self, node: gtir.Interval, **kwargs: Any) -> oir.Interval: - return oir.Interval( - start=self.visit(node.start), - end=self.visit(node.end), - loc=node.loc, - ) - - def visit_VerticalLoop( - self, node: gtir.VerticalLoop, *, ctx: Context, **kwargs: Any - ) -> oir.VerticalLoop: - ctx.horizontal_executions.clear() - self.visit(node.body, ctx=ctx) - - for temp in node.temporaries: - ctx.add_decl( - oir.Temporary( - name=temp.name, - dtype=temp.dtype, - dimensions=temp.dimensions, - loc=node.loc, - ) - ) + combined_mask = oir.BinaryOp(op=LogicalOperator.AND, left=mask, right=combined_mask) + stmts.extend(self.visit(node.false_branch.body, mask=combined_mask, ctx=ctx, **kwargs)) + + return stmts + + # --- Misc --- + def visit_Interval(self, node: gtir.Interval) -> oir.Interval: + return oir.Interval(start=self.visit(node.start), end=self.visit(node.end), loc=node.loc) + + # --- Control flow --- + def visit_VerticalLoop(self, node: gtir.VerticalLoop, *, ctx: Context) -> oir.VerticalLoop: + horiz_execs: List[oir.HorizontalExecution] = [] + for stmt in node.body: + ctx.reset_local_scalars() + ret = self.visit(stmt, ctx=ctx) + stmts = utils.flatten_list([ret] if isinstance(ret, oir.Stmt) else ret) + horiz_execs.append(oir.HorizontalExecution(body=stmts, declarations=ctx.local_scalars)) + + ctx.temp_fields += [ + oir.Temporary(name=temp.name, dtype=temp.dtype, dimensions=temp.dimensions) + for temp in node.temporaries + ] return oir.VerticalLoop( loop_order=node.loop_order, sections=[ oir.VerticalLoopSection( - interval=self.visit(node.interval, **kwargs), - horizontal_executions=ctx.horizontal_executions, + interval=self.visit(node.interval), + horizontal_executions=horiz_execs, loc=node.loc, ) ], - caches=[], - loc=node.loc, ) - def visit_Stencil(self, node: gtir.Stencil, **kwargs: Any) -> oir.Stencil: + def visit_Stencil(self, node: gtir.Stencil) -> oir.Stencil: ctx = self.Context() + vertical_loops = self.visit(node.vertical_loops, ctx=ctx) return oir.Stencil( name=node.name, params=self.visit(node.params), - vertical_loops=self.visit(node.vertical_loops, ctx=ctx), - declarations=ctx.decls, + vertical_loops=vertical_loops, + declarations=ctx.temp_fields, loc=node.loc, ) diff --git a/src/gtc/numpy/npir.py b/src/gtc/numpy/npir.py index 5b5aa42c2d..4d54b6120d 100644 --- a/src/gtc/numpy/npir.py +++ b/src/gtc/numpy/npir.py @@ -161,8 +161,14 @@ class NativeFuncCall(common.NativeFuncCall[Expr], Expr): _dtype_propagation = common.native_func_call_dtype_propagation(strict=True) +# --- Statements --- +@eve.utils.noninstantiable +class Stmt(eve.Node): + pass + + # --- Statement --- -class VectorAssign(common.AssignStmt[VectorLValue, Expr]): +class VectorAssign(common.AssignStmt[VectorLValue, Expr], Stmt): left: VectorLValue right: Expr mask: Optional[Expr] = None @@ -176,10 +182,14 @@ def right_is_field_kind(cls, right: Expr) -> Expr: _dtype_validation = common.assign_stmt_dtype_validation(strict=True) +class While(common.While[Stmt, Expr], Stmt): + pass + + # --- Control Flow --- class HorizontalBlock(common.LocNode, eve.SymbolTableTrait): declarations: List[ScalarDecl] - body: List[VectorAssign] + body: List[Stmt] extent: HorizontalExtent diff --git a/src/gtc/numpy/npir_codegen.py b/src/gtc/numpy/npir_codegen.py index 93ff78639e..a50918c993 100644 --- a/src/gtc/numpy/npir_codegen.py +++ b/src/gtc/numpy/npir_codegen.py @@ -280,6 +280,23 @@ def visit_Broadcast( Broadcast = FormatTemplate("np.full({shape}, {expr})") + While = JinjaTemplate( + textwrap.dedent( + """\ + while {{ cond }}: + {% for stmt in body %}{{ stmt }} + {% endfor %} + """ + ) + ) + + def visit_While(self, node: npir.While, **kwargs: Any) -> str: + cond = self.visit(node.cond, **kwargs) + body = [] + for stmt in self.visit(node.body, **kwargs): + body.extend(stmt.split("\n")) + return self.While.render(cond=cond, body=body) + def visit_VerticalPass(self, node: npir.VerticalPass, **kwargs): is_serial = node.direction != common.LoopOrder.PARALLEL has_variable_k = bool(node.iter_tree().if_isinstance(npir.VarKOffset).to_list()) @@ -319,7 +336,7 @@ def visit_HorizontalBlock( i, I = _di_ - {{ lower[0] }}, _dI_ + {{ upper[0] }} j, J = _dj_ - {{ lower[1] }}, _dJ_ + {{ upper[1] }} - {% for assign in body %}{{ assign }} + {% for stmt in body %}{{ stmt }} {% endfor -%} # --- end horizontal block -- diff --git a/src/gtc/numpy/oir_to_npir.py b/src/gtc/numpy/oir_to_npir.py index a6eec3e758..f638abf5be 100644 --- a/src/gtc/numpy/oir_to_npir.py +++ b/src/gtc/numpy/oir_to_npir.py @@ -165,6 +165,16 @@ def visit_AssignStmt( return npir.VectorAssign(left=left, right=right, mask=mask) # --- Control Flow --- + def visit_While( + self, node: oir.While, *, mask: Optional[npir.Expr] = None, **kwargs: Any + ) -> npir.While: + cond = self.visit(node.cond, mask=mask, **kwargs) + if mask: + mask = npir.VectorLogic(op=common.LogicalOperator.AND, left=mask, right=cond) + else: + mask = cond + return npir.While(cond=cond, body=self.visit(node.body, mask=mask, **kwargs)) + def visit_HorizontalExecution( self, node: oir.HorizontalExecution, diff --git a/src/gtc/oir.py b/src/gtc/oir.py index eab2d5f8e3..00f67d0f31 100644 --- a/src/gtc/oir.py +++ b/src/gtc/oir.py @@ -99,6 +99,10 @@ class NativeFuncCall(common.NativeFuncCall[Expr], Expr): _dtype_propagation = common.native_func_call_dtype_propagation(strict=True) +class While(common.While[Stmt, Expr], Stmt): + pass + + class Decl(LocNode): name: SymbolName dtype: common.DataType @@ -185,7 +189,7 @@ class VerticalLoopSection(LocNode): class VerticalLoop(LocNode): loop_order: common.LoopOrder sections: List[VerticalLoopSection] - caches: List[CacheDesc] + caches: List[CacheDesc] = [] @validator("sections") def nonempty_loop(cls, v: List[VerticalLoopSection]) -> List[VerticalLoopSection]: diff --git a/src/gtc/passes/gtir_access_kind.py b/src/gtc/passes/gtir_access_kind.py index 828b2f0b21..995b704a5a 100644 --- a/src/gtc/passes/gtir_access_kind.py +++ b/src/gtc/passes/gtir_access_kind.py @@ -50,6 +50,10 @@ def visit_ScalarIfStmt(self, node: gtir.ScalarIfStmt, **kwargs: Any) -> None: def visit_FieldIfStmt(self, node: gtir.FieldIfStmt, **kwargs: Any) -> None: self._visit_If(node, **kwargs) + def visit_While(self, node: gtir.While, **kwargs: Any) -> None: + self.visit(node.cond, kind=AccessKind.READ, **kwargs) + self.visit(node.body, **kwargs) + def visit_ParAssignStmt(self, node: gtir.ParAssignStmt, **kwargs: Any) -> None: self.visit(node.right, kind=AccessKind.READ, **kwargs) self.visit(node.left, kind=AccessKind.WRITE, **kwargs) diff --git a/src/gtc/passes/gtir_legacy_extents.py b/src/gtc/passes/gtir_legacy_extents.py index f6d2c1f898..8f4f45bbf6 100644 --- a/src/gtc/passes/gtir_legacy_extents.py +++ b/src/gtc/passes/gtir_legacy_extents.py @@ -4,7 +4,7 @@ from eve import NodeVisitor from eve.utils import XIterable from gt4py.definitions import Extent -from gtc import gtir +from gtc import common, gtir def _iter_field_names(node: Union[gtir.Stencil, gtir.ParAssignStmt]) -> XIterable[gtir.FieldAccess]: @@ -15,7 +15,7 @@ def _iter_assigns(node: gtir.Stencil) -> XIterable[gtir.ParAssignStmt]: return node.iter_tree().if_isinstance(gtir.ParAssignStmt) -def _ext_from_off(offset: Union[gtir.CartesianOffset, gtir.VariableKOffset]) -> Extent: +def _ext_from_off(offset: Union[common.CartesianOffset, gtir.VariableKOffset]) -> Extent: if isinstance(offset, gtir.VariableKOffset): return Extent(((0, 0), (0, 0), (0, 0))) return Extent(((offset.i, offset.i), (offset.j, offset.j), (0, 0))) diff --git a/src/gtc/passes/oir_optimizations/caches.py b/src/gtc/passes/oir_optimizations/caches.py index 19b819f7e2..0fbd520f9d 100644 --- a/src/gtc/passes/oir_optimizations/caches.py +++ b/src/gtc/passes/oir_optimizations/caches.py @@ -114,6 +114,13 @@ def visit_VerticalLoop(self, node: oir.VerticalLoop, **kwargs: Any) -> oir.Verti ): return self.generic_visit(node, **kwargs) + all_accesses = AccessCollector.apply(node) + fields_with_variable_reads = { + field + for field, offsets in all_accesses.offsets().items() + if any(off[2] is None for off in offsets) + } + def accessed_more_than_once(offsets: Set[Any]) -> bool: return len(offsets) > 1 @@ -127,11 +134,15 @@ def has_horizontal_offset(offsets: Set[Tuple[int, int, int]]) -> bool: def offsets_within_limits(offsets: Set[Tuple[int, int, int]]) -> bool: return all(abs(offset[2]) <= self.max_cacheable_offset for offset in offsets) - accesses = AccessCollector.apply(node).cartesian_accesses().offsets() + def has_variable_offset_reads(field: str) -> bool: + return field in fields_with_variable_reads + + accesses = all_accesses.cartesian_accesses().offsets() cacheable = { field for field, offsets in accesses.items() if not already_cached(field) + and not has_variable_offset_reads(field) and accessed_more_than_once(offsets) and not has_horizontal_offset(offsets) and offsets_within_limits(offsets) diff --git a/src/gtc/passes/oir_optimizations/mask_stmt_merging.py b/src/gtc/passes/oir_optimizations/mask_stmt_merging.py index 1acb919fc2..22f92e21a0 100644 --- a/src/gtc/passes/oir_optimizations/mask_stmt_merging.py +++ b/src/gtc/passes/oir_optimizations/mask_stmt_merging.py @@ -48,5 +48,16 @@ def visit_HorizontalExecution(self, node: oir.HorizontalExecution) -> oir.Horizo loc=node.loc, ) + # Stmt node types with lists of Stmts within them: + def visit_MaskStmt(self, node: oir.MaskStmt) -> oir.MaskStmt: return oir.MaskStmt(mask=node.mask, body=self._merge(node.body), loc=node.loc) + + def visit_While(self, node: oir.While) -> oir.While: + body_nodes = [] + for stmt in node.body: + if isinstance(stmt, oir.MaskStmt) and node.cond == stmt.mask: + body_nodes.extend(stmt.body) + else: + body_nodes.append(stmt) + return oir.While(cond=self.visit(node.cond), body=self.visit(body_nodes), loc=node.loc) diff --git a/src/gtc/passes/oir_optimizations/utils.py b/src/gtc/passes/oir_optimizations/utils.py index e14adc3da0..70a1e18144 100644 --- a/src/gtc/passes/oir_optimizations/utils.py +++ b/src/gtc/passes/oir_optimizations/utils.py @@ -91,6 +91,10 @@ def visit_MaskStmt(self, node: oir.MaskStmt, **kwargs: Any) -> None: self.visit(node.mask, is_write=False, **kwargs) self.visit(node.body, **kwargs) + def visit_While(self, node: oir.While, **kwargs: Any) -> None: + self.visit(node.cond, is_write=False, **kwargs) + self.visit(node.body, **kwargs) + @dataclass class GenericAccessCollection(Generic[AccessT, OffsetT]): _ordered_accesses: List[AccessT] @@ -102,15 +106,15 @@ def _offset_dict(accesses: XIterable) -> Dict[str, Set[OffsetT]]: ) def offsets(self) -> Dict[str, Set[OffsetT]]: - """Get a dictonary, mapping all accessed fields' names to sets of offset tuples.""" + """Get a dictionary, mapping all accessed fields' names to sets of offset tuples.""" return self._offset_dict(xiter(self._ordered_accesses)) def read_offsets(self) -> Dict[str, Set[OffsetT]]: - """Get a dictonary, mapping read fields' names to sets of offset tuples.""" + """Get a dictionary, mapping read fields' names to sets of offset tuples.""" return self._offset_dict(xiter(self._ordered_accesses).filter(lambda x: x.is_read)) def write_offsets(self) -> Dict[str, Set[OffsetT]]: - """Get a dictonary, mapping written fields' names to sets of offset tuples.""" + """Get a dictionary, mapping written fields' names to sets of offset tuples.""" return self._offset_dict(xiter(self._ordered_accesses).filter(lambda x: x.is_write)) def fields(self) -> Set[str]: diff --git a/tests/test_integration/test_code_generation.py b/tests/test_integration/test_code_generation.py index 75efcbf8b5..9c5f2d5952 100644 --- a/tests/test_integration/test_code_generation.py +++ b/tests/test_integration/test_code_generation.py @@ -377,7 +377,10 @@ def stencil_ijk( out_field = in_field[0, 0, 1] + in_field[0, 0, index_field + 1] -@pytest.mark.skip("While loop not yet supported") +# TODO: Enable DaCe +@pytest.mark.parametrize( + "backend", [backend for backend in ALL_BACKENDS if backend.values[0] != "gtc:dace"] +) def test_variable_offsets_and_while_loop(backend): @gtscript.stencil(backend=backend) def stencil( @@ -387,7 +390,7 @@ def stencil( qout: gtscript.Field[np.float_], lev: gtscript.Field[gtscript.IJ, np.int_], ): - with computation(FORWARD), interval(...): + with computation(FORWARD), interval(0, -1): if pe2[0, 0, 1] <= pe1[0, 0, lev]: qout = qin[0, 0, 1] else: @@ -398,6 +401,24 @@ def stencil( qout = qsum / (pe2[0, 0, 1] - pe2) +# TODO: Enable DaCe +@pytest.mark.parametrize( + "backend", [backend for backend in ALL_BACKENDS if backend.values[0] != "gtc:dace"] +) +def test_nested_while_loop(backend): + @gtscript.stencil(backend=backend) + def stencil( + field_a: gtscript.Field[np.float_], + field_b: gtscript.Field[np.int_], + ): + with computation(PARALLEL), interval(...): + while field_a < 1: + add = 0 + while field_a + field_b < 1: + add += 1 + field_a += add + + @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_mask_with_offset_written_in_conditional(backend): @gtscript.stencil(backend, externals={"mord": 5}) diff --git a/tests/test_unittest/test_gtc/gtir_utils.py b/tests/test_unittest/test_gtc/gtir_utils.py index 93b87eaaae..c10320a11c 100644 --- a/tests/test_unittest/test_gtc/gtir_utils.py +++ b/tests/test_unittest/test_gtc/gtir_utils.py @@ -77,7 +77,7 @@ class BlockStmtFactory(factory.Factory): class Meta: model = gtir.BlockStmt - body = [] + body: List[gtir.Stmt] = factory.List([factory.SubFactory(ParAssignStmtFactory)]) class FieldIfStmtFactory(factory.Factory): @@ -98,6 +98,14 @@ class Meta: false_branch = None +class WhileFactory(factory.Factory): + class Meta: + model = gtir.While + + cond = factory.SubFactory(FieldAccessFactory, dtype=common.DataType.BOOL) + body = factory.List([factory.SubFactory(ParAssignStmtFactory)]) + + class IntervalFactory(factory.Factory): class Meta: model = gtir.Interval @@ -129,7 +137,7 @@ class Meta: interval = factory.SubFactory(IntervalFactory) loop_order = common.LoopOrder.PARALLEL - temporaries = [] + temporaries: List[gtir.FieldDecl] = [] body = factory.List([factory.SubFactory(ParAssignStmtFactory)]) diff --git a/tests/test_unittest/test_gtc/oir_utils.py b/tests/test_unittest/test_gtc/oir_utils.py index ab34315877..ff3fccdb09 100644 --- a/tests/test_unittest/test_gtc/oir_utils.py +++ b/tests/test_unittest/test_gtc/oir_utils.py @@ -80,6 +80,14 @@ class Meta: body = factory.List([factory.SubFactory(AssignStmtFactory)]) +class WhileFactory(factory.Factory): + class Meta: + model = oir.While + + cond = factory.SubFactory(FieldAccessFactory, dtype=common.DataType.BOOL) + body = factory.List([factory.SubFactory(AssignStmtFactory)]) + + class NativeFuncCallFactory(factory.Factory): class Meta: model = oir.NativeFuncCall diff --git a/tests/test_unittest/test_gtc/test_gtir.py b/tests/test_unittest/test_gtc/test_gtir.py index 8cf9619b07..3c6e10e3ed 100644 --- a/tests/test_unittest/test_gtc/test_gtir.py +++ b/tests/test_unittest/test_gtc/test_gtir.py @@ -18,7 +18,7 @@ from pydantic.error_wrappers import ValidationError from eve import SourceLocation -from gtc.common import ArithmeticOperator, DataType, LevelMarker, LoopOrder +from gtc.common import ArithmeticOperator, ComparisonOperator, DataType, LevelMarker, LoopOrder from gtc.gtir import ( AxisBound, Decl, @@ -42,6 +42,7 @@ StencilFactory, VariableKOffsetFactory, VerticalLoopFactory, + WhileFactory, ) @@ -207,6 +208,35 @@ def test_indirect_address_data_dims(): FieldAccessFactory(data_index=[ScalarAccessFactory(dtype=DataType.FLOAT32)]) +def test_while_without_boolean_condition(): + with pytest.raises(ValueError, match=r"Condition in.*must be boolean."): + WhileFactory( + cond=BinaryOpFactory( + left__name="foo", + right__name="bar", + ), + dtype=DataType.FLOAT32, + ) + + +def test_while_with_accumulated_extents(): + with pytest.raises( + ValueError, match=r"Illegal write and read with horizontal offset detected for.*" + ): + WhileFactory( + cond=BinaryOpFactory( + left__name="a", + right__name="b", + op=ComparisonOperator.LT, + dtype=DataType.BOOL, + ), + body=[ + ParAssignStmtFactory(left__name="a", right__name="b", right__offset__i=1), + ParAssignStmtFactory(left__name="b", right__name="a"), + ], + ) + + def test_variable_k_offset_in_access(): # Integer expressions are OK FieldAccessFactory(offset=VariableKOffsetFactory()) diff --git a/tests/test_unittest/test_gtc/test_gtir_to_oir.py b/tests/test_unittest/test_gtc/test_gtir_to_oir.py index 50344075be..321413210c 100644 --- a/tests/test_unittest/test_gtc/test_gtir_to_oir.py +++ b/tests/test_unittest/test_gtc/test_gtir_to_oir.py @@ -17,17 +17,16 @@ from typing import Type from eve import Node -from gtc import gtir, gtir_to_oir, oir -from gtc.common import DataType +from gtc import oir from gtc.gtir_to_oir import GTIRToOIR -from . import oir_utils from .gtir_utils import ( - BlockStmtFactory, - FieldAccessFactory, FieldIfStmtFactory, + ParAssignStmtFactory, ScalarIfStmtFactory, + StencilFactory, VariableKOffsetFactory, + WhileFactory, ) @@ -39,69 +38,60 @@ def isinstance_and_return(node: Node, expected_type: Type[Node]): def test_visit_ParAssignStmt(): out_name = "out" in_name = "in" - testee = gtir.ParAssignStmt( - left=FieldAccessFactory(name=out_name), right=FieldAccessFactory(name=in_name) - ) - - ctx = GTIRToOIR.Context() - GTIRToOIR().visit(testee, ctx=ctx) - result_horizontal_executions = ctx.horizontal_executions - - assert len(result_horizontal_executions) == 1 - assign = isinstance_and_return(result_horizontal_executions[0].body[0], oir.AssignStmt) + testee = ParAssignStmtFactory(left__name=out_name, right__name=in_name) + assign = GTIRToOIR().visit(testee) left = isinstance_and_return(assign.left, oir.FieldAccess) right = isinstance_and_return(assign.right, oir.FieldAccess) assert left.name == out_name assert right.name == in_name -def test_create_mask(): - mask_name = "mask" - cond = oir_utils.FieldAccessFactory(dtype=DataType.BOOL) - ctx = GTIRToOIR.Context() - result_decl = gtir_to_oir._create_mask(ctx, mask_name, cond) - result_assign = ctx.horizontal_executions[0] - - assert isinstance(result_decl, oir.Temporary) - assert result_decl.name == mask_name +def test_visit_gtir_Stencil(): + out_name = "out" + in_name = "in" - horizontal_exec = isinstance_and_return(result_assign, oir.HorizontalExecution) - assign = isinstance_and_return(horizontal_exec.body[0], oir.AssignStmt) + testee = StencilFactory( + vertical_loops__0__body__0=ParAssignStmtFactory(left__name=out_name, right__name=in_name) + ) + oir_stencil = GTIRToOIR().visit(testee) + hexecs = oir_stencil.vertical_loops[0].sections[0].horizontal_executions + assert len(hexecs) == 1 + assert len(hexecs[0].body) == 1 + assign = hexecs[0].body[0] left = isinstance_and_return(assign.left, oir.FieldAccess) right = isinstance_and_return(assign.right, oir.FieldAccess) - - assert left.name == mask_name - assert right == cond - - -def test_visit_Assign_VariableKOffset(): - testee = gtir.ParAssignStmt( - left=FieldAccessFactory(), right=FieldAccessFactory(offset=VariableKOffsetFactory()) - ) - ctx = GTIRToOIR.Context() - GTIRToOIR().visit(testee, ctx=ctx) - - assert len(ctx.horizontal_executions) == 1 - assert ctx.horizontal_executions[0].iter_tree().if_isinstance(oir.VariableKOffset).to_list() + assert left.name == out_name + assert right.name == in_name def test_visit_FieldIfStmt(): - testee = FieldIfStmtFactory(false_branch=BlockStmtFactory()) - GTIRToOIR().visit(testee, ctx=GTIRToOIR.Context()) - + testee = FieldIfStmtFactory(true_branch__body__0=ParAssignStmtFactory()) + mask_stmts = GTIRToOIR().visit(testee, ctx=GTIRToOIR.Context()) -def test_visit_FieldIfStmt_no_else(): - testee = FieldIfStmtFactory(false_branch=None) - GTIRToOIR().visit(testee, ctx=GTIRToOIR.Context()) + assert len(mask_stmts) == 2 + assert "mask" in mask_stmts[0].left.name + assert testee.cond.name == mask_stmts[0].right.name + assert mask_stmts[1].body[0].left.name == testee.true_branch.body[0].left.name def test_visit_FieldIfStmt_nesting(): - testee = FieldIfStmtFactory(true_branch=BlockStmtFactory(body=[FieldIfStmtFactory()])) + testee = FieldIfStmtFactory(true_branch__body__0=FieldIfStmtFactory()) GTIRToOIR().visit(testee, ctx=GTIRToOIR.Context()) def test_visit_ScalarIfStmt(): testee = ScalarIfStmtFactory() GTIRToOIR().visit(testee, ctx=GTIRToOIR.Context()) + + +def test_visit_Assign_VariableKOffset(): + testee = ParAssignStmtFactory(right__offset=VariableKOffsetFactory()) + assign_stmt = GTIRToOIR().visit(testee) + assert assign_stmt.iter_tree().if_isinstance(oir.VariableKOffset).to_list() + + +def test_visit_While(): + testee = WhileFactory() + GTIRToOIR().visit(testee)