Skip to content

Commit

Permalink
[Tokenization] RM textonly and text2text data support
Browse files Browse the repository at this point in the history
  • Loading branch information
Yizhen committed Jun 12, 2024
1 parent d18aad4 commit 58ec809
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 12 deletions.
46 changes: 35 additions & 11 deletions src/lmflow/models/hf_text_regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@
from lmflow.models.interfaces.tunable import Tunable
from lmflow.models.hf_model_mixin import HFModelMixin
from lmflow.models.text_regression_model import TextRegressionModel
from lmflow.tokenization.hf_text_regression_model import tokenize_function
from lmflow.tokenization.hf_text_regression_model import paired_conversation_tokenize_function, tokenize_function
from lmflow.utils.conversation_template import PRESET_TEMPLATES
from lmflow.utils.constants import (
PAIRED_CONVERSATION_DATASET_DESCRIPTION,
TEXT2TEXT_DATASET_DESCRIPTION,
TEXT_ONLY_DATASET_DESCRIPTION,
CONVERSATION_ROLE_NAMES,
)

Expand Down Expand Up @@ -135,14 +137,28 @@ def tokenize(
raw_datasets = dataset
hf_raw_datasets = dataset.get_backend_dataset()
column_names = list(hf_raw_datasets.features) # in paired conversation, for example, would be 'chosen' and 'rejected'

# since this will be pickled to avoid _LazyModule error in Hasher force
# logger loading before tokenize_function
tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")

data_args = raw_datasets.get_data_args()

if dataset_type == "paired_conversation":

# Requires three types of information for tokenizing different datasets
# 1) Which fields require tokenization, e.g.
# "text2float": "text", but not "float"
# "text2text": both "input" and "output"
# 2) How will there tokenized sequence concatenated together, e.g.
# "text_only": "text" -> "text"
# "text2text": "input", "output" -> "input" + "output"
# 3) Which fields require loss in final computation, e.g.
# "text_only": "text"
# "text2text": "output" only
tokenized_column_order = None # Handles 1) and 2)
label_columns = None # Handles 3)
if dataset_type == "text_only":
tokenized_column_order = ["text"]
label_columns = ["text"]
elif dataset_type == "text2text":
tokenized_column_order = ["input", "output"]
label_columns = ["output"]
add_special_tokens = False
elif dataset_type == "paired_conversation":
if data_args.conversation_template:
if data_args.conversation_template in PRESET_TEMPLATES.keys():
conversation_template = PRESET_TEMPLATES[data_args.conversation_template]
Expand All @@ -159,21 +175,29 @@ def tokenize(
raise NotImplementedError(
f"Dataset type \"{dataset_type}\" is not supported, currently"
" only support following data types for HFTextRegressionModel:\n"
f" {PAIRED_CONVERSATION_DATASET_DESCRIPTION}\n"
f" 1) {TEXT_ONLY_DATASET_DESCRIPTION}\n"
f" 2) {TEXT2TEXT_DATASET_DESCRIPTION}\n"
f" 3) {PAIRED_CONVERSATION_DATASET_DESCRIPTION}\n"
)

# Whether to truncate long sequences to fit into max_length
use_truncation = False
if model_args.use_lora or data_args.disable_group_texts:
use_truncation = True

tokenize_fn = tokenize_function
tokenize_fn = paired_conversation_tokenize_function if "conversation" in dataset_type else tokenize_function
tokenize_fn_kwargs = {
"data_args": data_args,
"tokenizer": self.tokenizer,
"column_names": column_names,
"conversation_template": conversation_template
}
if "conversation" in dataset_type:
tokenize_fn_kwargs["conversation_template"] = conversation_template
else:
tokenize_fn_kwargs["label_columns"] = label_columns
tokenize_fn_kwargs["tokenized_column_order"] = tokenized_column_order
tokenize_fn_kwargs["add_special_tokens"] = add_special_tokens
tokenize_fn_kwargs["use_truncation"] = use_truncation

tokenize_kwargs = {}
if not data_args.streaming:
Expand Down
101 changes: 100 additions & 1 deletion src/lmflow/tokenization/hf_text_regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")


def tokenize_function(
def paired_conversation_tokenize_function(
examples,
data_args,
tokenizer,
Expand Down Expand Up @@ -116,4 +116,103 @@ def blocking_paired(
" for maximum tokenized sequence length"
)

return token_dict


def blocking(
token_dict: Dict,
block_size: int,
model_max_length: int,
pad_token_id: int,
) -> Dict:
block_size_warning_num = 0
num_example = len(token_dict[list(token_dict.keys())[0]])
for i in range(num_example):
max_length = min(block_size, model_max_length)
pad_length = max_length - len(token_dict["input_ids"][i])
if block_size < model_max_length:
block_size_warning_num += 1
if pad_length < 0:
# Truncates too long samples
for key in ["input_ids", "attention_mask", "labels"]:
token_dict[key][i] = token_dict[key][i][:pad_length]
else:
# Pads too short samples
token_dict["input_ids"][i].extend(
[pad_token_id for _ in range(pad_length)]
)
token_dict["attention_mask"][i].extend(
[0 for _ in range(pad_length)]
)
token_dict["labels"][i].extend(
[-100 for _ in range(pad_length)]
)
if block_size_warning_num > 0:
logger.warning(
f"There are {block_size_warning_num} of {num_example} samples where"
f"block_size {block_size} < model_max_length"
f" {model_max_length}, use block_size"
" for maximum tokenized sequence length"
)

return token_dict


def tokenize_function(
examples,
data_args,
tokenizer,
column_names,
label_columns,
tokenized_column_order,
add_special_tokens,
use_truncation,
) -> Dict:
"""Handels text_only and text2text datasets tokenization
"""
num_example = len(examples[column_names[0]])
token_dict = {
"input_ids": [[] for _ in range(num_example)],
"attention_mask": [[] for _ in range(num_example)],
"labels": [[] for _ in range(num_example)],
}
with CaptureLogger(tok_logger) as cl:
for column_name in tokenized_column_order:
encoding = tokenizer(
examples[column_name],
add_special_tokens=add_special_tokens,
truncation=use_truncation,
)

if column_name in label_columns:
labels = encoding["input_ids"].copy()
else:
labels = [
[-100] * len(encoding["input_ids"][i])
for i in range(num_example)
]

for i in range(num_example):
token_dict["input_ids"][i].extend(
encoding["input_ids"][i]
)
token_dict["attention_mask"][i].extend(
encoding["attention_mask"][i]
)
token_dict["labels"][i].extend(labels[i])

if data_args.disable_group_texts:
token_dict = blocking(
token_dict=token_dict,
block_size=data_args.block_size,
model_max_length=tokenizer.model_max_length,
pad_token_id=tokenizer.pad_token_id,
)

# clm input could be much much longer than block_size
if "Token indices sequence length is longer than the" in cl.out:
tok_logger.warning(
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
" before being passed to the model."
)
return token_dict

0 comments on commit 58ec809

Please sign in to comment.