Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update hybird window full attention #8467

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions llm/eval_ppl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import numpy as np
from paddlenlp.transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
Llama3Tokenizer,
LlamaTokenizer,
)
from tqdm import tqdm
import paddle
import argparse

def parse_config():
parser = argparse.ArgumentParser(description='arg parser')
parser.add_argument('--device', type=str, default="gpu:7", help='device: gpu or cpu')
parser.add_argument('--batch_size', type=int, default=4, help='batch size during inference')
parser.add_argument('--base_model', type=str, default="/home/song_test/PaddleNLP/llm/trained_model/llama_lora_merge/")
parser.add_argument('--cache_dir', type=str, default="./cache")
parser.add_argument('--seq_len', type=int, default=4096, help='context length during evaluation')
parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning')
parser.add_argument('--peft_model', type=str, default=None, help='')
parser.add_argument('--sliding_window', type=int, default=256, help='')
parser.add_argument('--data_path', type=str, default="/home/LongLoRA/eval_dataset/Proof-pile/test_sampled_data.bin", help='')
args = parser.parse_args()
return args

def get_as_batch(data, seq_length, batch_size, device='cpu', sliding_window=256):
all_ix = list(range(0, len(data) - seq_length, sliding_window))
all_ix.pop()

for idx in range(0, len(all_ix), batch_size):
ix = all_ix[idx:idx+batch_size]
assert all([idx + seq_length + 1 <= len(data) for idx in ix])
x = paddle.stack([paddle.to_tensor((data[i:i+seq_length]).astype(np.int64)) for i in ix])
y = paddle.stack([paddle.to_tensor((data[i+1:i+1+seq_length]).astype(np.int64)) for i in ix])
if device != 'cpu':
x, y = x.pin_memory().to(device, blocking=False), y.pin_memory().to(device, blocking=False)
yield x, y

def iceildiv(x, y):
return (x + y - 1) // y

def evaluate(model, data, args):
stats = {}

model.eval()

loss_list_val, acc_list = [], []
loss_step_list_val = []

with paddle.no_grad():
print(f"Using seq length {args.seq_len}")
paddle.set_printoptions(sci_mode=False)
for idx, (x, y) in tqdm(
enumerate(
get_as_batch(
data['val'],
args.seq_len,
args.batch_size,
device=args.device,
sliding_window=args.sliding_window
)
),
total = iceildiv(
iceildiv(len(data['val']), args.sliding_window),
args.batch_size
)
):
val_loss = 0.
acc = 0.
cnt = 0
for part_idx, i in enumerate(range(0, x.shape[1], args.seq_len)):
part_len = x[:, i:i + args.seq_len].shape[1]
outputs = model(
input_ids = x[:, i:i + args.seq_len],
labels = x[:, i:i + args.seq_len].contiguous(),
use_cache=False
)

val_loss = outputs[0] * part_len + val_loss

acc = ((outputs[1].argmax(-1) == y[:, i:i+args.seq_len]).sum()) + acc
cnt += part_len

while len(loss_step_list_val) <= part_idx:
loss_step_list_val.append([])
loss_step_list_val[part_idx].append(outputs[0].item())
val_loss /= cnt
acc /= cnt

loss_list_val.append(val_loss.item())
acc_list.append(acc.item())
stats['val_acc'] = paddle.to_tensor(acc_list).mean().item()
stats['val_loss'] = paddle.to_tensor(loss_list_val).mean().item()
stats['val_perplexity'] = 2.71828 ** stats['val_loss']
stats['val_perplexity_per_chunk'] = paddle.exp(paddle.to_tensor(loss_step_list_val).mean(axis=1))

return stats





def main(args):
paddle.device.set_device(args.device)
data = {'val': np.memmap(args.data_path, dtype=np.uint16, mode='r')}
print(f"Num validation tokens: {len(data['val'])}")
print("data path", args.data_path)
print("base model", args.base_model)

config = AutoConfig.from_pretrained(args.base_model, cache_dir=args.cache_dir)

# context_size = args.context_size if args.context_size > 0 else args.seq_len
# orig_ctx_len = getattr(config, 'max_position_embeddings', None)

# if orig_ctx_len and context_size > orig_ctx_len:
# scaling_factor = float(math.ceil(context_size / orig_ctx_len))
# config.rope_scaling = {"type": "linear", "factor": 2}

model = AutoModelForCausalLM.from_pretrained(
args.base_model,
config=config,
cache_dir=args.cache_dir,
)
model.resize_token_embeddings(32001)

stats = evaluate(model, data, args)
print(stats)

if __name__ == "__main__":

args = parse_config()
main(args)
119 changes: 100 additions & 19 deletions llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
QuantArgument,
TrainingArguments,
)
from llama_attn_replace import replace_llama_attn
from data import get_convert_example
from utils import (
CausalLMTrainer,
Expand All @@ -36,7 +37,7 @@
init_chat_template,
)

from paddlenlp.data import DataCollatorForSeq2Seq
from paddlenlp.data import DataCollatorForSeq2Seq, DataCollatorForSupervisedDataset
from paddlenlp.datasets import InTokensIterableDataset, InTokensMapDataset, load_dataset
from paddlenlp.metrics import BLEU, Rouge1, Rouge2, RougeL
from paddlenlp.peft import LoRAConfig, LoRAModel, PrefixConfig, PrefixModelForCausalLM
Expand All @@ -60,6 +61,33 @@ def docstring_decorator(fn):
return docstring_decorator


def tokenizer_fn_dev_redpajama(example, tokenizer, inference_length):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

数据处理的逻辑建议写到data.py里面


inputs = example["text"]
tokenized_source = tokenizer(inputs, padding=False, truncation=True, max_length=8192, return_dict=False)
input_ids = tokenized_source["input_ids"]
input_ids, labels = input_ids[:-1], input_ids[1:]
position_ids = list(range(len(input_ids)))
model_input = {"input_ids": input_ids, "position_ids": position_ids, "labels": labels}

return model_input


def tokenizer_fn_train_redpajama(example, tokenizer, scaled_max_position_embeddings, model_max_position_embeddings):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model_max_position_embeddings 看起来没有用

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种pretrain的数据不用区分train和dev直接写一个def tokenizer_fn_pretrain()就行

source = example['text']
tokenized_source = tokenizer(
source,
max_length=scaled_max_position_embeddings,
truncation=True,
truncation_side="left",
add_special_tokens=True,
)
ids = tokenized_source["input_ids"]
features = {"input_ids": ids, "labels": ids, "position_ids": list(range(len(ids)))}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里input_ids和labels没有错开一位?

return features



@dataclass
@add_start_docstrings(TrainingArguments.__doc__)
class FinetuneArguments(TrainingArguments):
Expand All @@ -71,6 +99,14 @@ class FinetuneArguments(TrainingArguments):
default=False,
metadata={"help": "whether to output logits in distributed status"},
)
use_ssa: Optional[bool] = field(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块放进argument.py 里TrainingArguments

default=True,
metadata={"help": "whether to use ssa"},
)
use_hybird_window_full_attention: Optional[bool] = field(
default=True,
metadata={"help": "whether to use hybird window full attention"},
)


def read_local_dataset(path):
Expand All @@ -93,6 +129,13 @@ def main():
training_args.print_config(quant_args, "Quant")
training_args.print_config(gen_args, "Generation")

# save the origin full attention
origin_forward = paddlenlp.transformers.llama.modeling.LlamaAttention.forward

# replace llama attention with shift sparse attention
if "llama" in model_args.model_name_or_path and training_args.use_ssa:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议用加载出来的model判断isinstance(model,LlamaForCausalLM)

replace_llama_attn()

if sum([quant_args.do_ptq, quant_args.do_qat, quant_args.do_gptq, training_args.do_train]) > 1:
raise ValueError(
"--do_train, --do_ptq, --do_gptq and --do_qat cannot work at the same time. Please choose only one at a time"
Expand Down Expand Up @@ -149,6 +192,7 @@ def main():
if hasattr(model_config, "use_flash_attention"):
model_config.use_flash_attention = model_args.use_flash_attention


model_config.use_fused_rms_norm = model_args.use_fused_rms_norm
model_config.fuse_attention_qkv = model_args.fuse_attention_qkv
model_config.fuse_attention_ffn = model_args.fuse_attention_ffn
Expand All @@ -174,6 +218,23 @@ def main():
model_config.tensor_parallel_output = training_args.tensor_parallel_output
model_config.seq_length = data_args.max_length


# set the rope scaling factor for long context window size
orig_rope_scaling = getattr(model_config, "rope_scaling", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if orig_rope_scaling is None:
orig_rope_scaling = {"factor": 1}

orig_rope_scaling_factor = orig_rope_scaling["factor"] if "factor" in orig_rope_scaling.keys() else 1

orig_ctx_len = getattr(model_config, "max_position_embeddings", None)
if orig_ctx_len:
orig_ctx_len *= orig_rope_scaling_factor

if data_args.max_length > orig_ctx_len:
scaling_factor = float(math.ceil(data_args.max_length / orig_ctx_len))
model_config.rope_scaling = {"type": "linear", "factor": scaling_factor}
model_config.max_position_embeddings = int(orig_ctx_len * scaling_factor)

if training_args.pipeline_parallel_degree > 1:
if data_args.eval_with_do_generation and training_args.do_eval:
raise ValueError("Plese set eval_with_do_generation to false in pipeline parallel mode.")
Expand All @@ -198,6 +259,11 @@ def main():
else:
# NOTE(gongenlei): new add autotuner_benchmark
model = AutoModelForCausalLM.from_config(model_config, dtype=dtype)

# set the last layer with full attention
if training_args.use_hybird_window_full_attention:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace这块建议写成一个大的函数,if isinstance(model,LlamaForCausalLM) and training_args.use_ssa:不要分散成多处写

last_layer = model.llama.layers[-1].self_attn
last_layer.forward = origin_forward.__get__(last_layer, type(last_layer))
if training_args.do_train and model_args.neftune:
# Inspired by https://github.com/neelsjain/NEFTune
if hasattr(model, "get_input_embeddings"):
Expand Down Expand Up @@ -366,11 +432,17 @@ def neft_post_hook(module, input, output):
raise NotImplementedError(
"Zero Padding data stream is only implemented for LLaMA, Bloom, ChatGLM and QWen so far."
)
train_ds = (
train_ds.map(partial(trans_func, is_test=False, intokens=data_args.zero_padding))
if train_ds is not None
else None
)
if training_args.use_hybird_window_full_attention:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以用data_args来标注是"pretrain"还是"instruct_tuning"的数据类型,现在默认的数据类型"instruct_tuning",长文本的数据算是"pretrain"

train_ds = train_ds.map(
fn=partial(tokenizer_fn_train_redpajama, tokenizer=tokenizer, scaled_max_position_embeddings=8192, model_max_position_embeddings=4096), batched=False, num_workers=512,
)

else:
train_ds = (
train_ds.map(partial(trans_func, is_test=False, intokens=data_args.zero_padding))
if train_ds is not None
else None
)
ptq_ds = (
ptq_ds.map(partial(trans_func, is_test=False, intokens=data_args.zero_padding)) if ptq_ds is not None else None
)
Expand All @@ -380,11 +452,16 @@ def neft_post_hook(module, input, output):
"`zero_padding` conflicts with `eval_with_do_generation`. Setting zero_padding to False for the eval_dataset."
)
eval_intokens = False
dev_ds = (
dev_ds.map(partial(trans_func, is_test=data_args.eval_with_do_generation, intokens=eval_intokens))
if dev_ds is not None
else None
)
if training_args.use_hybird_window_full_attention:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

dev_ds = dev_ds.map(
fn=partial(tokenizer_fn_dev_redpajama, tokenizer=tokenizer, inference_length=8192), batched=False, num_workers=256
)
else:
dev_ds = (
dev_ds.map(partial(trans_func, is_test=data_args.eval_with_do_generation, intokens=eval_intokens))
if dev_ds is not None
else None
)
if data_args.zero_padding:
if data_args.lazy:
intoken_dataset = InTokensIterableDataset
Expand Down Expand Up @@ -520,6 +597,17 @@ def compute_metrics_do_generation(eval_preds):
metrics = compute_metrics_do_generation
else:
metrics = compute_metrics
if training_args.use_hybird_window_full_attention:
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么需要单独写一个DataCollatorForSupervisedDataset, DataCollatorForSeq2Seq哪一块不满足需求?

else:
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
max_length=max_length,
padding=padding,
max_label_length=max_length,
return_tensors="np",
pad_to_multiple_of=data_args.pad_to_multiple_of,
)

trainer = CausalLMTrainer(
model=model,
Expand All @@ -528,14 +616,7 @@ def compute_metrics_do_generation(eval_preds):
eval_dataset=dev_ds,
tokenizer=tokenizer,
compute_metrics=metrics,
data_collator=DataCollatorForSeq2Seq(
tokenizer=tokenizer,
max_length=max_length,
padding=padding,
max_label_length=max_length,
return_tensors="np",
pad_to_multiple_of=data_args.pad_to_multiple_of,
),
data_collator=data_collator,
do_generation=data_args.eval_with_do_generation,
callbacks=[InTokensIterDatasetCallback()] if isinstance(train_ds, InTokensIterableDataset) else None,
gen_args=gen_args,
Expand Down
Loading
Loading