Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update hybird window full attention #8467

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from

Conversation

micelvrice
Copy link

PR types

PR changes

Description

Copy link

paddle-bot bot commented May 20, 2024

Thanks for your contribution!

@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

@@ -60,6 +61,33 @@ def docstring_decorator(fn):
return docstring_decorator


def tokenizer_fn_dev_redpajama(example, tokenizer, inference_length):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

数据处理的逻辑建议写到data.py里面

return model_input


def tokenizer_fn_train_redpajama(example, tokenizer, scaled_max_position_embeddings, model_max_position_embeddings):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model_max_position_embeddings 看起来没有用

add_special_tokens=True,
)
ids = tokenized_source["input_ids"]
features = {"input_ids": ids, "labels": ids, "position_ids": list(range(len(ids)))}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里input_ids和labels没有错开一位?

@@ -71,6 +99,14 @@ class FinetuneArguments(TrainingArguments):
default=False,
metadata={"help": "whether to output logits in distributed status"},
)
use_ssa: Optional[bool] = field(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块放进argument.py 里TrainingArguments

origin_forward = paddlenlp.transformers.llama.modeling.LlamaAttention.forward

# replace llama attention with shift sparse attention
if "llama" in model_args.model_name_or_path and training_args.use_ssa:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议用加载出来的model判断isinstance(model,LlamaForCausalLM)

@@ -198,6 +259,11 @@ def main():
else:
# NOTE(gongenlei): new add autotuner_benchmark
model = AutoModelForCausalLM.from_config(model_config, dtype=dtype)

# set the last layer with full attention
if training_args.use_hybird_window_full_attention:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace这块建议写成一个大的函数,if isinstance(model,LlamaForCausalLM) and training_args.use_ssa:不要分散成多处写

if train_ds is not None
else None
)
if training_args.use_hybird_window_full_attention:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以用data_args来标注是"pretrain"还是"instruct_tuning"的数据类型,现在默认的数据类型"instruct_tuning",长文本的数据算是"pretrain"

if dev_ds is not None
else None
)
if training_args.use_hybird_window_full_attention:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

@@ -520,6 +597,17 @@ def compute_metrics_do_generation(eval_preds):
metrics = compute_metrics_do_generation
else:
metrics = compute_metrics
if training_args.use_hybird_window_full_attention:
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么需要单独写一个DataCollatorForSupervisedDataset, DataCollatorForSeq2Seq哪一块不满足需求?

import paddle.nn.functional as F
group_size_ratio = 1/4

def ssa_forward(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个forward实现建议在LlamaAttention forward基础上实现,现在的实现很可能不能通用适配其他模型,比如这个模型qkv是合并的(self.fuse_attention_qkv为True),另外现在的实现不能使用FA2,FA2能够节省显存提高训练速度。更简单的实现可以考虑修改LlamaAttention 的scaled_dot_product_attention函数

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants