Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 134 additions & 21 deletions tensorrt_llm/_torch/attention_backend/sparse/dsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,8 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
indexer_max_chunk_size: int
# Topk for sparse MLA
sparse_mla_topk: int
# max number of draft tokens
max_draft_tokens: int = 0

def __init__(self, *args, **kwargs):
self.num_sms = tensorrt_llm.deep_gemm.get_num_sms()
Expand Down Expand Up @@ -432,6 +434,64 @@ def __post_init__(self):
dtype=torch.int32,
capture_graph=capture_graph,
)
self.create_expanded_buffers(capture_graph=capture_graph)

# TODO: remove these expanded buffers when fp8_paged_mqa_logits supports MTP > 1.
def create_expanded_buffers(self, capture_graph=False):
self.kv_lens_expanded_cuda = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_sequences * (1 + self.max_draft_tokens), ),
cache_name="kv_lens_expanded_cuda",
dtype=torch.int32,
capture_graph=capture_graph,
)
self.kv_lens_expanded_host = torch.zeros_like(
self.kv_lens_expanded_cuda,
device='cpu',
pin_memory=True,
)
self.block_table_expanded = self.get_empty(
self.cuda_graph_buffers,
[
self.max_num_sequences * (1 + self.max_draft_tokens),
self.kv_cache_manager.max_blocks_per_seq
],
cache_name="block_table_expanded",
dtype=torch.int32,
capture_graph=capture_graph,
)
self.host_block_table_expanded = torch.zeros_like(
self.block_table_expanded,
device='cpu',
pin_memory=True,
)
self.scheduler_metadata_buffer_expanded = self.get_empty(
self.cuda_graph_buffers,
(self.num_sms + 1, 2),
cache_name="scheduler_metadata_buffer_expanded",
dtype=torch.int32,
capture_graph=capture_graph,
)

# This function is only used to create the expanded buffers when the max_draft_tokens is changed.
# TODO: remove this function when fp8_paged_mqa_logits can support MTP > 1.
def update_spec_dec_param(
self,
is_spec_decoding_enabled,
is_spec_dec_tree,
is_spec_dec_dynamic_tree,
max_draft_tokens,
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None,
):
super().update_spec_dec_param(is_spec_decoding_enabled,
is_spec_dec_tree,
is_spec_dec_dynamic_tree,
max_draft_tokens, spec_decoding_tensor)
self.max_draft_tokens = max_draft_tokens
init_shape = self.kv_lens_expanded_host.shape[0]
if self.max_num_sequences * (1 + self.max_draft_tokens) != init_shape:
capture_graph = torch.cuda.is_current_stream_capturing()
self.create_expanded_buffers(capture_graph=capture_graph)

def prepare(self):
super().prepare()
Expand Down Expand Up @@ -535,6 +595,41 @@ def prepare(self):
else:
self.max_gen_seq_len = 0

# Because the fp8_paged_mqa_logits only supports seq_len == 1 or 2, so it cannot support
# MTP > 1. To handle this, when MTP > 1, we flatten the q tensor and expand the kv_lens and
# block_table for to use the fp8_paged_mqa_logits.
# TODO: remove this when fp8_paged_mqa_logits supports MTP > 1.
if self.max_draft_tokens > 1:
# Expand kv_lens_cuda (only generation)
num_tokens = self.num_generations * (1 + self.max_draft_tokens)
gen_kv_lens = kv_lens[self.num_contexts:self.num_seqs]
gen_kv_lens_expanded = torch.stack([gen_kv_lens] *
(1 + self.max_draft_tokens),
dim=0)
gen_kv_lens_expanded = gen_kv_lens_expanded.transpose(
0, 1).contiguous().flatten()
self.kv_lens_expanded_host[:num_tokens].copy_(gen_kv_lens_expanded)
self.kv_lens_expanded_cuda[:num_tokens].copy_(
self.kv_lens_expanded_host[:num_tokens], non_blocking=True)

# Expand indexer_k_cache_block_offsets (only generation)
if self.kv_cache_manager is not None:
block_ids = self.kv_cache_manager.get_batch_cache_indices(
self.request_ids)
gen_block_ids = block_ids[self.num_contexts:]
if len(gen_block_ids) > 0:
# Find max length and create padded tensor
max_len = max(len(bid) for bid in gen_block_ids)
gen_block_tensor = self.host_indexer_k_cache_block_offsets[
self.num_contexts:self.num_seqs, :max_len]
expanded_blocks = gen_block_tensor.repeat_interleave(
1 + self.max_draft_tokens, dim=0)
self.host_block_table_expanded[:num_tokens, :max_len].copy_(
expanded_blocks, non_blocking=True)
self.block_table_expanded[:num_tokens].copy_(
self.host_block_table_expanded[:num_tokens],
non_blocking=True)

# Prepare metadata for indexer
Indexer.prepare(metadata=self)

Expand Down Expand Up @@ -799,12 +894,22 @@ def prepare(metadata: DSAtrtllmAttentionMetadata):
if num_generations > 0:
# Prepare schedule metadata for fp8_paged_mqa_logits
# This is a preprocessing step that computes scheduling information for the kernel
gen_seq_lens = metadata.kv_lens_cuda_runtime[
num_contexts:num_contexts + num_generations]
scheduler_metadata_buffer = get_paged_mqa_logits_metadata(
gen_seq_lens, tokens_per_block, metadata.num_sms)
metadata.scheduler_metadata_buffer.copy_(scheduler_metadata_buffer,
non_blocking=True)
if metadata.max_draft_tokens <= 1:
gen_seq_lens = metadata.kv_lens_cuda_runtime[
num_contexts:num_contexts + num_generations]
scheduler_metadata_buffer = get_paged_mqa_logits_metadata(
gen_seq_lens, tokens_per_block, metadata.num_sms)
metadata.scheduler_metadata_buffer.copy_(
scheduler_metadata_buffer, non_blocking=True)
else:
# Expand schedule metadata buffer (only generation)
num_tokens = metadata.num_generations * (
1 + metadata.max_draft_tokens)
kv_lens_expanded = metadata.kv_lens_expanded_cuda[:num_tokens]
scheduler_metadata_buffer_expanded = get_paged_mqa_logits_metadata(
kv_lens_expanded, tokens_per_block, metadata.num_sms)
metadata.scheduler_metadata_buffer_expanded.copy_(
scheduler_metadata_buffer_expanded, non_blocking=True)

# Compute slot_mapping for all requests (both context and generation)
# This maps each token to its flat cache position for vectorized KV cache updates
Expand Down Expand Up @@ -1053,9 +1158,24 @@ def sparse_attn_indexer(
# Reshape q for decode phase: [num_gen_tokens, ...] -> [batch_size, next_n, ...]
q_decode = q_fp8[num_ctx_tokens:num_ctx_tokens + num_gen_tokens,
...]
q_decode = q_decode.view(num_generations, -1, *q_fp8.shape[1:])
batch_size = q_decode.shape[0]
next_n = q_decode.shape[1]
batch_size = num_generations
next_n = num_gen_tokens // num_generations
# Because fp8_paged_mqa_logits cannot support next_n > 2, we need to flatten the q_decode tensor
# and expand the corresponding metadata.
if next_n <= 2:
q_decode = q_decode.view(num_generations, -1, *q_fp8.shape[1:])
context_lens = metadata.kv_lens_cuda_runtime[
num_contexts:num_contexts + num_generations]
block_table = metadata.indexer_k_cache_block_offsets[
num_contexts:num_contexts + num_generations]
scheduler_metadata_buffer = metadata.scheduler_metadata_buffer
else:
q_decode = q_decode.view(-1, 1, *q_fp8.shape[1:])
num_tokens = num_generations * (1 + metadata.max_draft_tokens)
context_lens = metadata.kv_lens_expanded_cuda[:num_tokens]
block_table = metadata.block_table_expanded[:num_tokens]
scheduler_metadata_buffer = metadata.scheduler_metadata_buffer_expanded

assert num_gen_tokens == batch_size * next_n
weights_decode = weights[num_ctx_tokens:num_ctx_tokens +
num_gen_tokens, ...]
Expand All @@ -1064,18 +1184,11 @@ def sparse_attn_indexer(
# [num_blocks, tokens_per_block, 1, head_dim + scale_size]
k_cache = metadata.kv_cache_manager.get_indexer_k_cache_buffers(
self.layer_idx)
logits_decode = fp8_paged_mqa_logits(
q_decode,
k_cache,
weights_decode,
metadata.kv_lens_cuda_runtime[
num_contexts:num_contexts +
num_generations], # context_lens prepared in prepare()
metadata.indexer_k_cache_block_offsets[
num_contexts:num_contexts +
num_generations], # Only pass generation request block tables
metadata.scheduler_metadata_buffer,
max_seq_len)
logits_decode = fp8_paged_mqa_logits(q_decode, k_cache,
weights_decode, context_lens,
block_table,
scheduler_metadata_buffer,
max_seq_len)

if use_custom_topk:
# Kernel expects kv_lens (total cache length), not seq_lens (new tokens)
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2380,7 +2380,7 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness):
(8, 1, 8, 0, False, True, True, True, 24, "_DEFAULT"),
(8, 1, 8, 1, False, True, True, True, 24, "_DEFAULT"),
(8, 1, 8, 0, True, True, True, True, 24, "_DEFAULT"),
(8, 1, 8, 1, False, False, True, True, 1, "TRTLLM"),
(8, 1, 8, 3, False, False, True, True, 1, "TRTLLM"),
],
ids=["baseline", "baseline_mtp1", "baseline_fp8kv", "latency"])
def test_fp8_blockscale(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
Expand Down Expand Up @@ -2448,7 +2448,7 @@ def test_fp8_blockscale(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
(8, 1, 8, 0, False, True, True, True, 24, "CUTLASS"),
(8, 1, 8, 1, False, True, True, True, 24, "CUTLASS"),
(8, 1, 8, 0, True, True, True, True, 24, "CUTLASS"),
(8, 1, 8, 1, False, False, True, True, 1, "TRTLLM"),
(8, 1, 8, 3, False, False, True, True, 1, "TRTLLM"),
],
ids=["baseline", "baseline_mtp1", "baseline_fp8kv", "latency"])
def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
Expand Down
38 changes: 29 additions & 9 deletions tests/unittest/_torch/attention/sparse/test_dsa_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,8 @@ def _create_mock_metadata(request_ids,
cache_manager,
num_ctx_tokens,
num_tokens,
indexer_max_chunk_size=8194):
indexer_max_chunk_size=8194,
max_draft_tokens=0):
"""Helper to create mock metadata for testing."""

class MockKVCacheParams:
Expand All @@ -396,6 +397,7 @@ def __init__(self):
self.request_ids = request_ids
self.num_contexts = num_contexts
self.num_generations = num_generations
self.max_draft_tokens = max_draft_tokens
# Keep seq_lens on CPU for split_prefill_chunks and other CPU operations
# CUDA kernels will convert to CUDA as needed
self.seq_lens = seq_lens.cpu() if seq_lens.is_cuda else seq_lens
Expand Down Expand Up @@ -826,6 +828,7 @@ def test_indexer_decode_with_paged_kv_cache(batch_size, next_n):
cache_manager=cache_manager,
num_ctx_tokens=total_context_tokens,
num_tokens=total_context_tokens,
max_draft_tokens=next_n - 1,
)
Indexer.prepare(metadata_context)

Expand All @@ -851,6 +854,7 @@ def test_indexer_decode_with_paged_kv_cache(batch_size, next_n):
cache_manager=cache_manager,
num_ctx_tokens=0,
num_tokens=batch_size * num_gen_tokens,
max_draft_tokens=next_n - 1,
)
Indexer.prepare(metadata_gen)

Expand Down Expand Up @@ -1418,6 +1422,7 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
cache_manager=cache_manager,
num_ctx_tokens=total_context_tokens,
num_tokens=total_context_tokens,
max_draft_tokens=next_n - 1,
)
Indexer.prepare(metadata_context)
indexer._update_k_cache(k_context_fp8, k_context_scale, metadata_context)
Expand Down Expand Up @@ -1450,16 +1455,24 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
cache_manager=cache_manager,
num_ctx_tokens=0,
num_tokens=num_gen_tokens,
max_draft_tokens=next_n - 1,
)
Indexer.prepare(metadata_gen_write)
indexer._update_k_cache(k_fp8, k_scale, metadata_gen_write)

# Test with custom CUDA kernel
metadata_custom = _create_mock_metadata(request_ids, batch_size, 0,
batch_size, seq_lens.clone(),
metadata_custom = _create_mock_metadata(request_ids,
batch_size,
0,
batch_size,
seq_lens.clone(),
final_lens.clone(),
num_cached_tokens, cache_manager, 0,
num_gen_tokens, max_model_len)
num_cached_tokens,
cache_manager,
0,
num_gen_tokens,
max_model_len,
max_draft_tokens=next_n - 1)

Indexer.prepare(metadata_custom)
indexer._update_k_cache(k_fp8, k_scale, metadata_custom)
Expand All @@ -1476,11 +1489,18 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
pytest.skip(f"Custom topk not available: {e}")

# Test with PyTorch fallback
metadata_fallback = _create_mock_metadata(request_ids, batch_size, 0,
batch_size, seq_lens.clone(),
metadata_fallback = _create_mock_metadata(request_ids,
batch_size,
0,
batch_size,
seq_lens.clone(),
final_lens.clone(),
num_cached_tokens, cache_manager,
0, num_gen_tokens, max_model_len)
num_cached_tokens,
cache_manager,
0,
num_gen_tokens,
max_model_len,
max_draft_tokens=next_n - 1)

Indexer.prepare(metadata_fallback)
indexer._update_k_cache(k_fp8, k_scale, metadata_fallback)
Expand Down