diff --git a/lmdeploy/model.py b/lmdeploy/model.py index ce22694519..fbc6736f9a 100644 --- a/lmdeploy/model.py +++ b/lmdeploy/model.py @@ -29,12 +29,14 @@ def __init__(self, temperature=0.8, repetition_penalty=1.0, capability='chat', + stop_words=None, **kwargs): self.session_len = session_len self.top_p = top_p self.top_k = top_k self.temperature = temperature self.repetition_penalty = repetition_penalty + self.stop_words = stop_words self.capability = capability def get_prompt(self, prompt, sequence_start=True): @@ -101,11 +103,6 @@ def messages2prompt(self, messages, sequence_start=True): return self.get_prompt(messages) # chat history processing in derived classes - @property - def stop_words(self): - """Return the stop-words' token ids.""" - return None - @property def sampling_param(self): return SamplingParam(top_p=self.top_p, @@ -185,6 +182,7 @@ def __init__( eoh='', eoa='', assistant='<|Bot|>', + stop_words=[''], **kwargs): super().__init__(**kwargs) self.system = system @@ -193,6 +191,7 @@ def __init__( self.eoh = eoh self.eoa = eoa self.assistant = assistant + self.stop_words = stop_words def decorate_prompt(self, prompt, sequence_start=True): """Return the prompt that is concatenated with other elements in the @@ -227,7 +226,8 @@ def messages2prompt(self, messages, sequence_start=True): if isinstance(messages, str): return self.get_prompt(messages, sequence_start) system, users, assistants = self._translate_messages(messages) - ret = '' + system = self.meta_instruction if not system else system + ret = f'{self.system}:{system}\n' for user, assistant in zip(users, assistants): if assistant: ret += f'{self.user}:{user}{self.eoh}\n{self.assistant}:' \ @@ -236,11 +236,6 @@ def messages2prompt(self, messages, sequence_start=True): ret += f'{self.user}:{user}{self.eoh}\n{self.assistant}:' return ret - @property - def stop_words(self): - """Return the stop-words' token ids.""" - return [103028] - @MODELS.register_module(name='internlm-chat-20b') @MODELS.register_module(name='internlm-chat-7b-8k') @@ -339,12 +334,14 @@ def __init__(self, eoh='', assistant='', eoa='', + stop_words=None, **kwargs): super().__init__(**kwargs) self.meta_instruction = meta_instruction self.system = system self.user = user self.assistant = assistant + self.stop_words = stop_words self.eosys = eosys self.eoh = eoh self.eoa = eoa @@ -382,11 +379,6 @@ def messages2prompt(self, messages, sequence_start=True): ret += f'{self.user}{user}{self.eoh}{self.assistant}' return ret - @property - def stop_words(self): - """Return the stop-words' token ids.""" - return [45623] - @MODELS.register_module(name='llama2') class Llama2(BaseModel): @@ -468,6 +460,7 @@ def __init__(self, im_start='<|im_start|>', im_end='<|im_end|>', system='You are a helpful assistant.', + stop_words=['<|im_end|>'], **kwargs): super().__init__(**kwargs) self.session_len = session_len @@ -478,6 +471,7 @@ def __init__(self, self.im_start = im_start self.im_end = im_end self.system = system + self.stop_words = stop_words def decorate_prompt(self, prompt, sequence_start=True): assert self.capability == 'chat', \ @@ -513,11 +507,6 @@ def messages2prompt(self, messages, sequence_start=True): f'\n{self.im_start}assistant\n' return ret - @property - def stop_words(self): - """Return the stop-words' token ids.""" - return [151645] # <|im_end|> - @MODELS.register_module(name='codellama') class CodeLlama(Llama2): @@ -526,6 +515,7 @@ def __init__(self, system='', session_len=4096, suffix_first=False, + stop_words=None, **kwargs): super().__init__(**kwargs) caps = ['completion', 'infilling', 'chat', 'python'] @@ -535,6 +525,7 @@ def __init__(self, self.default_sys_prompt = system self.session_len = session_len self.suffix_first = suffix_first + self.stop_words = stop_words # The following sampling parameters refers to https://github.com/facebookresearch/codellama # noqa: E501 if self.capability == 'completion' or self.capability == 'python': @@ -546,6 +537,8 @@ def __init__(self, elif self.capability == 'infilling': self.top_p = kwargs.get('top_p', 0.9) self.temperature = kwargs.get('temperature', 0.0) + if self.stop_words is None: + self.stop_words = [''] def decorate_prompt(self, prompt, sequence_start=True): if self.capability == 'infilling': @@ -574,14 +567,6 @@ def _get_prompt(self, prompt, sequence_start): return f'{self.b_inst} {prompt} {self.e_inst}' - @property - def stop_words(self): - if self.capability == 'infilling': - # EOT ID - return [32010] - else: - return None - def messages2prompt(self, messages, sequence_start=True): assert self.capability == 'chat', \ f'codellama message2prompt only supports chat mode ' \ diff --git a/lmdeploy/serve/turbomind/chatbot.py b/lmdeploy/serve/turbomind/chatbot.py index eb532e2602..cc12fcff3b 100644 --- a/lmdeploy/serve/turbomind/chatbot.py +++ b/lmdeploy/serve/turbomind/chatbot.py @@ -18,6 +18,7 @@ from lmdeploy.model import MODELS from lmdeploy.serve.turbomind.utils import (Postprocessor, Preprocessor, prepare_tensor) +from lmdeploy.utils import filter_suffix @dataclass @@ -157,6 +158,8 @@ def stream_infer(self, request_output_len, sequence_start, sequence_end): + if status == StatusCode.TRITON_STREAM_END: # remove stop_words + res = filter_suffix(res, self.model.stop_words) if status.value < 0: break else: @@ -346,6 +349,8 @@ def infer(self, sequence_end): if status.value < 0: break + if status == StatusCode.TRITON_STREAM_END: # remove stop_words + res = filter_suffix(res, self.model.stop_words) if status.value == 0: self._session.histories = \ self._session.histories + self._session.prompt + \ @@ -386,16 +391,23 @@ def _get_eos(self): token_ids, _ = self.preprocess('') return token_ids[0][0] - def _stop_words(self, stop_words: List[int]): + def _stop_words(self, stop_words: List[str]): """return stop-words' token ids.""" if stop_words is None: return None assert isinstance(stop_words, List) and \ - all(isinstance(elem, int) for elem in stop_words), \ + all(isinstance(elem, str) for elem in stop_words), \ f'stop_words must be a list but got {type(stop_words)}' # each id in stop_words represents a stop word # refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for # detailed explanation about turbomind's stop_words + stop_words = [ + int(self.preprocess(stop_word)[0][0][-1]) + for stop_word in stop_words + ] + assert isinstance(stop_words, List) and \ + all(isinstance(elem, int) for elem in stop_words), \ + 'invalid stop_words' stop_word_offsets = range(1, len(stop_words) + 1) stop_words = np.array([[stop_words, stop_word_offsets]]).astype(np.int32) diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index 45760a309a..f8a7444546 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -14,6 +14,7 @@ import lmdeploy from lmdeploy.model import MODELS +from lmdeploy.turbomind import Tokenizer from lmdeploy.utils import get_logger # TODO: find another way import _turbomind @@ -22,14 +23,16 @@ import _turbomind as _tm # noqa: E402 -def _stop_words(stop_words: List[int]): +def _stop_words(stop_words: List[str], tokenizer: Tokenizer): """return list of stop-words to numpy.ndarray.""" if stop_words is None: return None assert isinstance(stop_words, List) and \ - all(isinstance(elem, int) for elem in stop_words), \ + all(isinstance(elem, str) for elem in stop_words), \ f'stop_words must be a list but got {type(stop_words)}' - + stop_words = [tokenizer.encode(stop_word)[-1] for stop_word in stop_words] + assert isinstance(stop_words, List) and all( + isinstance(elem, int) for elem in stop_words), 'invalid stop_words' # each id in stop_words represents a stop word # refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for # detailed explanation about fastertransformer's stop_words @@ -106,7 +109,10 @@ def __init__(self, model_path: str, eos_id: int = 2, tp: int = 1): self.model_name = parser.get(section_name, 'model_name') data_type = parser.get(section_name, 'weight_type') model = MODELS.get(self.model_name)() - self.stop_words = _stop_words(model.stop_words) + tokenizer_model_path = osp.join(model_path, 'triton_models', + 'tokenizer') + tokenizer = Tokenizer(tokenizer_model_path) + self.stop_words = _stop_words(model.stop_words, tokenizer) # params self.node_id = node_id @@ -162,6 +168,8 @@ def __init__(self, tm_model, cuda_stream_id=0): self.gpu_count = tm_model.gpu_count self.stop_words = tm_model.stop_words + self.stop_tokens = [] if self.stop_words is None else \ + self.stop_words.flatten().tolist() self.eos_id = tm_model.eos_id self.session_len = tm_model.session_len @@ -346,6 +354,8 @@ def _broadcast_np(data, dtype, shape=(batch_size, )): output, len_ = output, len_.item() if len(output) > 0 and output[-1].item() == self.eos_id: outputs.append((output[:-1], len_ - 1)) + elif len(output) > 0 and output[-1].item() in self.stop_tokens: + outputs.append((output[:-1], len_)) else: outputs.append((output, len_)) diff --git a/lmdeploy/utils.py b/lmdeploy/utils.py index 7b6d51a01a..e284f50075 100644 --- a/lmdeploy/utils.py +++ b/lmdeploy/utils.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import logging -from typing import Optional +from typing import List, Optional logger_initialized = {} @@ -77,3 +77,21 @@ def get_logger(name: str, logger_initialized[name] = True return logger + + +def filter_suffix(response: str, suffixes: Optional[List[str]] = None) -> str: + """Filter response with suffixes. + + Args: + response (str): generated response by LLMs. + suffixes (str): a list of suffixes to be deleted. + + Return: + str: a clean response. + """ + if suffixes is None: + return response + for item in suffixes: + if response.endswith(item): + response = response[:len(response) - len(item)] + return response diff --git a/tests/test_lmdeploy/test_model.py b/tests/test_lmdeploy/test_model.py index dcf04d5c28..d07e1f1f73 100644 --- a/tests/test_lmdeploy/test_model.py +++ b/tests/test_lmdeploy/test_model.py @@ -133,7 +133,7 @@ def test_codellama_infilling(): ''' _prompt = model.get_prompt(prompt) assert _prompt.find('') == -1 - assert model.stop_words == [32010] + assert model.stop_words == [''] model = MODELS.get('codellama')(capability='infilling', suffix_first=True) _prompt = model.get_prompt(prompt)