diff --git a/lightllm/common/__init__.py b/lightllm/common/__init__.py index e69de29bb2..01475492a6 100644 --- a/lightllm/common/__init__.py +++ b/lightllm/common/__init__.py @@ -0,0 +1,7 @@ +import torch +try: + from numpy import from_dlpack as _np_from_dlpack + np_from_tensor = _np_from_dlpack +except: + def np_from_tensor(tensor: torch.Tensor): + return tensor.cpu().numpy() diff --git a/lightllm/common/infer_utils.py b/lightllm/common/infer_utils.py index f712d3c80c..2dfc1c31a7 100644 --- a/lightllm/common/infer_utils.py +++ b/lightllm/common/infer_utils.py @@ -1,8 +1,20 @@ -def init_bloc(b_loc, b_seq_len, max_len_in_batch, alloc_mem_index): +import numpy as np +import torch +from typing import Union, List +from . import np_from_tensor + + +def init_bloc(b_loc: torch.Tensor, + b_seq_len: Union[np.ndarray, List[int], torch.Tensor], + max_len_in_batch: int, + alloc_mem_index: torch.Tensor, + ): start_index = 0 - b_seq_len_numpy = b_seq_len.cpu().numpy() - for i in range(len(b_seq_len)): - cur_seq_len = b_seq_len_numpy[i] + if isinstance(b_seq_len, (np.ndarray, list)): + b_seq_len_lst = b_seq_len + elif isinstance(b_seq_len, torch.Tensor): + b_seq_len_lst = np_from_tensor(b_seq_len) + for i, cur_seq_len in enumerate(b_seq_len_lst): b_loc[i, max_len_in_batch - cur_seq_len:max_len_in_batch] = alloc_mem_index[start_index:start_index + cur_seq_len] start_index += cur_seq_len return \ No newline at end of file diff --git a/lightllm/models/llama/layer_infer/model.py b/lightllm/models/llama/layer_infer/model.py index 593ed3a23c..667dfe76ed 100644 --- a/lightllm/models/llama/layer_infer/model.py +++ b/lightllm/models/llama/layer_infer/model.py @@ -1,17 +1,36 @@ import os import json import torch -from lightllm.models.llama.layer_infer.pre_layer_inference import PreLayerInfer -from lightllm.models.llama.layer_infer.post_layer_inference import PostLayerInfer -from lightllm.models.llama.layer_infer.transformer_layer_inference import TransformerLayerInfer -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import * -from lightllm.models.llama.layer_weights.transformer_layer_weight import * -from lightllm.models.llama.layer_infer.infer_struct import InferStateInfo +from .pre_layer_inference import PreLayerInfer +from .post_layer_inference import PostLayerInfer +from .transformer_layer_inference import TransformerLayerInfer +from ..layer_weights.pre_and_post_layer_weight import * +from ..layer_weights.transformer_layer_weight import * +from .infer_struct import InferStateInfo from lightllm.models.llama.layer_weights.hf_load_utils import load_hf_weights from lightllm.common.mem_manager import MemoryManager from lightllm.common.infer_utils import init_bloc +from lightllm.common import np_from_tensor +from typing import List, Dict class LlamaTpPartModel: + tp_rank_: int + world_size_: int + weight_dir_: str + config: Dict + mem_manager: MemoryManager + pre_post_weight: PreAndPostLayerWeight + trans_layers_weight: List[TransformerLayerWeight] + pre_infer: PreLayerInfer + post_infer: PostLayerInfer + layers_infer: List[TransformerLayerInfer] + head_num_: int + head_dim_: int + tp_head_num_: int + vocab_size: int + _cos_cached: torch.Tensor + _sin_cached: torch.Tensor + def __init__(self, tp_rank, world_size, weight_dir, max_total_token_num, load_way="HF", mode=""): self.tp_rank_ = tp_rank self.world_size_ = world_size @@ -44,8 +63,9 @@ def __init__(self, tp_rank, world_size, weight_dir, max_total_token_num, load_wa tp_rank=self.tp_rank_, world_size=self.world_size_, network_config=self.config, - mode=mode) for i in range( - self.config["num_hidden_layers"])] + mode=mode + ) for i in range(self.config["num_hidden_layers"]) + ] self.head_num_ = self.config["num_attention_heads"] self.head_dim_ = self.config["hidden_size"] // self.head_num_ @@ -61,44 +81,54 @@ def init_to_get_rotary(self, base=10000): t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) freqs = torch.outer(t, inv_freq) - self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() - self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + cos_sin_cached = torch.empty([2] + list(freqs.shape), dtype=torch.float16, device='cuda') + cos_sin_cached[0][:] = torch.cos(freqs).to(torch.float16) + cos_sin_cached[1][:] = torch.sin(freqs).to(torch.float16) + self._cos_sin_cached = cos_sin_cached + self._cos_cached = self._cos_sin_cached[0] + self._sin_cached = self._cos_sin_cached[1] return @torch.no_grad() def forward( self, - batch_size, - total_token_num, - max_len_in_batch, - input_ids, - b_loc, - b_start_loc, - b_seq_len, + batch_size: int, + total_token_num: int, + max_len_in_batch: int, + input_ids: torch.Tensor, + b_loc: torch.Tensor, + b_start_loc: torch.Tensor, + b_seq_len: torch.Tensor, is_prefill=True): if is_prefill: + assert (input_ids.shape[0] == total_token_num) + assert (b_loc.shape[0] == b_start_loc.shape[0] == b_seq_len.shape[0]) infer_state = InferStateInfo() infer_state.is_prefill = is_prefill infer_state.batch_size = batch_size infer_state.total_token_num = total_token_num infer_state.max_len_in_batch = max_len_in_batch - assert (input_ids.shape[0] == total_token_num) - assert (b_loc.shape[0] == b_start_loc.shape[0] == b_seq_len.shape[0]) - - b_seq_len_numpy = b_seq_len.cpu().numpy() - position_ids = torch.from_numpy(np.concatenate([np.arange(0, b_seq_len_numpy[i]) - for i in range(len(b_seq_len_numpy))], axis=0)).cuda() - infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids).view(position_ids.shape[0], -1) - infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids).view(position_ids.shape[0], -1) - position_ids = None infer_state.b_loc = b_loc infer_state.b_start_loc = b_start_loc infer_state.b_seq_len = b_seq_len infer_state.mem_manager = self.mem_manager infer_state.prefill_mem_index = self.mem_manager.alloc(infer_state.total_token_num) + + b_seq_len_cpu = b_seq_len.cpu() + + _arange = np.arange(0, max_len_in_batch) + b_seq_len_numpy = np_from_tensor(b_seq_len_cpu) + position_ids = torch.from_numpy(np.concatenate([_arange[:x] for x in b_seq_len_numpy], axis=0)).cuda() + del _arange + + position_cos_sin = torch.index_select(self._cos_sin_cached, 1, position_ids) + infer_state.position_cos = position_cos_sin[0].view(position_ids.shape[0], -1) + infer_state.position_sin = position_cos_sin[1].view(position_ids.shape[0], -1) + del position_ids + infer_state.prefill_key_buffer = torch.empty((infer_state.total_token_num, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") infer_state.prefill_value_buffer = torch.empty((infer_state.total_token_num, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - init_bloc(b_loc, b_seq_len, max_len_in_batch, infer_state.prefill_mem_index) + init_bloc(b_loc, b_seq_len_numpy, max_len_in_batch, infer_state.prefill_mem_index) predict_logics = self._context_forward(input_ids, infer_state) return predict_logics @@ -108,9 +138,13 @@ def forward( infer_state.batch_size = batch_size infer_state.total_token_num = total_token_num infer_state.max_len_in_batch = max_len_in_batch - assert (b_loc.shape[0] == b_start_loc.shape[0] == b_seq_len.shape[0]) - infer_state.position_cos = torch.index_select(self._cos_cached, 0, b_seq_len - 1).view(b_seq_len.shape[0], -1) - infer_state.position_sin = torch.index_select(self._sin_cached, 0, b_seq_len - 1).view(b_seq_len.shape[0], -1) + assert (batch_size == b_loc.shape[0] == b_start_loc.shape[0] == b_seq_len.shape[0]) + + _sel = b_seq_len - 1 + position_cos_sin = torch.index_select(self._cos_sin_cached, 1, _sel) + infer_state.position_cos = position_cos_sin[0].view(batch_size, -1) + infer_state.position_sin = position_cos_sin[1].view(batch_size, -1) + del _sel infer_state.b_loc = b_loc infer_state.b_start_loc = b_start_loc infer_state.b_seq_len = b_seq_len diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 011c4a4221..592379ac7a 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -90,15 +90,15 @@ def init_batch(cls, batch_id, requests, dtype: torch.dtype, device: torch.device nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_length) - nopad_b_seq_len = torch.tensor(input_lengths, dtype=torch.int32, device="cuda") - nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1] + nopad_b_seq_len = torch.tensor(input_lengths, dtype=torch.int32, device=device) + torch.cumsum(nopad_b_seq_len[:-1], dim=0, dtype=torch.int32, out=nopad_b_start_loc[1:]) if len(requests) > 1: input_ids = np.concatenate(all_input_ids, dtype=np.int64) else: input_ids = all_input_ids[0] # Create tensors on device - input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + input_ids = torch.from_numpy(input_ids).to(device=device) return cls( batch_id=batch_id, @@ -140,9 +140,9 @@ def filter(self, request_ids: List[int]): nopad_total_token_num = 0 nopad_max_len_in_batch = 0 + nopad_b_seq_len_numpy = np.array(self.input_lengths, np.int32) nopad_b_loc = torch.empty((len(request_ids), setting['max_req_total_len'] + 12), dtype=torch.long, device='cuda') nopad_b_start_loc = torch.zeros(len(request_ids), dtype=torch.int32, device='cuda') - nopad_b_seq_len = torch.zeros(len(request_ids), dtype=torch.int32, device='cuda') left_idx = [] for i, request_id in enumerate(request_ids): @@ -163,10 +163,13 @@ def filter(self, request_ids: List[int]): idx = self.requests_idx_mapping[request_id] indices.append(idx) - nopad_b_seq_len[:] = self.nopad_b_seq_len[indices] - nopad_max_len_in_batch = torch.max(nopad_b_seq_len).item() - nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1] - nopad_total_token_num = torch.sum(nopad_b_seq_len).item() + nopad_b_seq_len_numpy[:len(indices)] = nopad_b_seq_len_numpy[indices] + nopad_b_seq_len_numpy = nopad_b_seq_len_numpy[:len(indices)] + self.nopad_b_seq_len = self.nopad_b_seq_len[:len(indices)] + self.nopad_b_seq_len.copy_(torch.from_numpy(nopad_b_seq_len_numpy)) + torch.cumsum(self.nopad_b_seq_len[:len(indices) - 1], dim=0, dtype=torch.int32, out=nopad_b_start_loc[1:]) + nopad_max_len_in_batch = nopad_b_seq_len_numpy.max() + nopad_total_token_num = nopad_b_seq_len_numpy.sum() nopad_b_loc[:, 0 : (nopad_max_len_in_batch - 1)] = self.nopad_b_loc[indices, (self.nopad_max_len_in_batch - 1) - (nopad_max_len_in_batch - 1): (self.nopad_max_len_in_batch - 1)] for i, request_id in enumerate(request_ids): @@ -197,12 +200,11 @@ def filter(self, request_ids: List[int]): @classmethod @torch.no_grad() - def merge(cls, batch1, batch2): + def merge(cls, batch1: 'InferBatch', batch2: 'InferBatch'): requests = batch1.requests + batch2.requests requests_idx_mapping = {} new_batch_size = len(batch1) + len(batch2) - input_ids = batch1.input_ids.new_empty(new_batch_size) all_input_ids = [] input_lengths = [] out_token_id_counts=[] @@ -213,8 +215,11 @@ def merge(cls, batch1, batch2): nopad_max_len_in_batch = max(batch1.nopad_max_len_in_batch, batch2 .nopad_max_len_in_batch) nopad_b_loc = torch.empty((new_batch_size, setting['max_req_total_len'] + 12), dtype=torch.long, device='cuda') - nopad_b_start_loc = torch.zeros(new_batch_size, dtype=torch.int32, device='cuda') - nopad_b_seq_len = torch.zeros(new_batch_size, dtype=torch.int32, device='cuda') + nopad_b_start_loc = torch.empty(new_batch_size, dtype=torch.int32, device='cuda') + + nopad_b_seq_len_cuda = torch.concat((batch1.nopad_b_seq_len, batch2.nopad_b_seq_len)) + input_ids = torch.concat((batch1.input_ids, batch2.input_ids)) + nopad_start_loc_len_temp = 0 batches = [batch1, batch2] for i, batch in enumerate(batches): @@ -225,10 +230,8 @@ def merge(cls, batch1, batch2): requests_idx_mapping[k] = v + cumulative_batch_size start_index = cumulative_batch_size end_index = cumulative_batch_size + len(batch) - input_ids[start_index:end_index] = batch.input_ids - nopad_b_seq_len[start_index: end_index] = batch.nopad_b_seq_len nopad_b_start_loc[start_index: end_index] = batch.nopad_b_start_loc + nopad_start_loc_len_temp - nopad_start_loc_len_temp = nopad_b_start_loc[end_index - 1] + nopad_b_seq_len[end_index - 1] + nopad_start_loc_len_temp = nopad_b_start_loc[end_index - 1] + nopad_b_seq_len_cuda[end_index - 1] nopad_b_loc[start_index: end_index, nopad_max_len_in_batch - batch.nopad_max_len_in_batch: nopad_max_len_in_batch - 1] = batch.nopad_b_loc[:, :batch.nopad_max_len_in_batch - 1] @@ -240,8 +243,10 @@ def merge(cls, batch1, batch2): # Update cumulative_batch_size += len(batch) - nopad_b_loc[:, nopad_max_len_in_batch - 1] = nopad_total_token_num - \ - new_batch_size + torch.arange(0, new_batch_size, dtype=torch.int32, device='cuda') + offset = nopad_total_token_num - new_batch_size + torch.arange(offset, offset + new_batch_size, dtype=torch.int32, device='cuda', + out=nopad_b_loc[:, nopad_max_len_in_batch - 1]) + batches[0].nopad_b_seq_len.set_shape(len(batch1) + len(batch2)) return InferBatch( batch_id=batches[0].batch_id, requests=requests, @@ -253,7 +258,7 @@ def merge(cls, batch1, batch2): nopad_max_len_in_batch=nopad_max_len_in_batch, nopad_b_loc=nopad_b_loc, nopad_b_start_loc=nopad_b_start_loc, - nopad_b_seq_len=nopad_b_seq_len, + nopad_b_seq_len=nopad_b_seq_len_cuda, out_token_id_counts=out_token_id_counts, sampling_param_list=sampling_param_list, mem_manager=batches[0].mem_manager diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 183e426c33..f5af34ee9c 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -12,6 +12,7 @@ from lightllm.models.bloom.layer_infer.model import BloomTpPartModel from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end +from lightllm.common import np_from_tensor from .post_process import sample class ModelRpcServer(rpyc.Service): @@ -119,7 +120,7 @@ def forward(self, batch_id, is_prefill): next_token_ids = sample(logits, batch) output_dict = {} new_input_ids = [] - next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_ids = np_from_tensor(next_token_ids.detach().cpu()) for i, (r, all_input_ids, next_token_id) in enumerate(zip(batch.requests, batch.all_input_ids, next_token_ids)): # all_input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.long, device="cuda") all_input_ids.append(int(next_token_id))