Skip to content

[PyTorch] CPU Overhead Micro-optimizations#2146

Closed
zhongbozhu wants to merge 7 commits into
NVIDIA:mainfrom
zhongbozhu:zhongbo/cpu_overhead
Closed

[PyTorch] CPU Overhead Micro-optimizations#2146
zhongbozhu wants to merge 7 commits into
NVIDIA:mainfrom
zhongbozhu:zhongbo/cpu_overhead

Conversation

@zhongbozhu
Copy link
Copy Markdown
Collaborator

Description

Motivation: #2053

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@zhongbozhu zhongbozhu requested a review from timmoon10 September 2, 2025 19:42
@zhongbozhu zhongbozhu self-assigned this Sep 2, 2025
@@ -641,11 +641,15 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) {
}

int nvte_is_non_tn_fp8_gemm_supported() {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't handle the case where we have multiple GPUs with different archs. We could add an arg for the device ID, but that just pushes the CPU overhead problem somewhere else.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but we didn't really support this case anyway?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For topology like 1 CPU 8/4GPUs with homogenous GPU arch, we can cache the TN layout check.

with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(
if is_first_microbatch is None or is_first_microbatch:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we assume we can skip setting the device if is_first_microbatch=False?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume that the device won't change across microbatches in a global batch.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since in a CPU bounded fwd only case, skipping set device for every single forward pass could account for 10% perf difference.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach is really ad hoc. Personally, I think it would be better to not to support the multi-device case (basically revert #1974) than to have inconsistent multi-device support.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with it, but not sure if there are any potential impact for customers using this feature?

@zhongbozhu
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

1 similar comment
@zhongbozhu
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

@nvMelissa
Copy link
Copy Markdown
Collaborator

@zhongbozhu - this is out of date with base branch . Woudl you please retry with latest ?

Signed-off-by: zhongboz <zhongboz@nvidia.com>
Signed-off-by: zhongboz <zhongboz@nvidia.com>
Signed-off-by: zhongboz <zhongboz@nvidia.com>
Signed-off-by: zhongboz <zhongboz@nvidia.com>
Signed-off-by: zhongboz <zhongboz@nvidia.com>
@zhongbozhu zhongbozhu force-pushed the zhongbo/cpu_overhead branch from f1f6891 to 89aeccf Compare October 21, 2025 21:35
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
@zhongbozhu
Copy link
Copy Markdown
Collaborator Author

/te_ci L1

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR implements CPU overhead micro-optimizations for PyTorch TransformerEngine modules to address performance bottlenecks on Grace-based systems (GB200/GH200). The changes introduce conditional device context management: instead of always wrapping forward passes in torch.cuda.device() context managers, the code now checks should_set_cuda_device_every_batch() and uses contextlib.nullcontext() when device setting is unnecessary. A new environment variable NVTE_SET_CUDA_DEVICE controls this behavior globally (defaulting to True for backward compatibility), with runtime overrides available via set_cuda_device_every_batch(). Additionally, the C++ layer caches the result of nvte_is_non_tn_fp8_gemm_supported() using std::call_once to eliminate redundant device capability queries. These optimizations target the ~2x higher CPU overhead observed on Grace systems compared to Intel-based H100 systems (issue #2053), where execution can become CPU-bound despite available GPU capacity. The changes apply uniformly across Linear, LayerNormLinear, LayerNormMLP, and GroupedLinear modules, maintaining functional equivalence while reducing per-batch CPU costs. A new benchmark script (benchmark_linear_cpu_overhead.py) enables reproducible measurement of the improvements.

Important Files Changed

Filename Score Overview
transformer_engine/pytorch/utils.py 4/5 Adds global configuration mechanism for device context management with environment variable NVTE_SET_CUDA_DEVICE and runtime control functions
transformer_engine/pytorch/module/linear.py 4/5 Conditionally skips torch.cuda.device() context manager in forward pass using new utility function
transformer_engine/pytorch/module/layernorm_linear.py 4/5 Applies same conditional device context optimization pattern to LayerNormLinear forward pass
transformer_engine/pytorch/module/layernorm_mlp.py 5/5 Implements device context optimization for LayerNormMLP module with clean integration
transformer_engine/pytorch/module/grouped_linear.py 5/5 Adds conditional device context handling to GroupedLinear forward pass
transformer_engine/common/transformer_engine.cpp 5/5 Caches device compute capability check using std::call_once to eliminate redundant queries
benchmarks/linear/benchmark_linear_cpu_overhead.py 4/5 New benchmark script for measuring CPU overhead with proper warmup/timing separation and is_first_microbatch handling

Confidence score: 3/5

  • This PR addresses a documented performance issue but introduces behavioral changes that could affect multi-GPU and heterogeneous-GPU scenarios
  • Score reflects concerns about the device-setting optimization's safety in multi-GPU environments and the lack of clarity around when device context can be safely skipped; previous reviewers correctly identified that the assumption about skipping device setting may not hold across all multi-GPU configurations
  • Pay close attention to transformer_engine/pytorch/utils.py for the initialization logic and all module forward passes for the conditional device context pattern, particularly testing on multi-GPU systems with different architectures

Sequence Diagram

sequenceDiagram
    participant User
    participant Linear
    participant _Linear
    participant Quantizer
    participant GEMM
    participant TPComm as TP Communication
    
    User->>Linear: forward(inp, is_first_microbatch)
    activate Linear
    
    Linear->>Linear: should_set_cuda_device_every_batch()
    Linear->>Linear: prepare_forward(inp)
    Linear->>Linear: _get_weight_and_bias_tensors()
    Linear->>Linear: _get_quantizers(fp8_output, fp8_grad)
    
    alt FP8 enabled
        Linear->>Quantizer: set_usage(rowwise, columnwise)
    end
    
    Linear->>_Linear: apply(weight, inp, bias, ...)
    activate _Linear
    
    alt Sequence Parallel with Column Mode
        _Linear->>Quantizer: set_usage() for input
        _Linear->>TPComm: gather_along_first_dim(inputmat, tp_group)
        TPComm-->>_Linear: inputmat_total
    else No All-Gather
        _Linear->>Quantizer: quantize input
        Quantizer-->>_Linear: inputmat (quantized)
    end
    
    alt FP8 Weight
        _Linear->>Linear: get_weight_workspace()
        Linear->>Quantizer: quantize weight
        Quantizer-->>Linear: weightmat (FP8)
        Linear-->>_Linear: weightmat
    end
    
    _Linear->>Quantizer: set_usage() for output
    _Linear->>GEMM: general_gemm(weightmat, inputmat_total)
    GEMM-->>_Linear: gemm_out
    
    alt Row Parallel Mode
        _Linear->>TPComm: reduce_scatter or allreduce
        TPComm-->>_Linear: out
    else Reduce-Scatter with Userbuffers
        note over _Linear,TPComm: Communication overlapped with GEMM
        _Linear-->>_Linear: out from reduce_scatter_out
    end
    
    alt is_grad_enabled
        _Linear->>_Linear: save_for_backward(inputmat, weightmat, weight, bias)
        _Linear->>_Linear: ctx.quantizers = quantizers
        _Linear->>_Linear: ctx.tp_group = tp_group
    end
    
    _Linear-->>Linear: out
    deactivate _Linear
    
    alt gemm_bias_unfused_add
        Linear->>Linear: out = out + bias
    end
    
    Linear-->>User: out
    deactivate Linear
    
    note over User: Backward pass begins
    
    User->>_Linear: backward(grad_output)
    activate _Linear
    
    _Linear->>_Linear: restore_from_saved(saved_tensors)
    
    alt Column Parallel with dgrad
        _Linear->>TPComm: gather_along_first_dim(inputmat)
        TPComm-->>_Linear: inputmat_total
    end
    
    _Linear->>Quantizer: set_usage() for grad_output
    _Linear->>Linear: grad_output_preprocess()
    Linear->>Quantizer: quantize grad_output
    Quantizer-->>Linear: grad_output (quantized)
    Linear-->>_Linear: grad_output, grad_bias
    
    alt requires_dgrad
        _Linear->>Quantizer: set_usage() for grad_input
        _Linear->>GEMM: general_gemm(weight, grad_output)
        GEMM-->>_Linear: gemm_out
        
        alt Column Parallel Mode
            _Linear->>TPComm: reduce_scatter or allreduce
            TPComm-->>_Linear: dgrad
        end
    end
    
    alt requires_wgrad
        _Linear->>Quantizer: set_usage() for input/grad_output
        
        alt delay_wgrad_compute
            _Linear->>_Linear: wgrad_store.put([inputmat_total, grad_output])
        else Compute Now
            _Linear->>GEMM: general_gemm(inputmat_total, grad_output)
            GEMM-->>_Linear: wgrad, grad_bias
        end
    end
    
    alt reduce_and_update_bwd_fp8_tensors
        _Linear->>FP8GlobalStateManager: reduce_and_update_fp8_tensors(forward=False)
    end
    
    _Linear-->>User: dgrad, wgrad, grad_bias
    deactivate _Linear
Loading

7 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

args = parser.parse_args()

x = torch.randn(
(args.seq_length, args.hidden_size), dtype=torch.bfloat16, device="cuda", requires_grad=True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: requires_grad=True is set but the benchmark runs under torch.no_grad() context. Setting requires_grad=False would be more consistent.

Comment on lines +1537 to +1542
if should_set_cuda_device_every_batch():
device_ctx = torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
)
else:
device_ctx = contextlib.nullcontext()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: The device is retrieved by accessing the first parameter via list(self.named_parameters())[0][0]. This creates a list from the iterator on every forward pass. Consider caching the device during initialization (e.g., in __init__ or reset_parameters) to avoid repeated computation. Is there a reason the device cannot be cached during initialization?

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The PR introduces micro-optimizations to reduce CPU overhead in Transformer Engine's PyTorch modules, specifically targeting Grace-based ARM systems (GB200/GH200) where CPU bottlenecks limit GPU utilization. The primary optimization involves conditionally disabling CUDA device context management (torch.cuda.device()) on every forward pass through a new configurable flag NVTE_SET_CUDA_DEVICE. Additionally, the C++ function nvte_is_non_tn_fp8_gemm_supported() was optimized by caching the compute capability query result using std::call_once, eliminating redundant CUDA API calls. A new benchmark script was added to measure and validate these CPU overhead improvements. These changes integrate with the existing TE module architecture by wrapping forward passes in conditional contexts (contextlib.nullcontext() when device setting is unnecessary), maintaining backward compatibility while providing performance gains in single-GPU and latency-sensitive scenarios.

Important Files Changed

Filename Score Overview
transformer_engine/common/transformer_engine.cpp 4/5 Added thread-safe caching to nvte_is_non_tn_fp8_gemm_supported() to avoid repeated compute capability queries, but doesn't handle multi-device heterogeneous scenarios.
transformer_engine/pytorch/utils.py 5/5 Introduced configurable device context management via NVTE_SET_CUDA_DEVICE environment variable and runtime API functions; clean implementation with proper defaults.
transformer_engine/pytorch/module/linear.py 4/5 Conditionally skips device context creation based on should_set_cuda_device_every_batch(), but still creates a list from parameters iterator when device context is needed.
transformer_engine/pytorch/module/layernorm_linear.py 5/5 Applies the same conditional device context optimization as other linear modules; well-integrated with existing code structure.
transformer_engine/pytorch/module/layernorm_mlp.py 4/5 Conditionally disables device context switching; follows the same pattern as other modules with the same parameter retrieval inefficiency.
transformer_engine/pytorch/module/grouped_linear.py 4/5 Implements conditional device context management consistently with other linear modules; parameter device retrieval still has overhead.
benchmarks/linear/benchmark_linear_cpu_overhead.py 3/5 New benchmark script for measuring CPU overhead with inconsistent requires_grad=True setting under torch.no_grad() context.

Confidence score: 4/5

  • This PR is generally safe to merge with attention to specific concerns around multi-device scenarios and parameter device retrieval.
  • Score reflects that the core optimizations are sound and well-implemented, but there are two notable issues: (1) the C++ caching doesn't handle multi-GPU systems with heterogeneous architectures (as noted in previous reviews), and (2) the parameter device retrieval via list(self.named_parameters())[0][0] still creates overhead when device context is enabled (previously noted style issue remains unaddressed).
  • Pay close attention to transformer_engine/common/transformer_engine.cpp for potential correctness issues in heterogeneous multi-GPU environments, and consider whether the parameter device retrieval pattern across all linear modules should be optimized before merging.

7 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +771 to +776
if should_set_cuda_device_every_batch():
device_ctx = torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
)
else:
device_ctx = contextlib.nullcontext()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Retrieving device via list(self.named_parameters())[0][0] creates a list every forward pass. Consider caching during initialization to avoid repeated computation.

use_te=use_te,
)

total_ops = 2 * args.hidden_size * args.hidden_size * args.seq_length * args.timing_iters
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: FLOP calculation uses avg_gpu_time_per_round which is the sum over num_rounds, but total_ops is scaled by timing_iters only. The division should use avg_gpu_time_per_round / num_rounds or total_ops * num_rounds.

@ksivaman
Copy link
Copy Markdown
Member

Closing in favor of #2377 and #2400

@ksivaman ksivaman closed this Nov 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants