diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/bindings.py b/src/gt4py/next/program_processors/runners/dace/workflow/bindings.py index 52cfc8869f..fbdc8229ec 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/bindings.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/bindings.py @@ -22,6 +22,39 @@ _cb_device: Final[str] = "device" _cb_sdfg_argtypes: Final[str] = "sdfg_argtypes" _cb_sdfg_call_args: Final[str] = "sdfg_call_args" +_cb_neighbor_table: Final[str] = "table" +_cb_offset_provider: Final[str] = "offset_provider" + + +def _update_sdfg_array_ptr(code: codegen.TextBlock, arg: str, sdfg_arg_index: int) -> None: + code.append(f"assert field_utils.verify_device_field_type({arg}, {_cb_device})") + code.append(f"assert isinstance({_cb_sdfg_call_args}[{sdfg_arg_index}], ctypes.c_void_p)") + code.append(f"{_cb_sdfg_call_args}[{sdfg_arg_index}].value = {arg}.__gt_buffer_info__.data_ptr") + + +def _update_sdfg_array_strides( + code: codegen.TextBlock, + sdfg_arglist: dict[str, dace.data.Data], + arg: str, + sdfg_arg_desc: dace.data.Array, + sdfg_arg_index: int, +) -> None: + for i, array_stride in enumerate(sdfg_arg_desc.strides): + arg_stride = f"{arg}.__gt_buffer_info__.elem_strides[{i}]" + if isinstance(array_stride, int) or str(array_stride).isdigit(): + # The array stride is set to constant value in this dimension. + code.append( + f"assert {_cb_sdfg_argtypes}[{sdfg_arg_index}].strides[{i}] == {arg_stride}" + ) + else: + # The strides of a global array are defined by a sequence of SDFG symbols. + _parse_gt_param( + param_name=array_stride.name, + param_type=gtx_dace_args.as_itir_type(array_stride.dtype), + arg=arg_stride, + code=code, + sdfg_arglist=sdfg_arglist, + ) def _update_sdfg_scalar_arg( @@ -118,22 +151,14 @@ def _parse_gt_param( ) else: assert isinstance(sdfg_arg_desc, dace.data.Array) - code.append(f"assert field_utils.verify_device_field_type({arg}, {_cb_device})") - code.append( - f"assert isinstance({_cb_sdfg_call_args}[{sdfg_arg_index}], ctypes.c_void_p)" - ) - code.append(f"{arg}_buffer_info = {arg}.__gt_buffer_info__") - - code.append( - f"{_cb_sdfg_call_args}[{sdfg_arg_index}].value = {arg}_buffer_info.data_ptr" - ) + _update_sdfg_array_ptr(code, arg, sdfg_arg_index) for i, (dim, array_size) in enumerate( zip(param_type.dims, sdfg_arg_desc.shape, strict=True) ): if isinstance(array_size, int) or str(array_size).isdigit(): # The array shape in this dimension is set at compile-time. code.append( - f"assert {_cb_sdfg_argtypes}[{sdfg_arg_index}].shape[{i}] == {arg}_buffer_info.shape[{i}]" + f"assert {_cb_sdfg_argtypes}[{sdfg_arg_index}].shape[{i}] == {arg}.__gt_buffer_info__.shape[{i}]" ) else: # The array shape is defined as a sequence of expressions @@ -150,25 +175,7 @@ def _parse_gt_param( code=code, sdfg_arglist=sdfg_arglist, ) - for i, (dim, array_stride) in enumerate( - zip(param_type.dims, sdfg_arg_desc.strides, strict=True) - ): - arg_stride = f"{arg}_buffer_info.elem_strides[{i}]" - if isinstance(array_stride, int) or str(array_stride).isdigit(): - # The array stride is set to constant value in this dimension. - code.append( - f"assert {_cb_sdfg_argtypes}[{sdfg_arg_index}].strides[{i}] == {arg_stride}" - ) - else: - # The strides of a global array are defined by a sequence of SDFG symbols. - assert array_stride == gtx_dace_args.field_stride_symbol(param_name, dim) - _parse_gt_param( - param_name=array_stride.name, - param_type=gtx_dace_args.as_itir_type(array_stride.dtype), - arg=arg_stride, - code=code, - sdfg_arglist=sdfg_arglist, - ) + _update_sdfg_array_strides(code, sdfg_arglist, arg, sdfg_arg_desc, sdfg_arg_index) elif isinstance(param_type, ts.ScalarType): assert isinstance(sdfg_arg_desc, dace.data.Scalar) @@ -183,6 +190,38 @@ def _parse_gt_param( raise ValueError(f"Unexpected paramter type {param_type}") +def _parse_gt_connectivities( + code: codegen.TextBlock, sdfg_arglist: dict[str, dace.data.Data] +) -> None: + for sdfg_arg_index, (arg_name, arg_desc) in enumerate(sdfg_arglist.items()): + if gtx_dace_args.is_connectivity_identifier(arg_name): + assert isinstance(arg_desc, dace.data.Array) + assert len(arg_desc.shape) == 2 + assert isinstance(arg_desc.shape[1], int) or str(arg_desc.shape[1]).isdigit() + origin_size_arg = arg_desc.shape[0] + assert len(origin_size_arg.free_symbols) == 1 + origin_size_param = next(iter(origin_size_arg.free_symbols)) + m = gtx_dace_args.CONNECTIVITY_INDENTIFIER_RE.match(arg_name) + assert m is not None + conn_arg = f"{_cb_neighbor_table}_{m[1]}" + code.append(f'{conn_arg} = {_cb_offset_provider}["{m[1]}"]') + _update_sdfg_array_ptr(code, conn_arg, sdfg_arg_index) + _parse_gt_param( # set the size in the horizontal dimension + param_name=origin_size_param, + param_type=gtx_dace_args.as_itir_type(gtx_dace_args.FIELD_SYMBOL_DTYPE), + arg=f"{conn_arg}.__gt_buffer_info__.shape[0]", + code=code, + sdfg_arglist=sdfg_arglist, + ) + _update_sdfg_array_strides( + code, + sdfg_arglist, + conn_arg, + arg_desc, + sdfg_arg_index, + ) + + def _create_sdfg_bindings( program_source: stages.ProgramSource[languages.SDFG, languages.LanguageSettings], bind_func_name: str, @@ -212,12 +251,13 @@ def _create_sdfg_bindings( code.append("from gt4py.next import common as gtx_common, field_utils") code.empty_line() code.append( - "def {funname}({arg0}, {arg1}, {arg2}, {arg3}):".format( + "def {funname}({arg0}, {arg1}, {arg2}, {arg3}, {arg4}):".format( funname=bind_func_name, arg0=_cb_device, arg1=_cb_sdfg_argtypes, arg2=_cb_args, arg3=_cb_sdfg_call_args, + arg4=_cb_offset_provider, ) ) @@ -233,6 +273,14 @@ def _create_sdfg_bindings( assert isinstance(param.type_, ts.DataType) _parse_gt_param(param.name, param.type_, arg, code, sdfg_arglist) + # In the regular case, the connectivity fields are allocated at the beginning + # of the application and then used during its entire lifetime and never reallocated. + # However, this might not be the case all the time, for example in unit tests + # where, due to limited lifetime of the fixtures, the connectivity fields + # might get reallocated. In order to avoid problems, we update the connectivity + # arrays as well in SDFG fastcall. + _parse_gt_connectivities(code, sdfg_arglist) + src = codegen.format_python_source(code.text) return stages.BindingSource(src, library_deps=tuple()) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 0f845ba7e9..d5e42ef181 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -18,7 +18,7 @@ import factory from gt4py._core import definitions as core_defs, locking -from gt4py.next import config +from gt4py.next import common, config from gt4py.next.otf import languages, stages, step_types, workflow from gt4py.next.otf.compilation import cache as gtx_cache from gt4py.next.program_processors.runners.dace.workflow import common as gtx_wfdcommon @@ -33,7 +33,13 @@ class CompiledDaceProgram(stages.CompiledProgram): # The compiled program contains a callable object to update the SDFG arguments list. update_sdfg_ctype_arglist: Callable[ - [core_defs.DeviceType, Sequence[dace.dtypes.Data], Sequence[Any], MutableSequence[Any]], + [ + core_defs.DeviceType, + Sequence[dace.dtypes.Data], + Sequence[Any], + MutableSequence[Any], + common.OffsetProvider, + ], None, ] diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index 996ba7a095..a6cc6a0a87 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -50,7 +50,7 @@ def decorated_program( # `fun.csdfg_args` is `None` # TODO(phimuell, edopao): Think about refactor the code such that the update # of the argument vector is a Method of the `CompiledDaceProgram`. - update_sdfg_call_args(args, fun.csdfg_argv) # type: ignore[arg-type] # Will error out in first call. + update_sdfg_call_args(args, fun.csdfg_argv, offset_provider) # type: ignore[arg-type] # Will error out in first call. except TypeError: # First call. Construct the initial argument vector of the `CompiledDaceProgram`. diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py index 9d4aa778b9..6c538af4d6 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py @@ -19,6 +19,8 @@ from gt4py.next.otf import languages, stages from gt4py.next.program_processors.runners import dace as dace_runner from gt4py.next.program_processors.runners.dace import workflow as dace_workflow +from gt4py.next import neighbor_sum +from next_tests.integration_tests.cases import E2V, E2VDim, V2E, V2EDim from next_tests.integration_tests import cases from next_tests.integration_tests.feature_tests.ffront_tests import ffront_test_utils @@ -36,7 +38,7 @@ """ -def _binding_source(use_metrics: bool) -> str: +def _binding_source_cartesian(use_metrics: bool) -> str: # In this SDFG 'sdfg_call_args[2]' is used to collect the stencil compute time. # Note that the position of 'gt_compute_time' in the SDFG arguments list is # defined by dace, based on alphabetical order from index 0 ('a', 'b', 'gt_compute_time'). @@ -47,7 +49,7 @@ def _binding_source(use_metrics: bool) -> str: return ( _bind_header + f"""\ -def {_bind_func_name}(device, sdfg_argtypes, args, sdfg_call_args): +def {_bind_func_name}(device, sdfg_argtypes, args, sdfg_call_args, offset_provider): ( args_0, args_1, @@ -67,41 +69,38 @@ def {_bind_func_name}(device, sdfg_argtypes, args, sdfg_call_args): args_0_1_2, ) = args_0_1 sdfg_call_args[{idx[1]}] = ctypes.c_int(args_0_1_0) - args_0_1_1_buffer_info = args_0_1_1.__gt_buffer_info__ - sdfg_call_args[{idx[2]}].value = args_0_1_1_buffer_info.data_ptr + sdfg_call_args[{idx[2]}].value = args_0_1_1.__gt_buffer_info__.data_ptr sdfg_call_args[{idx[3]}] = ctypes.c_int(args_0_1_1.domain.ranges[0].start) sdfg_call_args[{idx[4]}] = ctypes.c_int(args_0_1_1.domain.ranges[1].start) sdfg_call_args[{idx[5]}] = ctypes.c_int(args_0_1_1.domain.ranges[2].start) - sdfg_call_args[{idx[6]}] = ctypes.c_int(args_0_1_1_buffer_info.elem_strides[0]) - sdfg_call_args[{idx[7]}] = ctypes.c_int(args_0_1_1_buffer_info.elem_strides[1]) - sdfg_call_args[{idx[8]}] = ctypes.c_int(args_0_1_1_buffer_info.elem_strides[2]) + sdfg_call_args[{idx[6]}] = ctypes.c_int(args_0_1_1.__gt_buffer_info__.elem_strides[0]) + sdfg_call_args[{idx[7]}] = ctypes.c_int(args_0_1_1.__gt_buffer_info__.elem_strides[1]) + sdfg_call_args[{idx[8]}] = ctypes.c_int(args_0_1_1.__gt_buffer_info__.elem_strides[2]) ( args_1_0, args_1_1, ) = args_1 (args_1_0_0,) = args_1_0 - args_1_0_0_buffer_info = args_1_0_0.__gt_buffer_info__ - sdfg_call_args[{idx[9]}].value = args_1_0_0_buffer_info.data_ptr + sdfg_call_args[{idx[9]}].value = args_1_0_0.__gt_buffer_info__.data_ptr sdfg_call_args[{idx[10]}] = ctypes.c_int(args_1_0_0.domain.ranges[0].start) sdfg_call_args[{idx[11]}] = ctypes.c_int(args_1_0_0.domain.ranges[1].start) sdfg_call_args[{idx[12]}] = ctypes.c_int(args_1_0_0.domain.ranges[2].start) - sdfg_call_args[{idx[13]}] = ctypes.c_int(args_1_0_0_buffer_info.elem_strides[0]) - sdfg_call_args[{idx[14]}] = ctypes.c_int(args_1_0_0_buffer_info.elem_strides[1]) - sdfg_call_args[{idx[15]}] = ctypes.c_int(args_1_0_0_buffer_info.elem_strides[2]) + sdfg_call_args[{idx[13]}] = ctypes.c_int(args_1_0_0.__gt_buffer_info__.elem_strides[0]) + sdfg_call_args[{idx[14]}] = ctypes.c_int(args_1_0_0.__gt_buffer_info__.elem_strides[1]) + sdfg_call_args[{idx[15]}] = ctypes.c_int(args_1_0_0.__gt_buffer_info__.elem_strides[2]) sdfg_call_args[{idx[16]}] = ctypes.c_int(args_1_1) - args_5_buffer_info = args_5.__gt_buffer_info__ - sdfg_call_args[{idx[17]}].value = args_5_buffer_info.data_ptr + sdfg_call_args[{idx[17]}].value = args_5.__gt_buffer_info__.data_ptr sdfg_call_args[{idx[18]}] = ctypes.c_int(args_5.domain.ranges[0].start) sdfg_call_args[{idx[19]}] = ctypes.c_int(args_5.domain.ranges[1].start) sdfg_call_args[{idx[20]}] = ctypes.c_int(args_5.domain.ranges[2].start) - sdfg_call_args[{idx[21]}] = ctypes.c_int(args_5_buffer_info.elem_strides[0]) - sdfg_call_args[{idx[22]}] = ctypes.c_int(args_5_buffer_info.elem_strides[1]) - sdfg_call_args[{idx[23]}] = ctypes.c_int(args_5_buffer_info.elem_strides[2])\ + sdfg_call_args[{idx[21]}] = ctypes.c_int(args_5.__gt_buffer_info__.elem_strides[0]) + sdfg_call_args[{idx[22]}] = ctypes.c_int(args_5.__gt_buffer_info__.elem_strides[1]) + sdfg_call_args[{idx[23]}] = ctypes.c_int(args_5.__gt_buffer_info__.elem_strides[2])\ """ ) -def _binding_source_with_zero_origin(use_metrics: bool) -> str: +def _binding_source_cartesian_with_zero_origin(use_metrics: bool) -> str: # In this SDFG 'sdfg_call_args[2]' is used to collect the stencil compute time. # Note that the position of 'gt_compute_time' in the SDFG arguments list is # defined by dace, based on alphabetical order from index 0 ('a', 'b', 'gt_compute_time'). @@ -112,7 +111,7 @@ def _binding_source_with_zero_origin(use_metrics: bool) -> str: return ( _bind_header + f"""\ -def {_bind_func_name}(device, sdfg_argtypes, args, sdfg_call_args): +def {_bind_func_name}(device, sdfg_argtypes, args, sdfg_call_args, offset_provider): ( args_0, args_1, @@ -132,34 +131,93 @@ def {_bind_func_name}(device, sdfg_argtypes, args, sdfg_call_args): args_0_1_2, ) = args_0_1 sdfg_call_args[{idx[1]}] = ctypes.c_int(args_0_1_0) - args_0_1_1_buffer_info = args_0_1_1.__gt_buffer_info__ - sdfg_call_args[{idx[2]}].value = args_0_1_1_buffer_info.data_ptr - sdfg_call_args[{idx[3]}] = ctypes.c_int(args_0_1_1_buffer_info.elem_strides[0]) - sdfg_call_args[{idx[4]}] = ctypes.c_int(args_0_1_1_buffer_info.elem_strides[1]) - sdfg_call_args[{idx[5]}] = ctypes.c_int(args_0_1_1_buffer_info.elem_strides[2]) + sdfg_call_args[{idx[2]}].value = args_0_1_1.__gt_buffer_info__.data_ptr + sdfg_call_args[{idx[3]}] = ctypes.c_int(args_0_1_1.__gt_buffer_info__.elem_strides[0]) + sdfg_call_args[{idx[4]}] = ctypes.c_int(args_0_1_1.__gt_buffer_info__.elem_strides[1]) + sdfg_call_args[{idx[5]}] = ctypes.c_int(args_0_1_1.__gt_buffer_info__.elem_strides[2]) ( args_1_0, args_1_1, ) = args_1 (args_1_0_0,) = args_1_0 - args_1_0_0_buffer_info = args_1_0_0.__gt_buffer_info__ - sdfg_call_args[{idx[6]}].value = args_1_0_0_buffer_info.data_ptr - sdfg_call_args[{idx[7]}] = ctypes.c_int(args_1_0_0_buffer_info.elem_strides[0]) - sdfg_call_args[{idx[8]}] = ctypes.c_int(args_1_0_0_buffer_info.elem_strides[1]) - sdfg_call_args[{idx[9]}] = ctypes.c_int(args_1_0_0_buffer_info.elem_strides[2]) + sdfg_call_args[{idx[6]}].value = args_1_0_0.__gt_buffer_info__.data_ptr + sdfg_call_args[{idx[7]}] = ctypes.c_int(args_1_0_0.__gt_buffer_info__.elem_strides[0]) + sdfg_call_args[{idx[8]}] = ctypes.c_int(args_1_0_0.__gt_buffer_info__.elem_strides[1]) + sdfg_call_args[{idx[9]}] = ctypes.c_int(args_1_0_0.__gt_buffer_info__.elem_strides[2]) sdfg_call_args[{idx[10]}] = ctypes.c_int(args_1_1) - args_5_buffer_info = args_5.__gt_buffer_info__ - sdfg_call_args[{idx[11]}].value = args_5_buffer_info.data_ptr - sdfg_call_args[{idx[12]}] = ctypes.c_int(args_5_buffer_info.elem_strides[0]) - sdfg_call_args[{idx[13]}] = ctypes.c_int(args_5_buffer_info.elem_strides[1]) - sdfg_call_args[{idx[14]}] = ctypes.c_int(args_5_buffer_info.elem_strides[2])\ + sdfg_call_args[{idx[11]}].value = args_5.__gt_buffer_info__.data_ptr + sdfg_call_args[{idx[12]}] = ctypes.c_int(args_5.__gt_buffer_info__.elem_strides[0]) + sdfg_call_args[{idx[13]}] = ctypes.c_int(args_5.__gt_buffer_info__.elem_strides[1]) + sdfg_call_args[{idx[14]}] = ctypes.c_int(args_5.__gt_buffer_info__.elem_strides[2])\ +""" + ) + + +def _binding_source_unstructured(use_metrics: bool) -> str: + metrics_arg_index = 2 + idx = [0, 4, 1, 5, 6, 7, 2, 9, 8, 3, 11, 10] + if use_metrics: + idx = [idx + 1 if idx >= metrics_arg_index else idx for idx in idx] + return ( + _bind_header + + f"""\ +def {_bind_func_name}(device, sdfg_argtypes, args, sdfg_call_args, offset_provider): + ( + args_0, + args_1, + ) = args + sdfg_call_args[{idx[0]}].value = args_0.__gt_buffer_info__.data_ptr + sdfg_call_args[{idx[1]}] = ctypes.c_int(args_0.__gt_buffer_info__.elem_strides[0]) + sdfg_call_args[{idx[2]}].value = args_1.__gt_buffer_info__.data_ptr + sdfg_call_args[{idx[3]}] = ctypes.c_int(args_1.domain.ranges[0].start) + sdfg_call_args[{idx[4]}] = ctypes.c_int(args_1.domain.ranges[0].stop) + sdfg_call_args[{idx[5]}] = ctypes.c_int(args_1.__gt_buffer_info__.elem_strides[0]) + table_E2V = offset_provider["E2V"] + sdfg_call_args[{idx[6]}].value = table_E2V.__gt_buffer_info__.data_ptr + sdfg_call_args[{idx[7]}] = ctypes.c_int(table_E2V.__gt_buffer_info__.elem_strides[0]) + sdfg_call_args[{idx[8]}] = ctypes.c_int(table_E2V.__gt_buffer_info__.elem_strides[1]) + table_V2E = offset_provider["V2E"] + sdfg_call_args[{idx[9]}].value = table_V2E.__gt_buffer_info__.data_ptr + sdfg_call_args[{idx[10]}] = ctypes.c_int(table_V2E.__gt_buffer_info__.elem_strides[0]) + sdfg_call_args[{idx[11]}] = ctypes.c_int(table_V2E.__gt_buffer_info__.elem_strides[1])\ +""" + ) + + +def _binding_source_unstructured_with_zero_origin(use_metrics: bool) -> str: + metrics_arg_index = 2 + idx = [0, 4, 1, 5, 6, 2, 8, 7, 3, 10, 9] + if use_metrics: + idx = [idx + 1 if idx >= metrics_arg_index else idx for idx in idx] + return ( + _bind_header + + f"""\ +def {_bind_func_name}(device, sdfg_argtypes, args, sdfg_call_args, offset_provider): + ( + args_0, + args_1, + ) = args + sdfg_call_args[{idx[0]}].value = args_0.__gt_buffer_info__.data_ptr + sdfg_call_args[{idx[1]}] = ctypes.c_int(args_0.__gt_buffer_info__.elem_strides[0]) + sdfg_call_args[{idx[2]}].value = args_1.__gt_buffer_info__.data_ptr + sdfg_call_args[{idx[3]}] = ctypes.c_int(args_1.domain.ranges[0].stop) + sdfg_call_args[{idx[4]}] = ctypes.c_int(args_1.__gt_buffer_info__.elem_strides[0]) + table_E2V = offset_provider["E2V"] + sdfg_call_args[{idx[5]}].value = table_E2V.__gt_buffer_info__.data_ptr + sdfg_call_args[{idx[6]}] = ctypes.c_int(table_E2V.__gt_buffer_info__.elem_strides[0]) + sdfg_call_args[{idx[7]}] = ctypes.c_int(table_E2V.__gt_buffer_info__.elem_strides[1]) + table_V2E = offset_provider["V2E"] + sdfg_call_args[{idx[8]}].value = table_V2E.__gt_buffer_info__.data_ptr + sdfg_call_args[{idx[9]}] = ctypes.c_int(table_V2E.__gt_buffer_info__.elem_strides[0]) + sdfg_call_args[{idx[10]}] = ctypes.c_int(table_V2E.__gt_buffer_info__.elem_strides[1])\ """ ) # The difference between the two bindings versions is that one uses field domain # with zero origin, therefore the range-start symbols are not present in the SDFG. -assert _binding_source_with_zero_origin != _binding_source +assert _binding_source_cartesian_with_zero_origin != _binding_source_cartesian +assert _binding_source_unstructured_with_zero_origin != _binding_source_unstructured _dace_compile_call = dace_workflow.compilation.DaCeCompiler.__call__ @@ -168,8 +226,7 @@ def {_bind_func_name}(device, sdfg_argtypes, args, sdfg_call_args): def mocked_compile_call( self, inp: stages.CompilableSource[languages.SDFG, languages.LanguageSettings, languages.Python], - use_metrics: bool, - use_zero_origin: bool, + binding_source_ref: str, ): assert len(inp.library_deps) == 0 @@ -179,17 +236,41 @@ def mocked_compile_call( for line in inp.binding_source.source_code.splitlines() if not line.lstrip().startswith("assert") ) - - binding_source_ref = _binding_source_with_zero_origin if use_zero_origin else _binding_source - assert binding_source_pruned == binding_source_ref(use_metrics) + assert binding_source_pruned == binding_source_ref return _dace_compile_call(self, inp) +def mocked_compile_call_cartesian( + self, + inp: stages.CompilableSource[languages.SDFG, languages.LanguageSettings, languages.Python], + use_metrics: bool, + use_zero_origin: bool, +): + binding_ref_fun = ( + _binding_source_cartesian_with_zero_origin if use_zero_origin else _binding_source_cartesian + ) + return mocked_compile_call(self, inp, binding_ref_fun(use_metrics)) + + +def mocked_compile_call_unstructured( + self, + inp: stages.CompilableSource[languages.SDFG, languages.LanguageSettings, languages.Python], + use_metrics: bool, + use_zero_origin: bool, +): + binding_ref_fun = ( + _binding_source_unstructured_with_zero_origin + if use_zero_origin + else _binding_source_unstructured + ) + return mocked_compile_call(self, inp, binding_ref_fun(use_metrics)) + + @pytest.mark.parametrize("use_metrics", [False, True], ids=["no_metrics", "use_metrics"]) @pytest.mark.parametrize( "use_zero_origin", [False, True], ids=["no_zero_origin", "use_zero_origin"] ) -def test_bind_sdfg(use_metrics, use_zero_origin, monkeypatch): +def test_cartesian_bind_sdfg(use_metrics, use_zero_origin, monkeypatch): M, N, K = (30, 20, 10) @gtx.field_operator @@ -222,17 +303,10 @@ def testee( dace_workflow.compilation.DaCeCompiler, "__call__", functools.partialmethod( - mocked_compile_call, use_metrics=use_metrics, use_zero_origin=use_zero_origin + mocked_compile_call_cartesian, use_metrics=use_metrics, use_zero_origin=use_zero_origin ), ) - static_args = {"M": [M], "N": [N], "K": [K]} - program = ( - testee.with_grid_type(gtx_common.GridType.CARTESIAN) - .with_backend(backend) - .compile(enable_jit=False, offset_provider={}, **static_args) - ) - test_case = cases.Case.from_cartesian_grid_descriptor( ffront_test_utils.simple_cartesian_grid(), backend=backend, @@ -249,5 +323,68 @@ def testee( a[0] + 2 * a[1][0] + 3 * a[1][1].asnumpy() + 4 * b[0][0].asnumpy() + 5 * b[1] )[1 : M - 1, 2 : N - 2, 3 : K - 3] + static_args = {"M": [M], "N": [N], "K": [K]} + program = ( + testee.with_grid_type(gtx_common.GridType.CARTESIAN) + .with_backend(backend) + .compile(enable_jit=False, offset_provider={}, **static_args) + ) program(a, b, out=c, M=M, N=N, K=K) assert np.all(c.asnumpy() == ref) + + +@pytest.mark.parametrize("use_metrics", [False, True], ids=["no_metrics", "use_metrics"]) +@pytest.mark.parametrize( + "use_zero_origin", [False, True], ids=["no_zero_origin", "use_zero_origin"] +) +def test_unstructured_bind_sdfg(use_metrics, use_zero_origin, monkeypatch): + @gtx.field_operator + def testee_op(a: cases.VField) -> cases.VField: + tmp = neighbor_sum(a(E2V), axis=E2VDim) + tmp_2 = neighbor_sum(tmp(V2E), axis=V2EDim) + return tmp_2 + + @gtx.program + def testee(a: cases.VField, b: cases.VField): + testee_op(a, out=b) + + backend = dace_runner.make_dace_backend( + gpu=False, + cached=False, + auto_optimize=True, + use_metrics=use_metrics, + use_zero_origin=use_zero_origin, + ) + monkeypatch.setattr( + dace_workflow.compilation.DaCeCompiler, + "__call__", + functools.partialmethod( + mocked_compile_call_unstructured, + use_metrics=use_metrics, + use_zero_origin=use_zero_origin, + ), + ) + + SIMPLE_MESH = ffront_test_utils.simple_mesh(None) + offset_provider = SIMPLE_MESH.offset_provider + + test_case = cases.Case.from_mesh_descriptor(SIMPLE_MESH, backend=backend, allocator=backend) + + a = cases.allocate(test_case, testee, "a")() + b = cases.allocate(test_case, testee, "b")() + + ref = np.sum( + np.sum(a.asnumpy()[offset_provider["E2V"].asnumpy()], axis=1, initial=0)[ + offset_provider["V2E"].asnumpy() + ], + axis=1, + ) + + static_args = {} + program = ( + testee.with_grid_type(gtx_common.GridType.UNSTRUCTURED) + .with_backend(backend) + .compile(enable_jit=False, offset_provider=offset_provider, **static_args) + ) + program(a, b, offset_provider=offset_provider) + assert np.all(b.asnumpy() == ref)