diff --git a/inference/worker/__main__.py b/inference/worker/__main__.py index 569e340276..cb7d3089a5 100644 --- a/inference/worker/__main__.py +++ b/inference/worker/__main__.py @@ -34,7 +34,7 @@ def main(): tokenizer = None else: tokenizer: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(model_config.model_id) - logger.warning(f"Tokenizer {tokenizer.name_or_path} vocab size: {tokenizer.vocab_size}") + logger.warning(f"Tokenizer {tokenizer.name_or_path} vocab size: {len(tokenizer)}") inference_http = utils.HttpClient( base_url=settings.inference_server_url, diff --git a/inference/worker/basic_hf_server.py b/inference/worker/basic_hf_server.py index b97fb8d83e..052680f01f 100644 --- a/inference/worker/basic_hf_server.py +++ b/inference/worker/basic_hf_server.py @@ -138,7 +138,7 @@ def load_models(): hf_config = transformers.AutoConfig.from_pretrained(model_config.model_id) logger.warning(f"Loading model {model_config.model_id}...") tokenizer = transformers.AutoTokenizer.from_pretrained(model_config.model_id) - logger.warning(f"tokenizer {tokenizer.name_or_path} has vocab size {tokenizer.vocab_size}") + logger.warning(f"tokenizer {tokenizer.name_or_path} has vocab size {len(tokenizer)}") # see `decode_token` method, taken from HF text-generation-inference tokenizer.add_special_tokens({"additional_special_tokens": [""]}) diff --git a/inference/worker/utils.py b/inference/worker/utils.py index ba1175a02e..d5ae793f5c 100644 --- a/inference/worker/utils.py +++ b/inference/worker/utils.py @@ -95,11 +95,14 @@ def get_max_input_length(worker_config: inference.WorkerConfig, plugin_used: boo return max_input_length -def get_tokens_until(tokens: list[int], target: int | list[int]) -> list[int]: - if isinstance(target, int): - return tokens[: tokens.index(target)] - else: - return next((i for i in range(len(tokens) - len(target) + 1) if tokens[i : i + len(target)] == target)) +def get_tokens_until(tokens: list[int], target: list[int]) -> list[int]: + if len(target) == 1: + return tokens[: tokens.index(target[0])] + + for i in range(len(tokens) - len(target)): + if tokens[i : i + len(target)] == target: + break + return tokens[:i] def truncate_prompt( @@ -118,8 +121,8 @@ def truncate_prompt( """ with shared_tokenizer_lock: ids = tokenizer.encode(prompt) - # prompter_prefix_ids could be int or list of ints - prompter_prefix_ids = tokenizer.convert_tokens_to_ids(special_tokens["prompter"]) + # list of int IDs + prompter_prefix_ids = tokenizer.encode(special_tokens["prompter"]) system_prompt: str | None = None system_tokens: list[int] | None = None @@ -134,7 +137,9 @@ def truncate_prompt( num_system_tokens = len(system_tokens) if system_tokens else 0 # Maximum token allowed for the conversation, ex system prompt - max_conversation_length = max_input_length - num_system_tokens + # We incorporate a buffer to allow for final inference tokenization differing from ours + # This is a slightly hacky workaround and it would be better to find a cleaner solution + max_conversation_length = max_input_length - num_system_tokens - int(0.01 * max_input_length) ids = ids[-(max_conversation_length - 1) :] with shared_tokenizer_lock: diff --git a/oasst-shared/oasst_shared/model_configs.py b/oasst-shared/oasst_shared/model_configs.py index 9f1da4fc8f..bd85e4461e 100644 --- a/oasst-shared/oasst_shared/model_configs.py +++ b/oasst-shared/oasst_shared/model_configs.py @@ -150,4 +150,9 @@ def compat_hash(self) -> str: max_input_length=3072, max_total_length=4096, ), + "OA_SFT_CodeLlama_13B_10": ModelConfig( + model_id="OpenAssistant/codellama-13b-oasst-sft-v10", + max_input_length=8192, + max_total_length=12288, + ), }