Skip to content

Commit

Permalink
Add dataset loader for fanfics-10k-50k dataset (#3585)
Browse files Browse the repository at this point in the history
To use the
[atom-in-the-universe/fanfics-10k-500k](https://huggingface.co/datasets/atom-in-the-universe/fanfics-10k-500k)
during training add `fanfics` to the dataset configuration. e.g.:
```
  datasets:
    - fanfics
```
  • Loading branch information
andreaskoepf committed Jul 19, 2023
1 parent 280d3bb commit 64ece88
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 5 deletions.
1 change: 1 addition & 0 deletions model/model_training/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ defaults:
deepspeed_config: configs/zero_config.json
peft_model: false
peft_type: "lora"
superhot: false

use_system_tag:
use_system_tag: True
Expand Down
4 changes: 3 additions & 1 deletion model/model_training/custom_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from model_training.custom_datasets.extra_rm_datasets import load_anthropic_rlhf, load_hellaswag, load_shp
from model_training.custom_datasets.instruction import INSTRUCTION_DATASETS, InstructionDataset
from model_training.custom_datasets.oasst_dataset import load_oasst_export
from model_training.custom_datasets.pretrain_datasets import RedPajama
from model_training.custom_datasets.pretrain_datasets import FanFics, RedPajama
from model_training.custom_datasets.prompt_dialogue import Gpt4All, OrcaChat, load_oig_file
from model_training.custom_datasets.qa_datasets import (
SODA,
Expand Down Expand Up @@ -170,6 +170,8 @@ def get_one_dataset(
dataset = AlpacaGpt4(cache_dir=data_path, mode=mode, **kwargs)
elif dataset_name == "red_pajama":
dataset = RedPajama(cache_dir=data_path, mode=mode, **kwargs)
elif dataset_name == "fanfics":
dataset = FanFics(cache_dir=data_path, mode=mode, **kwargs)
elif dataset_name == "gpteacher_roleplay":
dataset = GPTeacher_Roleplay(cache_dir=data_path, mode=mode, **kwargs)
elif dataset_name == "orca-chat":
Expand Down
47 changes: 43 additions & 4 deletions model/model_training/custom_datasets/pretrain_datasets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""
Datasets for LM objective pre-training aimed to prevent catastrophic forgetting during fine-tuning
"""
import random
from pathlib import Path
from typing import Optional

from datasets import load_dataset
from model_training.custom_datasets.formatting import DatasetEntryLm
Expand All @@ -11,17 +13,54 @@
class RedPajama(Dataset):
name = "red_pajama"

def __init__(self, cache_dir: str | Path, mode: str = "sft", char_max_len: str = 9216) -> None:
def __init__(
self,
cache_dir: str | Path,
mode: str = "sft",
char_max_len: Optional[int] = 65536,
random_offset: bool = False,
) -> None:
super().__init__()
self.mode = mode
assert mode in ("sft", "rm", "rl")
self.char_max_len = char_max_len

self.random_offset = random_offset
self.dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", cache_dir=cache_dir)["train"]

def __len__(self) -> int:
return len(self.dataset)

def __getitem__(self, index) -> DatasetEntryLm:
dialogue = DatasetEntryLm(text=self.dataset[index]["text"][: self.char_max_len])
return dialogue
text = self.dataset[index]["text"]
if self.char_max_len and len(text) > self.char_max_len:
offset = 0 if not self.random_offset else random.randrange(len(text) - self.char_max_len)
text = text[offset : offset + self.char_max_len]
return DatasetEntryLm(text=text)


class FanFics(Dataset):
name = "fanfics"

def __init__(
self,
cache_dir: str | Path,
mode: str = "sft",
char_max_len: Optional[int] = 65536,
random_offset: bool = False,
) -> None:
super().__init__()
self.mode = mode
assert mode in ("sft", "rm", "rl")
self.char_max_len = char_max_len
self.random_offset = random_offset
self.dataset = load_dataset("atom-in-the-universe/fanfics-10k-50k", cache_dir=cache_dir)["train"]

def __len__(self) -> int:
return len(self.dataset)

def __getitem__(self, index) -> DatasetEntryLm:
text = self.dataset[index]["TEXT"]
if self.char_max_len and len(text) > self.char_max_len:
offset = 0 if not self.random_offset else random.randrange(len(text) - self.char_max_len)
text = text[offset : offset + self.char_max_len]
return DatasetEntryLm(text=text)

0 comments on commit 64ece88

Please sign in to comment.