Skip to content

[PyTorch] Use consistent API for fused norm kernels#1560

Merged
timmoon10 merged 13 commits into
NVIDIA:mainfrom
timmoon10:debug-mxfp8-norms
Mar 22, 2025
Merged

[PyTorch] Use consistent API for fused norm kernels#1560
timmoon10 merged 13 commits into
NVIDIA:mainfrom
timmoon10:debug-mxfp8-norms

Conversation

@timmoon10
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 commented Mar 12, 2025

Description

There are multiple redundant code paths for suppressing fused norm kernels:

  • The tex norm functions check an envvar to suppress cuDNN MXFP8 norm kernels
  • The Python wrapper around the tex norm function checks an envvar to suppress cuDNN MXFP8 norm kernels
  • LayerNormLinear and LayerNormMLP disable FP8 norm kernels if FP8 current-scaling is enabled

This PR consolidates this logic into the tex functions.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Remove redundant logic for suppressing cuDNN MXFP8 norm kernels
  • Control cuDNN MXFP8 norm kernels with NVTE_NORM_FWD_USE_CUDNN environment variable

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 added the bug Something isn't working label Mar 12, 2025
@timmoon10
Copy link
Copy Markdown
Member Author

/te-ci pytorch

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10
Copy link
Copy Markdown
Member Author

/te-ci pytorch

@timmoon10 timmoon10 marked this pull request as ready for review March 14, 2025 01:05
@timmoon10 timmoon10 requested a review from ksivaman March 14, 2025 01:15
@timmoon10 timmoon10 removed the bug Something isn't working label Mar 14, 2025
@timmoon10 timmoon10 changed the title [PyTorch] Debug MXFP8 norms [PyTorch] Use consistent API for fused norm kernels Mar 14, 2025
timmoon10 and others added 2 commits March 14, 2025 19:35
@timmoon10
Copy link
Copy Markdown
Member Author

/te-ci pytorch

Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
@timmoon10
Copy link
Copy Markdown
Member Author

/te-ci pytorch

@timmoon10 timmoon10 merged commit e80fbd7 into NVIDIA:main Mar 22, 2025
@timmoon10 timmoon10 deleted the debug-mxfp8-norms branch March 24, 2025 20:47
KshitijLakhani pushed a commit that referenced this pull request Mar 26, 2025
* Do not suppress MXFP8 norm in Python wrapper func

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Support FP8 current scaling in tex norm functions

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Use single envvar to enable cuDNN MXFP8 norm kernels

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Debug compilation error

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Fix compilation error

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Fix full-tile requirement for MXFP8 norm kernels

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Remove unused imports

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add missing imports

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
lhb8125 pushed a commit to lhb8125/TransformerEngine that referenced this pull request Apr 8, 2025
* Do not suppress MXFP8 norm in Python wrapper func

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Support FP8 current scaling in tex norm functions

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Use single envvar to enable cuDNN MXFP8 norm kernels

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Debug compilation error

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Fix compilation error

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Fix full-tile requirement for MXFP8 norm kernels

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Remove unused imports

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add missing imports

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants