Skip to content

Commit

Permalink
Add comments.
Browse files Browse the repository at this point in the history
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
  • Loading branch information
mzient committed Aug 2, 2021
1 parent 2391452 commit 981923a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
21 changes: 21 additions & 0 deletions dali/operators/generic/slice/subscript.h
Expand Up @@ -174,6 +174,9 @@ class TensorSubscript : public Operator<Backend> {
return true;
}

/**
* @brief Calculates the input ranges from the arguments and the input shape.
*/
void GetRanges(const workspace_t<Backend> &ws, const TensorListShape<> &in_shape) {
int nsub = subscripts_.size();
int ndim = in_shape.sample_dim();
Expand All @@ -194,6 +197,9 @@ class TensorSubscript : public Operator<Backend> {
}
}

/**
* @brief Calculates the input range for one dimension `d`
*/
void GetRange(SubscriptInfo &s, int d, const TensorListShape<> &in_shape) {
int nsamples = in_shape.num_samples();

Expand Down Expand Up @@ -237,6 +243,21 @@ class TensorSubscript : public Operator<Backend> {
}
}

/**
* @brief Produces output shape as well as some intermediate shapes and anchors
*
* There are three shape spaces:
* - Input shape space
* - Output shape space - with the dimensions indexed by scalars removed
* - Simplified shape space - the scalar-indexed dimensions are kept, but the
* adjacent non-sliced dimensions are collapsed to facilitate processing.
*
* The output shape is, obviously, in output space.
* There's also a smplified output shape, simplified input shape and anchors, all in the
* simplified shape space.
* This function calculates the collapsed groups for simplification as well as calculates
* all the aforementioned shapes and achors.
*/
void ProcessShapes(TensorListShape<> &out_shape, const TensorListShape<> &in_shape) {
int in_dims = in_shape.sample_dim();
int nsub = subscripts_.size();
Expand Down
5 changes: 4 additions & 1 deletion dali/test/python/test_operator_subscript.py
Expand Up @@ -154,21 +154,24 @@ def _test_too_many_indices(device):
data = [np.uint8([1,2,3]),np.uint8([1,2])]
src = fn.external_source(lambda: data, device=device)
pipe = index_pipe(src, lambda x: x[1,:])

# Verified by tensor_subscript
with assert_raises(RuntimeError):
pipe.build()
_ = pipe.run()

# Verified by subscript_dim_check
pipe = index_pipe(src, lambda x: x[:,:])
with assert_raises(RuntimeError):
pipe.build()
_ = pipe.run()

# Verified by expand_dims
pipe = index_pipe(src, lambda x: x[:,:,dali.newaxis])
with assert_raises(RuntimeError):
pipe.build()
_ = pipe.run()


def test_too_many_indices():
for device in ["cpu", "gpu"]:
yield _test_too_many_indices, device
Expand Down

0 comments on commit 981923a

Please sign in to comment.