-
Notifications
You must be signed in to change notification settings - Fork 3
/
gather.go
231 lines (195 loc) · 6.56 KB
/
gather.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
package opset13
import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)
const (
MinGatherInputs = 2
MaxGatherInputs = 2
)
// Gather represents the ONNX gather operator.
type Gather struct {
axis int // axis to gather on, default is 0
}
// newGather creates a new gather operator.
func newGather() ops.Operator {
return &Gather{
axis: 0,
}
}
// Init initializes the gather operator.
func (g *Gather) Init(n *onnx.NodeProto) error {
attributes := n.GetAttribute()
if len(attributes) == 1 {
attr := attributes[0]
if attr.GetName() == "axis" {
g.axis = int(attr.GetI())
} else {
return ops.ErrInvalidAttribute(attr.GetName(), g)
}
} else if len(attributes) > 1 {
return ops.ErrInvalidAttributeCount(1, len(attributes), g)
}
return nil
}
// Apply applies the gather operator.
func (g *Gather) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
// Convert the indices (of Dtype Int32 or Int64) to a tensor with Dtype Int
indicesData, err := ops.AnyToIntSlice(ops.IfScalarToSlice(inputs[1].Data()))
if err != nil {
return nil, err
}
indices := tensor.New(tensor.WithBacking(indicesData), tensor.WithShape(inputs[1].Shape()...))
data := inputs[0]
// Make sure axis is in the correct range (according to the size of the data tensor)
rank := len(data.Shape())
dataAxis := g.axis
if dataAxis < -rank || dataAxis > rank-1 {
return nil, ops.ErrAxisOutOfRange(rank, rank, dataAxis)
}
// Offset axis if a negative index is given.
if dataAxis < 0 {
dataAxis += rank
}
// Make sure the input indices are all in the correct range (according to the size of the
// dimension which is selected by `axis`)
axisDimSize := data.Shape()[dataAxis]
if !ops.AllInRange(indicesData, -axisDimSize, axisDimSize-1) {
return nil, ops.ErrNotAllAxesInRange(axisDimSize, axisDimSize)
}
err = ops.OffsetTensorIfNegative(indices, axisDimSize)
if err != nil {
return nil, err
}
// Make the shape of the output tensor
os := insertWithReplace(indices.Shape(), data.Shape(), dataAxis)
output := tensor.New(tensor.WithShape(os...), tensor.Of(data.Dtype()))
// Perform the actual gather operation
err = gather(output, data, indices, dataAxis)
if err != nil {
return nil, err
}
return []tensor.Tensor{output}, nil
}
// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (g *Gather) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(g, inputs)
}
// GetMinInputs returns the minimum number of input tensors this operator expects.
func (g *Gather) GetMinInputs() int {
return MinGatherInputs
}
// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (g *Gather) GetMaxInputs() int {
return MaxGatherInputs
}
// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (g *Gather) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{
ops.AllTypes,
{tensor.Int32, tensor.Int64},
}
}
// String implements the stringer interface, and can be used to format errors or messages.
func (g *Gather) String() string {
return "gather operator"
}
// Perform gather according to the definition given by ONNX :
// --------------------------
// For axis = 0 :
// Let k = indices[i_{0}, ..., i_{q-1}]
// Then output[i_{0}, ..., i_{q-1}, j_{0}, ..., j_{r-2}] = input[k , j_{0}, ..., j_{r-2}]
//
// For axis = 1 :
// Let k = indices[i_{0}, ..., i_{q-1}]
// Then output[i_{0}, ..., i_{q-1}, j_{0}, ..., j_{r-2}] = input[j_{0}, k, j_{1}, ..., j_{r-2}]
// --------------------------
// where q: size of `indices`
//
// r: size of `data`
// i and j are here indices which should be iterated over.
//
// A simplified example of how i and j work in such a statement (not related to gather):
// suppose x = [1, 2] and y = [4, 5], and we have statement:
//
// l = x[i_0]
// output[i_0, j_0] = y[j_0] - l
//
// This means, for each valid combination of (i_0, j_0) (in this case (0,0) (0,1), (1,0) (1,1) )
// we evaluate the expression, so:
//
// l = x[0] -> l = 1
// output[0, 0] = y[0] - l -> output[0,0] = 4 - 1 = 3
// l = x[0] -> l = 1
// output[0, 1] = y[1] - l -> output[0,1] = 5 - 1 = 4
// l = x[1] -> l = 2
// output[1, 0] = y[0] - l -> output[1,0] = 4 - 2 = 2
// l = x[1] -> l = 2
// output[1, 1] = y[1] - l -> output[1,1] = 5 - 2 = 3
//
// so this results in:
//
// output = [ 3 4 ]
// [ 2 3 ]
//
// -------------------------
// The implementation iterates over each element in 'indices', and k is extracted.
// For each given k (and therefore also [i_0, ..., i_q-1]) we need to iterate over each combination
// of [j_0, ..., j_r-1] and perform the above assignment. Instead of explicitly iterating, we use
// slicing to extract the blocks that we need to assign, and then pairwise assign them.
func gather(out, data, indices tensor.Tensor, axis int) error {
it := indices.Iterator()
it.Reset()
for !it.Done() {
coords := it.Coord()
at, err := indices.At(coords...)
if err != nil {
return err
}
k, ok := at.(int)
if !ok {
return ops.ErrTypeAssert("int", at)
}
// Slice that selects `k` on the given axis.
// Equivalent to: data[:, ... , :, k, :, ..., :], where `k` is on the index `axis`
dslices := make([]tensor.Slice, len(data.Shape()))
dslices[axis] = ops.NewSlicer(k)
dataSlice, _ := data.Slice(dslices...)
// slice with the current indices (used to make k) starting from `axis` and
// the rest nil.
// Equivalent to:
// out[:, ... , :, i_1, ..., i_N, :, ..., :]
// where i_1 starts at index 'axis'. Note that: k = indices[i_1, ..., i_N]
oslices := make([]tensor.Slice, len(coords)+len(data.Shape())-1)
for i, s := range coords {
oslices[i+axis] = ops.NewSlicer(s)
}
outputSlice, _ := out.Slice(oslices...)
err = ops.PairwiseAssign(outputSlice, dataSlice)
if err != nil {
return err
}
_, err = it.Next()
if err != nil {
return err
}
}
return nil
}
// insertWithReplace makes a new array, which is equal to an insertion of all elements of `a`
// into `x` at index `axis`. The element at x[axis] is removed (i.e. it is replaced with `a`).
// Output array always has length: len(a) + len(x) - 1
// Example:
// > a = [-1, -2, -3]
// > x = [1, 2, 3, 4, 5, 6, 7]
// insertWithReplace(a, x, 3) -> [1, 2, 3, -1, -2, -3, 5, 6, 7].
func insertWithReplace(a, x []int, axis int) []int {
y := append([]int{}, x[:axis]...)
y = append(y, a...)
if axis+1 < len(x) {
y = append(y, x[axis+1:]...)
}
return y
}