Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions lightllm/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import torch
try:
from numpy import from_dlpack as _np_from_dlpack
np_from_tensor = _np_from_dlpack
except:
def np_from_tensor(tensor: torch.Tensor):
return tensor.cpu().numpy()
20 changes: 16 additions & 4 deletions lightllm/common/infer_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
def init_bloc(b_loc, b_seq_len, max_len_in_batch, alloc_mem_index):
import numpy as np
import torch
from typing import Union, List
from . import np_from_tensor


def init_bloc(b_loc: torch.Tensor,
b_seq_len: Union[np.ndarray, List[int], torch.Tensor],
max_len_in_batch: int,
alloc_mem_index: torch.Tensor,
):
start_index = 0
b_seq_len_numpy = b_seq_len.cpu().numpy()
for i in range(len(b_seq_len)):
cur_seq_len = b_seq_len_numpy[i]
if isinstance(b_seq_len, (np.ndarray, list)):
b_seq_len_lst = b_seq_len
elif isinstance(b_seq_len, torch.Tensor):
b_seq_len_lst = np_from_tensor(b_seq_len)
for i, cur_seq_len in enumerate(b_seq_len_lst):
b_loc[i, max_len_in_batch - cur_seq_len:max_len_in_batch] = alloc_mem_index[start_index:start_index + cur_seq_len]
start_index += cur_seq_len
return
94 changes: 64 additions & 30 deletions lightllm/models/llama/layer_infer/model.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,36 @@
import os
import json
import torch
from lightllm.models.llama.layer_infer.pre_layer_inference import PreLayerInfer
from lightllm.models.llama.layer_infer.post_layer_inference import PostLayerInfer
from lightllm.models.llama.layer_infer.transformer_layer_inference import TransformerLayerInfer
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import *
from lightllm.models.llama.layer_weights.transformer_layer_weight import *
from lightllm.models.llama.layer_infer.infer_struct import InferStateInfo
from .pre_layer_inference import PreLayerInfer
from .post_layer_inference import PostLayerInfer
from .transformer_layer_inference import TransformerLayerInfer
from ..layer_weights.pre_and_post_layer_weight import *
from ..layer_weights.transformer_layer_weight import *
from .infer_struct import InferStateInfo
from lightllm.models.llama.layer_weights.hf_load_utils import load_hf_weights
from lightllm.common.mem_manager import MemoryManager
from lightllm.common.infer_utils import init_bloc
from lightllm.common import np_from_tensor
from typing import List, Dict

class LlamaTpPartModel:
tp_rank_: int
world_size_: int
weight_dir_: str
config: Dict
mem_manager: MemoryManager
pre_post_weight: PreAndPostLayerWeight
trans_layers_weight: List[TransformerLayerWeight]
pre_infer: PreLayerInfer
post_infer: PostLayerInfer
layers_infer: List[TransformerLayerInfer]
head_num_: int
head_dim_: int
tp_head_num_: int
vocab_size: int
_cos_cached: torch.Tensor
_sin_cached: torch.Tensor

def __init__(self, tp_rank, world_size, weight_dir, max_total_token_num, load_way="HF", mode=""):
self.tp_rank_ = tp_rank
self.world_size_ = world_size
Expand Down Expand Up @@ -44,8 +63,9 @@ def __init__(self, tp_rank, world_size, weight_dir, max_total_token_num, load_wa
tp_rank=self.tp_rank_,
world_size=self.world_size_,
network_config=self.config,
mode=mode) for i in range(
self.config["num_hidden_layers"])]
mode=mode
) for i in range(self.config["num_hidden_layers"])
]

self.head_num_ = self.config["num_attention_heads"]
self.head_dim_ = self.config["hidden_size"] // self.head_num_
Expand All @@ -61,44 +81,54 @@ def init_to_get_rotary(self, base=10000):
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32)
freqs = torch.outer(t, inv_freq)

self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
cos_sin_cached = torch.empty([2] + list(freqs.shape), dtype=torch.float16, device='cuda')
cos_sin_cached[0][:] = torch.cos(freqs).to(torch.float16)
cos_sin_cached[1][:] = torch.sin(freqs).to(torch.float16)
self._cos_sin_cached = cos_sin_cached
self._cos_cached = self._cos_sin_cached[0]
self._sin_cached = self._cos_sin_cached[1]
return

@torch.no_grad()
def forward(
self,
batch_size,
total_token_num,
max_len_in_batch,
input_ids,
b_loc,
b_start_loc,
b_seq_len,
batch_size: int,
total_token_num: int,
max_len_in_batch: int,
input_ids: torch.Tensor,
b_loc: torch.Tensor,
b_start_loc: torch.Tensor,
b_seq_len: torch.Tensor,
is_prefill=True):
if is_prefill:
assert (input_ids.shape[0] == total_token_num)
assert (b_loc.shape[0] == b_start_loc.shape[0] == b_seq_len.shape[0])
infer_state = InferStateInfo()
infer_state.is_prefill = is_prefill
infer_state.batch_size = batch_size
infer_state.total_token_num = total_token_num
infer_state.max_len_in_batch = max_len_in_batch
assert (input_ids.shape[0] == total_token_num)
assert (b_loc.shape[0] == b_start_loc.shape[0] == b_seq_len.shape[0])

b_seq_len_numpy = b_seq_len.cpu().numpy()
position_ids = torch.from_numpy(np.concatenate([np.arange(0, b_seq_len_numpy[i])
for i in range(len(b_seq_len_numpy))], axis=0)).cuda()
infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids).view(position_ids.shape[0], -1)
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids).view(position_ids.shape[0], -1)
position_ids = None
infer_state.b_loc = b_loc
infer_state.b_start_loc = b_start_loc
infer_state.b_seq_len = b_seq_len
infer_state.mem_manager = self.mem_manager
infer_state.prefill_mem_index = self.mem_manager.alloc(infer_state.total_token_num)

b_seq_len_cpu = b_seq_len.cpu()

_arange = np.arange(0, max_len_in_batch)
b_seq_len_numpy = np_from_tensor(b_seq_len_cpu)
position_ids = torch.from_numpy(np.concatenate([_arange[:x] for x in b_seq_len_numpy], axis=0)).cuda()
del _arange

position_cos_sin = torch.index_select(self._cos_sin_cached, 1, position_ids)
infer_state.position_cos = position_cos_sin[0].view(position_ids.shape[0], -1)
infer_state.position_sin = position_cos_sin[1].view(position_ids.shape[0], -1)
del position_ids

infer_state.prefill_key_buffer = torch.empty((infer_state.total_token_num, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
infer_state.prefill_value_buffer = torch.empty((infer_state.total_token_num, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
init_bloc(b_loc, b_seq_len, max_len_in_batch, infer_state.prefill_mem_index)
init_bloc(b_loc, b_seq_len_numpy, max_len_in_batch, infer_state.prefill_mem_index)

predict_logics = self._context_forward(input_ids, infer_state)
return predict_logics
Expand All @@ -108,9 +138,13 @@ def forward(
infer_state.batch_size = batch_size
infer_state.total_token_num = total_token_num
infer_state.max_len_in_batch = max_len_in_batch
assert (b_loc.shape[0] == b_start_loc.shape[0] == b_seq_len.shape[0])
infer_state.position_cos = torch.index_select(self._cos_cached, 0, b_seq_len - 1).view(b_seq_len.shape[0], -1)
infer_state.position_sin = torch.index_select(self._sin_cached, 0, b_seq_len - 1).view(b_seq_len.shape[0], -1)
assert (batch_size == b_loc.shape[0] == b_start_loc.shape[0] == b_seq_len.shape[0])

_sel = b_seq_len - 1
position_cos_sin = torch.index_select(self._cos_sin_cached, 1, _sel)
infer_state.position_cos = position_cos_sin[0].view(batch_size, -1)
infer_state.position_sin = position_cos_sin[1].view(batch_size, -1)
del _sel
infer_state.b_loc = b_loc
infer_state.b_start_loc = b_start_loc
infer_state.b_seq_len = b_seq_len
Expand Down
41 changes: 23 additions & 18 deletions lightllm/server/router/model_infer/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,15 @@ def init_batch(cls, batch_id, requests, dtype: torch.dtype, device: torch.device
nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_length)


nopad_b_seq_len = torch.tensor(input_lengths, dtype=torch.int32, device="cuda")
nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1]
nopad_b_seq_len = torch.tensor(input_lengths, dtype=torch.int32, device=device)
torch.cumsum(nopad_b_seq_len[:-1], dim=0, dtype=torch.int32, out=nopad_b_start_loc[1:])
if len(requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
else:
input_ids = all_input_ids[0]

# Create tensors on device
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
input_ids = torch.from_numpy(input_ids).to(device=device)

return cls(
batch_id=batch_id,
Expand Down Expand Up @@ -140,9 +140,9 @@ def filter(self, request_ids: List[int]):

nopad_total_token_num = 0
nopad_max_len_in_batch = 0
nopad_b_seq_len_numpy = np.array(self.input_lengths, np.int32)
nopad_b_loc = torch.empty((len(request_ids), setting['max_req_total_len'] + 12), dtype=torch.long, device='cuda')
nopad_b_start_loc = torch.zeros(len(request_ids), dtype=torch.int32, device='cuda')
nopad_b_seq_len = torch.zeros(len(request_ids), dtype=torch.int32, device='cuda')

left_idx = []
for i, request_id in enumerate(request_ids):
Expand All @@ -163,10 +163,13 @@ def filter(self, request_ids: List[int]):
idx = self.requests_idx_mapping[request_id]
indices.append(idx)

nopad_b_seq_len[:] = self.nopad_b_seq_len[indices]
nopad_max_len_in_batch = torch.max(nopad_b_seq_len).item()
nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1]
nopad_total_token_num = torch.sum(nopad_b_seq_len).item()
nopad_b_seq_len_numpy[:len(indices)] = nopad_b_seq_len_numpy[indices]
nopad_b_seq_len_numpy = nopad_b_seq_len_numpy[:len(indices)]
self.nopad_b_seq_len = self.nopad_b_seq_len[:len(indices)]
self.nopad_b_seq_len.copy_(torch.from_numpy(nopad_b_seq_len_numpy))
torch.cumsum(self.nopad_b_seq_len[:len(indices) - 1], dim=0, dtype=torch.int32, out=nopad_b_start_loc[1:])
nopad_max_len_in_batch = nopad_b_seq_len_numpy.max()
nopad_total_token_num = nopad_b_seq_len_numpy.sum()

nopad_b_loc[:, 0 : (nopad_max_len_in_batch - 1)] = self.nopad_b_loc[indices, (self.nopad_max_len_in_batch - 1) - (nopad_max_len_in_batch - 1): (self.nopad_max_len_in_batch - 1)]
for i, request_id in enumerate(request_ids):
Expand Down Expand Up @@ -197,12 +200,11 @@ def filter(self, request_ids: List[int]):

@classmethod
@torch.no_grad()
def merge(cls, batch1, batch2):
def merge(cls, batch1: 'InferBatch', batch2: 'InferBatch'):
requests = batch1.requests + batch2.requests
requests_idx_mapping = {}
new_batch_size = len(batch1) + len(batch2)

input_ids = batch1.input_ids.new_empty(new_batch_size)
all_input_ids = []
input_lengths = []
out_token_id_counts=[]
Expand All @@ -213,8 +215,11 @@ def merge(cls, batch1, batch2):
nopad_max_len_in_batch = max(batch1.nopad_max_len_in_batch, batch2 .nopad_max_len_in_batch)

nopad_b_loc = torch.empty((new_batch_size, setting['max_req_total_len'] + 12), dtype=torch.long, device='cuda')
nopad_b_start_loc = torch.zeros(new_batch_size, dtype=torch.int32, device='cuda')
nopad_b_seq_len = torch.zeros(new_batch_size, dtype=torch.int32, device='cuda')
nopad_b_start_loc = torch.empty(new_batch_size, dtype=torch.int32, device='cuda')

nopad_b_seq_len_cuda = torch.concat((batch1.nopad_b_seq_len, batch2.nopad_b_seq_len))
input_ids = torch.concat((batch1.input_ids, batch2.input_ids))

nopad_start_loc_len_temp = 0
batches = [batch1, batch2]
for i, batch in enumerate(batches):
Expand All @@ -225,10 +230,8 @@ def merge(cls, batch1, batch2):
requests_idx_mapping[k] = v + cumulative_batch_size
start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch)
input_ids[start_index:end_index] = batch.input_ids
nopad_b_seq_len[start_index: end_index] = batch.nopad_b_seq_len
nopad_b_start_loc[start_index: end_index] = batch.nopad_b_start_loc + nopad_start_loc_len_temp
nopad_start_loc_len_temp = nopad_b_start_loc[end_index - 1] + nopad_b_seq_len[end_index - 1]
nopad_start_loc_len_temp = nopad_b_start_loc[end_index - 1] + nopad_b_seq_len_cuda[end_index - 1]
nopad_b_loc[start_index: end_index, nopad_max_len_in_batch - batch.nopad_max_len_in_batch: nopad_max_len_in_batch -
1] = batch.nopad_b_loc[:, :batch.nopad_max_len_in_batch - 1]

Expand All @@ -240,8 +243,10 @@ def merge(cls, batch1, batch2):
# Update
cumulative_batch_size += len(batch)

nopad_b_loc[:, nopad_max_len_in_batch - 1] = nopad_total_token_num - \
new_batch_size + torch.arange(0, new_batch_size, dtype=torch.int32, device='cuda')
offset = nopad_total_token_num - new_batch_size
torch.arange(offset, offset + new_batch_size, dtype=torch.int32, device='cuda',
out=nopad_b_loc[:, nopad_max_len_in_batch - 1])
batches[0].nopad_b_seq_len.set_shape(len(batch1) + len(batch2))
return InferBatch(
batch_id=batches[0].batch_id,
requests=requests,
Expand All @@ -253,7 +258,7 @@ def merge(cls, batch1, batch2):
nopad_max_len_in_batch=nopad_max_len_in_batch,
nopad_b_loc=nopad_b_loc,
nopad_b_start_loc=nopad_b_start_loc,
nopad_b_seq_len=nopad_b_seq_len,
nopad_b_seq_len=nopad_b_seq_len_cuda,
out_token_id_counts=out_token_id_counts,
sampling_param_list=sampling_param_list,
mem_manager=batches[0].mem_manager
Expand Down
3 changes: 2 additions & 1 deletion lightllm/server/router/model_infer/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lightllm.models.bloom.layer_infer.model import BloomTpPartModel
from lightllm.utils.infer_utils import set_random_seed
from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end
from lightllm.common import np_from_tensor
from .post_process import sample

class ModelRpcServer(rpyc.Service):
Expand Down Expand Up @@ -119,7 +120,7 @@ def forward(self, batch_id, is_prefill):
next_token_ids = sample(logits, batch)
output_dict = {}
new_input_ids = []
next_token_ids = next_token_ids.detach().cpu().numpy()
next_token_ids = np_from_tensor(next_token_ids.detach().cpu())
for i, (r, all_input_ids, next_token_id) in enumerate(zip(batch.requests, batch.all_input_ids, next_token_ids)):
# all_input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.long, device="cuda")
all_input_ids.append(int(next_token_id))
Expand Down