diff --git a/model/model_training/custom_datasets/__init__.py b/model/model_training/custom_datasets/__init__.py index 29eb24f3b8..50f0a1bf6e 100644 --- a/model/model_training/custom_datasets/__init__.py +++ b/model/model_training/custom_datasets/__init__.py @@ -5,7 +5,12 @@ import numpy as np from model_training.custom_datasets.extra_rm_datasets import load_anthropic_rlhf, load_hellaswag, load_shp -from model_training.custom_datasets.instruction import INSTRUCTION_DATASETS, InstructionDataset +from model_training.custom_datasets.instruction import ( + INSTRUCTION_DATASETS, + RAG_DATASETS, + InstructionDataset, + RAGDataset, +) from model_training.custom_datasets.oasst_dataset import load_oasst_export from model_training.custom_datasets.pretrain_datasets import FanFics, RedPajama from model_training.custom_datasets.prompt_dialogue import DolphinMix, Gpt4All, OrcaChat, load_oig_file @@ -181,6 +186,8 @@ def get_one_dataset( dataset = OrcaChat(cache_dir=data_path, **kwargs) elif dataset_name == "dolphin-mix": dataset = DolphinMix(cache_dir=data_path, **kwargs) + elif dataset_name in RAG_DATASETS.keys(): + dataset = RAGDataset(dataset_name, cache_dir=data_path, **kwargs) else: raise ValueError(f"Unknown dataset {dataset_name}") diff --git a/model/model_training/custom_datasets/instruction.py b/model/model_training/custom_datasets/instruction.py index 7b6ad39787..e932d26259 100644 --- a/model/model_training/custom_datasets/instruction.py +++ b/model/model_training/custom_datasets/instruction.py @@ -124,3 +124,32 @@ def __getitem__(self, idx) -> DatasetEntry: answers=answers, lang=lang, ) + + +RAG_DATASETS = { + "multi-chapter-summaries": "shahules786/Multi-chapter-summaries", +} + + +class RAGDataset(Dataset): + def __init__( + self, + dataset, + split: str = "train", + cache_dir: str = ".cache/", + ): + if dataset not in RAG_DATASETS.keys(): + raise ValueError(f"Invalid dataset {dataset}") + + if dataset == "multi-chapter-summaries": + self.prompt, self.context, self.response = "prompt", "context", "summary" + + self.dataset = load_dataset(RAG_DATASETS[dataset], cache_dir=cache_dir)[split] + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + prompt, context, response = [self.dataset[idx][key] for key in [self.prompt, self.context, self.response]] + + return create_dataset_entry_qa(mode="sft", questions=[prompt + context], answers=[response])