Skip to content

Implement QK-Clip (Muon-Clip) for MLA attention#2968

Closed
gagika wants to merge 1 commit intomainfrom
agagik-qk-clip
Closed

Implement QK-Clip (Muon-Clip) for MLA attention#2968
gagika wants to merge 1 commit intomainfrom
agagik-qk-clip

Conversation

@gagika
Copy link
Copy Markdown
Collaborator

@gagika gagika commented Jan 20, 2026

Description

Implements QK-Clip, a training stabilization technique for MLA attention models, as described in the Kimi K2 Technical Report.

Changes:

  • Core Logic: Added src/MaxText/utils/qk_clip_utils.py containing apply_qk_clip and calculate_max_logit_metric.
  • Layers: Updated AttentionOp to sow max logits statistics and AttentionMLA to enable this when configured.
  • Training: Integrated the clipping step and max_logits metric reporting into src/MaxText/train.py.
  • Tests: Added tests/qk_clip_test.py.

Context:
QK-Clip mitigates training instability by preventing attention logits from growing excessively. This implementation:

  1. Calculates global max logit ($S_{max}$) using GSPMD-compatible jnp.max.
  2. Computes per-head scaling factor $\gamma = \min(1, \tau / S_{max})$.
  3. Scales $W_q$ and $W_k$ while explicitly leaving shared rotary keys ($k^R$) and values ($W_v$) untouched.
  4. Leverages Flax sow to pass statistics from layers to the training loop efficiently.

Tests

  • Unit Tests: Ran python3 tests/qk_clip_test.py. Verified:
    • Correct scaling of $W_q$ and $W_k$.
    • Heads below threshold are not clipped.
    • Shared keys and values remain untouched.
    • Global max_logits metric calculation.
    • Error handling for non-MLA attention types.
  • Integration: Verified train_step executes without shape mismatches or runtime errors.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@gagika gagika force-pushed the agagik-qk-clip branch 9 times, most recently from 3d75b2c to 522ef04 Compare February 1, 2026 17:11
@codecov
Copy link
Copy Markdown

codecov Bot commented Feb 1, 2026

Codecov Report

❌ Patch coverage is 66.93548% with 41 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/attention_op.py 60.00% 19 Missing and 3 partials ⚠️
src/maxtext/utils/qk_clip_utils.py 77.04% 7 Missing and 7 partials ⚠️
src/maxtext/trainers/pre_train/train.py 16.66% 4 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Feb 5, 2026

🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This pull request introduces QK-Clip, a training stabilization technique for MLA attention models. The implementation is well-structured, touching configuration, model layers, and the training loop as expected. The code is clear and includes comprehensive unit and integration tests, ensuring the feature is robust.

🔍 General Feedback

  • The feature is gated by a configuration flag (use_qk_clip), which is good practice.
  • The new utility file, src/maxtext/utils/qk_clip_utils.py, is well-written and isolates the core logic effectively.
  • The test coverage is excellent, including checks for edge cases, consistency between attention mechanisms (dot-product vs. splash), and graceful handling of missing data.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Feb 5, 2026

🤖 I'm sorry @RissyRan, but I was unable to process your request. Please see the logs for more details.

Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feature! Have some comments. I recall you mentioned we need to run some convergence tests. Wondering if those are sharding related?

Comment thread src/maxtext/configs/base.yml
Comment thread src/maxtext/layers/attention_mla.py
Comment thread src/maxtext/layers/attention_op.py
Comment thread src/MaxText/layers/attention_op.py Outdated
Comment thread src/MaxText/layers/attention_op.py Outdated
Comment thread src/maxtext/layers/attention_op.py Outdated
Comment thread tests/unit/qk_clip_test.py
Comment thread src/maxtext/utils/qk_clip_utils.py Outdated
Comment thread src/maxtext/utils/qk_clip_utils.py Outdated

# 1. Attempt to locate corresponding max_logits in intermediate_outputs
curr = intermediate_outputs
try:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gemini suggested:

def _get_key(node):
    """Unwraps JAX DictKey or SequenceKey to a raw string/int."""
    return getattr(node, "key", getattr(node, "idx", node))

path_keys = [_get_key(p) for p in path]

Comment thread src/maxtext/utils/qk_clip_utils.py Outdated
Copy link
Copy Markdown
Collaborator

@shuningjin shuningjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the implementation and careful testing! Overall looks good to me.

Comment thread src/maxtext/utils/qk_clip_utils.py Outdated
Comment thread tests/unit/qk_clip_test.py
Comment thread tests/utils/qk_clip_test.py Outdated
Comment thread tests/unit/qk_clip_test.py
Comment thread src/maxtext/layers/attention_mla.py
Comment thread src/maxtext/layers/attention_mla.py
Comment thread src/maxtext/layers/attention_op.py
@gagika gagika force-pushed the agagik-qk-clip branch 5 times, most recently from 607d146 to f3873ff Compare February 23, 2026 06:34
@github-actions
Copy link
Copy Markdown

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR correctly implements the QK-Clip algorithm as a mechanism to stabilize MLA attention training by managing excessively large attention logits. While the mathematical logic (including clipping heads based on global max values, scaling Query/Key projections, and correctly ignoring shared rotary keys/values) is elegantly implemented, there are a few critical issues with JAX compilation tracing and Linen state dict structures that would cause silent failures or runtime compilation errors if deployed in their current form.

🔍 General Feedback

  • Mathematical Correctness: The approach faithfully reproduces the core clipping constraints from the paper and accurately computes both local and global $S_{max}$ using GSPMD-compatible JAX primitives.
  • Test Coverage: Great addition of comprehensive unit tests, although relying exclusively on CPU execution for the tests meant the JAX compilation-specific ConcretizationTypeError was unfortunately masked. Consider ensuring test scenarios run with @jax.jit or equivalent test decorators to catch tracing issues.
  • State Traversal: The logic applied to fetch max_logits effectively avoids relying on internal hard-coded layer names, which is a big plus for maintainability, but it just needs a small tweak to accommodate the intermediates nested root from Flax outputs.

Comment thread src/maxtext/utils/qk_clip_utils.py Outdated
Comment thread src/maxtext/layers/attention_op.py
Comment thread src/maxtext/layers/attention_op.py
Comment thread src/maxtext/layers/attention_op.py
@gagika gagika force-pushed the agagik-qk-clip branch 4 times, most recently from 135a9cb to ab896ea Compare February 23, 2026 21:30
Copy link
Copy Markdown
Collaborator

@shuningjin shuningjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your thoughtful implementation!

Comment thread src/maxtext/layers/attention_op.py
@gagika gagika force-pushed the agagik-qk-clip branch 3 times, most recently from 2321811 to 0db9d22 Compare February 24, 2026 01:54
Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Comment thread src/maxtext/layers/attention_mla.py
Shuwen-Fang pushed a commit that referenced this pull request Mar 13, 2026
Imported from GitHub PR #2968

# Description

Implements QK-Clip, a training stabilization technique for MLA attention models, as described in the Kimi K2 Technical Report.

**Changes:**
* **Core Logic:** Added `src/MaxText/utils/qk_clip_utils.py` containing `apply_qk_clip` and `calculate_max_logit_metric`.
* **Layers:** Updated `AttentionOp` to `sow` max logits statistics and `AttentionMLA` to enable this when configured.
* **Training:** Integrated the clipping step and `max_logits` metric reporting into `src/MaxText/train.py`.
* **Tests:** Added `tests/qk_clip_test.py`.

**Context:**
QK-Clip mitigates training instability by preventing attention logits from growing excessively. This implementation:
1.  Calculates global max logit ($S_{max}$) using GSPMD-compatible `jnp.max`.
2.  Computes per-head scaling factor $\gamma = \min(1, \tau / S_{max})$.
3.  Scales $W_q$ and $W_k$ while explicitly leaving shared rotary keys ($k^R$) and values ($W_v$) untouched.
4.  Leverages Flax `sow` to pass statistics from layers to the training loop efficiently.

# Tests

* **Unit Tests:** Ran `python3 tests/qk_clip_test.py`. Verified:
    * Correct scaling of $W_q$ and $W_k$.
    * Heads below threshold are not clipped.
    * Shared keys and values remain untouched.
    * Global `max_logits` metric calculation.
    * Error handling for non-MLA attention types.
* **Integration:** Verified `train_step` executes without shape mismatches or runtime errors.

# Checklist

Before submitting this PR, please make sure (put X in square brackets):
- [x] I have performed a self-review of my code. For an optional AI review, add the `gemini-review` label.
- [x] I have necessary comments in my code, particularly in hard-to-understand areas.
- [x] I have run end-to-end tests tests and provided workload links above if applicable.
- [x] I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in [our documentation](https://maxtext.readthedocs.io/en/latest/development.html#adding-new-documentation-files).

Copybara import of the project:

--
0db9d22 by Gagik Amirkhanyan <agagik@google.com>:

Implement QK-Clip (Muon-Clip) functionality add tests for QK-Clip logic

Merging this change closes #2968

COPYBARA_INTEGRATE_REVIEW=#2968 from AI-Hypercomputer:agagik-qk-clip 0db9d22
PiperOrigin-RevId: 874946094
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants