Skip to content

Commit

Permalink
add fuse_attention_ffn support for qwen (#8526)
Browse files Browse the repository at this point in the history
  • Loading branch information
deepllz committed Jun 6, 2024
1 parent 79e8b6e commit d06b327
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 20 deletions.
4 changes: 4 additions & 0 deletions paddlenlp/transformers/qwen/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def __init__(
use_flash_attention=False,
use_fused_rms_norm=False,
use_fused_rope=False,
fuse_attention_ffn=False,
sequence_parallel=False,
intermediate_size=22016,
tensor_parallel_output=True,
no_bias=True,
Expand Down Expand Up @@ -77,6 +79,8 @@ def __init__(
self.use_flash_attention = use_flash_attention
self.use_fused_rms_norm = use_fused_rms_norm
self.use_fused_rope = use_fused_rope
self.fuse_attention_ffn = fuse_attention_ffn
self.sequence_parallel = sequence_parallel
self.no_bias = no_bias

self.long_sequence_strategy_type = long_sequence_strategy_type
Expand Down
74 changes: 54 additions & 20 deletions paddlenlp/transformers/qwen/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@
from paddle.distributed.fleet.utils import recompute
from paddle.utils import try_import

try:
from paddle.incubate.nn.functional import swiglu
except ImportError:

def swiglu(x, y=None):
if y is None:
x, y = paddle.chunk(x, chunks=2, axis=-1)
return F.silu(x) * y


from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies
from paddlenlp.transformers.model_outputs import (
BaseModelOutputWithPast,
Expand All @@ -35,6 +45,7 @@
from paddlenlp.utils.log import logger

from ...utils.converter import StateDictNameMapping, init_name_mappings
from .. import linear_utils
from ..model_outputs import ModelOutput
from .configuration import QWenConfig

Expand Down Expand Up @@ -329,37 +340,60 @@ class QWenMLP(nn.Layer):
def __init__(self, config):
super().__init__()
ff_dim_in = config.intermediate_size // 2
self.fuse_attention_ffn = config.fuse_attention_ffn

if config.sequence_parallel:
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear
else:
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear

if config.tensor_parallel_degree > 1:
self.w1 = mpu.ColumnParallelLinear(
config.hidden_size,
ff_dim_in,
gather_output=False,
has_bias=False,
)
self.w2 = mpu.ColumnParallelLinear(
config.hidden_size,
ff_dim_in,
gather_output=False,
has_bias=False,
)
self.c_proj = mpu.RowParallelLinear(
if self.fuse_attention_ffn:
self.gate_up_fused_proj = ColumnParallelLinear(
config.hidden_size,
ff_dim_in * 2,
gather_output=False,
has_bias=False,
)
else:
self.w1 = ColumnParallelLinear(
config.hidden_size,
ff_dim_in,
gather_output=False,
has_bias=False,
)
self.w2 = ColumnParallelLinear(
config.hidden_size,
ff_dim_in,
gather_output=False,
has_bias=False,
)
self.c_proj = RowParallelLinear(
ff_dim_in,
config.hidden_size,
input_is_parallel=True,
has_bias=False,
)
else:
self.w1 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias)
self.w2 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias)
if self.fuse_attention_ffn:
self.gate_up_fused_proj = nn.Linear(config.hidden_size, ff_dim_in * 2, bias_attr=not config.no_bias)
else:
self.w1 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias)
self.w2 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias)
self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias_attr=not config.no_bias)

def forward(self, hidden_states):
# up
a1 = self.w1(hidden_states)
# gate
a2 = self.w2(hidden_states)
intermediate_parallel = a1 * F.silu(a2)
# down
# a1 = self.w1(hidden_states)
# # gate
# a2 = self.w2(hidden_states)
# intermediate_parallel = a1 * F.silu(a2)
if self.fuse_attention_ffn:
intermediate_parallel = swiglu(self.gate_up_fused_proj(hidden_states))
else:
intermediate_parallel = swiglu(self.w2(hidden_states), self.w1(hidden_states))
output = self.c_proj(intermediate_parallel)
return output

Expand Down

0 comments on commit d06b327

Please sign in to comment.