diff --git a/examples/conformer/config.yml b/examples/conformer/config.yml index db47374810..e538ab6383 100755 --- a/examples/conformer/config.yml +++ b/examples/conformer/config.yml @@ -52,7 +52,7 @@ model_config: prediction_num_rnns: 1 prediction_rnn_units: 320 prediction_rnn_type: lstm - prediction_rnn_implementation: 1 + prediction_rnn_implementation: 2 prediction_layer_norm: True prediction_projection_units: 0 joint_dim: 320 @@ -77,7 +77,7 @@ learning_config: - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv test_paths: - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv - tfrecords_dir: null + tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords optimizer_config: warmup_steps: 40000 @@ -93,3 +93,16 @@ learning_config: log_interval_steps: 300 eval_interval_steps: 500 save_interval_steps: 1000 + checkpoint: + filepath: /mnt/Miscellanea/Models/local/conformer/checkpoints/{epoch:02d}.h5 + save_best_only: True + save_weights_only: False + save_freq: epoch + states_dir: /mnt/Miscellanea/Models/local/conformer/states + tensorboard: + log_dir: /mnt/Miscellanea/Models/local/conformer/tensorboard + histogram_freq: 1 + write_graph: True + write_images: True + update_freq: 'epoch' + profile_batch: 2 diff --git a/examples/conformer/masking/train_ga_masking_conformer.py b/examples/conformer/masking/train_ga_masking_conformer.py index 32e847c5c5..62a0deb240 100644 --- a/examples/conformer/masking/train_ga_masking_conformer.py +++ b/examples/conformer/masking/train_ga_masking_conformer.py @@ -67,7 +67,7 @@ from tensorflow_asr.models.conformer import Conformer from tensorflow_asr.optimizers.schedules import TransformerSchedule -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) text_featurizer = CharFeaturizer(config.decoder_config) diff --git a/examples/conformer/masking/train_ga_masking_subword_conformer.py b/examples/conformer/masking/train_ga_masking_subword_conformer.py index 3fd076236c..1e74f9a68b 100644 --- a/examples/conformer/masking/train_ga_masking_subword_conformer.py +++ b/examples/conformer/masking/train_ga_masking_subword_conformer.py @@ -73,7 +73,7 @@ from tensorflow_asr.models.conformer import Conformer from tensorflow_asr.optimizers.schedules import TransformerSchedule -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) if args.subwords and os.path.exists(args.subwords): diff --git a/examples/conformer/masking/train_masking_conformer.py b/examples/conformer/masking/train_masking_conformer.py index 683282d6eb..82dbbda9ec 100644 --- a/examples/conformer/masking/train_masking_conformer.py +++ b/examples/conformer/masking/train_masking_conformer.py @@ -64,7 +64,7 @@ from tensorflow_asr.models.conformer import Conformer from tensorflow_asr.optimizers.schedules import TransformerSchedule -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) text_featurizer = CharFeaturizer(config.decoder_config) diff --git a/examples/conformer/masking/train_masking_subword_conformer.py b/examples/conformer/masking/train_masking_subword_conformer.py index e7dd1c743a..be99ec3ceb 100644 --- a/examples/conformer/masking/train_masking_subword_conformer.py +++ b/examples/conformer/masking/train_masking_subword_conformer.py @@ -70,7 +70,7 @@ from tensorflow_asr.models.conformer import Conformer from tensorflow_asr.optimizers.schedules import TransformerSchedule -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) if args.subwords and os.path.exists(args.subwords): diff --git a/examples/conformer/save_conformer_from_weights.py b/examples/conformer/save_conformer_from_weights.py index 3d51abfc49..f070731b25 100644 --- a/examples/conformer/save_conformer_from_weights.py +++ b/examples/conformer/save_conformer_from_weights.py @@ -51,7 +51,7 @@ from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer from tensorflow_asr.models.conformer import Conformer -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) text_featurizer = CharFeaturizer(config.decoder_config) diff --git a/examples/conformer/test_conformer.py b/examples/conformer/test_conformer.py index 6f84b5782b..26e470f21f 100755 --- a/examples/conformer/test_conformer.py +++ b/examples/conformer/test_conformer.py @@ -59,7 +59,7 @@ from tensorflow_asr.runners.base_runners import BaseTester from tensorflow_asr.models.conformer import Conformer -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) text_featurizer = CharFeaturizer(config.decoder_config) diff --git a/examples/conformer/test_subword_conformer.py b/examples/conformer/test_subword_conformer.py index c2ff26ee6c..4c7f0edf0b 100755 --- a/examples/conformer/test_subword_conformer.py +++ b/examples/conformer/test_subword_conformer.py @@ -64,7 +64,7 @@ from tensorflow_asr.runners.base_runners import BaseTester from tensorflow_asr.models.conformer import Conformer -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) if args.sentence_piece: diff --git a/examples/conformer/tflite_conformer.py b/examples/conformer/tflite_conformer.py index 4dd90d03a9..a44997a3be 100644 --- a/examples/conformer/tflite_conformer.py +++ b/examples/conformer/tflite_conformer.py @@ -43,7 +43,7 @@ assert args.saved and args.output -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) text_featurizer = CharFeaturizer(config.decoder_config) diff --git a/examples/conformer/tflite_subword_conformer.py b/examples/conformer/tflite_subword_conformer.py index 6205c08a5d..1adc5e75d4 100644 --- a/examples/conformer/tflite_subword_conformer.py +++ b/examples/conformer/tflite_subword_conformer.py @@ -46,7 +46,7 @@ assert args.saved and args.output -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) if args.subwords and os.path.exists(args.subwords): diff --git a/examples/conformer/train_conformer.py b/examples/conformer/train_conformer.py index 5697130463..0cc601af60 100644 --- a/examples/conformer/train_conformer.py +++ b/examples/conformer/train_conformer.py @@ -60,7 +60,7 @@ from tensorflow_asr.models.conformer import Conformer from tensorflow_asr.optimizers.schedules import TransformerSchedule -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) text_featurizer = CharFeaturizer(config.decoder_config) diff --git a/examples/conformer/train_ga_conformer.py b/examples/conformer/train_ga_conformer.py index d5be194c4d..81bbdc3074 100644 --- a/examples/conformer/train_ga_conformer.py +++ b/examples/conformer/train_ga_conformer.py @@ -62,7 +62,7 @@ from tensorflow_asr.models.conformer import Conformer from tensorflow_asr.optimizers.schedules import TransformerSchedule -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) text_featurizer = CharFeaturizer(config.decoder_config) diff --git a/examples/conformer/train_ga_subword_conformer.py b/examples/conformer/train_ga_subword_conformer.py index 609c27e6ad..519b14d136 100644 --- a/examples/conformer/train_ga_subword_conformer.py +++ b/examples/conformer/train_ga_subword_conformer.py @@ -68,7 +68,7 @@ from tensorflow_asr.models.conformer import Conformer from tensorflow_asr.optimizers.schedules import TransformerSchedule -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) if args.sentence_piece: diff --git a/examples/conformer/train_keras_subword_conformer.py b/examples/conformer/train_keras_subword_conformer.py new file mode 100644 index 0000000000..2c87fa6da7 --- /dev/null +++ b/examples/conformer/train_keras_subword_conformer.py @@ -0,0 +1,156 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import math +import argparse +from tensorflow_asr.utils import setup_environment, setup_strategy + +setup_environment() +import tensorflow as tf + +DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml") + +tf.keras.backend.clear_session() + +parser = argparse.ArgumentParser(prog="Conformer Training") + +parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file") + +parser.add_argument("--max_ckpts", type=int, default=10, help="Max number of checkpoints to keep") + +parser.add_argument("--tfrecords", default=False, action="store_true", help="Whether to use tfrecords") + +parser.add_argument("--tfrecords_shards", type=int, default=16, help="Number of tfrecords shards") + +parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model") + +parser.add_argument("--tbs", type=int, default=None, help="Train batch size per replica") + +parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica") + +parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training") + +parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision") + +parser.add_argument("--cache", default=False, action="store_true", help="Enable caching for dataset") + +parser.add_argument("--subwords", type=str, default=None, help="Path to file that stores generated subwords") + +parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[], help="Transcript files for generating subwords") + +parser.add_argument("--bfs", type=int, default=100, help="Buffer size for shuffling") + +args = parser.parse_args() + +tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp}) + +strategy = setup_strategy(args.devices) + +from tensorflow_asr.configs.config import Config +from tensorflow_asr.datasets.keras import ASRTFRecordDatasetKeras, ASRSliceDatasetKeras +from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer +from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, SentencePieceFeaturizer +from tensorflow_asr.models.keras.conformer import Conformer +from tensorflow_asr.optimizers.schedules import TransformerSchedule + +config = Config(args.config) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) + +if args.sentence_piece: + print("Loading SentencePiece model ...") + text_featurizer = SentencePieceFeaturizer.load_from_file(config.decoder_config, args.subwords) +elif args.subwords and os.path.exists(args.subwords): + print("Loading subwords ...") + text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords) +else: + print("Generating subwords ...") + text_featurizer = SubwordFeaturizer.build_from_corpus( + config.decoder_config, + corpus_files=args.subwords_corpus + ) + text_featurizer.save_to_file(args.subwords) + +if args.tfrecords: + train_dataset = ASRTFRecordDatasetKeras( + data_paths=config.learning_config.dataset_config.train_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, + speech_featurizer=speech_featurizer, + text_featurizer=text_featurizer, + augmentations=config.learning_config.augmentations, + tfrecords_shards=args.tfrecords_shards, + stage="train", cache=args.cache, + shuffle=True, buffer_size=args.bfs, + ) + eval_dataset = ASRTFRecordDatasetKeras( + data_paths=config.learning_config.dataset_config.eval_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, + tfrecords_shards=args.tfrecords_shards, + speech_featurizer=speech_featurizer, + text_featurizer=text_featurizer, + stage="eval", cache=args.cache, + shuffle=True, buffer_size=args.bfs, + ) +else: + train_dataset = ASRSliceDatasetKeras( + data_paths=config.learning_config.dataset_config.train_paths, + speech_featurizer=speech_featurizer, + text_featurizer=text_featurizer, + augmentations=config.learning_config.augmentations, + stage="train", cache=args.cache, + shuffle=True, buffer_size=args.bfs, + ) + eval_dataset = ASRSliceDatasetKeras( + data_paths=config.learning_config.dataset_config.eval_paths, + speech_featurizer=speech_featurizer, + text_featurizer=text_featurizer, + stage="eval", cache=args.cache, + shuffle=True, buffer_size=args.bfs, + ) + +with strategy.scope(): + global_batch_size = config.learning_config.running_config.batch_size + global_batch_size *= strategy.num_replicas_in_sync + # build model + conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) + conformer._build(speech_featurizer.shape) + conformer.summary(line_length=120) + + optimizer = tf.keras.optimizers.Adam( + TransformerSchedule( + d_model=conformer.dmodel, + warmup_steps=config.learning_config.optimizer_config["warmup_steps"], + max_lr=(0.05 / math.sqrt(conformer.dmodel)) + ), + beta_1=config.learning_config.optimizer_config["beta1"], + beta_2=config.learning_config.optimizer_config["beta2"], + epsilon=config.learning_config.optimizer_config["epsilon"] + ) + + conformer.compile(optimizer=optimizer, global_batch_size=global_batch_size, blank=text_featurizer.blank) + + train_data_loader = train_dataset.create(global_batch_size) + eval_data_loader = eval_dataset.create(global_batch_size) + + callbacks = [ + tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), + tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), + tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) + ] + + conformer.fit( + train_data_loader, epochs=config.learning_config.running_config.num_epochs, + validation_data=eval_data_loader, callbacks=callbacks, + steps_per_epoch=train_dataset.total_steps + ) diff --git a/examples/conformer/train_subword_conformer.py b/examples/conformer/train_subword_conformer.py index 07abf8c8cd..c6a54a1339 100644 --- a/examples/conformer/train_subword_conformer.py +++ b/examples/conformer/train_subword_conformer.py @@ -66,7 +66,7 @@ from tensorflow_asr.models.conformer import Conformer from tensorflow_asr.optimizers.schedules import TransformerSchedule -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) if args.sentence_piece: diff --git a/examples/contextnet/config.yml b/examples/contextnet/config.yml index 7b5d8d2333..f4b5764b65 100644 --- a/examples/contextnet/config.yml +++ b/examples/contextnet/config.yml @@ -213,7 +213,7 @@ learning_config: - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv test_paths: - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv - tfrecords_dir: null + tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords optimizer_config: warmup_steps: 40000 @@ -229,3 +229,16 @@ learning_config: log_interval_steps: 300 eval_interval_steps: 500 save_interval_steps: 1000 + checkpoint: + filepath: /mnt/Miscellanea/Models/local/contextnet/checkpoints/{epoch:02d}.h5 + save_best_only: True + save_weights_only: False + save_freq: epoch + states_dir: /mnt/Miscellanea/Models/local/contextnet/states + tensorboard: + log_dir: /mnt/Miscellanea/Models/local/contextnet/tensorboard + histogram_freq: 1 + write_graph: True + write_images: True + update_freq: 'epoch' + profile_batch: 2 diff --git a/examples/contextnet/test_contextnet.py b/examples/contextnet/test_contextnet.py index 591082d348..2f88848089 100644 --- a/examples/contextnet/test_contextnet.py +++ b/examples/contextnet/test_contextnet.py @@ -59,7 +59,7 @@ from tensorflow_asr.runners.base_runners import BaseTester from tensorflow_asr.models.contextnet import ContextNet -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) text_featurizer = CharFeaturizer(config.decoder_config) diff --git a/examples/contextnet/test_subword_contextnet.py b/examples/contextnet/test_subword_contextnet.py index 2992ece5b5..8c3437b695 100644 --- a/examples/contextnet/test_subword_contextnet.py +++ b/examples/contextnet/test_subword_contextnet.py @@ -62,7 +62,7 @@ from tensorflow_asr.runners.base_runners import BaseTester from tensorflow_asr.models.contextnet import ContextNet -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) if args.subwords and os.path.exists(args.subwords): diff --git a/examples/contextnet/tflite_contextnet.py b/examples/contextnet/tflite_contextnet.py index 9427ceee6e..9eaf279b1d 100644 --- a/examples/contextnet/tflite_contextnet.py +++ b/examples/contextnet/tflite_contextnet.py @@ -43,7 +43,7 @@ assert args.saved and args.output -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) text_featurizer = CharFeaturizer(config.decoder_config) diff --git a/examples/contextnet/tflite_subword_contextnet.py b/examples/contextnet/tflite_subword_contextnet.py index 610648e2a1..6e96c83f21 100644 --- a/examples/contextnet/tflite_subword_contextnet.py +++ b/examples/contextnet/tflite_subword_contextnet.py @@ -46,7 +46,7 @@ assert args.saved and args.output -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) if args.subwords and os.path.exists(args.subwords): diff --git a/examples/contextnet/train_contextnet.py b/examples/contextnet/train_contextnet.py index e52ab642b0..a0cef5c581 100644 --- a/examples/contextnet/train_contextnet.py +++ b/examples/contextnet/train_contextnet.py @@ -60,7 +60,7 @@ from tensorflow_asr.models.contextnet import ContextNet from tensorflow_asr.optimizers.schedules import TransformerSchedule -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) text_featurizer = CharFeaturizer(config.decoder_config) diff --git a/examples/contextnet/train_ga_contextnet.py b/examples/contextnet/train_ga_contextnet.py index d5e969ae54..7b7815fd99 100644 --- a/examples/contextnet/train_ga_contextnet.py +++ b/examples/contextnet/train_ga_contextnet.py @@ -62,7 +62,7 @@ from tensorflow_asr.models.contextnet import ContextNet from tensorflow_asr.optimizers.schedules import TransformerSchedule -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) text_featurizer = CharFeaturizer(config.decoder_config) diff --git a/examples/contextnet/train_ga_subword_contextnet.py b/examples/contextnet/train_ga_subword_contextnet.py index 32cf6a75df..bc7499143e 100644 --- a/examples/contextnet/train_ga_subword_contextnet.py +++ b/examples/contextnet/train_ga_subword_contextnet.py @@ -66,7 +66,7 @@ from tensorflow_asr.models.contextnet import ContextNet from tensorflow_asr.optimizers.schedules import TransformerSchedule -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) if args.subwords and os.path.exists(args.subwords): diff --git a/examples/contextnet/train_keras_subword_contextnet.py b/examples/contextnet/train_keras_subword_contextnet.py new file mode 100644 index 0000000000..a4346e2294 --- /dev/null +++ b/examples/contextnet/train_keras_subword_contextnet.py @@ -0,0 +1,151 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import math +import argparse +from tensorflow_asr.utils import setup_environment, setup_strategy + +setup_environment() +import tensorflow as tf + +DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml") + +tf.keras.backend.clear_session() + +parser = argparse.ArgumentParser(prog="ContextNet Training") + +parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file") + +parser.add_argument("--max_ckpts", type=int, default=10, help="Max number of checkpoints to keep") + +parser.add_argument("--tfrecords", default=False, action="store_true", help="Whether to use tfrecords") + +parser.add_argument("--tfrecords_shards", type=int, default=16, help="Number of tfrecords shards") + +parser.add_argument("--tbs", type=int, default=None, help="Train batch size per replica") + +parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica") + +parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training") + +parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision") + +parser.add_argument("--cache", default=False, action="store_true", help="Enable caching for dataset") + +parser.add_argument("--subwords", type=str, default=None, help="Path to file that stores generated subwords") + +parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[], help="Transcript files for generating subwords") + +parser.add_argument("--bfs", type=int, default=100, help="Buffer size for shuffling") + +args = parser.parse_args() + +tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp}) + +strategy = setup_strategy(args.devices) + +from tensorflow_asr.configs.config import Config +from tensorflow_asr.datasets.keras import ASRTFRecordDatasetKeras, ASRSliceDatasetKeras +from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer +from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer +from tensorflow_asr.models.keras.contextnet import ContextNet +from tensorflow_asr.optimizers.schedules import TransformerSchedule + +config = Config(args.config) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) + +if args.subwords and os.path.exists(args.subwords): + print("Loading subwords ...") + text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords) +else: + print("Generating subwords ...") + text_featurizer = SubwordFeaturizer.build_from_corpus( + config.decoder_config, + corpus_files=args.subwords_corpus + ) + text_featurizer.save_to_file(args.subwords) + +if args.tfrecords: + train_dataset = ASRTFRecordDatasetKeras( + data_paths=config.learning_config.dataset_config.train_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, + speech_featurizer=speech_featurizer, + text_featurizer=text_featurizer, + augmentations=config.learning_config.augmentations, + tfrecords_shards=args.tfrecords_shards, + stage="train", cache=args.cache, + shuffle=True, buffer_size=args.bfs, + ) + eval_dataset = ASRTFRecordDatasetKeras( + data_paths=config.learning_config.dataset_config.eval_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, + tfrecords_shards=args.tfrecords_shards, + speech_featurizer=speech_featurizer, + text_featurizer=text_featurizer, + stage="eval", cache=args.cache, + shuffle=True, buffer_size=args.bfs, + ) +else: + train_dataset = ASRSliceDatasetKeras( + data_paths=config.learning_config.dataset_config.train_paths, + speech_featurizer=speech_featurizer, + text_featurizer=text_featurizer, + augmentations=config.learning_config.augmentations, + stage="train", cache=args.cache, + shuffle=True, buffer_size=args.bfs, + ) + eval_dataset = ASRSliceDatasetKeras( + data_paths=config.learning_config.dataset_config.eval_paths, + speech_featurizer=speech_featurizer, + text_featurizer=text_featurizer, + stage="eval", cache=args.cache, + shuffle=True, buffer_size=args.bfs, + ) + +with strategy.scope(): + global_batch_size = config.learning_config.running_config.batch_size + global_batch_size *= strategy.num_replicas_in_sync + # build model + contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes) + contextnet._build(speech_featurizer.shape) + contextnet.summary(line_length=120) + + optimizer = tf.keras.optimizers.Adam( + TransformerSchedule( + d_model=contextnet.dmodel, + warmup_steps=config.learning_config.optimizer_config["warmup_steps"], + max_lr=(0.05 / math.sqrt(contextnet.dmodel)) + ), + beta_1=config.learning_config.optimizer_config["beta1"], + beta_2=config.learning_config.optimizer_config["beta2"], + epsilon=config.learning_config.optimizer_config["epsilon"] + ) + + contextnet.compile(optimizer=optimizer, global_batch_size=global_batch_size, blank=text_featurizer.blank) + + train_data_loader = train_dataset.create(global_batch_size) + eval_data_loader = eval_dataset.create(global_batch_size) + + callbacks = [ + tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), + tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), + tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) + ] + + contextnet.fit( + train_data_loader, epochs=config.learning_config.running_config.num_epochs, + validation_data=eval_data_loader, callbacks=callbacks, + steps_per_epoch=train_dataset.total_steps + ) diff --git a/examples/contextnet/train_subword_contextnet.py b/examples/contextnet/train_subword_contextnet.py index 2d186fbe42..04a617a59a 100644 --- a/examples/contextnet/train_subword_contextnet.py +++ b/examples/contextnet/train_subword_contextnet.py @@ -64,7 +64,7 @@ from tensorflow_asr.models.contextnet import ContextNet from tensorflow_asr.optimizers.schedules import TransformerSchedule -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) if args.subwords and os.path.exists(args.subwords): diff --git a/examples/deepspeech2/config.yml b/examples/deepspeech2/config.yml index 3c6f6d12f5..c3507c1756 100755 --- a/examples/deepspeech2/config.yml +++ b/examples/deepspeech2/config.yml @@ -59,7 +59,7 @@ learning_config: - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv test_paths: - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv - tfrecords_dir: null + tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords optimizer_config: class_name: adam @@ -74,3 +74,16 @@ learning_config: log_interval_steps: 400 save_interval_steps: 400 eval_interval_steps: 800 + checkpoint: + filepath: /mnt/Miscellanea/Models/local/deepspeech2/checkpoints/{epoch:02d}.h5 + save_best_only: True + save_weights_only: False + save_freq: epoch + states_dir: /mnt/Miscellanea/Models/local/deepspeech2/states + tensorboard: + log_dir: /mnt/Miscellanea/Models/local/deepspeech2/tensorboard + histogram_freq: 1 + write_graph: True + write_images: True + update_freq: 'epoch' + profile_batch: 2 diff --git a/examples/deepspeech2/test_ds2.py b/examples/deepspeech2/test_ds2.py index e84475c38c..6b1baf7ea4 100644 --- a/examples/deepspeech2/test_ds2.py +++ b/examples/deepspeech2/test_ds2.py @@ -59,7 +59,7 @@ tf.random.set_seed(0) assert args.export -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) text_featurizer = CharFeaturizer(config.decoder_config) # Build DS2 model diff --git a/examples/deepspeech2/train_ds2.py b/examples/deepspeech2/train_ds2.py index 33a410a4fd..44ab5f2ebb 100644 --- a/examples/deepspeech2/train_ds2.py +++ b/examples/deepspeech2/train_ds2.py @@ -62,7 +62,7 @@ from tensorflow_asr.runners.ctc_runners import CTCTrainer from tensorflow_asr.models.deepspeech2 import DeepSpeech2 -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) text_featurizer = CharFeaturizer(config.decoder_config) diff --git a/examples/deepspeech2/train_ga_ds2.py b/examples/deepspeech2/train_ga_ds2.py index d73e105587..4ff23683d3 100644 --- a/examples/deepspeech2/train_ga_ds2.py +++ b/examples/deepspeech2/train_ga_ds2.py @@ -65,7 +65,7 @@ from tensorflow_asr.runners.ctc_runners import CTCTrainerGA from tensorflow_asr.models.deepspeech2 import DeepSpeech2 -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) text_featurizer = CharFeaturizer(config.decoder_config) diff --git a/examples/deepspeech2/train_keras_ds2.py b/examples/deepspeech2/train_keras_ds2.py new file mode 100644 index 0000000000..eab4b2bbb4 --- /dev/null +++ b/examples/deepspeech2/train_keras_ds2.py @@ -0,0 +1,124 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +from tensorflow_asr.utils import setup_environment, setup_strategy + +setup_environment() +import tensorflow as tf + +DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml") + +tf.keras.backend.clear_session() + +parser = argparse.ArgumentParser(prog="Deep Speech 2 Training") + +parser.add_argument("--config", "-c", type=str, default=DEFAULT_YAML, + help="The file path of model configuration file") + +parser.add_argument("--max_ckpts", type=int, default=10, + help="Max number of checkpoints to keep") + +parser.add_argument("--tbs", type=int, default=None, + help="Train batch size per replicas") + +parser.add_argument("--ebs", type=int, default=None, + help="Evaluation batch size per replicas") + +parser.add_argument("--tfrecords", default=False, action="store_true", + help="Whether to use tfrecords dataset") + +parser.add_argument("--devices", type=int, nargs="*", default=[0], + help="Devices' ids to apply distributed training") + +parser.add_argument("--mxp", default=False, action="store_true", + help="Enable mixed precision") + +parser.add_argument("--cache", default=False, action="store_true", + help="Enable caching for dataset") + +args = parser.parse_args() + +tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp}) + +strategy = setup_strategy(args.devices) + +from tensorflow_asr.configs.config import Config +from tensorflow_asr.datasets.keras import ASRTFRecordDatasetKeras, ASRSliceDatasetKeras +from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer +from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer +from tensorflow_asr.models.keras.deepspeech2 import DeepSpeech2 + +config = Config(args.config) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) +text_featurizer = CharFeaturizer(config.decoder_config) + +if args.tfrecords: + train_dataset = ASRTFRecordDatasetKeras( + data_paths=config.learning_config.dataset_config.train_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, + speech_featurizer=speech_featurizer, + text_featurizer=text_featurizer, + augmentations=config.learning_config.augmentations, + stage="train", cache=args.cache, shuffle=True + ) + eval_dataset = ASRTFRecordDatasetKeras( + data_paths=config.learning_config.dataset_config.eval_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, + speech_featurizer=speech_featurizer, + text_featurizer=text_featurizer, + stage="eval", cache=args.cache, shuffle=True + ) +else: + train_dataset = ASRSliceDatasetKeras( + speech_featurizer=speech_featurizer, + text_featurizer=text_featurizer, + data_paths=config.learning_config.dataset_config.train_paths, + augmentations=config.learning_config.augmentations, + stage="train", cache=args.cache, shuffle=True + ) + eval_dataset = ASRSliceDatasetKeras( + speech_featurizer=speech_featurizer, + text_featurizer=text_featurizer, + data_paths=config.learning_config.dataset_config.eval_paths, + stage="eval", cache=args.cache, shuffle=True + ) + +# Build DS2 model +with strategy.scope(): + global_batch_size = config.learning_config.running_config.batch_size + global_batch_size *= strategy.num_replicas_in_sync + + ds2_model = DeepSpeech2(**config.model_config, vocabulary_size=text_featurizer.num_classes) + ds2_model._build(speech_featurizer.shape) + ds2_model.summary(line_length=120) + + ds2_model.compile(optimizer=config.learning_config.optimizer_config, + global_batch_size=global_batch_size, blank=text_featurizer.blank) + + train_data_loader = train_dataset.create(global_batch_size) + eval_data_loader = eval_dataset.create(global_batch_size) + + callbacks = [ + tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), + tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), + tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) + ] + + ds2_model.fit( + train_data_loader, epochs=config.learning_config.running_config.num_epochs, + validation_data=eval_data_loader, callbacks=callbacks, + steps_per_epoch=train_dataset.total_steps + ) diff --git a/examples/demonstration/conformer.py b/examples/demonstration/conformer.py index 668a3bc5e5..8d37b4838e 100644 --- a/examples/demonstration/conformer.py +++ b/examples/demonstration/conformer.py @@ -48,7 +48,7 @@ from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer, SubwordFeaturizer from tensorflow_asr.models.conformer import Conformer -config = Config(args.config, learning=False) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) if args.subwords and os.path.exists(args.subwords): print("Loading subwords ...") diff --git a/examples/jasper/config.yml b/examples/jasper/config.yml index a785862d06..7011e62f4f 100755 --- a/examples/jasper/config.yml +++ b/examples/jasper/config.yml @@ -81,3 +81,16 @@ learning_config: log_interval_steps: 400 save_interval_steps: 400 eval_interval_steps: 800 + checkpoint: + filepath: /mnt/Miscellanea/Models/local/jasper/checkpoints/{epoch:02d}.h5 + save_best_only: True + save_weights_only: False + save_freq: epoch + states_dir: /mnt/Miscellanea/Models/local/jasper/states + tensorboard: + log_dir: /mnt/Miscellanea/Models/local/jasper/tensorboard + histogram_freq: 1 + write_graph: True + write_images: True + update_freq: 'epoch' + profile_batch: 2 diff --git a/examples/jasper/test_jasper.py b/examples/jasper/test_jasper.py index 6889c89b52..8b9cd1f974 100644 --- a/examples/jasper/test_jasper.py +++ b/examples/jasper/test_jasper.py @@ -59,7 +59,7 @@ tf.random.set_seed(0) assert args.export -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) text_featurizer = CharFeaturizer(config.decoder_config) # Build DS2 model diff --git a/examples/jasper/train_ga_jasper.py b/examples/jasper/train_ga_jasper.py index 41e1810011..b9c0d10af9 100644 --- a/examples/jasper/train_ga_jasper.py +++ b/examples/jasper/train_ga_jasper.py @@ -65,7 +65,7 @@ from tensorflow_asr.runners.ctc_runners import CTCTrainerGA from tensorflow_asr.models.jasper import Jasper -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) text_featurizer = CharFeaturizer(config.decoder_config) diff --git a/examples/jasper/train_jasper.py b/examples/jasper/train_jasper.py index 698733d62a..c50f89b556 100644 --- a/examples/jasper/train_jasper.py +++ b/examples/jasper/train_jasper.py @@ -62,7 +62,7 @@ from tensorflow_asr.runners.ctc_runners import CTCTrainer from tensorflow_asr.models.jasper import Jasper -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) text_featurizer = CharFeaturizer(config.decoder_config) diff --git a/examples/jasper/train_keras_jasper.py b/examples/jasper/train_keras_jasper.py new file mode 100644 index 0000000000..ed31f0c168 --- /dev/null +++ b/examples/jasper/train_keras_jasper.py @@ -0,0 +1,123 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +from tensorflow_asr.utils import setup_environment, setup_strategy + +setup_environment() +import tensorflow as tf + +DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml") + +tf.keras.backend.clear_session() + +parser = argparse.ArgumentParser(prog="Jasper Training") + +parser.add_argument("--config", "-c", type=str, default=DEFAULT_YAML, + help="The file path of model configuration file") + +parser.add_argument("--max_ckpts", type=int, default=10, + help="Max number of checkpoints to keep") + +parser.add_argument("--tbs", type=int, default=None, + help="Train batch size per replicas") + +parser.add_argument("--ebs", type=int, default=None, + help="Evaluation batch size per replicas") + +parser.add_argument("--tfrecords", default=False, action="store_true", + help="Whether to use tfrecords dataset") + +parser.add_argument("--devices", type=int, nargs="*", default=[0], + help="Devices' ids to apply distributed training") + +parser.add_argument("--mxp", default=False, action="store_true", + help="Enable mixed precision") + +parser.add_argument("--cache", default=False, action="store_true", + help="Enable caching for dataset") + +args = parser.parse_args() + +tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp}) + +strategy = setup_strategy(args.devices) + +from tensorflow_asr.configs.config import Config +from tensorflow_asr.datasets.keras import ASRTFRecordDatasetKeras, ASRSliceDatasetKeras +from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer +from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer +from tensorflow_asr.models.keras.jasper import Jasper + +config = Config(args.config) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) +text_featurizer = CharFeaturizer(config.decoder_config) + +if args.tfrecords: + train_dataset = ASRTFRecordDatasetKeras( + data_paths=config.learning_config.dataset_config.train_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, + speech_featurizer=speech_featurizer, + text_featurizer=text_featurizer, + augmentations=config.learning_config.augmentations, + stage="train", cache=args.cache, shuffle=True + ) + eval_dataset = ASRTFRecordDatasetKeras( + data_paths=config.learning_config.dataset_config.eval_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, + speech_featurizer=speech_featurizer, + text_featurizer=text_featurizer, + stage="eval", cache=args.cache, shuffle=True + ) +else: + train_dataset = ASRSliceDatasetKeras( + speech_featurizer=speech_featurizer, + text_featurizer=text_featurizer, + data_paths=config.learning_config.dataset_config.train_paths, + augmentations=config.learning_config.augmentations, + stage="train", cache=args.cache, shuffle=True + ) + eval_dataset = ASRSliceDatasetKeras( + speech_featurizer=speech_featurizer, + text_featurizer=text_featurizer, + data_paths=config.learning_config.dataset_config.eval_paths, + stage="eval", cache=args.cache, shuffle=True + ) + +with strategy.scope(): + global_batch_size = config.learning_config.running_config.batch_size + global_batch_size *= strategy.num_replicas_in_sync + + jasper = Jasper(**config.model_config, vocabulary_size=text_featurizer.num_classes) + jasper._build(speech_featurizer.shape) + jasper.summary(line_length=120) + + jasper.compile(optimizer=config.learning_config.optimizer_config, + global_batch_size=global_batch_size, blank=text_featurizer.blank) + + train_data_loader = train_dataset.create(global_batch_size) + eval_data_loader = eval_dataset.create(global_batch_size) + + callbacks = [ + tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), + tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), + tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) + ] + + jasper.fit( + train_data_loader, epochs=config.learning_config.running_config.num_epochs, + validation_data=eval_data_loader, callbacks=callbacks, + steps_per_epoch=train_dataset.total_steps + ) diff --git a/examples/streaming_transducer/config.yml b/examples/streaming_transducer/config.yml index 54551e3c5b..510e1fd140 100755 --- a/examples/streaming_transducer/config.yml +++ b/examples/streaming_transducer/config.yml @@ -70,7 +70,7 @@ learning_config: - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv test_paths: - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv - tfrecords_dir: null + tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords optimizer_config: class_name: adam @@ -85,3 +85,17 @@ learning_config: log_interval_steps: 300 eval_interval_steps: 500 save_interval_steps: 1000 + checkpoint: + filepath: /mnt/Miscellanea/Models/local/streaming_transducer/checkpoints/{epoch:02d}.h5 + save_best_only: True + save_weights_only: False + save_freq: epoch + states_dir: /mnt/Miscellanea/Models/local/streaming_transducer/states + tensorboard: + log_dir: /mnt/Miscellanea/Models/local/streaming_transducer/tensorboard + histogram_freq: 1 + write_graph: True + write_images: True + update_freq: 'epoch' + profile_batch: 2 + diff --git a/examples/streaming_transducer/test_streaming_transducer.py b/examples/streaming_transducer/test_streaming_transducer.py index 921c56486e..896b6edfa7 100755 --- a/examples/streaming_transducer/test_streaming_transducer.py +++ b/examples/streaming_transducer/test_streaming_transducer.py @@ -59,7 +59,7 @@ from tensorflow_asr.runners.base_runners import BaseTester from tensorflow_asr.models.streaming_transducer import StreamingTransducer -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) text_featurizer = CharFeaturizer(config.decoder_config) diff --git a/examples/streaming_transducer/test_subword_streaming_transducer.py b/examples/streaming_transducer/test_subword_streaming_transducer.py index 6d8b2027be..b5e113d701 100755 --- a/examples/streaming_transducer/test_subword_streaming_transducer.py +++ b/examples/streaming_transducer/test_subword_streaming_transducer.py @@ -62,7 +62,7 @@ from tensorflow_asr.runners.base_runners import BaseTester from tensorflow_asr.models.streaming_transducer import StreamingTransducer -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) if args.subwords and os.path.exists(args.subwords): diff --git a/examples/streaming_transducer/tflite_streaming_transducer.py b/examples/streaming_transducer/tflite_streaming_transducer.py index eacb4ba584..b2a2ed6dfb 100644 --- a/examples/streaming_transducer/tflite_streaming_transducer.py +++ b/examples/streaming_transducer/tflite_streaming_transducer.py @@ -43,7 +43,7 @@ assert args.saved and args.output -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) text_featurizer = CharFeaturizer(config.decoder_config) diff --git a/examples/streaming_transducer/tflite_subword_streaming_transducer.py b/examples/streaming_transducer/tflite_subword_streaming_transducer.py index 8bd3d0511b..1930345960 100644 --- a/examples/streaming_transducer/tflite_subword_streaming_transducer.py +++ b/examples/streaming_transducer/tflite_subword_streaming_transducer.py @@ -46,7 +46,7 @@ assert args.saved and args.output -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) if args.subwords and os.path.exists(args.subwords): diff --git a/examples/streaming_transducer/train_ga_streaming_transducer.py b/examples/streaming_transducer/train_ga_streaming_transducer.py index 85452c7705..6f7868edc2 100644 --- a/examples/streaming_transducer/train_ga_streaming_transducer.py +++ b/examples/streaming_transducer/train_ga_streaming_transducer.py @@ -65,7 +65,7 @@ from tensorflow_asr.runners.transducer_runners import TransducerTrainerGA from tensorflow_asr.models.streaming_transducer import StreamingTransducer -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) text_featurizer = CharFeaturizer(config.decoder_config) diff --git a/examples/streaming_transducer/train_ga_subword_streaming_transducer.py b/examples/streaming_transducer/train_ga_subword_streaming_transducer.py index ce3d0c0dbd..2385bc9129 100644 --- a/examples/streaming_transducer/train_ga_subword_streaming_transducer.py +++ b/examples/streaming_transducer/train_ga_subword_streaming_transducer.py @@ -71,7 +71,7 @@ from tensorflow_asr.runners.transducer_runners import TransducerTrainerGA from tensorflow_asr.models.streaming_transducer import StreamingTransducer -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) if args.subwords and os.path.exists(args.subwords): diff --git a/examples/streaming_transducer/train_keras_subword_streaming_transducer.py b/examples/streaming_transducer/train_keras_subword_streaming_transducer.py new file mode 100644 index 0000000000..284943c0d6 --- /dev/null +++ b/examples/streaming_transducer/train_keras_subword_streaming_transducer.py @@ -0,0 +1,143 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +from tensorflow_asr.utils import setup_environment, setup_strategy + +setup_environment() +import tensorflow as tf + +DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml") + +tf.keras.backend.clear_session() + +parser = argparse.ArgumentParser(prog="Conformer Training") + +parser.add_argument("--config", type=str, default=DEFAULT_YAML, + help="The file path of model configuration file") + +parser.add_argument("--max_ckpts", type=int, default=10, + help="Max number of checkpoints to keep") + +parser.add_argument("--tfrecords", default=False, action="store_true", + help="Whether to use tfrecords") + +parser.add_argument("--tbs", type=int, default=None, + help="Train batch size per replica") + +parser.add_argument("--ebs", type=int, default=None, + help="Evaluation batch size per replica") + +parser.add_argument("--devices", type=int, nargs="*", default=[0], + help="Devices' ids to apply distributed training") + +parser.add_argument("--mxp", default=False, action="store_true", + help="Enable mixed precision") + +parser.add_argument("--cache", default=False, action="store_true", + help="Enable caching for dataset") + +parser.add_argument("--subwords", type=str, default=None, + help="Path to file that stores generated subwords") + +parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[], + help="Transcript files for generating subwords") + +args = parser.parse_args() + +tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp}) + +strategy = setup_strategy(args.devices) + +from tensorflow_asr.configs.config import Config +from tensorflow_asr.datasets.keras import ASRTFRecordDatasetKeras, ASRSliceDatasetKeras +from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer +from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer +from tensorflow_asr.models.keras.streaming_transducer import StreamingTransducer + +config = Config(args.config) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) + +if args.subwords and os.path.exists(args.subwords): + print("Loading subwords ...") + text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords) +else: + print("Generating subwords ...") + text_featurizer = SubwordFeaturizer.build_from_corpus( + config.decoder_config, + corpus_files=args.subwords_corpus + ) + text_featurizer.save_to_file(args.subwords) + +if args.tfrecords: + train_dataset = ASRTFRecordDatasetKeras( + data_paths=config.learning_config.dataset_config.train_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, + speech_featurizer=speech_featurizer, + text_featurizer=text_featurizer, + augmentations=config.learning_config.augmentations, + stage="train", cache=args.cache, shuffle=True + ) + eval_dataset = ASRTFRecordDatasetKeras( + data_paths=config.learning_config.dataset_config.eval_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, + speech_featurizer=speech_featurizer, + text_featurizer=text_featurizer, + stage="eval", cache=args.cache, shuffle=True + ) +else: + train_dataset = ASRSliceDatasetKeras( + data_paths=config.learning_config.dataset_config.train_paths, + speech_featurizer=speech_featurizer, + text_featurizer=text_featurizer, + augmentations=config.learning_config.augmentations, + stage="train", cache=args.cache, shuffle=True + ) + eval_dataset = ASRSliceDatasetKeras( + data_paths=config.learning_config.dataset_config.eval_paths, + speech_featurizer=speech_featurizer, + text_featurizer=text_featurizer, + stage="eval", cache=args.cache, shuffle=True + ) + +with strategy.scope(): + global_batch_size = config.learning_config.running_config.batch_size + global_batch_size *= strategy.num_replicas_in_sync + # build model + streaming_transducer = StreamingTransducer( + **config.model_config, + vocabulary_size=text_featurizer.num_classes + ) + streaming_transducer._build(speech_featurizer.shape) + streaming_transducer.summary(line_length=150) + + optimizer = tf.keras.optimizers.get(config.learning_config.optimizer_config) + + streaming_transducer.compile(optimizer=optimizer, global_batch_size=global_batch_size, blank=text_featurizer.blank) + + train_data_loader = train_dataset.create(global_batch_size) + eval_data_loader = eval_dataset.create(global_batch_size) + + callbacks = [ + tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), + tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), + tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) + ] + + streaming_transducer.fit( + train_data_loader, epochs=config.learning_config.running_config.num_epochs, + validation_data=eval_data_loader, callbacks=callbacks, + steps_per_epoch=train_dataset.total_steps + ) diff --git a/examples/streaming_transducer/train_streaming_transducer.py b/examples/streaming_transducer/train_streaming_transducer.py index 6852dc9580..84351d5ad3 100644 --- a/examples/streaming_transducer/train_streaming_transducer.py +++ b/examples/streaming_transducer/train_streaming_transducer.py @@ -62,7 +62,7 @@ from tensorflow_asr.runners.transducer_runners import TransducerTrainer from tensorflow_asr.models.streaming_transducer import StreamingTransducer -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) text_featurizer = CharFeaturizer(config.decoder_config) diff --git a/examples/streaming_transducer/train_subword_streaming_transducer.py b/examples/streaming_transducer/train_subword_streaming_transducer.py index 5971b4735d..42fea7b566 100644 --- a/examples/streaming_transducer/train_subword_streaming_transducer.py +++ b/examples/streaming_transducer/train_subword_streaming_transducer.py @@ -68,7 +68,7 @@ from tensorflow_asr.runners.transducer_runners import TransducerTrainer from tensorflow_asr.models.streaming_transducer import StreamingTransducer -config = Config(args.config, learning=True) +config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) if args.subwords and os.path.exists(args.subwords): diff --git a/scripts/generate_vocab_sentencepiece.py b/scripts/generate_vocab_sentencepiece.py index 040133e488..f5f69a0cef 100644 --- a/scripts/generate_vocab_sentencepiece.py +++ b/scripts/generate_vocab_sentencepiece.py @@ -24,7 +24,7 @@ from tensorflow_asr.configs.config import Config from tensorflow_asr.featurizers.text_featurizers import SentencePieceFeaturizer -config = Config(args.config, learning=True) +config = Config(args.config) print("Generating subwords ...") text_featurizer = SentencePieceFeaturizer.build_from_corpus( diff --git a/setup.py b/setup.py index 6f58ed0ba5..f7a9c54213 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ setuptools.setup( name="TensorFlowASR", - version="0.6.4", + version="0.7.0", author="Huy Le Nguyen", author_email="nlhuy.cs.16@gmail.com", description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2", diff --git a/tensorflow_asr/configs/config.py b/tensorflow_asr/configs/config.py index e6e068da09..3613a56568 100644 --- a/tensorflow_asr/configs/config.py +++ b/tensorflow_asr/configs/config.py @@ -20,60 +20,59 @@ class DecoderConfig: def __init__(self, config: dict = None): if not config: config = {} - self.vocabulary = preprocess_paths(config.get("vocabulary", None)) - self.beam_width = config.get("beam_width", 0) - self.blank_at_zero = config.get("blank_at_zero", True) - self.target_vocab_size = config.get("target_vocab_size", 1024) - self.max_subword_length = config.get("max_subword_length", 4) - self.norm_score = config.get("norm_score", True) - self.lm_config = config.get("lm_config", {}) - self.additional_properties = config.get("additional_properties", {}) - self.output_path_prefix = preprocess_paths(config.get("output_path_prefix", None)) - self.model_type = config.get("model_type", None) - self.corpus_files = config.get("corpus_files", None) + self.vocabulary = preprocess_paths(config.pop("vocabulary", None)) + self.beam_width = config.pop("beam_width", 0) + self.blank_at_zero = config.pop("blank_at_zero", True) + self.target_vocab_size = config.pop("target_vocab_size", 1024) + self.max_subword_length = config.pop("max_subword_length", 4) + self.norm_score = config.pop("norm_score", True) + self.lm_config = config.pop("lm_config", {}) + self.output_path_prefix = preprocess_paths(config.pop("output_path_prefix", None)) + self.model_type = config.pop("model_type", None) + self.corpus_files = config.pop("corpus_files", None) + for k, v in config.items(): setattr(self, k, v) class DatasetConfig: def __init__(self, config: dict = None): if not config: config = {} - self.train_paths = preprocess_paths(config.get("train_paths", None)) - self.eval_paths = preprocess_paths(config.get("eval_paths", None)) - self.test_paths = preprocess_paths(config.get("test_paths", None)) - self.tfrecords_dir = preprocess_paths(config.get("tfrecords_dir", None)) - self.additional_properties = config.get("additional_properties", {}) + self.train_paths = preprocess_paths(config.pop("train_paths", None)) + self.eval_paths = preprocess_paths(config.pop("eval_paths", None)) + self.test_paths = preprocess_paths(config.pop("test_paths", None)) + self.tfrecords_dir = preprocess_paths(config.pop("tfrecords_dir", None)) + for k, v in config.items(): setattr(self, k, v) class RunningConfig: def __init__(self, config: dict = None): if not config: config = {} - self.batch_size = config.get("batch_size", 1) - self.accumulation_steps = config.get("accumulation_steps", 1) - self.num_epochs = config.get("num_epochs", 20) - self.outdir = preprocess_paths(config.get("outdir", None)) - self.log_interval_steps = config.get("log_interval_steps", 500) - self.save_interval_steps = config.get("save_interval_steps", 500) - self.eval_interval_steps = config.get("eval_interval_steps", 1000) - self.additional_properties = config.get("additional_properties", {}) + self.batch_size = config.pop("batch_size", 1) + self.accumulation_steps = config.pop("accumulation_steps", 1) + self.num_epochs = config.pop("num_epochs", 20) + self.outdir = preprocess_paths(config.pop("outdir", None)) + self.log_interval_steps = config.pop("log_interval_steps", 500) + self.save_interval_steps = config.pop("save_interval_steps", 500) + self.eval_interval_steps = config.pop("eval_interval_steps", 1000) + for k, v in config.items(): setattr(self, k, v) class LearningConfig: def __init__(self, config: dict = None): if not config: config = {} - self.augmentations = Augmentation(config.get("augmentations")) - self.dataset_config = DatasetConfig(config.get("dataset_config")) - self.optimizer_config = config.get("optimizer_config", {}) - self.running_config = RunningConfig(config.get("running_config")) - self.additional_properties = config.get("additional_properties", {}) + self.augmentations = Augmentation(config.pop("augmentations", {})) + self.dataset_config = DatasetConfig(config.pop("dataset_config", {})) + self.optimizer_config = config.pop("optimizer_config", {}) + self.running_config = RunningConfig(config.pop("running_config", {})) + for k, v in config.items(): setattr(self, k, v) class Config: """ User config class for training, testing or infering """ - def __init__(self, path: str, learning: bool): + def __init__(self, path: str): config = load_yaml(preprocess_paths(path)) - self.speech_config = config.get("speech_config", {}) - self.decoder_config = config.get("decoder_config", {}) - self.model_config = config.get("model_config", {}) - self.additional_properties = config.get("additional_properties", {}) - if learning: - self.learning_config = LearningConfig(config.get("learning_config")) + self.speech_config = config.pop("speech_config", {}) + self.decoder_config = config.pop("decoder_config", {}) + self.model_config = config.pop("model_config", {}) + self.learning_config = LearningConfig(config.pop("learning_config", {})) + for k, v in config.items(): setattr(self, k, v) diff --git a/tensorflow_asr/datasets/__init__.py b/tensorflow_asr/datasets/__init__.py index e69de29bb2..18f3bb5640 100644 --- a/tensorflow_asr/datasets/__init__.py +++ b/tensorflow_asr/datasets/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base_dataset import BaseDataset +from .asr_dataset import ASRTFRecordDataset, ASRSliceDataset, ASRTFRecordTestDataset, ASRSliceTestDataset +__all__ = ['BaseDataset', 'ASRTFRecordDataset', 'ASRSliceDataset', 'ASRTFRecordTestDataset', 'ASRSliceTestDataset'] diff --git a/tensorflow_asr/datasets/asr_dataset.py b/tensorflow_asr/datasets/asr_dataset.py index ab2e8690fc..e0d9832e74 100755 --- a/tensorflow_asr/datasets/asr_dataset.py +++ b/tensorflow_asr/datasets/asr_dataset.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import abc import multiprocessing import os diff --git a/tensorflow_asr/datasets/keras/__init__.py b/tensorflow_asr/datasets/keras/__init__.py new file mode 100644 index 0000000000..c0a7bc02e1 --- /dev/null +++ b/tensorflow_asr/datasets/keras/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .asr_dataset import ASRTFRecordDatasetKeras, ASRSliceDatasetKeras +__all__ = ['ASRTFRecordDatasetKeras', 'ASRSliceDatasetKeras'] diff --git a/tensorflow_asr/datasets/keras/asr_dataset.py b/tensorflow_asr/datasets/keras/asr_dataset.py new file mode 100644 index 0000000000..a7cb310c08 --- /dev/null +++ b/tensorflow_asr/datasets/keras/asr_dataset.py @@ -0,0 +1,199 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import multiprocessing +import tensorflow as tf +import numpy as np + +from ..asr_dataset import ASRDataset, AUTOTUNE, TFRECORD_SHARDS, write_tfrecord_file +from ..base_dataset import BUFFER_SIZE +from ...featurizers.speech_featurizers import SpeechFeaturizer +from ...featurizers.text_featurizers import TextFeaturizer +from ...utils.utils import get_num_batches +from ...augmentations.augments import Augmentation + + +class ASRDatasetKeras(ASRDataset): + def process(self, dataset, batch_size): + dataset = dataset.map(self.parse, num_parallel_calls=AUTOTUNE) + + if self.cache: + dataset = dataset.cache() + + if self.shuffle: + dataset = dataset.shuffle(self.buffer_size, reshuffle_each_iteration=True) + + # PADDED BATCH the dataset + dataset = dataset.padded_batch( + batch_size=batch_size, + padded_shapes=( + { + "input": tf.TensorShape(self.speech_featurizer.shape), + "input_length": tf.TensorShape([]), + "prediction": tf.TensorShape([None]), + "prediction_length": tf.TensorShape([]) + }, + { + "label": tf.TensorShape([None]), + "label_length": tf.TensorShape([]) + }, + ), + padding_values=( + { + "input": 0., + "input_length": 0, + "prediction": self.text_featurizer.blank, + "prediction_length": 0 + }, + { + "label": self.text_featurizer.blank, + "label_length": 0 + } + ), + drop_remainder=True + ) + + # PREFETCH to improve speed of input length + dataset = dataset.prefetch(AUTOTUNE) + self.total_steps = get_num_batches(self.total_steps, batch_size) + return dataset + + +class ASRTFRecordDatasetKeras(ASRDatasetKeras): + """ Keras Dataset for ASR using TFRecords """ + + def __init__(self, + data_paths: list, + tfrecords_dir: str, + speech_featurizer: SpeechFeaturizer, + text_featurizer: TextFeaturizer, + stage: str, + augmentations: Augmentation = Augmentation(None), + tfrecords_shards: int = TFRECORD_SHARDS, + cache: bool = False, + shuffle: bool = False, + buffer_size: int = BUFFER_SIZE): + super(ASRTFRecordDatasetKeras, self).__init__( + stage=stage, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, + data_paths=data_paths, augmentations=augmentations, cache=cache, shuffle=shuffle, buffer_size=buffer_size + ) + self.tfrecords_dir = tfrecords_dir + if tfrecords_shards <= 0: raise ValueError("tfrecords_shards must be positive") + self.tfrecords_shards = tfrecords_shards + if not tf.io.gfile.exists(self.tfrecords_dir): + tf.io.gfile.makedirs(self.tfrecords_dir) + + def create_tfrecords(self): + if not tf.io.gfile.exists(self.tfrecords_dir): + tf.io.gfile.makedirs(self.tfrecords_dir) + + if tf.io.gfile.glob(os.path.join(self.tfrecords_dir, f"{self.stage}*.tfrecord")): + print(f"TFRecords're already existed: {self.stage}") + return True + + print(f"Creating {self.stage}.tfrecord ...") + + entries = self.read_entries() + if len(entries) <= 0: + return False + + def get_shard_path(shard_id): + return os.path.join(self.tfrecords_dir, f"{self.stage}_{shard_id}.tfrecord") + + shards = [get_shard_path(idx) for idx in range(1, self.tfrecords_shards + 1)] + + splitted_entries = np.array_split(entries, self.tfrecords_shards) + with multiprocessing.Pool(self.tfrecords_shards) as pool: + pool.map(write_tfrecord_file, zip(shards, splitted_entries)) + + return True + + @tf.function + def parse(self, record): + feature_description = { + "path": tf.io.FixedLenFeature([], tf.string), + "audio": tf.io.FixedLenFeature([], tf.string), + "transcript": tf.io.FixedLenFeature([], tf.string) + } + example = tf.io.parse_single_example(record, feature_description) + + features, input_length, label, label_length, \ + prediction, prediction_length = tf.numpy_function( + self.preprocess, + inp=[example["audio"], example["transcript"]], + Tout=[tf.float32, tf.int32, tf.int32, tf.int32, tf.int32, tf.int32] + ) + + return ( + { + "input": features, + "input_length": input_length, + "prediction": prediction, + "prediction_length": prediction_length + }, + { + "label": label, + "label_length": label_length + } + ) + + def create(self, batch_size): + # Create TFRecords dataset + have_data = self.create_tfrecords() + if not have_data: return None + + pattern = os.path.join(self.tfrecords_dir, f"{self.stage}*.tfrecord") + files_ds = tf.data.Dataset.list_files(pattern) + ignore_order = tf.data.Options() + ignore_order.experimental_deterministic = False + files_ds = files_ds.with_options(ignore_order) + dataset = tf.data.TFRecordDataset(files_ds, compression_type='ZLIB', num_parallel_reads=AUTOTUNE) + + return self.process(dataset, batch_size) + + +class ASRSliceDatasetKeras(ASRDatasetKeras): + """ Keras Dataset for ASR using Slice """ + + def preprocess(self, path, transcript): + return super(ASRSliceDatasetKeras, self).preprocess(path.decode("utf-8"), transcript) + + @tf.function + def parse(self, record): + features, input_length, label, label_length, \ + prediction, prediction_length = tf.numpy_function( + self.preprocess, + inp=[record[0], record[1]], + Tout=[tf.float32, tf.int32, tf.int32, tf.int32, tf.int32, tf.int32] + ) + return ( + { + "input": features, + "input_length": input_length, + "prediction": prediction, + "prediction_length": prediction_length + }, + { + "label": label, + "label_length": label_length + } + ) + + def create(self, batch_size): + entries = self.read_entries() + if len(entries) == 0: return None + entries = np.delete(entries, 1, 1) # Remove unused duration + dataset = tf.data.Dataset.from_tensor_slices(entries) + return self.process(dataset, batch_size) diff --git a/tensorflow_asr/losses/__init__.py b/tensorflow_asr/losses/__init__.py index e69de29bb2..f9ae63d25d 100644 --- a/tensorflow_asr/losses/__init__.py +++ b/tensorflow_asr/losses/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .ctc_losses import ctc_loss +from .rnnt_losses import rnnt_loss +__all__ = ['ctc_loss', 'rnnt_loss'] diff --git a/tensorflow_asr/losses/keras/__init__.py b/tensorflow_asr/losses/keras/__init__.py new file mode 100644 index 0000000000..4b667418c3 --- /dev/null +++ b/tensorflow_asr/losses/keras/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .rnnt_losses import RnntLoss +from .ctc_losses import CtcLoss +__all__ = ['RnntLoss', 'CtcLoss'] diff --git a/tensorflow_asr/losses/keras/ctc_losses.py b/tensorflow_asr/losses/keras/ctc_losses.py new file mode 100644 index 0000000000..e8eddbc46c --- /dev/null +++ b/tensorflow_asr/losses/keras/ctc_losses.py @@ -0,0 +1,33 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from tensorflow.python.keras.utils import losses_utils + +from .. import ctc_loss + + +class CtcLoss(tf.keras.losses.Loss): + def __init__(self, blank=0, global_batch_size=None, reduction=losses_utils.ReductionV2.NONE, name=None): + super(CtcLoss, self).__init__(reduction=reduction, name=name) + self.blank = blank + self.global_batch_size = global_batch_size + + def call(self, y_true, y_pred): + logits = y_pred["logit"] + logit_length = y_pred["logit_length"] + labels = y_true["label"] + label_length = y_true["label_length"] + loss = ctc_loss(labels, logits, logit_length, label_length, blank=self.blank) + return tf.nn.compute_average_loss(loss, global_batch_size=self.global_batch_size) diff --git a/tensorflow_asr/losses/keras/rnnt_losses.py b/tensorflow_asr/losses/keras/rnnt_losses.py new file mode 100644 index 0000000000..c100bb07df --- /dev/null +++ b/tensorflow_asr/losses/keras/rnnt_losses.py @@ -0,0 +1,33 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from tensorflow.python.keras.utils import losses_utils + +from .. import rnnt_loss + + +class RnntLoss(tf.keras.losses.Loss): + def __init__(self, blank=0, global_batch_size=None, reduction=losses_utils.ReductionV2.NONE, name=None): + super(RnntLoss, self).__init__(reduction=reduction, name=name) + self.blank = blank + self.global_batch_size = global_batch_size + + def call(self, y_true, y_pred): + logits = y_pred["logit"] + logit_length = y_pred["logit_length"] + labels = y_true["label"] + label_length = y_true["label_length"] + loss = rnnt_loss(logits, labels, label_length, logit_length, blank=self.blank) + return tf.nn.compute_average_loss(loss, global_batch_size=self.global_batch_size) diff --git a/tensorflow_asr/models/keras/__init__.py b/tensorflow_asr/models/keras/__init__.py new file mode 100644 index 0000000000..c494840752 --- /dev/null +++ b/tensorflow_asr/models/keras/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .transducer import Transducer +from .conformer import Conformer +__all__ = ['Transducer', 'Conformer'] diff --git a/tensorflow_asr/models/keras/conformer.py b/tensorflow_asr/models/keras/conformer.py new file mode 100644 index 0000000000..160d395ada --- /dev/null +++ b/tensorflow_asr/models/keras/conformer.py @@ -0,0 +1,79 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .transducer import Transducer +from ..conformer import ConformerEncoder, L2 + + +class Conformer(Transducer): + def __init__(self, + vocabulary_size: int, + encoder_subsampling: dict, + encoder_positional_encoding: str = "sinusoid", + encoder_dmodel: int = 144, + encoder_num_blocks: int = 16, + encoder_head_size: int = 36, + encoder_num_heads: int = 4, + encoder_mha_type: str = "relmha", + encoder_kernel_size: int = 32, + encoder_depth_multiplier: int = 1, + encoder_fc_factor: float = 0.5, + encoder_dropout: float = 0, + prediction_embed_dim: int = 512, + prediction_embed_dropout: int = 0, + prediction_num_rnns: int = 1, + prediction_rnn_units: int = 320, + prediction_rnn_type: str = "lstm", + prediction_rnn_implementation: int = 2, + prediction_layer_norm: bool = True, + prediction_projection_units: int = 0, + joint_dim: int = 1024, + joint_activation: str = "tanh", + kernel_regularizer=L2, + bias_regularizer=L2, + name: str = "conformer_transducer", + **kwargs): + super(Conformer, self).__init__( + encoder=ConformerEncoder( + subsampling=encoder_subsampling, + positional_encoding=encoder_positional_encoding, + dmodel=encoder_dmodel, + num_blocks=encoder_num_blocks, + head_size=encoder_head_size, + num_heads=encoder_num_heads, + mha_type=encoder_mha_type, + kernel_size=encoder_kernel_size, + depth_multiplier=encoder_depth_multiplier, + fc_factor=encoder_fc_factor, + dropout=encoder_dropout, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer + ), + vocabulary_size=vocabulary_size, + embed_dim=prediction_embed_dim, + embed_dropout=prediction_embed_dropout, + num_rnns=prediction_num_rnns, + rnn_units=prediction_rnn_units, + rnn_type=prediction_rnn_type, + rnn_implementation=prediction_rnn_implementation, + layer_norm=prediction_layer_norm, + projection_units=prediction_projection_units, + joint_dim=joint_dim, + joint_activation=joint_activation, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name=name, **kwargs + ) + self.dmodel = encoder_dmodel + self.time_reduction_factor = self.encoder.conv_subsampling.time_reduction_factor diff --git a/tensorflow_asr/models/keras/contextnet.py b/tensorflow_asr/models/keras/contextnet.py new file mode 100644 index 0000000000..a3489c559c --- /dev/null +++ b/tensorflow_asr/models/keras/contextnet.py @@ -0,0 +1,168 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List +import tensorflow as tf + +from .transducer import Transducer +from ..contextnet import ContextNetEncoder, L2 +from ...utils.utils import get_reduced_length + + +class ContextNet(Transducer): + def __init__(self, + vocabulary_size: int, + encoder_blocks: List[dict], + encoder_alpha: float = 0.5, + prediction_embed_dim: int = 512, + prediction_embed_dropout: int = 0, + prediction_num_rnns: int = 1, + prediction_rnn_units: int = 320, + prediction_rnn_type: str = "lstm", + prediction_rnn_implementation: int = 2, + prediction_layer_norm: bool = True, + prediction_projection_units: int = 0, + joint_dim: int = 1024, + joint_activation: str = "tanh", + kernel_regularizer=L2, + bias_regularizer=L2, + name: str = "contextnet", + **kwargs): + super(ContextNet, self).__init__( + encoder=ContextNetEncoder( + blocks=encoder_blocks, + alpha=encoder_alpha, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name=f"{name}_encoder" + ), + vocabulary_size=vocabulary_size, + embed_dim=prediction_embed_dim, + embed_dropout=prediction_embed_dropout, + num_rnns=prediction_num_rnns, + rnn_units=prediction_rnn_units, + rnn_type=prediction_rnn_type, + rnn_implementation=prediction_rnn_implementation, + layer_norm=prediction_layer_norm, + projection_units=prediction_projection_units, + joint_dim=joint_dim, + joint_activation=joint_activation, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name=name, **kwargs + ) + self.dmodel = self.encoder.blocks[-1].dmodel + self.time_reduction_factor = 1 + for block in self.encoder.blocks: self.time_reduction_factor *= block.time_reduction_factor + + def call(self, inputs, training=False, **kwargs): + enc = self.encoder([inputs['input'], inputs['input_length']], training=training, **kwargs) + pred = self.predict_net([inputs['prediction'], inputs['prediction_length']], training=training, **kwargs) + outputs = self.joint_net([enc, pred], training=training, **kwargs) + return { + 'logit': outputs, + 'logit_length': get_reduced_length(inputs['input_length'], self.time_reduction_factor) + } + + def encoder_inference(self, features: tf.Tensor, input_length: tf.Tensor): + with tf.name_scope(f"{self.name}_encoder"): + input_length = tf.expand_dims(tf.shape(features)[0], axis=0) + outputs = tf.expand_dims(features, axis=0) + outputs = self.encoder([outputs, input_length], training=False) + return tf.squeeze(outputs, axis=0) + + # -------------------------------- GREEDY ------------------------------------- + + @tf.function + def recognize(self, + features: tf.Tensor, + input_length: tf.Tensor, + parallel_iterations: int = 10, + swap_memory: bool = True): + """ + RNN Transducer Greedy decoding + Args: + features (tf.Tensor): a batch of padded extracted features + + Returns: + tf.Tensor: a batch of decoded transcripts + """ + encoded = self.encoder([features, input_length], training=False) + return self._perform_greedy_batch(encoded, input_length, + parallel_iterations=parallel_iterations, swap_memory=swap_memory) + + def recognize_tflite(self, signal, predicted, prediction_states): + """ + Function to convert to tflite using greedy decoding (default streaming mode) + Args: + signal: tf.Tensor with shape [None] indicating a single audio signal + predicted: last predicted character with shape [] + prediction_states: lastest prediction states with shape [num_rnns, 1 or 2, 1, P] + + Return: + transcript: tf.Tensor of Unicode Code Points with shape [None] and dtype tf.int32 + predicted: last predicted character with shape [] + encoder_states: lastest encoder states with shape [num_rnns, 1 or 2, 1, P] + prediction_states: lastest prediction states with shape [num_rnns, 1 or 2, 1, P] + """ + features = self.speech_featurizer.tf_extract(signal) + encoded = self.encoder_inference(features, tf.shape(features)[0]) + hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, prediction_states) + transcript = self.text_featurizer.indices2upoints(hypothesis.prediction) + return transcript, hypothesis.index, hypothesis.states + + def recognize_tflite_with_timestamp(self, signal, predicted, states): + features = self.speech_featurizer.tf_extract(signal) + encoded = self.encoder_inference(features, tf.shape(features)[0]) + hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, states) + indices = self.text_featurizer.normalize_indices(hypothesis.prediction) + upoints = tf.gather_nd(self.text_featurizer.upoints, tf.expand_dims(indices, axis=-1)) # [None, max_subword_length] + + num_samples = tf.cast(tf.shape(signal)[0], dtype=tf.float32) + total_time_reduction_factor = self.time_reduction_factor * self.speech_featurizer.frame_step + + stime = tf.range(0, num_samples, delta=total_time_reduction_factor, dtype=tf.float32) + stime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32) + + etime = tf.range(total_time_reduction_factor, num_samples, delta=total_time_reduction_factor, dtype=tf.float32) + etime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32) + + non_blank = tf.where(tf.not_equal(upoints, 0)) + non_blank_transcript = tf.gather_nd(upoints, non_blank) + non_blank_stime = tf.gather_nd(tf.repeat(tf.expand_dims(stime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank) + non_blank_etime = tf.gather_nd(tf.repeat(tf.expand_dims(etime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank) + + return non_blank_transcript, non_blank_stime, non_blank_etime, hypothesis.index, hypothesis.states + + # -------------------------------- BEAM SEARCH ------------------------------------- + + @tf.function + def recognize_beam(self, + features: tf.Tensor, + input_length: tf.Tensor, + lm: bool = False, + parallel_iterations: int = 10, + swap_memory: bool = True): + """ + RNN Transducer Beam Search + Args: + features (tf.Tensor): a batch of padded extracted features + lm (bool, optional): whether to use language model. Defaults to False. + + Returns: + tf.Tensor: a batch of decoded transcripts + """ + encoded = self.encoder([features, input_length], training=False) + return self._perform_beam_search_batch(encoded, input_length, lm, + parallel_iterations=parallel_iterations, swap_memory=swap_memory) diff --git a/tensorflow_asr/models/keras/ctc.py b/tensorflow_asr/models/keras/ctc.py new file mode 100644 index 0000000000..6b4cc8c304 --- /dev/null +++ b/tensorflow_asr/models/keras/ctc.py @@ -0,0 +1,60 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from tensorflow.keras import mixed_precision as mxp + +from ..ctc import CtcModel as BaseCtcModel +from ...utils.utils import get_reduced_length +from ...losses.keras.ctc_losses import CtcLoss + + +class CtcModel(BaseCtcModel): + """ Keras CTC Model Warper """ + + def compile(self, optimizer, global_batch_size, blank=0, + loss_weights=None, weighted_metrics=None, run_eagerly=None, **kwargs): + loss = CtcLoss(blank=blank, global_batch_size=global_batch_size) + optimizer_with_scale = mxp.experimental.LossScaleOptimizer(tf.keras.optimizers.get(optimizer), 'dynamic') + super(CtcModel, self).compile( + optimizer=optimizer_with_scale, loss=loss, + loss_weights=loss_weights, weighted_metrics=weighted_metrics, + run_eagerly=run_eagerly, + **kwargs + ) + + def train_step(self, batch): + x, y_true = batch + with tf.GradientTape() as tape: + logit = self(x['input'], training=True) + y_pred = { + 'logit': logit, + 'logit_length': get_reduced_length(x['input_length'], self.time_reduction_factor) + } + loss = self.loss(y_true, y_pred) + scaled_loss = self.optimizer.get_scaled_loss(loss) + scaled_gradients = tape.gradient(scaled_loss, self.trainable_weights) + gradients = self.optimizer.get_unscaled_gradients(scaled_gradients) + self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) + return {"train_ctc_loss": loss} + + def test_step(self, batch): + x, y_true = batch + logit = self(x, training=False) + y_pred = { + 'logit': logit, + 'logit_length': get_reduced_length(x['input_length'], self.time_reduction_factor) + } + loss = self.loss(y_true, y_pred) + return {"val_ctc_loss": loss} diff --git a/tensorflow_asr/models/keras/deepspeech2.py b/tensorflow_asr/models/keras/deepspeech2.py new file mode 100644 index 0000000000..0c685e87c5 --- /dev/null +++ b/tensorflow_asr/models/keras/deepspeech2.py @@ -0,0 +1,86 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .ctc import CtcModel +from ..deepspeech2 import ConvModule, RnnModule, FcModule + + +class DeepSpeech2(CtcModel): + def __init__(self, + vocabulary_size: int, + conv_type: str = "conv2d", + conv_kernels: list = [[11, 41], [11, 21], [11, 21]], + conv_strides: list = [[2, 2], [1, 2], [1, 2]], + conv_filters: list = [32, 32, 96], + conv_dropout: float = 0.1, + rnn_nlayers: int = 5, + rnn_type: str = "lstm", + rnn_units: int = 1024, + rnn_bidirectional: bool = True, + rnn_rowconv: int = 0, + rnn_dropout: float = 0.1, + fc_nlayers: int = 0, + fc_units: int = 1024, + fc_dropout: float = 0.1, + name: str = "deepspeech2", + **kwargs): + super(DeepSpeech2, self).__init__(name=name, **kwargs) + + self.conv_module = ConvModule( + conv_type=conv_type, + kernels=conv_kernels, + strides=conv_strides, + filters=conv_filters, + dropout=conv_dropout, + name=f"{self.name}_conv_module" + ) + + self.rnn_module = RnnModule( + nlayers=rnn_nlayers, + rnn_type=rnn_type, + units=rnn_units, + bidirectional=rnn_bidirectional, + rowconv=rnn_rowconv, + dropout=rnn_dropout, + name=f"{self.name}_rnn_module" + ) + + self.fc_module = FcModule( + nlayers=fc_nlayers, + units=fc_units, + dropout=fc_dropout, + vocabulary_size=vocabulary_size, + name=f"{self.name}_fc_module" + ) + + self.time_reduction_factor = self.conv_module.reduction_factor + + def call(self, inputs, training=False, **kwargs): + outputs = self.conv_module(inputs, training=training, **kwargs) + outputs = self.rnn_module(outputs, training=training, **kwargs) + outputs = self.fc_module(outputs, training=training, **kwargs) + return outputs + + def summary(self, line_length=100, **kwargs): + self.conv_module.summary(line_length=line_length, **kwargs) + self.rnn_module.summary(line_length=line_length, **kwargs) + self.fc_module.summary(line_length=line_length, **kwargs) + super(DeepSpeech2, self).summary(line_length=line_length, **kwargs) + + def get_config(self): + conf = super(DeepSpeech2, self).get_config() + conf.update(self.conv_module.get_config()) + conf.update(self.rnn_module.get_config()) + conf.update(self.fc_module.get_config()) + return conf diff --git a/tensorflow_asr/models/keras/jasper.py b/tensorflow_asr/models/keras/jasper.py new file mode 100644 index 0000000000..b19010d1b0 --- /dev/null +++ b/tensorflow_asr/models/keras/jasper.py @@ -0,0 +1,137 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf + +from .ctc import CtcModel +from ..jasper import Reshape, JasperBlock, JasperSubBlock + + +class Jasper(CtcModel): + def __init__(self, + vocabulary_size: int, + dense: bool = False, + first_additional_block_channels: int = 256, + first_additional_block_kernels: int = 11, + first_additional_block_strides: int = 2, + first_additional_block_dilation: int = 1, + first_additional_block_dropout: int = 0.2, + nsubblocks: int = 5, + block_channels: list = [256, 384, 512, 640, 768], + block_kernels: list = [11, 13, 17, 21, 25], + block_dropout: list = [0.2, 0.2, 0.2, 0.3, 0.3], + second_additional_block_channels: int = 896, + second_additional_block_kernels: int = 1, + second_additional_block_strides: int = 1, + second_additional_block_dilation: int = 2, + second_additional_block_dropout: int = 0.4, + third_additional_block_channels: int = 1024, + third_additional_block_kernels: int = 1, + third_additional_block_strides: int = 1, + third_additional_block_dilation: int = 1, + third_additional_block_dropout: int = 0.4, + kernel_regularizer=None, + bias_regularizer=None, + name: str = "jasper", + **kwargs): + super(Jasper, self).__init__(name=name, **kwargs) + + assert len(block_channels) == len(block_kernels) == len(block_dropout) + + self.reshape = Reshape(name=f"{self.name}_reshape") + + self.first_additional_block = JasperSubBlock( + channels=first_additional_block_channels, + kernels=first_additional_block_kernels, + strides=first_additional_block_strides, + dropout=first_additional_block_dropout, + dilation=first_additional_block_dilation, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name=f"{self.name}_first_block" + ) + + self.blocks = [ + JasperBlock( + nsubblocks=nsubblocks, + channels=block_channels[i], + kernels=block_kernels[i], + dropout=block_dropout[i], + dense=dense, + nresiduals=(i + 1) if dense else 1, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name=f"{self.name}_block_{i}" + ) for i in range(len(block_channels)) + ] + + self.second_additional_block = JasperSubBlock( + channels=second_additional_block_channels, + kernels=second_additional_block_kernels, + strides=second_additional_block_strides, + dropout=second_additional_block_dropout, + dilation=second_additional_block_dilation, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name=f"{self.name}_second_block" + ) + + self.third_additional_block = JasperSubBlock( + channels=third_additional_block_channels, + kernels=third_additional_block_kernels, + strides=third_additional_block_strides, + dropout=third_additional_block_dropout, + dilation=third_additional_block_dilation, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name=f"{self.name}_third_block" + ) + + self.last_block = tf.keras.layers.Conv1D( + filters=vocabulary_size, kernel_size=1, + strides=1, padding="same", + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name=f"{self.name}_last_block" + ) + + self.time_reduction_factor = self.first_additional_block.reduction_factor + self.time_reduction_factor *= self.second_additional_block.reduction_factor + self.time_reduction_factor *= self.third_additional_block.reduction_factor + + def call(self, inputs, training=False, **kwargs): + outputs = self.reshape(inputs) + outputs = self.first_additional_block(outputs, training=training, **kwargs) + + residuals = [] + for block in self.blocks: + outputs, residuals = block([outputs, residuals], training=training, **kwargs) + + outputs = self.second_additional_block(outputs, training=training, **kwargs) + outputs = self.third_additional_block(outputs, training=training, **kwargs) + outputs = self.last_block(outputs, training=training, **kwargs) + return outputs + + def summary(self, line_length=100, **kwargs): + super(Jasper, self).summary(line_length=line_length, **kwargs) + + def get_config(self): + conf = self.reshape.get_config() + conf.update(self.first_additional_block.get_config()) + for block in self.blocks: + conf.update(block.get_config()) + conf.update(self.second_additional_block.get_config()) + conf.update(self.third_additional_block.get_config()) + conf.update(self.last_block.get_config()) + return conf diff --git a/tensorflow_asr/models/keras/streaming_transducer.py b/tensorflow_asr/models/keras/streaming_transducer.py new file mode 100644 index 0000000000..928d37d76d --- /dev/null +++ b/tensorflow_asr/models/keras/streaming_transducer.py @@ -0,0 +1,186 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf + + +from .transducer import Transducer +from ..streaming_transducer import StreamingTransducerEncoder + + +class StreamingTransducer(Transducer): + def __init__(self, + vocabulary_size: int, + encoder_reductions: dict = {0: 3, 1: 2}, + encoder_dmodel: int = 640, + encoder_nlayers: int = 8, + encoder_rnn_type: str = "lstm", + encoder_rnn_units: int = 2048, + encoder_layer_norm: bool = True, + prediction_embed_dim: int = 320, + prediction_embed_dropout: float = 0, + prediction_num_rnns: int = 2, + prediction_rnn_units: int = 2048, + prediction_rnn_type: str = "lstm", + prediction_layer_norm: bool = True, + prediction_projection_units: int = 640, + joint_dim: int = 640, + joint_activation: str = "tanh", + kernel_regularizer = None, + bias_regularizer = None, + name = "StreamingTransducer", + **kwargs): + super(StreamingTransducer, self).__init__( + encoder=StreamingTransducerEncoder( + reductions=encoder_reductions, + dmodel=encoder_dmodel, + nlayers=encoder_nlayers, + rnn_type=encoder_rnn_type, + rnn_units=encoder_rnn_units, + layer_norm=encoder_layer_norm, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name=f"{name}_encoder" + ), + vocabulary_size=vocabulary_size, + embed_dim=prediction_embed_dim, + embed_dropout=prediction_embed_dropout, + num_rnns=prediction_num_rnns, + rnn_units=prediction_rnn_units, + rnn_type=prediction_rnn_type, + layer_norm=prediction_layer_norm, + projection_units=prediction_projection_units, + joint_dim=joint_dim, + joint_activation=joint_activation, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name=name, **kwargs + ) + self.time_reduction_factor = self.encoder.time_reduction_factor + + def encoder_inference(self, features: tf.Tensor, states: tf.Tensor): + """Infer function for encoder (or encoders) + + Args: + features (tf.Tensor): features with shape [T, F, C] + states (tf.Tensor): previous states of encoders with shape [num_rnns, 1 or 2, 1, P] + + Returns: + tf.Tensor: output of encoders with shape [T, E] + tf.Tensor: states of encoders with shape [num_rnns, 1 or 2, 1, P] + """ + with tf.name_scope(f"{self.name}_encoder"): + outputs = tf.expand_dims(features, axis=0) + outputs, new_states = self.encoder.recognize(outputs, states) + return tf.squeeze(outputs, axis=0), new_states + + # -------------------------------- GREEDY ------------------------------------- + + @tf.function + def recognize(self, + features: tf.Tensor, + input_length: tf.Tensor, + parallel_iterations: int = 10, + swap_memory: bool = True): + """ + RNN Transducer Greedy decoding + Args: + features (tf.Tensor): a batch of padded extracted features + + Returns: + tf.Tensor: a batch of decoded transcripts + """ + encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state()) + return self._perform_greedy_batch(encoded, input_length, + parallel_iterations=parallel_iterations, swap_memory=swap_memory) + + def recognize_tflite(self, signal, predicted, encoder_states, prediction_states): + """ + Function to convert to tflite using greedy decoding (default streaming mode) + Args: + signal: tf.Tensor with shape [None] indicating a single audio signal + predicted: last predicted character with shape [] + encoder_states: lastest encoder states with shape [num_rnns, 1 or 2, 1, P] + prediction_states: lastest prediction states with shape [num_rnns, 1 or 2, 1, P] + + Return: + transcript: tf.Tensor of Unicode Code Points with shape [None] and dtype tf.int32 + predicted: last predicted character with shape [] + encoder_states: lastest encoder states with shape [num_rnns, 1 or 2, 1, P] + prediction_states: lastest prediction states with shape [num_rnns, 1 or 2, 1, P] + """ + features = self.speech_featurizer.tf_extract(signal) + encoded, new_encoder_states = self.encoder_inference(features, encoder_states) + hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, prediction_states) + transcript = self.text_featurizer.indices2upoints(hypothesis.prediction) + return transcript, hypothesis.index, new_encoder_states, hypothesis.states + + def recognize_tflite_with_timestamp(self, signal, predicted, encoder_states, prediction_states): + features = self.speech_featurizer.tf_extract(signal) + encoded, new_encoder_states = self.encoder_inference(features, encoder_states) + hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, prediction_states) + indices = self.text_featurizer.normalize_indices(hypothesis.prediction) + upoints = tf.gather_nd(self.text_featurizer.upoints, tf.expand_dims(indices, axis=-1)) # [None, max_subword_length] + + num_samples = tf.cast(tf.shape(signal)[0], dtype=tf.float32) + total_time_reduction_factor = self.time_reduction_factor * self.speech_featurizer.frame_step + + stime = tf.range(0, num_samples, delta=total_time_reduction_factor, dtype=tf.float32) + stime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32) + + etime = tf.range(total_time_reduction_factor, num_samples, delta=total_time_reduction_factor, dtype=tf.float32) + etime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32) + + non_blank = tf.where(tf.not_equal(upoints, 0)) + non_blank_transcript = tf.gather_nd(upoints, non_blank) + non_blank_stime = tf.gather_nd(tf.repeat(tf.expand_dims(stime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank) + non_blank_etime = tf.gather_nd(tf.repeat(tf.expand_dims(etime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank) + + return non_blank_transcript, non_blank_stime, non_blank_etime, hypothesis.index, new_encoder_states, hypothesis.states + + # -------------------------------- BEAM SEARCH ------------------------------------- + + @tf.function + def recognize_beam(self, + features: tf.Tensor, + input_length: tf.Tensor, + lm: bool = False, + parallel_iterations: int = 10, + swap_memory: bool = True): + """ + RNN Transducer Beam Search + Args: + features (tf.Tensor): a batch of padded extracted features + lm (bool, optional): whether to use language model. Defaults to False. + + Returns: + tf.Tensor: a batch of decoded transcripts + """ + encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state()) + return self._perform_beam_search_batch(encoded, input_length, lm, + parallel_iterations=parallel_iterations, swap_memory=swap_memory) + + # -------------------------------- TFLITE ------------------------------------- + + def make_tflite_function(self, timestamp: bool = True): + tflite_func = self.recognize_tflite_with_timestamp if timestamp else self.recognize_tflite + return tf.function( + tflite_func, + input_signature=[ + tf.TensorSpec([None], dtype=tf.float32), + tf.TensorSpec([], dtype=tf.int32), + tf.TensorSpec(self.encoder.get_initial_state().get_shape(), dtype=tf.float32), + tf.TensorSpec(self.predict_net.get_initial_state().get_shape(), dtype=tf.float32) + ] + ) diff --git a/tensorflow_asr/models/keras/transducer.py b/tensorflow_asr/models/keras/transducer.py new file mode 100644 index 0000000000..59e206aaa1 --- /dev/null +++ b/tensorflow_asr/models/keras/transducer.py @@ -0,0 +1,77 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" https://arxiv.org/pdf/1811.06621.pdf """ + +import tensorflow as tf +from tensorflow.keras import mixed_precision as mxp + +from ..transducer import Transducer as BaseTransducer +from ...utils.utils import get_reduced_length +from ...losses.keras.rnnt_losses import RnntLoss + + +class Transducer(BaseTransducer): + """ Keras Transducer Model Warper """ + + def _build(self, input_shape): + features = tf.keras.Input(shape=input_shape, dtype=tf.float32) + input_length = tf.keras.Input(shape=[], dtype=tf.int32) + pred = tf.keras.Input(shape=[None], dtype=tf.int32) + pred_length = tf.keras.Input(shape=[], dtype=tf.int32) + self({ + "input": features, + "input_length": input_length, + "prediction": pred, + "prediction_length": pred_length + }, training=True) + + def call(self, inputs, training=False, **kwargs): + features = inputs["input"] + prediction = inputs["prediction"] + prediction_length = inputs["prediction_length"] + enc = self.encoder(features, training=training, **kwargs) + pred = self.predict_net([prediction, prediction_length], training=training, **kwargs) + outputs = self.joint_net([enc, pred], training=training, **kwargs) + return { + "logit": outputs, + "logit_length": get_reduced_length(inputs["input_length"], self.time_reduction_factor) + } + + def compile(self, optimizer, global_batch_size, blank=0, + loss_weights=None, weighted_metrics=None, run_eagerly=None, **kwargs): + loss = RnntLoss(blank=blank, global_batch_size=global_batch_size) + optimizer_with_scale = mxp.experimental.LossScaleOptimizer(tf.keras.optimizers.get(optimizer), 'dynamic') + super(Transducer, self).compile( + optimizer=optimizer_with_scale, loss=loss, + loss_weights=loss_weights, weighted_metrics=weighted_metrics, + run_eagerly=run_eagerly, + **kwargs + ) + + def train_step(self, batch): + x, y_true = batch + with tf.GradientTape() as tape: + y_pred = self(x, training=True) + loss = self.loss(y_true, y_pred) + scaled_loss = self.optimizer.get_scaled_loss(loss) + scaled_gradients = tape.gradient(scaled_loss, self.trainable_weights) + gradients = self.optimizer.get_unscaled_gradients(scaled_gradients) + self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) + return {"train_rnnt_loss": loss} + + def test_step(self, batch): + x, y_true = batch + y_pred = self(x, training=False) + loss = self.loss(y_true, y_pred) + return {"val_rnnt_loss": loss} diff --git a/tensorflow_asr/models/transducer.py b/tensorflow_asr/models/transducer.py index 4806d07536..2469c56767 100755 --- a/tensorflow_asr/models/transducer.py +++ b/tensorflow_asr/models/transducer.py @@ -284,6 +284,8 @@ def call(self, inputs, training=False, **kwargs): outputs = self.joint_net([enc, pred], training=training, **kwargs) return outputs + # -------------------------------- INFERENCES------------------------------------- + def encoder_inference(self, features: tf.Tensor): """Infer function for encoder (or encoders) diff --git a/tests/conformer/test_conformer.py b/tests/conformer/test_conformer.py index 61f4d1f21a..bdc3fbf011 100644 --- a/tests/conformer/test_conformer.py +++ b/tests/conformer/test_conformer.py @@ -25,7 +25,7 @@ def test_conformer(): - config = Config(DEFAULT_YAML, learning=False) + config = Config(DEFAULT_YAML) text_featurizer = CharFeaturizer(config.decoder_config) diff --git a/tests/contextnet/test_contextnet.py b/tests/contextnet/test_contextnet.py index 0ee8fe21e5..41ded97da9 100644 --- a/tests/contextnet/test_contextnet.py +++ b/tests/contextnet/test_contextnet.py @@ -25,7 +25,7 @@ def test_contextnet(): - config = Config(DEFAULT_YAML, learning=False) + config = Config(DEFAULT_YAML) text_featurizer = CharFeaturizer(config.decoder_config) diff --git a/tests/deepspeech2/test_ds2.py b/tests/deepspeech2/test_ds2.py index 9b8a994ae0..49fc3a0807 100644 --- a/tests/deepspeech2/test_ds2.py +++ b/tests/deepspeech2/test_ds2.py @@ -25,7 +25,7 @@ def test_ds2(): - config = Config(DEFAULT_YAML, learning=False) + config = Config(DEFAULT_YAML) text_featurizer = CharFeaturizer(config.decoder_config) diff --git a/tests/jasper/test_jasper.py b/tests/jasper/test_jasper.py index 536c91c23b..3c0de7679d 100644 --- a/tests/jasper/test_jasper.py +++ b/tests/jasper/test_jasper.py @@ -25,7 +25,7 @@ def test_jasper(): - config = Config(DEFAULT_YAML, learning=False) + config = Config(DEFAULT_YAML) text_featurizer = CharFeaturizer(config.decoder_config) diff --git a/tests/streaming_transducer/test_streaming_transducer.py b/tests/streaming_transducer/test_streaming_transducer.py index 20dddf6257..2b20fccdbc 100644 --- a/tests/streaming_transducer/test_streaming_transducer.py +++ b/tests/streaming_transducer/test_streaming_transducer.py @@ -25,7 +25,7 @@ def test_streaming_transducer(): - config = Config(DEFAULT_YAML, learning=False) + config = Config(DEFAULT_YAML) text_featurizer = CharFeaturizer(config.decoder_config)