ND Crop layer #3570

Merged
merged 4 commits into from Mar 5, 2016
@@ -0,0 +1,67 @@
+#ifndef CAFFE_CROP_LAYER_HPP_
+#define CAFFE_CROP_LAYER_HPP_
+
+#include <utility>
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+namespace caffe {
+
+/**
+ * @brief Takes a Blob and crop it, to the shape specified by the second input
+ * Blob, across all dimensions after the specified axis.
+ *
+ * TODO(dox): thorough documentation for Forward, Backward, and proto params.
+ */
+
+template <typename Dtype>
+class CropLayer : public Layer<Dtype> {
+ public:
+ explicit CropLayer(const LayerParameter& param)
+ : Layer<Dtype>(param) {}
+ virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+ virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+
+ virtual inline const char* type() const { return "Crop"; }
+ virtual inline int ExactNumBottomBlobs() const { return 2; }
+ virtual inline int ExactNumTopBlobs() const { return 1; }
+
+ protected:
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+ virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+ virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+
+ vector<int> offsets;
+
+ private:
+ void crop_copy(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top,
+ const vector<int>& offsets,
+ vector<int> indices,
+ int cur_dim,
+ const Dtype* src_data,
+ Dtype* dest_data,
+ bool is_forward);
+
+ void crop_copy_gpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top,
+ const vector<int>& offsets,
+ vector<int> indices,
+ int cur_dim,
+ const Dtype* src_data,
+ Dtype* dest_data,
+ bool is_forward);
+};
+} // namespace caffe
+
+#endif // CAFFE_CROP_LAYER_HPP_
@@ -0,0 +1,150 @@
+#include <algorithm>
+#include <functional>
+#include <map>
+#include <set>
+#include <vector>
+
+
+#include "caffe/layer.hpp"
+#include "caffe/layers/crop_layer.hpp"
+#include "caffe/net.hpp"
+
+
+namespace caffe {
+
+template <typename Dtype>
+void CropLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ // All logic that depends only on the number of dimensions is here,
+ // the rest is in Reshape because it depends on Blob size.
+ // bottom[0] supplies the data
+ // bottom[1] supplies the size
+ const CropParameter& param = this->layer_param_.crop_param();
+ CHECK_EQ(bottom.size(), 2) << "Wrong number of bottom blobs.";
+ int input_dim = bottom[0]->num_axes();
+ const int start_axis = bottom[0]->CanonicalAxisIndex(param.axis());
+ CHECK_LT(start_axis, input_dim) << "crop axis bigger than input dim";
+ if (param.offset_size() > 1) {
+ // the number of crop values specified must be equal to the number
+ // of dimensions following axis
+ CHECK_EQ(start_axis + param.offset_size(), input_dim)
+ << "number of offset values specified must be equal to the number of "
+ << "dimensions following axis.";
+ }
+}
+
+template <typename Dtype>
+void CropLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ const CropParameter& param = this->layer_param_.crop_param();
+ int input_dim = bottom[0]->num_axes();
+ const int start_axis = bottom[0]->CanonicalAxisIndex(param.axis());
+
+ // initialize all offsets to 0
+ offsets = vector<int>(input_dim, 0);
+ // initialize new shape to bottom[0]
+ vector<int> new_shape(bottom[0]->shape());
+
+ // apply crops
+ for (int i = 0; i < input_dim; ++i) {
+ int crop_offset = 0;
+ int new_size = bottom[0]->shape(i);
+ if (i >= start_axis) {
+ new_size = bottom[1]->shape(i);
+
+ if (param.offset_size() == 1) {
+ // if only one crop value is supplied, crop all dimensions after axis
+ // by this crop value
+ crop_offset = param.offset(0);
+ } else if (param.offset_size() > 1) {
+ // crop values specified must be equal to the number of dimensions
+ // following axis
+ crop_offset = param.offset(i - start_axis);
+ }
+ }
+ // Check that the image we are cropping minus the margin is bigger
+ // than the destination image.
+ CHECK_GE(bottom[0]->shape(i) - crop_offset,
+ bottom[1]->shape(i))
+ << "invalid crop parameters in dimension: " << i;
+ // Now set new size and offsets
+ new_shape[i] = new_size;
+ offsets[i] = crop_offset;
+ }
+ top[0]->Reshape(new_shape);
+}
+
+// recursive copy function
+template <typename Dtype>
+void CropLayer<Dtype>::crop_copy(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top,
+ const vector<int>& offsets,
+ vector<int> indices,
+ int cur_dim,
+ const Dtype* src_data,
+ Dtype* dest_data,
+ bool is_forward) {
+ if (cur_dim + 1 < top[0]->num_axes()) {
+ // We are not yet at the final dimension, call copy recursively
+ for (int i = 0; i < top[0]->shape(cur_dim); ++i) {
+ indices[cur_dim] = i;
+ crop_copy(bottom, top, offsets, indices, cur_dim+1,
+ src_data, dest_data, is_forward);
+ }
+ } else {
+ // We are at the last dimensions, which is stored continously in memory
+ for (int i = 0; i < top[0]->shape(cur_dim); ++i) {
+ // prepare index vector reduced(red) and with offsets(off)
+ std::vector<int> ind_red(cur_dim, 0);
+ std::vector<int> ind_off(cur_dim+1, 0);
+ for (int j = 0; j < cur_dim; ++j) {
+ ind_red[j] = indices[j];
+ ind_off[j] = indices[j] + offsets[j];
+ }
+ ind_off[cur_dim] = offsets[cur_dim];
+ // do the copy
+ if (is_forward) {
+ caffe_copy(top[0]->shape(cur_dim),
+ src_data + bottom[0]->offset(ind_off),
+ dest_data + top[0]->offset(ind_red));
+ } else {
+ // in the backwards pass the src_data is top_diff
+ // and the dest_data is bottom_diff
+ caffe_copy(top[0]->shape(cur_dim),
+ src_data + top[0]->offset(ind_red),
+ dest_data + bottom[0]->offset(ind_off));
+ }
+ }
+ }
+}
+
+template <typename Dtype>
+void CropLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ std::vector<int> indices(top[0]->num_axes(), 0);
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ Dtype* top_data = top[0]->mutable_cpu_data();
+ crop_copy(bottom, top, offsets, indices, 0, bottom_data, top_data, true);
+}
+
+template <typename Dtype>
+void CropLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
+ const Dtype* top_diff = top[0]->cpu_diff();
+ Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
+
+ if (propagate_down[0]) {
+ caffe_set(bottom[0]->count(), static_cast<Dtype>(0), bottom_diff);
+ std::vector<int> indices(top[0]->num_axes(), 0);
+ crop_copy(bottom, top, offsets, indices, 0, top_diff, bottom_diff, false);
+ }
+}
+
+#ifdef CPU_ONLY
+STUB_GPU(CropLayer);
+#endif
+
+INSTANTIATE_CLASS(CropLayer);
+REGISTER_LAYER_CLASS(Crop);
+
+} // namespace caffe
@@ -0,0 +1,124 @@
+#include <vector>
+
+#include "caffe/layers/crop_layer.hpp"
+
+namespace caffe {
+
+// Copy (one line per thread) from one array to another, with arbitrary
+// strides in the last two dimensions.
+template <typename Dtype>
+__global__ void copy_kernel(const int n, const int height, const int width,
+ const int src_outer_stride, const int src_inner_stride,
+ const int dest_outer_stride, const int dest_inner_stride,
+ const Dtype* src, Dtype* dest) {
+ CUDA_KERNEL_LOOP(index, n) {
+ int src_start = index / height * src_outer_stride
+ + index % height * src_inner_stride;
+ int dest_start = index / height * dest_outer_stride
+ + index % height * dest_inner_stride;
+ for (int i = 0; i < width; ++i) {
+ dest[dest_start + i] = src[src_start + i];
+ }
+ }
+}
+
+// recursive copy function, this function is similar to crop_copy but loops
+// over all but the last two dimensions. It is implemented this way to allow
+// for ND cropping while still relying on a CUDA kernel for the innermost
+// two dimensions for performance reasons.
+// An alternative way to implement ND cropping relying more on the kernel
+// would require passing offsets to the kernel, which is a bit problematic
+// because it is of variable length. Since in the standard (N,C,W,H) case
+// N,C are usually not cropped a speedup could be achieved by not looping
+// the application of the copy_kernel around these dimensions.
+template <typename Dtype>
+void CropLayer<Dtype>::crop_copy_gpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top,
+ const vector<int>& offsets,
+ vector<int> indices,
+ int cur_dim,
+ const Dtype* src_data,
+ Dtype* dest_data,
+ bool is_forward) {
+ if (cur_dim + 2 < top[0]->num_axes()) {
+ // We are not yet at the final dimension, call copy recursivley
+ for (int i = 0; i < top[0]->shape(cur_dim); ++i) {
+ indices[cur_dim] = i;
+ crop_copy_gpu(bottom, top, offsets, indices, cur_dim+1,
+ src_data, dest_data, is_forward);
+ }
+ } else {
+ // We are at the last two dimensions, which are stored continously in memory
+ // With (N,C,H,W)
+ // (0,1,2,3) cur_dim -> H
+ // cur_dim+1 -> W
+ const int lines = top[0]->shape(cur_dim);
+ const int height = top[0]->shape(cur_dim);
+ const int width = top[0]->shape(cur_dim+1);
+ std::vector<int> ind_off(cur_dim+2, 0);
+ for (int j = 0; j < cur_dim; ++j) {
+ ind_off[j] = indices[j] + offsets[j];
+ }
+ ind_off[cur_dim] = offsets[cur_dim];
+ ind_off[cur_dim+1] = offsets[cur_dim+1];
+ // Compute copy strides
+ const int src_outer_stride =
+ bottom[0]->shape(cur_dim)*bottom[0]->shape(cur_dim+1);
+ const int src_inner_stride = bottom[0]->shape(cur_dim+1);
+ const int dest_outer_stride =
+ top[0]->shape(cur_dim)*top[0]->shape(cur_dim+1);
+ const int dest_inner_stride = top[0]->shape(cur_dim+1);
+
+ if (is_forward) {
+ const Dtype* bottom_data = bottom[0]->gpu_data() +
+ bottom[0]->offset(ind_off);
+ Dtype* top_data = top[0]->mutable_gpu_data() +
+ top[0]->offset(indices);
+ // NOLINT_NEXT_LINE(whitespace/operators)
+ copy_kernel<<<CAFFE_GET_BLOCKS(lines), CAFFE_CUDA_NUM_THREADS>>>(
+ lines, height, width,
+ src_outer_stride, src_inner_stride,
+ dest_outer_stride, dest_inner_stride,
+ bottom_data, top_data);
+
+ } else {
+ const Dtype* top_diff = top[0]->gpu_diff() +
+ top[0]->offset(indices);
+ Dtype* bottom_diff = bottom[0]->mutable_gpu_diff() +
+ bottom[0]->offset(ind_off);
+ // NOLINT_NEXT_LINE(whitespace/operators)
+ copy_kernel<<<CAFFE_GET_BLOCKS(lines), CAFFE_CUDA_NUM_THREADS>>>(
+ lines, height, width,
+ dest_outer_stride, dest_inner_stride,
+ src_outer_stride, src_inner_stride,
+ top_diff, bottom_diff);
+ }
+ }
+}
+
+template <typename Dtype>
+void CropLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ std::vector<int> indices(top[0]->num_axes(), 0);
+ const Dtype* bottom_data = bottom[0]->gpu_data();
+ Dtype* top_data = top[0]->mutable_gpu_data();
+ crop_copy_gpu(bottom, top, offsets, indices, 0, bottom_data, top_data, true);
+}
+
+template <typename Dtype>
+void CropLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
+ const Dtype* top_diff = top[0]->gpu_diff();
+ Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
+
+ if (propagate_down[0]) {
+ caffe_gpu_set(bottom[0]->count(), static_cast<Dtype>(0), bottom_diff);
+ std::vector<int> indices(top[0]->num_axes(), 0);
+ crop_copy_gpu(bottom, top, offsets, indices, 0, top_diff, bottom_diff,
+ false);
+ }
+}
+
+INSTANTIATE_LAYER_GPU_FUNCS(CropLayer);
+
+} // namespace caffe
@@ -306,7 +306,7 @@ message ParamSpec {
// NOTE
// Update the next available ID when you add a new LayerParameter field.
//
-// LayerParameter next available layer-specific ID: 144 (last added: input_param)
+// LayerParameter next available layer-specific ID: 145 (last added: crop_param)
message LayerParameter {
optional string name = 1; // the layer name
optional string type = 2; // the layer type
@@ -360,6 +360,7 @@ message LayerParameter {
optional ConcatParameter concat_param = 104;
optional ContrastiveLossParameter contrastive_loss_param = 105;
optional ConvolutionParameter convolution_param = 106;
+ optional CropParameter crop_param = 144;
optional DataParameter data_param = 107;
optional DropoutParameter dropout_param = 108;
optional DummyDataParameter dummy_data_param = 109;
@@ -598,6 +599,24 @@ message ConvolutionParameter {
optional bool force_nd_im2col = 17 [default = false];
}
+message CropParameter {
+ // To crop, elements of the first bottom are selected to fit the dimensions
+ // of the second, reference bottom. The crop is configured by
+ // - the crop `axis` to pick the dimensions for cropping
+ // - the crop `offset` to set the shift for all/each dimension
+ // to align the cropped bottom with the reference bottom.
+ // All dimensions up to but excluding `axis` are preserved, while
+ // the dimensions including and trailing `axis` are cropped.
+ // If only one `offset` is set, then all dimensions are offset by this amount.
+ // Otherwise, the number of offsets must equal the number of cropped axes to
+ // shift the crop in each dimension accordingly.
+ // Note: standard dimensions are N,C,H,W so the default is a spatial crop,
+ // and `axis` may be negative to index from the end (e.g., -1 for the last
+ // axis).
+ optional int32 axis = 1 [default = 2];
+ repeated uint32 offset = 2;
+}
+
message DataParameter {
enum DB {
LEVELDB = 0;
@@ -672,7 +691,7 @@ message EltwiseParameter {
// Message that stores parameters used by ELULayer
message ELUParameter {
// Described in:
- // Clevert, D.-A., Unterthiner, T., & Hochreiter, S. (2015). Fast and Accurate
+ // Clevert, D.-A., Unterthiner, T., & Hochreiter, S. (2015). Fast and Accurate
// Deep Network Learning by Exponential Linear Units (ELUs). arXiv
optional float alpha = 1 [default = 1];
}
Oops, something went wrong.