From e111237a95969384adffd22d3c8904c4cdd7e8de Mon Sep 17 00:00:00 2001 From: Johannes M Dieterich Date: Mon, 22 Oct 2018 13:01:23 -0500 Subject: [PATCH] Since upstream #9322 by @petrex , hipification resulted in illegal code. The legal function cublasHandle_t cublas_handle() was hipified to the clearly illegal rocblas_handle rocblas_handle(). It should not work and correctly fails with gcc as the host compiler as it induces an ambiguity. Function now hipifies to rocblas_handle rocblashandle() Fixes long standing issue we've observed in PyTorch when base compiler is gcc. --- caffe2/core/hip/context_hip.h | 2 +- caffe2/utils/hip/math_hip.cc | 28 +++++++++---------- .../pyHIPIFY/cuda_to_hip_mappings.py | 2 +- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/caffe2/core/hip/context_hip.h b/caffe2/core/hip/context_hip.h index ca2d03b1449d..463303a1f0df 100644 --- a/caffe2/core/hip/context_hip.h +++ b/caffe2/core/hip/context_hip.h @@ -175,7 +175,7 @@ class HIPContext final : public BaseContext { return hip_objects_.GetStream(gpu_id, stream_id); } - rocblas_handle rocblas_handle() { + rocblas_handle rocblashandle() { return hip_objects_.GetHandle(gpu_id_, stream_id_); } diff --git a/caffe2/utils/hip/math_hip.cc b/caffe2/utils/hip/math_hip.cc index d499393dbe81..bc6ee92f46e3 100644 --- a/caffe2/utils/hip/math_hip.cc +++ b/caffe2/utils/hip/math_hip.cc @@ -759,7 +759,7 @@ void Gemm( ? rocblas_operation_none : rocblas_operation_transpose; ROCBLAS_ENFORCE(rocblas_sgemm( - context->rocblas_handle(), + context->rocblashandle(), cuTransB, cuTransA, N, @@ -803,7 +803,7 @@ void Gemm( : rocblas_operation_transpose; if (math_type == TensorProto_DataType_FLOAT) { ROCBLAS_CHECK(rocblas_sgemmEx( - context->rocblas_handle(), + context->rocblashandle(), cuTransB, cuTransA, N, @@ -828,7 +828,7 @@ void Gemm( // call cublasHgemm ROCBLAS_CHECK(cublasHgemm( - context->rocblas_handle(), + context->rocblashandle(), cuTransB, cuTransA, N, @@ -933,7 +933,7 @@ void GemmStridedBatched( ? rocblas_operation_none : rocblas_operation_transpose; ROCBLAS_ENFORCE(rocblas_sgemm_strided_batched( - context->rocblas_handle(), + context->rocblashandle(), cuTransB, cuTransA, N, @@ -1004,7 +1004,7 @@ void GemmStridedBatched( __half alpha_fp16 = at::Half(alpha); __half beta_fp16 = at::Half(beta); ROCBLAS_ENFORCE(cublasHgemmStridedBatched( - context->rocblas_handle(), + context->rocblashandle(), cuTransB, cuTransA, N, @@ -1051,7 +1051,7 @@ void GemmEx( ? rocblas_operation_none : rocblas_operation_transpose; ROCBLAS_ENFORCE(rocblas_sgemm( - context->rocblas_handle(), + context->rocblashandle(), cuTransB, cuTransA, N, @@ -1083,7 +1083,7 @@ void Gemv( ? rocblas_operation_transpose : rocblas_operation_none; ROCBLAS_ENFORCE(rocblas_sgemv( - context->rocblas_handle(), + context->rocblashandle(), cuTransA, N, M, @@ -1170,7 +1170,7 @@ void Gemv( if (math_type == TensorProto_DataType_FLOAT) { ROCBLAS_CHECK(cublasSgemmEx( - context->rocblas_handle(), + context->rocblashandle(), cuTransA, rocblas_operation_none, m, @@ -1192,7 +1192,7 @@ void Gemv( __half beta_fp16 = at::Half(beta); ROCBLAS_CHECK(cublasHgemm( - context->rocblas_handle(), + context->rocblashandle(), cuTransA, rocblas_operation_none, m, @@ -1390,7 +1390,7 @@ void Dot( HIPContext* context) { float result; ROCBLAS_ENFORCE( - rocblas_sdot(context->rocblas_handle(), n, a, 1, b, 1, &result)); + rocblas_sdot(context->rocblashandle(), n, a, 1, b, 1, &result)); context->CopyFromCPU(1, &result, y); } @@ -1406,7 +1406,7 @@ void Dot( at::Half result; // execute with 32-bit math ROCBLAS_CHECK(cublasDotEx( - context->rocblas_handle(), + context->rocblashandle(), n, a, CUDA_R_16F, @@ -1879,7 +1879,7 @@ void Axpy( float* Y, HIPContext* context) { ROCBLAS_ENFORCE( - rocblas_saxpy(context->rocblas_handle(), N, &alpha, X, 1, Y, 1)); + rocblas_saxpy(context->rocblashandle(), N, &alpha, X, 1, Y, 1)); } template <> @@ -1891,7 +1891,7 @@ void Axpy( HIPContext* context) { double alpha_d{alpha}; ROCBLAS_ENFORCE( - rocblas_daxpy(context->rocblas_handle(), N, &alpha_d, X, 1, Y, 1)); + rocblas_daxpy(context->rocblashandle(), N, &alpha_d, X, 1, Y, 1)); } template <> @@ -1904,7 +1904,7 @@ void Axpy( CAFFE_THROW("Unsupported math type"); #if ROCBLAS_FP16 ROCBLAS_CHECK(cublasAxpyEx( - context->rocblas_handle(), + context->rocblashandle(), N, &alpha, CUDA_R_16F, diff --git a/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py b/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py index 2d1f722c8db2..4a8aef36eaae 100644 --- a/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py +++ b/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py @@ -2230,7 +2230,7 @@ ("HasCudaGPU" , ("HasHipGPU", API_CAFFE2)), ("__expf" , ("expf", API_CAFFE2)), ("CUBLAS_ENFORCE" , ("ROCBLAS_ENFORCE", API_CAFFE2)), - ("cublas_handle" , ("rocblas_handle", API_CAFFE2)), + ("cublas_handle" , ("rocblashandle", API_CAFFE2)), ("CURAND_ENFORCE" ,("HIPRAND_ENFORCE", API_CAFFE2)), ("curandGenerateUniform" , ("hiprandGenerateUniform", API_CAFFE2)), ("curand_generator" , ("hiprand_generator", API_CAFFE2)),