Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support sparse mask for sft #8412

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions llm/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class DataArgument:
dataset_name_or_path: str = field(default=None, metadata={"help": "Name or path for dataset"})
task_name: str = field(default=None, metadata={"help": "Additional name to select a more specific task."})
zero_padding: bool = field(default=False, metadata={"help": "Whether to use Zero Padding data stream"})
use_sparse_mask: bool = field(default=False, metadata={"help": "Whether to use Column Sparse Mask"})
src_length: int = field(default=1024, metadata={"help": "The maximum length of source(context) tokens."})
max_length: int = field(
default=2048,
Expand Down
24 changes: 24 additions & 0 deletions llm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,30 @@ def convert_example_common(example, tokenizer, data_args, is_test=True, intokens

return features

def convert_example_using_sparse_mask(example, tokenizer, data_args, is_test=True, intokens=False):
if tokenizer.chat_template is not None:
raise NotImplementedError
#return convert_rounds_example_common(example, tokenizer, data_args, is_test, intokens)

tokenized_source, tokenized_target_input_ids = tokenize_example(tokenizer, example, data_args)
if is_test:
return {
**tokenized_source,
"labels": tokenized_target_input_ids,
}
else:
input_ids = tokenized_source["input_ids"] + tokenized_target_input_ids
source_length = len(tokenized_source["input_ids"])
labels = [-100] * source_length + input_ids[source_length:]
# shift input_ids and labels
input_ids, labels = input_ids[:-1], labels[1:]
seq_length = len(input_ids)
features = {"input_ids": input_ids, "labels": labels}
if "position_ids" in tokenized_source:
features["position_ids"] = list(range(seq_length))

return features


def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, intokens=False):
"""convert multi-rounds conversation example
Expand Down
19 changes: 16 additions & 3 deletions llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)

from paddlenlp.data import DataCollatorForSeq2Seq
from paddlenlp.datasets import InTokensIterableDataset, InTokensMapDataset, load_dataset
from paddlenlp.datasets import InTokensIterableDataset, InTokensMapDataset, InTokensSparseMaskMapDataset, load_dataset
from paddlenlp.metrics import BLEU, Rouge1, Rouge2, RougeL
from paddlenlp.peft import LoRAConfig, LoRAModel, PrefixConfig, PrefixModelForCausalLM
from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint
Expand Down Expand Up @@ -182,6 +182,7 @@ def neft_post_hook(module, input, output):

# Load tokenizer & dataset
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, from_aistudio=model_args.from_aistudio)
#tokenizer.padding_side = "right"
# init chat_template for tokenizer
init_chat_template(tokenizer, model_args.model_name_or_path, data_args.chat_template)

Expand Down Expand Up @@ -315,7 +316,11 @@ def neft_post_hook(module, input, output):
)
train_ds = train_ds.skip(consumed_samples)

if training_args.pipeline_parallel_degree > 1:
if data_args.use_sparse_mask:
from data import convert_example_using_sparse_mask
trans_func = partial(convert_example_using_sparse_mask, tokenizer=tokenizer, data_args=data_args)
print('convert_example_using_sparse_mask')
elif training_args.pipeline_parallel_degree > 1:
from data import convert_example_common

trans_func = partial(convert_example_common, tokenizer=tokenizer, data_args=data_args)
Expand Down Expand Up @@ -350,11 +355,14 @@ def neft_post_hook(module, input, output):
else None
)
if data_args.zero_padding:
if data_args.lazy:
if data_args.use_sparse_mask:
intoken_dataset = InTokensSparseMaskMapDataset
elif data_args.lazy:
intoken_dataset = InTokensIterableDataset
else:
intoken_dataset = InTokensMapDataset
logger.info("Creating Zero Padding Data Stream. This may take a few minutes.")
print('InTokensSparseMaskMapDataset', intoken_dataset)
train_ds = (
intoken_dataset(
train_ds,
Expand Down Expand Up @@ -471,6 +479,10 @@ def compute_metrics_do_generation(eval_preds):
else:
metrics = compute_metrics

padding = "max_length"
max_length = data_args.max_length
training_args.max_length = max_length
return_attention_mask = False if data_args.use_sparse_mask else True
trainer = CausalLMTrainer(
model=model,
args=training_args,
Expand All @@ -484,6 +496,7 @@ def compute_metrics_do_generation(eval_preds):
padding=padding,
max_label_length=max_length,
return_tensors="np",
return_attention_mask=return_attention_mask,
),
do_generation=data_args.eval_with_do_generation,
callbacks=[InTokensIterDatasetCallback()] if isinstance(train_ds, InTokensIterableDataset) else None,
Expand Down
8 changes: 4 additions & 4 deletions llm/llama/sft_argument.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"save_total_limit": 1,
"tensor_parallel_degree": 4,
"pipeline_parallel_degree": 1,
"intokens": true,
"zero_padding": false,
"use_flash_attention": false
}
"zero_padding": true,
"use_sparse_mask": true,
"use_flash_attention": true
}
79 changes: 79 additions & 0 deletions paddlenlp/datasets/intokens_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,44 @@ def _pad_batch_records(cls, batch_records):
batched_features["position_ids"] = np.concatenate(batched_features["position_ids"], axis=-1).tolist()
return batched_features

class InTokensSparseMask:
required_input_keys = ["input_ids", "labels"]
required_output_keys = ["input_ids", "labels", "attn_mask_start_row_indices"]
# Only supported the following keys for InTokens. Keys outside of the set will be ignored.
supported_input_keys = ["input_ids", "labels", "position_ids"]

@classmethod
def _pad_batch_records(cls, batch_records):
# Only consider supported input keys
input_keys = [key for key in batch_records[0].keys() if key in cls.supported_input_keys]

# Check required_keys
for key in cls.required_input_keys:
if key not in input_keys:
raise ValueError(f"feature `{key}` is required for InTokensDataset")
# Output features must include all required output keys
for key in cls.required_output_keys:
if key not in input_keys:
input_keys.append(key)

batched_features = {key: [] for key in input_keys}
cur_len_so_far = 0
for record in batch_records:
batched_features["input_ids"].extend(record["input_ids"])
batched_features["labels"].extend(record["labels"])
seq_length = len(record["input_ids"])
cur_len_so_far += seq_length
batched_features["attn_mask_start_row_indices"].extend(seq_length * [cur_len_so_far])
# NOTE: position_ids is optional and not required by every model
# We append instead of extend here to accomodate 2D position ids
if "position_ids" in record:
batched_features["position_ids"].append(record["position_ids"])
# convert to 2-D [num_head(1), seq_length]
batched_features["attn_mask_start_row_indices"] = np.array(batched_features["attn_mask_start_row_indices"], dtype=np.int32).reshape([1, -1])
if "position_ids" in batched_features:
# Accomodate both 1D and 2D position ids
batched_features["position_ids"] = np.concatenate(batched_features["position_ids"], axis=-1).tolist()
return batched_features

class InTokensMapDataset(InTokens, Dataset):
def __init__(self, data, tokenizer, max_length):
Expand Down Expand Up @@ -131,3 +169,44 @@ def __iter__(self):
if batch_records:
padded_list = self._pad_batch_records(batch_records)
yield padded_list

class InTokensSparseMaskMapDataset(InTokensSparseMask, Dataset):
def __init__(self, data, tokenizer, max_length):
self.tokenizer = tokenizer
self.max_length = max_length
self.new_data = self._create_intokens_data(data)

def _create_intokens_data(self, data):
batch_records, max_len = [], 0
cur_len_so_far = 0

total_data = []
for i in range(len(data)):
record = data[i]
max_len = max(max_len, len(record["input_ids"]))
to_append = (cur_len_so_far + len(record["input_ids"])) <= self.max_length
if to_append:
batch_records.append(record)
cur_len_so_far += len(record["input_ids"])
else:
# exceed max length
padded_list = self._pad_batch_records(batch_records)
total_data.append(padded_list)
# reset
batch_records, max_len = [], 0
cur_len_so_far = 0
# append current data
batch_records.append(record)
cur_len_so_far += len(record["input_ids"])

# remaining data
if batch_records:
padded_list = self._pad_batch_records(batch_records)
total_data.append(padded_list)
return total_data

def __getitem__(self, idx):
return self.new_data[idx]

def __len__(self):
return len(self.new_data)
16 changes: 16 additions & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ def __init__(
self.args = args
self.is_in_train = False
# self.do_grad_scaling = args.fp16
self._tokens_per_sec_per_card_buffer = []
self.skip_first_tokens_per_sec_per_card_buffer = True

# memory metrics - must set up as early as possible
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
Expand Down Expand Up @@ -1226,6 +1228,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
logs["learning_rate"] = float("{0:.3e}".format(self._get_learning_rate()))
logs["global_step"] = int(self.state.global_step)


total_train_batch_size = (
self.args.train_batch_size * self.args.gradient_accumulation_steps * self.args.dataset_world_size
)
Expand All @@ -1239,6 +1242,13 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
)
)

logs["tokens_per_sec_per_card"] = round(
self.args.max_length * logs["interval_samples_per_second"] / self.args.world_size, 2
)

self._tokens_per_sec_per_card_buffer.append(logs["tokens_per_sec_per_card"])
logs["tokens_per_sec_per_card_average"] = round(np.mean(self._tokens_per_sec_per_card_buffer), 2)

self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
self._globalstep_last_start_time = time.time()
Expand All @@ -1261,6 +1271,10 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,

self.log(logs, **kwargs)

if self.skip_first_tokens_per_sec_per_card_buffer and len(self._tokens_per_sec_per_card_buffer) == 5:
self._tokens_per_sec_per_card_buffer = []
self.skip_first_tokens_per_sec_per_card_buffer = False

metrics = None
if self.control.should_evaluate:
if isinstance(self.optimizer, GroupShardedOptimizerStage2) and self.optimizer._broadcast_overlap:
Expand All @@ -1283,6 +1297,8 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
self._save_checkpoint(model, metrics=metrics)
logger.info(f"{self.runtime_timer.log()}")
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
self._tokens_per_sec_per_card_buffer = []
self.skip_first_tokens_per_sec_per_card_buffer = True

def _get_learning_rate(self):
return self.optimizer.get_lr()
Expand Down