[DeepSeek-V4] Implement Compressed Attention Layers#3866
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
5f54827 to
07eb3e2
Compare
37ee811 to
31329c5
Compare
07eb3e2 to
4520166
Compare
31329c5 to
22a57ff
Compare
4520166 to
10ca4f6
Compare
22a57ff to
32869e5
Compare
10ca4f6 to
31a5932
Compare
32869e5 to
c92f2e0
Compare
…ghtningIndexer) Implement compressed attention mechanisms and indexer modules for DeepSeek-V4 integration into MaxText: - CSACompressor & HCACompressor: Long-range attention compressors supporting causal block bias and YaRN frequency scaling decoupling. - LightningIndexer: Memory-efficient indexer module implementing sentinel masking and dynamic RoPE scaling. - Configuration: Register attention compression hyperparameters (compress_ratios, index_head_dim, sliding_window) in types.py and base.yml. - Parity verification: Extended unit test suite (deepseek_v4_vs_reference_test.py) validating attention compression parity against PyTorch reference implementations at atol=1e-5, rtol=1e-5.
31a5932 to
c98a34e
Compare
|
🤖 Hi @parambole, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 Hi @parambole, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR implements the core compressed attention layers for DeepSeek-V4 integration, including HCA, CSA, and the Lightning Indexer. The implementation is technically sound, follows established patterns in MaxText, and includes comprehensive parity tests against PyTorch.
🔍 General Feedback
- Efficiency: The main coordinator block uses
jnp.repeatfor broadcasting MQA keys/values, which is memory-intensive. Switching tojnp.einsumbroadcasting is recommended. - Typo: A minor typo
swapedwas found in the indexer module. - Config: Ensure
compress_ratiosis properly documented as a required list when using these attention variants to avoid runtimeIndexError.
|
|
||
| # Compute attention logits | ||
| # logits: [B, H, S, S_kv] | ||
| logits = jnp.einsum("bhsd, bhkd -> bhsk", q, k, precision=self.config.matmul_precision) * self.scaling |
There was a problem hiding this comment.
| logits = jnp.einsum("bhsd, bhkd -> bhsk", q, k, precision=self.config.matmul_precision) * self.scaling | |
| # Broadcast key/value configurations to all heads using broadcasting in einsum | |
| # k and v remain [B, 1, S_kv, D_head] | |
| k = kv | |
| v = kv | |
| # Compute attention logits with head broadcasting: [B, H, S, S_kv] | |
| logits = jnp.einsum("bhsd, b1kd -> bhsk", q, k, precision=self.config.matmul_precision) * self.scaling |
|
|
||
| # Project attention weights onto values | ||
| # attn_output: [B, H, S, D_head] | ||
| attn_output = jnp.einsum("bhsk, bhkd -> bhsd", attn_weights, v, precision=self.config.matmul_precision) |
There was a problem hiding this comment.
| attn_output = jnp.einsum("bhsk, bhkd -> bhsd", attn_weights, v, precision=self.config.matmul_precision) | |
| # Project attention weights onto values with head broadcasting: [B, H, S, D_head] | |
| attn_output = jnp.einsum("bhsk, b1kd -> bhsd", attn_weights, v, precision=self.config.matmul_precision) |
| swaped_kv = jnp.swapaxes(compressed_kv, -1, -2) | ||
| swaped_kv = jnp.expand_dims(swaped_kv, axis=1) | ||
| # scores: [B, S, H, W] | ||
| scores = jnp.matmul(q, swaped_kv) |
There was a problem hiding this comment.
| scores = jnp.matmul(q, swaped_kv) | |
| # swapped_kv: [B, 1, D_idx, W] | |
| swapped_kv = jnp.swapaxes(compressed_kv, -1, -2) | |
| swapped_kv = jnp.expand_dims(swapped_kv, axis=1) | |
| # scores: [B, S, H, W] | |
| scores = jnp.matmul(q, swapped_kv) |
| """Configuration specific to DeepSeek-V4 stateless compressed attention layers.""" | ||
|
|
||
| compress_rope_theta: float = Field(160000.0, description="Theta base frequency for long-range compressor layers.") | ||
| compress_ratios: list[int] = Field( |
There was a problem hiding this comment.
RissyRan
left a comment
There was a problem hiding this comment.
I didn't see sharding annotations. It will be good we start to add some of them in this PR? i.e. starting with those weights.
| return compressed_kv, block_bias | ||
|
|
||
|
|
||
| class DeepSeekV4Indexer(nnx.Module): |
There was a problem hiding this comment.
Have you considered to re-use v3.2 Indexer?
Ref: doc
There was a problem hiding this comment.
I agree with reusing. Can we at least have a BaseIndexer class to extract common logic (e.g., init, scoring)? Then v3.2 indexer and v4 indexer inherits from this.
Here is an example
class BaseIndexer(nnx.Module):
"""Base Lightning Indexer for Sparse Attention"""
def __init__(
self,
config: Any,
kernel_init: NdInitializer,
quant: Optional[Quant] = None,
weights_proj_in_features: Optional[int] = None,
weights_proj_dtype: Optional[DType] = None,
weights_proj_quant: Optional[Quant] = None,
rngs: Optional[nnx.Rngs] = None,
):
self.config = config
self.dtype = config.dtype
self.weight_dtype = config.weight_dtype
self.n_heads = config.indexer_n_heads
self.head_dim = config.indexer_head_dim
self.indexer_topk = config.indexer_topk
self.q_lora_rank = config.q_lora_rank
self.softmax_scale = self.head_dim**-0.5
self.weights_scaling = self.n_heads**-0.5
# Query Projection: Latent Query -> Indexer Query
self.wq_b = DenseGeneral(
in_features_shape=self.q_lora_rank,
out_features_shape=(self.n_heads, self.head_dim),
axis=-1,
kernel_init=kernel_init,
kernel_axes=("q_lora", "q_heads", "kv"),
dtype=self.dtype,
weight_dtype=self.weight_dtype,
quant=quant,
matmul_precision=config.matmul_precision,
shard_mode=config.shard_mode,
rngs=rngs,
)
wp_in_features = weights_proj_in_features if weights_proj_in_features is not None else config.emb_dim
wp_dtype = weights_proj_dtype if weights_proj_dtype is not None else self.dtype
# Projection: Input -> Importance Weights for Heads
self.weights_proj = DenseGeneral(
in_features_shape=wp_in_features,
out_features_shape=self.n_heads,
axis=-1,
kernel_init=kernel_init,
kernel_axes=("embed", "q_heads"),
dtype=wp_dtype,
weight_dtype=wp_dtype,
quant=weights_proj_quant,
matmul_precision=config.matmul_precision,
shard_mode=config.shard_mode,
rngs=rngs,
)
def prepare_query(self, low_rank_q: jnp.ndarray, inputs_positions: jnp.ndarray, apply_partial_rope) -> jnp.ndarray:
bsz, seqlen, _ = low_rank_q.shape
q = self.wq_b(low_rank_q)
q = q.reshape(bsz, seqlen, self.n_heads, self.head_dim)
q = apply_partial_rope(q, inputs_positions=inputs_positions)
return q
def compute_topk(self, q: jnp.ndarray, k: jnp.ndarray, inputs_q: jnp.ndarray, mask: Optional[jnp.ndarray]) -> tuple[jnp.ndarray, jnp.ndarray]:
logits = jnp.einsum("bthd, bsd -> btsh", q, k, precision=self.config.matmul_precision)
logits = jax.nn.relu(logits)
weights = self.weights_proj(inputs_q)
weights = weights * self.weights_scaling * self.softmax_scale
indexer_score = jnp.einsum("btsh, bth -> bts", logits, weights, precision=self.config.matmul_precision)
if mask is not None:
indexer_score += mask
_, topk_indices = jax.lax.top_k(indexer_score, k=self.indexer_topk)
return topk_indices, indexer_score
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Compressed Attention layers and long-range compressors.""" |
There was a problem hiding this comment.
For file name, may be compressed_attention.py as you mentioned in the comment here?
There was a problem hiding this comment.
maybe it is influenced byattention_mla.py?
…ry expansion parameters for Trillium parity
shuningjin
left a comment
There was a problem hiding this comment.
Thanks for the implementation! Here are my high-level comments:
-
Please replace
nnx.Linearwith MaxText DenseGeneral -
I recommend reusing some common logic from Indexer of V3.2, perhaps by abstracting into a base class
-
It unclear to me how "Additional Branch of Sliding Window Attention" from Section 2.3.3 in paper is implemented. Could you explain?
| index_head_dim: 128 | ||
| index_n_heads: 64 | ||
| index_topk: 512 |
There was a problem hiding this comment.
These can be removed. Please reuse existing indexer config:
maxtext/src/maxtext/configs/base.yml
Lines 379 to 384 in 9d79b99
| index_topk: 512 | ||
| o_groups: 8 | ||
| o_lora_rank: 1024 | ||
| sliding_window: 128 |
There was a problem hiding this comment.
- Can you make the name
sliding_windowmore specific?
It can be easily confused with sliding_window_size for attention_type=local_sliding.
maxtext/src/maxtext/configs/base.yml
Lines 358 to 362 in 9d79b99
- Where are we using
config.sliding_window? I didn't see it in code.
| MlaAttention, | ||
| MoBa, | ||
| AttentionIndexer, | ||
| DeepSeekV4AttentionConfig, | ||
| Llama4Attention, | ||
| SplashAttention, | ||
| PagedAttention, |
There was a problem hiding this comment.
nit: DeepSeekV4AttentionConfig -> DeepSeekV4Attention, to be consistent with other class in types.py.
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Compressed Attention layers and long-range compressors.""" |
There was a problem hiding this comment.
maybe it is influenced byattention_mla.py?
| rope_theta = config.compress_rope_theta | ||
|
|
||
| # Linear projection of inputs to key/value representation | ||
| self.kv_proj = nnx.Linear( |
There was a problem hiding this comment.
For all occurrences of nnx.Linear in the PR, should be replaced with MaxText's DenseGeneral.
from maxtext.layers.linears import DenseGeneral
See these examples:
maxtext/src/maxtext/layers/attentions.py
Lines 655 to 668 in 4110d13
maxtext/src/maxtext/layers/attention_mla.py
Lines 716 to 743 in 4110d13
| self.attention_type = attention_type | ||
| self.num_heads = num_heads | ||
| self.head_dim = head_dim | ||
| self.sliding_window = config.sliding_window |
There was a problem hiding this comment.
where is this used? where is "Additional Branch of Sliding Window Attention" in Sec 2.3.3?
| self.layer_idx = layer_idx | ||
| self.layer_type = config.layer_types[layer_idx] | ||
| # Sliding-only layers use the "main" (plain θ=10000) rope; CSA/HCA layers | ||
| # share the same yarn-scaled "compress" rope as their compressor. |
There was a problem hiding this comment.
Where is yarn incorporated? I didn't see it in #3865
| # Projections for query extraction and low-rank compression | ||
| self.q_a_proj = nnx.Linear( |
There was a problem hiding this comment.
nit: The query low rank logic seems identical to v3. Maybe make the name consistent with v3:
q_a_proj -> wq_a
q_b_proj -> wq_b
| if attention_mask is not None: | ||
| # Pad 4D attention mask along trailing key/value sequence axis. | ||
| # # [B, 1, Q, S + W] -> [B, 1, Q, align(S + W, sa_block_kv)] | ||
| attention_mask = jnp.pad(attention_mask, ((0, 0), (0, 0), (0, 0), (0, pad_len)), constant_values=0.0) |
There was a problem hiding this comment.
pad with -jnp.inf instead of 0? or perhaps DEFAULT_MASK_VALUE
maxtext/src/maxtext/common/common_types.py
Lines 72 to 74 in 4110d13
|
|
||
| # Reconcile structural block bias masks with runtime attention masks. | ||
| if attention_mask is not None: | ||
| if block_bias is not None: |
There was a problem hiding this comment.
What happens if attention_mask is None and block_bias is not None?
entrpn
left a comment
There was a problem hiding this comment.
LGTM. Please resolve other comments before merging.
Description
Implement compressed attention mechanisms and indexer modules required for DeepSeek-V4 integration into MaxText:
CSACompressor&HCACompressor: Long-range attention compressors supporting causal block bias and YaRN frequency scaling decoupling.LightningIndexer: Memory-efficient indexer module implementing sentinel masking and dynamic RoPE scaling.compress_ratios,index_head_dim,sliding_window) intypes.pyandbase.yml.tests/unit/deepseek_v4_vs_reference_test.py) validating attention compression parity against PyTorch reference implementations atatol=1e-5, rtol=1e-5.Tests
Tested on CPU
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.