diff --git a/examples/mistral.yaml b/examples/mistral.yaml index 4b7fdd968..88655954f 100644 --- a/examples/mistral.yaml +++ b/examples/mistral.yaml @@ -62,7 +62,7 @@ model: multi_stage: zero_stage: 2 distributed: - training_dtype: bf16 + compute_dtype: bf16 seed: 984059 run: experiment_dir: mistral_example diff --git a/fast_llm/__init__.py b/fast_llm/__init__.py index d3ec452c3..493f7415d 100644 --- a/fast_llm/__init__.py +++ b/fast_llm/__init__.py @@ -1 +1 @@ -__version__ = "0.2.0" +__version__ = "0.3.0" diff --git a/fast_llm/config.py b/fast_llm/config.py index 4d3858fd7..9644df9c1 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -759,28 +759,8 @@ def from_dict( return cls._from_dict(default, strict) @classmethod - def from_flat_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - ) -> typing.Self: - # TODO v0.3: Remove flat format - return cls._from_dict(default, strict, True) - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - # TODO v0.3: Remove flat format + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: out_arg_dict = {"_from_dict_check": True} - - # TODO v0.3: Remove backward compatibility fix - if "__class__" in default: - del default["__class__"] - try: actual_cls = cls.get_subclass(default.get("type")) except KeyError: @@ -788,29 +768,23 @@ def _from_dict( actual_cls = cls if actual_cls is not None and actual_cls is not cls: - return actual_cls._from_dict(default, strict=strict, flat=flat) + return actual_cls._from_dict(default, strict=strict) # Do not validate yet in case the root class sets cross-dependencies in validation. with NoAutoValidate(): for name, field in cls.fields(): if not field.init or field._field_type != dataclasses._FIELD: # noqa continue - if flat: - if isinstance(field.type, type) and issubclass(field.type, Config): - out_arg_dict[name] = field.type._from_dict(default, False, True) - elif name in default: - out_arg_dict[name] = default.pop(name) - else: - # Check for nested configs to instantiate. - try: - value = cls._from_dict_nested(default.pop(name, MISSING), field.type, strict) - if value is not MISSING: - out_arg_dict[name] = value - except FieldTypeError as e: - raise FieldTypeError( - f"Invalid field type `{get_type_name(field.type)}` in class {cls._get_class_name()}: " - + ", ".join(e.args) - ) + # Check for nested configs to instantiate. + try: + value = cls._from_dict_nested(default.pop(name, MISSING), field.type, strict) + if value is not MISSING: + out_arg_dict[name] = value + except FieldTypeError as e: + raise FieldTypeError( + f"Invalid field type `{get_type_name(field.type)}` in class {cls._get_class_name()}: " + + ", ".join(e.args) + ) out = cls(**out_arg_dict) # noqa if strict and default: out._unknown_fields = default.copy() diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index 405d1c672..efee46959 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -1,23 +1,16 @@ import logging -import typing from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.data.config import MultiprocessingContext, TokenizerConfig from fast_llm.data.data.config import DataConfig -from fast_llm.data.dataset.gpt.config import ( - GPTLegacyConfig, - GPTLegacyDatasetConfig, - GPTSampledDatasetConfig, - GPTSamplingConfig, -) -from fast_llm.engine.distributed.config import PhaseType +from fast_llm.data.dataset.gpt.config import GPTSampledDatasetConfig, GPTSamplingConfig from fast_llm.utils import Assert logger = logging.getLogger(__name__) @config_class() -class GPTDataConfig(DataConfig, GPTLegacyConfig): +class GPTDataConfig(DataConfig): """ Configuration for the dataset(s), split and sampling. Currently hard-coded to a GPT dataset. @@ -48,32 +41,3 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): desc="Multiprocessing context. Do not touch.", hint=FieldHint.expert, ) - - def _validate(self) -> None: - if not self.datasets: - logger.warning( - "Using the legacy dataset definition format." " Specify it through `data.datasets` instead." - ) - self.datasets = { - phase.value.lower(): GPTLegacyDatasetConfig.from_dict(self, strict=False) - for phase in (PhaseType.training, PhaseType.validation, PhaseType.test) - } - super()._validate() - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - # TODO v0.x: Remove backward compatibility. - if "datasets" in default: - for phase in PhaseType: - if phase.value in default["datasets"]: - rename = phase.value.lower() - logger.warning(f"Renaming dataset {phase.value} to {rename}") - assert rename not in default["datasets"] - default["datasets"][rename] = default["datasets"].pop(phase.value) - - return super()._from_dict(default, strict, flat) diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 5e3ced8a4..0c1b0cd09 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -204,11 +204,6 @@ class BlendedDatasetConfig(SampledDatasetConfig): desc="The blending weight of each dataset.", hint=FieldHint.core, ) - legacy: bool = Field( - default=False, - desc="Use the legacy formulas for sub-dataset seeds and sample sizes.", - hint=FieldHint.deprecated, - ) def _validate(self) -> None: self.weights = normalize_probabilities(self.weights) @@ -231,20 +226,10 @@ def build_and_sample( sampling, parameters=dataclasses.replace( sampling.parameters, - num_samples=( - math.ceil( - weight - * ( - sampling.parameters.num_samples - + 5 * (sampling.parameters.num_samples * (1 - weight)) ** 0.5 - ) - ) - if self.legacy - else math.ceil(weight * sampling.parameters.num_samples) + 1 - ), + num_samples=math.ceil(weight * sampling.parameters.num_samples) + 1, ), # TODO: Seed may not be unique for nested blended datasets. - config=sampling.config.to_copy({"seed": sampling.config.seed + i * (0 if self.legacy else 697)}), + config=sampling.config.to_copy({"seed": sampling.config.seed + i * 697}), ), ) for i, (dataset, weight) in enumerate(zip(self.datasets, self.weights, strict=True)) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index ef2efedc9..656cd7d24 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -1,10 +1,8 @@ import dataclasses import enum -import json import pathlib import time import typing -import warnings import yaml @@ -22,8 +20,7 @@ SamplingData, SamplingParameters, ) -from fast_llm.engine.distributed.config import PhaseType -from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum +from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset, GPTDatasetSlice, GPTIndexedDataset @@ -41,7 +38,6 @@ class ShufflingType(str, enum.Enum): skip_first_epoch = "skip_first_epoch" # Disable shuffling entirely. disabled = "disabled" - legacy = "legacy" @config_class() @@ -222,45 +218,6 @@ def _convert_paths(self, config): return config -# Add user-friendly names for the configs. -@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated_memmap"}) -class GPTConcatenatedMemmapConfig(GPTIndexedDatasetConfig): - # TODO v0.3: Remove. - _abstract: typing.ClassVar[bool] = False - path: pathlib.Path = Field( - default=None, - desc="The path to a dataset directory.", - hint=FieldHint.core, - ) - - def _validate(self) -> None: - warnings.warn("`concatenated_memmap` dataset is deprecated. Use `file` instead.", DeprecationWarning) - super()._validate() - - def build(self) -> "GPTConcatenatedDataset": - - assert self.path.is_dir() - index_path = self.path / "index.txt" - - if index_path.is_file(): - prefixes = [self.path / line.strip() for line in index_path.open("r").readlines()] - else: - warnings.warn( - f"The dataset path {self.path} points to a directory." - " The dataset will be indexed automatically, which may be unsafe." - " We recommend using an index file instead." - ) - prefixes = [ - path.with_suffix("") - for path in self.path.iterdir() - if path.suffix == ".idx" and path.is_file() and path.with_suffix(".bin").is_file() - ] - dataset_config = GPTConcatenatedDatasetConfig.from_dict( - {"datasets": [{"type": "memmap", "path": prefix} for prefix in prefixes]} - ) - return dataset_config.build() - - @config_class() class FimConfig(Config): """ @@ -268,7 +225,7 @@ class FimConfig(Config): """ rate: float = Field( - # TODO: Use meaningful default now that fim is a wrapper? (bad for legacy config) + # TODO: Use meaningful default now that fim is a wrapper? default=0.0, desc="FIM rate for each sample.", hint=FieldHint.core, @@ -352,131 +309,6 @@ def build_and_sample( return GPTFimDataset(self, self.dataset.build_and_sample(sampling), sampling) -class LegacyDatasetSource(str, enum.Enum): - """ - An enum for the different ways to load datasets. - """ - - list = "list" - file = "file" - random = "random" - - -def _validate_split(value: list[int]) -> list[int]: - Assert.leq(len(value), 3) - return value + [0] * (len(value) - 3) - - -def _validate_path(value: str | list[str]) -> list[str]: - return [value] if isinstance(value, str) else value - - -@config_class() -class GPTLegacyConfig(Config): - split: list[float] = Field( - default_factory=lambda: [969, 30, 1], - desc="Split ratio for train, valid and test datasets.", - hint=FieldHint.deprecated, - valid=_validate_split, - ) - format: LegacyDatasetSource = Field( - default=LegacyDatasetSource.list, - desc="Format for the dataset definition.", - hint=FieldHint.deprecated, - ) - path: list[str] = Field( - default_factory=list, - desc="Path or list of paths and weights.", - hint=FieldHint.deprecated, - valid=_validate_path, - ) - fim: FimConfig = Field( - desc="Configuration for Fill In the Middle (FIM).", - hint=FieldHint.feature, - ) - - -@config_class(dynamic_type={GPTSampledDatasetConfig: "legacy"}) -class GPTLegacyDatasetConfig(GPTSampledDatasetConfig, GPTLegacyConfig): - _abstract: typing.ClassVar[bool] = False - - def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset: - - if self.format == LegacyDatasetSource.random: - Assert.eq(len(self.path), 0) - dataset_config = GPTRandomDatasetConfig() - else: - if self.format == LegacyDatasetSource.file: - Assert.eq(len(self.path), 1) - data_path = pathlib.Path(self.path[0]) - dataset_defs = json.load(data_path.open("r")) - data_base_path = data_path.parent - dataset_prefixes = [ - (data_base_path / dataset_def["prefix"]).resolve() for dataset_def in dataset_defs["datasets"] - ] - dataset_weights = normalize_probabilities( - [dataset_def["weight"] for dataset_def in dataset_defs["datasets"]] - ) - elif self.format == LegacyDatasetSource.list: - Assert.geq(len(self.path), 1) - if len(self.path) == 1: - dataset_prefixes, dataset_weights = [self.path[0].strip()], [1.0] - else: - Assert.custom(lambda x: x % 2 == 0, len(self.path)) - dataset_prefixes = [pathlib.Path(x.strip()).resolve() for x in self.path[1::2]] - assert len(dataset_prefixes) == len(set(dataset_prefixes)) - dataset_weights = normalize_probabilities([float(x) for x in self.path[::2]]) - else: - raise NotImplementedError(self.format) - - phase_splits = padded_cumsum(normalize_probabilities(self.split)) - - phase_index = { - PhaseType.training.value.lower(): 0, - PhaseType.validation.value.lower(): 1, - PhaseType.test.value.lower(): 2, - }[sampling.dataset_name] - - dataset_configs = [ - { - "type": "slice", - # TODO: this duplicates memmap datasets for each phase. - "dataset": {"type": "memmap", "path": prefix}, - "begin": float(phase_splits[phase_index]), - "end": float(phase_splits[phase_index + 1]), - } - for prefix in dataset_prefixes - ] - dataset_config = ( - { - "type": "blended", - "name": "blended", - "datasets": dataset_configs, - "weights": dataset_weights, - "legacy": True, - } - if len(dataset_configs) > 1 - else dataset_configs[0] - ) - if self.fim.rate > 0: - dataset_config = { - "type": "fim", - "dataset": dataset_config, - **self.fim.to_dict(), - } - # Legacy sampling config - dataset_config = { - "type": "sampled", - "dataset": dataset_config, - "sampling": { - "seed": sampling.distributed.config.seed, - "shuffle": "legacy", - }, - } - - return GPTSampledDatasetConfig.from_dict(dataset_config).build_and_sample(sampling) - - @config_class(dynamic_type={GPTSampledDatasetConfig: "test_slow"}) class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig): """ diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 688ea6a70..896229772 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -3,7 +3,7 @@ import numpy as np -from fast_llm.data.dataset.gpt.config import GPTSamplingData, ShufflingType +from fast_llm.data.dataset.gpt.config import GPTSamplingData from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset if typing.TYPE_CHECKING: @@ -26,13 +26,9 @@ def get_document_size(self, index: int) -> int: """ def sample(self, sampling: GPTSamplingData) -> "GPTSampledIndexedDataset": - from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset, LegacyGPTSampledIndexedDataset + from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset - return ( - LegacyGPTSampledIndexedDataset(self, sampling) - if sampling.config.shuffle == ShufflingType.legacy - else GPTSampledIndexedDataset(self, sampling) - ) + return GPTSampledIndexedDataset(self, sampling) class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[IndexedDatasetType], GPTIndexedDataset): diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 6a06002cb..95006f18e 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -17,7 +17,7 @@ from fast_llm.utils import Assert try: - from fast_llm.csrc.data import build_padded_token_cumsum, build_sample_idx # noqa + from fast_llm.csrc.data import build_padded_token_cumsum # noqa _extension_available = True except ImportError: @@ -531,160 +531,3 @@ def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._unshuffled_tokens = data["unshuffled_tokens"] self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch - - -class LegacyGPTSampledIndexedDataset(SampledDataset): - """ - A GPT dataset augmented with a sampling, i.e., - a pre-computed, shuffled list of samples to be indexed sequentially (as-is) during training. - The sampling exactly matches Megatron-LM with matching parameters. - Supports optional post-processing with FIM. - """ - - def __init__( - self, - indexed_dataset: GPTIndexedDataset, - sampling: GPTSamplingData, - ): - assert isinstance(sampling, GPTSamplingData) - self._indexed_dataset = indexed_dataset - if not sampling.parameters.truncate_documents: - raise NotImplementedError( - "Legacy sampling only supports document truncation. Please use the latest dataset format." - ) - self._config = sampling.config - self._parameters = sampling.parameters - if self._parameters.use_preference_loss_spans: - raise NotImplementedError("Legacy sampling does not support preference loss masking.") - - if sampling.cache_directory is None: - log_main_rank( - " > No dataset cache directory provided, building the index map on all ranks." - "This may be very inefficient...", - log_fn=logger.warning, - ) - base_path = None - else: - base_path = ( - sampling.cache_directory - / f"{self.name}_ns_{self._parameters.num_samples}_sl_{self._parameters.sequence_length}" - f"_s_{self._config.seed}" - ) - - self._doc_idx = MemmapArray( - None if base_path is None else base_path.with_name(base_path.name + "_doc_idx.npy") - ) - self._sample_idx = MemmapArray( - None if base_path is None else base_path.with_name(base_path.name + "_sample_idx.npy") - ) - self._shuffle_idx = MemmapArray( - None if base_path is None else base_path.with_name(base_path.name + "_shuffle_idx.npy") - ) - - # Build the indexed mapping if it doesn't exist. - if base_path is None or ( - sampling.distributed.config.rank == sampling.get_next_rank() - and not (self._doc_idx.exists() and self._sample_idx.exists() and self._shuffle_idx.exists()) - ): - self._sample() - - def _sample(self) -> None: - """ - Create a `GPTSampledDataset` with the requested parameters. - """ - logger.info(f" > Sampling dataset {self._indexed_dataset.name} ...") - document_sizes = self._indexed_dataset.get_document_sizes() - num_documents = len(document_sizes) - num_tokens = document_sizes.sum() - np_rng = np.random.RandomState(seed=self._config.seed) - - num_epochs = math.ceil((self._parameters.sequence_length * self._parameters.num_samples + 1) / num_tokens) - main_epochs_samples = ((num_epochs - 1) * num_tokens - 1) // self._parameters.sequence_length - last_epoch_samples = self._parameters.num_samples - main_epochs_samples - samples_per_epoch = (num_tokens - 1) // self._parameters.sequence_length - separate_last_epoch = num_epochs > 1 and last_epoch_samples < 0.8 * samples_per_epoch - - doc_idx = np.tile(np.arange(num_documents, dtype=np.int32), num_epochs) - if separate_last_epoch: - np_rng.shuffle(doc_idx[:-num_documents]) - np_rng.shuffle(doc_idx[-num_documents:]) - else: - np_rng.shuffle(doc_idx) - - assert _extension_available, ( - "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." - ) - - sample_idx = build_sample_idx( - document_sizes, - doc_idx, - self._parameters.sequence_length, - num_epochs, - num_tokens, - True, - ) - - total_size = sample_idx.shape[0] - 1 - shuffle_idx = np.arange( - 0, total_size, dtype=np.int64 if total_size >= (np.iinfo(np.uint32).max - 1) else np.uint32 - ) - if separate_last_epoch: - np_rng.shuffle(shuffle_idx[:main_epochs_samples]) - np_rng.shuffle(shuffle_idx[main_epochs_samples:]) - else: - np_rng.shuffle(shuffle_idx) - - Assert.geq(len(shuffle_idx), self._parameters.num_samples) - self._doc_idx.save(doc_idx) - self._sample_idx.save(sample_idx) - self._shuffle_idx.save(shuffle_idx[: self._parameters.num_samples]) - - def __len__(self) -> int: - return self._parameters.num_samples - - def __getitem__(self, idx: int) -> typing.Any: - """ - Get the sample, (fixed-length sequence of tokens holding one or more complete or partial documents) - with the requested sampling index. - The returned sample is ready to be concatenated, then fed to a `GPTModel` (see `GPTModel.preprocess`). - """ - # Get the shuffled index. - shuffled_idx = self._shuffle_idx[idx] - # Start and end documents and offsets. - doc_f, offset_f = self._sample_idx[shuffled_idx] - doc_l, offset_l = self._sample_idx[shuffled_idx + 1] - sample_list = [ - self._indexed_dataset.get( - self._doc_idx[doc].item(), - offset=(doc == doc_f) * offset_f, - length=offset_l + 1 - (doc == doc_f) * offset_f if doc == doc_l else None, - use_loss_masking_spans=self._parameters.use_loss_masking_spans, - ) - for doc in range(doc_f, doc_l + 1) - ] - token_ids = np.concatenate([sample.token_ids for sample in sample_list], dtype=np.int64) - Assert.eq(len(token_ids), self._parameters.sequence_length + 1) - - if self._parameters.use_loss_masking_spans: - spans = [] - offset = 0 - for sample in sample_list: - for span in sample.loss_masking_spans: - spans.append(span + offset) - offset += len(sample.token_ids) - spans = np.stack(spans, dtype=np.int32) if spans else np.array([]) - else: - spans = None - sequence_lengths = ( - np.array( - [sample.token_ids.size - (idx == len(sample_list) - 1) for idx, sample in enumerate(sample_list)], - dtype=np.int32, - ) - if not self._parameters.cross_document_attention - else None - ) - return GPTSample(token_ids=token_ids, loss_masking_spans=spans, sequence_lengths=sequence_lengths) - - @property - def name(self) -> str: - return self._indexed_dataset.name diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index c878cec0a..3f1970538 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -4,7 +4,6 @@ import logging import pathlib import typing -import warnings import yaml @@ -58,9 +57,7 @@ def __fast_llm_serialize__(cls) -> str: class DistributedCheckpointFormat(CheckpointFormat): - # TODO v0.3: Add `enforce_version_match` name: typing.ClassVar[str] = "distributed" - enforce_architecture_match: typing.ClassVar[bool] = True @classmethod def get_handler_class(cls) -> type["DistributedCheckpointHandler"]: @@ -125,17 +122,6 @@ class CheckpointStateConfigBase(CheckpointConfigBase): model_weights: bool = Field(default=True, hint=FieldHint.feature) optimizer_state: bool = Field(default=None, hint=FieldHint.feature) - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - cls._handle_renamed_field(default, "load_weights", "model_weights") - cls._handle_renamed_field(default, "load_optimizer", "optimizer_state") - return super()._from_dict(default, strict, flat) - @config_class() class CheckpointSaveConfigBase(CheckpointConfigBase): @@ -204,23 +190,6 @@ class CheckpointLoadMetadataConfig(CheckpointPathConfigBase): hint=FieldHint.core, ) - def _validate(self) -> None: - if self.load_config == "architecture": - raise NotImplementedError("load_config==`architecture` is no longer supported.") - super()._validate() - if ( - self.format in (DistributedCheckpointFormat, FastLLMCheckpointFormat) - and "load_config" not in self._explicit_fields - ): - warnings.warn( - "The default behaviour for model configuration loading has changed (May 2025)." - "All model parameters are now loaded, not just the architecture parameters." - "Please make sure this doesn't lead to unexpected breaking changes." - "Suppress this warning by setting `load_config = model` explicitly.", - ) - if self.format.enforce_architecture_match: - assert self.load_config.load_base_model - @config_class() class CheckpointLoadConfig(CheckpointLoadMetadataConfig, CheckpointStateConfigBase): diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index 7faf599f7..c2f4d8cdd 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -71,18 +71,9 @@ def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: framework="pt", device=str(self._model.distributed.device), ) as f: - if "state_shard" in f.keys(): - # Old format `state_shard` with shape `(num_shards, shard_size) - # TODO v0.3: Use checkpoint version? Drop support? - log_main_rank("Using legacy distributed checkpoint loader.", log_fn=logger.warning) - for shard_name in shard_names: - self._model.get_shard(shard_name).copy_( - f.get_slice("state_shard")[loaded_metadata.shards.index(shard_name)] - ) - else: - # TODO: Does this copy twice? - for shard_name in shard_names: - self._model.get_shard(shard_name).copy_(f.get_tensor(f"{shard_name}_shard")) + # TODO: Does this copy twice? + for shard_name in shard_names: + self._model.get_shard(shard_name).copy_(f.get_tensor(f"{shard_name}_shard")) else: log_main_rank("Checkpoint format doesn't match, using safe load", log_fn=logger.info) @@ -105,18 +96,7 @@ def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: # TODO: Lazy loading? with safetensors.safe_open(path, framework="pt", device=str(self._model.distributed.device)) as f: # TODO: Use self_shard - if "state_shard" in f.keys(): - # Old format `state_shard` with shape `(num_shards, shard_size) - # TODO v0.3: Use checkpoint version? Drop support? - log_main_rank("Using legacy distributed checkpoint loader.", log_fn=logger.warning) - loaded_shards = { - shard_name: f.get_slice("state_shard")[loaded_metadata.shards.index(shard_name)] - for shard_name in shard_names - } - else: - loaded_shards = { - shard_name: f.get_tensor(f"{shard_name}_shard") for shard_name in shard_names - } + loaded_shards = {shard_name: f.get_tensor(f"{shard_name}_shard") for shard_name in shard_names} self._copy_shard_overlaps(loaded_model, loaded_shards, context) diff --git a/fast_llm/engine/config_utils/initialization.py b/fast_llm/engine/config_utils/initialization.py index 7fefda4b0..2f12a45d2 100644 --- a/fast_llm/engine/config_utils/initialization.py +++ b/fast_llm/engine/config_utils/initialization.py @@ -26,16 +26,11 @@ class InitializationConfig(Config, Initialization): is_default: typing.ClassVar[bool] = False @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: if cls is InitializationConfig and cls.get_subclass(default.get("type")) is None: # Default subclass. - return DefaultInitializationConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) + return DefaultInitializationConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) @config_class() diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 1fc0c626d..1737f4308 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -20,7 +20,7 @@ @config_class() class RunConfig(Config): tensor_logs: TensorLogsConfig = Field(desc="Configuration for debug tensor logs.", hint=FieldHint.logging) - # TODO v0.3: Adjust (now only affects logging to file). + # TODO: Adjust (now only affects logging to file). structured_logs: bool = Field( default=True, desc="Configure logging to the Fast-LLM format.", hint=FieldHint.logging ) diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 9ec63517c..602c44a4e 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -104,7 +104,7 @@ class DistributedConfig(Config): """ Configuration for the distributed setup. Also include variables for global settings such as data types, random seeds, initialization parameters. - TODO v0.3: Move these unrelated variables elsewhere. + TODO: Move these unrelated variables elsewhere. TODO: Avoid hard-coding distributed dims (use derived class?) TODO: Separate distributed space from config? """ @@ -181,19 +181,19 @@ class DistributedConfig(Config): valid=check_field(Assert.gt, 0), ) seed: int = Field(default=1234, desc="A seed for training.", hint=FieldHint.optional) - # TODO v0.3: Rename to compute_dtype (not just for training), move elsewhere - training_dtype: DataType = Field( + # TODO: Rename to compute_dtype (not just for training), move elsewhere + compute_dtype: DataType = Field( default=DataType.float32, desc="The data type used for the forward and backward passes.", hint=FieldHint.core, ) - # TODO v0.3: move elsewhere + # TODO : move elsewhere optimization_dtype: DataType = Field( default=DataType.float32, desc="The data type used for the optimizer.", hint=FieldHint.expert, ) - # TODO v0.3: move random state elsewhere + # TODO: move random state elsewhere # Extra seed parameters (can usually be left alone) dp_seed_shift: int = Field( default=_BIG_PRIMES[0], desc="Seed shift for extra randomness.", hint=FieldHint.optional @@ -378,13 +378,3 @@ def _log_on_rank[ def log_first_rank[T](self, *message, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info): return self._log_on_rank(*message, rank=0, log_fn=log_fn) - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - cls._handle_renamed_field(default, "distributed_timeout", "timeout") - return super()._from_dict(default, strict, flat) diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index dc41539c0..2e2f9d401 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -238,7 +238,7 @@ def check_config(self, config: DistributedConfig) -> None: def set_step(self, step: int, phase: PhaseType) -> None: """ Reseed pytorch for a given training step. - TODO v0.3: Move unrelated content elsewhere. + TODO: Move unrelated content elsewhere. """ seed_shift = step * self._config.sample_seed_shift + self._phase_seeds_shifts[phase] self.pp_generator.manual_seed((self._pp_seed + seed_shift) % MAX_SEED) diff --git a/fast_llm/engine/evaluation/__init__.py b/fast_llm/engine/evaluation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index 04e4227f1..4eb5d71df 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -27,16 +27,11 @@ class EvaluatorConfig(EvaluatorConfigBase): _abstract: typing.ClassVar[bool] = True @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - # TODO v0.x: Remove backward compatibility. - if not "type" in default: - default["type"] = "loss" - return super()._from_dict(default, strict, flat) + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + if cls is EvaluatorConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return LossEvaluatorConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) @config_class(dynamic_type={EvaluatorConfig: "loss"}) diff --git a/fast_llm/engine/evaluation/lm_eval/__init__.py b/fast_llm/engine/evaluation/lm_eval/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/engine/inference/config.py b/fast_llm/engine/inference/config.py index b414323e4..d19e2478d 100644 --- a/fast_llm/engine/inference/config.py +++ b/fast_llm/engine/inference/config.py @@ -25,7 +25,7 @@ def __init__(self, fast_llm_config: FastLLMModelConfig | None = None, **kwargs): self.use_cache = kwargs.pop("use_cache", True) super().__init__(**kwargs) if self.torch_dtype is not None: - assert self.torch_dtype == self.fast_llm_config.distributed.training_dtype.torch + assert self.torch_dtype == self.fast_llm_config.distributed.compute_dtype.torch def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs) -> None: # Hack the method to save at the right place. @@ -90,7 +90,7 @@ def _get_config_dict( updates = {} torch_dtype = kwargs.pop("torch_dtype", None) if torch_dtype is not None: - updates[("distributed", "training_dtype")] = torch_dtype + updates[("distributed", "compute_dtype")] = torch_dtype fast_llm_config = cls.model_config_class.from_metadata( pretrained, metadata, default=kwargs.pop("fast_llm_config", None), updates=updates ) diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 719088057..aa18f5052 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -351,48 +351,28 @@ class CheckpointMetadata(Config): def _validate(self) -> None: if isinstance(self.fast_llm_version, str): self.fast_llm_version = packaging.version.Version(self.fast_llm_version) - + code_version = packaging.version.Version(__version__) self.format = self.model.get_checkpoint_format(self.format) super()._validate() - if self.fast_llm_version.major != 0 or self.fast_llm_version.minor not in (0, 1, 2): - raise ValueError(f"Invalid checkpoint version: {self.fast_llm_version}") + if self.fast_llm_version > code_version: + raise ValueError(f"Unknown checkpoint version: {self.fast_llm_version}") + if self.fast_llm_version < packaging.version.Version("0.3.0"): + raise ValueError( + f"Checkpoint version {self.fast_llm_version} is no longer supported." + " If you really need this checkpoint," + " please convert it to an external model first using a compatible Fast-LLM version." + ) Assert.eq(self.config.__class__, self.model) @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - # TODO v0.3: Remove backward compatibility. - cls._handle_renamed_field(default, "checkpoint_type", "format") - cls._handle_renamed_field(default, "checkpoint_version", "fast_llm_version") - cls._handle_renamed_field(default, "fast_llm_config", "config") - cls._handle_renamed_field(default, "state_shard_names", "shards") - if "model" not in default: - default["model"] = "gpt" - if "format" not in default: - default["format"] = DistributedCheckpointFormat - if "fast_llm_version" not in default: - default["fast_llm_version"] = "0" - + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: model_config_class = default["model"] if isinstance(model_config_class, str): model_config_class = FastLLMModelConfig.get_subclass(default["model"]) default["model"] = model_config_class - # TODO v0.3: Remove backward compatibility. - if "config" not in default: - default["config"] = { - "base_model": model_config_class.get_base_model_config_class().from_flat_dict( - default.pop("model_config", {}) - ), - "multi_stage": default.pop("multi_stage_config", {}), - "distributed": default.pop("distributed_config", {}), - } # Instantiate the config with the appropriate class config = default.get("config", {}) if isinstance(config, dict): default["config"] = model_config_class.from_dict(config) - return super()._from_dict(default, strict, flat) + return super()._from_dict(default, strict) diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index 09ee788e6..6a6223cb7 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -54,14 +54,10 @@ def from_pretrained( metadata = cls.config_class.load_metadata(pretrained_config) config = cls.config_class.from_dict(metadata.config, *updates, update_type=UpdateType.update) if mode.support_training: - # TODO v0.3: Make metadata.shards mandatory? - if metadata.shards: - if optimizer_state_names is None: - optimizer_state_names = metadata.shards[1:] - else: - Assert.eq(optimizer_state_names, metadata.shards[1:]) - elif optimizer_state_names is None: - raise ValueError("`optimizer_state_names` is required") + if optimizer_state_names is None: + optimizer_state_names = metadata.shards[1:] + else: + Assert.eq(optimizer_state_names, metadata.shards[1:]) else: assert optimizer_state_names is None optimizer_state_names = () diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index cb0a02a67..868cc2db4 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -84,7 +84,7 @@ def __init__( dtype=( self._distributed_config.optimization_dtype if full_precision_shards - else self._distributed_config.training_dtype + else self._distributed_config.compute_dtype ).torch, ) # TODO: Distinguish grad and optimizer shard? @@ -94,13 +94,13 @@ def __init__( dtype=( self._distributed_config.optimization_dtype if full_precision_shards - else self._distributed_config.training_dtype + else self._distributed_config.compute_dtype ).torch, ) self._weight_buffer_meta = TensorMeta.from_dims( (TensorDim("weight_buffer", weight_shard_dim.size * self._fsdp_dim.size),), tensor_name=f"{self._name}_weight_buffer", - dtype=self._distributed_config.training_dtype.torch, + dtype=self._distributed_config.compute_dtype.torch, ) self._grad_buffer_meta = TensorMeta.from_dims( (TensorDim("grad_buffer", weight_shard_dim.size * self._fsdp_dim.size if self._requires_grad else 0),), @@ -108,7 +108,7 @@ def __init__( dtype=( self._distributed_config.optimization_dtype if full_precision_gradient_buffer - else self._distributed_config.training_dtype + else self._distributed_config.compute_dtype ).torch, ) diff --git a/fast_llm/engine/optimizer/optimizer.py b/fast_llm/engine/optimizer/optimizer.py index e72901e6e..0dd094390 100644 --- a/fast_llm/engine/optimizer/optimizer.py +++ b/fast_llm/engine/optimizer/optimizer.py @@ -19,7 +19,7 @@ def get_grad_scaler(config: GradientScalerConfig, distributed: Distributed) -> " initial_scale=config.constant, distributed=distributed, ) - elif distributed.config.training_dtype == DataType.float16: + elif distributed.config.compute_dtype == DataType.float16: return DynamicGradScaler( initial_scale=config.initial, min_scale=config.minimum, diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 8c9e035d9..531bc206b 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -170,17 +170,6 @@ def get_evaluator( return TrainingEvaluator(name, self, batch_config, data_load_num_proc, train_iters) - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - # TODO v0.x: Remove backward compatibility. - cls._handle_renamed_field(default, "iterations", ("evaluator", "iterations")) - return super()._from_dict(default, strict, flat) - @config_class() class TrainingCheckpointBaseConfig(IntervalConfig): @@ -234,10 +223,7 @@ class TrainingCheckpointConfig(TrainingCheckpointBaseConfig): keep: int | None = FieldUpdate(default=5) def get_save_directory(self, experiment_directory: pathlib.Path) -> pathlib.Path: - # TODO v0.3: Remove backward compatibility. - old_path = experiment_directory / "checkpoints" - new_path = experiment_directory / "checkpoint" - return old_path if old_path.is_dir() and not new_path.is_dir() else new_path + return experiment_directory / "checkpoint" def get_save_config(self, path: pathlib.Path, timeout: float | None) -> CheckpointSaveConfig: return CheckpointSaveConfig( @@ -329,18 +315,6 @@ class TrainingConfig(Config): valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - # TODO v0.x: Remove backward compatibility. - cls._handle_renamed_field(default, "validation", ("evaluators", "validation")) - cls._handle_renamed_field(default, "evaluations", ("evaluators")) - return super()._from_dict(default, strict, flat) - def _validate(self) -> None: super()._validate() self.shutdown.assert_sub_interval(self.checkpoint) diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 7db9b1fc3..a752bec28 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -541,7 +541,7 @@ def _prepare_training_state(self) -> None: def _save_checkpoint( self, config: TrainingCheckpointBaseConfig, metrics: dict[str, dict[str, float | int]] | None ) -> None: - # TODO v0.3: Move barrier, ok file to FastLLMModel + # TODO: Move barrier, ok file to FastLLMModel checkpoint_base_directory = config.get_save_directory(self._run.experiment_directory) checkpoint_directory = checkpoint_base_directory / str(self._completed_steps) @@ -600,7 +600,7 @@ def _load_checkpoint(self, config: TrainingCheckpointConfig, iteration: int) -> self._completed_steps = metadata["schedules"][PhaseType.training.value]["completed_steps"] else: self._completed_steps = metadata["completed_steps"] - # TODO v0.3: Move barrier, ok file to FastLLMModel + # TODO: Move barrier, ok file to FastLLMModel safe_barrier( self._distributed.world_group, f"load {config.save_name} {iteration} exit", diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 214bb7729..2910c7c76 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -120,7 +120,7 @@ def layer_class(self) -> "type[Attention]": return Attention def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: - return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) + return self.use_flash_attention and distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16) def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: # We have multiple identical rotary modules/preprocessors, so it's simpler to make a new one here. diff --git a/fast_llm/layers/attention/preprocessing.py b/fast_llm/layers/attention/preprocessing.py index 2326b1bf7..204c08ad2 100644 --- a/fast_llm/layers/attention/preprocessing.py +++ b/fast_llm/layers/attention/preprocessing.py @@ -39,8 +39,8 @@ def _create_tensors(self, sequence_length: int, device: torch.device) -> None: self._mask.triu_(-self._config.window_size + 1) self._mask_value = torch.full( [], - torch.finfo(self._distributed_config.training_dtype.torch).min, - dtype=self._distributed_config.training_dtype.torch, + torch.finfo(self._distributed_config.compute_dtype.torch).min, + dtype=self._distributed_config.compute_dtype.torch, device=device, ) @@ -80,7 +80,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: kwargs[AttentionKwargs.attention_mask_value] = TensorMeta.from_dims( (scalar_dim,), tensor_name=AttentionKwargs.attention_mask_value, - dtype=self._distributed_config.training_dtype.torch, + dtype=self._distributed_config.compute_dtype.torch, ) diff --git a/fast_llm/layers/attention/rotary/config.py b/fast_llm/layers/attention/rotary/config.py index 43bae8c54..5bd7a9b87 100644 --- a/fast_llm/layers/attention/rotary/config.py +++ b/fast_llm/layers/attention/rotary/config.py @@ -18,16 +18,11 @@ class RotaryConfig(BaseModelConfig): # TODO: Move rotary to its own submodule. @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: if cls is RotaryConfig and cls.get_subclass(default.get("type")) is None: # Default subclass. - return NoRotaryConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) + return NoRotaryConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) def get_layer(self, head_size_dim: TensorDim) -> "Rotary": return self._get_configurable_class()(self, head_size_dim) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 7df2705fa..df5bd8181 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -69,18 +69,13 @@ class BlockConfig(BaseBlockConfig): """ @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: if cls is BlockConfig and cls.get_subclass(default.get("type")) is None: from fast_llm.layers.decoder.config import DecoderBlockConfig # Default subclass. - return DecoderBlockConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) + return DecoderBlockConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) @property def layer_class(self) -> "type[Block]": @@ -107,16 +102,11 @@ def get_block( @config_class(registry=True) class BlockSequenceConfig(BaseModelConfig): @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: if cls is BlockSequenceConfig and cls.get_subclass(default.get("type")) is None: # Default subclass. - return FixedBlockSequenceConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) + return FixedBlockSequenceConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) @abc.abstractmethod def __len__(self) -> int: diff --git a/fast_llm/layers/common/normalization/config.py b/fast_llm/layers/common/normalization/config.py index 33cbd9768..c1ced10df 100644 --- a/fast_llm/layers/common/normalization/config.py +++ b/fast_llm/layers/common/normalization/config.py @@ -52,16 +52,11 @@ def get_layer( return out @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: if cls is NormalizationConfig and cls.get_subclass(default.get("type")) is None: # Default subclass. - return LayerNormalizationConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) + return LayerNormalizationConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) @config_class(dynamic_type={NormalizationConfig: "none"}) @@ -107,20 +102,6 @@ class LayerNormalizationBaseConfig(NormalizationConfig): def module_class(self): pass - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - cls._handle_renamed_field(default, "normalization_type", "type") - cls._handle_renamed_field(default, "layer_norm_eps", "epsilon") - cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") - cls._handle_renamed_field(default, "normalization_implementation", "implementation") - cls._handle_renamed_field(default, "layer_norm_init_range", "initialization_range") - return super()._from_dict(default, strict, flat) - @config_class(dynamic_type={NormalizationConfig: "layer_norm"}) class LayerNormalizationConfig(LayerNormalizationBaseConfig): diff --git a/fast_llm/layers/common/peft/config.py b/fast_llm/layers/common/peft/config.py index d0af61cee..6c7656839 100644 --- a/fast_llm/layers/common/peft/config.py +++ b/fast_llm/layers/common/peft/config.py @@ -15,16 +15,11 @@ class PeftConfig(Config): _abstract = True @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: if cls is PeftConfig and cls.get_subclass(default.get("type")) is None: # Default subclass. - return NoPeftConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) + return NoPeftConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) def apply_linear( self, diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 2d8cc71fd..5f8131b5c 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -45,18 +45,13 @@ class MLPBaseConfig(BlockWithBiasConfig): _abstract = True @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: if cls is MLPBaseConfig and cls.get_subclass(default.get("type")) is None: from fast_llm.layers.decoder.mlp.config import MLPConfig # Default subclass. - return MLPConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) + return MLPConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) @config_class(registry=True) @@ -66,18 +61,13 @@ class MixerConfig(BlockWithBiasConfig): """ @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: if cls is MixerConfig and cls.get_subclass(default.get("type")) is None: from fast_llm.layers.attention.config import AttentionConfig # Default subclass. - return AttentionConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) + return AttentionConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) @config_class(dynamic_type={BlockConfig: "decoder"}) diff --git a/fast_llm/layers/decoder/mlp/mlp.py b/fast_llm/layers/decoder/mlp/mlp.py index fe4879e73..9dd17d698 100644 --- a/fast_llm/layers/decoder/mlp/mlp.py +++ b/fast_llm/layers/decoder/mlp/mlp.py @@ -87,7 +87,7 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c dims = (input_.dims[0], kwargs[AttentionKwargs.sequence_q_dim], self._intermediate_2_dim) # Also adjust the dtype in case of full-precision residual layer_2_input = TensorMeta.from_dims( - dims, tensor_name="intermediate_1", dtype=self._distributed_config.training_dtype.torch + dims, tensor_name="intermediate_1", dtype=self._distributed_config.compute_dtype.torch ) # TODO: Add marginal compute? (ex. activation, gate + up) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 849e09aa9..f59b4cffd 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -350,20 +350,6 @@ class LanguageModelBaseConfig(BaseModelConfig): hint=FieldHint.testing, ) - @classmethod - def from_flat_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - ) -> typing.Self: - # The backward compatibility fix in `NormalizationArchitectureConfig` - # won't work for older checkpoints saved with a flat config. - # TODO v0.3: Remove flat format - cls._handle_renamed_field(default, "normalization_type", "type") - cls._handle_renamed_field(default, "layer_norm_eps", "epsilon") - cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") - return super().from_flat_dict(default, strict) - def __len__(self) -> int: return len(self.decoder) + 2 * self.output_layer.prediction_heads diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index e0661cfa2..1d1e13a5b 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -51,7 +51,7 @@ def __init__( self._residual_dtype = ( self._distributed_config.optimization_dtype if self._config.full_precision_residual - else self._distributed_config.training_dtype + else self._distributed_config.compute_dtype ).torch self._sequence_parallel = self._distributed_config.sequence_tensor_parallel self._vocab_parallel = self._distributed_config.tensor_parallel > 1 and self._config.vocab_parallel diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 9cd77ff37..8fbb99cad 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -88,26 +88,6 @@ class GPTBaseModelConfig(LanguageModelBaseConfig): default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing ) - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - # TODO v0.3: Remove backward compatibility fix - if "transposed_mlp_weight" in default: - assert default.pop("transposed_mlp_weight") - if "match_megatron" in default: - assert "use_megatron_initialization" not in default - default["use_megatron_initialization"] = default.pop("match_megatron") - if "layer_norm_impl" in default: - assert "normalization_implementation" not in default - default["normalization_implementation"] = default.pop("layer_norm_impl") - if "fused_mlp" in default: - del default["fused_mlp"] - return super()._from_dict(default, strict, flat) - @config_class(dynamic_type={FastLLMModelConfig: "gpt"}) class GPTModelConfig(FastLLMModelConfig): @@ -197,29 +177,6 @@ def _validate(self) -> None: ) Assert.geq(output_layer.prediction_heads, output_layer.prediction_heads) - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - # TODO v0.x: Remove backward compatibility. - cls._handle_renamed_field( - default, ("data", "sampling", "use_loss_masking_spans"), ("batch", "use_loss_masking_spans") - ) - if "truncate_documents" in default.get("data", {}): - # Backward compatibility for the legacy truncate_documents field. - # TODO v0.x: Remove backward compatibility. - logger.warning( - "`data.truncate_documents` field is deprecated. " "Please use `batch.truncate_documents` instead." - ) - assert "truncate_documents" not in default.get("batch", {}) - if "batch" not in default: - default["batch"] = {} - default["batch"]["truncate_documents"] = default["data"].pop("truncate_documents") - return super()._from_dict(default, strict, flat) - @classmethod def get_trainer_class(cls) -> type["GPTTrainer"]: from fast_llm.models.gpt.trainer import GPTTrainer diff --git a/tests/data/common.py b/tests/data/common.py index 6614accce..d8cc6fff2 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -77,9 +77,8 @@ def get_test_data_and_compare_samples( sequence_length: int = 512, vocab_size=TEST_VOCAB_SIZE, expected_samples: dict[str, list[list[int]]] | list[list[int]], - legacy: bool = False, ) -> GPTData: - distributed_config = DistributedConfig(seed=seed if legacy else 87522) + distributed_config = DistributedConfig(seed=87522) distributed = Distributed(distributed_config, use_cpu=True) if isinstance(samples_per_dataset, int): samples_per_dataset = {PhaseType.training.value.lower(): samples_per_dataset} @@ -97,11 +96,7 @@ def get_test_data_and_compare_samples( expected_samples = {PhaseType.training.value.lower(): expected_samples} assert "sampling" not in config - config["sampling"] = GPTSamplingConfig( - seed=87522 if legacy else seed, - gpu=gpu, - shuffle=shuffle, - ) + config["sampling"] = GPTSamplingConfig(seed=seed, gpu=gpu, shuffle=shuffle) data = GPTData(GPTDataConfig.from_dict(config), distributed_config) data.setup(distributed, sampling_parameters, cache_directory) with NoAutoValidate(): diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 312807aad..e64b47020 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -46,17 +46,6 @@ def _get_blending_alt(probs: list[float], num_samples: int) -> tuple[np.ndarray, [3036, 253, 207, 2968, 4536, 1178], ] -GPT_BLENDED_LEGACY_SAMPLES = [ - [1725, 74, 207, 1635, 4440, 2774], - [359, 489, 4266, 2052, 5351, 80], - [328, 80, 263, 890, 1797, 88], - [374, 7534, 87, 1073, 79, 480], - [8008, 498, 71, 727, 80, 315], - [2210, 8179, 73, 2582, 897, 1178], - [1852, 71, 776, 7878, 7390, 80], - [409, 5091, 328, 1378, 5483, 88], -] - GPT_BLENDED_MIXED_SAMPLES = [ [4709, 819, 79, 207, 277, 1790], [916, 6683, 7685, 1277, 5106, 378], @@ -144,7 +133,7 @@ def test_gpt_blended_data(): get_test_data_and_compare_samples( { "datasets": { - "Training": { + "training": { "type": "blended", "datasets": [ {"type": "memmap", "path": DATASET_PREFIX}, @@ -160,22 +149,6 @@ def test_gpt_blended_data(): ) -def test_gpt_blended_data_legacy(): - get_test_dataset() - _get_test_dataset_mix_1() - get_test_data_and_compare_samples( - { - "format": "list", - "path": ["0.75", str(DATASET_PREFIX), "0.25", str(_DATASET_PREFIX_MIX_1)], - "split": [1, 0, 0], - }, - 8, - sequence_length=5, - expected_samples=GPT_BLENDED_LEGACY_SAMPLES, - legacy=True, - ) - - def test_gpt_blended_mixed(): # Make sure dataset blending works and check for unintended changes in behavior. get_test_dataset() @@ -198,7 +171,7 @@ def test_gpt_blended_mixed_data(): get_test_data_and_compare_samples( { "datasets": { - "Training": { + "training": { "type": "blended", "datasets": [{"type": "memmap", "path": DATASET_PREFIX}, {"type": "random"}], "weights": [0.6, 0.4], diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index 6cc5d639a..2c025cbaf 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -44,7 +44,7 @@ def test_gpt_concatenate_data(): get_test_data_and_compare_samples( { "datasets": { - "Training": { + "training": { "type": "concatenated", "datasets": [{"type": "memmap", "path": DATASET_PREFIX} for _ in range(3)], } diff --git a/tests/data/test_concatenated_memmap.py b/tests/data/test_concatenated_memmap.py deleted file mode 100644 index 35d93d9d5..000000000 --- a/tests/data/test_concatenated_memmap.py +++ /dev/null @@ -1,78 +0,0 @@ -import pytest - -from fast_llm.data.dataset.gpt.config import GPTConcatenatedMemmapConfig -from tests.data.common import ( - compare_indexed_dataset, - get_dataset_config, - get_sampling_data, - get_test_data_and_compare_samples, - validate_indexed_dataset_sampling, -) -from tests.data.test_memmap import MEMMAP_DATASET_SAMPLES -from tests.utils.dataset import get_test_concatenated_memmap_dataset -from tests.utils.global_variables import DATASET_CACHE - -_DATASET_PREFIX_MIX_CONCATENATED_MEMMAP = DATASET_CACHE / "concatenated_memmap" - - -def _get_test_dataset_concatenated_memmap(): - return get_test_concatenated_memmap_dataset(_DATASET_PREFIX_MIX_CONCATENATED_MEMMAP, 4) - - -CONCATENATED_MEMMAP_DATASET_LENGTH = 24806 -CONCATENATED_MEMMAP_DATASET_TOKENS = 2033639 -CONCATENATED_MEMMAP_DATASET_SAMPLES = { - **MEMMAP_DATASET_SAMPLES, - 6930: [65, 2327], - 11962: [7078, 2713, 1431], - 15958: [207], - 19362: [69], - 24098: [555, 668, 70], -} -CONCATENATED_MEMMAP_SAMPLES = [ - [7554, 80, 5970, 87, 477, 4119], - [4119, 6506, 74, 447, 87, 277], - [277, 320, 2597, 4117, 301, 727], - [727, 330, 3067, 2740, 81, 417], - [417, 1486, 542, 248, 540, 1364], - [1364, 7072, 2516, 2455, 79, 207], - [207, 727, 2204, 2379, 540, 1322], - [1322, 365, 2009, 72, 489, 1886], -] - - -def test_gpt_concatenated_memmap(): - # Make sure dataset splitting works and check for unintended changes in behavior. - _get_test_dataset_concatenated_memmap() - # samples[9:18] - with pytest.warns(DeprecationWarning): - dataset = get_dataset_config( - {"type": "concatenated_memmap", "path": _DATASET_PREFIX_MIX_CONCATENATED_MEMMAP}, - GPTConcatenatedMemmapConfig, - ).build() - compare_indexed_dataset( - dataset, - CONCATENATED_MEMMAP_DATASET_LENGTH, - CONCATENATED_MEMMAP_DATASET_TOKENS, - CONCATENATED_MEMMAP_DATASET_SAMPLES, - ) - sampled = dataset.sample(get_sampling_data(8, sequence_length=5)) - validate_indexed_dataset_sampling(sampled, CONCATENATED_MEMMAP_SAMPLES) - - -def test_gpt_concatenated_memmap_data(): - _get_test_dataset_concatenated_memmap() - with pytest.warns(DeprecationWarning): - get_test_data_and_compare_samples( - { - "datasets": { - "Training": { - "type": "concatenated_memmap", - "path": _DATASET_PREFIX_MIX_CONCATENATED_MEMMAP, - } - } - }, - 8, - sequence_length=5, - expected_samples=CONCATENATED_MEMMAP_SAMPLES, - ) diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 551134fd2..c9212d6e3 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -21,17 +21,6 @@ [86, 49152, 89, 542, 395, 89], ] -GPT_FIM_SAMPLES_LEGACY = [ - [1725, 74, 207, 1635, 4440, 2774], - [359, 489, 4266, 2052, 5351, 80], - [86, 49152, 89, 22255, 1073, 79], - [8008, 498, 71, 727, 80, 315], - [2210, 8179, 73, 2582, 897, 1178], - [86, 89, 88, 49152, 87, 49152], - [86, 49152, 83, 744, 89, 64], - [86, 89, 1461, 49152, 87, 49152], -] - def test_gpt_fim(): # Make sure the FIM wrapper works in a simple case and check for unintended changes in behavior. @@ -63,7 +52,7 @@ def test_gpt_fim_data(): get_test_data_and_compare_samples( { "datasets": { - "Training": { + "training": { "type": "fim", "dataset": {"type": "memmap", "path": DATASET_PREFIX}, "rate": 0.5, @@ -80,21 +69,3 @@ def test_gpt_fim_data(): expected_samples=GPT_FIM_SAMPLES, vocab_size=49157, ) - - -def test_gpt_fim_data_legacy(): - get_test_dataset() - get_test_data_and_compare_samples( - { - "format": "list", - "path": [str(DATASET_PREFIX)], - "fim": {"rate": 0.5, "prefix_token": "w", "middle_token": "x", "pad_token": "y", "suffix_token": "z"}, - "tokenizer": {"path": TOKENIZER_PATH}, - "split": [1, 0, 0], - }, - 8, - sequence_length=5, - expected_samples=GPT_FIM_SAMPLES_LEGACY, - legacy=True, - vocab_size=49157, - ) diff --git a/tests/data/test_random.py b/tests/data/test_random.py index 72a6080a7..8e5c61904 100644 --- a/tests/data/test_random.py +++ b/tests/data/test_random.py @@ -26,7 +26,7 @@ def test_gpt_random_data(): get_test_data_and_compare_samples( { "datasets": { - "Training": { + "training": { "type": "random", } } @@ -35,13 +35,3 @@ def test_gpt_random_data(): sequence_length=7, expected_samples=RANDOM_DATASET_EXPECTED_SAMPLES, ) - - -def test_gpt_random_data_legacy(): - get_test_data_and_compare_samples( - {"format": "random"}, - 4, - sequence_length=7, - expected_samples=RANDOM_DATASET_EXPECTED_SAMPLES, - legacy=True, - ) diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index a2996aa1c..6a2be3dcc 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -34,16 +34,6 @@ [1178, 3291, 317, 277, 2679, 89], [89, 542, 395, 583, 684, 554], ] -GPT_MEMMAP_SAMPLES_LEGACY = [ - [1725, 74, 207, 1635, 4440, 2774], - [359, 489, 4266, 2052, 5351, 80], - [374, 7534, 87, 1073, 79, 480], - [8008, 498, 71, 727, 80, 315], - [2210, 8179, 73, 2582, 897, 1178], - [409, 5091, 328, 1378, 5483, 88], - [83, 4457, 3316, 333, 489, 317], - [330, 155, 2449, 1136, 1106, 5370], -] def test_gpt_sampled(): @@ -60,7 +50,7 @@ def test_gpt_sampled_data(): get_test_data_and_compare_samples( { "datasets": { - "Training": { + "training": { "type": "memmap", "path": DATASET_PREFIX, } @@ -72,16 +62,6 @@ def test_gpt_sampled_data(): ) -def test_gpt_sampled_data_legacy(): - get_test_data_and_compare_samples( - {"format": "list", "path": [str(DATASET_PREFIX)], "split": [1, 0, 0]}, - 8, - sequence_length=5, - expected_samples=GPT_MEMMAP_SAMPLES_LEGACY, - legacy=True, - ) - - class SimpleGPTIndexedDataset(GPTIndexedDataset): # TODO: worth adding to the main codebase? def __init__(self, samples): diff --git a/tests/data/test_slice.py b/tests/data/test_slice.py index 1440614cb..1fc8df1eb 100644 --- a/tests/data/test_slice.py +++ b/tests/data/test_slice.py @@ -27,23 +27,6 @@ [3712, 86, 476, 80, 2547, 7390], ] -GPT_SLICE_TRAINING_SAMPLES_LEGACY = [ - [2625, 76, 2625, 2639, 74, 243], - [207, 481, 5546, 74, 414, 498], - [74, 333, 1963, 310, 5337, 3628], - [79, 2361, 80, 2012, 84, 480], -] -GPT_SLICE_VALIDATION_SAMPLES_LEGACY = [ - [2352, 3687, 2311, 4900, 542, 3732], - [2551, 5283, 900, 3140, 328, 68], - [7979, 2283, 329, 727, 2740, 2818], - [4117, 8056, 79, 1798, 243, 498], - [243, 542, 387, 6476, 6686, 785], - [95, 6641, 207, 279, 2304, 602], - [89, 4446, 947, 293, 947, 1544], - [243, 3712, 86, 476, 80, 2547], -] - def test_gpt_slice(): # Make sure dataset splitting works and check for unintended changes in behavior. @@ -89,17 +72,3 @@ def test_gpt_slice_data(): "validation": GPT_SLICE_VALIDATION_SAMPLES, }, ) - - -def test_gpt_slice_data_legacy(): - get_test_dataset() - get_test_data_and_compare_samples( - {"format": "list", "path": [str(DATASET_PREFIX)], "split": [0.0015, 0.0015, 0.997]}, - {"training": 4, "validation": 8, "test": 5}, - sequence_length=5, - expected_samples={ - "training": GPT_SLICE_TRAINING_SAMPLES_LEGACY, - "validation": GPT_SLICE_VALIDATION_SAMPLES_LEGACY, - }, - legacy=True, - ) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index d52564cc0..f14f028e1 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -104,8 +104,8 @@ def _lm_head( ("config_dict", "distributed_config_dict", "loss_masking"), ( ({}, {}, False), - ({}, {"training_dtype": DataType.bfloat16}, False), - ({"embeddings_layer": {"full_precision_residual": True}}, {"training_dtype": DataType.bfloat16}, False), + ({}, {"compute_dtype": DataType.bfloat16}, False), + ({"embeddings_layer": {"full_precision_residual": True}}, {"compute_dtype": DataType.bfloat16}, False), ({"sequence_first": True}, {}, False), ({"output_layer": {"logit_z_loss": 1e-3}}, {}, False), ({"output_layer": {"logits_scale_factor": 5.0}}, {}, False), @@ -195,7 +195,7 @@ def test_lm_head( dtype=( distributed.config.optimization_dtype.torch if config.embeddings_layer.full_precision_residual - else distributed.config.training_dtype.torch + else distributed.config.compute_dtype.torch ), device=distributed.device, requires_grad=True, @@ -239,7 +239,7 @@ def test_lm_head( if config.output_layer.tied_weight or config.output_layer.prediction_heads > 1: logit_weight = ( torch.empty( - VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed.config.training_dtype.torch, device=distributed.device + VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed.config.compute_dtype.torch, device=distributed.device ) .normal_(config.embeddings_layer.hidden_size**-0.5) .requires_grad_(True) @@ -302,9 +302,9 @@ def test_lm_head( output, context = stage.forward(head_input, kwargs, losses) stage.backward(output_grad, context) - threshold = 1e-5 if distributed.config.training_dtype == DataType.float32 else 5e-3 + threshold = 1e-5 if distributed.config.compute_dtype == DataType.float32 else 5e-3 min_threshold = ( - 1e-5 if distributed.config.training_dtype == DataType.float32 else 1e-4 + 1e-5 if distributed.config.compute_dtype == DataType.float32 else 1e-4 ) * config.output_layer.logits_scale_factor Assert.eq(losses.keys(), loss_keys) diff --git a/tests/models/test_generate.py b/tests/models/test_generate.py index ad0de47e6..bce77d4f2 100644 --- a/tests/models/test_generate.py +++ b/tests/models/test_generate.py @@ -61,11 +61,11 @@ def _get_fast_llm_model( updates = {} if use_flash_attention: updates[("base_model", "decoder", "block", "mixer", "use_flash_attention")] = True - updates[("distributed", "training_dtype")] = "bf16" + updates[("distributed", "compute_dtype")] = "bf16" else: updates[("base_model", "decoder", "block", "mixer", "use_flash_attention")] = False if use_bf16: - updates[("distributed", "training_dtype")] = "bf16" + updates[("distributed", "compute_dtype")] = "bf16" return HuggingfaceGPTModelForCausalLM.from_pretrained( CheckpointLoadConfig( path=model_path, @@ -87,11 +87,11 @@ def _get_fast_llm_model_from_model( if use_flash_attention: updates[("model", "base_model", "decoder", "block", "mixer", "use_flash_attention")] = True - updates[("model", "distributed", "training_dtype")] = "bf16" + updates[("model", "distributed", "compute_dtype")] = "bf16" else: updates[("model", "base_model", "decoder", "block", "mixer", "use_flash_attention")] = False if use_bf16: - updates[("model", "distributed", "training_dtype")] = "bf16" + updates[("model", "distributed", "compute_dtype")] = "bf16" config = PretrainedGPTModelConfig.from_dict({}, updates) multi_stage = config.model.get_model_class()(config.model) diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index fdb908b0d..6aa541b8c 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -1,7 +1,15 @@ import os +import typing +import numpy as np import pytest +from fast_llm.config import Field, FieldHint, config_class +from fast_llm.data.dataset.abstract import SampledDataset +from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSampledDatasetConfig, GPTSamplingData +from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.dataset.gpt.sampled import GPTSample, logger +from fast_llm.utils import Assert from tests.utils.compare_tensor_logs import CompareConfig from tests.utils.dataset import get_model_test_dataset from tests.utils.distributed_configs import DistributedTestingConfig @@ -9,6 +17,13 @@ from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda +try: + from fast_llm.csrc.data import build_sample_idx # noqa + + _extension_available = True +except ImportError: + _extension_available = False + @requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.megatron) @@ -51,9 +66,9 @@ def test_match_megatron(run_test_script_for_all_models, model_testing_config, co name="match_megatron", compare="megatron", config_args=[ - "model.distributed.training_dtype=fp32", - "data.datasets={}", - f"data.path={MODEL_DATASET_PREFIX}", + "model.distributed.compute_dtype=fp32", + f'data.datasets.training={{"type":"megatron","path":{MODEL_DATASET_PREFIX}}}', + "data.sampling.seed=1234", "model.base_model.use_megatron_initialization=True", ], num_gpus=1, @@ -62,3 +77,83 @@ def test_match_megatron(run_test_script_for_all_models, model_testing_config, co run_test_script_for_all_models(distributed_testing_config) compare_results_for_all_models(distributed_testing_config) + + +@config_class(dynamic_type={GPTSampledDatasetConfig: "megatron"}) +class GPTMegatronDatasetConfig(GPTMemmapDatasetConfig): + _abstract: typing.ClassVar[bool] = False + path: str = Field( + desc="Dataset path (prefix).", + hint=FieldHint.core, + ) + + def build(self) -> "GPTMemmapDataset": + return GPTMegatronMemmapDataset( + str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens + ) + + +class GPTMegatronMemmapDataset(GPTMemmapDataset): + def sample(self, sampling: GPTSamplingData) -> "MegatronGPTSampledIndexedDataset": + return MegatronGPTSampledIndexedDataset(self, sampling) + + +class MegatronGPTSampledIndexedDataset(SampledDataset): + """ + A GPT sampled dataset that exactly matches Megatron-LM, for testing purposes. + Minimalistic implementation, implements only the required features. + """ + + def __init__( + self, + indexed_dataset: GPTMegatronMemmapDataset, + sampling: GPTSamplingData, + ): + assert isinstance(sampling, GPTSamplingData) + self._indexed_dataset = indexed_dataset + self._num_samples = sampling.parameters.num_samples + self._sequence_length = sampling.parameters.sequence_length + + logger.info(f" > Sampling dataset {self._indexed_dataset.name} ...") + document_sizes = self._indexed_dataset.get_document_sizes() + num_documents = len(document_sizes) + num_tokens = document_sizes.sum() + np_rng = np.random.RandomState(seed=sampling.config.seed) + + # Assume less than one epoch. + Assert.lt(self._sequence_length * self._num_samples, num_tokens) + + self._doc_idx = np.arange(num_documents, dtype=np.int32) + np_rng.shuffle(self._doc_idx) + + assert _extension_available, ( + "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." + ) + + self._sample_idx = build_sample_idx(document_sizes, self._doc_idx, self._sequence_length, 1, num_tokens, True) + self._shuffle_idx = np.arange(0, self._sample_idx.shape[0] - 1, dtype=np.uint32) + np_rng.shuffle(self._shuffle_idx) + + def __len__(self) -> int: + return self._num_samples + + def __getitem__(self, idx: int) -> typing.Any: + shuffled_idx = self._shuffle_idx[idx] + doc_f, offset_f = self._sample_idx[shuffled_idx] + doc_l, offset_l = self._sample_idx[shuffled_idx + 1] + sample_list = [ + self._indexed_dataset.get( + self._doc_idx[doc].item(), + offset=(doc == doc_f) * offset_f, + length=offset_l + 1 - (doc == doc_f) * offset_f if doc == doc_l else None, + ) + for doc in range(doc_f, doc_l + 1) + ] + token_ids = np.concatenate([sample.token_ids for sample in sample_list], dtype=np.int64) + Assert.eq(len(token_ids), self._sequence_length + 1) + + return GPTSample(token_ids=token_ids) + + @property + def name(self) -> str: + return self._indexed_dataset.name diff --git a/tests/test_attention.py b/tests/test_attention.py index 62c34d3c0..dceaa8282 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -29,7 +29,7 @@ def test_varlen_preprocessor(): micro_sequence_length = 12 sequence_length = 36 varlen_preprocessor = FlashAttnVarlenPreprocessor( - AttentionConfig(head_size=64), DistributedConfig(training_dtype="bfloat16") + AttentionConfig(head_size=64), DistributedConfig(compute_dtype="bfloat16") ) for micro_seq_idx in range(int(sequence_length / micro_sequence_length)): kwargs = { diff --git a/tests/test_config.py b/tests/test_config.py index 4e73569b3..6d2583ba3 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -95,7 +95,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): "output_layer": {"tied_weight": False}, }, "multi_stage": {"zero_stage": 3}, - "distributed": {"training_dtype": "bfloat16"}, + "distributed": {"compute_dtype": "bfloat16"}, } ) with NoAutoValidate(): @@ -121,7 +121,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): { "model": { "base_model": base_model_update, - "distributed": {"seed": 1234, "training_dtype": "float16"}, + "distributed": {"seed": 1234, "compute_dtype": "float16"}, }, "pretrained": {"format": "fast_llm", "path": config_path, "load_config": load_config}, } @@ -131,7 +131,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): if load_config == ModelConfigType.fast_llm: expected_config["multi_stage"] = {"zero_stage": 3} - expected_config["distributed"].update({"seed": 1234, "training_dtype": "float16"}) + expected_config["distributed"].update({"seed": 1234, "compute_dtype": "float16"}) if load_config in (ModelConfigType.fast_llm, ModelConfigType.model): expected_config["base_model"] = { "embeddings_layer": { diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index e4cce2935..680faa931 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -66,25 +66,3 @@ def get_model_test_dataset( vocab_size: int = MODEL_TEST_VOCAB_SIZE, ): return get_test_dataset(prefix=prefix, vocab_size=vocab_size) - - -def get_test_concatenated_memmap_dataset( - path: pathlib.Path, - num_files: int, - seed: int = 1234, - num_tokens: int = TEST_DATASET_TOKENS, - characters: str = TEST_CHARACTERS, - vocab_size: int = TEST_VOCAB_SIZE, - seed_shift: int = 55, -): - index_file = path / "index.txt" - if not index_file.is_file(): - for i in range(num_files): - get_test_dataset( - prefix=path / f"dataset_{i}", - seed=seed + i * seed_shift, - num_tokens=num_tokens, - characters=characters, - vocab_size=vocab_size, - ) - index_file.open("w").writelines([str(path / f"dataset_{i}") + "\n" for i in range(num_files)]) diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index 306beadf8..863be2cae 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -87,14 +87,14 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon DistributedTestingConfig( name="bf16", compare="simple", - config_args=["model.distributed.training_dtype=bf16"], + config_args=["model.distributed.compute_dtype=bf16"], num_gpus=1, compare_config=_bf16_compare, ), DistributedTestingConfig( name="fp16", compare="simple", - config_args=["model.distributed.training_dtype=fp16"], + config_args=["model.distributed.compute_dtype=fp16"], num_gpus=1, compare_config=_fp16_compare, ), diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 55ac4ae74..aa8100126 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -279,6 +279,7 @@ def _update_and_add_testing_config( # Megatron messes with the vocab size, so we have to subtract 1. f"--vocab-size={MODEL_TEST_VOCAB_SIZE - 1}", f"--data-path={MODEL_DATASET_PREFIX}", + "--split=1,0,0", "--lr-decay-style=constant", # Initialization is set up to match MCore models (MCore inverts self-attn qkv and dense layers compared to original Megatron) "--use-mcore-models",