Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LongLoRA implementation #8341

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
68 changes: 63 additions & 5 deletions llm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,53 @@ def get_convert_example(model):
if base_model_prefix == "chatglm":
return convert_example_chatglm
elif base_model_prefix in ["chatglm_v2", "llama", "bloom", "opt", "qwen", "mixtral", "gemma"]:
return convert_example_common
return convert_example_common_meta_text
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议不要改动数据集读取逻辑,先将数据预处理符合llm模型格式,然后按照现有的方式加载进来

else:
raise ValueError(
f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama, qwen, mixtral, gemma"
)


class DataFormatError(ValueError):
class DataFormatError( ValueError):
pass

def tokenize_example_meta_text(tokenizer, example, data_args):
if "text" in example:
text = example["text"]
source = text
words = text.split(' ')
source = ' '.join(words)
if len(words) > 1:
# remove the first word in the sentence
target = ' '.join(words[1:])
else:
target = ''
else:
raise DataFormatError(
f"Example format is wrong, please check: {example} or rewrite tokenize_example in data.py "
)
tokenized_source = tokenizer(
source,
max_length=data_args.src_length,
truncation=True,
truncation_side="left",
add_special_tokens=True,
)
tgt_max_length = data_args.max_length - len(tokenized_source["input_ids"])
tokenized_target = tokenizer(
target,
max_length=tgt_max_length,
truncation=True,
truncation_side="left",
add_special_tokens=False,
)
tokenized_target_input_ids = tokenized_target["input_ids"]
# Add eos_token_id at the end of sequence if the sentence is not truncated.
# Attention! In some cases(ex. ChatGLMv2), tokenized eos_token is not equal to eos_token_id.
if len(tokenized_target_input_ids) < tgt_max_length:
tokenized_target_input_ids += [tokenizer.eos_token_id]
return tokenized_source, tokenized_target_input_ids
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这么处理数据的原因没有明白



def tokenize_example(tokenizer, example, data_args):
if "src" in example and "tgt" in example:
Expand All @@ -77,7 +114,7 @@ def tokenize_example(tokenizer, example, data_args):
target,
max_length=tgt_max_length,
truncation=True,
truncation_side="right",
truncation_side="left",
add_special_tokens=False,
)

Expand All @@ -86,10 +123,31 @@ def tokenize_example(tokenizer, example, data_args):
# Attention! In some cases(ex. ChatGLMv2), tokenized eos_token is not equal to eos_token_id.
if len(tokenized_target_input_ids) < tgt_max_length:
tokenized_target_input_ids += [tokenizer.eos_token_id]

return tokenized_source, tokenized_target_input_ids


def convert_example_common_meta_text(example, tokenizer, data_args, is_test=True, intokens=False):
if tokenizer.chat_template is not None:
return convert_rounds_example_common(example, tokenizer, data_args, is_test, intokens)
tokenized_source, tokenized_target_input_ids = tokenize_example_meta_text(tokenizer, example, data_args)
if is_test:
return {
**tokenized_source,
"labels": tokenized_target_input_ids,
}
else:
input_ids = tokenized_source["input_ids"] + tokenized_target_input_ids
source_length = len(tokenized_source["input_ids"])
labels = [-100] * source_length + input_ids[source_length:]
# shift input_ids and labels
input_ids, labels = input_ids[:-1], labels[1:]
seq_length = len(input_ids)
features = {"input_ids": input_ids, "labels": labels}
if "position_ids" in tokenized_source:
features["position_ids"] = list(range(seq_length))
if intokens:
features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool)
return features

def tokenize_rounds_example(tokenizer, example, data_args):
"""tokenize multi-rounds examples with chat_template.json

Expand Down
79 changes: 59 additions & 20 deletions llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import sys
from dataclasses import dataclass, field
from functools import partial

import math
import paddle
from argument import (
DataArgument,
Expand All @@ -34,8 +34,9 @@
get_prefix_tuning_params,
init_chat_template,
)

from paddlenlp.data import DataCollatorForSeq2Seq
from glm.utils import GLMTrainer
# from llama_attn_replace_paddle import replace_llama_attn
from paddlenlp.data import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
from paddlenlp.datasets import InTokensIterableDataset, InTokensMapDataset, load_dataset
from paddlenlp.metrics import BLEU, Rouge1, Rouge2, RougeL
from paddlenlp.peft import LoRAConfig, LoRAModel, PrefixConfig, PrefixModelForCausalLM
Expand All @@ -45,7 +46,6 @@
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
Llama3Tokenizer,
LlamaTokenizer,
)
from paddlenlp.utils.log import logger
Expand All @@ -66,6 +66,14 @@ class FinetuneArguments(TrainingArguments):
default=0,
metadata={"help": "The steps use to control the learing rate."},
)
sparse: bool = field(
default=True,
metadata={"help": "The steps use to control the learing rate."},
)
trainable_params: str = field(
default="embed,norm",
metadata={"help": "Additional trainable parameters except LoRA weights, if low rank training."},
)


def read_local_dataset(path):
Expand Down Expand Up @@ -115,6 +123,7 @@ def main():
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)


# Load model
if training_args.fp16_opt_level == "O2":
if training_args.fp16:
Expand Down Expand Up @@ -193,10 +202,20 @@ def main():
# Config for model using dropout, such as GPT.
model_config.hidden_dropout_prob = model_args.hidden_dropout_prob
model_config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob

model_config.sep_parallel_degree = training_args.sep_parallel_degree
model_config.tensor_parallel_output = True
model_config.seq_length = data_args.max_length

#set RoPE scaling factor
orig_rope_scaling_factor = model_config.rope_scaling_factor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里会被废弃,后续建议使用LongSequenceStrategies https://github.com/PaddlePaddle/PaddleNLP/pull/8076/files

model.config.long_sequence_strategy_type = "attention_strategies"
model.config.long_sequence_strategy_name = "LinearScalingRotaryEmbedding"
model.config.long_sequence_init_args = {"head_dim":head_dim,"max_position_embeddings":max_position_embeddings,"rope_scaling_type":rope_scaling_type,"rope_scaling_factor":rope_scaling_factor}

orig_ctx_len = model_config.max_position_embeddings
if orig_ctx_len:
orig_ctx_len *= orig_rope_scaling_factor
if data_args.max_length > orig_ctx_len:
scaling_factor = float(math.ceil(data_args.max_length / orig_ctx_len))
model_config.rope_scaling_factor = scaling_factor
model_config.rope_scaling_type = "linear"

if not training_args.autotuner_benchmark:
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
Expand Down Expand Up @@ -233,7 +252,7 @@ def neft_post_hook(module, input, output):
if tokenizer.chat_template is not None:
data_args.eval_with_do_generation = False

if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, Llama3Tokenizer):
if isinstance(tokenizer, LlamaTokenizer):
tokenizer.pad_token_id = tokenizer.eos_token_id

if data_args.dataset_name_or_path is None:
Expand Down Expand Up @@ -297,11 +316,7 @@ def neft_post_hook(module, input, output):
else:
train_ds = None
if training_args.do_eval:
dev_ds = load_dataset(
"json",
data_files=glob.glob(os.path.join(data_args.dataset_name_or_path, "dev", "*.json")),
lazy=data_args.lazy,
)[0]
dev_ds = load_dataset('/home/mg/proof-pile',splits=["test"],cache_dir='proof-pile')[0]
else:
dev_ds = None
if quant_args.do_ptq or quant_args.do_gptq:
Expand Down Expand Up @@ -330,7 +345,7 @@ def neft_post_hook(module, input, output):
else:
train_ds = None
if training_args.do_eval:
dev_ds = load_dataset(data_args.dataset_name_or_path, splits=["dev"])[0]
dev_ds = load_dataset(data_args.dataset_name_or_path, splits=["test"])[0]
else:
dev_ds = None
if quant_args.do_ptq or quant_args.do_gptq:
Expand Down Expand Up @@ -361,7 +376,6 @@ def neft_post_hook(module, input, output):

if training_args.pipeline_parallel_degree > 1:
from data import convert_example_common

trans_func = partial(convert_example_common, tokenizer=tokenizer, data_args=data_args)
else:
trans_func = partial(get_convert_example(model), tokenizer=tokenizer, data_args=data_args)
Expand All @@ -388,6 +402,7 @@ def neft_post_hook(module, input, output):
"`zero_padding` conflicts with `eval_with_do_generation`. Setting zero_padding to False for the eval_dataset."
)
eval_intokens = False

dev_ds = (
dev_ds.map(partial(trans_func, is_test=data_args.eval_with_do_generation, intokens=eval_intokens))
if dev_ds is not None
Expand Down Expand Up @@ -458,7 +473,7 @@ def neft_post_hook(module, input, output):

if model_args.lora:
if model_args.lora_path is None:
target_modules = get_lora_target_modules(model)
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
lora_config = LoRAConfig(
target_modules=target_modules,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lora_config本身就有一个参数叫trainable_modules [".*embed.*",".*norm.*"]
https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/peft/lora/lora_config.py#L47

r=model_args.lora_rank,
Expand All @@ -474,9 +489,27 @@ def neft_post_hook(module, input, output):
use_quick_lora=model_args.use_quick_lora,
)
model = LoRAModel(model, lora_config)

# model.mark_only_lora_as_trainable()
# model.print_trainable_parameters()
trainable_keywords = ["embed","norm"]
# set embedding and norm trainable
for name, param in model.named_parameters():
make_trainable = False
for keyword in trainable_keywords:
if keyword in name:
make_trainable = True
break
if make_trainable:
param.stop_gradient = False
model.config.use_cache = False

model.recompute_enable()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接在执行finetune_generation.py json脚本里配置"recompute":"true"

for param in model.parameters():
if not param.stop_gradient and param.grad is None:
param.clear_gradient()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是干什么用

else:
model = LoRAModel.from_pretrained(model=model, lora_path=model_args.lora_path)

model.print_trainable_parameters()

def compute_metrics_do_generation(eval_preds):
Expand Down Expand Up @@ -532,13 +565,18 @@ def compute_metrics_do_generation(eval_preds):
eval_dataset=dev_ds,
tokenizer=tokenizer,
compute_metrics=metrics,
# data_collator=DataCollatorForLanguageModeling(
# tokenizer=tokenizer,
# return_tensors="np",
# mlm=False,
# pad_to_multiple_of=data_args.max_length,
# ),
data_collator=DataCollatorForSeq2Seq(
tokenizer=tokenizer,
max_length=max_length,
padding=padding,
max_label_length=max_length,
return_tensors="np",
pad_to_multiple_of=data_args.pad_to_multiple_of,
max_length=max_length,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个改动是为了什么

padding=True,
pad_to_multiple_of=data_args.max_length,
),
do_generation=data_args.eval_with_do_generation,
callbacks=[InTokensIterDatasetCallback()] if isinstance(train_ds, InTokensIterableDataset) else None,
Expand All @@ -553,7 +591,8 @@ def compute_metrics_do_generation(eval_preds):
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
# train_result = trainer.train(resume_from_checkpoint=checkpoint)
train_result = trainer.train()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?这行不要改

if model_args.neftune:
neft_post_hook_handle.remove()
if training_args.benchmark:
Expand Down