From 26193fb7008e127765054ec014eb0e90c2db43bc Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Sun, 8 Nov 2020 20:03:04 +0700 Subject: [PATCH 1/2] :zap: Update gradient accumulation :zap: update _ --- examples/conformer/config.yml | 14 +- .../conformer/tflite_subword_conformer.py | 2 +- examples/deepspeech2/config.yml | 15 +- examples/jasper/config.yml | 15 +- examples/streaming_transducer/config.yml | 10 +- setup.py | 9 +- tensorflow_asr/models/conformer.py | 2 + tensorflow_asr/optimizers/accumulation.py | 16 +-- tensorflow_asr/runners/base_runners.py | 1 + tensorflow_asr/runners/transducer_runners.py | 2 +- tests/test_conformer.py | 128 ++++++++++++++++++ tests/test_transducer.py | 7 +- 12 files changed, 173 insertions(+), 48 deletions(-) create mode 100644 tests/test_conformer.py diff --git a/examples/conformer/config.yml b/examples/conformer/config.yml index a6e214c400..e08ae239e1 100755 --- a/examples/conformer/config.yml +++ b/examples/conformer/config.yml @@ -68,12 +68,12 @@ learning_config: dataset_config: train_paths: - - /mnt/Data/ML/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv + - /mnt/d/SpeechProcessing/Datasets/LibriSpeech/train-clean-100/transcripts.tsv eval_paths: - - /mnt/Data/ML/ASR/Raw/LibriSpeech/dev-clean/transcripts.tsv - - /mnt/Data/ML/ASR/Raw/LibriSpeech/dev-other/transcripts.tsv + - /mnt/d/SpeechProcessing/Datasets/LibriSpeech/dev-clean/transcripts.tsv + - /mnt/d/SpeechProcessing/Datasets/LibriSpeech/dev-other/transcripts.tsv test_paths: - - /mnt/Data/ML/ASR/Raw/LibriSpeech/test-clean/transcripts.tsv + - /mnt/d/SpeechProcessing/Datasets/LibriSpeech/test-clean/transcripts.tsv tfrecords_dir: null optimizer_config: @@ -83,10 +83,10 @@ learning_config: epsilon: 1e-9 running_config: - batch_size: 2 - accumulation_steps: 1 + batch_size: 4 + accumulation_steps: 4 num_epochs: 20 - outdir: /mnt/Projects/asrk16/trained/local/librispeech/conformer + outdir: /mnt/d/SpeechProcessing/Trained/local/conformer log_interval_steps: 300 eval_interval_steps: 500 save_interval_steps: 1000 diff --git a/examples/conformer/tflite_subword_conformer.py b/examples/conformer/tflite_subword_conformer.py index 6ea372ee41..51222ce71e 100644 --- a/examples/conformer/tflite_subword_conformer.py +++ b/examples/conformer/tflite_subword_conformer.py @@ -58,7 +58,7 @@ # build model conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) conformer._build(speech_featurizer.shape) -conformer.load_weights(args.saved) +conformer.load_weights(args.saved, by_name=True) conformer.summary(line_length=150) conformer.add_featurizers(speech_featurizer, text_featurizer) diff --git a/examples/deepspeech2/config.yml b/examples/deepspeech2/config.yml index ee43e06404..e8d8cd6957 100755 --- a/examples/deepspeech2/config.yml +++ b/examples/deepspeech2/config.yml @@ -24,11 +24,11 @@ speech_config: normalize_per_feature: False decoder_config: - vocabulary: ./vocabularies/vietnamese.characters + vocabulary: null blank_at_zero: False beam_width: 500 lm_config: - model_path: /mnt/Data/ML/NLP/vntc_asr_5gram_trie.binary + model_path: null alpha: 2.0 beta: 1.0 @@ -53,12 +53,13 @@ learning_config: dataset_config: train_paths: - - /mnt/Data/ML/ASR/Preprocessed/Vivos/train/train_transcripts.tsv + - /mnt/d/SpeechProcessing/Datasets/LibriSpeech/train-clean-100/transcripts.tsv eval_paths: - - /mnt/Data/ML/ASR/Preprocessed/Vivos/train/eval_transcripts.tsv + - /mnt/d/SpeechProcessing/Datasets/LibriSpeech/dev-clean/transcripts.tsv + - /mnt/d/SpeechProcessing/Datasets/LibriSpeech/dev-other/transcripts.tsv test_paths: - - /mnt/Data/ML/ASR/Preprocessed/Vivos/test/transcripts.tsv - tfrecords_dir: /mnt/Data/ML/ASR/Preprocessed/Vivos/TFRecords + - /mnt/d/SpeechProcessing/Datasets/LibriSpeech/test-clean/transcripts.tsv + tfrecords_dir: null optimizer_config: class_name: adam @@ -68,7 +69,7 @@ learning_config: running_config: batch_size: 8 num_epochs: 20 - outdir: /mnt/Projects/asrk16/trained/local/vivos + outdir: /mnt/d/SpeechProcessing/Trained/local/deepspeech2 log_interval_steps: 400 save_interval_steps: 400 eval_interval_steps: 800 diff --git a/examples/jasper/config.yml b/examples/jasper/config.yml index d6e62d8dbb..8c792f2d6f 100755 --- a/examples/jasper/config.yml +++ b/examples/jasper/config.yml @@ -24,11 +24,11 @@ speech_config: normalize_per_feature: False decoder_config: - vocabulary: ./vocabularies/vietnamese.characters + vocabulary: null blank_at_zero: False beam_width: 500 lm_config: - model_path: /mnt/Data/ML/NLP/vntc_asr_5gram_trie.binary + model_path: null alpha: 2.0 beta: 1.0 @@ -60,12 +60,13 @@ learning_config: dataset_config: train_paths: - - /mnt/Data/ML/ASR/Preprocessed/Vivos/train/train_transcripts.tsv + - /mnt/d/SpeechProcessing/Datasets/LibriSpeech/train-clean-100/transcripts.tsv eval_paths: - - /mnt/Data/ML/ASR/Preprocessed/Vivos/train/eval_transcripts.tsv + - /mnt/d/SpeechProcessing/Datasets/LibriSpeech/dev-clean/transcripts.tsv + - /mnt/d/SpeechProcessing/Datasets/LibriSpeech/dev-other/transcripts.tsv test_paths: - - /mnt/Data/ML/ASR/Preprocessed/Vivos/test/transcripts.tsv - tfrecords_dir: /mnt/Data/ML/ASR/Preprocessed/Vivos/TFRecords + - /mnt/d/SpeechProcessing/Datasets/LibriSpeech/test-clean/transcripts.tsv + tfrecords_dir: null optimizer_config: class_name: adam @@ -75,7 +76,7 @@ learning_config: running_config: batch_size: 8 num_epochs: 20 - outdir: /mnt/Projects/asrk16/trained/local/jasper + outdir: /mnt/d/SpeechProcessing/Trained/local/jasper log_interval_steps: 400 save_interval_steps: 400 eval_interval_steps: 800 diff --git a/examples/streaming_transducer/config.yml b/examples/streaming_transducer/config.yml index 6b205e23b0..fc1663e4ce 100755 --- a/examples/streaming_transducer/config.yml +++ b/examples/streaming_transducer/config.yml @@ -63,12 +63,12 @@ learning_config: dataset_config: train_paths: - - /mnt/Data/ML/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv + - /mnt/d/SpeechProcessing/Datasets/LibriSpeech/train-clean-100/transcripts.tsv eval_paths: - - /mnt/Data/ML/ASR/Raw/LibriSpeech/dev-clean/transcripts.tsv - - /mnt/Data/ML/ASR/Raw/LibriSpeech/dev-other/transcripts.tsv + - /mnt/d/SpeechProcessing/Datasets/LibriSpeech/dev-clean/transcripts.tsv + - /mnt/d/SpeechProcessing/Datasets/LibriSpeech/dev-other/transcripts.tsv test_paths: - - /mnt/Data/ML/ASR/Raw/LibriSpeech/test-clean/transcripts.tsv + - /mnt/d/SpeechProcessing/Datasets/LibriSpeech/test-clean/transcripts.tsv tfrecords_dir: null optimizer_config: @@ -80,7 +80,7 @@ learning_config: batch_size: 2 accumulation_steps: 1 num_epochs: 20 - outdir: /mnt/Projects/asrk16/trained/local/librispeech/streaming_transducer + outdir: /mnt/SpeechProcessing/Trained/local/streaming_transducer log_interval_steps: 300 eval_interval_steps: 500 save_interval_steps: 1000 diff --git a/setup.py b/setup.py index 1fbdda09f5..ea5a11f87c 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ long_description = fh.read() requirements = [ - "tensorflow>=2.3.0", + # "tensorflow>=2.3.0", "tensorflow-datasets>=3.2.1,<4.0.0", "tensorflow-addons>=0.10.0", "setuptools>=47.1.1", @@ -26,13 +26,14 @@ "soundfile>=0.10.3", "PyYAML>=5.3.1", "matplotlib>=3.2.1", - "numpy>=1.18.5,<1.19.0", + "numpy>=1.16.0,<1.19.0", "sox>=1.3.7", "nltk>=3.5", "numba==0.49.1", - "tqdm>=4.47.0", + "tqdm>=4.51.0", "colorama>=0.4.3", - "nlpaug>=1.0.1" + "nlpaug>=1.0.1", + "absl-py>=0.9,<0.11" ] setuptools.setup( diff --git a/tensorflow_asr/models/conformer.py b/tensorflow_asr/models/conformer.py index 5a02fc2b87..a88fefd297 100755 --- a/tensorflow_asr/models/conformer.py +++ b/tensorflow_asr/models/conformer.py @@ -372,6 +372,7 @@ def __init__(self, num_heads: int = 4, mha_type: str = "relmha", kernel_size: int = 32, + depth_multiplier: int = 1, fc_factor: float = 0.5, dropout: float = 0, embed_dim: int = 512, @@ -395,6 +396,7 @@ def __init__(self, num_heads=num_heads, mha_type=mha_type, kernel_size=kernel_size, + depth_multiplier=depth_multiplier, fc_factor=fc_factor, dropout=dropout, kernel_regularizer=kernel_regularizer, diff --git a/tensorflow_asr/optimizers/accumulation.py b/tensorflow_asr/optimizers/accumulation.py index 5947f51983..a5f3e8950e 100644 --- a/tensorflow_asr/optimizers/accumulation.py +++ b/tensorflow_asr/optimizers/accumulation.py @@ -19,25 +19,15 @@ class GradientAccumulation: def __init__(self, trainable_variables): self.gradients = [ tf.Variable( - tf.zeros_like(self.flat_gradients(g)), + tf.zeros_like(g), synchronization=tf.VariableSynchronization.ON_READ ) for g in trainable_variables ] - @staticmethod - def flat_gradients(gradient): - """ Convert gradients if it's tf.IndexedSlices. """ - if type(gradient) == tf.IndexedSlices: - return tf.scatter_nd( - tf.expand_dims(gradient.indices, 1), - gradient.values, - gradient.dense_shape - ) - return gradient - def reset(self): for g in self.gradients: g.assign(tf.zeros_like(g)) def accumulate(self, step_gradients): for i, g in enumerate(step_gradients): - self.gradients[i].assign_add(self.flat_gradients(g)) + if g is None: continue + self.gradients[i].assign_add(g) diff --git a/tensorflow_asr/runners/base_runners.py b/tensorflow_asr/runners/base_runners.py index 76b0624979..7524d97b13 100644 --- a/tensorflow_asr/runners/base_runners.py +++ b/tensorflow_asr/runners/base_runners.py @@ -120,6 +120,7 @@ def set_train_data_loader(self, train_dataset, train_bs=None, train_acs=None): self.config.batch_size = train_bs # Update batch size fed from arguments if not train_acs: train_acs = self.config.accumulation_steps + assert train_bs % train_acs == 0, "Batch size must be a multiple of Accumulation Steps" self.accumulation_bs = train_bs // train_acs self.config.accumulation_steps = train_acs # update accum steps fed from arguments diff --git a/tensorflow_asr/runners/transducer_runners.py b/tensorflow_asr/runners/transducer_runners.py index d961a922db..9ead0662e8 100644 --- a/tensorflow_asr/runners/transducer_runners.py +++ b/tensorflow_asr/runners/transducer_runners.py @@ -96,7 +96,7 @@ def _train_step(self, batch): self.accumulation.reset() - for accum_step in range(self.config.accumulation_step): + for accum_step in range(self.config.accumulation_steps): indices = tf.expand_dims( tf.range( diff --git a/tests/test_conformer.py b/tests/test_conformer.py new file mode 100644 index 0000000000..fedcb66298 --- /dev/null +++ b/tests/test_conformer.py @@ -0,0 +1,128 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +# import datetime +import sys +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" +import tensorflow as tf + +from tensorflow_asr.models.conformer import Conformer +from tensorflow_asr.models.transducer import Transducer +from tensorflow_asr.models.layers.subsampling import Conv2dSubsampling +from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer +from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer, read_raw_audio + +text_featurizer = CharFeaturizer({ + "vocabulary": None, + "blank_at_zero": True, + "beam_width": 5, + "norm_score": True +}) + +speech_featurizer = TFSpeechFeaturizer({ + "sample_rate": 16000, + "frame_ms": 25, + "stride_ms": 10, + "num_feature_bins": 80, + "feature_type": "log_mel_spectrogram", + "preemphasis": 0.97, + "normalize_signal": True, + "normalize_feature": True, + "normalize_per_feature": False +}) + +# i = tf.keras.Input(shape=[None, 80, 1]) +# o = Conv2dSubsampling(144)(i) + +# encoder = tf.keras.Model(inputs=i, outputs=o) +# model = Transducer(encoder=encoder, vocabulary_size=text_featurizer.num_classes) + +model = Conformer( + subsampling={"type": "conv2d", "filters": 144, "kernel_size": 3, + "strides": 2}, + num_blocks=1, + vocabulary_size=text_featurizer.num_classes) + +model._build(speech_featurizer.shape) +model.summary(line_length=150) + +model.save_weights("/tmp/transducer.h5") + +model.add_featurizers( + speech_featurizer=speech_featurizer, + text_featurizer=text_featurizer +) + +# features = tf.zeros(shape=[5, 50, 80, 1], dtype=tf.float32) +# pred = model.recognize(features) +# print(pred) +# pred = model.recognize_beam(features) +# print(pred) + +# stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") +# logdir = '/tmp/logs/func/%s' % stamp +# writer = tf.summary.create_file_writer(logdir) +# +signal = read_raw_audio(sys.argv[1], speech_featurizer.sample_rate) +# +# tf.summary.trace_on(graph=True, profiler=True) +# hyps = model.recognize_tflite(signal, 0, tf.zeros([1, 2, 1, 320], dtype=tf.float32)) +# with writer.as_default(): +# tf.summary.trace_export( +# name="recognize_tflite", +# step=0, +# profiler_outdir=logdir) +# +# print(hyps[0]) +# +# # hyps = model.recognize_beam(features) +# +# + +# hyps = model.recognize_beam(tf.expand_dims(speech_featurizer.tf_extract(signal), 0)) + +# print(hyps) + +# hyps = model.recognize_beam_tflite(signal) + +# print(hyps.numpy().decode("utf-8")) + +concrete_func = model.make_tflite_function(greedy=True).get_concrete_function() +converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) +converter.optimizations = [tf.lite.Optimize.DEFAULT] +converter.experimental_new_converter = True +converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, + tf.lite.OpsSet.SELECT_TF_OPS] +tflite = converter.convert() + +tflitemodel = tf.lite.Interpreter(model_content=tflite) + +input_details = tflitemodel.get_input_details() +output_details = tflitemodel.get_output_details() +tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape) +tflitemodel.allocate_tensors() +tflitemodel.set_tensor(input_details[0]["index"], signal) +tflitemodel.set_tensor( + input_details[1]["index"], + tf.constant(text_featurizer.blank, dtype=tf.int32) +) +tflitemodel.set_tensor( + input_details[2]["index"], + tf.zeros([1, 2, 1, 320], dtype=tf.float32) +) +tflitemodel.invoke() +hyp = tflitemodel.get_tensor(output_details[0]["index"]) + +print(hyp) diff --git a/tests/test_transducer.py b/tests/test_transducer.py index 7007dee216..e221e76c7f 100644 --- a/tests/test_transducer.py +++ b/tests/test_transducer.py @@ -41,7 +41,8 @@ "normalize_per_feature": False }) -model = StreamingTransducer(vocabulary_size=text_featurizer.num_classes) +model = StreamingTransducer(vocabulary_size=text_featurizer.num_classes, + encoder_dmodel=320, encoder_nlayers=3) model._build(speech_featurizer.shape) model.summary(line_length=150) @@ -108,11 +109,11 @@ ) tflitemodel.set_tensor( input_details[2]["index"], - tf.zeros([8, 2, 1, 1024], dtype=tf.float32) + tf.zeros([3, 2, 1, 2048], dtype=tf.float32) ) tflitemodel.set_tensor( input_details[3]["index"], - tf.zeros([2, 2, 1, 1024], dtype=tf.float32) + tf.zeros([2, 2, 1, 2048], dtype=tf.float32) ) tflitemodel.invoke() hyp = tflitemodel.get_tensor(output_details[0]["index"]) From 263855effb35c7770cb7c00ee24bee1a85cf1051 Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Sat, 14 Nov 2020 23:07:39 +0700 Subject: [PATCH 2/2] :zap: Supported Gradients Accumulation --- README.md | 1 + examples/conformer/train_ga_conformer.py | 6 +- .../conformer/train_ga_subword_conformer.py | 6 +- .../train_ga_streaming_transducer.py | 6 +- .../train_ga_subword_streaming_transducer.py | 6 +- setup.py | 4 +- tensorflow_asr/optimizers/accumulation.py | 4 +- tensorflow_asr/runners/__init__.py | 4 +- tensorflow_asr/runners/base_runners.py | 10 +-- tensorflow_asr/runners/transducer_runners.py | 65 ++++++++----------- 10 files changed, 62 insertions(+), 50 deletions(-) diff --git a/README.md b/README.md index a66906e29c..921276be1e 100755 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ TensorFlowASR implements some automatic speech recognition architectures such as ## What's New? +- (11/14/2020) Supported Gradient Accumulation for Training in Larger Batch Size - (11/3/2020) Reduce differences between `librosa.stft` and `tf.signal.stft` - (10/31/2020) Update DeepSpeech2 and Supported Jasper [https://arxiv.org/abs/1904.03288](https://arxiv.org/abs/1904.03288) - (10/18/2020) Supported Streaming Transducer [https://arxiv.org/abs/1811.06621](https://arxiv.org/abs/1811.06621) diff --git a/examples/conformer/train_ga_conformer.py b/examples/conformer/train_ga_conformer.py index ec8c5404bc..ea30f45a2c 100644 --- a/examples/conformer/train_ga_conformer.py +++ b/examples/conformer/train_ga_conformer.py @@ -41,6 +41,9 @@ parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica") +parser.add_argument("--acs", type=int, default=None, + help="Train accumulation steps") + parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training") @@ -125,4 +128,5 @@ conformer_trainer.compile(model=conformer, optimizer=optimizer, max_to_keep=args.max_ckpts) -conformer_trainer.fit(train_dataset, eval_dataset, train_bs=args.tbs, eval_bs=args.ebs) +conformer_trainer.fit(train_dataset, eval_dataset, + train_bs=args.tbs, eval_bs=args.ebs, train_acs=args.acs) diff --git a/examples/conformer/train_ga_subword_conformer.py b/examples/conformer/train_ga_subword_conformer.py index a384a14c14..52cd3f8ae3 100644 --- a/examples/conformer/train_ga_subword_conformer.py +++ b/examples/conformer/train_ga_subword_conformer.py @@ -41,6 +41,9 @@ parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica") +parser.add_argument("--acs", type=int, default=None, + help="Train accumulation steps") + parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training") @@ -141,4 +144,5 @@ conformer_trainer.compile(model=conformer, optimizer=optimizer, max_to_keep=args.max_ckpts) -conformer_trainer.fit(train_dataset, eval_dataset, train_bs=args.tbs, eval_bs=args.ebs) +conformer_trainer.fit(train_dataset, eval_dataset, + train_bs=args.tbs, eval_bs=args.ebs, train_acs=args.acs) diff --git a/examples/streaming_transducer/train_ga_streaming_transducer.py b/examples/streaming_transducer/train_ga_streaming_transducer.py index b64c50c0ee..85452c7705 100644 --- a/examples/streaming_transducer/train_ga_streaming_transducer.py +++ b/examples/streaming_transducer/train_ga_streaming_transducer.py @@ -40,6 +40,9 @@ parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica") +parser.add_argument("--acs", type=int, default=None, + help="Train accumulation steps") + parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training") @@ -116,4 +119,5 @@ streaming_transducer_trainer.compile(model=streaming_transducer, optimizer=optimizer, max_to_keep=args.max_ckpts) -streaming_transducer_trainer.fit(train_dataset, eval_dataset, train_bs=args.tbs, eval_bs=args.ebs) +streaming_transducer_trainer.fit(train_dataset, eval_dataset, + train_bs=args.tbs, eval_bs=args.ebs, train_acs=args.acs) diff --git a/examples/streaming_transducer/train_ga_subword_streaming_transducer.py b/examples/streaming_transducer/train_ga_subword_streaming_transducer.py index 70e16cd38c..ce3d0c0dbd 100644 --- a/examples/streaming_transducer/train_ga_subword_streaming_transducer.py +++ b/examples/streaming_transducer/train_ga_subword_streaming_transducer.py @@ -40,6 +40,9 @@ parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica") +parser.add_argument("--acs", type=int, default=None, + help="Train accumulation steps") + parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training") @@ -132,4 +135,5 @@ streaming_transducer_trainer.compile(model=streaming_transducer, optimizer=optimizer, max_to_keep=args.max_ckpts) -streaming_transducer_trainer.fit(train_dataset, eval_dataset, train_bs=args.tbs, eval_bs=args.ebs) +streaming_transducer_trainer.fit(train_dataset, eval_dataset, + train_bs=args.tbs, eval_bs=args.ebs, train_acs=args.acs) diff --git a/setup.py b/setup.py index ea5a11f87c..53a742a3ff 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ long_description = fh.read() requirements = [ - # "tensorflow>=2.3.0", + "tensorflow>=2.3.0", "tensorflow-datasets>=3.2.1,<4.0.0", "tensorflow-addons>=0.10.0", "setuptools>=47.1.1", @@ -38,7 +38,7 @@ setuptools.setup( name="TensorFlowASR", - version="0.2.10", + version="0.3.0", author="Huy Le Nguyen", author_email="nlhuy.cs.16@gmail.com", description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2", diff --git a/tensorflow_asr/optimizers/accumulation.py b/tensorflow_asr/optimizers/accumulation.py index a5f3e8950e..79c37965e8 100644 --- a/tensorflow_asr/optimizers/accumulation.py +++ b/tensorflow_asr/optimizers/accumulation.py @@ -20,12 +20,14 @@ def __init__(self, trainable_variables): self.gradients = [ tf.Variable( tf.zeros_like(g), + trainable=False, synchronization=tf.VariableSynchronization.ON_READ ) for g in trainable_variables ] def reset(self): - for g in self.gradients: g.assign(tf.zeros_like(g)) + for i, g in enumerate(self.gradients): + self.gradients[i].assign(tf.zeros_like(g)) def accumulate(self, step_gradients): for i, g in enumerate(step_gradients): diff --git a/tensorflow_asr/runners/__init__.py b/tensorflow_asr/runners/__init__.py index 52e75e7ec5..b750f2499e 100644 --- a/tensorflow_asr/runners/__init__.py +++ b/tensorflow_asr/runners/__init__.py @@ -28,8 +28,8 @@ def save_from_checkpoint(func, max_to_keep: number of checkpoints to keep **kwargs: contains built models, optimizers """ - steps = tf.Variable(0, dtype=tf.int64) # Step must be int64 - epochs = tf.Variable(1) + steps = tf.Variable(0, trainable=False, dtype=tf.int64) # Step must be int64 + epochs = tf.Variable(1, trainable=False) checkpoint_dir = os.path.join(outdir, "checkpoints") if not os.path.exists(checkpoint_dir): raise ValueError(f"checkpoint directory not found: {checkpoint_dir}") diff --git a/tensorflow_asr/runners/base_runners.py b/tensorflow_asr/runners/base_runners.py index 7524d97b13..9565d027f8 100644 --- a/tensorflow_asr/runners/base_runners.py +++ b/tensorflow_asr/runners/base_runners.py @@ -72,7 +72,8 @@ def __init__(self, super(BaseTrainer, self).__init__(config) self.set_strategy(strategy) # Steps and Epochs start from 0 - self.steps = tf.Variable(0, dtype=tf.int64) # Step must be int64 to use tf.summary + # Step must be int64 to use tf.summary + self.steps = tf.Variable(0, trainable=False, dtype=tf.int64) self.train_steps_per_epoch = None self.eval_steps_per_epoch = None # Dataset @@ -120,13 +121,14 @@ def set_train_data_loader(self, train_dataset, train_bs=None, train_acs=None): self.config.batch_size = train_bs # Update batch size fed from arguments if not train_acs: train_acs = self.config.accumulation_steps - assert train_bs % train_acs == 0, "Batch size must be a multiple of Accumulation Steps" - self.accumulation_bs = train_bs // train_acs self.config.accumulation_steps = train_acs # update accum steps fed from arguments self.train_data = train_dataset.create(self.global_batch_size) self.train_data_loader = self.strategy.experimental_distribute_dataset(self.train_data) - self.train_steps_per_epoch = train_dataset.total_steps + if hasattr(self, "accumulation"): + self.train_steps_per_epoch = train_dataset.total_steps // self.config.accumulation_steps + else: + self.train_steps_per_epoch = train_dataset.total_steps def set_eval_data_loader(self, eval_dataset, eval_bs=None): """ Set eval data loader (MUST). diff --git a/tensorflow_asr/runners/transducer_runners.py b/tensorflow_asr/runners/transducer_runners.py index 9ead0662e8..6664c37098 100644 --- a/tensorflow_asr/runners/transducer_runners.py +++ b/tensorflow_asr/runners/transducer_runners.py @@ -90,48 +90,39 @@ def compile(self, class TransducerTrainerGA(TransducerTrainer): """ Transducer Trainer that uses Gradients Accumulation """ - @tf.function(experimental_relax_shapes=True) - def _train_step(self, batch): - _, bfeatures, binput_length, blabels, blabel_length, bpred_inp = batch - + @tf.function + def _train_function(self, iterator): + for _ in range(self.config.accumulation_steps): + batch = next(iterator) + self.strategy.run(self._train_step, args=(batch,)) + self.strategy.run(self._apply_gradients, args=()) + + @tf.function + def _apply_gradients(self): + self.optimizer.apply_gradients( + zip(self.accumulation.gradients, self.model.trainable_variables)) self.accumulation.reset() - for accum_step in range(self.config.accumulation_steps): + @tf.function(experimental_relax_shapes=True) + def _train_step(self, batch): + _, features, input_length, labels, label_length, pred_inp = batch - indices = tf.expand_dims( - tf.range( - accum_step * self.accumulation_bs, - (accum_step + 1) * self.accumulation_bs, - dtype=tf.int32 - ), - axis=-1 + with tf.GradientTape() as tape: + logits = self.model([features, pred_inp], training=True) + tape.watch(logits) + per_train_loss = rnnt_loss( + logits=logits, labels=labels, label_length=label_length, + logit_length=(input_length // self.model.time_reduction_factor), + blank=self.text_featurizer.blank + ) + train_loss = tf.nn.compute_average_loss( + per_train_loss, + global_batch_size=self.global_batch_size ) - features = tf.gather_nd(bfeatures, indices) - input_length = tf.gather_nd(binput_length, indices) - labels = tf.gather_nd(blabels, indices) - label_length = tf.gather_nd(blabel_length, indices) - pred_inp = tf.gather_nd(bpred_inp, indices) - - with tf.GradientTape() as tape: - logits = self.model([features, pred_inp], training=True) - tape.watch(logits) - per_train_loss = rnnt_loss( - logits=logits, labels=labels, label_length=label_length, - logit_length=(input_length // self.model.time_reduction_factor), - blank=self.text_featurizer.blank - ) - train_loss = tf.nn.compute_average_loss( - per_train_loss, - global_batch_size=self.global_batch_size - ) - - step_gradients = tape.gradient(train_loss, self.model.trainable_variables) - self.accumulation.accumulate(step_gradients) - self.train_metrics["transducer_loss"].update_state(per_train_loss) - - self.optimizer.apply_gradients( - zip(self.accumulation.gradients, self.model.trainable_variables)) + gradients = tape.gradient(train_loss, self.model.trainable_variables) + self.accumulation.accumulate(gradients) + self.train_metrics["transducer_loss"].update_state(per_train_loss) def compile(self, model: Transducer,