From 9a9f1a6c32c9a7cbb9f6b72939132cb4ddfde196 Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Sun, 27 Dec 2020 13:15:57 +0700 Subject: [PATCH 1/2] :rocket: add unittest using pytest --- examples/conformer/config.yml | 1 + examples/contextnet/config.yml | 1 + examples/deepspeech2/README.md | 3 +- examples/jasper/README.md | 3 +- examples/streaming_transducer/config.yml | 1 + tensorflow_asr/models/conformer.py | 2 + tensorflow_asr/models/contextnet.py | 4 +- tensorflow_asr/models/ctc.py | 5 +- tensorflow_asr/models/streaming_transducer.py | 2 + tensorflow_asr/models/transducer.py | 12 +- tests/conformer/config.yml | 95 +++++++ tests/conformer/test_conformer.py | 57 +++++ tests/contextnet/config.yml | 231 ++++++++++++++++++ tests/contextnet/test_contextnet.py | 60 +++++ tests/deepspeech2/config.yml | 76 ++++++ tests/deepspeech2/test_ds2.py | 57 +++++ tests/jasper/config.yml | 83 +++++++ tests/jasper/test_jasper.py | 57 +++++ tests/plot_learning_rate.py | 31 --- tests/specaugment_test.py | 53 ---- tests/test_conformer.py | 128 ---------- tests/test_ctc.py | 104 -------- tests/test_dataset.py | 64 ----- tests/test_pos_enc.py | 43 ---- ...izer_test.py => test_speech_featurizer.py} | 9 +- tests/test_subword.py | 33 --- tests/test_text_featurizer.py | 7 - tests/test_transducer.py | 121 --------- 28 files changed, 745 insertions(+), 598 deletions(-) create mode 100644 tests/conformer/config.yml create mode 100644 tests/conformer/test_conformer.py create mode 100644 tests/contextnet/config.yml create mode 100644 tests/contextnet/test_contextnet.py create mode 100644 tests/deepspeech2/config.yml create mode 100644 tests/deepspeech2/test_ds2.py create mode 100644 tests/jasper/config.yml create mode 100644 tests/jasper/test_jasper.py delete mode 100755 tests/plot_learning_rate.py delete mode 100755 tests/specaugment_test.py delete mode 100644 tests/test_conformer.py delete mode 100644 tests/test_ctc.py delete mode 100644 tests/test_dataset.py delete mode 100755 tests/test_pos_enc.py rename tests/{speech_featurizer_test.py => test_speech_featurizer.py} (90%) mode change 100755 => 100644 delete mode 100644 tests/test_subword.py delete mode 100644 tests/test_text_featurizer.py delete mode 100644 tests/test_transducer.py diff --git a/examples/conformer/config.yml b/examples/conformer/config.yml index ed05251409..db47374810 100755 --- a/examples/conformer/config.yml +++ b/examples/conformer/config.yml @@ -56,6 +56,7 @@ model_config: prediction_layer_norm: True prediction_projection_units: 0 joint_dim: 320 + joint_activation: tanh learning_config: augmentations: diff --git a/examples/contextnet/config.yml b/examples/contextnet/config.yml index a68f93d43d..7b5d8d2333 100644 --- a/examples/contextnet/config.yml +++ b/examples/contextnet/config.yml @@ -192,6 +192,7 @@ model_config: prediction_layer_norm: True prediction_projection_units: 0 joint_dim: 640 + joint_activation: tanh learning_config: augmentations: diff --git a/examples/deepspeech2/README.md b/examples/deepspeech2/README.md index 7f9e072e6c..0e79f8bf34 100755 --- a/examples/deepspeech2/README.md +++ b/examples/deepspeech2/README.md @@ -29,4 +29,5 @@ model_config: See `python examples/deepspeech2/train_*.py --help` -See `python examples/deepspeech2/test_*.py --help` \ No newline at end of file +See `python examples/deepspeech2/test_*.py --help` + diff --git a/examples/jasper/README.md b/examples/jasper/README.md index 2c3266894f..f666a5cdba 100755 --- a/examples/jasper/README.md +++ b/examples/jasper/README.md @@ -37,4 +37,5 @@ model_config: See `python examples/jasper/train_*.py --help` -See `python examples/jasper/test_*.py --help` \ No newline at end of file +See `python examples/jasper/test_*.py --help` + diff --git a/examples/streaming_transducer/config.yml b/examples/streaming_transducer/config.yml index 01eaf281d4..54551e3c5b 100755 --- a/examples/streaming_transducer/config.yml +++ b/examples/streaming_transducer/config.yml @@ -49,6 +49,7 @@ model_config: prediction_projection_units: 320 prediction_layer_norm: True joint_dim: 320 + joint_activation: tanh learning_config: augmentations: diff --git a/tensorflow_asr/models/conformer.py b/tensorflow_asr/models/conformer.py index f8860180cd..0b798525f0 100755 --- a/tensorflow_asr/models/conformer.py +++ b/tensorflow_asr/models/conformer.py @@ -384,6 +384,7 @@ def __init__(self, prediction_layer_norm: bool = True, prediction_projection_units: int = 0, joint_dim: int = 1024, + joint_activation: str = "tanh", kernel_regularizer=L2, bias_regularizer=L2, name: str = "conformer_transducer", @@ -414,6 +415,7 @@ def __init__(self, layer_norm=prediction_layer_norm, projection_units=prediction_projection_units, joint_dim=joint_dim, + joint_activation=joint_activation, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, name=name, **kwargs diff --git a/tensorflow_asr/models/contextnet.py b/tensorflow_asr/models/contextnet.py index 0d80b02dc4..331b75b130 100644 --- a/tensorflow_asr/models/contextnet.py +++ b/tensorflow_asr/models/contextnet.py @@ -196,7 +196,7 @@ class ContextNet(Transducer): def __init__(self, vocabulary_size: int, encoder_blocks: List[dict], - encoder_alpha: float, + encoder_alpha: float = 0.5, prediction_embed_dim: int = 512, prediction_embed_dropout: int = 0, prediction_num_rnns: int = 1, @@ -206,6 +206,7 @@ def __init__(self, prediction_layer_norm: bool = True, prediction_projection_units: int = 0, joint_dim: int = 1024, + joint_activation: str = "tanh", kernel_regularizer=L2, bias_regularizer=L2, name: str = "contextnet", @@ -228,6 +229,7 @@ def __init__(self, layer_norm=prediction_layer_norm, projection_units=prediction_projection_units, joint_dim=joint_dim, + joint_activation=joint_activation, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, name=name, **kwargs diff --git a/tensorflow_asr/models/ctc.py b/tensorflow_asr/models/ctc.py index 2335e7c74b..4cbb163eaa 100644 --- a/tensorflow_asr/models/ctc.py +++ b/tensorflow_asr/models/ctc.py @@ -25,6 +25,7 @@ class CtcModel(Model): def __init__(self, **kwargs): super(CtcModel, self).__init__(**kwargs) + self.time_reduction_factor = 1 def _build(self, input_shape): features = tf.keras.Input(input_shape, dtype=tf.float32) @@ -67,7 +68,7 @@ def recognize_tflite(self, signal): features = self.speech_featurizer.tf_extract(signal) features = tf.expand_dims(features, axis=0) input_length = shape_list(features)[1] - input_length = get_reduced_length(input_length, self.base_model.time_reduction_factor) + input_length = get_reduced_length(input_length, self.time_reduction_factor) input_length = tf.expand_dims(input_length, axis=0) logits = self(features, training=False) probs = tf.nn.softmax(logits) @@ -113,7 +114,7 @@ def recognize_beam_tflite(self, signal): features = self.speech_featurizer.tf_extract(signal) features = tf.expand_dims(features, axis=0) input_length = shape_list(features)[1] - input_length = get_reduced_length(input_length, self.base_model.time_reduction_factor) + input_length = get_reduced_length(input_length, self.time_reduction_factor) input_length = tf.expand_dims(input_length, axis=0) logits = self(features, training=False) probs = tf.nn.softmax(logits) diff --git a/tensorflow_asr/models/streaming_transducer.py b/tensorflow_asr/models/streaming_transducer.py index 4b9d14f3c2..c2b0c26150 100644 --- a/tensorflow_asr/models/streaming_transducer.py +++ b/tensorflow_asr/models/streaming_transducer.py @@ -192,6 +192,7 @@ def __init__(self, prediction_layer_norm: bool = True, prediction_projection_units: int = 640, joint_dim: int = 640, + joint_activation: str = "tanh", kernel_regularizer = None, bias_regularizer = None, name = "StreamingTransducer", @@ -217,6 +218,7 @@ def __init__(self, layer_norm=prediction_layer_norm, projection_units=prediction_projection_units, joint_dim=joint_dim, + joint_activation=joint_activation, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, name=name, **kwargs diff --git a/tensorflow_asr/models/transducer.py b/tensorflow_asr/models/transducer.py index ec40de3469..11263a66bb 100755 --- a/tensorflow_asr/models/transducer.py +++ b/tensorflow_asr/models/transducer.py @@ -146,11 +146,19 @@ class TransducerJoint(tf.keras.Model): def __init__(self, vocabulary_size: int, joint_dim: int = 1024, + activation: str = "tanh", kernel_regularizer=None, bias_regularizer=None, name="tranducer_joint", **kwargs): super(TransducerJoint, self).__init__(name=name, **kwargs) + + activation = activation.lower() + if activation == "linear": self.activation = tf.keras.activation.linear + elif activation == "relu": self.activation = tf.nn.relu + elif activation == "tanh": self.activation = tf.nn.tanh + else: raise ValueError("activation must be either 'linear', 'relu' or 'tanh'") + self.ffn_enc = tf.keras.layers.Dense( joint_dim, name=f"{name}_enc", kernel_regularizer=kernel_regularizer, @@ -174,7 +182,7 @@ def call(self, inputs, training=False, **kwargs): pred_out = self.ffn_pred(pred_out, training=training) # [B, U, P] => [B, U, V] enc_out = tf.expand_dims(enc_out, axis=2) pred_out = tf.expand_dims(pred_out, axis=1) - outputs = tf.nn.tanh(enc_out + pred_out) # => [B, T, U, V] + outputs = self.activation(enc_out + pred_out) # => [B, T, U, V] outputs = self.ffn_out(outputs, training=training) return outputs @@ -200,6 +208,7 @@ def __init__(self, layer_norm: bool = True, projection_units: int = 0, joint_dim: int = 1024, + joint_activation: str = "tanh", kernel_regularizer=None, bias_regularizer=None, name="transducer", @@ -223,6 +232,7 @@ def __init__(self, self.joint_net = TransducerJoint( vocabulary_size=vocabulary_size, joint_dim=joint_dim, + activation=joint_activation, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, name=f"{name}_joint" diff --git a/tests/conformer/config.yml b/tests/conformer/config.yml new file mode 100644 index 0000000000..db47374810 --- /dev/null +++ b/tests/conformer/config.yml @@ -0,0 +1,95 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +speech_config: + sample_rate: 16000 + frame_ms: 25 + stride_ms: 10 + num_feature_bins: 80 + feature_type: log_mel_spectrogram + preemphasis: 0.97 + normalize_signal: True + normalize_feature: True + normalize_per_feature: False + +decoder_config: + vocabulary: null + target_vocab_size: 1024 + max_subword_length: 4 + blank_at_zero: True + beam_width: 5 + norm_score: True + +model_config: + name: conformer + encoder_subsampling: + type: conv2d + filters: 144 + kernel_size: 3 + strides: 2 + encoder_positional_encoding: sinusoid_concat + encoder_dmodel: 144 + encoder_num_blocks: 16 + encoder_head_size: 36 + encoder_num_heads: 4 + encoder_mha_type: relmha + encoder_kernel_size: 32 + encoder_fc_factor: 0.5 + encoder_dropout: 0.1 + prediction_embed_dim: 320 + prediction_embed_dropout: 0 + prediction_num_rnns: 1 + prediction_rnn_units: 320 + prediction_rnn_type: lstm + prediction_rnn_implementation: 1 + prediction_layer_norm: True + prediction_projection_units: 0 + joint_dim: 320 + joint_activation: tanh + +learning_config: + augmentations: + after: + time_masking: + num_masks: 10 + mask_factor: 100 + p_upperbound: 0.05 + freq_masking: + num_masks: 1 + mask_factor: 27 + + dataset_config: + train_paths: + - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv + eval_paths: + - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-clean/transcripts.tsv + - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv + test_paths: + - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv + tfrecords_dir: null + + optimizer_config: + warmup_steps: 40000 + beta1: 0.9 + beta2: 0.98 + epsilon: 1e-9 + + running_config: + batch_size: 2 + accumulation_steps: 4 + num_epochs: 20 + outdir: /mnt/Miscellanea/Models/local/conformer + log_interval_steps: 300 + eval_interval_steps: 500 + save_interval_steps: 1000 diff --git a/tests/conformer/test_conformer.py b/tests/conformer/test_conformer.py new file mode 100644 index 0000000000..9a2b55201a --- /dev/null +++ b/tests/conformer/test_conformer.py @@ -0,0 +1,57 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +import tensorflow as tf + +DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml") + +from tensorflow_asr.configs.config import Config +from tensorflow_asr.models.conformer import Conformer +from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer +from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer + + +def test_conformer(): + config = Config(DEFAULT_YAML, learning=False) + + text_featurizer = CharFeaturizer(config.decoder_config) + + speech_featurizer = TFSpeechFeaturizer(config.speech_config) + + model = Conformer(vocabulary_size=text_featurizer.num_classes, **config.model_config) + + model._build(speech_featurizer.shape) + model.summary(line_length=150) + + model.add_featurizers(speech_featurizer=speech_featurizer, text_featurizer=text_featurizer) + + concrete_func = model.make_tflite_function(timestamp=False).get_concrete_function() + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.experimental_new_converter = True + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] + converter.convert() + + print("Converted successfully with no timestamp") + + concrete_func = model.make_tflite_function(timestamp=True).get_concrete_function() + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.experimental_new_converter = True + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] + converter.convert() + + print("Converted successfully with timestamp") diff --git a/tests/contextnet/config.yml b/tests/contextnet/config.yml new file mode 100644 index 0000000000..7b5d8d2333 --- /dev/null +++ b/tests/contextnet/config.yml @@ -0,0 +1,231 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +speech_config: + sample_rate: 16000 + frame_ms: 25 + stride_ms: 10 + num_feature_bins: 80 + feature_type: log_mel_spectrogram + preemphasis: 0.97 + normalize_signal: True + normalize_feature: True + normalize_per_feature: False + +decoder_config: + vocabulary: null + target_vocab_size: 1024 + max_subword_length: 4 + blank_at_zero: True + beam_width: 5 + norm_score: True + +model_config: + name: contextnet + encoder_alpha: 0.5 + encoder_blocks: + # C0 + - nlayers: 1 + kernel_size: 5 + filters: 256 + strides: 1 + residual: False + activation: silu + # C1-C2 + - nlayers: 5 + kernel_size: 5 + filters: 256 + strides: 1 + residual: True + activation: silu + - nlayers: 5 + kernel_size: 5 + filters: 256 + strides: 1 + residual: True + activation: silu + # C3 + - nlayers: 5 + kernel_size: 5 + filters: 256 + strides: 2 + residual: True + activation: silu + # C4-C6 + - nlayers: 5 + kernel_size: 5 + filters: 256 + strides: 1 + residual: True + activation: silu + - nlayers: 5 + kernel_size: 5 + filters: 256 + strides: 1 + residual: True + activation: silu + - nlayers: 5 + kernel_size: 5 + filters: 256 + strides: 1 + residual: True + activation: silu + # C7 + - nlayers: 5 + kernel_size: 5 + filters: 256 + strides: 2 + residual: True + activation: silu + # C8 - C10 + - nlayers: 5 + kernel_size: 5 + filters: 256 + strides: 1 + residual: True + activation: silu + - nlayers: 5 + kernel_size: 5 + filters: 256 + strides: 1 + residual: True + activation: silu + - nlayers: 5 + kernel_size: 5 + filters: 256 + strides: 1 + residual: True + activation: silu + # C11 - C13 + - nlayers: 5 + kernel_size: 5 + filters: 512 + strides: 1 + residual: True + activation: silu + - nlayers: 5 + kernel_size: 5 + filters: 512 + strides: 1 + residual: True + activation: silu + - nlayers: 5 + kernel_size: 5 + filters: 512 + strides: 1 + residual: True + activation: silu + # C14 + - nlayers: 5 + kernel_size: 5 + filters: 512 + strides: 2 + residual: True + activation: silu + # C15 - C21 + - nlayers: 5 + kernel_size: 5 + filters: 512 + strides: 1 + residual: True + activation: silu + - nlayers: 5 + kernel_size: 5 + filters: 512 + strides: 1 + residual: True + activation: silu + - nlayers: 5 + kernel_size: 5 + filters: 512 + strides: 1 + residual: True + activation: silu + - nlayers: 5 + kernel_size: 5 + filters: 512 + strides: 1 + residual: True + activation: silu + - nlayers: 5 + kernel_size: 5 + filters: 512 + strides: 1 + residual: True + activation: silu + - nlayers: 5 + kernel_size: 5 + filters: 512 + strides: 1 + residual: True + activation: silu + - nlayers: 5 + kernel_size: 5 + filters: 512 + strides: 1 + residual: True + activation: silu + # C22 + - nlayers: 1 + kernel_size: 5 + filters: 640 + strides: 1 + residual: False + activation: silu + prediction_embed_dim: 640 + prediction_embed_dropout: 0 + prediction_num_rnns: 1 + prediction_rnn_units: 640 + prediction_rnn_type: lstm + prediction_rnn_implementation: 1 + prediction_layer_norm: True + prediction_projection_units: 0 + joint_dim: 640 + joint_activation: tanh + +learning_config: + augmentations: + after: + time_masking: + num_masks: 10 + mask_factor: 100 + p_upperbound: 0.05 + freq_masking: + num_masks: 1 + mask_factor: 27 + + dataset_config: + train_paths: + - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv + eval_paths: + - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-clean/transcripts.tsv + - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv + test_paths: + - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv + tfrecords_dir: null + + optimizer_config: + warmup_steps: 40000 + beta1: 0.9 + beta2: 0.98 + epsilon: 1e-9 + + running_config: + batch_size: 2 + accumulation_steps: 4 + num_epochs: 20 + outdir: /mnt/Miscellanea/Models/local/contextnet + log_interval_steps: 300 + eval_interval_steps: 500 + save_interval_steps: 1000 diff --git a/tests/contextnet/test_contextnet.py b/tests/contextnet/test_contextnet.py new file mode 100644 index 0000000000..4a9328b12f --- /dev/null +++ b/tests/contextnet/test_contextnet.py @@ -0,0 +1,60 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +import tensorflow as tf + +DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml") + +from tensorflow_asr.configs.config import Config +from tensorflow_asr.models.contextnet import ContextNet +from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer +from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer + + +def test_contextnet(): + config = Config(DEFAULT_YAML, learning=False) + + text_featurizer = CharFeaturizer(config.decoder_config) + + speech_featurizer = TFSpeechFeaturizer(config.speech_config) + + model = ContextNet(vocabulary_size=text_featurizer.num_classes, **config.model_config) + + model._build(speech_featurizer.shape) + model.summary(line_length=150) + + model.add_featurizers( + speech_featurizer=speech_featurizer, + text_featurizer=text_featurizer + ) + + concrete_func = model.make_tflite_function(timestamp=False).get_concrete_function() + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.experimental_new_converter = True + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] + converter.convert() + + print("Converted successfully with no timestamp") + + concrete_func = model.make_tflite_function(timestamp=True).get_concrete_function() + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.experimental_new_converter = True + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] + converter.convert() + + print("Converted successfully with timestamp") diff --git a/tests/deepspeech2/config.yml b/tests/deepspeech2/config.yml new file mode 100644 index 0000000000..3c6f6d12f5 --- /dev/null +++ b/tests/deepspeech2/config.yml @@ -0,0 +1,76 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +speech_config: + sample_rate: 16000 + frame_ms: 25 + stride_ms: 10 + num_feature_bins: 80 + feature_type: spectrogram + preemphasis: 0.97 + normalize_signal: True + normalize_feature: True + normalize_per_feature: False + +decoder_config: + vocabulary: null + blank_at_zero: False + beam_width: 500 + lm_config: + model_path: null + alpha: 2.0 + beta: 1.0 + +model_config: + name: deepspeech2 + conv_type: conv2d + conv_kernels: [[11, 41], [11, 21], [11, 11]] + conv_strides: [[2, 2], [1, 2], [1, 2]] + conv_filters: [32, 32, 96] + conv_dropout: 0.1 + rnn_nlayers: 5 + rnn_type: lstm + rnn_units: 512 + rnn_bidirectional: True + rnn_rowconv: 0 + rnn_dropout: 0.1 + fc_nlayers: 0 + fc_units: 1024 + +learning_config: + augmentations: null + + dataset_config: + train_paths: + - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv + eval_paths: + - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-clean/transcripts.tsv + - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv + test_paths: + - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv + tfrecords_dir: null + + optimizer_config: + class_name: adam + config: + learning_rate: 0.0001 + + running_config: + batch_size: 4 + num_epochs: 20 + accumulation_steps: 8 + outdir: /mnt/Miscellanea/Models/local/deepspeech2 + log_interval_steps: 400 + save_interval_steps: 400 + eval_interval_steps: 800 diff --git a/tests/deepspeech2/test_ds2.py b/tests/deepspeech2/test_ds2.py new file mode 100644 index 0000000000..9b8a994ae0 --- /dev/null +++ b/tests/deepspeech2/test_ds2.py @@ -0,0 +1,57 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +import tensorflow as tf + +DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml") + +from tensorflow_asr.configs.config import Config +from tensorflow_asr.models.deepspeech2 import DeepSpeech2 +from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer +from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer + + +def test_ds2(): + config = Config(DEFAULT_YAML, learning=False) + + text_featurizer = CharFeaturizer(config.decoder_config) + + speech_featurizer = TFSpeechFeaturizer(config.speech_config) + + model = DeepSpeech2(vocabulary_size=text_featurizer.num_classes, **config.model_config) + + model._build(speech_featurizer.shape) + model.summary(line_length=150) + + model.add_featurizers(speech_featurizer=speech_featurizer, text_featurizer=text_featurizer) + + concrete_func = model.make_tflite_function(greedy=False).get_concrete_function() + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.experimental_new_converter = True + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] + converter.convert() + + print("Converted successfully with beam search") + + concrete_func = model.make_tflite_function(greedy=True).get_concrete_function() + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.experimental_new_converter = True + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] + converter.convert() + + print("Converted successfully with greedy") diff --git a/tests/jasper/config.yml b/tests/jasper/config.yml new file mode 100644 index 0000000000..a785862d06 --- /dev/null +++ b/tests/jasper/config.yml @@ -0,0 +1,83 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +speech_config: + sample_rate: 16000 + frame_ms: 25 + stride_ms: 10 + num_feature_bins: 80 + feature_type: log_mel_spectrogram + preemphasis: 0.97 + normalize_signal: True + normalize_feature: True + normalize_per_feature: False + +decoder_config: + vocabulary: null + blank_at_zero: False + beam_width: 500 + lm_config: + model_path: null + alpha: 2.0 + beta: 1.0 + +model_config: + name: jasper + dense: True + first_additional_block_channels: 256 + first_additional_block_kernels: 11 + first_additional_block_strides: 2 + first_additional_block_dilation: 1 + first_additional_block_dropout: 0.2 + nsubblocks: 3 + block_channels: [256, 384, 512, 640, 768] + block_kernels: [11, 13, 17, 21, 25] + block_dropout: [0.2, 0.2, 0.2, 0.3, 0.3] + second_additional_block_channels: 896 + second_additional_block_kernels: 1 + second_additional_block_strides: 1 + second_additional_block_dilation: 2 + second_additional_block_dropout: 0.4 + third_additional_block_channels: 1024 + third_additional_block_kernels: 1 + third_additional_block_strides: 1 + third_additional_block_dilation: 1 + third_additional_block_dropout: 0.4 + +learning_config: + augmentations: null + + dataset_config: + train_paths: + - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv + eval_paths: + - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-clean/transcripts.tsv + - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv + test_paths: + - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv + tfrecords_dir: null + + optimizer_config: + class_name: adam + config: + learning_rate: 0.0001 + + running_config: + batch_size: 4 + num_epochs: 20 + accumulation_steps: 8 + outdir: /mnt/Miscellanea/Models/local/jasper + log_interval_steps: 400 + save_interval_steps: 400 + eval_interval_steps: 800 diff --git a/tests/jasper/test_jasper.py b/tests/jasper/test_jasper.py new file mode 100644 index 0000000000..536c91c23b --- /dev/null +++ b/tests/jasper/test_jasper.py @@ -0,0 +1,57 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +import tensorflow as tf + +DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml") + +from tensorflow_asr.configs.config import Config +from tensorflow_asr.models.jasper import Jasper +from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer +from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer + + +def test_jasper(): + config = Config(DEFAULT_YAML, learning=False) + + text_featurizer = CharFeaturizer(config.decoder_config) + + speech_featurizer = TFSpeechFeaturizer(config.speech_config) + + model = Jasper(vocabulary_size=text_featurizer.num_classes, **config.model_config) + + model._build(speech_featurizer.shape) + model.summary(line_length=150) + + model.add_featurizers(speech_featurizer=speech_featurizer, text_featurizer=text_featurizer) + + concrete_func = model.make_tflite_function(greedy=False).get_concrete_function() + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.experimental_new_converter = True + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] + converter.convert() + + print("Converted successfully with beam search") + + concrete_func = model.make_tflite_function(greedy=True).get_concrete_function() + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.experimental_new_converter = True + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] + converter.convert() + + print("Converted successfully with greedy") diff --git a/tests/plot_learning_rate.py b/tests/plot_learning_rate.py deleted file mode 100755 index 65e3f41770..0000000000 --- a/tests/plot_learning_rate.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@usimarit) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tensorflow as tf -import matplotlib.pyplot as plt -from tensorflow_asr.optimizers.schedules import SANSchedule, TransformerSchedule - -lr = SANSchedule(lamb=0.05, d_model=512, warmup_steps=4000) - -plt.plot(lr(tf.range(40000, dtype=tf.float32))) -plt.ylabel("Learning Rate") -plt.xlabel("Train Step") -# plt.show() - -lr = TransformerSchedule(d_model=144, warmup_steps=10000) - -plt.plot(lr(tf.range(2000000, dtype=tf.float32))) -plt.ylabel("Learning Rate") -plt.xlabel("Train Step") -# plt.show() diff --git a/tests/specaugment_test.py b/tests/specaugment_test.py deleted file mode 100755 index 2c8b906cbe..0000000000 --- a/tests/specaugment_test.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@usimarit) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# from __future__ import absolute_import -# import matplotlib.pyplot as plt -# from augmentations.augments import TimeWarping, TimeMasking, FreqMasking -# from featurizers.speech_featurizers import SpeechFeaturizer -# import tensorflow as tf -# -# import sys -# import os.path as o -# sys.path.append(o.abspath(o.join(o.dirname(sys.modules[__name__].__file__), ".."))) -# -# -# def main(argv): -# fm = FreqMasking(num_freq_mask=2) -# tm = TimeMasking() -# tw = TimeWarping() -# -# speech_file = argv[1] -# sf = SpeechFeaturizer(sample_rate=16000, frame_ms=20, stride_ms=10, num_feature_bins=128) -# ft = sf.compute_speech_features(speech_file) -# -# plt.figure(figsize=(15, 5)) -# -# plt.subplot(2, 1, 1) -# plt.imshow(tf.transpose(tf.squeeze(ft))) -# -# ft = fm(ft) -# -# print(ft) -# -# ft = tf.squeeze(ft) -# ft = tf.transpose(ft) -# -# plt.subplot(2, 1, 2) -# plt.imshow(ft) -# plt.show() -# -# -# if __name__ == "__main__": -# main(sys.argv) diff --git a/tests/test_conformer.py b/tests/test_conformer.py deleted file mode 100644 index c1d41a16ec..0000000000 --- a/tests/test_conformer.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@usimarit) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -# import datetime -import sys -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" -import tensorflow as tf - -from tensorflow_asr.models.conformer import Conformer -# from tensorflow_asr.models.transducer import Transducer -# from tensorflow_asr.models.layers.subsampling import Conv2dSubsampling -from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer -from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer, read_raw_audio - -text_featurizer = CharFeaturizer({ - "vocabulary": None, - "blank_at_zero": True, - "beam_width": 5, - "norm_score": True -}) - -speech_featurizer = TFSpeechFeaturizer({ - "sample_rate": 16000, - "frame_ms": 25, - "stride_ms": 10, - "num_feature_bins": 80, - "feature_type": "log_mel_spectrogram", - "preemphasis": 0.97, - "normalize_signal": True, - "normalize_feature": True, - "normalize_per_feature": False -}) - -# i = tf.keras.Input(shape=[None, 80, 1]) -# o = Conv2dSubsampling(144)(i) - -# encoder = tf.keras.Model(inputs=i, outputs=o) -# model = Transducer(encoder=encoder, vocabulary_size=text_featurizer.num_classes) - -model = Conformer( - subsampling={"type": "conv2d", "filters": 144, "kernel_size": 3, - "strides": 2}, - num_blocks=1, - vocabulary_size=text_featurizer.num_classes) - -model._build(speech_featurizer.shape) -model.summary(line_length=150) - -model.save_weights("/tmp/transducer.h5") - -model.add_featurizers( - speech_featurizer=speech_featurizer, - text_featurizer=text_featurizer -) - -# features = tf.zeros(shape=[5, 50, 80, 1], dtype=tf.float32) -# pred = model.recognize(features) -# print(pred) -# pred = model.recognize_beam(features) -# print(pred) - -# stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") -# logdir = '/tmp/logs/func/%s' % stamp -# writer = tf.summary.create_file_writer(logdir) -# -signal = read_raw_audio(sys.argv[1], speech_featurizer.sample_rate) -# -# tf.summary.trace_on(graph=True, profiler=True) -# hyps = model.recognize_tflite(signal, 0, tf.zeros([1, 2, 1, 320], dtype=tf.float32)) -# with writer.as_default(): -# tf.summary.trace_export( -# name="recognize_tflite", -# step=0, -# profiler_outdir=logdir) -# -# print(hyps[0]) -# -# # hyps = model.recognize_beam(features) -# -# - -# hyps = model.recognize_beam(tf.expand_dims(speech_featurizer.tf_extract(signal), 0)) - -# print(hyps) - -# hyps = model.recognize_beam_tflite(signal) - -# print(hyps.numpy().decode("utf-8")) - -concrete_func = model.make_tflite_function(greedy=True).get_concrete_function() -converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) -converter.optimizations = [tf.lite.Optimize.DEFAULT] -converter.experimental_new_converter = True -converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, - tf.lite.OpsSet.SELECT_TF_OPS] -tflite = converter.convert() - -tflitemodel = tf.lite.Interpreter(model_content=tflite) - -input_details = tflitemodel.get_input_details() -output_details = tflitemodel.get_output_details() -tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape) -tflitemodel.allocate_tensors() -tflitemodel.set_tensor(input_details[0]["index"], signal) -tflitemodel.set_tensor( - input_details[1]["index"], - tf.constant(text_featurizer.blank, dtype=tf.int32) -) -tflitemodel.set_tensor( - input_details[2]["index"], - tf.zeros([1, 2, 1, 320], dtype=tf.float32) -) -tflitemodel.invoke() -hyp = tflitemodel.get_tensor(output_details[0]["index"]) - -print(hyp) diff --git a/tests/test_ctc.py b/tests/test_ctc.py deleted file mode 100644 index 026f67ea79..0000000000 --- a/tests/test_ctc.py +++ /dev/null @@ -1,104 +0,0 @@ -import tensorflow as tf - -from ctc_decoders import Scorer -from tensorflow_asr.models.ctc import CtcModel -from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer -from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer -from tensorflow_asr.utils.utils import bytes_to_string, merge_two_last_dims - -decoder_config = { - "vocabulary": "/mnt/Projects/asrk16/TiramisuASR/vocabularies/vietnamese.txt", - "beam_width": 100, - "blank_at_zero": False, - "lm_config": { - "model_path": "/mnt/Data/ML/NLP/vntc_asrtrain_5gram_trie.binary", - "alpha": 2.0, - "beta": 2.0 - } -} -text_featurizer = CharFeaturizer(decoder_config) -text_featurizer.add_scorer(Scorer(**decoder_config["lm_config"], - vocabulary=text_featurizer.vocab_array)) -speech_featurizer = TFSpeechFeaturizer({ - "sample_rate": 16000, - "frame_ms": 25, - "stride_ms": 10, - "num_feature_bins": 80, - "feature_type": "spectrogram", - "preemphasis": 0.97, - # "delta": True, - # "delta_delta": True, - "normalize_signal": True, - "normalize_feature": True, - "normalize_per_feature": False, - # "pitch": False, -}) - -inp = tf.keras.Input(shape=[None, 80, 3]) - - -class BaseModel(tf.keras.Model): - def __init__(self, name="basemodel", **kwargs): - super().__init__(name=name, **kwargs) - self.dense = tf.keras.layers.Dense(350) - self.lstm = tf.keras.layers.LSTM(350, return_sequences=True) - self.time_reduction_factor = 1 - - @tf.function - def call(self, inputs, training=False, **kwargs): - outputs = merge_two_last_dims(inputs) - outputs = self.lstm(outputs, training=training) - return self.dense(outputs, training=training) - - -model = CtcModel(base_model=BaseModel(), num_classes=text_featurizer.num_classes) - -model._build(speech_featurizer.shape) -model.summary(line_length=100) -model.add_featurizers( - speech_featurizer=speech_featurizer, - text_featurizer=text_featurizer -) - -features = tf.random.normal(shape=[5, 50, 80, 1], dtype=tf.float32) -hyp = model.recognize(features) -print(bytes_to_string(hyp.numpy())) - -hyp = model.recognize_beam(features) -print(bytes_to_string(hyp.numpy())) - -hyp = model.recognize_beam(features, lm=True) -print(bytes_to_string(hyp.numpy())) - -# signal = read_raw_audio("/home/nlhuy/Desktop/test/11003.wav", speech_featurizer.sample_rate) -signal = tf.random.normal(shape=[500], dtype=tf.float32) - -hyp = model.recognize_tflite(signal) -print(hyp.numpy()) - -hyp = model.recognize_beam_tflite(signal) -print(hyp.numpy()) -# -# hyp = model.recognize_beam_tflite(signal, lm=True) -# print(hyp.numpy().decode("utf-8")) - -concrete_func = model.make_tflite_function(greedy=False).get_concrete_function() -converter = tf.lite.TFLiteConverter.from_concrete_functions( - [concrete_func] -) -converter.optimizations = [tf.lite.Optimize.DEFAULT] -converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, - tf.lite.OpsSet.SELECT_TF_OPS] -tflite = converter.convert() - -tflitemodel = tf.lite.Interpreter(model_content=tflite) - -input_details = tflitemodel.get_input_details() -output_details = tflitemodel.get_output_details() -tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape) -tflitemodel.allocate_tensors() -tflitemodel.set_tensor(input_details[0]["index"], signal) -tflitemodel.invoke() -hyp = tflitemodel.get_tensor(output_details[0]["index"]) - -print(hyp) diff --git a/tests/test_dataset.py b/tests/test_dataset.py deleted file mode 100644 index 21815e8aeb..0000000000 --- a/tests/test_dataset.py +++ /dev/null @@ -1,64 +0,0 @@ -import os -import psutil -process = psutil.Process(os.getpid()) - -from tensorflow_asr.utils import setup_environment -setup_environment() -from tensorflow_asr.datasets.asr_dataset import ASRSliceDataset -from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer -from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer - - -augments = { - "before": { - "loudness": { - "zone": (0.3, 0.7) - }, - "speed": None, - "noise": { - "noises": "/mnt/Data/ML/ASR/Preprocessed/Noises/train" - } - }, - "after": { - "time_masking": { - "num_masks": 10, - "mask_factor": 100, - "p_upperbound": 0.05 - }, - "freq_masking": { - "mask_factor": 27 - } - }, - "include_original": False -} - -data = "/mnt/Data/ML/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv" - -text_featurizer = CharFeaturizer({ - "vocabulary": None, - "blank_at_zero": True, - "beam_width": 5, - "norm_score": True -}) - -speech_featurizer = TFSpeechFeaturizer({ - "sample_rate": 16000, - "frame_ms": 25, - "stride_ms": 10, - "num_feature_bins": 80, - "feature_type": "log_mel_spectrogram", - "preemphasis": 0.97, - "normalize_signal": True, - "normalize_feature": True, - "normalize_per_feature": False -}) - - -dataset = ASRSliceDataset(stage="train", speech_featurizer=speech_featurizer, - text_featurizer=text_featurizer, data_paths=[data], - augmentations=augments, shuffle=True).create(4).take(100) - -while True: - print("--------------------------------------------") - for i, batch in enumerate(dataset): - print(process.memory_info().rss) diff --git a/tests/test_pos_enc.py b/tests/test_pos_enc.py deleted file mode 100755 index 26543a03c0..0000000000 --- a/tests/test_pos_enc.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@usimarit) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tensorflow as tf -import matplotlib.pyplot as plt -from tensorflow_asr.models.layers.positional_encoding import PositionalEncodingConcat -from tensorflow_asr.models.layers.multihead_attention import RelPositionMultiHeadAttention - -pos_encoding = PositionalEncodingConcat.encode(500, 144) -print(pos_encoding.shape) - -plt.pcolormesh(pos_encoding[0], cmap='RdBu') -plt.xlabel('Depth') -plt.xlim((0, 144)) -plt.ylabel('Position') -plt.colorbar() -# plt.show() - -rel = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])[None, None, ...] -rel_shift = RelPositionMultiHeadAttention.relative_shift(rel) -print(tf.reduce_all(tf.equal(rel, rel_shift))) - -plt.figure(figsize=(15, 5)) - -plt.subplot(2, 1, 1) -plt.imshow(rel[0][0]) -plt.colorbar() - -plt.subplot(2, 1, 2) -plt.imshow(rel_shift[0][0]) -plt.colorbar() -# plt.show() diff --git a/tests/speech_featurizer_test.py b/tests/test_speech_featurizer.py old mode 100755 new mode 100644 similarity index 90% rename from tests/speech_featurizer_test.py rename to tests/test_speech_featurizer.py index 3424560750..40a060a890 --- a/tests/speech_featurizer_test.py +++ b/tests/test_speech_featurizer.py @@ -11,20 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# from __future__ import absolute_import, print_function -# -# import os.path as o import sys from tensorflow_asr.utils import setup_environment setup_environment() import librosa import numpy as np -# sys.path.append(o.abspath(o.join(o.dirname(sys.modules[__name__].__file__), ".."))) -# import matplotlib.pyplot as plt -from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio, \ - TFSpeechFeaturizer, NumpySpeechFeaturizer +from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio, TFSpeechFeaturizer, NumpySpeechFeaturizer def main(argv): diff --git a/tests/test_subword.py b/tests/test_subword.py deleted file mode 100644 index 7e06a0571e..0000000000 --- a/tests/test_subword.py +++ /dev/null @@ -1,33 +0,0 @@ -import argparse -import tensorflow as tf - -from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer - -parser = argparse.ArgumentParser(prog="test subword") - -parser.add_argument("transcripts", nargs="+", type=str, default=[None]) - -args = parser.parse_args() - -config = { - "vocabulary": None, - "target_vocab_size": 1024, - "max_subword_length": 4, - "blank_at_zero": True, - "beam_width": 5, - "norm_score": True -} - -text_featurizer = SubwordFeaturizer.build_from_corpus(config, args.transcripts) - -print(len(text_featurizer.subwords.subwords)) -print(text_featurizer.upoints) -print(text_featurizer.num_classes) - -a = text_featurizer.extract("hello world") - -print(a) - -b = text_featurizer.indices2upoints(a) - -tf.print(tf.strings.unicode_encode(b, "UTF-8")) diff --git a/tests/test_text_featurizer.py b/tests/test_text_featurizer.py deleted file mode 100644 index 6e9a969f77..0000000000 --- a/tests/test_text_featurizer.py +++ /dev/null @@ -1,7 +0,0 @@ -from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer - -txf = CharFeaturizer(None, blank_at_zero=True) - -a = txf.extract("fkaff aksfbfnak kcjhoiu") - -print(a) diff --git a/tests/test_transducer.py b/tests/test_transducer.py deleted file mode 100644 index e221e76c7f..0000000000 --- a/tests/test_transducer.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@usimarit) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -# import datetime -import sys -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" -import tensorflow as tf - -from tensorflow_asr.models.streaming_transducer import StreamingTransducer -from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer -from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer, read_raw_audio - -text_featurizer = CharFeaturizer({ - "vocabulary": None, - "blank_at_zero": True, - "beam_width": 5, - "norm_score": True -}) - -speech_featurizer = TFSpeechFeaturizer({ - "sample_rate": 16000, - "frame_ms": 25, - "stride_ms": 10, - "num_feature_bins": 80, - "feature_type": "log_mel_spectrogram", - "preemphasis": 0.97, - "normalize_signal": True, - "normalize_feature": True, - "normalize_per_feature": False -}) - -model = StreamingTransducer(vocabulary_size=text_featurizer.num_classes, - encoder_dmodel=320, encoder_nlayers=3) - -model._build(speech_featurizer.shape) -model.summary(line_length=150) - -model.save_weights("/tmp/transducer.h5") - -model.add_featurizers( - speech_featurizer=speech_featurizer, - text_featurizer=text_featurizer -) - -features = tf.zeros(shape=[5, 50, 80, 1], dtype=tf.float32) -pred = model.recognize(features) -print(pred) -pred = model.recognize_beam(features) -print(pred) - -# stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") -# logdir = '/tmp/logs/func/%s' % stamp -# writer = tf.summary.create_file_writer(logdir) -# -signal = read_raw_audio(sys.argv[1], speech_featurizer.sample_rate) -# -# tf.summary.trace_on(graph=True, profiler=True) -# hyps = model.recognize_tflite(signal, 0, tf.zeros([1, 2, 1, 320], dtype=tf.float32)) -# with writer.as_default(): -# tf.summary.trace_export( -# name="recognize_tflite", -# step=0, -# profiler_outdir=logdir) -# -# print(hyps[0]) -# -# # hyps = model.recognize_beam(features) -# -# - -# hyps = model.recognize_beam(tf.expand_dims(speech_featurizer.tf_extract(signal), 0)) - -# print(hyps) - -# hyps = model.recognize_beam_tflite(signal) - -# print(hyps.numpy().decode("utf-8")) - -concrete_func = model.make_tflite_function(greedy=True).get_concrete_function() -converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) -converter.optimizations = [tf.lite.Optimize.DEFAULT] -converter.experimental_new_converter = True -converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, - tf.lite.OpsSet.SELECT_TF_OPS] -tflite = converter.convert() - -tflitemodel = tf.lite.Interpreter(model_content=tflite) - -input_details = tflitemodel.get_input_details() -output_details = tflitemodel.get_output_details() -tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape) -tflitemodel.allocate_tensors() -tflitemodel.set_tensor(input_details[0]["index"], signal) -tflitemodel.set_tensor( - input_details[1]["index"], - tf.constant(text_featurizer.blank, dtype=tf.int32) -) -tflitemodel.set_tensor( - input_details[2]["index"], - tf.zeros([3, 2, 1, 2048], dtype=tf.float32) -) -tflitemodel.set_tensor( - input_details[3]["index"], - tf.zeros([2, 2, 1, 2048], dtype=tf.float32) -) -tflitemodel.invoke() -hyp = tflitemodel.get_tensor(output_details[0]["index"]) - -print(hyp) From 4a97b8748349c12b8161da21db91b5d63e836687 Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Sun, 27 Dec 2020 13:17:02 +0700 Subject: [PATCH 2/2] :rocket: release v0.6.1 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 3bfb61767d..a56eb5a197 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ setuptools.setup( name="TensorFlowASR", - version="0.6.0", + version="0.6.1", author="Huy Le Nguyen", author_email="nlhuy.cs.16@gmail.com", description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",