Skip to content

Commit

Permalink
Merge pull request #1640 from zh794390558/frontend
Browse files Browse the repository at this point in the history
[speechx] Frontend refactor
  • Loading branch information
zh794390558 committed Apr 2, 2022
2 parents 36df70c + f83ec41 commit 2e94e0f
Show file tree
Hide file tree
Showing 26 changed files with 210 additions and 178 deletions.
2 changes: 1 addition & 1 deletion speechx/examples/decoder/offline_decoder_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "frontend/data_cache.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "frontend/data_cache.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
Expand Down
26 changes: 12 additions & 14 deletions speechx/examples/feat/linear_spectrogram_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,18 @@

// todo refactor, repalce with gtest

#include "frontend/linear_spectrogram.h"
#include "base/flags.h"
#include "base/log.h"
#include "frontend/audio_cache.h"
#include "frontend/data_cache.h"
#include "frontend/feature_cache.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/normalizer.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"

#include <glog/logging.h>
#include "frontend/audio/audio_cache.h"
#include "frontend/audio/data_cache.h"
#include "frontend/audio/feature_cache.h"
#include "frontend/audio/frontend_itf.h"
#include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/normalizer.h"

DEFINE_string(wav_rspecifier, "", "test wav scp path");
DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
Expand Down Expand Up @@ -170,13 +169,13 @@ int main(int argc, char* argv[]) {
// feature pipeline: wave cache --> decibel_normalizer --> hanning
// window -->linear_spectrogram --> global cmvn -> feat cache

// std::unique_ptr<ppspeech::FeatureExtractorInterface> data_source(new
// std::unique_ptr<ppspeech::FrontendInterface> data_source(new
// ppspeech::DataCache());
std::unique_ptr<ppspeech::FeatureExtractorInterface> data_source(
std::unique_ptr<ppspeech::FrontendInterface> data_source(
new ppspeech::AudioCache());

ppspeech::DecibelNormalizerOptions db_norm_opt;
std::unique_ptr<ppspeech::FeatureExtractorInterface> db_norm(
std::unique_ptr<ppspeech::FrontendInterface> db_norm(
new ppspeech::DecibelNormalizer(db_norm_opt, std::move(data_source)));

ppspeech::LinearSpectrogramOptions opt;
Expand All @@ -185,12 +184,11 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "frame length (ms): " << opt.frame_opts.frame_length_ms;
LOG(INFO) << "frame shift (ms): " << opt.frame_opts.frame_shift_ms;

std::unique_ptr<ppspeech::FeatureExtractorInterface> linear_spectrogram(
std::unique_ptr<ppspeech::FrontendInterface> linear_spectrogram(
new ppspeech::LinearSpectrogram(opt, std::move(db_norm)));

std::unique_ptr<ppspeech::FeatureExtractorInterface> cmvn(
new ppspeech::CMVN(FLAGS_cmvn_write_path,
std::move(linear_spectrogram)));
std::unique_ptr<ppspeech::FrontendInterface> cmvn(new ppspeech::CMVN(
FLAGS_cmvn_write_path, std::move(linear_spectrogram)));

ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn));
LOG(INFO) << "feat dim: " << feature_cache.Dim();
Expand Down
10 changes: 1 addition & 9 deletions speechx/speechx/frontend/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,2 @@
project(frontend)

add_library(frontend STATIC
normalizer.cc
linear_spectrogram.cc
audio_cache.cc
feature_cache.cc
)

target_link_libraries(frontend PUBLIC kaldi-matrix)
add_subdirectory(audio)
11 changes: 11 additions & 0 deletions speechx/speechx/frontend/audio/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
project(frontend)

add_library(frontend STATIC
cmvn.cc
db_norm.cc
linear_spectrogram.cc
audio_cache.cc
feature_cache.cc
)

target_link_libraries(frontend PUBLIC kaldi-matrix)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "frontend/audio_cache.h"
#include "frontend/audio/audio_cache.h"
#include "kaldi/base/timer.h"

namespace ppspeech {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
#pragma once

#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/audio/frontend_itf.h"

namespace ppspeech {

// waves cache
class AudioCache : public FeatureExtractorInterface {
class AudioCache : public FrontendInterface {
public:
explicit AudioCache(int buffer_size = kint16max);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.


#include "frontend/normalizer.h"
#include "frontend/audio/cmvn.h"
#include "kaldi/feat/cmvn.h"
#include "kaldi/util/kaldi-io.h"

Expand All @@ -26,73 +26,8 @@ using std::vector;
using kaldi::SubVector;
using std::unique_ptr;

DecibelNormalizer::DecibelNormalizer(
const DecibelNormalizerOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor) {
base_extractor_ = std::move(base_extractor);
opts_ = opts;
dim_ = 1;
}

void DecibelNormalizer::Accept(const kaldi::VectorBase<BaseFloat>& waves) {
base_extractor_->Accept(waves);
}

bool DecibelNormalizer::Read(kaldi::Vector<BaseFloat>* waves) {
if (base_extractor_->Read(waves) == false || waves->Dim() == 0) {
return false;
}
Compute(waves);
return true;
}

bool DecibelNormalizer::Compute(VectorBase<BaseFloat>* waves) const {
// calculate db rms
BaseFloat rms_db = 0.0;
BaseFloat mean_square = 0.0;
BaseFloat gain = 0.0;
BaseFloat wave_float_normlization = 1.0f / (std::pow(2, 16 - 1));

vector<BaseFloat> samples;
samples.resize(waves->Dim());
for (size_t i = 0; i < samples.size(); ++i) {
samples[i] = (*waves)(i);
}

// square
for (auto& d : samples) {
if (opts_.convert_int_float) {
d = d * wave_float_normlization;
}
mean_square += d * d;
}

// mean
mean_square /= samples.size();
rms_db = 10 * std::log10(mean_square);
gain = opts_.target_db - rms_db;

if (gain > opts_.max_gain_db) {
LOG(ERROR)
<< "Unable to normalize segment to " << opts_.target_db << "dB,"
<< "because the the probable gain have exceeds opts_.max_gain_db"
<< opts_.max_gain_db << "dB.";
return false;
}

// Note that this is an in-place transformation.
for (auto& item : samples) {
// python item *= 10.0 ** (gain / 20.0)
item *= std::pow(10.0, gain / 20.0);
}

std::memcpy(
waves->Data(), samples.data(), sizeof(BaseFloat) * samples.size());
return true;
}

CMVN::CMVN(std::string cmvn_file,
unique_ptr<FeatureExtractorInterface> base_extractor)
CMVN::CMVN(std::string cmvn_file, unique_ptr<FrontendInterface> base_extractor)
: var_norm_(true) {
base_extractor_ = std::move(base_extractor);
bool binary;
Expand Down Expand Up @@ -185,4 +120,4 @@ void CMVN::ApplyCMVN(kaldi::MatrixBase<BaseFloat>* feats) {
ApplyCmvn(stats_, var_norm_, feats);
}

} // namespace ppspeech
} // namespace ppspeech
48 changes: 48 additions & 0 deletions speechx/speechx/frontend/audio/cmvn.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "base/common.h"
#include "frontend/audio/frontend_itf.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h"

namespace ppspeech {

class CMVN : public FrontendInterface {
public:
explicit CMVN(std::string cmvn_file,
std::unique_ptr<FrontendInterface> base_extractor);
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs);

// the length of feats = feature_row * feature_dim,
// the Matrix is squashed into Vector
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
// the dim_ is the feautre dim.
virtual size_t Dim() const { return dim_; }
virtual void SetFinished() { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() { base_extractor_->Reset(); }

private:
void Compute(kaldi::VectorBase<kaldi::BaseFloat>* feats) const;
void ApplyCMVN(kaldi::MatrixBase<BaseFloat>* feats);
kaldi::Matrix<double> stats_;
std::unique_ptr<FrontendInterface> base_extractor_;
size_t dim_;
bool var_norm_;
};

} // namespace ppspeech
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@


#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/audio/frontend_itf.h"


namespace ppspeech {
// A data source for testing different frontend module.
// It accepts waves or feats.
class DataCache : public FeatureExtractorInterface {
class DataCache : public FrontendInterface {
public:
explicit DataCache() { finished_ = false; }

Expand Down
95 changes: 95 additions & 0 deletions speechx/speechx/frontend/audio/db_norm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.


#include "frontend/audio/db_norm.h"
#include "kaldi/feat/cmvn.h"
#include "kaldi/util/kaldi-io.h"

namespace ppspeech {

using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat;
using std::vector;
using kaldi::SubVector;
using std::unique_ptr;

DecibelNormalizer::DecibelNormalizer(
const DecibelNormalizerOptions& opts,
std::unique_ptr<FrontendInterface> base_extractor) {
base_extractor_ = std::move(base_extractor);
opts_ = opts;
dim_ = 1;
}

void DecibelNormalizer::Accept(const kaldi::VectorBase<BaseFloat>& waves) {
base_extractor_->Accept(waves);
}

bool DecibelNormalizer::Read(kaldi::Vector<BaseFloat>* waves) {
if (base_extractor_->Read(waves) == false || waves->Dim() == 0) {
return false;
}
Compute(waves);
return true;
}

bool DecibelNormalizer::Compute(VectorBase<BaseFloat>* waves) const {
// calculate db rms
BaseFloat rms_db = 0.0;
BaseFloat mean_square = 0.0;
BaseFloat gain = 0.0;
BaseFloat wave_float_normlization = 1.0f / (std::pow(2, 16 - 1));

vector<BaseFloat> samples;
samples.resize(waves->Dim());
for (size_t i = 0; i < samples.size(); ++i) {
samples[i] = (*waves)(i);
}

// square
for (auto& d : samples) {
if (opts_.convert_int_float) {
d = d * wave_float_normlization;
}
mean_square += d * d;
}

// mean
mean_square /= samples.size();
rms_db = 10 * std::log10(mean_square);
gain = opts_.target_db - rms_db;

if (gain > opts_.max_gain_db) {
LOG(ERROR)
<< "Unable to normalize segment to " << opts_.target_db << "dB,"
<< "because the the probable gain have exceeds opts_.max_gain_db"
<< opts_.max_gain_db << "dB.";
return false;
}

// Note that this is an in-place transformation.
for (auto& item : samples) {
// python item *= 10.0 ** (gain / 20.0)
item *= std::pow(10.0, gain / 20.0);
}

std::memcpy(
waves->Data(), samples.data(), sizeof(BaseFloat) * samples.size());
return true;
}


} // namespace ppspeech

0 comments on commit 2e94e0f

Please sign in to comment.