ND convolution with im2col #2049

Merged
merged 3 commits into from Sep 19, 2015
View
@@ -219,6 +219,7 @@ class Blob {
const Dtype* cpu_data() const;
void set_cpu_data(Dtype* data);
+ const int* gpu_shape() const;
const Dtype* gpu_data() const;
const Dtype* cpu_diff() const;
const Dtype* gpu_diff() const;
@@ -268,6 +269,7 @@ class Blob {
protected:
shared_ptr<SyncedMemory> data_;
shared_ptr<SyncedMemory> diff_;
+ shared_ptr<SyncedMemory> shape_data_;
vector<int> shape_;
int count_;
int capacity_;
@@ -4,24 +4,48 @@
namespace caffe {
template <typename Dtype>
+void im2col_nd_cpu(const Dtype* data_im, const int num_spatial_axes,
+ const int* im_shape, const int* col_shape,
+ const int* kernel_shape, const int* pad, const int* stride,
+ Dtype* data_col);
+
+template <typename Dtype>
void im2col_cpu(const Dtype* data_im, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h,
const int stride_w, Dtype* data_col);
template <typename Dtype>
+void col2im_nd_cpu(const Dtype* data_col, const int num_spatial_axes,
+ const int* im_shape, const int* col_shape,
+ const int* kernel_shape, const int* pad, const int* stride,
+ Dtype* data_im);
+
+template <typename Dtype>
void col2im_cpu(const Dtype* data_col, const int channels,
const int height, const int width, const int patch_h, const int patch_w,
const int pad_h, const int pad_w, const int stride_h,
const int stride_w, Dtype* data_im);
template <typename Dtype>
+void im2col_nd_gpu(const Dtype* data_im, const int num_spatial_axes,
+ const int col_size, const int* im_shape, const int* col_shape,
+ const int* kernel_shape, const int* pad, const int* stride,
+ Dtype* data_col);
+
+template <typename Dtype>
void im2col_gpu(const Dtype* data_im, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h,
const int stride_w, Dtype* data_col);
template <typename Dtype>
+void col2im_nd_gpu(const Dtype* data_col, const int num_spatial_axes,
+ const int im_size, const int* im_shape, const int* col_shape,
+ const int* kernel_shape, const int* pad, const int* stride,
+ Dtype* data_im);
+
+template <typename Dtype>
void col2im_gpu(const Dtype* data_col, const int channels,
const int height, const int width, const int patch_h, const int patch_w,
const int pad_h, const int pad_w, const int stride_h,
@@ -64,46 +64,101 @@ class BaseConvolutionLayer : public Layer<Dtype> {
// Compute height_out_ and width_out_ from other parameters.
virtual void compute_output_shape() = 0;
- int kernel_h_, kernel_w_;
- int stride_h_, stride_w_;
+ /// @brief The spatial dimensions of a filter kernel.
+ Blob<int> kernel_shape_;
+ /// @brief The spatial dimensions of the stride.
+ Blob<int> stride_;
+ /// @brief The spatial dimensions of the padding.
+ Blob<int> pad_;
+ /// @brief The spatial dimensions of the convolution input.
+ Blob<int> conv_input_shape_;
+ /// @brief The spatial dimensions of the input.
+ Blob<int> input_shape_;
+ /// @brief The spatial dimensions of the col_buffer.
+ vector<int> col_buffer_shape_;
+ /// @brief The spatial dimensions of the output.
+ vector<int> output_shape_;
+
+ int num_spatial_axes_;
+ int bottom_dim_;
+ int top_dim_;
+
+ int channel_axis_;
int num_;
int channels_;
- int pad_h_, pad_w_;
- int height_, width_;
int group_;
+ int out_spatial_dim_;
+ int weight_offset_;
int num_output_;
- int height_out_, width_out_;
bool bias_term_;
bool is_1x1_;
+ bool force_nd_im2col_;
private:
// wrap im2col/col2im so we don't have to remember the (long) argument lists
inline void conv_im2col_cpu(const Dtype* data, Dtype* col_buff) {
- im2col_cpu(data, conv_in_channels_, conv_in_height_, conv_in_width_,
- kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, col_buff);
+ if (!force_nd_im2col_ && num_spatial_axes_ == 2) {
+ im2col_cpu(data, conv_in_channels_,
+ conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2],
+ kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
+ pad_.cpu_data()[0], pad_.cpu_data()[1],
+ stride_.cpu_data()[0], stride_.cpu_data()[1], col_buff);
+ } else {
+ im2col_nd_cpu(data, num_spatial_axes_, conv_input_shape_.cpu_data(),
+ col_buffer_shape_.data(), kernel_shape_.cpu_data(),
+ pad_.cpu_data(), stride_.cpu_data(), col_buff);
+ }
}
inline void conv_col2im_cpu(const Dtype* col_buff, Dtype* data) {
- col2im_cpu(col_buff, conv_in_channels_, conv_in_height_, conv_in_width_,
- kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, data);
+ if (!force_nd_im2col_ && num_spatial_axes_ == 2) {
+ col2im_cpu(col_buff, conv_in_channels_,
+ conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2],
+ kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
+ pad_.cpu_data()[0], pad_.cpu_data()[1],
+ stride_.cpu_data()[0], stride_.cpu_data()[1], data);
+ } else {
+ col2im_nd_cpu(col_buff, num_spatial_axes_, conv_input_shape_.cpu_data(),
+ col_buffer_shape_.data(), kernel_shape_.cpu_data(),
+ pad_.cpu_data(), stride_.cpu_data(), data);
+ }
}
#ifndef CPU_ONLY
inline void conv_im2col_gpu(const Dtype* data, Dtype* col_buff) {
- im2col_gpu(data, conv_in_channels_, conv_in_height_, conv_in_width_,
- kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, col_buff);
+ if (!force_nd_im2col_ && num_spatial_axes_ == 2) {
+ im2col_gpu(data, conv_in_channels_,
+ conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2],
+ kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
+ pad_.cpu_data()[0], pad_.cpu_data()[1],
+ stride_.cpu_data()[0], stride_.cpu_data()[1], col_buff);
+ } else {
+ im2col_nd_gpu(data, num_spatial_axes_, num_kernels_im2col_,
+ conv_input_shape_.gpu_data(), col_buffer_.gpu_shape(),
+ kernel_shape_.gpu_data(), pad_.gpu_data(),
+ stride_.gpu_data(), col_buff);
+ }
}
inline void conv_col2im_gpu(const Dtype* col_buff, Dtype* data) {
- col2im_gpu(col_buff, conv_in_channels_, conv_in_height_, conv_in_width_,
- kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, data);
+ if (!force_nd_im2col_ && num_spatial_axes_ == 2) {
+ col2im_gpu(col_buff, conv_in_channels_,
+ conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2],
+ kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
+ pad_.cpu_data()[0], pad_.cpu_data()[1],
+ stride_.cpu_data()[0], stride_.cpu_data()[1], data);
+ } else {
+ col2im_nd_gpu(col_buff, num_spatial_axes_, num_kernels_col2im_,
+ conv_input_shape_.gpu_data(), col_buffer_.gpu_shape(),
+ kernel_shape_.gpu_data(), pad_.gpu_data(), stride_.gpu_data(),
+ data);
+ }
}
#endif
+ int num_kernels_im2col_;
+ int num_kernels_col2im_;
int conv_out_channels_;
int conv_in_channels_;
int conv_out_spatial_dim_;
- int conv_in_height_;
- int conv_in_width_;
int kernel_dim_;
- int weight_offset_;
int col_offset_;
int output_offset_;
@@ -250,7 +305,7 @@ class CuDNNConvolutionLayer : public ConvolutionLayer<Dtype> {
cudnnTensorDescriptor_t bias_desc_;
cudnnFilterDescriptor_t filter_desc_;
vector<cudnnConvolutionDescriptor_t> conv_descs_;
- int bottom_offset_, top_offset_, weight_offset_, bias_offset_;
+ int bottom_offset_, top_offset_, bias_offset_;
size_t workspaceSizeInBytes;
void *workspace;
};
@@ -287,11 +342,22 @@ class Im2colLayer : public Layer<Dtype> {
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
- int kernel_h_, kernel_w_;
- int stride_h_, stride_w_;
+ /// @brief The spatial dimensions of a filter kernel.
+ Blob<int> kernel_shape_;
+ /// @brief The spatial dimensions of the stride.
+ Blob<int> stride_;
+ /// @brief The spatial dimensions of the padding.
+ Blob<int> pad_;
+
+ int num_spatial_axes_;
+ int bottom_dim_;
+ int top_dim_;
+
+ int channel_axis_;
+ int num_;
int channels_;
- int height_, width_;
- int pad_h_, pad_w_;
+
+ bool force_nd_im2col_;
};
// Forward declare PoolingLayer and SplitLayer for use in LRNLayer.
View
@@ -24,11 +24,16 @@ void Blob<Dtype>::Reshape(const vector<int>& shape) {
CHECK_LE(shape.size(), kMaxBlobAxes);
count_ = 1;
shape_.resize(shape.size());
+ if (!shape_data_ || shape_data_->size() < shape.size() * sizeof(int)) {
+ shape_data_.reset(new SyncedMemory(shape.size() * sizeof(int)));
+ }
+ int* shape_data = static_cast<int*>(shape_data_->mutable_cpu_data());
for (int i = 0; i < shape.size(); ++i) {
CHECK_GE(shape[i], 0);
CHECK_LE(shape[i], INT_MAX / count_) << "blob size exceeds INT_MAX";
count_ *= shape[i];
shape_[i] = shape[i];
+ shape_data[i] = shape[i];
}
if (count_ > capacity_) {
capacity_ = count_;
@@ -68,6 +73,12 @@ Blob<Dtype>::Blob(const vector<int>& shape)
}
template <typename Dtype>
+const int* Blob<Dtype>::gpu_shape() const {
+ CHECK(shape_data_);
+ return (const int*)shape_data_->gpu_data();
+}
+
+template <typename Dtype>
const Dtype* Blob<Dtype>::cpu_data() const {
CHECK(data_);
return (const Dtype*)data_->cpu_data();
Oops, something went wrong.