From 035bd81ed0aadefbc6b8892b739159d76b7c3d0d Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Thu, 23 Apr 2020 12:26:19 +0200 Subject: [PATCH 1/6] Add CoordFlip CPU operator Signed-off-by: Joaquin Anton --- dali/operators/CMakeLists.txt | 1 + dali/operators/coord/CMakeLists.txt | 17 +++ dali/operators/coord/coord_flip.cc | 113 +++++++++++++++++++ dali/operators/coord/coord_flip.h | 60 ++++++++++ dali/test/python/test_operator_coord_flip.py | 90 +++++++++++++++ 5 files changed, 281 insertions(+) create mode 100644 dali/operators/coord/CMakeLists.txt create mode 100644 dali/operators/coord/coord_flip.cc create mode 100644 dali/operators/coord/coord_flip.h create mode 100644 dali/test/python/test_operator_coord_flip.py diff --git a/dali/operators/CMakeLists.txt b/dali/operators/CMakeLists.txt index 94f0b22c80..6aa5a77328 100644 --- a/dali/operators/CMakeLists.txt +++ b/dali/operators/CMakeLists.txt @@ -16,6 +16,7 @@ project(dali_operator CUDA CXX C) add_subdirectory(audio) add_subdirectory(bbox) +add_subdirectory(coord) add_subdirectory(debug) add_subdirectory(decoder) add_subdirectory(generic) diff --git a/dali/operators/coord/CMakeLists.txt b/dali/operators/coord/CMakeLists.txt new file mode 100644 index 0000000000..bf41aa0c73 --- /dev/null +++ b/dali/operators/coord/CMakeLists.txt @@ -0,0 +1,17 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +collect_headers(DALI_INST_HDRS PARENT_SCOPE) +collect_sources(DALI_OPERATOR_SRCS PARENT_SCOPE) +collect_test_sources(DALI_OPERATOR_TEST_SRCS PARENT_SCOPE) diff --git a/dali/operators/coord/coord_flip.cc b/dali/operators/coord/coord_flip.cc new file mode 100644 index 0000000000..da05488c65 --- /dev/null +++ b/dali/operators/coord/coord_flip.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dali/operators/coord/coord_flip.h" + +namespace dali { + +DALI_SCHEMA(CoordFlip) + .DocStr( + R"code(Transforms normalized coordinates (range [0.0, 1.0]) so that they map to the same place after +horizontal or/and vertical flip of the input they refer to.)code") + .NumInput(1) + .NumOutput(1) + .AddOptionalArg( + "layout", + R"code(Determines the layout of the coordinates. + Possible values are: + + ``x`` (horizontal position), ``y`` (vertical position), ``z`` (depthwise position), + +Note: If left empty, ``"xy"`` or ``"xyz"`` will be assumed, depending on the number of dimensions. +)code", + TensorLayout{""}) + .AddOptionalArg("horizontal", R"code(Perform flip along horizontal axis.)code", 1, true) + .AddOptionalArg("vertical", R"code(Perform flip along vertical axis.)code", 0, true) + .AddOptionalArg("depthwise", R"code(Perform flip along depthwise axis.)code", 0, true); + +class CoordFlipCPU : public CoordFlip { + public: + explicit CoordFlipCPU(const OpSpec &spec) + : CoordFlip(spec) {} + + ~CoordFlipCPU() override = default; + DISABLE_COPY_MOVE_ASSIGN(CoordFlipCPU); + + void RunImpl(workspace_t &ws) override; + + USE_OPERATOR_MEMBERS(); + using Operator::RunImpl; + using CoordFlip::layout_; +}; + +void CoordFlipCPU::RunImpl(workspace_t &ws) { + const auto &input = ws.InputRef(0); + DALI_ENFORCE(input.type().id() == DALI_FLOAT, "Input is expected to be float"); + + auto &output = ws.OutputRef(0); + auto &thread_pool = ws.GetThreadPool(); + + if (layout_.empty()) { + layout_ = ndim_ == 2 ? "xy" : "xyz"; + } + + int x_dim = layout_.find('x'); + DALI_ENFORCE(x_dim >= 0, "Dimension \"x\" not found in the layout"); + + int y_dim = layout_.find('y'); + if (ndim_ > 1) + DALI_ENFORCE(y_dim >= 0, "Dimension \"y\" not found in the layout"); + + int z_dim = layout_.find('z'); + if (ndim_ > 2) + DALI_ENFORCE(z_dim >= 0, "Dimension \"z\" not found in the layout"); + + for (int sample_id = 0; sample_id < batch_size_; sample_id++) { + bool horizontal_flip = spec_.GetArgument("horizontal", &ws, sample_id); + bool vertical_flip = spec_.GetArgument("vertical", &ws, sample_id); + bool depthwise_flip = spec_.GetArgument("depthwise", &ws, sample_id); + std::array flip_dim = {false, false, false}; + + if (horizontal_flip) { + flip_dim[x_dim] = horizontal_flip; + } + + if (vertical_flip) { + flip_dim[y_dim] = vertical_flip; + } + + if (depthwise_flip) { + flip_dim[z_dim] = depthwise_flip; + } + + thread_pool.DoWorkWithID( + [this, &input, &output, sample_id, flip_dim](int thread_id) { + const auto *in = input[sample_id].data(); + auto *out = output[sample_id].mutable_data(); + auto in_size = volume(input[sample_id].shape()); + int d = 0; + int64_t i = 0; + for (; i < in_size; i++, d++) { + if (d == ndim_) d = 0; + assert(in[i] >= 0.0f && in[i] <= 1.0f); + out[i] = flip_dim[d] ? 1.0f - in[i] : in[i]; + } + }); + } + thread_pool.WaitForWork(); +} + +DALI_REGISTER_OPERATOR(CoordFlip, CoordFlipCPU, CPU); + +} // namespace dali diff --git a/dali/operators/coord/coord_flip.h b/dali/operators/coord/coord_flip.h new file mode 100644 index 0000000000..2044a3f8c5 --- /dev/null +++ b/dali/operators/coord/coord_flip.h @@ -0,0 +1,60 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_OPERATORS_COORD_COORD_FLIP_H_ +#define DALI_OPERATORS_COORD_COORD_FLIP_H_ + +#include +#include + +#include "dali/pipeline/operator/common.h" +#include "dali/pipeline/operator/operator.h" + +namespace dali { + +template +class CoordFlip : public Operator { + public: + explicit CoordFlip(const OpSpec &spec) + : Operator(spec) + , layout_(spec.GetArgument("layout")) {} + + ~CoordFlip() override = default; + DISABLE_COPY_MOVE_ASSIGN(CoordFlip); + + protected: + bool CanInferOutputs() const override { + return true; + } + + bool SetupImpl(std::vector &output_desc, const workspace_t &ws) override { + const auto &input = ws.template InputRef(0); + output_desc.resize(1); + auto in_shape = input.shape(); + output_desc[0].shape = in_shape; + output_desc[0].type = input.type(); + + DALI_ENFORCE(in_shape[0].size() == 2); + ndim_ = in_shape[0][1]; + DALI_ENFORCE(ndim_ >= 1 && ndim_ <= 3, make_string("Unexpected number of dimensions ", ndim_)); + return true; + } + + TensorLayout layout_; + int ndim_; +}; + +} // namespace dali + +#endif // DALI_OPERATORS_COORD_COORD_FLIP_H_ diff --git a/dali/test/python/test_operator_coord_flip.py b/dali/test/python/test_operator_coord_flip.py new file mode 100644 index 0000000000..62f1f60526 --- /dev/null +++ b/dali/test/python/test_operator_coord_flip.py @@ -0,0 +1,90 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvidia.dali.pipeline import Pipeline +import nvidia.dali.ops as ops +import nvidia.dali.fn as fn +import nvidia.dali.types as types +import nvidia.dali as dali +import numpy as np +from numpy.testing import assert_array_equal, assert_allclose +from functools import partial +from test_utils import check_batch +from test_utils import compare_pipelines +from test_utils import RandomDataIterator +import math +from nose.tools import * + +class CoordFlipPipeline(Pipeline): + def __init__(self, device, batch_size, iterator, layout, + num_threads=1, device_id=0): + super(CoordFlipPipeline, self).__init__(batch_size, num_threads, device_id) + self.device = device + self.iterator = iterator + self.coord_flip = ops.CoordFlip(device = self.device, layout=layout) + self.flip_h = ops.CoinFlip(probability = 0.5) + self.flip_v = ops.CoinFlip(probability = 0.5) + self.flip_d = ops.CoinFlip(probability = 0.5) if len(layout) == 3 else None + + def define_graph(self): + inputs = fn.external_source(lambda: next(self.iterator)) + inputs = 0.5 + inputs # Make it fit the range [0.0, 1.0] + out = inputs.gpu() if self.device == 'gpu' else inputs + h = self.flip_h() + v = self.flip_v() + d = self.flip_d() if self.flip_d is not None else None + out = self.coord_flip(out, horizontal=h, vertical=v, depthwise=d) + outputs = [inputs, out, h, v] + if d is not None: + outputs.append(d) + return outputs + +def check_operator_coord_flip(device, batch_size, layout, shape): + eii1 = RandomDataIterator(batch_size, shape=shape, dtype=np.float32) + pipe = CoordFlipPipeline(device, batch_size, iter(eii1), layout) + pipe.build() + for i in range(30): + outputs = pipe.run() + for sample in range(batch_size): + in_coords = outputs[0].at(sample) + out_coords = outputs[1].at(sample) + h = outputs[2].at(sample) + v = outputs[3].at(sample) + d = None + if len(layout) == 3: + d = outputs[4].at(sample) + npoints, ndim = in_coords.shape + + flip_dim = [h[0], v[0]] + if ndim == 3: + flip_dim.append(d[0]) + + expected_out_coords = np.copy(in_coords) + for d in range(ndim): + if flip_dim[d]: + expected_out_coords[:, d] = 1.0 - in_coords[:, d] + np.testing.assert_allclose(out_coords[:, d], expected_out_coords[:, d]) + +def test_operator_coord_flip(): + for device in ['cpu']: + for batch_size in [1, 3]: + for layout, shape in [("xy", (10, 2)), ("xyz", (10, 3))]: + yield check_operator_coord_flip, device, batch_size, layout, shape + +def main(): + for test in test_operator_coord_flip(): + test[0](*test[1:]) + +if __name__ == '__main__': + main() From 8fbd4779606b81dff8198b5eb7ca8159eef36430 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Fri, 24 Apr 2020 09:20:36 +0200 Subject: [PATCH 2/6] Code review fixes Signed-off-by: Joaquin Anton --- dali/operators/coord/coord_flip.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dali/operators/coord/coord_flip.cc b/dali/operators/coord/coord_flip.cc index da05488c65..5d0a574a28 100644 --- a/dali/operators/coord/coord_flip.cc +++ b/dali/operators/coord/coord_flip.cc @@ -100,8 +100,10 @@ void CoordFlipCPU::RunImpl(workspace_t &ws) { int64_t i = 0; for (; i < in_size; i++, d++) { if (d == ndim_) d = 0; - assert(in[i] >= 0.0f && in[i] <= 1.0f); - out[i] = flip_dim[d] ? 1.0f - in[i] : in[i]; + auto in_val = in[i]; + DALI_ENFORCE(in_val >= 0.0f && in_val <= 1.0f, + "Input expected to be within the range [0.0, 1.0]"); + out[i] = flip_dim[d] ? 1.0f - in_val : in_val; } }); } From c2482b335d5e809365c0b2c36959842710c15131 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Mon, 27 Apr 2020 16:43:18 +0200 Subject: [PATCH 3/6] Add flip center argument Signed-off-by: Joaquin Anton --- dali/operators/bbox/bb_flip.cc | 6 ++-- dali/operators/coord/coord_flip.cc | 47 ++++++++++++------------------ dali/operators/coord/coord_flip.h | 17 +++++++++++ dali/operators/generic/flip.cc | 8 ++--- 4 files changed, 42 insertions(+), 36 deletions(-) diff --git a/dali/operators/bbox/bb_flip.cc b/dali/operators/bbox/bb_flip.cc index 86647c94ed..574e929ee0 100644 --- a/dali/operators/bbox/bb_flip.cc +++ b/dali/operators/bbox/bb_flip.cc @@ -26,7 +26,7 @@ const std::string kVerticalArgName = "vertical"; // NOLINT DALI_REGISTER_OPERATOR(BbFlip, BbFlip, CPU); DALI_SCHEMA(BbFlip) - .DocStr(R"code(Operator for horizontal flip (mirror) of bounding box. + .DocStr(R"code(Operator for horizontal or vertical flip (mirror) of bounding boxes. Input: Bounding box coordinates; in either `[x, y, w, h]` or `[left, top, right, bottom]` format. All coordinates are in the image coordinate system (i.e. 0.0-1.0))code") @@ -37,10 +37,10 @@ in the image coordinate system (i.e. 0.0-1.0))code") False for for width-height representation.)code", false, false) .AddOptionalArg(kHorizontalArgName, - R"code(Perform flip along horizontal axis.)code", + R"code(Flip horizontal dimension.)code", 1, true) .AddOptionalArg(kVerticalArgName, - R"code(Perform flip along vertical axis.)code", + R"code(Flip vertical dimension.)code", 0, true); BbFlip::BbFlip(const dali::OpSpec &spec) diff --git a/dali/operators/coord/coord_flip.cc b/dali/operators/coord/coord_flip.cc index 5d0a574a28..cccbc7664b 100644 --- a/dali/operators/coord/coord_flip.cc +++ b/dali/operators/coord/coord_flip.cc @@ -18,8 +18,8 @@ namespace dali { DALI_SCHEMA(CoordFlip) .DocStr( - R"code(Transforms normalized coordinates (range [0.0, 1.0]) so that they map to the same place after -horizontal or/and vertical flip of the input they refer to.)code") + R"code(Transforms coordinates so that they are flipped (point reflected) with respect +to a center point.)code") .NumInput(1) .NumOutput(1) .AddOptionalArg( @@ -32,9 +32,13 @@ horizontal or/and vertical flip of the input they refer to.)code") Note: If left empty, ``"xy"`` or ``"xyz"`` will be assumed, depending on the number of dimensions. )code", TensorLayout{""}) - .AddOptionalArg("horizontal", R"code(Perform flip along horizontal axis.)code", 1, true) - .AddOptionalArg("vertical", R"code(Perform flip along vertical axis.)code", 0, true) - .AddOptionalArg("depthwise", R"code(Perform flip along depthwise axis.)code", 0, true); + .AddOptionalArg("horizontal", R"code(Flip horizontal dimension.)code", 1, true) + .AddOptionalArg("vertical", R"code(Flip vertical dimension.)code", 0, true) + .AddOptionalArg("depthwise", R"code(Flip depthwise dimension.)code", 0, true) + .AddOptionalArg("center_x", R"code(Flip center on horizontal dimension.)code", 0.5f, true) + .AddOptionalArg("center_y", R"code(Flip center on vertical dimension.)code", 0.5f, true) + .AddOptionalArg("center_z", R"code(Flip center on depthwise dimension.)code", 0.5f, true); + class CoordFlipCPU : public CoordFlip { public: @@ -53,15 +57,9 @@ class CoordFlipCPU : public CoordFlip { void CoordFlipCPU::RunImpl(workspace_t &ws) { const auto &input = ws.InputRef(0); - DALI_ENFORCE(input.type().id() == DALI_FLOAT, "Input is expected to be float"); - auto &output = ws.OutputRef(0); auto &thread_pool = ws.GetThreadPool(); - if (layout_.empty()) { - layout_ = ndim_ == 2 ? "xy" : "xyz"; - } - int x_dim = layout_.find('x'); DALI_ENFORCE(x_dim >= 0, "Dimension \"x\" not found in the layout"); @@ -74,25 +72,18 @@ void CoordFlipCPU::RunImpl(workspace_t &ws) { DALI_ENFORCE(z_dim >= 0, "Dimension \"z\" not found in the layout"); for (int sample_id = 0; sample_id < batch_size_; sample_id++) { - bool horizontal_flip = spec_.GetArgument("horizontal", &ws, sample_id); - bool vertical_flip = spec_.GetArgument("vertical", &ws, sample_id); - bool depthwise_flip = spec_.GetArgument("depthwise", &ws, sample_id); std::array flip_dim = {false, false, false}; + flip_dim[x_dim] = spec_.GetArgument("horizontal", &ws, sample_id); + flip_dim[y_dim] = spec_.GetArgument("vertical", &ws, sample_id); + flip_dim[z_dim] = spec_.GetArgument("depthwise", &ws, sample_id); - if (horizontal_flip) { - flip_dim[x_dim] = horizontal_flip; - } - - if (vertical_flip) { - flip_dim[y_dim] = vertical_flip; - } - - if (depthwise_flip) { - flip_dim[z_dim] = depthwise_flip; - } + std::array mirrored_origin = {1.0f, 1.0f, 1.0f}; + mirrored_origin[x_dim] = 2.0f * spec_.GetArgument("center_x", &ws, sample_id); + mirrored_origin[y_dim] = 2.0f * spec_.GetArgument("center_y", &ws, sample_id); + mirrored_origin[z_dim] = 2.0f * spec_.GetArgument("center_z", &ws, sample_id); thread_pool.DoWorkWithID( - [this, &input, &output, sample_id, flip_dim](int thread_id) { + [this, &input, &output, sample_id, flip_dim, mirrored_origin](int thread_id) { const auto *in = input[sample_id].data(); auto *out = output[sample_id].mutable_data(); auto in_size = volume(input[sample_id].shape()); @@ -101,9 +92,7 @@ void CoordFlipCPU::RunImpl(workspace_t &ws) { for (; i < in_size; i++, d++) { if (d == ndim_) d = 0; auto in_val = in[i]; - DALI_ENFORCE(in_val >= 0.0f && in_val <= 1.0f, - "Input expected to be within the range [0.0, 1.0]"); - out[i] = flip_dim[d] ? 1.0f - in_val : in_val; + out[i] = flip_dim[d] ? mirrored_origin[d] - in_val : in_val; } }); } diff --git a/dali/operators/coord/coord_flip.h b/dali/operators/coord/coord_flip.h index 2044a3f8c5..780c27070a 100644 --- a/dali/operators/coord/coord_flip.h +++ b/dali/operators/coord/coord_flip.h @@ -40,6 +40,8 @@ class CoordFlip : public Operator { bool SetupImpl(std::vector &output_desc, const workspace_t &ws) override { const auto &input = ws.template InputRef(0); + DALI_ENFORCE(input.type().id() == DALI_FLOAT, "Input is expected to be float"); + output_desc.resize(1); auto in_shape = input.shape(); output_desc[0].shape = in_shape; @@ -48,6 +50,21 @@ class CoordFlip : public Operator { DALI_ENFORCE(in_shape[0].size() == 2); ndim_ = in_shape[0][1]; DALI_ENFORCE(ndim_ >= 1 && ndim_ <= 3, make_string("Unexpected number of dimensions ", ndim_)); + + if (layout_.empty()) { + switch (ndim_) { + case 1: + layout_ = "x"; + break; + case 2: + layout_ = "xy"; + break; + case 3: + default: + layout_ = "xyz"; + break; + } + } return true; } diff --git a/dali/operators/generic/flip.cc b/dali/operators/generic/flip.cc index 245e7f4d28..8eebdd5cf2 100644 --- a/dali/operators/generic/flip.cc +++ b/dali/operators/generic/flip.cc @@ -23,12 +23,12 @@ namespace dali { DALI_SCHEMA(Flip) - .DocStr(R"code(Flip the image over the horizontal and/or vertical axes.)code") + .DocStr(R"code(Flip selected dimensions (horizontal, vertical, depthwise).)code") .NumInput(1) .NumOutput(1) - .AddOptionalArg("horizontal", R"code(Perform a horizontal flip.)code", 1, true) - .AddOptionalArg("vertical", R"code(Perform a vertical flip.)code", 0, true) - .AddOptionalArg("depthwise", R"code(Perform a depthwise flip.)code", 0, true) + .AddOptionalArg("horizontal", R"code(Flip horizontal dimension.)code", 1, true) + .AddOptionalArg("vertical", R"code(Flip vertical dimension.)code", 0, true) + .AddOptionalArg("depthwise", R"code(Flip depthwise dimension.)code", 0, true) .InputLayout({"FDHWC", "FHWC", "DHWC", "HWC", "FCDHW", "FCHW", "CDHW", "CHW"}) .AllowSequences() .SupportVolumetric(); From 49075820486335167f8a667d6f0df6279e58bed4 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Mon, 27 Apr 2020 16:56:46 +0200 Subject: [PATCH 4/6] Code review fixes Signed-off-by: Joaquin Anton --- dali/operators/coord/coord_flip.cc | 21 +++++----- dali/test/python/test_operator_coord_flip.py | 40 +++++++++++--------- 2 files changed, 33 insertions(+), 28 deletions(-) diff --git a/dali/operators/coord/coord_flip.cc b/dali/operators/coord/coord_flip.cc index cccbc7664b..829de9576a 100644 --- a/dali/operators/coord/coord_flip.cc +++ b/dali/operators/coord/coord_flip.cc @@ -24,17 +24,18 @@ to a center point.)code") .NumOutput(1) .AddOptionalArg( "layout", - R"code(Determines the layout of the coordinates. - Possible values are: + R"code(Determines the layout of the coordinates. Possible values are: - ``x`` (horizontal position), ``y`` (vertical position), ``z`` (depthwise position), + - ``x`` (horizontal position), + - ``y`` (vertical position), + - ``z`` (depthwise position), -Note: If left empty, ``"xy"`` or ``"xyz"`` will be assumed, depending on the number of dimensions. +Note: If left empty, ``"x"``, ``"xy"`` or ``"xyz"`` will be assumed, depending on the number of dimensions. )code", TensorLayout{""}) - .AddOptionalArg("horizontal", R"code(Flip horizontal dimension.)code", 1, true) - .AddOptionalArg("vertical", R"code(Flip vertical dimension.)code", 0, true) - .AddOptionalArg("depthwise", R"code(Flip depthwise dimension.)code", 0, true) + .AddOptionalArg("flip_x", R"code(Flip horizontal (x) dimension.)code", 1, true) + .AddOptionalArg("flip_y", R"code(Flip vertical (y) dimension.)code", 0, true) + .AddOptionalArg("flip_z", R"code(Flip depthwise (z) dimension.)code", 0, true) .AddOptionalArg("center_x", R"code(Flip center on horizontal dimension.)code", 0.5f, true) .AddOptionalArg("center_y", R"code(Flip center on vertical dimension.)code", 0.5f, true) .AddOptionalArg("center_z", R"code(Flip center on depthwise dimension.)code", 0.5f, true); @@ -73,9 +74,9 @@ void CoordFlipCPU::RunImpl(workspace_t &ws) { for (int sample_id = 0; sample_id < batch_size_; sample_id++) { std::array flip_dim = {false, false, false}; - flip_dim[x_dim] = spec_.GetArgument("horizontal", &ws, sample_id); - flip_dim[y_dim] = spec_.GetArgument("vertical", &ws, sample_id); - flip_dim[z_dim] = spec_.GetArgument("depthwise", &ws, sample_id); + flip_dim[x_dim] = spec_.GetArgument("flip_x", &ws, sample_id); + flip_dim[y_dim] = spec_.GetArgument("flip_y", &ws, sample_id); + flip_dim[z_dim] = spec_.GetArgument("flip_z", &ws, sample_id); std::array mirrored_origin = {1.0f, 1.0f, 1.0f}; mirrored_origin[x_dim] = 2.0f * spec_.GetArgument("center_x", &ws, sample_id); diff --git a/dali/test/python/test_operator_coord_flip.py b/dali/test/python/test_operator_coord_flip.py index 62f1f60526..cc95c7352f 100644 --- a/dali/test/python/test_operator_coord_flip.py +++ b/dali/test/python/test_operator_coord_flip.py @@ -33,21 +33,21 @@ def __init__(self, device, batch_size, iterator, layout, self.device = device self.iterator = iterator self.coord_flip = ops.CoordFlip(device = self.device, layout=layout) - self.flip_h = ops.CoinFlip(probability = 0.5) - self.flip_v = ops.CoinFlip(probability = 0.5) - self.flip_d = ops.CoinFlip(probability = 0.5) if len(layout) == 3 else None + self.flip_x = ops.CoinFlip(probability = 0.5) + self.flip_y = ops.CoinFlip(probability = 0.5) + self.flip_z = ops.CoinFlip(probability = 0.5) if len(layout) == 3 else None def define_graph(self): inputs = fn.external_source(lambda: next(self.iterator)) inputs = 0.5 + inputs # Make it fit the range [0.0, 1.0] out = inputs.gpu() if self.device == 'gpu' else inputs - h = self.flip_h() - v = self.flip_v() - d = self.flip_d() if self.flip_d is not None else None - out = self.coord_flip(out, horizontal=h, vertical=v, depthwise=d) - outputs = [inputs, out, h, v] - if d is not None: - outputs.append(d) + flip_x = self.flip_x() + flip_y = self.flip_y() + flip_z = self.flip_z() if self.flip_z is not None else None + out = self.coord_flip(out, flip_x=flip_x, flip_y=flip_y, flip_z=flip_z) + outputs = [inputs, out, flip_x, flip_y] + if flip_z is not None: + outputs.append(flip_z) return outputs def check_operator_coord_flip(device, batch_size, layout, shape): @@ -58,17 +58,21 @@ def check_operator_coord_flip(device, batch_size, layout, shape): outputs = pipe.run() for sample in range(batch_size): in_coords = outputs[0].at(sample) - out_coords = outputs[1].at(sample) - h = outputs[2].at(sample) - v = outputs[3].at(sample) - d = None + out_coords = outputs[1].as_cpu().at(sample) if device == 'gpu' else outputs[1].at(sample) + if in_coords.shape == (): + assert(out_coords.shape == ()) + continue + + flip_x = outputs[2].at(sample) + flip_y = outputs[3].at(sample) + flip_z = None if len(layout) == 3: - d = outputs[4].at(sample) + flip_z = outputs[4].at(sample) npoints, ndim = in_coords.shape - flip_dim = [h[0], v[0]] + flip_dim = [flip_x[0], flip_y[0]] if ndim == 3: - flip_dim.append(d[0]) + flip_dim.append(flip_z[0]) expected_out_coords = np.copy(in_coords) for d in range(ndim): @@ -79,7 +83,7 @@ def check_operator_coord_flip(device, batch_size, layout, shape): def test_operator_coord_flip(): for device in ['cpu']: for batch_size in [1, 3]: - for layout, shape in [("xy", (10, 2)), ("xyz", (10, 3))]: + for layout, shape in [("x", (10, 1)), ("xy", (10, 2)), ("xyz", (10, 3)), ("xy", (0, 2))]: yield check_operator_coord_flip, device, batch_size, layout, shape def main(): From 5347fc0cd88b399d2395893a44c6f23e6627d157 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Tue, 28 Apr 2020 10:43:43 +0200 Subject: [PATCH 5/6] Add tests for custom coord flip center Signed-off-by: Joaquin Anton --- dali/test/python/test_operator_coord_flip.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/dali/test/python/test_operator_coord_flip.py b/dali/test/python/test_operator_coord_flip.py index cc95c7352f..33bfefe1d3 100644 --- a/dali/test/python/test_operator_coord_flip.py +++ b/dali/test/python/test_operator_coord_flip.py @@ -28,11 +28,13 @@ class CoordFlipPipeline(Pipeline): def __init__(self, device, batch_size, iterator, layout, + center_x = None, center_y = None, center_z = None, num_threads=1, device_id=0): super(CoordFlipPipeline, self).__init__(batch_size, num_threads, device_id) self.device = device self.iterator = iterator - self.coord_flip = ops.CoordFlip(device = self.device, layout=layout) + self.coord_flip = ops.CoordFlip(device = self.device, layout=layout, + center_x=center_x, center_y=center_y, center_z=center_z) self.flip_x = ops.CoinFlip(probability = 0.5) self.flip_y = ops.CoinFlip(probability = 0.5) self.flip_z = ops.CoinFlip(probability = 0.5) if len(layout) == 3 else None @@ -50,9 +52,10 @@ def define_graph(self): outputs.append(flip_z) return outputs -def check_operator_coord_flip(device, batch_size, layout, shape): +def check_operator_coord_flip(device, batch_size, layout, shape, center_x, center_y, center_z): eii1 = RandomDataIterator(batch_size, shape=shape, dtype=np.float32) - pipe = CoordFlipPipeline(device, batch_size, iter(eii1), layout) + pipe = CoordFlipPipeline(device, batch_size, iter(eii1), + layout, center_x, center_y, center_z) pipe.build() for i in range(30): outputs = pipe.run() @@ -74,17 +77,22 @@ def check_operator_coord_flip(device, batch_size, layout, shape): if ndim == 3: flip_dim.append(flip_z[0]) + center_dim = [center_x, center_y] + if ndim == 3: + center_dim.append(center_z) + expected_out_coords = np.copy(in_coords) for d in range(ndim): if flip_dim[d]: - expected_out_coords[:, d] = 1.0 - in_coords[:, d] + expected_out_coords[:, d] = 2 * center_dim[d] - in_coords[:, d] np.testing.assert_allclose(out_coords[:, d], expected_out_coords[:, d]) def test_operator_coord_flip(): for device in ['cpu']: for batch_size in [1, 3]: for layout, shape in [("x", (10, 1)), ("xy", (10, 2)), ("xyz", (10, 3)), ("xy", (0, 2))]: - yield check_operator_coord_flip, device, batch_size, layout, shape + for center_x, center_y, center_z in [(0.5, 0.5, 0.5), (0.0, 1.0, -0.5)]: + yield check_operator_coord_flip, device, batch_size, layout, shape, center_x, center_y, center_z def main(): for test in test_operator_coord_flip(): From 70de8280b1190f76a882a2edb00b6cb41ac7d260 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Tue, 28 Apr 2020 11:52:18 +0200 Subject: [PATCH 6/6] fixes Signed-off-by: Joaquin Anton --- dali/operators/coord/coord_flip.cc | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/dali/operators/coord/coord_flip.cc b/dali/operators/coord/coord_flip.cc index 829de9576a..ece77ba237 100644 --- a/dali/operators/coord/coord_flip.cc +++ b/dali/operators/coord/coord_flip.cc @@ -64,13 +64,16 @@ void CoordFlipCPU::RunImpl(workspace_t &ws) { int x_dim = layout_.find('x'); DALI_ENFORCE(x_dim >= 0, "Dimension \"x\" not found in the layout"); - int y_dim = layout_.find('y'); - if (ndim_ > 1) + int y_dim = 1; + if (ndim_ > 1) { + y_dim = layout_.find('y'); DALI_ENFORCE(y_dim >= 0, "Dimension \"y\" not found in the layout"); - - int z_dim = layout_.find('z'); - if (ndim_ > 2) + } + int z_dim = 2; + if (ndim_ > 2) { + z_dim = layout_.find('z'); DALI_ENFORCE(z_dim >= 0, "Dimension \"z\" not found in the layout"); + } for (int sample_id = 0; sample_id < batch_size_; sample_id++) { std::array flip_dim = {false, false, false};