Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions lightllm/models/deepseek2/flashinfer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import torch.distributed as dist
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.envs_utils import get_env_start_args, disable_trtllm_ragged_prefill
from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index


Expand Down Expand Up @@ -68,22 +68,25 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
if get_env_start_args().enable_flashinfer_prefill:
q_starts = self.b1_cu_q_seq_len.int()
kv_starts = self.b1_kv_start_loc.int()
if self.prefill_wrapper is None:
self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
self.flashinfer_extra_state.workspace_buffer, "NHD"
if disable_trtllm_ragged_prefill():
if self.prefill_wrapper is None:
self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
self.flashinfer_extra_state.workspace_buffer, "NHD"
)
self.prefill_wrapper.plan(
qo_indptr=q_starts,
kv_indptr=kv_starts,
num_qo_heads=self.flashinfer_extra_state.tp_q_head_num,
num_kv_heads=self.flashinfer_extra_state.tp_q_head_num,
head_dim_qk=self.flashinfer_extra_state.qk_nope_head_dim
+ self.flashinfer_extra_state.qk_rope_head_dim,
head_dim_vo=self.flashinfer_extra_state.qk_nope_head_dim,
q_data_type=self.flashinfer_extra_state.q_data_type,
causal=True,
sm_scale=self.flashinfer_extra_state.softmax_scale,
)
self.prefill_wrapper.plan(
qo_indptr=q_starts,
kv_indptr=kv_starts,
num_qo_heads=self.flashinfer_extra_state.tp_q_head_num,
num_kv_heads=self.flashinfer_extra_state.tp_q_head_num,
head_dim_qk=self.flashinfer_extra_state.qk_nope_head_dim
+ self.flashinfer_extra_state.qk_rope_head_dim,
head_dim_vo=self.flashinfer_extra_state.qk_nope_head_dim,
q_data_type=self.flashinfer_extra_state.q_data_type,
causal=True,
sm_scale=self.flashinfer_extra_state.softmax_scale,
)
else:
self.prefill_wrapper = None
return

def copy_for_cuda_graph(self, new_infer_state):
Expand Down
124 changes: 117 additions & 7 deletions lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch.distributed as dist
import numpy as np
import triton
import flashinfer
from typing import Tuple
from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight
from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv import destindex_copy_kv
Expand All @@ -26,7 +27,7 @@
from functools import partial
from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale
from lightllm.distributed.communication_op import all_gather, all_gather_into_tensor, all_reduce, reduce_scatter_tensor
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.envs_utils import get_env_start_args, disable_trtllm_ragged_prefill
from lightllm.utils.dist_utils import get_global_world_size
from lightllm.utils.log_utils import init_logger
from lightllm.utils.sgl_utils import flash_attn_varlen_func, flash_attn_with_kvcache, merge_state_v2
Expand Down Expand Up @@ -113,9 +114,14 @@ def _bind_attention(self):
if self.enable_cc_method:
if "triton_fp8kv" in self.mode:
if get_env_start_args().enable_flashinfer_prefill:
self._context_attention_kernel = partial(
Deepseek2TransformerLayerInfer._context_attention_flashinfer_kernel_with_CC_fp8, self
)
if disable_trtllm_ragged_prefill():
self._context_attention_kernel = partial(
Deepseek2TransformerLayerInfer._context_attention_flashinfer_kernel_with_CC_fp8, self
)
else:
self._context_attention_kernel = partial(
Deepseek2TransformerLayerInfer._context_attention_trtllm_ragged_with_CC_fp8, self
)
else:
self._context_attention_kernel = partial(
Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC_fp8, self
Expand All @@ -126,9 +132,14 @@ def _bind_attention(self):
Deepseek2TransformerLayerInfer._context_attention_flashattention_kernel_with_CC, self
)
elif get_env_start_args().enable_flashinfer_prefill:
self._context_attention_kernel = partial(
Deepseek2TransformerLayerInfer._context_attention_flashinfer_kernel_with_CC, self
)
if disable_trtllm_ragged_prefill():
self._context_attention_kernel = partial(
Deepseek2TransformerLayerInfer._context_attention_flashinfer_kernel_with_CC, self
)
else:
self._context_attention_kernel = partial(
Deepseek2TransformerLayerInfer._context_attention_trtllm_ragged_with_CC, self
)
else:
self._context_attention_kernel = partial(
Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC, self
Expand Down Expand Up @@ -466,6 +477,105 @@ def _context_attention_flashinfer_kernel_with_CC_fp8(
infer_state.prefill_wrapper.run(q, k, v, out=o_tensor)
return o_tensor

def _context_attention_trtllm_ragged_with_CC(
self,
q: torch.Tensor,
kv,
infer_state: Deepseek2FlashInferStateInfo,
layer_weight: Deepseek2TransformerLayerWeight,
out=None,
) -> torch.Tensor:
k_nope, k_rope, v = self._decompress_kv(
kv,
infer_state,
layer_weight,
False,
infer_state.total_token_num,
infer_state.b_seq_len,
infer_state.max_value_in_b_seq_len,
infer_state.b1_kv_start_loc,
)
o_tensor = (
self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) if out is None else out
)
k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1)

seq_lens = infer_state.b_seq_len.int()
cum_seq_lens = infer_state.b1_cu_q_seq_len.int()
max_seq_len = int(seq_lens.max().item())

o = flashinfer.prefill.trtllm_ragged_attention_deepseek(
query=q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim),
key=k.view(-1, self.tp_k_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim),
value=v.view(-1, self.tp_v_head_num_, self.v_head_dim),
workspace_buffer=infer_state.flashinfer_extra_state.workspace_buffer,
seq_lens=seq_lens,
max_q_len=max_seq_len,
max_kv_len=max_seq_len,
bmm1_scale=self.softmax_scale,
bmm2_scale=1.0,
o_sf_scale=1.0,
batch_size=infer_state.batch_size,
window_left=-1,
cum_seq_lens_q=cum_seq_lens,
cum_seq_lens_kv=cum_seq_lens,
enable_pdl=False,
is_causal=True,
return_lse=False,
)
o_tensor.copy_(o)
return o_tensor

def _context_attention_trtllm_ragged_with_CC_fp8(
self,
q: torch.Tensor,
kv,
infer_state: Deepseek2FlashInferStateInfo,
layer_weight: Deepseek2TransformerLayerWeight,
out=None,
) -> torch.Tensor:
k_nope, k_rope, v = self._decompress_kv(
kv,
infer_state,
layer_weight,
True,
infer_state.total_token_num,
infer_state.b_seq_len,
infer_state.max_value_in_b_seq_len,
infer_state.b1_kv_start_loc,
)
o_tensor = (
self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) if out is None else out
)
k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1)

seq_lens = infer_state.b_seq_len.int()
cum_seq_lens = infer_state.b1_cu_q_seq_len.int()
max_seq_len = int(seq_lens.max().item())

o = flashinfer.prefill.trtllm_ragged_attention_deepseek(
query=q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim),
key=k.view(-1, self.tp_k_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim),
value=v.view(-1, self.tp_v_head_num_, self.v_head_dim),
workspace_buffer=infer_state.flashinfer_extra_state.workspace_buffer,
seq_lens=seq_lens,
max_q_len=max_seq_len,
max_kv_len=max_seq_len,
bmm1_scale=self.softmax_scale,
bmm2_scale=1.0,
o_sf_scale=1.0,
batch_size=infer_state.batch_size,
window_left=-1,
cum_seq_lens_q=cum_seq_lens,
cum_seq_lens_kv=cum_seq_lens,
enable_pdl=False,
is_causal=True,
return_lse=False,
)
o_tensor.copy_(o)
return o_tensor
return q
Comment on lines +529 to +577
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The method _context_attention_trtllm_ragged_with_CC_fp8 is nearly identical to _context_attention_trtllm_ragged_with_CC. The only difference is the boolean value True passed for the is_fp8 parameter to self._decompress_kv. This significant code duplication makes the code harder to maintain and more prone to errors if one function is updated and the other is not.

To improve maintainability, consider refactoring these two methods into a single private helper that accepts an is_fp8 boolean parameter.


def _context_attention_kernel_with_CC(
self,
q: torch.Tensor,
Expand Down
5 changes: 5 additions & 0 deletions lightllm/utils/envs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,8 @@ def enable_radix_tree_timer_merge() -> bool:
@lru_cache(maxsize=None)
def get_radix_tree_merge_update_delta() -> int:
return int(os.getenv("LIGHTLMM_RADIX_TREE_MERGE_DELTA", 6000))


@lru_cache(maxsize=None)
def disable_trtllm_ragged_prefill():
return enable_env_vars("DISABLE_TRTLLM_RAGGED_PREFILL")