-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
update hybird window full attention #8467
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import numpy as np | ||
from paddlenlp.transformers import ( | ||
AutoConfig, | ||
AutoModelForCausalLM, | ||
AutoTokenizer, | ||
Llama3Tokenizer, | ||
LlamaTokenizer, | ||
) | ||
from tqdm import tqdm | ||
import paddle | ||
import argparse | ||
|
||
def parse_config(): | ||
parser = argparse.ArgumentParser(description='arg parser') | ||
parser.add_argument('--device', type=str, default="gpu:7", help='device: gpu or cpu') | ||
parser.add_argument('--batch_size', type=int, default=4, help='batch size during inference') | ||
parser.add_argument('--base_model', type=str, default="/home/song_test/PaddleNLP/llm/trained_model/llama_lora_merge/") | ||
parser.add_argument('--cache_dir', type=str, default="./cache") | ||
parser.add_argument('--seq_len', type=int, default=4096, help='context length during evaluation') | ||
parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning') | ||
parser.add_argument('--peft_model', type=str, default=None, help='') | ||
parser.add_argument('--sliding_window', type=int, default=256, help='') | ||
parser.add_argument('--data_path', type=str, default="/home/LongLoRA/eval_dataset/Proof-pile/test_sampled_data.bin", help='') | ||
args = parser.parse_args() | ||
return args | ||
|
||
def get_as_batch(data, seq_length, batch_size, device='cpu', sliding_window=256): | ||
all_ix = list(range(0, len(data) - seq_length, sliding_window)) | ||
all_ix.pop() | ||
|
||
for idx in range(0, len(all_ix), batch_size): | ||
ix = all_ix[idx:idx+batch_size] | ||
assert all([idx + seq_length + 1 <= len(data) for idx in ix]) | ||
x = paddle.stack([paddle.to_tensor((data[i:i+seq_length]).astype(np.int64)) for i in ix]) | ||
y = paddle.stack([paddle.to_tensor((data[i+1:i+1+seq_length]).astype(np.int64)) for i in ix]) | ||
if device != 'cpu': | ||
x, y = x.pin_memory().to(device, blocking=False), y.pin_memory().to(device, blocking=False) | ||
yield x, y | ||
|
||
def iceildiv(x, y): | ||
return (x + y - 1) // y | ||
|
||
def evaluate(model, data, args): | ||
stats = {} | ||
|
||
model.eval() | ||
|
||
loss_list_val, acc_list = [], [] | ||
loss_step_list_val = [] | ||
|
||
with paddle.no_grad(): | ||
print(f"Using seq length {args.seq_len}") | ||
paddle.set_printoptions(sci_mode=False) | ||
for idx, (x, y) in tqdm( | ||
enumerate( | ||
get_as_batch( | ||
data['val'], | ||
args.seq_len, | ||
args.batch_size, | ||
device=args.device, | ||
sliding_window=args.sliding_window | ||
) | ||
), | ||
total = iceildiv( | ||
iceildiv(len(data['val']), args.sliding_window), | ||
args.batch_size | ||
) | ||
): | ||
val_loss = 0. | ||
acc = 0. | ||
cnt = 0 | ||
for part_idx, i in enumerate(range(0, x.shape[1], args.seq_len)): | ||
part_len = x[:, i:i + args.seq_len].shape[1] | ||
outputs = model( | ||
input_ids = x[:, i:i + args.seq_len], | ||
labels = x[:, i:i + args.seq_len].contiguous(), | ||
use_cache=False | ||
) | ||
|
||
val_loss = outputs[0] * part_len + val_loss | ||
|
||
acc = ((outputs[1].argmax(-1) == y[:, i:i+args.seq_len]).sum()) + acc | ||
cnt += part_len | ||
|
||
while len(loss_step_list_val) <= part_idx: | ||
loss_step_list_val.append([]) | ||
loss_step_list_val[part_idx].append(outputs[0].item()) | ||
val_loss /= cnt | ||
acc /= cnt | ||
|
||
loss_list_val.append(val_loss.item()) | ||
acc_list.append(acc.item()) | ||
stats['val_acc'] = paddle.to_tensor(acc_list).mean().item() | ||
stats['val_loss'] = paddle.to_tensor(loss_list_val).mean().item() | ||
stats['val_perplexity'] = 2.71828 ** stats['val_loss'] | ||
stats['val_perplexity_per_chunk'] = paddle.exp(paddle.to_tensor(loss_step_list_val).mean(axis=1)) | ||
|
||
return stats | ||
|
||
|
||
|
||
|
||
|
||
def main(args): | ||
paddle.device.set_device(args.device) | ||
data = {'val': np.memmap(args.data_path, dtype=np.uint16, mode='r')} | ||
print(f"Num validation tokens: {len(data['val'])}") | ||
print("data path", args.data_path) | ||
print("base model", args.base_model) | ||
|
||
config = AutoConfig.from_pretrained(args.base_model, cache_dir=args.cache_dir) | ||
|
||
# context_size = args.context_size if args.context_size > 0 else args.seq_len | ||
# orig_ctx_len = getattr(config, 'max_position_embeddings', None) | ||
|
||
# if orig_ctx_len and context_size > orig_ctx_len: | ||
# scaling_factor = float(math.ceil(context_size / orig_ctx_len)) | ||
# config.rope_scaling = {"type": "linear", "factor": 2} | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
args.base_model, | ||
config=config, | ||
cache_dir=args.cache_dir, | ||
) | ||
model.resize_token_embeddings(32001) | ||
|
||
stats = evaluate(model, data, args) | ||
print(stats) | ||
|
||
if __name__ == "__main__": | ||
|
||
args = parse_config() | ||
main(args) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,7 @@ | |
QuantArgument, | ||
TrainingArguments, | ||
) | ||
from llama_attn_replace import replace_llama_attn | ||
from data import get_convert_example | ||
from utils import ( | ||
CausalLMTrainer, | ||
|
@@ -36,7 +37,7 @@ | |
init_chat_template, | ||
) | ||
|
||
from paddlenlp.data import DataCollatorForSeq2Seq | ||
from paddlenlp.data import DataCollatorForSeq2Seq, DataCollatorForSupervisedDataset | ||
from paddlenlp.datasets import InTokensIterableDataset, InTokensMapDataset, load_dataset | ||
from paddlenlp.metrics import BLEU, Rouge1, Rouge2, RougeL | ||
from paddlenlp.peft import LoRAConfig, LoRAModel, PrefixConfig, PrefixModelForCausalLM | ||
|
@@ -60,6 +61,33 @@ def docstring_decorator(fn): | |
return docstring_decorator | ||
|
||
|
||
def tokenizer_fn_dev_redpajama(example, tokenizer, inference_length): | ||
|
||
inputs = example["text"] | ||
tokenized_source = tokenizer(inputs, padding=False, truncation=True, max_length=8192, return_dict=False) | ||
input_ids = tokenized_source["input_ids"] | ||
input_ids, labels = input_ids[:-1], input_ids[1:] | ||
position_ids = list(range(len(input_ids))) | ||
model_input = {"input_ids": input_ids, "position_ids": position_ids, "labels": labels} | ||
|
||
return model_input | ||
|
||
|
||
def tokenizer_fn_train_redpajama(example, tokenizer, scaled_max_position_embeddings, model_max_position_embeddings): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. model_max_position_embeddings 看起来没有用 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这种pretrain的数据不用区分train和dev直接写一个def tokenizer_fn_pretrain()就行 |
||
source = example['text'] | ||
tokenized_source = tokenizer( | ||
source, | ||
max_length=scaled_max_position_embeddings, | ||
truncation=True, | ||
truncation_side="left", | ||
add_special_tokens=True, | ||
) | ||
ids = tokenized_source["input_ids"] | ||
features = {"input_ids": ids, "labels": ids, "position_ids": list(range(len(ids)))} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里input_ids和labels没有错开一位? |
||
return features | ||
|
||
|
||
|
||
@dataclass | ||
@add_start_docstrings(TrainingArguments.__doc__) | ||
class FinetuneArguments(TrainingArguments): | ||
|
@@ -71,6 +99,14 @@ class FinetuneArguments(TrainingArguments): | |
default=False, | ||
metadata={"help": "whether to output logits in distributed status"}, | ||
) | ||
use_ssa: Optional[bool] = field( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这块放进argument.py 里TrainingArguments |
||
default=True, | ||
metadata={"help": "whether to use ssa"}, | ||
) | ||
use_hybird_window_full_attention: Optional[bool] = field( | ||
default=True, | ||
metadata={"help": "whether to use hybird window full attention"}, | ||
) | ||
|
||
|
||
def read_local_dataset(path): | ||
|
@@ -93,6 +129,13 @@ def main(): | |
training_args.print_config(quant_args, "Quant") | ||
training_args.print_config(gen_args, "Generation") | ||
|
||
# save the origin full attention | ||
origin_forward = paddlenlp.transformers.llama.modeling.LlamaAttention.forward | ||
|
||
# replace llama attention with shift sparse attention | ||
if "llama" in model_args.model_name_or_path and training_args.use_ssa: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 建议用加载出来的model判断isinstance(model,LlamaForCausalLM) |
||
replace_llama_attn() | ||
|
||
if sum([quant_args.do_ptq, quant_args.do_qat, quant_args.do_gptq, training_args.do_train]) > 1: | ||
raise ValueError( | ||
"--do_train, --do_ptq, --do_gptq and --do_qat cannot work at the same time. Please choose only one at a time" | ||
|
@@ -149,6 +192,7 @@ def main(): | |
if hasattr(model_config, "use_flash_attention"): | ||
model_config.use_flash_attention = model_args.use_flash_attention | ||
|
||
|
||
model_config.use_fused_rms_norm = model_args.use_fused_rms_norm | ||
model_config.fuse_attention_qkv = model_args.fuse_attention_qkv | ||
model_config.fuse_attention_ffn = model_args.fuse_attention_ffn | ||
|
@@ -174,6 +218,23 @@ def main(): | |
model_config.tensor_parallel_output = training_args.tensor_parallel_output | ||
model_config.seq_length = data_args.max_length | ||
|
||
|
||
# set the rope scaling factor for long context window size | ||
orig_rope_scaling = getattr(model_config, "rope_scaling", None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
if orig_rope_scaling is None: | ||
orig_rope_scaling = {"factor": 1} | ||
|
||
orig_rope_scaling_factor = orig_rope_scaling["factor"] if "factor" in orig_rope_scaling.keys() else 1 | ||
|
||
orig_ctx_len = getattr(model_config, "max_position_embeddings", None) | ||
if orig_ctx_len: | ||
orig_ctx_len *= orig_rope_scaling_factor | ||
|
||
if data_args.max_length > orig_ctx_len: | ||
scaling_factor = float(math.ceil(data_args.max_length / orig_ctx_len)) | ||
model_config.rope_scaling = {"type": "linear", "factor": scaling_factor} | ||
model_config.max_position_embeddings = int(orig_ctx_len * scaling_factor) | ||
|
||
if training_args.pipeline_parallel_degree > 1: | ||
if data_args.eval_with_do_generation and training_args.do_eval: | ||
raise ValueError("Plese set eval_with_do_generation to false in pipeline parallel mode.") | ||
|
@@ -198,6 +259,11 @@ def main(): | |
else: | ||
# NOTE(gongenlei): new add autotuner_benchmark | ||
model = AutoModelForCausalLM.from_config(model_config, dtype=dtype) | ||
|
||
# set the last layer with full attention | ||
if training_args.use_hybird_window_full_attention: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. replace这块建议写成一个大的函数,if isinstance(model,LlamaForCausalLM) and training_args.use_ssa:不要分散成多处写 |
||
last_layer = model.llama.layers[-1].self_attn | ||
last_layer.forward = origin_forward.__get__(last_layer, type(last_layer)) | ||
if training_args.do_train and model_args.neftune: | ||
# Inspired by https://github.com/neelsjain/NEFTune | ||
if hasattr(model, "get_input_embeddings"): | ||
|
@@ -366,11 +432,17 @@ def neft_post_hook(module, input, output): | |
raise NotImplementedError( | ||
"Zero Padding data stream is only implemented for LLaMA, Bloom, ChatGLM and QWen so far." | ||
) | ||
train_ds = ( | ||
train_ds.map(partial(trans_func, is_test=False, intokens=data_args.zero_padding)) | ||
if train_ds is not None | ||
else None | ||
) | ||
if training_args.use_hybird_window_full_attention: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以用data_args来标注是"pretrain"还是"instruct_tuning"的数据类型,现在默认的数据类型"instruct_tuning",长文本的数据算是"pretrain" |
||
train_ds = train_ds.map( | ||
fn=partial(tokenizer_fn_train_redpajama, tokenizer=tokenizer, scaled_max_position_embeddings=8192, model_max_position_embeddings=4096), batched=False, num_workers=512, | ||
) | ||
|
||
else: | ||
train_ds = ( | ||
train_ds.map(partial(trans_func, is_test=False, intokens=data_args.zero_padding)) | ||
if train_ds is not None | ||
else None | ||
) | ||
ptq_ds = ( | ||
ptq_ds.map(partial(trans_func, is_test=False, intokens=data_args.zero_padding)) if ptq_ds is not None else None | ||
) | ||
|
@@ -380,11 +452,16 @@ def neft_post_hook(module, input, output): | |
"`zero_padding` conflicts with `eval_with_do_generation`. Setting zero_padding to False for the eval_dataset." | ||
) | ||
eval_intokens = False | ||
dev_ds = ( | ||
dev_ds.map(partial(trans_func, is_test=data_args.eval_with_do_generation, intokens=eval_intokens)) | ||
if dev_ds is not None | ||
else None | ||
) | ||
if training_args.use_hybird_window_full_attention: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
dev_ds = dev_ds.map( | ||
fn=partial(tokenizer_fn_dev_redpajama, tokenizer=tokenizer, inference_length=8192), batched=False, num_workers=256 | ||
) | ||
else: | ||
dev_ds = ( | ||
dev_ds.map(partial(trans_func, is_test=data_args.eval_with_do_generation, intokens=eval_intokens)) | ||
if dev_ds is not None | ||
else None | ||
) | ||
if data_args.zero_padding: | ||
if data_args.lazy: | ||
intoken_dataset = InTokensIterableDataset | ||
|
@@ -520,6 +597,17 @@ def compute_metrics_do_generation(eval_preds): | |
metrics = compute_metrics_do_generation | ||
else: | ||
metrics = compute_metrics | ||
if training_args.use_hybird_window_full_attention: | ||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为什么需要单独写一个DataCollatorForSupervisedDataset, DataCollatorForSeq2Seq哪一块不满足需求? |
||
else: | ||
data_collator = DataCollatorForSeq2Seq( | ||
tokenizer=tokenizer, | ||
max_length=max_length, | ||
padding=padding, | ||
max_label_length=max_length, | ||
return_tensors="np", | ||
pad_to_multiple_of=data_args.pad_to_multiple_of, | ||
) | ||
|
||
trainer = CausalLMTrainer( | ||
model=model, | ||
|
@@ -528,14 +616,7 @@ def compute_metrics_do_generation(eval_preds): | |
eval_dataset=dev_ds, | ||
tokenizer=tokenizer, | ||
compute_metrics=metrics, | ||
data_collator=DataCollatorForSeq2Seq( | ||
tokenizer=tokenizer, | ||
max_length=max_length, | ||
padding=padding, | ||
max_label_length=max_length, | ||
return_tensors="np", | ||
pad_to_multiple_of=data_args.pad_to_multiple_of, | ||
), | ||
data_collator=data_collator, | ||
do_generation=data_args.eval_with_do_generation, | ||
callbacks=[InTokensIterDatasetCallback()] if isinstance(train_ds, InTokensIterableDataset) else None, | ||
gen_args=gen_args, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
数据处理的逻辑建议写到data.py里面