Skip to content

Commit

Permalink
[LLM] support Qwen2 (#8338)
Browse files Browse the repository at this point in the history
* add Qwen2Moe

* update default config

* update QWen2Moe modeling

* update modeling

* update ckpt name

* support same prefix model name for auto modeling

* update qwen2moe testing

* update qwen2moe modeling and config

* update qwen2moe import

* fix mlp hidden_size

* update qkv bias convert

* update modeling init_weight

* update _get_name_mappings

* update _get_name_mappings and _init_weight

* add tokenizer

* update modeling

* update modeling

* update  tokenizer

* update modeling and tokenizer

* fix index_add_ error

* fix

* update comments

* update lora weights

* add todo

* update Copyright

* update Moe to MoE

* update comment

* update Copyright

* update readme and json

* update __init__.py

* add qwen-1.5

* update QWen to Qwen

* update Qwen2MoE to Qwen2Moe

* update readme

* update qwen2moe sft and lora json

* update qwen2moe base name

* update qwen2

* update

* update readme

* update readme

* update readme
  • Loading branch information
DrownFish19 committed Jun 11, 2024
1 parent 909be01 commit 4609d07
Show file tree
Hide file tree
Showing 17 changed files with 4,543 additions and 14 deletions.
14 changes: 12 additions & 2 deletions llm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,21 @@ def get_convert_example(model):

if base_model_prefix == "chatglm":
return convert_example_chatglm
elif base_model_prefix in ["chatglm_v2", "llama", "bloom", "opt", "qwen", "mixtral", "gemma"]:
elif base_model_prefix in [
"chatglm_v2",
"llama",
"bloom",
"opt",
"qwen",
"mixtral",
"gemma",
"qwen2",
"qwen2_moe",
]:
return convert_example_common
else:
raise ValueError(
f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama, qwen, mixtral, gemma"
f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama, qwen, mixtral, gemma, qwen2, qwen2_moe",
)


Expand Down
60 changes: 51 additions & 9 deletions llm/qwen/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,57 @@
[通义千问(Qwen)](https://arxiv.org/abs/2205.01068) 是阿里云研发的通义千问大模型系列的模型, 有 70 亿和 140 亿两个规模。Qwen是基于Transformer的大语言模型, 在超大规模的预训练数据上进行训练得到。预训练数据类型多样,覆盖广泛,包括大量网络文本、专业书籍、代码等。

**支持模型权重:**
| Model |
|-------------------|
| qwen/qwen-7b |
| qwen/qwen-7b-chat |
| qwen/qwen-14b |
| qwen/qwen-14b-chat|
| qwen/qwen-72b |
| qwen/qwen-72b-chat|
| qwen/qwen1.5-moe-a2.7b|
| Model |
|--------------------|
| qwen/qwen-7b |
| qwen/qwen-7b-chat |
| qwen/qwen-14b |
| qwen/qwen-14b-chat |
| qwen/qwen-72b |
| qwen/qwen-72b-chat |



[通义千问(Qwen1.5)](https://qwenlm.github.io/blog/qwen1.5/) 是阿里云研发的通义千问系列模型升级版。Qwen1.5包括0.5B、1.8B、4B、7B、14B、32B、72B、110B和MoE共计9个不同规模的Base和Chat模型。

**支持模型权重:**
| Model (qwen-1.5) |
|-----------------------------|
| Qwen/Qwen1.5-0.5B |
| Qwen/Qwen1.5-0.5B-Chat |
| Qwen/Qwen1.5-1.8B |
| Qwen/Qwen1.5-1.8B-Chat |
| Qwen/Qwen1.5-4B |
| Qwen/Qwen1.5-4B-Chat |
| Qwen/Qwen1.5-7B |
| Qwen/Qwen1.5-7B-Chat |
| Qwen/Qwen1.5-14B |
| Qwen/Qwen1.5-14B-Chat |
| Qwen/Qwen1.5-32B |
| Qwen/Qwen1.5-32B-Chat |
| Qwen/Qwen1.5-72B |
| Qwen/Qwen1.5-72B-Chat |
| Qwen/Qwen1.5-110B |
| Qwen/Qwen1.5-110B-Chat |
| Qwen/Qwen1.5-MoE-A2.7B |
| Qwen/Qwen1.5-MoE-A2.7B-Chat |


[通义千问(Qwen2)](https://qwenlm.github.io/blog/qwen2/) 是阿里云研发的通义千问系列模型升级版。Qwen2包括0.5B、1.5B、7B、72B和MoE共计5个不同规模的Base和Chat模型。
**支持模型权重:**
| Model (qwen2) |
|------------------------------|
| Qwen/Qwen2-0.5B |
| Qwen/Qwen2-0.5B-Instruct |
| Qwen/Qwen2-1.5B |
| Qwen/Qwen2-1.5B-Instruct |
| Qwen/Qwen2-7B |
| Qwen/Qwen2-7B-Instruct |
| Qwen/Qwen2-72B |
| Qwen/Qwen2-72B-Instruct |
| Qwen/Qwen2-57B-A14B |
| Qwen/Qwen2-57B-A14B-Instruct |


## 2. 模型精调
请参考[LLM全流程工具介绍](../README.md)
32 changes: 32 additions & 0 deletions llm/qwen/lora_argument_qwen2moe.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"model_name_or_path": "Qwen/Qwen1.5-MoE-A2.7B",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/qwen2moe_lora_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"learning_rate": 3e-04,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 1024,
"max_length": 32768,
"bf16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 8,
"pipeline_parallel_degree": 1,
"lora": true,
"zero_padding": false,
"use_flash_attention": false
}
30 changes: 30 additions & 0 deletions llm/qwen/sft_argument_qwen2moe.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"model_name_or_path": "Qwen/Qwen1.5-MoE-A2.7B",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/qwen2moe_sft_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"learning_rate": 3e-05,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 1024,
"max_length": 32768,
"bf16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 8,
"sharding": "stage2",
"pipeline_parallel_degree": 1
}
13 changes: 12 additions & 1 deletion llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@


def compute_metrics(eval_preds):

flattened_preds = np.array(eval_preds.predictions).flatten()
flattened_labels = np.array(eval_preds.label_ids).flatten()
filtered_preds = flattened_preds[flattened_labels != -100]
Expand Down Expand Up @@ -157,10 +156,22 @@ def get_lora_target_modules(model):
".*k_proj.*",
".*v_proj.*",
".*o_proj.*",
# ".*gate.*", # TODO(DrownFish19): Does the gate weight require training?
".*w1.*",
".*w2.*",
".*w3.*",
]
elif model.base_model_prefix == "qwen2_moe":
target_modules = [
".*q_proj.*",
".*k_proj.*",
".*v_proj.*",
".*o_proj.*",
# ".*gate.*", # TODO(DrownFish19): Does the gate weight require training?
".*gate_proj.*",
".*up_proj.*",
".*down_proj.*",
]
else:
raise ValueError(f"Unknown base_model_prefix: {model.base_model_prefix}.")
return target_modules
Expand Down
2 changes: 2 additions & 0 deletions paddlenlp/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,8 @@
from .deberta_v2.modeling import *
from .deberta_v2.tokenizer import *
from .deberta_v2.configuration import *
from .qwen2 import *
from .qwen2_moe import *

# For faster tokenizer
from ..utils.import_utils import is_fast_tokenizer_available
Expand Down
11 changes: 9 additions & 2 deletions paddlenlp/transformers/auto/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@
("Bloom", "bloom"),
("QWen", "qwen"),
("Mixtral", "mixtral"),
("Qwen2", "qwen2"),
("Qwen2Moe", "qwen2_moe"),
("Gemma", "gemma"),
]
)
Expand Down Expand Up @@ -215,15 +217,20 @@ def _get_model_class_from_config(cls, pretrained_model_name_or_path, config_file
else:
init_class = config.pop("init_class", None)
init_class = init_class[:-5] if init_class is not None and init_class.endswith("Model") else init_class

# Sort the MAPPING_NAMES to reorder the model class names with longest-first rule
# thus the names with same prefix can be correctly inferred
# such as QWen and QWen2MOE, QWen2MOE is the longest prefix of QWen2MOEModel
model_name = None
SORTED_MAPPING_NAMES = dict(sorted(MAPPING_NAMES.items(), key=lambda x: len(x[0]), reverse=True))
if init_class:
for model_flag, name in MAPPING_NAMES.items():
for model_flag, name in SORTED_MAPPING_NAMES.items():
if model_flag in init_class:
model_name = model_flag + "Model"
break
else:
# From pretrained_model_name_or_path
for model_flag, name in MAPPING_NAMES.items():
for model_flag, name in SORTED_MAPPING_NAMES.items():
if name in pretrained_model_name_or_path.lower():
model_name = model_flag + "Model"
break
Expand Down
18 changes: 18 additions & 0 deletions paddlenlp/transformers/qwen2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2024 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .configuration import *
from .modeling import *
from .tokenizer import *
Loading

0 comments on commit 4609d07

Please sign in to comment.