-
Notifications
You must be signed in to change notification settings - Fork 177
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add ppl int8kv flashdecoding mode (#363)
- Loading branch information
1 parent
6d67fbb
commit c96e5d5
Showing
8 changed files
with
266 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
44 changes: 44 additions & 0 deletions
44
lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import torch | ||
|
||
|
||
def token_decode_attention_flash_decoding( | ||
q, infer_state, q_head_num, head_dim, cache_k, cache_k_scale, cache_v, cache_v_scale, out=None | ||
): | ||
BLOCK_SEQ = 256 | ||
batch_size = infer_state.batch_size | ||
max_len_in_batch = infer_state.max_len_in_batch | ||
calcu_shape1 = (batch_size, q_head_num, head_dim) | ||
|
||
from lightllm_ppl_int8kv_flashdecoding_kernel import group8_int8kv_flashdecoding_stage1 | ||
from .flash_decoding_stage2 import flash_decode_stage2 | ||
|
||
o_tensor = torch.empty_like(q) if out is None else out | ||
|
||
if getattr(infer_state, "mid_o", None) is None: | ||
infer_state.mid_o = torch.empty( | ||
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float16, device="cuda" | ||
) | ||
infer_state.mid_o_logexpsum = torch.empty( | ||
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda" | ||
) | ||
|
||
mid_o = infer_state.mid_o | ||
mid_o_logexpsum = infer_state.mid_o_logexpsum | ||
group8_int8kv_flashdecoding_stage1( | ||
BLOCK_SEQ, | ||
mid_o, | ||
mid_o_logexpsum, | ||
1.0 / (head_dim ** 0.5), | ||
q.view(calcu_shape1), | ||
cache_k, | ||
cache_k_scale, | ||
cache_v, | ||
cache_v_scale, | ||
infer_state.req_manager.req_to_token_indexs, | ||
infer_state.b_req_idx, | ||
infer_state.b_seq_len, | ||
infer_state.max_len_in_batch, | ||
) | ||
|
||
flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) | ||
return o_tensor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters