Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/en/user_guides/chat.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,11 @@
```shell
xtuner chat baichuan-inc/Baichuan-7B --adapter xtuner/Baichuan-7B-qlora-alpaca-enzh --prompt-template alpaca
```

## Chat with [CodeLlama](https://github.com/facebookresearch/codellama)

- CodeLlama-7B, Instruct

```shell
xtuner chat codellama/CodeLlama-7b-Instruct-hf --prompt-template code_llama_chat
```
8 changes: 8 additions & 0 deletions docs/zh_cn/user_guides/chat.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,11 @@
```shell
xtuner chat baichuan-inc/Baichuan-7B --adapter xtuner/Baichuan-7B-qlora-alpaca-enzh --prompt-template alpaca
```

## 与 [CodeLlama](https://github.com/facebookresearch/codellama) 对话

- CodeLlama-7B, Instruct

```shell
xtuner chat codellama/CodeLlama-7b-Instruct-hf --prompt-template code_llama_chat
```
15 changes: 14 additions & 1 deletion xtuner/engine/hooks/evaluate_chat_hook.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from transformers import StoppingCriteriaList
from transformers import GenerationConfig, StoppingCriteriaList

from xtuner.registry import BUILDER
from xtuner.utils import StopWordStoppingCriteria
Expand All @@ -26,6 +26,18 @@ def __init__(self,
self.max_new_tokens = max_new_tokens
self.tokenizer = BUILDER.build(tokenizer)
self.stop_criteria = StoppingCriteriaList()
# default generation config
self.gen_config = GenerationConfig(
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.1,
top_p=0.75,
top_k=40,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id
if self.tokenizer.pad_token_id is not None else
self.tokenizer.eos_token_id,
)
if stop_word is not None:
self.stop_criteria.append(
StopWordStoppingCriteria(self.tokenizer, stop_word))
Expand Down Expand Up @@ -55,6 +67,7 @@ def _generate_samples(self, runner, max_new_tokens=None):
generation_output = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
generation_config=self.gen_config,
stopping_criteria=self.stop_criteria)
runner.logger.info(
f'Sample output:\n'
Expand Down
3 changes: 3 additions & 0 deletions xtuner/tools/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ def main():
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id
if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
)

n_turn = 0
Expand Down
4 changes: 4 additions & 0 deletions xtuner/utils/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
'ensure that your responses are socially unbiased and positive in '
'nature. \n<</SYS>>\n\n{input} [/INST]'),
INSTRUCTION='[INST] {input} [/INST]'),
code_llama_chat=dict(
INSTRUCTION_START='[INST] {input} [/INST]',
INSTRUCTION='[INST] {input} [/INST]',
),
internlm_chat=dict(
INSTRUCTION_START='<|User|>:{input}<eoh>\n<|Bot|>:',
INSTRUCTION='<|User|>:{input}<eoh>\n<|Bot|>:'),
Expand Down