diff --git a/paddlenlp/transformers/qwen/configuration.py b/paddlenlp/transformers/qwen/configuration.py index d61a93fd6b7..cc0004b978d 100644 --- a/paddlenlp/transformers/qwen/configuration.py +++ b/paddlenlp/transformers/qwen/configuration.py @@ -43,6 +43,8 @@ def __init__( use_flash_attention=False, use_fused_rms_norm=False, use_fused_rope=False, + fuse_attention_ffn=False, + sequence_parallel=False, intermediate_size=22016, tensor_parallel_output=True, no_bias=True, @@ -77,6 +79,8 @@ def __init__( self.use_flash_attention = use_flash_attention self.use_fused_rms_norm = use_fused_rms_norm self.use_fused_rope = use_fused_rope + self.fuse_attention_ffn = fuse_attention_ffn + self.sequence_parallel = sequence_parallel self.no_bias = no_bias self.long_sequence_strategy_type = long_sequence_strategy_type diff --git a/paddlenlp/transformers/qwen/modeling.py b/paddlenlp/transformers/qwen/modeling.py index 056ae08cbe1..54897ccf5f3 100755 --- a/paddlenlp/transformers/qwen/modeling.py +++ b/paddlenlp/transformers/qwen/modeling.py @@ -26,6 +26,16 @@ from paddle.distributed.fleet.utils import recompute from paddle.utils import try_import +try: + from paddle.incubate.nn.functional import swiglu +except ImportError: + + def swiglu(x, y=None): + if y is None: + x, y = paddle.chunk(x, chunks=2, axis=-1) + return F.silu(x) * y + + from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies from paddlenlp.transformers.model_outputs import ( BaseModelOutputWithPast, @@ -35,6 +45,7 @@ from paddlenlp.utils.log import logger from ...utils.converter import StateDictNameMapping, init_name_mappings +from .. import linear_utils from ..model_outputs import ModelOutput from .configuration import QWenConfig @@ -329,37 +340,60 @@ class QWenMLP(nn.Layer): def __init__(self, config): super().__init__() ff_dim_in = config.intermediate_size // 2 + self.fuse_attention_ffn = config.fuse_attention_ffn + + if config.sequence_parallel: + ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear + RowParallelLinear = linear_utils.RowSequenceParallelLinear + else: + ColumnParallelLinear = linear_utils.ColumnParallelLinear + RowParallelLinear = linear_utils.RowParallelLinear + if config.tensor_parallel_degree > 1: - self.w1 = mpu.ColumnParallelLinear( - config.hidden_size, - ff_dim_in, - gather_output=False, - has_bias=False, - ) - self.w2 = mpu.ColumnParallelLinear( - config.hidden_size, - ff_dim_in, - gather_output=False, - has_bias=False, - ) - self.c_proj = mpu.RowParallelLinear( + if self.fuse_attention_ffn: + self.gate_up_fused_proj = ColumnParallelLinear( + config.hidden_size, + ff_dim_in * 2, + gather_output=False, + has_bias=False, + ) + else: + self.w1 = ColumnParallelLinear( + config.hidden_size, + ff_dim_in, + gather_output=False, + has_bias=False, + ) + self.w2 = ColumnParallelLinear( + config.hidden_size, + ff_dim_in, + gather_output=False, + has_bias=False, + ) + self.c_proj = RowParallelLinear( ff_dim_in, config.hidden_size, input_is_parallel=True, has_bias=False, ) else: - self.w1 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias) - self.w2 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias) + if self.fuse_attention_ffn: + self.gate_up_fused_proj = nn.Linear(config.hidden_size, ff_dim_in * 2, bias_attr=not config.no_bias) + else: + self.w1 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias) + self.w2 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias) self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias_attr=not config.no_bias) def forward(self, hidden_states): # up - a1 = self.w1(hidden_states) - # gate - a2 = self.w2(hidden_states) - intermediate_parallel = a1 * F.silu(a2) - # down + # a1 = self.w1(hidden_states) + # # gate + # a2 = self.w2(hidden_states) + # intermediate_parallel = a1 * F.silu(a2) + if self.fuse_attention_ffn: + intermediate_parallel = swiglu(self.gate_up_fused_proj(hidden_states)) + else: + intermediate_parallel = swiglu(self.w2(hidden_states), self.w1(hidden_states)) output = self.c_proj(intermediate_parallel) return output