Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 83 additions & 46 deletions custom_ops/gpu_ops/noauxtc_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
#pragma once
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include "helper.h"
#include <cuda/std/limits>
#include "helper.h"

namespace cg = cooperative_groups;

Expand Down Expand Up @@ -64,7 +64,9 @@ __forceinline__ __device__ bool is_better_than(T val, T baseline) {
}

template <bool greater, typename T, typename idxT>
__forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index,
__forceinline__ __device__ bool is_better_than(T val,
T baseline,
idxT index,
idxT baseline_index) {
bool res = (val > baseline && greater) || (val < baseline && !greater);
if (val == baseline) {
Expand All @@ -82,7 +84,11 @@ int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT));
}

template <int size, bool ascending, bool reverse, typename T, typename idxT,
template <int size,
bool ascending,
bool reverse,
typename T,
typename idxT,
bool is_stable>
struct BitonicMerge {
// input should be a bitonic sequence, and sort it to be a monotonic sequence
Expand All @@ -99,8 +105,8 @@ struct BitonicMerge {
T& other_val = val_arr[other_i];
bool is_better;
if constexpr (is_stable) {
is_better = is_better_than<ascending>(val, other_val, idx_arr[i],
idx_arr[other_i]);
is_better = is_better_than<ascending>(
val, other_val, idx_arr[i], idx_arr[other_i]);
} else {
is_better = is_better_than<ascending>(val, other_val);
}
Expand Down Expand Up @@ -182,7 +188,10 @@ struct BitonicSort<32, ascending, T, idxT, is_stable> {
}
};

template <bool ascending, bool reverse, typename T, typename idxT,
template <bool ascending,
bool reverse,
typename T,
typename idxT,
bool is_stable>
struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> {
__device__ static void merge(T* __restrict__ val_arr,
Expand Down Expand Up @@ -234,7 +243,8 @@ class WarpSort {

// load and merge k sorted values
__device__ void load_sorted(T const* __restrict__ in,
idxT const* __restrict__ in_idx, idxT start) {
idxT const* __restrict__ in_idx,
idxT start) {
idxT idx = start + WARP_SIZE - 1 - lane_;
for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) {
if (idx < start + k_) {
Expand Down Expand Up @@ -456,8 +466,7 @@ __device__ void topk_with_k2(T* output,

template <typename T>
__global__ void topk_with_k2_kernel(T* output,
T* input,
int64_t const num_tokens,
const T* input,
int64_t const num_cases,
int64_t const n_group,
int64_t const num_experts_per_group) {
Expand All @@ -484,11 +493,11 @@ __global__ void topk_with_k2_kernel(T* output,

template <typename T, typename IdxT>
__global__ void group_idx_and_topk_idx_kernel(
T* scores,
const T* scores,
T const* group_scores,
T* topk_values,
IdxT* topk_indices,
T* scores_with_bias,
const T* scores_with_bias,
int64_t const num_tokens,
int64_t const n_group,
int64_t const topk_group,
Expand Down Expand Up @@ -550,14 +559,17 @@ __global__ void group_idx_and_topk_idx_kernel(
value = neg_inf<T>();
}
pre_count_equal_to_top_value = count_equal_to_top_value;
count_equal_to_top_value = __popc(__ballot_sync(
FULL_WARP_MASK, (value == neg_inf<T>())));
count_equal_to_top_value =
__popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf<T>())));
}
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
}
__syncthreads();

warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t,
warp_topk::WarpSelect</*capability*/ WARP_SIZE,
/*greater*/ true,
T,
int32_t,
/* is_stable */ true>
queue((int32_t)topk, neg_inf<T>());

Expand Down Expand Up @@ -602,19 +614,13 @@ __global__ void group_idx_and_topk_idx_kernel(
if (i < topk) {
s_topk_value[i] = value;
}
topk_sum += cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
topk_sum +=
cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
}
}

__syncthreads();

if (case_id < num_tokens && if_proceed_next_topk) {
for (int i = lane_id; i < num_experts; i += WARP_SIZE) {
scores[i] = 0;
}
}
__syncwarp();

if (case_id < num_tokens) {
if (if_proceed_next_topk) {
for (int i = lane_id; i < topk; i += WARP_SIZE) {
Expand All @@ -625,7 +631,6 @@ __global__ void group_idx_and_topk_idx_kernel(
} else {
value = cuda_cast<float, T>(s_topk_value[i]) * routed_scaling_factor;
}
scores[s_topk_idx[i]] = value;
topk_indices[i] = s_topk_idx[i];
topk_values[i] = cuda_cast<T, float>(value);
}
Expand Down Expand Up @@ -662,7 +667,11 @@ void invokeNoAuxTc(T* scores,

#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream>>>(
group_scores, scores_with_bias, num_tokens, num_cases, n_group, num_experts / n_group);
group_scores,
scores_with_bias,
num_cases,
n_group,
num_experts / n_group);
#else
auto* kernel_instance1 = &topk_with_k2_kernel<T>;
cudaLaunchConfig_t config;
Expand All @@ -675,8 +684,13 @@ void invokeNoAuxTc(T* scores,
attrs[0].val.programmaticStreamSerializationAllowed = false;
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias,
num_tokens, num_cases, n_group, num_experts / n_group);
cudaLaunchKernelEx(&config,
kernel_instance1,
group_scores,
scores_with_bias,
num_cases,
n_group,
num_experts / n_group);
#endif

int64_t topk_with_k_group_num_blocks =
Expand All @@ -686,10 +700,22 @@ void invokeNoAuxTc(T* scores,
topk);

#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
group_idx_and_topk_idx_kernel<T, IdxT><<<topk_with_k_group_num_blocks, BLOCK_SIZE, dynamic_smem_in_bytes, stream>>>(
scores, group_scores, topk_values, topk_indices, scores_with_bias,
num_tokens, n_group, topk_group, topk, num_experts, num_experts / n_group,
renormalize, routed_scaling_factor);
group_idx_and_topk_idx_kernel<T, IdxT><<<topk_with_k_group_num_blocks,
BLOCK_SIZE,
dynamic_smem_in_bytes,
stream>>>(scores,
group_scores,
topk_values,
topk_indices,
scores_with_bias,
num_tokens,
n_group,
topk_group,
topk,
num_experts,
num_experts / n_group,
renormalize,
routed_scaling_factor);
#else
auto* kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>;
config.gridDim = topk_with_k_group_num_blocks;
Expand All @@ -700,26 +726,37 @@ void invokeNoAuxTc(T* scores,
attrs[0].val.programmaticStreamSerializationAllowed = false;
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
topk_values, topk_indices, scores_with_bias, num_tokens,
n_group, topk_group, topk, num_experts,
num_experts / n_group, renormalize, routed_scaling_factor);
cudaLaunchKernelEx(&config,
kernel_instance2,
scores,
group_scores,
topk_values,
topk_indices,
scores_with_bias,
num_tokens,
n_group,
topk_group,
topk,
num_experts,
num_experts / n_group,
renormalize,
routed_scaling_factor);
#endif
}

#define INSTANTIATE_NOAUX_TC(T, IdxT) \
template void invokeNoAuxTc<T, IdxT>(T * scores, \
T * group_scores, \
T* topk_values, \
IdxT* topk_indices, \
T * scores_with_bias, \
int64_t const num_tokens, \
int64_t const num_experts, \
int64_t const n_group, \
int64_t const topk_group, \
int64_t const topk, \
bool const renormalize, \
double const routed_scaling_factor, \
cudaStream_t const stream);
T * group_scores, \
T * topk_values, \
IdxT * topk_indices, \
T * scores_with_bias, \
int64_t const num_tokens, \
int64_t const num_experts, \
int64_t const n_group, \
int64_t const topk_group, \
int64_t const topk, \
bool const renormalize, \
double const routed_scaling_factor, \
cudaStream_t const stream);

INSTANTIATE_NOAUX_TC(float, int32_t);
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def apply(
if topk_method == "noaux_tc":
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores

gate_out, _, _ = get_moe_scores(
_, topk_weights, topk_ids = get_moe_scores(
gate_out,
layer.n_group,
layer.topk_group,
Expand All @@ -265,8 +265,6 @@ def apply(
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)

topk_weights, topk_ids = paddle.topk(gate_out, k=layer.top_k, axis=-1, sorted=False)
else:
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
Expand Down
Loading