From 5667c9899155052dc9ba3d678d259579c1792ec3 Mon Sep 17 00:00:00 2001 From: shihaobai <42648726+shihaobai@users.noreply.github.com> Date: Tue, 26 Mar 2024 17:37:41 +0800 Subject: [PATCH] add sampling params: skip_special_tokens spaces_between_special_tokens (#373) --- lightllm/server/api_server.py | 11 +---------- lightllm/server/detokenization/decode.py | 17 ++++++++++------- lightllm/server/detokenization/manager.py | 12 ++++-------- lightllm/server/io_struct.py | 12 ++++++++++-- lightllm/server/router/manager.py | 1 + lightllm/server/router/model_infer/model_rpc.py | 2 +- .../server/router/model_infer/post_process.py | 4 ++-- lightllm/server/sampling_params.py | 8 +++++++- 8 files changed, 36 insertions(+), 31 deletions(-) diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index f0557cf3..70b6c313 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -322,7 +322,7 @@ def main(): default=None, help="max tokens num for new cat batch, it control prefill batch size to Preventing OOM", ) - parser.add_argument("--eos_id", type=int, default=2, help="eos stop token id") + parser.add_argument("--eos_id", nargs='+', type=int, default=2, help="eos stop token id") parser.add_argument( "--running_max_req_size", type=int, default=1000, help="the max size for forward requests in the same time" ) @@ -364,15 +364,6 @@ def main(): "--router_max_new_token_len", type=int, default=1024, help="the request max new token len for router" ) - parser.add_argument( - "--no_skipping_special_tokens", action="store_true", help="whether to skip special tokens when decoding" - ) - parser.add_argument( - "--no_spaces_between_special_tokens", - action="store_true", - help="whether to add spaces between special tokens when decoding", - ) - parser.add_argument("--use_dynamic_prompt_cache", action="store_true", help="use_dynamic_prompt_cache test") parser.add_argument("--splitfuse_mode", action="store_true", help="use splitfuse mode") diff --git a/lightllm/server/detokenization/decode.py b/lightllm/server/detokenization/decode.py index e42b61e2..d58d7572 100644 --- a/lightllm/server/detokenization/decode.py +++ b/lightllm/server/detokenization/decode.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, List from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -9,21 +9,24 @@ def decode_token( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], req: ReqDetokenizationState, new_token_id: int, - skip_special_tokens: bool, - spaces_between_special_tokens: bool, + eos_id: List[int], ) -> str: new_token = tokenizer.convert_ids_to_tokens( - new_token_id, skip_special_tokens=skip_special_tokens) + new_token_id, skip_special_tokens=req.skip_special_tokens) req.output_tokens.append(new_token) - - if skip_special_tokens and new_token_id in tokenizer.all_special_ids: + + is_eos_id = new_token_id in eos_id + if is_eos_id and not req.print_eos_token: + return req.output_str + + if req.skip_special_tokens and new_token_id in tokenizer.all_special_ids and not is_eos_id: return req.output_str if not getattr(tokenizer, "added_tokens_encoder", {}): output_text = tokenizer.convert_tokens_to_string(req.output_tokens) return output_text - sep = " " if spaces_between_special_tokens else "" + sep = " " if req.add_spaces_between_special_tokens else "" if new_token in tokenizer.added_tokens_encoder: if req.current_sub_text: diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index 2609012d..094692a5 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -17,13 +17,12 @@ class DeTokenizationManager: def __init__( self, + eos_id, model_weightdir, tokenizor_mode, detokenization_port, httpserver_port, trust_remote_code, - skip_special_tokens, - spaces_between_special_tokens, ): context = zmq.asyncio.Context(2) self.recv_from_router = context.socket(zmq.PULL) @@ -35,9 +34,8 @@ def __init__( self.tokenizer = get_tokenizer(model_weightdir, tokenizor_mode, trust_remote_code=trust_remote_code) self.all_special_ids = set(self.tokenizer.all_special_ids) self.req_id_to_out = {} + self.eos_id = eos_id - self.skip_special_tokens = skip_special_tokens - self.spaces_between_special_tokens = spaces_between_special_tokens async def handle_loop(self): while True: @@ -67,8 +65,7 @@ async def handle_loop(self): self.tokenizer, req_out, new_token_id, - skip_special_tokens=self.skip_special_tokens, - spaces_between_special_tokens=self.spaces_between_special_tokens, + self.eos_id, ) if out_text.endswith(u'\ufffd'): @@ -92,13 +89,12 @@ async def handle_loop(self): def start_detokenization_process(args, detokenization_port, httpserver_port, pipe_writer): try: router = DeTokenizationManager( + args.eos_id, args.model_dir, args.tokenizer_mode, detokenization_port=detokenization_port, httpserver_port=httpserver_port, trust_remote_code=args.trust_remote_code, - skip_special_tokens=not args.no_skipping_special_tokens, - spaces_between_special_tokens=not args.no_spaces_between_special_tokens, ) except Exception as e: pipe_writer.send(str(e)) diff --git a/lightllm/server/io_struct.py b/lightllm/server/io_struct.py index deed1b80..954aadce 100644 --- a/lightllm/server/io_struct.py +++ b/lightllm/server/io_struct.py @@ -64,7 +64,9 @@ def to_rpc_obj(self): def to_req_detokenization_state(self): out = ReqDetokenizationState( - self.request_id, self.prompt_ids, self.max_output_len, self.sample_params.ignore_eos + self.request_id, self.prompt_ids, self.max_output_len, self.sample_params.ignore_eos, + self.sample_params.skip_special_tokens, self.sample_params.add_spaces_between_special_tokens, + self.sample_params.print_eos_token ) # if self.output_metadata_list: # looks like no use # out.gen_metadata.update(self.output_metadata_list[-1]) @@ -226,6 +228,9 @@ def __init__( prompt_ids: List[int], max_output_len: int, ignore_eos: bool, + skip_special_tokens: bool, + add_spaces_between_special_tokens: bool, + print_eos_token: bool, ) -> None: self.request_id = request_id self.prompt_ids = prompt_ids @@ -237,6 +242,9 @@ def __init__( self.max_output_len = max_output_len self.ignore_eos = ignore_eos self.gen_metadata = {} + self.skip_special_tokens = skip_special_tokens + self.add_spaces_between_special_tokens = add_spaces_between_special_tokens + self.print_eos_token = print_eos_token class Batch: @@ -265,7 +273,7 @@ def mark_and_get_finished_req_and_preupdate_status(self, eos_id): for req in self.reqs: if req.stop_sequences_matched(): req.finish_status = FinishStatus.FINISHED_STOP - elif len(req.output_ids) >= 1 and req.output_ids[-1] == eos_id and req.sample_params.ignore_eos is False: + elif len(req.output_ids) >= 1 and req.output_ids[-1] in eos_id and req.sample_params.ignore_eos is False: req.finish_status = FinishStatus.FINISHED_STOP elif len(req.output_ids) >= req.max_output_len: req.finish_status = FinishStatus.FINISHED_LENGTH diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 9f8af1af..8ed402a8 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -75,6 +75,7 @@ async def wait_to_model_ready(self): "splitfuse_block_size": self.splitfuse_block_size, "return_all_prompt_logprobs": self.args.return_all_prompt_logprobs, "use_dynamic_prompt_cache": self.args.use_dynamic_prompt_cache, + "eos_id": self.eos_id, } init_model_ret.append(self.model_rpcs[rank_id].init_model(kvargs)) diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 6423dcce..1b4bd143 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -66,6 +66,7 @@ def exposed_init_model(self, kvargs): self.splitfuse_block_size = kvargs.get("splitfuse_block_size", None) self.return_all_prompt_logprobs = kvargs.get("return_all_prompt_logprobs", False) self.use_dynamic_prompt_cache = kvargs.get("use_dynamic_prompt_cache", False) + self.eos_id = kvargs.get("eos_id", [2]) self.cache = {} self.logger = init_logger(__name__) @@ -101,7 +102,6 @@ def exposed_init_model(self, kvargs): try: self.model_type = model_cfg.get("model_type", "") - self.eos_id = model_cfg.get("eos_token_id", 2) if self.model_type == "bloom": self.model = BloomTpPartModel(model_kvargs) elif self.model_type == "llama": diff --git a/lightllm/server/router/model_infer/post_process.py b/lightllm/server/router/model_infer/post_process.py index cb695b22..a609c1a3 100644 --- a/lightllm/server/router/model_infer/post_process.py +++ b/lightllm/server/router/model_infer/post_process.py @@ -5,7 +5,7 @@ from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty -def sample(logits, reqs, eos_id=2): +def sample(logits, reqs, eos_id:List[int]=[2]): logits = logits.contiguous() ( presence_penalties, @@ -33,7 +33,7 @@ def sample(logits, reqs, eos_id=2): p_max_len_in_batch, ) logits[:, eos_id] = logits[:, eos_id] + torch.abs(logits[:, eos_id]) * ( - torch.pow(exponential_decay_length_penalties, length_penalty_idx).view((-1,)) - 1 + torch.pow(exponential_decay_length_penalties, length_penalty_idx).view((-1, 1)) - 1 ) logits.div_(temperatures.view((-1, 1))) probs = torch.softmax(logits, dim=-1) diff --git a/lightllm/server/sampling_params.py b/lightllm/server/sampling_params.py index 28e1d7a3..28b62117 100644 --- a/lightllm/server/sampling_params.py +++ b/lightllm/server/sampling_params.py @@ -18,7 +18,10 @@ def __init__( top_k: int = -1, # -1 is for all ignore_eos: bool = False, max_new_tokens: int = 16, - stop_sequences: Optional[Union[str, List[str]]] = None # 停止句子条件 + stop_sequences: Optional[Union[str, List[str]]] = None, # 停止句子条件 + skip_special_tokens: bool = True, # whether to skip special tokens when decoding + add_spaces_between_special_tokens: bool = True, # whether to add spaces between special tokens when decoding + print_eos_token: bool = False, # eos_id will be always ignored except the value is set to True ) -> None: self.do_sample = do_sample self.presence_penalty = presence_penalty @@ -31,6 +34,9 @@ def __init__( self.ignore_eos = ignore_eos self.max_new_tokens = max_new_tokens self.stop_sequences = stop_sequences + self.skip_special_tokens = skip_special_tokens + self.add_spaces_between_special_tokens = add_spaces_between_special_tokens + self.print_eos_token = print_eos_token if self.do_sample == False: self.temperature = 1.0 self.top_p = 1.0