-
Notifications
You must be signed in to change notification settings - Fork 626
[Refactor] Support multi-session chat #178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
9a53c44
add some dist utils
wangruohui a3146f3
add model utils
wangruohui 66ed2f8
add termio and basicstreamer
wangruohui a10f2ed
typo
wangruohui 248f7b4
fix world size
wangruohui 5417e96
refactor chat and tested llama1
wangruohui 0322f40
add internlm adapter and support stoping criteria
wangruohui a272ed7
concat with id for internlm
wangruohui 15a6025
update docstring
wangruohui 1c9adc9
update and support llama2
wangruohui 117b0ee
typo
wangruohui f59d565
move docs to docs
wangruohui 8dfb221
update docstring of session manager
wangruohui b384b26
update docstring
wangruohui a7add95
update docs
wangruohui 097fc38
fix accel none in model
wangruohui 6492fdc
fix and add test for tensor broadcast
wangruohui 4275366
fix session using typing to check type
wangruohui 50b2571
add docstrings and comprehensive condition test
wangruohui b22ef94
unit test for dist
wangruohui d76499e
fix session
wangruohui 879659b
split unittests of utils
wangruohui 5d7a261
typo
wangruohui c8fbc0b
update control flow of accel
wangruohui 9b451e3
move test model
wangruohui d3c3fe4
remove main in unittest
wangruohui ba0f771
remove some log
wangruohui e5bbda6
remove some comments
wangruohui File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| # Pytorch | ||
|
|
||
| ## Chat in command line | ||
|
|
||
| LMDeploy support chatting with PyTorch models with submodule `lmdeploy.pytorch.chat`. | ||
|
|
||
| This submodule allow user to chat with language model through command line, and optionally accelerate model using backends like deepspeed. | ||
|
|
||
| **Example 1**: Chat with default setting | ||
|
|
||
| ```python | ||
| python -m lmdeploy.pytorch.chat $PATH_TO_HF_MODEL | ||
| ``` | ||
|
|
||
| **Example 2**: Disable sampling and chat history | ||
|
|
||
| ```python | ||
| python -m lmdeploy.pytorch.chat \ | ||
| $PATH_TO_LLAMA_MODEL_IN_HF_FORMAT \ | ||
| --temperature 0 --max-histroy 0 | ||
| ``` | ||
|
|
||
| **Example 3**: Accelerate with deepspeed inference | ||
|
|
||
| ```python | ||
| python -m lmdeploy.pytorch.chat \ | ||
| $PATH_TO_LLAMA_MODEL_IN_HF_FORMAT \ | ||
| --accel deepspeed | ||
| ``` | ||
|
|
||
| Note: to use deepspeed, you need to install deepspeed, and if hope to accelerate InternLM, you need a customized version <https://github.com/wangruohui/DeepSpeed/tree/support_internlm_0.10.0> | ||
|
|
||
| **Example 4**: Tensor parallel the model on 2 GPUs | ||
|
|
||
| ```python | ||
| deepspeed --module --num_gpus 2 lmdeploy.pytorch.chat \ | ||
| $PATH_TO_LLAMA_MODEL_IN_HF_FORMAT \ | ||
| --accel deepspeed \ | ||
| ``` | ||
|
|
||
| This module also allow the following control commands to change generation behaviors during chat. | ||
|
|
||
| - `exit`: terminate and exit chat | ||
| - `config set key=value`: change generation config `key` to `value`, e.g. config temperature=0 disable sampling for following chats | ||
| - `clear`: clear chat history | ||
|
|
||
| ### Simple diagram of components | ||
|
|
||
| ```mermaid | ||
| graph LR; | ||
| subgraph model specific adapter | ||
| p((user_input))-->tokenize-->id((input_ids))-->decorate | ||
| tmpl_ids((template_ids))-->decorate; | ||
| end | ||
| subgraph generate | ||
| model[CausalLM_model.generate]-->gen_result(("gen_result")) | ||
| gen_result-->hid | ||
| gen_result-->attn((attention)) | ||
| end | ||
| subgraph streamer | ||
| model-->s[streamer]--value-->decode_single--token-->output | ||
| end | ||
| subgraph session_manager | ||
| prepend_history-->fullid((complete_ids)); | ||
| trim-->prepend_history | ||
| end | ||
| decorate-->prepend_history | ||
| hid((history_ids))-->trim; | ||
| attn-->trim; | ||
| fullid-->model | ||
| tokenizer((tokenizer))-->decode_single | ||
| tokenizer-->tokenize | ||
| p-->genconfig(GenConfig)-->model | ||
| ``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| # Copyright (c) OpenMMLab. All rights reserved. | ||
|
|
||
| import logging | ||
|
|
||
| import torch.nn as nn | ||
|
|
||
| from .base import BasicAdapter, BasicAdapterFast | ||
| from .internlm import InternLMAdapter | ||
| from .llama2 import Llama2Adapter | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def _get_default_adapter(tokenizer): | ||
| if tokenizer.is_fast: | ||
| return BasicAdapterFast | ||
| else: | ||
| return BasicAdapter | ||
|
|
||
|
|
||
| def init_adapter(model: nn.Module, tokenizer, adapter=None): | ||
| if adapter is None: | ||
| for v in model.modules(): | ||
| if 'InternLMModel' in v.__class__.__name__: | ||
| Adapter = InternLMAdapter | ||
| break | ||
| elif 'LlamaModel' in v.__class__.__name__: | ||
| Adapter = Llama2Adapter | ||
| break | ||
| else: | ||
| Adapter = _get_default_adapter(tokenizer) | ||
| elif adapter == 'llama1': | ||
| Adapter = _get_default_adapter(tokenizer) | ||
| else: | ||
| raise ValueError(f'Adapter {adapter} is not allowed.') | ||
|
|
||
| logger.info(f'Using adapter {Adapter.__name__}') | ||
|
|
||
| return Adapter(tokenizer) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,127 @@ | ||
| # Copyright (c) OpenMMLab. All rights reserved. | ||
| """Basic adapter suitable for general HuggingFace models.""" | ||
|
|
||
| import logging | ||
| import re | ||
|
|
||
| from transformers import (PreTrainedTokenizer, PreTrainedTokenizerBase, | ||
| PreTrainedTokenizerFast) | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class BaseAdapter: | ||
| """Base class for all adapters. | ||
|
|
||
| Note: | ||
| Adapters coordinate with the session manager to prepare input_ids. | ||
| The full sequence fed to the model is as follows: | ||
|
|
||
| ``` | ||
| adapter.start_ids | ||
| adapter.encode_and_decorate(user_input_1) | ||
| output_1_generated_by_model | ||
| adapter.sep_ids | ||
| adapter.encode_and_decorate(user_input_2) | ||
| output_2_generated_by_model | ||
| adapter.sep_ids | ||
| adapter.encode_and_decorate(user_input_3) | ||
| ``` | ||
|
|
||
| Thus adapter is responsible for providing model specific | ||
| ``start_ids``, ``sep_ids``, and method to encode single prompt. | ||
| """ | ||
|
|
||
| def __init__(self, tokenizer: PreTrainedTokenizerBase): | ||
| self.tokenizer = tokenizer | ||
|
|
||
| def encode_and_decorate(self, prompt, add_special_tokens=False): | ||
| """Model specific method to encode and decorate prompt.""" | ||
| raise NotImplementedError | ||
|
|
||
| def decode(self, value): | ||
| """Model specific method to decode single value to string.""" | ||
| raise NotImplementedError | ||
|
|
||
| @property | ||
| def stopping_criteria(self): | ||
| """Model specific stopping criteria for generation.""" | ||
| return None | ||
|
|
||
| @property | ||
| def start_ids(self): | ||
| """Model specific start ids.""" | ||
| return [self.tokenizer.bos_token_id] | ||
|
|
||
| @property | ||
| def sep_ids(self): | ||
| """Model specific separation ids.""" | ||
| return [self.tokenizer.bos_token_id] | ||
|
|
||
|
|
||
| class BasicAdapter(BaseAdapter): | ||
| """Basic adapter for slow tokenizers.""" | ||
|
|
||
| def encode_and_decorate(self, prompt, add_special_tokens=False): | ||
| """Encode prompt. | ||
|
|
||
| Note: | ||
| we leave <bos> to session manager to add. | ||
| """ | ||
| input_ids = self.tokenizer.encode( | ||
| prompt, | ||
| add_special_tokens=add_special_tokens, | ||
| return_tensors='pt', | ||
| ) | ||
| logger.debug(f'Encode {prompt} to {input_ids}') | ||
| return input_ids | ||
|
|
||
| def decode(self, value): | ||
| """Fallback when tokenizer is not fast.""" | ||
|
|
||
| self.tokenizer: PreTrainedTokenizer | ||
|
|
||
| tok = self.tokenizer.decode(value) | ||
| return tok + ' ' | ||
|
|
||
|
|
||
| class BasicAdapterFast(BaseAdapter): | ||
| """Basic adapter for slow tokenizers.""" | ||
|
|
||
| hex_regex = re.compile(r'^<0x([0-9ABCDEF]+)>$') | ||
|
|
||
| def encode_and_decorate(self, prompt, add_special_tokens=False): | ||
| """Encode prompt. | ||
|
|
||
| Note: | ||
| we leave <bos> to session manager to add. | ||
| """ | ||
| input_ids = self.tokenizer.encode( | ||
| prompt, | ||
| add_special_tokens=add_special_tokens, | ||
| return_tensors='pt', | ||
| ) | ||
| logger.debug(f'Encode {prompt} to {input_ids}') | ||
| return input_ids | ||
|
|
||
| def decode(self, value): | ||
| """Decode with fast tokenizers.""" | ||
|
|
||
| self.tokenizer: PreTrainedTokenizerFast | ||
|
|
||
| tok = self.tokenizer._convert_id_to_token(value) | ||
| if tok.startswith('▁'): # sentencepiece | ||
| space = ' ' | ||
| tok = tok[1:] | ||
| else: | ||
| space = '' | ||
| if res := self.hex_regex.match(tok): | ||
| tok = chr(int(res.group(1), 16)) | ||
| if tok == '</s>' or tok == '\r': | ||
| tok = '\n' | ||
|
|
||
| tok = space + tok | ||
|
|
||
| logger.debug(f'Decode {value} to {repr(tok)}') | ||
|
|
||
| return tok |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| # Copyright (c) OpenMMLab. All rights reserved. | ||
| import logging | ||
| import re | ||
|
|
||
| import torch | ||
| from transformers import (PreTrainedTokenizerFast, StoppingCriteria, | ||
| StoppingCriteriaList) | ||
|
|
||
| from .base import BaseAdapter | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class InternLMStoppingCriteria(StoppingCriteria): | ||
| """Stopping criteria for HF version of InternLM.""" | ||
|
|
||
| def __call__(self, input_ids, *args, **kwargs) -> bool: | ||
| return input_ids[0, -1] in [2, 103028] | ||
|
|
||
|
|
||
| class InternLMAdapter(BaseAdapter): | ||
| """Adapter for InternLM. | ||
|
|
||
| InternLM use the following template and \n should be 13. | ||
|
|
||
| <bos> (no actual newline here, just for better readability) | ||
| <|User|>:{prompt}<eoh>\n | ||
| <|Bot|>:{model_output}<eoa>\n | ||
| <|User|>:{prompt}<eoh>\n | ||
| <|Bot|>:{model_output}<eoa>\n | ||
| ... | ||
| <eos> | ||
| """ | ||
|
|
||
| hex_regex = re.compile(r'^<0x([0-9ABCDEF]+)>$') | ||
| # ids of '<|User|>:' | ||
| B_USER_ID = torch.tensor([[333, 352, 1621, 352, 27232]]) | ||
| # ids of '<eoh>\n<|Bot|>:' | ||
| E_USER_ID = torch.tensor([[103027, 13, 333, 352, 23845, 352, 27232]]) | ||
| # ids of '<bos>' | ||
| start_ids = [1] | ||
| # ids of '\n' | ||
| sep_ids = [13] | ||
|
|
||
| def __init__(self, tokenizer: PreTrainedTokenizerFast): | ||
| self.tokenizer = tokenizer | ||
|
|
||
| def encode_and_decorate(self, prompt): | ||
| r"""Encode prompt and decorate with template. | ||
|
|
||
| Note: | ||
| we leave <bos> and chat history for session manager to add, | ||
| so we will decorate input_ids to '<|User|>:{prompt}<eoh>\n<|Bot|>:' | ||
| """ | ||
| input_ids = self.tokenizer.encode( | ||
| prompt, | ||
| add_special_tokens=False, | ||
| return_tensors='pt', | ||
| ) | ||
| # This is f'<|User|>:{prompt}<eoh>\n<|Bot|>:' | ||
| # but force \n to 13 instead of 364 | ||
| input_ids = torch.cat([self.B_USER_ID, input_ids, self.E_USER_ID], | ||
| dim=1) | ||
| return input_ids | ||
|
|
||
| def decode(self, value): | ||
| """Decode generated tokens for InternLM.""" | ||
|
|
||
| tok = self.tokenizer.decode(value) | ||
| if res := self.hex_regex.match(tok): | ||
| tok = chr(int(res.group(1), 16)) | ||
| if tok == '</s>' or tok == '<eoa>' or tok == '\r': | ||
| tok = '\n' | ||
|
|
||
| logger.debug(f'Decode {value} to {repr(tok)}') | ||
|
|
||
| return tok | ||
|
|
||
| @property | ||
| def stopping_criteria(self): | ||
| return StoppingCriteriaList([InternLMStoppingCriteria()]) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.