Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more datasets and some fixes #1455

Merged
merged 13 commits into from
Feb 11, 2023
42 changes: 42 additions & 0 deletions model/reward/instructor/rank_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@


"""
import os
from dataclasses import dataclass
from typing import Dict, List, Optional, Union

Expand Down Expand Up @@ -298,3 +299,44 @@ def __getitem__(self, index):
context, pair = self.pairs[index]

return context, [pair]


class OAPrivate(Dataset):
"""
{
"prompt": <prompt string>,
"history": [("prompt1", "answer2"), ("prompt2", "answer2")],
"pos": <pos answer string>,
"neg_replies": [list of bad answers]
}
"""

split_name_mapping = {
"train": "rm_train.jsonl",
"test": "rm_test.jsonl",
"val": "rm_val.jsonl",
}

def __init__(self, split="train", sep_token="<sep>", data_path=".cache") -> None:
super().__init__()
import json

jsonl_file = os.path.join(data_path, self.split_name_mapping[split])
self.pairs = []
with open(jsonl_file, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
prefix = sep_token.join([sep_token.join(p) for p in data["history"][-2:]])
prefix += sep_token + data["prompt"]
pair = []
for neg_text in data["neg_replies"]:
pair.append((data["pos"], neg_text))
self.pairs.append((prefix, pair))

def __len__(self):
return len(self.pairs)

def __getitem__(self, index):
context, pair = self.pairs[index]

return context, pair
8 changes: 7 additions & 1 deletion model/reward/instructor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def argument_parsing(parser):


def get_datasets(dataset_list: List[AnyStr], tokenizer):
from rank_datasets import AnthropicRLHF, GPTJSynthetic, HFSummary, WebGPT
from rank_datasets import AnthropicRLHF, GPTJSynthetic, HFSummary, OAPrivate, WebGPT
from torch.utils.data import ConcatDataset

train_datasets, evals = [], {}
Expand All @@ -141,5 +141,11 @@ def get_datasets(dataset_list: List[AnyStr], tokenizer):
eval = AnthropicRLHF("test", tokenizer.sep_token)
train_datasets.append(train)
evals["anthropic_rlhf"] = eval
elif "oa_private" == dataset_name:
train = OAPrivate(split="train", sep_token=tokenizer.sep_token)
eval = OAPrivate(split="val", sep_token=tokenizer.sep_token)
train_datasets.append(train)
evals["oa_private"] = eval

train = ConcatDataset(train_datasets)
return train, evals
12 changes: 9 additions & 3 deletions model/supervised_finetuning/custom_datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
High level functions for model training
"""
from custom_datasets.prompt_dialogue import InstructionTuning, PromptGeneratedDataset
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, WebGPT
from custom_datasets.prompt_dialogue import InstructionTuning, PrivateInstructionTuning, PromptGeneratedDataset
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, TranslatedQA, WebGPT
from custom_datasets.summarization import SummarizationDataset
from custom_datasets.toxic_conversation import ProsocialDialogue, ProsocialDialogueExplaination
from custom_datasets.translation import WMT2019, DiveMT, TEDTalk
Expand Down Expand Up @@ -32,7 +32,7 @@
"debate_sum",
"tldr_news",
]
OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning"]
OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_tuning", "oa_translated"]


def train_val_dataset(dataset, val_split=0.2):
Expand Down Expand Up @@ -92,6 +92,12 @@ def get_one_dataset(conf, dataset_name):
elif dataset_name == "instruct_tuning":
dataset = InstructionTuning(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.2)
elif dataset_name == "private_tuning":
dataset = PrivateInstructionTuning(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.2)
elif dataset_name == "oa_translated":
dataset = TranslatedQA(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.01)
else:
raise ValueError(f"Unknown dataset {dataset_name}")

Expand Down
106 changes: 106 additions & 0 deletions model/supervised_finetuning/custom_datasets/dialogue_collator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
from dataclasses import dataclass
from typing import Optional, Union

Expand Down Expand Up @@ -76,3 +77,108 @@ def __call__(self, features):
batch["targets"] = torch.roll(batch["input_ids"], -1, -1)

return batch


@dataclass
class TrainDialogueDataCollator:
"""
Expects a list of texts corresponding to a sequence of [question, answer, question, answer, ...] pairs.
"""

tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
mix_length_threshold: Optional[int] = 256
mix_probability: Optional[int] = 0.6
pad_to_multiple_of: Optional[int] = None

def __call__(self, features):
flatten_messages = []
label_masks = []
total_short_context = 0
for messages in features:
messages = list(messages)

# Add a way for the model to terminate generation
# When we predict the start of a new expected question, we want to be able to stop generation
messages.append(self.tokenizer.eos_token)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this collator to work, we need to replace this eos token with the tag. Let me know if I am wrong.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would change it anyway, for both this collator and the default collator


flatten_message = self.tokenizer(
"".join(messages),
truncation=True,
max_length=self.max_length,
return_offsets_mapping=True,
)

message_change_indices = np.cumsum([len(x) for x in messages[:-1]])
# for each token an integer indicating the index of the message it belongs to. Just to create the label mask.
# Label mask is true when predicting a token that is part of the answer, false otherwise.
# TEXT: Question: Hello, how are you? Answer: I am fine. Question: What is your name? Answer: My name is John. Question:
# MESSAGE_INDICES: 0 0 0 0 0 0 1 1 1 2 2 2 2 2 2 3 3 3 3 -2
# LABEL_MASK: 0 0 0 0 0 1 1 1 1 0 0 0 0 0 1 1 1 1 1 0

# If no result in next, we are predicting the last termination token(s)
message_indices = list(
map(
lambda x: next((i for i, val in enumerate(message_change_indices) if val >= x), -2),
list(map(lambda x: x[1], flatten_message["offset_mapping"])),
)
)
label_mask = np.roll(list(map(lambda x: x % 2 == 1, message_indices)), -1, -1)
try:
label_mask[[i for i in range(len(message_indices)) if message_indices[i] == -2][0] - 1] = True
except IndexError:
# due to truncation, we might not have the last termination token
label_mask[-1] = False

label_masks.append(label_mask)
if len(flatten_message["input_ids"]) < self.mix_length_threshold:
total_short_context += len(flatten_message["input_ids"])
flatten_messages.append({k: v for k, v in flatten_message.items() if k != "offset_mapping"})
# packing
if total_short_context > 2:
_flatten_messages, _label_masks = [], []
prev_short_msg, prev_short_mask = None, None
for flatten_msg, label_mask in zip(flatten_messages, label_masks):
if len(flatten_msg["input_ids"]) < self.mix_length_threshold and random.random() > 0.6:
if prev_short_msg is not None:
for key in flatten_msg.keys():
flatten_msg[key] += prev_short_msg[key]
flatten_msg[key] = flatten_msg[key][: self.max_length]
label_mask = np.concatenate([label_mask, prev_short_mask])
_label_masks.append(label_mask[: self.max_length])
_flatten_messages.append(flatten_msg)
# reset
prev_short_msg, prev_short_mask = None, None
else:
# prime
prev_short_msg, prev_short_mask = flatten_msg, label_mask
else:
_label_masks.append(label_mask)
_flatten_messages.append(flatten_msg)
if prev_short_msg is not None:
for key in flatten_msg.keys():
flatten_msg[key] += prev_short_msg[key]
flatten_msg[key] = flatten_msg[key][: self.max_length]
label_mask = np.concatenate([label_mask, prev_short_mask])[: self.max_length]
_label_masks.append(label_mask)
_flatten_messages.append(flatten_msg)

label_masks = _label_masks
flatten_messages = _flatten_messages

batch = self.tokenizer.pad(
flatten_messages,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
dim = batch["input_ids"].shape[-1]

batch["label_masks"] = torch.stack(
[F.pad(torch.tensor(x), (0, dim - len(x)), value=False) for x in label_masks]
)
batch["targets"] = torch.roll(batch["input_ids"], -1, -1)

return batch
44 changes: 43 additions & 1 deletion model/supervised_finetuning/custom_datasets/prompt_dialogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from urllib.request import urlopen

from custom_datasets.formatting import format_pair
from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair
from torch.utils.data import Dataset


Expand Down Expand Up @@ -102,3 +102,45 @@ def __len__(self):

def __getitem__(self, index):
return format_pair(self.pairs[index])


class PrivateInstructionTuning(Dataset):
"""
We have seen some promising capabilities from instruction tuning
with the following mix of datasets that are derived from datasets
available online.
The files for this data are in json format as a list of tuples
where each tuple is (source,instruction_response_pair)

Not to be confused with unatural instruction
"""

name = "private_tuning"
filename = "oa_v3_fixed_plus_safety.jsonl"

def __init__(self, cache_dir) -> None:
super().__init__()
os.makedirs(cache_dir, exist_ok=True)

self.pairs = []
for file_link in [self.filename]:
basename = file_link.split("/")[-1]
instruction_tune_file = os.path.join(cache_dir, basename)

with open(instruction_tune_file, "r", encoding="utf-8") as f:
for line in f:
row = json.loads(line)
prefix = ""
for _, convo in enumerate(row["text"].split("User:")):
if "Assistant" in convo:
prompt, answer = convo.split("Assistant:", maxsplit=1)
answer = answer.replace("<|endoftext|>", "").strip()
self.pairs.append((prefix + QA_SPECIAL_TOKENS["Question"] + prompt, answer))
prefix += "".join(format_pair((prompt, answer)))

def __len__(self):
return len(self.pairs)

def __getitem__(self, index):
prompt, answer = self.pairs[index]
return "{}{}".format(prompt, QA_SPECIAL_TOKENS["Answer"]), answer
70 changes: 67 additions & 3 deletions model/supervised_finetuning/custom_datasets/qa_datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Open / close book QA datasets
"""
import glob
import json
import os
import re
Expand Down Expand Up @@ -137,11 +138,13 @@ def __init__(self, dataset, cache_dir, split):
self.dataset = load_dataset(context["name"], **context["params"])
else:
raise ValueError("Unknown dataset : " + dataset)
self.length = len(self.dataset)

def __len__(self):
return len(self.dataset)
return self.length

def __getitem__(self, idx):

data = self.dataset[idx]
return format_pair(self.index_fn(data))

Expand Down Expand Up @@ -311,14 +314,75 @@ def __init__(self, cache_dir) -> None:
for line in f:
data = json.loads(line)
joke = data["joke"]
explanation = data["explanation"]
# DO NOT change this
# its the data that had syntax error
explanation = data["explaination"]
self.pairs.append((joke, explanation))

if len(question) > 0 and len(answer) > 0:
self.pairs.append((question, answer))
self.length = len(self.pairs)

def __len__(self):
return len(self.pairs)
return self.length

def __getitem__(self, index):
return format_pair(self.pairs[index])


class TranslatedQA(Dataset):
"""
Translation OA v3 results
a list of non english translation of OA v3 instruction generated text in jsonl
format for each line:
{
"text": "User: ... Assistant: ....",
"meta": {"source": ... },
"translate": [
{ "round": 1, "human":"...", "answer": "..."},
...
{ "round": K, "human":"...", "answer": "..."},
]
}
Since OA contain some code we needed to reference the original text to skip these
"""

name = "oa_translated"

def __init__(self, cache_dir) -> None:
super().__init__()
os.makedirs(cache_dir, exist_ok=True)
path = os.path.join(cache_dir, self.name)
os.makedirs(path, exist_ok=True)
self.pairs = []
for translated_jsonl in glob.glob(os.path.join(path, "*.jsonl")):
with open(translated_jsonl, "r") as fin:
for line in fin:
data = json.loads(line)
if "Python " in data["text"]:
# translation currently doesn't ignore code
# so we will have to reference original text
# for ignoring the translation
continue
prefix = ""
for convo_round in data["translate"]:
human, answer = format_pair((convo_round["human"], convo_round["answer"]))
if convo_round["round"] > 2:
self.pairs.append(("{}{}{}".format(prefix, "<sep>", human), answer))
else:
self.pairs.append((human, answer))

prefix += "{}{}{}{}".format(
QA_SPECIAL_TOKENS["Question"],
convo_round["human"],
QA_SPECIAL_TOKENS["Answer"],
convo_round["answer"],
)

self.length = len(self.pairs)

def __len__(self):
return self.length

def __getitem__(self, index):
return self.pairs[index]
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def index_summary_merge(text, summary):
class SummarizationDataset(Dataset):
def __init__(self, dataset, cache_dir, split, max_words=512):
self.name = dataset
if summarization_config_mapping[dataset][0] in ["billsum", "tldr_news"] & split == "validation":
if (dataset in ["billsum", "tldr_news"]) and (split == "validation"):
split = "test"
self.dataset = load_dataset(*summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
self.text_column, self.summary_column = summarization_name_mapping[dataset]
Expand Down