From 47176aee28885c136c281cd1f809bb5138e9c8ca Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Wed, 18 Nov 2020 22:53:32 +0700 Subject: [PATCH] :zap: Update demonstration example --- README.md | 62 ++++++++++++++---- examples/demonstration/conformer.py | 63 ++++++++++++------- ...ormer.py => streaming_tflite_conformer.py} | 16 ++++- examples/demonstration/tflite_conformer.py | 62 ++++++++++++++++++ tensorflow_asr/utils/metrics.py | 6 +- 5 files changed, 173 insertions(+), 36 deletions(-) rename examples/demonstration/{streaming_conformer.py => streaming_tflite_conformer.py} (88%) create mode 100644 examples/demonstration/tflite_conformer.py diff --git a/README.md b/README.md index 921276be1e..2d630e5ae7 100755 --- a/README.md +++ b/README.md @@ -30,6 +30,28 @@ TensorFlowASR implements some automatic speech recognition architectures such as - Support `transducer` tflite greedy decoding (conversion and invocation) - Distributed training using `tf.distribute.MirroredStrategy` +## Table of Contents + + +- [What's New?](#whats-new) +- [Table of Contents](#table-of-contents) +- [:yum: Supported Models](#yum-supported-models) +- [Installation](#installation) + - [Installing via PyPi](#installing-via-pypi) + - [Installing from source](#installing-from-source) +- [Setup training and testing](#setup-training-and-testing) +- [TFLite Convertion](#tflite-convertion) +- [Features Extraction](#features-extraction) +- [Augmentations](#augmentations) +- [Training & Testing](#training--testing) +- [Corpus Sources and Pretrained Models](#corpus-sources-and-pretrained-models) + - [English](#english) + - [Vietnamese](#vietnamese) + - [German](#german) +- [References & Credits](#references--credits) + + + ## :yum: Supported Models - **CTCModel** (End2end models using CTC Loss for training) @@ -43,26 +65,44 @@ TensorFlowASR implements some automatic speech recognition architectures such as - **Streaming Transducer** (Reference: [https://arxiv.org/abs/1811.06621](https://arxiv.org/abs/1811.06621)) See [examples/streaming_transducer](./examples/streaming_transducer) -## Setup Environment and Datasets +## Installation -Install tensorflow: `pip3 install -U tensorflow` or `pip3 install tf-nightly` (for using tflite) +For training and testing, you should use `git clone` for installing necessary packages from other authors (`ctc_decoders`, `rnnt_loss`, etc.) -Install packages (choose _one_ of these options): +### Installing via PyPi -- Run `pip3 install -U TensorFlowASR` -- Clone the repo and run `python3 setup.py install` in the repo's directory +Run `pip3 install -U TensorFlowASR` -For **setting up datasets**, see [datasets](./tensorflow_asr/datasets/README.md) +### Installing from source -- For _training, testing and using_ **CTC Models**, run `./scripts/install_ctc_decoders.sh` +```bash +git clone https://github.com/TensorSpeech/TensorFlowASR.git +cd TensorFlowASR +python3 setup.py install +``` -- For _training_ **Transducer Models**, export `CUDA_HOME` and run `./scripts/install_rnnt_loss.sh` +For anaconda3: + +```bash +conda create -y -n tfasr tensorflow-gpu python=3.7 # tensorflow if using CPU +conda activate tfasr +pip install -U tensorflow-gpu # upgrade to latest version of tensorflow +git clone https://github.com/TensorSpeech/TensorFlowASR.git +cd TensorFlowASR +python setup.py install +``` + +## Setup training and testing + +- For datasets, see [datasets](./tensorflow_asr/datasets/README.md) + +- For _training, testing and using_ **CTC Models**, run `./scripts/install_ctc_decoders.sh` -- Method `tensorflow_asr.utils.setup_environment()` enable **mixed_precision** if available. +- For _training_ **Transducer Models**, run `export CUDA_HOME=/usr/local/cuda && ./scripts/install_rnnt_loss.sh` (**Note**: only `export CUDA_HOME` when you have CUDA) -- To enable XLA, run `TF_XLA_FLAGS=--tf_xla_auto_jit=2 $python_train_script` +- For _mixed precision training_, use flag `--mxp` when running python scripts from [examples](./examples) -Clean up: `python3 setup.py clean --all` (this will remove `/build` contents) +- For _enabling XLA_, run `TF_XLA_FLAGS=--tf_xla_auto_jit=2 python3 $path_to_py_script`) ## TFLite Convertion diff --git a/examples/demonstration/conformer.py b/examples/demonstration/conformer.py index e095c07ef4..c80ddc2ffa 100644 --- a/examples/demonstration/conformer.py +++ b/examples/demonstration/conformer.py @@ -13,44 +13,65 @@ # limitations under the License. import argparse -import tensorflow as tf +from tensorflow_asr.utils import setup_environment, setup_devices -from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio +setup_environment() +import tensorflow as tf parser = argparse.ArgumentParser(prog="Conformer non streaming") parser.add_argument("filename", metavar="FILENAME", help="audio file to be played back") -parser.add_argument("--tflite", type=str, default=None, - help="Path to conformer tflite") +parser.add_argument("--config", type=str, default=None, + help="Path to conformer config yaml") + +parser.add_argument("--saved", type=str, default=None, + help="Path to conformer saved h5 weights") parser.add_argument("--blank", type=int, default=0, help="Path to conformer tflite") +parser.add_argument("--num_rnns", type=int, default=1, + help="Number of RNN layers in prediction network") + +parser.add_argument("--nstates", type=int, default=2, + help="Number of RNN states in prediction network (1 for GRU and 2 for LSTM)") + parser.add_argument("--statesize", type=int, default=320, - help="Path to conformer tflite") + help="Size of RNN state in prediction network") + +parser.add_argument("--device", type=int, default=0, + help="Device's id to run test on") + +parser.add_argument("--cpu", default=False, action="store_true", + help="Whether to only use cpu") args = parser.parse_args() -tflitemodel = tf.lite.Interpreter(model_path=args.tflite) +setup_devices([args.device], cpu=args.cpu) + +from tensorflow_asr.configs.config import Config +from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio +from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer +from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer +from tensorflow_asr.models.conformer import Conformer + +config = Config(args.config, learning=False) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) +text_featurizer = CharFeaturizer(config.decoder_config) + +# build model +conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) +conformer._build(speech_featurizer.shape) +conformer.load_weights(args.saved, by_name=True) +conformer.summary(line_length=120) +conformer.add_featurizers(speech_featurizer, text_featurizer) signal = read_raw_audio(args.filename) +predicted = tf.constant(args.blank, dtype=tf.int32) +states = tf.zeros([args.num_rnns, args.nstates, 1, args.statesize], dtype=tf.float32) -input_details = tflitemodel.get_input_details() -output_details = tflitemodel.get_output_details() -tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape) -tflitemodel.allocate_tensors() -tflitemodel.set_tensor(input_details[0]["index"], signal) -tflitemodel.set_tensor( - input_details[1]["index"], - tf.constant(args.blank, dtype=tf.int32) -) -tflitemodel.set_tensor( - input_details[2]["index"], - tf.zeros([1, 2, 1, args.statesize], dtype=tf.float32) -) -tflitemodel.invoke() -hyp = tflitemodel.get_tensor(output_details[0]["index"]) +hyp, _, _ = conformer.recognize_tflite(signal, predicted, states) print("".join([chr(u) for u in hyp])) diff --git a/examples/demonstration/streaming_conformer.py b/examples/demonstration/streaming_tflite_conformer.py similarity index 88% rename from examples/demonstration/streaming_conformer.py rename to examples/demonstration/streaming_tflite_conformer.py index cc3713a302..193ec91dd8 100644 --- a/examples/demonstration/streaming_conformer.py +++ b/examples/demonstration/streaming_tflite_conformer.py @@ -57,6 +57,18 @@ def int_or_str(text): parser.add_argument("--tflite", type=str, default=None, help="Path to conformer tflite") +parser.add_argument("--blank", type=int, default=0, + help="Path to conformer tflite") + +parser.add_argument("--num_rnns", type=int, default=1, + help="Number of RNN layers in prediction network") + +parser.add_argument("--nstates", type=int, default=2, + help="Number of RNN states in prediction network (1 for GRU and 2 for LSTM)") + +parser.add_argument("--statesize", type=int, default=320, + help="Size of RNN state in prediction network") + args = parser.parse_args(remaining) if args.blocksize == 0: @@ -92,8 +104,8 @@ def recognize(signal, lastid, states): text = "".join([chr(u) for u in upoints]) return text, lastid, states - lastid = np.zeros(shape=[], dtype=np.int32) - states = np.zeros(shape=[1, 2, 1, 320], dtype=np.float32) + lastid = args.blank * np.ones(shape=[], dtype=np.int32) + states = np.zeros(shape=[args.num_rnns, args.nstates, 1, args.statesize], dtype=np.float32) transcript = "" while True: diff --git a/examples/demonstration/tflite_conformer.py b/examples/demonstration/tflite_conformer.py new file mode 100644 index 0000000000..06b2eac7f9 --- /dev/null +++ b/examples/demonstration/tflite_conformer.py @@ -0,0 +1,62 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import tensorflow as tf + +from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio + +parser = argparse.ArgumentParser(prog="Conformer non streaming") + +parser.add_argument("filename", metavar="FILENAME", + help="Audio file to be played back") + +parser.add_argument("--tflite", type=str, default=None, + help="Path to conformer tflite") + +parser.add_argument("--blank", type=int, default=0, + help="Blank index") + +parser.add_argument("--num_rnns", type=int, default=1, + help="Number of RNN layers in prediction network") + +parser.add_argument("--nstates", type=int, default=2, + help="Number of RNN states in prediction network (1 for GRU and 2 for LSTM)") + +parser.add_argument("--statesize", type=int, default=320, + help="Size of RNN state in prediction network") + +args = parser.parse_args() + +tflitemodel = tf.lite.Interpreter(model_path=args.tflite) + +signal = read_raw_audio(args.filename) + +input_details = tflitemodel.get_input_details() +output_details = tflitemodel.get_output_details() +tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape) +tflitemodel.allocate_tensors() +tflitemodel.set_tensor(input_details[0]["index"], signal) +tflitemodel.set_tensor( + input_details[1]["index"], + tf.constant(args.blank, dtype=tf.int32) +) +tflitemodel.set_tensor( + input_details[2]["index"], + tf.zeros([args.num_rnns, args.nstates, 1, args.statesize], dtype=tf.float32) +) +tflitemodel.invoke() +hyp = tflitemodel.get_tensor(output_details[0]["index"]) + +print("".join([chr(u) for u in hyp])) diff --git a/tensorflow_asr/utils/metrics.py b/tensorflow_asr/utils/metrics.py index 336e0623dd..eb8049ebbc 100644 --- a/tensorflow_asr/utils/metrics.py +++ b/tensorflow_asr/utils/metrics.py @@ -11,13 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from typing import Tuple import numpy as np import tensorflow as tf from nltk.metrics import distance from .utils import bytes_to_string -def wer(decode: np.ndarray, target: np.ndarray) -> (tf.Tensor, tf.Tensor): +def wer(decode: np.ndarray, target: np.ndarray) -> Tuple[tf.Tensor, tf.Tensor]: """Word Error Rate Args: @@ -43,7 +45,7 @@ def wer(decode: np.ndarray, target: np.ndarray) -> (tf.Tensor, tf.Tensor): return tf.convert_to_tensor(dis, tf.float32), tf.convert_to_tensor(length, tf.float32) -def cer(decode: np.ndarray, target: np.ndarray) -> (tf.Tensor, tf.Tensor): +def cer(decode: np.ndarray, target: np.ndarray) -> Tuple[tf.Tensor, tf.Tensor]: """Character Error Rate Args: