-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LongLoRA implementation #8341
base: develop
Are you sure you want to change the base?
LongLoRA implementation #8341
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,16 +45,53 @@ def get_convert_example(model): | |
if base_model_prefix == "chatglm": | ||
return convert_example_chatglm | ||
elif base_model_prefix in ["chatglm_v2", "llama", "bloom", "opt", "qwen", "mixtral", "gemma"]: | ||
return convert_example_common | ||
return convert_example_common_meta_text | ||
else: | ||
raise ValueError( | ||
f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama, qwen, mixtral, gemma" | ||
) | ||
|
||
|
||
class DataFormatError(ValueError): | ||
class DataFormatError( ValueError): | ||
pass | ||
|
||
def tokenize_example_meta_text(tokenizer, example, data_args): | ||
if "text" in example: | ||
text = example["text"] | ||
source = text | ||
words = text.split(' ') | ||
source = ' '.join(words) | ||
if len(words) > 1: | ||
# remove the first word in the sentence | ||
target = ' '.join(words[1:]) | ||
else: | ||
target = '' | ||
else: | ||
raise DataFormatError( | ||
f"Example format is wrong, please check: {example} or rewrite tokenize_example in data.py " | ||
) | ||
tokenized_source = tokenizer( | ||
source, | ||
max_length=data_args.src_length, | ||
truncation=True, | ||
truncation_side="left", | ||
add_special_tokens=True, | ||
) | ||
tgt_max_length = data_args.max_length - len(tokenized_source["input_ids"]) | ||
tokenized_target = tokenizer( | ||
target, | ||
max_length=tgt_max_length, | ||
truncation=True, | ||
truncation_side="left", | ||
add_special_tokens=False, | ||
) | ||
tokenized_target_input_ids = tokenized_target["input_ids"] | ||
# Add eos_token_id at the end of sequence if the sentence is not truncated. | ||
# Attention! In some cases(ex. ChatGLMv2), tokenized eos_token is not equal to eos_token_id. | ||
if len(tokenized_target_input_ids) < tgt_max_length: | ||
tokenized_target_input_ids += [tokenizer.eos_token_id] | ||
return tokenized_source, tokenized_target_input_ids | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这么处理数据的原因没有明白 |
||
|
||
|
||
def tokenize_example(tokenizer, example, data_args): | ||
if "src" in example and "tgt" in example: | ||
|
@@ -77,7 +114,7 @@ def tokenize_example(tokenizer, example, data_args): | |
target, | ||
max_length=tgt_max_length, | ||
truncation=True, | ||
truncation_side="right", | ||
truncation_side="left", | ||
add_special_tokens=False, | ||
) | ||
|
||
|
@@ -86,10 +123,31 @@ def tokenize_example(tokenizer, example, data_args): | |
# Attention! In some cases(ex. ChatGLMv2), tokenized eos_token is not equal to eos_token_id. | ||
if len(tokenized_target_input_ids) < tgt_max_length: | ||
tokenized_target_input_ids += [tokenizer.eos_token_id] | ||
|
||
return tokenized_source, tokenized_target_input_ids | ||
|
||
|
||
def convert_example_common_meta_text(example, tokenizer, data_args, is_test=True, intokens=False): | ||
if tokenizer.chat_template is not None: | ||
return convert_rounds_example_common(example, tokenizer, data_args, is_test, intokens) | ||
tokenized_source, tokenized_target_input_ids = tokenize_example_meta_text(tokenizer, example, data_args) | ||
if is_test: | ||
return { | ||
**tokenized_source, | ||
"labels": tokenized_target_input_ids, | ||
} | ||
else: | ||
input_ids = tokenized_source["input_ids"] + tokenized_target_input_ids | ||
source_length = len(tokenized_source["input_ids"]) | ||
labels = [-100] * source_length + input_ids[source_length:] | ||
# shift input_ids and labels | ||
input_ids, labels = input_ids[:-1], labels[1:] | ||
seq_length = len(input_ids) | ||
features = {"input_ids": input_ids, "labels": labels} | ||
if "position_ids" in tokenized_source: | ||
features["position_ids"] = list(range(seq_length)) | ||
if intokens: | ||
features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool) | ||
return features | ||
|
||
def tokenize_rounds_example(tokenizer, example, data_args): | ||
"""tokenize multi-rounds examples with chat_template.json | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ | |
import sys | ||
from dataclasses import dataclass, field | ||
from functools import partial | ||
|
||
import math | ||
import paddle | ||
from argument import ( | ||
DataArgument, | ||
|
@@ -34,8 +34,9 @@ | |
get_prefix_tuning_params, | ||
init_chat_template, | ||
) | ||
|
||
from paddlenlp.data import DataCollatorForSeq2Seq | ||
from glm.utils import GLMTrainer | ||
# from llama_attn_replace_paddle import replace_llama_attn | ||
from paddlenlp.data import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling | ||
from paddlenlp.datasets import InTokensIterableDataset, InTokensMapDataset, load_dataset | ||
from paddlenlp.metrics import BLEU, Rouge1, Rouge2, RougeL | ||
from paddlenlp.peft import LoRAConfig, LoRAModel, PrefixConfig, PrefixModelForCausalLM | ||
|
@@ -45,7 +46,6 @@ | |
AutoConfig, | ||
AutoModelForCausalLM, | ||
AutoTokenizer, | ||
Llama3Tokenizer, | ||
LlamaTokenizer, | ||
) | ||
from paddlenlp.utils.log import logger | ||
|
@@ -66,6 +66,14 @@ class FinetuneArguments(TrainingArguments): | |
default=0, | ||
metadata={"help": "The steps use to control the learing rate."}, | ||
) | ||
sparse: bool = field( | ||
default=True, | ||
metadata={"help": "The steps use to control the learing rate."}, | ||
) | ||
trainable_params: str = field( | ||
default="embed,norm", | ||
metadata={"help": "Additional trainable parameters except LoRA weights, if low rank training."}, | ||
) | ||
|
||
|
||
def read_local_dataset(path): | ||
|
@@ -115,6 +123,7 @@ def main(): | |
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch." | ||
) | ||
|
||
|
||
# Load model | ||
if training_args.fp16_opt_level == "O2": | ||
if training_args.fp16: | ||
|
@@ -193,10 +202,20 @@ def main(): | |
# Config for model using dropout, such as GPT. | ||
model_config.hidden_dropout_prob = model_args.hidden_dropout_prob | ||
model_config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob | ||
|
||
model_config.sep_parallel_degree = training_args.sep_parallel_degree | ||
model_config.tensor_parallel_output = True | ||
model_config.seq_length = data_args.max_length | ||
|
||
#set RoPE scaling factor | ||
orig_rope_scaling_factor = model_config.rope_scaling_factor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里会被废弃,后续建议使用LongSequenceStrategies https://github.com/PaddlePaddle/PaddleNLP/pull/8076/files model.config.long_sequence_strategy_type = "attention_strategies" |
||
orig_ctx_len = model_config.max_position_embeddings | ||
if orig_ctx_len: | ||
orig_ctx_len *= orig_rope_scaling_factor | ||
if data_args.max_length > orig_ctx_len: | ||
scaling_factor = float(math.ceil(data_args.max_length / orig_ctx_len)) | ||
model_config.rope_scaling_factor = scaling_factor | ||
model_config.rope_scaling_type = "linear" | ||
|
||
if not training_args.autotuner_benchmark: | ||
model = AutoModelForCausalLM.from_pretrained( | ||
model_args.model_name_or_path, | ||
|
@@ -233,7 +252,7 @@ def neft_post_hook(module, input, output): | |
if tokenizer.chat_template is not None: | ||
data_args.eval_with_do_generation = False | ||
|
||
if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, Llama3Tokenizer): | ||
if isinstance(tokenizer, LlamaTokenizer): | ||
tokenizer.pad_token_id = tokenizer.eos_token_id | ||
|
||
if data_args.dataset_name_or_path is None: | ||
|
@@ -297,11 +316,7 @@ def neft_post_hook(module, input, output): | |
else: | ||
train_ds = None | ||
if training_args.do_eval: | ||
dev_ds = load_dataset( | ||
"json", | ||
data_files=glob.glob(os.path.join(data_args.dataset_name_or_path, "dev", "*.json")), | ||
lazy=data_args.lazy, | ||
)[0] | ||
dev_ds = load_dataset('/home/mg/proof-pile',splits=["test"],cache_dir='proof-pile')[0] | ||
else: | ||
dev_ds = None | ||
if quant_args.do_ptq or quant_args.do_gptq: | ||
|
@@ -330,7 +345,7 @@ def neft_post_hook(module, input, output): | |
else: | ||
train_ds = None | ||
if training_args.do_eval: | ||
dev_ds = load_dataset(data_args.dataset_name_or_path, splits=["dev"])[0] | ||
dev_ds = load_dataset(data_args.dataset_name_or_path, splits=["test"])[0] | ||
else: | ||
dev_ds = None | ||
if quant_args.do_ptq or quant_args.do_gptq: | ||
|
@@ -361,7 +376,6 @@ def neft_post_hook(module, input, output): | |
|
||
if training_args.pipeline_parallel_degree > 1: | ||
from data import convert_example_common | ||
|
||
trans_func = partial(convert_example_common, tokenizer=tokenizer, data_args=data_args) | ||
else: | ||
trans_func = partial(get_convert_example(model), tokenizer=tokenizer, data_args=data_args) | ||
|
@@ -388,6 +402,7 @@ def neft_post_hook(module, input, output): | |
"`zero_padding` conflicts with `eval_with_do_generation`. Setting zero_padding to False for the eval_dataset." | ||
) | ||
eval_intokens = False | ||
|
||
dev_ds = ( | ||
dev_ds.map(partial(trans_func, is_test=data_args.eval_with_do_generation, intokens=eval_intokens)) | ||
if dev_ds is not None | ||
|
@@ -458,7 +473,7 @@ def neft_post_hook(module, input, output): | |
|
||
if model_args.lora: | ||
if model_args.lora_path is None: | ||
target_modules = get_lora_target_modules(model) | ||
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"] | ||
lora_config = LoRAConfig( | ||
target_modules=target_modules, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lora_config本身就有一个参数叫 |
||
r=model_args.lora_rank, | ||
|
@@ -474,9 +489,27 @@ def neft_post_hook(module, input, output): | |
use_quick_lora=model_args.use_quick_lora, | ||
) | ||
model = LoRAModel(model, lora_config) | ||
|
||
# model.mark_only_lora_as_trainable() | ||
# model.print_trainable_parameters() | ||
trainable_keywords = ["embed","norm"] | ||
# set embedding and norm trainable | ||
for name, param in model.named_parameters(): | ||
make_trainable = False | ||
for keyword in trainable_keywords: | ||
if keyword in name: | ||
make_trainable = True | ||
break | ||
if make_trainable: | ||
param.stop_gradient = False | ||
model.config.use_cache = False | ||
|
||
model.recompute_enable() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 直接在执行finetune_generation.py json脚本里配置"recompute":"true" |
||
for param in model.parameters(): | ||
if not param.stop_gradient and param.grad is None: | ||
param.clear_gradient() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这是干什么用 |
||
else: | ||
model = LoRAModel.from_pretrained(model=model, lora_path=model_args.lora_path) | ||
|
||
model.print_trainable_parameters() | ||
|
||
def compute_metrics_do_generation(eval_preds): | ||
|
@@ -532,13 +565,18 @@ def compute_metrics_do_generation(eval_preds): | |
eval_dataset=dev_ds, | ||
tokenizer=tokenizer, | ||
compute_metrics=metrics, | ||
# data_collator=DataCollatorForLanguageModeling( | ||
# tokenizer=tokenizer, | ||
# return_tensors="np", | ||
# mlm=False, | ||
# pad_to_multiple_of=data_args.max_length, | ||
# ), | ||
data_collator=DataCollatorForSeq2Seq( | ||
tokenizer=tokenizer, | ||
max_length=max_length, | ||
padding=padding, | ||
max_label_length=max_length, | ||
return_tensors="np", | ||
pad_to_multiple_of=data_args.pad_to_multiple_of, | ||
max_length=max_length, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个改动是为了什么 |
||
padding=True, | ||
pad_to_multiple_of=data_args.max_length, | ||
), | ||
do_generation=data_args.eval_with_do_generation, | ||
callbacks=[InTokensIterDatasetCallback()] if isinstance(train_ds, InTokensIterableDataset) else None, | ||
|
@@ -553,7 +591,8 @@ def compute_metrics_do_generation(eval_preds): | |
checkpoint = training_args.resume_from_checkpoint | ||
elif last_checkpoint is not None: | ||
checkpoint = last_checkpoint | ||
train_result = trainer.train(resume_from_checkpoint=checkpoint) | ||
# train_result = trainer.train(resume_from_checkpoint=checkpoint) | ||
train_result = trainer.train() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ?这行不要改 |
||
if model_args.neftune: | ||
neft_post_hook_handle.remove() | ||
if training_args.benchmark: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议不要改动数据集读取逻辑,先将数据预处理符合llm模型格式,然后按照现有的方式加载进来