Implement QK-Clip (Muon-Clip) for MLA attention#2968
Conversation
d29fb9d to
1096319
Compare
3d75b2c to
522ef04
Compare
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
209eba0 to
ce7aea4
Compare
|
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
📋 Review Summary
This pull request introduces QK-Clip, a training stabilization technique for MLA attention models. The implementation is well-structured, touching configuration, model layers, and the training loop as expected. The code is clear and includes comprehensive unit and integration tests, ensuring the feature is robust.
🔍 General Feedback
- The feature is gated by a configuration flag (
use_qk_clip), which is good practice. - The new utility file,
src/maxtext/utils/qk_clip_utils.py, is well-written and isolates the core logic effectively. - The test coverage is excellent, including checks for edge cases, consistency between attention mechanisms (dot-product vs. splash), and graceful handling of missing data.
|
🤖 I'm sorry @RissyRan, but I was unable to process your request. Please see the logs for more details. |
RissyRan
left a comment
There was a problem hiding this comment.
Thanks for the feature! Have some comments. I recall you mentioned we need to run some convergence tests. Wondering if those are sharding related?
|
|
||
| # 1. Attempt to locate corresponding max_logits in intermediate_outputs | ||
| curr = intermediate_outputs | ||
| try: |
There was a problem hiding this comment.
Gemini suggested:
def _get_key(node):
"""Unwraps JAX DictKey or SequenceKey to a raw string/int."""
return getattr(node, "key", getattr(node, "idx", node))
path_keys = [_get_key(p) for p in path]
shuningjin
left a comment
There was a problem hiding this comment.
Thanks for the implementation and careful testing! Overall looks good to me.
607d146 to
f3873ff
Compare
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR correctly implements the QK-Clip algorithm as a mechanism to stabilize MLA attention training by managing excessively large attention logits. While the mathematical logic (including clipping heads based on global max values, scaling Query/Key projections, and correctly ignoring shared rotary keys/values) is elegantly implemented, there are a few critical issues with JAX compilation tracing and Linen state dict structures that would cause silent failures or runtime compilation errors if deployed in their current form.
🔍 General Feedback
-
Mathematical Correctness: The approach faithfully reproduces the core clipping constraints from the paper and accurately computes both local and global
$S_{max}$ using GSPMD-compatible JAX primitives. -
Test Coverage: Great addition of comprehensive unit tests, although relying exclusively on CPU execution for the tests meant the JAX compilation-specific
ConcretizationTypeErrorwas unfortunately masked. Consider ensuring test scenarios run with@jax.jitor equivalent test decorators to catch tracing issues. -
State Traversal: The logic applied to fetch
max_logitseffectively avoids relying on internal hard-coded layer names, which is a big plus for maintainability, but it just needs a small tweak to accommodate theintermediatesnested root from Flax outputs.
135a9cb to
ab896ea
Compare
shuningjin
left a comment
There was a problem hiding this comment.
Thanks for your thoughtful implementation!
2321811 to
0db9d22
Compare
0db9d22 to
36c0b39
Compare
Imported from GitHub PR #2968 # Description Implements QK-Clip, a training stabilization technique for MLA attention models, as described in the Kimi K2 Technical Report. **Changes:** * **Core Logic:** Added `src/MaxText/utils/qk_clip_utils.py` containing `apply_qk_clip` and `calculate_max_logit_metric`. * **Layers:** Updated `AttentionOp` to `sow` max logits statistics and `AttentionMLA` to enable this when configured. * **Training:** Integrated the clipping step and `max_logits` metric reporting into `src/MaxText/train.py`. * **Tests:** Added `tests/qk_clip_test.py`. **Context:** QK-Clip mitigates training instability by preventing attention logits from growing excessively. This implementation: 1. Calculates global max logit ($S_{max}$) using GSPMD-compatible `jnp.max`. 2. Computes per-head scaling factor $\gamma = \min(1, \tau / S_{max})$. 3. Scales $W_q$ and $W_k$ while explicitly leaving shared rotary keys ($k^R$) and values ($W_v$) untouched. 4. Leverages Flax `sow` to pass statistics from layers to the training loop efficiently. # Tests * **Unit Tests:** Ran `python3 tests/qk_clip_test.py`. Verified: * Correct scaling of $W_q$ and $W_k$. * Heads below threshold are not clipped. * Shared keys and values remain untouched. * Global `max_logits` metric calculation. * Error handling for non-MLA attention types. * **Integration:** Verified `train_step` executes without shape mismatches or runtime errors. # Checklist Before submitting this PR, please make sure (put X in square brackets): - [x] I have performed a self-review of my code. For an optional AI review, add the `gemini-review` label. - [x] I have necessary comments in my code, particularly in hard-to-understand areas. - [x] I have run end-to-end tests tests and provided workload links above if applicable. - [x] I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in [our documentation](https://maxtext.readthedocs.io/en/latest/development.html#adding-new-documentation-files). Copybara import of the project: -- 0db9d22 by Gagik Amirkhanyan <agagik@google.com>: Implement QK-Clip (Muon-Clip) functionality add tests for QK-Clip logic Merging this change closes #2968 COPYBARA_INTEGRATE_REVIEW=#2968 from AI-Hypercomputer:agagik-qk-clip 0db9d22 PiperOrigin-RevId: 874946094
Description
Implements QK-Clip, a training stabilization technique for MLA attention models, as described in the Kimi K2 Technical Report.
Changes:
src/MaxText/utils/qk_clip_utils.pycontainingapply_qk_clipandcalculate_max_logit_metric.AttentionOptosowmax logits statistics andAttentionMLAto enable this when configured.max_logitsmetric reporting intosrc/MaxText/train.py.tests/qk_clip_test.py.Context:
QK-Clip mitigates training instability by preventing attention logits from growing excessively. This implementation:
jnp.max.sowto pass statistics from layers to the training loop efficiently.Tests
python3 tests/qk_clip_test.py. Verified:max_logitsmetric calculation.train_stepexecutes without shape mismatches or runtime errors.Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.