This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
/
np_indexing_op.cc
551 lines (514 loc) · 21.8 KB
/
np_indexing_op.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file np_indexing_op.cc
*/
#include "./np_indexing_op.h"
namespace mxnet {
namespace op {
struct AdvancedIndexingTakeCPU {
// assume that idx have been flattened to a 1-D tensor (N,)
// assume that out_data and in_data have been flattened to 2-D tensors, (N, M) and (K, M)
// M is the number of columns of in_data and out_data
// K is the number of rows of in_data
// i is the index of out_data
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const DType* in_data,
const IType* idx, const size_t M, const int64_t K) {
int64_t j = static_cast<int64_t>(idx[i]);
j = j % K;
j += (j < 0) ? K : 0;
#pragma GCC diagnostic push
#if __GNUC__ >= 8
#pragma GCC diagnostic ignored "-Wclass-memaccess"
#endif
std::memcpy(out_data + i * M, in_data + j * M, M * sizeof(DType));
#pragma GCC diagnostic pop
}
};
struct AdvancedIndexingTakeMultiDimensionCPU {
// assume that idx have been flattened to a 1-D tensor (N,)
// assume that out_data and in_data have been flattened to 2-D tensors, (N, M) and (K, M)
// M is the number of columns of in_data and out_data
// K is the number of rows of in_data
// i is the index of out_data
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const DType* in_data,
const IType* idx, const size_t M, const int64_t K) {
int64_t j = static_cast<int64_t>(idx[i]);
j = j % K;
j += (j < 0) ? K : 0;
#pragma GCC diagnostic push
#if __GNUC__ >= 8
#pragma GCC diagnostic ignored "-Wclass-memaccess"
#endif
std::memcpy(out_data + i * M, in_data + (i * K + j) * M, M * sizeof(DType));
#pragma GCC diagnostic pop
}
};
struct AdvancedIndexingBooleanMaskBackwardCPUWriteKernel {
template<typename DType>
static void Map(int i,
DType* igrad,
const OpReqType /*req*/,
const DType* ograd,
const int32_t* idx,
const size_t col_size) {
// i is row id already
int32_t prev = (i == 0) ? 0 : idx[i - 1];
int32_t curr = idx[i];
#pragma GCC diagnostic push
#if __GNUC__ >= 8
#pragma GCC diagnostic ignored "-Wclass-memaccess"
#endif
if (prev != curr) {
std::memcpy(igrad + i * col_size, ograd + prev * col_size, col_size * sizeof(DType));
} else {
std::memset(igrad + i * col_size, 0, col_size * sizeof(DType));
}
#pragma GCC diagnostic pop
}
};
template<typename DType>
bool CheckIndexOutOfBound(const DType* data_ptr, size_t data_size,
const DType min, const DType max) {
bool is_valid = true;
for (size_t i = 0; i < data_size; i++) {
if (data_ptr[i] > max || data_ptr[i] < min) {
is_valid = false;
break;
}
}
return is_valid;
}
template<typename DType>
void GatherNDCheckBoundCPU(mshadow::Stream<cpu> *s, const DType* idx_ptr, index_t N,
index_t M, const mshadow::Shape<10> mshape, DType* is_valid_dim_ptr) {
using namespace mxnet_op;
Kernel<set_zero, cpu>::Launch(s, M, is_valid_dim_ptr);
Kernel<is_valid_check_gather_nd, cpu>::Launch(s, M, is_valid_dim_ptr, idx_ptr, N, mshape);
for (int m = 0; m < M; m++) {
if (is_valid_dim_ptr[m] > mshape[m] - 1 || is_valid_dim_ptr[m] < - mshape[m]) {
LOG(FATAL)<< "IndexError: index " << is_valid_dim_ptr[m] << " is out of bounds for axis "
<< m << " with size " << mshape[m];
}
}
}
inline bool AdvancedIndexingOpType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
CHECK_NE((*in_attrs)[1], -1) << "Index type must be set for take operator";
TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]);
TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[0]);
return (*in_attrs)[0] != -1;
}
bool AdvancedIndexingOpStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2);
CHECK_EQ(out_attrs->size(), 1);
for (int &attr : *in_attrs) {
CHECK_EQ(attr, kDefaultStorage) << "Only default storage is supported";
}
for (int &attr : *out_attrs) {
attr = kDefaultStorage;
}
*dispatch_mode = DispatchMode::kFComputeEx;
return true;
}
bool AdvancedIndexingOpBackStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 3);
CHECK_EQ(out_attrs->size(), 2);
for (int &attr : *in_attrs) {
CHECK_EQ(attr, kDefaultStorage) << "Only default storage is supported";
}
for (int &attr : *out_attrs) {
attr = kDefaultStorage;
}
for (int & out_attr : *out_attrs)
out_attr = kDefaultStorage;
*dispatch_mode = DispatchMode::kFComputeEx;
return true;
}
template<>
void AdvancedIndexingOpForward<cpu>(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
using namespace mxnet_op;
if (req[np_indexing_::kOut] == kNullOp) return;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
if (inputs[np_indexing_::kIdx].dtype() == mshadow::kBool) {
CHECK(req[0] == kWriteTo || req[0] == kWriteInplace);
const NDArray &data = inputs[0];
const NDArray &idx = inputs[1];
const NDArray &out = outputs[0];
CHECK_EQ(data.shape()[0], idx.shape()[0]);
CHECK_EQ(idx.shape().ndim(), 1U); // idx is required to be 1-d.
// count the number of 1s in `idx`, so that we could know the output dimension
size_t idx_size = idx.shape()[0];
std::vector<int32_t> prefix_sum(idx_size, 0);
size_t valid_num = 0;
// Calculate prefix sum
bool* idx_dptr = idx.data().dptr<bool>();
for (size_t i = 0; i < idx_size; i++) {
prefix_sum[i] = (i == 0) ? 0 : prefix_sum[i - 1];
prefix_sum[i] += (idx_dptr[i]) ? 1 : 0;
}
valid_num = prefix_sum[idx_size - 1];
// set the output shape forcefully
mxnet::TShape s = data.shape();
s[0] = valid_num;
const_cast<NDArray &>(out).Init(s);
// do the copy
MSHADOW_TYPE_SWITCH_WITH_BOOL(data.dtype(), DType, {
size_t input_size = data.shape().Size();
size_t col_size = input_size / idx_size;
mshadow::Stream<cpu> *stream = ctx.get_stream<cpu>();
mxnet_op::Kernel<BooleanMaskForwardCPUKernel, cpu>::Launch(
stream, idx_size, out.data().dptr<DType>(), data.data().dptr<DType>(),
prefix_sum.data(), col_size);
});
} else if (inputs[np_indexing_::kIdx].dtype() == mshadow::kInt8 ||
inputs[np_indexing_::kIdx].dtype() == mshadow::kInt16 ||
inputs[np_indexing_::kIdx].dtype() == mshadow::kInt32 ||
inputs[np_indexing_::kIdx].dtype() == mshadow::kInt64) {
using namespace mshadow;
const mxnet::TShape& idxshape = inputs[np_indexing_::kIdx].shape();
const mxnet::TShape& arrshape = inputs[np_indexing_::kArr].shape();
if (idxshape.Size() == 0) {
return;
}
mxnet::TShape oshape(idxshape.ndim() + arrshape.ndim() - 1, -1);
for (index_t i = 0; i < idxshape.ndim(); ++i) {
oshape[i] = idxshape[i];
}
for (index_t i = 0; i < arrshape.ndim(); i++) {
if (i < 0) {
oshape[i] = arrshape[i];
} else if (i > 0) {
oshape[i + idxshape.ndim() - 1] = arrshape[i];
}
}
const NDArray &out = outputs[0];
const_cast<NDArray &>(out).Init(oshape);
Stream<cpu> *s = ctx.get_stream<cpu>();
MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[np_indexing_::kOut].dtype(), DType, { // output data type
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[np_indexing_::kIdx].dtype(), IType, { // index data type
IType min = 0;
IType max = static_cast<IType>(arrshape[0] - 1);
// check with single thread is faster since data is small
IType* idx_ptr = inputs[np_indexing_::kIdx].data().dptr<IType>();
size_t idx_size = idxshape.Size();
bool is_valid = CheckIndexOutOfBound(idx_ptr, idx_size, min, max);
CHECK(is_valid) << "take operator contains indices out of bound";
Kernel<AdvancedIndexingTakeCPU, cpu>::Launch(s, idxshape.Size(),
outputs[np_indexing_::kOut].data().dptr<DType>(),
inputs[np_indexing_::kArr].data().dptr<DType>(),
inputs[np_indexing_::kIdx].data().dptr<IType>(),
oshape.Size()/idxshape.Size(), arrshape[0]);
});
});
} else {
LOG(FATAL)
<< "arrays used as indices must be explictly declared as integer (or boolean) type. "
<< "Use np.astype() to cast indices to integer or boolean.";
}
}
template<>
void AdvancedIndexingOpBackward<cpu>(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 2U);
if (req[0] == kNullOp) return;
if (inputs[np_indexing_::kIdx+1].dtype() == mshadow::kBool) {
// inputs: {ograd, data, idx}
// outputs: {igrad_data, igrad_idx}
const NDArray& ograd = inputs[0];
const NDArray& idx = inputs[2];
const NDArray& igrad_data = outputs[0];
MSHADOW_TYPE_SWITCH(igrad_data.dtype(), DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(idx.dtype(), IType, {
size_t input_size = igrad_data.shape().Size();
size_t idx_size = idx.shape()[0];
size_t col_size = input_size / idx_size;
std::vector<int32_t> prefix_sum(idx_size, 0);
bool* idx_dptr = idx.data().dptr<bool>();
for (size_t i = 0; i < idx_size; i++) {
prefix_sum[i] = (i == 0) ? 0 : prefix_sum[i - 1];
prefix_sum[i] += (idx_dptr[i]) ? 1 : 0;
}
mshadow::Stream<cpu> *stream = ctx.get_stream<cpu>();
if (req[0] == kAddTo) {
mxnet_op::Kernel<BooleanMaskBackwardKernel, cpu>::Launch(
stream, idx_size, igrad_data.data().dptr<DType>(), req[0],
ograd.data().dptr<DType>(), prefix_sum.data(), col_size);
} else {
mxnet_op::Kernel<AdvancedIndexingBooleanMaskBackwardCPUWriteKernel, cpu>::Launch(
stream, idx_size, igrad_data.data().dptr<DType>(), req[0],
ograd.data().dptr<DType>(), prefix_sum.data(), col_size);
}
});
});
} else if (inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt8 ||
inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt16 ||
inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt32 ||
inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt64) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_NE(req[np_indexing_::kIdx], kAddTo)
<< "take layer doesn't support gradient of req type kAddTo to index";
// grad_out is the gradient of the outputs in the feed-forward
// grad_in is the gradient of the inputs in the feed-forward
Stream<cpu> *s = ctx.get_stream<cpu>();
MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { // output data type
MSHADOW_TYPE_SWITCH(inputs[2].dtype(), IType, { // index data type
// inputs are specified in the .cc file, which are the gradients from
// the upper layer and the input index
// outputs are the gradients of inputs in the feed-forward pass
const mxnet::TShape& idxshape = inputs[2].shape();
const mxnet::TShape& arrshape = outputs[0].shape();
const mxnet::TShape& oshape = inputs[0].shape();
if (idxshape.Size() == 0) {
return;
}
if (req[np_indexing_::kIdx] != kNullOp) {
mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(
s, idxshape.Size(), outputs[np_indexing_::kIdx].data().dptr<IType>());
}
int idxndim = idxshape.ndim();
Tensor<cpu, 1, IType> idx = inputs[2].data().get_with_shape<cpu, 1, IType>(
Shape1(idxshape.ProdShape(0, idxndim)), s);
Tensor<cpu, 2, DType> grad_out = inputs[0].data().get_with_shape<cpu, 2, DType>(
Shape2(oshape.ProdShape(0, idxndim), oshape.ProdShape(idxndim, oshape.ndim())), s);
Tensor<cpu, 2, DType> grad_in = outputs[0].data().get_with_shape<cpu, 2, DType>(
Shape2(arrshape[0], arrshape.ProdShape(1, arrshape.ndim())), s);
// re-using the previous code for axis = 0 case
if (req[np_indexing_::kArr] == kWriteTo || req[np_indexing_::kArr] == kAddTo) {
if (req[np_indexing_::kArr] == kWriteTo) {
grad_in = scalar<DType>(0.0f);
}
AddTakeGrad<false>(grad_in, idx, grad_out);
} else {
LOG(FATAL) << "wrong req";
}
});
});
} else {
LOG(FATAL)
<< "arrays used as indices must be explictly declared as integer (or boolean) type. "
<< "Use np.astype() to cast indices to integer or boolean.";
}
}
void AdvancedIndexingMultipleForwardCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
using namespace mshadow;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
if (req[0] == kNullOp) return;
if (inputs[np_indexing_::kIdx].type_flag_ == mshadow::kBool) {
LOG(FATAL)
<< "Multi-dimension boolean indexing is not supported.";
} else if (inputs[np_indexing_::kIdx].type_flag_ == mshadow::kInt8 ||
inputs[np_indexing_::kIdx].type_flag_ == mshadow::kInt16 ||
inputs[np_indexing_::kIdx].type_flag_ == mshadow::kInt32 ||
inputs[np_indexing_::kIdx].type_flag_ == mshadow::kInt64) {
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
const mxnet::TShape& dshape = inputs[0].shape_;
const mxnet::TShape& ishape = inputs[1].shape_;
int M = ishape[0];
int N = ishape.Size() / M;
int K = dshape.ProdShape(M, dshape.ndim());
mshadow::Shape<10> strides;
mshadow::Shape<10> mshape;
for (int i = M-1, stride = K; i >= 0; stride *= dshape[i], --i) {
strides[i] = stride;
mshape[i] = dshape[i];
}
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { // output data type switch
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // indices data type switch
// check whether indices are out of bound
IType* idx_ptr = inputs[1].dptr<IType>();
Tensor<cpu, 1, IType> workspace =
ctx.requested[0].get_space_typed<cpu, 1, IType>(Shape1(M), s);
IType* is_valid_dim_ptr = reinterpret_cast<IType*>(workspace.dptr_);
GatherNDCheckBoundCPU(s, idx_ptr, N, M, mshape, is_valid_dim_ptr);
Kernel<gather_nd, cpu>::Launch(
s, N, req[0], N, M, K, strides, mshape, outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<IType>());
});
});
} else {
LOG(FATAL)
<< "arrays used as indices must be explictly declared as integer (or boolean) type."
<< "Use np.astype() to cast indices to integer or boolean.";
}
}
template<typename DType, typename IType>
inline typename std::enable_if<(!std::is_same<DType, mshadow::half::half_t>::value), void>::type
GatherNDBackwardImpl(index_t N, index_t M, index_t K,
const mshadow::Shape<10> strides,
DType* out,
const DType* data,
const IType* indices,
mshadow::Stream<cpu> *s) {
#pragma omp parallel for
for (index_t i = 0; i < N; i++) {
index_t offset = 0;
for (index_t j = 0; j < M; ++j) {
offset += strides[j] * static_cast<index_t>(indices[j*N + i]);
}
for (index_t j = 0; j < K; ++j) {
#pragma omp atomic
out[offset + j] += data[i * K + j];
}
}
}
template<typename DType, typename IType>
inline typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value, void>::type
GatherNDBackwardImpl(index_t N, index_t M, index_t K,
const mshadow::Shape<10> strides,
DType* out,
const DType* data,
const IType* indices,
mshadow::Stream<cpu> *s) {
for (index_t i = 0; i < N; i++) {
index_t offset = 0;
for (index_t j = 0; j < M; ++j) {
offset += strides[j] * static_cast<index_t>(indices[j*N + i]);
}
for (index_t j = 0; j < K; ++j) {
out[offset + j] += data[i * K + j];
}
}
}
NNVM_REGISTER_OP(_npi_advanced_indexing)
.describe(R"code(
Combination of boolean indexing and advanced ndarray indexing
)code" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "indices"};
})
.set_attr<nnvm::FInferType>("FInferType", AdvancedIndexingOpType)
.set_attr<FComputeEx>("FComputeEx<cpu>", AdvancedIndexingOpForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_np_advanced_indexing"})
.set_attr<FInferStorageType>("FInferStorageType", AdvancedIndexingOpStorageType)
.add_argument("data", "NDArray-or-Symbol", "Data")
.add_argument("indices", "NDArray-or-Symbol", "Indices");
NNVM_REGISTER_OP(_backward_np_advanced_indexing)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FInferStorageType>("FInferStorageType", AdvancedIndexingOpBackStorageType)
.set_attr<FComputeEx>("FComputeEx<cpu>", AdvancedIndexingOpBackward<cpu>);
NNVM_REGISTER_OP(_npi_advanced_indexing_multiple)
.describe(R"code(
Combination of multiple boolean indexing and advanced indexing
)code")
.set_num_outputs(1)
.set_num_inputs(2)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "indices"};
})
.set_attr<mxnet::FInferShape>("FInferShape", GatherNDShape)
.set_attr<nnvm::FInferType>("FInferType", GatherNDType)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", AdvancedIndexingMultipleForwardCPU)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
auto p = nnvm::Node::Create();
p->attrs.op = nnvm::Op::Get("_backward_np_advanced_indexing_multiple");
p->attrs.name = n->attrs.name + "_backward";
p->inputs.push_back(ograds[0]);
p->inputs.push_back(n->inputs[1]);
p->control_deps.emplace_back(n);
auto zero = MakeNode("zeros_like", n->attrs.name + "_backward_indices",
{n->inputs[1]}, nullptr, &n);
std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(p);
ret.emplace_back(zero);
return ret;
})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.add_argument("data", "NDArray-or-Symbol", "data")
.add_argument("indices", "NDArray-or-Symbol", "indices");
NNVM_REGISTER_OP(_backward_np_advanced_indexing_multiple)
.describe(R"code(Accumulates data according to indices and get the result. It's the backward of
`_npi_advanced_indexing_multiple`.
)code")
.set_num_outputs(1)
.set_num_inputs(2)
.set_attr_parser(ParamParser<ScatterNDParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "indices"};
})
.set_attr<mxnet::FInferShape>("FInferShape", ScatterNDShape)
.set_attr<nnvm::FInferType>("FInferType", ScatterNDType)
.set_attr<FCompute>("FCompute<cpu>", AdvancedIndexingMultipleBackward<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
auto p = nnvm::Node::Create();
p->attrs.op = nnvm::Op::Get("_npi_advanced_indexing_multiple");
p->attrs.name = n->attrs.name + "_backward";
p->inputs.push_back(ograds[0]);
p->inputs.push_back(n->inputs[1]);
p->control_deps.emplace_back(n);
auto zero = MakeNode("zeros_like", n->attrs.name + "_backward_indices",
{n->inputs[1]}, nullptr, &n);
std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(p);
ret.emplace_back(zero);
return ret;
})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.add_argument("data", "NDArray-or-Symbol", "data")
.add_argument("indices", "NDArray-or-Symbol", "indices")
.add_arguments(ScatterNDParam::__FIELDS__());
} // namespace op
} // namespace mxnet