Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 78 additions & 30 deletions src/gt4py/next/program_processors/runners/dace/workflow/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,39 @@
_cb_device: Final[str] = "device"
_cb_sdfg_argtypes: Final[str] = "sdfg_argtypes"
_cb_sdfg_call_args: Final[str] = "sdfg_call_args"
_cb_neighbor_table: Final[str] = "table"
_cb_offset_provider: Final[str] = "offset_provider"


def _update_sdfg_array_ptr(code: codegen.TextBlock, arg: str, sdfg_arg_index: int) -> None:
code.append(f"assert field_utils.verify_device_field_type({arg}, {_cb_device})")
code.append(f"assert isinstance({_cb_sdfg_call_args}[{sdfg_arg_index}], ctypes.c_void_p)")
code.append(f"{_cb_sdfg_call_args}[{sdfg_arg_index}].value = {arg}.__gt_buffer_info__.data_ptr")


def _update_sdfg_array_strides(
code: codegen.TextBlock,
sdfg_arglist: dict[str, dace.data.Data],
arg: str,
sdfg_arg_desc: dace.data.Array,
sdfg_arg_index: int,
) -> None:
for i, array_stride in enumerate(sdfg_arg_desc.strides):
arg_stride = f"{arg}.__gt_buffer_info__.elem_strides[{i}]"
if isinstance(array_stride, int) or str(array_stride).isdigit():
# The array stride is set to constant value in this dimension.
code.append(
f"assert {_cb_sdfg_argtypes}[{sdfg_arg_index}].strides[{i}] == {arg_stride}"
)
else:
# The strides of a global array are defined by a sequence of SDFG symbols.
_parse_gt_param(
param_name=array_stride.name,
param_type=gtx_dace_args.as_itir_type(array_stride.dtype),
arg=arg_stride,
code=code,
sdfg_arglist=sdfg_arglist,
)


def _update_sdfg_scalar_arg(
Expand Down Expand Up @@ -118,22 +151,14 @@ def _parse_gt_param(
)
else:
assert isinstance(sdfg_arg_desc, dace.data.Array)
code.append(f"assert field_utils.verify_device_field_type({arg}, {_cb_device})")
code.append(
f"assert isinstance({_cb_sdfg_call_args}[{sdfg_arg_index}], ctypes.c_void_p)"
)
code.append(f"{arg}_buffer_info = {arg}.__gt_buffer_info__")

code.append(
f"{_cb_sdfg_call_args}[{sdfg_arg_index}].value = {arg}_buffer_info.data_ptr"
)
_update_sdfg_array_ptr(code, arg, sdfg_arg_index)
for i, (dim, array_size) in enumerate(
zip(param_type.dims, sdfg_arg_desc.shape, strict=True)
):
if isinstance(array_size, int) or str(array_size).isdigit():
# The array shape in this dimension is set at compile-time.
code.append(
f"assert {_cb_sdfg_argtypes}[{sdfg_arg_index}].shape[{i}] == {arg}_buffer_info.shape[{i}]"
f"assert {_cb_sdfg_argtypes}[{sdfg_arg_index}].shape[{i}] == {arg}.__gt_buffer_info__.shape[{i}]"
)
else:
# The array shape is defined as a sequence of expressions
Expand All @@ -150,25 +175,7 @@ def _parse_gt_param(
code=code,
sdfg_arglist=sdfg_arglist,
)
for i, (dim, array_stride) in enumerate(
zip(param_type.dims, sdfg_arg_desc.strides, strict=True)
):
arg_stride = f"{arg}_buffer_info.elem_strides[{i}]"
if isinstance(array_stride, int) or str(array_stride).isdigit():
# The array stride is set to constant value in this dimension.
code.append(
f"assert {_cb_sdfg_argtypes}[{sdfg_arg_index}].strides[{i}] == {arg_stride}"
)
else:
# The strides of a global array are defined by a sequence of SDFG symbols.
assert array_stride == gtx_dace_args.field_stride_symbol(param_name, dim)
_parse_gt_param(
param_name=array_stride.name,
param_type=gtx_dace_args.as_itir_type(array_stride.dtype),
arg=arg_stride,
code=code,
sdfg_arglist=sdfg_arglist,
)
_update_sdfg_array_strides(code, sdfg_arglist, arg, sdfg_arg_desc, sdfg_arg_index)

elif isinstance(param_type, ts.ScalarType):
assert isinstance(sdfg_arg_desc, dace.data.Scalar)
Expand All @@ -183,6 +190,38 @@ def _parse_gt_param(
raise ValueError(f"Unexpected paramter type {param_type}")


def _parse_gt_connectivities(
code: codegen.TextBlock, sdfg_arglist: dict[str, dace.data.Data]
) -> None:
for sdfg_arg_index, (arg_name, arg_desc) in enumerate(sdfg_arglist.items()):
if gtx_dace_args.is_connectivity_identifier(arg_name):
assert isinstance(arg_desc, dace.data.Array)
assert len(arg_desc.shape) == 2
assert isinstance(arg_desc.shape[1], int) or str(arg_desc.shape[1]).isdigit()
origin_size_arg = arg_desc.shape[0]
assert len(origin_size_arg.free_symbols) == 1
origin_size_param = next(iter(origin_size_arg.free_symbols))
m = gtx_dace_args.CONNECTIVITY_INDENTIFIER_RE.match(arg_name)
assert m is not None
conn_arg = f"{_cb_neighbor_table}_{m[1]}"
code.append(f'{conn_arg} = {_cb_offset_provider}["{m[1]}"]')
_update_sdfg_array_ptr(code, conn_arg, sdfg_arg_index)
_parse_gt_param( # set the size in the horizontal dimension
param_name=origin_size_param,
param_type=gtx_dace_args.as_itir_type(gtx_dace_args.FIELD_SYMBOL_DTYPE),
arg=f"{conn_arg}.__gt_buffer_info__.shape[0]",
code=code,
sdfg_arglist=sdfg_arglist,
)
_update_sdfg_array_strides(
code,
sdfg_arglist,
conn_arg,
arg_desc,
sdfg_arg_index,
)


def _create_sdfg_bindings(
program_source: stages.ProgramSource[languages.SDFG, languages.LanguageSettings],
bind_func_name: str,
Expand Down Expand Up @@ -212,12 +251,13 @@ def _create_sdfg_bindings(
code.append("from gt4py.next import common as gtx_common, field_utils")
code.empty_line()
code.append(
"def {funname}({arg0}, {arg1}, {arg2}, {arg3}):".format(
"def {funname}({arg0}, {arg1}, {arg2}, {arg3}, {arg4}):".format(
funname=bind_func_name,
arg0=_cb_device,
arg1=_cb_sdfg_argtypes,
arg2=_cb_args,
arg3=_cb_sdfg_call_args,
arg4=_cb_offset_provider,
)
)

Expand All @@ -233,6 +273,14 @@ def _create_sdfg_bindings(
assert isinstance(param.type_, ts.DataType)
_parse_gt_param(param.name, param.type_, arg, code, sdfg_arglist)

# In the regular case, the connectivity fields are allocated at the beginning
# of the application and then used during its entire lifetime and never reallocated.
# However, this might not be the case all the time, for example in unit tests
# where, due to limited lifetime of the fixtures, the connectivity fields
# might get reallocated. In order to avoid problems, we update the connectivity
# arrays as well in SDFG fastcall.
_parse_gt_connectivities(code, sdfg_arglist)

src = codegen.format_python_source(code.text)
return stages.BindingSource(src, library_deps=tuple())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import factory

from gt4py._core import definitions as core_defs, locking
from gt4py.next import config
from gt4py.next import common, config
from gt4py.next.otf import languages, stages, step_types, workflow
from gt4py.next.otf.compilation import cache as gtx_cache
from gt4py.next.program_processors.runners.dace.workflow import common as gtx_wfdcommon
Expand All @@ -33,7 +33,13 @@ class CompiledDaceProgram(stages.CompiledProgram):

# The compiled program contains a callable object to update the SDFG arguments list.
update_sdfg_ctype_arglist: Callable[
[core_defs.DeviceType, Sequence[dace.dtypes.Data], Sequence[Any], MutableSequence[Any]],
[
core_defs.DeviceType,
Sequence[dace.dtypes.Data],
Sequence[Any],
MutableSequence[Any],
common.OffsetProvider,
],
None,
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def decorated_program(
# `fun.csdfg_args` is `None`
# TODO(phimuell, edopao): Think about refactor the code such that the update
# of the argument vector is a Method of the `CompiledDaceProgram`.
update_sdfg_call_args(args, fun.csdfg_argv) # type: ignore[arg-type] # Will error out in first call.
update_sdfg_call_args(args, fun.csdfg_argv, offset_provider) # type: ignore[arg-type] # Will error out in first call.

except TypeError:
# First call. Construct the initial argument vector of the `CompiledDaceProgram`.
Expand Down
Loading