Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 69 additions & 48 deletions src/maxtext/configs/inference/vllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,56 +29,77 @@ weight_dtype: bfloat16
# -------------- Logical Axis Rules --------------
mesh_axes: ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']
logical_axis_rules: [
['activation_batch', ['data']],
['activation_batch_moe', ['data']],
['activation_batch_attn', ['data']],
['activation_embed_and_logits_batch', ['data', 'expert']],
['activation_embed_and_logits_batch_sequence', ['data', 'expert']],
['activation_heads', ['model', 'expert']],
['activation_kv_heads', ['model', 'expert']],
['activation_length_attn', []],
['activation_length', []],
['activation_length_moe', []],
['activation_q_length', ['expert', 'attn_dp_expert']],
['activation_embed_attn', 'model'],
# Expert is missing explicitly from activation_embed despite using TP.
# We are going for a replicate-AR style of TP as opposed to our typical AG-RS style of TP
# due to the output sharding of the fused_moe_gmm kernel in tpu-inference.
['activation_embed', ['model', 'attn_dp']],
['activation_embed_moe', ['model', 'attn_dp']],
['activation_mlp', ['model', 'attn_dp']],
['activation_mlp_moe', ['model', 'attn_dp']],
['activation_kv', ['model']],
['activation_prefill_kv_batch', ['expert', 'attn_dp_expert']],
['activation_kv_batch', ['data']],
['activation_kv_head_dim', ['model']],
['activation_vocab', ['model', 'attn_dp']],
['activation_norm_length', []],
['activation_norm_length_moe', []],
['activation_exp', ['expert', 'attn_dp_expert']],
['decode_batch', ['data']],
['decode_batch_moe', ['data']],
['decode_length', []],
['mlp', ['model', 'attn_dp']],
['mlp_moe', ['model', 'attn_dp']],
['mlp_no_fsdp', ['model', 'attn_dp']],
['vocab', ['model', 'attn_dp']],
# Expert is intended to act like TP for attention.
# We target two all-reduces, one at the end of attention out projection and one at the end of the feedforward.
['heads', ['model', 'expert']],
['q_heads', ['model', 'expert']],
['kv_heads', ['model', 'expert']],
['kv_head_dim', []],
# ==========================================
# Vocabulary Embedding
# ==========================================
# Vocab Activations
['activation_embed_and_logits_batch', ['data', 'attn_dp', 'attn_dp_expert']],
['activation_embed_and_logits_batch_sequence', ['data', 'attn_dp', 'attn_dp_expert']],
['activation_vocab', ['expert', 'model']],
# Vocab Weights
['vocab', []],
['embed_vocab', []],
# ==========================================
# Attention
# ==========================================
# Attention Activations
['activation_batch_attn', ['data', 'attn_dp', 'attn_dp_expert']],
['activation_heads', ['expert', 'model']],
['activation_kv_heads', ['expert', 'model']],
['activation_embed_attn', []],
['activation_kv', []],
['activation_kv_batch', ['data', 'attn_dp', 'attn_dp_expert']],
['activation_kv_head_dim', []],
# Attention Weights
['heads', ['expert', 'model']],
['q_heads', ['expert', 'model']],
['kv_heads', ['expert', 'model']],
['qkv', []],
['kv', []],
['embed', []],
['kv_head_dim', []],
['q_lora', []],
["q_lora_up_proj", []],
['kv_lora', []],
["kv_lora_up_proj", []],
# ==========================================
# Mixture of Experts (MoE)
# ==========================================
# MoE Activations
['activation_batch_moe', ['data', 'attn_dp', 'attn_dp_expert']],
['activation_embed_moe', ['attn_dp', 'model']],
['activation_mlp_moe', ['attn_dp', 'model']],
['activation_exp', ['attn_dp_expert', 'expert']],
# MoE Weights
['exp', ['attn_dp_expert', 'expert']],
['mlp_moe', ['attn_dp', 'model']],
['embed_moe', []],
['embed_tensor_transpose', ['attn_dp', 'model']],
['q_lora', ['expert', 'attn_dp_expert']],
['kv_lora', ['expert', 'attn_dp_expert']],
# ==========================================
# Standard MLP / Dense Layers / Model Structure
# ==========================================
# Dense Activations
['activation_mlp', ['attn_dp', 'model']],
# Note activation batch and length also get used in attention and vocab
['activation_batch', ['data', 'attn_dp', 'attn_dp_expert']],
['activation_embed', []],
# General Weights
['mlp', ['attn_dp', 'model']],
['embed', []],
['norm', []],
['cache_heads', ['model']],
['exp', ['expert', 'attn_dp_expert']],
['paged_kv_heads', ['model']],
]
# ==========================================
# Inference(Prefill, Decode, Cache)
# ==========================================
['activation_prefill_kv_batch', ['data', 'attn_dp', 'attn_dp_expert']],
['decode_batch', ['data', 'attn_dp', 'attn_dp_expert']],
['cache_heads', ['expert', 'model']],
['paged_kv_heads', ['expert', 'model']],
['cache_batch_prefill', []],
['cache_batch', []],
['cache_heads_none', []],
['cache_kv', []],
['cache_sequence', []],
['num_pages', []],
['tokens_per_page', []],
['paged_kv_head_dim_size', []],
]
data_sharding: [['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']]
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch']
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def load_weights(self, rng_key: jax.Array) -> None:
return

with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
model, _ = model_creation_utils.create_nnx_model(
model = model_creation_utils.from_pretrained(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this part of the PR? Is this some artifact of rebasing?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not from rebasing but vllm decoding fails at head now. Put in the same PR for decoding test purpose.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mind adding this info to the PR description?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mind adding this info to the PR description?

Done

self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
)
self.model = nnx.data(model)
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def generate_and_save_data(config, local_args):

# Loading teacher model and dataset iterator
max_logging.log(f"Loading Teacher Model from {config.load_parameters_path}...")
teacher_model, _ = model_creation_utils.create_nnx_model(config, mesh=mesh)
teacher_model = model_creation_utils.from_pretrained(config, mesh=mesh)
train_iter, _ = input_pipeline_interface.create_data_iterator(config, mesh)

# Determine start_step for resuming
Expand Down
8 changes: 4 additions & 4 deletions src/maxtext/utils/model_creation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

# pylint: disable=bare-except, consider-using-generator
""" Utils that are only interesting for creating a model in MaxText. """
"""Utils that are only interesting for creating a model in MaxText."""

import dataclasses
import collections
Expand Down Expand Up @@ -226,9 +226,9 @@ def create_model(config, mesh, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rng
def create_nnx_abstract_model(config, mesh, model_mode=MODEL_MODE_TRAIN, rng_key=None):
"""Returns (_create_model_partial, abstract_model) for AOT compilation.

Unlike create_nnx_model, this does not shard parameters or load checkpoints.
It only builds the abstract shape/dtype structure needed by get_abstract_state
and optimizer construction (e.g. Muon).
This does not shard parameters or load checkpoints. It only builds the
abstract shape/dtype structure needed by get_abstract_state and optimizer
construction (e.g. Muon).

Args:
config: the configuration
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/model_creation_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def _make_nnx_metadata_mock(self):
@patch("maxtext.utils.model_creation_utils.ocp")
def test_load_nnx_checkpoint(self, mock_ocp):
"""NNX-format checkpoint: restored values are wrapped under a 'value' key."""
# Echo back the `item` argument passed by create_nnx_model to ckptr.restore.
# Echo back the `item` argument passed by from_pretrained to ckptr.restore.
# For NNX checkpoints, item IS already {leaf: {"value": array}, ...}, so
# returning it directly gives a correctly-structured restored dict that
# matches the model's own state — regardless of the exact leaf count.
Expand All @@ -364,7 +364,7 @@ def test_load_nnx_checkpoint(self, mock_ocp):
@patch("maxtext.utils.model_creation_utils.ocp")
def test_load_linen_checkpoint(self, mock_ocp):
"""Linen-format checkpoint: restored values are nested under 'params'/'params'."""
# Echo back the `item` argument passed by create_nnx_model to ckptr.restore.
# Echo back the `item` argument passed by from_pretrained to ckptr.restore.
# For Linen checkpoints, item IS already {"params": {"params": arrays}}, so
# returning it directly gives a correctly-structured restored dict that
# matches the model's own state — regardless of the exact leaf count.
Expand Down
Loading