Skip to content

Commit

Permalink
Support all types with --disable_group_texts 1
Browse files Browse the repository at this point in the history
- Now data types "text2text" and future data types are also supported
- Fix the bug when model_max_length is not provided (such as +inf)
  • Loading branch information
research4pan committed Sep 22, 2023
1 parent 6fae940 commit e90101c
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions src/lmflow/models/hf_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,9 +500,6 @@ def tokenize(self, dataset, add_special_tokens=True, *args, **kwargs):
if model_args.use_lora or data_args.disable_group_texts:
use_truncation = True

# Whether to pad short sequences to max_length
padding = "max_length" if data_args.disable_group_texts else False

def tokenize_function(examples):
num_example = len(examples[column_names[0]])
token_dict = {
Expand All @@ -516,7 +513,6 @@ def tokenize_function(examples):
examples[column_name],
add_special_tokens=add_special_tokens,
truncation=use_truncation,
padding=padding,
)

if column_name in label_columns:
Expand All @@ -536,6 +532,34 @@ def tokenize_function(examples):
)
token_dict["labels"][i].extend(labels[i])

if data_args.disable_group_texts:
for i in range(num_example):
block_size = data_args.block_size
max_length = min(block_size, self.get_max_length())
pad_length = max_length - len(token_dict["input_ids"][i])
if block_size < self.get_max_length():
logger.warning(
f"block_size {block_size} < model_max_length"
f" {self.get_max_length()}, use block_size"
" for maximum tokenized sequence length"
)
if pad_length < 0:
# Truncates too long samples
for key in ["input_ids", "attention_mask", "labels"]:
token_dict[key][i] = token_dict[key][i][:pad_length]
else:
# Pads too short samples
pad_token_id = self.tokenizer.pad_token_id
token_dict["input_ids"][i].extend(
[pad_token_id for _ in range(pad_length)]
)
token_dict["attention_mask"][i].extend(
[0 for _ in range(pad_length)]
)
token_dict["labels"][i].extend(
[-100 for _ in range(pad_length)]
)

# clm input could be much much longer than block_size
if "Token indices sequence length is longer than the" in cl.out:
tok_logger.warning(
Expand Down

0 comments on commit e90101c

Please sign in to comment.