Skip to content

Commit e517b6d

Browse files
blinxthaijieg
authored andcommitted
tiled view atomic ops
Signed-off-by: Boyan Li <boyanl@nvidia.com>
1 parent a55c8c2 commit e517b6d

14 files changed

Lines changed: 677 additions & 175 deletions

changelog.d/atomic-add-bf16.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- `ct.atomic_add()` now supports `bfloat16` operands on Hopper (sm_90) and
5+
newer architectures.

changelog.d/tiled-view-atomic.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- New `TiledView.atomic_add`, `TiledView.atomic_max`, `TiledView.atomic_min`,
5+
`TiledView.atomic_and`, `TiledView.atomic_or`, and `TiledView.atomic_xor`
6+
methods for performing element-wise atomic read-modify-write operations on
7+
a tiled view at a given tile index.

src/cuda/tile/_ir/op_impl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ def _check_version():
158158
if min_version is not None and cur_version < min_version:
159159
raise TileUnsupportedFeatureError(
160160
f"{stub.__name__} requires tileiras "
161-
f"{min_version.major()}.{min_version.minor()} or later. "
162-
f"Current version is {cur_version.major()}.{cur_version.minor()}."
161+
f"{min_version.as_string()} or later. "
162+
f"Current version is {cur_version.as_string()}."
163163
)
164164

165165
def decorate(func):

src/cuda/tile/_ir/ops.py

Lines changed: 210 additions & 72 deletions
Large diffs are not rendered by default.

src/cuda/tile/_ir/ops_utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ def check_rd_and_ftz(fn: str, rounding_mode: Optional[RoundingMode], flush_to_ze
160160
if cur_version < min_version:
161161
raise TileUnsupportedFeatureError(
162162
f'{fn} rounding_mode={rounding_mode.value} requires tileiras '
163-
f'{min_version.major()}.{min_version.minor()} or later. '
164-
f'Current version is {cur_version.major()}.{cur_version.minor()}.')
163+
f'{min_version.as_string()} or later. '
164+
f'Current version is {cur_version.as_string()}.')
165165
if not datatype.is_unrestricted_float(dtype):
166166
raise TileTypeError(
167167
f'Rounding mode can only be used for unrestricted float types, '
@@ -362,10 +362,21 @@ def validate_memory_order_and_scope(
362362
f"Invalid memory order for {opcode}. "
363363
f"Got {memory_order}, expected one of {formatted_expected}"
364364
)
365+
366+
if memory_scope not in operation_type.VALID_MEMORY_SCOPES:
367+
formatted_expected = ", ".join(
368+
str(scope) for scope in operation_type.VALID_MEMORY_SCOPES
369+
)
370+
raise TileTypeError(
371+
f"Invalid memory scope for {opcode}. "
372+
f"Got {memory_scope}, expected one of {formatted_expected}"
373+
)
374+
365375
if memory_order == MemoryOrder.WEAK and memory_scope != MemoryScope.NONE:
366376
raise TileTypeError(
367377
f"{opcode} with WEAK memory ordering cannot specify a memory scope"
368378
)
379+
369380
if memory_order != MemoryOrder.WEAK and memory_scope == MemoryScope.NONE:
370381
raise TileTypeError(
371382
f"{opcode} with {memory_order.name} memory ordering requires a memory scope"

src/cuda/tile/_ir/type.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -451,15 +451,12 @@ def __str__(self):
451451

452452
# ============== TiledView Type ===============
453453

454-
454+
@dataclass(frozen=True)
455455
class TiledViewTy(Type):
456-
def __init__(self, array_ty: ArrayTy, tile_shape: tuple[int, ...],
457-
padding_mode: PaddingMode,
458-
traversal_steps: Optional[tuple[int, ...]] = None):
459-
self.array_ty = array_ty
460-
self.tile_shape = tile_shape
461-
self.padding_mode = padding_mode
462-
self.traversal_steps = traversal_steps
456+
array_ty: ArrayTy
457+
tile_shape: tuple[int, ...]
458+
padding_mode: PaddingMode
459+
traversal_steps: Optional[tuple[int, ...]] = None
463460

464461
def is_aggregate(self) -> bool:
465462
return True
@@ -480,17 +477,6 @@ def ndim(self):
480477
def dtype(self):
481478
return self.array_ty.dtype
482479

483-
def __eq__(self, other: "Type"):
484-
return (isinstance(other, TiledViewTy)
485-
and self.array_ty == other.array_ty
486-
and self.tile_shape == other.tile_shape
487-
and self.padding_mode == other.padding_mode
488-
and self.traversal_steps == other.traversal_steps)
489-
490-
def __hash__(self):
491-
return hash(("TiledViewTy", self.array_ty, self.tile_shape,
492-
self.padding_mode, self.traversal_steps))
493-
494480
def __str__(self):
495481
return (f"TiledView[{self.array_ty},tile_shape={self.tile_shape},"
496482
f"padding_mode={self.padding_mode},traversal_steps={self.traversal_steps}]")

src/cuda/tile/_passes/check_dtype_support.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
import math
66
from typing import Callable
77
from cuda.tile._ir.ir import Block
8-
from cuda.tile._ir.ops import TypedConst
8+
from cuda.tile._ir.ops import TileAtomicRMW, TileAtomicRedView, TypedConst, AtomicRMWMode
99
from cuda.tile._ir.type import TileTy, PointerTy, Type
10-
from cuda.tile._datatype import DType, float4_e2m1fn, float8_e4m3fn, float8_e5m2, float8_e8m0fnu
10+
from cuda.tile._datatype import (
11+
DType, float4_e2m1fn, float8_e4m3fn, float8_e5m2, float8_e8m0fnu, bfloat16
12+
)
1113
from cuda.tile._bytecode.version import BytecodeVersion
1214
from cuda.tile._exception import TileUnsupportedFeatureError, TileValueError
1315

@@ -65,8 +67,33 @@ def _check_const_value(op: TypedConst):
6567
raise TileValueError(msg, loc=op.loc)
6668

6769

68-
def _check_dtype(dtype: DType, sm_arch: str, version: BytecodeVersion, loc):
69-
sm_number = int(sm_arch.removeprefix("sm_"))
70+
def _check_atomic_rmw_dtype(op: TileAtomicRedView | TileAtomicRMW,
71+
sm_arch: str,
72+
sm_number: int,
73+
version: BytecodeVersion):
74+
dtypes = (_extract_dtypes(op.view.try_get_type())
75+
if isinstance(op, TileAtomicRedView) else
76+
_extract_dtypes(op.result_vars[0].try_get_type()))
77+
if not (op.mode == AtomicRMWMode.ADD_FLOAT and bfloat16 in dtypes):
78+
return
79+
80+
if sm_number < 90:
81+
raise TileUnsupportedFeatureError(
82+
f"{bfloat16} is not supported by atomic add on {sm_arch}",
83+
loc=op.loc
84+
)
85+
86+
min_version = BytecodeVersion.V_13_3
87+
if version < min_version:
88+
raise TileUnsupportedFeatureError(
89+
f"{bfloat16} on atomic add requires tileiras"
90+
f" {min_version.as_string()} or later."
91+
f" Current version is {version.as_string()}.",
92+
loc=op.loc
93+
)
94+
95+
96+
def _check_dtype(dtype: DType, sm_arch: str, sm_number: int, version: BytecodeVersion, loc):
7097
min_sm = _DTYPE_MIN_SM.get(dtype)
7198
if min_sm is not None and sm_number < min_sm:
7299
raise TileUnsupportedFeatureError(
@@ -78,17 +105,21 @@ def _check_dtype(dtype: DType, sm_arch: str, version: BytecodeVersion, loc):
78105
if min_version is not None and version < min_version:
79106
raise TileUnsupportedFeatureError(
80107
f"{dtype} requires tileiras"
81-
f" {min_version.major()}.{min_version.minor()} or later."
82-
f" Current version is {version.major()}.{version.minor()}.",
108+
f" {min_version.as_string()} or later."
109+
f" Current version is {version.as_string()}.",
83110
loc=loc,
84111
)
85112

86113

87114
def check_dtype_support(root_block: Block, sm_arch: str, version: BytecodeVersion) -> None:
115+
sm_number = int(sm_arch.removeprefix("sm_"))
88116
for op in root_block.traverse():
89117
if isinstance(op, TypedConst):
90118
_check_const_value(op)
91119

120+
if isinstance(op, (TileAtomicRedView, TileAtomicRMW)):
121+
_check_atomic_rmw_dtype(op, sm_arch, sm_number, version)
122+
92123
all_dtypes = set().union(*(_extract_dtypes(v.try_get_type()) for v in op.all_inputs()))
93124
for dtype in all_dtypes:
94-
_check_dtype(dtype, sm_arch, version, op.loc)
125+
_check_dtype(dtype, sm_arch, sm_number, version, op.loc)

src/cuda/tile/_passes/token_order.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from cuda.tile._ir.ops import (
1616
Break, Continue, EndBranch, IfElse,
1717
JoinTokens, Loop, MakeToken,
18-
TileAtomicCAS, TileAtomicRMW, LoadPointer,
18+
TileAtomicCAS, TileAtomicRMW, LoadPointer, TileAtomicRedView,
1919
TileLoad, StorePointer,
2020
TileStore, TileAssert, TilePrintf,
2121
)
@@ -132,7 +132,8 @@ def get_memory_effects(cur_op):
132132
return EMPTY_MEMORY_EFFECTS
133133

134134
has_acquire_order = False
135-
if isinstance(cur_op, (TileAtomicCAS, TileAtomicRMW, TileLoad, TileStore)):
135+
if isinstance(cur_op, (TileAtomicCAS, TileAtomicRMW, TileAtomicRedView,
136+
TileLoad, TileStore)):
136137
has_acquire_order = memory_order_has_acquire(cur_op.memory_order)
137138

138139
return MemoryEffects({dataflow_result[_get_input_var(cur_op).name].alias_set: effect},
@@ -214,7 +215,7 @@ def _to_token_order_in_block(block: Block,
214215
token_map[last_op_key] = result_tok
215216
token_map[last_store_key] = result_tok
216217

217-
elif isinstance(op, (TileAtomicCAS, TileAtomicRMW)):
218+
elif isinstance(op, (TileAtomicCAS, TileAtomicRMW, TileAtomicRedView)):
218219
alias_set = context.dataflow_result[_get_input_var(op).name].alias_set
219220
last_op_key = _last_op_key(alias_set)
220221
last_store_key = _last_store_key(alias_set)
@@ -224,7 +225,11 @@ def _to_token_order_in_block(block: Block,
224225
if maybe_input_tok_join_op:
225226
operations.append(maybe_input_tok_join_op)
226227

227-
_, result_tok = op.result_vars
228+
if isinstance(op, TileAtomicRedView):
229+
[result_tok] = op.result_vars
230+
else:
231+
_, result_tok = op.result_vars
232+
228233
operations.append(dataclasses.replace(op, token=input_tok))
229234

230235
token_map[last_op_key] = result_tok
@@ -460,7 +465,7 @@ def _get_parallel_stores(
460465
alias_set_to_mem_ops = defaultdict(list)
461466
for op in loop_op.body.operations:
462467
if isinstance(op, (TileLoad, StorePointer, LoadPointer, TileStore,
463-
TileAtomicCAS, TileAtomicRMW)):
468+
TileAtomicCAS, TileAtomicRMW, TileAtomicRedView)):
464469
alias_set = context.dataflow_result[_get_input_var(op).name].alias_set
465470
alias_set_to_mem_ops[alias_set].append(op)
466471

src/cuda/tile/_stub.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,57 @@ def __rmatmul__(self, other) -> "Tile":
707707
TileOrScalar = Union[Tile, Scalar]
708708

709709

710+
def _doc_tv_atomic_rmw_op(*, testoutput: int):
711+
def decorator(f):
712+
op_name = f.__name__
713+
f.__doc__ += f"""\
714+
715+
This method does not return a value.
716+
717+
For each individual element, the operation is performed atomically,
718+
but the operation as a whole is not atomic, and the order of individual
719+
writes is unspecified.
720+
721+
`update`'s shape must be broadcastable to :attr:`tile_shape`.
722+
723+
For a tile that partially extends beyond the tiled view boundaries,
724+
out-of-bound elements are ignored.
725+
If the tile lies entirely outside the tiled view, the behavior is
726+
undefined.
727+
728+
Use this operation instead of ct.{op_name} for better performance
729+
when modified value is not needed.
730+
731+
Args:
732+
index (tuple[int,...]): An index in the |tiled view|'s tile space.
733+
update (Tile): The update values.
734+
735+
Returns:
736+
None:
737+
738+
Examples:
739+
740+
.. testcode::
741+
:template: setup_only.py
742+
743+
@ct.kernel
744+
def kernel(x):
745+
tv = x.tiled_view(4)
746+
update = ct.full((4,), 1, dtype=ct.int32)
747+
tv.{op_name}(0, update)
748+
749+
x = torch.zeros(4, dtype=torch.int32, device='cuda')
750+
ct.launch(stream, (1,), kernel, (x,))
751+
print(x.tolist())
752+
753+
.. testoutput::
754+
755+
{[testoutput] * 4}
756+
"""
757+
return f
758+
return decorator
759+
760+
710761
class TiledView:
711762
"""Class for |tiled view| objects."""
712763

@@ -836,6 +887,62 @@ def kernel(x):
836887
[99, 99, 99, 99, 0, 0, 0, 0]
837888
"""
838889

890+
@_doc_tv_atomic_rmw_op(testoutput=(0 + 1))
891+
@stub
892+
def atomic_add(self, index: Shape, update: Tile) -> None:
893+
"""Atomically adds `update` to the |tiled view| at the given tile `index`.
894+
895+
If `update`'s dtype differs from the view's dtype, an implicit cast is
896+
performed.
897+
"""
898+
899+
@_doc_tv_atomic_rmw_op(testoutput=max(0, 1))
900+
@stub
901+
def atomic_max(self, index: Shape, update: Tile) -> None:
902+
"""Atomically applies element-wise maximum with update to the |tiled view|
903+
at the given tile `index`.
904+
905+
If `update`'s dtype differs from the view's dtype, an implicit cast is
906+
performed.
907+
"""
908+
909+
@_doc_tv_atomic_rmw_op(testoutput=min(0, 1))
910+
@stub
911+
def atomic_min(self, index: Shape, update: Tile) -> None:
912+
"""Atomically applies element-wise minimum with update to the |tiled view|
913+
at the given tile `index`.
914+
915+
If `update`'s dtype differs from the view's dtype, an implicit cast is
916+
performed.
917+
"""
918+
919+
@_doc_tv_atomic_rmw_op(testoutput=(0 & 1))
920+
@stub
921+
def atomic_and(self, index: Shape, update: Tile) -> None:
922+
"""Atomically applies bitwise AND of `update` to the |tiled view| at the given tile `index`.
923+
924+
`update`'s dtype must exactly match the view's dtype; no implicit cast is
925+
performed.
926+
"""
927+
928+
@_doc_tv_atomic_rmw_op(testoutput=(0 | 1))
929+
@stub
930+
def atomic_or(self, index: Shape, update: Tile) -> None:
931+
"""Atomically applies bitwise OR of `update` to the |tiled view| at the given tile `index`.
932+
933+
`update`'s dtype must exactly match the view's dtype; no implicit cast is
934+
performed.
935+
"""
936+
937+
@_doc_tv_atomic_rmw_op(testoutput=(0 ^ 1))
938+
@stub
939+
def atomic_xor(self, index: Shape, update: Tile) -> None:
940+
"""Atomically applies bitwise XOR of `update` to the |tiled view| at the given tile `index`.
941+
942+
`update`'s dtype must exactly match the view's dtype; no implicit cast is
943+
performed.
944+
"""
945+
839946

840947
###############################################################################
841948
# Constantness Hints

0 commit comments

Comments
 (0)