From b8cf44039731d22d81e23ab0e95ab260056ad9d1 Mon Sep 17 00:00:00 2001 From: Shuwen Fang Date: Tue, 7 Apr 2026 22:48:50 +0000 Subject: [PATCH] update Update remove dense matmul changes update formatting update fixes --- src/maxtext/layers/attention_op.py | 37 ++++++++++++------------------ src/maxtext/layers/moe.py | 12 +++++++++- src/maxtext/utils/sharding.py | 13 +++++++++++ 3 files changed, 39 insertions(+), 23 deletions(-) diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index d265459e69..0fea057876 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -32,7 +32,7 @@ from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask import jax.numpy as jnp -from jax.sharding import Mesh, NamedSharding +from jax.sharding import Mesh from maxtext.common.common_types import ( Array, AttentionType, @@ -78,7 +78,7 @@ from maxtext.layers.initializers import variable_to_logically_partitioned from maxtext.layers.quantizations import AqtQuantization as Quant from maxtext.utils import max_utils -from maxtext.utils.sharding import logical_to_mesh_axes, maybe_shard_with_name +from maxtext.utils.sharding import logical_to_mesh_axes, maybe_shard_with_pspec import numpy as np from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_kernel from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_mask @@ -1484,26 +1484,19 @@ def kernel_fn(q, k, v, d, s): return attention_output, None - def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None): - # decoder_segment_ids can be None - if pspec is None: - return None - sharding = NamedSharding(self.mesh, pspec) - return maybe_shard_with_name( - inputs, - sharding, - shard_mode=self.config.shard_mode, - debug_sharding=self.config.debug_sharding, - extra_stack_level=1, - ) - - query = _maybe_shard_with_pspec(query, axis_names_q) - key = _maybe_shard_with_pspec(key, axis_names_kv) - value = _maybe_shard_with_pspec(value, axis_names_kv) - decoder_segment_ids_q = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_q) - decoder_segment_ids_kv = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_kv) - sinks = _maybe_shard_with_pspec(sinks, sink_axis_names) - indexer_mask = _maybe_shard_with_pspec(indexer_mask, indexer_mask_axis_names) + query = maybe_shard_with_pspec(query, self.mesh, self.config.shard_mode, axis_names_q, self.config.debug_sharding) + key = maybe_shard_with_pspec(key, self.mesh, self.config.shard_mode, axis_names_kv, self.config.debug_sharding) + value = maybe_shard_with_pspec(value, self.mesh, self.config.shard_mode, axis_names_kv, self.config.debug_sharding) + decoder_segment_ids_q = maybe_shard_with_pspec( + decoder_segment_ids, self.mesh, self.config.shard_mode, segment_axis_names_q, self.config.debug_sharding + ) + decoder_segment_ids_kv = maybe_shard_with_pspec( + decoder_segment_ids, self.mesh, self.config.shard_mode, segment_axis_names_kv, self.config.debug_sharding + ) + sinks = maybe_shard_with_pspec(sinks, self.mesh, self.config.shard_mode, sink_axis_names, self.config.debug_sharding) + indexer_mask = maybe_shard_with_pspec( + indexer_mask, self.mesh, self.config.shard_mode, indexer_mask_axis_names, self.config.debug_sharding + ) ret = wrap_flash_attention( query, diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 314c450b03..5669ac5fae 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -36,7 +36,7 @@ from maxtext.kernels import megablox as mblx from maxtext.utils import max_logging from maxtext.utils import max_utils -from maxtext.utils.sharding import maybe_shard_with_logical, create_sharding +from maxtext.utils.sharding import maybe_shard_with_logical, create_sharding, maybe_shard_with_pspec from maxtext.utils.sharding import logical_to_mesh_axes import numpy as np import qwix.pallas as qpl @@ -1439,6 +1439,16 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): gate_logits = self._maybe_shard_with_logical(gate_logits, gate_logits_axes) pre_bias_logits = self._maybe_shard_with_logical(pre_bias_logits, pre_bias_logits_axes) + w0_kernel = maybe_shard_with_pspec(w0_kernel, self.mesh, self.config.shard_mode, w0_pspec) + w1_kernel = maybe_shard_with_pspec(w1_kernel, self.mesh, self.config.shard_mode, w1_pspec) + wo_kernel = maybe_shard_with_pspec(wo_kernel, self.mesh, self.config.shard_mode, wo_pspec) + if w0_bias is not None: + w0_bias = maybe_shard_with_pspec(w0_bias, self.mesh, self.config.shard_mode, w0_bias_pspec) + if w1_bias is not None: + w1_bias = maybe_shard_with_pspec(w1_bias, self.mesh, self.config.shard_mode, w1_bias_pspec) + if wo_bias is not None: + wo_bias = maybe_shard_with_pspec(wo_bias, self.mesh, self.config.shard_mode, wo_bias_pspec) + return wrapper( inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias, self.rngs ) diff --git a/src/maxtext/utils/sharding.py b/src/maxtext/utils/sharding.py index 74b22548b0..5b8468c749 100644 --- a/src/maxtext/utils/sharding.py +++ b/src/maxtext/utils/sharding.py @@ -115,6 +115,19 @@ def maybe_shard_with_name( return jax.lax.with_sharding_constraint(inputs, named_sharding) +def maybe_shard_with_pspec(inputs, mesh, shard_mode, pspec: jax.sharding.PartitionSpec | None, debug_sharding=False): + if pspec is None: + return None + sharding = NamedSharding(mesh, pspec) + return maybe_shard_with_name( + inputs, + sharding, + shard_mode=shard_mode, + debug_sharding=debug_sharding, + extra_stack_level=1, + ) + + def maybe_shard_with_logical( inputs, logical_axes, mesh, shard_mode, rules=None, debug_sharding=False, extra_stack_level=0, sharding_desc="" ):