Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 79 additions & 3 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,57 @@
from qwix.contrib.sparsity import sparsity_module
import qwix.pallas as qpl
import tokamax
from tokamax._src.ops.ragged_dot.pallas_mosaic_tpu import (
PallasMosaicTpuRaggedDot,
Config,
DEFAULT_RAGGED_DOT_DIM_NUMS,
DLHS_RAGGED_DOT_DIM_NUMS,
DRHS_RAGGED_DOT_DIM_NUMS,
)
from tokamax._src.ops.ragged_dot import base
import dataclasses


@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
class PallasMosaicTpuRaggedDotCustom(PallasMosaicTpuRaggedDot):
"""A custom PallasMosaicTpuRaggedDot subclass that overrides _get_heuristics_config."""

config: Config | None = None
fwd_tile: tuple[int, int, int] = (128, 128, 128)
dlhs_tile: tuple[int, int, int] = (128, 128, 128)
drhs_tile: tuple[int, int, int] = (128, 128, 128)

def __post_init__(self):
qdtype = self.qdtype if self.qdtype is None else jnp.dtype(self.qdtype).name
if self.vjp is None:

def fn(*args, **kw):
# pylint: disable=unexpected-keyword-arg
return PallasMosaicTpuRaggedDotCustom(
qdtype=qdtype,
interpret=self.interpret,
fwd_tile=self.fwd_tile,
dlhs_tile=self.dlhs_tile,
drhs_tile=self.drhs_tile,
)(*args, **kw)

object.__setattr__(
self,
"vjp",
functools.partial(base.vjp, dlhs_ragged_dot=fn, drhs_ragged_dot=fn),
)

def _get_heuristics_config(self, ba) -> Config:
dims = ba.arguments.get("ragged_dot_dimension_numbers", DEFAULT_RAGGED_DOT_DIM_NUMS)
if dims == DEFAULT_RAGGED_DOT_DIM_NUMS:
return Config(tile_m=self.fwd_tile[0], tile_k=self.fwd_tile[1], tile_n=self.fwd_tile[2])
elif dims == DLHS_RAGGED_DOT_DIM_NUMS:
return Config(tile_m=self.dlhs_tile[0], tile_k=self.dlhs_tile[1], tile_n=self.dlhs_tile[2])
elif dims == DRHS_RAGGED_DOT_DIM_NUMS:
return Config(tile_m=self.drhs_tile[0], tile_k=self.drhs_tile[1], tile_n=self.drhs_tile[2])
return Config()


set_xla_metadata = xla_metadata.set_xla_metadata

Expand Down Expand Up @@ -1084,6 +1135,18 @@ def sparse_matmul(
wo_bias,
):
"""Perform sparse matrix multiplication of inputs and Experts."""
config = self.config

gmm_impl_wi = PallasMosaicTpuRaggedDotCustom(
fwd_tile=(config.wi_tile_fwd_batch_seq, config.wi_tile_fwd_embed_dim, config.wi_tile_fwd_mlp_dim),
dlhs_tile=(config.wi_tile_dlhs_batch_seq, config.wi_tile_dlhs_mlp_dim, config.wi_tile_dlhs_embed_dim),
drhs_tile=(config.wi_tile_drhs_batch_seq, config.wi_tile_drhs_embed_dim, config.wi_tile_drhs_mlp_dim),
)
gmm_impl_wo = PallasMosaicTpuRaggedDotCustom(
fwd_tile=(config.wo_tile_fwd_batch_seq, config.wo_tile_fwd_mlp_dim, config.wo_tile_fwd_embed_dim),
dlhs_tile=(config.wo_tile_dlhs_batch_seq, config.wo_tile_dlhs_embed_dim, config.wo_tile_dlhs_mlp_dim),
drhs_tile=(config.wo_tile_drhs_batch_seq, config.wo_tile_drhs_mlp_dim, config.wo_tile_drhs_embed_dim),
)

def jax_ragged_dot_gmm(inputs, kernel, tiling, group_sizes, expert_assignments, padding_amount):
"""Execute jax.lax.ragged_dot, with potential quantization"""
Expand Down Expand Up @@ -1128,6 +1191,15 @@ def jax_ragged_dot_gmm(inputs, kernel, tiling, group_sizes, expert_assignments,
output *= scales
return output

def get_gmm_group_sizes(inputs, kernel, ep):
# Calculates perfectly balanced group sizes where each local expert receives an equal
# share of local tokens, adjusted for expert parallelism.
#
# Note: This function assumes the inputs are ragged and padded to the worst-case size
# (which is generally a factor of EP larger than perfectly balanced). This is why we must
# divide by EP.
return (inputs.shape[0] // kernel.shape[0] // ep,) * kernel.shape[0]

def get_tokamax_group_sizes(group_sizes, inputs, kernel):
# TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm
if self.config.use_qwix_quantization or (
Expand All @@ -1137,9 +1209,10 @@ def get_tokamax_group_sizes(group_sizes, inputs, kernel):
elif self.config.attention == "vllm_rpa":
return group_sizes
else:
ep = self.get_expert_parallelism_size()
return tokamax.RaggedDotGroupSizes(
group_sizes,
(inputs.shape[0] // kernel.shape[0],) * kernel.shape[0],
get_gmm_group_sizes(inputs, kernel, ep),
)

def get_quantization_dtypes():
Expand All @@ -1150,7 +1223,7 @@ def get_quantization_dtypes():
rhs_quantize_dtype = quant_dg.fwd.dg_quantizer.rhs.numerics.get_dtype()
return lhs_quantize_dtype, rhs_quantize_dtype

def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes):
def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, gmm_impl=None):
if inputs.shape[0] != expert_assignments.shape[0]:
raise ValueError("The number of input tokens must match the number of expert assignments!")

Expand Down Expand Up @@ -1184,7 +1257,7 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a
group_sizes=tokamax_group_sizes,
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=self.dtype,
implementation="mosaic",
implementation="mosaic" if gmm_impl is None else [gmm_impl],
)
elif self.config.megablox: # Older forked megablox
output = mblx.gmm(
Expand Down Expand Up @@ -1473,6 +1546,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
w0,
tiling=wi_tile_size,
weight_gather_axes=wi_gather_axes,
gmm_impl=gmm_impl_wi,
)
if self.get_tensor_transpose_parallelism_size() > 1:
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
Expand All @@ -1485,6 +1559,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
w1,
tiling=wi_tile_size,
weight_gather_axes=wi_gather_axes,
gmm_impl=gmm_impl_wi,
)
if self.get_tensor_transpose_parallelism_size() > 1:
layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose")
Expand All @@ -1498,6 +1573,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
wo,
tiling=wo_tile_size,
weight_gather_axes=wo_gather_axes,
gmm_impl=gmm_impl_wo,
)
if self.get_tensor_parallelism_size() > 1:
intermediate_output = jax.lax.psum_scatter(
Expand Down
17 changes: 16 additions & 1 deletion src/maxtext/models/deepseek_batchsplit_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,16 @@ def unroute(

def compute(x, w0, w1, wo, group_sizes, weights, *, config, mesh):
"""Processes routed tokens through the MLP."""
gmm_impl_wi = moe_lib.PallasMosaicTpuRaggedDotCustom(
fwd_tile=(config.wi_tile_fwd_batch_seq, config.wi_tile_fwd_embed_dim, config.wi_tile_fwd_mlp_dim),
dlhs_tile=(config.wi_tile_dlhs_batch_seq, config.wi_tile_dlhs_mlp_dim, config.wi_tile_dlhs_embed_dim),
drhs_tile=(config.wi_tile_drhs_batch_seq, config.wi_tile_drhs_embed_dim, config.wi_tile_drhs_mlp_dim),
)
gmm_impl_wo = moe_lib.PallasMosaicTpuRaggedDotCustom(
fwd_tile=(config.wo_tile_fwd_batch_seq, config.wo_tile_fwd_mlp_dim, config.wo_tile_fwd_embed_dim),
dlhs_tile=(config.wo_tile_dlhs_batch_seq, config.wo_tile_dlhs_embed_dim, config.wo_tile_dlhs_mlp_dim),
drhs_tile=(config.wo_tile_drhs_batch_seq, config.wo_tile_drhs_mlp_dim, config.wo_tile_drhs_embed_dim),
)

def gmm(
inputs,
Expand All @@ -948,6 +958,7 @@ def gmm(
group_sizes,
preferred_element_type,
weight_gather_axes,
gmm_impl=None,
):
if config.use_qwix_quantization:
output = megablox.gmm(
Expand All @@ -968,7 +979,7 @@ def gmm(
group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)),
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=preferred_element_type,
implementation="mosaic",
implementation="mosaic" if gmm_impl is None else [gmm_impl],
)
return output

Expand Down Expand Up @@ -1028,6 +1039,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
w01,
tiling=wi_tile_size,
weight_gather_axes=wi_gather_axes,
gmm_impl=gmm_impl_wi,
)
layer_w0, layer_w1 = jnp.split(layer_w01, 2, axis=-1)
else:
Expand All @@ -1036,12 +1048,14 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
w0,
tiling=wi_tile_size,
weight_gather_axes=wi_gather_axes,
gmm_impl=gmm_impl_wi,
)
layer_w1 = gmm_fn(
x,
w1,
tiling=wi_tile_size,
weight_gather_axes=wi_gather_axes,
gmm_impl=gmm_impl_wi,
)
layer_w0 = jax.ad_checkpoint.checkpoint_name(layer_w0, "mlpwi_0")
layer_w1 = jax.ad_checkpoint.checkpoint_name(layer_w1, "mlpwi_1")
Expand All @@ -1052,6 +1066,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
wo,
tiling=wo_tile_size,
weight_gather_axes=wo_gather_axes,
gmm_impl=gmm_impl_wo,
)
return layer_wo

Expand Down
99 changes: 99 additions & 0 deletions tests/unit/moe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
from maxtext.utils import maxtext_utils
from tests.utils.test_helpers import get_test_config_path
import pytest
from tokamax._src.ops import op
from tokamax._src.ops.ragged_dot.pallas_mosaic_tpu import (
DEFAULT_RAGGED_DOT_DIM_NUMS,
DLHS_RAGGED_DOT_DIM_NUMS,
DRHS_RAGGED_DOT_DIM_NUMS,
)


class TokenDroppingTest(unittest.TestCase):
Expand Down Expand Up @@ -1521,5 +1527,98 @@ def test_prefused_vs_sparse_softmax(self):
self.assertIsNone(bias_updates)


class TokamaxCustomTilingTest(unittest.TestCase):
"""Tests that the WI and WO custom Pallas subclasses apply manual tiling configs."""

def setUp(self):
super().setUp()
self.cfg = pyconfig.initialize(
[None, get_test_config_path()],
run_name="custom_tiling_test",
enable_checkpointing=False,
model_name="deepseek3-tiny",
dtype="bfloat16",
base_emb_dim=256,
base_mlp_dim=512,
wi_tile_fwd_batch_seq=128,
wi_tile_fwd_embed_dim=128,
wi_tile_fwd_mlp_dim=128,
wi_tile_dlhs_batch_seq=256,
wi_tile_dlhs_embed_dim=256,
wi_tile_dlhs_mlp_dim=256,
wi_tile_drhs_batch_seq=512,
wi_tile_drhs_embed_dim=512,
wi_tile_drhs_mlp_dim=512,
wo_tile_fwd_batch_seq=11,
wo_tile_fwd_mlp_dim=22,
wo_tile_fwd_embed_dim=33,
wo_tile_dlhs_batch_seq=44,
wo_tile_dlhs_embed_dim=55,
wo_tile_dlhs_mlp_dim=66,
wo_tile_drhs_batch_seq=77,
wo_tile_drhs_mlp_dim=88,
wo_tile_drhs_embed_dim=99,
override_model_config=True,
)

def test_custom_heuristics_coverage(self):
"""Directly executes all branches of custom_heuristics to verify and cover it."""
config = self.cfg

op_wi = moe.PallasMosaicTpuRaggedDotCustom(
fwd_tile=(config.wi_tile_fwd_batch_seq, config.wi_tile_fwd_embed_dim, config.wi_tile_fwd_mlp_dim),
dlhs_tile=(config.wi_tile_dlhs_batch_seq, config.wi_tile_dlhs_mlp_dim, config.wi_tile_dlhs_embed_dim),
drhs_tile=(config.wi_tile_drhs_batch_seq, config.wi_tile_drhs_embed_dim, config.wi_tile_drhs_mlp_dim),
)
op_wo = moe.PallasMosaicTpuRaggedDotCustom(
fwd_tile=(config.wo_tile_fwd_batch_seq, config.wo_tile_fwd_mlp_dim, config.wo_tile_fwd_embed_dim),
dlhs_tile=(config.wo_tile_dlhs_batch_seq, config.wo_tile_dlhs_embed_dim, config.wo_tile_dlhs_mlp_dim),
drhs_tile=(config.wo_tile_drhs_batch_seq, config.wo_tile_drhs_mlp_dim, config.wo_tile_drhs_embed_dim),
)

def run_heuristics(op_instance, dims):
ba = op.BoundArguments(
op=op_instance,
arguments={
"ragged_dot_dimension_numbers": dims,
},
)
# pylint: disable=protected-access
return op_instance._get_heuristics_config(ba)

# 1. FWD:
wi_fwd_config = run_heuristics(op_wi, DEFAULT_RAGGED_DOT_DIM_NUMS)
self.assertEqual(wi_fwd_config.tile_m, 128)
self.assertEqual(wi_fwd_config.tile_k, 128)
self.assertEqual(wi_fwd_config.tile_n, 128)

wo_fwd_config = run_heuristics(op_wo, DEFAULT_RAGGED_DOT_DIM_NUMS)
self.assertEqual(wo_fwd_config.tile_m, 11)
self.assertEqual(wo_fwd_config.tile_k, 22)
self.assertEqual(wo_fwd_config.tile_n, 33)

# 2. DLHS:
wi_dlhs_config = run_heuristics(op_wi, DLHS_RAGGED_DOT_DIM_NUMS)
self.assertEqual(wi_dlhs_config.tile_m, 256)
self.assertEqual(wi_dlhs_config.tile_k, 256)
self.assertEqual(wi_dlhs_config.tile_n, 256)

wo_dlhs_config = run_heuristics(op_wo, DLHS_RAGGED_DOT_DIM_NUMS)
self.assertEqual(wo_dlhs_config.tile_m, 44)
self.assertEqual(wo_dlhs_config.tile_k, 55)
self.assertEqual(wo_dlhs_config.tile_n, 66)

# 3. DRHS:
wi_drhs_config = run_heuristics(op_wi, DRHS_RAGGED_DOT_DIM_NUMS)
self.assertEqual(wi_drhs_config.tile_m, 512)
self.assertEqual(wi_drhs_config.tile_k, 512)
self.assertEqual(wi_drhs_config.tile_n, 512)

wo_drhs_config = run_heuristics(op_wo, DRHS_RAGGED_DOT_DIM_NUMS)
self.assertEqual(wo_drhs_config.tile_m, 77)
self.assertEqual(wo_drhs_config.tile_k, 88)
self.assertEqual(wo_drhs_config.tile_n, 99)


if __name__ == "__main__":
unittest.main()
Loading