Skip to content

Commit d10a5da

Browse files
committed
Rename load/store_advanced to load/store_advanced_indexing
Signed-off-by: Qiqi Xiao <qiqix@nvidia.com>
1 parent 9229551 commit d10a5da

6 files changed

Lines changed: 59 additions & 59 deletions

File tree

changelog.d/load-store-advanced.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
22
<!--- SPDX-License-Identifier: Apache-2.0 -->
33

4-
- New `ct.load_advanced(array, indices)` and `ct.store_advanced(array, indices, tile)` for gathering/scattering along one dimension from/to a 2D or higher-rank array. A 1D integer `Tile` selects the sparse dimension; `ct.Slice(start, length)` selects a contiguous range along each dense dimension.
4+
- New `ct.load_advanced_indexing(array, indices)` and `ct.store_advanced_indexing(array, indices, tile)` for gathering/scattering along one dimension from/to a 2D or higher-rank array. A 1D integer `Tile` selects the sparse dimension; `ct.Slice(start, length)` selects a contiguous range along each dense dimension.

docs/source/operations.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ Load/Store
2222
num_tiles
2323
load
2424
store
25-
load_advanced
26-
store_advanced
25+
load_advanced_indexing
26+
store_advanced_indexing
2727
gather
2828
scatter
2929

src/cuda/tile/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@
115115
less,
116116
less_equal,
117117
load,
118-
load_advanced,
118+
load_advanced_indexing,
119119
log,
120120
log2,
121121
matmul,
@@ -150,7 +150,7 @@
150150
static_eval,
151151
static_iter,
152152
store,
153-
store_advanced,
153+
store_advanced_indexing,
154154
sub,
155155
sum,
156156
tan,
@@ -276,7 +276,7 @@
276276
"less",
277277
"less_equal",
278278
"load",
279-
"load_advanced",
279+
"load_advanced_indexing",
280280
"log",
281281
"log2",
282282
"matmul",
@@ -311,7 +311,7 @@
311311
"static_eval",
312312
"static_iter",
313313
"store",
314-
"store_advanced",
314+
"store_advanced_indexing",
315315
"sub",
316316
"sum",
317317
"tan",

src/cuda/tile/_ir/ops.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4875,7 +4875,7 @@ def _parse_advanced_index(indices: Var, ndim: int) -> tuple[int, tuple[int, ...]
48754875
items = list(indices.get_aggregate().items)
48764876
if len(items) != ndim:
48774877
raise TileTypeError(
4878-
f"load_advanced/store_advanced index length {len(items)} does not "
4878+
f"load_advanced_indexing/store_advanced_indexing index length {len(items)} does not "
48794879
f"match array rank {ndim}")
48804880

48814881
sparse_dims: list[int] = []
@@ -4901,7 +4901,7 @@ def _parse_advanced_index(indices: Var, ndim: int) -> tuple[int, tuple[int, ...]
49014901
if not length_var.is_constant():
49024902
raise TileTypeError(
49034903
f"ct.Slice length at dim {dim} must be a compile-time constant "
4904-
f"in load_advanced/store_advanced")
4904+
f"in load_advanced_indexing/store_advanced_indexing")
49054905
length_val = length_var.get_constant()
49064906
if not isinstance(length_val, int) or length_val <= 0:
49074907
raise TileTypeError(
@@ -4910,17 +4910,17 @@ def _parse_advanced_index(indices: Var, ndim: int) -> tuple[int, tuple[int, ...]
49104910
gs_index.append(item.get_aggregate().start)
49114911
else:
49124912
raise TileTypeError(
4913-
f"load_advanced/store_advanced index at dim {dim} must be a "
4913+
f"load_advanced_indexing/store_advanced_indexing index at dim {dim} must be a "
49144914
f"1D integer Tile (sparse dim) or ct.Slice(start, length) "
49154915
f"(dense dim), got type {item_ty}")
49164916

49174917
if len(sparse_dims) == 0:
49184918
raise TileTypeError(
4919-
"load_advanced/store_advanced: exactly one index must be a 1D "
4919+
"load_advanced_indexing/store_advanced_indexing: exactly one index must be a 1D "
49204920
"integer Tile (the sparse dim); none found")
49214921
if len(sparse_dims) > 1:
49224922
raise TileTypeError(
4923-
f"load_advanced/store_advanced: exactly one index must be a 1D "
4923+
f"load_advanced_indexing/store_advanced_indexing: exactly one index must be a 1D "
49244924
f"integer Tile (the sparse dim); found {len(sparse_dims)} at "
49254925
f"dims {sparse_dims}")
49264926

@@ -4932,13 +4932,13 @@ def _parse_advanced_index(indices: Var, ndim: int) -> tuple[int, tuple[int, ...]
49324932
return sparse_dims[0], tuple(tile_shape), tuple(gs_index)
49334933

49344934

4935-
@impl(ct.load_advanced, min_version=BytecodeVersion.V_13_3)
4935+
@impl(ct.load_advanced_indexing, min_version=BytecodeVersion.V_13_3)
49364936
def load_advanced_impl(array: Var, indices: Var, padding_mode: Var,
49374937
latency: Var, allow_tma: Var) -> Var:
49384938
array_ty = require_array_type(array)
49394939
if array_ty.ndim < 2:
49404940
raise TileTypeError(
4941-
"load_advanced requires a 2D or higher-rank array; "
4941+
"load_advanced_indexing requires a 2D or higher-rank array; "
49424942
"use ct.gather() for 1D arrays")
49434943
sparse_dim, tile_shape, gs_index = _parse_advanced_index(indices, array_ty.ndim)
49444944
padding_mode_val = require_constant_enum(padding_mode, PaddingMode)
@@ -4953,13 +4953,13 @@ def load_advanced_impl(array: Var, indices: Var, padding_mode: Var,
49534953
return result
49544954

49554955

4956-
@impl(ct.store_advanced, min_version=BytecodeVersion.V_13_3)
4956+
@impl(ct.store_advanced_indexing, min_version=BytecodeVersion.V_13_3)
49574957
def store_advanced_impl(array: Var, indices: Var, tile: Var,
49584958
latency: Var, allow_tma: Var):
49594959
array_ty = require_array_type(array)
49604960
if array_ty.ndim < 2:
49614961
raise TileTypeError(
4962-
"store_advanced requires a 2D or higher-rank array; "
4962+
"store_advanced_indexing requires a 2D or higher-rank array; "
49634963
"use ct.scatter() for 1D arrays")
49644964
sparse_dim, tile_shape, gs_index = _parse_advanced_index(indices, array_ty.ndim)
49654965
tile_ty = require_tile_type(tile)

src/cuda/tile/_stub.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -948,11 +948,11 @@ class Slice:
948948
_cutile_is_builtin = True
949949
"""A start + length index for array dimensions.
950950
951-
Used as a dense-dimension entry in :func:`load_advanced` and
952-
:func:`store_advanced`. ``start`` is an integer giving the
951+
Used as a dense-dimension entry in :func:`load_advanced_indexing` and
952+
:func:`store_advanced_indexing`. ``start`` is an integer giving the
953953
element-space start offset along this dimension; ``length`` is the
954954
tile size, and must be a power of two and must be a
955-
compile-time constant at the :func:`load_advanced`/:func:`store_advanced`
955+
compile-time constant at the :func:`load_advanced_indexing`/:func:`store_advanced_indexing`
956956
call site.
957957
958958
Args:
@@ -1414,10 +1414,10 @@ def kernel(x):
14141414

14151415

14161416
@stub
1417-
def load_advanced(array: Array, indices, /, *,
1418-
padding_mode: PaddingMode = PaddingMode.UNDETERMINED,
1419-
latency: Optional[int] = None,
1420-
allow_tma: Optional[bool] = None) -> Tile:
1417+
def load_advanced_indexing(array: Array, indices, /, *,
1418+
padding_mode: PaddingMode = PaddingMode.UNDETERMINED,
1419+
latency: Optional[int] = None,
1420+
allow_tma: Optional[bool] = None) -> Tile:
14211421
"""Loads a tile from non-contiguous slices of `array`.
14221422
14231423
``indices`` is a tuple of length ``array.ndim``. Exactly one entry must
@@ -1458,7 +1458,7 @@ def load_advanced(array: Array, indices, /, *,
14581458
@ct.kernel
14591459
def kernel(x, y, col_start):
14601460
row_indices = ct.arange(4, dtype=ct.int32)
1461-
tile = ct.load_advanced(x, (row_indices, ct.Slice(col_start, 4)),
1461+
tile = ct.load_advanced_indexing(x, (row_indices, ct.Slice(col_start, 4)),
14621462
padding_mode=ct.PaddingMode.ZERO)
14631463
ct.store(y, (0, 0), tile)
14641464
@@ -1472,17 +1472,17 @@ def kernel(x, y, col_start):
14721472
[[2, 3, 4, 5], [10, 11, 12, 13], [18, 19, 20, 21], [26, 27, 28, 29]]
14731473
14741474
.. seealso::
1475-
- :func:`store_advanced`
1475+
- :func:`store_advanced_indexing`
14761476
"""
14771477

14781478

14791479
@stub
1480-
def store_advanced(array: Array, indices, tile: TileOrScalar, /, *,
1481-
latency: Optional[int] = None,
1482-
allow_tma: Optional[bool] = None) -> None:
1480+
def store_advanced_indexing(array: Array, indices, tile: TileOrScalar, /, *,
1481+
latency: Optional[int] = None,
1482+
allow_tma: Optional[bool] = None) -> None:
14831483
"""Stores a `tile` into non-contiguous slices of `array`.
14841484
1485-
Uses the same ``indices`` convention as :func:`load_advanced` — exactly
1485+
Uses the same ``indices`` convention as :func:`load_advanced_indexing` — exactly
14861486
one entry is a 1-D integer :class:`Tile` (sparse dim) and the rest are
14871487
:class:`Slice` objects (dense dims).
14881488
The tile's shape must exactly match the shape implied by the indices.
@@ -1491,7 +1491,7 @@ def store_advanced(array: Array, indices, tile: TileOrScalar, /, *,
14911491
14921492
Args:
14931493
array (Array): Array to store into.
1494-
indices (tuple): Same convention as :func:`load_advanced`.
1494+
indices (tuple): Same convention as :func:`load_advanced_indexing`.
14951495
tile (Tile): Tile to store. Shape must exactly match the shape
14961496
implied by ``indices``.
14971497
latency (int, optional): DRAM traffic hint (1 = low, 10 = high).
@@ -1506,7 +1506,7 @@ def store_advanced(array: Array, indices, tile: TileOrScalar, /, *,
15061506
def kernel(y):
15071507
row_indices = ct.arange(4, dtype=ct.int32) + 1
15081508
tile = ct.full((4, 4), 1, dtype=y.dtype)
1509-
ct.store_advanced(y, (row_indices, ct.Slice(0, 4)), tile)
1509+
ct.store_advanced_indexing(y, (row_indices, ct.Slice(0, 4)), tile)
15101510
15111511
y = torch.zeros(6, 4, device='cuda', dtype=torch.int32)
15121512
ct.launch(stream, (1,), kernel, (y,))
@@ -1517,7 +1517,7 @@ def kernel(y):
15171517
[[0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [0, 0, 0, 0]]
15181518
15191519
.. seealso::
1520-
- :func:`load_advanced`
1520+
- :func:`load_advanced_indexing`
15211521
"""
15221522

15231523

0 commit comments

Comments
 (0)