diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn index 292caae3d2..fae1d5a572 100644 --- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn +++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn @@ -83,7 +83,6 @@ rtc_library("rnn_vad_lp_residual") { rtc_library("rnn_vad_pitch") { sources = [ - "pitch_info.h", "pitch_search.cc", "pitch_search.h", "pitch_search_internal.cc", @@ -94,6 +93,7 @@ rtc_library("rnn_vad_pitch") { ":rnn_vad_common", "../../../../api:array_view", "../../../../rtc_base:checks", + "../../../../rtc_base:gtest_prod", "../../../../rtc_base:safe_compare", "../../../../rtc_base:safe_conversions", ] diff --git a/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc b/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc index f6a4f42fd6..431c01fab3 100644 --- a/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc +++ b/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc @@ -20,7 +20,7 @@ namespace { constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT. static_assert(1 << kAutoCorrelationFftOrder > - kNumInvertedLags12kHz + kBufSize12kHz - kMaxPitch12kHz, + kNumLags12kHz + kBufSize12kHz - kMaxPitch12kHz, ""); } // namespace @@ -45,7 +45,7 @@ AutoCorrelationCalculator::~AutoCorrelationCalculator() = default; // pitch period. void AutoCorrelationCalculator::ComputeOnPitchBuffer( rtc::ArrayView pitch_buf, - rtc::ArrayView auto_corr) { + rtc::ArrayView auto_corr) { RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz); RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz); constexpr int kFftFrameSize = 1 << kAutoCorrelationFftOrder; @@ -53,7 +53,7 @@ void AutoCorrelationCalculator::ComputeOnPitchBuffer( static_assert(kConvolutionLength == kFrameSize20ms12kHz, "Mismatch between pitch buffer size, frame size and maximum " "pitch period."); - static_assert(kFftFrameSize > kNumInvertedLags12kHz + kConvolutionLength, + static_assert(kFftFrameSize > kNumLags12kHz + kConvolutionLength, "The FFT length is not sufficiently big to avoid cyclic " "convolution errors."); auto tmp = tmp_->GetView(); @@ -67,13 +67,12 @@ void AutoCorrelationCalculator::ComputeOnPitchBuffer( // Compute the FFT for the sliding frames chunk. The sliding frames are // defined as pitch_buf[i:i+kConvolutionLength] where i in - // [0, kNumInvertedLags12kHz). The chunk includes all of them, hence it is - // defined as pitch_buf[:kNumInvertedLags12kHz+kConvolutionLength]. + // [0, kNumLags12kHz). The chunk includes all of them, hence it is + // defined as pitch_buf[:kNumLags12kHz+kConvolutionLength]. std::copy(pitch_buf.begin(), - pitch_buf.begin() + kConvolutionLength + kNumInvertedLags12kHz, + pitch_buf.begin() + kConvolutionLength + kNumLags12kHz, tmp.begin()); - std::fill(tmp.begin() + kNumInvertedLags12kHz + kConvolutionLength, tmp.end(), - 0.f); + std::fill(tmp.begin() + kNumLags12kHz + kConvolutionLength, tmp.end(), 0.f); fft_.ForwardTransform(*tmp_, X_.get(), /*ordered=*/false); // Convolve in the frequency domain. @@ -84,7 +83,7 @@ void AutoCorrelationCalculator::ComputeOnPitchBuffer( // Extract the auto-correlation coefficients. std::copy(tmp.begin() + kConvolutionLength - 1, - tmp.begin() + kConvolutionLength + kNumInvertedLags12kHz - 1, + tmp.begin() + kConvolutionLength + kNumLags12kHz - 1, auto_corr.begin()); } diff --git a/modules/audio_processing/agc2/rnn_vad/auto_correlation.h b/modules/audio_processing/agc2/rnn_vad/auto_correlation.h index de7f453bc7..d58558ca2e 100644 --- a/modules/audio_processing/agc2/rnn_vad/auto_correlation.h +++ b/modules/audio_processing/agc2/rnn_vad/auto_correlation.h @@ -34,7 +34,7 @@ class AutoCorrelationCalculator { // |auto_corr| indexes are inverted lags. void ComputeOnPitchBuffer( rtc::ArrayView pitch_buf, - rtc::ArrayView auto_corr); + rtc::ArrayView auto_corr); private: Pffft fft_; diff --git a/modules/audio_processing/agc2/rnn_vad/common.h b/modules/audio_processing/agc2/rnn_vad/common.h index d6deff1556..36b366ad1d 100644 --- a/modules/audio_processing/agc2/rnn_vad/common.h +++ b/modules/audio_processing/agc2/rnn_vad/common.h @@ -36,7 +36,13 @@ constexpr int kInitialMinPitch24kHz = 3 * kMinPitch24kHz; static_assert(kMinPitch24kHz < kInitialMinPitch24kHz, ""); static_assert(kInitialMinPitch24kHz < kMaxPitch24kHz, ""); static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, ""); -constexpr int kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz; +// Number of (inverted) lags during the initial pitch search phase at 24 kHz. +constexpr int kInitialNumLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz; +// Number of (inverted) lags during the pitch search refinement phase at 24 kHz. +constexpr int kRefineNumLags24kHz = kMaxPitch24kHz + 1; +static_assert( + kRefineNumLags24kHz > kInitialNumLags24kHz, + "The refinement step must search the pitch in an extended pitch range."); // 12 kHz analysis. constexpr int kSampleRate12kHz = 12000; @@ -47,8 +53,8 @@ constexpr int kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2; constexpr int kMaxPitch12kHz = kMaxPitch24kHz / 2; static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, ""); // The inverted lags for the pitch interval [|kInitialMinPitch12kHz|, -// |kMaxPitch12kHz|] are in the range [0, |kNumInvertedLags12kHz|]. -constexpr int kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz; +// |kMaxPitch12kHz|] are in the range [0, |kNumLags12kHz|]. +constexpr int kNumLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz; // 48 kHz constants. constexpr int kMinPitch48kHz = kMinPitch24kHz * 2; diff --git a/modules/audio_processing/agc2/rnn_vad/features_extraction.cc b/modules/audio_processing/agc2/rnn_vad/features_extraction.cc index c207baeec0..cdbbbc311d 100644 --- a/modules/audio_processing/agc2/rnn_vad/features_extraction.cc +++ b/modules/audio_processing/agc2/rnn_vad/features_extraction.cc @@ -67,13 +67,12 @@ bool FeaturesExtractor::CheckSilenceComputeFeatures( ComputeLpResidual(lpc_coeffs, pitch_buf_24kHz_view_, lp_residual_view_); // Estimate pitch on the LP-residual and write the normalized pitch period // into the output vector (normalization based on training data stats). - pitch_info_48kHz_ = pitch_estimator_.Estimate(lp_residual_view_); - feature_vector[kFeatureVectorSize - 2] = - 0.01f * (pitch_info_48kHz_.period - 300); + pitch_period_48kHz_ = pitch_estimator_.Estimate(lp_residual_view_); + feature_vector[kFeatureVectorSize - 2] = 0.01f * (pitch_period_48kHz_ - 300); // Extract lagged frames (according to the estimated pitch period). - RTC_DCHECK_LE(pitch_info_48kHz_.period / 2, kMaxPitch24kHz); + RTC_DCHECK_LE(pitch_period_48kHz_ / 2, kMaxPitch24kHz); auto lagged_frame = pitch_buf_24kHz_view_.subview( - kMaxPitch24kHz - pitch_info_48kHz_.period / 2, kFrameSize20ms24kHz); + kMaxPitch24kHz - pitch_period_48kHz_ / 2, kFrameSize20ms24kHz); // Analyze reference and lagged frames checking if silence has been detected // and write the feature vector. return spectral_features_extractor_.CheckSilenceComputeFeatures( diff --git a/modules/audio_processing/agc2/rnn_vad/features_extraction.h b/modules/audio_processing/agc2/rnn_vad/features_extraction.h index ce5cce1857..e2c77d2cf8 100644 --- a/modules/audio_processing/agc2/rnn_vad/features_extraction.h +++ b/modules/audio_processing/agc2/rnn_vad/features_extraction.h @@ -16,7 +16,6 @@ #include "api/array_view.h" #include "modules/audio_processing/agc2/biquad_filter.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" -#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h" #include "modules/audio_processing/agc2/rnn_vad/pitch_search.h" #include "modules/audio_processing/agc2/rnn_vad/sequence_buffer.h" #include "modules/audio_processing/agc2/rnn_vad/spectral_features.h" @@ -53,7 +52,7 @@ class FeaturesExtractor { PitchEstimator pitch_estimator_; rtc::ArrayView reference_frame_view_; SpectralFeaturesExtractor spectral_features_extractor_; - PitchInfo pitch_info_48kHz_; + int pitch_period_48kHz_; }; } // namespace rnn_vad diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_info.h b/modules/audio_processing/agc2/rnn_vad/pitch_info.h deleted file mode 100644 index c9fdd182b0..0000000000 --- a/modules/audio_processing/agc2/rnn_vad/pitch_info.h +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_ -#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_ - -namespace webrtc { -namespace rnn_vad { - -// Stores pitch period and gain information. The pitch gain measures the -// strength of the pitch (the higher, the stronger). -struct PitchInfo { - PitchInfo() : period(0), gain(0.f) {} - PitchInfo(int p, float g) : period(p), gain(g) {} - int period; - float gain; -}; - -} // namespace rnn_vad -} // namespace webrtc - -#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_ diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc index 85f67377e4..9d4c5a2d81 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc @@ -21,22 +21,22 @@ namespace rnn_vad { PitchEstimator::PitchEstimator() : pitch_buf_decimated_(kBufSize12kHz), pitch_buf_decimated_view_(pitch_buf_decimated_.data(), kBufSize12kHz), - auto_corr_(kNumInvertedLags12kHz), - auto_corr_view_(auto_corr_.data(), kNumInvertedLags12kHz) { + auto_corr_(kNumLags12kHz), + auto_corr_view_(auto_corr_.data(), kNumLags12kHz) { RTC_DCHECK_EQ(kBufSize12kHz, pitch_buf_decimated_.size()); - RTC_DCHECK_EQ(kNumInvertedLags12kHz, auto_corr_view_.size()); + RTC_DCHECK_EQ(kNumLags12kHz, auto_corr_view_.size()); } PitchEstimator::~PitchEstimator() = default; -PitchInfo PitchEstimator::Estimate( - rtc::ArrayView pitch_buf) { +int PitchEstimator::Estimate( + rtc::ArrayView pitch_buffer) { // Perform the initial pitch search at 12 kHz. - Decimate2x(pitch_buf, pitch_buf_decimated_view_); + Decimate2x(pitch_buffer, pitch_buf_decimated_view_); auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buf_decimated_view_, auto_corr_view_); - CandidatePitchPeriods pitch_candidates_inverted_lags = FindBestPitchPeriods( - auto_corr_view_, pitch_buf_decimated_view_, kMaxPitch12kHz); + CandidatePitchPeriods pitch_candidates_inverted_lags = + ComputePitchPeriod12kHz(pitch_buf_decimated_view_, auto_corr_view_); // Refine the pitch period estimation. // The refinement is done using the pitch buffer that contains 24 kHz samples. // Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12 @@ -44,12 +44,14 @@ PitchInfo PitchEstimator::Estimate( pitch_candidates_inverted_lags.best *= 2; pitch_candidates_inverted_lags.second_best *= 2; const int pitch_inv_lag_48kHz = - RefinePitchPeriod48kHz(pitch_buf, pitch_candidates_inverted_lags); + ComputePitchPeriod48kHz(pitch_buffer, pitch_candidates_inverted_lags); // Look for stronger harmonics to find the final pitch period and its gain. RTC_DCHECK_LT(pitch_inv_lag_48kHz, kMaxPitch48kHz); - last_pitch_48kHz_ = CheckLowerPitchPeriodsAndComputePitchGain( - pitch_buf, kMaxPitch48kHz - pitch_inv_lag_48kHz, last_pitch_48kHz_); - return last_pitch_48kHz_; + last_pitch_48kHz_ = ComputeExtendedPitchPeriod48kHz( + pitch_buffer, + /*initial_pitch_period_48kHz=*/kMaxPitch48kHz - pitch_inv_lag_48kHz, + last_pitch_48kHz_); + return last_pitch_48kHz_.period; } } // namespace rnn_vad diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.h b/modules/audio_processing/agc2/rnn_vad/pitch_search.h index 74133d0738..1e6b9ad706 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search.h +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.h @@ -17,8 +17,8 @@ #include "api/array_view.h" #include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" -#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h" #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" +#include "rtc_base/gtest_prod_util.h" namespace webrtc { namespace rnn_vad { @@ -30,17 +30,21 @@ class PitchEstimator { PitchEstimator(const PitchEstimator&) = delete; PitchEstimator& operator=(const PitchEstimator&) = delete; ~PitchEstimator(); - // Estimates the pitch period and gain. Returns the pitch estimation data for - // 48 kHz. - PitchInfo Estimate(rtc::ArrayView pitch_buf); + // Returns the estimated pitch period at 48 kHz. + int Estimate(rtc::ArrayView pitch_buffer); private: - PitchInfo last_pitch_48kHz_; + FRIEND_TEST_ALL_PREFIXES(RnnVadTest, PitchSearchWithinTolerance); + float GetLastPitchStrengthForTesting() const { + return last_pitch_48kHz_.strength; + } + + PitchInfo last_pitch_48kHz_{}; AutoCorrelationCalculator auto_corr_calculator_; std::vector pitch_buf_decimated_; rtc::ArrayView pitch_buf_decimated_view_; std::vector auto_corr_; - rtc::ArrayView auto_corr_view_; + rtc::ArrayView auto_corr_view_; }; } // namespace rnn_vad diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc index d782a18d2f..8179dbd965 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc @@ -26,94 +26,88 @@ namespace webrtc { namespace rnn_vad { namespace { -// Converts a lag to an inverted lag (only for 24kHz). -int GetInvertedLag(int lag) { - RTC_DCHECK_LE(lag, kMaxPitch24kHz); - return kMaxPitch24kHz - lag; -} - -float ComputeAutoCorrelationCoeff(rtc::ArrayView pitch_buf, - int inv_lag, - int max_pitch_period) { - RTC_DCHECK_LT(inv_lag, pitch_buf.size()); - RTC_DCHECK_LT(max_pitch_period, pitch_buf.size()); - RTC_DCHECK_LE(inv_lag, max_pitch_period); +float ComputeAutoCorrelation( + int inverted_lag, + rtc::ArrayView pitch_buffer) { + RTC_DCHECK_LT(inverted_lag, kBufSize24kHz); + RTC_DCHECK_LT(inverted_lag, kRefineNumLags24kHz); + static_assert(kMaxPitch24kHz < kBufSize24kHz, ""); // TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization. - return std::inner_product(pitch_buf.begin() + max_pitch_period, - pitch_buf.end(), pitch_buf.begin() + inv_lag, 0.f); + return std::inner_product(pitch_buffer.begin() + kMaxPitch24kHz, + pitch_buffer.end(), + pitch_buffer.begin() + inverted_lag, 0.f); } -// Given the auto-correlation coefficients for a lag and its neighbors, computes -// a pseudo-interpolation offset to be applied to the pitch period associated to -// the central auto-correlation coefficient |lag_auto_corr|. The output is a lag -// in {-1, 0, +1}. -// TODO(bugs.webrtc.org/9076): Consider removing pseudo-i since it -// is relevant only if the spectral analysis works at a sample rate that is -// twice as that of the pitch buffer (not so important instead for the estimated -// pitch period feature fed into the RNN). -int GetPitchPseudoInterpolationOffset(float prev_auto_corr, - float lag_auto_corr, - float next_auto_corr) { - const float& a = prev_auto_corr; - const float& b = lag_auto_corr; - const float& c = next_auto_corr; - - int offset = 0; - if ((c - a) > 0.7f * (b - a)) { - offset = 1; // |c| is the largest auto-correlation coefficient. - } else if ((a - c) > 0.7f * (b - c)) { - offset = -1; // |a| is the largest auto-correlation coefficient. +// Given an auto-correlation coefficient `curr_auto_correlation` and its +// neighboring values `prev_auto_correlation` and `next_auto_correlation` +// computes a pseudo-interpolation offset to be applied to the pitch period +// associated to `curr`. The output is a lag in {-1, 0, +1}. +// TODO(bugs.webrtc.org/9076): Consider removing this method. +// `GetPitchPseudoInterpolationOffset()` it is relevant only if the spectral +// analysis works at a sample rate that is twice as that of the pitch buffer; +// In particular, it is not relevant for the estimated pitch period feature fed +// into the RNN. +int GetPitchPseudoInterpolationOffset(float prev_auto_correlation, + float curr_auto_correlation, + float next_auto_correlation) { + if ((next_auto_correlation - prev_auto_correlation) > + 0.7f * (curr_auto_correlation - prev_auto_correlation)) { + return 1; // |next_auto_correlation| is the largest auto-correlation + // coefficient. + } else if ((prev_auto_correlation - next_auto_correlation) > + 0.7f * (curr_auto_correlation - next_auto_correlation)) { + return -1; // |prev_auto_correlation| is the largest auto-correlation + // coefficient. } - return offset; + return 0; } // Refines a pitch period |lag| encoded as lag with pseudo-interpolation. The // output sample rate is twice as that of |lag|. int PitchPseudoInterpolationLagPitchBuf( int lag, - rtc::ArrayView pitch_buf) { + rtc::ArrayView pitch_buffer) { int offset = 0; // Cannot apply pseudo-interpolation at the boundaries. if (lag > 0 && lag < kMaxPitch24kHz) { + const int inverted_lag = kMaxPitch24kHz - lag; offset = GetPitchPseudoInterpolationOffset( - ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag - 1), - kMaxPitch24kHz), - ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag), - kMaxPitch24kHz), - ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag + 1), - kMaxPitch24kHz)); + ComputeAutoCorrelation(inverted_lag + 1, pitch_buffer), + ComputeAutoCorrelation(inverted_lag, pitch_buffer), + ComputeAutoCorrelation(inverted_lag - 1, pitch_buffer)); } return 2 * lag + offset; } -// Refines a pitch period |inv_lag| encoded as inverted lag with +// Refines a pitch period |inverted_lag| encoded as inverted lag with // pseudo-interpolation. The output sample rate is twice as that of -// |inv_lag|. +// |inverted_lag|. int PitchPseudoInterpolationInvLagAutoCorr( - int inv_lag, - rtc::ArrayView auto_corr) { + int inverted_lag, + rtc::ArrayView auto_correlation) { int offset = 0; // Cannot apply pseudo-interpolation at the boundaries. - if (inv_lag > 0 && inv_lag < rtc::dchecked_cast(auto_corr.size()) - 1) { + if (inverted_lag > 0 && inverted_lag < kInitialNumLags24kHz - 1) { offset = GetPitchPseudoInterpolationOffset( - auto_corr[inv_lag + 1], auto_corr[inv_lag], auto_corr[inv_lag - 1]); + auto_correlation[inverted_lag + 1], auto_correlation[inverted_lag], + auto_correlation[inverted_lag - 1]); } // TODO(bugs.webrtc.org/9076): When retraining, check if |offset| below should - // be subtracted since |inv_lag| is an inverted lag but offset is a lag. - return 2 * inv_lag + offset; + // be subtracted since |inverted_lag| is an inverted lag but offset is a lag. + return 2 * inverted_lag + offset; } -// Integer multipliers used in CheckLowerPitchPeriodsAndComputePitchGain() when +// Integer multipliers used in ComputeExtendedPitchPeriod48kHz() when // looking for sub-harmonics. // The values have been chosen to serve the following algorithm. Given the // initial pitch period T, we examine whether one of its harmonics is the true // fundamental frequency. We consider T/k with k in {2, ..., 15}. For each of -// these harmonics, in addition to the pitch gain of itself, we choose one +// these harmonics, in addition to the pitch strength of itself, we choose one // multiple of its pitch period, n*T/k, to validate it (by averaging their pitch -// gains). The multiplier n is chosen so that n*T/k is used only one time over -// all k. When for example k = 4, we should also expect a peak at 3*T/4. When -// k = 8 instead we don't want to look at 2*T/8, since we have already checked -// T/4 before. Instead, we look at T*3/8. +// strengths). The multiplier n is chosen so that n*T/k is used only one time +// over all k. When for example k = 4, we should also expect a peak at 3*T/4. +// When k = 8 instead we don't want to look at 2*T/8, since we have already +// checked T/4 before. Instead, we look at T*3/8. // The array can be generate in Python as follows: // from fractions import Fraction // # Smallest positive integer not in X. @@ -130,92 +124,168 @@ int PitchPseudoInterpolationInvLagAutoCorr( constexpr std::array kSubHarmonicMultipliers = { {3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2}}; -// Initial pitch period candidate thresholds for ComputePitchGainThreshold() for -// a sample rate of 24 kHz. Computed as [5*k*k for k in range(16)]. -constexpr std::array kInitialPitchPeriodThresholds = { - {20, 45, 80, 125, 180, 245, 320, 405, 500, 605, 720, 845, 980, 1125}}; +struct Range { + int min; + int max; +}; -} // namespace +// Creates a pitch period interval centered in `inverted_lag` with hard-coded +// radius. Clipping is applied so that the interval is always valid for a 24 kHz +// pitch buffer. +Range CreateInvertedLagRange(int inverted_lag) { + constexpr int kRadius = 2; + return {std::max(inverted_lag - kRadius, 0), + std::min(inverted_lag + kRadius, kInitialNumLags24kHz - 1)}; +} -void Decimate2x(rtc::ArrayView src, - rtc::ArrayView dst) { - // TODO(bugs.webrtc.org/9076): Consider adding anti-aliasing filter. - static_assert(2 * dst.size() == src.size(), ""); - for (int i = 0; rtc::SafeLt(i, dst.size()); ++i) { - dst[i] = src[2 * i]; +// Computes the auto correlation coefficients for the inverted lags in the +// closed interval `inverted_lags`. +void ComputeAutoCorrelation( + Range inverted_lags, + rtc::ArrayView pitch_buffer, + rtc::ArrayView auto_correlation) { + RTC_DCHECK_GE(inverted_lags.min, 0); + RTC_DCHECK_LT(inverted_lags.max, auto_correlation.size()); + for (int inverted_lag = inverted_lags.min; inverted_lag <= inverted_lags.max; + ++inverted_lag) { + auto_correlation[inverted_lag] = + ComputeAutoCorrelation(inverted_lag, pitch_buffer); + } +} + +int FindBestPitchPeriods24kHz( + rtc::ArrayView auto_correlation, + rtc::ArrayView pitch_buffer) { + static_assert(kMaxPitch24kHz > kInitialNumLags24kHz, ""); + static_assert(kMaxPitch24kHz < kBufSize24kHz, ""); + // Initialize the sliding 20 ms frame energy. + // TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization. + float denominator = std::inner_product( + pitch_buffer.begin(), pitch_buffer.begin() + kFrameSize20ms24kHz + 1, + pitch_buffer.begin(), 1.f); + // Search best pitch by looking at the scaled auto-correlation. + int best_inverted_lag = 0; // Pitch period. + float best_numerator = -1.f; // Pitch strength numerator. + float best_denominator = 0.f; // Pitch strength denominator. + for (int inverted_lag = 0; inverted_lag < kInitialNumLags24kHz; + ++inverted_lag) { + // A pitch candidate must have positive correlation. + if (auto_correlation[inverted_lag] > 0.f) { + const float numerator = + auto_correlation[inverted_lag] * auto_correlation[inverted_lag]; + // Compare numerator/denominator ratios without using divisions. + if (numerator * best_denominator > best_numerator * denominator) { + best_inverted_lag = inverted_lag; + best_numerator = numerator; + best_denominator = denominator; + } + } + // Update |denominator| for the next inverted lag. + static_assert(kInitialNumLags24kHz + kFrameSize20ms24kHz < kBufSize24kHz, + ""); + const float y_old = pitch_buffer[inverted_lag]; + const float y_new = pitch_buffer[inverted_lag + kFrameSize20ms24kHz]; + denominator -= y_old * y_old; + denominator += y_new * y_new; + denominator = std::max(0.f, denominator); } + return best_inverted_lag; } -float ComputePitchGainThreshold(int candidate_pitch_period, - int pitch_period_ratio, - int initial_pitch_period, - float initial_pitch_gain, - int prev_pitch_period, - float prev_pitch_gain) { - // Map arguments to more compact aliases. - const int& t1 = candidate_pitch_period; - const int& k = pitch_period_ratio; - const int& t0 = initial_pitch_period; - const float& g0 = initial_pitch_gain; - const int& t_prev = prev_pitch_period; - const float& g_prev = prev_pitch_gain; - - // Validate input. - RTC_DCHECK_GE(t1, 0); - RTC_DCHECK_GE(k, 2); - RTC_DCHECK_GE(t0, 0); - RTC_DCHECK_GE(t_prev, 0); - - // Compute a term that lowers the threshold when |t1| is close to the last - // estimated period |t_prev| - i.e., pitch tracking. - float lower_threshold_term = 0; - if (abs(t1 - t_prev) <= 1) { - // The candidate pitch period is within 1 sample from the previous one. - // Make the candidate at |t1| very easy to be accepted. - lower_threshold_term = g_prev; - } else if (abs(t1 - t_prev) == 2 && - t0 > kInitialPitchPeriodThresholds[k - 2]) { - // The candidate pitch period is 2 samples far from the previous one and the - // period |t0| (from which |t1| has been derived) is greater than a - // threshold. Make |t1| easy to be accepted. - lower_threshold_term = 0.5f * g_prev; +// Returns an alternative pitch period for `pitch_period` given a `multiplier` +// and a `divisor` of the period. +constexpr int GetAlternativePitchPeriod(int pitch_period, + int multiplier, + int divisor) { + RTC_DCHECK_GT(divisor, 0); + // Same as `round(multiplier * pitch_period / divisor)`. + return (2 * multiplier * pitch_period + divisor) / (2 * divisor); +} + +// Returns true if the alternative pitch period is stronger than the initial one +// given the last estimated pitch and the value of `period_divisor` used to +// compute the alternative pitch period via `GetAlternativePitchPeriod()`. +bool IsAlternativePitchStrongerThanInitial(PitchInfo last, + PitchInfo initial, + PitchInfo alternative, + int period_divisor) { + // Initial pitch period candidate thresholds for a sample rate of 24 kHz. + // Computed as [5*k*k for k in range(16)]. + constexpr std::array kInitialPitchPeriodThresholds = { + {20, 45, 80, 125, 180, 245, 320, 405, 500, 605, 720, 845, 980, 1125}}; + static_assert( + kInitialPitchPeriodThresholds.size() == kSubHarmonicMultipliers.size(), + ""); + RTC_DCHECK_GE(last.period, 0); + RTC_DCHECK_GE(initial.period, 0); + RTC_DCHECK_GE(alternative.period, 0); + RTC_DCHECK_GE(period_divisor, 2); + // Compute a term that lowers the threshold when |alternative.period| is close + // to the last estimated period |last.period| - i.e., pitch tracking. + float lower_threshold_term = 0.f; + if (std::abs(alternative.period - last.period) <= 1) { + // The candidate pitch period is within 1 sample from the last one. + // Make the candidate at |alternative.period| very easy to be accepted. + lower_threshold_term = last.strength; + } else if (std::abs(alternative.period - last.period) == 2 && + initial.period > + kInitialPitchPeriodThresholds[period_divisor - 2]) { + // The candidate pitch period is 2 samples far from the last one and the + // period |initial.period| (from which |alternative.period| has been + // derived) is greater than a threshold. Make |alternative.period| easy to + // be accepted. + lower_threshold_term = 0.5f * last.strength; } - // Set the threshold based on the gain of the initial estimate |t0|. Also - // reduce the chance of false positives caused by a bias towards high - // frequencies (originating from short-term correlations). - float threshold = std::max(0.3f, 0.7f * g0 - lower_threshold_term); - if (t1 < 3 * kMinPitch24kHz) { + // Set the threshold based on the strength of the initial estimate + // |initial.period|. Also reduce the chance of false positives caused by a + // bias towards high frequencies (originating from short-term correlations). + float threshold = + std::max(0.3f, 0.7f * initial.strength - lower_threshold_term); + if (alternative.period < 3 * kMinPitch24kHz) { // High frequency. - threshold = std::max(0.4f, 0.85f * g0 - lower_threshold_term); - } else if (t1 < 2 * kMinPitch24kHz) { + threshold = std::max(0.4f, 0.85f * initial.strength - lower_threshold_term); + } else if (alternative.period < 2 * kMinPitch24kHz) { // Even higher frequency. - threshold = std::max(0.5f, 0.9f * g0 - lower_threshold_term); + threshold = std::max(0.5f, 0.9f * initial.strength - lower_threshold_term); } - return threshold; + return alternative.strength > threshold; } -void ComputeSlidingFrameSquareEnergies( - rtc::ArrayView pitch_buf, - rtc::ArrayView yy_values) { - float yy = - ComputeAutoCorrelationCoeff(pitch_buf, kMaxPitch24kHz, kMaxPitch24kHz); +} // namespace + +void Decimate2x(rtc::ArrayView src, + rtc::ArrayView dst) { + // TODO(bugs.webrtc.org/9076): Consider adding anti-aliasing filter. + static_assert(2 * kBufSize12kHz == kBufSize24kHz, ""); + for (int i = 0; i < kBufSize12kHz; ++i) { + dst[i] = src[2 * i]; + } +} + +void ComputeSlidingFrameSquareEnergies24kHz( + rtc::ArrayView pitch_buffer, + rtc::ArrayView yy_values) { + float yy = ComputeAutoCorrelation(kMaxPitch24kHz, pitch_buffer); yy_values[0] = yy; - for (int i = 1; rtc::SafeLt(i, yy_values.size()); ++i) { - RTC_DCHECK_LE(i, kMaxPitch24kHz + kFrameSize20ms24kHz); - RTC_DCHECK_LE(i, kMaxPitch24kHz); - const float old_coeff = pitch_buf[kMaxPitch24kHz + kFrameSize20ms24kHz - i]; - const float new_coeff = pitch_buf[kMaxPitch24kHz - i]; - yy -= old_coeff * old_coeff; - yy += new_coeff * new_coeff; + static_assert(kMaxPitch24kHz - (kRefineNumLags24kHz - 1) >= 0, ""); + static_assert(kMaxPitch24kHz - 1 + kFrameSize20ms24kHz < kBufSize24kHz, ""); + for (int lag = 1; lag < kRefineNumLags24kHz; ++lag) { + const int inverted_lag = kMaxPitch24kHz - lag; + const float y_old = pitch_buffer[inverted_lag + kFrameSize20ms24kHz]; + const float y_new = pitch_buffer[inverted_lag]; + yy -= y_old * y_old; + yy += y_new * y_new; yy = std::max(0.f, yy); - yy_values[i] = yy; + yy_values[lag] = yy; } } -CandidatePitchPeriods FindBestPitchPeriods( - rtc::ArrayView auto_corr, - rtc::ArrayView pitch_buf, - int max_pitch_period) { +CandidatePitchPeriods ComputePitchPeriod12kHz( + rtc::ArrayView pitch_buffer, + rtc::ArrayView auto_correlation) { + static_assert(kMaxPitch12kHz > kNumLags12kHz, ""); + static_assert(kMaxPitch12kHz < kBufSize12kHz, ""); + // Stores a pitch candidate period and strength information. struct PitchCandidate { // Pitch period encoded as inverted lag. @@ -231,28 +301,22 @@ CandidatePitchPeriods FindBestPitchPeriods( } }; - RTC_DCHECK_GT(max_pitch_period, auto_corr.size()); - RTC_DCHECK_LT(max_pitch_period, pitch_buf.size()); - const int frame_size = - rtc::dchecked_cast(pitch_buf.size()) - max_pitch_period; - RTC_DCHECK_GT(frame_size, 0); // TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization. - float yy = - std::inner_product(pitch_buf.begin(), pitch_buf.begin() + frame_size + 1, - pitch_buf.begin(), 1.f); + float denominator = std::inner_product( + pitch_buffer.begin(), pitch_buffer.begin() + kFrameSize20ms12kHz + 1, + pitch_buffer.begin(), 1.f); // Search best and second best pitches by looking at the scaled // auto-correlation. - PitchCandidate candidate; PitchCandidate best; PitchCandidate second_best; second_best.period_inverted_lag = 1; - for (int inv_lag = 0; inv_lag < rtc::dchecked_cast(auto_corr.size()); - ++inv_lag) { + for (int inverted_lag = 0; inverted_lag < kNumLags12kHz; ++inverted_lag) { // A pitch candidate must have positive correlation. - if (auto_corr[inv_lag] > 0) { - candidate.period_inverted_lag = inv_lag; - candidate.strength_numerator = auto_corr[inv_lag] * auto_corr[inv_lag]; - candidate.strength_denominator = yy; + if (auto_correlation[inverted_lag] > 0.f) { + PitchCandidate candidate{ + inverted_lag, + auto_correlation[inverted_lag] * auto_correlation[inverted_lag], + denominator}; if (candidate.HasStrongerPitchThan(second_best)) { if (candidate.HasStrongerPitchThan(best)) { second_best = best; @@ -263,144 +327,144 @@ CandidatePitchPeriods FindBestPitchPeriods( } } // Update |squared_energy_y| for the next inverted lag. - const float old_coeff = pitch_buf[inv_lag]; - const float new_coeff = pitch_buf[inv_lag + frame_size]; - yy -= old_coeff * old_coeff; - yy += new_coeff * new_coeff; - yy = std::max(0.f, yy); + const float y_old = pitch_buffer[inverted_lag]; + const float y_new = pitch_buffer[inverted_lag + kFrameSize20ms12kHz]; + denominator -= y_old * y_old; + denominator += y_new * y_new; + denominator = std::max(0.f, denominator); } return {best.period_inverted_lag, second_best.period_inverted_lag}; } -int RefinePitchPeriod48kHz( - rtc::ArrayView pitch_buf, - CandidatePitchPeriods pitch_candidates_inverted_lags) { +int ComputePitchPeriod48kHz( + rtc::ArrayView pitch_buffer, + CandidatePitchPeriods pitch_candidates) { // Compute the auto-correlation terms only for neighbors of the given pitch // candidates (similar to what is done in ComputePitchAutoCorrelation(), but // for a few lag values). - std::array auto_correlation; - auto_correlation.fill( - 0.f); // Zeros become ignored lags in FindBestPitchPeriods(). - auto is_neighbor = [](int i, int j) { - return ((i > j) ? (i - j) : (j - i)) <= 2; - }; - // TODO(https://crbug.com/webrtc/10480): Optimize by removing the loop. - for (int inverted_lag = 0; rtc::SafeLt(inverted_lag, auto_correlation.size()); - ++inverted_lag) { - if (is_neighbor(inverted_lag, pitch_candidates_inverted_lags.best) || - is_neighbor(inverted_lag, pitch_candidates_inverted_lags.second_best)) - auto_correlation[inverted_lag] = - ComputeAutoCorrelationCoeff(pitch_buf, inverted_lag, kMaxPitch24kHz); + std::array auto_correlation{}; + const Range r1 = CreateInvertedLagRange(pitch_candidates.best); + const Range r2 = CreateInvertedLagRange(pitch_candidates.second_best); + RTC_DCHECK_LE(r1.min, r1.max); + RTC_DCHECK_LE(r2.min, r2.max); + if (r1.min <= r2.min && r1.max + 1 >= r2.min) { + // Overlapping or adjacent ranges (`r1` precedes `r2`). + RTC_DCHECK_LE(r1.max, r2.max); + ComputeAutoCorrelation({r1.min, r2.max}, pitch_buffer, auto_correlation); + } else if (r1.min > r2.min && r2.max + 1 >= r1.min) { + // Overlapping or adjacent ranges (`r2` precedes `r1`). + RTC_DCHECK_LE(r2.max, r1.max); + ComputeAutoCorrelation({r2.min, r1.max}, pitch_buffer, auto_correlation); + } else { + // Disjoint ranges. + ComputeAutoCorrelation(r1, pitch_buffer, auto_correlation); + ComputeAutoCorrelation(r2, pitch_buffer, auto_correlation); } // Find best pitch at 24 kHz. - const CandidatePitchPeriods pitch_candidates_24kHz = - FindBestPitchPeriods(auto_correlation, pitch_buf, kMaxPitch24kHz); + const int pitch_candidate_24kHz = + FindBestPitchPeriods24kHz(auto_correlation, pitch_buffer); // Pseudo-interpolation. - return PitchPseudoInterpolationInvLagAutoCorr(pitch_candidates_24kHz.best, + return PitchPseudoInterpolationInvLagAutoCorr(pitch_candidate_24kHz, auto_correlation); } -PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( - rtc::ArrayView pitch_buf, +PitchInfo ComputeExtendedPitchPeriod48kHz( + rtc::ArrayView pitch_buffer, int initial_pitch_period_48kHz, - PitchInfo prev_pitch_48kHz) { + PitchInfo last_pitch_48kHz) { RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz); RTC_DCHECK_LE(initial_pitch_period_48kHz, kMaxPitch48kHz); + // Stores information for a refined pitch candidate. struct RefinedPitchCandidate { - RefinedPitchCandidate() {} - RefinedPitchCandidate(int period_24kHz, float gain, float xy, float yy) - : period_24kHz(period_24kHz), gain(gain), xy(xy), yy(yy) {} - int period_24kHz; - // Pitch strength information. - float gain; - // Additional pitch strength information used for the final estimation of - // pitch gain. + int period; + float strength; + // Additional strength data used for the final estimation of the strength. float xy; // Cross-correlation. float yy; // Auto-correlation. }; // Initialize. - std::array yy_values; - ComputeSlidingFrameSquareEnergies(pitch_buf, - {yy_values.data(), yy_values.size()}); + std::array yy_values; + // TODO(bugs.webrtc.org/9076): Reuse values from FindBestPitchPeriods24kHz(). + ComputeSlidingFrameSquareEnergies24kHz(pitch_buffer, yy_values); const float xx = yy_values[0]; - // Helper lambdas. - const auto pitch_gain = [](float xy, float yy, float xx) { - RTC_DCHECK_LE(0.f, xx * yy); + const auto pitch_strength = [](float xy, float yy, float xx) { + RTC_DCHECK_GE(xx * yy, 0.f); return xy / std::sqrt(1.f + xx * yy); }; - // Initial pitch candidate gain. + // Initial pitch candidate. RefinedPitchCandidate best_pitch; - best_pitch.period_24kHz = + best_pitch.period = std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1); - best_pitch.xy = ComputeAutoCorrelationCoeff( - pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHz); - best_pitch.yy = yy_values[best_pitch.period_24kHz]; - best_pitch.gain = pitch_gain(best_pitch.xy, best_pitch.yy, xx); - - // Store the initial pitch period information. - const int initial_pitch_period = best_pitch.period_24kHz; - const float initial_pitch_gain = best_pitch.gain; - - // Given the initial pitch estimation, check lower periods (i.e., harmonics). - const auto alternative_period = [](int period, int k, int n) -> int { - RTC_DCHECK_GT(k, 0); - return (2 * n * period + k) / (2 * k); // Same as round(n*period/k). - }; - // |max_k| such that alternative_period(initial_pitch_period, max_k, 1) equals - // kMinPitch24kHz. - const int max_k = (2 * initial_pitch_period) / (2 * kMinPitch24kHz - 1); - for (int k = 2; k <= max_k; ++k) { - int candidate_pitch_period = alternative_period(initial_pitch_period, k, 1); - RTC_DCHECK_GE(candidate_pitch_period, kMinPitch24kHz); - // When looking at |candidate_pitch_period|, we also look at one of its + best_pitch.xy = + ComputeAutoCorrelation(kMaxPitch24kHz - best_pitch.period, pitch_buffer); + best_pitch.yy = yy_values[best_pitch.period]; + best_pitch.strength = pitch_strength(best_pitch.xy, best_pitch.yy, xx); + + // 24 kHz version of the last estimated pitch and copy of the initial + // estimation. + const PitchInfo last_pitch{last_pitch_48kHz.period / 2, + last_pitch_48kHz.strength}; + const PitchInfo initial_pitch{best_pitch.period, best_pitch.strength}; + + // Find `max_period_divisor` such that the result of + // `GetAlternativePitchPeriod(initial_pitch_period, 1, max_period_divisor)` + // equals `kMinPitch24kHz`. + const int max_period_divisor = + (2 * initial_pitch.period) / (2 * kMinPitch24kHz - 1); + for (int period_divisor = 2; period_divisor <= max_period_divisor; + ++period_divisor) { + PitchInfo alternative_pitch; + alternative_pitch.period = GetAlternativePitchPeriod( + initial_pitch.period, /*multiplier=*/1, period_divisor); + RTC_DCHECK_GE(alternative_pitch.period, kMinPitch24kHz); + // When looking at |alternative_pitch.period|, we also look at one of its // sub-harmonics. |kSubHarmonicMultipliers| is used to know where to look. - // |k| == 2 is a special case since |candidate_pitch_secondary_period| might - // be greater than the maximum pitch period. - int candidate_pitch_secondary_period = alternative_period( - initial_pitch_period, k, kSubHarmonicMultipliers[k - 2]); - RTC_DCHECK_GT(candidate_pitch_secondary_period, 0); - if (k == 2 && candidate_pitch_secondary_period > kMaxPitch24kHz) { - candidate_pitch_secondary_period = initial_pitch_period; + // |period_divisor| == 2 is a special case since |dual_alternative_period| + // might be greater than the maximum pitch period. + int dual_alternative_period = GetAlternativePitchPeriod( + initial_pitch.period, kSubHarmonicMultipliers[period_divisor - 2], + period_divisor); + RTC_DCHECK_GT(dual_alternative_period, 0); + if (period_divisor == 2 && dual_alternative_period > kMaxPitch24kHz) { + dual_alternative_period = initial_pitch.period; } - RTC_DCHECK_NE(candidate_pitch_period, candidate_pitch_secondary_period) + RTC_DCHECK_NE(alternative_pitch.period, dual_alternative_period) << "The lower pitch period and the additional sub-harmonic must not " "coincide."; // Compute an auto-correlation score for the primary pitch candidate - // |candidate_pitch_period| by also looking at its possible sub-harmonic - // |candidate_pitch_secondary_period|. - float xy_primary_period = ComputeAutoCorrelationCoeff( - pitch_buf, GetInvertedLag(candidate_pitch_period), kMaxPitch24kHz); - float xy_secondary_period = ComputeAutoCorrelationCoeff( - pitch_buf, GetInvertedLag(candidate_pitch_secondary_period), - kMaxPitch24kHz); + // |alternative_pitch.period| by also looking at its possible sub-harmonic + // |dual_alternative_period|. + float xy_primary_period = ComputeAutoCorrelation( + kMaxPitch24kHz - alternative_pitch.period, pitch_buffer); + float xy_secondary_period = ComputeAutoCorrelation( + kMaxPitch24kHz - dual_alternative_period, pitch_buffer); float xy = 0.5f * (xy_primary_period + xy_secondary_period); - float yy = 0.5f * (yy_values[candidate_pitch_period] + - yy_values[candidate_pitch_secondary_period]); - float candidate_pitch_gain = pitch_gain(xy, yy, xx); + float yy = 0.5f * (yy_values[alternative_pitch.period] + + yy_values[dual_alternative_period]); + alternative_pitch.strength = pitch_strength(xy, yy, xx); // Maybe update best period. - float threshold = ComputePitchGainThreshold( - candidate_pitch_period, k, initial_pitch_period, initial_pitch_gain, - prev_pitch_48kHz.period / 2, prev_pitch_48kHz.gain); - if (candidate_pitch_gain > threshold) { - best_pitch = {candidate_pitch_period, candidate_pitch_gain, xy, yy}; + if (IsAlternativePitchStrongerThanInitial( + last_pitch, initial_pitch, alternative_pitch, period_divisor)) { + best_pitch = {alternative_pitch.period, alternative_pitch.strength, xy, + yy}; } } - // Final pitch gain and period. + // Final pitch strength and period. best_pitch.xy = std::max(0.f, best_pitch.xy); RTC_DCHECK_LE(0.f, best_pitch.yy); - float final_pitch_gain = (best_pitch.yy <= best_pitch.xy) - ? 1.f - : best_pitch.xy / (best_pitch.yy + 1.f); - final_pitch_gain = std::min(best_pitch.gain, final_pitch_gain); + float final_pitch_strength = (best_pitch.yy <= best_pitch.xy) + ? 1.f + : best_pitch.xy / (best_pitch.yy + 1.f); + final_pitch_strength = std::min(best_pitch.strength, final_pitch_strength); int final_pitch_period_48kHz = std::max( kMinPitch48kHz, - PitchPseudoInterpolationLagPitchBuf(best_pitch.period_24kHz, pitch_buf)); + PitchPseudoInterpolationLagPitchBuf(best_pitch.period, pitch_buffer)); - return {final_pitch_period_48kHz, final_pitch_gain}; + return {final_pitch_period_48kHz, final_pitch_strength}; } } // namespace rnn_vad diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h index cab6286523..b16a2f438d 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h @@ -18,7 +18,6 @@ #include "api/array_view.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" -#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h" namespace webrtc { namespace rnn_vad { @@ -27,56 +26,78 @@ namespace rnn_vad { void Decimate2x(rtc::ArrayView src, rtc::ArrayView dst); -// Computes a gain threshold for a candidate pitch period given the initial and -// the previous pitch period and gain estimates and the pitch period ratio used -// to derive the candidate pitch period from the initial period. -float ComputePitchGainThreshold(int candidate_pitch_period, - int pitch_period_ratio, - int initial_pitch_period, - float initial_pitch_gain, - int prev_pitch_period, - float prev_pitch_gain); - -// Computes the sum of squared samples for every sliding frame in the pitch -// buffer. |yy_values| indexes are lags. +// Key concepts and keywords used below in this file. +// +// The pitch estimation relies on a pitch buffer, which is an array-like data +// structured designed as follows: +// +// |....A....|.....B.....| +// +// The part on the left, named `A` contains the oldest samples, whereas `B` +// contains the most recent ones. The size of `A` corresponds to the maximum +// pitch period, that of `B` to the analysis frame size (e.g., 16 ms and 20 ms +// respectively). +// +// Pitch estimation is essentially based on the analysis of two 20 ms frames +// extracted from the pitch buffer. One frame, called `x`, is kept fixed and +// corresponds to `B` - i.e., the most recent 20 ms. The other frame, called +// `y`, is extracted from different parts of the buffer instead. +// +// The offset between `x` and `y` corresponds to a specific pitch period. +// For instance, if `y` is positioned at the beginning of the pitch buffer, then +// the cross-correlation between `x` and `y` can be used as an indication of the +// strength for the maximum pitch. // -// The pitch buffer is structured as depicted below: -// |.........|...........| -// a b -// The part on the left, named "a" contains the oldest samples, whereas "b" the -// most recent ones. The size of "a" corresponds to the maximum pitch period, -// that of "b" to the frame size (e.g., 16 ms and 20 ms respectively). -void ComputeSlidingFrameSquareEnergies( - rtc::ArrayView pitch_buf, - rtc::ArrayView yy_values); +// Such an offset can be encoded in two ways: +// - As a lag, which is the index in the pitch buffer for the first item in `y` +// - As an inverted lag, which is the number of samples from the beginning of +// `x` and the end of `y` +// +// |---->| lag +// |....A....|.....B.....| +// |<--| inverted lag +// |.....y.....| `y` 20 ms frame +// +// The inverted lag has the advantage of being directly proportional to the +// corresponding pitch period. + +// Computes the sum of squared samples for every sliding frame `y` in the pitch +// buffer. The indexes of `yy_values` are lags. +void ComputeSlidingFrameSquareEnergies24kHz( + rtc::ArrayView pitch_buffer, + rtc::ArrayView yy_values); -// Top-2 pitch period candidates. +// Top-2 pitch period candidates. Unit: number of samples - i.e., inverted lags. struct CandidatePitchPeriods { int best; int second_best; }; -// Computes the candidate pitch periods given the auto-correlation coefficients -// stored according to ComputePitchAutoCorrelation() (i.e., using inverted -// lags). The return periods are inverted lags. -CandidatePitchPeriods FindBestPitchPeriods( - rtc::ArrayView auto_corr, - rtc::ArrayView pitch_buf, - int max_pitch_period); +// Computes the candidate pitch periods at 12 kHz given a view on the 12 kHz +// pitch buffer and the auto-correlation values (having inverted lags as +// indexes). +CandidatePitchPeriods ComputePitchPeriod12kHz( + rtc::ArrayView pitch_buffer, + rtc::ArrayView auto_correlation); -// Refines the pitch period estimation given the pitch buffer |pitch_buf| and -// the initial pitch period estimation |pitch_candidates_inverted_lags|. -// Returns an inverted lag at 48 kHz. -int RefinePitchPeriod48kHz( - rtc::ArrayView pitch_buf, - CandidatePitchPeriods pitch_candidates_inverted_lags); +// Computes the pitch period at 48 kHz given a view on the 24 kHz pitch buffer +// and the pitch period candidates at 24 kHz (encoded as inverted lag). +int ComputePitchPeriod48kHz( + rtc::ArrayView pitch_buffer, + CandidatePitchPeriods pitch_candidates_24kHz); + +struct PitchInfo { + int period; + float strength; +}; -// Refines the pitch period estimation and compute the pitch gain. Returns the -// refined pitch estimation data at 48 kHz. -PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( - rtc::ArrayView pitch_buf, +// Computes the pitch period at 48 kHz searching in an extended pitch range +// given a view on the 24 kHz pitch buffer, the initial 48 kHz estimation +// (computed by `ComputePitchPeriod48kHz()`) and the last estimated pitch. +PitchInfo ComputeExtendedPitchPeriod48kHz( + rtc::ArrayView pitch_buffer, int initial_pitch_period_48kHz, - PitchInfo prev_pitch_48kHz); + PitchInfo last_pitch_48kHz); } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc index fdbee68357..7acb046db1 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc @@ -31,138 +31,77 @@ constexpr float kTestPitchGainsHigh = 0.75f; } // namespace -class ComputePitchGainThresholdTest - : public ::testing::Test, - public ::testing::WithParamInterface> {}; - -// Checks that the computed pitch gain is within tolerance given test input -// data. -TEST_P(ComputePitchGainThresholdTest, WithinTolerance) { - const auto params = GetParam(); - const int candidate_pitch_period = std::get<0>(params); - const int pitch_period_ratio = std::get<1>(params); - const int initial_pitch_period = std::get<2>(params); - const float initial_pitch_gain = std::get<3>(params); - const int prev_pitch_period = std::get<4>(params); - const float prev_pitch_gain = std::get<5>(params); - const float threshold = std::get<6>(params); - { - // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. - // FloatingPointExceptionObserver fpe_observer; - EXPECT_NEAR( - threshold, - ComputePitchGainThreshold(candidate_pitch_period, pitch_period_ratio, - initial_pitch_period, initial_pitch_gain, - prev_pitch_period, prev_pitch_gain), - 5e-7f); - } -} - -INSTANTIATE_TEST_SUITE_P( - RnnVadTest, - ComputePitchGainThresholdTest, - ::testing::Values( - std::make_tuple(31, 7, 219, 0.45649201f, 199, 0.604747f, 0.40000001f), - std::make_tuple(113, - 2, - 226, - 0.20967799f, - 219, - 0.40392199f, - 0.30000001f), - std::make_tuple(63, 2, 126, 0.210788f, 364, 0.098519f, 0.40000001f), - std::make_tuple(30, 5, 152, 0.82356697f, 149, 0.55535901f, 0.700032f), - std::make_tuple(76, 2, 151, 0.79522997f, 151, 0.82356697f, 0.675946f), - std::make_tuple(31, 5, 153, 0.85069299f, 150, 0.79073799f, 0.72308898f), - std::make_tuple(78, 2, 156, 0.72750503f, 153, 0.85069299f, 0.618379f))); - // Checks that the frame-wise sliding square energy function produces output // within tolerance given test input data. -TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesWithinTolerance) { +TEST(RnnVadTest, ComputeSlidingFrameSquareEnergies24kHzWithinTolerance) { PitchTestData test_data; std::array computed_output; - { - // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. - // FloatingPointExceptionObserver fpe_observer; - ComputeSlidingFrameSquareEnergies(test_data.GetPitchBufView(), - computed_output); - } + // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(), + computed_output); auto square_energies_view = test_data.GetPitchBufSquareEnergiesView(); ExpectNearAbsolute({square_energies_view.data(), square_energies_view.size()}, computed_output, 3e-2f); } // Checks that the estimated pitch period is bit-exact given test input data. -TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) { +TEST(RnnVadTest, ComputePitchPeriod12kHzBitExactness) { PitchTestData test_data; std::array pitch_buf_decimated; Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated); CandidatePitchPeriods pitch_candidates; - { - // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. - // FloatingPointExceptionObserver fpe_observer; - auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView(); - pitch_candidates = FindBestPitchPeriods(auto_corr_view, pitch_buf_decimated, - kMaxPitch12kHz); - } + // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView(); + pitch_candidates = + ComputePitchPeriod12kHz(pitch_buf_decimated, auto_corr_view); EXPECT_EQ(pitch_candidates.best, 140); EXPECT_EQ(pitch_candidates.second_best, 142); } // Checks that the refined pitch period is bit-exact given test input data. -TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) { +TEST(RnnVadTest, ComputePitchPeriod48kHzBitExactness) { PitchTestData test_data; // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; - EXPECT_EQ(RefinePitchPeriod48kHz(test_data.GetPitchBufView(), - /*pitch_candidates=*/{280, 284}), + EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), + /*pitch_candidates=*/{280, 284}), 560); - EXPECT_EQ(RefinePitchPeriod48kHz(test_data.GetPitchBufView(), - /*pitch_candidates=*/{260, 284}), + EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), + /*pitch_candidates=*/{260, 284}), 568); } -class CheckLowerPitchPeriodsAndComputePitchGainTest +class ComputeExtendedPitchPeriod48kHzTest : public ::testing::Test, - public ::testing::WithParamInterface> {}; + public ::testing::WithParamInterface< + std::tuple> { + protected: + int GetInitialPitchPeriod() const { return std::get<0>(GetParam()); } + int GetLastPitchPeriod() const { return std::get<1>(GetParam()); } + float GetLastPitchStrength() const { return std::get<2>(GetParam()); } + int GetExpectedPitchPeriod() const { return std::get<3>(GetParam()); } + float GetExpectedPitchStrength() const { return std::get<4>(GetParam()); } +}; // Checks that the computed pitch period is bit-exact and that the computed -// pitch gain is within tolerance given test input data. -TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest, +// pitch strength is within tolerance given test input data. +TEST_P(ComputeExtendedPitchPeriod48kHzTest, PeriodBitExactnessGainWithinTolerance) { - const auto params = GetParam(); - const int initial_pitch_period = std::get<0>(params); - const int prev_pitch_period = std::get<1>(params); - const float prev_pitch_gain = std::get<2>(params); - const int expected_pitch_period = std::get<3>(params); - const float expected_pitch_gain = std::get<4>(params); PitchTestData test_data; - { - // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. - // FloatingPointExceptionObserver fpe_observer; - const auto computed_output = CheckLowerPitchPeriodsAndComputePitchGain( - test_data.GetPitchBufView(), initial_pitch_period, - {prev_pitch_period, prev_pitch_gain}); - EXPECT_EQ(expected_pitch_period, computed_output.period); - EXPECT_NEAR(expected_pitch_gain, computed_output.gain, 1e-6f); - } + // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + const auto computed_output = ComputeExtendedPitchPeriod48kHz( + test_data.GetPitchBufView(), GetInitialPitchPeriod(), + {GetLastPitchPeriod(), GetLastPitchStrength()}); + EXPECT_EQ(GetExpectedPitchPeriod(), computed_output.period); + EXPECT_NEAR(GetExpectedPitchStrength(), computed_output.strength, 1e-6f); } INSTANTIATE_TEST_SUITE_P( RnnVadTest, - CheckLowerPitchPeriodsAndComputePitchGainTest, + ComputeExtendedPitchPeriod48kHzTest, ::testing::Values(std::make_tuple(kTestPitchPeriodsLow, kTestPitchPeriodsLow, kTestPitchGainsLow, diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc index fdecb92807..c57c8c24db 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc @@ -13,7 +13,6 @@ #include #include -#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h" #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" #include "modules/audio_processing/agc2/rnn_vad/test_utils.h" // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. @@ -22,15 +21,14 @@ namespace webrtc { namespace rnn_vad { -namespace test { // Checks that the computed pitch period is bit-exact and that the computed // pitch gain is within tolerance given test input data. TEST(RnnVadTest, PitchSearchWithinTolerance) { - auto lp_residual_reader = CreateLpResidualAndPitchPeriodGainReader(); + auto lp_residual_reader = test::CreateLpResidualAndPitchPeriodGainReader(); const int num_frames = std::min(lp_residual_reader.second, 300); // Max 3 s. std::vector lp_residual(kBufSize24kHz); - float expected_pitch_period, expected_pitch_gain; + float expected_pitch_period, expected_pitch_strength; PitchEstimator pitch_estimator; { // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. @@ -39,15 +37,15 @@ TEST(RnnVadTest, PitchSearchWithinTolerance) { SCOPED_TRACE(i); lp_residual_reader.first->ReadChunk(lp_residual); lp_residual_reader.first->ReadValue(&expected_pitch_period); - lp_residual_reader.first->ReadValue(&expected_pitch_gain); - PitchInfo pitch_info = + lp_residual_reader.first->ReadValue(&expected_pitch_strength); + int pitch_period = pitch_estimator.Estimate({lp_residual.data(), kBufSize24kHz}); - EXPECT_EQ(expected_pitch_period, pitch_info.period); - EXPECT_NEAR(expected_pitch_gain, pitch_info.gain, 1e-5f); + EXPECT_EQ(expected_pitch_period, pitch_period); + EXPECT_NEAR(expected_pitch_strength, + pitch_estimator.GetLastPitchStrengthForTesting(), 1e-5f); } } } -} // namespace test } // namespace rnn_vad } // namespace webrtc