diff --git a/src/MaxText/layers/attention_mla.py b/src/MaxText/layers/attention_mla.py index 39e4bfbf8..0baa01272 100644 --- a/src/MaxText/layers/attention_mla.py +++ b/src/MaxText/layers/attention_mla.py @@ -672,7 +672,9 @@ def __call__( page_state: Optional[page_manager.PageState] = None, bidirectional_mask: Optional[Any] = None, rope_kwargs: dict | None = None, - ) -> Array: + kv_cache: Optional[Array] = None, + attention_metadata: Optional[dict[str, Any]] = None, + ) -> tuple[Array, Optional[Array]]: """Forward pass for MLA, reusing `AttentionOp` for the actual attention. Args: @@ -686,6 +688,8 @@ def __call__( slot: The batch slot index for paged attention. page_state: The current state of the paged attention manager. bidirectional_mask: A mask for bidirectional attention, used in multimodal models. + kv_cache: Optional key-value cache used when serving models with vLLM. + attention_metadata: Optional attention-related metadata used when serving models with vLLM. Returns: A tensor of shape [batch, length, embed_dim] containing the @@ -726,4 +730,4 @@ def __call__( out = self.out_projection(out) out = checkpoint_name(out, "out_proj") - return out + return out, kv_cache diff --git a/src/MaxText/layers/attention_op.py b/src/MaxText/layers/attention_op.py index 59a2d413d..13d8c8c29 100644 --- a/src/MaxText/layers/attention_op.py +++ b/src/MaxText/layers/attention_op.py @@ -847,11 +847,14 @@ def apply_attention( raise NotImplementedError(target_hardware) return impl(query, key, value, lengths, self.ragged_block_size) + # 'vllm_rpa' uses the same dot-attention wrapper but routes to the vLLM + # ragged paged attention kernel in `Attention.__call__`. elif ( self.attention_kernel == "dot_product" or (self.attention_kernel == "autoselected" and model_mode == MODEL_MODE_AUTOREGRESSIVE) or (self.attention_kernel == "autoselected" and length < 128) or (self.attention_kernel == "paged") + or (self.attention_kernel == "vllm_rpa") ): return self.apply_attention_dot( query, diff --git a/src/MaxText/layers/attentions.py b/src/MaxText/layers/attentions.py index 44e4160e0..0586953de 100644 --- a/src/MaxText/layers/attentions.py +++ b/src/MaxText/layers/attentions.py @@ -889,6 +889,51 @@ def update_kv_caches(self, key, value, decoder_segment_ids, model_mode, previous ) return [prefill_kv_cache, ar_kv_cache] + def forward_serve_vllm( + self, + query: Array, + key: Array, + value: Array, + rpa_kv_cache: list[Array] | None = None, + rpa_metadata: dict[str, Any] | None = None, + ) -> tuple[list[Array], Array]: + """Forward function for vLLM serving with RPA attention.""" + try: + # pylint: disable=import-outside-toplevel + # pytype: disable=import-error + from tpu_inference.layers.jax.attention_interface import sharded_ragged_paged_attention as rpa_ops + except ImportError as e: + raise ImportError( + "vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`." + ) from e + + if self.config.attention_sink: + raise NotImplementedError("Attention sink is not supported in MaxText vLLM RPA attention.") + + if rpa_kv_cache is None or rpa_metadata is None: + raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.") + + query = query.reshape(-1, query.shape[2], query.shape[3]) + key = key.reshape(-1, key.shape[2], key.shape[3]) + value = value.reshape(-1, value.shape[2], value.shape[3]) + + attention_chunk_size = self.config.chunk_attn_window_size if self.config.chunk_attn_window_size > 0 else None + q_scale, k_scale, v_scale = None, None, None + + md = rpa_metadata + + output, kv_cache = rpa_ops(1.0, self.mesh, attention_chunk_size, q_scale, k_scale, v_scale)( + query, + key, + value, + rpa_kv_cache, + md.seq_lens, + md.block_tables, + md.query_start_loc, + md.request_distribution, + ) + return kv_cache, output + def __call__( self, inputs_q: Array, @@ -904,6 +949,8 @@ def __call__( page_state: Optional[page_manager.PageState] = None, bidirectional_mask: Any = None, rope_kwargs: dict | None = None, + kv_cache: Optional[Array] = None, + attention_metadata: Optional[dict[str, Any]] = None, ): """Applies Attention on the input data. @@ -931,6 +978,8 @@ def __call__( slot: The batch slot index for paged attention. page_state: The current state of the paged attention manager. bidirectional_mask: A mask for bidirectional attention, used in multimodal models. + kv_cache: Optional KV cache input, used when invoking from vLLM. + attention_metadata: Optional mapping to store attention metadata, used when invoking from vLLM. Returns: output of shape `[batch, length, q_features]`. @@ -1026,6 +1075,15 @@ def __call__( query, key, value, decoder_segment_ids, model_mode, previous_chunk, slot=slot, page_state=page_state ) out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out + + elif self.config.attention == "vllm_rpa" and model_mode != MODEL_MODE_TRAIN: + batch, seq_len, num_heads, head_dim = query.shape + updated_kv, attn_out = self.forward_serve_vllm( + query, key, value, rpa_kv_cache=kv_cache, rpa_metadata=attention_metadata + ) + out = attn_out.reshape(batch, seq_len, num_heads, head_dim) + kv_cache = updated_kv + else: cached_values = [None, None] if model_mode != MODEL_MODE_TRAIN: @@ -1054,4 +1112,4 @@ def __call__( out = self._maybe_shard_with_logical(out, self.decode_out_axis_names) out = self.out_projection(out, out_sharding=out_sharding) out = checkpoint_name(out, "out_proj") - return out + return out, kv_cache diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index 655ff0ac4..90a7557ff 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -87,6 +87,8 @@ def __call__( previous_chunk=None, slot: None | int = None, page_state: None | page_manager.PageState = None, + kv_cache: jax.Array | None = None, + attention_metadata: dict[str, Any] | None = None, ): cfg = self.config mesh = self.mesh @@ -149,13 +151,15 @@ def __call__( model_mode=model_mode, ) - attention_lnx = attention_layer( + attention_lnx, kv_cache = attention_layer( lnx, lnx, decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, model_mode=model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, ) if model_mode == MODEL_MODE_PREFILL: @@ -209,7 +213,10 @@ def __call__( jnp.sum(layer_output == 0) / jnp.size(layer_output), ) - return layer_output, None if cfg.scan_layers else layer_output + if cfg.scan_layers: + return layer_output, None + else: + return layer_output, kv_cache class SequentialBlockDecoderLayers(nn.Module): @@ -691,6 +698,8 @@ def __call__( bidirectional_mask: None | Any = None, image_embeddings: None | jnp.ndarray = None, image_masks: None | jnp.ndarray = None, + kv_caches: list[jax.Array] | None = None, + attention_metadata=None, ): cfg = self.config mesh = self.mesh @@ -844,7 +853,8 @@ def __call__( # Iterate over the two layer groups (dense and MoE) and apply layer transformation for layer, num_layers, layer_prefix in zip(layers, num_layers_list, layer_prefixes): for index in range(num_layers): - y = layer( + kv_cache = kv_caches[index] if kv_caches is not None else None + y, kv_cache = layer( config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode )( y, @@ -855,7 +865,11 @@ def __call__( previous_chunk=previous_chunk, page_state=page_state, slot=slot, + kv_cache=kv_cache, + attention_metadata=attention_metadata, ) + if kv_caches is not None and kv_cache is not None: + kv_caches[index] = kv_cache else: for lyr in range(cfg.num_decoder_layers): RemattedBlockLayer = RemattedBlockLayers[0] @@ -877,7 +891,8 @@ def __call__( layer = RemattedBlockLayer( config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode, **layer_kwargs ) - y = layer( + kv_cache = kv_caches[lyr] if kv_caches is not None else None + y, kv_cache = layer( y, decoder_segment_ids, decoder_positions, @@ -886,8 +901,12 @@ def __call__( previous_chunk=previous_chunk, page_state=page_state, slot=slot, + kv_cache=kv_cache, + attention_metadata=attention_metadata, **layer_call_kwargs, ) + if kv_caches is not None and kv_cache is not None: + kv_caches[lyr] = kv_cache assert isinstance(y, jax.Array) @@ -904,7 +923,7 @@ def __call__( # The API of the Decoder is now a tuple, providing both the main output # and the raw hidden state needed for auxiliary tasks. - return logits, hidden_state + return logits, hidden_state, kv_caches def _apply_gemma3_scanned_blocks( self, @@ -957,10 +976,9 @@ def _apply_gemma3_scanned_blocks( if num_remaining_layers > 0: # We name the remainder block with a 'remainder' suffix to avoid parameter name collisions rem_layer_kwargs = {"num_of_layers": num_remaining_layers} - # pytype: disable=wrong-keyword-args layer = RemattedGemma3Block( config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, name="layers_remainder", **rem_layer_kwargs - ) + ) # pytype: disable=wrong-keyword-args y, _ = layer( y, decoder_segment_ids, diff --git a/src/MaxText/layers/deepseek.py b/src/MaxText/layers/deepseek.py index 0883dee09..5eb91fc5f 100644 --- a/src/MaxText/layers/deepseek.py +++ b/src/MaxText/layers/deepseek.py @@ -99,7 +99,7 @@ def self_attention_with_norm( model_mode=model_mode, ) - attention_lnx = attention_layer( + attention_lnx, _ = attention_layer( lnx, lnx, decoder_positions, @@ -127,7 +127,7 @@ def self_attention_with_norm( return hidden_states, intermediate_inputs -def post_process(cfg, layer_output, sow): +def post_process(cfg, layer_output, sow, kv_cache=None): """postprocessing.""" if cfg.record_internal_nn_metrics: sow("intermediates", "activation_mean", jnp.mean(layer_output)) @@ -141,7 +141,7 @@ def post_process(cfg, layer_output, sow): if cfg.scan_layers: return layer_output, None else: - return layer_output + return layer_output, kv_cache class DeepSeekDenseLayer(nn.Module): @@ -163,6 +163,8 @@ def __call__( previous_chunk=None, page_state: None | page_manager.PageState = None, slot: None | int = None, + kv_cache=None, + attention_metadata=None, ): cfg = self.config if model_mode == MODEL_MODE_PREFILL: @@ -230,6 +232,8 @@ def __call__( previous_chunk=None, page_state: None | page_manager.PageState = None, slot: None | int = None, + kv_cache=None, + attention_metadata=None, ): cfg = self.config if model_mode == MODEL_MODE_PREFILL: diff --git a/src/MaxText/layers/deepseek_batchsplit.py b/src/MaxText/layers/deepseek_batchsplit.py index bfea1fc04..5917b522c 100644 --- a/src/MaxText/layers/deepseek_batchsplit.py +++ b/src/MaxText/layers/deepseek_batchsplit.py @@ -58,6 +58,8 @@ def __call__( previous_chunk=None, page_state: None | page_manager.PageState = None, slot: None | int = None, + kv_cache=None, + attention_metadata=None, ): x = self.with_logical_constraint(inputs) x = jax.ad_checkpoint.checkpoint_name(x, "decoder_layer_input") @@ -74,7 +76,7 @@ def __call__( x += self.mlp(self.post_attention_norm(x), deterministic) x = self.dropout(x, deterministic) - return self.post_process(x) + return self.post_process(x, kv_cache) def setup(self): self.pre_attention_norm_op = self.rms_norm_layer("pre_attention_layer_norm") @@ -177,7 +179,7 @@ def attention( previous_chunk=previous_chunk, page_state=page_state, slot=slot, - ) + )[0] ) def mlp_layer(self): @@ -194,7 +196,7 @@ def dropout(self, x, deterministic): self.dropout_op(x, deterministic=deterministic) ) - def post_process(self, x): + def post_process(self, x, kv_cache=None): """Collect statistics about the output of the layer.""" if self.config.record_internal_nn_metrics: self.sow("intermediates", "activation_mean", jnp.mean(x)) @@ -208,7 +210,7 @@ def post_process(self, x): if self.config.scan_layers: return x, None else: - return x + return x, kv_cache class DeepSeekDenseLayer(DeepSeekGenericLayer): @@ -245,6 +247,8 @@ def __call__( previous_chunk=None, page_state: None | page_manager.PageState = None, slot: None | int = None, + kv_cache=None, + attention_metadata=None, split_factor: int = 2, ): x = self.with_logical_constraint(inputs) @@ -289,7 +293,7 @@ def _moe(x): x = _merge(x) x = self.dropout(x, deterministic) - return self.post_process(x) + return self.post_process(x, kv_cache) def init(self, *args, **kwargs): # Calls the parent init method for testing parity. diff --git a/src/MaxText/layers/gemma.py b/src/MaxText/layers/gemma.py index 8c9aa51ff..dcd237162 100644 --- a/src/MaxText/layers/gemma.py +++ b/src/MaxText/layers/gemma.py @@ -129,6 +129,8 @@ def __call__( page_manager=None, page_state=None, slot=None, + kv_cache=None, + attention_metadata=None, ): inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") @@ -137,13 +139,15 @@ def __call__( lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) - attention_lnx = self.self_attention( + attention_lnx, kv_cache = self.self_attention( lnx, lnx, decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, model_mode=model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, ) attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names) @@ -177,7 +181,7 @@ def __call__( if self.config.scan_layers: return layer_output, None else: - return layer_output + return layer_output, kv_cache GemmaDecoderLayerToLinen = nnx_wrappers.to_linen_class( diff --git a/src/MaxText/layers/gemma2.py b/src/MaxText/layers/gemma2.py index fe58c463c..3d0d39efe 100644 --- a/src/MaxText/layers/gemma2.py +++ b/src/MaxText/layers/gemma2.py @@ -223,6 +223,8 @@ def __call__( previous_chunk=None, page_state=None, slot=None, + kv_cache=None, + attention_metadata=None, ): inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") @@ -230,13 +232,15 @@ def __call__( lnx = self.pre_self_attention_norm_local(inputs) lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) - attention_lnx = self.self_attention_local( + attention_lnx, kv_cache = self.self_attention_local( lnx, lnx, decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, model_mode=model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, ) if self.config.use_post_attn_norm: attention_lnx = self.post_self_attention_norm_local(attention_lnx) @@ -268,7 +272,7 @@ def __call__( lnx = self.pre_self_attention_norm_global(inputs) lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) - attention_lnx = self.self_attention_global( + attention_lnx, kv_cache = self.self_attention_global( lnx, lnx, decoder_positions, @@ -311,7 +315,7 @@ def __call__( if self.config.scan_layers: return layer_output, None else: - return layer_output + return layer_output, kv_cache Gemma2DecoderLayerToLinen = nnx_wrappers.to_linen_class( diff --git a/src/MaxText/layers/gemma3.py b/src/MaxText/layers/gemma3.py index 645637b2c..fdf78f09e 100644 --- a/src/MaxText/layers/gemma3.py +++ b/src/MaxText/layers/gemma3.py @@ -189,6 +189,8 @@ def __call__( page_state=None, slot=None, bidirectional_mask=None, + kv_cache=None, + attention_metadata=None, ): cfg = self.config inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) @@ -198,7 +200,7 @@ def __call__( lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) # Self-attention block - attention_lnx = self.self_attention( + attention_lnx, kv_cache = self.self_attention( lnx, lnx, decoder_positions, @@ -206,6 +208,8 @@ def __call__( deterministic=deterministic, model_mode=model_mode, bidirectional_mask=bidirectional_mask, + kv_cache=kv_cache, + attention_metadata=attention_metadata, ) if cfg.use_post_attn_norm: attention_lnx = self.post_self_attention_norm(attention_lnx) @@ -240,7 +244,7 @@ def __call__( if cfg.scan_layers: return layer_output, None else: - return layer_output + return layer_output, kv_cache Gemma3DecoderLayerToLinen = nnx_wrappers.to_linen_class( diff --git a/src/MaxText/layers/gpt3.py b/src/MaxText/layers/gpt3.py index d961e395b..831677583 100644 --- a/src/MaxText/layers/gpt3.py +++ b/src/MaxText/layers/gpt3.py @@ -271,6 +271,8 @@ def __call__( *, model_mode: str = MODEL_MODE_TRAIN, deterministic: bool = False, + kv_cache: Array | None = None, + attention_metadata: dict[str, Any] | None = None, ): inputs_q = nn.with_logical_constraint(inputs_q, self.input_axis_names) if self.fused_qkv: @@ -312,7 +314,7 @@ def __call__( # apply output projection, output dim is set to the input dim. out = self.out_projection(inputs_q.shape[-1], out) out = checkpoint_name(out, "out_proj") - return out + return out, kv_cache # ----------------------------------------- @@ -339,6 +341,8 @@ def __call__( previous_chunk=None, page_state=None, slot=None, + kv_cache=None, + attention_metadata=None, ): cfg = self.config mesh = self.mesh @@ -381,8 +385,13 @@ def __call__( kv_quant=quantizations.configure_kv_quant(cfg), ) - attention_lnx = attention_layer( - lnx, decoder_segment_ids=decoder_segment_ids, model_mode=model_mode, deterministic=deterministic + attention_lnx, kv_cache = attention_layer( + lnx, + decoder_segment_ids=decoder_segment_ids, + model_mode=model_mode, + deterministic=deterministic, + kv_cache=kv_cache, + attention_metadata=attention_metadata, ) attention_lnx = nn.with_logical_constraint( @@ -428,4 +437,4 @@ def __call__( if cfg.scan_layers: return layer_output, None else: - return layer_output + return layer_output, kv_cache diff --git a/src/MaxText/layers/gpt_oss.py b/src/MaxText/layers/gpt_oss.py index 432738cd9..1301a46b9 100644 --- a/src/MaxText/layers/gpt_oss.py +++ b/src/MaxText/layers/gpt_oss.py @@ -145,6 +145,8 @@ def __call__( previous_chunk=None, page_state=None, slot=None, + kv_cache=None, + attention_metadata=None, ): cfg = self.config @@ -154,13 +156,15 @@ def __call__( lnx = self.pre_self_attention_layer_norm(inputs) lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed")) - attention_lnx = self.GptOssAttention( + attention_lnx, kv_cache = self.GptOssAttention( lnx, lnx, decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, model_mode=model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, ) attention_lnx = nn.with_logical_constraint( @@ -201,7 +205,7 @@ def __call__( if cfg.scan_layers: return layer_output, None else: - return layer_output + return layer_output, kv_cache GptOssDecoderLayerToLinen = nnx_wrappers.to_linen_class( diff --git a/src/MaxText/layers/llama2.py b/src/MaxText/layers/llama2.py index 05be82a2e..7148b1a5b 100644 --- a/src/MaxText/layers/llama2.py +++ b/src/MaxText/layers/llama2.py @@ -147,6 +147,8 @@ def __call__( slot: None | int = None, page_state: None | page_manager.PageState = None, previous_chunk=None, + kv_cache=None, + attention_metadata=None, ): cfg = self.config @@ -157,7 +159,7 @@ def __call__( lnx = self._maybe_shard_with_logical(lnx, self.activation_axis_names) # Self-attention block - attention_lnx = self.self_attention( + attention_lnx, kv_cache = self.self_attention( lnx, lnx, decoder_positions, @@ -168,6 +170,8 @@ def __call__( page_state=page_state, previous_chunk=previous_chunk, out_sharding=lnx_sharding, + kv_cache=kv_cache, + attention_metadata=attention_metadata, ) attention_lnx = self._maybe_shard_with_logical(attention_lnx, self.activation_axis_names) @@ -206,7 +210,7 @@ def __call__( if cfg.scan_layers: return layer_output, None else: - return layer_output + return layer_output, kv_cache LlamaDecoderLayerToLinen = nnx_wrappers.to_linen_class( diff --git a/src/MaxText/layers/llama4.py b/src/MaxText/layers/llama4.py index 90dab1d78..9db76ae5b 100644 --- a/src/MaxText/layers/llama4.py +++ b/src/MaxText/layers/llama4.py @@ -448,6 +448,8 @@ def __call__( slot: None | int = None, page_state: None | page_manager.PageState = None, previous_chunk=None, + kv_cache=None, + attention_metadata=None, ): cfg = self.config assert cfg.num_experts >= 1, "Expected the Llama4 config to have `num_experts > 1`." @@ -459,7 +461,7 @@ def __call__( lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) # Self-attention block - attention_lnx = self.self_attention( + attention_lnx, kv_cache = self.self_attention( lnx, lnx, decoder_positions, @@ -469,6 +471,8 @@ def __call__( slot=slot, page_state=page_state, previous_chunk=previous_chunk, + kv_cache=kv_cache, + attention_metadata=attention_metadata, ) attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names) intermediate_inputs = inputs + attention_lnx @@ -499,7 +503,7 @@ def __call__( if cfg.scan_layers: return layer_output, None else: - return layer_output + return layer_output, kv_cache Llama4DecoderLayerToLinen = nnx_wrappers.to_linen_class( @@ -654,7 +658,7 @@ def __call__( ): residual = hidden_states hidden_states = self.input_layer_norm(hidden_states) - hidden_states = self.self_attention_vision( + hidden_states, _ = self.self_attention_vision( inputs_q=hidden_states, inputs_kv=hidden_states, deterministic=deterministic, diff --git a/src/MaxText/layers/mistral.py b/src/MaxText/layers/mistral.py index db0888f96..643fecaae 100644 --- a/src/MaxText/layers/mistral.py +++ b/src/MaxText/layers/mistral.py @@ -132,6 +132,8 @@ def __call__( page_state: None | int = None, slot: None | int = None, previous_chunk=None, + kv_cache=None, + attention_metadata=None, ): cfg = self.config @@ -141,7 +143,7 @@ def __call__( lnx = self.pre_self_attention_layer_norm(inputs) lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) - attention_lnx = self.self_attention( + attention_lnx, kv_cache = self.self_attention( lnx, lnx, decoder_positions, @@ -151,6 +153,8 @@ def __call__( slot=slot, page_state=page_state, previous_chunk=previous_chunk, + kv_cache=kv_cache, + attention_metadata=attention_metadata, ) attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names) @@ -180,7 +184,7 @@ def __call__( if cfg.scan_layers: return layer_output, None else: - return layer_output + return layer_output, kv_cache MistralDecoderLayerToLinen = nnx_wrappers.to_linen_class( diff --git a/src/MaxText/layers/mixtral.py b/src/MaxText/layers/mixtral.py index 1a2cfa7e3..8d23e72d3 100644 --- a/src/MaxText/layers/mixtral.py +++ b/src/MaxText/layers/mixtral.py @@ -137,6 +137,8 @@ def __call__( previous_chunk=None, page_state=None, slot=None, + kv_cache=None, + attention_metadata=None, ): inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) @@ -145,7 +147,7 @@ def __call__( lnx = self.pre_self_attention_layer_norm(inputs) lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) - attention_lnx = self.self_attention( + attention_lnx, kv_cache = self.self_attention( lnx, lnx, decoder_positions, @@ -153,6 +155,8 @@ def __call__( deterministic=deterministic, model_mode=model_mode, previous_chunk=previous_chunk, + kv_cache=kv_cache, + attention_metadata=attention_metadata, ) attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names) @@ -188,7 +192,7 @@ def __call__( if self.config.scan_layers: return layer_output, None else: - return layer_output + return layer_output, kv_cache MixtralDecoderLayerToLinen = nnx_wrappers.to_linen_class( diff --git a/src/MaxText/layers/models.py b/src/MaxText/layers/models.py index 0a736b538..d4722af76 100644 --- a/src/MaxText/layers/models.py +++ b/src/MaxText/layers/models.py @@ -19,6 +19,7 @@ import jax import jax.numpy as jnp from jax.sharding import Mesh +from typing import Any from flax import linen as nn from flax import nnx @@ -127,6 +128,8 @@ def __call__( decoder_target_tokens: None | jnp.ndarray = None, decoder_target_mask: None | jnp.ndarray = None, nnx_method=None, + kv_caches: list[jax.Array] | None = None, + attention_metadata: dict[str, Any] | None = None, ): """Applies Transformer decoder-branch on encoded-input and target. @@ -154,7 +157,7 @@ def __call__( elif self.config.decoder_block == DecoderBlockType.QWEN3_MOE: bidirectional_mask = decoder_input_tokens == multimodal_utils.QWEN3_OMNI_IMAGE_TOKEN - logits, hidden_state = self.decoder( + logits, hidden_state, kv_caches = self.decoder( shared_embedding=self.shared_embedding, decoder_input_tokens=decoder_input_tokens, decoder_positions=decoder_positions, @@ -167,6 +170,8 @@ def __call__( bidirectional_mask=bidirectional_mask, image_embeddings=image_embeddings, image_masks=encoder_image_masks, + kv_caches=kv_caches, + attention_metadata=attention_metadata, ) # If we are initializing the model AND MTP is enabled, we must create @@ -201,6 +206,10 @@ def __call__( model_mode=model_mode, ) + if self.config.attention == "vllm_rpa": + # In vLLM, logits are computed separately after updating the KV cache. + return logits, hidden_state, kv_caches + return logits @@ -306,10 +315,30 @@ def __init__(self, config: Config, mesh: Mesh, quant: Quant, *, model_mode: str dummy_decoder_input_tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32) dummy_decoder_positions = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + if self.config.attention == "vllm_rpa": + try: + # pylint: disable=import-outside-toplevel + # pytype: disable=import-error + from tpu_inference.layers.common.attention_metadata import AttentionMetadata + except ImportError as e: + raise ImportError( + "vLLM RPA attention requires the vllm-tpu package. Please install it with `pip install vllm-tpu`." + ) from e + dummy_attention_metadata = AttentionMetadata( + input_positions=jnp.ones((batch_size * seq_len,), dtype=jnp.int32), + block_tables=jnp.ones((seq_len,), dtype=jnp.int32), + seq_lens=jnp.ones((1), dtype=jnp.int32), + query_start_loc=jnp.ones((2), dtype=jnp.int32), + request_distribution=jnp.ones((3), dtype=jnp.int32), + ) + else: + dummy_attention_metadata = None + self.decoder.lazy_init( shared_embedding=self.token_embedder, decoder_input_tokens=dummy_decoder_input_tokens, decoder_positions=dummy_decoder_positions, + attention_metadata=dummy_attention_metadata, ) # If MTP is enabled via config, set up the MTP block. @@ -368,6 +397,8 @@ def __call__( page_state: page_manager.PageState | None = None, decoder_target_tokens: jax.Array | None = None, decoder_target_mask: jax.Array | None = None, + kv_caches: list[jax.Array] | None = None, + attention_metadata: dict[str, Any] | None = None, ): """Applies the Zero-1 FSDP wrapped Transformer model. @@ -388,9 +419,11 @@ def __call__( decoder_target_tokens: Target tokens for the decoder (optional, used in MTP). decoder_target_mask: Target mask for the decoder (optional, used in MTP). nnx_method: Method to call on the NNX module (optional). + kv_caches: List of KV caches for each attention layer, used when invoking from vLLM (optional). + attention_metadata: Mapping to store attention metadata, used when invoking from vLLM (optional). Returns: - Logits from the Transformer model. + Logits from the Transformer model. Logits, hidden_state, kv_caches if called by vLLM. """ if decoder_segment_ids is not None and model_mode == MODEL_MODE_AUTOREGRESSIVE: raise ValueError( @@ -410,7 +443,7 @@ def __call__( elif self.config.decoder_block == DecoderBlockType.QWEN3_MOE: bidirectional_mask = decoder_input_tokens == multimodal_utils.QWEN3_OMNI_IMAGE_TOKEN - logits, hidden_state = self.decoder( + logits, hidden_state, kv_caches = self.decoder( shared_embedding=self.token_embedder, decoder_input_tokens=decoder_input_tokens, decoder_positions=decoder_positions, @@ -423,6 +456,8 @@ def __call__( bidirectional_mask=bidirectional_mask, image_embeddings=image_embeddings, image_masks=encoder_image_masks, + kv_caches=kv_caches, + attention_metadata=attention_metadata, ) # Materialize hidden state when vocab tiling is enabled @@ -461,6 +496,10 @@ def __call__( model_mode=model_mode, ) + if self.config.attention == "vllm_rpa": + # In vLLM, logits are computed separately after updating the KV cache. + return logits, hidden_state, kv_caches + return logits diff --git a/src/MaxText/layers/moe.py b/src/MaxText/layers/moe.py index 77190c36f..7cb4df407 100644 --- a/src/MaxText/layers/moe.py +++ b/src/MaxText/layers/moe.py @@ -386,16 +386,16 @@ def __init__( self.wo_bias = None def get_expert_parallelism_size(self): - return self.mesh.shape["expert"] + return self.mesh.shape.get("expert", 1) def get_tensor_parallelism_size(self): - return self.mesh.shape["tensor"] + return self.mesh.shape.get("tensor", 1) def get_tensor_transpose_parallelism_size(self): - return self.mesh.shape["tensor_transpose"] + return self.mesh.shape.get("tensor_transpose", 1) def get_context_autoregressive_parallelism_size(self): - return self.mesh.shape["context_autoregressive"] + return self.mesh.shape.get("context_autoregressive", 1) def get_topk(self, gate_logits, pre_bias_logits, rngs=None): """get topk.""" @@ -940,7 +940,11 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments): def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs): batch_size, sequence_length, _ = x.shape expert_axis_name = "expert" - expert_shard_id = jax.lax.axis_index(expert_axis_name) + num_expert_parallelism = self.get_expert_parallelism_size() + if num_expert_parallelism > 1: + expert_shard_id = jax.lax.axis_index(expert_axis_name) + else: + expert_shard_id = 0 num_expert_parallelism = self.get_expert_parallelism_size() if self.config.use_ring_of_experts: # The ring-of-experts strategy first duplicates the inputs to all diff --git a/src/MaxText/layers/qwen3.py b/src/MaxText/layers/qwen3.py index 3bb21e41e..5898d422f 100644 --- a/src/MaxText/layers/qwen3.py +++ b/src/MaxText/layers/qwen3.py @@ -16,7 +16,7 @@ # pylint: disable=arguments-differ # pylint: disable=no-name-in-module -from typing import cast +from typing import Any, cast import jax import jax.nn @@ -540,16 +540,20 @@ def __call__( decoder_positions: None | jnp.ndarray, deterministic: bool, model_mode: str, + kv_cache: None | jnp.ndarray = None, + attention_metadata: None | dict[str, Any] = None, ): - attention_output = self.attention( + attention_output, kv_cache = self.attention( inputs_q=inputs, inputs_kv=inputs, inputs_positions=decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, model_mode=model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, ) - return attention_output + return attention_output, kv_cache class Qwen3NextSparseMoeBlock(nnx.Module): @@ -795,6 +799,8 @@ def __call__( previous_chunk=None, page_state: None | page_manager.PageState = None, slot: None | int = None, + kv_cache: None | jnp.ndarray = None, + attention_metadata: None | dict[str, Any] = None, ): residual = inputs @@ -804,12 +810,14 @@ def __call__( # Conditionally apply either the Linear Attention or Full Attention block. if isinstance(self.attention, Qwen3NextFullAttention): - attention_output = cast(Qwen3NextFullAttention, self.attention)( + attention_output, kv_cache = cast(Qwen3NextFullAttention, self.attention)( hidden_states, decoder_segment_ids, decoder_positions, deterministic, model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, ) elif isinstance(self.attention, Qwen3NextGatedDeltaNet): attention_output = cast(Qwen3NextGatedDeltaNet, self.attention)(hidden_states) @@ -842,7 +850,7 @@ def __call__( self.activation_axis_names, ) - return layer_output + return layer_output, kv_cache # ----------------------------------------- @@ -922,6 +930,8 @@ def apply_attention_with_norm( decoder_positions: None | jnp.ndarray, deterministic: bool, model_mode: str, + kv_cache: None | jnp.ndarray = None, + attention_metadata: None | dict[str, Any] = None, ): """Applies self-attention with pre and post-layer normalization.""" inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) @@ -930,13 +940,15 @@ def apply_attention_with_norm( lnx = self.pre_self_attention_layer_norm(inputs) lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) # Self attention - attention_lnx = self.self_attention( + attention_lnx, kv_cache = self.self_attention( lnx, lnx, decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, model_mode=model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, ) attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names) # Residual connection after attention @@ -944,7 +956,7 @@ def apply_attention_with_norm( # Post attention norm hidden_states = self.post_self_attention_layer_norm(intermediate_inputs) hidden_states = nn.with_logical_constraint(hidden_states, self.activation_axis_names) - return hidden_states, intermediate_inputs + return hidden_states, intermediate_inputs, kv_cache # ----------------------------------------- @@ -986,9 +998,17 @@ def __call__( previous_chunk=None, page_state: None | page_manager.PageState = None, slot: None | int = None, + kv_cache: None | jnp.ndarray = None, + attention_metadata: None | dict[str, Any] = None, ): - hidden_states, intermediate_inputs = self.apply_attention_with_norm( - inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode + hidden_states, intermediate_inputs, kv_cache = self.apply_attention_with_norm( + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, ) mlp_lnx = self.mlp(hidden_states, deterministic=deterministic) @@ -1000,7 +1020,7 @@ def __call__( if self.config.scan_layers: return layer_output, None else: - return layer_output + return layer_output, kv_cache # ----------------------------------------- @@ -1042,9 +1062,17 @@ def __call__( previous_chunk=None, page_state: None | page_manager.PageState = None, slot: None | int = None, + kv_cache: None | jnp.ndarray = None, + attention_metadata: None | dict[str, Any] = None, ): - hidden_states, intermediate_inputs = self.apply_attention_with_norm( - inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode + hidden_states, intermediate_inputs, kv_cache = self.apply_attention_with_norm( + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, ) mlp_lnx, load_balance_loss = self.moe_block(hidden_states) @@ -1058,7 +1086,7 @@ def __call__( if self.config.scan_layers: return layer_output, None else: - return layer_output + return layer_output, kv_cache class Qwen3OmniMoeVisionPatchMerger(nnx.Module): @@ -1391,7 +1419,7 @@ def __call__( "height": height, "width": width, } - output = self.attn( + output, _ = self.attn( inputs_q=hidden_states, inputs_kv=hidden_states, deterministic=deterministic, diff --git a/src/MaxText/pyconfig_deprecated.py b/src/MaxText/pyconfig_deprecated.py index 5b4c0df85..9f47b15c5 100644 --- a/src/MaxText/pyconfig_deprecated.py +++ b/src/MaxText/pyconfig_deprecated.py @@ -99,7 +99,15 @@ def validate_kv_quant_axis(s: str, quantize_kvcache: bool) -> None: def validate_attention_kernel(s: str) -> None: - valid_attention_kernels = ("autoselected", "dot_product", "flash", "cudnn_flash_te", "cudnn_flash_jax", "paged") + valid_attention_kernels = ( + "autoselected", + "dot_product", + "flash", + "cudnn_flash_te", + "cudnn_flash_jax", + "paged", + "vllm_rpa", + ) if s not in valid_attention_kernels: # currently supported attention raise ValueError("Invalid attention kernel was passed. Valid options ", valid_attention_kernels) diff --git a/tests/attention_test.py b/tests/attention_test.py index e9709a429..1d7c755d0 100644 --- a/tests/attention_test.py +++ b/tests/attention_test.py @@ -19,6 +19,7 @@ import random import sys import unittest +from unittest import mock import pytest @@ -373,7 +374,7 @@ def test_autoregression(self): decode_total_length = self.cfg.max_target_length lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(self.dtype) - mha_full = self._attention_as_mha_generic( + mha_full, _ = self._attention_as_mha_generic( lnx, lnx, decoder_segment_ids=decoder_segment_ids, @@ -386,7 +387,7 @@ def test_autoregression(self): decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] decoder_positions_prefill = decoder_positions[:, 0:prefill_length] - mha_prefill = self._attention_as_mha_generic( + mha_prefill, _ = self._attention_as_mha_generic( lnx_prefill, lnx_prefill, decoder_segment_ids=decoder_segment_ids_prefill, @@ -402,7 +403,7 @@ def test_autoregression(self): for idx in range(prefill_length, decode_total_length): lnx_idx = lnx[:, idx : idx + 1, :] decoder_positions_idx = decoder_positions[:, idx : idx + 1] - mha_idx = self._attention_as_mha_generic( + mha_idx, _ = self._attention_as_mha_generic( lnx_idx, lnx_idx, inputs_positions=decoder_positions_idx, @@ -450,7 +451,7 @@ def _test_model_mode_prefill_dtype(self, dtype): rngs=self.nnx_rng, ) - mha_prefill = attention_as_mha_generic( + mha_prefill, _ = attention_as_mha_generic( lnx_prefill, lnx_prefill, decoder_segment_ids=decoder_segment_ids_prefill, @@ -498,7 +499,7 @@ def tpu_kernel_attention_helper(self, num_kv_heads): generic_state = nnx.state(attention_as_mha_generic) - mha_generic_output = attention_as_mha_generic( + mha_generic_output, _ = attention_as_mha_generic( lnx, lnx, decoder_segment_ids=decoder_segment_ids, @@ -526,7 +527,7 @@ def tpu_kernel_attention_helper(self, num_kv_heads): ) nnx.update(attention_as_mha_flash, generic_state) - mha_generic_flash_output = attention_as_mha_flash( + mha_generic_flash_output, _ = attention_as_mha_flash( lnx, lnx, decoder_segment_ids=decoder_segment_ids, @@ -596,7 +597,7 @@ def test_tpu_flash_attention_context_parallel( num_kv_heads = self.num_kv_heads lnx, decoder_segment_ids, decoder_positions = self.get_data(self.dtype) # Dot product - mha_generic_output = self._attention_as_mha_generic( + mha_generic_output, _ = self._attention_as_mha_generic( lnx, lnx, decoder_segment_ids=decoder_segment_ids, @@ -714,7 +715,7 @@ def _dot_product_attention( model_mode=MODEL_MODE_PREFILL, rngs=self.nnx_rng, ) - attention_w_layout_full = attention_w_layout( + attention_w_layout_full, _ = attention_w_layout( lnx, lnx, decoder_segment_ids=decoder_segment_ids, @@ -723,7 +724,7 @@ def _dot_product_attention( model_mode=MODEL_MODE_TRAIN, ) - attention_w_layout_prefill = attention_w_layout( + attention_w_layout_prefill, _ = attention_w_layout( lnx_prefill, lnx_prefill, decoder_segment_ids=decoder_segment_ids_prefill, @@ -739,7 +740,7 @@ def _dot_product_attention( lnx_idx = lnx[:, idx : idx + 1, :] decoder_positions_idx = decoder_positions[:, idx : idx + 1] - attention_w_layout_idx = attention_w_layout( + attention_w_layout_idx, _ = attention_w_layout( lnx_idx, lnx_idx, inputs_positions=decoder_positions_idx, @@ -828,7 +829,7 @@ def _dot_product_attention_reshape_q(self, compute_axis_order): attention_wo_reshape_q_state = nnx.state(attention_wo_reshape_q) nnx.update(attention_w_reshape_q, attention_wo_reshape_q_state) - attention_wo_reshape_q_full = attention_wo_reshape_q( + attention_wo_reshape_q_full, _ = attention_wo_reshape_q( lnx, lnx, decoder_segment_ids=decoder_segment_ids, @@ -837,7 +838,7 @@ def _dot_product_attention_reshape_q(self, compute_axis_order): model_mode=MODEL_MODE_TRAIN, ) - attention_w_reshape_q_full = attention_w_reshape_q( + attention_w_reshape_q_full, _ = attention_w_reshape_q( lnx, lnx, decoder_segment_ids=decoder_segment_ids, @@ -846,7 +847,7 @@ def _dot_product_attention_reshape_q(self, compute_axis_order): model_mode=MODEL_MODE_TRAIN, ) - attention_wo_reshape_q_prefill = attention_wo_reshape_q( + attention_wo_reshape_q_prefill, _ = attention_wo_reshape_q( lnx_prefill, lnx_prefill, decoder_segment_ids=decoder_segment_ids_prefill, @@ -860,7 +861,7 @@ def _dot_product_attention_reshape_q(self, compute_axis_order): ) ) - attention_w_reshape_q_prefill = attention_w_reshape_q( + attention_w_reshape_q_prefill, _ = attention_w_reshape_q( lnx_prefill, lnx_prefill, decoder_segment_ids=decoder_segment_ids_prefill, @@ -887,7 +888,7 @@ def _dot_product_attention_reshape_q(self, compute_axis_order): lnx_idx = lnx[:, idx : idx + 1, :] decoder_positions_idx = decoder_positions[:, idx : idx + 1] - attention_wo_reshape_q_idx = attention_wo_reshape_q( + attention_wo_reshape_q_idx, _ = attention_wo_reshape_q( lnx_idx, lnx_idx, inputs_positions=decoder_positions_idx, @@ -903,7 +904,7 @@ def _dot_product_attention_reshape_q(self, compute_axis_order): ) ) - attention_w_reshape_q_idx = attention_w_reshape_q( + attention_w_reshape_q_idx, _ = attention_w_reshape_q( lnx_idx, lnx_idx, inputs_positions=decoder_positions_idx, @@ -974,7 +975,7 @@ def test_sliding_window_attention(self): sliding_attn_state = nnx.state(sliding_attn) nnx.update(global_attn, sliding_attn_state) - global_attn_output = global_attn( + global_attn_output, _ = global_attn( lnx, lnx, decoder_segment_ids=decoder_segment_ids, @@ -983,7 +984,7 @@ def test_sliding_window_attention(self): model_mode=MODEL_MODE_TRAIN, ) - sliding_window_output = sliding_attn( + sliding_window_output, _ = sliding_attn( lnx, lnx, decoder_segment_ids=decoder_segment_ids, @@ -1022,7 +1023,7 @@ def test_sliding_window_attention(self): nnx.update(sliding_attn_full_window, sliding_attn_state) - sliding_window_output_full = sliding_attn_full_window( + sliding_window_output_full, _ = sliding_attn_full_window( lnx, lnx, decoder_segment_ids=decoder_segment_ids, @@ -1044,6 +1045,83 @@ def test_sliding_window_attention(self): ) ) + @pytest.mark.skip(reason="Requires `vllm-tpu` package which is not yet a MaxText dependency.") + @pytest.mark.tpu_only + @mock.patch("tpu_inference.layers.jax.attention_interface.sharded_ragged_paged_attention", create=True) + def test_forward_serve_vllm(self, mock_sharded_ragged_paged_attention): + """Tests the forward_serve_vllm method with mocked RPA attention.""" + # Setup config for vLLM RPA + vllm_config_arguments = self.config_arguments.copy() + vllm_config_arguments["attention"] = "vllm_rpa" + vllm_config_arguments["chunk_attn_window_size"] = 128 + config = pyconfig.initialize( + [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + **vllm_config_arguments, + ) + + seq_len = self.max_target_length + + # Create Attention instance + dummy_inputs_q = jnp.ones((self.global_batch_size, seq_len, self.embed_dim)) + dummy_inputs_kv = jnp.ones((self.global_batch_size, seq_len, self.embed_dim)) + attention_vllm = Attention( + config=config, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.max_prefill_predict_length, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + mesh=self.mesh, + attention_kernel="dot_product", + dtype=self.dtype, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + rngs=self.nnx_rng, + ) + + # Prepare inputs + lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(self.dtype) + mock_kv_cache = [jnp.ones((1,))] + + mock_attention_metadata = mock.Mock() + mock_attention_metadata.seq_lens = jnp.array([1] * self.global_batch_size) + mock_attention_metadata.block_tables = jnp.array([[0]] * self.global_batch_size) + mock_attention_metadata.query_start_loc = jnp.array(list(range(self.global_batch_size))) + mock_attention_metadata.request_distribution = jnp.array([self.global_batch_size]) + + # Mock the return value of sharded_ragged_paged_attention + total_tokens = self.global_batch_size * seq_len + mock_output_shape = (total_tokens, self.num_query_heads, self.head_dim) + mock_output = jnp.ones(mock_output_shape, dtype=self.dtype) + mock_updated_kv_cache = [jnp.zeros((1,))] + + mock_callable = mock.Mock(return_value=(mock_output, mock_updated_kv_cache)) + mock_sharded_ragged_paged_attention.return_value = mock_callable + + # Call the attention layer + output, updated_kv_cache = attention_vllm( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + kv_cache=mock_kv_cache, + attention_metadata=mock_attention_metadata, + ) + + # Assertions + mock_sharded_ragged_paged_attention.assert_called_once() + mock_callable.assert_called_once() + self.assertEqual(updated_kv_cache, mock_updated_kv_cache) + + # The output of forward_serve_vllm is reshaped back to (batch, seq, ...) + reshaped_mock_output = mock_output.reshape(self.global_batch_size, seq_len, self.num_query_heads, self.head_dim) + expected_output = attention_vllm.out_projection(reshaped_mock_output) + self.assertTrue(jnp.allclose(output, expected_output)) + self.assertEqual(output.shape, (self.global_batch_size, seq_len, self.embed_dim)) + class MLATest(parameterized.TestCase): """Test for the Multi-Headed Latent Attention""" @@ -1164,7 +1242,7 @@ def test_autoregression(self, rope_type): decode_total_length = cfg.max_target_length lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(cfg, cfg.dtype) - mla_full = mla( + mla_full, _ = mla( lnx, lnx, decoder_segment_ids=decoder_segment_ids, @@ -1177,7 +1255,7 @@ def test_autoregression(self, rope_type): decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] decoder_positions_prefill = decoder_positions[:, 0:prefill_length] - mla_prefill = mla( + mla_prefill, _ = mla( lnx_prefill, lnx_prefill, decoder_segment_ids=decoder_segment_ids_prefill, @@ -1193,7 +1271,7 @@ def test_autoregression(self, rope_type): for idx in range(prefill_length, decode_total_length): lnx_idx = lnx[:, idx : idx + 1, :] decoder_positions_idx = decoder_positions[:, idx : idx + 1] - mla_idx = mla( + mla_idx, _ = mla( lnx_idx, lnx_idx, inputs_positions=decoder_positions_idx, @@ -1340,7 +1418,7 @@ def test_tpu_flash_attention_context_parallel( cfg, mla = self.init_mla(config_arguments, rope_type="default") lnx, decoder_segment_ids, decoder_positions = self.get_data(cfg, cfg.dtype) # Dot product - mla_generic_output = mla( + mla_generic_output, _ = mla( lnx, lnx, decoder_segment_ids=decoder_segment_ids, @@ -1424,7 +1502,7 @@ def _forward_with_context_expert_parallelism(cfg_cp, mesh_cp, attention_cp, lnx, decoder_segment_ids = jax.device_put(decoder_segment_ids, pos_sharding) decoder_positions = jax.device_put(decoder_positions, pos_sharding) - attention_cp_output = attention_cp( + attention_cp_output, _ = attention_cp( lnx, lnx, decoder_segment_ids=decoder_segment_ids, @@ -1432,6 +1510,8 @@ def _forward_with_context_expert_parallelism(cfg_cp, mesh_cp, attention_cp, lnx, deterministic=True, model_mode=MODEL_MODE_TRAIN, ) + attention_cp_output = attention_cp_output[0] if isinstance(attention_cp_output, tuple) else attention_cp_output + # If load balanced cp, de-shuffle and gather along seq dim for output # Note training does not need post-shuffle. Since the target seq is also pre-shuffled, the loss remains correct if context_parallel_size > 1 and cfg_cp.context_parallel_load_balance: