diff --git a/model/model_training/custom_datasets/__init__.py b/model/model_training/custom_datasets/__init__.py index 4c66d06008..f1b8c92082 100644 --- a/model/model_training/custom_datasets/__init__.py +++ b/model/model_training/custom_datasets/__init__.py @@ -13,7 +13,7 @@ ) from model_training.custom_datasets.oasst_dataset import load_oasst_export from model_training.custom_datasets.pretrain_datasets import FanFics, RedPajama -from model_training.custom_datasets.prompt_dialogue import DolphinMix, Gpt4All, OrcaChat, load_oig_file +from model_training.custom_datasets.prompt_dialogue import DolphinMix, Gpt4All, OrcaChat, load_oig_file, BestOfMegacode from model_training.custom_datasets.qa_datasets import ( SODA, AlpacaGpt4, @@ -188,6 +188,8 @@ def get_one_dataset( dataset = DolphinMix(cache_dir=data_path, **kwargs) elif dataset_name in RAG_DATASETS.keys(): dataset = RAGDataset(dataset_name, cache_dir=data_path, **kwargs) + elif dataset_name == "bestofmegacode": + dataset = BestOfMegacode(cache_dir=data_path, **kwargs) else: raise ValueError(f"Unknown dataset {dataset_name}") diff --git a/model/model_training/custom_datasets/prompt_dialogue.py b/model/model_training/custom_datasets/prompt_dialogue.py index 1d30458cb3..d573798141 100644 --- a/model/model_training/custom_datasets/prompt_dialogue.py +++ b/model/model_training/custom_datasets/prompt_dialogue.py @@ -11,7 +11,8 @@ from model_training.custom_datasets.utils import _filter_by_words from torch import Generator, randperm from torch.utils.data import Dataset, random_split - +import datasets +datasets.builder.has_sufficient_disk_space = lambda needed_bytes, directory='.': True def load_oig_file( source_url: str, @@ -172,15 +173,18 @@ def __getitem__(self, index: int) -> list[str] | tuple[str]: class OrcaChat(Dataset): name = "orca-chat" - def __init__(self, data_files: Union[List[str], str] = "orca-chat-gpt4.json", cache_dir: str = None) -> None: - self.dataset = load_dataset("shahules786/orca-chat", split="train", data_files=data_files, cache_dir=cache_dir) - + def __init__(self, rows_per_conv: int = 1, use_auth_token: Optional[Union[bool, str]] = None, cache_dir: str = None) -> None: + self.dataset = load_dataset("shahules786/orca-best", split="train", + use_auth_token=use_auth_token, + cache_dir=cache_dir) + self.rows_per_conv = rows_per_conv + def __len__(self): return len(self.dataset) def __getitem__(self, idx): conversation, instruction = [self.dataset[idx][key] for key in ("conversation", "instruction")] - conversation = [(item["input"], item["output"]) for item in conversation] + conversation = [(item["input"], item["output"]) for item in conversation["samples"][:self.rows_per_conv]] conversation = list(sum(conversation, ())) conv_utt: list[Utterance] = [ ( @@ -195,6 +199,34 @@ def __getitem__(self, idx): return DatasetEntrySft(conversation=conv_utt, system_message=instruction) +class BestOfMegacode(Dataset): + name = "bestofmegacode" + + def __init__(self, rows_per_conv: int = 1, use_auth_token: Optional[Union[bool, str]] = None, cache_dir: str = None) -> None: + self.dataset = load_dataset("shahules786/megacode-best", split="train", + use_auth_token=use_auth_token, + cache_dir=cache_dir) + self.rows_per_conv = rows_per_conv + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + conversation = [self.dataset[idx][key] for key in ("conversation")] + conversation = [(item["USER"], item["ASSISTANT"]) for item in conversation["samples"][:self.rows_per_conv]] + conversation = list(sum(conversation, ())) + conv_utt: list[Utterance] = [ + ( + Utterance( + text=conv, + role=Role.prompter if i % 2 == 0 else Role.assistant, + ) + ) + for i, conv in enumerate(conversation) + ] + + return DatasetEntrySft(conversation=conv_utt) + class DolphinMix(Dataset): name = "dophin-mix"