Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OF_CUDA_CHECK/OF_CUDNN_CHECK/OF_CUBLAS_CHECK/OF_CURAND_CHECK #3446

Merged
merged 12 commits into from
Aug 8, 2020
6 changes: 3 additions & 3 deletions oneflow/core/common/blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@ OF_PP_FOR_EACH_TUPLE(CBLAS_TEMPLATE, BLAS_NAME_SEQ);
#define CUBLAS_TEMPLATE(name) \
template<typename T, typename... Args> \
typename std::enable_if<std::is_same<T, float>::value>::type cublas_##name(Args&&... args) { \
CudaCheck(cublasS##name(std::forward<Args>(args)...)); \
OF_CUBLAS_CHECK(cublasS##name(std::forward<Args>(args)...)); \
} \
template<typename T, typename... Args> \
typename std::enable_if<std::is_same<T, double>::value>::type cublas_##name(Args&&... args) { \
CudaCheck(cublasD##name(std::forward<Args>(args)...)); \
OF_CUBLAS_CHECK(cublasD##name(std::forward<Args>(args)...)); \
} \
template<typename T, typename... Args> \
typename std::enable_if<std::is_same<T, half>::value>::type cublas_##name(Args&&... args) { \
CudaCheck(cublasH##name(std::forward<Args>(args)...)); \
OF_CUBLAS_CHECK(cublasH##name(std::forward<Args>(args)...)); \
}

OF_PP_FOR_EACH_TUPLE(CUBLAS_TEMPLATE, BLAS_NAME_SEQ);
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/device/cuda_device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class CudaDeviceCtx : public DeviceCtx {
}
const cudnnHandle_t& cudnn_handle() const override { return *(cuda_handler_->cudnn_handle()); }

void SyncDevice() override { CudaCheck(cudaStreamSynchronize(cuda_stream())); }
void SyncDevice() override { OF_CUDA_CHECK(cudaStreamSynchronize(cuda_stream())); }

void AddCallBack(std::function<void()> callback) const override {
cuda_handler_->AddCallBack(callback);
Expand Down
40 changes: 20 additions & 20 deletions oneflow/core/device/cuda_stream_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,70 +25,70 @@ namespace oneflow {
const cudaStream_t* CudaStreamHandle::cuda_stream() {
if (!cuda_stream_) {
cuda_stream_.reset(new cudaStream_t);
CudaCheck(cudaStreamCreate(cuda_stream_.get()));
OF_CUDA_CHECK(cudaStreamCreate(cuda_stream_.get()));
}
return cuda_stream_.get();
}

const cublasHandle_t* CudaStreamHandle::cublas_pmh_handle() {
if (!cublas_pmh_handle_) {
cublas_pmh_handle_.reset(new cublasHandle_t);
CudaCheck(cublasCreate(cublas_pmh_handle_.get()));
CudaCheck(cublasSetStream(*cublas_pmh_handle_, *cuda_stream()));
OF_CUBLAS_CHECK(cublasCreate(cublas_pmh_handle_.get()));
OF_CUBLAS_CHECK(cublasSetStream(*cublas_pmh_handle_, *cuda_stream()));
}
return cublas_pmh_handle_.get();
}

const cublasHandle_t* CudaStreamHandle::cublas_pmd_handle() {
if (!cublas_pmd_handle_) {
cublas_pmd_handle_.reset(new cublasHandle_t);
CudaCheck(cublasCreate(cublas_pmd_handle_.get()));
CudaCheck(cublasSetStream(*cublas_pmd_handle_, *cuda_stream()));
CudaCheck(cublasSetPointerMode(*cublas_pmd_handle_, CUBLAS_POINTER_MODE_DEVICE));
OF_CUBLAS_CHECK(cublasCreate(cublas_pmd_handle_.get()));
OF_CUBLAS_CHECK(cublasSetStream(*cublas_pmd_handle_, *cuda_stream()));
OF_CUBLAS_CHECK(cublasSetPointerMode(*cublas_pmd_handle_, CUBLAS_POINTER_MODE_DEVICE));
}
return cublas_pmd_handle_.get();
}

const cublasHandle_t* CudaStreamHandle::cublas_tensor_op_math_handle() {
if (!cublas_tensor_op_math_handle_) {
cublas_tensor_op_math_handle_.reset(new cublasHandle_t);
CudaCheck(cublasCreate(cublas_tensor_op_math_handle_.get()));
CudaCheck(cublasSetStream(*cublas_tensor_op_math_handle_, *cuda_stream()));
CudaCheck(cublasSetMathMode(*cublas_tensor_op_math_handle_, CUBLAS_TENSOR_OP_MATH));
OF_CUBLAS_CHECK(cublasCreate(cublas_tensor_op_math_handle_.get()));
OF_CUBLAS_CHECK(cublasSetStream(*cublas_tensor_op_math_handle_, *cuda_stream()));
OF_CUBLAS_CHECK(cublasSetMathMode(*cublas_tensor_op_math_handle_, CUBLAS_TENSOR_OP_MATH));
}
return cublas_tensor_op_math_handle_.get();
}

const cudnnHandle_t* CudaStreamHandle::cudnn_handle() {
if (!cudnn_handle_) {
if (IsCuda9OnTuringDevice()) {
CudaCheck(cudaDeviceSynchronize());
CudaCheck(cudaGetLastError());
OF_CUDA_CHECK(cudaDeviceSynchronize());
OF_CUDA_CHECK(cudaGetLastError());
}
cudnn_handle_.reset(new cudnnHandle_t);
CudaCheck(cudnnCreate(cudnn_handle_.get()));
OF_CUDNN_CHECK(cudnnCreate(cudnn_handle_.get()));
if (IsCuda9OnTuringDevice()) {
CudaCheck(cudaDeviceSynchronize());
OF_CUDA_CHECK(cudaDeviceSynchronize());
cudaGetLastError();
}
CudaCheck(cudnnSetStream(*cudnn_handle_, *cuda_stream()));
OF_CUDNN_CHECK(cudnnSetStream(*cudnn_handle_, *cuda_stream()));
}
return cudnn_handle_.get();
}

void CudaStreamHandle::AddCallBack(std::function<void()> callback) {
CudaCBEvent cb_event;
cb_event.callback = std::move(callback);
CudaCheck(cudaEventCreateWithFlags(&(cb_event.event), cudaEventDisableTiming));
CudaCheck(cudaEventRecord(cb_event.event, *cuda_stream()));
OF_CUDA_CHECK(cudaEventCreateWithFlags(&(cb_event.event), cudaEventDisableTiming));
OF_CUDA_CHECK(cudaEventRecord(cb_event.event, *cuda_stream()));
cb_event_chan_->Send(cb_event);
}

CudaStreamHandle::~CudaStreamHandle() {
if (cudnn_handle_) { CudaCheck(cudnnDestroy(*cudnn_handle_)); }
if (cublas_pmh_handle_) { CudaCheck(cublasDestroy(*cublas_pmh_handle_)); }
if (cublas_pmd_handle_) { CudaCheck(cublasDestroy(*cublas_pmd_handle_)); }
if (cuda_stream_) { CudaCheck(cudaStreamDestroy(*cuda_stream_)); }
if (cudnn_handle_) { OF_CUDNN_CHECK(cudnnDestroy(*cudnn_handle_)); }
if (cublas_pmh_handle_) { OF_CUBLAS_CHECK(cublasDestroy(*cublas_pmh_handle_)); }
if (cublas_pmd_handle_) { OF_CUBLAS_CHECK(cublasDestroy(*cublas_pmd_handle_)); }
if (cuda_stream_) { OF_CUDA_CHECK(cudaStreamDestroy(*cuda_stream_)); }
}

#endif // WITH_CUDA
Expand Down
18 changes: 7 additions & 11 deletions oneflow/core/device/cuda_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ namespace oneflow {

#ifdef WITH_CUDA

namespace {

const char* CublasGetErrorString(cublasStatus_t error) {
switch (error) {
case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
Expand Down Expand Up @@ -62,8 +60,6 @@ const char* CurandGetErrorString(curandStatus_t error) {
return "Unknown curand status";
}

} // namespace

void InitGlobalCudaDeviceProp() {
CHECK(Global<cudaDeviceProp>::Get() == nullptr) << "initialized Global<cudaDeviceProp> twice";
Global<cudaDeviceProp>::New();
Expand Down Expand Up @@ -147,8 +143,8 @@ void ParseCpuMask(const std::string& cpu_mask, cpu_set_t* cpu_set) {

std::string CudaDeviceGetCpuMask(int32_t dev_id) {
std::vector<char> pci_bus_id_buf(sizeof("0000:00:00.0"));
CudaCheck(cudaDeviceGetPCIBusId(pci_bus_id_buf.data(), static_cast<int>(pci_bus_id_buf.size()),
dev_id));
OF_CUDA_CHECK(cudaDeviceGetPCIBusId(pci_bus_id_buf.data(),
static_cast<int>(pci_bus_id_buf.size()), dev_id));
for (int32_t i = 0; i < pci_bus_id_buf.size(); ++i) {
pci_bus_id_buf[i] = std::tolower(pci_bus_id_buf[i]);
}
Expand Down Expand Up @@ -182,7 +178,7 @@ void NumaAwareCudaMallocHost(int32_t dev, void** ptr, size_t size) {
cpu_set_t saved_cpu_set;
CHECK_EQ(sched_getaffinity(0, sizeof(cpu_set_t), &saved_cpu_set), 0);
CHECK_EQ(sched_setaffinity(0, sizeof(cpu_set_t), &new_cpu_set), 0);
CudaCheck(cudaMallocHost(ptr, size));
OF_CUDA_CHECK(cudaMallocHost(ptr, size));
CHECK_EQ(sched_setaffinity(0, sizeof(cpu_set_t), &saved_cpu_set), 0);
#else
UNIMPLEMENTED();
Expand All @@ -198,13 +194,13 @@ cudaDataType_t GetCudaDataType(DataType val) {
}

CudaCurrentDeviceGuard::CudaCurrentDeviceGuard(int32_t dev_id) {
CudaCheck(cudaGetDevice(&saved_dev_id_));
CudaCheck(cudaSetDevice(dev_id));
OF_CUDA_CHECK(cudaGetDevice(&saved_dev_id_));
OF_CUDA_CHECK(cudaSetDevice(dev_id));
}

CudaCurrentDeviceGuard::CudaCurrentDeviceGuard() { CudaCheck(cudaGetDevice(&saved_dev_id_)); }
CudaCurrentDeviceGuard::CudaCurrentDeviceGuard() { OF_CUDA_CHECK(cudaGetDevice(&saved_dev_id_)); }

CudaCurrentDeviceGuard::~CudaCurrentDeviceGuard() { CudaCheck(cudaSetDevice(saved_dev_id_)); }
CudaCurrentDeviceGuard::~CudaCurrentDeviceGuard() { OF_CUDA_CHECK(cudaSetDevice(saved_dev_id_)); }

#endif // WITH_CUDA

Expand Down
32 changes: 32 additions & 0 deletions oneflow/core/device/cuda_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,38 @@ limitations under the License.

namespace oneflow {

const char* CublasGetErrorString(cublasStatus_t error);

const char* CurandGetErrorString(curandStatus_t error);

#define OF_CUDA_CHECK(condition) \
for (cudaError_t _of_cuda_check_status = (condition); _of_cuda_check_status != cudaSuccess;) \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用这个for做临时作用域来定义临时变量的技巧非常有意思

LOG(FATAL) << "Check failed: " #condition " : " << cudaGetErrorString(_of_cuda_check_status) \
<< " (" << _of_cuda_check_status << ") "

#define OF_CUDNN_CHECK(condition) \
for (cudnnStatus_t _of_cudnn_check_status = (condition); \
_of_cudnn_check_status != CUDNN_STATUS_SUCCESS;) \
LOG(FATAL) << "Check failed: " #condition " : " << cudnnGetErrorString(_of_cudnn_check_status) \
<< " (" << _of_cudnn_check_status << ") "

#define OF_CUBLAS_CHECK(condition) \
for (cublasStatus_t _of_cublas_check_status = (condition); \
_of_cublas_check_status != CUBLAS_STATUS_SUCCESS;) \
LOG(FATAL) << "Check failed: " #condition " : " << CublasGetErrorString(_of_cublas_check_status) \
<< " (" << _of_cublas_check_status << ") "

#define OF_CURAND_CHECK(condition) \
for (curandStatus_t _of_curand_check_status = (condition); \
_of_curand_check_status != CURAND_STATUS_SUCCESS;) \
LOG(FATAL) << "Check failed: " #condition " : " << CurandGetErrorString(_of_curand_check_status) \
<< " (" << _of_curand_check_status << ") "

#define OF_NCCL_CHECK(condition) \
for (ncclResult_t _of_nccl_check_status = (condition); _of_nccl_check_status != ncclSuccess;) \
LOG(FATAL) << "Check failed: " #condition " : " << ncclGetErrorString(_of_nccl_check_status) \
<< " (" << _of_nccl_check_status << ") "

template<typename T>
void CudaCheck(T error);

Expand Down