Skip to content

Commit

Permalink
Add check for the number of indices. Fix some review issues.
Browse files Browse the repository at this point in the history
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
  • Loading branch information
mzient committed Aug 2, 2021
1 parent 601935e commit d0e3369
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 36 deletions.
37 changes: 37 additions & 0 deletions dali/operators/generic/slice/subscript.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,41 @@ void TensorSubscript<CPUBackend>::RunTyped(HostWorkspace &ws) {

DALI_REGISTER_OPERATOR(TensorSubscript, TensorSubscript<CPUBackend>, CPU);

DALI_SCHEMA(SubscriptDimCheck)
.MakeDocHidden()
.DocStr(R"(Checks that the input has at least `num_subscripts` dimensions.
This operator is used internally when all indices are empty (:) and just verifieis
that the input has sufficient number of dimensions and passes through the input.)")
.NumInput(1)
.NumOutput(1)
.PassThrough({{0, 0}})
.AddArg("num_subscripts",
"Number of subscripts supplied, which is the minimum required in the input.", DALI_INT32);


template <typename Backend>
struct SubscriptDimCheck : public Operator<Backend> {
explicit SubscriptDimCheck(const OpSpec &spec) : Operator<Backend>(spec) {
num_subscripts_ = spec.GetArgument<int>("num_subscripts");
}

virtual bool SetupImpl(vector<OutputDesc> &desc, const workspace_t<Backend> &ws) {
return false;
}

void RunImpl(workspace_t<Backend> &ws) override {
auto &in = ws.template InputRef<Backend>(0);
DALI_ENFORCE(num_subscripts_ <= in.sample_dim(), make_string("Too many indices (",
num_subscripts_, ") for a ", in.sample_dim(), "-D tensor."));
auto &out = ws.template OutputRef<Backend>(0);
out.ShareData(&in);
}

int num_subscripts_ = 0;
};

DALI_REGISTER_OPERATOR(SubscriptDimCheck, SubscriptDimCheck<CPUBackend>, CPU);
DALI_REGISTER_OPERATOR(SubscriptDimCheck, SubscriptDimCheck<GPUBackend>, GPU);

} // namespace dali
24 changes: 12 additions & 12 deletions dali/operators/generic/slice/subscript.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,7 @@ class TensorSubscript : public Operator<Backend> {
make_string("Too many indices (", nsub_declared_, ") for a ", ndim, "D tensor."));

int nsamples = in_shape.num_samples();
lo_.resize(nsamples, ndim);
hi_ = in_shape;
start_.resize(nsamples, ndim);
step_.resize(nsamples, ndim);
for (auto &s : step_.shapes)
s = 1;
Expand All @@ -207,8 +206,7 @@ class TensorSubscript : public Operator<Backend> {
DALI_FAIL(make_string("Index ", at, " is out of range "
"for axis ", d, " of length ", in_extent, "\n"
"Detected while processing sample #", i, " of shape (", in_shape[i], ")"));
lo_.tensor_shape_span(i)[d] = idx;
hi_.tensor_shape_span(i)[d] = idx + 1;
start_.tensor_shape_span(i)[d] = idx;
shape_.tensor_shape_span(i)[d] = 1;
}
}
Expand All @@ -218,15 +216,18 @@ class TensorSubscript : public Operator<Backend> {
int64_t lo = s.lo.IsDefined() ? s.lo.values[i] : 0;
int64_t hi = s.hi.IsDefined() ? s.hi.values[i] : in_extent;
int64_t step = s.step.IsDefined() ? s.step.values[i] : 1;
// TODO(michalz) Remove when strides are supported
DALI_ENFORCE(step == 1, "Indexing with non-unit step is not implemented");
if (lo < 0) lo += in_extent;
if (hi < 0) hi += in_extent;
lo = clamp(lo, 0_i64, in_extent);
hi = clamp(hi, 0_i64, in_extent);
lo_.tensor_shape_span(i)[d] = lo;
hi_.tensor_shape_span(i)[d] = hi;
start_.tensor_shape_span(i)[d] = lo;
step_.tensor_shape_span(i)[d] = step;

// NOTE: this code is currently not used, since the underlying kernels
// don't support strides.
// TODO(michalz): Remove this comment when strides are supported.
int64_t out_extent = step > 0 ? div_ceil(hi - lo, step)
: div_ceil(lo - hi, -step);
if (out_extent < 0)
Expand Down Expand Up @@ -265,14 +266,14 @@ class TensorSubscript : public Operator<Backend> {

collapse_dims(simplified_in_shape_, in_shape, collapsed_dims_);
collapse_dims(simplified_out_shape_, shape_, collapsed_dims_);
collapse_dims(simplified_anchor_, lo_, collapsed_dims_);
collapse_dims(simplified_anchor_, start_, collapsed_dims_);

out_shape.resize(in_shape.num_samples(), out_dim_map_.size());
for (int i = 0; i < out_shape.num_samples(); i++) {
auto out_sample_shape = out_shape.tensor_shape_span(i);
auto sh = shape_.tensor_shape_span(i);
auto sample_shape = shape_.tensor_shape_span(i);
for (int d = 0; d < out_shape.sample_dim(); d++) {
out_sample_shape[d] = sh[out_dim_map_[d]];
out_sample_shape[d] = sample_shape[out_dim_map_[d]];
}
}
}
Expand Down Expand Up @@ -300,7 +301,7 @@ class TensorSubscript : public Operator<Backend> {
(RunTyped<ndim, element_size>(ws);),
(DALI_FAIL(make_string("Unsupported input type: ", input.type().id()));))),
(DALI_FAIL("Subscript too complex.\n"
"The subscript operator supports up to 32 total and up to 16 non-collapsible dimensions.\n"
"The subscript operator supports up to 32 total and up to 16 non-collapsible dimensions.\n"
"Adjacent dimensions from which no index or slice is taken can be collapsed.");)
); // NOLINT
}
Expand All @@ -315,7 +316,7 @@ class TensorSubscript : public Operator<Backend> {

// Ranges, steps and output shapes in input space - that is, not including
// the dimensions which are removed by indexing or ones collapsed as a result of simplification.
TensorListShape<> lo_, hi_, step_, shape_;
TensorListShape<> start_, step_, shape_;

// Grouping of indices, used for simplification
SmallVector<std::pair<int, int>, 6> collapsed_dims_;
Expand All @@ -331,7 +332,6 @@ class TensorSubscript : public Operator<Backend> {
any ctx_;
};


} // namespace dali

#endif // DALI_OPERATORS_GENERIC_SLICE_SUBSCRIPT_H_
2 changes: 1 addition & 1 deletion dali/pipeline/data/tensor_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ void TensorVector<Backend>::ShareData(TensorVector<Backend> *tv) {
state_ = tv->state_;
pinned_ = tv->is_pinned();

if (tv->tl_->raw_data()) {
if (IsValidType(tv->tl_->type())) {
tl_->ShareData(tv->tl_.get());
} else {
tl_->Reset();
Expand Down
2 changes: 1 addition & 1 deletion dali/python/nvidia/dali/_utils/hacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _hook_iterable_check():
_not_iterable = (_NotIterable,)

def not_iterable(cls, add_iter=True):
"""Makes an object non-iterable by raising a TypeError in add_iter and suppressing
"""Makes an object non-iterable by raising a TypeError in __iter__ and suppressing
the detection of the object as an instance of collections.abc.Iterable.
"""
_hook_iterable_check()
Expand Down
8 changes: 6 additions & 2 deletions dali/python/nvidia/dali/data_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,13 @@ def process_index(idx, dim):

import nvidia.dali.fn
if len(slice_args) == 0:
sliced = self
if len(new_axes) == 0 or new_axes[-1] < len(idxs):
print("Adding dim check for ", len(idxs))
sliced = nvidia.dali.fn.subscript_dim_check(self, num_subscripts=len(idxs))
else:
print("Dim check will be performed by expand_dims.")
sliced = self
else:
print("num_subscripts", len(idxs))
sliced = nvidia.dali.fn.tensor_subscript(self, **slice_args, num_subscripts=len(idxs))
if len(new_axes) == 0:
return sliced
Expand Down
16 changes: 12 additions & 4 deletions dali/test/python/test_operator_subscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_swapped_ends():
def test_noop():
node = dali.types.Constant(np.float32([1,2,2]))
indexed = node[:]
assert node is indexed
assert "SubscriptDimCheck" in indexed.name

def test_runtime_indexing():
def data_gen():
Expand Down Expand Up @@ -151,16 +151,24 @@ def test_out_of_range():


def _test_too_many_indices(device):
# NOTE: There's a know issue that the number of dimensions is not checked if the
# tensor is indexed only with full-range subscripts (':') - in this case, the whole operation
# is a no-op.
data = [np.uint8([1,2,3]),np.uint8([1,2])]
src = fn.external_source(lambda: data, device=device)
pipe = index_pipe(src, lambda x: x[1,:])
with assert_raises(RuntimeError):
pipe.build()
_ = pipe.run()

pipe = index_pipe(src, lambda x: x[:,:])
with assert_raises(RuntimeError):
pipe.build()
_ = pipe.run()

pipe = index_pipe(src, lambda x: x[:,:,dali.newaxis])
with assert_raises(RuntimeError):
pipe.build()
_ = pipe.run()


def test_too_many_indices():
for device in ["cpu", "gpu"]:
yield _test_too_many_indices, device
17 changes: 16 additions & 1 deletion dali/test/python/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -1609,6 +1609,21 @@ def my_pipe():

p = my_pipe(device_id=0, seed=1234, num_threads=3, set_affinity=True, py_num_workers=3)

def test_not_iterable():
import nvidia.dali._utils.hacks as hacks
import collections.abc
class X:
def __iter__(self):
pass
class Y:
def __iter__(self):
pass
assert isinstance(X(), collections.abc.Iterable)
hacks.not_iterable(X)
assert not isinstance(X(), collections.abc.Iterable)
assert isinstance(Y(), collections.abc.Iterable)
hacks.not_iterable(Y)
assert not isinstance(Y(), collections.abc.Iterable)

@pipeline_def(batch_size=1, num_threads=1, device_id=0)
def _identity_pipe():
Expand Down
15 changes: 0 additions & 15 deletions dali/test/python/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import nvidia.dali.types as types
import nvidia.dali as dali
from nvidia.dali.backend_impl import TensorListGPU, TensorGPU, TensorListCPU
import nvidia.dali._utils.hacks as hacks

import tempfile
import subprocess
Expand All @@ -27,20 +26,6 @@
import random
import re

def test_not_iterable():
import collections.abc
class X:
def __iter__(self):
pass
class Y:
def __iter__(self):
pass
assert isinstance(X(), collections.abc.Iterable)
hacks.not_iterable(X)
assert not isinstance(X(), collections.abc.Iterable)
assert isinstance(Y(), collections.abc.Iterable)


def get_dali_extra_path():
try:
dali_extra_path = os.environ['DALI_EXTRA_PATH']
Expand Down

0 comments on commit d0e3369

Please sign in to comment.