[PyTorch] [torch.compile] transformer_engine.pytorch.autocast suport inside torch.compile#2759
Conversation
Move FP8 global state onto an instance so Dynamo can trace autocast state updates, explicitly reject DelayedScaling under torch.compile, and add toy compile tests that keep TE forward/backward opaque while covering supported recipes. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Drop the standalone global dict and dataclass mutation experiments now that the torch.compile regression coverage lives in the focused autocast test file. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Use compiler constant-result wrappers for support checks and rename the module-level FP8 singleton to `_FP8_GLOBAL_STATE` for clearer semantics. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
338ddae to
b5d46fd
Compare
Restore the FP8 naming and remove extra state access helpers so the torch.compile changes stay focused on the instance-backed global state. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Drop stale availability fields from FP8GlobalState now that support checks use module-level cached results instead of manager state. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Resolve conflicts in the FP8 torch.compile changes while preserving the upstream updates in graph.py and module/base.py. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR enables Several P1 concerns surfaced in prior review rounds remain open — most critically, Confidence Score: 4/5
Important Files Changed
Class Diagram%%{init: {'theme': 'neutral'}}%%
classDiagram
class FP8GlobalState {
+bool fp8_enabled
+bool fp8_calibration
+Recipe fp8_recipe
+dist_group_type fp8_distributed_group
+bool fp8_parameters
+bool high_precision_init_val
+bool is_first_fp8_module
+bool fp8_graph_capturing
+int autocast_depth
+Dict global_amax_buffer
+Dict global_amax_history_buffer
+Dict global_scale_buffer
+list fp8_tensors_recompute_buffer
+Dict autocast_arguments
+Tensor skip_fp8_weight_update_tensor
}
class FP8GlobalStateManager {
+FP8GlobalState quantization_state
+reset() void
+autocast_enter() void
+autocast_exit() void
+is_fp8_enabled() bool
+fp8_graph_capturing() bool
+is_first_fp8_module() bool
+get_fp8_recipe() Recipe
+get_autocast_state() tuple
+set_autocast_state() void
+reduce_and_update_fp8_tensors() void
}
class ModuleLevelCaches {
+_FP8_SUPPORT: Optional[Tuple]
+_MXFP8_SUPPORT: Optional[Tuple]
+_NVFP4_SUPPORT: Optional[Tuple]
+_FP8_BLOCK_SCALING_SUPPORT: Optional[Tuple]
}
class SupportChecks {
+check_fp8_support() [assume_constant_result]
+check_mxfp8_support() [assume_constant_result]
+check_nvfp4_support() [assume_constant_result]
+check_fp8_block_scaling_support() [assume_constant_result]
}
FP8GlobalStateManager --> FP8GlobalState : quantization_state (singleton)
SupportChecks --> ModuleLevelCaches : reads/writes
FP8GlobalStateManager --> SupportChecks : calls
Reviews (8): Last reviewed commit: "Merge branch 'main' into torch_compile_a..." | Re-trigger Greptile |
Replace custom-op-based ToyLinear with a minimal version using F.linear. Add test_autocast_sanity (parametrized over all recipes including NVFP4) and test_autocast_nested_sanity with CustomRecipes. Both verify fullgraph=True compilation without graph breaks. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor
for more information, see https://pre-commit.ci
Verify that te.autocast(recipe=DelayedScaling(), enabled=True) raises a clear RuntimeError when used inside torch.compile. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor
Use str(recipe) for content-based recipe keying (avoids unbounded growth when identical recipes are constructed inline) and id(group) for process group identity (same semantics as the old hash(group) which was id-based). Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor
for more information, see https://pre-commit.ci
|
/te-ci pytorch L1 |
Replace custom_op-based approach with torch.library.define/impl/register_fake using get_opaque_type_name() in the schema, which allows Inductor to properly handle opaque value types. Add ToyQuantizer as an opaque value-type wrapper around Float8CurrentScalingQuantizer with proper __eq__/__hash__/__fx_repr__. test_autocast_nested_custom validates that nested te.autocast with 3 distinct CustomRecipe instances passes the correct quantizers in both forward and backward. test_autocast_sanity is a smoke test for all hardware-supported built-in recipes. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
|
/te-ci pytorch |
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
…inside torch.compile (NVIDIA#2759) * Improve torch.compile behavior around FP8 autocast. Move FP8 global state onto an instance so Dynamo can trace autocast state updates, explicitly reject DelayedScaling under torch.compile, and add toy compile tests that keep TE forward/backward opaque while covering supported recipes. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * Remove temporary global state experiment tests. Drop the standalone global dict and dataclass mutation experiments now that the torch.compile regression coverage lives in the focused autocast test file. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * Clean up FP8 global state naming. Use compiler constant-result wrappers for support checks and rename the module-level FP8 singleton to `_FP8_GLOBAL_STATE` for clearer semantics. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * Minimize FP8 global state diff. Restore the FP8 naming and remove extra state access helpers so the torch.compile changes stay focused on the instance-backed global state. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * Remove unused FP8 state fields. Drop stale availability fields from FP8GlobalState now that support checks use module-level cached results instead of manager state. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Simplify torch.compile autocast tests Replace custom-op-based ToyLinear with a minimal version using F.linear. Add test_autocast_sanity (parametrized over all recipes including NVFP4) and test_autocast_nested_sanity with CustomRecipes. Both verify fullgraph=True compilation without graph breaks. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add test for DelayedScaling rejection under torch.compile Verify that te.autocast(recipe=DelayedScaling(), enabled=True) raises a clear RuntimeError when used inside torch.compile. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor * Use content-based autocast key with id() for group Use str(recipe) for content-based recipe keying (avoids unbounded growth when identical recipes are constructed inline) and id(group) for process group identity (same semantics as the old hash(group) which was id-based). Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Rewrite torch.compile tests with opaque value-type quantizers Replace custom_op-based approach with torch.library.define/impl/register_fake using get_opaque_type_name() in the schema, which allows Inductor to properly handle opaque value types. Add ToyQuantizer as an opaque value-type wrapper around Float8CurrentScalingQuantizer with proper __eq__/__hash__/__fx_repr__. test_autocast_nested_custom validates that nested te.autocast with 3 distinct CustomRecipe instances passes the correct quantizers in both forward and backward. test_autocast_sanity is a smoke test for all hardware-supported built-in recipes. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * apply suggestions Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> --------- Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Description
Enable
torch.compile(fullgraph=True)for FP8 autocast by moving compile-visible mutable state off class attributes, avoiding tracing through support checks, and adding test.Type of change
Changes
Please list the changes introduced in this PR:
clsattribute writes to a dataclass-backed singleton object, becausetorch.compiledoes not support writes directly to class attributes.lru_cache-based support checks with explicit module-level caches and mark the wrapper functions with@torch.compiler.assume_constant_resultsotorch.compiledoes not trace intocheck_*_support().torch.compilecoverage for FP8 autocast using a custom test module; the test is more involved because there is currently no simple TE layer that supports both FP8 andtorch.compile.DelayedScalingexplicitly unsupported undertorch.compileand raise a clear error for that case.Checklist:
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: