Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 51 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,28 @@ TensorFlowASR implements some automatic speech recognition architectures such as
- Support `transducer` tflite greedy decoding (conversion and invocation)
- Distributed training using `tf.distribute.MirroredStrategy`

## Table of Contents
<!-- TOC -->

- [What's New?](#whats-new)
- [Table of Contents](#table-of-contents)
- [:yum: Supported Models](#yum-supported-models)
- [Installation](#installation)
- [Installing via PyPi](#installing-via-pypi)
- [Installing from source](#installing-from-source)
- [Setup training and testing](#setup-training-and-testing)
- [TFLite Convertion](#tflite-convertion)
- [Features Extraction](#features-extraction)
- [Augmentations](#augmentations)
- [Training & Testing](#training--testing)
- [Corpus Sources and Pretrained Models](#corpus-sources-and-pretrained-models)
- [English](#english)
- [Vietnamese](#vietnamese)
- [German](#german)
- [References & Credits](#references--credits)

<!-- /TOC -->

## :yum: Supported Models

- **CTCModel** (End2end models using CTC Loss for training)
Expand All @@ -43,26 +65,44 @@ TensorFlowASR implements some automatic speech recognition architectures such as
- **Streaming Transducer** (Reference: [https://arxiv.org/abs/1811.06621](https://arxiv.org/abs/1811.06621))
See [examples/streaming_transducer](./examples/streaming_transducer)

## Setup Environment and Datasets
## Installation

Install tensorflow: `pip3 install -U tensorflow` or `pip3 install tf-nightly` (for using tflite)
For training and testing, you should use `git clone` for installing necessary packages from other authors (`ctc_decoders`, `rnnt_loss`, etc.)

Install packages (choose _one_ of these options):
### Installing via PyPi

- Run `pip3 install -U TensorFlowASR`
- Clone the repo and run `python3 setup.py install` in the repo's directory
Run `pip3 install -U TensorFlowASR`

For **setting up datasets**, see [datasets](./tensorflow_asr/datasets/README.md)
### Installing from source

- For _training, testing and using_ **CTC Models**, run `./scripts/install_ctc_decoders.sh`
```bash
git clone https://github.com/TensorSpeech/TensorFlowASR.git
cd TensorFlowASR
python3 setup.py install
```

- For _training_ **Transducer Models**, export `CUDA_HOME` and run `./scripts/install_rnnt_loss.sh`
For anaconda3:

```bash
conda create -y -n tfasr tensorflow-gpu python=3.7 # tensorflow if using CPU
conda activate tfasr
pip install -U tensorflow-gpu # upgrade to latest version of tensorflow
git clone https://github.com/TensorSpeech/TensorFlowASR.git
cd TensorFlowASR
python setup.py install
```

## Setup training and testing

- For datasets, see [datasets](./tensorflow_asr/datasets/README.md)

- For _training, testing and using_ **CTC Models**, run `./scripts/install_ctc_decoders.sh`

- Method `tensorflow_asr.utils.setup_environment()` enable **mixed_precision** if available.
- For _training_ **Transducer Models**, run `export CUDA_HOME=/usr/local/cuda && ./scripts/install_rnnt_loss.sh` (**Note**: only `export CUDA_HOME` when you have CUDA)

- To enable XLA, run `TF_XLA_FLAGS=--tf_xla_auto_jit=2 $python_train_script`
- For _mixed precision training_, use flag `--mxp` when running python scripts from [examples](./examples)

Clean up: `python3 setup.py clean --all` (this will remove `/build` contents)
- For _enabling XLA_, run `TF_XLA_FLAGS=--tf_xla_auto_jit=2 python3 $path_to_py_script`)

## TFLite Convertion

Expand Down
63 changes: 42 additions & 21 deletions examples/demonstration/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,44 +13,65 @@
# limitations under the License.

import argparse
import tensorflow as tf
from tensorflow_asr.utils import setup_environment, setup_devices

from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio
setup_environment()
import tensorflow as tf

parser = argparse.ArgumentParser(prog="Conformer non streaming")

parser.add_argument("filename", metavar="FILENAME",
help="audio file to be played back")

parser.add_argument("--tflite", type=str, default=None,
help="Path to conformer tflite")
parser.add_argument("--config", type=str, default=None,
help="Path to conformer config yaml")

parser.add_argument("--saved", type=str, default=None,
help="Path to conformer saved h5 weights")

parser.add_argument("--blank", type=int, default=0,
help="Path to conformer tflite")

parser.add_argument("--num_rnns", type=int, default=1,
help="Number of RNN layers in prediction network")

parser.add_argument("--nstates", type=int, default=2,
help="Number of RNN states in prediction network (1 for GRU and 2 for LSTM)")

parser.add_argument("--statesize", type=int, default=320,
help="Path to conformer tflite")
help="Size of RNN state in prediction network")

parser.add_argument("--device", type=int, default=0,
help="Device's id to run test on")

parser.add_argument("--cpu", default=False, action="store_true",
help="Whether to only use cpu")

args = parser.parse_args()

tflitemodel = tf.lite.Interpreter(model_path=args.tflite)
setup_devices([args.device], cpu=args.cpu)

from tensorflow_asr.configs.config import Config
from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
from tensorflow_asr.models.conformer import Conformer

config = Config(args.config, learning=False)
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
text_featurizer = CharFeaturizer(config.decoder_config)

# build model
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
conformer._build(speech_featurizer.shape)
conformer.load_weights(args.saved, by_name=True)
conformer.summary(line_length=120)
conformer.add_featurizers(speech_featurizer, text_featurizer)

signal = read_raw_audio(args.filename)
predicted = tf.constant(args.blank, dtype=tf.int32)
states = tf.zeros([args.num_rnns, args.nstates, 1, args.statesize], dtype=tf.float32)

input_details = tflitemodel.get_input_details()
output_details = tflitemodel.get_output_details()
tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape)
tflitemodel.allocate_tensors()
tflitemodel.set_tensor(input_details[0]["index"], signal)
tflitemodel.set_tensor(
input_details[1]["index"],
tf.constant(args.blank, dtype=tf.int32)
)
tflitemodel.set_tensor(
input_details[2]["index"],
tf.zeros([1, 2, 1, args.statesize], dtype=tf.float32)
)
tflitemodel.invoke()
hyp = tflitemodel.get_tensor(output_details[0]["index"])
hyp, _, _ = conformer.recognize_tflite(signal, predicted, states)

print("".join([chr(u) for u in hyp]))
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@ def int_or_str(text):
parser.add_argument("--tflite", type=str, default=None,
help="Path to conformer tflite")

parser.add_argument("--blank", type=int, default=0,
help="Path to conformer tflite")

parser.add_argument("--num_rnns", type=int, default=1,
help="Number of RNN layers in prediction network")

parser.add_argument("--nstates", type=int, default=2,
help="Number of RNN states in prediction network (1 for GRU and 2 for LSTM)")

parser.add_argument("--statesize", type=int, default=320,
help="Size of RNN state in prediction network")

args = parser.parse_args(remaining)

if args.blocksize == 0:
Expand Down Expand Up @@ -92,8 +104,8 @@ def recognize(signal, lastid, states):
text = "".join([chr(u) for u in upoints])
return text, lastid, states

lastid = np.zeros(shape=[], dtype=np.int32)
states = np.zeros(shape=[1, 2, 1, 320], dtype=np.float32)
lastid = args.blank * np.ones(shape=[], dtype=np.int32)
states = np.zeros(shape=[args.num_rnns, args.nstates, 1, args.statesize], dtype=np.float32)
transcript = ""

while True:
Expand Down
62 changes: 62 additions & 0 deletions examples/demonstration/tflite_conformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import tensorflow as tf

from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio

parser = argparse.ArgumentParser(prog="Conformer non streaming")

parser.add_argument("filename", metavar="FILENAME",
help="Audio file to be played back")

parser.add_argument("--tflite", type=str, default=None,
help="Path to conformer tflite")

parser.add_argument("--blank", type=int, default=0,
help="Blank index")

parser.add_argument("--num_rnns", type=int, default=1,
help="Number of RNN layers in prediction network")

parser.add_argument("--nstates", type=int, default=2,
help="Number of RNN states in prediction network (1 for GRU and 2 for LSTM)")

parser.add_argument("--statesize", type=int, default=320,
help="Size of RNN state in prediction network")

args = parser.parse_args()

tflitemodel = tf.lite.Interpreter(model_path=args.tflite)

signal = read_raw_audio(args.filename)

input_details = tflitemodel.get_input_details()
output_details = tflitemodel.get_output_details()
tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape)
tflitemodel.allocate_tensors()
tflitemodel.set_tensor(input_details[0]["index"], signal)
tflitemodel.set_tensor(
input_details[1]["index"],
tf.constant(args.blank, dtype=tf.int32)
)
tflitemodel.set_tensor(
input_details[2]["index"],
tf.zeros([args.num_rnns, args.nstates, 1, args.statesize], dtype=tf.float32)
)
tflitemodel.invoke()
hyp = tflitemodel.get_tensor(output_details[0]["index"])

print("".join([chr(u) for u in hyp]))
6 changes: 4 additions & 2 deletions tensorflow_asr/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple
import numpy as np
import tensorflow as tf
from nltk.metrics import distance
from .utils import bytes_to_string


def wer(decode: np.ndarray, target: np.ndarray) -> (tf.Tensor, tf.Tensor):
def wer(decode: np.ndarray, target: np.ndarray) -> Tuple[tf.Tensor, tf.Tensor]:
"""Word Error Rate

Args:
Expand All @@ -43,7 +45,7 @@ def wer(decode: np.ndarray, target: np.ndarray) -> (tf.Tensor, tf.Tensor):
return tf.convert_to_tensor(dis, tf.float32), tf.convert_to_tensor(length, tf.float32)


def cer(decode: np.ndarray, target: np.ndarray) -> (tf.Tensor, tf.Tensor):
def cer(decode: np.ndarray, target: np.ndarray) -> Tuple[tf.Tensor, tf.Tensor]:
"""Character Error Rate

Args:
Expand Down