diff --git a/examples/conformer/save_conformer_from_weights.py b/examples/conformer/save_conformer_from_weights.py index 9285603c66..3d51abfc49 100644 --- a/examples/conformer/save_conformer_from_weights.py +++ b/examples/conformer/save_conformer_from_weights.py @@ -46,23 +46,20 @@ setup_devices([args.device], cpu=args.cpu) -from tensorflow_asr.configs.user_config import UserConfig +from tensorflow_asr.configs.config import Config from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer from tensorflow_asr.models.conformer import Conformer -config = UserConfig(DEFAULT_YAML, args.config, learning=True) -speech_featurizer = TFSpeechFeaturizer(config["speech_config"]) -text_featurizer = CharFeaturizer(config["decoder_config"]) +config = Config(args.config, learning=True) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) +text_featurizer = CharFeaturizer(config.decoder_config) tf.random.set_seed(0) assert args.saved # build model -conformer = Conformer( - vocabulary_size=text_featurizer.num_classes, - **config["model_config"] -) +conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) conformer._build(speech_featurizer.shape) conformer.load_weights(args.saved, by_name=True) conformer.summary(line_length=150) diff --git a/examples/conformer/test_conformer.py b/examples/conformer/test_conformer.py index 565e1d1a3f..0998e6667a 100755 --- a/examples/conformer/test_conformer.py +++ b/examples/conformer/test_conformer.py @@ -52,48 +52,45 @@ setup_devices([args.device], cpu=args.cpu) -from tensorflow_asr.configs.user_config import UserConfig +from tensorflow_asr.configs.config import Config from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer from tensorflow_asr.runners.base_runners import BaseTester from tensorflow_asr.models.conformer import Conformer -config = UserConfig(DEFAULT_YAML, args.config, learning=True) -speech_featurizer = TFSpeechFeaturizer(config["speech_config"]) -text_featurizer = CharFeaturizer(config["decoder_config"]) +config = Config(args.config, learning=True) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) +text_featurizer = CharFeaturizer(config.decoder_config) tf.random.set_seed(0) assert args.saved if args.tfrecords: test_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["test_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.test_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="test", shuffle=False ) else: test_dataset = ASRSliceDataset( - data_paths=config["learning_config"]["dataset_config"]["test_paths"], + data_paths=config.learning_config.dataset_config.test_paths, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="test", shuffle=False ) # build model -conformer = Conformer( - vocabulary_size=text_featurizer.num_classes, - **config["model_config"] -) +conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) conformer._build(speech_featurizer.shape) conformer.load_weights(args.saved, by_name=True) conformer.summary(line_length=120) conformer.add_featurizers(speech_featurizer, text_featurizer) conformer_tester = BaseTester( - config=config["learning_config"]["running_config"], + config=config.learning_config.running_config, output_name=args.output_name ) conformer_tester.compile(conformer) diff --git a/examples/conformer/test_subword_conformer.py b/examples/conformer/test_subword_conformer.py index 62b920348d..47ea1b09f0 100755 --- a/examples/conformer/test_subword_conformer.py +++ b/examples/conformer/test_subword_conformer.py @@ -55,19 +55,19 @@ setup_devices([args.device], cpu=args.cpu) -from tensorflow_asr.configs.user_config import UserConfig +from tensorflow_asr.configs.config import Config from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer from tensorflow_asr.runners.base_runners import BaseTester from tensorflow_asr.models.conformer import Conformer -config = UserConfig(DEFAULT_YAML, args.config, learning=True) -speech_featurizer = TFSpeechFeaturizer(config["speech_config"]) +config = Config(args.config, learning=True) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) if args.subwords and os.path.exists(args.subwords): print("Loading subwords ...") - text_featurizer = SubwordFeaturizer.load_from_file(config["decoder_config"], args.subwords) + text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords) else: raise ValueError("subwords must be set") @@ -76,32 +76,29 @@ if args.tfrecords: test_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["test_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.test_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="test", shuffle=False ) else: test_dataset = ASRSliceDataset( - data_paths=config["learning_config"]["dataset_config"]["test_paths"], + data_paths=config.learning_config.dataset_config.test_paths, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="test", shuffle=False ) # build model -conformer = Conformer( - vocabulary_size=text_featurizer.num_classes, - **config["model_config"] -) +conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) conformer._build(speech_featurizer.shape) conformer.load_weights(args.saved, by_name=True) conformer.summary(line_length=120) conformer.add_featurizers(speech_featurizer, text_featurizer) conformer_tester = BaseTester( - config=config["learning_config"]["running_config"], + config=config.learning_config.running_config, output_name=args.output_name ) conformer_tester.compile(conformer) diff --git a/examples/conformer/tflite_conformer.py b/examples/conformer/tflite_conformer.py index 2f002cfa81..1dc1f9cacf 100644 --- a/examples/conformer/tflite_conformer.py +++ b/examples/conformer/tflite_conformer.py @@ -19,7 +19,7 @@ setup_environment() import tensorflow as tf -from tensorflow_asr.configs.user_config import UserConfig +from tensorflow_asr.configs.config import Config from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer from tensorflow_asr.models.conformer import Conformer @@ -43,15 +43,12 @@ assert args.saved and args.output -config = UserConfig(DEFAULT_YAML, args.config, learning=True) -speech_featurizer = TFSpeechFeaturizer(config["speech_config"]) -text_featurizer = CharFeaturizer(config["decoder_config"]) +config = Config(args.config, learning=True) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) +text_featurizer = CharFeaturizer(config.decoder_config) # build model -conformer = Conformer( - **config["model_config"], - vocabulary_size=text_featurizer.num_classes -) +conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) conformer._build(speech_featurizer.shape) conformer.load_weights(args.saved) conformer.summary(line_length=150) diff --git a/examples/conformer/tflite_subword_conformer.py b/examples/conformer/tflite_subword_conformer.py index 53d0485602..6ea372ee41 100644 --- a/examples/conformer/tflite_subword_conformer.py +++ b/examples/conformer/tflite_subword_conformer.py @@ -19,7 +19,7 @@ setup_environment() import tensorflow as tf -from tensorflow_asr.configs.user_config import UserConfig +from tensorflow_asr.configs.config import Config from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer from tensorflow_asr.models.conformer import Conformer @@ -46,20 +46,17 @@ assert args.saved and args.output -config = UserConfig(DEFAULT_YAML, args.config, learning=True) -speech_featurizer = TFSpeechFeaturizer(config["speech_config"]) +config = Config(args.config, learning=True) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) if args.subwords and os.path.exists(args.subwords): print("Loading subwords ...") - text_featurizer = SubwordFeaturizer.load_from_file(config["decoder_config"], args.subwords) + text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords) else: raise ValueError("subwords must be set") # build model -conformer = Conformer( - **config["model_config"], - vocabulary_size=text_featurizer.num_classes -) +conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) conformer._build(speech_featurizer.shape) conformer.load_weights(args.saved) conformer.summary(line_length=150) diff --git a/examples/conformer/train_conformer.py b/examples/conformer/train_conformer.py index 8ff0b06b38..af7897f888 100644 --- a/examples/conformer/train_conformer.py +++ b/examples/conformer/train_conformer.py @@ -56,7 +56,7 @@ strategy = setup_strategy(args.devices) -from tensorflow_asr.configs.user_config import UserConfig +from tensorflow_asr.configs.config import Config from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer @@ -64,61 +64,58 @@ from tensorflow_asr.models.conformer import Conformer from tensorflow_asr.optimizers.schedules import TransformerSchedule -config = UserConfig(DEFAULT_YAML, args.config, learning=True) -speech_featurizer = TFSpeechFeaturizer(config["speech_config"]) -text_featurizer = CharFeaturizer(config["decoder_config"]) +config = Config(args.config, learning=True) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) +text_featurizer = CharFeaturizer(config.decoder_config) if args.tfrecords: train_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["train_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.train_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - augmentations=config["learning_config"]["augmentations"], + augmentations=config.learning_config.augmentations, stage="train", cache=args.cache, shuffle=True ) eval_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["eval_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.eval_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="eval", cache=args.cache, shuffle=True ) else: train_dataset = ASRSliceDataset( - data_paths=config["learning_config"]["dataset_config"]["train_paths"], + data_paths=config.learning_config.dataset_config.train_paths, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - augmentations=config["learning_config"]["augmentations"], + augmentations=config.learning_config.augmentations, stage="train", cache=args.cache, shuffle=True ) eval_dataset = ASRSliceDataset( - data_paths=config["learning_config"]["dataset_config"]["eval_paths"], + data_paths=config.learning_config.dataset_config.eval_paths, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="eval", cache=args.cache, shuffle=True ) conformer_trainer = TransducerTrainer( - config=config["learning_config"]["running_config"], + config=config.learning_config.running_config, text_featurizer=text_featurizer, strategy=strategy ) with conformer_trainer.strategy.scope(): # build model - conformer = Conformer( - **config["model_config"], - vocabulary_size=text_featurizer.num_classes - ) + conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) conformer._build(speech_featurizer.shape) conformer.summary(line_length=120) - optimizer_config = config["learning_config"]["optimizer_config"] + optimizer_config = config.learning_config.optimizer_config optimizer = tf.keras.optimizers.Adam( TransformerSchedule( - d_model=config["model_config"]["dmodel"], + d_model=config.model_config["dmodel"], warmup_steps=optimizer_config["warmup_steps"], - max_lr=(0.05 / math.sqrt(config["model_config"]["dmodel"])) + max_lr=(0.05 / math.sqrt(config.model_config["dmodel"])) ), beta_1=optimizer_config["beta1"], beta_2=optimizer_config["beta2"], diff --git a/examples/conformer/train_ga_conformer.py b/examples/conformer/train_ga_conformer.py index 34f617d5b4..ec8c5404bc 100644 --- a/examples/conformer/train_ga_conformer.py +++ b/examples/conformer/train_ga_conformer.py @@ -56,7 +56,7 @@ strategy = setup_strategy(args.devices) -from tensorflow_asr.configs.user_config import UserConfig +from tensorflow_asr.configs.config import Config from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer @@ -64,61 +64,58 @@ from tensorflow_asr.models.conformer import Conformer from tensorflow_asr.optimizers.schedules import TransformerSchedule -config = UserConfig(DEFAULT_YAML, args.config, learning=True) -speech_featurizer = TFSpeechFeaturizer(config["speech_config"]) -text_featurizer = CharFeaturizer(config["decoder_config"]) +config = Config(args.config, learning=True) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) +text_featurizer = CharFeaturizer(config.decoder_config) if args.tfrecords: train_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["train_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.train_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - augmentations=config["learning_config"]["augmentations"], + augmentations=config.learning_config.augmentations, stage="train", cache=args.cache, shuffle=True ) eval_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["eval_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.eval_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="eval", cache=args.cache, shuffle=True ) else: train_dataset = ASRSliceDataset( - data_paths=config["learning_config"]["dataset_config"]["train_paths"], + data_paths=config.learning_config.dataset_config.train_paths, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - augmentations=config["learning_config"]["augmentations"], + augmentations=config.learning_config.augmentations, stage="train", cache=args.cache, shuffle=True ) eval_dataset = ASRSliceDataset( - data_paths=config["learning_config"]["dataset_config"]["eval_paths"], + data_paths=config.learning_config.dataset_config.eval_paths, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="eval", cache=args.cache, shuffle=True ) conformer_trainer = TransducerTrainerGA( - config=config["learning_config"]["running_config"], + config=config.learning_config.running_config, text_featurizer=text_featurizer, strategy=strategy ) with conformer_trainer.strategy.scope(): # build model - conformer = Conformer( - **config["model_config"], - vocabulary_size=text_featurizer.num_classes - ) + conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) conformer._build(speech_featurizer.shape) conformer.summary(line_length=120) - optimizer_config = config["learning_config"]["optimizer_config"] + optimizer_config = config.learning_config.optimizer_config optimizer = tf.keras.optimizers.Adam( TransformerSchedule( - d_model=config["model_config"]["dmodel"], + d_model=config.model_config["dmodel"], warmup_steps=optimizer_config["warmup_steps"], - max_lr=(0.05 / math.sqrt(config["model_config"]["dmodel"])) + max_lr=(0.05 / math.sqrt(config.model_config["dmodel"])) ), beta_1=optimizer_config["beta1"], beta_2=optimizer_config["beta2"], diff --git a/examples/conformer/train_ga_subword_conformer.py b/examples/conformer/train_ga_subword_conformer.py index adefbf88c6..a384a14c14 100644 --- a/examples/conformer/train_ga_subword_conformer.py +++ b/examples/conformer/train_ga_subword_conformer.py @@ -62,7 +62,7 @@ strategy = setup_strategy(args.devices) -from tensorflow_asr.configs.user_config import UserConfig +from tensorflow_asr.configs.config import Config from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer @@ -70,71 +70,68 @@ from tensorflow_asr.models.conformer import Conformer from tensorflow_asr.optimizers.schedules import TransformerSchedule -config = UserConfig(DEFAULT_YAML, args.config, learning=True) -speech_featurizer = TFSpeechFeaturizer(config["speech_config"]) +config = Config(args.config, learning=True) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) if args.subwords and os.path.exists(args.subwords): print("Loading subwords ...") - text_featurizer = SubwordFeaturizer.load_from_file(config["decoder_config"], args.subwords) + text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords) else: print("Generating subwords ...") text_featurizer = SubwordFeaturizer.build_from_corpus( - config["decoder_config"], + config.decoder_config, corpus_files=args.subwords_corpus ) text_featurizer.save_to_file(args.subwords) if args.tfrecords: train_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["train_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.train_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - augmentations=config["learning_config"]["augmentations"], + augmentations=config.learning_config.augmentations, stage="train", cache=args.cache, shuffle=True ) eval_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["eval_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.eval_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="eval", cache=args.cache, shuffle=True ) else: train_dataset = ASRSliceDataset( - data_paths=config["learning_config"]["dataset_config"]["train_paths"], + data_paths=config.learning_config.dataset_config.train_paths, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - augmentations=config["learning_config"]["augmentations"], + augmentations=config.learning_config.augmentations, stage="train", cache=args.cache, shuffle=True ) eval_dataset = ASRSliceDataset( - data_paths=config["learning_config"]["dataset_config"]["eval_paths"], + data_paths=config.learning_config.dataset_config.eval_paths, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="eval", cache=args.cache, shuffle=True ) conformer_trainer = TransducerTrainerGA( - config=config["learning_config"]["running_config"], + config=config.learning_config.running_config, text_featurizer=text_featurizer, strategy=strategy ) with conformer_trainer.strategy.scope(): # build model - conformer = Conformer( - **config["model_config"], - vocabulary_size=text_featurizer.num_classes - ) + conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) conformer._build(speech_featurizer.shape) conformer.summary(line_length=120) - optimizer_config = config["learning_config"]["optimizer_config"] + optimizer_config = config.learning_config.optimizer_config optimizer = tf.keras.optimizers.Adam( TransformerSchedule( - d_model=config["model_config"]["dmodel"], + d_model=config.model_config["dmodel"], warmup_steps=optimizer_config["warmup_steps"], - max_lr=(0.05 / math.sqrt(config["model_config"]["dmodel"])) + max_lr=(0.05 / math.sqrt(config.model_config["dmodel"])) ), beta_1=optimizer_config["beta1"], beta_2=optimizer_config["beta2"], diff --git a/examples/conformer/train_subword_conformer.py b/examples/conformer/train_subword_conformer.py index 1b0018c657..ca205287c7 100644 --- a/examples/conformer/train_subword_conformer.py +++ b/examples/conformer/train_subword_conformer.py @@ -62,7 +62,7 @@ strategy = setup_strategy(args.devices) -from tensorflow_asr.configs.user_config import UserConfig +from tensorflow_asr.configs.config import Config from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer @@ -70,71 +70,68 @@ from tensorflow_asr.models.conformer import Conformer from tensorflow_asr.optimizers.schedules import TransformerSchedule -config = UserConfig(DEFAULT_YAML, args.config, learning=True) -speech_featurizer = TFSpeechFeaturizer(config["speech_config"]) +config = Config(args.config, learning=True) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) if args.subwords and os.path.exists(args.subwords): print("Loading subwords ...") - text_featurizer = SubwordFeaturizer.load_from_file(config["decoder_config"], args.subwords) + text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords) else: print("Generating subwords ...") text_featurizer = SubwordFeaturizer.build_from_corpus( - config["decoder_config"], + config.decoder_config, corpus_files=args.subwords_corpus ) text_featurizer.save_to_file(args.subwords) if args.tfrecords: train_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["train_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.train_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - augmentations=config["learning_config"]["augmentations"], + augmentations=config.learning_config.augmentations, stage="train", cache=args.cache, shuffle=True ) eval_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["eval_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.eval_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="eval", cache=args.cache, shuffle=True ) else: train_dataset = ASRSliceDataset( - data_paths=config["learning_config"]["dataset_config"]["train_paths"], + data_paths=config.learning_config.dataset_config.train_paths, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - augmentations=config["learning_config"]["augmentations"], + augmentations=config.learning_config.augmentations, stage="train", cache=args.cache, shuffle=True ) eval_dataset = ASRSliceDataset( - data_paths=config["learning_config"]["dataset_config"]["eval_paths"], + data_paths=config.learning_config.dataset_config.eval_paths, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="eval", cache=args.cache, shuffle=True ) conformer_trainer = TransducerTrainer( - config=config["learning_config"]["running_config"], + config=config.learning_config.running_config, text_featurizer=text_featurizer, strategy=strategy ) with conformer_trainer.strategy.scope(): # build model - conformer = Conformer( - **config["model_config"], - vocabulary_size=text_featurizer.num_classes - ) + conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) conformer._build(speech_featurizer.shape) conformer.summary(line_length=120) - optimizer_config = config["learning_config"]["optimizer_config"] + optimizer_config = config.learning_config.optimizer_config optimizer = tf.keras.optimizers.Adam( TransformerSchedule( - d_model=config["model_config"]["dmodel"], + d_model=config.model_config["dmodel"], warmup_steps=optimizer_config["warmup_steps"], - max_lr=(0.05 / math.sqrt(config["model_config"]["dmodel"])) + max_lr=(0.05 / math.sqrt(config.model_config["dmodel"])) ), beta_1=optimizer_config["beta1"], beta_2=optimizer_config["beta2"], diff --git a/examples/deepspeech2/test_ds2.py b/examples/deepspeech2/test_ds2.py index f8741d32fe..3e16f04a24 100644 --- a/examples/deepspeech2/test_ds2.py +++ b/examples/deepspeech2/test_ds2.py @@ -49,7 +49,7 @@ setup_devices([args.device]) -from tensorflow_asr.configs.user_config import UserConfig +from tensorflow_asr.configs.config import Config from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer @@ -59,11 +59,11 @@ tf.random.set_seed(0) assert args.export -config = UserConfig(DEFAULT_YAML, args.config, learning=True) -speech_featurizer = TFSpeechFeaturizer(config["speech_config"]) -text_featurizer = CharFeaturizer(config["decoder_config"]) +config = Config(args.config, learning=True) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) +text_featurizer = CharFeaturizer(config.decoder_config) # Build DS2 model -ds2_model = DeepSpeech2(**config["model_config"], vocabulary_size=text_featurizer.num_classes) +ds2_model = DeepSpeech2(**config.model_config, vocabulary_size=text_featurizer.num_classes) ds2_model._build(speech_featurizer.shape) ds2_model.load_weights(args.saved, by_name=True) ds2_model.summary(line_length=120) @@ -71,22 +71,22 @@ if args.tfrecords: test_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["test_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.test_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="test", shuffle=False ) else: test_dataset = ASRSliceDataset( - data_paths=config["learning_config"]["dataset_config"]["test_paths"], + data_paths=config.learning_config.dataset_config.test_paths, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="test", shuffle=False ) ctc_tester = BaseTester( - config=config["learning_config"]["running_config"], + config=config.learning_config.running_config, output_name=args.output_name ) ctc_tester.compile(ds2_model) diff --git a/examples/deepspeech2/train_ds2.py b/examples/deepspeech2/train_ds2.py index 21b74f3868..33a410a4fd 100644 --- a/examples/deepspeech2/train_ds2.py +++ b/examples/deepspeech2/train_ds2.py @@ -55,29 +55,29 @@ strategy = setup_strategy(args.devices) -from tensorflow_asr.configs.user_config import UserConfig +from tensorflow_asr.configs.config import Config from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer from tensorflow_asr.runners.ctc_runners import CTCTrainer from tensorflow_asr.models.deepspeech2 import DeepSpeech2 -config = UserConfig(DEFAULT_YAML, args.config, learning=True) -speech_featurizer = TFSpeechFeaturizer(config["speech_config"]) -text_featurizer = CharFeaturizer(config["decoder_config"]) +config = Config(args.config, learning=True) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) +text_featurizer = CharFeaturizer(config.decoder_config) if args.tfrecords: train_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["train_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.train_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - augmentations=config["learning_config"]["augmentations"], + augmentations=config.learning_config.augmentations, stage="train", cache=args.cache, shuffle=True ) eval_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["eval_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.eval_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="eval", cache=args.cache, shuffle=True @@ -86,25 +86,25 @@ train_dataset = ASRSliceDataset( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - data_paths=config["learning_config"]["dataset_config"]["train_paths"], - augmentations=config["learning_config"]["augmentations"], + data_paths=config.learning_config.dataset_config.train_paths, + augmentations=config.learning_config.augmentations, stage="train", cache=args.cache, shuffle=True ) eval_dataset = ASRSliceDataset( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - data_paths=config["learning_config"]["dataset_config"]["eval_paths"], + data_paths=config.learning_config.dataset_config.eval_paths, stage="eval", cache=args.cache, shuffle=True ) -ctc_trainer = CTCTrainer(text_featurizer, config["learning_config"]["running_config"]) +ctc_trainer = CTCTrainer(text_featurizer, config.learning_config.running_config) # Build DS2 model with ctc_trainer.strategy.scope(): - ds2_model = DeepSpeech2(**config["model_config"], vocabulary_size=text_featurizer.num_classes) + ds2_model = DeepSpeech2(**config.model_config, vocabulary_size=text_featurizer.num_classes) ds2_model._build(speech_featurizer.shape) ds2_model.summary(line_length=120) # Compile -ctc_trainer.compile(ds2_model, config["learning_config"]["optimizer_config"], +ctc_trainer.compile(ds2_model, config.learning_config.optimizer_config, max_to_keep=args.max_ckpts) ctc_trainer.fit(train_dataset, eval_dataset, train_bs=args.tbs, eval_bs=args.ebs) diff --git a/examples/jasper/test_jasper.py b/examples/jasper/test_jasper.py index c50294580a..0169f6cb0b 100644 --- a/examples/jasper/test_jasper.py +++ b/examples/jasper/test_jasper.py @@ -49,7 +49,7 @@ setup_devices([args.device]) -from tensorflow_asr.configs.user_config import UserConfig +from tensorflow_asr.configs.config import Config from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer @@ -59,11 +59,11 @@ tf.random.set_seed(0) assert args.export -config = UserConfig(DEFAULT_YAML, args.config, learning=True) -speech_featurizer = TFSpeechFeaturizer(config["speech_config"]) -text_featurizer = CharFeaturizer(config["decoder_config"]) +config = Config(args.config, learning=True) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) +text_featurizer = CharFeaturizer(config.decoder_config) # Build DS2 model -jasper = Jasper(**config["model_config"], vocabulary_size=text_featurizer.num_classes) +jasper = Jasper(**config.model_config, vocabulary_size=text_featurizer.num_classes) jasper._build(speech_featurizer.shape) jasper.load_weights(args.saved, by_name=True) jasper.summary(line_length=120) @@ -71,22 +71,22 @@ if args.tfrecords: test_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["test_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.test_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="test", shuffle=False ) else: test_dataset = ASRSliceDataset( - data_paths=config["learning_config"]["dataset_config"]["test_paths"], + data_paths=config.learning_config.dataset_config.test_paths, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="test", shuffle=False ) ctc_tester = BaseTester( - config=config["learning_config"]["running_config"], + config=config.learning_config.running_config, output_name=args.output_name ) ctc_tester.compile(jasper) diff --git a/examples/jasper/train_jasper.py b/examples/jasper/train_jasper.py index 9e65f2c739..698733d62a 100644 --- a/examples/jasper/train_jasper.py +++ b/examples/jasper/train_jasper.py @@ -55,29 +55,29 @@ strategy = setup_strategy(args.devices) -from tensorflow_asr.configs.user_config import UserConfig +from tensorflow_asr.configs.config import Config from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer from tensorflow_asr.runners.ctc_runners import CTCTrainer from tensorflow_asr.models.jasper import Jasper -config = UserConfig(DEFAULT_YAML, args.config, learning=True) -speech_featurizer = TFSpeechFeaturizer(config["speech_config"]) -text_featurizer = CharFeaturizer(config["decoder_config"]) +config = Config(args.config, learning=True) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) +text_featurizer = CharFeaturizer(config.decoder_config) if args.tfrecords: train_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["train_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.train_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - augmentations=config["learning_config"]["augmentations"], + augmentations=config.learning_config.augmentations, stage="train", cache=args.cache, shuffle=True ) eval_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["eval_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.eval_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="eval", cache=args.cache, shuffle=True @@ -86,25 +86,25 @@ train_dataset = ASRSliceDataset( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - data_paths=config["learning_config"]["dataset_config"]["train_paths"], - augmentations=config["learning_config"]["augmentations"], + data_paths=config.learning_config.dataset_config.train_paths, + augmentations=config.learning_config.augmentations, stage="train", cache=args.cache, shuffle=True ) eval_dataset = ASRSliceDataset( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - data_paths=config["learning_config"]["dataset_config"]["eval_paths"], + data_paths=config.learning_config.dataset_config.eval_paths, stage="eval", cache=args.cache, shuffle=True ) -ctc_trainer = CTCTrainer(text_featurizer, config["learning_config"]["running_config"]) +ctc_trainer = CTCTrainer(text_featurizer, config.learning_config.running_config) # Build DS2 model with ctc_trainer.strategy.scope(): - jasper = Jasper(**config["model_config"], vocabulary_size=text_featurizer.num_classes) + jasper = Jasper(**config.model_config, vocabulary_size=text_featurizer.num_classes) jasper._build(speech_featurizer.shape) jasper.summary(line_length=120) # Compile -ctc_trainer.compile(jasper, config["learning_config"]["optimizer_config"], +ctc_trainer.compile(jasper, config.learning_config.optimizer_config, max_to_keep=args.max_ckpts) ctc_trainer.fit(train_dataset, eval_dataset, train_bs=args.tbs, eval_bs=args.ebs) diff --git a/examples/streaming_transducer/test_streaming_transducer.py b/examples/streaming_transducer/test_streaming_transducer.py index 8cefe2b365..c9e79d34f0 100755 --- a/examples/streaming_transducer/test_streaming_transducer.py +++ b/examples/streaming_transducer/test_streaming_transducer.py @@ -52,31 +52,31 @@ setup_devices([args.device], cpu=args.cpu) -from tensorflow_asr.configs.user_config import UserConfig +from tensorflow_asr.configs.config import Config from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer from tensorflow_asr.runners.base_runners import BaseTester from tensorflow_asr.models.streaming_transducer import StreamingTransducer -config = UserConfig(DEFAULT_YAML, args.config, learning=True) -speech_featurizer = TFSpeechFeaturizer(config["speech_config"]) -text_featurizer = CharFeaturizer(config["decoder_config"]) +config = Config(args.config, learning=True) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) +text_featurizer = CharFeaturizer(config.decoder_config) tf.random.set_seed(0) assert args.saved if args.tfrecords: test_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["test_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.test_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="test", shuffle=False ) else: test_dataset = ASRSliceDataset( - data_paths=config["learning_config"]["dataset_config"]["test_paths"], + data_paths=config.learning_config.dataset_config.test_paths, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="test", shuffle=False @@ -85,7 +85,7 @@ # build model streaming_transducer = StreamingTransducer( vocabulary_size=text_featurizer.num_classes, - **config["model_config"] + **config.model_config ) streaming_transducer._build(speech_featurizer.shape) streaming_transducer.load_weights(args.saved, by_name=True) @@ -93,7 +93,7 @@ streaming_transducer.add_featurizers(speech_featurizer, text_featurizer) streaming_transducer_tester = BaseTester( - config=config["learning_config"]["running_config"], + config=config.learning_config.running_config, output_name=args.output_name ) streaming_transducer_tester.compile(streaming_transducer) diff --git a/examples/streaming_transducer/test_subword_streaming_transducer.py b/examples/streaming_transducer/test_subword_streaming_transducer.py index 1c30615b4a..d422d69237 100755 --- a/examples/streaming_transducer/test_subword_streaming_transducer.py +++ b/examples/streaming_transducer/test_subword_streaming_transducer.py @@ -55,19 +55,19 @@ setup_devices([args.device], cpu=args.cpu) -from tensorflow_asr.configs.user_config import UserConfig +from tensorflow_asr.configs.config import Config from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer from tensorflow_asr.runners.base_runners import BaseTester from tensorflow_asr.models.streaming_transducer import StreamingTransducer -config = UserConfig(DEFAULT_YAML, args.config, learning=True) -speech_featurizer = TFSpeechFeaturizer(config["speech_config"]) +config = Config(args.config, learning=True) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) if args.subwords and os.path.exists(args.subwords): print("Loading subwords ...") - text_featurizer = SubwordFeaturizer.load_from_file(config["decoder_config"], args.subwords) + text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords) else: raise ValueError("subwords must be set") @@ -76,15 +76,15 @@ if args.tfrecords: test_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["test_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.test_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="test", shuffle=False ) else: test_dataset = ASRSliceDataset( - data_paths=config["learning_config"]["dataset_config"]["test_paths"], + data_paths=config.learning_config.dataset_config.test_paths, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="test", shuffle=False @@ -93,7 +93,7 @@ # build model streaming_transducer = StreamingTransducer( vocabulary_size=text_featurizer.num_classes, - **config["model_config"] + **config.model_config ) streaming_transducer._build(speech_featurizer.shape) streaming_transducer.load_weights(args.saved, by_name=True) @@ -101,7 +101,7 @@ streaming_transducer.add_featurizers(speech_featurizer, text_featurizer) streaming_transducer_tester = BaseTester( - config=config["learning_config"]["running_config"], + config=config.learning_config.running_config, output_name=args.output_name ) streaming_transducer_tester.compile(streaming_transducer) diff --git a/examples/streaming_transducer/tflite_streaming_transducer.py b/examples/streaming_transducer/tflite_streaming_transducer.py index 29a5c9e58c..eacb4ba584 100644 --- a/examples/streaming_transducer/tflite_streaming_transducer.py +++ b/examples/streaming_transducer/tflite_streaming_transducer.py @@ -19,7 +19,7 @@ setup_environment() import tensorflow as tf -from tensorflow_asr.configs.user_config import UserConfig +from tensorflow_asr.configs.config import Config from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer from tensorflow_asr.models.streaming_transducer import StreamingTransducer @@ -43,13 +43,13 @@ assert args.saved and args.output -config = UserConfig(DEFAULT_YAML, args.config, learning=True) -speech_featurizer = TFSpeechFeaturizer(config["speech_config"]) -text_featurizer = CharFeaturizer(config["decoder_config"]) +config = Config(args.config, learning=True) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) +text_featurizer = CharFeaturizer(config.decoder_config) # build model streaming_transducer = StreamingTransducer( - **config["model_config"], + **config.model_config, vocabulary_size=text_featurizer.num_classes ) streaming_transducer._build(speech_featurizer.shape) diff --git a/examples/streaming_transducer/tflite_subword_streaming_transducer.py b/examples/streaming_transducer/tflite_subword_streaming_transducer.py index c920efa3a2..8bd3d0511b 100644 --- a/examples/streaming_transducer/tflite_subword_streaming_transducer.py +++ b/examples/streaming_transducer/tflite_subword_streaming_transducer.py @@ -19,7 +19,7 @@ setup_environment() import tensorflow as tf -from tensorflow_asr.configs.user_config import UserConfig +from tensorflow_asr.configs.config import Config from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer from tensorflow_asr.models.streaming_transducer import StreamingTransducer @@ -46,18 +46,18 @@ assert args.saved and args.output -config = UserConfig(DEFAULT_YAML, args.config, learning=True) -speech_featurizer = TFSpeechFeaturizer(config["speech_config"]) +config = Config(args.config, learning=True) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) if args.subwords and os.path.exists(args.subwords): print("Loading subwords ...") - text_featurizer = SubwordFeaturizer.load_from_file(config["decoder_config"], args.subwords) + text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords) else: raise ValueError("subwords must be set") # build model streaming_transducer = StreamingTransducer( - **config["model_config"], + **config.model_config, vocabulary_size=text_featurizer.num_classes ) streaming_transducer._build(speech_featurizer.shape) diff --git a/examples/streaming_transducer/train_ga_streaming_transducer.py b/examples/streaming_transducer/train_ga_streaming_transducer.py index 82ffbbb344..c3186126d9 100644 --- a/examples/streaming_transducer/train_ga_streaming_transducer.py +++ b/examples/streaming_transducer/train_ga_streaming_transducer.py @@ -55,57 +55,57 @@ strategy = setup_strategy(args.devices) -from tensorflow_asr.configs.user_config import UserConfig +from tensorflow_asr.configs.config import Config from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer from tensorflow_asr.runners.transducer_runners import TransducerTrainerGA from tensorflow_asr.models.streaming_transducer import StreamingTransducer -config = UserConfig(DEFAULT_YAML, args.config, learning=True) -speech_featurizer = TFSpeechFeaturizer(config["speech_config"]) -text_featurizer = CharFeaturizer(config["decoder_config"]) +config = Config(args.config, learning=True) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) +text_featurizer = CharFeaturizer(config.decoder_config) if args.tfrecords: train_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["train_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.train_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - augmentations=config["learning_config"]["augmentations"], + augmentations=config.learning_config.augmentations, stage="train", cache=args.cache, shuffle=True ) eval_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["eval_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.eval_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="eval", cache=args.cache, shuffle=True ) else: train_dataset = ASRSliceDataset( - data_paths=config["learning_config"]["dataset_config"]["train_paths"], + data_paths=config.learning_config.dataset_config.train_paths, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - augmentations=config["learning_config"]["augmentations"], + augmentations=config.learning_config.augmentations, stage="train", cache=args.cache, shuffle=True ) eval_dataset = ASRSliceDataset( - data_paths=config["learning_config"]["dataset_config"]["eval_paths"], + data_paths=config.learning_config.dataset_config.eval_paths, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="eval", cache=args.cache, shuffle=True ) streaming_transducer_trainer = TransducerTrainerGA( - config=config["learning_config"]["running_config"], + config=config.learning_config.running_config, text_featurizer=text_featurizer, strategy=strategy ) with streaming_transducer_trainer.strategy.scope(): # build model streaming_transducer = StreamingTransducer( - **config["model_config"], + **config.model_config, vocabulary_size=text_featurizer.num_classes ) streaming_transducer._build(speech_featurizer.shape) diff --git a/examples/streaming_transducer/train_ga_subword_streaming_transducer.py b/examples/streaming_transducer/train_ga_subword_streaming_transducer.py index c779e125ff..8b354fc891 100644 --- a/examples/streaming_transducer/train_ga_subword_streaming_transducer.py +++ b/examples/streaming_transducer/train_ga_subword_streaming_transducer.py @@ -61,67 +61,67 @@ strategy = setup_strategy(args.devices) -from tensorflow_asr.configs.user_config import UserConfig +from tensorflow_asr.configs.config import Config from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer from tensorflow_asr.runners.transducer_runners import TransducerTrainerGA from tensorflow_asr.models.streaming_transducer import StreamingTransducer -config = UserConfig(DEFAULT_YAML, args.config, learning=True) -speech_featurizer = TFSpeechFeaturizer(config["speech_config"]) +config = Config(args.config, learning=True) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) if args.subwords and os.path.exists(args.subwords): print("Loading subwords ...") - text_featurizer = SubwordFeaturizer.load_from_file(config["decoder_config"], args.subwords) + text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords) else: print("Generating subwords ...") text_featurizer = SubwordFeaturizer.build_from_corpus( - config["decoder_config"], + config.decoder_config, corpus_files=args.subwords_corpus ) text_featurizer.save_to_file(args.subwords) if args.tfrecords: train_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["train_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.train_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - augmentations=config["learning_config"]["augmentations"], + augmentations=config.learning_config.augmentations, stage="train", cache=args.cache, shuffle=True ) eval_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["eval_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.eval_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="eval", cache=args.cache, shuffle=True ) else: train_dataset = ASRSliceDataset( - data_paths=config["learning_config"]["dataset_config"]["train_paths"], + data_paths=config.learning_config.dataset_config.train_paths, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - augmentations=config["learning_config"]["augmentations"], + augmentations=config.learning_config.augmentations, stage="train", cache=args.cache, shuffle=True ) eval_dataset = ASRSliceDataset( - data_paths=config["learning_config"]["dataset_config"]["eval_paths"], + data_paths=config.learning_config.dataset_config.eval_paths, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="eval", cache=args.cache, shuffle=True ) streaming_transducer_trainer = TransducerTrainerGA( - config=config["learning_config"]["running_config"], + config=config.learning_config.running_config, text_featurizer=text_featurizer, strategy=strategy ) with streaming_transducer_trainer.strategy.scope(): # build model streaming_transducer = StreamingTransducer( - **config["model_config"], + **config.model_config, vocabulary_size=text_featurizer.num_classes ) streaming_transducer._build(speech_featurizer.shape) diff --git a/examples/streaming_transducer/train_streaming_transducer.py b/examples/streaming_transducer/train_streaming_transducer.py index cac1004ab6..7d22482405 100644 --- a/examples/streaming_transducer/train_streaming_transducer.py +++ b/examples/streaming_transducer/train_streaming_transducer.py @@ -55,57 +55,57 @@ strategy = setup_strategy(args.devices) -from tensorflow_asr.configs.user_config import UserConfig +from tensorflow_asr.configs.config import Config from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer from tensorflow_asr.runners.transducer_runners import TransducerTrainer from tensorflow_asr.models.streaming_transducer import StreamingTransducer -config = UserConfig(DEFAULT_YAML, args.config, learning=True) -speech_featurizer = TFSpeechFeaturizer(config["speech_config"]) -text_featurizer = CharFeaturizer(config["decoder_config"]) +config = Config(args.config, learning=True) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) +text_featurizer = CharFeaturizer(config.decoder_config) if args.tfrecords: train_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["train_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.train_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - augmentations=config["learning_config"]["augmentations"], + augmentations=config.learning_config.augmentations, stage="train", cache=args.cache, shuffle=True ) eval_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["eval_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.eval_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="eval", cache=args.cache, shuffle=True ) else: train_dataset = ASRSliceDataset( - data_paths=config["learning_config"]["dataset_config"]["train_paths"], + data_paths=config.learning_config.dataset_config.train_paths, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - augmentations=config["learning_config"]["augmentations"], + augmentations=config.learning_config.augmentations, stage="train", cache=args.cache, shuffle=True ) eval_dataset = ASRSliceDataset( - data_paths=config["learning_config"]["dataset_config"]["eval_paths"], + data_paths=config.learning_config.dataset_config.eval_paths, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="eval", cache=args.cache, shuffle=True ) streaming_transducer_trainer = TransducerTrainer( - config=config["learning_config"]["running_config"], + config=config.learning_config.running_config, text_featurizer=text_featurizer, strategy=strategy ) with streaming_transducer_trainer.strategy.scope(): # build model streaming_transducer = StreamingTransducer( - **config["model_config"], + **config.model_config, vocabulary_size=text_featurizer.num_classes ) streaming_transducer._build(speech_featurizer.shape) diff --git a/examples/streaming_transducer/train_subword_streaming_transducer.py b/examples/streaming_transducer/train_subword_streaming_transducer.py index 88053b4460..8319431c46 100644 --- a/examples/streaming_transducer/train_subword_streaming_transducer.py +++ b/examples/streaming_transducer/train_subword_streaming_transducer.py @@ -61,67 +61,67 @@ strategy = setup_strategy(args.devices) -from tensorflow_asr.configs.user_config import UserConfig +from tensorflow_asr.configs.config import Config from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer from tensorflow_asr.runners.transducer_runners import TransducerTrainer from tensorflow_asr.models.streaming_transducer import StreamingTransducer -config = UserConfig(DEFAULT_YAML, args.config, learning=True) -speech_featurizer = TFSpeechFeaturizer(config["speech_config"]) +config = Config(args.config, learning=True) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) if args.subwords and os.path.exists(args.subwords): print("Loading subwords ...") - text_featurizer = SubwordFeaturizer.load_from_file(config["decoder_config"], args.subwords) + text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords) else: print("Generating subwords ...") text_featurizer = SubwordFeaturizer.build_from_corpus( - config["decoder_config"], + config.decoder_config, corpus_files=args.subwords_corpus ) text_featurizer.save_to_file(args.subwords) if args.tfrecords: train_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["train_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.train_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - augmentations=config["learning_config"]["augmentations"], + augmentations=config.learning_config.augmentations, stage="train", cache=args.cache, shuffle=True ) eval_dataset = ASRTFRecordDataset( - data_paths=config["learning_config"]["dataset_config"]["eval_paths"], - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], + data_paths=config.learning_config.dataset_config.eval_paths, + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="eval", cache=args.cache, shuffle=True ) else: train_dataset = ASRSliceDataset( - data_paths=config["learning_config"]["dataset_config"]["train_paths"], + data_paths=config.learning_config.dataset_config.train_paths, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, - augmentations=config["learning_config"]["augmentations"], + augmentations=config.learning_config.augmentations, stage="train", cache=args.cache, shuffle=True ) eval_dataset = ASRSliceDataset( - data_paths=config["learning_config"]["dataset_config"]["eval_paths"], + data_paths=config.learning_config.dataset_config.eval_paths, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="eval", cache=args.cache, shuffle=True ) streaming_transducer_trainer = TransducerTrainer( - config=config["learning_config"]["running_config"], + config=config.learning_config.running_config, text_featurizer=text_featurizer, strategy=strategy ) with streaming_transducer_trainer.strategy.scope(): # build model streaming_transducer = StreamingTransducer( - **config["model_config"], + **config.model_config, vocabulary_size=text_featurizer.num_classes ) streaming_transducer._build(speech_featurizer.shape) diff --git a/setup.py b/setup.py index 1433f9acb7..76f20fcf05 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ setuptools.setup( name="TensorFlowASR", - version="0.2.8", + version="0.2.9", author="Huy Le Nguyen", author_email="nlhuy.cs.16@gmail.com", description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2", diff --git a/tensorflow_asr/augmentations/augments.py b/tensorflow_asr/augmentations/augments.py index 4def7572b3..1cd45473d8 100755 --- a/tensorflow_asr/augmentations/augments.py +++ b/tensorflow_asr/augmentations/augments.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import UserDict import nlpaug.flow as naf from .signal_augment import SignalCropping, SignalLoudness, SignalMask, SignalNoise, \ @@ -34,15 +33,11 @@ } -class UserAugmentation(UserDict): +class Augmentation: def __init__(self, config: dict = None): if not config: config = {} - config["before"] = self.parse(config.get("before", {})) - config["after"] = self.parse(config.get("after", {})) - super(UserAugmentation, self).__init__(config) - - def __missing__(self, key): - return None + self.before = self.parse(config.get("before", {})) + self.after = self.parse(config.get("after", {})) @staticmethod def parse(config: dict) -> list: diff --git a/tensorflow_asr/configs/__init__.py b/tensorflow_asr/configs/__init__.py index e69de29bb2..f4d5510355 100644 --- a/tensorflow_asr/configs/__init__.py +++ b/tensorflow_asr/configs/__init__.py @@ -0,0 +1,33 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import yaml + + +def load_yaml(path): + # Fix yaml numbers https://stackoverflow.com/a/30462009/11037553 + loader = yaml.SafeLoader + loader.add_implicit_resolver( + u'tag:yaml.org,2002:float', + re.compile(u'''^(?: + [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? + |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) + |\\.[0-9_]+(?:[eE][-+][0-9]+)? + |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* + |[-+]?\\.(?:inf|Inf|INF) + |\\.(?:nan|NaN|NAN))$''', re.X), + list(u'-+0123456789.')) + with open(path, "r", encoding="utf-8") as file: + return yaml.load(file, Loader=loader) diff --git a/tensorflow_asr/configs/config.py b/tensorflow_asr/configs/config.py new file mode 100644 index 0000000000..0c7d5fc47d --- /dev/null +++ b/tensorflow_asr/configs/config.py @@ -0,0 +1,59 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import load_yaml +from ..augmentations.augments import Augmentation +from ..utils.utils import preprocess_paths + + +class DatasetConfig: + def __init__(self, config: dict = None): + if not config: config = {} + self.train_paths = config.get("train_paths", None) + self.eval_paths = config.get("eval_paths", None) + self.test_paths = config.get("test_paths", None) + self.tfrecords_dir = config.get("tfrecords_dir", None) + + +class RunningConfig: + def __init__(self, config: dict = None): + if not config: config = {} + self.batch_size = config.get("batch_size", 1) + self.accumulation_steps = config.get("accumulation_steps", 1) + self.num_epochs = config.get("num_epochs", 20) + self.outdir = preprocess_paths(config.get("outdir", None)) + self.log_interval_steps = config.get("log_interval_steps", 500) + self.save_interval_steps = config.get("save_interval_steps", 500) + self.eval_interval_steps = config.get("eval_interval_steps", 1000) + + +class LearningConfig: + def __init__(self, config: dict = None): + if not config: config = {} + self.augmentations = Augmentation(config.get("augmentations")) + self.dataset_config = DatasetConfig(config.get("dataset_config")) + self.optimizer_config = config.get("optimizer_config", {}) + self.running_config = RunningConfig(config.get("running_config")) + + +class Config: + """ User config class for training, testing or infering """ + + def __init__(self, path: str, learning: bool): + config = load_yaml(preprocess_paths(path)) + self.speech_config = config.get("speech_config", {}) + self.decoder_config = config.get("decoder_config", {}) + self.model_config = config.get("model_config", {}) + if learning: + self.learning_config = LearningConfig(config.get("learning_config")) diff --git a/tensorflow_asr/configs/user_config.py b/tensorflow_asr/configs/user_config.py deleted file mode 100644 index be9e2dd089..0000000000 --- a/tensorflow_asr/configs/user_config.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@usimarit) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import re -import yaml -from collections import UserDict - -from ..utils.utils import preprocess_paths, append_default_keys_dict, check_key_in_dict - - -def load_yaml(path): - # Fix yaml numbers https://stackoverflow.com/a/30462009/11037553 - loader = yaml.SafeLoader - loader.add_implicit_resolver( - u'tag:yaml.org,2002:float', - re.compile(u'''^(?: - [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? - |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) - |\\.[0-9_]+(?:[eE][-+][0-9]+)? - |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* - |[-+]?\\.(?:inf|Inf|INF) - |\\.(?:nan|NaN|NAN))$''', re.X), - list(u'-+0123456789.')) - with open(preprocess_paths(path), "r", encoding="utf-8") as file: - return yaml.load(file, Loader=loader) - - -class UserConfig(UserDict): - """ User config class for training, testing or infering """ - - def __init__(self, default: str, custom: str, learning: bool = True): - assert default, "Default dict for config must be set" - default = load_yaml(default) - custom = append_default_keys_dict(default, load_yaml(custom)) - super(UserConfig, self).__init__(custom) - if not learning and self.data.get("learning_config", None) is not None: - # No need to have learning_config on Inferencer - del self.data["learning_config"] - elif learning: - # Check keys - check_key_in_dict( - self.data["learning_config"], - ["augmentations", "dataset_config", "running_config"] - ) - check_key_in_dict( - self.data["learning_config"]["dataset_config"], - ["train_paths", "eval_paths", "test_paths"] - ) - check_key_in_dict( - self.data["learning_config"]["running_config"], - ["batch_size", "num_epochs", "outdir", "log_interval_steps", - "save_interval_steps", "eval_interval_steps"] - ) - - def __missing__(self, key): - return None diff --git a/tensorflow_asr/datasets/asr_dataset.py b/tensorflow_asr/datasets/asr_dataset.py index 06d607b4fc..b157a26dcc 100755 --- a/tensorflow_asr/datasets/asr_dataset.py +++ b/tensorflow_asr/datasets/asr_dataset.py @@ -19,6 +19,7 @@ import numpy as np import tensorflow as tf +from ..augmentations.augments import Augmentation from .base_dataset import BaseDataset from ..featurizers.speech_featurizers import read_raw_audio, SpeechFeaturizer from ..featurizers.text_featurizers import TextFeaturizer @@ -55,7 +56,7 @@ def __init__(self, speech_featurizer: SpeechFeaturizer, text_featurizer: TextFeaturizer, data_paths: list, - augmentations: dict = None, + augmentations: Augmentation = Augmentation(None), cache: bool = False, shuffle: bool = False): super(ASRDataset, self).__init__(data_paths, augmentations, cache, shuffle, stage) @@ -82,11 +83,11 @@ def preprocess(self, audio, transcript): with tf.device("/CPU:0"): signal = read_raw_audio(audio, self.speech_featurizer.sample_rate) - signal = self.augmentations["before"].augment(signal) + signal = self.augmentations.before.augment(signal) features = self.speech_featurizer.extract(signal) - features = self.augmentations["after"].augment(features) + features = self.augmentations.after.augment(features) label = self.text_featurizer.extract(transcript.decode("utf-8")) label_length = tf.cast(tf.shape(label)[0], tf.int32) @@ -148,7 +149,7 @@ def __init__(self, speech_featurizer: SpeechFeaturizer, text_featurizer: TextFeaturizer, stage: str, - augmentations: dict = None, + augmentations: Augmentation = Augmentation(None), cache: bool = False, shuffle: bool = False): super(ASRTFRecordDataset, self).__init__( diff --git a/tensorflow_asr/datasets/base_dataset.py b/tensorflow_asr/datasets/base_dataset.py index 049f10222a..864eee8864 100644 --- a/tensorflow_asr/datasets/base_dataset.py +++ b/tensorflow_asr/datasets/base_dataset.py @@ -13,8 +13,7 @@ # limitations under the License. import abc -from ..augmentations.augments import UserAugmentation -from ..utils.utils import preprocess_paths +from ..augmentations.augments import Augmentation class BaseDataset(metaclass=abc.ABCMeta): @@ -22,12 +21,12 @@ class BaseDataset(metaclass=abc.ABCMeta): def __init__(self, data_paths: list, - augmentations: dict = None, + augmentations: Augmentation = Augmentation(None), cache: bool = False, shuffle: bool = False, stage: str = "train"): - self.data_paths = preprocess_paths(data_paths) if data_paths else [] - self.augmentations = UserAugmentation(augmentations) # apply augmentation + self.data_paths = data_paths + self.augmentations = augmentations # apply augmentation self.cache = cache # whether to cache WHOLE transformed dataset to memory self.shuffle = shuffle # whether to shuffle tf.data.Dataset self.stage = stage # for defining tfrecords files diff --git a/tensorflow_asr/featurizers/speech_featurizers.py b/tensorflow_asr/featurizers/speech_featurizers.py index 26d2e1d403..03e4470787 100755 --- a/tensorflow_asr/featurizers/speech_featurizers.py +++ b/tensorflow_asr/featurizers/speech_featurizers.py @@ -183,17 +183,17 @@ def __init__(self, speech_config: dict): } """ # Samples - self.sample_rate = speech_config["sample_rate"] - self.frame_length = int(self.sample_rate * (speech_config["frame_ms"] / 1000)) - self.frame_step = int(self.sample_rate * (speech_config["stride_ms"] / 1000)) + self.sample_rate = speech_config.get("sample_rate", 16000) + self.frame_length = int(self.sample_rate * (speech_config.get("frame_ms", 25) / 1000)) + self.frame_step = int(self.sample_rate * (speech_config.get("stride_ms", 10) / 1000)) # Features - self.num_feature_bins = speech_config["num_feature_bins"] - self.feature_type = speech_config["feature_type"] - self.preemphasis = speech_config["preemphasis"] + self.num_feature_bins = speech_config.get("num_feature_bins", 80) + self.feature_type = speech_config.get("feature_type", "log_mel_spectrogram") + self.preemphasis = speech_config.get("preemphasis", None) # Normalization - self.normalize_signal = speech_config["normalize_signal"] - self.normalize_feature = speech_config["normalize_feature"] - self.normalize_per_feature = speech_config["normalize_per_feature"] + self.normalize_signal = speech_config.get("normalize_signal", True) + self.normalize_feature = speech_config.get("normalize_feature", True) + self.normalize_per_feature = speech_config.get("normalize_per_feature", False) @property def nfft(self) -> int: diff --git a/tensorflow_asr/runners/base_runners.py b/tensorflow_asr/runners/base_runners.py index 34dddb619c..76b0624979 100644 --- a/tensorflow_asr/runners/base_runners.py +++ b/tensorflow_asr/runners/base_runners.py @@ -21,36 +21,27 @@ import numpy as np import tensorflow as tf -from ..utils.utils import preprocess_paths, get_num_batches, bytes_to_string +from ..configs.config import RunningConfig +from ..utils.utils import get_num_batches, bytes_to_string from ..utils.metrics import ErrorRate, wer, cer class BaseRunner(metaclass=abc.ABCMeta): """ Customized runner module for all models """ - def __init__(self, config: dict): - """ - running_config: - batch_size: 8 - num_epochs: 20 - outdir: ... - log_interval_steps: 200 - eval_interval_steps: 200 - save_interval_steps: 200 - """ + def __init__(self, config: RunningConfig): self.config = config - self.config["outdir"] = preprocess_paths(self.config["outdir"]) # Writers self.writers = { "train": tf.summary.create_file_writer( - os.path.join(self.config["outdir"], "tensorboard", "train")), + os.path.join(self.config.outdir, "tensorboard", "train")), "eval": tf.summary.create_file_writer( - os.path.join(self.config["outdir"], "tensorboard", "eval")) + os.path.join(self.config.outdir, "tensorboard", "eval")) } def add_writer(self, stage: str): self.writers[stage] = tf.summary.create_file_writer( - os.path.join(self.config["outdir"], "tensorboard", stage)) + os.path.join(self.config.outdir, "tensorboard", stage)) def _write_to_tensorboard(self, list_metrics: dict, @@ -75,12 +66,8 @@ class BaseTrainer(BaseRunner): """Customized trainer module for all models.""" def __init__(self, - config: dict, - strategy=None): - """ - Args: - config: the 'learning_config' part in YAML config file - """ + config: RunningConfig, + strategy: tf.distribute.Strategy = None): # Configurations super(BaseTrainer, self).__init__(config) self.set_strategy(strategy) @@ -99,7 +86,7 @@ def __init__(self, @property def total_train_steps(self): if self.train_steps_per_epoch is None: return None - return self.config["num_epochs"] * self.train_steps_per_epoch + return self.config.num_epochs * self.train_steps_per_epoch @property def epochs(self): @@ -128,13 +115,13 @@ def set_strategy(self, strategy=None): def set_train_data_loader(self, train_dataset, train_bs=None, train_acs=None): """ Set train data loader (MUST). """ - if not train_bs: train_bs = self.config["batch_size"] + if not train_bs: train_bs = self.config.batch_size self.global_batch_size = train_bs * self.strategy.num_replicas_in_sync - self.config["batch_size"] = train_bs # Update batch size fed from arguments + self.config.batch_size = train_bs # Update batch size fed from arguments - if not train_acs: train_acs = self.config.get("accumulation_steps", 1) + if not train_acs: train_acs = self.config.accumulation_steps self.accumulation_bs = train_bs // train_acs - self.config["accumulation_steps"] = train_acs + self.config.accumulation_steps = train_acs # update accum steps fed from arguments self.train_data = train_dataset.create(self.global_batch_size) self.train_data_loader = self.strategy.experimental_distribute_dataset(self.train_data) @@ -147,7 +134,7 @@ def set_eval_data_loader(self, eval_dataset, eval_bs=None): self.eval_data = None self.eval_data_loader = None return - if not eval_bs: eval_bs = self.config["batch_size"] + if not eval_bs: eval_bs = self.config.batch_size self.eval_data = eval_dataset.create(eval_bs * self.strategy.num_replicas_in_sync) self.eval_data_loader = self.strategy.experimental_distribute_dataset(self.eval_data) self.eval_steps_per_epoch = eval_dataset.total_steps @@ -158,7 +145,7 @@ def create_checkpoint_manager(self, max_to_keep=10, **kwargs): """Create checkpoint management.""" with self.strategy.scope(): self.ckpt = tf.train.Checkpoint(steps=self.steps, **kwargs) - checkpoint_dir = os.path.join(self.config["outdir"], "checkpoints") + checkpoint_dir = os.path.join(self.config.outdir, "checkpoints") if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) self.ckpt_manager = tf.train.CheckpointManager( @@ -232,7 +219,7 @@ def _train_epoch(self): # Print epoch info self.train_progbar.set_description_str( - f"[Train] [Epoch {self.epochs}/{self.config['num_epochs']}]") + f"[Train] [Epoch {self.epochs}/{self.config.num_epochs}]") # Print train info to progress bar self._print_train_metrics(self.train_progbar) @@ -325,7 +312,7 @@ def fit(self, train_dataset, eval_dataset=None, train_bs=None, train_acs=None, e def _check_log_interval(self): """Save log interval.""" - if self.steps % self.config["log_interval_steps"] == 0 or \ + if self.steps % self.config.log_interval_steps == 0 or \ self.steps >= self.total_train_steps: self._write_to_tensorboard(self.train_metrics, self.steps, stage="train") """Reset train metrics after save it to tensorboard.""" @@ -334,14 +321,14 @@ def _check_log_interval(self): def _check_save_interval(self): """Save log interval.""" - if self.steps % self.config["save_interval_steps"] == 0 or \ + if self.steps % self.config.save_interval_steps == 0 or \ self.steps >= self.total_train_steps: self.save_checkpoint() self.save_model_weights() def _check_eval_interval(self): """Save log interval.""" - if self.steps % self.config["eval_interval_steps"] == 0: + if self.steps % self.config.eval_interval_steps == 0: self._eval_epoch() # -------------------------------- UTILS ------------------------------------- @@ -367,12 +354,14 @@ class BaseTester(BaseRunner): After writing finished, it will calculate testing metrics """ - def __init__(self, config: dict, output_name: str = "test"): + def __init__(self, + config: RunningConfig, + output_name: str = "test"): super(BaseTester, self).__init__(config) self.test_data_loader = None self.processed_records = 0 - self.output_file_path = os.path.join(self.config["outdir"], f"{output_name}.tsv") + self.output_file_path = os.path.join(self.config.outdir, f"{output_name}.tsv") self.test_metrics = { "beam_wer": ErrorRate(func=wer, name="test_beam_wer", dtype=tf.float32), "beam_cer": ErrorRate(func=cer, name="test_beam_cer", dtype=tf.float32), diff --git a/tensorflow_asr/runners/ctc_runners.py b/tensorflow_asr/runners/ctc_runners.py index a7f8b1eb0a..7826542367 100644 --- a/tensorflow_asr/runners/ctc_runners.py +++ b/tensorflow_asr/runners/ctc_runners.py @@ -15,6 +15,7 @@ import os import tensorflow as tf +from ..configs.config import RunningConfig from ..featurizers.text_featurizers import TextFeaturizer from ..losses.ctc_losses import ctc_loss from .base_runners import BaseTrainer @@ -25,7 +26,7 @@ class CTCTrainer(BaseTrainer): def __init__(self, text_featurizer: TextFeaturizer, - config: dict, + config: RunningConfig, strategy: tf.distribute.Strategy = None): self.text_featurizer = text_featurizer super(CTCTrainer, self).__init__(config=config, strategy=strategy) @@ -42,7 +43,7 @@ def set_eval_metrics(self): def save_model_weights(self): with self.strategy.scope(): - self.model.save_weights(os.path.join(self.config["outdir"], "latest.h5")) + self.model.save_weights(os.path.join(self.config.outdir, "latest.h5")) @tf.function(experimental_relax_shapes=True) def _train_step(self, batch): diff --git a/tensorflow_asr/runners/transducer_runners.py b/tensorflow_asr/runners/transducer_runners.py index cd368d6d71..d961a922db 100644 --- a/tensorflow_asr/runners/transducer_runners.py +++ b/tensorflow_asr/runners/transducer_runners.py @@ -15,6 +15,7 @@ import os import tensorflow as tf +from ..configs.config import RunningConfig from ..optimizers.accumulation import GradientAccumulation from .base_runners import BaseTrainer from ..losses.rnnt_losses import rnnt_loss @@ -24,7 +25,7 @@ class TransducerTrainer(BaseTrainer): def __init__(self, - config: dict, + config: RunningConfig, text_featurizer: TextFeaturizer, strategy: tf.distribute.Strategy = None): self.text_featurizer = text_featurizer @@ -41,7 +42,7 @@ def set_eval_metrics(self): } def save_model_weights(self): - self.model.save_weights(os.path.join(self.config["outdir"], "latest.h5")) + self.model.save_weights(os.path.join(self.config.outdir, "latest.h5")) @tf.function(experimental_relax_shapes=True) def _train_step(self, batch): @@ -95,7 +96,7 @@ def _train_step(self, batch): self.accumulation.reset() - for accum_step in range(self.config.get("accumulation_steps", 1)): + for accum_step in range(self.config.accumulation_step): indices = tf.expand_dims( tf.range( diff --git a/tensorflow_asr/utils/utils.py b/tensorflow_asr/utils/utils.py index 4b03d5c978..1007e1a3fd 100755 --- a/tensorflow_asr/utils/utils.py +++ b/tensorflow_asr/utils/utils.py @@ -14,6 +14,8 @@ import os import sys import math +from typing import Union, List + import numpy as np import tensorflow as tf @@ -45,7 +47,7 @@ def check_key_in_dict(dictionary, keys): raise ValueError("{} must be defined".format(key)) -def preprocess_paths(paths): +def preprocess_paths(paths: Union[List, str]): if isinstance(paths, list): return [os.path.abspath(os.path.expanduser(path)) for path in paths] return os.path.abspath(os.path.expanduser(paths)) if paths else None