Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Sharding Overlap #8473

Merged
merged 3 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ class ModelArguments:
default=None,
metadata={"help": "num_hidden_layers."},
)
use_casual_mask: Optional[bool] = field(
default=True,
metadata={"help": "whether to use casual mask"},
)


def create_pretrained_dataset(
Expand Down Expand Up @@ -476,6 +480,7 @@ def main():
config.pp_recompute_interval = model_args.pp_recompute_interval
config.recompute_use_reentrant = model_args.recompute_use_reentrant
config.use_recompute = training_args.recompute
config.use_casual_mask = model_args.use_casual_mask

config.tensor_parallel_degree = training_args.tensor_parallel_degree
config.tensor_parallel_rank = training_args.tensor_parallel_rank
Expand Down
16 changes: 15 additions & 1 deletion paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1892,7 +1892,6 @@
optimizer._set_broadcast_overlap(True, model)

self.optimizer = optimizer

# pure tesnor parallel mode, no pipeline_parallel, no sharding.
if (
not in_pipeline_parallel_mode
Expand All @@ -1908,6 +1907,21 @@
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
self.optimizer = fleet.distributed_optimizer(self.optimizer)

# stage1 has v1 and v2 version
if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding:
if "split_param" in self.args.sharding_parallel_config:
if (

Check warning on line 1913 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1911-L1913

Added lines #L1911 - L1913 were not covered by tests
hasattr(self.optimizer, "_set_all_gather_overlap_forward")
and "enable_stage1_allgather_overlap" in self.args.sharding_parallel_config
):
self.optimizer._set_all_gather_overlap_forward(True, model)

Check warning on line 1917 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1917

Added line #L1917 was not covered by tests
else:
if (

Check warning on line 1919 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1919

Added line #L1919 was not covered by tests
hasattr(self.optimizer, "_set_broadcast_overlap")
and "enable_stage1_broadcast_overlap" in self.args.sharding_parallel_config
):
self.optimizer._set_broadcast_overlap(True, model)

Check warning on line 1923 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1923

Added line #L1923 was not covered by tests

return model

def _prepare_input(self, data: Union[paddle.Tensor, Any]) -> Union[paddle.Tensor, Any]:
Expand Down
24 changes: 17 additions & 7 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@
def build_alibi_tensor(
bool_attention_mask: Tensor, num_heads: int, dtype: paddle.dtype, tensor_parallel_degree=1
) -> Tensor:
attention_mask = bool_attention_mask.astype("float32")
batch_size, seq_length = attention_mask.shape[0], attention_mask.shape[-1]
batch_size, seq_length = bool_attention_mask.shape[0], bool_attention_mask.shape[-1]

Check warning on line 121 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L121

Added line #L121 was not covered by tests
slopes = paddle.to_tensor(_get_interleave(num_heads), dtype="float32")
alibi = slopes.unsqueeze(axis=[1, 2]) * paddle.arange(seq_length, dtype="float32").unsqueeze(axis=[0, 1]).expand(
[num_heads, -1, -1]
Expand Down Expand Up @@ -307,7 +306,7 @@

def _make_causal_mask(input_ids_shape, past_key_values_length):
"""
Make causal mask used for self-attention
Make casual mask used for self-attention
"""
batch_size, target_length = input_ids_shape # target_length: seq_len

Expand Down Expand Up @@ -1533,12 +1532,23 @@
if position_ids is None:
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))

attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]
use_casual_mask = (
True if hasattr(self.config, "use_casual_mask") and self.config.use_casual_mask is True else False

Check warning on line 1536 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1536

Added line #L1536 was not covered by tests
)
if use_casual_mask:

Check warning on line 1538 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1538

Added line #L1538 was not covered by tests
attention_mask = None
else:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]

is_casual = False

if self.config.use_flash_attention and get_env_device() != "gcu":
is_casual = is_casual_mask(attention_mask)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要删除这个。if hasattr(self.config, "casual_mask")

if use_casual_mask:
is_casual = True

Check warning on line 1549 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1549

Added line #L1549 was not covered by tests
else:
is_casual = is_casual_mask(attention_mask)
if get_env_device() != "npu":
if is_casual and alibi is None:
attention_mask = None
Expand Down
Loading