Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/MaxText/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
PREFILL_LENGTH = "prefill_activation_length"
Q_LENGTH = "activation_q_length"
Q_LENGTH_NO_EXP = "activation_q_length_no_exp"
Q_LORA_UP_PROJ = "q_lora_up_proj"
KV_LENGTH = "activation_kv_length"
KV_LORA_UP_PROJ = "kv_lora_up_proj"
EMBED = "activation_embed"
HEAD = "activation_heads"
PREFILL_KV_BATCH = "activation_prefill_kv_batch"
Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1920,7 +1920,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
if self.packing:
raise ValueError("For multimodal SFT, `packing` is not yet supported.")
if self.shard_mode == ShardMode.EXPLICIT:
supported_decoders = {"simple", "simple_mlp", "llama2"}
supported_decoders = {"simple", "simple_mlp", "llama2", "deepseek"}
if self.decoder_block.value not in supported_decoders:
raise ValueError(
f"Decoder '{self.decoder_block.value}' is not supported with 'explicit' sharding. "
Expand Down
90 changes: 61 additions & 29 deletions src/MaxText/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@
EMBED,
EP_AS_CONTEXT,
HEAD,
Q_LORA_UP_PROJ,
KV_BATCH,
KV_BATCH_NO_EXP,
KV_HEAD,
KV_HEAD_DIM,
KV_LORA_UP_PROJ,
LENGTH,
LENGTH_NO_EXP,
MODEL_MODE_PREFILL,
Expand Down Expand Up @@ -389,6 +391,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
weight_dtype=self.weight_dtype,
quant=self.quant,
matmul_precision=self.config.matmul_precision,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
else:
Expand All @@ -403,6 +406,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
weight_dtype=self.weight_dtype,
quant=self.quant,
matmul_precision=self.config.matmul_precision,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
self.q_norm = RMSNorm(
Expand All @@ -423,6 +427,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
weight_dtype=self.weight_dtype,
quant=self.quant,
matmul_precision=self.config.matmul_precision,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)

Expand All @@ -437,6 +442,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
weight_dtype=self.weight_dtype,
quant=self.quant,
matmul_precision=self.config.matmul_precision,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
self.kv_norm = RMSNorm(
Expand All @@ -460,6 +466,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
weight_dtype=self.weight_dtype,
quant=self.quant,
matmul_precision=self.config.matmul_precision,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)

Expand Down Expand Up @@ -498,6 +505,18 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No

def mla_query_projection(self, inputs_q: Array, inputs_positions: Array, model_mode) -> Array:
"""Query projection for MLA, e.g. includes LoRA if q_lora_rank > 0."""
# specify query logical name
if model_mode == MODEL_MODE_PREFILL:
query_logical_name = self.prefill_query_axis_names
wqa_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, Q_LORA_UP_PROJ)
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
query_logical_name = self.ep_query_axis_names
wqa_logical_name = (KV_BATCH_NO_EXP, LENGTH, Q_LORA_UP_PROJ)
else:
query_logical_name = self.query_axis_names
wqa_logical_name = (KV_BATCH, LENGTH_NO_EXP, Q_LORA_UP_PROJ)
query_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(query_logical_name))
wqa_out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(wqa_logical_name))
# Set softmax scaling.
self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
self.softmax_scale = self.qk_head_dim**-0.5
Expand All @@ -506,47 +525,49 @@ def mla_query_projection(self, inputs_q: Array, inputs_positions: Array, model_m
self.softmax_scale = self.softmax_scale * mscale * mscale

if self.q_lora_rank == 0:
q = self.query(inputs_q)
q = self.query(inputs_q, out_sharding=query_sharding)
else:
# LoRA path
low_rank_q = self.wq_a(inputs_q) # [B, L, q_lora_rank]
low_rank_q = self.wq_a(inputs_q, out_sharding=wqa_out_sharding) # [B, L, q_lora_rank]
low_rank_q = self.q_norm(low_rank_q) # RMSNorm on low rank
q = self.wq_b(low_rank_q) # [B, L, n_heads * qk_head_dim]
q = self.wq_b(low_rank_q, out_sharding=query_sharding) # [B, L, n_heads * qk_head_dim]

# Split into non-positional and rotary parts.
q_nope, q_pe = jnp.split(q, [self.qk_nope_head_dim], axis=-1)
q_nope = self._maybe_shard_with_logical(q_nope, query_logical_name)
q_pe = self.apply_rotary_embedding(q_pe, inputs_positions=inputs_positions)
q_pe = self._maybe_shard_with_logical(q_pe, query_logical_name)
# Query projection is scaled by self.softmax_scale to be consistent MaxText implementation.
# DeepSeek v3 was doing it in attention score computation.
query = jnp.concatenate([q_nope, q_pe], axis=-1) * self.softmax_scale
query = self._maybe_shard_with_logical(query, query_logical_name)
return query

def mla_get_key_value(self, low_rank_main, key_rope, model_mode):
"""get (key,value) pair from mla"""
if model_mode == MODEL_MODE_PREFILL:
query = nn.with_logical_constraint(query, self.prefill_query_axis_names)
key_logical_name = self.prefill_key_axis_names
value_logical_name = self.prefill_value_axis_names
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
query = nn.with_logical_constraint(query, self.ep_query_axis_names)
key_logical_name = self.ep_key_axis_names
value_logical_name = self.ep_value_axis_names
else:
query = nn.with_logical_constraint(query, self.query_axis_names)
return query
key_logical_name = self.key_axis_names
value_logical_name = self.value_axis_names

def mla_get_key_value(self, low_rank_main, key_rope, model_mode):
"""get (key,value) pair from mla"""
kv_out = self.wkv_b(low_rank_main)
wkva_out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(key_logical_name))
kv_out = self.wkv_b(low_rank_main, out_sharding=wkva_out_sharding)

# Split kv_out into key_nope and value parts.
key_nope, value = jnp.split(kv_out, [self.qk_nope_head_dim], axis=-1)
key_rope = jnp.broadcast_to(key_rope, (key_nope.shape[0], key_nope.shape[1], self.num_query_heads, key_rope.shape[3]))
key_nope = self._maybe_shard_with_logical(key_nope, key_logical_name)
key_rope = self._maybe_shard_with_logical(key_rope, key_logical_name)

key = jnp.concatenate([key_nope, key_rope], axis=-1)

if model_mode == MODEL_MODE_PREFILL:
key = nn.with_logical_constraint(key, self.prefill_key_axis_names)
value = nn.with_logical_constraint(value, self.prefill_value_axis_names)
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
key = nn.with_logical_constraint(key, self.ep_key_axis_names)
value = nn.with_logical_constraint(value, self.ep_value_axis_names)
else:
key = nn.with_logical_constraint(key, self.key_axis_names)
value = nn.with_logical_constraint(value, self.value_axis_names)
key = self._maybe_shard_with_logical(key, key_logical_name)
value = self._maybe_shard_with_logical(value, value_logical_name)
return key, value

def init_mla_kv_caches(self, inputs_kv_shape: Tuple):
Expand Down Expand Up @@ -637,7 +658,14 @@ def update_mla_kv_caches(self, low_rank_main, key_rope, decoder_segment_ids, mod

def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segment_ids, model_mode, previous_chunk):
"""MLA key/value projection with integrated rotary embedding."""
low_rank = self.wkv_a(inputs)
if model_mode == MODEL_MODE_PREFILL:
wka_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_LORA_UP_PROJ)
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
wka_logical_name = (KV_BATCH_NO_EXP, LENGTH, KV_LORA_UP_PROJ)
else:
wka_logical_name = (KV_BATCH, LENGTH_NO_EXP, KV_LORA_UP_PROJ)
wkva_out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(wka_logical_name))
low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding)
low_rank_main, low_rank_rope = jnp.split(low_rank, [self.kv_lora_rank], axis=-1)
low_rank_main = self.kv_norm(low_rank_main)

Expand Down Expand Up @@ -696,14 +724,17 @@ def __call__(
MLA-attended outputs.
"""
if model_mode == MODEL_MODE_PREFILL:
inputs_q = nn.with_logical_constraint(inputs_q, self.prefill_input_axis_names)
inputs_kv = nn.with_logical_constraint(inputs_kv, self.prefill_input_axis_names)
inputs_q = self._maybe_shard_with_logical(inputs_q, self.prefill_input_axis_names)
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.prefill_input_axis_names)
out_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV)
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
inputs_q = nn.with_logical_constraint(inputs_q, self.ep_input_axis_names)
inputs_kv = nn.with_logical_constraint(inputs_kv, self.ep_input_axis_names)
inputs_q = self._maybe_shard_with_logical(inputs_q, self.ep_input_axis_names)
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.ep_input_axis_names)
out_logical_name = (BATCH_NO_EXP, LENGTH, HEAD, D_KV)
else:
inputs_q = nn.with_logical_constraint(inputs_q, self.input_axis_names)
inputs_kv = nn.with_logical_constraint(inputs_kv, self.input_axis_names)
inputs_q = self._maybe_shard_with_logical(inputs_q, self.input_axis_names)
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names)
out_logical_name = (BATCH, LENGTH_NO_EXP, HEAD, D_KV)

query = self.mla_query_projection(inputs_q, inputs_positions, model_mode)
key, value, cached_values = self.mla_kv_projection(
Expand All @@ -724,10 +755,11 @@ def __call__(
out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, cached_values)

if model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
out = nn.with_logical_constraint(out, self.ep_out_axis_names)
out = self._maybe_shard_with_logical(out, self.ep_out_axis_names)
else:
out = nn.with_logical_constraint(out, self.out_axis_names)
out = self._maybe_shard_with_logical(out, self.out_axis_names)

out = self.out_projection(out)
out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(out_logical_name))
out = self.out_projection(out, out_sharding=out_sharding)
out = checkpoint_name(out, "out_proj")
return out, kv_cache
33 changes: 19 additions & 14 deletions src/MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,14 +748,17 @@ def init_rotary_embedding(self):
rotary_embedding = LLaMARotaryEmbedding(
min_timescale=self.config.rope_min_timescale,
max_timescale=self.config.rope_max_timescale,
mesh=self.mesh,
embedding_dims=rope_embedding_dims,
fprop_dtype=self.dtype,
use_scale=rope_use_scale,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
elif rope_type.startswith("yarn"):
rotary_embedding = YarnRotaryEmbedding(
max_position_embeddings=self.config.max_position_embeddings,
mesh=self.mesh,
original_max_position_embeddings=self.config.original_max_position_embeddings,
beta_fast=self.config.beta_fast,
beta_slow=self.config.beta_slow,
Expand All @@ -766,16 +769,19 @@ def init_rotary_embedding(self):
interleave=self.config.rope_interleave,
truncate=self.config.rope_truncate,
attention_scaling=self.config.rope_attention_scaling,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
elif self.is_qwen3_next:
rotary_embedding = Qwen3NextRotaryEmbedding(
min_timescale=self.config.rope_min_timescale,
max_timescale=self.config.rope_max_timescale,
mesh=self.mesh,
embedding_dims=self.config.head_dim,
partial_rotary_factor=self.config.partial_rotary_factor,
cast_as_fprop_dtype=True,
fprop_dtype=self.config.dtype,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
else:
Expand All @@ -792,9 +798,11 @@ def init_rotary_embedding(self):
rotary_embedding = RotaryEmbedding(
min_timescale=self.config.rope_min_timescale,
max_timescale=max_timescale,
mesh=self.mesh,
embedding_dims=rope_embedding_dims,
fprop_dtype=self.dtype,
rope_linear_scaling_factor=rope_linear_scaling_factor,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
return rotary_embedding
Expand Down Expand Up @@ -985,28 +993,25 @@ def __call__(
output of shape `[batch, length, q_features]`.
"""
if model_mode == MODEL_MODE_PREFILL:
inputs_q = self._maybe_shard_with_logical(inputs_q, self.prefill_input_axis_names)
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.prefill_input_axis_names)
input_axis_names = self.prefill_input_axis_names
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
inputs_q = self._maybe_shard_with_logical(inputs_q, self.ep_input_axis_names)
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.ep_input_axis_names)
input_axis_names = self.ep_input_axis_names
elif model_mode == MODEL_MODE_TRAIN:
inputs_q = self._maybe_shard_with_logical(inputs_q, self.input_axis_names)
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names)
input_axis_names = self.input_axis_names
else:
inputs_q = self._maybe_shard_with_logical(inputs_q, self.decode_input_axis_names)
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.decode_input_axis_names)
input_axis_names = self.decode_input_axis_names

inputs_q = self._maybe_shard_with_logical(inputs_q, input_axis_names)
inputs_kv = self._maybe_shard_with_logical(inputs_kv, input_axis_names)
qkv_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(input_axis_names))

# apply projection.
if self.config.fused_qkv:
query, key, value = self.qkv_projection(inputs_q, proj_name="qkv_proj")
else:
query_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(self.query_axis_names))
query = self.query_projection(inputs_q, out_sharding=query_sharding)
key_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(self.key_axis_names))
key = self.kv_projection(inputs_kv, proj_name="key", out_sharding=key_sharding)
value_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(self.value_axis_names))
value = self.kv_projection(inputs_kv, proj_name="value", out_sharding=value_sharding)
query = self.query_projection(inputs_q, out_sharding=qkv_sharding)
key = self.kv_projection(inputs_kv, proj_name="key", out_sharding=qkv_sharding)
value = self.kv_projection(inputs_kv, proj_name="value", out_sharding=qkv_sharding)

gate = None
if self.is_qwen3_next:
Expand Down
Loading
Loading