Skip to content

[PyTorch] [torch.compile] transformer_engine.pytorch.autocast suport inside torch.compile#2759

Merged
pggPL merged 19 commits into
NVIDIA:mainfrom
pggPL:torch_compile_autocast
Apr 15, 2026
Merged

[PyTorch] [torch.compile] transformer_engine.pytorch.autocast suport inside torch.compile#2759
pggPL merged 19 commits into
NVIDIA:mainfrom
pggPL:torch_compile_autocast

Conversation

@pggPL
Copy link
Copy Markdown
Collaborator

@pggPL pggPL commented Mar 13, 2026

Description

Enable torch.compile(fullgraph=True) for FP8 autocast by moving compile-visible mutable state off class attributes, avoiding tracing through support checks, and adding test.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Move mutable FP8 autocast state from direct cls attribute writes to a dataclass-backed singleton object, because torch.compile does not support writes directly to class attributes.
  • Replace lru_cache-based support checks with explicit module-level caches and mark the wrapper functions with @torch.compiler.assume_constant_result so torch.compile does not trace into check_*_support().
  • Add torch.compile coverage for FP8 autocast using a custom test module; the test is more involved because there is currently no simple TE layer that supports both FP8 and torch.compile.
  • Make DelayedScaling explicitly unsupported under torch.compile and raise a clear error for that case.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL added 3 commits March 13, 2026 14:02
Move FP8 global state onto an instance so Dynamo can trace autocast state updates, explicitly reject DelayedScaling under torch.compile, and add toy compile tests that keep TE forward/backward opaque while covering supported recipes.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Drop the standalone global dict and dataclass mutation experiments now that the torch.compile regression coverage lives in the focused autocast test file.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Use compiler constant-result wrappers for support checks and rename the module-level FP8 singleton to `_FP8_GLOBAL_STATE` for clearer semantics.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL force-pushed the torch_compile_autocast branch from 338ddae to b5d46fd Compare March 13, 2026 13:03
pggPL and others added 4 commits March 13, 2026 14:24
Restore the FP8 naming and remove extra state access helpers so the torch.compile changes stay focused on the instance-backed global state.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Drop stale availability fields from FP8GlobalState now that support checks use module-level cached results instead of manager state.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Resolve conflicts in the FP8 torch.compile changes while preserving the upstream updates in graph.py and module/base.py.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL marked this pull request as ready for review March 13, 2026 14:09
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Mar 13, 2026

Greptile Summary

This PR enables torch.compile(fullgraph=True) for FP8 autocast by migrating mutable global state from class-level attributes to a FP8GlobalState dataclass instance (quantization_state) and replacing lru_cache-based support checks with @torch.compiler.assume_constant_result-decorated wrappers. The mechanical refactor across modules is clean and consistent, and the ToyLinear + opaque-type test infrastructure is a thoughtful solution to the difficulty of testing FP8 + compile together.

Several P1 concerns surfaced in prior review rounds remain open — most critically, autocast_enter unconditionally calls get_default_fp8_recipe() when no recipe is supplied (even when enabled=False), which will AssertionError under torch.compile; reset() swaps the singleton object rather than mutating it in-place, invalidating compiled-graph guards on quantization_state; and the null-guard on skip_fp8_weight_update_tensor at graph replay time is still missing in graph.py.

Confidence Score: 4/5

  • The core state-migration refactor is correct and the torch.compile guard approach is sound, but several P1 issues from prior review rounds remain unaddressed and should be fixed before merge.
  • Score reflects that the mechanical refactor across 10 files is clean and consistent, but open issues from prior rounds (unconditional get_default_fp8_recipe() crash path, reset() breaking compiled-graph guards, missing null guard in graph replay, assert strippable under -O) are confirmed present in the submitted code and represent real defect paths in the compile-enabled flow this PR is trying to unlock.
  • transformer_engine/pytorch/quantization.py (autocast_enter recipe path, reset(), fp8_graph_capturing assert) and transformer_engine/pytorch/graph.py (null guard on skip_fp8_weight_update_tensor at line 841)

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantization.py Core of the PR: replaces class-level attribute writes with a FP8GlobalState dataclass singleton and adds @torch.compiler.assume_constant_result to support checks. Several open P1 concerns from prior review rounds remain unresolved (unconditional get_default_fp8_recipe() in autocast_enter, reset() swapping the singleton object, assert strippable under -O).
tests/pytorch/test_torch_compile.py New test file with sophisticated ToyLinear + opaque-type infrastructure for testing torch.compile; however test_autocast_sanity exercises only torch.nn.Linear (not TE ops), no DelayedScaling-under-compile error test, and private torch._opaque_base/torch._library.opaque_object APIs are used without stability guards.
transformer_engine/pytorch/graph.py Moves skip_fp8_weight_update_tensor initialization inline with a null check at setup time; the inner call site on line 841 still has no null guard and will crash with AttributeError if reset() is called between graph capture and replay.
transformer_engine/pytorch/module/base.py Correctly redirects amax buffer writes to quantization_state; update order of global_amax_history_buffer and global_amax_buffer is reversed compared to ops/op.py (cosmetic inconsistency, not a correctness bug).

Class Diagram

%%{init: {'theme': 'neutral'}}%%
classDiagram
    class FP8GlobalState {
        +bool fp8_enabled
        +bool fp8_calibration
        +Recipe fp8_recipe
        +dist_group_type fp8_distributed_group
        +bool fp8_parameters
        +bool high_precision_init_val
        +bool is_first_fp8_module
        +bool fp8_graph_capturing
        +int autocast_depth
        +Dict global_amax_buffer
        +Dict global_amax_history_buffer
        +Dict global_scale_buffer
        +list fp8_tensors_recompute_buffer
        +Dict autocast_arguments
        +Tensor skip_fp8_weight_update_tensor
    }

    class FP8GlobalStateManager {
        +FP8GlobalState quantization_state
        +reset() void
        +autocast_enter() void
        +autocast_exit() void
        +is_fp8_enabled() bool
        +fp8_graph_capturing() bool
        +is_first_fp8_module() bool
        +get_fp8_recipe() Recipe
        +get_autocast_state() tuple
        +set_autocast_state() void
        +reduce_and_update_fp8_tensors() void
    }

    class ModuleLevelCaches {
        +_FP8_SUPPORT: Optional[Tuple]
        +_MXFP8_SUPPORT: Optional[Tuple]
        +_NVFP4_SUPPORT: Optional[Tuple]
        +_FP8_BLOCK_SCALING_SUPPORT: Optional[Tuple]
    }

    class SupportChecks {
        +check_fp8_support() [assume_constant_result]
        +check_mxfp8_support() [assume_constant_result]
        +check_nvfp4_support() [assume_constant_result]
        +check_fp8_block_scaling_support() [assume_constant_result]
    }

    FP8GlobalStateManager --> FP8GlobalState : quantization_state (singleton)
    SupportChecks --> ModuleLevelCaches : reads/writes
    FP8GlobalStateManager --> SupportChecks : calls
Loading

Reviews (8): Last reviewed commit: "Merge branch 'main' into torch_compile_a..." | Re-trigger Greptile

Comment thread tests/pytorch/test_torch_compile.py Outdated
pggPL and others added 2 commits March 23, 2026 12:12
Replace custom-op-based ToyLinear with a minimal version using F.linear.
Add test_autocast_sanity (parametrized over all recipes including NVFP4)
and test_autocast_nested_sanity with CustomRecipes. Both verify
fullgraph=True compilation without graph breaks.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
Comment thread transformer_engine/pytorch/quantization.py
Comment thread tests/pytorch/test_torch_compile.py Outdated
Comment thread transformer_engine/pytorch/quantization.py
Comment thread transformer_engine/pytorch/quantization.py Outdated
Verify that te.autocast(recipe=DelayedScaling(), enabled=True) raises
a clear RuntimeError when used inside torch.compile.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
Comment thread tests/pytorch/test_torch_compile.py Outdated
Comment thread transformer_engine/pytorch/quantization.py
pggPL and others added 2 commits March 23, 2026 12:46
Use str(recipe) for content-based recipe keying (avoids unbounded growth
when identical recipes are constructed inline) and id(group) for process
group identity (same semantics as the old hash(group) which was id-based).

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
Comment thread transformer_engine/pytorch/quantization.py
Comment thread transformer_engine/pytorch/module/base.py Outdated
Comment thread transformer_engine/pytorch/quantization.py Outdated
@pggPL
Copy link
Copy Markdown
Collaborator Author

pggPL commented Mar 23, 2026

/te-ci pytorch L1

pggPL and others added 2 commits March 25, 2026 17:46
Replace custom_op-based approach with torch.library.define/impl/register_fake
using get_opaque_type_name() in the schema, which allows Inductor to properly
handle opaque value types. Add ToyQuantizer as an opaque value-type wrapper
around Float8CurrentScalingQuantizer with proper __eq__/__hash__/__fx_repr__.

test_autocast_nested_custom validates that nested te.autocast with 3 distinct
CustomRecipe instances passes the correct quantizers in both forward and backward.
test_autocast_sanity is a smoke test for all hardware-supported built-in recipes.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
Comment thread tests/pytorch/test_torch_compile.py Outdated
Comment thread tests/pytorch/test_torch_compile.py
@pggPL
Copy link
Copy Markdown
Collaborator Author

pggPL commented Mar 30, 2026

/te-ci pytorch

@pggPL pggPL changed the title [PyTorch] transformer_engine.pytorch.autocast suport inside torch.compile [PyTorch] [torch.compile] transformer_engine.pytorch.autocast suport inside torch.compile Mar 30, 2026
@pggPL
Copy link
Copy Markdown
Collaborator Author

pggPL commented Apr 13, 2026

/te-ci pytorch

Comment thread transformer_engine/pytorch/graph.py
Comment thread tests/pytorch/test_torch_compile.py Outdated
Comment thread transformer_engine/pytorch/module/layernorm_mlp.py Outdated
Comment thread transformer_engine/pytorch/quantization.py Outdated
Comment thread transformer_engine/pytorch/quantization.py
Comment thread transformer_engine/pytorch/quantization.py
Comment thread transformer_engine/pytorch/quantization.py Outdated
pggPL and others added 4 commits April 14, 2026 11:42
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented Apr 14, 2026

/te-ci pytorch

@pggPL pggPL merged commit 17aa2e4 into NVIDIA:main Apr 15, 2026
19 of 24 checks passed
faradawn pushed a commit to faradawn/TransformerEngine that referenced this pull request May 14, 2026
…inside torch.compile (NVIDIA#2759)

* Improve torch.compile behavior around FP8 autocast.

Move FP8 global state onto an instance so Dynamo can trace autocast state updates, explicitly reject DelayedScaling under torch.compile, and add toy compile tests that keep TE forward/backward opaque while covering supported recipes.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* Remove temporary global state experiment tests.

Drop the standalone global dict and dataclass mutation experiments now that the torch.compile regression coverage lives in the focused autocast test file.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* Clean up FP8 global state naming.

Use compiler constant-result wrappers for support checks and rename the module-level FP8 singleton to `_FP8_GLOBAL_STATE` for clearer semantics.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* Minimize FP8 global state diff.

Restore the FP8 naming and remove extra state access helpers so the torch.compile changes stay focused on the instance-backed global state.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* Remove unused FP8 state fields.

Drop stale availability fields from FP8GlobalState now that support checks use module-level cached results instead of manager state.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Simplify torch.compile autocast tests

Replace custom-op-based ToyLinear with a minimal version using F.linear.
Add test_autocast_sanity (parametrized over all recipes including NVFP4)
and test_autocast_nested_sanity with CustomRecipes. Both verify
fullgraph=True compilation without graph breaks.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add test for DelayedScaling rejection under torch.compile

Verify that te.autocast(recipe=DelayedScaling(), enabled=True) raises
a clear RuntimeError when used inside torch.compile.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor

* Use content-based autocast key with id() for group

Use str(recipe) for content-based recipe keying (avoids unbounded growth
when identical recipes are constructed inline) and id(group) for process
group identity (same semantics as the old hash(group) which was id-based).

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Rewrite torch.compile tests with opaque value-type quantizers

Replace custom_op-based approach with torch.library.define/impl/register_fake
using get_opaque_type_name() in the schema, which allows Inductor to properly
handle opaque value types. Add ToyQuantizer as an opaque value-type wrapper
around Float8CurrentScalingQuantizer with proper __eq__/__hash__/__fx_repr__.

test_autocast_nested_custom validates that nested te.autocast with 3 distinct
CustomRecipe instances passes the correct quantizers in both forward and backward.
test_autocast_sanity is a smoke test for all hardware-supported built-in recipes.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* apply suggestions

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

---------

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants