From 159abe94ac13241e27caead332b3d2e88172361b Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Tue, 16 Feb 2021 01:04:34 +0700 Subject: [PATCH 01/12] :zap: add max length to text and speech featurizers --- tensorflow_asr/datasets/asr_dataset.py | 6 +-- .../featurizers/speech_featurizers.py | 20 +++++--- .../featurizers/text_featurizers.py | 47 ++++++++++++------- 3 files changed, 46 insertions(+), 27 deletions(-) diff --git a/tensorflow_asr/datasets/asr_dataset.py b/tensorflow_asr/datasets/asr_dataset.py index d5631e3ea0..709754cbe7 100755 --- a/tensorflow_asr/datasets/asr_dataset.py +++ b/tensorflow_asr/datasets/asr_dataset.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import multiprocessing import os - import numpy as np import tensorflow as tf @@ -129,9 +127,9 @@ def process(self, dataset: tf.data.Dataset, batch_size: int): tf.TensorShape([]), tf.TensorShape(self.speech_featurizer.shape), tf.TensorShape([]), - tf.TensorShape([None]), + tf.TensorShape(self.text_featurizer.shape), tf.TensorShape([]), - tf.TensorShape([None]), + tf.TensorShape(self.text_featurizer.prepand_shape), tf.TensorShape([]), ), padding_values=("", 0., 0, self.text_featurizer.blank, 0, self.text_featurizer.blank, 0), diff --git a/tensorflow_asr/featurizers/speech_featurizers.py b/tensorflow_asr/featurizers/speech_featurizers.py index 8288b9f643..65b40eb076 100755 --- a/tensorflow_asr/featurizers/speech_featurizers.py +++ b/tensorflow_asr/featurizers/speech_featurizers.py @@ -177,7 +177,7 @@ def map_fn(elem): class SpeechFeaturizer(metaclass=abc.ABCMeta): - def __init__(self, speech_config: dict): + def __init__(self, speech_config: dict, tpu: bool = False): """ We should use TFSpeechFeaturizer for training to avoid differences between tf and librosa when converting to tflite in post-training stage @@ -207,6 +207,9 @@ def __init__(self, speech_config: dict): self.normalize_signal = speech_config.get("normalize_signal", True) self.normalize_feature = speech_config.get("normalize_feature", True) self.normalize_per_feature = speech_config.get("normalize_per_feature", False) + # Length + self.tpu = tpu + self.max_length = 0 @property def nfft(self) -> int: @@ -218,6 +221,9 @@ def shape(self) -> list: """ The shape of extracted features """ raise NotImplementedError() + def update_length(self, length: int): + self.max_length = max(self.max_length, length) + @abc.abstractclassmethod def stft(self, signal): raise NotImplementedError() @@ -233,8 +239,8 @@ def extract(self, signal): class NumpySpeechFeaturizer(SpeechFeaturizer): - def __init__(self, speech_config: dict): - super(NumpySpeechFeaturizer, self).__init__(speech_config) + def __init__(self, speech_config: dict, tpu: bool = False): + super(NumpySpeechFeaturizer, self).__init__(speech_config, tpu) self.delta = speech_config.get("delta", False) self.delta_delta = speech_config.get("delta_delta", False) self.pitch = speech_config.get("pitch", False) @@ -253,7 +259,9 @@ def shape(self) -> list: if self.pitch: channel_dim += 1 - return [None, self.num_feature_bins, channel_dim] + length = self.max_length if (self.max_length > 0 and self.tpu) else None + + return [length, self.num_feature_bins, channel_dim] def stft(self, signal): return np.square( @@ -383,8 +391,8 @@ def compute_log_gammatone_spectrogram(self, signal: np.ndarray) -> np.ndarray: class TFSpeechFeaturizer(SpeechFeaturizer): @property def shape(self) -> list: - # None for time dimension - return [None, self.num_feature_bins, 1] + length = self.max_length if (self.max_length > 0 and self.tpu) else None + return [length, self.num_feature_bins, 1] def stft(self, signal): return tf.square( diff --git a/tensorflow_asr/featurizers/text_featurizers.py b/tensorflow_asr/featurizers/text_featurizers.py index ef467b08eb..7ec805cc90 100755 --- a/tensorflow_asr/featurizers/text_featurizers.py +++ b/tensorflow_asr/featurizers/text_featurizers.py @@ -30,13 +30,26 @@ class TextFeaturizer(metaclass=abc.ABCMeta): - def __init__(self, decoder_config: dict): + def __init__(self, decoder_config: dict, tpu: bool = False): self.scorer = None self.decoder_config = DecoderConfig(decoder_config) self.blank = None self.tokens2indices = {} self.tokens = [] self.num_classes = None + self.tpu = tpu + self.max_length = 0 + + @property + def shape(self) -> list: + return [self.max_length if (self.max_length > 0 and self.tpu) else None] + + @property + def prepand_shape(self) -> list: + return [self.max_length + 1 if (self.max_length > 0 and self.tpu) else None] + + def update_length(self, length: int): + self.max_length = max(self.max_length, length) def preprocess_text(self, text): text = unicodedata.normalize("NFC", text.lower()) @@ -84,7 +97,7 @@ class CharFeaturizer(TextFeaturizer): converted to a sequence of integer indexes. """ - def __init__(self, decoder_config: dict): + def __init__(self, decoder_config: dict, tpu: bool = False): """ decoder_config = { "vocabulary": str, @@ -95,7 +108,7 @@ def __init__(self, decoder_config: dict): } } """ - super(CharFeaturizer, self).__init__(decoder_config) + super(CharFeaturizer, self).__init__(decoder_config, tpu) self.__init_vocabulary() def __init_vocabulary(self): @@ -178,7 +191,7 @@ class SubwordFeaturizer(TextFeaturizer): converted to a sequence of integer indexes. """ - def __init__(self, decoder_config: dict, subwords=None): + def __init__(self, decoder_config: dict, subwords=None, tpu: bool = False): """ decoder_config = { "target_vocab_size": int, @@ -191,7 +204,7 @@ def __init__(self, decoder_config: dict, subwords=None): } } """ - super(SubwordFeaturizer, self).__init__(decoder_config) + super(SubwordFeaturizer, self).__init__(decoder_config, tpu) self.subwords = self.__load_subwords() if subwords is None else subwords self.blank = 0 # subword treats blank as 0 self.num_classes = self.subwords.vocab_size @@ -210,7 +223,7 @@ def __load_subwords(self): return tds.deprecated.text.SubwordTextEncoder.load_from_file(filename_prefix) @classmethod - def build_from_corpus(cls, decoder_config: dict, corpus_files: list = None): + def build_from_corpus(cls, decoder_config: dict, corpus_files: list = None, tpu: bool = False): dconf = DecoderConfig(decoder_config.copy()) corpus_files = dconf.corpus_files if corpus_files is None or len(corpus_files) == 0 else corpus_files @@ -230,15 +243,15 @@ def corpus_generator(): dconf.max_corpus_chars, dconf.reserved_tokens ) - return cls(decoder_config, subwords) + return cls(decoder_config, subwords, tpu) @classmethod - def load_from_file(cls, decoder_config: dict, filename: str = None): + def load_from_file(cls, decoder_config: dict, filename: str = None, tpu: bool = False): dconf = DecoderConfig(decoder_config.copy()) filename = dconf.vocabulary if filename is None else preprocess_paths(filename) filename_prefix = os.path.splitext(filename)[0] subwords = tds.deprecated.text.SubwordTextEncoder.load_from_file(filename_prefix) - return cls(decoder_config, subwords) + return cls(decoder_config, subwords, tpu) def save_to_file(self, filename: str = None): filename = self.decoder_config.vocabulary if filename is None else preprocess_paths(filename) @@ -316,8 +329,8 @@ class SentencePieceFeaturizer(TextFeaturizer): EOS_TOKEN, EOS_TOKEN_ID = "", 3 PAD_TOKEN, PAD_TOKEN_ID = "", 0 # unused, by default - def __init__(self, decoder_config: dict, model=None): - super().__init__(decoder_config) + def __init__(self, decoder_config: dict, model=None, tpu: bool = False): + super(SentencePieceFeaturizer, self).__init__(decoder_config, tpu) self.model = model self.blank = 0 # treats blank as 0 (pad) self.upoints = None @@ -333,11 +346,11 @@ def __init_upoints(self): self.upoints = self.upoints.to_tensor() # [num_classes, max_subword_length] @classmethod - def build_from_corpus(cls, decoder_config: dict): + def build_from_corpus(cls, decoder_config: dict, tpu: bool = False): """ - --model_prefix: output model name prefix. .model and .vocab are generated. + --model_prefix: output model name prefix. .model and .vocab are generated. --vocab_size: vocabulary size, e.g., 8000, 16000, or 32000 - --model_type: model type. Choose from unigram (default), bpe, char, or word. + --model_type: model type. Choose from unigram (default), bpe, char, or word. The input sentence must be pretokenized when using word type.""" decoder_cfg = DecoderConfig(decoder_config) # Train SentencePiece Model @@ -381,17 +394,17 @@ def corpus_iterator(): for _, s in sorted(vocab.items(), key=lambda x: x[0]): f_out.write(f"{s} 1\n") - return cls(decoder_config, processor) + return cls(decoder_config, processor, tpu) @classmethod - def load_from_file(cls, decoder_config: dict, filename: str = None): + def load_from_file(cls, decoder_config: dict, filename: str = None, tpu: bool = False): if filename is not None: filename_prefix = os.path.splitext(preprocess_paths(filename))[0] else: filename_prefix = decoder_config.get("output_path_prefix", None) processor = sp.SentencePieceProcessor() processor.load(filename_prefix + ".model") - return cls(decoder_config, processor) + return cls(decoder_config, processor, tpu) def extract(self, text: str) -> tf.Tensor: """ From b269ff797641596866353eab172cc2da927a2cd1 Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Tue, 16 Feb 2021 16:33:14 +0700 Subject: [PATCH 02/12] :zap: add compute max lengths and update create tfrecords --- examples/conformer/config.yml | 17 ++- .../conformer/train_tpu_subword_conformer.py | 124 ++++++++++++++++++ scripts/create_librispeech_trans.py | 8 +- scripts/create_mls_trans.py | 19 +-- tensorflow_asr/datasets/asr_dataset.py | 93 ++++++++++--- .../featurizers/speech_featurizers.py | 7 +- tensorflow_asr/models/ctc.py | 4 +- tensorflow_asr/models/keras/transducer.py | 14 +- tensorflow_asr/models/transducer.py | 10 +- tensorflow_asr/utils/utils.py | 4 + 10 files changed, 244 insertions(+), 56 deletions(-) create mode 100644 examples/conformer/train_tpu_subword_conformer.py diff --git a/examples/conformer/config.yml b/examples/conformer/config.yml index 594a61aa66..84b2636df3 100755 --- a/examples/conformer/config.yml +++ b/examples/conformer/config.yml @@ -31,9 +31,9 @@ decoder_config: beam_width: 5 norm_score: True corpus_files: - - /media/nlhuy/Data/ML/ASR/Raw/LibriSpeech/LibriSpeech/train-clean-100/transcripts.tsv - - /media/nlhuy/Data/ML/ASR/Raw/LibriSpeech/LibriSpeech/train-clean-360/transcripts.tsv - - /media/nlhuy/Data/ML/ASR/Raw/LibriSpeech/LibriSpeech/train-other-500/transcripts.tsv + - /media/nlhuy/Data/ML/Datasets/ASR/Raw/LibriSpeech/LibriSpeech/train-clean-100/transcripts.tsv + - /media/nlhuy/Data/ML/Datasets/ASR/Raw/LibriSpeech/LibriSpeech/train-clean-360/transcripts.tsv + - /media/nlhuy/Data/ML/Datasets/ASR/Raw/LibriSpeech/LibriSpeech/train-other-500/transcripts.tsv model_config: name: conformer @@ -77,32 +77,37 @@ learning_config: mask_factor: 27 data_paths: - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv - tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test + - /media/nlhuy/Data/ML/Datasets/ASR/Raw/LibriSpeech/LibriSpeech/train-clean-360/transcripts.tsv + - /media/nlhuy/Data/ML/Datasets/ASR/Raw/LibriSpeech/LibriSpeech/train-other-500/transcripts.tsv + tfrecords_dir: /media/nlhuy/Data/ML/Datasets/ASR/Raw/LibriSpeech/tfrecords_4076 shuffle: True cache: True buffer_size: 100 drop_remainder: True + stage: train eval_dataset_config: use_tf: True data_paths: - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-clean/transcripts.tsv - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv - tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test + tfrecords_dir: /media/nlhuy/Data/ML/Datasets/ASR/Raw/LibriSpeech/tfrecords_4076 shuffle: False cache: True buffer_size: 100 drop_remainder: True + stage: eval test_dataset_config: use_tf: True data_paths: - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv - tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test + tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords shuffle: False cache: True buffer_size: 100 drop_remainder: True + stage: test optimizer_config: warmup_steps: 40000 diff --git a/examples/conformer/train_tpu_subword_conformer.py b/examples/conformer/train_tpu_subword_conformer.py new file mode 100644 index 0000000000..ca9b643741 --- /dev/null +++ b/examples/conformer/train_tpu_subword_conformer.py @@ -0,0 +1,124 @@ +# Copyright 2021 M. Yusuf Sarıgöz (@monatis) and Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import math +import argparse +from tensorflow_asr.utils import setup_environment, setup_tpu + +setup_environment() +import tensorflow as tf + +DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml") + +tf.keras.backend.clear_session() + +parser = argparse.ArgumentParser(prog="Conformer Training") + +parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file") + +parser.add_argument("--max_ckpts", type=int, default=10, help="Max number of checkpoints to keep") + +parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model") + +parser.add_argument("--bs", type=int, default=None, help="Common training and evaluation batch size per TPU core") + +parser.add_argument("--tpu_address", type=str, default=None, help="TPU address. Leave None on Colab") + +parser.add_argument("--max_lengths_dir", type=str, default="~", + help="Path to file containing max lengths. Will be computed if not exists") + +parser.add_argument("--compute_lengths", default=False, action="store_true", help="Whether to compute lengths") + +parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision") + +parser.add_argument("--subwords", type=str, default=None, help="Path to file that stores generated subwords") + +parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[], help="Transcript files for generating subwords") + +args = parser.parse_args() + +tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp}) + +from tensorflow_asr.configs.config import Config +from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset +from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer +from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, SentencePieceFeaturizer +from tensorflow_asr.runners.transducer_runners import TransducerTrainer +from tensorflow_asr.models.conformer import Conformer +from tensorflow_asr.optimizers.schedules import TransformerSchedule + +config = Config(args.config, learning=True) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) + +if args.sentence_piece: + print("Loading SentencePiece model ...") + text_featurizer = SentencePieceFeaturizer.load_from_file(config.decoder_config, args.subwords) +elif args.subwords and os.path.exists(args.subwords): + print("Loading subwords ...") + text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords) +else: + print("Generating subwords ...") + text_featurizer = SubwordFeaturizer.build_from_corpus( + config.decoder_config, + corpus_files=args.subwords_corpus + ) + text_featurizer.save_to_file(args.subwords) + +train_dataset = ASRTFRecordDataset( + speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, + **vars(config.learning_config.train_dataset_config) +) + +eval_dataset = ASRTFRecordDataset( + speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, + **vars(config.learning_config.eval_dataset_config) +) + +if args.compute_lengths: + train_dataset.update_lengths(args.max_lengths_dir) + eval_dataset.update_lengths(args.max_lengths_dir) + +# Update max lengths calculated from both train and eval datasets +train_dataset.load_max_lengths(args.max_lengths_dir) +eval_dataset.load_max_lengths(args.max_lengths_dir) + +strategy = setup_tpu(args.tpu_address) + +conformer_trainer = TransducerTrainer( + config=config.learning_config.running_config, + text_featurizer=text_featurizer, strategy=strategy +) + +with conformer_trainer.strategy.scope(): + # build model + conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) + conformer._build(speech_featurizer.shape, prediction_shape=text_featurizer.prepand_shape, + batch_size=args.bs if args.bs is not None else config.learning_config.running_config.batch_size) + conformer.summary(line_length=120) + + optimizer = tf.keras.optimizers.Adam( + TransformerSchedule( + d_model=conformer.dmodel, + warmup_steps=config.learning_config.optimizer_config["warmup_steps"], + max_lr=(0.05 / math.sqrt(conformer.dmodel)) + ), + beta_1=config.learning_config.optimizer_config["beta1"], + beta_2=config.learning_config.optimizer_config["beta2"], + epsilon=config.learning_config.optimizer_config["epsilon"] + ) + +conformer_trainer.compile(model=conformer, optimizer=optimizer, max_to_keep=args.max_ckpts) + +conformer_trainer.fit(train_dataset, eval_dataset, train_bs=args.bs, eval_bs=args.bs) diff --git a/scripts/create_librispeech_trans.py b/scripts/create_librispeech_trans.py index 2c6c30de64..9a84cb4039 100644 --- a/scripts/create_librispeech_trans.py +++ b/scripts/create_librispeech_trans.py @@ -23,11 +23,9 @@ parser = argparse.ArgumentParser(prog="Setup LibriSpeech Transcripts") -parser.add_argument("--dir", "-d", type=str, - default=None, help="Directory of dataset") +parser.add_argument("--dir", "-d", type=str, default=None, help="Directory of dataset") -parser.add_argument("output", type=str, - default=None, help="The output .tsv transcript file path") +parser.add_argument("output", type=str, default=None, help="The output .tsv transcript file path") args = parser.parse_args() @@ -50,7 +48,7 @@ y, sr = librosa.load(audio_file, sr=None) duration = librosa.get_duration(y, sr) text = unicodedata.normalize("NFC", line[1].lower()) - transcripts.append(f"{audio_file}\t{duration:.2f}\t{text}\n") + transcripts.append(f"{audio_file}\t{duration}\t{text}\n") with open(args.output, "w", encoding="utf-8") as out: out.write("PATH\tDURATION\tTRANSCRIPT\n") diff --git a/scripts/create_mls_trans.py b/scripts/create_mls_trans.py index a196927269..1abaef74aa 100644 --- a/scripts/create_mls_trans.py +++ b/scripts/create_mls_trans.py @@ -42,11 +42,12 @@ chars = set() + def prepare_split(dataset_dir, split, opus=False): # Setup necessary paths split_home = os.path.join(dataset_dir, split) transcripts_infile = os.path.join(split_home, 'transcripts.txt') - transcripts_outfile = os.path.join(split_home, 'transcripts_tfasr.tsv') + transcripts_outfile = os.path.join(split_home, 'transcripts_tfasr.tsv') audio_home = os.path.join(split_home, "audio") extension = ".opus" if opus else ".flac" transcripts = [] @@ -59,7 +60,7 @@ def prepare_split(dataset_dir, split, opus=False): audio_path = os.path.join(audio_home, speaker_id, book_id, f"{file_id}{extension}") y, sr = librosa.load(audio_path, sr=None) duration = librosa.get_duration(y, sr) - transcripts.append(f"{audio_path}\t{duration:2f}\t{transcript}\n") + transcripts.append(f"{audio_path}\t{duration}\t{transcript}\n") for char in transcript: chars.add(char) @@ -83,10 +84,12 @@ def make_alphabet_file(filepath, chars_list, lang): if __name__ == "__main__": ap = argparse.ArgumentParser(description="Download and prepare MLS dataset in a given language") - ap.add_argument("--dataset-home", "-d", help="Path to home directory to download and prepare dataset. Default to ~/.keras", default=None, required=False) - ap.add_argument("--language", "-l", type=str, choices=langs, help="Any name of language included in MLS", default=None, required=True) - ap.add_argument("--opus", help="Whether to use dataset in opus format or not", default=False, action='store_true') - + ap.add_argument("--dataset-home", "-d", default=None, required=False, + help="Path to home directory to download and prepare dataset. Default to ~/.keras") + ap.add_argument("--language", "-l", type=str, choices=langs, default=None, required=True, + help="Any name of language included in MLS") + ap.add_argument("--opus", default=False, action="store_true", help="Whether to use dataset in opus format or not") + args = ap.parse_args() fname = "mls_{}{}.tar.gz".format(args.language, "_opus" if args.opus else "") subdir = fname[:-7] @@ -99,11 +102,11 @@ def make_alphabet_file(filepath, chars_list, lang): full_url, cache_subdir=dataset_home, extract=True - ) + ) print(f"Dataset extracted to {dataset_dir}. Preparing...") for split in splits: prepare_split(dataset_dir=dataset_dir, split=split, opus=args.opus) - make_alphabet_file(os.path.join(dataset_dir, "alphabet.txt"), chars, args.language) \ No newline at end of file + make_alphabet_file(os.path.join(dataset_dir, "alphabet.txt"), chars, args.language) diff --git a/tensorflow_asr/datasets/asr_dataset.py b/tensorflow_asr/datasets/asr_dataset.py index 709754cbe7..06665d82d5 100755 --- a/tensorflow_asr/datasets/asr_dataset.py +++ b/tensorflow_asr/datasets/asr_dataset.py @@ -13,6 +13,8 @@ # limitations under the License. import os +import json +import tqdm import numpy as np import tensorflow as tf @@ -20,7 +22,7 @@ from .base_dataset import BaseDataset, BUFFER_SIZE, TFRECORD_SHARDS, AUTOTUNE from ..featurizers.speech_featurizers import load_and_convert_to_wav, read_raw_audio, tf_read_raw_audio, SpeechFeaturizer from ..featurizers.text_featurizers import TextFeaturizer -from ..utils.utils import bytestring_feature, print_one_line, get_num_batches +from ..utils.utils import bytestring_feature, get_num_batches, preprocess_paths, get_nsamples_from_duration class ASRDataset(BaseDataset): @@ -46,7 +48,49 @@ def __init__(self, self.speech_featurizer = speech_featurizer self.text_featurizer = text_featurizer + # -------------------------------- MAX LENGTHS ------------------------------------- + + def compute_max_lengths(self): + self.read_entries() + for _, duration, indices in tqdm.tqdm(self.entries, desc=f"Computing max lengths for entries in {self.stage} dataset"): + nsamples = get_nsamples_from_duration(duration, sample_rate=self.speech_featurizer.sample_rate) + # https://www.tensorflow.org/api_docs/python/tf/signal/frame + input_length = -(-nsamples // self.speech_featurizer.frame_step) + label = str(indices).split() + label_length = len(label) + self.speech_featurizer.update_length(input_length) + self.text_featurizer.update_length(label_length) + + def save_max_lengths(self, max_lengths_dir: str = None): + if max_lengths_dir is None: return + max_lengths_path = os.path.join(preprocess_paths(max_lengths_dir), "max_lengths.json") + content = { + "max_input_length": self.speech_featurizer.max_length, + "max_label_length": self.text_featurizer.max_length + } + with tf.io.gfile.GFile(max_lengths_path, "w") as f: + f.write(json.dumps(content)) + print(f"Max lengths written to {max_lengths_path}") + + def load_max_lengths(self, max_lengths_dir: str = None): + if max_lengths_dir is None: return + max_lengths_path = os.path.join(preprocess_paths(max_lengths_dir), "max_lengths.json") + if tf.io.gfile.exists(max_lengths_path): + print(f"Loading max lengths from {max_lengths_path} ...") + with tf.io.gfile.GFile(max_lengths_path, "r") as f: + content = json.loads(f.read()) + self.speech_featurizer.update_length(int(content["max_input_length"])) + self.text_featurizer.update_length(int(content["max_label_length"])) + + def update_lengths(self, max_lengths_dir: str = None): + self.load_max_lengths(max_lengths_dir) + self.compute_max_lengths() + self.save_max_lengths(max_lengths_dir) + + # -------------------------------- ENTRIES ------------------------------------- + def read_entries(self): + if hasattr(self, 'entries') and len(self.entries) > 0: return self.entries = [] for file_path in self.data_paths: print(f"Reading {file_path} ...") @@ -62,6 +106,8 @@ def read_entries(self): if self.shuffle: np.random.shuffle(self.entries) # Mix transcripts.tsv self.total_steps = len(self.entries) + # -------------------------------- LOAD AND PREPROCESS ------------------------------------- + def generator(self): for path, _, indices in self.entries: audio = load_and_convert_to_wav(path).numpy() @@ -111,6 +157,17 @@ def tf_preprocess(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor): return path, features, input_length, label, label_length, prediction, prediction_length + @tf.function + def parse(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor): + """ + Returns: + path, features, input_lengths, labels, label_lengths, pred_inp + """ + if self.use_tf: return self.tf_preprocess(path, audio, indices) + return self.preprocess(path, audio, indices) + + # -------------------------------- CREATION ------------------------------------- + def process(self, dataset: tf.data.Dataset, batch_size: int): dataset = dataset.map(self.parse, num_parallel_calls=AUTOTUNE) @@ -141,15 +198,6 @@ def process(self, dataset: tf.data.Dataset, batch_size: int): self.total_steps = get_num_batches(self.total_steps, batch_size, drop_remainders=self.drop_remainder) return dataset - @tf.function - def parse(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor): - """ - Returns: - path, features, input_lengths, labels, label_lengths, pred_inp - """ - if self.use_tf: return self.tf_preprocess(path, audio, indices) - return self.preprocess(path, audio, indices) - def create(self, batch_size: int): self.read_entries() if not self.total_steps or self.total_steps == 0: return None @@ -192,18 +240,25 @@ def __init__(self, @staticmethod def write_tfrecord_file(splitted_entries): shard_path, entries = splitted_entries - with tf.io.TFRecordWriter(shard_path, options='ZLIB') as out: - for path, _, indices in entries: - audio = load_and_convert_to_wav(path).numpy() + + def parse(record): + def fn(path, indices): + audio = load_and_convert_to_wav(path.decode("utf-8")).numpy() feature = { - "path": bytestring_feature([bytes(path, "utf-8")]), + "path": bytestring_feature([path]), "audio": bytestring_feature([audio]), - "indices": bytestring_feature([bytes(indices, "utf-8")]) + "indices": bytestring_feature([indices]) } example = tf.train.Example(features=tf.train.Features(feature=feature)) - out.write(example.SerializeToString()) - print_one_line("Processed:", path) - print(f"\nCreated {shard_path}") + return example.SerializeToString() + return tf.numpy_function(fn, inp=[record[0], record[2]], Tout=tf.string) + + dataset = tf.data.Dataset.from_tensor_slices(entries) + dataset = dataset.map(parse, num_parallel_calls=AUTOTUNE) + writer = tf.data.experimental.TFRecordWriter(shard_path, compression_type="ZLIB") + print(f"Processing {shard_path} ...") + writer.write(dataset) + print(f"Created {shard_path}") def create_tfrecords(self): if not tf.io.gfile.exists(self.tfrecords_dir): @@ -249,7 +304,7 @@ def create(self, batch_size: int): ignore_order = tf.data.Options() ignore_order.experimental_deterministic = False files_ds = files_ds.with_options(ignore_order) - dataset = tf.data.TFRecordDataset(files_ds, compression_type='ZLIB', num_parallel_reads=AUTOTUNE) + dataset = tf.data.TFRecordDataset(files_ds, compression_type="ZLIB", num_parallel_reads=AUTOTUNE) return self.process(dataset, batch_size) diff --git a/tensorflow_asr/featurizers/speech_featurizers.py b/tensorflow_asr/featurizers/speech_featurizers.py index 65b40eb076..24e398b479 100755 --- a/tensorflow_asr/featurizers/speech_featurizers.py +++ b/tensorflow_asr/featurizers/speech_featurizers.py @@ -177,7 +177,7 @@ def map_fn(elem): class SpeechFeaturizer(metaclass=abc.ABCMeta): - def __init__(self, speech_config: dict, tpu: bool = False): + def __init__(self, speech_config: dict): """ We should use TFSpeechFeaturizer for training to avoid differences between tf and librosa when converting to tflite in post-training stage @@ -208,7 +208,6 @@ def __init__(self, speech_config: dict, tpu: bool = False): self.normalize_feature = speech_config.get("normalize_feature", True) self.normalize_per_feature = speech_config.get("normalize_per_feature", False) # Length - self.tpu = tpu self.max_length = 0 @property @@ -259,7 +258,7 @@ def shape(self) -> list: if self.pitch: channel_dim += 1 - length = self.max_length if (self.max_length > 0 and self.tpu) else None + length = self.max_length if self.max_length > 0 else None return [length, self.num_feature_bins, channel_dim] @@ -391,7 +390,7 @@ def compute_log_gammatone_spectrogram(self, signal: np.ndarray) -> np.ndarray: class TFSpeechFeaturizer(SpeechFeaturizer): @property def shape(self) -> list: - length = self.max_length if (self.max_length > 0 and self.tpu) else None + length = self.max_length if self.max_length > 0 else None return [length, self.num_feature_bins, 1] def stft(self, signal): diff --git a/tensorflow_asr/models/ctc.py b/tensorflow_asr/models/ctc.py index 4cbb163eaa..0e12c52c79 100644 --- a/tensorflow_asr/models/ctc.py +++ b/tensorflow_asr/models/ctc.py @@ -27,8 +27,8 @@ def __init__(self, **kwargs): super(CtcModel, self).__init__(**kwargs) self.time_reduction_factor = 1 - def _build(self, input_shape): - features = tf.keras.Input(input_shape, dtype=tf.float32) + def _build(self, input_shape, batch_size=None): + features = tf.keras.Input(input_shape, batch_size=batch_size, dtype=tf.float32) self(features, training=False) def add_featurizers(self, diff --git a/tensorflow_asr/models/keras/transducer.py b/tensorflow_asr/models/keras/transducer.py index 96f5741dfe..7f04c2cba5 100644 --- a/tensorflow_asr/models/keras/transducer.py +++ b/tensorflow_asr/models/keras/transducer.py @@ -24,17 +24,17 @@ class Transducer(BaseTransducer): """ Keras Transducer Model Warper """ - def _build(self, input_shape): - features = tf.keras.Input(shape=input_shape, dtype=tf.float32) - input_length = tf.keras.Input(shape=[], dtype=tf.int32) - pred = tf.keras.Input(shape=[None], dtype=tf.int32) - pred_length = tf.keras.Input(shape=[], dtype=tf.int32) + def _build(self, input_shape, prediction_shape=[None], batch_size=None): + inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size, dtype=tf.float32) + input_length = tf.keras.Input(shape=[], batch_size=batch_size, dtype=tf.int32) + pred = tf.keras.Input(shape=prediction_shape, batch_size=batch_size, dtype=tf.int32) + pred_length = tf.keras.Input(shape=[], batch_size=batch_size, dtype=tf.int32) self({ - "input": features, + "input": inputs, "input_length": input_length, "prediction": pred, "prediction_length": pred_length - }, training=True) + }, training=False) def call(self, inputs, training=False, **kwargs): features = inputs["input"] diff --git a/tensorflow_asr/models/transducer.py b/tensorflow_asr/models/transducer.py index b4109660cc..b3cf810233 100755 --- a/tensorflow_asr/models/transducer.py +++ b/tensorflow_asr/models/transducer.py @@ -246,11 +246,11 @@ def __init__(self, ) self.time_reduction_factor = 1 - def _build(self, input_shape): - inputs = tf.keras.Input(shape=input_shape, dtype=tf.float32) - input_length = tf.keras.Input(shape=[], dtype=tf.int32) - pred = tf.keras.Input(shape=[None], dtype=tf.int32) - pred_length = tf.keras.Input(shape=[], dtype=tf.int32) + def _build(self, input_shape, prediction_shape=[None], batch_size=None): + inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size, dtype=tf.float32) + input_length = tf.keras.Input(shape=[], batch_size=batch_size, dtype=tf.int32) + pred = tf.keras.Input(shape=prediction_shape, batch_size=batch_size, dtype=tf.int32) + pred_length = tf.keras.Input(shape=[], batch_size=batch_size, dtype=tf.int32) self([inputs, input_length, pred, pred_length], training=False) def summary(self, line_length=None, **kwargs): diff --git a/tensorflow_asr/utils/utils.py b/tensorflow_asr/utils/utils.py index 945adcaac6..8fb55a7624 100755 --- a/tensorflow_asr/utils/utils.py +++ b/tensorflow_asr/utils/utils.py @@ -206,3 +206,7 @@ def body(index, tfarray): index, tfarray = tf.while_loop(condition, body, loop_vars=[index, tfarray], swap_memory=False) return tfarray + + +def get_nsamples_from_duration(duration, sample_rate=16000): + return math.ceil(float(duration) * sample_rate) From 68774f7a3aa8c559d6f1dbfa0231e345a6abc4f6 Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Tue, 16 Feb 2021 16:43:02 +0700 Subject: [PATCH 03/12] :writing_hand: update speech and text featurizers --- .../featurizers/speech_featurizers.py | 4 +-- .../featurizers/text_featurizers.py | 35 +++++++++---------- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/tensorflow_asr/featurizers/speech_featurizers.py b/tensorflow_asr/featurizers/speech_featurizers.py index 24e398b479..35e295f707 100755 --- a/tensorflow_asr/featurizers/speech_featurizers.py +++ b/tensorflow_asr/featurizers/speech_featurizers.py @@ -238,8 +238,8 @@ def extract(self, signal): class NumpySpeechFeaturizer(SpeechFeaturizer): - def __init__(self, speech_config: dict, tpu: bool = False): - super(NumpySpeechFeaturizer, self).__init__(speech_config, tpu) + def __init__(self, speech_config: dict): + super(NumpySpeechFeaturizer, self).__init__(speech_config) self.delta = speech_config.get("delta", False) self.delta_delta = speech_config.get("delta_delta", False) self.pitch = speech_config.get("pitch", False) diff --git a/tensorflow_asr/featurizers/text_featurizers.py b/tensorflow_asr/featurizers/text_featurizers.py index 7ec805cc90..270ed89d04 100755 --- a/tensorflow_asr/featurizers/text_featurizers.py +++ b/tensorflow_asr/featurizers/text_featurizers.py @@ -30,23 +30,22 @@ class TextFeaturizer(metaclass=abc.ABCMeta): - def __init__(self, decoder_config: dict, tpu: bool = False): + def __init__(self, decoder_config: dict): self.scorer = None self.decoder_config = DecoderConfig(decoder_config) self.blank = None self.tokens2indices = {} self.tokens = [] self.num_classes = None - self.tpu = tpu self.max_length = 0 @property def shape(self) -> list: - return [self.max_length if (self.max_length > 0 and self.tpu) else None] + return [self.max_length if self.max_length > 0 else None] @property def prepand_shape(self) -> list: - return [self.max_length + 1 if (self.max_length > 0 and self.tpu) else None] + return [self.max_length + 1 if self.max_length > 0 else None] def update_length(self, length: int): self.max_length = max(self.max_length, length) @@ -97,7 +96,7 @@ class CharFeaturizer(TextFeaturizer): converted to a sequence of integer indexes. """ - def __init__(self, decoder_config: dict, tpu: bool = False): + def __init__(self, decoder_config: dict): """ decoder_config = { "vocabulary": str, @@ -108,7 +107,7 @@ def __init__(self, decoder_config: dict, tpu: bool = False): } } """ - super(CharFeaturizer, self).__init__(decoder_config, tpu) + super(CharFeaturizer, self).__init__(decoder_config) self.__init_vocabulary() def __init_vocabulary(self): @@ -191,7 +190,7 @@ class SubwordFeaturizer(TextFeaturizer): converted to a sequence of integer indexes. """ - def __init__(self, decoder_config: dict, subwords=None, tpu: bool = False): + def __init__(self, decoder_config: dict, subwords=None): """ decoder_config = { "target_vocab_size": int, @@ -204,7 +203,7 @@ def __init__(self, decoder_config: dict, subwords=None, tpu: bool = False): } } """ - super(SubwordFeaturizer, self).__init__(decoder_config, tpu) + super(SubwordFeaturizer, self).__init__(decoder_config) self.subwords = self.__load_subwords() if subwords is None else subwords self.blank = 0 # subword treats blank as 0 self.num_classes = self.subwords.vocab_size @@ -223,7 +222,7 @@ def __load_subwords(self): return tds.deprecated.text.SubwordTextEncoder.load_from_file(filename_prefix) @classmethod - def build_from_corpus(cls, decoder_config: dict, corpus_files: list = None, tpu: bool = False): + def build_from_corpus(cls, decoder_config: dict, corpus_files: list = None): dconf = DecoderConfig(decoder_config.copy()) corpus_files = dconf.corpus_files if corpus_files is None or len(corpus_files) == 0 else corpus_files @@ -243,15 +242,15 @@ def corpus_generator(): dconf.max_corpus_chars, dconf.reserved_tokens ) - return cls(decoder_config, subwords, tpu) + return cls(decoder_config, subwords) @classmethod - def load_from_file(cls, decoder_config: dict, filename: str = None, tpu: bool = False): + def load_from_file(cls, decoder_config: dict, filename: str = None): dconf = DecoderConfig(decoder_config.copy()) filename = dconf.vocabulary if filename is None else preprocess_paths(filename) filename_prefix = os.path.splitext(filename)[0] subwords = tds.deprecated.text.SubwordTextEncoder.load_from_file(filename_prefix) - return cls(decoder_config, subwords, tpu) + return cls(decoder_config, subwords) def save_to_file(self, filename: str = None): filename = self.decoder_config.vocabulary if filename is None else preprocess_paths(filename) @@ -329,8 +328,8 @@ class SentencePieceFeaturizer(TextFeaturizer): EOS_TOKEN, EOS_TOKEN_ID = "", 3 PAD_TOKEN, PAD_TOKEN_ID = "", 0 # unused, by default - def __init__(self, decoder_config: dict, model=None, tpu: bool = False): - super(SentencePieceFeaturizer, self).__init__(decoder_config, tpu) + def __init__(self, decoder_config: dict, model=None): + super(SentencePieceFeaturizer, self).__init__(decoder_config) self.model = model self.blank = 0 # treats blank as 0 (pad) self.upoints = None @@ -346,7 +345,7 @@ def __init_upoints(self): self.upoints = self.upoints.to_tensor() # [num_classes, max_subword_length] @classmethod - def build_from_corpus(cls, decoder_config: dict, tpu: bool = False): + def build_from_corpus(cls, decoder_config: dict): """ --model_prefix: output model name prefix. .model and .vocab are generated. --vocab_size: vocabulary size, e.g., 8000, 16000, or 32000 @@ -394,17 +393,17 @@ def corpus_iterator(): for _, s in sorted(vocab.items(), key=lambda x: x[0]): f_out.write(f"{s} 1\n") - return cls(decoder_config, processor, tpu) + return cls(decoder_config, processor) @classmethod - def load_from_file(cls, decoder_config: dict, filename: str = None, tpu: bool = False): + def load_from_file(cls, decoder_config: dict, filename: str = None): if filename is not None: filename_prefix = os.path.splitext(preprocess_paths(filename))[0] else: filename_prefix = decoder_config.get("output_path_prefix", None) processor = sp.SentencePieceProcessor() processor.load(filename_prefix + ".model") - return cls(decoder_config, processor, tpu) + return cls(decoder_config, processor) def extract(self, text: str) -> tf.Tensor: """ From 846a4738c94eea06c0595889e8b703086e9fb445 Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Tue, 16 Feb 2021 17:06:58 +0700 Subject: [PATCH 04/12] :writing_hand: update configs --- examples/contextnet/config.yml | 9 ++++++--- examples/deepspeech2/config.yml | 9 ++++++--- examples/jasper/config.yml | 9 ++++++--- examples/streaming_transducer/config.yml | 9 ++++++--- 4 files changed, 24 insertions(+), 12 deletions(-) diff --git a/examples/contextnet/config.yml b/examples/contextnet/config.yml index 7a32b08af9..5127dd1de6 100644 --- a/examples/contextnet/config.yml +++ b/examples/contextnet/config.yml @@ -208,32 +208,35 @@ learning_config: mask_factor: 27 data_paths: - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv - tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test + tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords shuffle: True cache: True buffer_size: 100 drop_remainder: True + stage: train eval_dataset_config: use_tf: True data_paths: - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-clean/transcripts.tsv - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv - tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test + tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords shuffle: False cache: True buffer_size: 100 drop_remainder: True + stage: eval test_dataset_config: use_tf: True data_paths: - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv - tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test + tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords shuffle: False cache: True buffer_size: 100 drop_remainder: True + stage: test optimizer_config: warmup_steps: 40000 diff --git a/examples/deepspeech2/config.yml b/examples/deepspeech2/config.yml index a82b787bbc..68a77d7bd4 100755 --- a/examples/deepspeech2/config.yml +++ b/examples/deepspeech2/config.yml @@ -53,32 +53,35 @@ learning_config: use_tf: True data_paths: - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv - tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test + tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords shuffle: True cache: True buffer_size: 100 drop_remainder: True + stage: train eval_dataset_config: use_tf: True data_paths: - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-clean/transcripts.tsv - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv - tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test + tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords shuffle: False cache: True buffer_size: 100 drop_remainder: True + stage: eval test_dataset_config: use_tf: True data_paths: - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv - tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test + tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords shuffle: False cache: True buffer_size: 100 drop_remainder: True + stage: test optimizer_config: class_name: adam diff --git a/examples/jasper/config.yml b/examples/jasper/config.yml index 257122e4cb..f6c158edce 100755 --- a/examples/jasper/config.yml +++ b/examples/jasper/config.yml @@ -60,32 +60,35 @@ learning_config: use_tf: True data_paths: - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv - tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test + tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords shuffle: True cache: True buffer_size: 100 drop_remainder: True + stage: train eval_dataset_config: use_tf: True data_paths: - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-clean/transcripts.tsv - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv - tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test + tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords shuffle: False cache: True buffer_size: 100 drop_remainder: True + stage: eval test_dataset_config: use_tf: True data_paths: - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv - tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test + tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords shuffle: False cache: True buffer_size: 100 drop_remainder: True + stage: test optimizer_config: class_name: adam diff --git a/examples/streaming_transducer/config.yml b/examples/streaming_transducer/config.yml index 9c813a5688..47b0e41ae9 100755 --- a/examples/streaming_transducer/config.yml +++ b/examples/streaming_transducer/config.yml @@ -65,32 +65,35 @@ learning_config: mask_factor: 27 data_paths: - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv - tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test + tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords shuffle: True cache: True buffer_size: 100 drop_remainder: True + stage: train eval_dataset_config: use_tf: True data_paths: - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-clean/transcripts.tsv - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv - tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test + tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords shuffle: False cache: True buffer_size: 100 drop_remainder: True + stage: eval test_dataset_config: use_tf: True data_paths: - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv - tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test + tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords shuffle: False cache: True buffer_size: 100 drop_remainder: True + stage: test optimizer_config: class_name: adam From 654f51c5c9688623b6d24bd33a1a228fb90a5d83 Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Tue, 16 Feb 2021 18:26:29 +0700 Subject: [PATCH 05/12] :zap: add script to generate max lengths --- .../conformer/train_tpu_subword_conformer.py | 11 ++-- scripts/create_tfrecords.py | 12 ++-- scripts/generate_max_lengths.py | 58 ++++++++++++++++++ setup.py | 2 +- tensorflow_asr/datasets/asr_dataset.py | 18 +++--- .../librispeech_train_4_1030.max_lengths.json | 1 + .../librispeech_train_4_1030.subwords | 0 .../librispeech_train_4_4076.max_lengths.json | 1 + .../librispeech_train_4_4076.subwords | 0 .../sentencepiece_librispeech_960_8000.model | Bin 10 files changed, 83 insertions(+), 20 deletions(-) create mode 100644 scripts/generate_max_lengths.py create mode 100644 vocabularies/librispeech/librispeech_train_4_1030.max_lengths.json rename vocabularies/{ => librispeech}/librispeech_train_4_1030.subwords (100%) create mode 100644 vocabularies/librispeech/librispeech_train_4_4076.max_lengths.json rename vocabularies/{ => librispeech}/librispeech_train_4_4076.subwords (100%) rename vocabularies/{ => librispeech}/sentencepiece_librispeech_960_8000.model (100%) diff --git a/examples/conformer/train_tpu_subword_conformer.py b/examples/conformer/train_tpu_subword_conformer.py index ca9b643741..131b35a2a4 100644 --- a/examples/conformer/train_tpu_subword_conformer.py +++ b/examples/conformer/train_tpu_subword_conformer.py @@ -36,8 +36,7 @@ parser.add_argument("--tpu_address", type=str, default=None, help="TPU address. Leave None on Colab") -parser.add_argument("--max_lengths_dir", type=str, default="~", - help="Path to file containing max lengths. Will be computed if not exists") +parser.add_argument("--max_lengths_prefix", type=str, default=None, help="Path to file containing max lengths") parser.add_argument("--compute_lengths", default=False, action="store_true", help="Whether to compute lengths") @@ -87,12 +86,12 @@ ) if args.compute_lengths: - train_dataset.update_lengths(args.max_lengths_dir) - eval_dataset.update_lengths(args.max_lengths_dir) + train_dataset.update_lengths(args.max_lengths_prefix) + eval_dataset.update_lengths(args.max_lengths_prefix) # Update max lengths calculated from both train and eval datasets -train_dataset.load_max_lengths(args.max_lengths_dir) -eval_dataset.load_max_lengths(args.max_lengths_dir) +train_dataset.load_max_lengths(args.max_lengths_prefix) +eval_dataset.load_max_lengths(args.max_lengths_prefix) strategy = setup_tpu(args.tpu_address) diff --git a/scripts/create_tfrecords.py b/scripts/create_tfrecords.py index 81fee9e948..8fe48dcd0e 100644 --- a/scripts/create_tfrecords.py +++ b/scripts/create_tfrecords.py @@ -17,7 +17,7 @@ from tensorflow_asr.configs.config import Config from tensorflow_asr.utils.utils import preprocess_paths from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset -from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer +from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, SentencePieceFeaturizer parser = argparse.ArgumentParser(prog="TFRecords Creation") @@ -31,6 +31,8 @@ parser.add_argument("--shuffle", default=False, action="store_true", help="Shuffle data or not") +parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model") + parser.add_argument("--subwords", type=str, default=None, help="Path to file that stores generated subwords") parser.add_argument("transcripts", nargs="+", type=str, default=None, help="Paths to transcript files") @@ -41,11 +43,13 @@ tfrecords_dir = preprocess_paths(args.tfrecords_dir) config = Config(args.config) -if args.subwords and os.path.exists(args.subwords): + +if args.sentence_piece: + print("Loading SentencePiece model ...") + text_featurizer = SentencePieceFeaturizer.load_from_file(config.decoder_config, args.subwords) +elif args.subwords and os.path.exists(args.subwords): print("Loading subwords ...") text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords) -else: - raise ValueError("subwords must be set") ASRTFRecordDataset( data_paths=transcripts, tfrecords_dir=tfrecords_dir, diff --git a/scripts/generate_max_lengths.py b/scripts/generate_max_lengths.py new file mode 100644 index 0000000000..d533546b4f --- /dev/null +++ b/scripts/generate_max_lengths.py @@ -0,0 +1,58 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +from tensorflow_asr.configs.config import Config +from tensorflow_asr.utils.utils import preprocess_paths +from tensorflow_asr.datasets.asr_dataset import ASRDataset +from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer +from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, SentencePieceFeaturizer + +parser = argparse.ArgumentParser(prog="TFRecords Creation") + +parser.add_argument("--config", type=str, default=None, help="The file path of model configuration file") + +parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model") + +parser.add_argument("--max_lengths_prefix", type=str, default=None, help="Path to file containing max lengths") + +parser.add_argument("--subwords", type=str, default=None, help="Path to file that stores generated subwords") + +parser.add_argument("transcripts", nargs="+", type=str, default=None, help="Paths to transcript files") + +args = parser.parse_args() + +assert args.max_lengths_prefix is not None, "max_lengths_prefix must be defined" + +transcripts = preprocess_paths(args.transcripts) + +config = Config(args.config) + +speech_featurizer = TFSpeechFeaturizer(config.speech_config) + +if args.sentence_piece: + print("Loading SentencePiece model ...") + text_featurizer = SentencePieceFeaturizer.load_from_file(config.decoder_config, args.subwords) +elif args.subwords and os.path.exists(args.subwords): + print("Loading subwords ...") + text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords) + +dataset = ASRDataset( + data_paths=transcripts, + speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, + stage="train", shuffle=False, +) + +dataset.update_lengths(args.max_lengths_prefix) diff --git a/setup.py b/setup.py index b76298e8d4..f2535c4619 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ setuptools.setup( name="TensorFlowASR", - version="0.7.4", + version="0.7.5", author="Huy Le Nguyen", author_email="nlhuy.cs.16@gmail.com", description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2", diff --git a/tensorflow_asr/datasets/asr_dataset.py b/tensorflow_asr/datasets/asr_dataset.py index 06665d82d5..017506d2d1 100755 --- a/tensorflow_asr/datasets/asr_dataset.py +++ b/tensorflow_asr/datasets/asr_dataset.py @@ -61,9 +61,9 @@ def compute_max_lengths(self): self.speech_featurizer.update_length(input_length) self.text_featurizer.update_length(label_length) - def save_max_lengths(self, max_lengths_dir: str = None): - if max_lengths_dir is None: return - max_lengths_path = os.path.join(preprocess_paths(max_lengths_dir), "max_lengths.json") + def save_max_lengths(self, max_lengths_prefix: str = None): + if max_lengths_prefix is None: return + max_lengths_path = preprocess_paths(max_lengths_prefix) + ".max_lengths.json" content = { "max_input_length": self.speech_featurizer.max_length, "max_label_length": self.text_featurizer.max_length @@ -72,9 +72,9 @@ def save_max_lengths(self, max_lengths_dir: str = None): f.write(json.dumps(content)) print(f"Max lengths written to {max_lengths_path}") - def load_max_lengths(self, max_lengths_dir: str = None): - if max_lengths_dir is None: return - max_lengths_path = os.path.join(preprocess_paths(max_lengths_dir), "max_lengths.json") + def load_max_lengths(self, max_lengths_prefix: str = None): + if max_lengths_prefix is None: return + max_lengths_path = preprocess_paths(max_lengths_prefix) + ".max_lengths.json" if tf.io.gfile.exists(max_lengths_path): print(f"Loading max lengths from {max_lengths_path} ...") with tf.io.gfile.GFile(max_lengths_path, "r") as f: @@ -82,10 +82,10 @@ def load_max_lengths(self, max_lengths_dir: str = None): self.speech_featurizer.update_length(int(content["max_input_length"])) self.text_featurizer.update_length(int(content["max_label_length"])) - def update_lengths(self, max_lengths_dir: str = None): - self.load_max_lengths(max_lengths_dir) + def update_lengths(self, max_lengths_prefix: str = None): + self.load_max_lengths(max_lengths_prefix) self.compute_max_lengths() - self.save_max_lengths(max_lengths_dir) + self.save_max_lengths(max_lengths_prefix) # -------------------------------- ENTRIES ------------------------------------- diff --git a/vocabularies/librispeech/librispeech_train_4_1030.max_lengths.json b/vocabularies/librispeech/librispeech_train_4_1030.max_lengths.json new file mode 100644 index 0000000000..b1168dee5b --- /dev/null +++ b/vocabularies/librispeech/librispeech_train_4_1030.max_lengths.json @@ -0,0 +1 @@ +{"max_input_length": 2974, "max_label_length": 207} \ No newline at end of file diff --git a/vocabularies/librispeech_train_4_1030.subwords b/vocabularies/librispeech/librispeech_train_4_1030.subwords similarity index 100% rename from vocabularies/librispeech_train_4_1030.subwords rename to vocabularies/librispeech/librispeech_train_4_1030.subwords diff --git a/vocabularies/librispeech/librispeech_train_4_4076.max_lengths.json b/vocabularies/librispeech/librispeech_train_4_4076.max_lengths.json new file mode 100644 index 0000000000..bfc85a2346 --- /dev/null +++ b/vocabularies/librispeech/librispeech_train_4_4076.max_lengths.json @@ -0,0 +1 @@ +{"max_input_length": 2974, "max_label_length": 164} \ No newline at end of file diff --git a/vocabularies/librispeech_train_4_4076.subwords b/vocabularies/librispeech/librispeech_train_4_4076.subwords similarity index 100% rename from vocabularies/librispeech_train_4_4076.subwords rename to vocabularies/librispeech/librispeech_train_4_4076.subwords diff --git a/vocabularies/sentencepiece_librispeech_960_8000.model b/vocabularies/librispeech/sentencepiece_librispeech_960_8000.model similarity index 100% rename from vocabularies/sentencepiece_librispeech_960_8000.model rename to vocabularies/librispeech/sentencepiece_librispeech_960_8000.model From dbab4c6d97a9f5e43c0a4e57fead02d0344ffd93 Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Tue, 16 Feb 2021 18:28:49 +0700 Subject: [PATCH 06/12] :writing_hand: update configs --- examples/conformer/config.yml | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/examples/conformer/config.yml b/examples/conformer/config.yml index 84b2636df3..e370a1a9eb 100755 --- a/examples/conformer/config.yml +++ b/examples/conformer/config.yml @@ -31,9 +31,7 @@ decoder_config: beam_width: 5 norm_score: True corpus_files: - - /media/nlhuy/Data/ML/Datasets/ASR/Raw/LibriSpeech/LibriSpeech/train-clean-100/transcripts.tsv - - /media/nlhuy/Data/ML/Datasets/ASR/Raw/LibriSpeech/LibriSpeech/train-clean-360/transcripts.tsv - - /media/nlhuy/Data/ML/Datasets/ASR/Raw/LibriSpeech/LibriSpeech/train-other-500/transcripts.tsv + - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv model_config: name: conformer @@ -77,9 +75,7 @@ learning_config: mask_factor: 27 data_paths: - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv - - /media/nlhuy/Data/ML/Datasets/ASR/Raw/LibriSpeech/LibriSpeech/train-clean-360/transcripts.tsv - - /media/nlhuy/Data/ML/Datasets/ASR/Raw/LibriSpeech/LibriSpeech/train-other-500/transcripts.tsv - tfrecords_dir: /media/nlhuy/Data/ML/Datasets/ASR/Raw/LibriSpeech/tfrecords_4076 + tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords shuffle: True cache: True buffer_size: 100 @@ -91,7 +87,7 @@ learning_config: data_paths: - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-clean/transcripts.tsv - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv - tfrecords_dir: /media/nlhuy/Data/ML/Datasets/ASR/Raw/LibriSpeech/tfrecords_4076 + tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords shuffle: False cache: True buffer_size: 100 From 4238f59c8c360c0f1d09131e993c9b28d41abd08 Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Tue, 16 Feb 2021 18:37:58 +0700 Subject: [PATCH 07/12] :writing_hand: update shape for keras datasets --- .../train_tpu_keras_subword_conformer.py | 128 ++++++++++++++++++ tensorflow_asr/datasets/keras/asr_dataset.py | 4 +- 2 files changed, 130 insertions(+), 2 deletions(-) create mode 100644 examples/conformer/train_tpu_keras_subword_conformer.py diff --git a/examples/conformer/train_tpu_keras_subword_conformer.py b/examples/conformer/train_tpu_keras_subword_conformer.py new file mode 100644 index 0000000000..ceaf267d5b --- /dev/null +++ b/examples/conformer/train_tpu_keras_subword_conformer.py @@ -0,0 +1,128 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import math +import argparse +from tensorflow_asr.utils import setup_environment, setup_tpu + +setup_environment() +import tensorflow as tf + +DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml") + +tf.keras.backend.clear_session() + +parser = argparse.ArgumentParser(prog="Conformer Training") + +parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file") + +parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model") + +parser.add_argument("--bs", type=int, default=None, help="Batch size per replica") + +parser.add_argument("--tpu_address", type=str, default=None, help="TPU address. Leave None on Colab") + +parser.add_argument("--max_lengths_prefix", type=str, default=None, help="Path to file containing max lengths") + +parser.add_argument("--compute_lengths", default=False, action="store_true", help="Whether to compute lengths") + +parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision") + +parser.add_argument("--subwords", type=str, default=None, help="Path to file that stores generated subwords") + +parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[], help="Transcript files for generating subwords") + +args = parser.parse_args() + +tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp}) + +strategy = setup_tpu(args.tpu_address) + +from tensorflow_asr.configs.config import Config +from tensorflow_asr.datasets.keras import ASRTFRecordDatasetKeras +from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer +from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, SentencePieceFeaturizer +from tensorflow_asr.models.keras.conformer import Conformer +from tensorflow_asr.optimizers.schedules import TransformerSchedule + +config = Config(args.config) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) + +if args.sentence_piece: + print("Loading SentencePiece model ...") + text_featurizer = SentencePieceFeaturizer.load_from_file(config.decoder_config, args.subwords) +elif args.subwords and os.path.exists(args.subwords): + print("Loading subwords ...") + text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords) +else: + print("Generating subwords ...") + text_featurizer = SubwordFeaturizer.build_from_corpus( + config.decoder_config, + corpus_files=args.subwords_corpus + ) + text_featurizer.save_to_file(args.subwords) + +train_dataset = ASRTFRecordDatasetKeras( + speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, + **vars(config.learning_config.train_dataset_config) +) +eval_dataset = ASRTFRecordDatasetKeras( + speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, + **vars(config.learning_config.eval_dataset_config) +) + +if args.compute_lengths: + train_dataset.update_lengths(args.max_lengths_prefix) + eval_dataset.update_lengths(args.max_lengths_prefix) + +# Update max lengths calculated from both train and eval datasets +train_dataset.load_max_lengths(args.max_lengths_prefix) +eval_dataset.load_max_lengths(args.max_lengths_prefix) + +with strategy.scope(): + global_batch_size = config.learning_config.running_config.batch_size + global_batch_size *= strategy.num_replicas_in_sync + # build model + conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) + conformer._build(speech_featurizer.shape) + conformer.summary(line_length=120) + + optimizer = tf.keras.optimizers.Adam( + TransformerSchedule( + d_model=conformer.dmodel, + warmup_steps=config.learning_config.optimizer_config["warmup_steps"], + max_lr=(0.05 / math.sqrt(conformer.dmodel)) + ), + beta_1=config.learning_config.optimizer_config["beta1"], + beta_2=config.learning_config.optimizer_config["beta2"], + epsilon=config.learning_config.optimizer_config["epsilon"] + ) + + conformer.compile(optimizer=optimizer, global_batch_size=global_batch_size, blank=text_featurizer.blank) + + train_data_loader = train_dataset.create(global_batch_size) + eval_data_loader = eval_dataset.create(global_batch_size) + + callbacks = [ + tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), + tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), + tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) + ] + + conformer.fit( + train_data_loader, epochs=config.learning_config.running_config.num_epochs, + validation_data=eval_data_loader, callbacks=callbacks, + steps_per_epoch=train_dataset.total_steps + ) diff --git a/tensorflow_asr/datasets/keras/asr_dataset.py b/tensorflow_asr/datasets/keras/asr_dataset.py index e981fa1fd4..6f29bf95ee 100644 --- a/tensorflow_asr/datasets/keras/asr_dataset.py +++ b/tensorflow_asr/datasets/keras/asr_dataset.py @@ -67,11 +67,11 @@ def process(self, dataset, batch_size): "path": tf.TensorShape([]), "input": tf.TensorShape(self.speech_featurizer.shape), "input_length": tf.TensorShape([]), - "prediction": tf.TensorShape([None]), + "prediction": tf.TensorShape(self.text_featurizer.prepand_shape), "prediction_length": tf.TensorShape([]) }, { - "label": tf.TensorShape([None]), + "label": tf.TensorShape(self.text_featurizer.shape), "label_length": tf.TensorShape([]) }, ), From b0cc951f44c286c7d4d3e9deb180b7193682781f Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Tue, 16 Feb 2021 18:40:22 +0700 Subject: [PATCH 08/12] :writing_hand: update tpu keras training script --- examples/conformer/train_tpu_keras_subword_conformer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/conformer/train_tpu_keras_subword_conformer.py b/examples/conformer/train_tpu_keras_subword_conformer.py index ceaf267d5b..c45807eead 100644 --- a/examples/conformer/train_tpu_keras_subword_conformer.py +++ b/examples/conformer/train_tpu_keras_subword_conformer.py @@ -92,11 +92,12 @@ eval_dataset.load_max_lengths(args.max_lengths_prefix) with strategy.scope(): - global_batch_size = config.learning_config.running_config.batch_size + batch_size = args.bs if args.bs is not None else config.learning_config.running_config.batch_size + global_batch_size = batch_size global_batch_size *= strategy.num_replicas_in_sync # build model conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) - conformer._build(speech_featurizer.shape) + conformer._build(speech_featurizer.shape, prediction_shape=text_featurizer.prepand_shape, batch_size=batch_size) conformer.summary(line_length=120) optimizer = tf.keras.optimizers.Adam( From b10cb03b75174c31356fd3fc8a1060182191f741 Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Tue, 16 Feb 2021 18:54:40 +0700 Subject: [PATCH 09/12] :writing_hand: tensorflow-io does not support tpu yet --- .../conformer/train_tpu_keras_subword_conformer.py | 1 - tensorflow_asr/featurizers/speech_featurizers.py | 10 +++++++--- tensorflow_asr/utils/utils.py | 6 ++++++ 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/examples/conformer/train_tpu_keras_subword_conformer.py b/examples/conformer/train_tpu_keras_subword_conformer.py index c45807eead..fde7022c52 100644 --- a/examples/conformer/train_tpu_keras_subword_conformer.py +++ b/examples/conformer/train_tpu_keras_subword_conformer.py @@ -125,5 +125,4 @@ conformer.fit( train_data_loader, epochs=config.learning_config.running_config.num_epochs, validation_data=eval_data_loader, callbacks=callbacks, - steps_per_epoch=train_dataset.total_steps ) diff --git a/tensorflow_asr/featurizers/speech_featurizers.py b/tensorflow_asr/featurizers/speech_featurizers.py index 35e295f707..5d2d76b6e2 100755 --- a/tensorflow_asr/featurizers/speech_featurizers.py +++ b/tensorflow_asr/featurizers/speech_featurizers.py @@ -21,9 +21,11 @@ import tensorflow as tf import tensorflow_io as tfio -from ..utils.utils import log10 +from ..utils.utils import log10, has_tpu from .gammatone import fft_weights +tpu = has_tpu() + def load_and_convert_to_wav(path: str) -> tf.Tensor: wave, rate = librosa.load(os.path.expanduser(path), sr=None, mono=True) @@ -48,8 +50,10 @@ def read_raw_audio(audio, sample_rate=16000): def tf_read_raw_audio(audio: tf.Tensor, sample_rate=16000): wave, rate = tf.audio.decode_wav(audio, desired_channels=1, desired_samples=-1) - resampled = tfio.audio.resample(wave, rate_in=tf.cast(rate, dtype=tf.int64), rate_out=sample_rate) - return tf.reshape(resampled, shape=[-1]) # reshape for using tf.signal + if not tpu: + resampled = tfio.audio.resample(wave, rate_in=tf.cast(rate, dtype=tf.int64), rate_out=sample_rate) + return tf.reshape(resampled, shape=[-1]) # reshape for using tf.signal + return tf.reshape(wave, shape=[-1]) # reshape for using tf.signal def slice_signal(signal, window_size, stride=0.5) -> np.ndarray: diff --git a/tensorflow_asr/utils/utils.py b/tensorflow_asr/utils/utils.py index 8fb55a7624..1ade899227 100755 --- a/tensorflow_asr/utils/utils.py +++ b/tensorflow_asr/utils/utils.py @@ -169,6 +169,12 @@ def has_gpu_or_tpu(): return True +def has_tpu(): + tpus = tf.config.list_logical_devices("TPU") + if len(tpus) == 0: return False + return True + + def find_max_length_prediction_tfarray(tfarray: tf.TensorArray) -> tf.Tensor: with tf.name_scope("find_max_length_prediction_tfarray"): index = tf.constant(0, dtype=tf.int32) From 8e02e88ea46d1e3bafa98b9d582594acc4f095d3 Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Tue, 16 Feb 2021 19:10:32 +0700 Subject: [PATCH 10/12] :zap: add max length to predict network --- tensorflow_asr/datasets/asr_dataset.py | 2 +- tensorflow_asr/models/transducer.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow_asr/datasets/asr_dataset.py b/tensorflow_asr/datasets/asr_dataset.py index 017506d2d1..5ecc026f03 100755 --- a/tensorflow_asr/datasets/asr_dataset.py +++ b/tensorflow_asr/datasets/asr_dataset.py @@ -90,7 +90,7 @@ def update_lengths(self, max_lengths_prefix: str = None): # -------------------------------- ENTRIES ------------------------------------- def read_entries(self): - if hasattr(self, 'entries') and len(self.entries) > 0: return + if hasattr(self, "entries") and len(self.entries) > 0: return self.entries = [] for file_path in self.data_paths: print(f"Reading {file_path} ...") diff --git a/tensorflow_asr/models/transducer.py b/tensorflow_asr/models/transducer.py index b3cf810233..84148fcb26 100755 --- a/tensorflow_asr/models/transducer.py +++ b/tensorflow_asr/models/transducer.py @@ -93,10 +93,11 @@ def call(self, inputs, training=False, **kwargs): # inputs has shape [B, U] # use tf.gather_nd instead of tf.gather for tflite conversion outputs, prediction_length = inputs + if not hasattr(self, "max_length"): self.max_length = shape_list(outputs)[-1] outputs = self.embed(outputs, training=training) outputs = self.do(outputs, training=training) for rnn in self.rnns: - mask = tf.sequence_mask(prediction_length) + mask = tf.sequence_mask(prediction_length, maxlen=self.max_length) outputs = rnn["rnn"](outputs, training=training, mask=mask) outputs = outputs[0] if rnn["ln"] is not None: From 50d246c439f6fa9bdbac6818ce3342de71e49c5a Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Tue, 16 Feb 2021 19:12:37 +0700 Subject: [PATCH 11/12] :writing_hand: update message --- tensorflow_asr/losses/rnnt_losses.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow_asr/losses/rnnt_losses.py b/tensorflow_asr/losses/rnnt_losses.py index 8e48779de3..d6cbcb2b06 100644 --- a/tensorflow_asr/losses/rnnt_losses.py +++ b/tensorflow_asr/losses/rnnt_losses.py @@ -22,9 +22,9 @@ try: from warprnnt_tensorflow import rnnt_loss as warp_rnnt_loss use_warprnnt = True + print("Use RNNT loss in WarpRnnt") except ImportError: - print("Cannot import RNNT loss in warprnnt. Falls back to RNNT in TensorFlow") - print("Note: The RNNT in Tensorflow is not supported for CPU yet") + print("Use RNNT loss in TensorFlow") use_warprnnt = False From 4414d8158822ac0704b42b6cf16cb667cf67d00b Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Tue, 16 Feb 2021 20:23:07 +0700 Subject: [PATCH 12/12] :writing_hand: update training script --- examples/conformer/train_tpu_keras_subword_conformer.py | 2 +- tensorflow_asr/featurizers/speech_featurizers.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/conformer/train_tpu_keras_subword_conformer.py b/examples/conformer/train_tpu_keras_subword_conformer.py index fde7022c52..338f72f932 100644 --- a/examples/conformer/train_tpu_keras_subword_conformer.py +++ b/examples/conformer/train_tpu_keras_subword_conformer.py @@ -97,7 +97,7 @@ global_batch_size *= strategy.num_replicas_in_sync # build model conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) - conformer._build(speech_featurizer.shape, prediction_shape=text_featurizer.prepand_shape, batch_size=batch_size) + conformer._build(speech_featurizer.shape, prediction_shape=text_featurizer.prepand_shape, batch_size=global_batch_size) conformer.summary(line_length=120) optimizer = tf.keras.optimizers.Adam( diff --git a/tensorflow_asr/featurizers/speech_featurizers.py b/tensorflow_asr/featurizers/speech_featurizers.py index 5d2d76b6e2..554b00063d 100755 --- a/tensorflow_asr/featurizers/speech_featurizers.py +++ b/tensorflow_asr/featurizers/speech_featurizers.py @@ -27,6 +27,14 @@ tpu = has_tpu() +# def tf_resample(signal, rate_in, rate_out): +# if rate_in == rate_out: return signal +# rate_in = tf.cast(rate_in, dtype=tf.float32) +# rate_out = tf.cast(rate_out, dtype=tf.float32) +# ratio = rate_out / rate_in +# nsamples = tf.math.ceil(tf.shape(signal)[0] * ratio) + + def load_and_convert_to_wav(path: str) -> tf.Tensor: wave, rate = librosa.load(os.path.expanduser(path), sr=None, mono=True) return tf.audio.encode_wav(tf.expand_dims(wave, axis=-1), sample_rate=rate)