[PyTorch] CPU Overhead Micro-optimizations#2146
Conversation
81adf6d to
dc2532a
Compare
| @@ -641,11 +641,15 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) { | |||
| } | |||
|
|
|||
| int nvte_is_non_tn_fp8_gemm_supported() { | |||
There was a problem hiding this comment.
This doesn't handle the case where we have multiple GPUs with different archs. We could add an arg for the device ID, but that just pushes the CPU overhead problem somewhere else.
There was a problem hiding this comment.
Yeah, but we didn't really support this case anyway?
There was a problem hiding this comment.
For topology like 1 CPU 8/4GPUs with homogenous GPU arch, we can cache the TN layout check.
| with torch.cuda.device( | ||
| getattr(self, list(self.named_parameters())[0][0]).device | ||
| ), self.prepare_forward( | ||
| if is_first_microbatch is None or is_first_microbatch: |
There was a problem hiding this comment.
How do we assume we can skip setting the device if is_first_microbatch=False?
There was a problem hiding this comment.
I assume that the device won't change across microbatches in a global batch.
There was a problem hiding this comment.
Since in a CPU bounded fwd only case, skipping set device for every single forward pass could account for 10% perf difference.
There was a problem hiding this comment.
This approach is really ad hoc. Personally, I think it would be better to not to support the multi-device case (basically revert #1974) than to have inconsistent multi-device support.
There was a problem hiding this comment.
I agree with it, but not sure if there are any potential impact for customers using this feature?
|
/te-ci pytorch L1 |
1 similar comment
|
/te-ci pytorch L1 |
|
@zhongbozhu - this is out of date with base branch . Woudl you please retry with latest ? |
Signed-off-by: zhongboz <zhongboz@nvidia.com>
Signed-off-by: zhongboz <zhongboz@nvidia.com>
Signed-off-by: zhongboz <zhongboz@nvidia.com>
f1f6891 to
89aeccf
Compare
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
|
/te_ci L1 |
There was a problem hiding this comment.
Greptile Overview
Greptile Summary
This PR implements CPU overhead micro-optimizations for PyTorch TransformerEngine modules to address performance bottlenecks on Grace-based systems (GB200/GH200). The changes introduce conditional device context management: instead of always wrapping forward passes in torch.cuda.device() context managers, the code now checks should_set_cuda_device_every_batch() and uses contextlib.nullcontext() when device setting is unnecessary. A new environment variable NVTE_SET_CUDA_DEVICE controls this behavior globally (defaulting to True for backward compatibility), with runtime overrides available via set_cuda_device_every_batch(). Additionally, the C++ layer caches the result of nvte_is_non_tn_fp8_gemm_supported() using std::call_once to eliminate redundant device capability queries. These optimizations target the ~2x higher CPU overhead observed on Grace systems compared to Intel-based H100 systems (issue #2053), where execution can become CPU-bound despite available GPU capacity. The changes apply uniformly across Linear, LayerNormLinear, LayerNormMLP, and GroupedLinear modules, maintaining functional equivalence while reducing per-batch CPU costs. A new benchmark script (benchmark_linear_cpu_overhead.py) enables reproducible measurement of the improvements.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/utils.py | 4/5 | Adds global configuration mechanism for device context management with environment variable NVTE_SET_CUDA_DEVICE and runtime control functions |
| transformer_engine/pytorch/module/linear.py | 4/5 | Conditionally skips torch.cuda.device() context manager in forward pass using new utility function |
| transformer_engine/pytorch/module/layernorm_linear.py | 4/5 | Applies same conditional device context optimization pattern to LayerNormLinear forward pass |
| transformer_engine/pytorch/module/layernorm_mlp.py | 5/5 | Implements device context optimization for LayerNormMLP module with clean integration |
| transformer_engine/pytorch/module/grouped_linear.py | 5/5 | Adds conditional device context handling to GroupedLinear forward pass |
| transformer_engine/common/transformer_engine.cpp | 5/5 | Caches device compute capability check using std::call_once to eliminate redundant queries |
| benchmarks/linear/benchmark_linear_cpu_overhead.py | 4/5 | New benchmark script for measuring CPU overhead with proper warmup/timing separation and is_first_microbatch handling |
Confidence score: 3/5
- This PR addresses a documented performance issue but introduces behavioral changes that could affect multi-GPU and heterogeneous-GPU scenarios
- Score reflects concerns about the device-setting optimization's safety in multi-GPU environments and the lack of clarity around when device context can be safely skipped; previous reviewers correctly identified that the assumption about skipping device setting may not hold across all multi-GPU configurations
- Pay close attention to transformer_engine/pytorch/utils.py for the initialization logic and all module forward passes for the conditional device context pattern, particularly testing on multi-GPU systems with different architectures
Sequence Diagram
sequenceDiagram
participant User
participant Linear
participant _Linear
participant Quantizer
participant GEMM
participant TPComm as TP Communication
User->>Linear: forward(inp, is_first_microbatch)
activate Linear
Linear->>Linear: should_set_cuda_device_every_batch()
Linear->>Linear: prepare_forward(inp)
Linear->>Linear: _get_weight_and_bias_tensors()
Linear->>Linear: _get_quantizers(fp8_output, fp8_grad)
alt FP8 enabled
Linear->>Quantizer: set_usage(rowwise, columnwise)
end
Linear->>_Linear: apply(weight, inp, bias, ...)
activate _Linear
alt Sequence Parallel with Column Mode
_Linear->>Quantizer: set_usage() for input
_Linear->>TPComm: gather_along_first_dim(inputmat, tp_group)
TPComm-->>_Linear: inputmat_total
else No All-Gather
_Linear->>Quantizer: quantize input
Quantizer-->>_Linear: inputmat (quantized)
end
alt FP8 Weight
_Linear->>Linear: get_weight_workspace()
Linear->>Quantizer: quantize weight
Quantizer-->>Linear: weightmat (FP8)
Linear-->>_Linear: weightmat
end
_Linear->>Quantizer: set_usage() for output
_Linear->>GEMM: general_gemm(weightmat, inputmat_total)
GEMM-->>_Linear: gemm_out
alt Row Parallel Mode
_Linear->>TPComm: reduce_scatter or allreduce
TPComm-->>_Linear: out
else Reduce-Scatter with Userbuffers
note over _Linear,TPComm: Communication overlapped with GEMM
_Linear-->>_Linear: out from reduce_scatter_out
end
alt is_grad_enabled
_Linear->>_Linear: save_for_backward(inputmat, weightmat, weight, bias)
_Linear->>_Linear: ctx.quantizers = quantizers
_Linear->>_Linear: ctx.tp_group = tp_group
end
_Linear-->>Linear: out
deactivate _Linear
alt gemm_bias_unfused_add
Linear->>Linear: out = out + bias
end
Linear-->>User: out
deactivate Linear
note over User: Backward pass begins
User->>_Linear: backward(grad_output)
activate _Linear
_Linear->>_Linear: restore_from_saved(saved_tensors)
alt Column Parallel with dgrad
_Linear->>TPComm: gather_along_first_dim(inputmat)
TPComm-->>_Linear: inputmat_total
end
_Linear->>Quantizer: set_usage() for grad_output
_Linear->>Linear: grad_output_preprocess()
Linear->>Quantizer: quantize grad_output
Quantizer-->>Linear: grad_output (quantized)
Linear-->>_Linear: grad_output, grad_bias
alt requires_dgrad
_Linear->>Quantizer: set_usage() for grad_input
_Linear->>GEMM: general_gemm(weight, grad_output)
GEMM-->>_Linear: gemm_out
alt Column Parallel Mode
_Linear->>TPComm: reduce_scatter or allreduce
TPComm-->>_Linear: dgrad
end
end
alt requires_wgrad
_Linear->>Quantizer: set_usage() for input/grad_output
alt delay_wgrad_compute
_Linear->>_Linear: wgrad_store.put([inputmat_total, grad_output])
else Compute Now
_Linear->>GEMM: general_gemm(inputmat_total, grad_output)
GEMM-->>_Linear: wgrad, grad_bias
end
end
alt reduce_and_update_bwd_fp8_tensors
_Linear->>FP8GlobalStateManager: reduce_and_update_fp8_tensors(forward=False)
end
_Linear-->>User: dgrad, wgrad, grad_bias
deactivate _Linear
7 files reviewed, 2 comments
| args = parser.parse_args() | ||
|
|
||
| x = torch.randn( | ||
| (args.seq_length, args.hidden_size), dtype=torch.bfloat16, device="cuda", requires_grad=True |
There was a problem hiding this comment.
style: requires_grad=True is set but the benchmark runs under torch.no_grad() context. Setting requires_grad=False would be more consistent.
| if should_set_cuda_device_every_batch(): | ||
| device_ctx = torch.cuda.device( | ||
| getattr(self, list(self.named_parameters())[0][0]).device | ||
| ) | ||
| else: | ||
| device_ctx = contextlib.nullcontext() |
There was a problem hiding this comment.
style: The device is retrieved by accessing the first parameter via list(self.named_parameters())[0][0]. This creates a list from the iterator on every forward pass. Consider caching the device during initialization (e.g., in __init__ or reset_parameters) to avoid repeated computation. Is there a reason the device cannot be cached during initialization?
There was a problem hiding this comment.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The PR introduces micro-optimizations to reduce CPU overhead in Transformer Engine's PyTorch modules, specifically targeting Grace-based ARM systems (GB200/GH200) where CPU bottlenecks limit GPU utilization. The primary optimization involves conditionally disabling CUDA device context management (torch.cuda.device()) on every forward pass through a new configurable flag NVTE_SET_CUDA_DEVICE. Additionally, the C++ function nvte_is_non_tn_fp8_gemm_supported() was optimized by caching the compute capability query result using std::call_once, eliminating redundant CUDA API calls. A new benchmark script was added to measure and validate these CPU overhead improvements. These changes integrate with the existing TE module architecture by wrapping forward passes in conditional contexts (contextlib.nullcontext() when device setting is unnecessary), maintaining backward compatibility while providing performance gains in single-GPU and latency-sensitive scenarios.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
transformer_engine/common/transformer_engine.cpp |
4/5 | Added thread-safe caching to nvte_is_non_tn_fp8_gemm_supported() to avoid repeated compute capability queries, but doesn't handle multi-device heterogeneous scenarios. |
transformer_engine/pytorch/utils.py |
5/5 | Introduced configurable device context management via NVTE_SET_CUDA_DEVICE environment variable and runtime API functions; clean implementation with proper defaults. |
transformer_engine/pytorch/module/linear.py |
4/5 | Conditionally skips device context creation based on should_set_cuda_device_every_batch(), but still creates a list from parameters iterator when device context is needed. |
transformer_engine/pytorch/module/layernorm_linear.py |
5/5 | Applies the same conditional device context optimization as other linear modules; well-integrated with existing code structure. |
transformer_engine/pytorch/module/layernorm_mlp.py |
4/5 | Conditionally disables device context switching; follows the same pattern as other modules with the same parameter retrieval inefficiency. |
transformer_engine/pytorch/module/grouped_linear.py |
4/5 | Implements conditional device context management consistently with other linear modules; parameter device retrieval still has overhead. |
benchmarks/linear/benchmark_linear_cpu_overhead.py |
3/5 | New benchmark script for measuring CPU overhead with inconsistent requires_grad=True setting under torch.no_grad() context. |
Confidence score: 4/5
- This PR is generally safe to merge with attention to specific concerns around multi-device scenarios and parameter device retrieval.
- Score reflects that the core optimizations are sound and well-implemented, but there are two notable issues: (1) the C++ caching doesn't handle multi-GPU systems with heterogeneous architectures (as noted in previous reviews), and (2) the parameter device retrieval via
list(self.named_parameters())[0][0]still creates overhead when device context is enabled (previously noted style issue remains unaddressed). - Pay close attention to
transformer_engine/common/transformer_engine.cppfor potential correctness issues in heterogeneous multi-GPU environments, and consider whether the parameter device retrieval pattern across all linear modules should be optimized before merging.
7 files reviewed, 2 comments
| if should_set_cuda_device_every_batch(): | ||
| device_ctx = torch.cuda.device( | ||
| getattr(self, list(self.named_parameters())[0][0]).device | ||
| ) | ||
| else: | ||
| device_ctx = contextlib.nullcontext() |
There was a problem hiding this comment.
style: Retrieving device via list(self.named_parameters())[0][0] creates a list every forward pass. Consider caching during initialization to avoid repeated computation.
| use_te=use_te, | ||
| ) | ||
|
|
||
| total_ops = 2 * args.hidden_size * args.hidden_size * args.seq_length * args.timing_iters |
There was a problem hiding this comment.
logic: FLOP calculation uses avg_gpu_time_per_round which is the sum over num_rounds, but total_ops is scaled by timing_iters only. The division should use avg_gpu_time_per_round / num_rounds or total_ops * num_rounds.
Description
Motivation: #2053
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: