Skip to content

Commit

Permalink
Add Squeezeformer to ASR (#4416)
Browse files Browse the repository at this point in the history
* Initial squeezeformer impl

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Start time reduce and recovery

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Working commit of time reduction and time recovery modules

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Fix issue with number of params being incorrect

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Add initializations to the model

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Fix scheduler

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Remove float()

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Correct order of operations

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Correct order of operations

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Update time reduce PE to only update PE and nothing else

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Fix initialization

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Fix PE usage

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Comment out k2 for now

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Add usage comments to buffered ctc script

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Update docs

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Add squeezeformer configs for CTC

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Mark squeezeformer as experimental

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Add Jenkinsfile test

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Add Jenkinsfile test

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Fix style

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Replace all with /content/

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Try Jenkinsfile Fix with closure

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Update ctc config

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Update ctc config

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Update ctc config

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Add squeezeformer

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Add squeezeformer

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Fix Jenkinsfile

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Fix Jenkinsfile

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Try closure

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Remove test

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Add back squeezeformer test

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Remvoe script tag

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Update for review comments

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Remove experimental

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Correct an issue with RNNT alignments

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Correct an issue with RNNT metrics

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Code formatting

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Correct offset calculation for no look ahead

Signed-off-by: smajumdar <smajumdar@nvidia.com>
  • Loading branch information
titu1994 committed Jul 28, 2022
1 parent 16c96ba commit 96021f4
Show file tree
Hide file tree
Showing 22 changed files with 1,402 additions and 21 deletions.
33 changes: 32 additions & 1 deletion Jenkinsfile
Expand Up @@ -357,6 +357,37 @@ pipeline {
}
}

stage('L2: ASR dev run - part two') {
when {
anyOf {
branch 'main'
changeRequest target: 'main'
}
}
failFast true
parallel {
stage('L2: Speech to Text WPE - Squeezeformer') {
steps {
sh 'python examples/asr/asr_ctc/speech_to_text_ctc_bpe.py \
--config-path="../conf/squeezeformer" --config-name="squeezeformer_ctc_bpe" \
model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \
model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \
model.tokenizer.dir="/home/TestData/asr_tokenizers/an4_wpe_128/" \
model.tokenizer.type="wpe" \
model.encoder.d_model=144 \
model.train_ds.batch_size=4 \
model.validation_ds.batch_size=4 \
trainer.devices=[0] \
trainer.accelerator="gpu" \
+trainer.fast_dev_run=True \
exp_manager.exp_dir=examples/asr/speech_to_text_wpe_squeezeformer_results'
sh 'rm -rf examples/asr/speech_to_text_wpe_squeezeformer_results'
}
}
}
}


stage('L2: Speaker dev run') {
when {
anyOf {
Expand Down Expand Up @@ -3000,7 +3031,7 @@ pipeline {
4"
sh "rm /home/TestData/nlp/megatron_gpt/TP2/test-increase.nemo"
}
}
}
}
}
stage('L2: Megatron T5 Pretraining and Resume Training TP=2') {
Expand Down
2 changes: 1 addition & 1 deletion README.rst
Expand Up @@ -45,7 +45,7 @@ Key Features

* Speech processing
* `Automatic Speech Recognition (ASR) <https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/intro.html>`_
* Supported models: Jasper, QuartzNet, CitriNet, Conformer-CTC, Conformer-Transducer, ContextNet, LSTM-Transducer (RNNT), LSTM-CTC, ...
* Supported models: Jasper, QuartzNet, CitriNet, Conformer-CTC, Conformer-Transducer, Squeezeformer-CTC, Squeezeformer-Transducer, ContextNet, LSTM-Transducer (RNNT), LSTM-CTC, ...
* Supports CTC and Transducer/RNNT losses/decoders
* Beam Search decoding
* `Language Modelling for ASR <https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/asr_language_modeling.html>`_: N-gram LM in fusion with Beam Search decoding, Neural Rescoring with Transformer
Expand Down
13 changes: 12 additions & 1 deletion docs/source/asr/asr_all.bib
Expand Up @@ -1045,4 +1045,15 @@ @misc{ssl_inter
publisher = {arXiv},
year = {2021},
copyright = {arXiv.org perpetual, non-exclusive license}
}
}

@misc{kim2022squeezeformer,
doi = {10.48550/ARXIV.2206.00888},
url = {https://arxiv.org/abs/2206.00888},
author = {Kim, Sehoon and Gholami, Amir and Shaw, Albert and Lee, Nicholas and Mangalam, Karttikeya and Malik, Jitendra and Mahoney, Michael W. and Keutzer, Kurt},
keywords = {Audio and Speech Processing (eess.AS), Computation and Language (cs.CL), Sound (cs.SD), FOS: Electrical engineering, electronic engineering, information engineering, FOS: Electrical engineering, electronic engineering, information engineering, FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Squeezeformer: An Efficient Transformer for Automatic Speech Recognition},
publisher = {arXiv},
year = {2022},
copyright = {arXiv.org perpetual, non-exclusive license}
}
11 changes: 11 additions & 0 deletions docs/source/asr/configs.rst
Expand Up @@ -502,6 +502,17 @@ specify the tokenizer if you want to use sub-word encoding instead of character-
The encoder section includes the details about the Conformer-CTC encoder architecture. You may find more information in the
config files and also :doc:`nemo.collections.asr.modules.ConformerEncoder<./api.html#nemo.collections.asr.modules.ConformerEncoder>`.

Squeezeformer-CTC
~~~~~~~~~~~~~~~~~

The config files for Squeezeformer-CTC model contain character-based encoding and sub-word encoding at
``<NeMo_git_root>/examples/asr/conf/squeezeformer/squeezeformer_ctc_char.yaml`` and ``<NeMo_git_root>/examples/asr/conf/squeezeformer/squeezeformer_ctc_bpe.yaml``
respectively. Components of the configs of `Squeezeformer-CTC <./models.html#Squeezeformer-CTC>`__ are similar to Conformer config - `QuartzNet <./configs.html#Conformer-CTC>`__.

The encoder section includes the details about the Squeezeformer-CTC encoder architecture. You may find more information in the
config files and also :doc:`nemo.collections.asr.modules.SqueezeformerEncoder<./api.html#nemo.collections.asr.modules.SqueezeformerEncoder>`.


ContextNet
~~~~~~~~~~

Expand Down
Binary file added docs/source/asr/images/squeezeformer.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
37 changes: 36 additions & 1 deletion docs/source/asr/models.rst
Expand Up @@ -127,6 +127,8 @@ You may find the example config files of Conformer-Transducer model with charact
``<NeMo_git_root>/examples/asr/conf/conformer/conformer_transducer_char.yaml`` and
with sub-word encoding at ``<NeMo_git_root>/examples/asr/conf/conformer/conformer_transducer_bpe.yaml``.

.. _LSTM-Transducer_model:

LSTM-Transducer
---------------

Expand All @@ -138,12 +140,45 @@ It can be trained/used in unidirectional or bidirectional mode. The unidirection
This model supports both the sub-word level and character level encodings. You may find the example config file of RNNT model with wordpiece encoding at ``<NeMo_git_root>/examples/asr/conf/lstm/lstm_transducer_bpe.yaml``.
You can find more details on the config files for the RNNT models at ``LSTM-Transducer <./configs.html#lstm-transducer>``.

.. _LSTM-CTC_model:

LSTM-CTC
-------
--------

LSTM-CTC model is a CTC-variant of the LSTM-Transducer model which uses CTC loss/decoding instead of Transducer.
You may find the example config file of LSTM-CTC model with wordpiece encoding at ``<NeMo_git_root>/examples/asr/conf/lstm/lstm_ctc_bpe.yaml``.

.. _Squeezeformer-CTC_model:

Squeezeformer-CTC
-----------------

Squeezeformer-CTC is a CTC-based variant of the Squeezeformer model introduced in :cite:`asr-models-kim2022squeezeformer`. Squeezeformer-CTC has a
similar encoder as the original Squeezeformer but uses CTC loss and decoding instead of RNNT/Transducer loss, which makes it a non-autoregressive model. The vast majority of the architecture is similar to Conformer model, so please refer to `Conformer-CTC <./models.html#conformer-ctc>`.

The model primarily differs from Conformer in the following ways :

* Temporal U-Net style time reduction, effectively reducing memory consumption and FLOPs for execution.
* Unified activations throughout the model.
* Simplification of module structure, removal of redundant layers.

Here is the overall architecture of the encoder of Squeezeformer-CTC:

.. image:: images/squeezeformer.png
:align: center
:alt: Squeezeformer-CTC Model
:scale: 50%

This model supports both the sub-word level and character level encodings. You can find more details on the config files for the
Squeezeformer-CTC models at `Squeezeformer-CTC <./configs.html#squeezeformer-ctc>`. The variant with sub-word encoding is a BPE-based model
which can be instantiated using the :class:`~nemo.collections.asr.models.EncDecCTCModelBPE` class, while the
character-based variant is based on :class:`~nemo.collections.asr.models.EncDecCTCModel`.

You may find the example config files of Squeezeformer-CTC model with character-based encoding at
``<NeMo_git_root>/examples/asr/conf/squeezeformer/squeezeformer_ctc_char.yaml`` and
with sub-word encoding at ``<NeMo_git_root>/examples/asr/conf/squeezeformer/squeezeformer_ctc_bpe.yaml``.


References
----------

Expand Down
Expand Up @@ -17,6 +17,17 @@
(1) Demonstrate how to use NeMo Models outside of PytorchLightning
(2) Shows example of batch ASR inference
(3) Serves as CI test for pre-trained checkpoint
python speech_to_text_buffered_infer_ctc.py \
--asr_model="<Add model path here>" \
--test_manifest="<Add test dataset here>" \
--model_stride=4 \
--batch_size=32 \
--total_buffer_in_secs=4.0 \
--chunk_len_in_ms=1000
"""

import copy
Expand All @@ -27,6 +38,7 @@

import torch
from omegaconf import OmegaConf
from tqdm import tqdm

import nemo.collections.asr as nemo_asr
from nemo.collections.asr.metrics.wer import word_error_rate
Expand All @@ -47,7 +59,7 @@ def get_wer_feat(mfst, asr, frame_len, tokens_per_chunk, delay, preprocessor_cfg
refs = []

with open(mfst, "r") as mfst_f:
for l in mfst_f:
for l in tqdm(mfst_f, desc="Sample:"):
asr.reset()
row = json.loads(l.strip())
asr.read_audio_file(row['audio_filepath'], delay, model_stride_in_secs)
Expand Down Expand Up @@ -129,7 +141,7 @@ def main():
model_stride_in_secs,
asr_model.device,
)
logging.info(f"WER is {round(wer, 2)} when decoded with a delay of {round(mid_delay*model_stride_in_secs, 2)}s")
logging.info(f"WER is {round(wer, 4)} when decoded with a delay of {round(mid_delay*model_stride_in_secs, 2)}s")

if args.output_path is not None:

Expand Down
201 changes: 201 additions & 0 deletions examples/asr/conf/squeezeformer/squeezeformer_ctc_bpe.yaml
@@ -0,0 +1,201 @@
# It contains the default values for training a Squeezeformer-CTC ASR model, large size (~120M) with CTC loss and sub-word encoding.

# Architecture and training config:
# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective
# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches.
# Here are the recommended configs for different variants of Squeezeformer-CTC, other parameters are the same as in this config file.
# One extra layer (compared to original paper) is added to the medium and large variants to compensate for replacing the LSTM decoder with a linear one.
#
# | Model | d_model | n_layers | n_heads | time_masks | lr | time_reduce_idx | GBS |
# |--------------|---------|----------|---------|------------|--------|-----------------|------|
# | Extra-Small | 144 | 16 | 4 | 5 | 2e-3 | 7 | 1024 |
# | Small | 196 | 18 | 4 | 5 | 2e-3 | 8 | 1024 |
# | Small-Medium | 256 | 16 | 4 | 5 | 1.5e-3 | 7 | 1024 |
# | Medium | 324 | 20 | 4 | 7 | 1.5e-3 | 9 | 1024 |
# | Medium-Large | 512 | 18 | 8 | 10 | 1e-3 | 8 | 2048 |
# | Large | 640 | 22 | 8 | 10 | 5e-4 | 10 | 2048 |
#
# You may find more info about Squeezeformer-CTC here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#squeezeformer-ctc
# Pre-trained models of Squeezeformer-CTC can be found here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/results.html

name: "Squeezeformer-CTC-BPE"

model:
sample_rate: 16000
log_prediction: true # enables logging sample predictions in the output during training
ctc_reduction: 'mean_batch'
skip_nan_grad: false

train_ds:
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 8 # you may increase batch_size if your memory allows
shuffle: true
num_workers: 8
pin_memory: true
use_start_end_token: false
trim_silence: false
max_duration: 16.7 # it is set for LibriSpeech, you may need to update it for your dataset
min_duration: 0.1
# tarred datasets
is_tarred: false
tarred_audio_filepaths: null
shuffle_n: 2048
# bucketing params
bucketing_strategy: "synced_randomized"
bucketing_batch_size: null

validation_ds:
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 8 # you may increase batch_size if your memory allows
shuffle: false
num_workers: 8
pin_memory: true
use_start_end_token: false

test_ds:
manifest_filepath: null
sample_rate: ${model.sample_rate}
batch_size: 8 # you may increase batch_size if your memory allows
shuffle: false
num_workers: 8
pin_memory: true
use_start_end_token: false

# recommend small vocab size of 128 or 256 when using 4x sub-sampling
# you may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py
tokenizer:
dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (wpe)
type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer)

preprocessor:
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
sample_rate: ${model.sample_rate}
normalize: "per_feature"
window_size: 0.025
window_stride: 0.01
window: "hann"
features: 80
n_fft: 512
log: true
frame_splicing: 1
dither: 0.00001
pad_to: 0
pad_value: 0.0

spec_augment:
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
freq_masks: 2 # set to zero to disable it
# you may use lower time_masks for smaller models to have a faster convergence
time_masks: 10 # set to zero to disable it
freq_width: 27
time_width: 0.05

encoder:
_target_: nemo.collections.asr.modules.SqueezeformerEncoder
feat_in: ${model.preprocessor.features}
feat_out: -1 # you may set it if you need different output size other than the default d_model
n_layers: 18
d_model: 512

# Squeezeformer params
adaptive_scale: true
time_reduce_idx: 8
time_recovery_idx: null

# Sub-sampling params
subsampling: dw_striding # dw_striding, vggnet, striding or stacking, vggnet may give better results but needs more memory
subsampling_factor: 4 # must be power of 2
subsampling_conv_channels: -1 # -1 sets it to d_model

# Feed forward module's params
ff_expansion_factor: 4

# Multi-headed Attention Module's params
self_attention_model: rel_pos # rel_pos or abs_pos
n_heads: 8 # may need to be lower for smaller d_models
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
att_context_size: [-1, -1] # -1 means unlimited context
xscaling: true # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000

# Convolution module's params
conv_kernel_size: 31
conv_norm_type: 'batch_norm' # batch_norm or layer_norm

### regularization
dropout: 0.1 # The dropout used in most of the Conformer Modules
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules

decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoder
feat_in: null
num_classes: -1
vocabulary: []

optim:
name: adamw
lr: 0.001
# optimizer arguments
betas: [0.9, 0.98]
# less necessity for weight_decay as we already have large augmentations with SpecAug
# you may need weight_decay for large models, stable AMP training, small datasets, or when lower augmentations are used
# weight decay of 0.0 with lr of 2.0 also works fine
weight_decay: 4e-5

# scheduler setup
sched:
name: NoamHoldAnnealing
# scheduler config override
warmup_steps: 5000 # paper uses ~ 6500 steps (20 epochs) out of 500 epochs.
warmup_ratio: null
hold_steps: 40000
hold_ratio: null # paper uses ~ 40000 steps (160 epochs) out of 500 epochs.
decay_rate: 1.0 # Noam decay = 0.5 and no hold steps. For Squeezeformer, use hold ~ 10-30% of training, then faster decay.
min_lr: 1e-5

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: 1000
max_steps: null # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: 0.0
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 10 # Interval of logging.
progress_bar_refresh_rate: 10
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager
benchmark: false # needs to be false for models with variable-length speech input as it slows down training

exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: "val_wer"
mode: "min"
save_top_k: 5
always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints

# you need to set these two to True to continue the training
resume_if_exists: false
resume_ignore_no_checkpoint: false

# You may use this section to create a W&B logger
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null

0 comments on commit 96021f4

Please sign in to comment.