Skip to content

Triton RMSNorm Optimizations#593

Merged
Micky774 merged 11 commits into
devfrom
zain/rms-opt
Jun 2, 2026
Merged

Triton RMSNorm Optimizations#593
Micky774 merged 11 commits into
devfrom
zain/rms-opt

Conversation

@Micky774
Copy link
Copy Markdown
Contributor

@Micky774 Micky774 commented May 20, 2026

Description

Optimizes the Triton RMSNorm forward and backward kernels and adds an LDS-tiled FP8 transpose path. Measured 10%-50% improvements across a representative shape sweep for bf16 w/ no quantization or FP8 quant, and improvements of 3x-8x on FP8 Transpose outputs.

Benchmarks generated by this script.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Loop-invariant hoisting.
    • Fwd non-blocked path: gamma load + ZERO_CENTERED_GAMMA adjustment + 1/n_cols hoisted outside the persistent row loop.
    • Bwd non-blocked path: same gamma hoist; inv_n_cols hoisted.
    • Bwd both paths: per-row c_scalar = nf*nf*grad_sum*inv_n_cols computed once before the dx/dg loop; dx expression refactored to nf * (dz*g - c*x) (saves one multiply per element).
  • Autotune wiring for bwd kernels. _rmsnorm_bwd_triton and _rmsnorm_bwd_dg_reduce_triton now follow the impl + autotune-wrapper dispatch pattern already used by the fwd kernel. te_rmsnorm_bwd_triton takes an autotune: bool = True kwarg; when off it uses the previously-hardcoded num_warps=8 + fixed BLOCK_SIZE_M/N=128/64 reduce tile.
  • External LDS-tiled FP8 transpose kernel. New _fp8_transpose_2d_impl (+ autotune wrapper) replaces the in-kernel out_transpose_ptr + cols * stride + row_idx strided byte stores that were uncoalesced (one byte per thread to a different cache line). The new kernel does a coalesced (BLOCK_M, BLOCK_N) read, tl.trans() for LDS-staged transpose, then coalesced strided write.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@Micky774 Micky774 added the ci-level 3 CI test level 3 label May 21, 2026
Comment thread transformer_engine/pytorch/triton_kernels/norms_common.py
Comment thread transformer_engine/pytorch/triton_kernels/rmsnorm.py Outdated
Comment thread README.rst Outdated
Comment thread tests/pytorch/triton_kernels/test_norms.py Outdated
Comment thread transformer_engine/pytorch/triton_kernels/norms_common.py Outdated
Comment thread transformer_engine/pytorch/triton_kernels/norms_common.py Outdated
Comment thread transformer_engine/pytorch/triton_kernels/norms_common.py Outdated
Comment thread transformer_engine/pytorch/triton_kernels/norms_common.py Outdated
Comment thread transformer_engine/pytorch/triton_kernels/norms_common.py
Comment thread transformer_engine/pytorch/triton_kernels/norms_common.py Outdated
Comment thread transformer_engine/pytorch/triton_kernels/rmsnorm.py
Comment thread transformer_engine/pytorch/triton_kernels/rmsnorm.py Outdated
@Micky774 Micky774 requested a review from alextmagro May 29, 2026 17:13
Copy link
Copy Markdown
Contributor

@alextmagro alextmagro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Comment thread transformer_engine/pytorch/triton_kernels/rmsnorm.py
@Micky774 Micky774 requested a review from aris134 May 29, 2026 19:17
Comment thread transformer_engine/pytorch/triton_kernels/rmsnorm.py Outdated
Copy link
Copy Markdown
Contributor

@aris134 aris134 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Comment thread README.rst Outdated
* NVTE_USE_CAST_TRANSPOSE_TRITON=1 can be used to enable cast transpose (bgrad) triton kernels;
* NVTE_USE_LAYERNORM_TRITON=1 can be used to enable layernorm triton kernels.
* NVTE_USE_RMSNORM_TRITON=1 can be used to enable rmsnorm triton kernels.
* NVTE_RMS_EXTERNAL_TRANSPOSE=0 disables external transpose in RMSNorm Triton kernels and
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not used in code

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ipanfilo Could you check if the comments has been addressed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated!

@wenchenvincent
Copy link
Copy Markdown
Collaborator

@alextmagro @aris134 I saw you had approved the PR. For the inline comments, let's also resolve conversation if the comments has been addressed.

@Micky774 Micky774 requested a review from ipanfilo June 2, 2026 17:12
@Micky774 Micky774 merged commit 4bfe12d into dev Jun 2, 2026
6 of 9 checks passed
@Micky774 Micky774 deleted the zain/rms-opt branch June 2, 2026 20:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 3 CI test level 3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants