Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 49 additions & 4 deletions aten/src/ATen/cuda/CublasHandlePool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,39 @@ void destroyCublasLtHandle(cublasLtHandle_t handle) {
}

using CuBlasLtPoolType = DeviceThreadHandlePool<cublasLtHandle_t, createCublasLtHandle, destroyCublasLtHandle>;

// ugly hack until hipblasSetWorkspace exists
#include <rocblas/rocblas.h>

static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error) {
switch(error) {
case rocblas_status_size_unchanged:
case rocblas_status_size_increased:
case rocblas_status_success:
return HIPBLAS_STATUS_SUCCESS;
case rocblas_status_invalid_handle:
return HIPBLAS_STATUS_NOT_INITIALIZED;
case rocblas_status_not_implemented:
return HIPBLAS_STATUS_NOT_SUPPORTED;
case rocblas_status_invalid_pointer:
case rocblas_status_invalid_size:
case rocblas_status_invalid_value:
return HIPBLAS_STATUS_INVALID_VALUE;
case rocblas_status_memory_error:
return HIPBLAS_STATUS_ALLOC_FAILED;
case rocblas_status_internal_error:
return HIPBLAS_STATUS_INTERNAL_ERROR;
}
TORCH_CHECK(false, "HIPBLAS_STATUS_INVALID_ENUM");
}

static hipblasStatus_t hipblasSetWorkspace_replacement(hipblasHandle_t handle, void* addr, size_t size) {
return rocBLASStatusToHIPStatus(rocblas_set_workspace((rocblas_handle)handle, addr, size));
}

// hipify mappings file correctly maps this but the function doesn't exist yet
#define hipblasSetWorkspace hipblasSetWorkspace_replacement

#endif

std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
Expand Down Expand Up @@ -77,17 +110,29 @@ using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle
} // namespace

void clearCublasWorkspaces() {
#if !defined(USE_ROCM)
cublas_handle_stream_to_workspace().clear();
#endif
cublas_handle_stream_to_workspace().clear();
}

size_t parseChosenWorkspaceSize() {
const char * val = getenv("CUBLAS_WORKSPACE_CONFIG");
#ifdef USE_ROCM
if (!val) {
val = getenv("HIPBLAS_WORKSPACE_CONFIG");
}
if (!val) {
// for extra convenience
val = getenv("ROCBLAS_WORKSPACE_CONFIG");
}
/* 32MiB default, 128MiB for MI300 */
cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties();
const bool gfx94 = properties != nullptr && properties->major == 9 && properties->minor == 4;
const size_t default_size = gfx94 ? 1024 * 128 * 1024 : 1024 * 32 * 1024;
#else
/* :4096:2:16:8 default, 32MiB for Hopper */
cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties();
const bool sm90 = properties != nullptr && properties->major == 9 && properties->minor == 0;
const size_t default_size = sm90 ? 4096 * 8 * 1024 : 4096 * 1024 * 2 + 16 * 1024 * 8;
#endif

if (val) {
size_t total_size = 0;
Expand Down Expand Up @@ -156,7 +201,6 @@ cublasHandle_t getCurrentCUDABlasHandle() {
auto handle = myPoolWindow->reserve(device);
auto stream = c10::cuda::getCurrentCUDAStream();
TORCH_CUDABLAS_CHECK(cublasSetStream(handle, stream));
#if !defined(USE_ROCM)
// We explicitly set the cublas workspace even though CUDA 12.2+ fixed the
// issue where memory usage increased during graph capture.
// original issue: https://github.com/pytorch/pytorch/pull/83461
Expand All @@ -171,6 +215,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
workspace_it = cublas_handle_stream_to_workspace().insert(workspace_it, {key, getNewWorkspace()});
}
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(handle, workspace_it->second.get(), getChosenWorkspaceSize()));
#if !defined(USE_ROCM)
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
// FP32 data type calculations based on the value of the allow_tf32 flag.
// To enable TF32, set the math mode of the handle to CUBLAS_TF32_TENSOR_OP_MATH.
Expand Down
19 changes: 18 additions & 1 deletion docs/source/notes/hip.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,24 @@ complete snapshot of the memory allocator state via
underlying allocation patterns produced by your code.

To debug memory errors, set
``PYTORCH_NO_CUDA_MEMORY_CACHING=1`` in your environment to disable caching.
``PYTORCH_NO_HIP_MEMORY_CACHING=1`` in your environment to disable caching.
``PYTORCH_NO_CUDA_MEMORY_CACHING=1`` is also accepted for ease of porting.

.. hipblas-workspaces:

hipBLAS workspaces
------------------

For each combination of hipBLAS handle and HIP stream, a hipBLAS workspace will be allocated if that
handle and stream combination executes a hipBLAS kernel that requires a workspace. In order to
avoid repeatedly allocating workspaces, these workspaces are not deallocated unless
``torch._C._cuda_clearCublasWorkspaces()`` is called; note that it's the same function for CUDA or
HIP. The workspace size per allocation can be specified via the environment variable
``HIPBLAS_WORKSPACE_CONFIG`` with the format ``:[SIZE]:[COUNT]``. As an example, the environment
variable ``HIPBLAS_WORKSPACE_CONFIG=:4096:2:16:8`` specifies a total size of ``2 * 4096 + 8 * 16
KiB`` or 8 MIB. The default workspace size is 32 MiB; MI300 and newer defaults to 128 MiB. To force
hipBLAS to avoid using workspaces, set ``HIPBLAS_WORKSPACE_CONFIG=:0:0``. For convenience,
``CUBLAS_WORKSPACE_CONFIG`` is also accepted.

.. _hipfft-plan-cache:

Expand Down
28 changes: 16 additions & 12 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from torch.testing._internal.autocast_test_lists import AutocastTestLists
from torch.testing._internal.common_cuda import (
_create_scaling_case,
_get_torch_cuda_version,
TEST_CUDNN,
TEST_MULTIGPU,
)
Expand All @@ -55,6 +54,7 @@
parametrize,
run_tests,
serialTest,
setBlasBackendsToDefaultFinally,
skipCUDAMemoryLeakCheckIf,
skipCUDANonDefaultStreamIf,
skipIfRocm,
Expand Down Expand Up @@ -364,19 +364,23 @@ def test_serialization_array_with_storage(self):
q_copy[1].fill_(10)
self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10))

@unittest.skipIf(
TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "temporarily disabled for async"
)
@unittest.skipIf(
_get_torch_cuda_version() >= (12, 2),
"skipped as explicit workspace allocation is removed",
)
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled for async")
@setBlasBackendsToDefaultFinally
def test_cublas_workspace_explicit_allocation(self):
torch.backends.cuda.preferred_blas_library("cublas")
a = torch.randn(7, 7, device="cuda", requires_grad=False)
default_workspace_size = 4096 * 2 * 1024 + 16 * 8 * 1024 # :4096:2:16:8
# different size (32 MiB) expected on Hopper GPU
if torch.cuda.get_device_capability() == (9, 0):
default_workspace_size = 4096 * 8 * 1024
if torch.version.hip:
default_workspace_size = 1024 * 32 * 1024 # :1024:32 32MiB
# different size (128 MiB) expected on MI300 GPU
if torch.cuda.get_device_capability() >= (9, 4):
default_workspace_size = 1024 * 128 * 1024 # :1024:128
else:
default_workspace_size = (
4096 * 2 * 1024 + 16 * 8 * 1024
) # :4096:2:16:8 8MiB
# different size (32 MiB) expected on Hopper GPU
if torch.cuda.get_device_capability() == (9, 0):
default_workspace_size = 4096 * 8 * 1024

def check_workspace_size(inp):
torch._C._cuda_clearCublasWorkspaces()
Expand Down
1 change: 1 addition & 0 deletions torch/utils/hipify/cuda_to_hip_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6685,6 +6685,7 @@
"cublasGetVersion_v2",
("hipblasGetVersion_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED),
),
("cublasSetWorkspace", ("hipblasSetWorkspace", CONV_MATH_FUNC, API_BLAS)),
("cublasSetStream", ("hipblasSetStream", CONV_MATH_FUNC, API_BLAS)),
("cublasGetStream", ("hipblasGetStream", CONV_MATH_FUNC, API_BLAS)),
("cublasSetStream_v2", ("hipblasSetStream_v2", CONV_MATH_FUNC, API_BLAS)),
Expand Down