From e1189e9c1da8610e7fa77423d5bc135cd052ba69 Mon Sep 17 00:00:00 2001 From: darisoy Date: Tue, 2 Jun 2026 17:45:42 +0000 Subject: [PATCH] refactor(moe): Remove tokamax_gmm_autotune and unconditionally use custom GMM tile sizes with Pallas subclass This change addresses review comments on PR #3779: 1. Replaces the brittle global monkey-patch and separate WI/WO subclasses with a single, reusable `PallasMosaicTpuRaggedDotCustom` class defined in `moe.py`. 2. Implements custom __post_init__ in this subclass to ensure JAX backward passes (VJP) preserve tile configurations correctly. 3. Removes the configuration flag `tokamax_gmm_autotune` entirely (unconditionally enabling custom tile sizes when Tokamax GMM is active). 4. Updates unit tests in `moe_test.py` to verify the new subclass-based tiling overrides. CONV=2c6843af-dcf7-403b-b67e-2fedd5f81b95 --- src/maxtext/layers/moe.py | 82 ++++++++++++++- src/maxtext/models/deepseek_batchsplit_fp8.py | 17 +++- tests/unit/moe_test.py | 99 +++++++++++++++++++ 3 files changed, 194 insertions(+), 4 deletions(-) diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 7be9db2290..25be00cf14 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -48,6 +48,57 @@ from qwix.contrib.sparsity import sparsity_module import qwix.pallas as qpl import tokamax +from tokamax._src.ops.ragged_dot.pallas_mosaic_tpu import ( + PallasMosaicTpuRaggedDot, + Config, + DEFAULT_RAGGED_DOT_DIM_NUMS, + DLHS_RAGGED_DOT_DIM_NUMS, + DRHS_RAGGED_DOT_DIM_NUMS, +) +from tokamax._src.ops.ragged_dot import base +import dataclasses + + +@jax.tree_util.register_dataclass +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class PallasMosaicTpuRaggedDotCustom(PallasMosaicTpuRaggedDot): + """A custom PallasMosaicTpuRaggedDot subclass that overrides _get_heuristics_config.""" + + config: Config | None = None + fwd_tile: tuple[int, int, int] = (128, 128, 128) + dlhs_tile: tuple[int, int, int] = (128, 128, 128) + drhs_tile: tuple[int, int, int] = (128, 128, 128) + + def __post_init__(self): + qdtype = self.qdtype if self.qdtype is None else jnp.dtype(self.qdtype).name + if self.vjp is None: + + def fn(*args, **kw): + # pylint: disable=unexpected-keyword-arg + return PallasMosaicTpuRaggedDotCustom( + qdtype=qdtype, + interpret=self.interpret, + fwd_tile=self.fwd_tile, + dlhs_tile=self.dlhs_tile, + drhs_tile=self.drhs_tile, + )(*args, **kw) + + object.__setattr__( + self, + "vjp", + functools.partial(base.vjp, dlhs_ragged_dot=fn, drhs_ragged_dot=fn), + ) + + def _get_heuristics_config(self, ba) -> Config: + dims = ba.arguments.get("ragged_dot_dimension_numbers", DEFAULT_RAGGED_DOT_DIM_NUMS) + if dims == DEFAULT_RAGGED_DOT_DIM_NUMS: + return Config(tile_m=self.fwd_tile[0], tile_k=self.fwd_tile[1], tile_n=self.fwd_tile[2]) + elif dims == DLHS_RAGGED_DOT_DIM_NUMS: + return Config(tile_m=self.dlhs_tile[0], tile_k=self.dlhs_tile[1], tile_n=self.dlhs_tile[2]) + elif dims == DRHS_RAGGED_DOT_DIM_NUMS: + return Config(tile_m=self.drhs_tile[0], tile_k=self.drhs_tile[1], tile_n=self.drhs_tile[2]) + return Config() + set_xla_metadata = xla_metadata.set_xla_metadata @@ -1084,6 +1135,18 @@ def sparse_matmul( wo_bias, ): """Perform sparse matrix multiplication of inputs and Experts.""" + config = self.config + + gmm_impl_wi = PallasMosaicTpuRaggedDotCustom( + fwd_tile=(config.wi_tile_fwd_batch_seq, config.wi_tile_fwd_embed_dim, config.wi_tile_fwd_mlp_dim), + dlhs_tile=(config.wi_tile_dlhs_batch_seq, config.wi_tile_dlhs_mlp_dim, config.wi_tile_dlhs_embed_dim), + drhs_tile=(config.wi_tile_drhs_batch_seq, config.wi_tile_drhs_embed_dim, config.wi_tile_drhs_mlp_dim), + ) + gmm_impl_wo = PallasMosaicTpuRaggedDotCustom( + fwd_tile=(config.wo_tile_fwd_batch_seq, config.wo_tile_fwd_mlp_dim, config.wo_tile_fwd_embed_dim), + dlhs_tile=(config.wo_tile_dlhs_batch_seq, config.wo_tile_dlhs_embed_dim, config.wo_tile_dlhs_mlp_dim), + drhs_tile=(config.wo_tile_drhs_batch_seq, config.wo_tile_drhs_mlp_dim, config.wo_tile_drhs_embed_dim), + ) def jax_ragged_dot_gmm(inputs, kernel, tiling, group_sizes, expert_assignments, padding_amount): """Execute jax.lax.ragged_dot, with potential quantization""" @@ -1128,6 +1191,15 @@ def jax_ragged_dot_gmm(inputs, kernel, tiling, group_sizes, expert_assignments, output *= scales return output + def get_gmm_group_sizes(inputs, kernel, ep): + # Calculates perfectly balanced group sizes where each local expert receives an equal + # share of local tokens, adjusted for expert parallelism. + # + # Note: This function assumes the inputs are ragged and padded to the worst-case size + # (which is generally a factor of EP larger than perfectly balanced). This is why we must + # divide by EP. + return (inputs.shape[0] // kernel.shape[0] // ep,) * kernel.shape[0] + def get_tokamax_group_sizes(group_sizes, inputs, kernel): # TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm if self.config.use_qwix_quantization or ( @@ -1137,9 +1209,10 @@ def get_tokamax_group_sizes(group_sizes, inputs, kernel): elif self.config.attention == "vllm_rpa": return group_sizes else: + ep = self.get_expert_parallelism_size() return tokamax.RaggedDotGroupSizes( group_sizes, - (inputs.shape[0] // kernel.shape[0],) * kernel.shape[0], + get_gmm_group_sizes(inputs, kernel, ep), ) def get_quantization_dtypes(): @@ -1150,7 +1223,7 @@ def get_quantization_dtypes(): rhs_quantize_dtype = quant_dg.fwd.dg_quantizer.rhs.numerics.get_dtype() return lhs_quantize_dtype, rhs_quantize_dtype - def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes): + def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, gmm_impl=None): if inputs.shape[0] != expert_assignments.shape[0]: raise ValueError("The number of input tokens must match the number of expert assignments!") @@ -1184,7 +1257,7 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a group_sizes=tokamax_group_sizes, precision=jax.lax.Precision.DEFAULT, preferred_element_type=self.dtype, - implementation="mosaic", + implementation="mosaic" if gmm_impl is None else [gmm_impl], ) elif self.config.megablox: # Older forked megablox output = mblx.gmm( @@ -1473,6 +1546,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): w0, tiling=wi_tile_size, weight_gather_axes=wi_gather_axes, + gmm_impl=gmm_impl_wi, ) if self.get_tensor_transpose_parallelism_size() > 1: layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose") @@ -1485,6 +1559,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): w1, tiling=wi_tile_size, weight_gather_axes=wi_gather_axes, + gmm_impl=gmm_impl_wi, ) if self.get_tensor_transpose_parallelism_size() > 1: layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose") @@ -1498,6 +1573,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): wo, tiling=wo_tile_size, weight_gather_axes=wo_gather_axes, + gmm_impl=gmm_impl_wo, ) if self.get_tensor_parallelism_size() > 1: intermediate_output = jax.lax.psum_scatter( diff --git a/src/maxtext/models/deepseek_batchsplit_fp8.py b/src/maxtext/models/deepseek_batchsplit_fp8.py index cef7c0646f..990cfcc9ef 100644 --- a/src/maxtext/models/deepseek_batchsplit_fp8.py +++ b/src/maxtext/models/deepseek_batchsplit_fp8.py @@ -940,6 +940,16 @@ def unroute( def compute(x, w0, w1, wo, group_sizes, weights, *, config, mesh): """Processes routed tokens through the MLP.""" + gmm_impl_wi = moe_lib.PallasMosaicTpuRaggedDotCustom( + fwd_tile=(config.wi_tile_fwd_batch_seq, config.wi_tile_fwd_embed_dim, config.wi_tile_fwd_mlp_dim), + dlhs_tile=(config.wi_tile_dlhs_batch_seq, config.wi_tile_dlhs_mlp_dim, config.wi_tile_dlhs_embed_dim), + drhs_tile=(config.wi_tile_drhs_batch_seq, config.wi_tile_drhs_embed_dim, config.wi_tile_drhs_mlp_dim), + ) + gmm_impl_wo = moe_lib.PallasMosaicTpuRaggedDotCustom( + fwd_tile=(config.wo_tile_fwd_batch_seq, config.wo_tile_fwd_mlp_dim, config.wo_tile_fwd_embed_dim), + dlhs_tile=(config.wo_tile_dlhs_batch_seq, config.wo_tile_dlhs_embed_dim, config.wo_tile_dlhs_mlp_dim), + drhs_tile=(config.wo_tile_drhs_batch_seq, config.wo_tile_drhs_mlp_dim, config.wo_tile_drhs_embed_dim), + ) def gmm( inputs, @@ -948,6 +958,7 @@ def gmm( group_sizes, preferred_element_type, weight_gather_axes, + gmm_impl=None, ): if config.use_qwix_quantization: output = megablox.gmm( @@ -968,7 +979,7 @@ def gmm( group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)), precision=jax.lax.Precision.DEFAULT, preferred_element_type=preferred_element_type, - implementation="mosaic", + implementation="mosaic" if gmm_impl is None else [gmm_impl], ) return output @@ -1028,6 +1039,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): w01, tiling=wi_tile_size, weight_gather_axes=wi_gather_axes, + gmm_impl=gmm_impl_wi, ) layer_w0, layer_w1 = jnp.split(layer_w01, 2, axis=-1) else: @@ -1036,12 +1048,14 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): w0, tiling=wi_tile_size, weight_gather_axes=wi_gather_axes, + gmm_impl=gmm_impl_wi, ) layer_w1 = gmm_fn( x, w1, tiling=wi_tile_size, weight_gather_axes=wi_gather_axes, + gmm_impl=gmm_impl_wi, ) layer_w0 = jax.ad_checkpoint.checkpoint_name(layer_w0, "mlpwi_0") layer_w1 = jax.ad_checkpoint.checkpoint_name(layer_w1, "mlpwi_1") @@ -1052,6 +1066,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): wo, tiling=wo_tile_size, weight_gather_axes=wo_gather_axes, + gmm_impl=gmm_impl_wo, ) return layer_wo diff --git a/tests/unit/moe_test.py b/tests/unit/moe_test.py index 93547f7d2a..56fbdc9e90 100644 --- a/tests/unit/moe_test.py +++ b/tests/unit/moe_test.py @@ -32,6 +32,12 @@ from maxtext.utils import maxtext_utils from tests.utils.test_helpers import get_test_config_path import pytest +from tokamax._src.ops import op +from tokamax._src.ops.ragged_dot.pallas_mosaic_tpu import ( + DEFAULT_RAGGED_DOT_DIM_NUMS, + DLHS_RAGGED_DOT_DIM_NUMS, + DRHS_RAGGED_DOT_DIM_NUMS, +) class TokenDroppingTest(unittest.TestCase): @@ -1521,5 +1527,98 @@ def test_prefused_vs_sparse_softmax(self): self.assertIsNone(bias_updates) +class TokamaxCustomTilingTest(unittest.TestCase): + """Tests that the WI and WO custom Pallas subclasses apply manual tiling configs.""" + + def setUp(self): + super().setUp() + self.cfg = pyconfig.initialize( + [None, get_test_config_path()], + run_name="custom_tiling_test", + enable_checkpointing=False, + model_name="deepseek3-tiny", + dtype="bfloat16", + base_emb_dim=256, + base_mlp_dim=512, + wi_tile_fwd_batch_seq=128, + wi_tile_fwd_embed_dim=128, + wi_tile_fwd_mlp_dim=128, + wi_tile_dlhs_batch_seq=256, + wi_tile_dlhs_embed_dim=256, + wi_tile_dlhs_mlp_dim=256, + wi_tile_drhs_batch_seq=512, + wi_tile_drhs_embed_dim=512, + wi_tile_drhs_mlp_dim=512, + wo_tile_fwd_batch_seq=11, + wo_tile_fwd_mlp_dim=22, + wo_tile_fwd_embed_dim=33, + wo_tile_dlhs_batch_seq=44, + wo_tile_dlhs_embed_dim=55, + wo_tile_dlhs_mlp_dim=66, + wo_tile_drhs_batch_seq=77, + wo_tile_drhs_mlp_dim=88, + wo_tile_drhs_embed_dim=99, + override_model_config=True, + ) + + def test_custom_heuristics_coverage(self): + """Directly executes all branches of custom_heuristics to verify and cover it.""" + config = self.cfg + + op_wi = moe.PallasMosaicTpuRaggedDotCustom( + fwd_tile=(config.wi_tile_fwd_batch_seq, config.wi_tile_fwd_embed_dim, config.wi_tile_fwd_mlp_dim), + dlhs_tile=(config.wi_tile_dlhs_batch_seq, config.wi_tile_dlhs_mlp_dim, config.wi_tile_dlhs_embed_dim), + drhs_tile=(config.wi_tile_drhs_batch_seq, config.wi_tile_drhs_embed_dim, config.wi_tile_drhs_mlp_dim), + ) + op_wo = moe.PallasMosaicTpuRaggedDotCustom( + fwd_tile=(config.wo_tile_fwd_batch_seq, config.wo_tile_fwd_mlp_dim, config.wo_tile_fwd_embed_dim), + dlhs_tile=(config.wo_tile_dlhs_batch_seq, config.wo_tile_dlhs_embed_dim, config.wo_tile_dlhs_mlp_dim), + drhs_tile=(config.wo_tile_drhs_batch_seq, config.wo_tile_drhs_mlp_dim, config.wo_tile_drhs_embed_dim), + ) + + def run_heuristics(op_instance, dims): + ba = op.BoundArguments( + op=op_instance, + arguments={ + "ragged_dot_dimension_numbers": dims, + }, + ) + # pylint: disable=protected-access + return op_instance._get_heuristics_config(ba) + + # 1. FWD: + wi_fwd_config = run_heuristics(op_wi, DEFAULT_RAGGED_DOT_DIM_NUMS) + self.assertEqual(wi_fwd_config.tile_m, 128) + self.assertEqual(wi_fwd_config.tile_k, 128) + self.assertEqual(wi_fwd_config.tile_n, 128) + + wo_fwd_config = run_heuristics(op_wo, DEFAULT_RAGGED_DOT_DIM_NUMS) + self.assertEqual(wo_fwd_config.tile_m, 11) + self.assertEqual(wo_fwd_config.tile_k, 22) + self.assertEqual(wo_fwd_config.tile_n, 33) + + # 2. DLHS: + wi_dlhs_config = run_heuristics(op_wi, DLHS_RAGGED_DOT_DIM_NUMS) + self.assertEqual(wi_dlhs_config.tile_m, 256) + self.assertEqual(wi_dlhs_config.tile_k, 256) + self.assertEqual(wi_dlhs_config.tile_n, 256) + + wo_dlhs_config = run_heuristics(op_wo, DLHS_RAGGED_DOT_DIM_NUMS) + self.assertEqual(wo_dlhs_config.tile_m, 44) + self.assertEqual(wo_dlhs_config.tile_k, 55) + self.assertEqual(wo_dlhs_config.tile_n, 66) + + # 3. DRHS: + wi_drhs_config = run_heuristics(op_wi, DRHS_RAGGED_DOT_DIM_NUMS) + self.assertEqual(wi_drhs_config.tile_m, 512) + self.assertEqual(wi_drhs_config.tile_k, 512) + self.assertEqual(wi_drhs_config.tile_n, 512) + + wo_drhs_config = run_heuristics(op_wo, DRHS_RAGGED_DOT_DIM_NUMS) + self.assertEqual(wo_drhs_config.tile_m, 77) + self.assertEqual(wo_drhs_config.tile_k, 88) + self.assertEqual(wo_drhs_config.tile_n, 99) + + if __name__ == "__main__": unittest.main()