From 41a24d8969b3b460f308e466fb513d43ec16b6bb Mon Sep 17 00:00:00 2001 From: "tobias.pitters" Date: Thu, 6 Apr 2023 07:48:45 +0200 Subject: [PATCH 1/8] add alpaca reverse augmentation functionality --- model/model_training/configs/config.yaml | 8 +++++++- .../model_training/custom_datasets/__init__.py | 4 ++-- .../custom_datasets/qa_datasets.py | 18 ++++++++++++------ 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/model/model_training/configs/config.yaml b/model/model_training/configs/config.yaml index e38c6ec1d2..006d698f8f 100644 --- a/model/model_training/configs/config.yaml +++ b/model/model_training/configs/config.yaml @@ -256,8 +256,9 @@ llama-30b: save_total_limit: 4 use_flash_attention: true -pythia: +pythia-70m-deduped: learning_rate: 8e-6 + # model_name: EleutherAI/pythia-1b-deduped model_name: EleutherAI/pythia-70m-deduped weight_decay: 0.0 max_length: 520 @@ -267,6 +268,11 @@ pythia: per_device_train_batch_size: 2 per_device_eval_batch_size: 4 output_dir: pythia_model + datasets: + - alpaca: + reverse_augmentation: True + - webgpt + pythia-1B: learning_rate: 8e-6 diff --git a/model/model_training/custom_datasets/__init__.py b/model/model_training/custom_datasets/__init__.py index 1a8aff33d8..5a92bd3f37 100644 --- a/model/model_training/custom_datasets/__init__.py +++ b/model/model_training/custom_datasets/__init__.py @@ -93,9 +93,9 @@ def get_one_dataset( elif dataset_name == "webgpt": dataset = WebGPT(mode=mode) elif dataset_name == "alpaca": - dataset = Alpaca(mode=mode, cache_dir=data_path) + dataset = Alpaca(mode=mode, cache_dir=data_path, **kwargs) elif dataset_name == "code_alpaca": - dataset = CodeAlpaca(mode=mode, cache_dir=data_path) + dataset = CodeAlpaca(mode=mode, cache_dir=data_path, **kwargs) elif dataset_name == "gpt4all": dataset = Gpt4All(mode=mode, cache_dir=data_path) elif dataset_name == "prosocial_dialogue": diff --git a/model/model_training/custom_datasets/qa_datasets.py b/model/model_training/custom_datasets/qa_datasets.py index 8347f589fe..a0f3b428cf 100644 --- a/model/model_training/custom_datasets/qa_datasets.py +++ b/model/model_training/custom_datasets/qa_datasets.py @@ -422,18 +422,22 @@ def __getitem__(self, index): class AlpacaBase(Dataset): - def __init__(self, dataset_name: str, mode: str, cache_dir: str = None) -> None: + def __init__(self, dataset_name: str, mode: str, reverse_augmentation: bool = False, cache_dir: str = None) -> None: super().__init__() self.mode = mode dataset = load_dataset(dataset_name, cache_dir=cache_dir) rows = [] + import pdb; pdb.set_trace() for row in dataset["train"]: question = row["instruction"] if len(row["input"]) > 0: input_ = "{}\n{}".format(question, row["input"]) else: input_ = question - rows.append((input_, row["output"])) + if reverse_augmentation: + rows.append((row["output"], input_)) + else: + rows.append((input_, row["output"])) self.rows = rows def __len__(self): @@ -445,13 +449,15 @@ def __getitem__(self, index): return (question, answer) elif self.mode == "rl": return (question,) + else: + raise NotImplementedError(f"Alpaca Dataset for mode {self.mode} is not implemented. Currently supported modes are 'sft' and 'rl'.") class Alpaca(AlpacaBase): - def __init__(self, mode: str = "sft", cache_dir: str = None) -> None: - super().__init__(dataset_name="yahma/alpaca-cleaned", mode=mode, cache_dir=cache_dir) + def __init__(self, mode: str = "sft", cache_dir: str = None, **kwargs) -> None: + super().__init__(dataset_name="yahma/alpaca-cleaned", mode=mode, cache_dir=cache_dir, **kwargs) class CodeAlpaca(AlpacaBase): - def __init__(self, mode: str = "sft", cache_dir: str = None) -> None: - super().__init__(dataset_name="sahil2801/CodeAlpaca-20k", mode=mode, cache_dir=cache_dir) + def __init__(self, mode: str = "sft", cache_dir: str = None, **kwargs) -> None: + super().__init__(dataset_name="sahil2801/CodeAlpaca-20k", mode=mode, cache_dir=cache_dir, **kwargs) From ebd0a620f9afeb086d45e4770347efb1a1ef1fc7 Mon Sep 17 00:00:00 2001 From: "tobias.pitters" Date: Thu, 6 Apr 2023 07:50:30 +0200 Subject: [PATCH 2/8] remove debug statement --- model/model_training/custom_datasets/qa_datasets.py | 1 - 1 file changed, 1 deletion(-) diff --git a/model/model_training/custom_datasets/qa_datasets.py b/model/model_training/custom_datasets/qa_datasets.py index aedc0c45dc..b073bdfc2f 100644 --- a/model/model_training/custom_datasets/qa_datasets.py +++ b/model/model_training/custom_datasets/qa_datasets.py @@ -425,7 +425,6 @@ def __init__(self, dataset_name: str, mode: str, reverse_augmentation: bool = Fa self.mode = mode dataset = load_dataset(dataset_name, cache_dir=cache_dir) rows = [] - import pdb; pdb.set_trace() for row in dataset["train"]: question = row["instruction"] if len(row["input"]) > 0: From 076030a8392f163b803c254fe5c756c1daba3c41 Mon Sep 17 00:00:00 2001 From: "tobias.pitters" Date: Thu, 6 Apr 2023 08:03:10 +0200 Subject: [PATCH 3/8] update qa datasets and config --- model/model_training/configs/config.yaml | 1 - model/model_training/custom_datasets/qa_datasets.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/model/model_training/configs/config.yaml b/model/model_training/configs/config.yaml index 006d698f8f..645ed68bab 100644 --- a/model/model_training/configs/config.yaml +++ b/model/model_training/configs/config.yaml @@ -273,7 +273,6 @@ pythia-70m-deduped: reverse_augmentation: True - webgpt - pythia-1B: learning_rate: 8e-6 model_name: EleutherAI/pythia-1b-deduped diff --git a/model/model_training/custom_datasets/qa_datasets.py b/model/model_training/custom_datasets/qa_datasets.py index b073bdfc2f..be0e6eebe5 100644 --- a/model/model_training/custom_datasets/qa_datasets.py +++ b/model/model_training/custom_datasets/qa_datasets.py @@ -447,7 +447,9 @@ def __getitem__(self, index): elif self.mode == "rl": return (question,) else: - raise NotImplementedError(f"Alpaca Dataset for mode {self.mode} is not implemented. Currently supported modes are 'sft' and 'rl'.") + raise NotImplementedError( + f"Alpaca Dataset for mode {self.mode} is not implemented. Currently supported modes are 'sft' and 'rl'." + ) class Alpaca(AlpacaBase): From eb07ecff43e74dde4d6385da07b7620e5c399dae Mon Sep 17 00:00:00 2001 From: "tobias.pitters" Date: Thu, 6 Apr 2023 09:58:32 +0200 Subject: [PATCH 4/8] update alpaca datasets to include train and eval --- .../custom_datasets/__init__.py | 9 +-- .../custom_datasets/qa_datasets.py | 77 +++++++++++-------- 2 files changed, 50 insertions(+), 36 deletions(-) diff --git a/model/model_training/custom_datasets/__init__.py b/model/model_training/custom_datasets/__init__.py index 1af7bf66c0..83e80e60c6 100644 --- a/model/model_training/custom_datasets/__init__.py +++ b/model/model_training/custom_datasets/__init__.py @@ -10,13 +10,12 @@ from model_training.custom_datasets.prompt_dialogue import Gpt4All, load_oig_file from model_training.custom_datasets.qa_datasets import ( SODA, - Alpaca, - CodeAlpaca, JokeExplaination, QADataset, SODADialogue, TranslatedQA, WebGPT, + load_alpaca_dataset, ) from model_training.custom_datasets.rank_datasets import AugmentedOA from model_training.custom_datasets.summarization import HFSummary, SummarizationDataset @@ -114,10 +113,8 @@ def get_one_dataset( dataset = DiveMT() elif dataset_name == "webgpt": dataset = WebGPT(mode=mode) - elif dataset_name == "alpaca": - dataset = Alpaca(mode=mode, cache_dir=data_path, **kwargs) - elif dataset_name == "code_alpaca": - dataset = CodeAlpaca(mode=mode, cache_dir=data_path, **kwargs) + elif dataset_name in ["alpaca", "code_alpace"]: + train, eval = load_alpaca_dataset(dataset_name, val_split=val_split, cache_dir=data_path, **kwargs) elif dataset_name == "gpt4all": dataset = Gpt4All(mode=mode, cache_dir=data_path) elif dataset_name == "prosocial_dialogue": diff --git a/model/model_training/custom_datasets/qa_datasets.py b/model/model_training/custom_datasets/qa_datasets.py index be0e6eebe5..2b14062eb4 100644 --- a/model/model_training/custom_datasets/qa_datasets.py +++ b/model/model_training/custom_datasets/qa_datasets.py @@ -11,7 +11,8 @@ import numpy as np from datasets import load_dataset -from torch.utils.data import Dataset +from torch import Generator +from torch.utils.data import Dataset, Subset, random_split # @agoryuno contributed this re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]") @@ -419,44 +420,60 @@ def __getitem__(self, index): return self.pairs[index] -class AlpacaBase(Dataset): - def __init__(self, dataset_name: str, mode: str, reverse_augmentation: bool = False, cache_dir: str = None) -> None: +class AlpacaDataset(Dataset): + def __init__(self, data: list, mode: str): super().__init__() + self.data = data + if mode not in ["sft", "rl"]: + raise NotImplementedError( + f"Alpaca Dataset for mode {self.mode} is not implemented. Currently supported modes are 'sft' and 'rl'." + ) self.mode = mode - dataset = load_dataset(dataset_name, cache_dir=cache_dir) - rows = [] - for row in dataset["train"]: - question = row["instruction"] - if len(row["input"]) > 0: - input_ = "{}\n{}".format(question, row["input"]) - else: - input_ = question - if reverse_augmentation: - rows.append((row["output"], input_)) - else: - rows.append((input_, row["output"])) - self.rows = rows def __len__(self): - return len(self.rows) + return len(self.data) def __getitem__(self, index): - question, answer = self.rows[index] + question, answer = self.data[index] if self.mode == "sft": return (question, answer) elif self.mode == "rl": return (question,) - else: - raise NotImplementedError( - f"Alpaca Dataset for mode {self.mode} is not implemented. Currently supported modes are 'sft' and 'rl'." - ) -class Alpaca(AlpacaBase): - def __init__(self, mode: str = "sft", cache_dir: str = None, **kwargs) -> None: - super().__init__(dataset_name="yahma/alpaca-cleaned", mode=mode, cache_dir=cache_dir, **kwargs) - - -class CodeAlpaca(AlpacaBase): - def __init__(self, mode: str = "sft", cache_dir: str = None, **kwargs) -> None: - super().__init__(dataset_name="sahil2801/CodeAlpaca-20k", mode=mode, cache_dir=cache_dir, **kwargs) +def load_alpaca_dataset( + dataset_name: str, + val_split: float, + cache_dir: str, + mode: str = "sft", + manual_seed: int = 287631038922, + reverse_augmentation: bool = False, +) -> tuple[AlpacaDataset, AlpacaDataset]: + # split on tree basis, messages from same tree must not end up in different splits + generator = Generator() + generator.manual_seed(manual_seed) + + def process_split(dataset: Subset, reverse_augmentation: bool = False) -> list[tuple[str, str]]: + data = [] + for row in dataset: + question = row["instruction"] + if len(row["input"]) > 0: + input_ = "{}\n{}".format(question, row["input"]) + else: + input_ = question + if reverse_augmentation: + data.append((row["output"], input_)) + else: + data.append((input_, row["output"])) + return data + + assert dataset_name in ["alpaca", "code_alpaca"] + if dataset_name == "alpaca": + dataset = load_dataset("yahma/alpaca-cleaned", cache_dir=cache_dir) + elif dataset_name == "code_alpaca": + dataset = load_dataset("sahil2801/CodeAlpaca-20k", cache_dir=cache_dir) + + splits = random_split(dataset["train"], lengths=[1.0 - val_split, val_split], generator=generator) + train = AlpacaDataset(process_split(splits[0], reverse_augmentation=reverse_augmentation), mode=mode) + val = AlpacaDataset(process_split(splits[1], reverse_augmentation=False), mode=mode) + return train, val From 42877bb818219a941fc4e5e0bc6815751a024497 Mon Sep 17 00:00:00 2001 From: "tobias.pitters" Date: Thu, 6 Apr 2023 10:08:22 +0200 Subject: [PATCH 5/8] add keep_unreversed keyword --- .../custom_datasets/qa_datasets.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/model/model_training/custom_datasets/qa_datasets.py b/model/model_training/custom_datasets/qa_datasets.py index 2b14062eb4..c4d866da0a 100644 --- a/model/model_training/custom_datasets/qa_datasets.py +++ b/model/model_training/custom_datasets/qa_datasets.py @@ -448,12 +448,14 @@ def load_alpaca_dataset( mode: str = "sft", manual_seed: int = 287631038922, reverse_augmentation: bool = False, + keep_unreversed: bool = True, ) -> tuple[AlpacaDataset, AlpacaDataset]: - # split on tree basis, messages from same tree must not end up in different splits generator = Generator() generator.manual_seed(manual_seed) - def process_split(dataset: Subset, reverse_augmentation: bool = False) -> list[tuple[str, str]]: + def process_split( + dataset: Subset, reverse_augmentation: bool = False, keep_unreversed: bool = True + ) -> list[tuple[str, str]]: data = [] for row in dataset: question = row["instruction"] @@ -463,6 +465,9 @@ def process_split(dataset: Subset, reverse_augmentation: bool = False) -> list[t input_ = question if reverse_augmentation: data.append((row["output"], input_)) + # in case of reverse augmentation we just keep both, reversed and unreversed data + if keep_unreversed: + data.append((input_, row["output"])) else: data.append((input_, row["output"])) return data @@ -474,6 +479,10 @@ def process_split(dataset: Subset, reverse_augmentation: bool = False) -> list[t dataset = load_dataset("sahil2801/CodeAlpaca-20k", cache_dir=cache_dir) splits = random_split(dataset["train"], lengths=[1.0 - val_split, val_split], generator=generator) - train = AlpacaDataset(process_split(splits[0], reverse_augmentation=reverse_augmentation), mode=mode) - val = AlpacaDataset(process_split(splits[1], reverse_augmentation=False), mode=mode) + train = AlpacaDataset( + process_split(splits[0], reverse_augmentation=reverse_augmentation, keep_unreversed=keep_unreversed), mode=mode + ) + val = AlpacaDataset( + process_split(splits[1], reverse_augmentation=False, keep_unreversed=keep_unreversed), mode=mode + ) return train, val From 3939377b7ee9dac9c68631d287615e56e0e79b07 Mon Sep 17 00:00:00 2001 From: "tobias.pitters" Date: Thu, 6 Apr 2023 10:17:05 +0200 Subject: [PATCH 6/8] use different classnames for alpaca and codealpaca --- .../custom_datasets/qa_datasets.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/model/model_training/custom_datasets/qa_datasets.py b/model/model_training/custom_datasets/qa_datasets.py index c4d866da0a..1e2e447b07 100644 --- a/model/model_training/custom_datasets/qa_datasets.py +++ b/model/model_training/custom_datasets/qa_datasets.py @@ -420,7 +420,7 @@ def __getitem__(self, index): return self.pairs[index] -class AlpacaDataset(Dataset): +class AlpacaBaseDataset(Dataset): def __init__(self, data: list, mode: str): super().__init__() self.data = data @@ -441,6 +441,14 @@ def __getitem__(self, index): return (question,) +class AlpacaDataset(AlpacaBaseDataset): + pass + + +class CodeAlpacaDataset(AlpacaBaseDataset): + pass + + def load_alpaca_dataset( dataset_name: str, val_split: float, @@ -449,7 +457,7 @@ def load_alpaca_dataset( manual_seed: int = 287631038922, reverse_augmentation: bool = False, keep_unreversed: bool = True, -) -> tuple[AlpacaDataset, AlpacaDataset]: +) -> tuple[AlpacaDataset, AlpacaDataset] | tuple[CodeAlpacaDataset, CodeAlpacaDataset]: generator = Generator() generator.manual_seed(manual_seed) @@ -475,14 +483,14 @@ def process_split( assert dataset_name in ["alpaca", "code_alpaca"] if dataset_name == "alpaca": dataset = load_dataset("yahma/alpaca-cleaned", cache_dir=cache_dir) + cls = AlpacaDataset elif dataset_name == "code_alpaca": dataset = load_dataset("sahil2801/CodeAlpaca-20k", cache_dir=cache_dir) + cls = CodeAlpacaDataset splits = random_split(dataset["train"], lengths=[1.0 - val_split, val_split], generator=generator) - train = AlpacaDataset( + train = cls( process_split(splits[0], reverse_augmentation=reverse_augmentation, keep_unreversed=keep_unreversed), mode=mode ) - val = AlpacaDataset( - process_split(splits[1], reverse_augmentation=False, keep_unreversed=keep_unreversed), mode=mode - ) + val = cls(process_split(splits[1], reverse_augmentation=False, keep_unreversed=keep_unreversed), mode=mode) return train, val From 6895461d967728b4d3b0b5fa16ac8b27729eae92 Mon Sep 17 00:00:00 2001 From: "tobias.pitters" Date: Thu, 6 Apr 2023 10:29:27 +0200 Subject: [PATCH 7/8] fix issues --- model/model_training/custom_datasets/__init__.py | 2 +- model/model_training/custom_datasets/qa_datasets.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/model/model_training/custom_datasets/__init__.py b/model/model_training/custom_datasets/__init__.py index 83e80e60c6..ad73a08401 100644 --- a/model/model_training/custom_datasets/__init__.py +++ b/model/model_training/custom_datasets/__init__.py @@ -113,7 +113,7 @@ def get_one_dataset( dataset = DiveMT() elif dataset_name == "webgpt": dataset = WebGPT(mode=mode) - elif dataset_name in ["alpaca", "code_alpace"]: + elif dataset_name in ["alpaca", "code_alpaca"]: train, eval = load_alpaca_dataset(dataset_name, val_split=val_split, cache_dir=data_path, **kwargs) elif dataset_name == "gpt4all": dataset = Gpt4All(mode=mode, cache_dir=data_path) diff --git a/model/model_training/custom_datasets/qa_datasets.py b/model/model_training/custom_datasets/qa_datasets.py index 1e2e447b07..c7f9f2bf42 100644 --- a/model/model_training/custom_datasets/qa_datasets.py +++ b/model/model_training/custom_datasets/qa_datasets.py @@ -480,13 +480,14 @@ def process_split( data.append((input_, row["output"])) return data - assert dataset_name in ["alpaca", "code_alpaca"] if dataset_name == "alpaca": dataset = load_dataset("yahma/alpaca-cleaned", cache_dir=cache_dir) cls = AlpacaDataset elif dataset_name == "code_alpaca": dataset = load_dataset("sahil2801/CodeAlpaca-20k", cache_dir=cache_dir) cls = CodeAlpacaDataset + else: + raise ValueError(f"Expected dataset_name to be 'alapaca' or 'code_alpaca'. Received {dataset_name}.") splits = random_split(dataset["train"], lengths=[1.0 - val_split, val_split], generator=generator) train = cls( From d447a6879e64f6629754d5c3ca1de66e8d9caa65 Mon Sep 17 00:00:00 2001 From: "tobias.pitters" Date: Fri, 7 Apr 2023 08:39:39 +0200 Subject: [PATCH 8/8] updates due to PR discussions --- model/model_training/configs/config.yaml | 4 ---- .../custom_datasets/__init__.py | 2 +- .../custom_datasets/qa_datasets.py | 20 ++++++------------- 3 files changed, 7 insertions(+), 19 deletions(-) diff --git a/model/model_training/configs/config.yaml b/model/model_training/configs/config.yaml index 645ed68bab..f918c00ad9 100644 --- a/model/model_training/configs/config.yaml +++ b/model/model_training/configs/config.yaml @@ -268,10 +268,6 @@ pythia-70m-deduped: per_device_train_batch_size: 2 per_device_eval_batch_size: 4 output_dir: pythia_model - datasets: - - alpaca: - reverse_augmentation: True - - webgpt pythia-1B: learning_rate: 8e-6 diff --git a/model/model_training/custom_datasets/__init__.py b/model/model_training/custom_datasets/__init__.py index ad73a08401..862299e5b5 100644 --- a/model/model_training/custom_datasets/__init__.py +++ b/model/model_training/custom_datasets/__init__.py @@ -113,7 +113,7 @@ def get_one_dataset( dataset = DiveMT() elif dataset_name == "webgpt": dataset = WebGPT(mode=mode) - elif dataset_name in ["alpaca", "code_alpaca"]: + elif dataset_name in ("alpaca", "code_alpaca"): train, eval = load_alpaca_dataset(dataset_name, val_split=val_split, cache_dir=data_path, **kwargs) elif dataset_name == "gpt4all": dataset = Gpt4All(mode=mode, cache_dir=data_path) diff --git a/model/model_training/custom_datasets/qa_datasets.py b/model/model_training/custom_datasets/qa_datasets.py index c7f9f2bf42..bb8eee5659 100644 --- a/model/model_training/custom_datasets/qa_datasets.py +++ b/model/model_training/custom_datasets/qa_datasets.py @@ -424,7 +424,7 @@ class AlpacaBaseDataset(Dataset): def __init__(self, data: list, mode: str): super().__init__() self.data = data - if mode not in ["sft", "rl"]: + if mode not in ("sft", "rl"): raise NotImplementedError( f"Alpaca Dataset for mode {self.mode} is not implemented. Currently supported modes are 'sft' and 'rl'." ) @@ -441,14 +441,6 @@ def __getitem__(self, index): return (question,) -class AlpacaDataset(AlpacaBaseDataset): - pass - - -class CodeAlpacaDataset(AlpacaBaseDataset): - pass - - def load_alpaca_dataset( dataset_name: str, val_split: float, @@ -457,7 +449,7 @@ def load_alpaca_dataset( manual_seed: int = 287631038922, reverse_augmentation: bool = False, keep_unreversed: bool = True, -) -> tuple[AlpacaDataset, AlpacaDataset] | tuple[CodeAlpacaDataset, CodeAlpacaDataset]: +) -> tuple[AlpacaBaseDataset, AlpacaBaseDataset]: generator = Generator() generator.manual_seed(manual_seed) @@ -482,16 +474,16 @@ def process_split( if dataset_name == "alpaca": dataset = load_dataset("yahma/alpaca-cleaned", cache_dir=cache_dir) - cls = AlpacaDataset elif dataset_name == "code_alpaca": dataset = load_dataset("sahil2801/CodeAlpaca-20k", cache_dir=cache_dir) - cls = CodeAlpacaDataset else: raise ValueError(f"Expected dataset_name to be 'alapaca' or 'code_alpaca'. Received {dataset_name}.") splits = random_split(dataset["train"], lengths=[1.0 - val_split, val_split], generator=generator) - train = cls( + train = AlpacaBaseDataset( process_split(splits[0], reverse_augmentation=reverse_augmentation, keep_unreversed=keep_unreversed), mode=mode ) - val = cls(process_split(splits[1], reverse_augmentation=False, keep_unreversed=keep_unreversed), mode=mode) + val = AlpacaBaseDataset( + process_split(splits[1], reverse_augmentation=False, keep_unreversed=keep_unreversed), mode=mode + ) return train, val