Skip to content

Commit

Permalink
feat[next]: iterator.embedded with new Field implementation (#1308)
Browse files Browse the repository at this point in the history
Makes iterator.embedded work with the new `common.Field` implementation.

- Temporarily `iterator.embedded.np_as_located_field` returns `common.Field`. In a next step all users should switch to a different allocation function.
- `iterator.embedded` wraps the `common.Field` into a `LocatedField` wrapper to do some translation between `iterator.embedded` `field_getitem`/`field_setitem`. This layer should eventually be removed.
- Adds a minimal (= only with requirements for `iterator.embedded`) implementation of `IndexField` and `ConstantField` to `iterator.embedded`. These can be generalized to proper `Field`s suitable for fieldview embedded.
  • Loading branch information
havogt authored Sep 14, 2023
1 parent b64fdab commit 5157007
Show file tree
Hide file tree
Showing 25 changed files with 556 additions and 566 deletions.
562 changes: 344 additions & 218 deletions src/gt4py/next/iterator/embedded.py

Large diffs are not rendered by default.

6 changes: 2 additions & 4 deletions src/gt4py/next/iterator/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from gt4py.eve import Node
from gt4py.next import common, iterator
from gt4py.next.iterator import builtins, ir_makers as im
from gt4py.next.iterator.embedded import LocatedField
from gt4py.next.iterator.ir import (
AxisLiteral,
Expr,
Expand Down Expand Up @@ -253,9 +252,8 @@ def _contains_tuple_dtype_field(arg):
# various implementations have different behaviour (some return e.g. `np.dtype("int32")`
# other `np.int32`). We just ignore the error here and postpone fixing this to when
# the new storages land (The implementation here works for LocatedFieldImpl).
return isinstance(arg, LocatedField) and (
arg.dtype.fields is not None or any(dim is None for dim in arg.__gt_dims__)
)

return common.is_field(arg) and any(dim is None for dim in arg.__gt_dims__)


def _make_fencil_params(fun, args, *, use_arg_types: bool) -> list[Sym]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import numpy as np

import gt4py.next.iterator.ir as itir
from gt4py.next.iterator.embedded import LocatedField, NeighborTableOffsetProvider
from gt4py.next import common
from gt4py.next.iterator.embedded import NeighborTableOffsetProvider
from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms
from gt4py.next.otf.compilation import cache
from gt4py.next.program_processors.processor_interface import program_executor
Expand All @@ -29,11 +30,12 @@


def convert_arg(arg: Any):
if isinstance(arg, LocatedField):
if common.is_field(arg):
sorted_dims = sorted(enumerate(arg.__gt_dims__), key=lambda v: v[1].value)
ndim = len(sorted_dims)
dim_indices = [dim[0] for dim in sorted_dims]
return np.moveaxis(np.asarray(arg), range(ndim), dim_indices)
assert isinstance(arg.ndarray, np.ndarray)
return np.moveaxis(arg.ndarray, range(ndim), dim_indices)
return arg


Expand Down
7 changes: 3 additions & 4 deletions src/gt4py/next/program_processors/runners/gtfn_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from typing import Any

import numpy as np
import numpy.typing as npt

from gt4py.eve.utils import content_hash
Expand All @@ -32,9 +31,9 @@
def convert_arg(arg: Any) -> Any:
if isinstance(arg, tuple):
return tuple(convert_arg(a) for a in arg)
if hasattr(arg, "__array__") and hasattr(arg, "__gt_dims__"):
arr = np.asarray(arg)
origin = getattr(arg, "__gt_origin__", tuple([0] * arr.ndim))
if common.is_field(arg):
arr = arg.ndarray
origin = getattr(arg, "__gt_origin__", tuple([0] * len(arg.domain)))
return arr, origin
else:
return arg
Expand Down
21 changes: 10 additions & 11 deletions src/gt4py/next/program_processors/runners/roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@
from gt4py.next.program_processors.processor_interface import program_executor


def _create_tmp(axes, origin, shape, dtype):
if isinstance(dtype, tuple):
return f"({','.join(_create_tmp(axes, origin, shape, dt) for dt in dtype)},)"
else:
return f"gtx.np_as_located_field({axes}, origin={origin})(np.empty({shape}, dtype=np.dtype('{dtype}')))"


class EmbeddedDSL(codegen.TemplatedGenerator):
Sym = as_fmt("{id}")
SymRef = as_fmt("{id}")
Expand Down Expand Up @@ -60,14 +67,7 @@ def ${id}(${','.join(params)}):
def visit_FencilWithTemporaries(self, node, **kwargs):
params = self.visit(node.params)

def np_dtype(dtype):
if isinstance(dtype, int):
return params[dtype] + ".dtype"
if isinstance(dtype, tuple):
return "np.dtype([" + ", ".join(f"('', {np_dtype(d)})" for d in dtype) + "])"
return f"np.dtype('{dtype}')"

tmps = "\n ".join(self.visit(node.tmps, np_dtype=np_dtype))
tmps = "\n ".join(self.visit(node.tmps))
args = ", ".join(params + [tmp.id for tmp in node.tmps])
params = ", ".join(params)
fencil = self.visit(node.fencil)
Expand All @@ -79,7 +79,7 @@ def np_dtype(dtype):
+ f"\n {node.fencil.id}({args}, **kwargs)\n"
)

def visit_Temporary(self, node, *, np_dtype, **kwargs):
def visit_Temporary(self, node, **kwargs):
assert isinstance(node.domain, itir.FunCall) and node.domain.fun.id in (
"cartesian_domain",
"unstructured_domain",
Expand All @@ -92,8 +92,7 @@ def visit_Temporary(self, node, *, np_dtype, **kwargs):
axes = ", ".join(label for label, _, _ in domain_ranges)
origin = "{" + ", ".join(f"{label}: -{start}" for label, start, _ in domain_ranges) + "}"
shape = "(" + ", ".join(f"{stop}-{start}" for _, start, stop in domain_ranges) + ")"
dtype = np_dtype(node.dtype)
return f"{node.id} = gtx.np_as_located_field({axes}, origin={origin})(np.empty({shape}, dtype={dtype}))"
return f"{node.id} = {_create_tmp(axes, origin, shape, node.dtype)}"


_BACKEND_NAME = "roundtrip"
Expand Down
6 changes: 2 additions & 4 deletions src/gt4py/next/type_system/type_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,6 @@ def from_type_hint(
def from_value(value: Any) -> ts.TypeSpec:
# TODO(tehrengruber): use protocol from gt4py.next.common when available
# instead of importing from the embedded implementation
from gt4py.next.iterator.embedded import LocatedField

"""Make a symbol node from a Python value."""
# TODO(tehrengruber): What we expect here currently is a GTCallable. Maybe
# we should check for the protocol in the future?
Expand All @@ -185,9 +183,9 @@ def from_value(value: Any) -> ts.TypeSpec:
return candidate_type
elif isinstance(value, common.Dimension):
symbol_type = ts.DimensionType(dim=value)
elif isinstance(value, LocatedField):
elif common.is_field(value):
dims = list(value.__gt_dims__)
dtype = from_type_hint(value.dtype.type)
dtype = from_type_hint(value.dtype.scalar_type)
symbol_type = ts.FieldType(dims=dims, dtype=dtype)
elif isinstance(value, tuple):
# Since the elements of the tuple might be one of the special cases
Expand Down
15 changes: 8 additions & 7 deletions tests/next_tests/integration_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from gt4py.eve.extended_typing import Self
from gt4py.next import common
from gt4py.next.ffront import decorator
from gt4py.next.iterator import embedded
from gt4py.next.program_processors import processor_interface as ppi
from gt4py.next.type_system import type_specifications as ts, type_translation

Expand Down Expand Up @@ -75,7 +74,7 @@
C2E = gtx.FieldOffset("E2V", source=Edge, target=(Cell, C2EDim))

ScalarValue: TypeAlias = np.int32 | np.int64 | np.float32 | np.float64 | np.generic
FieldValue: TypeAlias = gtx.Field | embedded.LocatedFieldImpl
FieldValue: TypeAlias = gtx.Field
FieldViewArg: TypeAlias = FieldValue | ScalarValue | tuple["FieldViewArg", ...]
FieldViewInout: TypeAlias = FieldValue | tuple["FieldViewInout", ...]
ReferenceValue: TypeAlias = (
Expand Down Expand Up @@ -341,7 +340,9 @@ def allocate(
Useful for shifted fields, which must start off bigger
than the output field in the shifted dimension.
"""
sizes = extend_sizes(case.default_sizes | (sizes or {}), extend)
sizes = extend_sizes(
case.default_sizes | (sizes or {}), extend
) # TODO: this should take into account the Domain of the allocated field
arg_type = get_param_types(fieldview_prog)[name]
if strategy is None:
if name in ["out", RETURN]:
Expand Down Expand Up @@ -421,8 +422,8 @@ def verify(
out_comp = out or inout
out_comp_str = str(out_comp)
assert out_comp is not None
if hasattr(out_comp, "array"):
out_comp_str = str(out_comp.array())
if hasattr(out_comp, "ndarray"):
out_comp_str = str(out_comp.ndarray)
assert comparison(ref, out_comp), (
f"Verification failed:\n"
f"\tcomparison={comparison.__name__}(ref, out)\n"
Expand All @@ -447,12 +448,12 @@ def verify_with_default_data(
case: The test case.
fieldview_prog: The field operator or program to be verified.
ref: A callable which will be called with all the input arguments
of the fieldview code, after applying ``.array()`` on the fields.
of the fieldview code, after applying ``.ndarray`` on the fields.
comparison: A comparison function, which will be called as
``comparison(ref, <out | inout>)`` and should return a boolean.
"""
inps, kwfields = get_default_data(case, fieldop)
ref_args = tuple(i.array() if hasattr(i, "array") else i for i in inps)
ref_args = tuple(i.ndarray if hasattr(i, "ndarray") else i for i in inps)
verify(
case,
fieldop,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def testee(a: cases.IJKField, b: cases.IJKField) -> cases.IJKField:
b = cases.allocate(cartesian_case, testee, "b").extend({cases.IDim: (0, 2)})()
out = cases.allocate(cartesian_case, testee, cases.RETURN)()

cases.verify(cartesian_case, testee, a, b, out=out, ref=a[1:] + b[2:])
cases.verify(cartesian_case, testee, a, b, out=out, ref=a.ndarray[1:] + b.ndarray[2:])


def test_tuples(cartesian_case): # noqa: F811 # fixtures
Expand Down Expand Up @@ -223,7 +223,7 @@ def testee(a: cases.IJKField, b: int32) -> cases.IJKField:
a = cases.allocate(cartesian_case, testee, "a").extend({IDim: (0, 1)})()
b = cases.allocate(cartesian_case, testee, "b")()
out = cases.allocate(cartesian_case, testee, cases.RETURN)()
ref = a.array()[1:] * b
ref = a[1:] * b

cases.verify(cartesian_case, testee, a, b, out=out, ref=ref)

Expand All @@ -250,7 +250,7 @@ def testee(size: gtx.IndexType, out: gtx.Field[[IDim], gtx.IndexType]):
testee,
size,
out=out,
ref=np.full_like(out.array(), size, dtype=gtx.IndexType),
ref=np.full_like(out, size, dtype=gtx.IndexType),
)


Expand Down Expand Up @@ -410,7 +410,7 @@ def testee(a: cases.IKField, offset_field: cases.IKField) -> gtx.Field[[IDim, KD
comparison=lambda out, ref: np.all(out == ref),
)

assert np.allclose(out.array(), ref)
assert np.allclose(out, ref)


def test_nested_tuple_return(cartesian_case):
Expand Down Expand Up @@ -524,7 +524,7 @@ def testee(a: tuple[tuple[cases.IField, cases.IField], cases.IField]) -> cases.I
return 3 * a[0][0] + a[0][1] + a[1]

cases.verify_with_default_data(
cartesian_case, testee, ref=lambda a: 3 * a[0][0].array() + a[0][1].array() + a[1].array()
cartesian_case, testee, ref=lambda a: 3 * a[0][0] + a[0][1] + a[1]
)


Expand Down Expand Up @@ -695,9 +695,9 @@ def testee(out: tuple[cases.KField, tuple[cases.KField, cases.KField]]):
cartesian_case,
testee,
ref=lambda: (expected + 1.0, (expected + 2.0, expected + 3.0)),
comparison=lambda ref, out: np.all(out[0].array() == ref[0])
and np.all(out[1][0].array() == ref[1][0])
and np.all(out[1][1].array() == ref[1][1]),
comparison=lambda ref, out: np.all(out[0] == ref[0])
and np.all(out[1][0] == ref[1][0])
and np.all(out[1][1] == ref[1][1]),
)


Expand Down Expand Up @@ -754,7 +754,7 @@ def program_bound_args(a: cases.IField, scalar: int32, condition: bool, out: cas

a = cases.allocate(cartesian_case, program_bound_args, "a")()
scalar = int32(1)
ref = a.array() + a.array() + 1
ref = a + a + 1
out = cases.allocate(cartesian_case, program_bound_args, "out")()

prog_bounds = program_bound_args.with_bound_args(scalar=scalar, condition=True)
Expand All @@ -773,9 +773,7 @@ def program_domain(a: cases.IField, out: cases.IField):
a = cases.allocate(cartesian_case, program_domain, "a")()
out = cases.allocate(cartesian_case, program_domain, "out")()

cases.verify(
cartesian_case, program_domain, a, out, inout=out.array()[1:9], ref=a.array()[1:9] * 2
)
cases.verify(cartesian_case, program_domain, a, out, inout=out[1:9], ref=a[1:9] * 2)


def test_domain_input_bounds(cartesian_case):
Expand Down Expand Up @@ -809,8 +807,8 @@ def program_domain(
out,
lower_i,
upper_i,
inout=out.array()[lower_i : int(upper_i / 2)],
ref=inp.array()[lower_i : int(upper_i / 2)] * 2,
inout=out[lower_i : int(upper_i / 2)],
ref=inp[lower_i : int(upper_i / 2)] * 2,
)


Expand Down Expand Up @@ -851,8 +849,8 @@ def program_domain(
upper_i,
lower_j,
upper_j,
inout=out.array()[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j],
ref=a.array()[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] * 2,
inout=out[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j],
ref=a[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] * 2,
)


Expand Down Expand Up @@ -888,7 +886,7 @@ def program_domain_tuple(
out0,
out1,
inout=(out0[1:9, 4:6], out1[1:9, 4:6]),
ref=(inp0.array()[1:9, 4:6] + inp1.array()[1:9, 4:6], inp1.array()[1:9, 4:6]),
ref=(inp0[1:9, 4:6] + inp1[1:9, 4:6], inp1[1:9, 4:6]),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
ids=["positive_values", "negative_values"],
)
def test_maxover_execution_(unstructured_case, strategy):
if unstructured_case.backend == dace_iterator.run_dace_iterator:
pytest.xfail("Not supported in DaCe backend: reductions")
if unstructured_case.backend in [gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative]:
pytest.xfail("`maxover` broken in gtfn, see #1289.")

Expand All @@ -58,11 +60,14 @@ def testee(edge_f: cases.EField) -> cases.VField:
out = cases.allocate(unstructured_case, testee, cases.RETURN)()

v2e_table = unstructured_case.offset_provider["V2E"].table
ref = np.max(inp[v2e_table], axis=1)
ref = np.max(inp.ndarray[v2e_table], axis=1)
cases.verify(unstructured_case, testee, inp, ref=ref, out=out)


def test_minover_execution(unstructured_case):
if unstructured_case.backend == dace_iterator.run_dace_iterator:
pytest.xfail("Not supported in DaCe backend: reductions")

@gtx.field_operator
def minover(edge_f: cases.EField) -> cases.VField:
out = min_over(edge_f(V2E), axis=V2EDim)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def mod_fieldop(inp1: cases.IField) -> cases.IField:
inp1 = gtx.np_as_located_field(IDim)(np.asarray(range(10), dtype=int32) - 5)
out = cases.allocate(cartesian_case, mod_fieldop, cases.RETURN)()

cases.verify(cartesian_case, mod_fieldop, inp1, out=out, ref=inp1.array() % 2)
cases.verify(cartesian_case, mod_fieldop, inp1, out=out, ref=inp1 % 2)


def test_bit_xor(cartesian_case):
Expand All @@ -117,7 +117,7 @@ def binary_xor(inp1: cases.IBoolField, inp2: cases.IBoolField) -> cases.IBoolFie
cases.ConstInitializer(bool_field)
)()
out = cases.allocate(cartesian_case, binary_xor, cases.RETURN)()
cases.verify(cartesian_case, binary_xor, inp1, inp2, out=out, ref=inp1.array() ^ inp2.array())
cases.verify(cartesian_case, binary_xor, inp1, inp2, out=out, ref=inp1 ^ inp2)


def test_bit_and(cartesian_case):
Expand All @@ -134,7 +134,7 @@ def bit_and(inp1: cases.IBoolField, inp2: cases.IBoolField) -> cases.IBoolField:
cases.ConstInitializer(bool_field)
)()
out = cases.allocate(cartesian_case, bit_and, cases.RETURN)()
cases.verify(cartesian_case, bit_and, inp1, inp2, out=out, ref=inp1.array() & inp2.array())
cases.verify(cartesian_case, bit_and, inp1, inp2, out=out, ref=inp1 & inp2)


def test_bit_or(cartesian_case):
Expand All @@ -151,7 +151,7 @@ def bit_or(inp1: cases.IBoolField, inp2: cases.IBoolField) -> cases.IBoolField:
cases.ConstInitializer(bool_field)
)()
out = cases.allocate(cartesian_case, bit_or, cases.RETURN)()
cases.verify(cartesian_case, bit_or, inp1, inp2, out=out, ref=inp1.array() | inp2.array())
cases.verify(cartesian_case, bit_or, inp1, inp2, out=out, ref=inp1 | inp2)


# Unary builtins
Expand All @@ -176,7 +176,7 @@ def tilde_fieldop(inp1: cases.IBoolField) -> cases.IBoolField:
cases.ConstInitializer(bool_field)
)()
out = cases.allocate(cartesian_case, tilde_fieldop, cases.RETURN)()
cases.verify(cartesian_case, tilde_fieldop, inp1, out=out, ref=~inp1.array())
cases.verify(cartesian_case, tilde_fieldop, inp1, out=out, ref=~inp1)


def test_unary_not(cartesian_case):
Expand All @@ -190,7 +190,7 @@ def not_fieldop(inp1: cases.IBoolField) -> cases.IBoolField:
cases.ConstInitializer(bool_field)
)()
out = cases.allocate(cartesian_case, not_fieldop, cases.RETURN)()
cases.verify(cartesian_case, not_fieldop, inp1, out=out, ref=~inp1.array())
cases.verify(cartesian_case, not_fieldop, inp1, out=out, ref=~inp1)


# Trig builtins
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def shift_by_one_program(in_field: cases.IFloatField, out_field: cases.IFloatFie
shift_by_one_program,
in_field,
out_field,
inout=out_field.array()[:-1],
ref=in_field.array()[1:-1],
inout=out_field[:-1],
ref=in_field[1:-1],
)


Expand Down Expand Up @@ -184,7 +184,7 @@ def prog(

cases.run(cartesian_case, prog, a, b, out_a, out_b, offset_provider={})

assert np.allclose((a.array()[1:], b.array()[1:]), (out_a.array()[1:], out_b.array()[1:]))
assert np.allclose((a[1:], b[1:]), (out_a[1:], out_b[1:]))
assert out_a[0] == 0 and out_b[0] == 0


Expand Down
Loading

0 comments on commit 5157007

Please sign in to comment.