-
Notifications
You must be signed in to change notification settings - Fork 453
DeepSeek3.2: Onboard sparse attention #2933
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
RissyRan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the change! I took a look at indexer part, and overall it looks good for functionality. It also has indexer logit kernel for performance, I will take a look there.
I will take a look at MLA part shortly.
94b73d8 to
fe2ea34
Compare
RissyRan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the change! Great work! A few comments.
| @@ -0,0 +1,59 @@ | |||
| # Copyright 2023–2025 Google LLC | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit with 2026
| v_head_dim: NonNegativeInt = Field(128, description="Dimension of V heads in MLA.") | ||
|
|
||
|
|
||
| class AttentionIndexer(BaseModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add those 4 configs into MoE config readme doc? Could be a follow-up PR
| """Configuration for DeepSeek Sparse Attention (DSA): DeepSeek3.2-style MLA with indexer.""" | ||
|
|
||
| use_sparse_indexer: bool = Field(False, description="If True, enables sparse indexer.") | ||
| index_head_dim: NonNegativeInt = Field(128, description="head dim for indexer") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Capitalize first letter in description to align with others. Similar comments for following.
| class Indexer(nnx.Module): | ||
| """ | ||
| Indexer for DeepSeek Sparse Attention (DSA). | ||
| Introduced by DeepSeek V3.2: https://arxiv.org/pdf/2512.02556. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think we could also attach the reference implementation here along with paper?
| out_features_shape=(self.n_heads, self.head_dim), | ||
| axis=-1, | ||
| kernel_init=self.kernel_init, | ||
| # TODO(shuningjin): double check kernel axes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have some concerns? I see it aligned with MLA:
maxtext/src/MaxText/layers/attention_mla.py
Line 425 in d4a259d
| kernel_axes=("q_lora", "q_heads", "kv"), |
We could start with this.
| # Indexer Logic | ||
| index_mask = None | ||
| if self.use_sparse_indexer: | ||
| if self.q_lora_rank == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we shift this logic earlier? I think validations are in type.py file. Similar comment, if you validate model_mode
| length = query.shape[-3] | ||
| target_hardware = self.mesh.devices[(0,) * self.mesh.devices.ndim].platform | ||
|
|
||
| if index_mask is not None and self.attention_kernel != "dot_product": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment with sparse attention enabled with other attention type.
| "megablox=True", | ||
| "per_device_batch_size=1", | ||
| "max_target_length=1024", | ||
| "attention=dot_product", # only support dot product now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could write a TODO: update to flash attention when it's available.
| index_topk = 4 | ||
|
|
||
|
|
||
| SEQ_LEN = 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we also try SEQ_LEN < index_topk to ensure end-to-end working fine?
|
|
||
| print("torch out", pt_out) | ||
| print("jax out", jax_out) | ||
| # np.testing.assert_allclose(to_jax(pt_out / pt_out.sum()), jax_out / jax_out.sum(), rtol=1e-3, atol=1e-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems this is like a normalization? Suggest to remove it if not needed, or keep it in the code (not in comment)
|
Also, don't forget to squash commits :) |
| # 2. Broadcast compare against [b, t, k] to get [b, t, k, s] | ||
| # 3. Use .any() to see if a s-index is present in any of the k slots | ||
| is_topk = (jnp.arange(s) == topk_indices[..., None]).any(axis=-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't this really large Tensor [b, t, k, s] that can cause OOM?
I think you can use
jnp.put_along_axis (or jax.lax.scatter) to construct the mask directly without materializing the [b, t, k, s] tensor.
Description
Main author by @shuningjin
Background
DeepSeek V3.2 differs from DeepSeek V3 solely in the attention mechanism, aiming for efficiency in long-context scenario. While DeepSeek V3 uses Multi-head Latent Attention (MLA), DeepSeek V3.2 uses DeepSeek Sparse Attention (DSA). DSA augments MLA with two components:
What this PR does
1. Naive implementation of DeepSeek Sparse Attention (DSA)
Indexer:
Top-k selection for qkv attention:
training only (no prefill / decode)
See changes
attention_mla.py,attention_op.py2. Onboard deepseek3.2-671b config
deepseek3.2-671b.yml671026419200(671.03B) and v3.2 has671877944064(671.88B) parameters.3. unit test: ahead-of-time train compile for deepseek3.2-671b
4. unit test: compare output against torch code for Indexer and MLA
check_deepseek32_vs_reference.pyReference
Future work
Tests
Unit test against torch code (adapted from reference): indexer, MLA
Unit test for train compile
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.