Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor indexing #3195

Merged
merged 24 commits into from Aug 3, 2021
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
218c1ba
CPU tensor indexing.
mzient Jul 22, 2021
039fd7a
Add python wrapper for slicing.
mzient Jul 28, 2021
48e3326
* Add GPU subscript operator
mzient Jul 28, 2021
24e4c9c
Reduce max. non-collapsible dims to 16.
mzient Jul 28, 2021
771da06
Fix output layout calculation.
mzient Jul 28, 2021
6269261
Enable tensor arguments.
mzient Jul 28, 2021
887cf40
Add tests. Fix ExpandDims for 0D input and nonempty new_axis_names.
mzient Jul 28, 2021
6f28e43
Add cpu-only and variable batch size tests.
mzient Jul 29, 2021
bb15eb9
Fix behaviour with swapped ends. Fix a bug in TensorVector - type not…
mzient Jul 29, 2021
0ac474e
Fix out-of-range error message. Add tests for out-of-range and incons…
mzient Jul 29, 2021
847f53e
Add missing file with tests.
mzient Jul 29, 2021
f4146fa
Fix review issues.
mzient Jul 29, 2021
1d8f677
Prevent iteration over DataNode.
mzient Jul 29, 2021
86ba294
Review issues.
mzient Jul 29, 2021
601935e
Add a hack to suppress DataNode being detected as an instance of Iter…
mzient Jul 30, 2021
d0e3369
Add check for the number of indices. Fix some review issues.
mzient Aug 2, 2021
2391452
Add more tests.
mzient Aug 2, 2021
981923a
Add comments.
mzient Aug 2, 2021
2726dd1
Fix Python LGTM.
mzient Aug 2, 2021
f9eab5a
* Add message check to error tests.
mzient Aug 3, 2021
6e878c8
Remove print.
mzient Aug 3, 2021
d93320a
Fix trailin full-range handling.
mzient Aug 3, 2021
a6d6e43
Add comments.
mzient Aug 3, 2021
13961ae
Add subscript_dim_check to cpu_only and variable_batch_size tests.
mzient Aug 3, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 6 additions & 6 deletions dali/operators/generic/expand_dims.cc
@@ -1,4 +1,4 @@
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,9 +29,9 @@ DALI_SCHEMA(ExpandDims)
.DocStr(R"code(Insert new dimension(s) with extent 1 to the data shape.

The new dimensions are inserted at the positions specified by ``axes``.
If ``new_axis_names`` is provided, the new dimension names will be inserted in the data layout,
at the positions specified by ``axes``. If ``new_axis_names`` is not provided, the output data

If ``new_axis_names`` is provided, the new dimension names will be inserted in the data layout,
at the positions specified by ``axes``. If ``new_axis_names`` is not provided, the output data
layout will be empty.")code")
.NumInput(1)
.NumOutput(1)
Expand All @@ -42,7 +42,7 @@ layout will be empty.")code")
.AddArg("axes", R"code(Indices at which the new dimensions are inserted.)code",
DALI_INT_VEC, true)
.AddOptionalArg("new_axis_names", R"code(Names of the new dimensions in the data layout.

The length of ``new_axis_names`` must match the length of ``axes``.
If argument isn't be provided, the layout will be cleared.)code", TensorLayout(""));

Expand Down Expand Up @@ -125,7 +125,7 @@ void ExpandDims<Backend>::GenerateSrcDims(const Workspace &ws) {
out_layout += in_layout.empty() ? 0 : in_layout[d];
this->src_dims_.push_back(d++);
}
if (!in_layout.empty()) {
if (!in_layout.empty() || ndim == 0) {
this->layout_ = use_new_axis_names_arg_ ? out_layout : TensorLayout();
}
}
Expand Down
134 changes: 134 additions & 0 deletions dali/operators/generic/slice/subscript.cc
@@ -0,0 +1,134 @@
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "dali/operators/generic/slice/subscript.h"
#include "dali/kernels/common/type_erasure.h"
#include "dali/kernels/slice/slice_cpu.h"

namespace dali {

#define INDEX_ARGS(idx) \
AddOptionalArg<int>("at_" #idx, "Position index", nullptr, true) \
.AddOptionalArg<int>("lo_" #idx, "Range start", nullptr, true) \
.AddOptionalArg<int>("hi_" #idx, "Range end", nullptr, true) \
.AddOptionalArg<int>("step_" #idx, "Range step", nullptr, true)


DALI_SCHEMA(TensorSubscript)
.MakeDocHidden()
.DocStr(R"(Applies NumPy-like indexing to a batch of tensors.)")
.NumInput(1)
.NumOutput(1)
.AddOptionalArg<int>("num_subscripts",
"Number of subscripts supplied, including full-range - used for validation only.", nullptr)
.INDEX_ARGS(0)
.INDEX_ARGS(1)
.INDEX_ARGS(2)
.INDEX_ARGS(3)
.INDEX_ARGS(4)
.INDEX_ARGS(5)
.INDEX_ARGS(6)
.INDEX_ARGS(7)
.INDEX_ARGS(8)
.INDEX_ARGS(9)
.INDEX_ARGS(10)
.INDEX_ARGS(11)
.INDEX_ARGS(12)
.INDEX_ARGS(13)
.INDEX_ARGS(14)
.INDEX_ARGS(15)
.INDEX_ARGS(16)
.INDEX_ARGS(17)
.INDEX_ARGS(18)
.INDEX_ARGS(19)
.INDEX_ARGS(20)
.INDEX_ARGS(21)
.INDEX_ARGS(22)
.INDEX_ARGS(23)
.INDEX_ARGS(24)
.INDEX_ARGS(25)
.INDEX_ARGS(26)
.INDEX_ARGS(27)
.INDEX_ARGS(28)
.INDEX_ARGS(29)
.INDEX_ARGS(30)
.INDEX_ARGS(31);

template <>
template <int ndim, int element_size>
void TensorSubscript<CPUBackend>::RunTyped(HostWorkspace &ws) {
auto &input = ws.template InputRef<CPUBackend>(0);
auto &output = ws.template OutputRef<CPUBackend>(0);
int N = input.ntensor();
using T = kernels::type_of_size<element_size>;
ThreadPool &tp = ws.GetThreadPool();

kernels::SliceCPU<T, T, ndim> K;
TensorView<StorageCPU, const T, ndim> tv_in;
TensorView<StorageCPU, T, ndim> tv_out;

kernels::KernelContext ctx;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is having one ctx for all kernels here ok?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's completely stateless. As you can see, I'm not running setup nor providing a scratchpad.

for (int i = 0; i < N; i++) {
tv_in.shape = simplified_in_shape_[i];
tv_in.data = static_cast<const T*>(input.raw_tensor(i));
tv_out.shape = simplified_out_shape_[i];
tv_out.data = static_cast<T*>(output.raw_mutable_tensor(i));
kernels::SliceArgs<T, ndim> args;
args.anchor = simplified_anchor_[i].to_static<ndim>();
args.shape = tv_out.shape;
K.Schedule(ctx, tv_out, tv_in, args, tp);
}
tp.RunAll();
}

DALI_REGISTER_OPERATOR(TensorSubscript, TensorSubscript<CPUBackend>, CPU);

DALI_SCHEMA(SubscriptDimCheck)
.MakeDocHidden()
.DocStr(R"(Checks that the input has at least `num_subscripts` dimensions.

This operator is used internally when all indices are empty (:) and just verifieis
that the input has sufficient number of dimensions and passes through the input.)")
.NumInput(1)
.NumOutput(1)
.PassThrough({{0, 0}})
.AddArg("num_subscripts",
"Number of subscripts supplied, which is the minimum required in the input.", DALI_INT32);


template <typename Backend>
struct SubscriptDimCheck : public Operator<Backend> {
explicit SubscriptDimCheck(const OpSpec &spec) : Operator<Backend>(spec) {
num_subscripts_ = spec.GetArgument<int>("num_subscripts");
}

virtual bool SetupImpl(vector<OutputDesc> &desc, const workspace_t<Backend> &ws) {
return false;
}

void RunImpl(workspace_t<Backend> &ws) override {
auto &in = ws.template InputRef<Backend>(0);
DALI_ENFORCE(num_subscripts_ <= in.sample_dim(), make_string("Too many indices (",
num_subscripts_, ") for a ", in.sample_dim(), "-D tensor."));
auto &out = ws.template OutputRef<Backend>(0);
out.ShareData(&in);
}

int num_subscripts_ = 0;
};

DALI_REGISTER_OPERATOR(SubscriptDimCheck, SubscriptDimCheck<CPUBackend>, CPU);
DALI_REGISTER_OPERATOR(SubscriptDimCheck, SubscriptDimCheck<GPUBackend>, GPU);

} // namespace dali
63 changes: 63 additions & 0 deletions dali/operators/generic/slice/subscript.cu
@@ -0,0 +1,63 @@
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <vector>
#include "dali/operators/generic/slice/subscript.h"
#include "dali/kernels/common/type_erasure.h"
#include "dali/kernels/slice/slice_gpu.cuh"

namespace dali {

template <>
template <int ndim, int element_size>
void TensorSubscript<GPUBackend>::RunTyped(DeviceWorkspace &ws) {
auto &input = ws.template InputRef<GPUBackend>(0);
auto &output = ws.template OutputRef<GPUBackend>(0);
int N = input.ntensor();
using T = kernels::type_of_size<element_size>;
using Kernel = kernels::SliceGPU<T, T, ndim>;
kmgr_.Resize<Kernel>(1, 1);

struct Ctx {
TensorListView<StorageGPU, const T, ndim> tmp_in;
TensorListView<StorageGPU, T, ndim> tmp_out;
vector<kernels::SliceArgs<T, ndim>> args;
};
Ctx *ctx = any_cast<Ctx>(&ctx_);
if (!ctx) {
ctx_ = Ctx();
ctx = &any_cast<Ctx&>(ctx_);
}

ctx->tmp_in.resize(N);
ctx->tmp_out.resize(N);
ctx->args.resize(N);
ctx->tmp_in.shape = simplified_in_shape_.to_static<ndim>();
ctx->tmp_out.shape = simplified_out_shape_.to_static<ndim>();
for (int i = 0; i < N; i++) {
ctx->tmp_in.data[i] = static_cast<const T *>(input.raw_tensor(i));
ctx->tmp_out.data[i] = static_cast<T *>(output.raw_mutable_tensor(i));
ctx->args[i].shape = ctx->tmp_out.shape[i];
ctx->args[i].anchor = simplified_anchor_[i];
}

kernels::KernelContext kctx;
kctx.gpu.stream = ws.stream();
kmgr_.Setup<Kernel>(0, kctx, ctx->tmp_in, ctx->args);
kmgr_.Run<Kernel>(0, 0, kctx, ctx->tmp_out, ctx->tmp_in, ctx->args);
}

DALI_REGISTER_OPERATOR(TensorSubscript, TensorSubscript<GPUBackend>, GPU);

} // namespace dali