Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ def stop_words(self):
"""Return the stop-words' token ids."""
return None

def update_input_ids(self, input_ids: List[int]):
"""Further modify input ids of the prompt."""
return input_ids


@MODELS.register_module(name='vicuna')
class Vicuna(BaseModel):
Expand Down Expand Up @@ -481,6 +485,25 @@ def stop_words(self):
return [151645] # <|im_end|>


@MODELS.register_module(name='chatglm2-6b')
class ChatGLM2(BaseModel):

def __init__(self):
super().__init__()
self.count = 0

def get_prompt(self, prompt, sequence_start=True):
# need more check
# https://github.com/THUDM/ChatGLM2-6B/issues/48
# [64790, 64792] to be prepended
self.count += 1
return f'[Round {self.count}]\n\n问:{prompt}\n\n答:'

def update_input_ids(self, input_ids: List):
input_ids = [64790, 64792] + input_ids
return input_ids


def main(model_name: str = 'test'):
assert model_name in MODELS.module_dict.keys(), \
f"'{model_name}' is not supported. " \
Expand Down
10 changes: 7 additions & 3 deletions lmdeploy/pytorch_poc/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def main(
model_path,
model_name: str, # can not get model_name from hf model
session_id: int = 1,
top_k=40,
top_p=0.8,
temperature=0.8,
repetition_penalty: float = 1.0,
tp: int = 1,
stream_output=True):
Expand Down Expand Up @@ -73,12 +76,13 @@ def main(
continue
prompt = model.get_prompt(prompt, nth_round == 1)
input_ids = tokenizer.encode(prompt)
input_ids = model.update_input_ids(input_ids)
print(f'{prompt} ', end='', flush=True)
response_size = 0
sampling_param = SamplingParam(
top_k=40,
top_p=0.8,
temperature=0.8,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
ignore_eos=False,
random_seed=seed,
Expand Down
27 changes: 19 additions & 8 deletions lmdeploy/pytorch_poc/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,14 +447,25 @@ def __init__(
cache_config = CacheConfig(block_size=64,
num_cpu_blocks=0,
num_gpu_blocks=0)
model_config = ModelConfig(
hf_config.hidden_size,
hf_config.num_hidden_layers,
hf_config.num_attention_heads,
bos_token_id=hf_config.bos_token_id,
eos_token_id=hf_config.eos_token_id,
dtype=torch_dtype,
)
if 'chatglm' in model_path:
model_config = ModelConfig(
hf_config.hidden_size // hf_config.num_attention_heads *
hf_config.multi_query_group_num,
hf_config.num_layers,
hf_config.multi_query_group_num,
bos_token_id=hf_config.bos_token_id,
eos_token_id=hf_config.eos_token_id,
dtype=torch_dtype,
)
else:
model_config = ModelConfig(
hf_config.hidden_size,
hf_config.num_hidden_layers,
hf_config.num_attention_heads,
bos_token_id=hf_config.bos_token_id,
eos_token_id=hf_config.eos_token_id,
dtype=torch_dtype,
)

self.scheduler_config = scheduler_config
self.cache_config = cache_config
Expand Down
Loading