forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
im2col_op.cc
104 lines (94 loc) · 3.75 KB
/
im2col_op.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#include "caffe2/operators/im2col_op.h"
namespace caffe2 {
REGISTER_CPU_OPERATOR(Im2Col, Im2ColOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(Col2Im, Col2ImOp<float, CPUContext>);
class GetIm2ColGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
"Col2Im",
"",
std::vector<string>{GO(0), I(0)},
std::vector<string>{GI(0)});
}
};
REGISTER_GRADIENT(Im2Col, GetIm2ColGradient);
class GetCol2ImGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
"Im2Col", "", std::vector<string>{GO(0)}, std::vector<string>{GI(0)});
}
};
REGISTER_GRADIENT(Col2Im, GetCol2ImGradient);
OPERATOR_SCHEMA(Im2Col)
.NumInputs(1)
.NumOutputs(1)
.SetDoc("The Im2Col operator from Matlab.")
.TensorInferenceFunction(
[](const OperatorDef& def, const vector<TensorShape>& in) {
ArgumentHelper helper(def);
auto pad = helper.GetSingleArgument<int>("pad", 0);
auto kernel_h = helper.GetSingleArgument<int>(
"kernel_h", helper.GetSingleArgument<int>("kernel", 0));
auto kernel_w = helper.GetSingleArgument<int>(
"kernel_w", helper.GetSingleArgument<int>("kernel", 0));
auto dilation_h = helper.GetSingleArgument<int>(
"dilation_h", helper.GetSingleArgument<int>("dilation", 1));
auto dilation_w = helper.GetSingleArgument<int>(
"dilation_w", helper.GetSingleArgument<int>("dilation", 1));
auto stride_h = helper.GetSingleArgument<int>(
"stride_h", helper.GetSingleArgument<int>("stride", 1));
auto stride_w = helper.GetSingleArgument<int>(
"stride_w", helper.GetSingleArgument<int>("stride", 1));
auto order = StringToStorageOrder(
helper.GetSingleArgument<string>("order", "NCHW"));
const TensorShape& X = in[0];
int N = 0, C = 0, H = 0, W = 0;
switch (order) {
case StorageOrder::NCHW:
N = X.dims(0);
C = X.dims(1);
H = X.dims(2);
W = X.dims(3);
break;
case StorageOrder::NHWC:
N = X.dims(0);
H = X.dims(1);
W = X.dims(2);
C = X.dims(3);
break;
default:
CAFFE_THROW("Unknown storage order: ", order);
}
const int dkernel_h = dilation_h * (kernel_h - 1) + 1;
const int dkernel_w = dilation_w * (kernel_w - 1) + 1;
CAFFE_ENFORCE(H >= dkernel_h);
CAFFE_ENFORCE(W >= dkernel_w);
const int out_h = (H + 2 * pad - dkernel_h) / stride_h + 1;
const int out_w = (W + 2 * pad - dkernel_w) / stride_w + 1;
vector<TensorShape> out(1);
switch (order) {
case StorageOrder::NCHW:
out[0] = CreateTensorShape(
vector<int>{N, C * kernel_h * kernel_w, out_h, out_w},
TensorProto::FLOAT);
break;
case StorageOrder::NHWC:
out[0] = CreateTensorShape(
vector<int>{N, out_h, out_w, kernel_h * kernel_w * C},
TensorProto::FLOAT);
break;
default:
CAFFE_THROW("Unknown storage order: ", order);
}
return out;
})
.Input(0, "X", "4-tensor in NCHW or NHWC.")
.Output(
0,
"Y",
"4-tensor. For NCHW: N x (C x kH x kW) x outH x outW."
"For NHWC: N x outH x outW x (kH x kW x C");
OPERATOR_SCHEMA(Col2Im).NumInputs(2).NumOutputs(1);
} // namespace caffe2