From 1a18929d9e82fb8f636ff98941d75725c18448d1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 14 Oct 2025 22:52:54 -0400 Subject: [PATCH 01/45] Dataset interface --- fast_llm/config.py | 21 +++- fast_llm/data/config.py | 11 +- fast_llm/data/data/abstract.py | 3 +- fast_llm/data/data/gpt/config.py | 12 +- fast_llm/data/data/gpt/data.py | 44 ++------ fast_llm/data/dataset/abstract.py | 12 +- fast_llm/data/dataset/blended.py | 41 +++---- fast_llm/data/dataset/config.py | 55 +++++---- fast_llm/data/dataset/gpt/config.py | 104 +++++------------- fast_llm/data/dataset/gpt/fim.py | 22 ++-- fast_llm/data/dataset/gpt/indexed.py | 60 ---------- fast_llm/data/dataset/gpt/memmap.py | 71 ++++++------ fast_llm/data/dataset/gpt/random.py | 13 ++- fast_llm/data/dataset/indexed.py | 75 ++++++++++--- fast_llm/data/dataset/monitor.py | 14 +-- fast_llm/data/dataset/{gpt => }/sampled.py | 54 +++++---- .../data/preparator/gpt_memmap/prepare.py | 50 +++++---- fast_llm/data/sample/__init__.py | 0 fast_llm/data/sample/abstract.py | 10 ++ fast_llm/data/sample/gpt.py | 25 +++++ fast_llm/engine/config_utils/data_type.py | 14 ++- fast_llm/models/gpt/huggingface.py | 2 +- fast_llm/models/gpt/model.py | 2 +- tests/data/common.py | 66 ++++++----- tests/data/test_blending.py | 7 +- tests/data/test_concatenate.py | 5 +- tests/data/test_fim.py | 6 +- tests/data/test_prepare_gpt_memmap.py | 42 ++++--- tests/data/test_sampling.py | 28 +++-- tests/data/test_slice.py | 5 +- tests/models/test_match_megatron.py | 17 +-- tests/utils/dataset.py | 8 +- 32 files changed, 448 insertions(+), 451 deletions(-) delete mode 100644 fast_llm/data/dataset/gpt/indexed.py rename fast_llm/data/dataset/{gpt => }/sampled.py (93%) create mode 100644 fast_llm/data/sample/__init__.py create mode 100644 fast_llm/data/sample/abstract.py create mode 100644 fast_llm/data/sample/gpt.py diff --git a/fast_llm/config.py b/fast_llm/config.py index 9644df9c1..658ad5666 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -492,6 +492,10 @@ def _validate_element(cls, value, type_, name: str): value = cls._validate_dict(value, type_, name) elif origin is type: value = cls._validate_type(value, type_, name) + elif issubclass(origin, Config): + # TODO: Validate arguments for config generics. + cls._validate_element_type(value, type_.__origin__, strict=False) + value.validate(_is_validating=True) else: raise FieldTypeError(f"Unsupported __origin__ `{origin}`") elif not isinstance(type_, type): @@ -806,6 +810,8 @@ def _from_dict_nested(cls, value, type_, strict: bool): value = cls._from_dict_array(value, type_, strict) elif issubclass(origin, dict): value = cls._from_dict_dict(value, type_, strict) + elif issubclass(origin, Config): + value = cls._from_dict_config(value, type_, strict) elif origin is type: pass else: @@ -813,10 +819,15 @@ def _from_dict_nested(cls, value, type_, strict: bool): elif not isinstance(type_, type): raise FieldTypeError(f"Not a type: {type_}.") elif issubclass(type_, Config): - if value is MISSING: - value = {} - if isinstance(value, dict): - value = type_._from_dict(value, strict) + value = cls._from_dict_config(value, type_, strict) + return value + + @classmethod + def _from_dict_config(cls, value, type_, strict: bool): + if value is MISSING: + value = {} + if isinstance(value, dict): + value = type_._from_dict(value, strict) return value @classmethod @@ -938,6 +949,7 @@ def __init_subclass__(cls): We need to postpone validation until the class has been processed by the dataclass wrapper. """ Assert.eq(cls.__name__, cls.__qualname__) + super().__init_subclass__() for base_class in cls.__mro__: if issubclass(base_class, Config) and base_class is not cls: assert cls.__class_validated__, ( @@ -1006,6 +1018,7 @@ def __init__(self, config: ConfigType, *args, **kwargs): def __init_subclass__(cls): # Automatically set `config_class` based on the bound type. # Make sure `ConfigType` is bound and respects class hierarchy. + super().__init_subclass__() try: config_class = None for base in types.get_original_bases(cls): diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 4c041945d..633367c80 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -1,9 +1,13 @@ import enum import pathlib +import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.data.tokenizer import Tokenizer + class MultiprocessingContext(str, enum.Enum): # Fast but risk of segfaults due to interactions with triton @@ -29,7 +33,7 @@ class TokenizerConfig(Config): hint=FieldHint.deprecated, valid=check_field(Assert.eq, TokenizerFromFile), ) - path: pathlib.Path | None = Field( + path: pathlib.Path = Field( default=None, desc="Path to the tokenizer file.", hint=FieldHint.core, @@ -39,3 +43,8 @@ class TokenizerConfig(Config): desc="BOS token to use if the tokenizer doesn't define one; must be an existing token.", hint=FieldHint.core, ) + + def get_tokenizer(self) -> "Tokenizer": + from fast_llm.data.tokenizer import Tokenizer + + return Tokenizer(self) diff --git a/fast_llm/data/data/abstract.py b/fast_llm/data/data/abstract.py index e24d39985..c67dc0321 100644 --- a/fast_llm/data/data/abstract.py +++ b/fast_llm/data/data/abstract.py @@ -5,6 +5,7 @@ from fast_llm.config import Configurable from fast_llm.data.data.config import DataConfig from fast_llm.data.dataset.config import SamplingParameters +from fast_llm.data.sample.abstract import Batch from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.schedule.config import BatchConfig @@ -47,5 +48,5 @@ def get_iterator( num_workers: int, prefetch_factor: int | None = None, timeout: float = 60, - ) -> typing.Iterator[typing.Any]: + ) -> typing.Iterator[Batch]: pass diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index efee46959..5083c5121 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -1,9 +1,11 @@ import logging from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class -from fast_llm.data.config import MultiprocessingContext, TokenizerConfig +from fast_llm.data.config import MultiprocessingContext from fast_llm.data.data.config import DataConfig -from fast_llm.data.dataset.gpt.config import GPTSampledDatasetConfig, GPTSamplingConfig +from fast_llm.data.dataset.config import SampledDatasetConfig +from fast_llm.data.dataset.gpt.config import GPTSamplingConfig +from fast_llm.data.sample.gpt import GPTSample from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -19,12 +21,8 @@ class GPTDataConfig(DataConfig): _abstract = False - tokenizer: TokenizerConfig = Field( - desc="Configuration for the tokenizer (for FIM).", - hint=FieldHint.feature, - ) # TODO: Review field. Move closer to phase definition in training config? - datasets: dict[str, GPTSampledDatasetConfig] = Field( + datasets: dict[str, SampledDatasetConfig[GPTSample]] = Field( default_factory=dict, desc="Configuration for the dataset(s).", hint=FieldHint.core, diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 6724afb59..2a18afd50 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -1,11 +1,9 @@ -import dataclasses import logging import pathlib import typing import warnings from functools import partial -import numpy as np import torch import torch.utils.data @@ -14,43 +12,32 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters -from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.dataset.monitor import DatasetMonitor from fast_llm.data.iterator import SampledDatasetIterator -from fast_llm.data.tokenizer import Tokenizer +from fast_llm.data.sample.gpt import GPTBatch, GPTSample from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.schedule.config import BatchConfig +from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.utils import Assert logger = logging.getLogger(__name__) -@dataclasses.dataclass -class GPTBatch: - token_ids: torch.Tensor - loss_masking_spans: list[torch.Tensor] | None = None - sequence_lengths: list[torch.Tensor] | None = None - chosen_spans: list[torch.Tensor] | None = None - rejected_spans: list[torch.Tensor] | None = None - - def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch: - stacked_ids = np.stack([sample.token_ids for sample in batch]) stacked_spans = None sequence_lengths = None stacked_chosen_spans = None stacked_rejected_spans = None if sampling_parameters.use_loss_masking_spans: - stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] + stacked_spans = [sample.loss_masking_spans for sample in batch] if sampling_parameters.use_preference_loss_spans: - stacked_chosen_spans = [torch.from_numpy(sample.chosen_span) for sample in batch] - stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch] + stacked_chosen_spans = [sample.chosen_span for sample in batch] + stacked_rejected_spans = [sample.rejected_span for sample in batch] if not sampling_parameters.cross_document_attention: - sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] + sequence_lengths = [sample.sequence_lengths for sample in batch] return GPTBatch( - token_ids=torch.from_numpy(stacked_ids), + token_ids=torch.stack([sample.token_ids for sample in batch]), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths, chosen_spans=stacked_chosen_spans, @@ -67,7 +54,6 @@ class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): _datasets: dict[str, SampledDataset] _sampling_parameters: dict[str, GPTSamplingParameters] - _tokenizer: Tokenizer | None _is_setup: bool = False def __init__( @@ -108,7 +94,6 @@ def setup( ) log_main_rank(f"Preparing dataset. This may take several minutes.") - self._tokenizer = None if self._config.tokenizer.path is None else Tokenizer(self._config.tokenizer) if self._cache_directory is None: # TODO: Avoid this @@ -116,11 +101,6 @@ def setup( self._datasets = {} for dataset_name, sampling_parameters in self._sampling_parameters.items(): - if self._tokenizer is not None: - # NOTE: Some models like Qwen2-1.5B-Instruct - # have vocab_size bigger in model config than in tokenizer - # TODO: Still, is it too constraining? - Assert.geq(sampling_parameters.vocab_size, self._tokenizer.vocab_size) if sampling_parameters.num_samples > 0: sampling = GPTSamplingData( config=self._config.sampling, @@ -128,7 +108,6 @@ def setup( cache_directory=self._cache_directory, distributed=distributed, dataset_name=dataset_name, - tokenizer=self._tokenizer, ) dataset = self._config.datasets[dataset_name].build_and_sample(sampling) self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) @@ -136,21 +115,16 @@ def setup( safe_barrier(self._distributed.world_group, "data_preparation", timeout) self._is_setup = True - @property - def tokenizer(self) -> Tokenizer: - assert self._is_setup - return self._tokenizer - def get_iterator( self, - batch_config: BatchConfig, + batch_config: GPTBatchConfig, dataset_name: str, *, consumed_samples: int, num_workers: int, prefetch_factor: int | None = None, timeout: float = 60, - ) -> typing.Iterator[typing.Any]: + ) -> typing.Iterator[GPTBatch]: assert self._is_setup # Some dataset names may come from phases and are capitalized, diff --git a/fast_llm/data/dataset/abstract.py b/fast_llm/data/dataset/abstract.py index b470c0159..d57135ede 100644 --- a/fast_llm/data/dataset/abstract.py +++ b/fast_llm/data/dataset/abstract.py @@ -1,11 +1,13 @@ import abc import typing +from fast_llm.data.sample.abstract import Sample + if typing.TYPE_CHECKING: from fast_llm.data.dataset.config import SamplingData -class Dataset(abc.ABC): +class Dataset[SampleType: Sample](abc.ABC): """ A generic dataset class compatible with torch.utils.data.Dataset but with a slightly different signature. """ @@ -18,14 +20,14 @@ def name(self) -> str: """ -class SampledDataset(Dataset): +class SampledDataset[SampleType: Sample](Dataset[SampleType]): """ A sampled dataset class containing a prepared list of samples to be indexed sequentially (as-is) during training. (See the `Sampler` class below.) """ @abc.abstractmethod - def __getitem__(self, index: int) -> typing.Any: + def __getitem__(self, index: int) -> SampleType: pass @abc.abstractmethod @@ -33,8 +35,8 @@ def __len__(self) -> int: pass -class SamplableDataset(Dataset): +class SamplableDataset[SampleType: Sample](Dataset[SampleType]): @abc.abstractmethod - def sample(self, config: "SamplingData") -> SampledDataset: + def sample(self, config: "SamplingData") -> SampledDataset[SampleType]: pass diff --git a/fast_llm/data/dataset/blended.py b/fast_llm/data/dataset/blended.py index 24b0fa76f..264eb373d 100644 --- a/fast_llm/data/dataset/blended.py +++ b/fast_llm/data/dataset/blended.py @@ -1,16 +1,16 @@ import logging -import typing -import numpy as np +import torch from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.config import SamplingData +from fast_llm.data.sample.abstract import Sample from fast_llm.utils import Assert, normalize_probabilities logger = logging.getLogger(__name__) -class BlendedDataset(SampledDataset): +class BlendedDataset[SampleType: Sample](SampledDataset[SampleType]): """ A blended sampling of multiple sampled datasets, where each dataset is sampled with the provided probability. The sampling order of each dataset is respected, but there is no strict guarantee @@ -21,7 +21,7 @@ class BlendedDataset(SampledDataset): def __init__( self, name: str, - datasets: list[SampledDataset], + datasets: list[SampledDataset[SampleType]], weights: list[float], sampling_config: SamplingData, ): @@ -29,51 +29,52 @@ def __init__( assert len(datasets) > 0 Assert.eq(len(datasets), len(weights)) self._datasets = datasets - self._weights = np.array(normalize_probabilities(weights)) + self._weights = torch.from_numpy(normalize_probabilities(weights, return_array=True)) self._num_samples = sampling_config.parameters.num_samples def __len__(self) -> int: return self._num_samples - def __getitem__(self, idx: int) -> typing.Any: + def __getitem__(self, index: int) -> SampleType: """ Blending is typically done in one of the following iterative way (ex. in Megatron datasets): ```python dataset_index=np.zeros(num_samples) sample_index=np.zeros(num_samples) sampled=np.zeros(len(weights)) - for idx in range(num_samples): - error = weights * (idx + 1) - sampled + for index in range(num_samples): + error = weights * (index + 1) - sampled dataset_index_ = np.argmax(error) - dataset_index[idx] = dataset_index_ - sample_index[idx] = sampled[dataset_index_] + dataset_index[index] = dataset_index_ + sample_index[index] = sampled[dataset_index_] sampled[dataset_index_] +=1 ``` I.e. it iteratively picks samples to minimize the error `weights * sum(sampled) - sampled`. This implementation computes values on the fly instead of pre-computing them all. """ # We find the number of samples taken from each dataset prior to this point. - sampled = self._get_sampled(idx) + sampled = self._get_sampled(index) # Then get the present sample. - dataset_index = self._get_next_dataset(idx, sampled) - return self._datasets[dataset_index][sampled[dataset_index]] + dataset_index = self._get_next_dataset(index, sampled) + return self._datasets[dataset_index][sampled[dataset_index].item()] - def _get_sampled(self, num_samples: int): + def _get_sampled(self, num_samples: int) -> torch.Tensor: # First we determine a lower bound. # This is indeed a lower bound because a lower value for one dataset would involve more sampling below, # and it would be from that same dataset because it would have the highest error, - sampled = np.floor(self._weights * num_samples).astype(int) + + sampled = (self._weights * num_samples).to(torch.int64) # Then we sample until we reach the target number of samples. # This may not match the actual sampling order, but the final value of `sampled` is correct. - for idx in range(sampled.sum(), num_samples): - dataset_index = self._get_next_dataset(idx, sampled) + for index in range(sampled.sum().item(), num_samples): + dataset_index = self._get_next_dataset(index, sampled) sampled[dataset_index] += 1 return sampled - def _get_next_dataset(self, idx, sampled): + def _get_next_dataset(self, index: int, sampled: torch.Tensor) -> int: # The next sample is the one with the highest error. - return (self._weights * (idx + 1) - sampled).argmax() + return (self._weights * (index + 1) - sampled).argmax().item() @property - def name(self): + def name(self) -> str: return self._name diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 0c1b0cd09..7a8d3567d 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -7,6 +7,7 @@ from fast_llm.config import Config, Field, FieldHint, UpdateType, check_field, config_class from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset +from fast_llm.data.sample.abstract import Sample from fast_llm.utils import Assert, normalize_probabilities if typing.TYPE_CHECKING: @@ -64,37 +65,38 @@ def get_next_rank(self) -> int: @config_class() -class DatasetConfig(Config): +class DatasetConfig[SampleType: Sample](Config): _abstract: typing.ClassVar[bool] = True -@config_class() -class SampledDatasetConfig(DatasetConfig): +@config_class(registry=True) +class SampledDatasetConfig[SampleType: Sample](DatasetConfig[SampleType]): """ A sampled dataset containing a prepared list of samples to be indexed sequentially (as-is) during training. """ - def build_and_sample(self, sampling: SamplingData) -> SampledDataset: + def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: + # TODO: ====== `SamplingData` contains more than needed (ex. `num_samples`) raise NotImplementedError() @config_class() -class SamplableDatasetConfig(SampledDatasetConfig): - def build(self) -> SamplableDataset: +class SamplableDatasetConfig[SampleType: Sample](SampledDatasetConfig[SampleType]): + def build(self) -> SamplableDataset[SampleType]: raise NotImplementedError() - def build_and_sample(self, sampling: SamplingData) -> SampledDataset: + def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: return self.build().sample(sampling) @config_class() -class IndexedDatasetConfig(SamplableDatasetConfig): - def _build(self) -> "IndexedDataset": +class IndexedDatasetConfig[SampleType: Sample](SamplableDatasetConfig[SampleType]): + def build(self) -> "IndexedDataset[SampleType]": raise NotImplementedError() -@config_class() -class ConcatenatedDatasetConfig(SamplableDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "concatenated"}) +class ConcatenatedDatasetConfig[SampleType: Sample](SamplableDatasetConfig[SampleType]): """ Concatenate multiple indexed datasets as if they were one. TODO: Make a post-sampling version? (staged training) @@ -106,7 +108,7 @@ class ConcatenatedDatasetConfig(SamplableDatasetConfig): desc="The name of the dataset.", hint=FieldHint.core, ) - datasets: list[IndexedDatasetConfig] = Field( + datasets: list[IndexedDatasetConfig[SampleType]] = Field( default_factory=list, desc="The datasets to concatenate.", hint=FieldHint.core, @@ -122,8 +124,8 @@ def _build[T: ConcatenatedDataset](self, cls: type[T]) -> T: return cls(self.name, [dataset.build() for dataset in self.datasets]) -@config_class() -class DatasetSliceConfig(SamplableDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "slice"}) +class DatasetSliceConfig[SampleType: Sample](SamplableDatasetConfig[SampleType]): """ Use a fraction of an indexed dataset, specified by the range (begin, end). Typically used to subsample a dataset, or to reserve part of the dataset for validation and/or testing. @@ -133,7 +135,7 @@ class DatasetSliceConfig(SamplableDatasetConfig): """ _abstract = False - dataset: IndexedDatasetConfig = Field( + dataset: IndexedDatasetConfig[SampleType] = Field( default=None, desc="The dataset to split.", hint=FieldHint.core, @@ -152,12 +154,9 @@ class DatasetSliceConfig(SamplableDatasetConfig): def build(self) -> "DatasetSlice": from fast_llm.data.dataset.indexed import DatasetSlice - return self._build(DatasetSlice) - - def _build[T: DatasetSlice](self, cls: type[T]) -> T: dataset = self.dataset.build() size = len(dataset) - return cls( + return DatasetSlice[SampleType]( f"{dataset.name}_{self.begin}_{self.end}", dataset, round(self.begin * size), @@ -165,8 +164,8 @@ def _build[T: DatasetSlice](self, cls: type[T]) -> T: ) -@config_class() -class SampledDatasetUpdateConfig(SampledDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "sampled"}) +class SampledDatasetUpdateConfig[SampleType: Sample](SampledDatasetConfig[SampleType]): """ Wrap a dataset to explicitly sample from it and optionally update its configuration parameters. Only explicitly set parameters (not None) will be updated, other will still be taken from `build_and_sample`'s argument. @@ -177,24 +176,24 @@ class SampledDatasetUpdateConfig(SampledDatasetConfig): desc="Optional override to sampling configuration parameters.", hint=FieldHint.core, ) - dataset: SampledDatasetConfig = Field( + dataset: SampledDatasetConfig[SampleType] = Field( desc="The dataset to sample from.", hint=FieldHint.core, ) - def build_and_sample(self, data: SamplingData) -> SampledDataset: + def build_and_sample(self, data: SamplingData) -> SampledDataset[SampleType]: return self.dataset.build_and_sample(data.update_config(self.sampling)) -@config_class() -class BlendedDatasetConfig(SampledDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "blended"}) +class BlendedDatasetConfig[SampleType: Sample](SampledDatasetConfig[SampleType]): _abstract = False name: str = Field( default="blended", desc="The name of the dataset.", hint=FieldHint.core, ) - datasets: list[SampledDatasetConfig] = Field( + datasets: list[SampledDatasetConfig[SampleType]] = Field( default_factory=list, desc="The datasets to blend.", hint=FieldHint.core, @@ -214,7 +213,7 @@ def _validate(self) -> None: def build_and_sample( self, sampling: SamplingData, - ) -> SampledDataset: + ) -> SampledDataset[SampleType]: from fast_llm.data.dataset.blended import BlendedDataset # Build and sample the datasets. @@ -235,7 +234,7 @@ def build_and_sample( for i, (dataset, weight) in enumerate(zip(self.datasets, self.weights, strict=True)) ] # Blend the datasets. - return BlendedDataset( + return BlendedDataset[SampleType]( self.name, sampled_datasets, self.weights, diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 656cd7d24..36412b6ce 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -6,27 +6,23 @@ import yaml -from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.data.config import TokenizerConfig from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.dataset.config import ( - BlendedDatasetConfig, - ConcatenatedDatasetConfig, - DatasetSliceConfig, IndexedDatasetConfig, SamplableDatasetConfig, SampledDatasetConfig, - SampledDatasetUpdateConfig, SamplingConfig, SamplingData, SamplingParameters, ) +from fast_llm.data.sample.gpt import GPTSample from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset, GPTDatasetSlice, GPTIndexedDataset from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.random import GPTRandomDataset - from fast_llm.data.tokenizer import Tokenizer class ShufflingType(str, enum.Enum): @@ -86,27 +82,10 @@ class GPTSamplingData(SamplingData): config: GPTSamplingConfig parameters: GPTSamplingParameters - tokenizer: "Tokenizer" -@config_class(registry=True) -class GPTSampledDatasetConfig(SampledDatasetConfig): - pass - - -@config_class() -class GPTSamplableDatasetConfig(SamplableDatasetConfig, GPTSampledDatasetConfig): - pass - - -@config_class() -class GPTIndexedDatasetConfig(GPTSamplableDatasetConfig, IndexedDatasetConfig): - def build(self) -> "GPTIndexedDataset": - raise NotImplementedError() - - -@config_class(dynamic_type={GPTSampledDatasetConfig: "random"}) -class GPTRandomDatasetConfig(GPTSamplableDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "random"}) +class GPTRandomDatasetConfig[SampleType: GPTSample](SamplableDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False name: str = Field( default="dummy", @@ -120,8 +99,8 @@ def build(self) -> "GPTRandomDataset": return GPTRandomDataset(self.name) -@config_class(dynamic_type={GPTSampledDatasetConfig: "memmap"}) -class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "memmap"}) +class GPTMemmapDatasetConfig[SampleType: GPTSample](IndexedDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False path: pathlib.Path = Field( default=None, @@ -145,43 +124,8 @@ def build(self) -> "GPTMemmapDataset": return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens) -@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated"}) -class GPTConcatenatedDatasetConfig(ConcatenatedDatasetConfig, GPTIndexedDatasetConfig): - _abstract: typing.ClassVar[bool] = False - datasets: list[GPTIndexedDatasetConfig] = FieldUpdate() - - def build(self) -> "GPTConcatenatedDataset": - from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset - - return self._build(GPTConcatenatedDataset) - - -@config_class(dynamic_type={GPTSampledDatasetConfig: "slice"}) -class GPTDatasetSliceConfig(DatasetSliceConfig, GPTIndexedDatasetConfig): - _abstract: typing.ClassVar[bool] = False - dataset: GPTIndexedDatasetConfig = FieldUpdate() - - def build(self) -> "GPTDatasetSlice": - from fast_llm.data.dataset.gpt.indexed import GPTDatasetSlice - - return self._build(GPTDatasetSlice) - - -@config_class(dynamic_type={GPTSampledDatasetConfig: "sampled"}) -class GPTSampledDatasetUpdateConfig(SampledDatasetUpdateConfig, GPTSampledDatasetConfig): - _abstract = False - sampling: GPTSamplingConfig = FieldUpdate() - dataset: GPTSampledDatasetConfig = FieldUpdate() - - -@config_class(dynamic_type={GPTSampledDatasetConfig: "blended"}) -class GPTBlendedDatasetConfig(BlendedDatasetConfig, GPTSampledDatasetConfig): - _abstract: typing.ClassVar[bool] = False - datasets: list[GPTSampledDatasetConfig] = FieldUpdate() - - -@config_class(dynamic_type={GPTSampledDatasetConfig: "file"}) -class GPTDatasetFromFileConfig(GPTSamplableDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "file"}) +class GPTDatasetFromFileConfig[SampleType: GPTSample](SamplableDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False path: pathlib.Path = Field( default=None, @@ -189,18 +133,18 @@ class GPTDatasetFromFileConfig(GPTSamplableDatasetConfig): hint=FieldHint.core, ) - def build_and_sample(self, sampling: SamplingData) -> SampledDataset: + def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: config = self._load_config() return config.build_and_sample(sampling) - def build(self) -> SamplableDataset: + def build(self) -> SamplableDataset[SampleType]: config = self._load_config() - assert isinstance(config, GPTSamplableDatasetConfig) + assert isinstance(config, SamplableDatasetConfig) return config.build() - def _load_config(self): + def _load_config(self) -> SampledDatasetConfig[SampleType]: assert self.path.is_file(), f"File {self.path} does not exist." - return GPTSampledDatasetConfig.from_dict(self._convert_paths(yaml.safe_load(self.path.open("r")))) + return SampledDatasetConfig[SampleType].from_dict(self._convert_paths(yaml.safe_load(self.path.open("r")))) def _convert_paths(self, config): # Recursively convert paths relative to `self.path.parent` to make them relative to cwd. @@ -224,6 +168,10 @@ class FimConfig(Config): Configuration for FIM. """ + tokenizer: TokenizerConfig = Field( + desc="Configuration for the tokenizer.", + hint=FieldHint.feature, + ) rate: float = Field( # TODO: Use meaningful default now that fim is a wrapper? default=0.0, @@ -286,15 +234,15 @@ class FimConfig(Config): ) -@config_class(dynamic_type={GPTSampledDatasetConfig: "fim"}) -class GPTFimSampledDatasetConfig(GPTSampledDatasetConfig, FimConfig): +@config_class(dynamic_type={SampledDatasetConfig: "fim"}) +class GPTFimSampledDatasetConfig[SampleType: GPTSample](SampledDatasetConfig[SampleType], FimConfig): """ Configuration for FIM. """ _abstract: typing.ClassVar[bool] = False - dataset: GPTSampledDatasetConfig = Field( + dataset: SampledDatasetConfig = Field( default=None, desc="The dataset to wrap with fim.", hint=FieldHint.core, @@ -302,15 +250,15 @@ class GPTFimSampledDatasetConfig(GPTSampledDatasetConfig, FimConfig): def build_and_sample( self, - sampling: GPTSamplingData, + sampling: SamplingData, ) -> SampledDataset: from fast_llm.data.dataset.gpt.fim import GPTFimDataset return GPTFimDataset(self, self.dataset.build_and_sample(sampling), sampling) -@config_class(dynamic_type={GPTSampledDatasetConfig: "test_slow"}) -class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "test_slow"}) +class GPTTestSlowDatasetConfig[SampleType: GPTSample](SampledDatasetConfig[SampleType]): """ A mock dataset that mimics a slow dataset creation on one rank, which may trigger a timeout. """ @@ -323,8 +271,8 @@ class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig): hint=FieldHint.core, ) - def build_and_sample(self, sampling: SamplingData) -> SampledDataset: + def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: assert sampling.distributed.config.world_size > 1 if sampling.distributed.config.rank == 0: time.sleep(self.sleep) - return GPTRandomDatasetConfig().build_and_sample(sampling) + return GPTRandomDatasetConfig[SampleType]().build_and_sample(sampling) diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 2b2c8b3be..175a0e549 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -1,12 +1,13 @@ import numpy as np +import torch from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import FimConfig, GPTSamplingData -from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.data.sample.gpt import GPTSample from fast_llm.engine.distributed.config import MAX_SEED -class GPTFimDataset(SampledDataset): +class GPTFimDataset[SampleType: GPTSample](SampledDataset[SampleType]): """ An implementation of FIM (fill in the middle) post-processing of GPT datasets. Adapted from https://github.com/EleutherAI/gpt-neox/blob/FIM-clean/megatron/data/gpt2_dataset.py @@ -15,7 +16,7 @@ class GPTFimDataset(SampledDataset): def __init__( self, config: FimConfig, - dataset: SampledDataset, + dataset: SampledDataset[SampleType], sampling: GPTSamplingData, ): if sampling.parameters.use_loss_masking_spans: @@ -26,7 +27,7 @@ def __init__( self._dataset = dataset self._seed = sampling.config.seed - self._tokenizer = sampling.tokenizer + self._tokenizer = self._config.tokenizer.get_tokenizer() if self._tokenizer is None: raise ValueError("Fim requires a tokenizer") self._suffix_tok_id, self._prefix_tok_id, self._middle_tok_id, self._pad_tok_id = ( @@ -40,11 +41,15 @@ def __init__( def __len__(self) -> int: return len(self._dataset) - def __getitem__(self, idx: int) -> np.ndarray: - fim_token_ids = self._fim( - self._dataset[idx].token_ids, np.random.RandomState(seed=(self._seed + idx) % MAX_SEED) + def __getitem__(self, index: int) -> SampleType: + # TODO: Use torch methods to avoid back and forth. + return GPTSample( + torch.from_numpy( + self._fim( + self._dataset[index].token_ids.numpy(), np.random.RandomState(seed=(self._seed + index) % MAX_SEED) + ) + ) ) - return GPTSample(fim_token_ids) @property def name(self) -> str: @@ -55,6 +60,7 @@ def _fim(self, sample: np.ndarray, np_rng: np.random.RandomState) -> np.ndarray: # TODO: permute segments in sample_list, before concatenating. sample_len = sample.shape[0] eod = self._tokenizer.eod + # TODO: Available through `tokens.lengths` segment_breaks = np.argwhere(sample == eod) # split sample by document if segment_breaks.shape != (0, 1): # then there is an EOD token in this example diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py deleted file mode 100644 index 896229772..000000000 --- a/fast_llm/data/dataset/gpt/indexed.py +++ /dev/null @@ -1,60 +0,0 @@ -import abc -import typing - -import numpy as np - -from fast_llm.data.dataset.gpt.config import GPTSamplingData -from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset - -if typing.TYPE_CHECKING: - from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset - - -class GPTIndexedDataset(IndexedDataset): - @abc.abstractmethod - def get_document_sizes(self) -> np.ndarray: - """ - The size of each document in the dataset. - The resulting array could be very large, so this method should be called cautiously, - and derived classes should try to avoid holding the whole array im memory. - """ - - @abc.abstractmethod - def get_document_size(self, index: int) -> int: - """ - The size of a document in the dataset. - """ - - def sample(self, sampling: GPTSamplingData) -> "GPTSampledIndexedDataset": - from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset - - return GPTSampledIndexedDataset(self, sampling) - - -class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[IndexedDatasetType], GPTIndexedDataset): - """ - A GPT dataset, which reads samples from (a split of) a `MMapIndexedDataset` pointing to a GPT dataset. - """ - - _dataset: GPTIndexedDataset - - def get_document_sizes(self) -> np.ndarray: - # TODO: This can be really big. - return self._dataset.get_document_sizes()[self._begin : self._end] - - def get_document_size(self, index: int) -> int: - return self._dataset.get_document_size(self._begin + index) - - -class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( - ConcatenatedDataset[IndexedDatasetType], GPTIndexedDataset -): - _datasets: list[GPTIndexedDataset] - - def get_document_sizes(self) -> np.ndarray: - # TODO: This can be really big. - return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets]) - - def get_document_size(self, index: int) -> int: - dataset = np.searchsorted(self._dataset_splits[1:], index, side="right") - return self._datasets[dataset].get_document_size(index - self._dataset_splits[dataset].item()) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index f39fd56f4..c78805380 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -3,15 +3,17 @@ import typing import numpy as np +import torch -from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.data.dataset.gpt.config import GPTSamplingParameters +from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER +from fast_llm.data.sample.gpt import GPTSample from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert, div -class GPTMemmapDataset(GPTIndexedDataset): +class GPTMemmapDataset[SampleType: GPTSample](IndexedDataset[SampleType]): """ A memory map dataset, which handles lazy loading of a pre-processed dataset in the Megatron-LM format, i.e. a pair of numpy file containing @@ -142,37 +144,34 @@ def __del__(self): self._index_bin_buffer_mmap._mmap.close() # noqa del self._index_bin_buffer_mmap - def get( - self, - idx: int, - offset: int = 0, - length: int | None = None, - use_loss_masking_spans: bool = False, - use_preference_loss_spans: bool = False, - ) -> GPTSample: + def get_document( + self, index: int, begin: int = 0, end: int | None = None, parameters: GPTSamplingParameters | None = None + ) -> SampleType: + if end is None: + end = self.get_document_size(index) token_ids = np.frombuffer( self._bin_buffer, dtype=self._dtype, - count=self._document_sizes[idx] - offset if length is None else length, - offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, + count=end - begin, + offset=self._pointers[index] + begin * np.dtype(self._dtype).itemsize, ) sample_spans = None - if use_loss_masking_spans and self._spans is not None: - sample_spans = self._spans[idx] + if parameters is not None and parameters.use_loss_masking_spans: + assert self._spans is not None + sample_spans = self._spans[index] # filter spans that are outside the range of the selected tokens in the document - sample_spans = sample_spans[ - (sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset) - ] + sample_spans = sample_spans[(sample_spans[:, 0] < begin + len(token_ids)) & (sample_spans[:, 1] >= begin)] # subtract by offset to normalize span boundaries - sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset # offset - sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset + sample_spans[:, 0] = np.maximum(sample_spans[:, 0], begin) - begin # offset + sample_spans[:, 1] = np.minimum(sample_spans[:, 1], begin + len(token_ids) - 1) - begin + sample_spans = torch.from_numpy(sample_spans) chosen_span = None rejected_span = None - if use_preference_loss_spans: + if parameters is not None and parameters.use_preference_loss_spans: if not self._has_preference_spans: raise ValueError("No preference spans found in memmap dataset.") elif self._has_preference_spans and self._chosen_spans is None: @@ -180,28 +179,30 @@ def get( elif self._has_preference_spans and self._rejected_spans is None: raise ValueError("Failed to read rejected spans from memmap dataset.") else: - chosen_span = self._chosen_spans[idx] + chosen_span = self._chosen_spans[index] # filter spans that are outside the range of the selected tokens in the document - chosen_span = chosen_span[(chosen_span[0] < offset + len(token_ids)) & (chosen_span[1] >= offset)][0] + chosen_span = chosen_span[(chosen_span[0] < begin + len(token_ids)) & (chosen_span[1] >= begin)][0] # subtract by offset to normalize span boundaries - chosen_span[0] = np.maximum(chosen_span[0], offset) - offset # offset - chosen_span[1] = np.minimum(chosen_span[1], offset + len(token_ids) - 1) - offset + chosen_span[0] = np.maximum(chosen_span[0], begin) - begin # offset + chosen_span[1] = np.minimum(chosen_span[1], begin + len(token_ids) - 1) - begin + chosen_span = torch.from_numpy(chosen_span) - rejected_span = self._rejected_spans[idx] + rejected_span = self._rejected_spans[index] # filter spans that are outside the range of the selected tokens in the document rejected_span = rejected_span[ - (rejected_span[0] < offset + len(token_ids)) & (rejected_span[1] >= offset) + (rejected_span[0] < begin + len(token_ids)) & (rejected_span[1] >= begin) ][0] # subtract by offset to normalize span boundaries - rejected_span[0] = np.maximum(rejected_span[0], offset) - offset # offset - rejected_span[1] = np.minimum(rejected_span[1], offset + len(token_ids) - 1) - offset + rejected_span[0] = np.maximum(rejected_span[0], begin) - begin # offset + rejected_span[1] = np.minimum(rejected_span[1], begin + len(token_ids) - 1) - begin + rejected_span = torch.from_numpy(rejected_span) return GPTSample( - token_ids=token_ids, + token_ids=torch.from_numpy(token_ids), loss_masking_spans=sample_spans, chosen_span=chosen_span, rejected_span=rejected_span, @@ -218,13 +219,13 @@ def __len__(self) -> int: def num_tokens(self) -> int: return self._num_tokens - def get_document_sizes(self) -> np.ndarray: + def get_document_sizes(self) -> torch.Tensor: """ The size of each document in the dataset. The resulting array could be very large, so this method should be called cautiously, and derived classes should try to avoid holding the whole array im memory. """ - return self._document_sizes + return torch.from_numpy(self._document_sizes) def get_document_size(self, index: int) -> int: return self._document_sizes[index].item() @@ -258,7 +259,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP assert document.token_ids.dtype == dtype, f"Expected dtype {dtype}, got {document.token_ids.dtype}." # Write document to binary file - bin_stream.write(document.token_ids.tobytes(order="C")) + bin_stream.write(document.token_ids.numpy().tobytes(order="C")) # Update metadata doc_length = len(document.token_ids) @@ -271,7 +272,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP chosen_spans.append(document.chosen_span) if document.rejected_span is not None: rejected_spans.append(document.rejected_span) - offset += doc_length * np.dtype(dtype).itemsize + offset += doc_length * dtype.itemsize num_documents += 1 # Finalize metadata arrays @@ -297,7 +298,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # Flag to indicate whether preference loss-masking spans are present idx_stream.write(struct.pack(" 0 and rejected_spans.size > 0 else 0)) # Data type - idx_stream.write(struct.pack(" str: return self._name -class GPTRandomSampledDataset(SampledDataset): +class GPTRandomSampledDataset[SampleType: GPTSample](SampledDataset[SampleType]): def __init__(self, sampling: GPTSamplingData, name: str): self._name = name self._seed = sampling.config.seed @@ -32,10 +33,12 @@ def __init__(self, sampling: GPTSamplingData, name: str): def __len__(self) -> int: return self._num_samples - def __getitem__(self, idx) -> np.ndarray: + def __getitem__(self, index: int) -> SampleType: return GPTSample( - np.random.RandomState(self._seed + 48576439 + 74593 * idx).randint( - 0, self._vocab_size, size=(self._sequence_length + 1,), dtype=np.int64 + torch.from_numpy( + np.random.RandomState(self._seed + 48576439 + 74593 * index).randint( + 0, self._vocab_size, size=(self._sequence_length + 1,), dtype=np.int64 + ) ) ) diff --git a/fast_llm/data/dataset/indexed.py b/fast_llm/data/dataset/indexed.py index 09ed52779..c6eac9e28 100644 --- a/fast_llm/data/dataset/indexed.py +++ b/fast_llm/data/dataset/indexed.py @@ -1,20 +1,37 @@ import abc -import typing -import numpy as np +import torch from fast_llm.data.dataset.abstract import SamplableDataset +from fast_llm.data.dataset.config import SamplingData, SamplingParameters +from fast_llm.data.sample.abstract import Sample from fast_llm.utils import Assert, padded_cumsum -class IndexedDataset(SamplableDataset): +class IndexedDataset[SampleType: Sample](SamplableDataset[SampleType]): """ A dataset containing a list of samples. TODO: Move sampling responsibility here? """ @abc.abstractmethod - def get(self, index: int, *args, **kwargs) -> typing.Any: + def get_document_sizes(self) -> torch.Tensor: + """ + The size of each document in the dataset. + The resulting array could be very large, so this method should be called cautiously, + and derived classes should try to avoid holding the whole array im memory. + """ + + @abc.abstractmethod + def get_document_size(self, index: int) -> int: + """ + The size of a document in the dataset. + """ + + @abc.abstractmethod + def get_document( + self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None + ) -> SampleType: pass @abc.abstractmethod @@ -23,13 +40,18 @@ def __len__(self) -> int: Number of samples in the dataset. """ + def sample(self, sampling: SamplingData) -> "GPTSampledIndexedDataset": + from fast_llm.data.dataset.sampled import SampledIndexedDataset + + return SampledIndexedDataset(self, sampling) -class DatasetSlice[IndexedDatasetType: IndexedDataset](IndexedDataset): + +class DatasetSlice[SampleType: Sample](IndexedDataset[SampleType]): def __init__( self, name: str, - dataset: IndexedDataset, + dataset: IndexedDataset[SampleType], begin: int | None = None, end: int | None = None, ): @@ -46,15 +68,22 @@ def __init__( except Exception as e: raise AssertionError(f"Invalid document indices for dataset {name} with length {num_samples}") from e - def get( - self, document: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False - ) -> typing.Any: + def get_document_sizes(self) -> torch.Tensor: + # TODO: This can be really big. + return self._dataset.get_document_sizes()[self._begin : self._end] + + def get_document_size(self, index: int) -> int: + return self._dataset.get_document_size(self._begin + index) + + def get_document( + self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None + ) -> SampleType: """ Get the sample (document) with the given index (in the dataset slice), - optionally sub-sampled to a specific offset (starting point) and maximum length + optionally subsampled to a specific offset (starting point) and maximum length (end = min(offset + length, sample_length). """ - return self._dataset.get(document + self._begin, offset, length, use_loss_masking_spans) + return self._dataset.get_document(index + self._begin, begin, end, parameters) def __len__(self) -> int: return self._end - self._begin @@ -64,24 +93,36 @@ def name(self) -> str: return self._name -class ConcatenatedDataset[IndexedDatasetType: IndexedDataset](IndexedDataset): +class ConcatenatedDataset[SampleType: Sample](IndexedDataset[SampleType]): def __init__( self, name: str, - datasets: list[IndexedDataset], + datasets: list[IndexedDataset[SampleType]], ): self._name = name self._datasets = datasets sizes = [len(dataset) for dataset in self._datasets] - self._dataset_splits = padded_cumsum(sizes) + self._dataset_splits = torch.from_numpy(padded_cumsum(sizes)) def __len__(self) -> int: return self._dataset_splits[-1].item() - def get(self, index: int, *args, **kwargs): - dataset = np.searchsorted(self._dataset_splits[1:], index, side="right") - return self._datasets[dataset].get(index - self._dataset_splits[dataset].item(), *args, **kwargs) + def get_document_sizes(self) -> torch.Tensor: + # TODO: This can be really big. + return torch.cat([dataset.get_document_sizes() for dataset in self._datasets]) + + def get_document_size(self, index: int) -> int: + dataset = torch.searchsorted(self._dataset_splits[1:], index, side="right") + return self._datasets[dataset].get_document_size(index - self._dataset_splits[dataset].item()) + + def get_document( + self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None + ) -> SampleType: + dataset = torch.searchsorted(self._dataset_splits[1:], index, side="right") + return self._datasets[dataset].get_document( + index - self._dataset_splits[dataset].item(), begin, end, parameters + ) @property def name(self) -> str: diff --git a/fast_llm/data/dataset/monitor.py b/fast_llm/data/dataset/monitor.py index 86bc080fe..01f3195e4 100644 --- a/fast_llm/data/dataset/monitor.py +++ b/fast_llm/data/dataset/monitor.py @@ -1,8 +1,8 @@ import logging import time -import typing from fast_llm.data.dataset.abstract import SampledDataset +from fast_llm.data.sample.abstract import Sample try: from fast_llm.csrc.data import build_blending_indices # noqa @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) -class DatasetMonitor(SampledDataset): +class DatasetMonitor[SampleType: Sample](SampledDataset[SampleType]): """ A blended sampling of multiple sampled datasets, where each dataset is sampled with the provided probability. The sampling order of each dataset is respected, but there is no strict guarantee @@ -24,7 +24,7 @@ class DatasetMonitor(SampledDataset): def __init__( self, - dataset: SampledDataset, + dataset: SampledDataset[SampleType], data_sample_warn_time_ms: float, ): self._dataset = dataset @@ -33,19 +33,19 @@ def __init__( def __len__(self) -> int: return len(self._dataset) - def __getitem__(self, idx) -> typing.Any: + def __getitem__(self, index: int) -> SampleType: start_time = time.perf_counter() try: - sample = self._dataset[idx] + sample = self._dataset[index] sample_time = (time.perf_counter() - start_time) * 1000 if sample_time > self._data_sample_warn_time_ms: logger.warning( - f"Sample {idx} from dataset {self._dataset.name})" f" took {sample_time:,.2f} ms to load" + f"Sample {index} from dataset {self._dataset.name})" f" took {sample_time:,.2f} ms to load" ) return sample except Exception: - logger.error(f"Failed to get sample {idx} from dataset {self._dataset.name}") + logger.error(f"Failed to get sample {index} from dataset {self._dataset.name}") raise @property diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/sampled.py similarity index 93% rename from fast_llm/data/dataset/gpt/sampled.py rename to fast_llm/data/dataset/sampled.py index 95006f18e..238e99bca 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -1,4 +1,3 @@ -import dataclasses import logging import math import pathlib @@ -11,7 +10,9 @@ from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingData, ShufflingType -from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset +from fast_llm.data.dataset.indexed import IndexedDataset +from fast_llm.data.sample.abstract import Sample +from fast_llm.data.sample.gpt import GPTSample from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.utils import Assert @@ -26,15 +27,6 @@ logger = logging.getLogger(__name__) -@dataclasses.dataclass -class GPTSample: - token_ids: np.ndarray - loss_masking_spans: np.ndarray | None = None - chosen_span: np.ndarray | None = None - rejected_span: np.ndarray | None = None - sequence_lengths: np.ndarray | None = None - - class MemmapArray: """ An array with lazy loading in memmap mode. @@ -75,14 +67,15 @@ def _lazy_load(self): TOKEN_CUMSUM_RATE = 10 -class GPTSampledIndexedDataset(SampledDataset): +class SampledIndexedDataset[SampleType: Sample](SampledDataset[SampleType]): """ A sampled GPT dataset. """ def __init__( self, - indexed_dataset: GPTIndexedDataset, + indexed_dataset: IndexedDataset[SampleType], + # TODO: ====== Remove gpt-specific stuff ====== sampling: GPTSamplingData, ): assert isinstance(sampling, GPTSamplingData) @@ -133,7 +126,7 @@ def _sample(self) -> None: Create a `GPTSampledDataset` with the requested parameters. """ # Get the document sizes, the main information needed for sampling. - document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device) + document_sizes = self._indexed_dataset.get_document_sizes().to(self._device) documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() @@ -375,7 +368,7 @@ def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) - def __len__(self) -> int: return self._parameters.num_samples - def __getitem__(self, index: int) -> typing.Any: + def __getitem__(self, index: int) -> SampleType: """ Get the sample, (fixed-length sequence of tokens holding one or more complete or partial documents) with the requested sampling index. @@ -391,12 +384,11 @@ def __getitem__(self, index: int) -> typing.Any: self._document_shuffling[index - self._unshuffled_documents].item() ] - sample = self._indexed_dataset.get( - document_index, - offset=0, - length=self._document_sizes[document_index], - use_loss_masking_spans=self._parameters.use_loss_masking_spans, - use_preference_loss_spans=self._parameters.use_preference_loss_spans, + sample = self._indexed_dataset.get_document( + document_index.item(), + begin=0, + end=self._document_sizes[document_index].item(), + parameters=self._parameters, ) chosen_span_end = sample.chosen_span[1] + 1 @@ -412,7 +404,7 @@ def __getitem__(self, index: int) -> typing.Any: sample.token_ids = padding if not self._parameters.cross_document_attention: - sample.sequence_lengths = np.array(sequence_lengths) + sample.sequence_lengths = torch.tensor(sequence_lengths) return sample @@ -474,11 +466,11 @@ def __getitem__(self, index: int) -> typing.Any: # Determine which part of the document belong to the sample, and add it to the list. token_start_index_in_document = max(token_start - token_count, 0) token_end_index_in_document = min(token_end - token_count, document_size) - sample = self._indexed_dataset.get( + sample = self._indexed_dataset.get_document( document_index, - offset=token_start_index_in_document, - length=token_end_index_in_document - token_start_index_in_document, - use_loss_masking_spans=self._parameters.use_loss_masking_spans, + begin=token_start_index_in_document, + end=token_end_index_in_document, + parameters=self._parameters, ) token_ids.append(sample.token_ids) if self._parameters.use_loss_masking_spans: @@ -496,19 +488,23 @@ def __getitem__(self, index: int) -> typing.Any: token_count += document_size sequence_lengths = ( - np.array([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) + torch.tensor([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) if not self._parameters.cross_document_attention else None ) token_ids = np.concatenate(token_ids, dtype=np.int64) loss_masking_spans = ( - (np.stack(loss_masking_spans, dtype=np.int32) if loss_masking_spans else np.array([])) + torch.from_numpy(np.stack(loss_masking_spans, dtype=np.int32) if loss_masking_spans else np.array([])) if self._parameters.use_loss_masking_spans else None ) Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) - return GPTSample(token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths) + return GPTSample( + token_ids=torch.from_numpy(token_ids), + loss_masking_spans=loss_masking_spans, + sequence_lengths=sequence_lengths, + ) @property def name(self) -> str: diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 33c40bf8f..a8ff187ae 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -14,17 +14,17 @@ import transformers import yaml -from fast_llm.data.dataset.gpt.config import ( - GPTBlendedDatasetConfig, - GPTDatasetSliceConfig, - GPTIndexedDatasetConfig, - GPTMemmapDatasetConfig, - GPTSampledDatasetConfig, +from fast_llm.data.dataset.config import ( + BlendedDatasetConfig, + DatasetSliceConfig, + IndexedDatasetConfig, + SampledDatasetConfig, ) +from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, TextColumnConfig +from fast_llm.data.sample.gpt import GPTSample from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum @@ -37,6 +37,7 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D _data_type: DataType _text_column: str _loss_masking_spans_column: str | None + _sample_type: typing.ClassVar[type[GPTSample]] = GPTSample def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: input_ids = [ @@ -144,8 +145,8 @@ def _document_generator(): if "token_spans" in shard_dataset.column_names and self._loss_masking_spans_column is not None: for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): yield GPTSample( - np.array(item["input_ids"], dtype=self._data_type.numpy), - np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), + torch.tensor(item["input_ids"], dtype=self._data_type.torch), + torch.tensor(item["token_spans"], dtype=torch.int32).reshape(-1, 2), ) elif ( "chosen_token_spans" in shard_dataset.column_names @@ -155,13 +156,13 @@ def _document_generator(): ): for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): yield GPTSample( - token_ids=np.array(item["input_ids"], dtype=self._data_type.numpy), - chosen_span=np.array(item["chosen_token_spans"], dtype=np.int32).reshape(-1, 2), - rejected_span=np.array(item["rejected_token_spans"], dtype=np.int32).reshape(-1, 2), + token_ids=torch.tensor(item["input_ids"], dtype=self._data_type.torch), + chosen_span=torch.tensor(item["chosen_token_spans"], dtype=torch.int32).reshape(-1, 2), + rejected_span=torch.tensor(item["rejected_token_spans"], dtype=torch.int32).reshape(-1, 2), ) else: for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample(np.array(item["input_ids"], dtype=self._data_type.numpy)) + yield GPTSample(torch.tensor(item["input_ids"], dtype=self._data_type.torch)) GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) @@ -376,7 +377,9 @@ def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[GPTMemmapDa torch.distributed.destroy_process_group() @classmethod - def _save_dataset_config(cls, dataset_config: GPTIndexedDatasetConfig, output_path: pathlib.Path) -> None: + def _save_dataset_config( + cls, dataset_config: IndexedDatasetConfig[_sample_type], output_path: pathlib.Path + ) -> None: logger.info(f"Saving config to {output_path}") yaml.safe_dump( dataset_config.to_dict(), @@ -384,10 +387,12 @@ def _save_dataset_config(cls, dataset_config: GPTIndexedDatasetConfig, output_pa ) @classmethod - def _blend_dataset_configs(cls, dataset_configs: list[GPTMemmapDatasetConfig]) -> GPTIndexedDatasetConfig: + def _blend_dataset_configs( + cls, dataset_configs: list[GPTMemmapDatasetConfig[_sample_type]] + ) -> IndexedDatasetConfig[_sample_type]: if len(dataset_configs) == 1: return dataset_configs[0] - return GPTSampledDatasetConfig.from_dict( + return SampledDatasetConfig[cls._sample_type].from_dict( { "type": "blended", "datasets": dataset_configs, @@ -397,8 +402,11 @@ def _blend_dataset_configs(cls, dataset_configs: list[GPTMemmapDatasetConfig]) - @classmethod def _split_and_blend_dataset_configs( - cls, dataset_configs: list[GPTMemmapDatasetConfig], splits: dict[str, int | float], output_path: pathlib.Path - ) -> dict[str, GPTSampledDatasetConfig]: + cls, + dataset_configs: list[GPTMemmapDatasetConfig[_sample_type]], + splits: dict[str, int | float], + output_path: pathlib.Path, + ) -> dict[str, SampledDatasetConfig[_sample_type]]: split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist() dataset_sizes = [dataset_config.num_tokens for dataset_config in dataset_configs] dataset_probabilities = normalize_probabilities(dataset_sizes) @@ -427,13 +435,13 @@ def _split_and_blend_dataset_configs( # Part of the dataset belongs to the split. # TODO: Somehow getting a segfault when merging two lines below (numpy bug?). dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build() - sizes_cumsum = dataset.get_document_sizes().cumsum() + sizes_cumsum = dataset.get_document_sizes().numpy().cumsum() Assert.eq(sizes_cumsum[-1], dataset_config.num_tokens) begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * dataset_config.num_tokens) end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * dataset_config.num_tokens) if end_index > begin_index: datasets_in_split.append( - GPTDatasetSliceConfig.from_dict( + DatasetSliceConfig[cls._sample_type].from_dict( { "type": "slice", "dataset": dataset_configs[dataset_index], @@ -455,7 +463,7 @@ def _split_and_blend_dataset_configs( elif len(datasets_in_split) == 1: dataset_splits[split_name] = datasets_in_split[0] else: - dataset_splits[split_name] = GPTBlendedDatasetConfig.from_dict( + dataset_splits[split_name] = BlendedDatasetConfig[cls._sample_type].from_dict( { "type": "blended", "datasets": datasets_in_split, diff --git a/fast_llm/data/sample/__init__.py b/fast_llm/data/sample/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py new file mode 100644 index 000000000..0c640b9b3 --- /dev/null +++ b/fast_llm/data/sample/abstract.py @@ -0,0 +1,10 @@ +import abc + + +class Sample(abc.ABC): + pass + + +class Batch(abc.ABC): + # TODO: Relate to `BatchConfig`? + pass diff --git a/fast_llm/data/sample/gpt.py b/fast_llm/data/sample/gpt.py new file mode 100644 index 000000000..4bf740462 --- /dev/null +++ b/fast_llm/data/sample/gpt.py @@ -0,0 +1,25 @@ +import dataclasses +import typing + +from fast_llm.data.sample.abstract import Batch, Sample + +if typing.TYPE_CHECKING: + import torch + + +@dataclasses.dataclass +class GPTSample(Sample): + token_ids: "torch.Tensor" + loss_masking_spans: "torch.Tensor | None" = None + chosen_span: "torch.Tensor | None" = None + rejected_span: "torch.Tensor | None" = None + sequence_lengths: "torch.Tensor | None" = None + + +@dataclasses.dataclass +class GPTBatch(Batch): + token_ids: "torch.Tensor" + loss_masking_spans: "list[torch.Tensor] | None" = None + sequence_lengths: "list[torch.Tensor] | None" = None + chosen_spans: "list[torch.Tensor] | None" = None + rejected_spans: "list[torch.Tensor] | None" = None diff --git a/fast_llm/engine/config_utils/data_type.py b/fast_llm/engine/config_utils/data_type.py index f4a2cfd6c..add121c50 100644 --- a/fast_llm/engine/config_utils/data_type.py +++ b/fast_llm/engine/config_utils/data_type.py @@ -23,8 +23,10 @@ class DataType(enum.StrEnum): int32 = "int32" int16 = "int16" int8 = "int8" - uint8 = "uint8" + uint64 = "uint64" + uint32 = "uint32" uint16 = "uint16" + uint8 = "uint8" @classmethod def _missing_(cls, dtype: str) -> "DataType": @@ -105,6 +107,9 @@ def _set_torch_dtype_map() -> None: DataType.int32: torch.int32, DataType.int16: torch.int16, DataType.int8: torch.int8, + DataType.uint64: torch.uint64, + DataType.uint32: torch.uint32, + DataType.uint16: torch.uint16, DataType.uint8: torch.uint8, } _TORCH_DTYPE_MAP_INV = {y: x for x, y in _TORCH_DTYPE_MAP.items()} @@ -127,8 +132,10 @@ def _set_numpy_dtype_map() -> None: DataType.int32: np.int32, DataType.int16: np.int16, DataType.int8: np.int8, - DataType.uint8: np.uint8, + DataType.uint64: np.uint64, + DataType.uint32: np.uint32, DataType.uint16: np.uint16, + DataType.uint8: np.uint8, } _NUMPY_DTYPE_MAP_INV = {y: x for x, y in _NUMPY_DTYPE_MAP.items()} @@ -151,6 +158,9 @@ def _set_triton_dtype_map() -> None: DataType.int32: tl.int32, DataType.int16: tl.int16, DataType.int8: tl.int8, + DataType.uint64: tl.uint64, + DataType.uint32: tl.uint32, + DataType.uint16: tl.uint16, DataType.uint8: tl.uint8, } diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index 9215e6dc7..a76c3712e 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -5,7 +5,7 @@ import torch import transformers.modeling_outputs -from fast_llm.data.data.gpt.data import GPTBatch +from fast_llm.data.sample.gpt import GPTBatch from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index efa348ecb..bd3c91a38 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -3,7 +3,7 @@ import torch -from fast_llm.data.data.gpt.data import GPTBatch +from fast_llm.data.sample.gpt import GPTBatch from fast_llm.engine.base_model.base_model import BaseModel from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType diff --git a/tests/data/common.py b/tests/data/common.py index d8cc6fff2..3ade0e9bf 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -8,17 +8,11 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.gpt.config import ( - GPTIndexedDatasetConfig, - GPTSampledDatasetConfig, - GPTSamplingConfig, - GPTSamplingData, - GPTSamplingParameters, - ShufflingType, -) -from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset -from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset -from fast_llm.data.tokenizer import Tokenizer +from fast_llm.data.dataset.config import IndexedDatasetConfig, SampledDatasetConfig, SamplingParameters +from fast_llm.data.dataset.gpt.config import GPTSamplingConfig, GPTSamplingData, GPTSamplingParameters, ShufflingType +from fast_llm.data.dataset.indexed import IndexedDataset +from fast_llm.data.dataset.sampled import SampledIndexedDataset +from fast_llm.data.sample.abstract import Sample from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.models.gpt.config import GPTBatchConfig @@ -34,7 +28,6 @@ def get_sampling_data( phase=PhaseType.training, sequence_length: int = 512, vocab_size=TEST_VOCAB_SIZE, - tokenizer: Tokenizer | None = None, gpu: bool = False, shuffle: ShufflingType = ShufflingType.epoch, truncate_documents=True, @@ -56,13 +49,12 @@ def get_sampling_data( cache_directory=cache_directory, distributed=distributed, dataset_name=phase.value, - tokenizer=tokenizer, ) -def get_dataset_config[T: GPTSampledDatasetConfig](config: dict[str, typing.Any], cls: type[T]) -> T: - dataset_config = GPTSampledDatasetConfig.from_dict(config) - Assert.custom(isinstance, dataset_config, cls) +def get_dataset_config[T: SampledDatasetConfig](config: dict[str, typing.Any], cls: type[T]) -> T: + dataset_config = SampledDatasetConfig.from_dict(config) + Assert.custom(isinstance, dataset_config, getattr(cls, "__origin__", cls)) return typing.cast(cls, dataset_config) @@ -115,7 +107,7 @@ def get_test_data_and_compare_samples( def compare_indexed_dataset( - dataset: GPTIndexedDataset, + dataset: IndexedDataset, length: int, num_tokens: int, expected_samples: dict[int, list[int]], @@ -125,26 +117,30 @@ def compare_indexed_dataset( sizes = dataset.get_document_sizes() # Assert.eq(sizes.sum(), num_tokens) Assert.all_equal( - [len(dataset.get(i).token_ids) for i in range(min(len(dataset), 100))], sizes[: min(len(dataset), 100)] + [len(dataset.get_document(i).token_ids) for i in range(min(len(dataset), 100))], + sizes[: min(len(dataset), 100)], ) for i, expected_sample in expected_samples.items(): - Assert.all_equal(dataset.get(i).token_ids, np.array(expected_sample, dtype=np.uint16)) + Assert.all_equal(dataset.get_document(i).token_ids, np.array(expected_sample, dtype=np.uint16)) if loss_masking_spans: for i, loss_masking_span in loss_masking_spans.items(): Assert.all_equal( - dataset.get(i, use_loss_masking_spans=True).loss_masking_spans, + dataset.get_document( + i, + parameters=GPTSamplingParameters( + num_samples=0, sequence_length=0, vocab_size=0, use_loss_masking_spans=True + ), + ).loss_masking_spans, np.array(loss_masking_spans[i], dtype=np.int32).reshape(-1, 2), ) def compare_sampled_dataset(sampled: SampledDataset, expected_samples: list[list[int] | np.ndarray]) -> None: Assert.eq(len(sampled), len(expected_samples)) - Assert.all_equal([sampled[i].token_ids for i in range(len(expected_samples))], expected_samples) + Assert.all_equal(torch.stack([sampled[i].token_ids for i in range(len(expected_samples))]), expected_samples) -def validate_indexed_dataset_sampling( - sampled: GPTSampledIndexedDataset, expected_samples: list[list[int]] | None = None -): +def validate_indexed_dataset_sampling(sampled: SampledIndexedDataset, expected_samples: list[list[int]] | None = None): """ Compare `GPTSampledIndexedDataset` sampling against a more basic approach """ @@ -165,7 +161,7 @@ def validate_indexed_dataset_sampling( ) seen_tokens = 0 for document_index in document_sampling: - document = sampled._indexed_dataset.get(document_index).token_ids + document = sampled._indexed_dataset.get_document(document_index).token_ids all_tokens[seen_tokens : seen_tokens + len(document)] = document[: num_tokens - seen_tokens] seen_tokens += len(document) @@ -176,7 +172,7 @@ def validate_indexed_dataset_sampling( all_tokens[index * sampled._parameters.sequence_length : (index + 1) * sampled._parameters.sequence_length + 1] for index in range(sampled._parameters.num_samples) ] - token_ids = [sampled[i].token_ids for i in range(len(sampled))] + token_ids = torch.stack([sampled[i].token_ids for i in range(len(sampled))]) Assert.all_equal(token_ids, validate_samples) if expected_samples is not None: @@ -184,8 +180,8 @@ def validate_indexed_dataset_sampling( return token_ids -@config_class(dynamic_type={GPTSampledDatasetConfig: "mock_memmap"}) -class MockGPTMemmapDatasetConfig(GPTIndexedDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "mock_memmap"}) +class MockGPTMemmapDatasetConfig(IndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False num_documents: int | None = Field( default=None, @@ -199,15 +195,15 @@ class MockGPTMemmapDatasetConfig(GPTIndexedDatasetConfig): ) path: pathlib.Path = Field(default=".") - def build(self) -> "GPTIndexedDataset": - return MockGPTMemmapDataset(self) + def build(self) -> "IndexedDataset": + return MockMemmapDataset(self) @property def num_tokens(self) -> int: return self.num_documents * self.num_tokens_per_document -class MockGPTMemmapDataset(GPTIndexedDataset): +class MockMemmapDataset[SampleType: Sample](IndexedDataset[SampleType]): def __init__(self, config: MockGPTMemmapDatasetConfig): self._config = config @@ -218,11 +214,13 @@ def name(self) -> str: def __len__(self) -> int: return self._config.num_documents - def get_document_sizes(self) -> np.ndarray: - return np.full(self._config.num_documents, self._config.num_tokens_per_document, dtype=np.int64) + def get_document_sizes(self) -> torch.Tensor: + return torch.full([self._config.num_documents], self._config.num_tokens_per_document, dtype=torch.int64) def get_document_size(self, index: int) -> int: return self._config.num_tokens_per_document - def get(self, index: int, *args, **kwargs) -> typing.Any: + def get_document( + self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None + ) -> SampleType: raise NotImplementedError() diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index e64b47020..678bffa21 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -3,7 +3,8 @@ import numpy as np import pytest -from fast_llm.data.dataset.gpt.config import GPTBlendedDatasetConfig +from fast_llm.data.dataset.config import BlendedDatasetConfig +from fast_llm.data.sample.gpt import GPTSample from fast_llm.utils import Assert, normalize_probabilities from tests.data.common import ( compare_sampled_dataset, @@ -122,7 +123,7 @@ def test_gpt_blended(): ], "weights": [0.75, 0.25], }, - GPTBlendedDatasetConfig, + BlendedDatasetConfig[GPTSample], ).build_and_sample(get_sampling_data(8, sequence_length=5)) compare_sampled_dataset(sampled, GPT_BLENDED_SAMPLES) @@ -161,7 +162,7 @@ def test_gpt_blended_mixed(): ], "weights": [0.6, 0.4], }, - GPTBlendedDatasetConfig, + BlendedDatasetConfig[GPTSample], ).build_and_sample(get_sampling_data(8, sequence_length=5)) compare_sampled_dataset(sampled, GPT_BLENDED_MIXED_SAMPLES) diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index 2c025cbaf..bb4905cb6 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -1,4 +1,5 @@ -from fast_llm.data.dataset.gpt.config import GPTConcatenatedDatasetConfig +from fast_llm.data.dataset.config import ConcatenatedDatasetConfig +from fast_llm.data.sample.gpt import GPTSample from tests.data.common import ( compare_indexed_dataset, compare_sampled_dataset, @@ -27,7 +28,7 @@ def test_gpt_concatenate(): get_test_dataset() dataset = get_dataset_config( {"type": "concatenated", "datasets": [{"type": "memmap", "path": DATASET_PREFIX} for _ in range(3)]}, - GPTConcatenatedDatasetConfig, + ConcatenatedDatasetConfig[GPTSample], ).build() compare_indexed_dataset( dataset, diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index c9212d6e3..438c5e7e3 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -1,6 +1,4 @@ -from fast_llm.data.config import TokenizerConfig from fast_llm.data.dataset.gpt.config import GPTFimSampledDatasetConfig -from fast_llm.data.tokenizer import Tokenizer from tests.data.common import ( compare_sampled_dataset, get_dataset_config, @@ -29,13 +27,13 @@ def test_gpt_fim(): sampling_config = get_sampling_data( 8, sequence_length=5, - tokenizer=Tokenizer(TokenizerConfig.from_dict({"path": TOKENIZER_PATH})), vocab_size=49157, ) sampled = get_dataset_config( { "type": "fim", "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "tokenizer": {"path": TOKENIZER_PATH}, "rate": 0.5, "prefix_token": "w", "middle_token": "x", @@ -55,6 +53,7 @@ def test_gpt_fim_data(): "training": { "type": "fim", "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "tokenizer": {"path": TOKENIZER_PATH}, "rate": 0.5, "prefix_token": "w", "middle_token": "x", @@ -62,7 +61,6 @@ def test_gpt_fim_data(): "suffix_token": "z", } }, - "tokenizer": {"path": TOKENIZER_PATH}, }, 8, sequence_length=5, diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index 17ba5de01..388726bfb 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -4,12 +4,14 @@ import numpy as np import pytest +import torch -from fast_llm.data.dataset.gpt.config import GPTIndexedDatasetConfig +from fast_llm.data.dataset.config import IndexedDatasetConfig +from fast_llm.data.dataset.gpt.config import GPTSamplingParameters from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, GPTMemmapDatasetPreparatorConfig from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator +from fast_llm.data.sample.gpt import GPTSample from fast_llm.utils import Assert from tests.data.common import MockGPTMemmapDatasetConfig # Noqa @@ -28,22 +30,25 @@ def get_preparator(output_path: str, dataset_path_name: str) -> GPTMemmapDataset @pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) def test_write_memmap_dataset(dtype): - documents = [GPTSample(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype)) for _ in range(100)] + documents = [ + GPTSample(torch.from_numpy(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype))) + for _ in range(100) + ] with tempfile.TemporaryDirectory() as temp_dir: prefix = pathlib.Path(temp_dir) GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) dataset = GPTMemmapDataset(name="foo", prefix=prefix) for i, document in enumerate(documents): assert np.array_equal( - dataset.get(i).token_ids, document.token_ids, equal_nan=True - ), f"Mismatch for document {i}: {document} != {dataset.get(i)}." + dataset.get_document(i).token_ids, document.token_ids, equal_nan=True + ), f"Mismatch for document {i}: {document} != {dataset.get_document(i)}." @pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) def test_write_memmap_preference_dataset(dtype): def generate_valid_span(max_seq_length): span = np.random.choice(np.arange(0, max_seq_length - 1), size=2, replace=False) - return np.sort(span) + return torch.from_numpy(np.sort(span)) vocab_size = 1000 max_seq_length = 8192 @@ -51,7 +56,7 @@ def generate_valid_span(max_seq_length): documents = [ GPTSample( - token_ids=np.random.randint(vocab_size, size=max_seq_length).astype(dtype), + token_ids=torch.from_numpy(np.random.randint(vocab_size, size=max_seq_length).astype(dtype)), chosen_span=generate_valid_span(max_seq_length=max_seq_length), rejected_span=generate_valid_span(max_seq_length=max_seq_length), ) @@ -62,18 +67,23 @@ def generate_valid_span(max_seq_length): GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) dataset = GPTMemmapDataset(name="foo", prefix=prefix) for i, document in enumerate(documents): - dataset_item = dataset.get(i, use_preference_loss_spans=True) + dataset_item = dataset.get_document( + i, + parameters=GPTSamplingParameters( + num_samples=0, sequence_length=0, vocab_size=0, use_preference_loss_spans=True + ), + ) assert np.array_equal( dataset_item.token_ids, document.token_ids, equal_nan=True - ), f"Token ids mismatch for document {i}: {document} != {dataset.get(i)}." + ), f"Token ids mismatch for document {i}: {document} != {dataset.get_document(i)}." assert np.array_equal( dataset_item.chosen_span, document.chosen_span, equal_nan=True - ), f"Chosen loss masking spans mismatch for document {i}: {document.chosen_span} != {dataset.get(i).chosen_span}." + ), f"Chosen loss masking spans mismatch for document {i}: {document.chosen_span} != {dataset.get_document(i).chosen_span}." assert np.array_equal( dataset_item.rejected_span, document.rejected_span, equal_nan=True - ), f"Rejected loss masking spans mismatch for document {i}: {document.rejected_span} != {dataset.get(i).rejected_span}." + ), f"Rejected loss masking spans mismatch for document {i}: {document.rejected_span} != {dataset.get_document(i).rejected_span}." def test_load_metadata_from_hub(): @@ -126,7 +136,7 @@ def test_absent_metadata_local(): def test_split_dataset(): - dataset_config_0 = GPTIndexedDatasetConfig.from_dict(DATASET_DICT_0.copy()) + dataset_config_0 = IndexedDatasetConfig[GPTSample].from_dict(DATASET_DICT_0.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0], {"training": 3, "validation": 1}, @@ -154,8 +164,8 @@ def test_split_dataset(): def test_split_datasets_0(): - dataset_config_0 = GPTIndexedDatasetConfig.from_dict(DATASET_DICT_0.copy()) - dataset_config_1 = GPTIndexedDatasetConfig.from_dict(DATASET_DICT_1.copy()) + dataset_config_0 = IndexedDatasetConfig[GPTSample].from_dict(DATASET_DICT_0.copy()) + dataset_config_1 = IndexedDatasetConfig[GPTSample].from_dict(DATASET_DICT_1.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0, dataset_config_1], {"training": 1, "validation": 1}, @@ -173,8 +183,8 @@ def test_split_datasets_0(): def test_split_datasets_1(): - dataset_config_0 = GPTIndexedDatasetConfig.from_dict(DATASET_DICT_0.copy()) - dataset_config_1 = GPTIndexedDatasetConfig.from_dict(DATASET_DICT_1.copy()) + dataset_config_0 = IndexedDatasetConfig[GPTSample].from_dict(DATASET_DICT_0.copy()) + dataset_config_1 = IndexedDatasetConfig[GPTSample].from_dict(DATASET_DICT_1.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0, dataset_config_1], {"training": 3, "validation": 1}, pathlib.Path(".") ) diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 6a2be3dcc..d7b3021fe 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -1,11 +1,10 @@ -import typing - import numpy as np import pytest +import torch -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, ShufflingType -from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSamplingParameters, ShufflingType +from fast_llm.data.dataset.indexed import IndexedDataset +from fast_llm.data.sample.gpt import GPTSample from fast_llm.utils import Assert from tests.data.common import ( get_dataset_config, @@ -62,24 +61,23 @@ def test_gpt_sampled_data(): ) -class SimpleGPTIndexedDataset(GPTIndexedDataset): +class SimpleGPTIndexedDataset[SampleType: GPTSample](IndexedDataset[SampleType]): # TODO: worth adding to the main codebase? def __init__(self, samples): self._samples = samples - def get(self, index: int, offset=0, length=None, use_loss_masking_spans: bool = False) -> typing.Any: - if length is None: - length = len(self._samples[index]) - assert not use_loss_masking_spans - return GPTSample( - token_ids=np.array(self._samples[index][offset : offset + length], dtype=np.int64), loss_masking_spans=None - ) + def get_document( + self, index: int, begin: int = 0, end: int | None = None, parameters: GPTSamplingParameters | None = None + ) -> SampleType: + if end is None: + end = len(self._samples[index]) + return GPTSample(token_ids=torch.tensor(self._samples[index][begin:end], dtype=torch.int64)) def __len__(self) -> int: return len(self._samples) - def get_document_sizes(self) -> np.ndarray: - return np.array([self.get_document_size(index) for index in range(len(self))], dtype=np.int64) + def get_document_sizes(self) -> torch.Tensor: + return torch.tensor([self.get_document_size(index) for index in range(len(self))], dtype=torch.int64) def get_document_size(self, index: int) -> int: return len(self._samples[index]) diff --git a/tests/data/test_slice.py b/tests/data/test_slice.py index 1fc8df1eb..e83387a24 100644 --- a/tests/data/test_slice.py +++ b/tests/data/test_slice.py @@ -1,4 +1,5 @@ -from fast_llm.data.dataset.gpt.config import GPTDatasetSliceConfig +from fast_llm.data.dataset.config import DatasetSliceConfig +from fast_llm.data.sample.gpt import GPTSample from tests.data.common import ( compare_indexed_dataset, get_dataset_config, @@ -34,7 +35,7 @@ def test_gpt_slice(): # samples[9:18] dataset = get_dataset_config( {"type": "slice", "dataset": {"type": "memmap", "path": DATASET_PREFIX}, "begin": 0.0015, "end": 0.003}, - GPTDatasetSliceConfig, + DatasetSliceConfig[GPTSample], ).build() compare_indexed_dataset(dataset, 9, 544, {i - 9: sample for i, sample in MEMMAP_DATASET_SAMPLES.items()}) sampled = dataset.sample(get_sampling_data(8, sequence_length=5)) diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index 6aa541b8c..f057c037f 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -3,12 +3,15 @@ import numpy as np import pytest +import torch from fast_llm.config import Field, FieldHint, config_class from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSampledDatasetConfig, GPTSamplingData +from fast_llm.data.dataset.config import SampledDatasetConfig +from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSamplingData from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample, logger +from fast_llm.data.dataset.sampled import logger +from fast_llm.data.sample.gpt import GPTSample from fast_llm.utils import Assert from tests.utils.compare_tensor_logs import CompareConfig from tests.utils.dataset import get_model_test_dataset @@ -79,7 +82,7 @@ def test_match_megatron(run_test_script_for_all_models, model_testing_config, co compare_results_for_all_models(distributed_testing_config) -@config_class(dynamic_type={GPTSampledDatasetConfig: "megatron"}) +@config_class(dynamic_type={SampledDatasetConfig: "megatron"}) class GPTMegatronDatasetConfig(GPTMemmapDatasetConfig): _abstract: typing.ClassVar[bool] = False path: str = Field( @@ -142,14 +145,14 @@ def __getitem__(self, idx: int) -> typing.Any: doc_f, offset_f = self._sample_idx[shuffled_idx] doc_l, offset_l = self._sample_idx[shuffled_idx + 1] sample_list = [ - self._indexed_dataset.get( + self._indexed_dataset.get_document( self._doc_idx[doc].item(), - offset=(doc == doc_f) * offset_f, - length=offset_l + 1 - (doc == doc_f) * offset_f if doc == doc_l else None, + begin=(doc == doc_f) * offset_f, + end=offset_l + 1 if doc == doc_l else None, ) for doc in range(doc_f, doc_l + 1) ] - token_ids = np.concatenate([sample.token_ids for sample in sample_list], dtype=np.int64) + token_ids = torch.cat([sample.token_ids for sample in sample_list]) Assert.eq(len(token_ids), self._sequence_length + 1) return GPTSample(token_ids=token_ids) diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index 680faa931..b43923f4d 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -2,10 +2,11 @@ import random import numpy as np +import torch import yaml from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.data.sample.gpt import GPTSample from tests.utils.global_variables import ( DATASET_PREFIX, MODEL_DATASET_PREFIX, @@ -46,14 +47,15 @@ def get_test_dataset( tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) samples = [ - GPTSample(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size) for document in texts + GPTSample(torch.from_numpy(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size)) + for document in texts ] if max_spans > 0: lengths = np.array([max(len(sample.token_ids), 1) for sample in samples]) spans = np.sort(np.random.RandomState(seed + 3847).randint(0, lengths[:, None], [len(samples), max_spans])) for sample, span in zip(samples, spans): span = np.unique(span) - sample.loss_masking_spans = span[: len(span) // 2 * 2].reshape(-1, 2) + sample.loss_masking_spans = torch.from_numpy(span[: len(span) // 2 * 2].reshape(-1, 2)) GPTMemmapDataset.write_dataset(prefix, samples) yaml.safe_dump( From fd63846895b64c5f8755c5289e9a74d0f1364535 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 14 Oct 2025 22:57:17 -0400 Subject: [PATCH 02/45] misc --- fast_llm/engine/evaluation/config.py | 4 ++++ fast_llm/engine/evaluation/lm_eval/evaluator.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index 4f035e174..f8dfd4825 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -2,6 +2,7 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.data.config import TokenizerConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.utils import Assert @@ -63,6 +64,9 @@ def get_evaluator( class LmEvalEvaluatorConfig(EvaluatorConfig): _abstract: typing.ClassVar[bool] = False + tokenizer: TokenizerConfig = Field( + desc="Configuration for the tokenizer.", + ) cli_args: list[str] = Field( default_factory=lambda: [], desc="lm_eval CLI arguments, excluding those related to model, wandb, batch sizes, and device.", diff --git a/fast_llm/engine/evaluation/lm_eval/evaluator.py b/fast_llm/engine/evaluation/lm_eval/evaluator.py index 14aed65c4..5bfb544ed 100644 --- a/fast_llm/engine/evaluation/lm_eval/evaluator.py +++ b/fast_llm/engine/evaluation/lm_eval/evaluator.py @@ -60,7 +60,7 @@ def setup( self._flm_wrapper = FastLLMLmEvalWrapper( model=self._hf_model, - tokenizer=self._data.tokenizer.tokenizer, + tokenizer=self._config.tokenizer.get_tokenizer(), truncation=self._config.truncation, logits_cache=self._config.logits_cache, add_bos_token=self._config.add_bos_token, From 2486cafba2e3dfc31b219a253b39ece0ff2b8d77 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 15 Oct 2025 16:21:42 -0400 Subject: [PATCH 03/45] fix --- fast_llm/data/dataset/abstract.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/fast_llm/data/dataset/abstract.py b/fast_llm/data/dataset/abstract.py index d57135ede..33942708b 100644 --- a/fast_llm/data/dataset/abstract.py +++ b/fast_llm/data/dataset/abstract.py @@ -19,6 +19,14 @@ def name(self) -> str: A name for the dataset to facilitate identification and debugging. """ + def __getstate__(self): + state = super().__getstate__() + # Pickling sometimes fails with bound `SampleType`. + # This is not needed at runtime, so we just drop it. + if "__orig_class__" in state: + del state["__orig_class__"] + return state + class SampledDataset[SampleType: Sample](Dataset[SampleType]): """ From 92e93e8f03d9db34a9621c9ecb832b16efbb2cf3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 15 Oct 2025 23:51:33 -0400 Subject: [PATCH 04/45] Language model sample --- fast_llm/data/data/gpt/config.py | 8 +- fast_llm/data/data/gpt/data.py | 33 +--- fast_llm/data/dataset/config.py | 29 ++++ fast_llm/data/dataset/gpt/config.py | 71 +++------ fast_llm/data/dataset/gpt/fim.py | 16 +- fast_llm/data/dataset/gpt/memmap.py | 118 +++++++-------- fast_llm/data/dataset/gpt/random.py | 29 ++-- fast_llm/data/dataset/sampled.py | 134 +++-------------- .../data/preparator/gpt_memmap/prepare.py | 25 ++- fast_llm/data/sample/abstract.py | 31 +++- fast_llm/data/sample/gpt.py | 25 --- fast_llm/data/sample/language_model.py | 99 ++++++++++++ fast_llm/data/sample/range.py | 46 ++++++ fast_llm/data/sample/token.py | 70 +++++++++ fast_llm/functional/dpo.py | 71 +++------ fast_llm/models/gpt/huggingface.py | 7 +- fast_llm/models/gpt/model.py | 142 ++++++------------ tests/data/common.py | 26 ++-- tests/data/test_blending.py | 4 +- tests/data/test_concatenate.py | 4 +- tests/data/test_memmap.py | 4 +- tests/data/test_prepare_gpt_memmap.py | 69 ++++----- tests/data/test_sampling.py | 10 +- tests/data/test_slice.py | 4 +- tests/models/test_match_megatron.py | 5 +- tests/test_config.py | 4 +- tests/utils/dataset.py | 33 +++- 27 files changed, 576 insertions(+), 541 deletions(-) delete mode 100644 fast_llm/data/sample/gpt.py create mode 100644 fast_llm/data/sample/language_model.py create mode 100644 fast_llm/data/sample/range.py create mode 100644 fast_llm/data/sample/token.py diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index 5083c5121..c7f16c936 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -1,11 +1,10 @@ import logging -from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class +from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.data.config import MultiprocessingContext from fast_llm.data.data.config import DataConfig from fast_llm.data.dataset.config import SampledDatasetConfig -from fast_llm.data.dataset.gpt.config import GPTSamplingConfig -from fast_llm.data.sample.gpt import GPTSample +from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -22,12 +21,11 @@ class GPTDataConfig(DataConfig): _abstract = False # TODO: Review field. Move closer to phase definition in training config? - datasets: dict[str, SampledDatasetConfig[GPTSample]] = Field( + datasets: dict[str, SampledDatasetConfig[LanguageModelSample]] = Field( default_factory=dict, desc="Configuration for the dataset(s).", hint=FieldHint.core, ) - sampling: GPTSamplingConfig = FieldUpdate() data_sample_warn_time_ms: float = Field( default=1000, desc="Warn if a sample takes too long to load.", diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 2a18afd50..de47ef761 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -2,7 +2,6 @@ import pathlib import typing import warnings -from functools import partial import torch import torch.utils.data @@ -14,7 +13,7 @@ from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters from fast_llm.data.dataset.monitor import DatasetMonitor from fast_llm.data.iterator import SampledDatasetIterator -from fast_llm.data.sample.gpt import GPTBatch, GPTSample +from fast_llm.data.sample.language_model import LanguageModelBatch from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed @@ -24,32 +23,9 @@ logger = logging.getLogger(__name__) -def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch: - stacked_spans = None - sequence_lengths = None - stacked_chosen_spans = None - stacked_rejected_spans = None - if sampling_parameters.use_loss_masking_spans: - stacked_spans = [sample.loss_masking_spans for sample in batch] - if sampling_parameters.use_preference_loss_spans: - stacked_chosen_spans = [sample.chosen_span for sample in batch] - stacked_rejected_spans = [sample.rejected_span for sample in batch] - if not sampling_parameters.cross_document_attention: - sequence_lengths = [sample.sequence_lengths for sample in batch] - return GPTBatch( - token_ids=torch.stack([sample.token_ids for sample in batch]), - loss_masking_spans=stacked_spans, - sequence_lengths=sequence_lengths, - chosen_spans=stacked_chosen_spans, - rejected_spans=stacked_rejected_spans, - ) - - class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): """ A global class for all dataset needs, including loading, splitting, sampling and iteration. - Currently hard-coded to a GPT dataset. - TODO: Separate generic and GPT classes. """ _datasets: dict[str, SampledDataset] @@ -124,7 +100,7 @@ def get_iterator( num_workers: int, prefetch_factor: int | None = None, timeout: float = 60, - ) -> typing.Iterator[GPTBatch]: + ) -> typing.Iterator[LanguageModelBatch]: assert self._is_setup # Some dataset names may come from phases and are capitalized, @@ -149,10 +125,7 @@ def get_iterator( num_workers=num_workers, prefetch_factor=prefetch_factor, pin_memory=True, - collate_fn=partial( - gpt_data_collate_fn, - sampling_parameters=sampling_parameters, - ), + collate_fn=LanguageModelBatch.from_samples, multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) ) diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 7a8d3567d..e93e5865a 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -1,4 +1,5 @@ import dataclasses +import enum import functools import itertools import math @@ -15,6 +16,17 @@ from fast_llm.engine.distributed.distributed import Distributed +class ShufflingType(str, enum.Enum): + # Shuffle all epochs together. Not extendable. + full = "full" + # Shuffle all epochs separately. Default mode, recommended if the dataset doesn't come pre-shuffled. + epoch = "epoch" + # Shuffle all epochs except the first one. Recommended for pre-shuffled datasets, especially big ones. + skip_first_epoch = "skip_first_epoch" + # Disable shuffling entirely. + disabled = "disabled" + + @config_class() class SamplingConfig(Config): """ @@ -26,6 +38,18 @@ class SamplingConfig(Config): desc="Seed for random sampling.", hint=FieldHint.feature, ) + gpu: bool = Field( + default=True, + desc="Enable fast sampling on GPU." + " Note that random sampling works differently on GPU," + " so the sample won't match the CPU equivalent.", + hint=FieldHint.feature, + ) + shuffle: ShufflingType = Field( + default=ShufflingType.epoch, + desc="Shuffling strategy.", + hint=FieldHint.feature, + ) @dataclasses.dataclass(kw_only=True) @@ -34,7 +58,12 @@ class SamplingParameters: Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model. """ + sequence_length: int num_samples: int + truncate_documents: bool = True + # How many extra tokens to add to the sequence length. + # This is used to provide labels even for the last tokens in the sequence. + extra_tokens: int = 1 @dataclasses.dataclass(kw_only=True) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 36412b6ce..15f54ec80 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -1,5 +1,4 @@ import dataclasses -import enum import pathlib import time import typing @@ -13,64 +12,27 @@ IndexedDatasetConfig, SamplableDatasetConfig, SampledDatasetConfig, - SamplingConfig, SamplingData, SamplingParameters, ) -from fast_llm.data.sample.gpt import GPTSample +from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert if typing.TYPE_CHECKING: + from fast_llm.data.dataset.gpt.fim import GPTFimDataset from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.random import GPTRandomDataset -class ShufflingType(str, enum.Enum): - # Shuffle all epochs together. Not extendable. - full = "full" - # Shuffle all epochs separately. Default mode, recommended if the dataset doesn't come pre-shuffled. - epoch = "epoch" - # Shuffle all epochs except the first one. Recommended for pre-shuffled datasets, especially big ones. - skip_first_epoch = "skip_first_epoch" - # Disable shuffling entirely. - disabled = "disabled" - - -@config_class() -class GPTSamplingConfig(SamplingConfig): - """ - A dataset-dependent configuration for sampling. - """ - - gpu: bool = Field( - default=True, - desc="Enable fast sampling on GPU." - " Note that random sampling works differently on GPU," - " so the sample won't match the CPU equivalent.", - hint=FieldHint.feature, - ) - shuffle: ShufflingType = Field( - default=ShufflingType.epoch, - desc="Shuffling strategy.", - hint=FieldHint.feature, - ) - - @dataclasses.dataclass(kw_only=True) class GPTSamplingParameters(SamplingParameters): """ Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model. """ - sequence_length: int vocab_size: int use_loss_masking_spans: bool = False use_preference_loss_spans: bool = False - cross_document_attention: bool = True - truncate_documents: bool = True - # How many extra tokens to add to the sequence length. - # This is used to provide labels even for the last tokens in the sequence. - extra_tokens: int = 1 @dataclasses.dataclass(kw_only=True) @@ -80,12 +42,11 @@ class GPTSamplingData(SamplingData): usage-dependent ones (`GPTSamplingParameters`), and others set by the `Data`. """ - config: GPTSamplingConfig parameters: GPTSamplingParameters @config_class(dynamic_type={SampledDatasetConfig: "random"}) -class GPTRandomDatasetConfig[SampleType: GPTSample](SamplableDatasetConfig[SampleType]): +class GPTRandomDatasetConfig[SampleType: LanguageModelSample](SamplableDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False name: str = Field( default="dummy", @@ -93,14 +54,14 @@ class GPTRandomDatasetConfig[SampleType: GPTSample](SamplableDatasetConfig[Sampl hint=FieldHint.core, ) - def build(self) -> "GPTRandomDataset": + def build(self) -> "GPTRandomDataset[SampleType]": from fast_llm.data.dataset.gpt.random import GPTRandomDataset - return GPTRandomDataset(self.name) + return GPTRandomDataset[SampleType](self.name) @config_class(dynamic_type={SampledDatasetConfig: "memmap"}) -class GPTMemmapDatasetConfig[SampleType: GPTSample](IndexedDatasetConfig[SampleType]): +class GPTMemmapDatasetConfig[SampleType: LanguageModelSample](IndexedDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False path: pathlib.Path = Field( default=None, @@ -118,14 +79,16 @@ class GPTMemmapDatasetConfig[SampleType: GPTSample](IndexedDatasetConfig[SampleT hint=FieldHint.optional, ) - def build(self) -> "GPTMemmapDataset": + def build(self) -> "GPTMemmapDataset[SampleType]": from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset - return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens) + return GPTMemmapDataset[SampleType]( + str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens + ) @config_class(dynamic_type={SampledDatasetConfig: "file"}) -class GPTDatasetFromFileConfig[SampleType: GPTSample](SamplableDatasetConfig[SampleType]): +class GPTDatasetFromFileConfig[SampleType: LanguageModelSample](SamplableDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False path: pathlib.Path = Field( default=None, @@ -235,14 +198,14 @@ class FimConfig(Config): @config_class(dynamic_type={SampledDatasetConfig: "fim"}) -class GPTFimSampledDatasetConfig[SampleType: GPTSample](SampledDatasetConfig[SampleType], FimConfig): +class GPTFimSampledDatasetConfig[SampleType: LanguageModelSample](SampledDatasetConfig[SampleType], FimConfig): """ Configuration for FIM. """ _abstract: typing.ClassVar[bool] = False - dataset: SampledDatasetConfig = Field( + dataset: SampledDatasetConfig[SampleType] = Field( default=None, desc="The dataset to wrap with fim.", hint=FieldHint.core, @@ -250,15 +213,15 @@ class GPTFimSampledDatasetConfig[SampleType: GPTSample](SampledDatasetConfig[Sam def build_and_sample( self, - sampling: SamplingData, - ) -> SampledDataset: + sampling: GPTSamplingData, + ) -> "GPTFimDataset[SampleType]": from fast_llm.data.dataset.gpt.fim import GPTFimDataset - return GPTFimDataset(self, self.dataset.build_and_sample(sampling), sampling) + return GPTFimDataset[SampleType](self, self.dataset.build_and_sample(sampling), sampling) @config_class(dynamic_type={SampledDatasetConfig: "test_slow"}) -class GPTTestSlowDatasetConfig[SampleType: GPTSample](SampledDatasetConfig[SampleType]): +class GPTTestSlowDatasetConfig[SampleType: LanguageModelSample](SampledDatasetConfig[SampleType]): """ A mock dataset that mimics a slow dataset creation on one rank, which may trigger a timeout. """ diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 175a0e549..e7ca863b4 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -3,11 +3,12 @@ from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import FimConfig, GPTSamplingData -from fast_llm.data.sample.gpt import GPTSample +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.sample.token import TokenSample from fast_llm.engine.distributed.config import MAX_SEED -class GPTFimDataset[SampleType: GPTSample](SampledDataset[SampleType]): +class GPTFimDataset[SampleType: LanguageModelSample](SampledDataset[SampleType]): """ An implementation of FIM (fill in the middle) post-processing of GPT datasets. Adapted from https://github.com/EleutherAI/gpt-neox/blob/FIM-clean/megatron/data/gpt2_dataset.py @@ -43,10 +44,13 @@ def __len__(self) -> int: def __getitem__(self, index: int) -> SampleType: # TODO: Use torch methods to avoid back and forth. - return GPTSample( - torch.from_numpy( - self._fim( - self._dataset[index].token_ids.numpy(), np.random.RandomState(seed=(self._seed + index) % MAX_SEED) + return LanguageModelSample( + TokenSample( + torch.from_numpy( + self._fim( + self._dataset[index].tokens.tokens.numpy(), + np.random.RandomState(seed=(self._seed + index) % MAX_SEED), + ) ) ) ) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index c78805380..e6b650621 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -8,12 +8,14 @@ from fast_llm.data.dataset.gpt.config import GPTSamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER -from fast_llm.data.sample.gpt import GPTSample +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.sample.range import RangeSample +from fast_llm.data.sample.token import TokenSample from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert, div -class GPTMemmapDataset[SampleType: GPTSample](IndexedDataset[SampleType]): +class GPTMemmapDataset[SampleType: LanguageModelSample](IndexedDataset[SampleType]): """ A memory map dataset, which handles lazy loading of a pre-processed dataset in the Megatron-LM format, i.e. a pair of numpy file containing @@ -47,7 +49,7 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None if self._version >= 3: self._has_preference_spans = struct.unpack(" SampleType: if end is None: end = self.get_document_size(index) - token_ids = np.frombuffer( - self._bin_buffer, - dtype=self._dtype, - count=end - begin, - offset=self._pointers[index] + begin * np.dtype(self._dtype).itemsize, + sample_size = self._document_sizes[index].item() + assert 0 <= begin <= end <= sample_size, (0, begin, end, sample_size) + token_ids = ( + torch.frombuffer( + self._bin_buffer, + dtype=self._dtype, + count=end - begin, + offset=self._pointers[index].item() + begin * self._dtype.itemsize, + ) + if end > begin + else torch.empty(0, dtype=self._dtype) ) - sample_spans = None if parameters is not None and parameters.use_loss_masking_spans: assert self._spans is not None - sample_spans = self._spans[index] - - # filter spans that are outside the range of the selected tokens in the document - sample_spans = sample_spans[(sample_spans[:, 0] < begin + len(token_ids)) & (sample_spans[:, 1] >= begin)] - - # subtract by offset to normalize span boundaries - sample_spans[:, 0] = np.maximum(sample_spans[:, 0], begin) - begin # offset - sample_spans[:, 1] = np.minimum(sample_spans[:, 1], begin + len(token_ids) - 1) - begin - sample_spans = torch.from_numpy(sample_spans) - - chosen_span = None - rejected_span = None + # TODO: ====== Store in range format (begin, end) ====== + sample_spans = RangeSample( + [(begin_, last_ + 1) for begin_, last_ in self._spans[index].tolist()], sample_size + ).crop(begin, end) + else: + sample_spans = None if parameters is not None and parameters.use_preference_loss_spans: if not self._has_preference_spans: @@ -178,34 +179,25 @@ def get_document( raise ValueError("Failed to read chosen spans from memmap dataset.") elif self._has_preference_spans and self._rejected_spans is None: raise ValueError("Failed to read rejected spans from memmap dataset.") - else: - chosen_span = self._chosen_spans[index] - - # filter spans that are outside the range of the selected tokens in the document - chosen_span = chosen_span[(chosen_span[0] < begin + len(token_ids)) & (chosen_span[1] >= begin)][0] - - # subtract by offset to normalize span boundaries - chosen_span[0] = np.maximum(chosen_span[0], begin) - begin # offset - chosen_span[1] = np.minimum(chosen_span[1], begin + len(token_ids) - 1) - begin - chosen_span = torch.from_numpy(chosen_span) - - rejected_span = self._rejected_spans[index] - - # filter spans that are outside the range of the selected tokens in the document - rejected_span = rejected_span[ - (rejected_span[0] < begin + len(token_ids)) & (rejected_span[1] >= begin) - ][0] - - # subtract by offset to normalize span boundaries - rejected_span[0] = np.maximum(rejected_span[0], begin) - begin # offset - rejected_span[1] = np.minimum(rejected_span[1], begin + len(token_ids) - 1) - begin - rejected_span = torch.from_numpy(rejected_span) + elif begin != 0 or end != sample_size: + raise ValueError("Samples with preference spans should not be cropped.") + # TODO: ====== Store in range format ====== + chosen_spans = RangeSample( + [(self._chosen_spans[index][0].item(), self._chosen_spans[index][1].item() + 1)], + sample_size, + ) + rejected_spans = RangeSample( + [(self._rejected_spans[index][0].item(), self._rejected_spans[index][1].item() + 1)], + sample_size, + ) + else: + chosen_spans = rejected_spans = None - return GPTSample( - token_ids=torch.from_numpy(token_ids), + return LanguageModelSample( + tokens=TokenSample(token_ids), loss_masking_spans=sample_spans, - chosen_span=chosen_span, - rejected_span=rejected_span, + chosen_spans=chosen_spans, + rejected_spans=rejected_spans, ) @property @@ -231,7 +223,11 @@ def get_document_size(self, index: int) -> int: return self._document_sizes[index].item() @classmethod - def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTSample]): + def write_dataset( + cls, + prefix: pathlib.Path | str, + documents: typing.Iterable[tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]], + ) -> None: # Initialize metadata dtype = None num_documents = 0 @@ -249,29 +245,29 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # Write the binary data file (.bin) lazily with prefix.with_suffix(".bin").open("wb") as bin_stream: - for document in documents: + for token_ids, loss_masking_spans, chosen_span, rejected_span in documents: # Infer dtype from the first document if dtype is None: - dtype = document.token_ids.dtype + dtype = token_ids.dtype assert dtype is not None, "Document dtype could not be inferred from the data." # Ensure all documents have the same dtype - assert document.token_ids.dtype == dtype, f"Expected dtype {dtype}, got {document.token_ids.dtype}." + assert token_ids.dtype == dtype, f"Expected dtype {dtype}, got {token_ids.dtype}." # Write document to binary file - bin_stream.write(document.token_ids.numpy().tobytes(order="C")) + bin_stream.write(token_ids.numpy().tobytes(order="C")) # Update metadata - doc_length = len(document.token_ids) + doc_length = len(token_ids) lengths.append(doc_length) pointers.append(offset) - if document.loss_masking_spans is not None: - num_spans.append(len(document.loss_masking_spans)) - spans.append(document.loss_masking_spans) - if document.chosen_span is not None: - chosen_spans.append(document.chosen_span) - if document.rejected_span is not None: - rejected_spans.append(document.rejected_span) + if loss_masking_spans is not None: + num_spans.append(len(loss_masking_spans)) + spans.append(loss_masking_spans) + if chosen_span is not None: + chosen_spans.append(chosen_span) + if rejected_span is not None: + rejected_spans.append(rejected_span) offset += doc_length * dtype.itemsize num_documents += 1 diff --git a/fast_llm/data/dataset/gpt/random.py b/fast_llm/data/dataset/gpt/random.py index c12e4adcc..62c4311ab 100644 --- a/fast_llm/data/dataset/gpt/random.py +++ b/fast_llm/data/dataset/gpt/random.py @@ -3,10 +3,11 @@ from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingData -from fast_llm.data.sample.gpt import GPTSample +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.sample.token import TokenSample -class GPTRandomDataset(SamplableDataset): +class GPTRandomDataset[SampleType: LanguageModelSample](SamplableDataset[SampleType]): """ A dummy dataset that always returns the same random sample, for debugging purposes. """ @@ -22,22 +23,28 @@ def name(self) -> str: return self._name -class GPTRandomSampledDataset[SampleType: GPTSample](SampledDataset[SampleType]): +class GPTRandomSampledDataset[SampleType: LanguageModelSample](SampledDataset[SampleType]): def __init__(self, sampling: GPTSamplingData, name: str): self._name = name self._seed = sampling.config.seed - self._sequence_length = sampling.parameters.sequence_length - self._vocab_size = sampling.parameters.vocab_size - self._num_samples = sampling.parameters.num_samples + self._parameters = sampling.parameters + # TODO: Support? + assert not self._parameters.use_loss_masking_spans + assert not self._parameters.use_preference_loss_spans def __len__(self) -> int: - return self._num_samples + return self._parameters.num_samples def __getitem__(self, index: int) -> SampleType: - return GPTSample( - torch.from_numpy( - np.random.RandomState(self._seed + 48576439 + 74593 * index).randint( - 0, self._vocab_size, size=(self._sequence_length + 1,), dtype=np.int64 + return LanguageModelSample( + TokenSample( + torch.from_numpy( + np.random.RandomState(self._seed + 48576439 + 74593 * index).randint( + 0, + self._parameters.vocab_size, + size=(self._parameters.sequence_length + self._parameters.extra_tokens,), + dtype=np.int64, + ) ) ) ) diff --git a/fast_llm/data/dataset/sampled.py b/fast_llm/data/dataset/sampled.py index 238e99bca..441dfafae 100644 --- a/fast_llm/data/dataset/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -9,10 +9,9 @@ import yaml from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.gpt.config import GPTSamplingData, ShufflingType +from fast_llm.data.dataset.config import SamplingData, ShufflingType from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.sample.abstract import Sample -from fast_llm.data.sample.gpt import GPTSample from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.utils import Assert @@ -69,16 +68,14 @@ def _lazy_load(self): class SampledIndexedDataset[SampleType: Sample](SampledDataset[SampleType]): """ - A sampled GPT dataset. + A sampled dataset. """ def __init__( self, indexed_dataset: IndexedDataset[SampleType], - # TODO: ====== Remove gpt-specific stuff ====== - sampling: GPTSamplingData, + sampling: SamplingData, ): - assert isinstance(sampling, GPTSamplingData) self._indexed_dataset = indexed_dataset self._config = sampling.config self._parameters = sampling.parameters @@ -108,22 +105,15 @@ def __init__( self._token_cumsum_unshuffled = MemmapArray(base_path.with_name(base_path.name + "_unshuffled_cumsum.npy")) self._yaml_path = base_path.with_suffix(".yaml") - # keep document sizes and len filtered docs for preference loss masking - if self._parameters.use_preference_loss_spans: - self._document_sizes = MemmapArray(base_path.with_name(base_path.name + "_doc_sizes.npy")) - self._doc_length_filtered_indicies = MemmapArray( - base_path.with_name(base_path.name + "_doc_length_filtered_indices.npy") - ) - # Sample or validate the dataset of a given rank. if sampling.distributed.config.rank == sampling.get_next_rank(): self._sample() # No barrier yet to allow running in parallel. - # There needs to be one before calling `__getitem__`, normally handled through `GPTData`. + # There needs to be one before calling `__getitem__`, normally handled through `Data`. def _sample(self) -> None: """ - Create a `GPTSampledDataset` with the requested parameters. + Create a `SampledDataset` with the requested parameters. """ # Get the document sizes, the main information needed for sampling. document_sizes = self._indexed_dataset.get_document_sizes().to(self._device) @@ -152,10 +142,7 @@ def _sample(self) -> None: # We produce sequences of length `self._sequence_length + extra_tokens` so the last token has a label for all prediction heads, # but in case of truncations we also include those last labels in the following sample, # so we need `sequence_length * num_samples + extra_tokens` tokens in total. - if self._parameters.use_preference_loss_spans: - documents_per_epoch = (~long_docs_filter).sum().item() - num_epochs = math.ceil(self._parameters.num_samples / documents_per_epoch) - elif self._truncate_documents: + if self._truncate_documents: num_epochs = math.ceil( (self._parameters.sequence_length * self._parameters.num_samples + self._parameters.extra_tokens) / tokens_per_epoch @@ -259,24 +246,6 @@ def _sample(self) -> None: else: raise NotImplementedError(f"Unknown shuffling type: {self._config.shuffle}") - if self._parameters.use_preference_loss_spans: - yaml_data["unshuffled_tokens"] = 0 # not used, ignore - - # index of all documents less than seq length long - doc_length_filtered_indicies = torch.nonzero(~long_docs_filter, as_tuple=True)[0] - self._doc_length_filtered_indicies.save(doc_length_filtered_indicies.numpy(force=self._config.gpu)) - - # apply shuffling on doc_length_filtered_indicies - if shuffled_epochs > 0: - self._document_shuffling.save( - document_shuffling[: self._parameters.num_samples].numpy(force=self._config.gpu) - ) - self._document_sizes.save(document_sizes.numpy(force=self._config.gpu)) - if self._yaml_path is not None: - self._yaml_path.parent.mkdir(parents=True, exist_ok=True) - yaml.safe_dump(yaml_data, self._yaml_path.open("w")) - return - # To get a sample on the fly we need to know where it begins, # and this is a non-trivial information because the documents have variable length. # The starting point `(document[idx], token[idx])` corresponds to the `(idx * sequence_length)` th token, i.e. @@ -372,42 +341,10 @@ def __getitem__(self, index: int) -> SampleType: """ Get the sample, (fixed-length sequence of tokens holding one or more complete or partial documents) with the requested sampling index. - The returned sample is ready to be concatenated, then fed to a `GPTModel` (see `GPTModel.preprocess`). + The returned sample is ready to be concatenated, then fed to a `Model`. """ self._lazy_load() - if self._parameters.use_preference_loss_spans: - if index < self._unshuffled_documents: - document_index = self._doc_length_filtered_indicies[index % self._documents_per_epoch] - else: - document_index = self._doc_length_filtered_indicies[ - self._document_shuffling[index - self._unshuffled_documents].item() - ] - - sample = self._indexed_dataset.get_document( - document_index.item(), - begin=0, - end=self._document_sizes[document_index].item(), - parameters=self._parameters, - ) - - chosen_span_end = sample.chosen_span[1] + 1 - sequence_lengths = [ - chosen_span_end, - len(sample.token_ids) - chosen_span_end, - ] - - # compute padding size - padding = np.full((self._parameters.sequence_length + 1,), 0) - padding[: len(sample.token_ids)] = sample.token_ids - sequence_lengths.append(self._parameters.sequence_length - len(sample.token_ids)) - sample.token_ids = padding - - if not self._parameters.cross_document_attention: - sample.sequence_lengths = torch.tensor(sequence_lengths) - - return sample - # tokens at the boundary are included in only one sample when we pack without truncations # in case of packing with truncations, the last token from the previous sample is also the first token of the next sample sample_length = ( @@ -432,8 +369,7 @@ def __getitem__(self, index: int) -> SampleType: token_count = token_start_array[token_start_cumsum_index] - token_ids = [] - loss_masking_spans = [] + documents: list[SampleType] = [] while token_count < token_end: # Find the document index in the dataset. if document_sampling_index < self._unshuffled_documents: @@ -453,8 +389,8 @@ def __getitem__(self, index: int) -> SampleType: # Document belongs to the next sample, need to account for padding. padding_size = self._parameters.sequence_length + 1 - tokens_in_sample if token_count > token_start: - # Add padding tokens to current sample - token_ids.append(np.full((padding_size,), -100, dtype=np.int64)) + # TODO: ====== Handle padding ====== + documents.append(PaddingSample(padding_size)) Assert.eq(token_count + padding_size, token_end) break else: @@ -466,45 +402,21 @@ def __getitem__(self, index: int) -> SampleType: # Determine which part of the document belong to the sample, and add it to the list. token_start_index_in_document = max(token_start - token_count, 0) token_end_index_in_document = min(token_end - token_count, document_size) - sample = self._indexed_dataset.get_document( - document_index, - begin=token_start_index_in_document, - end=token_end_index_in_document, - parameters=self._parameters, + documents.append( + self._indexed_dataset.get_document( + document_index, + begin=token_start_index_in_document, + end=token_end_index_in_document, + parameters=self._parameters, + ) ) - token_ids.append(sample.token_ids) - if self._parameters.use_loss_masking_spans: - for loss_masking_span in sample.loss_masking_spans: - span = np.clip( - loss_masking_span + token_count - token_start, - 0, - self._parameters.sequence_length + self._parameters.extra_tokens, - ) - if span[1] >= span[0]: - loss_masking_spans.append(span) # Go to the next document. document_sampling_index += 1 token_count += document_size - sequence_lengths = ( - torch.tensor([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) - if not self._parameters.cross_document_attention - else None - ) - token_ids = np.concatenate(token_ids, dtype=np.int64) - loss_masking_spans = ( - torch.from_numpy(np.stack(loss_masking_spans, dtype=np.int32) if loss_masking_spans else np.array([])) - if self._parameters.use_loss_masking_spans - else None - ) - Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) - - return GPTSample( - token_ids=torch.from_numpy(token_ids), - loss_masking_spans=loss_masking_spans, - sequence_lengths=sequence_lengths, - ) + # TODO: ====== Better way to get the class method? ====== + return documents[0].from_documents(documents) @property def name(self) -> str: @@ -517,13 +429,5 @@ def _lazy_load(self): def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._documents_per_epoch = data["dataset"]["documents_per_epoch"] - if self._parameters.use_preference_loss_spans: - data["unshuffled_tokens"] = 0 # not used, ignore - elif "unshuffled_tokens" not in data: - # Backward compatibility - # TODO v0.x: Remove - assert self._truncate_documents - data["unshuffled_tokens"] = data["tokens_per_epoch"] * data["unshuffled_epochs"] - self._unshuffled_tokens = data["unshuffled_tokens"] self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index a8ff187ae..73dba6ccc 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -24,7 +24,7 @@ from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, TextColumnConfig -from fast_llm.data.sample.gpt import GPTSample +from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum @@ -37,7 +37,7 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D _data_type: DataType _text_column: str _loss_masking_spans_column: str | None - _sample_type: typing.ClassVar[type[GPTSample]] = GPTSample + _sample_type: typing.ClassVar[type[LanguageModelSample]] = LanguageModelSample def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: input_ids = [ @@ -142,11 +142,14 @@ def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetCon shard_output_path = self._config.output_path / prefix def _document_generator(): + # TODO: Yield `LanguageModelSample` if "token_spans" in shard_dataset.column_names and self._loss_masking_spans_column is not None: for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample( + yield ( torch.tensor(item["input_ids"], dtype=self._data_type.torch), torch.tensor(item["token_spans"], dtype=torch.int32).reshape(-1, 2), + None, + None, ) elif ( "chosen_token_spans" in shard_dataset.column_names @@ -155,14 +158,20 @@ def _document_generator(): and self._config.dataset.rejected_text is not None ): for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample( - token_ids=torch.tensor(item["input_ids"], dtype=self._data_type.torch), - chosen_span=torch.tensor(item["chosen_token_spans"], dtype=torch.int32).reshape(-1, 2), - rejected_span=torch.tensor(item["rejected_token_spans"], dtype=torch.int32).reshape(-1, 2), + yield ( + torch.tensor(item["input_ids"], dtype=self._data_type.torch), + None, + torch.tensor(item["chosen_token_spans"], dtype=torch.int32).reshape(-1, 2), + torch.tensor(item["rejected_token_spans"], dtype=torch.int32).reshape(-1, 2), ) else: for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample(torch.tensor(item["input_ids"], dtype=self._data_type.torch)) + yield ( + torch.tensor(item["input_ids"], dtype=self._data_type.torch), + None, + None, + None, + ) GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py index 0c640b9b3..c1bebe166 100644 --- a/fast_llm/data/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -1,10 +1,37 @@ import abc +import typing + +import torch class Sample(abc.ABC): - pass + @classmethod + @abc.abstractmethod + def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: + pass + + @abc.abstractmethod + def crop(self, begin: int, end: int) -> typing.Self: + pass + + @abc.abstractmethod + def __len__(self) -> int: + pass class Batch(abc.ABC): # TODO: Relate to `BatchConfig`? - pass + @classmethod + @abc.abstractmethod + def from_samples(cls, samples: typing.Iterable[Sample]) -> typing.Self: + pass + + @abc.abstractmethod + def to_samples(self) -> list[Sample]: + pass + + def crop(self, begin: int, end: int) -> typing.Self: + return self.from_samples(sample.crop(begin, end) for sample in self.to_samples()) + + def to_device_(self, device: "torch.device | str"): + pass diff --git a/fast_llm/data/sample/gpt.py b/fast_llm/data/sample/gpt.py deleted file mode 100644 index 4bf740462..000000000 --- a/fast_llm/data/sample/gpt.py +++ /dev/null @@ -1,25 +0,0 @@ -import dataclasses -import typing - -from fast_llm.data.sample.abstract import Batch, Sample - -if typing.TYPE_CHECKING: - import torch - - -@dataclasses.dataclass -class GPTSample(Sample): - token_ids: "torch.Tensor" - loss_masking_spans: "torch.Tensor | None" = None - chosen_span: "torch.Tensor | None" = None - rejected_span: "torch.Tensor | None" = None - sequence_lengths: "torch.Tensor | None" = None - - -@dataclasses.dataclass -class GPTBatch(Batch): - token_ids: "torch.Tensor" - loss_masking_spans: "list[torch.Tensor] | None" = None - sequence_lengths: "list[torch.Tensor] | None" = None - chosen_spans: "list[torch.Tensor] | None" = None - rejected_spans: "list[torch.Tensor] | None" = None diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py new file mode 100644 index 000000000..57a9a30f6 --- /dev/null +++ b/fast_llm/data/sample/language_model.py @@ -0,0 +1,99 @@ +import typing + +from fast_llm.data.sample.abstract import Batch, Sample +from fast_llm.data.sample.range import RangeBatch, RangeSample +from fast_llm.data.sample.token import TokenBatch, TokenSample + + +class LanguageModelSample(Sample): + def __init__( + self, + tokens: TokenSample, + loss_masking_spans: RangeSample | None = None, + chosen_spans: RangeSample | None = None, + rejected_spans: RangeSample | None = None, + ): + self.tokens = tokens + self.loss_masking_spans = loss_masking_spans + self.chosen_spans = chosen_spans + self.rejected_spans = rejected_spans + + @classmethod + def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: + return cls( + TokenSample.from_documents(document.tokens for document in documents), + _merge_optional(RangeSample.from_documents, (document.loss_masking_spans for document in documents)), + _merge_optional(RangeSample.from_documents, (document.chosen_spans for document in documents)), + _merge_optional(RangeSample.from_documents, (document.rejected_spans for document in documents)), + ) + + def crop(self, begin: int, end: int) -> typing.Self: + return self.__class__( + self.tokens.crop(begin, end), + _crop_optional(self.loss_masking_spans, begin, end), + _crop_optional(self.chosen_spans, begin, end), + _crop_optional(self.rejected_spans, begin, end), + ) + + def __len__(self) -> int: + return len(self.tokens) + + +class LanguageModelBatch(Batch): + def __init__( + self, + tokens: TokenBatch, + loss_masking_spans: RangeBatch | None = None, + chosen_spans: RangeBatch | None = None, + rejected_spans: RangeBatch | None = None, + ): + self.tokens = tokens + self.loss_masking_spans = loss_masking_spans + self.chosen_spans = chosen_spans + self.rejected_spans = rejected_spans + + @classmethod + def from_samples(cls, samples: typing.Iterable[LanguageModelSample]) -> typing.Self: + return cls( + TokenBatch.from_samples(sample.tokens for sample in samples), + _merge_optional(RangeBatch.from_samples, (sample.loss_masking_spans for sample in samples)), + _merge_optional(RangeBatch.from_samples, (sample.chosen_spans for sample in samples)), + _merge_optional(RangeBatch.from_samples, (sample.rejected_spans for sample in samples)), + ) + + def to_samples(self) -> list[LanguageModelSample]: + return [ + LanguageModelSample(tokens, loss_masking_spans, chosen_spans, rejected_spans) + for tokens, loss_masking_spans, chosen_spans, rejected_spans in zip( + self.tokens.to_samples(), + self.loss_masking_spans.to_samples(), + self.chosen_spans.to_samples(), + self.rejected_spans.to_samples(), + strict=True, + ) + ] + + def crop(self, begin: int, end: int) -> typing.Self: + return self.__class__( + self.tokens.crop(begin, end), + _crop_optional(self.loss_masking_spans, begin, end), + _crop_optional(self.chosen_spans, begin, end), + _crop_optional(self.rejected_spans, begin, end), + ) + + def to_device_(self, device: "torch.device | str"): + self.tokens.to_device_(device) + if self.loss_masking_spans is not None: + self.loss_masking_spans.to_device_(device) + if self.chosen_spans is not None: + self.chosen_spans.to_device_(device) + if self.rejected_spans is not None: + self.rejected_spans.to_device_(device) + + +def _merge_optional[T](fn: typing.Callable[[typing.Iterable], T], args: typing.Iterable) -> T | None: + return None if any(arg is None for arg in args) else fn(args) + + +def _crop_optional[T: Sample | Batch](sample_or_batch: T, begin: int, end: int) -> T | None: + return None if sample_or_batch is None else sample_or_batch.crop(begin, end) diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py new file mode 100644 index 000000000..ec71b98d6 --- /dev/null +++ b/fast_llm/data/sample/range.py @@ -0,0 +1,46 @@ +import typing + +from fast_llm.data.sample.abstract import Batch, Sample +from fast_llm.utils import get_unique + + +class RangeSample(Sample): + """ + A reusable component holding a set of ranges in a sample. + """ + + def __init__(self, ranges: list[tuple[int, int]], sample_size: int): + self.ranges = ranges + self.sample_size = sample_size + + @classmethod + def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: + document: RangeSample + ranges = [] + sample_size = 0 + for document in documents: + for begin, end in document.ranges: + ranges.extend((begin + sample_size, end + sample_size)) + sample_size += document.sample_size + return cls(ranges, sample_size) + + def crop(self, begin: int, end: int) -> typing.Self: + sample_size = end - begin + cropped_ranges = ((max(begin_ - begin, 0), min(end_ - begin, sample_size)) for begin_, end_ in self.ranges) + return self.__class__([(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_], sample_size) + + def __len__(self) -> int: + return self.sample_size + + +class RangeBatch(Batch): + def __init__(self, ranges: list[list[tuple[int, int]]], sample_size: int): + self.sample_size = sample_size + self.ranges = ranges + + @classmethod + def from_samples(cls, samples: typing.Iterable[RangeSample]) -> typing.Self: + return cls([sample.ranges for sample in samples], get_unique(sample.sample_size for sample in samples)) + + def to_samples(self) -> list[RangeSample]: + return [RangeSample(sample_ranges, self.sample_size) for sample_ranges in self.ranges] diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py new file mode 100644 index 000000000..db853341a --- /dev/null +++ b/fast_llm/data/sample/token.py @@ -0,0 +1,70 @@ +import typing + +import torch + +from fast_llm.data.sample.abstract import Batch, Sample +from fast_llm.utils import Assert + + +class TokenSample(Sample): + def __init__(self, tokens: torch.Tensor, lengths: list[int] | None = None): + self.tokens = tokens + # Length of each document in the sample. TODO: Use cumsums instead? + if lengths is None: + lengths = [len(tokens)] + else: + Assert.eq(sum(lengths), len(tokens)) + self.lengths = lengths + + @classmethod + def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: + return cls( + torch.cat([document.tokens for document in documents]), + sum((document.lengths for document in documents), []), + ) + + def crop(self, begin: int, end: int) -> typing.Self: + sample_size = end - begin + if self.lengths == [len(self.tokens)]: + # Shortcut for the common case of a single document. + lengths = [sample_size] + else: + begin_ = 0 + lengths = [] + for length in self.lengths: + end_ = begin + length + cropped_length = max(begin_ - begin, 0) - min(end_ - begin, end) + if cropped_length > 0: + lengths.append(cropped_length) + if end_ > end: + break + begin_ = end_ + return self.__class__(self.tokens[begin:end], lengths) + + def __len__(self) -> int: + return len(self.tokens) + + +class TokenBatch(Batch): + def __init__(self, tokens: torch.Tensor, lengths: list[list[int]]) -> None: + self.tokens = tokens + self.lengths = lengths + + @classmethod + def from_samples(cls, samples: typing.Iterable[TokenSample]) -> typing.Self: + return cls( + torch.stack([sample.tokens for sample in samples]), + [sample.lengths for sample in samples], + ) + + def to_samples(self) -> list[TokenSample]: + return [TokenSample(tokens, lengths) for tokens, lengths in zip(self.tokens, self.lengths, strict=True)] + + def crop(self, begin: int, end: int) -> typing.Self: + return self.__class__( + self.tokens[:, begin:end], [sample.crop(begin, end).lengths for sample in self.to_samples()] + ) + + def to_device_(self, device: "torch.device | str"): + # Also standardize the dtype while we're here. + self.tokens = self.tokens.to(device, dtype=torch.int64, non_blocking=True) diff --git a/fast_llm/functional/dpo.py b/fast_llm/functional/dpo.py index 3a70f308f..7ab0b9ff6 100644 --- a/fast_llm/functional/dpo.py +++ b/fast_llm/functional/dpo.py @@ -1,51 +1,25 @@ import torch -def _compute_logprobs_for_preference_spans( - logits: torch.Tensor, targets: torch.Tensor, chosen_spans: torch.Tensor, rejected_spans: torch.Tensor -): - assert torch.all(targets < logits.size(-1)), "Target out of vocab range" +def _get_target_log_probabilities(logits: torch.Tensor, targets: torch.Tensor): + # Gather log probabilities corresponding to the target tokens + return torch.nn.functional.log_softmax(logits, dim=-1).gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) - log_probs = torch.nn.functional.log_softmax(logits, dim=-1) - # gather log probabilities corresponding to the target tokens - selected_log_probs = log_probs.gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) - - # apply chosen mask - chosen_logp = 0 - for idx, span in enumerate(chosen_spans): - chosen_logp += selected_log_probs[idx][span[0].item() : span[1].item() + 1].sum() - - # apply rejected mask - rejected_logp = 0 - for idx, span in enumerate(rejected_spans): - rejected_logp += selected_log_probs[idx][span[0].item() : span[1].item() + 1].sum() - - return chosen_logp, rejected_logp, selected_log_probs - - -def _compute_dpo_loss( - policy_chosen_logps: torch.Tensor, - policy_rejected_logps: torch.Tensor, - reference_chosen_logps: torch.Tensor, - reference_rejected_logps: torch.Tensor, - beta: float, -): - pi_logratios = policy_chosen_logps - policy_rejected_logps - ref_logratios = reference_chosen_logps - reference_rejected_logps - - diff_logratios = pi_logratios - ref_logratios - - losses = -torch.nn.functional.logsigmoid(beta * diff_logratios) - return losses +def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans: list[list[tuple[int, int]]]): + return sum( + log_probabilities[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(spans) + for begin, end in sample_spans + ) def compute_dpo_loss( logits: torch.Tensor, targets: torch.Tensor, reference_model_logits: torch.Tensor, - chosen_spans: torch.Tensor, - rejected_spans: torch.Tensor, + chosen_spans: list[list[tuple[int, int]]], + rejected_spans: list[list[tuple[int, int]]], beta: float, grad_output: float | None, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -53,21 +27,18 @@ def compute_dpo_loss( logits_ = logits.float().detach().requires_grad_() reference_model_logits_ = reference_model_logits.float().detach() - policy_chosen_logps, policy_rejected_logps, _ = _compute_logprobs_for_preference_spans( - logits_, targets, chosen_spans, rejected_spans - ) + policy_log_probabilities = _get_target_log_probabilities(logits_, targets) + policy_log_ratios = _get_target_log_probability_for_spans( + policy_log_probabilities, chosen_spans + ) - _get_target_log_probability_for_spans(policy_log_probabilities, rejected_spans) - reference_chosen_logps, reference_rejected_logps, _ = _compute_logprobs_for_preference_spans( - reference_model_logits_, targets, chosen_spans, rejected_spans - ) + reference_log_probabilities = _get_target_log_probabilities(reference_model_logits_, targets) + reference_log_ratios = _get_target_log_probability_for_spans( + reference_log_probabilities, chosen_spans + ) - _get_target_log_probability_for_spans(reference_log_probabilities, rejected_spans) - losses = _compute_dpo_loss( - policy_chosen_logps=policy_chosen_logps, - policy_rejected_logps=policy_rejected_logps, - reference_chosen_logps=reference_chosen_logps, - reference_rejected_logps=reference_rejected_logps, - beta=beta, - ) + # TODO: ====== Shouldn't the sigmoid be computed independently for each document? + losses = -torch.nn.functional.logsigmoid(beta * (policy_log_ratios - reference_log_ratios)) if grad_output is None: loss = None diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index a76c3712e..34e38469a 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -5,7 +5,8 @@ import torch import transformers.modeling_outputs -from fast_llm.data.sample.gpt import GPTBatch +from fast_llm.data.sample.language_model import LanguageModelBatch +from fast_llm.data.sample.token import TokenBatch from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM @@ -80,7 +81,9 @@ def inner_forward( # Iteration serves as a random seed, using random module because it's not seeded by Fast LLM iteration = random.randint(0, 2**32) batch = self.fast_llm_base_model.preprocess_batch( - GPTBatch(input_ids, sequence_lengths=sequence_lenghts), phase=PhaseType.inference, iteration=iteration + LanguageModelBatch(TokenBatch(input_ids, lengths=sequence_lenghts)), + phase=PhaseType.inference, + iteration=iteration, ) ((input_, kwargs),) = batch diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index bd3c91a38..3e50d1ed1 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -3,7 +3,7 @@ import torch -from fast_llm.data.sample.gpt import GPTBatch +from fast_llm.data.sample.language_model import LanguageModelBatch from fast_llm.engine.base_model.base_model import BaseModel from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType @@ -40,7 +40,7 @@ def __init__( param.init_parameter = get_init_megatron(param, self._config.decoder.block, config.hidden_size) # Noqa def preprocess_meta( - self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType + self, batch_meta: GPTBatchConfig | LanguageModelBatch, phase: PhaseType ) -> list[tuple[TensorMeta, dict]]: # TODO Remove (Move batch splitting elsewhere) # TODO: Use parallel/sequential dims, distinguish micro and full batch/sequence @@ -51,7 +51,7 @@ def preprocess_meta( micro_sequence_length = batch_meta.micro_sequence_length truncate_documents = batch_meta.truncate_documents else: - micro_batch_size, sequence_length = batch_meta.shape + micro_batch_size, sequence_length = batch_meta.tokens.tokens.shape if phase != PhaseType.inference: sequence_length -= self._config.head.prediction_heads micro_sequence_length = sequence_length @@ -151,7 +151,7 @@ def preprocess_meta( def preprocess_batch( self, - batch: GPTBatch, + batch: LanguageModelBatch, preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None, *, phase: PhaseType, @@ -161,19 +161,10 @@ def preprocess_batch( # TODO Move batch splitting elsewhere, align interface with LayerBase assert self._is_setup - if preprocessed_meta is None: - preprocessed_meta = self.preprocess_meta(batch.token_ids, phase) - - _, common_kwargs = preprocessed_meta[0] - sequence_q = common_kwargs[AttentionKwargs.sequence_q_dim].size - sequence_first = common_kwargs[AttentionKwargs.sequence_first] - max_prediction_distance = self._config.head.max_prediction_distance + batch.to_device_(self._distributed.device) - batch.token_ids = batch.token_ids.to( - device=self._distributed.device, - dtype=torch.int64, - non_blocking=True, - ) + if preprocessed_meta is None: + preprocessed_meta = self.preprocess_meta(batch, phase) reference_logits = [{} for _ in preprocessed_meta] for name, reference_model in self._reference_models.items(): @@ -191,103 +182,60 @@ def preprocess_batch( reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"] - token_ids = batch.token_ids - if sequence_first: - # Move the sequence dimension first to make sequence parallel ops more efficient. - token_ids = token_ids.transpose(0, 1).contiguous() - preprocessed = [] presents = None for i, (_, kwargs_meta) in enumerate(preprocessed_meta): - sequence_k = kwargs_meta[AttentionKwargs.sequence_k_dim].size - if sequence_first: - tokens = token_ids[sequence_k - sequence_q : sequence_k] - else: - # TODO: Avoid multiple contiguous calls? - tokens = token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() - if batch.sequence_lengths is not None: - kwargs_meta[AttentionKwargs.sequence_lengths] = batch.sequence_lengths - if batch.chosen_spans is not None: - kwargs_meta[LanguageModelKwargs.chosen_spans] = batch.chosen_spans - if batch.rejected_spans is not None: - kwargs_meta[LanguageModelKwargs.rejected_spans] = batch.rejected_spans + tokens_end = kwargs_meta[AttentionKwargs.sequence_k_dim].size + tokens_begin = tokens_end - kwargs_meta[AttentionKwargs.sequence_q_dim].size + cropped_tokens = batch.tokens.crop(tokens_begin, tokens_end) # TODO: Add pasts/presents to meta input? # Use lists as pointers so `past_key_values` is populated during the previous micro_sequence. pasts = presents presents = None if i == len(preprocessed_meta) - 1 else [] - kwargs = { + + kwargs: dict[str, typing.Any] = { **kwargs_meta, AttentionKwargs.past_key_values: pasts, AttentionKwargs.presents: presents, + # TODO: ====== Use only if wanted ====== + AttentionKwargs.sequence_lengths: cropped_tokens.lengths, + **reference_logits[i], } + if phase != PhaseType.inference: - sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels - if sequence_first: - labels = token_ids[sequence_offset : sequence_k + max_prediction_distance] - else: - # TODO: Avoid multiple contiguous calls? - labels = token_ids[:, sequence_offset : sequence_k + max_prediction_distance].contiguous() - # We set label indices to -100 for masked spans, inline with ignore_index in torch.nn.CrossEntropyLoss - # TODO: take ignore_index from config + labels_begin = tokens_begin + 1 + labels_end = tokens_end + self._config.head.max_prediction_distance + + labels = batch.tokens.crop(labels_begin, labels_end).tokens + if batch.loss_masking_spans is not None: - # avoid changing input tokens - labels = labels.clone() - for idx, spans in enumerate(batch.loss_masking_spans): - if not spans.numel(): - continue - valid_spans = spans[ - (spans[:, 0] <= sequence_k + max_prediction_distance - 1) - & (spans[:, 1] >= sequence_offset) - ] - if valid_spans.numel(): - # if span is partially within the sequence, truncate parts of spans that are outside of the sequence - valid_spans[:, 0].clamp_(min=sequence_offset) - valid_spans[:, 1].clamp_(max=sequence_k + max_prediction_distance - 1) - valid_spans -= sequence_offset - loss_mask = torch.ones_like(labels, dtype=torch.bool) - for start, end in valid_spans: - if sequence_first: - loss_mask[start : end + 1, idx] = False - else: - loss_mask[idx, start : end + 1] = False - if self._config.output_layer.distillation_model is not None: - kwargs[LanguageModelKwargs.loss_mask] = loss_mask - labels = torch.where(loss_mask, labels, -100) - kwargs[LanguageModelKwargs.labels] = labels - kwargs.update(reference_logits[i]) + loss_masking_spans = batch.loss_masking_spans.crop(labels_begin, labels_end) + loss_mask = torch.ones_like(labels, dtype=torch.bool) + for sample_index, loss_masking_spans in enumerate(loss_masking_spans.ranges): + for begin, end in loss_masking_spans: + loss_mask[sample_index, begin:end] = False + if self._config.output_layer.distillation_model is not None: + kwargs[LanguageModelKwargs.loss_mask] = loss_mask + labels = torch.where(loss_mask, labels, -100) + + kwargs[LanguageModelKwargs.labels] = ( + labels.transpose(0, 1) if kwargs[AttentionKwargs.sequence_first] else labels + ).contiguous() if batch.chosen_spans is not None: - chosen_valid_spans = [] - for spans in batch.chosen_spans: - if not spans.numel(): - continue - # only keep spans within the sequence or partially within the sequence - valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] - if valid_spans.numel(): - # if span is partially within the sequence, truncate parts of spans that are outside of the sequence - valid_spans[0].clamp_(min=sequence_offset) - valid_spans[1].clamp_(max=sequence_k) - valid_spans -= sequence_offset - - chosen_valid_spans.append(valid_spans) - kwargs[LanguageModelKwargs.chosen_spans] = chosen_valid_spans - - rejected_valid_spans = [] - for spans in batch.rejected_spans: - if not spans.numel(): - continue - # only keep spans within the sequence or partially within the sequence - valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] - if valid_spans.numel(): - # if span is partially within the sequence, truncate parts of spans that are outside of the sequence - valid_spans[0].clamp_(min=sequence_offset) - valid_spans[1].clamp_(max=sequence_k) - valid_spans -= sequence_offset - - rejected_valid_spans.append(valid_spans) - kwargs[LanguageModelKwargs.rejected_spans] = rejected_valid_spans - + kwargs[LanguageModelKwargs.chosen_spans] = batch.chosen_spans.crop(labels_begin, labels_end).ranges + + if batch.rejected_spans is not None: + kwargs[LanguageModelKwargs.rejected_spans] = batch.rejected_spans.crop( + labels_begin, labels_end + ).ranges + + tokens = ( + cropped_tokens.tokens.transpose(0, 1) + if kwargs[AttentionKwargs.sequence_first] + else cropped_tokens.tokens + ).contiguous() self.preprocess(tokens, kwargs) preprocessed.append((tokens, kwargs)) diff --git a/tests/data/common.py b/tests/data/common.py index 3ade0e9bf..5102bfbcd 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -8,8 +8,14 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.config import IndexedDatasetConfig, SampledDatasetConfig, SamplingParameters -from fast_llm.data.dataset.gpt.config import GPTSamplingConfig, GPTSamplingData, GPTSamplingParameters, ShufflingType +from fast_llm.data.dataset.config import ( + IndexedDatasetConfig, + SampledDatasetConfig, + SamplingConfig, + SamplingParameters, + ShufflingType, +) +from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.dataset.sampled import SampledIndexedDataset from fast_llm.data.sample.abstract import Sample @@ -35,7 +41,7 @@ def get_sampling_data( # Config with convenient defaults. distributed = Distributed(DistributedConfig(), use_cpu=True) return GPTSamplingData( - config=GPTSamplingConfig( + config=SamplingConfig( seed=seed, gpu=gpu, shuffle=shuffle, @@ -88,7 +94,7 @@ def get_test_data_and_compare_samples( expected_samples = {PhaseType.training.value.lower(): expected_samples} assert "sampling" not in config - config["sampling"] = GPTSamplingConfig(seed=seed, gpu=gpu, shuffle=shuffle) + config["sampling"] = SamplingConfig(seed=seed, gpu=gpu, shuffle=shuffle) data = GPTData(GPTDataConfig.from_dict(config), distributed_config) data.setup(distributed, sampling_parameters, cache_directory) with NoAutoValidate(): @@ -117,21 +123,21 @@ def compare_indexed_dataset( sizes = dataset.get_document_sizes() # Assert.eq(sizes.sum(), num_tokens) Assert.all_equal( - [len(dataset.get_document(i).token_ids) for i in range(min(len(dataset), 100))], + [len(dataset.get_document(i).tokens.tokens) for i in range(min(len(dataset), 100))], sizes[: min(len(dataset), 100)], ) for i, expected_sample in expected_samples.items(): - Assert.all_equal(dataset.get_document(i).token_ids, np.array(expected_sample, dtype=np.uint16)) + Assert.all_equal(dataset.get_document(i).tokens.tokens, np.array(expected_sample, dtype=np.uint16)) if loss_masking_spans: for i, loss_masking_span in loss_masking_spans.items(): - Assert.all_equal( + Assert.eq( dataset.get_document( i, parameters=GPTSamplingParameters( num_samples=0, sequence_length=0, vocab_size=0, use_loss_masking_spans=True ), - ).loss_masking_spans, - np.array(loss_masking_spans[i], dtype=np.int32).reshape(-1, 2), + ).loss_masking_spans.ranges, + loss_masking_spans[i], ) @@ -161,7 +167,7 @@ def validate_indexed_dataset_sampling(sampled: SampledIndexedDataset, expected_s ) seen_tokens = 0 for document_index in document_sampling: - document = sampled._indexed_dataset.get_document(document_index).token_ids + document = sampled._indexed_dataset.get_document(document_index).tokens.tokens all_tokens[seen_tokens : seen_tokens + len(document)] = document[: num_tokens - seen_tokens] seen_tokens += len(document) diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 678bffa21..b64465d55 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -4,7 +4,7 @@ import pytest from fast_llm.data.dataset.config import BlendedDatasetConfig -from fast_llm.data.sample.gpt import GPTSample +from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert, normalize_probabilities from tests.data.common import ( compare_sampled_dataset, @@ -123,7 +123,7 @@ def test_gpt_blended(): ], "weights": [0.75, 0.25], }, - BlendedDatasetConfig[GPTSample], + BlendedDatasetConfig[LanguageModelSample], ).build_and_sample(get_sampling_data(8, sequence_length=5)) compare_sampled_dataset(sampled, GPT_BLENDED_SAMPLES) diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index bb4905cb6..5335e01c0 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -1,5 +1,5 @@ from fast_llm.data.dataset.config import ConcatenatedDatasetConfig -from fast_llm.data.sample.gpt import GPTSample +from fast_llm.data.sample.language_model import LanguageModelSample from tests.data.common import ( compare_indexed_dataset, compare_sampled_dataset, @@ -28,7 +28,7 @@ def test_gpt_concatenate(): get_test_dataset() dataset = get_dataset_config( {"type": "concatenated", "datasets": [{"type": "memmap", "path": DATASET_PREFIX} for _ in range(3)]}, - ConcatenatedDatasetConfig[GPTSample], + ConcatenatedDatasetConfig[LanguageModelSample], ).build() compare_indexed_dataset( dataset, diff --git a/tests/data/test_memmap.py b/tests/data/test_memmap.py index 1286bddd7..d718f089c 100644 --- a/tests/data/test_memmap.py +++ b/tests/data/test_memmap.py @@ -27,8 +27,8 @@ def test_gpt_memmap(cache_directory): MEMMAP_DATASET_SPANS = { 9: [], - 10: [[0, 4], [6, 8]], - 13: [[1, 2]], + 10: [(0, 5), (6, 9)], + 13: [(1, 3)], 15: [], } diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index 388726bfb..90610381a 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -11,7 +11,7 @@ from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, GPTMemmapDatasetPreparatorConfig from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator -from fast_llm.data.sample.gpt import GPTSample +from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert from tests.data.common import MockGPTMemmapDatasetConfig # Noqa @@ -31,59 +31,44 @@ def get_preparator(output_path: str, dataset_path_name: str) -> GPTMemmapDataset @pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) def test_write_memmap_dataset(dtype): documents = [ - GPTSample(torch.from_numpy(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype))) + (torch.from_numpy(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype)), None, None, None) for _ in range(100) ] with tempfile.TemporaryDirectory() as temp_dir: prefix = pathlib.Path(temp_dir) GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) dataset = GPTMemmapDataset(name="foo", prefix=prefix) - for i, document in enumerate(documents): - assert np.array_equal( - dataset.get_document(i).token_ids, document.token_ids, equal_nan=True - ), f"Mismatch for document {i}: {document} != {dataset.get_document(i)}." + for i, (tokens, _, _, _) in enumerate(documents): + Assert.all_equal(dataset.get_document(i).tokens.tokens, tokens) -@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) -def test_write_memmap_preference_dataset(dtype): - def generate_valid_span(max_seq_length): - span = np.random.choice(np.arange(0, max_seq_length - 1), size=2, replace=False) - return torch.from_numpy(np.sort(span)) +def _generate_valid_span(max_seq_length): + return np.sort(np.random.choice(np.arange(0, max_seq_length - 1), size=2, replace=False)).tolist() - vocab_size = 1000 - max_seq_length = 8192 - num_samples = 100 +@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) +def test_write_memmap_preference_dataset(dtype): documents = [ - GPTSample( - token_ids=torch.from_numpy(np.random.randint(vocab_size, size=max_seq_length).astype(dtype)), - chosen_span=generate_valid_span(max_seq_length=max_seq_length), - rejected_span=generate_valid_span(max_seq_length=max_seq_length), + ( + torch.from_numpy(np.random.randint(1000, size=100).astype(dtype)), + None, + _generate_valid_span(100), + _generate_valid_span(100), ) - for _ in range(num_samples) + for _ in range(50) ] with tempfile.TemporaryDirectory() as temp_dir: prefix = pathlib.Path(temp_dir) GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) dataset = GPTMemmapDataset(name="foo", prefix=prefix) - for i, document in enumerate(documents): - dataset_item = dataset.get_document( - i, - parameters=GPTSamplingParameters( - num_samples=0, sequence_length=0, vocab_size=0, use_preference_loss_spans=True - ), - ) - assert np.array_equal( - dataset_item.token_ids, document.token_ids, equal_nan=True - ), f"Token ids mismatch for document {i}: {document} != {dataset.get_document(i)}." - - assert np.array_equal( - dataset_item.chosen_span, document.chosen_span, equal_nan=True - ), f"Chosen loss masking spans mismatch for document {i}: {document.chosen_span} != {dataset.get_document(i).chosen_span}." - - assert np.array_equal( - dataset_item.rejected_span, document.rejected_span, equal_nan=True - ), f"Rejected loss masking spans mismatch for document {i}: {document.rejected_span} != {dataset.get_document(i).rejected_span}." + parameters = GPTSamplingParameters( + num_samples=0, sequence_length=0, vocab_size=0, use_preference_loss_spans=True + ) + for i, (token_ids, _, chosen_spans, rejected_spans) in enumerate(documents): + document = dataset.get_document(i, parameters=parameters) + Assert.all_equal(document.tokens.tokens, token_ids) + Assert.all_equal(document.chosen_spans.ranges, chosen_spans) + Assert.all_equal(document.rejected_spans.ranges, rejected_spans) def test_load_metadata_from_hub(): @@ -136,7 +121,7 @@ def test_absent_metadata_local(): def test_split_dataset(): - dataset_config_0 = IndexedDatasetConfig[GPTSample].from_dict(DATASET_DICT_0.copy()) + dataset_config_0 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_0.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0], {"training": 3, "validation": 1}, @@ -164,8 +149,8 @@ def test_split_dataset(): def test_split_datasets_0(): - dataset_config_0 = IndexedDatasetConfig[GPTSample].from_dict(DATASET_DICT_0.copy()) - dataset_config_1 = IndexedDatasetConfig[GPTSample].from_dict(DATASET_DICT_1.copy()) + dataset_config_0 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_0.copy()) + dataset_config_1 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_1.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0, dataset_config_1], {"training": 1, "validation": 1}, @@ -183,8 +168,8 @@ def test_split_datasets_0(): def test_split_datasets_1(): - dataset_config_0 = IndexedDatasetConfig[GPTSample].from_dict(DATASET_DICT_0.copy()) - dataset_config_1 = IndexedDatasetConfig[GPTSample].from_dict(DATASET_DICT_1.copy()) + dataset_config_0 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_0.copy()) + dataset_config_1 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_1.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0, dataset_config_1], {"training": 3, "validation": 1}, pathlib.Path(".") ) diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index d7b3021fe..4019dd909 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -2,9 +2,11 @@ import pytest import torch -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSamplingParameters, ShufflingType +from fast_llm.data.dataset.config import ShufflingType +from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset -from fast_llm.data.sample.gpt import GPTSample +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.sample.token import TokenSample from fast_llm.utils import Assert from tests.data.common import ( get_dataset_config, @@ -61,7 +63,7 @@ def test_gpt_sampled_data(): ) -class SimpleGPTIndexedDataset[SampleType: GPTSample](IndexedDataset[SampleType]): +class SimpleGPTIndexedDataset[SampleType: LanguageModelSample](IndexedDataset[SampleType]): # TODO: worth adding to the main codebase? def __init__(self, samples): self._samples = samples @@ -71,7 +73,7 @@ def get_document( ) -> SampleType: if end is None: end = len(self._samples[index]) - return GPTSample(token_ids=torch.tensor(self._samples[index][begin:end], dtype=torch.int64)) + return LanguageModelSample(TokenSample(torch.tensor(self._samples[index][begin:end], dtype=torch.int64))) def __len__(self) -> int: return len(self._samples) diff --git a/tests/data/test_slice.py b/tests/data/test_slice.py index e83387a24..3c6ae10d4 100644 --- a/tests/data/test_slice.py +++ b/tests/data/test_slice.py @@ -1,5 +1,5 @@ from fast_llm.data.dataset.config import DatasetSliceConfig -from fast_llm.data.sample.gpt import GPTSample +from fast_llm.data.sample.language_model import LanguageModelSample from tests.data.common import ( compare_indexed_dataset, get_dataset_config, @@ -35,7 +35,7 @@ def test_gpt_slice(): # samples[9:18] dataset = get_dataset_config( {"type": "slice", "dataset": {"type": "memmap", "path": DATASET_PREFIX}, "begin": 0.0015, "end": 0.003}, - DatasetSliceConfig[GPTSample], + DatasetSliceConfig[LanguageModelSample], ).build() compare_indexed_dataset(dataset, 9, 544, {i - 9: sample for i, sample in MEMMAP_DATASET_SAMPLES.items()}) sampled = dataset.sample(get_sampling_data(8, sequence_length=5)) diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index f057c037f..e9690a3c5 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -11,7 +11,8 @@ from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSamplingData from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.sampled import logger -from fast_llm.data.sample.gpt import GPTSample +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.sample.token import TokenSample from fast_llm.utils import Assert from tests.utils.compare_tensor_logs import CompareConfig from tests.utils.dataset import get_model_test_dataset @@ -155,7 +156,7 @@ def __getitem__(self, idx: int) -> typing.Any: token_ids = torch.cat([sample.token_ids for sample in sample_list]) Assert.eq(len(token_ids), self._sequence_length + 1) - return GPTSample(token_ids=token_ids) + return LanguageModelSample(TokenSample(token_ids)) @property def name(self) -> str: diff --git a/tests/test_config.py b/tests/test_config.py index 63f2606f1..9a1f542a0 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -6,7 +6,7 @@ import yaml from fast_llm.config import NoAutoValidate -from fast_llm.data.dataset.gpt.config import GPTSamplingConfig +from fast_llm.data.dataset.config import SamplingConfig from fast_llm.engine.checkpoint.config import CheckpointSaveMetadataConfig, ModelConfigType from fast_llm.engine.distributed.config import DistributedConfig, DistributedDim, DistributedDimNames from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig, PretrainedGPTModelConfig @@ -60,7 +60,7 @@ def test_validate_example_config(): GPTTrainerConfig.from_dict(fast_llm_config_dict) -@pytest.mark.parametrize("cls", (GPTSamplingConfig, GPTModelConfig)) +@pytest.mark.parametrize("cls", (SamplingConfig, GPTModelConfig)) def test_serialize_default_config_updates(cls): # Config classes used as config updates should have a default that serializes to an empty dict # so no value is incorrectly overridden. diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index b43923f4d..7d084b5ab 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -6,7 +6,6 @@ import yaml from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.sample.gpt import GPTSample from tests.utils.global_variables import ( DATASET_PREFIX, MODEL_DATASET_PREFIX, @@ -47,15 +46,35 @@ def get_test_dataset( tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) samples = [ - GPTSample(torch.from_numpy(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size)) + ( + torch.from_numpy(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size), + None, + None, + None, + ) for document in texts ] if max_spans > 0: - lengths = np.array([max(len(sample.token_ids), 1) for sample in samples]) - spans = np.sort(np.random.RandomState(seed + 3847).randint(0, lengths[:, None], [len(samples), max_spans])) - for sample, span in zip(samples, spans): - span = np.unique(span) - sample.loss_masking_spans = torch.from_numpy(span[: len(span) // 2 * 2].reshape(-1, 2)) + spans = np.sort( + np.random.RandomState(seed + 3847).randint( + 0, np.array([[max(len(tokens), 1)] for tokens, _, _, _ in samples]), [len(samples), max_spans] + ) + ) + samples = ( + (tokens, np.unique(spans_).tolist()) for (tokens, _, _, _), spans_ in zip(samples, spans, strict=True) + ) + samples = [ + ( + tokens, + torch.tensor( + [[begin, end] for begin, end in zip(spans_[::2], spans_[1::2], strict=False)], + dtype=torch.int32, + ).reshape(-1, 2), + None, + None, + ) + for tokens, spans_ in samples + ] GPTMemmapDataset.write_dataset(prefix, samples) yaml.safe_dump( From d6f6944860e088ee10247a01dba9019aeea6e72b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 16 Oct 2025 00:19:56 -0400 Subject: [PATCH 05/45] fix --- fast_llm/data/dataset/gpt/memmap.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index e6b650621..c47caef79 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -179,17 +179,15 @@ def get_document( raise ValueError("Failed to read chosen spans from memmap dataset.") elif self._has_preference_spans and self._rejected_spans is None: raise ValueError("Failed to read rejected spans from memmap dataset.") - elif begin != 0 or end != sample_size: - raise ValueError("Samples with preference spans should not be cropped.") # TODO: ====== Store in range format ====== chosen_spans = RangeSample( [(self._chosen_spans[index][0].item(), self._chosen_spans[index][1].item() + 1)], sample_size, - ) + ).crop(begin, end) rejected_spans = RangeSample( [(self._rejected_spans[index][0].item(), self._rejected_spans[index][1].item() + 1)], sample_size, - ) + ).crop(begin, end) else: chosen_spans = rejected_spans = None From 5c802fadd358d4e4f5c3c90bffa1c3db8d34bd97 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 16 Oct 2025 18:22:28 -0400 Subject: [PATCH 06/45] fixes --- fast_llm/data/data/gpt/config.py | 6 +- fast_llm/data/dataset/gpt/fim.py | 20 +-- fast_llm/data/dataset/gpt/memmap.py | 4 + fast_llm/data/dataset/gpt/random.py | 6 +- fast_llm/data/sample/abstract.py | 3 +- fast_llm/data/sample/language_model.py | 16 +- fast_llm/data/sample/token.py | 10 +- fast_llm/engine/config_utils/data_type.py | 1 + fast_llm/layers/attention/attention.py | 94 ++++++----- fast_llm/layers/attention/config.py | 20 ++- fast_llm/models/gpt/config.py | 6 - fast_llm/models/gpt/model.py | 1 - fast_llm/models/gpt/trainer.py | 1 - tests/data/common.py | 16 +- tests/data/test_blending.py | 2 +- tests/data/test_memmap.py | 4 +- tests/data/test_prepare_gpt_memmap.py | 10 +- tests/data/test_sampling.py | 2 +- tests/functional/test_functional.py | 197 ++++++---------------- tests/test_attention.py | 4 +- tests/utils/dataset.py | 25 +-- 21 files changed, 198 insertions(+), 250 deletions(-) diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index c7f16c936..ba5be883a 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -1,12 +1,14 @@ import logging +import typing from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.data.config import MultiprocessingContext from fast_llm.data.data.config import DataConfig from fast_llm.data.dataset.config import SampledDatasetConfig -from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.data.sample.language_model import LanguageModelSample logger = logging.getLogger(__name__) @@ -21,7 +23,7 @@ class GPTDataConfig(DataConfig): _abstract = False # TODO: Review field. Move closer to phase definition in training config? - datasets: dict[str, SampledDatasetConfig[LanguageModelSample]] = Field( + datasets: dict[str, SampledDatasetConfig["LanguageModelSample"]] = Field( default_factory=dict, desc="Configuration for the dataset(s).", hint=FieldHint.core, diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index e7ca863b4..1fde74530 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -83,19 +83,19 @@ def _fim(self, sample: np.ndarray, np_rng: np.random.RandomState) -> np.ndarray: permuted = self._fim_split_and_permute_sequence(sample[curr_start_position:], np_rng) new_samples.append(permuted) - sample = np.concatenate(new_samples) + fim_sample = np.concatenate(new_samples) else: - sample = self._fim_split_and_permute_sequence(sample, np_rng) + fim_sample = self._fim_split_and_permute_sequence(sample, np_rng) # Truncate or pad sequence to max-length - diff = sample.shape[0] - sample_len + diff = fim_sample.shape[0] - sample_len if diff > 0: # too long - sample = sample[:sample_len] + fim_sample = fim_sample[:sample_len] elif diff < 0: # too short - sample = np.concatenate([sample, np.full((-1 * diff), self._pad_tok_id)]) + fim_sample = np.concatenate([fim_sample, np.full((-1 * diff), self._pad_tok_id)]) # noqa - assert sample.shape[0] == sample_len - return sample + assert fim_sample.shape[0] == sample_len + return fim_sample.astype(sample.dtype) def _fim_split_and_permute_sequence(self, sequence: np.ndarray, np_rng: np.random.RandomState) -> np.ndarray: """ @@ -168,9 +168,9 @@ def _fim_permute_sequence( middle = contents[boundaries[0] : boundaries[1]] suffix = contents[boundaries[1] :] - prefix = np.array([*self._tokenizer.tokenize(prefix, end=False)], dtype=np.int64) - middle = np.array([*self._tokenizer.tokenize(middle, begin=False, end=False)], dtype=np.int64) - suffix = np.array([*self._tokenizer.tokenize(suffix, begin=False)], dtype=np.int64) + prefix = np.array([*self._tokenizer.tokenize(prefix, end=False)], dtype=sequence.dtype) + middle = np.array([*self._tokenizer.tokenize(middle, begin=False, end=False)], dtype=sequence.dtype) + suffix = np.array([*self._tokenizer.tokenize(suffix, begin=False)], dtype=sequence.dtype) # here we truncate each given segment to fit the same length as it was before # A consequence is that we never reach the end of a file? diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index c47caef79..486afee1d 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -163,6 +163,9 @@ def get_document( if end > begin else torch.empty(0, dtype=self._dtype) ) + if not self._dtype.is_signed: + # Needed because torch doesn't yet support type promotion between signed and unsigned types. TODO: Remove when supported. + token_ids = token_ids.to(torch.int64) if parameters is not None and parameters.use_loss_masking_spans: assert self._spans is not None # TODO: ====== Store in range format (begin, end) ====== @@ -275,6 +278,7 @@ def write_dataset( num_spans = np.array(num_spans, dtype=np.int32) if len(spans) > 0: spans = np.vstack(spans, dtype=np.int32) + print("JEFNEW", spans[:50].tolist()) else: spans = np.array(spans, dtype=np.int32) chosen_spans = np.array(chosen_spans, dtype=np.int32).reshape(-1, 2) diff --git a/fast_llm/data/dataset/gpt/random.py b/fast_llm/data/dataset/gpt/random.py index 62c4311ab..463c5a7d6 100644 --- a/fast_llm/data/dataset/gpt/random.py +++ b/fast_llm/data/dataset/gpt/random.py @@ -5,6 +5,7 @@ from fast_llm.data.dataset.gpt.config import GPTSamplingData from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.token import TokenSample +from fast_llm.engine.config_utils.data_type import get_unsigned_integer_type class GPTRandomDataset[SampleType: LanguageModelSample](SamplableDataset[SampleType]): @@ -31,11 +32,13 @@ def __init__(self, sampling: GPTSamplingData, name: str): # TODO: Support? assert not self._parameters.use_loss_masking_spans assert not self._parameters.use_preference_loss_spans + self._dtype = get_unsigned_integer_type(self._parameters.vocab_size).torch def __len__(self) -> int: return self._parameters.num_samples def __getitem__(self, index: int) -> SampleType: + # TODO: Sample in self._dtype (breaking) return LanguageModelSample( TokenSample( torch.from_numpy( @@ -43,9 +46,8 @@ def __getitem__(self, index: int) -> SampleType: 0, self._parameters.vocab_size, size=(self._parameters.sequence_length + self._parameters.extra_tokens,), - dtype=np.int64, ) - ) + ).to(self._dtype), ) ) diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py index c1bebe166..b2cb42cfe 100644 --- a/fast_llm/data/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -1,7 +1,8 @@ import abc import typing -import torch +if typing.TYPE_CHECKING: + import torch class Sample(abc.ABC): diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index 57a9a30f6..0a4efa47b 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -21,10 +21,10 @@ def __init__( @classmethod def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: return cls( - TokenSample.from_documents(document.tokens for document in documents), - _merge_optional(RangeSample.from_documents, (document.loss_masking_spans for document in documents)), - _merge_optional(RangeSample.from_documents, (document.chosen_spans for document in documents)), - _merge_optional(RangeSample.from_documents, (document.rejected_spans for document in documents)), + TokenSample.from_documents([document.tokens for document in documents]), + _merge_optional(RangeSample.from_documents, [document.loss_masking_spans for document in documents]), + _merge_optional(RangeSample.from_documents, [document.chosen_spans for document in documents]), + _merge_optional(RangeSample.from_documents, [document.rejected_spans for document in documents]), ) def crop(self, begin: int, end: int) -> typing.Self: @@ -55,10 +55,10 @@ def __init__( @classmethod def from_samples(cls, samples: typing.Iterable[LanguageModelSample]) -> typing.Self: return cls( - TokenBatch.from_samples(sample.tokens for sample in samples), - _merge_optional(RangeBatch.from_samples, (sample.loss_masking_spans for sample in samples)), - _merge_optional(RangeBatch.from_samples, (sample.chosen_spans for sample in samples)), - _merge_optional(RangeBatch.from_samples, (sample.rejected_spans for sample in samples)), + TokenBatch.from_samples([sample.tokens for sample in samples]), + _merge_optional(RangeBatch.from_samples, [sample.loss_masking_spans for sample in samples]), + _merge_optional(RangeBatch.from_samples, [sample.chosen_spans for sample in samples]), + _merge_optional(RangeBatch.from_samples, [sample.rejected_spans for sample in samples]), ) def to_samples(self) -> list[LanguageModelSample]: diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index db853341a..d12b27fa0 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -26,14 +26,14 @@ def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: def crop(self, begin: int, end: int) -> typing.Self: sample_size = end - begin if self.lengths == [len(self.tokens)]: - # Shortcut for the common case of a single document. + # Shortcut for the frequent case of a single document. lengths = [sample_size] else: begin_ = 0 lengths = [] for length in self.lengths: - end_ = begin + length - cropped_length = max(begin_ - begin, 0) - min(end_ - begin, end) + end_ = begin_ + length + cropped_length = min(end_, end) - max(begin_, begin) if cropped_length > 0: lengths.append(cropped_length) if end_ > end: @@ -46,8 +46,10 @@ def __len__(self) -> int: class TokenBatch(Batch): - def __init__(self, tokens: torch.Tensor, lengths: list[list[int]]) -> None: + def __init__(self, tokens: torch.Tensor, lengths: list[list[int]] | None) -> None: self.tokens = tokens + if lengths is None: + lengths = [[tokens.size(1)]] * tokens.size(0) self.lengths = lengths @classmethod diff --git a/fast_llm/engine/config_utils/data_type.py b/fast_llm/engine/config_utils/data_type.py index add121c50..1a0fed91b 100644 --- a/fast_llm/engine/config_utils/data_type.py +++ b/fast_llm/engine/config_utils/data_type.py @@ -168,6 +168,7 @@ def _set_triton_dtype_map() -> None: def get_unsigned_integer_type(max_size: int) -> DataType: + # TODO: Use uint types (recently added for torch, not enough methods supported yet) if max_size < 2**8: return DataType.uint8 elif max_size < 2**15: diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 167184193..4e9b1b5b5 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -5,11 +5,12 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import wrap_forward_backward -from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs +from fast_llm.layers.attention.config import AttentionConfig, AttentionImplementation, AttentionKwargs from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias @@ -79,7 +80,12 @@ def __init__( peft=peft, return_bias=return_bias, ) - self._use_flash_attention = self._config.do_use_flash_attention(self._distributed_config) + self._implementation = self._config.implementation + if self._implementation == AttentionImplementation.auto: + if _flash_available and self._distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16): + self._implementation = AttentionImplementation.flash + else: + self._implementation = AttentionImplementation.backup self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) self._sequence_data_parallel_dim = self._distributed_config.get_distributed_dim( @@ -209,8 +215,7 @@ def _attn_fused( attn_weights = torch.where(mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) - with set_generator(self._distributed.tp_generator): - attn_weights = torch.dropout(attn_weights, self._config.dropout, self.training) + attn_weights = torch.dropout(attn_weights, self._config.dropout, self.training) attn_output = torch.bmm( attn_weights.view(b * self._local_head_groups, sq * self._local_heads_per_group, sk), value ) @@ -329,47 +334,52 @@ def _forward( window_size = (-1, -1) if self._config.window_size is None else (self._config.window_size - 1, 0) - if self._use_flash_attention: - assert _flash_available - with set_generator(self._distributed.tp_generator): - if (cu_seqlens_q := kwargs.get(AttentionKwargs.cu_seqlens_q, None)) is not None: - out_dims = query.size() - query = query.view(-1, query.size(-2), query.size(-1)) - key = key.view(-1, key.size(-2), key.size(-1)) - value = value.view(-1, value.size(-2), value.size(-1)) - input_ = _flash_attn_varlen_func( + with set_generator(self._distributed.tp_generator): + if self._implementation == AttentionImplementation.flash_varlen: + assert _flash_available + out_dims = query.size() + query = query.view(-1, query.size(-2), query.size(-1)) + key = key.view(-1, key.size(-2), key.size(-1)) + value = value.view(-1, value.size(-2), value.size(-1)) + input_ = ( + _flash_attn_varlen_func( query, key, value, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=kwargs.get(AttentionKwargs.cu_seqlens_k), - max_seqlen_q=kwargs.get(AttentionKwargs.max_seqlen_q), - max_seqlen_k=kwargs.get(AttentionKwargs.max_seqlen_k), + cu_seqlens_q=kwargs[AttentionKwargs.cu_seqlens_q], + cu_seqlens_k=kwargs[AttentionKwargs.cu_seqlens_k], + max_seqlen_q=kwargs[AttentionKwargs.max_seqlen_q], + max_seqlen_k=kwargs[AttentionKwargs.max_seqlen_k], dropout_p=self._config.dropout if self.training else 0.0, window_size=window_size, causal=self._config.causal, softmax_scale=self._softmax_scale, - ).view(*out_dims) - else: - input_ = _flash_attn_func( - query, - key, - value, - window_size=window_size, - dropout_p=self._config.dropout if self.training else 0.0, - causal=self._config.causal, - softmax_scale=self._softmax_scale, ) - input_ = input_.flatten(-2) - else: - # TODO: Avoid the flattens. - input_ = self._attn_fused( - query.flatten(-2), - key.flatten(-2), - value.flatten(-2), - kwargs[AttentionKwargs.attention_mask], - kwargs[AttentionKwargs.attention_mask_value], - ) + .view(*out_dims) + .flatten(-2) + ) + elif self._implementation == AttentionImplementation.flash: + assert _flash_available + input_ = _flash_attn_func( + query, + key, + value, + window_size=window_size, + dropout_p=self._config.dropout if self.training else 0.0, + causal=self._config.causal, + softmax_scale=self._softmax_scale, + ).flatten(-2) + elif self._implementation == AttentionImplementation.backup: + # TODO: Avoid the flattens. + input_ = self._attn_fused( + query.flatten(-2), + key.flatten(-2), + value.flatten(-2), + kwargs[AttentionKwargs.attention_mask], + kwargs[AttentionKwargs.attention_mask_value], + ) + else: + raise NotImplementedError(self._implementation) if self._debug.enabled: self._debug(query, "query", self._query_dims, kwargs) @@ -413,8 +423,12 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c attention_compute = sequence_q * sequence_k * attn_compute_base - if (not config.hardware) or self._use_flash_attention: + if (not config.hardware) or self._implementation in ( + AttentionImplementation.flash, + AttentionImplementation.flash_varlen, + ): # Remove non-causal part. (TODO: Support non-causal) + # For varlen implementation, compute is overestimated as we include cross-document attention. attention_compute -= (sequence_q * (sequence_q - 1) * attn_compute_base) // 2 if self._config.window_size is not None: @@ -439,9 +453,9 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: self._rotary.preprocess(batch, kwargs) - if not self._use_flash_attention: + if self._implementation == AttentionImplementation.backup: self._preprocess_for_backup_attention(batch, kwargs) - elif AttentionKwargs.sequence_lengths in kwargs: + elif self._implementation == AttentionImplementation.flash_varlen: self._preprocess_for_varlen(batch, kwargs) def _preprocess_for_backup_attention(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 68b6dde91..c02b67293 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -1,10 +1,9 @@ +import enum import logging import typing import warnings from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig from fast_llm.layers.attention.rotary.config import RotaryConfig from fast_llm.layers.block.config import BlockKwargs @@ -32,6 +31,13 @@ class AttentionKwargs(BlockKwargs): past_key_values = "past_key_values" +class AttentionImplementation(enum.StrEnum): + auto = "auto" + flash = "flash" + flash_varlen = "flash_varlen" + backup = "backup" + + @config_class(dynamic_type={MixerConfig: "attention"}) class AttentionConfig(MixerConfig): # TODO: Make mixer class dynamic. @@ -107,6 +113,13 @@ class AttentionConfig(MixerConfig): " Under muP (if scaling number of heads instead of head_size): use 0.5.", valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) + implementation: AttentionImplementation = Field( + default=AttentionImplementation.auto, + desc="The implementation to use for the attention layer.", + doc="Use `flash_varlen` to enable the varlen version of Flash Attention and prevent cross-document attention. " + "Default: `flash` if supported, otherwise `backup`,", + hint=FieldHint.feature, + ) def _validate(self) -> None: super()._validate() @@ -121,6 +134,3 @@ def layer_class(self) -> "type[Attention]": from fast_llm.layers.attention.attention import Attention return Attention - - def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: - return self.use_flash_attention and distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index a901a0466..c1ee246f7 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -48,12 +48,6 @@ class GPTBatchConfig(BatchConfig): hint=FieldHint.performance, valid=check_field(Assert.gt, 0), ) - # TODO: Find a better place for these? - cross_document_attention: bool = Field( - default=True, - desc="Applies attention to tokens from other documents in the packed sequence. Set to False for masking attention to other documents.", - hint=FieldHint.feature, - ) use_loss_masking_spans: bool = Field( default=False, desc="Read loss masking spans from the dataset.", diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 3e50d1ed1..56438670f 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -198,7 +198,6 @@ def preprocess_batch( **kwargs_meta, AttentionKwargs.past_key_values: pasts, AttentionKwargs.presents: presents, - # TODO: ====== Use only if wanted ====== AttentionKwargs.sequence_lengths: cropped_tokens.lengths, **reference_logits[i], } diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 54ea13dc4..b8fb22ebb 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -27,7 +27,6 @@ def _get_sampling_parameters( "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, # OK since DPO is not supported for MTP. "use_preference_loss_spans": getattr(self._config.model.base_model.head, "enable_dpo", False), - "cross_document_attention": self._config.batch.cross_document_attention, "truncate_documents": self._config.batch.truncate_documents, "extra_tokens": self._config.model.base_model.head.max_prediction_distance, } diff --git a/tests/data/common.py b/tests/data/common.py index 5102bfbcd..e6ab8a265 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -103,12 +103,15 @@ def get_test_data_and_compare_samples( batch_config.validate() tokens = { phase: torch.stack( - [batch.token_ids[0] for batch in data.get_iterator(batch_config, phase, consumed_samples=0, num_workers=0)] + [ + batch.tokens.tokens[0] + for batch in data.get_iterator(batch_config, phase, consumed_samples=0, num_workers=0) + ] ) for phase, samples in samples_per_dataset.items() } for phase, expected_samples_ in expected_samples.items(): - Assert.all_equal(tokens[phase], expected_samples_) + Assert.all_equal(tokens[phase].to(torch.int64), expected_samples_) return data @@ -127,9 +130,10 @@ def compare_indexed_dataset( sizes[: min(len(dataset), 100)], ) for i, expected_sample in expected_samples.items(): - Assert.all_equal(dataset.get_document(i).tokens.tokens, np.array(expected_sample, dtype=np.uint16)) + Assert.all_equal(dataset.get_document(i).tokens.tokens, np.array(expected_sample, dtype=np.int64)) if loss_masking_spans: for i, loss_masking_span in loss_masking_spans.items(): + print(i) Assert.eq( dataset.get_document( i, @@ -143,7 +147,9 @@ def compare_indexed_dataset( def compare_sampled_dataset(sampled: SampledDataset, expected_samples: list[list[int] | np.ndarray]) -> None: Assert.eq(len(sampled), len(expected_samples)) - Assert.all_equal(torch.stack([sampled[i].token_ids for i in range(len(expected_samples))]), expected_samples) + Assert.all_equal( + torch.stack([sampled[i].tokens.tokens for i in range(len(expected_samples))]).to(torch.int64), expected_samples + ) def validate_indexed_dataset_sampling(sampled: SampledIndexedDataset, expected_samples: list[list[int]] | None = None): @@ -178,7 +184,7 @@ def validate_indexed_dataset_sampling(sampled: SampledIndexedDataset, expected_s all_tokens[index * sampled._parameters.sequence_length : (index + 1) * sampled._parameters.sequence_length + 1] for index in range(sampled._parameters.num_samples) ] - token_ids = torch.stack([sampled[i].token_ids for i in range(len(sampled))]) + token_ids = torch.stack([sampled[i].tokens.tokens for i in range(len(sampled))]).to(torch.int64) Assert.all_equal(token_ids, validate_samples) if expected_samples is not None: diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index b64465d55..0099cb50b 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -162,7 +162,7 @@ def test_gpt_blended_mixed(): ], "weights": [0.6, 0.4], }, - BlendedDatasetConfig[GPTSample], + BlendedDatasetConfig[LanguageModelSample], ).build_and_sample(get_sampling_data(8, sequence_length=5)) compare_sampled_dataset(sampled, GPT_BLENDED_MIXED_SAMPLES) diff --git a/tests/data/test_memmap.py b/tests/data/test_memmap.py index d718f089c..ca887f3c1 100644 --- a/tests/data/test_memmap.py +++ b/tests/data/test_memmap.py @@ -27,8 +27,8 @@ def test_gpt_memmap(cache_directory): MEMMAP_DATASET_SPANS = { 9: [], - 10: [(0, 5), (6, 9)], - 13: [(1, 3)], + 10: [(0, 2), (2, 7), (7, 10)], + 13: [(0, 2)], 15: [], } diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index 90610381a..601abcf99 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -39,7 +39,7 @@ def test_write_memmap_dataset(dtype): GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) dataset = GPTMemmapDataset(name="foo", prefix=prefix) for i, (tokens, _, _, _) in enumerate(documents): - Assert.all_equal(dataset.get_document(i).tokens.tokens, tokens) + Assert.all_equal(dataset.get_document(i).tokens.tokens, tokens.to(torch.int64)) def _generate_valid_span(max_seq_length): @@ -64,11 +64,11 @@ def test_write_memmap_preference_dataset(dtype): parameters = GPTSamplingParameters( num_samples=0, sequence_length=0, vocab_size=0, use_preference_loss_spans=True ) - for i, (token_ids, _, chosen_spans, rejected_spans) in enumerate(documents): + for i, (token_ids, _, (chosen_begin, chosen_end), (rejected_begin, rejected_end)) in enumerate(documents): document = dataset.get_document(i, parameters=parameters) - Assert.all_equal(document.tokens.tokens, token_ids) - Assert.all_equal(document.chosen_spans.ranges, chosen_spans) - Assert.all_equal(document.rejected_spans.ranges, rejected_spans) + Assert.all_equal(document.tokens.tokens, token_ids.to(torch.int64)) + Assert.eq(document.chosen_spans.ranges, [(chosen_begin, chosen_end + 1)]) + Assert.eq(document.rejected_spans.ranges, [(rejected_begin, rejected_end + 1)]) def test_load_metadata_from_hub(): diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 4019dd909..58f4d3dab 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -180,4 +180,4 @@ def test_gpt_sample_padding(): else: sampled = dataset.sample(sampling) for idx in range(len(expected_samples)): - Assert.all_equal(sampled[idx].token_ids, np.array(expected_samples[idx])) + Assert.all_equal(sampled[idx].tokens.tokens, np.array(expected_samples[idx])) diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 3fae970f8..489f5e1c1 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -1,167 +1,80 @@ -import random - import pytest import torch from fast_llm.functional.config import ActivationType, MLPRecomputeLevel -from fast_llm.functional.dpo import _compute_dpo_loss, _compute_logprobs_for_preference_spans +from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped, torch_mlp_activation from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.utils import Assert +from tests.utils.dataset import get_random_spans from tests.utils.utils import requires_cuda -def ref_log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0) -> torch.Tensor: - if temperature != 1.0: - logits.div_(temperature) - batch_dim = logits.shape[:-1] - last_dim = logits.shape[-1] - - output = torch.nn.functional.cross_entropy(logits.reshape(-1, last_dim), labels.reshape(-1), reduction="none") - log_probs_labels = -output.view(*batch_dim) - - return log_probs_labels - - -def ref_packed_get_batch_logps( - logits: torch.FloatTensor, - labels: torch.LongTensor, - attention_mask, - prompt_id_lens, - packed_seq_lens, -) -> torch.FloatTensor: - labels = labels[:, 1:] - logits = logits[:, :-1, :] - per_token_logps = ref_log_probs_from_logits(logits, labels) - - loss_masks = attention_mask.clone().bool() - - index = 0 - for i, seq_len in enumerate(packed_seq_lens): - loss_masks[0, index : index + prompt_id_lens[i]] = False - index = index + seq_len - - loss_masks = loss_masks[:, 1:] - - logprobs_sums = [] - index = 0 - for i, seq_len in enumerate(packed_seq_lens): - seq = per_token_logps[0, index : index + seq_len - 1] - mask = loss_masks[0, index : index + seq_len - 1] - logprobs_sums.append((seq * mask).sum()) - index = index + seq_len - chosen_logps = logprobs_sums[: len(packed_seq_lens) // 2] - rejected_logps = logprobs_sums[len(packed_seq_lens) // 2 :] - - return torch.tensor(chosen_logps), torch.tensor(rejected_logps) - - -@pytest.mark.slow -@pytest.mark.parametrize( - ("batch_size", "seq_length", "vocab_size"), - ( - (2, 32, 50), - (1, 32, 50), - (2, 100, 50), - (2, 32, 200), - ), -) -def test_preference_logps(batch_size, seq_length, vocab_size): - random.seed(0) - torch.manual_seed(0) - - def random_split(seq_length): - min_val = int(seq_length * 0.3) - max_val = int(seq_length * 0.7) - - if max_val < min_val: - max_val = min_val - - a = random.randint(min_val, max_val) - b = seq_length - a - return [a, b] - - logits = torch.randn(batch_size, seq_length, vocab_size) - targets = torch.randint(0, vocab_size, (batch_size, seq_length)) - packed_seq_lens = random_split(seq_length) # simulate different chosen/rejected lengths - prompt_id_lens = [int(min(packed_seq_lens) * 0.75)] * 2 # sequences are 75% prompt 25% generation - attention_mask = torch.tensor([1] * packed_seq_lens[0] + [2] * packed_seq_lens[1]).unsqueeze(0) - - chosen_span = torch.tensor([[prompt_id_lens[0], packed_seq_lens[0] - 1]]) - 1 # shift by 1 due to label shifting - rejected_span = ( - torch.tensor([[packed_seq_lens[0] + prompt_id_lens[1], packed_seq_lens[0] + packed_seq_lens[1] - 1]]) - 1 - ) # shift by 1 due to label shifting - - ref_chosen_logps, ref_rejected_logps = ref_packed_get_batch_logps( - logits, targets, attention_mask, prompt_id_lens, packed_seq_lens +def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans: list[list[tuple[int, int]]]): + return sum( + log_probabilities[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(spans) + for begin, end in sample_spans ) - chosen_logps, rejected_logps, selected_log_probs = _compute_logprobs_for_preference_spans( - logits=logits, - targets=targets[:, 1:], - chosen_spans=chosen_span, - rejected_spans=rejected_span, - ) - - ref_logps = ref_log_probs_from_logits(logits[:, :-1, :], targets[:, 1:]) - - # check all logps - Assert.custom(torch.allclose, ref_logps, selected_log_probs, rtol=1e-5) - # check chosen and rejected summed logps - Assert.custom(torch.allclose, ref_chosen_logps, chosen_logps, rtol=1e-5) - Assert.custom(torch.allclose, ref_rejected_logps, rejected_logps, rtol=1e-5) - - -def ref_dpo_loss_fcn( - policy_chosen_logps: torch.Tensor, - policy_rejected_logps: torch.Tensor, - reference_chosen_logps: torch.Tensor, - reference_rejected_logps: torch.Tensor, - beta=1, - label_smoothing=0, +def reference_dpo_loss( + logits: torch.Tensor, + targets: torch.Tensor, + reference_model_logits: torch.Tensor, + chosen_spans: torch.Tensor, + rejected_spans: torch.Tensor, + beta: float, ) -> torch.Tensor: + # TODO: Too similar to the actual implementation. + policy_log_probs = ( + torch.nn.functional.log_softmax(logits.float(), dim=-1).gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) + ) + policy_chosen_logps = sum( + policy_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(chosen_spans) + for begin, end in sample_spans + ) + policy_rejected_logps = sum( + policy_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(rejected_spans) + for begin, end in sample_spans + ) + reference_log_probs = ( + torch.nn.functional.log_softmax(reference_model_logits.float(), dim=-1) + .gather(dim=-1, index=targets.unsqueeze(-1)) + .squeeze(-1) + ) + reference_chosen_logps = sum( + reference_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(chosen_spans) + for begin, end in sample_spans + ) + reference_rejected_logps = sum( + reference_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(rejected_spans) + for begin, end in sample_spans + ) pi_logratios = policy_chosen_logps - policy_rejected_logps ref_logratios = reference_chosen_logps - reference_rejected_logps - logits = pi_logratios - ref_logratios - - # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf) - losses = ( - -torch.nn.functional.logsigmoid(beta * logits) * (1 - label_smoothing) - - torch.nn.functional.logsigmoid(-beta * logits) * label_smoothing - ) - - loss = losses.mean() - - return loss + return -torch.nn.functional.logsigmoid(beta * (pi_logratios - ref_logratios)).mean() def test_dpo_loss(): torch.manual_seed(0) + logits = torch.randn((10, 50, 100), requires_grad=True) + reference_model_logits = torch.randn((10, 50, 100)) + targets = torch.randint(0, 100, (10, 50)) - NUM_SAMPLES = 20 - policy_chosen_logps = torch.rand(NUM_SAMPLES) - policy_rejected_logps = torch.rand(NUM_SAMPLES) - reference_chosen_logps = torch.rand(NUM_SAMPLES) - reference_rejected_logps = torch.rand(NUM_SAMPLES) - betas = torch.rand(NUM_SAMPLES) + spans = get_random_spans(10, 10, 50) - for i in range(NUM_SAMPLES): - fastllm_dpo_loss = _compute_dpo_loss( - policy_chosen_logps=policy_chosen_logps[i], - policy_rejected_logps=policy_rejected_logps[i], - reference_chosen_logps=reference_chosen_logps[i], - reference_rejected_logps=reference_rejected_logps[i], - beta=betas[i].item(), - ) - ref_dpo_loss = ref_dpo_loss_fcn( - policy_chosen_logps=policy_chosen_logps[i].unsqueeze(0), - policy_rejected_logps=policy_rejected_logps[i].unsqueeze(0), - reference_chosen_logps=reference_chosen_logps[i].unsqueeze(0), - reference_rejected_logps=reference_rejected_logps[i].unsqueeze(0), - beta=betas[i].item(), - ) - Assert.rms_close(fastllm_dpo_loss, ref_dpo_loss, 1e-5) + fastllm_loss, fast_llm_grad = compute_dpo_loss( + logits, targets, reference_model_logits, spans[::2], spans[1::2], beta=1, grad_output=1 + ) + reference_loss = reference_dpo_loss(logits, targets, reference_model_logits, spans[::2], spans[1::2], beta=1) + reference_loss.backward() + Assert.rms_close(fastllm_loss, reference_loss, 1e-5) + Assert.rms_close(fast_llm_grad, logits.grad, 1e-5) @requires_cuda diff --git a/tests/test_attention.py b/tests/test_attention.py index a19cba8f0..ff1a700fe 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -3,7 +3,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.attention.attention import Attention -from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs +from fast_llm.layers.attention.config import AttentionConfig, AttentionImplementation, AttentionKwargs from fast_llm.layers.block.config import BlockDimNames from fast_llm.utils import Assert @@ -29,7 +29,7 @@ def test_varlen_preprocessing(): micro_sequence_length = 12 sequence_length = 36 attention = Attention( - AttentionConfig(head_size=64), + AttentionConfig(head_size=64, implementation=AttentionImplementation.flash_varlen), DistributedConfig(compute_dtype="bfloat16"), hidden_dim=TensorDim("", 1), lr_scale=None, diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index 7d084b5ab..428dec56b 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -25,6 +25,15 @@ def download_santacoder_tokenizer(): transformers.AutoTokenizer.from_pretrained("bigcode/santacoder").save_pretrained(TOKENIZER_PATH) +def get_random_spans(num_samples: int, max_spans: int, lengths: np.ndarray | int, seed: int = 0): + spans = np.sort(np.random.RandomState(seed + 3847).randint(0, lengths, [num_samples, max_spans * 2])) + spans = [np.unique(sample_spans).tolist() for sample_spans in spans] + return [ + [(begin, end) for begin, end in zip(sample_spans[::2], sample_spans[1::2], strict=False)] + for sample_spans in spans + ] + + def get_test_dataset( prefix: pathlib.Path = DATASET_PREFIX, seed: int = 1234, @@ -55,25 +64,17 @@ def get_test_dataset( for document in texts ] if max_spans > 0: - spans = np.sort( - np.random.RandomState(seed + 3847).randint( - 0, np.array([[max(len(tokens), 1)] for tokens, _, _, _ in samples]), [len(samples), max_spans] - ) - ) - samples = ( - (tokens, np.unique(spans_).tolist()) for (tokens, _, _, _), spans_ in zip(samples, spans, strict=True) + spans = get_random_spans( + len(samples), max_spans, np.array([[max(len(tokens), 1)] for tokens, _, _, _ in samples]), seed ) samples = [ ( tokens, - torch.tensor( - [[begin, end] for begin, end in zip(spans_[::2], spans_[1::2], strict=False)], - dtype=torch.int32, - ).reshape(-1, 2), + torch.tensor(sample_spans, dtype=torch.int32).reshape(-1, 2), None, None, ) - for tokens, spans_ in samples + for (tokens, _, _, _), sample_spans in zip(samples, spans, strict=True) ] GPTMemmapDataset.write_dataset(prefix, samples) From 95d1840f92c9c25c22e8fab94577541bfafc0fc0 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 16 Oct 2025 18:49:50 -0400 Subject: [PATCH 07/45] test --- fast_llm/layers/attention/attention.py | 44 ++++++++++++++------------ 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 4e9b1b5b5..f00bc3c2d 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -215,7 +215,8 @@ def _attn_fused( attn_weights = torch.where(mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) - attn_weights = torch.dropout(attn_weights, self._config.dropout, self.training) + with set_generator(self._distributed.tp_generator): + attn_weights = torch.dropout(attn_weights, self._config.dropout, self.training) attn_output = torch.bmm( attn_weights.view(b * self._local_head_groups, sq * self._local_heads_per_group, sk), value ) @@ -334,13 +335,13 @@ def _forward( window_size = (-1, -1) if self._config.window_size is None else (self._config.window_size - 1, 0) - with set_generator(self._distributed.tp_generator): - if self._implementation == AttentionImplementation.flash_varlen: - assert _flash_available - out_dims = query.size() - query = query.view(-1, query.size(-2), query.size(-1)) - key = key.view(-1, key.size(-2), key.size(-1)) - value = value.view(-1, value.size(-2), value.size(-1)) + if self._implementation == AttentionImplementation.flash_varlen: + assert _flash_available + out_dims = query.size() + query = query.view(-1, query.size(-2), query.size(-1)) + key = key.view(-1, key.size(-2), key.size(-1)) + value = value.view(-1, value.size(-2), value.size(-1)) + with set_generator(self._distributed.tp_generator): input_ = ( _flash_attn_varlen_func( query, @@ -358,8 +359,9 @@ def _forward( .view(*out_dims) .flatten(-2) ) - elif self._implementation == AttentionImplementation.flash: - assert _flash_available + elif self._implementation == AttentionImplementation.flash: + assert _flash_available + with set_generator(self._distributed.tp_generator): input_ = _flash_attn_func( query, key, @@ -369,17 +371,17 @@ def _forward( causal=self._config.causal, softmax_scale=self._softmax_scale, ).flatten(-2) - elif self._implementation == AttentionImplementation.backup: - # TODO: Avoid the flattens. - input_ = self._attn_fused( - query.flatten(-2), - key.flatten(-2), - value.flatten(-2), - kwargs[AttentionKwargs.attention_mask], - kwargs[AttentionKwargs.attention_mask_value], - ) - else: - raise NotImplementedError(self._implementation) + elif self._implementation == AttentionImplementation.backup: + # TODO: Avoid the flattens. + input_ = self._attn_fused( + query.flatten(-2), + key.flatten(-2), + value.flatten(-2), + kwargs[AttentionKwargs.attention_mask], + kwargs[AttentionKwargs.attention_mask_value], + ) + else: + raise NotImplementedError(self._implementation) if self._debug.enabled: self._debug(query, "query", self._query_dims, kwargs) From eafd9cbbb9ff77fe6586e0152c6071c9142f5e8e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 16 Oct 2025 22:29:35 -0400 Subject: [PATCH 08/45] fixes --- fast_llm/data/dataset/sampled.py | 3 +- fast_llm/data/sample/abstract.py | 4 + fast_llm/data/sample/language_model.py | 8 ++ fast_llm/data/sample/range.py | 3 + fast_llm/data/sample/token.py | 3 + fast_llm/engine/config_utils/run.py | 3 +- fast_llm/layers/attention/attention.py | 96 +++++++++------------ fast_llm/layers/attention/config.py | 11 ++- fast_llm/layers/language_model/config.py | 7 ++ fast_llm/layers/language_model/embedding.py | 7 +- fast_llm/models/gpt/model.py | 2 +- tests/models/test_match_megatron.py | 24 +++--- tests/test_attention.py | 2 +- 13 files changed, 95 insertions(+), 78 deletions(-) diff --git a/fast_llm/data/dataset/sampled.py b/fast_llm/data/dataset/sampled.py index 441dfafae..46a518cd0 100644 --- a/fast_llm/data/dataset/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -389,8 +389,7 @@ def __getitem__(self, index: int) -> SampleType: # Document belongs to the next sample, need to account for padding. padding_size = self._parameters.sequence_length + 1 - tokens_in_sample if token_count > token_start: - # TODO: ====== Handle padding ====== - documents.append(PaddingSample(padding_size)) + documents.append(documents[-1].get_padding(padding_size)) Assert.eq(token_count + padding_size, token_end) break else: diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py index b2cb42cfe..031002101 100644 --- a/fast_llm/data/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -19,6 +19,10 @@ def crop(self, begin: int, end: int) -> typing.Self: def __len__(self) -> int: pass + @abc.abstractmethod + def get_padding(self, size: int) -> typing.Self: + pass + class Batch(abc.ABC): # TODO: Relate to `BatchConfig`? diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index 0a4efa47b..f30188553 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -38,6 +38,14 @@ def crop(self, begin: int, end: int) -> typing.Self: def __len__(self) -> int: return len(self.tokens) + def get_padding(self, size: int) -> typing.Self: + return LanguageModelSample( + self.tokens.get_padding(size), + None if self.loss_masking_spans is None else self.loss_masking_spans.get_padding(size), + None if self.chosen_spans is None else self.chosen_spans.get_padding(size), + None if self.rejected_spans is None else self.rejected_spans.get_padding(size), + ) + class LanguageModelBatch(Batch): def __init__( diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py index ec71b98d6..d121a38b6 100644 --- a/fast_llm/data/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -32,6 +32,9 @@ def crop(self, begin: int, end: int) -> typing.Self: def __len__(self) -> int: return self.sample_size + def get_padding(self, size: int) -> typing.Self: + return RangeSample([], size) + class RangeBatch(Batch): def __init__(self, ranges: list[list[tuple[int, int]]], sample_size: int): diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index d12b27fa0..62d1c0e67 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -44,6 +44,9 @@ def crop(self, begin: int, end: int) -> typing.Self: def __len__(self) -> int: return len(self.tokens) + def get_padding(self, size: int) -> typing.Self: + return TokenSample(torch.full([size], -100, dtype=self.tokens.dtype), [size]) + class TokenBatch(Batch): def __init__(self, tokens: torch.Tensor, lengths: list[list[int]] | None) -> None: diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 1849a2316..3fca005f5 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -44,7 +44,8 @@ class RunConfig(Config): # Enable torch compile. torch_dynamo_enable: bool = Field( default=True, - desc="Set to False to disable torch compile entirely. Not recommended unless there is a good reason to do so.", + desc="Set to False to " + "disable torch compile entirely. Not recommended unless there is a good reason to do so.", hint=FieldHint.expert, ) enable_triton_kernels: bool = Field( diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index f00bc3c2d..ffbe9955e 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -215,8 +215,7 @@ def _attn_fused( attn_weights = torch.where(mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) - with set_generator(self._distributed.tp_generator): - attn_weights = torch.dropout(attn_weights, self._config.dropout, self.training) + attn_weights = torch.dropout(attn_weights, self._config.dropout, self.training) attn_output = torch.bmm( attn_weights.view(b * self._local_head_groups, sq * self._local_heads_per_group, sk), value ) @@ -334,54 +333,48 @@ def _forward( query, key = self._rotary(query, key, kwargs) window_size = (-1, -1) if self._config.window_size is None else (self._config.window_size - 1, 0) - - if self._implementation == AttentionImplementation.flash_varlen: - assert _flash_available - out_dims = query.size() - query = query.view(-1, query.size(-2), query.size(-1)) - key = key.view(-1, key.size(-2), key.size(-1)) - value = value.view(-1, value.size(-2), value.size(-1)) - with set_generator(self._distributed.tp_generator): - input_ = ( - _flash_attn_varlen_func( + with set_generator(self._distributed.tp_generator): + if self._implementation == AttentionImplementation.flash: + assert _flash_available + if self._config.cross_document_attention: + input_ = _flash_attn_func( query, key, value, - cu_seqlens_q=kwargs[AttentionKwargs.cu_seqlens_q], - cu_seqlens_k=kwargs[AttentionKwargs.cu_seqlens_k], - max_seqlen_q=kwargs[AttentionKwargs.max_seqlen_q], - max_seqlen_k=kwargs[AttentionKwargs.max_seqlen_k], - dropout_p=self._config.dropout if self.training else 0.0, window_size=window_size, + dropout_p=self._config.dropout if self.training else 0.0, causal=self._config.causal, softmax_scale=self._softmax_scale, + ).flatten(-2) + else: + input_ = ( + _flash_attn_varlen_func( + query.view(-1, query.size(-2), query.size(-1)), + key.view(-1, key.size(-2), key.size(-1)), + value.view(-1, value.size(-2), value.size(-1)), + cu_seqlens_q=kwargs.get(AttentionKwargs.cu_seqlens_q), + cu_seqlens_k=kwargs.get(AttentionKwargs.cu_seqlens_k), + max_seqlen_q=kwargs.get(AttentionKwargs.max_seqlen_q), + max_seqlen_k=kwargs.get(AttentionKwargs.max_seqlen_k), + dropout_p=self._config.dropout if self.training else 0.0, + window_size=window_size, + causal=self._config.causal, + softmax_scale=self._softmax_scale, + ) + .view(query.size()) + .flatten(-2) ) - .view(*out_dims) - .flatten(-2) + elif self._implementation == AttentionImplementation.backup: + # TODO: Avoid the flattens. + input_ = self._attn_fused( + query.flatten(-2), + key.flatten(-2), + value.flatten(-2), + kwargs[AttentionKwargs.attention_mask], + kwargs[AttentionKwargs.attention_mask_value], ) - elif self._implementation == AttentionImplementation.flash: - assert _flash_available - with set_generator(self._distributed.tp_generator): - input_ = _flash_attn_func( - query, - key, - value, - window_size=window_size, - dropout_p=self._config.dropout if self.training else 0.0, - causal=self._config.causal, - softmax_scale=self._softmax_scale, - ).flatten(-2) - elif self._implementation == AttentionImplementation.backup: - # TODO: Avoid the flattens. - input_ = self._attn_fused( - query.flatten(-2), - key.flatten(-2), - value.flatten(-2), - kwargs[AttentionKwargs.attention_mask], - kwargs[AttentionKwargs.attention_mask_value], - ) - else: - raise NotImplementedError(self._implementation) + else: + raise NotImplementedError(self._implementation) if self._debug.enabled: self._debug(query, "query", self._query_dims, kwargs) @@ -425,12 +418,9 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c attention_compute = sequence_q * sequence_k * attn_compute_base - if (not config.hardware) or self._implementation in ( - AttentionImplementation.flash, - AttentionImplementation.flash_varlen, - ): + if (not config.hardware) or self._implementation in AttentionImplementation.flash: # Remove non-causal part. (TODO: Support non-causal) - # For varlen implementation, compute is overestimated as we include cross-document attention. + # TODO: Compute is overestimated without cross-document attention. attention_compute -= (sequence_q * (sequence_q - 1) * attn_compute_base) // 2 if self._config.window_size is not None: @@ -457,8 +447,8 @@ def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None self._rotary.preprocess(batch, kwargs) if self._implementation == AttentionImplementation.backup: self._preprocess_for_backup_attention(batch, kwargs) - elif self._implementation == AttentionImplementation.flash_varlen: - self._preprocess_for_varlen(batch, kwargs) + elif self._implementation == AttentionImplementation.flash: + self._preprocess_for_flash_attention(batch, kwargs) def _preprocess_for_backup_attention(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: if ( @@ -487,11 +477,11 @@ def _preprocess_for_backup_attention(self, batch: torch.Tensor, kwargs: dict[str kwargs[AttentionKwargs.attention_mask] = self._backup_attention_mask[ None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k ] - if (sequence_lengths := kwargs.get(AttentionKwargs.sequence_lengths, None)) is not None: + if not self._config.cross_document_attention: seq_ids = torch.stack( [ torch.cat([torch.full((x,), i) for i, x in enumerate(sample_lens)]) - for sample_lens in sequence_lengths + for sample_lens in kwargs[AttentionKwargs.sequence_lengths] ] ) document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(batch.device) @@ -501,7 +491,7 @@ def _preprocess_for_backup_attention(self, batch: torch.Tensor, kwargs: dict[str ) kwargs[AttentionKwargs.attention_mask_value] = self._backup_attention_mask_value - def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + def _preprocess_for_flash_attention(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: """ Prepares cu_seqlens_q and cu_seqlens_k for flash_attn_varlen_func: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1375 @@ -511,7 +501,7 @@ def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.A also contain previous tokens from the first document in micro-sequence. We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. """ - if AttentionKwargs.sequence_lengths not in kwargs: + if self._config.cross_document_attention: return sequence_lengths = kwargs[AttentionKwargs.sequence_lengths] sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index c02b67293..206fa6e6f 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -34,7 +34,6 @@ class AttentionKwargs(BlockKwargs): class AttentionImplementation(enum.StrEnum): auto = "auto" flash = "flash" - flash_varlen = "flash_varlen" backup = "backup" @@ -115,9 +114,13 @@ class AttentionConfig(MixerConfig): ) implementation: AttentionImplementation = Field( default=AttentionImplementation.auto, - desc="The implementation to use for the attention layer.", - doc="Use `flash_varlen` to enable the varlen version of Flash Attention and prevent cross-document attention. " - "Default: `flash` if supported, otherwise `backup`,", + desc="The implementation to use for the attention layer. Default: `flash` if supported, otherwise `backup`.", + hint=FieldHint.feature, + ) + cross_document_attention: bool = Field( + default=True, + desc="Allow for cross-document attention.", + doc="Disable to prevent attention between tokens belonging to different documents.", hint=FieldHint.feature, ) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 25fa2d91e..18c64acc4 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -53,6 +53,13 @@ class LanguageModelEmbeddingsConfig(BlockConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + cross_document_position_embeddings: bool = Field( + default=True, + desc="Allow for cross-document position embeddings.", + doc="Disable to reset position ids at the beginning of each document.", + hint=FieldHint.feature, + ) + dropout: float = Field( default=0.0, desc="Dropout applied to the embedding layer.", diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 0ad3225c8..61ca1cfc0 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -136,9 +136,12 @@ def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None self._create_position_embeddings(kwargs[LanguageModelKwargs.sequence_length], batch.device) sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size - if (sequence_lengths := kwargs.get(LanguageModelKwargs.sequence_lengths)) is not None: + if not self._config.cross_document_position_embeddings: position_ids = torch.stack( - [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] + [ + torch.cat([torch.arange(x) for x in sample_lens]) + for sample_lens in kwargs[LanguageModelKwargs.sequence_lengths] + ] ).to(batch.device, dtype=torch.int64) position_ids = position_ids[:, sequence_k - sequence_q : sequence_k] if kwargs[LanguageModelKwargs.sequence_first]: diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 56438670f..3295295f6 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -198,7 +198,7 @@ def preprocess_batch( **kwargs_meta, AttentionKwargs.past_key_values: pasts, AttentionKwargs.presents: presents, - AttentionKwargs.sequence_lengths: cropped_tokens.lengths, + AttentionKwargs.sequence_lengths: batch.tokens.lengths, **reference_logits[i], } diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index e9690a3c5..7447e395a 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -3,7 +3,6 @@ import numpy as np import pytest -import torch from fast_llm.config import Field, FieldHint, config_class from fast_llm.data.dataset.abstract import SampledDataset @@ -12,7 +11,6 @@ from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.sampled import logger from fast_llm.data.sample.language_model import LanguageModelSample -from fast_llm.data.sample.token import TokenSample from fast_llm.utils import Assert from tests.utils.compare_tensor_logs import CompareConfig from tests.utils.dataset import get_model_test_dataset @@ -145,18 +143,16 @@ def __getitem__(self, idx: int) -> typing.Any: shuffled_idx = self._shuffle_idx[idx] doc_f, offset_f = self._sample_idx[shuffled_idx] doc_l, offset_l = self._sample_idx[shuffled_idx + 1] - sample_list = [ - self._indexed_dataset.get_document( - self._doc_idx[doc].item(), - begin=(doc == doc_f) * offset_f, - end=offset_l + 1 if doc == doc_l else None, - ) - for doc in range(doc_f, doc_l + 1) - ] - token_ids = torch.cat([sample.token_ids for sample in sample_list]) - Assert.eq(len(token_ids), self._sequence_length + 1) - - return LanguageModelSample(TokenSample(token_ids)) + return LanguageModelSample.from_documents( + [ + self._indexed_dataset.get_document( + self._doc_idx[doc].item(), + begin=(doc == doc_f) * offset_f, + end=offset_l + 1 if doc == doc_l else None, + ) + for doc in range(doc_f, doc_l + 1) + ] + ) @property def name(self) -> str: diff --git a/tests/test_attention.py b/tests/test_attention.py index ff1a700fe..b86cc95fa 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -29,7 +29,7 @@ def test_varlen_preprocessing(): micro_sequence_length = 12 sequence_length = 36 attention = Attention( - AttentionConfig(head_size=64, implementation=AttentionImplementation.flash_varlen), + AttentionConfig(head_size=64, implementation=AttentionImplementation.flash, cross_document_attention=False), DistributedConfig(compute_dtype="bfloat16"), hidden_dim=TensorDim("", 1), lr_scale=None, From c56df69e9764eeae5411571a626a1f181963beb5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 16 Oct 2025 23:29:31 -0400 Subject: [PATCH 09/45] cleanup --- fast_llm/data/dataset/gpt/memmap.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 486afee1d..06d8d7acc 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -278,7 +278,6 @@ def write_dataset( num_spans = np.array(num_spans, dtype=np.int32) if len(spans) > 0: spans = np.vstack(spans, dtype=np.int32) - print("JEFNEW", spans[:50].tolist()) else: spans = np.array(spans, dtype=np.int32) chosen_spans = np.array(chosen_spans, dtype=np.int32).reshape(-1, 2) From 7f437e1341605717d630afd14cac07e7ab3617c5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 16 Oct 2025 23:59:35 -0400 Subject: [PATCH 10/45] misc --- fast_llm/data/dataset/config.py | 5 +---- fast_llm/data/preparator/gpt_memmap/prepare.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index e93e5865a..20e40b66e 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -147,10 +147,7 @@ class ConcatenatedDatasetConfig[SampleType: Sample](SamplableDatasetConfig[Sampl def build(self) -> "ConcatenatedDataset": from fast_llm.data.dataset.indexed import ConcatenatedDataset - return self._build(ConcatenatedDataset) - - def _build[T: ConcatenatedDataset](self, cls: type[T]) -> T: - return cls(self.name, [dataset.build() for dataset in self.datasets]) + return ConcatenatedDataset(self.name, [dataset.build() for dataset in self.datasets]) @config_class(dynamic_type={SampledDatasetConfig: "slice"}) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 73dba6ccc..274bbf1b0 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -250,7 +250,7 @@ def run(self) -> None: datasets.builder.has_sufficient_disk_space = lambda needed_bytes, directory=".": True # Load tokenizer - self._tokenizer = Tokenizer(config=self._config.tokenizer) + self._tokenizer = self._config.tokenizer.get_tokenizer() # Decide the datatype based on the tokenizer vocabulary size self._data_type = ( From dfd27f54f949eebc54e2604d01f26a4769f2b652 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 17 Oct 2025 00:18:14 -0400 Subject: [PATCH 11/45] misc --- fast_llm/engine/config_utils/run.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 3fca005f5..1849a2316 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -44,8 +44,7 @@ class RunConfig(Config): # Enable torch compile. torch_dynamo_enable: bool = Field( default=True, - desc="Set to False to " - "disable torch compile entirely. Not recommended unless there is a good reason to do so.", + desc="Set to False to disable torch compile entirely. Not recommended unless there is a good reason to do so.", hint=FieldHint.expert, ) enable_triton_kernels: bool = Field( From 90cd0096309d44b60000a667888759d261c1338a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Sat, 18 Oct 2025 00:07:22 -0400 Subject: [PATCH 12/45] Memmap dataset --- fast_llm/data/dataset/config.py | 32 ++ fast_llm/data/dataset/gpt/config.py | 36 +-- .../gpt/{memmap.py => legacy_memmap.py} | 119 +------ fast_llm/data/dataset/memmap.py | 100 ++++++ fast_llm/data/preparator/gpt_memmap/config.py | 53 +++- .../data/preparator/gpt_memmap/prepare.py | 290 ++++++------------ fast_llm/data/sample/abstract.py | 152 +++++++++ fast_llm/data/sample/language_model.py | 186 ++++++++++- fast_llm/data/sample/range.py | 81 ++++- fast_llm/data/sample/token.py | 86 +++++- fast_llm/data/tokenizer.py | 50 +-- tests/data/test_blending.py | 10 +- tests/data/test_concatenate.py | 6 +- tests/data/test_dataset_from_file.py | 4 +- tests/data/test_fim.py | 6 +- tests/data/test_memmap.py | 14 +- tests/data/test_prepare_gpt_memmap.py | 44 +-- tests/data/test_sampling.py | 10 +- tests/data/test_slice.py | 10 +- tests/models/test_match_megatron.py | 28 +- tests/utils/dataset.py | 50 ++- tests/utils/global_variables.py | 5 +- tests/utils/model_configs.py | 10 +- tools/concatenate_dataset.py | 60 ---- 24 files changed, 899 insertions(+), 543 deletions(-) rename fast_llm/data/dataset/gpt/{memmap.py => legacy_memmap.py} (61%) create mode 100644 fast_llm/data/dataset/memmap.py delete mode 100644 tools/concatenate_dataset.py diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 20e40b66e..f1bc3472a 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -2,6 +2,7 @@ import enum import functools import itertools +import logging import math import pathlib import typing @@ -9,12 +10,15 @@ from fast_llm.config import Config, Field, FieldHint, UpdateType, check_field, config_class from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.sample.abstract import Sample +from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert, normalize_probabilities if typing.TYPE_CHECKING: from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset from fast_llm.engine.distributed.distributed import Distributed +logger = logging.getLogger(__name__) + class ShufflingType(str, enum.Enum): # Shuffle all epochs together. Not extendable. @@ -266,3 +270,31 @@ def build_and_sample( self.weights, sampling, ) + + +@config_class(dynamic_type={SampledDatasetConfig: "memmap"}) +class MemmapDatasetConfig[SampleType: LanguageModelSample](IndexedDatasetConfig[SampleType]): + _abstract: typing.ClassVar[bool] = False + path: pathlib.Path = Field( + default=None, + desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.", + hint=FieldHint.core, + ) + + def build(self) -> "IndexedDataset[SampleType]": + name = str(self.path).replace("/", "__") + if self.path.is_file(): + from fast_llm.data.dataset.memmap import MemmapDataset + + return MemmapDataset[SampleType](name, self.path) + elif self.path.with_suffix(".bin").is_file() and self.path.with_suffix(".idx").is_file(): + logger.warning( + "Using the legacy memmap dataset format." + " This format is deprecated and will be removed in a future release." + " Please recreate the dataset in the new memmap format." + ) + from fast_llm.data.dataset.gpt.legacy_memmap import LegacyMemmapDataset + + return LegacyMemmapDataset[SampleType](name, self.path) + else: + raise FileNotFoundError(self.path) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 15f54ec80..9ff6654c2 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -8,19 +8,12 @@ from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.data.config import TokenizerConfig from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset -from fast_llm.data.dataset.config import ( - IndexedDatasetConfig, - SamplableDatasetConfig, - SampledDatasetConfig, - SamplingData, - SamplingParameters, -) +from fast_llm.data.dataset.config import SamplableDatasetConfig, SampledDatasetConfig, SamplingData, SamplingParameters from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.data.dataset.gpt.fim import GPTFimDataset - from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.random import GPTRandomDataset @@ -60,33 +53,6 @@ def build(self) -> "GPTRandomDataset[SampleType]": return GPTRandomDataset[SampleType](self.name) -@config_class(dynamic_type={SampledDatasetConfig: "memmap"}) -class GPTMemmapDatasetConfig[SampleType: LanguageModelSample](IndexedDatasetConfig[SampleType]): - _abstract: typing.ClassVar[bool] = False - path: pathlib.Path = Field( - default=None, - desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.", - hint=FieldHint.core, - ) - num_documents: int | None = Field( - default=None, - desc="Expected number of documents in the dataset.", - hint=FieldHint.optional, - ) - num_tokens: int | None = Field( - default=None, - desc="Expected number of tokens in the dataset.", - hint=FieldHint.optional, - ) - - def build(self) -> "GPTMemmapDataset[SampleType]": - from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset - - return GPTMemmapDataset[SampleType]( - str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens - ) - - @config_class(dynamic_type={SampledDatasetConfig: "file"}) class GPTDatasetFromFileConfig[SampleType: LanguageModelSample](SamplableDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/legacy_memmap.py similarity index 61% rename from fast_llm/data/dataset/gpt/memmap.py rename to fast_llm/data/dataset/gpt/legacy_memmap.py index 06d8d7acc..d8c63e9f9 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/legacy_memmap.py @@ -1,21 +1,19 @@ import pathlib import struct -import typing import numpy as np import torch from fast_llm.data.dataset.gpt.config import GPTSamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset -from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER +from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_INDEX_HEADER from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.range import RangeSample from fast_llm.data.sample.token import TokenSample -from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert, div -class GPTMemmapDataset[SampleType: LanguageModelSample](IndexedDataset[SampleType]): +class LegacyMemmapDataset[SampleType: LanguageModelSample](IndexedDataset[SampleType]): """ A memory map dataset, which handles lazy loading of a pre-processed dataset in the Megatron-LM format, i.e. a pair of numpy file containing @@ -28,12 +26,10 @@ def __init__( self, name: str, prefix: pathlib.Path | str, - num_documents: int | None = None, - num_tokens: int | None = None, ): - self._init(name, prefix, num_documents, num_tokens) + self._init(name, prefix) - def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None, num_tokens: int | None) -> None: + def _init(self, name: str, prefix: pathlib.Path | str) -> None: super().__init__() self._name = name self._prefix = pathlib.Path(prefix) @@ -54,9 +50,6 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None _ = struct.unpack(" tuple[str, pathlib.Path, int | None, int | None]: - return (self._name, self._prefix, self._num_documents, self._num_tokens) + def __getstate__(self) -> tuple[str, pathlib.Path]: + return (self._name, self._prefix) - def __setstate__(self, state: tuple[str, pathlib.Path, int | None, int | None]): + def __setstate__(self, state: tuple[str, pathlib.Path]): self._init(*state) def __del__(self): @@ -168,7 +159,7 @@ def get_document( token_ids = token_ids.to(torch.int64) if parameters is not None and parameters.use_loss_masking_spans: assert self._spans is not None - # TODO: ====== Store in range format (begin, end) ====== + # Convert to in range format (begin, end). sample_spans = RangeSample( [(begin_, last_ + 1) for begin_, last_ in self._spans[index].tolist()], sample_size ).crop(begin, end) @@ -182,7 +173,7 @@ def get_document( raise ValueError("Failed to read chosen spans from memmap dataset.") elif self._has_preference_spans and self._rejected_spans is None: raise ValueError("Failed to read rejected spans from memmap dataset.") - # TODO: ====== Store in range format ====== + # Convert to in range format (begin, end). chosen_spans = RangeSample( [(self._chosen_spans[index][0].item(), self._chosen_spans[index][1].item() + 1)], sample_size, @@ -222,95 +213,3 @@ def get_document_sizes(self) -> torch.Tensor: def get_document_size(self, index: int) -> int: return self._document_sizes[index].item() - - @classmethod - def write_dataset( - cls, - prefix: pathlib.Path | str, - documents: typing.Iterable[tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]], - ) -> None: - # Initialize metadata - dtype = None - num_documents = 0 - lengths = [] - pointers = [] - offset = 0 - # number of spans for each document - num_spans = [] - spans = [] - chosen_spans = [] - rejected_spans = [] - - prefix = pathlib.Path(prefix) - prefix.parent.mkdir(parents=True, exist_ok=True) - - # Write the binary data file (.bin) lazily - with prefix.with_suffix(".bin").open("wb") as bin_stream: - for token_ids, loss_masking_spans, chosen_span, rejected_span in documents: - # Infer dtype from the first document - if dtype is None: - dtype = token_ids.dtype - assert dtype is not None, "Document dtype could not be inferred from the data." - - # Ensure all documents have the same dtype - assert token_ids.dtype == dtype, f"Expected dtype {dtype}, got {token_ids.dtype}." - - # Write document to binary file - bin_stream.write(token_ids.numpy().tobytes(order="C")) - - # Update metadata - doc_length = len(token_ids) - lengths.append(doc_length) - pointers.append(offset) - if loss_masking_spans is not None: - num_spans.append(len(loss_masking_spans)) - spans.append(loss_masking_spans) - if chosen_span is not None: - chosen_spans.append(chosen_span) - if rejected_span is not None: - rejected_spans.append(rejected_span) - offset += doc_length * dtype.itemsize - num_documents += 1 - - # Finalize metadata arrays - lengths = np.array(lengths, dtype=np.int32) - pointers = np.array(pointers, dtype=np.int64) - num_spans = np.array(num_spans, dtype=np.int32) - if len(spans) > 0: - spans = np.vstack(spans, dtype=np.int32) - else: - spans = np.array(spans, dtype=np.int32) - chosen_spans = np.array(chosen_spans, dtype=np.int32).reshape(-1, 2) - rejected_spans = np.array(rejected_spans, dtype=np.int32).reshape(-1, 2) - - # Write the index file (.idx) - with prefix.with_suffix(".idx").open("wb") as idx_stream: - idx_stream.write(MEMMAP_INDEX_HEADER) - # Indicates the version - # Version 2 optionally adds loss-masking spans - # Version 3 optionally adds chosen/rejected spans - idx_stream.write(struct.pack(" 0 else 0)) - # Flag to indicate whether preference loss-masking spans are present - idx_stream.write(struct.pack(" 0 and rejected_spans.size > 0 else 0)) - # Data type - idx_stream.write(struct.pack(" None: + super().__init__() + self._name = name + self._path = path + + with self._path.open("rb") as stream: + # Very file type. + assert stream.read(len(FILE_HEADER)) == FILE_HEADER + # Go to reader configs. + stream.seek(int.from_bytes(stream.read(4), signed=False)) + # Read the reader config. + reader_config = MemmapIndexDatasetReaderConfig.from_dict( + json.loads(stream.read(int.from_bytes(stream.read(4), signed=False)).decode("utf-8")) + ) + + self._memmap = np.memmap(self._path, mode="r") + # TODO: ===== Check num_documents, num_tokens ====== + self._reader = reader_config.get_reader(memoryview(self._memmap)) + + def __getstate__(self) -> tuple[str, pathlib.Path]: + return (self._name, self._path) + + def __setstate__(self, state: tuple[str, pathlib.Path]): + self._init(*state) + + def __del__(self): + if hasattr(self, "_memmap"): + self._memmap._mmap.close() # noqa + del self._memmap + + def get_document( + self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None + ) -> SampleType: + return self._reader.get_document(index, begin, end) + + @property + def name(self) -> str: + return self._name + + def __len__(self) -> int: + return self._reader + + # TODO: ====== needed? ====== + # @property + # def num_tokens(self) -> int: + # return self._reader.num_tokens + + def get_document_sizes(self) -> torch.Tensor: + return self._reader.get_document_sizes() + + def get_document_size(self, index: int) -> int: + return self._reader.get_document_size(index) + + @classmethod + def write_dataset(cls, path: pathlib.Path, documents: typing.Iterable[Sample], writer_class: type[MemmapWriter]): + # TODO: Match `writer_class` with `SampleType`? + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("wb") as stream: + # Write the file type header. + stream.write(FILE_HEADER) + # Leave space for a pointer to the reader config. + # We write the config at the end since we don't know it yet. + start = stream.tell() + stream.seek(start + 4) + # Write the data. + reader_config = writer_class.write_dataset(stream, documents) + # Write the reader config. + config_offset = stream.tell() + reader_config_bytes = json.dumps(reader_config.to_dict()).encode("utf-8") + stream.write(len(reader_config_bytes).to_bytes(4, signed=False)) + stream.write(reader_config_bytes) + # Write a pointer to the reader config. + stream.seek(start) + stream.write(config_offset.to_bytes(4, signed=False)) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index d2aaee5e2..c193cf942 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -1,3 +1,4 @@ +import functools import os import pathlib import typing @@ -25,14 +26,9 @@ MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x00" -@config_class(registry=True) -class SourceSchemaConfig(Config): - pass - - -@config_class(dynamic_type={SourceSchemaConfig: "text_column"}) -class TextColumnConfig(SourceSchemaConfig): - input_column: str = Field( +@config_class() +class LanguageModelSourceConfig(Config): + text_column: str = Field( default="text", desc="Field of the dataset to use.", hint=FieldHint.optional, @@ -40,6 +36,38 @@ class TextColumnConfig(SourceSchemaConfig): loss_masking_spans_column: None | str = Field( default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional ) + chosen_spans_column: None | str = Field( + default=None, desc="Field containing chosen text for preference optimization", hint=FieldHint.optional + ) + rejected_spans_column: None | str = Field( + default=None, desc="Field containing rejected text for preference optimization", hint=FieldHint.optional + ) + + @functools.cached_property + def columns(self) -> list[str]: + columns = [self.text_column] + if self.has_loss_masking_span: + columns.append(self.loss_masking_spans_column) + if self.has_preference_spans: + columns.extend([self.chosen_spans_column, self.rejected_spans_column]) + return columns + + @functools.cached_property + def has_loss_masking_span(self) -> bool: + return self.loss_masking_spans_column is not None + + @functools.cached_property + def has_preference_spans(self) -> bool: + Assert.eq(self.chosen_spans_column is None, self.rejected_spans_column is None) + return self.chosen_spans_column is not None + + def _validate(self): + super()._validate() + if self.has_loss_masking_span != self.rejected_spans_column is not None: + raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.") + if self.has_preference_spans == self.has_loss_masking_span: + # TODO: ====== Still needed? ====== + raise ValueError(f"Can not enable both loss masking and preference spans.") @config_class() @@ -69,16 +97,10 @@ class GPTHuggingfaceDatasetConfig(Config): desc="Split of the dataset to use.", hint=FieldHint.optional, ) - source_schema: SourceSchemaConfig = Field( + source_schema: LanguageModelSourceConfig = Field( desc="Configuration for the data source.", hint=FieldHint.optional, ) - chosen_text: None | str = Field( - default=None, desc="Field containing chosen text for preference optimization", hint=FieldHint.optional - ) - rejected_text: None | str = Field( - default=None, desc="Field containing rejected text for preference optimization", hint=FieldHint.optional - ) data_type: DataType | None = Field( default=None, desc="Data type of the dataset field." @@ -133,7 +155,6 @@ def _validate(self) -> None: @config_class(dynamic_type={RunnableConfig: "prepare_gpt_memmap", DatasetPreparatorConfig: "gpt_memmap"}) class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): preparator_name: typing.ClassVar[str] = "gpt_memmap" - output_path: pathlib.Path = Field( default=None, desc="Output directory for the processed dataset.", diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 274bbf1b0..18ab2d787 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -1,5 +1,6 @@ import json import logging +import math import multiprocessing import pathlib import shutil @@ -18,13 +19,15 @@ BlendedDatasetConfig, DatasetSliceConfig, IndexedDatasetConfig, + MemmapDatasetConfig, SampledDatasetConfig, ) -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.dataset.memmap import MemmapDataset from fast_llm.data.preparator.config import DatasetPreparator -from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, TextColumnConfig -from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, LanguageModelSourceConfig +from fast_llm.data.sample.language_model import LanguageModelSample, LanguageModelWriter +from fast_llm.data.sample.range import RangeSample +from fast_llm.data.sample.token import TokenSample from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum @@ -35,154 +38,24 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](DatasetPreparator[ConfigType]): _tokenizer: Tokenizer _data_type: DataType - _text_column: str - _loss_masking_spans_column: str | None _sample_type: typing.ClassVar[type[LanguageModelSample]] = LanguageModelSample + _config: GPTMemmapDatasetPreparatorConfig - def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids = [ - np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) for text in batch[self._text_column] - ] - num_tokens = [len(x) for x in input_ids] - return { - "input_ids": input_ids, - "num_tokens": num_tokens, - } - - def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids, token_spans = map( - list, - zip( - *[ - ( - np.array(input_ids, dtype=self._data_type.numpy), - np.array(token_spans, dtype=np.int32).reshape(-1, 2), - ) - for input_ids, token_spans in [ - self._tokenizer.tokenize_with_spans(text, char_spans) - for text, char_spans in zip(batch[self._text_column], batch[self._loss_masking_spans_column]) - ] - ] - ), - ) - num_tokens = [len(x) for x in input_ids] - return { - "input_ids": input_ids, - "token_spans": token_spans, - "num_tokens": num_tokens, - } - - def _tokenize_preference_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - packed_texts = [] - chosen_spans = [] - rejected_spans = [] - - for conv_history, chosen_text, rejected_text in zip( - batch[self._config.dataset.field], - batch[self._config.dataset.chosen_text], - batch[self._config.dataset.rejected_text], - ): - # compute chosen span - full_chosen_text = conv_history + chosen_text + self._tokenizer.tokenizer.eos_token - chosen_span = [len(conv_history), len(full_chosen_text) - 1] - offset = len(full_chosen_text) - chosen_spans.append(chosen_span) + def __init__(self, config: ConfigType): + super().__init__(config) + self._source_shema: LanguageModelSourceConfig = self._config.dataset.source_shema - # compute rejected span - full_rejected_text = self._tokenizer.tokenizer.bos_token + conv_history + rejected_text - rejected_span = [ - offset + len(self._tokenizer.tokenizer.bos_token + conv_history), - offset + len(full_rejected_text) - 1, - ] - rejected_spans.append(rejected_span) + def _save_shard(self, args: tuple[int, datasets.Dataset]) -> MemmapDatasetConfig: + shard_index, shard_dataset = args + file_name = f"shard_{self._config.distributed.rank}_{shard_index}.fast_llm_dataset" - # pack texts - packed_text = full_chosen_text + full_rejected_text - - assert ( - packed_text[chosen_span[0] : chosen_span[1] + 1] == chosen_text + self._tokenizer.tokenizer.eos_token - ), f"{packed_text[chosen_span[0]: chosen_span[1] + 1]} does not match {chosen_text}" - - assert ( - packed_text[rejected_span[0] : rejected_span[1] + 1] == rejected_text - ), f"{packed_text[rejected_span[0]: rejected_span[1] + 1]} does not match {rejected_text}" - packed_texts.append(packed_text) - - # tokenize with spans - input_ids, chosen_token_spans, rejected_token_spans = map( - list, - zip( - *[ - ( - np.array(input_ids, dtype=self._data_type.numpy), - np.array(token_spans[0], dtype=np.int32), - np.array( - [token_spans[1][0], token_spans[1][1] + 1], dtype=np.int32 - ), # adding 1 to end for eos token - ) - for input_ids, token_spans in [ - self._tokenizer.tokenize_with_spans(text, [chosen_span, rejected_span]) - for text, chosen_span, rejected_span in zip(packed_texts, chosen_spans, rejected_spans) - ] - ] - ), + MemmapDataset.write_dataset( + self._config.output_path / file_name, + tqdm.tqdm((sample["sample"] for sample in shard_dataset), desc=f"Saving shard {shard_index}", unit="docs"), + LanguageModelWriter, ) - num_tokens = [len(x) for x in input_ids] - return { - "input_ids": input_ids, - "chosen_token_spans": chosen_token_spans, - "rejected_token_spans": rejected_token_spans, - "num_tokens": num_tokens, - } - - def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetConfig: - shard_idx, shard_dataset = args - prefix = f"shard_{self._config.distributed.rank}_{shard_idx}" - shard_output_path = self._config.output_path / prefix - - def _document_generator(): - # TODO: Yield `LanguageModelSample` - if "token_spans" in shard_dataset.column_names and self._loss_masking_spans_column is not None: - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield ( - torch.tensor(item["input_ids"], dtype=self._data_type.torch), - torch.tensor(item["token_spans"], dtype=torch.int32).reshape(-1, 2), - None, - None, - ) - elif ( - "chosen_token_spans" in shard_dataset.column_names - and "rejected_token_spans" in shard_dataset.column_names - and self._config.dataset.chosen_text is not None - and self._config.dataset.rejected_text is not None - ): - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield ( - torch.tensor(item["input_ids"], dtype=self._data_type.torch), - None, - torch.tensor(item["chosen_token_spans"], dtype=torch.int32).reshape(-1, 2), - torch.tensor(item["rejected_token_spans"], dtype=torch.int32).reshape(-1, 2), - ) - else: - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield ( - torch.tensor(item["input_ids"], dtype=self._data_type.torch), - None, - None, - None, - ) - - GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) - - return GPTMemmapDatasetConfig.from_dict( - { - "type": "memmap", - "path": prefix, - "num_documents": len(shard_dataset), # Use the length of the shard dataset directly - "num_tokens": sum(len(doc["input_ids"]) for doc in shard_dataset), - } - ) + return MemmapDatasetConfig.from_dict({"type": "memmap", "path": file_name}) def _load_dataset(self) -> datasets.Dataset: dataset = datasets.load_dataset( @@ -270,7 +143,11 @@ def run(self) -> None: # Prepare output directory self._config.output_path.mkdir(parents=True, exist_ok=True) - if pathlib.Path(self._config.dataset.path).is_dir(): + downloaded = pathlib.Path(self._config.dataset.path).is_dir() + if self._config.distributed.world_size > 1: + torch.distributed.barrier() + + if downloaded: # Dataset is already downloaded, load from disk dataset = self._load_dataset() else: @@ -296,54 +173,24 @@ def run(self) -> None: index=self._config.distributed.rank, ) - # Set data column and loss masking spans column based on source schema - if isinstance(self._config.dataset.source_schema, TextColumnConfig): - self._text_column = self._config.dataset.source_schema.input_column - self._loss_masking_spans_column = self._config.dataset.source_schema.loss_masking_spans_column - else: - raise ValueError( - f"Dataset source_schema set incorrectly. source_schema: '{self._config.dataset.source_schema}'." - ) - - if self._text_column not in dataset.column_names: - raise ValueError(f"Dataset does not have field '{self._text_column}'.") - - if self._config.dataset.source_schema.loss_masking_spans_column is not None and ( - self._config.dataset.chosen_text is not None or self._config.dataset.rejected_text is not None - ): - raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.") - if (self._config.dataset.chosen_text is None) != (self._config.dataset.rejected_text is None): - raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.") - - # route tokenize function - if self._loss_masking_spans_column is not None: - if self._loss_masking_spans_column not in dataset.column_names: - raise ValueError(f"Dataset does not have spans field '{self._loss_masking_spans_column}'.") - tokenize_fn = self._tokenize_batch_with_spans - elif self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None: - if self._config.dataset.chosen_text not in dataset.column_names: - raise ValueError(f"Dataset does not have chosen spans field '{self._config.dataset.chosen_text}'.") - if self._config.dataset.rejected_text not in dataset.column_names: - raise ValueError(f"Dataset does not have rejected spans field '{self._config.dataset.rejected_text}'.") - tokenize_fn = self._tokenize_preference_batch_with_spans - else: - tokenize_fn = self._tokenize_batch + for column_name in self._source_shema.columns: + if column_name not in dataset.column_names: + raise ValueError(f"Dataset does not have field '{column_name}'.") # Tokenize the dataset in parallel - tokenized_dataset = dataset.map( - tokenize_fn, + prepared_dataset = dataset.map( + self._prepare_batch, batched=True, num_proc=self._config.tokenize_workers, desc="Tokenizing batches", ) - # Calculate total number of tokens - total_tokens = sum(tqdm.tqdm(tokenized_dataset["num_tokens"], desc="Counting tokens", unit="tokens")) - # Split dataset into shards based on number of tokens - num_shards = int(np.ceil(total_tokens / self._config.tokens_per_shard)) + num_shards = math.ceil( + sum(len(sample) for sample in prepared_dataset["samples"]) / self._config.tokens_per_shard + ) shards = [ - (i, tokenized_dataset.shard(num_shards=num_shards, index=i)) + (i, prepared_dataset.shard(num_shards=num_shards, index=i)) for i in tqdm.tqdm(range(num_shards), desc="Creating shards") ] @@ -353,7 +200,67 @@ def run(self) -> None: self.generate_config_yaml_for_sharded_dst(dataset_configs) - def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[GPTMemmapDatasetConfig]) -> None: + def _prepare_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[LanguageModelSample]]: + # Gather values by sample using zip* + sample_column_values = zip(*(batch[column_name] for column_name in self._source_shema.columns)) + # Convert to dicts using column names. + sample_dicts = ( + {column_name: column_value for column_name, column_value in zip(self._source_shema.columns, sample_data)} + for sample_data in sample_column_values + ) + # Prepare each sample, wrap in dict for the `Dataset` interface + return {"samples": [self._prepare_sample(sample_dict) for sample_dict in sample_dicts]} + + def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: + text = sample[self._source_shema.text_column] + all_spans = [] + if self._source_shema.has_loss_masking_span: + # TODO: ====== What is the input format? ====== + # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format. + loss_masking_spans = _sort_spans( + (begin, last + 1) + for begin, last in np.array(sample[self._source_shema.loss_masking_spans_column], dtype=np.int32) + .reshape(-1, 2) + .tolist() + ) + all_spans.extend(loss_masking_spans) + + if self._source_shema.has_preference_spans: + # TODO: ===== Was `self._config.dataset.field` (bug?) ====== + full_chosen_text = ( + text + sample[self._source_shema.chosen_spans_column] + self._tokenizer.tokenizer.eos_token + ) + full_rejected_text = ( + self._tokenizer.tokenizer.bos_token + text + sample[self._source_shema.rejected_spans_column] + ) + # compute chosen span + chosen_spans = [[len(text), len(full_chosen_text)]] + + # compute rejected span + rejected_span = [ + [ + len(full_chosen_text) + len(self._tokenizer.tokenizer.bos_token) + len(text), + len(full_chosen_text) + len(full_rejected_text), + ] + ] + # pack texts + text = full_chosen_text + full_rejected_text + all_spans.extend(chosen_spans + rejected_span) + + tokens = torch.tensor( + self._tokenizer.tokenize_with_spans(text, True, True, spans=_sort_spans(all_spans)), + dtype=self._data_type.torch, + ) + sample_size = len(tokens) + + return LanguageModelSample( + TokenSample(tokens, [sample_size]), + RangeSample(loss_masking_spans, sample_size) if self._source_shema.has_loss_masking_span else None, + RangeSample(chosen_spans, sample_size) if self._source_shema.has_preference_spans else None, + RangeSample(rejected_span, sample_size) if self._source_shema.has_preference_spans else None, + ) + + def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[MemmapDatasetConfig]) -> None: # Gather dataset_dicts from all ranks to rank 0 if self._config.distributed.world_size > 1: if self._config.distributed.rank == 0: @@ -397,7 +304,7 @@ def _save_dataset_config( @classmethod def _blend_dataset_configs( - cls, dataset_configs: list[GPTMemmapDatasetConfig[_sample_type]] + cls, dataset_configs: list[MemmapDatasetConfig[_sample_type]] ) -> IndexedDatasetConfig[_sample_type]: if len(dataset_configs) == 1: return dataset_configs[0] @@ -412,10 +319,11 @@ def _blend_dataset_configs( @classmethod def _split_and_blend_dataset_configs( cls, - dataset_configs: list[GPTMemmapDatasetConfig[_sample_type]], + dataset_configs: list[MemmapDatasetConfig[_sample_type]], splits: dict[str, int | float], output_path: pathlib.Path, ) -> dict[str, SampledDatasetConfig[_sample_type]]: + # TODO: ====== Missing `num_tokens`, `num_documents`. ====== split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist() dataset_sizes = [dataset_config.num_tokens for dataset_config in dataset_configs] dataset_probabilities = normalize_probabilities(dataset_sizes) @@ -483,6 +391,10 @@ def _split_and_blend_dataset_configs( return dataset_splits +def _sort_spans(spans: typing.Iterable[tuple[int, int]]) -> list[tuple[int, int]]: + return sorted(spans, key=lambda span: span[0]) + + def _get_nearest_split(cumsum: np.ndarray, value: float) -> int: left = cumsum.searchsorted(value, side="right") if left == len(cumsum): diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py index 031002101..f122100f9 100644 --- a/fast_llm/data/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -1,6 +1,11 @@ import abc +import io +import pathlib import typing +from fast_llm.config import Config, Configurable, Field, config_class +from fast_llm.utils import Assert + if typing.TYPE_CHECKING: import torch @@ -40,3 +45,150 @@ def crop(self, begin: int, end: int) -> typing.Self: def to_device_(self, device: "torch.device | str"): pass + + +@config_class(registry=True) +class MemmapReaderBaseConfig(Config): + """ + Configuration for a memmap reader or reader-like object. + Note: `MemmapDataset` requires a `MemmapIndexedDatasetReader`. + Other readers need to be nested within a `MemmapIndexedDatasetReader` + Note: Reader configs are not typical configs, and do not need to be located in a separate `config.py` file. + """ + + _abstract = True + + @classmethod + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + if cls is MemmapReaderBaseConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass, necessary for loading configs where some components could be absent. + return NullReaderConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) + + def get_reader(self, buffer: memoryview) -> "MemmapReader|None": + raise NotImplementedError() + + @property + def expected_buffer_size(self) -> int: + """ + The expected buffer size in bytes. Used for self-validation. + """ + raise NotImplementedError() + + +@config_class(dynamic_type={MemmapReaderBaseConfig: "none"}) +class NullReaderConfig(MemmapReaderBaseConfig): + """ + Configuration for a dynamically disabled reader. + """ + + _abstract = False + + def get_reader(self, buffer: memoryview) -> None: + return None + + @property + def expected_buffer_size(self) -> int: + return 0 + + +@config_class(registry=True) +class MemmapReaderConfig(MemmapReaderBaseConfig): + """ + Configuration for a standard memmap reader. + """ + + begin: int = Field() + end: int = Field() + + @property + def reader_class(self) -> "type[MemmapReader]": + raise NotImplementedError() + + def get_reader(self, buffer: memoryview) -> "MemmapReader": + return self.reader_class(self, buffer[self.begin : self.end]) + + @property + def writer_class(self) -> "type[MemmapWriter]": + raise NotImplementedError() + + def get_writer(self, stream: io.BufferedWriter) -> "MemmapWriter": + return self.writer_class(stream) + + def _validate(self): + super()._validate() + print("AAAAA", self.__class__.__name__, self.begin, self.end, self.expected_buffer_size) + Assert.eq(self.end - self.begin, self.expected_buffer_size) + + +@config_class() +class MemmapIndexDatasetReaderConfig(MemmapReaderConfig): + """ + Configuration for a standard memmap reader matching the indexed dataset interface, i.e., + consisting of a list of documents of known lengths. + """ + + @property + def reader_class(self) -> "type[MemmapIndexedDatasetReader]": + raise NotImplementedError() + + def get_reader( + self, + buffer: memoryview, + ) -> "MemmapIndexedDatasetReader": + return self.reader_class(self, buffer[self.begin : self.end]) + + +class MemmapReader[ConfigType: MemmapReaderBaseConfig](Configurable[ConfigType]): + def __init__(self, config: ConfigType, buffer: memoryview): + super().__init__(config) + self._buffer = buffer[self._config.begin : self._config.end] + + @abc.abstractmethod + def get_document(self, index: int, begin: int, end: int) -> Sample: + pass + + +class MemmapIndexedDatasetReader[ConfigType: MemmapIndexDatasetReaderConfig](MemmapReader[ConfigType]): + @abc.abstractmethod + def get_document_sizes(self) -> "torch.Tensor": + pass + + @abc.abstractmethod + def get_document_size(self, index: int) -> int: + pass + + +class MemmapWriter: + def __init__(self, stream: io.BufferedWriter | pathlib.Path): + self._owns_stream = isinstance(stream, pathlib.Path) + if self._owns_stream: + stream = stream.open("wb") + self._stream = stream + + def __enter__(self): + self._begin = self._stream.tell() + return self + + def write(self, document: Sample): + assert hasattr(self, "_begin") and not hasattr(self, "_end") + + def __exit__(self, exc_type, exc_val, exc_tb): + self._end = self._stream.tell() + if self._owns_stream: + self._stream.close() + + def get_config(self, offset: int = 0) -> MemmapReaderConfig: + assert hasattr(self, "_end") + return self._get_config(self._begin + offset, self._end + offset) + + @abc.abstractmethod + def _get_config(self, begin: int, end: int): + pass + + @classmethod + def write_dataset(cls, stream: io.BufferedWriter, documents: typing.Iterable[Sample]) -> MemmapReaderConfig: + with cls(stream) as writer: + for document in documents: + writer.write(document) + return writer.get_config() diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index f30188553..3d6964b30 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -1,8 +1,23 @@ +import io +import pathlib +import tempfile import typing -from fast_llm.data.sample.abstract import Batch, Sample -from fast_llm.data.sample.range import RangeBatch, RangeSample -from fast_llm.data.sample.token import TokenBatch, TokenSample +import torch + +from fast_llm.config import Field, config_class +from fast_llm.data.sample.abstract import ( + Batch, + MemmapIndexDatasetReaderConfig, + MemmapIndexedDatasetReader, + MemmapReaderBaseConfig, + MemmapWriter, + NullReaderConfig, + Sample, +) +from fast_llm.data.sample.range import RangeBatch, RangeSample, RangeWriter +from fast_llm.data.sample.token import TokenBatch, TokenReaderConfig, TokenSample, TokenWriter +from fast_llm.utils import Assert class LanguageModelSample(Sample): @@ -105,3 +120,168 @@ def _merge_optional[T](fn: typing.Callable[[typing.Iterable], T], args: typing.I def _crop_optional[T: Sample | Batch](sample_or_batch: T, begin: int, end: int) -> T | None: return None if sample_or_batch is None else sample_or_batch.crop(begin, end) + + +@config_class(dynamic_type={MemmapReaderBaseConfig: "language_model"}) +class LanguageModelReaderConfig(MemmapIndexDatasetReaderConfig): + _abstract = False + tokens: TokenReaderConfig = Field() + # Using dynamic type for optional readers for enabling/disabling + loss_masking_spans: MemmapReaderBaseConfig = Field() + chosen_spans: MemmapReaderBaseConfig = Field() + rejected_spans: MemmapReaderBaseConfig = Field() + + @property + def reader_class(self) -> "type[LanguageModelReader]": + return LanguageModelReader + + @property + def writer_class(self) -> "type[LanguageModelWriter]": + return LanguageModelWriter + + @property + def expected_buffer_size(self) -> int: + return ( + self.tokens.expected_buffer_size + + self.loss_masking_spans.expected_buffer_size + + self.chosen_spans.expected_buffer_size + + self.rejected_spans.expected_buffer_size + ) + + +class LanguageModelReader[ConfigType: LanguageModelReaderConfig](MemmapIndexedDatasetReader[ConfigType]): + def __init__(self, config: ConfigType, buffer: memoryview): + super().__init__(config, buffer) + # Using `buffer` and not `self._buffer` because nested offsets (`begin`, `end`) are global. + self._tokens = self._config.tokens.get_reader(buffer) + self._loss_masking_spans = self._config.loss_masking_spans.get_reader(buffer) + self._preference_spans = self._config.preference_spans.get_reader(buffer) + + def get_document(self, index: int, begin: int, end: int) -> Sample: + return LanguageModelSample( + self._tokens.get_document(index, begin, end), + self._loss_masking_spans.get_document(index, begin, end), + self._preference_spans.get_document(index, begin, end), + ) + + def get_document_sizes(self) -> torch.Tensor: + return self._tokens.get_document_sizes() + + def get_document_size(self, index: int) -> int: + return self._tokens.get_document_size(index) + + +class LanguageModelWriter(MemmapWriter): + _has_loss_masking_spans: bool | None = None + _has_preference_spans: bool | None = None + + def __enter__(self): + super().__enter__() + self._size_cumsum = [0] + self._data_type = None + + self._directory = tempfile.TemporaryDirectory() + self._path = pathlib.Path(self._directory.name) + # We write intermediate results in separate files so we don't need to iterate over the dataset multiple times. + self._token_writer = TokenWriter(self._path.joinpath("tokens")).__enter__() + self._loss_masking_span_writer = RangeWriter(self._path.joinpath("loss_masking_spans")).__enter__() + self._chosen_spans_writer = RangeWriter(self._path.joinpath("chosen_spans")).__enter__() + self._rejected_spans_writer = RangeWriter(self._path.joinpath("rejected_spans")).__enter__() + return self + + def write(self, document: LanguageModelSample): + # ====== TODO: Make sure input uses end = 1 past last index (currently use last index) ====== + super().write(document) + # Write tokens. + self._token_writer.write(document.tokens) + + # Ensure either all samples have loss masking spans or none of them do. + if self._has_loss_masking_spans is None: + self._has_loss_masking_spans = document.loss_masking_spans is not None + else: + Assert.eq(self._has_loss_masking_spans, document.loss_masking_spans is not None) + + # Write loss masking spans. + if self._has_loss_masking_spans: + self._loss_masking_span_writer.write(document.loss_masking_spans) + + # All sample must either have both chosen and rejected spans, or neither. + if self._has_preference_spans is None: + self._has_preference_spans = document.chosen_spans is not None + else: + Assert.eq(self._has_preference_spans, document.chosen_spans is not None) + Assert.eq(self._has_preference_spans, document.rejected_spans is not None) + + # Write preference spans. + if self._has_preference_spans: + self._chosen_spans_writer.write(document.chosen_spans) + self._rejected_spans_writer.write(document.rejected_spans) + + def __exit__(self, exc_type, exc_val, exc_tb): + self._token_writer.__exit__(exc_type, exc_val, exc_tb) + self._loss_masking_span_writer.__exit__(exc_type, exc_val, exc_tb) + self._chosen_spans_writer.__exit__(exc_type, exc_val, exc_tb) + self._rejected_spans_writer.__exit__(exc_type, exc_val, exc_tb) + + # A dummy config so we can verify the begin and end offsets. + config = self._get_config(self._begin, None) + _copy_chunked(self._path.joinpath("tokens"), self._stream, config.tokens.begin, config.tokens.end) + + if self._has_loss_masking_spans: + _copy_chunked( + self._path.joinpath("loss_masking_spans"), + self._stream, + config.loss_masking_spans.begin, + config.loss_masking_spans.end, + ) + if self._has_preference_spans: + _copy_chunked( + self._path.joinpath("chosen_spans"), self._stream, config.chosen_spans.begin, config.chosen_spans.end + ) + _copy_chunked( + self._path.joinpath("rejected_spans"), + self._stream, + config.rejected_spans.begin, + config.rejected_spans.end, + ) + + self._directory.cleanup() + super().__exit__(exc_type, exc_val, exc_tb) + + def _get_config(self, begin: int, end: int | None): + tokens = self._token_writer.get_config(begin) + offset = tokens.end + if self._has_loss_masking_spans: + loss_masking_spans = self._loss_masking_span_writer.get_config(offset) + offset = loss_masking_spans.end + else: + loss_masking_spans = NullReaderConfig() + if self._has_preference_spans: + chosen_spans = self._chosen_spans_writer.get_config(offset) + offset = chosen_spans.end + rejected_spans = self._rejected_spans_writer.get_config(offset) + offset = rejected_spans.end + else: + chosen_spans = NullReaderConfig() + rejected_spans = NullReaderConfig() + + if end is None: + end = offset + + return LanguageModelReaderConfig( + begin=begin, + end=end, + tokens=tokens, + loss_masking_spans=loss_masking_spans, + chosen_spans=chosen_spans, + rejected_spans=rejected_spans, + ) + + +def _copy_chunked(path: pathlib.Path, stream: io.BufferedWriter, expected_begin: int, expected_end: int): + # Copy temporary file content in chunks of 100 MB. + Assert.eq(stream.tell(), expected_begin) + with path.open("rb") as input_stream: + while data := input_stream.read(100000000): + stream.write(data) + Assert.eq(stream.tell(), expected_end) diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py index d121a38b6..88dd1352d 100644 --- a/fast_llm/data/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -1,6 +1,17 @@ import typing -from fast_llm.data.sample.abstract import Batch, Sample +import numpy as np +import torch + +from fast_llm.config import Field, config_class +from fast_llm.data.sample.abstract import ( + Batch, + MemmapReader, + MemmapReaderBaseConfig, + MemmapReaderConfig, + MemmapWriter, + Sample, +) from fast_llm.utils import get_unique @@ -47,3 +58,71 @@ def from_samples(cls, samples: typing.Iterable[RangeSample]) -> typing.Self: def to_samples(self) -> list[RangeSample]: return [RangeSample(sample_ranges, self.sample_size) for sample_ranges in self.ranges] + + +@config_class(dynamic_type={MemmapReaderBaseConfig: "range"}) +class RangeReaderConfig(MemmapReaderConfig): + _abstract = False + num_documents: int = Field() + num_ranges: int = Field() + + @property + def reader_class(self) -> "type[RangeReader]": + return RangeReader + + @property + def writer_class(self) -> "type[RangeWriter]": + return RangeWriter + + @property + def expected_buffer_size(self) -> int: + return (self.num_ranges + 1) * torch.uint32.itemsize * 2 + (self.num_documents + 1) * torch.uint32.itemsize + + +class RangeReader[ConfigType: RangeReaderConfig](MemmapReader[ConfigType]): + def __init__(self, config: ConfigType, buffer: memoryview): + super().__init__(config, buffer) + self._ranges = torch.frombuffer( + self._buffer, + dtype=torch.uint32, + count=self._config.num_ranges, + ).reshape(-1, 2) + self._count_cumsums = torch.frombuffer( + self._buffer, + dtype=torch.uint32, + count=self._config.num_documents + 1, + offset=self._ranges.nbytes, + ) + + def get(self, index: int, begin: int, end: int) -> RangeSample: + sample_size = end - begin + cropped_ranges = ( + (max(begin_ - begin, 0), min(end_ - begin, sample_size)) + for begin_, end_ in self._ranges[self._count_cumsums[index] : self._count_cumsums[index + 1]].tolist() + ) + return RangeSample([(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_], sample_size) + + +class RangeWriter(MemmapWriter): + def __enter__(self): + super().__enter__() + self._count_cumsum = [0] + return self + + def write(self, document: RangeSample): + # ====== TODO: Make sure input uses end = 1 past last index (currently use last index) ====== + super().write(document) + self._stream.write(np.array(document.ranges, dtype=np.uint32).tobytes(order="C")) + self._count_cumsum.append(self._count_cumsum[-1] + len(document.ranges)) + + def __exit__(self, exc_type, exc_val, exc_tb): + self._stream.write(np.array(self._count_cumsum, dtype=np.uint32).tobytes(order="C")) + super().__exit__(exc_type, exc_val, exc_tb) + + def _get_config(self, begin: int, end: int): + return RangeReaderConfig( + begin=begin, + end=end, + num_documents=len(self._count_cumsum) - 1, + num_ranges=self._count_cumsum[-1], + ) diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index 62d1c0e67..98ee9a2a1 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -1,8 +1,18 @@ import typing +import numpy as np import torch -from fast_llm.data.sample.abstract import Batch, Sample +from fast_llm.config import Field, config_class +from fast_llm.data.sample.abstract import ( + Batch, + MemmapIndexedDatasetReader, + MemmapReaderBaseConfig, + MemmapReaderConfig, + MemmapWriter, + Sample, +) +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert @@ -73,3 +83,77 @@ def crop(self, begin: int, end: int) -> typing.Self: def to_device_(self, device: "torch.device | str"): # Also standardize the dtype while we're here. self.tokens = self.tokens.to(device, dtype=torch.int64, non_blocking=True) + + +@config_class(dynamic_type={MemmapReaderBaseConfig: "token"}) +class TokenReaderConfig(MemmapReaderConfig): + _abstract = False + num_documents: int = Field() + num_tokens: int = Field() + data_type: DataType = Field() + + @property + def reader_class(self) -> "type[TokenReader]": + return TokenReader + + @property + def writer_class(self) -> "type[TokenWriter]": + return TokenWriter + + @property + def expected_buffer_size(self) -> int: + return self.num_tokens * self.data_type.torch.itemsize + (self.num_documents + 1) * torch.uint64.itemsize + + +class TokenReader[ConfigType: TokenReaderConfig](MemmapIndexedDatasetReader[ConfigType]): + def __init__(self, config: ConfigType, buffer: memoryview): + super().__init__(config, buffer) + self._tokens = torch.frombuffer( + self._buffer, + dtype=self._config.data_type.torch, + count=self._config.num_tokens, + ) + self._size_cumsums = torch.frombuffer( + self._buffer, dtype=torch.uint64, count=self._config.num_documents + 1, offset=self._tokens.nbytes + ) + + def get_document(self, index: int, begin: int, end: int) -> Sample: + begin_ = self._size_cumsums[index].item() + return TokenSample(torch.from_numpy(self._tokens[begin_ + begin : begin_ + end]), [end - begin]) + + def get_document_sizes(self) -> torch.Tensor: + return self._size_cumsums[1:] - self._size_cumsums[:-1] + + def get_document_size(self, index: int) -> int: + return self._size_cumsums[index + 1].item() - self._size_cumsums[index].item() + + +class TokenWriter(MemmapWriter): + def __enter__(self): + super().__enter__() + self._size_cumsum = [0] + self._data_type = None + return self + + def write(self, document: TokenSample): + # ====== TODO: Make sure input uses end = 1 past last index (currently use last index) ====== + super().write(document) + if self._data_type is None: + self._data_type = document.tokens.dtype + else: + Assert.eq(self._data_type, document.tokens.dtype) + self._stream.write(document.tokens.numpy().tobytes()) + self._size_cumsum.append(self._size_cumsum[-1] + len(document.tokens)) + + def __exit__(self, exc_type, exc_val, exc_tb): + self._stream.write(np.array(self._size_cumsum, dtype=np.uint64).tobytes(order="C")) + super().__exit__(exc_type, exc_val, exc_tb) + + def _get_config(self, begin: int, end: int): + return TokenReaderConfig( + begin=begin, + end=end, + num_documents=len(self._size_cumsum) - 1, + num_tokens=self._size_cumsum[-1], + data_type=DataType.from_torch(self._data_type), + ) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index c74586207..71219a2bf 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -4,6 +4,7 @@ from fast_llm.data.config import TokenizerConfig from fast_llm.engine.config_utils.run import log_main_rank +from fast_llm.utils import Assert class Tokenizer: @@ -41,7 +42,7 @@ def vocab(self) -> dict[str, int]: def inv_vocab(self) -> dict[int, str]: return self._inv_vocab - def tokenize(self, text: str, begin=True, end=True) -> list[int]: + def tokenize(self, text: str, begin: bool = True, end: bool = True) -> list[int]: return ( ([self.bod_id] if begin else []) + self.tokenizer.encode(text, add_special_tokens=False) @@ -49,36 +50,35 @@ def tokenize(self, text: str, begin=True, end=True) -> list[int]: ) def tokenize_with_spans( - self, text: str, char_spans: list[tuple[int, int]] + self, text: str, begin: bool = True, end: bool = True, *, spans: list[tuple[int, int]] ) -> tuple[list[int], list[tuple[int, int]]]: """ Perform span-aware tokenization and return the tokenized input_ids along with token spans. """ + if not spans: + return self.tokenize(text, begin, end), [] + input_ids, token_splits = self.tokenize_with_splits( + text, begin, end, text_splits=[split for splits in spans for split in splits] + ) + return input_ids, [(begin, end) for begin, end in zip(token_splits[::2], token_splits[1::2], strict=True)] + + def tokenize_with_splits( + self, text: str, begin: bool = True, end: bool = True, *, text_splits: list[int] + ) -> tuple[list[int], list[int]]: + Assert.eq(sorted(text_splits), text_splits) input_ids = [] - token_spans = [] - char_pos = 0 - beginning_of_text = True + text_splits = [0, *text_splits, len(text_splits)] + token_splits = [] + + for split_begin, split_end in zip(text_splits[:-1], text_splits[1:]): + input_ids.extend( + self.tokenize( + text[split_begin:split_end], begin=begin and split_begin == 0, end=end and split_end == len(text) + ) + ) + token_splits.append(len(input_ids)) - for start, end in char_spans: - if char_pos < start: - curr_text = text[char_pos:start] - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) - beginning_of_text = False - input_ids.extend(tokenized_text) - curr_text = text[start : end + 1] - if end >= len(text) - 1: - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - else: - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) - beginning_of_text = False - token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1)) - input_ids.extend(tokenized_text) - char_pos = end + 1 - if char_pos < len(text): - curr_text = text[char_pos:] - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - input_ids.extend(tokenized_text) - return input_ids, token_spans + return input_ids, token_splits[:-1] def detokenize(self, token_ids: int | list[int] | np.ndarray | torch.Tensor) -> str: return self.tokenizer.decode(token_ids) diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 0099cb50b..49eceee0b 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -13,7 +13,7 @@ get_test_data_and_compare_samples, ) from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_CACHE, DATASET_PREFIX +from tests.utils.global_variables import DATASET_CACHE, DATASET_PATH _DATASET_PREFIX_MIX_1 = DATASET_CACHE / "blended_mix_1" / "dataset" @@ -118,7 +118,7 @@ def test_gpt_blended(): { "type": "blended", "datasets": [ - {"type": "memmap", "path": DATASET_PREFIX}, + {"type": "memmap", "path": DATASET_PATH}, {"type": "memmap", "path": _DATASET_PREFIX_MIX_1}, ], "weights": [0.75, 0.25], @@ -137,7 +137,7 @@ def test_gpt_blended_data(): "training": { "type": "blended", "datasets": [ - {"type": "memmap", "path": DATASET_PREFIX}, + {"type": "memmap", "path": DATASET_PATH}, {"type": "memmap", "path": _DATASET_PREFIX_MIX_1}, ], "weights": [0.75, 0.25], @@ -157,7 +157,7 @@ def test_gpt_blended_mixed(): { "type": "blended", "datasets": [ - {"type": "memmap", "path": DATASET_PREFIX}, + {"type": "memmap", "path": DATASET_PATH}, {"type": "random"}, ], "weights": [0.6, 0.4], @@ -174,7 +174,7 @@ def test_gpt_blended_mixed_data(): "datasets": { "training": { "type": "blended", - "datasets": [{"type": "memmap", "path": DATASET_PREFIX}, {"type": "random"}], + "datasets": [{"type": "memmap", "path": DATASET_PATH}, {"type": "random"}], "weights": [0.6, 0.4], } } diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index 5335e01c0..7b009bbf6 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -9,7 +9,7 @@ ) from tests.data.test_memmap import MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_TOKENS from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PREFIX +from tests.utils.global_variables import DATASET_PATH GPT_CONCATENATED_SAMPLES = [ [4709, 819, 79, 207, 277, 1790], @@ -27,7 +27,7 @@ def test_gpt_concatenate(): # Make sure the dataset concatenation works and check for unintended changes in behavior. get_test_dataset() dataset = get_dataset_config( - {"type": "concatenated", "datasets": [{"type": "memmap", "path": DATASET_PREFIX} for _ in range(3)]}, + {"type": "concatenated", "datasets": [{"type": "memmap", "path": DATASET_PATH} for _ in range(3)]}, ConcatenatedDatasetConfig[LanguageModelSample], ).build() compare_indexed_dataset( @@ -47,7 +47,7 @@ def test_gpt_concatenate_data(): "datasets": { "training": { "type": "concatenated", - "datasets": [{"type": "memmap", "path": DATASET_PREFIX} for _ in range(3)], + "datasets": [{"type": "memmap", "path": DATASET_PATH} for _ in range(3)], } } }, diff --git a/tests/data/test_dataset_from_file.py b/tests/data/test_dataset_from_file.py index c149e1395..af91df1e2 100644 --- a/tests/data/test_dataset_from_file.py +++ b/tests/data/test_dataset_from_file.py @@ -2,11 +2,11 @@ from tests.data.common import compare_indexed_dataset, get_dataset_config from tests.data.test_memmap import MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_TOKENS from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PREFIX +from tests.utils.global_variables import DATASET_PATH def test_dataset_from_file(): get_test_dataset() - dataset_config = {"type": "file", "path": str(DATASET_PREFIX.parent.joinpath("fast_llm_config.yaml"))} + dataset_config = {"type": "file", "path": str(DATASET_PATH.parent.joinpath("fast_llm_config.yaml"))} dataset = get_dataset_config(dataset_config, GPTDatasetFromFileConfig).build() compare_indexed_dataset(dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, MEMMAP_DATASET_SAMPLES) diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 438c5e7e3..b9dc7fe32 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -6,7 +6,7 @@ get_test_data_and_compare_samples, ) from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PREFIX, TOKENIZER_PATH +from tests.utils.global_variables import DATASET_PATH, TOKENIZER_PATH GPT_FIM_SAMPLES = [ [4709, 819, 79, 207, 277, 1790], @@ -32,7 +32,7 @@ def test_gpt_fim(): sampled = get_dataset_config( { "type": "fim", - "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": DATASET_PATH}, "tokenizer": {"path": TOKENIZER_PATH}, "rate": 0.5, "prefix_token": "w", @@ -52,7 +52,7 @@ def test_gpt_fim_data(): "datasets": { "training": { "type": "fim", - "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": DATASET_PATH}, "tokenizer": {"path": TOKENIZER_PATH}, "rate": 0.5, "prefix_token": "w", diff --git a/tests/data/test_memmap.py b/tests/data/test_memmap.py index ca887f3c1..419b67903 100644 --- a/tests/data/test_memmap.py +++ b/tests/data/test_memmap.py @@ -2,10 +2,10 @@ import pytest -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig +from fast_llm.data.dataset.config import MemmapDatasetConfig from tests.data.common import compare_indexed_dataset, get_dataset_config from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_CACHE, DATASET_PREFIX, DATASET_SAMPLING_CACHE +from tests.utils.global_variables import DATASET_PATH, DATASET_SAMPLING_CACHE, DATASET_WITH_SPANS_PATH MEMMAP_DATASET_LENGTH = 6153 MEMMAP_DATASET_TOKENS = 508327 @@ -21,7 +21,7 @@ def test_gpt_memmap(cache_directory): # Make sure the memmap dataset works and check for unintended changes in behavior. get_test_dataset() - dataset = get_dataset_config({"type": "memmap", "path": DATASET_PREFIX}, GPTMemmapDatasetConfig).build() + dataset = get_dataset_config({"type": "memmap", "path": DATASET_PATH}, MemmapDatasetConfig).build() compare_indexed_dataset(dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, MEMMAP_DATASET_SAMPLES) @@ -32,17 +32,15 @@ def test_gpt_memmap(cache_directory): 15: [], } -_DATASET_PREFIX_SPANS = DATASET_CACHE / "with_spans" / "dataset" - def test_gpt_data_with_spans(): - get_test_dataset(prefix=_DATASET_PREFIX_SPANS, max_spans=5) + get_test_dataset(DATASET_WITH_SPANS_PATH, max_spans=5) dataset = get_dataset_config( { "type": "memmap", - "path": _DATASET_PREFIX_SPANS, + "path": DATASET_WITH_SPANS_PATH, }, - GPTMemmapDatasetConfig, + MemmapDatasetConfig, ).build() compare_indexed_dataset( dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_SPANS diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index 601abcf99..1608bb48c 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -8,10 +8,12 @@ from fast_llm.data.dataset.config import IndexedDatasetConfig from fast_llm.data.dataset.gpt.config import GPTSamplingParameters -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.dataset.memmap import MemmapDataset from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, GPTMemmapDatasetPreparatorConfig from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator -from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.sample.language_model import LanguageModelSample, LanguageModelWriter +from fast_llm.data.sample.range import RangeSample +from fast_llm.data.sample.token import TokenSample from fast_llm.utils import Assert from tests.data.common import MockGPTMemmapDatasetConfig # Noqa @@ -31,15 +33,17 @@ def get_preparator(output_path: str, dataset_path_name: str) -> GPTMemmapDataset @pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) def test_write_memmap_dataset(dtype): documents = [ - (torch.from_numpy(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype)), None, None, None) + LanguageModelSample( + TokenSample(torch.from_numpy(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype))) + ) for _ in range(100) ] with tempfile.TemporaryDirectory() as temp_dir: - prefix = pathlib.Path(temp_dir) - GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) - dataset = GPTMemmapDataset(name="foo", prefix=prefix) - for i, (tokens, _, _, _) in enumerate(documents): - Assert.all_equal(dataset.get_document(i).tokens.tokens, tokens.to(torch.int64)) + path = pathlib.Path(temp_dir) / "dataset" + MemmapDataset.write_dataset(path, documents, LanguageModelWriter) + dataset = MemmapDataset("dataset", path) + for i, document in enumerate(documents): + Assert.all_equal(dataset.get_document(i).tokens.tokens, document.tokens.tokens.to(torch.int64)) def _generate_valid_span(max_seq_length): @@ -49,26 +53,26 @@ def _generate_valid_span(max_seq_length): @pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) def test_write_memmap_preference_dataset(dtype): documents = [ - ( - torch.from_numpy(np.random.randint(1000, size=100).astype(dtype)), + LanguageModelSample( + TokenSample(torch.from_numpy(np.random.randint(1000, size=100).astype(dtype))), None, - _generate_valid_span(100), - _generate_valid_span(100), + RangeSample(_generate_valid_span(100), 100), + RangeSample(_generate_valid_span(100), 100), ) for _ in range(50) ] with tempfile.TemporaryDirectory() as temp_dir: - prefix = pathlib.Path(temp_dir) - GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) - dataset = GPTMemmapDataset(name="foo", prefix=prefix) + path = pathlib.Path(temp_dir) / "dataset" + MemmapDataset.write_dataset(path, documents, LanguageModelWriter) + dataset = MemmapDataset("dataset", path) parameters = GPTSamplingParameters( num_samples=0, sequence_length=0, vocab_size=0, use_preference_loss_spans=True ) - for i, (token_ids, _, (chosen_begin, chosen_end), (rejected_begin, rejected_end)) in enumerate(documents): - document = dataset.get_document(i, parameters=parameters) - Assert.all_equal(document.tokens.tokens, token_ids.to(torch.int64)) - Assert.eq(document.chosen_spans.ranges, [(chosen_begin, chosen_end + 1)]) - Assert.eq(document.rejected_spans.ranges, [(rejected_begin, rejected_end + 1)]) + for i, document in enumerate(documents): + dataset_document = dataset.get_document(i, parameters=parameters) + Assert.all_equal(dataset_document.tokens.tokens, document.tokens.tokens.to(torch.int64)) + Assert.eq(dataset_document.chosen_spans.ranges, document.chosen_spans.ranges) + Assert.eq(dataset_document.rejected_spans.ranges, document.rejected_spans.ranges) def test_load_metadata_from_hub(): diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 58f4d3dab..c171d15dd 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -2,8 +2,8 @@ import pytest import torch -from fast_llm.data.dataset.config import ShufflingType -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSamplingParameters +from fast_llm.data.dataset.config import MemmapDatasetConfig, ShufflingType +from fast_llm.data.dataset.gpt.config import GPTSamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.token import TokenSample @@ -15,7 +15,7 @@ validate_indexed_dataset_sampling, ) from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PREFIX +from tests.utils.global_variables import DATASET_PATH try: from fast_llm.csrc.data import build_padded_token_cumsum # noqa @@ -40,7 +40,7 @@ def test_gpt_sampled(): # Make sure the memmap dataset works and check for unintended changes in behavior. get_test_dataset() - sampled = get_dataset_config({"type": "memmap", "path": DATASET_PREFIX}, GPTMemmapDatasetConfig).build_and_sample( + sampled = get_dataset_config({"type": "memmap", "path": DATASET_PATH}, MemmapDatasetConfig).build_and_sample( get_sampling_data(8, sequence_length=5) ) validate_indexed_dataset_sampling(sampled, GPT_MEMMAP_SAMPLES) @@ -53,7 +53,7 @@ def test_gpt_sampled_data(): "datasets": { "training": { "type": "memmap", - "path": DATASET_PREFIX, + "path": DATASET_PATH, } } }, diff --git a/tests/data/test_slice.py b/tests/data/test_slice.py index 3c6ae10d4..3a6b999cd 100644 --- a/tests/data/test_slice.py +++ b/tests/data/test_slice.py @@ -9,7 +9,7 @@ ) from tests.data.test_memmap import MEMMAP_DATASET_SAMPLES from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PREFIX +from tests.utils.global_variables import DATASET_PATH GPT_SLICE_TRAINING_SAMPLES = [ [80, 268, 79, 260, 207, 3086], @@ -34,7 +34,7 @@ def test_gpt_slice(): get_test_dataset() # samples[9:18] dataset = get_dataset_config( - {"type": "slice", "dataset": {"type": "memmap", "path": DATASET_PREFIX}, "begin": 0.0015, "end": 0.003}, + {"type": "slice", "dataset": {"type": "memmap", "path": DATASET_PATH}, "begin": 0.0015, "end": 0.003}, DatasetSliceConfig[LanguageModelSample], ).build() compare_indexed_dataset(dataset, 9, 544, {i - 9: sample for i, sample in MEMMAP_DATASET_SAMPLES.items()}) @@ -48,19 +48,19 @@ def test_gpt_slice_data(): "datasets": { "training": { "type": "slice", - "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": DATASET_PATH}, "begin": 0, "end": 0.0015, }, "validation": { "type": "slice", - "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": DATASET_PATH}, "begin": 0.0015, "end": 0.003, }, "test": { "type": "slice", - "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": DATASET_PATH}, "begin": 0.003, "end": 1, }, diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index 7447e395a..42a7c1f0d 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -6,16 +6,16 @@ from fast_llm.config import Field, FieldHint, config_class from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.config import SampledDatasetConfig -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSamplingData -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.dataset.config import MemmapDatasetConfig, SampledDatasetConfig +from fast_llm.data.dataset.gpt.config import GPTSamplingData +from fast_llm.data.dataset.gpt.legacy_memmap import LegacyMemmapDataset from fast_llm.data.dataset.sampled import logger from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert from tests.utils.compare_tensor_logs import CompareConfig from tests.utils.dataset import get_model_test_dataset from tests.utils.distributed_configs import DistributedTestingConfig -from tests.utils.global_variables import MODEL_DATASET_PREFIX +from tests.utils.global_variables import MODEL_DATASET_PATH from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -69,7 +69,7 @@ def test_match_megatron(run_test_script_for_all_models, model_testing_config, co compare="megatron", config_args=[ "model.distributed.compute_dtype=fp32", - f'data.datasets.training={{"type":"megatron","path":{MODEL_DATASET_PREFIX}}}', + f'data.datasets.training={{"type":"megatron","path":{MODEL_DATASET_PATH}}}', "data.sampling.seed=1234", "model.base_model.use_megatron_initialization=True", ], @@ -82,25 +82,23 @@ def test_match_megatron(run_test_script_for_all_models, model_testing_config, co @config_class(dynamic_type={SampledDatasetConfig: "megatron"}) -class GPTMegatronDatasetConfig(GPTMemmapDatasetConfig): +class MegatronDatasetConfig[SampleType: LanguageModelSample](MemmapDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False path: str = Field( desc="Dataset path (prefix).", hint=FieldHint.core, ) - def build(self) -> "GPTMemmapDataset": - return GPTMegatronMemmapDataset( - str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens - ) + def build(self) -> "LegacyMemmapDataset[SampleType]": + return MegatronMemmapDataset(str(self.path).replace("/", "__"), self.path) -class GPTMegatronMemmapDataset(GPTMemmapDataset): - def sample(self, sampling: GPTSamplingData) -> "MegatronGPTSampledIndexedDataset": - return MegatronGPTSampledIndexedDataset(self, sampling) +class MegatronMemmapDataset(LegacyMemmapDataset): + def sample(self, sampling: GPTSamplingData) -> "MegatronSampledIndexedDataset": + return MegatronSampledIndexedDataset(self, sampling) -class MegatronGPTSampledIndexedDataset(SampledDataset): +class MegatronSampledIndexedDataset(SampledDataset): """ A GPT sampled dataset that exactly matches Megatron-LM, for testing purposes. Minimalistic implementation, implements only the required features. @@ -108,7 +106,7 @@ class MegatronGPTSampledIndexedDataset(SampledDataset): def __init__( self, - indexed_dataset: GPTMegatronMemmapDataset, + indexed_dataset: MegatronMemmapDataset, sampling: GPTSamplingData, ): assert isinstance(sampling, GPTSamplingData) diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index 428dec56b..baff00b80 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -5,10 +5,13 @@ import torch import yaml -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.dataset.memmap import MemmapDataset +from fast_llm.data.sample.language_model import LanguageModelSample, LanguageModelWriter +from fast_llm.data.sample.range import RangeSample +from fast_llm.data.sample.token import TokenSample from tests.utils.global_variables import ( - DATASET_PREFIX, - MODEL_DATASET_PREFIX, + DATASET_PATH, + MODEL_DATASET_PATH, MODEL_TEST_VOCAB_SIZE, TEST_CHARACTERS, TEST_DATASET_TOKENS, @@ -35,7 +38,7 @@ def get_random_spans(num_samples: int, max_spans: int, lengths: np.ndarray | int def get_test_dataset( - prefix: pathlib.Path = DATASET_PREFIX, + path: pathlib.Path = DATASET_PATH, seed: int = 1234, num_tokens: int = TEST_DATASET_TOKENS, characters: str = TEST_CHARACTERS, @@ -43,48 +46,35 @@ def get_test_dataset( max_spans: int = 0, ): download_santacoder_tokenizer() + config_path = path.parent.joinpath("fast_llm_config.yaml") - if not ( - prefix.with_suffix(".idx").is_file() - and prefix.with_suffix(".bin").is_file() - and prefix.parent.joinpath("fast_llm_config.yaml").is_file() - ): + if not (path.is_file() and config_path.is_file()): import transformers texts = "".join(random.Random(seed).choices(characters, k=num_tokens)).splitlines() tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) samples = [ - ( - torch.from_numpy(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size), - None, - None, - None, + LanguageModelSample( + TokenSample( + torch.from_numpy(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size) + ), ) for document in texts ] if max_spans > 0: spans = get_random_spans( - len(samples), max_spans, np.array([[max(len(tokens), 1)] for tokens, _, _, _ in samples]), seed + len(samples), max_spans, np.array([[max(len(sample), 1)] for sample in samples]), seed ) - samples = [ - ( - tokens, - torch.tensor(sample_spans, dtype=torch.int32).reshape(-1, 2), - None, - None, - ) - for (tokens, _, _, _), sample_spans in zip(samples, spans, strict=True) - ] + for sample, sample_spans in zip(samples, spans, strict=True): + sample.loss_masking_spans = RangeSample(sample_spans, len(sample)) - GPTMemmapDataset.write_dataset(prefix, samples) - yaml.safe_dump( - {"type": "memmap", "path": prefix.name}, prefix.parent.joinpath("fast_llm_config.yaml").open("w") - ) + MemmapDataset.write_dataset(path, samples, LanguageModelWriter) + yaml.safe_dump({"type": "memmap", "path": path.name}, config_path.open("w")) def get_model_test_dataset( - prefix: pathlib.Path = MODEL_DATASET_PREFIX, + path: pathlib.Path = MODEL_DATASET_PATH, vocab_size: int = MODEL_TEST_VOCAB_SIZE, ): - return get_test_dataset(prefix=prefix, vocab_size=vocab_size) + return get_test_dataset(path, vocab_size=vocab_size) diff --git a/tests/utils/global_variables.py b/tests/utils/global_variables.py index 42e588911..c62903a6c 100644 --- a/tests/utils/global_variables.py +++ b/tests/utils/global_variables.py @@ -37,12 +37,13 @@ def set_testing_global_variables(): TOKENIZER_PATH = SHARED_RESULT_PATH / "tokenizer" TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json" DATASET_CACHE = SHARED_RESULT_PATH / "dataset" -DATASET_PREFIX = DATASET_CACHE / "common_dataset" +DATASET_PATH = DATASET_CACHE / "common_dataset.fast_llm_dataset" +DATASET_WITH_SPANS_PATH = DATASET_CACHE / "dataset_with_spans.fast_llm_dataset" DATASET_SAMPLING_CACHE = TEST_RESULTS_PATH / "dataset_sampling_cache" TEST_VOCAB_SIZE = 8192 # Random lowercase: 80.7% (3.1% each); space: 18.6%; doc end: 0.6% TEST_CHARACTERS = (string.ascii_lowercase) * 5 + " " * 30 + "\n" TEST_DATASET_TOKENS = 1000000 -MODEL_DATASET_PREFIX = DATASET_CACHE / "model_dataset" +MODEL_DATASET_PATH = DATASET_CACHE / "model_dataset.fast_llm_dataset" MODEL_TEST_VOCAB_SIZE = 384 diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index c02521d7b..adcf84b18 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -22,7 +22,7 @@ Qwen2CheckpointFormat, ) from tests.utils.distributed_configs import DistributedTestingConfig -from tests.utils.global_variables import MODEL_DATASET_PREFIX, MODEL_TEST_VOCAB_SIZE +from tests.utils.global_variables import MODEL_DATASET_PATH, MODEL_TEST_VOCAB_SIZE from fast_llm.engine.evaluation.evaluators import ( # isort:skip # needed for dynamic type registration EvaluatorsConfig, @@ -234,18 +234,18 @@ def _update_and_add_testing_config( "data": { "datasets": { "training": { - "dataset": {"type": "memmap", "path": MODEL_DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": MODEL_DATASET_PATH}, "type": "slice", "end": 0.969, }, "validation": { - "dataset": {"type": "memmap", "path": MODEL_DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": MODEL_DATASET_PATH}, "type": "slice", "begin": 0.969, "end": 0.999, }, "test": { - "dataset": {"type": "memmap", "path": MODEL_DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": MODEL_DATASET_PATH}, "type": "slice", "begin": 0.999, "end": 1, @@ -279,7 +279,7 @@ def _update_and_add_testing_config( "--tokenizer-type=NullTokenizer", # Megatron messes with the vocab size, so we have to subtract 1. f"--vocab-size={MODEL_TEST_VOCAB_SIZE - 1}", - f"--data-path={MODEL_DATASET_PREFIX}", + f"--data-path={MODEL_DATASET_PATH}", "--split=1,0,0", "--lr-decay-style=constant", # Initialization is set up to match MCore models (MCore inverts self-attn qkv and dense layers compared to original Megatron) diff --git a/tools/concatenate_dataset.py b/tools/concatenate_dataset.py deleted file mode 100644 index bbfa4b21a..000000000 --- a/tools/concatenate_dataset.py +++ /dev/null @@ -1,60 +0,0 @@ -import json -import logging -import pathlib - -from fast_llm.config import Field, config_class -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.engine.config_utils.runnable import RunnableConfig - -logger = logging.getLogger(__name__) - - -@config_class() -class ConcatenateDatasetConfig(RunnableConfig): - directory: pathlib.Path = Field() - output_name: str = Field(default="fast_llm_dataset.json") - # A lower bound on the number of tokens in a dataset. - # Normally we would like each dataset split to contain at least a few samples, - # i.e. we want num_tokens >= sequence_length * min_split * min_samples_per_split. - # For example with a (999, 1, 0) split , 8K sequence length, we need at least 8M tokens - # for a single validation sample, possibly more if the split is imperfect. - min_tokens: int | None = Field(default=None) - - def run(self): - self.to_logs() - assert self.directory.is_dir() - output_file = self.directory / self.output_name - assert not output_file.exists(), str(output_file) - datasets = [] - - logger.info(f"Loading datasets from {self.directory}") - for path in self.directory.glob("**/*.idx"): - prefix = path.with_suffix("") - logger.info(str(prefix)) - dataset = GPTMemmapDataset("dataset", prefix) - dataset_dict = { - "prefix": str(prefix.relative_to(self.directory)), - "num_documents": len(dataset), - "num_tokens": dataset.num_tokens, - } - if self.min_tokens is not None and dataset_dict["num_tokens"] < self.min_tokens: - logger.info( - f"Ignoring dataset {dataset_dict['prefix']} with {dataset_dict['num_tokens']:,} tokens" - f" (requiring at least {self.min_tokens:,} tokens)" - ) - else: - datasets.append(dataset_dict) - total_documents = sum(dataset["num_documents"] for dataset in datasets) - total_tokens = sum(dataset["num_tokens"] for dataset in datasets) - logger.info(f"Found {total_documents:,} documents, {total_tokens:,} tokens in {len(datasets)} dataset files") - for dataset in datasets: - dataset["weight"] = dataset["num_tokens"] / total_tokens - logger.info( - f'{dataset["prefix"]}: documents = {dataset["num_documents"]:,}, tokens = {dataset["num_tokens"]:,}, weight = {dataset["weight"]:.6f}' - ) - logger.info(f"Saving merged dataset to {output_file}") - json.dump({"datasets": datasets}, output_file.open("w")) - - -if __name__ == "__main__": - ConcatenateDatasetConfig.parse_and_run() From acfd30ea476c0c11a5ff8233aaa71ea2e5814956 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Oct 2025 17:57:02 -0400 Subject: [PATCH 13/45] fixes --- fast_llm/data/dataset/config.py | 2 +- fast_llm/data/dataset/indexed.py | 19 ++++- fast_llm/data/dataset/memmap.py | 16 ++-- fast_llm/data/preparator/gpt_memmap/config.py | 2 +- .../data/preparator/gpt_memmap/prepare.py | 80 +++++++++++-------- fast_llm/data/sample/abstract.py | 58 ++++++++++++-- fast_llm/data/sample/language_model.py | 31 +++++-- fast_llm/data/sample/range.py | 23 +++--- fast_llm/data/sample/token.py | 20 +++-- tests/data/common.py | 17 ++-- tests/data/test_blending.py | 8 +- tests/data/test_memmap.py | 4 +- tests/data/test_prepare_gpt_memmap.py | 15 ++-- 13 files changed, 207 insertions(+), 88 deletions(-) diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index f1bc3472a..f60decd81 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -10,11 +10,11 @@ from fast_llm.config import Config, Field, FieldHint, UpdateType, check_field, config_class from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.sample.abstract import Sample -from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert, normalize_probabilities if typing.TYPE_CHECKING: from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset + from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.engine.distributed.distributed import Distributed logger = logging.getLogger(__name__) diff --git a/fast_llm/data/dataset/indexed.py b/fast_llm/data/dataset/indexed.py index c6eac9e28..5d6636f7f 100644 --- a/fast_llm/data/dataset/indexed.py +++ b/fast_llm/data/dataset/indexed.py @@ -34,11 +34,20 @@ def get_document( ) -> SampleType: pass - @abc.abstractmethod def __len__(self) -> int: """ - Number of samples in the dataset. + Number of documents in the dataset. + Note: this default implementation is slow and should be overridden when possible. + """ + return len(self.get_document_sizes()) + + @property + def num_tokens(self) -> int: + """ + Number of tokens in the dataset. + Note: this default implementation is slow and should be overridden when possible. """ + return self.get_document_sizes().sum().item() def sample(self, sampling: SamplingData) -> "GPTSampledIndexedDataset": from fast_llm.data.dataset.sampled import SampledIndexedDataset @@ -108,6 +117,12 @@ def __init__( def __len__(self) -> int: return self._dataset_splits[-1].item() + def num_tokens(self) -> int: + """ + Number of tokens in the dataset. + """ + return sum(len(dataset) for dataset in self._datasets) + def get_document_sizes(self) -> torch.Tensor: # TODO: This can be really big. return torch.cat([dataset.get_document_sizes() for dataset in self._datasets]) diff --git a/fast_llm/data/dataset/memmap.py b/fast_llm/data/dataset/memmap.py index e2aeda077..ffb2bc6d1 100644 --- a/fast_llm/data/dataset/memmap.py +++ b/fast_llm/data/dataset/memmap.py @@ -57,6 +57,8 @@ def __del__(self): def get_document( self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None ) -> SampleType: + if end is None: + end = self._reader.get_document_size(index) return self._reader.get_document(index, begin, end) @property @@ -64,12 +66,11 @@ def name(self) -> str: return self._name def __len__(self) -> int: - return self._reader + return len(self._reader) - # TODO: ====== needed? ====== - # @property - # def num_tokens(self) -> int: - # return self._reader.num_tokens + @property + def num_tokens(self) -> int: + return self._reader.num_tokens def get_document_sizes(self) -> torch.Tensor: return self._reader.get_document_sizes() @@ -78,7 +79,9 @@ def get_document_size(self, index: int) -> int: return self._reader.get_document_size(index) @classmethod - def write_dataset(cls, path: pathlib.Path, documents: typing.Iterable[Sample], writer_class: type[MemmapWriter]): + def write_dataset( + cls, path: pathlib.Path, documents: typing.Iterable[Sample], writer_class: type[MemmapWriter] + ) -> MemmapIndexDatasetReaderConfig: # TODO: Match `writer_class` with `SampleType`? path.parent.mkdir(parents=True, exist_ok=True) with path.open("wb") as stream: @@ -98,3 +101,4 @@ def write_dataset(cls, path: pathlib.Path, documents: typing.Iterable[Sample], w # Write a pointer to the reader config. stream.seek(start) stream.write(config_offset.to_bytes(4, signed=False)) + return reader_config diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index c193cf942..7dd520ec3 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -65,7 +65,7 @@ def _validate(self): super()._validate() if self.has_loss_masking_span != self.rejected_spans_column is not None: raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.") - if self.has_preference_spans == self.has_loss_masking_span: + if self.has_preference_spans and self.has_loss_masking_span: # TODO: ====== Still needed? ====== raise ValueError(f"Can not enable both loss masking and preference spans.") diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 18ab2d787..06a4bd517 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -25,6 +25,7 @@ from fast_llm.data.dataset.memmap import MemmapDataset from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, LanguageModelSourceConfig +from fast_llm.data.sample.abstract import MemmapIndexDatasetReaderConfig from fast_llm.data.sample.language_model import LanguageModelSample, LanguageModelWriter from fast_llm.data.sample.range import RangeSample from fast_llm.data.sample.token import TokenSample @@ -43,19 +44,21 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D def __init__(self, config: ConfigType): super().__init__(config) - self._source_shema: LanguageModelSourceConfig = self._config.dataset.source_shema + self._source_schema: LanguageModelSourceConfig = self._config.dataset.source_schema - def _save_shard(self, args: tuple[int, datasets.Dataset]) -> MemmapDatasetConfig: + def _save_shard( + self, args: tuple[int, datasets.Dataset] + ) -> tuple[MemmapDatasetConfig, MemmapIndexDatasetReaderConfig]: shard_index, shard_dataset = args file_name = f"shard_{self._config.distributed.rank}_{shard_index}.fast_llm_dataset" - MemmapDataset.write_dataset( + reader_config = MemmapDataset.write_dataset( self._config.output_path / file_name, tqdm.tqdm((sample["sample"] for sample in shard_dataset), desc=f"Saving shard {shard_index}", unit="docs"), LanguageModelWriter, ) - return MemmapDatasetConfig.from_dict({"type": "memmap", "path": file_name}) + return MemmapDatasetConfig.from_dict({"type": "memmap", "path": file_name}), reader_config def _load_dataset(self) -> datasets.Dataset: dataset = datasets.load_dataset( @@ -173,7 +176,7 @@ def run(self) -> None: index=self._config.distributed.rank, ) - for column_name in self._source_shema.columns: + for column_name in self._source_schema.columns: if column_name not in dataset.column_names: raise ValueError(f"Dataset does not have field '{column_name}'.") @@ -196,42 +199,42 @@ def run(self) -> None: # Use multiprocessing to save each shard in parallel on all ranks with multiprocessing.Pool(processes=self._config.saving_workers) as pool: - dataset_configs = pool.map(self._save_shard, shards) + dataset_and_reader_configs = pool.map(self._save_shard, shards) - self.generate_config_yaml_for_sharded_dst(dataset_configs) + self.generate_config_yaml_for_sharded_dst(dataset_and_reader_configs) def _prepare_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[LanguageModelSample]]: # Gather values by sample using zip* - sample_column_values = zip(*(batch[column_name] for column_name in self._source_shema.columns)) + sample_column_values = zip(*(batch[column_name] for column_name in self._source_schema.columns)) # Convert to dicts using column names. sample_dicts = ( - {column_name: column_value for column_name, column_value in zip(self._source_shema.columns, sample_data)} + {column_name: column_value for column_name, column_value in zip(self._source_schema.columns, sample_data)} for sample_data in sample_column_values ) # Prepare each sample, wrap in dict for the `Dataset` interface return {"samples": [self._prepare_sample(sample_dict) for sample_dict in sample_dicts]} def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: - text = sample[self._source_shema.text_column] + text = sample[self._source_schema.text_column] all_spans = [] - if self._source_shema.has_loss_masking_span: + if self._source_schema.has_loss_masking_span: # TODO: ====== What is the input format? ====== # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format. loss_masking_spans = _sort_spans( (begin, last + 1) - for begin, last in np.array(sample[self._source_shema.loss_masking_spans_column], dtype=np.int32) + for begin, last in np.array(sample[self._source_schema.loss_masking_spans_column], dtype=np.int32) .reshape(-1, 2) .tolist() ) all_spans.extend(loss_masking_spans) - if self._source_shema.has_preference_spans: + if self._source_schema.has_preference_spans: # TODO: ===== Was `self._config.dataset.field` (bug?) ====== full_chosen_text = ( - text + sample[self._source_shema.chosen_spans_column] + self._tokenizer.tokenizer.eos_token + text + sample[self._source_schema.chosen_spans_column] + self._tokenizer.tokenizer.eos_token ) full_rejected_text = ( - self._tokenizer.tokenizer.bos_token + text + sample[self._source_shema.rejected_spans_column] + self._tokenizer.tokenizer.bos_token + text + sample[self._source_schema.rejected_spans_column] ) # compute chosen span chosen_spans = [[len(text), len(full_chosen_text)]] @@ -255,33 +258,37 @@ def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: return LanguageModelSample( TokenSample(tokens, [sample_size]), - RangeSample(loss_masking_spans, sample_size) if self._source_shema.has_loss_masking_span else None, - RangeSample(chosen_spans, sample_size) if self._source_shema.has_preference_spans else None, - RangeSample(rejected_span, sample_size) if self._source_shema.has_preference_spans else None, + RangeSample(loss_masking_spans, sample_size) if self._source_schema.has_loss_masking_span else None, + RangeSample(chosen_spans, sample_size) if self._source_schema.has_preference_spans else None, + RangeSample(rejected_span, sample_size) if self._source_schema.has_preference_spans else None, ) - def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[MemmapDatasetConfig]) -> None: + def generate_config_yaml_for_sharded_dst( + self, dataset_and_reader_configs: list[tuple[MemmapDatasetConfig, MemmapIndexDatasetReaderConfig]] + ) -> None: # Gather dataset_dicts from all ranks to rank 0 if self._config.distributed.world_size > 1: if self._config.distributed.rank == 0: - all_dataset_configs = [None] * self._config.distributed.world_size - torch.distributed.gather_object(dataset_configs, all_dataset_configs, dst=0) - dataset_configs = [item for sublist in all_dataset_configs for item in sublist] + all_dataset_and_reader_configs = [None] * self._config.distributed.world_size + torch.distributed.gather_object(dataset_and_reader_configs, all_dataset_and_reader_configs, dst=0) + dataset_and_reader_configs = [item for sublist in all_dataset_and_reader_configs for item in sublist] else: - torch.distributed.gather_object(dataset_configs, [], dst=0) + torch.distributed.gather_object(dataset_and_reader_configs, [], dst=0) if self._config.distributed.rank == 0: # Create the config file(s) on rank 0 + dataset_configs, reader_configs = zip(*dataset_and_reader_configs) if self._config.splits: for split_name, split_config in self._split_and_blend_dataset_configs( - dataset_configs, self._config.splits, self._config.output_path + dataset_configs, reader_configs, self._config.splits, self._config.output_path ).items(): self._save_dataset_config( split_config, self._config.output_path / f"fast_llm_config_{split_name}.yaml" ) else: self._save_dataset_config( - self._blend_dataset_configs(dataset_configs), self._config.output_path / f"fast_llm_config.yaml" + self._blend_dataset_configs(dataset_configs, reader_configs), + self._config.output_path / f"fast_llm_config.yaml", ) # Save metadata on rank 0 @@ -304,7 +311,9 @@ def _save_dataset_config( @classmethod def _blend_dataset_configs( - cls, dataset_configs: list[MemmapDatasetConfig[_sample_type]] + cls, + dataset_configs: list[MemmapDatasetConfig[_sample_type]], + reader_configs: list[MemmapIndexDatasetReaderConfig], ) -> IndexedDatasetConfig[_sample_type]: if len(dataset_configs) == 1: return dataset_configs[0] @@ -312,7 +321,7 @@ def _blend_dataset_configs( { "type": "blended", "datasets": dataset_configs, - "weights": [dataset_config.num_tokens for dataset_config in dataset_configs], + "weights": [reader_config.num_tokens for reader_config in reader_configs], } ) @@ -320,12 +329,13 @@ def _blend_dataset_configs( def _split_and_blend_dataset_configs( cls, dataset_configs: list[MemmapDatasetConfig[_sample_type]], + reader_configs: list[MemmapIndexDatasetReaderConfig], splits: dict[str, int | float], output_path: pathlib.Path, ) -> dict[str, SampledDatasetConfig[_sample_type]]: # TODO: ====== Missing `num_tokens`, `num_documents`. ====== split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist() - dataset_sizes = [dataset_config.num_tokens for dataset_config in dataset_configs] + dataset_sizes = [reader_config.num_tokens for reader_config in reader_configs] dataset_probabilities = normalize_probabilities(dataset_sizes) dataset_cumsums = padded_cumsum(dataset_probabilities).tolist() dataset_splits = {} @@ -333,7 +343,9 @@ def _split_and_blend_dataset_configs( for split_index, split_name in enumerate(splits): datasets_in_split = [] dataset_tokens_in_split = [] - for dataset_index, dataset_config in enumerate(dataset_configs): + for dataset_index, (dataset_config, reader_config) in enumerate( + zip(dataset_configs, reader_configs, strict=True) + ): split_begin_in_dataset = max( (split_cumsum[split_index] - dataset_cumsums[dataset_index]) / dataset_probabilities[dataset_index], @@ -353,17 +365,17 @@ def _split_and_blend_dataset_configs( # TODO: Somehow getting a segfault when merging two lines below (numpy bug?). dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build() sizes_cumsum = dataset.get_document_sizes().numpy().cumsum() - Assert.eq(sizes_cumsum[-1], dataset_config.num_tokens) - begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * dataset_config.num_tokens) - end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * dataset_config.num_tokens) + Assert.eq(sizes_cumsum[-1], reader_config.num_tokens) + begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * reader_config.num_tokens) + end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * reader_config.num_tokens) if end_index > begin_index: datasets_in_split.append( DatasetSliceConfig[cls._sample_type].from_dict( { "type": "slice", "dataset": dataset_configs[dataset_index], - "begin": begin_index / dataset_config.num_documents, - "end": end_index / dataset_config.num_documents, + "begin": begin_index / len(reader_config), + "end": end_index / len(reader_config), } ) ) diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py index f122100f9..9afc6124c 100644 --- a/fast_llm/data/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -71,7 +71,7 @@ def get_reader(self, buffer: memoryview) -> "MemmapReader|None": @property def expected_buffer_size(self) -> int: """ - The expected buffer size in bytes. Used for self-validation. + The expected buffer size in bytes, including header and footer. Used for self-validation. """ raise NotImplementedError() @@ -98,15 +98,33 @@ class MemmapReaderConfig(MemmapReaderBaseConfig): Configuration for a standard memmap reader. """ + # Data location in the file. begin: int = Field() end: int = Field() + # Constant strings for alignment safety. + header: typing.ClassVar[bytes] + footer: typing.ClassVar[bytes] @property def reader_class(self) -> "type[MemmapReader]": raise NotImplementedError() def get_reader(self, buffer: memoryview) -> "MemmapReader": - return self.reader_class(self, buffer[self.begin : self.end]) + return self.reader_class(self, buffer) + + @property + def expected_buffer_size(self) -> int: + """ + The expected buffer size in bytes, including header and footer. Used for self-validation. + """ + return self._expected_buffer_size + len(self.header) + len(self.footer) + + @property + def _expected_buffer_size(self) -> int: + """ + The expected buffer size in bytes, excluding header and footer. Used for self-validation. + """ + raise NotImplementedError() @property def writer_class(self) -> "type[MemmapWriter]": @@ -117,7 +135,6 @@ def get_writer(self, stream: io.BufferedWriter) -> "MemmapWriter": def _validate(self): super()._validate() - print("AAAAA", self.__class__.__name__, self.begin, self.end, self.expected_buffer_size) Assert.eq(self.end - self.begin, self.expected_buffer_size) @@ -128,6 +145,15 @@ class MemmapIndexDatasetReaderConfig(MemmapReaderConfig): consisting of a list of documents of known lengths. """ + @abc.abstractmethod + def __len__(self) -> int: + pass + + @property + @abc.abstractmethod + def num_tokens(self) -> int: + pass + @property def reader_class(self) -> "type[MemmapIndexedDatasetReader]": raise NotImplementedError() @@ -136,13 +162,17 @@ def get_reader( self, buffer: memoryview, ) -> "MemmapIndexedDatasetReader": - return self.reader_class(self, buffer[self.begin : self.end]) + return self.reader_class(self, buffer) -class MemmapReader[ConfigType: MemmapReaderBaseConfig](Configurable[ConfigType]): +class MemmapReader[ConfigType: MemmapReaderConfig](Configurable[ConfigType]): def __init__(self, config: ConfigType, buffer: memoryview): super().__init__(config) - self._buffer = buffer[self._config.begin : self._config.end] + buffer_begin = self._config.begin + len(self._config.header) + buffer_end = self._config.end - len(self._config.footer) + Assert.eq(buffer[self._config.begin : buffer_begin].tobytes(), self._config.header) + Assert.eq(buffer[buffer_end : self._config.end].tobytes(), self._config.footer) + self._buffer = buffer[buffer_begin:buffer_end] @abc.abstractmethod def get_document(self, index: int, begin: int, end: int) -> Sample: @@ -150,6 +180,13 @@ def get_document(self, index: int, begin: int, end: int) -> Sample: class MemmapIndexedDatasetReader[ConfigType: MemmapIndexDatasetReaderConfig](MemmapReader[ConfigType]): + def __len__(self) -> int: + return len(self._config) + + @property + def num_tokens(self) -> int: + return self._config.num_tokens + @abc.abstractmethod def get_document_sizes(self) -> "torch.Tensor": pass @@ -159,7 +196,7 @@ def get_document_size(self, index: int) -> int: pass -class MemmapWriter: +class MemmapWriter(abc.ABC): def __init__(self, stream: io.BufferedWriter | pathlib.Path): self._owns_stream = isinstance(stream, pathlib.Path) if self._owns_stream: @@ -168,16 +205,23 @@ def __init__(self, stream: io.BufferedWriter | pathlib.Path): def __enter__(self): self._begin = self._stream.tell() + self._stream.write(self._get_config_class().header) return self def write(self, document: Sample): assert hasattr(self, "_begin") and not hasattr(self, "_end") def __exit__(self, exc_type, exc_val, exc_tb): + self._stream.write(self._get_config_class().footer) self._end = self._stream.tell() if self._owns_stream: self._stream.close() + @classmethod + @abc.abstractmethod + def _get_config_class(cls) -> type[MemmapReaderConfig]: + pass + def get_config(self, offset: int = 0) -> MemmapReaderConfig: assert hasattr(self, "_end") return self._get_config(self._begin + offset, self._end + offset) diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index 3d6964b30..d6f737c7b 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -125,12 +125,21 @@ def _crop_optional[T: Sample | Batch](sample_or_batch: T, begin: int, end: int) @config_class(dynamic_type={MemmapReaderBaseConfig: "language_model"}) class LanguageModelReaderConfig(MemmapIndexDatasetReaderConfig): _abstract = False + header: typing.ClassVar[bytes] = b"lm begin" + footer: typing.ClassVar[bytes] = b"lm end" tokens: TokenReaderConfig = Field() # Using dynamic type for optional readers for enabling/disabling loss_masking_spans: MemmapReaderBaseConfig = Field() chosen_spans: MemmapReaderBaseConfig = Field() rejected_spans: MemmapReaderBaseConfig = Field() + def __len__(self) -> int: + return len(self.tokens) + + @property + def num_tokens(self) -> int: + return self.tokens.num_tokens + @property def reader_class(self) -> "type[LanguageModelReader]": return LanguageModelReader @@ -140,7 +149,7 @@ def writer_class(self) -> "type[LanguageModelWriter]": return LanguageModelWriter @property - def expected_buffer_size(self) -> int: + def _expected_buffer_size(self) -> int: return ( self.tokens.expected_buffer_size + self.loss_masking_spans.expected_buffer_size @@ -155,13 +164,19 @@ def __init__(self, config: ConfigType, buffer: memoryview): # Using `buffer` and not `self._buffer` because nested offsets (`begin`, `end`) are global. self._tokens = self._config.tokens.get_reader(buffer) self._loss_masking_spans = self._config.loss_masking_spans.get_reader(buffer) - self._preference_spans = self._config.preference_spans.get_reader(buffer) + self._chosen_spans = self._config.chosen_spans.get_reader(buffer) + self._rejected_spans = self._config.rejected_spans.get_reader(buffer) + + @property + def num_tokens(self) -> int: + return self._config.tokens.num_tokens def get_document(self, index: int, begin: int, end: int) -> Sample: return LanguageModelSample( self._tokens.get_document(index, begin, end), - self._loss_masking_spans.get_document(index, begin, end), - self._preference_spans.get_document(index, begin, end), + None if self._loss_masking_spans is None else self._loss_masking_spans.get_document(index, begin, end), + None if self._chosen_spans is None else self._chosen_spans.get_document(index, begin, end), + None if self._rejected_spans is None else self._rejected_spans.get_document(index, begin, end), ) def get_document_sizes(self) -> torch.Tensor: @@ -248,8 +263,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._directory.cleanup() super().__exit__(exc_type, exc_val, exc_tb) + @classmethod + def _get_config_class(cls) -> type[LanguageModelReaderConfig]: + return LanguageModelReaderConfig + def _get_config(self, begin: int, end: int | None): - tokens = self._token_writer.get_config(begin) + tokens = self._token_writer.get_config(begin + len(LanguageModelReaderConfig.header)) offset = tokens.end if self._has_loss_masking_spans: loss_masking_spans = self._loss_masking_span_writer.get_config(offset) @@ -266,7 +285,7 @@ def _get_config(self, begin: int, end: int | None): rejected_spans = NullReaderConfig() if end is None: - end = offset + end = offset + len(LanguageModelReaderConfig.footer) return LanguageModelReaderConfig( begin=begin, diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py index 88dd1352d..92d5ce7fc 100644 --- a/fast_llm/data/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -63,6 +63,8 @@ def to_samples(self) -> list[RangeSample]: @config_class(dynamic_type={MemmapReaderBaseConfig: "range"}) class RangeReaderConfig(MemmapReaderConfig): _abstract = False + header: typing.ClassVar[bytes] = b"range begin" + footer: typing.ClassVar[bytes] = b"range end" num_documents: int = Field() num_ranges: int = Field() @@ -75,8 +77,8 @@ def writer_class(self) -> "type[RangeWriter]": return RangeWriter @property - def expected_buffer_size(self) -> int: - return (self.num_ranges + 1) * torch.uint32.itemsize * 2 + (self.num_documents + 1) * torch.uint32.itemsize + def _expected_buffer_size(self) -> int: + return self.num_ranges * torch.int32.itemsize * 2 + (self.num_documents + 1) * torch.int32.itemsize class RangeReader[ConfigType: RangeReaderConfig](MemmapReader[ConfigType]): @@ -84,17 +86,17 @@ def __init__(self, config: ConfigType, buffer: memoryview): super().__init__(config, buffer) self._ranges = torch.frombuffer( self._buffer, - dtype=torch.uint32, - count=self._config.num_ranges, + dtype=torch.int32, + count=self._config.num_ranges * 2, ).reshape(-1, 2) self._count_cumsums = torch.frombuffer( self._buffer, - dtype=torch.uint32, + dtype=torch.int32, count=self._config.num_documents + 1, offset=self._ranges.nbytes, ) - def get(self, index: int, begin: int, end: int) -> RangeSample: + def get_document(self, index: int, begin: int, end: int) -> Sample: sample_size = end - begin cropped_ranges = ( (max(begin_ - begin, 0), min(end_ - begin, sample_size)) @@ -110,15 +112,18 @@ def __enter__(self): return self def write(self, document: RangeSample): - # ====== TODO: Make sure input uses end = 1 past last index (currently use last index) ====== super().write(document) - self._stream.write(np.array(document.ranges, dtype=np.uint32).tobytes(order="C")) + self._stream.write(np.array(document.ranges, dtype=np.int32).tobytes(order="C")) self._count_cumsum.append(self._count_cumsum[-1] + len(document.ranges)) def __exit__(self, exc_type, exc_val, exc_tb): - self._stream.write(np.array(self._count_cumsum, dtype=np.uint32).tobytes(order="C")) + self._stream.write(np.array(self._count_cumsum, dtype=np.int32).tobytes(order="C")) super().__exit__(exc_type, exc_val, exc_tb) + @classmethod + def _get_config_class(cls) -> type[RangeReaderConfig]: + return RangeReaderConfig + def _get_config(self, begin: int, end: int): return RangeReaderConfig( begin=begin, diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index 98ee9a2a1..0e57209c5 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -91,6 +91,11 @@ class TokenReaderConfig(MemmapReaderConfig): num_documents: int = Field() num_tokens: int = Field() data_type: DataType = Field() + header: typing.ClassVar[bytes] = b"token begin" + footer: typing.ClassVar[bytes] = b"token end" + + def __len__(self) -> int: + return self.num_documents @property def reader_class(self) -> "type[TokenReader]": @@ -101,8 +106,8 @@ def writer_class(self) -> "type[TokenWriter]": return TokenWriter @property - def expected_buffer_size(self) -> int: - return self.num_tokens * self.data_type.torch.itemsize + (self.num_documents + 1) * torch.uint64.itemsize + def _expected_buffer_size(self) -> int: + return self.num_tokens * self.data_type.torch.itemsize + (self.num_documents + 1) * torch.int64.itemsize class TokenReader[ConfigType: TokenReaderConfig](MemmapIndexedDatasetReader[ConfigType]): @@ -114,12 +119,13 @@ def __init__(self, config: ConfigType, buffer: memoryview): count=self._config.num_tokens, ) self._size_cumsums = torch.frombuffer( - self._buffer, dtype=torch.uint64, count=self._config.num_documents + 1, offset=self._tokens.nbytes + self._buffer, dtype=torch.int64, count=self._config.num_documents + 1, offset=self._tokens.nbytes ) def get_document(self, index: int, begin: int, end: int) -> Sample: begin_ = self._size_cumsums[index].item() - return TokenSample(torch.from_numpy(self._tokens[begin_ + begin : begin_ + end]), [end - begin]) + # Torch doesn't support type promotion between signed and unsigned types, so we convert here to avoid issues. + return TokenSample(self._tokens[begin_ + begin : begin_ + end].to(torch.int64), [end - begin]) def get_document_sizes(self) -> torch.Tensor: return self._size_cumsums[1:] - self._size_cumsums[:-1] @@ -146,9 +152,13 @@ def write(self, document: TokenSample): self._size_cumsum.append(self._size_cumsum[-1] + len(document.tokens)) def __exit__(self, exc_type, exc_val, exc_tb): - self._stream.write(np.array(self._size_cumsum, dtype=np.uint64).tobytes(order="C")) + self._stream.write(np.array(self._size_cumsum, dtype=np.int64).tobytes(order="C")) super().__exit__(exc_type, exc_val, exc_tb) + @classmethod + def _get_config_class(cls) -> type[TokenReaderConfig]: + return TokenReaderConfig + def _get_config(self, begin: int, end: int): return TokenReaderConfig( begin=begin, diff --git a/tests/data/common.py b/tests/data/common.py index e6ab8a265..7053666b8 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -111,7 +111,7 @@ def get_test_data_and_compare_samples( for phase, samples in samples_per_dataset.items() } for phase, expected_samples_ in expected_samples.items(): - Assert.all_equal(tokens[phase].to(torch.int64), expected_samples_) + Assert.all_equal(tokens[phase], expected_samples_) return data @@ -130,7 +130,7 @@ def compare_indexed_dataset( sizes[: min(len(dataset), 100)], ) for i, expected_sample in expected_samples.items(): - Assert.all_equal(dataset.get_document(i).tokens.tokens, np.array(expected_sample, dtype=np.int64)) + Assert.all_equal(dataset.get_document(i).tokens.tokens, np.array(expected_sample)) if loss_masking_spans: for i, loss_masking_span in loss_masking_spans.items(): print(i) @@ -147,9 +147,7 @@ def compare_indexed_dataset( def compare_sampled_dataset(sampled: SampledDataset, expected_samples: list[list[int] | np.ndarray]) -> None: Assert.eq(len(sampled), len(expected_samples)) - Assert.all_equal( - torch.stack([sampled[i].tokens.tokens for i in range(len(expected_samples))]).to(torch.int64), expected_samples - ) + Assert.all_equal(torch.stack([sampled[i].tokens.tokens for i in range(len(expected_samples))]), expected_samples) def validate_indexed_dataset_sampling(sampled: SampledIndexedDataset, expected_samples: list[list[int]] | None = None): @@ -210,6 +208,9 @@ class MockGPTMemmapDatasetConfig(IndexedDatasetConfig): def build(self) -> "IndexedDataset": return MockMemmapDataset(self) + def __len__(self) -> int: + return self.num_documents + @property def num_tokens(self) -> int: return self.num_documents * self.num_tokens_per_document @@ -224,7 +225,11 @@ def name(self) -> str: return "mock_memmap" def __len__(self) -> int: - return self._config.num_documents + return len(self._config) + + @property + def num_tokens(self) -> int: + return self._config.num_tokens def get_document_sizes(self) -> torch.Tensor: return torch.full([self._config.num_documents], self._config.num_tokens_per_document, dtype=torch.int64) diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 49eceee0b..b2b2f0117 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -15,11 +15,11 @@ from tests.utils.dataset import get_test_dataset from tests.utils.global_variables import DATASET_CACHE, DATASET_PATH -_DATASET_PREFIX_MIX_1 = DATASET_CACHE / "blended_mix_1" / "dataset" +_DATASET_PATH_MIX_1 = DATASET_CACHE / "blended_mix_1" / "dataset" def _get_test_dataset_mix_1(): - return get_test_dataset(prefix=_DATASET_PREFIX_MIX_1, seed=2345) + return get_test_dataset(_DATASET_PATH_MIX_1, seed=2345) def _get_blending_alt(probs: list[float], num_samples: int) -> tuple[np.ndarray, np.ndarray]: @@ -119,7 +119,7 @@ def test_gpt_blended(): "type": "blended", "datasets": [ {"type": "memmap", "path": DATASET_PATH}, - {"type": "memmap", "path": _DATASET_PREFIX_MIX_1}, + {"type": "memmap", "path": _DATASET_PATH_MIX_1}, ], "weights": [0.75, 0.25], }, @@ -138,7 +138,7 @@ def test_gpt_blended_data(): "type": "blended", "datasets": [ {"type": "memmap", "path": DATASET_PATH}, - {"type": "memmap", "path": _DATASET_PREFIX_MIX_1}, + {"type": "memmap", "path": _DATASET_PATH_MIX_1}, ], "weights": [0.75, 0.25], } diff --git a/tests/data/test_memmap.py b/tests/data/test_memmap.py index 419b67903..b11f84d9c 100644 --- a/tests/data/test_memmap.py +++ b/tests/data/test_memmap.py @@ -27,8 +27,8 @@ def test_gpt_memmap(cache_directory): MEMMAP_DATASET_SPANS = { 9: [], - 10: [(0, 2), (2, 7), (7, 10)], - 13: [(0, 2)], + 10: [(0, 1), (2, 6), (7, 9)], + 13: [(0, 1)], 15: [], } diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index 1608bb48c..9647264e7 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -46,8 +46,8 @@ def test_write_memmap_dataset(dtype): Assert.all_equal(dataset.get_document(i).tokens.tokens, document.tokens.tokens.to(torch.int64)) -def _generate_valid_span(max_seq_length): - return np.sort(np.random.choice(np.arange(0, max_seq_length - 1), size=2, replace=False)).tolist() +def _generate_valid_span(max_seq_length) -> tuple[int, int]: + return tuple(np.sort(np.random.choice(np.arange(0, max_seq_length - 1), size=2, replace=False)).tolist()) @pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) @@ -56,8 +56,8 @@ def test_write_memmap_preference_dataset(dtype): LanguageModelSample( TokenSample(torch.from_numpy(np.random.randint(1000, size=100).astype(dtype))), None, - RangeSample(_generate_valid_span(100), 100), - RangeSample(_generate_valid_span(100), 100), + RangeSample([_generate_valid_span(100)], 100), + RangeSample([_generate_valid_span(100)], 100), ) for _ in range(50) ] @@ -128,6 +128,7 @@ def test_split_dataset(): dataset_config_0 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_0.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0], + [dataset_config_0], # Mock reader config. {"training": 3, "validation": 1}, pathlib.Path("."), ) @@ -157,6 +158,7 @@ def test_split_datasets_0(): dataset_config_1 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_1.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0, dataset_config_1], + [dataset_config_0, dataset_config_1], # Mock reader configs. {"training": 1, "validation": 1}, pathlib.Path("."), ) @@ -175,7 +177,10 @@ def test_split_datasets_1(): dataset_config_0 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_0.copy()) dataset_config_1 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_1.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( - [dataset_config_0, dataset_config_1], {"training": 3, "validation": 1}, pathlib.Path(".") + [dataset_config_0, dataset_config_1], + [dataset_config_0, dataset_config_1], # Mock reader configs. + {"training": 3, "validation": 1}, + pathlib.Path("."), ) config = {key: value.to_dict() for key, value in config.items()} From 34939e930b2d2e1bd3d636c05d2f91e303bbafa1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Oct 2025 19:05:31 -0400 Subject: [PATCH 14/45] fixes --- fast_llm/data/dataset/config.py | 1 - fast_llm/data/dataset/gpt/legacy_memmap.py | 14 ++- fast_llm/data/dataset/memmap.py | 11 +-- fast_llm/data/dataset/sampled.py | 1 - fast_llm/data/preparator/gpt_memmap/config.py | 17 +--- .../data/preparator/gpt_memmap/prepare.py | 3 +- fast_llm/data/sample/abstract.py | 6 +- fast_llm/data/sample/language_model.py | 1 - fast_llm/data/sample/token.py | 1 - fast_llm/functional/dpo.py | 2 +- tests/data/test_prepare_gpt_memmap.py | 3 +- tests/models/test_match_megatron.py | 89 +++++++++++++++++-- tests/utils/dataset.py | 56 +++++++----- tests/utils/global_variables.py | 6 +- tests/utils/model_configs.py | 1 - 15 files changed, 149 insertions(+), 63 deletions(-) diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index f60decd81..7611b4a31 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -109,7 +109,6 @@ class SampledDatasetConfig[SampleType: Sample](DatasetConfig[SampleType]): """ def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: - # TODO: ====== `SamplingData` contains more than needed (ex. `num_samples`) raise NotImplementedError() diff --git a/fast_llm/data/dataset/gpt/legacy_memmap.py b/fast_llm/data/dataset/gpt/legacy_memmap.py index d8c63e9f9..2a23e378b 100644 --- a/fast_llm/data/dataset/gpt/legacy_memmap.py +++ b/fast_llm/data/dataset/gpt/legacy_memmap.py @@ -6,12 +6,24 @@ from fast_llm.data.dataset.gpt.config import GPTSamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset -from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_INDEX_HEADER from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.range import RangeSample from fast_llm.data.sample.token import TokenSample +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert, div +MEMMAP_DTYPES = { + 1: DataType.uint8, + 2: DataType.int8, + 3: DataType.int16, + 4: DataType.int32, + 5: DataType.int64, + 6: DataType.float32, + 7: DataType.float64, + 8: DataType.uint16, +} +MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x00" + class LegacyMemmapDataset[SampleType: LanguageModelSample](IndexedDataset[SampleType]): """ diff --git a/fast_llm/data/dataset/memmap.py b/fast_llm/data/dataset/memmap.py index ffb2bc6d1..e51dfb40d 100644 --- a/fast_llm/data/dataset/memmap.py +++ b/fast_llm/data/dataset/memmap.py @@ -40,14 +40,15 @@ def _init(self, name: str, path: pathlib.Path | str) -> None: ) self._memmap = np.memmap(self._path, mode="r") - # TODO: ===== Check num_documents, num_tokens ====== self._reader = reader_config.get_reader(memoryview(self._memmap)) - def __getstate__(self) -> tuple[str, pathlib.Path]: - return (self._name, self._path) + def __getstate__(self) -> tuple[str, pathlib.Path, MemmapIndexDatasetReaderConfig]: + # We pass the reader config to force its import in data loader workers. + return self._name, self._path, self._reader.config - def __setstate__(self, state: tuple[str, pathlib.Path]): - self._init(*state) + def __setstate__(self, state: tuple[str, pathlib.Path, MemmapIndexDatasetReaderConfig]): + name, path, _ = state + self._init(name, path) def __del__(self): if hasattr(self, "_memmap"): diff --git a/fast_llm/data/dataset/sampled.py b/fast_llm/data/dataset/sampled.py index 46a518cd0..d51a68746 100644 --- a/fast_llm/data/dataset/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -414,7 +414,6 @@ def __getitem__(self, index: int) -> SampleType: document_sampling_index += 1 token_count += document_size - # TODO: ====== Better way to get the class method? ====== return documents[0].from_documents(documents) @property diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 7dd520ec3..a54465080 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -12,18 +12,6 @@ if typing.TYPE_CHECKING: from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator -MEMMAP_DTYPES = { - 1: DataType.uint8, - 2: DataType.int8, - 3: DataType.int16, - 4: DataType.int32, - 5: DataType.int64, - 6: DataType.float32, - 7: DataType.float64, - 8: DataType.uint16, -} -MEMMAP_DTYPES_INV = {y: x for x, y in MEMMAP_DTYPES.items()} -MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x00" @config_class() @@ -66,7 +54,6 @@ def _validate(self): if self.has_loss_masking_span != self.rejected_spans_column is not None: raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.") if self.has_preference_spans and self.has_loss_masking_span: - # TODO: ====== Still needed? ====== raise ValueError(f"Can not enable both loss masking and preference spans.") @@ -204,10 +191,8 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): ) def _validate(self) -> None: - assert self.tokenizer.path is not None - if self.dataset.data_type is not None: - Assert.incl(DataType.from_numpy(self.dataset.data_type.numpy), MEMMAP_DTYPES_INV) super()._validate() + assert self.tokenizer.path is not None @classmethod def get_dataset_preparator_class(cls) -> type["GPTMemmapDatasetPreparator"]: diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 06a4bd517..d3d15fa64 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -218,7 +218,7 @@ def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: text = sample[self._source_schema.text_column] all_spans = [] if self._source_schema.has_loss_masking_span: - # TODO: ====== What is the input format? ====== + # TODO: ====== What is the exact input format? ====== # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format. loss_masking_spans = _sort_spans( (begin, last + 1) @@ -333,7 +333,6 @@ def _split_and_blend_dataset_configs( splits: dict[str, int | float], output_path: pathlib.Path, ) -> dict[str, SampledDatasetConfig[_sample_type]]: - # TODO: ====== Missing `num_tokens`, `num_documents`. ====== split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist() dataset_sizes = [reader_config.num_tokens for reader_config in reader_configs] dataset_probabilities = normalize_probabilities(dataset_sizes) diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py index 9afc6124c..0b2e324c3 100644 --- a/fast_llm/data/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -145,14 +145,12 @@ class MemmapIndexDatasetReaderConfig(MemmapReaderConfig): consisting of a list of documents of known lengths. """ - @abc.abstractmethod def __len__(self) -> int: - pass + raise NotImplementedError() @property - @abc.abstractmethod def num_tokens(self) -> int: - pass + raise NotImplementedError() @property def reader_class(self) -> "type[MemmapIndexedDatasetReader]": diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index d6f737c7b..77cc6e8a2 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -205,7 +205,6 @@ def __enter__(self): return self def write(self, document: LanguageModelSample): - # ====== TODO: Make sure input uses end = 1 past last index (currently use last index) ====== super().write(document) # Write tokens. self._token_writer.write(document.tokens) diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index 0e57209c5..ae190658f 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -142,7 +142,6 @@ def __enter__(self): return self def write(self, document: TokenSample): - # ====== TODO: Make sure input uses end = 1 past last index (currently use last index) ====== super().write(document) if self._data_type is None: self._data_type = document.tokens.dtype diff --git a/fast_llm/functional/dpo.py b/fast_llm/functional/dpo.py index 7ab0b9ff6..c5ae48eba 100644 --- a/fast_llm/functional/dpo.py +++ b/fast_llm/functional/dpo.py @@ -37,7 +37,7 @@ def compute_dpo_loss( reference_log_probabilities, chosen_spans ) - _get_target_log_probability_for_spans(reference_log_probabilities, rejected_spans) - # TODO: ====== Shouldn't the sigmoid be computed independently for each document? + # TODO: ====== Shouldn't the sigmoid be computed independently for each document? ======= losses = -torch.nn.functional.logsigmoid(beta * (policy_log_ratios - reference_log_ratios)) if grad_output is None: diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index 9647264e7..09a91d6a8 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -8,8 +8,9 @@ from fast_llm.data.dataset.config import IndexedDatasetConfig from fast_llm.data.dataset.gpt.config import GPTSamplingParameters +from fast_llm.data.dataset.gpt.legacy_memmap import MEMMAP_DTYPES from fast_llm.data.dataset.memmap import MemmapDataset -from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, GPTMemmapDatasetPreparatorConfig +from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator from fast_llm.data.sample.language_model import LanguageModelSample, LanguageModelWriter from fast_llm.data.sample.range import RangeSample diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index 42a7c1f0d..4b057dabd 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -1,21 +1,25 @@ import os +import pathlib +import struct import typing import numpy as np import pytest +import yaml from fast_llm.config import Field, FieldHint, config_class from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.config import MemmapDatasetConfig, SampledDatasetConfig from fast_llm.data.dataset.gpt.config import GPTSamplingData -from fast_llm.data.dataset.gpt.legacy_memmap import LegacyMemmapDataset +from fast_llm.data.dataset.gpt.legacy_memmap import MEMMAP_DTYPES, MEMMAP_INDEX_HEADER, LegacyMemmapDataset from fast_llm.data.dataset.sampled import logger from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert from tests.utils.compare_tensor_logs import CompareConfig -from tests.utils.dataset import get_model_test_dataset +from tests.utils.dataset import get_test_dataset_samples from tests.utils.distributed_configs import DistributedTestingConfig -from tests.utils.global_variables import MODEL_DATASET_PATH +from tests.utils.global_variables import DATASET_CACHE, MODEL_TEST_VOCAB_SIZE from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -26,6 +30,20 @@ except ImportError: _extension_available = False +MEGATRON_DATASET_PREFIX = DATASET_CACHE / "megatron_dataset/dataset" + + +def get_megatron_test_dataset(prefix: pathlib.Path = MEGATRON_DATASET_PREFIX): + if not ( + prefix.with_suffix(".idx").is_file() + and prefix.with_suffix(".bin").is_file() + and prefix.parent.joinpath("fast_llm_config.yaml").is_file() + ): + MegatronMemmapDataset.write_dataset(prefix, get_test_dataset_samples(vocab_size=MODEL_TEST_VOCAB_SIZE)) + yaml.safe_dump( + {"type": "memmap", "path": prefix.name}, prefix.parent.joinpath("fast_llm_config.yaml").open("w") + ) + @requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.megatron) @@ -35,11 +53,12 @@ def test_megatron(run_distributed_script, model_testing_config, run_test_script_ # Prevent Megatron from complaining. env["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" env["NVTE_FLASH_ATTN"] = "0" - get_model_test_dataset() + get_megatron_test_dataset() run_distributed_script( [ "Megatron-LM/pretrain_gpt.py", *model_testing_config.megatron_args, + f"--data-path={MEGATRON_DATASET_PREFIX}", f"--structured-logs-dir={path}", f"--data-cache-path={path}", ], @@ -69,7 +88,7 @@ def test_match_megatron(run_test_script_for_all_models, model_testing_config, co compare="megatron", config_args=[ "model.distributed.compute_dtype=fp32", - f'data.datasets.training={{"type":"megatron","path":{MODEL_DATASET_PATH}}}', + f'data.datasets.training={{"type":"megatron","path":{MEGATRON_DATASET_PREFIX}}}', "data.sampling.seed=1234", "model.base_model.use_megatron_initialization=True", ], @@ -97,6 +116,66 @@ class MegatronMemmapDataset(LegacyMemmapDataset): def sample(self, sampling: GPTSamplingData) -> "MegatronSampledIndexedDataset": return MegatronSampledIndexedDataset(self, sampling) + @classmethod + def write_dataset( + cls, + prefix: pathlib.Path | str, + documents: typing.Iterable[LanguageModelSample], + ) -> None: + # Initialize metadata + dtype = None + num_documents = 0 + lengths = [] + pointers = [] + offset = 0 + + prefix = pathlib.Path(prefix) + prefix.parent.mkdir(parents=True, exist_ok=True) + + # Write the binary data file (.bin) lazily + with prefix.with_suffix(".bin").open("wb") as bin_stream: + for document in documents: + token_ids = document.tokens.tokens + # Infer dtype from the first document + if dtype is None: + dtype = token_ids.dtype + assert dtype is not None, "Document dtype could not be inferred from the data." + + # Ensure all documents have the same dtype + assert token_ids.dtype == dtype, f"Expected dtype {dtype}, got {token_ids.dtype}." + + # Write document to binary file + bin_stream.write(token_ids.numpy().tobytes(order="C")) + + # Update metadata + doc_length = len(token_ids) + lengths.append(doc_length) + pointers.append(offset) + offset += doc_length * dtype.itemsize + num_documents += 1 + + # Finalize metadata arrays + lengths = np.array(lengths, dtype=np.int32) + pointers = np.array(pointers, dtype=np.int64) + + # Write the index file (.idx) + with prefix.with_suffix(".idx").open("wb") as idx_stream: + idx_stream.write(MEMMAP_INDEX_HEADER) + # Version + idx_stream.write(struct.pack(" list[LanguageModelSample]: + import transformers + + download_santacoder_tokenizer() + + texts = "".join(random.Random(seed).choices(characters, k=num_tokens)).splitlines() + tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) + + samples = [ + LanguageModelSample( + TokenSample(torch.from_numpy(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size)), + ) + for document in texts + ] + if max_spans > 0: + spans = get_random_spans( + len(samples), max_spans, np.array([[max(len(sample), 1)] for sample in samples]), seed + ) + for sample, sample_spans in zip(samples, spans, strict=True): + sample.loss_masking_spans = RangeSample(sample_spans, len(sample)) + return samples + + def get_test_dataset( path: pathlib.Path = DATASET_PATH, seed: int = 1234, @@ -45,29 +74,16 @@ def get_test_dataset( vocab_size: int = TEST_VOCAB_SIZE, max_spans: int = 0, ): - download_santacoder_tokenizer() config_path = path.parent.joinpath("fast_llm_config.yaml") if not (path.is_file() and config_path.is_file()): - import transformers - - texts = "".join(random.Random(seed).choices(characters, k=num_tokens)).splitlines() - tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) - - samples = [ - LanguageModelSample( - TokenSample( - torch.from_numpy(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size) - ), - ) - for document in texts - ] - if max_spans > 0: - spans = get_random_spans( - len(samples), max_spans, np.array([[max(len(sample), 1)] for sample in samples]), seed - ) - for sample, sample_spans in zip(samples, spans, strict=True): - sample.loss_masking_spans = RangeSample(sample_spans, len(sample)) + samples = get_test_dataset_samples( + seed=seed, + num_tokens=num_tokens, + characters=characters, + vocab_size=vocab_size, + max_spans=max_spans, + ) MemmapDataset.write_dataset(path, samples, LanguageModelWriter) yaml.safe_dump({"type": "memmap", "path": path.name}, config_path.open("w")) diff --git a/tests/utils/global_variables.py b/tests/utils/global_variables.py index c62903a6c..ea770be0a 100644 --- a/tests/utils/global_variables.py +++ b/tests/utils/global_variables.py @@ -37,13 +37,13 @@ def set_testing_global_variables(): TOKENIZER_PATH = SHARED_RESULT_PATH / "tokenizer" TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json" DATASET_CACHE = SHARED_RESULT_PATH / "dataset" -DATASET_PATH = DATASET_CACHE / "common_dataset.fast_llm_dataset" -DATASET_WITH_SPANS_PATH = DATASET_CACHE / "dataset_with_spans.fast_llm_dataset" +DATASET_PATH = DATASET_CACHE / "common_dataset/dataset.fast_llm_dataset" +DATASET_WITH_SPANS_PATH = DATASET_CACHE / "dataset_with_spans/dataset.fast_llm_dataset" DATASET_SAMPLING_CACHE = TEST_RESULTS_PATH / "dataset_sampling_cache" TEST_VOCAB_SIZE = 8192 # Random lowercase: 80.7% (3.1% each); space: 18.6%; doc end: 0.6% TEST_CHARACTERS = (string.ascii_lowercase) * 5 + " " * 30 + "\n" TEST_DATASET_TOKENS = 1000000 -MODEL_DATASET_PATH = DATASET_CACHE / "model_dataset.fast_llm_dataset" +MODEL_DATASET_PATH = DATASET_CACHE / "model_dataset/dataset.fast_llm_dataset" MODEL_TEST_VOCAB_SIZE = 384 diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index adcf84b18..ee9c2b730 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -279,7 +279,6 @@ def _update_and_add_testing_config( "--tokenizer-type=NullTokenizer", # Megatron messes with the vocab size, so we have to subtract 1. f"--vocab-size={MODEL_TEST_VOCAB_SIZE - 1}", - f"--data-path={MODEL_DATASET_PATH}", "--split=1,0,0", "--lr-decay-style=constant", # Initialization is set up to match MCore models (MCore inverts self-attn qkv and dense layers compared to original Megatron) From c5fa07214aab5fd230d9e62aaf6bd0a38e5e1588 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Oct 2025 19:46:03 -0400 Subject: [PATCH 15/45] int64 --- fast_llm/data/dataset/memmap.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/data/dataset/memmap.py b/fast_llm/data/dataset/memmap.py index e51dfb40d..4b1930dd3 100644 --- a/fast_llm/data/dataset/memmap.py +++ b/fast_llm/data/dataset/memmap.py @@ -33,7 +33,7 @@ def _init(self, name: str, path: pathlib.Path | str) -> None: # Very file type. assert stream.read(len(FILE_HEADER)) == FILE_HEADER # Go to reader configs. - stream.seek(int.from_bytes(stream.read(4), signed=False)) + stream.seek(int.from_bytes(stream.read(8), signed=False)) # Read the reader config. reader_config = MemmapIndexDatasetReaderConfig.from_dict( json.loads(stream.read(int.from_bytes(stream.read(4), signed=False)).decode("utf-8")) @@ -91,7 +91,7 @@ def write_dataset( # Leave space for a pointer to the reader config. # We write the config at the end since we don't know it yet. start = stream.tell() - stream.seek(start + 4) + stream.seek(start + 8) # Write the data. reader_config = writer_class.write_dataset(stream, documents) # Write the reader config. @@ -101,5 +101,5 @@ def write_dataset( stream.write(reader_config_bytes) # Write a pointer to the reader config. stream.seek(start) - stream.write(config_offset.to_bytes(4, signed=False)) + stream.write(config_offset.to_bytes(8, signed=False)) return reader_config From cd286766f08b0a470da007628ddae27ddcc9f583 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 4 Nov 2025 21:29:42 -0500 Subject: [PATCH 16/45] Test and fix preparator --- fast_llm/data/config.py | 41 ---- fast_llm/data/dataset/gpt/config.py | 5 +- fast_llm/data/dataset/gpt/fim.py | 8 +- fast_llm/data/dataset/gpt/random.py | 1 + fast_llm/data/dataset/indexed.py | 3 +- fast_llm/data/preparator/config.py | 1 - fast_llm/data/preparator/gpt_memmap/config.py | 60 +++-- .../data/preparator/gpt_memmap/prepare.py | 189 +++++++-------- fast_llm/data/preprocessing/__init__.py | 0 fast_llm/data/preprocessing/tokenizer.py | 196 ++++++++++++++++ fast_llm/data/sample/abstract.py | 5 +- fast_llm/data/sample/language_model.py | 52 ++-- fast_llm/data/sample/range.py | 8 +- fast_llm/data/sample/token.py | 7 +- fast_llm/engine/config_utils/data_type.py | 6 +- fast_llm/engine/config_utils/runnable.py | 2 +- fast_llm/engine/evaluation/config.py | 2 +- fast_llm/utils.py | 23 +- tests/data/common.py | 87 +------ tests/data/test_blending.py | 90 +++---- tests/data/test_concatenate.py | 50 ++-- tests/data/test_dataset_from_file.py | 12 - tests/data/test_fim.py | 51 ++-- tests/data/test_loss_masking_spans.py | 78 ++++++ tests/data/test_memmap.py | 47 ---- tests/data/test_preference_spans.py | 105 +++++++++ tests/data/test_preparator.py | 197 ++++++++++++++++ tests/data/test_prepare_gpt_memmap.py | 211 ----------------- tests/data/test_random.py | 16 +- tests/data/test_sampling.py | 45 ++-- tests/data/test_slice.py | 56 ++--- tests/functional/test_functional.py | 11 +- tests/models/test_match_megatron.py | 20 +- tests/utils/dataset.py | 222 ++++++++++++------ tests/utils/global_variables.py | 14 +- tests/utils/model_configs.py | 8 +- 36 files changed, 1075 insertions(+), 854 deletions(-) create mode 100644 fast_llm/data/preprocessing/__init__.py create mode 100644 fast_llm/data/preprocessing/tokenizer.py delete mode 100644 tests/data/test_dataset_from_file.py create mode 100644 tests/data/test_loss_masking_spans.py delete mode 100644 tests/data/test_memmap.py create mode 100644 tests/data/test_preference_spans.py create mode 100644 tests/data/test_preparator.py delete mode 100644 tests/data/test_prepare_gpt_memmap.py diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 633367c80..78bc20636 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -1,12 +1,4 @@ import enum -import pathlib -import typing - -from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.utils import Assert - -if typing.TYPE_CHECKING: - from fast_llm.data.tokenizer import Tokenizer class MultiprocessingContext(str, enum.Enum): @@ -15,36 +7,3 @@ class MultiprocessingContext(str, enum.Enum): fork = "fork" # Safe but much slower. spawn = "spawn" - - -TokenizerFromFile = "TokenizerFromFile" - - -@config_class() -class TokenizerConfig(Config): - """ - Configuration for the tokenizer. - The tokenizer is needed for FIM and dataset preparation. - """ - - format: str = Field( - default="TokenizerFromFile", - desc="Unused.", - hint=FieldHint.deprecated, - valid=check_field(Assert.eq, TokenizerFromFile), - ) - path: pathlib.Path = Field( - default=None, - desc="Path to the tokenizer file.", - hint=FieldHint.core, - ) - bos_token: str | None = Field( - default=None, - desc="BOS token to use if the tokenizer doesn't define one; must be an existing token.", - hint=FieldHint.core, - ) - - def get_tokenizer(self) -> "Tokenizer": - from fast_llm.data.tokenizer import Tokenizer - - return Tokenizer(self) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 9ff6654c2..7583345c3 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -6,9 +6,9 @@ import yaml from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.data.config import TokenizerConfig from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.dataset.config import SamplableDatasetConfig, SampledDatasetConfig, SamplingData, SamplingParameters +from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert @@ -23,7 +23,8 @@ class GPTSamplingParameters(SamplingParameters): Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model. """ - vocab_size: int + # TODO: Only used for random dataset. Remove? Or use as safety check? + vocab_size: int | None = None use_loss_masking_spans: bool = False use_preference_loss_spans: bool = False diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 1fde74530..d36384ee5 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -5,6 +5,7 @@ from fast_llm.data.dataset.gpt.config import FimConfig, GPTSamplingData from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.token import TokenSample +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import MAX_SEED @@ -168,9 +169,10 @@ def _fim_permute_sequence( middle = contents[boundaries[0] : boundaries[1]] suffix = contents[boundaries[1] :] - prefix = np.array([*self._tokenizer.tokenize(prefix, end=False)], dtype=sequence.dtype) - middle = np.array([*self._tokenizer.tokenize(middle, begin=False, end=False)], dtype=sequence.dtype) - suffix = np.array([*self._tokenizer.tokenize(suffix, begin=False)], dtype=sequence.dtype) + data_type = DataType.from_numpy(sequence.dtype) + prefix = self._tokenizer.tokenize(prefix, end=False, data_type=data_type).numpy() + middle = self._tokenizer.tokenize(middle, begin=False, end=False, data_type=data_type).numpy() + suffix = self._tokenizer.tokenize(suffix, begin=False, data_type=data_type).numpy() # here we truncate each given segment to fit the same length as it was before # A consequence is that we never reach the end of a file? diff --git a/fast_llm/data/dataset/gpt/random.py b/fast_llm/data/dataset/gpt/random.py index 463c5a7d6..f1e73c595 100644 --- a/fast_llm/data/dataset/gpt/random.py +++ b/fast_llm/data/dataset/gpt/random.py @@ -29,6 +29,7 @@ def __init__(self, sampling: GPTSamplingData, name: str): self._name = name self._seed = sampling.config.seed self._parameters = sampling.parameters + assert self._parameters.vocab_size is not None # TODO: Support? assert not self._parameters.use_loss_masking_spans assert not self._parameters.use_preference_loss_spans diff --git a/fast_llm/data/dataset/indexed.py b/fast_llm/data/dataset/indexed.py index 5d6636f7f..b2e6f7e3d 100644 --- a/fast_llm/data/dataset/indexed.py +++ b/fast_llm/data/dataset/indexed.py @@ -117,11 +117,12 @@ def __init__( def __len__(self) -> int: return self._dataset_splits[-1].item() + @property def num_tokens(self) -> int: """ Number of tokens in the dataset. """ - return sum(len(dataset) for dataset in self._datasets) + return sum(dataset.num_tokens for dataset in self._datasets) def get_document_sizes(self) -> torch.Tensor: # TODO: This can be really big. diff --git a/fast_llm/data/preparator/config.py b/fast_llm/data/preparator/config.py index 160fccafc..a774fc3de 100644 --- a/fast_llm/data/preparator/config.py +++ b/fast_llm/data/preparator/config.py @@ -7,7 +7,6 @@ @config_class(registry=True, dynamic_type={RunnableConfig: "prepare"}) class DatasetPreparatorConfig(RunnableConfig): - preparator_name: typing.ClassVar[str] @classmethod def get_dataset_preparator_class(cls) -> type["DatasetPreparator"]: diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index a54465080..9bf292033 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -4,8 +4,8 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.data.config import TokenizerConfig from fast_llm.data.preparator.config import DatasetPreparatorConfig +from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.utils import Assert @@ -16,50 +16,53 @@ @config_class() class LanguageModelSourceConfig(Config): - text_column: str = Field( + """ + A schema holding the name of each relevant column in the dataset. + Setting optional entries will enable the associated feature. + """ + + text: str = Field( default="text", desc="Field of the dataset to use.", hint=FieldHint.optional, ) - loss_masking_spans_column: None | str = Field( + loss_masking_spans: None | str = Field( default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional ) - chosen_spans_column: None | str = Field( + chosen_span: None | str = Field( default=None, desc="Field containing chosen text for preference optimization", hint=FieldHint.optional ) - rejected_spans_column: None | str = Field( + rejected_span: None | str = Field( default=None, desc="Field containing rejected text for preference optimization", hint=FieldHint.optional ) @functools.cached_property def columns(self) -> list[str]: - columns = [self.text_column] + columns = [self.text] if self.has_loss_masking_span: - columns.append(self.loss_masking_spans_column) + columns.append(self.loss_masking_spans) if self.has_preference_spans: - columns.extend([self.chosen_spans_column, self.rejected_spans_column]) + columns.extend([self.chosen_span, self.rejected_span]) return columns @functools.cached_property def has_loss_masking_span(self) -> bool: - return self.loss_masking_spans_column is not None + return self.loss_masking_spans is not None @functools.cached_property def has_preference_spans(self) -> bool: - Assert.eq(self.chosen_spans_column is None, self.rejected_spans_column is None) - return self.chosen_spans_column is not None + Assert.eq(self.chosen_span is None, self.rejected_span is None) + return self.chosen_span is not None def _validate(self): super()._validate() - if self.has_loss_masking_span != self.rejected_spans_column is not None: - raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.") if self.has_preference_spans and self.has_loss_masking_span: raise ValueError(f"Can not enable both loss masking and preference spans.") @config_class() class GPTHuggingfaceDatasetConfig(Config): - path: str = Field( + path: str | pathlib.Path = Field( default=None, desc="Name or path of the dataset.", hint=FieldHint.core, @@ -104,6 +107,11 @@ class GPTHuggingfaceDatasetConfig(Config): desc="Disable disk space check. Useful for environments where disk space is not accurately reported.", hint=FieldHint.optional, ) + load_from_disk: bool = Field( + default=False, + desc="Use the `load_from_disk` method for datasets saved with `save_to_disk`.", + hint=FieldHint.feature, + ) @config_class() @@ -141,7 +149,6 @@ def _validate(self) -> None: @config_class(dynamic_type={RunnableConfig: "prepare_gpt_memmap", DatasetPreparatorConfig: "gpt_memmap"}) class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): - preparator_name: typing.ClassVar[str] = "gpt_memmap" output_path: pathlib.Path = Field( default=None, desc="Output directory for the processed dataset.", @@ -151,27 +158,14 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): desc="Configuration for distributed processing.", hint=FieldHint.feature, ) - tokens_per_shard: int = Field( - default=10**9, - desc="Approximate number of tokens per shard.", + documents_per_shard: int = Field( + default=10**6, + desc="Target number of documents per shard.", hint=FieldHint.feature, - valid=check_field(Assert.geq, 10**5), - ) - loading_workers: int = Field( - default=1, - desc="Number of workers in load_dataset() call.", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 1), - ) - tokenize_workers: int = Field( - default=1, - desc="Number of workers for tokenization.", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 1), ) - saving_workers: int = Field( + num_workers: int = Field( default=1, - desc="Number of processes for saving the data.", + desc="Number of parallel workers.", hint=FieldHint.optional, valid=check_field(Assert.geq, 1), ) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index d3d15fa64..18d4d46e2 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -1,3 +1,5 @@ +import collections +import enum import json import logging import math @@ -25,17 +27,24 @@ from fast_llm.data.dataset.memmap import MemmapDataset from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, LanguageModelSourceConfig +from fast_llm.data.preprocessing.tokenizer import Tokenizer from fast_llm.data.sample.abstract import MemmapIndexDatasetReaderConfig from fast_llm.data.sample.language_model import LanguageModelSample, LanguageModelWriter from fast_llm.data.sample.range import RangeSample from fast_llm.data.sample.token import TokenSample -from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type +from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum logger = logging.getLogger(__name__) +class SpanType(enum.StrEnum): + loss_masking = "loss_masking" + chosen = "chosen" + rejected = "rejected" + + class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](DatasetPreparator[ConfigType]): _tokenizer: Tokenizer _data_type: DataType @@ -46,30 +55,19 @@ def __init__(self, config: ConfigType): super().__init__(config) self._source_schema: LanguageModelSourceConfig = self._config.dataset.source_schema - def _save_shard( - self, args: tuple[int, datasets.Dataset] - ) -> tuple[MemmapDatasetConfig, MemmapIndexDatasetReaderConfig]: - shard_index, shard_dataset = args - file_name = f"shard_{self._config.distributed.rank}_{shard_index}.fast_llm_dataset" - - reader_config = MemmapDataset.write_dataset( - self._config.output_path / file_name, - tqdm.tqdm((sample["sample"] for sample in shard_dataset), desc=f"Saving shard {shard_index}", unit="docs"), - LanguageModelWriter, - ) - - return MemmapDatasetConfig.from_dict({"type": "memmap", "path": file_name}), reader_config - def _load_dataset(self) -> datasets.Dataset: - dataset = datasets.load_dataset( - path=self._config.dataset.path, - name=self._config.dataset.config_name, - data_dir=self._config.dataset.data_directory, - data_files=self._config.dataset.data_files, - split=self._config.dataset.split, - num_proc=self._config.loading_workers, - trust_remote_code=self._config.dataset.trust_remote_code, - ) + if self._config.dataset.load_from_disk: + dataset = datasets.load_from_disk(self._config.dataset.path)[self._config.dataset.split] + else: + dataset = datasets.load_dataset( + path=self._config.dataset.path, + name=self._config.dataset.config_name, + data_dir=self._config.dataset.data_directory, + data_files=self._config.dataset.data_files, + split=self._config.dataset.split, + num_proc=self._config.num_workers, + trust_remote_code=self._config.dataset.trust_remote_code, + ) assert isinstance(dataset, datasets.Dataset) return dataset @@ -137,6 +135,7 @@ def run(self) -> None: # Initialize distributed processing if self._config.distributed.world_size > 1: + log_main_rank(f"> Initializing distributed process groups ...") torch.distributed.init_process_group( backend=self._config.distributed.backend, rank=self._config.distributed.rank, @@ -146,31 +145,18 @@ def run(self) -> None: # Prepare output directory self._config.output_path.mkdir(parents=True, exist_ok=True) - downloaded = pathlib.Path(self._config.dataset.path).is_dir() - if self._config.distributed.world_size > 1: - torch.distributed.barrier() - - if downloaded: - # Dataset is already downloaded, load from disk + log_main_rank(f"> Loading dataset `{self._config.dataset.path}` ...") + if self._config.distributed.world_size == 1: + dataset = self._load_dataset() + elif self._config.distributed.rank == 0: + # Load first on rank 0 to prevent parallel downloads. dataset = self._load_dataset() + torch.distributed.barrier() else: - # Dataset is not downloaded, download on rank 0 - if self._config.distributed.rank == 0: - dataset = self._load_dataset() - - # Synchronize processes to wait for the download to finish on rank 0 - if self._config.distributed.world_size > 1: - torch.distributed.barrier() - + torch.distributed.barrier() # Load the downloaded dataset on remaining ranks - if self._config.distributed.rank != 0: - dataset = self._load_dataset() - - # Synchronize processes to wait for the dataset to load on remaining ranks - if self._config.distributed.world_size > 1: - torch.distributed.barrier() + dataset = self._load_dataset() - assert isinstance(dataset, datasets.Dataset) dataset = dataset.shard( num_shards=self._config.distributed.world_size, index=self._config.distributed.rank, @@ -180,49 +166,45 @@ def run(self) -> None: if column_name not in dataset.column_names: raise ValueError(f"Dataset does not have field '{column_name}'.") - # Tokenize the dataset in parallel - prepared_dataset = dataset.map( - self._prepare_batch, - batched=True, - num_proc=self._config.tokenize_workers, - desc="Tokenizing batches", - ) - # Split dataset into shards based on number of tokens - num_shards = math.ceil( - sum(len(sample) for sample in prepared_dataset["samples"]) / self._config.tokens_per_shard - ) - shards = [ - (i, prepared_dataset.shard(num_shards=num_shards, index=i)) - for i in tqdm.tqdm(range(num_shards), desc="Creating shards") - ] + num_shards = math.ceil(len(dataset) / self._config.documents_per_shard) + shards = [(i, dataset.shard(num_shards=num_shards, index=i)) for i in range(num_shards)] + + log_main_rank(f"> Preparing samples on {self._config.num_workers} workers ...") # Use multiprocessing to save each shard in parallel on all ranks - with multiprocessing.Pool(processes=self._config.saving_workers) as pool: - dataset_and_reader_configs = pool.map(self._save_shard, shards) + with multiprocessing.Pool(processes=self._config.num_workers) as pool: + dataset_and_reader_configs = pool.map(self._prepare_shard, shards) + log_main_rank(f"> Generating dataset config ...") self.generate_config_yaml_for_sharded_dst(dataset_and_reader_configs) - def _prepare_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[LanguageModelSample]]: - # Gather values by sample using zip* - sample_column_values = zip(*(batch[column_name] for column_name in self._source_schema.columns)) - # Convert to dicts using column names. - sample_dicts = ( - {column_name: column_value for column_name, column_value in zip(self._source_schema.columns, sample_data)} - for sample_data in sample_column_values + def _prepare_shard( + self, args: tuple[int, datasets.Dataset] + ) -> tuple[MemmapDatasetConfig, MemmapIndexDatasetReaderConfig]: + shard_index, shard_dataset = args + file_name = f"shard_{self._config.distributed.rank}_{shard_index}.fast_llm_dataset" + + reader_config = MemmapDataset.write_dataset( + self._config.output_path / file_name, + ( + self._prepare_sample(sample) + for sample in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_index}", unit="docs") + ), + LanguageModelWriter, ) - # Prepare each sample, wrap in dict for the `Dataset` interface - return {"samples": [self._prepare_sample(sample_dict) for sample_dict in sample_dicts]} + return MemmapDatasetConfig.from_dict({"type": "memmap", "path": file_name}), reader_config def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: - text = sample[self._source_schema.text_column] + # TODO: ======= Extract so we can use elsewhere? (ex. inference) ====== + text = sample[self._source_schema.text] all_spans = [] if self._source_schema.has_loss_masking_span: # TODO: ====== What is the exact input format? ====== # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format. loss_masking_spans = _sort_spans( - (begin, last + 1) - for begin, last in np.array(sample[self._source_schema.loss_masking_spans_column], dtype=np.int32) + (SpanType.loss_masking, (begin, last + 1)) + for begin, last in np.array(sample[self._source_schema.loss_masking_spans], dtype=np.int32) .reshape(-1, 2) .tolist() ) @@ -230,37 +212,58 @@ def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: if self._source_schema.has_preference_spans: # TODO: ===== Was `self._config.dataset.field` (bug?) ====== - full_chosen_text = ( - text + sample[self._source_schema.chosen_spans_column] + self._tokenizer.tokenizer.eos_token - ) - full_rejected_text = ( - self._tokenizer.tokenizer.bos_token + text + sample[self._source_schema.rejected_spans_column] - ) + full_chosen_text = text + sample[self._source_schema.chosen_span] + self._tokenizer.tokenizer.eos_token + full_rejected_text = self._tokenizer.tokenizer.bos_token + text + sample[self._source_schema.rejected_span] # compute chosen span - chosen_spans = [[len(text), len(full_chosen_text)]] + chosen_spans = [(SpanType.chosen, (len(text), len(full_chosen_text)))] # compute rejected span rejected_span = [ - [ - len(full_chosen_text) + len(self._tokenizer.tokenizer.bos_token) + len(text), - len(full_chosen_text) + len(full_rejected_text), - ] + ( + SpanType.rejected, + ( + len(full_chosen_text) + len(self._tokenizer.tokenizer.bos_token) + len(text), + len(full_chosen_text) + len(full_rejected_text), + ), + ) ] # pack texts text = full_chosen_text + full_rejected_text all_spans.extend(chosen_spans + rejected_span) - tokens = torch.tensor( - self._tokenizer.tokenize_with_spans(text, True, True, spans=_sort_spans(all_spans)), - dtype=self._data_type.torch, + # Sort the spans by location (begin), keeping track of their type. + # Note: overlapping spans are not supported (explicit assertion in the tokenizer). + span_types, spans = zip(*_sort_spans(all_spans)) if all_spans else ([], []) + # Tokenize the text, and determine the span locations in the tokenized text. + tokens, token_spans = self._tokenizer.tokenize_with_spans( + text, True, True, text_spans=spans, data_type=self._data_type ) + + # Gather token spans by type. + token_spans_by_type = collections.defaultdict(list) + for span_type, token_span in zip(span_types, token_spans, strict=True): + token_spans_by_type[span_type].append(token_span) + sample_size = len(tokens) return LanguageModelSample( TokenSample(tokens, [sample_size]), - RangeSample(loss_masking_spans, sample_size) if self._source_schema.has_loss_masking_span else None, - RangeSample(chosen_spans, sample_size) if self._source_schema.has_preference_spans else None, - RangeSample(rejected_span, sample_size) if self._source_schema.has_preference_spans else None, + ( + RangeSample(token_spans_by_type[SpanType.loss_masking], sample_size) + if self._source_schema.has_loss_masking_span + else None + ), + ( + RangeSample(token_spans_by_type[SpanType.chosen], sample_size) + if self._source_schema.has_preference_spans + else None + ), + ( + # `tokenize_with_spans` excludes the final eod token from the rejected span, but we want to include it. + RangeSample([(begin, end + 1) for begin, end in token_spans_by_type[SpanType.rejected]], sample_size) + if self._source_schema.has_preference_spans + else None + ), ) def generate_config_yaml_for_sharded_dst( @@ -402,8 +405,8 @@ def _split_and_blend_dataset_configs( return dataset_splits -def _sort_spans(spans: typing.Iterable[tuple[int, int]]) -> list[tuple[int, int]]: - return sorted(spans, key=lambda span: span[0]) +def _sort_spans(spans: typing.Iterable[tuple[SpanType, tuple[int, int]]]) -> list[tuple[SpanType, tuple[int, int]]]: + return sorted(spans, key=lambda span: span[1][0]) def _get_nearest_split(cumsum: np.ndarray, value: float) -> int: diff --git a/fast_llm/data/preprocessing/__init__.py b/fast_llm/data/preprocessing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py new file mode 100644 index 000000000..70291bcaa --- /dev/null +++ b/fast_llm/data/preprocessing/tokenizer.py @@ -0,0 +1,196 @@ +import pathlib +import typing + +from fast_llm.config import Config, Configurable, Field, FieldHint, config_class +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.config_utils.run import log_main_rank +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + import numpy as np + import torch + + +@config_class() +class TokenizerConfig(Config): + """ + Configuration for the tokenizer. + The tokenizer is needed for FIM and dataset preparation. + """ + + path: pathlib.Path = Field( + default=None, + desc="Path to the tokenizer file.", + hint=FieldHint.core, + ) + bos_token: str | None = Field( + default=None, + desc="BOS token to use if the tokenizer doesn't define one; must be an existing token.", + hint=FieldHint.core, + ) + max_vocab_size: int | None = Field( + default=None, + desc="Constrain output tokens to a specific range. Used for testing.", + hint=FieldHint.testing, + ) + + def get_tokenizer(self) -> "Tokenizer": + from fast_llm.data.preprocessing.tokenizer import Tokenizer + + return Tokenizer(self) + + +class Tokenizer[ConfigType: TokenizerConfig](Configurable[ConfigType]): + """ + A wrapper around Huggingface (transformers) tokenizer. + """ + + def __init__(self, config: ConfigType): + super().__init__(config) + from transformers import AutoTokenizer + + log_main_rank(f"> loading tokenizer from {config.path} ...") + self.tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=self._config.path, + errors="replace", + max_len=None, + trust_remote_code=True, + use_fast=True, + ) + if self._config.bos_token is not None: + self.tokenizer.bos_token = self._config.bos_token + if self.tokenizer.eos_token_id is None: + raise ValueError("Tokenizer does not have an EOS token.") + if self.tokenizer.bos_token_id is None: + raise ValueError("Tokenizer does not have an BOS token.") + self.eod_id = self.tokenizer.eos_token_id + self.bod_id = self.tokenizer.bos_token_id + + @property + def vocab_size(self) -> int: + return len(self.tokenizer) + + @property + def vocab(self) -> dict[str, int]: + return self.tokenizer.vocab + + @property + def inv_vocab(self) -> dict[int, str]: + return self._inv_vocab + + def tokenize( + self, text: str, begin: bool = True, end: bool = True, data_type: DataType = DataType.int64 + ) -> "torch.Tensor": + import torch + + tokens = torch.tensor( + ([self.bod_id] if begin else []) + + self.tokenizer.encode(text, add_special_tokens=False) + + ([self.eod_id] if end else []), + dtype=data_type.torch, + ) + if self._config.max_vocab_size is not None: + tokens %= self._config.max_vocab_size + return tokens + + def tokenize_with_spans( + self, + text: str, + begin: bool = True, + end: bool = True, + *, + text_spans: list[tuple[int, int]], + data_type: DataType = DataType.int64, + ) -> tuple["torch.Tensor", list[tuple[int, int]]]: + """ + Perform span-aware tokenization and return the tokenized input_ids along with token spans. + """ + if not text_spans: + return self.tokenize(text, begin, end, data_type=data_type), [] + input_ids, token_splits = self.tokenize_with_splits( + text, begin, end, text_splits=[split for splits in text_spans for split in splits], data_type=data_type + ) + return input_ids, [(begin, end) for begin, end in zip(token_splits[::2], token_splits[1::2], strict=True)] + + def tokenize_with_splits( + self, + text: str, + begin: bool = True, + end: bool = True, + *, + text_splits: list[int], + data_type: DataType = DataType.int64, + ) -> tuple["torch.Tensor", list[int]]: + if not text_splits: + return self.tokenize(text, begin, end, data_type=data_type), [] + import torch + + Assert.eq(sorted(text_splits), text_splits) + input_ids = [] + text_splits = [0, *text_splits, len(text)] + token_splits = [] + total_tokens = 0 + + for i, (split_begin, split_end) in enumerate(zip(text_splits[:-1], text_splits[1:])): + input_ids.append( + split_tokens := self.tokenize( + text[split_begin:split_end], + begin and i == 0, + end and i == len(text_splits) - 2, + data_type=data_type, + ) + ) + total_tokens += len(split_tokens) + token_splits.append(total_tokens) + + return torch.cat(input_ids), token_splits[:-1] + + def detokenize( + self, tokens: "int | list[int] | np.ndarray | torch.Tensor", begin: bool = False, end: bool = False + ) -> str: + tokens = self._remove_delimiters(tokens, begin, end) + return self.tokenizer.decode(tokens) + + def detokenize_with_spans( + self, tokens: "torch.Tensor", begin: bool = False, end: bool = False, *, token_spans: list[tuple[int, int]] + ) -> tuple[str, list[tuple[int, int]]]: + if not token_spans: + return self.detokenize(tokens, begin, end), [] + text, text_splits = self.detokenize_with_splits( + tokens, begin, end, token_splits=[split for splits in token_spans for split in splits] + ) + return text, [(begin, end) for begin, end in zip(text_splits[::2], text_splits[1::2], strict=True)] + + def detokenize_with_splits( + self, tokens: "torch.Tensor", begin: bool = False, end: bool = False, *, token_splits: list[int] + ) -> tuple[str, list[int]]: + if not token_splits: + return self.detokenize(tokens, begin, end), [] + Assert.eq(sorted(token_splits), token_splits) + tokens = self._remove_delimiters(tokens, begin, end) + texts = [] + token_splits = [0, *(token_split - begin for token_split in token_splits), len(tokens)] + text_splits = [] + total_characters = 0 + + for i, (split_begin, split_end) in enumerate(zip(token_splits[:-1], token_splits[1:])): + texts.append(split_text := self.detokenize(tokens[split_begin:split_end])) + total_characters += len(split_text) + text_splits.append(total_characters) + + return "".join(texts), text_splits[:-1] + + def _remove_delimiters( + self, token_ids: "int | list[int] | np.ndarray | torch.Tensor", begin: bool = False, end: bool = False + ): + if begin: + Assert.eq(token_ids[0], self.bod_id) + token_ids = token_ids[1:] + if end: + Assert.eq(token_ids[-1], self.eod_id) + token_ids = token_ids[:-1] + return token_ids + + @property + def eod(self): + return self.eod_id diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py index 0b2e324c3..aaa321efd 100644 --- a/fast_llm/data/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -210,8 +210,9 @@ def write(self, document: Sample): assert hasattr(self, "_begin") and not hasattr(self, "_end") def __exit__(self, exc_type, exc_val, exc_tb): - self._stream.write(self._get_config_class().footer) - self._end = self._stream.tell() + if exc_type is None: + self._stream.write(self._get_config_class().footer) + self._end = self._stream.tell() if self._owns_stream: self._stream.close() diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index 77cc6e8a2..6f485bf84 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -89,9 +89,9 @@ def to_samples(self) -> list[LanguageModelSample]: LanguageModelSample(tokens, loss_masking_spans, chosen_spans, rejected_spans) for tokens, loss_masking_spans, chosen_spans, rejected_spans in zip( self.tokens.to_samples(), - self.loss_masking_spans.to_samples(), - self.chosen_spans.to_samples(), - self.rejected_spans.to_samples(), + None if self.loss_masking_spans is None else self.loss_masking_spans.to_samples(), + None if self.chosen_spans is None else self.chosen_spans.to_samples(), + None if self.rejected_spans is None else self.rejected_spans.to_samples(), strict=True, ) ] @@ -237,27 +237,31 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._chosen_spans_writer.__exit__(exc_type, exc_val, exc_tb) self._rejected_spans_writer.__exit__(exc_type, exc_val, exc_tb) - # A dummy config so we can verify the begin and end offsets. - config = self._get_config(self._begin, None) - _copy_chunked(self._path.joinpath("tokens"), self._stream, config.tokens.begin, config.tokens.end) - - if self._has_loss_masking_spans: - _copy_chunked( - self._path.joinpath("loss_masking_spans"), - self._stream, - config.loss_masking_spans.begin, - config.loss_masking_spans.end, - ) - if self._has_preference_spans: - _copy_chunked( - self._path.joinpath("chosen_spans"), self._stream, config.chosen_spans.begin, config.chosen_spans.end - ) - _copy_chunked( - self._path.joinpath("rejected_spans"), - self._stream, - config.rejected_spans.begin, - config.rejected_spans.end, - ) + if exc_type is None: + # A dummy config so we can verify the begin and end offsets. + config = self._get_config(self._begin, None) + _copy_chunked(self._path.joinpath("tokens"), self._stream, config.tokens.begin, config.tokens.end) + + if self._has_loss_masking_spans: + _copy_chunked( + self._path.joinpath("loss_masking_spans"), + self._stream, + config.loss_masking_spans.begin, + config.loss_masking_spans.end, + ) + if self._has_preference_spans: + _copy_chunked( + self._path.joinpath("chosen_spans"), + self._stream, + config.chosen_spans.begin, + config.chosen_spans.end, + ) + _copy_chunked( + self._path.joinpath("rejected_spans"), + self._stream, + config.rejected_spans.begin, + config.rejected_spans.end, + ) self._directory.cleanup() super().__exit__(exc_type, exc_val, exc_tb) diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py index 92d5ce7fc..c3a035376 100644 --- a/fast_llm/data/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -12,7 +12,7 @@ MemmapWriter, Sample, ) -from fast_llm.utils import get_unique +from fast_llm.utils import Assert, get_unique class RangeSample(Sample): @@ -88,7 +88,7 @@ def __init__(self, config: ConfigType, buffer: memoryview): self._buffer, dtype=torch.int32, count=self._config.num_ranges * 2, - ).reshape(-1, 2) + ).view(-1, 2) self._count_cumsums = torch.frombuffer( self._buffer, dtype=torch.int32, @@ -117,7 +117,9 @@ def write(self, document: RangeSample): self._count_cumsum.append(self._count_cumsum[-1] + len(document.ranges)) def __exit__(self, exc_type, exc_val, exc_tb): - self._stream.write(np.array(self._count_cumsum, dtype=np.int32).tobytes(order="C")) + if exc_type is None: + Assert.lt(self._count_cumsum[-1], np.iinfo(np.int32).max) + self._stream.write(np.array(self._count_cumsum, dtype=np.int32).tobytes(order="C")) super().__exit__(exc_type, exc_val, exc_tb) @classmethod diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index ae190658f..706b5053a 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -88,11 +88,11 @@ def to_device_(self, device: "torch.device | str"): @config_class(dynamic_type={MemmapReaderBaseConfig: "token"}) class TokenReaderConfig(MemmapReaderConfig): _abstract = False + header: typing.ClassVar[bytes] = b"token begin" + footer: typing.ClassVar[bytes] = b"token end" num_documents: int = Field() num_tokens: int = Field() data_type: DataType = Field() - header: typing.ClassVar[bytes] = b"token begin" - footer: typing.ClassVar[bytes] = b"token end" def __len__(self) -> int: return self.num_documents @@ -151,7 +151,8 @@ def write(self, document: TokenSample): self._size_cumsum.append(self._size_cumsum[-1] + len(document.tokens)) def __exit__(self, exc_type, exc_val, exc_tb): - self._stream.write(np.array(self._size_cumsum, dtype=np.int64).tobytes(order="C")) + if exc_type is None: + self._stream.write(np.array(self._size_cumsum, dtype=np.int64).tobytes(order="C")) super().__exit__(exc_type, exc_val, exc_tb) @classmethod diff --git a/fast_llm/engine/config_utils/data_type.py b/fast_llm/engine/config_utils/data_type.py index 1a0fed91b..27709a8bb 100644 --- a/fast_llm/engine/config_utils/data_type.py +++ b/fast_llm/engine/config_utils/data_type.py @@ -50,9 +50,13 @@ def from_torch(cls, dtype: "torch.dtype") -> "DataType": return _TORCH_DTYPE_MAP_INV[dtype] @classmethod - def from_numpy(cls, dtype: "np.dtype") -> "DataType": + def from_numpy(cls, dtype: "np.dtype | type[np.number]") -> "DataType": + import numpy as np + if not _NUMPY_DTYPE_MAP_INV: _set_numpy_dtype_map() + if isinstance(dtype, np.dtype): + dtype = dtype.type return _NUMPY_DTYPE_MAP_INV[dtype] @classmethod diff --git a/fast_llm/engine/config_utils/runnable.py b/fast_llm/engine/config_utils/runnable.py index 051163084..163a9459c 100644 --- a/fast_llm/engine/config_utils/runnable.py +++ b/fast_llm/engine/config_utils/runnable.py @@ -106,7 +106,7 @@ def _get_runnable(self) -> typing.Callable[[], None]: return self.run def run(self) -> None: - raise NotImplementedError() + self._get_runnable()() def _show[ T diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index f8dfd4825..df7ab0f51 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -2,7 +2,7 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.data.config import TokenizerConfig +from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.utils import Assert diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 1f9feceb4..83675ac74 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -161,21 +161,22 @@ def rms_close_relative(x, y, threshold, min_threshold=0): assert rms <= threshold, f"Rms diff too big ({rms:.3e} > {threshold:.3e}) between tensors {x} and {y}" @staticmethod - def all_equal(x, y): + def all_equal(x, *args): import torch # Make it work for lists and numpy arrays. x = torch.as_tensor(x) - y = torch.as_tensor(y) - - Assert.eq(x.shape, y.shape) - neq = x != y - if neq.any().item(): # noqa - index = None if x.numel() == 1 else torch.where(neq) # noqa - raise AssertionError( - f"Tensors have {index[0].numel()} different entries out of " - f"{x.numel()}: {x[index]} != {y[index]} at index {torch.stack(index, -1)}" - ) + for arg in args: + arg = torch.as_tensor(arg) + + Assert.eq(x.shape, arg.shape) + neq = x != arg + if neq.any().item(): # noqa + index = None if x.numel() == 1 else torch.where(neq) # noqa + raise AssertionError( + f"Tensors have {index[0].numel()} different entries out of " + f"{x.numel()}: {x[index]} != {arg[index]} at index {torch.stack(index, -1)}" + ) @staticmethod def all_different(x, y): diff --git a/tests/data/common.py b/tests/data/common.py index 7053666b8..ac8d8023c 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -4,26 +4,18 @@ import numpy as np import torch -from fast_llm.config import Field, FieldHint, NoAutoValidate, config_class +from fast_llm.config import NoAutoValidate from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.config import ( - IndexedDatasetConfig, - SampledDatasetConfig, - SamplingConfig, - SamplingParameters, - ShufflingType, -) +from fast_llm.data.dataset.config import SampledDatasetConfig, SamplingConfig, ShufflingType from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.dataset.sampled import SampledIndexedDataset -from fast_llm.data.sample.abstract import Sample from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.utils import Assert, div -from tests.utils.global_variables import TEST_VOCAB_SIZE def get_sampling_data( @@ -33,7 +25,7 @@ def get_sampling_data( cache_directory: pathlib.Path | None = None, phase=PhaseType.training, sequence_length: int = 512, - vocab_size=TEST_VOCAB_SIZE, + vocab_size: int | None = None, gpu: bool = False, shuffle: ShufflingType = ShufflingType.epoch, truncate_documents=True, @@ -73,7 +65,7 @@ def get_test_data_and_compare_samples( shuffle: ShufflingType = ShufflingType.epoch, cache_directory: pathlib.Path | None = None, sequence_length: int = 512, - vocab_size=TEST_VOCAB_SIZE, + vocab_size: int | None = None, expected_samples: dict[str, list[list[int]]] | list[list[int]], ) -> GPTData: distributed_config = DistributedConfig(seed=87522) @@ -115,34 +107,21 @@ def get_test_data_and_compare_samples( return data -def compare_indexed_dataset( +def compare_indexed_dataset_tokens( dataset: IndexedDataset, length: int, num_tokens: int, expected_samples: dict[int, list[int]], - loss_masking_spans: dict[int, list[int]] | None = None, ) -> None: Assert.eq(len(dataset), length) sizes = dataset.get_document_sizes() - # Assert.eq(sizes.sum(), num_tokens) + Assert.eq(sizes.sum(), num_tokens, dataset.num_tokens) Assert.all_equal( [len(dataset.get_document(i).tokens.tokens) for i in range(min(len(dataset), 100))], sizes[: min(len(dataset), 100)], ) for i, expected_sample in expected_samples.items(): Assert.all_equal(dataset.get_document(i).tokens.tokens, np.array(expected_sample)) - if loss_masking_spans: - for i, loss_masking_span in loss_masking_spans.items(): - print(i) - Assert.eq( - dataset.get_document( - i, - parameters=GPTSamplingParameters( - num_samples=0, sequence_length=0, vocab_size=0, use_loss_masking_spans=True - ), - ).loss_masking_spans.ranges, - loss_masking_spans[i], - ) def compare_sampled_dataset(sampled: SampledDataset, expected_samples: list[list[int] | np.ndarray]) -> None: @@ -183,61 +162,7 @@ def validate_indexed_dataset_sampling(sampled: SampledIndexedDataset, expected_s for index in range(sampled._parameters.num_samples) ] token_ids = torch.stack([sampled[i].tokens.tokens for i in range(len(sampled))]).to(torch.int64) - Assert.all_equal(token_ids, validate_samples) if expected_samples is not None: Assert.all_equal(token_ids, expected_samples) return token_ids - - -@config_class(dynamic_type={SampledDatasetConfig: "mock_memmap"}) -class MockGPTMemmapDatasetConfig(IndexedDatasetConfig): - _abstract: typing.ClassVar[bool] = False - num_documents: int | None = Field( - default=None, - desc="Expected number of documents in the dataset.", - hint=FieldHint.core, - ) - num_tokens_per_document: int | None = Field( - default=None, - desc="Expected number of tokens in the dataset.", - hint=FieldHint.optional, - ) - path: pathlib.Path = Field(default=".") - - def build(self) -> "IndexedDataset": - return MockMemmapDataset(self) - - def __len__(self) -> int: - return self.num_documents - - @property - def num_tokens(self) -> int: - return self.num_documents * self.num_tokens_per_document - - -class MockMemmapDataset[SampleType: Sample](IndexedDataset[SampleType]): - def __init__(self, config: MockGPTMemmapDatasetConfig): - self._config = config - - @property - def name(self) -> str: - return "mock_memmap" - - def __len__(self) -> int: - return len(self._config) - - @property - def num_tokens(self) -> int: - return self._config.num_tokens - - def get_document_sizes(self) -> torch.Tensor: - return torch.full([self._config.num_documents], self._config.num_tokens_per_document, dtype=torch.int64) - - def get_document_size(self, index: int) -> int: - return self._config.num_tokens_per_document - - def get_document( - self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None - ) -> SampleType: - raise NotImplementedError() diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index b2b2f0117..88ecf2c99 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -12,17 +12,11 @@ get_sampling_data, get_test_data_and_compare_samples, ) -from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_CACHE, DATASET_PATH - -_DATASET_PATH_MIX_1 = DATASET_CACHE / "blended_mix_1" / "dataset" - - -def _get_test_dataset_mix_1(): - return get_test_dataset(_DATASET_PATH_MIX_1, seed=2345) +from tests.utils.dataset import get_alt_test_dataset, get_common_test_dataset def _get_blending_alt(probs: list[float], num_samples: int) -> tuple[np.ndarray, np.ndarray]: + # Alternate implementation for blending. probs = np.array(probs) dataset_index = np.zeros(num_samples) sample_index = np.zeros(num_samples) @@ -37,25 +31,25 @@ def _get_blending_alt(probs: list[float], num_samples: int) -> tuple[np.ndarray, GPT_BLENDED_SAMPLES = [ - [4709, 819, 79, 207, 277, 1790], - [1790, 80, 6506, 1735, 542, 88], - [4628, 7392, 920, 79, 1322, 387], - [88, 4302, 269, 2794, 119, 80], - [80, 207, 567, 498, 89, 207], - [207, 4700, 549, 79, 417, 3036], - [387, 4224, 87, 2713, 423, 324], - [3036, 253, 207, 2968, 4536, 1178], + [49152, 46, 10, 819, 19, 45], + [45, 69, 17, 86, 38826, 15], + [49152, 83, 80, 20452, 45, 93], + [15, 25, 51, 31, 32348, 64], + [64, 17, 93, 78, 40, 1793], + [1793, 1, 1746, 38, 27, 58], + [93, 90, 39, 6, 75, 9], + [58, 22885, 93, 37, 92, 76], ] GPT_BLENDED_MIXED_SAMPLES = [ - [4709, 819, 79, 207, 277, 1790], + [49152, 46, 10, 819, 19, 45], [916, 6683, 7685, 1277, 5106, 378], - [1790, 80, 6506, 1735, 542, 88], + [45, 69, 17, 86, 38826, 15], [3359, 6803, 780, 4561, 669, 7878], - [88, 4302, 269, 2794, 119, 80], - [80, 207, 567, 498, 89, 207], + [15, 25, 51, 31, 32348, 64], + [64, 17, 93, 78, 40, 1793], [6920, 2218, 2921, 3963, 7606, 6904], - [207, 4700, 549, 79, 417, 3036], + [1793, 1, 1746, 38, 27, 58], ] @@ -112,38 +106,21 @@ def test_blending(probs): def test_gpt_blended(): # Make sure dataset blending works and check for unintended changes in behavior. - get_test_dataset() - _get_test_dataset_mix_1() + _, config, _ = get_common_test_dataset() + _, alt_config, _ = get_alt_test_dataset() sampled = get_dataset_config( - { + dataset_config := { "type": "blended", - "datasets": [ - {"type": "memmap", "path": DATASET_PATH}, - {"type": "memmap", "path": _DATASET_PATH_MIX_1}, - ], + "datasets": [config, alt_config], "weights": [0.75, 0.25], }, BlendedDatasetConfig[LanguageModelSample], - ).build_and_sample(get_sampling_data(8, sequence_length=5)) + ).build_and_sample(get_sampling_data(8, sequence_length=5, vocab_size=8192)) compare_sampled_dataset(sampled, GPT_BLENDED_SAMPLES) - -def test_gpt_blended_data(): - get_test_dataset() - _get_test_dataset_mix_1() + # Test in data. get_test_data_and_compare_samples( - { - "datasets": { - "training": { - "type": "blended", - "datasets": [ - {"type": "memmap", "path": DATASET_PATH}, - {"type": "memmap", "path": _DATASET_PATH_MIX_1}, - ], - "weights": [0.75, 0.25], - } - } - }, + {"datasets": {"training": dataset_config}}, 8, sequence_length=5, expected_samples=GPT_BLENDED_SAMPLES, @@ -152,34 +129,25 @@ def test_gpt_blended_data(): def test_gpt_blended_mixed(): # Make sure dataset blending works and check for unintended changes in behavior. - get_test_dataset() + _, config, _ = get_common_test_dataset() sampled = get_dataset_config( - { + dataset_config := { "type": "blended", "datasets": [ - {"type": "memmap", "path": DATASET_PATH}, + config, {"type": "random"}, ], "weights": [0.6, 0.4], }, BlendedDatasetConfig[LanguageModelSample], - ).build_and_sample(get_sampling_data(8, sequence_length=5)) + ).build_and_sample(get_sampling_data(8, sequence_length=5, vocab_size=8192)) compare_sampled_dataset(sampled, GPT_BLENDED_MIXED_SAMPLES) - -def test_gpt_blended_mixed_data(): - get_test_dataset() + # Test in data. get_test_data_and_compare_samples( - { - "datasets": { - "training": { - "type": "blended", - "datasets": [{"type": "memmap", "path": DATASET_PATH}, {"type": "random"}], - "weights": [0.6, 0.4], - } - } - }, + {"datasets": {"training": dataset_config}}, 8, sequence_length=5, + vocab_size=8192, expected_samples=GPT_BLENDED_MIXED_SAMPLES, ) diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index 7b009bbf6..d7e750c8b 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -1,56 +1,48 @@ from fast_llm.data.dataset.config import ConcatenatedDatasetConfig +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from fast_llm.data.sample.language_model import LanguageModelSample from tests.data.common import ( - compare_indexed_dataset, + compare_indexed_dataset_tokens, compare_sampled_dataset, get_dataset_config, get_sampling_data, get_test_data_and_compare_samples, ) -from tests.data.test_memmap import MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_TOKENS -from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PATH +from tests.data.test_preparator import COMMON_DATASET_LENGTH, COMMON_DATASET_SAMPLES, COMMON_DATASET_TOKENS +from tests.utils.dataset import get_common_test_dataset GPT_CONCATENATED_SAMPLES = [ - [4709, 819, 79, 207, 277, 1790], - [1790, 80, 6506, 1735, 542, 88], - [88, 4302, 269, 2794, 119, 80], - [80, 207, 567, 498, 89, 207], - [207, 4700, 549, 79, 417, 3036], - [3036, 253, 207, 2968, 4536, 1178], - [1178, 3291, 317, 277, 2679, 89], - [89, 542, 395, 583, 684, 554], + [49152, 46, 10, 819, 19, 45], + [45, 69, 17, 86, 38826, 15], + [15, 25, 51, 31, 32348, 64], + [64, 17, 93, 78, 40, 1793], + [1793, 1, 1746, 38, 27, 58], + [58, 22885, 93, 37, 92, 76], + [76, 29, 19, 17365, 93, 46], + [46, 83, 17211, 1, 785, 1023], ] def test_gpt_concatenate(): # Make sure the dataset concatenation works and check for unintended changes in behavior. - get_test_dataset() + _, config, _ = get_common_test_dataset() + memmap_config = GPTDatasetFromFileConfig.from_dict(config)._load_config() dataset = get_dataset_config( - {"type": "concatenated", "datasets": [{"type": "memmap", "path": DATASET_PATH} for _ in range(3)]}, + dataset_config := {"type": "concatenated", "datasets": [memmap_config.to_dict() for _ in range(3)]}, ConcatenatedDatasetConfig[LanguageModelSample], ).build() - compare_indexed_dataset( + compare_indexed_dataset_tokens( dataset, - 3 * MEMMAP_DATASET_LENGTH, - 3 * MEMMAP_DATASET_TOKENS, - {j * MEMMAP_DATASET_LENGTH + i: sample for j in range(3) for i, sample in MEMMAP_DATASET_SAMPLES.items()}, + 3 * COMMON_DATASET_LENGTH, + 3 * COMMON_DATASET_TOKENS, + {j * COMMON_DATASET_LENGTH + i: sample for j in range(3) for i, sample in COMMON_DATASET_SAMPLES.items()}, ) sampled = dataset.sample(get_sampling_data(8, sequence_length=5)) compare_sampled_dataset(sampled, GPT_CONCATENATED_SAMPLES) - -def test_gpt_concatenate_data(): - get_test_dataset() + # Test in data. get_test_data_and_compare_samples( - { - "datasets": { - "training": { - "type": "concatenated", - "datasets": [{"type": "memmap", "path": DATASET_PATH} for _ in range(3)], - } - } - }, + {"datasets": {"training": dataset_config}}, 8, sequence_length=5, expected_samples=GPT_CONCATENATED_SAMPLES, diff --git a/tests/data/test_dataset_from_file.py b/tests/data/test_dataset_from_file.py deleted file mode 100644 index af91df1e2..000000000 --- a/tests/data/test_dataset_from_file.py +++ /dev/null @@ -1,12 +0,0 @@ -from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig -from tests.data.common import compare_indexed_dataset, get_dataset_config -from tests.data.test_memmap import MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_TOKENS -from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PATH - - -def test_dataset_from_file(): - get_test_dataset() - dataset_config = {"type": "file", "path": str(DATASET_PATH.parent.joinpath("fast_llm_config.yaml"))} - dataset = get_dataset_config(dataset_config, GPTDatasetFromFileConfig).build() - compare_indexed_dataset(dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, MEMMAP_DATASET_SAMPLES) diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index b9dc7fe32..0600c5258 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -5,34 +5,30 @@ get_sampling_data, get_test_data_and_compare_samples, ) -from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PATH, TOKENIZER_PATH +from tests.utils.dataset import get_common_test_dataset +from tests.utils.global_variables import TOKENIZER_PATH GPT_FIM_SAMPLES = [ - [4709, 819, 79, 207, 277, 1790], - [1790, 80, 6506, 1735, 542, 88], - [86, 89, 7876, 80, 49152, 87], - [80, 207, 567, 498, 89, 207], - [207, 4700, 549, 79, 417, 3036], - [86, 89, 1178, 49152, 87, 49152], - [86, 49152, 1178, 64, 89, 900], - [86, 49152, 89, 542, 395, 89], + [46, 10, 819, 19, 45, 88], + [45, 69, 17, 86, 38826, 15], + [86, 89, 32348, 64, 49152, 87], + [64, 17, 93, 78, 40, 1793], + [1793, 1, 1746, 38, 27, 58], + [86, 89, 37, 92, 76, 49152], + [86, 49152, 76, 29, 19, 89], + [86, 49152, 46, 83, 17211, 1], ] def test_gpt_fim(): # Make sure the FIM wrapper works in a simple case and check for unintended changes in behavior. - get_test_dataset() + _, config, _ = get_common_test_dataset() # The test tokenizer doesn't have fim tokens, so we work around it. - sampling_config = get_sampling_data( - 8, - sequence_length=5, - vocab_size=49157, - ) + sampling_config = get_sampling_data(8, sequence_length=5) sampled = get_dataset_config( - { + dataset_config := { "type": "fim", - "dataset": {"type": "memmap", "path": DATASET_PATH}, + "dataset": config, "tokenizer": {"path": TOKENIZER_PATH}, "rate": 0.5, "prefix_token": "w", @@ -44,26 +40,9 @@ def test_gpt_fim(): ).build_and_sample(sampling_config) compare_sampled_dataset(sampled, GPT_FIM_SAMPLES) - -def test_gpt_fim_data(): - get_test_dataset() get_test_data_and_compare_samples( - { - "datasets": { - "training": { - "type": "fim", - "dataset": {"type": "memmap", "path": DATASET_PATH}, - "tokenizer": {"path": TOKENIZER_PATH}, - "rate": 0.5, - "prefix_token": "w", - "middle_token": "x", - "pad_token": "y", - "suffix_token": "z", - } - }, - }, + {"datasets": {"training": dataset_config}}, 8, sequence_length=5, expected_samples=GPT_FIM_SAMPLES, - vocab_size=49157, ) diff --git a/tests/data/test_loss_masking_spans.py b/tests/data/test_loss_masking_spans.py new file mode 100644 index 000000000..521eaf2a9 --- /dev/null +++ b/tests/data/test_loss_masking_spans.py @@ -0,0 +1,78 @@ +import datasets + +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingParameters +from fast_llm.data.dataset.memmap import MemmapDataset +from fast_llm.data.preprocessing.tokenizer import TokenizerConfig +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.utils import Assert +from tests.data.common import get_dataset_config +from tests.data.test_preparator import COMMON_DATASET_LENGTH, COMMON_DATASET_TEXT +from tests.utils.dataset import get_test_dataset_with_loss_masking_spans +from tests.utils.global_variables import TOKENIZER_NAME + +DATASET_WITH_SPAN_TOKENS = 46199 +DATASET_WITH_SPAN_SAMPLES = { + 27: [49152, 63, 82, 11, 84, 71, 49152], + 30: [49152, 31, 85, 78, 27, 34, 46, 62, 43, 49152], + 31: [49152, 60, 55, 80, 30, 85, 22, 18, 49152], + 77: [49152, 73, 80, 85, 52, 22, 46, 5, 88, 78, 49152], + 87: [49152, 52, 89, 75, 11, 71, 49152], +} +HF_LOSS_MASKING_SPANS = { + 27: [[0, 1], [3, 3]], + 30: [[0, 0], [2, 2], [5, 5]], + 31: [[0, 0], [2, 2], [4, 4]], + 77: [[0, 0], [3, 5], [7, 7]], + 87: [[1, 1], [3, 3]], +} +TOKEN_LOSS_MASKING_SPANS = { + 27: [(1, 3), (4, 5)], + 30: [(1, 2), (3, 4), (6, 7)], + 31: [(1, 2), (3, 4), (5, 6)], + 77: [(1, 2), (4, 7), (8, 9)], + 87: [(2, 3), (4, 5)], +} + + +def test_gpt_data_with_spans(): + _, config, hf_path = get_test_dataset_with_loss_masking_spans() + dataset: MemmapDataset[LanguageModelSample] = get_dataset_config(config, GPTDatasetFromFileConfig).build() + + hf_dataset = datasets.load_from_disk(hf_path)["train"] + tokenizer = TokenizerConfig(path=TOKENIZER_NAME).get_tokenizer() + + # Check global stats. + Assert.eq(len(dataset), len(hf_dataset), COMMON_DATASET_LENGTH) + Assert.eq(dataset.num_tokens, DATASET_WITH_SPAN_TOKENS) + + for index in range(0, 200, 8): + expected_text = hf_dataset[index]["text"] + expected_text_spans = [(begin, last + 1) for begin, last in hf_dataset[index]["loss_masking_spans"]] + expected_tokens, expected_spans = tokenizer.tokenize_with_spans( + hf_dataset[index]["text"], + text_spans=[(begin, last + 1) for begin, last in hf_dataset[index]["loss_masking_spans"]], + ) + document = dataset.get_document( + index, parameters=GPTSamplingParameters(num_samples=0, sequence_length=0, use_loss_masking_spans=True) + ) + + # Compare tokens and token spans. + Assert.all_equal(document.tokens.tokens, expected_tokens) + Assert.eq(document.loss_masking_spans.ranges, expected_spans) + + # Compare text. + text, text_spans = tokenizer.detokenize_with_spans( + document.tokens.tokens, True, True, token_spans=document.loss_masking_spans.ranges + ) + Assert.eq(text, expected_text) + Assert.eq(text_spans, expected_text_spans) + + # Check some numerical values. + for index in DATASET_WITH_SPAN_SAMPLES: + Assert.eq(hf_dataset[index]["text"], COMMON_DATASET_TEXT[index]) + Assert.eq(hf_dataset[index]["loss_masking_spans"], HF_LOSS_MASKING_SPANS[index]) + document = dataset.get_document( + index, parameters=GPTSamplingParameters(num_samples=0, sequence_length=0, use_loss_masking_spans=True) + ) + Assert.all_equal(document.tokens.tokens, DATASET_WITH_SPAN_SAMPLES[index]) + Assert.all_equal(document.loss_masking_spans.ranges, TOKEN_LOSS_MASKING_SPANS[index]) diff --git a/tests/data/test_memmap.py b/tests/data/test_memmap.py deleted file mode 100644 index b11f84d9c..000000000 --- a/tests/data/test_memmap.py +++ /dev/null @@ -1,47 +0,0 @@ -import pathlib - -import pytest - -from fast_llm.data.dataset.config import MemmapDatasetConfig -from tests.data.common import compare_indexed_dataset, get_dataset_config -from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PATH, DATASET_SAMPLING_CACHE, DATASET_WITH_SPANS_PATH - -MEMMAP_DATASET_LENGTH = 6153 -MEMMAP_DATASET_TOKENS = 508327 -MEMMAP_DATASET_SAMPLES = { - 9: [], - 10: [80, 85, 4295, 4182, 489, 727, 84, 698, 1197, 583], - 13: [78, 727, 74, 317, 1358, 89], - 15: [78], -} - - -@pytest.mark.parametrize("cache_directory", (None, pathlib.Path(DATASET_SAMPLING_CACHE) / "test_memmap")) -def test_gpt_memmap(cache_directory): - # Make sure the memmap dataset works and check for unintended changes in behavior. - get_test_dataset() - dataset = get_dataset_config({"type": "memmap", "path": DATASET_PATH}, MemmapDatasetConfig).build() - compare_indexed_dataset(dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, MEMMAP_DATASET_SAMPLES) - - -MEMMAP_DATASET_SPANS = { - 9: [], - 10: [(0, 1), (2, 6), (7, 9)], - 13: [(0, 1)], - 15: [], -} - - -def test_gpt_data_with_spans(): - get_test_dataset(DATASET_WITH_SPANS_PATH, max_spans=5) - dataset = get_dataset_config( - { - "type": "memmap", - "path": DATASET_WITH_SPANS_PATH, - }, - MemmapDatasetConfig, - ).build() - compare_indexed_dataset( - dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_SPANS - ) diff --git a/tests/data/test_preference_spans.py b/tests/data/test_preference_spans.py new file mode 100644 index 000000000..7b570c5a1 --- /dev/null +++ b/tests/data/test_preference_spans.py @@ -0,0 +1,105 @@ +import datasets +import numpy as np +import torch + +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingParameters +from fast_llm.data.dataset.memmap import MemmapDataset +from fast_llm.data.preprocessing.tokenizer import TokenizerConfig +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.utils import Assert +from tests.data.common import get_dataset_config +from tests.data.test_preparator import COMMON_DATASET_LENGTH +from tests.utils.dataset import get_test_dataset_with_preference_spans +from tests.utils.global_variables import TOKENIZER_NAME + +DATASET_WITH_PREFERENCE_SPAN_TOKENS = 62163 +DATASET_WITH_PREFERENCE_SPAN_TEXT = { + 27: ["`", "s,", "uh"], + 30: ["@v", "o{hf_dataset[index]["answer"]}<|endoftext|>", + ) diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py deleted file mode 100644 index 09a91d6a8..000000000 --- a/tests/data/test_prepare_gpt_memmap.py +++ /dev/null @@ -1,211 +0,0 @@ -import json -import pathlib -import tempfile - -import numpy as np -import pytest -import torch - -from fast_llm.data.dataset.config import IndexedDatasetConfig -from fast_llm.data.dataset.gpt.config import GPTSamplingParameters -from fast_llm.data.dataset.gpt.legacy_memmap import MEMMAP_DTYPES -from fast_llm.data.dataset.memmap import MemmapDataset -from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig -from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator -from fast_llm.data.sample.language_model import LanguageModelSample, LanguageModelWriter -from fast_llm.data.sample.range import RangeSample -from fast_llm.data.sample.token import TokenSample -from fast_llm.utils import Assert -from tests.data.common import MockGPTMemmapDatasetConfig # Noqa - - -def get_preparator(output_path: str, dataset_path_name: str) -> GPTMemmapDatasetPreparator: - config = GPTMemmapDatasetPreparatorConfig.from_dict( - { - "output_path": output_path, - "dataset": {"path": dataset_path_name}, - "tokenizer": {"path": "no_tokenizer"}, - }, - {}, - ) - return config.get_dataset_preparator_class()(config=config) - - -@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) -def test_write_memmap_dataset(dtype): - documents = [ - LanguageModelSample( - TokenSample(torch.from_numpy(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype))) - ) - for _ in range(100) - ] - with tempfile.TemporaryDirectory() as temp_dir: - path = pathlib.Path(temp_dir) / "dataset" - MemmapDataset.write_dataset(path, documents, LanguageModelWriter) - dataset = MemmapDataset("dataset", path) - for i, document in enumerate(documents): - Assert.all_equal(dataset.get_document(i).tokens.tokens, document.tokens.tokens.to(torch.int64)) - - -def _generate_valid_span(max_seq_length) -> tuple[int, int]: - return tuple(np.sort(np.random.choice(np.arange(0, max_seq_length - 1), size=2, replace=False)).tolist()) - - -@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) -def test_write_memmap_preference_dataset(dtype): - documents = [ - LanguageModelSample( - TokenSample(torch.from_numpy(np.random.randint(1000, size=100).astype(dtype))), - None, - RangeSample([_generate_valid_span(100)], 100), - RangeSample([_generate_valid_span(100)], 100), - ) - for _ in range(50) - ] - with tempfile.TemporaryDirectory() as temp_dir: - path = pathlib.Path(temp_dir) / "dataset" - MemmapDataset.write_dataset(path, documents, LanguageModelWriter) - dataset = MemmapDataset("dataset", path) - parameters = GPTSamplingParameters( - num_samples=0, sequence_length=0, vocab_size=0, use_preference_loss_spans=True - ) - for i, document in enumerate(documents): - dataset_document = dataset.get_document(i, parameters=parameters) - Assert.all_equal(dataset_document.tokens.tokens, document.tokens.tokens.to(torch.int64)) - Assert.eq(dataset_document.chosen_spans.ranges, document.chosen_spans.ranges) - Assert.eq(dataset_document.rejected_spans.ranges, document.rejected_spans.ranges) - - -def test_load_metadata_from_hub(): - with tempfile.TemporaryDirectory(suffix="test") as local_folder: - get_preparator(local_folder, "lhoestq/demo1")._save_croissant_metadata() - croissant_path = pathlib.Path(local_folder) / "croissant.json" - assert croissant_path.is_file() - metadata = json.load(croissant_path.open("r")) - assert metadata["url"] == "https://huggingface.co/datasets/lhoestq/demo1" - - -def test_absent_metadata_from_hub(): - with tempfile.TemporaryDirectory(suffix="test") as local_folder: - get_preparator(local_folder, "allenai/dolma")._save_croissant_metadata() - assert not (pathlib.Path(local_folder) / "croissant.json").is_file() - - -def test_load_metadata_local(): - with ( - tempfile.TemporaryDirectory(suffix="dataset") as dataset_folder, - tempfile.TemporaryDirectory(suffix="test") as local_folder, - ): - metadata = {"name": "test"} - json.dump(metadata, (pathlib.Path(dataset_folder) / "croissant.json").open("w")) - get_preparator(local_folder, dataset_folder)._save_croissant_metadata() - croissant_path = pathlib.Path(local_folder) / "croissant.json" - assert croissant_path.is_file() - assert json.loads(croissant_path.open("r").read()) == metadata - - -def test_absent_metadata_local(): - with ( - tempfile.TemporaryDirectory(suffix="dataset") as dataset_folder, - tempfile.TemporaryDirectory(suffix="test") as local_folder, - ): - get_preparator(local_folder, dataset_folder)._save_croissant_metadata() - assert not (pathlib.Path(local_folder) / "croissant.json").is_file() - - -DATASET_DICT_0 = { - "type": "mock_memmap", - "num_documents": 500, - "num_tokens_per_document": 300, -} -DATASET_DICT_1 = { - "type": "mock_memmap", - "num_documents": 1500, - "num_tokens_per_document": 100, -} - - -def test_split_dataset(): - dataset_config_0 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_0.copy()) - config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( - [dataset_config_0], - [dataset_config_0], # Mock reader config. - {"training": 3, "validation": 1}, - pathlib.Path("."), - ) - config = {key: value.to_dict() for key, value in config.items()} - - Assert.eq( - config, - { - "training": { - "type": "slice", - "dataset": dataset_config_0.to_dict(), - "begin": 0, - "end": 0.75, - }, - "validation": { - "type": "slice", - "dataset": dataset_config_0.to_dict(), - "begin": 0.75, - "end": 1, - }, - }, - ) - - -def test_split_datasets_0(): - dataset_config_0 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_0.copy()) - dataset_config_1 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_1.copy()) - config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( - [dataset_config_0, dataset_config_1], - [dataset_config_0, dataset_config_1], # Mock reader configs. - {"training": 1, "validation": 1}, - pathlib.Path("."), - ) - config = {key: value.to_dict() for key, value in config.items()} - - Assert.eq( - config, - { - "training": dataset_config_0.to_dict(), - "validation": dataset_config_1.to_dict(), - }, - ) - - -def test_split_datasets_1(): - dataset_config_0 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_0.copy()) - dataset_config_1 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_1.copy()) - config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( - [dataset_config_0, dataset_config_1], - [dataset_config_0, dataset_config_1], # Mock reader configs. - {"training": 3, "validation": 1}, - pathlib.Path("."), - ) - config = {key: value.to_dict() for key, value in config.items()} - - Assert.eq( - config, - { - "training": { - "type": "blended", - "datasets": [ - dataset_config_0.to_dict(), - { - "type": "slice", - "dataset": dataset_config_1.to_dict(), - "begin": 0, - "end": 0.5, - }, - ], - "weights": [2 / 3, 1 / 3], - }, - "validation": { - "type": "slice", - "dataset": dataset_config_1.to_dict(), - "begin": 0.5, - "end": 1, - }, - }, - ) diff --git a/tests/data/test_random.py b/tests/data/test_random.py index 8e5c61904..7a31358b9 100644 --- a/tests/data/test_random.py +++ b/tests/data/test_random.py @@ -16,22 +16,16 @@ def test_gpt_random_dataset(): # Make sure the random dataset works and check for unintended changes in behavior. - sampled = get_dataset_config({"type": "random"}, GPTRandomDatasetConfig).build_and_sample( - get_sampling_data(4, sequence_length=7) + sampled = get_dataset_config(config := {"type": "random"}, GPTRandomDatasetConfig).build_and_sample( + get_sampling_data(4, sequence_length=7, vocab_size=8192) ) compare_sampled_dataset(sampled, RANDOM_DATASET_EXPECTED_SAMPLES) - -def test_gpt_random_data(): + # Test in data. get_test_data_and_compare_samples( - { - "datasets": { - "training": { - "type": "random", - } - } - }, + {"datasets": {"training": config}}, 4, sequence_length=7, + vocab_size=8192, expected_samples=RANDOM_DATASET_EXPECTED_SAMPLES, ) diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index c171d15dd..2d102be01 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -2,8 +2,8 @@ import pytest import torch -from fast_llm.data.dataset.config import MemmapDatasetConfig, ShufflingType -from fast_llm.data.dataset.gpt.config import GPTSamplingParameters +from fast_llm.data.dataset.config import ShufflingType +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.token import TokenSample @@ -14,8 +14,7 @@ get_test_data_and_compare_samples, validate_indexed_dataset_sampling, ) -from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PATH +from tests.utils.dataset import get_common_test_dataset try: from fast_llm.csrc.data import build_padded_token_cumsum # noqa @@ -26,37 +25,28 @@ GPT_MEMMAP_SAMPLES = [ - [4709, 819, 79, 207, 277, 1790], - [1790, 80, 6506, 1735, 542, 88], - [88, 4302, 269, 2794, 119, 80], - [80, 207, 567, 498, 89, 207], - [207, 4700, 549, 79, 417, 3036], - [3036, 253, 207, 2968, 4536, 1178], - [1178, 3291, 317, 277, 2679, 89], - [89, 542, 395, 583, 684, 554], + [49152, 46, 10, 819, 19, 45], + [45, 69, 17, 86, 38826, 15], + [15, 25, 51, 31, 32348, 64], + [64, 17, 93, 78, 40, 1793], + [1793, 1, 1746, 38, 27, 58], + [58, 22885, 93, 37, 92, 76], + [76, 29, 19, 17365, 93, 46], + [46, 83, 17211, 1, 785, 1023], ] def test_gpt_sampled(): # Make sure the memmap dataset works and check for unintended changes in behavior. - get_test_dataset() - sampled = get_dataset_config({"type": "memmap", "path": DATASET_PATH}, MemmapDatasetConfig).build_and_sample( - get_sampling_data(8, sequence_length=5) - ) + _, config, _ = get_common_test_dataset() + sampled = get_dataset_config( + dataset_config := config, GPTDatasetFromFileConfig[LanguageModelSample] + ).build_and_sample(get_sampling_data(8, sequence_length=5)) validate_indexed_dataset_sampling(sampled, GPT_MEMMAP_SAMPLES) - -def test_gpt_sampled_data(): - get_test_dataset() + # Test in data. get_test_data_and_compare_samples( - { - "datasets": { - "training": { - "type": "memmap", - "path": DATASET_PATH, - } - } - }, + {"datasets": {"training": dataset_config}}, 8, sequence_length=5, expected_samples=GPT_MEMMAP_SAMPLES, @@ -169,7 +159,6 @@ def test_gpt_sample_padding(): sampling = get_sampling_data( num_samples=len(expected_samples), sequence_length=sequence_length, - vocab_size=vocab_size, seed=seed, shuffle=ShufflingType.disabled, truncate_documents=False, diff --git a/tests/data/test_slice.py b/tests/data/test_slice.py index 3a6b999cd..224b18270 100644 --- a/tests/data/test_slice.py +++ b/tests/data/test_slice.py @@ -1,67 +1,67 @@ from fast_llm.data.dataset.config import DatasetSliceConfig +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from fast_llm.data.sample.language_model import LanguageModelSample from tests.data.common import ( - compare_indexed_dataset, + compare_indexed_dataset_tokens, get_dataset_config, get_sampling_data, get_test_data_and_compare_samples, validate_indexed_dataset_sampling, ) -from tests.data.test_memmap import MEMMAP_DATASET_SAMPLES -from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PATH +from tests.data.test_preparator import COMMON_DATASET_SAMPLES +from tests.utils.dataset import get_common_test_dataset GPT_SLICE_TRAINING_SAMPLES = [ - [80, 268, 79, 260, 207, 3086], - [3086, 80, 413, 4872, 4602, 207], - [207, 7208, 1489, 776, 3514, 269], - [269, 73, 7367, 267, 477, 3126], + [49152, 20, 59, 81, 15, 54], + [54, 76, 7909, 44, 41, 1], + [1, 71, 28, 10, 42, 15963], + [15963, 80, 59, 86, 4, 74], ] GPT_SLICE_VALIDATION_SAMPLES = [ - [1886, 317, 5621, 3173, 330, 284], - [284, 2846, 706, 89, 80, 2047], - [2047, 207, 2449, 1423, 65, 985], - [985, 683, 4917, 87, 477, 481], - [481, 695, 947, 5871, 2344, 87], - [87, 489, 207, 489, 269, 356], - [356, 727, 7800, 4078, 243, 3712], - [3712, 86, 476, 80, 2547, 7390], + [49152, 3, 5621, 27, 7859, 13009], + [13009, 73, 32, 29, 32, 3], + [3, 89, 15, 45, 25, 75], + [75, 52, 13366, 88, 54, 19], + [19, 2, 74, 23, 92, 24747], + [24747, 42, 6, 477, 21, 47], + [47, 92, 31, 30, 463, 64], + [64, 23, 11, 56, 23555, 85], ] def test_gpt_slice(): # Make sure dataset splitting works and check for unintended changes in behavior. - get_test_dataset() + _, config, _ = get_common_test_dataset() + memmap_config = GPTDatasetFromFileConfig.from_dict(config)._load_config() # samples[9:18] dataset = get_dataset_config( - {"type": "slice", "dataset": {"type": "memmap", "path": DATASET_PATH}, "begin": 0.0015, "end": 0.003}, + {"type": "slice", "dataset": memmap_config, "begin": 0.025, "end": 0.1}, DatasetSliceConfig[LanguageModelSample], ).build() - compare_indexed_dataset(dataset, 9, 544, {i - 9: sample for i, sample in MEMMAP_DATASET_SAMPLES.items()}) + compare_indexed_dataset_tokens(dataset, 75, 3399, {i - 25: sample for i, sample in COMMON_DATASET_SAMPLES.items()}) sampled = dataset.sample(get_sampling_data(8, sequence_length=5)) validate_indexed_dataset_sampling(sampled, GPT_SLICE_VALIDATION_SAMPLES) - -def test_gpt_slice_data(): + # Test in data with multiple phases. get_test_data_and_compare_samples( { "datasets": { "training": { "type": "slice", - "dataset": {"type": "memmap", "path": DATASET_PATH}, + "dataset": memmap_config, "begin": 0, - "end": 0.0015, + "end": 0.025, }, "validation": { "type": "slice", - "dataset": {"type": "memmap", "path": DATASET_PATH}, - "begin": 0.0015, - "end": 0.003, + "dataset": memmap_config, + "begin": 0.025, + "end": 0.1, }, "test": { "type": "slice", - "dataset": {"type": "memmap", "path": DATASET_PATH}, - "begin": 0.003, + "dataset": memmap_config, + "begin": 0.1, "end": 1, }, } diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 489f5e1c1..3a90745eb 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import torch @@ -61,12 +62,12 @@ def reference_dpo_loss( def test_dpo_loss(): - torch.manual_seed(0) - logits = torch.randn((10, 50, 100), requires_grad=True) - reference_model_logits = torch.randn((10, 50, 100)) - targets = torch.randint(0, 100, (10, 50)) + random_state = np.random.RandomState(0) + logits = torch.from_numpy(random_state.normal(size=(10, 50, 100))).to(torch.float32).requires_grad_() + reference_model_logits = torch.from_numpy(random_state.normal(size=(10, 50, 100))).to(torch.float32) + targets = torch.from_numpy(random_state.randint(0, 100, (10, 50))) - spans = get_random_spans(10, 10, 50) + spans = get_random_spans(10, 10, 50, random_state) fastllm_loss, fast_llm_grad = compute_dpo_loss( logits, targets, reference_model_logits, spans[::2], spans[1::2], beta=1, grad_output=1 diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index 4b057dabd..f3ce65966 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -3,8 +3,10 @@ import struct import typing +import datasets import numpy as np import pytest +import torch import yaml from fast_llm.config import Field, FieldHint, config_class @@ -13,13 +15,15 @@ from fast_llm.data.dataset.gpt.config import GPTSamplingData from fast_llm.data.dataset.gpt.legacy_memmap import MEMMAP_DTYPES, MEMMAP_INDEX_HEADER, LegacyMemmapDataset from fast_llm.data.dataset.sampled import logger +from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.sample.token import TokenSample from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert from tests.utils.compare_tensor_logs import CompareConfig -from tests.utils.dataset import get_test_dataset_samples +from tests.utils.dataset import get_common_test_dataset from tests.utils.distributed_configs import DistributedTestingConfig -from tests.utils.global_variables import DATASET_CACHE, MODEL_TEST_VOCAB_SIZE +from tests.utils.global_variables import DATASET_CACHE, MODEL_TEST_VOCAB_SIZE, TOKENIZER_NAME from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -39,7 +43,17 @@ def get_megatron_test_dataset(prefix: pathlib.Path = MEGATRON_DATASET_PREFIX): and prefix.with_suffix(".bin").is_file() and prefix.parent.joinpath("fast_llm_config.yaml").is_file() ): - MegatronMemmapDataset.write_dataset(prefix, get_test_dataset_samples(vocab_size=MODEL_TEST_VOCAB_SIZE)) + _, _, hf_path = get_common_test_dataset() + hf_dataset = datasets.load_from_disk(hf_path)["train"] + tokenizer = TokenizerConfig(path=TOKENIZER_NAME).get_tokenizer() + samples = [ + LanguageModelSample( + TokenSample((tokenizer.tokenize(document["text"]) % MODEL_TEST_VOCAB_SIZE).to(torch.uint16)) + ) + for document in hf_dataset + ] + + MegatronMemmapDataset.write_dataset(prefix, samples) yaml.safe_dump( {"type": "memmap", "path": prefix.name}, prefix.parent.joinpath("fast_llm_config.yaml").open("w") ) diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index 7f2c9290a..b21bda1ea 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -1,24 +1,12 @@ import pathlib -import random +import typing +import datasets import numpy as np -import torch -import yaml - -from fast_llm.data.dataset.memmap import MemmapDataset -from fast_llm.data.sample.language_model import LanguageModelSample, LanguageModelWriter -from fast_llm.data.sample.range import RangeSample -from fast_llm.data.sample.token import TokenSample -from tests.utils.global_variables import ( - DATASET_PATH, - MODEL_DATASET_PATH, - MODEL_TEST_VOCAB_SIZE, - TEST_CHARACTERS, - TEST_DATASET_TOKENS, - TEST_VOCAB_SIZE, - TOKENIZER_FILE, - TOKENIZER_PATH, -) + +from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig +from fast_llm.utils import padded_cumsum +from tests.utils.global_variables import DATASET_CACHE, MODEL_TEST_VOCAB_SIZE, TOKENIZER_FILE, TOKENIZER_PATH def download_santacoder_tokenizer(): @@ -28,69 +16,165 @@ def download_santacoder_tokenizer(): transformers.AutoTokenizer.from_pretrained("bigcode/santacoder").save_pretrained(TOKENIZER_PATH) -def get_random_spans(num_samples: int, max_spans: int, lengths: np.ndarray | int, seed: int = 0): - spans = np.sort(np.random.RandomState(seed + 3847).randint(0, lengths, [num_samples, max_spans * 2])) - spans = [np.unique(sample_spans).tolist() for sample_spans in spans] +def get_random_spans( + num_documents: int, + max_spans: int, + lengths: np.ndarray | int, + random_state: np.random.RandomState = np.random, + use_last_format: bool = False, + variable_length: bool = True, +): + if variable_length: + spans = random_state.randint( + 0, lengths[:, None] if isinstance(lengths, np.ndarray) else lengths, [num_documents, max_spans * 2] + ) + else: + spans = [ + random_state.choice(range(length), max_spans * 2, replace=False) + for length in (lengths if isinstance(lengths, np.ndarray) else (lengths for _ in range(num_documents))) + ] + spans = [np.unique(sample_spans).tolist() for sample_spans in np.sort(spans)] return [ - [(begin, end) for begin, end in zip(sample_spans[::2], sample_spans[1::2], strict=False)] + [(begin, end - use_last_format) for begin, end in zip(sample_spans[::2], sample_spans[1::2], strict=False)] for sample_spans in spans ] -def get_test_dataset_samples( - seed: int = 1234, - num_tokens: int = TEST_DATASET_TOKENS, - characters: str = TEST_CHARACTERS, - vocab_size: int = TEST_VOCAB_SIZE, - max_spans: int = 0, -) -> list[LanguageModelSample]: - import transformers +def get_random_preference_spans(texts, random_state: np.random.RandomState = np.random) -> dict[str, str]: + texts_ = [] + chosen_spans = [] + rejected_spans = [] + for text in texts: + # Split in three non-empty_chunks + splits = np.sort(random_state.choice(range(1, len(text) - 1), 2, replace=False)).tolist() + texts_.append(text[: splits[0]]) + chosen_spans.append(text[splits[0] : splits[1]]) + rejected_spans.append(text[splits[1] :]) + return {"text": texts_, "chosen_span": chosen_spans, "rejected_span": rejected_spans} - download_santacoder_tokenizer() - texts = "".join(random.Random(seed).choices(characters, k=num_tokens)).splitlines() - tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) +def _get_hf_test_dataset( + seed: int = 1234, + num_documents: int = 1000, + min_document_size: int = 5, + max_document_size: int = 100, + max_loss_masking_spans: int = 0, + has_preference_spans: bool = False, +): + random_state = np.random.RandomState(seed) + # Generate random document sizes (character count). + document_sizes = random_state.randint(min_document_size, max_document_size, num_documents) + size_cumsums = padded_cumsum(document_sizes) + # Generate random ascii characters. + random_text = random_state.randint(32, 127, document_sizes.sum(), dtype=np.uint8).tobytes().decode() + texts = [random_text[begin:end] for begin, end in zip(size_cumsums[:-1], size_cumsums[1:])] + + if has_preference_spans: + dataset_dict = get_random_preference_spans(texts, random_state) + else: + dataset_dict: dict[str, typing.Any] = {"text": texts} + + if max_loss_masking_spans > 0: + dataset_dict["loss_masking_spans"] = get_random_spans( + num_documents, max_loss_masking_spans, document_sizes, random_state, use_last_format=True + ) - samples = [ - LanguageModelSample( - TokenSample(torch.from_numpy(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size)), + return datasets.Dataset.from_dict(dataset_dict) + + +def _get_test_dataset( + path: pathlib.Path, + seed: int, + tokenizer_path: str = TOKENIZER_PATH, + vocab_size: int | None = None, + documents_per_shard: int = 10**6, + num_documents: int = 1000, + min_document_size: int = 5, + max_document_size: int = 100, + max_loss_masking_spans: int = 0, + has_preference_spans: bool = False, + splits: dict[str, float] | None = None, +): + config_paths = ( + [path / "fast_llm_config.yaml"] + if splits is None + else [path / f"fast_llm_config_{split}.yaml" for split in splits] + ) + hf_path = path / "hf" + + if not (path.is_file() and all(config_path.is_file() for config_path in config_paths)): + dataset = _get_hf_test_dataset( + seed, num_documents, min_document_size, max_document_size, max_loss_masking_spans, has_preference_spans ) - for document in texts - ] - if max_spans > 0: - spans = get_random_spans( - len(samples), max_spans, np.array([[max(len(sample), 1)] for sample in samples]), seed + datasets.DatasetDict({"train": dataset}).save_to_disk(hf_path) + source_schema = {"text": "text"} + if max_loss_masking_spans > 0: + source_schema["loss_masking_spans"] = "loss_masking_spans" + if has_preference_spans: + source_schema["chosen_span"] = "chosen_span" + source_schema["rejected_span"] = "rejected_span" + + download_santacoder_tokenizer() + preparator_config = GPTMemmapDatasetPreparatorConfig.from_dict( + { + "dataset": { + "path": hf_path, + "load_from_disk": True, + "source_schema": source_schema, + }, + "tokenizer": {"path": tokenizer_path, "max_vocab_size": vocab_size}, + "output_path": path, + "documents_per_shard": documents_per_shard, + "splits": splits, + } ) - for sample, sample_spans in zip(samples, spans, strict=True): - sample.loss_masking_spans = RangeSample(sample_spans, len(sample)) - return samples + preparator_config.run() + config = ( + {"type": "file", "path": config_paths[0]} + if splits is None + else { + split: {"type": "file", "path": config_path} + for split, config_path in zip(splits, config_paths, strict=True) + } + ) + return path, config, hf_path -def get_test_dataset( - path: pathlib.Path = DATASET_PATH, - seed: int = 1234, - num_tokens: int = TEST_DATASET_TOKENS, - characters: str = TEST_CHARACTERS, - vocab_size: int = TEST_VOCAB_SIZE, - max_spans: int = 0, -): - config_path = path.parent.joinpath("fast_llm_config.yaml") - - if not (path.is_file() and config_path.is_file()): - samples = get_test_dataset_samples( - seed=seed, - num_tokens=num_tokens, - characters=characters, - vocab_size=vocab_size, - max_spans=max_spans, - ) - MemmapDataset.write_dataset(path, samples, LanguageModelWriter) - yaml.safe_dump({"type": "memmap", "path": path.name}, config_path.open("w")) +def get_common_test_dataset(): + return _get_test_dataset(DATASET_CACHE / "common_dataset", seed=1234) -def get_model_test_dataset( - path: pathlib.Path = MODEL_DATASET_PATH, - vocab_size: int = MODEL_TEST_VOCAB_SIZE, -): - return get_test_dataset(path, vocab_size=vocab_size) +def get_alt_test_dataset(): + return _get_test_dataset(DATASET_CACHE / "other_dataset", seed=2345) + + +def get_sharded_test_dataset(): + return _get_test_dataset(DATASET_CACHE / "common_dataset_sharded", seed=1234, documents_per_shard=350) + + +def get_split_test_dataset(): + return _get_test_dataset( + DATASET_CACHE / "common_dataset_split", seed=1234, splits={"training": 1, "validation": 1} + ) + + +def get_split_sharded_test_dataset(): + return _get_test_dataset( + DATASET_CACHE / "common_dataset_split_sharded", + seed=1234, + documents_per_shard=350, + splits={"training": 1, "validation": 1}, + ) + + +def get_test_dataset_with_loss_masking_spans(): + return _get_test_dataset(DATASET_CACHE / "dataset_with_loss_masking_spans", seed=1234, max_loss_masking_spans=5) + + +def get_test_dataset_with_preference_spans(): + return _get_test_dataset(DATASET_CACHE / "dataset_with_preference_spans", seed=1234, has_preference_spans=True) + + +def get_model_test_dataset(): + return _get_test_dataset(DATASET_CACHE / "model_dataset", seed=1234, vocab_size=MODEL_TEST_VOCAB_SIZE) diff --git a/tests/utils/global_variables.py b/tests/utils/global_variables.py index ea770be0a..20a0c7219 100644 --- a/tests/utils/global_variables.py +++ b/tests/utils/global_variables.py @@ -5,7 +5,6 @@ import os import pathlib -import string from fast_llm.utils import set_global_variables @@ -36,14 +35,11 @@ def set_testing_global_variables(): # TODO: Fixtures TOKENIZER_PATH = SHARED_RESULT_PATH / "tokenizer" TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json" +TOKENIZER_NAME = "bigcode/santacoder" + DATASET_CACHE = SHARED_RESULT_PATH / "dataset" -DATASET_PATH = DATASET_CACHE / "common_dataset/dataset.fast_llm_dataset" -DATASET_WITH_SPANS_PATH = DATASET_CACHE / "dataset_with_spans/dataset.fast_llm_dataset" -DATASET_SAMPLING_CACHE = TEST_RESULTS_PATH / "dataset_sampling_cache" -TEST_VOCAB_SIZE = 8192 -# Random lowercase: 80.7% (3.1% each); space: 18.6%; doc end: 0.6% -TEST_CHARACTERS = (string.ascii_lowercase) * 5 + " " * 30 + "\n" -TEST_DATASET_TOKENS = 1000000 -MODEL_DATASET_PATH = DATASET_CACHE / "model_dataset/dataset.fast_llm_dataset" +MODEL_DATASET_SHARD_PATH = DATASET_CACHE / "model_dataset/shard_0_0.fast_llm_dataset" + +DATASET_SAMPLING_CACHE = TEST_RESULTS_PATH / "dataset_sampling_cache" MODEL_TEST_VOCAB_SIZE = 384 diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index ee9c2b730..956aaea5a 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -22,7 +22,7 @@ Qwen2CheckpointFormat, ) from tests.utils.distributed_configs import DistributedTestingConfig -from tests.utils.global_variables import MODEL_DATASET_PATH, MODEL_TEST_VOCAB_SIZE +from tests.utils.global_variables import MODEL_DATASET_SHARD_PATH, MODEL_TEST_VOCAB_SIZE from fast_llm.engine.evaluation.evaluators import ( # isort:skip # needed for dynamic type registration EvaluatorsConfig, @@ -234,18 +234,18 @@ def _update_and_add_testing_config( "data": { "datasets": { "training": { - "dataset": {"type": "memmap", "path": MODEL_DATASET_PATH}, + "dataset": {"type": "memmap", "path": MODEL_DATASET_SHARD_PATH}, "type": "slice", "end": 0.969, }, "validation": { - "dataset": {"type": "memmap", "path": MODEL_DATASET_PATH}, + "dataset": {"type": "memmap", "path": MODEL_DATASET_SHARD_PATH}, "type": "slice", "begin": 0.969, "end": 0.999, }, "test": { - "dataset": {"type": "memmap", "path": MODEL_DATASET_PATH}, + "dataset": {"type": "memmap", "path": MODEL_DATASET_SHARD_PATH}, "type": "slice", "begin": 0.999, "end": 1, From 435d21491acb8357a8b49377c7c809e8b8d703d1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 4 Nov 2025 21:32:10 -0500 Subject: [PATCH 17/45] fix --- fast_llm/data/tokenizer.py | 88 -------------------------------------- 1 file changed, 88 deletions(-) delete mode 100644 fast_llm/data/tokenizer.py diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py deleted file mode 100644 index 71219a2bf..000000000 --- a/fast_llm/data/tokenizer.py +++ /dev/null @@ -1,88 +0,0 @@ -import numpy as np -import torch -from transformers import AutoTokenizer - -from fast_llm.data.config import TokenizerConfig -from fast_llm.engine.config_utils.run import log_main_rank -from fast_llm.utils import Assert - - -class Tokenizer: - """ - A wrapper around Huggingface (transformers) tokenizer. - """ - - def __init__(self, config: TokenizerConfig): - log_main_rank(f"> loading tokenizer from {config.path} ...") - self.tokenizer = AutoTokenizer.from_pretrained( - pretrained_model_name_or_path=config.path, - errors="replace", - max_len=None, - trust_remote_code=True, - use_fast=True, - ) - if config.bos_token is not None: - self.tokenizer.bos_token = config.bos_token - if self.tokenizer.eos_token_id is None: - raise ValueError("Tokenizer does not have an EOS token.") - if self.tokenizer.bos_token_id is None: - raise ValueError("Tokenizer does not have an BOS token.") - self.eod_id = self.tokenizer.eos_token_id - self.bod_id = self.tokenizer.bos_token_id - - @property - def vocab_size(self) -> int: - return len(self.tokenizer) - - @property - def vocab(self) -> dict[str, int]: - return self.tokenizer.vocab - - @property - def inv_vocab(self) -> dict[int, str]: - return self._inv_vocab - - def tokenize(self, text: str, begin: bool = True, end: bool = True) -> list[int]: - return ( - ([self.bod_id] if begin else []) - + self.tokenizer.encode(text, add_special_tokens=False) - + ([self.eod_id] if end else []) - ) - - def tokenize_with_spans( - self, text: str, begin: bool = True, end: bool = True, *, spans: list[tuple[int, int]] - ) -> tuple[list[int], list[tuple[int, int]]]: - """ - Perform span-aware tokenization and return the tokenized input_ids along with token spans. - """ - if not spans: - return self.tokenize(text, begin, end), [] - input_ids, token_splits = self.tokenize_with_splits( - text, begin, end, text_splits=[split for splits in spans for split in splits] - ) - return input_ids, [(begin, end) for begin, end in zip(token_splits[::2], token_splits[1::2], strict=True)] - - def tokenize_with_splits( - self, text: str, begin: bool = True, end: bool = True, *, text_splits: list[int] - ) -> tuple[list[int], list[int]]: - Assert.eq(sorted(text_splits), text_splits) - input_ids = [] - text_splits = [0, *text_splits, len(text_splits)] - token_splits = [] - - for split_begin, split_end in zip(text_splits[:-1], text_splits[1:]): - input_ids.extend( - self.tokenize( - text[split_begin:split_end], begin=begin and split_begin == 0, end=end and split_end == len(text) - ) - ) - token_splits.append(len(input_ids)) - - return input_ids, token_splits[:-1] - - def detokenize(self, token_ids: int | list[int] | np.ndarray | torch.Tensor) -> str: - return self.tokenizer.decode(token_ids) - - @property - def eod(self): - return self.eod_id From f6bef55fb25d4c0c85f3bde2763e2ec55baaf416 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 5 Nov 2025 19:22:49 -0500 Subject: [PATCH 18/45] fix --- tests/utils/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index b21bda1ea..ba19916ee 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -102,7 +102,7 @@ def _get_test_dataset( ) hf_path = path / "hf" - if not (path.is_file() and all(config_path.is_file() for config_path in config_paths)): + if not all(config_path.is_file() for config_path in config_paths): dataset = _get_hf_test_dataset( seed, num_documents, min_document_size, max_document_size, max_loss_masking_spans, has_preference_spans ) From e05d9a1d0bc6fbfcdb7229982623653d7f1a7082 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 5 Nov 2025 19:47:25 -0500 Subject: [PATCH 19/45] fix --- fast_llm/data/auto.py | 12 ++++++++++++ fast_llm/models/auto.py | 1 + 2 files changed, 13 insertions(+) diff --git a/fast_llm/data/auto.py b/fast_llm/data/auto.py index c44e538fa..22ab3d731 100644 --- a/fast_llm/data/auto.py +++ b/fast_llm/data/auto.py @@ -2,4 +2,16 @@ Import these submodules to ensure classes are added to the dynamic class registry. """ +from fast_llm.data.dataset.config import ( # isort: skip + BlendedDatasetConfig, + ConcatenatedDatasetConfig, + DatasetSliceConfig, + MemmapDatasetConfig, + SampledDatasetUpdateConfig, +) +from fast_llm.data.dataset.gpt.config import ( # isort: skip + GPTDatasetFromFileConfig, + GPTFimSampledDatasetConfig, + GPTRandomDatasetConfig, +) from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig # isort: skip diff --git a/fast_llm/models/auto.py b/fast_llm/models/auto.py index 322932664..414314627 100644 --- a/fast_llm/models/auto.py +++ b/fast_llm/models/auto.py @@ -2,6 +2,7 @@ Import these submodules to ensure classes are added to the dynamic class registry. """ +from fast_llm.layers.attention.config import AttentionConfig # isort: skip from fast_llm.layers.ssm.config import MambaConfig, Mamba2Config, DiscreteMamba2Config # isort: skip from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig # isort: skip from fast_llm.engine.evaluation.evaluators import EvaluatorsConfig # isort: skip From 9ba8d1bb6aaf7008cf5d5d9ded24ae19b9795233 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 5 Nov 2025 19:59:34 -0500 Subject: [PATCH 20/45] fix --- fast_llm/data/dataset/gpt/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 7583345c3..2334d1173 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -9,12 +9,12 @@ from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.dataset.config import SamplableDatasetConfig, SampledDatasetConfig, SamplingData, SamplingParameters from fast_llm.data.preprocessing.tokenizer import TokenizerConfig -from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.data.dataset.gpt.fim import GPTFimDataset from fast_llm.data.dataset.gpt.random import GPTRandomDataset + from fast_llm.data.sample.language_model import LanguageModelSample @dataclasses.dataclass(kw_only=True) From b35b297678bba8672e9c81ae52e135ccbd6382eb Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 6 Nov 2025 02:40:20 -0500 Subject: [PATCH 21/45] fixes --- fast_llm/data/dataset/gpt/config.py | 1 + tests/data/test_loss_masking_spans.py | 34 +++++++------ tests/data/test_preference_spans.py | 6 ++- tests/data/test_preparator.py | 2 +- tests/utils/dataset.py | 70 ++++++++++++++++++--------- 5 files changed, 70 insertions(+), 43 deletions(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 2334d1173..8dd4098a3 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -25,6 +25,7 @@ class GPTSamplingParameters(SamplingParameters): # TODO: Only used for random dataset. Remove? Or use as safety check? vocab_size: int | None = None + # TODO: ====== Get these to memmap dataset (currently ignored) ====== use_loss_masking_spans: bool = False use_preference_loss_spans: bool = False diff --git a/tests/data/test_loss_masking_spans.py b/tests/data/test_loss_masking_spans.py index 521eaf2a9..443a26819 100644 --- a/tests/data/test_loss_masking_spans.py +++ b/tests/data/test_loss_masking_spans.py @@ -1,4 +1,5 @@ import datasets +import pytest from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingParameters from fast_llm.data.dataset.memmap import MemmapDataset @@ -10,30 +11,31 @@ from tests.utils.dataset import get_test_dataset_with_loss_masking_spans from tests.utils.global_variables import TOKENIZER_NAME -DATASET_WITH_SPAN_TOKENS = 46199 +DATASET_WITH_SPAN_TOKENS = 45577 DATASET_WITH_SPAN_SAMPLES = { - 27: [49152, 63, 82, 11, 84, 71, 49152], - 30: [49152, 31, 85, 78, 27, 34, 46, 62, 43, 49152], + 27: [49152, 63, 82, 11, 27799, 49152], + 30: [49152, 31, 85, 78, 27, 1448, 62, 43, 49152], 31: [49152, 60, 55, 80, 30, 85, 22, 18, 49152], 77: [49152, 73, 80, 85, 52, 22, 46, 5, 88, 78, 49152], - 87: [49152, 52, 89, 75, 11, 71, 49152], + 87: [49152, 52, 42536, 11, 71, 49152], } HF_LOSS_MASKING_SPANS = { - 27: [[0, 1], [3, 3]], - 30: [[0, 0], [2, 2], [5, 5]], - 31: [[0, 0], [2, 2], [4, 4]], - 77: [[0, 0], [3, 5], [7, 7]], - 87: [[1, 1], [3, 3]], + 27: [[0, 1]], + 30: [[0, 1]], + 31: [[0, 0], [2, 2], [5, 5]], + 77: [[0, 0], [2, 2], [5, 5], [7, 7]], + 87: [[0, 0], [3, 3]], } TOKEN_LOSS_MASKING_SPANS = { - 27: [(1, 3), (4, 5)], - 30: [(1, 2), (3, 4), (6, 7)], - 31: [(1, 2), (3, 4), (5, 6)], - 77: [(1, 2), (4, 7), (8, 9)], - 87: [(2, 3), (4, 5)], + 27: [(1, 3)], + 30: [(1, 3)], + 31: [(1, 2), (3, 4), (6, 7)], + 77: [(1, 2), (3, 4), (6, 7), (8, 9)], + 87: [(1, 2), (3, 4)], } +@pytest.mark.slow def test_gpt_data_with_spans(): _, config, hf_path = get_test_dataset_with_loss_masking_spans() dataset: MemmapDataset[LanguageModelSample] = get_dataset_config(config, GPTDatasetFromFileConfig).build() @@ -74,5 +76,5 @@ def test_gpt_data_with_spans(): document = dataset.get_document( index, parameters=GPTSamplingParameters(num_samples=0, sequence_length=0, use_loss_masking_spans=True) ) - Assert.all_equal(document.tokens.tokens, DATASET_WITH_SPAN_SAMPLES[index]) - Assert.all_equal(document.loss_masking_spans.ranges, TOKEN_LOSS_MASKING_SPANS[index]) + Assert.eq(document.tokens.tokens.tolist(), DATASET_WITH_SPAN_SAMPLES[index]) + Assert.eq(document.loss_masking_spans.ranges, TOKEN_LOSS_MASKING_SPANS[index]) diff --git a/tests/data/test_preference_spans.py b/tests/data/test_preference_spans.py index 7b570c5a1..ef18337eb 100644 --- a/tests/data/test_preference_spans.py +++ b/tests/data/test_preference_spans.py @@ -1,5 +1,6 @@ import datasets import numpy as np +import pytest import torch from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingParameters @@ -36,6 +37,7 @@ } +@pytest.mark.slow def test_gpt_data_with_spans(): _, config, hf_path = get_test_dataset_with_preference_spans() dataset: MemmapDataset[LanguageModelSample] = get_dataset_config(config, GPTDatasetFromFileConfig).build() @@ -101,5 +103,5 @@ def test_gpt_data_with_spans(): document = dataset.get_document( index, parameters=GPTSamplingParameters(num_samples=0, sequence_length=0, use_loss_masking_spans=True) ) - Assert.all_equal(document.tokens.tokens, DATASET_WITH_PREFERENCE_SPAN_SAMPLES[index]) - Assert.all_equal(document.chosen_spans.ranges + document.rejected_spans.ranges, TOKEN_PREFERENCE_SPANS[index]) + Assert.eq(document.tokens.tokens.tolist(), DATASET_WITH_PREFERENCE_SPAN_SAMPLES[index]) + Assert.eq(document.chosen_spans.ranges + document.rejected_spans.ranges, TOKEN_PREFERENCE_SPANS[index]) diff --git a/tests/data/test_preparator.py b/tests/data/test_preparator.py index 235135156..729888d9c 100644 --- a/tests/data/test_preparator.py +++ b/tests/data/test_preparator.py @@ -72,7 +72,7 @@ def test_common_prepared_dataset(): for index in COMMON_DATASET_SAMPLES: Assert.eq(hf_dataset[index]["text"], COMMON_DATASET_TEXT[index]) document = dataset.get_document(index, parameters=GPTSamplingParameters(num_samples=0, sequence_length=0)) - Assert.all_equal(document.tokens.tokens, COMMON_DATASET_SAMPLES[index]) + Assert.eq(document.tokens.tokens.tolist(), COMMON_DATASET_SAMPLES[index]) @pytest.mark.slow diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index ba19916ee..28d28bd94 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -16,27 +16,45 @@ def download_santacoder_tokenizer(): transformers.AutoTokenizer.from_pretrained("bigcode/santacoder").save_pretrained(TOKENIZER_PATH) +def get_random_text( + num_documents: int = 1000, + min_document_size: int = 5, + max_document_size: int = 99, + random_state: np.random.RandomState = np.random, +): + # Randomize document sizes + document_sizes = random_state.randint(min_document_size, max_document_size + 1, num_documents) + size_cumsums = padded_cumsum(document_sizes) + # Generate random ascii characters. + random_text = random_state.randint(32, 127, document_sizes.sum(), dtype=np.uint8).tobytes().decode() + # Gather text by documents. + texts = [ + random_text[size_cumsums[document_index] : size_cumsums[document_index + 1]] + for document_index in range(num_documents) + ] + return texts, document_sizes + + def get_random_spans( - num_documents: int, + document_sizes: np.ndarray, + min_spans: int, max_spans: int, - lengths: np.ndarray | int, random_state: np.random.RandomState = np.random, use_last_format: bool = False, - variable_length: bool = True, ): - if variable_length: - spans = random_state.randint( - 0, lengths[:, None] if isinstance(lengths, np.ndarray) else lengths, [num_documents, max_spans * 2] - ) - else: - spans = [ - random_state.choice(range(length), max_spans * 2, replace=False) - for length in (lengths if isinstance(lengths, np.ndarray) else (lengths for _ in range(num_documents))) - ] - spans = [np.unique(sample_spans).tolist() for sample_spans in np.sort(spans)] + # Randomize span counts. Actual count may be lower for small documents. + span_counts = random_state.randint(min_spans, max_spans + 1, len(document_sizes)) + # Generate random spans. return [ - [(begin, end - use_last_format) for begin, end in zip(sample_spans[::2], sample_spans[1::2], strict=False)] - for sample_spans in spans + [ + (begin, end - use_last_format) + for begin, end in np.sort( + random_state.choice(range(length), min(num_spans, length // 2) * 2, replace=False) + ) + .reshape([-1, 2]) + .tolist() + ] + for length, num_spans in zip(document_sizes, span_counts, strict=True) ] @@ -57,17 +75,14 @@ def _get_hf_test_dataset( seed: int = 1234, num_documents: int = 1000, min_document_size: int = 5, - max_document_size: int = 100, + max_document_size: int = 99, + min_loss_masking_spans: int = 0, max_loss_masking_spans: int = 0, has_preference_spans: bool = False, ): random_state = np.random.RandomState(seed) # Generate random document sizes (character count). - document_sizes = random_state.randint(min_document_size, max_document_size, num_documents) - size_cumsums = padded_cumsum(document_sizes) - # Generate random ascii characters. - random_text = random_state.randint(32, 127, document_sizes.sum(), dtype=np.uint8).tobytes().decode() - texts = [random_text[begin:end] for begin, end in zip(size_cumsums[:-1], size_cumsums[1:])] + texts, document_sizes = get_random_text(num_documents, min_document_size, max_document_size, random_state) if has_preference_spans: dataset_dict = get_random_preference_spans(texts, random_state) @@ -76,7 +91,7 @@ def _get_hf_test_dataset( if max_loss_masking_spans > 0: dataset_dict["loss_masking_spans"] = get_random_spans( - num_documents, max_loss_masking_spans, document_sizes, random_state, use_last_format=True + document_sizes, min_loss_masking_spans, max_loss_masking_spans, random_state, use_last_format=True ) return datasets.Dataset.from_dict(dataset_dict) @@ -90,7 +105,8 @@ def _get_test_dataset( documents_per_shard: int = 10**6, num_documents: int = 1000, min_document_size: int = 5, - max_document_size: int = 100, + max_document_size: int = 99, + min_loss_masking_spans: int = 0, max_loss_masking_spans: int = 0, has_preference_spans: bool = False, splits: dict[str, float] | None = None, @@ -104,7 +120,13 @@ def _get_test_dataset( if not all(config_path.is_file() for config_path in config_paths): dataset = _get_hf_test_dataset( - seed, num_documents, min_document_size, max_document_size, max_loss_masking_spans, has_preference_spans + seed=seed, + num_documents=num_documents, + min_document_size=min_document_size, + max_document_size=max_document_size, + min_loss_masking_spans=min_loss_masking_spans, + max_loss_masking_spans=max_loss_masking_spans, + has_preference_spans=has_preference_spans, ) datasets.DatasetDict({"train": dataset}).save_to_disk(hf_path) source_schema = {"text": "text"} From abe23579fa7bb01181234741d015fb1bd5ed0e54 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 10 Nov 2025 20:04:01 -0500 Subject: [PATCH 22/45] misc --- fast_llm/data/sample/abstract.py | 5 +--- fast_llm/data/sample/range.py | 13 +++++++---- fast_llm/data/sample/token.py | 40 ++++++++++++++++---------------- 3 files changed, 29 insertions(+), 29 deletions(-) diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py index aaa321efd..11f5d187c 100644 --- a/fast_llm/data/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -37,11 +37,8 @@ def from_samples(cls, samples: typing.Iterable[Sample]) -> typing.Self: pass @abc.abstractmethod - def to_samples(self) -> list[Sample]: - pass - def crop(self, begin: int, end: int) -> typing.Self: - return self.from_samples(sample.crop(begin, end) for sample in self.to_samples()) + pass def to_device_(self, device: "torch.device | str"): pass diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py index c3a035376..b7be4efe1 100644 --- a/fast_llm/data/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -15,6 +15,11 @@ from fast_llm.utils import Assert, get_unique +def crop_ranges(ranges: list[tuple[int, int]], begin: int, end: int) -> list[tuple[int, int]]: + cropped_ranges = ((max(begin_ - begin, 0), min(end_ - begin, end - begin)) for begin_, end_ in ranges) + return [(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_] + + class RangeSample(Sample): """ A reusable component holding a set of ranges in a sample. @@ -36,9 +41,7 @@ def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: return cls(ranges, sample_size) def crop(self, begin: int, end: int) -> typing.Self: - sample_size = end - begin - cropped_ranges = ((max(begin_ - begin, 0), min(end_ - begin, sample_size)) for begin_, end_ in self.ranges) - return self.__class__([(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_], sample_size) + return self.__class__(crop_ranges(self.ranges, begin, end), end - begin) def __len__(self) -> int: return self.sample_size @@ -56,8 +59,8 @@ def __init__(self, ranges: list[list[tuple[int, int]]], sample_size: int): def from_samples(cls, samples: typing.Iterable[RangeSample]) -> typing.Self: return cls([sample.ranges for sample in samples], get_unique(sample.sample_size for sample in samples)) - def to_samples(self) -> list[RangeSample]: - return [RangeSample(sample_ranges, self.sample_size) for sample_ranges in self.ranges] + def crop(self, begin: int, end: int) -> typing.Self: + return self.__class__([crop_ranges(sample_ranges, begin, end) for sample_ranges in self.ranges], end - begin) @config_class(dynamic_type={MemmapReaderBaseConfig: "range"}) diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index 706b5053a..0944f5689 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -16,6 +16,23 @@ from fast_llm.utils import Assert +def crop_lengths(lengths: list[int], begin: int, end: int) -> list[int]: + if len(lengths) == 1: + # Shortcut for the frequent case of a single document. + return [end - begin] + begin_ = 0 + lengths = [] + for length in lengths: + end_ = begin_ + length + cropped_length = min(end_, end) - max(begin_, begin) + if cropped_length > 0: + lengths.append(cropped_length) + if end_ > end: + break + begin_ = end_ + return lengths + + class TokenSample(Sample): def __init__(self, tokens: torch.Tensor, lengths: list[int] | None = None): self.tokens = tokens @@ -34,22 +51,7 @@ def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: ) def crop(self, begin: int, end: int) -> typing.Self: - sample_size = end - begin - if self.lengths == [len(self.tokens)]: - # Shortcut for the frequent case of a single document. - lengths = [sample_size] - else: - begin_ = 0 - lengths = [] - for length in self.lengths: - end_ = begin_ + length - cropped_length = min(end_, end) - max(begin_, begin) - if cropped_length > 0: - lengths.append(cropped_length) - if end_ > end: - break - begin_ = end_ - return self.__class__(self.tokens[begin:end], lengths) + return self.__class__(self.tokens[begin:end], crop_lengths(self.lengths, begin, end)) def __len__(self) -> int: return len(self.tokens) @@ -72,12 +74,10 @@ def from_samples(cls, samples: typing.Iterable[TokenSample]) -> typing.Self: [sample.lengths for sample in samples], ) - def to_samples(self) -> list[TokenSample]: - return [TokenSample(tokens, lengths) for tokens, lengths in zip(self.tokens, self.lengths, strict=True)] - def crop(self, begin: int, end: int) -> typing.Self: return self.__class__( - self.tokens[:, begin:end], [sample.crop(begin, end).lengths for sample in self.to_samples()] + self.tokens[:, begin:end], + [crop_lengths(lengths, begin, end) for lengths in self.lengths], ) def to_device_(self, device: "torch.device | str"): From 1801d873d49c4056927d3d61e5808153c7f6e896 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 10 Nov 2025 20:53:39 -0500 Subject: [PATCH 23/45] fix --- tests/functional/test_functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 3a90745eb..05fafe7a9 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -67,7 +67,7 @@ def test_dpo_loss(): reference_model_logits = torch.from_numpy(random_state.normal(size=(10, 50, 100))).to(torch.float32) targets = torch.from_numpy(random_state.randint(0, 100, (10, 50))) - spans = get_random_spans(10, 10, 50, random_state) + spans = get_random_spans(np.full(10, 50), 0, 10, random_state) fastllm_loss, fast_llm_grad = compute_dpo_loss( logits, targets, reference_model_logits, spans[::2], spans[1::2], beta=1, grad_output=1 From 2223b85fd195a0e02601b9274b484ab1dea967f5 Mon Sep 17 00:00:00 2001 From: bigximik Date: Thu, 13 Nov 2025 13:40:52 +0000 Subject: [PATCH 24/45] fix right stage mode --- fast_llm/engine/training/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index aa4f2d570..9737546ad 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -203,7 +203,9 @@ def setup(self, distributed: Distributed, run: Run) -> None: # Setup the model. with torch.no_grad(): log_main_rank("Setting up model...") - self._multi_stage.setup(distributed) + self._multi_stage.setup( + distributed, mode=StageMode.inference if self._is_evaluation_only else StageMode.training + ) for name, reference_model in self._reference_models.items(): log_main_rank(f"Setting up `{name}` reference model...") reference_model.fast_llm_model.setup(distributed, StageMode.inference) From a9a4ace43151d228df2ae7ef1a571c32ff51f961 Mon Sep 17 00:00:00 2001 From: bigximik Date: Thu, 13 Nov 2025 13:42:01 +0000 Subject: [PATCH 25/45] newer transformers fixes --- .../apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py | 7 ++----- .../diffusion_llama/modeling_diffusion_llama.py | 2 +- fast_llm_external_models/mtp_llama/modeling_mtp_llama.py | 7 ++----- 3 files changed, 5 insertions(+), 11 deletions(-) diff --git a/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py index 40c4cfa87..a80c031aa 100644 --- a/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py +++ b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py @@ -18,7 +18,7 @@ from transformers.modeling_utils import PreTrainedModel from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralMLP, MistralModel, MistralRMSNorm from transformers.processing_utils import Unpack -from transformers.utils import LossKwargs, logging +from transformers.utils import TransformersKwargs, logging from transformers.utils.generic import ModelOutput from fast_llm_external_models.apriel_hybrid_ssm.configuration_apriel_hybrid_ssm import AprielHybridSSMConfig @@ -1252,9 +1252,6 @@ def forward( return output -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... - - class AprielHybridSSMPreTrainedModel(PreTrainedModel): config_class = AprielHybridSSMConfig base_model_prefix = "model" @@ -1383,7 +1380,7 @@ def forward( output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], + **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): diff --git a/fast_llm_external_models/diffusion_llama/modeling_diffusion_llama.py b/fast_llm_external_models/diffusion_llama/modeling_diffusion_llama.py index c8723af5d..a67a302ef 100644 --- a/fast_llm_external_models/diffusion_llama/modeling_diffusion_llama.py +++ b/fast_llm_external_models/diffusion_llama/modeling_diffusion_llama.py @@ -706,7 +706,7 @@ def forward( output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs, # TODO: Kwargs for Diffusion? : Unpack[KwargsForCausalLM], + **kwargs, # TODO: Kwargs for Diffusion? : Unpack[TransformersKwargs], ) -> MaskedLMOutput: r""" # TODO: Update docstring for diffusion diff --git a/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py b/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py index 5ad99ff96..d0e1988f1 100644 --- a/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py +++ b/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py @@ -15,7 +15,7 @@ from transformers.processing_utils import Unpack from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.utils import ( - LossKwargs, + TransformersKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging, @@ -761,9 +761,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... - - class MTPLlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} @@ -812,7 +809,7 @@ def forward( output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], + **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): From 97f2b60e9683d31386f6f9b9451171047aca7be3 Mon Sep 17 00:00:00 2001 From: bigximik Date: Thu, 13 Nov 2025 13:43:43 +0000 Subject: [PATCH 26/45] fix distributed tests skip on single gpu --- tests/models/test_checkpoint.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 3c3bfb833..c5a5e1c5b 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -379,6 +379,9 @@ def test_save_and_load_in_parallel(run_distributed_script, run_test_script_base_ # TODO: Test beyond 2 gpu configs? import tests.models.distributed_test_checkpoint + if torch.cuda.device_count() < 2: + pytest.skip(f"Not enough GPUs: {torch.cuda.device_count()} < 2") + script = [ "-m", tests.models.distributed_test_checkpoint.__name__, @@ -405,6 +408,7 @@ def reference_distributed_shard(get_convert_path) -> torch.Tensor | None: @requires_cuda +# NOTE: Should it depend on test_model_distributed instead? @pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) def test_load_parallel_checkpoint_in_single_gpu( @@ -425,6 +429,10 @@ def test_load_parallel_checkpoint_in_single_gpu( distributed_save_load_config = distributed_save_load_config.resolve( base_path=run_test_script_base_path, model_testing_config=model_testing_config ) + if torch.cuda.device_count() < distributed_save_load_config.num_gpus: + pytest.skip( + f"Not enough GPUs to run dependency: {torch.cuda.device_count()} < {distributed_save_load_config.num_gpus}" + ) report_subtest(distributed_save_load_config.save_path, distributed_save_load_config.num_gpus) load_and_compare_checkpoints( DistributedCheckpointFormat, From 0fdc978eebed4d0c8d323179865dfdb227108670 Mon Sep 17 00:00:00 2001 From: bigximik Date: Thu, 13 Nov 2025 13:44:27 +0000 Subject: [PATCH 27/45] set mamba 2 style model conversions to broke --- tests/utils/model_configs.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index c02521d7b..075530bbb 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -637,7 +637,8 @@ def _update_and_add_testing_config( groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.normal, + # TODO: Fix and bring back to `testing_groups` + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, @@ -650,7 +651,7 @@ def _update_and_add_testing_config( ), # "pp","dp", "ce","16", "bf", "df", "stp"), ) - +# TODO: remove obsolete model _update_and_add_testing_config( # Tests hybrid discrete Mamba 2. "llama", @@ -682,7 +683,7 @@ def _update_and_add_testing_config( groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, # TODO: Implement From 224c2ec7fde6cea6d0a1fab12864b429febdde6b Mon Sep 17 00:00:00 2001 From: bigximik Date: Mon, 17 Nov 2025 07:38:02 +0000 Subject: [PATCH 28/45] mmaba2 enable conversion tests --- tests/utils/model_configs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 075530bbb..3e2748aa9 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -638,7 +638,7 @@ def _update_and_add_testing_config( ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, # TODO: Fix and bring back to `testing_groups` - ModelTestingGroup.convert: ModelTestingGroupAction.broken, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, From 00bba27b9834befeb56381a3455ca1638f3102ec Mon Sep 17 00:00:00 2001 From: bigximik Date: Sun, 23 Nov 2025 16:20:13 +0000 Subject: [PATCH 29/45] added model_and_sequence_data_group --- fast_llm/engine/distributed/config.py | 46 +++++++++++++++++++--- fast_llm/engine/distributed/distributed.py | 3 ++ 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 602c44a4e..4bfefbc04 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -1,3 +1,5 @@ +import collections +import copy import dataclasses import enum import logging @@ -97,6 +99,7 @@ class DistributedDimNames: sequence_data = "sequence_data" batch_data = "batch_data" tensor_and_sequence_data = "tensor_and_sequence_data" + model_and_sequence_data = "model_and_sequence_data" @config_class() @@ -242,6 +245,17 @@ class DistributedConfig(Config): ) def _validate(self) -> None: + self._init_ranks_and_sizes() + self._init_distributed_dims() + + super()._validate() + + if self.reference_config is not None: + self.compare(self.reference_config, ValueError) + Assert.in_range(self.rank, 0, self.world_size) + Assert.in_range(self.local_rank, 0, self.local_world_size) + + def _init_ranks_and_sizes(self): if self.world_size is None: self.world_size = self.default_world_size if self.rank is None: @@ -272,6 +286,7 @@ def _validate(self) -> None: if self.tensor_parallel == 1 and self.sequence_tensor_parallel: self.sequence_tensor_parallel = False + def _init_distributed_dims(self): if self.reference_config is not None: self.reference_config.validate() if self.reference_config.reference_config is not None: @@ -343,12 +358,33 @@ def _validate(self) -> None: ) ) - super()._validate() + rank, global_ranks = self._get_model_and_sequence_data_rank_and_global_ranks() + self._add_distributed_dim( + DistributedDim( + name=DistributedDimNames.model_and_sequence_data, + size=self.sequence_data_parallel * self.model_parallel, + rank=rank, + global_ranks=global_ranks, + ) + ) - if self.reference_config is not None: - self.compare(self.reference_config, ValueError) - Assert.in_range(self.rank, 0, self.world_size) - Assert.in_range(self.local_rank, 0, self.local_world_size) + def _get_model_and_sequence_data_rank_and_global_ranks(self) -> tuple[int, tuple[int]]: + # NOTE: The mapping from global ranks to batch-data-parallel groups is not easily + # expressible with a simple arithmetic pattern (e.g., fixed striding). To determine + # the grouping, we simulate rank initialization for every possible rank and record + # how ranks are assigned. This lets us compute: + # - `rank`: the index of the current rank within its batch-data-parallel group + # - the full list of global ranks that belong to this batch-data-parallel group. + + cfg = copy.copy(self) + batch_data_groups = collections.defaultdict(list) + for i in range(self.world_size): + cfg.rank = i + cfg._init_ranks_and_sizes() + if i == self.rank: + rank = len(batch_data_groups[cfg.batch_data_rank]) + batch_data_groups[cfg.batch_data_rank].append(i) + return rank, tuple(batch_data_groups[self.batch_data_rank]) def _get_global_ranks(self, size: int, stride: int) -> range: start = self.rank // (size * stride) * size * stride + self.rank % stride diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index 2e2f9d401..85e1b29fa 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -174,6 +174,9 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False): self.tensor_and_sequence_data_group = self.add_group( self._config.distributed_dims[DistributedDimNames.tensor_and_sequence_data] ) + self.model_and_sequence_data_group = self.add_group( + self._config.distributed_dims[DistributedDimNames.model_and_sequence_data] + ) self._config.log_first_rank(f"Setting random seeds...") From 5b20276ed84400042fb5cbab2b425f9a34070922 Mon Sep 17 00:00:00 2001 From: bigximik Date: Sun, 23 Nov 2025 16:22:43 +0000 Subject: [PATCH 30/45] added Iterable dataset base classes --- fast_llm/data/dataset/abstract.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/fast_llm/data/dataset/abstract.py b/fast_llm/data/dataset/abstract.py index 33942708b..f8705d470 100644 --- a/fast_llm/data/dataset/abstract.py +++ b/fast_llm/data/dataset/abstract.py @@ -43,8 +43,30 @@ def __len__(self) -> int: pass +class SampledIterableDataset[SampleType: Sample](Dataset[SampleType], typing.Iterable[SampleType]): + """ + A sampled dataset class that provides an iterator over samples. + (See the `Sampler` class below.) + """ + + @abc.abstractmethod + def __iter__(self) -> typing.Iterator[SampleType]: + """Return an iterator over samples.""" + + class SamplableDataset[SampleType: Sample](Dataset[SampleType]): @abc.abstractmethod def sample(self, config: "SamplingData") -> SampledDataset[SampleType]: pass + + +class SamplableIterableDataset[SampleType: Sample](Dataset[SampleType]): + + @abc.abstractmethod + def sample(self, config: "SamplingData") -> SampledIterableDataset[SampleType]: + pass + + @abc.abstractmethod + def __iter__(self) -> typing.Iterator[SampleType]: + """Return an iterator over documents or samples.""" From 978a68f20e7edeb83406e88441cb8c90c126ab2f Mon Sep 17 00:00:00 2001 From: bigximik Date: Sun, 23 Nov 2025 16:23:57 +0000 Subject: [PATCH 31/45] added naive sampled iterable dataset --- fast_llm/data/dataset/sampled.py | 58 +++++++++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/fast_llm/data/dataset/sampled.py b/fast_llm/data/dataset/sampled.py index d51a68746..7a11be6a4 100644 --- a/fast_llm/data/dataset/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -8,7 +8,7 @@ import torch import yaml -from fast_llm.data.dataset.abstract import SampledDataset +from fast_llm.data.dataset.abstract import SamplableIterableDataset, SampledDataset, SampledIterableDataset from fast_llm.data.dataset.config import SamplingData, ShufflingType from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.sample.abstract import Sample @@ -429,3 +429,59 @@ def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._unshuffled_tokens = data["unshuffled_tokens"] self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch + + +class NaiveSampledIterableDataset[SampleType: Sample](SampledIterableDataset[SampleType]): + def __init__( + self, + iterable_dataset: SamplableIterableDataset[SampleType], + sampling: SamplingData, + ): + self._dataset = iterable_dataset + self._config = sampling + assert self._config.parameters.truncate_documents == False + assert self._config.config.shuffle == ShufflingType.disabled + + def __iter__(self) -> typing.Iterator[SampleType]: + sample_length = self._config.parameters.sequence_length + self._config.parameters.extra_tokens + max_samples = self._config.parameters.num_samples + current_sample_length = 0 + documents: list[SampleType] = [] + num_samples = 0 + for doc in self._dataset: + if len(doc) > sample_length: + logging.warning(f"Dropping doc with length {len(doc)} higher then sample_length {sample_length}") + continue + if current_sample_length + len(doc) > sample_length: + padding_length = sample_length - current_sample_length + assert padding_length > 0 + documents.append(documents[-1].get_padding(padding_length)) + + yield documents[0].from_documents(documents) + + num_samples += 1 + if num_samples >= max_samples: + break + + documents = [doc] + current_sample_length = len(doc) + else: + documents.append(doc) + current_sample_length += len(doc) + + if current_sample_length == sample_length: + yield documents[0].from_documents(documents) + + num_samples += 1 + if num_samples >= max_samples: + break + + documents = [] + current_sample_length = 0 + + if num_samples < max_samples and current_sample_length > 0: + padding_length = sample_length - current_sample_length + assert padding_length > 0 + documents.append(documents[-1].get_padding(padding_length)) + + yield documents[0].from_documents(documents) From 066a0bfae44a9522067afa7f5958c5327d06456d Mon Sep 17 00:00:00 2001 From: bigximik Date: Sun, 23 Nov 2025 16:26:10 +0000 Subject: [PATCH 32/45] added iterable dataset configs, streaming dataset and PipelineRL sample and batch placeholder --- fast_llm/data/dataset/config.py | 86 +++++++++++++++++++- fast_llm/data/dataset/streaming.py | 117 ++++++++++++++++++++++++++++ fast_llm/data/sample/pipeline_rl.py | 102 ++++++++++++++++++++++++ 3 files changed, 304 insertions(+), 1 deletion(-) create mode 100644 fast_llm/data/dataset/streaming.py create mode 100644 fast_llm/data/sample/pipeline_rl.py diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 7611b4a31..86b666b45 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -8,13 +8,20 @@ import typing from fast_llm.config import Config, Field, FieldHint, UpdateType, check_field, config_class -from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset +from fast_llm.data.dataset.abstract import ( + SamplableDataset, + SamplableIterableDataset, + SampledDataset, + SampledIterableDataset, +) from fast_llm.data.sample.abstract import Sample from fast_llm.utils import Assert, normalize_probabilities if typing.TYPE_CHECKING: from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset + from fast_llm.data.dataset.streaming import StreamingDataset from fast_llm.data.sample.language_model import LanguageModelSample + from fast_llm.data.sample.pipeline_rl import PipelineRLSample from fast_llm.engine.distributed.distributed import Distributed logger = logging.getLogger(__name__) @@ -121,6 +128,25 @@ def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType] return self.build().sample(sampling) +@config_class(registry=True) +class SampledIterableDatasetConfig[SampleType: Sample](DatasetConfig[SampleType]): + """ + A sampled iterable dataset returning a prepared samples to be accessed sequentially (as-is) during training. + """ + + def build_and_sample(self, sampling: SamplingData) -> SampledIterableDataset[SampleType]: + raise NotImplementedError() + + +@config_class() +class SamplableIterableDatasetConfig[SampleType: Sample](SampledIterableDatasetConfig[SampleType]): + def build(self, distributed: "Distributed") -> SamplableIterableDataset[SampleType]: + raise NotImplementedError() + + def build_and_sample(self, sampling: SamplingData) -> SampledIterableDataset[SampleType]: + return self.build(sampling.distributed).sample(sampling) + + @config_class() class IndexedDatasetConfig[SampleType: Sample](SamplableDatasetConfig[SampleType]): def build(self) -> "IndexedDataset[SampleType]": @@ -297,3 +323,61 @@ def build(self) -> "IndexedDataset[SampleType]": return LegacyMemmapDataset[SampleType](name, self.path) else: raise FileNotFoundError(self.path) + + +@config_class() +class RedisConfig(Config): + host: str = Field( + default="localhost", + desc="Hostname or IP address of the Redis server.", + hint=FieldHint.core, + ) + + port: int = Field( + default=6379, + desc="Port number on which the Redis server is running.", + hint=FieldHint.core, + ) + + stream_key: str = Field( + default=None, + desc="Name of the Redis stream to read data from.", + hint=FieldHint.core, + ) + + group_name: str = Field( + default="fast_llm_dp_group", + desc="Name of the Redis consumer group used for data-parallel reading.", + hint=FieldHint.core, + ) + + consumer_name_prefix: str = Field( + default="fast_llm_dp_group_consumer", + desc="Prefix used to generate unique consumer names for each rank.", + hint=FieldHint.core, + ) + + +@config_class(dynamic_type={SampledIterableDatasetConfig: "streaming"}) +class StreamingDatasetConfig[SampleType: PipelineRLSample](SamplableIterableDatasetConfig[SampleType]): + """ + Configuration for a streaming dataset that reads training data from a Redis stream. + """ + + _abstract = False + + redis: RedisConfig = Field( + desc="Redis connection and stream settings used to fetch incoming training data.", + hint=FieldHint.core, + ) + + data_key: str = Field( + default="data", + desc="The Redis message field containing the serialized sample payload.", + hint=FieldHint.core, + ) + + def build(self, distributed: "Distributed") -> "StreamingDataset": + from fast_llm.data.dataset.streaming import StreamingDataset + + return StreamingDataset[SampleType](self, distributed) diff --git a/fast_llm/data/dataset/streaming.py b/fast_llm/data/dataset/streaming.py new file mode 100644 index 000000000..0914c725d --- /dev/null +++ b/fast_llm/data/dataset/streaming.py @@ -0,0 +1,117 @@ +import typing + +import torch.utils.data + +from fast_llm.data.dataset.abstract import SamplableIterableDataset, SampledIterableDataset +from fast_llm.data.dataset.config import SamplingData, StreamingDatasetConfig +from fast_llm.data.dataset.sampled import NaiveSampledIterableDataset +from fast_llm.data.sample.pipeline_rl import PipelineRLSample +from fast_llm.data.sample.range import RangeSample +from fast_llm.data.sample.token import TokenSample +from fast_llm.engine.distributed.distributed import Distributed + + +def dtype_from_string(name: str) -> torch.dtype: + try: + return getattr(torch, name) + except AttributeError: + raise ValueError(f"Unknown torch dtype: {name}") + + +class StreamingDataset[SampleType: PipelineRLSample](SamplableIterableDataset[SampleType]): + def __init__(self, config: StreamingDatasetConfig, distributed: Distributed): + super().__init__() + self._name = f"redis[{config.redis.host}:{config.redis.port}]({config.redis.stream_key}|{config.redis.group_name})[{config.data_key}]" + self._config = config + self.batch_data_rank = distributed.config.batch_data_rank + self.is_batch_data_group_leader = ( + distributed.model_and_sequence_data_group is None or distributed.model_and_sequence_data_group.rank() == 0 + ) + + @property + def name(self) -> str: + return self._name + + def sample(self, config: SamplingData) -> SampledIterableDataset[PipelineRLSample]: + # TODO: actually sample the dataset and not return docs + return NaiveSampledIterableDataset(self, config) + + def __getstate__(self) -> tuple[str, StreamingDatasetConfig, int, bool]: + return (self._name, self._config, self.batch_data_rank, self.is_batch_data_group_leader) + + def __setstate__(self, state: tuple[str, StreamingDatasetConfig, int, bool]): + name, config, batch_data_rank, is_batch_data_group_leader = state + self._name = name + self._config = config + self.batch_data_rank = batch_data_rank + self.is_batch_data_group_leader = is_batch_data_group_leader + + def __iter__(self) -> typing.Iterator[PipelineRLSample]: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None and worker_info.num_workers > 1: + raise RuntimeError("StreamingDataset can work only with one instance per rank") + + if not self.is_batch_data_group_leader: + raise RuntimeError("Must be only called on the batch data group leader") + + import orjson + import redis + import redis.exceptions + + r = redis.Redis(host=self._config.redis.host, port=self._config.redis.port) + + # Create the consumer group at the start of the stream ("0") + # If the stream already exists, XGROUP CREATE will fail unless we add mkstream=True + try: + r.xgroup_create( + name=self._config.redis.stream_key, groupname=self._config.redis.group_name, id="0", mkstream=True + ) + except redis.exceptions.ResponseError as e: + if "BUSYGROUP" in str(e): + # Consumer group already exists + pass + else: + raise + + while True: + # XREADGROUP reads from the consumer group + # COUNT: max number of messages to fetch at once + # BLOCK: wait for new messages (milliseconds) + messages = r.xreadgroup( + groupname=self._config.redis.group_name, + consumername=f"{self._config.redis.consumer_name_prefix}_{self.batch_data_rank}", + # ">" means read only new messages that were never delivered to this consumer + streams={self._config.redis.stream_key: ">"}, + count=1, + block=5000, # wait up to 5 seconds + ) + + if messages: + for stream_key, msgs in messages: + assert stream_key == self._config.redis.stream_key.encode() + for msg_id, msg_data in msgs: + r.xack(self._config.redis.stream_key, self._config.redis.group_name, msg_id) + data = orjson.loads(msg_data[self._config.data_key.encode()]) + yield self._sample_from_dict(data) + + def _sample_from_dict(cls, data: dict) -> PipelineRLSample: + tokens = torch.tensor(data["tokens"], dtype=dtype_from_string(data["tokens_dtype"])) + sample_size = len(tokens) + if "loss_masking_spans" in data: + loss_masking_spans = [tuple(el) for el in data["loss_masking_spans"]] + else: + loss_masking_spans = None + if "chosen_spans" in data: + chosen_spans = [tuple(el) for el in data["chosen_spans"]] + else: + chosen_spans = None + if "rejected_spans" in data: + rejected_spans = [tuple(el) for el in data["rejected_spans"]] + else: + rejected_spans = None + return PipelineRLSample( + TokenSample(tokens, [sample_size]), + RangeSample(loss_masking_spans, sample_size) if loss_masking_spans is not None else None, + RangeSample(chosen_spans, sample_size) if chosen_spans is not None else None, + RangeSample(rejected_spans, sample_size) if rejected_spans is not None else None, + ) diff --git a/fast_llm/data/sample/pipeline_rl.py b/fast_llm/data/sample/pipeline_rl.py new file mode 100644 index 000000000..21c434ad9 --- /dev/null +++ b/fast_llm/data/sample/pipeline_rl.py @@ -0,0 +1,102 @@ +import typing + +import torch + +from fast_llm.data.sample.abstract import Batch, Sample +from fast_llm.data.sample.language_model import _crop_optional, _merge_optional +from fast_llm.data.sample.range import RangeBatch, RangeSample +from fast_llm.data.sample.token import TokenBatch, TokenSample + + +class PipelineRLSample(Sample): + def __init__( + self, + tokens: TokenSample, + loss_masking_spans: RangeSample | None = None, + chosen_spans: RangeSample | None = None, + rejected_spans: RangeSample | None = None, + ): + self.tokens = tokens + self.loss_masking_spans = loss_masking_spans + self.chosen_spans = chosen_spans + self.rejected_spans = rejected_spans + + @classmethod + def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: + return cls( + TokenSample.from_documents([document.tokens for document in documents]), + _merge_optional(RangeSample.from_documents, [document.loss_masking_spans for document in documents]), + _merge_optional(RangeSample.from_documents, [document.chosen_spans for document in documents]), + _merge_optional(RangeSample.from_documents, [document.rejected_spans for document in documents]), + ) + + def crop(self, begin: int, end: int) -> typing.Self: + return self.__class__( + self.tokens.crop(begin, end), + _crop_optional(self.loss_masking_spans, begin, end), + _crop_optional(self.chosen_spans, begin, end), + _crop_optional(self.rejected_spans, begin, end), + ) + + def __len__(self) -> int: + return len(self.tokens) + + def get_padding(self, size: int) -> typing.Self: + return PipelineRLSample( + self.tokens.get_padding(size), + None if self.loss_masking_spans is None else self.loss_masking_spans.get_padding(size), + None if self.chosen_spans is None else self.chosen_spans.get_padding(size), + None if self.rejected_spans is None else self.rejected_spans.get_padding(size), + ) + + +class PipelineRLBatch(Batch): + def __init__( + self, + tokens: TokenBatch, + loss_masking_spans: RangeBatch | None = None, + chosen_spans: RangeBatch | None = None, + rejected_spans: RangeBatch | None = None, + ): + self.tokens = tokens + self.loss_masking_spans = loss_masking_spans + self.chosen_spans = chosen_spans + self.rejected_spans = rejected_spans + + @classmethod + def from_samples(cls, samples: typing.Iterable[PipelineRLSample]) -> typing.Self: + return cls( + TokenBatch.from_samples([sample.tokens for sample in samples]), + _merge_optional(RangeBatch.from_samples, [sample.loss_masking_spans for sample in samples]), + _merge_optional(RangeBatch.from_samples, [sample.chosen_spans for sample in samples]), + _merge_optional(RangeBatch.from_samples, [sample.rejected_spans for sample in samples]), + ) + + def to_samples(self) -> list[PipelineRLSample]: + return [ + PipelineRLSample(tokens, loss_masking_spans, chosen_spans, rejected_spans) + for tokens, loss_masking_spans, chosen_spans, rejected_spans in zip( + self.tokens.to_samples(), + None if self.loss_masking_spans is None else self.loss_masking_spans.to_samples(), + None if self.chosen_spans is None else self.chosen_spans.to_samples(), + None if self.rejected_spans is None else self.rejected_spans.to_samples(), + strict=True, + ) + ] + + def crop(self, begin: int, end: int) -> typing.Self: + return self.__class__( + self.tokens.crop(begin, end), + _crop_optional(self.loss_masking_spans, begin, end), + _crop_optional(self.chosen_spans, begin, end), + _crop_optional(self.rejected_spans, begin, end), + ) + + def to_device_(self, device: "torch.device | str"): + self.tokens.to_device_(device) + if self.loss_masking_spans is not None: + self.loss_masking_spans.to_device_(device) + if self.chosen_spans is not None: + self.chosen_spans.to_device_(device) + if self.rejected_spans is not None: + self.rejected_spans.to_device_(device) From 68b3d65198ba547f2c592f9334520ce6e8503fb2 Mon Sep 17 00:00:00 2001 From: bigximik Date: Sun, 23 Nov 2025 16:26:52 +0000 Subject: [PATCH 33/45] added distributed data loader wrapper --- fast_llm/data/data/data_loader_wrapper.py | 50 +++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 fast_llm/data/data/data_loader_wrapper.py diff --git a/fast_llm/data/data/data_loader_wrapper.py b/fast_llm/data/data/data_loader_wrapper.py new file mode 100644 index 000000000..e7ef470ba --- /dev/null +++ b/fast_llm/data/data/data_loader_wrapper.py @@ -0,0 +1,50 @@ +import torch.distributed +import torch.utils.data.dataloader + +from fast_llm.core.distributed import broadcast_object + + +class DistributedDataLoaderWrapper: + """ + Wraps a regular dataloader so that only the process group leader + loads data, and then broadcasts the batch to other ranks in the group. + """ + + def __init__( + self, + dataloader: torch.utils.data.dataloader.DataLoader | None, + rank: int, + process_group: torch.distributed.ProcessGroup | None, + ): + self.dataloader = dataloader + self.rank = rank + self.process_group = process_group + + assert (self.rank == 0 and self.dataloader is not None) or (self.rank > 0 and self.dataloader is None) + + def __iter__(self): + if self.rank == 0: + self.iterator = iter(self.dataloader) + if self.process_group is None: + return self.iterator + return self + + def __next__(self): + # TODO: + # Instead of broadcasting a general object, make this iterator yield actual batches. + # Add batch data to a state dict or a dedicated Batch class, so we can efficiently + # broadcast tensors directly. This avoids using `broadcast_object` on entire objects, + # which is inefficient for tensors since it serializes them (pickles) before sending. + if self.rank == 0: + try: + data = next(self.iterator) # may raise StopIteration + except Exception as e: + data = e + data = broadcast_object(data, self.process_group, 0) + else: + data = broadcast_object(None, self.process_group, 0) + + if isinstance(data, Exception): + raise data + + return data From 2fbfe99123cb8b4d7c2f8bf8745bc3ea51432bfb Mon Sep 17 00:00:00 2001 From: bigximik Date: Sun, 23 Nov 2025 16:27:53 +0000 Subject: [PATCH 34/45] added iterable dataset to gpt data --- fast_llm/data/data/gpt/config.py | 7 ++++-- fast_llm/data/data/gpt/data.py | 43 +++++++++++++++++++++++++++----- 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index ba5be883a..a19846c3b 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -4,11 +4,12 @@ from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.data.config import MultiprocessingContext from fast_llm.data.data.config import DataConfig -from fast_llm.data.dataset.config import SampledDatasetConfig +from fast_llm.data.dataset.config import SampledDatasetConfig, SampledIterableDatasetConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.data.sample.language_model import LanguageModelSample + from fast_llm.data.sample.pipeline_rl import PipelineRLSample logger = logging.getLogger(__name__) @@ -23,7 +24,9 @@ class GPTDataConfig(DataConfig): _abstract = False # TODO: Review field. Move closer to phase definition in training config? - datasets: dict[str, SampledDatasetConfig["LanguageModelSample"]] = Field( + datasets: dict[ + str, SampledDatasetConfig["LanguageModelSample"] | SampledIterableDatasetConfig["PipelineRLSample"] + ] = Field( default_factory=dict, desc="Configuration for the dataset(s).", hint=FieldHint.core, diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index de47ef761..f856a58a3 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -8,8 +8,9 @@ from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.abstract import Data +from fast_llm.data.data.data_loader_wrapper import DistributedDataLoaderWrapper from fast_llm.data.data.gpt.config import GPTDataConfig -from fast_llm.data.dataset.abstract import SampledDataset +from fast_llm.data.dataset.abstract import SampledDataset, SampledIterableDataset from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters from fast_llm.data.dataset.monitor import DatasetMonitor from fast_llm.data.iterator import SampledDatasetIterator @@ -86,7 +87,12 @@ def setup( dataset_name=dataset_name, ) dataset = self._config.datasets[dataset_name].build_and_sample(sampling) - self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) + if isinstance(dataset, SampledDataset): + self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) + else: + # Do not set monitor for iterable dataset as monitor only works with map style datasets + assert isinstance(dataset, SampledIterableDataset) + self._datasets[dataset_name] = dataset safe_barrier(self._distributed.world_group, "data_preparation", timeout) self._is_setup = True @@ -112,9 +118,11 @@ def get_iterator( Assert.in_range_incl(batch_config.sequence_length, 1, sampling_parameters.sequence_length) log_main_rank(f"Initializing {dataset_name} dataset iterator from sample {consumed_samples}...") - return iter( - torch.utils.data.DataLoader( - self._datasets[dataset_name], # noqa + dataset = self._datasets[dataset_name] + + if isinstance(dataset, SampledDataset): + data_loader = torch.utils.data.DataLoader( + dataset, # noqa batch_sampler=SampledDatasetIterator( total_samples=len(self._datasets[dataset_name]), begin_index=consumed_samples, @@ -128,4 +136,27 @@ def get_iterator( collate_fn=LanguageModelBatch.from_samples, multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) - ) + + elif isinstance(dataset, SampledIterableDataset): + if ( + self.distributed.model_and_sequence_data_group is None + or self.distributed.model_and_sequence_data_group.rank() == 0 + ): + rank = 0 + data_loader = torch.utils.data.DataLoader( + dataset, # noqa + batch_size=batch_config.micro_batch_size, + num_workers=0 if num_workers == 0 else 1, + prefetch_factor=prefetch_factor, + pin_memory=True, + collate_fn=LanguageModelBatch.from_samples, + multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, + ) + else: + rank = self.distributed.model_and_sequence_data_group.rank() + data_loader = None + data_loader = DistributedDataLoaderWrapper( + data_loader, rank, self.distributed.model_and_sequence_data_group + ) + + return iter(data_loader) From 08925237c66273eb99eea34833450f3fdfe2ad92 Mon Sep 17 00:00:00 2001 From: bigximik Date: Sun, 23 Nov 2025 16:42:56 +0000 Subject: [PATCH 35/45] appended comment --- fast_llm/data/data/data_loader_wrapper.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/data/data_loader_wrapper.py b/fast_llm/data/data/data_loader_wrapper.py index e7ef470ba..f9e517248 100644 --- a/fast_llm/data/data/data_loader_wrapper.py +++ b/fast_llm/data/data/data_loader_wrapper.py @@ -31,10 +31,12 @@ def __iter__(self): def __next__(self): # TODO: - # Instead of broadcasting a general object, make this iterator yield actual batches. - # Add batch data to a state dict or a dedicated Batch class, so we can efficiently - # broadcast tensors directly. This avoids using `broadcast_object` on entire objects, - # which is inefficient for tensors since it serializes them (pickles) before sending. + # Instead of broadcasting a general object, make this iterator yield an actual Batch class. + # Implement `get_state_dict` and `from_state_dict` in the Batch class so that we can + # efficiently broadcast tensors directly. This avoids using `broadcast_object` on the + # entire Batch object, which is inefficient for tensors because it serializes + # (pickles) them before sending. + if self.rank == 0: try: data = next(self.iterator) # may raise StopIteration From 54fadb4c015b015e331161d5a2fdcb1089c67f47 Mon Sep 17 00:00:00 2001 From: bigximik Date: Mon, 24 Nov 2025 18:16:48 +0000 Subject: [PATCH 36/45] changed base classes for iterable dataset configs --- fast_llm/data/data/gpt/config.py | 7 ++---- fast_llm/data/dataset/config.py | 40 ++++++++++---------------------- 2 files changed, 14 insertions(+), 33 deletions(-) diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index a19846c3b..ba5be883a 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -4,12 +4,11 @@ from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.data.config import MultiprocessingContext from fast_llm.data.data.config import DataConfig -from fast_llm.data.dataset.config import SampledDatasetConfig, SampledIterableDatasetConfig +from fast_llm.data.dataset.config import SampledDatasetConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.data.sample.language_model import LanguageModelSample - from fast_llm.data.sample.pipeline_rl import PipelineRLSample logger = logging.getLogger(__name__) @@ -24,9 +23,7 @@ class GPTDataConfig(DataConfig): _abstract = False # TODO: Review field. Move closer to phase definition in training config? - datasets: dict[ - str, SampledDatasetConfig["LanguageModelSample"] | SampledIterableDatasetConfig["PipelineRLSample"] - ] = Field( + datasets: dict[str, SampledDatasetConfig["LanguageModelSample"]] = Field( default_factory=dict, desc="Configuration for the dataset(s).", hint=FieldHint.core, diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 86b666b45..78969f5a7 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -19,7 +19,6 @@ if typing.TYPE_CHECKING: from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset - from fast_llm.data.dataset.streaming import StreamingDataset from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.pipeline_rl import PipelineRLSample from fast_llm.engine.distributed.distributed import Distributed @@ -112,41 +111,26 @@ class DatasetConfig[SampleType: Sample](Config): @config_class(registry=True) class SampledDatasetConfig[SampleType: Sample](DatasetConfig[SampleType]): """ - A sampled dataset containing a prepared list of samples to be indexed sequentially (as-is) during training. + A sampled dataset containing a prepared list or iterable of samples to be indexed sequentially (as-is) during training. """ - def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: + def build_and_sample( + self, sampling: SamplingData + ) -> SampledDataset[SampleType] | SampledIterableDataset[SampleType]: raise NotImplementedError() @config_class() class SamplableDatasetConfig[SampleType: Sample](SampledDatasetConfig[SampleType]): - def build(self) -> SamplableDataset[SampleType]: + def build(self) -> SamplableDataset[SampleType] | SamplableIterableDataset[SampleType]: raise NotImplementedError() - def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: + def build_and_sample( + self, sampling: SamplingData + ) -> SampledDataset[SampleType] | SampledIterableDataset[SampleType]: return self.build().sample(sampling) -@config_class(registry=True) -class SampledIterableDatasetConfig[SampleType: Sample](DatasetConfig[SampleType]): - """ - A sampled iterable dataset returning a prepared samples to be accessed sequentially (as-is) during training. - """ - - def build_and_sample(self, sampling: SamplingData) -> SampledIterableDataset[SampleType]: - raise NotImplementedError() - - -@config_class() -class SamplableIterableDatasetConfig[SampleType: Sample](SampledIterableDatasetConfig[SampleType]): - def build(self, distributed: "Distributed") -> SamplableIterableDataset[SampleType]: - raise NotImplementedError() - - def build_and_sample(self, sampling: SamplingData) -> SampledIterableDataset[SampleType]: - return self.build(sampling.distributed).sample(sampling) - - @config_class() class IndexedDatasetConfig[SampleType: Sample](SamplableDatasetConfig[SampleType]): def build(self) -> "IndexedDataset[SampleType]": @@ -358,8 +342,8 @@ class RedisConfig(Config): ) -@config_class(dynamic_type={SampledIterableDatasetConfig: "streaming"}) -class StreamingDatasetConfig[SampleType: PipelineRLSample](SamplableIterableDatasetConfig[SampleType]): +@config_class(dynamic_type={SampledDatasetConfig: "streaming"}) +class StreamingDatasetConfig[SampleType: PipelineRLSample](SamplableDatasetConfig[SampleType]): """ Configuration for a streaming dataset that reads training data from a Redis stream. """ @@ -377,7 +361,7 @@ class StreamingDatasetConfig[SampleType: PipelineRLSample](SamplableIterableData hint=FieldHint.core, ) - def build(self, distributed: "Distributed") -> "StreamingDataset": + def build_and_sample(self, sampling: SamplingData) -> SampledIterableDataset[SampleType]: from fast_llm.data.dataset.streaming import StreamingDataset - return StreamingDataset[SampleType](self, distributed) + return StreamingDataset[SampleType](self, sampling.distributed).sample(sampling) From 4e11bf3287d51e7ca01a1f3168eb8e78e7222b15 Mon Sep 17 00:00:00 2001 From: bigximik Date: Mon, 24 Nov 2025 18:17:47 +0000 Subject: [PATCH 37/45] fix batch type --- fast_llm/data/data/gpt/data.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index f856a58a3..7cd746ed0 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -15,6 +15,7 @@ from fast_llm.data.dataset.monitor import DatasetMonitor from fast_llm.data.iterator import SampledDatasetIterator from fast_llm.data.sample.language_model import LanguageModelBatch +from fast_llm.data.sample.pipeline_rl import PipelineRLBatch from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed @@ -149,7 +150,7 @@ def get_iterator( num_workers=0 if num_workers == 0 else 1, prefetch_factor=prefetch_factor, pin_memory=True, - collate_fn=LanguageModelBatch.from_samples, + collate_fn=PipelineRLBatch.from_samples, multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) else: From 8428df8f420607c2a1bf82bd3122e0d773f27343 Mon Sep 17 00:00:00 2001 From: bigximik Date: Mon, 24 Nov 2025 18:18:32 +0000 Subject: [PATCH 38/45] fix added name property to the class --- fast_llm/data/dataset/sampled.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/fast_llm/data/dataset/sampled.py b/fast_llm/data/dataset/sampled.py index 7a11be6a4..5708e3dbe 100644 --- a/fast_llm/data/dataset/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -485,3 +485,7 @@ def __iter__(self) -> typing.Iterator[SampleType]: documents.append(documents[-1].get_padding(padding_length)) yield documents[0].from_documents(documents) + + @property + def name(self) -> str: + return self._dataset.name From 04ee4d71aa641839c2e49dbc8008b3feed303472 Mon Sep 17 00:00:00 2001 From: bigximik Date: Mon, 24 Nov 2025 18:19:38 +0000 Subject: [PATCH 39/45] add eof for tests --- fast_llm/data/dataset/streaming.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/fast_llm/data/dataset/streaming.py b/fast_llm/data/dataset/streaming.py index 0914c725d..291886e2e 100644 --- a/fast_llm/data/dataset/streaming.py +++ b/fast_llm/data/dataset/streaming.py @@ -86,13 +86,21 @@ def __iter__(self) -> typing.Iterator[PipelineRLSample]: block=5000, # wait up to 5 seconds ) + data = None if messages: for stream_key, msgs in messages: assert stream_key == self._config.redis.stream_key.encode() for msg_id, msg_data in msgs: r.xack(self._config.redis.stream_key, self._config.redis.group_name, msg_id) data = orjson.loads(msg_data[self._config.data_key.encode()]) + # NOTE: for testing with fakeredis only as on real consumer group it will be delivered only to one consumer + if "eof" in data and data["eof"] == True: + break yield self._sample_from_dict(data) + if "eof" in data and data["eof"] == True: + break + if data is not None and "eof" in data and data["eof"] == True: + break def _sample_from_dict(cls, data: dict) -> PipelineRLSample: tokens = torch.tensor(data["tokens"], dtype=dtype_from_string(data["tokens_dtype"])) From 121799866c2e2c66ba6b8f64caa6ecb1267bb256 Mon Sep 17 00:00:00 2001 From: bigximik Date: Mon, 24 Nov 2025 18:20:32 +0000 Subject: [PATCH 40/45] change base class to torch iterable --- fast_llm/data/dataset/abstract.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/fast_llm/data/dataset/abstract.py b/fast_llm/data/dataset/abstract.py index f8705d470..5eddcd788 100644 --- a/fast_llm/data/dataset/abstract.py +++ b/fast_llm/data/dataset/abstract.py @@ -1,6 +1,8 @@ import abc import typing +import torch.utils.data.dataset + from fast_llm.data.sample.abstract import Sample if typing.TYPE_CHECKING: @@ -43,30 +45,28 @@ def __len__(self) -> int: pass -class SampledIterableDataset[SampleType: Sample](Dataset[SampleType], typing.Iterable[SampleType]): +# NOTE: We need to inherit from IterableDataset overwise torch data loader can not detect it properly +class SampledIterableDataset[SampleType: Sample](torch.utils.data.dataset.IterableDataset[SampleType]): """ A sampled dataset class that provides an iterator over samples. - (See the `Sampler` class below.) """ + # NOTE: We add name here so it is compatible with Fast-LLM Dataset + @property @abc.abstractmethod - def __iter__(self) -> typing.Iterator[SampleType]: - """Return an iterator over samples.""" + def name(self) -> str: + """ + A name for the dataset to facilitate identification and debugging. + """ class SamplableDataset[SampleType: Sample](Dataset[SampleType]): - @abc.abstractmethod def sample(self, config: "SamplingData") -> SampledDataset[SampleType]: pass -class SamplableIterableDataset[SampleType: Sample](Dataset[SampleType]): - +class SamplableIterableDataset[SampleType: Sample](SampledIterableDataset[SampleType]): @abc.abstractmethod def sample(self, config: "SamplingData") -> SampledIterableDataset[SampleType]: pass - - @abc.abstractmethod - def __iter__(self) -> typing.Iterator[SampleType]: - """Return an iterator over documents or samples.""" From c542dac24459702a7f1653323d92ff2f1ded4fa6 Mon Sep 17 00:00:00 2001 From: bigximik Date: Mon, 24 Nov 2025 18:21:28 +0000 Subject: [PATCH 41/45] added straming dataset, sampling and base data tests --- tests/data/test_streaming.py | 249 +++++++++++++++++++++++++++++++++++ 1 file changed, 249 insertions(+) create mode 100644 tests/data/test_streaming.py diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py new file mode 100644 index 000000000..54352a79d --- /dev/null +++ b/tests/data/test_streaming.py @@ -0,0 +1,249 @@ +import fakeredis +import orjson +import pytest +import torch + +from fast_llm.config import NoAutoValidate +from fast_llm.data.data.gpt.config import GPTDataConfig +from fast_llm.data.data.gpt.data import GPTData +from fast_llm.data.dataset.config import ( + RedisConfig, + SamplingConfig, + SamplingData, + SamplingParameters, + ShufflingType, + StreamingDatasetConfig, +) +from fast_llm.data.dataset.streaming import StreamingDataset +from fast_llm.data.sample.pipeline_rl import PipelineRLSample +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.models.gpt.config import GPTBatchConfig + +# --------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------- + + +@pytest.fixture +def fake_redis(): + """Return a FakeRedis instance.""" + return fakeredis.FakeRedis() + + +@pytest.fixture +def monkeypatched_redis(monkeypatch, fake_redis): + """Monkeypatch redis.Redis globally (works even for imports inside functions).""" + import redis + + monkeypatch.setattr(redis, "Redis", lambda *args, **kwargs: fake_redis) + return fake_redis + + +@pytest.fixture +def stream_config(): + return StreamingDatasetConfig( + redis=RedisConfig( + host="localhost", + port=6379, + stream_key="test_stream", + group_name="test_group", + consumer_name_prefix="consumer", + ), + data_key="data", + ) + + +# --------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------- + + +def push_msg(redis_client, config, tokens=None, is_eof=False): + """Push a message into FakeRedis stream.""" + if is_eof: + msg = {"eof": True} + else: + msg = { + "tokens": tokens, + "tokens_dtype": "int64", + } + redis_client.xadd(config.redis.stream_key, {config.data_key: orjson.dumps(msg)}) + + +def make_sampling(sequence_length, extra_tokens, num_samples, distributed): + return SamplingData( + parameters=SamplingParameters( + sequence_length=sequence_length, + extra_tokens=extra_tokens, + num_samples=num_samples, + truncate_documents=False, + ), + config=SamplingConfig(shuffle=ShufflingType.disabled), + distributed=distributed, + dataset_name="test", + cache_directory="/tmp", + ) + + +# --------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------- + + +def test_streaming_dataset_reads_single_message(monkeypatched_redis, stream_config): + """StreamingDataset should read a message and convert it into PipelineRLSample.""" + fake_redis = monkeypatched_redis + + distributed = Distributed(DistributedConfig(), use_cpu=True) + dataset = StreamingDataset(stream_config, distributed) + + # Insert a message + push_msg(fake_redis, stream_config, [1, 2, 3]) + + it = iter(dataset) + sample = next(it) + + assert isinstance(sample, PipelineRLSample) + assert torch.equal(sample.tokens.tokens, torch.tensor([1, 2, 3], dtype=torch.int64)) + assert sample.tokens.lengths == [3] + assert sample.loss_masking_spans is None + assert sample.chosen_spans is None + assert sample.rejected_spans is None + + +def test_streaming_dataset_reads_multiple_messages(monkeypatched_redis, stream_config): + """StreamingDataset should read a message and convert it into PipelineRLSample.""" + fake_redis = monkeypatched_redis + + distributed = Distributed(DistributedConfig(), use_cpu=True) + dataset = StreamingDataset(stream_config, distributed) + + # Insert a message + push_msg(fake_redis, stream_config, [1, 2, 3]) + push_msg(fake_redis, stream_config, [1, 2, 3]) + push_msg(fake_redis, stream_config, [1, 2, 3]) + + it = iter(dataset) + for i in range(3): + sample = next(it) + + assert isinstance(sample, PipelineRLSample) + assert torch.equal(sample.tokens.tokens, torch.tensor([1, 2, 3], dtype=torch.int64)) + assert sample.tokens.lengths == [3] + assert sample.loss_masking_spans is None + assert sample.chosen_spans is None + assert sample.rejected_spans is None + + +def test_sampling_1_doc_exact_fit(monkeypatched_redis, stream_config): + """Docs exactly fill one sample.""" + fake_redis = monkeypatched_redis + + # Two rollouts: lengths 4 and 6 -> exactly 10 + push_msg(fake_redis, stream_config, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + + distributed = Distributed(DistributedConfig(), use_cpu=True) + sampler = StreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 1, distributed)) + + out = list(sampler) + + assert len(out) == 1 + s = out[0] + assert isinstance(s, PipelineRLSample) + assert len(s) == 10 + assert s.tokens.tokens.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + +def test_sampling_2_docs_exact_fit(monkeypatched_redis, stream_config): + """Docs exactly fill one sample.""" + fake_redis = monkeypatched_redis + + # Two rollouts: lengths 4 and 6 -> exactly 10 + push_msg(fake_redis, stream_config, [1, 2, 3, 4]) + push_msg(fake_redis, stream_config, [5, 6, 7, 8, 9, 10]) + + distributed = Distributed(DistributedConfig(), use_cpu=True) + sampler = StreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 1, distributed)) + + out = list(sampler) + + assert len(out) == 1 + s = out[0] + assert isinstance(s, PipelineRLSample) + assert len(s) == 10 + assert s.tokens.tokens.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + +def test_sampling_skips_too_long_doc_and_padding_final(monkeypatched_redis, stream_config): + """Rollout longer than sample_length must be dropped.""" + fake_redis = monkeypatched_redis + + push_msg(fake_redis, stream_config, list(range(20))) # skip: too long + push_msg(fake_redis, stream_config, list(range(8))) # usable + push_msg(fake_redis, stream_config, is_eof=True) + + distributed = Distributed(DistributedConfig(), use_cpu=True) + sampler = StreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 1, distributed)) + + out = list(sampler) + + assert len(out) == 1 + s = out[0] + assert len(s) == 10 + assert s.tokens.tokens.tolist() == list(range(8)) + [-100, -100] + + +def test_sampling_overflow_creates_two_and_padding_final(monkeypatched_redis, stream_config): + """A document overflowing the boundary triggers padding + next sample.""" + fake_redis = monkeypatched_redis + + push_msg(fake_redis, stream_config, list(range(6))) + push_msg(fake_redis, stream_config, list(range(6))) + push_msg(fake_redis, stream_config, is_eof=True) + + distributed = Distributed(DistributedConfig(), use_cpu=True) + sampler = StreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 2, distributed)) + + out = list(sampler) + + assert len(out) == 2 + + # sample 1: 0..5 + pad(4) + assert out[0].tokens.tokens.tolist() == list(range(6)) + [-100, -100, -100, -100] + + # sample 2: 0..5 + pad(4) + assert out[1].tokens.tokens.tolist() == list(range(6)) + [-100, -100, -100, -100] + + +def test_data_single_consumer(monkeypatched_redis, stream_config): + fake_redis = monkeypatched_redis + + sequence_length = 10 + samples_count = 2 + + push_msg(fake_redis, stream_config, list(range(sequence_length))) + push_msg(fake_redis, stream_config, list(range(sequence_length))) + + distributed = Distributed(DistributedConfig(), use_cpu=True) + sampling_data = make_sampling(sequence_length, 0, samples_count, distributed) + + data_config = {"datasets": {"streaming1": stream_config.to_dict()}, "sampling": {"shuffle": "disabled"}} + data_config = GPTDataConfig.from_dict(data_config) + + data = GPTData(data_config, distributed.config) + + data.setup(distributed, {"streaming1": sampling_data.parameters}, "/tmp") + + with NoAutoValidate(): + batch_config = GPTBatchConfig( + micro_batch_size=samples_count, batch_size=samples_count, sequence_length=sequence_length + ) + batch_config.setup(distributed_config=distributed.config) + batch_config.validate() + + # TODO: check why is not working with num_workers == 1 + data_iter = data.get_iterator(batch_config, "streaming1", consumed_samples=0, num_workers=0, prefetch_factor=None) + + batch = next(data_iter) + assert batch.tokens.tokens.shape == (2, 10) From a1556f8c34ed0b275d7a79b6a9781bacb37e7590 Mon Sep 17 00:00:00 2001 From: bigximik Date: Mon, 24 Nov 2025 19:32:13 +0000 Subject: [PATCH 42/45] change import --- fast_llm/data/dataset/abstract.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fast_llm/data/dataset/abstract.py b/fast_llm/data/dataset/abstract.py index 5eddcd788..0fe53b7ba 100644 --- a/fast_llm/data/dataset/abstract.py +++ b/fast_llm/data/dataset/abstract.py @@ -1,7 +1,7 @@ import abc import typing -import torch.utils.data.dataset +import torch.utils.data from fast_llm.data.sample.abstract import Sample @@ -46,7 +46,7 @@ def __len__(self) -> int: # NOTE: We need to inherit from IterableDataset overwise torch data loader can not detect it properly -class SampledIterableDataset[SampleType: Sample](torch.utils.data.dataset.IterableDataset[SampleType]): +class SampledIterableDataset[SampleType: Sample](torch.utils.data.IterableDataset[SampleType]): """ A sampled dataset class that provides an iterator over samples. """ From 63737b15767990896c3344a7b55cb73803f5951d Mon Sep 17 00:00:00 2001 From: bigximik Date: Tue, 25 Nov 2025 12:52:45 +0000 Subject: [PATCH 43/45] fix iterable sampler for spawn, add fake redis server to multi process tests --- fast_llm/data/dataset/sampled.py | 12 +++++++----- tests/data/test_streaming.py | 30 ++++++++++++++++++++++++++---- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/fast_llm/data/dataset/sampled.py b/fast_llm/data/dataset/sampled.py index 5708e3dbe..7ec659bd8 100644 --- a/fast_llm/data/dataset/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -438,13 +438,15 @@ def __init__( sampling: SamplingData, ): self._dataset = iterable_dataset - self._config = sampling - assert self._config.parameters.truncate_documents == False - assert self._config.config.shuffle == ShufflingType.disabled + self._config = sampling.config + self._parameters = sampling.parameters + + assert self._parameters.truncate_documents == False + assert self._config.shuffle == ShufflingType.disabled def __iter__(self) -> typing.Iterator[SampleType]: - sample_length = self._config.parameters.sequence_length + self._config.parameters.extra_tokens - max_samples = self._config.parameters.num_samples + sample_length = self._parameters.sequence_length + self._parameters.extra_tokens + max_samples = self._parameters.num_samples current_sample_length = 0 documents: list[SampleType] = [] num_samples = 0 diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py index 54352a79d..bcf3ac043 100644 --- a/tests/data/test_streaming.py +++ b/tests/data/test_streaming.py @@ -1,3 +1,5 @@ +import threading + import fakeredis import orjson import pytest @@ -54,6 +56,27 @@ def stream_config(): ) +@pytest.fixture +def fake_redis_server(stream_config): + server_address = (stream_config.redis.host, stream_config.redis.port) + server = fakeredis.TcpFakeServer(server_address, server_type="redis") + + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + + # Create a redis-py client pointing at the fake server + import redis + + client = redis.Redis(host=server_address[0], port=server_address[1]) + + yield stream_config, client + + # Everything after yield = teardown + server.shutdown() + server.server_close() + thread.join() + + # --------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------- @@ -216,8 +239,8 @@ def test_sampling_overflow_creates_two_and_padding_final(monkeypatched_redis, st assert out[1].tokens.tokens.tolist() == list(range(6)) + [-100, -100, -100, -100] -def test_data_single_consumer(monkeypatched_redis, stream_config): - fake_redis = monkeypatched_redis +def test_data_single_consumer(fake_redis_server): + stream_config, fake_redis = fake_redis_server sequence_length = 10 samples_count = 2 @@ -242,8 +265,7 @@ def test_data_single_consumer(monkeypatched_redis, stream_config): batch_config.setup(distributed_config=distributed.config) batch_config.validate() - # TODO: check why is not working with num_workers == 1 - data_iter = data.get_iterator(batch_config, "streaming1", consumed_samples=0, num_workers=0, prefetch_factor=None) + data_iter = data.get_iterator(batch_config, "streaming1", consumed_samples=0, num_workers=1, prefetch_factor=1) batch = next(data_iter) assert batch.tokens.tokens.shape == (2, 10) From e843c8e4f1b23753e557506df700beee69361629 Mon Sep 17 00:00:00 2001 From: bigximik Date: Tue, 25 Nov 2025 15:02:21 +0000 Subject: [PATCH 44/45] preparation for multi gpu tests --- tests/data/test_streaming.py | 141 +++++++++++++++++++++++++++++++---- 1 file changed, 128 insertions(+), 13 deletions(-) diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py index bcf3ac043..092bbba90 100644 --- a/tests/data/test_streaming.py +++ b/tests/data/test_streaming.py @@ -109,6 +109,76 @@ def make_sampling(sequence_length, extra_tokens, num_samples, distributed): ) +def generate_parallelism_variants(total_gpus: int): + """ + Generate all valid variants of (data_groups, tensor_parallel, pipeline_parallel, sequence_parallel) + for a number of GPUs up to the total_gpus. + If total_gpus is odd and > 1, fallback to nearest lower even number for decomposable parallelism. + """ + if total_gpus > 1 and total_gpus % 2 == 1: + total_gpus = total_gpus - 1 + + if total_gpus == 0: + return [ + { + "data_groups": 1, + "tensor_parallel": 1, + "pipeline_parallel": 1, + "sequence_data_parallel": 1, + "total_gpus": 0, + } + ] + + variants = [ + { + "data_groups": 1, + "tensor_parallel": 1, + "pipeline_parallel": 1, + "sequence_data_parallel": 1, + "total_gpus": 1, + } + ] + + for gpus in range(2, total_gpus + 1, 2): + # try all possible numbers of data groups (1..total_gpus) + for data_groups in range(1, gpus + 1): + if gpus % data_groups != 0: + continue # cannot evenly split + + gpus_per_group = gpus // data_groups + + # now find all decompositions of gpus_per_group into tp*pp*sp + for tp in range(1, gpus_per_group + 1): + if gpus_per_group % tp != 0: + continue + rem_after_tp = gpus_per_group // tp + for pp in range(1, rem_after_tp + 1): + if rem_after_tp % pp != 0: + continue + sp = rem_after_tp // pp + try: + # instead of repeating all safeguards here just try to instantiate distributed config to check if combination is valid + DistributedConfig( + tensor_parallel=tp, + pipeline_parallel=pp, + sequence_data_parallel=sp, + world_size=gpus, + rank=0, + ) + except Exception: + continue + variants.append( + { + "data_groups": data_groups, + "tensor_parallel": tp, + "pipeline_parallel": pp, + "sequence_data_parallel": sp, + "total_gpus": gpus, + } + ) + return variants + + # --------------------------------------------------------------------- # Tests # --------------------------------------------------------------------- @@ -239,17 +309,25 @@ def test_sampling_overflow_creates_two_and_padding_final(monkeypatched_redis, st assert out[1].tokens.tokens.tolist() == list(range(6)) + [-100, -100, -100, -100] -def test_data_single_consumer(fake_redis_server): - stream_config, fake_redis = fake_redis_server - - sequence_length = 10 - samples_count = 2 - - push_msg(fake_redis, stream_config, list(range(sequence_length))) - push_msg(fake_redis, stream_config, list(range(sequence_length))) - - distributed = Distributed(DistributedConfig(), use_cpu=True) - sampling_data = make_sampling(sequence_length, 0, samples_count, distributed) +def distributed_gptdata_streaming_test( + stream_config, + sequence_length, + micro_batch_size, + batch_size, + tensor_parallel, + pipeline_parallel, + sequence_data_parallel, + total_gpus, +): + distributed = Distributed( + DistributedConfig( + tensor_parallel=tensor_parallel, + pipeline_parallel=pipeline_parallel, + sequence_data_parallel=sequence_data_parallel, + ), + use_cpu=total_gpus > 0, + ) + sampling_data = make_sampling(sequence_length, 0, micro_batch_size, distributed) data_config = {"datasets": {"streaming1": stream_config.to_dict()}, "sampling": {"shuffle": "disabled"}} data_config = GPTDataConfig.from_dict(data_config) @@ -260,7 +338,7 @@ def test_data_single_consumer(fake_redis_server): with NoAutoValidate(): batch_config = GPTBatchConfig( - micro_batch_size=samples_count, batch_size=samples_count, sequence_length=sequence_length + micro_batch_size=micro_batch_size, batch_size=batch_size, sequence_length=sequence_length ) batch_config.setup(distributed_config=distributed.config) batch_config.validate() @@ -268,4 +346,41 @@ def test_data_single_consumer(fake_redis_server): data_iter = data.get_iterator(batch_config, "streaming1", consumed_samples=0, num_workers=1, prefetch_factor=1) batch = next(data_iter) - assert batch.tokens.tokens.shape == (2, 10) + assert batch.tokens.tokens.shape == (micro_batch_size, sequence_length) + + +variants = generate_parallelism_variants(torch.cuda.device_count()) + + +@pytest.mark.parametrize( + "variant", + variants, + ids=[ + f"dg{v['data_groups']}_tp{v['tensor_parallel']}_pp{v['pipeline_parallel']}_sp{v['sequence_data_parallel']}_gpu{v['total_gpus']}" + for v in variants + ], +) +def test_gptdata_streaming(fake_redis_server, variant): + if variant["total_gpus"] > 1: + pytest.skip(f"Skipping, not implemented for gpu count {variant["total_gpus"]}") + + stream_config, fake_redis = fake_redis_server + + sequence_length = 10 + micro_batch_size = 2 + batch_size = micro_batch_size * variant["data_groups"] // variant["sequence_data_parallel"] + + for _ in range(batch_size): + push_msg(fake_redis, stream_config, list(range(sequence_length))) + + # TODO: call with torchrun.distributed for more than 1 gpu + distributed_gptdata_streaming_test( + stream_config, + sequence_length, + micro_batch_size, + batch_size, + variant["tensor_parallel"], + variant["pipeline_parallel"], + variant["sequence_data_parallel"], + variant["total_gpus"], + ) From d5ce3f280f634f6c6ae23bd00f8e5a3371e82b97 Mon Sep 17 00:00:00 2001 From: bigximik Date: Wed, 26 Nov 2025 15:19:57 +0000 Subject: [PATCH 45/45] added multi gpu gptdata streaming test --- tests/conftest.py | 1 + tests/data/gptdata_streaming_test.py | 105 +++++++++++ tests/data/test_streaming.py | 272 +++++++++++++++++---------- tests/utils/run_test_script.py | 11 ++ 4 files changed, 294 insertions(+), 95 deletions(-) create mode 100644 tests/data/gptdata_streaming_test.py diff --git a/tests/conftest.py b/tests/conftest.py index 58301919f..1bd89bdd9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,6 +27,7 @@ from tests.utils.run_test_script import ( # isort: skip compare_results_for_all_models, run_distributed_script, + run_distributed_script_lean, run_test_script_base_path, run_test_script_for_all_models, ) diff --git a/tests/data/gptdata_streaming_test.py b/tests/data/gptdata_streaming_test.py new file mode 100644 index 000000000..7c225fbfd --- /dev/null +++ b/tests/data/gptdata_streaming_test.py @@ -0,0 +1,105 @@ +import argparse +import pathlib + +import cloudpickle + +from fast_llm.config import NoAutoValidate +from fast_llm.data.data.gpt.config import GPTDataConfig +from fast_llm.data.data.gpt.data import GPTData +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.models.gpt.config import GPTBatchConfig +from tests.data.test_streaming import get_stream_config, make_sampling + + +def distributed_gptdata_streaming_test( + sequence_length, + micro_batch_size, + batch_size, + tensor_parallel, + pipeline_parallel, + sequence_data_parallel, + total_gpus, + redis_port, + result_path, +): + stream_config = get_stream_config() + stream_config = stream_config.from_dict(stream_config.to_dict(), {("redis", "port"): redis_port}) + + distributed = Distributed( + DistributedConfig( + tensor_parallel=tensor_parallel, + pipeline_parallel=pipeline_parallel, + sequence_data_parallel=sequence_data_parallel, + ), + use_cpu=total_gpus == 0, + ) + sampling_data = make_sampling(sequence_length, 0, micro_batch_size, distributed) + + data_config = {"datasets": {"streaming1": stream_config.to_dict()}, "sampling": {"shuffle": "disabled"}} + data_config = GPTDataConfig.from_dict(data_config) + + data = GPTData(data_config, distributed.config) + + data.setup(distributed, {"streaming1": sampling_data.parameters}, "/tmp") + + with NoAutoValidate(): + batch_config = GPTBatchConfig( + micro_batch_size=micro_batch_size, batch_size=batch_size, sequence_length=sequence_length + ) + batch_config.setup(distributed_config=distributed.config) + batch_config.validate() + + data_iter = data.get_iterator(batch_config, "streaming1", consumed_samples=0, num_workers=1, prefetch_factor=1) + + batch = next(data_iter) + # TODO: save result per batch_data_group and rank + assert batch.tokens.tokens.shape == (micro_batch_size, sequence_length) + + result_path = ( + pathlib.Path(result_path) + / ( + f"{distributed.config.batch_data_rank}_" + f"{distributed.model_and_sequence_data_group.rank() if distributed.model_and_sequence_data_group is not None else 0}" + ) + / "batch.pkl" + ) + result_path.parent.mkdir(exist_ok=True, parents=True) + with result_path.open("wb") as f: + cloudpickle.dump(batch, f) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run distributed GPT data streaming test.") + + parser.add_argument("--sequence-length", type=int, required=True, help="Sequence length of the model input.") + parser.add_argument("--micro-batch-size", type=int, required=True, help="Micro batch size.") + parser.add_argument("--batch-size", type=int, required=True, help="Global batch size.") + parser.add_argument("--tensor-parallel", type=int, required=True, help="Tensor parallel degree.") + parser.add_argument("--pipeline-parallel", type=int, required=True, help="Pipeline parallel degree.") + parser.add_argument("--sequence-data-parallel", type=int, required=True, help="Sequence data parallel degree.") + parser.add_argument("--total-gpus", type=int, required=True, help="Total number of GPUs available.") + parser.add_argument("--redis-port", type=int, required=True, help="Redis port to connect to.") + parser.add_argument("--result-path", type=str, required=True, help="Path to save test results.") + + return parser.parse_args() + + +def main(): + args = parse_args() + + distributed_gptdata_streaming_test( + sequence_length=args.sequence_length, + micro_batch_size=args.micro_batch_size, + batch_size=args.batch_size, + tensor_parallel=args.tensor_parallel, + pipeline_parallel=args.pipeline_parallel, + sequence_data_parallel=args.sequence_data_parallel, + total_gpus=args.total_gpus, + redis_port=args.redis_port, + result_path=args.result_path, + ) + + +if __name__ == "__main__": + main() diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py index 092bbba90..06c14d02a 100644 --- a/tests/data/test_streaming.py +++ b/tests/data/test_streaming.py @@ -1,3 +1,7 @@ +import logging +import os +import pickle +import socket import threading import fakeredis @@ -5,9 +9,6 @@ import pytest import torch -from fast_llm.config import NoAutoValidate -from fast_llm.data.data.gpt.config import GPTDataConfig -from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.config import ( RedisConfig, SamplingConfig, @@ -20,7 +21,10 @@ from fast_llm.data.sample.pipeline_rl import PipelineRLSample from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.models.gpt.config import GPTBatchConfig +from tests.utils.utils import requires_cuda + +logger = logging.getLogger(__name__) + # --------------------------------------------------------------------- # Fixtures @@ -44,20 +48,14 @@ def monkeypatched_redis(monkeypatch, fake_redis): @pytest.fixture def stream_config(): - return StreamingDatasetConfig( - redis=RedisConfig( - host="localhost", - port=6379, - stream_key="test_stream", - group_name="test_group", - consumer_name_prefix="consumer", - ), - data_key="data", - ) + return get_stream_config() @pytest.fixture def fake_redis_server(stream_config): + # We search for free port as port from previous test can still be not free even after server shutdown + stream_config = stream_config.from_dict(stream_config.to_dict(), {("redis", "port"): find_free_port()}) + server_address = (stream_config.redis.host, stream_config.redis.port) server = fakeredis.TcpFakeServer(server_address, server_type="redis") @@ -82,6 +80,26 @@ def fake_redis_server(stream_config): # --------------------------------------------------------------------- +def find_free_port(): + """Find a free TCP port and return it.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def get_stream_config(): + return StreamingDatasetConfig( + redis=RedisConfig( + host="localhost", + port=6379, + stream_key="test_stream", + group_name="test_group", + consumer_name_prefix="consumer", + ), + data_key="data", + ) + + def push_msg(redis_client, config, tokens=None, is_eof=False): """Push a message into FakeRedis stream.""" if is_eof: @@ -118,26 +136,12 @@ def generate_parallelism_variants(total_gpus: int): if total_gpus > 1 and total_gpus % 2 == 1: total_gpus = total_gpus - 1 - if total_gpus == 0: - return [ - { - "data_groups": 1, - "tensor_parallel": 1, - "pipeline_parallel": 1, - "sequence_data_parallel": 1, - "total_gpus": 0, - } - ] + if total_gpus < 2: + # No gpu and one gpu tests are the same, + # so no need of creation of variant for a single gpu + return [] - variants = [ - { - "data_groups": 1, - "tensor_parallel": 1, - "pipeline_parallel": 1, - "sequence_data_parallel": 1, - "total_gpus": 1, - } - ] + variants = [] for gpus in range(2, total_gpus + 1, 2): # try all possible numbers of data groups (1..total_gpus) @@ -157,19 +161,24 @@ def generate_parallelism_variants(total_gpus: int): continue sp = rem_after_tp // pp try: - # instead of repeating all safeguards here just try to instantiate distributed config to check if combination is valid - DistributedConfig( + # instead of repeating all safeguards here just try to + # instantiate distributed config to check if combination is valid + dist_config = DistributedConfig( tensor_parallel=tp, pipeline_parallel=pp, sequence_data_parallel=sp, world_size=gpus, + # TODO: works only on one node + local_world_size=gpus, rank=0, ) except Exception: continue + variants.append( { "data_groups": data_groups, + "batch_data_parallel": dist_config.batch_data_parallel, "tensor_parallel": tp, "pipeline_parallel": pp, "sequence_data_parallel": sp, @@ -179,6 +188,113 @@ def generate_parallelism_variants(total_gpus: int): return variants +def run_distributed_gptdata_streaming_test( + fake_redis_server, + variant, + run_distributed_script, + result_path, + request, +): + import tests.data.gptdata_streaming_test + + stream_config, fake_redis = fake_redis_server + + sequence_length = 10 + micro_batch_size = 2 + batch_size = micro_batch_size * variant["batch_data_parallel"] + tensor_parallel = variant["tensor_parallel"] + pipeline_parallel = variant["pipeline_parallel"] + sequence_data_parallel = variant["sequence_data_parallel"] + total_gpus = variant["total_gpus"] + redis_port = stream_config.redis.port + + for i in range(batch_size): + push_msg(fake_redis, stream_config, [i] * sequence_length) + + result_path = result_path / "distributed_gptdata_streaming_test" / request.node.name + + if total_gpus > 0: + script = [ + "-m", + tests.data.gptdata_streaming_test.__name__, + "--sequence-length", + str(sequence_length), + "--micro-batch-size", + str(micro_batch_size), + "--batch-size", + str(batch_size), + "--tensor-parallel", + str(tensor_parallel), + "--pipeline-parallel", + str(pipeline_parallel), + "--sequence-data-parallel", + str(sequence_data_parallel), + "--total-gpus", + str(total_gpus), + "--result-path", + str(result_path), + "--redis-port", + str(redis_port), + ] + # TODO: distributed_capture is ignored now inside the script + if request.config.getoption("distributed_capture"): + logger.warning( + "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." + ) + else: + script.append("--no-distributed-capture") + + env = os.environ.copy() + env["PYTHONHASHSEED"] = "42" + run_distributed_script(script, num_gpus=total_gpus, env=env) + else: + tests.data.gptdata_streaming_test.distributed_gptdata_streaming_test( + sequence_length=sequence_length, + micro_batch_size=micro_batch_size, + batch_size=batch_size, + tensor_parallel=tensor_parallel, + pipeline_parallel=pipeline_parallel, + sequence_data_parallel=sequence_data_parallel, + total_gpus=total_gpus, + redis_port=redis_port, + result_path=result_path, + ) + + check_distributed_gptdata_streaming_test_results( + result_path=result_path, + micro_batch_size=micro_batch_size, + batch_data_parallel=variant["batch_data_parallel"], + total_gpu=variant["total_gpus"], + ) + + +def check_distributed_gptdata_streaming_test_results( + result_path, + micro_batch_size, + batch_data_parallel, + total_gpu, +): + batch_data_parallel_size = total_gpu // batch_data_parallel if total_gpu > 0 else 1 + sample_idx = set() + for i in range(batch_data_parallel): + ref_batch = None + for j in range(batch_data_parallel_size): + with (result_path / f"{i}_{j}" / "batch.pkl").open("rb") as f: + batch = pickle.load(f) + if ref_batch is None: + ref_batch = batch + else: + # batches for same batch_data_parallel_group must be equal on all ranks + assert torch.equal(batch.tokens.tokens, ref_batch.tokens.tokens) + for j in range(micro_batch_size): + val = ref_batch.tokens.tokens[j, 0].item() + # all samples in batches between groups and in the batch must be unique + assert val not in sample_idx + sample_idx.add(val) + # unique sample count must be the same as global batch size + assert len(sample_idx) == micro_batch_size * batch_data_parallel + + # --------------------------------------------------------------------- # Tests # --------------------------------------------------------------------- @@ -309,49 +425,29 @@ def test_sampling_overflow_creates_two_and_padding_final(monkeypatched_redis, st assert out[1].tokens.tokens.tolist() == list(range(6)) + [-100, -100, -100, -100] -def distributed_gptdata_streaming_test( - stream_config, - sequence_length, - micro_batch_size, - batch_size, - tensor_parallel, - pipeline_parallel, - sequence_data_parallel, - total_gpus, -): - distributed = Distributed( - DistributedConfig( - tensor_parallel=tensor_parallel, - pipeline_parallel=pipeline_parallel, - sequence_data_parallel=sequence_data_parallel, - ), - use_cpu=total_gpus > 0, - ) - sampling_data = make_sampling(sequence_length, 0, micro_batch_size, distributed) - - data_config = {"datasets": {"streaming1": stream_config.to_dict()}, "sampling": {"shuffle": "disabled"}} - data_config = GPTDataConfig.from_dict(data_config) - - data = GPTData(data_config, distributed.config) +def test_gptdata_streaming_single_consumer(fake_redis_server, run_distributed_script_lean, result_path, request): - data.setup(distributed, {"streaming1": sampling_data.parameters}, "/tmp") - - with NoAutoValidate(): - batch_config = GPTBatchConfig( - micro_batch_size=micro_batch_size, batch_size=batch_size, sequence_length=sequence_length - ) - batch_config.setup(distributed_config=distributed.config) - batch_config.validate() - - data_iter = data.get_iterator(batch_config, "streaming1", consumed_samples=0, num_workers=1, prefetch_factor=1) - - batch = next(data_iter) - assert batch.tokens.tokens.shape == (micro_batch_size, sequence_length) + run_distributed_gptdata_streaming_test( + fake_redis_server=fake_redis_server, + variant={ + "data_groups": 1, + "tensor_parallel": 1, + "pipeline_parallel": 1, + "sequence_data_parallel": 1, + "total_gpus": 0, + "batch_data_parallel": 1, + }, + run_distributed_script=run_distributed_script_lean, + result_path=result_path, + request=request, + ) variants = generate_parallelism_variants(torch.cuda.device_count()) +@pytest.mark.slow +@requires_cuda @pytest.mark.parametrize( "variant", variants, @@ -360,27 +456,13 @@ def distributed_gptdata_streaming_test( for v in variants ], ) -def test_gptdata_streaming(fake_redis_server, variant): - if variant["total_gpus"] > 1: - pytest.skip(f"Skipping, not implemented for gpu count {variant["total_gpus"]}") - - stream_config, fake_redis = fake_redis_server - - sequence_length = 10 - micro_batch_size = 2 - batch_size = micro_batch_size * variant["data_groups"] // variant["sequence_data_parallel"] - - for _ in range(batch_size): - push_msg(fake_redis, stream_config, list(range(sequence_length))) - - # TODO: call with torchrun.distributed for more than 1 gpu - distributed_gptdata_streaming_test( - stream_config, - sequence_length, - micro_batch_size, - batch_size, - variant["tensor_parallel"], - variant["pipeline_parallel"], - variant["sequence_data_parallel"], - variant["total_gpus"], +def test_gptdata_streamin_gpus(fake_redis_server, variant, run_distributed_script_lean, result_path, request): + # TODO: make tests on the same number of gpu as subtests similar to how it is done in the test_model + # for speed + run_distributed_gptdata_streaming_test( + fake_redis_server=fake_redis_server, + variant=variant, + run_distributed_script=run_distributed_script_lean, + result_path=result_path, + request=request, ) diff --git a/tests/utils/run_test_script.py b/tests/utils/run_test_script.py index 5a24e5936..f8d30db27 100644 --- a/tests/utils/run_test_script.py +++ b/tests/utils/run_test_script.py @@ -59,6 +59,17 @@ def run_distributed_script( ) +@pytest.fixture(scope="session") +def run_distributed_script_lean( + worker_resources: "WorkerResources", +): + return functools.partial( + do_run_distributed_script, + rendezvous_port=worker_resources.rendezvous_port, + torchrun_port=worker_resources.torchrun_port, + ) + + @pytest.fixture(scope="session") def run_test_script_base_path(model_testing_config, result_path, request): return result_path / "models" / model_testing_config.name