Skip to content

Commit

Permalink
support tensor indexer with negative (#8938)
Browse files Browse the repository at this point in the history
* support tensor index with negative

* remove unnecessary code

* add check msg

* add validation check in cuda code

Co-authored-by: Yinggang Wang <wyg19970408@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Aug 22, 2022
1 parent e39e580 commit 78ba55c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
15 changes: 13 additions & 2 deletions oneflow/user/kernels/nd_index_slice_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#ifndef ONEFLOW_USER_KERNELS_ND_INDEX_SLICE_UTIL_H_
#define ONEFLOW_USER_KERNELS_ND_INDEX_SLICE_UTIL_H_

#include "oneflow/core/common/shape.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/ndarray/xpu_util.h"

Expand Down Expand Up @@ -91,11 +92,21 @@ OF_DEVICE_FUNC int64_t OffsetInSliceToOffsetInDense(int64_t slice_size, int64_t
const int64_t* dense_shape, const I* indices,
int64_t n) {
int64_t slice_idx = n / slice_size;
const I* cur_nd_index_ptr = indices + slice_idx * index_ndims;
const I* nd_index = indices + slice_idx * index_ndims;
int64_t offset = 0;
int64_t product = 1;
int64_t shifted_index = 0;
for (int64_t i = index_ndims - 1; i >= 0; --i) {
offset += cur_nd_index_ptr[i] * product;
#if defined(__CUDACC__)
assert(nd_index[i] < dense_shape[i] && nd_index[i] >= -dense_shape[i] && "index out of bounds");
#else
CHECK(nd_index[i] < dense_shape[i] && nd_index[i] >= -dense_shape[i])
<< "IndexError: index " << nd_index[i] << " is out of bounds for dimension " << i
<< " with size " << dense_shape[i];
#endif
shifted_index = nd_index[i] < 0 && nd_index[i] >= -dense_shape[i] ? nd_index[i] + dense_shape[i]
: nd_index[i];
offset += shifted_index * product;
product *= dense_shape[i];
}
return offset * slice_size + n % slice_size;
Expand Down
7 changes: 3 additions & 4 deletions python/oneflow/test/tensor/test_tensor_indexing2.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,11 +530,10 @@ def get_set_tensor(indexed, indexer):
# # weird shape
# [slice(None), [[0, 1],
# [2, 3]]],
# BUG(wyg): It has bug when using negative indexing(setitem and getitem)
# negatives
# [[-1], [0]],
# [[0, 2], [-1]],
# [slice(None), [-1]],
[[-1], [0]],
[[0, 2], [-1]],
[slice(None), [-1]],
]

# test getitem
Expand Down

0 comments on commit 78ba55c

Please sign in to comment.