diff --git a/fast_llm/data/data/data_loader_wrapper.py b/fast_llm/data/data/data_loader_wrapper.py new file mode 100644 index 000000000..f9e517248 --- /dev/null +++ b/fast_llm/data/data/data_loader_wrapper.py @@ -0,0 +1,52 @@ +import torch.distributed +import torch.utils.data.dataloader + +from fast_llm.core.distributed import broadcast_object + + +class DistributedDataLoaderWrapper: + """ + Wraps a regular dataloader so that only the process group leader + loads data, and then broadcasts the batch to other ranks in the group. + """ + + def __init__( + self, + dataloader: torch.utils.data.dataloader.DataLoader | None, + rank: int, + process_group: torch.distributed.ProcessGroup | None, + ): + self.dataloader = dataloader + self.rank = rank + self.process_group = process_group + + assert (self.rank == 0 and self.dataloader is not None) or (self.rank > 0 and self.dataloader is None) + + def __iter__(self): + if self.rank == 0: + self.iterator = iter(self.dataloader) + if self.process_group is None: + return self.iterator + return self + + def __next__(self): + # TODO: + # Instead of broadcasting a general object, make this iterator yield an actual Batch class. + # Implement `get_state_dict` and `from_state_dict` in the Batch class so that we can + # efficiently broadcast tensors directly. This avoids using `broadcast_object` on the + # entire Batch object, which is inefficient for tensors because it serializes + # (pickles) them before sending. + + if self.rank == 0: + try: + data = next(self.iterator) # may raise StopIteration + except Exception as e: + data = e + data = broadcast_object(data, self.process_group, 0) + else: + data = broadcast_object(None, self.process_group, 0) + + if isinstance(data, Exception): + raise data + + return data diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index de47ef761..7cd746ed0 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -8,12 +8,14 @@ from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.abstract import Data +from fast_llm.data.data.data_loader_wrapper import DistributedDataLoaderWrapper from fast_llm.data.data.gpt.config import GPTDataConfig -from fast_llm.data.dataset.abstract import SampledDataset +from fast_llm.data.dataset.abstract import SampledDataset, SampledIterableDataset from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters from fast_llm.data.dataset.monitor import DatasetMonitor from fast_llm.data.iterator import SampledDatasetIterator from fast_llm.data.sample.language_model import LanguageModelBatch +from fast_llm.data.sample.pipeline_rl import PipelineRLBatch from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed @@ -86,7 +88,12 @@ def setup( dataset_name=dataset_name, ) dataset = self._config.datasets[dataset_name].build_and_sample(sampling) - self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) + if isinstance(dataset, SampledDataset): + self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) + else: + # Do not set monitor for iterable dataset as monitor only works with map style datasets + assert isinstance(dataset, SampledIterableDataset) + self._datasets[dataset_name] = dataset safe_barrier(self._distributed.world_group, "data_preparation", timeout) self._is_setup = True @@ -112,9 +119,11 @@ def get_iterator( Assert.in_range_incl(batch_config.sequence_length, 1, sampling_parameters.sequence_length) log_main_rank(f"Initializing {dataset_name} dataset iterator from sample {consumed_samples}...") - return iter( - torch.utils.data.DataLoader( - self._datasets[dataset_name], # noqa + dataset = self._datasets[dataset_name] + + if isinstance(dataset, SampledDataset): + data_loader = torch.utils.data.DataLoader( + dataset, # noqa batch_sampler=SampledDatasetIterator( total_samples=len(self._datasets[dataset_name]), begin_index=consumed_samples, @@ -128,4 +137,27 @@ def get_iterator( collate_fn=LanguageModelBatch.from_samples, multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) - ) + + elif isinstance(dataset, SampledIterableDataset): + if ( + self.distributed.model_and_sequence_data_group is None + or self.distributed.model_and_sequence_data_group.rank() == 0 + ): + rank = 0 + data_loader = torch.utils.data.DataLoader( + dataset, # noqa + batch_size=batch_config.micro_batch_size, + num_workers=0 if num_workers == 0 else 1, + prefetch_factor=prefetch_factor, + pin_memory=True, + collate_fn=PipelineRLBatch.from_samples, + multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, + ) + else: + rank = self.distributed.model_and_sequence_data_group.rank() + data_loader = None + data_loader = DistributedDataLoaderWrapper( + data_loader, rank, self.distributed.model_and_sequence_data_group + ) + + return iter(data_loader) diff --git a/fast_llm/data/dataset/abstract.py b/fast_llm/data/dataset/abstract.py index 33942708b..0fe53b7ba 100644 --- a/fast_llm/data/dataset/abstract.py +++ b/fast_llm/data/dataset/abstract.py @@ -1,6 +1,8 @@ import abc import typing +import torch.utils.data + from fast_llm.data.sample.abstract import Sample if typing.TYPE_CHECKING: @@ -43,8 +45,28 @@ def __len__(self) -> int: pass -class SamplableDataset[SampleType: Sample](Dataset[SampleType]): +# NOTE: We need to inherit from IterableDataset overwise torch data loader can not detect it properly +class SampledIterableDataset[SampleType: Sample](torch.utils.data.IterableDataset[SampleType]): + """ + A sampled dataset class that provides an iterator over samples. + """ + # NOTE: We add name here so it is compatible with Fast-LLM Dataset + @property + @abc.abstractmethod + def name(self) -> str: + """ + A name for the dataset to facilitate identification and debugging. + """ + + +class SamplableDataset[SampleType: Sample](Dataset[SampleType]): @abc.abstractmethod def sample(self, config: "SamplingData") -> SampledDataset[SampleType]: pass + + +class SamplableIterableDataset[SampleType: Sample](SampledIterableDataset[SampleType]): + @abc.abstractmethod + def sample(self, config: "SamplingData") -> SampledIterableDataset[SampleType]: + pass diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 7611b4a31..78969f5a7 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -8,13 +8,19 @@ import typing from fast_llm.config import Config, Field, FieldHint, UpdateType, check_field, config_class -from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset +from fast_llm.data.dataset.abstract import ( + SamplableDataset, + SamplableIterableDataset, + SampledDataset, + SampledIterableDataset, +) from fast_llm.data.sample.abstract import Sample from fast_llm.utils import Assert, normalize_probabilities if typing.TYPE_CHECKING: from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset from fast_llm.data.sample.language_model import LanguageModelSample + from fast_llm.data.sample.pipeline_rl import PipelineRLSample from fast_llm.engine.distributed.distributed import Distributed logger = logging.getLogger(__name__) @@ -105,19 +111,23 @@ class DatasetConfig[SampleType: Sample](Config): @config_class(registry=True) class SampledDatasetConfig[SampleType: Sample](DatasetConfig[SampleType]): """ - A sampled dataset containing a prepared list of samples to be indexed sequentially (as-is) during training. + A sampled dataset containing a prepared list or iterable of samples to be indexed sequentially (as-is) during training. """ - def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: + def build_and_sample( + self, sampling: SamplingData + ) -> SampledDataset[SampleType] | SampledIterableDataset[SampleType]: raise NotImplementedError() @config_class() class SamplableDatasetConfig[SampleType: Sample](SampledDatasetConfig[SampleType]): - def build(self) -> SamplableDataset[SampleType]: + def build(self) -> SamplableDataset[SampleType] | SamplableIterableDataset[SampleType]: raise NotImplementedError() - def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: + def build_and_sample( + self, sampling: SamplingData + ) -> SampledDataset[SampleType] | SampledIterableDataset[SampleType]: return self.build().sample(sampling) @@ -297,3 +307,61 @@ def build(self) -> "IndexedDataset[SampleType]": return LegacyMemmapDataset[SampleType](name, self.path) else: raise FileNotFoundError(self.path) + + +@config_class() +class RedisConfig(Config): + host: str = Field( + default="localhost", + desc="Hostname or IP address of the Redis server.", + hint=FieldHint.core, + ) + + port: int = Field( + default=6379, + desc="Port number on which the Redis server is running.", + hint=FieldHint.core, + ) + + stream_key: str = Field( + default=None, + desc="Name of the Redis stream to read data from.", + hint=FieldHint.core, + ) + + group_name: str = Field( + default="fast_llm_dp_group", + desc="Name of the Redis consumer group used for data-parallel reading.", + hint=FieldHint.core, + ) + + consumer_name_prefix: str = Field( + default="fast_llm_dp_group_consumer", + desc="Prefix used to generate unique consumer names for each rank.", + hint=FieldHint.core, + ) + + +@config_class(dynamic_type={SampledDatasetConfig: "streaming"}) +class StreamingDatasetConfig[SampleType: PipelineRLSample](SamplableDatasetConfig[SampleType]): + """ + Configuration for a streaming dataset that reads training data from a Redis stream. + """ + + _abstract = False + + redis: RedisConfig = Field( + desc="Redis connection and stream settings used to fetch incoming training data.", + hint=FieldHint.core, + ) + + data_key: str = Field( + default="data", + desc="The Redis message field containing the serialized sample payload.", + hint=FieldHint.core, + ) + + def build_and_sample(self, sampling: SamplingData) -> SampledIterableDataset[SampleType]: + from fast_llm.data.dataset.streaming import StreamingDataset + + return StreamingDataset[SampleType](self, sampling.distributed).sample(sampling) diff --git a/fast_llm/data/dataset/sampled.py b/fast_llm/data/dataset/sampled.py index d51a68746..7ec659bd8 100644 --- a/fast_llm/data/dataset/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -8,7 +8,7 @@ import torch import yaml -from fast_llm.data.dataset.abstract import SampledDataset +from fast_llm.data.dataset.abstract import SamplableIterableDataset, SampledDataset, SampledIterableDataset from fast_llm.data.dataset.config import SamplingData, ShufflingType from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.sample.abstract import Sample @@ -429,3 +429,65 @@ def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._unshuffled_tokens = data["unshuffled_tokens"] self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch + + +class NaiveSampledIterableDataset[SampleType: Sample](SampledIterableDataset[SampleType]): + def __init__( + self, + iterable_dataset: SamplableIterableDataset[SampleType], + sampling: SamplingData, + ): + self._dataset = iterable_dataset + self._config = sampling.config + self._parameters = sampling.parameters + + assert self._parameters.truncate_documents == False + assert self._config.shuffle == ShufflingType.disabled + + def __iter__(self) -> typing.Iterator[SampleType]: + sample_length = self._parameters.sequence_length + self._parameters.extra_tokens + max_samples = self._parameters.num_samples + current_sample_length = 0 + documents: list[SampleType] = [] + num_samples = 0 + for doc in self._dataset: + if len(doc) > sample_length: + logging.warning(f"Dropping doc with length {len(doc)} higher then sample_length {sample_length}") + continue + if current_sample_length + len(doc) > sample_length: + padding_length = sample_length - current_sample_length + assert padding_length > 0 + documents.append(documents[-1].get_padding(padding_length)) + + yield documents[0].from_documents(documents) + + num_samples += 1 + if num_samples >= max_samples: + break + + documents = [doc] + current_sample_length = len(doc) + else: + documents.append(doc) + current_sample_length += len(doc) + + if current_sample_length == sample_length: + yield documents[0].from_documents(documents) + + num_samples += 1 + if num_samples >= max_samples: + break + + documents = [] + current_sample_length = 0 + + if num_samples < max_samples and current_sample_length > 0: + padding_length = sample_length - current_sample_length + assert padding_length > 0 + documents.append(documents[-1].get_padding(padding_length)) + + yield documents[0].from_documents(documents) + + @property + def name(self) -> str: + return self._dataset.name diff --git a/fast_llm/data/dataset/streaming.py b/fast_llm/data/dataset/streaming.py new file mode 100644 index 000000000..291886e2e --- /dev/null +++ b/fast_llm/data/dataset/streaming.py @@ -0,0 +1,125 @@ +import typing + +import torch.utils.data + +from fast_llm.data.dataset.abstract import SamplableIterableDataset, SampledIterableDataset +from fast_llm.data.dataset.config import SamplingData, StreamingDatasetConfig +from fast_llm.data.dataset.sampled import NaiveSampledIterableDataset +from fast_llm.data.sample.pipeline_rl import PipelineRLSample +from fast_llm.data.sample.range import RangeSample +from fast_llm.data.sample.token import TokenSample +from fast_llm.engine.distributed.distributed import Distributed + + +def dtype_from_string(name: str) -> torch.dtype: + try: + return getattr(torch, name) + except AttributeError: + raise ValueError(f"Unknown torch dtype: {name}") + + +class StreamingDataset[SampleType: PipelineRLSample](SamplableIterableDataset[SampleType]): + def __init__(self, config: StreamingDatasetConfig, distributed: Distributed): + super().__init__() + self._name = f"redis[{config.redis.host}:{config.redis.port}]({config.redis.stream_key}|{config.redis.group_name})[{config.data_key}]" + self._config = config + self.batch_data_rank = distributed.config.batch_data_rank + self.is_batch_data_group_leader = ( + distributed.model_and_sequence_data_group is None or distributed.model_and_sequence_data_group.rank() == 0 + ) + + @property + def name(self) -> str: + return self._name + + def sample(self, config: SamplingData) -> SampledIterableDataset[PipelineRLSample]: + # TODO: actually sample the dataset and not return docs + return NaiveSampledIterableDataset(self, config) + + def __getstate__(self) -> tuple[str, StreamingDatasetConfig, int, bool]: + return (self._name, self._config, self.batch_data_rank, self.is_batch_data_group_leader) + + def __setstate__(self, state: tuple[str, StreamingDatasetConfig, int, bool]): + name, config, batch_data_rank, is_batch_data_group_leader = state + self._name = name + self._config = config + self.batch_data_rank = batch_data_rank + self.is_batch_data_group_leader = is_batch_data_group_leader + + def __iter__(self) -> typing.Iterator[PipelineRLSample]: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None and worker_info.num_workers > 1: + raise RuntimeError("StreamingDataset can work only with one instance per rank") + + if not self.is_batch_data_group_leader: + raise RuntimeError("Must be only called on the batch data group leader") + + import orjson + import redis + import redis.exceptions + + r = redis.Redis(host=self._config.redis.host, port=self._config.redis.port) + + # Create the consumer group at the start of the stream ("0") + # If the stream already exists, XGROUP CREATE will fail unless we add mkstream=True + try: + r.xgroup_create( + name=self._config.redis.stream_key, groupname=self._config.redis.group_name, id="0", mkstream=True + ) + except redis.exceptions.ResponseError as e: + if "BUSYGROUP" in str(e): + # Consumer group already exists + pass + else: + raise + + while True: + # XREADGROUP reads from the consumer group + # COUNT: max number of messages to fetch at once + # BLOCK: wait for new messages (milliseconds) + messages = r.xreadgroup( + groupname=self._config.redis.group_name, + consumername=f"{self._config.redis.consumer_name_prefix}_{self.batch_data_rank}", + # ">" means read only new messages that were never delivered to this consumer + streams={self._config.redis.stream_key: ">"}, + count=1, + block=5000, # wait up to 5 seconds + ) + + data = None + if messages: + for stream_key, msgs in messages: + assert stream_key == self._config.redis.stream_key.encode() + for msg_id, msg_data in msgs: + r.xack(self._config.redis.stream_key, self._config.redis.group_name, msg_id) + data = orjson.loads(msg_data[self._config.data_key.encode()]) + # NOTE: for testing with fakeredis only as on real consumer group it will be delivered only to one consumer + if "eof" in data and data["eof"] == True: + break + yield self._sample_from_dict(data) + if "eof" in data and data["eof"] == True: + break + if data is not None and "eof" in data and data["eof"] == True: + break + + def _sample_from_dict(cls, data: dict) -> PipelineRLSample: + tokens = torch.tensor(data["tokens"], dtype=dtype_from_string(data["tokens_dtype"])) + sample_size = len(tokens) + if "loss_masking_spans" in data: + loss_masking_spans = [tuple(el) for el in data["loss_masking_spans"]] + else: + loss_masking_spans = None + if "chosen_spans" in data: + chosen_spans = [tuple(el) for el in data["chosen_spans"]] + else: + chosen_spans = None + if "rejected_spans" in data: + rejected_spans = [tuple(el) for el in data["rejected_spans"]] + else: + rejected_spans = None + return PipelineRLSample( + TokenSample(tokens, [sample_size]), + RangeSample(loss_masking_spans, sample_size) if loss_masking_spans is not None else None, + RangeSample(chosen_spans, sample_size) if chosen_spans is not None else None, + RangeSample(rejected_spans, sample_size) if rejected_spans is not None else None, + ) diff --git a/fast_llm/data/sample/pipeline_rl.py b/fast_llm/data/sample/pipeline_rl.py new file mode 100644 index 000000000..21c434ad9 --- /dev/null +++ b/fast_llm/data/sample/pipeline_rl.py @@ -0,0 +1,102 @@ +import typing + +import torch + +from fast_llm.data.sample.abstract import Batch, Sample +from fast_llm.data.sample.language_model import _crop_optional, _merge_optional +from fast_llm.data.sample.range import RangeBatch, RangeSample +from fast_llm.data.sample.token import TokenBatch, TokenSample + + +class PipelineRLSample(Sample): + def __init__( + self, + tokens: TokenSample, + loss_masking_spans: RangeSample | None = None, + chosen_spans: RangeSample | None = None, + rejected_spans: RangeSample | None = None, + ): + self.tokens = tokens + self.loss_masking_spans = loss_masking_spans + self.chosen_spans = chosen_spans + self.rejected_spans = rejected_spans + + @classmethod + def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: + return cls( + TokenSample.from_documents([document.tokens for document in documents]), + _merge_optional(RangeSample.from_documents, [document.loss_masking_spans for document in documents]), + _merge_optional(RangeSample.from_documents, [document.chosen_spans for document in documents]), + _merge_optional(RangeSample.from_documents, [document.rejected_spans for document in documents]), + ) + + def crop(self, begin: int, end: int) -> typing.Self: + return self.__class__( + self.tokens.crop(begin, end), + _crop_optional(self.loss_masking_spans, begin, end), + _crop_optional(self.chosen_spans, begin, end), + _crop_optional(self.rejected_spans, begin, end), + ) + + def __len__(self) -> int: + return len(self.tokens) + + def get_padding(self, size: int) -> typing.Self: + return PipelineRLSample( + self.tokens.get_padding(size), + None if self.loss_masking_spans is None else self.loss_masking_spans.get_padding(size), + None if self.chosen_spans is None else self.chosen_spans.get_padding(size), + None if self.rejected_spans is None else self.rejected_spans.get_padding(size), + ) + + +class PipelineRLBatch(Batch): + def __init__( + self, + tokens: TokenBatch, + loss_masking_spans: RangeBatch | None = None, + chosen_spans: RangeBatch | None = None, + rejected_spans: RangeBatch | None = None, + ): + self.tokens = tokens + self.loss_masking_spans = loss_masking_spans + self.chosen_spans = chosen_spans + self.rejected_spans = rejected_spans + + @classmethod + def from_samples(cls, samples: typing.Iterable[PipelineRLSample]) -> typing.Self: + return cls( + TokenBatch.from_samples([sample.tokens for sample in samples]), + _merge_optional(RangeBatch.from_samples, [sample.loss_masking_spans for sample in samples]), + _merge_optional(RangeBatch.from_samples, [sample.chosen_spans for sample in samples]), + _merge_optional(RangeBatch.from_samples, [sample.rejected_spans for sample in samples]), + ) + + def to_samples(self) -> list[PipelineRLSample]: + return [ + PipelineRLSample(tokens, loss_masking_spans, chosen_spans, rejected_spans) + for tokens, loss_masking_spans, chosen_spans, rejected_spans in zip( + self.tokens.to_samples(), + None if self.loss_masking_spans is None else self.loss_masking_spans.to_samples(), + None if self.chosen_spans is None else self.chosen_spans.to_samples(), + None if self.rejected_spans is None else self.rejected_spans.to_samples(), + strict=True, + ) + ] + + def crop(self, begin: int, end: int) -> typing.Self: + return self.__class__( + self.tokens.crop(begin, end), + _crop_optional(self.loss_masking_spans, begin, end), + _crop_optional(self.chosen_spans, begin, end), + _crop_optional(self.rejected_spans, begin, end), + ) + + def to_device_(self, device: "torch.device | str"): + self.tokens.to_device_(device) + if self.loss_masking_spans is not None: + self.loss_masking_spans.to_device_(device) + if self.chosen_spans is not None: + self.chosen_spans.to_device_(device) + if self.rejected_spans is not None: + self.rejected_spans.to_device_(device) diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index f4dab5a26..89b3756d3 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -1,3 +1,5 @@ +import collections +import copy import dataclasses import enum import logging @@ -97,6 +99,7 @@ class DistributedDimNames: sequence_data = "sequence_data" batch_data = "batch_data" tensor_and_sequence_data = "tensor_and_sequence_data" + model_and_sequence_data = "model_and_sequence_data" tensor_and_data = "tensor_and_data" @@ -243,6 +246,17 @@ class DistributedConfig(Config): ) def _validate(self) -> None: + self._init_ranks_and_sizes() + self._init_distributed_dims() + + super()._validate() + + if self.reference_config is not None: + self.compare(self.reference_config, ValueError) + Assert.in_range(self.rank, 0, self.world_size) + Assert.in_range(self.local_rank, 0, self.local_world_size) + + def _init_ranks_and_sizes(self): if self.world_size is None: self.world_size = self.default_world_size if self.rank is None: @@ -271,6 +285,7 @@ def _validate(self) -> None: if self.tensor_parallel == 1 and self.sequence_tensor_parallel: self.sequence_tensor_parallel = False + def _init_distributed_dims(self): if self.reference_config is not None: self.reference_config.validate() if self.reference_config.reference_config is not None: @@ -352,12 +367,33 @@ def _validate(self) -> None: ) ) - super()._validate() + rank, global_ranks = self._get_model_and_sequence_data_rank_and_global_ranks() + self._add_distributed_dim( + DistributedDim( + name=DistributedDimNames.model_and_sequence_data, + size=self.sequence_data_parallel * self.model_parallel, + rank=rank, + global_ranks=global_ranks, + ) + ) - if self.reference_config is not None: - self.compare(self.reference_config, ValueError) - Assert.in_range(self.rank, 0, self.world_size) - Assert.in_range(self.local_rank, 0, self.local_world_size) + def _get_model_and_sequence_data_rank_and_global_ranks(self) -> tuple[int, tuple[int]]: + # NOTE: The mapping from global ranks to batch-data-parallel groups is not easily + # expressible with a simple arithmetic pattern (e.g., fixed striding). To determine + # the grouping, we simulate rank initialization for every possible rank and record + # how ranks are assigned. This lets us compute: + # - `rank`: the index of the current rank within its batch-data-parallel group + # - the full list of global ranks that belong to this batch-data-parallel group. + + cfg = copy.copy(self) + batch_data_groups = collections.defaultdict(list) + for i in range(self.world_size): + cfg.rank = i + cfg._init_ranks_and_sizes() + if i == self.rank: + rank = len(batch_data_groups[cfg.batch_data_rank]) + batch_data_groups[cfg.batch_data_rank].append(i) + return rank, tuple(batch_data_groups[self.batch_data_rank]) def _get_global_ranks(self, size: int, stride: int) -> range: start = self.rank // (size * stride) * size * stride + self.rank % stride diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index 302cfcdce..7b95cecfb 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -171,6 +171,7 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False): self.tensor_group = self.add_group(self._config.distributed_dims[DistributedDimNames.tensor]) self.sequence_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.sequence_data]) self.batch_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.batch_data]) + # Global ranks wrong with pipeline first, so we hide the dims as a safety check. if not self._config.pipeline_first: self.tensor_and_sequence_data_group = self.add_group( @@ -180,6 +181,10 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False): self._config.distributed_dims[DistributedDimNames.tensor_and_data] ) + self.model_and_sequence_data_group = self.add_group( + self._config.distributed_dims[DistributedDimNames.model_and_sequence_data] + ) + self._config.log_first_rank(f"Setting random seeds...") dp_shift = self._config.dp_seed_shift * self._config.data_rank diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index aa4f2d570..9737546ad 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -203,7 +203,9 @@ def setup(self, distributed: Distributed, run: Run) -> None: # Setup the model. with torch.no_grad(): log_main_rank("Setting up model...") - self._multi_stage.setup(distributed) + self._multi_stage.setup( + distributed, mode=StageMode.inference if self._is_evaluation_only else StageMode.training + ) for name, reference_model in self._reference_models.items(): log_main_rank(f"Setting up `{name}` reference model...") reference_model.fast_llm_model.setup(distributed, StageMode.inference) diff --git a/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py index 40c4cfa87..a80c031aa 100644 --- a/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py +++ b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py @@ -18,7 +18,7 @@ from transformers.modeling_utils import PreTrainedModel from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralMLP, MistralModel, MistralRMSNorm from transformers.processing_utils import Unpack -from transformers.utils import LossKwargs, logging +from transformers.utils import TransformersKwargs, logging from transformers.utils.generic import ModelOutput from fast_llm_external_models.apriel_hybrid_ssm.configuration_apriel_hybrid_ssm import AprielHybridSSMConfig @@ -1252,9 +1252,6 @@ def forward( return output -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... - - class AprielHybridSSMPreTrainedModel(PreTrainedModel): config_class = AprielHybridSSMConfig base_model_prefix = "model" @@ -1383,7 +1380,7 @@ def forward( output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], + **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): diff --git a/fast_llm_external_models/diffusion_llama/modeling_diffusion_llama.py b/fast_llm_external_models/diffusion_llama/modeling_diffusion_llama.py index c8723af5d..a67a302ef 100644 --- a/fast_llm_external_models/diffusion_llama/modeling_diffusion_llama.py +++ b/fast_llm_external_models/diffusion_llama/modeling_diffusion_llama.py @@ -706,7 +706,7 @@ def forward( output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs, # TODO: Kwargs for Diffusion? : Unpack[KwargsForCausalLM], + **kwargs, # TODO: Kwargs for Diffusion? : Unpack[TransformersKwargs], ) -> MaskedLMOutput: r""" # TODO: Update docstring for diffusion diff --git a/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py b/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py index 5ad99ff96..d0e1988f1 100644 --- a/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py +++ b/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py @@ -15,7 +15,7 @@ from transformers.processing_utils import Unpack from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.utils import ( - LossKwargs, + TransformersKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging, @@ -761,9 +761,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... - - class MTPLlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} @@ -812,7 +809,7 @@ def forward( output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], + **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py new file mode 100644 index 000000000..092bbba90 --- /dev/null +++ b/tests/data/test_streaming.py @@ -0,0 +1,386 @@ +import threading + +import fakeredis +import orjson +import pytest +import torch + +from fast_llm.config import NoAutoValidate +from fast_llm.data.data.gpt.config import GPTDataConfig +from fast_llm.data.data.gpt.data import GPTData +from fast_llm.data.dataset.config import ( + RedisConfig, + SamplingConfig, + SamplingData, + SamplingParameters, + ShufflingType, + StreamingDatasetConfig, +) +from fast_llm.data.dataset.streaming import StreamingDataset +from fast_llm.data.sample.pipeline_rl import PipelineRLSample +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.models.gpt.config import GPTBatchConfig + +# --------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------- + + +@pytest.fixture +def fake_redis(): + """Return a FakeRedis instance.""" + return fakeredis.FakeRedis() + + +@pytest.fixture +def monkeypatched_redis(monkeypatch, fake_redis): + """Monkeypatch redis.Redis globally (works even for imports inside functions).""" + import redis + + monkeypatch.setattr(redis, "Redis", lambda *args, **kwargs: fake_redis) + return fake_redis + + +@pytest.fixture +def stream_config(): + return StreamingDatasetConfig( + redis=RedisConfig( + host="localhost", + port=6379, + stream_key="test_stream", + group_name="test_group", + consumer_name_prefix="consumer", + ), + data_key="data", + ) + + +@pytest.fixture +def fake_redis_server(stream_config): + server_address = (stream_config.redis.host, stream_config.redis.port) + server = fakeredis.TcpFakeServer(server_address, server_type="redis") + + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + + # Create a redis-py client pointing at the fake server + import redis + + client = redis.Redis(host=server_address[0], port=server_address[1]) + + yield stream_config, client + + # Everything after yield = teardown + server.shutdown() + server.server_close() + thread.join() + + +# --------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------- + + +def push_msg(redis_client, config, tokens=None, is_eof=False): + """Push a message into FakeRedis stream.""" + if is_eof: + msg = {"eof": True} + else: + msg = { + "tokens": tokens, + "tokens_dtype": "int64", + } + redis_client.xadd(config.redis.stream_key, {config.data_key: orjson.dumps(msg)}) + + +def make_sampling(sequence_length, extra_tokens, num_samples, distributed): + return SamplingData( + parameters=SamplingParameters( + sequence_length=sequence_length, + extra_tokens=extra_tokens, + num_samples=num_samples, + truncate_documents=False, + ), + config=SamplingConfig(shuffle=ShufflingType.disabled), + distributed=distributed, + dataset_name="test", + cache_directory="/tmp", + ) + + +def generate_parallelism_variants(total_gpus: int): + """ + Generate all valid variants of (data_groups, tensor_parallel, pipeline_parallel, sequence_parallel) + for a number of GPUs up to the total_gpus. + If total_gpus is odd and > 1, fallback to nearest lower even number for decomposable parallelism. + """ + if total_gpus > 1 and total_gpus % 2 == 1: + total_gpus = total_gpus - 1 + + if total_gpus == 0: + return [ + { + "data_groups": 1, + "tensor_parallel": 1, + "pipeline_parallel": 1, + "sequence_data_parallel": 1, + "total_gpus": 0, + } + ] + + variants = [ + { + "data_groups": 1, + "tensor_parallel": 1, + "pipeline_parallel": 1, + "sequence_data_parallel": 1, + "total_gpus": 1, + } + ] + + for gpus in range(2, total_gpus + 1, 2): + # try all possible numbers of data groups (1..total_gpus) + for data_groups in range(1, gpus + 1): + if gpus % data_groups != 0: + continue # cannot evenly split + + gpus_per_group = gpus // data_groups + + # now find all decompositions of gpus_per_group into tp*pp*sp + for tp in range(1, gpus_per_group + 1): + if gpus_per_group % tp != 0: + continue + rem_after_tp = gpus_per_group // tp + for pp in range(1, rem_after_tp + 1): + if rem_after_tp % pp != 0: + continue + sp = rem_after_tp // pp + try: + # instead of repeating all safeguards here just try to instantiate distributed config to check if combination is valid + DistributedConfig( + tensor_parallel=tp, + pipeline_parallel=pp, + sequence_data_parallel=sp, + world_size=gpus, + rank=0, + ) + except Exception: + continue + variants.append( + { + "data_groups": data_groups, + "tensor_parallel": tp, + "pipeline_parallel": pp, + "sequence_data_parallel": sp, + "total_gpus": gpus, + } + ) + return variants + + +# --------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------- + + +def test_streaming_dataset_reads_single_message(monkeypatched_redis, stream_config): + """StreamingDataset should read a message and convert it into PipelineRLSample.""" + fake_redis = monkeypatched_redis + + distributed = Distributed(DistributedConfig(), use_cpu=True) + dataset = StreamingDataset(stream_config, distributed) + + # Insert a message + push_msg(fake_redis, stream_config, [1, 2, 3]) + + it = iter(dataset) + sample = next(it) + + assert isinstance(sample, PipelineRLSample) + assert torch.equal(sample.tokens.tokens, torch.tensor([1, 2, 3], dtype=torch.int64)) + assert sample.tokens.lengths == [3] + assert sample.loss_masking_spans is None + assert sample.chosen_spans is None + assert sample.rejected_spans is None + + +def test_streaming_dataset_reads_multiple_messages(monkeypatched_redis, stream_config): + """StreamingDataset should read a message and convert it into PipelineRLSample.""" + fake_redis = monkeypatched_redis + + distributed = Distributed(DistributedConfig(), use_cpu=True) + dataset = StreamingDataset(stream_config, distributed) + + # Insert a message + push_msg(fake_redis, stream_config, [1, 2, 3]) + push_msg(fake_redis, stream_config, [1, 2, 3]) + push_msg(fake_redis, stream_config, [1, 2, 3]) + + it = iter(dataset) + for i in range(3): + sample = next(it) + + assert isinstance(sample, PipelineRLSample) + assert torch.equal(sample.tokens.tokens, torch.tensor([1, 2, 3], dtype=torch.int64)) + assert sample.tokens.lengths == [3] + assert sample.loss_masking_spans is None + assert sample.chosen_spans is None + assert sample.rejected_spans is None + + +def test_sampling_1_doc_exact_fit(monkeypatched_redis, stream_config): + """Docs exactly fill one sample.""" + fake_redis = monkeypatched_redis + + # Two rollouts: lengths 4 and 6 -> exactly 10 + push_msg(fake_redis, stream_config, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + + distributed = Distributed(DistributedConfig(), use_cpu=True) + sampler = StreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 1, distributed)) + + out = list(sampler) + + assert len(out) == 1 + s = out[0] + assert isinstance(s, PipelineRLSample) + assert len(s) == 10 + assert s.tokens.tokens.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + +def test_sampling_2_docs_exact_fit(monkeypatched_redis, stream_config): + """Docs exactly fill one sample.""" + fake_redis = monkeypatched_redis + + # Two rollouts: lengths 4 and 6 -> exactly 10 + push_msg(fake_redis, stream_config, [1, 2, 3, 4]) + push_msg(fake_redis, stream_config, [5, 6, 7, 8, 9, 10]) + + distributed = Distributed(DistributedConfig(), use_cpu=True) + sampler = StreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 1, distributed)) + + out = list(sampler) + + assert len(out) == 1 + s = out[0] + assert isinstance(s, PipelineRLSample) + assert len(s) == 10 + assert s.tokens.tokens.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + +def test_sampling_skips_too_long_doc_and_padding_final(monkeypatched_redis, stream_config): + """Rollout longer than sample_length must be dropped.""" + fake_redis = monkeypatched_redis + + push_msg(fake_redis, stream_config, list(range(20))) # skip: too long + push_msg(fake_redis, stream_config, list(range(8))) # usable + push_msg(fake_redis, stream_config, is_eof=True) + + distributed = Distributed(DistributedConfig(), use_cpu=True) + sampler = StreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 1, distributed)) + + out = list(sampler) + + assert len(out) == 1 + s = out[0] + assert len(s) == 10 + assert s.tokens.tokens.tolist() == list(range(8)) + [-100, -100] + + +def test_sampling_overflow_creates_two_and_padding_final(monkeypatched_redis, stream_config): + """A document overflowing the boundary triggers padding + next sample.""" + fake_redis = monkeypatched_redis + + push_msg(fake_redis, stream_config, list(range(6))) + push_msg(fake_redis, stream_config, list(range(6))) + push_msg(fake_redis, stream_config, is_eof=True) + + distributed = Distributed(DistributedConfig(), use_cpu=True) + sampler = StreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 2, distributed)) + + out = list(sampler) + + assert len(out) == 2 + + # sample 1: 0..5 + pad(4) + assert out[0].tokens.tokens.tolist() == list(range(6)) + [-100, -100, -100, -100] + + # sample 2: 0..5 + pad(4) + assert out[1].tokens.tokens.tolist() == list(range(6)) + [-100, -100, -100, -100] + + +def distributed_gptdata_streaming_test( + stream_config, + sequence_length, + micro_batch_size, + batch_size, + tensor_parallel, + pipeline_parallel, + sequence_data_parallel, + total_gpus, +): + distributed = Distributed( + DistributedConfig( + tensor_parallel=tensor_parallel, + pipeline_parallel=pipeline_parallel, + sequence_data_parallel=sequence_data_parallel, + ), + use_cpu=total_gpus > 0, + ) + sampling_data = make_sampling(sequence_length, 0, micro_batch_size, distributed) + + data_config = {"datasets": {"streaming1": stream_config.to_dict()}, "sampling": {"shuffle": "disabled"}} + data_config = GPTDataConfig.from_dict(data_config) + + data = GPTData(data_config, distributed.config) + + data.setup(distributed, {"streaming1": sampling_data.parameters}, "/tmp") + + with NoAutoValidate(): + batch_config = GPTBatchConfig( + micro_batch_size=micro_batch_size, batch_size=batch_size, sequence_length=sequence_length + ) + batch_config.setup(distributed_config=distributed.config) + batch_config.validate() + + data_iter = data.get_iterator(batch_config, "streaming1", consumed_samples=0, num_workers=1, prefetch_factor=1) + + batch = next(data_iter) + assert batch.tokens.tokens.shape == (micro_batch_size, sequence_length) + + +variants = generate_parallelism_variants(torch.cuda.device_count()) + + +@pytest.mark.parametrize( + "variant", + variants, + ids=[ + f"dg{v['data_groups']}_tp{v['tensor_parallel']}_pp{v['pipeline_parallel']}_sp{v['sequence_data_parallel']}_gpu{v['total_gpus']}" + for v in variants + ], +) +def test_gptdata_streaming(fake_redis_server, variant): + if variant["total_gpus"] > 1: + pytest.skip(f"Skipping, not implemented for gpu count {variant["total_gpus"]}") + + stream_config, fake_redis = fake_redis_server + + sequence_length = 10 + micro_batch_size = 2 + batch_size = micro_batch_size * variant["data_groups"] // variant["sequence_data_parallel"] + + for _ in range(batch_size): + push_msg(fake_redis, stream_config, list(range(sequence_length))) + + # TODO: call with torchrun.distributed for more than 1 gpu + distributed_gptdata_streaming_test( + stream_config, + sequence_length, + micro_batch_size, + batch_size, + variant["tensor_parallel"], + variant["pipeline_parallel"], + variant["sequence_data_parallel"], + variant["total_gpus"], + ) diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 3c3bfb833..c5a5e1c5b 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -379,6 +379,9 @@ def test_save_and_load_in_parallel(run_distributed_script, run_test_script_base_ # TODO: Test beyond 2 gpu configs? import tests.models.distributed_test_checkpoint + if torch.cuda.device_count() < 2: + pytest.skip(f"Not enough GPUs: {torch.cuda.device_count()} < 2") + script = [ "-m", tests.models.distributed_test_checkpoint.__name__, @@ -405,6 +408,7 @@ def reference_distributed_shard(get_convert_path) -> torch.Tensor | None: @requires_cuda +# NOTE: Should it depend on test_model_distributed instead? @pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) def test_load_parallel_checkpoint_in_single_gpu( @@ -425,6 +429,10 @@ def test_load_parallel_checkpoint_in_single_gpu( distributed_save_load_config = distributed_save_load_config.resolve( base_path=run_test_script_base_path, model_testing_config=model_testing_config ) + if torch.cuda.device_count() < distributed_save_load_config.num_gpus: + pytest.skip( + f"Not enough GPUs to run dependency: {torch.cuda.device_count()} < {distributed_save_load_config.num_gpus}" + ) report_subtest(distributed_save_load_config.save_path, distributed_save_load_config.num_gpus) load_and_compare_checkpoints( DistributedCheckpointFormat, diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 1ed99416e..056fbea70 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -626,6 +626,7 @@ def _update_and_add_testing_config( groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + # TODO: Fix and bring back to `testing_groups` ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, @@ -639,7 +640,7 @@ def _update_and_add_testing_config( ), # "pp","dp", "ce","16", "bf", "df", "stp"), ) - +# TODO: remove obsolete model _update_and_add_testing_config( # Tests hybrid discrete Mamba 2. "llama",