Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NeMo-UX] Add data module #9133

Merged
merged 3 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions nemo/lightning/base.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,70 @@
import gc
import inspect
import os
from pathlib import Path
from typing import Optional
from typing import Generic, Optional, Type, TypeVar

import torch
import torch.distributed
from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule, Trainer
from torch import nn

from nemo import io

DEFAULT_NEMO_CACHE_HOME = Path.home() / ".cache" / "nemo"
NEMO_CACHE_HOME = Path(os.getenv("NEMO_HOME", DEFAULT_NEMO_CACHE_HOME))
DEFAULT_NEMO_DATASETS_CACHE = NEMO_CACHE_HOME / "datasets"
NEMO_DATASETS_CACHE = Path(os.getenv("NEMO_DATASETS_CACHE", DEFAULT_NEMO_DATASETS_CACHE))
DEFAULT_NEMO_MODELS_CACHE = NEMO_CACHE_HOME / "models"
NEMO_MODELS_CACHE = Path(os.getenv("NEMO_MODELS_CACHE", DEFAULT_NEMO_MODELS_CACHE))

#
# @dataclass
# class DataConfig:
# seq_length: int
# micro_batch_size: int = 4
# global_batch_size: int = 8
# rampup_batch_size: Optional[List[int]] = None
# train_drop_last: bool = True
# val_drop_last: bool = True
# test_drop_last: bool = True
# num_workers: int = 8
# pin_memory: bool = True
# persistent_workers: bool = False
#
# @property
# def num_microbatches(self) -> int:
# from apex.transformer.pipeline_parallel.utils import get_num_microbatches
#
# return get_num_microbatches()
Dismissed Show dismissed Hide dismissed
#
#
ModelT = TypeVar("ModelT", bound=LightningModule)


class ModelConfig(Generic[ModelT], io.IOMixin):
def model_cls(self) -> Type[ModelT]:
raise NotImplementedError("Must be implemented by subclass")

@property
def model_type(self) -> Type[ModelT]:
return self.model_cls()

def init(self, *args, data=None, cpu: bool = False, **kwargs) -> ModelT:
model_cls = self.model_cls()
if data:
kwargs.update(data.model_kwargs())

signature = inspect.signature(model_cls.__init__)
filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature.parameters}

model = model_cls(self, *args, **filtered_kwargs)

if not cpu:
model.cuda(torch.cuda.current_device())

return model


def get_vocab_size(config, vocab_size: int, make_vocab_size_divisible_by: int = 128,) -> int:
from nemo.utils import logging
Expand Down
6 changes: 5 additions & 1 deletion nemo/llm/gpt/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from nemo.llm.gpt.data.dolly import DollyDataModule
from nemo.llm.gpt.data.fine_tuning import FineTuningDataModule
from nemo.llm.gpt.data.mock import MockDataModule
from nemo.llm.gpt.data.pre_training import PreTrainingDataModule
from nemo.llm.gpt.data.squad import SquadDataModule

__all__ = ["MockDataModule"]
__all__ = ["FineTuningDataModule", "SquadDataModule", "DollyDataModule", "MockDataModule", "PreTrainingDataModule"]
57 changes: 57 additions & 0 deletions nemo/llm/gpt/data/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from pathlib import Path
from typing import TYPE_CHECKING, Optional

from nemo.lightning.base import NEMO_DATASETS_CACHE

if TYPE_CHECKING:
from nemo.collections.common.tokenizers import TokenizerSpec
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTDataset


def get_dataset_root(name: str) -> Path:
output = Path(NEMO_DATASETS_CACHE) / name
output.mkdir(parents=True, exist_ok=True)

return output


def create_sft_dataset(
path: Path,
tokenizer: "TokenizerSpec",
seq_length: int = 2048,
add_bos: bool = False,
add_eos: bool = True,
add_sep: bool = False,
seed: int = 1234,
label_key: str = 'output',
answer_only_loss: bool = True,
truncation_field: str = 'input',
pad_to_max_length: bool = False,
index_mapping_dir: Optional[str] = None,
prompt_template: str = '{input} {output}',
truncation_method: str = 'right',
memmap_workers: int = 2,
hf_dataset: bool = False,
**kwargs
) -> "GPTSFTDataset":
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTDataset

return GPTSFTDataset(
file_path=str(path),
tokenizer=tokenizer,
max_seq_length=seq_length,
memmap_workers=memmap_workers,
hf_dataset=hf_dataset,
add_bos=add_bos,
add_eos=add_eos,
add_sep=add_sep,
seed=seed,
label_key=label_key,
answer_only_loss=answer_only_loss,
truncation_field=truncation_field,
pad_to_max_length=pad_to_max_length,
index_mapping_dir=index_mapping_dir,
prompt_template=prompt_template,
truncation_method=truncation_method,
**kwargs
)
122 changes: 122 additions & 0 deletions nemo/llm/gpt/data/dolly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import json
import shutil
from typing import TYPE_CHECKING, List, Optional

import numpy as np
from datasets import load_dataset

from nemo.llm.gpt.data.core import get_dataset_root
from nemo.llm.gpt.data.fine_tuning import FineTuningDataModule
from nemo.utils import logging

if TYPE_CHECKING:
from nemo.collections.common.tokenizers import TokenizerSpec


class DollyDataModule(FineTuningDataModule):
"""A data module for fine-tuning on the Dolly dataset.

This class inherits from the `FineTuningDataModule` class and is specifically designed for fine-tuning models on the
"databricks/databricks-dolly-15k" dataset. It handles data download, preprocessing, splitting, and preparing the data
in a format suitable for training, validation, and testing.

Args:
force_redownload (bool, optional): Whether to force re-download the dataset even if it exists locally. Defaults to False.
delete_raw (bool, optional): Whether to delete the raw downloaded dataset after preprocessing. Defaults to True.
See FineTuningDataModule for the other args
"""

def __init__(
self,
seq_length: int = 2048,
tokenizer: Optional["TokenizerSpec"] = None,
micro_batch_size: int = 4,
global_batch_size: int = 8,
rampup_batch_size: Optional[List[int]] = None,
force_redownload: bool = False,
delete_raw: bool = True,
seed: int = 1234,
memmap_workers: int = 1,
num_workers: int = 8,
pin_memory: bool = True,
persistent_workers: bool = False,
):
self.force_redownload = force_redownload
self.delete_raw = delete_raw

super().__init__(
dataset_root=get_dataset_root("dolly"),
seq_length=seq_length,
tokenizer=tokenizer,
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
rampup_batch_size=rampup_batch_size,
seed=seed,
memmap_workers=memmap_workers,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
)

def prepare_data(self) -> None:
# if train file is specified, no need to do anything
if self.train_path.exists() and not self.force_redownload:
return

dset = self._download_data()
self._preprocess_and_split_data(dset)

def _download_data(self):
logging.info(f"Downloading {self.__class__.__name__}...")
return load_dataset(
"databricks/databricks-dolly-15k",
cache_dir=str(self.dataset_root),
download_mode="force_redownload" if self.force_redownload else None,
)

def _preprocess_and_split_data(self, dset, train_ratio: float = 0.80, val_ratio: float = 0.15):
logging.info(f"Preprocessing {self.__class__.__name__} to jsonl format and splitting...")

test_ratio = 1 - train_ratio - val_ratio
save_splits = {}
dataset = dset.get('train')
split_dataset = dataset.train_test_split(test_size=val_ratio + test_ratio, seed=self.seed)
split_dataset2 = split_dataset['test'].train_test_split(
test_size=test_ratio / (val_ratio + test_ratio), seed=self.seed
)
save_splits['training'] = split_dataset['train']
save_splits['validation'] = split_dataset2['train']
save_splits['test'] = split_dataset2['test']

for split_name, dataset in save_splits.items():
output_file = self.dataset_root / f"{split_name}.jsonl"
with output_file.open("w", encoding="utf-8") as f:
for example in dataset:
context = example["context"].strip()
if context != "":
# Randomize context and instruction order.
context_first = np.random.randint(0, 2) == 0
if context_first:
instruction = example["instruction"].strip()
assert instruction != ""
_input = f"{context}\n\n{instruction}"
_output = example["response"]
else:
instruction = example["instruction"].strip()
assert instruction != ""
_input = f"{instruction}\n\n{context}"
_output = example["response"]
else:
_input = example["instruction"]
_output = example["response"]

f.write(json.dumps({"input": _input, "output": _output, "category": example["category"]}) + "\n")

logging.info(f"{split_name} split saved to {output_file}")

if self.delete_raw:
for p in self.dataset_root.iterdir():
if p.is_dir():
shutil.rmtree(p)
elif '.jsonl' not in str(p.name):
p.unlink()
105 changes: 105 additions & 0 deletions nemo/llm/gpt/data/fine_tuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Union

import pytorch_lightning as pl
from torch.utils.data import DataLoader

from nemo.lightning.pytorch.plugins import MegatronDataSampler
from nemo.llm.gpt.data.core import create_sft_dataset

if TYPE_CHECKING:
from nemo.collections.common.tokenizers import TokenizerSpec


class FineTuningDataModule(pl.LightningDataModule):
"""Base class for fine-tuning an LLM.

This class provides a foundation for building custom data modules for fine-tuning Nemo NLP models. It inherits from
`pl.LightningDataModule` from the PyTorch Lightning library and handles data loading, preprocessing, and batch creation
for training, validation, and testing.

Args:
dataset_root (Union[str, Path]): The root directory containing the training, validation, and test data.
seq_length (int, optional): The maximum sequence length for the input and output text. Defaults to 2048.
tokenizer (Optional[TokenizerSpec], optional): The tokenizer to use for preprocessing the text. Defaults to None.
If not provided, a Megatron GPT2 BPE tokenizer will be used.
micro_batch_size (int, optional): The micro batch size for training. Defaults to 4.
global_batch_size (int, optional): The global batch size for training. Defaults to 8.
rampup_batch_size (Optional[List[int]], optional): A list of batch sizes for ramping up during training. Defaults to None.
seed (int, optional): The random seed for data shuffling. Defaults to 1234.
memmap_workers (int, optional): The number of worker processes for loading data using TextMemMapDataset. Defaults to 1.
num_workers (int, optional): The number of worker processes for data loading. Defaults to 8.
pin_memory (bool, optional): Whether to pin memory during data loading for faster GPU training. Defaults to True.
persistent_workers (bool, optional): Whether to keep data loading workers persistent across epochs. Defaults to False.
"""

def __init__(
self,
dataset_root: Union[str, Path],
seq_length: int = 2048,
tokenizer: Optional["TokenizerSpec"] = None,
micro_batch_size: int = 4,
global_batch_size: int = 8,
rampup_batch_size: Optional[List[int]] = None,
seed: int = 1234,
memmap_workers: int = 1,
num_workers: int = 8,
pin_memory: bool = True,
persistent_workers: bool = False,
):
super().__init__()
self.seq_length = seq_length
self.seed = seed
self.dataset_root = Path(dataset_root)

from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer

self.tokenizer = tokenizer or get_nmt_tokenizer("megatron", "GPT2BPETokenizer")
self.memmap_workers = memmap_workers
self.num_workers = num_workers
self.pin_memory = pin_memory
self.persistent_workers = persistent_workers
self.data_sampler = MegatronDataSampler(
seq_len=self.seq_length,
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
rampup_batch_size=rampup_batch_size,
)

def train_dataloader(self) -> DataLoader:
return self._create_dataloader(self._create_dataset(str(self.train_path)))

def val_dataloader(self) -> DataLoader:
return self._create_dataloader(self._create_dataset(str(self.validation_path)))

def test_dataloader(self) -> DataLoader:
return self._create_dataloader(self._create_dataset(str(self.test_path), tokens_to_generate=32, is_test=True,))

@lru_cache
def _create_dataset(self, path, **kwargs):
return create_sft_dataset(
path, tokenizer=self.tokenizer, seq_length=self.seq_length, memmap_workers=self.memmap_workers, **kwargs
)

def _create_dataloader(self, dataset, **kwargs) -> DataLoader:
return DataLoader(
dataset,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers,
collate_fn=dataset.collate_fn,
**kwargs
)

@property
def train_path(self) -> Path:
return self.dataset_root / "training.jsonl"

@property
def validation_path(self) -> Path:
return self.dataset_root / "validation.jsonl"

@property
def test_path(self) -> Path:
return self.dataset_root / "test.jsonl"
Loading
Loading