Skip to content

Commit c2360bd

Browse files
committed
Support GatherScatterView via advanced indexing
Signed-off-by: Qiqi Xiao <qiqix@nvidia.com>
1 parent 84956b9 commit c2360bd

11 files changed

Lines changed: 700 additions & 6 deletions

File tree

changelog.d/load-store-advanced.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- New `ct.load_advanced(array, indices)` and `ct.store_advanced(array, indices, tile)` for gathering/scattering along one dimension from/to a 2D or higher-rank array. A 1D integer `Tile` selects the sparse dimension; `ct.Slice(start, length)` selects a contiguous range along each dense dimension.

docs/source/data/slice.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
.. SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
..
3+
.. SPDX-License-Identifier: Apache-2.0
4+
5+
.. currentmodule:: cuda.tile
6+
7+
.. _data-slice-cuda-tile-slice:
8+
9+
cuda.tile.Slice
10+
===============
11+
12+
.. autoclass:: Slice
13+
:members:
14+
:undoc-members:
15+
:special-members:
16+
:exclude-members: __annotations__, __dict__, __module__, __weakref__

docs/source/operations.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ Load/Store
2222
num_tiles
2323
load
2424
store
25+
load_advanced
26+
store_advanced
2527
gather
2628
scatter
2729

@@ -234,12 +236,14 @@ Classes
234236

235237
Array
236238
TiledView
239+
Slice
237240

238241
.. toctree::
239242
:hidden:
240243

241244
data/array
242245
data/tiled_view
246+
data/slice
243247

244248

245249
.. _operations-enums:

src/cuda/tile/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
ListAnnotation,
6565
Scalar,
6666
ScalarInt64,
67+
Slice,
6768
Tile,
6869
TiledView,
6970

@@ -114,6 +115,7 @@
114115
less,
115116
less_equal,
116117
load,
118+
load_advanced,
117119
log,
118120
log2,
119121
matmul,
@@ -148,6 +150,7 @@
148150
static_eval,
149151
static_iter,
150152
store,
153+
store_advanced,
151154
sub,
152155
sum,
153156
tan,
@@ -222,6 +225,7 @@
222225
"ListAnnotation",
223226
"Scalar",
224227
"ScalarInt64",
228+
"Slice",
225229
"Tile",
226230
"TiledView",
227231

@@ -272,6 +276,7 @@
272276
"less",
273277
"less_equal",
274278
"load",
279+
"load_advanced",
275280
"log",
276281
"log2",
277282
"matmul",
@@ -306,6 +311,7 @@
306311
"static_eval",
307312
"static_iter",
308313
"store",
314+
"store_advanced",
309315
"sub",
310316
"sum",
311317
"tan",

src/cuda/tile/_ir/ir.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,15 @@ def as_tuple(self) -> tuple["Var", ...]:
309309
return (self.array,)
310310

311311

312+
@dataclass
313+
class IndexSliceValue(AggregateValue):
314+
start: Var
315+
length: Var
316+
317+
def as_tuple(self) -> tuple[Var, ...]:
318+
return (self.start, self.length)
319+
320+
312321
@dataclass
313322
class RawArrayMemoryValue(AggregateValue):
314323
base_ptr: Var

src/cuda/tile/_ir/ops.py

Lines changed: 150 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
enter_nested_block, nested_block, PhiState, LoopVarState,
2828
TupleValue, make_aggregate, RangeValue, BoundMethodValue, ArrayValue, ConstantState,
2929
ListValue, TiledViewValue, ClosureValue, MemoryEffect, attribute, operand,
30-
BlockRestriction, FormattedStringValue, RawArrayMemoryValue, DataclassValue, DataclassInfo
30+
BlockRestriction, FormattedStringValue, RawArrayMemoryValue, DataclassValue, DataclassInfo,
31+
IndexSliceValue
3132
)
3233
from .type import PointerTy
3334
from . import hir, hir_stubs
@@ -59,11 +60,12 @@
5960
typeof_pyval, dtype_registry, loose_type_of_pyval, get_constant_value, get_dataclass_info,
6061
)
6162
from .type import (
62-
PartitionViewTy, StridedViewTy, TupleTy, TileTy, NoneType, BoundMethodTy, ArrayTy,
63+
PartitionViewTy, StridedViewTy, GatherScatterViewTy, TupleTy, TileTy, NoneType,
64+
BoundMethodTy, ArrayTy,
6365
ListTy, make_tile_ty, SliceType, DTypeConstructor, RangeIterType, Type,
6466
NONE, ModuleTy, TypeTy, LooselyTypedScalar, DTypeSpec, StringTy, InvalidType,
6567
ClosureTy, LiveCapturedScope, TokenTy, TiledViewTy, FormattedStringTy,
66-
StringFormat, FormattedPiece, RawArrayMemoryTy, DataclassTy
68+
StringFormat, FormattedPiece, RawArrayMemoryTy, DataclassTy, IndexSliceTy
6769
)
6870
from cuda.tile._datatype import (
6971
DType, is_integral, is_float, is_signed, is_boolean,
@@ -2375,6 +2377,29 @@ def _materialize_tiled_view(array: Var,
23752377
return _make_partition_view(array, tile_shape, order, padding_mode)
23762378

23772379

2380+
@dataclass(eq=False)
2381+
class MakeGatherScatterView(Operation, opcode="make_gather_scatter_view"):
2382+
array: Var = operand()
2383+
2384+
@override
2385+
def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
2386+
gs_view_ty = self.result_var.get_type()
2387+
return bc.encode_MakeGatherScatterViewOp(ctx.builder,
2388+
typeid(ctx.type_table, gs_view_ty),
2389+
ctx.get_value(self.array))
2390+
2391+
2392+
def make_gather_scatter_view(array: Var, tile_shape: Sequence[int],
2393+
sparse_dim: int,
2394+
padding_mode: PaddingMode) -> Var:
2395+
array_ty = array.get_type()
2396+
assert isinstance(array_ty, ArrayTy)
2397+
view_ty = GatherScatterViewTy(array_ty, tuple(tile_shape), sparse_dim, padding_mode)
2398+
ret = add_operation(MakeGatherScatterView, view_ty, array=array)
2399+
ret.set_aggregate(array.get_aggregate())
2400+
return ret
2401+
2402+
23782403
@dataclass(eq=False)
23792404
class TileLoad(Operation, opcode="tile_load", memory_effect=MemoryEffect.LOAD):
23802405
latency: Optional[int] = attribute()
@@ -4826,6 +4851,128 @@ def tiled_view_atomic_rmw_impl(int_mode: Optional[AtomicRMWMode],
48264851
view=view, index=index_items, update=update)
48274852

48284853

4854+
@impl(ct.Slice)
4855+
def slice_index_constructor_impl(start: Var, length: Var) -> Var:
4856+
start_ty = require_signed_integer_0d_tile_type(start)
4857+
length_ty = require_signed_integer_0d_tile_type(length)
4858+
res_type = IndexSliceTy(start_ty, length_ty)
4859+
res_loose_type = IndexSliceTy(start.get_loose_type(), length.get_loose_type())
4860+
return make_aggregate(IndexSliceValue(start, length), res_type, res_loose_type)
4861+
4862+
4863+
def _parse_advanced_index(indices: Var, ndim: int) -> tuple[int, tuple[int, ...], tuple[Var, ...]]:
4864+
"""Unpack, classify, validate, and build the gather scatter view index.
4865+
4866+
Returns (sparse_dim, tile_shape, gs_index).
4867+
"""
4868+
require_tuple_type(indices)
4869+
items = list(indices.get_aggregate().items)
4870+
if len(items) != ndim:
4871+
raise TileTypeError(
4872+
f"load_advanced/store_advanced index length {len(items)} does not "
4873+
f"match array rank {ndim}")
4874+
4875+
sparse_dims: list[int] = []
4876+
tile_shape: list[int] = []
4877+
gs_index: list[Var] = []
4878+
4879+
for dim, item in enumerate(items):
4880+
item_ty = item.get_type()
4881+
if isinstance(item_ty, TileTy):
4882+
if item_ty.ndim != 1:
4883+
raise TileTypeError(
4884+
f"Sparse index at dim {dim} must be a 1D integer tile, "
4885+
f"got {item_ty.ndim}D")
4886+
if not is_integral(item_ty.dtype):
4887+
raise TileTypeError(
4888+
f"Sparse index at dim {dim} must be an integer tile, "
4889+
f"got dtype {item_ty.dtype}")
4890+
sparse_dims.append(dim)
4891+
tile_shape.append(item_ty.shape[0])
4892+
gs_index.append(item)
4893+
elif isinstance(item_ty, IndexSliceTy):
4894+
length_var = item.get_aggregate().length
4895+
if not length_var.is_constant():
4896+
raise TileTypeError(
4897+
f"ct.Slice length at dim {dim} must be a compile-time constant "
4898+
f"in load_advanced/store_advanced")
4899+
length_val = length_var.get_constant()
4900+
if not isinstance(length_val, int) or length_val <= 0:
4901+
raise TileTypeError(
4902+
f"ct.Slice length at dim {dim} must be a positive integer, got {length_val}")
4903+
tile_shape.append(length_val)
4904+
gs_index.append(item.get_aggregate().start)
4905+
else:
4906+
raise TileTypeError(
4907+
f"load_advanced/store_advanced index at dim {dim} must be a "
4908+
f"1D integer Tile (sparse dim) or ct.Slice(start, length) "
4909+
f"(dense dim), got type {item_ty}")
4910+
4911+
if len(sparse_dims) == 0:
4912+
raise TileTypeError(
4913+
"load_advanced/store_advanced: exactly one index must be a 1D "
4914+
"integer Tile (the sparse dim); none found")
4915+
if len(sparse_dims) > 1:
4916+
raise TileTypeError(
4917+
f"load_advanced/store_advanced: exactly one index must be a 1D "
4918+
f"integer Tile (the sparse dim); found {len(sparse_dims)} at "
4919+
f"dims {sparse_dims}")
4920+
4921+
for dim, n in enumerate(tile_shape):
4922+
if not _is_power_of_2(n):
4923+
raise TileTypeError(
4924+
f"Index at dim {dim} has size {n}; must be a power of two")
4925+
4926+
return sparse_dims[0], tuple(tile_shape), tuple(gs_index)
4927+
4928+
4929+
@impl(ct.load_advanced, min_version=BytecodeVersion.V_13_3)
4930+
def load_advanced_impl(array: Var, indices: Var, padding_mode: Var,
4931+
latency: Var, allow_tma: Var) -> Var:
4932+
array_ty = require_array_type(array)
4933+
if array_ty.ndim < 2:
4934+
raise TileTypeError(
4935+
"load_advanced requires a 2D or higher-rank array; "
4936+
"use ct.gather() for 1D arrays")
4937+
sparse_dim, tile_shape, gs_index = _parse_advanced_index(indices, array_ty.ndim)
4938+
padding_mode_val = require_constant_enum(padding_mode, PaddingMode)
4939+
latency_val = require_optional_constant_int(latency)
4940+
allow_tma_val = require_optional_constant_bool(allow_tma)
4941+
_check_load_store_hints(latency_val, allow_tma_val)
4942+
4943+
view = make_gather_scatter_view(array, tile_shape, sparse_dim, padding_mode_val)
4944+
result, _token = add_operation(TileLoad, (make_tile_ty(array_ty.dtype, tile_shape), TokenTy()),
4945+
view=view, index=gs_index,
4946+
latency=latency_val, allow_tma=allow_tma_val)
4947+
return result
4948+
4949+
4950+
@impl(ct.store_advanced, min_version=BytecodeVersion.V_13_3)
4951+
def store_advanced_impl(array: Var, indices: Var, tile: Var,
4952+
latency: Var, allow_tma: Var):
4953+
array_ty = require_array_type(array)
4954+
if array_ty.ndim < 2:
4955+
raise TileTypeError(
4956+
"store_advanced requires a 2D or higher-rank array; "
4957+
"use ct.scatter() for 1D arrays")
4958+
sparse_dim, tile_shape, gs_index = _parse_advanced_index(indices, array_ty.ndim)
4959+
tile_ty = require_tile_type(tile)
4960+
if tile_ty.shape != tile_shape:
4961+
raise TileTypeError(
4962+
f"Tile shape {tile_ty.shape} does not match the shape implied by "
4963+
f"indices {tile_shape}")
4964+
tile = _implicit_cast(tile, array_ty.dtype,
4965+
"Stored tile dtype is incompatible with array dtype")
4966+
latency_val = require_optional_constant_int(latency)
4967+
allow_tma_val = require_optional_constant_bool(allow_tma)
4968+
_check_load_store_hints(latency_val, allow_tma_val)
4969+
4970+
view = make_gather_scatter_view(array, tile_shape, sparse_dim, PaddingMode.UNDETERMINED)
4971+
[_token] = add_operation(TileStore, (TokenTy(),),
4972+
view=view, index=gs_index, tile=tile,
4973+
latency=latency_val, allow_tma=allow_tma_val)
4974+
4975+
48294976
def store_var(local_idx: int, value: Var, loc: Loc | None = None):
48304977
scope = Scope.get_current()
48314978
new_var = scope.local.redefine(local_idx, loc or Builder.get_current().loc)

src/cuda/tile/_ir/type.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,58 @@ def __str__(self):
450450
f"padding_mode={self.padding_mode}]")
451451

452452

453+
# ============== GatherScatterView Type ===============
454+
455+
456+
@dataclass(frozen=True)
457+
class GatherScatterViewTy(Type):
458+
array_ty: ArrayTy
459+
tile_shape: tuple[int, ...]
460+
sparse_dim: int
461+
padding_mode: PaddingMode
462+
463+
def is_aggregate(self) -> bool:
464+
return True
465+
466+
def aggregate_item_types(self) -> tuple["Type", ...]:
467+
return self.array_ty.aggregate_item_types()
468+
469+
def make_aggregate_value(self, items: tuple["Var", ...]) -> "AggregateValue":
470+
return self.array_ty.make_aggregate_value(items)
471+
472+
@property
473+
def dtype(self):
474+
return self.array_ty.dtype
475+
476+
def __str__(self):
477+
return (f"GatherScatterView[{self.array_ty},tile_shape={self.tile_shape},"
478+
f"sparse_dim={self.sparse_dim},padding_mode={self.padding_mode}]")
479+
480+
481+
# ============== IndexSlice Type ===============
482+
483+
484+
@dataclass(frozen=True)
485+
class IndexSliceTy(Type):
486+
"""Type of a ct.Slice(start, length)."""
487+
start_ty: "Type"
488+
length_ty: "Type"
489+
490+
def is_aggregate(self) -> bool:
491+
return True
492+
493+
def aggregate_item_types(self) -> tuple["Type", ...]:
494+
return (self.start_ty, self.length_ty)
495+
496+
def make_aggregate_value(self, items: tuple["Var", ...]) -> "AggregateValue":
497+
from .ir import IndexSliceValue
498+
assert len(items) == 2
499+
return IndexSliceValue(items[0], items[1])
500+
501+
def __str__(self) -> str:
502+
return f"IndexSlice[start_ty={self.start_ty}, length_ty={self.length_ty}]"
503+
504+
453505
# ============== TiledView Type ===============
454506

455507
@dataclass(frozen=True)

src/cuda/tile/_ir/typing_support.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def get_signature(f) -> inspect.Signature:
118118

119119

120120
def is_supported_builtin_func(x: Any) -> bool:
121-
return _safe_get(BUILTIN_FUNCS, x) is not None
121+
return _safe_get(BUILTIN_FUNCS, x) is not None or getattr(x, '_cutile_is_builtin', False)
122122

123123

124124
def typeof_pyval(val) -> Type:

src/cuda/tile/_ir2bytecode.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
get_default_rounding_mode,
2121
)
2222
from cuda.tile._ir.type import (
23-
PartitionViewTy, StridedViewTy, Type, TileTy, PointerTy, TokenTy, TupleTy, ArrayTy,
24-
size_to_bytecode,
23+
PartitionViewTy, StridedViewTy, GatherScatterViewTy, Type, TileTy, PointerTy, TokenTy,
24+
TupleTy, ArrayTy, size_to_bytecode,
2525
)
2626

2727

@@ -61,6 +61,11 @@ def typeid(tt: bc.TypeTable, ty: Type) -> bc.TypeId:
6161
ty.order, padding_value)
6262
else:
6363
return tt.partition_view(ty.tile_shape, tv_id, ty.order, padding_value)
64+
elif isinstance(ty, GatherScatterViewTy):
65+
padding_value = padding_mode_to_bytecode[ty.padding_mode]
66+
assert isinstance(ty.array_ty, ArrayTy)
67+
tv_id = tensor_view_typeid(tt, ty.array_ty)
68+
return tt.gather_scatter_view(ty.tile_shape, tv_id, ty.sparse_dim, padding_value)
6469
else:
6570
raise NotImplementedError(f"Lowering type '{ty}' is not supported")
6671

0 commit comments

Comments
 (0)