From 3dc998871146c3866ab5a0acd2ab779159e3c501 Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Wed, 3 Dec 2025 20:31:01 +0000 Subject: [PATCH] deepseek explicit split --- src/MaxText/common_types.py | 2 + src/MaxText/configs/types.py | 2 +- src/MaxText/layers/attention_mla.py | 90 +++++++++++------ src/MaxText/layers/attentions.py | 33 +++--- src/MaxText/layers/deepseek.py | 52 +++++++--- src/MaxText/layers/embeddings.py | 43 +++++++- src/MaxText/layers/moe.py | 115 +++++++++++++++------ src/MaxText/layers/normalizations.py | 7 +- src/MaxText/sharding.py | 4 + tests/attention_test.py | 145 +++++++++++++++++++++++++-- tests/check_gemma3_layers.py | 4 + tests/check_gpt_vs_reference.py | 3 + tests/llama_test.py | 28 ++++-- tests/moe_test.py | 2 + 14 files changed, 421 insertions(+), 109 deletions(-) diff --git a/src/MaxText/common_types.py b/src/MaxText/common_types.py index f26d02cb1..8f7bb2101 100644 --- a/src/MaxText/common_types.py +++ b/src/MaxText/common_types.py @@ -37,7 +37,9 @@ PREFILL_LENGTH = "prefill_activation_length" Q_LENGTH = "activation_q_length" Q_LENGTH_NO_EXP = "activation_q_length_no_exp" +Q_LORA_UP_PROJ = "q_lora_up_proj" KV_LENGTH = "activation_kv_length" +KV_LORA_UP_PROJ = "kv_lora_up_proj" EMBED = "activation_embed" HEAD = "activation_heads" PREFILL_KV_BATCH = "activation_prefill_kv_batch" diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index 86eb77919..724f1db8f 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -1920,7 +1920,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de if self.packing: raise ValueError("For multimodal SFT, `packing` is not yet supported.") if self.shard_mode == ShardMode.EXPLICIT: - supported_decoders = {"simple", "simple_mlp", "llama2"} + supported_decoders = {"simple", "simple_mlp", "llama2", "deepseek"} if self.decoder_block.value not in supported_decoders: raise ValueError( f"Decoder '{self.decoder_block.value}' is not supported with 'explicit' sharding. " diff --git a/src/MaxText/layers/attention_mla.py b/src/MaxText/layers/attention_mla.py index 0baa01272..b34c20c9d 100644 --- a/src/MaxText/layers/attention_mla.py +++ b/src/MaxText/layers/attention_mla.py @@ -38,10 +38,12 @@ EMBED, EP_AS_CONTEXT, HEAD, + Q_LORA_UP_PROJ, KV_BATCH, KV_BATCH_NO_EXP, KV_HEAD, KV_HEAD_DIM, + KV_LORA_UP_PROJ, LENGTH, LENGTH_NO_EXP, MODEL_MODE_PREFILL, @@ -389,6 +391,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No weight_dtype=self.weight_dtype, quant=self.quant, matmul_precision=self.config.matmul_precision, + shard_mode=self.config.shard_mode, rngs=self.rngs, ) else: @@ -403,6 +406,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No weight_dtype=self.weight_dtype, quant=self.quant, matmul_precision=self.config.matmul_precision, + shard_mode=self.config.shard_mode, rngs=self.rngs, ) self.q_norm = RMSNorm( @@ -423,6 +427,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No weight_dtype=self.weight_dtype, quant=self.quant, matmul_precision=self.config.matmul_precision, + shard_mode=self.config.shard_mode, rngs=self.rngs, ) @@ -437,6 +442,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No weight_dtype=self.weight_dtype, quant=self.quant, matmul_precision=self.config.matmul_precision, + shard_mode=self.config.shard_mode, rngs=self.rngs, ) self.kv_norm = RMSNorm( @@ -460,6 +466,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No weight_dtype=self.weight_dtype, quant=self.quant, matmul_precision=self.config.matmul_precision, + shard_mode=self.config.shard_mode, rngs=self.rngs, ) @@ -498,6 +505,18 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No def mla_query_projection(self, inputs_q: Array, inputs_positions: Array, model_mode) -> Array: """Query projection for MLA, e.g. includes LoRA if q_lora_rank > 0.""" + # specify query logical name + if model_mode == MODEL_MODE_PREFILL: + query_logical_name = self.prefill_query_axis_names + wqa_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, Q_LORA_UP_PROJ) + elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: + query_logical_name = self.ep_query_axis_names + wqa_logical_name = (KV_BATCH_NO_EXP, LENGTH, Q_LORA_UP_PROJ) + else: + query_logical_name = self.query_axis_names + wqa_logical_name = (KV_BATCH, LENGTH_NO_EXP, Q_LORA_UP_PROJ) + query_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(query_logical_name)) + wqa_out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(wqa_logical_name)) # Set softmax scaling. self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim self.softmax_scale = self.qk_head_dim**-0.5 @@ -506,47 +525,49 @@ def mla_query_projection(self, inputs_q: Array, inputs_positions: Array, model_m self.softmax_scale = self.softmax_scale * mscale * mscale if self.q_lora_rank == 0: - q = self.query(inputs_q) + q = self.query(inputs_q, out_sharding=query_sharding) else: # LoRA path - low_rank_q = self.wq_a(inputs_q) # [B, L, q_lora_rank] + low_rank_q = self.wq_a(inputs_q, out_sharding=wqa_out_sharding) # [B, L, q_lora_rank] low_rank_q = self.q_norm(low_rank_q) # RMSNorm on low rank - q = self.wq_b(low_rank_q) # [B, L, n_heads * qk_head_dim] + q = self.wq_b(low_rank_q, out_sharding=query_sharding) # [B, L, n_heads * qk_head_dim] # Split into non-positional and rotary parts. q_nope, q_pe = jnp.split(q, [self.qk_nope_head_dim], axis=-1) + q_nope = self._maybe_shard_with_logical(q_nope, query_logical_name) q_pe = self.apply_rotary_embedding(q_pe, inputs_positions=inputs_positions) + q_pe = self._maybe_shard_with_logical(q_pe, query_logical_name) # Query projection is scaled by self.softmax_scale to be consistent MaxText implementation. # DeepSeek v3 was doing it in attention score computation. query = jnp.concatenate([q_nope, q_pe], axis=-1) * self.softmax_scale + query = self._maybe_shard_with_logical(query, query_logical_name) + return query + def mla_get_key_value(self, low_rank_main, key_rope, model_mode): + """get (key,value) pair from mla""" if model_mode == MODEL_MODE_PREFILL: - query = nn.with_logical_constraint(query, self.prefill_query_axis_names) + key_logical_name = self.prefill_key_axis_names + value_logical_name = self.prefill_value_axis_names elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: - query = nn.with_logical_constraint(query, self.ep_query_axis_names) + key_logical_name = self.ep_key_axis_names + value_logical_name = self.ep_value_axis_names else: - query = nn.with_logical_constraint(query, self.query_axis_names) - return query + key_logical_name = self.key_axis_names + value_logical_name = self.value_axis_names - def mla_get_key_value(self, low_rank_main, key_rope, model_mode): - """get (key,value) pair from mla""" - kv_out = self.wkv_b(low_rank_main) + wkva_out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(key_logical_name)) + kv_out = self.wkv_b(low_rank_main, out_sharding=wkva_out_sharding) # Split kv_out into key_nope and value parts. key_nope, value = jnp.split(kv_out, [self.qk_nope_head_dim], axis=-1) key_rope = jnp.broadcast_to(key_rope, (key_nope.shape[0], key_nope.shape[1], self.num_query_heads, key_rope.shape[3])) + key_nope = self._maybe_shard_with_logical(key_nope, key_logical_name) + key_rope = self._maybe_shard_with_logical(key_rope, key_logical_name) key = jnp.concatenate([key_nope, key_rope], axis=-1) - if model_mode == MODEL_MODE_PREFILL: - key = nn.with_logical_constraint(key, self.prefill_key_axis_names) - value = nn.with_logical_constraint(value, self.prefill_value_axis_names) - elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: - key = nn.with_logical_constraint(key, self.ep_key_axis_names) - value = nn.with_logical_constraint(value, self.ep_value_axis_names) - else: - key = nn.with_logical_constraint(key, self.key_axis_names) - value = nn.with_logical_constraint(value, self.value_axis_names) + key = self._maybe_shard_with_logical(key, key_logical_name) + value = self._maybe_shard_with_logical(value, value_logical_name) return key, value def init_mla_kv_caches(self, inputs_kv_shape: Tuple): @@ -637,7 +658,14 @@ def update_mla_kv_caches(self, low_rank_main, key_rope, decoder_segment_ids, mod def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segment_ids, model_mode, previous_chunk): """MLA key/value projection with integrated rotary embedding.""" - low_rank = self.wkv_a(inputs) + if model_mode == MODEL_MODE_PREFILL: + wka_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_LORA_UP_PROJ) + elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: + wka_logical_name = (KV_BATCH_NO_EXP, LENGTH, KV_LORA_UP_PROJ) + else: + wka_logical_name = (KV_BATCH, LENGTH_NO_EXP, KV_LORA_UP_PROJ) + wkva_out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(wka_logical_name)) + low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding) low_rank_main, low_rank_rope = jnp.split(low_rank, [self.kv_lora_rank], axis=-1) low_rank_main = self.kv_norm(low_rank_main) @@ -696,14 +724,17 @@ def __call__( MLA-attended outputs. """ if model_mode == MODEL_MODE_PREFILL: - inputs_q = nn.with_logical_constraint(inputs_q, self.prefill_input_axis_names) - inputs_kv = nn.with_logical_constraint(inputs_kv, self.prefill_input_axis_names) + inputs_q = self._maybe_shard_with_logical(inputs_q, self.prefill_input_axis_names) + inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.prefill_input_axis_names) + out_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV) elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: - inputs_q = nn.with_logical_constraint(inputs_q, self.ep_input_axis_names) - inputs_kv = nn.with_logical_constraint(inputs_kv, self.ep_input_axis_names) + inputs_q = self._maybe_shard_with_logical(inputs_q, self.ep_input_axis_names) + inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.ep_input_axis_names) + out_logical_name = (BATCH_NO_EXP, LENGTH, HEAD, D_KV) else: - inputs_q = nn.with_logical_constraint(inputs_q, self.input_axis_names) - inputs_kv = nn.with_logical_constraint(inputs_kv, self.input_axis_names) + inputs_q = self._maybe_shard_with_logical(inputs_q, self.input_axis_names) + inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names) + out_logical_name = (BATCH, LENGTH_NO_EXP, HEAD, D_KV) query = self.mla_query_projection(inputs_q, inputs_positions, model_mode) key, value, cached_values = self.mla_kv_projection( @@ -724,10 +755,11 @@ def __call__( out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, cached_values) if model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: - out = nn.with_logical_constraint(out, self.ep_out_axis_names) + out = self._maybe_shard_with_logical(out, self.ep_out_axis_names) else: - out = nn.with_logical_constraint(out, self.out_axis_names) + out = self._maybe_shard_with_logical(out, self.out_axis_names) - out = self.out_projection(out) + out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(out_logical_name)) + out = self.out_projection(out, out_sharding=out_sharding) out = checkpoint_name(out, "out_proj") return out, kv_cache diff --git a/src/MaxText/layers/attentions.py b/src/MaxText/layers/attentions.py index 0586953de..dd52038f1 100644 --- a/src/MaxText/layers/attentions.py +++ b/src/MaxText/layers/attentions.py @@ -748,14 +748,17 @@ def init_rotary_embedding(self): rotary_embedding = LLaMARotaryEmbedding( min_timescale=self.config.rope_min_timescale, max_timescale=self.config.rope_max_timescale, + mesh=self.mesh, embedding_dims=rope_embedding_dims, fprop_dtype=self.dtype, use_scale=rope_use_scale, + shard_mode=self.config.shard_mode, rngs=self.rngs, ) elif rope_type.startswith("yarn"): rotary_embedding = YarnRotaryEmbedding( max_position_embeddings=self.config.max_position_embeddings, + mesh=self.mesh, original_max_position_embeddings=self.config.original_max_position_embeddings, beta_fast=self.config.beta_fast, beta_slow=self.config.beta_slow, @@ -766,16 +769,19 @@ def init_rotary_embedding(self): interleave=self.config.rope_interleave, truncate=self.config.rope_truncate, attention_scaling=self.config.rope_attention_scaling, + shard_mode=self.config.shard_mode, rngs=self.rngs, ) elif self.is_qwen3_next: rotary_embedding = Qwen3NextRotaryEmbedding( min_timescale=self.config.rope_min_timescale, max_timescale=self.config.rope_max_timescale, + mesh=self.mesh, embedding_dims=self.config.head_dim, partial_rotary_factor=self.config.partial_rotary_factor, cast_as_fprop_dtype=True, fprop_dtype=self.config.dtype, + shard_mode=self.config.shard_mode, rngs=self.rngs, ) else: @@ -792,9 +798,11 @@ def init_rotary_embedding(self): rotary_embedding = RotaryEmbedding( min_timescale=self.config.rope_min_timescale, max_timescale=max_timescale, + mesh=self.mesh, embedding_dims=rope_embedding_dims, fprop_dtype=self.dtype, rope_linear_scaling_factor=rope_linear_scaling_factor, + shard_mode=self.config.shard_mode, rngs=self.rngs, ) return rotary_embedding @@ -985,28 +993,25 @@ def __call__( output of shape `[batch, length, q_features]`. """ if model_mode == MODEL_MODE_PREFILL: - inputs_q = self._maybe_shard_with_logical(inputs_q, self.prefill_input_axis_names) - inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.prefill_input_axis_names) + input_axis_names = self.prefill_input_axis_names elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: - inputs_q = self._maybe_shard_with_logical(inputs_q, self.ep_input_axis_names) - inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.ep_input_axis_names) + input_axis_names = self.ep_input_axis_names elif model_mode == MODEL_MODE_TRAIN: - inputs_q = self._maybe_shard_with_logical(inputs_q, self.input_axis_names) - inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names) + input_axis_names = self.input_axis_names else: - inputs_q = self._maybe_shard_with_logical(inputs_q, self.decode_input_axis_names) - inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.decode_input_axis_names) + input_axis_names = self.decode_input_axis_names + + inputs_q = self._maybe_shard_with_logical(inputs_q, input_axis_names) + inputs_kv = self._maybe_shard_with_logical(inputs_kv, input_axis_names) + qkv_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(input_axis_names)) # apply projection. if self.config.fused_qkv: query, key, value = self.qkv_projection(inputs_q, proj_name="qkv_proj") else: - query_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(self.query_axis_names)) - query = self.query_projection(inputs_q, out_sharding=query_sharding) - key_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(self.key_axis_names)) - key = self.kv_projection(inputs_kv, proj_name="key", out_sharding=key_sharding) - value_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(self.value_axis_names)) - value = self.kv_projection(inputs_kv, proj_name="value", out_sharding=value_sharding) + query = self.query_projection(inputs_q, out_sharding=qkv_sharding) + key = self.kv_projection(inputs_kv, proj_name="key", out_sharding=qkv_sharding) + value = self.kv_projection(inputs_kv, proj_name="value", out_sharding=qkv_sharding) gate = None if self.is_qwen3_next: diff --git a/src/MaxText/layers/deepseek.py b/src/MaxText/layers/deepseek.py index 2db5690db..e87d8c5c3 100644 --- a/src/MaxText/layers/deepseek.py +++ b/src/MaxText/layers/deepseek.py @@ -16,8 +16,10 @@ # pylint: disable=arguments-differ # pylint: disable=no-name-in-module +from functools import partial + from jax.ad_checkpoint import checkpoint_name -from jax.sharding import Mesh +from jax.sharding import Mesh, NamedSharding import jax.numpy as jnp from flax import linen as nn @@ -30,6 +32,7 @@ from MaxText.layers import moe from MaxText.layers import quantizations from MaxText.layers.quantizations import AqtQuantization as Quant +from MaxText.sharding import maybe_shard_with_logical from MaxText.inference import page_manager from MaxText.common_types import MODEL_MODE_PREFILL @@ -67,7 +70,13 @@ def self_attention_with_norm( else: logical_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") - lnx = nn.with_logical_constraint(lnx, logical_axis_names) + _maybe_shard_with_logical = partial( + maybe_shard_with_logical, + mesh=mesh, + shard_mode=cfg.shard_mode, + ) + lnx_sharding = NamedSharding(mesh, nn.logical_to_mesh_axes(logical_axis_names)) + lnx = _maybe_shard_with_logical(lnx, logical_axis_names) attention_layer = attention_mla.mla_as_linen( config=cfg, @@ -106,12 +115,13 @@ def self_attention_with_norm( decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, model_mode=model_mode, + out_sharding=lnx_sharding, previous_chunk=previous_chunk, page_state=page_state, slot=slot, ) - attention_lnx = nn.with_logical_constraint(attention_lnx, logical_axis_names) + attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names) intermediate_inputs = inputs + attention_lnx # Normalization @@ -123,7 +133,7 @@ def self_attention_with_norm( kernel_axes=("norm",), epsilon=cfg.normalization_layer_epsilon, )(intermediate_inputs) - hidden_states = nn.with_logical_constraint(hidden_states, logical_axis_names) + hidden_states = _maybe_shard_with_logical(hidden_states, logical_axis_names) return hidden_states, intermediate_inputs @@ -169,12 +179,19 @@ def __call__( cfg = self.config if model_mode == MODEL_MODE_PREFILL: logical_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed") + mlp_logical_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_mlp") else: logical_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") + mlp_logical_axis_names = ("activation_batch", "activation_norm_length", "activation_mlp") + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) if isinstance(inputs, tuple): inputs = inputs[0] - inputs = nn.with_logical_constraint(inputs, logical_axis_names) + + _maybe_shard_with_logical = partial(maybe_shard_with_logical, mesh=self.mesh, shard_mode=self.config.shard_mode) + lnx_out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(logical_axis_names)) + mlp_intermediate_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(mlp_logical_axis_names)) + inputs = _maybe_shard_with_logical(inputs, logical_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") hidden_states, intermediate_inputs = self_attention_with_norm( @@ -201,12 +218,17 @@ def __call__( config=cfg, mesh=self.mesh, quant=self.quant, - )(hidden_states, deterministic=deterministic) - mlp_lnx = nn.with_logical_constraint(mlp_lnx, logical_axis_names) + )( + hidden_states, + deterministic=deterministic, + intermediate_sharding=mlp_intermediate_sharding, + out_sharding=lnx_out_sharding, + ) + mlp_lnx = _maybe_shard_with_logical(mlp_lnx, logical_axis_names) layer_output = mlp_lnx + intermediate_inputs layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) - layer_output = nn.with_logical_constraint( + layer_output = _maybe_shard_with_logical( layer_output, logical_axis_names, ) @@ -241,13 +263,19 @@ def __call__( cfg = self.config if model_mode == MODEL_MODE_PREFILL: logical_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed") + mlp_logical_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_mlp") else: logical_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") + mlp_logical_axis_names = ("activation_batch", "activation_norm_length", "activation_mlp") # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) if isinstance(inputs, tuple): inputs = inputs[0] - inputs = nn.with_logical_constraint(inputs, logical_axis_names) + + _maybe_shard_with_logical = partial(maybe_shard_with_logical, mesh=self.mesh, shard_mode=self.config.shard_mode) + lnx_out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(logical_axis_names)) + lnx_intermediate_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(mlp_logical_axis_names)) + inputs = _maybe_shard_with_logical(inputs, logical_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") hidden_states, intermediate_inputs = self_attention_with_norm( @@ -276,12 +304,12 @@ def __call__( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, quant=self.quant, - )(hidden_states) - mlp_lnx = nn.with_logical_constraint(mlp_lnx, logical_axis_names) + )(hidden_states, intermediate_sharding=lnx_intermediate_sharding, out_sharding=lnx_out_sharding) + mlp_lnx = _maybe_shard_with_logical(mlp_lnx, logical_axis_names) layer_output = mlp_lnx + intermediate_inputs layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) - layer_output = nn.with_logical_constraint( + layer_output = _maybe_shard_with_logical( layer_output, logical_axis_names, ) diff --git a/src/MaxText/layers/embeddings.py b/src/MaxText/layers/embeddings.py index 76fb1b8a7..af9020ba2 100644 --- a/src/MaxText/layers/embeddings.py +++ b/src/MaxText/layers/embeddings.py @@ -275,9 +275,11 @@ def __init__( self, min_timescale: int, max_timescale: int, + mesh: Mesh, embedding_dims: int = 0, cast_as_fprop_dtype: bool = True, fprop_dtype: DType = jnp.bfloat16, + shard_mode: ShardMode = ShardMode.AUTO, # Not used in RotaryEmbedding but passed in by nnx.bridge.to_linen. # TODO: Remove when bridge no longer needed rope_linear_scaling_factor: float = 1.0, @@ -297,9 +299,11 @@ def __init__( """ self.min_timescale = min_timescale self.max_timescale = max_timescale + self.mesh = mesh self.embedding_dims = embedding_dims self.cast_as_fprop_dtype = cast_as_fprop_dtype self.fprop_dtype = fprop_dtype + self.shard_mode = shard_mode self.rope_linear_scaling_factor = rope_linear_scaling_factor if self.embedding_dims % 2: @@ -396,10 +400,12 @@ def qwen3_next_rotary_embedding_as_linen( *, min_timescale: int, max_timescale: int, + mesh: Mesh, embedding_dims: int = 0, partial_rotary_factor: float = 0.25, cast_as_fprop_dtype: bool = True, fprop_dtype: DType = jnp.bfloat16, + shard_mode: ShardMode = ShardMode.AUTO, name: str | None = None, ): """Initializes the Qwen3NextRotaryEmbedding module and returns it as a Linen module. @@ -419,10 +425,12 @@ def qwen3_next_rotary_embedding_as_linen( Qwen3NextRotaryEmbedding, min_timescale=min_timescale, max_timescale=max_timescale, + mesh=mesh, embedding_dims=embedding_dims, partial_rotary_factor=partial_rotary_factor, cast_as_fprop_dtype=cast_as_fprop_dtype, fprop_dtype=fprop_dtype, + shard_mode=shard_mode, metadata_fn=variable_to_logically_partitioned, name=name, ) @@ -435,10 +443,12 @@ def __init__( self, min_timescale: int, max_timescale: int, + mesh: Mesh, embedding_dims: int = 0, cast_as_fprop_dtype: bool = True, fprop_dtype: DType = jnp.bfloat16, partial_rotary_factor: float = 0.25, + shard_mode: ShardMode = ShardMode.AUTO, rngs: nnx.Rngs = None, ): """Initializes the Qwen3NextRotaryEmbedding module. @@ -459,9 +469,11 @@ def __init__( super().__init__( min_timescale=min_timescale, max_timescale=max_timescale, + mesh=mesh, embedding_dims=self.rotary_dim, cast_as_fprop_dtype=cast_as_fprop_dtype, fprop_dtype=fprop_dtype, + shard_mode=shard_mode, rngs=rngs, ) @@ -490,10 +502,12 @@ def __init__( self, min_timescale: int, max_timescale: int, + mesh: Mesh, embedding_dims: int = 0, cast_as_fprop_dtype: bool = True, fprop_dtype: DType = jnp.bfloat16, use_scale: bool = True, + shard_mode: ShardMode = ShardMode.AUTO, # Not used in LLaMARotaryEmbedding but passed in by nnx.bridge.to_linen. # TODO: Remove when bridge no longer needed rngs: nnx.Rngs = None, @@ -514,9 +528,11 @@ def __init__( super().__init__( min_timescale=min_timescale, max_timescale=max_timescale, + mesh=mesh, embedding_dims=embedding_dims, cast_as_fprop_dtype=cast_as_fprop_dtype, fprop_dtype=fprop_dtype, + shard_mode=shard_mode, rngs=rngs, ) @@ -625,6 +641,7 @@ def __call__(self, inputs: jax.Array, position: None | jax.Array = None) -> jax. def yarn_rotary_embedding_as_linen( *, embedding_dims: int, + mesh: Mesh, max_position_embeddings: int = 4096 * 4, original_max_position_embeddings: int = 4096, beta_fast: float = 32, @@ -637,6 +654,7 @@ def yarn_rotary_embedding_as_linen( interleave: bool = True, truncate: bool = True, attention_scaling: bool = False, + shard_mode: ShardMode = ShardMode.AUTO, ): """Initializes the YarnRotaryEmbedding module and returns it as a Linen module. @@ -656,6 +674,7 @@ def yarn_rotary_embedding_as_linen( YarnRotaryEmbedding, embedding_dims=embedding_dims, max_position_embeddings=max_position_embeddings, + mesh=mesh, original_max_position_embeddings=original_max_position_embeddings, beta_fast=beta_fast, beta_slow=beta_slow, @@ -668,6 +687,7 @@ def yarn_rotary_embedding_as_linen( interleave=interleave, truncate=truncate, attention_scaling=attention_scaling, + shard_mode=shard_mode, ) @@ -697,6 +717,7 @@ class YarnRotaryEmbedding(nnx.Module): def __init__( self, embedding_dims: int, + mesh: Mesh, max_position_embeddings: int = 4096 * 4, original_max_position_embeddings: int = 4096, beta_fast: float = 32, @@ -705,6 +726,7 @@ def __init__( rope_factor: float = 40, cast_as_fprop_dtype: bool = True, fprop_dtype: DType = jnp.bfloat16, + shard_mode: ShardMode = ShardMode.AUTO, interleave=True, truncate=True, attention_scaling=False, @@ -724,8 +746,16 @@ def __init__( self.fprop_dtype = fprop_dtype self.interleave = interleave self.truncate = truncate + self.mesh = mesh + self.shard_mode = shard_mode self.attention_scaling = attention_scaling + self.freqs_sharding = ( + NamedSharding(mesh, nn.logical_to_mesh_axes(("activation_batch", "activation_length_no_exp", "q_heads"))) + if shard_mode == ShardMode.EXPLICIT + else None + ) + if self.embedding_dims % 2: raise ValueError("Embedding dim for rotary position embedding must be a multiple of 2.") @@ -829,7 +859,7 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array: # Lookup the precomputed frequencies using the position indices. # self.freqs_cis has shape [max_position_embeddings, half_dim] so we use jnp.take along axis 0. # After indexing, shape becomes [B, S, half_dim]; we then add an axis for the heads. - freqs = jnp.take(self.freqs_cis, position, axis=0) # shape: [B, S, half_dim] + freqs = self.freqs_cis.at[position].get(out_sharding=self.freqs_sharding) # shape: [B, S, half_dim] freqs = freqs[:, :, jnp.newaxis, :] # shape: [B, S, 1, half_dim] if self.interleave: @@ -846,7 +876,14 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array: inputs_complex = first_half + 1j * second_half # shape: [B, S, N, half_dim] # Apply the rotary transformation via complex multiplication. - rotated = inputs_complex * freqs # shape: [B, S, N, half_dim] + rotated_sharding = ( + NamedSharding(self.mesh, nn.logical_to_mesh_axes(("activation_batch", "activation_length_no_exp", None, None))) + if self.shard_mode == ShardMode.EXPLICIT + else None + ) + freqs = jnp.broadcast_to(freqs, inputs_complex.shape, out_sharding=rotated_sharding) + rotated = jnp.multiply(inputs_complex, freqs) # shape: [B, S, N, half_dim] + # Convert the complex result back to a real tensor. # Split the complex number into its real and imaginary parts. # [real1, real2, ..., img1, img2, ...] @@ -1025,7 +1062,7 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array: """ if len(inputs.shape) != 4: raise ValueError( - """Input is assumed to be a rank 4 tensor of shape [batch_size_times_tiles, num_patches_incl_cls, + """Input is assumed to be a rank 4 tensor of shape [batch_size_times_tiles, num_patches_incl_cls, num_heads, head_dim].""" ) diff --git a/src/MaxText/layers/moe.py b/src/MaxText/layers/moe.py index 9293c0a78..22fb5a9e6 100644 --- a/src/MaxText/layers/moe.py +++ b/src/MaxText/layers/moe.py @@ -27,10 +27,13 @@ import jax from jax import ad_checkpoint as adc from jax.experimental import xla_metadata +from jax.sharding import NamedSharding, Mesh import jax.numpy as jnp from MaxText import common_types as ctypes from MaxText import max_logging from MaxText import max_utils +from MaxText.common_types import ShardMode +from MaxText.sharding import maybe_shard_with_logical from MaxText.kernels import megablox as mblx from MaxText.layers import attentions, linears, nnx_wrappers, quantizations from MaxText.layers.initializers import NdInitializer, default_bias_init, nd_dense_init, variable_to_logically_partitioned @@ -138,6 +141,7 @@ def __init__( in_features_shape: Union[Iterable[int], int], out_features_shape: Union[Iterable[int], int], model_name: str, + mesh: Mesh, rngs: nnx.Rngs, axis: Union[Iterable[int], int] = -1, weight_dtype: ctypes.DType = jnp.float32, @@ -147,6 +151,7 @@ def __init__( use_bias: bool = False, score_func: str = "", quant: Optional[quantizations.AqtQuantization] = None, + shard_mode: ShardMode = ShardMode.AUTO, matmul_precision: str = "default", ): """Initializes the GateLogit module. @@ -171,6 +176,7 @@ def __init__( self.in_features_shape = linears.canonicalize_tuple(in_features_shape) self.out_features_shape = linears.canonicalize_tuple(out_features_shape) self.model_name = model_name + self.mesh = mesh self.axis = linears.canonicalize_tuple(axis) self.weight_dtype = weight_dtype self.dtype = dtype @@ -179,6 +185,7 @@ def __init__( self.use_bias = use_bias self.score_func = score_func self.quant = quant + self.shard_mode = shard_mode self.matmul_precision = matmul_precision # Parameter initialization @@ -238,6 +245,11 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax. kernel = jnp.asarray(kernel, self.dtype) contract_ind = tuple(range(0, len(norm_axis))) + output_sharding = ( + NamedSharding(self.mesh, nn.logical_to_mesh_axes(("activation_batch_no_exp", "activation_length_no_exp", None))) + if self.shard_mode == ShardMode.EXPLICIT + else None + ) output = linears._compute_dot_general_nnx( inputs, kernel, @@ -246,6 +258,7 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax. self.matmul_precision, self.quant_dot_general, _initializing, + out_sharding=output_sharding, ) pre_bias_logits = None @@ -315,6 +328,7 @@ def __init__( self.gate = GateLogit( in_features_shape=self.config.emb_dim, out_features_shape=self.num_experts, + mesh=self.mesh, model_name=self.config.model_name, dtype=self.dtype, weight_dtype=self.weight_dtype, @@ -324,6 +338,7 @@ def __init__( use_bias=self.config.routed_bias, score_func=self.config.routed_score_func, matmul_precision=self.config.matmul_precision, + shard_mode=config.shard_mode, rngs=self.rngs, ) @@ -394,6 +409,9 @@ def __init__( self.wi_1_bias = None self.wo_bias = None + def _maybe_shard_with_logical(self, inputs, logical_name): + return maybe_shard_with_logical(inputs, logical_name, mesh=self.mesh, shard_mode=self.config.shard_mode) + def get_expert_parallelism_size(self): return self.mesh.shape.get("expert", 1) @@ -1073,14 +1091,17 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r if self.config.mlp_bias: w0_bias, w1_bias, wo_bias = self.transform_bias(selected_experts, w0_bias, w1_bias, wo_bias) + def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): - if pspec_dim_axes is None: return [] + if pspec_dim_axes is None: + return [] axes = (pspec_dim_axes,) if isinstance(pspec_dim_axes, str) else pspec_dim_axes active = [] for ax in axes: if ax and self.mesh.shape.get(ax, 1) > 1: active.append((ax, tensor_dim_index)) return active + wi_gather_axes = [] wo_gather_axes = [] @@ -1227,11 +1248,11 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): if self.config.moe_fsdp_use_two_stage_all_gather: # Unshard on fsdp axis - w0_kernel = nn.with_logical_constraint(w0_kernel, ("exp", "embed_tensor_transpose", "mlp")) - w1_kernel = nn.with_logical_constraint(w1_kernel, ("exp", "embed_tensor_transpose", "mlp")) + w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp", "embed_tensor_transpose", "mlp")) + w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp", "embed_tensor_transpose", "mlp")) # Unshard on fsdp_transpose axis - wo_kernel = nn.with_logical_constraint(wo_kernel, ("exp", "mlp", "embed_tensor_transpose")) + wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp", "mlp", "embed_tensor_transpose")) # Make sure XLA does not optimize by combining above All-Gather to unshard # on FSDP axis and the subsequent unshard on fsdp_transpose axis @@ -1240,9 +1261,24 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): wo_kernel = jax.lax.optimization_barrier(wo_kernel) # Unshard on both fsdp and fsdp_transpose transpose - w0_kernel = nn.with_logical_constraint(w0_kernel, ("exp", "embed_tensor_transpose", "mlp_no_fsdp")) - w1_kernel = nn.with_logical_constraint(w1_kernel, ("exp", "embed_tensor_transpose", "mlp_no_fsdp")) - wo_kernel = nn.with_logical_constraint(wo_kernel, ("exp", "mlp_no_fsdp", "embed_tensor_transpose")) + w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp", "embed_tensor_transpose", "mlp_no_fsdp")) + w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp", "embed_tensor_transpose", "mlp_no_fsdp")) + wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp", "mlp_no_fsdp", "embed_tensor_transpose")) + + if self.get_tensor_transpose_parallelism_size() > 1: + input_axes = (batch_logical_axis, "activation_norm_length", "activation_embed") + else: + input_axes = (batch_logical_axis, "activation_norm_length", None) + + gate_logits_axes = (batch_logical_axis, "activation_norm_length", None) + if self.config.model_name.startswith("deepseek3"): + pre_bias_logits_axes = (batch_logical_axis, "activation_norm_length", None) + else: + pre_bias_logits_axes = None + + inputs = self._maybe_shard_with_logical(inputs, input_axes) + gate_logits = self._maybe_shard_with_logical(gate_logits, gate_logits_axes) + pre_bias_logits = self._maybe_shard_with_logical(pre_bias_logits, pre_bias_logits_axes) return wrapper( inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias, self.rngs @@ -1254,11 +1290,21 @@ def reshape_and_update_weights(self, weights, indices): # output of updated weights: (batch_size, seq_len, num_experts) update_weights = jnp.zeros((weights.shape[0], weights.shape[1], self.num_experts), dtype=self.dtype) index_update = ( - jnp.arange(weights.shape[0])[:, None, None], - jnp.arange(weights.shape[1])[:, None], + self._maybe_shard_with_logical( + jnp.arange(weights.shape[0])[:, None, None], ("activation_batch_no_exp", None, None) + ), + self._maybe_shard_with_logical(jnp.arange(weights.shape[1])[:, None], ("activation_length_no_exp", None)), indices, ) - update_weights = update_weights.at[index_update].set(weights) + weight_sharding = ( + NamedSharding( + self.mesh, + nn.logical_to_mesh_axes(("activation_batch_no_exp", "activation_length_no_exp", None)), + ) + if self.config.shard_mode == ShardMode.EXPLICIT + else None + ) + update_weights = update_weights.at[index_update].set(weights, out_sharding=weight_sharding) return update_weights def get_context_partition_and_sub_seq(self, seq_len): @@ -1309,13 +1355,13 @@ def generate_masks_subgroup(self, top_k_indices, softmax_probs): expert_mask, (batch_size, cp, sub_seq * self.num_experts_per_tok, self.num_experts), ) - expert_mask_fused = nn.with_logical_constraint(expert_mask_fused, ("activation_batch", None, None, None)) + expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activation_batch", None, None, None)) expert_token_count_fused = jnp.cumsum(expert_mask_fused, axis=2) expert_token_count = jnp.reshape( expert_token_count_fused, ((batch_size, cp, sub_seq, self.num_experts_per_tok, self.num_experts)), ) - expert_token_count = nn.with_logical_constraint( + expert_token_count = self._maybe_shard_with_logical( expert_token_count, ("activation_batch", "activation_norm_length", None, None, None), ) @@ -1397,13 +1443,13 @@ def generate_masks(self, top_k_indices, softmax_probs): expert_mask, (batch_size, seq_len * self.num_experts_per_tok, self.num_experts), ) - expert_mask_fused = nn.with_logical_constraint(expert_mask_fused, ("activation_batch", None, None)) + expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activation_batch", None, None)) expert_token_count_fused = jnp.cumsum(expert_mask_fused, axis=1) expert_token_count = jnp.reshape( expert_token_count_fused, ((batch_size, seq_len, self.num_experts_per_tok, self.num_experts)), ) - expert_token_count = nn.with_logical_constraint( + expert_token_count = self._maybe_shard_with_logical( expert_token_count, ("activation_batch", "activation_norm_length", None, None), ) @@ -1486,7 +1532,7 @@ def maybe_all_gather_kernel_weight_in_expert_parallelism( # Otherwise compiler will handle communication automatically # esp. with int8 quantization, kernel will be all-gathered in int8 instead # of weight_dtype - kernel = nn.with_logical_constraint(kernel, kernel_axes) + kernel = self._maybe_shard_with_logical(kernel, kernel_axes) return kernel def dense_matmul( @@ -1503,10 +1549,12 @@ def dense_matmul( ) -> tuple[jax.Array, Optional[jax.Array]]: """Dense matrix multiplication.""" # gate_logits: batch, length, expert - gate_logits = nn.with_logical_constraint(gate_logits, ("activation_batch", "activation_norm_length", None)) + gate_logits = self._maybe_shard_with_logical(gate_logits, ("activation_batch", "activation_norm_length", None)) if self.config.model_name.startswith("deepseek3"): # pre_bias_logits is None for non-DeepSeek v3 models - pre_bias_logits = nn.with_logical_constraint(pre_bias_logits, ("activation_batch", "activation_norm_length", None)) + pre_bias_logits = self._maybe_shard_with_logical( + pre_bias_logits, ("activation_batch", "activation_norm_length", None) + ) top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits, self.rngs) is_llama4_decoder_layer = self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4 if is_llama4_decoder_layer: @@ -1617,10 +1665,10 @@ def dense_matmul( output_einsum = "EBNCM,BNSEC -> BNSM" inputs = jnp.reshape(inputs, (batch_size, cp, sub_seq, inputs.shape[2])) - inputs = nn.with_logical_constraint(inputs, input_axis) + inputs = self._maybe_shard_with_logical(inputs, input_axis) - dispatch_mask = nn.with_logical_constraint(dispatch_mask, mask_axes) - combine_mask = nn.with_logical_constraint(combine_mask, mask_axes) + dispatch_mask = self._maybe_shard_with_logical(dispatch_mask, mask_axes) + combine_mask = self._maybe_shard_with_logical(combine_mask, mask_axes) with jax.named_scope("dispatch"): # only cp during prefill @@ -1628,7 +1676,7 @@ def dense_matmul( dispatch_eimsum, inputs, dispatch_mask, precision=matmul_precision ) if cp > 1: - dispatch = nn.with_logical_constraint( + dispatch = self._maybe_shard_with_logical( dispatch, ( None, @@ -1638,7 +1686,7 @@ def dense_matmul( "activation_embed", ), ) - dispatch = nn.with_logical_constraint( + dispatch = self._maybe_shard_with_logical( dispatch, dispatch_axis, ) @@ -1654,7 +1702,7 @@ def dense_matmul( if self.config.activations_in_float32: layer_w0 = layer_w0.astype(jnp.float32) - layer_w0 = nn.with_logical_constraint( + layer_w0 = self._maybe_shard_with_logical( layer_w0, mlp_axis, ) @@ -1670,7 +1718,7 @@ def dense_matmul( layer_w1 = layer_w1 + w1_bias if self.config.activations_in_float32: layer_w1 = layer_w1.astype(jnp.float32) - layer_w1 = nn.with_logical_constraint( + layer_w1 = self._maybe_shard_with_logical( layer_w1, mlp_axis, ) @@ -1691,7 +1739,7 @@ def dense_matmul( if self.config.activations_in_float32: intermediate_layer = intermediate_layer.astype(jnp.float32) if self.config.model_call_mode != "inference": - intermediate_layer = nn.with_logical_constraint( + intermediate_layer = self._maybe_shard_with_logical( intermediate_layer, ( "activation_exp", @@ -1720,7 +1768,7 @@ def dense_matmul( ) return output, loss else: - inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) + inputs = self._maybe_shard_with_logical(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) with jax.named_scope("wi_0"): layer_w0 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)( "BSM,EMH -> BSEH", inputs, w0_kernel, precision=matmul_precision @@ -1797,7 +1845,9 @@ def retrieve_quantized_weight( wo_kernel = max_utils.unbox_logicallypartioned(wo_kernel) return w0_kernel, w1_kernel, wo_kernel - def __call__(self, inputs: jax.Array) -> tuple[jax.Array, Optional[jax.Array]]: + def __call__( + self, inputs: jax.Array, out_sharding: NamedSharding | None = None + ) -> tuple[jax.Array, Optional[jax.Array]]: cfg = self.config inputs = inputs.astype(cfg.dtype) gate_logits, pre_bias_logits = self.gate(inputs) @@ -1901,9 +1951,14 @@ def __init__( def routed_moe(self): return self.MoeBlock_0 - def __call__(self, inputs: jax.Array) -> jax.Array: - routed_experts, _ = self.routed_moe(inputs) - shared_experts = self.shared_experts(inputs) + def __call__( + self, + inputs: jax.Array, + intermediate_sharding: NamedSharding | None = None, + out_sharding: NamedSharding | None = None, + ) -> jax.Array: + routed_experts, _ = self.routed_moe(inputs, out_sharding=out_sharding) + shared_experts = self.shared_experts(inputs, intermediate_sharding=intermediate_sharding, out_sharding=out_sharding) return routed_experts + shared_experts diff --git a/src/MaxText/layers/normalizations.py b/src/MaxText/layers/normalizations.py index d868abec7..2387cc523 100644 --- a/src/MaxText/layers/normalizations.py +++ b/src/MaxText/layers/normalizations.py @@ -77,8 +77,11 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> scale = jnp.asarray(scale, self.dtype) effective_scale = scale + self.scale_offset # Apply offset - # broadcast 2nd input then element-wise mul - return jnp.einsum("i...k,...k->i...k", y, effective_scale, out_sharding=out_sharding) + # y: (B, S, E) + # effective_scale: (E,) -> (1, 1, E) -> (B, S, E) + effective_scale = jnp.expand_dims(effective_scale, axis=tuple(range(y.ndim - effective_scale.ndim))) + effective_scale = jnp.broadcast_to(effective_scale, y.shape, out_sharding=out_sharding) + return jnp.multiply(y, effective_scale) def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs): diff --git a/src/MaxText/sharding.py b/src/MaxText/sharding.py index 8a5e7d338..aefe5dae6 100644 --- a/src/MaxText/sharding.py +++ b/src/MaxText/sharding.py @@ -38,6 +38,8 @@ def maybe_shard_with_name(inputs, named_sharding, shard_mode): In auto shardmode, this function hints inputs follow given named_sharding. In explicit shardmode, this function enforces inputs following named_sharding. """ + if inputs is None: + return None if shard_mode == ShardMode.EXPLICIT: return reshard(inputs, named_sharding) else: @@ -48,6 +50,8 @@ def maybe_shard_with_logical(inputs, logical_axes, mesh, shard_mode): """ A wrapper of maybe_shard_with_name when logical axes are inputs """ + if inputs is None: + return None named_sharding = NamedSharding(mesh, nn.logical_to_mesh_axes(logical_axes)) return maybe_shard_with_name(inputs, named_sharding, shard_mode) diff --git a/tests/attention_test.py b/tests/attention_test.py index 1d7c755d0..3408003a9 100644 --- a/tests/attention_test.py +++ b/tests/attention_test.py @@ -27,7 +27,7 @@ import numpy as np -from jax.sharding import Mesh, NamedSharding, AxisType +from jax.sharding import Mesh, NamedSharding, AxisType, PartitionSpec as P import jax import jax.numpy as jnp @@ -42,10 +42,12 @@ MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, + EP_AS_CONTEXT, AttentionType, ShardMode, ) from MaxText.globals import MAXTEXT_PKG_DIR +from MaxText.sharding import maybe_shard_with_name from MaxText.layers.attentions import Attention from MaxText.layers.attention_op import ChunkedCausalMask, _make_bidirectional_block_mask, _generate_chunk_attention_mask from MaxText.layers.attention_mla import MLA @@ -547,6 +549,7 @@ def tpu_kernel_attention_helper(self, num_kv_heads): "context_parallel_load_balance": False, "ici_expert_parallelism": 1, "expert_shard_attention_option": "fsdp", + "shard_mode": "auto", }, { "testcase_name": "cp_with_load_balance", @@ -554,6 +557,7 @@ def tpu_kernel_attention_helper(self, num_kv_heads): "context_parallel_load_balance": True, "ici_expert_parallelism": 1, "expert_shard_attention_option": "fsdp", + "shard_mode": "auto", }, { "testcase_name": "cp_ep_no_load_balance", @@ -561,6 +565,7 @@ def tpu_kernel_attention_helper(self, num_kv_heads): "context_parallel_load_balance": False, "ici_expert_parallelism": 2, "expert_shard_attention_option": "context", + "shard_mode": "auto", }, { "testcase_name": "cp_ep_with_load_balance", @@ -568,6 +573,7 @@ def tpu_kernel_attention_helper(self, num_kv_heads): "context_parallel_load_balance": True, "ici_expert_parallelism": 2, "expert_shard_attention_option": "context", + "shard_mode": "auto", }, { "testcase_name": "ep_no_load_balance", @@ -575,6 +581,7 @@ def tpu_kernel_attention_helper(self, num_kv_heads): "context_parallel_load_balance": False, "ici_expert_parallelism": 4, "expert_shard_attention_option": "context", + "shard_mode": "auto", }, { "testcase_name": "ep_with_load_balance", @@ -582,6 +589,55 @@ def tpu_kernel_attention_helper(self, num_kv_heads): "context_parallel_load_balance": True, "ici_expert_parallelism": 4, "expert_shard_attention_option": "context", + "shard_mode": "auto", + }, + { + "testcase_name": "cp_no_load_balance_explicit", + "ici_context_parallelism": 4, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 1, + "expert_shard_attention_option": "fsdp", + "shard_mode": "explicit", + }, + { + "testcase_name": "cp_with_load_balance_explicit", + "ici_context_parallelism": 4, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 1, + "expert_shard_attention_option": "fsdp", + "shard_mode": "explicit", + }, + { + "testcase_name": "cp_ep_no_load_balance_explicit", + "ici_context_parallelism": 2, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 2, + "expert_shard_attention_option": "context", + "shard_mode": "explicit", + }, + { + "testcase_name": "cp_ep_with_load_balance_explicit", + "ici_context_parallelism": 2, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 2, + "expert_shard_attention_option": "context", + "shard_mode": "explicit", + }, + { + "testcase_name": "ep_no_load_balance_explicit", + "ici_context_parallelism": 1, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 4, + "expert_shard_attention_option": "context", + "shard_mode": "explicit", + }, + { + "testcase_name": "ep_with_load_balance_explicit", + "ici_context_parallelism": 1, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 4, + "expert_shard_attention_option": "context", + "shard_mode": "explicit", }, ) # TODO (b/454764135.) : This tests fails with new tokamax kernel @@ -592,6 +648,7 @@ def test_tpu_flash_attention_context_parallel( context_parallel_load_balance, ici_expert_parallelism, expert_shard_attention_option, + shard_mode, ): """Test equivalence between dot_product and flash attention + context/expert parallelism""" num_kv_heads = self.num_kv_heads @@ -615,9 +672,11 @@ def test_tpu_flash_attention_context_parallel( context_parallel_load_balance=context_parallel_load_balance, ici_expert_parallelism=ici_expert_parallelism, expert_shard_attention_option=expert_shard_attention_option, + shard_mode=shard_mode, ) devices_array_cp = maxtext_utils.create_device_mesh(cfg_cp) - axis_names = [AxisType.Auto for _ in cfg_cp.mesh_axes] + axis_type = AxisType.Explicit if shard_mode == "explicit" else AxisType.Auto + axis_names = [axis_type for _ in cfg_cp.mesh_axes] mesh_cp = Mesh(devices_array_cp, cfg_cp.mesh_axes, axis_types=tuple(axis_names)) attention_as_mha_flash_cp = Attention( config=cfg_cp, @@ -1346,6 +1405,7 @@ def test_projection_initialization(self): "context_parallel_load_balance": False, "ici_expert_parallelism": 1, "expert_shard_attention_option": "fsdp", + "shard_mode": "auto", }, { "testcase_name": "cp_with_load_balance", @@ -1353,6 +1413,7 @@ def test_projection_initialization(self): "context_parallel_load_balance": True, "ici_expert_parallelism": 1, "expert_shard_attention_option": "fsdp", + "shard_mode": "auto", }, { "testcase_name": "cp_ep_no_load_balance", @@ -1360,6 +1421,7 @@ def test_projection_initialization(self): "context_parallel_load_balance": False, "ici_expert_parallelism": 2, "expert_shard_attention_option": "context", + "shard_mode": "auto", }, { "testcase_name": "cp_ep_with_load_balance", @@ -1367,6 +1429,7 @@ def test_projection_initialization(self): "context_parallel_load_balance": True, "ici_expert_parallelism": 2, "expert_shard_attention_option": "context", + "shard_mode": "auto", }, { "testcase_name": "ep_no_load_balance", @@ -1374,6 +1437,7 @@ def test_projection_initialization(self): "context_parallel_load_balance": False, "ici_expert_parallelism": 4, "expert_shard_attention_option": "context", + "shard_mode": "auto", }, { "testcase_name": "ep_with_load_balance", @@ -1381,6 +1445,55 @@ def test_projection_initialization(self): "context_parallel_load_balance": True, "ici_expert_parallelism": 4, "expert_shard_attention_option": "context", + "shard_mode": "auto", + }, + { + "testcase_name": "cp_no_load_balance_explicit", + "ici_context_parallelism": 4, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 1, + "expert_shard_attention_option": "fsdp", + "shard_mode": "explicit", + }, + { + "testcase_name": "cp_with_load_balance_explicit", + "ici_context_parallelism": 4, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 1, + "expert_shard_attention_option": "fsdp", + "shard_mode": "explicit", + }, + { + "testcase_name": "cp_ep_no_load_balance_explicit", + "ici_context_parallelism": 2, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 2, + "expert_shard_attention_option": "context", + "shard_mode": "explicit", + }, + { + "testcase_name": "cp_ep_with_load_balance_explicit", + "ici_context_parallelism": 2, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 2, + "expert_shard_attention_option": "context", + "shard_mode": "explicit", + }, + { + "testcase_name": "ep_no_load_balance_explicit", + "ici_context_parallelism": 1, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 4, + "expert_shard_attention_option": "context", + "shard_mode": "explicit", + }, + { + "testcase_name": "ep_with_load_balance_explicit", + "ici_context_parallelism": 1, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 4, + "expert_shard_attention_option": "context", + "shard_mode": "explicit", }, ) # TODO (b/454764135.) : This tests fails with new tokamax kernel @@ -1391,6 +1504,7 @@ def test_tpu_flash_attention_context_parallel( context_parallel_load_balance, ici_expert_parallelism, expert_shard_attention_option, + shard_mode, ): """Test equivalence between dot_product and flash attention + context/expert parallelism""" @@ -1413,6 +1527,7 @@ def test_tpu_flash_attention_context_parallel( "qk_nope_head_dim": 128, "qk_rope_head_dim": 64, "v_head_dim": 128, + "shard_mode": shard_mode, } cfg, mla = self.init_mla(config_arguments, rope_type="default") @@ -1439,7 +1554,9 @@ def test_tpu_flash_attention_context_parallel( expert_shard_attention_option=expert_shard_attention_option, ) devices_array_cp = maxtext_utils.create_device_mesh(cfg_cp) - mesh_cp = Mesh(devices_array_cp, cfg_cp.mesh_axes) + axis_type = AxisType.Explicit if shard_mode == "explicit" else AxisType.Auto + axis_names = [axis_type for _ in cfg_cp.mesh_axes] + mesh_cp = Mesh(devices_array_cp, cfg_cp.mesh_axes, axis_types=tuple(axis_names)) attention_as_mla_flash_cp = MLA( config=cfg_cp, num_query_heads=cfg_cp.num_query_heads, @@ -1467,6 +1584,10 @@ def test_tpu_flash_attention_context_parallel( cfg_cp, mesh_cp, attention_as_mla_flash_cp, lnx, decoder_segment_ids, decoder_positions ) + # This removes all sharding information and makes them standard NumPy arrays. + mla_generic_output = jax.device_get(mla_generic_output) + mla_generic_flash_cp_output = jax.device_get(mla_generic_flash_cp_output) + self.assertTrue( jax.numpy.allclose(mla_generic_output, mla_generic_flash_cp_output, rtol=1e-01, atol=1e-01, equal_nan=False), msg="MLA Logits from generic dot product and flash attention + context/expert parallelism are not close.\n" @@ -1489,12 +1610,16 @@ def _forward_with_context_expert_parallelism(cfg_cp, mesh_cp, attention_cp, lnx, decoder_positions = reordered_batch["inputs_position"] # apply attention with sharding with mesh_cp, nn_partitioning.axis_rules(cfg_cp.logical_axis_rules): + if cfg_cp.expert_shard_attention_option == EP_AS_CONTEXT: + batch_axis = "activation_batch_no_exp" + length_axis = "activation_length" + else: + batch_axis = "activation_batch" + length_axis = "activation_length_no_exp" lnx_spec = nn_partitioning.logical_to_mesh_axes( - ("activation_batch_no_exp", "activation_length_no_exp", "activation_embed"), nn_partitioning.get_axis_rules() - ) - pos_spec = nn_partitioning.logical_to_mesh_axes( - ("activation_batch_no_exp", "activation_length_no_exp"), nn_partitioning.get_axis_rules() + (batch_axis, length_axis, "activation_embed"), nn_partitioning.get_axis_rules() ) + pos_spec = nn_partitioning.logical_to_mesh_axes((batch_axis, length_axis), nn_partitioning.get_axis_rules()) lnx_sharding = NamedSharding(mesh_cp, lnx_spec) pos_sharding = NamedSharding(mesh_cp, pos_spec) @@ -1510,7 +1635,11 @@ def _forward_with_context_expert_parallelism(cfg_cp, mesh_cp, attention_cp, lnx, deterministic=True, model_mode=MODEL_MODE_TRAIN, ) - attention_cp_output = attention_cp_output[0] if isinstance(attention_cp_output, tuple) else attention_cp_output + + attention_cp_output = attention_cp_output[0] if isinstance(attention_cp_output, tuple) else attention_cp_output + # All-gather before re-shuffle to avoid re-order sharding confusion + repeat_sharding = NamedSharding(mesh_cp, P()) + attention_cp_output = maybe_shard_with_name(attention_cp_output, repeat_sharding, shard_mode=cfg_cp.shard_mode) # If load balanced cp, de-shuffle and gather along seq dim for output # Note training does not need post-shuffle. Since the target seq is also pre-shuffled, the loss remains correct diff --git a/tests/check_gemma3_layers.py b/tests/check_gemma3_layers.py index f812f905f..b95ce3269 100644 --- a/tests/check_gemma3_layers.py +++ b/tests/check_gemma3_layers.py @@ -27,6 +27,7 @@ import jax import jax.numpy as jnp +from jax.sharding import Mesh import numpy as np @@ -78,6 +79,8 @@ def __init__(self, config, device=None): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + self.mesh = Mesh(jax.devices(), "data") + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -186,6 +189,7 @@ def __init__(self, rope_theta, head_dim, max_position_embeddings): jax_rope = embeddings.RotaryEmbedding( min_timescale=min_timescale, max_timescale=max_timescale, + mesh=self.mesh, embedding_dims=head_dim, cast_as_fprop_dtype=False, fprop_dtype=jnp.float32, diff --git a/tests/check_gpt_vs_reference.py b/tests/check_gpt_vs_reference.py index a7fd27679..ea748719b 100644 --- a/tests/check_gpt_vs_reference.py +++ b/tests/check_gpt_vs_reference.py @@ -727,6 +727,8 @@ def setUp(self): "num_attention_heads": float("inf"), } self.pt_config = SimpleNamespace(**pt_config) + devices_array = maxtext_utils.create_device_mesh(self.config) + self.mesh = Mesh(devices_array, self.config.mesh_axes) def test_yarn(self): """Validates the JAX Yarn RoPE implementation against the HF reference.""" @@ -738,6 +740,7 @@ def test_yarn(self): model_jax = embeddings.YarnRotaryEmbedding( max_position_embeddings=self.config.max_position_embeddings, original_max_position_embeddings=self.config.original_max_position_embeddings, + mesh=self.mesh, beta_fast=self.config.beta_fast, beta_slow=self.config.beta_slow, rope_theta=self.config.rope_max_timescale, diff --git a/tests/llama_test.py b/tests/llama_test.py index f2d5dd06a..2054898f3 100644 --- a/tests/llama_test.py +++ b/tests/llama_test.py @@ -20,15 +20,16 @@ import jax import jax.numpy as jnp +from jax.sharding import Mesh from MaxText.layers import embeddings -""" -An example reference jax_llama RoPE implementation from https://github.com/Sea-Snell/ -Users should feel free to change and optimize the RoPE implementation in MaxText defined in layers.py -as long as it passes our tests. But they shouldn't change the "reference" implementation in -llama_test.py which is only to be used for comparison purpose. +""" +An example reference jax_llama RoPE implementation from https://github.com/Sea-Snell/ +Users should feel free to change and optimize the RoPE implementation in MaxText defined in layers.py +as long as it passes our tests. But they shouldn't change the "reference" implementation in +llama_test.py which is only to be used for comparison purpose. """ @@ -76,6 +77,9 @@ def permute_to_match_maxtext_rope(arr): class RoPETest(unittest.TestCase): """Test for the RoPE implementation.""" + def setUp(self): + self.mesh = Mesh(jax.devices(), "data") + def test_rope(self): dim_per_head = 128 seq_len = 8 @@ -91,7 +95,7 @@ def test_rope(self): llama_output = apply_rotary_emb(jnp.asarray(x_q), jnp.asarray(x_k), freqs_cis) position = jnp.arange(seq_len, dtype=jnp.float32)[jnp.newaxis, :] - rope = embeddings.RotaryEmbedding(min_timescale=1, max_timescale=10_000, embedding_dims=dim_per_head) + rope = embeddings.RotaryEmbedding(min_timescale=1, max_timescale=10_000, embedding_dims=dim_per_head, mesh=self.mesh) query_proj = rope(permute_to_match_maxtext_rope(x_q), position) key_proj = rope(permute_to_match_maxtext_rope(x_k), position) @@ -108,7 +112,7 @@ def test_scaling_rope(self): position = jnp.arange(seq_len, dtype=jnp.float32)[jnp.newaxis, :] # Calculate RoPE embeddings and then scale - rope = embeddings.RotaryEmbedding(min_timescale=1, max_timescale=10_000, embedding_dims=dim_per_head) + rope = embeddings.RotaryEmbedding(min_timescale=1, max_timescale=10_000, embedding_dims=dim_per_head, mesh=self.mesh) query_proj_1 = rope(x_q, position=position) query_proj_1 = query_proj_1 * (dim_per_head**-0.5) @@ -127,13 +131,13 @@ def test_llama_rope_with_scaling(self): # Test LLaMARotaryEmbedding with scaling llama_rope_scaled = embeddings.LLaMARotaryEmbedding( - min_timescale=1, max_timescale=10000, embedding_dims=dim_per_head, use_scale=True + min_timescale=1, max_timescale=10000, embedding_dims=dim_per_head, use_scale=True, mesh=self.mesh ) query_proj_scaled = llama_rope_scaled(x_q, position) # Test LLaMARotaryEmbedding without scaling llama_rope_no_scale = embeddings.LLaMARotaryEmbedding( - min_timescale=1, max_timescale=10000, embedding_dims=dim_per_head, use_scale=False + min_timescale=1, max_timescale=10000, embedding_dims=dim_per_head, use_scale=False, mesh=self.mesh ) query_proj_no_scale = llama_rope_no_scale(x_q, position) @@ -150,7 +154,11 @@ def test_llama_rope_single_rotation(self): # Use LLaMARotaryEmbedding llama_rope = embeddings.LLaMARotaryEmbedding( - min_timescale=min_timescale, max_timescale=max_timescale, embedding_dims=dim_per_head, use_scale=False + min_timescale=min_timescale, + max_timescale=max_timescale, + embedding_dims=dim_per_head, + use_scale=False, + mesh=self.mesh, ) query_proj = llama_rope(x_q, position) diff --git a/tests/moe_test.py b/tests/moe_test.py index bab5c649a..848d9b558 100644 --- a/tests/moe_test.py +++ b/tests/moe_test.py @@ -297,10 +297,12 @@ def __init__( self.gate = moe.GateLogit( in_features_shape=self.inputs_shape[-1], out_features_shape=self.num_experts, + mesh=self.mesh, model_name=self.config.model_name, dtype=self.dtype, kernel_init=self.kernel_init, kernel_axes=self.kernel_axes, + shard_mode=config.shard_mode, rngs=rngs, ) for k in range(self.num_experts):