diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 22f23174b..0b7b14ab1 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -15,7 +15,7 @@ class TritonConfig: MAX_BLOCK_SIZE_BYTES = 65536 -class MLPRecomputeLevel(str, enum.Enum): +class MLPRecomputeLevel(enum.StrEnum): none = "none" activation = "activation" activation_and_input = "activation_and_input" diff --git a/fast_llm/functional/triton/sparse_copy.py b/fast_llm/functional/triton/sparse_copy.py index 258a2578b..7c803689c 100644 --- a/fast_llm/functional/triton/sparse_copy.py +++ b/fast_llm/functional/triton/sparse_copy.py @@ -11,10 +11,15 @@ @dataclasses.dataclass() class SparseMap: sparse_rows: torch.Tensor + # The end row for each expert, including padding. `expert_ends[i] = expert_begins[i] + padded_tokens_per_expert[i]` expert_ends: torch.Tensor + # The end row for each expert, excluding padding. `expert_pad_begins[i] = expert_begins[i] + unpadded_tokens_per_expert[i]` expert_pad_begins: torch.Tensor + # The number of rows un the dense tensor, i.e., the number of tokens. num_rows_dense: int + # The number of sparse rows, including padding. `num_rows = expert_ends[-1]` num_rows: int + # The number of sparse rows, excluding padding. `num_rows_unpadded = num_rows_dense * num_experts_per_token` num_rows_unpadded: int num_experts: int num_experts_per_token: int diff --git a/fast_llm/functional/triton/sparse_linear.py b/fast_llm/functional/triton/sparse_linear.py index 9a0864944..ae46655ea 100644 --- a/fast_llm/functional/triton/sparse_linear.py +++ b/fast_llm/functional/triton/sparse_linear.py @@ -1,10 +1,12 @@ +import os + import torch from fast_llm.functional.triton import TritonConfig, tl, tl_constexpr, triton, triton_autotune, triton_jit from fast_llm.functional.triton.sparse_copy import SparseMap from fast_llm.utils import Assert, div -autotune_configs = [ +autotune_configs = ( TritonConfig( {"block_size_row": 128, "block_size_col": 256, "block_size_inner": 64, "group_size_row": 8}, num_stages=3, @@ -45,7 +47,10 @@ num_stages=5, num_warps=2, ), -] +) + +if os.environ.get("FAST_LLM_SKIP_TRITON_AUTOTUNE"): + autotune_configs = (autotune_configs[2],) @triton_autotune( @@ -255,13 +260,13 @@ def output_sparse_matmul_kernel( def output_sparse_matmul( lhs: torch.Tensor, rhs: torch.Tensor, - sparse_map: SparseMap | None, + sparse_map: SparseMap | None = None, out: torch.Tensor | None = None, accumulate: bool = False, ) -> torch.Tensor: """ - Output-sparse matrix multiplication with a sparse column dimension, - i.e., with a mapping row_index -> sparse_index (obtained from expert_ends). + Output-sparse matrix multiplication with a sparse column dimension + and a mapping row_index -> sparse_index (obtained from expert_ends). Ex.: MLP layer 1 forward (Y = X x W1^T), MLP layer 2 input grad (gY = gZ x W2). Formula: out[i, js] = sum_k(lhs[i, k] * rhs[k, jd]), where jd = js + col_sparse_dim * sparse_index[i] sparse_index[i] = sum(expert_ends <= i) @@ -381,13 +386,13 @@ def input_inner_sparse_matmul_kernel( def input_inner_sparse_matmul( lhs: torch.Tensor, rhs: torch.Tensor, - sparse_map: SparseMap | None, + sparse_map: SparseMap | None = None, out: torch.Tensor | None = None, accumulate: bool = False, ) -> torch.Tensor: """ - Left-input-sparse matrix multiplication with a sparse inner dimension, - i.e., with a mapping row_index -> sparse_index (obtained from expert_ends). + Left-input-sparse matrix multiplication with a sparse inner dimension + and a mapping row_index -> sparse_index (obtained from expert_ends). Ex.: MLP layer 2 forward (Z = Y x W2^T), MLP layer 1 input grad (gX = gY x W1). Formula: out[i, j] = sum_ks(lhs[i, ks] * rhs[kd, j]), where kd = ks + inner_sparse_dim * sparse_index[i] sparse_index[i] = sum(expert_ends <= i) @@ -511,13 +516,13 @@ def input_row_sparse_matmul_kernel( def input_row_sparse_matmul( lhs: torch.Tensor, rhs: torch.Tensor, - sparse_map: SparseMap | None, + sparse_map: SparseMap | None = None, out: torch.Tensor | None = None, accumulate: bool = False, ) -> torch.Tensor: """ - Left-input-sparse matrix multiplication with a sparse row dimension, - i.e., with a mapping inner_index -> sparse_index. + Left-input-sparse matrix multiplication with a sparse row dimension + and a mapping inner_index -> sparse_index. Ex.: MLP layer 1 weight grad (gW1 = gY^T x X), MLP layer 2 weight grad (gW2^T = Y^T x gZ). Formula: out[id, j] = sum_ks(lhs[is, ks] * rhs[ks, j]), where sparse_begin[sparse_index[id]] <= ks < sparse_end[sparse_index[id]], diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 526d66c01..d6a2f7e1a 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -3,14 +3,13 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 from fast_llm.layers.ssm.llamba_block import LlambaBlock from fast_llm.layers.ssm.mamba_layer import MambaLayer from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.models.gpt.model import GPTBaseModel +from fast_llm.models.gpt.model import GPTBaseModel, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType logger = logging.getLogger(__name__) @@ -135,7 +134,7 @@ def get_layers(self) -> list[Layer]: return layers -class HybridSSMModel[ConfigType: HybridSSMModelConfig](FastLLMModel[ConfigType]): +class HybridSSMModel[ConfigType: HybridSSMModelConfig](GPTModel[ConfigType]): """ A hybrid model that combines Transformer and SSM blocks. """ diff --git a/setup.cfg b/setup.cfg index b3b1df036..b1e44e814 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,6 +24,8 @@ CORE = safetensors>=0.5.3 # Update the base image (version fixed to ensure there is a wheel for the base image), may need --no-build-isolation flash-attn==2.7.3 + # Dropless MLP is broken with triton 3.2.0, 3.3.0 and 3.3.1. TODO: Remove once a working triton version is released. + triton==3.1.0 # Small packages required for some optional features and tools. @@ -57,7 +59,7 @@ DEV = pytest-xdist>=3.7.0 # Somehow needed for Megatron to work with base image 24.11 setuptools>=80.9.0 - # dependency manager needs it. + # Dependency manager needs colorama to show colors. colorama>=0.4.6 # Required for building the documentation diff --git a/tests/conftest.py b/tests/conftest.py index 0d25fc5aa..11757176e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -113,6 +113,9 @@ def pytest_configure(config): rendezvous_port=TORCHRUN_DEFAULT_PORT + 2 * worker_id + 1, ) + # Skip slow autotune for tests. The default config has the highest block size, so this shouldn't hide any bug. + os.environ["FAST_LLM_SKIP_TRITON_AUTOTUNE"] = "TRUE" + @pytest.hookimpl(trylast=True) def pytest_collection_modifyitems(config, items: list[pytest.Function]): diff --git a/tests/functional/__init__.py b/tests/functional/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_functional.py b/tests/functional/test_functional.py similarity index 98% rename from tests/test_functional.py rename to tests/functional/test_functional.py index 9211259c2..3ddd5d4fe 100644 --- a/tests/test_functional.py +++ b/tests/functional/test_functional.py @@ -224,8 +224,6 @@ def test_mlp_recomputation(gated, activation_type): @pytest.mark.slow @requires_cuda def test_dropless_mlp(): - # TODO: Fix dropless MOE - pytest.fail("Test fails, aborting to avoid breaking cuda", False) num_experts = 4 experts_per_token = 4 tokens = 256 @@ -273,7 +271,7 @@ def test_dropless_mlp(): sparse_map = get_sparse_map(top_experts, num_experts) for i, recompute_level in enumerate(MLPRecomputeLevel): - print(recompute_level.value) # noqa + print("recompute_level", recompute_level) # noqa input_.grad = None scores.grad = None for param in params: diff --git a/tests/functional/test_sparse_matmul.py b/tests/functional/test_sparse_matmul.py new file mode 100644 index 000000000..899dad967 --- /dev/null +++ b/tests/functional/test_sparse_matmul.py @@ -0,0 +1,154 @@ +import dataclasses +import functools + +import pytest +import torch + +from fast_llm.functional.triton.sparse_copy import SparseMap +from fast_llm.functional.triton.sparse_linear import ( + dense_matmul, + input_inner_sparse_matmul, + input_row_sparse_matmul, + output_sparse_matmul, +) +from fast_llm.utils import Assert +from tests.utils.utils import requires_cuda + + +@dataclasses.dataclass +class _SparseTestData: + dense_dim: int + sparse_dim: int + expert_ends: tuple[int, ...] + tokens_per_expert: tuple[int, ...] + std: float = 0.125 + + @functools.cached_property + def expert_begins(self) -> tuple[int, ...]: + return (0,) + self.expert_ends[:-1] + + @functools.cached_property + def expert_pad_begins(self) -> tuple[int, ...]: + return tuple( + expert_begin + expert_tokens + for expert_begin, expert_tokens in zip(self.expert_begins, self.tokens_per_expert, strict=True) + ) + + @functools.cached_property + def token_dim(self) -> int: + return self.expert_ends[-1] + + @property + def sparse_dim_expanded(self) -> int: + return self.sparse_dim * self.num_experts + + @functools.cached_property + def num_experts(self) -> int: + return len(self.expert_begins) + + @functools.cached_property + def sparse_map(self) -> SparseMap: + return SparseMap( + num_experts=self.num_experts, + expert_ends=torch.tensor(self.expert_ends, device="cuda"), + expert_pad_begins=torch.tensor(self.expert_pad_begins, device="cuda"), + num_rows=self.expert_ends[-1], + # Not needed + sparse_rows=None, + num_rows_dense=None, + num_rows_unpadded=None, + num_experts_per_token=None, + ) + + def normal(self, dim_0: int, dim_1: int) -> torch.Tensor: + return torch.normal(0, self.std, (dim_0, dim_1), device="cuda") + + +_SPARSE_TEST_DATAS = ( + _SparseTestData( + dense_dim=384, + sparse_dim=256, + expert_ends=(128, 384, 512), + tokens_per_expert=(78, 256, 54), + ), + _SparseTestData( + dense_dim=256, + sparse_dim=512, + expert_ends=(128, 256, 256, 384), + tokens_per_expert=(52, 125, 0, 97), + ), +) + + +@requires_cuda +@pytest.mark.slow +@pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS) +def test_dense_matmul(sparse_test_data): + lhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.dense_dim) + rhs = sparse_test_data.normal(sparse_test_data.dense_dim, sparse_test_data.sparse_dim) + + output = dense_matmul(lhs, rhs) + output_ref = torch.matmul(lhs, rhs) + Assert.rms_close(output, output_ref, 1e-3) + + +@requires_cuda +@pytest.mark.slow +@pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS) +def test_output_sparse_matmul(sparse_test_data): + lhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.dense_dim) + rhs = sparse_test_data.normal(sparse_test_data.dense_dim, sparse_test_data.sparse_dim_expanded) + + # Randomly initialize the output to ensure padded values have no effect. + out = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.sparse_dim) + output = output_sparse_matmul(lhs, rhs, sparse_test_data.sparse_map, out) + + output_ref = torch.zeros_like(output) + for i in range(sparse_test_data.num_experts): + # Padded tokens are treated like regular ones. + output_ref[sparse_test_data.expert_begins[i] : sparse_test_data.expert_ends[i]] = torch.matmul( + lhs[sparse_test_data.expert_begins[i] : sparse_test_data.expert_ends[i]], + rhs[:, i * sparse_test_data.sparse_dim : (i + 1) * sparse_test_data.sparse_dim], + ) + + Assert.rms_close(output, output_ref, 1e-3) + + +@requires_cuda +@pytest.mark.slow +@pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS) +def test_input_inner_sparse_matmul(sparse_test_data): + lhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.sparse_dim) + rhs = sparse_test_data.normal(sparse_test_data.sparse_dim_expanded, sparse_test_data.dense_dim) + + output = input_inner_sparse_matmul(lhs, rhs, sparse_test_data.sparse_map) + + output_ref = torch.zeros_like(output) + for i in range(sparse_test_data.num_experts): + # Padded tokens are treated like regular ones. + output_ref[sparse_test_data.expert_begins[i] : sparse_test_data.expert_ends[i]] = torch.matmul( + lhs[sparse_test_data.expert_begins[i] : sparse_test_data.expert_ends[i]], + rhs[i * sparse_test_data.sparse_dim : (i + 1) * sparse_test_data.sparse_dim], + ) + + Assert.rms_close(output, output_ref, 1e-3) + + +@requires_cuda +@pytest.mark.slow +@pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS) +def test_input_row_sparse_matmul(sparse_test_data): + lhs = sparse_test_data.normal(sparse_test_data.sparse_dim, sparse_test_data.token_dim) + rhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.dense_dim) + + output = input_row_sparse_matmul(lhs, rhs, sparse_test_data.sparse_map) + + output_ref = torch.zeros_like(output) + for i in range(sparse_test_data.num_experts): + # Padded tokens are excluded from the sum. + output_ref[i * sparse_test_data.sparse_dim : (i + 1) * sparse_test_data.sparse_dim] = torch.matmul( + lhs[:, sparse_test_data.expert_begins[i] : sparse_test_data.expert_pad_begins[i]], + rhs[sparse_test_data.expert_begins[i] : sparse_test_data.expert_pad_begins[i]], + ) + + Assert.rms_close(output, output_ref, 1e-3) diff --git a/tests/test_triton_kernels.py b/tests/functional/test_triton_kernels.py similarity index 100% rename from tests/test_triton_kernels.py rename to tests/functional/test_triton_kernels.py diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 39fd0840e..aff7d991f 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -30,7 +30,7 @@ def test_checkpoint_and_eval(run_test_script_for_all_models, model_testing_confi + [ "training.checkpoint.interval=1", "training.evaluators.validation.interval=2", - "training.evaluators.validation.evaluators.iterations=1", + "training.evaluators.validation.evaluator.iterations=1", ], ) @@ -63,7 +63,7 @@ def test_resume(run_test_script_for_all_models): [ "training.checkpoint.interval=1", "training.evaluators.validation.interval=2", - "training.evaluators.validation.evaluators.iterations=1", + "training.evaluators.validation.evaluator.iterations=1", ], compare=f"test_checkpoint_and_eval", prepare_fn=_prepare_resume_fn, @@ -79,7 +79,7 @@ def test_resume_frozen(run_test_script_for_all_models): [ "training.checkpoint.interval=1", "training.evaluators.validation.interval=2", - "training.evaluators.validation.evaluators.iterations=1", + "training.evaluators.validation.evaluator.iterations=1", "model.base_model.transformer.mlp_lr_scale=0.", ], compare="test_checkpoint_and_eval", @@ -442,7 +442,12 @@ def test_run_converted_model(model_testing_config, convert_paths): ) errors = [] compare = CompareConfig() - model_as_hf = transformers.AutoModel.from_pretrained( + auto_model = ( + transformers.AutoModel + if model_testing_config.name in ("diffusion_llama", "dream") + else transformers.AutoModelForCausalLM + ) + model_as_hf = auto_model.from_pretrained( convert_paths["huggingface_0"], trust_remote_code=model_testing_config.checkpoint_format.trust_remote_code ).cuda() for name, model in zip(