Skip to content

Commit

Permalink
Added Stride to Subscript and Slice Kernel (NVIDIA#5007)
Browse files Browse the repository at this point in the history
Enables numpy style slicing with strides to tensor subscript operator by supporting a `steps` member to slice params.

Signed-off-by: Bryce Ferenczi <frenzi@hotmail.com.au>
Co-authored-by: Michal Zientkiewicz <michalz@nvidia.com>
  • Loading branch information
2 people authored and JanuszL committed Oct 13, 2023
1 parent 71a6f17 commit 8b0ee6d
Show file tree
Hide file tree
Showing 12 changed files with 354 additions and 208 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ docs/op_autodoc
docs/fn_autodoc
docs/nvidia.ico
.DS_Store
build-docker-*
113 changes: 62 additions & 51 deletions dali/kernels/slice/slice_cpu.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -40,11 +40,12 @@ namespace slice_impl {
template <typename OutputType, typename InputType, bool OutOfBounds, bool NeedPad>
void SliceKernelImplChannelLast(OutputType *output,
const InputType *input,
const int64_t* out_strides,
const int64_t* in_strides,
const int64_t* out_shape,
const int64_t* in_shape,
const int64_t* anchor,
const int64_t *out_strides,
const int64_t *in_strides,
const int64_t *out_shape,
const int64_t *in_shape,
const int64_t *anchor,
const int64_t *step,
const OutputType *fill_values,
int channel_dim, // negative if no channel dim or already processed
std::integral_constant<bool, OutOfBounds>,
Expand Down Expand Up @@ -104,7 +105,7 @@ void SliceKernelImplChannelLast(OutputType *output,
output[out_c] = fill_values[out_c];

output += out_nchannels;
input += in_nchannels;
input += in_nchannels * step[d];
}
}

Expand All @@ -129,11 +130,12 @@ void SliceKernelImplChannelLast(OutputType *output,
template <typename OutputType, typename InputType, bool OutOfBounds, bool NeedPad>
void SliceKernelImpl(OutputType *output,
const InputType *input,
const int64_t* out_strides,
const int64_t* in_strides,
const int64_t* out_shape,
const int64_t* in_shape,
const int64_t* anchor,
const int64_t *out_strides,
const int64_t *in_strides,
const int64_t *out_shape,
const int64_t *in_shape,
const int64_t *anchor,
const int64_t *step,
const OutputType *fill_values,
int channel_dim, // negative if no channel dim or already processed
std::integral_constant<int, 1>,
Expand All @@ -151,24 +153,26 @@ void SliceKernelImpl(OutputType *output,
int out_idx = 0;

if (NeedPad) {
// out of bounds (left side)
for (; in_idx < 0 && out_idx < out_shape[d]; in_idx++, out_idx++) {
// out of bounds (left side of output)
for (; (in_idx < 0 || in_idx >= in_shape[d]) && out_idx < out_shape[d];
in_idx += step[d], out_idx++) {
output[out_idx] = *fill_values;
if (d == channel_dim)
fill_values++;
}
}

// within input bounds
for (; in_idx < in_shape[d] && out_idx < out_shape[d]; in_idx++, out_idx++) {
for (; (0 <= in_idx && in_idx < in_shape[d]) && out_idx < out_shape[d];
in_idx += step[d], out_idx++) {
output[out_idx] = clamp<OutputType>(input[in_idx]);
if (NeedPad && d == channel_dim)
fill_values++;
}

if (NeedPad) {
// out of bounds (right side)
for (; out_idx < out_shape[d]; in_idx++, out_idx++) {
// out of bounds (right side of output)
for (; out_idx < out_shape[d]; in_idx += step[d], out_idx += out_strides[d]) {
output[out_idx] = *fill_values;
if (d == channel_dim)
fill_values++;
Expand All @@ -180,11 +184,12 @@ void SliceKernelImpl(OutputType *output,
template <typename OutputType, typename InputType, bool OutOfBounds, bool NeedPad, int DimsLeft>
void SliceKernelImpl(OutputType *output,
const InputType *input,
const int64_t* out_strides,
const int64_t* in_strides,
const int64_t* out_shape,
const int64_t* in_shape,
const int64_t* anchor,
const int64_t *out_strides,
const int64_t *in_strides,
const int64_t *out_shape,
const int64_t *in_shape,
const int64_t *anchor,
const int64_t *step,
const OutputType *fill_values,
int channel_dim, // negative if no channel dim or already processed
std::integral_constant<int, DimsLeft>,
Expand All @@ -193,7 +198,7 @@ void SliceKernelImpl(OutputType *output,
// Special case for last 2 dimensions with channel-last configuration
if (DimsLeft == 2 && channel_dim == 1) {
SliceKernelImplChannelLast(output, input, out_strides, in_strides, out_shape, in_shape, anchor,
fill_values, channel_dim,
step, fill_values, channel_dim,
std::integral_constant<bool, OutOfBounds>(),
std::integral_constant<bool, NeedPad>());
return;
Expand All @@ -207,10 +212,11 @@ void SliceKernelImpl(OutputType *output,
input += anchor[d] * in_strides[d];

if (NeedPad) {
// out of bounds (left side)
for (; in_idx < 0 && out_idx < out_shape[d]; in_idx++, out_idx++) {
// out of bounds (left side of output)
for (; (in_idx < 0 || in_idx >= in_shape[d]) && out_idx < out_shape[d];
in_idx += step[d], out_idx++) {
SliceKernelImpl(output, input, out_strides + 1, in_strides + 1, out_shape + 1,
in_shape + 1, anchor + 1, fill_values, channel_dim - 1,
in_shape + 1, anchor + 1, step + 1, fill_values, channel_dim - 1,
std::integral_constant<int, DimsLeft - 1>(),
std::integral_constant<bool, true>(),
std::integral_constant<bool, NeedPad>());
Expand All @@ -221,24 +227,25 @@ void SliceKernelImpl(OutputType *output,
}

// within input bounds
for (; in_idx < in_shape[d] && out_idx < out_shape[d]; in_idx++, out_idx++) {
for (; (0 <= in_idx && in_idx < in_shape[d]) && out_idx < out_shape[d];
in_idx += step[d], out_idx++) {
SliceKernelImpl(output, input, out_strides + 1, in_strides + 1, out_shape + 1,
in_shape + 1, anchor + 1, fill_values, channel_dim - 1,
in_shape + 1, anchor + 1, step + 1, fill_values, channel_dim - 1,
std::integral_constant<int, DimsLeft - 1>(),
std::integral_constant<bool, OutOfBounds>(),
std::integral_constant<bool, NeedPad>());
output += out_strides[d];
if (!OutOfBounds)
input += in_strides[d];
input += in_strides[d] * step[d];
if (NeedPad && d == channel_dim)
fill_values++;
}

if (NeedPad) {
// out of bounds (right side)
for (; out_idx < out_shape[d]; in_idx++, out_idx++) {
// out of bounds (right side of output)
for (; out_idx < out_shape[d]; out_idx++) {
SliceKernelImpl(output, input, out_strides + 1, in_strides + 1, out_shape + 1,
in_shape + 1, anchor + 1, fill_values, channel_dim - 1,
in_shape + 1, anchor + 1, step + 1, fill_values, channel_dim - 1,
std::integral_constant<int, DimsLeft - 1>(),
std::integral_constant<bool, true>(),
std::integral_constant<bool, NeedPad>());
Expand All @@ -260,20 +267,21 @@ void SliceKernel(OutputType *output,
const TensorShape<Dims> &out_shape,
const TensorShape<Dims> &in_shape,
const TensorShape<Dims> &anchor,
const TensorShape<Dims> &step,
const OutputType *fill_values,
int channel_dim = -1) { // negative if no channel dim or already processed
bool need_pad = NeedPad(Dims, anchor.data(), in_shape.data(), out_shape.data());
if (need_pad) {
slice_impl::SliceKernelImpl(
output, input, out_strides.data(), in_strides.data(), out_shape.data(),
in_shape.data(), anchor.data(), fill_values, channel_dim,
output, input, out_strides.data(), in_strides.data(), out_shape.data(), in_shape.data(),
anchor.data(), step.data(), fill_values, channel_dim,
std::integral_constant<int, Dims>(),
std::integral_constant<bool, false>(),
std::integral_constant<bool, true>());
} else {
slice_impl::SliceKernelImpl(
output, input, out_strides.data(), in_strides.data(), out_shape.data(),
in_shape.data(), anchor.data(), fill_values, channel_dim,
output, input, out_strides.data(), in_strides.data(), out_shape.data(), in_shape.data(),
anchor.data(), step.data(), fill_values, channel_dim,
std::integral_constant<int, Dims>(),
std::integral_constant<bool, false>(),
std::integral_constant<bool, false>());
Expand All @@ -287,7 +295,7 @@ void SliceKernel(OutputType *output,
template <typename ExecutionEngine, typename OutputType, typename InputType, int Dims>
DLL_LOCAL // workaround for GCC bug: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80947
void SliceKernel(ExecutionEngine &exec_engine,
OutputType* out_data,
OutputType *out_data,
const InputType *in_data,
const TensorShape<Dims> &out_strides,
const TensorShape<Dims> &in_strides,
Expand Down Expand Up @@ -326,28 +334,30 @@ void SliceKernel(ExecutionEngine &exec_engine,

if (nblocks == 1) {
exec_engine.AddWork([=](int) {
SliceKernel(out_data, in_data, out_strides, in_strides, out_shape, in_shape,
args.anchor, GetPtr<OutputType>(args.fill_values), args.channel_dim);
SliceKernel(out_data, in_data, out_strides, in_strides, out_shape, in_shape, args.anchor,
args.step, GetPtr<OutputType>(args.fill_values), args.channel_dim);
}, kSliceCost * volume(out_shape), false); // do not start work immediately
return;
}

TensorShape<Dims> start; // zero-filled
const auto& end = out_shape;
const auto &end = out_shape;
ForEachBlock(
start, end, split_factor, 0, LastSplitDim(split_factor),
[&](const TensorShape<Dims> &blk_start, const TensorShape<Dims> &blk_end) {
auto output_ptr = out_data;
TensorShape<Dims> blk_anchor;
TensorShape<Dims> blk_shape;
TensorShape<Dims> blk_step;
for (int d = 0; d < Dims; d++) {
output_ptr += blk_start[d] * out_strides[d];
blk_shape[d] = blk_end[d] - blk_start[d];
blk_anchor[d] = args.anchor[d] + blk_start[d];
blk_step[d] = args.step[d];
}
exec_engine.AddWork([=](int) {
SliceKernel(output_ptr, in_data, out_strides, in_strides, blk_shape, in_shape,
blk_anchor, GetPtr<OutputType>(args.fill_values), args.channel_dim);
blk_anchor, blk_step, GetPtr<OutputType>(args.fill_values), args.channel_dim);
}, kSliceCost * volume(blk_shape), false); // do not start work immediately
});
// scheduled work does not start until user calls Run()
Expand All @@ -359,15 +369,15 @@ void SliceKernel(ExecutionEngine &exec_engine,
*/
template <typename OutputType, typename InputType, int Dims>
void SliceKernel(SequentialExecutionEngine &exec_engine,
OutputType* out_data,
OutputType *out_data,
const InputType *in_data,
const TensorShape<Dims> &out_strides,
const TensorShape<Dims> &in_strides,
const TensorShape<Dims> &out_shape,
const TensorShape<Dims> &in_shape,
const SliceArgs<OutputType, Dims> &args,
int /* min_blk_sz */ = -1, int /* req_nblocks */ = -1) {
(void) exec_engine;
(void)exec_engine;

// If the output and input data type is the same and the slice arguments take the whole extent
// of the input, then we can simply run memcpy.
Expand All @@ -377,17 +387,16 @@ void SliceKernel(SequentialExecutionEngine &exec_engine,
return;
}

SliceKernel(out_data, in_data, out_strides, in_strides, out_shape, in_shape,
args.anchor, GetPtr<OutputType>(args.fill_values), args.channel_dim);
SliceKernel(out_data, in_data, out_strides, in_strides, out_shape, in_shape, args.anchor,
args.step, GetPtr<OutputType>(args.fill_values), args.channel_dim);
}

template <typename OutputType, typename InputType, int Dims>
class SliceCPU {
public:
static_assert(Dims >= 0, "Dims must be >= 0");

KernelRequirements Setup(KernelContext &context,
const InTensorCPU<InputType, Dims> &in,
KernelRequirements Setup(KernelContext &context, const InTensorCPU<InputType, Dims> &in,
const SliceArgs<OutputType, Dims> &slice_args) {
KernelRequirements req;
auto shape = GetOutputShape(in.shape, slice_args);
Expand All @@ -404,8 +413,9 @@ class SliceCPU {
* The user is responsible to synchronize with the execution engine.
*
* For execution engines other than SequentialExecutionEngine, the algorithm will try
* to split the slice into similar sized blocks until we either reach a minimum of ``req_nblocks``
* or the block volume is smaller than the minimum practical size, ``min_blk_sz``.
* to split the slice into similar sized blocks until we either reach a minimum of
* ``req_nblocks`` or the block volume is smaller than the minimum practical size,
* ``min_blk_sz``.
* @param context Kernel context
* @param out Output tensor view
* @param in Input tensor view
Expand All @@ -420,13 +430,14 @@ class SliceCPU {
InTensorCPU<InputType, Dims> in,
const SliceArgs<OutputType, Dims> &args,
ExecutionEngine &exec_engine,
int min_blk_sz = kSliceMinBlockSize, int req_nblocks = -1) {
int min_blk_sz = kSliceMinBlockSize,
int req_nblocks = -1) {
auto out_strides = GetStrides(out.shape);
auto in_strides = GetStrides(in.shape);

// fill values should not be empty. It should be left default if not used
assert(!args.fill_values.empty());
const OutputType* fill_values = args.fill_values.data();
const OutputType *fill_values = args.fill_values.data();
int fill_values_size = args.fill_values.size();
if (fill_values_size > 1) {
DALI_ENFORCE(args.channel_dim >= 0 && args.channel_dim < Dims,
Expand Down
Loading

0 comments on commit 8b0ee6d

Please sign in to comment.