Skip to content

Commit

Permalink
[src] Fix issue with CUDA device initialization if 'wait' specified. …
Browse files Browse the repository at this point in the history
  • Loading branch information
danpovey authored and LvHang committed Apr 13, 2018
1 parent 3db69e1 commit f403df9
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions src/cudamatrix/cu-device.cc
Expand Up @@ -62,25 +62,29 @@ static bool GetCudaContext(int32 num_gpus, std::string *debug_str) {
// Our first attempt to get a device context is: we do cudaFree(0) and see if
// that returns no error code. If it succeeds then we have a device
// context. Apparently this is the canonical way to get a context.
if (cudaFree(0) == 0)
if (cudaFree(0) == 0) {
cudaGetLastError(); // Clear any error status.
return true;
}

// The rest of this code represents how we used to get a device context, but
// now its purpose is mainly a debugging one.
std::ostringstream debug_stream;
debug_stream << "num-gpus=" << num_gpus << ". ";
for (int32 device = 0; device < num_gpus; device++) {
cudaSetDevice(device);
cudaError_t e = cudaDeviceSynchronize(); // << CUDA context gets created here.
cudaError_t e = cudaFree(0); // CUDA context gets created here.
if (e == cudaSuccess) {
*debug_str = debug_stream.str();
if (debug_str)
*debug_str = debug_stream.str();
cudaGetLastError(); // Make sure the error state doesn't get returned in
// the next cudaGetLastError().
return true;
}
debug_stream << "Device " << device << ": " << cudaGetErrorString(e) << ". ";
cudaGetLastError(); // Make sure the error state doesn't get returned in
// the next cudaGetLastError().
}
*debug_str = debug_stream.str();
if (debug_str)
*debug_str = debug_stream.str();
return false;
}

Expand Down Expand Up @@ -164,15 +168,15 @@ void CuDevice::SelectGpuId(std::string use_gpu) {
} else {
int32 num_times = 0;
BaseFloat wait_time = 0.0;
while (! got_context) {
while (!got_context) {
int32 sec_sleep = 5;
if (num_times == 0)
KALDI_WARN << "Will try again indefinitely every " << sec_sleep
<< " seconds to get a GPU.";
num_times++;
wait_time += sec_sleep;
Sleep(sec_sleep);
got_context = GetCudaContext(num_gpus, &debug_str);
got_context = GetCudaContext(num_gpus, NULL);
}

KALDI_WARN << "Waited " << wait_time
Expand Down

0 comments on commit f403df9

Please sign in to comment.