-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Resampling decoder #1582
Resampling decoder #1582
Changes from all commits
44bf740
9968961
707cf7e
ef75c1b
337497f
379df4f
235a2f9
fd05452
8dc81b6
6623e11
9fa390f
04bcee3
4f1db59
03c214a
71e322c
db31a6e
dab07a4
526c6a9
979f5fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,135 @@ | ||||||
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | ||||||
// | ||||||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
// you may not use this file except in compliance with the License. | ||||||
// You may obtain a copy of the License at | ||||||
// | ||||||
// http://www.apache.org/licenses/LICENSE-2.0 | ||||||
// | ||||||
// Unless required by applicable law or agreed to in writing, software | ||||||
// distributed under the License is distributed on an "AS IS" BASIS, | ||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
// See the License for the specific language governing permissions and | ||||||
// limitations under the License. | ||||||
|
||||||
#ifndef DALI_KERNELS_SIGNAL_DOWNMIXING_H_ | ||||||
#define DALI_KERNELS_SIGNAL_DOWNMIXING_H_ | ||||||
|
||||||
#include <cassert> | ||||||
#include <vector> | ||||||
#include "dali/core/convert.h" | ||||||
#include "dali/core/small_vector.h" | ||||||
#include "dali/core/span.h" | ||||||
#include "dali/core/static_switch.h" | ||||||
|
||||||
namespace dali { | ||||||
namespace kernels { | ||||||
namespace signal { | ||||||
|
||||||
/** | ||||||
* @brief Downmix interleaved signals to a single channel. | ||||||
* | ||||||
* @param out output buffer (single channel) | ||||||
* @param in input buffer (interleaved multiple channels) | ||||||
* @param num_samples number of samples in each channel | ||||||
* @param channels number of input channels | ||||||
* @param weights weights used for downmixing | ||||||
* @param normalize_weights if true, the weights are normalized so their sum is 1 | ||||||
* @tparam Out output sample type - if integral, the intermediate floating point representation | ||||||
* is stretched so that 0..1 or -1..1 range occupies the whole Out range. | ||||||
* @tparam In input sample type - if integral, it's normalized to 0..1 or -1..1 range | ||||||
* @tparam static_channels compile-time number of channels | ||||||
* | ||||||
* Downmix interleaved signals to a single channel, using the weights provided. | ||||||
* If `normalize_weights` is true, the weights are copied into intermediate buffer | ||||||
* and divided by their sum. | ||||||
* | ||||||
* @remarks The operation can be done in place if output and input are of the same type. | ||||||
*/ | ||||||
template <int static_channels = -1, typename Out, typename In> | ||||||
void DownmixChannels( | ||||||
Out *out, const In *in, int64_t samples, int channels, | ||||||
const float *weights, bool normalize_weights = false) { | ||||||
Comment on lines
+49
to
+52
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the benefit of having channels statically specified in this context? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It helps the compiler unroll the loop over channels, opening new optimization opportunities. For 2D resamping this amounted to 20-30% speedup, I assumed it wouldn't be much different here.
Comment on lines
+49
to
+52
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd add a point in the doc, that this can be done in-place There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not if you want to vectorize the loop. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not a C++ qualifier, unfortunately. The __restrict keyword is in C and in CUDA. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. GCC and Clang supports There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Then is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not adding it here, but I will in resampling which is utterly impossible to do in-place. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, after a thought I'll add the in-place info - obviously with the caveat that Out and In should be of the same time - otherwise the compiler assumes no aliasing. |
||||||
SmallVector<float, 8> normalized_weights; // 8 channels should be enough for 7.1 audio | ||||||
static_assert(static_channels != 0, "Number of channels cannot be zero." | ||||||
"Use negative values to use run-time value"); | ||||||
int actual_channels = static_channels < 0 ? channels : static_channels; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about the case when |
||||||
assert(actual_channels == channels); | ||||||
assert(actual_channels > 0); | ||||||
if (normalize_weights) { | ||||||
double sum = 0; | ||||||
for (int i = 0; i < channels; i++) | ||||||
sum += weights[i]; | ||||||
normalized_weights.resize(channels); | ||||||
for (int i = 0; i < channels; i++) { | ||||||
normalized_weights[i] = weights[i] / sum; | ||||||
} | ||||||
weights = normalized_weights.data(); // use this pointer now | ||||||
} | ||||||
for (int64_t o = 0, i = 0; o < samples; o++, i += channels) { | ||||||
float sum = ConvertNorm<float>(in[i]) * weights[0]; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about
Suggested change
and start the loot at There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That way you get one more addition. Given the simplicity of the inner loop, it may actually matter. |
||||||
for (int c = 1; c < channels; c++) { | ||||||
sum += ConvertNorm<float>(in[i + c]) * weights[c]; | ||||||
} | ||||||
out[o] = ConvertSatNorm<Out>(sum); | ||||||
} | ||||||
} | ||||||
|
||||||
/** | ||||||
* @brief Downmix data to a single channel. | ||||||
* | ||||||
* @param out output buffer (single channel) | ||||||
* @param in input buffer (interleaved multiple channels) | ||||||
* @param num_samples number of samples in each channel | ||||||
* @param channels number of input channels | ||||||
* @param weights weights used for downmixing | ||||||
* @param normalize_weights if true, the weights are normalized so their sum is 1 | ||||||
* @tparam Out output sample type - if integral, the intermediate floating point representation | ||||||
* is stretched so that 0..1 or -1..1 range occupies the whole Out range. | ||||||
* @tparam In input sample type - if integral, it's normalized to 0..1 or -1..1 range | ||||||
* | ||||||
* Downmix interleaved signals to a single channel, using the weights provided. | ||||||
* If `normalize_weights` is true, the weights are copied into intermediate buffer | ||||||
* and divided by their sum. | ||||||
* | ||||||
* @remarks The operation can be done in place if output and input are of the same type. | ||||||
*/ | ||||||
template <typename Out, typename In> | ||||||
void Downmix( | ||||||
Out *out, const In *in, int64_t samples, int channels, | ||||||
const float *weights, bool normalize_weights = false) { | ||||||
VALUE_SWITCH(channels, static_channels, (1, 2, 3, 4, 5, 6, 7, 8), | ||||||
(DownmixChannels<static_channels>(out, in, samples, static_channels, | ||||||
weights, normalize_weights);), | ||||||
(DownmixChannels(out, in, samples, channels, weights, normalize_weights);) | ||||||
); // NOLINT | ||||||
} | ||||||
|
||||||
template <typename Out, typename In> | ||||||
void Downmix(Out *out, const In *in, int64_t num_samples, int num_channels) { | ||||||
SmallVector<float, 8> weights; | ||||||
weights.resize(num_channels, 1.0f / num_channels); | ||||||
Downmix(out, in, num_samples, num_channels, weights.data()); | ||||||
} | ||||||
|
||||||
|
||||||
template <typename Out, typename In> | ||||||
void Downmix(span<Out> out, span<const In> in, | ||||||
const std::vector<float> &weights, bool normalize_weights = false) { | ||||||
int num_channels = weights.size(); | ||||||
assert(in.size() % num_channels == 0); | ||||||
Downmix(out.data(), in.data(), in.size() / num_channels, weights, normalize_weights); | ||||||
} | ||||||
|
||||||
|
||||||
template <typename Out, typename In> | ||||||
void Downmix(span<Out> out, span<const In> in, int num_channels) { | ||||||
assert(in.size() % num_channels == 0); | ||||||
Downmix(out.data(), in.data(), in.size() / num_channels, num_channels); | ||||||
} | ||||||
|
||||||
} // namespace signal | ||||||
} // namespace kernels | ||||||
} // namespace dali | ||||||
|
||||||
#endif // DALI_KERNELS_SIGNAL_DOWNMIXING_H_ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <gtest/gtest.h> | ||
#include <vector> | ||
#include <numeric> | ||
#include "dali/kernels/signal/downmixing.h" | ||
|
||
namespace dali { | ||
namespace kernels { | ||
namespace signal { | ||
|
||
TEST(SignalDownmixingTest, RawPointer_Weighted) { | ||
std::vector<float> in = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; | ||
int nchannels = 3; | ||
std::vector<float> weights = {3, 2, 1}; | ||
float sum = std::accumulate(weights.begin(), weights.end(), 0); | ||
std::vector<float> ref = { | ||
(1 * 3 + 2 * 2 + 3) / sum, | ||
(4 * 3 + 5 * 2 + 6) / sum, | ||
(7 * 3 + 8 * 2 + 9) / sum, | ||
(10 * 3 + 11 * 2 + 12) / sum | ||
}; | ||
std::vector<float> out; | ||
out.resize(ref.size()); | ||
|
||
Downmix(out.data(), in.data(), in.size() / nchannels, nchannels, weights.data(), true); | ||
|
||
for (size_t i = 0; i < ref.size(); i++) { | ||
EXPECT_FLOAT_EQ(out[i], ref[i]); | ||
} | ||
} | ||
|
||
TEST(SignalDownmixingTest, Span_DefaultWeights) { | ||
std::vector<float> in = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; | ||
int nchannels = 3; | ||
std::vector<float> ref = {2, 5, 8, 11}; | ||
|
||
Downmix(make_span(in), make_cspan(in), nchannels); | ||
|
||
for (size_t i = 0; i < ref.size(); i++) { | ||
EXPECT_FLOAT_EQ(in[i], ref[i]); | ||
} | ||
} | ||
|
||
} // namespace signal | ||
} // namespace kernels | ||
} // namespace dali |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rename this file to
downmix.h