Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions custom_ops/gpu_ops/append_attn/append_attention_func.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2414,7 +2414,8 @@ template <typename T,
uint32_t bdy,
uint32_t HEAD_DIM,
typename OutT = T,
bool ENABLE_PREFILL = true>
bool ENABLE_PREFILL = true,
bool DECODE_ONLY = true>
__global__ void merge_multi_chunks_v2_kernel(
const T* __restrict__ multi_out, // [token_num, num_chunks, num_heads,
// head_dim]
Expand Down Expand Up @@ -2458,15 +2459,16 @@ __global__ void merge_multi_chunks_v2_kernel(
if (ENABLE_PREFILL) {
seq_len_kv += seq_len_q;
if (seq_len_kv == 0) continue;

const int seq_len_enc = seq_lens_encoder[bid];
if (seq_len_enc <= 0) {
continue;
}
} else {
if (seq_len_kv == 0) continue;
seq_len_kv += seq_len_q;
}
if constexpr (DECODE_ONLY) {
const int seq_len_enc = seq_lens_encoder[bid];
if (seq_len_enc > 0) {
continue;
}
}
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
if (num_chunks_this_seq <= 1) {
continue;
Expand Down
145 changes: 55 additions & 90 deletions custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const T *__restrict__ sinks, // [q_num_heads]
const int *__restrict__ seq_lens,
const int *__restrict__ seq_lens_kv,
const int *__restrict__ seq_lens_encoder,
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
Expand Down Expand Up @@ -501,6 +502,10 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
}
kv_len += q_len;
}
const int seq_len_enc = seq_lens_encoder[batch_id];
if (seq_len_enc > 0) {
return;
}
const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size);
if (chunk_idx >= num_chunks_this_seq) {
return;
Expand Down Expand Up @@ -1050,95 +1055,52 @@ void MultiQueryAppendAttention(
sliding_window);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
if (is_decoder) {
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(bsz, num_heads);
dim3 blocks_merge(blockx, blocky);
auto *kernelFn = merge_multi_chunks_decoder_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>;
launchWithPdlWhenEnabled(
kernelFn,
grids_merge,
blocks_merge,
0,
stream,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM);
} else {
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
num_heads); // 128k is too large
dim3 blocks_merge(blockx, blocky);
auto *kernelFn = merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>;
launchWithPdlWhenEnabled(
kernelFn,
grids_merge,
blocks_merge,
0,
stream,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM,
token_num,
speculate_max_draft_token_num);
}
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
num_heads); // 128k is too large
dim3 blocks_merge(blockx, blocky);
auto *kernelFn = merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL,
false>;
launchWithPdlWhenEnabled(
kernelFn,
grids_merge,
blocks_merge,
0,
stream,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM,
token_num,
speculate_max_draft_token_num);
}
} else {
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV;
Expand Down Expand Up @@ -1222,6 +1184,7 @@ void MultiQueryAppendAttention(
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
Expand Down Expand Up @@ -1303,6 +1266,7 @@ void MultiQueryAppendAttention(
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
Expand Down Expand Up @@ -1380,7 +1344,8 @@ void MultiQueryAppendAttention(
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>;
ENABLE_PREFILL,
true>;
launchWithPdlWhenEnabled(
kernelFn,
grids_merge,
Expand Down
Loading
Loading