diff --git a/src/cudamatrix/cu-device.cc b/src/cudamatrix/cu-device.cc index 87e266e1889..c5114ed8b22 100644 --- a/src/cudamatrix/cu-device.cc +++ b/src/cudamatrix/cu-device.cc @@ -62,8 +62,10 @@ static bool GetCudaContext(int32 num_gpus, std::string *debug_str) { // Our first attempt to get a device context is: we do cudaFree(0) and see if // that returns no error code. If it succeeds then we have a device // context. Apparently this is the canonical way to get a context. - if (cudaFree(0) == 0) + if (cudaFree(0) == 0) { + cudaGetLastError(); // Clear any error status. return true; + } // The rest of this code represents how we used to get a device context, but // now its purpose is mainly a debugging one. @@ -71,16 +73,18 @@ static bool GetCudaContext(int32 num_gpus, std::string *debug_str) { debug_stream << "num-gpus=" << num_gpus << ". "; for (int32 device = 0; device < num_gpus; device++) { cudaSetDevice(device); - cudaError_t e = cudaDeviceSynchronize(); // << CUDA context gets created here. + cudaError_t e = cudaFree(0); // CUDA context gets created here. if (e == cudaSuccess) { - *debug_str = debug_stream.str(); + if (debug_str) + *debug_str = debug_stream.str(); + cudaGetLastError(); // Make sure the error state doesn't get returned in + // the next cudaGetLastError(). return true; } debug_stream << "Device " << device << ": " << cudaGetErrorString(e) << ". "; - cudaGetLastError(); // Make sure the error state doesn't get returned in - // the next cudaGetLastError(). } - *debug_str = debug_stream.str(); + if (debug_str) + *debug_str = debug_stream.str(); return false; } @@ -164,7 +168,7 @@ void CuDevice::SelectGpuId(std::string use_gpu) { } else { int32 num_times = 0; BaseFloat wait_time = 0.0; - while (! got_context) { + while (!got_context) { int32 sec_sleep = 5; if (num_times == 0) KALDI_WARN << "Will try again indefinitely every " << sec_sleep @@ -172,7 +176,7 @@ void CuDevice::SelectGpuId(std::string use_gpu) { num_times++; wait_time += sec_sleep; Sleep(sec_sleep); - got_context = GetCudaContext(num_gpus, &debug_str); + got_context = GetCudaContext(num_gpus, NULL); } KALDI_WARN << "Waited " << wait_time