From 08d16f55a0dc82a28d69beaa18e859429ecbeb51 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 6 Nov 2025 03:13:20 +0000 Subject: [PATCH 01/12] add triton_support_tensor_descriptor --- lightllm/utils/device_utils.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index 66a25e93b..9438587ac 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -227,3 +227,15 @@ def set_sm_limit(percent: int, gpu_index=0): os.environ["CUDA_MPS_ACTIVE_THREAD_PERCENTAGE"] = str(percent) logger.info(f"Set CUDA_MPS_ACTIVE_THREAD_PERCENTAGE to {percent}% for GPU {gpu_index}.") return True + + +@lru_cache(maxsize=None) +def triton_support_tensor_descriptor(): + try: + from triton.tools.tensor_descriptor import TensorDescriptor + + logger.info("triton support tensor_descriptor") + return True + except: + logger.info("triton not support tensor_descriptor") + return False From 2760f6bf4653474271104ad8a3361210684d918d Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 6 Nov 2025 07:15:11 +0000 Subject: [PATCH 02/12] add mblocks_to_tuple_info infos --- .../common/fused_moe/grouped_fused_moe.py | 43 +++++++++++-------- lightllm/utils/device_utils.py | 2 +- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/lightllm/common/fused_moe/grouped_fused_moe.py b/lightllm/common/fused_moe/grouped_fused_moe.py index 600e8b84a..d07926664 100644 --- a/lightllm/common/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/fused_moe/grouped_fused_moe.py @@ -24,12 +24,7 @@ from typing import Any, Callable, Dict, Optional, Tuple from lightllm.utils.log_utils import init_logger from lightllm.utils.vllm_utils import vllm_ops -from lightllm.utils.device_utils import ( - get_device_sm_count, - get_device_sm_regs_num, - get_device_sm_shared_mem_num, - get_device_warp_size, -) +from lightllm.utils.device_utils import triton_support_tensor_descriptor from .moe_kernel_configs import MoeGroupedGemmKernelConfig from .moe_silu_and_mul import silu_and_mul_fwd from .moe_sum_reduce import moe_sum_reduce @@ -307,8 +302,8 @@ def moe_align_fused( @triton.jit def moe_align2_kernel( experts_token_num_ptr, # [expert_num,] - mblocks_to_expert_id, # [max_num_m_blocks,] - mblocks_to_m_index, # [max_num_m_blocks,] + mblocks_to_tuple_info, # [max_num_m_blocks, 3], tuple for (expert_id, m_index, token_start_index) + mblocks_to_tuple_info_stride_0, expert_num, max_num_m_blocks, BLOCK_M: tl.constexpr, @@ -318,6 +313,10 @@ def moe_align2_kernel( expert_id = tl.program_id(axis=0) off_expert = tl.arange(0, BLOCK_EXPERT) expert_to_token_num = tl.load(experts_token_num_ptr + off_expert, mask=off_expert < expert_num, other=0) + token_start_index = tl.sum( + tl.where(off_expert == expert_id, tl.cumsum(expert_to_token_num) - expert_to_token_num, 0) + ) + expert_to_block_num = tl.cdiv(expert_to_token_num, BLOCK_M) block_starts = tl.cumsum(expert_to_block_num) - expert_to_block_num block_start = tl.sum(tl.where(off_expert == expert_id, block_starts, 0)) @@ -328,20 +327,25 @@ def moe_align2_kernel( block_off = tl.arange(0, 128) for start_loc in range(0, cur_block_num, 128): tl.store( - mblocks_to_expert_id + block_start + start_loc + block_off, + mblocks_to_tuple_info + (block_start + start_loc + block_off) * mblocks_to_tuple_info_stride_0 + 0, expert_id, mask=start_loc + block_off < cur_block_num, ) tl.store( - mblocks_to_m_index + block_start + start_loc + block_off, + mblocks_to_tuple_info + (block_start + start_loc + block_off) * mblocks_to_tuple_info_stride_0 + 1, start_loc + block_off, mask=start_loc + block_off < cur_block_num, ) + tl.store( + mblocks_to_tuple_info + (block_start + start_loc + block_off) * mblocks_to_tuple_info_stride_0 + 2, + token_start_index + (start_loc + block_off) * BLOCK_M, + mask=start_loc + block_off < cur_block_num, + ) if expert_id == expert_num - 1: for extra_fill_start in range(block_start + cur_block_num, max_num_m_blocks, 128): tl.store( - mblocks_to_expert_id + extra_fill_start + block_off, + mblocks_to_tuple_info + (extra_fill_start + block_off) * mblocks_to_tuple_info_stride_0 + 0, -1, mask=extra_fill_start + block_off < max_num_m_blocks, ) @@ -355,24 +359,25 @@ def moe_align2(token_num_mul_topk_num: int, exports_token_num: torch.Tensor, blo """ max_num_tokens_padded = token_num_mul_topk_num + exports_token_num.shape[0] * (block_m - 1) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_m) - mblocks_to_expert_id = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda") - mblocks_to_m_index = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda") + # first is expert, second is m_index, third is token_start_index + mblocks_to_tuple_info = torch.empty((max_num_m_blocks, 3), dtype=torch.int32, device="cuda") + expert_num = exports_token_num.shape[0] grid = (expert_num,) moe_align2_kernel[grid]( - exports_token_num, - mblocks_to_expert_id, - mblocks_to_m_index, - expert_num, - max_num_m_blocks, + experts_token_num_ptr=exports_token_num, + mblocks_to_tuple_info=mblocks_to_tuple_info, + mblocks_to_tuple_info_stride_0=mblocks_to_tuple_info.stride(0), + expert_num=expert_num, + max_num_m_blocks=max_num_m_blocks, BLOCK_M=block_m, BLOCK_EXPERT=triton.next_power_of_2(expert_num), num_warps=4, num_stages=1, ) - return mblocks_to_expert_id, mblocks_to_m_index + return mblocks_to_tuple_info @triton.jit diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index 9438587ac..6f848f381 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -230,7 +230,7 @@ def set_sm_limit(percent: int, gpu_index=0): @lru_cache(maxsize=None) -def triton_support_tensor_descriptor(): +def triton_support_tensor_descriptor() -> bool: try: from triton.tools.tensor_descriptor import TensorDescriptor From 90d53959d13ad6054038ba3ccb90329821fe479f Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 6 Nov 2025 07:37:03 +0000 Subject: [PATCH 03/12] fix moe call --- .../common/fused_moe/grouped_fused_moe.py | 83 ++++++++++--------- 1 file changed, 42 insertions(+), 41 deletions(-) diff --git a/lightllm/common/fused_moe/grouped_fused_moe.py b/lightllm/common/fused_moe/grouped_fused_moe.py index d07926664..421a271cf 100644 --- a/lightllm/common/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/fused_moe/grouped_fused_moe.py @@ -382,8 +382,8 @@ def moe_align2(token_num_mul_topk_num: int, exports_token_num: torch.Tensor, blo @triton.jit def grouped_matmul_kernel( - mblocks_to_expert_id, # [max_m_block_size] - mblocks_to_m_index, # [max_m_block_size] + mblocks_to_tuple_info, # [max_m_block_size, 3] tuple for (expert_id, m_index, token_start_index) + mblocks_to_tuple_info_stride_0, # int k, # int n, # int topk_num, # int @@ -444,14 +444,17 @@ def grouped_matmul_kernel( pid_m = first_pid_m + back_mark * (group_size_m - 1) + back_mark1 * (in_group_index % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m - expert_id = tl.load(mblocks_to_expert_id + pid_m) + expert_id = tl.load(mblocks_to_tuple_info + pid_m * mblocks_to_tuple_info_stride_0 + 0) if expert_id == -1: return - tile_m_idx = tl.load(mblocks_to_m_index + pid_m) + tile_m_idx = tl.load(mblocks_to_tuple_info + pid_m * mblocks_to_tuple_info_stride_0 + 1) tile_n_idx = pid_n + # get token start index in inputs + # token_start_index = tl.load(mblocks_to_tuple_info + pid_m * mblocks_to_tuple_info_stride_0 + 2) + # get the gemm size of the current problem cur_m = tl.load(expert_to_token_num + expert_id) @@ -693,59 +696,57 @@ def grouped_matmul( ) if reused_mblock_infos is None: - mblocks_to_expert_id, mblocks_to_m_index = moe_align2(token_num_mul_topk_num, expert_to_token_num, BLOCK_SIZE_M) + mblocks_to_tuple_info = moe_align2(token_num_mul_topk_num, expert_to_token_num, BLOCK_SIZE_M) else: # when up group gemm and down group gemm use same BLOCK_SIZE_M, # can reuse (mblocks_to_expert_id, mblocks_to_m_index) created by moe_align2 kernel. - mblocks_to_expert_id, mblocks_to_m_index, reused_block_size_m = reused_mblock_infos + mblocks_to_tuple_info, reused_block_size_m = reused_mblock_infos if reused_block_size_m != BLOCK_SIZE_M: - mblocks_to_expert_id, mblocks_to_m_index = moe_align2( - token_num_mul_topk_num, expert_to_token_num, BLOCK_SIZE_M - ) + mblocks_to_tuple_info = moe_align2(token_num_mul_topk_num, expert_to_token_num, BLOCK_SIZE_M) - block_num = triton.cdiv(n, BLOCK_SIZE_N) * mblocks_to_expert_id.shape[0] + block_num = triton.cdiv(n, BLOCK_SIZE_N) * mblocks_to_tuple_info.shape[0] grid = (block_num,) NEED_K_MASK = (k % BLOCK_SIZE_K) != 0 grouped_matmul_kernel[grid]( - mblocks_to_expert_id, - mblocks_to_m_index, - k, - n, - topk_num, - token_input_scale, - expert_to_weights_scale, - expert_to_weights_scale.stride(0) + mblocks_to_tuple_info=mblocks_to_tuple_info, + mblocks_to_tuple_info_stride_0=mblocks_to_tuple_info.stride(0), + k=k, + n=n, + topk_num=topk_num, + token_scale_ptr=token_input_scale, + weight_scale_ptr=expert_to_weights_scale, + weight_scale_stride0=expert_to_weights_scale.stride(0) if expert_to_weights_scale is not None and expert_to_weights_scale.ndim >= 1 else 0, - expert_to_weights_scale.stride(1) + weight_scale_stride1=expert_to_weights_scale.stride(1) if expert_to_weights_scale is not None and expert_to_weights_scale.ndim >= 2 else 0, - expert_to_weights_scale.stride(2) + weight_scale_stride2=expert_to_weights_scale.stride(2) if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 3 else 0, - token_inputs, - token_inputs.stride(0), - token_inputs.stride(1), - expert_weights, - expert_weights.stride(0), - expert_weights.stride(1), - expert_weights.stride(2), - bias, - bias.stride(0) if bias is not None else 0, - bias.stride(1) if bias is not None and bias.ndim >= 2 else 0, - expert_to_weights, - expert_to_weights.stride(0), - expert_to_weights.stride(1), - expert_to_token_num, - expert_to_token_index, - expert_to_token_index.stride(0), - out, - out.stride(0), - out.stride(1), - m_block_num=mblocks_to_expert_id.shape[0], + token_ptr=token_inputs, + token_stride_0=token_inputs.stride(0), + token_stride_1=token_inputs.stride(1), + weights_ptr=expert_weights, + weight_stride_0=expert_weights.stride(0), + weight_stride_1=expert_weights.stride(1), + weight_stride_2=expert_weights.stride(2), + bias_ptr=bias, + bias_stride_0=bias.stride(0) if bias is not None else 0, + bias_stride_1=bias.stride(1) if bias is not None and bias.ndim >= 2 else 0, + expert_to_weights_ptr=expert_to_weights, + expert_to_weights_stride0=expert_to_weights.stride(0), + expert_to_weights_stride1=expert_to_weights.stride(1), + expert_to_token_num=expert_to_token_num, + expert_to_token_index=expert_to_token_index, + expert_to_token_index_stride_0=expert_to_token_index.stride(0), + out_ptr=out, + out_stride_0=out.stride(0), + out_stride_1=out.stride(1), + m_block_num=mblocks_to_tuple_info.shape[0], n_block_num=triton.cdiv(n, BLOCK_SIZE_N), compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, @@ -762,7 +763,7 @@ def grouped_matmul( num_stages=num_stages, ADD_BIAS=bias is not None, ) - return (mblocks_to_expert_id, mblocks_to_m_index, BLOCK_SIZE_M) + return (mblocks_to_tuple_info, BLOCK_SIZE_M) def fused_experts_impl( From 151a1d8512a05f6b7448cba1e1a1aeedce31a0a7 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 6 Nov 2025 08:00:49 +0000 Subject: [PATCH 04/12] add out_sorted --- lightllm/common/fused_moe/grouped_fused_moe.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/lightllm/common/fused_moe/grouped_fused_moe.py b/lightllm/common/fused_moe/grouped_fused_moe.py index 421a271cf..b4fb5c3ce 100644 --- a/lightllm/common/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/fused_moe/grouped_fused_moe.py @@ -427,6 +427,7 @@ def grouped_matmul_kernel( NEED_K_MASK: tl.constexpr = True, NEED_TRANS: tl.constexpr = False, ADD_BIAS: tl.constexpr = False, + OUT_SORTED: tl.constexpr = False, ): pid = tl.program_id(0) @@ -452,8 +453,9 @@ def grouped_matmul_kernel( tile_m_idx = tl.load(mblocks_to_tuple_info + pid_m * mblocks_to_tuple_info_stride_0 + 1) tile_n_idx = pid_n - # get token start index in inputs - # token_start_index = tl.load(mblocks_to_tuple_info + pid_m * mblocks_to_tuple_info_stride_0 + 2) + if OUT_SORTED: + # get token start index in inputs + token_start_index = tl.load(mblocks_to_tuple_info + pid_m * mblocks_to_tuple_info_stride_0 + 2) # get the gemm size of the current problem cur_m = tl.load(expert_to_token_num + expert_id) @@ -557,8 +559,13 @@ def grouped_matmul_kernel( c = accumulator.to(compute_type) offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = out_ptr + a_m_index[:, None] * out_stride_0 + offs_cn[None, :] - tl.store(c_ptrs, c, mask=(token_mask[:, None]) & (offs_cn[None, :] < n)) + + if OUT_SORTED: + c_ptrs = out_ptr + (token_start_index + tl.arange(0, BLOCK_SIZE_M))[:, None] * out_stride_0 + offs_cn[None, :] + tl.store(c_ptrs, c, mask=(token_mask[:, None]) & (offs_cn[None, :] < n)) + else: + c_ptrs = out_ptr + a_m_index[:, None] * out_stride_0 + offs_cn[None, :] + tl.store(c_ptrs, c, mask=(token_mask[:, None]) & (offs_cn[None, :] < n)) return From e5657d3b87a8f91e5a224450cf0721182907f491 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 6 Nov 2025 09:08:15 +0000 Subject: [PATCH 05/12] add tma kernel --- .../common/fused_moe/grouped_fused_moe.py | 44 +++++++++++++------ 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/lightllm/common/fused_moe/grouped_fused_moe.py b/lightllm/common/fused_moe/grouped_fused_moe.py index b4fb5c3ce..a6e1b7b3d 100644 --- a/lightllm/common/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/fused_moe/grouped_fused_moe.py @@ -393,9 +393,11 @@ def grouped_matmul_kernel( weight_scale_stride1, weight_scale_stride2, token_ptr, # [token_num, hidden_dim] + token_desc, # triton tensor describdescriptor token_stride_0, token_stride_1, weights_ptr, # [expert_num, N, K] + weight_desc, # triton tensor describdescriptor weight_stride_0, weight_stride_1, weight_stride_2, @@ -428,6 +430,8 @@ def grouped_matmul_kernel( NEED_TRANS: tl.constexpr = False, ADD_BIAS: tl.constexpr = False, OUT_SORTED: tl.constexpr = False, + TOKEN_INPUT_USE_TMA: tl.constexpr = False, + WEIGHT_USE_TMA: tl.constexpr = False, ): pid = tl.program_id(0) @@ -453,7 +457,8 @@ def grouped_matmul_kernel( tile_m_idx = tl.load(mblocks_to_tuple_info + pid_m * mblocks_to_tuple_info_stride_0 + 1) tile_n_idx = pid_n - if OUT_SORTED: + if OUT_SORTED or TOKEN_INPUT_USE_TMA: + assert OUT_SORTED and TOKEN_INPUT_USE_TMA is False # get token start index in inputs token_start_index = tl.load(mblocks_to_tuple_info + pid_m * mblocks_to_tuple_info_stride_0 + 2) @@ -490,33 +495,44 @@ def grouped_matmul_kernel( b_ptrs = weights_ptr + weight_stride_0 * expert_id + offs_k[:, None] + offs_bn[None, :] * weight_stride_1 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for step_k in range(0, tl.cdiv(k, BLOCK_SIZE_K)): + for k_start in range(0, k, BLOCK_SIZE_K): # hint to Triton compiler to do proper loop pipelining # tl.multiple_of(a_ptrs, [16, 16]) # tl.multiple_of(b_ptrs, [16, 16]) if NEED_TRANS: - if NEED_K_MASK: - a = tl.load( - a_ptrs, mask=(token_mask[None, :]) & (offs_k[:, None] < k - step_k * BLOCK_SIZE_K), other=0.0 - ) - b = tl.load(b_ptrs, mask=(offs_k[None, :] < k), other=0.0) + if TOKEN_INPUT_USE_TMA: + a = token_desc.load([token_start_index, k_start]).T + elif NEED_K_MASK: + a = tl.load(a_ptrs, mask=(token_mask[None, :]) & (offs_k[:, None] < k - k_start), other=0.0) else: a = tl.load(a_ptrs, mask=(token_mask[None, :]), other=0.0) + + if WEIGHT_USE_TMA: + weight_desc.load([expert_id, tile_n_idx * BLOCK_SIZE_N, k_start]).reshape(BLOCK_SIZE_N, BLOCK_SIZE_K) + elif NEED_K_MASK: + b = tl.load(b_ptrs, mask=(offs_k[None, :] < k - k_start), other=0.0) + else: b = tl.load(b_ptrs) + else: - if NEED_K_MASK: - a = tl.load( - a_ptrs, mask=(token_mask[:, None]) & (offs_k[None, :] < k - step_k * BLOCK_SIZE_K), other=0.0 - ) - b = tl.load(b_ptrs, mask=(offs_k[:, None] < k), other=0.0) + if TOKEN_INPUT_USE_TMA: + a = token_desc.load([token_start_index, k_start]) + elif NEED_K_MASK: + a = tl.load(a_ptrs, mask=(token_mask[:, None]) & (offs_k[None, :] < k - k_start), other=0.0) else: a = tl.load(a_ptrs, mask=(token_mask[:, None]), other=0.0) + + if WEIGHT_USE_TMA: + weight_desc.load([expert_id, tile_n_idx * BLOCK_SIZE_N, k_start]).reshape(BLOCK_SIZE_N, BLOCK_SIZE_K).T + elif NEED_K_MASK: + b = tl.load(b_ptrs, mask=(offs_k[:, None] < k - k_start), other=0.0) + else: b = tl.load(b_ptrs) if use_fp8_w8a8: if block_size_k > 0 and block_size_n > 0: - offs_ks = step_k * BLOCK_SIZE_K // block_size_k + offs_ks = k_start // block_size_k a_scale = tl.load(a_scale_ptrs + offs_ks, mask=token_mask, other=0.0) b_scale = tl.load(b_scale_ptrs + offs_ks * weight_scale_stride2) if NEED_TRANS: @@ -735,9 +751,11 @@ def grouped_matmul( if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 3 else 0, token_ptr=token_inputs, + token_desc=None, token_stride_0=token_inputs.stride(0), token_stride_1=token_inputs.stride(1), weights_ptr=expert_weights, + weight_desc=None, weight_stride_0=expert_weights.stride(0), weight_stride_1=expert_weights.stride(1), weight_stride_2=expert_weights.stride(2), From 817dab04a13fda252cf72adbba4576bd31bfd831 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 6 Nov 2025 10:08:15 +0000 Subject: [PATCH 06/12] fix moe kernel --- .../common/fused_moe/grouped_fused_moe.py | 64 +++++++++++++++++-- 1 file changed, 59 insertions(+), 5 deletions(-) diff --git a/lightllm/common/fused_moe/grouped_fused_moe.py b/lightllm/common/fused_moe/grouped_fused_moe.py index a6e1b7b3d..69dace003 100644 --- a/lightllm/common/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/fused_moe/grouped_fused_moe.py @@ -478,7 +478,13 @@ def grouped_matmul_kernel( if use_fp8_w8a8: if block_size_k > 0 and block_size_n > 0: - a_scale_ptrs = token_scale_ptr + (a_m_index // topk_num) * (token_stride_0 // block_size_k) + token_scale_stride0 = token_stride_0 // block_size_k + if TOKEN_INPUT_USE_TMA: + assert MUL_ROUTED_WEIGHT is True + a_scale_ptrs = token_scale_ptr + (token_start_index + tl.arange(0, BLOCK_SIZE_M)) * token_scale_stride0 + else: + a_scale_ptrs = token_scale_ptr + (a_m_index // topk_num) * token_scale_stride0 + offs_bsn = offs_bn // block_size_n b_scale_ptrs = weight_scale_ptr + expert_id * weight_scale_stride0 + offs_bsn * weight_scale_stride1 else: @@ -509,7 +515,9 @@ def grouped_matmul_kernel( a = tl.load(a_ptrs, mask=(token_mask[None, :]), other=0.0) if WEIGHT_USE_TMA: - weight_desc.load([expert_id, tile_n_idx * BLOCK_SIZE_N, k_start]).reshape(BLOCK_SIZE_N, BLOCK_SIZE_K) + b = weight_desc.load([expert_id, tile_n_idx * BLOCK_SIZE_N, k_start]).reshape( + BLOCK_SIZE_N, BLOCK_SIZE_K + ) elif NEED_K_MASK: b = tl.load(b_ptrs, mask=(offs_k[None, :] < k - k_start), other=0.0) else: @@ -524,7 +532,11 @@ def grouped_matmul_kernel( a = tl.load(a_ptrs, mask=(token_mask[:, None]), other=0.0) if WEIGHT_USE_TMA: - weight_desc.load([expert_id, tile_n_idx * BLOCK_SIZE_N, k_start]).reshape(BLOCK_SIZE_N, BLOCK_SIZE_K).T + b = ( + weight_desc.load([expert_id, tile_n_idx * BLOCK_SIZE_N, k_start]) + .reshape(BLOCK_SIZE_N, BLOCK_SIZE_K) + .T + ) elif NEED_K_MASK: b = tl.load(b_ptrs, mask=(offs_k[:, None] < k - k_start), other=0.0) else: @@ -727,6 +739,45 @@ def grouped_matmul( if reused_block_size_m != BLOCK_SIZE_M: mblocks_to_tuple_info = moe_align2(token_num_mul_topk_num, expert_to_token_num, BLOCK_SIZE_M) + support_tma = triton_support_tensor_descriptor() + + if support_tma: + # TMA descriptors require a global memory allocation + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + # moe 分为 up 和 down 两次计算,当 mul_routed_weight 为 False 的时候为 up + is_up_moe = not mul_routed_weight + + if is_up_moe: + TOKEN_INPUT_USE_TMA = False + WEIGHT_USE_TMA = support_tma + OUT_SORTED = support_tma + else: + TOKEN_INPUT_USE_TMA = support_tma + WEIGHT_USE_TMA = support_tma + OUT_SORTED = False + + if TOKEN_INPUT_USE_TMA: + from triton.tools.tensor_descriptor import TensorDescriptor + + token_desc = TensorDescriptor( + token_inputs, token_inputs.shape, token_inputs.stride(), [BLOCK_SIZE_M, BLOCK_SIZE_K] + ) + else: + token_desc = None + + if WEIGHT_USE_TMA: + from triton.tools.tensor_descriptor import TensorDescriptor + + weight_desc = TensorDescriptor( + expert_weights, expert_weights.shape, expert_weights.stride(), [1, BLOCK_SIZE_N, BLOCK_SIZE_K] + ) + else: + weight_desc = None + block_num = triton.cdiv(n, BLOCK_SIZE_N) * mblocks_to_tuple_info.shape[0] grid = (block_num,) @@ -751,11 +802,11 @@ def grouped_matmul( if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 3 else 0, token_ptr=token_inputs, - token_desc=None, + token_desc=token_desc, token_stride_0=token_inputs.stride(0), token_stride_1=token_inputs.stride(1), weights_ptr=expert_weights, - weight_desc=None, + weight_desc=weight_desc, weight_stride_0=expert_weights.stride(0), weight_stride_1=expert_weights.stride(1), weight_stride_2=expert_weights.stride(2), @@ -787,6 +838,9 @@ def grouped_matmul( num_warps=num_warps, num_stages=num_stages, ADD_BIAS=bias is not None, + OUT_SORTED=OUT_SORTED, + TOKEN_INPUT_USE_TMA=TOKEN_INPUT_USE_TMA, + WEIGHT_USE_TMA=WEIGHT_USE_TMA, ) return (mblocks_to_tuple_info, BLOCK_SIZE_M) From 692bfbd012d13fe3b4dd426a99052b172f678d8c Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 7 Nov 2025 01:43:49 +0000 Subject: [PATCH 07/12] fix tma support --- lightllm/utils/device_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index 6f848f381..b4d1ba629 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -234,8 +234,12 @@ def triton_support_tensor_descriptor() -> bool: try: from triton.tools.tensor_descriptor import TensorDescriptor - logger.info("triton support tensor_descriptor") - return True + support_tma = torch.cuda.get_device_capability() >= (9, 0) + if support_tma: + logger.info("triton support tensor_descriptor") + return True + else: + assert False except: logger.info("triton not support tensor_descriptor") return False From 3f3a23c662e124f3d8163102e61e900eaf29ee92 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 7 Nov 2025 02:42:34 +0000 Subject: [PATCH 08/12] fix kernel, remove dead code --- .../common/fused_moe/grouped_fused_moe.py | 38 +++++++++++-------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/lightllm/common/fused_moe/grouped_fused_moe.py b/lightllm/common/fused_moe/grouped_fused_moe.py index 69dace003..68cc065a9 100644 --- a/lightllm/common/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/fused_moe/grouped_fused_moe.py @@ -457,10 +457,8 @@ def grouped_matmul_kernel( tile_m_idx = tl.load(mblocks_to_tuple_info + pid_m * mblocks_to_tuple_info_stride_0 + 1) tile_n_idx = pid_n - if OUT_SORTED or TOKEN_INPUT_USE_TMA: - assert OUT_SORTED and TOKEN_INPUT_USE_TMA is False - # get token start index in inputs - token_start_index = tl.load(mblocks_to_tuple_info + pid_m * mblocks_to_tuple_info_stride_0 + 2) + # get token start index in inputs + token_start_index = tl.load(mblocks_to_tuple_info + pid_m * mblocks_to_tuple_info_stride_0 + 2) # get the gemm size of the current problem cur_m = tl.load(expert_to_token_num + expert_id) @@ -468,11 +466,14 @@ def grouped_matmul_kernel( # do regular gemm here offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) token_mask = offs_am < cur_m - a_m_index = tl.load( - expert_to_token_index + expert_id * expert_to_token_index_stride_0 + offs_am, - mask=token_mask, - other=0, - ) + + if not OUT_SORTED or not TOKEN_INPUT_USE_TMA: + a_m_index = tl.load( + expert_to_token_index + expert_id * expert_to_token_index_stride_0 + offs_am, + mask=token_mask, + other=0, + ) + offs_bn = (tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % n offs_k = tl.arange(0, BLOCK_SIZE_K) @@ -493,12 +494,17 @@ def grouped_matmul_kernel( ab_scale = a_scale * b_scale if NEED_TRANS: - a_ptrs = token_ptr + (a_m_index // topk_num)[None, :] * token_stride_0 + offs_k[:, None] - b_ptrs = weights_ptr + weight_stride_0 * expert_id + offs_k[None, :] + offs_bn[:, None] * weight_stride_1 + if not TOKEN_INPUT_USE_TMA: + a_ptrs = token_ptr + (a_m_index // topk_num)[None, :] * token_stride_0 + offs_k[:, None] + if not WEIGHT_USE_TMA: + b_ptrs = weights_ptr + weight_stride_0 * expert_id + offs_k[None, :] + offs_bn[:, None] * weight_stride_1 accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_M), dtype=tl.float32) else: - a_ptrs = token_ptr + (a_m_index // topk_num)[:, None] * token_stride_0 + offs_k[None, :] - b_ptrs = weights_ptr + weight_stride_0 * expert_id + offs_k[:, None] + offs_bn[None, :] * weight_stride_1 + if not TOKEN_INPUT_USE_TMA: + a_ptrs = token_ptr + (a_m_index // topk_num)[:, None] * token_stride_0 + offs_k[None, :] + if not WEIGHT_USE_TMA: + b_ptrs = weights_ptr + weight_stride_0 * expert_id + offs_k[:, None] + offs_bn[None, :] * weight_stride_1 + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k_start in range(0, k, BLOCK_SIZE_K): @@ -559,8 +565,10 @@ def grouped_matmul_kernel( else: accumulator += tl.dot(a, b) - a_ptrs += BLOCK_SIZE_K - b_ptrs += BLOCK_SIZE_K + if not TOKEN_INPUT_USE_TMA: + a_ptrs += BLOCK_SIZE_K + if not WEIGHT_USE_TMA: + b_ptrs += BLOCK_SIZE_K if NEED_TRANS: accumulator = accumulator.T From 89ccf4c22dc691d9c9c6275fcb970f8e9d8a0254 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 7 Nov 2025 03:03:42 +0000 Subject: [PATCH 09/12] fix --- lightllm/common/fused_moe/grouped_fused_moe.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lightllm/common/fused_moe/grouped_fused_moe.py b/lightllm/common/fused_moe/grouped_fused_moe.py index 68cc065a9..83e06ccd7 100644 --- a/lightllm/common/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/fused_moe/grouped_fused_moe.py @@ -467,6 +467,8 @@ def grouped_matmul_kernel( offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) token_mask = offs_am < cur_m + assert (OUT_SORTED and TOKEN_INPUT_USE_TMA) is False + if not OUT_SORTED or not TOKEN_INPUT_USE_TMA: a_m_index = tl.load( expert_to_token_index + expert_id * expert_to_token_index_stride_0 + offs_am, From 0766e3de4c76e8d104a63dd2faa6b0712c3d3b32 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 7 Nov 2025 05:26:05 +0000 Subject: [PATCH 10/12] fix unittest --- unit_tests/common/fused_moe/test_grouped_fused_moe.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unit_tests/common/fused_moe/test_grouped_fused_moe.py b/unit_tests/common/fused_moe/test_grouped_fused_moe.py index 40b724422..9a613f6f7 100644 --- a/unit_tests/common/fused_moe/test_grouped_fused_moe.py +++ b/unit_tests/common/fused_moe/test_grouped_fused_moe.py @@ -77,14 +77,14 @@ def test_moe_align2(): experts_token_num[2] = 60 experts_token_num[3] = 16 - blocks_to_expert_id, mblocks_to_m_index = moe_align2(100, experts_token_num, block_m=16) - assert blocks_to_expert_id.shape[0] == triton.cdiv(100 + 4 * (16 - 1), 16) + mblocks_to_tuple_info = moe_align2(100, experts_token_num, block_m=16) + assert mblocks_to_tuple_info.shape[0] == triton.cdiv(100 + 4 * (16 - 1), 16) assert torch.allclose( - blocks_to_expert_id, + mblocks_to_tuple_info[:, 0], torch.tensor([0, 2, 2, 2, 2, 3, -1, -1, -1, -1], device="cuda", dtype=torch.int32), ) assert torch.allclose( - mblocks_to_m_index, torch.tensor([0, 0, 1, 2, 3, 0, 0, 0, 0, 0], device="cuda", dtype=torch.int32) + mblocks_to_tuple_info[:, 1], torch.tensor([0, 0, 1, 2, 3, 0, 0, 0, 0, 0], device="cuda", dtype=torch.int32) ) From bb1b3f29a43c9371f88123b13cf72a299b005282 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 7 Nov 2025 07:20:50 +0000 Subject: [PATCH 11/12] update kernel configs --- ..._num=1,use_fp8_w8a8=true}_NVIDIA_H200.json | 39 +++++------ ..._num=1,use_fp8_w8a8=true}_NVIDIA_H200.json | 33 ++++------ ..._num=8,use_fp8_w8a8=true}_NVIDIA_H200.json | 53 +++++++-------- ..._num=9,use_fp8_w8a8=true}_NVIDIA_H200.json | 65 ++++++++----------- .../{topk_num=8}_NVIDIA_H200.json | 4 ++ .../{topk_num=9}_NVIDIA_H200.json | 4 ++ ...orch.bfloat16,topk_num=8}_NVIDIA_H200.json | 6 ++ ...orch.bfloat16,topk_num=9}_NVIDIA_H200.json | 6 ++ ...=16,dtype=torch.bfloat16}_NVIDIA_H200.json | 6 ++ ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 6 ++ ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 18 +++++ 11 files changed, 127 insertions(+), 113 deletions(-) diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json index 25ac4bdd2..c804b31de 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -3,7 +3,7 @@ "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "NEED_TRANS": true, "num_stages": 2, "num_warps": 4 @@ -17,15 +17,6 @@ "num_stages": 2, "num_warps": 4 }, - "131072": { - "BLOCK_SIZE_K": 128, - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "GROUP_SIZE_M": 16, - "NEED_TRANS": false, - "num_stages": 3, - "num_warps": 4 - }, "16384": { "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 64, @@ -36,12 +27,12 @@ "num_warps": 4 }, "2048": { - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 64, "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "NEED_TRANS": true, - "num_stages": 2, + "num_stages": 3, "num_warps": 4 }, "256": { @@ -53,15 +44,6 @@ "num_stages": 2, "num_warps": 4 }, - "32": { - "BLOCK_SIZE_K": 128, - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "GROUP_SIZE_M": 64, - "NEED_TRANS": true, - "num_stages": 2, - "num_warps": 4 - }, "32768": { "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 64, @@ -89,13 +71,22 @@ "num_stages": 2, "num_warps": 4 }, + "67584": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, "8": { "BLOCK_SIZE_K": 64, "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "NEED_TRANS": true, - "num_stages": 3, + "num_stages": 2, "num_warps": 4 }, "800": { diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=256,N=7168,expert_num=257,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=256,N=7168,expert_num=257,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json index 72d10ee77..11ab7d645 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=256,N=7168,expert_num=257,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=256,N=7168,expert_num=257,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -17,20 +17,11 @@ "num_stages": 2, "num_warps": 4 }, - "147456": { - "BLOCK_SIZE_K": 128, - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "GROUP_SIZE_M": 16, - "NEED_TRANS": false, - "num_stages": 3, - "num_warps": 4 - }, "18432": { "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 64, "NEED_TRANS": false, "num_stages": 3, "num_warps": 4 @@ -53,15 +44,6 @@ "num_stages": 2, "num_warps": 4 }, - "36": { - "BLOCK_SIZE_K": 128, - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "GROUP_SIZE_M": 64, - "NEED_TRANS": true, - "num_stages": 2, - "num_warps": 4 - }, "36864": { "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 64, @@ -89,13 +71,22 @@ "num_stages": 2, "num_warps": 4 }, + "76032": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, "9": { "BLOCK_SIZE_K": 64, "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "GROUP_SIZE_M": 64, "NEED_TRANS": true, - "num_stages": 3, + "num_stages": 2, "num_warps": 4 }, "900": { @@ -111,7 +102,7 @@ "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "NEED_TRANS": false, "num_stages": 3, "num_warps": 4 diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json index 8f24968c2..f55446dc3 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -12,16 +12,16 @@ "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "NEED_TRANS": true, - "num_stages": 4, + "num_stages": 5, "num_warps": 4 }, "1024": { "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "NEED_TRANS": false, "num_stages": 4, "num_warps": 4 @@ -30,7 +30,7 @@ "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "NEED_TRANS": true, "num_stages": 5, "num_warps": 4 @@ -38,26 +38,17 @@ "16": { "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "GROUP_SIZE_M": 32, "NEED_TRANS": true, - "num_stages": 4, - "num_warps": 4 - }, - "16384": { - "BLOCK_SIZE_K": 128, - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "GROUP_SIZE_M": 16, - "NEED_TRANS": false, - "num_stages": 4, + "num_stages": 3, "num_warps": 4 }, "2048": { "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "NEED_TRANS": false, "num_stages": 4, "num_warps": 4 @@ -66,27 +57,18 @@ "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "NEED_TRANS": true, "num_stages": 3, "num_warps": 4 }, "32": { "BLOCK_SIZE_K": 128, - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, - "GROUP_SIZE_M": 32, - "NEED_TRANS": true, - "num_stages": 4, - "num_warps": 4 - }, - "4": { - "BLOCK_SIZE_K": 128, - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "NEED_TRANS": true, - "num_stages": 5, + "num_stages": 3, "num_warps": 4 }, "4096": { @@ -102,9 +84,9 @@ "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 64, "NEED_TRANS": true, - "num_stages": 4, + "num_stages": 3, "num_warps": 4 }, "8": { @@ -113,6 +95,15 @@ "BLOCK_SIZE_N": 64, "GROUP_SIZE_M": 32, "NEED_TRANS": true, + "num_stages": 5, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, "num_stages": 4, "num_warps": 4 } diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=7168,N=512,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=9,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=7168,N=512,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=9,use_fp8_w8a8=true}_NVIDIA_H200.json index 93f028981..0004a3788 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=7168,N=512,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=9,use_fp8_w8a8=true}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=7168,N=512,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=9,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -5,7 +5,7 @@ "BLOCK_SIZE_N": 64, "GROUP_SIZE_M": 1, "NEED_TRANS": true, - "num_stages": 4, + "num_stages": 5, "num_warps": 4 }, "100": { @@ -18,39 +18,30 @@ "num_warps": 4 }, "1024": { - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 64, "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, "NEED_TRANS": false, - "num_stages": 4, + "num_stages": 5, "num_warps": 4 }, "128": { "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "GROUP_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, "NEED_TRANS": true, - "num_stages": 3, + "num_stages": 5, "num_warps": 4 }, "16": { "BLOCK_SIZE_K": 128, - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "GROUP_SIZE_M": 16, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, "NEED_TRANS": true, - "num_stages": 3, - "num_warps": 4 - }, - "16384": { - "BLOCK_SIZE_K": 128, - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "GROUP_SIZE_M": 16, - "NEED_TRANS": false, - "num_stages": 4, + "num_stages": 5, "num_warps": 4 }, "2048": { @@ -64,29 +55,20 @@ }, "256": { "BLOCK_SIZE_K": 128, - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "NEED_TRANS": true, "num_stages": 4, "num_warps": 4 }, "32": { "BLOCK_SIZE_K": 128, - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "GROUP_SIZE_M": 16, "NEED_TRANS": true, - "num_stages": 4, - "num_warps": 4 - }, - "4": { - "BLOCK_SIZE_K": 128, - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "GROUP_SIZE_M": 32, - "NEED_TRANS": true, - "num_stages": 5, + "num_stages": 3, "num_warps": 4 }, "4096": { @@ -102,17 +84,26 @@ "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "NEED_TRANS": true, "num_stages": 3, "num_warps": 4 }, "8": { "BLOCK_SIZE_K": 128, - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, "num_stages": 4, "num_warps": 4 } diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=8}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=8}_NVIDIA_H200.json index f5f5ab3fc..7478f6489 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=8}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=8}_NVIDIA_H200.json @@ -46,5 +46,9 @@ "8": { "BLOCK_SIZE": 256, "num_warps": 8 + }, + "8448": { + "BLOCK_SIZE": 128, + "num_warps": 4 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=9}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=9}_NVIDIA_H200.json index dc9ee0a47..925305439 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=9}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=9}_NVIDIA_H200.json @@ -46,5 +46,9 @@ "8": { "BLOCK_SIZE": 256, "num_warps": 8 + }, + "8448": { + "BLOCK_SIZE": 128, + "num_warps": 4 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=7168,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=7168,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json index acab99ff2..577986e8c 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=7168,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=7168,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json @@ -76,5 +76,11 @@ "BLOCK_M": 1, "NUM_STAGE": 1, "num_warps": 4 + }, + "8448": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=7168,out_dtype=torch.bfloat16,topk_num=9}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=7168,out_dtype=torch.bfloat16,topk_num=9}_NVIDIA_H200.json index 1c9898a34..0ced0a7a8 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=7168,out_dtype=torch.bfloat16,topk_num=9}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=7168,out_dtype=torch.bfloat16,topk_num=9}_NVIDIA_H200.json @@ -76,5 +76,11 @@ "BLOCK_M": 1, "NUM_STAGE": 1, "num_warps": 1 + }, + "8448": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=16,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=16,dtype=torch.bfloat16}_NVIDIA_H200.json index 15fbe40bc..5d1f0db0f 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=16,dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=16,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -76,5 +76,11 @@ "HEAD_PARALLEL_NUM": 16, "num_stages": 5, "num_warps": 1 + }, + "8448": { + "BLOCK_SEQ": 2, + "HEAD_PARALLEL_NUM": 2, + "num_stages": 4, + "num_warps": 1 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=2304,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=2304,out_dtype=torch.bfloat16}_NVIDIA_H200.json index a832b8e01..8103b3206 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=2304,out_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=2304,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -76,5 +76,11 @@ "BLOCK_N": 64, "NUM_STAGES": 2, "num_warps": 4 + }, + "8448": { + "BLOCK_M": 32, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json index 8ec84e27f..b3657e491 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -143,12 +143,24 @@ "NUM_STAGES": 4, "num_warps": 1 }, + "67584": { + "BLOCK_M": 32, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, "72": { "BLOCK_M": 1, "BLOCK_N": 256, "NUM_STAGES": 4, "num_warps": 4 }, + "76032": { + "BLOCK_M": 32, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, "8": { "BLOCK_M": 1, "BLOCK_N": 64, @@ -167,6 +179,12 @@ "NUM_STAGES": 1, "num_warps": 4 }, + "8448": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, "9": { "BLOCK_M": 1, "BLOCK_N": 32, From b81de0e934ca323e02b5de21590da07d4c28241c Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 7 Nov 2025 07:36:24 +0000 Subject: [PATCH 12/12] fix kernel --- .../common/fused_moe/grouped_fused_moe.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/lightllm/common/fused_moe/grouped_fused_moe.py b/lightllm/common/fused_moe/grouped_fused_moe.py index 83e06ccd7..758d83ba3 100644 --- a/lightllm/common/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/fused_moe/grouped_fused_moe.py @@ -481,6 +481,7 @@ def grouped_matmul_kernel( if use_fp8_w8a8: if block_size_k > 0 and block_size_n > 0: + assert BLOCK_SIZE_K <= block_size_k token_scale_stride0 = token_stride_0 // block_size_k if TOKEN_INPUT_USE_TMA: assert MUL_ROUTED_WEIGHT is True @@ -488,7 +489,12 @@ def grouped_matmul_kernel( else: a_scale_ptrs = token_scale_ptr + (a_m_index // topk_num) * token_scale_stride0 - offs_bsn = offs_bn // block_size_n + if BLOCK_SIZE_N > block_size_n: + offs_bsn = offs_bn // block_size_n + else: + # single b scale + offs_bsn = (tile_n_idx * BLOCK_SIZE_N) // block_size_n + b_scale_ptrs = weight_scale_ptr + expert_id * weight_scale_stride0 + offs_bsn * weight_scale_stride1 else: a_scale = tl.load(token_scale_ptr, eviction_policy="evict_last") @@ -556,9 +562,16 @@ def grouped_matmul_kernel( a_scale = tl.load(a_scale_ptrs + offs_ks, mask=token_mask, other=0.0) b_scale = tl.load(b_scale_ptrs + offs_ks * weight_scale_stride2) if NEED_TRANS: - accumulator += tl.dot(b, a) * b_scale[:, None] * a_scale[None, :] + if BLOCK_SIZE_N > block_size_n: + accumulator += tl.dot(b, a) * b_scale[:, None] * a_scale[None, :] + else: + # single b scale + accumulator += tl.dot(b, a) * (a_scale[None, :] * b_scale) else: - accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + if BLOCK_SIZE_N > block_size_n: + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + else: + accumulator += tl.dot(a, b) * (a_scale[:, None] * b_scale) else: if NEED_TRANS: accumulator = tl.dot(b, a, acc=accumulator)