Skip to content

[PyTorch] Reduce CPU overheads#2377

Merged
ksivaman merged 5 commits into
NVIDIA:mainfrom
ksivaman:reduce_framework_cpu_overheads
Nov 17, 2025
Merged

[PyTorch] Reduce CPU overheads#2377
ksivaman merged 5 commits into
NVIDIA:mainfrom
ksivaman:reduce_framework_cpu_overheads

Conversation

@ksivaman
Copy link
Copy Markdown
Member

@ksivaman ksivaman commented Nov 13, 2025

Description

Based on single GPU profiling of the GroupedLinear module, implement some optimizations in order to reduce CPU overhead due to PyTorch.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring
  • Optimization/performance

Changes

  • Consolidate creation and caching of workspace in the GEMM logic. Fix workspace device for cases where incorrectly cached tensor is used.
  • Reduce number of arguments to PyTorch autograd function in order to not reduce overheads due to functions such as unwrap_dead_wrappers .
  • Use nvtx context manager only when enabled via envvar.
  • Remove torch.cuda.device context manager to C++.
  • Minor refactors such that delayed scaling recipe checks are grouped together to reduce unnecessary checks for other recipes.
  • Reduce number of calls to torch.is_grad_enabled and misc other torch calls.
  • Make quantizer specific copy implementations in order to avoid copy.copy().

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@ksivaman ksivaman marked this pull request as draft November 13, 2025 16:13
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Nov 13, 2025

Greptile Summary

  • Consolidates workspace creation/caching and moves device context management to C++ to reduce CPU overhead in PyTorch operations
  • Reduces PyTorch autograd function arguments by grouping non-tensor args into tuples, minimizing overhead from unwrap_dead_wrappers and arg validation
  • Optimizes recipe checks by caching torch.is_grad_enabled() results and grouping delayed scaling checks, plus implements custom quantizer copy() methods to avoid copy.copy() overhead

Confidence Score: 3/5

  • This PR has critical bugs that must be fixed before merging
  • The workspace caching and CPU overhead optimizations are solid improvements, but the get_tensor_device() function has a critical bug where .device.index can return None for CUDA devices without explicit index (e.g., torch.device('cuda')), which will cause runtime errors when creating tensors
  • Pay close attention to transformer_engine/pytorch/cpp_extensions/gemm.py - the device index handling must be fixed

Important Files Changed

Filename Overview
transformer_engine/pytorch/cpp_extensions/gemm.py Added workspace caching and device detection, but .device.index can return None causing issues with device specification

Sequence Diagram

sequenceDiagram
    participant User
    participant Linear/GroupedLinear
    participant AutogradFunction
    participant general_gemm
    participant CUDAGuard
    participant cuBLAS/cuDNN

    User->>Linear/GroupedLinear: forward(input)
    Linear/GroupedLinear->>Linear/GroupedLinear: Cache torch.is_grad_enabled()
    Linear/GroupedLinear->>Linear/GroupedLinear: Consolidate args into non_tensor_args tuple
    Linear/GroupedLinear->>AutogradFunction: forward(tensors, non_tensor_args)
    AutogradFunction->>general_gemm: gemm(A, B, quantization_params)
    general_gemm->>general_gemm: get_cublas_workspace(device, ub, grouped)
    general_gemm->>CUDAGuard: Set correct device context
    CUDAGuard->>cuBLAS/cuDNN: Execute GEMM with correct device
    cuBLAS/cuDNN-->>general_gemm: output
    general_gemm-->>AutogradFunction: result
    AutogradFunction-->>Linear/GroupedLinear: output
    Linear/GroupedLinear-->>User: result
Loading

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

27 files reviewed, 5 comments

Edit Code Review Agent Settings | Greptile

Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py
Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py
Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py
Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py
Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py
@yaox12 yaox12 requested review from yaox12 and zhongbozhu November 14, 2025 03:40
@ksivaman ksivaman marked this pull request as ready for review November 14, 2025 19:46
@ksivaman ksivaman requested a review from vthumbe1503 November 14, 2025 19:46
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

27 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@vthumbe1503
Copy link
Copy Markdown
Collaborator

Reviewed with Kirthi Offline. Looks good to me. LGTM.

if ub:
return torch.empty(
get_cublas_workspace_size_bytes(), dtype=torch.uint8, device=device
).repeat(_NUM_MAX_UB_STREAMS)
Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 Nov 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a minor further optimization, can be done in later PR as well. Instead of calling empty and then calling a repeat again which means there are 2 torch operations.Directly you can call

torch.empty(
            get_cublas_workspace_size_bytes() * _NUM_MAX_UB_STREAMS, dtype=torch.uint8,device=device)

@ksivaman
Copy link
Copy Markdown
Member Author

/te-ci

@ksivaman ksivaman merged commit e1edaae into NVIDIA:main Nov 17, 2025
9 of 12 checks passed
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

27 files reviewed, no comments

Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format

@pggPL pggPL mentioned this pull request Nov 18, 2025
13 tasks
KshitijLakhani pushed a commit that referenced this pull request Nov 20, 2025
Initial changes to remove pytorch overheads

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
timmoon10 referenced this pull request in NVIDIA-NeMo/RL May 4, 2026
Signed-off-by: Anna Shors <ashors@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants