diff --git a/model/model_training/check_dataset_appearances.py b/model/model_training/check_dataset_appearances.py new file mode 100644 index 0000000000..d80b894e06 --- /dev/null +++ b/model/model_training/check_dataset_appearances.py @@ -0,0 +1,109 @@ +""" +This script should help to detect any keywords or other unwanted appearances in the datasets +RUN WITH: +python check_dataset_appearances.py -d --cache_dir --mode + +e.g.: +python check_dataset_appearances.py -d gpt4all webgpt --cache_dir .cache --mode sft +""" +import argparse +import pprint +from collections import defaultdict + +from model_training.custom_datasets import get_one_dataset +from model_training.custom_datasets.entities import Mode +from model_training.custom_datasets.formatting import DatasetEntry +from model_training.custom_datasets.qa_datasets import ( + re_reference_remove, + re_single_reference_remove, + re_whitespace_newline_match, +) +from model_training.custom_datasets.utils import FILTER_BY_WORDS + +RE_TO_CHECK = [re_whitespace_newline_match, re_reference_remove, re_single_reference_remove] +STRINGS_TO_CHECK = [*FILTER_BY_WORDS] + + +def argument_parsing(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-d", + "--datasets", + nargs="+", + required=True, + help=""" + Multiple datasets can be passed to set different options. + For example, run as: + + ./check_dataset_counts.py --datasets math oasst_export_eu + + to check the counts of the math and the oasst_export_eu dataset. + """, + ) + parser.add_argument("--mode", dest="mode", type=Mode, choices=list(Mode)) + parser.add_argument("--cache_dir", dest="cache_dir", type=str) + + args, _ = parser.parse_known_args() + + return args + + +def check_in_dataset_row(row: str | list[str] | tuple[str], matched=dict[str, list]): + def _check_single_string(row: str, matched: dict[str, list]) -> dict[str, list]: + for exp in RE_TO_CHECK: + if exp.match(row) is not None: + matched[exp].append(row) + for string in STRINGS_TO_CHECK: + if string in row: + string_idx = row.index(string) + matched[string].append(row[max(string_idx - 50, 0) : string_idx + 50]) + return matched + + if isinstance(row, str): + matched = _check_single_string(row, matched) + elif isinstance(row, (list, tuple)): + for r in row: + if not isinstance(r, str): + raise ValueError(f"Unexpected type: {type(row)}") + matched = _check_single_string(r, matched) + elif isinstance(row, DatasetEntry): + formatted = row.get_formatted(mode=args.mode, eos_token="") + for r in formatted: + if not isinstance(r, str): + raise ValueError(f"Unexpected type: {type(r)}") + matched = _check_single_string( + r.replace("<|assistant|>", "").replace("<|prompter|>", "").replace("", ""), matched + ) + else: + raise ValueError(f"Received unexpected type: {type(row)}.") + return matched + + +def iterate_over_dataset(ds): + matched = defaultdict(list) + for row in ds: + check_in_dataset_row(row, matched) + return matched + + +if __name__ == "__main__": + args = argument_parsing() + pp = pprint.PrettyPrinter(indent=4) + + train_datasets, val_datasets = {}, {} + for dataset_name in args.datasets: + print(f"start with dataset {dataset_name}") + train, val = get_one_dataset(None, dataset_name, mode=args.mode.value, data_path=args.cache_dir) + train_datasets[dataset_name] = train + if val is not None: + val_datasets[dataset_name] = val + matched_train = iterate_over_dataset(train) + matched_val = iterate_over_dataset(val) + if len(matched_train) != 0: + pp.pprint(f"Found the following occurances in TRAIN {dataset_name}:") + pp.pprint(dict(matched_train)) + if len(matched_val) != 0: + pp.pprint(f"Found the following occurances in VAL {dataset_name}:") + pp.pprint(dict(matched_val)) + if len(matched_train) + len(matched_val) == 0: + print("Did not find of the specified regular expressions or filter words.") diff --git a/model/model_training/custom_datasets/qa_datasets.py b/model/model_training/custom_datasets/qa_datasets.py index 31dc4ca858..e79f8c3ddb 100644 --- a/model/model_training/custom_datasets/qa_datasets.py +++ b/model/model_training/custom_datasets/qa_datasets.py @@ -20,6 +20,8 @@ # @agoryuno contributed this re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]") +re_single_reference_remove = re.compile(r"\[\s?\d+\s?\]") +re_whitespace_newline_match = re.compile(r"^[\s\n]*$") LINKING_CHARS = ["\n", "\n\n", " "] @@ -486,35 +488,49 @@ class Vicuna(Dataset): name = "vicuna" @staticmethod - def process_vicuna_conversations(data: list[dict[str, None | str]], input_max_length: int) -> list[str] | None: - dialogue = [] + def process_vicuna_conversations( + data: list[dict[str, None | str]], input_max_length: int + ) -> tuple[list[str], list[str]] | None: role = None messages = [] # drop conversations that start with Bot if len(data["conversations"]) == 0 or data["conversations"][0]["from"] != "human": return None + questions = [] + answers = [] for line in data["conversations"]: speaker = line["from"] # 'human' or 'gpt' message = line["value"] - + if message is None or message == "": + if speaker == "gpt": + return None + elif speaker == "human": + # replace empty messages with one of the following + message = random.choice(["...", "Please continue", "Go on", ""]) # remove markdown escaping in revision 192ab2185289094fc556ec8ce5ce1e8e587154ca # python-markdownify with escape_asterisks & escape_underscores True is used # for pre-processing the dataset. # See also https://github.com/LAION-AI/Open-Assistant/issues/2510 message = message.replace(r"\_", "_") message = message.replace(r"\*", "*") + message = re_single_reference_remove.sub("", message) if role != speaker: if role is not None: - dialogue.append("\n".join(messages)) + if role == "human": + questions.append("\n".join(messages)[:input_max_length]) + if role == "gpt": + answers.append("\n".join(messages)[:input_max_length]) messages = [] role = speaker messages.append(message.strip()) if role is not None and len(messages) > 0: - dialogue.append("\n".join(messages)) - dialogue_truncated = [k[:input_max_length] for k in dialogue] - return dialogue_truncated + if role == "human": + questions.append("\n".join(messages)[:input_max_length]) + if role == "gpt": + answers.append("\n".join(messages)[:input_max_length]) + return questions, answers def __init__(self, cache_dir: str | Path, mode: str = "sft", input_max_length: int = 2048) -> None: super().__init__() @@ -530,20 +546,15 @@ def __init__(self, cache_dir: str | Path, mode: str = "sft", input_max_length: i revision="192ab2185289094fc556ec8ce5ce1e8e587154ca", )["train"] for data in dataset: - if ( - processed_data := self.process_vicuna_conversations(data, input_max_length=input_max_length) - ) is not None: - self.pairs.append(processed_data) + if (qa := self.process_vicuna_conversations(data, input_max_length=input_max_length)) is not None: + self.pairs.append(DatasetEntry(questions=qa[0], answers=qa[1])) def __len__(self) -> int: return len(self.pairs) - def __getitem__(self, index: int) -> list[str] | tuple[str]: - dialogue: list[str] = self.pairs[index] - if self.mode == "sft": - return dialogue - elif self.mode == "rl": - return tuple(dialogue[:-1]) + def __getitem__(self, index: int) -> DatasetEntry: + dialogue = self.pairs[index] + return dialogue class DatabricksDolly15k(Dataset):