Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 43 additions & 106 deletions src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@

"""vLLM adapter for MaxText models."""

import jax
import jax.numpy as jnp
import os
import jax

from flax import nnx
import flax.linen as nn
Expand Down Expand Up @@ -72,37 +71,38 @@ def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters
return maxtext_config


class MaxTextDecoderModel(nnx.Module):
"""A vLLM-compatible decoder model wrapper for MaxText.
class MaxTextForCausalLM(nnx.Module):
"""A vLLM-compatible causal language model wrapper for MaxText.

This class adapts a MaxText model for use within the vLLM framework,
handling configuration generation, model initialization, and execution
This class serves as the primary interface for integrating MaxText models
into the vLLM serving framework, specifically for causal language modeling
tasks. It handles configuration generation, model initialization, and execution
of the decoding step.
"""

def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh) -> None:
"""Initializes the MaxTextDecoderModel.
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
"""Initializes the MaxTextForCausalLM model.

Args:
vllm_config: The vLLM configuration object.
rng_key: A JAX random key for model initialization.
mesh: The JAX mesh device for model sharding.
"""
self.vllm_config = vllm_config
self.cfg = vllm_config.model_config
self.maxtext_config = generate_maxtext_config(vllm_config)

# Model configuration
self.mesh = mesh
self.model_mode = MODEL_MODE_AUTOREGRESSIVE
self.is_text_generation_model = True

# Model creation
self.model: nnx.Module | None = None
self.logits: jax.Array | None = None

# Handle dummy weight loading during initialization
if vllm_config.load_config.load_format == "dummy":
with self.mesh:
self.load_weights(rng_key)
self.load_weights(rng_key)

elif self.maxtext_config.load_parameters_path is None:
max_logging.log("Warning: No load_parameters_path provided. The model will be initialized with random weights.")
Expand All @@ -115,7 +115,7 @@ def __call__(
*args,
**kwargs,
) -> tuple[list[jax.Array], jax.Array, list[jax.Array]]:
"""Performs a forward pass through the decoder model.
"""Performs a forward pass through the causal language model.

Args:
kv_caches: A list of JAX arrays representing the KV caches.
Expand All @@ -127,7 +127,7 @@ def __call__(
Returns:
A tuple containing:
- updated_kv_caches: A list of updated KV caches.
- hidden: The hidden states (Q, d_model).
- hidden: The hidden states.
- aux_hidden_states: A list of auxiliary hidden states.

Raises:
Expand All @@ -137,15 +137,15 @@ def __call__(
raise ValueError("Model must be an instance of type nnx.Module.")

if input_ids.ndim < 2:
input_ids = jnp.expand_dims(input_ids, axis=0)
input_ids = input_ids[None, :]

input_positions = attention_metadata.input_positions
if input_positions.ndim < 2:
input_positions = jnp.expand_dims(input_positions, axis=0)
input_positions = input_positions[None, :]

with nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
aux_hidden_states = []
logits, hidden, kv_caches = self.model(
hidden, kv_caches = self.model(
decoder_input_tokens=input_ids,
decoder_positions=input_positions,
kv_caches=kv_caches,
Expand All @@ -154,88 +154,9 @@ def __call__(
**kwargs,
)

if hidden.ndim > 1:
hidden = jnp.squeeze(hidden, axis=0)
logits = jnp.squeeze(logits, axis=0)

self.logits = nnx.data(logits) # cache logits for compute_logits call

return kv_caches, hidden, aux_hidden_states

def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
"""Computes the logits from the hidden states.

Args:
hidden_states: A JAX array of hidden states.

Returns:
A JAX array of logits (Q, vocab_size).
"""
if self.logits is not None:
return self.logits

with nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
embeddings = self.model.token_embedder
return self.model.decoder.apply_output_head(embeddings, hidden_states, True, self.model_mode)

def load_weights(self, rng_key: jax.Array) -> None:
"""Loads model parameters on the provided mesh.

Args:
rng_key: A JAX random key for model initialization.
"""
if self.model is not None:
return

with nn.logical_axis_rules(""):
model, _ = model_creation_utils.create_nnx_model(
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
)
self.model = nnx.data(model)


class MaxTextForCausalLM(nnx.Module):
"""A vLLM-compatible causal language model wrapper for MaxText.

This class serves as the primary interface for integrating MaxText models
into the vLLM serving framework, specifically for causal language modeling
tasks. It wraps the `MaxTextDecoderModel` and exposes methods expected
by vLLM.
"""

def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
"""Initializes the MaxTextForCausalLM model.
if hidden.ndim > 1:
hidden = hidden.squeeze(0)

Args:
vllm_config: The vLLM configuration object.
rng_key: A JAX random key for model initialization.
mesh: The JAX mesh device for model sharding.
"""
self.cfg = vllm_config.model_config
self.mesh = mesh
self.model = MaxTextDecoderModel(vllm_config, rng_key, mesh)
self.is_text_generation_model = True

def __call__(
self, kv_caches: list[jax.Array], input_ids: jax.Array, attention_metadata: AttentionMetadata, *args, **kwargs
) -> tuple[list[jax.Array], jax.Array]:
"""Performs a forward pass through the causal language model.

Args:
kv_caches: A list of JAX arrays representing the KV caches.
input_ids: A JAX array of input token IDs.
attention_metadata: Attention metadata for the decoding process.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.

Returns:
A tuple containing:
- updated_kv_caches: A list of updated KV caches.
- hidden: The hidden states.
- aux_hidden_states: A list of auxiliary hidden states.
"""
with self.mesh:
kv_caches, hidden, aux_hidden_states = self.model(kv_caches, input_ids, attention_metadata, *args, **kwargs)
return kv_caches, hidden, aux_hidden_states

def forward(self, *args, **kwargs):
Expand All @@ -256,8 +177,11 @@ def get_input_embeddings(self) -> jax.Array:
Returns:
A JAX array representing the input embeddings.
"""
with self.mesh:
return self.model.model.token_embedder.embedding
if not isinstance(self.model, nnx.Module):
raise ValueError("Model is not initialized.")

with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
return self.model.token_embedder.embedding

def embed_input_ids(self, input_ids: jax.Array) -> jax.Array:
"""Embeds the input token IDs using the model's token embedder.
Expand All @@ -268,8 +192,11 @@ def embed_input_ids(self, input_ids: jax.Array) -> jax.Array:
Returns:
A JAX array of embedded input tokens.
"""
with self.mesh:
return self.model.model.token_embedder(input_ids)
if not isinstance(self.model, nnx.Module):
raise ValueError("Model is not initialized.")

with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
return self.model.token_embedder(input_ids)

def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
"""Computes the logits from the hidden states using the underlying decoder model.
Expand All @@ -280,14 +207,24 @@ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
Returns:
A JAX array of logits.
"""
with self.mesh:
return self.model.compute_logits(hidden_states)
if not isinstance(self.model, nnx.Module):
raise ValueError("Model is not initialized.")

with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
embeddings = self.model.token_embedder
return self.model.decoder.apply_output_head(embeddings, hidden_states, True, self.model_mode)

def load_weights(self, rng_key: jax.Array) -> None:
"""Loads model weights using the underlying decoder model.

Args:
rng_key: A JAX random key for model initialization.
"""
with self.mesh:
self.model.load_weights(rng_key)
if self.model is not None:
return

with self.mesh, nn.logical_axis_rules(""):
model, _ = model_creation_utils.create_nnx_model(
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
)
self.model = nnx.data(model)
7 changes: 6 additions & 1 deletion src/MaxText/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,11 +902,16 @@ def __call__(
# After the final transformer layer, `y` holds the raw, un-normalized hidden state.
hidden_state = y

# When invoking from vLLM with RPA attention, logit computation is deferred to a later stage.
if cfg.attention == "vllm_rpa":
logits = None

# When vocab tiling is enabled in training mode, full logits won't generate to reduce memory
# Instead, we keep track on the hidden states, which has smaller size compared to full logits
if cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN:
elif cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN:
logits = None
self.sow("intermediates", "hidden_states", hidden_state)

else:
logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode)

Expand Down
4 changes: 2 additions & 2 deletions src/MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def __call__(

if self.config.attention == "vllm_rpa":
# In vLLM, logits are computed separately after updating the KV cache.
return logits, hidden_state, kv_caches
return hidden_state, kv_caches

return logits

Expand Down Expand Up @@ -514,7 +514,7 @@ def __call__(

if self.config.attention == "vllm_rpa":
# In vLLM, logits are computed separately after updating the KV cache.
return logits, hidden_state, kv_caches
return hidden_state, kv_caches

return logits

Expand Down
18 changes: 7 additions & 11 deletions src/MaxText/vllm_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,13 @@

Or without Tunix using the MaxText vLLM integration:
python3 -m MaxText.vllm_decode \
--model-name qwen3-30b-a3b \
--hf-model-name Qwen/Qwen3-30B-A3B \
--hf-config-path src/MaxText/integration/vllm/maxtext_vllm_adapter \
--load-parameters-path <your_checkpoint_path> \
--ici_data_parallelism 1 \
--ici-tensor-parallelism 4 \
--ici-expert-parallelism 1 \
--max-model-len 4096 \
--max-num-batched-tokens 262144 \
--gpu-memory-utilization 0.5 \
--prompt "Suggest some famous landmarks in London." \
--model_name qwen3-30b-a3b \
--hf_model_name Qwen/Qwen3-30B-A3B \
--hf_config_path src/MaxText/integration/vllm/maxtext_vllm_adapter \
--load_parameters_path <your_checkpoint_path> \
--ici_tensor_parallelism 4 \
--gpu_memory_utilization 0.5 \
--prompt "Suggest some famous landmarks in London."
"""

import os
Expand Down