Skip to content

[PyTorch] Add support for FP8 current scaling in operation-based API#1858

Merged
ksivaman merged 11 commits intoNVIDIA:mainfrom
timmoon10:te-sequential-new-recipes
Jun 13, 2025
Merged

[PyTorch] Add support for FP8 current scaling in operation-based API#1858
ksivaman merged 11 commits intoNVIDIA:mainfrom
timmoon10:te-sequential-new-recipes

Conversation

@timmoon10
Copy link
Copy Markdown
Collaborator

Description

This PR makes some minor changes needed for te.Sequential to support FP8 current scaling. It also modifies the te.Sequential unit tests to more gracefully handle quantization schemes other than FP8 delayed scaling.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring
  • Testing

Changes

  • Support FP8 current scaling in operation-based API
  • Generalize support for quantizers in tests for operation-based API

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@timmoon10 timmoon10 requested review from ksivaman and ptrendx June 6, 2025 21:49
@timmoon10 timmoon10 added enhancement New feature or request testing Improvements to tests or testing infrastructure labels Jun 6, 2025
@timmoon10
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

Comment on lines +211 to +213
if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
if input_quantizer._get_compatible_recipe().float8_per_tensor_scaling():

Copy link
Copy Markdown
Collaborator Author

@timmoon10 timmoon10 Jun 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this fix adds overhead and makes the code less clear. _get_compatible_recipe (added in #1724) returns a type since it's used for type comparisons. If we want to use float8_per_tensor_scaling, then we need to instantiate a new DelayedScaling or Float8CurrentScaling that doesn't match the current recipe. Alternatively, we can compare classes with is or isinstance, which is basically what we're already doing.

Comment on lines +356 to +358
if isinstance(
weight_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
) and isinstance(weight, Float8TensorBase):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(
weight_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
) and isinstance(weight, Float8TensorBase):
if weight_quantizer._get_compatible_recipe().float8_per_tensor_scaling() and
isinstance(weight, Float8TensorBase):

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think _get_compatible_recipe benefits us here.


# Recipe-specific configuration
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This being required here is unfortunate. Most of what I would suggest here are probably out of scope for this PR.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's a good way around this. The amax reduction logic is specific to linear layers with FP8 current scaling, so there's no other logical place to put it.

elif quantization == "fp8_current_scaling":
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device=test_device,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Orthogonal but why does this device arg exist here?

Copy link
Copy Markdown
Collaborator Author

@timmoon10 timmoon10 Jun 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly for completeness. We have a ref_device option since I prefer computing the reference impl on CPU, which helps catch CUDA-related bugs.


# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a fan of this. But I see we currently do this for other tests as well so for this PR it's ok. In a separate refactor we should convert the tests dir to also be a module.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Ideally we would use relative imports like from .utils import dtype_tols, but Python doesn't allow that if you are running the script directly, which is how Pytest runs the tests.

@timmoon10 timmoon10 force-pushed the te-sequential-new-recipes branch from 5e2490c to 0496d1f Compare June 12, 2025 01:11
@timmoon10
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

ksivaman
ksivaman previously approved these changes Jun 12, 2025
Copy link
Copy Markdown
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@timmoon10
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

@ksivaman ksivaman merged commit e963e4a into NVIDIA:main Jun 13, 2025
25 of 27 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request testing Improvements to tests or testing infrastructure

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants