Skip to content

Conversation

@RissyRan
Copy link
Collaborator

@RissyRan RissyRan commented Jan 13, 2026

Description

Main author by @shuningjin

Background

DeepSeek V3.2 differs from DeepSeek V3 solely in the attention mechanism, aiming for efficiency in long-context scenario. While DeepSeek V3 uses Multi-head Latent Attention (MLA), DeepSeek V3.2 uses DeepSeek Sparse Attention (DSA). DSA augments MLA with two components:

  • Indexer: parametric, qk product to get index score
  • Top-k token selection: non-parametric, select top-k key/value for each query, introducing sparsity to qkv attention

What this PR does

1. Naive implementation of DeepSeek Sparse Attention (DSA)

  • Indexer:

    • qk product: currently implemented with dot product to get index scores. To be optimized.
    • (minor) RoPE: indexer applies partial RoPE to q and k based on YaRN extension. It uses the same YaRN frequency as MLA, but with concatenated layout rather than interleaved layout.
    • Based on index scores, get top-k indices and index mask
  • Top-k selection for qkv attention:

    • This is currently implemented inside dot product attention, by adding index mask to regular attention mask. To be optimized.
  • training only (no prefill / decode)

  • See changes attention_mla.py, attention_op.py

2. Onboard deepseek3.2-671b config

  • deepseek3.2-671b.yml
  • deepseek v3.2 vs. v3: HF config diff: additional config for indexer
"index_head_dim": 128, "index_n_heads": 64, "index_topk": 2048,
  • number of parameter: (1) Similar to v3, HF safetensor of v3.2 contains an extra layer for MTP which we omit. (2) Note that indexer contains extra parameter. (3) By counting, v3 has 671026419200 (671.03B) and v3.2 has671877944064 (671.88B) parameters.

3. unit test: ahead-of-time train compile for deepseek3.2-671b

4. unit test: compare output against torch code for Indexer and MLA

Reference

Future work

  • verify end-to-end training logits for deepseek3.2
  • more efficient implementation of DSA

Tests

Unit test against torch code (adapted from reference): indexer, MLA

python3 -m pytest -v --pyargs tests.check_deepseek32_vs_reference -rP -s

Unit test for train compile

python3 -m pytest -v --pyargs tests.train_compile_test -rP -s -k "test_deepseek32"

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Jan 13, 2026

Codecov Report

❌ Patch coverage is 51.25000% with 39 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/layers/attention_mla.py 53.94% 31 Missing and 4 partials ⚠️
src/MaxText/layers/attention_op.py 0.00% 2 Missing and 2 partials ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Collaborator Author

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change! I took a look at indexer part, and overall it looks good for functionality. It also has indexer logit kernel for performance, I will take a look there.

I will take a look at MLA part shortly.

@shuningjin shuningjin changed the title [DO NO MERGE] Draft for sparse DeepSeek3.2: Onboard sparse attention Jan 17, 2026
@shuningjin shuningjin self-assigned this Jan 17, 2026
Copy link
Collaborator Author

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change! Great work! A few comments.

@@ -0,0 +1,59 @@
# Copyright 2023–2025 Google LLC
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit with 2026

v_head_dim: NonNegativeInt = Field(128, description="Dimension of V heads in MLA.")


class AttentionIndexer(BaseModel):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add those 4 configs into MoE config readme doc? Could be a follow-up PR

"""Configuration for DeepSeek Sparse Attention (DSA): DeepSeek3.2-style MLA with indexer."""

use_sparse_indexer: bool = Field(False, description="If True, enables sparse indexer.")
index_head_dim: NonNegativeInt = Field(128, description="head dim for indexer")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Capitalize first letter in description to align with others. Similar comments for following.

class Indexer(nnx.Module):
"""
Indexer for DeepSeek Sparse Attention (DSA).
Introduced by DeepSeek V3.2: https://arxiv.org/pdf/2512.02556.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we could also attach the reference implementation here along with paper?

out_features_shape=(self.n_heads, self.head_dim),
axis=-1,
kernel_init=self.kernel_init,
# TODO(shuningjin): double check kernel axes
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have some concerns? I see it aligned with MLA:

kernel_axes=("q_lora", "q_heads", "kv"),

We could start with this.

# Indexer Logic
index_mask = None
if self.use_sparse_indexer:
if self.q_lora_rank == 0:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we shift this logic earlier? I think validations are in type.py file. Similar comment, if you validate model_mode

length = query.shape[-3]
target_hardware = self.mesh.devices[(0,) * self.mesh.devices.ndim].platform

if index_mask is not None and self.attention_kernel != "dot_product":
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment with sparse attention enabled with other attention type.

"megablox=True",
"per_device_batch_size=1",
"max_target_length=1024",
"attention=dot_product", # only support dot product now
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could write a TODO: update to flash attention when it's available.

index_topk = 4


SEQ_LEN = 8
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also try SEQ_LEN < index_topk to ensure end-to-end working fine?


print("torch out", pt_out)
print("jax out", jax_out)
# np.testing.assert_allclose(to_jax(pt_out / pt_out.sum()), jax_out / jax_out.sum(), rtol=1e-3, atol=1e-2)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this is like a normalization? Suggest to remove it if not needed, or keep it in the code (not in comment)

@RissyRan
Copy link
Collaborator Author

Also, don't forget to squash commits :)

Comment on lines +189 to +191
# 2. Broadcast compare against [b, t, k] to get [b, t, k, s]
# 3. Use .any() to see if a s-index is present in any of the k slots
is_topk = (jnp.arange(s) == topk_indices[..., None]).any(axis=-2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this really large Tensor [b, t, k, s] that can cause OOM?

I think you can use
jnp.put_along_axis (or jax.lax.scatter) to construct the mask directly without materializing the [b, t, k, s] tensor.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants