diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index f664ceaf19..4864122569 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -1,7 +1,7 @@ """ High level functions for model training """ -from custom_datasets.prompt_dialogue import InstructionTuning, PromptGeneratedDataset +from custom_datasets.prompt_dialogue import InstructionTuning, PrivateInstructionTuning, PromptGeneratedDataset from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, TranslatedQA, WebGPT from custom_datasets.summarization import SummarizationDataset from custom_datasets.toxic_conversation import ProsocialDialogue, ProsocialDialogueExplaination @@ -32,7 +32,7 @@ "debate_sum", "tldr_news", ] -OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning"] +OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_tuning"] def train_val_dataset(dataset, val_split=0.2): @@ -92,6 +92,9 @@ def get_one_dataset(conf, dataset_name): elif dataset_name == "instruct_tuning": dataset = InstructionTuning(conf.cache_dir) train, eval = train_val_dataset(dataset, val_split=0.2) + elif dataset_name == "private_tuning": + dataset = PrivateInstructionTuning(conf.cache_dir) + train, eval = train_val_dataset(dataset, val_split=0.2) elif dataset_name == "translate_qa": dataset = TranslatedQA(conf.cache_dir) train, eval = train_val_dataset(dataset, val_split=0.01) diff --git a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py index 1c82393484..4aac265549 100644 --- a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py +++ b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py @@ -2,7 +2,7 @@ import os from urllib.request import urlopen -from custom_datasets.formatting import format_pair +from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair from torch.utils.data import Dataset @@ -102,3 +102,45 @@ def __len__(self): def __getitem__(self, index): return format_pair(self.pairs[index]) + + +class PrivateInstructionTuning(Dataset): + """ + We have seen some promising capabilities from instruction tuning + with the following mix of datasets that are derived from datasets + available online. + The files for this data are in json format as a list of tuples + where each tuple is (source,instruction_response_pair) + + Not to be confused with unatural instruction + """ + + name = "private_tuning" + filename = "oa_v3_fixed_plus_safety.jsonl" + + def __init__(self, cache_dir) -> None: + super().__init__() + os.makedirs(cache_dir, exist_ok=True) + + self.pairs = [] + for file_link in [self.filename]: + basename = file_link.split("/")[-1] + instruction_tune_file = os.path.join(cache_dir, basename) + + with open(instruction_tune_file, "r", encoding="utf-8") as f: + for line in f: + row = json.loads(line) + prefix = "" + for _, convo in enumerate(row["text"].split("User:")): + if "Assistant" in convo: + prompt, answer = convo.split("Assistant:", maxsplit=1) + answer = answer.replace("<|endoftext|>", "").strip() + self.pairs.append((prefix + QA_SPECIAL_TOKENS["Question"] + prompt, answer)) + prefix += "".join(format_pair((prompt, answer))) + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, index): + prompt, answer = self.pairs[index] + return "{}{}".format(prompt, QA_SPECIAL_TOKENS["Answer"]), answer diff --git a/model/supervised_finetuning/custom_datasets/qa_datasets.py b/model/supervised_finetuning/custom_datasets/qa_datasets.py index 2acf910675..7876a920e9 100644 --- a/model/supervised_finetuning/custom_datasets/qa_datasets.py +++ b/model/supervised_finetuning/custom_datasets/qa_datasets.py @@ -3,7 +3,6 @@ """ import json import os -import random import re from urllib.request import urlopen @@ -116,7 +115,7 @@ class QADataset(Dataset): "reddit_asks": {"name": "eli5", "index_fn": index_eli5, "split_postfix": "_asks"}, } - def __init__(self, dataset, cache_dir, split, mix_prob=0.2): + def __init__(self, dataset, cache_dir, split): self.no_val = False if dataset in self.DATASET_FORMAT_MAPPING: context = self.DATASET_FORMAT_MAPPING[dataset] @@ -139,23 +138,11 @@ def __init__(self, dataset, cache_dir, split, mix_prob=0.2): else: raise ValueError("Unknown dataset : " + dataset) self.length = len(self.dataset) - self.mix_prob = mix_prob def __len__(self): return self.length def __getitem__(self, idx): - if self.mix_prob > 0 and random.random() < self.mix_prob and idx > 5 and idx < (self.length - 5): - - additional = random.randint(0, 10) - 5 - while additional == idx: - additional = random.randint(0, 10) - 5 - - answer_pair = self.index_fn(self.dataset[additional + idx]) - history_text = "".join(format_pair(answer_pair)) - question, answer = self.index_fn(self.dataset[idx]) - question = history_text + question - return format_pair((question, answer)) data = self.dataset[idx] return format_pair(self.index_fn(data)) @@ -312,9 +299,8 @@ class JokeExplaination(Dataset): name = "joke" url = "https://gist.github.com/theblackcat102/42b697e24a13fdb499e20edfbf618361/raw/1834dca207898c15f93b809d1195f6f6e47c9e1e/joke_explained.jsonl" - def __init__(self, cache_dir, mix_prob=0.2) -> None: + def __init__(self, cache_dir) -> None: super().__init__() - self.mix_prob = mix_prob os.makedirs(cache_dir, exist_ok=True) joke_explain_filename = os.path.join(cache_dir, "joke_explaination.jsonl") if not os.path.exists(joke_explain_filename): @@ -341,16 +327,6 @@ def __len__(self): return self.length def __getitem__(self, index): - if random.random() < self.mix_prob and index > 5 and index < (self.length - 5): - additional = random.randint(0, 10) - 5 - while additional == index: - additional = random.randint(0, 10) - 5 - - history_text = "".join(format_pair(self.pairs[additional + index])) - question, answer = self.pairs[index] - question = history_text + question - return format_pair((question, answer)) - return format_pair(self.pairs[index]) @@ -358,9 +334,8 @@ class TranslatedQA(Dataset): name = "oa_translated" - def __init__(self, cache_dir, mix_prob=0.2) -> None: + def __init__(self, cache_dir) -> None: super().__init__() - self.mix_prob = mix_prob os.makedirs(cache_dir, exist_ok=True) path = os.path.join(cache_dir, "oa_translated") os.makedirs(path, exist_ok=True) @@ -383,14 +358,4 @@ def __len__(self): return self.length def __getitem__(self, index): - if random.random() < self.mix_prob and index > 5 and index < (self.length - 5): - additional = random.randint(0, 10) - 5 - while additional == index: - additional = random.randint(0, 10) - 5 - - history_text = "".join(format_pair(self.pairs[additional + index])) - question, answer = self.pairs[index] - question = history_text + question - return format_pair((question, answer)) - return format_pair(self.pairs[index])