Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Fix scalar batch handling in arithmetic ops (#1449)
Adjust test for checking if the Tensor should
be considered a scalar - uniform shape
of 1-dim scalar elements.
Allow Scalar batch to be broadcasted in the operation
properly - no offset in tiles for scalar-like data,
same code-path as constants.

Return shape is always a batch - fixed for Constant/Scalar
inputs.

Add C++ unit tests for Scalar Batch.
Test if Pipeline behaves properly when switching between
scalar and non-scalar inputs.

Python test was extended to cover Scalar inputs,
size of some of the tests was reduced.

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
  • Loading branch information
klecki committed Nov 8, 2019
1 parent a5c275a commit b33c5ac
Show file tree
Hide file tree
Showing 6 changed files with 400 additions and 128 deletions.
16 changes: 9 additions & 7 deletions dali/operators/expressions/arithmetic.h
Expand Up @@ -145,7 +145,8 @@ inline std::vector<ExprImplTask> CreateExecutionTasks(const ExprNode &expr, Expr
return result;
}

inline TensorListShape<> ShapePromotion(std::string op, span<const TensorListShape<> *> shapes) {
inline TensorListShape<> ShapePromotion(std::string op, span<const TensorListShape<> *> shapes,
int batch_size) {
const TensorListShape<> *out_shape = nullptr;
for (int i = 0; i < shapes.size(); i++) {
if (IsScalarLike(*shapes[i]))
Expand All @@ -159,14 +160,15 @@ inline TensorListShape<> ShapePromotion(std::string op, span<const TensorListSha
*out_shape, ", ", *shapes[i], ")."));
}
}
return out_shape ? *out_shape : TensorListShape<>{{1}};
return out_shape ? *out_shape : uniform_list_shape(batch_size, {1});
}

template <typename Backend>
DLL_PUBLIC inline const TensorListShape<> &PropagateShapes(ExprNode &expr,
const workspace_t<Backend> &ws) {
const workspace_t<Backend> &ws,
int batch_size) {
if (expr.GetNodeType() == NodeType::Constant) {
expr.SetShape(TensorListShape<>{{1}});
expr.SetShape(uniform_list_shape(batch_size, {1}));
return expr.GetShape();
}
if (expr.GetNodeType() == NodeType::Tensor) {
Expand All @@ -182,9 +184,9 @@ DLL_PUBLIC inline const TensorListShape<> &PropagateShapes(ExprNode &expr,
SmallVector<const TensorListShape<> *, kMaxArity> shapes;
shapes.resize(subexpression_count);
for (int i = 0; i < subexpression_count; i++) {
shapes[i] = &PropagateShapes<Backend>(func[i], ws);
shapes[i] = &PropagateShapes<Backend>(func[i], ws, batch_size);
}
func.SetShape(ShapePromotion(func.GetFuncName(), make_span(shapes)));
func.SetShape(ShapePromotion(func.GetFuncName(), make_span(shapes), batch_size));
return func.GetShape();
}

Expand Down Expand Up @@ -242,7 +244,7 @@ class ArithmeticGenericOp : public Operator<Backend> {
types_layout_inferred_ = true;
}

result_shape_ = PropagateShapes<Backend>(*expr_, ws);
result_shape_ = PropagateShapes<Backend>(*expr_, ws, batch_size_);
AllocateIntermediateNodes();
exec_order_ = CreateExecutionTasks<Backend>(*expr_, cache_, ws.has_stream() ? ws.stream() : 0);

Expand Down
2 changes: 1 addition & 1 deletion dali/operators/expressions/arithmetic_meta.h
Expand Up @@ -429,7 +429,7 @@ inline ArithmeticOp NameToOp(const std::string &op_name) {
}

inline bool IsScalarLike(const TensorListShape<> &shape) {
return shape.num_samples() == 1 && shape.num_elements() == 1;
return is_uniform(shape) && shape.sample_dim() == 1 && shape.tensor_shape_span(0)[0] == 1;
}

} // namespace dali
Expand Down

0 comments on commit b33c5ac

Please sign in to comment.