Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
114 commits
Select commit Hold shift + click to select a range
82eed2b
TP mamba
jlamypoirier Jul 21, 2025
4e310c7
TP mamba
jlamypoirier Jul 22, 2025
3cc4118
fix
jlamypoirier Jul 22, 2025
9f7f75c
fix
jlamypoirier Jul 22, 2025
4054e04
fixes
jlamypoirier Jul 23, 2025
0014cc6
fix
jlamypoirier Jul 23, 2025
47ad548
fixes
jlamypoirier Jul 23, 2025
6a074fa
fixes
jlamypoirier Jul 23, 2025
d66651f
Update external
jlamypoirier Jul 23, 2025
50083ba
SSM debugging
jlamypoirier Jul 24, 2025
5006328
Merge branch 'main' into tp_mamba
jlamypoirier Jul 24, 2025
13176bd
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
7b32699
stuff
jlamypoirier Jul 24, 2025
73f591f
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
1feccc8
stuff
jlamypoirier Jul 24, 2025
e528b50
misc
jlamypoirier Jul 24, 2025
b49c42f
misc
jlamypoirier Jul 24, 2025
bb4dcd9
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
c1b7f44
misc
jlamypoirier Jul 24, 2025
31f5d41
misc
jlamypoirier Jul 24, 2025
051bb07
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
0a9ff25
misc
jlamypoirier Jul 24, 2025
e7d9636
Parallel discrete mamba 2
jlamypoirier Jul 24, 2025
c14b764
Mamba 2, misc
jlamypoirier Jul 25, 2025
b605bd2
doc
jlamypoirier Jul 25, 2025
5eea938
fix
jlamypoirier Jul 28, 2025
0a3e2a7
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 28, 2025
2e6d082
fixes
jlamypoirier Jul 28, 2025
b6c8613
misc
jlamypoirier Jul 28, 2025
f0c04cf
Merge remote-tracking branch 'origin/main' into debug_mamba
jlamypoirier Jul 28, 2025
acdfab1
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 28, 2025
e536af9
Concatenated dim
jlamypoirier Jul 28, 2025
017f5cc
fixes
jlamypoirier Jul 28, 2025
93e4c94
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Jul 28, 2025
c41efc2
doc
jlamypoirier Jul 28, 2025
0b8bd5d
cleanup
jlamypoirier Jul 28, 2025
02f8af5
Block interface
jlamypoirier Jul 29, 2025
6bf06d6
fix
jlamypoirier Jul 29, 2025
2ddc3a7
fix
jlamypoirier Jul 29, 2025
c0f1597
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Jul 29, 2025
b2f4476
Merge branch 'tp_mamba' into block_interface
jlamypoirier Jul 29, 2025
ce70b16
fixes
jlamypoirier Jul 29, 2025
a9f733d
fix
jlamypoirier Jul 29, 2025
cef7c15
fix
jlamypoirier Jul 30, 2025
a5eb076
stuff
jlamypoirier Jul 31, 2025
ab484ac
Revert "stuff"
jlamypoirier Jul 31, 2025
b68d360
stuff
jlamypoirier Jul 31, 2025
82c9dbd
misc
jlamypoirier Jul 31, 2025
9fbb9ff
misc
jlamypoirier Jul 31, 2025
44df195
misc
jlamypoirier Jul 31, 2025
3bb03cb
misc
jlamypoirier Jul 31, 2025
98bae95
misc
jlamypoirier Jul 31, 2025
fd731ef
fixes
jlamypoirier Aug 1, 2025
f483321
fixes
jlamypoirier Aug 1, 2025
5a0eabc
Merge remote-tracking branch 'origin/main' into debug_mamba
jlamypoirier Aug 8, 2025
dd288df
Merge branch 'debug_mamba' into concatenated_dim
jlamypoirier Aug 8, 2025
defd6e0
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Aug 8, 2025
8abf258
fixes
jlamypoirier Aug 8, 2025
c16c00f
Merge branch 'tp_mamba' into block_interface
jlamypoirier Aug 8, 2025
07c9211
stuff
jlamypoirier Aug 8, 2025
be99372
Merge branch 'main' into debug_mamba
jlamypoirier Aug 12, 2025
a505f3a
Merge branch 'debug_mamba' into concatenated_dim
jlamypoirier Aug 12, 2025
0cc859a
Merge remote-tracking branch 'origin/main' into concatenated_dim
jlamypoirier Aug 12, 2025
bd4ff0d
doc
jlamypoirier Aug 12, 2025
fd3307d
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Aug 12, 2025
0e2e124
stuff
jlamypoirier Aug 12, 2025
0a5e458
Remove tensor space, fixes
jlamypoirier Aug 14, 2025
797bd73
stuff
jlamypoirier Aug 14, 2025
c0a3782
stuff
jlamypoirier Aug 15, 2025
e60ded4
stuff
jlamypoirier Aug 15, 2025
1483bcc
stuff
jlamypoirier Aug 15, 2025
4deb501
misc
jlamypoirier Aug 15, 2025
fc809e0
Misc, tests pass
jlamypoirier Aug 15, 2025
cdb6710
misc
jlamypoirier Aug 20, 2025
9ce72e0
Move files
jlamypoirier Aug 20, 2025
065b34f
misc
jlamypoirier Aug 20, 2025
4510b7b
misc
jlamypoirier Aug 20, 2025
9a2a7a2
Pr comments
jlamypoirier Aug 21, 2025
8c382a9
Cleanup
jlamypoirier Aug 21, 2025
019e43d
Cleanup
jlamypoirier Aug 21, 2025
3e0f3e5
Cleanup
jlamypoirier Aug 21, 2025
90a3c98
Merge branch 'tp_mamba' into block_interface
jlamypoirier Aug 21, 2025
39960ce
Cleanup
jlamypoirier Aug 21, 2025
1abdd19
fixes
jlamypoirier Aug 21, 2025
7c24292
fixes
jlamypoirier Aug 21, 2025
af2964b
fixes
jlamypoirier Aug 21, 2025
0e62f7d
Merge branch 'tp_mamba' into block_interface
jlamypoirier Aug 21, 2025
654aeeb
Fix merge
jlamypoirier Aug 21, 2025
3f4a8ba
fix
jlamypoirier Aug 27, 2025
9741ba0
stuff
jlamypoirier Aug 27, 2025
be69677
fixes
jlamypoirier Aug 27, 2025
82a70aa
Simplify bias options
jlamypoirier Aug 27, 2025
680980a
stuff
jlamypoirier Aug 29, 2025
3ef7860
Dynamic mlp and block layer creation
jlamypoirier Aug 29, 2025
ecad96b
stuff
jlamypoirier Sep 3, 2025
3fd092c
fix
jlamypoirier Sep 3, 2025
1a3497c
stuff
jlamypoirier Sep 3, 2025
b6e7fce
stuff
jlamypoirier Sep 4, 2025
4dfe2a4
stuff
jlamypoirier Sep 9, 2025
4185741
misc
jlamypoirier Sep 9, 2025
7763296
stuff
jlamypoirier Sep 17, 2025
8249f8a
fix
jlamypoirier Sep 17, 2025
188587e
Merge branch 'main' into concatenated_dim
jlamypoirier Sep 17, 2025
e111509
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Sep 17, 2025
95e0231
Merge branch 'tp_mamba' into block_interface
jlamypoirier Sep 17, 2025
e076c7a
Merge remote-tracking branch 'origin/main' into block_interface
jlamypoirier Sep 18, 2025
2315ac4
Merge branch 'block_interface' into block_interface_weight
jlamypoirier Sep 18, 2025
79356f7
Merge remote-tracking branch 'origin/main' into block_interface_weight
jlamypoirier Sep 18, 2025
e4198a6
Merge branch 'block_interface_weight' into block_interface_mixer_mlp_…
jlamypoirier Sep 18, 2025
7abf263
Merge branch 'block_interface_mixer_mlp_config' into block_interface_…
jlamypoirier Sep 18, 2025
bfc9f84
Merge branch 'block_interface_fine_grained' into block_interface_tflops
jlamypoirier Sep 18, 2025
e68f96c
Merge branch 'block_interface_tflops' into block_interface_convert
jlamypoirier Sep 18, 2025
870afd3
v0.3.0
jlamypoirier Sep 18, 2025
e977336
Merge remote-tracking branch 'origin/main' into v0.3.0
jlamypoirier Sep 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/mistral.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ model:
multi_stage:
zero_stage: 2
distributed:
training_dtype: bf16
compute_dtype: bf16
seed: 984059
run:
experiment_dir: mistral_example
2 changes: 1 addition & 1 deletion fast_llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.0"
__version__ = "0.3.0"
50 changes: 12 additions & 38 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,58 +759,32 @@ def from_dict(
return cls._from_dict(default, strict)

@classmethod
def from_flat_dict(
cls,
default: dict[str, typing.Any],
strict: bool = True,
) -> typing.Self:
# TODO v0.3: Remove flat format
return cls._from_dict(default, strict, True)

@classmethod
def _from_dict(
cls,
default: dict[str, typing.Any],
strict: bool = True,
flat: bool = False,
) -> typing.Self:
# TODO v0.3: Remove flat format
def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self:
out_arg_dict = {"_from_dict_check": True}

# TODO v0.3: Remove backward compatibility fix
if "__class__" in default:
del default["__class__"]

try:
actual_cls = cls.get_subclass(default.get("type"))
except KeyError:
# Try to postpone error to validation.
actual_cls = cls

if actual_cls is not None and actual_cls is not cls:
return actual_cls._from_dict(default, strict=strict, flat=flat)
return actual_cls._from_dict(default, strict=strict)

# Do not validate yet in case the root class sets cross-dependencies in validation.
with NoAutoValidate():
for name, field in cls.fields():
if not field.init or field._field_type != dataclasses._FIELD: # noqa
continue
if flat:
if isinstance(field.type, type) and issubclass(field.type, Config):
out_arg_dict[name] = field.type._from_dict(default, False, True)
elif name in default:
out_arg_dict[name] = default.pop(name)
else:
# Check for nested configs to instantiate.
try:
value = cls._from_dict_nested(default.pop(name, MISSING), field.type, strict)
if value is not MISSING:
out_arg_dict[name] = value
except FieldTypeError as e:
raise FieldTypeError(
f"Invalid field type `{get_type_name(field.type)}` in class {cls._get_class_name()}: "
+ ", ".join(e.args)
)
# Check for nested configs to instantiate.
try:
value = cls._from_dict_nested(default.pop(name, MISSING), field.type, strict)
if value is not MISSING:
out_arg_dict[name] = value
except FieldTypeError as e:
raise FieldTypeError(
f"Invalid field type `{get_type_name(field.type)}` in class {cls._get_class_name()}: "
+ ", ".join(e.args)
)
out = cls(**out_arg_dict) # noqa
if strict and default:
out._unknown_fields = default.copy()
Expand Down
40 changes: 2 additions & 38 deletions fast_llm/data/data/gpt/config.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
import logging
import typing

from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class
from fast_llm.data.config import MultiprocessingContext, TokenizerConfig
from fast_llm.data.data.config import DataConfig
from fast_llm.data.dataset.gpt.config import (
GPTLegacyConfig,
GPTLegacyDatasetConfig,
GPTSampledDatasetConfig,
GPTSamplingConfig,
)
from fast_llm.engine.distributed.config import PhaseType
from fast_llm.data.dataset.gpt.config import GPTSampledDatasetConfig, GPTSamplingConfig
from fast_llm.utils import Assert

logger = logging.getLogger(__name__)


@config_class()
class GPTDataConfig(DataConfig, GPTLegacyConfig):
class GPTDataConfig(DataConfig):
"""
Configuration for the dataset(s), split and sampling.
Currently hard-coded to a GPT dataset.
Expand Down Expand Up @@ -48,32 +41,3 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
desc="Multiprocessing context. Do not touch.",
hint=FieldHint.expert,
)

def _validate(self) -> None:
if not self.datasets:
logger.warning(
"Using the legacy dataset definition format." " Specify it through `data.datasets` instead."
)
self.datasets = {
phase.value.lower(): GPTLegacyDatasetConfig.from_dict(self, strict=False)
for phase in (PhaseType.training, PhaseType.validation, PhaseType.test)
}
super()._validate()

@classmethod
def _from_dict(
cls,
default: dict[str, typing.Any],
strict: bool = True,
flat: bool = False,
) -> typing.Self:
# TODO v0.x: Remove backward compatibility.
if "datasets" in default:
for phase in PhaseType:
if phase.value in default["datasets"]:
rename = phase.value.lower()
logger.warning(f"Renaming dataset {phase.value} to {rename}")
assert rename not in default["datasets"]
default["datasets"][rename] = default["datasets"].pop(phase.value)

return super()._from_dict(default, strict, flat)
19 changes: 2 additions & 17 deletions fast_llm/data/dataset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,6 @@ class BlendedDatasetConfig(SampledDatasetConfig):
desc="The blending weight of each dataset.",
hint=FieldHint.core,
)
legacy: bool = Field(
default=False,
desc="Use the legacy formulas for sub-dataset seeds and sample sizes.",
hint=FieldHint.deprecated,
)

def _validate(self) -> None:
self.weights = normalize_probabilities(self.weights)
Expand All @@ -231,20 +226,10 @@ def build_and_sample(
sampling,
parameters=dataclasses.replace(
sampling.parameters,
num_samples=(
math.ceil(
weight
* (
sampling.parameters.num_samples
+ 5 * (sampling.parameters.num_samples * (1 - weight)) ** 0.5
)
)
if self.legacy
else math.ceil(weight * sampling.parameters.num_samples) + 1
),
num_samples=math.ceil(weight * sampling.parameters.num_samples) + 1,
),
# TODO: Seed may not be unique for nested blended datasets.
config=sampling.config.to_copy({"seed": sampling.config.seed + i * (0 if self.legacy else 697)}),
config=sampling.config.to_copy({"seed": sampling.config.seed + i * 697}),
),
)
for i, (dataset, weight) in enumerate(zip(self.datasets, self.weights, strict=True))
Expand Down
172 changes: 2 additions & 170 deletions fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import dataclasses
import enum
import json
import pathlib
import time
import typing
import warnings

import yaml

Expand All @@ -22,8 +20,7 @@
SamplingData,
SamplingParameters,
)
from fast_llm.engine.distributed.config import PhaseType
from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset, GPTDatasetSlice, GPTIndexedDataset
Expand All @@ -41,7 +38,6 @@ class ShufflingType(str, enum.Enum):
skip_first_epoch = "skip_first_epoch"
# Disable shuffling entirely.
disabled = "disabled"
legacy = "legacy"


@config_class()
Expand Down Expand Up @@ -222,53 +218,14 @@ def _convert_paths(self, config):
return config


# Add user-friendly names for the configs.
@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated_memmap"})
class GPTConcatenatedMemmapConfig(GPTIndexedDatasetConfig):
# TODO v0.3: Remove.
_abstract: typing.ClassVar[bool] = False
path: pathlib.Path = Field(
default=None,
desc="The path to a dataset directory.",
hint=FieldHint.core,
)

def _validate(self) -> None:
warnings.warn("`concatenated_memmap` dataset is deprecated. Use `file` instead.", DeprecationWarning)
super()._validate()

def build(self) -> "GPTConcatenatedDataset":

assert self.path.is_dir()
index_path = self.path / "index.txt"

if index_path.is_file():
prefixes = [self.path / line.strip() for line in index_path.open("r").readlines()]
else:
warnings.warn(
f"The dataset path {self.path} points to a directory."
" The dataset will be indexed automatically, which may be unsafe."
" We recommend using an index file instead."
)
prefixes = [
path.with_suffix("")
for path in self.path.iterdir()
if path.suffix == ".idx" and path.is_file() and path.with_suffix(".bin").is_file()
]
dataset_config = GPTConcatenatedDatasetConfig.from_dict(
{"datasets": [{"type": "memmap", "path": prefix} for prefix in prefixes]}
)
return dataset_config.build()


@config_class()
class FimConfig(Config):
"""
Configuration for FIM.
"""

rate: float = Field(
# TODO: Use meaningful default now that fim is a wrapper? (bad for legacy config)
# TODO: Use meaningful default now that fim is a wrapper?
default=0.0,
desc="FIM rate for each sample.",
hint=FieldHint.core,
Expand Down Expand Up @@ -352,131 +309,6 @@ def build_and_sample(
return GPTFimDataset(self, self.dataset.build_and_sample(sampling), sampling)


class LegacyDatasetSource(str, enum.Enum):
"""
An enum for the different ways to load datasets.
"""

list = "list"
file = "file"
random = "random"


def _validate_split(value: list[int]) -> list[int]:
Assert.leq(len(value), 3)
return value + [0] * (len(value) - 3)


def _validate_path(value: str | list[str]) -> list[str]:
return [value] if isinstance(value, str) else value


@config_class()
class GPTLegacyConfig(Config):
split: list[float] = Field(
default_factory=lambda: [969, 30, 1],
desc="Split ratio for train, valid and test datasets.",
hint=FieldHint.deprecated,
valid=_validate_split,
)
format: LegacyDatasetSource = Field(
default=LegacyDatasetSource.list,
desc="Format for the dataset definition.",
hint=FieldHint.deprecated,
)
path: list[str] = Field(
default_factory=list,
desc="Path or list of paths and weights.",
hint=FieldHint.deprecated,
valid=_validate_path,
)
fim: FimConfig = Field(
desc="Configuration for Fill In the Middle (FIM).",
hint=FieldHint.feature,
)


@config_class(dynamic_type={GPTSampledDatasetConfig: "legacy"})
class GPTLegacyDatasetConfig(GPTSampledDatasetConfig, GPTLegacyConfig):
_abstract: typing.ClassVar[bool] = False

def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset:

if self.format == LegacyDatasetSource.random:
Assert.eq(len(self.path), 0)
dataset_config = GPTRandomDatasetConfig()
else:
if self.format == LegacyDatasetSource.file:
Assert.eq(len(self.path), 1)
data_path = pathlib.Path(self.path[0])
dataset_defs = json.load(data_path.open("r"))
data_base_path = data_path.parent
dataset_prefixes = [
(data_base_path / dataset_def["prefix"]).resolve() for dataset_def in dataset_defs["datasets"]
]
dataset_weights = normalize_probabilities(
[dataset_def["weight"] for dataset_def in dataset_defs["datasets"]]
)
elif self.format == LegacyDatasetSource.list:
Assert.geq(len(self.path), 1)
if len(self.path) == 1:
dataset_prefixes, dataset_weights = [self.path[0].strip()], [1.0]
else:
Assert.custom(lambda x: x % 2 == 0, len(self.path))
dataset_prefixes = [pathlib.Path(x.strip()).resolve() for x in self.path[1::2]]
assert len(dataset_prefixes) == len(set(dataset_prefixes))
dataset_weights = normalize_probabilities([float(x) for x in self.path[::2]])
else:
raise NotImplementedError(self.format)

phase_splits = padded_cumsum(normalize_probabilities(self.split))

phase_index = {
PhaseType.training.value.lower(): 0,
PhaseType.validation.value.lower(): 1,
PhaseType.test.value.lower(): 2,
}[sampling.dataset_name]

dataset_configs = [
{
"type": "slice",
# TODO: this duplicates memmap datasets for each phase.
"dataset": {"type": "memmap", "path": prefix},
"begin": float(phase_splits[phase_index]),
"end": float(phase_splits[phase_index + 1]),
}
for prefix in dataset_prefixes
]
dataset_config = (
{
"type": "blended",
"name": "blended",
"datasets": dataset_configs,
"weights": dataset_weights,
"legacy": True,
}
if len(dataset_configs) > 1
else dataset_configs[0]
)
if self.fim.rate > 0:
dataset_config = {
"type": "fim",
"dataset": dataset_config,
**self.fim.to_dict(),
}
# Legacy sampling config
dataset_config = {
"type": "sampled",
"dataset": dataset_config,
"sampling": {
"seed": sampling.distributed.config.seed,
"shuffle": "legacy",
},
}

return GPTSampledDatasetConfig.from_dict(dataset_config).build_and_sample(sampling)


@config_class(dynamic_type={GPTSampledDatasetConfig: "test_slow"})
class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig):
"""
Expand Down
Loading