From b05fd9b664bc7d806a5057c4f1817752a09d9b87 Mon Sep 17 00:00:00 2001 From: Tobias Pitters <31857876+CloseChoice@users.noreply.github.com> Date: Sat, 22 Apr 2023 00:13:11 +0200 Subject: [PATCH] Add poem instruction ds (#2813) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I checked the poem instruction for empty strings, references and appearances of "openai". As expected for a poem dataset, none of these things was found. --------- Co-authored-by: Andreas Köpf --- model/model_training/custom_datasets/__init__.py | 4 ++-- model/model_training/custom_datasets/instruction.py | 1 + model/model_training/custom_datasets/qa_datasets.py | 3 +++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/model/model_training/custom_datasets/__init__.py b/model/model_training/custom_datasets/__init__.py index 4cbf314fe1..13029d750f 100644 --- a/model/model_training/custom_datasets/__init__.py +++ b/model/model_training/custom_datasets/__init__.py @@ -162,9 +162,9 @@ def get_one_dataset( elif dataset_name == "hellaswag": train, eval = load_hellaswag() elif dataset_name == "dolly15k": - dataset = DatabricksDolly15k(cache_dir=data_path) + dataset = DatabricksDolly15k(cache_dir=data_path, mode=mode, **kwargs) elif dataset_name == "alpaca_gpt4": - dataset = AlpacaGpt4(cache_dir=data_path, **kwargs) + dataset = AlpacaGpt4(cache_dir=data_path, mode=mode, **kwargs) else: raise ValueError(f"Unknown dataset {dataset_name}") diff --git a/model/model_training/custom_datasets/instruction.py b/model/model_training/custom_datasets/instruction.py index 3ca022d62f..05aa492836 100644 --- a/model/model_training/custom_datasets/instruction.py +++ b/model/model_training/custom_datasets/instruction.py @@ -18,6 +18,7 @@ "zhihu-kol": "wangrui6/zhihu-kol", "minimath": "kentsui/minimath", "oa_wiki_qa_bart_10000row": "michaelthwan/oa_wiki_qa_bart_10000row", + "poem_instructions": "checkai/instruction-poems", } diff --git a/model/model_training/custom_datasets/qa_datasets.py b/model/model_training/custom_datasets/qa_datasets.py index e79f8c3ddb..a0b2714a82 100644 --- a/model/model_training/custom_datasets/qa_datasets.py +++ b/model/model_training/custom_datasets/qa_datasets.py @@ -21,6 +21,8 @@ # @agoryuno contributed this re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]") re_single_reference_remove = re.compile(r"\[\s?\d+\s?\]") + +# check if the whole string is just a combination of (multiple) whitespaces and newlines re_whitespace_newline_match = re.compile(r"^[\s\n]*$") @@ -450,6 +452,7 @@ def process_split( dataset: Subset, reverse_augmentation: bool = False, keep_unreversed: bool = True ) -> list[tuple[str, str]]: data = [] + for row in dataset: question = row["instruction"] if len(row["input"]) > 0: