-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CombineTransforms operator #2317
Changes from 38 commits
dc385a8
81d2ead
82938cf
6975415
67f7b70
3d731e1
3243dd6
1c66aab
d9e54a4
a293d14
f6fe233
966ae17
0702de7
17a370d
a2e0a26
3c0044e
0910809
1417f72
dd4db07
f3804e9
84200d5
4cf8245
03f94de
2af9f64
9848fdf
890bdd4
642e97a
b4c7b38
edc99e1
fa087a6
df60cd7
42ac1ce
d1ddbd8
d3c1bce
e1c65c5
3ae3dee
4cf13db
fb6e3d7
c1a369a
5324e43
964515e
9a08245
b55954b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
#include "dali/core/format.h" | ||
#include "dali/core/geom/mat.h" | ||
#include "dali/core/static_switch.h" | ||
#include "dali/kernels/kernel_manager.h" | ||
#include "dali/pipeline/data/types.h" | ||
#include "dali/pipeline/operator/op_spec.h" | ||
#include "dali/pipeline/workspace/workspace.h" | ||
#include "dali/pipeline/operator/operator.h" | ||
|
||
#define TRANSFORM_INPUT_TYPES (float) | ||
|
||
namespace dali { | ||
|
||
template <int... values> | ||
using dims = std::integer_sequence<int, values...>; | ||
|
||
template <typename T, int mat_dim> | ||
using affine_mat_t = mat<mat_dim, mat_dim, T>; | ||
|
||
DALI_SCHEMA(transforms__Combine) | ||
.DocStr(R"code(Combines two or more affine transforms.)code") | ||
.NumInput(2, 99) | ||
.NumOutput(1) | ||
.AddParent("TransformAttr"); | ||
|
||
class CombineTransformsCPU : public Operator<CPUBackend> { | ||
public: | ||
explicit CombineTransformsCPU(const OpSpec &spec) : | ||
Operator<CPUBackend>(spec), | ||
reverse_order_(spec.GetArgument<bool>("reverse_order")) { | ||
} | ||
|
||
bool CanInferOutputs() const override { return true; } | ||
|
||
protected: | ||
bool SetupImpl(std::vector<OutputDesc> &output_descs, | ||
const workspace_t<CPUBackend> &ws) override { | ||
assert(ws.NumInput() > 1); | ||
TensorListShape<> in0_shape = ws.template InputRef<CPUBackend>(0).shape(); | ||
ndim_ = in0_shape[0][0]; | ||
nsamples_ = in0_shape.size(); | ||
|
||
DALI_ENFORCE(in0_shape.sample_dim() == 2 && | ||
in0_shape.size() > 0 && | ||
in0_shape[0][1] == (in0_shape[0][0] + 1), | ||
make_string( | ||
"The input, if provided, is expected to be a 2D tensor with dimensions " | ||
"(ndim, ndim+1) representing an affine transform. Got: ", in0_shape)); | ||
|
||
for (int i = 0; i < ws.NumInput(); i++) { | ||
const auto &shape = ws.template InputRef<CPUBackend>(i).shape(); | ||
DALI_ENFORCE(shape == in0_shape, | ||
make_string("All input transforms are expected to have the same shape. Got: ", | ||
in0_shape, " and ", shape, " for the ", i, "-th input.")); | ||
} | ||
|
||
output_descs.resize(1); // only one output | ||
output_descs[0].type = TypeTable::GetTypeInfo(dtype_); | ||
output_descs[0].shape = uniform_list_shape(nsamples_, {ndim_, ndim_+1}); | ||
return true; | ||
} | ||
|
||
template <typename T> | ||
void RunImplTyped(workspace_t<CPUBackend> &ws, dims<>) { | ||
DALI_FAIL(make_string("Unsupported number of dimensions ", ndim_)); | ||
} | ||
|
||
template <typename T, int ndim, int... ndims> | ||
void RunImplTyped(workspace_t<CPUBackend> &ws, dims<ndim, ndims...>) { | ||
if (ndim_ != ndim) { | ||
RunImplTyped<T>(ws, dims<ndims...>()); | ||
return; | ||
} | ||
|
||
constexpr int mat_dim = ndim + 1; | ||
auto &out = ws.template OutputRef<CPUBackend>(0); | ||
out.SetLayout({}); // no layout | ||
auto out_view = view<T>(out); | ||
|
||
for (int sample_idx = 0; sample_idx < nsamples_; sample_idx++) { | ||
auto mat = affine_mat_t<T, mat_dim>::identity(); | ||
for (int input_idx = 0; input_idx < ws.NumInput(); input_idx++) { | ||
auto &in = ws.template InputRef<CPUBackend>(input_idx); | ||
auto in_view = view<T>(in); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should get these views ones to some vector/smallvector? Now it will involve far too many of vector allocations and copies. Sorry, I haven't noticed that problem when I first suggested reversing the nesting order. |
||
auto next_mat = affine_mat_t<T, mat_dim>::identity(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The last row is never changed, so this can be hoisted outside the loop over inputs. |
||
for (int i = 0, k = 0; i < ndim; i++) | ||
for (int j = 0; j < ndim + 1; j++, k++) | ||
next_mat(i, j) = in_view[sample_idx].data[k]; | ||
mat = reverse_order_ ? mat * next_mat : next_mat * mat; // mat mul | ||
} | ||
|
||
for (int i = 0, k = 0; i < ndim; i++) { | ||
for (int j = 0; j < ndim + 1; j++, k++) { | ||
out_view[sample_idx].data[k] = mat(i, j); | ||
} | ||
} | ||
} | ||
} | ||
|
||
void RunImpl(workspace_t<CPUBackend> &ws) override { | ||
TYPE_SWITCH(dtype_, type2id, T, TRANSFORM_INPUT_TYPES, ( | ||
RunImplTyped<T>(ws, SupportedDims()); | ||
), DALI_FAIL(make_string("Unsupported data type: ", dtype_))); // NOLINT | ||
} | ||
|
||
private: | ||
using SupportedDims = dims<1, 2, 3, 4, 5, 6>; | ||
DALIDataType dtype_ = DALI_FLOAT; | ||
int ndim_ = -1; // will be inferred from the arguments or the input | ||
int nsamples_ = -1; | ||
bool reverse_order_ = false; | ||
}; | ||
|
||
DALI_REGISTER_OPERATOR(transforms__Combine, CombineTransformsCPU, CPU); | ||
|
||
} // namespace dali |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import sys | ||
import types | ||
|
||
def get_submodule(root, path): | ||
"""Gets or creates sumbodule(s) of `root`. | ||
If the module path contains multiple parts, multiple modules are traversed or created | ||
|
||
Parameters | ||
---------- | ||
`root` | ||
module object or name of the root module | ||
`path` | ||
period-separated path of the submodule or a list/tuple of submodule names""" | ||
|
||
if isinstance(root, str): | ||
root = sys.modules[root] | ||
|
||
if not path: | ||
return root | ||
|
||
if isinstance(path, str): | ||
if str == '': | ||
return root | ||
path = path.split('.') | ||
|
||
module_name = root.__name__ | ||
for part in path: | ||
m = getattr(root, part, None) | ||
module_name += '.' + part | ||
if m is None: | ||
m = sys.modules[module_name] = types.ModuleType(module_name) | ||
setattr(root, part, m) | ||
elif not isinstance(m, types.ModuleType): | ||
raise RuntimeError("The module {} already contains an attribute \"{}\", which is not a module, but {}".format( | ||
root, part, m)) | ||
root = m | ||
return root |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not applicable anymore