From 61abdbc6fc9282e74cef276526042939cd2c2863 Mon Sep 17 00:00:00 2001 From: Nicolas Grande Date: Mon, 22 Dec 2025 16:44:32 +0000 Subject: [PATCH] Simplifying maxtext vllm adapter implementation. updating example vllm_decode. --- .../vllm/maxtext_vllm_adapter/adapter.py | 149 +++++------------- src/MaxText/layers/decoders.py | 7 +- src/MaxText/layers/models.py | 4 +- src/MaxText/vllm_decode.py | 18 +-- 4 files changed, 58 insertions(+), 120 deletions(-) diff --git a/src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py b/src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py index 221a068e23..e29ba370a5 100644 --- a/src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py +++ b/src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py @@ -14,9 +14,8 @@ """vLLM adapter for MaxText models.""" -import jax -import jax.numpy as jnp import os +import jax from flax import nnx import flax.linen as nn @@ -72,16 +71,17 @@ def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters return maxtext_config -class MaxTextDecoderModel(nnx.Module): - """A vLLM-compatible decoder model wrapper for MaxText. +class MaxTextForCausalLM(nnx.Module): + """A vLLM-compatible causal language model wrapper for MaxText. - This class adapts a MaxText model for use within the vLLM framework, - handling configuration generation, model initialization, and execution + This class serves as the primary interface for integrating MaxText models + into the vLLM serving framework, specifically for causal language modeling + tasks. It handles configuration generation, model initialization, and execution of the decoding step. """ - def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh) -> None: - """Initializes the MaxTextDecoderModel. + def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh): + """Initializes the MaxTextForCausalLM model. Args: vllm_config: The vLLM configuration object. @@ -89,20 +89,20 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh) -> N mesh: The JAX mesh device for model sharding. """ self.vllm_config = vllm_config + self.cfg = vllm_config.model_config self.maxtext_config = generate_maxtext_config(vllm_config) # Model configuration self.mesh = mesh self.model_mode = MODEL_MODE_AUTOREGRESSIVE + self.is_text_generation_model = True # Model creation self.model: nnx.Module | None = None - self.logits: jax.Array | None = None # Handle dummy weight loading during initialization if vllm_config.load_config.load_format == "dummy": - with self.mesh: - self.load_weights(rng_key) + self.load_weights(rng_key) elif self.maxtext_config.load_parameters_path is None: max_logging.log("Warning: No load_parameters_path provided. The model will be initialized with random weights.") @@ -115,7 +115,7 @@ def __call__( *args, **kwargs, ) -> tuple[list[jax.Array], jax.Array, list[jax.Array]]: - """Performs a forward pass through the decoder model. + """Performs a forward pass through the causal language model. Args: kv_caches: A list of JAX arrays representing the KV caches. @@ -127,7 +127,7 @@ def __call__( Returns: A tuple containing: - updated_kv_caches: A list of updated KV caches. - - hidden: The hidden states (Q, d_model). + - hidden: The hidden states. - aux_hidden_states: A list of auxiliary hidden states. Raises: @@ -137,15 +137,15 @@ def __call__( raise ValueError("Model must be an instance of type nnx.Module.") if input_ids.ndim < 2: - input_ids = jnp.expand_dims(input_ids, axis=0) + input_ids = input_ids[None, :] input_positions = attention_metadata.input_positions if input_positions.ndim < 2: - input_positions = jnp.expand_dims(input_positions, axis=0) + input_positions = input_positions[None, :] - with nn.logical_axis_rules(self.maxtext_config.logical_axis_rules): + with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules): aux_hidden_states = [] - logits, hidden, kv_caches = self.model( + hidden, kv_caches = self.model( decoder_input_tokens=input_ids, decoder_positions=input_positions, kv_caches=kv_caches, @@ -154,88 +154,9 @@ def __call__( **kwargs, ) - if hidden.ndim > 1: - hidden = jnp.squeeze(hidden, axis=0) - logits = jnp.squeeze(logits, axis=0) - - self.logits = nnx.data(logits) # cache logits for compute_logits call - - return kv_caches, hidden, aux_hidden_states - - def compute_logits(self, hidden_states: jax.Array) -> jax.Array: - """Computes the logits from the hidden states. - - Args: - hidden_states: A JAX array of hidden states. - - Returns: - A JAX array of logits (Q, vocab_size). - """ - if self.logits is not None: - return self.logits - - with nn.logical_axis_rules(self.maxtext_config.logical_axis_rules): - embeddings = self.model.token_embedder - return self.model.decoder.apply_output_head(embeddings, hidden_states, True, self.model_mode) - - def load_weights(self, rng_key: jax.Array) -> None: - """Loads model parameters on the provided mesh. - - Args: - rng_key: A JAX random key for model initialization. - """ - if self.model is not None: - return - - with nn.logical_axis_rules(""): - model, _ = model_creation_utils.create_nnx_model( - self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key - ) - self.model = nnx.data(model) - - -class MaxTextForCausalLM(nnx.Module): - """A vLLM-compatible causal language model wrapper for MaxText. - - This class serves as the primary interface for integrating MaxText models - into the vLLM serving framework, specifically for causal language modeling - tasks. It wraps the `MaxTextDecoderModel` and exposes methods expected - by vLLM. - """ - - def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh): - """Initializes the MaxTextForCausalLM model. + if hidden.ndim > 1: + hidden = hidden.squeeze(0) - Args: - vllm_config: The vLLM configuration object. - rng_key: A JAX random key for model initialization. - mesh: The JAX mesh device for model sharding. - """ - self.cfg = vllm_config.model_config - self.mesh = mesh - self.model = MaxTextDecoderModel(vllm_config, rng_key, mesh) - self.is_text_generation_model = True - - def __call__( - self, kv_caches: list[jax.Array], input_ids: jax.Array, attention_metadata: AttentionMetadata, *args, **kwargs - ) -> tuple[list[jax.Array], jax.Array]: - """Performs a forward pass through the causal language model. - - Args: - kv_caches: A list of JAX arrays representing the KV caches. - input_ids: A JAX array of input token IDs. - attention_metadata: Attention metadata for the decoding process. - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - - Returns: - A tuple containing: - - updated_kv_caches: A list of updated KV caches. - - hidden: The hidden states. - - aux_hidden_states: A list of auxiliary hidden states. - """ - with self.mesh: - kv_caches, hidden, aux_hidden_states = self.model(kv_caches, input_ids, attention_metadata, *args, **kwargs) return kv_caches, hidden, aux_hidden_states def forward(self, *args, **kwargs): @@ -256,8 +177,11 @@ def get_input_embeddings(self) -> jax.Array: Returns: A JAX array representing the input embeddings. """ - with self.mesh: - return self.model.model.token_embedder.embedding + if not isinstance(self.model, nnx.Module): + raise ValueError("Model is not initialized.") + + with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules): + return self.model.token_embedder.embedding def embed_input_ids(self, input_ids: jax.Array) -> jax.Array: """Embeds the input token IDs using the model's token embedder. @@ -268,8 +192,11 @@ def embed_input_ids(self, input_ids: jax.Array) -> jax.Array: Returns: A JAX array of embedded input tokens. """ - with self.mesh: - return self.model.model.token_embedder(input_ids) + if not isinstance(self.model, nnx.Module): + raise ValueError("Model is not initialized.") + + with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules): + return self.model.token_embedder(input_ids) def compute_logits(self, hidden_states: jax.Array) -> jax.Array: """Computes the logits from the hidden states using the underlying decoder model. @@ -280,8 +207,12 @@ def compute_logits(self, hidden_states: jax.Array) -> jax.Array: Returns: A JAX array of logits. """ - with self.mesh: - return self.model.compute_logits(hidden_states) + if not isinstance(self.model, nnx.Module): + raise ValueError("Model is not initialized.") + + with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules): + embeddings = self.model.token_embedder + return self.model.decoder.apply_output_head(embeddings, hidden_states, True, self.model_mode) def load_weights(self, rng_key: jax.Array) -> None: """Loads model weights using the underlying decoder model. @@ -289,5 +220,11 @@ def load_weights(self, rng_key: jax.Array) -> None: Args: rng_key: A JAX random key for model initialization. """ - with self.mesh: - self.model.load_weights(rng_key) + if self.model is not None: + return + + with self.mesh, nn.logical_axis_rules(""): + model, _ = model_creation_utils.create_nnx_model( + self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key + ) + self.model = nnx.data(model) diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index fc063415f4..4f712d84be 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -902,11 +902,16 @@ def __call__( # After the final transformer layer, `y` holds the raw, un-normalized hidden state. hidden_state = y + # When invoking from vLLM with RPA attention, logit computation is deferred to a later stage. + if cfg.attention == "vllm_rpa": + logits = None + # When vocab tiling is enabled in training mode, full logits won't generate to reduce memory # Instead, we keep track on the hidden states, which has smaller size compared to full logits - if cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: + elif cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: logits = None self.sow("intermediates", "hidden_states", hidden_state) + else: logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) diff --git a/src/MaxText/layers/models.py b/src/MaxText/layers/models.py index 07c46be53e..d79e121412 100644 --- a/src/MaxText/layers/models.py +++ b/src/MaxText/layers/models.py @@ -211,7 +211,7 @@ def __call__( if self.config.attention == "vllm_rpa": # In vLLM, logits are computed separately after updating the KV cache. - return logits, hidden_state, kv_caches + return hidden_state, kv_caches return logits @@ -514,7 +514,7 @@ def __call__( if self.config.attention == "vllm_rpa": # In vLLM, logits are computed separately after updating the KV cache. - return logits, hidden_state, kv_caches + return hidden_state, kv_caches return logits diff --git a/src/MaxText/vllm_decode.py b/src/MaxText/vllm_decode.py index 69c889e466..3f80e7268f 100644 --- a/src/MaxText/vllm_decode.py +++ b/src/MaxText/vllm_decode.py @@ -26,17 +26,13 @@ Or without Tunix using the MaxText vLLM integration: python3 -m MaxText.vllm_decode \ - --model-name qwen3-30b-a3b \ - --hf-model-name Qwen/Qwen3-30B-A3B \ - --hf-config-path src/MaxText/integration/vllm/maxtext_vllm_adapter \ - --load-parameters-path \ - --ici_data_parallelism 1 \ - --ici-tensor-parallelism 4 \ - --ici-expert-parallelism 1 \ - --max-model-len 4096 \ - --max-num-batched-tokens 262144 \ - --gpu-memory-utilization 0.5 \ - --prompt "Suggest some famous landmarks in London." \ + --model_name qwen3-30b-a3b \ + --hf_model_name Qwen/Qwen3-30B-A3B \ + --hf_config_path src/MaxText/integration/vllm/maxtext_vllm_adapter \ + --load_parameters_path \ + --ici_tensor_parallelism 4 \ + --gpu_memory_utilization 0.5 \ + --prompt "Suggest some famous landmarks in London." """ import os