-
Notifications
You must be signed in to change notification settings - Fork 648
Add out-of-bounds-policy (including pad support) to Slice/Crop #2000
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d6f7b36
10a50ca
18f24e2
4d95692
0cf0e86
58d95ed
7e76c2b
76376a4
21357a0
58ccd68
47df793
5d4ee77
ef1a54c
7846d9c
e1d21c3
eb9ad1a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -242,6 +242,22 @@ struct ArgsGen_CompletelyOutOfBounds{ | |
| } | ||
| }; | ||
|
|
||
| template <typename OutputType, int Dims = 3> | ||
|
||
| struct ArgsGen_SingleValuePad { | ||
| SliceArgs<OutputType, Dims> Get(const TensorShape<Dims>& input_shape) { | ||
| SliceArgs<OutputType, 3> args; | ||
| args.anchor[0] = -input_shape[0] / 2; | ||
| args.anchor[1] = -input_shape[1] / 2; | ||
| args.anchor[2] = 0; | ||
| args.shape[0] = 2 * input_shape[0]; | ||
| args.shape[1] = 2 * input_shape[1]; | ||
| args.shape[2] = input_shape[2]; | ||
| args.fill_values = {128}; | ||
| args.channel_dim = -1; | ||
|
||
| return args; | ||
| } | ||
| }; | ||
|
|
||
| template <typename OutputType, int Dims = 3> | ||
| struct ArgsGen_MultiChannelPad { | ||
| SliceArgs<OutputType, Dims> Get(const TensorShape<Dims>& input_shape) { | ||
|
|
@@ -332,6 +348,7 @@ using SLICE_TEST_TYPES = ::testing::Types< | |
| SliceTestArgs<int, int, 1, 1, 22, ArgsGen_RightSideOutOfBounds<int, 1>>, | ||
| SliceTestArgs<int, int, 2, 1, 22, ArgsGen_RightSideOutOfBounds<int, 2>>, | ||
| SliceTestArgs<int, int, 2, 1, 22, ArgsGen_CompletelyOutOfBounds<int, 2>>, | ||
| SliceTestArgs<int, int, 3, 1, 20, ArgsGen_SingleValuePad<int, 3>, 20, 20, 3>, | ||
| SliceTestArgs<int, int, 3, 1, 20, ArgsGen_MultiChannelPad<int, 3>, 20, 20, 3>, | ||
| SliceTestArgs<int, int, 3, 1, 20, ArgsGen_MultiChannelPad_ChFirst<int, 3>, 3, 20, 20>, | ||
| SliceTestArgs<int, int, 3, 1, 20, ArgsGen_PadAlsoChDim<int, 3>, 20, 20, 3>, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| // Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #include <vector> | ||
| #include "dali/pipeline/operator/common.h" | ||
| #include "dali/pipeline/operator/operator.h" | ||
|
|
||
| namespace dali { | ||
|
|
||
| DALI_SCHEMA(OutOfBoundsAttr) | ||
| .DocStr(R"code(Out-of-bounds slicing attributes placeholder)code") | ||
| .AddOptionalArg("out_of_bounds_policy", | ||
| R"code(Determines the policy when slicing out of bounds of the input. | ||
| Supported values are: | ||
|
|
||
| - "error" (default) : Attempting to slice outside of the bounds of the image will produce an error. | ||
| - "pad": The input will be padded as needed with zeros or any other value specified with ``fill_values`` argument. | ||
| - "trim_to_shape": The slice window will be cut to the bounds of the input.))code", "error") | ||
| .AddOptionalArg("fill_values", | ||
| R"code(Determines padding values, only relevant if ``out_of_bounds_policy`` is set to "pad". | ||
| If a scalar is provided, it will be used for all the channels. If multiple values are given, there should be as many values as | ||
| channels (extent of dimension 'C' in the layout) in the output slice.)code", std::vector<float>{0.f}); | ||
|
|
||
| } // namespace dali |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| // Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #ifndef DALI_OPERATORS_GENERIC_SLICE_OUT_OF_BOUNDS_POLICY_H_ | ||
| #define DALI_OPERATORS_GENERIC_SLICE_OUT_OF_BOUNDS_POLICY_H_ | ||
|
|
||
| #include <string> | ||
| #include "dali/core/math_util.h" | ||
| #include "dali/core/tensor_shape.h" | ||
| #include "dali/core/tensor_shape_print.h" | ||
| #include "dali/pipeline/operator/common.h" | ||
|
|
||
| namespace dali { | ||
|
|
||
| template <bool inclusive_end> | ||
|
||
| DALI_HOST_DEV DALI_FORCEINLINE bool is_out_of_bounds(int64_t idx, int64_t data_extent) { | ||
| if (inclusive_end) // check idx is within [0, data_extent] | ||
| return static_cast<uint64_t>(idx) > static_cast<uint64_t>(data_extent); | ||
| else // check idx is within [0, data_extent) | ||
| return static_cast<uint64_t>(idx) >= static_cast<uint64_t>(data_extent); | ||
| } | ||
|
|
||
| /** | ||
| * @brief Determines what to do if slice parameters point to outside of the input bounds | ||
| */ | ||
| enum class OutOfBoundsPolicy { | ||
| Error, // sampling out of bounds will throw an error | ||
| TrimToShape, // Slice shape will be trimmed to fit the input bounds (potentially empty output) | ||
| Pad, // Slicing out of bounds will result in padding with zeroes or any other provided value(s) | ||
| }; | ||
|
|
||
| inline OutOfBoundsPolicy GetOutOfBoundsPolicy(const OpSpec &spec) { | ||
| bool has_out_of_bounds_policy = spec.HasArgument("out_of_bounds_policy"); | ||
| OutOfBoundsPolicy policy = OutOfBoundsPolicy::Error; | ||
| if (has_out_of_bounds_policy) { | ||
| auto policy_str = spec.GetArgument<std::string>("out_of_bounds_policy"); | ||
| if (policy_str == "pad") { | ||
| policy = OutOfBoundsPolicy::Pad; | ||
| } else if (policy_str == "trim_to_shape") { | ||
| policy = OutOfBoundsPolicy::TrimToShape; | ||
| } else if (policy_str == "error") { | ||
| policy = OutOfBoundsPolicy::Error; | ||
| } else { | ||
| DALI_FAIL( | ||
| make_string("Not supported out_of_bounds_policy: ", policy_str, | ||
| ". Supported values are \"pad\", \"trim_to_shape\", \"error\" (default)")); | ||
| } | ||
| } | ||
| return policy; | ||
| } | ||
|
|
||
| template <int Dims> | ||
| void ApplySliceBoundsPolicy(OutOfBoundsPolicy policy, const TensorShape<Dims> &input_shape, | ||
| TensorShape<Dims> &slice_anchor, TensorShape<Dims> &slice_shape) { | ||
| DALI_ENFORCE( | ||
| input_shape.size() == slice_anchor.size() && input_shape.size() == slice_shape.size(), | ||
| "Slice arguments should have the same number of dimensions as the input"); | ||
| switch (policy) { | ||
| case OutOfBoundsPolicy::Pad: | ||
| // nothing to do | ||
| break; | ||
|
|
||
| case OutOfBoundsPolicy::TrimToShape: | ||
| for (int d = 0; d < input_shape.size(); d++) { | ||
| auto slice_start = clamp<int64_t>(slice_anchor[d], 0, input_shape[d]); | ||
| auto slice_end = clamp<int64_t>(slice_anchor[d] + slice_shape[d], 0, input_shape[d]); | ||
| assert(slice_end >= slice_start); | ||
| slice_anchor[d] = slice_start; | ||
| slice_shape[d] = slice_end - slice_start; | ||
| } | ||
| break; | ||
|
|
||
| case OutOfBoundsPolicy::Error: | ||
| default: | ||
| for (int d = 0; d < input_shape.size(); d++) { | ||
| // start within [0, extent), and end within [0, extent] | ||
| if (is_out_of_bounds<false>(slice_anchor[d], input_shape[d]) || | ||
| is_out_of_bounds<true>(slice_anchor[d] + slice_shape[d], input_shape[d])) { | ||
| DALI_FAIL(make_string( | ||
| "Slice can't be place out of bounds with current policy. Got: input_shape={", | ||
|
||
| input_shape, "}, slice_anchor={", slice_anchor, "}, slice_shape={", slice_shape, | ||
| "}")); | ||
| } | ||
| } | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| } // namespace dali | ||
|
|
||
| #endif // DALI_OPERATORS_GENERIC_SLICE_OUT_OF_BOUNDS_POLICY_H_ | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes in this file are a bug fix