Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Sharding Overlap #8473

Merged
merged 3 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1892,7 +1892,6 @@
optimizer._set_broadcast_overlap(True, model)

self.optimizer = optimizer

# pure tesnor parallel mode, no pipeline_parallel, no sharding.
if (
not in_pipeline_parallel_mode
Expand All @@ -1908,6 +1907,21 @@
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
self.optimizer = fleet.distributed_optimizer(self.optimizer)

# stage1 has v1 and v2 version
if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding:
if "split_param" in self.args.sharding_parallel_config:
if (

Check warning on line 1913 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1911-L1913

Added lines #L1911 - L1913 were not covered by tests
hasattr(self.optimizer, "_set_all_gather_overlap_forward")
and "enable_stage1_allgather_overlap" in self.args.sharding_parallel_config
):
self.optimizer._set_all_gather_overlap_forward(True, model)

Check warning on line 1917 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1917

Added line #L1917 was not covered by tests
else:
if (

Check warning on line 1919 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1919

Added line #L1919 was not covered by tests
hasattr(self.optimizer, "_set_broadcast_overlap")
and "enable_stage1_broadcast_overlap" in self.args.sharding_parallel_config
):
self.optimizer._set_broadcast_overlap(True, model)

Check warning on line 1923 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1923

Added line #L1923 was not covered by tests

return model

def _prepare_input(self, data: Union[paddle.Tensor, Any]) -> Union[paddle.Tensor, Any]:
Expand Down
28 changes: 21 additions & 7 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,15 @@
)


def get_use_casual_mask():
"""Get the value of the 'USE_CASUAL_MASK' environment variable."""
return os.getenv("USE_CASUAL_MASK", "False")


def build_alibi_tensor(
bool_attention_mask: Tensor, num_heads: int, dtype: paddle.dtype, tensor_parallel_degree=1
) -> Tensor:
attention_mask = bool_attention_mask.astype("float32")
batch_size, seq_length = attention_mask.shape[0], attention_mask.shape[-1]
batch_size, seq_length = bool_attention_mask.shape[0], bool_attention_mask.shape[-1]

Check warning on line 126 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L126

Added line #L126 was not covered by tests
slopes = paddle.to_tensor(_get_interleave(num_heads), dtype="float32")
alibi = slopes.unsqueeze(axis=[1, 2]) * paddle.arange(seq_length, dtype="float32").unsqueeze(axis=[0, 1]).expand(
[num_heads, -1, -1]
Expand Down Expand Up @@ -307,7 +311,7 @@

def _make_causal_mask(input_ids_shape, past_key_values_length):
"""
Make causal mask used for self-attention
Make casual mask used for self-attention
"""
batch_size, target_length = input_ids_shape # target_length: seq_len

Expand Down Expand Up @@ -1533,12 +1537,22 @@
if position_ids is None:
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))

attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]
use_casual_mask = get_use_casual_mask()

Check warning on line 1541 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1541

Added line #L1541 was not covered by tests
if use_casual_mask:
attention_mask = None

Check warning on line 1543 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1543

Added line #L1543 was not covered by tests
else:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]

is_casual = False

if self.config.use_flash_attention and get_env_device() != "gcu":
is_casual = is_casual_mask(attention_mask)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要删除这个。if hasattr(self.config, "casual_mask")

if use_casual_mask:
is_casual = True
else:
is_casual = is_casual_mask(attention_mask)

Check warning on line 1555 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1555

Added line #L1555 was not covered by tests
if get_env_device() != "npu":
if is_casual and alibi is None:
attention_mask = None
Expand Down
Loading