Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
768e458
First draft
tehrengruber Mar 2, 2025
ac7db53
Remove debugging leftovers
tehrengruber Mar 2, 2025
ff6e23b
Merge branch 'main' into get_domain_builtin
tehrengruber Jul 17, 2025
3731810
Get domain from tuple element
tehrengruber Jul 30, 2025
4bd8a51
Merge origin/main
tehrengruber Jul 30, 2025
0e4eb57
Rename get_domain to get_domain_range
tehrengruber Aug 18, 2025
d916337
Remove compile time args
tehrengruber Aug 18, 2025
25e24e9
Fix format
tehrengruber Aug 18, 2025
e856a18
Fix failing tests
tehrengruber Aug 18, 2025
00a11a6
Fix failing tests
tehrengruber Aug 18, 2025
fd509c8
Merge remote-tracking branch 'origin/main' into get_domain_builtin
tehrengruber Aug 20, 2025
f0bb72f
Extend prototype for multiple output domains
SF-N Aug 22, 2025
8f8e228
Merge branch 'main' into multiple_output_domains
SF-N Aug 22, 2025
d865514
Fix some tests
SF-N Aug 22, 2025
65a831b
Start working on direct fo calls with multiple output domains
SF-N Aug 22, 2025
510ea69
Add tests
SF-N Aug 22, 2025
b910167
Fix embedded domain promotion
SF-N Aug 22, 2025
1066d60
Merge branch 'main' into multiple_output_domains
SF-N Aug 25, 2025
8b02398
Add more tests and extend type deduction
SF-N Aug 29, 2025
de0e1df
Merge branch 'main' into multiple_output_domains
SF-N Aug 29, 2025
5f75aff
Merge tehrengruber:get_domain_builtin
SF-N Aug 29, 2025
1135e00
Merge branch 'multiple_output_domains' of github.com:SF-N/gt4py into …
SF-N Aug 29, 2025
17bf432
Extend for nested tuples
SF-N Sep 1, 2025
87fa841
Cleanup tests
SF-N Sep 1, 2025
54dde30
Extend to also work for out arg that is a tuple
SF-N Sep 2, 2025
f6e6b35
Merge branch 'main' into multiple_output_domains
SF-N Sep 5, 2025
c280514
Add tests with restricted domain and extend to construct domain tuple…
SF-N Sep 5, 2025
3d91e4c
Merge branch 'main' into multiple_output_domains
SF-N Sep 7, 2025
3123d3f
Merge main
SF-N Sep 9, 2025
002b4c8
Clean up
SF-N Sep 9, 2025
1e86519
Extend and refactor to fix tests
SF-N Sep 23, 2025
b86770c
Merge branch 'main' into multiple_output_domains
SF-N Sep 23, 2025
ef9a29e
Fix several tests
SF-N Sep 23, 2025
5c1edae
Enable multiple output domains in direct fo calls and fix some tests
SF-N Sep 23, 2025
e357b89
Refactor and make slices work
SF-N Sep 24, 2025
93290f8
Merge main
SF-N Sep 30, 2025
3a50b60
Remove num_levels from unstructured meshes and reformat
SF-N Sep 30, 2025
5b1bf86
Remove num_levels from MeshDescriptor
SF-N Oct 1, 2025
0440779
Try to refactor Domain vs DomainLike
SF-N Oct 1, 2025
4e72b82
Update tests and address TODO
SF-N Oct 1, 2025
9658d67
Revert unintensional change and pdate tests
SF-N Oct 1, 2025
21b8ca6
Minor
SF-N Oct 2, 2025
98420e6
fix global tmps tuple splitting
havogt Oct 7, 2025
88401fe
nested direct fop call and cleanups
havogt Oct 7, 2025
008a437
Merge remote-tracking branch 'upstream/main' into multiple_output_dom…
havogt Oct 7, 2025
a07294d
cleanup
havogt Oct 7, 2025
0cc0bf4
improve past type deduction
havogt Oct 7, 2025
bfbfd36
fix domain type deduction
havogt Oct 7, 2025
93087d8
cleanup tree_map like operations
havogt Oct 7, 2025
5c771e2
Merge remote-tracking branch 'upstream/main' into multiple_output_dom…
havogt Oct 7, 2025
ddf95d0
refactor past_to_itir
havogt Oct 8, 2025
2b33bc5
Simplify and cleanup past_to_itir
tehrengruber Oct 19, 2025
35a8966
Fix doctest
tehrengruber Oct 19, 2025
afc3aa4
Merge branch 'main' into multiple_output_domains
SF-N Oct 20, 2025
638f9e2
Merge remote-tracking branch 'origin/main' into multiple_output_domains
tehrengruber Oct 22, 2025
e434146
Merge remote-tracking branch 'origin_sf_n/multiple_output_domains' in…
tehrengruber Oct 22, 2025
152e584
Merge branch 'main' into multiple_output_domains
edopao Oct 29, 2025
1888c2d
SDFG lowering of multiple output domains (#16)
edopao Oct 29, 2025
8229f34
address review comments in SDFG lowering
edopao Oct 30, 2025
c65107e
Merge branch 'main' into multiple_output_domains
edopao Oct 31, 2025
9d89093
edit comment
edopao Oct 31, 2025
6eddd52
apply review comments
edopao Oct 31, 2025
899b9a2
Merge branch 'main' into multiple_output_domains
SF-N Oct 31, 2025
40dccd7
remove normalize_domain
havogt Nov 3, 2025
97e0a33
rename
havogt Nov 3, 2025
ecdfcda
Merge branch 'main' into multiple_output_domains
SF-N Nov 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/gt4py/next/embedded/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from gt4py import eve
from gt4py._core import definitions as core_defs
from gt4py.eve import extended_typing as xtyping
from gt4py.next import common, errors, field_utils, utils
from gt4py.next.embedded import common as embedded_common, context as embedded_context
from gt4py.next.field_utils import get_array_ns
Expand Down Expand Up @@ -108,7 +109,9 @@ def field_operator_call(op: EmbeddedOperator[_R, _P], args: Any, kwargs: Any) ->

domain = kwargs.pop("domain", None)

out_domain = common.domain(domain) if domain is not None else _get_out_domain(out)
out_domain = (
utils.tree_map(common.domain)(domain) if domain is not None else _get_out_domain(out)
)

new_context_kwargs["closure_column_range"] = _get_vertical_range(out_domain)

Expand All @@ -128,6 +131,7 @@ def field_operator_call(op: EmbeddedOperator[_R, _P], args: Any, kwargs: Any) ->
return op(*args, **kwargs)


@utils.tree_map
def _get_vertical_range(domain: common.Domain) -> common.NamedRange | eve.NothingType:
vertical_dim_filtered = [nr for nr in domain if nr.dim.kind == common.DimensionKind.VERTICAL]
assert len(vertical_dim_filtered) <= 1
Expand All @@ -137,17 +141,19 @@ def _get_vertical_range(domain: common.Domain) -> common.NamedRange | eve.Nothin
def _tuple_assign_field(
target: tuple[common.MutableField | tuple, ...] | common.MutableField,
source: tuple[common.Field | tuple, ...] | common.Field,
domain: common.Domain,
domain: xtyping.MaybeNestedInTuple[common.Domain],
) -> None:
@utils.tree_map
def impl(target: common.MutableField, source: common.Field) -> None:
def impl(target: common.MutableField, source: common.Field, domain: common.Domain) -> None:
if isinstance(source, common.Field):
target[domain] = source[domain]
else:
assert core_defs.is_scalar_type(source)
target[domain] = source

impl(target, source)
if not isinstance(domain, tuple):
domain = utils.tree_map(lambda _: domain)(target)
impl(target, source, domain)


def _intersect_scan_args(
Expand Down
6 changes: 4 additions & 2 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,8 +714,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
raise errors.MissingArgumentError(None, "out", True)
out = kwargs.pop("out")
if "domain" in kwargs:
domain = common.domain(kwargs.pop("domain"))
out = utils.tree_map(lambda f: f[domain])(out)
domain = utils.tree_map(common.domain)(kwargs.pop("domain"))
if not isinstance(domain, tuple):
domain = utils.tree_map(lambda _: domain)(out)
out = utils.tree_map(lambda f, dom: f[dom])(out, domain)

args, kwargs = type_info.canonicalize_arguments(
self.foast_stage.foast_node.type, args, kwargs
Expand Down
93 changes: 65 additions & 28 deletions src/gt4py/next/ffront/past_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,53 @@ def _is_integral_scalar(expr: past.Expr) -> bool:
return isinstance(expr.type, ts.ScalarType) and type_info.is_integral(expr.type)


def _validate_domain_out(
dom: past.Dict | past.TupleExpr,
out: ts.TypeSpec,
is_nested: bool = False,
) -> None:
if isinstance(dom, past.Dict):
# Only reject tuple outputs if nested
if is_nested and isinstance(out, ts.TupleType):
raise ValueError("Domain dict cannot map to tuple outputs.")
assert not (is_nested and isinstance(out, past.TupleExpr))

if len(dom.keys_) == 0:
raise ValueError("Empty domain not allowed.")

for dim in dom.keys_:
if not isinstance(dim.type, ts.DimensionType):
raise ValueError(
f"Only 'Dimension' allowed in domain dictionary keys, got '{dim}' which is of type '{dim.type}'."
)

for domain_values in dom.values_:
if len(domain_values.elts) != 2:
raise ValueError(
f"Only 2 values allowed in domain range, got {len(domain_values.elts)}."
)
if any(not _is_integral_scalar(el) for el in domain_values.elts):
raise ValueError(
f"Only integer values allowed in domain range, got '{domain_values.elts[0].type}' and '{domain_values.elts[1].type}'."
)

elif isinstance(dom, past.TupleExpr):
if isinstance(out, ts.TupleType):
out_elts = out.types
else:
raise ValueError(f"Tuple domain requires tuple output, got {type(out)}.")

if len(dom.elts) != len(out_elts):
raise ValueError("Mismatched tuple lengths between domain and output.")

for d, o in zip(dom.elts, out_elts, strict=True):
assert isinstance(d, (past.Dict, past.TupleExpr))
_validate_domain_out(d, o, is_nested=True)

else:
raise ValueError(f"'domain' must be Dict or TupleExpr, got {type(dom)}.")


def _validate_operator_call(new_func: past.Name, new_kwargs: dict) -> None:
"""
Perform checks for domain and output field types.
Expand All @@ -53,32 +100,11 @@ def _validate_operator_call(new_func: past.Name, new_kwargs: dict) -> None:

if "out" not in new_kwargs:
raise ValueError("Missing required keyword argument 'out'.")
if "domain" in new_kwargs:
if (domain := new_kwargs.get("domain")) is not None:
_ensure_no_sliced_field(new_kwargs["out"])

domain_kwarg = new_kwargs["domain"]
if not isinstance(domain_kwarg, past.Dict):
raise ValueError(f"Only Dictionaries allowed in 'domain', got '{type(domain_kwarg)}'.")

if len(domain_kwarg.values_) == 0 and len(domain_kwarg.keys_) == 0:
raise ValueError("Empty domain not allowed.")

for dim in domain_kwarg.keys_:
if not isinstance(dim.type, ts.DimensionType):
raise ValueError(
f"Only 'Dimension' allowed in domain dictionary keys, got '{dim}' which is of type '{dim.type}'."
)
for domain_values in domain_kwarg.values_:
if len(domain_values.elts) != 2:
raise ValueError(
f"Only 2 values allowed in domain range, got {len(domain_values.elts)}."
)
if not _is_integral_scalar(domain_values.elts[0]) or not _is_integral_scalar(
domain_values.elts[1]
):
raise ValueError(
f"Only integer values allowed in domain range, got '{domain_values.elts[0].type}' and '{domain_values.elts[1].type}'."
)
out = new_kwargs["out"]
assert isinstance(out, past.Expr) and out.type is not None
_validate_domain_out(domain, out.type)


class ProgramTypeDeduction(traits.VisitorWithSymbolTableTrait, NodeTranslator):
Expand Down Expand Up @@ -131,11 +157,22 @@ def visit_Attribute(self, node: past.Attribute, **kwargs: Any) -> past.Attribute
type=getattr(new_value.type, node.attr),
)

def visit_Dict(self, node: past.Dict, **kwargs: Any) -> past.Dict:
# the only supported dict for now is in domain specification
keys = self.visit(node.keys_, **kwargs)
assert all(isinstance(key.type, ts.DimensionType) for key in keys)
return past.Dict(
keys_=keys,
values_=self.visit(node.values_, **kwargs),
location=node.location,
type=ts.DomainType(dims=[key.type.dim for key in keys]),
)

def visit_TupleExpr(self, node: past.TupleExpr, **kwargs: Any) -> past.TupleExpr:
elts = self.visit(node.elts, **kwargs)
return past.TupleExpr(
elts=elts, type=ts.TupleType(types=[el.type for el in elts]), location=node.location
)
ttype = ts.TupleType(types=[elt.type for elt in elts])

return past.TupleExpr(elts=elts, type=ttype, location=node.location)

def _deduce_binop_type(
self, node: past.BinOp, *, left: past.Expr, right: past.Expr, **kwargs: Any
Expand Down
Loading