Basic cuDNN v3 support (update) #3160

Merged
merged 1 commit into from Oct 16, 2015
Jump to file or symbol
Failed to load files and symbols.
+692 −56
Split
@@ -304,13 +304,24 @@ class CuDNNConvolutionLayer : public ConvolutionLayer<Dtype> {
bool handles_setup_;
cudnnHandle_t* handle_;
cudaStream_t* stream_;
+
+ // algorithms for forward and backwards convolutions
+ cudnnConvolutionFwdAlgo_t *fwd_algo_;
+ cudnnConvolutionBwdFilterAlgo_t *bwd_filter_algo_;
+ cudnnConvolutionBwdDataAlgo_t *bwd_data_algo_;
+
vector<cudnnTensorDescriptor_t> bottom_descs_, top_descs_;
cudnnTensorDescriptor_t bias_desc_;
cudnnFilterDescriptor_t filter_desc_;
vector<cudnnConvolutionDescriptor_t> conv_descs_;
int bottom_offset_, top_offset_, bias_offset_;
- size_t workspaceSizeInBytes;
- void *workspace;
+
+ size_t *workspace_fwd_sizes_;
+ size_t *workspace_bwd_data_sizes_;
+ size_t *workspace_bwd_filter_sizes_;
+ size_t workspaceSizeInBytes; // size of underlying storage
+ void *workspaceData; // underlying storage
+ void **workspace; // aliases into workspaceData
};
#endif
@@ -442,6 +453,65 @@ class LRNLayer : public Layer<Dtype> {
vector<Blob<Dtype>*> product_bottom_vec_;
};
+#ifdef USE_CUDNN
+
+template <typename Dtype>
+class CuDNNLRNLayer : public LRNLayer<Dtype> {
+ public:
+ explicit CuDNNLRNLayer(const LayerParameter& param)
+ : LRNLayer<Dtype>(param), handles_setup_(false) {}
+ virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+ virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+ virtual ~CuDNNLRNLayer();
+
+ protected:
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+ virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+
+ bool handles_setup_;
+ cudnnHandle_t handle_;
+ cudnnLRNDescriptor_t norm_desc_;
+ cudnnTensorDescriptor_t bottom_desc_, top_desc_;
+
+ int size_;
+ Dtype alpha_, beta_, k_;
+};
+
+template <typename Dtype>
+class CuDNNLCNLayer : public LRNLayer<Dtype> {
+ public:
+ explicit CuDNNLCNLayer(const LayerParameter& param)
+ : LRNLayer<Dtype>(param), handles_setup_(false), tempDataSize(0),
+ tempData1(NULL), tempData2(NULL) {}
+ virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+ virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+ virtual ~CuDNNLCNLayer();
+
+ protected:
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+ virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+
+ bool handles_setup_;
+ cudnnHandle_t handle_;
+ cudnnLRNDescriptor_t norm_desc_;
+ cudnnTensorDescriptor_t bottom_desc_, top_desc_;
+
+ int size_, pre_pad_;
+ Dtype alpha_, beta_, k_;
+
+ size_t tempDataSize;
+ void *tempData1, *tempData2;
+};
+
+#endif
/**
* @brief Pools the input image by taking the max, average, etc. within regions.
@@ -54,10 +54,8 @@ shared_ptr<Layer<Dtype> > GetPoolingLayer(const LayerParameter& param) {
return shared_ptr<Layer<Dtype> >(new PoolingLayer<Dtype>(param));
#ifdef USE_CUDNN
} else if (engine == PoolingParameter_Engine_CUDNN) {
- PoolingParameter p_param = param.pooling_param();
- if (p_param.pad() || p_param.pad_h() || p_param.pad_w() ||
- param.top_size() > 1) {
- LOG(INFO) << "CUDNN does not support padding or multiple tops. "
+ if (param.top_size() > 1) {
+ LOG(INFO) << "cuDNN does not support multiple tops. "
<< "Using Caffe's own pooling layer.";
return shared_ptr<Layer<Dtype> >(new PoolingLayer<Dtype>(param));
}
@@ -70,6 +68,43 @@ shared_ptr<Layer<Dtype> > GetPoolingLayer(const LayerParameter& param) {
REGISTER_LAYER_CREATOR(Pooling, GetPoolingLayer);
+// Get LRN layer according to engine
+template <typename Dtype>
+shared_ptr<Layer<Dtype> > GetLRNLayer(const LayerParameter& param) {
+ LRNParameter_Engine engine = param.lrn_param().engine();
+
+ if (engine == LRNParameter_Engine_DEFAULT) {
+#ifdef USE_CUDNN
+ engine = LRNParameter_Engine_CUDNN;
+#else
+ engine = LRNParameter_Engine_CAFFE;
+#endif
+ }
+
+ if (engine == LRNParameter_Engine_CAFFE) {
+ return shared_ptr<Layer<Dtype> >(new LRNLayer<Dtype>(param));
+#ifdef USE_CUDNN
+ } else if (engine == LRNParameter_Engine_CUDNN) {
+ LRNParameter lrn_param = param.lrn_param();
+
+ if (lrn_param.norm_region() ==LRNParameter_NormRegion_WITHIN_CHANNEL) {
+ return shared_ptr<Layer<Dtype> >(new CuDNNLCNLayer<Dtype>(param));
+ } else {
+ // local size is too big to be handled through cuDNN
+ if (param.lrn_param().local_size() > CUDNN_LRN_MAX_N) {
+ return shared_ptr<Layer<Dtype> >(new LRNLayer<Dtype>(param));
+ } else {
+ return shared_ptr<Layer<Dtype> >(new CuDNNLRNLayer<Dtype>(param));
+ }
+ }
+#endif
+ } else {
+ LOG(FATAL) << "Layer " << param.name() << " has unknown engine.";
+ }
+}
+
+REGISTER_LAYER_CREATOR(LRN, GetLRNLayer);
+
// Get relu layer according to engine.
template <typename Dtype>
shared_ptr<Layer<Dtype> > GetReLULayer(const LayerParameter& param) {
@@ -1,4 +1,5 @@
#ifdef USE_CUDNN
+#include <algorithm>
#include <vector>
#include "caffe/filler.hpp"
@@ -24,13 +25,38 @@ void CuDNNConvolutionLayer<Dtype>::LayerSetUp(
// Initialize CUDA streams and cuDNN.
stream_ = new cudaStream_t[this->group_ * CUDNN_STREAMS_PER_GROUP];
handle_ = new cudnnHandle_t[this->group_ * CUDNN_STREAMS_PER_GROUP];
+
+ // Initialize algorithm arrays
+ fwd_algo_ = new cudnnConvolutionFwdAlgo_t[bottom.size()];
+ bwd_filter_algo_= new cudnnConvolutionBwdFilterAlgo_t[bottom.size()];
+ bwd_data_algo_ = new cudnnConvolutionBwdDataAlgo_t[bottom.size()];
+
+ // initialize size arrays
+ workspace_fwd_sizes_ = new size_t[bottom.size()];
+ workspace_bwd_filter_sizes_ = new size_t[bottom.size()];
+ workspace_bwd_data_sizes_ = new size_t[bottom.size()];
+
+ // workspace data
workspaceSizeInBytes = 0;
- workspace = NULL;
+ workspaceData = NULL;
+ workspace = new void*[this->group_ * CUDNN_STREAMS_PER_GROUP];
+
+ for (size_t i = 0; i < bottom.size(); ++i) {
+ // initialize all to default algorithms
+ fwd_algo_[i] = (cudnnConvolutionFwdAlgo_t)0;
+ bwd_filter_algo_[i] = (cudnnConvolutionBwdFilterAlgo_t)0;
+ bwd_data_algo_[i] = (cudnnConvolutionBwdDataAlgo_t)0;
+ // default algorithms don't require workspace
+ workspace_fwd_sizes_[i] = 0;
+ workspace_bwd_data_sizes_[i] = 0;
+ workspace_bwd_filter_sizes_[i] = 0;
+ }
for (int g = 0; g < this->group_ * CUDNN_STREAMS_PER_GROUP; g++) {
CUDA_CHECK(cudaStreamCreate(&stream_[g]));
CUDNN_CHECK(cudnnCreate(&handle_[g]));
CUDNN_CHECK(cudnnSetStream(handle_[g], stream_[g]));
+ workspace[g] = NULL;
}
// Set the indexing parameters.
@@ -86,6 +112,10 @@ void CuDNNConvolutionLayer<Dtype>::Reshape(
const int stride_h = stride_data[0];
const int stride_w = stride_data[1];
+ // Specify workspace limit for kernels directly until we have a
+ // planning strategy and a rewrite of Caffe's GPU memory mangagement
+ size_t workspace_limit_bytes = 8*1024*1024;
+
for (int i = 0; i < bottom.size(); i++) {
cudnn::setTensor4dDesc<Dtype>(&bottom_descs_[i],
this->num_,
@@ -98,7 +128,104 @@ void CuDNNConvolutionLayer<Dtype>::Reshape(
this->num_output_ * this->out_spatial_dim_,
this->out_spatial_dim_, width_out, 1);
cudnn::setConvolutionDesc<Dtype>(&conv_descs_[i], bottom_descs_[i],
- filter_desc_, pad_h, pad_w, stride_h, stride_w);
+ filter_desc_, pad_h, pad_w,
+ stride_h, stride_w);
+
+ // choose forward and backward algorithms + workspace(s)
+ CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(handle_[0],
+ bottom_descs_[i],
+ filter_desc_,
+ conv_descs_[i],
+ top_descs_[i],
+ CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
+ workspace_limit_bytes,
+ &fwd_algo_[i]));
+
+ CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(handle_[0],
+ bottom_descs_[i],
+ filter_desc_,
+ conv_descs_[i],
+ top_descs_[i],
+ fwd_algo_[i],
+ &(workspace_fwd_sizes_[i])));
+
+ // choose backward algorithm for filter
+ CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(handle_[0],
+ bottom_descs_[i], top_descs_[i], conv_descs_[i], filter_desc_,
+ CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
+ workspace_limit_bytes, &bwd_filter_algo_[i]) );
+
+ // get workspace for backwards filter algorithm
+ CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(handle_[0],
+ bottom_descs_[i], top_descs_[i], conv_descs_[i], filter_desc_,
+ bwd_filter_algo_[i], &workspace_bwd_filter_sizes_[i]));
+
+ // choose backward algo for data
+ CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(handle_[0],
+ filter_desc_, top_descs_[i], conv_descs_[i], bottom_descs_[i],
+ CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
+ workspace_limit_bytes, &bwd_data_algo_[i]));
+
+ // get workspace size
+ CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(handle_[0],
+ filter_desc_, top_descs_[i], conv_descs_[i], bottom_descs_[i],
+ bwd_data_algo_[i], &workspace_bwd_data_sizes_[i]) );
+ }
+
+ // reduce over all workspace sizes to get a maximum to allocate / reallocate
+ size_t total_workspace_fwd = 0;
+ size_t total_workspace_bwd_data = 0;
+ size_t total_workspace_bwd_filter = 0;
+
+ for (size_t i = 0; i < bottom.size(); i++) {
+ total_workspace_fwd = std::max(total_workspace_fwd,
+ workspace_fwd_sizes_[i]);
+ total_workspace_bwd_data = std::max(total_workspace_bwd_data,
+ workspace_bwd_data_sizes_[i]);
+ total_workspace_bwd_filter = std::max(total_workspace_bwd_filter,
+ workspace_bwd_filter_sizes_[i]);
+ }
+ // get max over all operations
+ size_t max_workspace = std::max(total_workspace_fwd,
+ total_workspace_bwd_data);
+ max_workspace = std::max(max_workspace, total_workspace_bwd_filter);
+ // ensure all groups have enough workspace
+ size_t total_max_workspace = max_workspace *
+ (this->group_ * CUDNN_STREAMS_PER_GROUP);
+
+ // this is the total amount of storage needed over all groups + streams
+ if (total_max_workspace > workspaceSizeInBytes) {
+ LOG(INFO) << "Reallocating workspace storage: " << total_max_workspace;
+ workspaceSizeInBytes = total_max_workspace;
+
+ // free the existing workspace and allocate a new (larger) one
+ cudaFree(this->workspaceData);
+
+ cudaError_t err = cudaMalloc(&(this->workspaceData), workspaceSizeInBytes);
+ if (err != cudaSuccess) {
+ // force zero memory path
+ for (int i = 0; i < bottom.size(); i++) {
+ workspace_fwd_sizes_[i] = 0;
+ workspace_bwd_filter_sizes_[i] = 0;
+ workspace_bwd_data_sizes_[i] = 0;
+ fwd_algo_[i] = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
+ bwd_filter_algo_[i] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
+ bwd_data_algo_[i] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
+ }
+
+ // NULL out all workspace pointers
+ for (int g = 0; g < (this->group_ * CUDNN_STREAMS_PER_GROUP); g++) {
+ workspace[g] = NULL;
+ }
+ // NULL out underlying data
+ workspaceData = NULL;
+ workspaceSizeInBytes = 0;
+ }
+
+ // if we succeed in the allocation, set pointer aliases for workspaces
+ for (int g = 0; g < (this->group_ * CUDNN_STREAMS_PER_GROUP); g++) {
+ workspace[g] = reinterpret_cast<char *>(workspaceData) + g*max_workspace;
+ }
}
// Tensor descriptor for bias.
@@ -128,8 +255,15 @@ CuDNNConvolutionLayer<Dtype>::~CuDNNConvolutionLayer() {
cudnnDestroy(handle_[g]);
}
+ cudaFree(workspaceData);
delete [] stream_;
delete [] handle_;
+ delete [] fwd_algo_;
+ delete [] bwd_filter_algo_;
+ delete [] bwd_data_algo_;
+ delete [] workspace_fwd_sizes_;
+ delete [] workspace_bwd_data_sizes_;
+ delete [] workspace_bwd_filter_sizes_;
}
INSTANTIATE_CLASS(CuDNNConvolutionLayer);
Oops, something went wrong.