Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
1a18929
Dataset interface
jlamypoirier Oct 15, 2025
fd63846
misc
jlamypoirier Oct 15, 2025
2486caf
fix
jlamypoirier Oct 15, 2025
92e93e8
Language model sample
jlamypoirier Oct 16, 2025
d6f6944
fix
jlamypoirier Oct 16, 2025
5c802fa
fixes
jlamypoirier Oct 16, 2025
95d1840
test
jlamypoirier Oct 16, 2025
eafd9cb
fixes
jlamypoirier Oct 17, 2025
c56df69
cleanup
jlamypoirier Oct 17, 2025
7f437e1
misc
jlamypoirier Oct 17, 2025
dfd27f5
misc
jlamypoirier Oct 17, 2025
90cd009
Memmap dataset
jlamypoirier Oct 18, 2025
acfd30e
fixes
jlamypoirier Oct 29, 2025
34939e9
fixes
jlamypoirier Oct 29, 2025
c5fa072
int64
jlamypoirier Oct 29, 2025
cd28676
Test and fix preparator
jlamypoirier Nov 5, 2025
435d214
fix
jlamypoirier Nov 5, 2025
f6bef55
fix
jlamypoirier Nov 6, 2025
e05d9a1
fix
jlamypoirier Nov 6, 2025
9ba8d1b
fix
jlamypoirier Nov 6, 2025
b35b297
fixes
jlamypoirier Nov 6, 2025
abe2357
misc
jlamypoirier Nov 11, 2025
1801d87
fix
jlamypoirier Nov 11, 2025
2223b85
fix right stage mode
bigximik Nov 13, 2025
a9a4ace
newer transformers fixes
bigximik Nov 13, 2025
97f2b60
fix distributed tests skip on single gpu
bigximik Nov 13, 2025
0fdc978
set mamba 2 style model conversions to broke
bigximik Nov 13, 2025
665deb5
Merge branch 'jlp/dataset_interface' of github.com:ServiceNow/Fast-LL…
bigximik Nov 17, 2025
4d03889
Merge branch 'jlp/lm_sample' of github.com:ServiceNow/Fast-LLM into d…
bigximik Nov 17, 2025
224c2ec
mmaba2 enable conversion tests
bigximik Nov 17, 2025
f1afbf2
Merge branch 'jlp/memmap_dataset' of github.com:ServiceNow/Fast-LLM i…
bigximik Nov 17, 2025
00bba27
added model_and_sequence_data_group
bigximik Nov 23, 2025
5b20276
added Iterable dataset base classes
bigximik Nov 23, 2025
978a68f
added naive sampled iterable dataset
bigximik Nov 23, 2025
066a0bf
added iterable dataset configs, streaming dataset and PipelineRL samp…
bigximik Nov 23, 2025
68b3d65
added distributed data loader wrapper
bigximik Nov 23, 2025
2fbfe99
added iterable dataset to gpt data
bigximik Nov 23, 2025
0892523
appended comment
bigximik Nov 23, 2025
54fadb4
changed base classes for iterable dataset configs
bigximik Nov 24, 2025
4e11bf3
fix batch type
bigximik Nov 24, 2025
8428df8
fix added name property to the class
bigximik Nov 24, 2025
04ee4d7
add eof for tests
bigximik Nov 24, 2025
1217998
change base class to torch iterable
bigximik Nov 24, 2025
c542dac
added straming dataset, sampling and base data tests
bigximik Nov 24, 2025
3999a8e
merge from main
bigximik Nov 24, 2025
c6ef780
merge from main
bigximik Nov 24, 2025
a1556f8
change import
bigximik Nov 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions fast_llm/data/data/data_loader_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch.distributed
import torch.utils.data.dataloader

from fast_llm.core.distributed import broadcast_object


class DistributedDataLoaderWrapper:
"""
Wraps a regular dataloader so that only the process group leader
loads data, and then broadcasts the batch to other ranks in the group.
"""

def __init__(
self,
dataloader: torch.utils.data.dataloader.DataLoader | None,
rank: int,
process_group: torch.distributed.ProcessGroup | None,
):
self.dataloader = dataloader
self.rank = rank
self.process_group = process_group

assert (self.rank == 0 and self.dataloader is not None) or (self.rank > 0 and self.dataloader is None)

def __iter__(self):
if self.rank == 0:
self.iterator = iter(self.dataloader)
if self.process_group is None:
return self.iterator
return self

def __next__(self):
# TODO:
# Instead of broadcasting a general object, make this iterator yield an actual Batch class.
# Implement `get_state_dict` and `from_state_dict` in the Batch class so that we can
# efficiently broadcast tensors directly. This avoids using `broadcast_object` on the
# entire Batch object, which is inefficient for tensors because it serializes
# (pickles) them before sending.

if self.rank == 0:
try:
data = next(self.iterator) # may raise StopIteration
except Exception as e:
data = e
data = broadcast_object(data, self.process_group, 0)
else:
data = broadcast_object(None, self.process_group, 0)

if isinstance(data, Exception):
raise data

return data
44 changes: 38 additions & 6 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@

from fast_llm.core.distributed import safe_barrier
from fast_llm.data.data.abstract import Data
from fast_llm.data.data.data_loader_wrapper import DistributedDataLoaderWrapper
from fast_llm.data.data.gpt.config import GPTDataConfig
from fast_llm.data.dataset.abstract import SampledDataset
from fast_llm.data.dataset.abstract import SampledDataset, SampledIterableDataset
from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters
from fast_llm.data.dataset.monitor import DatasetMonitor
from fast_llm.data.iterator import SampledDatasetIterator
from fast_llm.data.sample.language_model import LanguageModelBatch
from fast_llm.data.sample.pipeline_rl import PipelineRLBatch
from fast_llm.engine.config_utils.run import log_main_rank
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.engine.distributed.distributed import Distributed
Expand Down Expand Up @@ -86,7 +88,12 @@ def setup(
dataset_name=dataset_name,
)
dataset = self._config.datasets[dataset_name].build_and_sample(sampling)
self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)
if isinstance(dataset, SampledDataset):
self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)
else:
# Do not set monitor for iterable dataset as monitor only works with map style datasets
assert isinstance(dataset, SampledIterableDataset)
self._datasets[dataset_name] = dataset

safe_barrier(self._distributed.world_group, "data_preparation", timeout)
self._is_setup = True
Expand All @@ -112,9 +119,11 @@ def get_iterator(
Assert.in_range_incl(batch_config.sequence_length, 1, sampling_parameters.sequence_length)
log_main_rank(f"Initializing {dataset_name} dataset iterator from sample {consumed_samples}...")

return iter(
torch.utils.data.DataLoader(
self._datasets[dataset_name], # noqa
dataset = self._datasets[dataset_name]

if isinstance(dataset, SampledDataset):
data_loader = torch.utils.data.DataLoader(
dataset, # noqa
batch_sampler=SampledDatasetIterator(
total_samples=len(self._datasets[dataset_name]),
begin_index=consumed_samples,
Expand All @@ -128,4 +137,27 @@ def get_iterator(
collate_fn=LanguageModelBatch.from_samples,
multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None,
)
)

elif isinstance(dataset, SampledIterableDataset):
if (
self.distributed.model_and_sequence_data_group is None
or self.distributed.model_and_sequence_data_group.rank() == 0
):
rank = 0
data_loader = torch.utils.data.DataLoader(
dataset, # noqa
batch_size=batch_config.micro_batch_size,
num_workers=0 if num_workers == 0 else 1,
prefetch_factor=prefetch_factor,
pin_memory=True,
collate_fn=PipelineRLBatch.from_samples,
multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None,
)
else:
rank = self.distributed.model_and_sequence_data_group.rank()
data_loader = None
data_loader = DistributedDataLoaderWrapper(
data_loader, rank, self.distributed.model_and_sequence_data_group
)

return iter(data_loader)
24 changes: 23 additions & 1 deletion fast_llm/data/dataset/abstract.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import abc
import typing

import torch.utils.data

from fast_llm.data.sample.abstract import Sample

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -43,8 +45,28 @@ def __len__(self) -> int:
pass


class SamplableDataset[SampleType: Sample](Dataset[SampleType]):
# NOTE: We need to inherit from IterableDataset overwise torch data loader can not detect it properly
class SampledIterableDataset[SampleType: Sample](torch.utils.data.IterableDataset[SampleType]):
"""
A sampled dataset class that provides an iterator over samples.
"""

# NOTE: We add name here so it is compatible with Fast-LLM Dataset
@property
@abc.abstractmethod
def name(self) -> str:
"""
A name for the dataset to facilitate identification and debugging.
"""


class SamplableDataset[SampleType: Sample](Dataset[SampleType]):
@abc.abstractmethod
def sample(self, config: "SamplingData") -> SampledDataset[SampleType]:
pass


class SamplableIterableDataset[SampleType: Sample](SampledIterableDataset[SampleType]):
@abc.abstractmethod
def sample(self, config: "SamplingData") -> SampledIterableDataset[SampleType]:
pass
78 changes: 73 additions & 5 deletions fast_llm/data/dataset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,19 @@
import typing

from fast_llm.config import Config, Field, FieldHint, UpdateType, check_field, config_class
from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset
from fast_llm.data.dataset.abstract import (
SamplableDataset,
SamplableIterableDataset,
SampledDataset,
SampledIterableDataset,
)
from fast_llm.data.sample.abstract import Sample
from fast_llm.utils import Assert, normalize_probabilities

if typing.TYPE_CHECKING:
from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset
from fast_llm.data.sample.language_model import LanguageModelSample
from fast_llm.data.sample.pipeline_rl import PipelineRLSample
from fast_llm.engine.distributed.distributed import Distributed

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -105,19 +111,23 @@ class DatasetConfig[SampleType: Sample](Config):
@config_class(registry=True)
class SampledDatasetConfig[SampleType: Sample](DatasetConfig[SampleType]):
"""
A sampled dataset containing a prepared list of samples to be indexed sequentially (as-is) during training.
A sampled dataset containing a prepared list or iterable of samples to be indexed sequentially (as-is) during training.
"""

def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]:
def build_and_sample(
self, sampling: SamplingData
) -> SampledDataset[SampleType] | SampledIterableDataset[SampleType]:
raise NotImplementedError()


@config_class()
class SamplableDatasetConfig[SampleType: Sample](SampledDatasetConfig[SampleType]):
def build(self) -> SamplableDataset[SampleType]:
def build(self) -> SamplableDataset[SampleType] | SamplableIterableDataset[SampleType]:
raise NotImplementedError()

def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]:
def build_and_sample(
self, sampling: SamplingData
) -> SampledDataset[SampleType] | SampledIterableDataset[SampleType]:
return self.build().sample(sampling)


Expand Down Expand Up @@ -297,3 +307,61 @@ def build(self) -> "IndexedDataset[SampleType]":
return LegacyMemmapDataset[SampleType](name, self.path)
else:
raise FileNotFoundError(self.path)


@config_class()
class RedisConfig(Config):
host: str = Field(
default="localhost",
desc="Hostname or IP address of the Redis server.",
hint=FieldHint.core,
)

port: int = Field(
default=6379,
desc="Port number on which the Redis server is running.",
hint=FieldHint.core,
)

stream_key: str = Field(
default=None,
desc="Name of the Redis stream to read data from.",
hint=FieldHint.core,
)

group_name: str = Field(
default="fast_llm_dp_group",
desc="Name of the Redis consumer group used for data-parallel reading.",
hint=FieldHint.core,
)

consumer_name_prefix: str = Field(
default="fast_llm_dp_group_consumer",
desc="Prefix used to generate unique consumer names for each rank.",
hint=FieldHint.core,
)


@config_class(dynamic_type={SampledDatasetConfig: "streaming"})
class StreamingDatasetConfig[SampleType: PipelineRLSample](SamplableDatasetConfig[SampleType]):
"""
Configuration for a streaming dataset that reads training data from a Redis stream.
"""

_abstract = False

redis: RedisConfig = Field(
desc="Redis connection and stream settings used to fetch incoming training data.",
hint=FieldHint.core,
)

data_key: str = Field(
default="data",
desc="The Redis message field containing the serialized sample payload.",
hint=FieldHint.core,
)

def build_and_sample(self, sampling: SamplingData) -> SampledIterableDataset[SampleType]:
from fast_llm.data.dataset.streaming import StreamingDataset

return StreamingDataset[SampleType](self, sampling.distributed).sample(sampling)
62 changes: 61 additions & 1 deletion fast_llm/data/dataset/sampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
import yaml

from fast_llm.data.dataset.abstract import SampledDataset
from fast_llm.data.dataset.abstract import SamplableIterableDataset, SampledDataset, SampledIterableDataset
from fast_llm.data.dataset.config import SamplingData, ShufflingType
from fast_llm.data.dataset.indexed import IndexedDataset
from fast_llm.data.sample.abstract import Sample
Expand Down Expand Up @@ -429,3 +429,63 @@ def _load_yaml_data(self, data: dict[str, typing.Any]) -> None:

self._unshuffled_tokens = data["unshuffled_tokens"]
self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch


class NaiveSampledIterableDataset[SampleType: Sample](SampledIterableDataset[SampleType]):
def __init__(
self,
iterable_dataset: SamplableIterableDataset[SampleType],
sampling: SamplingData,
):
self._dataset = iterable_dataset
self._config = sampling
assert self._config.parameters.truncate_documents == False
assert self._config.config.shuffle == ShufflingType.disabled

def __iter__(self) -> typing.Iterator[SampleType]:
sample_length = self._config.parameters.sequence_length + self._config.parameters.extra_tokens
max_samples = self._config.parameters.num_samples
current_sample_length = 0
documents: list[SampleType] = []
num_samples = 0
for doc in self._dataset:
if len(doc) > sample_length:
logging.warning(f"Dropping doc with length {len(doc)} higher then sample_length {sample_length}")
continue
if current_sample_length + len(doc) > sample_length:
padding_length = sample_length - current_sample_length
assert padding_length > 0
documents.append(documents[-1].get_padding(padding_length))

yield documents[0].from_documents(documents)

num_samples += 1
if num_samples >= max_samples:
break

documents = [doc]
current_sample_length = len(doc)
else:
documents.append(doc)
current_sample_length += len(doc)

if current_sample_length == sample_length:
yield documents[0].from_documents(documents)

num_samples += 1
if num_samples >= max_samples:
break

documents = []
current_sample_length = 0

if num_samples < max_samples and current_sample_length > 0:
padding_length = sample_length - current_sample_length
assert padding_length > 0
documents.append(documents[-1].get_padding(padding_length))

yield documents[0].from_documents(documents)

@property
def name(self) -> str:
return self._dataset.name
Loading
Loading