forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gather_op.cc
151 lines (119 loc) · 4.18 KB
/
gather_op.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#include "gather_op.h"
namespace caffe2 {
REGISTER_CPU_OPERATOR(Gather, GatherOp<CPUContext>);
OPERATOR_SCHEMA(Gather)
.NumInputs(2)
.NumOutputs(1)
.SetDoc(R"DOC(
The *Gather* op accepts a *DATA* tensor of rank $r >= 1$ and *INDICES* tensor of rank $q$ as inputs. It then gathers entries of the outer-most dimension of *DATA*, indexed by *INDICES*, and concatenate them in an output tensor of rank $q + (r - 1)$.
Github Links:
- https://github.com/caffe2/caffe2/blob/master/caffe2/operators/gather_op.cc
- https://github.com/caffe2/caffe2/blob/master/caffe2/operators/gather_op.h
<details>
<summary> <b>Example</b> </summary>
**Code**
```
workspace.ResetWorkspace()
op = core.CreateOperator(
"Gather",
["DATA", "INDICES"],
["OUTPUT"]
)
data = np.array([[1., 1.2],[2.3, 3.4],[4.5, 5.7]])
print("DATA:\n",data)
inds = np.array([[0, 1],[1, 2]])
print("INDICES:\n",inds)
// Feed X into workspace
workspace.FeedBlob("DATA", data.astype(np.float32))
workspace.FeedBlob("INDICES", inds.astype(np.int32))
workspace.RunOperatorOnce(op)
print("OUTPUT:\n", workspace.FetchBlob("OUTPUT"))
```
**Result**
```
DATA:
[[1. 1.2]
[2.3 3.4]
[4.5 5.7]]
INDICES:
[[0 1]
[1 2]]
OUTPUT:
[[[1. 1.2]
[2.3 3.4]]
[[2.3 3.4]
[4.5 5.7]]]
```
</details>
)DOC")
.Input(0, "DATA", "Input data tensor of rank $r>=1$")
.Input(
1,
"INDICES",
"Input indices tensor of rank $q$. This tensor must contain integers.")
.Output(0, "OUTPUT", "Output tensor of rank $q+(r-1)$")
.TensorInferenceFunction([](const OperatorDef& def,
const vector<TensorShape>& in) {
ArgumentHelper helper(def);
const int axis = helper.GetSingleArgument<int>("axis", 0);
const bool match_outer =
helper.GetSingleArgument<bool>("match_outer", false);
const auto& data_dims = GetDimsVector(in[0]);
const auto& indices_dims = GetDimsVector(in[1]);
vector<int> output_dims =
caffe2::gather_helper::calc_output_shape_vector<int>(
data_dims, indices_dims, axis, match_outer);
vector<TensorShape> out(1);
out[0] = CreateTensorShape(output_dims, in[0].data_type());
return out;
})
.InheritOnnxSchema();
class GetGatherGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
ArgumentHelper argsHelper(def_);
const bool dense_gradient =
argsHelper.GetSingleArgument<bool>("dense_gradient", false);
const int axis = argsHelper.GetSingleArgument<int>("axis", 0);
// TBD: While it hasn't been used yet, we need to add wrap_indices support
// to gradients next.
// if (argsHelper.HasArgument("wrap_indices_")) {
// }
using Op = GatherOp<CPUContext>;
if (axis == 0) {
if (dense_gradient) {
return vector<OperatorDef>{CreateOperatorDef(
"SparseToDense",
"",
vector<string>{I(Op::INDICES), GO(0), I(Op::DATA)},
vector<string>{GI(Op::DATA)})};
} else {
// For now we don't do any reshaping as the consumer of this op would
// probably be ScatterUpdate which is intenionally ignores shapes. We
// might need to revisit it in the future for correctness purposes. The
// right shape for the output woild be to flatten INDICES and collapse
// first X dims of GRAD
SetSparse(Op::DATA, I(Op::INDICES), GO(0));
return vector<OperatorDef>();
}
}
// TBD: This is misleading to use dense_gradient by default for axis 0
// and not othewise....
if (argsHelper.HasArgument("dense_gradient")) {
CAFFE_ENFORCE(
dense_gradient == true,
"Gather with axis > 0 must use dense_gradient");
}
Argument axisArg = MakeArgument<int>("axis", axis);
return SingleGradientDef(
"BatchGatherGradient",
"",
// This is the order as expected by BatchGatherGradient indices,
// different from SpartseToDense above.
vector<string>{I(Op::DATA), I(Op::INDICES), GO(0)},
vector<string>{GI(0)},
std::vector<Argument>{axisArg});
}
};
REGISTER_GRADIENT(Gather, GetGatherGradient);
} // namespace caffe2