Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add long sequence strategies #8076

Merged
merged 30 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions paddlenlp/transformers/bloom/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ def __init__(
use_recompute=False,
use_pure_fp16=False,
use_flash_attention=False,
long_sequence_strategy_type=None,
long_sequence_strategy_name=None,
long_sequence_init_args=None,
use_long_sequence_strategies=False,
**kwargs,
):

Expand All @@ -150,3 +154,8 @@ def __init__(
self.use_recompute = use_recompute
self.use_pure_fp16 = use_pure_fp16
self.use_flash_attention = use_flash_attention

self.long_sequence_strategy_type = long_sequence_strategy_type
self.long_sequence_strategy_name = long_sequence_strategy_name
self.long_sequence_init_args = {} if long_sequence_init_args is None else long_sequence_init_args
self.use_long_sequence_strategies = use_long_sequence_strategies
24 changes: 21 additions & 3 deletions paddlenlp/transformers/bloom/modeling.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from paddle.distributed import fleet
from paddle.distributed.fleet.utils import recompute

from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies
from paddlenlp.transformers.model_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Expand Down Expand Up @@ -944,10 +945,27 @@
attention_mask = paddle.cast(attention_mask, "bool")
if len(attention_mask.shape) > 2:
_attention_mask = paddle.ones([batch_size, seq_length_with_past], dtype="bool")
alibi = build_alibi_tensor(_attention_mask, self.config.n_head, dtype=hidden_states.dtype)
if self.config.use_long_sequence_strategies:
alibi_layer = LongSequenceStrategies.build_long_sequence_strategy(
self.config.long_sequence_strategy_type,
self.config.long_sequence_strategy_name,

Check warning on line 951 in paddlenlp/transformers/bloom/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/bloom/modeling.py#L951

Added line #L951 was not covered by tests
**self.config.long_sequence_init_args,
)
alibi = alibi_layer(_attention_mask, self.config.n_head, dtype=hidden_states.dtype)
alibi = paddle.squeeze(alibi)
else:
alibi = build_alibi_tensor(_attention_mask, self.config.n_head, dtype=hidden_states.dtype)

Check warning on line 957 in paddlenlp/transformers/bloom/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/bloom/modeling.py#L956-L957

Added lines #L956 - L957 were not covered by tests
else:
alibi = build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype)

if self.config.use_long_sequence_strategies:
alibi_layer = LongSequenceStrategies.build_long_sequence_strategy(
self.config.long_sequence_strategy_type,
self.config.long_sequence_strategy_name,

Check warning on line 962 in paddlenlp/transformers/bloom/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/bloom/modeling.py#L962

Added line #L962 was not covered by tests
**self.config.long_sequence_init_args,
)
alibi = alibi_layer(attention_mask, self.config.n_head, dtype=hidden_states.dtype)
alibi = paddle.squeeze(alibi)
else:
alibi = build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype)

Check warning on line 968 in paddlenlp/transformers/bloom/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/bloom/modeling.py#L967-L968

Added lines #L967 - L968 were not covered by tests
if self.config.tensor_parallel_degree > 1:
block_size = self.config.n_head // self.config.tensor_parallel_degree
alibi = alibi[
Expand Down
9 changes: 8 additions & 1 deletion paddlenlp/transformers/chatglm/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
"CHATGLM_PRETRAINED_RESOURCE_FILES_MAP",
]


CHATGLM_PRETRAINED_RESOURCE_FILES_MAP = {
"model_state": {
"THUDM/chatglm-6b": "https://paddlenlp.bj.bcebos.com/models/community/THUDM/chatglm-6b/model_state.pdparams",
Expand Down Expand Up @@ -104,6 +103,10 @@ def __init__(
activation="gelu",
num_image_tokens=0,
use_flash_attention=False,
long_sequence_strategy_type=None,
long_sequence_strategy_name=None,
long_sequence_init_args=None,
use_long_sequence_strategies=False,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
Expand All @@ -129,3 +132,7 @@ def __init__(
self.activation = activation
self.num_image_tokens = num_image_tokens
self.use_flash_attention = use_flash_attention
self.long_sequence_strategy_type = long_sequence_strategy_type
self.long_sequence_strategy_name = long_sequence_strategy_name
self.long_sequence_init_args = {} if long_sequence_init_args is None else long_sequence_init_args
self.use_long_sequence_strategies = use_long_sequence_strategies
37 changes: 28 additions & 9 deletions paddlenlp/transformers/chatglm/modeling.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from paddle.distributed.fleet.utils import recompute
from paddle.utils import map_structure

from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies

from ...utils.env import CONFIG_NAME
from ...utils.log import logger
from .. import PretrainedModel, register_base_model
Expand Down Expand Up @@ -442,12 +444,21 @@
# Recompute defaults to False and is controlled by Trainer
self.enable_recompute = False
self.num_attention_heads = config.num_attention_heads
self.rotary_embeddings = RotaryEmbeddings(
self.hidden_size // (self.num_attention_heads * 2)
if self.position_encoding_2d
else self.hidden_size // self.num_attention_heads,
base=10000.0,
)

if config.use_long_sequence_strategies:
self.rotary_embeddings = LongSequenceStrategies.build_long_sequence_strategy(

Check warning on line 449 in paddlenlp/transformers/chatglm/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm/modeling.py#L449

Added line #L449 was not covered by tests
config.long_sequence_strategy_type,
config.long_sequence_strategy_name,
**config.long_sequence_init_args,
)

else:
self.rotary_embeddings = RotaryEmbeddings(
self.hidden_size // (self.num_attention_heads * 2)
if self.position_encoding_2d
else self.hidden_size // self.num_attention_heads,
base=10000.0,
)
# self.embedding_dropout = nn.Dropout(config.embedding_dropout_prob)

if self.config.tensor_parallel_degree > 1:
Expand Down Expand Up @@ -530,7 +541,6 @@
cache: Optional[Tensor] = None,
use_cache: bool = False,
):

if input_ids is not None and inputs_embeds is not None:
input_ids = None
logger.warning("Specify both input_ids and inputs_embeds at the same time, will use inputs_embeds")
Expand All @@ -544,8 +554,17 @@
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
inputs_embeds = inputs_embeds.transpose([1, 0, 2])

rotary_embeds = self.rotary_embeddings(position_ids)
if self.config.use_long_sequence_strategies:
cos, sin = self.rotary_embeddings(seq_len=seq_length)
block_position_ids = position_ids[:, 1, :].transpose([1, 0])
position_ids = position_ids[:, 0, :].transpose([1, 0])
block_rotary_embeds = paddle.stack(

Check warning on line 561 in paddlenlp/transformers/chatglm/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm/modeling.py#L558-L561

Added lines #L558 - L561 were not covered by tests
[cos[block_position_ids].unsqueeze(2), sin[block_position_ids].unsqueeze(2)]
)
position_rotary_embeds = paddle.stack([cos[position_ids].unsqueeze(2), sin[position_ids].unsqueeze(2)])
rotary_embeds = paddle.stack([position_rotary_embeds, block_rotary_embeds], axis=0)

Check warning on line 565 in paddlenlp/transformers/chatglm/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm/modeling.py#L564-L565

Added lines #L564 - L565 were not covered by tests
else:
rotary_embeds = self.rotary_embeddings(position_ids)

if cache is None:
if self.config.pre_seq_len is not None:
Expand Down
8 changes: 8 additions & 0 deletions paddlenlp/transformers/chatglm_v2/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def __init__(
eos_token_id=2,
pad_token_id=0,
use_flash_attention=False,
long_sequence_strategy_type=None,
long_sequence_strategy_name=None,
long_sequence_init_args=None,
use_long_sequence_strategies=False,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, eos_token_id=eos_token_id, **kwargs)
Expand All @@ -81,3 +85,7 @@ def __init__(
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.fp32_residual_connection = fp32_residual_connection
self.use_flash_attention = use_flash_attention
self.long_sequence_strategy_type = long_sequence_strategy_type
self.long_sequence_strategy_name = long_sequence_strategy_name
self.long_sequence_init_args = {} if long_sequence_init_args is None else long_sequence_init_args
self.use_long_sequence_strategies = use_long_sequence_strategies
22 changes: 19 additions & 3 deletions paddlenlp/transformers/chatglm_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from paddle.distributed.fleet.utils import recompute
from paddle.utils import map_structure

from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies

from ...utils.converter import StateDictNameMapping, init_name_mappings
from .. import PretrainedModel, register_base_model
from ..model_outputs import (
Expand Down Expand Up @@ -650,7 +652,15 @@
rotary_dim = (
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
)
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2)
if config.use_long_sequence_strategies:
self.config = config
self.rotary_pos_emb = LongSequenceStrategies.build_long_sequence_strategy(

Check warning on line 657 in paddlenlp/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling.py#L656-L657

Added lines #L656 - L657 were not covered by tests
config.long_sequence_strategy_type,
config.long_sequence_strategy_name,
**config.long_sequence_init_args,
)
else:
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2)
self.encoder = GLMTransformer(config)
self.output_layer = nn.Linear(config.hidden_size, config.padded_vocab_size, bias_attr=False)

Expand All @@ -677,7 +687,6 @@
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

batch_size, seq_length = input_ids.shape

if inputs_embeds is None:
Expand All @@ -686,7 +695,14 @@
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)

# Rotary positional embeddings
rotary_pos_emb = self.rotary_pos_emb(self.max_sequence_length)
if self.config.use_long_sequence_strategies:
cos, sin = self.rotary_pos_emb(seq_len=self.max_sequence_length)
cos, cos = paddle.chunk(cos, 2, axis=-1)
sin, sin = paddle.chunk(sin, 2, axis=-1)
rotary_pos_emb = paddle.stack([cos, sin], axis=-1)

Check warning on line 702 in paddlenlp/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling.py#L699-L702

Added lines #L699 - L702 were not covered by tests
else:
rotary_pos_emb = self.rotary_pos_emb(self.max_sequence_length)

if position_ids is not None:
rotary_pos_emb = rotary_pos_emb[position_ids]
else:
Expand Down
9 changes: 9 additions & 0 deletions paddlenlp/transformers/llama/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ def __init__(
alibi=False,
rope_scaling_factor=1.0,
rope_scaling_type=None,
long_sequence_strategy_type=None,
long_sequence_strategy_name=None,
long_sequence_init_args=None,
use_long_sequence_strategies=False,
**kwargs,
):
self.vocab_size = vocab_size
Expand Down Expand Up @@ -208,6 +212,11 @@ def __init__(
self.rope_scaling_factor = rope_scaling_factor
self.rope_scaling_type = rope_scaling_type

self.long_sequence_strategy_type = long_sequence_strategy_type
self.long_sequence_strategy_name = long_sequence_strategy_name
self.long_sequence_init_args = {} if long_sequence_init_args is None else long_sequence_init_args
self.use_long_sequence_strategies = use_long_sequence_strategies

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
Expand Down
33 changes: 30 additions & 3 deletions paddlenlp/transformers/llama/modeling.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
StateDictNameMapping,
init_name_mappings,
)
from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies
from paddlenlp.transformers.model_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Expand Down Expand Up @@ -762,7 +763,14 @@
)

if config.rope:
self._init_rope()
if config.use_long_sequence_strategies:
self.rotary_emb = LongSequenceStrategies.build_long_sequence_strategy(

Check warning on line 767 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L767

Added line #L767 was not covered by tests
config.long_sequence_strategy_type,
config.long_sequence_strategy_name,
**config.long_sequence_init_args,
)
else:
self._init_rope()

self.reshard_layer = None
if config.sep_parallel_degree > 1:
Expand Down Expand Up @@ -971,7 +979,17 @@
use_neox_rotary_style=False,
)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
if self.config.use_long_sequence_strategies:
cos, sin = self.rotary_emb(seq_len=kv_seq_len)
cos = cos[None, :, None, :]
sin = sin[None, :, None, :]
cos, sin = (

Check warning on line 986 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L983-L986

Added lines #L983 - L986 were not covered by tests
cos.cast(value_states.dtype) if cos.dtype != value_states.dtype else cos,
sin.cast(value_states.dtype) if sin.dtype != value_states.dtype else sin,
)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

# [bs, seq_len, num_head, head_dim]
Expand Down Expand Up @@ -1324,6 +1342,7 @@
self.sequence_parallel = config.sequence_parallel
self.recompute_granularity = config.recompute_granularity
self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else []
self.config = config

# Recompute defaults to False and is controlled by Trainer
self.enable_recompute = False
Expand Down Expand Up @@ -1476,7 +1495,15 @@
# [bs, seq_len]
attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool)
if self.config.alibi:
alibi = build_alibi_tensor(attention_mask, self.config.num_attention_heads, dtype=inputs_embeds.dtype)
if self.config.use_long_sequence_strategies:
alibi_layer = LongSequenceStrategies.build_long_sequence_strategy(

Check warning on line 1499 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1498-L1499

Added lines #L1498 - L1499 were not covered by tests
self.config.long_sequence_strategy_type,
self.config.long_sequence_strategy_name,
**self.config.long_sequence_init_args,
)
alibi = alibi_layer(attention_mask, self.config.num_attention_heads, dtype=inputs_embeds.dtype)

Check warning on line 1504 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1504

Added line #L1504 was not covered by tests
else:
alibi = build_alibi_tensor(attention_mask, self.config.num_attention_heads, dtype=inputs_embeds.dtype)

Check warning on line 1506 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1506

Added line #L1506 was not covered by tests
if self.config.tensor_parallel_degree > 1:
block_size = self.config.num_attention_heads // self.config.tensor_parallel_degree
alibi = alibi[
Expand Down
18 changes: 18 additions & 0 deletions paddlenlp/transformers/long_sequence_strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from .attention_strategies import *
from .embedding_strategies import *
from .long_sequence_strategies import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import numpy as np
import paddle
from paddle import Tensor, nn

__all__ = ["AttentionWithLinearBias"]


class AttentionWithLinearBias(nn.Layer):
def __init__(self, **init_args):
super().__init__()

Check warning on line 26 in paddlenlp/transformers/long_sequence_strategies/attention_strategies.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/long_sequence_strategies/attention_strategies.py#L26

Added line #L26 was not covered by tests

def _get_interleave(self, n):
def _get_interleave_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
return np.array([start * start**i for i in range(n)]).astype(np.float32)

Check warning on line 31 in paddlenlp/transformers/long_sequence_strategies/attention_strategies.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/long_sequence_strategies/attention_strategies.py#L29-L31

Added lines #L29 - L31 were not covered by tests

if math.log2(n).is_integer():
return _get_interleave_power_of_2(n)

Check warning on line 34 in paddlenlp/transformers/long_sequence_strategies/attention_strategies.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/long_sequence_strategies/attention_strategies.py#L33-L34

Added lines #L33 - L34 were not covered by tests
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (

Check warning on line 37 in paddlenlp/transformers/long_sequence_strategies/attention_strategies.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/long_sequence_strategies/attention_strategies.py#L36-L37

Added lines #L36 - L37 were not covered by tests
_get_interleave_power_of_2(closest_power_of_2)
+ self._get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)

def forward(self, bool_attention_mask: Tensor, num_heads: int, dtype: paddle.dtype):
attention_mask = bool_attention_mask.astype("float32")
batch_size, seq_length = attention_mask.shape[0], attention_mask.shape[-1]
slopes = paddle.to_tensor(self._get_interleave(num_heads), dtype="float32")
with paddle.amp.auto_cast(enable=False):
alibi = slopes.unsqueeze(axis=[1, 2]) * paddle.arange(seq_length, dtype="float32").unsqueeze(

Check warning on line 47 in paddlenlp/transformers/long_sequence_strategies/attention_strategies.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/long_sequence_strategies/attention_strategies.py#L43-L47

Added lines #L43 - L47 were not covered by tests
axis=[0, 1]
).expand([num_heads, -1, -1])
alibi = alibi.reshape(shape=(1, num_heads, 1, seq_length)).expand([batch_size, -1, -1, -1])
return paddle.cast(alibi, dtype)

Check warning on line 51 in paddlenlp/transformers/long_sequence_strategies/attention_strategies.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/long_sequence_strategies/attention_strategies.py#L50-L51

Added lines #L50 - L51 were not covered by tests
Loading
Loading