Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
193 commits
Select commit Hold shift + click to select a range
d6e8732
Add concat_where frontend and domain inference
SF-N Oct 25, 2024
69f6b11
Finish domain inference for (nested) concat_where and transform to as…
SF-N Oct 25, 2024
05e74c2
fix merge conflicts
havogt Jan 20, 2025
c3a18c4
Merge origin/main
tehrengruber Jan 29, 2025
ba8343b
Extend concat_where, now also working for nested concat_wheres and ex…
SF-N Jan 30, 2025
f90329e
Some fixes, tuples still not supported
SF-N Jan 31, 2025
401d9dd
Merge branch 'main' into GTIR_concat_where
SF-N Jan 31, 2025
b49a82d
Some updates for concat where, which were necessary when using it in …
SF-N Feb 5, 2025
2219314
Merge branch 'main' into GTIR_concat_where
SF-N Feb 5, 2025
9eb428a
Merge origin/main
tehrengruber Feb 14, 2025
d16bbd5
ITIR type inference: store param type in Lambda
tehrengruber Feb 15, 2025
aca4824
Merge branch 'main' into store_lambda_param_type
tehrengruber Feb 17, 2025
813f328
Flatten as_fieldop tuple arguments
tehrengruber Feb 18, 2025
3745461
Add support for scan and nested tuples
tehrengruber Feb 19, 2025
1f23e17
Merge branch 'collapse_tuple_as_fieldop_args' into GTIR_concat_where
tehrengruber Feb 19, 2025
8bec9ab
Merge branch 'store_lambda_param_type' into GTIR_concat_where
tehrengruber Feb 19, 2025
06806fb
Preserve annex on new nodes
tehrengruber Feb 19, 2025
bab4fe1
Fix unnecessary import
tehrengruber Feb 19, 2025
6257a2b
Merge branch 'eve_annex_preserve_new_node' into GTIR_concat_where
tehrengruber Feb 19, 2025
14b4bf3
Cleanup
tehrengruber Feb 19, 2025
fc20d7c
Fix doctest
tehrengruber Feb 19, 2025
c5fba83
Fix failing tests
tehrengruber Feb 19, 2025
fa17228
Merge branch 'store_lambda_param_type' into collapse_tuple_as_fieldop…
tehrengruber Feb 19, 2025
04ae430
Fix tests
tehrengruber Feb 19, 2025
5136adc
Fix tests
tehrengruber Feb 19, 2025
5939618
Cleanup frontend type deduction
tehrengruber Feb 19, 2025
157b0e2
Cleanup frontend type deduction
tehrengruber Feb 19, 2025
435d057
Cleanup concat where:
tehrengruber Feb 20, 2025
5e5c66e
Merge branch 'eve_annex_preserve_new_node' into GTIR_concat_where
tehrengruber Feb 20, 2025
bd8dbaa
Fix iterator tests
tehrengruber Feb 20, 2025
2c14648
Fix infer domain ops
tehrengruber Feb 20, 2025
a7f3cac
Merge branch 'collapse_tuple_as_fieldop_args' into GTIR_concat_where
tehrengruber Feb 20, 2025
1200803
Cleanup
tehrengruber Feb 20, 2025
cf0ffb2
Fix format
tehrengruber Feb 20, 2025
335e932
Fix broken scan (e.g. test_tuple_scalar_scan)
tehrengruber Feb 20, 2025
7518b9c
Merge branch 'collapse_tuple_as_fieldop_args' into GTIR_concat_where
tehrengruber Feb 20, 2025
39652de
Merge remote-tracking branch 'origin/main' into GTIR_concat_where
tehrengruber Feb 20, 2025
ba03c7e
Merge remote-tracking branch 'origin/main' into collapse_tuple_as_fie…
tehrengruber Feb 20, 2025
71980af
Merge branch 'collapse_tuple_as_fieldop_args' into GTIR_concat_where
tehrengruber Feb 20, 2025
c18b7ad
Fix failing tests
tehrengruber Feb 20, 2025
d399c65
Fix format
tehrengruber Feb 20, 2025
5ad7701
Fix failing tests
tehrengruber Feb 20, 2025
d3957bd
Fix format
tehrengruber Feb 20, 2025
e95fdf0
Merge branch 'collapse_tuple_as_fieldop_args' into GTIR_concat_where
tehrengruber Feb 20, 2025
b52a07c
Cleanup
tehrengruber Feb 20, 2025
f8703b2
Merge branch 'collapse_tuple_as_fieldop_args' into GTIR_concat_where
tehrengruber Feb 20, 2025
c5c3e5f
Fix pyproject.toml test marker
tehrengruber Feb 20, 2025
f59fabf
Remove unnecessary visits
tehrengruber Feb 20, 2025
c8e06bd
Cleanup trace shifts
tehrengruber Feb 20, 2025
f748da7
Fix type inference
tehrengruber Feb 20, 2025
45f8b09
Add concat_where transforms to field view transforms
tehrengruber Feb 20, 2025
b3647bf
Fix typo
tehrengruber Feb 20, 2025
6ea11e5
Add support for tuples
tehrengruber Feb 20, 2025
60d0d9a
Fixes
tehrengruber Feb 20, 2025
93a6d33
Merge remote-tracking branch 'origin/main' into GTIR_concat_where
tehrengruber Feb 21, 2025
132e576
Improve docs
tehrengruber Feb 21, 2025
e469075
Improve docs
tehrengruber Feb 21, 2025
24e2f57
Fix typo
tehrengruber Feb 21, 2025
d14fb21
Cleanup & improve test coverage
tehrengruber Feb 24, 2025
1e3ced5
Cleanup
tehrengruber Feb 24, 2025
595b675
Cleanup
tehrengruber Feb 24, 2025
59a1226
Improve type inference for concat_where tuple case
tehrengruber Feb 28, 2025
f832a19
Fix typo
tehrengruber Feb 28, 2025
75cc4f2
Fix bug in infer domain ops
tehrengruber Mar 2, 2025
6e85bd0
Address review comments
tehrengruber Mar 2, 2025
a8b9736
Merge remote-tracking branch 'origin_tehrengruber/store_lambda_param_…
tehrengruber Mar 2, 2025
9978a43
Address review comments
tehrengruber Mar 2, 2025
232d4b8
Address review comments
tehrengruber Mar 2, 2025
57abfaf
Merge remote-tracking branch 'origin/main' into store_lambda_param_type
tehrengruber Mar 2, 2025
f488b1a
Merge remote-tracking branch 'origin/main' into GTIR_concat_where
tehrengruber Mar 3, 2025
55dc611
Merge branch 'store_lambda_param_type' into GTIR_concat_where
tehrengruber Mar 3, 2025
2674f11
Merge origin/main
tehrengruber Mar 3, 2025
d0f93be
Fix deferred type in concat_where
tehrengruber Mar 3, 2025
cf50a37
Fix tuple concat_where (not fully done yet)
tehrengruber Mar 3, 2025
5fc42ce
Fix tuple concat_where (not fully done yet)
tehrengruber Mar 3, 2025
77edc98
Unclean fixes (revert tuple lowering)
tehrengruber Mar 11, 2025
1a4bf3a
Enable laplacian test
tehrengruber Mar 14, 2025
ac0625f
Merge origin/main
tehrengruber Mar 21, 2025
1af561e
Merge origin/main
tehrengruber Mar 21, 2025
1ab8c69
embedded concat_where
havogt Mar 21, 2025
73fba27
Merge branch 'GTIR_concat_where' of github.com:SF-N/gt4py into GTIR_c…
havogt Mar 21, 2025
ae07826
add support for more comparison operators
havogt Mar 21, 2025
a8fe04e
change Dimension comparison
havogt Mar 22, 2025
40cf33b
embedded: non-python int comparison
havogt Mar 23, 2025
8c4fc45
Fix import
tehrengruber Apr 14, 2025
ac6fbb4
Merge origin/main
tehrengruber Apr 14, 2025
2731432
Merge remote-tracking branch 'origin/main' into GTIR_concat_where
tehrengruber Apr 15, 2025
bf7ae21
Merge remote-tracking branch 'upstream/main' into GTIR_concat_where
havogt Apr 17, 2025
b8b80f8
Merge origin/main
tehrengruber Apr 24, 2025
1f49b71
Merge remote-tracking branch 'origin_sf_n/GTIR_concat_where' into GTI…
tehrengruber Apr 24, 2025
a41eb1a
feat[next]: GTIR concat_where frontend
havogt Apr 24, 2025
e2c053c
disable concat_where tests
havogt Apr 24, 2025
4b46fcd
one more it_ts.DomainType
havogt Apr 24, 2025
d77a4c0
add test for concat_where on scalars and fix typing
havogt Apr 24, 2025
f41c112
add test for chained comparison
havogt Apr 25, 2025
de287c2
Merge remote-tracking branch 'origin/main' into GTIR_concat_where
tehrengruber May 1, 2025
3a46c71
Merge branch 'main' into GTIR_concat_where
tehrengruber May 2, 2025
13baa21
Merge branch 'main' into GTIR_concat_where
tehrengruber May 7, 2025
4d41d86
Merge remote-tracking branch 'origin/main' into GTIR_concat_where
tehrengruber May 9, 2025
c88adf7
Merge origin/main
tehrengruber May 18, 2025
359a921
Fix broken merge
tehrengruber May 18, 2025
4764c2b
Simplify tuple lowering, unit tests, cleanup
tehrengruber May 18, 2025
62db9ff
Small fix
tehrengruber May 18, 2025
7053c39
Cleanup
tehrengruber May 18, 2025
d597a4d
Cleanup
tehrengruber May 18, 2025
45ccbbc
Cleanup
tehrengruber May 18, 2025
0326e80
Add more unit tests
tehrengruber May 20, 2025
d069b67
Merge branch 'main' into GTIR_concat_where
edopao May 20, 2025
1ce4ed4
Cleanup
tehrengruber May 22, 2025
62688b2
Cleanup
tehrengruber May 23, 2025
afdd60c
Merge remote-tracking branch 'origin/main' into GTIR_concat_where
tehrengruber May 23, 2025
eb9adf9
Cleanup
tehrengruber May 23, 2025
dc98855
Cleanup
tehrengruber May 23, 2025
b4e5fd1
Cleanup
tehrengruber May 23, 2025
e1b9d88
Merge commit 'a41eb1a' into GTIR_concat_where
tehrengruber May 23, 2025
dfd2b14
Merge remote-tracking branch 'origin/main' into concat_where_frontend
tehrengruber May 23, 2025
ec6b305
Merge branch 'concat_where_frontend' into GTIR_concat_where (#1998)
tehrengruber May 23, 2025
0bd26ce
Cleanup
tehrengruber May 23, 2025
e5dbf4a
Cleanup
tehrengruber May 23, 2025
0e8faad
Fix dace
tehrengruber May 23, 2025
77b8efe
Merge branch 'main' into GTIR_concat_where
edopao May 26, 2025
c497022
Merge branch 'main' into GTIR_concat_where
edopao May 27, 2025
cfe389d
Merge branch 'main' into GTIR_concat_where
edopao May 28, 2025
599d8f1
Merge remote-tracking branch 'upstream/main' into concat_where_frontend
havogt Jun 3, 2025
2499389
remove unchain comparison (because doesn't make sense)
havogt Jun 4, 2025
398ec68
improve error messages
havogt Jun 4, 2025
f81393a
fix chain test
havogt Jun 4, 2025
eae7dc7
simplify typing
havogt Jun 4, 2025
16e1c65
rename
havogt Jun 4, 2025
5f7e251
add promotion tests
havogt Jun 5, 2025
b1e8f89
Fix small type inference bug
tehrengruber Jun 4, 2025
35c026e
Merge branch 'main' into GTIR_concat_where
tehrengruber Jun 5, 2025
fb01638
Merge branch 'concat_where_frontend' into GTIR_concat_where
tehrengruber Jun 5, 2025
06905b8
Merge branch 'main' into GTIR_concat_where
tehrengruber Jun 5, 2025
a66e5ca
Merge remote-tracking branch 'origin/main' into concat_where_frontend
tehrengruber Jun 5, 2025
d89cff6
Backport fixes from main PR
tehrengruber Jun 5, 2025
6f9ebff
Merge branch 'concat_where_frontend' into GTIR_concat_where
tehrengruber Jun 5, 2025
3dac495
Cleanup
tehrengruber Jun 5, 2025
506c2b5
Extract concat_where transformations
tehrengruber Jun 5, 2025
8ae99ae
Merge branch 'gtir_concat_where_passes' into GTIR_concat_where
tehrengruber Jun 5, 2025
af36bc9
Small fix
tehrengruber Jun 5, 2025
f1a99bd
Format
tehrengruber Jun 5, 2025
1f6b284
Format
tehrengruber Jun 5, 2025
45a2e23
Cleanup
tehrengruber Jun 5, 2025
721dde3
Merge branch 'gtir_concat_where_passes' into GTIR_concat_where
tehrengruber Jun 5, 2025
52c96ed
Cleanup
tehrengruber Jun 6, 2025
034c660
Cleanup
tehrengruber Jun 6, 2025
c460459
Cleanup
tehrengruber Jun 6, 2025
bbf0016
Cleanup
tehrengruber Jun 6, 2025
cbeee8e
Cleanup
tehrengruber Jun 6, 2025
9d179c6
Fix infer domain ops
tehrengruber Jun 6, 2025
724b0fe
Merge branch 'main' into gtir_concat_where_passes
havogt Jun 6, 2025
a421f79
Fix failing doctest
tehrengruber Jun 6, 2025
56dcbc4
Merge remote-tracking branch 'origin_tehrengruber/gtir_concat_where_p…
tehrengruber Jun 6, 2025
f770961
Merge branch 'main' into GTIR_concat_where
edopao Jun 11, 2025
aadf582
remove uses_concat_where from COMMON_SKIP_TEST_LIST
edopao Jun 11, 2025
0957ca9
Merge branch 'main' into GTIR_concat_where
edopao Jun 12, 2025
b056955
Merge branch 'main' into GTIR_concat_where
edopao Jun 13, 2025
fb074fe
Merge branch 'main' into GTIR_concat_where
edopao Jun 23, 2025
fc2df23
Merge branch 'main' into GTIR_concat_where
edopao Jun 25, 2025
2bd9d3b
add test cases for empty branches
edopao Jun 25, 2025
3c78ec1
Merge branch 'main' into GTIR_concat_where
edopao Jun 25, 2025
78a61ca
extend test case scalar_broadcast_on_empty_branch
edopao Jun 27, 2025
a9f146e
Merge branch 'main' into GTIR_concat_where
edopao Jun 27, 2025
31da410
pre-commit - format code
edopao Jun 27, 2025
71f46f3
Merge branch 'main' into GTIR_concat_where
edopao Jul 4, 2025
8c642b8
address review comments
havogt Jul 8, 2025
7877f6d
add domain_utils tests
havogt Jul 9, 2025
086910c
refactor domain ops
havogt Jul 9, 2025
73b333d
Merge remote-tracking branch 'upstream/main' into gtir_concat_where_p…
havogt Jul 9, 2025
58a0492
fix formatting
havogt Jul 9, 2025
81b9309
add type inference test
havogt Jul 9, 2025
a176174
cleanup
havogt Jul 9, 2025
6a1087d
delete an obsolete assert
havogt Jul 9, 2025
a3bdfed
Merge branch 'gtir_concat_where_passes' into GTIR_concat_where
havogt Jul 9, 2025
3678a38
remove embedded implementation
havogt Jul 9, 2025
94ef41a
address review comments
havogt Jul 9, 2025
8c0b1fc
Merge branch 'gtir_concat_where_passes' into GTIR_concat_where
havogt Jul 9, 2025
8a9a172
Merge remote-tracking branch 'upstream/main' into GTIR_concat_where
havogt Jul 9, 2025
de19ee9
fix merge conflict
havogt Jul 10, 2025
0feead0
document some tests
havogt Jul 10, 2025
7ba38d5
fix test structure in constant folding
havogt Jul 10, 2025
4722c05
remove resolved todos
havogt Jul 10, 2025
06a70e3
refactorings
havogt Jul 10, 2025
d3bcc56
Merge remote-tracking branch 'upstream/main' into GTIR_concat_where
havogt Jul 10, 2025
9a116c1
cleanup test
havogt Jul 10, 2025
8635b2c
fix[next]: symbol clash in inline_lambda
havogt Jul 10, 2025
ac442f5
Merge branch 'fix_symbol_clash_inline_lambda' into GTIR_concat_where
havogt Jul 10, 2025
c72f494
cleanup
havogt Jul 11, 2025
8fa1a75
Merge branch 'fix_symbol_clash_inline_lambda' into GTIR_concat_where
havogt Jul 11, 2025
9fe9c5c
address review comments
havogt Jul 11, 2025
8f71edf
cleanup todo
havogt Jul 11, 2025
47be113
Merge remote-tracking branch 'upstream/main' into GTIR_concat_where
havogt Jul 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,9 @@ def create_if(

return im.let(cond_symref_name, cond_)(result)

_visit_concat_where = _visit_where # TODO(havogt): upgrade concat_where
def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
domain, true_branch, false_branch = self.visit(node.args, **kwargs)
return im.concat_where(domain, true_branch, false_branch)

def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
return im.call("broadcast")(*self.visit(node.args, **kwargs))
Expand Down Expand Up @@ -488,7 +490,7 @@ def _map(
Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_`ing lists.
"""
if all(
isinstance(t, ts.ScalarType)
isinstance(t, (ts.ScalarType, ts.DimensionType, ts.DomainType))
for arg_type in original_arg_types
for t in type_info.primitive_constituents(arg_type)
):
Expand Down
6 changes: 6 additions & 0 deletions src/gt4py/next/iterator/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,11 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing]
raise BackendNotSelectedError()


@builtin_dispatch
def concat_where(*args):
raise BackendNotSelectedError()


UNARY_MATH_NUMBER_BUILTINS = {"abs", "neg"}
UNARY_LOGICAL_BUILTINS = {"not_"}
UNARY_MATH_FP_BUILTINS = {
Expand Down Expand Up @@ -494,6 +499,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing]
"scan",
"tuple_get",
"unstructured_domain",
"concat_where",
*ARITHMETIC_BUILTINS,
*TYPE_BUILTINS,
}
Expand Down
5 changes: 5 additions & 0 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,6 +1802,11 @@ def index(axis: common.Dimension) -> common.Field:
return IndexField(axis)


@builtins.concat_where.register(EMBEDDED)
def concat_where(*args):
raise NotImplementedError("To be implemented in frontend embedded.")


def closure(
domain_: runtime.CartesianDomain | runtime.UnstructuredDomain,
sten: Callable[..., Any],
Expand Down
7 changes: 4 additions & 3 deletions src/gt4py/next/iterator/transforms/collapse_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _with_altered_iterator_position_dims(
)


def _is_trivial_make_tuple_call(node: ir.Expr):
def _is_trivial_make_tuple_call(node: itir.Expr):
"""Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof."""
if not cpm.is_call_to(node, "make_tuple"):
return False
Expand Down Expand Up @@ -307,9 +307,10 @@ def transform_propagate_tuple_get(self, node: itir.FunCall, **kwargs) -> Optiona
self.fp_transform(im.tuple_get(idx.value, expr.fun.expr), **kwargs)
)
)(*expr.args)
elif cpm.is_call_to(expr, "if_"):
elif cpm.is_call_to(expr, ("if_", "concat_where")):
fun = expr.fun
cond, true_branch, false_branch = expr.args
return im.if_(
return im.call(fun)(
cond,
self.fp_transform(im.tuple_get(idx.value, true_branch), **kwargs),
self.fp_transform(im.tuple_get(idx.value, false_branch), **kwargs),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def _in(pos: itir.Expr, domain: itir.Expr) -> itir.Expr:
"""
Given a position and a domain return an expression that evaluates to `True` if the position is inside the domain.

`in_({i, j, k}, u⟨ Iₕ: [i0, i1[, Iₕ: [j0, j1[, Iₕ: [k0, k1[ ⟩`
-> `i0 <= i < i1 & j0 <= j < j1 & k0 <= k < k1`
pos = `{i, j, k}`, domain = `u⟨ Iₕ: [i0, i1[, Iₕ: [j0, j1[, Iₕ: [k0, k1[ ⟩`
-> `((i0 <= i) & (i < i1)) & ((j0 <= j) & (j < j1)) & ((k0 <= k)l & (k < k1))`
"""
ret = [
im.and_(
Expand Down
59 changes: 59 additions & 0 deletions src/gt4py/next/iterator/transforms/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class Transformation(enum.Flag):
# `if_(True, true_branch, false_branch)` -> `true_branch`
FOLD_IF = enum.auto()

FOLD_INFINITY_ARITHMETIC = enum.auto()

@classmethod
def all(self) -> ConstantFolding.Transformation:
return functools.reduce(operator.or_, self.__members__.values())
Expand Down Expand Up @@ -239,3 +241,60 @@ def transform_fold_if(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
assert node.args[0].value == "False"
return node.args[2]
return None

def transform_fold_infinity_arithmetic(self, node: ir.FunCall) -> Optional[ir.Node]:
if cpm.is_call_to(node, "plus"):
# `a + +/-inf` -> `+/-inf`
a, b = node.args
assert not (isinstance(a, ir.InfinityLiteral) and isinstance(b, ir.InfinityLiteral))
for arg in a, b:
if isinstance(arg, ir.InfinityLiteral):
return arg

if cpm.is_call_to(node, "minimum"):
if ir.InfinityLiteral.NEGATIVE in node.args:
# `minimum(-inf, a)` -> `-inf`
return ir.InfinityLiteral.NEGATIVE
if ir.InfinityLiteral.POSITIVE in node.args:
# `minimum(inf, a)` -> `a`
a, b = node.args
return b if a == ir.InfinityLiteral.POSITIVE else a

if cpm.is_call_to(node, "maximum"):
if ir.InfinityLiteral.POSITIVE in node.args:
# `maximum(inf, a)` -> `inf`
return ir.InfinityLiteral.POSITIVE
if ir.InfinityLiteral.NEGATIVE in node.args:
# `maximum(-inf, a)` -> `a`
a, b = node.args
return b if a == ir.InfinityLiteral.NEGATIVE else a

if cpm.is_call_to(node, ("less", "less_equal")):
a, b = node.args
# we don't handle `inf < inf` or `-inf < -inf`.args
assert a != b or not isinstance(a, ir.InfinityLiteral)

# `-inf < v` -> `True`
# `v < inf` -> `True`
if a == ir.InfinityLiteral.NEGATIVE or b == ir.InfinityLiteral.POSITIVE:
return im.literal_from_value(True)
# `inf < v` -> `False`
# `v < -inf ` -> `False`
if a == ir.InfinityLiteral.POSITIVE or b == ir.InfinityLiteral.NEGATIVE:
return im.literal_from_value(False)

if cpm.is_call_to(node, ("greater", "greater_equal")):
a, b = node.args
# we don't handle `inf > inf` or `-inf > -inf`.args
assert a != b or not isinstance(a, ir.InfinityLiteral)

# `inf > v` -> `True`
# `v > -inf ` -> `True`
if a == ir.InfinityLiteral.POSITIVE or b == ir.InfinityLiteral.NEGATIVE:
return im.literal_from_value(True)
# `-inf > v` -> `False`
# `v > inf` -> `False`
if a == ir.InfinityLiteral.NEGATIVE or b == ir.InfinityLiteral.POSITIVE:
return im.literal_from_value(False)

return None
12 changes: 11 additions & 1 deletion src/gt4py/next/iterator/transforms/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,19 @@ def _is_collectable_expr(node: itir.Node) -> bool:
# backend (single pass eager depth first visit approach)
# do also not collect lifts or applied lifts as they become invisible to the lift inliner
# otherwise
if cpm.is_call_to(node, ("lift", "shift", "reduce", "map_")) or cpm.is_applied_lift(node):
# do also not collect index nodes because otherwise the right hand side of SetAts becomes a let statement
# instead of an as_fieldop
if cpm.is_call_to(
node, ("lift", "shift", "reduce", "map_", "index")
) or cpm.is_applied_lift(node):
return False
return True
# do also not collect make_tuple(index) nodes because otherwise the right hand side of SetAts becomes a let statement
# instead of an as_fieldop
if cpm.is_call_to(node, "make_tuple") and all(
cpm.is_call_to(arg, "index") for arg in node.args
):
return False
elif isinstance(node, itir.Lambda):
return True

Expand Down
3 changes: 0 additions & 3 deletions src/gt4py/next/iterator/transforms/fuse_as_fieldop.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,4 @@ def visit(self, node, **kwargs):

node = super().visit(node, **kwargs)

if isinstance(node, itir.Expr) and hasattr(node.annex, "domain"):
node.annex.domain = node.annex.domain

return node
11 changes: 8 additions & 3 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,11 +329,16 @@ def create_global_tmps(
This pass looks at all `as_fieldop` calls and transforms field-typed subexpressions of its
arguments into temporaries.
"""
offset_provider_type = common.offset_provider_to_type(offset_provider)
# TODO(tehrengruber): document why to keep existing domains and add test
program = infer_domain.infer_program(
program, offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes
program,
offset_provider=offset_provider,
symbolic_domain_sizes=symbolic_domain_sizes,
keep_existing_domains=True,
)
program = type_inference.infer(
program, offset_provider_type=common.offset_provider_to_type(offset_provider)
)
program = type_inference.infer(program, offset_provider_type=offset_provider_type)

if not uids:
uids = eve_utils.UIDGenerator(prefix="__tmp")
Expand Down
13 changes: 13 additions & 0 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
from gt4py.next import common
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.transforms import (
concat_where,
dead_code_elimination,
fuse_as_fieldop,
global_tmps,
infer_domain,
infer_domain_ops,
inline_dynamic_shifts,
inline_fundefs,
inline_lifts,
Expand Down Expand Up @@ -81,13 +83,19 @@ def apply_common_transforms(
ir = inline_dynamic_shifts.InlineDynamicShifts.apply(
ir
) # domain inference does not support dynamic offsets yet
ir = infer_domain_ops.InferDomainOps.apply(ir)
ir = concat_where.canonicalize_domain_argument(ir)

ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program
ir = infer_domain.infer_program(
ir,
offset_provider=offset_provider,
symbolic_domain_sizes=symbolic_domain_sizes,
)
ir = remove_broadcast.RemoveBroadcast.apply(ir)

ir = concat_where.transform_to_as_fieldop(ir)

for _ in range(10):
inlined = ir

Expand Down Expand Up @@ -183,6 +191,11 @@ def apply_fieldview_transforms(
ir = inline_dynamic_shifts.InlineDynamicShifts.apply(
ir
) # domain inference does not support dynamic offsets yet

ir = infer_domain_ops.InferDomainOps.apply(ir)
ir = concat_where.canonicalize_domain_argument(ir)
ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program

ir = infer_domain.infer_program(ir, offset_provider=offset_provider)
ir = remove_broadcast.RemoveBroadcast.apply(ir)
return ir
3 changes: 3 additions & 0 deletions src/gt4py/next/iterator/transforms/trace_shifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ class TraceShifts(PreserveLocationVisitor, NodeTranslator):
def visit_Literal(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any:
return Sentinel.VALUE

def visit_InfinityLiteral(self, node: ir.InfinityLiteral, *, ctx: dict[str, Any]):
return Sentinel.VALUE

def visit_SymRef(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any:
if node.id in ctx:
return ctx[node.id]
Expand Down
6 changes: 5 additions & 1 deletion src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,11 @@ def _values_validator(
) -> None:
if not all(
isinstance(el, (SidFromScalar, SidComposite))
or _is_tuple_expr_of(lambda expr: isinstance(expr, (SymRef, Literal)), el)
or _is_tuple_expr_of(
lambda expr: isinstance(expr, (SymRef, Literal))
or (isinstance(expr, FunCall) and expr.fun == SymRef(id="index")),
el,
)
for el in value
):
raise ValueError(
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/next/program_processors/runners/gtfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ class Params:

run_gtfn_gpu = GTFNBackendFactory(gpu=True)

run_gtfn_gpu_cached = GTFNBackendFactory(gpu=True, cached=True)
run_gtfn_gpu_cached = GTFNBackendFactory(
gpu=True, cached=True, otf_workflow__cached_translation=True
)

run_gtfn_no_transforms = GTFNBackendFactory(
otf_workflow__bare_translation__enable_itir_transforms=False
Expand Down
7 changes: 5 additions & 2 deletions tests/next_tests/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
(USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE),
(USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE),
(USES_TUPLES_ARGS_WITH_DIFFERENT_BUT_PROMOTABLE_DIMS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE),
]
# Markers to skip because of missing features in the domain inference
DOMAIN_INFERENCE_SKIP_LIST = [
Expand All @@ -161,6 +160,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
(USES_SCAN_IN_STENCIL, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE),
(USES_SPARSE_FIELDS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_TUPLE_ITERATOR, XFAIL, UNSUPPORTED_MESSAGE),
(USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE),
]
)
EMBEDDED_SKIP_LIST = [
Expand All @@ -179,6 +179,9 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
(USES_TUPLES_ARGS_WITH_DIFFERENT_BUT_PROMOTABLE_DIMS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE),
]
GTIR_EMBEDDED_SKIP_LIST = ROUNDTRIP_SKIP_LIST + [
(USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE),
]
GTFN_SKIP_TEST_LIST = (
COMMON_SKIP_TEST_LIST
+ DOMAIN_INFERENCE_SKIP_LIST
Expand Down Expand Up @@ -219,5 +222,5 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
(ALL, XFAIL, UNSUPPORTED_MESSAGE),
(USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE),
],
ProgramBackendId.GTIR_EMBEDDED: ROUNDTRIP_SKIP_LIST,
ProgramBackendId.GTIR_EMBEDDED: GTIR_EMBEDDED_SKIP_LIST,
}
Loading