Add opt-in MXFP8 LM-head output projection#4825
Conversation
Introduces an opt-in TE-based LM-head ColumnParallelLinear that runs under the MXFP8 autocast context. Controlled by the new `fp8_output_proj` config flag; active only when `fp8=True` and `fp8_recipe='mxfp8'`. Defaults to off, so existing flows are unaffected. This is the main-branch equivalent of NVIDIA#4484 and NVIDIA#4489 (merged to `26.04-alpha`), ported as a self-contained module instead of layering on the alpha-only `LinearCrossEntropyModule` wrapper.
Covers:
* is_te_mxfp8_output_proj_active branches (pure-Python, no GPU)
* TELinearCrossEntropyModule constructor validations (early raises,
no GPU init required since each fires before super().__init__())
* GPTModel default uses ColumnParallelLinear
* GPTModel uses TELinearCrossEntropyModule when fp8_output_proj is
enabled under the mxfp8 recipe (Blackwell-only, skipped otherwise)
Reject fp8_output_proj=True when fp8 is off or fp8_recipe is not 'mxfp8' at config-construction time, so misconfiguration fails fast instead of only being caught by the runtime RuntimeError in TELinearCrossEntropyModule. The runtime check is retained as defense-in-depth against later mutation.
|
This PR has been automatically converted to draft because all PRs must start as drafts. When you are ready for review, click Ready for Review to begin the review process. This will:
See the contribution guide for more details. |
|
/claude review |
|
/ok to test c292e98 |
| # so GPTModel.sharded_state_dict's no-extra-state invariant still holds. | ||
| return None | ||
|
|
||
| def set_extra_state(self, state): |
There was a problem hiding this comment.
Add a warning here in case someone attempts to call set_extra_state?
|
/ok to test 859fd91 |
|
/ok to test 13976e4 |
I'll add support in a follow-up PR for this. There are a few other features that also need to be added to |
|
/ok to test 5bc6d74 |
|
/ok to test 2b624cc |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/26600989893 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/26603411168 |
What does this PR do ?
Adds an opt-in TE-based LM-head ColumnParallelLinear that runs under the MXFP8 autocast context. Controlled by the new
fp8_output_projconfig flag; active only whenfp8=Trueandfp8_recipe='mxfp8'. Defaults to off, so existing flows are unaffected.This is the main-branch equivalent of #4484 and #4489 (merged to
26.04-alpha), ported as a self-contained module instead of layering on the alpha-onlyLinearCrossEntropyModulewrapper.Changes
megatron/core/transformer/mxfp8_output_proj.pywithTELinearCrossEntropyModuleandis_te_mxfp8_output_proj_active.GPTModel.__init__conditionally swaps the output layer when the gate is on.TransformerConfig.fp8_output_projfield (defaultFalse).Tests
tests/unit_tests/transformer/test_mxfp8_output_proj.py:ColumnParallelLinear.TELinearCrossEntropyModuleunder mxfp8 (Blackwell-gated).Contribution process
Pre-checks