From b12dca470935382c8f943846925df9474349ddaf Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Sat, 2 Jan 2021 18:39:51 +0700 Subject: [PATCH 1/6] :rocket: update transducer beam search --- tensorflow_asr/models/transducer.py | 107 +++++++++++++++++----------- tensorflow_asr/utils/utils.py | 4 ++ 2 files changed, 68 insertions(+), 43 deletions(-) diff --git a/tensorflow_asr/models/transducer.py b/tensorflow_asr/models/transducer.py index bfc901157d..bace2e371f 100755 --- a/tensorflow_asr/models/transducer.py +++ b/tensorflow_asr/models/transducer.py @@ -18,7 +18,7 @@ import tensorflow as tf from . import Model -from ..utils.utils import get_rnn, shape_list +from ..utils.utils import get_rnn, shape_list, count_non_blank from ..featurizers.speech_featurizers import SpeechFeaturizer from ..featurizers.text_featurizers import TextFeaturizer from .layers.embedding import Embedding @@ -551,7 +551,7 @@ def _perform_beam_search(self, lm: bool = False, parallel_iterations: int = 10, swap_memory: bool = False): - with tf.device("/CPU:0"), tf.name_scope(f"{self.name}_beam_search"): + with tf.name_scope(f"{self.name}_beam_search"): beam_width = tf.cond( tf.less(self.text_featurizer.decoder_config.beam_width, self.text_featurizer.num_classes), true_fn=lambda: self.text_featurizer.decoder_config.beam_width, @@ -562,22 +562,20 @@ def _perform_beam_search(self, def initialize_beam(dynamic=False): return BeamHypothesis( score=tf.TensorArray( - dtype=tf.float32, size=beam_width if not dynamic else 0, - dynamic_size=dynamic, element_shape=tf.TensorShape([]), clear_after_read=False + dtype=tf.float32, size=beam_width if not dynamic else 0, dynamic_size=dynamic, + element_shape=tf.TensorShape([]), clear_after_read=False ), indices=tf.TensorArray( - dtype=tf.int32, size=beam_width if not dynamic else 0, - dynamic_size=dynamic, element_shape=tf.TensorShape([]), clear_after_read=False + dtype=tf.int32, size=beam_width if not dynamic else 0, dynamic_size=dynamic, + element_shape=tf.TensorShape([]), clear_after_read=False ), prediction=tf.TensorArray( - dtype=tf.string, size=beam_width if not dynamic else 0, dynamic_size=dynamic, - element_shape=tf.TensorShape([]), clear_after_read=False + dtype=tf.int32, size=beam_width if not dynamic else 0, dynamic_size=dynamic, + element_shape=None, clear_after_read=False ), states=tf.TensorArray( - dtype=tf.float32, size=beam_width if not dynamic else 0, - dynamic_size=dynamic, - element_shape=tf.TensorShape(shape_list(self.predict_net.get_initial_state())), - clear_after_read=False + dtype=tf.float32, size=beam_width if not dynamic else 0, dynamic_size=dynamic, + element_shape=tf.TensorShape(shape_list(self.predict_net.get_initial_state())), clear_after_read=False ), ) @@ -585,7 +583,7 @@ def initialize_beam(dynamic=False): B = BeamHypothesis( score=B.score.write(0, 0.0), indices=B.indices.write(0, self.text_featurizer.blank), - prediction=B.prediction.write(0, ''), + prediction=B.prediction.write(0, tf.ones([total], dtype=tf.int32) * self.text_featurizer.blank), states=B.states.write(0, self.predict_net.get_initial_state()) ) @@ -607,47 +605,74 @@ def body(time, total, B): def beam_condition(beam, beam_width, A, A_i, B): return tf.less(beam, beam_width) def beam_body(beam, beam_width, A, A_i, B): - y_hat_score, y_hat_score_index = tf.math.top_k(A.score.stack(), k=1) + # get y_hat + y_hat_score, y_hat_score_index = tf.math.top_k(A.score.stack(), k=1, sorted=True) y_hat_score = y_hat_score[0] y_hat_index = tf.gather_nd(A.indices.stack(), y_hat_score_index) y_hat_prediction = tf.gather_nd(A.prediction.stack(), y_hat_score_index) y_hat_states = tf.gather_nd(A.states.stack(), y_hat_score_index) + # remove y_hat from A + remain_indices = tf.range(0, tf.shape(A.score.stack())[0], dtype=tf.int32) + remain_indices = tf.gather_nd(remain_indices, tf.where(tf.not_equal(remain_indices, y_hat_score_index[0]))) + remain_indices = tf.expand_dims(remain_indices, axis=-1) + A = BeamHypothesis( + score=A.score.unstack(tf.gather_nd(A.score.stack(), remain_indices)), + indices=A.indices.unstack(tf.gather_nd(A.indices.stack(), remain_indices)), + prediction=A.prediction.unstack(tf.gather_nd(A.prediction.stack(), remain_indices)), + states=A.states.unstack(tf.gather_nd(A.states.stack(), remain_indices)), + ) + A_i = tf.cond(tf.equal(A_i, 0), true_fn=lambda: A_i, false_fn=lambda: A_i - 1) + ytu, new_states = self.decoder_inference(encoded=encoded_t, predicted=y_hat_index, states=y_hat_states) def predict_condition(pred, A, A_i, B): return tf.less(pred, self.text_featurizer.num_classes) def predict_body(pred, A, A_i, B): new_score = y_hat_score + tf.gather_nd(ytu, tf.expand_dims(pred, axis=-1)) + + def true_fn(): + return ( + B.score.write(beam, new_score), + B.indices.write(beam, y_hat_index), + B.prediction.write(beam, y_hat_prediction), + B.states.write(beam, y_hat_states), + A.score, + A.indices, + A.prediction, + A.states, + A_i, + ) + + def false_fn(): + scatter_index = count_non_blank(y_hat_prediction, blank=self.text_featurizer.blank) + updated_prediction = tf.tensor_scatter_nd_update( + y_hat_prediction, + indices=tf.reshape(scatter_index, [1, 1]), + updates=tf.expand_dims(pred, axis=-1) + ) + return ( + B.score, + B.indices, + B.prediction, + B.states, + A.score.write(A_i, new_score), + A.indices.write(A_i, pred), + A.prediction.write(A_i, updated_prediction), + A.states.write(A_i, new_states), + A_i + 1 + ) + b_score, b_indices, b_prediction, b_states, \ a_score, a_indices, a_prediction, a_states, A_i = tf.cond( tf.equal(pred, self.text_featurizer.blank), - true_fn=lambda: ( - B.score.write(beam, new_score), - B.indices.write(beam, y_hat_index), - B.prediction.write(beam, y_hat_prediction), - B.states.write(beam, y_hat_states), - A.score, - A.indices, - A.prediction, - A.states, - A_i, - ), - false_fn=lambda: ( - B.score, - B.indices, - B.prediction, - B.states, - A.score.write(A_i, new_score), - A.indices.write(A_i, pred), - A.prediction.write(A_i, tf.strings.reduce_join( - [y_hat_prediction, tf.strings.format("{}", pred)], separator=" ")), - A.states.write(A_i, new_states), - A_i + 1 - ) + true_fn=true_fn, + false_fn=false_fn ) + B = BeamHypothesis(score=b_score, indices=b_indices, prediction=b_prediction, states=b_states) A = BeamHypothesis(score=a_score, indices=a_indices, prediction=a_prediction, states=a_states) + return pred + 1, A, A_i, B _, A, A_i, B = tf.while_loop( @@ -674,7 +699,7 @@ def predict_body(pred, A, A_i, B): scores = B.score.stack() if self.text_featurizer.decoder_config.norm_score: - prediction_lengths = tf.strings.length(B.prediction.stack(), unit="UTF8_CHAR") + prediction_lengths = count_non_blank(B.prediction.stack(), blank=self.text_featurizer.blank, axis=1) scores /= tf.cast(prediction_lengths, dtype=scores.dtype) y_hat_score, y_hat_score_index = tf.math.top_k(scores, k=1) @@ -683,11 +708,7 @@ def predict_body(pred, A, A_i, B): y_hat_prediction = tf.gather_nd(B.prediction.stack(), y_hat_score_index) y_hat_states = tf.gather_nd(B.states.stack(), y_hat_score_index) - return Hypothesis( - index=y_hat_index, - prediction=tf.strings.to_number(tf.strings.split(y_hat_prediction), out_type=tf.int32), - states=y_hat_states - ) + return Hypothesis(index=y_hat_index, prediction=y_hat_prediction, states=y_hat_states) # -------------------------------- TFLITE ------------------------------------- diff --git a/tensorflow_asr/utils/utils.py b/tensorflow_asr/utils/utils.py index e3e35da211..101ee2d15a 100755 --- a/tensorflow_asr/utils/utils.py +++ b/tensorflow_asr/utils/utils.py @@ -153,3 +153,7 @@ def log10(x): def get_reduced_length(length, reduction_factor): return tf.cast(tf.math.ceil(tf.divide(length, tf.cast(reduction_factor, dtype=length.dtype))), dtype=tf.int32) + + +def count_non_blank(tensor: tf.Tensor, blank: int or tf.Tensor = 0, axis=None): + return tf.reduce_sum(tf.where(tf.not_equal(tensor, blank), x=tf.ones_like(tensor), y=tf.zeros_like(tensor)), axis=axis) From 79d22e1ae44acaaa3651f76dafe8e617e7c769d3 Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Sun, 3 Jan 2021 13:59:26 +0700 Subject: [PATCH 2/6] :rocket: update batch decoding --- .../featurizers/text_featurizers.py | 5 +- tensorflow_asr/models/transducer.py | 60 +++++++++++-------- 2 files changed, 36 insertions(+), 29 deletions(-) diff --git a/tensorflow_asr/featurizers/text_featurizers.py b/tensorflow_asr/featurizers/text_featurizers.py index 04e1c2c5a2..6f4385a622 100755 --- a/tensorflow_asr/featurizers/text_featurizers.py +++ b/tensorflow_asr/featurizers/text_featurizers.py @@ -270,12 +270,11 @@ def iextract(self, indices: tf.Tensor) -> tf.Tensor: clear_after_read=False, element_shape=tf.TensorShape([]) ) - def cond(batch, total, transcripts): return tf.less(batch, total) + def cond(batch, total, _): return tf.less(batch, total) def body(batch, total, transcripts): upoints = self.indices2upoints(indices[batch]) - _transcript = tf.strings.unicode_encode(upoints, "UTF-8") - transcripts = transcripts.write(batch, _transcript) + transcripts = transcripts.write(batch, tf.strings.unicode_encode(upoints, "UTF-8")) return batch + 1, total, transcripts _, _, transcripts = tf.while_loop(cond, body, loop_vars=[batch, total, transcripts]) diff --git a/tensorflow_asr/models/transducer.py b/tensorflow_asr/models/transducer.py index bace2e371f..15b12a6c88 100755 --- a/tensorflow_asr/models/transducer.py +++ b/tensorflow_asr/models/transducer.py @@ -402,18 +402,17 @@ def _perform_greedy_batch(self, encoded_length: tf.Tensor, parallel_iterations: int = 10, swap_memory: bool = False): - total = tf.shape(encoded)[0] + total_batch, total_time, _ = shape_list(encoded) batch = tf.constant(0, dtype=tf.int32) decoded = tf.TensorArray( - dtype=tf.string, - size=total, dynamic_size=False, - clear_after_read=False, element_shape=tf.TensorShape([]) + dtype=tf.int32, size=total_batch, dynamic_size=False, + clear_after_read=False, element_shape=None ) - def condition(batch, total, encoded, encoded_length, decoded): return tf.less(batch, total) + def condition(batch, _): return tf.less(batch, total_batch) - def body(batch, total, encoded, encoded_length, decoded): + def body(batch, decoded): hypothesis = self._perform_greedy( encoded=encoded[batch], encoded_length=encoded_length[batch], @@ -422,18 +421,22 @@ def body(batch, total, encoded, encoded_length, decoded): parallel_iterations=parallel_iterations, swap_memory=swap_memory ) - transcripts = self.text_featurizer.iextract(tf.expand_dims(hypothesis.prediction, axis=0)) - decoded = decoded.write(batch, tf.squeeze(transcripts)) - return batch + 1, total, encoded, encoded_length, decoded + prediction = tf.pad( + hypothesis.prediction, + paddings=[[0, total_time - encoded_length[batch]]], + mode="CONSTANT", constant_values=self.text_featurizer.blank + ) + decoded = decoded.write(batch, prediction) + return batch + 1, decoded - batch, total, _, _, decoded = tf.while_loop( + batch, decoded = tf.while_loop( condition, body, - loop_vars=[batch, total, encoded, encoded_length, decoded], + loop_vars=[batch, decoded], parallel_iterations=parallel_iterations, swap_memory=True, ) - return decoded.stack() + return self.text_featurizer.iextract(decoded.stack()) def _perform_greedy(self, encoded: tf.Tensor, @@ -518,32 +521,37 @@ def _perform_beam_search_batch(self, lm: bool = False, parallel_iterations: int = 10, swap_memory: bool = False): - total = tf.shape(encoded)[0] + total_batch, total_time, _ = shape_list(encoded) batch = tf.constant(0, dtype=tf.int32) decoded = tf.TensorArray( - dtype=tf.string, - size=total, dynamic_size=False, - clear_after_read=False, element_shape=tf.TensorShape([]) + dtype=tf.int32, size=total_batch, dynamic_size=False, + clear_after_read=False, element_shape=None ) - def condition(batch, total, encoded, encoded_length, decoded): return tf.less(batch, total) + def condition(batch, _): return tf.less(batch, total_batch) - def body(batch, total, encoded, encoded_length, decoded): - hypothesis = self._perform_beam_search(encoded[batch], encoded_length[batch], lm, - parallel_iterations=parallel_iterations, swap_memory=swap_memory) - transcripts = self.text_featurizer.iextract(tf.expand_dims(hypothesis.prediction, axis=0)) - decoded = decoded.write(batch, tf.squeeze(transcripts)) - return batch + 1, total, encoded, encoded_length, decoded + def body(batch, decoded): + hypothesis = self._perform_beam_search( + encoded[batch], encoded_length[batch], lm, + parallel_iterations=parallel_iterations, swap_memory=swap_memory + ) + prediction = tf.pad( + hypothesis.prediction, + paddings=[[0, total_time - encoded_length[batch]]], + mode="CONSTANT", constant_values=self.text_featurizer.blank + ) + decoded = decoded.write(batch, prediction) + return batch + 1, decoded - batch, total, _, _, decoded = tf.while_loop( + batch, decoded = tf.while_loop( condition, body, - loop_vars=[batch, total, encoded, encoded_length, decoded], + loop_vars=[batch, decoded], parallel_iterations=parallel_iterations, swap_memory=True, ) - return decoded.stack() + return self.text_featurizer.iextract(decoded.stack()) def _perform_beam_search(self, encoded: tf.Tensor, From df5462d7e800521a722b2c76988fdf24bf0801ef Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Sun, 3 Jan 2021 15:50:30 +0700 Subject: [PATCH 3/6] :writing_hand: fix tfrecord test dataset --- tensorflow_asr/datasets/asr_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow_asr/datasets/asr_dataset.py b/tensorflow_asr/datasets/asr_dataset.py index 374d458942..695f530f5e 100755 --- a/tensorflow_asr/datasets/asr_dataset.py +++ b/tensorflow_asr/datasets/asr_dataset.py @@ -238,9 +238,9 @@ def create(self, batch_size): class ASRTFRecordTestDataset(ASRTFRecordDataset): - def preprocess(self, path, transcript): + def preprocess(self, path, audio, transcript): with tf.device("/CPU:0"): - signal = read_raw_audio(path.decode("utf-8"), self.speech_featurizer.sample_rate) + signal = read_raw_audio(audio, self.speech_featurizer.sample_rate) features = self.speech_featurizer.extract(signal) features = tf.convert_to_tensor(features, tf.float32) @@ -262,8 +262,8 @@ def parse(self, record): return tf.numpy_function( self.preprocess, - inp=[example["audio"], example["transcript"]], - Tout=(tf.string, tf.float32, tf.int32, tf.int32) + inp=[example["path"], example["audio"], example["transcript"]], + Tout=[tf.string, tf.float32, tf.int32, tf.int32] ) def process(self, dataset, batch_size): From 49841e518533c1159623c9531b29a096f41ce523 Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Sun, 3 Jan 2021 16:16:03 +0700 Subject: [PATCH 4/6] :writing_hand: use tf.nn.swish for backward compatibility --- tensorflow_asr/models/conformer.py | 6 ++---- tensorflow_asr/models/contextnet.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/tensorflow_asr/models/conformer.py b/tensorflow_asr/models/conformer.py index 0b798525f0..cffa0c3efc 100755 --- a/tensorflow_asr/models/conformer.py +++ b/tensorflow_asr/models/conformer.py @@ -45,8 +45,7 @@ def __init__(self, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer ) - self.swish = tf.keras.layers.Activation( - tf.keras.activations.swish, name=f"{name}_swish_activation") + self.swish = tf.keras.layers.Activation(tf.nn.swish, name=f"{name}_swish_activation") self.do1 = tf.keras.layers.Dropout(dropout, name=f"{name}_dropout_1") self.ffn2 = tf.keras.layers.Dense( input_dim, name=f"{name}_dense_2", @@ -168,8 +167,7 @@ def __init__(self, gamma_regularizer=kernel_regularizer, beta_regularizer=bias_regularizer ) - self.swish = tf.keras.layers.Activation( - tf.keras.activations.swish, name=f"{name}_swish_activation") + self.swish = tf.keras.layers.Activation(tf.nn.swish, name=f"{name}_swish_activation") self.pw_conv_2 = tf.keras.layers.Conv2D( filters=input_dim, kernel_size=1, strides=1, padding="valid", name=f"{name}_pw_conv_2", diff --git a/tensorflow_asr/models/contextnet.py b/tensorflow_asr/models/contextnet.py index 331b75b130..075cc471a5 100644 --- a/tensorflow_asr/models/contextnet.py +++ b/tensorflow_asr/models/contextnet.py @@ -23,7 +23,7 @@ def get_activation(activation: str = "silu"): activation = activation.lower() - if activation in ["silu", "swish"]: return tf.nn.silu + if activation in ["silu", "swish"]: return tf.nn.swish elif activation == "relu": return tf.nn.relu elif activation == "linear": return tf.keras.activations.linear else: raise ValueError("activation must be either 'silu', 'swish', 'relu' or 'linear'") From 66b79c7f6158eeeb84ff176e2a6d818318b8c291 Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Sun, 3 Jan 2021 16:18:58 +0700 Subject: [PATCH 5/6] :writing_hand: fix total steps in gradient accumulation --- tensorflow_asr/runners/base_runners.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_asr/runners/base_runners.py b/tensorflow_asr/runners/base_runners.py index 6e26b95889..c381969cad 100644 --- a/tensorflow_asr/runners/base_runners.py +++ b/tensorflow_asr/runners/base_runners.py @@ -125,7 +125,7 @@ def set_train_data_loader(self, train_dataset, train_bs=None, train_acs=None): self.train_data = train_dataset.create(self.global_batch_size) self.train_data_loader = self.strategy.experimental_distribute_dataset(self.train_data) - if hasattr(self, "accumulation"): + if hasattr(self, "accumulation") and train_dataset.total_steps is not None: self.train_steps_per_epoch = train_dataset.total_steps // self.config.accumulation_steps else: self.train_steps_per_epoch = train_dataset.total_steps From b11e464a2ffc24db16fd6fd9ac28a8b25e122629 Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Sun, 3 Jan 2021 17:54:58 +0700 Subject: [PATCH 6/6] :rocket: add recognize with timestamp for streaming transducer --- setup.py | 2 +- tensorflow_asr/models/streaming_transducer.py | 41 +++++++-- tests/jasper/config.yml | 2 +- tests/streaming_transducer/config.yml | 87 +++++++++++++++++++ .../test_streaming_transducer.py | 57 ++++++++++++ 5 files changed, 181 insertions(+), 8 deletions(-) create mode 100644 tests/streaming_transducer/config.yml create mode 100644 tests/streaming_transducer/test_streaming_transducer.py diff --git a/setup.py b/setup.py index 21e1cc3025..66334675df 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,7 @@ setuptools.setup( name="TensorFlowASR", - version="0.6.2", + version="0.6.3", author="Huy Le Nguyen", author_email="nlhuy.cs.16@gmail.com", description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2", diff --git a/tensorflow_asr/models/streaming_transducer.py b/tensorflow_asr/models/streaming_transducer.py index b423e6c3e2..3ce396eaad 100644 --- a/tensorflow_asr/models/streaming_transducer.py +++ b/tensorflow_asr/models/streaming_transducer.py @@ -297,6 +297,36 @@ def recognize_tflite(self, signal, predicted, encoder_states, prediction_states) hypothesis.states ) + def recognize_tflite_with_timestamp(self, signal, predicted, encoder_states, prediction_states): + features = self.speech_featurizer.tf_extract(signal) + encoded, new_encoder_states = self.encoder_inference(features, encoder_states) + hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, prediction_states) + indices = self.text_featurizer.normalize_indices(hypothesis.prediction) + upoints = tf.gather_nd(self.text_featurizer.upoints, tf.expand_dims(indices, axis=-1)) # [None, max_subword_length] + + num_samples = tf.cast(tf.shape(signal)[0], dtype=tf.float32) + total_time_reduction_factor = self.time_reduction_factor * self.speech_featurizer.frame_step + + stime = tf.range(0, num_samples, delta=total_time_reduction_factor, dtype=tf.float32) + stime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32) + + etime = tf.range(total_time_reduction_factor, num_samples, delta=total_time_reduction_factor, dtype=tf.float32) + etime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32) + + non_blank = tf.where(tf.not_equal(upoints, 0)) + non_blank_transcript = tf.gather_nd(upoints, non_blank) + non_blank_stime = tf.gather_nd(tf.repeat(tf.expand_dims(stime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank) + non_blank_etime = tf.gather_nd(tf.repeat(tf.expand_dims(etime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank) + + return ( + non_blank_transcript, + non_blank_stime, + non_blank_etime, + hypothesis.prediction, + new_encoder_states, + hypothesis.states + ) + # -------------------------------- BEAM SEARCH ------------------------------------- @tf.function @@ -325,15 +355,14 @@ def recognize_beam(self, # -------------------------------- TFLITE ------------------------------------- - def make_tflite_function(self, greedy: bool = True): + def make_tflite_function(self, timestamp: bool = True): + tflite_func = self.recognize_tflite_with_timestamp if timestamp else self.recognize_tflite return tf.function( - self.recognize_tflite, + tflite_func, input_signature=[ tf.TensorSpec([None], dtype=tf.float32), tf.TensorSpec([], dtype=tf.int32), - tf.TensorSpec(self.encoder.get_initial_state().get_shape(), - dtype=tf.float32), - tf.TensorSpec(self.predict_net.get_initial_state().get_shape(), - dtype=tf.float32) + tf.TensorSpec(self.encoder.get_initial_state().get_shape(), dtype=tf.float32), + tf.TensorSpec(self.predict_net.get_initial_state().get_shape(), dtype=tf.float32) ] ) diff --git a/tests/jasper/config.yml b/tests/jasper/config.yml index a785862d06..686c10db10 100644 --- a/tests/jasper/config.yml +++ b/tests/jasper/config.yml @@ -40,7 +40,7 @@ model_config: first_additional_block_strides: 2 first_additional_block_dilation: 1 first_additional_block_dropout: 0.2 - nsubblocks: 3 + nsubblocks: 1 block_channels: [256, 384, 512, 640, 768] block_kernels: [11, 13, 17, 21, 25] block_dropout: [0.2, 0.2, 0.2, 0.3, 0.3] diff --git a/tests/streaming_transducer/config.yml b/tests/streaming_transducer/config.yml new file mode 100644 index 0000000000..ff2c6a4ed5 --- /dev/null +++ b/tests/streaming_transducer/config.yml @@ -0,0 +1,87 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +speech_config: + sample_rate: 16000 + frame_ms: 25 + stride_ms: 10 + num_feature_bins: 80 + feature_type: log_mel_spectrogram + preemphasis: 0.97 + normalize_signal: True + normalize_feature: True + normalize_per_feature: False + +decoder_config: + vocabulary: null + target_vocab_size: 1024 + max_subword_length: 4 + blank_at_zero: True + beam_width: 5 + norm_score: True + +model_config: + name: streaming_transducer + encoder_reductions: + 0: 3 + 1: 2 + encoder_dmodel: 320 + encoder_rnn_type: lstm + encoder_rnn_units: 1024 + encoder_nlayers: 2 + encoder_layer_norm: True + prediction_embed_dim: 320 + prediction_embed_dropout: 0.0 + prediction_num_rnns: 2 + prediction_rnn_units: 1024 + prediction_rnn_type: lstm + prediction_projection_units: 320 + prediction_layer_norm: True + joint_dim: 320 + joint_activation: tanh + +learning_config: + augmentations: + after: + time_masking: + num_masks: 10 + mask_factor: 100 + p_upperbound: 0.05 + freq_masking: + num_masks: 1 + mask_factor: 27 + + dataset_config: + train_paths: + - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv + eval_paths: + - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-clean/transcripts.tsv + - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv + test_paths: + - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv + tfrecords_dir: null + + optimizer_config: + class_name: adam + config: + learning_rate: 0.0001 + + running_config: + batch_size: 2 + accumulation_steps: 1 + num_epochs: 20 + outdir: /mnt/Miscellanea/Models/local/streaming_transducer + log_interval_steps: 300 + eval_interval_steps: 500 + save_interval_steps: 1000 diff --git a/tests/streaming_transducer/test_streaming_transducer.py b/tests/streaming_transducer/test_streaming_transducer.py new file mode 100644 index 0000000000..0b4e524df7 --- /dev/null +++ b/tests/streaming_transducer/test_streaming_transducer.py @@ -0,0 +1,57 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +import tensorflow as tf + +DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml") + +from tensorflow_asr.configs.config import Config +from tensorflow_asr.models.streaming_transducer import StreamingTransducer +from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer +from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer + + +def test_streaming_transducer(): + config = Config(DEFAULT_YAML, learning=False) + + text_featurizer = CharFeaturizer(config.decoder_config) + + speech_featurizer = TFSpeechFeaturizer(config.speech_config) + + model = StreamingTransducer(vocabulary_size=text_featurizer.num_classes, **config.model_config) + + model._build(speech_featurizer.shape) + model.summary(line_length=150) + + model.add_featurizers(speech_featurizer=speech_featurizer, text_featurizer=text_featurizer) + + concrete_func = model.make_tflite_function(timestamp=False).get_concrete_function() + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.experimental_new_converter = True + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] + converter.convert() + + print("Converted successfully with no timestamp") + + concrete_func = model.make_tflite_function(timestamp=True).get_concrete_function() + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.experimental_new_converter = True + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] + converter.convert() + + print("Converted successfully with timestamp")