diff --git a/lightllm/models/deepseek2/flashinfer_struct.py b/lightllm/models/deepseek2/flashinfer_struct.py index a00c45601..cc66c74bd 100644 --- a/lightllm/models/deepseek2/flashinfer_struct.py +++ b/lightllm/models/deepseek2/flashinfer_struct.py @@ -3,7 +3,7 @@ import numpy as np import torch.distributed as dist from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, disable_trtllm_ragged_prefill from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index @@ -68,22 +68,25 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): if get_env_start_args().enable_flashinfer_prefill: q_starts = self.b1_cu_q_seq_len.int() kv_starts = self.b1_kv_start_loc.int() - if self.prefill_wrapper is None: - self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( - self.flashinfer_extra_state.workspace_buffer, "NHD" + if disable_trtllm_ragged_prefill(): + if self.prefill_wrapper is None: + self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( + self.flashinfer_extra_state.workspace_buffer, "NHD" + ) + self.prefill_wrapper.plan( + qo_indptr=q_starts, + kv_indptr=kv_starts, + num_qo_heads=self.flashinfer_extra_state.tp_q_head_num, + num_kv_heads=self.flashinfer_extra_state.tp_q_head_num, + head_dim_qk=self.flashinfer_extra_state.qk_nope_head_dim + + self.flashinfer_extra_state.qk_rope_head_dim, + head_dim_vo=self.flashinfer_extra_state.qk_nope_head_dim, + q_data_type=self.flashinfer_extra_state.q_data_type, + causal=True, + sm_scale=self.flashinfer_extra_state.softmax_scale, ) - self.prefill_wrapper.plan( - qo_indptr=q_starts, - kv_indptr=kv_starts, - num_qo_heads=self.flashinfer_extra_state.tp_q_head_num, - num_kv_heads=self.flashinfer_extra_state.tp_q_head_num, - head_dim_qk=self.flashinfer_extra_state.qk_nope_head_dim - + self.flashinfer_extra_state.qk_rope_head_dim, - head_dim_vo=self.flashinfer_extra_state.qk_nope_head_dim, - q_data_type=self.flashinfer_extra_state.q_data_type, - causal=True, - sm_scale=self.flashinfer_extra_state.softmax_scale, - ) + else: + self.prefill_wrapper = None return def copy_for_cuda_graph(self, new_infer_state): diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 30d37d1df..f80d66e27 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -4,6 +4,7 @@ import torch.distributed as dist import numpy as np import triton +import flashinfer from typing import Tuple from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv import destindex_copy_kv @@ -26,7 +27,7 @@ from functools import partial from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale from lightllm.distributed.communication_op import all_gather, all_gather_into_tensor, all_reduce, reduce_scatter_tensor -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, disable_trtllm_ragged_prefill from lightllm.utils.dist_utils import get_global_world_size from lightllm.utils.log_utils import init_logger from lightllm.utils.sgl_utils import flash_attn_varlen_func, flash_attn_with_kvcache, merge_state_v2 @@ -113,9 +114,14 @@ def _bind_attention(self): if self.enable_cc_method: if "triton_fp8kv" in self.mode: if get_env_start_args().enable_flashinfer_prefill: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_flashinfer_kernel_with_CC_fp8, self - ) + if disable_trtllm_ragged_prefill(): + self._context_attention_kernel = partial( + Deepseek2TransformerLayerInfer._context_attention_flashinfer_kernel_with_CC_fp8, self + ) + else: + self._context_attention_kernel = partial( + Deepseek2TransformerLayerInfer._context_attention_trtllm_ragged_with_CC_fp8, self + ) else: self._context_attention_kernel = partial( Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC_fp8, self @@ -126,9 +132,14 @@ def _bind_attention(self): Deepseek2TransformerLayerInfer._context_attention_flashattention_kernel_with_CC, self ) elif get_env_start_args().enable_flashinfer_prefill: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_flashinfer_kernel_with_CC, self - ) + if disable_trtllm_ragged_prefill(): + self._context_attention_kernel = partial( + Deepseek2TransformerLayerInfer._context_attention_flashinfer_kernel_with_CC, self + ) + else: + self._context_attention_kernel = partial( + Deepseek2TransformerLayerInfer._context_attention_trtllm_ragged_with_CC, self + ) else: self._context_attention_kernel = partial( Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC, self @@ -466,6 +477,105 @@ def _context_attention_flashinfer_kernel_with_CC_fp8( infer_state.prefill_wrapper.run(q, k, v, out=o_tensor) return o_tensor + def _context_attention_trtllm_ragged_with_CC( + self, + q: torch.Tensor, + kv, + infer_state: Deepseek2FlashInferStateInfo, + layer_weight: Deepseek2TransformerLayerWeight, + out=None, + ) -> torch.Tensor: + k_nope, k_rope, v = self._decompress_kv( + kv, + infer_state, + layer_weight, + False, + infer_state.total_token_num, + infer_state.b_seq_len, + infer_state.max_value_in_b_seq_len, + infer_state.b1_kv_start_loc, + ) + o_tensor = ( + self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) if out is None else out + ) + k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1) + + seq_lens = infer_state.b_seq_len.int() + cum_seq_lens = infer_state.b1_cu_q_seq_len.int() + max_seq_len = int(seq_lens.max().item()) + + o = flashinfer.prefill.trtllm_ragged_attention_deepseek( + query=q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim), + key=k.view(-1, self.tp_k_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim), + value=v.view(-1, self.tp_v_head_num_, self.v_head_dim), + workspace_buffer=infer_state.flashinfer_extra_state.workspace_buffer, + seq_lens=seq_lens, + max_q_len=max_seq_len, + max_kv_len=max_seq_len, + bmm1_scale=self.softmax_scale, + bmm2_scale=1.0, + o_sf_scale=1.0, + batch_size=infer_state.batch_size, + window_left=-1, + cum_seq_lens_q=cum_seq_lens, + cum_seq_lens_kv=cum_seq_lens, + enable_pdl=False, + is_causal=True, + return_lse=False, + ) + o_tensor.copy_(o) + return o_tensor + + def _context_attention_trtllm_ragged_with_CC_fp8( + self, + q: torch.Tensor, + kv, + infer_state: Deepseek2FlashInferStateInfo, + layer_weight: Deepseek2TransformerLayerWeight, + out=None, + ) -> torch.Tensor: + k_nope, k_rope, v = self._decompress_kv( + kv, + infer_state, + layer_weight, + True, + infer_state.total_token_num, + infer_state.b_seq_len, + infer_state.max_value_in_b_seq_len, + infer_state.b1_kv_start_loc, + ) + o_tensor = ( + self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) if out is None else out + ) + k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1) + + seq_lens = infer_state.b_seq_len.int() + cum_seq_lens = infer_state.b1_cu_q_seq_len.int() + max_seq_len = int(seq_lens.max().item()) + + o = flashinfer.prefill.trtllm_ragged_attention_deepseek( + query=q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim), + key=k.view(-1, self.tp_k_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim), + value=v.view(-1, self.tp_v_head_num_, self.v_head_dim), + workspace_buffer=infer_state.flashinfer_extra_state.workspace_buffer, + seq_lens=seq_lens, + max_q_len=max_seq_len, + max_kv_len=max_seq_len, + bmm1_scale=self.softmax_scale, + bmm2_scale=1.0, + o_sf_scale=1.0, + batch_size=infer_state.batch_size, + window_left=-1, + cum_seq_lens_q=cum_seq_lens, + cum_seq_lens_kv=cum_seq_lens, + enable_pdl=False, + is_causal=True, + return_lse=False, + ) + o_tensor.copy_(o) + return o_tensor + return q + def _context_attention_kernel_with_CC( self, q: torch.Tensor, diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 7c221c574..2ba46051e 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -194,3 +194,8 @@ def enable_radix_tree_timer_merge() -> bool: @lru_cache(maxsize=None) def get_radix_tree_merge_update_delta() -> int: return int(os.getenv("LIGHTLMM_RADIX_TREE_MERGE_DELTA", 6000)) + + +@lru_cache(maxsize=None) +def disable_trtllm_ragged_prefill(): + return enable_env_vars("DISABLE_TRTLLM_RAGGED_PREFILL")