Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
c4d9475
Initial implementation for numpy and gt_backends
jdahm Apr 14, 2021
3375985
Fix numpy backend
jdahm Apr 14, 2021
5cfb0c1
Add while loop to AST
jdahm Apr 14, 2021
f5e8e4c
Remove bad check on parallel axes
jdahm Apr 14, 2021
aba0b1a
Merge branch 'variable-offsets' into while-loop
jdahm Apr 14, 2021
9aa8da3
Numpy implementation
jdahm Apr 14, 2021
b498f5b
Fix detection of self.block_info.has_variable_koffset
jdahm Apr 14, 2021
2b4fc60
Merge branch 'variable-offsets' into while-loop
jdahm Apr 14, 2021
9939947
Treat case where K offset does not exist
jdahm Apr 14, 2021
e183f15
Merge branch 'variable-offsets' into while-loop
jdahm Apr 14, 2021
2e131cc
Visit offsets
jdahm Apr 20, 2021
9c0a0ff
Some hacks
jdahm Apr 21, 2021
660abe9
Correct numpy backend
jdahm Apr 21, 2021
bebf2fd
Remove prints
jdahm Apr 22, 2021
d6325a2
Fix variable offsets
jdahm Apr 23, 2021
8146fe7
Merge branch 'master' into variable-offsets
jdahm Apr 23, 2021
342e121
Potential fix
jdahm Apr 28, 2021
833387d
Fix parallel axes
jdahm Apr 28, 2021
278e72f
Merge branch 'master' into variable-offsets
jdahm Apr 28, 2021
46a8459
Merge branch 'variable-offsets' into while-loop
jdahm Apr 28, 2021
64ea4ac
Fix parallel axes in numpy backend
jdahm Apr 29, 2021
2798a5d
Debug implementation
jdahm Apr 29, 2021
1107630
Merge branch 'master' into variable-offsets
jdahm Apr 29, 2021
f28473d
Merge branch 'variable-offsets' into while-loop
jdahm Apr 29, 2021
71532d6
Debug implementation and test
jdahm Apr 30, 2021
755f250
Fix code generation in gt_backends
jdahm Apr 30, 2021
2ea91f0
Merge branch 'variable-offsets' into while-loop
jdahm Apr 30, 2021
ae9ac6a
Remove verbose=True that broke test
jdahm May 3, 2021
64688aa
Merge branch 'variable-offsets' into while-loop
jdahm May 3, 2021
a8b1932
Fix numpy codegen
jdahm May 4, 2021
40bfac4
Restore GreedyMerging pass in GtCpp
May 26, 2021
f2b2dd3
Merge remote-tracking branch 'jdahm/while-loop' into while-loops-gtc
May 26, 2021
2cba5d2
Add while loops to GtCpp and CUDA backends
May 26, 2021
68e6417
Merge remote-tracking branch 'gridtools/master' into while-loops-gtc
Jun 1, 2021
b75b132
Apply formatting
Jun 1, 2021
771ca8b
Propagate 'MaskStmt.is_loop' attribute in MaskStmtMerging pass
Jun 1, 2021
83bf084
Make ternary if
Jun 1, 2021
701fabb
Merge remote-tracking branch 'gridtools/master' into while-loops-gtc
Jun 2, 2021
b976dc3
Add GTIR and OIR tests
Jun 2, 2021
a9851f7
Revert GtCpp backend changes
Jun 2, 2021
31b3efc
Copy 'is_loop' from OIR to CUIR MaskStmt
Jun 2, 2021
fab0092
Add 'is_loop' check to MaskStmtMerging
Jun 2, 2021
3b2cb99
Add validator for accumulated extents in while
Jun 4, 2021
062d3a9
Merge remote-tracking branch 'gridtools/master' into while-loops-gtc
Jun 18, 2021
ed3049c
Merge remote-tracking branch 'gridtools/master' into while-loops-gtc
Jul 15, 2021
d08d3e7
oir.While node replaces MaskStmt.is_loop
jdahm Jul 23, 2021
b02e926
Merge remote-tracking branch 'gridtools/master' into while-loops-gtc
Jul 26, 2021
d28f7a4
Add While node to CUIR and update codegen
Jul 26, 2021
418857f
Remove remaining 'is_loop' refs
Jul 26, 2021
69f0f1d
Merge remote-tracking branch 'gridtools/master' into while-loops-gtc
Jul 27, 2021
5b66d8c
Implement while loops in npir
Jul 27, 2021
579412c
Simplify npir_gen.visit_While
Jul 28, 2021
97ec9a3
Merge remote-tracking branch 'gridtools/master' into while-loops-gtc
Aug 5, 2021
f4a4f2c
Merge remote-tracking branch 'gridtools/master' into while-loops-gtc
Sep 14, 2021
6d02133
Merge remote-tracking branch 'gridtools/master' into while-loops-gtc
Oct 25, 2021
2b34e4b
Merge remote-tracking branch 'gridtools/master' into while-loops-gtc
Nov 19, 2021
133d990
Add 'visit_While' to TaskletCodegen for DaCe support
Nov 19, 2021
0d85150
Add tests to known_first_party isort config (#563)
jdahm Nov 25, 2021
be52ee1
Cache origin and domain normalization (#539)
jdahm Dec 2, 2021
25b7023
Clean-up
jdahm Dec 2, 2021
3f59442
Merge branch 'master' into while-loops-gtc
jdahm Dec 2, 2021
3ecf95b
Fix test
jdahm Dec 3, 2021
841e430
Fix gtir_to_oir lowering and cleanup
jdahm Dec 4, 2021
93a914e
Add While support to MaskStmtMerging
jdahm Dec 6, 2021
0d2236a
Add temp declarations for masks
jdahm Dec 6, 2021
6c6bcc2
Correct gtir_to_oir
jdahm Dec 9, 2021
bf16455
Merge branch 'master' into while-loops-gtc
jdahm Dec 9, 2021
9f037ee
Fix NativeFuncCall template
jdahm Dec 15, 2021
3291e06
Minor change to trigger tests again
jdahm Dec 15, 2021
4b729e4
Merge branch 'master' into while-loops-gtc
jdahm Dec 16, 2021
951c043
Fix temps addition for unit tests
jdahm Dec 16, 2021
23035ee
Clean up with ctx
jdahm Dec 17, 2021
3df21bb
Add ctx to a few visit calls to fix tests
jdahm Dec 17, 2021
7119daf
test_variable_offsets_and_while_loop on all backends
jdahm Dec 21, 2021
45647e5
Merge branch 'master' into while-loops-gtc
jdahm Dec 21, 2021
35c86d2
Merge branch 'master' into while-loops-gtc
jdahm Dec 21, 2021
89db5f5
Remove while loop workaround
jdahm Dec 21, 2021
d7250da
Fix npir backend While loops
jdahm Dec 22, 2021
e984970
Disable gtc:dace backend on integration test
jdahm Dec 22, 2021
fba7b62
Merge branch 'master' into while-loops-gtc
jdahm Jan 13, 2022
1575746
Address @havogt initial comments
jdahm Jan 13, 2022
2eac6dc
Fix test
jdahm Jan 13, 2022
40b8e89
Merge branch 'master' into while-loops-gtc
jdahm Feb 7, 2022
3378fc5
Merge #628
jdahm Feb 7, 2022
b919320
Address review comments by @havogt
jdahm Feb 7, 2022
20fd3f7
Add While to new gtc:numpy backend
jdahm Feb 7, 2022
237bf80
Merge branch 'master' into while-loops-gtc
jdahm Feb 11, 2022
0a73bb5
Fix while loop parsing and add codegen test
jdahm Feb 11, 2022
e16534d
Merge branch 'master' into while-loops-gtc
jdahm Feb 18, 2022
267c882
Add While support to gtir_access_kind
jdahm Feb 21, 2022
19885e3
Do not k-cache fields with any var-k-offset read
jdahm Feb 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ repos:
- id: check-case-conflict
# - id: check-json
- id: check-merge-conflict
exclude_types: [rst]
- id: check-toml
- id: check-yaml
- id: debug-statements
Expand Down
26 changes: 17 additions & 9 deletions src/gt4py/backend/gtc_backend/defir_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later

import numbers
from typing import Any, Dict, List, Union, cast
from typing import Any, Dict, List, Tuple, Union, cast

from gt4py.ir import IRNodeVisitor
from gt4py.ir.nodes import (
Expand Down Expand Up @@ -45,6 +45,7 @@
UnaryOpExpr,
VarDecl,
VarRef,
While,
)
from gtc import common, gtir
from gtc.common import ExprKind
Expand Down Expand Up @@ -245,15 +246,15 @@ def visit_NativeFuncCall(self, node: NativeFuncCall) -> gtir.NativeFuncCall:
loc=common.location_to_source_location(node.loc),
)

def visit_FieldRef(self, node: FieldRef):
def visit_FieldRef(self, node: FieldRef) -> gtir.FieldAccess:
return gtir.FieldAccess(
name=node.name,
offset=self.transform_offset(node.offset),
data_index=[self.visit(index) for index in node.data_index],
loc=common.location_to_source_location(node.loc),
)

def visit_If(self, node: If):
def visit_If(self, node: If) -> Union[gtir.FieldIfStmt, gtir.ScalarIfStmt]:
cond = self.visit(node.condition)
if cond.kind == ExprKind.FIELD:
return gtir.FieldIfStmt(
Expand All @@ -274,19 +275,26 @@ def visit_If(self, node: If):
loc=common.location_to_source_location(node.loc),
)

def visit_While(self, node: While) -> gtir.While:
return gtir.While(
cond=self.visit(node.condition),
body=self.visit(node.body),
loc=common.location_to_source_location(node.loc),
)

def visit_VarRef(self, node: VarRef, **kwargs):
return gtir.ScalarAccess(name=node.name, loc=common.location_to_source_location(node.loc))

def visit_AxisInterval(self, node: AxisInterval):
def visit_AxisInterval(self, node: AxisInterval) -> Tuple[gtir.AxisBound, gtir.AxisBound]:
return self.visit(node.start), self.visit(node.end)

def visit_AxisBound(self, node: AxisBound):
def visit_AxisBound(self, node: AxisBound) -> gtir.AxisBound:
# TODO(havogt) add support VarRef
return gtir.AxisBound(
level=self.GT4PY_LEVELMARKER_TO_GTIR_LEVELMARKER[node.level], offset=node.offset
)

def visit_FieldDecl(self, node: FieldDecl):
def visit_FieldDecl(self, node: FieldDecl) -> gtir.FieldDecl:
dimension_names = ["I", "J", "K"]
dimensions = [dim in node.axes for dim in dimension_names]
# datatype conversion works via same ID
Expand All @@ -298,7 +306,7 @@ def visit_FieldDecl(self, node: FieldDecl):
loc=common.location_to_source_location(node.loc),
)

def visit_VarDecl(self, node: VarDecl):
def visit_VarDecl(self, node: VarDecl) -> gtir.ScalarDecl:
# datatype conversion works via same ID
return gtir.ScalarDecl(
name=node.name,
Expand All @@ -308,10 +316,10 @@ def visit_VarDecl(self, node: VarDecl):

def transform_offset(
self, offset: Dict[str, Union[int, Expr]], **kwargs: Any
) -> Union[gtir.CartesianOffset, gtir.VariableKOffset]:
) -> Union[common.CartesianOffset, gtir.VariableKOffset]:
k_val = offset.get("K", 0)
if isinstance(k_val, numbers.Integral):
return gtir.CartesianOffset(i=offset.get("I", 0), j=offset.get("J", 0), k=k_val)
return common.CartesianOffset(i=offset.get("I", 0), j=offset.get("J", 0), k=k_val)
elif isinstance(k_val, Expr):
return gtir.VariableKOffset(k=self.visit(k_val, **kwargs))
else:
Expand Down
51 changes: 31 additions & 20 deletions src/gt4py/frontend/gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ def __init__(
self.extra_temp_decls = extra_temp_decls or {}
self.parsing_context = None
self.iteration_order = None
self.if_decls_stack = []
self.decls_stack = []
gt_ir.NativeFunction.PYTHON_SYMBOL_TO_IR_OP = {
"abs": gt_ir.NativeFunction.ABS,
"min": gt_ir.NativeFunction.MIN,
Expand Down Expand Up @@ -1237,7 +1237,7 @@ def visit_IfExp(self, node: ast.IfExp) -> gt_ir.TernaryOpExpr:
return result

def visit_If(self, node: ast.If) -> list:
self.if_decls_stack.append([])
self.decls_stack.append([])

main_stmts = []
for stmt in node.body:
Expand All @@ -1251,11 +1251,11 @@ def visit_If(self, node: ast.If) -> list:
assert all(isinstance(item, gt_ir.Statement) for item in else_stmts)

result = []
if len(self.if_decls_stack) == 1:
result.extend(self.if_decls_stack.pop())
elif len(self.if_decls_stack) > 1:
self.if_decls_stack[-2].extend(self.if_decls_stack[-1])
self.if_decls_stack.pop()
if len(self.decls_stack) == 1:
result.extend(self.decls_stack.pop())
elif len(self.decls_stack) > 1:
self.decls_stack[-2].extend(self.decls_stack[-1])
self.decls_stack.pop()

result.append(
gt_ir.If(
Expand All @@ -1270,17 +1270,28 @@ def visit_If(self, node: ast.If) -> list:

return result

def visit_While(self, node: ast.While) -> gt_ir.While:
if node.orelse:
raise GTScriptSyntaxError("orelse is not supported on while loops")
stmts = []
for stmt in node.body:
stmts.extend(self.visit(stmt))
return gt_ir.While(
condition=self.visit(node.test),
loc=gt_ir.Location.from_ast_node(node),
body=gt_ir.BlockStmt(stmts=stmts, loc=gt_ir.Location.from_ast_node(node)),
)
def visit_While(self, node: ast.While) -> list:
loc = gt_ir.Location.from_ast_node(node)

self.decls_stack.append([])
stmts = gt_utils.flatten([self.visit(stmt) for stmt in node.body])
assert all(isinstance(item, gt_ir.Statement) for item in stmts)

result = [
gt_ir.While(
condition=self.visit(node.test),
loc=gt_ir.Location.from_ast_node(node),
body=gt_ir.BlockStmt(stmts=stmts, loc=loc),
)
]

if len(self.decls_stack) == 1:
result.extend(self.decls_stack.pop())
elif len(self.decls_stack) > 1:
self.decls_stack[-2].extend(self.decls_stack[-1])
self.decls_stack.pop()

return result

def visit_Call(self, node: ast.Call):
native_fcn = gt_ir.NativeFunction.PYTHON_SYMBOL_TO_IR_OP[node.func.id]
Expand Down Expand Up @@ -1371,8 +1382,8 @@ def visit_Assign(self, node: ast.Assign) -> list:
# layout_id=t.id,
is_api=False,
)
if len(self.if_decls_stack):
self.if_decls_stack[-1].append(field_decl)
if len(self.decls_stack):
self.decls_stack[-1].append(field_decl)
else:
result.append(field_decl)
self.fields[field_decl.name] = field_decl
Expand Down
15 changes: 15 additions & 0 deletions src/gtc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,21 @@ def condition_is_boolean(cls, cond: Expr) -> Expr:
return verify_condition_is_boolean(cls, cond)


class While(GenericNode, Generic[StmtT, ExprT]):
"""
Generic while loop.

Verifies that `cond` is a boolean expr (if `dtype` is set).
"""

cond: ExprT
body: List[StmtT]

@validator("cond")
def condition_is_boolean(cls, cond: Expr) -> Expr:
return verify_condition_is_boolean(cls, cond)


class AssignStmt(GenericNode, Generic[TargetT, ExprT]):
left: TargetT
right: ExprT
Expand Down
12 changes: 11 additions & 1 deletion src/gtc/cuir/cuir.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,17 @@ class KCacheAccess(common.FieldAccess[Expr, VariableKOffset], Expr):
k_cache_is_different_from_field_access = True

@validator("offset")
def zero_ij_offset(cls, v: CartesianOffset) -> CartesianOffset:
def has_no_ij_offset(cls, v: Union[CartesianOffset, VariableKOffset]) -> CartesianOffset:
if not v.i == v.j == 0:
raise ValueError("No ij-offset allowed")
return v

@validator("offset")
def not_variable_offset(cls, v: Union[CartesianOffset, VariableKOffset]) -> CartesianOffset:
if isinstance(v, VariableKOffset):
raise ValueError("Cannot k-cache a variable k offset")
return v

@validator("data_index")
def no_additional_dimensions(cls, v: List[int]) -> List[int]:
if v:
Expand All @@ -92,6 +98,10 @@ class MaskStmt(Stmt):
body: List[Stmt]


class While(common.While[Stmt, Expr], Stmt):
pass


class UnaryOp(common.UnaryOp[Expr], Expr):
pass

Expand Down
8 changes: 8 additions & 0 deletions src/gtc/cuir/cuir_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ class CUIRCodegen(codegen.TemplatedGenerator):
"""
)

While = as_mako(
"""
while (${cond}) {
${'\\n'.join(body)}
}
"""
)

def visit_FieldAccess(self, node: cuir.FieldAccess, **kwargs: Any):
def maybe_const(s):
try:
Expand Down
5 changes: 5 additions & 0 deletions src/gtc/cuir/oir_to_cuir.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ def visit_MaskStmt(self, node: oir.MaskStmt, **kwargs: Any) -> cuir.MaskStmt:
mask=self.visit(node.mask, **kwargs), body=self.visit(node.body, **kwargs)
)

def visit_While(self, node: oir.While, **kwargs: Any) -> cuir.While:
return cuir.While(
cond=self.visit(node.cond, **kwargs), body=self.visit(node.body, **kwargs)
)

def visit_Cast(self, node: oir.Cast, **kwargs: Any) -> cuir.Cast:
return cuir.Cast(dtype=node.dtype, expr=self.visit(node.expr, **kwargs))

Expand Down
7 changes: 7 additions & 0 deletions src/gtc/dace/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,13 @@ def visit_MaskStmt(self, node: oir.MaskStmt, **kwargs):
body_code = [indent + b for b in body_code]
return "\n".join([mask_str] + body_code)

def visit_While(self, node: oir.While, **kwargs):
body = self.visit(node.body, **kwargs)
cond = self.visit(node.cond, is_target=False, **kwargs)
indent = " " * 4
delim = f"\n{indent}"
return f"while {cond}:\n{indent}{delim.join(body)}"

class RemoveCastInIndexVisitor(eve.NodeTranslator):
def visit_FieldAccess(self, node: oir.FieldAccess):
if node.data_index:
Expand Down
13 changes: 8 additions & 5 deletions src/gtc/dace/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,16 +232,19 @@ def domain() -> "CartesianIterationSpace":
)

@staticmethod
def from_offset(offset: CartesianOffset) -> "CartesianIterationSpace":
def from_offset(
offset: Union[CartesianOffset, oir.VariableKOffset]
) -> "CartesianIterationSpace":

dict_offsets = offset.to_dict()
return CartesianIterationSpace(
i_interval=oir.Interval(
start=oir.AxisBound.from_start(min(0, offset.i)),
end=oir.AxisBound.from_end(max(0, offset.i)),
start=oir.AxisBound.from_start(min(0, dict_offsets["i"])),
end=oir.AxisBound.from_end(max(0, dict_offsets["i"])),
),
j_interval=oir.Interval(
start=oir.AxisBound.from_start(min(0, offset.j)),
end=oir.AxisBound.from_end(max(0, offset.j)),
start=oir.AxisBound.from_start(min(0, dict_offsets["j"])),
end=oir.AxisBound.from_end(max(0, dict_offsets["j"])),
),
)

Expand Down
4 changes: 4 additions & 0 deletions src/gtc/gtcpp/gtcpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ class IfStmt(common.IfStmt[Stmt, Expr], Stmt):
pass


class While(common.While[Stmt, Expr], Stmt):
pass


class UnaryOp(common.UnaryOp[Expr], Expr):
pass

Expand Down
2 changes: 2 additions & 0 deletions src/gtc/gtcpp/gtcpp_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ def visit_LoopOrder(self, looporder: LoopOrder, **kwargs: Any) -> str:
"""
)

While = as_mako("while(${cond}) {${''.join(body)}}")

BlockStmt = as_mako("{${''.join(body)}}")

def visit_GTComputationCall(
Expand Down
5 changes: 5 additions & 0 deletions src/gtc/gtcpp/oir_to_gtcpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,11 @@ def visit_MaskStmt(self, node: oir.MaskStmt, **kwargs: Any) -> gtcpp.IfStmt:
true_branch=gtcpp.BlockStmt(body=self.visit(node.body, **kwargs)),
)

def visit_While(self, node: oir.While, **kwargs: Any) -> gtcpp.While:
return gtcpp.While(
cond=self.visit(node.cond, **kwargs), body=self.visit(node.body, **kwargs)
)

def visit_HorizontalExecution(
self,
node: oir.HorizontalExecution,
Expand Down
Loading