diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index b273c333f..dd27103b4 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -43,7 +43,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): def copy_for_cuda_graph(self, new_infer_state): for attr_name, attr_value in vars(new_infer_state).items(): if isinstance(attr_value, torch.Tensor): - attr_ = getattr(self, attr_name) + attr_ = getattr(self, attr_name, None) if attr_ is not None: attr_.copy_(attr_value) return diff --git a/lightllm/common/basemodel/layer_infer/base_layer_infer.py b/lightllm/common/basemodel/layer_infer/base_layer_infer.py index f032dbee8..5846fd850 100644 --- a/lightllm/common/basemodel/layer_infer/base_layer_infer.py +++ b/lightllm/common/basemodel/layer_infer/base_layer_infer.py @@ -22,7 +22,7 @@ def splitfuse_forward(self, input_ids, infer_state: SplitFuseInferStateInfo, lay def alloc_tensor( self, shape: Union[torch.Size, Iterable[int]], - data_type: torch.dtype, + dtype: torch.dtype, device: str = "cuda", is_graph_out: bool = False, ) -> torch.Tensor: @@ -30,4 +30,4 @@ def alloc_tensor( is_graph_out 用于标记是graph图推理中的最后一个tensor,该参数只会在开启cuda graph时生效。该tensor的复用有特殊的逻辑,用于降低显存 占用 """ - return g_cache_manager.alloc_tensor(shape, data_type, device=device, is_graph_out=is_graph_out) + return g_cache_manager.alloc_tensor(shape, dtype, device=device, is_graph_out=is_graph_out) diff --git a/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py b/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py index 9f3ad01d4..4d9c44891 100644 --- a/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py +++ b/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py @@ -129,11 +129,14 @@ def cache_env_out(self): def alloc_tensor( self, - shape: Union[torch.Size, Iterable[int]], + shape: Union[torch.Size, Tuple[int, ...]], data_type: torch.dtype, device: str = "cuda", is_graph_out: bool = False, ) -> torch.Tensor: + # shape 类型转换 + if isinstance(shape, list): + shape = torch.Size(shape) # 是 cuda graph的时候,由cuda graph manager 接管 if self.is_cuda_graph: return self.inner_cuda_graph_manager.alloc_tensor_for_cuda_graph( diff --git a/lightllm/models/bloom/layer_infer/pre_layer_infer.py b/lightllm/models/bloom/layer_infer/pre_layer_infer.py index 7a24ca8fe..d4695877e 100644 --- a/lightllm/models/bloom/layer_infer/pre_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/pre_layer_infer.py @@ -27,7 +27,7 @@ def context_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: input_ids = input_ids[0:total_token_num] input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), data_type=layer_weight.data_type_ + (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ ) embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) if self.world_size_ > 1: @@ -37,7 +37,7 @@ def context_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: def token_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: BloomPreAndPostLayerWeight): input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), data_type=layer_weight.data_type_ + (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ ) embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) if self.world_size_ > 1: diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 7eaddfcd4..50972d7ae 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -151,6 +151,7 @@ def _token_decode_attention_flashdecoding(self, q, infer_state: LlamaInferStateI self.qk_rope_head_dim, self.qk_nope_head_dim, self.softmax_scale, + alloc_tensor_func=self.alloc_tensor, ) def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager): diff --git a/lightllm/models/deepseek2/triton_kernel/flash_decoding.py b/lightllm/models/deepseek2/triton_kernel/flash_decoding.py index 80acd2017..f0ba13b54 100644 --- a/lightllm/models/deepseek2/triton_kernel/flash_decoding.py +++ b/lightllm/models/deepseek2/triton_kernel/flash_decoding.py @@ -13,6 +13,7 @@ def token_decode_attention_flash_decoding( qk_nope_head_dim, softmax_scale, out=None, + alloc_tensor_func=torch.empty, ): if kv_lora_rank > 128: BLOCK_SEQ = 256 // (kv_lora_rank // 128) @@ -24,20 +25,15 @@ def token_decode_attention_flash_decoding( from lightllm.models.deepseek2.triton_kernel.flash_decoding_stage1 import flash_decode_stage1 from lightllm.models.deepseek2.triton_kernel.flash_decoding_stage2 import flash_decode_stage2 - o_tensor = torch.empty_like(q_nope) if out is None else out - - if getattr(infer_state, "mid_o", None) is None: - infer_state.mid_o = torch.empty( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, kv_lora_rank], - dtype=torch.float32, - device="cuda", - ) - infer_state.mid_o_logexpsum = torch.empty( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" - ) - - mid_o = infer_state.mid_o - mid_o_logexpsum = infer_state.mid_o_logexpsum + o_tensor = alloc_tensor_func(q_nope.shape, q_nope.dtype, q_nope.device) if out is None else out + mid_o = alloc_tensor_func( + [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, kv_lora_rank], + dtype=torch.float32, + device="cuda", + ) + mid_o_logexpsum = alloc_tensor_func( + [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" + ) flash_decode_stage1( q_nope.view(calcu_shape1), diff --git a/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py b/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py index 7cd7721a6..17555de03 100644 --- a/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py +++ b/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py @@ -25,7 +25,7 @@ def _norm(self, input, infer_state, layer_weight: Gemma_2bPreAndPostLayerWeight) def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Gemma_2bPreAndPostLayerWeight): input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), data_type=layer_weight.data_type_ + (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ ) embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) if self.world_size_ > 1: @@ -35,7 +35,7 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Gemma_2bPreAndPostLayerWeight): input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), data_type=layer_weight.data_type_ + (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ ) embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) if self.world_size_ > 1: diff --git a/lightllm/models/internlm/layer_infer/transformer_layer_infer.py b/lightllm/models/internlm/layer_infer/transformer_layer_infer.py index fb453194e..bf594f97a 100755 --- a/lightllm/models/internlm/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/internlm/layer_infer/transformer_layer_infer.py @@ -18,7 +18,7 @@ def _get_qkv( self, input, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: InternlmTransformerLayerWeight ) -> torch.Tensor: input = input.view(-1, self.embed_dim_) - q = self.alloc_tensor((input.size(0), layer_weight.q_weight_.size(1)), data_type=input.dtype) + q = self.alloc_tensor((input.size(0), layer_weight.q_weight_.size(1)), dtype=input.dtype) torch.addmm(layer_weight.q_bias_, input, layer_weight.q_weight_, beta=1.0, alpha=1.0, out=q) torch.addmm( layer_weight.kv_bias_, diff --git a/lightllm/models/llama/layer_infer/post_layer_infer.py b/lightllm/models/llama/layer_infer/post_layer_infer.py index 32033d312..a642a0fe0 100644 --- a/lightllm/models/llama/layer_infer/post_layer_infer.py +++ b/lightllm/models/llama/layer_infer/post_layer_infer.py @@ -30,7 +30,7 @@ def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo if infer_state.is_splitfuse: # for SplitFuse batch_size = infer_state.batch_size - last_input = self.alloc_tensor((batch_size, self.embed_dim_), data_type=input_embdings.dtype) + last_input = self.alloc_tensor((batch_size, self.embed_dim_), dtype=input_embdings.dtype) tmp_ = torch.cat( [ torch.ones(infer_state.decode_req_num, dtype=torch.int32, device="cuda"), @@ -54,13 +54,13 @@ def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo select_token_num += 1 last_index = torch.tensor(select_index, dtype=torch.long, device=input_embdings.device) - last_input = self.alloc_tensor((select_token_num, self.embed_dim_), data_type=input_embdings.dtype) + last_input = self.alloc_tensor((select_token_num, self.embed_dim_), dtype=input_embdings.dtype) last_input[:, :] = input_embdings[last_index, :] return last_input, select_token_num if not infer_state.is_splitfuse and infer_state.is_prefill and not infer_state.return_all_prompt_logics: batch_size = infer_state.batch_size - last_input = self.alloc_tensor((batch_size, self.embed_dim_), data_type=input_embdings.dtype) + last_input = self.alloc_tensor((batch_size, self.embed_dim_), dtype=input_embdings.dtype) last_index = ( torch.cumsum(infer_state.b_seq_len - infer_state.b_ready_cache_len, dim=0, dtype=torch.long) - 1 ) @@ -84,7 +84,7 @@ def token_forward(self, input_embdings, infer_state: LlamaInferStateInfo, layer_ last_input = self._norm(last_input, infer_state, layer_weight) last_input = last_input.permute(1, 0).view(-1, token_num) logic_batch = self.alloc_tensor( - (layer_weight.lm_head_weight_.shape[0], last_input.shape[1]), data_type=last_input.dtype + (layer_weight.lm_head_weight_.shape[0], last_input.shape[1]), dtype=last_input.dtype ) torch.mm(layer_weight.lm_head_weight_, last_input, out=logic_batch) @@ -92,7 +92,7 @@ def token_forward(self, input_embdings, infer_state: LlamaInferStateInfo, layer_ if self.world_size_ == 1: gather_data = logic_batch else: - gather_data = self.alloc_tensor((self.vocab_size_, token_num), data_type=input_embdings_dtype) + gather_data = self.alloc_tensor((self.vocab_size_, token_num), dtype=input_embdings_dtype) split_indexes = np.linspace(0, self.vocab_size_, self.world_size_ + 1, dtype=np.int64) dist.all_gather( [gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.world_size_)], @@ -101,7 +101,7 @@ def token_forward(self, input_embdings, infer_state: LlamaInferStateInfo, layer_ async_op=False, ) logic_batch = None - ans_logics = self.alloc_tensor((token_num, self.vocab_size_), data_type=torch.float32, is_graph_out=True) + ans_logics = self.alloc_tensor((token_num, self.vocab_size_), dtype=torch.float32, is_graph_out=True) ans_logics[:, :] = gather_data.permute(1, 0) gather_data = None return ans_logics diff --git a/lightllm/models/llama/layer_infer/pre_layer_infer.py b/lightllm/models/llama/layer_infer/pre_layer_infer.py index ecda4b1d2..f60fa6127 100644 --- a/lightllm/models/llama/layer_infer/pre_layer_infer.py +++ b/lightllm/models/llama/layer_infer/pre_layer_infer.py @@ -21,7 +21,7 @@ def __init__(self, tp_rank, world_size, network_config, mode): def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), data_type=layer_weight.data_type_ + (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ ) embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) if self.world_size_ > 1: @@ -30,7 +30,7 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), data_type=layer_weight.data_type_ + (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ ) embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) if self.world_size_ > 1: diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index c7a7180e5..31ee52c7a 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -122,7 +122,7 @@ def _ffn_norm( def _get_qkv( self, input, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: - q = self.alloc_tensor((input.size(0), layer_weight.q_weight_.size(1)), data_type=input.dtype) + q = self.alloc_tensor((input.size(0), layer_weight.q_weight_.size(1)), dtype=input.dtype) torch.mm(input, layer_weight.q_weight_, out=q) torch.mm( input, @@ -503,7 +503,14 @@ def _token_decode_attention_ppl_fp16_flashdecoding( :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ] return token_decode_attention_flash_decoding( - q, infer_state, self.tp_q_head_num_, self.head_dim_, cache_k, cache_v, out=out + q, + infer_state, + self.tp_q_head_num_, + self.head_dim_, + cache_k, + cache_v, + out=out, + alloc_tensor_func=self.alloc_tensor, ) def _token_decode_attention_ppl_int8kv_flashdecoding( @@ -520,7 +527,16 @@ def _token_decode_attention_ppl_int8kv_flashdecoding( :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ] return token_decode_attention_flash_decoding( - q, infer_state, self.tp_q_head_num_, self.head_dim_, cache_k, cache_k_scale, cache_v, cache_v_scale, out=out + q, + infer_state, + self.tp_q_head_num_, + self.head_dim_, + cache_k, + cache_k_scale, + cache_v, + cache_v_scale, + out=out, + alloc_tensor_func=self.alloc_tensor, ) def _token_decode_attention_ppl_int4kv_flashdecoding( @@ -537,5 +553,14 @@ def _token_decode_attention_ppl_int4kv_flashdecoding( :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ] return token_decode_attention_flash_decoding( - q, infer_state, self.tp_q_head_num_, self.head_dim_, cache_k, cache_k_scale, cache_v, cache_v_scale, out=out + q, + infer_state, + self.tp_q_head_num_, + self.head_dim_, + cache_k, + cache_k_scale, + cache_v, + cache_v_scale, + out=out, + alloc_tensor_func=self.alloc_tensor, ) diff --git a/lightllm/models/llama/triton_kernel/embedding.py b/lightllm/models/llama/triton_kernel/embedding.py index 2664c8dc7..6178d7a41 100644 --- a/lightllm/models/llama/triton_kernel/embedding.py +++ b/lightllm/models/llama/triton_kernel/embedding.py @@ -64,7 +64,7 @@ def embedding(input_ids, weight: torch.Tensor, vob_start_id, vob_end_id, out: to @torch.no_grad() def embedding_new(input_ids, weight, vob_start_id, vob_end_id): - # out = self.alloc_tensor((N_CTX, DIM), data_type=torch.float32) + # out = self.alloc_tensor((N_CTX, DIM), dtype=torch.float32) out = torch.empty((N_CTX, DIM), device="cuda", requires_grad=False) embedding(input_ids, weight, vob_start_id, vob_end_id, out) diff --git a/lightllm/models/llama/triton_kernel/flash_decoding.py b/lightllm/models/llama/triton_kernel/flash_decoding.py index f4a5d0404..e47e30886 100644 --- a/lightllm/models/llama/triton_kernel/flash_decoding.py +++ b/lightllm/models/llama/triton_kernel/flash_decoding.py @@ -14,16 +14,12 @@ def token_decode_attention_flash_decoding( o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out - if getattr(infer_state, "mid_o", None) is None: - infer_state.mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" - ) - infer_state.mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" - ) - - mid_o = infer_state.mid_o - mid_o_logexpsum = infer_state.mid_o_logexpsum + mid_o = alloc_tensor_func( + [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" + ) + mid_o_logexpsum = alloc_tensor_func( + [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" + ) flash_decode_stage1( q.view(calcu_shape1), diff --git a/lightllm/models/llama/triton_kernel/gqa_flash_decoding.py b/lightllm/models/llama/triton_kernel/gqa_flash_decoding.py index 52a767d00..67be7c968 100644 --- a/lightllm/models/llama/triton_kernel/gqa_flash_decoding.py +++ b/lightllm/models/llama/triton_kernel/gqa_flash_decoding.py @@ -14,16 +14,12 @@ def gqa_token_decode_attention_flash_decoding( o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out - if getattr(infer_state, "mid_o", None) is None: - infer_state.mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" - ) - infer_state.mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" - ) - - mid_o = infer_state.mid_o - mid_o_logexpsum = infer_state.mid_o_logexpsum + mid_o = alloc_tensor_func( + [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" + ) + mid_o_logexpsum = alloc_tensor_func( + [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" + ) flash_decode_stage1( q.view(calcu_shape1), diff --git a/lightllm/models/llama/triton_kernel/ppl_fp16_flash_decoding.py b/lightllm/models/llama/triton_kernel/ppl_fp16_flash_decoding.py index c7eaf4cf9..8fda08460 100644 --- a/lightllm/models/llama/triton_kernel/ppl_fp16_flash_decoding.py +++ b/lightllm/models/llama/triton_kernel/ppl_fp16_flash_decoding.py @@ -1,6 +1,9 @@ import torch -def token_decode_attention_flash_decoding(q, infer_state, q_head_num, head_dim, cache_k, cache_v, out=None): + +def token_decode_attention_flash_decoding( + q, infer_state, q_head_num, head_dim, cache_k, cache_v, out=None, alloc_tensor_func=torch.empty +): BLOCK_SEQ = 256 batch_size = infer_state.batch_size max_len_in_batch = infer_state.max_len_in_batch @@ -9,39 +12,28 @@ def token_decode_attention_flash_decoding(q, infer_state, q_head_num, head_dim, from lightllm_ppl_fp16_flashdecoding_kernel import fp16_flashdecoding_stage1 from .flash_decoding_stage2 import flash_decode_stage2 - o_tensor = torch.empty_like(q) if out is None else out - - if getattr(infer_state, 'mid_o', None) is None: - infer_state.mid_o = torch.empty([batch_size, - q_head_num, - max_len_in_batch // BLOCK_SEQ + 1, - head_dim], - dtype=torch.float16, - device="cuda") - infer_state.mid_o_logexpsum = torch.empty([batch_size, - q_head_num, - max_len_in_batch // BLOCK_SEQ + 1], - dtype=torch.float16, - device="cuda") - - mid_o = infer_state.mid_o - mid_o_logexpsum = infer_state.mid_o_logexpsum + o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out + + mid_o = alloc_tensor_func( + [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float16, device="cuda" + ) + mid_o_logexpsum = alloc_tensor_func( + [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda" + ) + + fp16_flashdecoding_stage1( + BLOCK_SEQ, + mid_o, + mid_o_logexpsum, + 1.0 / (head_dim ** 0.5), + q.view(calcu_shape1), + cache_k, + cache_v, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) - fp16_flashdecoding_stage1(BLOCK_SEQ, - mid_o, - mid_o_logexpsum, - 1.0 / (head_dim**0.5), - q.view(calcu_shape1), - cache_k, - cache_v, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.max_len_in_batch) - - flash_decode_stage2(mid_o, - mid_o_logexpsum, - infer_state.b_seq_len, - o_tensor.view(calcu_shape1), - BLOCK_SEQ) + flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) return o_tensor diff --git a/lightllm/models/llama/triton_kernel/ppl_int4kv_flash_decoding.py b/lightllm/models/llama/triton_kernel/ppl_int4kv_flash_decoding.py index 95284cfa9..1e324bcc0 100644 --- a/lightllm/models/llama/triton_kernel/ppl_int4kv_flash_decoding.py +++ b/lightllm/models/llama/triton_kernel/ppl_int4kv_flash_decoding.py @@ -2,7 +2,16 @@ def token_decode_attention_flash_decoding( - q, infer_state, q_head_num, head_dim, cache_k, cache_k_scale, cache_v, cache_v_scale, out=None + q, + infer_state, + q_head_num, + head_dim, + cache_k, + cache_k_scale, + cache_v, + cache_v_scale, + out=None, + alloc_tensor_func=torch.empty, ): BLOCK_SEQ = 256 batch_size = infer_state.batch_size @@ -12,18 +21,15 @@ def token_decode_attention_flash_decoding( from lightllm_ppl_int4kv_flashdecoding_kernel import group8_int4kv_flashdecoding_stage1 from .flash_decoding_stage2 import flash_decode_stage2 - o_tensor = torch.empty_like(q) if out is None else out + o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out - if getattr(infer_state, "mid_o", None) is None: - infer_state.mid_o = torch.empty( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float16, device="cuda" - ) - infer_state.mid_o_logexpsum = torch.empty( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda" - ) + mid_o = alloc_tensor_func( + [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float16, device="cuda" + ) + mid_o_logexpsum = alloc_tensor_func( + [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda" + ) - mid_o = infer_state.mid_o - mid_o_logexpsum = infer_state.mid_o_logexpsum group8_int4kv_flashdecoding_stage1( BLOCK_SEQ, mid_o, diff --git a/lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding.py b/lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding.py index b72097151..8e9ca3e83 100644 --- a/lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding.py +++ b/lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding.py @@ -2,7 +2,16 @@ def token_decode_attention_flash_decoding( - q, infer_state, q_head_num, head_dim, cache_k, cache_k_scale, cache_v, cache_v_scale, out=None + q, + infer_state, + q_head_num, + head_dim, + cache_k, + cache_k_scale, + cache_v, + cache_v_scale, + out=None, + alloc_tensor_func=torch.empty, ): BLOCK_SEQ = 256 batch_size = infer_state.batch_size @@ -12,18 +21,15 @@ def token_decode_attention_flash_decoding( from lightllm_ppl_int8kv_flashdecoding_kernel import group8_int8kv_flashdecoding_stage1 from .flash_decoding_stage2 import flash_decode_stage2 - o_tensor = torch.empty_like(q) if out is None else out + o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out - if getattr(infer_state, "mid_o", None) is None: - infer_state.mid_o = torch.empty( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float16, device="cuda" - ) - infer_state.mid_o_logexpsum = torch.empty( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda" - ) + mid_o = alloc_tensor_func( + [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float16, device="cuda" + ) + mid_o_logexpsum = alloc_tensor_func( + [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda" + ) - mid_o = infer_state.mid_o - mid_o_logexpsum = infer_state.mid_o_logexpsum group8_int8kv_flashdecoding_stage1( BLOCK_SEQ, mid_o, diff --git a/lightllm/models/phi3/layer_infer/transformer_layer_infer.py b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py index 66614a3fd..7ac3019ce 100755 --- a/lightllm/models/phi3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py @@ -86,5 +86,12 @@ def _token_decode_attention_flashdecoding(self, q, infer_state: LlamaInferStateI :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ] return token_decode_attention_flash_decoding( - q, infer_state, self.tp_q_head_num_, self.head_dim_, cache_k, cache_v, out=out + q, + infer_state, + self.tp_q_head_num_, + self.head_dim_, + cache_k, + cache_v, + out=out, + alloc_tensor_func=self.alloc_tensor, ) diff --git a/lightllm/models/phi3/triton_kernel/flash_decoding.py b/lightllm/models/phi3/triton_kernel/flash_decoding.py index 94a3dacac..e47e30886 100644 --- a/lightllm/models/phi3/triton_kernel/flash_decoding.py +++ b/lightllm/models/phi3/triton_kernel/flash_decoding.py @@ -1,7 +1,9 @@ import torch -def token_decode_attention_flash_decoding(q, infer_state, q_head_num, head_dim, cache_k, cache_v, out=None): +def token_decode_attention_flash_decoding( + q, infer_state, q_head_num, head_dim, cache_k, cache_v, out=None, alloc_tensor_func=torch.empty +): BLOCK_SEQ = 256 batch_size = infer_state.batch_size max_len_in_batch = infer_state.max_len_in_batch @@ -10,18 +12,14 @@ def token_decode_attention_flash_decoding(q, infer_state, q_head_num, head_dim, from .flash_decoding_stage1 import flash_decode_stage1 from .flash_decoding_stage2 import flash_decode_stage2 - o_tensor = torch.empty_like(q) if out is None else out + o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out - if getattr(infer_state, "mid_o", None) is None: - infer_state.mid_o = torch.empty( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" - ) - infer_state.mid_o_logexpsum = torch.empty( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" - ) - - mid_o = infer_state.mid_o - mid_o_logexpsum = infer_state.mid_o_logexpsum + mid_o = alloc_tensor_func( + [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" + ) + mid_o_logexpsum = alloc_tensor_func( + [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" + ) flash_decode_stage1( q.view(calcu_shape1), diff --git a/lightllm/models/starcoder/layer_infer/pre_layer_infer.py b/lightllm/models/starcoder/layer_infer/pre_layer_infer.py index f9f38ed23..9c5f6f88d 100644 --- a/lightllm/models/starcoder/layer_infer/pre_layer_infer.py +++ b/lightllm/models/starcoder/layer_infer/pre_layer_infer.py @@ -27,14 +27,14 @@ def context_forward(self, input_ids, infer_state: StarcoderInferStateInfo, layer input_ids = input_ids[0:total_token_num] input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), data_type=layer_weight.data_type_ + (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ ) embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) if self.world_size_ > 1: dist.all_reduce(input_embdings, op=dist.ReduceOp.SUM, async_op=False) position_embeds = self.alloc_tensor( - (infer_state.position_ids.shape[0], layer_weight.wpe_weight_.shape[1]), data_type=layer_weight.data_type_ + (infer_state.position_ids.shape[0], layer_weight.wpe_weight_.shape[1]), dtype=layer_weight.data_type_ ) embedding( infer_state.position_ids, layer_weight.wpe_weight_, 0, layer_weight.wpe_weight_.shape[0], position_embeds @@ -45,14 +45,14 @@ def context_forward(self, input_ids, infer_state: StarcoderInferStateInfo, layer def token_forward(self, input_ids, infer_state: StarcoderInferStateInfo, layer_weight: PreAndPostLayerWeight): # import ipdb;ipdb.set_trace() input_embdings = self.alloc_tensor( - (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), data_type=layer_weight.data_type_ + (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ ) embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) if self.world_size_ > 1: dist.all_reduce(input_embdings, op=dist.ReduceOp.SUM, async_op=False) position_embeds = self.alloc_tensor( - (infer_state.position_ids.shape[0], layer_weight.wpe_weight_.shape[1]), data_type=layer_weight.data_type_ + (infer_state.position_ids.shape[0], layer_weight.wpe_weight_.shape[1]), dtype=layer_weight.data_type_ ) embedding( infer_state.position_ids, layer_weight.wpe_weight_, 0, layer_weight.wpe_weight_.shape[0], position_embeds