Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add alpaca reverse augmentation possibility #2342

Merged
merged 10 commits into from
Apr 7, 2023
3 changes: 2 additions & 1 deletion model/model_training/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,9 @@ llama-30b:
save_total_limit: 4
use_flash_attention: true

pythia:
pythia-70m-deduped:
learning_rate: 8e-6
# model_name: EleutherAI/pythia-1b-deduped
model_name: EleutherAI/pythia-70m-deduped
weight_decay: 0.0
max_length: 520
Expand Down
9 changes: 3 additions & 6 deletions model/model_training/custom_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@
from model_training.custom_datasets.prompt_dialogue import Gpt4All, load_oig_file
from model_training.custom_datasets.qa_datasets import (
SODA,
Alpaca,
CodeAlpaca,
JokeExplaination,
QADataset,
SODADialogue,
TranslatedQA,
Vicuna,
WebGPT,
load_alpaca_dataset,
)
from model_training.custom_datasets.rank_datasets import AugmentedOA
from model_training.custom_datasets.summarization import HFSummary, HFSummaryPairs, SummarizationDataset
Expand Down Expand Up @@ -118,10 +117,8 @@ def get_one_dataset(
dataset = DiveMT()
elif dataset_name == "webgpt":
dataset = WebGPT(mode=mode)
elif dataset_name == "alpaca":
dataset = Alpaca(mode=mode, cache_dir=data_path)
elif dataset_name == "code_alpaca":
dataset = CodeAlpaca(mode=mode, cache_dir=data_path)
elif dataset_name in ("alpaca", "code_alpaca"):
train, eval = load_alpaca_dataset(dataset_name, val_split=val_split, cache_dir=data_path, **kwargs)
elif dataset_name == "gpt4all":
dataset = Gpt4All(mode=mode, cache_dir=data_path)
elif dataset_name == "prosocial_dialogue":
Expand Down
76 changes: 55 additions & 21 deletions model/model_training/custom_datasets/qa_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

import numpy as np
from datasets import load_dataset
from torch.utils.data import Dataset
from torch import Generator
from torch.utils.data import Dataset, Subset, random_split

# @agoryuno contributed this
re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]")
Expand Down Expand Up @@ -420,40 +421,73 @@ def __getitem__(self, index):
return self.pairs[index]


class AlpacaBase(Dataset):
def __init__(self, dataset_name: str, mode: str, cache_dir: str = None) -> None:
class AlpacaBaseDataset(Dataset):
def __init__(self, data: list, mode: str):
super().__init__()
self.data = data
if mode not in ("sft", "rl"):
raise NotImplementedError(
f"Alpaca Dataset for mode {self.mode} is not implemented. Currently supported modes are 'sft' and 'rl'."
)
self.mode = mode
dataset = load_dataset(dataset_name, cache_dir=cache_dir)
rows = []
for row in dataset["train"]:
question = row["instruction"]
if len(row["input"]) > 0:
input_ = "{}\n{}".format(question, row["input"])
else:
input_ = question
rows.append((input_, row["output"]))
self.rows = rows

def __len__(self):
return len(self.rows)
return len(self.data)

def __getitem__(self, index):
question, answer = self.rows[index]
question, answer = self.data[index]
if self.mode == "sft":
return (question, answer)
elif self.mode == "rl":
return (question,)


class Alpaca(AlpacaBase):
def __init__(self, mode: str = "sft", cache_dir: str = None) -> None:
super().__init__(dataset_name="yahma/alpaca-cleaned", mode=mode, cache_dir=cache_dir)
def load_alpaca_dataset(
dataset_name: str,
val_split: float,
cache_dir: str,
mode: str = "sft",
manual_seed: int = 287631038922,
reverse_augmentation: bool = False,
keep_unreversed: bool = True,
) -> tuple[AlpacaBaseDataset, AlpacaBaseDataset]:
generator = Generator()
generator.manual_seed(manual_seed)

def process_split(
dataset: Subset, reverse_augmentation: bool = False, keep_unreversed: bool = True
) -> list[tuple[str, str]]:
data = []
for row in dataset:
question = row["instruction"]
if len(row["input"]) > 0:
input_ = "{}\n{}".format(question, row["input"])
else:
input_ = question
if reverse_augmentation:
data.append((row["output"], input_))
# in case of reverse augmentation we just keep both, reversed and unreversed data
if keep_unreversed:
data.append((input_, row["output"]))
else:
data.append((input_, row["output"]))
return data

if dataset_name == "alpaca":
dataset = load_dataset("yahma/alpaca-cleaned", cache_dir=cache_dir)
elif dataset_name == "code_alpaca":
dataset = load_dataset("sahil2801/CodeAlpaca-20k", cache_dir=cache_dir)
else:
raise ValueError(f"Expected dataset_name to be 'alapaca' or 'code_alpaca'. Received {dataset_name}.")

class CodeAlpaca(AlpacaBase):
def __init__(self, mode: str = "sft", cache_dir: str = None) -> None:
super().__init__(dataset_name="sahil2801/CodeAlpaca-20k", mode=mode, cache_dir=cache_dir)
splits = random_split(dataset["train"], lengths=[1.0 - val_split, val_split], generator=generator)
train = AlpacaBaseDataset(
process_split(splits[0], reverse_augmentation=reverse_augmentation, keep_unreversed=keep_unreversed), mode=mode
)
val = AlpacaBaseDataset(
process_split(splits[1], reverse_augmentation=False, keep_unreversed=keep_unreversed), mode=mode
)
return train, val


class Vicuna(Dataset):
Expand Down