From 6329cb9b6753e2d17245bd0fe50d7ba3d4369262 Mon Sep 17 00:00:00 2001 From: Jinman Xie Date: Tue, 21 Apr 2026 22:57:31 -0700 Subject: [PATCH 1/6] =?UTF-8?q?feat:=20migrate=20cuda.tile=5Fexperimental.?= =?UTF-8?q?autotune=5Flaunch=20=E2=86=92=20cuda.tile.tune.exhaustive=5Fsea?= =?UTF-8?q?rch?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/tilegym/__init__.py | 14 +- src/tilegym/ops/cutile/attention.py | 189 ++++++++++++++---- src/tilegym/ops/cutile/attention_sink.py | 61 ++++-- src/tilegym/ops/cutile/bmm.py | 38 ++-- src/tilegym/ops/cutile/experimental/mhc.py | 14 +- .../ops/cutile/experimental/sparse_mla.py | 65 +++--- src/tilegym/ops/cutile/gemma_attention.py | 80 ++++++-- src/tilegym/ops/cutile/group_gemm.py | 57 ++++-- src/tilegym/ops/cutile/layer_norm_legacy.py | 35 +++- src/tilegym/ops/cutile/matmul.py | 105 +++++++--- src/tilegym/suites/unsloth/cutile/geglu.py | 72 +++++-- .../suites/unsloth/cutile/grouped_gemm.py | 175 ++++++++++++++-- .../suites/unsloth/cutile/layernorm.py | 58 ++++-- .../suites/unsloth/cutile/rms_layernorm.py | 57 ++++-- .../suites/unsloth/cutile/rope_embedding.py | 101 ++++++++-- src/tilegym/suites/unsloth/cutile/swiglu.py | 31 ++- 16 files changed, 892 insertions(+), 260 deletions(-) diff --git a/src/tilegym/__init__.py b/src/tilegym/__init__.py index 4fda9527..afb9f5e8 100644 --- a/src/tilegym/__init__.py +++ b/src/tilegym/__init__.py @@ -18,17 +18,15 @@ def _check_torch_dependencies(): def _check_ct_experimental_dependency(): - """Verify that cuda-tile-experimental is installed with helpful error message.""" + """Verify that cuda-tile with tune support is installed with helpful error message.""" try: - import cuda.tile_experimental # noqa: F401 + import cuda.tile.tune # noqa: F401 except (ImportError, ModuleNotFoundError): raise ImportError( - "\n\n[TileGym] cuda-tile-experimental is required but not installed.\n" - "It is not available on PyPI and must be installed from source:\n\n" - ' pip install "cuda-tile-experimental @ ' - 'git+https://github.com/NVIDIA/cutile-python.git#subdirectory=experimental"\n\n' - "See: https://github.com/NVIDIA/cutile-python?tab=readme-ov-file" - "#experimental-features-optional\n" + "\n\n[TileGym] cuda.tile.tune is required but not available.\n" + "Please install or upgrade cuda-tile:\n\n" + " pip install cuda-tile\n\n" + "See: https://github.com/NVIDIA/cutile-python" ) from None diff --git a/src/tilegym/ops/cutile/attention.py b/src/tilegym/ops/cutile/attention.py index 92a97793..0d12e9c5 100644 --- a/src/tilegym/ops/cutile/attention.py +++ b/src/tilegym/ops/cutile/attention.py @@ -7,9 +7,9 @@ from types import SimpleNamespace import cuda.tile as ct -import cuda.tile_experimental as ct_experimental import torch from cuda.tile import RoundingMode as RMd +from cuda.tile.tune import exhaustive_search from tilegym.backend import register_impl from tilegym.experimental import experimental_kernel @@ -17,6 +17,11 @@ from .utils import next_power_of_2 +# Module-level tune caches for fmha forward, dkdv backward, and dq backward +_fmha_fwd_tune_cache: dict = {} +_fmha_bwd_dkdv_tune_cache: dict = {} +_fmha_bwd_dq_tune_cache: dict = {} + logger = get_logger(__name__) INV_LOG_2 = 1.0 / math.log(2) @@ -740,15 +745,54 @@ def cutile_autotune_fmha( ), ) else: - ct_experimental.autotune_launch( + fwd_cache_key = ( + batch_size, + num_heads, + q_len, + hidden_size, + query_group_size, + is_causal, + EVEN_K, + q.dtype, + str(q.device), + ) + if fwd_cache_key not in _fmha_fwd_tune_cache: + result = exhaustive_search( + list(_fmha_autotune_configs(hidden_size)), + stream, + lambda cfg: (math.ceil(q_len / cfg.TILE_M), batch_size * num_heads, 1), + fmha_kernel, + lambda cfg: ( + q, + k, + v, + o, + sm_scale, + input_pos, + hidden_size, + num_heads, + cfg.TILE_M, + cfg.TILE_N, + query_group_size, + is_causal, + EVEN_K, + ), + ) + best_cfg = result.best.config + _fmha_fwd_tune_cache[fwd_cache_key] = ( + best_cfg, + ct.kernel( + fmha_kernel._pyfunc, + num_ctas=best_cfg.num_ctas, + occupancy=best_cfg.occupancy, + ), + ) + best_cfg, tuned_kernel = _fmha_fwd_tune_cache[fwd_cache_key] + ct.launch( stream, - grid_fn=lambda cfg: ( - math.ceil(q_len / cfg.TILE_M), - batch_size * num_heads, - 1, - ), - kernel=fmha_kernel, - args_fn=lambda cfg: ( + (math.ceil(q_len / best_cfg.TILE_M), batch_size * num_heads, 1), + tuned_kernel, + ( q, k, v, @@ -757,14 +801,12 @@ def cutile_autotune_fmha( input_pos, hidden_size, num_heads, - cfg.TILE_M, - cfg.TILE_N, + best_cfg.TILE_M, + best_cfg.TILE_N, query_group_size, is_causal, EVEN_K, ), - search_space=lambda: _fmha_autotune_configs(hidden_size), - max_iter=20, ) return o @@ -1059,15 +1101,55 @@ def fmha_backward( ), ) else: - ct_experimental.autotune_launch( + dkdv_cache_key = ( + batch_size, + num_heads, + num_head_kv, + q_len, + k_len, + hidden_size, + query_group_size, + is_causal, + q.dtype, + str(q.device), + ) + if dkdv_cache_key not in _fmha_bwd_dkdv_tune_cache: + result = exhaustive_search( + list(_fmha_bwd_dkdv_autotune_configs(hidden_size)), + stream, + lambda cfg: (math.ceil(k_len / cfg.TILE_N), batch_size * num_head_kv, 1), + fmha_bwd_dkdv_kernel, + lambda cfg: ( + q, + k, + v, + do, + dk, + dv, + lse_flat, + delta_flat, + sm_scale, + TILE_D, + num_heads, + num_head_kv, + padded_q_len, + cfg.TILE_M, + cfg.TILE_N, + query_group_size, + is_causal, + ), + ) + best_cfg = result.best.config + _fmha_bwd_dkdv_tune_cache[dkdv_cache_key] = ( + best_cfg, + ct.kernel(fmha_bwd_dkdv_kernel._pyfunc), + ) + best_cfg, tuned_kernel = _fmha_bwd_dkdv_tune_cache[dkdv_cache_key] + ct.launch( stream, - grid_fn=lambda cfg: ( - math.ceil(k_len / cfg.TILE_N), - batch_size * num_head_kv, - 1, - ), - kernel=fmha_bwd_dkdv_kernel, - args_fn=lambda cfg: ( + (math.ceil(k_len / best_cfg.TILE_N), batch_size * num_head_kv, 1), + tuned_kernel, + ( q, k, v, @@ -1081,13 +1163,11 @@ def fmha_backward( num_heads, num_head_kv, padded_q_len, - cfg.TILE_M, - cfg.TILE_N, + best_cfg.TILE_M, + best_cfg.TILE_N, query_group_size, is_causal, ), - search_space=lambda: _fmha_bwd_dkdv_autotune_configs(hidden_size), - max_iter=20, ) # Step 3: Compute dQ @@ -1119,15 +1199,52 @@ def fmha_backward( ), ) else: - ct_experimental.autotune_launch( + dq_cache_key = ( + batch_size, + num_heads, + q_len, + k_len, + hidden_size, + query_group_size, + is_causal, + q.dtype, + str(q.device), + ) + if dq_cache_key not in _fmha_bwd_dq_tune_cache: + result = exhaustive_search( + list(_fmha_bwd_dq_autotune_configs(hidden_size)), + stream, + lambda cfg: (math.ceil(q_len / cfg.TILE_M), batch_size * num_heads, 1), + fmha_bwd_dq_kernel, + lambda cfg: ( + q, + k, + v, + do, + dq, + lse_flat, + delta_flat, + sm_scale, + TILE_D, + num_heads, + padded_q_len, + cfg.TILE_M, + cfg.TILE_N, + query_group_size, + is_causal, + ), + ) + best_cfg = result.best.config + _fmha_bwd_dq_tune_cache[dq_cache_key] = ( + best_cfg, + ct.kernel(fmha_bwd_dq_kernel._pyfunc), + ) + best_cfg, tuned_kernel = _fmha_bwd_dq_tune_cache[dq_cache_key] + ct.launch( stream, - grid_fn=lambda cfg: ( - math.ceil(q_len / cfg.TILE_M), - batch_size * num_heads, - 1, - ), - kernel=fmha_bwd_dq_kernel, - args_fn=lambda cfg: ( + (math.ceil(q_len / best_cfg.TILE_M), batch_size * num_heads, 1), + tuned_kernel, + ( q, k, v, @@ -1139,13 +1256,11 @@ def fmha_backward( TILE_D, num_heads, padded_q_len, - cfg.TILE_M, - cfg.TILE_N, + best_cfg.TILE_M, + best_cfg.TILE_N, query_group_size, is_causal, ), - search_space=lambda: _fmha_bwd_dq_autotune_configs(hidden_size), - max_iter=20, ) return dq, dk, dv diff --git a/src/tilegym/ops/cutile/attention_sink.py b/src/tilegym/ops/cutile/attention_sink.py index 4e579d89..4ef089d8 100644 --- a/src/tilegym/ops/cutile/attention_sink.py +++ b/src/tilegym/ops/cutile/attention_sink.py @@ -7,13 +7,16 @@ from types import SimpleNamespace import cuda.tile as ct -import cuda.tile_experimental as ct_experimental import numpy as np import torch from cuda.tile import RoundingMode as RMd +from cuda.tile.tune import exhaustive_search from tilegym.backend import register_impl +# Module-level tune cache: (batch_size, n_heads, n_ctx, head_dim, n_kv_ctx, bandwidth, dtype, device) -> (best_cfg, tuned_kernel) +_attention_sink_tune_cache: dict = {} + INV_LOG_2 = 1.0 / math.log(2) # Define type aliases for Constant integers and booleans @@ -193,15 +196,46 @@ def _cutile_autotune_attention_sink( """Autotuned kernel launch.""" batch_size, _, n_ctx, _ = q.shape - ct_experimental.autotune_launch( + cache_key = (batch_size, n_heads, n_ctx, head_dim, n_kv_ctx, bandwidth, q.dtype, str(q.device)) + if cache_key not in _attention_sink_tune_cache: + result = exhaustive_search( + list(_attention_sink_autotune_configs()), + stream, + lambda cfg: (math.ceil(n_ctx / cfg.TILE_M), batch_size * n_heads, 1), + attention_sink_kernel, + lambda cfg: ( + q, + k, + v, + sinks, + o, + start_q, + sm_scale, + head_dim, + n_heads, + n_kv_ctx, + cfg.TILE_M, + cfg.TILE_N, + repeat_kv, + bandwidth, + ), + lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _attention_sink_tune_cache[cache_key] = ( + best_cfg, + ct.kernel( + attention_sink_kernel._pyfunc, + num_ctas=best_cfg.num_ctas, + occupancy=best_cfg.occupancy, + ), + ) + best_cfg, tuned_kernel = _attention_sink_tune_cache[cache_key] + ct.launch( stream, - grid_fn=lambda cfg: ( - math.ceil(n_ctx / cfg.TILE_M), - batch_size * n_heads, - 1, - ), - kernel=attention_sink_kernel, - args_fn=lambda cfg: ( + (math.ceil(n_ctx / best_cfg.TILE_M), batch_size * n_heads, 1), + tuned_kernel, + ( q, k, v, @@ -212,16 +246,11 @@ def _cutile_autotune_attention_sink( head_dim, n_heads, n_kv_ctx, - cfg.TILE_M, - cfg.TILE_N, + best_cfg.TILE_M, + best_cfg.TILE_N, repeat_kv, bandwidth, ), - hints_fn=lambda cfg: { - "num_ctas": cfg.num_ctas, - "occupancy": cfg.occupancy, - }, - search_space=_attention_sink_autotune_configs, ) diff --git a/src/tilegym/ops/cutile/bmm.py b/src/tilegym/ops/cutile/bmm.py index 1d547c9d..03bbd805 100644 --- a/src/tilegym/ops/cutile/bmm.py +++ b/src/tilegym/ops/cutile/bmm.py @@ -7,11 +7,14 @@ from types import SimpleNamespace import cuda.tile as ct -import cuda.tile_experimental as ct_experimental import torch +from cuda.tile.tune import exhaustive_search from tilegym.backend import register_impl +# Module-level tune cache: (batch_size, M, N, K, transpose_a, transpose_b, dtype, device) -> (best_cfg, tuned_kernel) +_bmm_tune_cache: dict = {} + # CuTile implementation of BMM kernel @ct.kernel @@ -279,18 +282,27 @@ def grid_fn(cfg): return (grid_size,) # Call autotuner to find the best config and execute the kernel - ct_experimental.autotune_launch( - stream, - grid_fn=grid_fn, - kernel=ct_static_persistent_bmm_kernel, - args_fn=args_fn, - hints_fn=lambda cfg: { - "num_ctas": cfg.num_ctas, - "occupancy": cfg.occupancy, - }, - search_space=_bmm_autotune_configs, - compiler_time_limit_sec=30, - ) + cache_key = (batch_size, M, N, K, transpose_a, transpose_b, a.dtype, str(a.device)) + if cache_key not in _bmm_tune_cache: + result = exhaustive_search( + list(_bmm_autotune_configs()), + stream, + grid_fn, + ct_static_persistent_bmm_kernel, + args_fn, + lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _bmm_tune_cache[cache_key] = ( + best_cfg, + ct.kernel( + ct_static_persistent_bmm_kernel._pyfunc, + num_ctas=best_cfg.num_ctas, + occupancy=best_cfg.occupancy, + ), + ) + best_cfg, tuned_kernel = _bmm_tune_cache[cache_key] + ct.launch(stream, grid_fn(best_cfg), tuned_kernel, args_fn(best_cfg)) return output diff --git a/src/tilegym/ops/cutile/experimental/mhc.py b/src/tilegym/ops/cutile/experimental/mhc.py index f6ef81f2..b21923fb 100644 --- a/src/tilegym/ops/cutile/experimental/mhc.py +++ b/src/tilegym/ops/cutile/experimental/mhc.py @@ -6,8 +6,8 @@ from types import SimpleNamespace import cuda.tile as ct -import cuda.tile_experimental as ct_experimental import torch +from cuda.tile.tune import exhaustive_search from tilegym.backend import register_impl from tilegym.experimental import experimental_kernel @@ -263,15 +263,16 @@ def cutile_autotune_mhc_split_gemm_rms(stream, x, w, M, N, K, cfg=None): max_num_bid_n = max(ceil(N / cfg.TILE_SIZE_N) for cfg in configs) y_acc = torch.empty((M * max_split_k, N), device=x.device, dtype=torch.float32) r_acc = torch.empty((M * max_split_k, max_num_bid_n), device=x.device, dtype=torch.float32) - tuned = ct_experimental.autotune_launch( + result = exhaustive_search( + configs, stream, - grid_fn=lambda cfg: ( + lambda cfg: ( ceil(M / cfg.TILE_SIZE_M) * ceil(N / cfg.TILE_SIZE_N), cfg.SPLIT_K, 1, ), - kernel=mhc_split_gemm_rms_kernel, - args_fn=lambda cfg: ( + mhc_split_gemm_rms_kernel, + lambda cfg: ( x, w, y_acc, @@ -285,9 +286,8 @@ def cutile_autotune_mhc_split_gemm_rms(stream, x, w, M, N, K, cfg=None): cfg.SPLIT_K, cfg.GROUP_SIZE_M, ), - search_space=configs, ) - best_cfg = tuned.tuned_config + best_cfg = result.best.config # Re-run the winning config with fresh buffers. The autotuner reuses # a single y_acc/r_acc pair across all evaluated configs, so after diff --git a/src/tilegym/ops/cutile/experimental/sparse_mla.py b/src/tilegym/ops/cutile/experimental/sparse_mla.py index 176e70c7..c84ab0b4 100644 --- a/src/tilegym/ops/cutile/experimental/sparse_mla.py +++ b/src/tilegym/ops/cutile/experimental/sparse_mla.py @@ -7,9 +7,9 @@ from types import SimpleNamespace import cuda.tile as ct -import cuda.tile_experimental as ct_experimental import torch from cuda.tile import RoundingMode as RMd +from cuda.tile.tune import exhaustive_search from tilegym.backend import register_impl from tilegym.experimental import experimental_kernel @@ -19,6 +19,9 @@ logger = get_logger(__name__) +# Module-level tune cache: (B, H, S, topk, D, D_PE, query_group_size, dtype, device) -> (best_cfg, tuned_kernel) +_sparse_mla_tune_cache: dict = {} + ConstInt = ct.Constant[int] ConstBool = ct.Constant[bool] @@ -232,7 +235,7 @@ def _launch_sparse_mla_fwd( Three mutually exclusive config selection modes (no fallback between paths): Path 1: Explicit kernel_configs — validated and launched directly. Path 2: DISABLE_AUTOTUNE=1 — first valid config from search space. - Path 3: Autotune — ct_experimental.autotune_launch over search space. + Path 3: Autotune — exhaustive_search over search space. """ B = q.shape[0] @@ -283,31 +286,39 @@ def _launch_with_cfg(cfg): # Path 3: Autotune — search over all valid (TILE_H, TILE_N) pairs. else: - ct_experimental.autotune_launch( - stream, - grid_fn=lambda cfg: (S, B * (H // cfg.TILE_H), 1), - kernel=sparse_mla_fwd_kernel, - args_fn=lambda cfg: ( - q, - k, - v, - indices, - qpe, - kpe, - o, - sm_scale, - D, - D_PE, - H, - cfg.TILE_N, - topk // cfg.TILE_N, - query_group_size, - cfg.TILE_H, - H // cfg.TILE_H, - ), - search_space=lambda: _sparse_mla_autotune_configs(topk, H, query_group_size), - max_iter=20, - ) + cache_key = (B, H, S, topk, D, D_PE, query_group_size, q.dtype, str(q.device)) + if cache_key not in _sparse_mla_tune_cache: + result = exhaustive_search( + list(_sparse_mla_autotune_configs(topk, H, query_group_size)), + stream, + lambda cfg: (S, B * (H // cfg.TILE_H), 1), + sparse_mla_fwd_kernel, + lambda cfg: ( + q, + k, + v, + indices, + qpe, + kpe, + o, + sm_scale, + D, + D_PE, + H, + cfg.TILE_N, + topk // cfg.TILE_N, + query_group_size, + cfg.TILE_H, + H // cfg.TILE_H, + ), + ) + best_cfg = result.best.config + _sparse_mla_tune_cache[cache_key] = ( + best_cfg, + ct.kernel(sparse_mla_fwd_kernel._pyfunc), + ) + best_cfg, tuned_kernel = _sparse_mla_tune_cache[cache_key] + _launch_with_cfg(best_cfg) return o diff --git a/src/tilegym/ops/cutile/gemma_attention.py b/src/tilegym/ops/cutile/gemma_attention.py index 8fdfa768..22ee89b3 100644 --- a/src/tilegym/ops/cutile/gemma_attention.py +++ b/src/tilegym/ops/cutile/gemma_attention.py @@ -16,12 +16,15 @@ from types import SimpleNamespace import cuda.tile as ct -import cuda.tile_experimental as ct_experimental import torch from cuda.tile import RoundingMode as RMd +from cuda.tile.tune import exhaustive_search from tilegym.backend import register_impl +# Module-level tune cache: (B, H, S_qo, S_kv, BLOCK_D, query_group_size, stage, window_size, soft_cap_val, has_soft_cap, dtype, device) -> (best_cfg, tuned_kernel) +_gemma_fmha_tune_cache: dict = {} + # Constants INV_LOG_2 = 1.0 / math.log(2) LOG_2 = math.log(2) @@ -302,15 +305,63 @@ def _cutile_autotune_gemma_fmha( has_soft_cap, ): """Launch gemma FMHA kernel with autotune.""" - ct_experimental.autotune_launch( + cache_key = ( + B, + H, + S_qo, + S_kv, + BLOCK_D, + query_group_size, + stage, + window_size, + soft_cap_val, + has_soft_cap, + q.dtype, + str(q.device), + ) + if cache_key not in _gemma_fmha_tune_cache: + result = exhaustive_search( + list(_gemma_fmha_autotune_configs()), + stream, + lambda cfg: (math.ceil(S_qo / cfg.BLOCK_M), B * H, 1), + gemma_fmha_kernel, + lambda cfg: ( + q, + k, + v, + o, + sm_scale, + B, + H, + S_qo, + S_kv, + BLOCK_D, + cfg.BLOCK_M, + cfg.BLOCK_N, + query_group_size, + stage, + window_size, + soft_cap_val, + has_soft_cap, + (S_kv % cfg.BLOCK_N) == 0, + ), + lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _gemma_fmha_tune_cache[cache_key] = ( + best_cfg, + ct.kernel( + gemma_fmha_kernel._pyfunc, + num_ctas=best_cfg.num_ctas, + occupancy=best_cfg.occupancy, + ), + ) + best_cfg, tuned_kernel = _gemma_fmha_tune_cache[cache_key] + ct.launch( stream, - grid_fn=lambda cfg: ( - math.ceil(S_qo / cfg.BLOCK_M), - B * H, - 1, - ), - kernel=gemma_fmha_kernel, - args_fn=lambda cfg: ( + (math.ceil(S_qo / best_cfg.BLOCK_M), B * H, 1), + tuned_kernel, + ( q, k, v, @@ -321,20 +372,15 @@ def _cutile_autotune_gemma_fmha( S_qo, S_kv, BLOCK_D, - cfg.BLOCK_M, - cfg.BLOCK_N, + best_cfg.BLOCK_M, + best_cfg.BLOCK_N, query_group_size, stage, window_size, soft_cap_val, has_soft_cap, - (S_kv % cfg.BLOCK_N) == 0, + (S_kv % best_cfg.BLOCK_N) == 0, ), - hints_fn=lambda cfg: { - "num_ctas": cfg.num_ctas, - "occupancy": cfg.occupancy, - }, - search_space=_gemma_fmha_autotune_configs, ) return o diff --git a/src/tilegym/ops/cutile/group_gemm.py b/src/tilegym/ops/cutile/group_gemm.py index 67a3fb3d..ca8b6edd 100644 --- a/src/tilegym/ops/cutile/group_gemm.py +++ b/src/tilegym/ops/cutile/group_gemm.py @@ -5,14 +5,17 @@ from types import SimpleNamespace import cuda.tile as ct -import cuda.tile_experimental as ct_experimental import torch +from cuda.tile.tune import exhaustive_search from tilegym.backend import register_impl from tilegym.logger import get_logger logger = get_logger(__name__) +# Module-level tune cache: (group_shapes, transpose_b, dtype, device) -> (best_cfg, tuned_kernel) +_group_gemm_tune_cache: dict = {} + # Type aliases for constants ConstInt = ct.Constant[int] ConstBool = ct.Constant[bool] @@ -121,26 +124,50 @@ def _group_gemm_autotune_configs(): def cutile_autotune_group_gemm(stream, group_A, group_B, group_C, transpose_b, device): """Autotune group GEMM kernel.""" NUM_SMS = torch.cuda.get_device_properties(device).multi_processor_count - - ct_experimental.autotune_launch( + group_shapes = tuple((tuple(A.shape), tuple(B.shape)) for A, B in zip(group_A, group_B)) + cache_key = (group_shapes, transpose_b, group_A[0].dtype, str(group_A[0].device)) + if cache_key not in _group_gemm_tune_cache: + result = exhaustive_search( + list(_group_gemm_autotune_configs()), + stream, + lambda cfg: (NUM_SMS // cfg.num_ctas * cfg.occupancy, 1, 1), + group_gemm_kernel, + lambda cfg: ( + group_A, + group_B, + group_C, + cfg.TILE_M, + cfg.TILE_N, + cfg.TILE_K, + NUM_SMS // cfg.num_ctas * cfg.occupancy, + transpose_b, + ), + lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _group_gemm_tune_cache[cache_key] = ( + best_cfg, + ct.kernel( + group_gemm_kernel._pyfunc, + num_ctas=best_cfg.num_ctas, + occupancy=best_cfg.occupancy, + ), + ) + best_cfg, tuned_kernel = _group_gemm_tune_cache[cache_key] + ct.launch( stream, - grid_fn=lambda cfg: (NUM_SMS // cfg.num_ctas * cfg.occupancy, 1, 1), - kernel=group_gemm_kernel, - args_fn=lambda cfg: ( + (NUM_SMS // best_cfg.num_ctas * best_cfg.occupancy, 1, 1), + tuned_kernel, + ( group_A, group_B, group_C, - cfg.TILE_M, - cfg.TILE_N, - cfg.TILE_K, - NUM_SMS // cfg.num_ctas * cfg.occupancy, + best_cfg.TILE_M, + best_cfg.TILE_N, + best_cfg.TILE_K, + NUM_SMS // best_cfg.num_ctas * best_cfg.occupancy, transpose_b, ), - hints_fn=lambda cfg: { - "num_ctas": cfg.num_ctas, - "occupancy": cfg.occupancy, - }, - search_space=_group_gemm_autotune_configs, ) return group_C diff --git a/src/tilegym/ops/cutile/layer_norm_legacy.py b/src/tilegym/ops/cutile/layer_norm_legacy.py index 0154409d..8ba63c6a 100644 --- a/src/tilegym/ops/cutile/layer_norm_legacy.py +++ b/src/tilegym/ops/cutile/layer_norm_legacy.py @@ -6,13 +6,16 @@ from types import SimpleNamespace import cuda.tile as ct -import cuda.tile_experimental as ct_experimental import torch +from cuda.tile.tune import exhaustive_search from tilegym.backend import register_impl from .utils import next_power_of_2 +# Module-level tune cache: (N, D, BLOCK_D, IS_SWISH, TRAINING, COMPUTE_MEAN_AND_RSTD, dtype, device) -> (best_cfg, tuned_kernel) +_layer_norm_legacy_tune_cache: dict = {} + PAD_ZERO = ct.PaddingMode.ZERO @@ -328,16 +331,26 @@ def grid_fn(cfg): grid_size = min(NUM_SM, num_row_blocks) return (grid_size, 1, 1) - ct_experimental.autotune_launch( - stream, - grid_fn=grid_fn, - kernel=_persistent_layer_norm_fwd_kernel, - args_fn=args_fn, - hints_fn=lambda cfg: { - "num_ctas": cfg.num_ctas, - }, - search_space=search_space, - ) + cache_key = (N, D, BLOCK_D, IS_SWISH, TRAINING, COMPUTE_MEAN_AND_RSTD, x.dtype, str(x.device)) + if cache_key not in _layer_norm_legacy_tune_cache: + result = exhaustive_search( + pruned_configs, + stream, + grid_fn, + _persistent_layer_norm_fwd_kernel, + args_fn, + lambda cfg: {"num_ctas": cfg.num_ctas}, + ) + best_cfg = result.best.config + _layer_norm_legacy_tune_cache[cache_key] = ( + best_cfg, + ct.kernel( + _persistent_layer_norm_fwd_kernel._pyfunc, + num_ctas=best_cfg.num_ctas, + ), + ) + best_cfg, tuned_kernel = _layer_norm_legacy_tune_cache[cache_key] + ct.launch(stream, grid_fn(best_cfg), tuned_kernel, args_fn(best_cfg)) def cutile_persistent_layer_norm_fwd( diff --git a/src/tilegym/ops/cutile/matmul.py b/src/tilegym/ops/cutile/matmul.py index 8fce891b..cb0c0a79 100644 --- a/src/tilegym/ops/cutile/matmul.py +++ b/src/tilegym/ops/cutile/matmul.py @@ -6,12 +6,16 @@ from types import SimpleNamespace import cuda.tile as ct -import cuda.tile_experimental as ct_experimental import torch +from cuda.tile.tune import exhaustive_search from tilegym.backend import register_impl from tilegym.logger import get_logger +# Module-level tune caches: (M, N, K, dtype, device) -> (best_cfg, tuned_kernel) +_matmul_tune_cache: dict = {} +_static_persistent_matmul_tune_cache: dict = {} + logger = get_logger(__name__) # Type aliases for constants @@ -217,20 +221,32 @@ def _matmul_autotune_configs(): def cutile_autotune_matmul(stream, a, b, c): M, N = c.shape - ct_experimental.autotune_launch( + K = a.shape[1] + cache_key = (M, N, K, a.dtype, str(a.device)) + if cache_key not in _matmul_tune_cache: + result = exhaustive_search( + list(_matmul_autotune_configs()), + stream, + lambda cfg: (ceil(M / cfg.TILE_SIZE_M) * ceil(N / cfg.TILE_SIZE_N), 1, 1), + matmul_kernel, + lambda cfg: (a, b, c, cfg.TILE_SIZE_M, cfg.TILE_SIZE_N, cfg.TILE_SIZE_K), + lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _matmul_tune_cache[cache_key] = ( + best_cfg, + ct.kernel( + matmul_kernel._pyfunc, + num_ctas=best_cfg.num_ctas, + occupancy=best_cfg.occupancy, + ), + ) + best_cfg, tuned_kernel = _matmul_tune_cache[cache_key] + ct.launch( stream, - grid_fn=lambda cfg: ( - ceil(M / cfg.TILE_SIZE_M) * ceil(N / cfg.TILE_SIZE_N), - 1, - 1, - ), - kernel=matmul_kernel, - args_fn=lambda cfg: (a, b, c, cfg.TILE_SIZE_M, cfg.TILE_SIZE_N, cfg.TILE_SIZE_K), - hints_fn=lambda cfg: { - "num_ctas": cfg.num_ctas, - "occupancy": cfg.occupancy, - }, - search_space=_matmul_autotune_configs, + (ceil(M / best_cfg.TILE_SIZE_M) * ceil(N / best_cfg.TILE_SIZE_N), 1, 1), + tuned_kernel, + (a, b, c, best_cfg.TILE_SIZE_M, best_cfg.TILE_SIZE_N, best_cfg.TILE_SIZE_K), ) return c @@ -269,33 +285,66 @@ def _static_persistent_matmul_autotune_configs(): def cutile_autotune_static_persistent_matmul(stream, a, b, c, M, N, K, trans_a, trans_b): NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - ct_experimental.autotune_launch( + cache_key = (M, N, K, trans_a, trans_b, a.dtype, str(a.device)) + if cache_key not in _static_persistent_matmul_tune_cache: + result = exhaustive_search( + list(_static_persistent_matmul_autotune_configs()), + stream, + lambda cfg: ( + min(NUM_SMS // cfg.num_ctas, ceil(M / cfg.TILE_SIZE_M) * ceil(N / cfg.TILE_SIZE_N)) * cfg.occupancy, + 1, + 1, + ), + static_persistent_matmul_kernel, + lambda cfg: ( + a, + b, + c, + M, + N, + K, + cfg.TILE_SIZE_M, + cfg.TILE_SIZE_N, + cfg.TILE_SIZE_K, + trans_a, + trans_b, + cfg.GROUP_SIZE_M, + ), + lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _static_persistent_matmul_tune_cache[cache_key] = ( + best_cfg, + ct.kernel( + static_persistent_matmul_kernel._pyfunc, + num_ctas=best_cfg.num_ctas, + occupancy=best_cfg.occupancy, + ), + ) + best_cfg, tuned_kernel = _static_persistent_matmul_tune_cache[cache_key] + ct.launch( stream, - grid_fn=lambda cfg: ( - min(NUM_SMS // cfg.num_ctas, ceil(M / cfg.TILE_SIZE_M) * ceil(N / cfg.TILE_SIZE_N)) * cfg.occupancy, + ( + min(NUM_SMS // best_cfg.num_ctas, ceil(M / best_cfg.TILE_SIZE_M) * ceil(N / best_cfg.TILE_SIZE_N)) + * best_cfg.occupancy, 1, 1, ), - kernel=static_persistent_matmul_kernel, - args_fn=lambda cfg: ( + tuned_kernel, + ( a, b, c, M, N, K, - cfg.TILE_SIZE_M, - cfg.TILE_SIZE_N, - cfg.TILE_SIZE_K, + best_cfg.TILE_SIZE_M, + best_cfg.TILE_SIZE_N, + best_cfg.TILE_SIZE_K, trans_a, trans_b, - cfg.GROUP_SIZE_M, + best_cfg.GROUP_SIZE_M, ), - hints_fn=lambda cfg: { - "num_ctas": cfg.num_ctas, - "occupancy": cfg.occupancy, - }, - search_space=_static_persistent_matmul_autotune_configs, ) return c diff --git a/src/tilegym/suites/unsloth/cutile/geglu.py b/src/tilegym/suites/unsloth/cutile/geglu.py index 461d325f..e4892a35 100644 --- a/src/tilegym/suites/unsloth/cutile/geglu.py +++ b/src/tilegym/suites/unsloth/cutile/geglu.py @@ -23,8 +23,8 @@ """ import cuda.tile as ct -import cuda.tile_experimental as ct_experimental import torch +from cuda.tile.tune import exhaustive_search from tilegym.backend import register_impl @@ -34,6 +34,10 @@ ConstInt = ct.Constant[int] +# Module-level tune caches: (n_elements, LONG_INDEXING, dtype, device) -> tuned_kernel +_geglu_exact_fwd_tune_cache: dict = {} +_geglu_approx_fwd_tune_cache: dict = {} + # signed int32 max is 2**31-1 so num_elements cannot exceed 2**31 NUM_INT32_ELEMENTS = 2**31 SAFE_INT32_BUFFER_MULTIPLIER = 4 @@ -185,11 +189,34 @@ def geglu_exact_forward(gate, up): out = torch.empty((batch, seq_len, hd), dtype=gate.dtype, device=gate.device) stream = torch.cuda.current_stream() LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1 - ct_experimental.autotune_launch( + cache_key = (n_elements, LONG_INDEXING, gate.dtype, str(gate.device)) + if cache_key not in _geglu_exact_fwd_tune_cache: + result = exhaustive_search( + list(autotune_configs()), + stream, + lambda cfg: (cdiv(n_elements, BLOCK_SIZE_FWD),), + _exact_forward_ct, + lambda cfg: ( + gate.reshape(-1), + up.reshape(-1), + out.reshape(-1), + n_elements, + BLOCK_SIZE_FWD, + LONG_INDEXING, + ), + lambda cfg: {"occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _geglu_exact_fwd_tune_cache[cache_key] = ct.kernel( + _exact_forward_ct._pyfunc, + occupancy=best_cfg.occupancy, + ) + tuned_kernel = _geglu_exact_fwd_tune_cache[cache_key] + ct.launch( stream, - grid_fn=lambda cfg: (cdiv(n_elements, BLOCK_SIZE_FWD),), - kernel=_exact_forward_ct, - args_fn=lambda cfg: ( + (cdiv(n_elements, BLOCK_SIZE_FWD),), + tuned_kernel, + ( gate.reshape(-1), up.reshape(-1), out.reshape(-1), @@ -197,8 +224,6 @@ def geglu_exact_forward(gate, up): BLOCK_SIZE_FWD, LONG_INDEXING, ), - hints_fn=lambda cfg: {"occupancy": cfg.occupancy}, - search_space=autotune_configs, ) return out @@ -225,11 +250,34 @@ def geglu_approx_forward(gate, up): out = torch.empty((batch, seq_len, hd), dtype=gate.dtype, device=gate.device) stream = torch.cuda.current_stream() LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1 - ct_experimental.autotune_launch( + cache_key = (n_elements, LONG_INDEXING, gate.dtype, str(gate.device)) + if cache_key not in _geglu_approx_fwd_tune_cache: + result = exhaustive_search( + list(autotune_configs()), + stream, + lambda cfg: (cdiv(n_elements, BLOCK_SIZE_FWD),), + _approx_forward_ct, + lambda cfg: ( + gate.reshape(-1), + up.reshape(-1), + out.reshape(-1), + n_elements, + BLOCK_SIZE_FWD, + LONG_INDEXING, + ), + lambda cfg: {"occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _geglu_approx_fwd_tune_cache[cache_key] = ct.kernel( + _approx_forward_ct._pyfunc, + occupancy=best_cfg.occupancy, + ) + tuned_kernel = _geglu_approx_fwd_tune_cache[cache_key] + ct.launch( stream, - grid_fn=lambda cfg: (cdiv(n_elements, BLOCK_SIZE_FWD),), - kernel=_approx_forward_ct, - args_fn=lambda cfg: ( + (cdiv(n_elements, BLOCK_SIZE_FWD),), + tuned_kernel, + ( gate.reshape(-1), up.reshape(-1), out.reshape(-1), @@ -237,8 +285,6 @@ def geglu_approx_forward(gate, up): BLOCK_SIZE_FWD, LONG_INDEXING, ), - hints_fn=lambda cfg: {"occupancy": cfg.occupancy}, - search_space=autotune_configs, ) return out diff --git a/src/tilegym/suites/unsloth/cutile/grouped_gemm.py b/src/tilegym/suites/unsloth/cutile/grouped_gemm.py index 5ae90181..7c830f86 100644 --- a/src/tilegym/suites/unsloth/cutile/grouped_gemm.py +++ b/src/tilegym/suites/unsloth/cutile/grouped_gemm.py @@ -48,14 +48,19 @@ import math import cuda.tile as ct -import cuda.tile_experimental as ct_experimental import torch +from cuda.tile.tune import exhaustive_search from tilegym.backend import register_impl from .ct_ops import autotune_configs from .ct_ops import next_power_of_2 +# Module-level tune caches for forward, dX backward, and dW backward +_grouped_gemm_fwd_tune_cache: dict = {} +_grouped_gemm_dX_tune_cache: dict = {} +_grouped_gemm_dW_tune_cache: dict = {} + def _gemm_block_sizes(N, K, avg_tokens): """Select block sizes heuristically based on problem shape. @@ -563,11 +568,58 @@ def forward( if total_tokens > 0: NUM_SMS = _get_num_sms(X.device) - ct_experimental.autotune_launch( - torch.cuda.current_stream(), - grid_fn=lambda cfg: (NUM_SMS, 1, 1), - kernel=_grouped_gemm_fwd_kernel_ct, - args_fn=lambda cfg: ( + fwd_stream = torch.cuda.current_stream() + fwd_cache_key = ( + total_tokens, + N, + K, + num_experts, + BLOCK_M, + BLOCK_N, + BLOCK_K, + permute_x_flag, + permute_y_flag, + topk, + X.dtype, + str(X.device), + ) + if fwd_cache_key not in _grouped_gemm_fwd_tune_cache: + result = exhaustive_search( + list(autotune_configs()), + fwd_stream, + lambda cfg: (NUM_SMS, 1, 1), + _grouped_gemm_fwd_kernel_ct, + lambda cfg: ( + X_2d, + W_flat, + Y, + m_sizes_i32, + gather_indices_i32, + N, + K, + total_tokens, + num_experts, + NUM_SMS, + BLOCK_M, + BLOCK_N, + BLOCK_K, + permute_x_flag, + permute_y_flag, + topk, + ), + lambda cfg: {"occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _grouped_gemm_fwd_tune_cache[fwd_cache_key] = ct.kernel( + _grouped_gemm_fwd_kernel_ct._pyfunc, + occupancy=best_cfg.occupancy, + ) + tuned_fwd_kernel = _grouped_gemm_fwd_tune_cache[fwd_cache_key] + ct.launch( + fwd_stream, + (NUM_SMS, 1, 1), + tuned_fwd_kernel, + ( X_2d, W_flat, Y, @@ -585,8 +637,6 @@ def forward( permute_y_flag, topk, ), - hints_fn=lambda cfg: {"occupancy": cfg.occupancy}, - search_space=autotune_configs, ) ctx.save_for_backward(X, W, m_sizes, gather_indices) @@ -649,11 +699,55 @@ def backward(ctx, dY): dX = torch.zeros((total_tokens, K), device=dY.device, dtype=dY.dtype) if total_tokens > 0: - ct_experimental.autotune_launch( + dX_cache_key = ( + total_tokens, + N, + K, + num_experts, + BLOCK_M, + BLOCK_N, + BLOCK_K, + permute_x_flag, + permute_y_flag, + dY.dtype, + str(dY.device), + ) + if dX_cache_key not in _grouped_gemm_dX_tune_cache: + result = exhaustive_search( + list(autotune_configs()), + stream, + lambda cfg: (NUM_SMS, 1, 1), + _grouped_gemm_dX_kernel_ct, + lambda cfg: ( + dY.view(-1, N), + W_flat, + dX, + m_sizes_i32, + gather_indices_i32, + N, + K, + total_tokens, + num_experts, + NUM_SMS, + BLOCK_M, + BLOCK_N, + BLOCK_K, + permute_x_flag, + permute_y_flag, + ), + lambda cfg: {"occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _grouped_gemm_dX_tune_cache[dX_cache_key] = ct.kernel( + _grouped_gemm_dX_kernel_ct._pyfunc, + occupancy=best_cfg.occupancy, + ) + tuned_dX_kernel = _grouped_gemm_dX_tune_cache[dX_cache_key] + ct.launch( stream, - grid_fn=lambda cfg: (NUM_SMS, 1, 1), - kernel=_grouped_gemm_dX_kernel_ct, - args_fn=lambda cfg: ( + (NUM_SMS, 1, 1), + tuned_dX_kernel, + ( dY.view(-1, N), W_flat, dX, @@ -670,8 +764,6 @@ def backward(ctx, dY): permute_x_flag, permute_y_flag, ), - hints_fn=lambda cfg: {"occupancy": cfg.occupancy}, - search_space=autotune_configs, ) # topk > 1 with permute_x: multiple expert slots map to same token, @@ -690,11 +782,56 @@ def backward(ctx, dY): dW = torch.zeros((num_experts * N, K), device=dY.device, dtype=dY.dtype) if total_dw_tiles > 0: - ct_experimental.autotune_launch( + dW_cache_key = ( + total_tokens, + N, + K, + num_experts, + BLOCK_M, + BLOCK_N, + BLOCK_K, + permute_x_flag, + permute_y_flag, + topk, + dY.dtype, + str(dY.device), + ) + if dW_cache_key not in _grouped_gemm_dW_tune_cache: + result = exhaustive_search( + list(autotune_configs()), + stream, + lambda cfg: (total_dw_tiles, 1, 1), + _grouped_gemm_dW_kernel_ct, + lambda cfg: ( + X_2d, + dY.view(-1, N), + dW, + m_sizes_i32, + gather_indices_i32, + num_experts, + N, + K, + total_tokens, + BLOCK_M, + BLOCK_N, + BLOCK_K, + permute_x_flag, + permute_y_flag, + topk, + ), + lambda cfg: {"occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _grouped_gemm_dW_tune_cache[dW_cache_key] = ct.kernel( + _grouped_gemm_dW_kernel_ct._pyfunc, + occupancy=best_cfg.occupancy, + ) + tuned_dW_kernel = _grouped_gemm_dW_tune_cache[dW_cache_key] + ct.launch( stream, - grid_fn=lambda cfg: (total_dw_tiles, 1, 1), - kernel=_grouped_gemm_dW_kernel_ct, - args_fn=lambda cfg: ( + (total_dw_tiles, 1, 1), + tuned_dW_kernel, + ( X_2d, dY.view(-1, N), dW, @@ -711,8 +848,6 @@ def backward(ctx, dY): permute_y_flag, topk, ), - hints_fn=lambda cfg: {"occupancy": cfg.occupancy}, - search_space=autotune_configs, ) dW = dW.view(num_experts, N, K) diff --git a/src/tilegym/suites/unsloth/cutile/layernorm.py b/src/tilegym/suites/unsloth/cutile/layernorm.py index 3b8e4bef..87fbbc83 100644 --- a/src/tilegym/suites/unsloth/cutile/layernorm.py +++ b/src/tilegym/suites/unsloth/cutile/layernorm.py @@ -22,8 +22,8 @@ """ import cuda.tile as ct -import cuda.tile_experimental as ct_experimental import torch +from cuda.tile.tune import exhaustive_search from tilegym.backend import register_impl @@ -33,6 +33,10 @@ ConstInt = ct.Constant[int] ConstFloat = ct.Constant[float] +# Module-level tune caches: (direction, n_rows, n_cols, dtype, TILE_N, device) -> tuned_kernel +_layernorm_fwd_tune_cache: dict = {} +_layernorm_bwd_tune_cache: dict = {} + def _layernorm_forward_ct_1d_body(Y, X, W, b, r, mu, n_cols, eps, TILE_N): """ @@ -203,13 +207,27 @@ def forward(ctx, X, W, b, eps): mu = torch.empty(n_rows, dtype=torch.float32, device=X.device) stream = torch.cuda.current_stream() - ct_experimental.autotune_launch( + fwd_cache_key = (n_rows, n_cols, X.dtype, TILE_N, str(X.device)) + if fwd_cache_key not in _layernorm_fwd_tune_cache: + result = exhaustive_search( + list(autotune_configs()), + stream, + lambda cfg: (n_rows,), + _layernorm_forward_ct_1d, + lambda cfg: (Y, X, W, b, r, mu, n_cols, eps, TILE_N), + lambda cfg: {"occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _layernorm_fwd_tune_cache[fwd_cache_key] = ct.kernel( + _layernorm_forward_ct_1d._pyfunc, + occupancy=best_cfg.occupancy, + ) + tuned_fwd_kernel = _layernorm_fwd_tune_cache[fwd_cache_key] + ct.launch( stream, - grid_fn=lambda cfg: (n_rows,), - kernel=_layernorm_forward_ct_1d, - args_fn=lambda cfg: (Y, X, W, b, r, mu, n_cols, eps, TILE_N), - hints_fn=lambda cfg: {"occupancy": cfg.occupancy}, - search_space=autotune_configs, + (n_rows,), + tuned_fwd_kernel, + (Y, X, W, b, r, mu, n_cols, eps, TILE_N), ) ctx.eps = eps @@ -231,13 +249,27 @@ def backward(ctx, dY): dX = torch.empty_like(dY) stream = torch.cuda.current_stream() - ct_experimental.autotune_launch( + bwd_cache_key = (n_rows, n_cols, dY.dtype, ctx.TILE_N, str(dY.device)) + if bwd_cache_key not in _layernorm_bwd_tune_cache: + result = exhaustive_search( + list(autotune_configs()), + stream, + lambda cfg: (n_rows,), + _layernorm_backward_ct_1d, + lambda cfg: (dX, dY, X, W, r, mu, n_cols, ctx.TILE_N), + lambda cfg: {"occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _layernorm_bwd_tune_cache[bwd_cache_key] = ct.kernel( + _layernorm_backward_ct_1d._pyfunc, + occupancy=best_cfg.occupancy, + ) + tuned_bwd_kernel = _layernorm_bwd_tune_cache[bwd_cache_key] + ct.launch( stream, - grid_fn=lambda cfg: (n_rows,), - kernel=_layernorm_backward_ct_1d, - args_fn=lambda cfg: (dX, dY, X, W, r, mu, n_cols, ctx.TILE_N), - hints_fn=lambda cfg: {"occupancy": cfg.occupancy}, - search_space=autotune_configs, + (n_rows,), + tuned_bwd_kernel, + (dX, dY, X, W, r, mu, n_cols, ctx.TILE_N), ) return dX.view(*shape), None, None, None, None diff --git a/src/tilegym/suites/unsloth/cutile/rms_layernorm.py b/src/tilegym/suites/unsloth/cutile/rms_layernorm.py index e4eb2c7a..427d99de 100644 --- a/src/tilegym/suites/unsloth/cutile/rms_layernorm.py +++ b/src/tilegym/suites/unsloth/cutile/rms_layernorm.py @@ -28,14 +28,17 @@ """ import cuda.tile as ct -import cuda.tile_experimental as ct_experimental import torch +from cuda.tile.tune import exhaustive_search from tilegym.backend import register_impl from .ct_ops import autotune_configs from .ct_ops import calculate_settings +# Module-level tune cache: (direction, n_rows, n_cols, dtype, TILE_N, OFFSET, device) -> tuned_kernel +_rms_layernorm_tune_cache: dict = {} + ConstInt = ct.Constant[int] ConstFloat = ct.Constant[float] @@ -192,13 +195,27 @@ def forward(ctx, X, W, eps, gemma=False): r = torch.empty(n_rows, dtype=torch.float32, device=X.device) stream = torch.cuda.current_stream() - ct_experimental.autotune_launch( + fwd_cache_key = ("fwd", n_rows, n_cols, X.dtype, TILE_N, OFFSET, str(X.device)) + if fwd_cache_key not in _rms_layernorm_tune_cache: + result = exhaustive_search( + list(autotune_configs()), + stream, + lambda cfg: (n_rows,), + _rms_layernorm_forward_ct_1d, + lambda cfg: (Y, X, W, r, n_cols, eps, OFFSET, TILE_N), + lambda cfg: {"occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _rms_layernorm_tune_cache[fwd_cache_key] = ct.kernel( + _rms_layernorm_forward_ct_1d._pyfunc, + occupancy=best_cfg.occupancy, + ) + tuned_fwd_kernel = _rms_layernorm_tune_cache[fwd_cache_key] + ct.launch( stream, - grid_fn=lambda cfg: (n_rows,), - kernel=_rms_layernorm_forward_ct_1d, - args_fn=lambda cfg: (Y, X, W, r, n_cols, eps, OFFSET, TILE_N), - hints_fn=lambda cfg: {"occupancy": cfg.occupancy}, - search_space=autotune_configs, + (n_rows,), + tuned_fwd_kernel, + (Y, X, W, r, n_cols, eps, OFFSET, TILE_N), ) ctx.eps = eps @@ -222,13 +239,27 @@ def backward(ctx, dY): dX = torch.empty_like(dY) stream = torch.cuda.current_stream() - ct_experimental.autotune_launch( + bwd_cache_key = ("bwd", n_rows, n_cols, dY.dtype, ctx.TILE_N, ctx.OFFSET, str(dY.device)) + if bwd_cache_key not in _rms_layernorm_tune_cache: + result = exhaustive_search( + list(autotune_configs()), + stream, + lambda cfg: (n_rows,), + _rms_layernorm_backward_ct_1d, + lambda cfg: (dX, dY, X, W, r, n_cols, ctx.OFFSET, ctx.TILE_N), + lambda cfg: {"occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _rms_layernorm_tune_cache[bwd_cache_key] = ct.kernel( + _rms_layernorm_backward_ct_1d._pyfunc, + occupancy=best_cfg.occupancy, + ) + tuned_bwd_kernel = _rms_layernorm_tune_cache[bwd_cache_key] + ct.launch( stream, - grid_fn=lambda cfg: (n_rows,), - kernel=_rms_layernorm_backward_ct_1d, - args_fn=lambda cfg: (dX, dY, X, W, r, n_cols, ctx.OFFSET, ctx.TILE_N), - hints_fn=lambda cfg: {"occupancy": cfg.occupancy}, - search_space=autotune_configs, + (n_rows,), + tuned_bwd_kernel, + (dX, dY, X, W, r, n_cols, ctx.OFFSET, ctx.TILE_N), ) return dX.view(*shape), None, None, None diff --git a/src/tilegym/suites/unsloth/cutile/rope_embedding.py b/src/tilegym/suites/unsloth/cutile/rope_embedding.py index 68c60f81..ed51cc23 100644 --- a/src/tilegym/suites/unsloth/cutile/rope_embedding.py +++ b/src/tilegym/suites/unsloth/cutile/rope_embedding.py @@ -23,7 +23,7 @@ - _rope_embedding_QK_ct: joint Q+K RoPE for (batch, heads, seq, 2, half_dim) Performance notes: - - Autotune over occupancy=[1,2,4,8] via ct_experimental.autotune_launch. + - Autotune over occupancy=[1,2,4,8] via cuda.tile.tune.exhaustive_search. - Split-buffer pattern: kernel reads from Q_in and writes to Q_out. This allows autotune to re-run the kernel without corrupting input data (no clone needed). For backward (ct.launch, no autotune), Q_in=Q_out @@ -35,8 +35,8 @@ """ import cuda.tile as ct -import cuda.tile_experimental as ct_experimental import torch +from cuda.tile.tune import exhaustive_search from tilegym.backend import register_impl @@ -46,6 +46,10 @@ ConstInt = ct.Constant[int] PAD_ZERO = ct.PaddingMode.ZERO +# Module-level tune caches: (n_rows, n_heads, seq_len, head_dim, TILE_HD, dtype, device) -> tuned_kernel +_rope_embedding_single_tune_cache: dict = {} +_rope_embedding_qk_tune_cache: dict = {} + # ---- CuTile kernel: joint Q+K RoPE (1D gather/scatter) ---- @@ -259,11 +263,39 @@ def forward(ctx, Q, cos, sin): # Autotune over occupancy=[1,2,4,8] (matching layernorm.py pattern). # Split-buffer: Q_flat_1d is read-only, Q_result is write-only. stream = torch.cuda.current_stream() - ct_experimental.autotune_launch( + single_cache_key = (n_rows, n_heads, seq_len, head_dim, TILE_HD, cos_row_stride, Q.dtype, str(Q.device)) + if single_cache_key not in _rope_embedding_single_tune_cache: + result = exhaustive_search( + list(autotune_configs()), + stream, + lambda cfg: (n_rows, n_heads, 1), + _rope_embedding_ct, + lambda cfg: ( + Q_flat_1d, + Q_result, + cos_flat, + sin_flat, + seq_len, + n_heads, + head_dim, + cos_row_stride, + 0, + TILE_HD, + no_padding, + ), + lambda cfg: {"occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _rope_embedding_single_tune_cache[single_cache_key] = ct.kernel( + _rope_embedding_ct._pyfunc, + occupancy=best_cfg.occupancy, + ) + tuned_kernel = _rope_embedding_single_tune_cache[single_cache_key] + ct.launch( stream, - grid_fn=lambda cfg: (n_rows, n_heads, 1), - kernel=_rope_embedding_ct, - args_fn=lambda cfg: ( + (n_rows, n_heads, 1), + tuned_kernel, + ( Q_flat_1d, Q_result, cos_flat, @@ -272,12 +304,10 @@ def forward(ctx, Q, cos, sin): n_heads, head_dim, cos_row_stride, - 0, # BACKWARD_PASS = False + 0, TILE_HD, no_padding, ), - hints_fn=lambda cfg: {"occupancy": cfg.occupancy}, - search_space=autotune_configs, ) ctx.TILE_HD = TILE_HD @@ -370,11 +400,54 @@ def forward(ctx, Q, K, cos, sin, rope_indices): sin_flat = sin.reshape(-1) stream = torch.cuda.current_stream() - ct_experimental.autotune_launch( + qk_cache_key = ( + n_rows, + n_heads_Q, + n_heads_K, + seq_len, + head_dim, + TILE_HD, + has_indices_int, + Q.dtype, + str(Q.device), + ) + if qk_cache_key not in _rope_embedding_qk_tune_cache: + result = exhaustive_search( + list(autotune_configs()), + stream, + lambda cfg: (n_rows, n_heads_Q, 1), + _rope_embedding_QK_ct, + lambda cfg: ( + Q_flat, + Q_result, + K_flat, + K_result, + cos_flat, + sin_flat, + rope_ptr, + seq_len, + head_dim, + n_heads_Q, + n_heads_K, + cos_row_stride, + 0, + has_indices_int, + TILE_HD, + no_padding, + ), + lambda cfg: {"occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _rope_embedding_qk_tune_cache[qk_cache_key] = ct.kernel( + _rope_embedding_QK_ct._pyfunc, + occupancy=best_cfg.occupancy, + ) + tuned_qk_kernel = _rope_embedding_qk_tune_cache[qk_cache_key] + ct.launch( stream, - grid_fn=lambda cfg: (n_rows, n_heads_Q, 1), - kernel=_rope_embedding_QK_ct, - args_fn=lambda cfg: ( + (n_rows, n_heads_Q, 1), + tuned_qk_kernel, + ( Q_flat, Q_result, K_flat, @@ -392,8 +465,6 @@ def forward(ctx, Q, K, cos, sin, rope_indices): TILE_HD, no_padding, ), - hints_fn=lambda cfg: {"occupancy": cfg.occupancy}, - search_space=autotune_configs, ) ctx.TILE_HD = TILE_HD diff --git a/src/tilegym/suites/unsloth/cutile/swiglu.py b/src/tilegym/suites/unsloth/cutile/swiglu.py index 81fe5948..1e5d7501 100644 --- a/src/tilegym/suites/unsloth/cutile/swiglu.py +++ b/src/tilegym/suites/unsloth/cutile/swiglu.py @@ -15,8 +15,8 @@ """ import cuda.tile as ct -import cuda.tile_experimental as ct_experimental import torch +from cuda.tile.tune import exhaustive_search from tilegym.backend import register_impl @@ -25,6 +25,9 @@ ConstInt = ct.Constant[int] +# Module-level tune cache: (n_elements, LONG_INDEXING, dtype, device) -> tuned_kernel +_swiglu_fg_tune_cache: dict = {} + # signed int32 max is 2**31-1 so num_elements cannot exceed 2**31 NUM_INT32_ELEMENTS = 2**31 SAFE_INT32_BUFFER_MULTIPLIER = 4 @@ -97,13 +100,27 @@ def swiglu_fg(e, g): h = torch.empty((batch, seq_len, hd), dtype=e.dtype, device=e.device) stream = torch.cuda.current_stream() LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1 - ct_experimental.autotune_launch( + cache_key = (n_elements, LONG_INDEXING, e.dtype, str(e.device)) + if cache_key not in _swiglu_fg_tune_cache: + result = exhaustive_search( + list(autotune_configs()), + stream, + lambda cfg: (cdiv(n_elements, BLOCK_SIZE_FWD),), + _fg_kernel_ct, + lambda cfg: (e.reshape(-1), g.reshape(-1), h.reshape(-1), n_elements, BLOCK_SIZE_FWD, LONG_INDEXING), + lambda cfg: {"occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _swiglu_fg_tune_cache[cache_key] = ct.kernel( + _fg_kernel_ct._pyfunc, + occupancy=best_cfg.occupancy, + ) + tuned_kernel = _swiglu_fg_tune_cache[cache_key] + ct.launch( stream, - grid_fn=lambda cfg: (cdiv(n_elements, BLOCK_SIZE_FWD),), - kernel=_fg_kernel_ct, - args_fn=lambda cfg: (e.reshape(-1), g.reshape(-1), h.reshape(-1), n_elements, BLOCK_SIZE_FWD, LONG_INDEXING), - hints_fn=lambda cfg: {"occupancy": cfg.occupancy}, - search_space=autotune_configs, + (cdiv(n_elements, BLOCK_SIZE_FWD),), + tuned_kernel, + (e.reshape(-1), g.reshape(-1), h.reshape(-1), n_elements, BLOCK_SIZE_FWD, LONG_INDEXING), ) return h From 8b2392d8d9e2f8fa25c413381e3e967105bc8434 Mon Sep 17 00:00:00 2001 From: Yifei Song Date: Wed, 22 Apr 2026 00:25:19 -0700 Subject: [PATCH 2/6] feat(flashinfer): Add flashinfer kernel, support flashinfer --- src/tilegym/kernel_utils.py | 36 + src/tilegym/ops/cutile/utils.py | 5 + src/tilegym/suites/__init__.py | 5 + src/tilegym/suites/flashinfer/__init__.py | 52 + .../suites/flashinfer/cutile/__init__.py | 23 + .../flashinfer/cutile/fmha_decode_bsr.py | 1466 +++++++++++++++ .../flashinfer/cutile/fmha_prefill_bsr.py | 1660 +++++++++++++++++ .../suites/flashinfer/cutile/gemm/__init__.py | 8 + .../flashinfer/cutile/gemm/gemm_alpha_beta.py | 523 ++++++ .../flashinfer/cutile/gemm/masked_bmm.py | 444 +++++ .../cutile/gemm/ragged_block_scaled_bmm.py | 598 ++++++ .../flashinfer/cutile/gemm/ragged_bmm.py | 747 ++++++++ .../cutile/per_token_group_quant_8bit.py | 272 +++ .../flashinfer/cutile/rope_quantize_fp8.py | 379 ++++ src/tilegym/suites/flashinfer/ops.py | 468 +++++ tests/suites/flashinfer/__init__.py | 3 + .../flashinfer/test_flashinfer_attention.py | 606 ++++++ .../test_flashinfer_gemm_alpha_beta.py | 71 + .../flashinfer/test_flashinfer_masked_bmm.py | 157 ++ ...test_flashinfer_ragged_block_scaled_bmm.py | 271 +++ .../flashinfer/test_flashinfer_ragged_bmm.py | 298 +++ .../test_per_token_group_quant_8bit.py | 150 ++ .../flashinfer/test_rope_quantize_fp8.py | 215 +++ tests/test_utils/__init__.py | 3 + tests/test_utils/bsr_attention_sample.py | 175 ++ tests/test_utils/cudnn_decode.py | 345 ++++ tests/test_utils/cudnn_prefill.py | 510 +++++ 27 files changed, 9490 insertions(+) create mode 100644 src/tilegym/kernel_utils.py create mode 100644 src/tilegym/suites/flashinfer/__init__.py create mode 100644 src/tilegym/suites/flashinfer/cutile/__init__.py create mode 100644 src/tilegym/suites/flashinfer/cutile/fmha_decode_bsr.py create mode 100644 src/tilegym/suites/flashinfer/cutile/fmha_prefill_bsr.py create mode 100644 src/tilegym/suites/flashinfer/cutile/gemm/__init__.py create mode 100644 src/tilegym/suites/flashinfer/cutile/gemm/gemm_alpha_beta.py create mode 100644 src/tilegym/suites/flashinfer/cutile/gemm/masked_bmm.py create mode 100644 src/tilegym/suites/flashinfer/cutile/gemm/ragged_block_scaled_bmm.py create mode 100644 src/tilegym/suites/flashinfer/cutile/gemm/ragged_bmm.py create mode 100644 src/tilegym/suites/flashinfer/cutile/per_token_group_quant_8bit.py create mode 100644 src/tilegym/suites/flashinfer/cutile/rope_quantize_fp8.py create mode 100644 src/tilegym/suites/flashinfer/ops.py create mode 100644 tests/suites/flashinfer/__init__.py create mode 100644 tests/suites/flashinfer/test_flashinfer_attention.py create mode 100644 tests/suites/flashinfer/test_flashinfer_gemm_alpha_beta.py create mode 100644 tests/suites/flashinfer/test_flashinfer_masked_bmm.py create mode 100644 tests/suites/flashinfer/test_flashinfer_ragged_block_scaled_bmm.py create mode 100644 tests/suites/flashinfer/test_flashinfer_ragged_bmm.py create mode 100644 tests/suites/flashinfer/test_per_token_group_quant_8bit.py create mode 100644 tests/suites/flashinfer/test_rope_quantize_fp8.py create mode 100644 tests/test_utils/__init__.py create mode 100644 tests/test_utils/bsr_attention_sample.py create mode 100644 tests/test_utils/cudnn_decode.py create mode 100644 tests/test_utils/cudnn_prefill.py diff --git a/src/tilegym/kernel_utils.py b/src/tilegym/kernel_utils.py new file mode 100644 index 00000000..df035d5b --- /dev/null +++ b/src/tilegym/kernel_utils.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +"""Kernel utility functions for TileGym.""" + +from typing import Any +from typing import Dict +from typing import Optional + +from tilegym.logger import get_logger + + +def get_kernel_configs(default_configs: Dict[str, Any], provided_configs: Optional[Dict[str, Any]] = None): + """ + Merge default kernel configs with provided configs. + + Args: + default_configs: Default kernel configuration dictionary. + provided_configs: Optional user-provided configuration dictionary. + + Returns: + Merged configuration dictionary with provided configs overriding defaults. + """ + logger = get_logger(__name__) + + if provided_configs is None: + return default_configs + # log any differences between default_configs and provided_configs + for key, value in default_configs.items(): + if key not in provided_configs: + logger.warning(f"Provided kernel config {key} is not in default: {value}") + continue + if provided_configs[key] != value: + logger.info(f"Provided kernel config {key} differs from default: {value} -> {provided_configs[key]}") + return {**default_configs, **provided_configs} diff --git a/src/tilegym/ops/cutile/utils.py b/src/tilegym/ops/cutile/utils.py index 64333da4..58dc1c47 100644 --- a/src/tilegym/ops/cutile/utils.py +++ b/src/tilegym/ops/cutile/utils.py @@ -14,3 +14,8 @@ def next_power_of_2(n: int): n |= n >> 32 n += 1 return n + + +def is_power_of_2(n: int): + """Check if n is a power of 2""" + return n > 0 and (n & (n - 1)) == 0 diff --git a/src/tilegym/suites/__init__.py b/src/tilegym/suites/__init__.py index f7c88fa9..1689f83b 100644 --- a/src/tilegym/suites/__init__.py +++ b/src/tilegym/suites/__init__.py @@ -6,6 +6,9 @@ TileGym Suites - cutile implementations for external kernel libraries Usage: + # Import flashinfer suite + from tilegym.suites import flashinfer + output = flashinfer.ops.decode_attention_kv_paged(q, k_cache, v_cache, ...) # Import unsloth suite from tilegym.suites import unsloth @@ -19,8 +22,10 @@ def list_available() -> List[str]: """List all available suites""" available = [] try: + from . import flashinfer from . import unsloth + available.append("flashinfer") available.append("unsloth") except ImportError: pass diff --git a/src/tilegym/suites/flashinfer/__init__.py b/src/tilegym/suites/flashinfer/__init__.py new file mode 100644 index 00000000..c3157a05 --- /dev/null +++ b/src/tilegym/suites/flashinfer/__init__.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +""" +TileGym FlashInfer Suite + +This suite contains implementations for FlashInfer-compatible operations. + +Usage: + from tilegym.suites import flashinfer + + # Use FlashInfer ops interface (recommended) + + + # Use cuTile backend (if available) + if flashinfer.cutile is not None: + output = flashinfer.cutile.decode_attention_kv_paged(...) + +Available backends: +- ops: FlashInfer operations interface (auto backend selection) +- cutile/: cuTile backend implementations (high-performance, if available) + +See USAGE.md for detailed documentation and examples. +""" + +import warnings + +from tilegym.backend import is_backend_available + +# Import ops interface first to register dispatch functions +from . import ops + +if is_backend_available("cutile"): + try: + from . import cutile + except (ImportError, RuntimeError): + cutile = None + warnings.warn("Cutile backend import failed in flashinfer suite, cutile operations will not be available") +else: + cutile = None + + +ref = None +__all__ = ["ops"] + + +if cutile is not None: + __all__.append("cutile") + +if ref is not None: + __all__.append("ref") diff --git a/src/tilegym/suites/flashinfer/cutile/__init__.py b/src/tilegym/suites/flashinfer/cutile/__init__.py new file mode 100644 index 00000000..9fcb4a0f --- /dev/null +++ b/src/tilegym/suites/flashinfer/cutile/__init__.py @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +"""cuTile operations""" + +from . import gemm +from .fmha_decode_bsr import decode_attention_kv_paged +from .fmha_decode_bsr import decode_mla_kv_paged +from .fmha_prefill_bsr import prefill_attention_kv_paged +from .fmha_prefill_bsr import prefill_attention_kv_ragged +from .per_token_group_quant_8bit import per_token_group_quant_8bit +from .rope_quantize_fp8 import rope_quantize_fp8 + +__all__ = [ + "gemm", + "decode_attention_kv_paged", + "decode_mla_kv_paged", + "prefill_attention_kv_paged", + "prefill_attention_kv_ragged", + "per_token_group_quant_8bit", + "rope_quantize_fp8", +] diff --git a/src/tilegym/suites/flashinfer/cutile/fmha_decode_bsr.py b/src/tilegym/suites/flashinfer/cutile/fmha_decode_bsr.py new file mode 100644 index 00000000..bf1147e1 --- /dev/null +++ b/src/tilegym/suites/flashinfer/cutile/fmha_decode_bsr.py @@ -0,0 +1,1466 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +import math +import os +from types import SimpleNamespace +from typing import Optional + +import cuda.tile as ct +import torch +from cuda.tile import RoundingMode as RMd +from cuda.tile.tune import exhaustive_search + +from tilegym.backend import register_impl +from tilegym.ops.cutile.utils import is_power_of_2 +from tilegym.ops.cutile.utils import next_power_of_2 + +# Module-level tune caches for paged decode and MLA decode +_decode_kv_paged_tune_cache: dict = {} +_decode_mla_paged_tune_cache: dict = {} + +INV_LOG_2 = 1.0 / math.log(2) + +ConstInt = ct.Constant[int] +ConstBool = ct.Constant[bool] +ConstFloat = ct.Constant[float] + + +@ct.kernel +def _splitk_reduce_kernel( + attn_splitk_out, + lse_splitk_out, + attn_out, + actual_seq_lens, + num_heads: ConstInt, + num_kv_len_per_split: ConstInt, + NUM_KV_SPLITS: ConstInt, + NUM_KV_SPLITS_POW2: ConstInt, + BLOCK_D: ConstInt, +): + batch_id = ct.bid(0) + head_id = ct.bid(1) + dtype = attn_out.dtype + + seq_len_tile = ct.load(actual_seq_lens, (batch_id,), shape=(1,)) + seq_len = seq_len_tile.item() + actual_num_splits = (seq_len + num_kv_len_per_split - 1) // num_kv_len_per_split + actual_num_splits = ct.minimum(actual_num_splits, NUM_KV_SPLITS) + + lse_vals = ct.load(lse_splitk_out, (batch_id, head_id, 0), shape=(1, 1, NUM_KV_SPLITS_POW2)) + lse_vals = ct.reshape(lse_vals, (NUM_KV_SPLITS_POW2,)) + + split_indices = ct.arange(NUM_KV_SPLITS_POW2, dtype=ct.int32) + valid_mask = split_indices < actual_num_splits + lse_vals = ct.where(valid_mask, lse_vals, ct.full((NUM_KV_SPLITS_POW2,), -1e30, dtype=ct.float32)) + + lse_max = ct.max(lse_vals) + weights = ct.exp2(lse_vals - lse_max) + weights = ct.where(valid_mask, weights, ct.zeros((NUM_KV_SPLITS_POW2,), dtype=ct.float32)) + weights_sum = ct.sum(weights) + weights = weights / weights_sum + + out_all = ct.load(attn_splitk_out, (0, batch_id, head_id, 0), shape=(NUM_KV_SPLITS_POW2, 1, 1, BLOCK_D)) + out_all = ct.reshape(out_all, (NUM_KV_SPLITS_POW2, BLOCK_D)) + out_all = ct.astype(out_all, ct.float32) + + weights_row = ct.reshape(weights, (1, NUM_KV_SPLITS_POW2)) + acc = ct.mma(weights_row, out_all, ct.zeros((1, BLOCK_D), dtype=ct.float32)) + acc = ct.reshape(acc, (BLOCK_D,)) + + result = ct.astype(acc, dtype) + ct.store(attn_out, (batch_id, head_id, 0), ct.reshape(result, (1, 1, BLOCK_D))) + + +def splitk_reduce_with_seq_len(attn_splitk_out, lse_splitk_out, actual_seq_lens, num_kv_len_per_split, attn_out=None): + NUM_KV_SPLITS, B, num_heads, head_dim = attn_splitk_out.shape + + if attn_out is None: + attn_out = torch.empty((B, num_heads, head_dim), device=attn_splitk_out.device, dtype=attn_splitk_out.dtype) + + if NUM_KV_SPLITS == 1: + attn_out.copy_(attn_splitk_out[0]) + return attn_out + + if NUM_KV_SPLITS == 2: + lse_0, lse_1 = lse_splitk_out[:, :, 0], lse_splitk_out[:, :, 1] + lse_max = torch.maximum(lse_0, lse_1) + w0, w1 = torch.exp2(lse_0 - lse_max), torch.exp2(lse_1 - lse_max) + w_sum = w0 + w1 + result = ( + attn_splitk_out[0].float() * w0.unsqueeze(-1) + attn_splitk_out[1].float() * w1.unsqueeze(-1) + ) / w_sum.unsqueeze(-1) + attn_out.copy_(result.to(attn_out.dtype)) + return attn_out + + NUM_KV_SPLITS_POW2 = next_power_of_2(NUM_KV_SPLITS) + BLOCK_D = next_power_of_2(head_dim) + + if NUM_KV_SPLITS < NUM_KV_SPLITS_POW2: + lse_padded = torch.full( + (B, num_heads, NUM_KV_SPLITS_POW2), float("-inf"), device=lse_splitk_out.device, dtype=lse_splitk_out.dtype + ) + lse_padded[:, :, :NUM_KV_SPLITS] = lse_splitk_out + else: + lse_padded = lse_splitk_out + + ct.launch( + torch.cuda.current_stream(), + (B, num_heads, 1), + _splitk_reduce_kernel, + ( + attn_splitk_out, + lse_padded, + attn_out, + actual_seq_lens, + num_heads, + num_kv_len_per_split, + NUM_KV_SPLITS, + NUM_KV_SPLITS_POW2, + BLOCK_D, + ), + ) + return attn_out + + +def load_page( + cache, block_tables, page_table_offset, page, token, off_kv_h, NUM_PAGES, LOAD_BLOCK_N, BLOCK_D, _PAGE_SIZE +): + """ + Load data from paged cache via TMA. + + For single page, issues one TMA load. + For multiple pages, issues N independent TMA loads and concatenates via ct.cat. + Each individual load still uses TMA, loading one page at a time. + + Args: + cache: cache array [total_num_pages, PAGE_SIZE, N_KV_HEADS, BLOCK_D] + block_tables: flattened page table array + page_table_offset: offset into block_tables for current batch + page: starting page index in the page table + token: token offset within page + off_kv_h: KV head index + NUM_PAGES: number of pages to load (1, 2, or 4) + LOAD_BLOCK_N: tokens per page load (== PAGE_SIZE) + BLOCK_D: feature dimension + _PAGE_SIZE: tokens per page (unused; LOAD_BLOCK_N is used instead) + + Returns: + Loaded tensor of shape [NUM_PAGES * LOAD_BLOCK_N, BLOCK_D] + """ + if NUM_PAGES == 1: + page_id = ct.gather(block_tables, (page_table_offset + page,), padding_value=0).item() + data = ct.reshape( + ct.load( + cache, + index=(page_id, token // LOAD_BLOCK_N, off_kv_h, 0), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=2, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + elif NUM_PAGES == 2: + pg0 = ct.gather(block_tables, (page_table_offset + page,), padding_value=0).item() + d0 = ct.reshape( + ct.load( + cache, + index=(pg0, token // LOAD_BLOCK_N, off_kv_h, 0), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=2, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + pg1 = ct.gather(block_tables, (page_table_offset + page + 1,), padding_value=0).item() + d1 = ct.reshape( + ct.load( + cache, + index=(pg1, 0, off_kv_h, 0), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=2, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + data = ct.cat((d0, d1), 0) + elif NUM_PAGES == 4: + pg0 = ct.gather(block_tables, (page_table_offset + page,), padding_value=0).item() + d0 = ct.reshape( + ct.load( + cache, + index=(pg0, token // LOAD_BLOCK_N, off_kv_h, 0), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=2, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + pg1 = ct.gather(block_tables, (page_table_offset + page + 1,), padding_value=0).item() + d1 = ct.reshape( + ct.load( + cache, + index=(pg1, 0, off_kv_h, 0), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=2, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + pg2 = ct.gather(block_tables, (page_table_offset + page + 2,), padding_value=0).item() + d2 = ct.reshape( + ct.load( + cache, + index=(pg2, 0, off_kv_h, 0), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=2, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + pg3 = ct.gather(block_tables, (page_table_offset + page + 3,), padding_value=0).item() + d3 = ct.reshape( + ct.load( + cache, + index=(pg3, 0, off_kv_h, 0), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=2, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + # ct.cat takes exactly a pair; chain for 4 pages + data = ct.cat((ct.cat((d0, d1), 0), ct.cat((d2, d3), 0)), 0) + return data + + +def load_page_wrapper( + curr_n, cache, block_tables, page_table_offset, off_kv_h, PAGE_SIZE, BLOCK_N, BLOCK_D, LOAD_BLOCK_N +): + """ + Load cache data (K or V) for current position. + + Computes page index and token offset from curr_n, then delegates to load_page. + """ + NUM_PAGES = BLOCK_N // LOAD_BLOCK_N + page = curr_n // PAGE_SIZE + token = curr_n % PAGE_SIZE + + data = load_page( + cache, block_tables, page_table_offset, page, token, off_kv_h, NUM_PAGES, LOAD_BLOCK_N, BLOCK_D, PAGE_SIZE + ) + return data + + +@ct.kernel +def _decode_attention_kv_paged_kernel( + q, + k_cache, + v_cache, + actual_seq_lens, + block_tables, + o_ptr, + lse_out, + num_batches: ConstInt, + total_num_pages: ConstInt, + k_scale: ConstFloat, + v_scale: ConstFloat, + N_KV_HEADS: ConstInt, + PAGE_SIZE: ConstInt, + BLOCK_H: ConstInt, + BLOCK_N: ConstInt, + BLOCK_D: ConstInt, + QUERY_GROUP_SIZE: ConstInt, + NUM_KV_SPLITS: ConstInt, + KV_LEN_PER_SPLIT: ConstInt, + HAS_LSE_OUT: ConstBool, + stride_block_table: ConstInt, + NUM_HEAD_BLOCKS: ConstInt, + TRANS_QK: ConstBool, + LOAD_BLOCK_N: ConstInt, + NUM_PAGES_PER_BLOCK: ConstInt, +): + batch_id = ct.bid(0) + head_block_id = ct.bid(1) + kv_split_id = ct.bid(2) + + kv_head_id = head_block_id // NUM_HEAD_BLOCKS + hb = head_block_id % NUM_HEAD_BLOCKS + + seq_len_tile = ct.gather(actual_seq_lens, (batch_id,), padding_value=0) + seq_len = seq_len_tile.item() + + qk_scale = k_scale * INV_LOG_2 + page_table_offset = batch_id * stride_block_table + + if KV_LEN_PER_SPLIT > 0: + start_n = KV_LEN_PER_SPLIT * kv_split_id + end_n = min(start_n + KV_LEN_PER_SPLIT, seq_len) + else: + start_n = 0 + end_n = seq_len + + if start_n >= end_n: + return + + num_iters = (end_n - start_n + BLOCK_N - 1) // BLOCK_N + offs_n_base = ct.arange(BLOCK_N, dtype=ct.int32) + tail_n = start_n + ((end_n - start_n) // BLOCK_N) * BLOCK_N + + head_offset = kv_head_id * QUERY_GROUP_SIZE + hb * BLOCK_H + head_block_idx = head_offset // BLOCK_H + + q_tile = ct.load( + q, index=(batch_id, head_block_idx, 0), shape=(1, BLOCK_H, BLOCK_D), order=(0, 1, 2), allow_tma=True, latency=2 + ) + q_tile = ct.reshape(q_tile, (BLOCK_H, BLOCK_D)) + + neg_inf_h = ct.full((BLOCK_H,), -math.inf, dtype=ct.float32) + ones_h = ct.full((BLOCK_H,), 1.0, dtype=ct.float32) + + if TRANS_QK: + m_i = neg_inf_h + l_i = ct.full((BLOCK_N, BLOCK_H), 0.0, dtype=ct.float32) + acc = ct.full((BLOCK_D, BLOCK_H), 0.0, dtype=ct.float32) + qk_zeros = ct.full((BLOCK_N, BLOCK_H), 0.0, dtype=ct.float32) + mask_fill = ct.full((BLOCK_N, BLOCK_H), -1.0e6, dtype=ct.float32) + + for iter_idx in range(num_iters): + curr_n = start_n + iter_idx * BLOCK_N + + if NUM_PAGES_PER_BLOCK == 1: + # Single-page path: share page_id between K and V loads + page = curr_n // PAGE_SIZE + token = curr_n % PAGE_SIZE + page_id = ct.gather(block_tables, (page_table_offset + page,), padding_value=0).item() + token_idx = token // LOAD_BLOCK_N + + k = ct.reshape( + ct.load( + k_cache, + index=(page_id, token_idx, kv_head_id, 0), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=2, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + else: + k = load_page_wrapper( + curr_n, + k_cache, + block_tables, + page_table_offset, + kv_head_id, + PAGE_SIZE, + BLOCK_N, + BLOCK_D, + LOAD_BLOCK_N, + ) + + qk = ct.mma(k, ct.transpose(q_tile), acc=qk_zeros) + + if curr_n >= tail_n: + offs_n = curr_n + offs_n_base + mask = ct.reshape(ct.less(offs_n, end_n), (BLOCK_N, 1)) + qk = ct.where(mask, qk, mask_fill) + + qk_max = ct.max(qk, axis=0, keepdims=False) + m_ij = ct.maximum(m_i, ct.mul(qk_max, qk_scale, flush_to_zero=True)) + p = ct.exp2( + ct.sub(ct.mul(qk, qk_scale, flush_to_zero=True), ct.reshape(m_ij, (1, BLOCK_H)), flush_to_zero=True), + flush_to_zero=True, + ) + + alpha = ct.exp2(ct.sub(m_i, m_ij, flush_to_zero=True), flush_to_zero=True) + l_i = ct.add(ct.mul(l_i, ct.reshape(alpha, (1, BLOCK_H)), flush_to_zero=True), p, flush_to_zero=True) + acc = ct.mul(acc, ct.reshape(alpha, (1, BLOCK_H)), flush_to_zero=True) + + if NUM_PAGES_PER_BLOCK == 1: + # Reuse page_id from K load + v = ct.reshape( + ct.load( + v_cache, + index=(page_id, token_idx, kv_head_id, 0), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=2, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + else: + v = load_page_wrapper( + curr_n, + v_cache, + block_tables, + page_table_offset, + kv_head_id, + PAGE_SIZE, + BLOCK_N, + BLOCK_D, + LOAD_BLOCK_N, + ) + + acc = ct.mma(ct.transpose(v), ct.astype(p, q.dtype), acc=acc) + m_i = m_ij + + l_i_sum = ct.sum(l_i, axis=0, keepdims=False) + l_i_expanded = ct.reshape(l_i_sum, (1, BLOCK_H)) + acc = ct.truediv( + ct.mul(acc, v_scale, flush_to_zero=True), l_i_expanded, flush_to_zero=True, rounding_mode=RMd.APPROX + ) + acc_out = ct.astype(ct.transpose(acc), o_ptr.dtype) + else: + m_i = neg_inf_h + l_i = ones_h + acc = ct.full((BLOCK_H, BLOCK_D), 0.0, dtype=ct.float32) + qk_zeros = ct.full((BLOCK_H, BLOCK_N), 0.0, dtype=ct.float32) + mask_fill = ct.full((BLOCK_H, BLOCK_N), -1.0e6, dtype=ct.float32) + + for iter_idx in range(num_iters): + curr_n = start_n + iter_idx * BLOCK_N + + if NUM_PAGES_PER_BLOCK == 1: + # Single-page path: share page_id between K and V loads + page = curr_n // PAGE_SIZE + token = curr_n % PAGE_SIZE + page_id = ct.gather(block_tables, (page_table_offset + page,), padding_value=0).item() + token_idx = token // LOAD_BLOCK_N + + k = ct.reshape( + ct.load( + k_cache, + index=(page_id, token_idx, kv_head_id, 0), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=2, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + else: + k = load_page_wrapper( + curr_n, + k_cache, + block_tables, + page_table_offset, + kv_head_id, + PAGE_SIZE, + BLOCK_N, + BLOCK_D, + LOAD_BLOCK_N, + ) + + qk = ct.mma(q_tile, ct.transpose(k), acc=qk_zeros) + + if curr_n >= tail_n: + offs_n = curr_n + offs_n_base + mask = ct.reshape(ct.less(offs_n, end_n), (1, BLOCK_N)) + qk = ct.where(mask, qk, mask_fill) + + qk_max = ct.max(qk, axis=1, keepdims=False) + m_ij = ct.maximum(m_i, ct.mul(qk_max, qk_scale, flush_to_zero=True)) + p = ct.exp2( + ct.sub(ct.mul(qk, qk_scale, flush_to_zero=True), ct.reshape(m_ij, (BLOCK_H, 1)), flush_to_zero=True), + flush_to_zero=True, + ) + + alpha = ct.exp2(ct.sub(m_i, m_ij, flush_to_zero=True), flush_to_zero=True) + l_i = ct.add(ct.mul(l_i, alpha, flush_to_zero=True), ct.sum(p, axis=1, keepdims=False), flush_to_zero=True) + + if NUM_PAGES_PER_BLOCK == 1: + # Reuse page_id from K load + v = ct.reshape( + ct.load( + v_cache, + index=(page_id, token_idx, kv_head_id, 0), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=2, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + else: + v = load_page_wrapper( + curr_n, + v_cache, + block_tables, + page_table_offset, + kv_head_id, + PAGE_SIZE, + BLOCK_N, + BLOCK_D, + LOAD_BLOCK_N, + ) + + acc = ct.mul(acc, ct.reshape(alpha, (BLOCK_H, 1)), flush_to_zero=True) + acc = ct.mma(ct.astype(p, q.dtype), v, acc=acc) + m_i = m_ij + + l_i_expanded = ct.reshape(l_i, (BLOCK_H, 1)) + acc = ct.truediv( + ct.mul(acc, v_scale, flush_to_zero=True), l_i_expanded, flush_to_zero=True, rounding_mode=RMd.APPROX + ) + acc_out = ct.astype(acc, o_ptr.dtype) + + acc_4d = ct.reshape(acc_out, (1, 1, BLOCK_H, BLOCK_D)) + ct.store( + o_ptr, + index=(kv_split_id, batch_id, head_block_idx, 0), + tile=acc_4d, + order=(0, 1, 2, 3), + allow_tma=True, + latency=2, + ) + + if HAS_LSE_OUT: + lse = m_i + ct.log2(l_i if not TRANS_QK else ct.sum(l_i, axis=0, keepdims=False)) + offs_h = ct.arange(BLOCK_H, dtype=ct.int32) + lse_indices = (batch_id, head_offset + offs_h, kv_split_id) + ct.scatter(lse_out, lse_indices, lse) + + +def load_page_mla(cache, block_tables, page_table_offset, page, token, NUM_PAGES, LOAD_BLOCK_N, BLOCK_DIM, _PAGE_SIZE): + """ + Load data from paged MLA cache (3D: [total_num_pages, PAGE_SIZE, dim]) via TMA. + + For single page, issues one TMA load. + For multiple pages, issues N independent TMA loads and concatenates via ct.cat. + + Args: + cache: cache array [total_num_pages, PAGE_SIZE, dim] + block_tables: flattened page table array + page_table_offset: offset into block_tables for current batch + page: starting page index in the page table + token: token offset within page + NUM_PAGES: number of pages to load (1, 2, or 4) + LOAD_BLOCK_N: tokens per page load (== PAGE_SIZE) + BLOCK_DIM: feature dimension (BLOCK_D or BLOCK_R) + _PAGE_SIZE: tokens per page (unused; LOAD_BLOCK_N is used instead) + + Returns: + Loaded tensor of shape [NUM_PAGES * LOAD_BLOCK_N, BLOCK_DIM] + """ + if NUM_PAGES == 1: + page_id = ct.gather(block_tables, (page_table_offset + page,), padding_value=0).item() + data = ct.reshape( + ct.load( + cache, + index=(page_id, token // LOAD_BLOCK_N, 0), + shape=(1, LOAD_BLOCK_N, BLOCK_DIM), + order=(0, 1, 2), + allow_tma=True, + latency=2, + ), + (LOAD_BLOCK_N, BLOCK_DIM), + ) + elif NUM_PAGES == 2: + pg0 = ct.gather(block_tables, (page_table_offset + page,), padding_value=0).item() + d0 = ct.reshape( + ct.load( + cache, + index=(pg0, token // LOAD_BLOCK_N, 0), + shape=(1, LOAD_BLOCK_N, BLOCK_DIM), + order=(0, 1, 2), + allow_tma=True, + latency=2, + ), + (LOAD_BLOCK_N, BLOCK_DIM), + ) + pg1 = ct.gather(block_tables, (page_table_offset + page + 1,), padding_value=0).item() + d1 = ct.reshape( + ct.load( + cache, + index=(pg1, 0, 0), + shape=(1, LOAD_BLOCK_N, BLOCK_DIM), + order=(0, 1, 2), + allow_tma=True, + latency=2, + ), + (LOAD_BLOCK_N, BLOCK_DIM), + ) + data = ct.cat((d0, d1), 0) + elif NUM_PAGES == 4: + pg0 = ct.gather(block_tables, (page_table_offset + page,), padding_value=0).item() + d0 = ct.reshape( + ct.load( + cache, + index=(pg0, token // LOAD_BLOCK_N, 0), + shape=(1, LOAD_BLOCK_N, BLOCK_DIM), + order=(0, 1, 2), + allow_tma=True, + latency=2, + ), + (LOAD_BLOCK_N, BLOCK_DIM), + ) + pg1 = ct.gather(block_tables, (page_table_offset + page + 1,), padding_value=0).item() + d1 = ct.reshape( + ct.load( + cache, + index=(pg1, 0, 0), + shape=(1, LOAD_BLOCK_N, BLOCK_DIM), + order=(0, 1, 2), + allow_tma=True, + latency=2, + ), + (LOAD_BLOCK_N, BLOCK_DIM), + ) + pg2 = ct.gather(block_tables, (page_table_offset + page + 2,), padding_value=0).item() + d2 = ct.reshape( + ct.load( + cache, + index=(pg2, 0, 0), + shape=(1, LOAD_BLOCK_N, BLOCK_DIM), + order=(0, 1, 2), + allow_tma=True, + latency=2, + ), + (LOAD_BLOCK_N, BLOCK_DIM), + ) + pg3 = ct.gather(block_tables, (page_table_offset + page + 3,), padding_value=0).item() + d3 = ct.reshape( + ct.load( + cache, + index=(pg3, 0, 0), + shape=(1, LOAD_BLOCK_N, BLOCK_DIM), + order=(0, 1, 2), + allow_tma=True, + latency=2, + ), + (LOAD_BLOCK_N, BLOCK_DIM), + ) + # ct.cat takes exactly a pair; chain for 4 pages + data = ct.cat((ct.cat((d0, d1), 0), ct.cat((d2, d3), 0)), 0) + return data + + +def load_page_mla_wrapper(curr_n, cache, block_tables, page_table_offset, PAGE_SIZE, BLOCK_N, BLOCK_DIM, LOAD_BLOCK_N): + """ + Load MLA cache data (K, V, or K_rope) for current position. + + Computes page index and token offset from curr_n, then delegates to load_page_mla. + """ + NUM_PAGES = BLOCK_N // LOAD_BLOCK_N + page = curr_n // PAGE_SIZE + token = curr_n % PAGE_SIZE + + data = load_page_mla( + cache, block_tables, page_table_offset, page, token, NUM_PAGES, LOAD_BLOCK_N, BLOCK_DIM, PAGE_SIZE + ) + return data + + +@ct.kernel +def _decode_mla_kv_paged_kernel( + q_nope, + q_rope, + k_cache, + v_cache, + k_rope, + actual_seq_lens, + block_tables, + o_ptr, + lse_out, + num_batches: ConstInt, + total_num_pages: ConstInt, + k_scale: ConstFloat, + v_scale: ConstFloat, + PAGE_SIZE: ConstInt, + BLOCK_H: ConstInt, + BLOCK_N: ConstInt, + BLOCK_D: ConstInt, + BLOCK_R: ConstInt, + QUERY_GROUP_SIZE: ConstInt, + NUM_KV_SPLITS: ConstInt, + KV_LEN_PER_SPLIT: ConstInt, + HAS_LSE_OUT: ConstBool, + stride_block_table: ConstInt, + LOAD_BLOCK_N: ConstInt, + NUM_PAGES_PER_BLOCK: ConstInt, +): + batch_id = ct.bid(0) + head_block_id = ct.bid(1) + kv_split_id = ct.bid(2) + + seq_len_tile = ct.gather(actual_seq_lens, (batch_id,), padding_value=0) + seq_len = seq_len_tile.item() + + qk_scale = k_scale * INV_LOG_2 + page_table_offset = batch_id * stride_block_table + + if KV_LEN_PER_SPLIT > 0: + start_n = KV_LEN_PER_SPLIT * kv_split_id + end_n = min(start_n + KV_LEN_PER_SPLIT, seq_len) + else: + start_n = 0 + end_n = seq_len + + if start_n >= end_n: + return + + num_iters = (end_n - start_n + BLOCK_N - 1) // BLOCK_N + offs_n_base = ct.arange(BLOCK_N, dtype=ct.int32) + # tail_n = start position of last incomplete block (where masking is needed) + tail_n = start_n + ((end_n - start_n) // BLOCK_N) * BLOCK_N + + head_block_idx = head_block_id + head_offset = head_block_id * BLOCK_H + + q_nope_tile = ct.load( + q_nope, + index=(batch_id, head_block_idx, 0), + shape=(1, BLOCK_H, BLOCK_D), + order=(0, 1, 2), + allow_tma=True, + latency=2, + ) + q_nope_tile = ct.reshape(q_nope_tile, (BLOCK_H, BLOCK_D)) + + q_rope_tile = ct.load( + q_rope, + index=(batch_id, head_block_idx, 0), + shape=(1, BLOCK_H, BLOCK_R), + order=(0, 1, 2), + allow_tma=True, + latency=2, + ) + q_rope_tile = ct.reshape(q_rope_tile, (BLOCK_H, BLOCK_R)) + + m_i = ct.full((BLOCK_H,), -math.inf, dtype=ct.float32) + l_i = ct.full((BLOCK_H,), 1.0, dtype=ct.float32) + acc = ct.full((BLOCK_H, BLOCK_D), 0.0, dtype=ct.float32) + + for iter_idx in range(num_iters): + curr_n = start_n + iter_idx * BLOCK_N + + if NUM_PAGES_PER_BLOCK == 1: + # Single-page path: share page_id between K, K_rope, and V loads + page_idx = curr_n // PAGE_SIZE + token_block_idx = (curr_n % PAGE_SIZE) // BLOCK_N + page_id = ct.gather(block_tables, (page_table_offset + page_idx,), padding_value=0).item() + + k_tile = ct.reshape( + ct.load( + k_cache, + index=(page_id, token_block_idx, 0), + shape=(1, BLOCK_N, BLOCK_D), + order=(0, 1, 2), + allow_tma=True, + latency=2, + ), + (BLOCK_N, BLOCK_D), + ) + + k_rope_tile = ct.reshape( + ct.load( + k_rope, + index=(page_id, token_block_idx, 0), + shape=(1, BLOCK_N, BLOCK_R), + order=(0, 1, 2), + allow_tma=True, + latency=2, + ), + (BLOCK_N, BLOCK_R), + ) + else: + # Multi-page path: load BLOCK_N tokens across multiple pages + k_tile = load_page_mla_wrapper( + curr_n, + k_cache, + block_tables, + page_table_offset, + PAGE_SIZE, + BLOCK_N, + BLOCK_D, + LOAD_BLOCK_N, + ) + k_rope_tile = load_page_mla_wrapper( + curr_n, + k_rope, + block_tables, + page_table_offset, + PAGE_SIZE, + BLOCK_N, + BLOCK_R, + LOAD_BLOCK_N, + ) + + qk = ct.mma(q_nope_tile, ct.transpose(k_tile), acc=ct.full((BLOCK_H, BLOCK_N), 0.0, dtype=ct.float32)) + if BLOCK_R > 0: + qk = ct.mma(q_rope_tile, ct.transpose(k_rope_tile), acc=qk) + + if curr_n >= tail_n: + offs_n = curr_n + offs_n_base + mask = ct.reshape(ct.less(offs_n, end_n), (1, BLOCK_N)) + qk = ct.where(mask, qk, ct.full((BLOCK_H, BLOCK_N), -1.0e6, dtype=ct.float32)) + + qk_max = ct.max(qk, axis=1, keepdims=False) + m_ij = ct.maximum(m_i, ct.mul(qk_max, qk_scale, flush_to_zero=True)) + p = ct.exp2( + ct.sub(ct.mul(qk, qk_scale, flush_to_zero=True), ct.reshape(m_ij, (BLOCK_H, 1)), flush_to_zero=True), + flush_to_zero=True, + ) + + alpha = ct.exp2(ct.sub(m_i, m_ij, flush_to_zero=True), flush_to_zero=True) + l_i = ct.add(ct.mul(l_i, alpha, flush_to_zero=True), ct.sum(p, axis=1, keepdims=False), flush_to_zero=True) + + if NUM_PAGES_PER_BLOCK == 1: + # Reuse page_id from K load + v_tile = ct.reshape( + ct.load( + v_cache, + index=(page_id, token_block_idx, 0), + shape=(1, BLOCK_N, BLOCK_D), + order=(0, 1, 2), + allow_tma=True, + latency=2, + ), + (BLOCK_N, BLOCK_D), + ) + else: + v_tile = load_page_mla_wrapper( + curr_n, + v_cache, + block_tables, + page_table_offset, + PAGE_SIZE, + BLOCK_N, + BLOCK_D, + LOAD_BLOCK_N, + ) + + acc = ct.mul(acc, ct.reshape(alpha, (BLOCK_H, 1)), flush_to_zero=True) + acc = ct.mma(ct.astype(p, q_nope.dtype), v_tile, acc=acc) + m_i = m_ij + + l_i_expanded = ct.reshape(l_i, (BLOCK_H, 1)) + acc = ct.truediv( + ct.mul(acc, v_scale, flush_to_zero=True), l_i_expanded, flush_to_zero=True, rounding_mode=RMd.APPROX + ) + acc_out = ct.astype(acc, o_ptr.dtype) + + acc_4d = ct.reshape(acc_out, (1, 1, BLOCK_H, BLOCK_D)) + ct.store( + o_ptr, + index=(kv_split_id, batch_id, head_block_idx, 0), + tile=acc_4d, + order=(0, 1, 2, 3), + allow_tma=True, + latency=2, + ) + + if HAS_LSE_OUT: + lse = m_i + ct.log2(l_i) + offs_h = ct.arange(BLOCK_H, dtype=ct.int32) + lse_indices = (batch_id, head_offset + offs_h, kv_split_id) + ct.scatter(lse_out, lse_indices, lse) + + +_gqa_decode_autotune_cache = {} + + +def _get_gqa_decode_autotune_configs(query_group_size: int, page_size: int = 128): + """Get autotune configurations for GQA decode kernel.""" + cache_key = (query_group_size, page_size) + if cache_key not in _gqa_decode_autotune_cache: + configs = [] + for BLOCK_H in [8, 16, 32, 64]: + if BLOCK_H <= query_group_size and query_group_size % BLOCK_H == 0: + # Allow BLOCK_N values that are multiples of page_size for multi-page loading + for BLOCK_N in [32, 64, 128]: + if BLOCK_N < page_size: + continue + # BLOCK_N must be a multiple of page_size (or equal) + if BLOCK_N > page_size and BLOCK_N % page_size != 0: + continue + for occupancy in [1, 2]: + configs.append(SimpleNamespace(BLOCK_H=BLOCK_H, BLOCK_N=BLOCK_N, occupancy=occupancy)) + if not configs: + BLOCK_H = query_group_size if query_group_size > 0 else 1 + if not is_power_of_2(BLOCK_H): + BLOCK_H = next_power_of_2(BLOCK_H) // 2 + if BLOCK_H == 0: + BLOCK_H = 1 + configs.append(SimpleNamespace(BLOCK_H=BLOCK_H, BLOCK_N=min(64, page_size), occupancy=1)) + _gqa_decode_autotune_cache[cache_key] = configs + return _gqa_decode_autotune_cache[cache_key] + + +def _gqa_decode_autotune_base( + stream, + q, + k_cache, + v_cache, + actual_seq_lens_flat, + block_tables_flat, + Att_Out, + LSE_Out_arg, + num_batch, + num_qo_heads, + num_kv_heads, + total_num_pages, + k_scale, + v_scale, + page_size, + head_dim_qk, + QUERY_GROUP_SIZE, + NUM_KV_SPLITS, + kv_len_per_split, + HAS_LSE_OUT, + stride_block_table, + TRANS_QK, +): + configs = _get_gqa_decode_autotune_configs(QUERY_GROUP_SIZE, page_size) + + cache_key = ( + num_batch, + num_qo_heads, + num_kv_heads, + total_num_pages, + page_size, + head_dim_qk, + QUERY_GROUP_SIZE, + NUM_KV_SPLITS, + kv_len_per_split, + HAS_LSE_OUT, + TRANS_QK, + q.dtype, + str(q.device), + ) + if cache_key not in _decode_kv_paged_tune_cache: + result = exhaustive_search( + list(configs), + stream, + lambda cfg: (num_batch, num_kv_heads * max(QUERY_GROUP_SIZE // cfg.BLOCK_H, 1), NUM_KV_SPLITS), + _decode_attention_kv_paged_kernel, + lambda cfg: ( + q, + k_cache, + v_cache, + actual_seq_lens_flat, + block_tables_flat, + Att_Out, + LSE_Out_arg, + num_batch, + total_num_pages, + k_scale, + v_scale, + num_kv_heads, + page_size, + cfg.BLOCK_H, + cfg.BLOCK_N, + head_dim_qk, + QUERY_GROUP_SIZE, + NUM_KV_SPLITS, + kv_len_per_split, + HAS_LSE_OUT, + stride_block_table, + max(QUERY_GROUP_SIZE // cfg.BLOCK_H, 1), + TRANS_QK, + min(cfg.BLOCK_N, page_size), + max(cfg.BLOCK_N // page_size, 1), + ), + lambda cfg: {"occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _decode_kv_paged_tune_cache[cache_key] = ( + best_cfg, + ct.kernel( + _decode_attention_kv_paged_kernel._pyfunc, + occupancy=best_cfg.occupancy, + ), + ) + best_cfg, tuned_kernel = _decode_kv_paged_tune_cache[cache_key] + ct.launch( + stream, + (num_batch, num_kv_heads * max(QUERY_GROUP_SIZE // best_cfg.BLOCK_H, 1), NUM_KV_SPLITS), + tuned_kernel, + ( + q, + k_cache, + v_cache, + actual_seq_lens_flat, + block_tables_flat, + Att_Out, + LSE_Out_arg, + num_batch, + total_num_pages, + k_scale, + v_scale, + num_kv_heads, + page_size, + best_cfg.BLOCK_H, + best_cfg.BLOCK_N, + head_dim_qk, + QUERY_GROUP_SIZE, + NUM_KV_SPLITS, + kv_len_per_split, + HAS_LSE_OUT, + stride_block_table, + max(QUERY_GROUP_SIZE // best_cfg.BLOCK_H, 1), + TRANS_QK, + min(best_cfg.BLOCK_N, page_size), + max(best_cfg.BLOCK_N // page_size, 1), + ), + ) + return Att_Out + + +@register_impl("flashinfer.attention.decode_attention_kv_paged", backend="cutile") +def decode_attention_kv_paged( + q, + k_cache, + v_cache, + actual_seq_lens, + block_tables, + k_scale, + v_scale, + max_seq_len: int = -1, + outputs: Optional[torch.Tensor] = None, + force_split_kv: bool = False, + force_persistent: bool = False, +): + num_batch = q.shape[0] + num_qo_heads = q.shape[1] + head_dim_qk = q.shape[-1] + + total_num_pages = k_cache.shape[0] + page_size = k_cache.shape[1] + num_kv_heads = k_cache.shape[2] + head_dim_vo = v_cache.shape[-1] + + QUERY_GROUP_SIZE = num_qo_heads // num_kv_heads + TRANS_QK = QUERY_GROUP_SIZE < 64 + + if not (is_power_of_2(head_dim_qk) and is_power_of_2(head_dim_vo)): + raise NotImplementedError( + f"CuTile decode attention requires power-of-2 dimensions. Got head_dim_qk={head_dim_qk}, head_dim_vo={head_dim_vo}." + ) + + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + LSE_Out = None + kv_len_per_split = -1 + NUM_KV_SPLITS = 1 + + if max_seq_len < 0: + max_pages_per_seq = block_tables.shape[1] if block_tables.dim() > 1 else block_tables.shape[0] + max_seq_len = max_pages_per_seq * page_size + + persistent = num_batch * num_kv_heads > NUM_SMS + if force_split_kv: + should_use_split_kv = True + else: + should_use_split_kv = not persistent and max_seq_len > 2048 + + if should_use_split_kv: + num_split_kv_estimated = max(NUM_SMS // num_batch, 1) + kv_len_per_split = 1 << ((max_seq_len // num_split_kv_estimated - 1).bit_length()) + kv_len_per_split = max(kv_len_per_split, 128) + NUM_KV_SPLITS = (max_seq_len + kv_len_per_split - 1) // kv_len_per_split + + if num_batch <= 4: + NUM_KV_SPLITS = min(NUM_KV_SPLITS, 4) + kv_len_per_split = (max_seq_len + NUM_KV_SPLITS - 1) // NUM_KV_SPLITS + # Align kv_len_per_split to next power of 2 (ensure alignment with BLOCK_N) + kv_len_per_split = 1 << (kv_len_per_split - 1).bit_length() + + # Initialize to 0 and -inf so empty splits (where start_n >= end_n) contribute nothing + Att_Out = torch.zeros((NUM_KV_SPLITS, num_batch, num_qo_heads, head_dim_vo), device=q.device, dtype=q.dtype) + LSE_Out = torch.full( + (num_batch, num_qo_heads, NUM_KV_SPLITS), float("-inf"), device=q.device, dtype=torch.float32 + ) + else: + outputs = ( + torch.empty((num_batch, num_qo_heads, head_dim_vo), device=q.device, dtype=q.dtype) + if outputs is None + else outputs + ) + Att_Out = outputs.reshape(NUM_KV_SPLITS, num_batch, num_qo_heads, head_dim_vo) + + actual_seq_lens_flat = actual_seq_lens.reshape(-1).contiguous() + block_tables_flat = block_tables.reshape(-1).contiguous() + stride_block_table = block_tables.shape[1] if block_tables.dim() > 1 else 1 + + LSE_Out_arg = LSE_Out if LSE_Out is not None else torch.zeros(1, device=q.device, dtype=torch.float32) + HAS_LSE_OUT = LSE_Out is not None + + _gqa_decode_autotune_base( + torch.cuda.current_stream(), + q, + k_cache, + v_cache, + actual_seq_lens_flat, + block_tables_flat, + Att_Out, + LSE_Out_arg, + num_batch, + num_qo_heads, + num_kv_heads, + total_num_pages, + k_scale, + v_scale, + page_size, + head_dim_qk, + QUERY_GROUP_SIZE, + NUM_KV_SPLITS, + kv_len_per_split, + HAS_LSE_OUT, + stride_block_table, + TRANS_QK, + ) + + if should_use_split_kv: + return splitk_reduce_with_seq_len(Att_Out, LSE_Out, actual_seq_lens_flat, kv_len_per_split, outputs) + return outputs + + +def _mla_decode_autotune_configs(): + for bh in [16, 32]: + for bn in [16, 32, 64, 128]: + for occupancy in [1, 2]: + yield SimpleNamespace(BLOCK_H=bh, BLOCK_N=bn, occupancy=occupancy) + + +def _mla_decode_autotune_base( + stream, + q, + q_rope, + kv_cache, + k_rope, + actual_seq_lens_flat, + block_tables_flat, + Att_Out, + LSE_Out_arg, + num_batch, + num_qo_heads, + total_num_pages, + k_scale, + v_scale, + page_size, + head_dim_qk, + head_dim_rope, + QUERY_GROUP_SIZE, + NUM_KV_SPLITS, + kv_len_per_split, + HAS_LSE_OUT, + stride_block_table, +): + mla_cache_key = ( + num_batch, + num_qo_heads, + total_num_pages, + page_size, + head_dim_qk, + head_dim_rope, + QUERY_GROUP_SIZE, + NUM_KV_SPLITS, + kv_len_per_split, + HAS_LSE_OUT, + q.dtype, + str(q.device), + ) + if mla_cache_key not in _decode_mla_paged_tune_cache: + result = exhaustive_search( + list(_mla_decode_autotune_configs()), + stream, + lambda cfg: (num_batch, (num_qo_heads + cfg.BLOCK_H - 1) // cfg.BLOCK_H, NUM_KV_SPLITS), + _decode_mla_kv_paged_kernel, + lambda cfg: ( + q, + q_rope, + kv_cache, + kv_cache, + k_rope, + actual_seq_lens_flat, + block_tables_flat, + Att_Out, + LSE_Out_arg, + num_batch, + total_num_pages, + k_scale, + v_scale, + page_size, + cfg.BLOCK_H, + min(cfg.BLOCK_N, page_size), + head_dim_qk, + head_dim_rope, + QUERY_GROUP_SIZE, + NUM_KV_SPLITS, + kv_len_per_split, + HAS_LSE_OUT, + stride_block_table, + ), + lambda cfg: {"occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _decode_mla_paged_tune_cache[mla_cache_key] = ( + best_cfg, + ct.kernel( + _decode_mla_kv_paged_kernel._pyfunc, + occupancy=best_cfg.occupancy, + ), + ) + best_cfg, tuned_kernel = _decode_mla_paged_tune_cache[mla_cache_key] + ct.launch( + stream, + (num_batch, (num_qo_heads + best_cfg.BLOCK_H - 1) // best_cfg.BLOCK_H, NUM_KV_SPLITS), + tuned_kernel, + ( + q, + q_rope, + kv_cache, + kv_cache, + k_rope, + actual_seq_lens_flat, + block_tables_flat, + Att_Out, + LSE_Out_arg, + num_batch, + total_num_pages, + k_scale, + v_scale, + page_size, + best_cfg.BLOCK_H, + min(best_cfg.BLOCK_N, page_size), + head_dim_qk, + head_dim_rope, + QUERY_GROUP_SIZE, + NUM_KV_SPLITS, + kv_len_per_split, + HAS_LSE_OUT, + stride_block_table, + min(best_cfg.BLOCK_N, page_size), + max(best_cfg.BLOCK_N // page_size, 1), + ), + ) + return Att_Out + + +@register_impl("flashinfer.attention.decode_mla_kv_paged", backend="cutile") +def decode_mla_kv_paged( + q, + q_rope, + kv_cache, + k_rope, + actual_seq_lens, + block_tables, + k_scale, + v_scale, + max_seq_len: int = -1, + outputs: Optional[torch.Tensor] = None, + force_split_kv: bool = False, + force_persistent: bool = False, +): + num_qo_heads = q.shape[1] + head_dim_qk = q.shape[-1] + head_dim_rope = q_rope.shape[-1] + num_batch = q.shape[0] + total_num_pages = kv_cache.shape[0] + page_size = kv_cache.shape[1] + + QUERY_GROUP_SIZE = num_qo_heads + + use_autotune = os.environ.get("ENABLE_CUTILE_TUNE", "0") == "1" + if use_autotune: + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + num_head_blocks = max(QUERY_GROUP_SIZE // 32, 1) + total_work = num_batch * num_head_blocks + + if max_seq_len < 0: + max_pages_per_seq = block_tables.shape[1] if block_tables.dim() > 1 else block_tables.shape[0] + estimated_seq_len = max_pages_per_seq * page_size + else: + estimated_seq_len = max_seq_len + + # Split-KV heuristic: use split-KV when not persistent and seqlen > 256 + # This parallelizes across the sequence dimension for better SM utilization + if force_split_kv or (not force_persistent and estimated_seq_len > 256 and total_work < NUM_SMS): + should_use_split_kv = True + num_split_kv_estimated = max(NUM_SMS // num_batch, 1) + kv_len_per_split = estimated_seq_len // num_split_kv_estimated + kv_len_per_split = max(1 << (kv_len_per_split - 1).bit_length() if kv_len_per_split > 0 else 128, 128) + NUM_KV_SPLITS = (estimated_seq_len + kv_len_per_split - 1) // kv_len_per_split + max_seq_len = estimated_seq_len + else: + should_use_split_kv = False + NUM_KV_SPLITS = 1 + kv_len_per_split = -1 + + if should_use_split_kv: + # Initialize to 0 and -inf so empty splits contribute nothing + Att_Out = torch.zeros((NUM_KV_SPLITS, num_batch, num_qo_heads, head_dim_qk), device=q.device, dtype=q.dtype) + LSE_Out = torch.full( + (num_batch, num_qo_heads, NUM_KV_SPLITS), float("-inf"), device=q.device, dtype=torch.float32 + ) + else: + outputs = torch.empty_like(q) if outputs is None else outputs + Att_Out = outputs.reshape(1, num_batch, num_qo_heads, head_dim_qk) + LSE_Out = torch.zeros(1, device=q.device, dtype=torch.float32) + + actual_seq_lens_flat = actual_seq_lens.reshape(-1).contiguous() + block_tables_flat = block_tables.reshape(-1).contiguous() + stride_block_table = block_tables.shape[1] if block_tables.dim() > 1 else 1 + + HAS_LSE_OUT = should_use_split_kv + LSE_Out_arg = LSE_Out if should_use_split_kv else torch.zeros(1, device=q.device, dtype=torch.float32) + + _mla_decode_autotune_base( + torch.cuda.current_stream(), + q, + q_rope, + kv_cache, + k_rope, + actual_seq_lens_flat, + block_tables_flat, + Att_Out, + LSE_Out_arg, + num_batch, + num_qo_heads, + total_num_pages, + k_scale, + v_scale, + page_size, + head_dim_qk, + head_dim_rope, + QUERY_GROUP_SIZE, + NUM_KV_SPLITS, + kv_len_per_split, + HAS_LSE_OUT, + stride_block_table, + ) + + if should_use_split_kv: + return splitk_reduce_with_seq_len(Att_Out, LSE_Out, actual_seq_lens_flat, kv_len_per_split, outputs) + return outputs + + use_large_block_h = num_batch >= 16 + + if use_large_block_h: + for bh in [128, 64, 32, 16, 8]: + if QUERY_GROUP_SIZE >= bh and QUERY_GROUP_SIZE % bh == 0: + BLOCK_H = bh + break + else: + BLOCK_H = max(QUERY_GROUP_SIZE, 1) + else: + for bh in [16, 8, 32]: + if QUERY_GROUP_SIZE >= bh and QUERY_GROUP_SIZE % bh == 0: + BLOCK_H = bh + break + else: + BLOCK_H = max(QUERY_GROUP_SIZE, 1) + + while QUERY_GROUP_SIZE % BLOCK_H != 0 and BLOCK_H > 1: + BLOCK_H = BLOCK_H // 2 + + if not is_power_of_2(BLOCK_H): + BLOCK_H = next_power_of_2(BLOCK_H) // 2 + if BLOCK_H == 0: + BLOCK_H = 1 + + BLOCK_N = page_size if page_size >= 16 else 16 + LOAD_BLOCK_N = min(BLOCK_N, page_size) + NUM_PAGES_PER_BLOCK = max(BLOCK_N // page_size, 1) + num_head_blocks = max(QUERY_GROUP_SIZE // BLOCK_H, 1) + + if not (is_power_of_2(head_dim_qk) and is_power_of_2(head_dim_rope) and is_power_of_2(BLOCK_H)): + raise NotImplementedError( + f"CuTile MLA decode requires power-of-2 dimensions. Got head_dim_qk={head_dim_qk}, head_dim_rope={head_dim_rope}, BLOCK_H={BLOCK_H}." + ) + + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + LSE_Out = None + kv_len_per_split = -1 + NUM_KV_SPLITS = 1 + + if max_seq_len < 0: + max_pages_per_seq = block_tables.shape[1] if block_tables.dim() > 1 else block_tables.shape[0] + estimated_seq_len = max_pages_per_seq * page_size + else: + estimated_seq_len = max_seq_len + + if force_split_kv and estimated_seq_len >= 1024: + num_split_kv_estimated = max(NUM_SMS // num_batch, 1) + kv_len_per_split = estimated_seq_len // num_split_kv_estimated + kv_len_per_split = max(next_power_of_2(kv_len_per_split), 128) + NUM_KV_SPLITS = (estimated_seq_len + kv_len_per_split - 1) // kv_len_per_split + should_use_split_kv = NUM_KV_SPLITS > 1 + max_seq_len = estimated_seq_len + elif not force_persistent and estimated_seq_len > 256 and num_batch * num_head_blocks < NUM_SMS: + num_split_kv_estimated = max(NUM_SMS // num_batch, 1) + kv_len_per_split = estimated_seq_len // num_split_kv_estimated + kv_len_per_split = max(next_power_of_2(kv_len_per_split), 128) + NUM_KV_SPLITS = (estimated_seq_len + kv_len_per_split - 1) // kv_len_per_split + should_use_split_kv = NUM_KV_SPLITS > 1 + max_seq_len = estimated_seq_len + else: + should_use_split_kv = False + kv_len_per_split = estimated_seq_len + + if should_use_split_kv: + # Initialize to 0 and -inf so empty splits contribute nothing + Att_Out = torch.zeros((NUM_KV_SPLITS, num_batch, num_qo_heads, head_dim_qk), device=q.device, dtype=q.dtype) + LSE_Out = torch.full( + (num_batch, num_qo_heads, NUM_KV_SPLITS), float("-inf"), device=q.device, dtype=torch.float32 + ) + grid = (num_batch, num_head_blocks, NUM_KV_SPLITS) + else: + outputs = torch.empty_like(q) if outputs is None else outputs + Att_Out = outputs.reshape(NUM_KV_SPLITS, num_batch, num_qo_heads, head_dim_qk) + grid = (num_batch, num_head_blocks, NUM_KV_SPLITS) + + actual_seq_lens_flat = actual_seq_lens.reshape(-1).contiguous() + block_tables_flat = block_tables.reshape(-1).contiguous() + stride_block_table = block_tables.shape[1] if block_tables.dim() > 1 else 1 + + LSE_Out_arg = LSE_Out if LSE_Out is not None else torch.zeros(1, device=q.device, dtype=torch.float32) + HAS_LSE_OUT = LSE_Out is not None + + kernel = _decode_mla_kv_paged_kernel + + ct.launch( + torch.cuda.current_stream(), + grid, + kernel, + ( + q, + q_rope, + kv_cache, + kv_cache, + k_rope, + actual_seq_lens_flat, + block_tables_flat, + Att_Out, + LSE_Out_arg, + num_batch, + total_num_pages, + k_scale, + v_scale, + page_size, + BLOCK_H, + BLOCK_N, + head_dim_qk, + head_dim_rope, + QUERY_GROUP_SIZE, + NUM_KV_SPLITS, + kv_len_per_split, + HAS_LSE_OUT, + stride_block_table, + LOAD_BLOCK_N, + NUM_PAGES_PER_BLOCK, + ), + ) + + if should_use_split_kv: + return splitk_reduce_with_seq_len(Att_Out, LSE_Out, actual_seq_lens_flat, kv_len_per_split, outputs) + return outputs diff --git a/src/tilegym/suites/flashinfer/cutile/fmha_prefill_bsr.py b/src/tilegym/suites/flashinfer/cutile/fmha_prefill_bsr.py new file mode 100644 index 00000000..1be033b1 --- /dev/null +++ b/src/tilegym/suites/flashinfer/cutile/fmha_prefill_bsr.py @@ -0,0 +1,1660 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +import math +import os +from types import SimpleNamespace +from typing import Optional + +import cuda.tile as ct +import torch +from cuda.tile import RoundingMode as RMd +from cuda.tile.tune import exhaustive_search + +from tilegym.backend import register_impl +from tilegym.ops.cutile.utils import is_power_of_2 +from tilegym.ops.cutile.utils import next_power_of_2 + +# Module-level tune caches for prefill kernels +_prefill_paged_lpt_tune_cache: dict = {} +_prefill_paged_tune_cache: dict = {} +_prefill_ragged_lpt_tune_cache: dict = {} +_prefill_ragged_tune_cache: dict = {} + +INV_LOG_2 = 1.0 / math.log(2) + +ConstInt = ct.Constant[int] +ConstBool = ct.Constant[bool] +ConstFloat = ct.Constant[float] + + +def _get_prefill_autotune_configs(page_size=None): + configs = [ + SimpleNamespace(BLOCK_M=128, BLOCK_N=32, occupancy=1, num_ctas=1), + SimpleNamespace(BLOCK_M=128, BLOCK_N=64, occupancy=1, num_ctas=1), + SimpleNamespace(BLOCK_M=128, BLOCK_N=128, occupancy=1, num_ctas=1), + ] + + if torch.cuda.get_device_capability()[0] != 9: + configs.extend( + [ + SimpleNamespace(BLOCK_M=256, BLOCK_N=128, occupancy=1, num_ctas=1), + SimpleNamespace(BLOCK_M=256, BLOCK_N=64, occupancy=1, num_ctas=1), + SimpleNamespace(BLOCK_M=128, BLOCK_N=16, occupancy=1, num_ctas=1), + SimpleNamespace(BLOCK_M=128, BLOCK_N=16, occupancy=2, num_ctas=1), + SimpleNamespace(BLOCK_M=128, BLOCK_N=32, occupancy=2, num_ctas=1), + SimpleNamespace(BLOCK_M=128, BLOCK_N=64, occupancy=2, num_ctas=1), + SimpleNamespace(BLOCK_M=128, BLOCK_N=128, occupancy=2, num_ctas=1), + ] + ) + + for cfg in configs: + if page_size is not None and cfg.BLOCK_N > page_size: + continue + yield cfg + + +def _load_page_prefill( + cache, + block_tables, + page_table_offset, + page, + token, + off_kv_h, + NUM_PAGES, + LOAD_BLOCK_N, + BLOCK_D, + _PAGE_SIZE, + dim3_offset=0, + LATENCY=3, +): + """ + Load data from paged cache via TMA for prefill attention. + + For single page, issues one TMA load. + For multiple pages, issues N independent TMA loads and concatenates via ct.cat. + """ + PAD_ZERO = ct.PaddingMode.ZERO + if NUM_PAGES == 1: + page_id = ct.gather(block_tables, (page_table_offset + page,), padding_value=0).item() + data = ct.reshape( + ct.load( + cache, + index=(page_id, token // LOAD_BLOCK_N, off_kv_h, dim3_offset), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=LATENCY, + padding_mode=PAD_ZERO, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + elif NUM_PAGES == 2: + pg0 = ct.gather(block_tables, (page_table_offset + page,), padding_value=0).item() + d0 = ct.reshape( + ct.load( + cache, + index=(pg0, token // LOAD_BLOCK_N, off_kv_h, dim3_offset), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=LATENCY, + padding_mode=PAD_ZERO, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + pg1 = ct.gather(block_tables, (page_table_offset + page + 1,), padding_value=0).item() + d1 = ct.reshape( + ct.load( + cache, + index=(pg1, 0, off_kv_h, dim3_offset), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=LATENCY, + padding_mode=PAD_ZERO, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + data = ct.cat((d0, d1), 0) + elif NUM_PAGES == 4: + pg0 = ct.gather(block_tables, (page_table_offset + page,), padding_value=0).item() + d0 = ct.reshape( + ct.load( + cache, + index=(pg0, token // LOAD_BLOCK_N, off_kv_h, dim3_offset), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=LATENCY, + padding_mode=PAD_ZERO, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + pg1 = ct.gather(block_tables, (page_table_offset + page + 1,), padding_value=0).item() + d1 = ct.reshape( + ct.load( + cache, + index=(pg1, 0, off_kv_h, dim3_offset), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=LATENCY, + padding_mode=PAD_ZERO, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + pg2 = ct.gather(block_tables, (page_table_offset + page + 2,), padding_value=0).item() + d2 = ct.reshape( + ct.load( + cache, + index=(pg2, 0, off_kv_h, dim3_offset), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=LATENCY, + padding_mode=PAD_ZERO, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + pg3 = ct.gather(block_tables, (page_table_offset + page + 3,), padding_value=0).item() + d3 = ct.reshape( + ct.load( + cache, + index=(pg3, 0, off_kv_h, dim3_offset), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=LATENCY, + padding_mode=PAD_ZERO, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + data = ct.cat((ct.cat((d0, d1), 0), ct.cat((d2, d3), 0)), 0) + elif NUM_PAGES == 8: + pg0 = ct.gather(block_tables, (page_table_offset + page,), padding_value=0).item() + d0 = ct.reshape( + ct.load( + cache, + index=(pg0, token // LOAD_BLOCK_N, off_kv_h, dim3_offset), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=LATENCY, + padding_mode=PAD_ZERO, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + pg1 = ct.gather(block_tables, (page_table_offset + page + 1,), padding_value=0).item() + d1 = ct.reshape( + ct.load( + cache, + index=(pg1, 0, off_kv_h, dim3_offset), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=LATENCY, + padding_mode=PAD_ZERO, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + pg2 = ct.gather(block_tables, (page_table_offset + page + 2,), padding_value=0).item() + d2 = ct.reshape( + ct.load( + cache, + index=(pg2, 0, off_kv_h, dim3_offset), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=LATENCY, + padding_mode=PAD_ZERO, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + pg3 = ct.gather(block_tables, (page_table_offset + page + 3,), padding_value=0).item() + d3 = ct.reshape( + ct.load( + cache, + index=(pg3, 0, off_kv_h, dim3_offset), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=LATENCY, + padding_mode=PAD_ZERO, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + pg4 = ct.gather(block_tables, (page_table_offset + page + 4,), padding_value=0).item() + d4 = ct.reshape( + ct.load( + cache, + index=(pg4, 0, off_kv_h, dim3_offset), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=LATENCY, + padding_mode=PAD_ZERO, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + pg5 = ct.gather(block_tables, (page_table_offset + page + 5,), padding_value=0).item() + d5 = ct.reshape( + ct.load( + cache, + index=(pg5, 0, off_kv_h, dim3_offset), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=LATENCY, + padding_mode=PAD_ZERO, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + pg6 = ct.gather(block_tables, (page_table_offset + page + 6,), padding_value=0).item() + d6 = ct.reshape( + ct.load( + cache, + index=(pg6, 0, off_kv_h, dim3_offset), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=LATENCY, + padding_mode=PAD_ZERO, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + pg7 = ct.gather(block_tables, (page_table_offset + page + 7,), padding_value=0).item() + d7 = ct.reshape( + ct.load( + cache, + index=(pg7, 0, off_kv_h, dim3_offset), + shape=(1, LOAD_BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2, 3), + allow_tma=True, + latency=LATENCY, + padding_mode=PAD_ZERO, + ), + (LOAD_BLOCK_N, BLOCK_D), + ) + data = ct.cat( + ( + ct.cat((ct.cat((d0, d1), 0), ct.cat((d2, d3), 0)), 0), + ct.cat((ct.cat((d4, d5), 0), ct.cat((d6, d7), 0)), 0), + ), + 0, + ) + return data + + +def _load_page_wrapper_prefill( + curr_n, + cache, + block_tables, + page_table_offset, + off_kv_h, + PAGE_SIZE, + BLOCK_N, + BLOCK_D, + LOAD_BLOCK_N, + dim3_offset=0, + LATENCY=3, +): + NUM_PAGES = BLOCK_N // LOAD_BLOCK_N + page = curr_n // PAGE_SIZE + token = curr_n % PAGE_SIZE + return _load_page_prefill( + cache, + block_tables, + page_table_offset, + page, + token, + off_kv_h, + NUM_PAGES, + LOAD_BLOCK_N, + BLOCK_D, + PAGE_SIZE, + dim3_offset, + LATENCY, + ) + + +def _prefill_attention_paged_body( + batch_id, + head_id, + seq_block_id, + q_ptr, + k_cache_ptr, + v_cache_ptr, + actual_seq_lens_q, + actual_seq_lens_kv, + batch_offsets, + block_tables, + o_ptr, + lse_ptr, + k_scale: ConstFloat, + v_scale: ConstFloat, + N_KV_HEADS: ConstInt, + PAGE_SIZE: ConstInt, + BLOCK_M: ConstInt, + BLOCK_N: ConstInt, + BLOCK_D: ConstInt, + BLOCK_R: ConstInt, + QUERY_GROUP_SIZE: ConstInt, + stride_block_table: ConstInt, + IS_CAUSAL: ConstBool, + LOAD_BLOCK_N: ConstInt, +): + # Load sequence info + seq_start_idx_tile = ct.gather(batch_offsets, (batch_id,), padding_value=0) + seq_start_index = seq_start_idx_tile.item() + + seq_len_q_tile = ct.gather(actual_seq_lens_q, (batch_id,), padding_value=0) + seq_len_q = seq_len_q_tile.item() + + seq_len_kv_tile = ct.gather(actual_seq_lens_kv, (batch_id,), padding_value=0) + seq_len_kv = seq_len_kv_tile.item() + + start_m = BLOCK_M * seq_block_id + + if start_m >= seq_len_q: + return + + N_HEADS = N_KV_HEADS * QUERY_GROUP_SIZE + off_kv_h = head_id // QUERY_GROUP_SIZE + qk_scale = k_scale * INV_LOG_2 + PAD_ZERO = ct.PaddingMode.ZERO + + page_table_offset = batch_id * stride_block_table + + q_seq = q_ptr.slice(axis=0, start=seq_start_index, stop=seq_start_index + seq_len_q) + o_seq = o_ptr.slice(axis=0, start=seq_start_index, stop=seq_start_index + seq_len_q) + + q_tile = ct.load( + q_seq, + index=(seq_block_id, head_id, 0), + shape=(BLOCK_M, 1, BLOCK_D), + order=(0, 1, 2), + allow_tma=True, + latency=2, + padding_mode=PAD_ZERO, + ) + q = ct.reshape(q_tile, (BLOCK_M, BLOCK_D)) + + q_pe = None + if BLOCK_R > 0: + q_pe_tile = ct.load( + q_seq, + index=(seq_block_id, head_id, BLOCK_D // BLOCK_R), + shape=(BLOCK_M, 1, BLOCK_R), + order=(0, 1, 2), + allow_tma=True, + latency=2, + padding_mode=PAD_ZERO, + ) + q_pe = ct.reshape(q_pe_tile, (BLOCK_M, BLOCK_R)) + + # Initialize accumulators + m_i = ct.full((BLOCK_M,), -math.inf, dtype=ct.float32) + l_i = ct.full((BLOCK_M,), 1.0, dtype=ct.float32) + acc = ct.full((BLOCK_M, BLOCK_D), 0.0, dtype=ct.float32) + + # Pre-allocate zero accumulator for QK (hoisted outside loop) + qk_zeros = ct.full((BLOCK_M, BLOCK_N), 0.0, dtype=ct.float32) + + offs_n_base = ct.arange(BLOCK_N, dtype=ct.int32) + offs_m = start_m + ct.arange(BLOCK_M, dtype=ct.int32) + + # Compute bounds for two-stage causal split + if IS_CAUSAL: + # Off-band: everything before the diagonal block (fully unmasked) + off_band_hi = ct.minimum(seq_len_kv, start_m) + # On-band: the diagonal block itself + on_band_lo = start_m + on_band_hi = ct.minimum(seq_len_kv, start_m + BLOCK_M) + else: + # Non-causal: process everything in off-band loop + off_band_hi = seq_len_kv + on_band_lo = 0 + on_band_hi = 0 + + # ========== Stage 1: Off-band loop (NO causal mask) ========== + off_band_iters = (off_band_hi + BLOCK_N - 1) // BLOCK_N + for iter_idx in range(off_band_iters): + curr_n = iter_idx * BLOCK_N + + k = _load_page_wrapper_prefill( + curr_n, + k_cache_ptr, + block_tables, + page_table_offset, + off_kv_h, + PAGE_SIZE, + BLOCK_N, + BLOCK_D, + LOAD_BLOCK_N, + 0, + 3, + ) + + qk = ct.mma(q, ct.transpose(k), acc=qk_zeros) + + if BLOCK_R > 0: + k_pe = _load_page_wrapper_prefill( + curr_n, + k_cache_ptr, + block_tables, + page_table_offset, + off_kv_h, + PAGE_SIZE, + BLOCK_N, + BLOCK_R, + LOAD_BLOCK_N, + BLOCK_D // BLOCK_R, + 3, + ) + qk = ct.mma(q_pe, ct.transpose(k_pe), acc=qk) + + qk_max = ct.max(qk, axis=1, keepdims=False) + m_ij = ct.maximum(m_i, ct.mul(qk_max, qk_scale, flush_to_zero=True)) + p = ct.exp2( + ct.sub(ct.mul(qk, qk_scale, flush_to_zero=True), ct.reshape(m_ij, (BLOCK_M, 1)), flush_to_zero=True), + flush_to_zero=True, + ) + + alpha = ct.exp2(ct.sub(m_i, m_ij, flush_to_zero=True), flush_to_zero=True) + l_i = ct.add(ct.mul(l_i, alpha, flush_to_zero=True), ct.sum(p, axis=1, keepdims=False), flush_to_zero=True) + acc = ct.mul(acc, ct.reshape(alpha, (BLOCK_M, 1)), flush_to_zero=True) + + v = _load_page_wrapper_prefill( + curr_n, + v_cache_ptr, + block_tables, + page_table_offset, + off_kv_h, + PAGE_SIZE, + BLOCK_N, + BLOCK_D, + LOAD_BLOCK_N, + 0, + 4, + ) + + acc = ct.mma(ct.astype(p, q.dtype), v, acc=acc) + m_i = m_ij + + # ========== Stage 2: On-band loop (WITH causal mask) ========== + if IS_CAUSAL: + on_band_iters = (on_band_hi - on_band_lo + BLOCK_N - 1) // BLOCK_N + for iter_idx in range(on_band_iters): + curr_n = on_band_lo + iter_idx * BLOCK_N + + k = _load_page_wrapper_prefill( + curr_n, + k_cache_ptr, + block_tables, + page_table_offset, + off_kv_h, + PAGE_SIZE, + BLOCK_N, + BLOCK_D, + LOAD_BLOCK_N, + 0, + 3, + ) + + qk = ct.mma(q, ct.transpose(k), acc=qk_zeros) + + if BLOCK_R > 0: + k_pe = _load_page_wrapper_prefill( + curr_n, + k_cache_ptr, + block_tables, + page_table_offset, + off_kv_h, + PAGE_SIZE, + BLOCK_N, + BLOCK_R, + LOAD_BLOCK_N, + BLOCK_D // BLOCK_R, + 3, + ) + qk = ct.mma(q_pe, ct.transpose(k_pe), acc=qk) + + offs_n = curr_n + offs_n_base + causal_mask = ct.reshape(offs_m, (BLOCK_M, 1)) >= ct.reshape(offs_n, (1, BLOCK_N)) + qk = ct.where(causal_mask, qk, ct.full((BLOCK_M, BLOCK_N), -1.0e6, dtype=ct.float32)) + + qk_max = ct.max(qk, axis=1, keepdims=False) + m_ij = ct.maximum(m_i, ct.mul(qk_max, qk_scale, flush_to_zero=True)) + p = ct.exp2( + ct.sub(ct.mul(qk, qk_scale, flush_to_zero=True), ct.reshape(m_ij, (BLOCK_M, 1)), flush_to_zero=True), + flush_to_zero=True, + ) + + alpha = ct.exp2(ct.sub(m_i, m_ij, flush_to_zero=True), flush_to_zero=True) + l_i = ct.add(ct.mul(l_i, alpha, flush_to_zero=True), ct.sum(p, axis=1, keepdims=False), flush_to_zero=True) + acc = ct.mul(acc, ct.reshape(alpha, (BLOCK_M, 1)), flush_to_zero=True) + + v = _load_page_wrapper_prefill( + curr_n, + v_cache_ptr, + block_tables, + page_table_offset, + off_kv_h, + PAGE_SIZE, + BLOCK_N, + BLOCK_D, + LOAD_BLOCK_N, + 0, + 4, + ) + + acc = ct.mma(ct.astype(p, q.dtype), v, acc=acc) + m_i = m_ij + + # Epilogue: normalize and store with RMd.APPROX + l_i_rcp = ct.truediv(v_scale, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX) + acc = ct.mul(acc, ct.reshape(l_i_rcp, (BLOCK_M, 1)), flush_to_zero=True) + lse = m_i + ct.log2(l_i) + + # Store output using TMA + acc_out = ct.astype(acc, o_ptr.dtype) + acc_3d = ct.reshape(acc_out, (BLOCK_M, 1, BLOCK_D)) + ct.store( + o_seq, + index=(seq_block_id, head_id, 0), + tile=acc_3d, + order=(0, 1, 2), + allow_tma=True, + latency=2, + ) + + # Store LSE - lse_ptr is 2D [total_tokens, num_heads] + lse_scaled = lse * (1.0 / INV_LOG_2) # multiply by constant instead of dividing + offs_m_store = ct.arange(BLOCK_M, dtype=ct.int32) + token_indices = seq_start_index + start_m + offs_m_store + head_indices = ct.full((BLOCK_M,), head_id, dtype=ct.int32) + lse_mask = ct.less(offs_m_store + start_m, seq_len_q) + token_indices_masked = ct.where(lse_mask, token_indices, ct.full((BLOCK_M,), -1, dtype=ct.int32)) + lse_indices = (token_indices_masked, head_indices) + ct.scatter(lse_ptr, lse_indices, lse_scaled) + + +@ct.kernel +def _prefill_attention_paged_kernel( + q_ptr, + k_cache_ptr, + v_cache_ptr, + actual_seq_lens_q, + actual_seq_lens_kv, + batch_offsets, + block_tables, + o_ptr, + lse_ptr, + k_scale: ConstFloat, + v_scale: ConstFloat, + N_KV_HEADS: ConstInt, + PAGE_SIZE: ConstInt, + BLOCK_M: ConstInt, + BLOCK_N: ConstInt, + BLOCK_D: ConstInt, + BLOCK_R: ConstInt, + QUERY_GROUP_SIZE: ConstInt, + stride_block_table: ConstInt, + IS_CAUSAL: ConstBool, + LOAD_BLOCK_N: ConstInt, +): + batch_id = ct.bid(0) + head_id = ct.bid(1) + seq_block_id = ct.bid(2) + + _prefill_attention_paged_body( + batch_id, + head_id, + seq_block_id, + q_ptr, + k_cache_ptr, + v_cache_ptr, + actual_seq_lens_q, + actual_seq_lens_kv, + batch_offsets, + block_tables, + o_ptr, + lse_ptr, + k_scale, + v_scale, + N_KV_HEADS, + PAGE_SIZE, + BLOCK_M, + BLOCK_N, + BLOCK_D, + BLOCK_R, + QUERY_GROUP_SIZE, + stride_block_table, + IS_CAUSAL, + LOAD_BLOCK_N, + ) + + +@ct.kernel +def _prefill_attention_paged_lpt_kernel( + q_ptr, + k_cache_ptr, + v_cache_ptr, + actual_seq_lens_q, + actual_seq_lens_kv, + batch_offsets, + block_tables, + o_ptr, + lse_ptr, + k_scale: ConstFloat, + v_scale: ConstFloat, + N_KV_HEADS: ConstInt, + PAGE_SIZE: ConstInt, + BLOCK_M: ConstInt, + BLOCK_N: ConstInt, + BLOCK_D: ConstInt, + BLOCK_R: ConstInt, + QUERY_GROUP_SIZE: ConstInt, + stride_block_table: ConstInt, + IS_CAUSAL: ConstBool, + LOAD_BLOCK_N: ConstInt, + NUM_HEADS: ConstInt, + NUM_BATCH: ConstInt, + MAX_SEQ_LEN: ConstInt, + SWIZZLE: ConstInt, + NUM_HB_QUOTIENT: ConstInt, + NUM_HB_REMAINDER: ConstInt, +): + tile_idx = ct.bid(0) + NUM_BLOCKS = (MAX_SEQ_LEN + BLOCK_M - 1) // BLOCK_M + l2_major_blocks = SWIZZLE * NUM_BLOCKS + bidhb = tile_idx // l2_major_blocks + l2_mod = tile_idx % l2_major_blocks + if bidhb < NUM_HB_QUOTIENT: + block = l2_mod // SWIZZLE + bidhb_residual = l2_mod % SWIZZLE + else: + block = l2_mod // NUM_HB_REMAINDER + bidhb_residual = l2_mod % NUM_HB_REMAINDER + bidhb_actual = bidhb * SWIZZLE + bidhb_residual + batch_id = bidhb_actual // NUM_HEADS + head_id = bidhb_actual % NUM_HEADS + seq_block_id = NUM_BLOCKS - 1 - block + + if tile_idx >= NUM_BLOCKS * NUM_HEADS * NUM_BATCH or batch_id >= NUM_BATCH or head_id >= NUM_HEADS: + return + + _prefill_attention_paged_body( + batch_id, + head_id, + seq_block_id, + q_ptr, + k_cache_ptr, + v_cache_ptr, + actual_seq_lens_q, + actual_seq_lens_kv, + batch_offsets, + block_tables, + o_ptr, + lse_ptr, + k_scale, + v_scale, + N_KV_HEADS, + PAGE_SIZE, + BLOCK_M, + BLOCK_N, + BLOCK_D, + BLOCK_R, + QUERY_GROUP_SIZE, + stride_block_table, + IS_CAUSAL, + LOAD_BLOCK_N, + ) + + +def _prefill_attention_ragged_body( + batch_id, + head_id, + seq_block_id, + q_ptr, + k_cache_ptr, + v_cache_ptr, + actual_seq_lens_q, + actual_seq_lens_kv, + batch_offsets, + o_ptr, + lse_ptr, + k_scale: ConstFloat, + v_scale: ConstFloat, + N_KV_HEADS: ConstInt, + BLOCK_M: ConstInt, + BLOCK_N: ConstInt, + BLOCK_D: ConstInt, + BLOCK_R: ConstInt, + QUERY_GROUP_SIZE: ConstInt, + stride_q0: ConstInt, + stride_q1: ConstInt, + stride_q2: ConstInt, + stride_o0: ConstInt, + stride_o1: ConstInt, + stride_o2: ConstInt, + stride_k_0: ConstInt, + stride_k_1: ConstInt, + stride_k_2: ConstInt, + stride_v_0: ConstInt, + stride_v_1: ConstInt, + stride_v_2: ConstInt, + lse_stride_token: ConstInt, + IS_CAUSAL: ConstBool, +): + # Load sequence info + seq_start_idx_tile = ct.gather(batch_offsets, (batch_id,), padding_value=0) + seq_start_index = seq_start_idx_tile.item() + + seq_len_q_tile = ct.gather(actual_seq_lens_q, (batch_id,), padding_value=0) + seq_len_q = seq_len_q_tile.item() + + seq_len_kv_tile = ct.gather(actual_seq_lens_kv, (batch_id,), padding_value=0) + seq_len_kv = seq_len_kv_tile.item() + + start_m = BLOCK_M * seq_block_id + + if start_m >= seq_len_q: + return + + N_HEADS = N_KV_HEADS * QUERY_GROUP_SIZE + off_kv_h = head_id // QUERY_GROUP_SIZE + qk_scale = k_scale * INV_LOG_2 + PAD_ZERO = ct.PaddingMode.ZERO + + # Create sliced views for ragged tensors - enables TMA with block indices + # Slice along axis 0 to offset base pointer by seq_start_index + q_seq = q_ptr.slice(axis=0, start=seq_start_index, stop=seq_start_index + seq_len_q) + k_seq = k_cache_ptr.slice(axis=0, start=seq_start_index, stop=seq_start_index + seq_len_kv) + v_seq = v_cache_ptr.slice(axis=0, start=seq_start_index, stop=seq_start_index + seq_len_kv) + o_seq = o_ptr.slice(axis=0, start=seq_start_index, stop=seq_start_index + seq_len_q) + + # Load Q tile using TMA - use seq_block_id as block index + # q_seq shape: [seq_len_q, num_heads, head_dim_qk + head_dim_rope] + q_tile = ct.load( + q_seq, + index=(seq_block_id, head_id, 0), + shape=(BLOCK_M, 1, BLOCK_D), + order=(0, 1, 2), + allow_tma=True, + latency=2, + padding_mode=PAD_ZERO, + ) + q = ct.reshape(q_tile, (BLOCK_M, BLOCK_D)) + + # Load Q_PE if needed + q_pe = None + if BLOCK_R > 0: + q_pe_tile = ct.load( + q_seq, + index=(seq_block_id, head_id, BLOCK_D // BLOCK_R), + shape=(BLOCK_M, 1, BLOCK_R), + order=(0, 1, 2), + allow_tma=True, + latency=2, + padding_mode=PAD_ZERO, + ) + q_pe = ct.reshape(q_pe_tile, (BLOCK_M, BLOCK_R)) + + # Initialize accumulators + m_i = ct.full((BLOCK_M,), -math.inf, dtype=ct.float32) + l_i = ct.full((BLOCK_M,), 1.0, dtype=ct.float32) + acc = ct.full((BLOCK_M, BLOCK_D), 0.0, dtype=ct.float32) + + # Pre-allocate zero accumulator for QK (hoisted outside loop) + qk_zeros = ct.full((BLOCK_M, BLOCK_N), 0.0, dtype=ct.float32) + + offs_n_base = ct.arange(BLOCK_N, dtype=ct.int32) + offs_m = start_m + ct.arange(BLOCK_M, dtype=ct.int32) + + # Compute bounds for two-stage causal split + if IS_CAUSAL: + # Off-band: everything before the diagonal block (fully unmasked) + off_band_hi = ct.minimum(seq_len_kv, start_m) + # On-band: the diagonal block itself + on_band_lo = start_m + on_band_hi = ct.minimum(seq_len_kv, start_m + BLOCK_M) + else: + # Non-causal: process everything in off-band loop + off_band_hi = seq_len_kv + on_band_lo = 0 + on_band_hi = 0 + + # ========== Stage 1: Off-band loop (NO causal mask) ========== + off_band_iters = (off_band_hi + BLOCK_N - 1) // BLOCK_N + for iter_idx in range(off_band_iters): + curr_n = iter_idx * BLOCK_N + + k_tile = ct.load( + k_seq, + index=(iter_idx, off_kv_h, 0), + shape=(BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2), + allow_tma=True, + latency=2, + padding_mode=PAD_ZERO, + ) + k = ct.reshape(k_tile, (BLOCK_N, BLOCK_D)) + + qk = ct.mma(q, ct.transpose(k), acc=qk_zeros) + + if BLOCK_R > 0: + k_pe_tile = ct.load( + k_seq, + index=(iter_idx, off_kv_h, BLOCK_D // BLOCK_R), + shape=(BLOCK_N, 1, BLOCK_R), + order=(0, 1, 2), + allow_tma=True, + latency=2, + padding_mode=PAD_ZERO, + ) + k_pe = ct.reshape(k_pe_tile, (BLOCK_N, BLOCK_R)) + qk = ct.mma(q_pe, ct.transpose(k_pe), acc=qk) + + qk_max = ct.max(qk, axis=1, keepdims=False) + m_ij = ct.maximum(m_i, ct.mul(qk_max, qk_scale, flush_to_zero=True)) + p = ct.exp2( + ct.sub(ct.mul(qk, qk_scale, flush_to_zero=True), ct.reshape(m_ij, (BLOCK_M, 1)), flush_to_zero=True), + flush_to_zero=True, + ) + + alpha = ct.exp2(ct.sub(m_i, m_ij, flush_to_zero=True), flush_to_zero=True) + l_i = ct.add(ct.mul(l_i, alpha, flush_to_zero=True), ct.sum(p, axis=1, keepdims=False), flush_to_zero=True) + acc = ct.mul(acc, ct.reshape(alpha, (BLOCK_M, 1)), flush_to_zero=True) + + v_tile = ct.load( + v_seq, + index=(iter_idx, off_kv_h, 0), + shape=(BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2), + allow_tma=True, + latency=2, + padding_mode=PAD_ZERO, + ) + v = ct.reshape(v_tile, (BLOCK_N, BLOCK_D)) + + acc = ct.mma(ct.astype(p, q.dtype), v, acc=acc) + m_i = m_ij + + # ========== Stage 2: On-band loop (WITH causal mask) ========== + if IS_CAUSAL: + on_band_iters = (on_band_hi - on_band_lo + BLOCK_N - 1) // BLOCK_N + on_band_block_start = on_band_lo // BLOCK_N + for iter_idx in range(on_band_iters): + curr_n = on_band_lo + iter_idx * BLOCK_N + block_idx = on_band_block_start + iter_idx + + k_tile = ct.load( + k_seq, + index=(block_idx, off_kv_h, 0), + shape=(BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2), + allow_tma=True, + latency=2, + padding_mode=PAD_ZERO, + ) + k = ct.reshape(k_tile, (BLOCK_N, BLOCK_D)) + + qk = ct.mma(q, ct.transpose(k), acc=qk_zeros) + + if BLOCK_R > 0: + k_pe_tile = ct.load( + k_seq, + index=(block_idx, off_kv_h, BLOCK_D // BLOCK_R), + shape=(BLOCK_N, 1, BLOCK_R), + order=(0, 1, 2), + allow_tma=True, + latency=2, + padding_mode=PAD_ZERO, + ) + k_pe = ct.reshape(k_pe_tile, (BLOCK_N, BLOCK_R)) + qk = ct.mma(q_pe, ct.transpose(k_pe), acc=qk) + + offs_n = curr_n + offs_n_base + causal_mask = ct.reshape(offs_m, (BLOCK_M, 1)) >= ct.reshape(offs_n, (1, BLOCK_N)) + qk = ct.where(causal_mask, qk, ct.full((BLOCK_M, BLOCK_N), -1.0e6, dtype=ct.float32)) + + qk_max = ct.max(qk, axis=1, keepdims=False) + m_ij = ct.maximum(m_i, ct.mul(qk_max, qk_scale, flush_to_zero=True)) + p = ct.exp2( + ct.sub(ct.mul(qk, qk_scale, flush_to_zero=True), ct.reshape(m_ij, (BLOCK_M, 1)), flush_to_zero=True), + flush_to_zero=True, + ) + + alpha = ct.exp2(ct.sub(m_i, m_ij, flush_to_zero=True), flush_to_zero=True) + l_i = ct.add(ct.mul(l_i, alpha, flush_to_zero=True), ct.sum(p, axis=1, keepdims=False), flush_to_zero=True) + acc = ct.mul(acc, ct.reshape(alpha, (BLOCK_M, 1)), flush_to_zero=True) + + v_tile = ct.load( + v_seq, + index=(block_idx, off_kv_h, 0), + shape=(BLOCK_N, 1, BLOCK_D), + order=(0, 1, 2), + allow_tma=True, + latency=2, + padding_mode=PAD_ZERO, + ) + v = ct.reshape(v_tile, (BLOCK_N, BLOCK_D)) + + acc = ct.mma(ct.astype(p, q.dtype), v, acc=acc) + m_i = m_ij + + l_i_rcp = ct.truediv(v_scale, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX) + acc = ct.mul(acc, ct.reshape(l_i_rcp, (BLOCK_M, 1)), flush_to_zero=True) + lse = m_i + ct.log2(l_i) + + acc_out = ct.astype(acc, o_ptr.dtype) + acc_3d = ct.reshape(acc_out, (BLOCK_M, 1, BLOCK_D)) + ct.store( + o_seq, + index=(seq_block_id, head_id, 0), + tile=acc_3d, + order=(0, 1, 2), + allow_tma=True, + latency=2, + ) + + lse_scaled = lse * (1.0 / INV_LOG_2) + offs_m_store = ct.arange(BLOCK_M, dtype=ct.int32) + token_indices = seq_start_index + start_m + offs_m_store + head_indices = ct.full((BLOCK_M,), head_id, dtype=ct.int32) + lse_mask = ct.less(offs_m_store + start_m, seq_len_q) + token_indices_masked = ct.where(lse_mask, token_indices, ct.full((BLOCK_M,), -1, dtype=ct.int32)) + lse_indices = (token_indices_masked, head_indices) + ct.scatter(lse_ptr, lse_indices, lse_scaled) + + +@ct.kernel +def _prefill_attention_ragged_kernel( + q_ptr, + k_cache_ptr, + v_cache_ptr, + actual_seq_lens_q, + actual_seq_lens_kv, + batch_offsets, + o_ptr, + lse_ptr, + k_scale: ConstFloat, + v_scale: ConstFloat, + N_KV_HEADS: ConstInt, + BLOCK_M: ConstInt, + BLOCK_N: ConstInt, + BLOCK_D: ConstInt, + BLOCK_R: ConstInt, + QUERY_GROUP_SIZE: ConstInt, + stride_q0: ConstInt, + stride_q1: ConstInt, + stride_q2: ConstInt, + stride_o0: ConstInt, + stride_o1: ConstInt, + stride_o2: ConstInt, + stride_k_0: ConstInt, + stride_k_1: ConstInt, + stride_k_2: ConstInt, + stride_v_0: ConstInt, + stride_v_1: ConstInt, + stride_v_2: ConstInt, + lse_stride_token: ConstInt, + IS_CAUSAL: ConstBool, +): + """ + Prefill attention kernel with ragged (contiguous) KV cache. + Optimized with two-stage causal loop split: + - Stage 1 (off-band): Fully unmasked region before diagonal - no causal mask needed + - Stage 2 (on-band): Diagonal block where causal mask matters + """ + seq_block_id = ct.bid(0) + batch_id = ct.bid(1) + head_id = ct.bid(2) + + _prefill_attention_ragged_body( + batch_id, + head_id, + seq_block_id, + q_ptr, + k_cache_ptr, + v_cache_ptr, + actual_seq_lens_q, + actual_seq_lens_kv, + batch_offsets, + o_ptr, + lse_ptr, + k_scale, + v_scale, + N_KV_HEADS, + BLOCK_M, + BLOCK_N, + BLOCK_D, + BLOCK_R, + QUERY_GROUP_SIZE, + stride_q0, + stride_q1, + stride_q2, + stride_o0, + stride_o1, + stride_o2, + stride_k_0, + stride_k_1, + stride_k_2, + stride_v_0, + stride_v_1, + stride_v_2, + lse_stride_token, + IS_CAUSAL, + ) + + +@ct.kernel +def _prefill_attention_ragged_lpt_kernel( + q_ptr, + k_cache_ptr, + v_cache_ptr, + actual_seq_lens_q, + actual_seq_lens_kv, + batch_offsets, + o_ptr, + lse_ptr, + k_scale: ConstFloat, + v_scale: ConstFloat, + N_KV_HEADS: ConstInt, + BLOCK_M: ConstInt, + BLOCK_N: ConstInt, + BLOCK_D: ConstInt, + BLOCK_R: ConstInt, + QUERY_GROUP_SIZE: ConstInt, + stride_q0: ConstInt, + stride_q1: ConstInt, + stride_q2: ConstInt, + stride_o0: ConstInt, + stride_o1: ConstInt, + stride_o2: ConstInt, + stride_k_0: ConstInt, + stride_k_1: ConstInt, + stride_k_2: ConstInt, + stride_v_0: ConstInt, + stride_v_1: ConstInt, + stride_v_2: ConstInt, + lse_stride_token: ConstInt, + IS_CAUSAL: ConstBool, + NUM_HEADS: ConstInt, + NUM_BATCH: ConstInt, + MAX_SEQ_LEN: ConstInt, + SWIZZLE: ConstInt, + NUM_HB_QUOTIENT: ConstInt, + NUM_HB_REMAINDER: ConstInt, +): + tile_idx = ct.bid(0) + NUM_BLOCKS = (MAX_SEQ_LEN + BLOCK_M - 1) // BLOCK_M + l2_major_blocks = SWIZZLE * NUM_BLOCKS + bidhb = tile_idx // l2_major_blocks + l2_mod = tile_idx % l2_major_blocks + if bidhb < NUM_HB_QUOTIENT: + block = l2_mod // SWIZZLE + bidhb_residual = l2_mod % SWIZZLE + else: + block = l2_mod // NUM_HB_REMAINDER + bidhb_residual = l2_mod % NUM_HB_REMAINDER + bidhb_actual = bidhb * SWIZZLE + bidhb_residual + batch_id = bidhb_actual // NUM_HEADS + head_id = bidhb_actual % NUM_HEADS + seq_block_id = NUM_BLOCKS - 1 - block # LPT: reverse order + + if tile_idx >= NUM_BLOCKS * NUM_HEADS * NUM_BATCH or batch_id >= NUM_BATCH or head_id >= NUM_HEADS: + return + + _prefill_attention_ragged_body( + batch_id, + head_id, + seq_block_id, + q_ptr, + k_cache_ptr, + v_cache_ptr, + actual_seq_lens_q, + actual_seq_lens_kv, + batch_offsets, + o_ptr, + lse_ptr, + k_scale, + v_scale, + N_KV_HEADS, + BLOCK_M, + BLOCK_N, + BLOCK_D, + BLOCK_R, + QUERY_GROUP_SIZE, + stride_q0, + stride_q1, + stride_q2, + stride_o0, + stride_o1, + stride_o2, + stride_k_0, + stride_k_1, + stride_k_2, + stride_v_0, + stride_v_1, + stride_v_2, + lse_stride_token, + IS_CAUSAL, + ) + + +@register_impl("flashinfer.attention.prefill_attention_kv_paged", backend="cutile") +def prefill_attention_kv_paged( + q, + k_cache, + v_cache, + actual_seq_lens_q, + actual_seq_lens_kv, + actual_seq_offset, + block_tables, + k_scale, + v_scale, + num_batch, + max_seq_len, + is_causal: bool = True, + outputs: Optional[torch.Tensor] = None, + out_lse: Optional[torch.Tensor] = None, + use_lpt_scheduler: bool = True, +): + """ + Prefill attention with paged KV cache (CuTile implementation). + """ + # KV cache [num_pages, page_size, num_kv_heads, head_dim_qk] + total_num_pages = k_cache.shape[0] + page_size = k_cache.shape[1] + num_kv_heads = k_cache.shape[2] + num_qo_heads = q.shape[1] + head_dim_qk = q.shape[-1] + head_dim_vo = v_cache.shape[-1] + + BLOCK_R = head_dim_qk - head_dim_vo + QUERY_GROUP_SIZE = num_qo_heads // num_kv_heads + + outputs = ( + torch.empty( + [q.shape[0], num_qo_heads, head_dim_vo], + dtype=q.dtype, + device=q.device, + ) + if outputs is None + else outputs + ) + out_lse = ( + torch.zeros([q.shape[0], num_qo_heads], dtype=torch.float32, device=q.device) if out_lse is None else out_lse + ) + + # Flatten tensors for kernel + actual_seq_lens_q_flat = actual_seq_lens_q.reshape(-1).contiguous() + actual_seq_lens_kv_flat = actual_seq_lens_kv.reshape(-1).contiguous() + batch_offsets_flat = actual_seq_offset.reshape(-1).contiguous() + block_tables_flat = block_tables.reshape(-1).contiguous() + stride_block_table = block_tables.shape[1] if block_tables.dim() > 1 else 1 + + if use_lpt_scheduler: + element_size = q.element_size() + size_one_kv_head = max_seq_len * (head_dim_qk + head_dim_vo) * element_size + size_l2 = 50 * 1024 * 1024 # 50 MB for K & V + if size_l2 < size_one_kv_head: + swizzle = 1 + else: + log2_floor = (size_l2 // size_one_kv_head).bit_length() - 1 + swizzle = 1 << log2_floor + num_hb_quotient = (num_qo_heads * num_batch) // swizzle + num_hb_remainder = (num_qo_heads * num_batch) % swizzle + + paged_lpt_stream = torch.cuda.current_stream() + paged_lpt_cache_key = ( + num_batch, + num_qo_heads, + num_kv_heads, + total_num_pages, + page_size, + head_dim_qk, + head_dim_vo, + BLOCK_R, + QUERY_GROUP_SIZE, + max_seq_len, + is_causal, + swizzle, + q.dtype, + str(q.device), + ) + if paged_lpt_cache_key not in _prefill_paged_lpt_tune_cache: + result = exhaustive_search( + list(_get_prefill_autotune_configs(page_size)), + paged_lpt_stream, + lambda cfg: ((max_seq_len + cfg.BLOCK_M - 1) // cfg.BLOCK_M * num_qo_heads * num_batch, 1, 1), + _prefill_attention_paged_lpt_kernel, + lambda cfg: ( + q, + k_cache, + v_cache, + actual_seq_lens_q_flat, + actual_seq_lens_kv_flat, + batch_offsets_flat, + block_tables_flat, + outputs, + out_lse, + k_scale, + v_scale, + num_kv_heads, + page_size, + cfg.BLOCK_M, + cfg.BLOCK_N, + head_dim_vo, + BLOCK_R, + QUERY_GROUP_SIZE, + stride_block_table, + is_causal, + min(cfg.BLOCK_N, page_size), + num_qo_heads, + num_batch, + max_seq_len, + swizzle, + num_hb_quotient, + max(num_hb_remainder, 1), + ), + lambda cfg: {"occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _prefill_paged_lpt_tune_cache[paged_lpt_cache_key] = ( + best_cfg, + ct.kernel( + _prefill_attention_paged_lpt_kernel._pyfunc, + occupancy=best_cfg.occupancy, + ), + ) + best_cfg, tuned_kernel = _prefill_paged_lpt_tune_cache[paged_lpt_cache_key] + ct.launch( + paged_lpt_stream, + ((max_seq_len + best_cfg.BLOCK_M - 1) // best_cfg.BLOCK_M * num_qo_heads * num_batch, 1, 1), + tuned_kernel, + ( + q, + k_cache, + v_cache, + actual_seq_lens_q_flat, + actual_seq_lens_kv_flat, + batch_offsets_flat, + block_tables_flat, + outputs, + out_lse, + k_scale, + v_scale, + num_kv_heads, + page_size, + best_cfg.BLOCK_M, + best_cfg.BLOCK_N, + head_dim_vo, + BLOCK_R, + QUERY_GROUP_SIZE, + stride_block_table, + is_causal, + min(best_cfg.BLOCK_N, page_size), + num_qo_heads, + num_batch, + max_seq_len, + swizzle, + num_hb_quotient, + max(num_hb_remainder, 1), + ), + ) + else: + paged_stream = torch.cuda.current_stream() + paged_cache_key = ( + num_batch, + num_qo_heads, + num_kv_heads, + total_num_pages, + page_size, + head_dim_qk, + head_dim_vo, + BLOCK_R, + QUERY_GROUP_SIZE, + max_seq_len, + is_causal, + q.dtype, + str(q.device), + ) + if paged_cache_key not in _prefill_paged_tune_cache: + result = exhaustive_search( + list(_get_prefill_autotune_configs(page_size)), + paged_stream, + lambda cfg: (num_batch, num_qo_heads, (max_seq_len + cfg.BLOCK_M - 1) // cfg.BLOCK_M), + _prefill_attention_paged_kernel, + lambda cfg: ( + q, + k_cache, + v_cache, + actual_seq_lens_q_flat, + actual_seq_lens_kv_flat, + batch_offsets_flat, + block_tables_flat, + outputs, + out_lse, + k_scale, + v_scale, + num_kv_heads, + page_size, + cfg.BLOCK_M, + cfg.BLOCK_N, + head_dim_vo, + BLOCK_R, + QUERY_GROUP_SIZE, + stride_block_table, + is_causal, + min(cfg.BLOCK_N, page_size), + ), + lambda cfg: {"occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _prefill_paged_tune_cache[paged_cache_key] = ( + best_cfg, + ct.kernel( + _prefill_attention_paged_kernel._pyfunc, + occupancy=best_cfg.occupancy, + ), + ) + best_cfg, tuned_kernel = _prefill_paged_tune_cache[paged_cache_key] + ct.launch( + paged_stream, + (num_batch, num_qo_heads, (max_seq_len + best_cfg.BLOCK_M - 1) // best_cfg.BLOCK_M), + tuned_kernel, + ( + q, + k_cache, + v_cache, + actual_seq_lens_q_flat, + actual_seq_lens_kv_flat, + batch_offsets_flat, + block_tables_flat, + outputs, + out_lse, + k_scale, + v_scale, + num_kv_heads, + page_size, + best_cfg.BLOCK_M, + best_cfg.BLOCK_N, + head_dim_vo, + BLOCK_R, + QUERY_GROUP_SIZE, + stride_block_table, + is_causal, + min(best_cfg.BLOCK_N, page_size), + ), + ) + + return outputs, out_lse + + +@register_impl("flashinfer.attention.prefill_attention_kv_ragged", backend="cutile") +def prefill_attention_kv_ragged( + q, + k_cache, + v_cache, + actual_seq_lens_q, + actual_seq_lens_kv, + actual_seq_offset, + block_tables, + k_scale, + v_scale, + num_batch, + max_seq_len, + is_causal: bool = True, + outputs: Optional[torch.Tensor] = None, + out_lse: Optional[torch.Tensor] = None, + use_lpt_scheduler: bool = True, +): + """ + Prefill attention with ragged KV cache (CuTile implementation). + """ + # KV cache [total_num_tokens, num_kv_heads, head_dim_qk] + num_kv_heads = k_cache.shape[1] + num_qo_heads = q.shape[1] + head_dim_qk = q.shape[-1] + head_dim_vo = v_cache.shape[-1] + + BLOCK_R = head_dim_qk - head_dim_vo + QUERY_GROUP_SIZE = num_qo_heads // num_kv_heads + + outputs = ( + torch.empty( + [q.shape[0], num_qo_heads, head_dim_vo], + device=q.device, + dtype=q.dtype, + ) + if outputs is None + else outputs + ) + out_lse = ( + torch.zeros([q.shape[0], num_qo_heads], dtype=torch.float32, device=q.device) if out_lse is None else out_lse + ) + + # Flatten tensors for kernel + actual_seq_lens_q_flat = actual_seq_lens_q.reshape(-1).contiguous() + actual_seq_lens_kv_flat = actual_seq_lens_kv.reshape(-1).contiguous() + batch_offsets_flat = actual_seq_offset.reshape(-1).contiguous() + + autotune_key = ( + QUERY_GROUP_SIZE, + num_kv_heads, + BLOCK_R, + head_dim_vo, + k_scale, + v_scale, + max_seq_len, + num_batch, + 3 if is_causal else 1, # STAGE + ) + + if use_lpt_scheduler: + element_size = q.element_size() + size_one_kv_head = max_seq_len * (head_dim_qk + head_dim_vo) * element_size + size_l2 = 50 * 1024 * 1024 # 50 MB for K & V + if size_l2 < size_one_kv_head: + swizzle = 1 + else: + log2_floor = (size_l2 // size_one_kv_head).bit_length() - 1 + swizzle = 1 << log2_floor + num_hb_quotient = (num_qo_heads * num_batch) // swizzle + num_hb_remainder = (num_qo_heads * num_batch) % swizzle + + ragged_lpt_stream = torch.cuda.current_stream() + ragged_lpt_cache_key = (autotune_key, swizzle, str(q.device)) + if ragged_lpt_cache_key not in _prefill_ragged_lpt_tune_cache: + result = exhaustive_search( + list(_get_prefill_autotune_configs(None)), + ragged_lpt_stream, + lambda cfg: ((max_seq_len + cfg.BLOCK_M - 1) // cfg.BLOCK_M * num_qo_heads * num_batch, 1, 1), + _prefill_attention_ragged_lpt_kernel, + lambda cfg: ( + q, + k_cache, + v_cache, + actual_seq_lens_q_flat, + actual_seq_lens_kv_flat, + batch_offsets_flat, + outputs, + out_lse, + k_scale, + v_scale, + num_kv_heads, + cfg.BLOCK_M, + cfg.BLOCK_N, + head_dim_vo, + BLOCK_R, + QUERY_GROUP_SIZE, + q.stride(0), + q.stride(1), + q.stride(2), + outputs.stride(0), + outputs.stride(1), + outputs.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + out_lse.stride(0), + is_causal, + num_qo_heads, + num_batch, + max_seq_len, + swizzle, + num_hb_quotient, + max(num_hb_remainder, 1), + ), + lambda cfg: {"occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _prefill_ragged_lpt_tune_cache[ragged_lpt_cache_key] = ( + best_cfg, + ct.kernel( + _prefill_attention_ragged_lpt_kernel._pyfunc, + occupancy=best_cfg.occupancy, + ), + ) + best_cfg, tuned_kernel = _prefill_ragged_lpt_tune_cache[ragged_lpt_cache_key] + ct.launch( + ragged_lpt_stream, + ((max_seq_len + best_cfg.BLOCK_M - 1) // best_cfg.BLOCK_M * num_qo_heads * num_batch, 1, 1), + tuned_kernel, + ( + q, + k_cache, + v_cache, + actual_seq_lens_q_flat, + actual_seq_lens_kv_flat, + batch_offsets_flat, + outputs, + out_lse, + k_scale, + v_scale, + num_kv_heads, + best_cfg.BLOCK_M, + best_cfg.BLOCK_N, + head_dim_vo, + BLOCK_R, + QUERY_GROUP_SIZE, + q.stride(0), + q.stride(1), + q.stride(2), + outputs.stride(0), + outputs.stride(1), + outputs.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + out_lse.stride(0), + is_causal, + num_qo_heads, + num_batch, + max_seq_len, + swizzle, + num_hb_quotient, + max(num_hb_remainder, 1), + ), + ) + else: + ragged_stream = torch.cuda.current_stream() + ragged_cache_key = (autotune_key, str(q.device)) + if ragged_cache_key not in _prefill_ragged_tune_cache: + result = exhaustive_search( + list(_get_prefill_autotune_configs(None)), + ragged_stream, + lambda cfg: ((max_seq_len + cfg.BLOCK_M - 1) // cfg.BLOCK_M, num_batch, num_qo_heads), + _prefill_attention_ragged_kernel, + lambda cfg: ( + q, + k_cache, + v_cache, + actual_seq_lens_q_flat, + actual_seq_lens_kv_flat, + batch_offsets_flat, + outputs, + out_lse, + k_scale, + v_scale, + num_kv_heads, + cfg.BLOCK_M, + cfg.BLOCK_N, + head_dim_vo, + BLOCK_R, + QUERY_GROUP_SIZE, + q.stride(0), + q.stride(1), + q.stride(2), + outputs.stride(0), + outputs.stride(1), + outputs.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + out_lse.stride(0), + is_causal, + ), + lambda cfg: {"occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _prefill_ragged_tune_cache[ragged_cache_key] = ( + best_cfg, + ct.kernel( + _prefill_attention_ragged_kernel._pyfunc, + occupancy=best_cfg.occupancy, + ), + ) + best_cfg, tuned_kernel = _prefill_ragged_tune_cache[ragged_cache_key] + ct.launch( + ragged_stream, + ((max_seq_len + best_cfg.BLOCK_M - 1) // best_cfg.BLOCK_M, num_batch, num_qo_heads), + tuned_kernel, + ( + q, + k_cache, + v_cache, + actual_seq_lens_q_flat, + actual_seq_lens_kv_flat, + batch_offsets_flat, + outputs, + out_lse, + k_scale, + v_scale, + num_kv_heads, + best_cfg.BLOCK_M, + best_cfg.BLOCK_N, + head_dim_vo, + BLOCK_R, + QUERY_GROUP_SIZE, + q.stride(0), + q.stride(1), + q.stride(2), + outputs.stride(0), + outputs.stride(1), + outputs.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + out_lse.stride(0), + is_causal, + ), + ) + + return outputs, out_lse diff --git a/src/tilegym/suites/flashinfer/cutile/gemm/__init__.py b/src/tilegym/suites/flashinfer/cutile/gemm/__init__.py new file mode 100644 index 00000000..01a0af36 --- /dev/null +++ b/src/tilegym/suites/flashinfer/cutile/gemm/__init__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +from . import gemm_alpha_beta +from . import masked_bmm +from . import ragged_block_scaled_bmm +from . import ragged_bmm diff --git a/src/tilegym/suites/flashinfer/cutile/gemm/gemm_alpha_beta.py b/src/tilegym/suites/flashinfer/cutile/gemm/gemm_alpha_beta.py new file mode 100644 index 00000000..25afb890 --- /dev/null +++ b/src/tilegym/suites/flashinfer/cutile/gemm/gemm_alpha_beta.py @@ -0,0 +1,523 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import os +from math import ceil +from types import SimpleNamespace + +import cuda.tile as ct +import torch +from cuda.tile.tune import exhaustive_search + +from tilegym.backend import register_impl +from tilegym.kernel_utils import get_kernel_configs + +# Module-level tune cache: (M, N, K, transpose_a_int, transpose_b_int, dtype, num_sms, device) -> (best_cfg, tuned_kernel) +_gemm_alpha_beta_tune_cache: dict = {} + + +def cdiv(a, b): + """Ceiling division helper function.""" + return (a + b - 1) // b + + +@ct.kernel +def gemm_alpha_beta_kernel_cutile( + a_ptr, # Input matrix A [M, K] or [K, M] if transpose_a + b_ptr, # Input matrix B [K, N] or [N, K] if transpose_b + c_ptr, # Output/Input matrix C [M, N] - modified in place + alpha: ct.Constant[float], # Alpha scaling factor + beta: ct.Constant[float], # Beta scaling factor + M: ct.Constant[int], # M dimension + N: ct.Constant[int], # N dimension + K: ct.Constant[int], # K dimension + total_tiles: ct.Constant[int], # Total number of tiles + num_programs: ct.Constant[int], # Number of SMs + num_pid_m: ct.Constant[int], # Number of M tiles + num_pid_n: ct.Constant[int], # Number of N tiles + transpose_a: ct.Constant[int], # Whether A is transposed (0 or 1) + transpose_b: ct.Constant[int], # Whether B is transposed (0 or 1) + BLOCK_M: ct.Constant[int], + BLOCK_N: ct.Constant[int], + BLOCK_K: ct.Constant[int], + GROUP_SIZE_M: ct.Constant[int], + EPILOGUE_SUBTILE: ct.Constant[int], +): + """ + CuTile kernel for GEMM with alpha/beta scaling: C = alpha * A @ B + beta * C + + Features: + - Standard GEMM with alpha and beta scaling factors + - Supports transpose_a and transpose_b + - Uses persistent scheduling with SM-aware grid sizing + - Uses GROUP_SIZE_M based tile swizzling + - Optimized with latency hints for better pipelining + """ + pid = ct.bid(0) + + num_k_tiles = ct.cdiv(K, BLOCK_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + zero_pad = ct.PaddingMode.ZERO + + # Persistent scheduling loop + for current_pid in range(pid, total_tiles, num_programs): + # Calculate pid_m, pid_n with GROUP_SIZE_M swizzling + group_id = current_pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m_actual = ct.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + + pid_m = first_pid_m + (current_pid % group_size_m_actual) + pid_n = (current_pid % num_pid_in_group) // group_size_m_actual + + # Initialize accumulator + acc = ct.full((BLOCK_M, BLOCK_N), 0.0, dtype=ct.float32) + + # K-loop for matrix multiplication using tile indices + for k in range(num_k_tiles): + # Load A block based on transpose_a flag with latency hint for pipelining + if transpose_a == 1: + # A is [K, M], load [BLOCK_K, BLOCK_M] and transpose + a_block_kt = ct.load( + a_ptr, + index=(k, pid_m), # tile indices + shape=(BLOCK_K, BLOCK_M), + order=(0, 1), + padding_mode=zero_pad, + latency=3, + ) + a_block = ct.permute(a_block_kt, (1, 0)) # [BLOCK_M, BLOCK_K] + else: + # A is [M, K], load [BLOCK_M, BLOCK_K] + a_block = ct.load( + a_ptr, + index=(pid_m, k), # tile indices + shape=(BLOCK_M, BLOCK_K), + order=(0, 1), + padding_mode=zero_pad, + latency=3, + ) + + # Load B block based on transpose_b flag with latency hint + if transpose_b == 1: + # B is [N, K], load [BLOCK_N, BLOCK_K] and transpose + b_block_nt = ct.load( + b_ptr, + index=(pid_n, k), # tile indices + shape=(BLOCK_N, BLOCK_K), + order=(0, 1), + padding_mode=zero_pad, + latency=3, + ) + b_block = ct.permute(b_block_nt, (1, 0)) # [BLOCK_K, BLOCK_N] + else: + # B is [K, N], load [BLOCK_K, BLOCK_N] + b_block = ct.load( + b_ptr, + index=(k, pid_n), # tile indices + shape=(BLOCK_K, BLOCK_N), + order=(0, 1), + padding_mode=zero_pad, + latency=3, + ) + + # Matrix multiplication: A @ B + acc = ct.mma(a_block, b_block, acc=acc) + + if EPILOGUE_SUBTILE == 1: + # Split accumulator into two N/2 halves to reduce shared memory in epilogue + acc0 = ct.extract(acc, index=(0, 0), shape=(BLOCK_M, BLOCK_N // 2)) + acc1 = ct.extract(acc, index=(0, 1), shape=(BLOCK_M, BLOCK_N // 2)) + + c_load0 = ct.load( + c_ptr, + index=(pid_m, pid_n * 2), + shape=(BLOCK_M, BLOCK_N // 2), + order=(0, 1), + padding_mode=zero_pad, + ) + c_load0_f32 = ct.astype(c_load0, ct.float32) + result0 = alpha * acc0 + beta * c_load0_f32 + c_block0 = ct.astype(result0, c_ptr.dtype) + ct.store( + c_ptr, + index=(pid_m, pid_n * 2), + tile=c_block0, + order=(0, 1), + ) + + c_load1 = ct.load( + c_ptr, + index=(pid_m, pid_n * 2 + 1), + shape=(BLOCK_M, BLOCK_N // 2), + order=(0, 1), + padding_mode=zero_pad, + ) + c_load1_f32 = ct.astype(c_load1, ct.float32) + result1 = alpha * acc1 + beta * c_load1_f32 + c_block1 = ct.astype(result1, c_ptr.dtype) + ct.store( + c_ptr, + index=(pid_m, pid_n * 2 + 1), + tile=c_block1, + order=(0, 1), + ) + else: + c_load = ct.load( + c_ptr, + index=(pid_m, pid_n), + shape=(BLOCK_M, BLOCK_N), + order=(0, 1), + padding_mode=zero_pad, + ) + + c_load_f32 = ct.astype(c_load, ct.float32) + result = alpha * acc + beta * c_load_f32 + + c_block = ct.astype(result, c_ptr.dtype) + + ct.store( + c_ptr, + index=(pid_m, pid_n), + tile=c_block, + order=(0, 1), + ) + + +def _gemm_alpha_beta_autotune_configs(): + """ + Iterator of autotune configurations for gemm_alpha_beta kernel. + Returns configurations optimized for different GPU architectures. + """ + gpu_capability = torch.cuda.get_device_capability() + + if gpu_capability[0] >= 10: + # EPILOGUE_SUBTILE=1 only validated on SM100; disable on SM12x to avoid correctness issues + subtile_options = [0, 1] if gpu_capability == (10, 0) else [0] + for BM, BN, nc in [ + (64, 64, 1), + (64, 128, 1), + (128, 64, 1), + (128, 128, 1), + (256, 64, 1), + (256, 128, 1), + (256, 128, 2), + (256, 256, 2), + ]: + for BK in [64]: + for occupancy in [1, 2, 4, 8]: + for subtile in subtile_options: + yield SimpleNamespace( + BLOCK_M=BM, + BLOCK_N=BN, + BLOCK_K=BK, + GROUP_SIZE_M=8, + num_ctas=nc, + occupancy=occupancy, + EPILOGUE_SUBTILE=subtile, + ) + elif gpu_capability == (9, 0): + for BM, BN in [ + (128, 128), + (128, 256), + (64, 128), + (128, 64), + (256, 128), + ]: + for BK in [64]: + for occupancy in [1, 2, 4]: + yield SimpleNamespace( + BLOCK_M=BM, + BLOCK_N=BN, + BLOCK_K=BK, + GROUP_SIZE_M=8, + num_ctas=1, + occupancy=occupancy, + EPILOGUE_SUBTILE=0, + ) + else: + for BM, BN in [ + (64, 64), + (128, 64), + (128, 128), + (128, 256), + (256, 128), + ]: + for BK in [64]: + for occupancy in [1, 2]: + yield SimpleNamespace( + BLOCK_M=BM, + BLOCK_N=BN, + BLOCK_K=BK, + GROUP_SIZE_M=8, + num_ctas=1, + occupancy=occupancy, + EPILOGUE_SUBTILE=0, + ) + + +def _get_default_kernel_configs(): + """ + Get GPU-specific default kernel configs for non-autotune path. + """ + gpu_capability = torch.cuda.get_device_capability() + + if gpu_capability == (10, 0): + # Blackwell SM100 – aggressive config with epilogue subtiling + return { + "BLOCK_M": 256, + "BLOCK_N": 128, + "BLOCK_K": 64, + "GROUP_SIZE_M": 8, + "num_ctas": 2, + "occupancy": 2, + "EPILOGUE_SUBTILE": 1, + } + elif gpu_capability[0] >= 10: + # SM10x / SM12x – conservative default (autotuner finds best) + return { + "BLOCK_M": 128, + "BLOCK_N": 64, + "BLOCK_K": 64, + "GROUP_SIZE_M": 8, + "num_ctas": 1, + "occupancy": 2, + "EPILOGUE_SUBTILE": 0, + } + elif gpu_capability == (9, 0): + return { + "BLOCK_M": 128, + "BLOCK_N": 128, + "BLOCK_K": 64, + "GROUP_SIZE_M": 8, + "num_ctas": 1, + "occupancy": 1, + "EPILOGUE_SUBTILE": 0, + } + else: + return { + "BLOCK_M": 128, + "BLOCK_N": 128, + "BLOCK_K": 64, + "GROUP_SIZE_M": 8, + "num_ctas": 1, + "occupancy": 1, + "EPILOGUE_SUBTILE": 0, + } + + +def _compute_grid_and_programs(M, N, BLOCK_M, BLOCK_N, num_sms, num_ctas, occupancy): + """Helper to compute grid size and number of programs.""" + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + if num_sms is not None: + NUM_SMS = min(NUM_SMS, num_sms) + + num_pid_m = cdiv(M, BLOCK_M) + num_pid_n = cdiv(N, BLOCK_N) + total_tiles = num_pid_m * num_pid_n + # Ensure num_programs >= 1: cuTile requires positive step for range() in persistent kernel. + # When num_sms is very small (e.g., 1) and num_ctas > 1, NUM_SMS // num_ctas can be 0. + num_programs = max(1, min(NUM_SMS // num_ctas, total_tiles) * occupancy) + + return num_pid_m, num_pid_n, total_tiles, num_programs + + +@register_impl("flashinfer.gemm.gemm_alpha_beta", backend="cutile") +def gemm_alpha_beta( + a: torch.Tensor, + b: torch.Tensor, + c: torch.Tensor, + trans_a=False, + trans_b=True, + alpha=1.0, + beta=0.0, + num_sms=None, + **kwargs, +): + """ + CuTile implementation of GEMM with alpha/beta scaling. + + Computes: C = alpha * A @ B + beta * C + + Args: + a: Input matrix A [M, K] or [K, M] if trans_a + b: Input matrix B [K, N] or [N, K] if trans_b + c: Input/Output matrix C [M, N] - modified in place + trans_a: Whether A is transposed + trans_b: Whether B is transposed + alpha: Scaling factor for A @ B + beta: Scaling factor for existing C + num_sms: Number of SMs to use (for SM throttling) + + Returns: + Output tensor C [M, N] + """ + # Get dimensions + if trans_a: + K, M = a.shape + else: + M, K = a.shape + + if trans_b: + N, KB = b.shape + else: + KB, N = b.shape + + assert K == KB, "incompatible dimensions" + assert c.shape == (M, N), "C must have shape [M, N]" + assert a.is_contiguous(), "A matrix must be contiguous" + assert b.is_contiguous(), "B matrix must be contiguous" + assert c.is_contiguous(), "C matrix must be contiguous" + + # Convert boolean to int for ct.Constant + transpose_a_int = 1 if trans_a else 0 + transpose_b_int = 1 if trans_b else 0 + + # Check if autotune is requested + use_autotune = kwargs.get("use_autotune", True) + + # For very low SM counts (1-16), disable autotune and use fixed config + # to avoid autotune overhead dominating runtime + if num_sms is not None and num_sms <= 16: + use_autotune = False + + if use_autotune: + # Use exhaustive_search for automatic configuration selection + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + def grid_fn(cfg): + num_pid_m, num_pid_n, total_tiles, num_programs = _compute_grid_and_programs( + M, N, cfg.BLOCK_M, cfg.BLOCK_N, num_sms, cfg.num_ctas, cfg.occupancy + ) + return (num_programs, 1, 1) + + def args_fn(cfg): + num_pid_m, num_pid_n, total_tiles, num_programs = _compute_grid_and_programs( + M, N, cfg.BLOCK_M, cfg.BLOCK_N, num_sms, cfg.num_ctas, cfg.occupancy + ) + return ( + a, + b, + c.clone(), # Clone for tuning to avoid corrupting C + float(alpha), + float(beta), + M, + N, + K, + total_tiles, + num_programs, + num_pid_m, + num_pid_n, + transpose_a_int, + transpose_b_int, + cfg.BLOCK_M, + cfg.BLOCK_N, + cfg.BLOCK_K, + cfg.GROUP_SIZE_M, + cfg.EPILOGUE_SUBTILE, + ) + + def launch_args_fn(cfg): + num_pid_m, num_pid_n, total_tiles, num_programs = _compute_grid_and_programs( + M, N, cfg.BLOCK_M, cfg.BLOCK_N, num_sms, cfg.num_ctas, cfg.occupancy + ) + return ( + a, + b, + c, # Use actual C for final launch + float(alpha), + float(beta), + M, + N, + K, + total_tiles, + num_programs, + num_pid_m, + num_pid_n, + transpose_a_int, + transpose_b_int, + cfg.BLOCK_M, + cfg.BLOCK_N, + cfg.BLOCK_K, + cfg.GROUP_SIZE_M, + cfg.EPILOGUE_SUBTILE, + ) + + stream = torch.cuda.current_stream() + cache_key = (M, N, K, transpose_a_int, transpose_b_int, a.dtype, num_sms, str(a.device)) + if cache_key not in _gemm_alpha_beta_tune_cache: + result = exhaustive_search( + list(_gemm_alpha_beta_autotune_configs()), + stream, + grid_fn, + gemm_alpha_beta_kernel_cutile, + args_fn, + lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _gemm_alpha_beta_tune_cache[cache_key] = ( + best_cfg, + ct.kernel( + gemm_alpha_beta_kernel_cutile._pyfunc, + num_ctas=best_cfg.num_ctas, + occupancy=best_cfg.occupancy, + ), + ) + best_cfg, tuned_kernel = _gemm_alpha_beta_tune_cache[cache_key] + ct.launch(stream, grid_fn(best_cfg), tuned_kernel, launch_args_fn(best_cfg)) + + return c + + else: + # Fallback to non-autotune path + default_configs = _get_default_kernel_configs() + kernel_configs = get_kernel_configs(default_configs, kwargs.get("kernel_configs")) + + BLOCK_M = kernel_configs.get("BLOCK_M") + BLOCK_N = kernel_configs.get("BLOCK_N") + BLOCK_K = kernel_configs.get("BLOCK_K") + GROUP_SIZE_M = kernel_configs.get("GROUP_SIZE_M", 8) + num_ctas = kernel_configs.get("num_ctas", 1) + occupancy = kernel_configs.get("occupancy", 1) + epilogue_subtile = kernel_configs.get("EPILOGUE_SUBTILE", 0) + + num_pid_m, num_pid_n, total_tiles, num_programs = _compute_grid_and_programs( + M, N, BLOCK_M, BLOCK_N, num_sms, num_ctas, occupancy + ) + + # 1D grid for persistent scheduling + grid = (num_programs, 1, 1) + + # Build kernel with hints + + kernel = gemm_alpha_beta_kernel_cutile + + ct.launch( + torch.cuda.current_stream(), + grid, + kernel, + ( + a, + b, + c, + float(alpha), + float(beta), + M, + N, + K, + total_tiles, + num_programs, + num_pid_m, + num_pid_n, + transpose_a_int, + transpose_b_int, + BLOCK_M, + BLOCK_N, + BLOCK_K, + GROUP_SIZE_M, + epilogue_subtile, + ), + ) + + return c diff --git a/src/tilegym/suites/flashinfer/cutile/gemm/masked_bmm.py b/src/tilegym/suites/flashinfer/cutile/gemm/masked_bmm.py new file mode 100644 index 00000000..3648bc9b --- /dev/null +++ b/src/tilegym/suites/flashinfer/cutile/gemm/masked_bmm.py @@ -0,0 +1,444 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import os +from math import ceil +from types import SimpleNamespace + +import cuda.tile as ct +import torch +from cuda.tile.tune import exhaustive_search + +from tilegym.backend import register_impl +from tilegym.kernel_utils import get_kernel_configs + +# Module-level tune cache: (Q, M, N, K, transpose_a_int, transpose_b_int, dtype, device) -> (best_cfg, tuned_kernel) +_masked_bmm_tune_cache: dict = {} + + +def cdiv(a, b): + """Ceiling division helper function.""" + return (a + b - 1) // b + + +@ct.kernel +def masked_bmm_kernel_cutile( + a_ptr, # Input matrix A [Q, M, K] or [Q, K, M] if transpose_a + b_ptr, # Input matrix B [Q, K, N] or [Q, N, K] if transpose_b + c_ptr, # Output matrix C [Q, M, N] + masked_m, # Per-batch M mask [Q], int32 + total_tiles: ct.Constant[int], # Total number of tiles + num_programs: ct.Constant[int], # Number of SMs + num_pid_m: ct.Constant[int], # Number of M tiles per batch + num_pid_n: ct.Constant[int], # Number of N tiles per batch + tiles_per_batch: ct.Constant[int], # num_pid_m * num_pid_n + transpose_a: ct.Constant[int], # Whether A is transposed (0 or 1) + transpose_b: ct.Constant[int], # Whether B is transposed (0 or 1) + BLOCK_M: ct.Constant[int], + BLOCK_N: ct.Constant[int], + BLOCK_K: ct.Constant[int], + GROUP_SIZE_M: ct.Constant[int], +): + """ + CuTile kernel for masked batched matrix multiplication. + + Performs A @ B with per-batch M masking where: + - A is batched [Q, M, K] or [Q, K, M] if transpose_a + - B is batched [Q, K, N] or [Q, N, K] if transpose_b + - masked_m is per-batch M mask [Q] + - Output C is [Q, M, N] + + Uses persistent scheduling with static grid and GROUP_SIZE_M tile swizzling. + """ + pid = ct.bid(0) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + zero_pad = ct.PaddingMode.ZERO + + # Compute num_k_tiles from tensor shape using ct.num_tiles + # For non-transposed A: shape is [Q, M, K], we tile K (axis=2) + # For transposed A: shape is [Q, K, M], we tile K (axis=1) + if transpose_a == 1: + num_k_tiles = ct.num_tiles(a_ptr, axis=1, shape=(1, BLOCK_K, BLOCK_M)) + else: + num_k_tiles = ct.num_tiles(a_ptr, axis=2, shape=(1, BLOCK_M, BLOCK_K)) + + # Persistent scheduling loop + for current_pid in range(pid, total_tiles, num_programs): + # Calculate pid_q, pid_m, pid_n with GROUP_SIZE_M swizzling + pid_q = current_pid // tiles_per_batch + pid_in_batch = current_pid % tiles_per_batch + + group_id = pid_in_batch // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m_actual = ct.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + + pid_m = first_pid_m + (pid_in_batch % group_size_m_actual) + pid_n = (pid_in_batch % num_pid_in_group) // group_size_m_actual + + # Load valid_m for this batch + valid_m_tile = ct.load(masked_m, index=(pid_q,), shape=(1,)) + valid_m = valid_m_tile.item() + + # Only process if this tile is within valid M range + if pid_m * BLOCK_M < valid_m: + # Initialize accumulator + acc = ct.full((BLOCK_M, BLOCK_N), 0.0, dtype=ct.float32) + + # K-loop for matrix multiplication using tile indices + for k in range(num_k_tiles): + # Load A block based on transpose_a flag + # Using tile-index based loading (k is tile index, not element offset) + if transpose_a == 1: + # A is [Q, K, M], load [1, BLOCK_K, BLOCK_M] using tile indices + a_block_3d = ct.load( + a_ptr, + index=(pid_q, k, pid_m), # tile indices + shape=(1, BLOCK_K, BLOCK_M), + order=(0, 1, 2), + padding_mode=zero_pad, + ) + a_block_km = ct.reshape(a_block_3d, (BLOCK_K, BLOCK_M)) + a_block = ct.permute(a_block_km, (1, 0)) # [BLOCK_M, BLOCK_K] + else: + # A is [Q, M, K], load [1, BLOCK_M, BLOCK_K] using tile indices + a_block_3d = ct.load( + a_ptr, + index=(pid_q, pid_m, k), # tile indices + shape=(1, BLOCK_M, BLOCK_K), + order=(0, 1, 2), + padding_mode=zero_pad, + ) + a_block = ct.reshape(a_block_3d, (BLOCK_M, BLOCK_K)) + + # Load B block based on transpose_b flag + if transpose_b == 1: + # B is [Q, N, K], load [1, BLOCK_N, BLOCK_K] using tile indices + b_block_3d = ct.load( + b_ptr, + index=(pid_q, pid_n, k), # tile indices + shape=(1, BLOCK_N, BLOCK_K), + order=(0, 1, 2), + padding_mode=zero_pad, + ) + b_block_nk = ct.reshape(b_block_3d, (BLOCK_N, BLOCK_K)) + b_block = ct.permute(b_block_nk, (1, 0)) # [BLOCK_K, BLOCK_N] + else: + # B is [Q, K, N], load [1, BLOCK_K, BLOCK_N] using tile indices + b_block_3d = ct.load( + b_ptr, + index=(pid_q, k, pid_n), # tile indices + shape=(1, BLOCK_K, BLOCK_N), + order=(0, 1, 2), + padding_mode=zero_pad, + ) + b_block = ct.reshape(b_block_3d, (BLOCK_K, BLOCK_N)) + + # Matrix multiplication: A @ B + acc = ct.mma(a_block, b_block, acc=acc) + + # Convert to output dtype and store + c_block = ct.astype(acc, c_ptr.dtype) + + # Reshape to 3D for store [1, BLOCK_M, BLOCK_N] + c_block_3d = ct.reshape(c_block, (1, BLOCK_M, BLOCK_N)) + + # Store to output C [Q, M, N] using tile indices + ct.store( + c_ptr, + index=(pid_q, pid_m, pid_n), # tile indices + tile=c_block_3d, + order=(0, 1, 2), + ) + + +def _masked_bmm_autotune_configs(): + """ + Iterator of autotune configurations for masked BMM kernel. + + IMPORTANT: Focus tuning on num_ctas and occupancy as requested. + - num_ctas: Number of CTAs in a CGA (valid: 1, 2, 4, 8, 16) + - occupancy: Expected active CTAs per SM (range: 1-32) + + For GEMM-like kernels: + - Higher num_ctas can improve L2 cache hit rate via CGA + - Occupancy affects latency hiding vs register pressure tradeoff + """ + gpu_capability = torch.cuda.get_device_capability() + + if gpu_capability[0] == 10: + # B200 / GB200 (sm100 / sm103) + for BM, BN in [ + (128, 128), + (128, 256), + (256, 128), + (256, 256), + ]: + for BK in [64]: + # Focus on num_ctas tuning: 1, 2, 4 are most common for GEMM + for num_ctas in [1, 2, 4]: + # Focus on occupancy tuning: 1-4 for compute-bound GEMM + for occupancy in [1, 2, 4]: + yield SimpleNamespace( + BLOCK_M=BM, + BLOCK_N=BN, + BLOCK_K=BK, + GROUP_SIZE_M=8, + num_ctas=num_ctas, + occupancy=occupancy, + ) + elif gpu_capability in [(12, 0), (12, 1)]: + # RTX 5090 (sm120/sm121) + for BM, BN in [ + (128, 128), + (128, 256), + (256, 128), + (256, 256), + ]: + for BK in [64]: + for num_ctas in [1, 2, 4]: + for occupancy in [1, 2, 4]: + yield SimpleNamespace( + BLOCK_M=BM, + BLOCK_N=BN, + BLOCK_K=BK, + GROUP_SIZE_M=8, + num_ctas=num_ctas, + occupancy=occupancy, + ) + else: + # Default configurations + for BM, BN in [ + (128, 128), + (128, 256), + ]: + for BK in [64]: + for num_ctas in [1, 2]: + for occupancy in [1, 2, 4]: + yield SimpleNamespace( + BLOCK_M=BM, + BLOCK_N=BN, + BLOCK_K=BK, + GROUP_SIZE_M=8, + num_ctas=num_ctas, + occupancy=occupancy, + ) + + +def _get_default_kernel_configs(): + """ + Get GPU-specific default kernel configs for non-autotune path. + """ + gpu_capability = torch.cuda.get_device_capability() + + if gpu_capability[0] == 10: + # B200 / GB200 (sm100 / sm103) + return { + "BLOCK_M": 128, + "BLOCK_N": 128, + "BLOCK_K": 64, + "GROUP_SIZE_M": 8, + "num_ctas": 1, + "occupancy": 2, + } + elif gpu_capability in [(12, 0), (12, 1)]: + # RTX 5090 (sm120/sm121) + return { + "BLOCK_M": 128, + "BLOCK_N": 128, + "BLOCK_K": 64, + "GROUP_SIZE_M": 8, + "num_ctas": 1, + "occupancy": 2, + } + else: + return { + "BLOCK_M": 128, + "BLOCK_N": 128, + "BLOCK_K": 64, + "GROUP_SIZE_M": 8, + "num_ctas": 1, + "occupancy": 1, + } + + +def _masked_bmm_autotune_launch(stream, a, b, c, masked_m, Q, M, N, transpose_a, transpose_b): + NUM_SMS = torch.cuda.get_device_properties(a.device).multi_processor_count + + transpose_a_int = 1 if transpose_a else 0 + transpose_b_int = 1 if transpose_b else 0 + + def args_fn(cfg): + BM = cfg.BLOCK_M + BN = cfg.BLOCK_N + BK = cfg.BLOCK_K + GSM = cfg.GROUP_SIZE_M + + num_pid_m = cdiv(M, BM) + num_pid_n = cdiv(N, BN) + tiles_per_batch = num_pid_m * num_pid_n + total_tiles = tiles_per_batch * Q + num_programs = min(NUM_SMS // cfg.num_ctas, total_tiles) * cfg.occupancy + + return ( + a, + b, + c, + masked_m, + total_tiles, + num_programs, + num_pid_m, + num_pid_n, + tiles_per_batch, + transpose_a_int, + transpose_b_int, + BM, + BN, + BK, + GSM, + ) + + def grid_fn(cfg): + BM = cfg.BLOCK_M + BN = cfg.BLOCK_N + num_pid_m = cdiv(M, BM) + num_pid_n = cdiv(N, BN) + tiles_per_batch = num_pid_m * num_pid_n + total_tiles = tiles_per_batch * Q + num_programs = min(NUM_SMS // cfg.num_ctas, total_tiles) * cfg.occupancy + return (num_programs, 1, 1) + + K = a.shape[1] if transpose_a else a.shape[2] + cache_key = (Q, M, N, K, transpose_a_int, transpose_b_int, a.dtype, str(a.device)) + if cache_key not in _masked_bmm_tune_cache: + result = exhaustive_search( + list(_masked_bmm_autotune_configs()), + stream, + grid_fn, + masked_bmm_kernel_cutile, + args_fn, + lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _masked_bmm_tune_cache[cache_key] = ( + best_cfg, + ct.kernel( + masked_bmm_kernel_cutile._pyfunc, + num_ctas=best_cfg.num_ctas, + occupancy=best_cfg.occupancy, + ), + ) + best_cfg, tuned_kernel = _masked_bmm_tune_cache[cache_key] + ct.launch(stream, grid_fn(best_cfg), tuned_kernel, args_fn(best_cfg)) + + +@register_impl("flashinfer.gemm.masked_bmm", backend="cutile") +def masked_bmm( + a, + b, + masked_m, + transpose_a=False, + transpose_b=False, + static_persistent=None, + **kwargs, +): + """ + CuTile implementation of masked batched matrix multiplication. + + Performs A @ B with per-batch M masking where: + - A is batched [Q, M, K] or [Q, K, M] if transpose_a + - B is batched [Q, K, N] or [Q, N, K] if transpose_b + - masked_m is per-batch M mask [Q] + + Args: + a: Input matrix A, batched [Q, M, K] or [Q, K, M] if transpose_a + b: Input matrix B, batched [Q, K, N] or [Q, N, K] if transpose_b + masked_m: Per-batch M mask tensor [Q] + transpose_a: Whether A is transposed + transpose_b: Whether B is transposed + static_persistent: Whether to use static persistent (unused, for API compat) + + Returns: + Output tensor C [Q, M, N] + """ + # Get dimensions from input tensors + if transpose_a: + Q_A, K_A, M = a.shape + else: + Q_A, M, K_A = a.shape + + if transpose_b: + Q_B, N, K_B = b.shape + else: + Q_B, K_B, N = b.shape + + assert K_A == K_B, "incompatible dimensions" + assert Q_A == Q_B, "incompatible dimensions" + K = K_A + Q = Q_A + + assert a.is_contiguous(), "A matrix must be contiguous" + assert b.is_contiguous(), "B matrix must be contiguous" + assert masked_m.is_contiguous(), "Masked matrix must be contiguous" + assert masked_m.shape.numel() == Q, "Masked matrix must have the same shape as the number of batches" + + # Allocate output + c = torch.empty((Q, M, N), device=a.device, dtype=a.dtype) + + enable_autotune = os.environ.get("DISABLE_CUTILE_TUNE", "0") != "1" + + if enable_autotune: + _masked_bmm_autotune_launch(torch.cuda.current_stream(), a, b, c, masked_m, Q, M, N, transpose_a, transpose_b) + else: + default_configs = _get_default_kernel_configs() + kernel_configs = get_kernel_configs(default_configs, kwargs.get("kernel_configs")) + + BLOCK_M = kernel_configs.get("BLOCK_M") + BLOCK_N = kernel_configs.get("BLOCK_N") + BLOCK_K = kernel_configs.get("BLOCK_K") + GROUP_SIZE_M = kernel_configs.get("GROUP_SIZE_M", 8) + num_ctas = kernel_configs.get("num_ctas", 1) + occupancy = kernel_configs.get("occupancy", 1) + + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + num_pid_m = cdiv(M, BLOCK_M) + num_pid_n = cdiv(N, BLOCK_N) + tiles_per_batch = num_pid_m * num_pid_n + total_tiles = tiles_per_batch * Q + num_programs = min(NUM_SMS // num_ctas, total_tiles) * occupancy + + grid = (num_programs, 1, 1) + + transpose_a_int = 1 if transpose_a else 0 + transpose_b_int = 1 if transpose_b else 0 + + # Build kernel with hints + + kernel = masked_bmm_kernel_cutile + + ct.launch( + torch.cuda.current_stream(), + grid, + kernel, + ( + a, + b, + c, + masked_m, + total_tiles, + num_programs, + num_pid_m, + num_pid_n, + tiles_per_batch, + transpose_a_int, + transpose_b_int, + BLOCK_M, + BLOCK_N, + BLOCK_K, + GROUP_SIZE_M, + ), + ) + + return c diff --git a/src/tilegym/suites/flashinfer/cutile/gemm/ragged_block_scaled_bmm.py b/src/tilegym/suites/flashinfer/cutile/gemm/ragged_block_scaled_bmm.py new file mode 100644 index 00000000..f31bca86 --- /dev/null +++ b/src/tilegym/suites/flashinfer/cutile/gemm/ragged_block_scaled_bmm.py @@ -0,0 +1,598 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import os +from math import ceil +from types import SimpleNamespace + +import cuda.tile as ct +import torch + +from tilegym.backend import register_impl +from tilegym.kernel_utils import get_kernel_configs + + +def cdiv(a, b): + """Ceiling division helper function.""" + return (a + b - 1) // b + + +def _is_large_m(total_m, Q): + """Determine if average M is large enough for non-swapped configs.""" + average_m = total_m / Q + is_large_m = average_m >= 256 + return is_large_m + + +@ct.kernel +def ragged_block_scaled_bmm_kernel_cutile( + a_ptr, # Input matrix A [total_m, K] FP8 + b_ptr, # Input matrix B [Q, N, K] FP8 + a_scale_ptr, # Scale for A [total_m, k_tiles] FP32 + b_scale_ptr, # Scale for B [Q, n_tiles, k_tiles] FP32 + c_ptr, # Output matrix C [total_m, N] + m_indptr, # Segment offsets [Q+1], flattened 1D + Q: ct.Constant[int], # Number of batches + max_m: ct.Constant[int], # Max segment size + N: ct.Constant[int], # Output N dimension + K: ct.Constant[int], # K dimension + total_m: ct.Constant[int], # Total M (for bounds checking) + total_tiles: ct.Constant[int], # Total number of tiles + num_programs: ct.Constant[int], # Number of SMs + num_k_tiles: ct.Constant[int], # Number of K tiles + num_pid_m: ct.Constant[int], # Number of M tiles per batch + num_pid_n: ct.Constant[int], # Number of N tiles per batch + tiles_per_batch: ct.Constant[int], # num_pid_m * num_pid_n + stride_a0: ct.Constant[int], # Stride for A dim 0 + stride_a1: ct.Constant[int], # Stride for A dim 1 + stride_b0: ct.Constant[int], # Stride for B dim 0 + stride_b1: ct.Constant[int], # Stride for B dim 1 + stride_b2: ct.Constant[int], # Stride for B dim 2 + stride_sa0: ct.Constant[int], # Stride for a_scale dim 0 + stride_sa1: ct.Constant[int], # Stride for a_scale dim 1 + stride_sb0: ct.Constant[int], # Stride for b_scale dim 0 + stride_sb1: ct.Constant[int], # Stride for b_scale dim 1 + stride_sb2: ct.Constant[int], # Stride for b_scale dim 2 + stride_c0: ct.Constant[int], # Stride for C dim 0 + stride_c1: ct.Constant[int], # Stride for C dim 1 + has_a_scale: ct.Constant[int], # Whether a_scale is provided (0 or 1) + BLOCK_M: ct.Constant[int], + BLOCK_N: ct.Constant[int], + BLOCK_K: ct.Constant[int], + GROUP_SIZE_M: ct.Constant[int], +): + """ + CuTile kernel for ragged block-scaled batched matrix multiplication. + + Performs (A * a_scale) @ (B * b_scale)^T where: + - A is flattened FP8 with segment offsets (m_indptr defines boundaries) + - B is batched FP8 [Q, N, K] + - a_scale and b_scale are per-block scales + - Output C is [total_m, N] + + Uses persistent scheduling with static grid and GROUP_SIZE_M tile swizzling. + Uses Array.slice + TMA (ct.load/ct.store) for A and C access. + """ + pid = ct.bid(0) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + # Persistent scheduling loop + for current_pid in range(pid, total_tiles, num_programs): + # Calculate pid_q, pid_m, pid_n with GROUP_SIZE_M swizzling + # pid_q = batch index + pid_q = current_pid // tiles_per_batch + pid_in_batch = current_pid % tiles_per_batch + + group_id = pid_in_batch // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m_actual = ct.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + + pid_m = first_pid_m + (pid_in_batch % group_size_m_actual) + pid_n = (pid_in_batch % num_pid_in_group) // group_size_m_actual + + # Load segment boundaries using ct.load with dynamic index + m_start_tile = ct.load(m_indptr, index=(pid_q,), shape=(1,)) + m_start = m_start_tile.item() + m_end_tile = ct.load(m_indptr, index=(pid_q + 1,), shape=(1,)) + m_end = m_end_tile.item() + valid_m = m_end - m_start + + # Only process if this tile is within valid M range + if pid_m * BLOCK_M < valid_m: + # Create sliced views for A and C using Array.slice + Ai = a_ptr.slice(axis=0, start=m_start, stop=m_end) + Ci = c_ptr.slice(axis=0, start=m_start, stop=m_end) + + if has_a_scale == 1: + a_scale_i = a_scale_ptr.slice(axis=0, start=m_start, stop=m_end) + + # Initialize accumulator + acc = ct.full((BLOCK_M, BLOCK_N), 0.0, dtype=ct.float32) + + # N tile offset (element-level) for b_scale calculation + n_offset = pid_n * BLOCK_N + offs_bsn = n_offset // BLOCK_K + + # Zero accumulator for per-K MMA (reused each iteration) + mma_zeros = ct.full((BLOCK_M, BLOCK_N), 0.0, dtype=ct.float32) + + # K-loop for matrix multiplication + for k in range(num_k_tiles): + k_offset = k * BLOCK_K + + # Load A block using TMA + a_block = ct.load( + Ai, + index=(pid_m, k), + shape=(BLOCK_M, BLOCK_K), + padding_mode=ct.PaddingMode.ZERO, + ) + + # Load B block - B is [Q, N, K], we need [BLOCK_N, BLOCK_K] + b_block_3d = ct.load( + b_ptr, + index=(pid_q, n_offset // BLOCK_N, k_offset // BLOCK_K), + shape=(1, BLOCK_N, BLOCK_K), + order=(0, 1, 2), + padding_mode=ct.PaddingMode.ZERO, + ) + # Reshape to [BLOCK_N, BLOCK_K] then transpose to get [BLOCK_K, BLOCK_N] + b_block_nk = ct.reshape(b_block_3d, (BLOCK_N, BLOCK_K)) + b_block = ct.permute(b_block_nk, (1, 0)) # [BLOCK_K, BLOCK_N] + + # Matrix multiplication: A [BLOCK_M, BLOCK_K] @ B [BLOCK_K, BLOCK_N] = C [BLOCK_M, BLOCK_N] + c_mma = ct.mma(a_block, b_block, acc=mma_zeros) + + # Load and apply scales + if has_a_scale == 1: + # Load a_scale for this block using TMA + a_scale_block = ct.load( + a_scale_i, + index=(pid_m, k), + shape=(BLOCK_M, 1), + padding_mode=ct.PaddingMode.ZERO, + ) + + # Load b_scale - scalar at [pid_q, offs_bsn, k] + b_scale_block = ct.load( + b_scale_ptr, + index=(pid_q, offs_bsn, k), + shape=(1, 1, 1), + order=(0, 1, 2), + padding_mode=ct.PaddingMode.ZERO, + ) + b_scale_val = ct.reshape(b_scale_block, (1, 1)) + + # Combined scale: a_scale [BLOCK_M, 1] * b_scale [1, 1] = [BLOCK_M, 1] + scale_combined = a_scale_block * ct.broadcast_to(b_scale_val, (BLOCK_M, 1)) + scale_ab = ct.broadcast_to(scale_combined, (BLOCK_M, BLOCK_N)) + else: + # Only b_scale + b_scale_block = ct.load( + b_scale_ptr, + index=(pid_q, offs_bsn, k), + shape=(1, 1, 1), + order=(0, 1, 2), + padding_mode=ct.PaddingMode.ZERO, + ) + b_scale_val = ct.reshape(b_scale_block, (1, 1)) + scale_ab = ct.broadcast_to(b_scale_val, (BLOCK_M, BLOCK_N)) + + # Apply scale and accumulate + acc = acc + c_mma * scale_ab + + # Convert to output dtype + c_block = ct.astype(acc, c_ptr.dtype) + + # Store to output C using TMA + ct.store(Ci, index=(pid_m, pid_n), tile=c_block) + + +@ct.kernel +def ragged_block_scaled_bmm_kernel_cutile_swap_ab( + a_ptr, # Input matrix A [total_m, K] FP8 + b_ptr, # Input matrix B [Q, N, K] FP8 + a_scale_ptr, # Scale for A [total_m, k_tiles] FP32 + b_scale_ptr, # Scale for B [Q, n_tiles, k_tiles] FP32 + c_ptr, # Output matrix C [total_m, N] + m_indptr, # Segment offsets [Q+1], flattened 1D + Q: ct.Constant[int], + max_m: ct.Constant[int], + N: ct.Constant[int], + K: ct.Constant[int], + total_m: ct.Constant[int], + total_tiles: ct.Constant[int], + num_programs: ct.Constant[int], + num_k_tiles: ct.Constant[int], + num_pid_m: ct.Constant[int], + num_pid_n: ct.Constant[int], + tiles_per_batch: ct.Constant[int], + stride_a0: ct.Constant[int], + stride_a1: ct.Constant[int], + stride_b0: ct.Constant[int], + stride_b1: ct.Constant[int], + stride_b2: ct.Constant[int], + stride_sa0: ct.Constant[int], + stride_sa1: ct.Constant[int], + stride_sb0: ct.Constant[int], + stride_sb1: ct.Constant[int], + stride_sb2: ct.Constant[int], + stride_c0: ct.Constant[int], + stride_c1: ct.Constant[int], + has_a_scale: ct.Constant[int], + BLOCK_M: ct.Constant[int], + BLOCK_N: ct.Constant[int], + BLOCK_K: ct.Constant[int], + GROUP_SIZE_M: ct.Constant[int], +): + """ + CuTile kernel for ragged block-scaled BMM with swap_ab optimization. + Uses Array.slice + TMA (ct.load/ct.store) for A and C access. + """ + pid = ct.bid(0) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + # Persistent scheduling loop + for current_pid in range(pid, total_tiles, num_programs): + pid_q = current_pid // tiles_per_batch + pid_in_batch = current_pid % tiles_per_batch + + group_id = pid_in_batch // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m_actual = ct.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + + pid_m = first_pid_m + (pid_in_batch % group_size_m_actual) + pid_n = (pid_in_batch % num_pid_in_group) // group_size_m_actual + + m_start_tile = ct.load(m_indptr, index=(pid_q,), shape=(1,)) + m_start = m_start_tile.item() + m_end_tile = ct.load(m_indptr, index=(pid_q + 1,), shape=(1,)) + m_end = m_end_tile.item() + valid_m = m_end - m_start + + if pid_m * BLOCK_M < valid_m: + # Create sliced views for A and C using Array.slice + Ai = a_ptr.slice(axis=0, start=m_start, stop=m_end) + Ci = c_ptr.slice(axis=0, start=m_start, stop=m_end) + + if has_a_scale == 1: + a_scale_i = a_scale_ptr.slice(axis=0, start=m_start, stop=m_end) + + acc = ct.full((BLOCK_M, BLOCK_N), 0.0, dtype=ct.float32) + + n_offset = pid_n * BLOCK_N + offs_bsn = n_offset // BLOCK_K + + # Zero accumulator for per-K MMA (reused each iteration) + mma_zeros = ct.full((BLOCK_N, BLOCK_M), 0.0, dtype=ct.float32) + + for k in range(num_k_tiles): + k_offset = k * BLOCK_K + + # Load A block using TMA + a_block = ct.load( + Ai, + index=(pid_m, k), + shape=(BLOCK_M, BLOCK_K), + padding_mode=ct.PaddingMode.ZERO, + ) + + # Load B block + b_block_3d = ct.load( + b_ptr, + index=(pid_q, n_offset // BLOCK_N, k_offset // BLOCK_K), + shape=(1, BLOCK_N, BLOCK_K), + order=(0, 1, 2), + padding_mode=ct.PaddingMode.ZERO, + ) + b_block_nk = ct.reshape(b_block_3d, (BLOCK_N, BLOCK_K)) + + # swap_ab: compute (B @ A^T)^T + a_block_t = ct.permute(a_block, (1, 0)) + c_swapped = ct.mma(b_block_nk, a_block_t, acc=mma_zeros) + c_mma = ct.permute(c_swapped, (1, 0)) + + # Load and apply scales + if has_a_scale == 1: + a_scale_block = ct.load( + a_scale_i, + index=(pid_m, k), + shape=(BLOCK_M, 1), + padding_mode=ct.PaddingMode.ZERO, + ) + + b_scale_block = ct.load( + b_scale_ptr, + index=(pid_q, offs_bsn, k), + shape=(1, 1, 1), + order=(0, 1, 2), + padding_mode=ct.PaddingMode.ZERO, + ) + b_scale_val = ct.reshape(b_scale_block, (1, 1)) + scale_combined = a_scale_block * ct.broadcast_to(b_scale_val, (BLOCK_M, 1)) + scale_ab = ct.broadcast_to(scale_combined, (BLOCK_M, BLOCK_N)) + else: + b_scale_block = ct.load( + b_scale_ptr, + index=(pid_q, offs_bsn, k), + shape=(1, 1, 1), + order=(0, 1, 2), + padding_mode=ct.PaddingMode.ZERO, + ) + b_scale_val = ct.reshape(b_scale_block, (1, 1)) + scale_ab = ct.broadcast_to(b_scale_val, (BLOCK_M, BLOCK_N)) + + acc = acc + c_mma * scale_ab + + c_block = ct.astype(acc, c_ptr.dtype) + + # Store to output C using TMA + ct.store(Ci, index=(pid_m, pid_n), tile=c_block) + + +def _ragged_block_scaled_bmm_autotune_configs(): + """ + Iterator of autotune configurations for ragged_block_scaled_bmm kernel. + """ + gpu_capability = torch.cuda.get_device_capability() + + if gpu_capability in [(12, 0), (12, 1)]: + for BM, BN, swap_ab in [ + (128, 128, False), + (64, 128, True), + (32, 128, True), + ]: + for BK in [128]: + for occupancy in [1, 2]: + yield SimpleNamespace( + BLOCK_M=BM, + BLOCK_N=BN, + BLOCK_K=BK, + GROUP_SIZE_M=8, + swap_ab=swap_ab, + num_ctas=1, + occupancy=occupancy, + ) + elif gpu_capability == (9, 0): + for BM, BN, swap_ab in [ + (256, 128, False), + (128, 128, False), + (64, 128, True), + (32, 128, True), + (16, 256, True), + (32, 256, True), + (64, 256, True), + ]: + for BK in [128]: + for occupancy in [1, 2]: + yield SimpleNamespace( + BLOCK_M=BM, + BLOCK_N=BN, + BLOCK_K=BK, + GROUP_SIZE_M=8 if not swap_ab else 4, + swap_ab=swap_ab, + num_ctas=2 if BM == 256 else 1, + occupancy=occupancy, + ) + else: + # Non-swapped configs (for large M) + for BM, nc, occ in [ + (256, 2, 1), + (128, 1, 1), + (128, 2, 2), # for small M + ]: + yield SimpleNamespace( + BLOCK_M=BM, + BLOCK_N=128, + BLOCK_K=128, + GROUP_SIZE_M=8, + swap_ab=False, + num_ctas=nc, + occupancy=occ, + ) + # Swapped configs (for small M) + for GM in [2, 4]: + for BM in [16, 32, 64]: + yield SimpleNamespace( + BLOCK_M=BM, + BLOCK_N=256, + BLOCK_K=128, + GROUP_SIZE_M=GM, + swap_ab=True, + num_ctas=1, + occupancy=1, + ) + + +def _get_default_kernel_configs(total_m, Q, VEC_SIZE): + """ + Get GPU-specific default kernel configs for non-autotune path. + """ + gpu_capability = torch.cuda.get_device_capability() + is_large_m = _is_large_m(total_m, Q) + + if gpu_capability in [(12, 0), (12, 1)]: + return { + "BLOCK_M": 128, + "BLOCK_N": 128, + "BLOCK_K": VEC_SIZE, + "GROUP_SIZE_M": 8, + "swap_ab": False, + "num_ctas": 1, + "occupancy": 2, + } + elif gpu_capability == (9, 0): + if is_large_m: + return { + "BLOCK_M": 256, + "BLOCK_N": 128, + "BLOCK_K": VEC_SIZE, + "GROUP_SIZE_M": 8, + "swap_ab": False, + "num_ctas": 2, + "occupancy": 1, + } + else: + return { + "BLOCK_M": 128, + "BLOCK_N": 128, + "BLOCK_K": VEC_SIZE, + "GROUP_SIZE_M": 8, + "swap_ab": False, + "num_ctas": 1, + "occupancy": 1, + } + else: + return { + "BLOCK_M": 128, + "BLOCK_N": 128, + "BLOCK_K": VEC_SIZE, + "GROUP_SIZE_M": 8, + "swap_ab": False, + "num_ctas": 1, + "occupancy": 1, + } + + +@register_impl("flashinfer.gemm.ragged_block_scaled_bmm", backend="cutile") +def ragged_block_scaled_bmm( + a, + b, + a_scale, + b_scale, + m_indptr, + max_m, + max_m_device=None, + transpose_a=False, + transpose_b=True, + out_dtype=None, + **kwargs, +): + """ + CuTile implementation of ragged block-scaled BMM. + """ + # Validate inputs + assert transpose_a == False and transpose_b == True, "Only NT layout is supported" + assert a.is_contiguous(), "A matrix must be contiguous" + assert b.is_contiguous(), "B matrix must be contiguous" + assert a_scale is None or a_scale.is_contiguous(), "A scale matrix must be contiguous" + assert b_scale.is_contiguous(), "B scale matrix must be contiguous" + assert m_indptr.is_contiguous(), "m_indptr must be contiguous" + + # Get dimensions + total_m, K_A = a.shape + Q, N, K_B = b.shape + + assert K_A == K_B, f"K dimensions must match: {K_A} != {K_B}" + assert m_indptr.shape[0] == Q + 1, "m_indptr must have Q+1 elements" + + # Validate scale dimensions + Q_SB, rnb, rkb = b_scale.shape + VEC_SIZE = K_B // rkb + + if a_scale is not None: + total_ma, rka = a_scale.shape + assert total_ma == total_m, "a_scale total_m dimension mismatch" + + assert Q_SB == Q, "b_scale Q dimension mismatch" + + # Determine output dtype + if out_dtype is None: + out_dtype = torch.bfloat16 + + # Allocate output + c = torch.empty((total_m, N), device=a.device, dtype=out_dtype) + + # Get kernel configs + default_configs = _get_default_kernel_configs(total_m, Q, VEC_SIZE) + kernel_configs = get_kernel_configs(default_configs, kwargs.get("kernel_configs")) + + BLOCK_M = kernel_configs.get("BLOCK_M") + BLOCK_N = kernel_configs.get("BLOCK_N") + BLOCK_K = kernel_configs.get("BLOCK_K", VEC_SIZE) + GROUP_SIZE_M = kernel_configs.get("GROUP_SIZE_M", 8) + swap_ab = kernel_configs.get("swap_ab", False) + num_ctas = kernel_configs.get("num_ctas", 1) + occupancy = kernel_configs.get("occupancy", 1) + + # Calculate grid size for persistent scheduling + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + num_pid_m = cdiv(max_m, BLOCK_M) + num_pid_n = cdiv(N, BLOCK_N) + tiles_per_batch = num_pid_m * num_pid_n + total_tiles = tiles_per_batch * Q + num_programs = min(NUM_SMS // num_ctas, total_tiles) * occupancy + num_k_tiles = cdiv(K_A, BLOCK_K) + + grid = (num_programs, 1, 1) + + # Prepare strides + stride_a0 = a.stride(0) + stride_a1 = a.stride(1) + stride_b0 = b.stride(0) + stride_b1 = b.stride(1) + stride_b2 = b.stride(2) + stride_sa0 = a_scale.stride(0) if a_scale is not None else 0 + stride_sa1 = a_scale.stride(1) if a_scale is not None else 0 + stride_sb0 = b_scale.stride(0) + stride_sb1 = b_scale.stride(1) + stride_sb2 = b_scale.stride(2) + stride_c0 = c.stride(0) + stride_c1 = c.stride(1) + has_a_scale = 1 if a_scale is not None else 0 + + if a_scale is None: + a_scale_ptr = torch.empty(1, device=a.device, dtype=torch.float32) + else: + a_scale_ptr = a_scale + + kernel_fn = ragged_block_scaled_bmm_kernel_cutile_swap_ab if swap_ab else ragged_block_scaled_bmm_kernel_cutile + + kernel = kernel_fn + + ct.launch( + torch.cuda.current_stream(), + grid, + kernel, + ( + a, + b, + a_scale_ptr, + b_scale, + c, + m_indptr, + Q, + max_m, + N, + K_A, + total_m, + total_tiles, + num_programs, + num_k_tiles, + num_pid_m, + num_pid_n, + tiles_per_batch, + stride_a0, + stride_a1, + stride_b0, + stride_b1, + stride_b2, + stride_sa0, + stride_sa1, + stride_sb0, + stride_sb1, + stride_sb2, + stride_c0, + stride_c1, + has_a_scale, + BLOCK_M, + BLOCK_N, + BLOCK_K, + GROUP_SIZE_M, + ), + ) + + return c diff --git a/src/tilegym/suites/flashinfer/cutile/gemm/ragged_bmm.py b/src/tilegym/suites/flashinfer/cutile/gemm/ragged_bmm.py new file mode 100644 index 00000000..a3235d8c --- /dev/null +++ b/src/tilegym/suites/flashinfer/cutile/gemm/ragged_bmm.py @@ -0,0 +1,747 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import os +from math import ceil +from types import SimpleNamespace + +import cuda.tile as ct +import torch +from cuda.tile.tune import exhaustive_search + +from tilegym.backend import register_impl +from tilegym.kernel_utils import get_kernel_configs + +# Module-level tune caches for standard and swap_ab ragged BMM +_ragged_bmm_standard_tune_cache: dict = {} +_ragged_bmm_swap_ab_tune_cache: dict = {} + + +def cdiv(a, b): + """Ceiling division helper function.""" + return (a + b - 1) // b + + +@ct.kernel +def ragged_bmm_kernel_cutile( + a_ptr, # Input matrix A [total_m, K] or [K, total_m] if transpose_a + b_ptr, # Input matrix B [Q, N, K] or [Q, K, N] + c_ptr, # Output matrix C [total_m, N] + m_indptr, # Segment offsets [Q+1], flattened 1D + Q: ct.Constant[int], # Number of batches + max_m: ct.Constant[int], # Max segment size + N: ct.Constant[int], # Output N dimension + total_m: ct.Constant[int], # Total M (for bounds checking) + total_tiles: ct.Constant[int], # Total number of tiles + num_programs: ct.Constant[int], # Number of SMs + num_pid_m: ct.Constant[int], # Number of M tiles per batch + num_pid_n: ct.Constant[int], # Number of N tiles per batch + tiles_per_batch: ct.Constant[int], # num_pid_m * num_pid_n + transpose_a: ct.Constant[int], # Whether A is transposed (0 or 1) + transpose_b: ct.Constant[int], # Whether B is transposed (0 or 1) + BLOCK_M: ct.Constant[int], + BLOCK_N: ct.Constant[int], + BLOCK_K: ct.Constant[int], + GROUP_SIZE_M: ct.Constant[int], +): + """ + CuTile kernel for ragged batched matrix multiplication. + + Performs A @ B^T or A @ B where: + - A is flattened with segment offsets (m_indptr defines boundaries) + - B is batched [Q, N, K] or [Q, K, N] + - Output C is [total_m, N] + + Uses persistent scheduling with static grid and GROUP_SIZE_M tile swizzling. + Uses Array.slice + TMA (ct.load/ct.store) for A and C access. + """ + pid = ct.bid(0) + + # Get K dimension from tensor shape + if transpose_a == 1: + K = a_ptr.shape[0] + else: + K = a_ptr.shape[1] + + num_k_tiles = ct.cdiv(K, BLOCK_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + # Persistent scheduling loop + for current_pid in range(pid, total_tiles, num_programs): + # Calculate pid_q, pid_m, pid_n with GROUP_SIZE_M swizzling + pid_q = current_pid // tiles_per_batch + pid_in_batch = current_pid % tiles_per_batch + + group_id = pid_in_batch // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m_actual = ct.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + + pid_m = first_pid_m + (pid_in_batch % group_size_m_actual) + pid_n = (pid_in_batch % num_pid_in_group) // group_size_m_actual + + # Load segment boundaries using ct.load with dynamic index + m_start_tile = ct.load(m_indptr, index=(pid_q,), shape=(1,)) + m_start = m_start_tile.item() + m_end_tile = ct.load(m_indptr, index=(pid_q + 1,), shape=(1,)) + m_end = m_end_tile.item() + valid_m = m_end - m_start + + # Only process if this tile is within valid M range + if pid_m * BLOCK_M < valid_m: + # Create sliced views for A and C using Array.slice + # This enables TMA (ct.load/ct.store) on ragged data + if transpose_a == 1: + # A is [K, total_m], slice axis 1 for M dimension + Ai = a_ptr.slice(axis=1, start=m_start, stop=m_end) # shape: (K, valid_m) + else: + # A is [total_m, K], slice axis 0 for M dimension + Ai = a_ptr.slice(axis=0, start=m_start, stop=m_end) # shape: (valid_m, K) + + # Slice C along axis 0 for M dimension + Ci = c_ptr.slice(axis=0, start=m_start, stop=m_end) # shape: (valid_m, N) + + # Initialize accumulator + dot_acc = ct.full((BLOCK_M, BLOCK_N), 0.0, dtype=ct.float32) + + # K-loop for matrix multiplication using TMA for A + for k in range(num_k_tiles): + # Load A block based on transpose_a flag using ct.load on sliced array + if transpose_a == 1: + # Ai is [K, valid_m], load [BLOCK_K, BLOCK_M] then transpose + a_block_kt = ct.load( + Ai, + index=(k, pid_m), + shape=(BLOCK_K, BLOCK_M), + padding_mode=ct.PaddingMode.ZERO, + ) + a_block = ct.permute(a_block_kt, (1, 0)) # [BLOCK_M, BLOCK_K] + else: + # Ai is [valid_m, K], load [BLOCK_M, BLOCK_K] + a_block = ct.load( + Ai, + index=(pid_m, k), + shape=(BLOCK_M, BLOCK_K), + padding_mode=ct.PaddingMode.ZERO, + ) + + # Load B block based on transpose_b flag using tile indices + if transpose_b == 1: + # B is [Q, N, K], load [1, BLOCK_N, BLOCK_K] + b_block_3d = ct.load( + b_ptr, + index=(pid_q, pid_n, k), + shape=(1, BLOCK_N, BLOCK_K), + padding_mode=ct.PaddingMode.ZERO, + ) + b_block = ct.reshape(b_block_3d, (BLOCK_N, BLOCK_K)) + # Transpose B: [BLOCK_N, BLOCK_K] -> [BLOCK_K, BLOCK_N] + b_block_t = ct.permute(b_block, (1, 0)) + else: + # B is [Q, K, N], load [1, BLOCK_K, BLOCK_N] + b_block_3d = ct.load( + b_ptr, + index=(pid_q, k, pid_n), + shape=(1, BLOCK_K, BLOCK_N), + padding_mode=ct.PaddingMode.ZERO, + ) + b_block_t = ct.reshape(b_block_3d, (BLOCK_K, BLOCK_N)) + + # Matrix multiplication: A @ B + dot_acc = ct.mma(a_block, b_block_t, acc=dot_acc) + + # Convert to output dtype + c_block = ct.astype(dot_acc, c_ptr.dtype) + + # Store to output C using ct.store on sliced array + # Ci is [valid_m, N], store at tile index (pid_m, pid_n) + # padding_mode handles partial tiles at segment boundaries + ct.store(Ci, index=(pid_m, pid_n), tile=c_block) + + +@ct.kernel +def ragged_bmm_kernel_cutile_swap_ab( + a_ptr, # Input matrix A [total_m, K] or [K, total_m] if transpose_a + b_ptr, # Input matrix B [Q, N, K] or [Q, K, N] + c_ptr, # Output matrix C [total_m, N] + m_indptr, # Segment offsets [Q+1], flattened 1D + Q: ct.Constant[int], # Number of batches + max_m: ct.Constant[int], # Max segment size + N: ct.Constant[int], # Output N dimension + total_m: ct.Constant[int], # Total M (for bounds checking) + total_tiles: ct.Constant[int], # Total number of tiles + num_programs: ct.Constant[int], # Number of SMs + num_pid_m: ct.Constant[int], # Number of M tiles per batch + num_pid_n: ct.Constant[int], # Number of N tiles per batch + tiles_per_batch: ct.Constant[int], # num_pid_m * num_pid_n + transpose_a: ct.Constant[int], # Whether A is transposed (0 or 1) + transpose_b: ct.Constant[int], # Whether B is transposed (0 or 1) + BLOCK_M: ct.Constant[int], + BLOCK_N: ct.Constant[int], + BLOCK_K: ct.Constant[int], + GROUP_SIZE_M: ct.Constant[int], +): + """ + CuTile kernel for ragged batched matrix multiplication with swap_ab optimization. + + Uses swapped accumulator layout (BLOCK_N, BLOCK_M) for better performance + when M dimension is small. Equivalent to: dot(B^T.T, A.T).T = A @ B^T + + Uses Array.slice + TMA (ct.load/ct.store) for A and C access. + """ + pid = ct.bid(0) + + # Get K dimension from tensor shape + if transpose_a == 1: + K = a_ptr.shape[0] + else: + K = a_ptr.shape[1] + + num_k_tiles = ct.cdiv(K, BLOCK_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + # Persistent scheduling loop + for current_pid in range(pid, total_tiles, num_programs): + # Calculate pid_q, pid_m, pid_n with GROUP_SIZE_M swizzling + pid_q = current_pid // tiles_per_batch + pid_in_batch = current_pid % tiles_per_batch + + group_id = pid_in_batch // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m_actual = ct.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + + pid_m = first_pid_m + (pid_in_batch % group_size_m_actual) + pid_n = (pid_in_batch % num_pid_in_group) // group_size_m_actual + + # Load segment boundaries using ct.load with dynamic index + m_start_tile = ct.load(m_indptr, index=(pid_q,), shape=(1,)) + m_start = m_start_tile.item() + m_end_tile = ct.load(m_indptr, index=(pid_q + 1,), shape=(1,)) + m_end = m_end_tile.item() + valid_m = m_end - m_start + + # Only process if this tile is within valid M range + if pid_m * BLOCK_M < valid_m: + # Create sliced views for A and C using Array.slice + # This enables TMA (ct.load/ct.store) on ragged data + if transpose_a == 1: + # A is [K, total_m], slice axis 1 for M dimension + Ai = a_ptr.slice(axis=1, start=m_start, stop=m_end) # shape: (K, valid_m) + else: + # A is [total_m, K], slice axis 0 for M dimension + Ai = a_ptr.slice(axis=0, start=m_start, stop=m_end) # shape: (valid_m, K) + + # Slice C along axis 0 for M dimension + Ci = c_ptr.slice(axis=0, start=m_start, stop=m_end) # shape: (valid_m, N) + + # Initialize accumulator with swapped dimensions [BLOCK_N, BLOCK_M] + dot_acc = ct.full((BLOCK_N, BLOCK_M), 0.0, dtype=ct.float32) + + # K-loop for matrix multiplication using TMA for A + for k in range(num_k_tiles): + # Load A block based on transpose_a flag using ct.load on sliced array + if transpose_a == 1: + # Ai is [K, valid_m], load [BLOCK_K, BLOCK_M] then transpose + a_block_kt = ct.load( + Ai, + index=(k, pid_m), + shape=(BLOCK_K, BLOCK_M), + padding_mode=ct.PaddingMode.ZERO, + ) + a_block = ct.permute(a_block_kt, (1, 0)) # [BLOCK_M, BLOCK_K] + else: + # Ai is [valid_m, K], load [BLOCK_M, BLOCK_K] + a_block = ct.load( + Ai, + index=(pid_m, k), + shape=(BLOCK_M, BLOCK_K), + padding_mode=ct.PaddingMode.ZERO, + ) + + # Load B block based on transpose_b flag using tile indices + if transpose_b == 1: + # B is [Q, N, K], load [1, BLOCK_N, BLOCK_K] + b_block_3d = ct.load( + b_ptr, + index=(pid_q, pid_n, k), + shape=(1, BLOCK_N, BLOCK_K), + padding_mode=ct.PaddingMode.ZERO, + ) + b_block = ct.reshape(b_block_3d, (BLOCK_N, BLOCK_K)) + else: + # B is [Q, K, N], load [1, BLOCK_K, BLOCK_N] + b_block_3d = ct.load( + b_ptr, + index=(pid_q, k, pid_n), + shape=(1, BLOCK_K, BLOCK_N), + padding_mode=ct.PaddingMode.ZERO, + ) + b_block_kn = ct.reshape(b_block_3d, (BLOCK_K, BLOCK_N)) + b_block = ct.permute(b_block_kn, (1, 0)) # [BLOCK_N, BLOCK_K] + + # For swap_ab: compute B @ A^T = [BLOCK_N, BLOCK_K] @ [BLOCK_K, BLOCK_M] + a_block_t = ct.permute(a_block, (1, 0)) # [BLOCK_K, BLOCK_M] + + # Matrix multiplication: B @ A^T + dot_acc = ct.mma(b_block, a_block_t, acc=dot_acc) + + # Transpose back: [BLOCK_N, BLOCK_M] -> [BLOCK_M, BLOCK_N] + acc_transposed = ct.permute(dot_acc, (1, 0)) + + # Convert to output dtype + c_block = ct.astype(acc_transposed, c_ptr.dtype) + + # Store to output C using ct.store on sliced array + # Ci is [valid_m, N], store at tile index (pid_m, pid_n) + # padding_mode handles partial tiles at segment boundaries + ct.store(Ci, index=(pid_m, pid_n), tile=c_block) + + +def _ragged_bmm_autotune_configs_standard(): + """ + Iterator of autotune configurations for standard (non-swap_ab) kernel. + + with extended occupancy range for better workload coverage. + """ + gpu_capability = torch.cuda.get_device_capability() + + if gpu_capability[0] == 10: + # B200 / GB200 (sm100 / sm103) - expanded configs + for BM, BN in [ + (256, 256), + (128, 256), + (128, 128), + ]: + for BK in [64]: + for occupancy in [1, 2, 4]: + yield SimpleNamespace( + BLOCK_M=BM, + BLOCK_N=BN, + BLOCK_K=BK, + GROUP_SIZE_M=8, + num_ctas=1, + occupancy=occupancy, + ) + elif gpu_capability in [(12, 0), (12, 1)]: + # RTX 5090 (sm120/sm121) - expanded configs + for BM, BN in [ + (256, 256), + (128, 256), + (128, 128), + ]: + for BK in [64]: + for occupancy in [1, 2, 4]: + yield SimpleNamespace( + BLOCK_M=BM, + BLOCK_N=BN, + BLOCK_K=BK, + GROUP_SIZE_M=8, + num_ctas=1, + occupancy=occupancy, + ) + else: + # Default configurations - expanded + for BM, BN in [ + (128, 256), + (128, 128), + ]: + for BK in [64]: + for occupancy in [1, 2, 4]: + yield SimpleNamespace( + BLOCK_M=BM, + BLOCK_N=BN, + BLOCK_K=BK, + GROUP_SIZE_M=8, + num_ctas=1, + occupancy=occupancy, + ) + + +def _ragged_bmm_autotune_configs_swap_ab(): + """ + Iterator of autotune configurations for swap_ab kernel. + Used when M dimension is small relative to N. + """ + gpu_capability = torch.cuda.get_device_capability() + + if gpu_capability[0] == 10: + # B200 / GB200 (sm100 / sm103) + for BM, BN in [ + (64, 256), + (64, 128), + (32, 128), + ]: + for BK in [64]: + for occupancy in [1, 2, 4]: + yield SimpleNamespace( + BLOCK_M=BM, + BLOCK_N=BN, + BLOCK_K=BK, + GROUP_SIZE_M=8, + num_ctas=1, + occupancy=occupancy, + ) + elif gpu_capability in [(12, 0), (12, 1)]: + # RTX 5090 (sm120/sm121) + for BM, BN in [ + (64, 256), + (64, 128), + (32, 128), + ]: + for BK in [64]: + for occupancy in [1, 2, 4]: + yield SimpleNamespace( + BLOCK_M=BM, + BLOCK_N=BN, + BLOCK_K=BK, + GROUP_SIZE_M=8, + num_ctas=1, + occupancy=occupancy, + ) + else: + # Default configurations + for BM, BN in [ + (64, 128), + (32, 128), + ]: + for BK in [64]: + for occupancy in [1, 2, 4]: + yield SimpleNamespace( + BLOCK_M=BM, + BLOCK_N=BN, + BLOCK_K=BK, + GROUP_SIZE_M=8, + num_ctas=1, + occupancy=occupancy, + ) + + +def _get_default_kernel_configs(): + """ + Get GPU-specific default kernel configs for non-autotune path. + """ + gpu_capability = torch.cuda.get_device_capability() + + if gpu_capability[0] == 10: + return { + "BLOCK_M": 256, + "BLOCK_N": 256, + "BLOCK_K": 64, + "GROUP_SIZE_M": 8, + "swap_ab": False, + "num_ctas": 1, + "occupancy": 2, + } + elif gpu_capability in [(12, 0), (12, 1)]: + # RTX 5090 (sm120/sm121) + return { + "BLOCK_M": 128, + "BLOCK_N": 128, + "BLOCK_K": 64, + "GROUP_SIZE_M": 8, + "swap_ab": False, + "num_ctas": 1, + "occupancy": 2, + } + else: + return { + "BLOCK_M": 128, + "BLOCK_N": 128, + "BLOCK_K": 64, + "GROUP_SIZE_M": 8, + "swap_ab": False, + "num_ctas": 1, + "occupancy": 2, + } + + +def _ragged_bmm_autotune_launch_standard(stream, a, b, c, m_indptr, Q, max_m, N, total_m, transpose_a, transpose_b): + """ + Autotuned launch for standard ragged BMM kernel. + """ + NUM_SMS = torch.cuda.get_device_properties(a.device).multi_processor_count + + transpose_a_int = 1 if transpose_a else 0 + transpose_b_int = 1 if transpose_b else 0 + + def args_fn(cfg): + BM = cfg.BLOCK_M + BN = cfg.BLOCK_N + BK = cfg.BLOCK_K + GSM = cfg.GROUP_SIZE_M + + num_pid_m = cdiv(max_m, BM) + num_pid_n = cdiv(N, BN) + tiles_per_batch = num_pid_m * num_pid_n + total_tiles = tiles_per_batch * Q + num_programs = min(NUM_SMS // cfg.num_ctas, total_tiles) * cfg.occupancy + + return ( + a, + b, + c, + m_indptr, + Q, + max_m, + N, + total_m, + total_tiles, + num_programs, + num_pid_m, + num_pid_n, + tiles_per_batch, + transpose_a_int, + transpose_b_int, + BM, + BN, + BK, + GSM, + ) + + def grid_fn(cfg): + BM = cfg.BLOCK_M + BN = cfg.BLOCK_N + num_pid_m = cdiv(max_m, BM) + num_pid_n = cdiv(N, BN) + tiles_per_batch = num_pid_m * num_pid_n + total_tiles = tiles_per_batch * Q + num_programs = min(NUM_SMS // cfg.num_ctas, total_tiles) * cfg.occupancy + return (num_programs, 1, 1) + + cache_key = (Q, max_m, N, total_m, transpose_a_int, transpose_b_int, a.dtype, str(a.device)) + if cache_key not in _ragged_bmm_standard_tune_cache: + result = exhaustive_search( + list(_ragged_bmm_autotune_configs_standard()), + stream, + grid_fn, + ragged_bmm_kernel_cutile, + args_fn, + lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _ragged_bmm_standard_tune_cache[cache_key] = ( + best_cfg, + ct.kernel( + ragged_bmm_kernel_cutile._pyfunc, + num_ctas=best_cfg.num_ctas, + occupancy=best_cfg.occupancy, + ), + ) + best_cfg, tuned_kernel = _ragged_bmm_standard_tune_cache[cache_key] + ct.launch(stream, grid_fn(best_cfg), tuned_kernel, args_fn(best_cfg)) + + +def _ragged_bmm_autotune_launch_swap_ab(stream, a, b, c, m_indptr, Q, max_m, N, total_m, transpose_a, transpose_b): + """ + Autotuned launch for swap_ab ragged BMM kernel. + """ + NUM_SMS = torch.cuda.get_device_properties(a.device).multi_processor_count + + transpose_a_int = 1 if transpose_a else 0 + transpose_b_int = 1 if transpose_b else 0 + + def args_fn(cfg): + BM = cfg.BLOCK_M + BN = cfg.BLOCK_N + BK = cfg.BLOCK_K + GSM = cfg.GROUP_SIZE_M + + num_pid_m = cdiv(max_m, BM) + num_pid_n = cdiv(N, BN) + tiles_per_batch = num_pid_m * num_pid_n + total_tiles = tiles_per_batch * Q + num_programs = min(NUM_SMS // cfg.num_ctas, total_tiles) * cfg.occupancy + + return ( + a, + b, + c, + m_indptr, + Q, + max_m, + N, + total_m, + total_tiles, + num_programs, + num_pid_m, + num_pid_n, + tiles_per_batch, + transpose_a_int, + transpose_b_int, + BM, + BN, + BK, + GSM, + ) + + def grid_fn(cfg): + BM = cfg.BLOCK_M + BN = cfg.BLOCK_N + num_pid_m = cdiv(max_m, BM) + num_pid_n = cdiv(N, BN) + tiles_per_batch = num_pid_m * num_pid_n + total_tiles = tiles_per_batch * Q + num_programs = min(NUM_SMS // cfg.num_ctas, total_tiles) * cfg.occupancy + return (num_programs, 1, 1) + + swap_cache_key = (Q, max_m, N, total_m, transpose_a_int, transpose_b_int, a.dtype, str(a.device)) + if swap_cache_key not in _ragged_bmm_swap_ab_tune_cache: + result = exhaustive_search( + list(_ragged_bmm_autotune_configs_swap_ab()), + stream, + grid_fn, + ragged_bmm_kernel_cutile_swap_ab, + args_fn, + lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _ragged_bmm_swap_ab_tune_cache[swap_cache_key] = ( + best_cfg, + ct.kernel( + ragged_bmm_kernel_cutile_swap_ab._pyfunc, + num_ctas=best_cfg.num_ctas, + occupancy=best_cfg.occupancy, + ), + ) + best_cfg, tuned_kernel = _ragged_bmm_swap_ab_tune_cache[swap_cache_key] + ct.launch(stream, grid_fn(best_cfg), tuned_kernel, args_fn(best_cfg)) + + +@register_impl("flashinfer.gemm.ragged_bmm", backend="cutile") +def ragged_bmm( + a, + b, + m_indptr, + max_m, + max_m_device=None, + transpose_a=False, + transpose_b=True, + out_dtype=None, + **kwargs, +): + """ + CuTile implementation of ragged BMM with non-even M segments. + + Matrix A is flattened with m_indptr defining the boundaries. + Performs A @ B (or A @ B^T if transpose_b=True) where B is batched. + + Note: For CuTile implementation, segment offsets should ideally be + multiples of BLOCK_M for optimal performance and correctness. + + Args: + a: Input matrix A, flattened [total_m, K] or [K, total_m] if transpose_a + b: Input matrix B, batched [Q, N, K] or [Q, K, N] if not transpose_b + m_indptr: Segment offsets tensor [Q+1] + max_m: Maximum segment size + max_m_device: Optional device tensor with max_m (unused in CuTile, kept for API compatibility) + transpose_a: Whether A is transposed + transpose_b: Whether B is transposed + out_dtype: Output dtype + + Returns: + Output tensor C [total_m, N] + """ + # Get dimensions from flattened matrix a + if transpose_a: + K, total_m = a.shape + else: + total_m, K = a.shape + + if transpose_b: + Q, N, K_B = b.shape + else: + Q, K_B, N = b.shape + + assert K == K_B, "incompatible dimensions" + assert m_indptr.shape[0] == Q + 1, "m_indptr must have Q+1 elements" + assert a.is_contiguous(), "A matrix must be contiguous" + assert b.is_contiguous(), "B matrix must be contiguous" + assert m_indptr.is_contiguous(), "m_indptr must be contiguous" + + # Determine output dtype + if out_dtype is None: + out_dtype = a.dtype + + # Allocate output + c = torch.empty((total_m, N), device=a.device, dtype=out_dtype) + + # Check if autotune is enabled + enable_autotune = os.environ.get("DISABLE_CUTILE_TUNE", "0") != "1" + + # Decide whether to use swap_ab based on M vs N ratio + # swap_ab is beneficial when M is small relative to N + use_swap_ab = max_m <= 128 and N >= 256 + + if enable_autotune: + # Use autotune launch for optimal configuration selection + if use_swap_ab: + _ragged_bmm_autotune_launch_swap_ab( + torch.cuda.current_stream(), a, b, c, m_indptr, Q, max_m, N, total_m, transpose_a, transpose_b + ) + else: + _ragged_bmm_autotune_launch_standard( + torch.cuda.current_stream(), a, b, c, m_indptr, Q, max_m, N, total_m, transpose_a, transpose_b + ) + else: + # Use fixed default configs + default_configs = _get_default_kernel_configs() + kernel_configs = get_kernel_configs(default_configs, kwargs.get("kernel_configs")) + + BLOCK_M = kernel_configs.get("BLOCK_M") + BLOCK_N = kernel_configs.get("BLOCK_N") + BLOCK_K = kernel_configs.get("BLOCK_K") + GROUP_SIZE_M = kernel_configs.get("GROUP_SIZE_M", 8) + swap_ab = kernel_configs.get("swap_ab", False) + num_ctas = kernel_configs.get("num_ctas", 1) + occupancy = kernel_configs.get("occupancy", 2) + + # Calculate grid size for persistent scheduling + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + num_pid_m = cdiv(max_m, BLOCK_M) + num_pid_n = cdiv(N, BLOCK_N) + tiles_per_batch = num_pid_m * num_pid_n + total_tiles = tiles_per_batch * Q + num_programs = min(NUM_SMS // num_ctas, total_tiles) * occupancy + + # 1D grid for persistent scheduling + grid = (num_programs, 1, 1) + + # Convert boolean to int for ct.Constant + transpose_a_int = 1 if transpose_a else 0 + transpose_b_int = 1 if transpose_b else 0 + + # Select kernel based on swap_ab + kernel_fn = ragged_bmm_kernel_cutile_swap_ab if swap_ab else ragged_bmm_kernel_cutile + + # Build kernel with hints + + kernel = kernel_fn + + ct.launch( + torch.cuda.current_stream(), + grid, + kernel, + ( + a, + b, + c, + m_indptr, + Q, + max_m, + N, + total_m, + total_tiles, + num_programs, + num_pid_m, + num_pid_n, + tiles_per_batch, + transpose_a_int, + transpose_b_int, + BLOCK_M, + BLOCK_N, + BLOCK_K, + GROUP_SIZE_M, + ), + ) + + return c diff --git a/src/tilegym/suites/flashinfer/cutile/per_token_group_quant_8bit.py b/src/tilegym/suites/flashinfer/cutile/per_token_group_quant_8bit.py new file mode 100644 index 00000000..6d0f1f66 --- /dev/null +++ b/src/tilegym/suites/flashinfer/cutile/per_token_group_quant_8bit.py @@ -0,0 +1,272 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +from typing import Optional +from typing import Tuple + +import cuda.tile as ct +import torch + +from tilegym.backend import register_impl + +ConstInt = ct.Constant[int] + + +def next_power_of_2(n): + """Return the smallest power of 2 greater than or equal to n.""" + if n <= 0: + return 1 + return 1 << (n - 1).bit_length() + + +@ct.kernel +def _per_token_group_quant_8bit_kernel( + y_ptr, + y_q_ptr, + y_s_ptr, + y_stride: ConstInt, + N: ConstInt, + eps: ConstInt, + bit8_min: ConstInt, + bit8_max: ConstInt, + BLOCK: ConstInt, + y_array_size: ConstInt, + y_q_array_size: ConstInt, + y_s_array_size: ConstInt, +): + """Per-token-group quantization kernel (row-major scales).""" + g_id = ct.bid(0) + + # Compute base offsets for this group + y_base = g_id * y_stride + y_q_base = g_id * y_stride + + # Create column offsets and mask + cols = ct.arange(BLOCK, dtype=ct.int32) + mask = cols < N + + # Load input values with gather (element-level access with mask) + y_indices = y_base + cols + y = ct.gather(y_ptr, (y_indices,), check_bounds=True, padding_value=0.0) + y = ct.astype(y, ct.float32) + + # Compute absmax + abs_y = ct.abs(y) + _absmax = ct.max(abs_y, axis=0) + _absmax = ct.maximum(_absmax, eps) + + # Compute scale and inverse scale + y_s = _absmax / bit8_max + y_s_inv = 1.0 / y_s + + # Quantize: clamp(y * y_s_inv, bit8_min, bit8_max) + y_q = ct.minimum(ct.maximum(y * y_s_inv, bit8_min), bit8_max) + y_q = ct.astype(y_q, y_q_ptr.dtype) + + # Store quantized values with mask (use OOB offsets for invalid positions) + oob_offset = ct.full((BLOCK,), y_q_array_size, dtype=ct.int32) + y_q_indices = y_q_base + cols + y_q_indices_masked = ct.where(mask, y_q_indices, oob_offset) + ct.scatter(y_q_ptr, (y_q_indices_masked,), y_q, check_bounds=True) + + # Store scale (single scalar per group) + y_s_idx = g_id + oob_scalar = ct.full((), y_s_array_size, dtype=ct.int32) + s_idx_masked = ct.where(y_s_idx < y_s_array_size, y_s_idx, oob_scalar) + ct.scatter(y_s_ptr, (s_idx_masked,), y_s) + + +@ct.kernel +def _per_token_group_quant_8bit_colmajor_kernel( + y_ptr, + y_q_ptr, + y_s_ptr, + group_size: ConstInt, + y_num_columns: ConstInt, + y_row_stride: ConstInt, + y_s_col_stride: ConstInt, + eps: ConstInt, + bit8_min: ConstInt, + bit8_max: ConstInt, + scale_ue8m0: ConstInt, + BLOCK: ConstInt, + y_array_size: ConstInt, + y_q_array_size: ConstInt, + y_s_array_size: ConstInt, +): + """Per-token-group quantization kernel (column-major scales).""" + groups_per_row = y_num_columns // group_size + + g_id = ct.bid(0) + row = g_id // groups_per_row + group_id = g_id % groups_per_row + + # Compute base offsets + y_base = row * y_row_stride + group_id * group_size + y_q_base = g_id * group_size + y_s_offset = group_id * y_s_col_stride + row + + # Create column offsets and mask + cols = ct.arange(BLOCK, dtype=ct.int32) + mask = cols < group_size + + # Load input values + y_indices = y_base + cols + y = ct.gather(y_ptr, (y_indices,), check_bounds=True, padding_value=0.0) + y = ct.astype(y, ct.float32) + + # Compute absmax + abs_y = ct.abs(y) + _absmax = ct.max(abs_y, axis=0) + _absmax = ct.maximum(_absmax, eps) + + # Compute scale + y_s = _absmax / bit8_max + + # Optional: round scale to power of 2 (UE8M0) + if scale_ue8m0: + abs_y_s = ct.abs(y_s) + safe_y_s = ct.maximum(abs_y_s, 1e-10) + y_s = ct.exp2(ct.ceil(ct.log2(safe_y_s))) + + # Quantize: clamp(y / y_s, bit8_min, bit8_max) + y_q = ct.minimum(ct.maximum(y / y_s, bit8_min), bit8_max) + y_q = ct.astype(y_q, y_q_ptr.dtype) + + # Store quantized values with mask + oob_offset = ct.full((BLOCK,), y_q_array_size, dtype=ct.int32) + y_q_indices = y_q_base + cols + y_q_indices_masked = ct.where(mask, y_q_indices, oob_offset) + ct.scatter(y_q_ptr, (y_q_indices_masked,), y_q, check_bounds=True) + + # Store scale (single scalar) + oob_scalar = ct.full((), y_s_array_size, dtype=ct.int32) + s_idx_masked = ct.where(y_s_offset < y_s_array_size, y_s_offset, oob_scalar) + ct.scatter(y_s_ptr, (s_idx_masked,), y_s) + + +def _ceil_align(x: int, align: int) -> int: + return (x + align - 1) // align * align + + +@register_impl("flashinfer.quant.per_token_group_quant_8bit", backend="cutile") +def per_token_group_quant_8bit( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dst_dtype: Optional[torch.dtype] = None, + column_major_scales: bool = False, + scale_tma_aligned: bool = False, + scale_ue8m0: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Per-token group 8-bit quantization (FP8 or INT8) - CuTile implementation.""" + if dst_dtype is None: + dst_dtype = torch.float8_e4m3fn + assert x.shape[-1] % group_size == 0, ( + f"the last dimension of `x` {x.shape[-1]} must be divisible by `group_size` {group_size}" + ) + assert x.is_contiguous(), "`x` must be contiguous" + if scale_tma_aligned or scale_ue8m0: + assert column_major_scales, "scale_tma_aligned or scale_ue8m0 requires column_major_scales=True" + + if dst_dtype == torch.int8: + info = torch.iinfo(dst_dtype) + bit8_min = float(info.min) + bit8_max = float(info.max) + else: + info = torch.finfo(dst_dtype) + bit8_min = info.min + bit8_max = info.max + + x_q = torch.empty_like(x, device=x.device, dtype=dst_dtype) + M = x.numel() // group_size + N = group_size + + if column_major_scales: + num_groups = x.shape[-1] // group_size + num_tokens = x.shape[-2] if x.dim() >= 2 else x.shape[0] + if scale_tma_aligned: + # TMA-friendly layout: (num_groups, aligned_num_tokens), align to 4 floats (16B) + aligned_size = _ceil_align(num_tokens, 4) + x_s_raw = torch.empty( + (num_groups, aligned_size), + device=x.device, + dtype=torch.float32, + ) + x_s_col_stride = aligned_size + else: + shape = (num_groups,) + x.shape[:-1] + x_s_raw = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) + x_s_col_stride = x_s_raw.stride(1) + x_s = x_s_raw + else: + shape = x.shape[:-1] + (x.shape[-1] // group_size,) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + + BLOCK = next_power_of_2(N) + + stream = torch.cuda.current_stream() + grid = (M, 1, 1) + + # Flatten tensors for gather/scatter access + x_flat = x.contiguous().view(-1) + x_q_flat = x_q.view(-1) + x_s_flat = x_s.contiguous().view(-1) if not column_major_scales else x_s + + if column_major_scales: + # Use a contiguous flat view of x_s for scatter + # x_s may be non-contiguous (permuted), so we use as_strided to get the raw storage + x_s_for_kernel = ( + x_s_raw.view(-1) + if x_s_raw.is_contiguous() + else torch.as_strided(x_s_raw, (x_s_raw.numel(),), (1,), storage_offset=x_s_raw.storage_offset()) + ) + + ct.launch( + stream, + grid, + _per_token_group_quant_8bit_colmajor_kernel, + ( + x_flat, + x_q_flat, + x_s_for_kernel, + group_size, + x.shape[-1], + x.stride(-2) if x.dim() >= 2 else x.shape[-1], + x_s_col_stride, + eps, + bit8_min, + bit8_max, + 1 if scale_ue8m0 else 0, + BLOCK, + x_flat.numel(), + x_q_flat.numel(), + x_s_for_kernel.numel(), + ), + ) + if scale_tma_aligned: + x_s = x_s_raw[:, :num_tokens].t().contiguous() + else: + assert not scale_ue8m0 + + ct.launch( + stream, + grid, + _per_token_group_quant_8bit_kernel, + ( + x_flat, + x_q_flat, + x_s_flat, + group_size, + N, + eps, + bit8_min, + bit8_max, + BLOCK, + x_flat.numel(), + x_q_flat.numel(), + x_s_flat.numel(), + ), + ) + + return x_q, x_s diff --git a/src/tilegym/suites/flashinfer/cutile/rope_quantize_fp8.py b/src/tilegym/suites/flashinfer/cutile/rope_quantize_fp8.py new file mode 100644 index 00000000..442e4b9e --- /dev/null +++ b/src/tilegym/suites/flashinfer/cutile/rope_quantize_fp8.py @@ -0,0 +1,379 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +import math +import os +from types import SimpleNamespace +from typing import Optional +from typing import Tuple + +import cuda.tile as ct +import torch +from cuda.tile.tune import exhaustive_search + +from tilegym.backend import register_impl + +ConstFloat = ct.Constant[float] +ConstInt = ct.Constant[int] +PAD_ZERO = ct.PaddingMode.ZERO + +# Module-level tune cache: (num_tokens, num_qo_heads, num_kv_heads, rope_dim, no_rope_dim, q_dtype, out_dtype, device) -> (best_cfg, tuned_kernel) +_rope_quantize_fp8_tune_cache: dict = {} + + +def _default_tokens_per_block(num_tokens: int) -> int: + if num_tokens <= 4: + return 1 + if num_tokens <= 32: + return 4 + if num_tokens <= 128: + return 8 + return 32 + + +def _rope_quantize_fp8_cutile_configs(num_tokens: int): + candidates = [tpb for tpb in (1, 4, 8, 16, 32) if tpb <= max(32, num_tokens)] + if os.getenv("ENABLE_CUTILE_TUNE", "0") != "1" or os.getenv("DISABLE_AUTOTUNE") == "1": + default_tpb = _default_tokens_per_block(num_tokens) + return [SimpleNamespace(TOKENS_PER_BLOCK=default_tpb, occupancy=4)] + + return [ + SimpleNamespace(TOKENS_PER_BLOCK=tokens_per_block, occupancy=occupancy) + for tokens_per_block in candidates + for occupancy in (1, 2, 4) + ] + + +def _cutile_autotune_rope_quantize_fp8( + stream, + q_rope, + k_rope, + q_nope, + k_nope, + cos_sin_cache, + pos_ids, + q_rope_out, + k_rope_out, + q_nope_out, + k_nope_out, + quant_scale_q, + quant_scale_kv, + num_tokens, + num_qo_heads, + num_kv_heads, + rope_dim, + no_rope_dim, + total_blocks_y, +): + cache_key = ( + num_tokens, + num_qo_heads, + num_kv_heads, + rope_dim, + no_rope_dim, + q_rope.dtype, + q_rope_out.dtype, + str(q_rope.device), + ) + if cache_key not in _rope_quantize_fp8_tune_cache: + result = exhaustive_search( + list(_rope_quantize_fp8_cutile_configs(num_tokens)), + stream, + lambda cfg: (math.ceil(num_tokens / cfg.TOKENS_PER_BLOCK), total_blocks_y, 1), + rope_quantize_fp8_kernel, + lambda cfg: ( + q_rope, + k_rope, + q_nope, + k_nope, + cos_sin_cache, + pos_ids, + q_rope_out, + k_rope_out, + q_nope_out, + k_nope_out, + quant_scale_q, + quant_scale_kv, + num_tokens, + num_qo_heads, + num_kv_heads, + rope_dim, + no_rope_dim, + cfg.TOKENS_PER_BLOCK, + ), + lambda cfg: {"occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _rope_quantize_fp8_tune_cache[cache_key] = ( + best_cfg, + ct.kernel( + rope_quantize_fp8_kernel._pyfunc, + occupancy=best_cfg.occupancy, + ), + ) + best_cfg, tuned_kernel = _rope_quantize_fp8_tune_cache[cache_key] + ct.launch( + stream, + (math.ceil(num_tokens / best_cfg.TOKENS_PER_BLOCK), total_blocks_y, 1), + tuned_kernel, + ( + q_rope, + k_rope, + q_nope, + k_nope, + cos_sin_cache, + pos_ids, + q_rope_out, + k_rope_out, + q_nope_out, + k_nope_out, + quant_scale_q, + quant_scale_kv, + num_tokens, + num_qo_heads, + num_kv_heads, + rope_dim, + no_rope_dim, + best_cfg.TOKENS_PER_BLOCK, + ), + ) + + +def _load_rope_factors(pos_ids, cos_sin_cache, token_block, TOKENS_PER_BLOCK, HALF_DIM): + pos_tile = ct.load(pos_ids, index=token_block, shape=TOKENS_PER_BLOCK, padding_mode=PAD_ZERO) + pos_rows = ct.reshape(pos_tile, (TOKENS_PER_BLOCK, 1)) + half_cols = ct.reshape(ct.arange(HALF_DIM, dtype=ct.int32), (1, HALF_DIM)) + cos = ct.gather(cos_sin_cache, (pos_rows, half_cols), check_bounds=False) + sin = ct.gather(cos_sin_cache, (pos_rows, half_cols + HALF_DIM), check_bounds=False) + return cos, sin + + +def _apply_rope_interleave_batched(x_tile, cos, sin, out_dtype, quant_scale, TOKENS_PER_BLOCK, HALF_DIM): + x_3d = ct.reshape(ct.astype(x_tile, ct.float32), (TOKENS_PER_BLOCK, HALF_DIM, 2)) + x_even = ct.reshape( + ct.extract(x_3d, index=(0, 0, 0), shape=(TOKENS_PER_BLOCK, HALF_DIM, 1)), + (TOKENS_PER_BLOCK, HALF_DIM), + ) + x_odd = ct.reshape( + ct.extract(x_3d, index=(0, 0, 1), shape=(TOKENS_PER_BLOCK, HALF_DIM, 1)), + (TOKENS_PER_BLOCK, HALF_DIM), + ) + + out_even = (x_even * cos - x_odd * sin) * quant_scale + out_odd = (x_odd * cos + x_even * sin) * quant_scale + + return ct.cat( + ( + ct.reshape(ct.astype(out_even, out_dtype), (TOKENS_PER_BLOCK, HALF_DIM, 1)), + ct.reshape(ct.astype(out_odd, out_dtype), (TOKENS_PER_BLOCK, HALF_DIM, 1)), + ), + 2, + ) + + +def _quantize_batched_tile(x_tile, out_dtype, quant_scale): + return ct.astype(ct.astype(x_tile, ct.float32) * quant_scale, out_dtype) + + +@ct.kernel +def rope_quantize_fp8_kernel( + q_rope, + k_rope, + q_nope, + k_nope, + cos_sin_cache, + pos_ids, + q_rope_out, + k_rope_out, + q_nope_out, + k_nope_out, + quant_scale_q: ConstFloat, + quant_scale_kv: ConstFloat, + num_tokens: ConstInt, + num_qo_heads: ConstInt, + num_kv_heads: ConstInt, + ROPE_DIM: ConstInt, + NO_ROPE_DIM: ConstInt, + TOKENS_PER_BLOCK: ConstInt, +): + pid_x = ct.bid(0) + pid_y = ct.bid(1) + + HALF_DIM: ConstInt = ROPE_DIM // 2 + no_rope_chunks: ConstInt = (NO_ROPE_DIM + ROPE_DIM - 1) // ROPE_DIM + + q_rope_end = num_qo_heads + k_rope_end = q_rope_end + num_kv_heads + k_nope_end = k_rope_end + num_kv_heads * no_rope_chunks + + if pid_y < q_rope_end: + cos, sin = _load_rope_factors(pos_ids, cos_sin_cache, pid_x, TOKENS_PER_BLOCK, HALF_DIM) + head_idx = pid_y + q_tile = ct.load( + q_rope, + index=(pid_x, head_idx, 0), + shape=(TOKENS_PER_BLOCK, 1, ROPE_DIM), + padding_mode=PAD_ZERO, + ) + q_rot = _apply_rope_interleave_batched( + q_tile, + cos, + sin, + q_rope_out.dtype, + quant_scale_q, + TOKENS_PER_BLOCK, + HALF_DIM, + ) + ct.store(q_rope_out, index=(pid_x, head_idx, 0), tile=ct.reshape(q_rot, (TOKENS_PER_BLOCK, 1, ROPE_DIM))) + elif pid_y < k_rope_end: + cos, sin = _load_rope_factors(pos_ids, cos_sin_cache, pid_x, TOKENS_PER_BLOCK, HALF_DIM) + k_tile = ct.load( + k_rope, + index=(pid_x, 0), + shape=(TOKENS_PER_BLOCK, ROPE_DIM), + padding_mode=PAD_ZERO, + ) + k_rot = _apply_rope_interleave_batched( + k_tile, + cos, + sin, + k_rope_out.dtype, + quant_scale_kv, + TOKENS_PER_BLOCK, + HALF_DIM, + ) + ct.store(k_rope_out, index=(pid_x, 0), tile=ct.reshape(k_rot, (TOKENS_PER_BLOCK, ROPE_DIM))) + elif pid_y < k_nope_end: + chunk_idx = pid_y - k_rope_end + k_tile = ct.load( + k_nope, + index=(pid_x, chunk_idx), + shape=(TOKENS_PER_BLOCK, ROPE_DIM), + padding_mode=PAD_ZERO, + ) + ct.store( + k_nope_out, index=(pid_x, chunk_idx), tile=_quantize_batched_tile(k_tile, k_nope_out.dtype, quant_scale_kv) + ) + else: + task_idx = pid_y - k_nope_end + head_idx = task_idx // no_rope_chunks + chunk_idx = task_idx % no_rope_chunks + q_tile = ct.load( + q_nope, + index=(pid_x, head_idx, chunk_idx), + shape=(TOKENS_PER_BLOCK, 1, ROPE_DIM), + padding_mode=PAD_ZERO, + ) + ct.store( + q_nope_out, + index=(pid_x, head_idx, chunk_idx), + tile=_quantize_batched_tile(q_tile, q_nope_out.dtype, quant_scale_q), + ) + + +@register_impl("flashinfer.rope.rope_quantize_fp8", backend="cutile") +def rope_quantize_fp8( + q_rope: torch.Tensor, + k_rope: torch.Tensor, + q_nope: Optional[torch.Tensor], + k_nope: Optional[torch.Tensor], + cos_sin_cache: torch.Tensor, + pos_ids: torch.Tensor, + is_neox: bool = True, + quantize_dtype: Optional[torch.dtype] = None, + quant_scale_q: float = 1.0, + quant_scale_kv: float = 1.0, + q_rope_out: Optional[torch.Tensor] = None, + k_rope_out: Optional[torch.Tensor] = None, + q_nope_out: Optional[torch.Tensor] = None, + k_nope_out: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if cos_sin_cache.dtype != torch.float32: + raise ValueError("cos_sin_cache should be float32") + + nnz = q_rope.shape[0] + num_qo_heads = q_rope.shape[1] + is_mla = k_rope.ndim == 2 + num_kv_heads = 1 if is_mla else k_rope.shape[1] + + if q_nope is None: + q_nope = torch.empty(nnz, num_qo_heads, 0, dtype=q_rope.dtype, device=q_rope.device) + if k_nope is None: + if is_mla: + k_nope = torch.empty(nnz, 0, dtype=k_rope.dtype, device=k_rope.device) + else: + k_nope = torch.empty(nnz, num_kv_heads, 0, dtype=k_rope.dtype, device=k_rope.device) + + if quantize_dtype is None: + for out in (q_rope_out, k_rope_out, q_nope_out, k_nope_out): + if out is not None: + quantize_dtype = out.dtype + break + else: + quantize_dtype = torch.float8_e4m3fn + + q_rope_out = q_rope_out if q_rope_out is not None else torch.empty_like(q_rope, dtype=quantize_dtype) + k_rope_out = k_rope_out if k_rope_out is not None else torch.empty_like(k_rope, dtype=quantize_dtype) + q_nope_out = q_nope_out if q_nope_out is not None else torch.empty_like(q_nope, dtype=quantize_dtype) + k_nope_out = k_nope_out if k_nope_out is not None else torch.empty_like(k_nope, dtype=quantize_dtype) + + num_tokens = q_rope.shape[0] + rope_dim = q_rope.shape[2] + num_kv_heads = 1 if k_rope.ndim == 2 else k_rope.shape[1] + no_rope_dim = q_nope.shape[2] if q_nope is not None else 0 + + no_rope_chunks = (no_rope_dim + rope_dim - 1) // rope_dim + total_blocks_y = num_qo_heads + num_kv_heads + num_kv_heads * no_rope_chunks + num_qo_heads * no_rope_chunks + + assert not is_neox, "is_neox should be False for rope_quantize_fp8" + + stream = torch.cuda.current_stream() + configs = _rope_quantize_fp8_cutile_configs(num_tokens) + if len(configs) == 1: + cfg = configs[0] + grid = (math.ceil(num_tokens / cfg.TOKENS_PER_BLOCK), total_blocks_y, 1) + kernel = ct.kernel(rope_quantize_fp8_kernel._pyfunc, occupancy=cfg.occupancy) + args = ( + q_rope, + k_rope, + q_nope, + k_nope, + cos_sin_cache, + pos_ids, + q_rope_out, + k_rope_out, + q_nope_out, + k_nope_out, + quant_scale_q, + quant_scale_kv, + num_tokens, + num_qo_heads, + num_kv_heads, + rope_dim, + no_rope_dim, + cfg.TOKENS_PER_BLOCK, + ) + ct.launch(stream, grid, kernel, args) + else: + _cutile_autotune_rope_quantize_fp8( + stream, + q_rope, + k_rope, + q_nope, + k_nope, + cos_sin_cache, + pos_ids, + q_rope_out, + k_rope_out, + q_nope_out, + k_nope_out, + quant_scale_q, + quant_scale_kv, + num_tokens, + num_qo_heads, + num_kv_heads, + rope_dim, + no_rope_dim, + total_blocks_y, + ) + return q_rope_out, k_rope_out, q_nope_out, k_nope_out diff --git a/src/tilegym/suites/flashinfer/ops.py b/src/tilegym/suites/flashinfer/ops.py new file mode 100644 index 00000000..5722e6bb --- /dev/null +++ b/src/tilegym/suites/flashinfer/ops.py @@ -0,0 +1,468 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +""" +FlashInfer Suite Operations Interface + +This module provides operation interfaces for FlashInfer-specific operations. +These operations are automatically dispatched to the appropriate backend implementation. +""" + +from typing import Optional +from typing import Tuple + +import torch + +from tilegym.backend import dispatch +from tilegym.backend import get_current_backend + +# ============================================================================ +# GEMM Operations +# ============================================================================ + + +@dispatch( + "flashinfer.gemm.gemm_alpha_beta", +) +def gemm_alpha_beta( + a: torch.Tensor, + b: torch.Tensor, + c: torch.Tensor, + trans_a: bool = False, + trans_b: bool = True, + alpha: float = 1.0, + beta: float = 0.0, + num_sms: int = 1, +): + """ + FlashInfer GEMM operation: Matrix multiplication with alpha/beta scaling. + + Computes: C = alpha * A @ B + beta * C + + Args: + a: Input matrix A [M, K] or [K, M] if trans_a=True + b: Input matrix B [K, N] or [N, K] if trans_b=True + c: Input/output matrix C [M, N] (modified in-place) + trans_a: Whether to transpose A + trans_b: Whether to transpose B + alpha: Scaling factor for A @ B + beta: Scaling factor for existing C + + Returns: + torch.Tensor: Modified C tensor (C = alpha * A @ B + beta * C) + """ + raise NotImplementedError(f"flashinfer.gemm.gemm_alpha_beta is not implemented for {get_current_backend()}") + + +@dispatch( + "flashinfer.gemm.masked_bmm", +) +def masked_bmm( + a: torch.Tensor, + b: torch.Tensor, + masked_m: torch.Tensor, + transpose_a: bool = False, + transpose_b: bool = False, + static_persistent: Optional[bool] = None, +): + """ + FlashInfer operation: Masked batch matrix multiplication. + + Performs batched matrix multiplication where each batch has a different valid M dimension. + + Args: + a: Input matrix A [Q, M, K] or [Q, K, M] if transpose_a=True + b: Input matrix B [Q, K, N] or [Q, N, K] if transpose_b=True + masked_m: Valid M dimensions for each batch [Q] + transpose_a: Whether to transpose A + transpose_b: Whether to transpose B + static_persistent: Whether to use static persistent scheduling + + Returns: + torch.Tensor: Output matrix C [Q, M, N] + """ + raise NotImplementedError(f"flashinfer.gemm.masked_bmm is not implemented for {get_current_backend()}") + + +@dispatch( + "flashinfer.gemm.ragged_bmm", +) +def ragged_bmm( + a: torch.Tensor, + b: torch.Tensor, + m_indptr: torch.Tensor, + max_m: int, + max_m_device: torch.Tensor = None, + transpose_a: bool = False, + transpose_b: bool = True, + out_dtype: torch.dtype = None, +): + """ + FlashInfer operation: Ragged batch matrix multiplication. + + Performs batched matrix multiplication where batches have non-uniform M dimensions. + Matrix A is flattened with m_indptr defining batch boundaries. + + Args: + a: Flattened input matrix A [total_m, K] or [K, total_m] if transpose_a=True + b: Input matrix B [Q, K, N] or [Q, N, K] if transpose_b=True + m_indptr: Segment offsets marking batch boundaries [Q+1] + max_m: Maximum M dimension across all batches + max_m_device: Optional device tensor containing max_m (for CUDA graph compatibility) + transpose_a: Whether to transpose A + transpose_b: Whether to transpose B + out_dtype: Optional output dtype override + + Returns: + torch.Tensor: Flattened output matrix C [total_m, N] + """ + raise NotImplementedError(f"flashinfer.gemm.ragged_bmm is not implemented for {get_current_backend()}") + + +@dispatch( + "flashinfer.gemm.ragged_block_scaled_bmm", +) +def ragged_block_scaled_bmm( + a: torch.Tensor, + b: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + m_indptr: torch.Tensor, + max_m: int, + max_m_device: torch.Tensor = None, + transpose_a: bool = False, + transpose_b: bool = True, + out_dtype: torch.dtype = None, +): + """ + FlashInfer operation: Ragged batch matrix multiplication with block-wise scaling. + + Performs ragged batched matrix multiplication with per-block scaling factors. + + Args: + a: Flattened input matrix A [total_m, K] + b: Input matrix B [Q, N, K] or [Q, K, N] + a_scale: Scaling factors for A blocks + b_scale: Scaling factors for B blocks + m_indptr: Segment offsets marking batch boundaries [Q+1] + max_m: Maximum M dimension across all batches + max_m_device: Optional device tensor containing max_m (for CUDA graph compatibility) + transpose_a: Whether to transpose A + transpose_b: Whether to transpose B + out_dtype: Optional output dtype override + + Returns: + torch.Tensor: Flattened output matrix C [total_m, N] + """ + raise NotImplementedError(f"flashinfer.gemm.ragged_block_scaled_bmm is not implemented for {get_current_backend()}") + + +# ============================================================================ +# RoPE Operations +# ============================================================================ + + +@dispatch( + "flashinfer.rope.rope_quantize_fp8", +) +def rope_quantize_fp8( + q_rope: torch.Tensor, + k_rope: torch.Tensor, + q_nope: Optional[torch.Tensor], + k_nope: Optional[torch.Tensor], + cos_sin_cache: torch.Tensor, + pos_ids: torch.Tensor, + is_neox: bool = True, + quantize_dtype: Optional[torch.dtype] = None, + quant_scale_q: float = 1.0, + quant_scale_kv: float = 1.0, + q_rope_out: Optional[torch.Tensor] = None, + k_rope_out: Optional[torch.Tensor] = None, + q_nope_out: Optional[torch.Tensor] = None, + k_nope_out: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + FlashInfer operation: RoPE (Rotary Position Embedding) with FP8 quantization. + + Applies RoPE to q_rope and k_rope tensors, then quantizes all outputs to FP8. + Supports both interleave mode (is_neox=False, GPT-J style) and non-interleave mode. + + Example (MLA with head_size=576, rope_dim=64, no_rope_dim=512): + q_in: [num_tokens, num_qo_heads, 576] # e.g., [N, 128, 576] + k_in: [num_tokens, 576] # MLA uses 2D key tensor + + flashinfer.ops.rope_quantize_fp8( + q_rope=q_in[..., :64], # [N, 128, 64] + k_rope=k_in[..., :64], # [N, 64] + q_nope=q_in[..., 64:], # [N, 128, 512] + k_nope=k_in[..., 64:], # [N, 512] + cos_sin_cache=cos_sin_cache, # [max_seq_len, rope_dim] + pos_ids=pos_ids, # [num_tokens] + is_neox=False, + q_rope_out=q_out[..., :64], + k_rope_out=k_out[..., :64], + q_nope_out=q_out[..., 64:], + k_nope_out=k_out[..., 64:], + ) + + Args: + q_rope: Query tensor for RoPE [num_tokens, num_qo_heads, rope_dim] + k_rope: Key tensor for RoPE [num_tokens, rope_dim] (MLA 2D) or [num_tokens, num_kv_heads, rope_dim] (3D) + q_nope: Query tensor without RoPE [num_tokens, num_qo_heads, no_rope_dim] or None + k_nope: Key tensor without RoPE [num_tokens, no_rope_dim] (MLA 2D) or None + cos_sin_cache: Precomputed cos/sin cache [max_seq_len, rope_dim] + pos_ids: Position IDs [num_tokens] + is_neox: Whether to use NeoX-style RoPE (True) or interleave/GPT-J style (False) + quantize_dtype: Output dtype for quantization (default: float8_e4m3fn) + quant_scale_q: Quantization scale for Q tensors + quant_scale_kv: Quantization scale for K tensors + q_rope_out: Optional pre-allocated output for q_rope + k_rope_out: Optional pre-allocated output for k_rope + q_nope_out: Optional pre-allocated output for q_nope + k_nope_out: Optional pre-allocated output for k_nope + enable_pdl: Whether to enable PDL (Persistent Data Layout) + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + (q_rope_out, k_rope_out, q_nope_out, k_nope_out) + """ + raise NotImplementedError(f"flashinfer.rope.rope_quantize_fp8 is not implemented for {get_current_backend()}") + + +# ============================================================================ +# Per-token group 8-bit quantization +# ============================================================================ + + +@dispatch( + "flashinfer.quant.per_token_group_quant_8bit", +) +def per_token_group_quant_8bit( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dst_dtype: Optional[torch.dtype] = None, + column_major_scales: bool = False, + scale_tma_aligned: bool = False, + scale_ue8m0: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Per-token group 8-bit quantization (FP8 or INT8). + + Args: + x: Input tensor [num_tokens, hidden_dim], contiguous. + group_size: Group size for quantization. + eps: Minimum value to avoid division by zero. + dst_dtype: Output dtype (torch.float8_e4m3fn or torch.int8). + column_major_scales: If True, scale tensor layout is column-major. + scale_tma_aligned: If True, scale buffer is aligned for TMA (requires column_major_scales=True). + scale_ue8m0: If True, round scale to power of 2 (UE8M0) (requires column_major_scales=True). + + Returns: + (x_q, x_s): Quantized tensor and scale tensor. + """ + raise NotImplementedError( + f"flashinfer.quant.per_token_group_quant_8bit is not implemented for {get_current_backend()}" + ) + + +# ============================================================================ +# Attention Operations +# ============================================================================ + + +@dispatch( + "flashinfer.attention.decode_attention_kv_paged", +) +def decode_attention_kv_paged( + q, + k_cache, + v_cache, + actual_seq_lens, + block_tables, + k_scale, + v_scale, + max_seq_len: int = -1, + outputs: Optional[torch.Tensor] = None, + force_split_kv: bool = False, + force_persistent: bool = False, +): + """ + FlashInfer operation: Decode attention with paged KV cache. + + Args: + q: Query tensor + k_cache: Paged key cache + v_cache: Paged value cache + actual_seq_lens: Actual sequence lengths + block_tables: Block tables for paged attention + k_scale: Key scaling factor + v_scale: Value scaling factor + max_seq_len: Maximum sequence length + outputs: Optional pre-allocated output tensor + force_split_kv: Whether to force use split KV mode + force_persistent: Whether to force use persistent mode + + Returns: + torch.Tensor: Attention output + """ + raise NotImplementedError( + f"flashinfer.attention.decode_attention_kv_paged is not implemented for {get_current_backend()}" + ) + + +@dispatch( + "flashinfer.attention.decode_mla_kv_paged", +) +def decode_mla_kv_paged( + q, + q_rope, + kv_cache, + k_rope, + actual_seq_lens, + block_tables, + k_scale, + v_scale, + max_seq_len: int = -1, + outputs: Optional[torch.Tensor] = None, + force_split_kv: bool = False, + force_persistent: bool = False, +): + """ + FlashInfer operation: Decode MLA (Multi-Latent Attention) with paged KV cache. + + Args: + q: Query tensor + q_rope: Query RoPE embeddings + kv_cache: Paged KV cache + k_rope: Key RoPE embeddings + actual_seq_lens: Actual sequence lengths + block_tables: Block tables for paged attention + k_scale: Key scaling factor + v_scale: Value scaling factor + max_seq_len: Maximum sequence length + outputs: Optional pre-allocated output tensor + force_split_kv: Whether to force use split KV mode + force_persistent: Whether to force use persistent mode + + Returns: + torch.Tensor: MLA attention output + """ + raise NotImplementedError( + f"flashinfer.attention.decode_mla_kv_paged is not implemented for {get_current_backend()}" + ) + + +@dispatch( + "flashinfer.attention.prefill_attention_kv_paged", +) +def prefill_attention_kv_paged( + q, + k_cache, + v_cache, + actual_seq_lens_q, + actual_seq_lens_kv, + actual_seq_offset, + block_tables, + k_scale, + v_scale, + num_batch, + max_seq_len, + is_causal: bool = True, + outputs: Optional[torch.Tensor] = None, + out_lse: Optional[torch.Tensor] = None, + use_lpt_scheduler: bool = True, +): + """ + FlashInfer operation: Prefill attention with paged KV cache. + + Args: + q: Query tensor + k_cache: Paged key cache + v_cache: Paged value cache + actual_seq_lens_q: Actual query sequence lengths + actual_seq_lens_kv: Actual KV sequence lengths + actual_seq_offset: Sequence offsets + block_tables: Block tables for paged attention + k_scale: Key scaling factor + v_scale: Value scaling factor + num_batch: Number of batches + max_seq_len: Maximum sequence length + is_causal: Whether to apply causal masking + outputs: Optional pre-allocated output tensor + out_lse: Optional output log-sum-exp tensor + use_lpt_scheduler: Whether to use LPT scheduler + + Returns: + Tuple[torch.Tensor, torch.Tensor]: (attention output, lse output) + """ + raise NotImplementedError( + f"flashinfer.attention.prefill_attention_kv_paged is not implemented for {get_current_backend()}" + ) + + +@dispatch( + "flashinfer.attention.prefill_attention_kv_ragged", +) +def prefill_attention_kv_ragged( + q, + k_cache, + v_cache, + actual_seq_lens_q, + actual_seq_lens_kv, + actual_seq_offset, + block_tables, + k_scale, + v_scale, + num_batch, + max_seq_len, + is_causal: bool = True, + outputs: Optional[torch.Tensor] = None, + out_lse: Optional[torch.Tensor] = None, + use_lpt_scheduler: bool = True, +): + """ + FlashInfer operation: Prefill attention with ragged KV cache. + + Args: + q: Query tensor + k_cache: Ragged key cache + v_cache: Ragged value cache + actual_seq_lens_q: Actual query sequence lengths + actual_seq_lens_kv: Actual KV sequence lengths + actual_seq_offset: Sequence offsets + block_tables: Block tables (unused for ragged) + k_scale: Key scaling factor + v_scale: Value scaling factor + num_batch: Number of batches + max_seq_len: Maximum sequence length + is_causal: Whether to apply causal masking + outputs: Optional pre-allocated output tensor + out_lse: Optional output log-sum-exp tensor + use_lpt_scheduler: Whether to use LPT scheduler + + Returns: + Tuple[torch.Tensor, torch.Tensor]: (attention output, lse output) + """ + raise NotImplementedError( + f"flashinfer.attention.prefill_attention_kv_ragged is not implemented for {get_current_backend()}" + ) + + +__all__ = [ + # GEMM (cuTile backends available, no dot_scaled) + "gemm_alpha_beta", + "masked_bmm", + "ragged_bmm", + "ragged_block_scaled_bmm", + # RoPE + "rope_quantize_fp8", + # Quantization + "per_token_group_quant_8bit", + # Attention + "decode_attention_kv_paged", + "decode_mla_kv_paged", + "prefill_attention_kv_paged", + "prefill_attention_kv_ragged", +] diff --git a/tests/suites/flashinfer/__init__.py b/tests/suites/flashinfer/__init__.py new file mode 100644 index 00000000..87ef525d --- /dev/null +++ b/tests/suites/flashinfer/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT diff --git a/tests/suites/flashinfer/test_flashinfer_attention.py b/tests/suites/flashinfer/test_flashinfer_attention.py new file mode 100644 index 00000000..7e130739 --- /dev/null +++ b/tests/suites/flashinfer/test_flashinfer_attention.py @@ -0,0 +1,606 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import math +import os +import warnings + +import pytest +import torch + +import tilegym +from tests import common +from tests.test_utils import bsr_attention_sample +from tests.test_utils import cudnn_decode +from tests.test_utils import cudnn_prefill +from tilegym.suites.flashinfer import ops as flashinfer_ops + +# Check if cuDNN is available for tests that use it as reference +CUDNN_AVAILABLE = cudnn_prefill.CUDNN_AVAILABLE + + +def get_prefill_problem_configs(quick_run=False, full_run=False): + if quick_run: + return [(num_batch, s_kv, head_dim_qk) for num_batch in [4] for s_kv in [1024] for head_dim_qk in [128, 192]] + if full_run: + return [ + (num_batch, s_kv, head_dim_qk) + for num_batch in [1, 16, 32, 64, 100] + for s_kv in [256, 1024, 2048, 4096, 8192] + for head_dim_qk in [128, 192] + ] + else: + return ( + [ # small problem sizes + (1, s_kv, 128) for s_kv in [256, 1024] + ] + + [ # normal problem sizes + (16, 1024, head_dim_qk) for head_dim_qk in [128] + ] + + [ # large problem sizes + (100, 4096, head_dim_qk) for head_dim_qk in [128] + ] + ) + + +class Test_FlashInfer_PrefillPaged(common.PyTestCase): + _backends = ["cutile"] + + @pytest.mark.parametrize("dtype", ["float16"]) + @pytest.mark.parametrize("page_size", [128]) + @pytest.mark.parametrize("num_batch, s_kv, head_dim_qk", get_prefill_problem_configs(quick_run=True)) + @pytest.mark.parametrize("backend", _backends) + def test_op( + self, + page_size, + num_batch, + s_kv, + head_dim_qk, + backend, + dtype, + monkeypatch, + ): + monkeypatch.setenv("DISABLE_AUTOTUNE", "1") + # Convert string dtype to torch dtype + dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float8_e4m3fn": torch.float8_e4m3fn, + } + dtype = dtype_map[dtype] + self.setUp() + if backend != "pytorch" and tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + device = torch.device("cuda") + ( + q, + k_cache, + v_cache, + scale, + actual_seq_lens, + block_tables, + actual_seq_offset, + q_indptr, + k_indptr, + v_indptr, + o_indptr, + lse_indptr, + ) = bsr_attention_sample.generate_sample_data( + batch_size=num_batch, + max_seq_len=s_kv, + head_dim_qk=head_dim_qk, + page_size=page_size, + is_decode=False, + device=device, + dtype=torch.float16, + ) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) + total_num_pages = k_cache.shape[0] + num_qo_heads = q.shape[1] + num_kv_heads = k_cache.shape[1] + head_dim_qk = q.shape[-1] + head_dim_vo = v_cache.shape[-1] + lse = torch.zeros([q.shape[0], num_qo_heads], device=device, dtype=torch.float32) + k_scale = 3.444 + v_scale = 2.444 + + # Compute reference if cuDNN is available + out_ref, lse_ref = None, None + if CUDNN_AVAILABLE: + out_ref, lse_ref = cudnn_prefill.cudnn_batch_prefill_with_kv_cache( + q, + k_cache, + v_cache * v_scale, + scale * k_scale, + workspace_buffer, + max_token_per_sequence=s_kv, + max_sequence_kv=s_kv, + actual_seq_lens_q=actual_seq_lens, + actual_seq_lens_kv=actual_seq_lens, + block_tables=block_tables, + causal=True, + return_lse=True, + lse=lse, + is_cuda_graph_compatible=True, + batch_offsets_q=q_indptr, + batch_offsets_o=o_indptr, + batch_offsets_stats=lse_indptr, + ) + else: + warnings.warn("cuDNN not available, skipping reference computation") + + try: + out_impl, lse_impl = flashinfer_ops.prefill_attention_kv_paged( + q, + k_cache.transpose(1, 2), + v_cache.transpose(1, 2), + actual_seq_lens, + actual_seq_lens, + actual_seq_offset, + block_tables, + scale * k_scale, + v_scale, + num_batch, + s_kv, + ) + except Exception as e: + raise e + + if out_ref is not None: + torch.testing.assert_close(out_impl, out_ref, atol=1e-2, rtol=2e-1) + torch.testing.assert_close(lse_impl, lse_ref, atol=1e-2, rtol=2e-1) + + +def get_prefill_ragged_problem_configs(quick_run=False, full_run=False): + if quick_run: + return [(num_batch, s_kv, head_dim_qk) for num_batch in [4] for s_kv in [1024] for head_dim_qk in [128, 192]] + if full_run: + return [ + (num_batch, s_kv, head_dim_qk) + for num_batch in [1, 16, 32, 64, 100] + for s_kv in [256, 1024, 2048, 4096, 8192] + for head_dim_qk in [128, 192] + ] + else: + return ( + [ # small problem sizes + (1, s_kv, 192) for s_kv in [256, 1024] + ] + + [ # normal problem sizes + (16, 1024, head_dim_qk) for head_dim_qk in [192] + ] + + [ # normal problem sizes + (16, 8192, head_dim_qk) for head_dim_qk in [192] + ] + + [ # large problem sizes + (32, 8192, head_dim_qk) for head_dim_qk in [192] + ] + + [ # large problem sizes + (100, 4096, head_dim_qk) for head_dim_qk in [192] + ] + ) + + +class Test_FlashInfer_PrefillRagged(common.PyTestCase): + _backends = ["cutile"] + + @pytest.mark.parametrize("dtype", ["float16"]) + @pytest.mark.parametrize("num_batch, s_kv, head_dim_qk", get_prefill_ragged_problem_configs(quick_run=True)) + @pytest.mark.parametrize("backend", _backends) + def test_op( + self, + num_batch, + s_kv, + head_dim_qk, + dtype, + backend, + monkeypatch, + ): + monkeypatch.setenv("DISABLE_AUTOTUNE", "1") + # Convert string dtype to torch dtype + dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float8_e4m3fn": torch.float8_e4m3fn, + } + dtype = dtype_map[dtype] + self.setUp() + if backend != "pytorch" and tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + device = torch.device("cuda") + ( + q, + k_cache, + v_cache, + scale, + actual_seq_lens, + block_tables, + actual_seq_offset, + q_indptr, + k_indptr, + v_indptr, + o_indptr, + lse_indptr, + ) = bsr_attention_sample.generate_sample_data( + batch_size=num_batch, + max_seq_len=s_kv, + head_dim_qk=head_dim_qk, + is_decode=False, + device=device, + dtype=dtype, + ) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) + lse = torch.zeros([q.shape[0], q.shape[1]], device=device, dtype=torch.float32) + + k_scale = 2.414 + v_scale = 1.414 + + # Compute reference if cuDNN is available + out_ref, lse_ref = None, None + if CUDNN_AVAILABLE: + out_ref, lse_ref = cudnn_prefill.cudnn_batch_prefill_with_kv_cache( + q, + k_cache, + v_cache * v_scale, + scale * k_scale, + workspace_buffer, + max_token_per_sequence=s_kv, + max_sequence_kv=s_kv, + actual_seq_lens_q=actual_seq_lens, + actual_seq_lens_kv=actual_seq_lens, + block_tables=None, + causal=True, + return_lse=True, + lse=lse, + is_cuda_graph_compatible=True, + batch_offsets_q=q_indptr, + batch_offsets_o=o_indptr, + batch_offsets_k=k_indptr, + batch_offsets_v=v_indptr, + batch_offsets_stats=lse_indptr, + ) + else: + warnings.warn("cuDNN not available, skipping reference computation") + + try: + out_impl, lse_impl = flashinfer_ops.prefill_attention_kv_ragged( + q, + k_cache, + v_cache, + actual_seq_lens, + actual_seq_lens, + actual_seq_offset, + block_tables, + scale * k_scale, + v_scale, + num_batch, + s_kv, + ) + except Exception as e: + raise e + + if out_ref is not None: + torch.testing.assert_close(out_impl, out_ref, atol=1e-2, rtol=2e-1) + torch.testing.assert_close(lse_impl, lse_ref, atol=1e-2, rtol=2e-1) + + +def get_decoding_problem_configs(quick_run=False, full_run=False): + if quick_run: + return [(num_batch, s_kv, page_size) for num_batch in [4] for s_kv in [1024] for page_size in [128]] + if full_run: + return [ + (num_batch, s_kv, page_size) + for num_batch in [1, 16, 32, 64, 200] + for s_kv in [256, 1024, 2048, 4096, 8192] + for page_size in [64, 128, 256] + ] + else: + return ( + [ # small problem sizes + (1, skv, ps) for skv in [256, 1024] for ps in [32, 64] + ] + + [ # normal problem sizes + (16, skv, ps) for skv in [1024, 2048] for ps in [32, 64] + ] + + [ # large problem sizes + (200, skv, ps) for skv in [8192] for ps in [32, 64] + ] + ) + + +class Test_FlashInfer_DecodePaged(common.PyTestCase): + _backends = ["cutile"] + + @pytest.mark.parametrize("dtype", ["float16"]) + @pytest.mark.parametrize("num_batch, s_kv, page_size", get_decoding_problem_configs(quick_run=True)) + @pytest.mark.parametrize("backend", _backends) + def test_op( + self, + num_batch, + s_kv, + page_size, + backend, + dtype, + ): + # Convert string dtype to torch dtype + dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float8_e4m3fn": torch.float8_e4m3fn, + } + dtype = dtype_map[dtype] + self.setUp() + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + device = torch.device("cuda") + ( + q, + k_cache, + v_cache, + scale, + actual_seq_lens, + block_tables, + actual_seq_offset, + q_indptr, + k_indptr, + v_indptr, + o_indptr, + lse_indptr, + ) = bsr_attention_sample.generate_sample_data( + batch_size=num_batch, + max_seq_len=s_kv, + page_size=page_size, + is_decode=True, + device=device, + dtype=dtype, + ) + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) + + k_scale = 3.678 + v_scale = 2.678 + + # Compute reference if cuDNN is available + out_ref = None + if CUDNN_AVAILABLE: + out_ref = cudnn_decode.cudnn_batch_decode_with_kv_cache( + q, + k_cache, + v_cache * v_scale, + scale * k_scale, + workspace_buffer, + max_sequence_kv=s_kv, + actual_seq_lens_kv=actual_seq_lens, + block_tables=block_tables, + is_cuda_graph_compatible=True, + batch_offsets_q=q_indptr, + batch_offsets_o=o_indptr, + ) + else: + warnings.warn("cuDNN not available, skipping reference computation") + + max_seq_len = actual_seq_lens.cpu().max().item() + + out_ref0 = flashinfer_ops.decode_attention_kv_paged( + q, + k_cache.transpose(1, 2), + v_cache.transpose(1, 2), + actual_seq_lens, + block_tables, + scale * k_scale, + v_scale, + max_seq_len=max_seq_len, + force_split_kv=False, + force_persistent=False, + ) + if out_ref is not None: + torch.testing.assert_close(out_ref0, out_ref, atol=1e-2, rtol=2e-1) + + out_ref1 = flashinfer_ops.decode_attention_kv_paged( + q, + k_cache.transpose(1, 2), + v_cache.transpose(1, 2), + actual_seq_lens, + block_tables, + scale * k_scale, + v_scale, + max_seq_len=max_seq_len, + force_split_kv=False, + force_persistent=True, + ) + if out_ref is not None: + torch.testing.assert_close(out_ref1, out_ref, atol=1e-2, rtol=2e-1) + + out_ref2 = flashinfer_ops.decode_attention_kv_paged( + q, + k_cache.transpose(1, 2), + v_cache.transpose(1, 2), + actual_seq_lens, + block_tables, + scale * k_scale, + v_scale, + max_seq_len=max_seq_len, + force_split_kv=True, + force_persistent=False, + ) + if out_ref is not None: + torch.testing.assert_close(out_ref2, out_ref, atol=1e-2, rtol=2e-1) + + out_ref3 = torch.empty_like(out_ref0) + out_ref3 = flashinfer_ops.decode_attention_kv_paged( + q, + k_cache.transpose(1, 2), + v_cache.transpose(1, 2), + actual_seq_lens, + block_tables, + scale * k_scale, + v_scale, + max_seq_len=max_seq_len, + force_split_kv=True, + force_persistent=False, + outputs=out_ref3, + ) + torch.testing.assert_close(out_ref2, out_ref3) + + +class Test_FlashInfer_MLADecodePaged(common.PyTestCase): + _backends = ["cutile"] + + @pytest.mark.parametrize("dtype", ["float16"]) + @pytest.mark.parametrize("num_batch, s_kv, page_size", get_decoding_problem_configs(quick_run=True)) + @pytest.mark.parametrize("num_heads", [32]) + @pytest.mark.parametrize("backend", _backends) + def test_op( + self, + num_batch, + s_kv, + page_size, + num_heads, + backend, + dtype, + ): + if torch.cuda.get_device_capability()[0] == 12: + pytest.xfail("Skip due to random result mismatch in sm120") + + # Convert string dtype to torch dtype + dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float8_e4m3fn": torch.float8_e4m3fn, + } + dtype = dtype_map[dtype] + self.setUp() + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + device = torch.device("cuda") + head_dim_rope = 64 + head_dim_qk = 512 + num_qo_heads = num_heads + ( + (q, q_rope), + kv_cache, + k_rope, + scale, + actual_seq_lens, + block_tables, + actual_seq_offset, + q_indptr, + _, + _, + _, + _, + ) = bsr_attention_sample.generate_sample_data( + batch_size=num_batch, + max_seq_len=s_kv, + page_size=page_size, + heads=num_qo_heads, + group_size=num_qo_heads, + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_qk, + head_dim_rope=head_dim_rope, + is_decode=True, + device=device, + dtype=dtype, + ) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) + + k_scale = 1.678 + v_scale = 2.821 + + # Compute reference if cuDNN is available + out_ref = None + if CUDNN_AVAILABLE: + out_ref = cudnn_decode.cudnn_batch_decode_with_kv_cache( + q, + kv_cache, + kv_cache * v_scale, + scale * k_scale, + workspace_buffer, + max_sequence_kv=s_kv, + actual_seq_lens_kv=actual_seq_lens, + block_tables=block_tables, + is_cuda_graph_compatible=True, + batch_offsets_q=q_indptr, + batch_offsets_o=q_indptr, + ) + else: + warnings.warn("cuDNN not available, skipping reference computation") + + max_seq_len = actual_seq_lens.cpu().max().item() + out_ref0 = flashinfer_ops.decode_mla_kv_paged( + q, + q_rope.zero_(), + kv_cache.reshape(-1, page_size, head_dim_qk), + k_rope.reshape(-1, page_size, head_dim_rope), + actual_seq_lens, + block_tables, + scale * k_scale, + v_scale, + max_seq_len=max_seq_len, + force_split_kv=False, + force_persistent=False, + ) + if out_ref is not None: + torch.testing.assert_close(out_ref0, out_ref, atol=1e-2, rtol=2e-1) + + out_ref1 = flashinfer_ops.decode_mla_kv_paged( + q, + q_rope.zero_(), + kv_cache.reshape(-1, page_size, head_dim_qk), + k_rope.reshape(-1, page_size, head_dim_rope), + actual_seq_lens, + block_tables, + scale * k_scale, + v_scale, + max_seq_len=max_seq_len, + force_split_kv=False, + force_persistent=True, + ) + if out_ref is not None: + torch.testing.assert_close(out_ref1, out_ref, atol=1e-2, rtol=2e-1) + + out_ref2 = flashinfer_ops.decode_mla_kv_paged( + q, + q_rope.zero_(), + kv_cache.reshape(-1, page_size, head_dim_qk), + k_rope.reshape(-1, page_size, head_dim_rope), + actual_seq_lens, + block_tables, + scale * k_scale, + v_scale, + max_seq_len=max_seq_len, + force_split_kv=True, + force_persistent=False, + ) + if out_ref is not None: + torch.testing.assert_close(out_ref2, out_ref, atol=1e-2, rtol=2e-1) + + out_ref3 = torch.empty_like(out_ref0) + out_ref3 = flashinfer_ops.decode_mla_kv_paged( + q, + q_rope.zero_(), + kv_cache.reshape(-1, page_size, head_dim_qk), + k_rope.reshape(-1, page_size, head_dim_rope), + actual_seq_lens, + block_tables, + scale * k_scale, + v_scale, + max_seq_len=max_seq_len, + force_split_kv=True, + force_persistent=False, + outputs=out_ref3, + ) + torch.testing.assert_close(out_ref2, out_ref3) diff --git a/tests/suites/flashinfer/test_flashinfer_gemm_alpha_beta.py b/tests/suites/flashinfer/test_flashinfer_gemm_alpha_beta.py new file mode 100644 index 00000000..3d826a94 --- /dev/null +++ b/tests/suites/flashinfer/test_flashinfer_gemm_alpha_beta.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import random + +import pytest +import torch + +import tilegym +from tests import common +from tilegym.suites import flashinfer + + +class Test_FlashInfer_Matmul_Alpha_Beta(common.PyTestCase): + @staticmethod + def reference(a, b, c, trans_a=False, trans_b=True, alpha=1.0, beta=0.0, dtype=torch.bfloat16): + if trans_a: + a = a.t() + if trans_b: + b = b.t() + return torch.addmm(c, a.to(dtype), b.to(dtype), beta=beta, alpha=alpha, out=c).to(dtype) + + @staticmethod + def prepare_data(m, n, k, trans_a, trans_b, dtype): + device = torch.device("cuda") + + a_size = m * k + b_size = k * n + a = torch.rand(a_size, device=device, dtype=torch.float16).to(dtype) + b = torch.rand(b_size, device=device, dtype=torch.float16).to(dtype) + + if trans_a: + a = a.view(k, m) + else: + a = a.view(m, k) + if trans_b: + b = b.view(n, k) + else: + b = b.view(k, n) + + return a, b + + @pytest.mark.parametrize("framework", ["cutile"]) # cutile only; other backends OOM on all configs on this GPU + @pytest.mark.parametrize("m, n, k", [(4096, 4096, 4096), (8192, 8192, 8192)]) + @pytest.mark.parametrize( + "dtype, out_dtype", + [(torch.float16, torch.float16), (torch.bfloat16, torch.bfloat16), (torch.float8_e4m3fn, torch.bfloat16)], + ) + @pytest.mark.parametrize("trans_a, trans_b", [(False, True)]) + @pytest.mark.parametrize("alpha, beta", [(1.0, 0.0), (1.5, 2.0)]) + def test_op(self, framework, m, n, k, dtype, out_dtype, trans_a, trans_b, alpha, beta): + if tilegym.is_backend_available(framework): + tilegym.set_backend(framework) + else: + pytest.skip(f"Backend {framework} is not available") + + torch.manual_seed(0) + random.seed(0) + a, b = self.prepare_data(m, n, k, trans_a, trans_b, dtype) + c = torch.rand((m, n), device=a.device, dtype=out_dtype) + ref_c = c.clone() + + framework_fn = lambda: flashinfer.ops.gemm_alpha_beta(a, b, c, trans_a, trans_b, alpha, beta) + self.assertCorrectness( + framework_fn, + lambda: self.reference(a, b, ref_c, trans_a, trans_b, alpha, beta, out_dtype), + kwargs={}, + atol=1e-2, + rtol=1e-2, + ) diff --git a/tests/suites/flashinfer/test_flashinfer_masked_bmm.py b/tests/suites/flashinfer/test_flashinfer_masked_bmm.py new file mode 100644 index 00000000..47af4b4f --- /dev/null +++ b/tests/suites/flashinfer/test_flashinfer_masked_bmm.py @@ -0,0 +1,157 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import os +import random +import sys + +import pytest +import torch + +import tilegym +from tests import common +from tilegym.suites import flashinfer + + +def enumerate_m_grouped_masked(): + max_m = 4096 + + cases = [ + (6, 512), + # DeepGEMM default cases + (1, 1024), + (2, 512), + (4, 256), + ] + # more GB200 cases + num_experts = 288 + num_experts_per_token = 8 + for num_ranks in [4, 8, 16, 32, 36, 48, 72]: + for num_tokens in [64, 128, 256, 384, 512, 768, 1024]: + num_groups = num_experts // num_ranks + expected_m_per_group = num_tokens * num_experts_per_token // num_groups + cases.append((num_groups, expected_m_per_group)) + + for num_groups, expected_m_per_group in cases: + for n, k in ( + (4096, 7168), + (7168, 2048), + ): + yield dict( + num_groups=num_groups, + max_m=max_m, + expected_m_per_group=expected_m_per_group, + n=n, + k=k, + ) + + +def create_masked_m(num_groups, expected_m_per_group, max_m): + masked_m = torch.empty((num_groups,), dtype=torch.int32, device="cuda") + for j in range(num_groups): + masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3)) + assert masked_m.amax().item() <= max_m + return masked_m + + +class Test_FlashInfer_MaskedBMM(common.PyTestCase): + @staticmethod + def reference(a, b, m_mask, trans_a=False, trans_b=False): + if trans_a: + a = torch.transpose(a, 1, 2) + if trans_b: + b = torch.transpose(b, 1, 2) + return torch.bmm(a, b) + + @staticmethod + def prepare_data(num_groups, max_m, expected_m_per_group, n, k, trans_a, trans_b, dtype): + device = torch.device("cuda") + q = num_groups + m = max_m + + if trans_a: + a_shape = (q, k, m) + else: + a_shape = (q, m, k) + + if trans_b: + b_shape = (q, n, k) + else: + b_shape = (q, k, n) + + m_mask = create_masked_m( + num_groups=num_groups, + expected_m_per_group=expected_m_per_group, + max_m=max_m, + ) + + a = torch.rand(a_shape, device=device, dtype=dtype, requires_grad=False) + b = torch.rand(b_shape, device=device, dtype=dtype, requires_grad=False) + + a_impl = a.clone() + b_impl = b.clone() + # Set all the element beyond the m_mask as 0 + if trans_a: + for i in range(num_groups): + a[i, :, m_mask[i] :] = 0 + else: + for i in range(num_groups): + a[i, m_mask[i] :, :] = 0 + + return a, b, a_impl, b_impl, m_mask + + @pytest.mark.parametrize( + "num_groups, max_m, expected_m_per_group, n, k", + [ + ( + case["num_groups"], + case["max_m"], + case["expected_m_per_group"], + case["n"], + case["k"], + ) + for case in list(enumerate_m_grouped_masked())[:2] # Use smaller set for correctness + ], + ) + @pytest.mark.parametrize("dtype", [torch.float16]) + @pytest.mark.parametrize("trans_a", [False, True]) + @pytest.mark.parametrize("trans_b", [False, True]) + @pytest.mark.parametrize( + "framework", + [ + "cutile", + ], + ) + def test_op( + self, + num_groups, + max_m, + expected_m_per_group, + n, + k, + dtype, + trans_a, + trans_b, + framework: str, + ): + _impl_fw = ["cutile"] + if framework not in _impl_fw: + pytest.skip(f"Framework {framework} not supported") + if tilegym.is_backend_available(framework): + tilegym.set_backend(framework) + else: + pytest.skip(f"Backend {framework} is not available") + + torch.manual_seed(0) + random.seed(0) + a_ref, b_ref, a_impl, b_impl, m_mask = self.prepare_data( + num_groups, max_m, expected_m_per_group, n, k, trans_a, trans_b, dtype + ) + + ref_c = self.reference(a_ref, b_ref, m_mask, trans_a, trans_b) + c = flashinfer.ops.masked_bmm(a_impl, b_impl, m_mask, trans_a, trans_b) + + for i in range(num_groups): + c[i, m_mask[i] :, :] = 0 + torch.testing.assert_close(ref_c, c, atol=1e-2, rtol=1e-2) diff --git a/tests/suites/flashinfer/test_flashinfer_ragged_block_scaled_bmm.py b/tests/suites/flashinfer/test_flashinfer_ragged_block_scaled_bmm.py new file mode 100644 index 00000000..6bb9c5fd --- /dev/null +++ b/tests/suites/flashinfer/test_flashinfer_ragged_block_scaled_bmm.py @@ -0,0 +1,271 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import random + +import pytest +import torch + +import tilegym +from tests import common +from tilegym.suites import flashinfer + + +class Test_FlashInfer_RaggedBlockScaledBMM(common.PyTestCase): + @staticmethod + def create_ragged_m_segments(num_groups, m, ELEM_PER_BYTE_A, alignment=16): + """Create non-even M segments for ragged BMM. + + Args: + num_groups: Number of groups/batches + m: Average segment size + ELEM_PER_BYTE_A: Elements per byte for A matrix + alignment: Segment size alignment (default 16, use 128 for CuTile) + """ + # Create random segment sizes that sum to approximately total_m + total_m = num_groups * m + segment_sizes = [] + num_items = alignment * ELEM_PER_BYTE_A + + # Generate random segment sizes + for i in range(num_groups - 1): + # Random size between 0.5x and 1.5x expected size + size = int(m * random.uniform(0.5, 1.5)) + size = (size // num_items) * num_items + segment_sizes.append(size) + + # Last segment gets the remaining size + remaining = total_m - sum(segment_sizes) + assert remaining > 0 and remaining % num_items == 0 + segment_sizes.append(remaining) + + # Create segment offsets + segment_offsets = torch.zeros(num_groups + 1, dtype=torch.int32, device="cuda") + for i in range(num_groups): + segment_offsets[i + 1] = segment_offsets[i] + segment_sizes[i] + + max_m = max(segment_sizes) + + return max_m, segment_offsets + + @staticmethod + def create_aligned_m_segments(num_groups, m, block_m=128): + """Create M segments aligned to BLOCK_M for CuTile. + + CuTile's tile-based indexing requires segment offsets to be + multiples of BLOCK_M for correct operation. + + Args: + num_groups: Number of groups/batches + m: Segment size (should be multiple of block_m) + block_m: Block size for M dimension (default 128) + """ + # Ensure m is a multiple of block_m + aligned_m = ((m + block_m - 1) // block_m) * block_m + total_m = num_groups * aligned_m + + # Create even segment offsets (all segments same size) + segment_offsets = torch.zeros(num_groups + 1, dtype=torch.int32, device="cuda") + for i in range(num_groups): + segment_offsets[i + 1] = segment_offsets[i] + aligned_m + + max_m = aligned_m + + return max_m, segment_offsets, aligned_m + + @staticmethod + def reference( + a_fp8, + b_fp8, + a_scale, + b_scale, + segment_offsets, + block_n, + block_k, + trans_a=False, + trans_b=True, + out_dtype=torch.bfloat16, + ): + """ + PyTorch reference for ragged BMM with non-even M segments. + Matrix a is flattened with segment_offsets defining the boundaries. + a_scale and b_scale are block-level scales that need to be expanded. + """ + + a = a_fp8.float() + b = b_fp8.float() + # Get dimensions + total_m, K = a.shape + Q, N, K_b = b.shape + + assert K == K_b, f"K dimensions must match: {K} != {K_b}" + + # Initialize output tensor + c = torch.zeros((total_m, N), device=a.device, dtype=out_dtype) + + # Process each segment + for q in range(Q): + start_offset = segment_offsets[q].item() + end_offset = segment_offsets[q + 1].item() + segment_size = end_offset - start_offset + assert segment_size > 0 + + # Extract segment from flattened matrix a + a_segment = a[start_offset:end_offset, :] # Shape: [segment_size, K] + a_scale_segment = a_scale[start_offset:end_offset, :] # Shape: [segment_size, k_tiles] + + b_segment = b[q, :, :] # Shape: [N, K] + b_scale_segment = b_scale[q, :, :] # Shape: [n_tiles, k_tiles] + + # Expand block-level scales to match data dimensions + # a_scale: [segment_size, k_tiles] -> [segment_size, K] + a_scale_expanded = torch.repeat_interleave(a_scale_segment, block_k, dim=1)[:, :K] + + # b_scale: [n_tiles, k_tiles] -> [N, K] + b_scale_expanded = torch.repeat_interleave(b_scale_segment, block_n, dim=0)[:N, :] + b_scale_expanded = torch.repeat_interleave(b_scale_expanded, block_k, dim=1)[:, :K] + + # Compute matrix multiplication for this segment + # (a * a_scale) @ (b * b_scale).T + c_segment = torch.mm(a_segment * a_scale_expanded, (b_segment * b_scale_expanded).t()).to(out_dtype) + + # Store the result in the output tensor + c[start_offset:end_offset, :] = c_segment + + return c + + @staticmethod + def prepare_data( + num_groups, + M, + N, + K, + trans_a=False, + trans_b=True, + out_dtype=torch.bfloat16, + use_aligned_segments=False, + ): + Q = num_groups + assert trans_a == False and trans_b == True, "Only NT layout is supported" + device = torch.device("cuda") + factor_for_scale = 1e-2 + block_n = 128 + block_k = 128 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + if use_aligned_segments: + # CuTile requires segment offsets aligned to BLOCK_M (128) + max_m, segment_offsets, aligned_m = Test_FlashInfer_RaggedBlockScaledBMM.create_aligned_m_segments( + num_groups=num_groups, + m=M, + block_m=128, # BLOCK_M for CuTile + ) + total_m = segment_offsets[-1].item() + actual_m = aligned_m + else: + # Supports non-aligned segments + max_m, segment_offsets = Test_FlashInfer_RaggedBlockScaledBMM.create_ragged_m_segments( + num_groups=num_groups, + m=M, + ELEM_PER_BYTE_A=1, + alignment=16, + ) + total_m = segment_offsets[-1].item() + actual_m = M + assert total_m == num_groups * M + + A_fp32 = ( + (torch.rand(total_m, K, dtype=torch.float32, device=device).normal_(mean=0.0, std=0.3) - 0.5) * 2 * fp8_max + ) + A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + B_fp32 = ( + (torch.rand(Q, N, K, dtype=torch.float32, device=device).normal_(mean=0.0, std=0.3) - 0.5) * 2 * fp8_max + ) + B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + As = torch.rand(total_m, k_tiles, dtype=torch.float32, device=device) * factor_for_scale + Bs = torch.rand(Q, n_tiles, k_tiles, dtype=torch.float32, device=device) * factor_for_scale + + ref_c = Test_FlashInfer_RaggedBlockScaledBMM.reference( + A_fp8, + B_fp8, + As, + Bs, + segment_offsets, + block_n, + block_k, + trans_a, + trans_b, + out_dtype, + ) + + return A_fp8, B_fp8, As, Bs, ref_c, segment_offsets, max_m + + @pytest.mark.parametrize( + "framework", + [ + "cutile", + ], + ) + @pytest.mark.parametrize("num_groups", [4, 16]) + @pytest.mark.parametrize("m", [128, 512]) + @pytest.mark.parametrize("n, k", [(2048, 2048), (4096, 4096)]) + @pytest.mark.parametrize("dtype, out_dtype", [(torch.float8_e4m3fn, torch.bfloat16)]) + @pytest.mark.parametrize("trans_a, trans_b", [(False, True)]) + def test_op(self, framework, num_groups, m, n, k, dtype, out_dtype, trans_a, trans_b, arch): + if torch.cuda.get_device_capability(0) == (10, 3): + pytest.xfail("Skip on sm103: due to CUDA error: unspecified launch failure") + + self.setUp() + _impl_fw = ["cutile"] + if framework not in _impl_fw: + pytest.skip(f"Framework {framework} not supported") + if tilegym.is_backend_available(framework): + tilegym.set_backend(framework) + else: + pytest.skip(f"Backend {framework} is not available") + + if arch == "sm80" and "float8" in dtype.__repr__(): + pytest.skip("FP8 is not supported on sm80 (Ampere).") + + use_aligned_segments = framework == "cutile" + + ( + a, + b, + a_scale, + b_scale, + ref_c, + segment_offsets, + max_m, + ) = self.prepare_data( + num_groups, + m, + n, + k, + trans_a, + trans_b, + out_dtype, + use_aligned_segments=use_aligned_segments, + ) + + c = flashinfer.ops.ragged_block_scaled_bmm( + a, + b, + a_scale, + b_scale, + segment_offsets, + max_m, + max_m_device=None, + transpose_a=trans_a, + transpose_b=trans_b, + out_dtype=out_dtype, + ) + + torch.testing.assert_close(ref_c, c, atol=1.0, rtol=1.0) diff --git a/tests/suites/flashinfer/test_flashinfer_ragged_bmm.py b/tests/suites/flashinfer/test_flashinfer_ragged_bmm.py new file mode 100644 index 00000000..5930ba46 --- /dev/null +++ b/tests/suites/flashinfer/test_flashinfer_ragged_bmm.py @@ -0,0 +1,298 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import os +import random +import sys + +import pytest +import torch + +import tilegym +from tests import common +from tilegym.suites import flashinfer + + +def create_ragged_m_segments(num_groups, m, dtype, align_to=None): + """Create non-even M segments for ragged BMM + + Args: + num_groups: Number of batches/groups + m: Average segment size + dtype: Data type + align_to: If specified, align segment sizes to this value + """ + total_m = num_groups * m + segment_sizes = [] + itemsize = dtype.itemsize + num_items = 16 // itemsize + + # Use align_to if specified, otherwise use default alignment + alignment = align_to if align_to is not None else num_items + + # Generate random segment sizes + for i in range(num_groups - 1): + size = int(m * random.uniform(0.5, 1.5)) + size = (size // alignment) * alignment + if size < alignment: + size = alignment + segment_sizes.append(size) + + remaining = total_m - sum(segment_sizes) + remaining = (remaining // alignment) * alignment + if remaining < alignment: + remaining = alignment + segment_sizes.append(remaining) + + actual_total_m = sum(segment_sizes) + + segment_offsets = torch.zeros(num_groups + 1, dtype=torch.int32, device="cuda") + for i in range(num_groups): + segment_offsets[i + 1] = segment_offsets[i] + segment_sizes[i] + + max_m = max(segment_sizes) + return max_m, segment_offsets, actual_total_m + + +class Test_FlashInfer_RaggedBMM(common.PyTestCase): + @staticmethod + def reference(a, b, segment_offsets, trans_a=False, trans_b=True, out_dtype=None): + """ + PyTorch reference for ragged BMM with non-even M segments. + Matrix a is flattened with segment_offsets defining the boundaries. + """ + if trans_a: + a = torch.transpose(a, 0, 1) + if trans_b: + b = torch.transpose(b, 1, 2) + + total_m, K = a.shape + Q, K_b, N = b.shape + + if out_dtype is None: + out_dtype = a.dtype + + c = torch.zeros((total_m, N), device=a.device, dtype=out_dtype) + + for q in range(Q): + start_offset = segment_offsets[q].item() + end_offset = segment_offsets[q + 1].item() + segment_size = end_offset - start_offset + assert segment_size > 0 + a_segment = a[start_offset:end_offset, :] + b_segment = b[q, :, :] + c_segment = torch.mm(a_segment.to(out_dtype), b_segment.to(out_dtype)) + c[start_offset:end_offset, :] = c_segment + + return c + + @staticmethod + def prepare_data(num_groups, m, n, k, trans_a, trans_b, dtype, framework="cutile"): + device = torch.device("cuda") + + # For CuTile, we need segments aligned to BLOCK_M (128) + # This ensures segment offsets are multiples of the tile size + align_to = 128 if framework == "cutile" else None + + max_m, segment_offsets, actual_total_m = create_ragged_m_segments( + num_groups=num_groups, + m=m, + dtype=dtype, + align_to=align_to, + ) + + total_m = segment_offsets[-1].item() + + if trans_a: + a_shape = (k, total_m) + else: + a_shape = (total_m, k) + + if trans_b: + b_shape = (num_groups, n, k) + else: + b_shape = (num_groups, k, n) + + a = torch.rand(a_shape, device=device, dtype=torch.float16, requires_grad=False).to(dtype) + b = torch.rand(b_shape, device=device, dtype=torch.float16, requires_grad=False).to(dtype) + + return a, b, max_m, segment_offsets + + @pytest.mark.parametrize( + "framework", + [ + "cutile", + ], + ) + @pytest.mark.parametrize("trans_a", [False]) + @pytest.mark.parametrize("trans_b", [False, True]) + @pytest.mark.parametrize("dtype", [(torch.bfloat16)]) + @pytest.mark.parametrize("num_groups, m, n, k", [(4, 256, 256, 256), (2, 128, 128, 128), (4, 512, 512, 512)]) + def test_op_shapes(self, framework, trans_a, trans_b, dtype, num_groups, m, n, k): + # cutile kernel only supports (trans_a=False, trans_b=True) + if framework == "cutile" and (trans_a or not trans_b): + pytest.skip("ragged_bmm only supports trans_a=False, trans_b=True") + _impl_fw = ["cutile"] + if framework not in _impl_fw: + pytest.skip(f"Framework {framework} not supported") + if tilegym.is_backend_available(framework): + tilegym.set_backend(framework) + else: + pytest.skip(f"Backend {framework} is not available") + + torch.manual_seed(0) + random.seed(0) + out_dtype = dtype + a, b, max_m, segment_offsets = self.prepare_data(num_groups, m, n, k, trans_a, trans_b, dtype, framework) + + framework_fn = lambda: flashinfer.ops.ragged_bmm( + a, + b, + segment_offsets, + max_m, + None, + transpose_a=trans_a, + transpose_b=trans_b, + out_dtype=out_dtype, + ) + self.assertCorrectness( + framework_fn, + lambda: self.reference(a, b, segment_offsets, trans_a, trans_b, out_dtype), + kwargs={}, + atol=1e-2, + rtol=1e-2, + ) + + @pytest.mark.parametrize( + "framework", + [ + "cutile", + ], + ) + @pytest.mark.parametrize("dtype", [(torch.bfloat16)]) + @pytest.mark.parametrize("m, n, k", [(256, 256, 256)]) + @pytest.mark.parametrize("num_groups", [1, 4, 8]) + def test_op_num_groups(self, framework, dtype, m, n, k, num_groups): + _impl_fw = ["cutile"] + if framework not in _impl_fw: + pytest.skip(f"Framework {framework} not supported") + if tilegym.is_backend_available(framework): + tilegym.set_backend(framework) + else: + pytest.skip(f"Backend {framework} is not available") + + torch.manual_seed(0) + random.seed(0) + trans_a = False + trans_b = True + out_dtype = dtype + a, b, max_m, segment_offsets = self.prepare_data(num_groups, m, n, k, trans_a, trans_b, dtype, framework) + + framework_fn = lambda: flashinfer.ops.ragged_bmm( + a, + b, + segment_offsets, + max_m, + None, + transpose_a=trans_a, + transpose_b=trans_b, + out_dtype=out_dtype, + ) + self.assertCorrectness( + framework_fn, + lambda: self.reference(a, b, segment_offsets, trans_a, trans_b, out_dtype), + kwargs={}, + atol=1e-2, + rtol=1e-2, + ) + + @pytest.mark.parametrize( + "framework", + [ + "cutile", + ], + ) + @pytest.mark.parametrize("num_groups, m, n, k", [(4, 256, 256, 256)]) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float8_e4m3fn]) + def test_op_dtypes(self, framework, num_groups, m, n, k, dtype, arch): + _impl_fw = ["cutile"] + if framework not in _impl_fw: + pytest.skip(f"Framework {framework} not supported") + if tilegym.is_backend_available(framework): + tilegym.set_backend(framework) + else: + pytest.skip(f"Backend {framework} is not available") + + if arch == "sm80" and "float8" in dtype.__repr__(): + pytest.skip("FP8 is not supported on sm80 (Ampere).") + + torch.manual_seed(0) + random.seed(0) + trans_a = False + trans_b = True + out_dtype = torch.bfloat16 + a, b, max_m, segment_offsets = self.prepare_data(num_groups, m, n, k, trans_a, trans_b, dtype, framework) + + framework_fn = lambda: flashinfer.ops.ragged_bmm( + a, + b, + segment_offsets, + max_m, + None, + transpose_a=trans_a, + transpose_b=trans_b, + out_dtype=out_dtype, + ) + self.assertCorrectness( + framework_fn, + lambda: self.reference(a, b, segment_offsets, trans_a, trans_b, out_dtype), + kwargs={}, + atol=1e-2, + rtol=1e-2, + ) + + @pytest.mark.parametrize( + "framework", + [ + "cutile", + ], + ) + @pytest.mark.parametrize("dtype", [(torch.bfloat16)]) + @pytest.mark.parametrize("num_groups, m, n, k", [(4, 256, 256, 256)]) + @pytest.mark.parametrize("trans_a", [False, True]) + @pytest.mark.parametrize("trans_b", [False, True]) + def test_op_transpose(self, framework, dtype, num_groups, m, n, k, trans_a, trans_b): + # cutile kernel only supports (trans_a=False, trans_b=True) + if framework == "cutile" and (trans_a or not trans_b): + pytest.skip("ragged_bmm only supports trans_a=False, trans_b=True") + _impl_fw = ["cutile"] + if framework not in _impl_fw: + pytest.skip(f"Framework {framework} not supported") + if tilegym.is_backend_available(framework): + tilegym.set_backend(framework) + else: + pytest.skip(f"Backend {framework} is not available") + + torch.manual_seed(0) + random.seed(0) + out_dtype = dtype + a, b, max_m, segment_offsets = self.prepare_data(num_groups, m, n, k, trans_a, trans_b, dtype, framework) + + framework_fn = lambda: flashinfer.ops.ragged_bmm( + a, + b, + segment_offsets, + max_m, + None, + transpose_a=trans_a, + transpose_b=trans_b, + out_dtype=out_dtype, + ) + self.assertCorrectness( + framework_fn, + lambda: self.reference(a, b, segment_offsets, trans_a, trans_b, out_dtype), + kwargs={}, + atol=1e-2, + rtol=1e-2, + ) diff --git a/tests/suites/flashinfer/test_per_token_group_quant_8bit.py b/tests/suites/flashinfer/test_per_token_group_quant_8bit.py new file mode 100644 index 00000000..3cca2103 --- /dev/null +++ b/tests/suites/flashinfer/test_per_token_group_quant_8bit.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +from typing import Tuple + +import pytest +import torch + +import tilegym +from tests import common +from tilegym.suites.flashinfer import ops as tilegym_flashinfer_ops + + +def _sota_impl_sgl_kernel_only( + x: torch.Tensor, + group_size: int, + eps: float, + dst_dtype: torch.dtype, + column_major_scales: bool, + scale_ue8m0: bool, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Thin wrapper around sgl_kernel (no sglang dependency). Only supports FP8, row-major scales.""" + from sgl_kernel import sgl_per_token_group_quant_8bit + + assert dst_dtype in (torch.float8_e4m3fn, torch.float8_e5m2), "sgl_kernel v1 only supports FP8" + assert not column_major_scales, "sgl_kernel v1 only supports row-major scales" + finfo = torch.finfo(dst_dtype) + fp8_min = finfo.min + fp8_max = finfo.max + x_q = torch.empty_like(x, device=x.device, dtype=dst_dtype) + x_s = torch.empty( + x.shape[:-1] + (x.shape[-1] // group_size,), + device=x.device, + dtype=torch.float32, + ) + sgl_per_token_group_quant_8bit(x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0, enable_v2=False) + return x_q, x_s + + +def native_per_token_group_quant_8bit( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dst_dtype: torch.dtype = torch.float8_e4m3fn, + scale_ue8m0: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Reference implementation of per-token group 8-bit quantization using PyTorch.""" + assert x.shape[-1] % group_size == 0 + assert x.is_contiguous() + + if dst_dtype == torch.int8: + bit8_min = float(torch.iinfo(dst_dtype).min) + bit8_max = float(torch.iinfo(dst_dtype).max) + else: + bit8_min = torch.finfo(dst_dtype).min + bit8_max = torch.finfo(dst_dtype).max + + x_ = x.reshape(x.numel() // group_size, group_size) + amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32) + x_s = amax / bit8_max + if scale_ue8m0: + min_val = torch.tensor(1e-10, dtype=x_s.dtype, device=x_s.device) + x_s = torch.exp2(torch.ceil(torch.log2(torch.maximum(x_s.abs(), min_val)))) + x_q = (x_ / x_s).clamp(min=bit8_min, max=bit8_max).to(dst_dtype) + x_q = x_q.reshape(x.shape) + x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,)) + return x_q, x_s + + +class Test_Per_Token_Group_Quant_8bit(common.PyTestCase): + @pytest.mark.parametrize("num_tokens", [512, 513]) + @pytest.mark.parametrize("hidden_dim", [2048]) + @pytest.mark.parametrize("group_size", [16, 32, 64, 128]) + @pytest.mark.parametrize("dst_dtype", [torch.float8_e4m3fn, torch.int8]) + @pytest.mark.parametrize("column_major_scales", [False, True]) + @pytest.mark.parametrize("scale_tma_aligned", [False, True]) + @pytest.mark.parametrize("scale_ue8m0", [False, True]) + @pytest.mark.parametrize( + "framework", + [ + "cutile", + ], + ) + def test_op( + self, + num_tokens, + hidden_dim, + group_size, + dst_dtype, + column_major_scales, + scale_tma_aligned, + scale_ue8m0, + framework, + arch, + ): + if framework == "cutile": + if tilegym.is_backend_available("cutile"): + tilegym.set_backend("cutile") + else: + pytest.skip("CuTile backend is not available") + else: + pytest.skip(f"Framework {framework} not supported") + + # scale_ue8m0 is a Blackwell-only feature (arch > 9); skip when testing it on older archs. + arch_major, _ = torch.cuda.get_device_capability(torch.cuda.current_device()) + if scale_ue8m0 and arch_major <= 9: + pytest.skip("scale_ue8m0 only relevant on Blackwell (arch > 9)") + if scale_ue8m0 and not column_major_scales: + pytest.skip("scale_ue8m0 requires column_major_scales=True") + if scale_tma_aligned and not column_major_scales: + pytest.skip("scale_tma_aligned requires column_major_scales=True") + # Ampere (sm80) has no native FP8 support; kernel/backend may fail or be unsupported. + if arch == "sm80" and "float8" in dst_dtype.__repr__(): + pytest.skip("FP8 is not supported on sm80 (Ampere).") + + device = "cuda:0" + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + x = torch.randn(num_tokens, hidden_dim, dtype=torch.bfloat16, device=device) + + # Reference (torch) + x_q_ref, x_s_ref = native_per_token_group_quant_8bit( + x, + group_size=group_size, + eps=1e-10, + dst_dtype=dst_dtype, + scale_ue8m0=scale_ue8m0, + ) + + # Implementation under test + x_q, x_s = tilegym_flashinfer_ops.per_token_group_quant_8bit( + x, + group_size=group_size, + eps=1e-10, + dst_dtype=dst_dtype, + column_major_scales=column_major_scales, + scale_tma_aligned=scale_tma_aligned, + scale_ue8m0=scale_ue8m0, + ) + + # Compare quantized values (allow tiny diff for fp8/int8 rounding) + torch.testing.assert_close(x_q_ref.float(), x_q.float(), atol=1e-2, rtol=2e-1) + torch.testing.assert_close( + x_s_ref.contiguous(), + x_s.contiguous(), + rtol=1e-3, + atol=1e-5, + ) diff --git a/tests/suites/flashinfer/test_rope_quantize_fp8.py b/tests/suites/flashinfer/test_rope_quantize_fp8.py new file mode 100644 index 00000000..85e6bbcc --- /dev/null +++ b/tests/suites/flashinfer/test_rope_quantize_fp8.py @@ -0,0 +1,215 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +from typing import Optional +from typing import Tuple +from typing import Union + +import pytest +import torch + +import tilegym +from tests import common +from tilegym.suites.flashinfer import ops as tilegym_flashinfer_ops + + +# reference implementation of RotaryEmbedding +class RotaryEmbedding(torch.nn.Module): + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + device: str = "cuda:0", + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + self.device = device + cache = self._compute_cos_sin_cache() + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device) / self.rotary_dim) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float, device=self.device) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def _apply_rotary_emb( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, + ) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + + # Note: the is different from the vLLM's implementation, + # We added float32 conversion because float32 is required for the rotary embedding to work correctly for long contexts + query = query.to(torch.float32) + key = key.to(torch.float32) + + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = self._apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = self._apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + + query = query.to(self.dtype) + key = key.to(self.dtype) + return query, key + + +class Test_FlashInfer_RopeQuantizeFp8(common.PyTestCase): + @pytest.mark.parametrize("num_tokens", [1, 19, 128]) + @pytest.mark.parametrize("num_qo_heads", [128]) + @pytest.mark.parametrize("head_size", [576]) + @pytest.mark.parametrize("rotary_dim", [64]) + @pytest.mark.parametrize("max_position_embeddings", [4096]) + @pytest.mark.parametrize("base", [10000]) + @pytest.mark.parametrize("is_neox_style", [False]) + @pytest.mark.parametrize("input_dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + @pytest.mark.parametrize( + "framework", + [ + "cutile", + ], + ) + def test_op( + self, + num_tokens, + num_qo_heads, + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + input_dtype, + quant_dtype, + framework, + arch, + ): + _impl_fw = ["cutile"] + if framework not in _impl_fw: + pytest.skip(f"Framework {framework} not supported") + if tilegym.is_backend_available(framework): + tilegym.set_backend(framework) + else: + pytest.skip(f"Backend {framework} is not available") + if framework == "cutile": + pytest.xfail("CuTile rope_kernel: TileAS lowering to NVVM fails with invalid use-def chain ") + + if arch == "sm80" and "float8" in quant_dtype.__repr__(): + pytest.skip("FP8 is not supported on sm80 (Ampere).") + + device = "cuda:0" + # Fixed seed for reproducibility across tests + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + q_in = torch.randn(num_tokens, num_qo_heads, 576, dtype=input_dtype, device=device) + k_in = torch.randn(num_tokens, 576, dtype=input_dtype, device=device) + pos_ids = torch.arange(num_tokens, device=device) + + # reference implementation + rope_flashinfer = RotaryEmbedding( + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + input_dtype, + device, + ) + q_out_f16_ref, k_out_f16_ref = rope_flashinfer.forward_native(pos_ids, q_in, k_in) + q_out_f8_ref, k_out_f8_ref = map( + lambda x: x.to(quant_dtype), + (q_out_f16_ref, k_out_f16_ref), + ) + + # kernel implementation + q_out = torch.empty_like(q_in, dtype=quant_dtype) + k_out = torch.empty_like(k_in, dtype=quant_dtype) + tilegym_flashinfer_ops.rope_quantize_fp8( + q_in[..., :rotary_dim], + k_in[..., :rotary_dim], + q_in[..., rotary_dim:], + k_in[..., rotary_dim:], + rope_flashinfer.cos_sin_cache, + pos_ids, + is_neox=is_neox_style, + q_rope_out=q_out[..., :rotary_dim], + k_rope_out=k_out[..., :rotary_dim], + q_nope_out=q_out[..., rotary_dim:], + k_nope_out=k_out[..., rotary_dim:], + quant_scale_q=1.0, + quant_scale_kv=1.0, + ) + torch.testing.assert_close(q_out_f8_ref.float(), q_out.float(), atol=1e-2, rtol=2e-1) + torch.testing.assert_close(k_out_f8_ref.float(), k_out.float(), atol=1e-2, rtol=2e-1) diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py new file mode 100644 index 00000000..87ef525d --- /dev/null +++ b/tests/test_utils/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT diff --git a/tests/test_utils/bsr_attention_sample.py b/tests/test_utils/bsr_attention_sample.py new file mode 100644 index 00000000..c6e59285 --- /dev/null +++ b/tests/test_utils/bsr_attention_sample.py @@ -0,0 +1,175 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +# fmt: off +# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/benchmarks + + +import torch + + +def sample_actual_seq_lens(max_seqlen, batch_size, device, random_actual_seq_len): + """ + Get an array of actual sequence lengths for given batch size and max sequence length. + If random_actual_seq_len is True, sample actual sequence lengths randomly. + Otherwise, set all actual sequence lengths to max_seqlen. + + Args: + max_seqlen: Maximum sequence length. + batch_size: Batch size. + device: Device to sample on. + random_actual_seq_len: Whether to sample actual sequence lengths randomly. + + Returns: + actual_seq_lens: Actual sequence lengths for each batch. + """ + if random_actual_seq_len: + actual_seq_lens = torch.randint( + 1, max_seqlen + 1, (batch_size, 1, 1, 1), device=device, dtype=torch.int32 + ) + else: + actual_seq_lens = torch.full( + (batch_size, 1, 1, 1), max_seqlen, device=device, dtype=torch.int32 + ) + return actual_seq_lens + + +def get_paged_kv(batch_size, s_kv, num_kv_heads, page_size, dim_head_qk, dim_head_vo, head_dim_rope, kv_init_dtype, device): + + # Create KV cache + num_pages_per_seq = (s_kv + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + + # Now initialize the page tables + block_tables = torch.tensor( + [ + [k + i * num_pages_per_seq for k in range(num_pages_per_seq)] + for i in range(batch_size) + ], + dtype=torch.int32, + device=device, + ) + # Initialize KV cache with appropriate shape and stride + k_cache_shape = ( + total_num_pages, + num_kv_heads, + page_size, + dim_head_qk, + ) + k_cache = torch.randn(size=k_cache_shape, dtype=kv_init_dtype).to(device) + if head_dim_rope > 0: + k_rope_shape = ( + total_num_pages, + num_kv_heads, + page_size, + head_dim_rope, + ) + k_rope = torch.randn(size=k_rope_shape, dtype=kv_init_dtype).to(device) + return k_cache, k_rope, block_tables + else: + v_cache_shape = ( + total_num_pages, + num_kv_heads, + page_size, + dim_head_vo, + ) + v_cache = torch.randn(size=v_cache_shape, dtype=kv_init_dtype).to(device) + + return k_cache, v_cache, block_tables + +def generate_sample_data( + batch_size, + max_seq_len=1024, + page_size=None, + dtype=torch.bfloat16, + group_size=16, + heads=128, + head_dim_qk=128, + head_dim_vo=128, + head_dim_rope=0, + device="cuda", + is_decode=False, +): + q_init_dtype = torch.float16 + kv_init_dtype = torch.float16 + + s_kv = max_seq_len + num_qo_heads = heads + num_kv_heads = heads // group_size + + # Sample sequence lengths and create tensors + actual_seq_lens = sample_actual_seq_lens( + s_kv, batch_size, device, random_actual_seq_len=True + ) + cumsum_s_kv = torch.sum(actual_seq_lens) + q = torch.randn( + batch_size if is_decode else cumsum_s_kv, + num_qo_heads, + head_dim_qk, + device=device, + dtype=q_init_dtype, + ) + + actual_seq_offset = torch.arange(0, batch_size + 1, device=device) + actual_seq_offset[1:] = torch.cumsum(actual_seq_lens.view(-1), dim=0) + + if page_size is None: + k_cache = torch.randn( + cumsum_s_kv, num_kv_heads, head_dim_qk, device=device, dtype=kv_init_dtype + ) + v_cache = torch.randn( + cumsum_s_kv, num_kv_heads, head_dim_vo, device=device, dtype=kv_init_dtype + ) + block_tables = None + else: + k_cache, v_cache, block_tables = get_paged_kv( + batch_size, s_kv, num_kv_heads, page_size, head_dim_qk, head_dim_vo, head_dim_rope, kv_init_dtype, device) + + q_indptr = actual_seq_offset * (head_dim_qk * num_qo_heads) # For cuDNN + k_indptr = actual_seq_offset * (head_dim_qk * num_kv_heads) # For cuDNN + v_indptr = actual_seq_offset * (head_dim_vo * num_kv_heads) # For cuDNN + o_indptr = actual_seq_offset * (head_dim_vo * num_qo_heads) # For cuDNN + lse_indptr = actual_seq_offset * num_qo_heads # For cuDNN + q_indptr = q_indptr.long() + k_indptr = k_indptr.long() + v_indptr = v_indptr.long() + o_indptr = o_indptr.long() + lse_indptr = lse_indptr.long() + + head_dim = head_dim_qk + head_dim_rope + scale = float(1.0 / (head_dim**0.5)) + if head_dim_rope > 0 and page_size is not None: + q_rope = torch.randn( + batch_size if is_decode else cumsum_s_kv, num_qo_heads, head_dim_rope, device=device, dtype=kv_init_dtype + ) + return ( + (q.to(dtype), q_rope.to(dtype)), + k_cache.to(dtype), + v_cache.to(dtype), + scale, + actual_seq_lens, + block_tables, + actual_seq_offset, + q_indptr, + k_indptr, + v_indptr, + o_indptr, + lse_indptr, + ) + + else: + return ( + q.to(dtype), + k_cache.to(dtype), + v_cache.to(dtype), + scale, + actual_seq_lens, + block_tables, + actual_seq_offset, + q_indptr, + k_indptr, + v_indptr, + o_indptr, + lse_indptr, + ) diff --git a/tests/test_utils/cudnn_decode.py b/tests/test_utils/cudnn_decode.py new file mode 100644 index 00000000..78cfcc38 --- /dev/null +++ b/tests/test_utils/cudnn_decode.py @@ -0,0 +1,345 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +# fmt: off +# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/cudnn/decode.py + +from enum import Enum +from typing import Optional + +import torch + +try: + import cudnn + + CUDNN_AVAILABLE = True +except Exception: + cudnn = None + CUDNN_AVAILABLE = False + +# Global cudnn handle. need to make it per device in future +_cudnn_handle = None + + +def _create_cudnn_handle(stream: torch.cuda.Stream): + global _cudnn_handle + if _cudnn_handle is None: + _cudnn_handle = cudnn.create_handle() + cudnn.set_stream(_cudnn_handle, stream.cuda_stream) + return _cudnn_handle + + +# Tensor ids +class UIDs(Enum): + RESERVED_INVALID_UID = 0 + + Q_UID = 1 # Query tensor + K_UID = 2 # Key cache tensor + V_UID = 3 # Value cache tensor + + ACTUAL_SEQ_LENS_Q_UID = 100 # Actual sequence lengths for query tensor + ACTUAL_SEQ_LENS_KV_UID = 101 # Actual sequence lengths for key/value tensor + + BLOCK_TABLES_UID = 200 # Block tables tensor + BLOCK_TABLES_K_UID = 201 # Block tables tensor for key + BLOCK_TABLES_V_UID = 202 # Block tables tensor for value + + RAGGED_Q_UID = 50 # Ragged query tensor + RAGGED_O_UID = 51 # Ragged output tensor + RAGGED_STATS_UID = 52 # Ragged stats tensor + + O_UID = 1000 # Output tensor + STATS_UID = 1001 # Stats tensor + + +def _sdpa_decode_key_fn( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + scale: float, + *, + max_sequence_kv: int, + block_size: Optional[int] = 1, + actual_seq_lens_q: Optional[torch.Tensor] = None, + actual_seq_lens_kv: Optional[torch.Tensor] = None, + block_tables: Optional[torch.Tensor] = None, + batch_offsets_q: Optional[torch.Tensor] = None, + batch_offsets_o: Optional[torch.Tensor] = None, +): + return ( + "decode", + max_sequence_kv, + tuple(q.shape), + tuple(k_cache.shape), + ) + + +if CUDNN_AVAILABLE: + + @cudnn.jit(heur_modes=[cudnn.heur_mode.A]) + @cudnn.graph_cache(key_fn=_sdpa_decode_key_fn) + def _build_decode_graph( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + scale: float, + *, + max_sequence_kv: int, + block_size: Optional[int] = 1, + actual_seq_lens_q: Optional[torch.Tensor] = None, + actual_seq_lens_kv: Optional[torch.Tensor] = None, + block_tables: Optional[torch.Tensor] = None, + batch_offsets_q: Optional[torch.Tensor] = None, + batch_offsets_o: Optional[torch.Tensor] = None, + ): + handle = _create_cudnn_handle(torch.cuda.current_stream()) + + # WAR: override batch offsets for now, as it leads to a poor performance + batch_offsets_q = None + batch_offsets_o = None + dtype_map = { + torch.float16: cudnn.data_type.HALF, + torch.bfloat16: cudnn.data_type.BFLOAT16, + } + + with cudnn.graph(handle) as (g, _): + if q.dim() == 3: + s_qo = 1 + b, h_qo, d_qk = q.shape[0], q.shape[1], q.shape[2] + elif q.dim() == 4: + b, h_qo, s_qo, d_qk = ( + q.shape[0], + q.shape[1], + q.shape[2], + q.shape[3], + ) + else: + raise ValueError(f"q must have 3 or 4 dimensions, got {q.dim()}") + + assert s_qo == 1, "q must have a sequence length of 1" + assert k_cache.dim() == 4, "k_cache must have 4 dimensions" + + d_vo = v_cache.shape[3] + + cudnn_q = g.tensor( + name="q", + dim=(b, h_qo, s_qo, d_qk), + stride=(h_qo * d_qk, d_qk, d_qk * h_qo, 1), + data_type=dtype_map[q.dtype], + ) + if batch_offsets_q is not None: + ragged_q = g.tensor_like(batch_offsets_q) + ragged_q.set_uid(UIDs.RAGGED_Q_UID.value) + cudnn_q.set_ragged_offset(ragged_q) + + cudnn_k_cache = g.tensor_like(k_cache) + cudnn_v_cache = g.tensor_like(v_cache) + + cudnn_q.set_uid(UIDs.Q_UID.value) + cudnn_k_cache.set_uid(UIDs.K_UID.value) + cudnn_v_cache.set_uid(UIDs.V_UID.value) + + if block_tables is not None: + nd_block_tables = block_tables.reshape( + block_tables.shape[0], 1, block_tables.shape[1], 1 + ) + cudnn_k_block_tables = g.tensor_like(nd_block_tables) + cudnn_k_block_tables.set_uid(UIDs.BLOCK_TABLES_K_UID.value) + + cudnn_v_block_tables = g.tensor_like(nd_block_tables) + cudnn_v_block_tables.set_uid(UIDs.BLOCK_TABLES_V_UID.value) + + if actual_seq_lens_q is not None: + cudnn_actual_seq_lens_q = g.tensor_like(actual_seq_lens_q) + cudnn_actual_seq_lens_q.set_uid(UIDs.ACTUAL_SEQ_LENS_Q_UID.value) + + if actual_seq_lens_kv is not None: + cudnn_actual_seq_lens_kv = g.tensor_like(actual_seq_lens_kv) + cudnn_actual_seq_lens_kv.set_uid(UIDs.ACTUAL_SEQ_LENS_KV_UID.value) + cudnn_actual_seq_lens_kv.set_is_pass_by_value(False) + + padding_mask = actual_seq_lens_kv is not None + kwargs = dict( + name="sdpa", + q=cudnn_q, + k=cudnn_k_cache, + v=cudnn_v_cache, + seq_len_q=( + cudnn_actual_seq_lens_q if actual_seq_lens_q is not None else None + ), + seq_len_kv=( + cudnn_actual_seq_lens_kv if actual_seq_lens_kv is not None else None + ), + use_padding_mask=padding_mask, + generate_stats=False, + attn_scale=scale, + paged_attention_k_table=cudnn_k_block_tables, + paged_attention_v_table=cudnn_v_block_tables, + paged_attention_max_seq_len_kv=max_sequence_kv, + compute_data_type=cudnn.data_type.FLOAT, + ) + if cudnn.__version__ < '1.13.0': + del kwargs['generate_stats'] + kwargs['is_inference'] = True + + O, _ = g.sdpa(**kwargs) + + if batch_offsets_o is not None: + ragged_o = g.tensor_like(batch_offsets_o) + ragged_o.set_uid(UIDs.RAGGED_O_UID.value) + O.set_ragged_offset(ragged_o) + + O.set_uid(UIDs.O_UID.value).set_output(True).set_dim( + [b, h_qo, s_qo, d_vo] + ).set_stride([d_vo * h_qo, d_vo, d_vo * h_qo, 1]).set_data_type( + dtype_map[q.dtype] + ) + + tensors_to_return = [cudnn_q, cudnn_k_cache, cudnn_v_cache, O] + + if actual_seq_lens_q is not None: + tensors_to_return.append(cudnn_actual_seq_lens_q) + if actual_seq_lens_kv is not None: + tensors_to_return.append(cudnn_actual_seq_lens_kv) + + return g, tensors_to_return + +else: + # cuDNN not available - define stub function + def _build_decode_graph(*args, **kwargs): + raise RuntimeError("cuDNN is not available. Please install cudnn package.") + + +def _batch_decode_with_kv_cache( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + scale: float, + workspace_buffer: torch.Tensor, + *, + max_sequence_kv: int, + actual_seq_lens_q: Optional[torch.Tensor] = None, + actual_seq_lens_kv: Optional[torch.Tensor] = None, + block_tables: Optional[torch.Tensor] = None, + block_size: Optional[int] = 1, + batch_offsets_q: Optional[torch.Tensor] = None, + batch_offsets_o: Optional[torch.Tensor] = None, + batch_offsets_k: Optional[torch.Tensor] = None, + batch_offsets_v: Optional[torch.Tensor] = None, + out: torch.Tensor, +) -> torch.Tensor: + graph, tensors = _build_decode_graph( + q=q, + k_cache=k_cache, + v_cache=v_cache, + scale=scale, + max_sequence_kv=max_sequence_kv, + actual_seq_lens_q=actual_seq_lens_q, + actual_seq_lens_kv=actual_seq_lens_kv, + block_tables=block_tables, + block_size=block_size, + batch_offsets_q=batch_offsets_q if batch_offsets_q is not None else None, + batch_offsets_o=batch_offsets_q if batch_offsets_q is not None else None, + ) + + handle_ = _create_cudnn_handle(torch.cuda.current_stream()) + + var_map = { + UIDs.Q_UID.value: q, + UIDs.K_UID.value: k_cache, + UIDs.V_UID.value: v_cache, + UIDs.O_UID.value: out, + } + if actual_seq_lens_q is not None: + var_map[UIDs.ACTUAL_SEQ_LENS_Q_UID.value] = actual_seq_lens_q + if actual_seq_lens_kv is not None: + var_map[UIDs.ACTUAL_SEQ_LENS_KV_UID.value] = actual_seq_lens_kv + + if batch_offsets_q is not None: + var_map[UIDs.RAGGED_Q_UID.value] = batch_offsets_q + if batch_offsets_o is not None: + var_map[UIDs.RAGGED_O_UID.value] = batch_offsets_o + + if block_tables is not None: + var_map[UIDs.BLOCK_TABLES_K_UID.value] = block_tables + var_map[UIDs.BLOCK_TABLES_V_UID.value] = block_tables + + graph.execute(var_map, workspace=workspace_buffer, handle=handle_) + + return out + + +def cudnn_batch_decode_with_kv_cache( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + scale: float, + workspace_buffer: torch.Tensor, + *, + max_sequence_kv: int, + actual_seq_lens_kv: Optional[torch.Tensor] = None, + block_tables: Optional[torch.Tensor] = None, + is_cuda_graph_compatible: bool = False, + batch_offsets_q: Optional[torch.Tensor] = None, + batch_offsets_o: Optional[torch.Tensor] = None, + batch_offsets_k: Optional[torch.Tensor] = None, + batch_offsets_v: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Performs batched decode attention with paged KV cache using cuDNN. + + Args: + q: Query tensor of shape (batch_size, num_heads_qo, head_dim), seq_len_q is the maximum sequence length of queries in the batch + k_cache: Key cache tensor of shape (total_num_pages, num_heads_kv, page_size, head_dim) + v_cache: Value cache tensor of shape (total_num_pages, num_heads_kv, page_size, head_dim) + scale: Scaling factor for attention scores, typically 1/sqrt(head_dim) + workspace_buffer: Workspace buffer for cuDNN operations. Scales with batch size. 128 MB should be sufficient for most cases + max_sequence_kv: Maximum number of tokens per key/value sequence (s_kv_max) + actual_seq_lens_kv: Actual sequence lengths for key/values per batch, shape (batch_size,) on CPU + block_tables: Page table mapping for KV cache, shape (batch_size, num_pages_per_seq) on GPU + is_cuda_graph_compatible: Whether the decode operation is compatible with CUDA graph + batch_offsets: Optional batch offsets tensor of shape (batch_size,) on GPU + out: Optional pre-allocated output tensor + batch_offsets_q: Optional batch offsets for query tensor of shape (batch_size,) on GPU + batch_offsets_o: Optional batch offsets for output tensor of shape (batch_size,) on GPU + batch_offsets_k: Optional batch offsets for key tensor of shape (batch_size,) on GPU + batch_offsets_v: Optional batch offsets for value tensor of shape (batch_size,) on GPU + + Returns: + Output tensor of shape (batch_size, num_heads_qo, head_dim) + + Note: + Currently only supports causal attention (causal must be True) + All tensors must be contiguous and on the same CUDA device + Query and KV heads can have different sizes (num_heads_qo >= num_heads_kv) + """ + + bs = q.shape[0] + h_qo = q.shape[1] + d_vo = v_cache.shape[3] + + if out is None: + out = torch.empty(bs, h_qo, d_vo, device=q.device, dtype=q.dtype) + actual_seq_lens_q = torch.ones( + (bs, 1, 1, 1), device=q.device, dtype=torch.int32 + ) + block_size = k_cache.shape[2] + + _batch_decode_with_kv_cache( + q=q, + k_cache=k_cache, + v_cache=v_cache, + scale=scale, + workspace_buffer=workspace_buffer, + max_sequence_kv=max_sequence_kv, + actual_seq_lens_q=actual_seq_lens_q, + actual_seq_lens_kv=actual_seq_lens_kv, + block_tables=block_tables, + batch_offsets_q=batch_offsets_q, + batch_offsets_o=batch_offsets_o, + block_size=block_size, + out=out, + ) + + return out diff --git a/tests/test_utils/cudnn_prefill.py b/tests/test_utils/cudnn_prefill.py new file mode 100644 index 00000000..dff27369 --- /dev/null +++ b/tests/test_utils/cudnn_prefill.py @@ -0,0 +1,510 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +# fmt: off +# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/cudnn/prefill.py + +from enum import Enum +from typing import Optional + +import torch + +try: + import cudnn + + CUDNN_AVAILABLE = True +except Exception: + cudnn = None + CUDNN_AVAILABLE = False + +# Global cudnn handle. need to make it per device in future +_cudnn_handle = None + + +def _create_cudnn_handle(stream: torch.cuda.Stream): + global _cudnn_handle + if _cudnn_handle is None: + _cudnn_handle = cudnn.create_handle() + cudnn.set_stream(_cudnn_handle, stream.cuda_stream) + return _cudnn_handle + + +# Tensor ids +class UIDs(Enum): + RESERVED_INVALID_UID = 0 + + Q_UID = 1 # Query tensor + K_UID = 2 # Key cache tensor + V_UID = 3 # Value cache tensor + + ACTUAL_SEQ_LENS_Q_UID = 100 # Actual sequence lengths for query tensor + ACTUAL_SEQ_LENS_KV_UID = 101 # Actual sequence lengths for key/value tensor + + BLOCK_TABLES_UID = 200 # Block tables tensor + BLOCK_TABLES_K_UID = 201 # Block tables tensor for key + BLOCK_TABLES_V_UID = 202 # Block tables tensor for value + + RAGGED_Q_UID = 50 # Ragged query tensor + RAGGED_O_UID = 51 # Ragged output tensor + RAGGED_STATS_UID = 52 # Ragged stats tensor + RAGGED_K_UID = 53 # Ragged key tensor + RAGGED_V_UID = 54 # Ragged value tensor + + O_UID = 1000 # Output tensor + STATS_UID = 1001 # Stats tensor + + +def _sdpa_prefill_key_fn( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + scale: float, + *, + max_token_seq_q: Optional[int] = None, + max_sequence_kv: Optional[int] = None, + actual_seq_lens_q: Optional[torch.Tensor] = None, + actual_seq_lens_kv: torch.Tensor, + block_tables: Optional[torch.Tensor] = None, + page_size: Optional[int] = None, + bottom_right_causal_mask: Optional[bool] = None, + return_lse: Optional[bool] = False, + batch_offsets_q: Optional[torch.Tensor] = None, + batch_offsets_o: Optional[torch.Tensor] = None, + batch_offsets_k: Optional[torch.Tensor] = None, + batch_offsets_v: Optional[torch.Tensor] = None, + batch_offsets_stats: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, +): + graph_b = actual_seq_lens_q.shape[0] + if q.dim() == 3: + h_qo, d_qk = q.shape[1], q.shape[2] + elif q.dim() == 4: + h_qo, d_qk = q.shape[1], q.shape[3] + + if v_cache.dim() == 3: + h_kv, d_vo = k_cache.shape[1], k_cache.shape[2] + elif k_cache.dim() == 4: + h_kv, d_vo = k_cache.shape[1], k_cache.shape[3] + + if block_tables is not None: + page_size = k_cache.shape[2] + + key = ( + graph_b, + q.dim(), + k_cache.dim(), + max_token_seq_q, + max_sequence_kv, + h_qo, + d_qk, + h_kv, + d_vo, + block_tables is not None, + return_lse, + bottom_right_causal_mask, + page_size, + ) + return key + + +if CUDNN_AVAILABLE: + + @cudnn.jit(heur_modes=[cudnn.heur_mode.A]) + @cudnn.graph_cache(key_fn=_sdpa_prefill_key_fn) + def _build_prefill_graph( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + scale: float, + *, + max_token_seq_q: Optional[int] = None, + max_sequence_kv: Optional[int] = None, + actual_seq_lens_q: Optional[torch.Tensor] = None, + actual_seq_lens_kv: Optional[torch.Tensor] = None, + block_tables: Optional[torch.Tensor] = None, + bottom_right_causal_mask: Optional[bool] = True, + return_lse: Optional[bool] = False, + batch_offsets_q: Optional[torch.Tensor] = None, + batch_offsets_o: Optional[torch.Tensor] = None, + batch_offsets_k: Optional[torch.Tensor] = None, + batch_offsets_v: Optional[torch.Tensor] = None, + batch_offsets_stats: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + ): + handle = _create_cudnn_handle(torch.cuda.current_stream(q.device)) + + graph_b = actual_seq_lens_q.shape[0] + graph_s_qo = max_token_seq_q + graph_s_kv = max_sequence_kv + + dtype_map = { + torch.float16: cudnn.data_type.HALF, + torch.bfloat16: cudnn.data_type.BFLOAT16, + } + + with cudnn.graph(handle) as (g, _): + # Create tensors from the input tensors + if q.dim() == 3: + h_qo, d_qk = q.shape[1], q.shape[2] + elif q.dim() == 4: + h_qo, d_qk = q.shape[2], q.shape[3] + else: + raise ValueError(f"Invalid query tensor shape: {q.shape}") + + cudnn_q = g.tensor( + name="q", + dim=(graph_b, h_qo, graph_s_qo, d_qk), + stride=(h_qo * d_qk, d_qk, d_qk * h_qo, 1), + data_type=dtype_map[q.dtype], + ) + + if batch_offsets_q is not None: + ragged_q = g.tensor_like(batch_offsets_q) + ragged_q.set_uid(UIDs.RAGGED_Q_UID.value) + cudnn_q.set_ragged_offset(ragged_q) + + if v_cache.dim() == 3: + assert block_tables is None, ( + "block_tables needs 4 dimensions of kv cache" + ) + h_kv, d_vo = v_cache.shape[1], v_cache.shape[2] + elif v_cache.dim() == 4: + h_kv, d_vo = ( + v_cache.shape[1], + v_cache.shape[3], + ) + else: + raise ValueError(f"Invalid kv cache tensor shape: {k_cache.shape}") + + if k_cache.dim() == 3: + cudnn_k_cache = g.tensor( + name="k_cache", + dim=(graph_b, h_kv, graph_s_kv, d_qk), + stride=(h_kv * d_qk * graph_s_kv, d_qk, d_qk * h_kv, 1), + data_type=dtype_map[k_cache.dtype], + ) + + if batch_offsets_k is not None: + ragged_k = g.tensor_like(batch_offsets_k) + ragged_k.set_uid(UIDs.RAGGED_K_UID.value) + cudnn_k_cache.set_ragged_offset(ragged_k) + + cudnn_v_cache = g.tensor( + name="v_cache", + dim=(graph_b, h_kv, graph_s_kv, d_vo), + stride=(h_kv * d_vo * graph_s_kv, d_vo, d_vo * h_kv, 1), + data_type=dtype_map[v_cache.dtype], + ) + + if batch_offsets_v is not None: + ragged_v = g.tensor_like(batch_offsets_v) + ragged_v.set_uid(UIDs.RAGGED_V_UID.value) + cudnn_v_cache.set_ragged_offset(ragged_v) + + elif k_cache.dim() == 4: + cudnn_k_cache = g.tensor( + name="k_cache", + dim=k_cache.shape, + stride=k_cache.stride(), + data_type=dtype_map[k_cache.dtype], + ) + + cudnn_v_cache = g.tensor( + name="v_cache", + dim=v_cache.shape, + stride=v_cache.stride(), + data_type=dtype_map[v_cache.dtype], + ) + + cudnn_q.set_uid(UIDs.Q_UID.value) + cudnn_k_cache.set_uid(UIDs.K_UID.value) + cudnn_v_cache.set_uid(UIDs.V_UID.value) + + if block_tables is not None: + nd_block_tables = block_tables.reshape( + block_tables.shape[0], 1, block_tables.shape[1], 1 + ) + cudnn_k_block_tables = g.tensor_like(nd_block_tables) + cudnn_k_block_tables.set_uid(UIDs.BLOCK_TABLES_K_UID.value) + + cudnn_v_block_tables = g.tensor_like(nd_block_tables) + cudnn_v_block_tables.set_uid(UIDs.BLOCK_TABLES_V_UID.value) + + if actual_seq_lens_q is not None: + cudnn_actual_seq_lens_q = g.tensor_like(actual_seq_lens_q) + cudnn_actual_seq_lens_q.set_name("actual_seq_lens_q") + cudnn_actual_seq_lens_q.set_uid(UIDs.ACTUAL_SEQ_LENS_Q_UID.value) + + if actual_seq_lens_kv is not None: + cudnn_actual_seq_lens_kv = g.tensor_like(actual_seq_lens_kv) + cudnn_actual_seq_lens_kv.set_name("actual_seq_lens_kv") + cudnn_actual_seq_lens_kv.set_uid(UIDs.ACTUAL_SEQ_LENS_KV_UID.value) + + padding_mask = ( + actual_seq_lens_q is not None and actual_seq_lens_kv is not None + ) + kwargs = dict( + name="sdpa", + q=cudnn_q, + k=cudnn_k_cache, + v=cudnn_v_cache, + seq_len_q=( + cudnn_actual_seq_lens_q if actual_seq_lens_q is not None else None + ), + seq_len_kv=( + cudnn_actual_seq_lens_kv if actual_seq_lens_kv is not None else None + ), + use_padding_mask=padding_mask, + attn_scale=scale, + generate_stats=return_lse, + use_causal_mask_bottom_right=bottom_right_causal_mask, + paged_attention_k_table=( + cudnn_k_block_tables if block_tables is not None else None + ), + paged_attention_v_table=( + cudnn_v_block_tables if block_tables is not None else None + ), + paged_attention_max_seq_len_kv=( + graph_s_kv if block_tables is not None else None + ), + compute_data_type=cudnn.data_type.FLOAT, + ) + if cudnn.__version__ < '1.13.0': + kwargs['is_inference'] = not kwargs['generate_stats'] + del kwargs['generate_stats'] + + O, Stats = g.sdpa(**kwargs) + + if batch_offsets_o is not None: + ragged_o = g.tensor_like(batch_offsets_o) + ragged_o.set_uid(UIDs.RAGGED_O_UID.value) + O.set_ragged_offset(ragged_o) + + if batch_offsets_stats is not None: + ragged_stats = g.tensor_like(batch_offsets_stats) + ragged_stats.set_uid(UIDs.RAGGED_STATS_UID.value) + Stats.set_ragged_offset(ragged_stats) + + O.set_uid(UIDs.O_UID.value).set_output(True).set_dim( + [graph_b, h_qo, graph_s_qo, d_vo] + ).set_stride( + [graph_s_qo * d_vo * h_qo, d_vo, d_vo * h_qo, 1] + ).set_data_type(dtype_map[q.dtype]) + + if return_lse: + Stats.set_uid(UIDs.STATS_UID.value).set_output( + return_lse + ).set_data_type(cudnn.data_type.FLOAT).set_dim( + [graph_b, h_qo, graph_s_qo, 1] + ).set_stride([graph_s_qo * h_qo, 1, h_qo, 1]) + + tensors_to_return = [cudnn_q, cudnn_k_cache, cudnn_v_cache, O] + if return_lse: + tensors_to_return.append(Stats) + + if actual_seq_lens_q is not None: + tensors_to_return.append(cudnn_actual_seq_lens_q) + if actual_seq_lens_kv is not None: + tensors_to_return.append(cudnn_actual_seq_lens_kv) + + return g, tensors_to_return + +else: + # cuDNN not available - define stub function + def _build_prefill_graph(*args, **kwargs): + raise RuntimeError("cuDNN is not available. Please install cudnn package.") + + +def _batch_prefill_with_kv_cache( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + scale: float, + workspace_buffer: torch.Tensor, + *, + max_token_per_sequence: int, + max_sequence_kv: int, + actual_seq_lens_q: torch.Tensor, + actual_seq_lens_kv: torch.Tensor, + block_tables: Optional[torch.Tensor] = None, + causal: bool, + return_lse: bool, + batch_offsets_q: Optional[torch.Tensor] = None, + batch_offsets_o: Optional[torch.Tensor] = None, + batch_offsets_k: Optional[torch.Tensor] = None, + batch_offsets_v: Optional[torch.Tensor] = None, + batch_offsets_stats: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + graph, tensors = _build_prefill_graph( + q=q, + k_cache=k_cache, + v_cache=v_cache, + scale=scale, + max_token_seq_q=max_token_per_sequence, + max_sequence_kv=max_sequence_kv, + actual_seq_lens_q=actual_seq_lens_q, + actual_seq_lens_kv=actual_seq_lens_kv, + block_tables=block_tables, + bottom_right_causal_mask=causal, + return_lse=return_lse, + batch_offsets_q=batch_offsets_q, + batch_offsets_o=batch_offsets_o, + batch_offsets_k=batch_offsets_k, + batch_offsets_v=batch_offsets_v, + batch_offsets_stats=batch_offsets_stats, + out=out, + lse=lse, + ) + + var_map = { + UIDs.Q_UID.value: q, + UIDs.K_UID.value: k_cache, + UIDs.V_UID.value: v_cache, + UIDs.O_UID.value: out, + } + + if actual_seq_lens_q is not None: + var_map[UIDs.ACTUAL_SEQ_LENS_Q_UID.value] = actual_seq_lens_q + if actual_seq_lens_kv is not None: + var_map[UIDs.ACTUAL_SEQ_LENS_KV_UID.value] = actual_seq_lens_kv + + if batch_offsets_q is not None: + var_map[UIDs.RAGGED_Q_UID.value] = batch_offsets_q + if batch_offsets_o is not None: + var_map[UIDs.RAGGED_O_UID.value] = batch_offsets_o + + if batch_offsets_k is not None: + var_map[UIDs.RAGGED_K_UID.value] = batch_offsets_k + if batch_offsets_v is not None: + var_map[UIDs.RAGGED_V_UID.value] = batch_offsets_v + + if block_tables is not None: + var_map[UIDs.BLOCK_TABLES_K_UID.value] = block_tables + var_map[UIDs.BLOCK_TABLES_V_UID.value] = block_tables + + if return_lse: + var_map[UIDs.STATS_UID.value] = lse + if batch_offsets_stats is not None: + var_map[UIDs.RAGGED_STATS_UID.value] = batch_offsets_stats + + handle = _create_cudnn_handle(torch.cuda.current_stream(q.device)) + graph.execute(var_map, workspace=workspace_buffer, handle=handle) + + if return_lse: + return out, lse + else: + return out, None + + +def cudnn_batch_prefill_with_kv_cache( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + scale: float, + workspace_buffer: torch.Tensor, + *, + max_token_per_sequence: int, + max_sequence_kv: int, + actual_seq_lens_q: torch.Tensor, + actual_seq_lens_kv: torch.Tensor, + block_tables: Optional[torch.Tensor] = None, + causal: bool, + return_lse: bool, + batch_offsets_q: Optional[torch.Tensor] = None, + batch_offsets_o: Optional[torch.Tensor] = None, + batch_offsets_k: Optional[torch.Tensor] = None, + batch_offsets_v: Optional[torch.Tensor] = None, + batch_offsets_stats: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + is_cuda_graph_compatible: bool = False, + backend: Optional[str] = None, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Performs batched prefill attention with paged KV cache using cuDNN. + + Args: + q: Query tensor of shape (Total number of tokens, num_heads_qo, head_dim) + k_cache: Key cache tensor of shape (total_num_pages, num_heads_kv, page_size, head_dim) if paged kv cache is enabled else (Total sequence length of kv, num_heads_kv, d_qk) + v_cache: Value cache tensor of shape (total_num_pages, num_heads_kv, page_size, head_dim) if paged kv cache is enabled else (Total sequence length of kv, num_heads_kv, d_vo) + scale: Scaling factor for attention scores, typically 1/sqrt(head_dim) + workspace_buffer: Workspace buffer for cuDNN operations. Scales with batch size. 128 MB should be sufficient for most cases + max_token_per_sequence: Maximum number of tokens per query sequence (s_qo_max) + max_sequence_kv: Maximum number of tokens per key/value sequence (s_kv_max) + actual_seq_lens_q: Actual number of tokens per query sequence shape (batch_size,) on cpu or device (cpu if cuda_graph is False) + actual_seq_lens_kv: Actual sequence lengths for key/values per batch, shape (batch_size,) on CPU or device (cpu if cuda_graph is False) + block_tables: Page table mapping for KV cache, shape (batch_size, num_pages_per_seq) on GPU + causal: Whether to apply causal masking + return_lse: Whether to return log-sum-exp values (must be True) + out: Optional pre-allocated output tensor + lse: Optional pre-allocated tensor for log-sum-exp values if return_lse is True else returns None + is_cuda_graph_compatible: Whether the prefill operation is compatible with CUDA graph + batch_offsets_q: Optional batch offsets for query tensor of shape (batch_size,) on GPU + batch_offsets_o: Optional batch offsets for output tensor of shape (batch_size,) on GPU + batch_offsets_k: Optional batch offsets for key tensor of shape (batch_size,) on GPU + batch_offsets_v: Optional batch offsets for value tensor of shape (batch_size,) on GPU + + Returns: + Output tensor of shape (batch_size * seq_len_q, num_heads_qo, head_dim) + If return_lse is True, also returns log-sum-exp tensor of shape (batch_size, seq_len_q, num_heads_qo) + + Note: + Query and KV heads can have different sizes (num_heads_qo >= num_heads_kv) + When using cuda graph, actual_seq_lens_q and actual_seq_lens_kv must be on the same device as q + Head dimension of query and key must be 128 or 192 + Head dimension of value and output must be 128 + """ + + num_tokens = q.shape[0] + + num_sequences = actual_seq_lens_q.shape[0] + + if q.dim() == 3: + h_qo, d_qk = q.shape[1], q.shape[2] + elif q.dim() == 4: + h_qo, d_qk = q.shape[1], q.shape[3] + + if v_cache.dim() == 3: + d_vo = v_cache.shape[2] + elif v_cache.dim() == 4: + d_vo = v_cache.shape[3] + + if return_lse: + if lse is None: + lse = torch.empty( + num_sequences, + max_token_per_sequence, + h_qo, + device=q.device, + dtype=torch.float32, + ) + + if out is None: + out_shape = (num_tokens, h_qo, d_vo) + out = torch.empty(out_shape, device=q.device, dtype=q.dtype) + + return _batch_prefill_with_kv_cache( + q=q, + k_cache=k_cache, + v_cache=v_cache, + scale=scale, + workspace_buffer=workspace_buffer, + max_token_per_sequence=max_token_per_sequence, + max_sequence_kv=max_sequence_kv, + actual_seq_lens_q=actual_seq_lens_q, + actual_seq_lens_kv=actual_seq_lens_kv, + block_tables=block_tables, + causal=causal, + return_lse=return_lse, + batch_offsets_q=batch_offsets_q, + batch_offsets_o=batch_offsets_o, + batch_offsets_k=batch_offsets_k, + batch_offsets_v=batch_offsets_v, + batch_offsets_stats=batch_offsets_stats, + out=out, + lse=lse, + ) From dbe9b79560594550ccb5c73ea2eb0f461f032280 Mon Sep 17 00:00:00 2001 From: Jinman Xie Date: Wed, 22 Apr 2026 04:33:45 -0700 Subject: [PATCH 3/6] Use cutile new autotuner for remaining kernels --- README.md | 14 +--- README_chs.md | 14 +--- README_cht.md | 14 +--- README_fr.md | 14 +--- README_ja.md | 14 +--- requirements.txt | 4 +- src/tilegym/ops/cutile/bmm.py | 1 - src/tilegym/ops/cutile/group_gemm.py | 1 - src/tilegym/ops/cutile/matmul.py | 1 - src/tilegym/ops/cutile/mla.py | 118 ++++++++++++++++++--------- tests/ops/test_bmm.py | 1 - 11 files changed, 91 insertions(+), 105 deletions(-) diff --git a/README.md b/README.md index bee84006..c69e4f10 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ We have verified that `torch==2.9.1` works. You can also get `triton` packages w #### 2. Install TileGym -TileGym uses [`cuda-tile`](https://github.com/nvidia/cutile-python) for GPU kernel programming, which depends on the `tileiras` compiler at runtime. +TileGym uses [`cuda-tile`](https://github.com/nvidia/cutile-python) (≥ 1.3.0) for GPU kernel programming, which depends on the `tileiras` compiler at runtime. ##### Install from PyPI (recommended) @@ -77,17 +77,7 @@ pip install .[tileiras] # or: pip install . (if you have system tileiras) For editable (development) mode, use `pip install -e .` or `pip install -e .[tileiras]`. -##### Install `cuda-tile-experimental` - -> ⚠️ **Required**: TileGym kernels use features from [`cuda-tile-experimental`](https://github.com/NVIDIA/cutile-python/tree/main/experimental) (e.g., the autotuner). This package is *not* available on PyPI and must be installed separately from source: -> -> ```bash -> pip install "cuda-tile-experimental @ git+https://github.com/NVIDIA/cutile-python.git#subdirectory=experimental" -> ``` -> -> `cuda-tile-experimental` is maintained by the CUDA Tile team as a source-only experimental package. See more details in [experimental-features-optional](https://github.com/NVIDIA/cutile-python?tab=readme-ov-file#experimental-features-optional). - -All runtime dependencies (except `cuda-tile-experimental`) are declared in [`requirements.txt`](requirements.txt) and are installed automatically by both `pip install tilegym` and `pip install .`. +All runtime dependencies are declared in [`requirements.txt`](requirements.txt) and are installed automatically by both `pip install tilegym` and `pip install .`. We also provide Dockerfile, you can refer to [modeling/transformers/README.md](modeling/transformers/README.md). diff --git a/README_chs.md b/README_chs.md index 2d7481a0..0697ad4b 100644 --- a/README_chs.md +++ b/README_chs.md @@ -51,7 +51,7 @@ pip install --pre torch --index-url https://download.pytorch.org/whl/cu130 #### 2. 安装 TileGym -TileGym 使用 [`cuda-tile`](https://github.com/nvidia/cutile-python) 进行 GPU 内核编程,运行时依赖 `tileiras` 编译器。 +TileGym 使用 [`cuda-tile`](https://github.com/nvidia/cutile-python)(≥ 1.3.0)进行 GPU 内核编程,运行时依赖 `tileiras` 编译器。 ##### 从 PyPI 安装(推荐) @@ -77,17 +77,7 @@ pip install .[tileiras] # 或者: pip install . (如果您已有系统级 til 如需可编辑(开发)模式,请使用 `pip install -e .` 或 `pip install -e .[tileiras]`。 -##### 安装 `cuda-tile-experimental` - -> ⚠️ **必需**:TileGym 内核使用了 [`cuda-tile-experimental`](https://github.com/NVIDIA/cutile-python/tree/main/experimental) 中的功能(如自动调优器)。此包*不*在 PyPI 上提供,必须从源码单独安装: -> -> ```bash -> pip install "cuda-tile-experimental @ git+https://github.com/NVIDIA/cutile-python.git#subdirectory=experimental" -> ``` -> -> `cuda-tile-experimental` 由 CUDA Tile 团队维护,仅提供源码安装。更多详情请参阅 [experimental-features-optional](https://github.com/NVIDIA/cutile-python?tab=readme-ov-file#experimental-features-optional)。 - -所有运行时依赖(`cuda-tile-experimental` 除外)均声明在 [`requirements.txt`](requirements.txt) 中,通过 `pip install tilegym` 和 `pip install .` 都会自动安装。 +所有运行时依赖均声明在 [`requirements.txt`](requirements.txt) 中,通过 `pip install tilegym` 和 `pip install .` 都会自动安装。 我们还提供了 Dockerfile,您可以参考 [modeling/transformers/README.md](modeling/transformers/README.md)。 diff --git a/README_cht.md b/README_cht.md index c900b996..6f670b7e 100644 --- a/README_cht.md +++ b/README_cht.md @@ -51,7 +51,7 @@ pip install --pre torch --index-url https://download.pytorch.org/whl/cu130 #### 2. 安裝 TileGym -TileGym 使用 [`cuda-tile`](https://github.com/nvidia/cutile-python) 進行 GPU 核心程式設計,執行時期依賴 `tileiras` 編譯器。 +TileGym 使用 [`cuda-tile`](https://github.com/nvidia/cutile-python)(≥ 1.3.0)進行 GPU 核心程式設計,執行時期依賴 `tileiras` 編譯器。 ##### 從 PyPI 安裝(建議) @@ -77,17 +77,7 @@ pip install .[tileiras] # 或者: pip install . (如果您已有系統級 til 如需可編輯(開發)模式,請使用 `pip install -e .` 或 `pip install -e .[tileiras]`。 -##### 安裝 `cuda-tile-experimental` - -> ⚠️ **必需**:TileGym 核心使用了 [`cuda-tile-experimental`](https://github.com/NVIDIA/cutile-python/tree/main/experimental) 中的功能(如自動調優器)。此套件*不*在 PyPI 上提供,必須從原始碼單獨安裝: -> -> ```bash -> pip install "cuda-tile-experimental @ git+https://github.com/NVIDIA/cutile-python.git#subdirectory=experimental" -> ``` -> -> `cuda-tile-experimental` 由 CUDA Tile 團隊維護,僅提供原始碼安裝。更多詳情請參閱 [experimental-features-optional](https://github.com/NVIDIA/cutile-python?tab=readme-ov-file#experimental-features-optional)。 - -所有執行時期依賴(`cuda-tile-experimental` 除外)均宣告於 [`requirements.txt`](requirements.txt) 中,透過 `pip install tilegym` 和 `pip install .` 都會自動安裝。 +所有執行時期依賴均宣告於 [`requirements.txt`](requirements.txt) 中,透過 `pip install tilegym` 和 `pip install .` 都會自動安裝。 我們還提供了 Dockerfile,您可以參考 [modeling/transformers/README.md](modeling/transformers/README.md)。 diff --git a/README_fr.md b/README_fr.md index f6ab527e..53bb01c7 100644 --- a/README_fr.md +++ b/README_fr.md @@ -51,7 +51,7 @@ Nous avons vérifié que `torch==2.9.1` fonctionne. Vous pouvez également obten #### 2. Installer TileGym -TileGym utilise [`cuda-tile`](https://github.com/nvidia/cutile-python) pour la programmation de noyaux GPU, qui dépend du compilateur `tileiras` à l'exécution. +TileGym utilise [`cuda-tile`](https://github.com/nvidia/cutile-python) (≥ 1.3.0) pour la programmation de noyaux GPU, qui dépend du compilateur `tileiras` à l'exécution. ##### Installer depuis PyPI (recommandé) @@ -77,17 +77,7 @@ pip install .[tileiras] # ou : pip install . (si vous avez tileiras sur votre Pour le mode éditable (développement), utilisez `pip install -e .` ou `pip install -e .[tileiras]`. -##### Installer `cuda-tile-experimental` - -> ⚠️ **Requis** : Les noyaux TileGym utilisent des fonctionnalités de [`cuda-tile-experimental`](https://github.com/NVIDIA/cutile-python/tree/main/experimental) (par ex. l'auto-tuner). Ce paquet n'est *pas* disponible sur PyPI et doit être installé séparément depuis les sources : -> -> ```bash -> pip install "cuda-tile-experimental @ git+https://github.com/NVIDIA/cutile-python.git#subdirectory=experimental" -> ``` -> -> `cuda-tile-experimental` est maintenu par l'équipe CUDA Tile comme un paquet expérimental disponible uniquement depuis les sources. Voir plus de détails dans [experimental-features-optional](https://github.com/NVIDIA/cutile-python?tab=readme-ov-file#experimental-features-optional). - -Toutes les dépendances d'exécution (sauf `cuda-tile-experimental`) sont déclarées dans [`requirements.txt`](requirements.txt) et sont installées automatiquement par `pip install tilegym` et `pip install .`. +Toutes les dépendances d'exécution sont déclarées dans [`requirements.txt`](requirements.txt) et sont installées automatiquement par `pip install tilegym` et `pip install .`. Nous fournissons également un Dockerfile, vous pouvez consulter [modeling/transformers/README.md](modeling/transformers/README.md). diff --git a/README_ja.md b/README_ja.md index 1b29a22e..f69b8093 100644 --- a/README_ja.md +++ b/README_ja.md @@ -51,7 +51,7 @@ pip install --pre torch --index-url https://download.pytorch.org/whl/cu130 #### 2. TileGym のインストール -TileGym は GPU カーネルプログラミングに [`cuda-tile`](https://github.com/nvidia/cutile-python) を使用しており、実行時に `tileiras` コンパイラに依存しています。 +TileGym は GPU カーネルプログラミングに [`cuda-tile`](https://github.com/nvidia/cutile-python)(≥ 1.3.0)を使用しており、実行時に `tileiras` コンパイラに依存しています。 ##### PyPI からインストール(推奨) @@ -77,17 +77,7 @@ pip install .[tileiras] # または: pip install . (システムに tileiras 編集可能(開発)モードの場合は、`pip install -e .` または `pip install -e .[tileiras]` を使用してください。 -##### `cuda-tile-experimental` のインストール - -> ⚠️ **必須**:TileGym カーネルは [`cuda-tile-experimental`](https://github.com/NVIDIA/cutile-python/tree/main/experimental) の機能(例:オートチューナー)を使用しています。このパッケージは PyPI では提供されて*おらず*、ソースから個別にインストールする必要があります: -> -> ```bash -> pip install "cuda-tile-experimental @ git+https://github.com/NVIDIA/cutile-python.git#subdirectory=experimental" -> ``` -> -> `cuda-tile-experimental` は CUDA Tile チームによってソースのみの実験的パッケージとして管理されています。詳細は [experimental-features-optional](https://github.com/NVIDIA/cutile-python?tab=readme-ov-file#experimental-features-optional) をご覧ください。 - -すべてのランタイム依存関係(`cuda-tile-experimental` を除く)は [`requirements.txt`](requirements.txt) に宣言されており、`pip install tilegym` と `pip install .` の両方で自動的にインストールされます。 +すべてのランタイム依存関係は [`requirements.txt`](requirements.txt) に宣言されており、`pip install tilegym` と `pip install .` の両方で自動的にインストールされます。 Dockerfile も提供しています。[modeling/transformers/README.md](modeling/transformers/README.md) を参照してください。 diff --git a/requirements.txt b/requirements.txt index 9d3a4526..decd3388 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,9 +10,7 @@ huggingface_hub matplotlib pandas numpy -cuda-tile # Or use: pip install cuda-tile[tileiras] for bundled tileiras compiler -# cuda-tile-experimental is NOT on PyPI and must be installed separately from source: -# pip install "cuda-tile-experimental @ git+https://github.com/NVIDIA/cutile-python.git#subdirectory=experimental" +cuda-tile>=1.3.0 # Or use: pip install cuda-tile[tileiras] for bundled tileiras compiler filelock>=3.20.3 # CVE fix: GHSA-w853-jp5j-5j7f, GHSA-qmgc-5h2g-mvrw pillow>=12.1.1 # CVE fix: GHSA-cfh3-3jmp-rvhc # nvidia-ml-py # optional diff --git a/src/tilegym/ops/cutile/bmm.py b/src/tilegym/ops/cutile/bmm.py index 03bbd805..c5875677 100644 --- a/src/tilegym/ops/cutile/bmm.py +++ b/src/tilegym/ops/cutile/bmm.py @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: MIT -import os from math import ceil from types import SimpleNamespace diff --git a/src/tilegym/ops/cutile/group_gemm.py b/src/tilegym/ops/cutile/group_gemm.py index ca8b6edd..83fed720 100644 --- a/src/tilegym/ops/cutile/group_gemm.py +++ b/src/tilegym/ops/cutile/group_gemm.py @@ -196,7 +196,6 @@ def group_gemm( C = torch.empty((M, N), device=device, dtype=dtype) group_C.append(C) - # Autotune mode stream = torch.cuda.current_stream() cutile_autotune_group_gemm(stream, group_A, group_B, group_C, transpose_b, device) return group_C diff --git a/src/tilegym/ops/cutile/matmul.py b/src/tilegym/ops/cutile/matmul.py index cb0c0a79..ec69530e 100644 --- a/src/tilegym/ops/cutile/matmul.py +++ b/src/tilegym/ops/cutile/matmul.py @@ -375,7 +375,6 @@ def matmul( # Create output tensor c = torch.empty((M, N), device=a.device, dtype=a.dtype) - # Grid calculation stream = torch.cuda.current_stream() if static_persistent: cutile_autotune_static_persistent_matmul(stream, a, b, c, M, N, K, trans_a, trans_b) diff --git a/src/tilegym/ops/cutile/mla.py b/src/tilegym/ops/cutile/mla.py index 2c6dceb2..78d445be 100644 --- a/src/tilegym/ops/cutile/mla.py +++ b/src/tilegym/ops/cutile/mla.py @@ -9,9 +9,13 @@ import cuda.tile as ct import torch from cuda.tile._numeric_semantics import RoundingMode as RMd +from cuda.tile.tune import exhaustive_search from tilegym.backend import register_impl +# Module-level tune cache: (S_qo, TILE_D, TILE_KPE, H, query_group_size, dtype, device) -> (best_cfg, tuned_kernel) +_mla_tune_cache: dict = {} + def _mla_sm80_autotune_configs(): """Pre-SM90 autotune search space for MLA prefill — num_ctas=1 only.""" @@ -166,22 +170,18 @@ def forward(ctx, q, qpe, k, kpe, v, sm_scale, IS_CAUSAL, kernel_configs): else: assert H % num_head_kv == 0 query_group_size = int(H / num_head_kv) - # Launch fmha fwd kernel. - # Autotune runs when ENABLE_CUTILE_TUNE=1 AND caller did not supply explicit - # kernel_configs. Explicit kernel_configs always bypasses autotune so callers - # can pin a fixed config for controlled A/B comparisons. + # Launch fmha fwd kernel using autotune. _gpu_cap = torch.cuda.get_device_capability(q.device) - _use_autotune = os.environ.get("ENABLE_CUTILE_TUNE", "0") == "1" and not kernel_configs.get( - "_user_explicit", False - ) - if _use_autotune: - import cuda.tile_experimental as ct_experimental # lazy — may not be installed - - ct_experimental.autotune_launch( - torch.cuda.current_stream(), - grid_fn=lambda cfg: (math.ceil(S_qo / cfg.TILE_M), B * H, 1), - kernel=prefill_mla, - args_fn=lambda cfg: ( + _configs_fn = _mla_sm80_autotune_configs if _gpu_cap[0] < 9 else _mla_sm90_autotune_configs + stream = torch.cuda.current_stream() + cache_key = (S_qo, TILE_D, TILE_KPE, H, query_group_size, q.dtype, str(q.device)) + if cache_key not in _mla_tune_cache: + result = exhaustive_search( + list(_configs_fn()), + stream, + lambda cfg: (math.ceil(S_qo / cfg.TILE_M), B * H, 1), + prefill_mla, + lambda cfg: ( q, qpe, k, @@ -196,19 +196,20 @@ def forward(ctx, q, qpe, k, kpe, v, sm_scale, IS_CAUSAL, kernel_configs): cfg.TILE_N, query_group_size, ), - hints_fn=lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, - search_space=list(_mla_sm80_autotune_configs() if _gpu_cap[0] < 9 else _mla_sm90_autotune_configs()), + lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, ) - else: - TILE_M = kernel_configs.get("TILE_M", 64 if _gpu_cap[0] < 9 else 256) - TILE_N = kernel_configs.get("TILE_N", 64 if _gpu_cap[0] < 9 else 128) - grid = (math.ceil(S_qo / TILE_M), B * H, 1) - ct.launch( - torch.cuda.current_stream(), - grid, - prefill_mla, - (q, qpe, k, kpe, v, o, sm_scale, TILE_D, TILE_KPE, H, TILE_M, TILE_N, query_group_size), + best_cfg = result.best.config + _mla_tune_cache[cache_key] = ( + best_cfg, + ct.kernel(prefill_mla._pyfunc, num_ctas=best_cfg.num_ctas, occupancy=best_cfg.occupancy), ) + best_cfg, tuned_kernel = _mla_tune_cache[cache_key] + ct.launch( + stream, + (math.ceil(S_qo / best_cfg.TILE_M), B * H, 1), + tuned_kernel, + (q, qpe, k, kpe, v, o, sm_scale, TILE_D, TILE_KPE, H, best_cfg.TILE_M, best_cfg.TILE_N, query_group_size), + ) ctx.save_for_backward(q, k, v, o) ctx.sm_scale = sm_scale ctx.shapes = (B, H, S_qo, S_kv) @@ -238,23 +239,64 @@ def __call__(self, q, k, v, sm_scale, qpe=None, kpe=None): return c +def cutile_autotune_mla(stream, q, qpe, k, kpe, v, o, sm_scale, H, query_group_size): + """Autotuned launch for prefill_mla kernel.""" + B, _, S_qo, TILE_D = q.shape + TILE_KPE = qpe.shape[3] + _gpu_cap = torch.cuda.get_device_capability(q.device) + _configs_fn = _mla_sm80_autotune_configs if _gpu_cap[0] < 9 else _mla_sm90_autotune_configs + cache_key = (S_qo, TILE_D, TILE_KPE, H, query_group_size, q.dtype, str(q.device)) + + if os.environ.get("DISABLE_AUTOTUNE", "0") == "1": + cfg = next(_configs_fn()) + ct.launch( + stream, + (math.ceil(S_qo / cfg.TILE_M), B * H, 1), + prefill_mla, + (q, qpe, k, kpe, v, o, sm_scale, TILE_D, TILE_KPE, H, cfg.TILE_M, cfg.TILE_N, query_group_size), + ) + return + + if cache_key not in _mla_tune_cache: + result = exhaustive_search( + list(_configs_fn()), + stream, + lambda cfg: (math.ceil(S_qo / cfg.TILE_M), B * H, 1), + prefill_mla, + lambda cfg: (q, qpe, k, kpe, v, o, sm_scale, TILE_D, TILE_KPE, H, cfg.TILE_M, cfg.TILE_N, query_group_size), + lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, + ) + best_cfg = result.best.config + _mla_tune_cache[cache_key] = ( + best_cfg, + ct.kernel(prefill_mla._pyfunc, num_ctas=best_cfg.num_ctas, occupancy=best_cfg.occupancy), + ) + best_cfg, tuned_kernel = _mla_tune_cache[cache_key] + ct.launch( + stream, + (math.ceil(S_qo / best_cfg.TILE_M), B * H, 1), + tuned_kernel, + (q, qpe, k, kpe, v, o, sm_scale, TILE_D, TILE_KPE, H, best_cfg.TILE_M, best_cfg.TILE_N, query_group_size), + ) + + def tile_mla(q, k, v, qpe, kpe, is_causal, scaling, **kwargs): + assert is_causal, "CuTile MLA only supports is_causal=True" if scaling is None: scaling = 1.0 / math.sqrt(q.size(-1) + qpe.size(-1)) - gpu_capability = torch.cuda.get_device_capability() - if gpu_capability[0] < 9: - defaults = {"TILE_M": 64, "TILE_N": 64} - else: - defaults = {"TILE_M": 256, "TILE_N": 128} - user_cfg = kwargs.get("kernel_configs") - if user_cfg is None: - kernel_configs = defaults + B, H, S_qo, TILE_D = q.shape + num_head_kv = k.shape[1] + o = torch.empty_like(q) + + if H == num_head_kv: + query_group_size = 0 else: - # Tag so forward() knows to bypass autotune and use the explicit config. - kernel_configs = {**defaults, **user_cfg, "_user_explicit": True} - attention = Attention(is_causal, kernel_configs) - o = attention(q, k, v, scaling, qpe, kpe) + assert H % num_head_kv == 0 + query_group_size = int(H / num_head_kv) + + stream = torch.cuda.current_stream() + cutile_autotune_mla(stream, q, qpe, k, kpe, v, o, scaling, H, query_group_size) return o diff --git a/tests/ops/test_bmm.py b/tests/ops/test_bmm.py index 91c13747..71063379 100644 --- a/tests/ops/test_bmm.py +++ b/tests/ops/test_bmm.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: MIT import itertools -import os import pytest import torch From 423d79b35ec7e4e8a84e1bd3a07e7fa47bd6b9c8 Mon Sep 17 00:00:00 2001 From: Zhiwei Fang Date: Wed, 22 Apr 2026 14:22:32 -0700 Subject: [PATCH 4/6] Fix CUPTI flag --- tests/config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/config.py b/tests/config.py index 8fac079c..d4ea581e 100644 --- a/tests/config.py +++ b/tests/config.py @@ -20,6 +20,8 @@ def __init__( default = os.environ[envvar] if nargs in ["+", "*"]: default = default.split(",") + elif "type" in kwargs and kwargs["type"] is bool: + default = default.lower() not in ("0", "false", "no", "") if required and default is not None: required = False if "choices" in kwargs: From 2cb19a08b256e1fbeae9c23f6182ba1a3c09d633 Mon Sep 17 00:00:00 2001 From: Hannah Li Date: Wed, 22 Apr 2026 20:44:33 -0700 Subject: [PATCH 5/6] Fix attention calling error --- src/tilegym/ops/cutile/attention.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/tilegym/ops/cutile/attention.py b/src/tilegym/ops/cutile/attention.py index 0d12e9c5..141292bd 100644 --- a/src/tilegym/ops/cutile/attention.py +++ b/src/tilegym/ops/cutile/attention.py @@ -779,14 +779,8 @@ def cutile_autotune_fmha( ), ) best_cfg = result.best.config - _fmha_fwd_tune_cache[fwd_cache_key] = ( - best_cfg, - ct.kernel( - fmha_kernel._pyfunc, - num_ctas=best_cfg.num_ctas, - occupancy=best_cfg.occupancy, - ), - ) + tuned_kernel = fmha_kernel + _fmha_fwd_tune_cache[fwd_cache_key] = (best_cfg, tuned_kernel) best_cfg, tuned_kernel = _fmha_fwd_tune_cache[fwd_cache_key] ct.launch( stream, From 161ef03d735b8c55df6365fb4e81b22035d2fc7d Mon Sep 17 00:00:00 2001 From: Hannah Li Date: Sat, 25 Apr 2026 09:18:49 +0800 Subject: [PATCH 6/6] ci, autotune: scope compile timeout to exhaustive_search and raise CI budgets The migration to cuda.tile.tune.exhaustive_search exhaustively searches the entire config space and has no built-in per-config compile timeout, so slow-to-compile configs on sm120 can stall CI. Scope the compile timeout to autotune only, and raise CI step/job budgets to absorb the longer adaptive-repeat measurement loop in the new tune API. - Wrap every cuda.tile.tune.exhaustive_search call site (13 across 10 op files) with `with ct.compiler_timeout(5):` so individual slow configs are killed and routed to result.failures while non-autotune ct.launch compiles remain unaffected. - Bump the test-benchmark job timeout 40 -> 70 min and the "Pull and run benchmarks" step timeout 35 -> 60 min. - Bump the per-benchmark subprocess timeout in run_all_json.py from 10 min -> 20 min. --- .github/workflows/tilegym-ci.yml | 4 +- src/tilegym/ops/cutile/attention.py | 141 +++++++++--------- src/tilegym/ops/cutile/attention_sink.py | 47 +++--- src/tilegym/ops/cutile/bmm.py | 17 ++- src/tilegym/ops/cutile/experimental/mhc.py | 49 +++--- .../ops/cutile/experimental/sparse_mla.py | 49 +++--- src/tilegym/ops/cutile/gemma_attention.py | 55 +++---- src/tilegym/ops/cutile/group_gemm.py | 35 ++--- src/tilegym/ops/cutile/layer_norm_legacy.py | 17 ++- src/tilegym/ops/cutile/matmul.py | 68 +++++---- src/tilegym/ops/cutile/mla.py | 76 ++++++---- tests/benchmark/run_all_json.py | 6 +- 12 files changed, 296 insertions(+), 268 deletions(-) diff --git a/.github/workflows/tilegym-ci.yml b/.github/workflows/tilegym-ci.yml index 3e0c219c..3c31e4cc 100644 --- a/.github/workflows/tilegym-ci.yml +++ b/.github/workflows/tilegym-ci.yml @@ -347,7 +347,7 @@ jobs: test-benchmark: name: test-benchmark needs: [config, build] - timeout-minutes: 40 + timeout-minutes: 70 if: | always() && needs.config.outputs.run_benchmark == 'true' && @@ -409,7 +409,7 @@ jobs: password: ${{ secrets.GITHUB_TOKEN }} - name: Pull and run benchmarks - timeout-minutes: 35 + timeout-minutes: 60 run: | OWNER_LOWER=$(echo '${{ github.repository_owner }}' | tr '[:upper:]' '[:lower:]') IMAGE="ghcr.io/${OWNER_LOWER}/${{ needs.config.outputs.image_name }}:${{ needs.config.outputs.image_tag }}" diff --git a/src/tilegym/ops/cutile/attention.py b/src/tilegym/ops/cutile/attention.py index 141292bd..662c944c 100644 --- a/src/tilegym/ops/cutile/attention.py +++ b/src/tilegym/ops/cutile/attention.py @@ -757,27 +757,28 @@ def cutile_autotune_fmha( str(q.device), ) if fwd_cache_key not in _fmha_fwd_tune_cache: - result = exhaustive_search( - list(_fmha_autotune_configs(hidden_size)), - stream, - lambda cfg: (math.ceil(q_len / cfg.TILE_M), batch_size * num_heads, 1), - fmha_kernel, - lambda cfg: ( - q, - k, - v, - o, - sm_scale, - input_pos, - hidden_size, - num_heads, - cfg.TILE_M, - cfg.TILE_N, - query_group_size, - is_causal, - EVEN_K, - ), - ) + with ct.compiler_timeout(5): + result = exhaustive_search( + list(_fmha_autotune_configs(hidden_size)), + stream, + lambda cfg: (math.ceil(q_len / cfg.TILE_M), batch_size * num_heads, 1), + fmha_kernel, + lambda cfg: ( + q, + k, + v, + o, + sm_scale, + input_pos, + hidden_size, + num_heads, + cfg.TILE_M, + cfg.TILE_N, + query_group_size, + is_causal, + EVEN_K, + ), + ) best_cfg = result.best.config tuned_kernel = fmha_kernel _fmha_fwd_tune_cache[fwd_cache_key] = (best_cfg, tuned_kernel) @@ -1108,31 +1109,32 @@ def fmha_backward( str(q.device), ) if dkdv_cache_key not in _fmha_bwd_dkdv_tune_cache: - result = exhaustive_search( - list(_fmha_bwd_dkdv_autotune_configs(hidden_size)), - stream, - lambda cfg: (math.ceil(k_len / cfg.TILE_N), batch_size * num_head_kv, 1), - fmha_bwd_dkdv_kernel, - lambda cfg: ( - q, - k, - v, - do, - dk, - dv, - lse_flat, - delta_flat, - sm_scale, - TILE_D, - num_heads, - num_head_kv, - padded_q_len, - cfg.TILE_M, - cfg.TILE_N, - query_group_size, - is_causal, - ), - ) + with ct.compiler_timeout(5): + result = exhaustive_search( + list(_fmha_bwd_dkdv_autotune_configs(hidden_size)), + stream, + lambda cfg: (math.ceil(k_len / cfg.TILE_N), batch_size * num_head_kv, 1), + fmha_bwd_dkdv_kernel, + lambda cfg: ( + q, + k, + v, + do, + dk, + dv, + lse_flat, + delta_flat, + sm_scale, + TILE_D, + num_heads, + num_head_kv, + padded_q_len, + cfg.TILE_M, + cfg.TILE_N, + query_group_size, + is_causal, + ), + ) best_cfg = result.best.config _fmha_bwd_dkdv_tune_cache[dkdv_cache_key] = ( best_cfg, @@ -1205,29 +1207,30 @@ def fmha_backward( str(q.device), ) if dq_cache_key not in _fmha_bwd_dq_tune_cache: - result = exhaustive_search( - list(_fmha_bwd_dq_autotune_configs(hidden_size)), - stream, - lambda cfg: (math.ceil(q_len / cfg.TILE_M), batch_size * num_heads, 1), - fmha_bwd_dq_kernel, - lambda cfg: ( - q, - k, - v, - do, - dq, - lse_flat, - delta_flat, - sm_scale, - TILE_D, - num_heads, - padded_q_len, - cfg.TILE_M, - cfg.TILE_N, - query_group_size, - is_causal, - ), - ) + with ct.compiler_timeout(5): + result = exhaustive_search( + list(_fmha_bwd_dq_autotune_configs(hidden_size)), + stream, + lambda cfg: (math.ceil(q_len / cfg.TILE_M), batch_size * num_heads, 1), + fmha_bwd_dq_kernel, + lambda cfg: ( + q, + k, + v, + do, + dq, + lse_flat, + delta_flat, + sm_scale, + TILE_D, + num_heads, + padded_q_len, + cfg.TILE_M, + cfg.TILE_N, + query_group_size, + is_causal, + ), + ) best_cfg = result.best.config _fmha_bwd_dq_tune_cache[dq_cache_key] = ( best_cfg, diff --git a/src/tilegym/ops/cutile/attention_sink.py b/src/tilegym/ops/cutile/attention_sink.py index 4ef089d8..116a3258 100644 --- a/src/tilegym/ops/cutile/attention_sink.py +++ b/src/tilegym/ops/cutile/attention_sink.py @@ -198,29 +198,30 @@ def _cutile_autotune_attention_sink( cache_key = (batch_size, n_heads, n_ctx, head_dim, n_kv_ctx, bandwidth, q.dtype, str(q.device)) if cache_key not in _attention_sink_tune_cache: - result = exhaustive_search( - list(_attention_sink_autotune_configs()), - stream, - lambda cfg: (math.ceil(n_ctx / cfg.TILE_M), batch_size * n_heads, 1), - attention_sink_kernel, - lambda cfg: ( - q, - k, - v, - sinks, - o, - start_q, - sm_scale, - head_dim, - n_heads, - n_kv_ctx, - cfg.TILE_M, - cfg.TILE_N, - repeat_kv, - bandwidth, - ), - lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, - ) + with ct.compiler_timeout(5): + result = exhaustive_search( + list(_attention_sink_autotune_configs()), + stream, + lambda cfg: (math.ceil(n_ctx / cfg.TILE_M), batch_size * n_heads, 1), + attention_sink_kernel, + lambda cfg: ( + q, + k, + v, + sinks, + o, + start_q, + sm_scale, + head_dim, + n_heads, + n_kv_ctx, + cfg.TILE_M, + cfg.TILE_N, + repeat_kv, + bandwidth, + ), + lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, + ) best_cfg = result.best.config _attention_sink_tune_cache[cache_key] = ( best_cfg, diff --git a/src/tilegym/ops/cutile/bmm.py b/src/tilegym/ops/cutile/bmm.py index c5875677..2ae4e52c 100644 --- a/src/tilegym/ops/cutile/bmm.py +++ b/src/tilegym/ops/cutile/bmm.py @@ -283,14 +283,15 @@ def grid_fn(cfg): # Call autotuner to find the best config and execute the kernel cache_key = (batch_size, M, N, K, transpose_a, transpose_b, a.dtype, str(a.device)) if cache_key not in _bmm_tune_cache: - result = exhaustive_search( - list(_bmm_autotune_configs()), - stream, - grid_fn, - ct_static_persistent_bmm_kernel, - args_fn, - lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, - ) + with ct.compiler_timeout(5): + result = exhaustive_search( + list(_bmm_autotune_configs()), + stream, + grid_fn, + ct_static_persistent_bmm_kernel, + args_fn, + lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, + ) best_cfg = result.best.config _bmm_tune_cache[cache_key] = ( best_cfg, diff --git a/src/tilegym/ops/cutile/experimental/mhc.py b/src/tilegym/ops/cutile/experimental/mhc.py index b21923fb..a1ad89ec 100644 --- a/src/tilegym/ops/cutile/experimental/mhc.py +++ b/src/tilegym/ops/cutile/experimental/mhc.py @@ -263,30 +263,31 @@ def cutile_autotune_mhc_split_gemm_rms(stream, x, w, M, N, K, cfg=None): max_num_bid_n = max(ceil(N / cfg.TILE_SIZE_N) for cfg in configs) y_acc = torch.empty((M * max_split_k, N), device=x.device, dtype=torch.float32) r_acc = torch.empty((M * max_split_k, max_num_bid_n), device=x.device, dtype=torch.float32) - result = exhaustive_search( - configs, - stream, - lambda cfg: ( - ceil(M / cfg.TILE_SIZE_M) * ceil(N / cfg.TILE_SIZE_N), - cfg.SPLIT_K, - 1, - ), - mhc_split_gemm_rms_kernel, - lambda cfg: ( - x, - w, - y_acc, - r_acc, - M, - N, - K, - cfg.TILE_SIZE_M, - cfg.TILE_SIZE_N, - cfg.TILE_SIZE_K, - cfg.SPLIT_K, - cfg.GROUP_SIZE_M, - ), - ) + with ct.compiler_timeout(5): + result = exhaustive_search( + configs, + stream, + lambda cfg: ( + ceil(M / cfg.TILE_SIZE_M) * ceil(N / cfg.TILE_SIZE_N), + cfg.SPLIT_K, + 1, + ), + mhc_split_gemm_rms_kernel, + lambda cfg: ( + x, + w, + y_acc, + r_acc, + M, + N, + K, + cfg.TILE_SIZE_M, + cfg.TILE_SIZE_N, + cfg.TILE_SIZE_K, + cfg.SPLIT_K, + cfg.GROUP_SIZE_M, + ), + ) best_cfg = result.best.config # Re-run the winning config with fresh buffers. The autotuner reuses diff --git a/src/tilegym/ops/cutile/experimental/sparse_mla.py b/src/tilegym/ops/cutile/experimental/sparse_mla.py index c84ab0b4..bae468a0 100644 --- a/src/tilegym/ops/cutile/experimental/sparse_mla.py +++ b/src/tilegym/ops/cutile/experimental/sparse_mla.py @@ -288,30 +288,31 @@ def _launch_with_cfg(cfg): else: cache_key = (B, H, S, topk, D, D_PE, query_group_size, q.dtype, str(q.device)) if cache_key not in _sparse_mla_tune_cache: - result = exhaustive_search( - list(_sparse_mla_autotune_configs(topk, H, query_group_size)), - stream, - lambda cfg: (S, B * (H // cfg.TILE_H), 1), - sparse_mla_fwd_kernel, - lambda cfg: ( - q, - k, - v, - indices, - qpe, - kpe, - o, - sm_scale, - D, - D_PE, - H, - cfg.TILE_N, - topk // cfg.TILE_N, - query_group_size, - cfg.TILE_H, - H // cfg.TILE_H, - ), - ) + with ct.compiler_timeout(5): + result = exhaustive_search( + list(_sparse_mla_autotune_configs(topk, H, query_group_size)), + stream, + lambda cfg: (S, B * (H // cfg.TILE_H), 1), + sparse_mla_fwd_kernel, + lambda cfg: ( + q, + k, + v, + indices, + qpe, + kpe, + o, + sm_scale, + D, + D_PE, + H, + cfg.TILE_N, + topk // cfg.TILE_N, + query_group_size, + cfg.TILE_H, + H // cfg.TILE_H, + ), + ) best_cfg = result.best.config _sparse_mla_tune_cache[cache_key] = ( best_cfg, diff --git a/src/tilegym/ops/cutile/gemma_attention.py b/src/tilegym/ops/cutile/gemma_attention.py index 22ee89b3..2f54e16d 100644 --- a/src/tilegym/ops/cutile/gemma_attention.py +++ b/src/tilegym/ops/cutile/gemma_attention.py @@ -320,33 +320,34 @@ def _cutile_autotune_gemma_fmha( str(q.device), ) if cache_key not in _gemma_fmha_tune_cache: - result = exhaustive_search( - list(_gemma_fmha_autotune_configs()), - stream, - lambda cfg: (math.ceil(S_qo / cfg.BLOCK_M), B * H, 1), - gemma_fmha_kernel, - lambda cfg: ( - q, - k, - v, - o, - sm_scale, - B, - H, - S_qo, - S_kv, - BLOCK_D, - cfg.BLOCK_M, - cfg.BLOCK_N, - query_group_size, - stage, - window_size, - soft_cap_val, - has_soft_cap, - (S_kv % cfg.BLOCK_N) == 0, - ), - lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, - ) + with ct.compiler_timeout(5): + result = exhaustive_search( + list(_gemma_fmha_autotune_configs()), + stream, + lambda cfg: (math.ceil(S_qo / cfg.BLOCK_M), B * H, 1), + gemma_fmha_kernel, + lambda cfg: ( + q, + k, + v, + o, + sm_scale, + B, + H, + S_qo, + S_kv, + BLOCK_D, + cfg.BLOCK_M, + cfg.BLOCK_N, + query_group_size, + stage, + window_size, + soft_cap_val, + has_soft_cap, + (S_kv % cfg.BLOCK_N) == 0, + ), + lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, + ) best_cfg = result.best.config _gemma_fmha_tune_cache[cache_key] = ( best_cfg, diff --git a/src/tilegym/ops/cutile/group_gemm.py b/src/tilegym/ops/cutile/group_gemm.py index 83fed720..f480d9df 100644 --- a/src/tilegym/ops/cutile/group_gemm.py +++ b/src/tilegym/ops/cutile/group_gemm.py @@ -127,23 +127,24 @@ def cutile_autotune_group_gemm(stream, group_A, group_B, group_C, transpose_b, d group_shapes = tuple((tuple(A.shape), tuple(B.shape)) for A, B in zip(group_A, group_B)) cache_key = (group_shapes, transpose_b, group_A[0].dtype, str(group_A[0].device)) if cache_key not in _group_gemm_tune_cache: - result = exhaustive_search( - list(_group_gemm_autotune_configs()), - stream, - lambda cfg: (NUM_SMS // cfg.num_ctas * cfg.occupancy, 1, 1), - group_gemm_kernel, - lambda cfg: ( - group_A, - group_B, - group_C, - cfg.TILE_M, - cfg.TILE_N, - cfg.TILE_K, - NUM_SMS // cfg.num_ctas * cfg.occupancy, - transpose_b, - ), - lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, - ) + with ct.compiler_timeout(5): + result = exhaustive_search( + list(_group_gemm_autotune_configs()), + stream, + lambda cfg: (NUM_SMS // cfg.num_ctas * cfg.occupancy, 1, 1), + group_gemm_kernel, + lambda cfg: ( + group_A, + group_B, + group_C, + cfg.TILE_M, + cfg.TILE_N, + cfg.TILE_K, + NUM_SMS // cfg.num_ctas * cfg.occupancy, + transpose_b, + ), + lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, + ) best_cfg = result.best.config _group_gemm_tune_cache[cache_key] = ( best_cfg, diff --git a/src/tilegym/ops/cutile/layer_norm_legacy.py b/src/tilegym/ops/cutile/layer_norm_legacy.py index 8ba63c6a..21b5231e 100644 --- a/src/tilegym/ops/cutile/layer_norm_legacy.py +++ b/src/tilegym/ops/cutile/layer_norm_legacy.py @@ -333,14 +333,15 @@ def grid_fn(cfg): cache_key = (N, D, BLOCK_D, IS_SWISH, TRAINING, COMPUTE_MEAN_AND_RSTD, x.dtype, str(x.device)) if cache_key not in _layer_norm_legacy_tune_cache: - result = exhaustive_search( - pruned_configs, - stream, - grid_fn, - _persistent_layer_norm_fwd_kernel, - args_fn, - lambda cfg: {"num_ctas": cfg.num_ctas}, - ) + with ct.compiler_timeout(5): + result = exhaustive_search( + pruned_configs, + stream, + grid_fn, + _persistent_layer_norm_fwd_kernel, + args_fn, + lambda cfg: {"num_ctas": cfg.num_ctas}, + ) best_cfg = result.best.config _layer_norm_legacy_tune_cache[cache_key] = ( best_cfg, diff --git a/src/tilegym/ops/cutile/matmul.py b/src/tilegym/ops/cutile/matmul.py index ec69530e..3c4f1422 100644 --- a/src/tilegym/ops/cutile/matmul.py +++ b/src/tilegym/ops/cutile/matmul.py @@ -224,14 +224,15 @@ def cutile_autotune_matmul(stream, a, b, c): K = a.shape[1] cache_key = (M, N, K, a.dtype, str(a.device)) if cache_key not in _matmul_tune_cache: - result = exhaustive_search( - list(_matmul_autotune_configs()), - stream, - lambda cfg: (ceil(M / cfg.TILE_SIZE_M) * ceil(N / cfg.TILE_SIZE_N), 1, 1), - matmul_kernel, - lambda cfg: (a, b, c, cfg.TILE_SIZE_M, cfg.TILE_SIZE_N, cfg.TILE_SIZE_K), - lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, - ) + with ct.compiler_timeout(5): + result = exhaustive_search( + list(_matmul_autotune_configs()), + stream, + lambda cfg: (ceil(M / cfg.TILE_SIZE_M) * ceil(N / cfg.TILE_SIZE_N), 1, 1), + matmul_kernel, + lambda cfg: (a, b, c, cfg.TILE_SIZE_M, cfg.TILE_SIZE_N, cfg.TILE_SIZE_K), + lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, + ) best_cfg = result.best.config _matmul_tune_cache[cache_key] = ( best_cfg, @@ -287,31 +288,32 @@ def cutile_autotune_static_persistent_matmul(stream, a, b, c, M, N, K, trans_a, NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count cache_key = (M, N, K, trans_a, trans_b, a.dtype, str(a.device)) if cache_key not in _static_persistent_matmul_tune_cache: - result = exhaustive_search( - list(_static_persistent_matmul_autotune_configs()), - stream, - lambda cfg: ( - min(NUM_SMS // cfg.num_ctas, ceil(M / cfg.TILE_SIZE_M) * ceil(N / cfg.TILE_SIZE_N)) * cfg.occupancy, - 1, - 1, - ), - static_persistent_matmul_kernel, - lambda cfg: ( - a, - b, - c, - M, - N, - K, - cfg.TILE_SIZE_M, - cfg.TILE_SIZE_N, - cfg.TILE_SIZE_K, - trans_a, - trans_b, - cfg.GROUP_SIZE_M, - ), - lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, - ) + with ct.compiler_timeout(5): + result = exhaustive_search( + list(_static_persistent_matmul_autotune_configs()), + stream, + lambda cfg: ( + min(NUM_SMS // cfg.num_ctas, ceil(M / cfg.TILE_SIZE_M) * ceil(N / cfg.TILE_SIZE_N)) * cfg.occupancy, + 1, + 1, + ), + static_persistent_matmul_kernel, + lambda cfg: ( + a, + b, + c, + M, + N, + K, + cfg.TILE_SIZE_M, + cfg.TILE_SIZE_N, + cfg.TILE_SIZE_K, + trans_a, + trans_b, + cfg.GROUP_SIZE_M, + ), + lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, + ) best_cfg = result.best.config _static_persistent_matmul_tune_cache[cache_key] = ( best_cfg, diff --git a/src/tilegym/ops/cutile/mla.py b/src/tilegym/ops/cutile/mla.py index 78d445be..9c9224e9 100644 --- a/src/tilegym/ops/cutile/mla.py +++ b/src/tilegym/ops/cutile/mla.py @@ -176,28 +176,29 @@ def forward(ctx, q, qpe, k, kpe, v, sm_scale, IS_CAUSAL, kernel_configs): stream = torch.cuda.current_stream() cache_key = (S_qo, TILE_D, TILE_KPE, H, query_group_size, q.dtype, str(q.device)) if cache_key not in _mla_tune_cache: - result = exhaustive_search( - list(_configs_fn()), - stream, - lambda cfg: (math.ceil(S_qo / cfg.TILE_M), B * H, 1), - prefill_mla, - lambda cfg: ( - q, - qpe, - k, - kpe, - v, - o, - sm_scale, - TILE_D, - TILE_KPE, - H, - cfg.TILE_M, - cfg.TILE_N, - query_group_size, - ), - lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, - ) + with ct.compiler_timeout(5): + result = exhaustive_search( + list(_configs_fn()), + stream, + lambda cfg: (math.ceil(S_qo / cfg.TILE_M), B * H, 1), + prefill_mla, + lambda cfg: ( + q, + qpe, + k, + kpe, + v, + o, + sm_scale, + TILE_D, + TILE_KPE, + H, + cfg.TILE_M, + cfg.TILE_N, + query_group_size, + ), + lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, + ) best_cfg = result.best.config _mla_tune_cache[cache_key] = ( best_cfg, @@ -258,14 +259,29 @@ def cutile_autotune_mla(stream, q, qpe, k, kpe, v, o, sm_scale, H, query_group_s return if cache_key not in _mla_tune_cache: - result = exhaustive_search( - list(_configs_fn()), - stream, - lambda cfg: (math.ceil(S_qo / cfg.TILE_M), B * H, 1), - prefill_mla, - lambda cfg: (q, qpe, k, kpe, v, o, sm_scale, TILE_D, TILE_KPE, H, cfg.TILE_M, cfg.TILE_N, query_group_size), - lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, - ) + with ct.compiler_timeout(5): + result = exhaustive_search( + list(_configs_fn()), + stream, + lambda cfg: (math.ceil(S_qo / cfg.TILE_M), B * H, 1), + prefill_mla, + lambda cfg: ( + q, + qpe, + k, + kpe, + v, + o, + sm_scale, + TILE_D, + TILE_KPE, + H, + cfg.TILE_M, + cfg.TILE_N, + query_group_size, + ), + lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, + ) best_cfg = result.best.config _mla_tune_cache[cache_key] = ( best_cfg, diff --git a/tests/benchmark/run_all_json.py b/tests/benchmark/run_all_json.py index ec22ec4e..a3393537 100755 --- a/tests/benchmark/run_all_json.py +++ b/tests/benchmark/run_all_json.py @@ -199,7 +199,7 @@ def run_benchmark(benchmark_file: Path) -> Dict[str, Any]: [sys.executable, str(benchmark_file)], capture_output=True, text=True, - timeout=600, # 10 minute timeout per benchmark + timeout=1200, # 20 minute timeout per benchmark cwd=benchmark_file.parent, ) @@ -232,8 +232,8 @@ def run_benchmark(benchmark_file: Path) -> Dict[str, Any]: "benchmark_file": benchmark_file.name, "status": "TIMEOUT", "error_type": "TimeoutError", - "error_message": "Benchmark exceeded 10 minute timeout", - "error": "Benchmark exceeded 10 minute timeout", + "error_message": "Benchmark exceeded 20 minute timeout", + "error": "Benchmark exceeded 20 minute timeout", "benchmarks": [], } except Exception as e: