Skip to content

Commit

Permalink
refactor datasets and oa private data selection
Browse files Browse the repository at this point in the history
  • Loading branch information
Sotirios Anagnostidis committed Feb 11, 2023
1 parent 23ee2f2 commit ac97943
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 33 deletions.
6 changes: 6 additions & 0 deletions model/supervised_finetuning/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ defaults:
log_wandb: true
samples_mixing: false # uses collator that mixes samples in the batch to create a single sample with possible multiple tasks within

oa_dataset_only:
datasets:
- oa_pricate:
data_path: .cache
val_split: 0.0

galactica-125m:
learning_rate: 5e-5
model_name: facebook/galactica-125m
Expand Down
57 changes: 24 additions & 33 deletions model/supervised_finetuning/custom_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"debate_sum",
"tldr_news",
]
OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_tuning", "oa_translated"]
OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_tuning", "oa_translated", "oa_private"]


def train_val_dataset(dataset, val_split=0.2):
Expand All @@ -42,63 +42,54 @@ def train_val_dataset(dataset, val_split=0.2):
return Subset(dataset, train_idx), Subset(dataset, val_idx)


def get_one_dataset(conf, dataset_name, val_split=0.2, **kwargs):
def get_one_dataset(conf, dataset_name, val_split=0.2, data_path=None, **kwargs):
data_path = data_path or conf.cache_dir
dataset_name = dataset_name.lower()

if dataset_name in QA_DATASETS:
train = QADataset(dataset_name, conf.cache_dir, "train")
if train.no_val:
train, eval = train_val_dataset(train, val_split=val_split, **kwargs)
else:
eval = QADataset(dataset_name, conf.cache_dir, "validation")
train = QADataset(dataset_name, data_path, "train")
if not train.no_val:
eval = QADataset(dataset_name, data_path, "validation")
elif dataset_name in SUMMARIZATION_DATASETS:
train = SummarizationDataset(dataset_name, conf.cache_dir, "train")
if dataset_name == "debate_sum":
train, eval = train_val_dataset(train, val_split=val_split, **kwargs)
else:
eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation")
train = SummarizationDataset(dataset_name, data_path, "train")
if dataset_name != "debate_sum":
eval = SummarizationDataset(dataset_name, data_path, "validation")
elif "ted_trans" in dataset_name:
language_pair = dataset_name.split("_")[-1]
dataset = TEDTalk(pair=language_pair, split="train")
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
elif "wmt2019" in dataset_name:
language_pair = dataset_name.split("_")[-1]
train = WMT2019(pair=language_pair, split="train")
eval = WMT2019(pair=language_pair, split="validation")
elif dataset_name == "dive_mt":
dataset = DiveMT()
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
elif dataset_name == "webgpt":
dataset = WebGPT()
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
elif dataset_name == "prompt_dialogue":
dataset = PromptGeneratedDataset(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
dataset = PromptGeneratedDataset(data_path)
elif dataset_name == "prosocial_dialogue":
train = ProsocialDialogue(cache_dir=conf.cache_dir, split="train")
eval = ProsocialDialogue(cache_dir=conf.cache_dir, split="validation")
train = ProsocialDialogue(cache_dir=data_path, split="train")
eval = ProsocialDialogue(cache_dir=data_path, split="validation")
elif dataset_name == "explain_prosocial":
train = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="train")
eval = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="validation")
train = ProsocialDialogueExplaination(cache_dir=data_path, split="train")
eval = ProsocialDialogueExplaination(cache_dir=data_path, split="validation")
elif dataset_name == "soda":
dataset = SODA(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
dataset = SODA(data_path)
elif dataset_name == "soda_dialogue":
dataset = SODADialogue(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
dataset = SODADialogue(data_path)
elif dataset_name == "joke":
dataset = JokeExplaination(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
dataset = JokeExplaination(data_path)
elif dataset_name == "instruct_tuning":
dataset = InstructionTuning(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
dataset = InstructionTuning(data_path)
elif dataset_name == "private_tuning":
dataset = PrivateInstructionTuning(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
dataset = PrivateInstructionTuning(data_path)
elif dataset_name == "oa_translated":
dataset = TranslatedQA(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) # TODO make val split lower..?
dataset = TranslatedQA(data_path) # TODO make val_split lower..?
else:
raise ValueError(f"Unknown dataset {dataset_name}")

# if eval not already defined
if "dataset" in locals():
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)

return train, eval
60 changes: 60 additions & 0 deletions model/supervised_finetuning/custom_datasets/prompt_dialogue.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,71 @@
import json
import math
import os
import random
from collections import OrderedDict
from functools import reduce
from urllib.request import urlopen

import numpy as np
from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair
from torch.utils.data import Dataset


class OAPrivate(Dataset):
file = "2023-02-10_oasst_prod.jsonl"
splits = OrderedDict(sft=0.25, reward_model=0.4, rl=0.35) # fractions per task

def __init__(self, split="sft", data_path=".cache") -> None:
super().__init__()

total_prob = reduce(lambda prev, split: prev + split[1], self.splits.items(), 0)
assert math.isclose(total_prob, 1), "Make sure OAPrivate split ratios add to 1"

jsonl_file = os.path.join(data_path, self.file)

with open(jsonl_file, "r", encoding="utf-8") as f:
lines = f.readlines()

# take a subset of the dataset based on the split
rng = np.random.default_rng(seed=0)
indices = np.arange(len(lines)).astype(int)
rng.shuffle(indices)

cumsums = np.cumsum([[0] + list(self.splits.values())])
split_index = list(self.splits.keys()).index(split)

start_index, end_index = int(cumsums[split_index] * len(lines)), int(cumsums[split_index + 1] * len(lines))

self.data = [json.loads(lines[index].strip()) for index in indices[start_index:end_index]]

def __len__(self):
return len(self.data)

def __getitem__(self, index):
# Sample randomly from replies
prompt = self.data[index]["prompt"]

pairs = []

while True:
assert prompt["role"] == "prompter"
prompter_text = prompt["text"]

if len(prompt["replies"]) == 0:
break

reply = random.choice(prompt["replies"])
reply_text = reply["text"]
pairs.append([prompter_text, reply_text])

if len(reply["replies"]) == 0:
break

prompt = random.choice(reply["replies"])

return pairs


class PromptGeneratedDataset(Dataset):
"""Generates from flan 11B
User: What are the best methods for preventing a slave trade?
Expand Down

0 comments on commit ac97943

Please sign in to comment.