-
Notifications
You must be signed in to change notification settings - Fork 618
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CoordFlip CPU operator #1894
Changes from all commits
035bd81
8fbd477
c2482b3
4907582
5347fc0
70de828
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
collect_headers(DALI_INST_HDRS PARENT_SCOPE) | ||
collect_sources(DALI_OPERATOR_SRCS PARENT_SCOPE) | ||
collect_test_sources(DALI_OPERATOR_TEST_SRCS PARENT_SCOPE) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "dali/operators/coord/coord_flip.h" | ||
|
||
namespace dali { | ||
|
||
DALI_SCHEMA(CoordFlip) | ||
.DocStr( | ||
R"code(Transforms coordinates so that they are flipped (point reflected) with respect | ||
to a center point.)code") | ||
.NumInput(1) | ||
.NumOutput(1) | ||
.AddOptionalArg<TensorLayout>( | ||
"layout", | ||
R"code(Determines the layout of the coordinates. Possible values are: | ||
|
||
- ``x`` (horizontal position), | ||
- ``y`` (vertical position), | ||
- ``z`` (depthwise position), | ||
|
||
Note: If left empty, ``"x"``, ``"xy"`` or ``"xyz"`` will be assumed, depending on the number of dimensions. | ||
)code", | ||
TensorLayout{""}) | ||
.AddOptionalArg("flip_x", R"code(Flip horizontal (x) dimension.)code", 1, true) | ||
.AddOptionalArg("flip_y", R"code(Flip vertical (y) dimension.)code", 0, true) | ||
.AddOptionalArg("flip_z", R"code(Flip depthwise (z) dimension.)code", 0, true) | ||
.AddOptionalArg("center_x", R"code(Flip center on horizontal dimension.)code", 0.5f, true) | ||
.AddOptionalArg("center_y", R"code(Flip center on vertical dimension.)code", 0.5f, true) | ||
.AddOptionalArg("center_z", R"code(Flip center on depthwise dimension.)code", 0.5f, true); | ||
|
||
|
||
class CoordFlipCPU : public CoordFlip<CPUBackend> { | ||
public: | ||
explicit CoordFlipCPU(const OpSpec &spec) | ||
: CoordFlip<CPUBackend>(spec) {} | ||
|
||
~CoordFlipCPU() override = default; | ||
DISABLE_COPY_MOVE_ASSIGN(CoordFlipCPU); | ||
|
||
void RunImpl(workspace_t<CPUBackend> &ws) override; | ||
|
||
USE_OPERATOR_MEMBERS(); | ||
using Operator<CPUBackend>::RunImpl; | ||
using CoordFlip<CPUBackend>::layout_; | ||
}; | ||
|
||
void CoordFlipCPU::RunImpl(workspace_t<CPUBackend> &ws) { | ||
const auto &input = ws.InputRef<CPUBackend>(0); | ||
auto &output = ws.OutputRef<CPUBackend>(0); | ||
auto &thread_pool = ws.GetThreadPool(); | ||
|
||
int x_dim = layout_.find('x'); | ||
DALI_ENFORCE(x_dim >= 0, "Dimension \"x\" not found in the layout"); | ||
|
||
int y_dim = 1; | ||
if (ndim_ > 1) { | ||
y_dim = layout_.find('y'); | ||
DALI_ENFORCE(y_dim >= 0, "Dimension \"y\" not found in the layout"); | ||
} | ||
int z_dim = 2; | ||
if (ndim_ > 2) { | ||
z_dim = layout_.find('z'); | ||
DALI_ENFORCE(z_dim >= 0, "Dimension \"z\" not found in the layout"); | ||
} | ||
|
||
for (int sample_id = 0; sample_id < batch_size_; sample_id++) { | ||
std::array<bool, 3> flip_dim = {false, false, false}; | ||
flip_dim[x_dim] = spec_.GetArgument<int>("flip_x", &ws, sample_id); | ||
flip_dim[y_dim] = spec_.GetArgument<int>("flip_y", &ws, sample_id); | ||
flip_dim[z_dim] = spec_.GetArgument<int>("flip_z", &ws, sample_id); | ||
Comment on lines
+81
to
+82
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. y_dim and z_dim can be -1 if they are not in layout (1D and 2D case). Same below. |
||
|
||
std::array<float, 3> mirrored_origin = {1.0f, 1.0f, 1.0f}; | ||
mirrored_origin[x_dim] = 2.0f * spec_.GetArgument<float>("center_x", &ws, sample_id); | ||
mirrored_origin[y_dim] = 2.0f * spec_.GetArgument<float>("center_y", &ws, sample_id); | ||
mirrored_origin[z_dim] = 2.0f * spec_.GetArgument<float>("center_z", &ws, sample_id); | ||
|
||
thread_pool.DoWorkWithID( | ||
[this, &input, &output, sample_id, flip_dim, mirrored_origin](int thread_id) { | ||
const auto *in = input[sample_id].data<float>(); | ||
auto *out = output[sample_id].mutable_data<float>(); | ||
auto in_size = volume(input[sample_id].shape()); | ||
int d = 0; | ||
int64_t i = 0; | ||
for (; i < in_size; i++, d++) { | ||
Comment on lines
+94
to
+96
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a nitpick, but I'd appreciate a little bit more descriptive names ;) Like |
||
if (d == ndim_) d = 0; | ||
auto in_val = in[i]; | ||
out[i] = flip_dim[d] ? mirrored_origin[d] - in_val : in_val; | ||
} | ||
}); | ||
} | ||
thread_pool.WaitForWork(); | ||
} | ||
|
||
DALI_REGISTER_OPERATOR(CoordFlip, CoordFlipCPU, CPU); | ||
|
||
} // namespace dali |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef DALI_OPERATORS_COORD_COORD_FLIP_H_ | ||
#define DALI_OPERATORS_COORD_COORD_FLIP_H_ | ||
|
||
#include <string> | ||
#include <vector> | ||
|
||
#include "dali/pipeline/operator/common.h" | ||
#include "dali/pipeline/operator/operator.h" | ||
|
||
namespace dali { | ||
|
||
template <typename Backend> | ||
class CoordFlip : public Operator<Backend> { | ||
public: | ||
explicit CoordFlip(const OpSpec &spec) | ||
: Operator<Backend>(spec) | ||
, layout_(spec.GetArgument<TensorLayout>("layout")) {} | ||
|
||
~CoordFlip() override = default; | ||
DISABLE_COPY_MOVE_ASSIGN(CoordFlip); | ||
|
||
protected: | ||
bool CanInferOutputs() const override { | ||
return true; | ||
} | ||
|
||
bool SetupImpl(std::vector<OutputDesc> &output_desc, const workspace_t<Backend> &ws) override { | ||
const auto &input = ws.template InputRef<Backend>(0); | ||
DALI_ENFORCE(input.type().id() == DALI_FLOAT, "Input is expected to be float"); | ||
|
||
output_desc.resize(1); | ||
auto in_shape = input.shape(); | ||
output_desc[0].shape = in_shape; | ||
output_desc[0].type = input.type(); | ||
|
||
DALI_ENFORCE(in_shape[0].size() == 2); | ||
ndim_ = in_shape[0][1]; | ||
DALI_ENFORCE(ndim_ >= 1 && ndim_ <= 3, make_string("Unexpected number of dimensions ", ndim_)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It could work for 0-dim, right? Just return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what would be the meaning of a 0D coordinate? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will never be 0. Hovever, it's going to be possible to have |
||
|
||
if (layout_.empty()) { | ||
switch (ndim_) { | ||
case 1: | ||
layout_ = "x"; | ||
break; | ||
case 2: | ||
layout_ = "xy"; | ||
break; | ||
case 3: | ||
default: | ||
layout_ = "xyz"; | ||
break; | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
TensorLayout layout_; | ||
int ndim_; | ||
}; | ||
|
||
} // namespace dali | ||
|
||
#endif // DALI_OPERATORS_COORD_COORD_FLIP_H_ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from nvidia.dali.pipeline import Pipeline | ||
import nvidia.dali.ops as ops | ||
import nvidia.dali.fn as fn | ||
import nvidia.dali.types as types | ||
import nvidia.dali as dali | ||
import numpy as np | ||
from numpy.testing import assert_array_equal, assert_allclose | ||
from functools import partial | ||
from test_utils import check_batch | ||
from test_utils import compare_pipelines | ||
from test_utils import RandomDataIterator | ||
import math | ||
from nose.tools import * | ||
|
||
class CoordFlipPipeline(Pipeline): | ||
def __init__(self, device, batch_size, iterator, layout, | ||
center_x = None, center_y = None, center_z = None, | ||
num_threads=1, device_id=0): | ||
super(CoordFlipPipeline, self).__init__(batch_size, num_threads, device_id) | ||
self.device = device | ||
self.iterator = iterator | ||
self.coord_flip = ops.CoordFlip(device = self.device, layout=layout, | ||
center_x=center_x, center_y=center_y, center_z=center_z) | ||
self.flip_x = ops.CoinFlip(probability = 0.5) | ||
self.flip_y = ops.CoinFlip(probability = 0.5) | ||
self.flip_z = ops.CoinFlip(probability = 0.5) if len(layout) == 3 else None | ||
|
||
def define_graph(self): | ||
inputs = fn.external_source(lambda: next(self.iterator)) | ||
inputs = 0.5 + inputs # Make it fit the range [0.0, 1.0] | ||
out = inputs.gpu() if self.device == 'gpu' else inputs | ||
flip_x = self.flip_x() | ||
flip_y = self.flip_y() | ||
flip_z = self.flip_z() if self.flip_z is not None else None | ||
out = self.coord_flip(out, flip_x=flip_x, flip_y=flip_y, flip_z=flip_z) | ||
outputs = [inputs, out, flip_x, flip_y] | ||
if flip_z is not None: | ||
outputs.append(flip_z) | ||
return outputs | ||
|
||
def check_operator_coord_flip(device, batch_size, layout, shape, center_x, center_y, center_z): | ||
eii1 = RandomDataIterator(batch_size, shape=shape, dtype=np.float32) | ||
pipe = CoordFlipPipeline(device, batch_size, iter(eii1), | ||
layout, center_x, center_y, center_z) | ||
pipe.build() | ||
for i in range(30): | ||
outputs = pipe.run() | ||
for sample in range(batch_size): | ||
in_coords = outputs[0].at(sample) | ||
out_coords = outputs[1].as_cpu().at(sample) if device == 'gpu' else outputs[1].at(sample) | ||
if in_coords.shape == (): | ||
assert(out_coords.shape == ()) | ||
continue | ||
|
||
flip_x = outputs[2].at(sample) | ||
flip_y = outputs[3].at(sample) | ||
flip_z = None | ||
if len(layout) == 3: | ||
flip_z = outputs[4].at(sample) | ||
npoints, ndim = in_coords.shape | ||
|
||
flip_dim = [flip_x[0], flip_y[0]] | ||
if ndim == 3: | ||
flip_dim.append(flip_z[0]) | ||
|
||
center_dim = [center_x, center_y] | ||
if ndim == 3: | ||
center_dim.append(center_z) | ||
|
||
expected_out_coords = np.copy(in_coords) | ||
for d in range(ndim): | ||
if flip_dim[d]: | ||
expected_out_coords[:, d] = 2 * center_dim[d] - in_coords[:, d] | ||
np.testing.assert_allclose(out_coords[:, d], expected_out_coords[:, d]) | ||
|
||
def test_operator_coord_flip(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about testing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
for device in ['cpu']: | ||
for batch_size in [1, 3]: | ||
for layout, shape in [("x", (10, 1)), ("xy", (10, 2)), ("xyz", (10, 3)), ("xy", (0, 2))]: | ||
for center_x, center_y, center_z in [(0.5, 0.5, 0.5), (0.0, 1.0, -0.5)]: | ||
yield check_operator_coord_flip, device, batch_size, layout, shape, center_x, center_y, center_z | ||
|
||
def main(): | ||
for test in test_operator_coord_flip(): | ||
test[0](*test[1:]) | ||
|
||
if __name__ == '__main__': | ||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::array
has no specialization forbool
. Maybe we'd be better withstd::vector<bool>
here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vector of bool should be killed with fire and purged from existence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just want 3 bools, no need for dynamic allocation