Skip to content

Commit

Permalink
RNN VAD: pitch search optimizations (part 1)
Browse files Browse the repository at this point in the history
TL;DR this CL improves efficiency and includes several code
readability improvements mainly triggered by the comments to
patch set #10.

Highlights:
- Split `FindBestPitchPeriods()` into 12 and 24 kHz versions
  to hard-code the input size and simplify the 24 kHz version
- Loop in `ComputePitchPeriod48kHz()` (new name for
  `RefinePitchPeriod48kHz()`) removed since the lags for which
  we need to compute the auto correlation are a few
- `ComputePitchGainThreshold()` was only used in unit tests; it's been
  moved into the anon ns and the test removed

This CL makes `ComputePitchPeriod48kHz()` is about 10% faster (measured
with https://webrtc-review.googlesource.com/c/src/+/191320/4/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc).
The realtime factor has improved by about +14%.

Benchmarked as follows:
```
out/release/modules_unittests \
  --gtest_filter=*RnnVadTest.DISABLED_RnnVadPerformance* \
  --gtest_also_run_disabled_tests --logs
```

Results:

      | baseline             | this CL
------+----------------------+------------------------
run 1 | 24.0231 +/- 0.591016 | 23.568 +/- 0.990788
      | 370.06x              | 377.207x
------+----------------------+------------------------
run 2 | 24.0485 +/- 0.957498 | 23.3714 +/- 0.857523
      | 369.67x              | 380.379x
------+----------------------+------------------------
run 2 | 25.4091 +/- 2.6123   | 23.709 +/- 1.04477
      | 349.875x             | 374.963x

Bug: webrtc:10480
Change-Id: I9a3e9164b2442114b928de506c92a547c273882f
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/191320
Reviewed-by: Per Åhgren <peah@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32568}
  • Loading branch information
alebzk authored and Commit Bot committed Nov 9, 2020
1 parent c95b939 commit 9da3e17
Show file tree
Hide file tree
Showing 13 changed files with 452 additions and 450 deletions.
2 changes: 1 addition & 1 deletion modules/audio_processing/agc2/rnn_vad/BUILD.gn
Expand Up @@ -83,7 +83,6 @@ rtc_library("rnn_vad_lp_residual") {

rtc_library("rnn_vad_pitch") {
sources = [
"pitch_info.h",
"pitch_search.cc",
"pitch_search.h",
"pitch_search_internal.cc",
Expand All @@ -94,6 +93,7 @@ rtc_library("rnn_vad_pitch") {
":rnn_vad_common",
"../../../../api:array_view",
"../../../../rtc_base:checks",
"../../../../rtc_base:gtest_prod",
"../../../../rtc_base:safe_compare",
"../../../../rtc_base:safe_conversions",
]
Expand Down
17 changes: 8 additions & 9 deletions modules/audio_processing/agc2/rnn_vad/auto_correlation.cc
Expand Up @@ -20,7 +20,7 @@ namespace {

constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT.
static_assert(1 << kAutoCorrelationFftOrder >
kNumInvertedLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
kNumLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
"");

} // namespace
Expand All @@ -45,15 +45,15 @@ AutoCorrelationCalculator::~AutoCorrelationCalculator() = default;
// pitch period.
void AutoCorrelationCalculator::ComputeOnPitchBuffer(
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr) {
rtc::ArrayView<float, kNumLags12kHz> auto_corr) {
RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz);
RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz);
constexpr int kFftFrameSize = 1 << kAutoCorrelationFftOrder;
constexpr int kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz;
static_assert(kConvolutionLength == kFrameSize20ms12kHz,
"Mismatch between pitch buffer size, frame size and maximum "
"pitch period.");
static_assert(kFftFrameSize > kNumInvertedLags12kHz + kConvolutionLength,
static_assert(kFftFrameSize > kNumLags12kHz + kConvolutionLength,
"The FFT length is not sufficiently big to avoid cyclic "
"convolution errors.");
auto tmp = tmp_->GetView();
Expand All @@ -67,13 +67,12 @@ void AutoCorrelationCalculator::ComputeOnPitchBuffer(

// Compute the FFT for the sliding frames chunk. The sliding frames are
// defined as pitch_buf[i:i+kConvolutionLength] where i in
// [0, kNumInvertedLags12kHz). The chunk includes all of them, hence it is
// defined as pitch_buf[:kNumInvertedLags12kHz+kConvolutionLength].
// [0, kNumLags12kHz). The chunk includes all of them, hence it is
// defined as pitch_buf[:kNumLags12kHz+kConvolutionLength].
std::copy(pitch_buf.begin(),
pitch_buf.begin() + kConvolutionLength + kNumInvertedLags12kHz,
pitch_buf.begin() + kConvolutionLength + kNumLags12kHz,
tmp.begin());
std::fill(tmp.begin() + kNumInvertedLags12kHz + kConvolutionLength, tmp.end(),
0.f);
std::fill(tmp.begin() + kNumLags12kHz + kConvolutionLength, tmp.end(), 0.f);
fft_.ForwardTransform(*tmp_, X_.get(), /*ordered=*/false);

// Convolve in the frequency domain.
Expand All @@ -84,7 +83,7 @@ void AutoCorrelationCalculator::ComputeOnPitchBuffer(

// Extract the auto-correlation coefficients.
std::copy(tmp.begin() + kConvolutionLength - 1,
tmp.begin() + kConvolutionLength + kNumInvertedLags12kHz - 1,
tmp.begin() + kConvolutionLength + kNumLags12kHz - 1,
auto_corr.begin());
}

Expand Down
2 changes: 1 addition & 1 deletion modules/audio_processing/agc2/rnn_vad/auto_correlation.h
Expand Up @@ -34,7 +34,7 @@ class AutoCorrelationCalculator {
// |auto_corr| indexes are inverted lags.
void ComputeOnPitchBuffer(
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr);
rtc::ArrayView<float, kNumLags12kHz> auto_corr);

private:
Pffft fft_;
Expand Down
12 changes: 9 additions & 3 deletions modules/audio_processing/agc2/rnn_vad/common.h
Expand Up @@ -36,7 +36,13 @@ constexpr int kInitialMinPitch24kHz = 3 * kMinPitch24kHz;
static_assert(kMinPitch24kHz < kInitialMinPitch24kHz, "");
static_assert(kInitialMinPitch24kHz < kMaxPitch24kHz, "");
static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, "");
constexpr int kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
// Number of (inverted) lags during the initial pitch search phase at 24 kHz.
constexpr int kInitialNumLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
// Number of (inverted) lags during the pitch search refinement phase at 24 kHz.
constexpr int kRefineNumLags24kHz = kMaxPitch24kHz + 1;
static_assert(
kRefineNumLags24kHz > kInitialNumLags24kHz,
"The refinement step must search the pitch in an extended pitch range.");

// 12 kHz analysis.
constexpr int kSampleRate12kHz = 12000;
Expand All @@ -47,8 +53,8 @@ constexpr int kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2;
constexpr int kMaxPitch12kHz = kMaxPitch24kHz / 2;
static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, "");
// The inverted lags for the pitch interval [|kInitialMinPitch12kHz|,
// |kMaxPitch12kHz|] are in the range [0, |kNumInvertedLags12kHz|].
constexpr int kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
// |kMaxPitch12kHz|] are in the range [0, |kNumLags12kHz|].
constexpr int kNumLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;

// 48 kHz constants.
constexpr int kMinPitch48kHz = kMinPitch24kHz * 2;
Expand Down
9 changes: 4 additions & 5 deletions modules/audio_processing/agc2/rnn_vad/features_extraction.cc
Expand Up @@ -67,13 +67,12 @@ bool FeaturesExtractor::CheckSilenceComputeFeatures(
ComputeLpResidual(lpc_coeffs, pitch_buf_24kHz_view_, lp_residual_view_);
// Estimate pitch on the LP-residual and write the normalized pitch period
// into the output vector (normalization based on training data stats).
pitch_info_48kHz_ = pitch_estimator_.Estimate(lp_residual_view_);
feature_vector[kFeatureVectorSize - 2] =
0.01f * (pitch_info_48kHz_.period - 300);
pitch_period_48kHz_ = pitch_estimator_.Estimate(lp_residual_view_);
feature_vector[kFeatureVectorSize - 2] = 0.01f * (pitch_period_48kHz_ - 300);
// Extract lagged frames (according to the estimated pitch period).
RTC_DCHECK_LE(pitch_info_48kHz_.period / 2, kMaxPitch24kHz);
RTC_DCHECK_LE(pitch_period_48kHz_ / 2, kMaxPitch24kHz);
auto lagged_frame = pitch_buf_24kHz_view_.subview(
kMaxPitch24kHz - pitch_info_48kHz_.period / 2, kFrameSize20ms24kHz);
kMaxPitch24kHz - pitch_period_48kHz_ / 2, kFrameSize20ms24kHz);
// Analyze reference and lagged frames checking if silence has been detected
// and write the feature vector.
return spectral_features_extractor_.CheckSilenceComputeFeatures(
Expand Down
3 changes: 1 addition & 2 deletions modules/audio_processing/agc2/rnn_vad/features_extraction.h
Expand Up @@ -16,7 +16,6 @@
#include "api/array_view.h"
#include "modules/audio_processing/agc2/biquad_filter.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_search.h"
#include "modules/audio_processing/agc2/rnn_vad/sequence_buffer.h"
#include "modules/audio_processing/agc2/rnn_vad/spectral_features.h"
Expand Down Expand Up @@ -53,7 +52,7 @@ class FeaturesExtractor {
PitchEstimator pitch_estimator_;
rtc::ArrayView<const float, kFrameSize20ms24kHz> reference_frame_view_;
SpectralFeaturesExtractor spectral_features_extractor_;
PitchInfo pitch_info_48kHz_;
int pitch_period_48kHz_;
};

} // namespace rnn_vad
Expand Down
29 changes: 0 additions & 29 deletions modules/audio_processing/agc2/rnn_vad/pitch_info.h

This file was deleted.

26 changes: 14 additions & 12 deletions modules/audio_processing/agc2/rnn_vad/pitch_search.cc
Expand Up @@ -21,35 +21,37 @@ namespace rnn_vad {
PitchEstimator::PitchEstimator()
: pitch_buf_decimated_(kBufSize12kHz),
pitch_buf_decimated_view_(pitch_buf_decimated_.data(), kBufSize12kHz),
auto_corr_(kNumInvertedLags12kHz),
auto_corr_view_(auto_corr_.data(), kNumInvertedLags12kHz) {
auto_corr_(kNumLags12kHz),
auto_corr_view_(auto_corr_.data(), kNumLags12kHz) {
RTC_DCHECK_EQ(kBufSize12kHz, pitch_buf_decimated_.size());
RTC_DCHECK_EQ(kNumInvertedLags12kHz, auto_corr_view_.size());
RTC_DCHECK_EQ(kNumLags12kHz, auto_corr_view_.size());
}

PitchEstimator::~PitchEstimator() = default;

PitchInfo PitchEstimator::Estimate(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf) {
int PitchEstimator::Estimate(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer) {
// Perform the initial pitch search at 12 kHz.
Decimate2x(pitch_buf, pitch_buf_decimated_view_);
Decimate2x(pitch_buffer, pitch_buf_decimated_view_);
auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buf_decimated_view_,
auto_corr_view_);
CandidatePitchPeriods pitch_candidates_inverted_lags = FindBestPitchPeriods(
auto_corr_view_, pitch_buf_decimated_view_, kMaxPitch12kHz);
CandidatePitchPeriods pitch_candidates_inverted_lags =
ComputePitchPeriod12kHz(pitch_buf_decimated_view_, auto_corr_view_);
// Refine the pitch period estimation.
// The refinement is done using the pitch buffer that contains 24 kHz samples.
// Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12
// to 24 kHz.
pitch_candidates_inverted_lags.best *= 2;
pitch_candidates_inverted_lags.second_best *= 2;
const int pitch_inv_lag_48kHz =
RefinePitchPeriod48kHz(pitch_buf, pitch_candidates_inverted_lags);
ComputePitchPeriod48kHz(pitch_buffer, pitch_candidates_inverted_lags);
// Look for stronger harmonics to find the final pitch period and its gain.
RTC_DCHECK_LT(pitch_inv_lag_48kHz, kMaxPitch48kHz);
last_pitch_48kHz_ = CheckLowerPitchPeriodsAndComputePitchGain(
pitch_buf, kMaxPitch48kHz - pitch_inv_lag_48kHz, last_pitch_48kHz_);
return last_pitch_48kHz_;
last_pitch_48kHz_ = ComputeExtendedPitchPeriod48kHz(
pitch_buffer,
/*initial_pitch_period_48kHz=*/kMaxPitch48kHz - pitch_inv_lag_48kHz,
last_pitch_48kHz_);
return last_pitch_48kHz_.period;
}

} // namespace rnn_vad
Expand Down
16 changes: 10 additions & 6 deletions modules/audio_processing/agc2/rnn_vad/pitch_search.h
Expand Up @@ -17,8 +17,8 @@
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
#include "rtc_base/gtest_prod_util.h"

namespace webrtc {
namespace rnn_vad {
Expand All @@ -30,17 +30,21 @@ class PitchEstimator {
PitchEstimator(const PitchEstimator&) = delete;
PitchEstimator& operator=(const PitchEstimator&) = delete;
~PitchEstimator();
// Estimates the pitch period and gain. Returns the pitch estimation data for
// 48 kHz.
PitchInfo Estimate(rtc::ArrayView<const float, kBufSize24kHz> pitch_buf);
// Returns the estimated pitch period at 48 kHz.
int Estimate(rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer);

private:
PitchInfo last_pitch_48kHz_;
FRIEND_TEST_ALL_PREFIXES(RnnVadTest, PitchSearchWithinTolerance);
float GetLastPitchStrengthForTesting() const {
return last_pitch_48kHz_.strength;
}

PitchInfo last_pitch_48kHz_{};
AutoCorrelationCalculator auto_corr_calculator_;
std::vector<float> pitch_buf_decimated_;
rtc::ArrayView<float, kBufSize12kHz> pitch_buf_decimated_view_;
std::vector<float> auto_corr_;
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr_view_;
rtc::ArrayView<float, kNumLags12kHz> auto_corr_view_;
};

} // namespace rnn_vad
Expand Down

0 comments on commit 9da3e17

Please sign in to comment.