Skip to content

[PyTorch] Expose function to bulk-allocate tensors backed by the same buffer#2900

Merged
timmoon10 merged 14 commits into
NVIDIA:mainfrom
timmoon10:tmoon/group-mlp-bulk-allocate
May 12, 2026
Merged

[PyTorch] Expose function to bulk-allocate tensors backed by the same buffer#2900
timmoon10 merged 14 commits into
NVIDIA:mainfrom
timmoon10:tmoon/group-mlp-bulk-allocate

Conversation

@timmoon10
Copy link
Copy Markdown
Collaborator

Description

Allocating PyTorch tensors carries non-trivial CPU overhead, which becomes especially painful when allocating per-expert tensors in the grouped linear layer. This PR generalizes the bulk allocation approach used in the split-quantize functions (see #1793), which involves allocating a single large buffer and creating tensor subviews with at::from_blob. We expose a dedicated bulk-allocation function that is exposed to Python, refactor the split-quantize functions, and bulk-allocate wgrads in the grouped linear implementations.

This is incremental progress toward #2897. When I run the grouped MLP benchmark discussed in that issue, I see a runtime reduction of 150 us (for reference the backward pass takes ~2.1 ms).

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add dedicated function for bulk-allocating PyTorch tensors
  • Bulk-allocate wgrad tensors in grouped linear implementations

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

timmoon10 and others added 4 commits April 18, 2026 01:19
…locators

Introduces transformer_engine/pytorch/csrc/extensions/allocate.cpp with a
general-purpose bulk_allocate function: given parallel lists of shapes,
dtypes, and per-tensor byte alignments, it computes a packed layout, does
a single CUDA allocation, and returns at::from_blob views whose deleters
keep the backing buffer alive.

The three internal bulk_allocate_*_tensors helpers in cast.cpp are
refactored to call bulk_allocate instead of each owning a copy of the
make_torch_view lambda and the offset-computation loops (~120 lines
removed). The new function is also exposed via pybind11 so Python can
allocate packed CUDA buffers directly without going through a quantizer.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 18, 2026

Greptile Summary

This PR introduces a dedicated bulk_allocate C++ function that allocates a single contiguous CUDA buffer and returns per-tensor views via at::from_blob, eliminating the CPU overhead of individual torch.empty calls in the grouped linear backward passes. The existing split-quantize helpers in cast.cpp are refactored to delegate to this new primitive, and all three grouped linear backward implementations are updated to use it for weight gradient allocation.

  • New allocate.cpp: Computes per-tensor byte offsets with configurable alignment, pads the base buffer for pointer alignment, and uses a shared_ptr<at::Tensor> deleter to keep the backing buffer alive as long as any view exists; zero-size tensors fall back to a standalone at::empty to avoid from_blob edge-case bugs.
  • cast.cpp refactor: Removes three copies of the bespoke buffer-management lambda and delegates to bulk_allocate; the contiguous_data_and_scale pre-check logic is correctly rewritten as a per-tensor size divisibility test that is semantically equivalent to the original cumulative offset check.
  • Python call-sites: Consistently pass device as the third positional argument and [256] * num_groups as alignments, matching the pybind11 signature.

Confidence Score: 5/5

The change is safe to merge: it is a pure performance optimization with well-contained scope — a new allocation helper replaces equivalent per-tensor allocations, and the refactored cast.cpp paths preserve the same memory layout semantics.

The core bulk_allocate logic is straightforward: offset arithmetic, a single CUDA allocation, and reference-counted views. The alignment padding is conservative and correct. All three Python call-sites pass arguments in the right positional order. The cast.cpp refactoring replaces identical local lambdas with equivalent calls to the new primitive, and the contiguous_data_and_scale rewrite is semantically equivalent to the original cumulative-offset check.

No files require special attention; all changes are self-contained and the critical path is covered by the three Python call-sites.

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/extensions/allocate.cpp New file implementing bulk_allocate: allocates a single contiguous CUDA buffer and creates at::from_blob views for each tensor, with per-tensor alignment support and correct empty-tensor workaround.
transformer_engine/pytorch/csrc/extensions/cast.cpp Refactors bulk_allocate_fp8/fp4 helpers to delegate to the new bulk_allocate function; eliminates duplicated buffer-management code while preserving the contiguous_data_and_scale semantics.
transformer_engine/pytorch/csrc/extensions/pybind.cpp Exposes bulk_allocate to Python with py::call_guardpy::gil_scoped_release(), consistent with other allocation-adjacent bindings in the file.
transformer_engine/pytorch/module/grouped_linear.py Replaces per-expert torch.empty loop with tex.bulk_allocate for wgrad tensors, using 256-byte alignment and the context device.
transformer_engine/pytorch/ops/basic/grouped_linear.py Replaces per-expert torch.empty loop with tex.bulk_allocate for wgrad tensors; aligned with module/grouped_linear.py change.
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Replaces per-expert torch.empty loop with tex.bulk_allocate in _compute_grad_params; correctly passes device as third arg and [256]*num_groups as alignments.

Sequence Diagram

sequenceDiagram
    participant Py as Python (grouped_linear backward)
    participant PB as pybind11
    participant BA as bulk_allocate (C++)
    participant Alloc as CUDA Allocator

    Py->>PB: tex.bulk_allocate(shapes, dtypes, device, alignments)
    Note over PB: GIL released
    PB->>BA: bulk_allocate(shapes, dtypes, device, alignments)
    BA->>BA: "compute per-tensor offsets & base_alignment"
    BA->>BA: "base_byte_size += base_alignment (padding)"
    BA->>Alloc: "at::empty({base_byte_size}, kUInt8, device)"
    Alloc-->>BA: "base_buffer (shared_ptr<at::Tensor>)"
    BA->>BA: align base_ptr to base_alignment
    loop for each tensor i
        alt "byte_sizes[i] == 0"
            BA->>Alloc: at::empty(shape_i, dtype_i)
            Alloc-->>BA: standalone empty tensor
        else
            BA->>BA: at::from_blob(base_ptr+offset[i], shape_i, deleter, dtype_i)
        end
    end
    BA-->>PB: "vector<at::Tensor>"
    PB-->>Py: List[torch.Tensor] (wgrad_list)
    Note over Py: base_buffer kept alive via shared_ptr in each view deleter
Loading

Reviews (7): Last reviewed commit: "Merge branch 'main' into tmoon/group-mlp..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/csrc/extensions/pybind.cpp Outdated
Comment thread transformer_engine/pytorch/csrc/extensions/allocate.cpp Outdated
Comment thread transformer_engine/pytorch/module/grouped_linear.py
Comment thread transformer_engine/pytorch/csrc/extensions/allocate.cpp Outdated
Comment thread transformer_engine/pytorch/csrc/extensions/allocate.cpp Outdated
Comment thread transformer_engine/pytorch/csrc/extensions/allocate.cpp Outdated
Comment thread transformer_engine/pytorch/csrc/extensions/allocate.cpp Outdated
Comment thread transformer_engine/pytorch/csrc/extensions/allocate.cpp Outdated
Comment on lines +756 to +758
// Check whether data and scales can be packed in contiguous
// buffer. Amaxes are not contiguous since they are aligned to
// 16B.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I understand the logic here. What is the case where we would not be able to pack those buffers together if we can control the alignment requirements for the individual tensors in the allocation?

Copy link
Copy Markdown
Collaborator Author

@timmoon10 timmoon10 Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there was some alignment requirement for the weight tensors, and some implementation difference depending on if the weights are contiguous. In any case, this is the logic in the existing implementation and I'm trying to avoid functional changes.

Comment thread transformer_engine/pytorch/ops/basic/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Outdated
timmoon10 and others added 4 commits April 22, 2026 23:14
Make optional args for device and alignment. Handle case where base data_ptr is unaligned. Align grouped linear wgrad buffers to 256B.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Comment thread transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py
Comment thread transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 force-pushed the tmoon/group-mlp-bulk-allocate branch from befa6f6 to 16806b4 Compare April 22, 2026 23:48
@timmoon10
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

timmoon10 and others added 3 commits May 11, 2026 20:21
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

vthumbe1503
vthumbe1503 previously approved these changes May 11, 2026
Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment thread transformer_engine/pytorch/csrc/extensions/allocate.cpp Outdated
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

@timmoon10 timmoon10 merged commit cb59ef1 into NVIDIA:main May 12, 2026
21 of 24 checks passed
@timmoon10 timmoon10 deleted the tmoon/group-mlp-bulk-allocate branch May 13, 2026 00:33
faradawn pushed a commit to faradawn/TransformerEngine that referenced this pull request May 14, 2026
… buffer (NVIDIA#2900)

* [PyTorch] Add bulk_allocate utility and use it in quantized tensor allocators

Introduces transformer_engine/pytorch/csrc/extensions/allocate.cpp with a
general-purpose bulk_allocate function: given parallel lists of shapes,
dtypes, and per-tensor byte alignments, it computes a packed layout, does
a single CUDA allocation, and returns at::from_blob views whose deleters
keep the backing buffer alive.

The three internal bulk_allocate_*_tensors helpers in cast.cpp are
refactored to call bulk_allocate instead of each owning a copy of the
make_torch_view lambda and the offset-computation loops (~120 lines
removed). The new function is also exposed via pybind11 so Python can
allocate packed CUDA buffers directly without going through a quantizer.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Bulk-allocate wgrads in grouped linear impls

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply review suggestions

Make optional args for device and alignment. Handle case where base data_ptr is unaligned. Align grouped linear wgrad buffers to 256B.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Nits from Claude

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix incorrect call to `bulk_allocate`

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Fix ambiguous return type

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Use c10::Device consistently

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants