From 519d82593f1dfba812b8fd7ff95f637f9f570fe4 Mon Sep 17 00:00:00 2001 From: Fisher Yu Date: Sun, 6 Dec 2015 20:04:43 -0500 Subject: [PATCH 1/4] add support for 2D dilated convolution --- include/caffe/layers/base_conv_layer.hpp | 14 ++-- include/caffe/layers/conv_layer.hpp | 3 + include/caffe/layers/im2col_layer.hpp | 2 + include/caffe/util/im2col.hpp | 12 ++-- src/caffe/layer_factory.cpp | 17 ++++- src/caffe/layers/base_conv_layer.cpp | 14 ++++ src/caffe/layers/conv_layer.cpp | 4 +- src/caffe/layers/im2col_layer.cpp | 21 +++++- src/caffe/layers/im2col_layer.cu | 2 + src/caffe/proto/caffe.proto | 1 + src/caffe/test/test_convolution_layer.cpp | 24 +++++-- src/caffe/test/test_im2col_kernel.cu | 17 ++++- src/caffe/test/test_im2col_layer.cpp | 3 +- src/caffe/util/im2col.cpp | 34 ++++++---- src/caffe/util/im2col.cu | 80 ++++++++++++----------- 15 files changed, 177 insertions(+), 71 deletions(-) diff --git a/include/caffe/layers/base_conv_layer.hpp b/include/caffe/layers/base_conv_layer.hpp index f3def16c039..db471b586da 100644 --- a/include/caffe/layers/base_conv_layer.hpp +++ b/include/caffe/layers/base_conv_layer.hpp @@ -68,6 +68,8 @@ class BaseConvolutionLayer : public Layer { Blob stride_; /// @brief The spatial dimensions of the padding. Blob pad_; + /// @brief The spatial dimensions of the dilation. + Blob dilation_; /// @brief The spatial dimensions of the convolution input. Blob conv_input_shape_; /// @brief The spatial dimensions of the col_buffer. @@ -99,7 +101,8 @@ class BaseConvolutionLayer : public Layer { conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2], kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1], pad_.cpu_data()[0], pad_.cpu_data()[1], - stride_.cpu_data()[0], stride_.cpu_data()[1], col_buff); + stride_.cpu_data()[0], stride_.cpu_data()[1], + dilation_.cpu_data()[0], dilation_.cpu_data()[1], col_buff); } else { im2col_nd_cpu(data, num_spatial_axes_, conv_input_shape_.cpu_data(), col_buffer_shape_.data(), kernel_shape_.cpu_data(), @@ -112,7 +115,8 @@ class BaseConvolutionLayer : public Layer { conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2], kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1], pad_.cpu_data()[0], pad_.cpu_data()[1], - stride_.cpu_data()[0], stride_.cpu_data()[1], data); + stride_.cpu_data()[0], stride_.cpu_data()[1], + dilation_.cpu_data()[0], dilation_.cpu_data()[1], data); } else { col2im_nd_cpu(col_buff, num_spatial_axes_, conv_input_shape_.cpu_data(), col_buffer_shape_.data(), kernel_shape_.cpu_data(), @@ -126,7 +130,8 @@ class BaseConvolutionLayer : public Layer { conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2], kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1], pad_.cpu_data()[0], pad_.cpu_data()[1], - stride_.cpu_data()[0], stride_.cpu_data()[1], col_buff); + stride_.cpu_data()[0], stride_.cpu_data()[1], + dilation_.cpu_data()[0], dilation_.cpu_data()[1], col_buff); } else { im2col_nd_gpu(data, num_spatial_axes_, num_kernels_im2col_, conv_input_shape_.gpu_data(), col_buffer_.gpu_shape(), @@ -140,7 +145,8 @@ class BaseConvolutionLayer : public Layer { conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2], kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1], pad_.cpu_data()[0], pad_.cpu_data()[1], - stride_.cpu_data()[0], stride_.cpu_data()[1], data); + stride_.cpu_data()[0], stride_.cpu_data()[1], + dilation_.cpu_data()[0], dilation_.cpu_data()[1], data); } else { col2im_nd_gpu(col_buff, num_spatial_axes_, num_kernels_col2im_, conv_input_shape_.gpu_data(), col_buffer_.gpu_shape(), diff --git a/include/caffe/layers/conv_layer.hpp b/include/caffe/layers/conv_layer.hpp index 15574766de5..93a618ddd72 100644 --- a/include/caffe/layers/conv_layer.hpp +++ b/include/caffe/layers/conv_layer.hpp @@ -44,6 +44,9 @@ class ConvolutionLayer : public BaseConvolutionLayer { * convolution, given by pad for equal dimensions or pad_h and pad_w for * different padding. Input padding is computed implicitly instead of * actually padding. + * - dilation (\b optional, default 1). The filter + * dilation, given by dilation_size for equal dimensions for different + * dilation. By default the convolution has dilation 1. * - group (\b optional, default 1). The number of filter groups. Group * convolution is a method for reducing parameterization by selectively * connecting input and output channels. The input and output channel dimensions must be divisible diff --git a/include/caffe/layers/im2col_layer.hpp b/include/caffe/layers/im2col_layer.hpp index 1d3b2eb67d1..71e32f7427f 100644 --- a/include/caffe/layers/im2col_layer.hpp +++ b/include/caffe/layers/im2col_layer.hpp @@ -46,6 +46,8 @@ class Im2colLayer : public Layer { Blob stride_; /// @brief The spatial dimensions of the padding. Blob pad_; + /// @brief The spatial dimensions of the dilation. + Blob dilation_; int num_spatial_axes_; int bottom_dim_; diff --git a/include/caffe/util/im2col.hpp b/include/caffe/util/im2col.hpp index d3eb6ccd6fc..748b65c4f36 100644 --- a/include/caffe/util/im2col.hpp +++ b/include/caffe/util/im2col.hpp @@ -13,7 +13,8 @@ template void im2col_cpu(const Dtype* data_im, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, Dtype* data_col); + const int stride_w, const int dilation_h, const int dilation_w, + Dtype* data_col); template void col2im_nd_cpu(const Dtype* data_col, const int num_spatial_axes, @@ -25,7 +26,8 @@ template void col2im_cpu(const Dtype* data_col, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, Dtype* data_im); + const int stride_w, const int dilation_h, const int dilation_w, + Dtype* data_im); template void im2col_nd_gpu(const Dtype* data_im, const int num_spatial_axes, @@ -37,7 +39,8 @@ template void im2col_gpu(const Dtype* data_im, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, Dtype* data_col); + const int stride_w, const int dilation_h, const int dilation_w, + Dtype* data_col); template void col2im_nd_gpu(const Dtype* data_col, const int num_spatial_axes, @@ -49,7 +52,8 @@ template void col2im_gpu(const Dtype* data_col, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, Dtype* data_im); + const int stride_w, const int dilation_h, const int dilation_w, + Dtype* data_im); } // namespace caffe diff --git a/src/caffe/layer_factory.cpp b/src/caffe/layer_factory.cpp index 76d851af9a2..6b1d1c1a5f5 100644 --- a/src/caffe/layer_factory.cpp +++ b/src/caffe/layer_factory.cpp @@ -37,17 +37,30 @@ namespace caffe { template shared_ptr > GetConvolutionLayer( const LayerParameter& param) { - ConvolutionParameter_Engine engine = param.convolution_param().engine(); + ConvolutionParameter conv_param = param.convolution_param(); + ConvolutionParameter_Engine engine = conv_param.engine(); + bool use_dilation = false; + for (int i = 0; i < conv_param.dilation_size(); ++i) { + if (conv_param.dilation(i) > 1) { + use_dilation = true; + } + } if (engine == ConvolutionParameter_Engine_DEFAULT) { engine = ConvolutionParameter_Engine_CAFFE; #ifdef USE_CUDNN - engine = ConvolutionParameter_Engine_CUDNN; + if (!use_dilation) { + engine = ConvolutionParameter_Engine_CUDNN; + } #endif } if (engine == ConvolutionParameter_Engine_CAFFE) { return shared_ptr >(new ConvolutionLayer(param)); #ifdef USE_CUDNN } else if (engine == ConvolutionParameter_Engine_CUDNN) { + if (use_dilation) { + LOG(FATAL) << "CuDNN doesn't support the dilated convolution at Layer " + << param.name(); + } return shared_ptr >(new CuDNNConvolutionLayer(param)); #endif } else { diff --git a/src/caffe/layers/base_conv_layer.cpp b/src/caffe/layers/base_conv_layer.cpp index f6f14cd0f17..0e8f5c98f50 100644 --- a/src/caffe/layers/base_conv_layer.cpp +++ b/src/caffe/layers/base_conv_layer.cpp @@ -92,6 +92,20 @@ void BaseConvolutionLayer::LayerSetUp(const vector*>& bottom, conv_param.pad((num_pad_dims == 1) ? 0 : i); } } + // Setup dilation dimensions (dilation_). + dilation_.Reshape(spatial_dim_blob_shape); + int* dilation_data = dilation_.mutable_cpu_data(); + const int num_dilation_dims = conv_param.dilation_size(); + CHECK(num_dilation_dims == 0 || num_dilation_dims == 1 || + num_dilation_dims == num_spatial_axes_) + << "dilation must be specified once, or once per spatial dimension " + << "(dilation specified " << num_dilation_dims << " times; " + << num_spatial_axes_ << " spatial dims);"; + const int kDefaultDilation = 1; + for (int i = 0; i < num_spatial_axes_; ++i) { + dilation_data[i] = (num_dilation_dims == 0) ? kDefaultDilation : + conv_param.dilation((num_dilation_dims == 1) ? 0 : i); + } // Special case: im2col is the identity for 1x1 convolution with stride 1 // and no padding, so flag for skipping the buffer and transformation. is_1x1_ = true; diff --git a/src/caffe/layers/conv_layer.cpp b/src/caffe/layers/conv_layer.cpp index cff09783945..b4aca54ee23 100644 --- a/src/caffe/layers/conv_layer.cpp +++ b/src/caffe/layers/conv_layer.cpp @@ -9,11 +9,13 @@ void ConvolutionLayer::compute_output_shape() { const int* kernel_shape_data = this->kernel_shape_.cpu_data(); const int* stride_data = this->stride_.cpu_data(); const int* pad_data = this->pad_.cpu_data(); + const int* dilation_data = this->dilation_.cpu_data(); this->output_shape_.clear(); for (int i = 0; i < this->num_spatial_axes_; ++i) { // i + 1 to skip channel axis const int input_dim = this->input_shape(i + 1); - const int output_dim = (input_dim + 2 * pad_data[i] - kernel_shape_data[i]) + int kernel_extent = dilation_data[i] * (kernel_shape_data[i] - 1) + 1; + const int output_dim = (input_dim + 2 * pad_data[i] - kernel_extent) / stride_data[i] + 1; this->output_shape_.push_back(output_dim); } diff --git a/src/caffe/layers/im2col_layer.cpp b/src/caffe/layers/im2col_layer.cpp index c12e4f52a10..2b6fc9c2867 100644 --- a/src/caffe/layers/im2col_layer.cpp +++ b/src/caffe/layers/im2col_layer.cpp @@ -87,6 +87,20 @@ void Im2colLayer::LayerSetUp(const vector*>& bottom, conv_param.pad((num_pad_dims == 1) ? 0 : i); } } + // Setup dilation dimensions (dilation_). + dilation_.Reshape(dim_blob_shape); + int* dilation_data = dilation_.mutable_cpu_data(); + const int num_dilation_dims = conv_param.dilation_size(); + CHECK(num_dilation_dims == 0 || num_dilation_dims == 1 || + num_dilation_dims == num_spatial_axes_) + << "dilation must be specified once, or once per spatial dimension " + << "(dilation specified " << num_dilation_dims << " times; " + << num_spatial_axes_ << " spatial dims);"; + const int kDefaultDilation = 1; + for (int i = 0; i < num_spatial_axes_; ++i) { + dilation_data[i] = (num_dilation_dims == 0) ? kDefaultDilation : + conv_param.dilation((num_dilation_dims == 1) ? 0 : i); + } } template @@ -96,10 +110,12 @@ void Im2colLayer::Reshape(const vector*>& bottom, const int* kernel_shape_data = kernel_shape_.cpu_data(); const int* stride_data = stride_.cpu_data(); const int* pad_data = pad_.cpu_data(); + const int* dilation_data = dilation_.cpu_data(); for (int i = 0; i < num_spatial_axes_; ++i) { top_shape[channel_axis_] *= kernel_shape_data[i]; const int input_dim = bottom[0]->shape(channel_axis_ + i + 1); - const int output_dim = (input_dim + 2 * pad_data[i] - kernel_shape_data[i]) + int kernel_extent = dilation_data[i] * (kernel_shape_data[i] - 1) + 1; + const int output_dim = (input_dim + 2 * pad_data[i] - kernel_extent) / stride_data[i] + 1; top_shape[channel_axis_ + i + 1] = output_dim; } @@ -122,6 +138,7 @@ void Im2colLayer::Forward_cpu(const vector*>& bottom, DCHECK_EQ(kernel_shape_.count(), num_spatial_axes_); DCHECK_EQ(pad_.count(), num_spatial_axes_); DCHECK_EQ(stride_.count(), num_spatial_axes_); + DCHECK_EQ(dilation_.count(), num_spatial_axes_); if (!force_nd_im2col_ && num_spatial_axes_ == 2) { im2col_cpu(bottom_data + n * bottom_dim_, channels_, bottom[0]->shape(channel_axis_ + 1), @@ -129,6 +146,7 @@ void Im2colLayer::Forward_cpu(const vector*>& bottom, kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1], pad_.cpu_data()[0], pad_.cpu_data()[1], stride_.cpu_data()[0], stride_.cpu_data()[1], + dilation_.cpu_data()[0], dilation_.cpu_data()[1], top_data + n * top_dim_); } else { im2col_nd_cpu(bottom_data + n * bottom_dim_, num_spatial_axes_, @@ -153,6 +171,7 @@ void Im2colLayer::Backward_cpu(const vector*>& top, kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1], pad_.cpu_data()[0], pad_.cpu_data()[1], stride_.cpu_data()[0], stride_.cpu_data()[1], + dilation_.cpu_data()[0], dilation_.cpu_data()[1], bottom_diff + n * bottom_dim_); } else { col2im_nd_cpu(top_diff + n * top_dim_, num_spatial_axes_, diff --git a/src/caffe/layers/im2col_layer.cu b/src/caffe/layers/im2col_layer.cu index 517b4220cb9..d90075d4304 100644 --- a/src/caffe/layers/im2col_layer.cu +++ b/src/caffe/layers/im2col_layer.cu @@ -19,6 +19,7 @@ void Im2colLayer::Forward_gpu(const vector*>& bottom, kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1], pad_.cpu_data()[0], pad_.cpu_data()[1], stride_.cpu_data()[0], stride_.cpu_data()[1], + dilation_.cpu_data()[0], dilation_.cpu_data()[1], top_data + n * top_dim_); } else { im2col_nd_gpu(bottom_data + n * bottom_dim_, num_spatial_axes_, @@ -43,6 +44,7 @@ void Im2colLayer::Backward_gpu(const vector*>& top, kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1], pad_.cpu_data()[0], pad_.cpu_data()[1], stride_.cpu_data()[0], stride_.cpu_data()[1], + dilation_.cpu_data()[0], dilation_.cpu_data()[1], bottom_diff + n * bottom_dim_); } else { col2im_nd_gpu(top_diff + n * top_dim_, num_spatial_axes_, bottom_dim_, diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 787369f7cff..87c46629baf 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -518,6 +518,7 @@ message ConvolutionParameter { repeated uint32 pad = 3; // The padding size; defaults to 0 repeated uint32 kernel_size = 4; // The kernel size repeated uint32 stride = 6; // The stride; defaults to 1 + repeated uint32 dilation = 18; // The dilation; defaults to 1 // For 2D convolution only, the *_h and *_w versions may also be used to // specify both spatial dimensions. diff --git a/src/caffe/test/test_convolution_layer.cpp b/src/caffe/test/test_convolution_layer.cpp index e2d43f31b6a..373aa7a7398 100644 --- a/src/caffe/test/test_convolution_layer.cpp +++ b/src/caffe/test/test_convolution_layer.cpp @@ -46,13 +46,17 @@ void caffe_conv(const Blob* in, ConvolutionParameter* conv_param, } else { stride_h = stride_w = conv_param->stride_size() ? conv_param->stride(0) : 1; } - int kernel_d, pad_d, stride_d; + int dilation_h, dilation_w; + dilation_h = dilation_w = conv_param->dilation_size() ? + conv_param->dilation(0) : 1; + int kernel_d, pad_d, stride_d, dilation_d; if (has_depth) { kernel_d = kernel_h; stride_d = stride_h; pad_d = pad_h; + dilation_d = dilation_h; } else { - kernel_d = stride_d = 1; + kernel_d = stride_d = dilation_d = 1; pad_d = 0; } // Groups @@ -77,9 +81,9 @@ void caffe_conv(const Blob* in, ConvolutionParameter* conv_param, for (int r = 0; r < kernel_d; r++) { for (int p = 0; p < kernel_h; p++) { for (int q = 0; q < kernel_w; q++) { - int in_z = z * stride_d - pad_d + r; - int in_y = y * stride_h - pad_h + p; - int in_x = x * stride_w - pad_w + q; + int in_z = z * stride_d - pad_d + r * dilation_d; + int in_y = y * stride_h - pad_h + p * dilation_h; + int in_x = x * stride_w - pad_w + q * dilation_w; if (in_z >= 0 && in_z < (has_depth ? in->shape(2) : 1) && in_y >= 0 && in_y < in->shape(2 + has_depth) && in_x >= 0 && in_x < in->shape(3 + has_depth)) { @@ -195,6 +199,7 @@ TYPED_TEST(ConvolutionLayerTest, TestSetup) { layer_param.mutable_convolution_param(); convolution_param->add_kernel_size(3); convolution_param->add_stride(2); + convolution_param->add_dilation(1); convolution_param->set_num_output(4); this->blob_bottom_vec_.push_back(this->blob_bottom_2_); this->blob_top_vec_.push_back(this->blob_top_2_); @@ -233,6 +238,7 @@ TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolution) { layer_param.mutable_convolution_param(); convolution_param->add_kernel_size(3); convolution_param->add_stride(2); + convolution_param->add_dilation(1); convolution_param->set_num_output(4); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("constant"); @@ -319,6 +325,7 @@ TYPED_TEST(ConvolutionLayerTest, TestSimple3DConvolution) { layer_param.mutable_convolution_param(); convolution_param->add_kernel_size(3); convolution_param->add_stride(2); + convolution_param->add_dilation(1); convolution_param->set_num_output(4); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("gaussian"); @@ -352,6 +359,7 @@ TYPED_TEST(ConvolutionLayerTest, Test1x1Convolution) { layer_param.mutable_convolution_param(); convolution_param->add_kernel_size(1); convolution_param->add_stride(1); + convolution_param->add_dilation(1); convolution_param->set_num_output(4); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("constant"); @@ -379,6 +387,7 @@ TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolutionGroup) { layer_param.mutable_convolution_param(); convolution_param->add_kernel_size(3); convolution_param->add_stride(2); + convolution_param->add_dilation(1); convolution_param->set_num_output(3); convolution_param->set_group(3); convolution_param->mutable_weight_filler()->set_type("gaussian"); @@ -418,6 +427,7 @@ TYPED_TEST(ConvolutionLayerTest, TestSobelConvolution) { layer_param.mutable_convolution_param(); convolution_param->add_kernel_size(3); convolution_param->add_stride(2); + convolution_param->add_dilation(1); convolution_param->set_num_output(1); convolution_param->set_bias_term(false); shared_ptr > layer( @@ -620,6 +630,7 @@ TYPED_TEST(ConvolutionLayerTest, TestGradient) { this->blob_top_vec_.push_back(this->blob_top_2_); convolution_param->add_kernel_size(3); convolution_param->add_stride(2); + convolution_param->add_dilation(1); convolution_param->set_num_output(2); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("gaussian"); @@ -648,6 +659,7 @@ TYPED_TEST(ConvolutionLayerTest, TestGradient3D) { } convolution_param->add_kernel_size(3); convolution_param->add_stride(2); + convolution_param->add_dilation(1); convolution_param->set_num_output(2); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("gaussian"); @@ -666,6 +678,7 @@ TYPED_TEST(ConvolutionLayerTest, Test1x1Gradient) { this->blob_top_vec_.push_back(this->blob_top_2_); convolution_param->add_kernel_size(1); convolution_param->add_stride(1); + convolution_param->add_dilation(1); convolution_param->set_num_output(2); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("gaussian"); @@ -682,6 +695,7 @@ TYPED_TEST(ConvolutionLayerTest, TestGradientGroup) { layer_param.mutable_convolution_param(); convolution_param->add_kernel_size(3); convolution_param->add_stride(2); + convolution_param->add_dilation(1); convolution_param->set_num_output(3); convolution_param->set_group(3); convolution_param->mutable_weight_filler()->set_type("gaussian"); diff --git a/src/caffe/test/test_im2col_kernel.cu b/src/caffe/test/test_im2col_kernel.cu index 3f97cf6d5ae..15e06aa8583 100644 --- a/src/caffe/test/test_im2col_kernel.cu +++ b/src/caffe/test/test_im2col_kernel.cu @@ -18,6 +18,7 @@ __global__ void im2col_gpu_kernel(const int n, const Dtype* data_im, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int height_col, const int width_col, Dtype* data_col); @@ -38,6 +39,7 @@ class Im2colKernelTest : public GPUDeviceTest { blob_kernel_shape_(new Blob()), blob_stride_(new Blob()), blob_pad_(new Blob()), + blob_dilation_(new Blob()), blob_top_(new Blob()), blob_top_cpu_(new Blob()) { FillerParameter filler_param; @@ -47,20 +49,25 @@ class Im2colKernelTest : public GPUDeviceTest { blob_kernel_shape_->Reshape(dim_blob_shape); blob_stride_->Reshape(dim_blob_shape); blob_pad_->Reshape(dim_blob_shape); + blob_dilation_->Reshape(dim_blob_shape); height_ = blob_bottom_->height(); width_ = blob_bottom_->width(); channels_ = blob_bottom_->channels(); pad_ = 0; stride_ = 2; + dilation_ = 1; kernel_size_ = 3; - height_col_ = (height_ + 2 * pad_ - kernel_size_) / stride_ + 1; - width_col_ = (width_ + 2 * pad_ - kernel_size_) / stride_ + 1; + height_col_ = (height_ + 2 * pad_ - + (dilation_ * (kernel_size_ - 1) + 1)) / stride_ + 1; + width_col_ = (width_ + 2 * pad_ - + (dilation_ * (kernel_size_ - 1) + 1)) / stride_ + 1; for (int i = 0; i < 2; ++i) { blob_kernel_shape_->mutable_cpu_data()[i] = kernel_size_; blob_stride_->mutable_cpu_data()[i] = stride_; blob_pad_->mutable_cpu_data()[i] = pad_; + blob_dilation_->mutable_cpu_data()[i] = dilation_; } } @@ -71,11 +78,13 @@ class Im2colKernelTest : public GPUDeviceTest { delete blob_kernel_shape_; delete blob_stride_; delete blob_pad_; + delete blob_dilation_; } Blob* const blob_kernel_shape_; Blob* const blob_stride_; Blob* const blob_pad_; + Blob* const blob_dilation_; Blob* const blob_bottom_; Blob* const blob_top_; Blob* const blob_top_cpu_; @@ -84,6 +93,7 @@ class Im2colKernelTest : public GPUDeviceTest { int channels_; int pad_; int stride_; + int dilation_; int kernel_size_; int height_col_; int width_col_; @@ -112,7 +122,7 @@ TYPED_TEST(Im2colKernelTest, Test2D) { im2col_cpu(this->blob_bottom_->cpu_data() + this->blob_bottom_->offset(n), this->channels_, this->height_, this->width_, this->kernel_size_, this->kernel_size_, this->pad_, this->pad_, - this->stride_, this->stride_, + this->stride_, this->stride_, this->dilation_, this->dilation_, cpu_data + this->blob_top_cpu_->offset(n)); } @@ -129,6 +139,7 @@ TYPED_TEST(Im2colKernelTest, Test2D) { num_kernels, bottom_data + this->blob_bottom_->offset(n), this->height_, this->width_, this->kernel_size_, this->kernel_size_, this->pad_, this->pad_, this->stride_, this->stride_, + this->dilation_, this->dilation_, this->height_col_, this->width_col_, top_data + this->blob_top_->offset(n)); CUDA_POST_KERNEL_CHECK; diff --git a/src/caffe/test/test_im2col_layer.cpp b/src/caffe/test/test_im2col_layer.cpp index 8274dd48971..932d3f21ae9 100644 --- a/src/caffe/test/test_im2col_layer.cpp +++ b/src/caffe/test/test_im2col_layer.cpp @@ -17,7 +17,7 @@ class Im2colLayerTest : public MultiDeviceTest { typedef typename TypeParam::Dtype Dtype; protected: Im2colLayerTest() - : blob_bottom_(new Blob(2, 3, 6, 5)), + : blob_bottom_(new Blob(2, 3, 10, 9)), blob_top_(new Blob()) { // fill the values Caffe::set_random_seed(1701); @@ -75,6 +75,7 @@ TYPED_TEST(Im2colLayerTest, TestGradient) { layer_param.mutable_convolution_param(); convolution_param->add_kernel_size(3); convolution_param->add_stride(2); + convolution_param->add_dilation(3); Im2colLayer layer(layer_param); GradientChecker checker(1e-2, 1e-2); checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, diff --git a/src/caffe/util/im2col.cpp b/src/caffe/util/im2col.cpp index 27e5b7c0928..373207382b6 100644 --- a/src/caffe/util/im2col.cpp +++ b/src/caffe/util/im2col.cpp @@ -10,9 +10,12 @@ void im2col_cpu(const Dtype* data_im, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, Dtype* data_col) { - const int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1; - const int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1; + const int height_col = (height + 2 * pad_h - + (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_col = (width + 2 * pad_w - + (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; const int channels_col = channels * kernel_h * kernel_w; for (int c_col = 0; c_col < channels_col; ++c_col) { int w_offset = c_col % kernel_w; @@ -20,8 +23,8 @@ void im2col_cpu(const Dtype* data_im, const int channels, int c_im = c_col / kernel_h / kernel_w; for (int h_col = 0; h_col < height_col; ++h_col) { for (int w_col = 0; w_col < width_col; ++w_col) { - int h_im = h_col * stride_h - pad_h + h_offset; - int w_im = w_col * stride_w - pad_w + w_offset; + int h_im = h_col * stride_h - pad_h + h_offset * dilation_h; + int w_im = w_col * stride_w - pad_w + w_offset * dilation_w; data_col[(c_col * height_col + h_col) * width_col + w_col] = (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) ? data_im[(c_im * height + h_im) * width + w_im] : 0; @@ -34,11 +37,13 @@ void im2col_cpu(const Dtype* data_im, const int channels, template void im2col_cpu(const float* data_im, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, float* data_col); + const int stride_w, const int dilation_h, const int dilation_w, + float* data_col); template void im2col_cpu(const double* data_im, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, double* data_col); + const int stride_w, const int dilation_h, const int dilation_w, + double* data_col); template inline void im2col_nd_core_cpu(const Dtype* data_input, const bool im2col, @@ -137,10 +142,13 @@ void col2im_cpu(const Dtype* data_col, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, Dtype* data_im) { caffe_set(height * width * channels, Dtype(0), data_im); - const int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1; - const int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1; + const int height_col = (height + 2 * pad_h - + (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_col = (width + 2 * pad_w - + (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; const int channels_col = channels * kernel_h * kernel_w; for (int c_col = 0; c_col < channels_col; ++c_col) { int w_offset = c_col % kernel_w; @@ -148,8 +156,8 @@ void col2im_cpu(const Dtype* data_col, const int channels, int c_im = c_col / kernel_h / kernel_w; for (int h_col = 0; h_col < height_col; ++h_col) { for (int w_col = 0; w_col < width_col; ++w_col) { - int h_im = h_col * stride_h - pad_h + h_offset; - int w_im = w_col * stride_w - pad_w + w_offset; + int h_im = h_col * stride_h - pad_h + h_offset * dilation_h; + int w_im = w_col * stride_w - pad_w + w_offset * dilation_w; if (h_im >= 0 && h_im < height && w_im >= 0 && w_im < width) data_im[(c_im * height + h_im) * width + w_im] += data_col[(c_col * height_col + h_col) * width_col + w_col]; @@ -162,11 +170,13 @@ void col2im_cpu(const Dtype* data_col, const int channels, template void col2im_cpu(const float* data_col, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, float* data_im); + const int stride_w, const int dilation_h, const int dilation_w, + float* data_im); template void col2im_cpu(const double* data_col, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, double* data_im); + const int stride_w, const int dilation_h, const int dilation_w, + double* data_im); template void col2im_nd_cpu(const Dtype* data_col, const int num_spatial_axes, diff --git a/src/caffe/util/im2col.cu b/src/caffe/util/im2col.cu index 49354ab7aa1..af32cbd13e3 100644 --- a/src/caffe/util/im2col.cu +++ b/src/caffe/util/im2col.cu @@ -10,6 +10,7 @@ __global__ void im2col_gpu_kernel(const int n, const Dtype* data_im, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int height_col, const int width_col, Dtype* data_col) { CUDA_KERNEL_LOOP(index, n) { @@ -26,11 +27,11 @@ __global__ void im2col_gpu_kernel(const int n, const Dtype* data_im, data_im_ptr += (c_im * height + h_offset) * width + w_offset; for (int i = 0; i < kernel_h; ++i) { for (int j = 0; j < kernel_w; ++j) { - int h_im = h_offset + i; - int w_im = w_offset + j; + int h_im = h_offset + i * dilation_h; + int w_im = w_offset + j * dilation_w; *data_col_ptr = (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) ? - data_im_ptr[i * width + j] : 0; + data_im_ptr[i * dilation_h * width + j * dilation_w] : 0; data_col_ptr += height_col * width_col; } } @@ -42,17 +43,20 @@ void im2col_gpu(const Dtype* data_im, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, Dtype* data_col) { // We are going to launch channels * height_col * width_col kernels, each // kernel responsible for copying a single-channel grid. - int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1; - int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1; + int height_col = (height + 2 * pad_h - + (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - + (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; int num_kernels = channels * height_col * width_col; // NOLINT_NEXT_LINE(whitespace/operators) im2col_gpu_kernel<<>>( num_kernels, data_im, height, width, kernel_h, kernel_w, pad_h, - pad_w, stride_h, stride_w, height_col, + pad_w, stride_h, stride_w, dilation_h, dilation_w, height_col, width_col, data_col); CUDA_POST_KERNEL_CHECK; } @@ -61,11 +65,11 @@ void im2col_gpu(const Dtype* data_im, const int channels, template void im2col_gpu(const float* data_im, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, - float* data_col); + const int dilation_h, const int dilation_w, float* data_col); template void im2col_gpu(const double* data_im, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, - double* data_col); + const int dilation_h, const int dilation_w, double* data_col); template __global__ void im2col_nd_gpu_kernel(const int n, const Dtype* data_im, @@ -223,6 +227,7 @@ __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int height_col, const int width_col, Dtype* data_im) { CUDA_KERNEL_LOOP(index, n) { @@ -230,33 +235,27 @@ __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col, const int w_im = index % width + pad_w; const int h_im = (index / width) % height + pad_h; const int c_im = index / (width * height); + int kernel_extent_w = (kernel_w - 1) * dilation_w + 1; + int kernel_extent_h = (kernel_h - 1) * dilation_h + 1; // compute the start and end of the output const int w_col_start = - (w_im < kernel_w) ? 0 : (w_im - kernel_w) / stride_w + 1; - const int w_col_end = - min(w_im / stride_w + 1, width_col); + (w_im < kernel_extent_w) ? 0 : (w_im - kernel_extent_w) / stride_w + 1; + const int w_col_end = min(w_im / stride_w + 1, width_col); const int h_col_start = - (h_im < kernel_h) ? 0 : (h_im - kernel_h) / stride_h + 1; - const int h_col_end = - min(h_im / stride_h + 1, height_col); - /* - for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { - for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { - // the col location: [c * width * height + h_out, w_out] - int c_col = c_im * kernel_h * kernel_w - + (h_im - h_col * stride_h) * kernel_w + (w_im - w_col * stride_w); - val += data_col[(c_col * height_col + h_col) * width_col + w_col]; - } - } - */ - // equivalent implementation - int offset = (c_im * kernel_h * kernel_w + h_im * kernel_w + w_im) - * height_col * width_col; - int coeff_h_col = (1 - stride_h * kernel_w * height_col) * width_col; - int coeff_w_col = (1 - stride_w * height_col * width_col); - for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { - for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { - val += data_col[offset + h_col * coeff_h_col + w_col * coeff_w_col]; + (h_im < kernel_extent_h) ? 0 : (h_im - kernel_extent_h) / stride_h + 1; + const int h_col_end = min(h_im / stride_h + 1, height_col); + // TODO: use LCM of stride and dilation to avoid unnecessary loops + for (int h_col = h_col_start; h_col < h_col_end; h_col += 1) { + for (int w_col = w_col_start; w_col < w_col_end; w_col += 1) { + int h_k = (h_im - h_col * stride_h); + int w_k = (w_im - w_col * stride_w); + if (h_k % dilation_h == 0 && w_k % dilation_w == 0) { + h_k /= dilation_h; + w_k /= dilation_w; + int data_col_index = (((c_im * kernel_h + h_k) * kernel_w + w_k) * + height_col + h_col) * width_col + w_col; + val += data_col[data_col_index]; + } } } data_im[index] = val; @@ -267,9 +266,12 @@ template void col2im_gpu(const Dtype* data_col, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, Dtype* data_im) { - int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1; - int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1; + const int stride_w, const int dilation_h, const int dilation_w, + Dtype* data_im) { + int height_col = (height + 2 * pad_h - + (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - + (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; int num_kernels = channels * height * width; // To avoid involving atomic operations, we will launch one kernel per // bottom dimension, and then in the kernel add up the top dimensions. @@ -277,7 +279,7 @@ void col2im_gpu(const Dtype* data_col, const int channels, col2im_gpu_kernel<<>>( num_kernels, data_col, height, width, channels, kernel_h, kernel_w, - pad_h, pad_w, stride_h, stride_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, height_col, width_col, data_im); CUDA_POST_KERNEL_CHECK; } @@ -286,11 +288,13 @@ void col2im_gpu(const Dtype* data_col, const int channels, template void col2im_gpu(const float* data_col, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, float* data_im); + const int stride_w, const int dilation_h, const int dilation_w, + float* data_im); template void col2im_gpu(const double* data_col, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, double* data_im); + const int stride_w, const int dilation_h, const int dilation_w, + double* data_im); template __global__ void col2im_nd_gpu_kernel(const int n, const Dtype* data_col, From fddd021d876d63000edfc9f4a5ca168a8004ca9a Mon Sep 17 00:00:00 2001 From: Fisher Yu Date: Mon, 7 Dec 2015 15:32:48 -0500 Subject: [PATCH 2/4] add support for N-D dilated convolution also add safeguard to avoid unused variable warning and clean code format --- include/caffe/layers/base_conv_layer.hpp | 8 +- include/caffe/util/im2col.hpp | 8 +- src/caffe/layer_factory.cpp | 2 + src/caffe/layers/base_conv_layer.cpp | 24 +-- src/caffe/layers/conv_layer.cpp | 2 +- src/caffe/layers/im2col_layer.cpp | 6 +- src/caffe/layers/im2col_layer.cu | 4 +- src/caffe/test/test_convolution_layer.cpp | 2 - src/caffe/test/test_im2col_kernel.cu | 9 +- src/caffe/test/test_im2col_layer.cpp | 8 +- src/caffe/util/im2col.cpp | 21 +-- src/caffe/util/im2col.cu | 174 ++++++++++++++-------- 12 files changed, 166 insertions(+), 102 deletions(-) diff --git a/include/caffe/layers/base_conv_layer.hpp b/include/caffe/layers/base_conv_layer.hpp index db471b586da..0160a833dd2 100644 --- a/include/caffe/layers/base_conv_layer.hpp +++ b/include/caffe/layers/base_conv_layer.hpp @@ -106,7 +106,7 @@ class BaseConvolutionLayer : public Layer { } else { im2col_nd_cpu(data, num_spatial_axes_, conv_input_shape_.cpu_data(), col_buffer_shape_.data(), kernel_shape_.cpu_data(), - pad_.cpu_data(), stride_.cpu_data(), col_buff); + pad_.cpu_data(), stride_.cpu_data(), dilation_.cpu_data(), col_buff); } } inline void conv_col2im_cpu(const Dtype* col_buff, Dtype* data) { @@ -120,7 +120,7 @@ class BaseConvolutionLayer : public Layer { } else { col2im_nd_cpu(col_buff, num_spatial_axes_, conv_input_shape_.cpu_data(), col_buffer_shape_.data(), kernel_shape_.cpu_data(), - pad_.cpu_data(), stride_.cpu_data(), data); + pad_.cpu_data(), stride_.cpu_data(), dilation_.cpu_data(), data); } } #ifndef CPU_ONLY @@ -136,7 +136,7 @@ class BaseConvolutionLayer : public Layer { im2col_nd_gpu(data, num_spatial_axes_, num_kernels_im2col_, conv_input_shape_.gpu_data(), col_buffer_.gpu_shape(), kernel_shape_.gpu_data(), pad_.gpu_data(), - stride_.gpu_data(), col_buff); + stride_.gpu_data(), dilation_.gpu_data(), col_buff); } } inline void conv_col2im_gpu(const Dtype* col_buff, Dtype* data) { @@ -151,7 +151,7 @@ class BaseConvolutionLayer : public Layer { col2im_nd_gpu(col_buff, num_spatial_axes_, num_kernels_col2im_, conv_input_shape_.gpu_data(), col_buffer_.gpu_shape(), kernel_shape_.gpu_data(), pad_.gpu_data(), stride_.gpu_data(), - data); + dilation_.gpu_data(), data); } } #endif diff --git a/include/caffe/util/im2col.hpp b/include/caffe/util/im2col.hpp index 748b65c4f36..a35bc6e0b1c 100644 --- a/include/caffe/util/im2col.hpp +++ b/include/caffe/util/im2col.hpp @@ -7,7 +7,7 @@ template void im2col_nd_cpu(const Dtype* data_im, const int num_spatial_axes, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - Dtype* data_col); + const int* dilation, Dtype* data_col); template void im2col_cpu(const Dtype* data_im, const int channels, @@ -20,7 +20,7 @@ template void col2im_nd_cpu(const Dtype* data_col, const int num_spatial_axes, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - Dtype* data_im); + const int* dilation, Dtype* data_im); template void col2im_cpu(const Dtype* data_col, const int channels, @@ -33,7 +33,7 @@ template void im2col_nd_gpu(const Dtype* data_im, const int num_spatial_axes, const int col_size, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - Dtype* data_col); + const int* dilation, Dtype* data_col); template void im2col_gpu(const Dtype* data_im, const int channels, @@ -46,7 +46,7 @@ template void col2im_nd_gpu(const Dtype* data_col, const int num_spatial_axes, const int im_size, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - Dtype* data_im); + const int* dilation, Dtype* data_im); template void col2im_gpu(const Dtype* data_col, const int channels, diff --git a/src/caffe/layer_factory.cpp b/src/caffe/layer_factory.cpp index 6b1d1c1a5f5..4d912d28351 100644 --- a/src/caffe/layer_factory.cpp +++ b/src/caffe/layer_factory.cpp @@ -39,12 +39,14 @@ shared_ptr > GetConvolutionLayer( const LayerParameter& param) { ConvolutionParameter conv_param = param.convolution_param(); ConvolutionParameter_Engine engine = conv_param.engine(); +#ifdef USE_CUDNN bool use_dilation = false; for (int i = 0; i < conv_param.dilation_size(); ++i) { if (conv_param.dilation(i) > 1) { use_dilation = true; } } +#endif if (engine == ConvolutionParameter_Engine_DEFAULT) { engine = ConvolutionParameter_Engine_CAFFE; #ifdef USE_CUDNN diff --git a/src/caffe/layers/base_conv_layer.cpp b/src/caffe/layers/base_conv_layer.cpp index 0e8f5c98f50..af89b4dcd53 100644 --- a/src/caffe/layers/base_conv_layer.cpp +++ b/src/caffe/layers/base_conv_layer.cpp @@ -28,15 +28,15 @@ void BaseConvolutionLayer::LayerSetUp(const vector*>& bottom, CHECK_EQ(num_spatial_axes_, 2) << "kernel_h & kernel_w can only be used for 2D convolution."; CHECK_EQ(0, conv_param.kernel_size_size()) - << "Either kernel_size or kernel_h/w should be specified; not both."; + << "Either kernel_size or kernel_h/w should be specified, not both."; kernel_shape_data[0] = conv_param.kernel_h(); kernel_shape_data[1] = conv_param.kernel_w(); } else { const int num_kernel_dims = conv_param.kernel_size_size(); CHECK(num_kernel_dims == 1 || num_kernel_dims == num_spatial_axes_) << "kernel_size must be specified once, or once per spatial dimension " - << "(kernel_size specified " << num_kernel_dims << " times; " - << num_spatial_axes_ << " spatial dims);"; + << "(kernel_size specified " << num_kernel_dims << " times " + << num_spatial_axes_ << " spatial dims)."; for (int i = 0; i < num_spatial_axes_; ++i) { kernel_shape_data[i] = conv_param.kernel_size((num_kernel_dims == 1) ? 0 : i); @@ -52,7 +52,7 @@ void BaseConvolutionLayer::LayerSetUp(const vector*>& bottom, CHECK_EQ(num_spatial_axes_, 2) << "stride_h & stride_w can only be used for 2D convolution."; CHECK_EQ(0, conv_param.stride_size()) - << "Either stride or stride_h/w should be specified; not both."; + << "Either stride or stride_h/w should be specified, not both."; stride_data[0] = conv_param.stride_h(); stride_data[1] = conv_param.stride_w(); } else { @@ -60,8 +60,8 @@ void BaseConvolutionLayer::LayerSetUp(const vector*>& bottom, CHECK(num_stride_dims == 0 || num_stride_dims == 1 || num_stride_dims == num_spatial_axes_) << "stride must be specified once, or once per spatial dimension " - << "(stride specified " << num_stride_dims << " times; " - << num_spatial_axes_ << " spatial dims);"; + << "(stride specified " << num_stride_dims << " times " + << num_spatial_axes_ << " spatial dims)."; const int kDefaultStride = 1; for (int i = 0; i < num_spatial_axes_; ++i) { stride_data[i] = (num_stride_dims == 0) ? kDefaultStride : @@ -76,7 +76,7 @@ void BaseConvolutionLayer::LayerSetUp(const vector*>& bottom, CHECK_EQ(num_spatial_axes_, 2) << "pad_h & pad_w can only be used for 2D convolution."; CHECK_EQ(0, conv_param.pad_size()) - << "Either pad or pad_h/w should be specified; not both."; + << "Either pad or pad_h/w should be specified, not both."; pad_data[0] = conv_param.pad_h(); pad_data[1] = conv_param.pad_w(); } else { @@ -84,8 +84,8 @@ void BaseConvolutionLayer::LayerSetUp(const vector*>& bottom, CHECK(num_pad_dims == 0 || num_pad_dims == 1 || num_pad_dims == num_spatial_axes_) << "pad must be specified once, or once per spatial dimension " - << "(pad specified " << num_pad_dims << " times; " - << num_spatial_axes_ << " spatial dims);"; + << "(pad specified " << num_pad_dims << " times " + << num_spatial_axes_ << " spatial dims)"; const int kDefaultPad = 0; for (int i = 0; i < num_spatial_axes_; ++i) { pad_data[i] = (num_pad_dims == 0) ? kDefaultPad : @@ -98,9 +98,9 @@ void BaseConvolutionLayer::LayerSetUp(const vector*>& bottom, const int num_dilation_dims = conv_param.dilation_size(); CHECK(num_dilation_dims == 0 || num_dilation_dims == 1 || num_dilation_dims == num_spatial_axes_) - << "dilation must be specified once, or once per spatial dimension " - << "(dilation specified " << num_dilation_dims << " times; " - << num_spatial_axes_ << " spatial dims);"; + << "dilation must be specified once, or once per spatial dimension " + << "(dilation specified " << num_dilation_dims << " times " + << num_spatial_axes_ << " spatial dims)."; const int kDefaultDilation = 1; for (int i = 0; i < num_spatial_axes_; ++i) { dilation_data[i] = (num_dilation_dims == 0) ? kDefaultDilation : diff --git a/src/caffe/layers/conv_layer.cpp b/src/caffe/layers/conv_layer.cpp index b4aca54ee23..5d522ab31f2 100644 --- a/src/caffe/layers/conv_layer.cpp +++ b/src/caffe/layers/conv_layer.cpp @@ -14,7 +14,7 @@ void ConvolutionLayer::compute_output_shape() { for (int i = 0; i < this->num_spatial_axes_; ++i) { // i + 1 to skip channel axis const int input_dim = this->input_shape(i + 1); - int kernel_extent = dilation_data[i] * (kernel_shape_data[i] - 1) + 1; + const int kernel_extent = dilation_data[i] * (kernel_shape_data[i] - 1) + 1; const int output_dim = (input_dim + 2 * pad_data[i] - kernel_extent) / stride_data[i] + 1; this->output_shape_.push_back(output_dim); diff --git a/src/caffe/layers/im2col_layer.cpp b/src/caffe/layers/im2col_layer.cpp index 2b6fc9c2867..d7dbb7db192 100644 --- a/src/caffe/layers/im2col_layer.cpp +++ b/src/caffe/layers/im2col_layer.cpp @@ -114,7 +114,7 @@ void Im2colLayer::Reshape(const vector*>& bottom, for (int i = 0; i < num_spatial_axes_; ++i) { top_shape[channel_axis_] *= kernel_shape_data[i]; const int input_dim = bottom[0]->shape(channel_axis_ + i + 1); - int kernel_extent = dilation_data[i] * (kernel_shape_data[i] - 1) + 1; + const int kernel_extent = dilation_data[i] * (kernel_shape_data[i] - 1) + 1; const int output_dim = (input_dim + 2 * pad_data[i] - kernel_extent) / stride_data[i] + 1; top_shape[channel_axis_ + i + 1] = output_dim; @@ -153,7 +153,7 @@ void Im2colLayer::Forward_cpu(const vector*>& bottom, bottom[0]->shape().data() + channel_axis_, top[0]->shape().data() + channel_axis_, kernel_shape_.cpu_data(), pad_.cpu_data(), stride_.cpu_data(), - top_data + n * top_dim_); + dilation_.cpu_data(), top_data + n * top_dim_); } } } @@ -178,7 +178,7 @@ void Im2colLayer::Backward_cpu(const vector*>& top, bottom[0]->shape().data() + channel_axis_, top[0]->shape().data() + channel_axis_, kernel_shape_.cpu_data(), pad_.cpu_data(), stride_.cpu_data(), - bottom_diff + n * bottom_dim_); + dilation_.cpu_data(), bottom_diff + n * bottom_dim_); } } } diff --git a/src/caffe/layers/im2col_layer.cu b/src/caffe/layers/im2col_layer.cu index d90075d4304..792c97f70f9 100644 --- a/src/caffe/layers/im2col_layer.cu +++ b/src/caffe/layers/im2col_layer.cu @@ -26,7 +26,7 @@ void Im2colLayer::Forward_gpu(const vector*>& bottom, num_kernels, bottom[0]->gpu_shape() + channel_axis_, top[0]->gpu_shape() + channel_axis_, kernel_shape_.gpu_data(), pad_.gpu_data(), stride_.gpu_data(), - top_data + n * top_dim_); + dilation_.gpu_data(), top_data + n * top_dim_); } } } @@ -51,7 +51,7 @@ void Im2colLayer::Backward_gpu(const vector*>& top, bottom[0]->gpu_shape() + channel_axis_, top[0]->gpu_shape() + channel_axis_, kernel_shape_.gpu_data(), pad_.gpu_data(), stride_.gpu_data(), - bottom_diff + n * bottom_dim_); + dilation_.gpu_data(), bottom_diff + n * bottom_dim_); } } } diff --git a/src/caffe/test/test_convolution_layer.cpp b/src/caffe/test/test_convolution_layer.cpp index 373aa7a7398..533b9dac837 100644 --- a/src/caffe/test/test_convolution_layer.cpp +++ b/src/caffe/test/test_convolution_layer.cpp @@ -199,7 +199,6 @@ TYPED_TEST(ConvolutionLayerTest, TestSetup) { layer_param.mutable_convolution_param(); convolution_param->add_kernel_size(3); convolution_param->add_stride(2); - convolution_param->add_dilation(1); convolution_param->set_num_output(4); this->blob_bottom_vec_.push_back(this->blob_bottom_2_); this->blob_top_vec_.push_back(this->blob_top_2_); @@ -238,7 +237,6 @@ TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolution) { layer_param.mutable_convolution_param(); convolution_param->add_kernel_size(3); convolution_param->add_stride(2); - convolution_param->add_dilation(1); convolution_param->set_num_output(4); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("constant"); diff --git a/src/caffe/test/test_im2col_kernel.cu b/src/caffe/test/test_im2col_kernel.cu index 15e06aa8583..5d8f01f1713 100644 --- a/src/caffe/test/test_im2col_kernel.cu +++ b/src/caffe/test/test_im2col_kernel.cu @@ -26,7 +26,7 @@ template __global__ void im2col_nd_gpu_kernel(const int n, const Dtype* data_im, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - Dtype* data_col); + const int* dilation, Dtype* data_col); extern cudaDeviceProp CAFFE_TEST_CUDA_PROP; @@ -35,7 +35,7 @@ class Im2colKernelTest : public GPUDeviceTest { protected: Im2colKernelTest() // big so launches > 1024 threads - : blob_bottom_(new Blob(5, 500, 10, 10)), + : blob_bottom_(new Blob(5, 500, 15, 15)), blob_kernel_shape_(new Blob()), blob_stride_(new Blob()), blob_pad_(new Blob()), @@ -56,7 +56,7 @@ class Im2colKernelTest : public GPUDeviceTest { channels_ = blob_bottom_->channels(); pad_ = 0; stride_ = 2; - dilation_ = 1; + dilation_ = 3; kernel_size_ = 3; height_col_ = (height_ + 2 * pad_ - (dilation_ * (kernel_size_ - 1) + 1)) / stride_ + 1; @@ -176,6 +176,7 @@ TYPED_TEST(Im2colKernelTest, TestND) { this->blob_top_cpu_->shape().data() + 1, this->blob_kernel_shape_->cpu_data(), this->blob_pad_->cpu_data(), this->blob_stride_->cpu_data(), + this->blob_dilation_->cpu_data(), top_data_cpu + this->blob_top_cpu_->offset(n)); } @@ -194,7 +195,7 @@ TYPED_TEST(Im2colKernelTest, TestND) { num_kernels, bottom_data_gpu + this->blob_bottom_->offset(n), this->blob_bottom_->gpu_shape() + 1, this->blob_top_->gpu_shape() + 1, this->blob_kernel_shape_->gpu_data(), this->blob_pad_->gpu_data(), - this->blob_stride_->gpu_data(), + this->blob_stride_->gpu_data(), this->blob_dilation_->gpu_data(), top_data_gpu + this->blob_top_->offset(n)); CUDA_POST_KERNEL_CHECK; } diff --git a/src/caffe/test/test_im2col_layer.cpp b/src/caffe/test/test_im2col_layer.cpp index 932d3f21ae9..24885e6b706 100644 --- a/src/caffe/test/test_im2col_layer.cpp +++ b/src/caffe/test/test_im2col_layer.cpp @@ -17,7 +17,7 @@ class Im2colLayerTest : public MultiDeviceTest { typedef typename TypeParam::Dtype Dtype; protected: Im2colLayerTest() - : blob_bottom_(new Blob(2, 3, 10, 9)), + : blob_bottom_(new Blob(2, 3, 10, 11)), blob_top_(new Blob()) { // fill the values Caffe::set_random_seed(1701); @@ -43,12 +43,13 @@ TYPED_TEST(Im2colLayerTest, TestSetup) { layer_param.mutable_convolution_param(); convolution_param->add_kernel_size(3); convolution_param->add_stride(2); + convolution_param->add_dilation(3); Im2colLayer layer(layer_param); layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); EXPECT_EQ(this->blob_top_->num(), 2); EXPECT_EQ(this->blob_top_->channels(), 27); EXPECT_EQ(this->blob_top_->height(), 2); - EXPECT_EQ(this->blob_top_->width(), 2); + EXPECT_EQ(this->blob_top_->width(), 3); } TYPED_TEST(Im2colLayerTest, TestForward) { @@ -89,6 +90,7 @@ TYPED_TEST(Im2colLayerTest, TestGradientForceND) { layer_param.mutable_convolution_param(); convolution_param->add_kernel_size(3); convolution_param->add_stride(2); + convolution_param->add_dilation(3); convolution_param->set_force_nd_im2col(true); Im2colLayer layer(layer_param); GradientChecker checker(1e-2, 1e-2); @@ -123,6 +125,8 @@ TYPED_TEST(Im2colLayerTest, TestRectGradient) { convolution_param->set_kernel_h(5); convolution_param->set_kernel_w(3); convolution_param->add_stride(2); + convolution_param->add_dilation(1); + convolution_param->add_dilation(3); Im2colLayer layer(layer_param); GradientChecker checker(1e-2, 1e-2); checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, diff --git a/src/caffe/util/im2col.cpp b/src/caffe/util/im2col.cpp index 373207382b6..40c84a7aaef 100644 --- a/src/caffe/util/im2col.cpp +++ b/src/caffe/util/im2col.cpp @@ -49,7 +49,7 @@ template inline void im2col_nd_core_cpu(const Dtype* data_input, const bool im2col, const int num_spatial_axes, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - Dtype* data_output) { + const int* dilation, Dtype* data_output) { if (!im2col) { int im_size = im_shape[0]; for (int i = 0; i < num_spatial_axes; ++i) { @@ -81,7 +81,8 @@ inline void im2col_nd_core_cpu(const Dtype* data_input, const bool im2col, bool is_padding = false; for (int d_i = 0; d_i < num_spatial_axes; ++d_i) { const int d = d_iter[d_i]; - const int d_im = d * stride[d_i] - pad[d_i] + d_offset[d_i]; + const int d_im = d * stride[d_i] - pad[d_i] + + d_offset[d_i] * dilation[d_i]; is_padding |= d_im < 0 || d_im >= im_shape[d_i + 1]; index_col *= col_shape[d_i + 1]; index_col += d; @@ -119,10 +120,10 @@ template void im2col_nd_cpu(const Dtype* data_im, const int num_spatial_axes, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - Dtype* data_col) { + const int* dilation, Dtype* data_col) { const bool kIm2Col = true; im2col_nd_core_cpu(data_im, kIm2Col, num_spatial_axes, im_shape, col_shape, - kernel_shape, pad, stride, data_col); + kernel_shape, pad, stride, dilation, data_col); } // Explicit instantiation @@ -130,12 +131,12 @@ template void im2col_nd_cpu(const float* data_im, const int num_spatial_axes, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - float* data_col); + const int* dilation, float* data_col); template void im2col_nd_cpu(const double* data_im, const int num_spatial_axes, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - double* data_col); + const int* dilation, double* data_col); template void col2im_cpu(const Dtype* data_col, const int channels, @@ -182,10 +183,10 @@ template void col2im_nd_cpu(const Dtype* data_col, const int num_spatial_axes, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - Dtype* data_im) { + const int* dilation, Dtype* data_im) { const bool kIm2Col = false; im2col_nd_core_cpu(data_col, kIm2Col, num_spatial_axes, im_shape, col_shape, - kernel_shape, pad, stride, data_im); + kernel_shape, pad, stride, dilation, data_im); } // Explicit instantiation @@ -193,12 +194,12 @@ template void col2im_nd_cpu(const float* data_col, const int num_spatial_axes, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - float* data_im); + const int* dilation, float* data_im); template void col2im_nd_cpu(const double* data_col, const int num_spatial_axes, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - double* data_im); + const int* dilation, double* data_im); } // namespace caffe diff --git a/src/caffe/util/im2col.cu b/src/caffe/util/im2col.cu index af32cbd13e3..ae0a8077424 100644 --- a/src/caffe/util/im2col.cu +++ b/src/caffe/util/im2col.cu @@ -75,9 +75,29 @@ template __global__ void im2col_nd_gpu_kernel(const int n, const Dtype* data_im, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - Dtype* data_col) { + const int* dilation, Dtype* data_col) { int d_temp[num_axes]; // NOLINT(runtime/arrays) int d_iter[num_axes]; // NOLINT(runtime/arrays) + + __shared__ int shared_dilation[num_axes]; + __shared__ int shared_kernel_shape[num_axes]; + __shared__ int shared_pad[num_axes]; + __shared__ int shared_stride[num_axes]; + __shared__ int shared_col_shape[num_axes + 1]; + __shared__ int shared_im_shape[num_axes + 1]; + + if (threadIdx.x < num_axes) { + shared_dilation[threadIdx.x] = dilation[threadIdx.x]; + shared_kernel_shape[threadIdx.x] = kernel_shape[threadIdx.x]; + shared_pad[threadIdx.x] = pad[threadIdx.x]; + shared_stride[threadIdx.x] = stride[threadIdx.x]; + } + if (threadIdx.x < num_axes + 1) { + shared_col_shape[threadIdx.x] = col_shape[threadIdx.x]; + shared_im_shape[threadIdx.x] = im_shape[threadIdx.x]; + } + __syncthreads(); + int i; CUDA_KERNEL_LOOP(index, n) { // Initialize channel_in, computed in the loop below, with intermediate @@ -85,19 +105,19 @@ __global__ void im2col_nd_gpu_kernel(const int n, const Dtype* data_im, int channel_in = index; int channel_out = 1; for (i = num_axes - 1; i >= 0; --i) { - d_temp[i] = channel_in % col_shape[i + 1]; - channel_in /= col_shape[i + 1]; - channel_out *= kernel_shape[i]; + d_temp[i] = channel_in % shared_col_shape[i + 1]; + channel_in /= shared_col_shape[i + 1]; + channel_out *= shared_kernel_shape[i]; } channel_out *= channel_in; int data_col_inc = 1; for (i = 0; i < num_axes; ++i) { - channel_out *= col_shape[i + 1]; + channel_out *= shared_col_shape[i + 1]; channel_out += d_temp[i]; - d_temp[i] = d_temp[i] * stride[i] - pad[i]; - channel_in *= im_shape[i + 1]; + d_temp[i] = d_temp[i] * shared_stride[i] - shared_pad[i]; + channel_in *= shared_im_shape[i + 1]; channel_in += d_temp[i]; - data_col_inc *= col_shape[i + 1]; + data_col_inc *= shared_col_shape[i + 1]; d_iter[i] = 0; } Dtype* data_col_ptr = data_col + channel_out; @@ -106,15 +126,15 @@ __global__ void im2col_nd_gpu_kernel(const int n, const Dtype* data_im, do { bool in_range = true; for (i = 0; i < num_axes; ++i) { - const int d_iter_im = d_iter[i] + d_temp[i]; - in_range &= d_iter_im >= 0 && d_iter_im < im_shape[i + 1]; + const int d_iter_im = d_iter[i] * shared_dilation[i] + d_temp[i]; + in_range &= d_iter_im >= 0 && d_iter_im < shared_im_shape[i + 1]; if (!in_range) { break; } } if (in_range) { - int data_im_offset = d_iter[0]; + int data_im_offset = d_iter[0] * shared_dilation[0]; for (i = 1; i < num_axes; ++i) { - data_im_offset *= im_shape[i + 1]; - data_im_offset += d_iter[i]; + data_im_offset *= shared_im_shape[i + 1]; + data_im_offset += d_iter[i] * shared_dilation[i]; } *data_col_ptr = data_im_ptr[data_im_offset]; } else { @@ -123,7 +143,7 @@ __global__ void im2col_nd_gpu_kernel(const int n, const Dtype* data_im, data_col_ptr += data_col_inc; incremented = false; for (i = num_axes - 1; i >= 0; --i) { - const int d_max = kernel_shape[i]; + const int d_max = shared_kernel_shape[i]; if (d_iter[i] == d_max - 1) { d_iter[i] = 0; } else { // d_iter[i] < d_max - 1 @@ -140,67 +160,69 @@ template void im2col_nd_gpu(const Dtype* data_im, const int num_spatial_axes, const int num_kernels, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - Dtype* data_col) { + const int* dilation, Dtype* data_col) { + // num_axes should be smaller than block size + DCHECK_LT(10, CAFFE_CUDA_NUM_THREADS); switch (num_spatial_axes) { case 1: im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( num_kernels, data_im, im_shape, col_shape, - kernel_shape, pad, stride, data_col); + kernel_shape, pad, stride, dilation, data_col); break; case 2: im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( num_kernels, data_im, im_shape, col_shape, - kernel_shape, pad, stride, data_col); + kernel_shape, pad, stride, dilation, data_col); break; case 3: im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( num_kernels, data_im, im_shape, col_shape, - kernel_shape, pad, stride, data_col); + kernel_shape, pad, stride, dilation, data_col); break; case 4: im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( num_kernels, data_im, im_shape, col_shape, - kernel_shape, pad, stride, data_col); + kernel_shape, pad, stride, dilation, data_col); break; case 5: im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( num_kernels, data_im, im_shape, col_shape, - kernel_shape, pad, stride, data_col); + kernel_shape, pad, stride, dilation, data_col); break; case 6: im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( num_kernels, data_im, im_shape, col_shape, - kernel_shape, pad, stride, data_col); + kernel_shape, pad, stride, dilation, data_col); break; case 7: im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( num_kernels, data_im, im_shape, col_shape, - kernel_shape, pad, stride, data_col); + kernel_shape, pad, stride, dilation, data_col); break; case 8: im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( num_kernels, data_im, im_shape, col_shape, - kernel_shape, pad, stride, data_col); + kernel_shape, pad, stride, dilation, data_col); break; case 9: im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( num_kernels, data_im, im_shape, col_shape, - kernel_shape, pad, stride, data_col); + kernel_shape, pad, stride, dilation, data_col); break; case 10: im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( num_kernels, data_im, im_shape, col_shape, - kernel_shape, pad, stride, data_col); + kernel_shape, pad, stride, dilation, data_col); break; default: LOG(FATAL) << "im2col_nd_gpu does not support computation with " @@ -214,12 +236,12 @@ template void im2col_nd_gpu(const float* data_im, const int num_spatial_axes, const int col_size, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - float* data_col); + const int* dilation, float* data_col); template void im2col_nd_gpu(const double* data_im, const int num_spatial_axes, const int col_size, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - double* data_col); + const int* dilation, double* data_col); template __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col, @@ -268,10 +290,10 @@ void col2im_gpu(const Dtype* data_col, const int channels, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, Dtype* data_im) { - int height_col = (height + 2 * pad_h - - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; - int width_col = (width + 2 * pad_w - - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + int height_col = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / + stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / + stride_w + 1; int num_kernels = channels * height * width; // To avoid involving atomic operations, we will launch one kernel per // bottom dimension, and then in the kernel add up the top dimensions. @@ -300,27 +322,50 @@ template __global__ void col2im_nd_gpu_kernel(const int n, const Dtype* data_col, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - Dtype* data_im) { + const int* dilation, Dtype* data_im) { int d_im[num_axes]; // NOLINT(runtime/arrays) int d_col_iter[num_axes]; // NOLINT(runtime/arrays) int d_col_start[num_axes]; // NOLINT(runtime/arrays) int d_col_end[num_axes]; // NOLINT(runtime/arrays) + + __shared__ int shared_dilation[num_axes]; + __shared__ int shared_kernel_shape[num_axes]; + __shared__ int shared_pad[num_axes]; + __shared__ int shared_stride[num_axes]; + __shared__ int shared_col_shape[num_axes + 1]; + __shared__ int shared_im_shape[num_axes + 1]; + + if (threadIdx.x < num_axes) { + shared_dilation[threadIdx.x] = dilation[threadIdx.x]; + shared_kernel_shape[threadIdx.x] = kernel_shape[threadIdx.x]; + shared_pad[threadIdx.x] = pad[threadIdx.x]; + shared_stride[threadIdx.x] = stride[threadIdx.x]; + } + if (threadIdx.x < num_axes + 1) { + shared_col_shape[threadIdx.x] = col_shape[threadIdx.x]; + shared_im_shape[threadIdx.x] = im_shape[threadIdx.x]; + } + __syncthreads(); + CUDA_KERNEL_LOOP(index, n) { // Initialize channel_in, computed in the loop below, with intermediate // computations used to compute the spatial indices. int c_im = index; // Calculate d_im (image dimensions). for (int i = num_axes - 1; i >= 0; --i) { - d_im[i] = c_im % im_shape[i + 1] + pad[i]; - c_im /= im_shape[i + 1]; + d_im[i] = c_im % shared_im_shape[i + 1] + shared_pad[i]; + c_im /= shared_im_shape[i + 1]; } // Calculate col start/end indices. bool done = false; for (int i = 0; i < num_axes; ++i) { + const int kernel_extent = + shared_dilation[i] * (shared_kernel_shape[i] - 1) + 1; d_col_start[i] = d_col_iter[i] = - (d_im[i] < kernel_shape[i]) ? - 0 : (d_im[i] - kernel_shape[i]) / stride[i] + 1; - d_col_end[i] = min(d_im[i] / stride[i] + 1, col_shape[i + 1]); + (d_im[i] < kernel_extent) ? 0 : + (d_im[i] - kernel_extent) / shared_stride[i] + 1; + d_col_end[i] = + min(d_im[i] / shared_stride[i] + 1, shared_col_shape[i + 1]); if (d_col_start[i] >= d_col_end[i]) { // Skip computation if the dimension is 0 at any spatial axis -- // final val will be 0. @@ -335,21 +380,32 @@ __global__ void col2im_nd_gpu_kernel(const int n, const Dtype* data_col, // Loop over the col to compute the output val. Dtype val = 0; bool incremented = true; + bool skip = false; do { // Compute the final offset. int final_offset = 0; int kernel_shape_prod = 1; + int kernel_index; for (int i = num_axes - 1; i >= 0; --i) { - final_offset += - (d_im[i] - d_col_iter[i] * stride[i]) * kernel_shape_prod; - kernel_shape_prod *= kernel_shape[i]; + kernel_index = d_im[i] - d_col_iter[i] * shared_stride[i]; + if (kernel_index % shared_dilation[i]) { + skip = true; + break; + } else { + kernel_index /= shared_dilation[i]; + final_offset += kernel_index * kernel_shape_prod; + kernel_shape_prod *= shared_kernel_shape[i]; + } } - final_offset += kernel_shape_prod * c_im; - for (int i = 0; i < num_axes; ++i) { - final_offset *= col_shape[i + 1]; - final_offset += d_col_iter[i]; + if (!skip) { + final_offset += kernel_shape_prod * c_im; + for (int i = 0; i < num_axes; ++i) { + final_offset *= shared_col_shape[i + 1]; + final_offset += d_col_iter[i]; + } + val += data_col[final_offset]; } - val += data_col[final_offset]; + skip = false; incremented = false; for (int i = num_axes - 1; i >= 0; --i) { const int d_max = d_col_end[i]; @@ -370,67 +426,69 @@ template void col2im_nd_gpu(const Dtype* data_col, const int num_spatial_axes, const int im_size, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - Dtype* data_im) { + const int* dilation, Dtype* data_im) { + // num_axes should be smaller than block size + DCHECK_LT(10, CAFFE_CUDA_NUM_THREADS); switch (num_spatial_axes) { case 1: col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( im_size, data_col, im_shape, col_shape, - kernel_shape, pad, stride, data_im); + kernel_shape, pad, stride, dilation, data_im); break; case 2: col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( im_size, data_col, im_shape, col_shape, - kernel_shape, pad, stride, data_im); + kernel_shape, pad, stride, dilation, data_im); break; case 3: col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( im_size, data_col, im_shape, col_shape, - kernel_shape, pad, stride, data_im); + kernel_shape, pad, stride, dilation, data_im); break; case 4: col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( im_size, data_col, im_shape, col_shape, - kernel_shape, pad, stride, data_im); + kernel_shape, pad, stride, dilation, data_im); break; case 5: col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( im_size, data_col, im_shape, col_shape, - kernel_shape, pad, stride, data_im); + kernel_shape, pad, stride, dilation, data_im); break; case 6: col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( im_size, data_col, im_shape, col_shape, - kernel_shape, pad, stride, data_im); + kernel_shape, pad, stride, dilation, data_im); break; case 7: col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( im_size, data_col, im_shape, col_shape, - kernel_shape, pad, stride, data_im); + kernel_shape, pad, stride, dilation, data_im); break; case 8: col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( im_size, data_col, im_shape, col_shape, - kernel_shape, pad, stride, data_im); + kernel_shape, pad, stride, dilation, data_im); break; case 9: col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( im_size, data_col, im_shape, col_shape, - kernel_shape, pad, stride, data_im); + kernel_shape, pad, stride, dilation, data_im); break; case 10: col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( im_size, data_col, im_shape, col_shape, - kernel_shape, pad, stride, data_im); + kernel_shape, pad, stride, dilation, data_im); break; default: LOG(FATAL) << "col2im_nd_gpu does not support computation with " @@ -444,11 +502,11 @@ template void col2im_nd_gpu(const float* data_col, const int num_spatial_axes, const int im_size, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - float* data_im); + const int* dilation, float* data_im); template void col2im_nd_gpu(const double* data_col, const int num_spatial_axes, const int im_size, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - double* data_im); + const int* dilation, double* data_im); } // namespace caffe From 8fdc7ee6671eb860ed0fce3c06c43ade1856077b Mon Sep 17 00:00:00 2001 From: Fisher Yu Date: Thu, 17 Dec 2015 12:16:58 -0500 Subject: [PATCH 3/4] improve test cases for dilated convolution --- src/caffe/solvers/adam_solver.cpp | 2 +- src/caffe/test/test_convolution_layer.cpp | 123 ++++++++++++++++++++-- src/caffe/test/test_im2col_layer.cpp | 54 ++++++++-- src/caffe/util/im2col.cpp | 8 +- src/caffe/util/im2col.cu | 8 +- 5 files changed, 172 insertions(+), 23 deletions(-) diff --git a/src/caffe/solvers/adam_solver.cpp b/src/caffe/solvers/adam_solver.cpp index cb0fbfe2f78..c3378d3890a 100644 --- a/src/caffe/solvers/adam_solver.cpp +++ b/src/caffe/solvers/adam_solver.cpp @@ -30,7 +30,7 @@ void AdamSolver::ComputeUpdateValue(int param_id, Dtype rate) { Blob* val_v = this->history_[param_id + update_history_offset].get(); Blob* val_t = this->temp_[param_id].get(); - const int t = this->iter_ + 1; + const int t = this->iter_ + 1; const Dtype correction = std::sqrt(Dtype(1) - pow(beta2, t)) / (Dtype(1.) - pow(beta1, t)); const int N = net_params[param_id]->count(); diff --git a/src/caffe/test/test_convolution_layer.cpp b/src/caffe/test/test_convolution_layer.cpp index 533b9dac837..9bb19d13592 100644 --- a/src/caffe/test/test_convolution_layer.cpp +++ b/src/caffe/test/test_convolution_layer.cpp @@ -264,6 +264,50 @@ TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolution) { } } +TYPED_TEST(ConvolutionLayerTest, TestDilatedConvolution) { + typedef typename TypeParam::Dtype Dtype; + vector bottom_shape; + bottom_shape.push_back(2); + bottom_shape.push_back(3); + bottom_shape.push_back(8); + bottom_shape.push_back(7); + this->blob_bottom_vec_.push_back(this->blob_bottom_2_); + this->blob_top_vec_.push_back(this->blob_top_2_); + for (int i = 0; i < this->blob_bottom_vec_.size(); ++i) { + this->blob_bottom_vec_[i]->Reshape(bottom_shape); + } + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + convolution_param->add_kernel_size(3); + convolution_param->add_dilation(2); + convolution_param->set_num_output(4); + convolution_param->mutable_weight_filler()->set_type("gaussian"); + convolution_param->mutable_bias_filler()->set_type("constant"); + convolution_param->mutable_bias_filler()->set_value(0.1); + shared_ptr > layer( + new ConvolutionLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Check against reference convolution. + const Dtype* top_data; + const Dtype* ref_top_data; + caffe_conv(this->blob_bottom_, convolution_param, layer->blobs(), + this->MakeReferenceTop(this->blob_top_)); + top_data = this->blob_top_->cpu_data(); + ref_top_data = this->ref_blob_top_->cpu_data(); + for (int i = 0; i < this->blob_top_->count(); ++i) { + EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + } + caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(), + this->MakeReferenceTop(this->blob_top_2_)); + top_data = this->blob_top_2_->cpu_data(); + ref_top_data = this->ref_blob_top_->cpu_data(); + for (int i = 0; i < this->blob_top_->count(); ++i) { + EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + } +} + TYPED_TEST(ConvolutionLayerTest, Test0DConvolution) { typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; @@ -323,7 +367,6 @@ TYPED_TEST(ConvolutionLayerTest, TestSimple3DConvolution) { layer_param.mutable_convolution_param(); convolution_param->add_kernel_size(3); convolution_param->add_stride(2); - convolution_param->add_dilation(1); convolution_param->set_num_output(4); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("gaussian"); @@ -350,6 +393,53 @@ TYPED_TEST(ConvolutionLayerTest, TestSimple3DConvolution) { } } +TYPED_TEST(ConvolutionLayerTest, TestDilated3DConvolution) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_2_); + this->blob_top_vec_.push_back(this->blob_top_2_); + vector bottom_shape(5); + bottom_shape[0] = this->blob_bottom_vec_[0]->shape(0); + bottom_shape[1] = this->blob_bottom_vec_[0]->shape(1); + bottom_shape[2] = 6; + bottom_shape[3] = 7; + bottom_shape[4] = 8; + FillerParameter filler_param; + GaussianFiller filler(filler_param); + for (int i = 0; i < this->blob_bottom_vec_.size(); ++i) { + this->blob_bottom_vec_[i]->Reshape(bottom_shape); + filler.Fill(this->blob_bottom_vec_[i]); + } + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + convolution_param->add_kernel_size(3); + convolution_param->add_dilation(2); + convolution_param->set_num_output(4); + convolution_param->mutable_weight_filler()->set_type("gaussian"); + convolution_param->mutable_bias_filler()->set_type("gaussian"); + shared_ptr > layer( + new ConvolutionLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Check against reference convolution. + const Dtype* top_data; + const Dtype* ref_top_data; + caffe_conv(this->blob_bottom_, convolution_param, layer->blobs(), + this->MakeReferenceTop(this->blob_top_)); + top_data = this->blob_top_->cpu_data(); + ref_top_data = this->ref_blob_top_->cpu_data(); + for (int i = 0; i < this->blob_top_->count(); ++i) { + EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + } + caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(), + this->MakeReferenceTop(this->blob_top_2_)); + top_data = this->blob_top_2_->cpu_data(); + ref_top_data = this->ref_blob_top_->cpu_data(); + for (int i = 0; i < this->blob_top_->count(); ++i) { + EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + } +} + TYPED_TEST(ConvolutionLayerTest, Test1x1Convolution) { typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; @@ -357,7 +447,6 @@ TYPED_TEST(ConvolutionLayerTest, Test1x1Convolution) { layer_param.mutable_convolution_param(); convolution_param->add_kernel_size(1); convolution_param->add_stride(1); - convolution_param->add_dilation(1); convolution_param->set_num_output(4); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("constant"); @@ -385,7 +474,6 @@ TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolutionGroup) { layer_param.mutable_convolution_param(); convolution_param->add_kernel_size(3); convolution_param->add_stride(2); - convolution_param->add_dilation(1); convolution_param->set_num_output(3); convolution_param->set_group(3); convolution_param->mutable_weight_filler()->set_type("gaussian"); @@ -425,7 +513,6 @@ TYPED_TEST(ConvolutionLayerTest, TestSobelConvolution) { layer_param.mutable_convolution_param(); convolution_param->add_kernel_size(3); convolution_param->add_stride(2); - convolution_param->add_dilation(1); convolution_param->set_num_output(1); convolution_param->set_bias_term(false); shared_ptr > layer( @@ -628,7 +715,6 @@ TYPED_TEST(ConvolutionLayerTest, TestGradient) { this->blob_top_vec_.push_back(this->blob_top_2_); convolution_param->add_kernel_size(3); convolution_param->add_stride(2); - convolution_param->add_dilation(1); convolution_param->set_num_output(2); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("gaussian"); @@ -638,6 +724,30 @@ TYPED_TEST(ConvolutionLayerTest, TestGradient) { this->blob_top_vec_); } +TYPED_TEST(ConvolutionLayerTest, TestDilatedGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + vector bottom_shape; + bottom_shape.push_back(2); + bottom_shape.push_back(3); + bottom_shape.push_back(5); + bottom_shape.push_back(6); + for (int i = 0; i < this->blob_bottom_vec_.size(); ++i) { + this->blob_bottom_vec_[i]->Reshape(bottom_shape); + } + convolution_param->add_kernel_size(3); + convolution_param->add_dilation(2); + convolution_param->set_num_output(2); + convolution_param->mutable_weight_filler()->set_type("gaussian"); + convolution_param->mutable_bias_filler()->set_type("gaussian"); + ConvolutionLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + TYPED_TEST(ConvolutionLayerTest, TestGradient3D) { typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; @@ -657,7 +767,6 @@ TYPED_TEST(ConvolutionLayerTest, TestGradient3D) { } convolution_param->add_kernel_size(3); convolution_param->add_stride(2); - convolution_param->add_dilation(1); convolution_param->set_num_output(2); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("gaussian"); @@ -676,7 +785,6 @@ TYPED_TEST(ConvolutionLayerTest, Test1x1Gradient) { this->blob_top_vec_.push_back(this->blob_top_2_); convolution_param->add_kernel_size(1); convolution_param->add_stride(1); - convolution_param->add_dilation(1); convolution_param->set_num_output(2); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("gaussian"); @@ -693,7 +801,6 @@ TYPED_TEST(ConvolutionLayerTest, TestGradientGroup) { layer_param.mutable_convolution_param(); convolution_param->add_kernel_size(3); convolution_param->add_stride(2); - convolution_param->add_dilation(1); convolution_param->set_num_output(3); convolution_param->set_group(3); convolution_param->mutable_weight_filler()->set_type("gaussian"); diff --git a/src/caffe/test/test_im2col_layer.cpp b/src/caffe/test/test_im2col_layer.cpp index 24885e6b706..a7faf18f972 100644 --- a/src/caffe/test/test_im2col_layer.cpp +++ b/src/caffe/test/test_im2col_layer.cpp @@ -17,7 +17,7 @@ class Im2colLayerTest : public MultiDeviceTest { typedef typename TypeParam::Dtype Dtype; protected: Im2colLayerTest() - : blob_bottom_(new Blob(2, 3, 10, 11)), + : blob_bottom_(new Blob(2, 3, 6, 5)), blob_top_(new Blob()) { // fill the values Caffe::set_random_seed(1701); @@ -41,6 +41,12 @@ TYPED_TEST(Im2colLayerTest, TestSetup) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); + vector bottom_shape; + bottom_shape.push_back(2); + bottom_shape.push_back(3); + bottom_shape.push_back(10); + bottom_shape.push_back(11); + this->blob_bottom_->Reshape(bottom_shape); convolution_param->add_kernel_size(3); convolution_param->add_stride(2); convolution_param->add_dilation(3); @@ -76,21 +82,39 @@ TYPED_TEST(Im2colLayerTest, TestGradient) { layer_param.mutable_convolution_param(); convolution_param->add_kernel_size(3); convolution_param->add_stride(2); - convolution_param->add_dilation(3); Im2colLayer layer(layer_param); GradientChecker checker(1e-2, 1e-2); checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, this->blob_top_vec_); } -TYPED_TEST(Im2colLayerTest, TestGradientForceND) { +TYPED_TEST(Im2colLayerTest, TestDilatedGradient) { typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); + vector bottom_shape; + bottom_shape.push_back(2); + bottom_shape.push_back(3); + bottom_shape.push_back(10); + bottom_shape.push_back(9); + this->blob_bottom_->Reshape(bottom_shape); convolution_param->add_kernel_size(3); convolution_param->add_stride(2); convolution_param->add_dilation(3); + Im2colLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(Im2colLayerTest, TestGradientForceND) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_force_nd_im2col(true); Im2colLayer layer(layer_param); GradientChecker checker(1e-2, 1e-2); @@ -98,6 +122,27 @@ TYPED_TEST(Im2colLayerTest, TestGradientForceND) { this->blob_top_vec_); } +TYPED_TEST(Im2colLayerTest, TestDilatedGradientForceND) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + vector bottom_shape; + bottom_shape.push_back(2); + bottom_shape.push_back(3); + bottom_shape.push_back(10); + bottom_shape.push_back(9); + this->blob_bottom_->Reshape(bottom_shape); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); + convolution_param->add_dilation(3); + convolution_param->set_force_nd_im2col(true); + Im2colLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + TYPED_TEST(Im2colLayerTest, TestRect) { typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; @@ -116,7 +161,6 @@ TYPED_TEST(Im2colLayerTest, TestRect) { } } - TYPED_TEST(Im2colLayerTest, TestRectGradient) { typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; @@ -125,8 +169,6 @@ TYPED_TEST(Im2colLayerTest, TestRectGradient) { convolution_param->set_kernel_h(5); convolution_param->set_kernel_w(3); convolution_param->add_stride(2); - convolution_param->add_dilation(1); - convolution_param->add_dilation(3); Im2colLayer layer(layer_param); GradientChecker checker(1e-2, 1e-2); checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, diff --git a/src/caffe/util/im2col.cpp b/src/caffe/util/im2col.cpp index 40c84a7aaef..6e5ea875757 100644 --- a/src/caffe/util/im2col.cpp +++ b/src/caffe/util/im2col.cpp @@ -13,9 +13,9 @@ void im2col_cpu(const Dtype* data_im, const int channels, const int dilation_h, const int dilation_w, Dtype* data_col) { const int height_col = (height + 2 * pad_h - - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; const int width_col = (width + 2 * pad_w - - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; const int channels_col = channels * kernel_h * kernel_w; for (int c_col = 0; c_col < channels_col; ++c_col) { int w_offset = c_col % kernel_w; @@ -147,9 +147,9 @@ void col2im_cpu(const Dtype* data_col, const int channels, Dtype* data_im) { caffe_set(height * width * channels, Dtype(0), data_im); const int height_col = (height + 2 * pad_h - - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; const int width_col = (width + 2 * pad_w - - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; const int channels_col = channels * kernel_h * kernel_w; for (int c_col = 0; c_col < channels_col; ++c_col) { int w_offset = c_col % kernel_w; diff --git a/src/caffe/util/im2col.cu b/src/caffe/util/im2col.cu index ae0a8077424..5bf0ceebfd9 100644 --- a/src/caffe/util/im2col.cu +++ b/src/caffe/util/im2col.cu @@ -48,9 +48,9 @@ void im2col_gpu(const Dtype* data_im, const int channels, // We are going to launch channels * height_col * width_col kernels, each // kernel responsible for copying a single-channel grid. int height_col = (height + 2 * pad_h - - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; int width_col = (width + 2 * pad_w - - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; int num_kernels = channels * height_col * width_col; // NOLINT_NEXT_LINE(whitespace/operators) im2col_gpu_kernel<< Date: Fri, 25 Dec 2015 15:55:51 -0500 Subject: [PATCH 4/4] add attribution for dilation in convolution --- src/caffe/layers/base_conv_layer.cpp | 3 +++ src/caffe/proto/caffe.proto | 7 +++++++ src/caffe/solvers/adam_solver.cpp | 2 +- 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/caffe/layers/base_conv_layer.cpp b/src/caffe/layers/base_conv_layer.cpp index af89b4dcd53..106ea726aba 100644 --- a/src/caffe/layers/base_conv_layer.cpp +++ b/src/caffe/layers/base_conv_layer.cpp @@ -105,6 +105,9 @@ void BaseConvolutionLayer::LayerSetUp(const vector*>& bottom, for (int i = 0; i < num_spatial_axes_; ++i) { dilation_data[i] = (num_dilation_dims == 0) ? kDefaultDilation : conv_param.dilation((num_dilation_dims == 1) ? 0 : i); + if (reverse_dimensions()) { + CHECK_EQ(dilation_data[i], 1) << "Deconvolution doesn't support dilation"; + } } // Special case: im2col is the identity for 1x1 convolution with stride 1 // and no padding, so flag for skipping the buffer and transformation. diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 87c46629baf..5d306b768ea 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -518,6 +518,13 @@ message ConvolutionParameter { repeated uint32 pad = 3; // The padding size; defaults to 0 repeated uint32 kernel_size = 4; // The kernel size repeated uint32 stride = 6; // The stride; defaults to 1 + // Properties of dilated convolution are described in + // F. Yu et al, Multi-Scale Context Aggregation by Dilated Convolutions, 2015. + // "dilation" was first proposed and used in a trous algorithm, as decribed in + // Holschneider et al. A Real-Time Algorithm for Signal Analysis + // with the Help of the Wavelet Transform, 1987. + // It was called filter rarefaction in Long* and Shelhamer* and Darrell, + // Fully Convolutional Networks for Semantic Segmentation, 2014. repeated uint32 dilation = 18; // The dilation; defaults to 1 // For 2D convolution only, the *_h and *_w versions may also be used to diff --git a/src/caffe/solvers/adam_solver.cpp b/src/caffe/solvers/adam_solver.cpp index c3378d3890a..cb0fbfe2f78 100644 --- a/src/caffe/solvers/adam_solver.cpp +++ b/src/caffe/solvers/adam_solver.cpp @@ -30,7 +30,7 @@ void AdamSolver::ComputeUpdateValue(int param_id, Dtype rate) { Blob* val_v = this->history_[param_id + update_history_offset].get(); Blob* val_t = this->temp_[param_id].get(); - const int t = this->iter_ + 1; + const int t = this->iter_ + 1; const Dtype correction = std::sqrt(Dtype(1) - pow(beta2, t)) / (Dtype(1.) - pow(beta1, t)); const int N = net_params[param_id]->count();