Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
9a53c44
add some dist utils
wangruohui Jul 25, 2023
a3146f3
add model utils
wangruohui Jul 25, 2023
66ed2f8
add termio and basicstreamer
wangruohui Jul 25, 2023
a10f2ed
typo
wangruohui Jul 25, 2023
248f7b4
fix world size
wangruohui Jul 26, 2023
5417e96
refactor chat and tested llama1
wangruohui Jul 26, 2023
0322f40
add internlm adapter and support stoping criteria
wangruohui Jul 26, 2023
a272ed7
concat with id for internlm
wangruohui Jul 28, 2023
15a6025
update docstring
wangruohui Jul 28, 2023
1c9adc9
update and support llama2
wangruohui Jul 28, 2023
117b0ee
typo
wangruohui Jul 28, 2023
f59d565
move docs to docs
wangruohui Jul 28, 2023
8dfb221
update docstring of session manager
wangruohui Jul 28, 2023
b384b26
update docstring
wangruohui Jul 28, 2023
a7add95
update docs
wangruohui Jul 28, 2023
097fc38
fix accel none in model
wangruohui Aug 1, 2023
6492fdc
fix and add test for tensor broadcast
wangruohui Aug 1, 2023
4275366
fix session using typing to check type
wangruohui Aug 1, 2023
50b2571
add docstrings and comprehensive condition test
wangruohui Aug 1, 2023
b22ef94
unit test for dist
wangruohui Aug 1, 2023
d76499e
fix session
wangruohui Aug 1, 2023
879659b
split unittests of utils
wangruohui Aug 1, 2023
5d7a261
typo
wangruohui Aug 1, 2023
c8fbc0b
update control flow of accel
wangruohui Aug 1, 2023
9b451e3
move test model
wangruohui Aug 1, 2023
d3c3fe4
remove main in unittest
wangruohui Aug 1, 2023
ba0f771
remove some log
wangruohui Aug 3, 2023
e5bbda6
remove some comments
wangruohui Aug 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,7 @@ For the deployment of other supported models, such as LLaMA, LLaMA-2, vicuna and

### Inference with PyTorch

You have to install deepspeed first before running with PyTorch.

```
pip install deepspeed
```
For detailed instructions on Inference pytorch models, see [here](docs/en/pytorch.md).

#### Single GPU

Expand All @@ -145,6 +141,12 @@ deepspeed --module --num_gpus 2 lmdeploy.pytorch.chat \
--seed 0
```

You need to install deepspeed first to use this feature.

```
pip install deepspeed
```

## Quantization

In fp16 mode, kv_cache int8 quantization can be enabled, and a single card can serve more users.
Expand Down
74 changes: 74 additions & 0 deletions docs/en/pytorch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Pytorch

## Chat in command line

LMDeploy support chatting with PyTorch models with submodule `lmdeploy.pytorch.chat`.

This submodule allow user to chat with language model through command line, and optionally accelerate model using backends like deepspeed.

**Example 1**: Chat with default setting

```python
python -m lmdeploy.pytorch.chat $PATH_TO_HF_MODEL
```

**Example 2**: Disable sampling and chat history

```python
python -m lmdeploy.pytorch.chat \
$PATH_TO_LLAMA_MODEL_IN_HF_FORMAT \
--temperature 0 --max-histroy 0
```

**Example 3**: Accelerate with deepspeed inference

```python
python -m lmdeploy.pytorch.chat \
$PATH_TO_LLAMA_MODEL_IN_HF_FORMAT \
--accel deepspeed
```

Note: to use deepspeed, you need to install deepspeed, and if hope to accelerate InternLM, you need a customized version <https://github.com/wangruohui/DeepSpeed/tree/support_internlm_0.10.0>

**Example 4**: Tensor parallel the model on 2 GPUs

```python
deepspeed --module --num_gpus 2 lmdeploy.pytorch.chat \
$PATH_TO_LLAMA_MODEL_IN_HF_FORMAT \
--accel deepspeed \
```

This module also allow the following control commands to change generation behaviors during chat.

- `exit`: terminate and exit chat
- `config set key=value`: change generation config `key` to `value`, e.g. config temperature=0 disable sampling for following chats
- `clear`: clear chat history

### Simple diagram of components

```mermaid
graph LR;
subgraph model specific adapter
p((user_input))-->tokenize-->id((input_ids))-->decorate
tmpl_ids((template_ids))-->decorate;
end
subgraph generate
model[CausalLM_model.generate]-->gen_result(("gen_result"))
gen_result-->hid
gen_result-->attn((attention))
end
subgraph streamer
model-->s[streamer]--value-->decode_single--token-->output
end
subgraph session_manager
prepend_history-->fullid((complete_ids));
trim-->prepend_history
end
decorate-->prepend_history
hid((history_ids))-->trim;
attn-->trim;
fullid-->model
tokenizer((tokenizer))-->decode_single
tokenizer-->tokenize
p-->genconfig(GenConfig)-->model
```
39 changes: 39 additions & 0 deletions lmdeploy/pytorch/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) OpenMMLab. All rights reserved.

import logging

import torch.nn as nn

from .base import BasicAdapter, BasicAdapterFast
from .internlm import InternLMAdapter
from .llama2 import Llama2Adapter

logger = logging.getLogger(__name__)


def _get_default_adapter(tokenizer):
if tokenizer.is_fast:
return BasicAdapterFast
else:
return BasicAdapter


def init_adapter(model: nn.Module, tokenizer, adapter=None):
if adapter is None:
for v in model.modules():
if 'InternLMModel' in v.__class__.__name__:
Adapter = InternLMAdapter
break
elif 'LlamaModel' in v.__class__.__name__:
Adapter = Llama2Adapter
break
else:
Adapter = _get_default_adapter(tokenizer)
elif adapter == 'llama1':
Adapter = _get_default_adapter(tokenizer)
else:
raise ValueError(f'Adapter {adapter} is not allowed.')

logger.info(f'Using adapter {Adapter.__name__}')

return Adapter(tokenizer)
127 changes: 127 additions & 0 deletions lmdeploy/pytorch/adapters/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Basic adapter suitable for general HuggingFace models."""

import logging
import re

from transformers import (PreTrainedTokenizer, PreTrainedTokenizerBase,
PreTrainedTokenizerFast)

logger = logging.getLogger(__name__)


class BaseAdapter:
"""Base class for all adapters.

Note:
Adapters coordinate with the session manager to prepare input_ids.
The full sequence fed to the model is as follows:

```
adapter.start_ids
adapter.encode_and_decorate(user_input_1)
output_1_generated_by_model
adapter.sep_ids
adapter.encode_and_decorate(user_input_2)
output_2_generated_by_model
adapter.sep_ids
adapter.encode_and_decorate(user_input_3)
```

Thus adapter is responsible for providing model specific
``start_ids``, ``sep_ids``, and method to encode single prompt.
"""

def __init__(self, tokenizer: PreTrainedTokenizerBase):
self.tokenizer = tokenizer

def encode_and_decorate(self, prompt, add_special_tokens=False):
"""Model specific method to encode and decorate prompt."""
raise NotImplementedError

def decode(self, value):
"""Model specific method to decode single value to string."""
raise NotImplementedError

@property
def stopping_criteria(self):
"""Model specific stopping criteria for generation."""
return None

@property
def start_ids(self):
"""Model specific start ids."""
return [self.tokenizer.bos_token_id]

@property
def sep_ids(self):
"""Model specific separation ids."""
return [self.tokenizer.bos_token_id]


class BasicAdapter(BaseAdapter):
"""Basic adapter for slow tokenizers."""

def encode_and_decorate(self, prompt, add_special_tokens=False):
"""Encode prompt.

Note:
we leave <bos> to session manager to add.
"""
input_ids = self.tokenizer.encode(
prompt,
add_special_tokens=add_special_tokens,
return_tensors='pt',
)
logger.debug(f'Encode {prompt} to {input_ids}')
return input_ids

def decode(self, value):
"""Fallback when tokenizer is not fast."""

self.tokenizer: PreTrainedTokenizer

tok = self.tokenizer.decode(value)
return tok + ' '


class BasicAdapterFast(BaseAdapter):
"""Basic adapter for slow tokenizers."""

hex_regex = re.compile(r'^<0x([0-9ABCDEF]+)>$')

def encode_and_decorate(self, prompt, add_special_tokens=False):
"""Encode prompt.

Note:
we leave <bos> to session manager to add.
"""
input_ids = self.tokenizer.encode(
prompt,
add_special_tokens=add_special_tokens,
return_tensors='pt',
)
logger.debug(f'Encode {prompt} to {input_ids}')
return input_ids

def decode(self, value):
"""Decode with fast tokenizers."""

self.tokenizer: PreTrainedTokenizerFast

tok = self.tokenizer._convert_id_to_token(value)
if tok.startswith('▁'): # sentencepiece
space = ' '
tok = tok[1:]
else:
space = ''
if res := self.hex_regex.match(tok):
tok = chr(int(res.group(1), 16))
if tok == '</s>' or tok == '\r':
tok = '\n'

tok = space + tok

logger.debug(f'Decode {value} to {repr(tok)}')

return tok
81 changes: 81 additions & 0 deletions lmdeploy/pytorch/adapters/internlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import re

import torch
from transformers import (PreTrainedTokenizerFast, StoppingCriteria,
StoppingCriteriaList)

from .base import BaseAdapter

logger = logging.getLogger(__name__)


class InternLMStoppingCriteria(StoppingCriteria):
"""Stopping criteria for HF version of InternLM."""

def __call__(self, input_ids, *args, **kwargs) -> bool:
return input_ids[0, -1] in [2, 103028]


class InternLMAdapter(BaseAdapter):
"""Adapter for InternLM.

InternLM use the following template and \n should be 13.

<bos> (no actual newline here, just for better readability)
<|User|>:{prompt}<eoh>\n
<|Bot|>:{model_output}<eoa>\n
<|User|>:{prompt}<eoh>\n
<|Bot|>:{model_output}<eoa>\n
...
<eos>
"""

hex_regex = re.compile(r'^<0x([0-9ABCDEF]+)>$')
# ids of '<|User|>:'
B_USER_ID = torch.tensor([[333, 352, 1621, 352, 27232]])
# ids of '<eoh>\n<|Bot|>:'
E_USER_ID = torch.tensor([[103027, 13, 333, 352, 23845, 352, 27232]])
# ids of '<bos>'
start_ids = [1]
# ids of '\n'
sep_ids = [13]

def __init__(self, tokenizer: PreTrainedTokenizerFast):
self.tokenizer = tokenizer

def encode_and_decorate(self, prompt):
r"""Encode prompt and decorate with template.

Note:
we leave <bos> and chat history for session manager to add,
so we will decorate input_ids to '<|User|>:{prompt}<eoh>\n<|Bot|>:'
"""
input_ids = self.tokenizer.encode(
prompt,
add_special_tokens=False,
return_tensors='pt',
)
# This is f'<|User|>:{prompt}<eoh>\n<|Bot|>:'
# but force \n to 13 instead of 364
input_ids = torch.cat([self.B_USER_ID, input_ids, self.E_USER_ID],
dim=1)
return input_ids

def decode(self, value):
"""Decode generated tokens for InternLM."""

tok = self.tokenizer.decode(value)
if res := self.hex_regex.match(tok):
tok = chr(int(res.group(1), 16))
if tok == '</s>' or tok == '<eoa>' or tok == '\r':
tok = '\n'

logger.debug(f'Decode {value} to {repr(tok)}')

return tok

@property
def stopping_criteria(self):
return StoppingCriteriaList([InternLMStoppingCriteria()])
Loading