-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More reductions #2395
More reductions #2395
Conversation
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
!build |
CI MESSAGE: [1737280]: BUILD STARTED |
#define MEAN_SQUARE_TYPES_MAP ( \ | ||
((uint8_t), (uint8_t, float)), \ | ||
((int8_t), (int8_t, float)), \ | ||
((uint16_t), (uint16_t, float)), \ | ||
((int16_t), (int16_t, float)), \ | ||
((uint32_t), (uint32_t, float)), \ | ||
((int32_t), (int32_t, float)), \ | ||
((uint64_t), (uint64_t, float)), \ | ||
((int64_t), (int64_t, float)), \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calculating mean square and staying in a small output type is asking for trouble - by squaring we're doubling the number of bits required. We also can support both signed and unsigned outputs for signed input just by reinterpreting the data - that would complicate the matters a bit, but would cut the number of instantiations without sacrificing functionality (compared to either supporting both signed and unsigned outputs or not supporting one of the type flavors).
#define REDUCE_WITH_MEAN_TYPES_MAP ( \ | ||
((uint8_t), (uint8_t, float)), \ | ||
((int8_t), (int8_t, float)), \ | ||
((uint16_t), (uint16_t, float)), \ | ||
((int16_t), (int16_t, float)), \ | ||
((uint32_t), (uint32_t, float)), \ | ||
((int32_t), (int32_t, float)), \ | ||
((uint64_t), (uint64_t, float)), \ | ||
((int64_t), (int64_t, float)), \ | ||
((float), (float))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likewise - the output types for Standard Deviation and for Variance would differ (variance needs a 2x larger output type).
0) | ||
.AddParent("ReduceWithOutputType"); | ||
|
||
DALI_SCHEMA(reductions__Std) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: I don't like this capitalization - I think it should be StD, but it would require some special handling in _to_snake_case (it's too short to add to special case mapping with acceptable risk).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, changed to STD
.NumOutput(1) | ||
.AddParent("ReduceWithOutputType"); | ||
|
||
DALI_SCHEMA(reductions__RootMeanSquare) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DALI_SCHEMA(reductions__RootMeanSquare) | |
DALI_SCHEMA(reductions__RMS) |
This is a well established term in the field. Also, if we keep it uppercase here it will convert to well to "rms" in fn.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
CI MESSAGE: [1737280]: BUILD PASSED |
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
!build |
CI MESSAGE: [1740843]: BUILD STARTED |
CI MESSAGE: [1740843]: BUILD PASSED |
Why we need this PR?
What happened in this PR?
Added new ops: Mean, MeanSquare, RootMeanSquare, Std, Variance
Added tests for new ops
Docs in op schema
JIRA TASK: [DALI-1675, DALI-1676]