Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 22 additions & 20 deletions examples/conformer/train_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,29 +26,25 @@

parser = argparse.ArgumentParser(prog="Conformer Training")

parser.add_argument("--config", type=str, default=DEFAULT_YAML,
help="The file path of model configuration file")
parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")

parser.add_argument("--max_ckpts", type=int, default=10,
help="Max number of checkpoints to keep")
parser.add_argument("--max_ckpts", type=int, default=10, help="Max number of checkpoints to keep")

parser.add_argument("--tfrecords", default=False, action="store_true",
help="Whether to use tfrecords")
parser.add_argument("--tfrecords", default=False, action="store_true", help="Whether to use tfrecords")

parser.add_argument("--tbs", type=int, default=None,
help="Train batch size per replica")
parser.add_argument("--tfrecords_shards", type=int, default=16, help="Number of tfrecords shards")

parser.add_argument("--ebs", type=int, default=None,
help="Evaluation batch size per replica")
parser.add_argument("--tbs", type=int, default=None, help="Train batch size per replica")

parser.add_argument("--devices", type=int, nargs="*", default=[0],
help="Devices' ids to apply distributed training")
parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica")

parser.add_argument("--mxp", default=False, action="store_true",
help="Enable mixed precision")
parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training")

parser.add_argument("--cache", default=False, action="store_true",
help="Enable caching for dataset")
parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision")

parser.add_argument("--cache", default=False, action="store_true", help="Enable caching for dataset")

parser.add_argument("--bfs", type=int, default=100, help="Buffer size for shuffling")

args = parser.parse_args()

Expand All @@ -75,28 +71,34 @@
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
augmentations=config.learning_config.augmentations,
stage="train", cache=args.cache, shuffle=True
tfrecords_shards=args.tfrecords_shards,
stage="train", cache=args.cache,
shuffle=True, buffer_size=args.bfs,
)
eval_dataset = ASRTFRecordDataset(
data_paths=config.learning_config.dataset_config.eval_paths,
tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
tfrecords_shards=args.tfrecords_shards,
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
stage="eval", cache=args.cache, shuffle=True
stage="eval", cache=args.cache,
shuffle=True, buffer_size=args.bfs,
)
else:
train_dataset = ASRSliceDataset(
data_paths=config.learning_config.dataset_config.train_paths,
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
augmentations=config.learning_config.augmentations,
stage="train", cache=args.cache, shuffle=True
stage="train", cache=args.cache,
shuffle=True, buffer_size=args.bfs,
)
eval_dataset = ASRSliceDataset(
data_paths=config.learning_config.dataset_config.eval_paths,
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
stage="eval", cache=args.cache, shuffle=True
stage="eval", cache=args.cache,
shuffle=True, buffer_size=args.bfs,
)

conformer_trainer = TransducerTrainer(
Expand Down
45 changes: 23 additions & 22 deletions examples/conformer/train_ga_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,32 +26,27 @@

parser = argparse.ArgumentParser(prog="Conformer Training")

parser.add_argument("--config", type=str, default=DEFAULT_YAML,
help="The file path of model configuration file")
parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")

parser.add_argument("--max_ckpts", type=int, default=10,
help="Max number of checkpoints to keep")
parser.add_argument("--max_ckpts", type=int, default=10, help="Max number of checkpoints to keep")

parser.add_argument("--tfrecords", default=False, action="store_true",
help="Whether to use tfrecords")
parser.add_argument("--tfrecords", default=False, action="store_true", help="Whether to use tfrecords")

parser.add_argument("--tbs", type=int, default=None,
help="Train batch size per replica")
parser.add_argument("--tfrecords_shards", type=int, default=16, help="Number of tfrecords shards")

parser.add_argument("--ebs", type=int, default=None,
help="Evaluation batch size per replica")
parser.add_argument("--tbs", type=int, default=None, help="Train batch size per replica")

parser.add_argument("--acs", type=int, default=None,
help="Train accumulation steps")
parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica")

parser.add_argument("--devices", type=int, nargs="*", default=[0],
help="Devices' ids to apply distributed training")
parser.add_argument("--acs", type=int, default=None, help="Train accumulation steps")

parser.add_argument("--mxp", default=False, action="store_true",
help="Enable mixed precision")
parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training")

parser.add_argument("--cache", default=False, action="store_true",
help="Enable caching for dataset")
parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision")

parser.add_argument("--cache", default=False, action="store_true", help="Enable caching for dataset")

parser.add_argument("--bfs", type=int, default=100, help="Buffer size for shuffling")

args = parser.parse_args()

Expand All @@ -78,28 +73,34 @@
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
augmentations=config.learning_config.augmentations,
stage="train", cache=args.cache, shuffle=True
tfrecords_shards=args.tfrecords_shards,
stage="train", cache=args.cache,
shuffle=True, buffer_size=args.bfs,
)
eval_dataset = ASRTFRecordDataset(
data_paths=config.learning_config.dataset_config.eval_paths,
tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
tfrecords_shards=args.tfrecords_shards,
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
stage="eval", cache=args.cache, shuffle=True
stage="eval", cache=args.cache,
shuffle=True, buffer_size=args.bfs,
)
else:
train_dataset = ASRSliceDataset(
data_paths=config.learning_config.dataset_config.train_paths,
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
augmentations=config.learning_config.augmentations,
stage="train", cache=args.cache, shuffle=True
stage="train", cache=args.cache,
shuffle=True, buffer_size=args.bfs,
)
eval_dataset = ASRSliceDataset(
data_paths=config.learning_config.dataset_config.eval_paths,
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
stage="eval", cache=args.cache, shuffle=True
stage="eval", cache=args.cache,
shuffle=True, buffer_size=args.bfs,
)

conformer_trainer = TransducerTrainerGA(
Expand Down
54 changes: 26 additions & 28 deletions examples/conformer/train_ga_subword_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,41 +26,33 @@

parser = argparse.ArgumentParser(prog="Conformer Training")

parser.add_argument("--config", type=str, default=DEFAULT_YAML,
help="The file path of model configuration file")
parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")

parser.add_argument("--max_ckpts", type=int, default=10,
help="Max number of checkpoints to keep")
parser.add_argument("--max_ckpts", type=int, default=10, help="Max number of checkpoints to keep")

parser.add_argument("--tfrecords", default=False, action="store_true",
help="Whether to use tfrecords")
parser.add_argument("--tfrecords", default=False, action="store_true", help="Whether to use tfrecords")

parser.add_argument("--tbs", type=int, default=None,
help="Train batch size per replica")
parser.add_argument("--tfrecords_shards", type=int, default=16, help="Number of tfrecords shards")

parser.add_argument("--ebs", type=int, default=None,
help="Evaluation batch size per replica")
parser.add_argument("--tbs", type=int, default=None, help="Train batch size per replica")

parser.add_argument("--acs", type=int, default=None,
help="Train accumulation steps")
parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica")

parser.add_argument("--sentence_piece", default=False, action="store_true",
help="Whether to use `SentencePiece` model")
parser.add_argument("--acs", type=int, default=None, help="Train accumulation steps")

parser.add_argument("--devices", type=int, nargs="*", default=[0],
help="Devices' ids to apply distributed training")
parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model")

parser.add_argument("--mxp", default=False, action="store_true",
help="Enable mixed precision")
parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training")

parser.add_argument("--cache", default=False, action="store_true",
help="Enable caching for dataset")
parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision")

parser.add_argument("--subwords", type=str, default=None,
help="Path to file that stores generated subwords")
parser.add_argument("--cache", default=False, action="store_true", help="Enable caching for dataset")

parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[],
help="Transcript files for generating subwords")
parser.add_argument("--subwords", type=str, default=None, help="Path to file that stores generated subwords")

parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[], help="Transcript files for generating subwords")

parser.add_argument("--bfs", type=int, default=100, help="Buffer size for shuffling")

args = parser.parse_args()

Expand Down Expand Up @@ -100,28 +92,34 @@
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
augmentations=config.learning_config.augmentations,
stage="train", cache=args.cache, shuffle=True
tfrecords_shards=args.tfrecords_shards,
stage="train", cache=args.cache,
shuffle=True, buffer_size=args.bfs,
)
eval_dataset = ASRTFRecordDataset(
data_paths=config.learning_config.dataset_config.eval_paths,
tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
tfrecords_shards=args.tfrecords_shards,
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
stage="eval", cache=args.cache, shuffle=True
stage="eval", cache=args.cache,
shuffle=True, buffer_size=args.bfs,
)
else:
train_dataset = ASRSliceDataset(
data_paths=config.learning_config.dataset_config.train_paths,
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
augmentations=config.learning_config.augmentations,
stage="train", cache=args.cache, shuffle=True
stage="train", cache=args.cache,
shuffle=True, buffer_size=args.bfs,
)
eval_dataset = ASRSliceDataset(
data_paths=config.learning_config.dataset_config.eval_paths,
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
stage="eval", cache=args.cache, shuffle=True
stage="eval", cache=args.cache,
shuffle=True, buffer_size=args.bfs,
)

conformer_trainer = TransducerTrainerGA(
Expand Down
51 changes: 25 additions & 26 deletions examples/conformer/train_subword_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,38 +26,31 @@

parser = argparse.ArgumentParser(prog="Conformer Training")

parser.add_argument("--config", type=str, default=DEFAULT_YAML,
help="The file path of model configuration file")
parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")

parser.add_argument("--max_ckpts", type=int, default=10,
help="Max number of checkpoints to keep")
parser.add_argument("--max_ckpts", type=int, default=10, help="Max number of checkpoints to keep")

parser.add_argument("--tfrecords", default=False, action="store_true",
help="Whether to use tfrecords")
parser.add_argument("--tfrecords", default=False, action="store_true", help="Whether to use tfrecords")

parser.add_argument("--sentence_piece", default=False, action="store_true",
help="Whether to use `SentencePiece` model")
parser.add_argument("--tfrecords_shards", type=int, default=16, help="Number of tfrecords shards")

parser.add_argument("--tbs", type=int, default=None,
help="Train batch size per replica")
parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model")

parser.add_argument("--ebs", type=int, default=None,
help="Evaluation batch size per replica")
parser.add_argument("--tbs", type=int, default=None, help="Train batch size per replica")

parser.add_argument("--devices", type=int, nargs="*", default=[0],
help="Devices' ids to apply distributed training")
parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica")

parser.add_argument("--mxp", default=False, action="store_true",
help="Enable mixed precision")
parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training")

parser.add_argument("--cache", default=False, action="store_true",
help="Enable caching for dataset")
parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision")

parser.add_argument("--subwords", type=str, default=None,
help="Path to file that stores generated subwords")
parser.add_argument("--cache", default=False, action="store_true", help="Enable caching for dataset")

parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[],
help="Transcript files for generating subwords")
parser.add_argument("--subwords", type=str, default=None, help="Path to file that stores generated subwords")

parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[], help="Transcript files for generating subwords")

parser.add_argument("--bfs", type=int, default=100, help="Buffer size for shuffling")

args = parser.parse_args()

Expand Down Expand Up @@ -97,28 +90,34 @@
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
augmentations=config.learning_config.augmentations,
stage="train", cache=args.cache, shuffle=True
tfrecords_shards=args.tfrecords_shards,
stage="train", cache=args.cache,
shuffle=True, buffer_size=args.bfs,
)
eval_dataset = ASRTFRecordDataset(
data_paths=config.learning_config.dataset_config.eval_paths,
tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
tfrecords_shards=args.tfrecords_shards,
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
stage="eval", cache=args.cache, shuffle=True
stage="eval", cache=args.cache,
shuffle=True, buffer_size=args.bfs,
)
else:
train_dataset = ASRSliceDataset(
data_paths=config.learning_config.dataset_config.train_paths,
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
augmentations=config.learning_config.augmentations,
stage="train", cache=args.cache, shuffle=True
stage="train", cache=args.cache,
shuffle=True, buffer_size=args.bfs,
)
eval_dataset = ASRSliceDataset(
data_paths=config.learning_config.dataset_config.eval_paths,
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
stage="eval", cache=args.cache, shuffle=True
stage="eval", cache=args.cache,
shuffle=True, buffer_size=args.bfs,
)

conformer_trainer = TransducerTrainer(
Expand Down
Loading