diff --git a/fluid/DeepASR/decoder/decoder.cc b/fluid/DeepASR/decoder/decoder.cc deleted file mode 100644 index a99f972e2f..0000000000 --- a/fluid/DeepASR/decoder/decoder.cc +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "decoder.h" - -std::string decode(std::vector> probs_mat) { - // Add decoding logic here - - return "example decoding result"; -} diff --git a/fluid/DeepASR/decoder/decoder.h b/fluid/DeepASR/decoder/decoder.h deleted file mode 100644 index 4a67fa366a..0000000000 --- a/fluid/DeepASR/decoder/decoder.h +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include - -std::string decode(std::vector> probs_mat); diff --git a/fluid/DeepASR/decoder/post_decode_faster.cc b/fluid/DeepASR/decoder/post_decode_faster.cc new file mode 100644 index 0000000000..d7f1d1ab34 --- /dev/null +++ b/fluid/DeepASR/decoder/post_decode_faster.cc @@ -0,0 +1,144 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "post_decode_faster.h" + +typedef kaldi::int32 int32; +using fst::SymbolTable; +using fst::VectorFst; +using fst::StdArc; + +Decoder::Decoder(std::string word_syms_filename, + std::string fst_in_filename, + std::string logprior_rxfilename) { + const char* usage = + "Decode, reading log-likelihoods (of transition-ids or whatever symbol " + "is on the graph) as matrices."; + + kaldi::ParseOptions po(usage); + binary = true; + acoustic_scale = 1.5; + allow_partial = true; + kaldi::FasterDecoderOptions decoder_opts; + decoder_opts.Register(&po, true); // true == include obscure settings. + po.Register("binary", &binary, "Write output in binary mode"); + po.Register("allow-partial", + &allow_partial, + "Produce output even when final state was not reached"); + po.Register("acoustic-scale", + &acoustic_scale, + "Scaling factor for acoustic likelihoods"); + + word_syms = NULL; + if (word_syms_filename != "") { + word_syms = fst::SymbolTable::ReadText(word_syms_filename); + if (!word_syms) + KALDI_ERR << "Could not read symbol table from file " + << word_syms_filename; + } + + std::ifstream is_logprior(logprior_rxfilename); + logprior.Read(is_logprior, false); + + // It's important that we initialize decode_fst after loglikes_reader, as it + // can prevent crashes on systems installed without enough virtual memory. + // It has to do with what happens on UNIX systems if you call fork() on a + // large process: the page-table entries are duplicated, which requires a + // lot of virtual memory. + decode_fst = fst::ReadFstKaldi(fst_in_filename); + + decoder = new kaldi::FasterDecoder(*decode_fst, decoder_opts); +} + + +Decoder::~Decoder() { + if (!word_syms) delete word_syms; + delete decode_fst; + delete decoder; +} + +std::string Decoder::decode( + std::string key, + const std::vector>& log_probs) { + size_t num_frames = log_probs.size(); + size_t dim_label = log_probs[0].size(); + + kaldi::Matrix loglikes( + num_frames, dim_label, kaldi::kSetZero, kaldi::kStrideEqualNumCols); + for (size_t i = 0; i < num_frames; ++i) { + memcpy(loglikes.Data() + i * dim_label, + log_probs[i].data(), + sizeof(kaldi::BaseFloat) * dim_label); + } + + return decode(key, loglikes); +} + + +std::vector Decoder::decode(std::string posterior_rspecifier) { + kaldi::SequentialBaseFloatMatrixReader posterior_reader(posterior_rspecifier); + std::vector decoding_results; + + for (; !posterior_reader.Done(); posterior_reader.Next()) { + std::string key = posterior_reader.Key(); + kaldi::Matrix loglikes(posterior_reader.Value()); + + decoding_results.push_back(decode(key, loglikes)); + } + + return decoding_results; +} + + +std::string Decoder::decode(std::string key, + kaldi::Matrix& loglikes) { + std::string decoding_result; + + if (loglikes.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << key; + } + KALDI_ASSERT(loglikes.NumCols() == logprior.Dim()); + + loglikes.ApplyLog(); + loglikes.AddVecToRows(-1.0, logprior); + + kaldi::DecodableMatrixScaled decodable(loglikes, acoustic_scale); + decoder->Decode(&decodable); + + VectorFst decoded; // linear FST. + + if ((allow_partial || decoder->ReachedFinal()) && + decoder->GetBestPath(&decoded)) { + if (!decoder->ReachedFinal()) + KALDI_WARN << "Decoder did not reach end-state, outputting partial " + "traceback."; + + std::vector alignment; + std::vector words; + kaldi::LatticeWeight weight; + + GetLinearSymbolSequence(decoded, &alignment, &words, &weight); + + if (word_syms != NULL) { + for (size_t i = 0; i < words.size(); i++) { + std::string s = word_syms->Find(words[i]); + decoding_result += s; + if (s == "") + KALDI_ERR << "Word-id " << words[i] << " not in symbol table."; + } + } + } + + return decoding_result; +} diff --git a/fluid/DeepASR/decoder/post_decode_faster.h b/fluid/DeepASR/decoder/post_decode_faster.h new file mode 100644 index 0000000000..2e31a1c19e --- /dev/null +++ b/fluid/DeepASR/decoder/post_decode_faster.h @@ -0,0 +1,57 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "base/kaldi-common.h" +#include "base/timer.h" +#include "decoder/decodable-matrix.h" +#include "decoder/faster-decoder.h" +#include "fstext/fstext-lib.h" +#include "hmm/transition-model.h" +#include "lat/kaldi-lattice.h" // for {Compact}LatticeArc +#include "tree/context-dep.h" +#include "util/common-utils.h" + + +class Decoder { +public: + Decoder(std::string word_syms_filename, + std::string fst_in_filename, + std::string logprior_rxfilename); + ~Decoder(); + + // Interface to accept the scores read from specifier and return + // the batch decoding results + std::vector decode(std::string posterior_rspecifier); + + // Accept the scores of one utterance and return the decoding result + std::string decode( + std::string key, + const std::vector> &log_probs); + +private: + // For decoding one utterance + std::string decode(std::string key, + kaldi::Matrix &loglikes); + + fst::SymbolTable *word_syms; + fst::VectorFst *decode_fst; + kaldi::FasterDecoder *decoder; + kaldi::Vector logprior; + + bool binary; + kaldi::BaseFloat acoustic_scale; + bool allow_partial; +}; diff --git a/fluid/DeepASR/decoder/pybind.cc b/fluid/DeepASR/decoder/pybind.cc index 8cd65903ea..56439d1802 100644 --- a/fluid/DeepASR/decoder/pybind.cc +++ b/fluid/DeepASR/decoder/pybind.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,15 +15,25 @@ limitations under the License. */ #include #include -#include "decoder.h" +#include "post_decode_faster.h" namespace py = pybind11; -PYBIND11_MODULE(decoder, m) { - m.doc() = "Decode function for Deep ASR model"; - - m.def("decode", - &decode, - "Decode one input probability matrix " - "and return the transcription"); +PYBIND11_MODULE(post_decode_faster, m) { + m.doc() = "Decoder for Deep ASR model"; + + py::class_(m, "Decoder") + .def(py::init()) + .def("decode", + (std::vector (Decoder::*)(std::string)) & + Decoder::decode, + "Decode for the probability matrices in specifier " + "and return the transcriptions.") + .def( + "decode", + (std::string (Decoder::*)( + std::string, const std::vector>&)) & + Decoder::decode, + "Decode one input probability matrix " + "and return the transcription."); } diff --git a/fluid/DeepASR/decoder/setup.py b/fluid/DeepASR/decoder/setup.py index cedd5d644e..a98c0b4cc1 100644 --- a/fluid/DeepASR/decoder/setup.py +++ b/fluid/DeepASR/decoder/setup.py @@ -1,4 +1,4 @@ -# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,27 +13,57 @@ # limitations under the License. import os +import glob from distutils.core import setup, Extension from distutils.sysconfig import get_config_vars -args = ['-std=c++11'] +try: + kaldi_root = os.environ['KALDI_ROOT'] +except: + raise ValueError("Enviroment variable 'KALDI_ROOT' is not defined. Please " + "install kaldi and export KALDI_ROOT= .") + +args = [ + '-std=c++11', '-Wno-sign-compare', '-Wno-unused-variable', + '-Wno-unused-local-typedefs', '-Wno-unused-but-set-variable', + '-Wno-deprecated-declarations', '-Wno-unused-function' +] # remove warning about -Wstrict-prototypes (opt, ) = get_config_vars('OPT') os.environ['OPT'] = " ".join(flag for flag in opt.split() if flag != '-Wstrict-prototypes') +os.environ['CC'] = 'g++' + +LIBS = [ + 'fst', 'kaldi-base', 'kaldi-util', 'kaldi-matrix', 'kaldi-tree', + 'kaldi-hmm', 'kaldi-fstext', 'kaldi-decoder', 'kaldi-lat' +] + +LIB_DIRS = [ + 'tools/openfst/lib', 'src/base', 'src/matrix', 'src/util', 'src/tree', + 'src/hmm', 'src/fstext', 'src/decoder', 'src/lat' +] +LIB_DIRS = [os.path.join(kaldi_root, path) for path in LIB_DIRS] +LIB_DIRS = [os.path.abspath(path) for path in LIB_DIRS] ext_modules = [ Extension( - 'decoder', - ['pybind.cc', 'decoder.cc'], - include_dirs=['pybind11/include', '.'], + 'post_decode_faster', + ['pybind.cc', 'post_decode_faster.cc'], + include_dirs=[ + 'pybind11/include', '.', os.path.join(kaldi_root, 'src'), + os.path.join(kaldi_root, 'tools/openfst/src/include') + ], language='c++', + libraries=LIBS, + library_dirs=LIB_DIRS, + runtime_library_dirs=LIB_DIRS, extra_compile_args=args, ), ] setup( - name='decoder', + name='post_decode_faster', version='0.0.1', author='Paddle', author_email='', diff --git a/fluid/DeepASR/decoder/setup.sh b/fluid/DeepASR/decoder/setup.sh index 71fd6626ef..1471f85f41 100644 --- a/fluid/DeepASR/decoder/setup.sh +++ b/fluid/DeepASR/decoder/setup.sh @@ -1,4 +1,4 @@ - +set -e if [ ! -d pybind11 ]; then git clone https://github.com/pybind/pybind11.git diff --git a/fluid/DeepASR/infer_by_ckpt.py b/fluid/DeepASR/infer_by_ckpt.py index edaa2b5ac4..f267f67498 100644 --- a/fluid/DeepASR/infer_by_ckpt.py +++ b/fluid/DeepASR/infer_by_ckpt.py @@ -13,7 +13,7 @@ import data_utils.augmentor.trans_add_delta as trans_add_delta import data_utils.augmentor.trans_splice as trans_splice import data_utils.async_data_reader as reader -import decoder.decoder as decoder +from decoder.post_decode_faster import Decoder from data_utils.util import lodtensor_to_ndarray from model_utils.model import stacked_lstmp_model from data_utils.util import split_infer_result @@ -91,6 +91,21 @@ def parse_args(): type=str, default='./checkpoint', help="The checkpoint path to init model. (default: %(default)s)") + parser.add_argument( + '--vocabulary', + type=str, + default='./decoder/graph/words.txt', + help="The path to vocabulary. (default: %(default)s)") + parser.add_argument( + '--graphs', + type=str, + default='./decoder/graph/TLG.fst', + help="The path to TLG graphs for decoding. (default: %(default)s)") + parser.add_argument( + '--log_prior', + type=str, + default="./decoder/logprior", + help="The log prior probs for training data. (default: %(default)s)") args = parser.parse_args() return args @@ -165,8 +180,9 @@ def infer_from_ckpt(args): probs, lod = lodtensor_to_ndarray(results[0]) infer_batch = split_infer_result(probs, lod) for index, sample in enumerate(infer_batch): - print("Decoding %d: " % (batch_id * args.batch_size + index), - decoder.decode(sample)) + key = "utter#%d" % (batch_id * args.batch_size + index) + print(key, ": ", decoder.decode(key, sample), "\n") + print(np.mean(infer_costs), np.mean(infer_accs))