Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[transformers] Prompt masking #2192

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion src/sparseml/export/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import os.path
from collections import OrderedDict
from pathlib import Path
from typing import Callable, List, Optional, Union
from typing import Callable, List, Optional
from typing import OrderedDict as OrderedDictType
from typing import Union

import numpy
import onnx
Expand Down
20 changes: 16 additions & 4 deletions src/sparseml/transformers/finetune/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class TextGenerationDataset(RegistryMixin):
"""

PROMPT_KEY = "prompt"
MASK_KEY = "mask"

def __init__(
self,
Expand Down Expand Up @@ -125,6 +126,7 @@ def tokenize_fn(data):
padding=self.padding,
max_length=self.max_seq_length,
truncation=True,
return_offsets_mapping=True,
)

# store unpadded prompt so we can mask out correct number of elements
Expand Down Expand Up @@ -156,16 +158,29 @@ def group_text_fn(data):
def label_fn(data):
# if the dataset uses prompts, mask them out so they don't contribute
# to the loss calculation
labels = data["input_ids"].copy()
if "offset_mapping" in data:
offset_mapping = data["offset_mapping"]
# get the character level mask
mask = data.get("mask")
if mask is not None:
for i, (start, end) in enumerate(offset_mapping):
# if any char is to be filtered
if "0" in mask[start:end]:
labels[i] = LABELS_MASK_VALUE
horheynm marked this conversation as resolved.
Show resolved Hide resolved

prompt_len = 0
if self.PROMPT_KEY in data:
prompt_len = len(data[self.PROMPT_KEY])
data["labels"] = data["input_ids"].copy()

data["labels"] = labels
data["labels"][:prompt_len] = [LABELS_MASK_VALUE] * prompt_len

# mask out padding in the labels as well
padding = len(data["attention_mask"]) - sum(data["attention_mask"])
if padding > 0:
data["labels"][-padding:] = [LABELS_MASK_VALUE] * padding

return data

dataset = self.map(
Expand Down Expand Up @@ -206,8 +221,6 @@ def label_fn(data):
load_from_cache_file=not self.data_args.overwrite_cache,
desc="Adding labels",
)
print(dataset.column_names)

return dataset

def map(
Expand All @@ -226,5 +239,4 @@ def map(
kwargs.pop("num_proc", None)
kwargs.pop("load_from_cache_file", None)
kwargs.pop("desc", None)

return dataset.map(**kwargs)
3 changes: 2 additions & 1 deletion src/sparseml/transformers/finetune/data/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def get_raw_dataset(self, *_ignore, **__ignore) -> Union[DatasetDict, Dataset]:
num_proc=self.data_args.preprocessing_num_workers,
desc="Removing unneeded columns",
)

return raw_dataset

def get_remove_columns_from_dataset(
Expand All @@ -108,5 +107,7 @@ def get_remove_columns_from_dataset(
remove_columns.remove(self.text_column)
if self.PROMPT_KEY in remove_columns:
remove_columns.remove(self.PROMPT_KEY)
if self.MASK_KEY in remove_columns:
remove_columns.remove(self.MASK_KEY)

return list(remove_columns)
51 changes: 51 additions & 0 deletions src/sparseml/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"ALL_TASK_NAMES",
"create_fake_dataloader",
"POSSIBLE_TOKENIZER_FILES",
"generate_mask",
]


Expand Down Expand Up @@ -556,3 +557,53 @@ def fetch_recipe_path(target: str):
recipe_path = hf_hub_download(repo_id=target, filename=DEFAULT_RECIPE_NAME)

return recipe_path


def generate_mask(string: str, response: str, prompt: str = "") -> str:
"""
Generate a mask based on provided prompt and response strings to obscure
characters in the input string. Prompt will be masked and string in response
will be kept represented by 0 - remove and 1 - keep.
By default, non-reponse wrapped strings will be matched with 0

Args:
:param string: The input string to be masked.
:param prompt: The prompt string to identify characters to obscure.
:param response: The response string to identify characters to keep visible.

Returns:
str: A string representing the mask where '1' indicates visible
characters and '0' indicates obscured characters.

"""

mask = ["1"] * len(string)
Satrat marked this conversation as resolved.
Show resolved Hide resolved
is_prompt = False if string.startswith(response) else True
counter = 0
for i, char in enumerate(string):
if is_prompt:
mask[i] = "0"

if counter > 0:
if not is_prompt and len(prompt) > 1 and char == prompt[counter]:
counter += 1
elif is_prompt and char == response[counter]:
counter += 1
else:
counter = 0

if len(prompt) > 0 and counter == len(prompt) and not is_prompt:
mask[i - counter + 1 : i + 1] = ["0"] * counter

counter = 0
is_prompt = True

if counter == len(response) and is_prompt:
mask[i - counter + 1 : i + 1] = ["1"] * counter

counter = 0
is_prompt = False

if prompt.startswith(char) or response.startswith(char):
counter = 1
return "".join(mask)
4 changes: 4 additions & 0 deletions src/sparseml/transformers/utils/preprocessing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from typing import Dict

from sparseml.transformers.utils.helpers import generate_mask
from sparsezoo.utils.registry import RegistryMixin


Expand All @@ -26,4 +27,7 @@ def custom_evolved_codealpaca_dataset(data: Dict):
PROMPT_DICT = """[Instruction]:\n{instruction}\n\n[Response]:"""
data["prompt"] = PROMPT_DICT.format_map(data)
data["text"] = data["prompt"] + data["output"]
data["mask"] = generate_mask(
data["text"], prompt="[Instruction]", censor="[Response]"
)
return data
33 changes: 33 additions & 0 deletions tests/sparseml/transformers/finetune/test_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
oneshot,
train,
)
from sparseml.transformers.utils.helpers import generate_mask


def test_oneshot_and_finetune(tmp_path: Path):
Expand Down Expand Up @@ -319,3 +320,35 @@ def test_oneshot_with_modifier_object(tmp_path: Path):
splits=splits,
oneshot_device=device,
)


def test_finetune_wout_recipe_with_mask(tmp_path: Path):
recipe_str = None
model = "Xenova/llama2.c-stories15M"
device = "cuda:0"
if not torch.cuda.is_available():
device = "cpu"
dataset = "open_platypus"
concatenate_data = False
output_dir = tmp_path
max_steps = 50
splits = "train"

def preprocessing_func(example):
example["text"] = "[foo]" + example["text"] + "[bar] mask this"
example["mask"] = generate_mask(
example["text"], response="[bar]", prompt="[foo]"
)
return example

train(
model=model,
dataset=dataset,
output_dir=output_dir,
recipe=recipe_str,
max_steps=max_steps,
concatenate_data=concatenate_data,
splits=splits,
oneshot_device=device,
preprocessing_func=preprocessing_func,
)
53 changes: 53 additions & 0 deletions tests/sparseml/transformers/utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from accelerate import init_empty_weights
from sparseml.transformers.utils.helpers import (
create_fake_dataloader,
generate_mask,
infer_recipe_from_model_path,
is_transformer_model,
resolve_recipe_file,
Expand Down Expand Up @@ -166,3 +167,55 @@ def test_save_zoo_directory(tmp_path, stub):
assert zoo_model.validate(minimal_validation=True, validate_onnxruntime=False)
shutil.rmtree(path_to_training_outputs)
shutil.rmtree(save_dir)


@pytest.mark.parametrize(
"string, response, prompt, expected_mask",
[
(
("[foo]hello\n\n" "[bar]world"),
"[bar]",
"[foo]",
("000000000000" "1111111111"),
horheynm marked this conversation as resolved.
Show resolved Hide resolved
),
(
(
"[Instruction]python is\n\n" # 24
"[Response]great\n\n" # 17
"[Instruction]What about Java" # 28
"[Response]Meh" # 13
),
"[Response]",
"[Instruction]",
(
"000000000000000000000000" # 24
"11111111111111111" # 17
"0000000000000000000000000000" # 28
"1111111111111" # 13
),
),
(
("[foo]hello\n\n" "[bar]world"),
"[bar]",
None,
("000000000000" "1111111111"),
),
(
("hello\n\n" "[bar]world"),
"[bar]",
None,
("0000000" "1111111111"),
),
(
("[bar]world" "[foo]hello\n\n" "[bar]world"),
"[bar]",
"[foo]",
("1111111111" "000000000000" "1111111111"),
),
],
)
def test_generate_mask(string, response, prompt, expected_mask):
horheynm marked this conversation as resolved.
Show resolved Hide resolved
if prompt is not None:
assert generate_mask(string, response, prompt) == expected_mask
else:
assert generate_mask(string, response) == expected_mask