Skip to content

Commit b2d3f82

Browse files
committed
exp rounding mode
Signed-off-by: Boyan Li <boyanl@nvidia.com>
1 parent ce9f125 commit b2d3f82

6 files changed

Lines changed: 81 additions & 11 deletions

File tree

changelog.d/exp-rounding-mode.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- Optional `rounding_mode` parameter for `ct.exp()` (supports `RoundingMode.FULL` and `RoundingMode.APPROX` for f32)

src/cuda/tile/_ir/ops.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1522,8 +1522,7 @@ def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
15221522
case "neg", False: return bc.encode_NegIOp(ctx.builder, res_type_id, x,
15231523
bc.IntegerOverflow.NONE)
15241524
case "exp", True: return bc.encode_ExpOp(ctx.builder, res_type_id, x,
1525-
# TODO: expose rounding mode in ct.exp
1526-
rounding_mode=bc.RoundingMode.FULL)
1525+
rounding_mode=rounding_mode)
15271526
case "exp2", True: return bc.encode_Exp2Op(ctx.builder, res_type_id, x,
15281527
flush_to_zero=flush_to_zero)
15291528
case "sin", True: return bc.encode_SinOp(ctx.builder, res_type_id, x)
@@ -1653,7 +1652,6 @@ def pos_impl(x: Var):
16531652
@impl(ct.sinh, fixed_args=["sinh", _UNARY_FLOAT])
16541653
@impl(ct.cos, fixed_args=["cos", _UNARY_FLOAT])
16551654
@impl(ct.cosh, fixed_args=["cosh", _UNARY_FLOAT])
1656-
@impl(ct.exp, fixed_args=["exp", _UNARY_FLOAT])
16571655
@impl(ct.bitwise_not, fixed_args=["bitwise_not", _UNARY_BOOL_INT])
16581656
@impl(ct.floor, fixed_args=["floor", _UNARY_STRICT_FLOAT])
16591657
@impl(ct.ceil, fixed_args=["ceil", _UNARY_STRICT_FLOAT])
@@ -1682,6 +1680,7 @@ def unary_impl_with_rd_and_ftz(fn: str, behavior: _UnaryBehavior,
16821680

16831681

16841682
@impl(ct.tanh, fixed_args=["tanh", _UNARY_FLOAT])
1683+
@impl(ct.exp, fixed_args=["exp", _UNARY_FLOAT])
16851684
def unary_impl_with_rd(fn: str, behavior: _UnaryBehavior, x: Var, rounding_mode: Var) -> Var:
16861685
rounding_mode = require_optional_constant_enum(rounding_mode, RoundingMode)
16871686
return unary(fn, behavior, x, rounding_mode=rounding_mode)

src/cuda/tile/_ir/ops_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class MathOpDef:
4545
_RD_TRUEDIV = {**_RD_BASIC, RoundingMode.FULL: None, RoundingMode.APPROX: None}
4646
_RD_SQRT = {**_RD_BASIC, RoundingMode.APPROX: None}
4747
_RD_TANH = {RoundingMode.FULL: None, RoundingMode.APPROX: BytecodeVersion.V_13_2}
48+
_RD_EXP = {RoundingMode.FULL: None, RoundingMode.APPROX: BytecodeVersion.V_13_3}
4849

4950
BINOP_REGISTRY = {
5051
"add": MathOpDef(lambda x, y: x + y, _RD_BASIC, support_flush_to_zero=True),
@@ -91,7 +92,7 @@ def _invert(x: int | bool, bool_action: Literal['raise'] | Literal['not']):
9192
UNARYOP_REGISTRY = {
9293
"abs": MathOpDef(abs),
9394
"neg": MathOpDef(lambda x: -x),
94-
"exp": MathOpDef(math.exp),
95+
"exp": MathOpDef(math.exp, _RD_EXP),
9596
"exp2": MathOpDef(lambda x: 2 ** x, support_flush_to_zero=True),
9697
"sin": MathOpDef(math.sin),
9798
"sinh": MathOpDef(math.sinh),
@@ -113,7 +114,7 @@ def _invert(x: int | bool, bool_action: Literal['raise'] | Literal['not']):
113114

114115

115116
def get_default_rounding_mode(opname: Optional[str] = None):
116-
return RoundingMode.FULL if opname == 'tanh' else RoundingMode.RN
117+
return RoundingMode.FULL if opname in ('tanh', 'exp') else RoundingMode.RN
117118

118119

119120
rounding_mode_to_bytecode = {

src/cuda/tile/_stub.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3255,10 +3255,23 @@ def wrapped(*args, **kwargs):
32553255
return wrapped
32563256

32573257

3258-
@_doc_unary_op
32593258
@stub
3260-
def exp(x, /) -> TileOrScalar:
3259+
def exp(x, /, *, rounding_mode: Optional[RoundingMode] = None) -> TileOrScalar:
32613260
"""
3261+
Perform `exp` on a tile.
3262+
3263+
Args:
3264+
x (Tile):
3265+
rounding_mode (RoundingMode): Supported values:
3266+
3267+
- ``RoundingMode.FULL`` (f32 only)
3268+
- ``RoundingMode.APPROX`` (f32 only)
3269+
3270+
(since CTK 13.3)
3271+
3272+
Returns:
3273+
Tile:
3274+
32623275
Examples:
32633276
32643277
.. testcode::
@@ -3474,8 +3487,10 @@ def tanh(x, /, *, rounding_mode: Optional[RoundingMode] = None) -> TileOrScalar:
34743487
x (Tile):
34753488
rounding_mode (RoundingMode): Supported values:
34763489
3477-
- ``RoundingMode.FULL``
3478-
- ``RoundingMode.APPROX`` (since CTK 13.2)
3490+
- ``RoundingMode.FULL`` (f32 only)
3491+
- ``RoundingMode.APPROX`` (f32 only)
3492+
3493+
(since CTK 13.2)
34793494
34803495
Returns:
34813496
Tile:

test/test_bytecode_version_compat.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,22 @@ def kernel(x, y):
5353

5454
# Should not raise version error
5555
compile_with_version(kernel, (tensor(), tensor()), "13.1")
56+
57+
58+
def test_exp_rounding_mode_requires_13_3():
59+
def kernel(x, y):
60+
tx = ct.load(x, 0, shape=64)
61+
ct.store(y, 0, tile=ct.exp(tx, rounding_mode=RoundingMode.APPROX))
62+
63+
with pytest.raises(TileUnsupportedFeatureError,
64+
match=r"exp rounding_mode=approx requires tileiras 13\.3"):
65+
compile_with_version(kernel, (tensor(), tensor()), "13.2")
66+
67+
68+
def test_exp_without_rounding_mode_works_on_13_1():
69+
def kernel(x, y):
70+
tx = ct.load(x, 0, shape=64)
71+
ct.store(y, 0, tile=ct.exp(tx))
72+
73+
# Should not raise version error
74+
compile_with_version(kernel, (tensor(), tensor()), "13.1")

test/test_unary_elementwise.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def test_scalar_rounding(shape, tile, is_constant, dtype, op, tmp_path):
336336
@requires_tileiras(BytecodeVersion.V_13_2)
337337
@pytest.mark.use_mlir
338338
@pytest.mark.parametrize("dtype", float_dtypes, ids=dtype_id)
339-
@pytest.mark.parametrize("rounding_mode", [RMd.FULL, RMd.APPROX])
339+
@pytest.mark.parametrize("rounding_mode", [RMd.FULL, RMd.APPROX, None])
340340
def test_array_tanh_rounding_mode(shape, tile, dtype, rounding_mode, tmp_path):
341341
should_raise_dtype = rounding_mode in [RMd.FULL, RMd.APPROX] and dtype != torch.float32
342342
x = make_tensor(shape, dtype=dtype, device='cuda')
@@ -353,7 +353,7 @@ def test_array_tanh_rounding_mode(shape, tile, dtype, rounding_mode, tmp_path):
353353
launch_unary(kernel, x, y, tile)
354354
else:
355355
bytecode = get_bytecode(kernel, (x, y, tile))
356-
if rounding_mode is RMd.FULL:
356+
if rounding_mode in (RMd.FULL, None):
357357
# FULL is the default, not included in mlir text
358358
check_directive = "// CHECK: %[[RES:.*]] = tanh %[[A:.*]]{{[[:space:]]*}}:"
359359
else:
@@ -363,3 +363,35 @@ def test_array_tanh_rounding_mode(shape, tile, dtype, rounding_mode, tmp_path):
363363
filecheck(bytecode, check_directive)
364364
launch_unary(kernel, x, y, tile)
365365
assert_close(y, y_ref)
366+
367+
368+
@requires_tileiras(BytecodeVersion.V_13_3)
369+
@pytest.mark.use_mlir
370+
@pytest.mark.parametrize("dtype", float_dtypes, ids=dtype_id)
371+
@pytest.mark.parametrize("rounding_mode", [RMd.FULL, RMd.APPROX, None])
372+
def test_array_exp_rounding_mode(shape, tile, dtype, rounding_mode, tmp_path):
373+
should_raise_dtype = rounding_mode in [RMd.FULL, RMd.APPROX] and dtype != torch.float32
374+
x = make_tensor(shape, dtype=dtype, device='cuda')
375+
y_ref = torch.exp(x)
376+
y = torch.zeros_like(y_ref, device="cuda")
377+
kernel = array_kernel("exp_rounding_mode",
378+
f"ty = ct.exp(tx, rounding_mode={rounding_mode})",
379+
tmp_path,
380+
globals={"RoundingMode": RMd})
381+
if should_raise_dtype:
382+
with pytest.raises(TileTypeError,
383+
match=fr"Rounding mode {rounding_mode.value} can only be used for "
384+
"float32 type"):
385+
launch_unary(kernel, x, y, tile)
386+
else:
387+
bytecode = get_bytecode(kernel, (x, y, tile))
388+
if rounding_mode in (RMd.FULL, None):
389+
# FULL is the default, not included in mlir text
390+
check_directive = "// CHECK: %[[RES:.*]] = exp %[[A:.*]]{{[[:space:]]*}}:"
391+
else:
392+
check_directive = (
393+
f"// CHECK: %[[RES:.*]] = exp %[[A:.*]] rounding<{rounding_mode.value}>"
394+
)
395+
filecheck(bytecode, check_directive)
396+
launch_unary(kernel, x, y, tile)
397+
assert_close(y, y_ref)

0 commit comments

Comments
 (0)