Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
iosmers committed May 23, 2024
1 parent 87e4c4f commit 92b106f
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 8 deletions.
5 changes: 5 additions & 0 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ class ModelArguments:
default=None,
metadata={"help": "num_hidden_layers."},
)
use_casual_mask: Optional[bool] = field(
default=True,
metadata={"help": "whether to use casual mask"},
)


def create_pretrained_dataset(
Expand Down Expand Up @@ -476,6 +480,7 @@ def main():
config.pp_recompute_interval = model_args.pp_recompute_interval
config.recompute_use_reentrant = model_args.recompute_use_reentrant
config.use_recompute = training_args.recompute
config.use_casual_mask = model_args.use_casual_mask

config.tensor_parallel_degree = training_args.tensor_parallel_degree
config.tensor_parallel_rank = training_args.tensor_parallel_rank
Expand Down
16 changes: 15 additions & 1 deletion paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1892,7 +1892,6 @@ def get_expected_keys(inputs, keys):
optimizer._set_broadcast_overlap(True, model)

self.optimizer = optimizer

# pure tesnor parallel mode, no pipeline_parallel, no sharding.
if (
not in_pipeline_parallel_mode
Expand All @@ -1908,6 +1907,21 @@ def get_expected_keys(inputs, keys):
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
self.optimizer = fleet.distributed_optimizer(self.optimizer)

# stage1 has v1 and v2 version
if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding:
if "split_param" in self.args.sharding_parallel_config:
if (
hasattr(self.optimizer, "_set_all_gather_overlap_forward")
and "enable_stage1_allgather_overlap" in self.args.sharding_parallel_config
):
self.optimizer._set_all_gather_overlap_forward(True, model)
else:
if (
hasattr(self.optimizer, "_set_broadcast_overlap")
and "enable_stage1_broadcast_overlap" in self.args.sharding_parallel_config
):
self.optimizer._set_broadcast_overlap(True, model)

return model

def _prepare_input(self, data: Union[paddle.Tensor, Any]) -> Union[paddle.Tensor, Any]:
Expand Down
24 changes: 17 additions & 7 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ def _get_interleave_power_of_2(n):
def build_alibi_tensor(
bool_attention_mask: Tensor, num_heads: int, dtype: paddle.dtype, tensor_parallel_degree=1
) -> Tensor:
attention_mask = bool_attention_mask.astype("float32")
batch_size, seq_length = attention_mask.shape[0], attention_mask.shape[-1]
batch_size, seq_length = bool_attention_mask.shape[0], bool_attention_mask.shape[-1]
slopes = paddle.to_tensor(_get_interleave(num_heads), dtype="float32")
alibi = slopes.unsqueeze(axis=[1, 2]) * paddle.arange(seq_length, dtype="float32").unsqueeze(axis=[0, 1]).expand(
[num_heads, -1, -1]
Expand Down Expand Up @@ -307,7 +306,7 @@ def is_casual_mask(attention_mask):

def _make_causal_mask(input_ids_shape, past_key_values_length):
"""
Make causal mask used for self-attention
Make casual mask used for self-attention
"""
batch_size, target_length = input_ids_shape # target_length: seq_len

Expand Down Expand Up @@ -1533,12 +1532,23 @@ def forward(
if position_ids is None:
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))

attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]
is_casual_mask = (
True if hasattr(self.config, "use_casual_mask") and self.config.use_casual_mask is True else False
)
if is_casual_mask:
attention_mask = None
else:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]

is_casual = False

if self.config.use_flash_attention and get_env_device() != "gcu":
is_casual = is_casual_mask(attention_mask)
if is_casual_mask:
is_casual = True
else:
is_casual = is_casual_mask(attention_mask)
if get_env_device() != "npu":
if is_casual and alibi is None:
attention_mask = None
Expand Down

0 comments on commit 92b106f

Please sign in to comment.