Skip to content

Commit

Permalink
Implement dataloader constructor for piaf dataset. (#93)
Browse files Browse the repository at this point in the history
* Implement dataloader constructor for piaf dataset.

* Fix for batched mapping.

* Fix for batched mapping.

* Black formatting.

* Prepare for rebasing.
  • Loading branch information
Davidyz committed Jun 9, 2024
1 parent a64e345 commit afbb8c5
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 5 deletions.
46 changes: 41 additions & 5 deletions llm_unlearn_ucl/unlearn_harm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from utils import (
compute_kl,
create_mathqa_dataloader_from_dataset,
create_piaf_dataloader_from_dataset,
create_pku_dataloader_from_dataset,
create_squad_dataloader_from_dataset,
create_symbolic_dataloader_from_dataset,
Expand Down Expand Up @@ -315,6 +316,46 @@ def main(args) -> None:
# ADDITONALLY: create_truthfulqa_dataloader() is also using this pattern!!!
question_prefix_str = "### Question:"
answer_prefix_str = "### Answer:"
elif args.unlearning_dataset == "AgentPublic/piaf":
# filter entries with harmful responses and draw random samples from the remaining dataset.
full_bad_dataset = load_dataset("AgentPublic/piaf", split="train").filter(
lambda entry: len(entry["answers"]["text"]) != 0
)
if args.shuffle_seed:
# shuffle the dataset with a given seed for reproducibility
full_bad_dataset = full_bad_dataset.shuffle(seed=args.shuffle_seed)
if args.sequential > 0:
# NOTE: sequential/batch unlearning using sliced dataset.
train_bad_dataset = full_bad_dataset.select(range(args.samples_count))
else:
# NOTE: full dataset like bytedance.
train_bad_dataset = full_bad_dataset

Path(args.samples_save_dir).mkdir(exist_ok=True)
bad_sample_path = f"{args.samples_save_dir}/piaf_{args.samples_count if args.sequential > 0 else 'full'}_samples.json"
with open(bad_sample_path, "w") as fin:
print(f"Writing bad samples to {bad_sample_path}")
json.dump(
[
train_bad_dataset[i]
for i in range(
args.samples_count
if args.sequential > 0
else len(train_bad_dataset)
)
],
fin,
)

train_bad_loaders = create_piaf_dataloader_from_dataset(
tokenizer,
train_bad_dataset,
batch_size=args.batch_size,
splits=max(args.sequential, 1),
)

question_prefix_str = "### Question:"
answer_prefix_str = "### Réponse:"
elif args.unlearning_dataset == "sail/symbolic-instruction-tuning":
# filter entries with harmful responses and draw random samples from the remaining dataset.
full_bad_dataset = load_dataset(
Expand Down Expand Up @@ -353,11 +394,6 @@ def main(args) -> None:
splits=max(args.sequential, 1),
)

# XXX: for now this is the prefix that is added before each q and answer,
# it is used by get_rand_ans_loss() to extract just the question part and
# add a random answer to it.
# !!!! Has additional sideffect of model unlearning this pattern!!!!
# ADDITONALLY: create_truthfulqa_dataloader() is also using this pattern!!!
question_prefix_str = "### Question:"
answer_prefix_str = "### Answer:"
elif args.unlearning_dataset == "math_qa":
Expand Down
74 changes: 74 additions & 0 deletions llm_unlearn_ucl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,80 @@ def preprocess(examples):
return dataloaders


def create_piaf_dataloader_from_dataset(
tokenizer, dataset, fraction=1.0, batch_size=4, splits: int = 1
):
"""
Given the piaf dataset, create the dataloader on the unlearned French Q&A pairs.
Args:
tokenizer: Tokenizer.
dataset: Loaded piaf dataset.
fraction: <1 will do downsampling.
batch_size: Batch size used for each step.
splits: The number of splits that the dataset will be sliced into.
Returns:
A List of DataLoader of piaf French Q&A pairs.
"""

def preproccess(examples):
"""
Input: Dict[List]
Output: Dict[List]
"""
results = {
"input_ids": [],
"attention_mask": [],
"start_locs": [],
}
for i in range(len(examples["answers"])):
prompt = examples["context"][i] + " " + examples["question"][i]
response = examples["answers"][i]["text"][0]

text = f"### Question: {prompt}\n ### Réponse: {response}"
tokenized = tokenizer(text, truncation=True, padding="max_length")

results["input_ids"].append(tokenized["input_ids"])
results["attention_mask"].append(tokenized["attention_mask"])
# Calculate start idx for answer
test_text = f"### Question: {prompt}\n ### Réponse: "
test_tokenized = tokenizer(test_text, truncation=True, padding="max_length")
results["start_locs"].append(len(test_tokenized["input_ids"]) - 1)

return results

dataset = dataset.map(
preproccess,
batched=True,
remove_columns=[
"answers",
"context",
"id",
"question",
"title",
],
)
dataset.set_format(
type="torch",
columns=["input_ids", "attention_mask", "start_locs"],
)

# Add labels and make it data loader.
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# TODO: data_collator introduces extra/less processed samples.
dataloaders = [
torch.utils.data.DataLoader(
train_split_dataset, batch_size=batch_size, collate_fn=data_collator
)
for train_split_dataset in torch.utils.data.random_split(
dataset, tuple(len(dataset) // splits for i in range(splits))
)
]

return dataloaders


def create_mathqa_dataloader_from_dataset(
tokenizer, dataset, fraction=1.0, batch_size=4
):
Expand Down

0 comments on commit afbb8c5

Please sign in to comment.