From 288584a914f1caa53150b3ad80ff9d49d80f2266 Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Sat, 19 Dec 2020 18:10:56 +0700 Subject: [PATCH] :zap: Update batch for faster testing --- examples/conformer/masking/masking.py | 6 +- examples/conformer/masking/trainer.py | 5 +- setup.py | 2 +- tensorflow_asr/datasets/README.md | 2 +- tensorflow_asr/datasets/asr_dataset.py | 32 ++- tensorflow_asr/models/__init__.py | 8 + tensorflow_asr/models/contextnet.py | 11 +- tensorflow_asr/models/ctc.py | 33 ++-- tensorflow_asr/models/streaming_transducer.py | 73 ++++--- tensorflow_asr/models/transducer.py | 182 +++++++++++++----- tensorflow_asr/runners/base_runners.py | 24 +-- tensorflow_asr/utils/utils.py | 1 + tests/plot_learning_rate.py | 4 +- tests/speech_featurizer_test.py | 2 +- tests/test_pos_enc.py | 4 +- 15 files changed, 250 insertions(+), 139 deletions(-) diff --git a/examples/conformer/masking/masking.py b/examples/conformer/masking/masking.py index c93aaabd10..69f8e0b01a 100644 --- a/examples/conformer/masking/masking.py +++ b/examples/conformer/masking/masking.py @@ -1,5 +1,5 @@ import tensorflow as tf -from tensorflow_asr.utils.utils import shape_list +from tensorflow_asr.utils.utils import shape_list, get_reduced_length def create_padding_mask(features, input_length, time_reduction_factor): @@ -14,10 +14,10 @@ def create_padding_mask(features, input_length, time_reduction_factor): [tf.Tensor]: with shape [B, Tquery, Tkey] """ batch_size, padded_time, _, _ = shape_list(features) - reduced_padded_time = tf.math.ceil(padded_time / time_reduction_factor) + reduced_padded_time = get_reduced_length(padded_time, time_reduction_factor) def create_mask(length): - reduced_length = tf.math.ceil(length / time_reduction_factor) + reduced_length = get_reduced_length(length, time_reduction_factor) mask = tf.ones([reduced_length, reduced_length], dtype=tf.float32) return tf.pad( mask, diff --git a/examples/conformer/masking/trainer.py b/examples/conformer/masking/trainer.py index 346b3d43cb..b860eafb2c 100644 --- a/examples/conformer/masking/trainer.py +++ b/examples/conformer/masking/trainer.py @@ -3,6 +3,7 @@ from masking import create_padding_mask from tensorflow_asr.runners.transducer_runners import TransducerTrainer, TransducerTrainerGA from tensorflow_asr.losses.rnnt_losses import rnnt_loss +from tensorflow_asr.utils.utils import get_reduced_length class TrainerWithMasking(TransducerTrainer): @@ -17,7 +18,7 @@ def _train_step(self, batch): tape.watch(logits) per_train_loss = rnnt_loss( logits=logits, labels=labels, label_length=label_length, - logit_length=tf.cast(tf.math.ceil(input_length / self.model.time_reduction_factor), dtype=tf.int32), + logit_length=get_reduced_length(input_length, self.model.time_reduction_factor), blank=self.text_featurizer.blank ) train_loss = tf.nn.compute_average_loss(per_train_loss, @@ -41,7 +42,7 @@ def _train_step(self, batch): tape.watch(logits) per_train_loss = rnnt_loss( logits=logits, labels=labels, label_length=label_length, - logit_length=tf.cast(tf.math.ceil(input_length / self.model.time_reduction_factor), dtype=tf.int32), + logit_length=get_reduced_length(input_length, self.model.time_reduction_factor), blank=self.text_featurizer.blank ) train_loss = tf.nn.compute_average_loss( diff --git a/setup.py b/setup.py index 4d3f7256c3..8f723380e3 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ setuptools.setup( name="TensorFlowASR", - version="0.5.0", + version="0.5.1", author="Huy Le Nguyen", author_email="nlhuy.cs.16@gmail.com", description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2", diff --git a/tensorflow_asr/datasets/README.md b/tensorflow_asr/datasets/README.md index 6a6a0fe396..17a06da227 100644 --- a/tensorflow_asr/datasets/README.md +++ b/tensorflow_asr/datasets/README.md @@ -53,5 +53,5 @@ Where `prediction` and `prediction_length` are the label prepanded by blank and **Outputs when iterating in test step** ```python -(path, signals, labels) +(path, features, input_lengths, labels) ``` \ No newline at end of file diff --git a/tensorflow_asr/datasets/asr_dataset.py b/tensorflow_asr/datasets/asr_dataset.py index 5600b0bb34..702791768b 100755 --- a/tensorflow_asr/datasets/asr_dataset.py +++ b/tensorflow_asr/datasets/asr_dataset.py @@ -241,8 +241,15 @@ class ASRTFRecordTestDataset(ASRTFRecordDataset): def preprocess(self, path, transcript): with tf.device("/CPU:0"): signal = read_raw_audio(path.decode("utf-8"), self.speech_featurizer.sample_rate) + + features = self.speech_featurizer.extract(signal) + features = tf.convert_to_tensor(features, tf.float32) + input_length = tf.cast(tf.shape(features)[0], tf.int32) + label = self.text_featurizer.extract(transcript.decode("utf-8")) - return path, signal, tf.convert_to_tensor(label, dtype=tf.int32) + label = tf.convert_to_tensor(label, dtype=tf.int32) + + return path, features, input_length, label @tf.function def parse(self, record): @@ -256,7 +263,7 @@ def parse(self, record): return tf.numpy_function( self.preprocess, inp=[example["audio"], example["transcript"]], - Tout=(tf.string, tf.float32, tf.int32) + Tout=(tf.string, tf.float32, tf.int32, tf.int32) ) def process(self, dataset, batch_size): @@ -273,10 +280,11 @@ def process(self, dataset, batch_size): batch_size=batch_size, padded_shapes=( tf.TensorShape([]), - tf.TensorShape([None]), + tf.TensorShape(self.speech_featurizer.shape), + tf.TensorShape([]), tf.TensorShape([None]), ), - padding_values=("", 0.0, self.text_featurizer.blank), + padding_values=("", 0.0, 0, self.text_featurizer.blank), drop_remainder=True ) @@ -304,15 +312,22 @@ class ASRSliceTestDataset(ASRDataset): def preprocess(self, path, transcript): with tf.device("/CPU:0"): signal = read_raw_audio(path.decode("utf-8"), self.speech_featurizer.sample_rate) + + features = self.speech_featurizer.extract(signal) + features = tf.convert_to_tensor(features, tf.float32) + input_length = tf.cast(tf.shape(features)[0], tf.int32) + label = self.text_featurizer.extract(transcript.decode("utf-8")) - return path, signal, tf.convert_to_tensor(label, dtype=tf.int32) + label = tf.convert_to_tensor(label, dtype=tf.int32) + + return path, features, input_length, label @tf.function def parse(self, record): return tf.numpy_function( self.preprocess, inp=[record[0], record[1]], - Tout=[tf.string, tf.float32, tf.int32] + Tout=[tf.string, tf.float32, tf.int32, tf.int32] ) def process(self, dataset, batch_size): @@ -329,10 +344,11 @@ def process(self, dataset, batch_size): batch_size=batch_size, padded_shapes=( tf.TensorShape([]), - tf.TensorShape([None]), + tf.TensorShape(self.speech_featurizer.shape), + tf.TensorShape([]), tf.TensorShape([None]), ), - padding_values=("", 0.0, self.text_featurizer.blank), + padding_values=("", 0.0, 0, self.text_featurizer.blank), drop_remainder=True ) diff --git a/tensorflow_asr/models/__init__.py b/tensorflow_asr/models/__init__.py index b2e67d3a44..84955496d0 100644 --- a/tensorflow_asr/models/__init__.py +++ b/tensorflow_asr/models/__init__.py @@ -27,3 +27,11 @@ def _build(self, *args, **kwargs): @abc.abstractmethod def call(self, inputs, training=False, **kwargs): raise NotImplementedError() + + @abc.abstractmethod + def recognize(self, features, input_lengths, **kwargs): + pass + + @abc.abstractmethod + def recognize_beam(self, features, input_lengths, **kwargs): + pass diff --git a/tensorflow_asr/models/contextnet.py b/tensorflow_asr/models/contextnet.py index 90d54672d4..0d80b02dc4 100644 --- a/tensorflow_asr/models/contextnet.py +++ b/tensorflow_asr/models/contextnet.py @@ -13,7 +13,7 @@ # limitations under the License. """ Ref: https://github.com/iankur/ContextNet """ -from typing import List +from typing import List, Optional import tensorflow as tf from .transducer import Transducer from ..utils.utils import merge_two_last_dims, get_reduced_length @@ -234,8 +234,7 @@ def __init__(self, ) self.dmodel = self.encoder.blocks[-1].dmodel self.time_reduction_factor = 1 - for block in self.encoder.blocks: - self.time_reduction_factor *= block.time_reduction_factor + for block in self.encoder.blocks: self.time_reduction_factor *= block.time_reduction_factor def call(self, inputs, training=False, **kwargs): features, input_length, prediction, prediction_length = inputs @@ -244,8 +243,12 @@ def call(self, inputs, training=False, **kwargs): outputs = self.joint_net([enc, pred], training=training, **kwargs) return outputs - def encoder_inference(self, features): + def encoder_inference(self, + features: tf.Tensor, + input_length: Optional[tf.Tensor] = None, + with_batch: bool = False): with tf.name_scope(f"{self.name}_encoder"): + if with_batch: return self.encoder([features, input_length], training=False) input_length = tf.expand_dims(tf.shape(features)[0], axis=0) outputs = tf.expand_dims(features, axis=0) outputs = self.encoder([outputs, input_length], training=False) diff --git a/tensorflow_asr/models/ctc.py b/tensorflow_asr/models/ctc.py index 9abbbbe161..a715685aad 100644 --- a/tensorflow_asr/models/ctc.py +++ b/tensorflow_asr/models/ctc.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional import numpy as np import tensorflow as tf from . import Model from ..featurizers.speech_featurizers import TFSpeechFeaturizer from ..featurizers.text_featurizers import TextFeaturizer -from ..utils.utils import shape_list +from ..utils.utils import shape_list, get_reduced_length class CtcModel(Model): @@ -41,20 +42,15 @@ def call(self, inputs, training=False, **kwargs): # -------------------------------- GREEDY ------------------------------------- @tf.function - def recognize(self, signals): - - def extract_fn(signal): return self.speech_featurizer.tf_extract(signal) - - features = tf.map_fn(extract_fn, signals, - fn_output_signature=tf.TensorSpec(self.speech_featurizer.shape, dtype=tf.float32)) + def recognize(self, features: tf.Tensor, input_length: Optional[tf.Tensor]): logits = self(features, training=False) probs = tf.nn.softmax(logits) - def map_fn(prob): return tf.numpy_function(self.perform_greedy, inp=[prob], Tout=tf.string) + def map_fn(prob): return tf.numpy_function(self.__perform_greedy, inp=[prob], Tout=tf.string) return tf.map_fn(map_fn, probs, fn_output_signature=tf.TensorSpec([], dtype=tf.string)) - def perform_greedy(self, probs: np.ndarray): + def __perform_greedy(self, probs: np.ndarray): from ctc_decoders import ctc_greedy_decoder decoded = ctc_greedy_decoder(probs, vocabulary=self.text_featurizer.vocab_array) return tf.convert_to_tensor(decoded, dtype=tf.string) @@ -71,7 +67,7 @@ def recognize_tflite(self, signal): features = self.speech_featurizer.tf_extract(signal) features = tf.expand_dims(features, axis=0) input_length = shape_list(features)[1] - input_length = input_length // self.base_model.time_reduction_factor + input_length = get_reduced_length(input_length, self.base_model.time_reduction_factor) input_length = tf.expand_dims(input_length, axis=0) logits = self(features, training=False) probs = tf.nn.softmax(logits) @@ -85,25 +81,20 @@ def recognize_tflite(self, signal): # -------------------------------- BEAM SEARCH ------------------------------------- @tf.function - def recognize_beam(self, signals, lm=False): - - def extract_fn(signal): return self.speech_featurizer.tf_extract(signal) - - features = tf.map_fn(extract_fn, signals, - fn_output_signature=tf.TensorSpec(self.speech_featurizer.shape, dtype=tf.float32)) + def recognize_beam(self, features: tf.Tensor, input_length: Optional[tf.Tensor], lm: bool = False): logits = self(features, training=False) probs = tf.nn.softmax(logits) - def map_fn(prob): return tf.numpy_function(self.perform_beam_search, inp=[prob, lm], Tout=tf.string) + def map_fn(prob): return tf.numpy_function(self.__perform_beam_search, inp=[prob, lm], Tout=tf.string) return tf.map_fn(map_fn, probs, dtype=tf.string) - def perform_beam_search(self, probs: np.ndarray, lm: bool = False): + def __perform_beam_search(self, probs: np.ndarray, lm: bool = False): from ctc_decoders import ctc_beam_search_decoder decoded = ctc_beam_search_decoder( probs_seq=probs, vocabulary=self.text_featurizer.vocab_array, - beam_size=self.text_featurizer.decoder_config["beam_width"], + beam_size=self.text_featurizer.decoder_config.beam_width, ext_scoring_func=self.text_featurizer.scorer if lm else None ) decoded = decoded[0][-1] @@ -122,13 +113,13 @@ def recognize_beam_tflite(self, signal): features = self.speech_featurizer.tf_extract(signal) features = tf.expand_dims(features, axis=0) input_length = shape_list(features)[1] - input_length = input_length // self.base_model.time_reduction_factor + input_length = get_reduced_length(input_length, self.base_model.time_reduction_factor) input_length = tf.expand_dims(input_length, axis=0) logits = self(features, training=False) probs = tf.nn.softmax(logits) decoded = tf.keras.backend.ctc_decode( y_pred=probs, input_length=input_length, greedy=False, - beam_width=self.text_featurizer.decoder_config["beam_width"] + beam_width=self.text_featurizer.decoder_config.beam_width ) decoded = tf.cast(decoded[0][0][0], dtype=tf.int32) transcript = self.text_featurizer.indices2upoints(decoded) diff --git a/tensorflow_asr/models/streaming_transducer.py b/tensorflow_asr/models/streaming_transducer.py index 0c1c302c7c..d80c8f356f 100644 --- a/tensorflow_asr/models/streaming_transducer.py +++ b/tensorflow_asr/models/streaming_transducer.py @@ -13,6 +13,7 @@ # limitations under the License. """ http://arxiv.org/abs/1811.06621 """ +from typing import Optional import tensorflow as tf from .layers.subsampling import TimeReduction @@ -222,23 +223,24 @@ def __init__(self, ) self.time_reduction_factor = self.encoder.time_reduction_factor - def summary(self, line_length=None, **kwargs): - for block in self.encoder.blocks: - block.summary(line_length=line_length, **kwargs) - super(StreamingTransducer, self).summary(line_length=line_length, **kwargs) - - def encoder_inference(self, features, states): + def encoder_inference(self, + features: tf.Tensor, + states: tf.Tensor, + input_length: Optional[tf.Tensor] = None, + with_batch: bool = False): """Infer function for encoder (or encoders) Args: features (tf.Tensor): features with shape [T, F, C] states (tf.Tensor): previous states of encoders with shape [num_rnns, 1 or 2, 1, P] + with_batch (bool): indicates whether the features included batch dim or not Returns: tf.Tensor: output of encoders with shape [T, E] tf.Tensor: states of encoders with shape [num_rnns, 1 or 2, 1, P] """ with tf.name_scope(f"{self.name}_encoder"): + if with_batch: return self.encoder.recognize(features, states) outputs = tf.expand_dims(features, axis=0) outputs, new_states = self.encoder.recognize(outputs, states) return tf.squeeze(outputs, axis=0), new_states @@ -246,28 +248,26 @@ def encoder_inference(self, features, states): # -------------------------------- GREEDY ------------------------------------- @tf.function - def recognize(self, signals): + def recognize(self, + features: tf.Tensor, + input_length: tf.Tensor, + parallel_iterations: int = 10, + swap_memory: bool = True): """ RNN Transducer Greedy decoding Args: - signals (tf.Tensor): a batch of padded signals + features (tf.Tensor): a batch of padded extracted features Returns: tf.Tensor: a batch of decoded transcripts """ - def execute(signal: tf.Tensor): - features = self.speech_featurizer.tf_extract(signal) - encoded, _ = self.encoder_inference(features, self.encoder.get_initial_state()) - hypothesis = self.perform_greedy( - encoded, - predicted=tf.constant(self.text_featurizer.blank, dtype=tf.int32), - states=self.predict_net.get_initial_state(), - swap_memory=True - ) - transcripts = self.text_featurizer.iextract(tf.expand_dims(hypothesis.prediction, axis=0)) - return tf.squeeze(transcripts) # reshape from [1] to [] - - return tf.map_fn(execute, signals, fn_output_signature=tf.TensorSpec([], dtype=tf.string)) + encoded, _ = self.encoder_inference( + features, + self.encoder.get_initial_state(), + input_length=input_length, with_batch=True + ) + return self.__perform_greedy_batch(encoded, input_length, + parallel_iterations=parallel_iterations, swap_memory=swap_memory) def recognize_tflite(self, signal, predicted, encoder_states, prediction_states): """ @@ -286,7 +286,7 @@ def recognize_tflite(self, signal, predicted, encoder_states, prediction_states) """ features = self.speech_featurizer.tf_extract(signal) encoded, new_encoder_states = self.encoder_inference(features, encoder_states) - hypothesis = self.perform_greedy(encoded, predicted, prediction_states, swap_memory=False) + hypothesis = self.__perform_greedy(encoded, tf.shape(encoded)[0], predicted, prediction_states) transcript = self.text_featurizer.indices2upoints(hypothesis.prediction) return ( transcript, @@ -298,29 +298,28 @@ def recognize_tflite(self, signal, predicted, encoder_states, prediction_states) # -------------------------------- BEAM SEARCH ------------------------------------- @tf.function - def recognize_beam(self, signals, lm=False): + def recognize_beam(self, + features: tf.Tensor, + input_length: tf.Tensor, + lm: bool = False, + parallel_iterations: int = 10, + swap_memory: bool = True): """ RNN Transducer Beam Search Args: - signals (tf.Tensor): a batch of padded signals + features (tf.Tensor): a batch of padded extracted features lm (bool, optional): whether to use language model. Defaults to False. Returns: tf.Tensor: a batch of decoded transcripts """ - def execute(signal: tf.Tensor): - features = self.speech_featurizer.tf_extract(signal) - encoded, _ = self.encoder_inference(features, self.encoder.get_initial_state()) - hypothesis = self.perform_beam_search(encoded, lm) - prediction = tf.map_fn( - lambda x: tf.strings.to_number(x, tf.int32), - tf.strings.split(hypothesis.prediction), - fn_output_signature=tf.TensorSpec([], dtype=tf.int32) - ) - transcripts = self.text_featurizer.iextract(tf.expand_dims(prediction, axis=0)) - return tf.squeeze(transcripts) # reshape from [1] to [] - - return tf.map_fn(execute, signals, fn_output_signature=tf.TensorSpec([], dtype=tf.string)) + encoded, _ = self.encoder_inference( + features, + self.encoder.get_initial_state(), + input_length=input_length, with_batch=True + ) + return self.__perform_beam_search_batch(encoded, input_length, lm, + parallel_iterations=parallel_iterations, swap_memory=swap_memory) # -------------------------------- TFLITE ------------------------------------- diff --git a/tensorflow_asr/models/transducer.py b/tensorflow_asr/models/transducer.py index eebbcad170..50490d92f4 100755 --- a/tensorflow_asr/models/transducer.py +++ b/tensorflow_asr/models/transducer.py @@ -14,11 +14,12 @@ """ https://arxiv.org/pdf/1811.06621.pdf """ import collections +from typing import Optional import tensorflow as tf from . import Model from ..utils.utils import get_rnn, shape_list -from ..featurizers.speech_featurizers import TFSpeechFeaturizer +from ..featurizers.speech_featurizers import SpeechFeaturizer from ..featurizers.text_featurizers import TextFeaturizer from .layers.embedding import Embedding @@ -126,8 +127,7 @@ def recognize(self, inputs, states): outputs = self.do(outputs, training=False) new_states = [] for i, rnn in enumerate(self.rnns): - outputs = rnn["rnn"](outputs, training=False, - initial_state=tf.unstack(states[i], axis=0)) + outputs = rnn["rnn"](outputs, training=False, initial_state=tf.unstack(states[i], axis=0)) new_states.append(tf.stack(outputs[1:])) outputs = outputs[0] if rnn["ln"] is not None: @@ -248,7 +248,7 @@ def summary(self, line_length=None, **kwargs): super(Transducer, self).summary(line_length=line_length, **kwargs) def add_featurizers(self, - speech_featurizer: TFSpeechFeaturizer, + speech_featurizer: SpeechFeaturizer, text_featurizer: TextFeaturizer): """ Function to add featurizer to model to convert to end2end tflite @@ -280,21 +280,27 @@ def call(self, inputs, training=False, **kwargs): outputs = self.joint_net([enc, pred], training=training, **kwargs) return outputs - def encoder_inference(self, features): + def encoder_inference(self, + features: tf.Tensor, + input_length: Optional[tf.Tensor] = None, + with_batch: Optional[bool] = False): """Infer function for encoder (or encoders) Args: features (tf.Tensor): features with shape [T, F, C] + input_length (tf.Tensor): optional features length with shape [] + with_batch (bool): indicates whether the features included batch dim or not Returns: tf.Tensor: output of encoders with shape [T, E] """ with tf.name_scope(f"{self.name}_encoder"): + if with_batch: return self.encoder(features, training=False) outputs = tf.expand_dims(features, axis=0) outputs = self.encoder(outputs, training=False) return tf.squeeze(outputs, axis=0) - def decoder_inference(self, encoded, predicted, states): + def decoder_inference(self, encoded: tf.Tensor, predicted: tf.Tensor, states: tf.Tensor): """Infer function for decoder Args: @@ -322,28 +328,23 @@ def get_config(self): # -------------------------------- GREEDY ------------------------------------- @tf.function - def recognize(self, signals): + def recognize(self, + features: tf.Tensor, + input_length: tf.Tensor, + parallel_iterations: int = 10, + swap_memory: bool = True): """ RNN Transducer Greedy decoding Args: - signals (tf.Tensor): a batch of padded signals + features (tf.Tensor): a batch of extracted features + input_length (tf.Tensor): a batch of extracted features length Returns: tf.Tensor: a batch of decoded transcripts """ - def execute(signal: tf.Tensor): - features = self.speech_featurizer.tf_extract(signal) - encoded = self.encoder_inference(features) - hypothesis = self.perform_greedy( - encoded, - predicted=tf.constant(self.text_featurizer.blank, dtype=tf.int32), - states=self.predict_net.get_initial_state(), - swap_memory=True - ) - transcripts = self.text_featurizer.iextract(tf.expand_dims(hypothesis.prediction, axis=0)) - return tf.squeeze(transcripts) # reshape from [1] to [] - - return tf.map_fn(execute, signals, fn_output_signature=tf.TensorSpec([], dtype=tf.string)) + encoded = self.encoder_inference(features, input_length, with_batch=True) + return self.__perform_greedy_batch(encoded, input_length, + parallel_iterations=parallel_iterations, swap_memory=swap_memory) def recognize_tflite(self, signal, predicted, states): """ @@ -360,7 +361,7 @@ def recognize_tflite(self, signal, predicted, states): """ features = self.speech_featurizer.tf_extract(signal) encoded = self.encoder_inference(features) - hypothesis = self.perform_greedy(encoded, predicted, states, swap_memory=False) + hypothesis = self.__perform_greedy(encoded, tf.shape(encoded)[0], predicted, states) transcript = self.text_featurizer.indices2upoints(hypothesis.prediction) return ( transcript, @@ -368,10 +369,54 @@ def recognize_tflite(self, signal, predicted, states): hypothesis.states ) - def perform_greedy(self, encoded, predicted, states, swap_memory=False): + def __perform_greedy_batch(self, + encoded: tf.Tensor, + encoded_length: tf.Tensor, + parallel_iterations: int = 10, + swap_memory: bool = False): + total = tf.shape(encoded)[0] + batch = tf.constant(0, dtype=tf.int32) + + decoded = tf.TensorArray( + dtype=tf.string, + size=total, dynamic_size=False, + clear_after_read=False, element_shape=tf.TensorShape([]) + ) + + def condition(batch, total, encoded, encoded_length, decoded): return tf.less(batch, total) + + def body(batch, total, encoded, encoded_length, decoded): + hypothesis = self.__perform_greedy( + encoded=encoded[batch], + encoded_length=encoded_length[batch], + predicted=tf.constant(self.text_featurizer.blank, dtype=tf.int32), + states=self.predict_net.get_initial_state(), + parallel_iterations=parallel_iterations, + swap_memory=swap_memory + ) + transcripts = self.text_featurizer.iextract(tf.expand_dims(hypothesis.prediction, axis=0)) + decoded = decoded.write(batch, tf.squeeze(transcripts)) + return batch + 1, total, encoded, encoded_length, decoded + + batch, total, _, _, decoded = tf.while_loop( + condition, body, + loop_vars=(batch, total, encoded, encoded_length, decoded), + parallel_iterations=parallel_iterations, + swap_memory=True, + ) + + return decoded.stack() + + def __perform_greedy(self, + encoded: tf.Tensor, + encoded_length: tf.Tensor, + predicted: tf.Tensor, + states: tf.Tensor, + parallel_iterations: int = 10, + swap_memory: bool = False): with tf.name_scope(f"{self.name}_greedy"): time = tf.constant(0, dtype=tf.int32) - total = tf.shape(encoded)[0] + total = encoded_length # Initialize prediction with a blank # Prediction can not be longer than the encoded of audio plus blank prediction = tf.TensorArray( @@ -425,6 +470,7 @@ def body(time, total, encoded, hypothesis): condition, body, loop_vars=(time, total, encoded, hypothesis), + parallel_iterations=parallel_iterations, swap_memory=swap_memory ) @@ -443,38 +489,71 @@ def body(time, total, encoded, hypothesis): # -------------------------------- BEAM SEARCH ------------------------------------- @tf.function - def recognize_beam(self, signals, lm=False): + def recognize_beam(self, + features: tf.Tensor, + input_length: tf.Tensor, + lm: bool = False, + parallel_iterations: int = 10, + swap_memory: bool = True): """ RNN Transducer Beam Search Args: - signals (tf.Tensor): a batch of padded signals + features (tf.Tensor): a batch of padded extracted features lm (bool, optional): whether to use language model. Defaults to False. Returns: tf.Tensor: a batch of decoded transcripts """ - def execute(signal: tf.Tensor): - features = self.speech_featurizer.tf_extract(signal) - encoded = self.encoder_inference(features) - hypothesis = self.perform_beam_search(encoded, lm) - prediction = tf.map_fn( - lambda x: tf.strings.to_number(x, tf.int32), - tf.strings.split(hypothesis.prediction), - fn_output_signature=tf.TensorSpec([], dtype=tf.int32) - ) - transcripts = self.text_featurizer.iextract(tf.expand_dims(prediction, axis=0)) - return tf.squeeze(transcripts) # reshape from [1] to [] + encoded = self.encoder_inference(features, input_length, with_batch=True) + return self.__perform_beam_search_batch(encoded, input_length, lm, + parallel_iterations=parallel_iterations, swap_memory=swap_memory) + + def __perform_beam_search_batch(self, + encoded: tf.Tensor, + encoded_length: tf.Tensor, + lm: bool = False, + parallel_iterations: int = 10, + swap_memory: bool = False): + total = tf.shape(encoded)[0] + batch = tf.constant(0, dtype=tf.int32) + + decoded = tf.TensorArray( + dtype=tf.string, + size=total, dynamic_size=False, + clear_after_read=False, element_shape=tf.TensorShape([]) + ) + + def condition(batch, total, encoded, encoded_length, decoded): return tf.less(batch, total) - return tf.map_fn(execute, signals, fn_output_signature=tf.TensorSpec([], dtype=tf.string)) + def body(batch, total, encoded, encoded_length, decoded): + hypothesis = self.__perform_beam_search(encoded[batch], encoded_length[batch], lm, + parallel_iterations=parallel_iterations, swap_memory=swap_memory) + transcripts = self.text_featurizer.iextract(tf.expand_dims(hypothesis.prediction, axis=0)) + decoded = decoded.write(batch, tf.squeeze(transcripts)) + return batch + 1, total, encoded, encoded_length, decoded + + batch, total, _, _, decoded = tf.while_loop( + condition, body, + loop_vars=(batch, total, encoded, encoded_length, decoded), + parallel_iterations=parallel_iterations, + swap_memory=True, + ) + + return decoded.stack() - def perform_beam_search(self, encoded, lm=False): + def __perform_beam_search(self, + encoded: tf.Tensor, + encoded_length: tf.Tensor, + lm: bool = False, + parallel_iterations: int = 10, + swap_memory: bool = False): with tf.device("/CPU:0"), tf.name_scope(f"{self.name}_beam_search"): beam_width = tf.cond( tf.less(self.text_featurizer.decoder_config.beam_width, self.text_featurizer.num_classes), true_fn=lambda: self.text_featurizer.decoder_config.beam_width, false_fn=lambda: self.text_featurizer.num_classes - 1 ) - total = tf.shape(encoded)[0] + total = encoded_length def initialize_beam(dynamic=False): return BeamHypothesis( @@ -567,15 +646,27 @@ def predict_body(pred, A, A_i, B): A = BeamHypothesis(score=a_score, indices=a_indices, prediction=a_prediction, states=a_states) return pred + 1, A, A_i, B - _, A, A_i, B = tf.while_loop(predict_condition, predict_body, loop_vars=(0, A, A_i, B)) + _, A, A_i, B = tf.while_loop( + predict_condition, predict_body, + loop_vars=(0, A, A_i, B), + parallel_iterations=parallel_iterations, swap_memory=swap_memory + ) return beam + 1, beam_width, A, A_i, B - _, _, A, A_i, B = tf.while_loop(beam_condition, beam_body, loop_vars=(0, beam_width, A, A_i, B)) + _, _, A, A_i, B = tf.while_loop( + beam_condition, beam_body, + loop_vars=(0, beam_width, A, A_i, B), + parallel_iterations=parallel_iterations, swap_memory=swap_memory + ) return time + 1, total, B - _, _, B = tf.while_loop(condition, body, loop_vars=(0, total, B)) + _, _, B = tf.while_loop( + condition, body, + loop_vars=(0, total, B), + parallel_iterations=parallel_iterations, swap_memory=swap_memory + ) scores = B.score.stack() if self.text_featurizer.decoder_config.norm_score: @@ -590,7 +681,7 @@ def predict_body(pred, A, A_i, B): return Hypothesis( index=y_hat_index, - prediction=y_hat_prediction, + prediction=tf.strings.to_number(tf.strings.split(y_hat_prediction), out_type=tf.int32), states=y_hat_states ) @@ -602,7 +693,6 @@ def make_tflite_function(self, greedy: bool = True): input_signature=[ tf.TensorSpec([None], dtype=tf.float32), tf.TensorSpec([], dtype=tf.int32), - tf.TensorSpec(self.predict_net.get_initial_state().get_shape(), - dtype=tf.float32) + tf.TensorSpec(self.predict_net.get_initial_state().get_shape(), dtype=tf.float32) ] ) diff --git a/tensorflow_asr/runners/base_runners.py b/tensorflow_asr/runners/base_runners.py index a71015cdfb..6e26b95889 100644 --- a/tensorflow_asr/runners/base_runners.py +++ b/tensorflow_asr/runners/base_runners.py @@ -22,7 +22,7 @@ import tensorflow as tf from ..configs.config import RunningConfig -from ..utils.utils import get_num_batches, bytes_to_string +from ..utils.utils import get_num_batches, bytes_to_string, get_reduced_length from ..utils.metrics import ErrorRate, wer, cer @@ -386,8 +386,9 @@ def set_output_file(self): with open(self.output_file_path, "w") as out: out.write("PATH\tGROUNDTRUTH\tGREEDY\tBEAMSEARCH\tBEAMSEARCHLM\n") - def set_test_data_loader(self, test_dataset, batch_size=1): + def set_test_data_loader(self, test_dataset, batch_size=None): """Set train data loader (MUST).""" + if not batch_size: batch_size = self.config.batch_size self.test_data_loader = test_dataset.create(batch_size) self.total_steps = test_dataset.total_steps @@ -399,7 +400,7 @@ def compile(self, trained_model: tf.keras.Model): raise AttributeError("Please do 'add_featurizers' before testing") self.model = trained_model - def run(self, test_dataset, batch_size=1): + def run(self, test_dataset, batch_size=None): self.set_output_file() self.set_test_data_loader(test_dataset, batch_size=batch_size) self._test_epoch() @@ -419,7 +420,7 @@ def _test_epoch(self): except tf.errors.OutOfRangeError: break - decoded = [d.numpy() for d in decoded] + decoded = [None if d is None else d.numpy() for d in decoded] self._append_to_file(*decoded) progbar.update(1) @@ -440,15 +441,16 @@ def _test_step(self, batch): Returns: (file_paths, groundtruth, greedy, beamsearch, beamsearch_lm) each has shape [B] """ - file_paths, signals, labels = batch + file_paths, features, input_length, labels = batch labels = self.model.text_featurizer.iextract(labels) - greed_pred = self.model.recognize(signals) - beam_pred = beam_lm_pred = tf.constant([""], dtype=tf.string) + input_length = get_reduced_length(input_length, self.model.time_reduction_factor) + greed_pred = self.model.recognize(features, input_length) + beam_pred = beam_lm_pred = None if self.model.text_featurizer.decoder_config.beam_width > 0: - beam_pred = self.model.recognize_beam(signals, lm=False) + beam_pred = self.model.recognize_beam(features, input_length, lm=False) if self.model.text_featurizer.decoder_config.lm_config: - beam_lm_pred = self.model.recognize_beam(signals, lm=True) + beam_lm_pred = self.model.recognize_beam(features, input_length, lm=True) return file_paths, labels, greed_pred, beam_pred, beam_lm_pred @@ -492,8 +494,8 @@ def _append_to_file(self, file_path = bytes_to_string(file_path) groundtruth = bytes_to_string(groundtruth) greedy = bytes_to_string(greedy) - beamsearch = bytes_to_string(beamsearch) - beamsearch_lm = bytes_to_string(beamsearch_lm) + beamsearch = bytes_to_string(beamsearch) if beamsearch is not None else ["" for _ in file_path] + beamsearch_lm = bytes_to_string(beamsearch_lm) if beamsearch_lm is not None else ["" for _ in file_path] with open(self.output_file_path, "a", encoding="utf-8") as out: for i, path in enumerate(file_path): line = f"{groundtruth[i]}\t{greedy[i]}\t{beamsearch[i]}\t{beamsearch_lm[i]}" diff --git a/tensorflow_asr/utils/utils.py b/tensorflow_asr/utils/utils.py index 976600c00c..b98049817d 100755 --- a/tensorflow_asr/utils/utils.py +++ b/tensorflow_asr/utils/utils.py @@ -58,6 +58,7 @@ def nan_to_zero(input_tensor): def bytes_to_string(array: np.ndarray, encoding: str = "utf-8"): + if array is None: return None return [transcript.decode(encoding) for transcript in array] diff --git a/tests/plot_learning_rate.py b/tests/plot_learning_rate.py index 96bcbad620..65e3f41770 100755 --- a/tests/plot_learning_rate.py +++ b/tests/plot_learning_rate.py @@ -21,11 +21,11 @@ plt.plot(lr(tf.range(40000, dtype=tf.float32))) plt.ylabel("Learning Rate") plt.xlabel("Train Step") -plt.show() +# plt.show() lr = TransformerSchedule(d_model=144, warmup_steps=10000) plt.plot(lr(tf.range(2000000, dtype=tf.float32))) plt.ylabel("Learning Rate") plt.xlabel("Train Step") -plt.show() +# plt.show() diff --git a/tests/speech_featurizer_test.py b/tests/speech_featurizer_test.py index b4ae71ba39..3424560750 100755 --- a/tests/speech_featurizer_test.py +++ b/tests/speech_featurizer_test.py @@ -59,7 +59,7 @@ def main(argv): plt.colorbar(pad=0.01, fraction=0.02, ax=ax, format="%.2f", ticks=v1) plt.tight_layout() # plt.savefig(argv[3]) - plt.show() + # plt.show() # plt.figure(figsize=(15, 5)) # for i in range(4): # plt.subplot(2, 2, i + 1) diff --git a/tests/test_pos_enc.py b/tests/test_pos_enc.py index 66eb9bed4c..26543a03c0 100755 --- a/tests/test_pos_enc.py +++ b/tests/test_pos_enc.py @@ -25,7 +25,7 @@ plt.xlim((0, 144)) plt.ylabel('Position') plt.colorbar() -plt.show() +# plt.show() rel = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])[None, None, ...] rel_shift = RelPositionMultiHeadAttention.relative_shift(rel) @@ -40,4 +40,4 @@ plt.subplot(2, 1, 2) plt.imshow(rel_shift[0][0]) plt.colorbar() -plt.show() +# plt.show()