diff --git a/src/tilegym/ops/cutile/__init__.py b/src/tilegym/ops/cutile/__init__.py index 85ae8f8..a5a2d61 100644 --- a/src/tilegym/ops/cutile/__init__.py +++ b/src/tilegym/ops/cutile/__init__.py @@ -19,6 +19,7 @@ from . import flash_decode from . import group_gemm from . import matmul + from . import mhc from . import mla from . import mla_decoding from . import mla_decoding_split_kv @@ -33,6 +34,9 @@ # Import specific functions for direct access from .flash_decode import fmha_decode + from .mhc import mhc_apply_residual + from .mhc import mhc_gemm_rms_scale + from .mhc import mhc_sinkhorn from .moe import fused_moe_kernel as invoke_fused_moe_kernel from .moe_align_block import moe_align_block_size from .rms_norm import get_rms_norm_module @@ -60,6 +64,9 @@ "get_apply_rope_func", "get_rms_norm_module", "rms_norm", + "mhc_gemm_rms_scale", + "mhc_apply_residual", + "mhc_sinkhorn", "silu_and_mul", "dropout", "softmax", @@ -73,6 +80,7 @@ "bmm", "matmul", "group_gemm", + "mhc", ] else: __all__ = [] diff --git a/src/tilegym/ops/cutile/mhc.py b/src/tilegym/ops/cutile/mhc.py new file mode 100644 index 0000000..a5fe2a7 --- /dev/null +++ b/src/tilegym/ops/cutile/mhc.py @@ -0,0 +1,554 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +from math import ceil +from types import SimpleNamespace + +import cuda.tile as ct +import cuda.tile_experimental as ct_experimental +import torch + +from tilegym.backend import register_impl +from tilegym.logger import get_logger + +logger = get_logger(__name__) + +# Type aliases for constants +ConstInt = ct.Constant[int] +LOG2E = 1.4426950408889634 + + +def _compute_bid(tile_id, num_bid_in_group, num_bid_m, GROUP_SIZE_M): + group_id = tile_id // num_bid_in_group + first_bid_m = group_id * GROUP_SIZE_M + group_size_m = ct.minimum(num_bid_m - first_bid_m, GROUP_SIZE_M) + bid_m = first_bid_m + (tile_id % group_size_m) + bid_n = (tile_id % num_bid_in_group) // group_size_m + return bid_m, bid_n + + +def _sigmoid(x): + return 1.0 / (1.0 + ct.exp(-x)) + + +@ct.kernel +def mhc_split_gemm_rms_kernel( + X, + W, + Y_acc, + R_acc, + M: int, + N: int, + K: int, + TILE_SIZE_M: ConstInt, + TILE_SIZE_N: ConstInt, + TILE_SIZE_K: ConstInt, + SPLIT_K: ConstInt, + GROUP_SIZE_M: ConstInt, +): + """Split-K fused GEMM + RMS compute kernel for mHC. + + Key optimization: All blocks compute RMS to avoid wasting registers. + Each block computes partial RMS for its K-tile range, which are later + summed in the finalize kernel. + """ + tile_id = ct.bid(0) + bid_k = ct.bid(1) + zero_pad = ct.PaddingMode.ZERO + + num_bid_m = ct.cdiv(M, TILE_SIZE_M) + num_bid_n = ct.cdiv(N, TILE_SIZE_N) + num_bid_in_group = GROUP_SIZE_M * num_bid_n + bid_m, bid_n = _compute_bid(tile_id, num_bid_in_group, num_bid_m, GROUP_SIZE_M) + k_tiles = ct.cdiv(K, TILE_SIZE_K) + k_tiles_per_split = ct.cdiv(k_tiles, SPLIT_K) + k_tile_start = bid_k * k_tiles_per_split + k_tile_end = ct.minimum(k_tile_start + k_tiles_per_split, k_tiles) + + rms_acc = ct.full((TILE_SIZE_M,), 0.0, dtype=ct.float32) + accumulator = ct.full((TILE_SIZE_M, TILE_SIZE_N), 0.0, dtype=ct.float32) + mma_dtype = ct.tfloat32 if (X.dtype == ct.float32 or W.dtype == ct.float32) else X.dtype + + for k_tile in range(k_tile_start, k_tile_end): + a = ct.load( + X, + index=(bid_m, k_tile), + shape=(TILE_SIZE_M, TILE_SIZE_K), + padding_mode=zero_pad, + allow_tma=True, + ) + b = ct.load( + W, + index=(k_tile, bid_n), + shape=(TILE_SIZE_K, TILE_SIZE_N), + padding_mode=zero_pad, + allow_tma=True, + ) + + a_mma = ct.astype(a, mma_dtype) + b_mma = ct.astype(b, mma_dtype) + accumulator = ct.mma(a_mma, b_mma, acc=accumulator) + + a_fp32 = ct.astype(a, ct.float32) + rms_acc = rms_acc + ct.sum(a_fp32 * a_fp32, axis=1, keepdims=False) + + bid_m_k = bid_m + bid_k * num_bid_m + ct.store(Y_acc, index=(bid_m_k, bid_n), tile=accumulator) + + # Store RMS partial results - will be summed across bid_n in finalize kernel + # Using bid_n as additional dimension for partial sums + ct.store(R_acc, index=(bid_m_k, bid_n), tile=ct.reshape(rms_acc, (TILE_SIZE_M, 1))) + + +@ct.kernel +def mhc_finalize_scale_bias_sigmoid_kernel( + Y_acc, + R_acc, + Y, + R, + n: int, + alpha_pre: float, + alpha_post: float, + alpha_res: float, + Bias, + M: int, + N: int, + K: int, + TILE_SIZE_M: ConstInt, + TILE_SIZE_N: ConstInt, + SPLIT_K: ConstInt, +): + """Finalize split-K + fused scale/bias/sigmoid kernel for mHC.""" + bid_m = ct.bid(0) + bid_n = ct.bid(1) + + num_bid_m = ct.cdiv(M, TILE_SIZE_M) + num_bid_n = ct.cdiv(N, TILE_SIZE_N) + + y_accum = ct.full((TILE_SIZE_M, TILE_SIZE_N), 0.0, dtype=ct.float32) + r_accum = ct.full((TILE_SIZE_M, 1), 0.0, dtype=ct.float32) + + # Sum across split_k dimension + for split_idx in range(SPLIT_K): + bid_m_k = bid_m + split_idx * num_bid_m + y_tile = ct.load( + Y_acc, + index=(bid_m_k, bid_n), + shape=(TILE_SIZE_M, TILE_SIZE_N), + padding_mode=ct.PaddingMode.ZERO, + ) + y_accum = y_accum + y_tile + + # RMS is independent of bid_n; each bid_n block stores the same partial RMS. + # Loading the current bid_n avoids over-counting when num_bid_n > 1. + r_tile = ct.load( + R_acc, + index=(bid_m_k, bid_n), + shape=(TILE_SIZE_M, 1), + padding_mode=ct.PaddingMode.ZERO, + ) + r_tile = ct.astype(r_tile, ct.float32) + r_accum = r_accum + r_tile + + denom = ct.full((TILE_SIZE_M, 1), K * 1.0, dtype=ct.float32) + mean = ct.truediv(r_accum, denom) + rstd = ct.rsqrt(mean) + ones = ct.full((TILE_SIZE_M, 1), 1.0, dtype=ct.float32) + r = ct.truediv(ones, rstd) + if bid_n == 0: + r_out = ct.astype(r, R.dtype) + ct.store(R, index=(bid_m, 0), tile=r_out) + + offsets = ct.arange(TILE_SIZE_N, dtype=ct.int32) + col_ids = bid_n * TILE_SIZE_N + offsets + bias = ct.load(Bias, index=(bid_n,), shape=(TILE_SIZE_N,), padding_mode=ct.PaddingMode.ZERO) + bias = ct.reshape(bias, (1, TILE_SIZE_N)) + + one = ct.full((TILE_SIZE_N,), 1.0, dtype=ct.float32) + zero = ct.full((TILE_SIZE_N,), 0.0, dtype=ct.float32) + mask_pre = ct.where(ct.less(col_ids, n), one, zero) + mask_post = ct.where(ct.less(col_ids, 2 * n), one, zero) + mask_post = mask_post - mask_pre + mask_res = one - mask_pre - mask_post + + scale = alpha_pre * mask_pre + alpha_post * mask_post + alpha_res * mask_res + scale = ct.reshape(scale, (1, TILE_SIZE_N)) + + linear = ct.truediv(y_accum * scale, r) + ct.astype(bias, ct.float32) + sigmoid_linear = _sigmoid(linear) + two_sigmoid = sigmoid_linear * 2.0 + + mask_pre = ct.reshape(mask_pre, (1, TILE_SIZE_N)) + mask_post = ct.reshape(mask_post, (1, TILE_SIZE_N)) + mask_res = ct.reshape(mask_res, (1, TILE_SIZE_N)) + + out = linear * mask_res + sigmoid_linear * mask_pre + two_sigmoid * mask_post + out = ct.astype(out, Y.dtype) + ct.store(Y, index=(bid_m, bid_n), tile=out) + + +def _mhc_split_gemm_rms_autotune_configs(): + tile_ms = (64, 128) + tile_ks = (64, 128) + split_ks = (1, 2, 4, 8, 16) + group_size_ms = (8, 16) + tile_n = 32 + for tile_m in tile_ms: + for tile_k in tile_ks: + for split_k in split_ks: + for group_size_m in group_size_ms: + yield SimpleNamespace( + TILE_SIZE_M=tile_m, + TILE_SIZE_N=tile_n, + TILE_SIZE_K=tile_k, + SPLIT_K=split_k, + GROUP_SIZE_M=group_size_m, + ) + + +def cutile_autotune_mhc_split_gemm_rms(stream, x, w, M, N, K, cfg=None): + if cfg is not None: + if isinstance(cfg, dict): + cfg = SimpleNamespace(**cfg) + if not hasattr(cfg, "TILE_SIZE_M") and hasattr(cfg, "m"): + cfg.TILE_SIZE_M = cfg.m + if not hasattr(cfg, "TILE_SIZE_N") and hasattr(cfg, "n"): + cfg.TILE_SIZE_N = cfg.n + if not hasattr(cfg, "TILE_SIZE_K") and hasattr(cfg, "k"): + cfg.TILE_SIZE_K = cfg.k + if not hasattr(cfg, "SPLIT_K") and hasattr(cfg, "split_k"): + cfg.SPLIT_K = cfg.split_k + if not hasattr(cfg, "GROUP_SIZE_M") and hasattr(cfg, "group_size_m"): + cfg.GROUP_SIZE_M = cfg.group_size_m + + num_bid_n = ceil(N / cfg.TILE_SIZE_N) + y_acc = torch.empty((M * cfg.SPLIT_K, N), device=x.device, dtype=torch.float32) + # R_acc now stores partial RMS for all N blocks + r_acc = torch.empty((M * cfg.SPLIT_K, num_bid_n), device=x.device, dtype=torch.float32) + grid = ( + ceil(M / cfg.TILE_SIZE_M) * ceil(N / cfg.TILE_SIZE_N), + cfg.SPLIT_K, + 1, + ) + ct.launch( + stream, + grid, + mhc_split_gemm_rms_kernel, + ( + x, + w, + y_acc, + r_acc, + M, + N, + K, + cfg.TILE_SIZE_M, + cfg.TILE_SIZE_N, + cfg.TILE_SIZE_K, + cfg.SPLIT_K, + cfg.GROUP_SIZE_M, + ), + ) + return y_acc, r_acc, cfg + + configs = list(_mhc_split_gemm_rms_autotune_configs()) + max_split_k = max(cfg.SPLIT_K for cfg in configs) + # Need max num_bid_n across all configs + max_num_bid_n = max(ceil(N / cfg.TILE_SIZE_N) for cfg in configs) + y_acc = torch.empty((M * max_split_k, N), device=x.device, dtype=torch.float32) + r_acc = torch.empty((M * max_split_k, max_num_bid_n), device=x.device, dtype=torch.float32) + tuned = ct_experimental.autotune_launch( + stream, + grid_fn=lambda cfg: ( + ceil(M / cfg.TILE_SIZE_M) * ceil(N / cfg.TILE_SIZE_N), + cfg.SPLIT_K, + 1, + ), + kernel=mhc_split_gemm_rms_kernel, + args_fn=lambda cfg: ( + x, + w, + y_acc, + r_acc, + M, + N, + K, + cfg.TILE_SIZE_M, + cfg.TILE_SIZE_N, + cfg.TILE_SIZE_K, + cfg.SPLIT_K, + cfg.GROUP_SIZE_M, + ), + search_space=configs, + ) + best_cfg = tuned.tuned_config + return y_acc, r_acc, best_cfg + + +def mhc_split_gemm_rms(x: torch.Tensor, w: torch.Tensor, **kwargs): + M, K = x.shape + KB, N = w.shape + assert K == KB, f"Incompatible matrices: K dimension of X is {K}, K dimension of W is {KB}" + + cfg = kwargs.pop("cfg", None) + kwargs.pop("w_nt", None) + w = w.contiguous() + + stream = torch.cuda.current_stream() + return cutile_autotune_mhc_split_gemm_rms(stream, x, w, M, N, K, cfg=cfg) + + +def mhc_finalize_scale_bias_sigmoid( + y_acc: torch.Tensor, + r_acc: torch.Tensor, + n: int, + alpha_pre: float, + alpha_post: float, + alpha_res: float, + bias: torch.Tensor, + M: int, + K: int, + **kwargs, +): + cfg = kwargs.pop("cfg", None) + split_k = kwargs.pop("split_k", None) + tile_m = kwargs.pop("tile_m", None) + tile_n = kwargs.pop("tile_n", None) + if cfg is not None: + tile_m = cfg.TILE_SIZE_M + tile_n = cfg.TILE_SIZE_N + split_k = cfg.SPLIT_K + + y_acc = y_acc.contiguous() + r_acc = r_acc.contiguous() + bias = bias.contiguous() + N = y_acc.shape[1] + + y = torch.empty((M, N), device=y_acc.device, dtype=bias.dtype) + r = torch.empty((M, 1), device=y_acc.device, dtype=torch.float32) + + grid = (ceil(M / tile_m), ceil(N / tile_n), 1) + ct.launch( + torch.cuda.current_stream(), + grid, + mhc_finalize_scale_bias_sigmoid_kernel, + ( + y_acc, + r_acc, + y, + r, + n, + float(alpha_pre), + float(alpha_post), + float(alpha_res), + bias, + M, + N, + K, + tile_m, + tile_n, + split_k, + ), + ) + return y, r + + +@register_impl("mhc_gemm_rms_scale", backend="cutile") +def mhc_gemm_rms_scale( + x: torch.Tensor, + w: torch.Tensor, + n: int, + alpha_pre: float, + alpha_post: float, + alpha_res: float, + bias: torch.Tensor, + **kwargs, +): + cfg = kwargs.pop("cfg", None) + kwargs.pop("w_nt", None) + w = w.contiguous() + + M, K = x.shape + _, N = w.shape + y_acc, r_acc, cfg = cutile_autotune_mhc_split_gemm_rms( + torch.cuda.current_stream(), + x, + w, + M, + N, + K, + cfg=cfg, + ) + return mhc_finalize_scale_bias_sigmoid( + y_acc, + r_acc, + n, + alpha_pre, + alpha_post, + alpha_res, + bias, + M, + K, + cfg=cfg, + ) + + +@ct.kernel +def mhc_apply_residual_kernel( + X, + F_out, + Y_post, + Y_res, + Out, + C: int, + n: ct.Constant[int], + TILE_SIZE_C: ConstInt, +): + """Apply H_res and H_post to residual stream (in-place on Out).""" + # Shapes: + # - X: [B, n, C] view of residual stream + # - F_out: [B, C] + # - Y_post: [B, n] + # - Y_res: [B, n, n] + # - Out: [B, n, C] + row = ct.bid(0) + c_tile = ct.bid(1) + compute_dtype = ( + ct.float32 if (X.dtype == ct.float32 or F_out.dtype == ct.float32 or Y_post.dtype == ct.float32) else X.dtype + ) + + f_tile = ct.load( + F_out, + index=(row, c_tile), + shape=(1, TILE_SIZE_C), + padding_mode=ct.PaddingMode.ZERO, + ) + f_tile = ct.astype(f_tile, compute_dtype) + + h_post = ct.load( + Y_post, + index=(row, 0), + shape=(1, n), + padding_mode=ct.PaddingMode.ZERO, + ) + h_post = ct.reshape(h_post, (n, 1)) + h_post = ct.astype(h_post, compute_dtype) + + h_res = ct.load( + Y_res, + index=(row, 0, 0), + shape=(1, n, n), + padding_mode=ct.PaddingMode.ZERO, + ) + h_res = ct.reshape(h_res, (n, n)) + h_res = ct.astype(h_res, compute_dtype) + + acc = ct.full((n, TILE_SIZE_C), 0.0, dtype=compute_dtype) + for j in range(n): + x_row = ct.load( + X, + index=(row, j, c_tile), + shape=(1, 1, TILE_SIZE_C), + padding_mode=ct.PaddingMode.ZERO, + ) + x_row = ct.reshape(x_row, (1, TILE_SIZE_C)) + x_row = ct.astype(x_row, compute_dtype) + h_col = ct.extract(h_res, (0, j), shape=(n, 1)) + x_row = ct.broadcast_to(x_row, (n, TILE_SIZE_C)) + h_col = ct.broadcast_to(h_col, (n, TILE_SIZE_C)) + prod = h_col * x_row + acc = acc + prod + h_post = ct.broadcast_to(h_post, (n, TILE_SIZE_C)) + f_tile = ct.broadcast_to(f_tile, (n, TILE_SIZE_C)) + x_post = h_post * f_tile + out_tile = acc + x_post + out_tile = ct.astype(out_tile, Out.dtype) + out_tile = ct.reshape(out_tile, (1, n, TILE_SIZE_C)) + ct.store(Out, index=(row, 0, c_tile), tile=out_tile) + + +@register_impl("mhc_apply_residual", backend="cutile") +def mhc_apply_residual( + x: torch.Tensor, + f_out: torch.Tensor, + y: torch.Tensor, + n: int, + **kwargs, +): + x = x.contiguous() + f_out = f_out.contiguous() + y = y.contiguous() + B, nC = x.shape + C = f_out.shape[1] + # Use view for [B, n, C] without changing external layout. + x_view = x.view(B, n, C) + y_post = y.narrow(1, n, n) + y_res = y.narrow(1, 2 * n, n * n).view(B, n, n) + out = torch.empty_like(x) + out_view = out.view(B, n, C) + + TILE_SIZE_C = 1024 + grid = (B, C // TILE_SIZE_C, 1) + ct.launch( + torch.cuda.current_stream(), + grid, + mhc_apply_residual_kernel, + ( + x_view, + f_out, + y_post, + y_res, + out_view, + C, + n, + TILE_SIZE_C, + ), + ) + return out + + +@ct.kernel +def mhc_sinkhorn_kernel( + Y, + n: ct.Constant[int], +): + """Sinkhorn-Knopp normalization for residual block (in-place on Y).""" + row = ct.bid(0) + total = n * n + mat = ct.load(Y, index=(row, 0), shape=(1, total)) + mat = ct.reshape(mat, (n, n)) + mat = ct.astype(mat, ct.float32) + mat = ct.exp2(mat * LOG2E) + + for _ in range(20): + row_sum = ct.sum(mat, axis=1, keepdims=True) + mat = ct.truediv(mat, row_sum) + col_sum = ct.sum(mat, axis=0, keepdims=True) + mat = ct.truediv(mat, col_sum) + + mat = ct.reshape(mat, (1, total)) + mat = ct.astype(mat, Y.dtype) + ct.store(Y, index=(row, 0), tile=mat) + + +@register_impl("mhc_sinkhorn", backend="cutile") +def mhc_sinkhorn( + y: torch.Tensor, + n: int, + **kwargs, +): + y = y.contiguous() + M, _ = y.shape + y_view = y.narrow(1, 2 * n, n * n) + grid = (M,) + ct.launch( + torch.cuda.current_stream(), + grid, + mhc_sinkhorn_kernel, + ( + y_view, + n, + ), + ) + return y diff --git a/src/tilegym/ops/ops.py b/src/tilegym/ops/ops.py index 9f45a35..aee39f1 100644 --- a/src/tilegym/ops/ops.py +++ b/src/tilegym/ops/ops.py @@ -417,6 +417,86 @@ def splitk_reduce( raise NotImplementedError(f"splitk_reduce is not implemented for {get_current_backend()}") +@dispatch( + "mhc_gemm_rms_scale", +) +def mhc_gemm_rms_scale( + x: torch.Tensor, + w: torch.Tensor, + n: int, + alpha_pre: float, + alpha_post: float, + alpha_res: float, + bias: torch.Tensor, + **kwargs: Any, +): + """ + GEMM + RMS reduce + scale/bias/sigmoid for mHC. + + Args: + x: Input matrix X (M, K) + w: Weight matrix W (K, N) + n: Expansion factor + alpha_pre: Scalar for pre mixing + alpha_post: Scalar for post mixing + alpha_res: Scalar for residual mixing + bias: Bias vector of shape (N,) + **kwargs: Additional arguments for backend-specific configurations + + Returns: + Tuple[torch.Tensor, torch.Tensor]: (Y, R) + """ + raise NotImplementedError(f"mhc_gemm_rms_scale is not implemented for {get_current_backend()}") + + +@dispatch( + "mhc_apply_residual", +) +def mhc_apply_residual( + x: torch.Tensor, + f_out: torch.Tensor, + y: torch.Tensor, + n: int, + **kwargs: Any, +): + """ + Apply H_res and H_post to residual stream. + + Args: + x: Input tensor X with shape (B, nC) + f_out: Output tensor from block with shape (B, C) + y: Coefficients tensor with shape (B, n^2 + 2n) + n: Expansion factor + **kwargs: Additional arguments for backend-specific configurations + + Returns: + torch.Tensor: Output tensor with shape (B, nC) + """ + raise NotImplementedError(f"mhc_apply_residual is not implemented for {get_current_backend()}") + + +@dispatch( + "mhc_sinkhorn", +) +def mhc_sinkhorn( + y: torch.Tensor, + n: int, + **kwargs: Any, +): + """ + Sinkhorn-Knopp normalization for residual block (in-place on Y). + + Args: + y: Input/output matrix Y (M, N), modified in-place + n: Expansion factor + **kwargs: Additional arguments for backend-specific configurations + + Returns: + torch.Tensor: Output matrix (M, N) + """ + raise NotImplementedError(f"mhc_sinkhorn is not implemented for {get_current_backend()}") + + # ============================================================================ # Linear Algebra Operations # ============================================================================ diff --git a/tests/benchmark/bench_mhc.py b/tests/benchmark/bench_mhc.py new file mode 100644 index 0000000..8aa5bfc --- /dev/null +++ b/tests/benchmark/bench_mhc.py @@ -0,0 +1,446 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import torch +import triton + +import tilegym +from tilegym.backend import is_backend_available +from tilegym.backend import register_impl + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + +X_DTYPE = torch.bfloat16 +W_DTYPE = torch.float32 +OTHER_DTYPE = torch.float32 + + +def _is_fp8_dtype(dtype): + return "float8" in str(dtype) + + +def _get_benchmark_dtypes(): + dtypes = [X_DTYPE] + return dtypes + + +def _randn(shape, dtype, device): + if _is_fp8_dtype(dtype): + return torch.randn(shape, device=device, dtype=torch.float16).to(dtype) + return torch.randn(shape, device=device, dtype=dtype) + + +BENCH_DTYPES = _get_benchmark_dtypes() + + +def reference_mhc_apply_residual( + x: torch.Tensor, + f_out: torch.Tensor, + y: torch.Tensor, + n: int, +): + B, nC = x.shape + C = nC // n + x_view = x.view(B, n, C) + f_view = f_out + h_post = y[:, n : 2 * n] + h_res = y[:, 2 * n : 2 * n + n * n].view(B, n, n) + x_res = torch.matmul(h_res, x_view) + x_post = h_post.unsqueeze(-1) * f_view.unsqueeze(1) + x_next = x_res + x_post + return x_next.to(x.dtype).view(B, nC) + + +def reference_mhc_gemm_rms_scale( + x: torch.Tensor, + w: torch.Tensor, + n: int, + alpha_pre: float, + alpha_post: float, + alpha_res: float, + bias: torch.Tensor, + w_nt: torch.Tensor = None, +): + x_fp32 = x.to(torch.float32) + if w_nt is not None: + w_nt_fp32 = w_nt.to(torch.float32) + y = x_fp32 @ w_nt_fp32.transpose(0, 1) + else: + w_fp32 = w.to(torch.float32) + y = x_fp32 @ w_fp32 + rms = torch.sqrt(x_fp32.pow(2).mean(dim=1, keepdim=True)) + scale = torch.empty((y.shape[1],), dtype=torch.float32, device=y.device) + scale[:n] = alpha_pre + scale[n : 2 * n] = alpha_post + scale[2 * n :] = alpha_res + linear = y * scale / rms + bias.to(torch.float32) + out = linear.clone() + out[:, :n] = torch.sigmoid(linear[:, :n]) + out[:, n : 2 * n] = 2.0 * torch.sigmoid(linear[:, n : 2 * n]) + return out.to(bias.dtype), rms + + +def reference_mhc_sinkhorn( + y: torch.Tensor, + n: int, +): + start = 2 * n + end = start + n * n + mat = y[:, start:end].to(torch.float32).reshape(-1, n, n) + mat = torch.exp(mat) + for _ in range(20): + mat = mat / mat.sum(dim=2, keepdim=True) + mat = mat / mat.sum(dim=1, keepdim=True) + y[:, start:end] = mat.reshape(-1, n * n).to(y.dtype) + return y + + +register_impl("mhc_apply_residual", "torch")(reference_mhc_apply_residual) +register_impl("mhc_gemm_rms_scale", "torch")(reference_mhc_gemm_rms_scale) +register_impl("mhc_sinkhorn", "torch")(reference_mhc_sinkhorn) + + +ALL_BACKENDS = [ + ("cutile", "CuTile", ("orange", "-")) if is_backend_available("cutile") else None, + ("torch", "PyTorch", ("green", "-")), +] + +# Try to add deepgemm backend if available +try: + import deep_gemm + + DEEPGEMM_BACKEND = ("deepgemm", "DeepGemm", ("blue", "-.")) + DEEPGEMM_AVAILABLE = True +except ImportError: + DEEPGEMM_BACKEND = None + DEEPGEMM_AVAILABLE = False + + +def get_supported_backends(): + return [p for p in ALL_BACKENDS if p is not None] + + +def get_supported_backends_split_gemm_rms(): + backends = [ + ("cutile", "CuTile", ("orange", "-")) if is_backend_available("cutile") else None, + ("torch", "PyTorch", ("green", "-")), + DEEPGEMM_BACKEND if DEEPGEMM_AVAILABLE else None, + ] + return [p for p in backends if p is not None] + + +def create_gemm_rms_scale_benchmark_config(dtype): + available_backends = get_supported_backends() + if not available_backends: + return None + + backends, names, styles = zip(*available_backends) + dtype_name = str(dtype).split(".")[-1] + n = 4 + return triton.testing.Benchmark( + x_names=["M"], + x_vals=[8192], + line_arg="backend", + line_vals=list(backends), + line_names=list(names), + styles=list(styles), + ylabel="GB/s", + plot_name=f"mhc-gemm-rms-scale-performance-{dtype_name}-GBps", + args={ + "dtype": dtype, + "K": n * 7168, + "N": n * n + 2 * n, + "n": n, + }, + ) + + +def create_sinkhorn_benchmark_config(dtype): + available_backends = get_supported_backends() + if not available_backends: + return None + + backends, names, styles = zip(*available_backends) + dtype_name = str(dtype).split(".")[-1] + n = 4 + return triton.testing.Benchmark( + x_names=["M"], + x_vals=[8192], + line_arg="backend", + line_vals=list(backends), + line_names=list(names), + styles=list(styles), + ylabel="GB/s", + plot_name=f"mhc-sinkhorn-performance-{dtype_name}-GBps", + args={ + "dtype": dtype, + "N": n * n + 2 * n, + "n": n, + }, + ) + + +def create_apply_residual_benchmark_config(dtype): + available_backends = get_supported_backends() + if not available_backends: + return None + + backends, names, styles = zip(*available_backends) + dtype_name = str(dtype).split(".")[-1] + n = 4 + return triton.testing.Benchmark( + x_names=["M"], + x_vals=[8192], + line_arg="backend", + line_vals=list(backends), + line_names=list(names), + styles=list(styles), + ylabel="GB/s", + plot_name=f"mhc-apply-residual-performance-{dtype_name}-GBps", + args={ + "dtype": dtype, + "C": 7168, + "n": n, + }, + ) + + +def create_split_gemm_rms_benchmark_config(dtype): + available_backends = get_supported_backends_split_gemm_rms() + if not available_backends: + return None + + backends, names, styles = zip(*available_backends) + dtype_name = str(dtype).split(".")[-1] + n = 4 + return triton.testing.Benchmark( + x_names=["M"], + x_vals=[8192], + line_arg="backend", + line_vals=list(backends), + line_names=list(names), + styles=list(styles), + ylabel="GB/s", + plot_name=f"mhc-split-gemm-rms-performance-{dtype_name}-GBps", + args={ + "dtype": dtype, + "K": n * 7168, + "N": n * n + 2 * n, + "n": n, + }, + y_log=False, + ) + + +@triton.testing.perf_report([create_split_gemm_rms_benchmark_config(dtype) for dtype in BENCH_DTYPES]) +def bench_mhc_split_gemm_rms(M, backend, dtype, K, N, n, device=DEVICE): + """Benchmark for split GEMM + RMS kernel only (compared with deepgemm)""" + x = _randn((M, K), dtype=X_DTYPE, device=device) + w = _randn((K, N), dtype=W_DTYPE, device=device) + w_nt = w.transpose(0, 1).contiguous() + + cfg = None + if is_backend_available("cutile"): + from tilegym.ops.cutile import mhc as mhc_cutile + + _, _, cfg = mhc_cutile.cutile_autotune_mhc_split_gemm_rms( + torch.cuda.current_stream(), + x, + w, + M, + N, + K, + ) + + if backend == "deepgemm": + if not DEEPGEMM_AVAILABLE: + return 0.0 + + import deep_gemm + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + split_k = 4 + dg_d = torch.empty((split_k, M, N), dtype=torch.float32, device=device) + dg_s = torch.empty((split_k, M), dtype=torch.float32, device=device) + + fn = lambda: deep_gemm.tf32_hc_prenorm_gemm(x, w_nt, dg_d, dg_s, num_splits=split_k) + ms = triton.testing.do_bench_cudagraph(fn) + + # Output bytes: dg_d and dg_s (split outputs) + x_bytes = M * K * x.element_size() + w_bytes = K * N * w.element_size() + out_d_bytes = split_k * M * N * 4 # float32 + out_s_bytes = split_k * M * 4 # float32 + total_bytes = x_bytes + w_bytes + out_d_bytes + out_s_bytes + gb_per_s = total_bytes * 1e-9 / (ms * 1e-3) + return gb_per_s + + if backend == "torch": + if cfg is None: + split_k = 4 + tile_size_n = 128 + else: + split_k = cfg.SPLIT_K + tile_size_n = cfg.TILE_SIZE_N + + num_bid_n = triton.cdiv(N, tile_size_n) + y_acc = torch.empty((split_k, M, N), device=x.device, dtype=torch.float32) + r_acc = torch.empty((split_k, M, num_bid_n), device=x.device, dtype=torch.float32) + + k_per_split = triton.cdiv(K, split_k) + x_fp32 = x.to(torch.float32) + w_fp32 = w.to(torch.float32) + + def run(): + for s in range(split_k): + k_start = s * k_per_split + k_end = min(k_start + k_per_split, K) + x_s = x_fp32[:, k_start:k_end] + w_s = w_fp32[k_start:k_end, :] + y_acc[s].copy_(x_s @ w_s) + rms_partial = (x_s * x_s).sum(dim=1, keepdim=True) # [M, 1] + r_acc[s].copy_(rms_partial.expand(M, num_bid_n)) + + ms = triton.testing.do_bench_cudagraph(run) + total_bytes = M * K * x.element_size() + K * N * w.element_size() + y_acc.numel() * 4 + r_acc.numel() * 4 + gb_per_s = total_bytes * 1e-9 / (ms * 1e-3) + return gb_per_s + + elif backend == "cutile": + import cuda.tile as ct + + from tilegym.ops.cutile import mhc as mhc_cutile + + y_acc = torch.empty((M * cfg.SPLIT_K, N), device=x.device, dtype=torch.float32) + r_acc = torch.empty((M * cfg.SPLIT_K, triton.cdiv(N, cfg.TILE_SIZE_N)), device=x.device, dtype=torch.float32) + + grid = ( + triton.cdiv(M, cfg.TILE_SIZE_M) * triton.cdiv(N, cfg.TILE_SIZE_N), + cfg.SPLIT_K, + 1, + ) + fn = lambda: ct.launch( + torch.cuda.current_stream(), + grid, + mhc_cutile.mhc_split_gemm_rms_kernel, + ( + x, + w, + y_acc, + r_acc, + M, + N, + K, + cfg.TILE_SIZE_M, + cfg.TILE_SIZE_N, + cfg.TILE_SIZE_K, + cfg.SPLIT_K, + cfg.GROUP_SIZE_M, + ), + ) + + ms = triton.testing.do_bench_cudagraph(fn) + + # Calculate bandwidth + x_bytes = M * K * x.element_size() + w_bytes = K * N * w.element_size() + y_acc_bytes = M * cfg.SPLIT_K * N * 4 # float32 + r_acc_bytes = r_acc.numel() * 4 # float32 + total_bytes = x_bytes + w_bytes + y_acc_bytes + r_acc_bytes + gb_per_s = total_bytes * 1e-9 / (ms * 1e-3) + return gb_per_s + + return 0.0 + + +@triton.testing.perf_report([create_gemm_rms_scale_benchmark_config(dtype) for dtype in BENCH_DTYPES]) +def bench_mhc_gemm_rms_scale(M, backend, dtype, K, N, n, device=DEVICE): + """Benchmark for full MHC GEMM+RMS+Scale operation (compared with torch)""" + x = _randn((M, K), dtype=dtype, device=device) + w = _randn((K, N), dtype=W_DTYPE, device=device) + bias = _randn((N,), dtype=OTHER_DTYPE, device=device) + alpha_pre = 0.8 + alpha_post = 1.1 + alpha_res = 0.9 + + kwargs = {} + if backend == "cutile": + from tilegym.ops.cutile import mhc as mhc_cutile + + _, _, cfg = mhc_cutile.cutile_autotune_mhc_split_gemm_rms( + torch.cuda.current_stream(), + x, + w, + M, + N, + K, + ) + kwargs["cfg"] = cfg + + fn = lambda: tilegym.ops.mhc_gemm_rms_scale( + x, w, n, alpha_pre, alpha_post, alpha_res, bias, backend=backend, **kwargs + ) + + ms = triton.testing.do_bench_cudagraph(fn) + + # Calculate bandwidth (GB/s) + # Input: x (M*K), w (K*N), bias (N) + # Output: y (M*N), r (M*1) + x_bytes = M * K * x.element_size() + w_bytes = K * N * w.element_size() + bias_bytes = N * bias.element_size() + y_bytes = M * N * bias.element_size() # output dtype same as bias + r_bytes = M * 1 * 4 # r is float32 + total_bytes = x_bytes + w_bytes + bias_bytes + y_bytes + r_bytes + gb_per_s = total_bytes * 1e-9 / (ms * 1e-3) + return gb_per_s + + +@triton.testing.perf_report([create_sinkhorn_benchmark_config(dtype) for dtype in BENCH_DTYPES]) +def bench_mhc_sinkhorn(M, backend, dtype, N, n, device=DEVICE): + """Benchmark for MHC Sinkhorn operation (compared with torch)""" + y = _randn((M, N), dtype=dtype, device=device) + + y_test = y.clone() + tilegym.ops.mhc_sinkhorn(y_test, n, backend=backend) + + fn = lambda: tilegym.ops.mhc_sinkhorn(y_test, n, backend=backend) + ms = triton.testing.do_bench_cudagraph(fn) + + bytes_per_row = n * n * y.element_size() + total_bytes = y.shape[0] * bytes_per_row * 2 + gb_per_s = total_bytes * 1e-9 / (ms * 1e-3) + return gb_per_s + + +@triton.testing.perf_report([create_apply_residual_benchmark_config(dtype) for dtype in BENCH_DTYPES]) +def bench_mhc_apply_residual(M, backend, dtype, C, n, device=DEVICE): + """Benchmark for MHC apply residual operation (compared with torch)""" + nC = n * C + out_n = n * n + 2 * n + x = _randn((M, nC), dtype=dtype, device=device) + f_out = _randn((M, C), dtype=dtype, device=device) + y = _randn((M, out_n), dtype=dtype, device=device) + + fn = lambda: tilegym.ops.mhc_apply_residual(x, f_out, y, n, backend=backend) + out = fn() + ms = triton.testing.do_bench_cudagraph(fn) + + total_bytes = ( + x.numel() * x.element_size() + + f_out.numel() * f_out.element_size() + + y.numel() * y.element_size() + + out.numel() * out.element_size() + ) + gb_per_s = total_bytes * 1e-9 / (ms * 1e-3) + return gb_per_s + + +if __name__ == "__main__": + bench_mhc_split_gemm_rms.run(print_data=True) + bench_mhc_gemm_rms_scale.run(print_data=True) + bench_mhc_sinkhorn.run(print_data=True) + bench_mhc_apply_residual.run(print_data=True) diff --git a/tests/ops/test_mhc.py b/tests/ops/test_mhc.py new file mode 100644 index 0000000..edb543e --- /dev/null +++ b/tests/ops/test_mhc.py @@ -0,0 +1,163 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import pytest +import torch + +import tilegym +from tests import common + +X_DTYPE = torch.bfloat16 +W_DTYPE = torch.float32 +OTHER_DTYPE = torch.float32 + + +class Test_MHC(common.PyTestCase): + @staticmethod + def reference_gemm_rmsnorm(x, w): + x_fp32 = x.to(torch.float32) + w_fp32 = w.to(torch.float32) + y = x_fp32 @ w_fp32 + rms = torch.sqrt(x_fp32.pow(2).mean(dim=1, keepdim=True)) + return y, rms + + @staticmethod + def reference_scale_bias_sigmoid(y, r, n, alpha_pre, alpha_post, alpha_res, bias): + scale = torch.empty((y.shape[1],), dtype=torch.float32, device=y.device) + scale[:n] = alpha_pre + scale[n : 2 * n] = alpha_post + scale[2 * n :] = alpha_res + bias = bias.to(torch.float32) + linear = y.to(torch.float32) * scale / r + bias + out = linear.clone() + out[:, :n] = torch.sigmoid(linear[:, :n]) + out[:, n : 2 * n] = 2.0 * torch.sigmoid(linear[:, n : 2 * n]) + return out.to(y.dtype) + + @staticmethod + def reference_sinkhorn(y, n): + y_out = y.clone() + start = 2 * n + end = start + n * n + mat = y_out[:, start:end].to(torch.float32).reshape(-1, n, n) + mat = torch.exp(mat) + for _ in range(20): + mat = mat / mat.sum(dim=2, keepdim=True) + mat = mat / mat.sum(dim=1, keepdim=True) + y_out[:, start:end] = mat.reshape(-1, n * n).to(y.dtype) + return y_out + + @staticmethod + def reference_apply_residual(x, f_out, y, n): + B, nC = x.shape + C = nC // n + x_view = x.view(B, n, C).to(torch.float32) + f_view = f_out + h_post = y[:, n : 2 * n] + h_res = y[:, 2 * n : 2 * n + n * n].view(B, n, n) + x_res = torch.matmul(h_res, x_view) + x_post = h_post.unsqueeze(-1) * f_view.unsqueeze(1) + x_next = x_res + x_post + return x_next.to(x.dtype).view(B, nC) + + _backends = ["cutile"] + + @pytest.mark.parametrize( + "m, k, n", + [ + (128, 1024, 4), + (128, 1024, 8), + (128, 1000, 8), + ], + ) + @pytest.mark.parametrize("backend", _backends) + def test_op_mhc_gemm_rms_scale_bf16_precision(self, m, k, n, backend): + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + self.setUp() + device = torch.device("cuda") + out_n = n * n + 2 * n + + x = torch.randn((m, k), dtype=X_DTYPE, device=device, requires_grad=False) + w = torch.randn((k, out_n), dtype=W_DTYPE, device=device, requires_grad=False) + bias = torch.randn((out_n,), dtype=OTHER_DTYPE, device=device, requires_grad=False) + alpha_pre = 0.8 + alpha_post = 1.1 + alpha_res = 0.9 + + allow_tf32 = torch.backends.cuda.matmul.allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = True + y_ref, r_ref = self.reference_gemm_rmsnorm(x, w) + torch.backends.cuda.matmul.allow_tf32 = allow_tf32 + y_ref = self.reference_scale_bias_sigmoid(y_ref, r_ref, n, alpha_pre, alpha_post, alpha_res, bias) + y_out, r_out = tilegym.ops.mhc_gemm_rms_scale( + x, + w, + n, + alpha_pre, + alpha_post, + alpha_res, + bias, + ) + torch.testing.assert_close(y_out, y_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(r_out, r_ref, rtol=1e-2, atol=1e-2) + + @pytest.mark.parametrize( + "m, n, dtype", + [ + (256, 4, OTHER_DTYPE), + (256, 8, OTHER_DTYPE), + ], + ) + @pytest.mark.parametrize("backend", _backends) + def test_op_mhc_sinkhorn(self, m, n, dtype, backend): + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + self.setUp() + device = torch.device("cuda") + out_n = n * n + 2 * n + + y = torch.randn((m, out_n), dtype=dtype, device=device, requires_grad=False) + y_ref = self.reference_sinkhorn(y, n) + y_test = y.clone() + y_out = tilegym.ops.mhc_sinkhorn(y_test, n) + tol = common.get_dtype_tolerances(dtype) + out_close, msg = common.compare_tensors(y_out, y_ref, rtol=tol["rtol"], atol=tol["atol"]) + assert out_close, "\n".join(msg) + + @pytest.mark.parametrize( + "m, n, c", + [ + (128, 4, 1024), + (64, 8, 2048), + (128, 2, 1024), + ], + ) + @pytest.mark.parametrize("backend", _backends) + def test_op_mhc_apply_residual(self, m, n, c, backend): + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + + self.setUp() + device = torch.device("cuda") + out_n = n * n + 2 * n + nC = n * c + + x = torch.randn((m, nC), dtype=X_DTYPE, device=device, requires_grad=False) + f_out = torch.randn((m, c), dtype=X_DTYPE, device=device, requires_grad=False) + y = torch.randn((m, out_n), dtype=OTHER_DTYPE, device=device, requires_grad=False) + + y_ref = self.reference_apply_residual(x, f_out, y, n) + y_out = tilegym.ops.mhc_apply_residual(x, f_out, y, n) + tol = common.get_dtype_tolerances(X_DTYPE) + out_close, msg = common.compare_tensors(y_out, y_ref, rtol=tol["rtol"], atol=tol["atol"]) + assert out_close, "\n".join(msg)