diff --git a/examples/conformer/tflite_subword_conformer.py b/examples/conformer/tflite_subword_conformer.py index 51222ce71e..1d3c7844e5 100644 --- a/examples/conformer/tflite_subword_conformer.py +++ b/examples/conformer/tflite_subword_conformer.py @@ -64,9 +64,9 @@ concrete_func = conformer.make_tflite_function(greedy=True).get_concrete_function() converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) +converter.experimental_new_converter = True converter.optimizations = [tf.lite.Optimize.DEFAULT] -converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, - tf.lite.OpsSet.SELECT_TF_OPS] +converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] tflite_model = converter.convert() if not os.path.exists(os.path.dirname(args.output)): diff --git a/setup.py b/setup.py index 8f723380e3..8890072241 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ setuptools.setup( name="TensorFlowASR", - version="0.5.1", + version="0.5.2", author="Huy Le Nguyen", author_email="nlhuy.cs.16@gmail.com", description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2", diff --git a/tensorflow_asr/models/transducer.py b/tensorflow_asr/models/transducer.py index 50490d92f4..5f3e415d48 100755 --- a/tensorflow_asr/models/transducer.py +++ b/tensorflow_asr/models/transducer.py @@ -417,73 +417,51 @@ def __perform_greedy(self, with tf.name_scope(f"{self.name}_greedy"): time = tf.constant(0, dtype=tf.int32) total = encoded_length - # Initialize prediction with a blank - # Prediction can not be longer than the encoded of audio plus blank - prediction = tf.TensorArray( - dtype=tf.int32, - size=(total + 1), - dynamic_size=False, - element_shape=tf.TensorShape([]), - clear_after_read=False - ) hypothesis = Hypothesis( index=tf.constant(0, dtype=tf.int32), - prediction=prediction.write(0, predicted), + prediction=tf.ones([total + 1], dtype=tf.int32) * self.text_featurizer.blank, states=states ) def condition(time, total, encoded, hypothesis): return tf.less(time, total) def body(time, total, encoded, hypothesis): + predicted = tf.gather_nd(hypothesis.prediction, tf.expand_dims(hypothesis.index, axis=-1)) + ytu, new_states = self.decoder_inference( # avoid using [index] in tflite encoded=tf.gather_nd(encoded, tf.expand_dims(time, axis=-1)), - predicted=hypothesis.prediction.read(hypothesis.index), + predicted=predicted, states=hypothesis.states ) - char = tf.argmax(ytu, axis=-1, output_type=tf.int32) # => argmax [] - - index, char, new_states = tf.cond( - tf.equal(char, self.text_featurizer.blank), - true_fn=lambda: ( - hypothesis.index, - hypothesis.prediction.read(hypothesis.index), - hypothesis.states - ), - false_fn=lambda: ( - hypothesis.index + 1, - char, - new_states - ) + new_predicted = tf.argmax(ytu, axis=-1, output_type=tf.int32) # => argmax [] + + index, new_predicted, new_states = tf.cond( + tf.equal(new_predicted, self.text_featurizer.blank), + true_fn=lambda: (hypothesis.index, predicted, hypothesis.states), + false_fn=lambda: (hypothesis.index + 1, new_predicted, new_states) ) hypothesis = Hypothesis( index=index, - prediction=hypothesis.prediction.write(index, char), + prediction=tf.tensor_scatter_nd_update( + hypothesis.prediction, + indices=tf.reshape(index, [1, 1]), + updates=tf.expand_dims(new_predicted, axis=-1) + ), states=new_states ) return time + 1, total, encoded, hypothesis time, total, encoded, hypothesis = tf.while_loop( - condition, - body, + condition, body, loop_vars=(time, total, encoded, hypothesis), parallel_iterations=parallel_iterations, swap_memory=swap_memory ) - # Gather predicted sequence - hypothesis = Hypothesis( - index=hypothesis.index, - prediction=tf.gather_nd( - params=hypothesis.prediction.stack(), - indices=tf.expand_dims(tf.range(hypothesis.index + 1), axis=-1) - ), - states=hypothesis.states - ) - return hypothesis # -------------------------------- BEAM SEARCH -------------------------------------