diff --git a/custom_ops/gpu_ops/noauxtc_kernel.h b/custom_ops/gpu_ops/noauxtc_kernel.h index e83d3e50ba3..7758c3b64b7 100644 --- a/custom_ops/gpu_ops/noauxtc_kernel.h +++ b/custom_ops/gpu_ops/noauxtc_kernel.h @@ -17,8 +17,8 @@ #pragma once #include #include -#include "helper.h" #include +#include "helper.h" namespace cg = cooperative_groups; @@ -64,7 +64,9 @@ __forceinline__ __device__ bool is_better_than(T val, T baseline) { } template -__forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index, +__forceinline__ __device__ bool is_better_than(T val, + T baseline, + idxT index, idxT baseline_index) { bool res = (val > baseline && greater) || (val < baseline && !greater); if (val == baseline) { @@ -82,7 +84,11 @@ int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) { round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT)); } -template struct BitonicMerge { // input should be a bitonic sequence, and sort it to be a monotonic sequence @@ -99,8 +105,8 @@ struct BitonicMerge { T& other_val = val_arr[other_i]; bool is_better; if constexpr (is_stable) { - is_better = is_better_than(val, other_val, idx_arr[i], - idx_arr[other_i]); + is_better = is_better_than( + val, other_val, idx_arr[i], idx_arr[other_i]); } else { is_better = is_better_than(val, other_val); } @@ -182,7 +188,10 @@ struct BitonicSort<32, ascending, T, idxT, is_stable> { } }; -template struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> { __device__ static void merge(T* __restrict__ val_arr, @@ -234,7 +243,8 @@ class WarpSort { // load and merge k sorted values __device__ void load_sorted(T const* __restrict__ in, - idxT const* __restrict__ in_idx, idxT start) { + idxT const* __restrict__ in_idx, + idxT start) { idxT idx = start + WARP_SIZE - 1 - lane_; for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) { if (idx < start + k_) { @@ -456,8 +466,7 @@ __device__ void topk_with_k2(T* output, template __global__ void topk_with_k2_kernel(T* output, - T* input, - int64_t const num_tokens, + const T* input, int64_t const num_cases, int64_t const n_group, int64_t const num_experts_per_group) { @@ -484,11 +493,11 @@ __global__ void topk_with_k2_kernel(T* output, template __global__ void group_idx_and_topk_idx_kernel( - T* scores, + const T* scores, T const* group_scores, T* topk_values, IdxT* topk_indices, - T* scores_with_bias, + const T* scores_with_bias, int64_t const num_tokens, int64_t const n_group, int64_t const topk_group, @@ -550,14 +559,17 @@ __global__ void group_idx_and_topk_idx_kernel( value = neg_inf(); } pre_count_equal_to_top_value = count_equal_to_top_value; - count_equal_to_top_value = __popc(__ballot_sync( - FULL_WARP_MASK, (value == neg_inf()))); + count_equal_to_top_value = + __popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf()))); } num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value; } __syncthreads(); - warp_topk::WarpSelect queue((int32_t)topk, neg_inf()); @@ -602,19 +614,13 @@ __global__ void group_idx_and_topk_idx_kernel( if (i < topk) { s_topk_value[i] = value; } - topk_sum += cg::reduce(tile, cuda_cast(value), cg::plus()); + topk_sum += + cg::reduce(tile, cuda_cast(value), cg::plus()); } } __syncthreads(); - if (case_id < num_tokens && if_proceed_next_topk) { - for (int i = lane_id; i < num_experts; i += WARP_SIZE) { - scores[i] = 0; - } - } - __syncwarp(); - if (case_id < num_tokens) { if (if_proceed_next_topk) { for (int i = lane_id; i < topk; i += WARP_SIZE) { @@ -625,7 +631,6 @@ __global__ void group_idx_and_topk_idx_kernel( } else { value = cuda_cast(s_topk_value[i]) * routed_scaling_factor; } - scores[s_topk_idx[i]] = value; topk_indices[i] = s_topk_idx[i]; topk_values[i] = cuda_cast(value); } @@ -662,7 +667,11 @@ void invokeNoAuxTc(T* scores, #ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU topk_with_k2_kernel<<>>( - group_scores, scores_with_bias, num_tokens, num_cases, n_group, num_experts / n_group); + group_scores, + scores_with_bias, + num_cases, + n_group, + num_experts / n_group); #else auto* kernel_instance1 = &topk_with_k2_kernel; cudaLaunchConfig_t config; @@ -675,8 +684,13 @@ void invokeNoAuxTc(T* scores, attrs[0].val.programmaticStreamSerializationAllowed = false; config.numAttrs = 1; config.attrs = attrs; - cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias, - num_tokens, num_cases, n_group, num_experts / n_group); + cudaLaunchKernelEx(&config, + kernel_instance1, + group_scores, + scores_with_bias, + num_cases, + n_group, + num_experts / n_group); #endif int64_t topk_with_k_group_num_blocks = @@ -686,10 +700,22 @@ void invokeNoAuxTc(T* scores, topk); #ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU - group_idx_and_topk_idx_kernel<<>>( - scores, group_scores, topk_values, topk_indices, scores_with_bias, - num_tokens, n_group, topk_group, topk, num_experts, num_experts / n_group, - renormalize, routed_scaling_factor); + group_idx_and_topk_idx_kernel<<>>(scores, + group_scores, + topk_values, + topk_indices, + scores_with_bias, + num_tokens, + n_group, + topk_group, + topk, + num_experts, + num_experts / n_group, + renormalize, + routed_scaling_factor); #else auto* kernel_instance2 = &group_idx_and_topk_idx_kernel; config.gridDim = topk_with_k_group_num_blocks; @@ -700,26 +726,37 @@ void invokeNoAuxTc(T* scores, attrs[0].val.programmaticStreamSerializationAllowed = false; config.numAttrs = 1; config.attrs = attrs; - cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, - topk_values, topk_indices, scores_with_bias, num_tokens, - n_group, topk_group, topk, num_experts, - num_experts / n_group, renormalize, routed_scaling_factor); + cudaLaunchKernelEx(&config, + kernel_instance2, + scores, + group_scores, + topk_values, + topk_indices, + scores_with_bias, + num_tokens, + n_group, + topk_group, + topk, + num_experts, + num_experts / n_group, + renormalize, + routed_scaling_factor); #endif } #define INSTANTIATE_NOAUX_TC(T, IdxT) \ template void invokeNoAuxTc(T * scores, \ - T * group_scores, \ - T* topk_values, \ - IdxT* topk_indices, \ - T * scores_with_bias, \ - int64_t const num_tokens, \ - int64_t const num_experts, \ - int64_t const n_group, \ - int64_t const topk_group, \ - int64_t const topk, \ - bool const renormalize, \ - double const routed_scaling_factor, \ - cudaStream_t const stream); + T * group_scores, \ + T * topk_values, \ + IdxT * topk_indices, \ + T * scores_with_bias, \ + int64_t const num_tokens, \ + int64_t const num_experts, \ + int64_t const n_group, \ + int64_t const topk_group, \ + int64_t const topk, \ + bool const renormalize, \ + double const routed_scaling_factor, \ + cudaStream_t const stream); INSTANTIATE_NOAUX_TC(float, int32_t); diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py index 7277b9697fb..ca2f4bd2527 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py @@ -256,7 +256,7 @@ def apply( if topk_method == "noaux_tc": from fastdeploy.model_executor.layers.moe.moe import get_moe_scores - gate_out, _, _ = get_moe_scores( + _, topk_weights, topk_ids = get_moe_scores( gate_out, layer.n_group, layer.topk_group, @@ -265,8 +265,6 @@ def apply( layer.gate_correction_bias, getattr(layer, "renormalize", True), ) - - topk_weights, topk_ids = paddle.topk(gate_out, k=layer.top_k, axis=-1, sorted=False) else: topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out,