From 45bde08ef7601215b61616af47d97906c6824526 Mon Sep 17 00:00:00 2001 From: Jun Shi Date: Fri, 22 Jan 2016 09:50:31 -0800 Subject: [PATCH] add check and grab GPU utilities --- include/caffe/common.hpp | 4 ++++ src/caffe/common.cpp | 29 +++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/include/caffe/common.hpp b/include/caffe/common.hpp index 1df6b9a14fb..1e5b52905f2 100644 --- a/include/caffe/common.hpp +++ b/include/caffe/common.hpp @@ -149,6 +149,10 @@ class Caffe { static void SetDevice(const int device_id); // Prints the current GPU status. static void DeviceQuery(); + // Check if specified device is available + static bool CheckDevice(const int device_id); + // Get the first available device id since start_id + static int GrabDevice(const int start_id = 0); // Parallel training info inline static int solver_count() { return Get().solver_count_; } inline static void set_solver_count(int val) { Get().solver_count_ = val; } diff --git a/src/caffe/common.cpp b/src/caffe/common.cpp index 299d67d4bec..93854e31a3f 100644 --- a/src/caffe/common.cpp +++ b/src/caffe/common.cpp @@ -70,6 +70,15 @@ void Caffe::DeviceQuery() { NO_GPU; } +bool Caffe::CheckDevice(const int device_id) { + NO_GPU; + return false; +} + +int Caffe::GrabDevice(const int start_id) { + NO_GPU; + return -1; +} class Caffe::RNG::Generator { public: @@ -192,6 +201,26 @@ void Caffe::DeviceQuery() { return; } +bool Caffe::CheckDevice(const int device_id) { + bool r = ((cudaSuccess == cudaSetDevice(device_id)) && + (cudaSuccess == cudaFree(0))); + cudaGetLastError(); + return r; +} + +int Caffe::GrabDevice(const int start_id) { + int count = 0; + int r = -1; + if (cudaSuccess == cudaGetDeviceCount(&count)) { + for (int i = start_id; i < count; i++) { + if (CheckDevice(i)) { + r = i; + break; + } + } // for + } + return r; +} class Caffe::RNG::Generator { public: