Skip to content

Commit bfb2960

Browse files
committed
num_worker_warps hint
Signed-off-by: Boyan Li <boyanl@nvidia.com>
1 parent 7e19e59 commit bfb2960

6 files changed

Lines changed: 75 additions & 5 deletions

File tree

src/cuda/tile/_bytecode/attribute.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,17 @@ def encode_dictionary_untagged(items: Sequence[tuple[str, TaggedAttribute]],
6262
class EntryHints:
6363
num_cta_in_cga: Optional[int] = None
6464
occupancy: Optional[int] = None
65+
num_worker_warps_per_cta: Optional[int] = None
6566

6667
def as_dictionary(self) -> Dictionary:
6768
items = []
6869
if self.num_cta_in_cga is not None:
6970
items.append(("num_cta_in_cga", Integer.create_i32(self.num_cta_in_cga)))
7071
if self.occupancy is not None:
7172
items.append(("occupancy", Integer.create_i32(self.occupancy)))
73+
if self.num_worker_warps_per_cta is not None:
74+
items.append(("num_worker_warps_per_cta",
75+
Integer.create_i32(self.num_worker_warps_per_cta)))
7276
return Dictionary(items)
7377

7478

src/cuda/tile/_compiler_options.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class CompilerOptions:
1616
num_ctas: None | int | ByTarget[int] = None
1717
occupancy: None | int | ByTarget[int] = None
1818
opt_level: int | ByTarget[int] = 3
19+
num_worker_warps: None | int | ByTarget[int] = None
1920

2021
def __post_init__(self):
2122
for field in dataclasses.fields(self):
@@ -59,3 +60,10 @@ def _validate_opt_level(opt_level: None | int):
5960
if opt_level is not None:
6061
if opt_level < 0 or opt_level > 3:
6162
raise ValueError(f'opt_level should be [0, 3], got {opt_level}')
63+
64+
65+
def _validate_num_worker_warps(num_worker_warps: None | int):
66+
if num_worker_warps is not None:
67+
if num_worker_warps not in (4, 8):
68+
raise ValueError(f'num_worker_warps should be either 4 or 8,'
69+
f' got {num_worker_warps}')

src/cuda/tile/_execution.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,16 @@ class kernel(TileDispatcher):
7575
Default: None (auto).
7676
occupancy: Expected number of active CTAs per SM, [1, 32]. Default: None (auto).
7777
opt_level: Optimization level [0, 3], default 3.
78+
num_worker_warps: Number of warps in the CUDA core warp groups in a
79+
warp-specialized kernel. The compiler may add warps
80+
(e.g., for asynchronous memory transfers) that are not counted here.
81+
This value does not represent the total warp count.
82+
It's worth tuning when a warp-specialized kernel has high register pressure
83+
that other approaches cannot resolve.
84+
Normalization-style kernels with large tiles are the canonical cases.
85+
Must be either 4 or 8.
86+
Default: None (auto).
87+
Since CTK 13.3. Ignored with a warning otherwise.
7888
7989
Target-specific values for the compiler options above can be provided
8090
using a :py:class:`ByTarget` object.
@@ -92,7 +102,8 @@ def __init__(self,
92102
/, *,
93103
num_ctas: None | int | ByTarget[int] = None,
94104
occupancy: None | int | ByTarget[int] = None,
95-
opt_level: None | int | ByTarget[int] = 3):
105+
opt_level: None | int | ByTarget[int] = 3,
106+
num_worker_warps: None | int | ByTarget[int] = None):
96107
if not isinstance(function, FunctionType):
97108
raise TypeError("`kernel` decorator must be applied to a Python function")
98109

@@ -103,7 +114,8 @@ def __init__(self,
103114
compiler_options = CompilerOptions(
104115
num_ctas=num_ctas,
105116
occupancy=occupancy,
106-
opt_level=opt_level
117+
opt_level=opt_level,
118+
num_worker_warps=num_worker_warps
107119
)
108120
super().__init__(ann_func.constant_parameter_mask, ann_func.int64_index_parameter_mask,
109121
ann_func.int64_parameter_mask)

src/cuda/tile/_ir2bytecode.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
import os
77
from contextlib import contextmanager
88
from typing import Dict, Tuple, Any, Optional
9+
import warnings
910

1011
from cuda.tile import _datatype as datatype
1112
from cuda.tile._bytecode.attribute import make_load_store_hints
13+
from cuda.tile._bytecode.version import BytecodeVersion
1214
from cuda.tile._datatype import get_signedness
1315
from cuda.tile import DType
1416
import cuda.tile._bytecode as bc
@@ -419,8 +421,18 @@ def generate_bytecode_for_kernel(func_body: Block,
419421
writer: bc.BytecodeWriter,
420422
anonymize_debug_attr: bool):
421423
target_options = compiler_options.specialize_for_target(sm_arch)
424+
num_worker_warps = target_options.num_worker_warps
425+
if num_worker_warps is not None and writer.version < BytecodeVersion.V_13_3:
426+
warnings.warn(
427+
f"num_worker_warps requires tileiras "
428+
f"{BytecodeVersion.V_13_3.as_string()} or later; ignoring "
429+
f"(current version is {writer.version.as_string()}).",
430+
)
431+
num_worker_warps = None
432+
422433
entry_hints = bc.EntryHints(num_cta_in_cga=target_options.num_ctas,
423-
occupancy=target_options.occupancy)
434+
occupancy=target_options.occupancy,
435+
num_worker_warps_per_cta=num_worker_warps)
424436

425437
param_type_ids = [typeid(writer.type_table, p.get_type()) for p in func_body.params]
426438
debug_attr_map = DebugAttrMap(writer.debug_attr_table, symbol, anonymize_debug_attr)

test/test_bytecode_version_compat.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
from cuda.tile.compilation import CallingConvention, KernelSignature
1414

1515

16-
def compile_with_version(pyfunc, args, version: str):
17-
kernel = ct.kernel(pyfunc)
16+
def compile_with_version(kernel, args, version: str):
1817
cconv = CallingConvention.cutile_python_v1()
1918
sig = KernelSignature.from_kernel_args(kernel, args, cconv)
2019
ct.compilation.export_kernel(kernel, [sig], output_file=BytesIO(),
@@ -27,6 +26,7 @@ def tensor(dtype=torch.float32):
2726

2827

2928
def test_atan2_requires_13_2():
29+
@ct.kernel
3030
def kernel(x, y, z):
3131
tx = ct.load(x, 0, shape=64)
3232
ty = ct.load(y, 0, shape=64)
@@ -37,6 +37,7 @@ def kernel(x, y, z):
3737

3838

3939
def test_tanh_rounding_mode_requires_13_2():
40+
@ct.kernel
4041
def kernel(x, y):
4142
tx = ct.load(x, 0, shape=64)
4243
ct.store(y, 0, tile=ct.tanh(tx, rounding_mode=RoundingMode.APPROX))
@@ -47,6 +48,7 @@ def kernel(x, y):
4748

4849

4950
def test_tanh_without_rounding_mode_works_on_13_1():
51+
@ct.kernel
5052
def kernel(x, y):
5153
tx = ct.load(x, 0, shape=64)
5254
ct.store(y, 0, tile=ct.tanh(tx))
@@ -56,6 +58,7 @@ def kernel(x, y):
5658

5759

5860
def test_exp_rounding_mode_requires_13_3():
61+
@ct.kernel
5962
def kernel(x, y):
6063
tx = ct.load(x, 0, shape=64)
6164
ct.store(y, 0, tile=ct.exp(tx, rounding_mode=RoundingMode.APPROX))
@@ -66,9 +69,21 @@ def kernel(x, y):
6669

6770

6871
def test_exp_without_rounding_mode_works_on_13_1():
72+
@ct.kernel
6973
def kernel(x, y):
7074
tx = ct.load(x, 0, shape=64)
7175
ct.store(y, 0, tile=ct.exp(tx))
7276

7377
# Should not raise version error
7478
compile_with_version(kernel, (tensor(), tensor()), "13.1")
79+
80+
81+
def test_num_worker_warps_warns_below_13_3():
82+
@ct.kernel(num_worker_warps=8)
83+
def kernel(x, y):
84+
tx = ct.load(x, 0, shape=64)
85+
ct.store(y, 0, tile=tx)
86+
87+
with pytest.warns(UserWarning,
88+
match=r"num_worker_warps requires tileiras 13\.3"):
89+
compile_with_version(kernel, (tensor(), tensor()), "13.1")

test/test_compiler_options.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,22 @@ def test_invalid_target_name():
1111
err = r"Invalid GPU architecture name: sm100, expected sm_<major><minor>"
1212
with pytest.raises(ValueError, match=err):
1313
ct.ByTarget(sm100=4)
14+
15+
16+
def _dummy():
17+
pass
18+
19+
20+
@pytest.mark.parametrize("value", [None, 4, 8])
21+
def test_num_worker_warps_accepts_valid(value):
22+
ct.kernel(_dummy, num_worker_warps=value)
23+
24+
25+
@pytest.mark.parametrize("value", [3, 7, 10])
26+
def test_num_worker_warps_rejects_invalid(value):
27+
with pytest.raises(ValueError, match="num_worker_warps should be either 4 or 8"):
28+
ct.kernel(_dummy, num_worker_warps=value)
29+
30+
31+
def test_num_worker_warps_accepts_by_target():
32+
ct.kernel(_dummy, num_worker_warps=ct.ByTarget(sm_100=8, default=4))

0 commit comments

Comments
 (0)