diff --git a/examples/deepspeech2/README.md b/examples/deepspeech2/README.md
index 7ececc8b3e..1b46916bf5 100755
--- a/examples/deepspeech2/README.md
+++ b/examples/deepspeech2/README.md
@@ -6,22 +6,19 @@ References: [https://arxiv.org/abs/1512.02595](https://arxiv.org/abs/1512.02595)
```yaml
model_config:
- conv_conf:
- conv_type: 2
- conv_kernels: [[11, 41], [11, 21], [11, 11]]
- conv_strides: [[2, 2], [1, 2], [1, 2]]
- conv_filters: [32, 32, 96]
- conv_dropout: 0
- rnn_conf:
- rnn_layers: 5
- rnn_type: lstm
- rnn_units: 512
- rnn_bidirectional: True
- rnn_rowconv: False
- rnn_dropout: 0
- fc_conf:
- fc_units: [1024]
- fc_dropout: 0
+ conv_type: conv2d
+ conv_kernels: [[11, 41], [11, 21], [11, 11]]
+ conv_strides: [[2, 2], [1, 2], [1, 2]]
+ conv_filters: [32, 32, 96]
+ conv_dropout: 0.1
+ rnn_nlayers: 5
+ rnn_type: lstm
+ rnn_units: 512
+ rnn_bidirectional: True
+ rnn_rowconv: 0
+ rnn_dropout: 0.1
+ fc_nlayers: 0
+ fc_units: 1024
```
## Architecture
@@ -30,24 +27,6 @@ model_config:
## Training and Testing
-See `python examples/deepspeech2/run_ds2.py --help`
+See `python examples/deepspeech2/train_ds2.py --help`
-## Results on VIVOS Dataset
-
-* Features: Spectrogram with `80` frequency channels
-* KenLM: `alpha = 2.0` and `beta = 1.0`
-* Epochs: `20`
-* Train set split ratio: `90:10`
-* Augmentation: `None`
-* Model architecture: same as [vivos.yaml](./configs/vivos.yml)
-
-**CTC Loss**
-
-
-
-**Error rates**
-
-| | WER (%) | CER (%) |
-| :-------------- | :------------: | :------------: |
-| *BeamSearch* | 43.75243 | 17.991581 |
-| *BeamSearch LM* | **20.7561836** | **11.0304441** |
\ No newline at end of file
+See `python examples/deepspeech2/test_ds2.py --help`
\ No newline at end of file
diff --git a/examples/deepspeech2/configs/vivos.yml b/examples/deepspeech2/config.yml
similarity index 76%
rename from examples/deepspeech2/configs/vivos.yml
rename to examples/deepspeech2/config.yml
index 60ecbb785e..ee43e06404 100755
--- a/examples/deepspeech2/configs/vivos.yml
+++ b/examples/deepspeech2/config.yml
@@ -24,7 +24,7 @@ speech_config:
normalize_per_feature: False
decoder_config:
- vocabulary: /mnt/Projects/asrk16/TiramisuASR/vocabularies/vietnamese.txt
+ vocabulary: ./vocabularies/vietnamese.characters
blank_at_zero: False
beam_width: 500
lm_config:
@@ -33,21 +33,20 @@ decoder_config:
beta: 1.0
model_config:
- conv_conf:
- conv_type: 2
- conv_kernels: [[11, 41], [11, 21], [11, 11]]
- conv_strides: [[2, 2], [1, 2], [1, 2]]
- conv_filters: [32, 32, 96]
- conv_dropout: 0
- rnn_conf:
- rnn_layers: 5
- rnn_type: lstm
- rnn_units: 512
- rnn_bidirectional: True
- rnn_rowconv: False
- rnn_dropout: 0
- fc_conf:
- fc_units: null
+ name: deepspeech2
+ conv_type: conv2d
+ conv_kernels: [[11, 41], [11, 21], [11, 11]]
+ conv_strides: [[2, 2], [1, 2], [1, 2]]
+ conv_filters: [32, 32, 96]
+ conv_dropout: 0.1
+ rnn_nlayers: 5
+ rnn_type: lstm
+ rnn_units: 512
+ rnn_bidirectional: True
+ rnn_rowconv: 0
+ rnn_dropout: 0.1
+ fc_nlayers: 0
+ fc_units: 1024
learning_config:
augmentations: null
diff --git a/examples/deepspeech2/figs/ds2_vivos_ctc_loss.svg b/examples/deepspeech2/figs/ds2_vivos_ctc_loss.svg
deleted file mode 100755
index 9af3bd046d..0000000000
--- a/examples/deepspeech2/figs/ds2_vivos_ctc_loss.svg
+++ /dev/null
@@ -1 +0,0 @@
-
\ No newline at end of file
diff --git a/examples/deepspeech2/model.py b/examples/deepspeech2/model.py
deleted file mode 100755
index f25e0ab858..0000000000
--- a/examples/deepspeech2/model.py
+++ /dev/null
@@ -1,148 +0,0 @@
-# Copyright 2020 Huy Le Nguyen (@usimarit)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Read https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTM
-to use cuDNN-LSTM
-"""
-import numpy as np
-import tensorflow as tf
-
-from tensorflow_asr.utils.utils import append_default_keys_dict, get_rnn
-from tensorflow_asr.models.layers.row_conv_1d import RowConv1D
-from tensorflow_asr.models.layers.sequence_wise_bn import SequenceBatchNorm
-from tensorflow_asr.models.layers.transpose_time_major import TransposeTimeMajor
-from tensorflow_asr.models.layers.merge_two_last_dims import Merge2LastDims
-from tensorflow_asr.models.ctc import CtcModel
-
-DEFAULT_CONV = {
- "conv_type": 2,
- "conv_kernels": ((11, 41), (11, 21), (11, 21)),
- "conv_strides": ((2, 2), (1, 2), (1, 2)),
- "conv_filters": (32, 32, 96),
- "conv_dropout": 0.2
-}
-
-DEFAULT_RNN = {
- "rnn_layers": 3,
- "rnn_type": "gru",
- "rnn_units": 350,
- "rnn_activation": "tanh",
- "rnn_bidirectional": True,
- "rnn_rowconv": False,
- "rnn_rowconv_context": 2,
- "rnn_dropout": 0.2
-}
-
-DEFAULT_FC = {
- "fc_units": (1024,),
- "fc_dropout": 0.2
-}
-
-
-def create_ds2(input_shape: list, arch_config: dict, name: str = "deepspeech2"):
- conv_conf = append_default_keys_dict(DEFAULT_CONV, arch_config.get("conv_conf", {}))
- rnn_conf = append_default_keys_dict(DEFAULT_RNN, arch_config.get("rnn_conf", {}))
- fc_conf = append_default_keys_dict(DEFAULT_FC, arch_config.get("fc_conf", {}))
- assert len(conv_conf["conv_strides"]) == \
- len(conv_conf["conv_filters"]) == len(conv_conf["conv_kernels"])
- assert conv_conf["conv_type"] in [1, 2]
- assert rnn_conf["rnn_type"] in ["lstm", "gru", "rnn"]
- assert conv_conf["conv_dropout"] >= 0.0 and rnn_conf["rnn_dropout"] >= 0.0
-
- features = tf.keras.Input(shape=input_shape, name="features")
- layer = features
-
- if conv_conf["conv_type"] == 2:
- conv = tf.keras.layers.Conv2D
- else:
- layer = Merge2LastDims("conv1d_features")(layer)
- conv = tf.keras.layers.Conv1D
- ker_shape = np.shape(conv_conf["conv_kernels"])
- stride_shape = np.shape(conv_conf["conv_strides"])
- filter_shape = np.shape(conv_conf["conv_filters"])
- assert len(ker_shape) == 1 and len(stride_shape) == 1 and len(filter_shape) == 1
-
- # CONV Layers
- for i, fil in enumerate(conv_conf["conv_filters"]):
- layer = conv(filters=fil, kernel_size=conv_conf["conv_kernels"][i],
- strides=conv_conf["conv_strides"][i], padding="same",
- activation=None, dtype=tf.float32, name=f"cnn_{i}")(layer)
- layer = tf.keras.layers.BatchNormalization(name=f"cnn_bn_{i}")(layer)
- layer = tf.keras.layers.ReLU(name=f"cnn_relu_{i}")(layer)
- layer = tf.keras.layers.Dropout(conv_conf["conv_dropout"],
- name=f"cnn_dropout_{i}")(layer)
-
- if conv_conf["conv_type"] == 2:
- layer = Merge2LastDims("reshape_conv2d_to_rnn")(layer)
-
- rnn = get_rnn(rnn_conf["rnn_type"])
-
- # To time major
- if rnn_conf["rnn_bidirectional"]:
- layer = TransposeTimeMajor("transpose_to_time_major")(layer)
-
- # RNN layers
- for i in range(rnn_conf["rnn_layers"]):
- if rnn_conf["rnn_bidirectional"]:
- layer = tf.keras.layers.Bidirectional(
- rnn(rnn_conf["rnn_units"], activation=rnn_conf["rnn_activation"],
- time_major=True, dropout=rnn_conf["rnn_dropout"],
- return_sequences=True, use_bias=True),
- name=f"b{rnn_conf['rnn_type']}_{i}")(layer)
- layer = SequenceBatchNorm(time_major=True, name=f"sequence_wise_bn_{i}")(layer)
- else:
- layer = rnn(rnn_conf["rnn_units"], activation=rnn_conf["rnn_activation"],
- dropout=rnn_conf["rnn_dropout"], return_sequences=True, use_bias=True,
- name=f"{rnn_conf['rnn_type']}_{i}")(layer)
- layer = SequenceBatchNorm(time_major=False, name=f"sequence_wise_bn_{i}")(layer)
- if rnn_conf["rnn_rowconv"]:
- layer = RowConv1D(filters=rnn_conf["rnn_units"],
- future_context=rnn_conf["rnn_rowconv_context"],
- name=f"row_conv_{i}")(layer)
-
- # To batch major
- if rnn_conf["rnn_bidirectional"]:
- layer = TransposeTimeMajor("transpose_to_batch_major")(layer)
-
- # FC Layers
- if fc_conf["fc_units"]:
- assert fc_conf["fc_dropout"] >= 0.0
-
- for idx, units in enumerate(fc_conf["fc_units"]):
- layer = tf.keras.layers.Dense(units=units, activation=None,
- use_bias=True, name=f"hidden_fc_{idx}")(layer)
- layer = tf.keras.layers.BatchNormalization(name=f"hidden_fc_bn_{idx}")(layer)
- layer = tf.keras.layers.ReLU(name=f"hidden_fc_relu_{idx}")(layer)
- layer = tf.keras.layers.Dropout(fc_conf["fc_dropout"],
- name=f"hidden_fc_dropout_{idx}")(layer)
-
- return tf.keras.Model(inputs=features, outputs=layer, name=name)
-
-
-class DeepSpeech2(CtcModel):
- def __init__(self,
- input_shape: list,
- arch_config: dict,
- num_classes: int,
- name: str = "deepspeech2"):
- super(DeepSpeech2, self).__init__(
- base_model=create_ds2(input_shape=input_shape,
- arch_config=arch_config,
- name=name),
- num_classes=num_classes,
- name=f"{name}_ctc"
- )
- self.time_reduction_factor = 1
- for s in arch_config["conv_conf"]["conv_strides"]:
- self.time_reduction_factor *= s[0]
diff --git a/examples/deepspeech2/test_ds2.py b/examples/deepspeech2/test_ds2.py
index f86ae1dfd1..f8741d32fe 100644
--- a/examples/deepspeech2/test_ds2.py
+++ b/examples/deepspeech2/test_ds2.py
@@ -19,7 +19,7 @@
setup_environment()
import tensorflow as tf
-DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "configs", "vivos.yml")
+DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
tf.keras.backend.clear_session()
@@ -54,7 +54,7 @@
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
from tensorflow_asr.runners.base_runners import BaseTester
-from model import DeepSpeech2
+from tensorflow_asr.models.deepspeech2 import DeepSpeech2
tf.random.set_seed(0)
assert args.export
@@ -63,13 +63,10 @@
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
text_featurizer = CharFeaturizer(config["decoder_config"])
# Build DS2 model
-ds2_model = DeepSpeech2(input_shape=speech_featurizer.shape,
- arch_config=config["model_config"],
- num_classes=text_featurizer.num_classes,
- name="deepspeech2")
+ds2_model = DeepSpeech2(**config["model_config"], vocabulary_size=text_featurizer.num_classes)
ds2_model._build(speech_featurizer.shape)
ds2_model.load_weights(args.saved, by_name=True)
-ds2_model.summary(line_length=150)
+ds2_model.summary(line_length=120)
ds2_model.add_featurizers(speech_featurizer, text_featurizer)
if args.tfrecords:
diff --git a/examples/deepspeech2/train_ds2.py b/examples/deepspeech2/train_ds2.py
index 3d5f15bd5b..21b74f3868 100644
--- a/examples/deepspeech2/train_ds2.py
+++ b/examples/deepspeech2/train_ds2.py
@@ -19,7 +19,7 @@
setup_environment()
import tensorflow as tf
-DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "configs", "vivos.yml")
+DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
tf.keras.backend.clear_session()
@@ -60,7 +60,7 @@
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
from tensorflow_asr.runners.ctc_runners import CTCTrainer
-from model import DeepSpeech2
+from tensorflow_asr.models.deepspeech2 import DeepSpeech2
config = UserConfig(DEFAULT_YAML, args.config, learning=True)
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
@@ -100,12 +100,9 @@
ctc_trainer = CTCTrainer(text_featurizer, config["learning_config"]["running_config"])
# Build DS2 model
with ctc_trainer.strategy.scope():
- ds2_model = DeepSpeech2(input_shape=speech_featurizer.shape,
- arch_config=config["model_config"],
- num_classes=text_featurizer.num_classes,
- name="deepspeech2")
+ ds2_model = DeepSpeech2(**config["model_config"], vocabulary_size=text_featurizer.num_classes)
ds2_model._build(speech_featurizer.shape)
- ds2_model.summary(line_length=150)
+ ds2_model.summary(line_length=120)
# Compile
ctc_trainer.compile(ds2_model, config["learning_config"]["optimizer_config"],
max_to_keep=args.max_ckpts)
diff --git a/examples/jasper/README.md b/examples/jasper/README.md
new file mode 100755
index 0000000000..1adfadeeeb
--- /dev/null
+++ b/examples/jasper/README.md
@@ -0,0 +1,20 @@
+# Jasper
+
+References: [https://arxiv.org/abs/1904.03288](https://arxiv.org/abs/1904.03288)
+
+## Model YAML Config Structure
+
+```yaml
+model_config:
+
+```
+
+## Architecture
+
+
+
+## Training and Testing
+
+See `python examples/jasper/train_jasper.py --help`
+
+See `python examples/jasper/test_jasper.py --help`
\ No newline at end of file
diff --git a/examples/sadeepspeech2/config.yml b/examples/jasper/config.yml
similarity index 57%
rename from examples/sadeepspeech2/config.yml
rename to examples/jasper/config.yml
index e76f4cd8c6..d6e62d8dbb 100755
--- a/examples/sadeepspeech2/config.yml
+++ b/examples/jasper/config.yml
@@ -24,7 +24,7 @@ speech_config:
normalize_per_feature: False
decoder_config:
- vocabulary: /mnt/Projects/asrk16/TiramisuASR/vocabularies/vietnamese.txt
+ vocabulary: ./vocabularies/vietnamese.characters
blank_at_zero: False
beam_width: 500
lm_config:
@@ -33,20 +33,27 @@ decoder_config:
beta: 1.0
model_config:
- subsampling:
- filters: 144
- kernel_size: 32
- strides: 2
- att:
- layers: 16
- head_size: 36
- num_heads: 4
- ffn_size: 1024
- dropout: 0
- rnn:
- layers: 1
- units: 320
- dropout: 0
+ name: jasper
+ dense: True
+ first_additional_block_channels: 256
+ first_additional_block_kernels: 11
+ first_additional_block_strides: 2
+ first_additional_block_dilation: 1
+ first_additional_block_dropout: 0.2
+ nsubblocks: 3
+ block_channels: [256, 384, 512, 640, 768]
+ block_kernels: [11, 13, 17, 21, 25]
+ block_dropout: [0.2, 0.2, 0.2, 0.3, 0.3]
+ second_additional_block_channels: 896
+ second_additional_block_kernels: 1
+ second_additional_block_strides: 1
+ second_additional_block_dilation: 2
+ second_additional_block_dropout: 0.4
+ third_additional_block_channels: 1024
+ third_additional_block_kernels: 1
+ third_additional_block_strides: 1
+ third_additional_block_dilation: 1
+ third_additional_block_dropout: 0.4
learning_config:
augmentations: null
@@ -61,14 +68,14 @@ learning_config:
tfrecords_dir: /mnt/Data/ML/ASR/Preprocessed/Vivos/TFRecords
optimizer_config:
- name: transformer_adam
+ class_name: adam
config:
- warmup_steps: 10000
+ learning_rate: 0.0001
running_config:
- batch_size: 2
+ batch_size: 8
num_epochs: 20
- outdir: /mnt/Projects/asrk16/trained/local/vivos_self_att_ds2
- log_interval_steps: 500
- save_interval_steps: 500
- eval_interval_steps: 700
+ outdir: /mnt/Projects/asrk16/trained/local/jasper
+ log_interval_steps: 400
+ save_interval_steps: 400
+ eval_interval_steps: 800
diff --git a/examples/jasper/figs/jasper_arch.png b/examples/jasper/figs/jasper_arch.png
new file mode 100644
index 0000000000..3ec60a7189
Binary files /dev/null and b/examples/jasper/figs/jasper_arch.png differ
diff --git a/examples/sadeepspeech2/test_sadeepspeech2.py b/examples/jasper/test_jasper.py
similarity index 66%
rename from examples/sadeepspeech2/test_sadeepspeech2.py
rename to examples/jasper/test_jasper.py
index 9040eee45d..c50294580a 100644
--- a/examples/sadeepspeech2/test_sadeepspeech2.py
+++ b/examples/jasper/test_jasper.py
@@ -1,3 +1,17 @@
+# Copyright 2020 Huy Le Nguyen (@usimarit)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
import argparse
from tensorflow_asr.utils import setup_environment, setup_devices
@@ -7,16 +21,18 @@
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
-parser = argparse.ArgumentParser(prog="Self Attention DS2")
+tf.keras.backend.clear_session()
+
+parser = argparse.ArgumentParser(prog="Jasper Testing")
parser.add_argument("--config", "-c", type=str, default=DEFAULT_YAML,
help="The file path of model configuration file")
parser.add_argument("--saved", type=str, default=None,
- help="Path to saved model")
+ help="Path to the model file to be exported")
parser.add_argument("--tfrecords", default=False, action="store_true",
- help="Whether to use tfrecords")
+ help="Whether to use tfrecords dataset")
parser.add_argument("--mxp", default=False, action="store_true",
help="Enable mixed precision")
@@ -33,34 +49,25 @@
setup_devices([args.device])
-from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
-from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
from tensorflow_asr.configs.user_config import UserConfig
from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset
-from model import SelfAttentionDS2
+from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
+from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
from tensorflow_asr.runners.base_runners import BaseTester
-from ctc_decoders import Scorer
+from tensorflow_asr.models.jasper import Jasper
tf.random.set_seed(0)
-assert args.saved
+assert args.export
config = UserConfig(DEFAULT_YAML, args.config, learning=True)
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
text_featurizer = CharFeaturizer(config["decoder_config"])
-
-text_featurizer.add_scorer(Scorer(**text_featurizer.decoder_config["lm_config"],
- vocabulary=text_featurizer.vocab_array))
-
# Build DS2 model
-satt_ds2_model = SelfAttentionDS2(
- input_shape=speech_featurizer.shape,
- arch_config=config["model_config"],
- num_classes=text_featurizer.num_classes
-)
-satt_ds2_model._build(speech_featurizer.shape)
-satt_ds2_model.load_weights(args.saved, by_name=True)
-satt_ds2_model.summary(line_length=150)
-satt_ds2_model.add_featurizers(speech_featurizer, text_featurizer)
+jasper = Jasper(**config["model_config"], vocabulary_size=text_featurizer.num_classes)
+jasper._build(speech_featurizer.shape)
+jasper.load_weights(args.saved, by_name=True)
+jasper.summary(line_length=120)
+jasper.add_featurizers(speech_featurizer, text_featurizer)
if args.tfrecords:
test_dataset = ASRTFRecordDataset(
@@ -82,5 +89,5 @@
config=config["learning_config"]["running_config"],
output_name=args.output_name
)
-ctc_tester.compile(satt_ds2_model)
+ctc_tester.compile(jasper)
ctc_tester.run(test_dataset)
diff --git a/examples/sadeepspeech2/train_sadeepspeech2.py b/examples/jasper/train_jasper.py
old mode 100755
new mode 100644
similarity index 81%
rename from examples/sadeepspeech2/train_sadeepspeech2.py
rename to examples/jasper/train_jasper.py
index 188404f646..9e65f2c739
--- a/examples/sadeepspeech2/train_sadeepspeech2.py
+++ b/examples/jasper/train_jasper.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import os
import argparse
from tensorflow_asr.utils import setup_environment, setup_strategy
@@ -22,7 +23,7 @@
tf.keras.backend.clear_session()
-parser = argparse.ArgumentParser(prog="Self Attention DS2")
+parser = argparse.ArgumentParser(prog="Jasper Training")
parser.add_argument("--config", "-c", type=str, default=DEFAULT_YAML,
help="The file path of model configuration file")
@@ -30,15 +31,15 @@
parser.add_argument("--max_ckpts", type=int, default=10,
help="Max number of checkpoints to keep")
-parser.add_argument("--tfrecords", default=False, action="store_true",
- help="Whether to use tfrecords")
-
parser.add_argument("--tbs", type=int, default=None,
help="Train batch size per replicas")
parser.add_argument("--ebs", type=int, default=None,
help="Evaluation batch size per replicas")
+parser.add_argument("--tfrecords", default=False, action="store_true",
+ help="Whether to use tfrecords dataset")
+
parser.add_argument("--devices", type=int, nargs="*", default=[0],
help="Devices' ids to apply distributed training")
@@ -59,17 +60,12 @@
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
from tensorflow_asr.runners.ctc_runners import CTCTrainer
-from model import SelfAttentionDS2
-from optimizer import create_optimizer
+from tensorflow_asr.models.jasper import Jasper
config = UserConfig(DEFAULT_YAML, args.config, learning=True)
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
text_featurizer = CharFeaturizer(config["decoder_config"])
-ctc_trainer = CTCTrainer(text_featurizer,
- config["learning_config"]["running_config"],
- strategy=strategy)
-
if args.tfrecords:
train_dataset = ASRTFRecordDataset(
data_paths=config["learning_config"]["dataset_config"]["train_paths"],
@@ -88,34 +84,27 @@
)
else:
train_dataset = ASRSliceDataset(
- data_paths=config["learning_config"]["dataset_config"]["train_paths"],
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
+ data_paths=config["learning_config"]["dataset_config"]["train_paths"],
augmentations=config["learning_config"]["augmentations"],
stage="train", cache=args.cache, shuffle=True
)
eval_dataset = ASRSliceDataset(
- data_paths=config["learning_config"]["dataset_config"]["eval_paths"],
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
+ data_paths=config["learning_config"]["dataset_config"]["eval_paths"],
stage="eval", cache=args.cache, shuffle=True
)
+ctc_trainer = CTCTrainer(text_featurizer, config["learning_config"]["running_config"])
# Build DS2 model
with ctc_trainer.strategy.scope():
- satt_ds2_model = SelfAttentionDS2(
- input_shape=speech_featurizer.shape,
- arch_config=config["model_config"],
- num_classes=text_featurizer.num_classes
- )
- satt_ds2_model._build(speech_featurizer.shape)
- satt_ds2_model.summary(line_length=150)
- optimizer = create_optimizer(
- name=config["learning_config"]["optimizer_config"]["name"],
- d_model=config["model_config"]["att"]["head_size"],
- **config["learning_config"]["optimizer_config"]["config"]
- )
+ jasper = Jasper(**config["model_config"], vocabulary_size=text_featurizer.num_classes)
+ jasper._build(speech_featurizer.shape)
+ jasper.summary(line_length=120)
# Compile
-ctc_trainer.compile(satt_ds2_model, optimizer, max_to_keep=args.max_ckpts)
+ctc_trainer.compile(jasper, config["learning_config"]["optimizer_config"],
+ max_to_keep=args.max_ckpts)
ctc_trainer.fit(train_dataset, eval_dataset, train_bs=args.tbs, eval_bs=args.ebs)
diff --git a/examples/sadeepspeech2/README.md b/examples/sadeepspeech2/README.md
deleted file mode 100755
index b7e14c9f11..0000000000
--- a/examples/sadeepspeech2/README.md
+++ /dev/null
@@ -1,49 +0,0 @@
-# Custom Self Attention Deep Speech 2
-
-My customized model based on Deep Speech 2 by adding Multihead Self Attention layers (With Additional Residual) between each Convolution and Batch Norm
-
-
-
-## Model YAML Config Structure
-
-```yaml
-model_config:
- subsampling:
- filters: 144
- kernel_size: 32
- strides: 2
- att:
- layers: 16
- head_size: 36
- num_heads: 4
- ffn_size: 1024
- dropout: 0.1
- rnn:
- layers: 1
- units: 320
- dropout: 0
-```
-
-## Training and Testing
-
-See `python examples/self_attention_ds2/run_sattds2.py --help`
-
-## Results on Vivos
-
-* Features: Log Mel Spectrogram with `80` frequency channels
-* KenLM: `alpha = 2.0` and `beta = 1.0`
-* Epochs: `20`
-* Train set split ratio: `90:10`
-* Augmentation: `None`
-* Model architecture: same as above YAML
-
-**CTC Loss**
-
-
-
-**Error rates**
-
-| | WER (%) | CER (%) |
-| :-------------- | :------------: | :------------: |
-| *BeamSearch* | 45.09 | 17.775 |
-| *BeamSearch LM* | **19.36** | **9.94** |
diff --git a/examples/sadeepspeech2/each_layer_histogram.py b/examples/sadeepspeech2/each_layer_histogram.py
deleted file mode 100644
index d692dd69e8..0000000000
--- a/examples/sadeepspeech2/each_layer_histogram.py
+++ /dev/null
@@ -1,103 +0,0 @@
-import os
-import argparse
-
-from tensorflow_asr.utils import setup_environment
-
-setup_environment()
-import matplotlib.pyplot as plt
-import tensorflow as tf
-import numpy as np
-
-from tensorflow_asr.featurizers.speech_featurizers import SpeechFeaturizer, read_raw_audio
-from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
-from tensorflow_asr.configs.user_config import UserConfig
-from tensorflow_asr.utils.utils import bytes_to_string
-from ctc_decoders import Scorer
-from model import SelfAttentionDS2
-
-
-def main():
- parser = argparse.ArgumentParser(prog="SelfAttentionDS2 Histogram")
-
- parser.add_argument("--config", type=str, default=None,
- help="Config file")
-
- parser.add_argument("--audio", type=str, default=None,
- help="Audio file")
-
- parser.add_argument("--saved_model", type=str, default=None,
- help="Saved model")
-
- parser.add_argument("--from_weights", type=bool, default=False,
- help="Load from weights")
-
- parser.add_argument("--output", type=str, default=None,
- help="Output dir storing histograms")
-
- args = parser.parse_args()
-
- config = UserConfig(args.config, args.config, learning=False)
- speech_featurizer = SpeechFeaturizer(config["speech_config"])
- text_featurizer = CharFeaturizer(config["decoder_config"])
- text_featurizer.add_scorer(Scorer(**text_featurizer.decoder_config["lm_config"],
- vocabulary=text_featurizer.vocab_array))
-
- f, c = speech_featurizer.compute_feature_dim()
- satt_ds2_model = SelfAttentionDS2(input_shape=[None, f, c],
- arch_config=config["model_config"],
- num_classes=text_featurizer.num_classes)
- satt_ds2_model._build([1, 50, f, c])
-
- if args.from_weights:
- satt_ds2_model.load_weights(args.saved_model)
- else:
- saved_model = tf.keras.models.load_model(args.saved_model)
- satt_ds2_model.set_weights(saved_model.get_weights())
-
- satt_ds2_model.summary(line_length=100)
-
- satt_ds2_model.add_featurizers(speech_featurizer, text_featurizer)
-
- signal = read_raw_audio(args.audio, speech_featurizer.sample_rate)
- features = speech_featurizer.extract(signal)
- decoded = satt_ds2_model.recognize_beam(tf.expand_dims(features, 0), lm=True)
- print(bytes_to_string(decoded.numpy()))
-
- for i in range(1, len(satt_ds2_model.base_model.layers)):
- func = tf.keras.backend.function([satt_ds2_model.base_model.input],
- [satt_ds2_model.base_model.layers[i].output])
- data = func([np.expand_dims(features, 0), 1])[0][0]
- print(data.shape)
- data = data.flatten()
- plt.hist(data, 200, color='green', histtype="stepfilled")
- plt.title(f"Output of {satt_ds2_model.base_model.layers[i].name}", fontweight="bold")
- plt.savefig(os.path.join(
- args.output, f"{i}_{satt_ds2_model.base_model.layers[i].name}.png"))
- plt.clf()
- plt.cla()
- plt.close()
-
- fc = satt_ds2_model(tf.expand_dims(features, 0), training=False)
- plt.hist(fc[0].numpy().flatten(), 200, color="green", histtype="stepfilled")
- plt.title(f"Output of {satt_ds2_model.layers[-1].name}", fontweight="bold")
- plt.savefig(os.path.join(args.output, f"{satt_ds2_model.layers[-1].name}.png"))
- plt.clf()
- plt.cla()
- plt.close()
- fc = tf.nn.softmax(fc)
- plt.hist(fc[0].numpy().flatten(), 10, color="green", histtype="stepfilled")
- plt.title("Output of softmax", fontweight="bold")
- plt.savefig(os.path.join(args.output, "softmax_hist.png"))
- plt.clf()
- plt.cla()
- plt.close()
- plt.hist(features.flatten(), 200, color="green", histtype="stepfilled")
- plt.title("Log Mel Spectrogram", fontweight="bold")
- plt.savefig(os.path.join(args.output, "log_mel_spectrogram.png"))
- plt.clf()
- plt.cla()
- plt.close()
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/sadeepspeech2/each_layer_imshow.py b/examples/sadeepspeech2/each_layer_imshow.py
deleted file mode 100644
index aa85acba6a..0000000000
--- a/examples/sadeepspeech2/each_layer_imshow.py
+++ /dev/null
@@ -1,123 +0,0 @@
-import os
-import argparse
-
-from tensorflow_asr.utils import setup_environment
-
-setup_environment()
-import matplotlib.pyplot as plt
-from mpl_toolkits.axes_grid1 import make_axes_locatable
-import tensorflow as tf
-import numpy as np
-
-from tensorflow_asr.featurizers.speech_featurizers import SpeechFeaturizer, read_raw_audio
-from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
-from tensorflow_asr.configs.user_config import UserConfig
-from tensorflow_asr.utils.utils import bytes_to_string
-from ctc_decoders import Scorer
-from model import SelfAttentionDS2
-
-
-def main():
- parser = argparse.ArgumentParser(prog="SelfAttentionDS2 Histogram")
-
- parser.add_argument("--config", type=str, default=None,
- help="Config file")
-
- parser.add_argument("--audio", type=str, default=None,
- help="Audio file")
-
- parser.add_argument("--saved_model", type=str, default=None,
- help="Saved model")
-
- parser.add_argument("--from_weights", type=bool, default=False,
- help="Load from weights")
-
- parser.add_argument("--output", type=str, default=None,
- help="Output dir storing histograms")
-
- args = parser.parse_args()
-
- config = UserConfig(args.config, args.config, learning=False)
- speech_featurizer = SpeechFeaturizer(config["speech_config"])
- text_featurizer = CharFeaturizer(config["decoder_config"])
- text_featurizer.add_scorer(Scorer(**text_featurizer.decoder_config["lm_config"],
- vocabulary=text_featurizer.vocab_array))
-
- f, c = speech_featurizer.compute_feature_dim()
- satt_ds2_model = SelfAttentionDS2(input_shape=[None, f, c],
- arch_config=config["model_config"],
- num_classes=text_featurizer.num_classes)
- satt_ds2_model._build([1, 50, f, c])
-
- if args.from_weights:
- satt_ds2_model.load_weights(args.saved_model)
- else:
- saved_model = tf.keras.models.load_model(args.saved_model)
- satt_ds2_model.set_weights(saved_model.get_weights())
-
- satt_ds2_model.summary(line_length=100)
-
- satt_ds2_model.add_featurizers(speech_featurizer, text_featurizer)
-
- signal = read_raw_audio(args.audio, speech_featurizer.sample_rate)
- features = speech_featurizer.extract(signal)
- decoded = satt_ds2_model.recognize_beam(tf.expand_dims(features, 0), lm=True)
- print(bytes_to_string(decoded.numpy()))
-
- # for i in range(1, len(satt_ds2_model.base_model.layers)):
- # func = tf.keras.backend.function([satt_ds2_model.base_model.input],
- # [satt_ds2_model.base_model.layers[i].output])
- # data = func([np.expand_dims(features, 0), 1])[0][0]
- # print(data.shape)
- # plt.figure(figsize=(16, 5))
- # ax = plt.gca()
- # im = ax.imshow(data.T, origin="lower", aspect="auto")
- # ax.set_title(f"{satt_ds2_model.base_model.layers[i].name}", fontweight="bold")
- # divider = make_axes_locatable(ax)
- # cax = divider.append_axes("right", size="5%", pad=0.05)
- # plt.colorbar(im, cax=cax)
- # plt.savefig(os.path.join(
- # args.output, f"{i}_{satt_ds2_model.base_model.layers[i].name}.png"))
- # plt.clf()
- # plt.cla()
- # plt.close()
-
- fc = satt_ds2_model(tf.expand_dims(features, 0), training=False)
- plt.figure(figsize=(16, 5))
- ax = plt.gca()
- ax.set_title(f"{satt_ds2_model.layers[-1].name}", fontweight="bold")
- im = ax.imshow(fc[0].numpy().T, origin="lower", aspect="auto")
- divider = make_axes_locatable(ax)
- cax = divider.append_axes("right", size="5%", pad=0.05)
- plt.colorbar(im, cax=cax)
- plt.savefig(os.path.join(args.output, f"{satt_ds2_model.layers[-1].name}.png"))
- plt.clf()
- plt.cla()
- plt.close()
- fc = tf.nn.softmax(fc)
- plt.figure(figsize=(16, 5))
- ax = plt.gca()
- ax.set_title("Softmax", fontweight="bold")
- im = ax.imshow(fc[0].numpy().T, origin="lower", aspect="auto")
- divider = make_axes_locatable(ax)
- cax = divider.append_axes("right", size="5%", pad=0.05)
- plt.colorbar(im, cax=cax)
- plt.savefig(os.path.join(args.output, "softmax.png"))
- plt.clf()
- plt.cla()
- plt.close()
- plt.figure(figsize=(16, 5))
- ax = plt.gca()
- ax.set_title("Log Mel Spectrogram", fontweight="bold")
- im = ax.imshow(features[:, :, 0].T, origin="lower", aspect="auto")
- divider = make_axes_locatable(ax)
- cax = divider.append_axes("right", size="5%", pad=0.05)
- plt.colorbar(im, cax=cax)
- plt.savefig(os.path.join(args.output, "features.png"))
- plt.clf()
- plt.cla()
- plt.close()
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/sadeepspeech2/figs/arch.png b/examples/sadeepspeech2/figs/arch.png
deleted file mode 100755
index 44c36ad7e7..0000000000
Binary files a/examples/sadeepspeech2/figs/arch.png and /dev/null differ
diff --git a/examples/sadeepspeech2/figs/vivos_ctc_loss.svg b/examples/sadeepspeech2/figs/vivos_ctc_loss.svg
deleted file mode 100644
index 314c3fb954..0000000000
--- a/examples/sadeepspeech2/figs/vivos_ctc_loss.svg
+++ /dev/null
@@ -1 +0,0 @@
-
\ No newline at end of file
diff --git a/examples/sadeepspeech2/model.py b/examples/sadeepspeech2/model.py
deleted file mode 100755
index ce0038d4e7..0000000000
--- a/examples/sadeepspeech2/model.py
+++ /dev/null
@@ -1,110 +0,0 @@
-# Copyright 2020 Huy Le Nguyen (@usimarit)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Read https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTM
-to use cuDNN-LSTM
-"""
-import tensorflow as tf
-import tensorflow_addons as tfa
-
-from tensorflow_asr.models.layers.positional_encoding import PositionalEncoding
-from tensorflow_asr.models.layers.point_wise_ffn import PointWiseFFN
-from tensorflow_asr.models.layers.sequence_wise_bn import SequenceBatchNorm
-from tensorflow_asr.utils.utils import merge_two_last_dims
-from tensorflow_asr.models.ctc import CtcModel
-
-ARCH_CONFIG = {
- "subsampling": {
- "filters": 256,
- "kernel_size": 32,
- "strides": 2
- },
- "att": {
- "layers": 2,
- "head_size": 256,
- "num_heads": 4,
- "ffn_size": 1024,
- "dropout": 0.1
- },
- "rnn": {
- "layers": 2,
- "units": 512,
- "dropout": 0.0
- },
-}
-
-
-def create_sattds2(input_shape: list,
- arch_config: dict,
- name: str = "self_attention_ds2"):
- features = tf.keras.Input(shape=input_shape, name="features")
- layer = merge_two_last_dims(features)
-
- layer = tf.keras.layers.Conv1D(filters=arch_config["subsampling"]["filters"],
- kernel_size=arch_config["subsampling"]["kernel_size"],
- strides=arch_config["subsampling"]["strides"],
- padding="same")(layer)
- layer = tf.keras.layers.BatchNormalization()(layer)
- layer = tf.keras.layers.ReLU()(layer)
-
- for i in range(arch_config["att"]["layers"]):
- ffn = tf.keras.layers.LayerNormalization()(layer)
-
- ffn = PointWiseFFN(size=arch_config["att"]["ffn_size"],
- output_size=layer.shape[-1],
- dropout=arch_config["att"]["dropout"],
- name=f"ffn1_{i}")(ffn)
- layer = tf.keras.layers.Add()([layer, 0.5 * ffn])
- layer = tf.keras.layers.LayerNormalization()(layer)
- pe = PositionalEncoding(name=f"pos_enc_{i}")(layer)
- att = tf.keras.layers.Add(name=f"pos_enc_add_{i}")([layer, pe])
- att = tfa.layers.MultiHeadAttention(head_size=arch_config["att"]["head_size"],
- num_heads=arch_config["att"]["num_heads"],
- name=f"mulhead_satt_{i}")([att, att, att])
- att = tf.keras.layers.Dropout(arch_config["att"]["dropout"],
- name=f"mhsa_dropout_{i}")(att)
- layer = tf.keras.layers.Add()([layer, att])
- ffn = tf.keras.layers.LayerNormalization()(layer)
- ffn = PointWiseFFN(size=arch_config["att"]["ffn_size"],
- output_size=layer.shape[-1],
- dropout=arch_config["att"]["dropout"],
- name=f"ffn2_{i}")(ffn)
- layer = tf.keras.layers.Add()([layer, 0.5 * ffn])
-
- output = tf.keras.layers.LayerNormalization()(layer)
-
- for i in range(arch_config["rnn"]["layers"]):
- output = tf.keras.layers.Bidirectional(
- tf.keras.layers.LSTM(units=arch_config["rnn"]["units"],
- dropout=arch_config["rnn"]["dropout"],
- return_sequences=True))(output)
- output = SequenceBatchNorm(time_major=False, name=f"seq_bn_{i}")(output)
-
- return tf.keras.Model(inputs=features, outputs=output, name=name)
-
-
-class SelfAttentionDS2(CtcModel):
- def __init__(self,
- input_shape: list,
- arch_config: dict,
- num_classes: int,
- name: str = "self_attention_ds2"):
- super(SelfAttentionDS2, self).__init__(
- base_model=create_sattds2(input_shape=input_shape,
- arch_config=arch_config,
- name=name),
- num_classes=num_classes,
- name=f"{name}_ctc"
- )
- self.time_reduction_factor = arch_config["subsampling"]["strides"]
diff --git a/examples/sadeepspeech2/optimizer.py b/examples/sadeepspeech2/optimizer.py
deleted file mode 100755
index fbfe220063..0000000000
--- a/examples/sadeepspeech2/optimizer.py
+++ /dev/null
@@ -1,37 +0,0 @@
-# Copyright 2020 Huy Le Nguyen (@usimarit)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import tensorflow as tf
-
-from tensorflow_asr.optimizers.schedules import TransformerSchedule, SANSchedule
-
-
-def create_optimizer(name, d_model, lamb=0.05, warmup_steps=4000):
- if name == "transformer_adam":
- learning_rate = TransformerSchedule(d_model, warmup_steps)
- optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9,
- beta_2=0.98, epsilon=1e-9)
-
- elif name == "transformer_sgd":
- learning_rate = TransformerSchedule(d_model, warmup_steps)
- optimizer = tf.keras.optimizers.SGD(learning_rate, momentum=0.99, nesterov=True)
-
- elif name == "san":
- learning_rate = SANSchedule(lamb, d_model, warmup_steps)
- optimizer = tf.keras.optimizers.SGD(learning_rate, momentum=0.99, nesterov=True)
-
- else:
- raise ValueError("optimizer name must be either 'transformer' or 'san'")
-
- return optimizer
diff --git a/tensorflow_asr/models/ctc.py b/tensorflow_asr/models/ctc.py
index 35f98377c7..1431ea1a7b 100644
--- a/tensorflow_asr/models/ctc.py
+++ b/tensorflow_asr/models/ctc.py
@@ -15,48 +15,28 @@
import numpy as np
import tensorflow as tf
-from ctc_decoders import ctc_greedy_decoder, ctc_beam_search_decoder
-
+from . import Model
from ..featurizers.speech_featurizers import TFSpeechFeaturizer
from ..featurizers.text_featurizers import TextFeaturizer
from ..utils.utils import shape_list
-class CtcModel(tf.keras.Model):
- def __init__(self,
- base_model: tf.keras.Model,
- num_classes: int,
- name="ctc_model",
- **kwargs):
- super(CtcModel, self).__init__(name=name, **kwargs)
- self.base_model = base_model
- # Fully connected layer
- self.fc = tf.keras.layers.Dense(units=num_classes, activation="linear",
- use_bias=True, name=f"{name}_fc")
+class CtcModel(Model):
+ def __init__(self, **kwargs):
+ super(CtcModel, self).__init__(**kwargs)
def _build(self, input_shape):
features = tf.keras.Input(input_shape, dtype=tf.float32)
self(features, training=False)
- def summary(self, line_length=None, **kwargs):
- self.base_model.summary(line_length=line_length, **kwargs)
- super(CtcModel, self).summary(line_length, **kwargs)
-
def add_featurizers(self,
speech_featurizer: TFSpeechFeaturizer,
text_featurizer: TextFeaturizer):
self.speech_featurizer = speech_featurizer
self.text_featurizer = text_featurizer
- def call(self, inputs, training=False, **kwargs):
- outputs = self.base_model(inputs, training=training)
- outputs = self.fc(outputs, training=training)
- return outputs
-
- def get_config(self):
- config = self.base_model.get_config()
- config.update(self.fc.get_config())
- return config
+ def call(self, inputs, training=False):
+ raise NotImplementedError()
# -------------------------------- GREEDY -------------------------------------
@@ -72,6 +52,7 @@ def map_fn(prob):
return tf.map_fn(map_fn, probs, dtype=tf.string)
def perform_greedy(self, probs: np.ndarray):
+ from ctc_decoders import ctc_greedy_decoder
decoded = ctc_greedy_decoder(probs, vocabulary=self.text_featurizer.vocab_array)
return tf.convert_to_tensor(decoded, dtype=tf.string)
@@ -114,6 +95,7 @@ def map_fn(prob):
def perform_beam_search(self,
probs: np.ndarray,
lm: bool = False):
+ from ctc_decoders import ctc_beam_search_decoder
decoded = ctc_beam_search_decoder(
probs_seq=probs,
vocabulary=self.text_featurizer.vocab_array,
@@ -125,6 +107,14 @@ def perform_beam_search(self,
return tf.convert_to_tensor(decoded, dtype=tf.string)
def recognize_beam_tflite(self, signal):
+ """
+ Function to convert to tflite using beam search decoding
+ Args:
+ signal: tf.Tensor with shape [None] indicating a single audio signal
+
+ Return:
+ transcript: tf.Tensor of Unicode Code Points with shape [None] and dtype tf.int32
+ """
features = self.speech_featurizer.tf_extract(signal)
features = tf.expand_dims(features, axis=0)
input_length = shape_list(features)[1]
diff --git a/tensorflow_asr/models/deepspeech2.py b/tensorflow_asr/models/deepspeech2.py
new file mode 100644
index 0000000000..83f3e6b91e
--- /dev/null
+++ b/tensorflow_asr/models/deepspeech2.py
@@ -0,0 +1,311 @@
+# Copyright 2020 Huy Le Nguyen (@usimarit)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tensorflow as tf
+
+from ..utils.utils import get_rnn, get_conv, merge_two_last_dims
+from .layers.row_conv_1d import RowConv1D
+from .layers.sequence_wise_bn import SequenceBatchNorm
+from .ctc import CtcModel
+
+
+class Reshape(tf.keras.layers.Layer):
+ def call(self, inputs): return merge_two_last_dims(inputs)
+
+
+class ConvBlock(tf.keras.layers.Layer):
+ def __init__(self,
+ conv_type: str = "conv2d",
+ kernels: list = [11, 41],
+ strides: list = [2, 2],
+ filters: int = 32,
+ dropout: float = 0.1,
+ **kwargs):
+ super(ConvBlock, self).__init__(**kwargs)
+
+ CNN = get_conv(conv_type)
+ self.conv = CNN(filters=filters, kernel_size=kernels,
+ strides=strides, padding="same",
+ dtype=tf.float32, name=f"{self.name}_{conv_type}")
+ self.bn = tf.keras.layers.BatchNormalization(name=f"{self.name}_bn")
+ self.relu = tf.keras.layers.ReLU(name=f"{self.name}_relu")
+ self.do = tf.keras.layers.Dropout(dropout, name=f"{self.name}_dropout")
+
+ def call(self, inputs, training=False):
+ outputs = self.conv(inputs, training=training)
+ outputs = self.bn(outputs, training=training)
+ outputs = self.relu(outputs, training=training)
+ outputs = self.do(outputs, training=training)
+ return outputs
+
+ def get_config(self):
+ conf = super(ConvBlock, self).get_config()
+ conf.update(self.conv.get_config())
+ conf.update(self.bn.get_config())
+ conf.update(self.relu.get_config())
+ conf.update(self.do.get_config())
+ return conf
+
+
+class ConvModule(tf.keras.Model):
+ def __init__(self,
+ conv_type: str = "conv2d",
+ kernels: list = [[11, 41], [11, 21], [11, 21]],
+ strides: list = [[2, 2], [1, 2], [1, 2]],
+ filters: list = [32, 32, 96],
+ dropout: float = 0.1,
+ **kwargs):
+ super(ConvModule, self).__init__(**kwargs)
+
+ assert len(kernels) == len(strides) == len(filters)
+ assert dropout >= 0.0
+
+ self.preprocess = None # reshape from [B, T, F, C] to [B, T, F * C]
+ if conv_type == "conv1d": self.preprocess = Reshape(name=f"{self.name}_preprocess")
+
+ self.blocks = [
+ ConvBlock(
+ conv_type=conv_type,
+ kernels=kernels[i],
+ strides=strides[i],
+ filters=filters[i],
+ dropout=dropout,
+ name=f"{self.name}_block_{i}"
+ ) for i in range(len(filters))
+ ]
+
+ self.postprocess = None # reshape from [B, T, F, C] to [B, T, F * C]
+ if conv_type == "conv2d": self.postprocess = Reshape(name=f"{self.name}_postprocess")
+
+ self.reduction_factor = 1
+ for s in strides: self.reduction_factor *= s[0]
+
+ def call(self, inputs, training=False):
+ outputs = inputs
+ if self.preprocess is not None: outputs = self.preprocess(outputs)
+ for block in self.blocks:
+ outputs = block(outputs, training=training)
+ if self.postprocess is not None: outputs = self.postprocess(outputs)
+ return outputs
+
+ def get_config(self):
+ conf = {}
+ conf.update(self.preprocess.get_config())
+ for block in self.blocks:
+ conf.update(block.get_config())
+ conf.update(self.postprocess.get_config())
+ return conf
+
+
+class RnnBlock(tf.keras.layers.Layer):
+ def __init__(self,
+ rnn_type: str = "lstm",
+ units: int = 1024,
+ bidirectional: bool = True,
+ rowconv: int = 0,
+ dropout: float = 0.1,
+ **kwargs):
+ super(RnnBlock, self).__init__(**kwargs)
+
+ RNN = get_rnn(rnn_type)
+ self.rnn = RNN(units, dropout=dropout, return_sequences=True,
+ use_bias=True, name=f"{self.name}_{rnn_type}")
+ if bidirectional:
+ self.rnn = tf.keras.layers.Bidirectional(self.rnn, name=f"{self.name}_b{rnn_type}")
+ self.bn = SequenceBatchNorm(time_major=False, name=f"{self.name}_bn")
+ self.rowconv = None
+ if not bidirectional and rowconv > 0:
+ self.rowconv = RowConv1D(filters=units, future_context=rowconv,
+ name=f"{self.name}_rowconv")
+
+ def call(self, inputs, training=False):
+ outputs = self.rnn(inputs, training=training)
+ outputs = self.bn(outputs, training=training)
+ if self.rowconv is not None:
+ outputs = self.rowconv(outputs, training=training)
+ return outputs
+
+ def get_config(self):
+ conf = super(RnnBlock, self).get_config()
+ conf.update(self.rnn.get_config())
+ conf.update(self.bn.get_config())
+ if self.rowconv is not None:
+ conf.update(self.rowconv.get_config())
+ return conf
+
+
+class RnnModule(tf.keras.Model):
+ def __init__(self,
+ nlayers: int = 5,
+ rnn_type: str = "lstm",
+ units: int = 1024,
+ bidirectional: bool = True,
+ rowconv: int = 0,
+ dropout: float = 0.1,
+ **kwargs):
+ super(RnnModule, self).__init__(**kwargs)
+
+ self.blocks = [
+ RnnBlock(
+ rnn_type=rnn_type,
+ units=units,
+ bidirectional=bidirectional,
+ rowconv=rowconv,
+ dropout=dropout,
+ name=f"{self.name}_block_{i}"
+ ) for i in range(nlayers)
+ ]
+
+ def call(self, inputs, training=False):
+ outputs = inputs
+ for block in self.blocks:
+ outputs = block(outputs, training=training)
+ return outputs
+
+ def get_config(self):
+ conf = {}
+ for block in self.blocks:
+ conf.update(block.get_config())
+ return conf
+
+
+class FcBlock(tf.keras.layers.Layer):
+ def __init__(self,
+ units: int = 1024,
+ dropout: float = 0.1,
+ **kwargs):
+ super(FcBlock, self).__init__(**kwargs)
+
+ self.fc = tf.keras.layers.Dense(units, name=f"{self.name}_fc")
+ self.bn = tf.keras.layers.BatchNormalization(name=f"{self.name}_bn")
+ self.relu = tf.keras.layers.ReLU(name=f"{self.name}_relu")
+ self.do = tf.keras.layers.Dropout(dropout, name=f"{self.name}_dropout")
+
+ def call(self, inputs, training=False):
+ outputs = self.fc(inputs, training=training)
+ outputs = self.bn(outputs, training=training)
+ outputs = self.relu(outputs, training=training)
+ outputs = self.do(outputs, training=training)
+ return outputs
+
+ def get_config(self):
+ conf = super(FcBlock, self).get_config()
+ conf.update(self.fc.get_config())
+ conf.update(self.bn.get_config())
+ conf.update(self.relu.get_config())
+ conf.update(self.do.get_config())
+ return conf
+
+
+class FcModule(tf.keras.Model):
+ def __init__(self,
+ nlayers: int = 0,
+ units: int = 1024,
+ dropout: float = 0.1,
+ **kwargs):
+ super(FcModule, self).__init__(**kwargs)
+
+ self.blocks = [
+ FcBlock(
+ units=units,
+ dropout=dropout,
+ name=f"{self.name}_block_{i}"
+ ) for i in range(nlayers)
+ ]
+
+ def call(self, inputs, training=False):
+ outputs = inputs
+ for block in self.blocks:
+ outputs = block(outputs, training=training)
+ return outputs
+
+ def get_config(self):
+ conf = {}
+ for block in self.blocks:
+ conf.update(block.get_config())
+ return conf
+
+
+class DeepSpeech2(CtcModel):
+ def __init__(self,
+ vocabulary_size: int,
+ conv_type: str = "conv2d",
+ conv_kernels: list = [[11, 41], [11, 21], [11, 21]],
+ conv_strides: list = [[2, 2], [1, 2], [1, 2]],
+ conv_filters: list = [32, 32, 96],
+ conv_dropout: float = 0.1,
+ rnn_nlayers: int = 5,
+ rnn_type: str = "lstm",
+ rnn_units: int = 1024,
+ rnn_bidirectional: bool = True,
+ rnn_rowconv: int = 0,
+ rnn_dropout: float = 0.1,
+ fc_nlayers: int = 0,
+ fc_units: int = 1024,
+ fc_dropout: float = 0.1,
+ name: str = "deepspeech2",
+ **kwargs):
+ super(DeepSpeech2, self).__init__(name=name, **kwargs)
+
+ self.conv_module = ConvModule(
+ conv_type=conv_type,
+ kernels=conv_kernels,
+ strides=conv_strides,
+ filters=conv_filters,
+ dropout=conv_dropout,
+ name=f"{self.name}_conv_module"
+ )
+
+ self.rnn_module = RnnModule(
+ nlayers=rnn_nlayers,
+ rnn_type=rnn_type,
+ units=rnn_units,
+ bidirectional=rnn_bidirectional,
+ rowconv=rnn_rowconv,
+ dropout=rnn_dropout,
+ name=f"{self.name}_rnn_module"
+ )
+
+ self.fc_module = FcModule(
+ nlayers=fc_nlayers,
+ units=fc_units,
+ dropout=fc_dropout,
+ name=f"{self.name}_fc_module"
+ )
+
+ # Fully connected layer
+ self.fc = tf.keras.layers.Dense(units=vocabulary_size, activation="linear",
+ use_bias=True, name=f"{name}_fc")
+
+ self.time_reduction_factor = self.conv_module.reduction_factor
+
+ def call(self, inputs, training=False):
+ outputs = self.conv_module(inputs, training=training)
+ outputs = self.rnn_module(outputs, training=training)
+ outputs = self.fc_module(outputs, training=training)
+ outputs = self.fc(outputs, training=training)
+ return outputs
+
+ def summary(self, line_length=100, **kwargs):
+ self.conv_module.summary(line_length=line_length, **kwargs)
+ self.rnn_module.summary(line_length=line_length, **kwargs)
+ self.fc_module.summary(line_length=line_length, **kwargs)
+ super(DeepSpeech2, self).summary(line_length=line_length, **kwargs)
+
+ def get_config(self):
+ conf = super(DeepSpeech2, self).get_config()
+ conf.update(self.conv_module.get_config())
+ conf.update(self.rnn_module.get_config())
+ conf.update(self.fc_module.get_config())
+ return conf
diff --git a/tensorflow_asr/models/jasper.py b/tensorflow_asr/models/jasper.py
new file mode 100644
index 0000000000..810940c585
--- /dev/null
+++ b/tensorflow_asr/models/jasper.py
@@ -0,0 +1,314 @@
+# Copyright 2020 Huy Le Nguyen (@usimarit)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tensorflow as tf
+
+from ..utils.utils import merge_two_last_dims
+from .ctc import CtcModel
+
+
+class Reshape(tf.keras.layers.Layer):
+ def call(self, inputs): return merge_two_last_dims(inputs)
+
+
+class JasperSubBlock(tf.keras.layers.Layer):
+ def __init__(self,
+ channels: int = 256,
+ kernels: int = 11,
+ strides: int = 1,
+ dropout: float = 0.1,
+ dilation: int = 1,
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ **kwargs):
+ super(JasperSubBlock, self).__init__(**kwargs)
+ self.conv1d = tf.keras.layers.Conv1D(
+ filters=channels, kernel_size=kernels,
+ strides=strides, dilation_rate=dilation, padding="same",
+ kernel_regularizer=kernel_regularizer,
+ bias_regularizer=bias_regularizer,
+ name=f"{self.name}_conv1d"
+ )
+ self.bn = tf.keras.layers.BatchNormalization(name=f"{self.name}_bn")
+ self.relu = tf.keras.layers.ReLU(name=f"{self.name}_relu")
+ self.do = tf.keras.layers.Dropout(dropout, name=f"{self.name}_dropout")
+ self.reduction_factor = strides
+
+ def call(self, inputs, training=False):
+ outputs = inputs
+ outputs = self.conv1d(outputs, training=training)
+ outputs = self.bn(outputs, training=training)
+ outputs = self.relu(outputs, training=training)
+ outputs = self.do(outputs, training=training)
+ return outputs
+
+ def get_config(self):
+ conf = super(JasperSubBlock, self).get_config()
+ conf.update(self.conv1d.get_config())
+ conf.update(self.bn.get_config())
+ conf.update(self.relu.get_config())
+ conf.update(self.do.get_config())
+ return conf
+
+
+class JasperResidual(tf.keras.layers.Layer):
+ def __init__(self,
+ channels: int = 256,
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ **kwargs):
+ super(JasperResidual, self).__init__(**kwargs)
+ self.pointwise_conv1d = tf.keras.layers.Conv1D(
+ filters=channels, kernel_size=1,
+ strides=1, padding="same",
+ kernel_regularizer=kernel_regularizer,
+ bias_regularizer=bias_regularizer,
+ name=f"{self.name}_pointwise_conv1d"
+ )
+ self.bn = tf.keras.layers.BatchNormalization(name=f"{self.name}_bn")
+
+ def call(self, inputs, training=False):
+ outputs = self.pointwise_conv1d(inputs, training=training)
+ outputs = self.bn(outputs, training=training)
+ return outputs
+
+ def get_config(self):
+ conf = super(JasperResidual, self).get_config()
+ conf.update(self.pointwise_conv1d.get_config())
+ conf.update(self.bn.get_config())
+ return conf
+
+
+class JasperSubBlockResidual(JasperSubBlock):
+ def __init__(self,
+ channels: int = 256,
+ kernels: int = 11,
+ strides: int = 1,
+ dropout: float = 0.1,
+ dilation: int = 1,
+ nresiduals: int = 1,
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ **kwargs):
+ super(JasperSubBlockResidual, self).__init__(
+ channels=channels, kernels=kernels,
+ strides=strides, dropout=dropout,
+ dilation=dilation, kernel_regularizer=kernel_regularizer,
+ bias_regularizer=bias_regularizer, **kwargs
+ )
+
+ self.residuals = [
+ JasperResidual(
+ channels=channels,
+ kernel_regularizer=kernel_regularizer,
+ bias_regularizer=bias_regularizer,
+ name=f"{self.name}_residual_{i}"
+ ) for i in range(nresiduals)
+ ]
+
+ self.add = tf.keras.layers.Add(name=f"{self.name}_add")
+
+ def call(self, inputs, training=False):
+ outputs, residuals = inputs
+ outputs = self.conv1d(outputs, training=training)
+ outputs = self.bn(outputs, training=training)
+ for i, res in enumerate(residuals):
+ res = self.residuals[i](res, training=training)
+ outputs = self.add([outputs, res], training=training)
+ outputs = self.relu(outputs, training=training)
+ outputs = self.do(outputs, training=training)
+ return outputs
+
+ def get_config(self):
+ conf = super(JasperSubBlockResidual, self).get_config()
+ conf.update(self.residual.get_config())
+ conf.update(self.add.get_config())
+ return conf
+
+
+class JasperBlock(tf.keras.Model):
+ def __init__(self,
+ nsubblocks: int = 3,
+ channels: int = 256,
+ kernels: int = 11,
+ dropout: float = 0.1,
+ dense: bool = False,
+ nresiduals: int = 1,
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ **kwargs):
+ super(JasperBlock, self).__init__(**kwargs)
+
+ self.dense = dense
+
+ self.subblocks = [
+ JasperSubBlock(
+ channels=channels,
+ kernels=kernels,
+ dropout=dropout,
+ kernel_regularizer=kernel_regularizer,
+ bias_regularizer=bias_regularizer,
+ name=f"{self.name}_subordinate_{i}"
+ ) for i in range(nsubblocks - 1)
+ ]
+
+ self.subblock_residual = JasperSubBlockResidual(
+ channels=channels,
+ kernels=kernels,
+ dropout=dropout,
+ nresiduals=nresiduals,
+ kernel_regularizer=kernel_regularizer,
+ bias_regularizer=bias_regularizer,
+ name=f"{self.name}_subordinate_{nsubblocks - 1}"
+ )
+
+ self.reduction_factor = 1
+
+ def call(self, inputs, training=False):
+ inputs, residuals = inputs
+ outputs = inputs
+ for subblock in self.subblocks:
+ outputs = subblock(outputs, training=training)
+ if self.dense:
+ residuals.append(inputs)
+ outputs = self.subblock_residual([outputs, residuals], training=training)
+ else:
+ outputs = self.subblock_residual([outputs, [inputs]], training=training)
+ return outputs, residuals
+
+ def get_config(self):
+ conf = self.subblock_residual.get_config()
+ conf.update({"dense": self.dense})
+ for subblock in self.subblocks:
+ conf.update(subblock.get_config())
+ return conf
+
+
+class Jasper(CtcModel):
+ def __init__(self,
+ vocabulary_size: int,
+ dense: bool = False,
+ first_additional_block_channels: int = 256,
+ first_additional_block_kernels: int = 11,
+ first_additional_block_strides: int = 2,
+ first_additional_block_dilation: int = 1,
+ first_additional_block_dropout: int = 0.2,
+ nsubblocks: int = 5,
+ block_channels: list = [256, 384, 512, 640, 768],
+ block_kernels: list = [11, 13, 17, 21, 25],
+ block_dropout: list = [0.2, 0.2, 0.2, 0.3, 0.3],
+ second_additional_block_channels: int = 896,
+ second_additional_block_kernels: int = 1,
+ second_additional_block_strides: int = 1,
+ second_additional_block_dilation: int = 2,
+ second_additional_block_dropout: int = 0.4,
+ third_additional_block_channels: int = 1024,
+ third_additional_block_kernels: int = 1,
+ third_additional_block_strides: int = 1,
+ third_additional_block_dilation: int = 1,
+ third_additional_block_dropout: int = 0.4,
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ name: str = "jasper",
+ **kwargs):
+ super(Jasper, self).__init__(name=name, **kwargs)
+
+ assert len(block_channels) == len(block_kernels) == len(block_dropout)
+
+ self.reshape = Reshape(name=f"{self.name}_reshape")
+
+ self.first_additional_block = JasperSubBlock(
+ channels=first_additional_block_channels,
+ kernels=first_additional_block_kernels,
+ strides=first_additional_block_strides,
+ dropout=first_additional_block_dropout,
+ dilation=first_additional_block_dilation,
+ kernel_regularizer=kernel_regularizer,
+ bias_regularizer=bias_regularizer,
+ name=f"{self.name}_first_block"
+ )
+
+ self.blocks = [
+ JasperBlock(
+ nsubblocks=nsubblocks,
+ channels=block_channels[i],
+ kernels=block_kernels[i],
+ dropout=block_dropout[i],
+ dense=dense,
+ nresiduals=(i + 1) if dense else 1,
+ kernel_regularizer=kernel_regularizer,
+ bias_regularizer=bias_regularizer,
+ name=f"{self.name}_block_{i}"
+ ) for i in range(len(block_channels))
+ ]
+
+ self.second_additional_block = JasperSubBlock(
+ channels=second_additional_block_channels,
+ kernels=second_additional_block_kernels,
+ strides=second_additional_block_strides,
+ dropout=second_additional_block_dropout,
+ dilation=second_additional_block_dilation,
+ kernel_regularizer=kernel_regularizer,
+ bias_regularizer=bias_regularizer,
+ name=f"{self.name}_second_block"
+ )
+
+ self.third_additional_block = JasperSubBlock(
+ channels=third_additional_block_channels,
+ kernels=third_additional_block_kernels,
+ strides=third_additional_block_strides,
+ dropout=third_additional_block_dropout,
+ dilation=third_additional_block_dilation,
+ kernel_regularizer=kernel_regularizer,
+ bias_regularizer=bias_regularizer,
+ name=f"{self.name}_third_block"
+ )
+
+ self.last_block = tf.keras.layers.Conv1D(
+ filters=vocabulary_size, kernel_size=1,
+ strides=1, padding="same",
+ kernel_regularizer=kernel_regularizer,
+ bias_regularizer=bias_regularizer,
+ name=f"{self.name}_last_block"
+ )
+
+ self.time_reduction_factor = self.first_additional_block.reduction_factor
+ self.time_reduction_factor *= self.second_additional_block.reduction_factor
+ self.time_reduction_factor *= self.third_additional_block.reduction_factor
+
+ def call(self, inputs, training=False):
+ outputs = self.reshape(inputs)
+ outputs = self.first_additional_block(outputs, training=training)
+
+ residuals = []
+ for block in self.blocks:
+ outputs, residuals = block([outputs, residuals], training=training)
+
+ outputs = self.second_additional_block(outputs, training=training)
+ outputs = self.third_additional_block(outputs, training=training)
+ outputs = self.last_block(outputs, training=training)
+ return outputs
+
+ def summary(self, line_length=100, **kwargs):
+ super(Jasper, self).summary(line_length=line_length, **kwargs)
+
+ def get_config(self):
+ conf = self.reshape.get_config()
+ conf.update(self.first_additional_block.get_config())
+ for block in self.blocks:
+ conf.update(block.get_config())
+ conf.update(self.second_additional_block.get_config())
+ conf.update(self.third_additional_block.get_config())
+ conf.update(self.last_block.get_config())
+ return conf
diff --git a/tensorflow_asr/models/layers/merge_two_last_dims.py b/tensorflow_asr/models/layers/merge_two_last_dims.py
deleted file mode 100755
index fc75086b63..0000000000
--- a/tensorflow_asr/models/layers/merge_two_last_dims.py
+++ /dev/null
@@ -1,28 +0,0 @@
-# Copyright 2020 Huy Le Nguyen (@usimarit)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import tensorflow as tf
-from ...utils.utils import merge_two_last_dims
-
-
-class Merge2LastDims(tf.keras.layers.Layer):
- def __init__(self, name: str = "merge_two_last_dims", **kwargs):
- super(Merge2LastDims, self).__init__(name=name, **kwargs)
-
- def call(self, inputs, **kwargs):
- return merge_two_last_dims(inputs)
-
- def get_config(self):
- config = super(Merge2LastDims, self).get_config()
- return config
diff --git a/tensorflow_asr/models/layers/transpose_time_major.py b/tensorflow_asr/models/layers/transpose_time_major.py
deleted file mode 100755
index 558184fe2a..0000000000
--- a/tensorflow_asr/models/layers/transpose_time_major.py
+++ /dev/null
@@ -1,27 +0,0 @@
-# Copyright 2020 Huy Le Nguyen (@usimarit)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import tensorflow as tf
-
-
-class TransposeTimeMajor(tf.keras.layers.Layer):
- def __init__(self, name: str = "transpose_time_major", **kwargs):
- super(TransposeTimeMajor, self).__init__(name=name, **kwargs)
-
- def call(self, inputs, **kwargs):
- return tf.transpose(inputs, perm=[1, 0, 2])
-
- def get_config(self):
- config = super(TransposeTimeMajor, self).get_config()
- return config
diff --git a/tensorflow_asr/utils/utils.py b/tensorflow_asr/utils/utils.py
index 61c6cce52b..4b03d5c978 100755
--- a/tensorflow_asr/utils/utils.py
+++ b/tensorflow_asr/utils/utils.py
@@ -82,6 +82,15 @@ def get_rnn(rnn_type):
return tf.keras.layers.SimpleRNN
+def get_conv(conv_type):
+ assert conv_type in ["conv1d", "conv2d"]
+
+ if conv_type == "conv1d":
+ return tf.keras.layers.Conv1D
+
+ return tf.keras.layers.Conv2D
+
+
def print_one_line(*args):
tf.print("\033[K", end="")
tf.print("\r", *args, sep="", end=" ", output_stream=sys.stdout)