Skip to content

Commit

Permalink
[ROCm] add hipblaslt support (pytorch#114329)
Browse files Browse the repository at this point in the history
Disabled by default. Enable with env var DISABLE_ADDMM_HIP_LT=0. Tested on both ROCm 5.7 and 6.0.

Pull Request resolved: pytorch#114329
Approved by: https://github.com/malfet
  • Loading branch information
jeffdaily authored and pytorchmergebot committed Dec 14, 2023
1 parent 04ef21f commit bb2bb8c
Show file tree
Hide file tree
Showing 7 changed files with 284 additions and 20 deletions.
103 changes: 93 additions & 10 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@

// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
// added bf16 support
#if !defined(USE_ROCM) && !defined(_MSC_VER)
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#include <cublasLt.h>
#endif

#ifdef USE_ROCM
// until hipblas has an API to accept flags, we must use rocblas here
#include <hipblas/hipblas.h>
#include <rocblas/rocblas.h>
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
Expand Down Expand Up @@ -64,6 +65,7 @@ static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
// until we use hiblas v2
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
// however hipblas v1 is still using its custom type
#ifndef HIPBLAS_V2
#define HIP_R_16F HIPBLAS_R_16F
#define HIP_R_32F HIPBLAS_R_32F
#define HIP_R_64F HIPBLAS_R_64F
Expand All @@ -81,6 +83,7 @@ static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
#define HIP_R_16BF HIPBLAS_R_16B
#define HIP_C_16BF HIPBLAS_C_16B
#endif
#endif

#define CUDABLAS_POSINT_CHECK(FD, X) \
TORCH_CHECK( \
Expand Down Expand Up @@ -167,6 +170,7 @@ static void _cublasAdjustLdLevel3(
}
}

#ifndef USE_ROCM
uint32_t _getAlignment(uintptr_t address) {
// alignment are in bytes
uint32_t alignment = 256;
Expand All @@ -176,18 +180,25 @@ uint32_t _getAlignment(uintptr_t address) {
}
}
}
#endif

static size_t _parseChosenWorkspaceSize() {
const char * val = getenv("CUBLASLT_WORKSPACE_SIZE");
#ifdef USE_ROCM
if (!val) {
// accept either env var
val = getenv("HIPBLASLT_WORKSPACE_SIZE");
}
#endif
size_t workspace_size = 1024; /* default size in KiB according to #73328 */
if (val) {
try {
workspace_size = std::stoi(val);
} catch(std::invalid_argument const& e) {
TORCH_WARN("invalid CUBLAS_LT_WORKSPACE_SIZE,",
TORCH_WARN("invalid CUBLASLT_WORKSPACE_SIZE,",
" using default workspace size of ", workspace_size, " bytes.");
} catch(std::out_of_range const& e) {
TORCH_WARN("CUBLAS_LT_WORKSPACE_SIZE out of range,",
TORCH_WARN("CUBLASLT_WORKSPACE_SIZE out of range,",
" using default workspace size of ", workspace_size, " bytes.");
}
}
Expand Down Expand Up @@ -346,7 +357,13 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
(void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea,
b, CUDA_R_16BF, (int)ldb, strideb,
(void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
(int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
(int)num_batches,
#if defined(USE_ROCM) && ROCM_VERSION >= 60000
CUBLAS_COMPUTE_32F,
#else
CUDA_R_32F,
#endif
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}

template <>
Expand Down Expand Up @@ -536,12 +553,66 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
c,
CUDA_R_16BF,
ldc,
#if defined(USE_ROCM) && ROCM_VERSION >= 60000
CUBLAS_COMPUTE_32F,
#else
CUDA_R_32F,
#endif
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
}

#if !defined(USE_ROCM) && !defined(_MSC_VER)
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)

#if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000
// only for rocm 5.7 where we first supported hipblaslt, it was difficult
// to hipify correctly without this change.
#define hipDataType hipblasDatatype_t
#endif

// hipblaslt custom types were a temporary work-around
#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && HIPBLASLT_CUSTOM_DATA_TYPE
hipblasltDatatype_t hipToLt(hipDataType type) {
switch (type) {
case HIP_R_32F: return HIPBLASLT_R_32F;
case HIP_R_64F: return HIPBLASLT_R_64F;
case HIP_R_16F: return HIPBLASLT_R_16F;
case HIP_R_8I: return HIPBLASLT_R_8I;
case HIP_C_32F: return HIPBLASLT_C_32F;
case HIP_C_64F: return HIPBLASLT_C_64F;
case HIP_C_16F: return HIPBLASLT_C_16F;
case HIP_C_8I: return HIPBLASLT_C_8I;
case HIP_R_8U: return HIPBLASLT_R_8U;
case HIP_C_8U: return HIPBLASLT_C_8U;
case HIP_R_32I: return HIPBLASLT_R_32I;
case HIP_C_32I: return HIPBLASLT_C_32I;
case HIP_R_32U: return HIPBLASLT_R_32U;
case HIP_C_32U: return HIPBLASLT_C_32U;
case HIP_R_16BF: return HIPBLASLT_R_16B;
case HIP_C_16BF: return HIPBLASLT_C_16B;
default: TORCH_CHECK(false);
}
}
#define HIPTOLT(type) hipToLt(type)
#else
#define HIPTOLT(type) type
#endif

#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && HIPBLASLT_CUSTOM_COMPUTE_TYPE
hipblasLtComputeType_t hipblasToLt(hipblasComputeType_t type) {
switch (type) {
case HIPBLAS_COMPUTE_32F: return HIPBLASLT_COMPUTE_F32;
case HIPBLAS_COMPUTE_32F_FAST_16F: return HIPBLASLT_COMPUTE_F32_FAST_F16;
case HIPBLAS_COMPUTE_32F_FAST_TF32: return HIPBLASLT_COMPUTE_F32_FAST_XF32;
case HIPBLAS_COMPUTE_64F: return HIPBLASLT_COMPUTE_F64;
case HIPBLAS_COMPUTE_32I: return HIPBLASLT_COMPUTE_I32;
default: TORCH_CHECK(false);
}
}
#define HIPCOMPTOLT(type) hipblasToLt(type)
#else
#define HIPCOMPTOLT(type) type
#endif

namespace {
// Following the pattern of CuSparseDescriptor
Expand Down Expand Up @@ -580,7 +651,7 @@ class CuBlasLtMatmulDescriptor : public CuBlasLtDescriptor<
cudaDataType_t scale_type) {
cublasLtMatmulDesc_t raw_descriptor = nullptr;
TORCH_CUDABLAS_CHECK(
cublasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
cublasLtMatmulDescCreate(&raw_descriptor, HIPCOMPTOLT(compute_type), HIPTOLT(scale_type)));
descriptor_.reset(raw_descriptor);
}
template <typename T>
Expand All @@ -601,7 +672,7 @@ class CuBlasLtMatrixLayout : public CuBlasLtDescriptor<
bool t = false) {
cublasLtMatrixLayout_t raw_descriptor = nullptr;
TORCH_CUDABLAS_CHECK(
cublasLtMatrixLayoutCreate(&raw_descriptor, type, t ? cols : rows, t ? rows : cols, ld));
cublasLtMatrixLayoutCreate(&raw_descriptor, HIPTOLT(type), t ? cols : rows, t ? rows : cols, ld));
descriptor_.reset(raw_descriptor);
}
};
Expand Down Expand Up @@ -645,13 +716,19 @@ void gemm_and_bias(
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
cudaDataType_t scaleType = CUDA_R_32F;
if constexpr (std::is_same_v<Dtype, double>) {
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
abcType = CUDA_R_64F;
computeType = CUBLAS_COMPUTE_64F;
scaleType = CUDA_R_64F;
#else
TORCH_CHECK(false, "gemm_and_bias is only supported for double type on ROCm 6.0 and above");
#endif
} else if constexpr (std::is_same_v<Dtype, float>) {
#ifndef USE_ROCM
if (at::globalContext().allowTF32CuBLAS()) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
}
#endif
abcType = CUDA_R_32F;
} else if constexpr (std::is_same_v<Dtype, at::Half>) {
abcType = CUDA_R_16F;
Expand All @@ -668,7 +745,7 @@ void gemm_and_bias(
if (activation == GEMMAndBiasActivationEpilogue::RELU) {
epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
} else if (activation == GEMMAndBiasActivationEpilogue::GELU) {
#if CUDA_VERSION >= 11040
#if CUDA_VERSION >= 11040 || defined(USE_ROCM)
epilogue = CUBLASLT_EPILOGUE_GELU_BIAS;
#endif
}
Expand All @@ -685,6 +762,7 @@ void gemm_and_bias(
size_t workspaceSize = _getWorkspaceSize();
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);

#ifndef USE_ROCM
uint32_t a_alignment = _getAlignment(reinterpret_cast<uintptr_t>(mat1_ptr));
uint32_t b_alignment = _getAlignment(reinterpret_cast<uintptr_t>(mat2_ptr));
uint32_t c_alignment = _getAlignment(reinterpret_cast<uintptr_t>(result_ptr));
Expand All @@ -693,6 +771,7 @@ void gemm_and_bias(
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, b_alignment);
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, d_alignment);
#endif

auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto workspace = allocator.allocate(workspaceSize);
Expand Down Expand Up @@ -952,6 +1031,7 @@ void int8_gemm(
int64_t mat2_ld,
int32_t* result_ptr,
int64_t result_ld) {
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)

cublasComputeType_t computeType = CUBLAS_COMPUTE_32I;
cudaDataType_t scaleType = CUDA_R_32I;
Expand Down Expand Up @@ -1022,11 +1102,14 @@ void int8_gemm(
computeType,
" scaleType ",
scaleType);
#else
TORCH_CHECK(false, "int8_gemm is only supported for ROCm 6.0 and above");
#endif // !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
}
#endif // !defined(USE_ROCM) && !defined(_MSC_VER)
#endif // (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)

// ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not.
#if defined(USE_ROCM) && ROCM_VERSION <= 56000
#if defined(USE_ROCM) && ROCM_VERSION <= 50600
#define ROCM_CONST_BUG
#else
#define ROCM_CONST_BUG const
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/cuda/CUDABlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
template <>
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));

#if !defined(USE_ROCM) && !defined(_MSC_VER)
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
enum GEMMAndBiasActivationEpilogue {
None,
RELU,
Expand Down Expand Up @@ -149,7 +149,7 @@ void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
template <>
void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));

#if defined(USE_ROCM) && ROCM_VERSION <= 55000
#if defined(USE_ROCM) && ROCM_VERSION <= 50500
// ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not.
#define CUDABLAS_TRSM_ARGTYPES(Dtype) \
hipblasHandle_t handle, hipblasSideMode_t side, hipblasFillMode_t uplo, \
Expand Down
53 changes: 48 additions & 5 deletions aten/src/ATen/native/cuda/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ enum class Activation {
GELU,
};

#if !defined(USE_ROCM) && !defined(_MSC_VER)
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) {
switch (a) {
case Activation::None:
Expand All @@ -171,12 +171,40 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa

static bool getDisableAddmmCudaLt() {
static const char* env_value = std::getenv("DISABLE_ADDMM_CUDA_LT");
#ifdef USE_ROCM
// allow both CUDA and HIP env var names for ROCm builds
// also, current default for ROCm builds is disable by default
if (env_value == nullptr) {
env_value = std::getenv("DISABLE_ADDMM_HIP_LT");
}
if (env_value != nullptr && strcmp(env_value, "0") == 0) {
return false;
}
return true;
#else
if (env_value != nullptr && strcmp(env_value, "1") == 0) {
return true;
}
return false;
#endif
}

#ifdef USE_ROCM
static bool isSupportedHipLtROCmArch(int index) {
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
std::string device_arch = prop->gcnArchName;
static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
for (std::string arch : archs) {
size_t substring = device_arch.find(arch);
if (substring != std::string::npos) {
return true;
}
}
TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!");
return false;
}
#endif

Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None) {
// Make sure to keep addmm_cuda below in sync with this code; it
// preflights a check to try to avoid actually needing to call
Expand All @@ -198,7 +226,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
at::ScalarType scalar_type = self.scalar_type();
c10::MaybeOwned<Tensor> self_;
if (&result != &self) {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && ROCM_VERSION >= 50700
// Strangely, if mat2 has only 1 row or column, we get
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
Expand All @@ -211,10 +239,17 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 &&
result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
self.is_contiguous() && result.is_contiguous() &&
#ifdef USE_ROCM
isSupportedHipLtROCmArch(self.device().index()) &&
(scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
scalar_type == at::ScalarType::BFloat16) &&
#else
(scalar_type == at::ScalarType::Double ||
scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
scalar_type == at::ScalarType::BFloat16) &&
#endif
mat2_sizes[0] > 1 && mat2_sizes[1] > 1 &&
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
Expand All @@ -234,6 +269,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
}
self__sizes = self_->sizes();
} else {
#if defined(USE_ROCM) && ROCM_VERSION >= 50700
useLtInterface = !disable_addmm_cuda_lt &&
result.dim() == 2 && result.is_contiguous() &&
isSupportedHipLtROCmArch(self.device().index()) &&
(scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
scalar_type == at::ScalarType::BFloat16);
#endif
self_ = c10::MaybeOwned<Tensor>::borrowed(self);
self__sizes = self_->sizes();
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
Expand Down Expand Up @@ -277,7 +320,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma

TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());

#if !defined(USE_ROCM) && !defined(_MSC_VER)
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
if (useLtInterface) {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
Expand All @@ -299,7 +342,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
self.const_data_ptr<scalar_t>(),
args.result->data_ptr<scalar_t>(),
args.result_ld,
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11080) || defined(USE_ROCM)
activation_to_gemm_and_blas_arg(activation)
#else
// GELU is not supported (and does not compile!) prior
Expand Down Expand Up @@ -357,7 +400,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
// gating activation_to_gemm_and_blas_arg above; here we are manually
// performing a post-GELU because we weren't able to use the GELU
// epilogue above.
#if !defined(CUDA_VERSION) || CUDA_VERSION < 11080
#if !(defined(CUDA_VERSION) && CUDA_VERSION >= 11080) && !defined(USE_ROCM)
if (useLtInterface && activation == Activation::GELU) {
at::gelu_(const_cast<Tensor&>(*args.result), "tanh");
}
Expand Down
Loading

0 comments on commit bb2bb8c

Please sign in to comment.