Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Stride to Subscript and Slice Kernel #5007

Merged
merged 25 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
0152190
add docker-build folder to gitignore, clang-format slice_cpu.h, level…
5had3z Jul 15, 2023
a9f7387
clang-format + black format, removed slice notimpl errors
5had3z Jul 15, 2023
898ba89
add step to slice args, multiply in stride by step, clang-format
5had3z Jul 15, 2023
d2cc8e6
update defauly pyver and add new runtime images
5had3z Jul 29, 2023
a936d16
fix build script
5had3z Jul 29, 2023
c84adfb
added .devcontainer and dockerfile
5had3z Aug 4, 2023
53fce55
remove deps post-compile, move pre-commit install
5had3z Aug 4, 2023
5eec656
added more devcontainer components, add step arg (can't easily handle…
5had3z Aug 5, 2023
734f5cd
step > 1 works (-ve not), add nsight to devctr
5had3z Aug 5, 2023
170df2f
Add nvjpeg2k and nvcomp to image
5had3z Aug 6, 2023
996d88d
remove dimension inlining and anchor embedding to enable stepping to …
5had3z Aug 19, 2023
aeec1f2
fix default values for step to be 1, clang-format
5had3z Aug 20, 2023
32f8471
add more tests for hi/lo, fix last element logic for reverse stride
5had3z Aug 20, 2023
831c572
added more tests, updated docs
5had3z Aug 20, 2023
51168dc
remove devcontainer and revert docker/build.sh
5had3z Aug 21, 2023
d9c9553
re-added dimension flattening with fixed logic + conditions
5had3z Aug 21, 2023
55cc4a7
re-added slicenopad flatten w/ step + anchor cond
5had3z Aug 21, 2023
b00063c
preapply anchor and step if no padding
5had3z Aug 21, 2023
e30c269
Update dali/kernels/slice/slice_kernel_utils.h
5had3z Aug 22, 2023
3c0f993
Add UnitCubeShape utility.
mzient Aug 22, 2023
948639d
added helper function to TensorShape to create filled tensor
5had3z Aug 22, 2023
fe0ac97
fix assertions, fix missing template param for ndim, removed unnessec…
5had3z Aug 23, 2023
c9ac7a8
Simplify step alongside anchor and shape.
mzient Aug 23, 2023
8e372b2
Add a targetted test for collapsing untouched dims.
mzient Aug 23, 2023
6d0fcc0
Restore formatting and comments.
mzient Aug 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ docs/op_autodoc
docs/fn_autodoc
docs/nvidia.ico
.DS_Store
build-docker-*
113 changes: 62 additions & 51 deletions dali/kernels/slice/slice_cpu.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -40,11 +40,12 @@ namespace slice_impl {
template <typename OutputType, typename InputType, bool OutOfBounds, bool NeedPad>
void SliceKernelImplChannelLast(OutputType *output,
const InputType *input,
const int64_t* out_strides,
const int64_t* in_strides,
const int64_t* out_shape,
const int64_t* in_shape,
const int64_t* anchor,
const int64_t *out_strides,
const int64_t *in_strides,
const int64_t *out_shape,
const int64_t *in_shape,
const int64_t *anchor,
const int64_t *step,
const OutputType *fill_values,
int channel_dim, // negative if no channel dim or already processed
std::integral_constant<bool, OutOfBounds>,
Expand Down Expand Up @@ -104,7 +105,7 @@ void SliceKernelImplChannelLast(OutputType *output,
output[out_c] = fill_values[out_c];

output += out_nchannels;
input += in_nchannels;
input += in_nchannels * step[d];
}
}

Expand All @@ -129,11 +130,12 @@ void SliceKernelImplChannelLast(OutputType *output,
template <typename OutputType, typename InputType, bool OutOfBounds, bool NeedPad>
void SliceKernelImpl(OutputType *output,
const InputType *input,
const int64_t* out_strides,
const int64_t* in_strides,
const int64_t* out_shape,
const int64_t* in_shape,
const int64_t* anchor,
const int64_t *out_strides,
const int64_t *in_strides,
const int64_t *out_shape,
const int64_t *in_shape,
const int64_t *anchor,
const int64_t *step,
const OutputType *fill_values,
int channel_dim, // negative if no channel dim or already processed
std::integral_constant<int, 1>,
Expand All @@ -151,24 +153,26 @@ void SliceKernelImpl(OutputType *output,
int out_idx = 0;

if (NeedPad) {
// out of bounds (left side)
for (; in_idx < 0 && out_idx < out_shape[d]; in_idx++, out_idx++) {
// out of bounds (left side of output)
for (; (in_idx < 0 || in_idx >= in_shape[d]) && out_idx < out_shape[d];
in_idx += step[d], out_idx++) {
output[out_idx] = *fill_values;
if (d == channel_dim)
fill_values++;
}
}

// within input bounds
for (; in_idx < in_shape[d] && out_idx < out_shape[d]; in_idx++, out_idx++) {
for (; (0 <= in_idx && in_idx < in_shape[d]) && out_idx < out_shape[d];
in_idx += step[d], out_idx++) {
output[out_idx] = clamp<OutputType>(input[in_idx]);
if (NeedPad && d == channel_dim)
fill_values++;
}

if (NeedPad) {
// out of bounds (right side)
for (; out_idx < out_shape[d]; in_idx++, out_idx++) {
// out of bounds (right side of output)
for (; out_idx < out_shape[d]; in_idx += step[d], out_idx += out_strides[d]) {
output[out_idx] = *fill_values;
if (d == channel_dim)
fill_values++;
Expand All @@ -180,11 +184,12 @@ void SliceKernelImpl(OutputType *output,
template <typename OutputType, typename InputType, bool OutOfBounds, bool NeedPad, int DimsLeft>
void SliceKernelImpl(OutputType *output,
const InputType *input,
const int64_t* out_strides,
const int64_t* in_strides,
const int64_t* out_shape,
const int64_t* in_shape,
const int64_t* anchor,
const int64_t *out_strides,
const int64_t *in_strides,
const int64_t *out_shape,
const int64_t *in_shape,
const int64_t *anchor,
const int64_t *step,
const OutputType *fill_values,
int channel_dim, // negative if no channel dim or already processed
std::integral_constant<int, DimsLeft>,
Expand All @@ -193,7 +198,7 @@ void SliceKernelImpl(OutputType *output,
// Special case for last 2 dimensions with channel-last configuration
if (DimsLeft == 2 && channel_dim == 1) {
SliceKernelImplChannelLast(output, input, out_strides, in_strides, out_shape, in_shape, anchor,
fill_values, channel_dim,
step, fill_values, channel_dim,
std::integral_constant<bool, OutOfBounds>(),
std::integral_constant<bool, NeedPad>());
return;
Expand All @@ -207,10 +212,11 @@ void SliceKernelImpl(OutputType *output,
input += anchor[d] * in_strides[d];

if (NeedPad) {
// out of bounds (left side)
for (; in_idx < 0 && out_idx < out_shape[d]; in_idx++, out_idx++) {
// out of bounds (left side of output)
for (; (in_idx < 0 || in_idx >= in_shape[d]) && out_idx < out_shape[d];
in_idx += step[d], out_idx++) {
SliceKernelImpl(output, input, out_strides + 1, in_strides + 1, out_shape + 1,
in_shape + 1, anchor + 1, fill_values, channel_dim - 1,
in_shape + 1, anchor + 1, step + 1, fill_values, channel_dim - 1,
std::integral_constant<int, DimsLeft - 1>(),
std::integral_constant<bool, true>(),
std::integral_constant<bool, NeedPad>());
Expand All @@ -221,24 +227,25 @@ void SliceKernelImpl(OutputType *output,
}

// within input bounds
for (; in_idx < in_shape[d] && out_idx < out_shape[d]; in_idx++, out_idx++) {
for (; (0 <= in_idx && in_idx < in_shape[d]) && out_idx < out_shape[d];
in_idx += step[d], out_idx++) {
SliceKernelImpl(output, input, out_strides + 1, in_strides + 1, out_shape + 1,
in_shape + 1, anchor + 1, fill_values, channel_dim - 1,
in_shape + 1, anchor + 1, step + 1, fill_values, channel_dim - 1,
std::integral_constant<int, DimsLeft - 1>(),
std::integral_constant<bool, OutOfBounds>(),
std::integral_constant<bool, NeedPad>());
output += out_strides[d];
if (!OutOfBounds)
input += in_strides[d];
input += in_strides[d] * step[d];
if (NeedPad && d == channel_dim)
fill_values++;
}

if (NeedPad) {
// out of bounds (right side)
for (; out_idx < out_shape[d]; in_idx++, out_idx++) {
// out of bounds (right side of output)
for (; out_idx < out_shape[d]; out_idx++) {
SliceKernelImpl(output, input, out_strides + 1, in_strides + 1, out_shape + 1,
in_shape + 1, anchor + 1, fill_values, channel_dim - 1,
in_shape + 1, anchor + 1, step + 1, fill_values, channel_dim - 1,
std::integral_constant<int, DimsLeft - 1>(),
std::integral_constant<bool, true>(),
std::integral_constant<bool, NeedPad>());
Expand All @@ -260,20 +267,21 @@ void SliceKernel(OutputType *output,
const TensorShape<Dims> &out_shape,
const TensorShape<Dims> &in_shape,
const TensorShape<Dims> &anchor,
const TensorShape<Dims> &step,
const OutputType *fill_values,
int channel_dim = -1) { // negative if no channel dim or already processed
bool need_pad = NeedPad(Dims, anchor.data(), in_shape.data(), out_shape.data());
if (need_pad) {
slice_impl::SliceKernelImpl(
output, input, out_strides.data(), in_strides.data(), out_shape.data(),
in_shape.data(), anchor.data(), fill_values, channel_dim,
output, input, out_strides.data(), in_strides.data(), out_shape.data(), in_shape.data(),
anchor.data(), step.data(), fill_values, channel_dim,
std::integral_constant<int, Dims>(),
std::integral_constant<bool, false>(),
std::integral_constant<bool, true>());
} else {
slice_impl::SliceKernelImpl(
output, input, out_strides.data(), in_strides.data(), out_shape.data(),
in_shape.data(), anchor.data(), fill_values, channel_dim,
output, input, out_strides.data(), in_strides.data(), out_shape.data(), in_shape.data(),
anchor.data(), step.data(), fill_values, channel_dim,
std::integral_constant<int, Dims>(),
std::integral_constant<bool, false>(),
std::integral_constant<bool, false>());
Expand All @@ -287,7 +295,7 @@ void SliceKernel(OutputType *output,
template <typename ExecutionEngine, typename OutputType, typename InputType, int Dims>
DLL_LOCAL // workaround for GCC bug: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80947
void SliceKernel(ExecutionEngine &exec_engine,
OutputType* out_data,
OutputType *out_data,
const InputType *in_data,
const TensorShape<Dims> &out_strides,
const TensorShape<Dims> &in_strides,
Expand Down Expand Up @@ -326,28 +334,30 @@ void SliceKernel(ExecutionEngine &exec_engine,

if (nblocks == 1) {
exec_engine.AddWork([=](int) {
SliceKernel(out_data, in_data, out_strides, in_strides, out_shape, in_shape,
args.anchor, GetPtr<OutputType>(args.fill_values), args.channel_dim);
SliceKernel(out_data, in_data, out_strides, in_strides, out_shape, in_shape, args.anchor,
args.step, GetPtr<OutputType>(args.fill_values), args.channel_dim);
}, kSliceCost * volume(out_shape), false); // do not start work immediately
return;
}

TensorShape<Dims> start; // zero-filled
const auto& end = out_shape;
const auto &end = out_shape;
ForEachBlock(
start, end, split_factor, 0, LastSplitDim(split_factor),
[&](const TensorShape<Dims> &blk_start, const TensorShape<Dims> &blk_end) {
auto output_ptr = out_data;
TensorShape<Dims> blk_anchor;
TensorShape<Dims> blk_shape;
TensorShape<Dims> blk_step;
for (int d = 0; d < Dims; d++) {
output_ptr += blk_start[d] * out_strides[d];
blk_shape[d] = blk_end[d] - blk_start[d];
blk_anchor[d] = args.anchor[d] + blk_start[d];
blk_step[d] = args.step[d];
}
exec_engine.AddWork([=](int) {
SliceKernel(output_ptr, in_data, out_strides, in_strides, blk_shape, in_shape,
blk_anchor, GetPtr<OutputType>(args.fill_values), args.channel_dim);
blk_anchor, blk_step, GetPtr<OutputType>(args.fill_values), args.channel_dim);
}, kSliceCost * volume(blk_shape), false); // do not start work immediately
});
// scheduled work does not start until user calls Run()
Expand All @@ -359,15 +369,15 @@ void SliceKernel(ExecutionEngine &exec_engine,
*/
template <typename OutputType, typename InputType, int Dims>
void SliceKernel(SequentialExecutionEngine &exec_engine,
OutputType* out_data,
OutputType *out_data,
const InputType *in_data,
const TensorShape<Dims> &out_strides,
const TensorShape<Dims> &in_strides,
const TensorShape<Dims> &out_shape,
const TensorShape<Dims> &in_shape,
const SliceArgs<OutputType, Dims> &args,
int /* min_blk_sz */ = -1, int /* req_nblocks */ = -1) {
(void) exec_engine;
(void)exec_engine;

// If the output and input data type is the same and the slice arguments take the whole extent
// of the input, then we can simply run memcpy.
Expand All @@ -377,17 +387,16 @@ void SliceKernel(SequentialExecutionEngine &exec_engine,
return;
}

SliceKernel(out_data, in_data, out_strides, in_strides, out_shape, in_shape,
args.anchor, GetPtr<OutputType>(args.fill_values), args.channel_dim);
SliceKernel(out_data, in_data, out_strides, in_strides, out_shape, in_shape, args.anchor,
args.step, GetPtr<OutputType>(args.fill_values), args.channel_dim);
}

template <typename OutputType, typename InputType, int Dims>
class SliceCPU {
public:
static_assert(Dims >= 0, "Dims must be >= 0");

KernelRequirements Setup(KernelContext &context,
const InTensorCPU<InputType, Dims> &in,
KernelRequirements Setup(KernelContext &context, const InTensorCPU<InputType, Dims> &in,
const SliceArgs<OutputType, Dims> &slice_args) {
KernelRequirements req;
auto shape = GetOutputShape(in.shape, slice_args);
Expand All @@ -404,8 +413,9 @@ class SliceCPU {
* The user is responsible to synchronize with the execution engine.
*
* For execution engines other than SequentialExecutionEngine, the algorithm will try
* to split the slice into similar sized blocks until we either reach a minimum of ``req_nblocks``
* or the block volume is smaller than the minimum practical size, ``min_blk_sz``.
* to split the slice into similar sized blocks until we either reach a minimum of
* ``req_nblocks`` or the block volume is smaller than the minimum practical size,
* ``min_blk_sz``.
* @param context Kernel context
* @param out Output tensor view
* @param in Input tensor view
Expand All @@ -420,13 +430,14 @@ class SliceCPU {
InTensorCPU<InputType, Dims> in,
const SliceArgs<OutputType, Dims> &args,
ExecutionEngine &exec_engine,
int min_blk_sz = kSliceMinBlockSize, int req_nblocks = -1) {
int min_blk_sz = kSliceMinBlockSize,
int req_nblocks = -1) {
auto out_strides = GetStrides(out.shape);
auto in_strides = GetStrides(in.shape);

// fill values should not be empty. It should be left default if not used
assert(!args.fill_values.empty());
const OutputType* fill_values = args.fill_values.data();
const OutputType *fill_values = args.fill_values.data();
int fill_values_size = args.fill_values.size();
if (fill_values_size > 1) {
DALI_ENFORCE(args.channel_dim >= 0 && args.channel_dim < Dims,
Expand Down
Loading