-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
update hybird window full attention #8467
base: develop
Are you sure you want to change the base?
update hybird window full attention #8467
Conversation
Thanks for your contribution! |
|
@@ -60,6 +61,33 @@ def docstring_decorator(fn): | |||
return docstring_decorator | |||
|
|||
|
|||
def tokenizer_fn_dev_redpajama(example, tokenizer, inference_length): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
数据处理的逻辑建议写到data.py里面
return model_input | ||
|
||
|
||
def tokenizer_fn_train_redpajama(example, tokenizer, scaled_max_position_embeddings, model_max_position_embeddings): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model_max_position_embeddings 看起来没有用
add_special_tokens=True, | ||
) | ||
ids = tokenized_source["input_ids"] | ||
features = {"input_ids": ids, "labels": ids, "position_ids": list(range(len(ids)))} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里input_ids和labels没有错开一位?
@@ -71,6 +99,14 @@ class FinetuneArguments(TrainingArguments): | |||
default=False, | |||
metadata={"help": "whether to output logits in distributed status"}, | |||
) | |||
use_ssa: Optional[bool] = field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块放进argument.py 里TrainingArguments
origin_forward = paddlenlp.transformers.llama.modeling.LlamaAttention.forward | ||
|
||
# replace llama attention with shift sparse attention | ||
if "llama" in model_args.model_name_or_path and training_args.use_ssa: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议用加载出来的model判断isinstance(model,LlamaForCausalLM)
@@ -198,6 +259,11 @@ def main(): | |||
else: | |||
# NOTE(gongenlei): new add autotuner_benchmark | |||
model = AutoModelForCausalLM.from_config(model_config, dtype=dtype) | |||
|
|||
# set the last layer with full attention | |||
if training_args.use_hybird_window_full_attention: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replace这块建议写成一个大的函数,if isinstance(model,LlamaForCausalLM) and training_args.use_ssa:不要分散成多处写
if train_ds is not None | ||
else None | ||
) | ||
if training_args.use_hybird_window_full_attention: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以用data_args来标注是"pretrain"还是"instruct_tuning"的数据类型,现在默认的数据类型"instruct_tuning",长文本的数据算是"pretrain"
if dev_ds is not None | ||
else None | ||
) | ||
if training_args.use_hybird_window_full_attention: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
@@ -520,6 +597,17 @@ def compute_metrics_do_generation(eval_preds): | |||
metrics = compute_metrics_do_generation | |||
else: | |||
metrics = compute_metrics | |||
if training_args.use_hybird_window_full_attention: | |||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么需要单独写一个DataCollatorForSupervisedDataset, DataCollatorForSeq2Seq哪一块不满足需求?
import paddle.nn.functional as F | ||
group_size_ratio = 1/4 | ||
|
||
def ssa_forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个forward实现建议在LlamaAttention forward基础上实现,现在的实现很可能不能通用适配其他模型,比如这个模型qkv是合并的(self.fuse_attention_qkv为True),另外现在的实现不能使用FA2,FA2能够节省显存提高训练速度。更简单的实现可以考虑修改LlamaAttention 的scaled_dot_product_attention函数
PR types
PR changes
Description