Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 12 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ TensorFlowASR implements some automatic speech recognition architectures such as
- [TFLite Convertion](#tflite-convertion)
- [Features Extraction](#features-extraction)
- [Augmentations](#augmentations)
- [Training & Testing](#training--testing)
- [Training & Testing Tutorial](#training--testing-tutorial)
- [Corpus Sources and Pretrained Models](#corpus-sources-and-pretrained-models)
- [English](#english)
- [Vietnamese](#vietnamese)
Expand Down Expand Up @@ -164,34 +164,17 @@ See [features_extraction](./tensorflow_asr/featurizers/README.md)

See [augmentations](./tensorflow_asr/augmentations/README.md)

## Training & Testing

**Example YAML Config Structure**

```yaml
speech_config: ...
model_config: ...
decoder_config: ...
learning_config:
train_dataset_config:
augmentation_config: ...
data_paths: ...
tfrecords_dir: ...
eval_dataset_config:
augmentation_config: ...
data_paths: ...
tfrecords_dir: ...
test_dataset_config:
augmentation_config: ...
data_paths: ...
tfrecords_dir: ...
optimizer_config: ...
running_config:
batch_size: 8
num_epochs: 20
outdir: ...
log_interval_steps: 500
```
## Training & Testing Tutorial

1. Define config YAML file, see the `config.yml` files in the [example folder](./examples) for reference (you can copy and modify values such as parameters, paths, etc.. to match your local machine configuration)
2. Download your corpus (a.k.a datasets) and create a script to generate `transcripts.tsv` files from your corpus (this is general format used in this project because each dataset has different format). For more detail, see [datasets](./tensorflow_asr/datasets/README.md). **Note:** Make sure your data contain only characters in your language, for example, english has `a` to `z` and `'`. **Do not use `cache` if your dataset size is not fit in the RAM**.
3. [Optional] Generate TFRecords to use `tf.data.TFRecordDataset` for better performance by using the script [create_tfrecords.py](./scripts/create_tfrecords.py)
4. Create vocabulary file (characters or subwords/wordpieces) by defining `language.characters`, using the scripts [generate_vocab_subwords.py](./scripts/generate_vocab_subwords.py) or [generate_vocab_sentencepiece.py](./scripts/generate_vocab_sentencepiece.py). There're predefined ones in [vocabularies](./vocabularies)
5. [Optional] Generate metadata file for your dataset by using script [generate_metadata.py](./scripts/generate_metadata.py). This metadata file contains maximum lengths calculated with your `config.yml` and total number of elements in each dataset, for static shape training and precalculated steps per epoch.
6. For training, see `train_*.py` files in the [example folder](./examples) to see the options
7. For testing, see `test_.*.py` files in the [example folder](./examples) to see the options. **Note:** Testing is currently not supported for TPUs. It will print nothing other than the progress bar in the console, but it will store the predicted transcripts to the file `output_name.tsv` in the `outdir` defined in the config yaml file. After testing is done, the metrics (WER and CER) are calculated from `output_name.tsv`. **If you define the same `output_name`, it will resume the testing from the previous tested batch, which means if the testing is done then it will only calculate the metrics, if you want to run a new test, define a new `output_name` that the file `output.tsv` is not exists or only contains the header**

**Recommendation**: For better performance, please use **keras builtin training functions** as in `train_keras_*.py` files and/or tfrecords. Keras builtin training uses **infinite dataset**, which avoids the potential last partial batch.

See [examples](./examples/) for some predefined ASR models and results

Expand Down
59 changes: 39 additions & 20 deletions examples/conformer/train_keras_subword_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@

parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica")

parser.add_argument("--spx", type=int, default=1, help="Steps per execution for maximizing performance")

parser.add_argument("--metadata_prefix", type=str, default=None, help="Path to file containing metadata")

parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training")

parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision")
Expand Down Expand Up @@ -79,25 +83,38 @@
if args.tfrecords:
train_dataset = ASRTFRecordDatasetKeras(
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
**vars(config.learning_config.train_dataset_config)
**vars(config.learning_config.train_dataset_config),
indefinite=True
)
eval_dataset = ASRTFRecordDatasetKeras(
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
**vars(config.learning_config.eval_dataset_config)
)
# Update metadata calculated from both train and eval datasets
train_dataset.load_metadata(args.metadata_prefix)
eval_dataset.load_metadata(args.metadata_prefix)
# Use dynamic length
speech_featurizer.reset_length()
text_featurizer.reset_length()
else:
train_dataset = ASRSliceDatasetKeras(
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
**vars(config.learning_config.train_dataset_config)
**vars(config.learning_config.train_dataset_config),
indefinite=True
)
eval_dataset = ASRSliceDatasetKeras(
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
**vars(config.learning_config.train_dataset_config)
**vars(config.learning_config.train_dataset_config),
indefinite=True
)

global_batch_size = config.learning_config.running_config.batch_size
global_batch_size *= strategy.num_replicas_in_sync

train_data_loader = train_dataset.create(global_batch_size)
eval_data_loader = eval_dataset.create(global_batch_size)

with strategy.scope():
global_batch_size = config.learning_config.running_config.batch_size
global_batch_size *= strategy.num_replicas_in_sync
# build model
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
conformer._build(speech_featurizer.shape)
Expand All @@ -114,19 +131,21 @@
epsilon=config.learning_config.optimizer_config["epsilon"]
)

conformer.compile(optimizer=optimizer, global_batch_size=global_batch_size, blank=text_featurizer.blank)

train_data_loader = train_dataset.create(global_batch_size)
eval_data_loader = eval_dataset.create(global_batch_size)

callbacks = [
tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
]

conformer.fit(
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
validation_data=eval_data_loader, callbacks=callbacks,
steps_per_epoch=train_dataset.total_steps
conformer.compile(
optimizer=optimizer,
experimental_steps_per_execution=args.spx,
global_batch_size=global_batch_size,
blank=text_featurizer.blank
)

callbacks = [
tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
]

conformer.fit(
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
validation_data=eval_data_loader, callbacks=callbacks,
steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps
)
41 changes: 22 additions & 19 deletions examples/conformer/train_tpu_keras_subword_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

parser.add_argument("--bs", type=int, default=None, help="Batch size per replica")

parser.add_argument("--spx", type=int, default=1, help="Steps per execution for maximizing TPU performance")
parser.add_argument("--spx", type=int, default=50, help="Steps per execution for maximizing TPU performance")

parser.add_argument("--tpu_address", type=str, default=None, help="TPU address. Leave None on Colab")

Expand Down Expand Up @@ -78,11 +78,13 @@

train_dataset = ASRTFRecordDatasetKeras(
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
**vars(config.learning_config.train_dataset_config)
**vars(config.learning_config.train_dataset_config),
indefinite=True
)
eval_dataset = ASRTFRecordDatasetKeras(
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
**vars(config.learning_config.eval_dataset_config)
**vars(config.learning_config.eval_dataset_config),
indefinite=True
)

if args.compute_lengths:
Expand All @@ -93,10 +95,14 @@
train_dataset.load_metadata(args.metadata_prefix)
eval_dataset.load_metadata(args.metadata_prefix)

batch_size = args.bs if args.bs is not None else config.learning_config.running_config.batch_size
global_batch_size = batch_size
global_batch_size *= strategy.num_replicas_in_sync

train_data_loader = train_dataset.create(global_batch_size)
eval_data_loader = eval_dataset.create(global_batch_size)

with strategy.scope():
batch_size = args.bs if args.bs is not None else config.learning_config.running_config.batch_size
global_batch_size = batch_size
global_batch_size *= strategy.num_replicas_in_sync
# build model
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
conformer._build(speech_featurizer.shape, prediction_shape=text_featurizer.prepand_shape, batch_size=global_batch_size)
Expand All @@ -120,17 +126,14 @@
blank=text_featurizer.blank
)

train_data_loader = train_dataset.create(global_batch_size)
eval_data_loader = eval_dataset.create(global_batch_size)

callbacks = [
tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
]
callbacks = [
tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
]

conformer.fit(
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
validation_data=eval_data_loader, callbacks=callbacks,
steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps
)
conformer.fit(
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
validation_data=eval_data_loader, callbacks=callbacks,
steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps
)
123 changes: 0 additions & 123 deletions examples/conformer/train_tpu_subword_conformer.py

This file was deleted.

Loading