diff --git a/dali/kernels/imgproc/resample/resampling_windows.h b/dali/kernels/imgproc/resample/resampling_windows.h index 4a1a662cff..51139e4e3a 100644 --- a/dali/kernels/imgproc/resample/resampling_windows.h +++ b/dali/kernels/imgproc/resample/resampling_windows.h @@ -18,6 +18,7 @@ #include #include #include "dali/kernels/kernel.h" +#include "dali/core/math_util.h" namespace dali { namespace kernels { @@ -40,10 +41,6 @@ inline __host__ __device__ float RectangularWindow(float x) { return -0.5f <= x && x < 0.5f ? 1 : 0; } -inline __host__ __device__ float sinc(float x) { - return x ? sinf(x * M_PI) / (x * M_PI) : 1; -} - inline __host__ __device__ float LanczosWindow(float x, float a) { if (fabsf(x) >= a) return 0.0f; diff --git a/dali/kernels/signal/downmixing.h b/dali/kernels/signal/downmixing.h new file mode 100644 index 0000000000..2e86e2c035 --- /dev/null +++ b/dali/kernels/signal/downmixing.h @@ -0,0 +1,135 @@ +// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_KERNELS_SIGNAL_DOWNMIXING_H_ +#define DALI_KERNELS_SIGNAL_DOWNMIXING_H_ + +#include +#include +#include "dali/core/convert.h" +#include "dali/core/small_vector.h" +#include "dali/core/span.h" +#include "dali/core/static_switch.h" + +namespace dali { +namespace kernels { +namespace signal { + +/** + * @brief Downmix interleaved signals to a single channel. + * + * @param out output buffer (single channel) + * @param in input buffer (interleaved multiple channels) + * @param num_samples number of samples in each channel + * @param channels number of input channels + * @param weights weights used for downmixing + * @param normalize_weights if true, the weights are normalized so their sum is 1 + * @tparam Out output sample type - if integral, the intermediate floating point representation + * is stretched so that 0..1 or -1..1 range occupies the whole Out range. + * @tparam In input sample type - if integral, it's normalized to 0..1 or -1..1 range + * @tparam static_channels compile-time number of channels + * + * Downmix interleaved signals to a single channel, using the weights provided. + * If `normalize_weights` is true, the weights are copied into intermediate buffer + * and divided by their sum. + * + * @remarks The operation can be done in place if output and input are of the same type. + */ +template +void DownmixChannels( + Out *out, const In *in, int64_t samples, int channels, + const float *weights, bool normalize_weights = false) { + SmallVector normalized_weights; // 8 channels should be enough for 7.1 audio + static_assert(static_channels != 0, "Number of channels cannot be zero." + "Use negative values to use run-time value"); + int actual_channels = static_channels < 0 ? channels : static_channels; + assert(actual_channels == channels); + assert(actual_channels > 0); + if (normalize_weights) { + double sum = 0; + for (int i = 0; i < channels; i++) + sum += weights[i]; + normalized_weights.resize(channels); + for (int i = 0; i < channels; i++) { + normalized_weights[i] = weights[i] / sum; + } + weights = normalized_weights.data(); // use this pointer now + } + for (int64_t o = 0, i = 0; o < samples; o++, i += channels) { + float sum = ConvertNorm(in[i]) * weights[0]; + for (int c = 1; c < channels; c++) { + sum += ConvertNorm(in[i + c]) * weights[c]; + } + out[o] = ConvertSatNorm(sum); + } +} + +/** + * @brief Downmix data to a single channel. + * + * @param out output buffer (single channel) + * @param in input buffer (interleaved multiple channels) + * @param num_samples number of samples in each channel + * @param channels number of input channels + * @param weights weights used for downmixing + * @param normalize_weights if true, the weights are normalized so their sum is 1 + * @tparam Out output sample type - if integral, the intermediate floating point representation + * is stretched so that 0..1 or -1..1 range occupies the whole Out range. + * @tparam In input sample type - if integral, it's normalized to 0..1 or -1..1 range + * + * Downmix interleaved signals to a single channel, using the weights provided. + * If `normalize_weights` is true, the weights are copied into intermediate buffer + * and divided by their sum. + * + * @remarks The operation can be done in place if output and input are of the same type. + */ +template +void Downmix( + Out *out, const In *in, int64_t samples, int channels, + const float *weights, bool normalize_weights = false) { + VALUE_SWITCH(channels, static_channels, (1, 2, 3, 4, 5, 6, 7, 8), + (DownmixChannels(out, in, samples, static_channels, + weights, normalize_weights);), + (DownmixChannels(out, in, samples, channels, weights, normalize_weights);) + ); // NOLINT +} + +template +void Downmix(Out *out, const In *in, int64_t num_samples, int num_channels) { + SmallVector weights; + weights.resize(num_channels, 1.0f / num_channels); + Downmix(out, in, num_samples, num_channels, weights.data()); +} + + +template +void Downmix(span out, span in, + const std::vector &weights, bool normalize_weights = false) { + int num_channels = weights.size(); + assert(in.size() % num_channels == 0); + Downmix(out.data(), in.data(), in.size() / num_channels, weights, normalize_weights); +} + + +template +void Downmix(span out, span in, int num_channels) { + assert(in.size() % num_channels == 0); + Downmix(out.data(), in.data(), in.size() / num_channels, num_channels); +} + +} // namespace signal +} // namespace kernels +} // namespace dali + +#endif // DALI_KERNELS_SIGNAL_DOWNMIXING_H_ diff --git a/dali/kernels/signal/downmixing_test.cc b/dali/kernels/signal/downmixing_test.cc new file mode 100644 index 0000000000..02cd9debf9 --- /dev/null +++ b/dali/kernels/signal/downmixing_test.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "dali/kernels/signal/downmixing.h" + +namespace dali { +namespace kernels { +namespace signal { + +TEST(SignalDownmixingTest, RawPointer_Weighted) { + std::vector in = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + int nchannels = 3; + std::vector weights = {3, 2, 1}; + float sum = std::accumulate(weights.begin(), weights.end(), 0); + std::vector ref = { + (1 * 3 + 2 * 2 + 3) / sum, + (4 * 3 + 5 * 2 + 6) / sum, + (7 * 3 + 8 * 2 + 9) / sum, + (10 * 3 + 11 * 2 + 12) / sum + }; + std::vector out; + out.resize(ref.size()); + + Downmix(out.data(), in.data(), in.size() / nchannels, nchannels, weights.data(), true); + + for (size_t i = 0; i < ref.size(); i++) { + EXPECT_FLOAT_EQ(out[i], ref[i]); + } +} + +TEST(SignalDownmixingTest, Span_DefaultWeights) { + std::vector in = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + int nchannels = 3; + std::vector ref = {2, 5, 8, 11}; + + Downmix(make_span(in), make_cspan(in), nchannels); + + for (size_t i = 0; i < ref.size(); i++) { + EXPECT_FLOAT_EQ(in[i], ref[i]); + } +} + +} // namespace signal +} // namespace kernels +} // namespace dali diff --git a/dali/kernels/signal/resampling.h b/dali/kernels/signal/resampling.h new file mode 100644 index 0000000000..0d5d1f48e2 --- /dev/null +++ b/dali/kernels/signal/resampling.h @@ -0,0 +1,226 @@ +// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_KERNELS_SIGNAL_RESAMPLING_H_ +#define DALI_KERNELS_SIGNAL_RESAMPLING_H_ + +#include +#include +#include +#include +#include "dali/core/math_util.h" +#include "dali/core/small_vector.h" +#include "dali/core/convert.h" +#include "dali/core/static_switch.h" + +namespace dali { +namespace kernels { +namespace signal { + +namespace resampling { + +inline double Hann(double x) { + return 0.5 * (1 + std::cos(x * M_PI)); +} + +struct ResamplingWindow { + inline std::pair input_range(float x) const { + int i0 = std::ceil(x) - lobes; + int i1 = std::floor(x) + lobes; + return {i0, i1}; + } + + inline float operator()(float x) const { + float fi = x * scale + center; + int i = std::floor(fi); + float di = fi - i; + assert(i >= 0 && i < static_cast(lookup.size())); + return lookup[i] + di * (lookup[i + 1] - lookup[i]); + } + + + float scale = 1, center = 1; + int lobes = 0, coeffs = 0; + std::vector lookup; +}; + +void windowed_sinc(ResamplingWindow &window, + int coeffs, int lobes, std::function envelope = Hann) { + assert(coeffs > 1 && lobes > 0 && "Degenerate parameters specified."); + float scale = 2.0f * lobes / (coeffs - 1); + float scale_envelope = 2.0f / coeffs; + window.coeffs = coeffs; + window.lobes = lobes; + window.lookup.resize(coeffs + 2); // add zeros + int center = (coeffs - 1) * 0.5f; + for (int i = 0; i < coeffs; i++) { + float x = (i - center) * scale; + float y = (i - center) * scale_envelope; + float w = sinc(x) * envelope(y); + window.lookup[i + 1] = w; + } + window.center = center + 1; // allow for leading zero + window.scale = 1 / scale; +} + + +inline int64_t resampled_length(int64_t in_length, double in_rate, double out_rate) { + return std::ceil(in_length * out_rate / in_rate); +} + +struct Resampler { + ResamplingWindow window; + + void Initialize(int lobes = 16, int lookup_size = 2048) { + windowed_sinc(window, lookup_size, lobes); + } + + /** + * @brief Resample single-channel signal and convert to Out + * + * Calculates a range of resampled signal. + * The function can seamlessly resample the input and produce the result in chunks. + * To reuse memory and still simulate chunk processing, adjust the in/out pointers. + */ + template + void Resample( + Out *__restrict__ out, int64_t out_begin, int64_t out_end, double out_rate, + const float *__restrict__ in, int64_t n_in, double in_rate) const { + assert(out_rate > 0 && in_rate > 0 && "Sampling rate must be positive"); + int64_t in_pos = 0; + int64_t block = 1 << 10; // still leaves 13 significant bits for fractional part + double scale = in_rate / out_rate; + float fscale = scale; + for (int64_t out_block = out_begin; out_block < out_end; out_block += block) { + int64_t block_end = std::min(out_block + block, out_end); + double in_block_f = out_block * scale; + int64_t in_block_i = std::floor(in_block_f); + float in_pos = in_block_f - in_block_i; + const float *__restrict__ in_block_ptr = in + in_block_i; + for (int64_t out_pos = out_block; out_pos < block_end; out_pos++, in_pos += fscale) { + int i0, i1; + std::tie(i0, i1) = window.input_range(in_pos); + if (i0 + in_block_i < 0) + i0 = -in_block_i; + if (i1 + in_block_i >= n_in) + i1 = n_in - 1 - in_block_i; + float f = 0; + float x = i0 - in_pos; + for (int i = i0; i <= i1; i++, x++) { + assert(in_block_ptr + i >= in && in_block_ptr + i < in + n_in); + float w = window(x); + f += in_block_ptr[i] * w; + } + assert(out_pos >= out_begin && out_pos < out_end); + out[out_pos] = ConvertSatNorm(f); + } + } + } + + + /** + * @brief Resample multi-channel signal and convert to Out + * + * Calculates a range of resampled signal. + * The function can seamlessly resample the input and produce the result in chunks. + * To reuse memory and still simulate chunk processing, adjust the in/out pointers. + * + * @tparam satic_channels number of channels, if known at compile time, or -1 + */ + template + void Resample( + Out *__restrict__ out, int64_t out_begin, int64_t out_end, double out_rate, + const float *__restrict__ in, int64_t n_in, double in_rate, + int dynamic_num_channels) { + static_assert(static_channels != 0, "Static number of channels must be positive (use static) " + "or negative (use dynamic)."); + assert(out_rate > 0 && in_rate > 0 && "Sampling rate must be positive"); + if (dynamic_num_channels == 1) { + // fast path + Resample(out, out_begin, out_end, out_rate, in, n_in, in_rate); + return; + } + // the check below is compile time, so num_channels will be a compile-time constant + // or a run-time constant, depending on the value of static_channels + const int num_channels = static_channels < 0 ? dynamic_num_channels : static_channels; + assert(num_channels > 0); + + int64_t in_pos = 0; + int64_t block = 1 << 10; // still leaves 13 significant bits for fractional part + double scale = in_rate / out_rate; + float fscale = scale; + SmallVector tmp; + tmp.resize(num_channels); + for (int64_t out_block = out_begin; out_block < out_end; out_block += block) { + int64_t block_end = std::min(out_block + block, out_end); + double in_block_f = out_block * scale; + int64_t in_block_i = std::floor(in_block_f); + float in_pos = in_block_f - in_block_i; + const float *__restrict__ in_block_ptr = in + in_block_i * num_channels; + for (int64_t out_pos = out_block; out_pos < block_end; out_pos++, in_pos += fscale) { + int i0, i1; + std::tie(i0, i1) = window.input_range(in_pos); + if (i0 + in_block_i < 0) + i0 = -in_block_i; + if (i1 + in_block_i >= n_in) + i1 = n_in - 1 - in_block_i; + + for (int c = 0; c < num_channels; c++) + tmp[c] = 0; + + float x = i0 - in_pos; + int ofs0 = i0 * num_channels; + int ofs1 = i1 * num_channels; + for (int in_ofs = ofs0; in_ofs <= ofs1; in_ofs += num_channels, x++) { + float w = window(x); + for (int c = 0; c < num_channels; c++) { + assert(in_block_ptr + in_ofs + c >= in && + in_block_ptr + in_ofs + c < in + n_in * num_channels); + tmp[c] += in_block_ptr[in_ofs + c] * w; + } + } + assert(out_pos >= out_begin && out_pos < out_end); + for (int c = 0; c < num_channels; c++) + out[out_pos * num_channels + c] = ConvertSatNorm(tmp[c]); + } + } + } + + + /** + * @brief Resample multi-channel signal and convert to Out + * + * Calculates a range of resampled signal. + * The function can seamlessly resample the input and produce the result in chunks. + * To reuse memory and still simulate chunk processing, adjust the in/out pointers. + */ + template + void Resample( + Out *__restrict__ out, int64_t out_begin, int64_t out_end, double out_rate, + const float *__restrict__ in, int64_t n_in, double in_rate, + int num_channels) { + VALUE_SWITCH(num_channels, static_channels, (1, 2, 3, 4, 5, 6, 7, 8), + (Resample(out, out_begin, out_end, out_rate, + in, n_in, in_rate, static_channels);), + (Resample<-1, Out>(out, out_begin, out_end, out_rate, + in, n_in, in_rate, num_channels))); + } +}; + +} // namespace resampling +} // namespace signal +} // namespace kernels +} // namespace dali + +#endif // DALI_KERNELS_SIGNAL_RESAMPLING_H_ diff --git a/dali/kernels/signal/resampling_test.cc b/dali/kernels/signal/resampling_test.cc new file mode 100644 index 0000000000..9dd2233867 --- /dev/null +++ b/dali/kernels/signal/resampling_test.cc @@ -0,0 +1,109 @@ +// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "dali/kernels/signal/resampling.h" + +namespace dali { +namespace kernels { +namespace signal { +namespace resampling { + +namespace { + +double HannWindow(int i, int n) { + assert(n > 0); + return Hann(2.0*i / n - 1); +} + +template +void TestWave(T *out, int n, int stride, float freq) { + for (int i = 0; i < n; i++) { + float x = i * freq; + float f = std::sin(i* freq) * HannWindow(i, n); + out[i*stride] = ConvertSatNorm(f); + } +} + +} // namespace + +TEST(ResampleSinc, SingleChannel) { + int n_in = 22050, n_out = 16000; // typical downsampling + std::vector in(n_in); + std::vector out(n_out); + std::vector ref(out.size()); + float f_in = 0.1f; + float f_out = f_in * n_in / n_out; + double in_rate = n_in; + double out_rate = n_out; + TestWave(in.data(), n_in, 1, f_in); + TestWave(ref.data(), n_out, 1, f_out); + Resampler R; + R.Initialize(16); + R.Resample(out.data(), 0, n_out, out_rate, in.data(), n_in, in_rate); + + double err = 0, max_diff = 0; + for (int i = 0; i < n_out; i++) { + ASSERT_NEAR(out[i], ref[i], 1e-3) << "Sample error too big @" << i << std::endl; + float diff = std::abs(out[i] - ref[i]); + if (diff > max_diff) + max_diff = diff; + err += diff*diff; + } + err = std::sqrt(err/n_out); + EXPECT_LE(err, 1e-3) << "Average error too big"; + std::cerr << "Resampling with Hann-windowed sinc filter and 16 zero crossings" + "\n max difference vs fresh signal: " << max_diff << + "\n RMS error: " << err << std::endl; +} + +TEST(ResampleSinc, MultiChannel) { + int n_in = 22050, n_out = 22053; // some weird upsampling + int ch = 5; + std::vector in(n_in * ch); + std::vector out(n_out * ch); + std::vector ref(out.size()); + double in_rate = n_in; + double out_rate = n_out; + for (int c = 0; c < ch; c++) { + float f_in = 0.1f * (1 + c * 0.012345); // different signal in each channel + float f_out = f_in * n_in / n_out; + TestWave(in.data() + c, n_in, ch, f_in); + TestWave(ref.data() + c, n_out, ch, f_out); + } + Resampler R; + R.Initialize(16); + R.Resample(out.data(), 0, n_out, out_rate, in.data(), n_in, in_rate, ch); + + double err = 0, max_diff = 0; + for (int i = 0; i < n_out * ch; i++) { + ASSERT_NEAR(out[i], ref[i], 2e-3) << "Sample error too big @" << i << std::endl; + float diff = std::abs(out[i] - ref[i]); + if (diff > max_diff) + max_diff = diff; + err += diff*diff; + } + err = std::sqrt(err/(n_out * ch)); + EXPECT_LE(err, 1e-3) << "Average error too big"; + std::cerr << "Resampling with Hann-windowed sinc filter and 16 zero crossings" + "\n max difference vs fresh signal: " << max_diff << + "\n RMS error: " << err << std::endl; +} + +} // namespace resampling +} // namespace signal +} // namespace kernels +} // namespace dali diff --git a/dali/operators/decoder/audio/audio_decoder.h b/dali/operators/decoder/audio/audio_decoder.h index 31b6ba5223..62231d9dd3 100644 --- a/dali/operators/decoder/audio/audio_decoder.h +++ b/dali/operators/decoder/audio/audio_decoder.h @@ -21,8 +21,10 @@ namespace dali { struct AudioMetadata { - int length; - int sample_rate; /// [Hz] + /// @brief Length, in (multi-channel) samples, of the recording + int64_t length; + /// @brief Sampling rate, in Hz + int sample_rate; int channels; bool channels_interleaved; }; @@ -40,7 +42,11 @@ class AudioDecoderBase { } - virtual void Decode(span raw_output) = 0; + /** + * @brief Decode audio data and store it in the supplied buffer + * @return Number of (multi-channel) samples actually read + */ + virtual ptrdiff_t Decode(span raw_output) = 0; virtual ~AudioDecoderBase() = default; @@ -53,13 +59,12 @@ class AudioDecoderBase { template class TypedAudioDecoderBase : public AudioDecoderBase { public: - void Decode(span raw_output) override { + ptrdiff_t Decode(span raw_output) override { int max_samples = static_cast(raw_output.size() / sizeof(SampleType)); - DecodeTyped({reinterpret_cast(raw_output.data()), max_samples}); + return DecodeTyped({reinterpret_cast(raw_output.data()), max_samples}); } - - virtual void DecodeTyped(span typed_output) = 0; + virtual ptrdiff_t DecodeTyped(span typed_output) = 0; }; diff --git a/dali/operators/decoder/audio/audio_decoder_op.cc b/dali/operators/decoder/audio/audio_decoder_op.cc index 83d73a349e..6fa16f85b7 100644 --- a/dali/operators/decoder/audio/audio_decoder_op.cc +++ b/dali/operators/decoder/audio/audio_decoder_op.cc @@ -14,11 +14,12 @@ #include "dali/operators/decoder/audio/audio_decoder_op.h" #include "dali/pipeline/operator/op_schema.h" +#include "dali/pipeline/data/views.h" namespace dali { DALI_SCHEMA(AudioDecoder) - .DocStr(R"code(Decode audio data. + .DocStr(R"code(Decode audio data. This operator is a generic way of handling encoded data in DALI. It supports most of well-known audio formats (wav, flac, ogg). @@ -26,43 +27,46 @@ This operator produces two outputs: * output[0]: batch of decoded data * output[1]: batch of sampling rates [Hz] - -Sample rate (output[1]) at index `i` corresponds to sample (output[0]) at index `i`. -On the event more metadata will appear, we reserve a right to change this behaviour.)code") - .NumInput(1) - .NumOutput(detail::kNumOutputs) - .AddOptionalArg( - detail::kOutputTypeName, - "Type of the output data. Supports types: `INT16`, `INT32`, `FLOAT`", - DALI_INT16); +)code") + .NumInput(1) + .NumOutput(2) + .AddOptionalArg("sample_rate", + "If specified, the target sample rate, in Hz, to which the audio is resampled.", + 0.0f, true) + .AddOptionalArg("quality", + "Resampling quality, 0 is lowest, 100 is highest.\n\n" + "0 corresponds to 3 lobes of the sinc filter;\n" + "50 gives 16 lobes and 100 gives 64 lobes.", + 50.0f, false) + .AddOptionalArg("downmix", + "If True, downmix all input channels to mono.", false) + .AddOptionalArg("dtype", + "Type of the output data. Supports types: `INT16`, `INT32`, `FLOAT`", DALI_FLOAT); DALI_REGISTER_OPERATOR(AudioDecoder, AudioDecoderCpu, CPU); - bool AudioDecoderCpu::SetupImpl(std::vector &output_desc, const workspace_t &ws) { + GetPerSampleArgument(target_sample_rates_, "sample_rate", ws); auto &input = ws.template InputRef(0); const auto batch_size = input.shape().num_samples(); for (int i = 0; i < batch_size; i++) { - DALI_ENFORCE(input.shape()[i].size() == 1, "Input must be 1D encoded byte data"); + DALI_ENFORCE(input.shape()[i].size() == 1, "Raw input must be 1D encoded byte data"); } - DALI_ENFORCE(IsType(input.type()), "Input must be stored as uint8_t data."); - TypeInfo type; - TypeInfo type_i32; - type_i32.SetType(DALI_INT32); + DALI_ENFORCE(IsType(input.type()), "Raw files must be stored as uint8 data."); decoders_.resize(batch_size); + intermediate_buffers_.resize(batch_size); sample_meta_.resize(batch_size); files_names_.resize(batch_size); - TYPE_SWITCH(output_type_, type2id, OutputType, (int16_t, int32_t, float), ( - for (int i=0; i < batch_size; i++) { - decoders_[i] = std::make_unique>(); - } - type.SetType(output_type_); + decode_type_ = use_resampling_ || downmix_ ? DALI_FLOAT : output_type_; + TYPE_SWITCH(decode_type_, type2id, OutputType, (int16_t, int32_t, float), ( + for (int i=0; i < batch_size; i++) + decoders_[i] = std::make_unique>(); ), DALI_FAIL("Unsupported output type")) // NOLINT - output_desc.resize(detail::kNumOutputs); + output_desc.resize(2); // Currently, metadata is only the sampling rate. // On the event something else would emerge, @@ -74,38 +78,112 @@ AudioDecoderCpu::SetupImpl(std::vector &output_desc, const workspace auto meta = decoders_[i]->Open({reinterpret_cast(input[i].raw_mutable_data()), input[i].shape().num_elements()}); sample_meta_[i] = meta; - shape_data.set_tensor_shape(i, {meta.length, meta.channels}); + int64_t out_length = OutputLength(meta.length, meta.sample_rate, i); + TensorShape<> data_sample_shape = { out_length, downmix_ ? 1 : meta.channels, }; + + shape_data.set_tensor_shape(i, data_sample_shape); shape_rate.set_tensor_shape(i, {1}); files_names_[i] = input[i].GetSourceInfo(); } - output_desc[0] = {shape_data, type}; - output_desc[1] = {shape_rate, type_i32}; + + output_desc[0] = { shape_data, TypeTable::GetTypeInfo(output_type_) }; + output_desc[1] = { shape_rate, TypeTable::GetTypeInfo(DALI_FLOAT) }; return true; } -void AudioDecoderCpu::RunImpl(workspace_t &ws) { - auto &decoded_output = ws.template OutputRef(0); - auto &sample_rate_output = ws.template OutputRef(1); +template +span as_raw_span(T *buffer, ptrdiff_t length) { + return make_span(reinterpret_cast(buffer), length*sizeof(T)); +} + +template +void AudioDecoderCpu::DecodeSample(const TensorView &audio, + int thread_idx, int sample_idx) { + const AudioMetadata &meta = sample_meta_[sample_idx]; + + auto &tmp_buf = intermediate_buffers_[thread_idx]; + double output_rate = meta.sample_rate; + if (use_resampling_) { + output_rate = target_sample_rates_[sample_idx]; + DALI_ENFORCE(meta.sample_rate > 0, make_string("Unknown or invalid input sampling rate.")); + DALI_ENFORCE(output_rate > 0, make_string( + "Output sampling rate must be positive; got ", output_rate)); + } + bool should_resample = meta.sample_rate != output_rate; + bool should_downmix = meta.channels > 1 && downmix_; + if (should_resample || should_downmix || output_type_ != decode_type_) { + assert(decode_type_ == DALI_FLOAT); + int64_t tmp_size = should_downmix && should_resample + ? meta.length * (meta.channels + 1) // downmix to intermediate buffer, then resample + : meta.length * meta.channels; // decode to intermediate, then resample or downmix + // directly to the output + + tmp_buf.resize(tmp_size); + decoders_[sample_idx]->Decode(as_raw_span(tmp_buf.data(), meta.length * meta.channels)); + + if (should_downmix) { + if (should_resample) { + // downmix and resample + float *downmixed = tmp_buf.data() + meta.length * meta.channels; + assert(downmixed + meta.length <= tmp_buf.data() + tmp_buf.size()); + kernels::signal::Downmix(downmixed, tmp_buf.data(), meta.length, meta.channels); + resampler_.Resample(audio.data, 0, audio.shape[0], output_rate, + downmixed, meta.length, meta.sample_rate); + } else { + // downmix only + kernels::signal::Downmix(audio.data, tmp_buf.data(), meta.length, meta.channels); + } + } else if (should_resample) { + // multi-channel resample + resampler_.Resample(audio.data, 0, audio.shape[0], output_rate, + tmp_buf.data(), meta.length, meta.sample_rate, meta.channels); + + } else { + // convert or copy only - this will only happen if resampling is specified, but this + // recording's sampling rate and number of channels coincides with the target + int64_t len = std::min(volume(audio.shape), meta.length*meta.channels); + for (int64_t ofs = 0; ofs < len; ofs++) { + audio.data[ofs] = ConvertSatNorm(tmp_buf[ofs]); + } + } + } else { + assert(!should_downmix && !should_resample); + decoders_[sample_idx]->Decode(as_raw_span(audio.data, volume(audio.shape))); + } +} + +template +void AudioDecoderCpu::DecodeBatch(workspace_t &ws) { + auto decoded_output = view(ws.template OutputRef(0)); + auto sample_rate_output = view(ws.template OutputRef(1)); + int batch_size = decoded_output.shape.num_samples(); auto &tp = ws.GetThreadPool(); - auto batch_size = decoded_output.shape().num_samples(); + + intermediate_buffers_.resize(tp.size()); for (int i = 0; i < batch_size; i++) { tp.DoWorkWithID([&, i](int thread_id) { - auto &decoder = decoders_[i]; - auto &output = decoded_output[i]; - try { - decoder->Decode({reinterpret_cast(output.raw_mutable_data()), - static_cast(output.type().size() * output.shape().num_elements())}); - auto sample_rate_ptr = sample_rate_output[i].mutable_data(); - *sample_rate_ptr = sample_meta_[i].sample_rate; - } catch (const DALIException &e) { - DALI_FAIL(make_string("Error decoding file.\nError: ", e.what(), "\nFile: ", - files_names_[i], "\n")); - } + try { + DecodeSample(decoded_output[i], thread_id, i); + sample_rate_output[i].data[0] = use_resampling_ + ? target_sample_rates_[i] + : sample_meta_[i].sample_rate; + } catch (const DALIException &e) { + DALI_FAIL(make_string("Error decoding file.\nError: ", e.what(), "\nFile: ", + files_names_[i], "\n")); + } }); } + tp.WaitForWork(); } + +void AudioDecoderCpu::RunImpl(workspace_t &ws) { + TYPE_SWITCH(output_type_, type2id, OutputType, (int16_t, int32_t, float), ( + DecodeBatch(ws); + ), DALI_FAIL("Unsupported output type")) // NOLINT +} + } // namespace dali diff --git a/dali/operators/decoder/audio/audio_decoder_op.h b/dali/operators/decoder/audio/audio_decoder_op.h index 2f1d9435ec..fad73d1b6f 100644 --- a/dali/operators/decoder/audio/audio_decoder_op.h +++ b/dali/operators/decoder/audio/audio_decoder_op.h @@ -1,6 +1,3 @@ - - - // Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,16 +25,12 @@ #include "dali/pipeline/workspace/workspace.h" #include "dali/pipeline/operator/operator.h" #include "dali/pipeline/workspace/host_workspace.h" +#include "dali/kernels/signal/resampling.h" +#include "dali/kernels/signal/downmixing.h" +#include "dali/core/tensor_view.h" namespace dali { -namespace detail { - -const std::string kOutputTypeName = "dtype"; // NOLINT -const int kNumOutputs = 2; - -} // namespace detail - class AudioDecoderCpu : public Operator { private: using Backend = CPUBackend; @@ -45,7 +38,19 @@ class AudioDecoderCpu : public Operator { public: explicit inline AudioDecoderCpu(const OpSpec &spec) : Operator(spec), - output_type_(spec.GetArgument(detail::kOutputTypeName)) {} + output_type_(spec.GetArgument("dtype")), + downmix_(spec.GetArgument("downmix")), + use_resampling_(spec.HasArgument("sample_rate")), + quality_(spec.GetArgument("quality")) { + if (use_resampling_) { + double q = quality_; + DALI_ENFORCE(q >= 0 && q <= 100, "Resampling quality must be in [0..100] range"); + // this should give 3 lobes for q = 0, 16 lobes for q = 50 and 64 lobes for q = 100 + int lobes = std::round(0.007 * q * q - 0.09 * q + 3); + resampler_.Initialize(lobes, lobes * 64 + 1); + } + } + inline ~AudioDecoderCpu() override = default; @@ -54,15 +59,38 @@ class AudioDecoderCpu : public Operator { void RunImpl(workspace_t &ws) override; + bool CanInferOutputs() const override { return true; } private: - DALIDataType output_type_; + template + void DecodeSample( + const TensorView &audio, + int thread_idx, + int sample_idx); + + template + void DecodeBatch(workspace_t &ws); + + int64_t OutputLength(int64_t in_length, double in_rate, int sample_idx) const { + if (use_resampling_) { + return kernels::signal::resampling::resampled_length( + in_length, in_rate, target_sample_rates_[sample_idx]); + } else { + return in_length; + } + } + + std::vector target_sample_rates_; + kernels::signal::resampling::Resampler resampler_; + DALIDataType output_type_, decode_type_; + const bool downmix_ = false, use_resampling_ = false; + const float quality_ = 50.0f; std::vector files_names_; std::vector sample_meta_; - using sample_rate_t = decltype(AudioMetadata::sample_rate); + std::vector> intermediate_buffers_; std::vector> decoders_; }; diff --git a/dali/operators/decoder/audio/generic_decoder.cc b/dali/operators/decoder/audio/generic_decoder.cc index b6ca0c89b2..3ecc81fb9c 100644 --- a/dali/operators/decoder/audio/generic_decoder.cc +++ b/dali/operators/decoder/audio/generic_decoder.cc @@ -143,8 +143,8 @@ sf_count_t Tell(void *self) { template -void GenericAudioDecoder::DecodeTyped(span output) { - impl_->DecodeTyped(output); +ptrdiff_t GenericAudioDecoder::DecodeTyped(span output) { + return impl_->DecodeTyped(output); } @@ -172,8 +172,8 @@ GenericAudioDecoder::GenericAudioDecoder() : template struct GenericAudioDecoder::Impl { - void DecodeTyped(span output) { - ReadSamples(sound_, output); + ptrdiff_t DecodeTyped(span output) { + return ReadSamples(sound_, output); } @@ -195,7 +195,7 @@ struct GenericAudioDecoder::Impl { throw DALIException(make_string("Failed to open encoded data: ", sf_strerror(sound_))); } - ret.length = sf_info_.frames * sf_info_.channels; + ret.length = sf_info_.frames; ret.channels = sf_info_.channels; ret.sample_rate = sf_info_.samplerate; ret.channels_interleaved = true; diff --git a/dali/operators/decoder/audio/generic_decoder.h b/dali/operators/decoder/audio/generic_decoder.h index 68d1a91b70..74cc1a3d78 100644 --- a/dali/operators/decoder/audio/generic_decoder.h +++ b/dali/operators/decoder/audio/generic_decoder.h @@ -35,7 +35,7 @@ class DLL_PUBLIC GenericAudioDecoder : public TypedAudioDecoderBase public: DLL_PUBLIC GenericAudioDecoder(); - DLL_PUBLIC void DecodeTyped(span output) override; + DLL_PUBLIC ptrdiff_t DecodeTyped(span output) override; DLL_PUBLIC ~GenericAudioDecoder() override; diff --git a/dali/test/python/test_operator_audio_decoder.py b/dali/test/python/test_operator_audio_decoder.py new file mode 100644 index 0000000000..bb7b1ebcac --- /dev/null +++ b/dali/test/python/test_operator_audio_decoder.py @@ -0,0 +1,147 @@ +from __future__ import print_function +from __future__ import division +from nvidia.dali.pipeline import Pipeline +import nvidia.dali.ops as ops +import nvidia.dali.types as types +import scipy.io.wavfile +import numpy as np +import math +import librosa + +# generate sinewaves with given frequencies, +# add Hann envelope and store in channel-last layout +def generate_waveforms(length, frequencies): + n = int(math.ceil(length)) + X = np.arange(n, dtype=np.float32) + def window(x): + x = 2 * x / length - 1 + np.clip(x, -1, 1, out=x) + return 0.5 * (1 + np.cos(x * math.pi)) + + return np.sin(X[:,np.newaxis] * (np.array(frequencies) * (2 * math.pi))) * window(X)[:,np.newaxis] + +names = [ + "/tmp/dali_test_1C.wav", + "/tmp/dali_test_2C.wav", + "/tmp/dali_test_4C.wav" +] + +freqs = [ + np.array([0.02]), + np.array([0.01, 0.012]), + np.array([0.01, 0.012, 0.013, 0.014]) +] +rates = [ 16000, 22050, 12347 ] +lengths = [ 10000, 54321, 12345 ] + +def create_test_files(): + for i in range(len(names)): + wave = generate_waveforms(lengths[i], freqs[i]) + wave = (wave * 32767).round().astype(np.int16) + scipy.io.wavfile.write(names[i], rates[i], wave) + + +create_test_files() + +rate1 = 16000 +rate2 = 12999 + +class DecoderPipeline(Pipeline): + def __init__(self): + super(DecoderPipeline, self).__init__(batch_size=8, num_threads=3, device_id=0, + exec_async=True, exec_pipelined=True) + self.file_source = ops.ExternalSource() + self.plain_decoder = ops.AudioDecoder(dtype = types.INT16) + self.resampling_decoder = ops.AudioDecoder(sample_rate=rate1, dtype = types.INT16) + self.downmixing_decoder = ops.AudioDecoder(downmix=True, dtype = types.INT16) + self.resampling_downmixing_decoder = ops.AudioDecoder(sample_rate=rate2, downmix=True, + quality=50, dtype = types.FLOAT) + + def define_graph(self): + self.raw_file = self.file_source() + dec_plain, rates_plain = self.plain_decoder(self.raw_file) + dec_res, rates_res = self.resampling_decoder(self.raw_file) + dec_mix, rates_mix = self.downmixing_decoder(self.raw_file) + dec_res_mix, rates_res_mix = self.resampling_downmixing_decoder(self.raw_file) + out = [dec_plain, dec_res, dec_mix, dec_res_mix, + rates_plain, rates_res, rates_mix, rates_res_mix] + return out + + def iter_setup(self): + list = [] + for i in range(self.batch_size): + idx = i % len(names) + with open(names[idx], mode = "rb") as f: + list.append(np.array(bytearray(f.read()), np.uint8)) + self.feed_input(self.raw_file, list) + + +def rosa_resample(input, in_rate, out_rate): + if input.shape[1] == 1: + return librosa.resample(input[:,0], in_rate, out_rate)[:,np.newaxis] + + channels = [librosa.resample(np.array(input[:,c]), in_rate, out_rate) for c in range(input.shape[1])] + ret = np.zeros(shape = [channels[0].shape[0], len(channels)], dtype=channels[0].dtype) + for c, a in enumerate(channels): + ret[:,c] = a + + return ret + + + +def test_decoded_vs_generated(): + pipeline = DecoderPipeline() + pipeline.build(); + idx = 0 + for iter in range(1): + out = pipeline.run() + for i in range(len(out[0])): + plain = out[0].at(i) + res = out[1].at(i) + mix = out[2].at(i) + res_mix = out[3].at(i) + + ref_len = [0,0,0,0] + ref_len[0] = lengths[idx] + ref_len[1] = lengths[idx] * rate1 / rates[idx] + ref_len[2] = lengths[idx] + ref_len[3] = lengths[idx] * rate2 / rates[idx] + + ref0 = generate_waveforms(ref_len[0], freqs[idx]) * 32767 + ref1 = generate_waveforms(ref_len[1], freqs[idx] * (rates[idx] / rate1)) * 32767 + ref2 = generate_waveforms(ref_len[2], freqs[idx]) * 32767 + ref2 = ref2.mean(axis = 1, keepdims = 1) + ref3 = generate_waveforms(ref_len[3], freqs[idx] * (rates[idx] / rate2)) + ref3 = ref3.mean(axis = 1, keepdims = 1) + + assert(out[4].at(i)[0] == rates[idx]) + assert(out[5].at(i)[0] == rate1) + assert(out[6].at(i)[0] == rates[idx]) + assert(out[7].at(i)[0] == rate2) + + # just reading - allow only for rounding + assert np.allclose(plain, ref0, rtol = 0, atol=0.5) + # resampling - allow for 1e-3 dynamic range error + assert np.allclose(res, ref1, rtol = 0, atol=32767 * 1e-3) + # downmixing - allow for 2 bits of error + # - one for quantization of channels, one for quantization of result + assert np.allclose(mix, ref2, rtol = 0, atol=2) + # resampling with weird ratio - allow for 3e-3 dynamic range error + assert np.allclose(res_mix, ref3, rtol = 0, atol=3e-3) + + rosa_in1 = plain.astype(np.float32) + rosa1 = rosa_resample(rosa_in1, rates[idx], rate1) + rosa_in3 = rosa_in1 / 32767; + rosa3 = rosa_resample(rosa_in3.mean(axis = 1, keepdims = 1), rates[idx], rate2) + + assert np.allclose(res, rosa1, rtol = 0, atol=32767 * 1e-3) + assert np.allclose(res_mix, rosa3, rtol = 0, atol=3e-3) + + idx = (idx + 1) % len(names) + +def main(): + test_decoded_vs_generated() + + +if __name__ == '__main__': + main() diff --git a/include/dali/core/math_util.h b/include/dali/core/math_util.h index 348ef8a663..ed348264bd 100644 --- a/include/dali/core/math_util.h +++ b/include/dali/core/math_util.h @@ -167,6 +167,25 @@ constexpr double rad2deg(double rad) { return rad * r2d; } +/// @brief Calculates normalized sinc i.e. `sin(pi * x) / (pi * x)` +DALI_HOST_DEV DALI_FORCEINLINE +double sinc(double x) { + x *= M_PI; + if (std::abs(x) < 1e-8) + return 1.0 - x * x * (1.0 / 6); // remove singularity by using Taylor expansion + return std::sin(x) / x; +} + +/// @brief Calculates normalized sinc i.e. `sin(pi * x) / (pi * x)` +DALI_HOST_DEV DALI_FORCEINLINE +float sinc(float x) { + x *= M_PI; + if (std::abs(x) < 1e-5f) + return 1.0f - x * x * (1.0f / 6); // remove singularity by using Taylor expansion + return std::sin(x) / x; +} + + } // namespace dali #endif // DALI_CORE_MATH_UTIL_H_