-
Notifications
You must be signed in to change notification settings - Fork 621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Resampling decoder #1582
Resampling decoder #1582
Conversation
3932c70
to
efe0d7a
Compare
list.append(np.array(bytearray(f.read()), np.uint8)) | ||
self.feed_input(self.raw_file, list) | ||
|
||
def test_decoded_vs_generated(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would still compare with librosa to be aware how close we are just in case we need to debug some discrepancies.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I've checked. Librosa does a bad thing here - namely, they do this:
- calculate output length as
ceil(in_length * out_rate / in_rate)
- resample at a slightly different effective rate: in_rate * out_length / in_length
So, they're changing the sampling rate instead of doing what is asked of them. At least it looks like it - I calculated the "corrected" (or rather librosa-like-distorted) rate and generated the new reference signal. Librosa matches this one much better than the original "ground truth".
I don't think it's our goal to reproduce this particular bug - because it is a bug; among other things, it precludes resampling in variable length chunks, because each output chunk will be resampled at a potentially different rate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can add some tests with generous epsilon - or abandon the weird rate of 12999 Hz, but I don't think it does much good.
Anyway, since there's a random playback speed distortion as an augmentation (and a ton of others), I don't think anyone will look into 0.01% change in sampling rate vs some other library.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I've checked. Librosa does a bad thing here - namely, they do this:
calculate output length as ceil(in_length * out_rate / in_rate)
Doesn't they have an option to turn off this called fix
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And maybe we should have a switch to support this. I believe they just wanted to make sure that the calculated number of output samples is an integer number.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we might want to leave the librosa comparing for just the typical conversions (22050 to 16000 and such)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know what happens to the rate then, but the output length was different, so I couldn't compare the arrays directly.
!build |
CI MESSAGE: [1037901]: BUILD STARTED |
CI MESSAGE: [1037901]: BUILD PASSED |
include/dali/core/math_util.h
Outdated
DALI_HOST_DEV DALI_FORCEINLINE | ||
double sinc(double x) { | ||
x *= M_PI; | ||
if (std::abs(x) < 1e-10) | ||
return 1 - x * x * 0.25; // approximate by a parabola near the pole | ||
return std::sin(x) / x; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two things here:
- There should be a documentation specifying, what is x (deg or rad)
- Out of curiosity, why this approximation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- What really deserves a comment is that the function calculates the so-called "normalized sinc". That definition already covers what
x
is. - For x == 0 there's a removable singularity, so we need a check anyway. For very small values of x calculating sin(x)/x may lose more precision than using an approximation.
By the way, the approximation should be1 - x * x * (1.0f / 6)
- and perhaps the limit 1e-10f set even closer to 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checked: for double, around x = 1e-8 the approximation with 1-x^2/6 is more accurate than sin(x)/x. I used long double version as a reference.
include/dali/core/math_util.h
Outdated
DALI_HOST_DEV DALI_FORCEINLINE | ||
float sinc(float x) { | ||
x *= M_PI; | ||
if (std::abs(x) < 1e-10f) | ||
return 1 - x * x * 0.25f; // approximate by a parabola near the pole | ||
return std::sin(x) / x; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As above, please specify what's x
int length; | ||
int sample_rate; /// [Hz] | ||
/// @brief Length, in samples, of the recording | ||
int64_t length; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In DALI we tend to use "sample" to name a tensor in a batch. While here it's pretty obvious (from the context), that this is certainly not a tensor in a batch, I'd propose to name it "audio sample", to distinguish these two.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It you're talking about the comment, I'll replace it with "(multi-channel) samples", because that's what it is. In libsnd they are using frames
for multi-channel samples, but that's not that common and even their own documenation uses "multi-channel samples" in some places.
template <int static_channels = -1, typename Out, typename In> | ||
void DownmixChannels( | ||
Out *out, const In *in, int64_t samples, int channels, | ||
const float *weights, bool normalize_weights = false) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the benefit of having channels statically specified in this context?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It helps the compiler unroll the loop over channels, opening new optimization opportunities. For 2D resamping this amounted to 20-30% speedup, I assumed it wouldn't be much different here.
for (size_t i = 0; i < ref.size(); i++) { | ||
EXPECT_NEAR(out[i], ref[i], 1e-6); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about using EXPECT_FLOAT_EQ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be OK, I guess. This one is a tad faster and still the dynamic range of values is small enough for a fixed epsilon.
template <int static_channels = -1, typename Out, typename In> | ||
void DownmixChannels( | ||
Out *out, const In *in, int64_t samples, int channels, | ||
const float *weights, bool normalize_weights = false) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd add a point in the doc, that this can be done in-place
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not if you want to vectorize the loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a restricted
qualifier to in
and out
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not a C++ qualifier, unfortunately. The __restrict keyword is in C and in CUDA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GCC and Clang supports __restrict__
, do you think we can use it or we want to stick to the standard?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not if you want to vectorize the loop.
Then is Span_DefaultWeights
incorrect? It does downmixing in-place
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not adding it here, but I will in resampling which is utterly impossible to do in-place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, after a thought I'll add the in-place info - obviously with the caveat that Out and In should be of the same time - otherwise the compiler assumes no aliasing.
|
||
namespace dali { | ||
|
||
DALI_SCHEMA(AudioDecoder) | ||
.DocStr(R"code(Decode audio data. | ||
.DocStr(R"code(Decode audio data. | ||
This operator is a generic way of handling encoded data in DALI. | ||
It supports most of well-known audio formats (wav, flac, ogg). | ||
|
||
This operator produces two outputs: | ||
|
||
* output[0]: batch of decoded data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* output[0]: batch of decoded data | |
* output[0]: batch of decoded data<br/> |
"Resampling quality, 0 is lowest, 100 is highest.\n\n" | ||
"0 corresponds to 3 lobes of the sinc filter;\n" | ||
"50 gives 16 lobes and 100 gives 64 lobes.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my perspective, this:
0 corresponds to 3 lobes of the sinc filter 50 gives 16 lobes and 100 gives 64 lobes.
Is an implementation detail. I know, that "quality" doesn't mean too much, but neither the latter. Especially, that it can change in future, e.g. we can add another constrain to the "quality". I'd remove this and leave only relative measure 0..100.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, librosa has some quality enum/string:
"kaiser_best" - 64-zero crossings
"kaiser_fast" - 16-zero crossings
I just followed that example.
Please note that I'm not specifying the envelope/window for the sinc, so we have some freedom here, still (e.g. we can swtich to Kaiser window when we find some implementation of Bessel functions in C++).
} | ||
|
||
std::vector<float> target_sample_rates_; | ||
kernels::signal::resampling::Resampler resampler_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't the resampler_ ber per-sample?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not unless per-sample quality is needed - and I have quite strong opinion that it won't and I wouldn't waste memory and cache locality on that assumption unless we see at least potential use case (i.e. if we find it in actual network or at least offline augmentation for that network).
dali/kernels/signal/resampling.h
Outdated
* The function can seamlessly resample the input and produce the result in chunks. | ||
* To reuse memory and still simulate chunk processing, adjust the in/out pointers. | ||
*/ | ||
template <typename Out> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you considered the static definition of a number of channels? Would it help with the perf?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would. I can do the VALUE_SWITCH trick here, no problem.
!build |
CI MESSAGE: [1040143]: BUILD STARTED |
Aside comment, to the tests. I recently found out, thet gtest discourages using underscore in test names: |
CI MESSAGE: [1040143]: BUILD PASSED |
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rename this file to downmix.h
Out *out, const In *in, int64_t samples, int channels, | ||
const float *weights, bool normalize_weights = false) { | ||
SmallVector<float, 8> normalized_weights; // 8 channels should be enough for 7.1 audio | ||
int actual_channels = static_channels < 0 ? channels : static_channels; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about the case when static_channels==0
?
weights = normalized_weights.data(); // use this pointer now | ||
} | ||
for (int64_t o = 0, i = 0; o < samples; o++, i += channels) { | ||
float sum = ConvertNorm<float>(in[i]) * weights[0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about
float sum = ConvertNorm<float>(in[i]) * weights[0]; | |
float sum = 0; |
and start the loot at c = 0
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That way you get one more addition. Given the simplicity of the inner loop, it may actually matter.
dali/kernels/signal/downmixing.h
Outdated
} | ||
|
||
/** | ||
* Downmix data to a single channel. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* Downmix data to a single channel. | |
* @brief Downmix data to a single channel. |
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef DALI_KERNELS_SIGNAL_RESAMPLING_H_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not resample.h
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Header / library names are usually nouns. Moreover, what's inside is the Resampler
class, not a free function - though resampler.h
also doesn't sound well.
audio.data[ofs] = ConvertSatNorm<OutputType>(tmp_buf[ofs]); | ||
} | ||
} | ||
} else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: I feel there are too many if/else levels here. We could at least get rid of the outer-most one by having this "decode-only" case at the top with an early return
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd find it a bit confusing. Early returns are great for checks like "do we have to do anything?" or checking for error conditions - this is neither.
I thought about decoding and then postprocessing - but it's impossible, because we're potentially decoding to different buffers.
double q = quality_; | ||
DALI_ENFORCE(q >= 0 && q <= 100, "Resampling quality must be in [0..100] range"); | ||
// this should give 3 lobes for q = 0, 16 lobes for q = 50 and 64 lobes for q = 100 | ||
int lobes = std::round(0.007 * q * q - 0.09 * q + 3); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe a wikipedia reference here to know where the formula comes from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It comes from nowhere - it's a solution of a linear system:
a * 0 + b * 0 + c = 3
a * 50^2 + b * 50 + c = 16
a * 100^2 + b * 100 + c = 64
It does exactly, what's written - gives you 3 at 0, 16 at 50 and 64 at 100.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so that quality
concept is invented by us?
@@ -53,13 +59,12 @@ class AudioDecoderBase { | |||
template<typename SampleType> | |||
class TypedAudioDecoderBase : public AudioDecoderBase { | |||
public: | |||
void Decode(span<char> raw_output) override { | |||
ptrdiff_t Decode(span<char> raw_output) override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not just int64_t?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's the return type of span::size()
.
list.append(np.array(bytearray(f.read()), np.uint8)) | ||
self.feed_input(self.raw_file, list) | ||
|
||
def test_decoded_vs_generated(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we might want to leave the librosa comparing for just the typical conversions (22050 to 16000 and such)
@@ -167,6 +167,25 @@ constexpr double rad2deg(double rad) { | |||
return rad * r2d; | |||
} | |||
|
|||
/// @brief Calculates normalized sinc i.e. `sin(pi * x) / (pi * x)` | |||
DALI_HOST_DEV DALI_FORCEINLINE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unnecessarily duplicated code. You can templatize this and replace 1.0
and 1.0f
by T(1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The threshold for transition to approximation is different - and putting
std::is_same<T, float>::value ? T(1e-5) : T(1e-8)
doesn't sound like a simplification to me.
Going further, it would require having some template magic for integer arguments or long double.
In principle, I could, but they would be harder to read. When I type "Span_DefaultWeights" , it's pretty clear that it's about some "Span" and "DefaultWeights". "SpanDefaultWeights" sounds like something spans default weights (which makes no sense in this context) or "Span, Default, Weights", whatever is "Default", etc. SpanAndDefaultWeights is also so-so at best. |
Signed-off-by: Michał Szołucha <mszolucha@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
* Add multi-channel resampling. Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
* Add output conversion to resampling. * Essentially rewrite the operator: * Add per-sample sampling rate * Add explicit downmixing option * Add multi-channel resampling * Remove dangerous global buffers (BUG!) Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Fix window function generator. Fix multi-channel resampling. Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Fix bugs. Lint. Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
* Add Python test against librosa resampling * Change sample center to sample zero Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
!build |
CI MESSAGE: [1040438]: BUILD STARTED |
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
b489ef9
to
979f5fd
Compare
!build |
CI MESSAGE: [1040494]: BUILD STARTED |
CI MESSAGE: [1040494]: BUILD FAILED |
CI MESSAGE: [1040494]: BUILD PASSED |
Why we need this PR?
What happened in this PR?
if
s there...THIS PR SUPERSEDES #1574
JIRA TASK: * don't remember