From f3322878a056a261bb0a5b148f96b7b9b7e2b0db Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Wed, 20 May 2026 16:19:58 +0800 Subject: [PATCH 1/3] add the getter and setter of skip_fp8_weight_update_tensor Signed-off-by: Xiaowei Ren --- transformer_engine/pytorch/graph.py | 9 ++------- transformer_engine/pytorch/quantization.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 075db1394b..568cf4308d 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -324,12 +324,7 @@ def _make_graphed_callables( if cache_quantized_params: # Initialize flag that controls FP8 weight updates - qstate = FP8GlobalStateManager.quantization_state - if qstate.skip_fp8_weight_update_tensor is None: - qstate.skip_fp8_weight_update_tensor = torch.empty( - 1, dtype=torch.float32, device="cuda" - ) - qstate.skip_fp8_weight_update_tensor.fill_(False) + FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False) # Check callables for c in callables: @@ -841,7 +836,7 @@ def forward(ctx, skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *i # Set flag for whether to update FP8 weight updates ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() if ctx.is_first_module and skip_fp8_weight_update is not None: - FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor.fill_( + FP8GlobalStateManager.set_skip_fp8_weight_update_tensor( skip_fp8_weight_update ) ctx.cuda_graph_stream = cuda_graph_stream diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 0c40723517..7900fcc014 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -409,6 +409,20 @@ class FP8GlobalStateManager: quantization_state = FP8GlobalState() + @classmethod + def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: + """Set the skip fp8 weight update tensor""" + if cls.quantization_state.skip_fp8_weight_update_tensor is None: + cls.quantization_state.skip_fp8_weight_update_tensor = torch.empty( + 1, dtype=torch.float32, device="cuda" + ) + cls.quantization_state.skip_fp8_weight_update_tensor.fill_(skip) + + @classmethod + def get_skip_fp8_weight_update_tensor(cls) -> Union[torch.Tensor, None]: + """Get the skip fp8 weight update tensor""" + return cls.quantization_state.skip_fp8_weight_update_tensor + @classmethod def reset(cls) -> None: """Reset the global state""" From a5e984bbf04eddc17e09584d313370768d703142 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 May 2026 09:49:47 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/graph.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 568cf4308d..86b8a4acf4 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -836,9 +836,7 @@ def forward(ctx, skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *i # Set flag for whether to update FP8 weight updates ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() if ctx.is_first_module and skip_fp8_weight_update is not None: - FP8GlobalStateManager.set_skip_fp8_weight_update_tensor( - skip_fp8_weight_update - ) + FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update) ctx.cuda_graph_stream = cuda_graph_stream ctx.cuda_graph_event = cuda_graph_event # Copy values from new tensors into static tensors From 25e01dd618a6a71f38bf4195af9f69efe8ae3074 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Date: Wed, 20 May 2026 18:59:57 +0800 Subject: [PATCH 3/3] Update transformer_engine/pytorch/quantization.py return type fix Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> --- transformer_engine/pytorch/quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 7900fcc014..41c72af661 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -419,7 +419,7 @@ def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: cls.quantization_state.skip_fp8_weight_update_tensor.fill_(skip) @classmethod - def get_skip_fp8_weight_update_tensor(cls) -> Union[torch.Tensor, None]: + def get_skip_fp8_weight_update_tensor(cls) -> Optional[torch.Tensor]: """Get the skip fp8 weight update tensor""" return cls.quantization_state.skip_fp8_weight_update_tensor