Skip to content

Commit 0d172bf

Browse files
committed
Expose use_fast_acc for fp8 mma
Signed-off-by: Qiqi Xiao <qiqix@nvidia.com>
1 parent 398bd57 commit 0d172bf

4 files changed

Lines changed: 65 additions & 17 deletions

File tree

changelog.d/mma-fast-acc.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- Add `use_fast_acc` keyword argument to `ct.mma()` to enable fast accumulation
5+
mode for fp8 inputs (`float8_e4m3fn`, `float8_e5m2`) on Hopper GPUs.

src/cuda/tile/_ir/ops.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from cuda.tile import RoundingMode, MemoryOrder, MemoryScope
1919
from cuda.tile._mutex import tile_mutex
2020
from cuda.tile._exception import TileTypeError, TileSyntaxError, TileError, \
21-
TileStaticAssertionError, TileStaticEvalError, TileValueError
21+
TileStaticAssertionError, TileStaticEvalError, TileValueError, TileUnsupportedFeatureError
2222
from cuda.tile._ir.ir import (
2323
Operation, Var, Loc, Block,
2424
add_operation, Builder,
@@ -3224,6 +3224,7 @@ def _matmul_broadcast_shape(x_shape: _TileShape, y_shape: _TileShape) -> \
32243224

32253225
@dataclass(eq=False)
32263226
class TileMma(Operation, opcode="tile_mma"):
3227+
use_fast_acc: bool = attribute(default=False)
32273228
x: Var = operand()
32283229
y: Var = operand()
32293230
acc: Var = operand()
@@ -3243,13 +3244,13 @@ def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
32433244
return bc.encode_MmaIOp(ctx.builder, res_typeid, x_value, y_value,
32443245
acc_value, signedness_lhs, signedness_rhs)
32453246
else:
3246-
# TODO: consider expose fast_acc
32473247
return bc.encode_MmaFOp(ctx.builder, res_typeid, x_value, y_value,
3248-
acc_value, fast_acc=False)
3248+
acc_value, fast_acc=self.use_fast_acc)
32493249

32503250

32513251
@impl(ct.mma)
3252-
def mma_impl(x: Var, y: Var, acc: Var) -> Var:
3252+
def mma_impl(x: Var, y: Var, acc: Var, use_fast_acc: Var) -> Var:
3253+
use_fast_acc = require_constant_bool(use_fast_acc)
32533254
x_tile_type = require_tile_type(x)
32543255
y_tile_type = require_tile_type(y)
32553256
acc_tile_type = require_tile_type(acc)
@@ -3263,10 +3264,21 @@ def mma_impl(x: Var, y: Var, acc: Var) -> Var:
32633264
x_shape, y_shape, _, output_shape = _matmul_broadcast_shape(x_shape_orig, y_shape_orig)
32643265
if acc_shape_orig != output_shape:
32653266
raise TileTypeError(f'Expect acc shape to be {output_shape}, got {acc_shape_orig}')
3267+
if use_fast_acc:
3268+
if x_tile_type.dtype not in (datatype.float8_e4m3fn, datatype.float8_e5m2):
3269+
raise TileTypeError(
3270+
f'use_fast_acc is only supported for fp8 input dtypes '
3271+
f'(float8_e4m3fn, float8_e5m2), got {x_tile_type.dtype}')
3272+
cur_version = Builder.get_current().ir_ctx.tileiras_version
3273+
if cur_version < BytecodeVersion.V_13_3:
3274+
raise TileUnsupportedFeatureError(
3275+
f'use_fast_acc requires tileiras '
3276+
f'{BytecodeVersion.V_13_3.as_string()} or later. '
3277+
f'Current version is {cur_version.as_string()}.')
32663278
datatype._resolve_mma_supported_dtype(x_tile_type.dtype, y_tile_type.dtype, acc_tile_type.dtype)
32673279
x = _promote_and_broadcast_to(x, TileTy(x_tile_type.dtype, x_shape))
32683280
y = _promote_and_broadcast_to(y, TileTy(y_tile_type.dtype, y_shape))
3269-
return add_operation(TileMma, acc_tile_type, x=x, y=y, acc=acc)
3281+
return add_operation(TileMma, acc_tile_type, use_fast_acc=use_fast_acc, x=x, y=y, acc=acc)
32703282

32713283

32723284
@impl(ct.matmul)

src/cuda/tile/_stub.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1701,7 +1701,7 @@ def zeros(shape, dtype) -> Tile:
17011701

17021702

17031703
@stub
1704-
def mma(x, y, /, acc) -> Tile:
1704+
def mma(x, y, /, acc, *, use_fast_acc: bool = False) -> Tile:
17051705
"""Matrix multiply-accumulate.
17061706
17071707
Computes ``(x @ y) + acc`` as a single operation
@@ -1712,6 +1712,11 @@ def mma(x, y, /, acc) -> Tile:
17121712
x (Tile): LHS of the mma, 2D or 3D.
17131713
y (Tile): RHS of the mma, 2D or 3D.
17141714
acc (Tile): Accumulator of mma.
1715+
use_fast_acc (bool): Enable fast accumulation mode, which trades accumulator
1716+
precision for throughput. Requires fp8 input dtypes
1717+
(``float8_e4m3fn`` or ``float8_e5m2``). Currently only has an effect on
1718+
Hopper GPUs; silently ignored on other architectures. Default: ``False``
1719+
(since CTK 13.3).
17151720
17161721
Supported datatypes:
17171722

test/test_mma.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from util import (
1212
assert_close, assert_equal, require_hopper_or_newer, torch_to_tf32, is_ampere_or_ada
1313
)
14-
from conftest import dtype_id
14+
from conftest import dtype_id, get_tileiras_version
15+
from cuda.tile._bytecode.version import BytecodeVersion
1516
from cuda.tile._exception import TileTypeError, TileUnsupportedFeatureError
1617

1718

@@ -21,11 +22,12 @@
2122
def mma_kernel(A, B, C,
2223
tm: ct.Constant[int],
2324
tn: ct.Constant[int],
24-
tk: ct.Constant[int]):
25+
tk: ct.Constant[int],
26+
use_fast_acc: ct.Constant[bool]):
2527
tx = ct.load(A, index=(0, 0), shape=(tm, tk))
2628
ty = ct.load(B, index=(0, 0), shape=(tk, tn))
2729
acc = ct.load(C, index=(0, 0), shape=(tm, tn))
28-
acc = ct.mma(tx, ty, acc)
30+
acc = ct.mma(tx, ty, acc, use_fast_acc=use_fast_acc)
2931
ct.store(C, index=(0, 0), tile=acc)
3032

3133

@@ -110,32 +112,56 @@ def test_mma_regular_float(tile_size, case):
110112
C = torch.ones((m, n), dtype=case.acc_dtype, device="cuda")
111113
ref = torch.mm(A, B, out_dtype=C.dtype) + C
112114
ct.launch(torch.cuda.current_stream(), (1,), mma_kernel,
113-
(A, B, C, m, n, k))
115+
(A, B, C, m, n, k, False))
114116
atol, rtol = get_tolerance(A.dtype)
115117
assert_close(C, ref, atol=atol, rtol=rtol)
116118

117119

120+
@ct.kernel
121+
def mma_fast_acc_kernel(A, B, C,
122+
tm: ct.Constant[int],
123+
tn: ct.Constant[int],
124+
tk: ct.Constant[int]):
125+
tx = ct.load(A, index=(0, 0), shape=(tm, tk))
126+
ty = ct.load(B, index=(0, 0), shape=(tk, tn))
127+
acc = ct.load(C, index=(0, 0), shape=(tm, tn))
128+
acc = ct.mma(tx, ty, acc, use_fast_acc=True)
129+
ct.store(C, index=(0, 0), tile=acc)
130+
131+
118132
@require_hopper_or_newer()
119133
@pytest.mark.parametrize("tile_size", [(16, 16, 16)])
120134
@pytest.mark.parametrize("case", fp8_cases, ids=str)
121-
def test_mma_fp8(tile_size, case):
135+
@pytest.mark.parametrize("use_fast_acc", [True, False])
136+
def test_mma_fp8(tile_size, case, use_fast_acc):
137+
if use_fast_acc and get_tileiras_version() < BytecodeVersion.V_13_3:
138+
pytest.skip("use_fast_acc requires tileiras 13.3")
122139
m, n, k = tile_size
123140
A = torch.randn((m, k), dtype=torch.float32, device="cuda").to(case.dtype)
124141
B = torch.randn((n, k), dtype=torch.float32, device="cuda").to(case.dtype)
125142
C = torch.ones((m, n), dtype=case.acc_dtype, device="cuda")
126143
scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
127144
try:
128-
ref = torch._scaled_mm(A, B.T, scale, scale, out_dtype=C.dtype) + C
145+
ref = torch._scaled_mm(A, B.T, scale, scale, out_dtype=C.dtype,
146+
use_fast_accum=use_fast_acc) + C
129147
except (RuntimeError, ValueError) as e:
130148
assert 'Multiplication of two Float8_e5m2 matrices is not supported' in str(e)
131149
ref = None
132150
ct.launch(torch.cuda.current_stream(), (1,), mma_kernel,
133-
(A, B.T, C, m, n, k))
151+
(A, B.T, C, m, n, k, use_fast_acc))
134152
if ref is not None:
135153
atol, rtol = get_tolerance(A.dtype)
136154
assert_close(C, ref, atol=atol, rtol=rtol)
137155

138156

157+
def test_mma_fast_acc_non_fp8_error():
158+
A = torch.randn((2, 4), dtype=torch.float16, device="cuda")
159+
B = torch.randn((4, 2), dtype=torch.float16, device="cuda")
160+
C = torch.zeros((2, 2), dtype=torch.float16, device="cuda")
161+
with pytest.raises(TileTypeError, match="use_fast_acc is only supported for fp8"):
162+
ct.launch(torch.cuda.current_stream(), (1,), mma_fast_acc_kernel, (A, B, C, 2, 2, 4))
163+
164+
139165
@pytest.mark.parametrize("tile_size", [(8, 2, 4)])
140166
def test_mma_tf32(tile_size):
141167
m, n, k = tile_size
@@ -163,7 +189,7 @@ def test_mma_int(tile_size, case):
163189
C = torch.ones((m, n), dtype=case.acc_dtype, device="cuda")
164190
ref = C + (A.to(torch.float32) @ B.to(torch.float32)).to(C.dtype)
165191
ct.launch(torch.cuda.current_stream(), (1,), mma_kernel,
166-
(A, B, C, m, n, k))
192+
(A, B, C, m, n, k, False))
167193
assert_equal(C, ref)
168194

169195

@@ -175,7 +201,7 @@ def test_mma_mixed_int_uint(tile_size):
175201
C = torch.ones((m, n), dtype=torch.int32, device="cuda")
176202
ref = C + (A.to(torch.float32) @ B.to(torch.float32)).to(C.dtype)
177203
ct.launch(torch.cuda.current_stream(), (1,), mma_kernel,
178-
(A, B, C, m, n, k))
204+
(A, B, C, m, n, k, False))
179205
assert_equal(C, ref)
180206

181207

@@ -229,7 +255,7 @@ def test_mma_dtype_error(case):
229255
with pytest.raises(TileTypeError, match=case.message):
230256
ct.launch(torch.cuda.current_stream(),
231257
(1,), mma_kernel,
232-
(A, B, C, 2, 2, 2))
258+
(A, B, C, 2, 2, 2, False))
233259

234260
# ================ ct.matmul =================
235261

@@ -405,4 +431,4 @@ def test_ampere_fp8_error(dtype):
405431
with pytest.raises(TileUnsupportedFeatureError,
406432
match="is not supported on sm_80"):
407433
ct.launch(torch.cuda.current_stream(), (1,), mma_kernel,
408-
(A, B, C, 16, 16, 16))
434+
(A, B, C, 16, 16, 16, False))

0 commit comments

Comments
 (0)