Skip to content

Commit

Permalink
Pad operator: Add support for per-sample shape and alignment requirem…
Browse files Browse the repository at this point in the history
…ents

Signed-off-by: Joaquin Anton <janton@nvidia.com>
  • Loading branch information
jantonguirao committed Nov 3, 2020
1 parent 359a6a5 commit 2427804
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 66 deletions.
6 changes: 4 additions & 2 deletions dali/operators/generic/pad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ If an integer value is provided, the alignment restrictions are applied to all t
To use alignment only, that is without any default or explicit padding behavior,
set the minimum ``shape`` to 1 for the specified axis.)code",
std::vector<int>())
std::vector<int>(), true)
.AddOptionalArg<int>("shape",
R"code(The extents of the output shape in the axes specified by the ``axes`` or ``axis_names``.
Expand All @@ -180,7 +180,7 @@ the aligned size of the largest sample in the batch.
If the provided extent is smaller than the one of the samples, padding will be applied
only to match the required alignment. For example, to disable padding in an axis, except
for the necessary for alignment, you can specify a value of 1.)code",
vector<int>());
std::vector<int>(), true);

template <>
bool Pad<CPUBackend>::SetupImpl(std::vector<OutputDesc> &output_desc,
Expand All @@ -193,6 +193,8 @@ bool Pad<CPUBackend>::SetupImpl(std::vector<OutputDesc> &output_desc,
int nsamples = in_shape.num_samples();
auto nthreads = ws.GetThreadPool().size();

ReadArguments(spec_, ws);

TYPE_SWITCH(input.type().id(), type2id, T, PAD_SUPPORTED_TYPES, (
VALUE_SWITCH(ndim, Dims, PAD_SUPPORTED_NDIMS, (
using Kernel = kernels::SliceCPU<T, T, Dims>;
Expand Down
2 changes: 2 additions & 0 deletions dali/operators/generic/pad.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ bool Pad<GPUBackend>::SetupImpl(std::vector<OutputDesc> &output_desc,
int ndim = in_shape.sample_dim();
int nsamples = in_shape.num_samples();

this->ReadArguments(spec_, ws);

TYPE_SWITCH(input.type().id(), type2id, T, PAD_SUPPORTED_TYPES, (
VALUE_SWITCH(ndim, Dims, PAD_SUPPORTED_NDIMS, (
using Kernel = kernels::SliceGPU<T, T, Dims>;
Expand Down
133 changes: 82 additions & 51 deletions dali/operators/generic/pad.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#define DALI_OPERATORS_GENERIC_PAD_H_

#include <cstring>
#include <string>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -46,15 +47,6 @@ class Pad : public Operator<Backend> {
} else if (has_axes) {
axes_ = spec.GetRepeatedArgument<int>("axes");
}

if (spec.HasArgument("align")) {
align_ = spec.GetRepeatedArgument<int>("align");
}

if (spec.HasArgument("shape")) {
auto shape_arg = spec.GetRepeatedArgument<int>("shape");
shape_ = std::vector<int64_t>{std::begin(shape_arg), std::end(shape_arg)};
}
}

protected:
Expand All @@ -68,64 +60,98 @@ class Pad : public Operator<Backend> {
}

private:
static inline int64_t aligned_extent(int64_t extent, int alignment) {
int64_t remainder = extent % alignment;
if (remainder > 0)
extent += alignment - remainder;
return extent;
void ReadShapeListArg(std::vector<TensorShape<>> &out, const std::string &arg_name,
const OpSpec &spec, const ArgumentWorkspace &ws,
int nsamples) {
out.clear();
out.resize(nsamples);
if (spec.HasTensorArgument(arg_name)) {
auto &arg_in = ws.ArgumentInput(arg_name);
DALI_ENFORCE(nsamples == static_cast<int>(arg_in.size()),
make_string("Expected ", nsamples, ", got ", arg_in.size()));
int nsamples = arg_in.size();
auto arg_view = view<const int>(arg_in);
for (int i = 0; i < nsamples; i++) {
out[i] = TensorShape<>(make_cspan(arg_view[i].data, arg_view[i].shape[0]));
}
} else if (spec.HasArgument(arg_name)) {
auto arg = spec.GetRepeatedArgument<int>(arg_name);
TensorShape<> sh(make_cspan(arg));
for (int i = 0; i < nsamples; i++) {
out[i] = sh;
}
}
assert(static_cast<int>(out.size()) == nsamples);
}

template <typename Args>
std::vector<Args>& FillArgs(TensorListShape<> in_shape, TensorLayout in_layout) {
int nsamples = in_shape.num_samples();
void ReadArguments(const OpSpec &spec, const workspace_t<Backend> &ws) {
const auto &input = ws.template InputRef<Backend>(0);
auto in_shape = input.shape();
auto in_layout = input.GetLayout();
int ndim = in_shape.sample_dim();
int nsamples = in_shape.num_samples();

if (!axis_names_.empty()) {
axes_ = GetDimIndices(in_layout, axis_names_).to_vector();
}

for (auto axis : axes_) {
DALI_ENFORCE(axis >= 0 && axis < ndim,
make_string("specified axis is out of bounds. axis=", axis, ", ndim=", ndim));
}

if (!axis_names_.empty()) {
axes_ = GetDimIndices(in_layout, axis_names_).to_vector();
}

// If no axes are provided, use all
if (axes_.empty()) {
axes_.resize(ndim);
std::iota(axes_.begin(), axes_.end(), 0);
}

// If a single *align* value is provided, use this value for all the axes
for (auto &align : align_) {
DALI_ENFORCE(align > 0, "Values of `align` argument must be positive.");
}
if (align_.size() == 1 && axes_.size() > 1) {
align_.resize(axes_.size(), align_[0]);
}

TensorShape<> padded_shape;
padded_shape.resize(ndim);
ReadShapeListArg(shape_, "shape", spec, ws, nsamples);
ReadShapeListArg(align_, "align", spec, ws, nsamples);

// If no shape arg is provided, mark all axis as -1 (automatic shape)
if (shape_.empty()) {
shape_ = std::vector<int64_t>(axes_.size(), -1);
// If a single *align* value is provided, use this value for all the axes
for (int i = 0; i < nsamples; i++) {
auto &align = align_[i];
for (auto &a : align) {
DALI_ENFORCE(a > 0, "Values of `align` argument must be positive.");
}
if (align.size() == 1 && axes_.size() > 1) {
align = std::vector<int64_t>(axes_.size(), align[0]);
}
auto &shape = shape_[i];
// If no shape arg is provided, mark all axis as -1 (automatic shape)
if (shape.empty()) {
shape = std::vector<int64_t>(axes_.size(), -1);
}
DALI_ENFORCE(
static_cast<int>(axes_.size()) == shape.size(),
make_string(
"If explicit shape is provided, there should be a value per axis to be padded. "
"Expected ",
axes_.size(), " values for shape, got ", shape.size()));
}
}

DALI_ENFORCE(static_cast<int>(axes_.size()) == shape_.size(),
make_string("If explicit shape is provided, there should be a value per axis to be padded. "
"Expected ", axes_.size(), " values for shape, got ", shape_.size()));
static inline int64_t aligned_extent(int64_t extent, int alignment) {
int64_t remainder = extent % alignment;
if (remainder > 0)
extent += alignment - remainder;
return extent;
}

template <typename Args>
std::vector<Args>& FillArgs(TensorListShape<> in_shape, TensorLayout in_layout) {
int nsamples = in_shape.num_samples();
int ndim = in_shape.sample_dim();
int naxes = axes_.size();
assert(naxes <= ndim);

TensorShape<> largest_shape = std::vector<int64_t>(ndim, -1);
for (int i = 0; i < naxes; i++) {
auto axis = axes_[i];
padded_shape[axis] = shape_[i];
if (padded_shape[axis] < 0) {
for (int sample_idx = 0; sample_idx < nsamples; sample_idx++) {
auto data_shape = in_shape[sample_idx];
if (data_shape[axis] > padded_shape[axis])
padded_shape[axis] = data_shape[axis];
}
for (int sample_idx = 0; sample_idx < nsamples; sample_idx++) {
auto data_shape = in_shape.tensor_shape_span(sample_idx);
largest_shape[axis] = std::max(largest_shape[axis], data_shape[axis]);
}
}

Expand All @@ -139,7 +165,9 @@ class Pad : public Operator<Backend> {

for (int sample_idx = 0; sample_idx < nsamples; sample_idx++) {
auto &sample_args = kernel_sample_args[sample_idx];
const auto &sample_shape = in_shape[sample_idx];
const auto &sample_shape = in_shape.tensor_shape_span(sample_idx);
const auto &req_shape = shape_[sample_idx];
const auto &req_align = align_[sample_idx];
for (int d = 0; d < sample_args.anchor.size(); d++) {
sample_args.anchor[d] = 0;
sample_args.shape[d] = sample_shape[d];
Expand All @@ -152,11 +180,14 @@ class Pad : public Operator<Backend> {
auto &extent = sample_args.shape[axis];
// Adjust padded extent only if it is bigger than the sample's extent
// That is, we are not cropping the image
if (padded_shape[axis] > extent)
extent = padded_shape[axis];
if (req_shape[axis] > 0) { // explicitly requested padded shape
extent = std::max(req_shape[axis], extent);
} else { // pad to largest
extent = std::max(largest_shape[axis], extent);
}
// Adjust alignment if required
if (!align_.empty())
extent = aligned_extent(extent, align_[i]);
if (!req_align.empty())
extent = aligned_extent(extent, req_align[i]);
}
}

Expand All @@ -166,8 +197,8 @@ class Pad : public Operator<Backend> {
float fill_value_;
TensorLayout axis_names_;
std::vector<int> axes_;
std::vector<int> align_;
TensorShape<> shape_;
std::vector<TensorShape<>> align_;
std::vector<TensorShape<>> shape_;
kernels::KernelManager kmgr_;
any kernel_sample_args_;

Expand Down
2 changes: 1 addition & 1 deletion dali/operators/random/uniform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ This argument is mutually exclusive with ``values``.
This argument is mutually exclusive with ``range``.
)code", std::vector<float>({}))
.AddOptionalArg("shape",
R"code(Shape of the samples.)code", std::vector<int>{1});
R"code(Shape of the samples.)code", std::vector<int>{1}, true);


} // namespace dali
25 changes: 17 additions & 8 deletions dali/operators/random/uniform.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,6 @@ class Uniform : public Operator<CPUBackend> {
DALI_ENFORCE(range_[0] < range_[1],
"Invalid range. It shall be left-closed [a, b), where a < b");
}

std::vector<int> shape_arg{};
if (spec.HasArgument("shape"))
shape_arg = spec.GetRepeatedArgument<int>("shape");
shape_ = std::vector<int64_t>{std::begin(shape_arg), std::end(shape_arg)};
}

inline ~Uniform() override = default;
Expand All @@ -63,12 +58,27 @@ class Uniform : public Operator<CPUBackend> {

bool SetupImpl(std::vector<OutputDesc> &output_desc, const HostWorkspace &ws) override {
output_desc.resize(1);
output_desc[0].shape = uniform_list_shape(batch_size_, shape_);
output_desc[0].type = TypeTable::GetTypeInfo(DALI_FLOAT);
auto& sh = output_desc[0].shape;
if (spec_.HasTensorArgument("shape")) {
auto &sh_arg_in = ws.ArgumentInput("shape");
int nsamples = sh_arg_in.size();
assert(nsamples > 0);
auto sh_view = view<const int>(sh_arg_in);
DALI_ENFORCE(is_uniform(sh_view.shape) && sh_view.shape[0].size() == 1,
"Shapes are expected to have the same number of dimensions");
int ndim = sh_view.shape.tensor_shape_span(0)[0];
sh.resize(nsamples, ndim);
for (int i = 0; i < nsamples; i++) {
sh.set_tensor_shape(i, TensorShape<>(make_cspan(sh_view[i].data, sh_view[i].shape[0])));
}
} else {
auto shape_arg = spec_.GetRepeatedArgument<int>("shape");
sh = uniform_list_shape(batch_size_, TensorShape<>(make_cspan(shape_arg)));
}
return true;
}


void RunImpl(HostWorkspace &ws) override;

private:
Expand All @@ -79,7 +89,6 @@ class Uniform : public Operator<CPUBackend> {
std::mt19937 rng_;
const bool discrete_mode_; // mode can't change throughout lifetime of this op, due to RNG
std::vector<float> range_, set_;
TensorShape<> shape_;
};

} // namespace dali
Expand Down
52 changes: 48 additions & 4 deletions dali/test/python/test_operator_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import nvidia.dali as dali
from nvidia.dali.backend_impl import TensorListGPU
Expand Down Expand Up @@ -46,7 +47,7 @@ def iter_setup(self):
self.feed_input(self.data, data, layout=self.layout)


def check_pad_synth_data(device, batch_size, input_max_shape, axes, axis_names, align, shape_arg):
def check_pad(device, batch_size, input_max_shape, axes, axis_names, align, shape_arg):
eii = RandomlyShapedDataIterator(batch_size, max_shape=input_max_shape)
layout = "HWC"
pipe = PadSynthDataPipeline(device, batch_size, iter(eii), axes=axes, axis_names=axis_names,
Expand Down Expand Up @@ -107,7 +108,7 @@ def check_pad_synth_data(device, batch_size, input_max_shape, axes, axis_names,
expected_extent = expected_extent - remainder + align_val
assert(output_shape[dim] == expected_extent)

def test_slice_synth_data_vs_numpy():
def test_pad():
for device in ["cpu", "gpu"]:
for batch_size in {1, 8}:
for input_max_shape, axes, axis_names, align, shape_arg in \
Expand All @@ -131,9 +132,9 @@ def test_slice_synth_data_vs_numpy():
((200, 400, 3), None, None, None, (-1, -1, 4)),
((25, 100, 3), (0,), None, None, (25,)),
((200, 400, 3), (0, 1), None, (4, 16), (1, 200))]:
yield check_pad_synth_data, device, batch_size, input_max_shape, axes, axis_names, align, shape_arg
yield check_pad, device, batch_size, input_max_shape, axes, axis_names, align, shape_arg

def test_pad_fail():
def test_pad_error():
batch_size = 2
input_max_shape = (5, 5, 3)
device = "cpu"
Expand All @@ -149,3 +150,46 @@ def test_pad_fail():

pipe.build()
assert_raises(RuntimeError, pipe.run)

def is_aligned(sh, align):
assert len(sh) == len(align)
for i in range(len(sh)):
if sh[i] % align[i] > 0:
return False
return True

def check_pad_per_sample_shapes_and_alignment(device='cpu', batch_size=3, ndim=2, num_iter=3):
pipe = Pipeline(batch_size=batch_size, num_threads=3, device_id=0, seed=1234)
with pipe:
in_shape = fn.cast(fn.uniform(range=(10, 20), shape=(ndim,)), dtype=types.INT32)
in_data = fn.uniform(range=(0., 1.), shape=in_shape)
req_shape = fn.cast(fn.uniform(range=(21, 30), shape=(ndim,)), dtype=types.INT32)
req_align = fn.cast(fn.uniform(range=(3, 5), shape=(ndim,)), dtype=types.INT32)
out_pad_shape = fn.pad(in_data, axes=[0, 1], axis_names=None, align=None, shape=req_shape)
out_pad_align = fn.pad(in_data, axes=[0, 1], axis_names=None, align=req_align, shape=None)
out_pad_both = fn.pad(in_data, axes=[0, 1], axis_names=None, align=req_align, shape=req_shape)
pipe.set_outputs(in_shape, in_data, req_shape, req_align, out_pad_shape, out_pad_align, out_pad_both)
pipe.build()
for _ in range(num_iter):
outs = pipe.run()
for i in range(batch_size):
in_shape, in_data, req_shape, req_align, out_pad_shape, out_pad_align, out_pad_both = \
[outs[out_idx].at(i) for out_idx in range(len(outs))]
assert (in_shape == in_data.shape).all()

# Pad to explicit shape
assert (out_pad_shape.shape >= in_shape).all()
assert (req_shape == out_pad_shape.shape).all()

# Alignment only
assert (out_pad_align.shape >= in_shape).all()
assert is_aligned(out_pad_align.shape, req_align)

# Explicit shape + alignment
assert (out_pad_both.shape >= in_shape).all()
assert (req_shape <= out_pad_both.shape).all()
assert is_aligned(out_pad_both.shape, req_align)

def test_pad_per_sample_shapes_and_alignment():
yield check_pad_per_sample_shapes_and_alignment, 'cpu'
yield check_pad_per_sample_shapes_and_alignment, 'gpu'
7 changes: 7 additions & 0 deletions include/dali/core/tensor_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,15 @@ struct TensorShape<DynamicDimensions>
: public TensorShapeBase<DynamicTensorShapeContainer, DynamicDimensions> {
using Base = TensorShapeBase<DynamicTensorShapeContainer, DynamicDimensions>;

TensorShape(span<const int64_t> s) // NOLINT
: Base(DynamicTensorShapeContainer(s.begin(), s.end())) {}

TensorShape(span<const int> s) // NOLINT
: Base(DynamicTensorShapeContainer(s.begin(), s.end())) {}

TensorShape(const std::vector<int64_t> &s) // NOLINT
: Base(DynamicTensorShapeContainer(s.data(), s.size())) {}

TensorShape(const DynamicTensorShapeContainer &s) : Base(s) {} // NOLINT

template <size_t N>
Expand Down

0 comments on commit 2427804

Please sign in to comment.