Skip to content

[DeepSeek-V4] Implement Compressed Attention Layers#3866

Open
parambole wants to merge 5 commits into
dsv4-moe-routing-primitivesfrom
deepseek_v4_compressed_attention
Open

[DeepSeek-V4] Implement Compressed Attention Layers#3866
parambole wants to merge 5 commits into
dsv4-moe-routing-primitivesfrom
deepseek_v4_compressed_attention

Conversation

@parambole
Copy link
Copy Markdown
Collaborator

@parambole parambole commented May 11, 2026

Description

Implement compressed attention mechanisms and indexer modules required for DeepSeek-V4 integration into MaxText:

  • CSACompressor & HCACompressor: Long-range attention compressors supporting causal block bias and YaRN frequency scaling decoupling.
  • LightningIndexer: Memory-efficient indexer module implementing sentinel masking and dynamic RoPE scaling.
  • Configuration: Register attention compression hyperparameters (compress_ratios, index_head_dim, sliding_window) in types.py and base.yml.
  • Unit test suite (tests/unit/deepseek_v4_vs_reference_test.py) validating attention compression parity against PyTorch reference implementations at atol=1e-5, rtol=1e-5.

Tests

Tested on CPU

pytest  tests/unit/deepseek_v4_vs_reference_test.py

======================= 10 passed, 10 warnings in 20.42s =======================
tests/unit/deepseek_v4_vs_reference_test.py ..........                   [100%]

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 11, 2026

Codecov Report

❌ Patch coverage is 0% with 262 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/attention_compressed.py 0.00% 262 Missing ⚠️

📢 Thoughts on this report? Let us know!

@parambole parambole force-pushed the deepseek_v4_compressed_attention branch from 5f54827 to 07eb3e2 Compare May 11, 2026 19:39
@parambole parambole changed the base branch from deepseek_v4_core_primitives to dsv4-moe-routing-primitives May 11, 2026 20:29
@parambole parambole force-pushed the dsv4-moe-routing-primitives branch from 37ee811 to 31329c5 Compare May 11, 2026 20:38
@parambole parambole force-pushed the deepseek_v4_compressed_attention branch from 07eb3e2 to 4520166 Compare May 11, 2026 20:43
@parambole parambole force-pushed the dsv4-moe-routing-primitives branch from 31329c5 to 22a57ff Compare May 12, 2026 17:23
@parambole parambole force-pushed the deepseek_v4_compressed_attention branch from 4520166 to 10ca4f6 Compare May 12, 2026 17:23
@parambole parambole force-pushed the dsv4-moe-routing-primitives branch from 22a57ff to 32869e5 Compare May 12, 2026 21:12
@parambole parambole force-pushed the deepseek_v4_compressed_attention branch from 10ca4f6 to 31a5932 Compare May 12, 2026 21:13
@parambole parambole force-pushed the dsv4-moe-routing-primitives branch from 32869e5 to c92f2e0 Compare May 14, 2026 17:51
…ghtningIndexer)

Implement compressed attention mechanisms and indexer modules for DeepSeek-V4 integration into MaxText:

- CSACompressor & HCACompressor: Long-range attention compressors supporting causal block bias and YaRN frequency scaling decoupling.
- LightningIndexer: Memory-efficient indexer module implementing sentinel masking and dynamic RoPE scaling.
- Configuration: Register attention compression hyperparameters (compress_ratios, index_head_dim, sliding_window) in types.py and base.yml.
- Parity verification: Extended unit test suite (deepseek_v4_vs_reference_test.py) validating attention compression parity against PyTorch reference implementations at atol=1e-5, rtol=1e-5.
@parambole parambole force-pushed the deepseek_v4_compressed_attention branch from 31a5932 to c98a34e Compare May 14, 2026 17:53
@parambole parambole changed the title Implement DeepSeek-V4 Compressed Attention Layers [DeepSeek-V4] Implement Compressed Attention Layers May 14, 2026
@github-actions
Copy link
Copy Markdown

🤖 Hi @parambole, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 Hi @parambole, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR implements the core compressed attention layers for DeepSeek-V4 integration, including HCA, CSA, and the Lightning Indexer. The implementation is technically sound, follows established patterns in MaxText, and includes comprehensive parity tests against PyTorch.

🔍 General Feedback

  • Efficiency: The main coordinator block uses jnp.repeat for broadcasting MQA keys/values, which is memory-intensive. Switching to jnp.einsum broadcasting is recommended.
  • Typo: A minor typo swaped was found in the indexer module.
  • Config: Ensure compress_ratios is properly documented as a required list when using these attention variants to avoid runtime IndexError.


# Compute attention logits
# logits: [B, H, S, S_kv]
logits = jnp.einsum("bhsd, bhkd -> bhsk", q, k, precision=self.config.matmul_precision) * self.scaling
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Memory/Performance: `jnp.repeat` creates large intermediate tensors that are unnecessary given the shared MQA structure. You can use broadcasting in `jnp.einsum` to achieve the same result more efficiently.
Suggested change
logits = jnp.einsum("bhsd, bhkd -> bhsk", q, k, precision=self.config.matmul_precision) * self.scaling
# Broadcast key/value configurations to all heads using broadcasting in einsum
# k and v remain [B, 1, S_kv, D_head]
k = kv
v = kv
# Compute attention logits with head broadcasting: [B, H, S, S_kv]
logits = jnp.einsum("bhsd, b1kd -> bhsk", q, k, precision=self.config.matmul_precision) * self.scaling


# Project attention weights onto values
# attn_output: [B, H, S, D_head]
attn_output = jnp.einsum("bhsk, bhkd -> bhsd", attn_weights, v, precision=self.config.matmul_precision)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Memory/Performance: Use broadcasting in `jnp.einsum` here as well to avoid the `jnp.repeat` from earlier.
Suggested change
attn_output = jnp.einsum("bhsk, bhkd -> bhsd", attn_weights, v, precision=self.config.matmul_precision)
# Project attention weights onto values with head broadcasting: [B, H, S, D_head]
attn_output = jnp.einsum("bhsk, b1kd -> bhsd", attn_weights, v, precision=self.config.matmul_precision)

swaped_kv = jnp.swapaxes(compressed_kv, -1, -2)
swaped_kv = jnp.expand_dims(swaped_kv, axis=1)
# scores: [B, S, H, W]
scores = jnp.matmul(q, swaped_kv)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Typo: "swaped" should be "swapped".
Suggested change
scores = jnp.matmul(q, swaped_kv)
# swapped_kv: [B, 1, D_idx, W]
swapped_kv = jnp.swapaxes(compressed_kv, -1, -2)
swapped_kv = jnp.expand_dims(swapped_kv, axis=1)
# scores: [B, S, H, W]
scores = jnp.matmul(q, swapped_kv)

"""Configuration specific to DeepSeek-V4 stateless compressed attention layers."""

compress_rope_theta: float = Field(160000.0, description="Theta base frequency for long-range compressor layers.")
compress_ratios: list[int] = Field(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Configuration: The default `compress_ratios` is an empty list. Since `DeepSeekV4Attention` relies on this list having at least `layer_idx + 1` elements when using compressed layer types, this will cause an `IndexError` at runtime unless the user provides a full list. It might be better to provide a default or add a check with a clear error message.

Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see sharding annotations. It will be good we start to add some of them in this PR? i.e. starting with those weights.

return compressed_kv, block_bias


class DeepSeekV4Indexer(nnx.Module):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you considered to re-use v3.2 Indexer?

Ref: doc

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with reusing. Can we at least have a BaseIndexer class to extract common logic (e.g., init, scoring)? Then v3.2 indexer and v4 indexer inherits from this.

Here is an example

class BaseIndexer(nnx.Module):
  """Base Lightning Indexer for Sparse Attention"""

  def __init__(
      self,
      config: Any,
      kernel_init: NdInitializer,
      quant: Optional[Quant] = None,
      weights_proj_in_features: Optional[int] = None,
      weights_proj_dtype: Optional[DType] = None,
      weights_proj_quant: Optional[Quant] = None,
      rngs: Optional[nnx.Rngs] = None,
  ):
    self.config = config
    self.dtype = config.dtype
    self.weight_dtype = config.weight_dtype
    self.n_heads = config.indexer_n_heads
    self.head_dim = config.indexer_head_dim
    self.indexer_topk = config.indexer_topk
    self.q_lora_rank = config.q_lora_rank
    self.softmax_scale = self.head_dim**-0.5
    self.weights_scaling = self.n_heads**-0.5 

    # Query Projection: Latent Query -> Indexer Query
    self.wq_b = DenseGeneral(
        in_features_shape=self.q_lora_rank,
        out_features_shape=(self.n_heads, self.head_dim),
        axis=-1,
        kernel_init=kernel_init,
        kernel_axes=("q_lora", "q_heads", "kv"),
        dtype=self.dtype,
        weight_dtype=self.weight_dtype,
        quant=quant,
        matmul_precision=config.matmul_precision,
        shard_mode=config.shard_mode,
        rngs=rngs,
    )

    wp_in_features = weights_proj_in_features if weights_proj_in_features is not None else config.emb_dim
    wp_dtype = weights_proj_dtype if weights_proj_dtype is not None else self.dtype

    # Projection: Input -> Importance Weights for Heads
    self.weights_proj = DenseGeneral(
        in_features_shape=wp_in_features,
        out_features_shape=self.n_heads,
        axis=-1,
        kernel_init=kernel_init,
        kernel_axes=("embed", "q_heads"),
        dtype=wp_dtype,
        weight_dtype=wp_dtype,
        quant=weights_proj_quant,
        matmul_precision=config.matmul_precision,
        shard_mode=config.shard_mode,
        rngs=rngs,
    )

  def prepare_query(self, low_rank_q: jnp.ndarray, inputs_positions: jnp.ndarray, apply_partial_rope) -> jnp.ndarray:
    bsz, seqlen, _ = low_rank_q.shape
    q = self.wq_b(low_rank_q)
    q = q.reshape(bsz, seqlen, self.n_heads, self.head_dim)
    q = apply_partial_rope(q, inputs_positions=inputs_positions)
    return q

  def compute_topk(self, q: jnp.ndarray, k: jnp.ndarray, inputs_q: jnp.ndarray, mask: Optional[jnp.ndarray]) -> tuple[jnp.ndarray, jnp.ndarray]:
    logits = jnp.einsum("bthd, bsd -> btsh", q, k, precision=self.config.matmul_precision)
    logits = jax.nn.relu(logits)
    weights = self.weights_proj(inputs_q)
    weights = weights * self.weights_scaling * self.softmax_scale
    indexer_score = jnp.einsum("btsh, bth -> bts", logits, weights, precision=self.config.matmul_precision)

    if mask is not None:
      indexer_score += mask

    _, topk_indices = jax.lax.top_k(indexer_score, k=self.indexer_topk)
    return topk_indices, indexer_score

# See the License for the specific language governing permissions and
# limitations under the License.

"""Compressed Attention layers and long-range compressors."""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For file name, may be compressed_attention.py as you mentioned in the comment here?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it is influenced byattention_mla.py?

Copy link
Copy Markdown
Collaborator

@shuningjin shuningjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the implementation! Here are my high-level comments:

  1. Please replace nnx.Linear with MaxText DenseGeneral

  2. I recommend reusing some common logic from Indexer of V3.2, perhaps by abstracting into a base class

  3. It unclear to me how "Additional Branch of Sliding Window Attention" from Section 2.3.3 in paper is implemented. Could you explain?

Comment on lines +411 to +413
index_head_dim: 128
index_n_heads: 64
index_topk: 512
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These can be removed. Please reuse existing indexer config:

# DeepSeek Sparse Attention (DSA)
# deepseek3.2 introduces indexer in MLA
use_indexer: False
indexer_head_dim: 128
indexer_n_heads: 64
indexer_topk: 2048

index_topk: 512
o_groups: 8
o_lora_rank: 1024
sliding_window: 128
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Can you make the name sliding_window more specific?

It can be easily confused with sliding_window_size for attention_type=local_sliding.

attention_type: 'global' # Supported attention_type: global, local_sliding, chunk, mla
share_kv_projections: False # Note: Not compatible with attention_type='mla'
attention_bias: False # If True, adds a learnable bias to the query, key, and value projections
attention_sink: False
sliding_window_size: 0

  1. Where are we using config.sliding_window? I didn't see it in code.

Comment on lines 2240 to 2246
MlaAttention,
MoBa,
AttentionIndexer,
DeepSeekV4AttentionConfig,
Llama4Attention,
SplashAttention,
PagedAttention,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: DeepSeekV4AttentionConfig -> DeepSeekV4Attention, to be consistent with other class in types.py.

# See the License for the specific language governing permissions and
# limitations under the License.

"""Compressed Attention layers and long-range compressors."""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it is influenced byattention_mla.py?

rope_theta = config.compress_rope_theta

# Linear projection of inputs to key/value representation
self.kv_proj = nnx.Linear(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For all occurrences of nnx.Linear in the PR, should be replaced with MaxText's DenseGeneral.

from maxtext.layers.linears import DenseGeneral

See these examples:

return DenseGeneral(
in_features_shape=self.convert_dense_general_inputs_shape(inputs_kv_shape),
out_features_shape=(self.num_kv_heads, self.head_dim),
axis=-1,
kernel_init=self.kernel_init,
kernel_axes=kernel_axes,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
quant=self.quant,
shard_mode=self.config.shard_mode,
matmul_precision=self.config.matmul_precision,
use_bias=self.use_bias_in_projections,
rngs=self.rngs,
)

def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None:
"""Initializes the MLA-specific projections."""
# Assert required configuration parameters for MLA attention.
assert (
self.config.attention_type == AttentionType.MLA.value
), f"MLA requires MLA attention type {AttentionType.MLA.value}"
assert self.kv_lora_rank > 0, "KV LoRA rank must be > 0"
assert self.qk_nope_head_dim > 0, "QK NoPe head dim must be > 0"
assert self.qk_rope_head_dim > 0, "QK RoPE head dim must be > 0"
assert self.v_head_dim > 0, "V head dim must be > 0"
assert self.num_query_heads == self.num_kv_heads, "MLA requires equal number of query and kv heads"
assert not self.config.fused_qkv, "Fused QKV is not supported for MLA"
if self.q_lora_rank == 0:
# Standard Q projection (without LoRA).
self.query = DenseGeneral(
in_features_shape=self.config.emb_dim,
out_features_shape=(self.num_query_heads, self.qk_head_dim),
axis=-1,
kernel_init=self.kernel_init,
kernel_axes=("embed", "q_heads", "kv"),
dtype=self.dtype,
weight_dtype=self.weight_dtype,
quant=self.quant,
matmul_precision=self.config.matmul_precision,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)

self.attention_type = attention_type
self.num_heads = num_heads
self.head_dim = head_dim
self.sliding_window = config.sliding_window
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is this used? where is "Additional Branch of Sliding Window Attention" in Sec 2.3.3?

self.layer_idx = layer_idx
self.layer_type = config.layer_types[layer_idx]
# Sliding-only layers use the "main" (plain θ=10000) rope; CSA/HCA layers
# share the same yarn-scaled "compress" rope as their compressor.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is yarn incorporated? I didn't see it in #3865

Comment on lines +801 to +802
# Projections for query extraction and low-rank compression
self.q_a_proj = nnx.Linear(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The query low rank logic seems identical to v3. Maybe make the name consistent with v3:

q_a_proj -> wq_a
q_b_proj -> wq_b

if attention_mask is not None:
# Pad 4D attention mask along trailing key/value sequence axis.
# # [B, 1, Q, S + W] -> [B, 1, Q, align(S + W, sa_block_kv)]
attention_mask = jnp.pad(attention_mask, ((0, 0), (0, 0), (0, 0), (0, pad_len)), constant_values=0.0)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pad with -jnp.inf instead of 0? or perhaps DEFAULT_MASK_VALUE

# A large negative mask value is used for masking to ensure that the
# softmax function assigns an extremely low probability to the masked positions.
DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max)


# Reconcile structural block bias masks with runtime attention masks.
if attention_mask is not None:
if block_bias is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if attention_mask is None and block_bias is not None?

Copy link
Copy Markdown
Collaborator

@entrpn entrpn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Please resolve other comments before merging.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants