|
|
@@ -1,4 +1,5 @@ |
|
|
#ifdef USE_CUDNN
|
|
|
+#include <algorithm>
|
|
|
#include <vector>
|
|
|
|
|
|
#include "caffe/filler.hpp"
|
|
|
@@ -24,13 +25,38 @@ void CuDNNConvolutionLayer<Dtype>::LayerSetUp( |
|
|
// Initialize CUDA streams and cuDNN.
|
|
|
stream_ = new cudaStream_t[this->group_ * CUDNN_STREAMS_PER_GROUP];
|
|
|
handle_ = new cudnnHandle_t[this->group_ * CUDNN_STREAMS_PER_GROUP];
|
|
|
+
|
|
|
+ // Initialize algorithm arrays
|
|
|
+ fwd_algo_ = new cudnnConvolutionFwdAlgo_t[bottom.size()];
|
|
|
+ bwd_filter_algo_= new cudnnConvolutionBwdFilterAlgo_t[bottom.size()];
|
|
|
+ bwd_data_algo_ = new cudnnConvolutionBwdDataAlgo_t[bottom.size()];
|
|
|
+
|
|
|
+ // initialize size arrays
|
|
|
+ workspace_fwd_sizes_ = new size_t[bottom.size()];
|
|
|
+ workspace_bwd_filter_sizes_ = new size_t[bottom.size()];
|
|
|
+ workspace_bwd_data_sizes_ = new size_t[bottom.size()];
|
|
|
+
|
|
|
+ // workspace data
|
|
|
workspaceSizeInBytes = 0;
|
|
|
- workspace = NULL;
|
|
|
+ workspaceData = NULL;
|
|
|
+ workspace = new void*[this->group_ * CUDNN_STREAMS_PER_GROUP];
|
|
|
+
|
|
|
+ for (size_t i = 0; i < bottom.size(); ++i) {
|
|
|
+ // initialize all to default algorithms
|
|
|
+ fwd_algo_[i] = (cudnnConvolutionFwdAlgo_t)0;
|
|
|
+ bwd_filter_algo_[i] = (cudnnConvolutionBwdFilterAlgo_t)0;
|
|
|
+ bwd_data_algo_[i] = (cudnnConvolutionBwdDataAlgo_t)0;
|
|
|
+ // default algorithms don't require workspace
|
|
|
+ workspace_fwd_sizes_[i] = 0;
|
|
|
+ workspace_bwd_data_sizes_[i] = 0;
|
|
|
+ workspace_bwd_filter_sizes_[i] = 0;
|
|
|
+ }
|
|
|
|
|
|
for (int g = 0; g < this->group_ * CUDNN_STREAMS_PER_GROUP; g++) {
|
|
|
CUDA_CHECK(cudaStreamCreate(&stream_[g]));
|
|
|
CUDNN_CHECK(cudnnCreate(&handle_[g]));
|
|
|
CUDNN_CHECK(cudnnSetStream(handle_[g], stream_[g]));
|
|
|
+ workspace[g] = NULL;
|
|
|
}
|
|
|
|
|
|
// Set the indexing parameters.
|
|
|
@@ -86,6 +112,10 @@ void CuDNNConvolutionLayer<Dtype>::Reshape( |
|
|
const int stride_h = stride_data[0];
|
|
|
const int stride_w = stride_data[1];
|
|
|
|
|
|
+ // Specify workspace limit for kernels directly until we have a
|
|
|
+ // planning strategy and a rewrite of Caffe's GPU memory mangagement
|
|
|
+ size_t workspace_limit_bytes = 8*1024*1024;
|
|
|
+
|
|
|
for (int i = 0; i < bottom.size(); i++) {
|
|
|
cudnn::setTensor4dDesc<Dtype>(&bottom_descs_[i],
|
|
|
this->num_,
|
|
|
@@ -98,7 +128,104 @@ void CuDNNConvolutionLayer<Dtype>::Reshape( |
|
|
this->num_output_ * this->out_spatial_dim_,
|
|
|
this->out_spatial_dim_, width_out, 1);
|
|
|
cudnn::setConvolutionDesc<Dtype>(&conv_descs_[i], bottom_descs_[i],
|
|
|
- filter_desc_, pad_h, pad_w, stride_h, stride_w);
|
|
|
+ filter_desc_, pad_h, pad_w,
|
|
|
+ stride_h, stride_w);
|
|
|
+
|
|
|
+ // choose forward and backward algorithms + workspace(s)
|
|
|
+ CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(handle_[0],
|
|
|
+ bottom_descs_[i],
|
|
|
+ filter_desc_,
|
|
|
+ conv_descs_[i],
|
|
|
+ top_descs_[i],
|
|
|
+ CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
|
|
|
+ workspace_limit_bytes,
|
|
|
+ &fwd_algo_[i]));
|
|
|
+
|
|
|
+ CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(handle_[0],
|
|
|
+ bottom_descs_[i],
|
|
|
+ filter_desc_,
|
|
|
+ conv_descs_[i],
|
|
|
+ top_descs_[i],
|
|
|
+ fwd_algo_[i],
|
|
|
+ &(workspace_fwd_sizes_[i])));
|
|
|
+
|
|
|
+ // choose backward algorithm for filter
|
|
|
+ CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(handle_[0],
|
|
|
+ bottom_descs_[i], top_descs_[i], conv_descs_[i], filter_desc_,
|
|
|
+ CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
|
|
|
+ workspace_limit_bytes, &bwd_filter_algo_[i]) );
|
|
|
+
|
|
|
+ // get workspace for backwards filter algorithm
|
|
|
+ CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(handle_[0],
|
|
|
+ bottom_descs_[i], top_descs_[i], conv_descs_[i], filter_desc_,
|
|
|
+ bwd_filter_algo_[i], &workspace_bwd_filter_sizes_[i]));
|
|
|
+
|
|
|
+ // choose backward algo for data
|
|
|
+ CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(handle_[0],
|
|
|
+ filter_desc_, top_descs_[i], conv_descs_[i], bottom_descs_[i],
|
|
|
+ CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
|
|
|
+ workspace_limit_bytes, &bwd_data_algo_[i]));
|
|
|
+
|
|
|
+ // get workspace size
|
|
|
+ CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(handle_[0],
|
|
|
+ filter_desc_, top_descs_[i], conv_descs_[i], bottom_descs_[i],
|
|
|
+ bwd_data_algo_[i], &workspace_bwd_data_sizes_[i]) );
|
|
|
+ }
|
|
|
+
|
|
|
+ // reduce over all workspace sizes to get a maximum to allocate / reallocate
|
|
|
+ size_t total_workspace_fwd = 0;
|
|
|
+ size_t total_workspace_bwd_data = 0;
|
|
|
+ size_t total_workspace_bwd_filter = 0;
|
|
|
+
|
|
|
+ for (size_t i = 0; i < bottom.size(); i++) {
|
|
|
+ total_workspace_fwd = std::max(total_workspace_fwd,
|
|
|
+ workspace_fwd_sizes_[i]);
|
|
|
+ total_workspace_bwd_data = std::max(total_workspace_bwd_data,
|
|
|
+ workspace_bwd_data_sizes_[i]);
|
|
|
+ total_workspace_bwd_filter = std::max(total_workspace_bwd_filter,
|
|
|
+ workspace_bwd_filter_sizes_[i]);
|
|
|
+ }
|
|
|
+ // get max over all operations
|
|
|
+ size_t max_workspace = std::max(total_workspace_fwd,
|
|
|
+ total_workspace_bwd_data);
|
|
|
+ max_workspace = std::max(max_workspace, total_workspace_bwd_filter);
|
|
|
+ // ensure all groups have enough workspace
|
|
|
+ size_t total_max_workspace = max_workspace *
|
|
|
+ (this->group_ * CUDNN_STREAMS_PER_GROUP);
|
|
|
+
|
|
|
+ // this is the total amount of storage needed over all groups + streams
|
|
|
+ if (total_max_workspace > workspaceSizeInBytes) {
|
|
|
+ LOG(INFO) << "Reallocating workspace storage: " << total_max_workspace;
|
|
|
+ workspaceSizeInBytes = total_max_workspace;
|
|
|
+
|
|
|
+ // free the existing workspace and allocate a new (larger) one
|
|
|
+ cudaFree(this->workspaceData);
|
|
|
+
|
|
|
+ cudaError_t err = cudaMalloc(&(this->workspaceData), workspaceSizeInBytes);
|
|
|
+ if (err != cudaSuccess) {
|
|
|
+ // force zero memory path
|
|
|
+ for (int i = 0; i < bottom.size(); i++) {
|
|
|
+ workspace_fwd_sizes_[i] = 0;
|
|
|
+ workspace_bwd_filter_sizes_[i] = 0;
|
|
|
+ workspace_bwd_data_sizes_[i] = 0;
|
|
|
+ fwd_algo_[i] = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
|
|
|
+ bwd_filter_algo_[i] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
|
|
|
+ bwd_data_algo_[i] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
|
|
|
+ }
|
|
|
+
|
|
|
+ // NULL out all workspace pointers
|
|
|
+ for (int g = 0; g < (this->group_ * CUDNN_STREAMS_PER_GROUP); g++) {
|
|
|
+ workspace[g] = NULL;
|
|
|
+ }
|
|
|
+ // NULL out underlying data
|
|
|
+ workspaceData = NULL;
|
|
|
+ workspaceSizeInBytes = 0;
|
|
|
+ }
|
|
|
+
|
|
|
+ // if we succeed in the allocation, set pointer aliases for workspaces
|
|
|
+ for (int g = 0; g < (this->group_ * CUDNN_STREAMS_PER_GROUP); g++) {
|
|
|
+ workspace[g] = reinterpret_cast<char *>(workspaceData) + g*max_workspace;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
// Tensor descriptor for bias.
|
|
|
@@ -128,8 +255,15 @@ CuDNNConvolutionLayer<Dtype>::~CuDNNConvolutionLayer() { |
|
|
cudnnDestroy(handle_[g]);
|
|
|
}
|
|
|
|
|
|
+ cudaFree(workspaceData);
|
|
|
delete [] stream_;
|
|
|
delete [] handle_;
|
|
|
+ delete [] fwd_algo_;
|
|
|
+ delete [] bwd_filter_algo_;
|
|
|
+ delete [] bwd_data_algo_;
|
|
|
+ delete [] workspace_fwd_sizes_;
|
|
|
+ delete [] workspace_bwd_data_sizes_;
|
|
|
+ delete [] workspace_bwd_filter_sizes_;
|
|
|
}
|
|
|
|
|
|
INSTANTIATE_CLASS(CuDNNConvolutionLayer);
|
|
|
|