Skip to content

Commit

Permalink
add ppl int8kv flashdecoding mode (#363)
Browse files Browse the repository at this point in the history
  • Loading branch information
hiworldwzj committed Mar 19, 2024
1 parent 6d67fbb commit c96e5d5
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 18 deletions.
4 changes: 4 additions & 0 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self, kvargs):
self.max_req_num = kvargs.get("max_req_num", 1000)
self.max_seq_length = kvargs.get("max_seq_length", 1024 * 5)
self.return_all_prompt_logprobs = kvargs.get("return_all_prompt_logprobs", False)
self.use_dynamic_prompt_cache = kvargs.get("use_dynamic_prompt_cache", False)

self._init_config()
self._verify_must()
Expand Down Expand Up @@ -192,6 +193,7 @@ def _prefill(
infer_state = self.infer_state_class()
infer_state.is_prefill = True
infer_state.return_all_prompt_logprobs = self.return_all_prompt_logprobs
infer_state.use_dynamic_prompt_cache = self.use_dynamic_prompt_cache
infer_state.batch_size = batch_size
infer_state.total_token_num = total_token_num
infer_state.max_len_in_batch = max_len_in_batch
Expand Down Expand Up @@ -251,6 +253,7 @@ def _decode(
infer_state.batch_size = batch_size
infer_state.total_token_num = total_token_num
infer_state.max_len_in_batch = max_len_in_batch
infer_state.use_dynamic_prompt_cache = self.use_dynamic_prompt_cache
assert b_req_idx.shape[0] == b_start_loc.shape[0] == b_seq_len.shape[0]
infer_state.b_req_idx = b_req_idx
infer_state.b_start_loc = b_start_loc
Expand Down Expand Up @@ -301,6 +304,7 @@ def splitfuse_forward(
):

infer_state = self.splitfuse_infer_state_class()
infer_state.use_dynamic_prompt_cache = self.use_dynamic_prompt_cache
infer_state.batch_size = decode_req_num + prefill_req_num

infer_state.decode_req_num = decode_req_num
Expand Down
1 change: 1 addition & 0 deletions lightllm/common/basemodel/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(self):

self.is_splitfuse = False
self.return_all_prompt_logprobs = False
self.use_dynamic_prompt_cache = False
self.multimodal_params = None

def init_some_extra_state(self, model, input_ids: torch.Tensor):
Expand Down
2 changes: 2 additions & 0 deletions lightllm/common/basemodel/splitfuse_infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class SplitFuseInferStateInfo:
inner_decode_infer_state_class = InferStateInfo

def __init__(self):
self.use_dynamic_prompt_cache = False

self.batch_size = None

self.decode_req_num = None
Expand Down
7 changes: 4 additions & 3 deletions lightllm/common/mem_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@

logger = init_logger(__name__)


def select_mem_manager_class(mode):
logger.info(f"mode setting params: {mode}")
if "ppl_int8kv" in mode:
if "ppl_int8kv" in mode or "ppl_int8kv_flashdecoding" in mode:
memory_manager_class = PPLINT8KVMemoryManager
logger.info("Model kv cache using mode ppl int8kv")
logger.info(f"Model kv cache using mode {mode}")
elif "triton_int8kv" in mode:
memory_manager_class = INT8KVMemoryManager
logger.info("Model kv cache using mode triton int8kv")
else:
memory_manager_class = MemoryManager
logger.info("Model kv cache using mode normal")
return memory_manager_class
return memory_manager_class
67 changes: 52 additions & 15 deletions lightllm/models/llama/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
import triton

from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import (
context_attention_fwd,
context_attention_fwd_no_prompt_cache,
)
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd, token_att_fwd_int8k
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2, token_att_fwd2_int8v
Expand Down Expand Up @@ -55,6 +58,11 @@ def _bind_attention(self):
if "ppl_int8kv" in self.mode:
self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_ppl_int8kv, self)
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int8kv, self)
elif "ppl_int8kv_flashdecoding" in self.mode:
self._token_attention_kernel = partial(
LlamaTransformerLayerInfer._token_decode_attention_ppl_int8kv_flashdecoding, self
)
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int8kv, self)
elif "ppl_fp16" in self.mode:
self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_ppl_fp16, self)
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
Expand Down Expand Up @@ -124,21 +132,33 @@ def _context_attention_kernel(
) -> torch.Tensor:
o_tensor = torch.empty_like(q) if out is None else out
import triton

if triton.__version__ >= "2.1.0":
context_attention_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :],
infer_state.mem_manager.kv_buffer[self.layer_num_][
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
],
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
infer_state.b_req_idx,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.b_ready_cache_len,
infer_state.max_len_in_batch,
infer_state.req_manager.req_to_token_indexs,
)
if infer_state.use_dynamic_prompt_cache:
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
context_attention_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
kv[:, 0 : self.tp_k_head_num_, :],
kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
infer_state.b_req_idx,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.b_ready_cache_len,
infer_state.max_len_in_batch,
infer_state.req_manager.req_to_token_indexs,
)
else:
context_attention_fwd_no_prompt_cache(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
kv[:, 0 : self.tp_k_head_num_, :],
kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.max_len_in_batch,
)

elif triton.__version__ == "2.0.0":
context_attention_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
Expand Down Expand Up @@ -487,3 +507,20 @@ def _token_decode_attention_ppl_fp16_flashdecoding(
return token_decode_attention_flash_decoding(
q, infer_state, self.tp_q_head_num_, self.head_dim_, cache_k, cache_v, out=out
)

def _token_decode_attention_ppl_int8kv_flashdecoding(
self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None
):
from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding import token_decode_attention_flash_decoding

cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :]
cache_k_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :]
cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
]
cache_v_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
]
return token_decode_attention_flash_decoding(
q, infer_state, self.tp_q_head_num_, self.head_dim_, cache_k, cache_k_scale, cache_v, cache_v_scale, out=out
)
158 changes: 158 additions & 0 deletions lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,164 @@ def context_attention_fwd(
)
return

@triton.jit
def _fwd_kernel_no_prompt_cache(
Q,
K,
V,
sm_scale,
B_Start_Loc,
B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度
Out,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
kv_group_num,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)

cur_kv_head = cur_head // kv_group_num

cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)

block_start_loc = BLOCK_M * start_m

# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :] * stride_qd
)
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd

q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)

k_ptrs = K + off_k
v_ptrs = V + off_v

# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)

block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)

for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
other=0.0,
)
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)

qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))

# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
other=0.0,
)

p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
+ cur_head * stride_oh
+ offs_d[None, :] * stride_od
)
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
return

@torch.no_grad()
def context_attention_fwd_no_prompt_cache(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
BLOCK = 128 if not TESLA else 64
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128, 256}

sm_scale = 1.0 / (Lq ** 0.5) # 计算scale系数
batch, head = b_seq_len.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k.shape[1]

grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,

num_warps = 4 if Lk <= 64 else 8
_fwd_kernel_no_prompt_cache[grid](
q,
k,
v,
sm_scale,
b_start_loc,
b_seq_len,
o,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
kv_group_num=kv_group_num,
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return

elif triton.__version__ == "2.0.0":

@triton.jit
Expand Down
44 changes: 44 additions & 0 deletions lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch


def token_decode_attention_flash_decoding(
q, infer_state, q_head_num, head_dim, cache_k, cache_k_scale, cache_v, cache_v_scale, out=None
):
BLOCK_SEQ = 256
batch_size = infer_state.batch_size
max_len_in_batch = infer_state.max_len_in_batch
calcu_shape1 = (batch_size, q_head_num, head_dim)

from lightllm_ppl_int8kv_flashdecoding_kernel import group8_int8kv_flashdecoding_stage1
from .flash_decoding_stage2 import flash_decode_stage2

o_tensor = torch.empty_like(q) if out is None else out

if getattr(infer_state, "mid_o", None) is None:
infer_state.mid_o = torch.empty(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float16, device="cuda"
)
infer_state.mid_o_logexpsum = torch.empty(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda"
)

mid_o = infer_state.mid_o
mid_o_logexpsum = infer_state.mid_o_logexpsum
group8_int8kv_flashdecoding_stage1(
BLOCK_SEQ,
mid_o,
mid_o_logexpsum,
1.0 / (head_dim ** 0.5),
q.view(calcu_shape1),
cache_k,
cache_k_scale,
cache_v,
cache_v_scale,
infer_state.req_manager.req_to_token_indexs,
infer_state.b_req_idx,
infer_state.b_seq_len,
infer_state.max_len_in_batch,
)

flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ)
return o_tensor
1 change: 1 addition & 0 deletions lightllm/server/router/model_infer/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def exposed_init_model(self, kvargs):
"max_req_num": kvargs.get("max_req_num", 1000),
"max_seq_length": kvargs.get("max_seq_length", 1024 * 5),
"return_all_prompt_logprobs": self.return_all_prompt_logprobs,
"use_dynamic_prompt_cache": self.use_dynamic_prompt_cache,
}

try:
Expand Down

0 comments on commit c96e5d5

Please sign in to comment.