Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OF_CUDA_CHECK/OF_CUDNN_CHECK/OF_CUBLAS_CHECK/OF_CURAND_CHECK #3446

Merged
merged 12 commits into from
Aug 8, 2020
4 changes: 0 additions & 4 deletions oneflow/core/device/cuda_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ namespace oneflow {

#ifdef WITH_CUDA

namespace {

const char* CublasGetErrorString(cublasStatus_t error) {
switch (error) {
case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
Expand Down Expand Up @@ -62,8 +60,6 @@ const char* CurandGetErrorString(curandStatus_t error) {
return "Unknown curand status";
}

} // namespace

void InitGlobalCudaDeviceProp() {
CHECK(Global<cudaDeviceProp>::Get() == nullptr) << "initialized Global<cudaDeviceProp> twice";
Global<cudaDeviceProp>::New();
Expand Down
32 changes: 32 additions & 0 deletions oneflow/core/device/cuda_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,38 @@ limitations under the License.

namespace oneflow {

const char* CublasGetErrorString(cublasStatus_t error);

const char* CurandGetErrorString(curandStatus_t error);

#define OF_CUDA_CHECK(condition) \
for (cudaError_t _of_cuda_check_status = (condition); _of_cuda_check_status != cudaSuccess;) \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用这个for做临时作用域来定义临时变量的技巧非常有意思

LOG(FATAL) << "Check failed: " #condition " : " << cudaGetErrorString(_of_cuda_check_status) \
<< " (" << _of_cuda_check_status << ") "

#define OF_CUDNN_CHECK(condition) \
for (cudnnStatus_t _of_cudnn_check_status = (condition); \
_of_cudnn_check_status != CUDNN_STATUS_SUCCESS;) \
LOG(FATAL) << "Check failed: " #condition " : " << cudnnGetErrorString(_of_cudnn_check_status) \
<< " (" << _of_cudnn_check_status << ") "

#define OF_CUBLAS_CHECK(condition) \
for (cublasStatus_t _of_cublas_check_status = (condition); \
_of_cublas_check_status != CUBLAS_STATUS_SUCCESS;) \
LOG(FATAL) << "Check failed: " #condition " : " << CublasGetErrorString(_of_cublas_check_status) \
<< " (" << _of_cublas_check_status << ") "

#define OF_CURAND_CHECK(condition) \
for (curandStatus_t _of_curand_check_status = (condition); \
_of_curand_check_status != CURAND_STATUS_SUCCESS;) \
LOG(FATAL) << "Check failed: " #condition " : " << CurandGetErrorString(_of_curand_check_status) \
<< " (" << _of_curand_check_status << ") "

#define OF_NCCL_CHECK(condition) \
for (ncclResult_t _of_nccl_check_status = (condition); _of_nccl_check_status != ncclSuccess;) \
LOG(FATAL) << "Check failed: " #condition " : " << ncclGetErrorString(_of_nccl_check_status) \
<< " (" << _of_nccl_check_status << ") "

template<typename T>
void CudaCheck(T error);

Expand Down