-
Notifications
You must be signed in to change notification settings - Fork 618
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add squeeze operator #2792
Add squeeze operator #2792
Conversation
Signed-off-by: Rafal Maj <rmaj@nvidia.com>
Signed-off-by: Rafal Maj <rmaj@nvidia.com>
Signed-off-by: Rafal Maj <rmaj@nvidia.com>
Signed-off-by: Rafal Maj <rmaj@nvidia.com>
Signed-off-by: Rafal Maj <rmaj@nvidia.com>
Signed-off-by: Rafal Maj <rmaj@nvidia.com>
Signed-off-by: Rafal Maj <rmaj@nvidia.com>
Signed-off-by: Rafal Maj <rmaj@nvidia.com>
Signed-off-by: Rafal Maj <rmaj@nvidia.com>
Signed-off-by: Rafal Maj <rmaj@nvidia.com>
7cb1dc6
to
5f66bb0
Compare
Signed-off-by: Rafal Maj <rmaj@nvidia.com>
Signed-off-by: Rafal Maj <rmaj@nvidia.com>
Signed-off-by: Rafal Maj <rmaj@nvidia.com>
dali/operators/generic/squeeze.cc
Outdated
namespace dali { | ||
|
||
DALI_SCHEMA(Squeeze) | ||
.DocStr(R"code(Collapses the dimensions given as axes or axis_names. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.DocStr(R"code(Collapses the dimensions given as axes or axis_names. | |
.DocStr(R"code(Removes the dimensions given as ``axes`` or ``axis_names``. |
"Collapse" could be misunderstood as "flatten" or "combine" - which, I believe, would make for a very useful operator, but it's not this one.
dali/operators/generic/squeeze.cc
Outdated
DALI_SCHEMA(Squeeze) | ||
.DocStr(R"code(Collapses the dimensions given as axes or axis_names. | ||
|
||
It's an error to collapse a dimension that would cause the total volume to change.)code") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's an error to collapse a dimension that would cause the total volume to change.)code") | |
It's an error to remove a dimension that would cause the total volume to change.)code") |
Signed-off-by: Rafal Maj <rmaj@nvidia.com>
Signed-off-by: Rafal Maj <rmaj@nvidia.com>
dali/operators/generic/squeeze.cc
Outdated
All indices must be in the range of valid dimensions of the input)code", std::vector<int>(), true) | ||
.AddOptionalArg("axis_names", R"code(Layout columns which should be removed. | ||
|
||
All squeezed dimensions should have size 1 unless the total volume of the tensor is 0 before and after squeeze. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All squeezed dimensions should have size 1 unless the total volume of the tensor is 0 before and after squeeze. | |
All squeezed dimensions should have size 1, unless the total volume of the tensor is 0 before and after squeeze. |
nitpick
Signed-off-by: Rafal Maj <rmaj@nvidia.com>
Signed-off-by: Rafal Maj <rmaj@nvidia.com>
dali/operators/generic/squeeze.cc
Outdated
DALI_ENFORCE(spec.HasArgument("axes") + spec.HasArgument("axis_names") == 1, | ||
make_string("Provided both axes and axis_names argument")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DALI_ENFORCE(spec.HasArgument("axes") + spec.HasArgument("axis_names") == 1, | |
make_string("Provided both axes and axis_names argument")); | |
DALI_ENFORCE(spec.HasArgument("axes") + spec.HasArgument("axis_names") == 1, | |
spec.HasArgument("axes") ? "Provided both ``axes`` and ``axis_names`` arguments" | |
: "Missing argument ``axes`` or ``axis_names``." ); |
dali/operators/generic/squeeze.cc
Outdated
this->SetOutputType(ws); | ||
|
||
GenerateSrcDims(ws); | ||
Reshape<Backend>::CalculateOutputShape(ws); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would invoke one from the base class in case you overrode it - and you wouldn't even know. Use this->
to enable dependent name lookup.
Reshape<Backend>::CalculateOutputShape(ws); | |
this->CalculateOutputShape(ws); |
dali/operators/generic/squeeze.cc
Outdated
DALI_ENFORCE(in_layout.size() == ndim || in_layout.empty(), | ||
make_string("Layout for data has ", | ||
in_layout.size(), " elements but data has ", ndim, " dimensions.")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DALI_ENFORCE(in_layout.size() == ndim || in_layout.empty(), | |
make_string("Layout for data has ", | |
in_layout.size(), " elements but data has ", ndim, " dimensions.")); |
This is checked in the executor.
dali/operators/generic/squeeze.h
Outdated
private: | ||
void GenerateSrcDims(const Workspace &ws); | ||
|
||
std::vector<int> axes_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::vector<int> axes_; | |
SmallVector<int, 6> axes_; |
This will avoid calling to_vector
and a needless dynamic allocation.
dali/operators/generic/squeeze.cc
Outdated
in_layout.size(), " elements but data has ", ndim, " dimensions.")); | ||
|
||
this->src_dims_.clear(); | ||
auto axes = axis_names_.empty() ? axes_ : GetDimIndices(in_layout, axis_names_).to_vector(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you convert axes_ to SmallVector, you can skip to_vector and avoid allocating a handful of numbers from the heap.
auto axes = axis_names_.empty() ? axes_ : GetDimIndices(in_layout, axis_names_).to_vector(); | |
auto axes = axis_names_.empty() ? axes_ : GetDimIndices(in_layout, axis_names_); |
dali/operators/generic/squeeze.cc
Outdated
this->src_dims_.clear(); | ||
auto axes = axis_names_.empty() ? axes_ : GetDimIndices(in_layout, axis_names_).to_vector(); | ||
std::sort(axes.begin(), axes.end()); | ||
axes.erase(std::unique(axes.begin(), axes.end()), axes.end()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that specifying an axis multiple times should be an error, not something swept under the carpet.
Signed-off-by: Rafal Maj <rmaj@nvidia.com>
if (!in_layout.empty()) { | ||
out_layout += in_layout[d]; | ||
} | ||
} | ||
this->layout_ = out_layout; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather a nitpick:
if (!in_layout.empty()) { | |
out_layout += in_layout[d]; | |
} | |
} | |
this->layout_ = out_layout; | |
} | |
if (!in_layout.empty()) | |
this->layout_ = permute(in_layout, this->src_dims_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it does something different then what I meant. I will skip it.
!build |
CI MESSAGE: [2171679]: BUILD STARTED |
CI MESSAGE: [2171679]: BUILD FAILED |
!build |
CI MESSAGE: [2171766]: BUILD STARTED |
CI MESSAGE: [2171766]: BUILD PASSED |
Why we need this PR?
What happened in this PR?
NA
JIRA TASK: DALI-1851