Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lightllm/common/basemodel/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
def copy_for_cuda_graph(self, new_infer_state):
for attr_name, attr_value in vars(new_infer_state).items():
if isinstance(attr_value, torch.Tensor):
attr_ = getattr(self, attr_name)
attr_ = getattr(self, attr_name, None)
if attr_ is not None:
attr_.copy_(attr_value)
return
4 changes: 2 additions & 2 deletions lightllm/common/basemodel/layer_infer/base_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ def splitfuse_forward(self, input_ids, infer_state: SplitFuseInferStateInfo, lay
def alloc_tensor(
self,
shape: Union[torch.Size, Iterable[int]],
data_type: torch.dtype,
dtype: torch.dtype,
device: str = "cuda",
is_graph_out: bool = False,
) -> torch.Tensor:
"""
is_graph_out 用于标记是graph图推理中的最后一个tensor,该参数只会在开启cuda graph时生效。该tensor的复用有特殊的逻辑,用于降低显存
占用
"""
return g_cache_manager.alloc_tensor(shape, data_type, device=device, is_graph_out=is_graph_out)
return g_cache_manager.alloc_tensor(shape, dtype, device=device, is_graph_out=is_graph_out)
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,14 @@ def cache_env_out(self):

def alloc_tensor(
self,
shape: Union[torch.Size, Iterable[int]],
shape: Union[torch.Size, Tuple[int, ...]],
data_type: torch.dtype,
device: str = "cuda",
is_graph_out: bool = False,
) -> torch.Tensor:
# shape 类型转换
if isinstance(shape, list):
shape = torch.Size(shape)
# 是 cuda graph的时候,由cuda graph manager 接管
if self.is_cuda_graph:
return self.inner_cuda_graph_manager.alloc_tensor_for_cuda_graph(
Expand Down
4 changes: 2 additions & 2 deletions lightllm/models/bloom/layer_infer/pre_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def context_forward(self, input_ids, infer_state: InferStateInfo, layer_weight:
input_ids = input_ids[0:total_token_num]

input_embdings = self.alloc_tensor(
(input_ids.shape[0], layer_weight.wte_weight_.shape[1]), data_type=layer_weight.data_type_
(input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_
)
embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings)
if self.world_size_ > 1:
Expand All @@ -37,7 +37,7 @@ def context_forward(self, input_ids, infer_state: InferStateInfo, layer_weight:

def token_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: BloomPreAndPostLayerWeight):
input_embdings = self.alloc_tensor(
(input_ids.shape[0], layer_weight.wte_weight_.shape[1]), data_type=layer_weight.data_type_
(input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_
)
embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings)
if self.world_size_ > 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def _token_decode_attention_flashdecoding(self, q, infer_state: LlamaInferStateI
self.qk_rope_head_dim,
self.qk_nope_head_dim,
self.softmax_scale,
alloc_tensor_func=self.alloc_tensor,
)

def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager):
Expand Down
24 changes: 10 additions & 14 deletions lightllm/models/deepseek2/triton_kernel/flash_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def token_decode_attention_flash_decoding(
qk_nope_head_dim,
softmax_scale,
out=None,
alloc_tensor_func=torch.empty,
):
if kv_lora_rank > 128:
BLOCK_SEQ = 256 // (kv_lora_rank // 128)
Expand All @@ -24,20 +25,15 @@ def token_decode_attention_flash_decoding(
from lightllm.models.deepseek2.triton_kernel.flash_decoding_stage1 import flash_decode_stage1
from lightllm.models.deepseek2.triton_kernel.flash_decoding_stage2 import flash_decode_stage2

o_tensor = torch.empty_like(q_nope) if out is None else out

if getattr(infer_state, "mid_o", None) is None:
infer_state.mid_o = torch.empty(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, kv_lora_rank],
dtype=torch.float32,
device="cuda",
)
infer_state.mid_o_logexpsum = torch.empty(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda"
)

mid_o = infer_state.mid_o
mid_o_logexpsum = infer_state.mid_o_logexpsum
o_tensor = alloc_tensor_func(q_nope.shape, q_nope.dtype, q_nope.device) if out is None else out
mid_o = alloc_tensor_func(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, kv_lora_rank],
dtype=torch.float32,
device="cuda",
)
mid_o_logexpsum = alloc_tensor_func(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda"
)

flash_decode_stage1(
q_nope.view(calcu_shape1),
Expand Down
4 changes: 2 additions & 2 deletions lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _norm(self, input, infer_state, layer_weight: Gemma_2bPreAndPostLayerWeight)

def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Gemma_2bPreAndPostLayerWeight):
input_embdings = self.alloc_tensor(
(input_ids.shape[0], layer_weight.wte_weight_.shape[1]), data_type=layer_weight.data_type_
(input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_
)
embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings)
if self.world_size_ > 1:
Expand All @@ -35,7 +35,7 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei

def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Gemma_2bPreAndPostLayerWeight):
input_embdings = self.alloc_tensor(
(input_ids.shape[0], layer_weight.wte_weight_.shape[1]), data_type=layer_weight.data_type_
(input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_
)
embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings)
if self.world_size_ > 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def _get_qkv(
self, input, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: InternlmTransformerLayerWeight
) -> torch.Tensor:
input = input.view(-1, self.embed_dim_)
q = self.alloc_tensor((input.size(0), layer_weight.q_weight_.size(1)), data_type=input.dtype)
q = self.alloc_tensor((input.size(0), layer_weight.q_weight_.size(1)), dtype=input.dtype)
torch.addmm(layer_weight.q_bias_, input, layer_weight.q_weight_, beta=1.0, alpha=1.0, out=q)
torch.addmm(
layer_weight.kv_bias_,
Expand Down
12 changes: 6 additions & 6 deletions lightllm/models/llama/layer_infer/post_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo
if infer_state.is_splitfuse:
# for SplitFuse
batch_size = infer_state.batch_size
last_input = self.alloc_tensor((batch_size, self.embed_dim_), data_type=input_embdings.dtype)
last_input = self.alloc_tensor((batch_size, self.embed_dim_), dtype=input_embdings.dtype)
tmp_ = torch.cat(
[
torch.ones(infer_state.decode_req_num, dtype=torch.int32, device="cuda"),
Expand All @@ -54,13 +54,13 @@ def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo
select_token_num += 1

last_index = torch.tensor(select_index, dtype=torch.long, device=input_embdings.device)
last_input = self.alloc_tensor((select_token_num, self.embed_dim_), data_type=input_embdings.dtype)
last_input = self.alloc_tensor((select_token_num, self.embed_dim_), dtype=input_embdings.dtype)
last_input[:, :] = input_embdings[last_index, :]
return last_input, select_token_num

if not infer_state.is_splitfuse and infer_state.is_prefill and not infer_state.return_all_prompt_logics:
batch_size = infer_state.batch_size
last_input = self.alloc_tensor((batch_size, self.embed_dim_), data_type=input_embdings.dtype)
last_input = self.alloc_tensor((batch_size, self.embed_dim_), dtype=input_embdings.dtype)
last_index = (
torch.cumsum(infer_state.b_seq_len - infer_state.b_ready_cache_len, dim=0, dtype=torch.long) - 1
)
Expand All @@ -84,15 +84,15 @@ def token_forward(self, input_embdings, infer_state: LlamaInferStateInfo, layer_
last_input = self._norm(last_input, infer_state, layer_weight)
last_input = last_input.permute(1, 0).view(-1, token_num)
logic_batch = self.alloc_tensor(
(layer_weight.lm_head_weight_.shape[0], last_input.shape[1]), data_type=last_input.dtype
(layer_weight.lm_head_weight_.shape[0], last_input.shape[1]), dtype=last_input.dtype
)
torch.mm(layer_weight.lm_head_weight_, last_input, out=logic_batch)

last_input = None
if self.world_size_ == 1:
gather_data = logic_batch
else:
gather_data = self.alloc_tensor((self.vocab_size_, token_num), data_type=input_embdings_dtype)
gather_data = self.alloc_tensor((self.vocab_size_, token_num), dtype=input_embdings_dtype)
split_indexes = np.linspace(0, self.vocab_size_, self.world_size_ + 1, dtype=np.int64)
dist.all_gather(
[gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.world_size_)],
Expand All @@ -101,7 +101,7 @@ def token_forward(self, input_embdings, infer_state: LlamaInferStateInfo, layer_
async_op=False,
)
logic_batch = None
ans_logics = self.alloc_tensor((token_num, self.vocab_size_), data_type=torch.float32, is_graph_out=True)
ans_logics = self.alloc_tensor((token_num, self.vocab_size_), dtype=torch.float32, is_graph_out=True)
ans_logics[:, :] = gather_data.permute(1, 0)
gather_data = None
return ans_logics
Expand Down
4 changes: 2 additions & 2 deletions lightllm/models/llama/layer_infer/pre_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, tp_rank, world_size, network_config, mode):

def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight):
input_embdings = self.alloc_tensor(
(input_ids.shape[0], layer_weight.wte_weight_.shape[1]), data_type=layer_weight.data_type_
(input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_
)
embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings)
if self.world_size_ > 1:
Expand All @@ -30,7 +30,7 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei

def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight):
input_embdings = self.alloc_tensor(
(input_ids.shape[0], layer_weight.wte_weight_.shape[1]), data_type=layer_weight.data_type_
(input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_
)
embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings)
if self.world_size_ > 1:
Expand Down
33 changes: 29 additions & 4 deletions lightllm/models/llama/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _ffn_norm(
def _get_qkv(
self, input, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight
) -> torch.Tensor:
q = self.alloc_tensor((input.size(0), layer_weight.q_weight_.size(1)), data_type=input.dtype)
q = self.alloc_tensor((input.size(0), layer_weight.q_weight_.size(1)), dtype=input.dtype)
torch.mm(input, layer_weight.q_weight_, out=q)
torch.mm(
input,
Expand Down Expand Up @@ -503,7 +503,14 @@ def _token_decode_attention_ppl_fp16_flashdecoding(
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
]
return token_decode_attention_flash_decoding(
q, infer_state, self.tp_q_head_num_, self.head_dim_, cache_k, cache_v, out=out
q,
infer_state,
self.tp_q_head_num_,
self.head_dim_,
cache_k,
cache_v,
out=out,
alloc_tensor_func=self.alloc_tensor,
)

def _token_decode_attention_ppl_int8kv_flashdecoding(
Expand All @@ -520,7 +527,16 @@ def _token_decode_attention_ppl_int8kv_flashdecoding(
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
]
return token_decode_attention_flash_decoding(
q, infer_state, self.tp_q_head_num_, self.head_dim_, cache_k, cache_k_scale, cache_v, cache_v_scale, out=out
q,
infer_state,
self.tp_q_head_num_,
self.head_dim_,
cache_k,
cache_k_scale,
cache_v,
cache_v_scale,
out=out,
alloc_tensor_func=self.alloc_tensor,
)

def _token_decode_attention_ppl_int4kv_flashdecoding(
Expand All @@ -537,5 +553,14 @@ def _token_decode_attention_ppl_int4kv_flashdecoding(
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
]
return token_decode_attention_flash_decoding(
q, infer_state, self.tp_q_head_num_, self.head_dim_, cache_k, cache_k_scale, cache_v, cache_v_scale, out=out
q,
infer_state,
self.tp_q_head_num_,
self.head_dim_,
cache_k,
cache_k_scale,
cache_v,
cache_v_scale,
out=out,
alloc_tensor_func=self.alloc_tensor,
)
2 changes: 1 addition & 1 deletion lightllm/models/llama/triton_kernel/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def embedding(input_ids, weight: torch.Tensor, vob_start_id, vob_end_id, out: to

@torch.no_grad()
def embedding_new(input_ids, weight, vob_start_id, vob_end_id):
# out = self.alloc_tensor((N_CTX, DIM), data_type=torch.float32)
# out = self.alloc_tensor((N_CTX, DIM), dtype=torch.float32)
out = torch.empty((N_CTX, DIM), device="cuda", requires_grad=False)

embedding(input_ids, weight, vob_start_id, vob_end_id, out)
Expand Down
16 changes: 6 additions & 10 deletions lightllm/models/llama/triton_kernel/flash_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,12 @@ def token_decode_attention_flash_decoding(

o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out

if getattr(infer_state, "mid_o", None) is None:
infer_state.mid_o = alloc_tensor_func(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda"
)
infer_state.mid_o_logexpsum = alloc_tensor_func(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda"
)

mid_o = infer_state.mid_o
mid_o_logexpsum = infer_state.mid_o_logexpsum
mid_o = alloc_tensor_func(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda"
)
mid_o_logexpsum = alloc_tensor_func(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda"
)

flash_decode_stage1(
q.view(calcu_shape1),
Expand Down
16 changes: 6 additions & 10 deletions lightllm/models/llama/triton_kernel/gqa_flash_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,12 @@ def gqa_token_decode_attention_flash_decoding(

o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out

if getattr(infer_state, "mid_o", None) is None:
infer_state.mid_o = alloc_tensor_func(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda"
)
infer_state.mid_o_logexpsum = alloc_tensor_func(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda"
)

mid_o = infer_state.mid_o
mid_o_logexpsum = infer_state.mid_o_logexpsum
mid_o = alloc_tensor_func(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda"
)
mid_o_logexpsum = alloc_tensor_func(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda"
)

flash_decode_stage1(
q.view(calcu_shape1),
Expand Down
62 changes: 27 additions & 35 deletions lightllm/models/llama/triton_kernel/ppl_fp16_flash_decoding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import torch

def token_decode_attention_flash_decoding(q, infer_state, q_head_num, head_dim, cache_k, cache_v, out=None):

def token_decode_attention_flash_decoding(
q, infer_state, q_head_num, head_dim, cache_k, cache_v, out=None, alloc_tensor_func=torch.empty
):
BLOCK_SEQ = 256
batch_size = infer_state.batch_size
max_len_in_batch = infer_state.max_len_in_batch
Expand All @@ -9,39 +12,28 @@ def token_decode_attention_flash_decoding(q, infer_state, q_head_num, head_dim,
from lightllm_ppl_fp16_flashdecoding_kernel import fp16_flashdecoding_stage1
from .flash_decoding_stage2 import flash_decode_stage2

o_tensor = torch.empty_like(q) if out is None else out

if getattr(infer_state, 'mid_o', None) is None:
infer_state.mid_o = torch.empty([batch_size,
q_head_num,
max_len_in_batch // BLOCK_SEQ + 1,
head_dim],
dtype=torch.float16,
device="cuda")
infer_state.mid_o_logexpsum = torch.empty([batch_size,
q_head_num,
max_len_in_batch // BLOCK_SEQ + 1],
dtype=torch.float16,
device="cuda")

mid_o = infer_state.mid_o
mid_o_logexpsum = infer_state.mid_o_logexpsum
o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out

mid_o = alloc_tensor_func(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float16, device="cuda"
)
mid_o_logexpsum = alloc_tensor_func(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda"
)

fp16_flashdecoding_stage1(
BLOCK_SEQ,
mid_o,
mid_o_logexpsum,
1.0 / (head_dim ** 0.5),
q.view(calcu_shape1),
cache_k,
cache_v,
infer_state.req_manager.req_to_token_indexs,
infer_state.b_req_idx,
infer_state.b_seq_len,
infer_state.max_len_in_batch,
)

fp16_flashdecoding_stage1(BLOCK_SEQ,
mid_o,
mid_o_logexpsum,
1.0 / (head_dim**0.5),
q.view(calcu_shape1),
cache_k,
cache_v,
infer_state.req_manager.req_to_token_indexs,
infer_state.b_req_idx,
infer_state.b_seq_len,
infer_state.max_len_in_batch)

flash_decode_stage2(mid_o,
mid_o_logexpsum,
infer_state.b_seq_len,
o_tensor.view(calcu_shape1),
BLOCK_SEQ)
flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ)
return o_tensor
Loading
Loading