Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions fast_llm/engine/schedule/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import enum
import functools
import warnings

from fast_llm.config import Config, Field, FieldHint, check_field, config_class, test_field
from fast_llm.engine.distributed.config import DistributedConfig
Expand Down Expand Up @@ -105,11 +104,6 @@ def _validate(self) -> None:

if self._distributed.pipeline_parallel > 1 and self.depth_first_micro_batches > 1:
raise NotImplementedError("Depth-first pipeline parallelism not yet implemented")
if self.depth_first_micro_batches > 1 and self.breadth_first_micro_batches > 1:
warnings.warn(
"Mixing of breadth-first and depth-first gradient accumulation is not thoroughly tested."
" Use at your own risk."
)
super()._validate()


Expand Down
3 changes: 3 additions & 0 deletions fast_llm/layers/ssm/discrete_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace
from fast_llm.layers.common.linear import Linear
from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames
from fast_llm.layers.transformer.config import TransformerKwargs
from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_
from fast_llm.utils import get_lr_scale

Expand Down Expand Up @@ -157,6 +158,8 @@ def forward(self, hidden_states, kwargs):
outputs["hidden_states"]: (B, L, D).
outputs["state"]: inference cache.
"""
if kwargs[TransformerKwargs.sequence_first]:
raise NotImplementedError(f"Sequence-first not supported for SSMs.")

assert _mamba_available
input_ = hidden_states
Expand Down
1 change: 1 addition & 0 deletions fast_llm/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def log_tensor[
) -> (T | None):
if level < 1:
return
tensor = tensor.detach()
save_stats = TensorLogs.config.save
shape = tuple(tensor.shape)
_, dtype = str(tensor.dtype).split("torch.")
Expand Down
6 changes: 6 additions & 0 deletions fast_llm/models/ssm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,12 @@ def _validate(self):
logger.warning(
"HybridSSMModelConfig is being instantiated. This model is experimental and may not work as expected."
)
if (
self.base_model.sequence_first
or self.distributed.sequence_data_parallel > 1
or self.distributed.sequence_tensor_parallel
):
raise NotImplementedError(f"Sequence-first not supported for SSMs.")
super()._validate()


Expand Down
2 changes: 1 addition & 1 deletion fast_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def multiple(x, y):

@staticmethod
def rms_close(x, y, threshold):
rms = rms_diff(x, y).item()
rms = rms_diff(x, y).detach().item()
assert rms <= threshold, f"Rms diff too big ({rms:.3e} > {threshold:.3e}) between tensors {x} and {y}"

@staticmethod
Expand Down
6 changes: 2 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,8 @@ def pytest_configure(config):
else:
worker_id = 0

# TODO: Remove the whole `TEST_RESULTS_PATH` once `get_test_dataset` is parallel-safe.
model_result_path = TEST_RESULTS_PATH / "models"
if model_result_path.exists():
shutil.rmtree(model_result_path)
if TEST_RESULTS_PATH.exists():
shutil.rmtree(TEST_RESULTS_PATH)

num_gpus = torch.cuda.device_count()
if num_gpus > 0 and is_parallel:
Expand Down
36 changes: 20 additions & 16 deletions tests/data/test_concatenated_memmap.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from fast_llm.data.dataset.gpt.config import GPTConcatenatedMemmapConfig
from tests.data.common import (
compare_indexed_dataset,
Expand Down Expand Up @@ -42,10 +44,11 @@ def test_gpt_concatenated_memmap():
# Make sure dataset splitting works and check for unintended changes in behavior.
_get_test_dataset_concatenated_memmap()
# samples[9:18]
dataset = get_dataset_config(
{"type": "concatenated_memmap", "path": _DATASET_PREFIX_MIX_CONCATENATED_MEMMAP},
GPTConcatenatedMemmapConfig,
).build()
with pytest.warns(DeprecationWarning):
dataset = get_dataset_config(
{"type": "concatenated_memmap", "path": _DATASET_PREFIX_MIX_CONCATENATED_MEMMAP},
GPTConcatenatedMemmapConfig,
).build()
compare_indexed_dataset(
dataset,
CONCATENATED_MEMMAP_DATASET_LENGTH,
Expand All @@ -58,16 +61,17 @@ def test_gpt_concatenated_memmap():

def test_gpt_concatenated_memmap_data():
_get_test_dataset_concatenated_memmap()
get_test_data_and_compare_samples(
{
"datasets": {
"Training": {
"type": "concatenated_memmap",
"path": _DATASET_PREFIX_MIX_CONCATENATED_MEMMAP,
with pytest.warns(DeprecationWarning):
get_test_data_and_compare_samples(
{
"datasets": {
"Training": {
"type": "concatenated_memmap",
"path": _DATASET_PREFIX_MIX_CONCATENATED_MEMMAP,
}
}
}
},
8,
sequence_length=5,
expected_samples=CONCATENATED_MEMMAP_SAMPLES,
)
},
8,
sequence_length=5,
expected_samples=CONCATENATED_MEMMAP_SAMPLES,
)
2 changes: 2 additions & 0 deletions tests/models/distributed_test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def main(args: list[str] | None = None) -> None:
group = pool.get_process_group(range(world_size), rank)

for name, config in DISTRIBUTED_TESTING_CONFIGS.items():
if model_testing_config.should_skip(config):
continue
if world_size < config.num_gpus:
logger.warning(f"{name} {f"SKIPPED (not enough GPUs: {world_size} < {config.num_gpus})"})")
continue
Expand Down
13 changes: 7 additions & 6 deletions tests/models/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from fast_llm.engine.checkpoint.convert import ConvertConfig
from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName
from fast_llm.utils import Assert
from tests.utils.compare_tensor_logs import CompareConfig, compare_logged_tensor
from tests.utils.compare_tensor_logs import CompareConfig
from tests.utils.distributed_configs import DistributedTestingConfig
from tests.utils.model_configs import ModelTestingConfig, ModelTestingGroup
from tests.utils.save_load_configs import DISTRIBUTED_SAVE_LOAD_CONFIGS, DistributedSaveLoadConfig
Expand Down Expand Up @@ -65,12 +65,15 @@ def do_prepare_resume(distributed_testing_config: DistributedTestingConfig):
@pytest.mark.model_testing_group(ModelTestingGroup.checkpoint)
def test_resume(run_test_script_for_all_models, compare_results_for_all_models, prepare_resume):
distributed_testing_config = DistributedTestingConfig(
name="resume", compare="checkpoint_and_eval", config_args=_CHECKPOINT_AND_EVAL_ARGS
name="resume",
compare="checkpoint_and_eval",
config_args=_CHECKPOINT_AND_EVAL_ARGS,
compare_config=CompareConfig(sub_configs={(("init", "train_1"), None): CompareConfig(ignore_tensors=True)}),
)
prepare_resume(distributed_testing_config)
# Resume from iteration=1 and compare outputs with the baseline run.
run_test_script_for_all_models(distributed_testing_config)
compare_results_for_all_models(distributed_testing_config, ("train_2",))
compare_results_for_all_models(distributed_testing_config)


@requires_cuda
Expand Down Expand Up @@ -304,7 +307,6 @@ def test_huggingface_model(model_testing_config, get_convert_path):
)
)
errors = []
compare = CompareConfig()
auto_model = (
transformers.AutoModel
if model_testing_config.name in ("diffusion_llama", "dream")
Expand All @@ -320,13 +322,12 @@ def test_huggingface_model(model_testing_config, get_convert_path):
print(name)
output = model(test_input)
# TODO: Make a generic comparison util.
compare_logged_tensor(
CompareConfig().compare_tensors(
{"samples": output_ref.logits, "shape": output_ref.logits.shape, "step": 0},
{"samples": output.logits, "shape": output.logits.shape, "step": 0},
errors,
name,
"logits",
compare,
)

if errors:
Expand Down
14 changes: 7 additions & 7 deletions tests/models/test_match_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from tests.utils.compare_tensor_logs import CompareConfig
from tests.utils.dataset import DATASET_PREFIX, get_test_dataset
from tests.utils.dataset import MODEL_DATASET_PREFIX, get_model_test_dataset
from tests.utils.distributed_configs import DistributedTestingConfig
from tests.utils.model_configs import ModelTestingGroup
from tests.utils.utils import requires_cuda
Expand All @@ -17,7 +17,7 @@ def test_megatron(run_distributed_script, model_testing_config, run_test_script_
# Prevent Megatron from complaining.
env["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
env["NVTE_FLASH_ATTN"] = "0"
get_test_dataset()
get_model_test_dataset()
run_distributed_script(
[
"Megatron-LM/pretrain_gpt.py",
Expand All @@ -36,27 +36,27 @@ def test_megatron(run_distributed_script, model_testing_config, run_test_script_
def test_match_megatron(run_test_script_for_all_models, model_testing_config, compare_results_for_all_models):
assert model_testing_config.megatron_args is not None

ignore_tensors = [
ignore_tensors = (
".self_attn.query_key_value.",
".self_attn.query.",
".self_attn.key_value.",
".mlp.layer_2.weight",
".mlp.experts.",
]
)
if model_testing_config.name == "mixtral":
ignore_tensors.extend([".mlp.experts.", ".mlp.layer_1.weight"])
ignore_tensors += (".mlp.experts.", ".mlp.layer_1.weight")

distributed_testing_config = DistributedTestingConfig(
name="match_megatron",
compare="megatron",
config_args=[
"model.distributed.training_dtype=fp32",
"data.datasets={}",
f"data.path={DATASET_PREFIX}",
f"data.path={MODEL_DATASET_PREFIX}",
"model.base_model.use_megatron_initialization=True",
],
num_gpus=1,
compare_config=CompareConfig(ignore_tensors=ignore_tensors),
compare_config=CompareConfig(sub_configs={(None, ignore_tensors): CompareConfig(ignore_tensors=True)}),
)

run_test_script_for_all_models(distributed_testing_config)
Expand Down
15 changes: 12 additions & 3 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,16 @@ def test_model_simple(run_test_script_for_all_models, run_test_script_base_path)
# Parametrize with config name so it shows in test name.
@pytest.mark.parametrize("config_name", SINGLE_GPU_TESTING_CONFIGS)
def test_and_compare_model(
run_test_script_for_all_models, compare_results_for_all_models, config_name, run_test_script_base_path
run_test_script_for_all_models,
compare_results_for_all_models,
config_name,
run_test_script_base_path,
model_testing_config,
):
# We can expect tests to respect the ordering of `SINGLE_GPU_TESTING_CONFIGS`, so compare should have run already.
config = SINGLE_GPU_TESTING_CONFIGS[config_name]
if model_testing_config.should_skip(config):
pytest.skip(f"Configuration not supported.")
if config.compare is not None:
check_subtest_success(run_test_script_base_path / config.compare)
# A baseline config (single-gpu, bf16, flash-attn).
Expand All @@ -40,7 +46,7 @@ def test_and_compare_model(
set_subtest_success(run_test_script_base_path / config.name)

if config.compare is not None:
compare_results_for_all_models(config, ("init", "train_1", "train_2"))
compare_results_for_all_models(config)


@requires_cuda
Expand Down Expand Up @@ -73,12 +79,15 @@ def test_model_distributed(
config_name,
run_test_script_base_path,
report_subtest,
model_testing_config,
):
config = DISTRIBUTED_TESTING_CONFIGS[config_name]
if model_testing_config.should_skip(config):
pytest.skip(f"Configuration not supported.")
if torch.cuda.device_count() < config.num_gpus:
pytest.skip(f"Not enough GPUs: {torch.cuda.device_count()} < {config.num_gpus}")
report_subtest(run_test_script_base_path / config.name, config.num_gpus)
if config.compare is not None:
if not check_subtest_success(run_test_script_base_path / config.compare):
pytest.fail(f"Test {config.compare} failed", pytrace=False)
compare_results_for_all_models(config, ("init", "train_1", "train_2"))
compare_results_for_all_models(config)
Loading