Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 108 additions & 38 deletions cpp/tensorrt_llm/kernels/sparseAttentionKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace tensorrt_llm
{
namespace kernels
{
template <int THREADS_PER_BLOCK>
template <int THREADS_PER_BLOCK, int MAX_NUM_PAGES>
__global__ void gatherKvPageOffsetsKernel(
int32_t* output_kv_page_offsets, // [num_head_kv, batch_size, 2, max_num_pages_per_seq]
int32_t* output_seq_lengths, // [num_head_kv, batch_size]
Expand All @@ -32,23 +32,33 @@ __global__ void gatherKvPageOffsetsKernel(
// Each CUDA block processes one sequence from the batch for one head.
int32_t const head_idx = blockIdx.x;
int32_t const batch_idx = blockIdx.y;
int32_t const indices_block_size = sparse_params.sparse_attn_indices_block_size;
if (batch_idx >= batch_size)
{
return;
}

// Shared memory for reduction.
__shared__ typename cub::BlockReduce<Pair, THREADS_PER_BLOCK>::TempStorage temp_storage;
using BlockScan = cub::BlockScan<int32_t, THREADS_PER_BLOCK>;
using BlockReduce = cub::BlockReduce<Pair, THREADS_PER_BLOCK>;

__shared__ typename BlockScan::TempStorage temp_storage_scan;
__shared__ typename BlockReduce::TempStorage temp_storage_reduce;

__shared__ int32_t s_page_mask[MAX_NUM_PAGES];
__shared__ int32_t s_cu_page_mask[MAX_NUM_PAGES];
__shared__ int32_t s_scan_total; // Store total count from scan

// Get the range of sparse indices and the sequence length.
int32_t const start_offset = sparse_params.sparse_attn_offsets[batch_idx];
int32_t const end_offset = sparse_params.sparse_attn_offsets[batch_idx + 1];
int32_t const total_pages = sparse_params.sparse_attn_offsets[batch_size];
int32_t const num_sparse_pages = end_offset - start_offset;
int32_t const sparse_attn_indices_stride = sparse_params.sparse_attn_indices_stride;
int32_t const num_sparse_indices = end_offset - start_offset;
int32_t const original_seq_len = seq_lengths[batch_idx];
int32_t const ori_valid_pages = (original_seq_len + tokens_per_page - 1) / tokens_per_page;
int32_t const page_loops = (ori_valid_pages + MAX_NUM_PAGES - 1) / MAX_NUM_PAGES;

// Get global sparse index.
int32_t const sparse_idx_global = head_idx * total_pages + start_offset;
int32_t const sparse_idx_global = head_idx * sparse_attn_indices_stride + start_offset;

// Get the base memory offset. shape: [batch_size, 2, max_num_pages_per_seq]
size_t const src_base_offset = (size_t) batch_idx * 2 * max_num_pages_per_seq;
Expand All @@ -58,56 +68,119 @@ __global__ void gatherKvPageOffsetsKernel(
int32_t local_max_page_index = -1;
int32_t local_num_valid_pages = 0;

// Perform the gather operation.
for (int32_t i = threadIdx.x; i < num_sparse_pages; i += blockDim.x)
int32_t src_page_idx_offset = 0;
int32_t dst_page_idx_offset = 0;
for (int32_t loop_idx = 0; loop_idx < page_loops; loop_idx++)
{
// Get the source idx and offset.
int32_t const src_idx = sparse_params.sparse_attn_indices[sparse_idx_global + i];
if (src_idx < 0)
src_page_idx_offset = loop_idx * MAX_NUM_PAGES;
int32_t loop_num_valid_pages = min(MAX_NUM_PAGES, ori_valid_pages - src_page_idx_offset);
for (int32_t i = threadIdx.x; i < MAX_NUM_PAGES; i += blockDim.x)
{
s_page_mask[i] = 0;
}
__syncthreads();

for (int32_t i = threadIdx.x; i < num_sparse_indices; i += blockDim.x)
{
continue;
int32_t const src_idx = sparse_params.sparse_attn_indices[sparse_idx_global + i];
int32_t const src_idx_start = src_idx * indices_block_size;
int32_t const src_idx_end = min(src_idx_start + indices_block_size, original_seq_len);
for (int32_t j = src_idx_start; j < src_idx_end; j++)
{
int32_t const src_page_idx = j / tokens_per_page;
if (src_page_idx >= src_page_idx_offset && src_page_idx < src_page_idx_offset + loop_num_valid_pages)
{
atomicExch(&s_page_mask[src_page_idx - src_page_idx_offset], 1);
}
}
}
__syncthreads();

// Handle case when loop_num_valid_pages > blockDim.x by processing in chunks
int32_t scan_offset = 0;
int32_t const scan_chunks = (loop_num_valid_pages + blockDim.x - 1) / blockDim.x;

// Update the local max page index.
local_max_page_index = max(local_max_page_index, src_idx);
local_num_valid_pages++;
for (int32_t chunk_idx = 0; chunk_idx < scan_chunks; chunk_idx++)
{
int32_t const chunk_start = chunk_idx * blockDim.x;
int32_t const chunk_size = min((int32_t) blockDim.x, loop_num_valid_pages - chunk_start);

int32_t thread_data = (threadIdx.x < chunk_size) ? s_page_mask[chunk_start + threadIdx.x] : 0;
int32_t thread_output;
int32_t aggregate;

BlockScan(temp_storage_scan).ExclusiveSum(thread_data, thread_output, aggregate);
__syncthreads();

if (threadIdx.x < chunk_size)
{
s_cu_page_mask[chunk_start + threadIdx.x] = thread_output + scan_offset;
}
__syncthreads();

// Get the source and destination offsets.
size_t const src_offset_dim0 = src_base_offset + 0 * max_num_pages_per_seq + src_idx;
size_t const src_offset_dim1 = src_base_offset + 1 * max_num_pages_per_seq + src_idx;
size_t const dst_offset_dim0 = dst_base_offset + 0 * max_num_pages_per_seq + i;
size_t const dst_offset_dim1 = dst_base_offset + 1 * max_num_pages_per_seq + i;
// Update scan offset for next chunk
scan_offset += aggregate;
}

// Perform the gather operation: read from the sparse location and write to the dense location.
output_kv_page_offsets[dst_offset_dim0] = kv_page_offsets[src_offset_dim0];
output_kv_page_offsets[dst_offset_dim1] = kv_page_offsets[src_offset_dim1];
if (threadIdx.x == 0)
{
s_scan_total = scan_offset;
}
__syncthreads();

// Perform the gather operation.
for (int32_t i = threadIdx.x; i < loop_num_valid_pages; i += blockDim.x)
{
// Skip if the page is not valid.
if (s_page_mask[i] == 0)
{
continue;
}

int32_t const src_idx = src_page_idx_offset + i;
int32_t const dst_idx = dst_page_idx_offset + s_cu_page_mask[i];

local_max_page_index = max(local_max_page_index, src_idx);
local_num_valid_pages++;

size_t const src_offset_dim0 = src_base_offset + 0 * max_num_pages_per_seq + src_idx;
size_t const src_offset_dim1 = src_base_offset + 1 * max_num_pages_per_seq + src_idx;
size_t const dst_offset_dim0 = dst_base_offset + 0 * max_num_pages_per_seq + dst_idx;
size_t const dst_offset_dim1 = dst_base_offset + 1 * max_num_pages_per_seq + dst_idx;

output_kv_page_offsets[dst_offset_dim0] = kv_page_offsets[src_offset_dim0];
output_kv_page_offsets[dst_offset_dim1] = kv_page_offsets[src_offset_dim1];
}
__syncthreads();

// Update dst offset using the total count from scan
dst_page_idx_offset += s_scan_total;
}

// Reduce the local max page indices and number of valid pages.
Pair local_pair = {local_max_page_index, local_num_valid_pages};
Pair result = cub::BlockReduce<Pair, THREADS_PER_BLOCK>(temp_storage).Reduce(local_pair, PairReduceOp());
Pair result = BlockReduce(temp_storage_reduce).Reduce(local_pair, PairReduceOp());

// Update sequence length for this head and batch.
if (threadIdx.x == 0)
{
int32_t const max_page_index = result.max_val;
int32_t const num_valid_pages = result.sum_val;
int32_t const ori_valid_pages = (original_seq_len + tokens_per_page - 1) / tokens_per_page;
size_t const seq_len_offset = (size_t) head_idx * batch_size + batch_idx;
int32_t seq_len = 0;
if (num_valid_pages > 0)
{
int32_t seq_len = original_seq_len - (ori_valid_pages - num_valid_pages) * tokens_per_page;
int32_t seq_len_remain = original_seq_len % tokens_per_page;
if (max_page_index != ori_valid_pages - 1 && seq_len_remain != 0)
if (max_page_index == ori_valid_pages - 1)
{
seq_len += tokens_per_page - seq_len_remain;
seq_len = (num_valid_pages - 1) * tokens_per_page
+ (original_seq_len - (ori_valid_pages - 1) * tokens_per_page);
}
else
{
seq_len = num_valid_pages * tokens_per_page;
}
output_seq_lengths[seq_len_offset] = seq_len;
}
else
{
output_seq_lengths[seq_len_offset] = 0;
}
output_seq_lengths[seq_len_offset] = seq_len;
}
}

Expand All @@ -121,11 +194,8 @@ void invokeGatherKvPageOffsets(int32_t* output_kv_page_offsets, int32_t* output_
dim3 grid(num_head_kv, batch_size, 1);
// The block.
dim3 block(256, 1, 1);
// Shared memory size.
size_t smem_size = sizeof(Pair) * 256;

// Launch the kernel.
gatherKvPageOffsetsKernel<256><<<grid, block, smem_size, stream>>>(output_kv_page_offsets, output_seq_lengths,
gatherKvPageOffsetsKernel<256, 512><<<grid, block, 0, stream>>>(output_kv_page_offsets, output_seq_lengths,
kv_page_offsets, seq_lengths, sparse_params, batch_size, tokens_per_page, max_num_pages_per_seq);
}
} // namespace kernels
Expand Down
7 changes: 6 additions & 1 deletion cpp/tensorrt_llm/kernels/sparseAttentionKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ struct SparseAttentionParams
int32_t sparse_mla_topk{0}; // for DSA attention
void* sparse_mla_kv_cache_pool{nullptr}; // for DSA attention

int32_t sparse_attn_indices_block_size{1};
int32_t sparse_attn_indices_stride{0};

std::string toString() const
{
std::stringstream ss;
Expand All @@ -43,7 +46,9 @@ struct SparseAttentionParams
<< "sparse_kv_offsets: " << this->sparse_kv_offsets << std::endl
<< "sparse_attn_offsets: " << this->sparse_attn_offsets << std::endl
<< "sparse_mla_topk: " << this->sparse_mla_topk << std::endl
<< "sparse_mla_kv_cache_pool: " << this->sparse_mla_kv_cache_pool << std::endl;
<< "sparse_mla_kv_cache_pool: " << this->sparse_mla_kv_cache_pool << std::endl
<< "sparse_attn_indices_block_size: " << this->sparse_attn_indices_block_size << std::endl
<< "sparse_attn_indices_stride: " << this->sparse_attn_indices_stride << std::endl;
return ss.str();
}
};
Expand Down
11 changes: 6 additions & 5 deletions cpp/tensorrt_llm/nanobind/thop/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,11 @@ void initBindings(nb::module_& m)
nb::arg("softmax_stats_tensor") = std::nullopt, nb::arg("spec_decoding_bool_params"),
nb::arg("spec_decoding_tensor_params"), nb::arg("sparse_kv_indices") = std::nullopt,
nb::arg("sparse_kv_offsets") = std::nullopt, nb::arg("sparse_attn_indices") = std::nullopt,
nb::arg("sparse_attn_offsets") = std::nullopt, nb::arg("sparse_mla_topk") = std::nullopt,
nb::arg("cu_q_seqlens") = std::nullopt, nb::arg("cu_kv_seqlens") = std::nullopt,
nb::arg("fmha_scheduler_counter") = std::nullopt, nb::arg("mla_bmm1_scale") = std::nullopt,
nb::arg("mla_bmm2_scale") = std::nullopt, nb::arg("quant_q_buffer") = std::nullopt,
"Multi-head attention operation", nb::call_guard<nb::gil_scoped_release>());
nb::arg("sparse_attn_offsets") = std::nullopt, nb::arg("sparse_attn_indices_block_size"),
nb::arg("sparse_mla_topk") = std::nullopt, nb::arg("cu_q_seqlens") = std::nullopt,
nb::arg("cu_kv_seqlens") = std::nullopt, nb::arg("fmha_scheduler_counter") = std::nullopt,
nb::arg("mla_bmm1_scale") = std::nullopt, nb::arg("mla_bmm2_scale") = std::nullopt,
nb::arg("quant_q_buffer") = std::nullopt, "Multi-head attention operation",
nb::call_guard<nb::gil_scoped_release>());
}
} // namespace tensorrt_llm::nanobind::thop
11 changes: 6 additions & 5 deletions cpp/tensorrt_llm/pybind/thop/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,11 @@ void initBindings(pybind11::module_& m)
py::arg("softmax_stats_tensor") = std::nullopt, py::arg("spec_decoding_bool_params"),
py::arg("spec_decoding_tensor_params"), py::arg("sparse_kv_indices") = std::nullopt,
py::arg("sparse_kv_offsets") = std::nullopt, py::arg("sparse_attn_indices") = std::nullopt,
py::arg("sparse_attn_offsets") = std::nullopt, py::arg("sparse_mla_topk") = std::nullopt,
py::arg("cu_q_seqlens") = std::nullopt, py::arg("cu_kv_seqlens") = std::nullopt,
py::arg("fmha_scheduler_counter") = std::nullopt, py::arg("mla_bmm1_scale") = std::nullopt,
py::arg("mla_bmm2_scale") = std::nullopt, py::arg("quant_q_buffer") = std::nullopt,
"Multi-head attention operation", py::call_guard<py::gil_scoped_release>());
py::arg("sparse_attn_offsets") = std::nullopt, py::arg("sparse_attn_indices_block_size"),
py::arg("sparse_mla_topk") = std::nullopt, py::arg("cu_q_seqlens") = std::nullopt,
py::arg("cu_kv_seqlens") = std::nullopt, py::arg("fmha_scheduler_counter") = std::nullopt,
py::arg("mla_bmm1_scale") = std::nullopt, py::arg("mla_bmm2_scale") = std::nullopt,
py::arg("quant_q_buffer") = std::nullopt, "Multi-head attention operation",
py::call_guard<py::gil_scoped_release>());
}
} // namespace tensorrt_llm::pybind::thop
37 changes: 21 additions & 16 deletions cpp/tensorrt_llm/thop/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,11 @@ class RunnerBase
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
torch::optional<torch::Tensor> sparse_attn_offsets, int32_t const sparse_mla_topk,
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer) const
torch::optional<torch::Tensor> sparse_attn_offsets, int64_t const sparse_attn_indices_block_size,
int32_t const sparse_mla_topk, std::optional<torch::Tensor> cu_q_seqlens,
std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
std::optional<torch::Tensor> quant_q_buffer) const
= 0;
};

Expand Down Expand Up @@ -146,10 +147,11 @@ class Runner : public RunnerBase
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
torch::optional<torch::Tensor> sparse_attn_offsets, int32_t const sparse_mla_topk,
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer) const override
torch::optional<torch::Tensor> sparse_attn_offsets, int64_t const sparse_attn_indices_block_size,
int32_t const sparse_mla_topk, std::optional<torch::Tensor> cu_q_seqlens,
std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
std::optional<torch::Tensor> quant_q_buffer) const override
{
auto stream = at::cuda::getCurrentCUDAStream(qkv_or_q.get_device());
T* attention_input = static_cast<T*>(qkv_or_q.slice(0, token_offset).data_ptr());
Expand Down Expand Up @@ -395,6 +397,9 @@ class Runner : public RunnerBase
= sparse_attn_indices.has_value() ? sparse_attn_indices.value().data_ptr<int32_t>() : nullptr;
op.mRuntimeSparseAttentionParams.sparse_attn_offsets
= sparse_attn_offsets.has_value() ? sparse_attn_offsets.value().data_ptr<int32_t>() : nullptr;
op.mRuntimeSparseAttentionParams.sparse_attn_indices_block_size = sparse_attn_indices_block_size;
op.mRuntimeSparseAttentionParams.sparse_attn_indices_stride
= sparse_attn_indices.has_value() ? sparse_attn_indices.value().size(-1) : 0;
if (op.isMLAEnabled() && op.mUseSparseAttention)
{
op.mRuntimeSparseAttentionParams.sparse_mla_topk = sparse_mla_topk;
Expand Down Expand Up @@ -589,10 +594,10 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params,
std::optional<torch::Tensor> sparse_kv_indices, std::optional<torch::Tensor> sparse_kv_offsets,
std::optional<torch::Tensor> sparse_attn_indices, std::optional<torch::Tensor> sparse_attn_offsets,
std::optional<int64_t> sparse_mla_topk, std::optional<torch::Tensor> cu_q_seqlens,
std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
std::optional<torch::Tensor> quant_q_buffer)
int64_t const sparse_attn_indices_block_size, std::optional<int64_t> sparse_mla_topk,
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer)
{
TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx);
// Use these tensors to infer if the attention is using KV cache
Expand Down Expand Up @@ -847,8 +852,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets,
sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter, mla_bmm1_scale, mla_bmm2_scale,
quant_q_buffer);
sparse_attn_indices_block_size, sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter,
mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer);
}

if ((num_generations > 0) && (attn_input_type != AttentionInputType::ContextOnly))
Expand All @@ -866,8 +871,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets,
sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter, mla_bmm1_scale, mla_bmm2_scale,
quant_q_buffer);
sparse_attn_indices_block_size, sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter,
mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer);
}

TLLM_LOG_TRACE("Attention op stops at layer %d", layer_idx);
Expand Down
Loading