Skip to content

Commit

Permalink
add Qwen2Moe
Browse files Browse the repository at this point in the history
  • Loading branch information
DrownFish19 committed Apr 16, 2024
1 parent 110983d commit 113b883
Show file tree
Hide file tree
Showing 9 changed files with 2,204 additions and 2 deletions.
4 changes: 2 additions & 2 deletions llm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ def get_convert_example(model):

if base_model_prefix == "chatglm":
return convert_example_chatglm
elif base_model_prefix in ["chatglm_v2", "llama", "bloom", "opt", "qwen", "mixtral"]:
elif base_model_prefix in ["chatglm_v2", "llama", "bloom", "opt", "qwen", "mixtral", "qwen2moe"]:
return convert_example_common
else:
raise ValueError(
f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama, qwen, mixtral"
f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama, qwen, mixtral, qwen2moe"
)


Expand Down
32 changes: 32 additions & 0 deletions llm/qwen2moe/lora_argument.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"model_name_or_path": "Qwen/Qwen1.5-MoE-A2.7B",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/qwen2moe_lora_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"learning_rate": 3e-04,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 1024,
"max_length": 2048,
"fp16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 8,
"pipeline_parallel_degree": 1,
"lora": true,
"zero_padding": false,
"use_flash_attention": false
}
30 changes: 30 additions & 0 deletions llm/qwen2moe/sft_argument.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"model_name_or_path": "Qwen/Qwen1.5-MoE-A2.7B",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/qwen2moe_sft_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"learning_rate": 3e-05,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 1024,
"max_length": 2048,
"bf16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 8,
"sharding": "stage2",
"pipeline_parallel_degree": 1
}
1 change: 1 addition & 0 deletions paddlenlp/transformers/auto/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
("Bloom", "bloom"),
("QWen", "qwen"),
("Mixtral", "mixtral"),
("QWen2Moe", "qwen2moe"),
]
)

Expand Down
16 changes: 16 additions & 0 deletions paddlenlp/transformers/qwen2moe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .configuration import Qwen2MoeConfig
from .modeling import Qwen2MoeForCausalLM
203 changes: 203 additions & 0 deletions paddlenlp/transformers/qwen2moe/configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Qwen2MoE model configuration"""

from paddlenlp.transformers.configuration_utils import PretrainedConfig

__all__ = [
"Qwen2MoeConfig",
]


class Qwen2MoeConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen2MoeModel`]. It is used to instantiate a
Qwen2MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen1.5-MoE-A2.7B" [Qwen/Qwen1.5-MoE-A2.7B"](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B").
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen2MoeModel`]
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 5632):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 16):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
max_window_layers (`int`, *optional*, defaults to 28):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
decoder_sparse_step (`int`, *optional*, defaults to 1):
The frequency of the MoE layer.
moe_intermediate_size (`int`, *optional*, defaults to 1408):
Intermediate size of the routed expert.
shared_expert_intermediate_size (`int`, *optional*, defaults to 5632):
Intermediate size of the shared expert.
num_experts_per_tok (`int`, *optional*, defaults to 4):
Number of selected experts.
num_experts (`int`, *optional*, defaults to 60):
Number of routed experts.
norm_topk_prob (`bool`, *optional*, defaults to `False`):
Whether to normalize the topk probabilities.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabeling this will also
allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss.
```python
>>> from paddlenlp.transformers import Qwen2MoeModel, Qwen2MoeConfig
>>> # Initializing a Qwen2MoE style configuration
>>> configuration = Qwen2MoeConfig()
>>> # Initializing a model from the Qwen1.5-MoE-A2.7B" style configuration
>>> model = Qwen2MoeModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "qwen2_moe"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size=151936,
hidden_size=2048,
intermediate_size=5632,
num_hidden_layers=24,
num_attention_heads=16,
num_key_value_heads=16,
hidden_act="silu",
max_position_embeddings=32768,
seq_length=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
use_recompute=False,
recompute_granularity="full",
no_recompute_layers=None,
use_flash_attention=False,
attention_dropout=0.0,
use_fused_rope=False,
rope_theta=10000.0,
tensor_parallel_output=True,
sequence_parallel=False,
fuse_sequence_parallel_allreduce=False,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=28,
decoder_sparse_step=1,
moe_intermediate_size=1408,
shared_expert_intermediate_size=5632,
num_experts_per_tok=4,
num_experts=60,
norm_topk_prob=False,
output_router_logits=False,
router_aux_loss_coef=0.001,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.seq_length = seq_length
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window
self.max_window_layers = max_window_layers

self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act

self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps

self.use_cache = use_cache
self.use_recompute = use_recompute
self.recompute_granularity = recompute_granularity
self.no_recompute_layers = no_recompute_layers
self.use_flash_attention = use_flash_attention
self.tensor_parallel_output = tensor_parallel_output
self.sequence_parallel = sequence_parallel
self.fuse_sequence_parallel_allreduce = fuse_sequence_parallel_allreduce

self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id

self.use_fused_rope = use_fused_rope
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout

# MoE arguments
self.decoder_sparse_step = decoder_sparse_step
self.moe_intermediate_size = moe_intermediate_size
self.shared_expert_intermediate_size = shared_expert_intermediate_size
self.num_experts_per_tok = num_experts_per_tok
self.num_experts = num_experts
self.norm_topk_prob = norm_topk_prob
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
tensor_parallel_output=tensor_parallel_output,
**kwargs,
)
Loading

0 comments on commit 113b883

Please sign in to comment.