Skip to content

Commit

Permalink
Revert "Support Sharding Overlap (#8473)" (#8491)
Browse files Browse the repository at this point in the history
This reverts commit 7aaa788.
  • Loading branch information
SylarTiaNII committed May 24, 2024
1 parent 7aaa788 commit 0cd8fe7
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 36 deletions.
16 changes: 1 addition & 15 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1892,6 +1892,7 @@ def get_expected_keys(inputs, keys):
optimizer._set_broadcast_overlap(True, model)

self.optimizer = optimizer

# pure tesnor parallel mode, no pipeline_parallel, no sharding.
if (
not in_pipeline_parallel_mode
Expand All @@ -1907,21 +1908,6 @@ def get_expected_keys(inputs, keys):
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
self.optimizer = fleet.distributed_optimizer(self.optimizer)

# stage1 has v1 and v2 version
if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding:
if "split_param" in self.args.sharding_parallel_config:
if (
hasattr(self.optimizer, "_set_all_gather_overlap_forward")
and "enable_stage1_allgather_overlap" in self.args.sharding_parallel_config
):
self.optimizer._set_all_gather_overlap_forward(True, model)
else:
if (
hasattr(self.optimizer, "_set_broadcast_overlap")
and "enable_stage1_broadcast_overlap" in self.args.sharding_parallel_config
):
self.optimizer._set_broadcast_overlap(True, model)

return model

def _prepare_input(self, data: Union[paddle.Tensor, Any]) -> Union[paddle.Tensor, Any]:
Expand Down
28 changes: 7 additions & 21 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,11 @@ def _get_interleave_power_of_2(n):
)


def get_use_casual_mask():
"""Get the value of the 'USE_CASUAL_MASK' environment variable."""
return os.getenv("USE_CASUAL_MASK", "False")


def build_alibi_tensor(
bool_attention_mask: Tensor, num_heads: int, dtype: paddle.dtype, tensor_parallel_degree=1
) -> Tensor:
batch_size, seq_length = bool_attention_mask.shape[0], bool_attention_mask.shape[-1]
attention_mask = bool_attention_mask.astype("float32")
batch_size, seq_length = attention_mask.shape[0], attention_mask.shape[-1]
slopes = paddle.to_tensor(_get_interleave(num_heads), dtype="float32")
alibi = slopes.unsqueeze(axis=[1, 2]) * paddle.arange(seq_length, dtype="float32").unsqueeze(axis=[0, 1]).expand(
[num_heads, -1, -1]
Expand Down Expand Up @@ -311,7 +307,7 @@ def is_casual_mask(attention_mask):

def _make_causal_mask(input_ids_shape, past_key_values_length):
"""
Make casual mask used for self-attention
Make causal mask used for self-attention
"""
batch_size, target_length = input_ids_shape # target_length: seq_len

Expand Down Expand Up @@ -1547,22 +1543,12 @@ def forward(
if position_ids is None:
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))

use_casual_mask = get_use_casual_mask()

if use_casual_mask:
attention_mask = None
else:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]

attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]
is_casual = False

if self.config.use_flash_attention and get_env_device() != "gcu":
if use_casual_mask:
is_casual = True
else:
is_casual = is_casual_mask(attention_mask)
is_casual = is_casual_mask(attention_mask)
if get_env_device() != "npu":
if is_casual and alibi is None:
attention_mask = None
Expand Down

0 comments on commit 0cd8fe7

Please sign in to comment.