Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,16 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
py::object ln_out, py::handle quantizer, DType otype,
const int sm_margin, const bool zero_centered_gamma);

/***************************************************************************************************
* Memory allocation
**************************************************************************************************/

// Allocates tensors all backed by a single contiguous buffer.
std::vector<at::Tensor> bulk_allocate(const std::vector<std::vector<size_t>> &shapes,
const std::vector<at::ScalarType> &dtypes,
std::optional<c10::Device> device = std::nullopt,
std::optional<std::vector<size_t>> alignments = std::nullopt);

/***************************************************************************************************
* Cast
**************************************************************************************************/
Expand Down
86 changes: 86 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/allocate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include <memory>
#include <vector>

#include "../extensions.h"

namespace transformer_engine {
namespace pytorch {

std::vector<at::Tensor> bulk_allocate(const std::vector<std::vector<size_t>> &shapes,
const std::vector<at::ScalarType> &dtypes,
std::optional<c10::Device> device,
std::optional<std::vector<size_t>> alignments) {
// Check shapes and dtypes
const size_t n = shapes.size();
NVTE_CHECK(dtypes.size() == n, "Got ", shapes.size(), " shapes and ", dtypes.size(), " dtypes.");
NVTE_CHECK(!alignments || alignments->size() == n, "Got ", shapes.size(), " shapes and ",
alignments->size(), " alignments.");

// Return immediately if no tensors are needed
if (n == 0) return {};

// Set defaults for optional arguments
if (!device) {
device = c10::Device(c10::kCUDA);
}
if (!alignments) {
alignments = std::vector<size_t>{};
alignments->reserve(n);
for (const auto &dtype : dtypes) {
alignments->push_back(c10::elementSize(dtype));
}
}

// Compute offsets in base buffer
std::vector<size_t> byte_sizes(n);
std::vector<size_t> offsets(n);
size_t base_byte_size = 0;
size_t base_alignment = 1;
for (size_t i = 0; i < n; ++i) {
byte_sizes[i] = product(shapes[i]) * at::elementSize(dtypes[i]);
offsets[i] = roundup(base_byte_size, (*alignments)[i]);
base_byte_size = offsets[i] + byte_sizes[i];
base_alignment = std::max(base_alignment, (*alignments)[i]);
}
if (base_alignment > 1) {
// Pad in case data pointer is not aligned
base_byte_size += base_alignment;
}

// Allocate base buffer
auto base_buffer = std::make_shared<at::Tensor>(
at::empty({static_cast<int64_t>(base_byte_size)}, at::device(*device).dtype(torch::kUInt8)));
uint8_t *base_ptr = base_buffer->data_ptr<uint8_t>();
base_ptr =
reinterpret_cast<uint8_t *>(roundup(reinterpret_cast<uintptr_t>(base_ptr), base_alignment));

// Create views into base buffer
std::vector<at::Tensor> out;
out.reserve(n);
std::vector<int64_t> shape_int64;
for (size_t i = 0; i < n; ++i) {
shape_int64.assign(shapes[i].begin(), shapes[i].end());
if (byte_sizes[i] == 0) {
// Work around problems with from_blob when constructing an
// empty tensor. Passing a null pointer fails because it checks
// that the pointer is on GPU. Passing a non-null pointer can
// cause bugs in TE kernels.
out.emplace_back(at::empty(shape_int64, at::device(*device).dtype(dtypes[i])));
} else {
// Construct tensor with custom deleter to keep base buffer alive
out.emplace_back(at::from_blob(
base_ptr + offsets[i], shape_int64, [base_buffer](void *) {},
at::device(*device).dtype(dtypes[i])));
}
}
return out;
}

} // namespace pytorch
} // namespace transformer_engine
Loading
Loading