diff --git a/Jenkinsfile b/Jenkinsfile index 4f343b196f12..92b0a09dffcc 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -209,10 +209,10 @@ pipeline { } } - stage('L2: Speech to Text WPE') { + stage('L2: Speech to Text WPE - CitriNet') { steps { sh 'python examples/asr/speech_to_text_bpe.py \ - --config-path="experimental/configs/" --config-name="config_bpe" \ + --config-path="experimental/citrinet/" --config-name="config_bpe" \ model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \ model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \ model.tokenizer.dir="/home/TestData/asr_tokenizers/an4_wpe_128/" \ @@ -223,6 +223,21 @@ pipeline { sh 'rm -rf examples/asr/speech_to_text_wpe_results' } } + + stage('L2: Speech to Text WPE - Conformer') { + steps { + sh 'python examples/asr/speech_to_text_bpe.py \ + --config-path="experimental/conformer" --config-name="conformer_bpe" \ + model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \ + model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \ + model.tokenizer.dir="/home/TestData/asr_tokenizers/an4_wpe_128/" \ + model.tokenizer.type="wpe" \ + trainer.gpus=[1] \ + +trainer.fast_dev_run=True \ + exp_manager.exp_dir=examples/asr/speech_to_text_wpe_conformer_results' + sh 'rm -rf examples/asr/speech_to_text_wpe_conformer_results' + } + } } } diff --git a/examples/asr/experimental/configs/config_bpe.yaml b/examples/asr/experimental/citrinet/config_bpe.yaml similarity index 100% rename from examples/asr/experimental/configs/config_bpe.yaml rename to examples/asr/experimental/citrinet/config_bpe.yaml diff --git a/examples/asr/experimental/configs/contextnet_bpe/contextnet_192_2x_stride.yaml b/examples/asr/experimental/citrinet/contextnet_192_2x_stride.yaml similarity index 100% rename from examples/asr/experimental/configs/contextnet_bpe/contextnet_192_2x_stride.yaml rename to examples/asr/experimental/citrinet/contextnet_192_2x_stride.yaml diff --git a/examples/asr/experimental/configs/contextnet_bpe/contextnet_192_4x_stride.yaml b/examples/asr/experimental/citrinet/contextnet_192_4x_stride.yaml similarity index 100% rename from examples/asr/experimental/configs/contextnet_bpe/contextnet_192_4x_stride.yaml rename to examples/asr/experimental/citrinet/contextnet_192_4x_stride.yaml diff --git a/examples/asr/experimental/configs/contextnet_bpe/contextnet_192_8x_stride.yaml b/examples/asr/experimental/citrinet/contextnet_192_8x_stride.yaml similarity index 100% rename from examples/asr/experimental/configs/contextnet_bpe/contextnet_192_8x_stride.yaml rename to examples/asr/experimental/citrinet/contextnet_192_8x_stride.yaml diff --git a/examples/asr/experimental/conformer/conformer_bpe.yaml b/examples/asr/experimental/conformer/conformer_bpe.yaml new file mode 100644 index 000000000000..0e5e19427288 --- /dev/null +++ b/examples/asr/experimental/conformer/conformer_bpe.yaml @@ -0,0 +1,156 @@ +name: &name "Conformer-BPE" + +model: + sample_rate: &sample_rate 16000 + log_prediction: true + load_weights_from_checkpoint: null + ctc_reduction: 'mean_batch' + + train_ds: + manifest_filepath: ??? + sample_rate: 16000 + batch_size: 16 + trim_silence: false + max_duration: 16.7 + min_duration: 0.1 + shuffle: true + is_tarred: false + tarred_audio_filepaths: null + num_workers: 4 + pin_memory: false + use_start_end_token: true + + validation_ds: + manifest_filepath: ??? + sample_rate: 16000 + batch_size: 16 + shuffle: false + num_workers: 4 + pin_memory: false + use_start_end_token: true + + test_ds: + manifest_filepath: null + sample_rate: 16000 + batch_size: 16 + shuffle: false + num_workers: 4 + pin_memory: false + use_start_end_token: true + + tokenizer: + dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe) + type: ??? # Can be either bpe or wpe + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: *sample_rate + normalize: "per_feature" + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: &n_mels 80 + n_fft: 512 + frame_splicing: 1 + dither: 0.00001 + pad_to: 16 + stft_conv: false + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + # SpecAug params + freq_masks: 2 # set to zero to disable the SpecAug augmentation + time_masks: 2 # set to zero to disable the SpecAug augmentation + freq_width: 27 + time_width: 100 + # Cut-off params + rect_masks: 0 # set to zero to disable the cut-off augmentation + rect_time: 120 + rect_freq: 50 + + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: *n_mels + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 16 + d_model: 256 + + # Sub-sampling params + subsampling: vggnet # vggnet or striding + subsampling_factor: 4 # must be power of 2 + subsampling_conv_channels: 64 # set to -1 to make it equal to the d_model + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos, abs_pos + pos_emb_max_len: 5000 + n_heads: 4 + xscaling: true + + # Convolution module's params + conv_kernel_size: 31 + + ### regularization + dropout: 0.1 # The dropout used inside the Conformer Modules + dropout_emb: 0.1 # The dropout used embeddings + dropout_att: 0.0 # The dropout for multi-headed attention modules + + decoder: + _target_: nemo.collections.asr.modules.LSTMDecoder + feat_in: null # If not provided, the feat_out of the encoder would be used + num_classes: -1 # filled with vocabulary size from tokenizer at runtime + vocabulary: [] # filled with vocabulary from tokenizer at runtime + lstm_hidden_size: 640 + bidirectional: False + num_layers: 1 + + optim: + name: novograd + lr: 0.01 + # optimizer arguments + betas: [0.8, 0.5] + weight_decay: 0.001 + + # scheduler setup + sched: + name: CosineAnnealing + # scheduler config override + warmup_steps: 4000 + warmup_ratio: null + min_lr: 1e-9 + last_epoch: -1 + +trainer: + gpus: 0 # number of gpus + num_nodes: 1 + max_epochs: 100 + max_steps: null # computed at runtime if not set + val_check_interval: 1 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + distributed_backend: ddp + accumulate_grad_batches: 1 + gradient_clip_val: 0.0 + amp_level: O0 # O1/O2 for mixed precision + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 10 # Interval of logging. + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + checkpoint_callback: false # Provided by exp_manager + logger: false # Provided by exp_manager + + +exp_manager: + exp_dir: null + name: *name + create_tensorboard_logger: true + create_checkpoint_callback: true + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null + resume_if_exists: false + resume_ignore_no_checkpoint: false diff --git a/examples/asr/experimental/conformer/conformer_char.yaml b/examples/asr/experimental/conformer/conformer_char.yaml new file mode 100644 index 000000000000..6129c0ce5bc8 --- /dev/null +++ b/examples/asr/experimental/conformer/conformer_char.yaml @@ -0,0 +1,153 @@ +name: &name "Conformer-char" + +model: + sample_rate: &sample_rate 16000 + labels: &labels [" ", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", + "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "'"] + log_prediction: true + load_weights_from_checkpoint: null + ctc_reduction: 'mean_batch' + + train_ds: + manifest_filepath: ??? + labels: *labels + sample_rate: 16000 + batch_size: 32 + trim_silence: false + max_duration: 16.7 + min_duration: 0.1 + shuffle: true + is_tarred: false + tarred_audio_filepaths: null + num_workers: 4 + pin_memory: true + + validation_ds: + manifest_filepath: ??? + labels: *labels + sample_rate: 16000 + batch_size: 32 + shuffle: false + num_workers: 4 + pin_memory: true + + test_ds: + manifest_filepath: null + labels: *labels + sample_rate: 16000 + batch_size: 32 + shuffle: false + num_workers: 4 + pin_memory: true + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: *sample_rate + normalize: "per_feature" + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: &n_mels 80 + n_fft: 512 + frame_splicing: 1 + dither: 0.00001 + pad_to: 16 + stft_conv: false + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + # SpecAug params + freq_masks: 2 # set to zero to disable the SpecAug augmentation + time_masks: 2 # set to zero to disable the SpecAug augmentation + freq_width: 27 + time_width: 100 + # Cut-off params + rect_masks: 0 # set to zero to disable the cut-off augmentation + rect_time: 120 + rect_freq: 50 + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: *n_mels + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 16 + d_model: 256 + + # Sub-sampling params + subsampling: vggnet # vggnet or striding + subsampling_factor: 4 # must be power of 2 + subsampling_conv_channels: 64 # set to -1 to make it equal to the d_model + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos, abs_pos + pos_emb_max_len: 5000 + n_heads: 4 + xscaling: true + + # Convolution module's params + conv_kernel_size: 31 + + ### regularization + dropout: 0.1 # The dropout used inside the Conformer Modules + dropout_emb: 0.1 # The dropout used embeddings + dropout_att: 0.0 # The dropout for multi-headed attention modules + + decoder: + _target_: nemo.collections.asr.modules.LSTMDecoder + feat_in: null # If not provided, the feat_out of the encoder would be used + num_classes: 28 + vocabulary: *labels + lstm_hidden_size: 640 + bidirectional: False + num_layers: 1 + + optim: + name: novograd + lr: 0.01 + # optimizer arguments + betas: [0.8, 0.5] + weight_decay: 0.001 + + # scheduler setup + sched: + name: CosineAnnealing + # scheduler config override + warmup_steps: 1000 + warmup_ratio: null + min_lr: 1e-9 + last_epoch: -1 + +trainer: + gpus: 0 # number of gpus + num_nodes: 1 + max_epochs: 100 + max_steps: null # computed at runtime if not set + val_check_interval: 1 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + distributed_backend: ddp + accumulate_grad_batches: 1 + gradient_clip_val: 0.0 + amp_level: O0 # O1/O2 for mixed precision + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 10 # Interval of logging. + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + checkpoint_callback: false # Provided by exp_manager + logger: false # Provided by exp_manager + + +exp_manager: + exp_dir: null + name: *name + create_tensorboard_logger: true + create_checkpoint_callback: true + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null + resume_if_exists: false + resume_ignore_no_checkpoint: false diff --git a/examples/asr/speech_to_text.py b/examples/asr/speech_to_text.py index d5636f41d6e0..fb164727f8f9 100644 --- a/examples/asr/speech_to_text.py +++ b/examples/asr/speech_to_text.py @@ -13,9 +13,11 @@ # limitations under the License. import pytorch_lightning as pl +from omegaconf import OmegaConf from nemo.collections.asr.models import EncDecCTCModel from nemo.core.config import hydra_runner +from nemo.utils import logging from nemo.utils.exp_manager import exp_manager @@ -44,8 +46,8 @@ hydra.run.dir="." \ trainer.gpus=2 \ trainer.max_epochs=2 \ - model.optim.args.params.betas=[0.8,0.5] \ - model.optim.args.params.weight_decay=0.0001 + model.optim.args.betas=[0.8,0.5] \ + model.optim.args.weight_decay=0.0001 Overide optimizer entirely python speech_to_text.py \ @@ -65,12 +67,24 @@ @hydra_runner(config_path="conf", config_name="config") def main(cfg): + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) asr_model = EncDecCTCModel(cfg=cfg.model, trainer=trainer) trainer.fit(asr_model) + if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: + gpu = 1 if cfg.trainer.gpus != 0 else 0 + trainer = pl.Trainer( + gpus=gpu, + precision=cfg.trainer.precision, + amp_level=cfg.trainer.amp_level, + amp_backend=cfg.trainer.amp_backend, + ) + if asr_model.prepare_test(trainer): + trainer.test(asr_model) + if __name__ == '__main__': main() # noqa pylint: disable=no-value-for-parameter diff --git a/examples/asr/speech_to_text_bpe.py b/examples/asr/speech_to_text_bpe.py index e49e794e7c3b..400ffe3ef479 100644 --- a/examples/asr/speech_to_text_bpe.py +++ b/examples/asr/speech_to_text_bpe.py @@ -47,6 +47,7 @@ ``` """ import pytorch_lightning as pl +from omegaconf import OmegaConf from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE from nemo.core.config import hydra_runner @@ -56,17 +57,21 @@ @hydra_runner(config_path="experimental/configs/", config_name="config_bpe") def main(cfg): - logging.info(f'Hydra config: {cfg.pretty()}') + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + print(OmegaConf.to_yaml(cfg)) trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) + asr_model = EncDecCTCModelBPE(cfg=cfg.model, trainer=trainer) trainer.fit(asr_model) if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: gpu = 1 if cfg.trainer.gpus != 0 else 0 - trainer = pl.Trainer(gpus=gpu) - if asr_model.prepare_test(trainer): + test_trainer = pl.Trainer( + gpus=gpu, precision=trainer.precision, amp_level=trainer.amp_level, amp_backend=trainer.amp_backend, + ) + if asr_model.prepare_test(test_trainer): trainer.test(asr_model) diff --git a/examples/nlp/text_classification/conf/text_classification_config.yaml b/examples/nlp/text_classification/conf/text_classification_config.yaml index 4b3963cad492..e88cca02d6a5 100644 --- a/examples/nlp/text_classification/conf/text_classification_config.yaml +++ b/examples/nlp/text_classification/conf/text_classification_config.yaml @@ -15,7 +15,7 @@ # Config file for text classification with pre-trained BERT models trainer: - gpus: 1 # number of gpus (0 for CPU), or list of the GPUs to use e.g. [0, 1] + gpus: 1 # number of GPUs (0 for CPU), or list of the GPUs to use e.g. [0, 1] num_nodes: 1 max_epochs: 100 max_steps: null # precedence over max_epochs diff --git a/examples/nlp/text_classification/data/import_datasets.py b/examples/nlp/text_classification/data/import_datasets.py index 76d08c0db675..95db24cea7d7 100644 --- a/examples/nlp/text_classification/data/import_datasets.py +++ b/examples/nlp/text_classification/data/import_datasets.py @@ -21,7 +21,8 @@ --source_data_dir "./thucnews_orig_data/" \ --target_data_dir "./thucnews/" -It reads the data from "source_data_dir" folder, processes and converts the data into NeMo's format. Then writes the results into "target_data_dir" folder. +It reads the data from "source_data_dir" folder, processes and converts the data into NeMo's format. +Then writes the results into "target_data_dir" folder. """ import argparse @@ -40,7 +41,8 @@ def process_imdb(infold, outfold, uncased, modes=['train', 'test']): link = 'https://ai.stanford.edu/~amaas/data/sentiment/' raise ValueError( f'Data not found at {infold}. ' - f'Please download IMDB reviews dataset from {link} and extract it into the folder specified by source_data_dir argument.' + f'Please download IMDB reviews dataset from {link} and ' + f'extract it into the folder specified by source_data_dir argument.' ) logging.info(f'Processing IMDB dataset and store at {outfold}') @@ -153,9 +155,7 @@ def process_thucnews(infold, outfold): if __name__ == "__main__": # Parse the command-line arguments. parser = argparse.ArgumentParser(description="Process and convert datasets into NeMo\'s format.") - parser.add_argument( - "--dataset_name", required=True, type=str, choices=['imdb', 'thucnews', 'chemprot'], - ) + parser.add_argument("--dataset_name", required=True, type=str, choices=['imdb', 'thucnews', 'chemprot']) parser.add_argument( "--source_data_dir", required=True, type=str, help='The path to the folder containing the dataset files.' ) diff --git a/nemo/collections/asr/data/audio_to_text.py b/nemo/collections/asr/data/audio_to_text.py index c60b385b0c55..f742341ef3c3 100644 --- a/nemo/collections/asr/data/audio_to_text.py +++ b/nemo/collections/asr/data/audio_to_text.py @@ -165,7 +165,7 @@ def __getitem__(self, index): offset = 0 features = self.featurizer.process( - sample.audio_file, offset=offset, duration=sample.duration, trim=self.trim, orig_sr=sample.orig_sr, + sample.audio_file, offset=offset, duration=sample.duration, trim=self.trim, orig_sr=sample.orig_sr ) f, fl = features, torch.tensor(features.shape[0]).long() else: @@ -270,7 +270,7 @@ def __init__( self.labels = labels parser = parsers.make_parser( - labels=labels, name=parser, unk_id=unk_index, blank_id=blank_index, do_normalize=normalize, + labels=labels, name=parser, unk_id=unk_index, blank_id=blank_index, do_normalize=normalize ) super().__init__( @@ -474,13 +474,14 @@ def __init__( trim: bool = False, load_audio: bool = True, add_misc: bool = False, + use_start_end_token: bool = True, ): - if hasattr(tokenizer, 'bos_token'): + if use_start_end_token and hasattr(tokenizer, 'bos_token'): bos_id = tokenizer.bos_id else: bos_id = None - if hasattr(tokenizer, 'eos_token'): + if use_start_end_token and hasattr(tokenizer, 'eos_token'): eos_id = tokenizer.eos_id else: eos_id = None @@ -553,7 +554,7 @@ def __init__( load_audio=True, ): self.collection = collections.ASRSpeechLabel( - manifests_files=manifest_filepath.split(','), min_duration=min_duration, max_duration=max_duration, + manifests_files=manifest_filepath.split(','), min_duration=min_duration, max_duration=max_duration ) self.featurizer = featurizer @@ -577,7 +578,7 @@ def __getitem__(self, index): offset = 0 features = self.featurizer.process( - sample.audio_file, offset=offset, duration=sample.duration, trim=self.trim, + sample.audio_file, offset=offset, duration=sample.duration, trim=self.trim ) f, fl = features, torch.tensor(features.shape[0]).long() else: @@ -1029,13 +1030,14 @@ def __init__( add_misc: bool = False, global_rank: int = 0, world_size: int = 0, + use_start_end_token: bool = True, ): - if hasattr(tokenizer, 'bos_token'): + if use_start_end_token and hasattr(tokenizer, 'bos_token'): bos_id = tokenizer.bos_id else: bos_id = None - if hasattr(tokenizer, 'eos_token'): + if use_start_end_token and hasattr(tokenizer, 'eos_token'): eos_id = tokenizer.eos_id else: eos_id = None diff --git a/nemo/collections/asr/losses/__init__.py b/nemo/collections/asr/losses/__init__.py index 9e3250071955..d5d8200ead58 100644 --- a/nemo/collections/asr/losses/__init__.py +++ b/nemo/collections/asr/losses/__init__.py @@ -11,3 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss +from nemo.collections.asr.losses.ctc import CTCLoss diff --git a/nemo/collections/asr/losses/ctc.py b/nemo/collections/asr/losses/ctc.py index dfd70100386a..e20c554d7262 100644 --- a/nemo/collections/asr/losses/ctc.py +++ b/nemo/collections/asr/losses/ctc.py @@ -44,10 +44,16 @@ def output_types(self): """ return {"loss": NeuralType(elements_type=LossType())} - def __init__(self, num_classes, zero_infinity=False): + def __init__(self, num_classes, zero_infinity=False, reduction='mean_batch'): self._blank = num_classes # Don't forget to properly call base constructor - super().__init__(blank=self._blank, reduction='none', zero_infinity=zero_infinity) + if reduction == 'mean_batch': + ctc_reduction = 'none' + self._apply_batch_mean = True + elif reduction in ['sum', 'mean', 'none']: + ctc_reduction = reduction + self._apply_batch_mean = False + super().__init__(blank=self._blank, reduction=ctc_reduction, zero_infinity=zero_infinity) @typecheck() def forward(self, log_probs, targets, input_lengths, target_lengths): @@ -61,7 +67,8 @@ def forward(self, log_probs, targets, input_lengths, target_lengths): loss = super().forward( log_probs=log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths ) - loss = torch.mean(loss) + if self._apply_batch_mean: + loss = torch.mean(loss) return loss diff --git a/nemo/collections/asr/metrics/wer.py b/nemo/collections/asr/metrics/wer.py index 40360e30f3b7..1a943d26b9aa 100644 --- a/nemo/collections/asr/metrics/wer.py +++ b/nemo/collections/asr/metrics/wer.py @@ -158,7 +158,7 @@ def update(self, predictions: torch.Tensor, targets: torch.Tensor, target_length if self.log_prediction: logging.info(f"\n") logging.info(f"reference:{references[0]}") - logging.info(f"decoded :{hypotheses[0]}") + logging.info(f"predicted:{hypotheses[0]}") for h, r in zip(hypotheses, references): if self.use_cer: diff --git a/nemo/collections/asr/metrics/wer_bpe.py b/nemo/collections/asr/metrics/wer_bpe.py index d7e0997d5dff..aad001fe81d7 100644 --- a/nemo/collections/asr/metrics/wer_bpe.py +++ b/nemo/collections/asr/metrics/wer_bpe.py @@ -121,7 +121,7 @@ def update(self, predictions: torch.Tensor, targets: torch.Tensor, target_length if self.log_prediction: logging.info(f"\n") logging.info(f"reference:{references[0]}") - logging.info(f"decoded :{hypotheses[0]}") + logging.info(f"predicted:{hypotheses[0]}") for h, r in zip(hypotheses, references): if self.use_cer: diff --git a/nemo/collections/asr/models/ctc_bpe_models.py b/nemo/collections/asr/models/ctc_bpe_models.py index 698f842bda5b..6fdce8a96106 100644 --- a/nemo/collections/asr/models/ctc_bpe_models.py +++ b/nemo/collections/asr/models/ctc_bpe_models.py @@ -17,7 +17,7 @@ from typing import Dict, Optional import torch -from omegaconf import DictConfig, ListConfig, OmegaConf +from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict from nemo.collections.asr.data.audio_to_text import AudioToBPEDataset, TarredAudioToBPEDataset from nemo.collections.asr.losses.ctc import CTCLoss @@ -26,7 +26,6 @@ from nemo.collections.asr.parts.perturb import process_augmentations from nemo.collections.common import tokenizers from nemo.core.classes.common import PretrainedModelInfo -from nemo.core.neural_types import * from nemo.utils import logging __all__ = ['EncDecCTCModelBPE', 'JasperNetBPE', 'QuartzNetBPE'] @@ -67,22 +66,39 @@ def __init__(self, cfg: DictConfig, trainer=None): vocabulary = self.tokenizer.tokenizer.get_vocab() # Set the new vocabulary - cfg.decoder.params.vocabulary = ListConfig(list(vocabulary.values())) + with open_dict(cfg): + if "params" in cfg.decoder: + cfg.decoder.params.vocabulary = ListConfig(list(vocabulary.values())) + else: + cfg.decoder.vocabulary = ListConfig(list(vocabulary.values())) # Override number of classes if placeholder provided - if cfg.decoder.params['num_classes'] < 1: + if "params" in cfg.decoder: + num_classes = cfg.decoder["params"]["num_classes"] + else: + num_classes = cfg.decoder["num_classes"] + + if num_classes < 1: logging.info( "\nReplacing placeholder number of classes ({}) with actual number of classes - {}".format( - cfg.decoder.params['num_classes'], len(vocabulary) + num_classes, len(vocabulary) ) ) - cfg.decoder.params['num_classes'] = len(vocabulary) + if "params" in cfg.decoder: + cfg.decoder["params"]["num_classes"] = len(vocabulary) + else: + cfg.decoder["num_classes"] = len(vocabulary) super().__init__(cfg=cfg, trainer=trainer) # Setup metric objects self._wer = WERBPE( - tokenizer=self.tokenizer, batch_dim_index=0, use_cer=False, ctc_decode=True, dist_sync_on_step=True, + tokenizer=self.tokenizer, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + ctc_decode=True, + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), ) def _setup_tokenizer(self): @@ -175,6 +191,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): add_misc=config.get('add_misc', False), global_rank=self.global_rank, world_size=self.world_size, + use_start_end_token=config.get('use_start_end_token', True), ) shuffle = False else: @@ -194,6 +211,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): trim=config.get('trim_silence', True), load_audio=config.get('load_audio', True), add_misc=config.get('add_misc', False), + use_start_end_token=config.get('use_start_end_token', True), ) return torch.utils.data.DataLoader( @@ -267,21 +285,40 @@ def change_vocabulary(self, new_tokenizer_dir: str, new_tokenizer_type: str): # Set the new vocabulary decoder_config = copy.deepcopy(self.decoder.to_config_dict()) - decoder_config.params.vocabulary = ListConfig(list(vocabulary.values())) + decoder_config.vocabulary = ListConfig(list(vocabulary.values())) + + if "params" in decoder_config: + decoder_num_classes = decoder_config['params']['num_classes'] + else: + decoder_num_classes = decoder_config['num_classes'] # Override number of classes if placeholder provided logging.info( "\nReplacing old number of classes ({}) with new number of classes - {}".format( - decoder_config['params']['num_classes'], len(vocabulary) + decoder_num_classes, len(vocabulary) ) ) - decoder_config['params']['num_classes'] = len(vocabulary) + + if "params" in decoder_config: + decoder_config['params']['num_classes'] = len(vocabulary) + else: + decoder_config['num_classes'] = len(vocabulary) del self.decoder self.decoder = EncDecCTCModelBPE.from_config_dict(decoder_config) del self.loss - self.loss = CTCLoss(num_classes=self.decoder.num_classes_with_blank - 1, zero_infinity=True) - self._wer = WERBPE(tokenizer=self.tokenizer, batch_dim_index=0, use_cer=False, ctc_decode=True) + self.loss = CTCLoss( + num_classes=self.decoder.num_classes_with_blank - 1, + zero_infinity=True, + reduction=self._cfg.get("ctc_reduction", "mean_batch"), + ) + self._wer = WERBPE( + tokenizer=self.tokenizer, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + ctc_decode=True, + log_prediction=self._cfg.get("log_prediction", False), + ) # Update config OmegaConf.set_struct(self._cfg.decoder, False) diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index effa30a4d011..8257e4170f1e 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -20,7 +20,7 @@ import onnx import torch -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig, OmegaConf, open_dict from pytorch_lightning import Trainer from nemo.collections.asr.data.audio_to_text import AudioToCharDataset, TarredAudioToCharDataset @@ -99,8 +99,31 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): super().__init__(cfg=cfg, trainer=trainer) self.preprocessor = EncDecCTCModel.from_config_dict(self._cfg.preprocessor) self.encoder = EncDecCTCModel.from_config_dict(self._cfg.encoder) + + with open_dict(self._cfg): + if "params" in self._cfg.decoder: + if "feat_in" not in self._cfg.decoder.params or ( + not self._cfg.decoder.params.feat_in and hasattr(self.encoder, '_feat_out') + ): + self._cfg.decoder.params.feat_in = self.encoder._feat_out + if "feat_in" not in self._cfg.decoder.params or not self._cfg.decoder.params.feat_in: + raise ValueError("param feat_in of the decoder's config is not set!") + else: + if "feat_in" not in self._cfg.decoder or ( + not self._cfg.decoder.feat_in and hasattr(self.encoder, '_feat_out') + ): + self._cfg.decoder.feat_in = self.encoder._feat_out + if "feat_in" not in self._cfg.decoder or not self._cfg.decoder.feat_in: + raise ValueError("param feat_in of the decoder's config is not set!") + self.decoder = EncDecCTCModel.from_config_dict(self._cfg.decoder) - self.loss = CTCLoss(num_classes=self.decoder.num_classes_with_blank - 1, zero_infinity=True) + + self.loss = CTCLoss( + num_classes=self.decoder.num_classes_with_blank - 1, + zero_infinity=True, + reduction=self._cfg.get("ctc_reduction", "mean_batch"), + ) + if hasattr(self._cfg, 'spec_augment') and self._cfg.spec_augment is not None: self.spec_augmentation = EncDecCTCModel.from_config_dict(self._cfg.spec_augment) else: @@ -110,9 +133,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self._wer = WER( vocabulary=self.decoder.vocabulary, batch_dim_index=0, - use_cer=False, + use_cer=self._cfg.get('use_cer', False), ctc_decode=True, dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), ) @torch.no_grad() @@ -196,18 +220,23 @@ def change_vocabulary(self, new_vocabulary: List[str]): raise ValueError(f'New vocabulary must be non-empty list of chars. But I got: {new_vocabulary}') decoder_config = self.decoder.to_config_dict() new_decoder_config = copy.deepcopy(decoder_config) - new_decoder_config['params']['vocabulary'] = new_vocabulary - new_decoder_config['params']['num_classes'] = len(new_vocabulary) + new_decoder_config['vocabulary'] = new_vocabulary + new_decoder_config['num_classes'] = len(new_vocabulary) del self.decoder self.decoder = EncDecCTCModel.from_config_dict(new_decoder_config) del self.loss - self.loss = CTCLoss(num_classes=self.decoder.num_classes_with_blank - 1, zero_infinity=True) + self.loss = CTCLoss( + num_classes=self.decoder.num_classes_with_blank - 1, + zero_infinity=True, + reduction=self._cfg.get("ctc_reduction", "mean_batch"), + ) self._wer = WER( vocabulary=self.decoder.vocabulary, batch_dim_index=0, - use_cer=False, + use_cer=self._cfg.get('use_cer', False), ctc_decode=True, dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), ) # Update config diff --git a/nemo/collections/asr/modules/__init__.py b/nemo/collections/asr/modules/__init__.py index a2112f0ec53c..36d553d1b34f 100644 --- a/nemo/collections/asr/modules/__init__.py +++ b/nemo/collections/asr/modules/__init__.py @@ -19,9 +19,11 @@ SpectrogramAugmentation, ) from nemo.collections.asr.modules.beam_search_decoder import BeamSearchDecoderWithLM +from nemo.collections.asr.modules.conformer_encoder import ConformerEncoder from nemo.collections.asr.modules.conv_asr import ( ConvASRDecoder, ConvASRDecoderClassification, ConvASREncoder, SpeakerDecoder, ) +from nemo.collections.asr.modules.lstm_decoder import LSTMDecoder diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py new file mode 100644 index 000000000000..e80d1aa64b84 --- /dev/null +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -0,0 +1,222 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections import OrderedDict + +import torch +import torch.nn as nn + +from nemo.collections.asr.parts.conformer_modules import ConformerEncoderBlock +from nemo.collections.asr.parts.multi_head_attention import PositionalEncoding, RelPositionalEncoding +from nemo.collections.asr.parts.subsampling import ConvSubsampling +from nemo.core.classes.common import typecheck +from nemo.core.classes.exportable import Exportable +from nemo.core.classes.module import NeuralModule +from nemo.core.neural_types import AcousticEncodedRepresentation, LengthsType, NeuralType, SpectrogramType + +__all__ = ['ConformerEncoder'] + + +class ConformerEncoder(NeuralModule, Exportable): + """ + The encoder for ASR model of Conformer. + Based on this paper: + 'Conformer: Convolution-augmented Transformer for Speech Recognition' by Anmol Gulati et al. + https://arxiv.org/abs/2005.08100 + + Args: + feat_in (int): the size of feature channels + n_layers (int): number of layers of ConformerBlock + d_model (int): the hidden size of the model + feat_out (int): the size of the output features + Defaults to -1 (means feat_out is d_model) + subsampling (str): the method of subsampling, choices=['vggnet', 'striding'] + subsampling_factor (int): the subsampling factor which should be power of 2 + Defaults to 4. + subsampling_conv_channels (int): the size of the convolutions in the subsampling module + Defaults to 64. + ff_expansion_factor (int): the expansion factor in feed forward layers + Defaults to 4. + self_attention_model (str): type of the attention layer and positional encoding + choices=['rel_pos', 'abs_pos']. + pos_emb_max_len (int): the maximum length of positional embeddings + Defaulst to 5000 + n_heads (int): number of heads in multi-headed attention layers + Defaults to 4. + xscaling (bool): enables scaling the inputs to the multi-headed attention layers by sqrt(d_model) + Defaults to True. + conv_kernel_size (int): the size of the convolutions in the convolutional modules + Defaults to 31. + dropout (float): the dropout rate used in all layers except the attention layers + Defaults to 0.1. + dropout_emb (float): the dropout rate used for the positional embeddings + Defaults to 0.1. + dropout_att (float): the dropout rate used for the attention layer + Defaults to 0.0. + """ + + def _prepare_for_export(self): + Exportable._prepare_for_export(self) + + def input_example(self): + """ + Generates input examples for tracing etc. + Returns: + A tuple of input examples. + """ + input_example = torch.randn(16, self._feat_in, 256).to(next(self.parameters()).device) + return tuple([input_example]) + + @property + def input_types(self): + """Returns definitions of module input ports. + """ + return OrderedDict( + { + "audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()), + "length": NeuralType(tuple('B'), LengthsType()), + } + ) + + @property + def output_types(self): + """Returns definitions of module output ports. + """ + return OrderedDict( + { + "outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + } + ) + + def __init__( + self, + feat_in, + n_layers, + d_model, + feat_out=-1, + subsampling='vggnet', + subsampling_factor=4, + subsampling_conv_channels=64, + ff_expansion_factor=4, + self_attention_model='rel_pos', + pos_emb_max_len=5000, + n_heads=4, + xscaling=True, + conv_kernel_size=31, + dropout=0.1, + dropout_emb=0.1, + dropout_att=0.0, + ): + super().__init__() + + d_ff = d_model * ff_expansion_factor + self.d_model = d_model + self.scale = math.sqrt(self.d_model) + + if xscaling: + self.xscale = math.sqrt(d_model) + else: + self.xscale = None + + if subsampling: + self.pre_encode = ConvSubsampling( + subsampling=subsampling, + subsampling_factor=subsampling_factor, + feat_in=feat_in, + feat_out=d_model, + conv_channels=subsampling_conv_channels, + activation=nn.ReLU(), + ) + self._feat_out = d_model + else: + self._feat_out = d_model + self.pre_encode = nn.Linear(feat_in, d_model) + + if self_attention_model == "rel_pos": + self.pos_enc = RelPositionalEncoding( + d_model=d_model, + dropout_rate=dropout, + max_len=pos_emb_max_len, + xscale=self.xscale, + dropout_emb_rate=dropout_emb, + ) + elif self_attention_model == "abs_pos": + self.pos_enc = PositionalEncoding( + d_model=d_model, dropout_rate=dropout, max_len=pos_emb_max_len, reverse=False, xscale=self.xscale + ) + else: + raise ValueError(f"Not valid self_attention_model: '{self_attention_model}'!") + + self.layers = nn.ModuleList() + for i in range(n_layers): + layer = ConformerEncoderBlock( + d_model=d_model, + d_ff=d_ff, + conv_kernel_size=conv_kernel_size, + self_attention_model=self_attention_model, + n_heads=n_heads, + dropout=dropout, + dropout_att=dropout_att, + ) + self.layers.append(layer) + + if feat_out > 0 and feat_out != self.output_dim: + self.out_proj = nn.Linear(self.feat_out, feat_out) + self._feat_out = feat_out + else: + self.out_proj = None + self._feat_out = d_model + + @typecheck() + def forward(self, audio_signal, length): + audio_signal = torch.transpose(audio_signal, 1, 2) + + if isinstance(self.pre_encode, ConvSubsampling): + audio_signal, length = self.pre_encode(audio_signal, length) + else: + audio_signal = self.embed(audio_signal) + + audio_signal, pos_emb = self.pos_enc(audio_signal) + bs, xmax, idim = audio_signal.size() + + # Create the self-attention and padding masks + pad_mask = self.make_pad_mask(length, max_time=xmax, device=audio_signal.device) + xx_mask = pad_mask.unsqueeze(1).repeat([1, xmax, 1]) + xx_mask = xx_mask & xx_mask.transpose(1, 2) + pad_mask = (~pad_mask).unsqueeze(2) + + for lth, layer in enumerate(self.layers): + audio_signal = layer(x=audio_signal, att_mask=xx_mask, pos_emb=pos_emb, pad_mask=pad_mask,) + + if self.out_proj is not None: + audio_signal = self.out_proj(audio_signal) + + audio_signal = torch.transpose(audio_signal, 1, 2) + return audio_signal, length + + @staticmethod + def make_pad_mask(seq_lens, max_time, device=None): + """Make masking for padding.""" + bs = seq_lens.size(0) + seq_range = torch.arange(0, max_time, dtype=torch.int32) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_time) + seq_lens = seq_lens.type(seq_range_expand.dtype).to(seq_range_expand.device) + seq_length_expand = seq_lens.unsqueeze(-1) + mask = seq_range_expand < seq_length_expand + + if device: + mask = mask.to(device) + return mask diff --git a/nemo/collections/asr/modules/conv_asr.py b/nemo/collections/asr/modules/conv_asr.py index 83b50f41e375..45073fbbc4de 100644 --- a/nemo/collections/asr/modules/conv_asr.py +++ b/nemo/collections/asr/modules/conv_asr.py @@ -63,7 +63,7 @@ def input_example(self): Returns: A tuple of input examples. """ - input_example = torch.randn(16, self.__feat_in, 256).to(next(self.parameters()).device) + input_example = torch.randn(16, self._feat_in, 256).to(next(self.parameters()).device) return tuple([input_example]) @property @@ -124,7 +124,7 @@ def __init__( activation = jasper_activations[activation]() feat_in = feat_in * frame_splicing - self.__feat_in = feat_in + self._feat_in = feat_in residual_panes = [] encoder_layers = [] @@ -174,6 +174,8 @@ def __init__( ) feat_in = lcfg['filters'] + self._feat_out = feat_in + self.encoder = torch.nn.Sequential(*encoder_layers) self.apply(lambda x: init_weights(x, mode=init_mode)) diff --git a/nemo/collections/asr/modules/lstm_decoder.py b/nemo/collections/asr/modules/lstm_decoder.py new file mode 100644 index 000000000000..5552638c8b8b --- /dev/null +++ b/nemo/collections/asr/modules/lstm_decoder.py @@ -0,0 +1,98 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict + +import torch +import torch.nn as nn + +from nemo.core.classes.common import typecheck +from nemo.core.classes.exportable import Exportable +from nemo.core.classes.module import NeuralModule +from nemo.core.neural_types import AcousticEncodedRepresentation, LogprobsType, NeuralType + +__all__ = ['LSTMDecoder'] + + +class LSTMDecoder(NeuralModule, Exportable): + """ + Simple LSTM Decoder for ASR models + Args: + feat_in (int): size of the input features + num_classes (int): the size of the vocabulary + lstm_hidden_size (int): hidden size of the LSTM layers + vocabulary (vocab): The vocabulary + bidirectional (bool): default is False. Whether LSTMs are bidirectional or not + num_layers (int): default is 1. Number of LSTM layers stacked + """ + + @property + def input_types(self): + return OrderedDict({"encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation())}) + + @property + def output_types(self): + return OrderedDict({"logprobs": NeuralType(('B', 'T', 'D'), LogprobsType())}) + + def __init__(self, feat_in, num_classes, lstm_hidden_size, vocabulary=None, bidirectional=False, num_layers=1): + super().__init__() + + if vocabulary is not None: + if num_classes != len(vocabulary): + raise ValueError( + f"If vocabulary is specified, it's length should be equal to the num_classes. " + f"Instead got: num_classes={num_classes} and len(vocabulary)={len(vocabulary)}" + ) + self.__vocabulary = vocabulary + self._feat_in = feat_in + # Add 1 for blank char + self._num_classes = num_classes + 1 + + self.lstm_layer = nn.LSTM( + input_size=feat_in, + hidden_size=lstm_hidden_size, + num_layers=num_layers, + batch_first=True, + bidirectional=bidirectional, + ) + self.linear_layer = torch.nn.Linear(in_features=lstm_hidden_size, out_features=self._num_classes) + + @typecheck() + def forward(self, encoder_output): + output = encoder_output.transpose(1, 2) + output, _ = self.lstm_layer(output) + output = self.linear_layer(output) + return torch.nn.functional.log_softmax(output, dim=-1) + + def input_example(self): + """ + Generates input examples for tracing etc. + Returns: + A tuple of input examples. + """ + bs = 8 + seq = 64 + input_example = torch.randn(bs, self._feat_in, seq).to(next(self.parameters()).device) + return tuple([input_example]) + + def _prepare_for_export(self): + Exportable._prepare_for_export(self) + + @property + def vocabulary(self): + return self.__vocabulary + + @property + def num_classes_with_blank(self): + return self._num_classes diff --git a/nemo/collections/asr/parts/activations.py b/nemo/collections/asr/parts/activations.py new file mode 100644 index 000000000000..627eef295717 --- /dev/null +++ b/nemo/collections/asr/parts/activations.py @@ -0,0 +1,27 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + +__all__ = ['Swish'] + + +class Swish(nn.Module): + """ + Swish activation function introduced in 'https://arxiv.org/abs/1710.05941' + """ + + def forward(self, x): + return x * torch.sigmoid(x) diff --git a/nemo/collections/asr/parts/conformer_modules.py b/nemo/collections/asr/parts/conformer_modules.py new file mode 100644 index 000000000000..ef8943d4d510 --- /dev/null +++ b/nemo/collections/asr/parts/conformer_modules.py @@ -0,0 +1,170 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import torch +from torch import nn as nn +from torch.nn import LayerNorm + +from nemo.collections.asr.parts.activations import Swish +from nemo.collections.asr.parts.multi_head_attention import MultiHeadAttention, RelPositionMultiHeadAttention + +__all__ = ['ConformerConvolution', 'ConformerFeedForward', 'ConformerEncoderBlock'] + + +class ConformerEncoderBlock(torch.nn.Module): + """A single block of the Conformer encoder. + + Args: + d_model (int): input dimension of MultiheadAttentionMechanism and PositionwiseFeedForward + d_ff (int): hidden dimension of PositionwiseFeedForward + n_heads (int): number of heads for multi-head attention + conv_kernel_size (int): kernel size for depthwise convolution in convolution module + dropout (float): dropout probabilities for linear layers + dropout_att (float): dropout probabilities for attention distributions + """ + + def __init__(self, d_model, d_ff, conv_kernel_size, self_attention_model, n_heads, dropout, dropout_att): + super(ConformerEncoderBlock, self).__init__() + + self.self_attention_model = self_attention_model + self.n_heads = n_heads + self.fc_factor = 0.5 + + # first feed forward module + self.norm_feed_forward1 = LayerNorm(d_model) + self.feed_forward1 = ConformerFeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout) + + # convolution module + self.norm_conv = LayerNorm(d_model) + self.conv = ConformerConvolution(d_model=d_model, kernel_size=conv_kernel_size) + + # multi-headed self-attention module + self.norm_self_att = LayerNorm(d_model) + if self_attention_model == 'rel_pos': + self.self_attn = RelPositionMultiHeadAttention(n_head=n_heads, n_feat=d_model, dropout_rate=dropout_att) + elif self_attention_model == 'abs_pos': + self.self_attn = MultiHeadAttention(n_head=n_heads, n_feat=d_model, dropout_rate=dropout_att) + else: + raise ValueError(f"Not valid self_attention_model: '{self_attention_model}'!") + + # second feed forward module + self.norm_feed_forward2 = LayerNorm(d_model) + self.feed_forward2 = ConformerFeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout) + + self.dropout = nn.Dropout(dropout) + self.norm_out = LayerNorm(d_model) + + def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None): + """ + Args: + x (torch.Tensor): input signals (B, T, d_model) + att_mask (torch.Tensor): attention masks(B, T, T) + pos_emb (torch.Tensor): (L, 1, d_model) + pad_mask (torch.tensor): padding mask + Returns: + x (torch.Tensor): (B, T, d_model) + """ + residual = x + x = self.norm_feed_forward1(x) + x = self.feed_forward1(x) + x = self.fc_factor * self.dropout(x) + residual + + residual = x + x = self.norm_self_att(x) + if self.self_attention_model == 'rel_pos': + x = self.self_attn(query=x, key=x, value=x, pos_emb=pos_emb, mask=att_mask) + elif self.self_attention_model == 'abs_pos': + x = self.self_attn(query=x, key=x, value=x, mask=att_mask) + else: + x = None + x = self.dropout(x) + residual + + residual = x + x = self.norm_conv(x) + x = self.conv(x) + x = self.dropout(x) + residual + + residual = x + x = self.norm_feed_forward2(x) + x = self.feed_forward2(x) + x = self.fc_factor * self.dropout(x) + residual + + x = self.norm_out(x) + return x + + +class ConformerConvolution(nn.Module): + """The convolution module for the Conformer model. + Args: + d_model (int): hidden dimension + kernel_size (int): kernel size for depthwise convolution + """ + + def __init__(self, d_model, kernel_size): + super(ConformerConvolution, self).__init__() + assert (kernel_size - 1) % 2 == 0 + self.d_model = d_model + + self.pointwise_conv1 = nn.Conv1d( + in_channels=d_model, out_channels=d_model * 2, kernel_size=1, stride=1, padding=0, bias=True + ) + self.depthwise_conv = nn.Conv1d( + in_channels=d_model, + out_channels=d_model, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=d_model, + bias=True, + ) + self.batch_norm = nn.BatchNorm1d(d_model) + self.activation = Swish() + self.pointwise_conv2 = nn.Conv1d( + in_channels=d_model, out_channels=d_model, kernel_size=1, stride=1, padding=0, bias=True + ) + + def forward(self, x): + x = x.transpose(1, 2) + x = self.pointwise_conv1(x) + + x = nn.functional.glu(x, dim=1) + + x = self.depthwise_conv(x) + + x = self.batch_norm(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) + x = x.transpose(1, 2) + return x + + +class ConformerFeedForward(nn.Module): + """ + feed-forward module of Conformer model. + """ + + def __init__(self, d_model, d_ff, dropout, activation=Swish()): + super(ConformerFeedForward, self).__init__() + self.linear1 = nn.Linear(d_model, d_ff) + self.activation = activation + self.dropout = nn.Dropout(p=dropout) + self.linear2 = nn.Linear(d_ff, d_model) + + def forward(self, x): + x = self.linear1(x) + x = self.activation(x) + x = self.dropout(x) + x = self.linear2(x) + return x diff --git a/nemo/collections/asr/parts/features.py b/nemo/collections/asr/parts/features.py index cd4a8f489fde..424ea65daaa9 100644 --- a/nemo/collections/asr/parts/features.py +++ b/nemo/collections/asr/parts/features.py @@ -285,7 +285,7 @@ def __init__( highfreq = highfreq or sample_rate / 2 filterbanks = torch.tensor( - librosa.filters.mel(sample_rate, self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq), dtype=torch.float, + librosa.filters.mel(sample_rate, self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq), dtype=torch.float ).unsqueeze(0) self.register_buffer("fb", filterbanks) @@ -388,4 +388,5 @@ def forward(self, x, seq_len): pad_amt = x.size(-1) % pad_to if pad_amt != 0: x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value) + return x, seq_len diff --git a/nemo/collections/asr/parts/jasper.py b/nemo/collections/asr/parts/jasper.py index 633f135ca4fd..c2f4cdd7ccbc 100644 --- a/nemo/collections/asr/parts/jasper.py +++ b/nemo/collections/asr/parts/jasper.py @@ -19,11 +19,9 @@ import torch.nn as nn from torch import Tensor -jasper_activations = { - "hardtanh": nn.Hardtanh, - "relu": nn.ReLU, - "selu": nn.SELU, -} +from nemo.collections.asr.parts.activations import Swish + +jasper_activations = {"hardtanh": nn.Hardtanh, "relu": nn.ReLU, "selu": nn.SELU, "swish": Swish} def init_weights(m, mode='xavier_uniform'): @@ -257,11 +255,6 @@ def forward(self, x): return x * y -class Swish(nn.Module): - def forward(self, x): - return x * torch.sigmoid(x) - - class JasperBlock(nn.Module): __constants__ = ["conv_mask", "separable", "residual_mode", "res", "mconv"] @@ -561,7 +554,3 @@ def forward(self, input_: Tuple[List[Tensor], Optional[Tensor]]): return xs + [out], lens return [out], lens - - -# Register swish activation function -jasper_activations['swish'] = Swish diff --git a/nemo/collections/asr/parts/multi_head_attention.py b/nemo/collections/asr/parts/multi_head_attention.py new file mode 100644 index 000000000000..b36631d9eb46 --- /dev/null +++ b/nemo/collections/asr/parts/multi_head_attention.py @@ -0,0 +1,296 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Part of this code is adopted from https://github.com/espnet/espnet +""" + +import math + +import numpy as np +import torch +import torch.nn as nn + +__all__ = [ + 'RelPositionMultiHeadAttention', + 'RelPositionalEncoding', + 'PositionalEncoding', +] + + +class MultiHeadAttention(nn.Module): + """Multi-Head Attention layer. + Args: + n_head (int): number of heads + n_feat (int): size of the features + dropout_rate (float): dropout rate + """ + + def __init__(self, n_head, n_feat, dropout_rate): + """Construct an MultiHeadedAttention object.""" + super(MultiHeadAttention, self).__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.attn = None + self.dropout = nn.Dropout(p=dropout_rate) + + def forward_qkv(self, query, key, value): + """Transforms query, key and value. + Args: + query (torch.Tensor): (batch, time1, size) + key (torch.Tensor): (batch, time2, size) + value (torch.Tensor): (batch, time2, size) + returns: + q (torch.Tensor): (batch, head, time1, size) + k (torch.Tensor): (batch, head, time2, size) + v (torch.Tensor): (batch, head, time2, size) + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + return q, k, v + + def forward_attention(self, value, scores, mask): + """Compute attention context vector. + Args: + value (torch.Tensor): (batch, time2, size) + scores(torch.Tensor): (batch, time1, time2) + mask(torch.Tensor): (batch, time1, time2) + returns: + value (torch.Tensor): transformed `value` (batch, time2, d_model) weighted by the attention scores + """ + n_batch = value.size(0) + if mask is not None: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2) + if scores.dtype == torch.float16: + dtype = np.float16 + else: + dtype = np.float32 + min_value = np.finfo(dtype).min + scores = scores.masked_fill(mask, min_value) + self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2) + else: + self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(self.attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, query, key, value, mask): + """Compute 'Scaled Dot Product Attention'. + Args: + query (torch.Tensor): (batch, time1, size) + key (torch.Tensor): (batch, time2, size) + value(torch.Tensor): (batch, time2, size) + mask (torch.Tensor): (batch, time1, time2) + returns: + output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention + """ + q, k, v = self.forward_qkv(query, key, value) + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask) + + +class RelPositionMultiHeadAttention(MultiHeadAttention): + """Multi-Head Attention layer with relative position encoding. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): number of heads + n_feat (int): size of the features + dropout_rate (float): dropout rate + """ + + def __init__(self, n_head, n_feat, dropout_rate): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate) + # linear transformation for positional ecoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x, zero_triu=False): + """Compute relative positinal encoding. + Args: + x (torch.Tensor): (batch, time, size) + zero_triu (bool): return the lower triangular part of the matrix + """ + zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x) + + if zero_triu: + ones = torch.ones((x.size(2), x.size(3))) + x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + + return x + + def forward(self, query, key, value, mask, pos_emb): + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Args: + query (torch.Tensor): (batch, time1, size) + key (torch.Tensor): (batch, time2, size) + value(torch.Tensor): (batch, time2, size) + mask (torch.Tensor): (batch, time1, time2) + pos_emb (torch.Tensor) : (batch, time1, size) + Returns: + output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention + """ + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, time2) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask) + + +class PositionalEncoding(torch.nn.Module): + """Positional encoding. + Args: + d_model (int): embedding dim + dropout_rate (float): dropout rate + max_len (int): maximum input length + reverse (int): whether to reverse the input position + """ + + def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False, xscale=None): + """Construct an PositionalEncoding object.""" + super(PositionalEncoding, self).__init__() + self.d_model = d_model + self.reverse = reverse + self.xscale = xscale # math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + if self.reverse: + position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1) + else: + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor): + """Add positional encoding. + Args: + x (torch.Tensor): Input. Its shape is (batch, time, ...) + Returns: + Encoded Output (torch.Tensor): Its shape is (batch, time, ...) + """ + self.extend_pe(x) + if self.xscale: + x = x * self.xscale + x = x + self.pe[:, : x.size(1)] + return self.dropout(x), None + + +class RelPositionalEncoding(PositionalEncoding): + """Relative positional encoding module. + See : Appendix B in https://arxiv.org/abs/1901.02860 + Args: + d_model (int): embedding dim + dropout_rate (float): dropout rate + max_len (int): maximum input length + """ + + def __init__(self, d_model, dropout_rate, max_len=5000, xscale=None, dropout_emb_rate=0.0): + super().__init__(d_model, dropout_rate, max_len, reverse=True, xscale=xscale) + + if dropout_emb_rate > 0: + self.dropout_emb = nn.Dropout(dropout_emb_rate) + else: + self.dropout_emb = None + + def forward(self, x): + """Compute positional encoding. + Args: + x (torch.Tensor): Input. Its shape is (batch, time, ...) + Returns: + x (torch.Tensor): Its shape is (batch, time, ...) + pos_emb (torch.Tensor): Its shape is (1, time, ...) + """ + self.extend_pe(x) + if self.xscale: + x = x * self.xscale + pos_emb = self.pe[:, : x.size(1)] + if self.dropout_emb: + pos_emb = self.dropout_emb(pos_emb) + return self.dropout(x), pos_emb diff --git a/nemo/collections/asr/parts/subsampling.py b/nemo/collections/asr/parts/subsampling.py new file mode 100644 index 000000000000..85f31611b8b4 --- /dev/null +++ b/nemo/collections/asr/parts/subsampling.py @@ -0,0 +1,138 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.nn as nn + + +class ConvSubsampling(torch.nn.Module): + """Convolutional subsampling which supports VGGNet and striding approach introduced in: + VGGNet Subsampling: https://arxiv.org/pdf/1910.12977.pdf + Striding Subsampling: + "Speech-Transformer: A No-Recurrence Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong et al. + Args: + subsampling (str): The subsampling technique from {"vggnet", "striding"} + subsampling_factor (int): The subsampling factor which should be a power of 2 + feat_in (int): size of the input features + feat_out (int): size of the output features + conv_channels (int): Number of channels for the convolution layers. + activation (Module): activation function, default is nn.ReLU() + """ + + def __init__(self, subsampling, subsampling_factor, feat_in, feat_out, conv_channels, activation=nn.ReLU()): + super(ConvSubsampling, self).__init__() + self._subsampling = subsampling + + if subsampling_factor % 2 != 0: + raise ValueError("Sampling factor should be a multiply of 2!") + self._sampling_num = int(math.log(subsampling_factor, 2)) + + in_channels = 1 + layers = [] + if subsampling == 'vggnet': + self._padding = 0 + self._stride = 2 + self._kernel_size = 2 + self._ceil_mode = True + + for i in range(self._sampling_num): + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, out_channels=conv_channels, kernel_size=3, stride=1, padding=1 + ) + ) + layers.append(activation) + layers.append( + torch.nn.Conv2d( + in_channels=conv_channels, out_channels=conv_channels, kernel_size=3, stride=1, padding=1 + ) + ) + layers.append(activation) + layers.append( + torch.nn.MaxPool2d( + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._padding, + ceil_mode=self._ceil_mode, + ) + ) + in_channels = conv_channels + elif subsampling == 'striding': + self._padding = 1 + self._stride = 2 + self._kernel_size = 3 + self._ceil_mode = False + + for i in range(self._sampling_num): + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._padding, + ) + ) + layers.append(activation) + in_channels = conv_channels + else: + raise ValueError(f"Not valid sub-sampling: {subsampling}!") + + in_length = feat_in + for i in range(self._sampling_num): + out_length = calc_length( + length=int(in_length), + padding=self._padding, + kernel_size=self._kernel_size, + stride=self._stride, + ceil_mode=self._ceil_mode, + ) + in_length = out_length + + self.out = torch.nn.Linear(conv_channels * out_length, feat_out) + self.conv = torch.nn.Sequential(*layers) + + def forward(self, x, lengths): + x = x.unsqueeze(1) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + + # TODO: improve the performance of length calculation + new_lengths = lengths + for i in range(self._sampling_num): + new_lengths = [ + calc_length( + length=int(length), + padding=self._padding, + kernel_size=self._kernel_size, + stride=self._stride, + ceil_mode=self._ceil_mode, + ) + for length in new_lengths + ] + + new_lengths = torch.IntTensor(new_lengths).to(lengths.device) + return x, new_lengths + + +def calc_length(length, padding, kernel_size, stride, ceil_mode): + """ Calculates the output length of a Tensor passed through a convolution or max pooling layer""" + if ceil_mode: + length = math.ceil((length + (2 * padding) - (kernel_size - 1) - 1) / float(stride) + 1) + else: + length = math.floor((length + (2 * padding) - (kernel_size - 1) - 1) / float(stride) + 1) + return length diff --git a/nemo/core/config/__init__.py b/nemo/core/config/__init__.py index a92557c44cec..9eb915cc4ac5 100644 --- a/nemo/core/config/__init__.py +++ b/nemo/core/config/__init__.py @@ -32,6 +32,7 @@ from nemo.core.config.schedulers import ( CosineAnnealingParams, InverseSquareRootAnnealingParams, + NoamAnnealingParams, PolynomialDecayAnnealingParams, PolynomialHoldDecayAnnealingParams, SchedulerParams, diff --git a/nemo/core/config/schedulers.py b/nemo/core/config/schedulers.py index 884808a38636..629a1fb9fcef 100644 --- a/nemo/core/config/schedulers.py +++ b/nemo/core/config/schedulers.py @@ -80,6 +80,16 @@ class CosineAnnealingParams(WarmupSchedulerParams): min_lr: float = 0.0 +@dataclass +class NoamAnnealingParams(WarmupSchedulerParams): + """ + Cosine Annealing parameter config + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + """ + + min_lr: float = 0.0 + + @dataclass class WarmupAnnealingParams(WarmupSchedulerParams): """ @@ -233,6 +243,7 @@ def get_scheduler_config(name: str, **kwargs: Optional[Dict[str, Any]]) -> Sched 'SquareRootAnnealingParams': SquareRootAnnealingParams, 'InverseSquareRootAnnealingParams': InverseSquareRootAnnealingParams, 'CosineAnnealingParams': CosineAnnealingParams, + 'NoamAnnealingParams': NoamAnnealingParams, 'WarmupAnnealingParams': WarmupAnnealingParams, 'PolynomialDecayAnnealingParams': PolynomialDecayAnnealingParams, 'PolynomialHoldDecayAnnealingParams': PolynomialHoldDecayAnnealingParams, diff --git a/nemo/core/optim/__init__.py b/nemo/core/optim/__init__.py index a194068db7a8..b421582c0034 100644 --- a/nemo/core/optim/__init__.py +++ b/nemo/core/optim/__init__.py @@ -15,6 +15,7 @@ from nemo.core.optim.lr_scheduler import ( CosineAnnealing, InverseSquareRootAnnealing, + NoamAnnealing, PolynomialDecayAnnealing, PolynomialHoldDecayAnnealing, SquareAnnealing, diff --git a/nemo/core/optim/lr_scheduler.py b/nemo/core/optim/lr_scheduler.py index 94b340cda482..0bcb7a574712 100644 --- a/nemo/core/optim/lr_scheduler.py +++ b/nemo/core/optim/lr_scheduler.py @@ -62,7 +62,7 @@ def __init__(self, optimizer, *, warmup_steps=None, warmup_ratio=None, max_steps def get_lr(self): if not self._get_lr_called_within_step: warnings.warn( - "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning + "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning ) step = self.last_epoch @@ -214,7 +214,7 @@ def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs): def _get_lr(self, step): new_lrs = [ - _squareroot_annealing(initial_lr=initial_lr, step=step, max_steps=self.max_steps, min_lr=self.min_lr,) + _squareroot_annealing(initial_lr=initial_lr, step=step, max_steps=self.max_steps, min_lr=self.min_lr) for initial_lr in self.base_lrs ] return new_lrs @@ -228,7 +228,7 @@ def _get_lr(self, step): for initial_lr in self.base_lrs: if initial_lr < self.min_lr: raise ValueError( - f"{self} received an initial learning rate that " f"was lower than the minimum learning rate." + f"{self} received an initial learning rate that was lower than the minimum learning rate." ) new_lrs = [ @@ -243,6 +243,57 @@ def _get_lr(self, step): return new_lrs +class NoamAnnealing(_LRScheduler): + def __init__( + self, optimizer, *, d_model, warmup_steps=None, warmup_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1 + ): + self._normalize = d_model ** (-0.5) + assert not ( + warmup_steps is not None and warmup_ratio is not None + ), "Either use particular number of step or ratio" + assert warmup_ratio is None or max_steps is not None, "If there is a ratio, there should be a total steps" + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning + ) + + step = max(1, self.last_epoch) + + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + for initial_lr in self.base_lrs: + if initial_lr < self.min_lr: + raise ValueError( + f"{self} received an initial learning rate that was lower than the minimum learning rate." + ) + + new_lrs = [self._noam_annealing(initial_lr=initial_lr, step=step) for initial_lr in self.base_lrs] + return new_lrs + + def _noam_annealing(self, initial_lr, step): + mult = self._normalize * min(step ** (-0.5), step * (self.warmup_steps ** (-1.5))) + out_lr = initial_lr * mult + if step > self.warmup_steps: + out_lr = max(out_lr, self.min_lr) + return out_lr + + class WarmupAnnealing(WarmupPolicy): def __init__(self, optimizer, *, max_steps, last_epoch=-1, min_lr=0.0, **kwargs): super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs) @@ -601,6 +652,7 @@ def compute_max_steps(max_epochs, accumulate_grad_batches, num_workers, num_samp 'WarmupHoldPolicy': WarmupHoldPolicy, 'SquareAnnealing': SquareAnnealing, 'CosineAnnealing': CosineAnnealing, + 'NoamAnnealing': NoamAnnealing, 'WarmupAnnealing': WarmupAnnealing, 'InverseSquareRootAnnealing': InverseSquareRootAnnealing, 'SquareRootAnnealing': SquareRootAnnealing, diff --git a/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py b/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py index 3b9cca9ad302..fa8bba3afd03 100644 --- a/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py +++ b/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py @@ -25,33 +25,33 @@ @pytest.fixture() def asr_model(test_data_dir): - preprocessor = {'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'params': dict({})} + preprocessor = {'_target_': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor'} encoder = { - 'cls': 'nemo.collections.asr.modules.ConvASREncoder', - 'params': { - 'feat_in': 64, - 'activation': 'relu', - 'conv_mask': True, - 'jasper': [ - { - 'filters': 1024, - 'repeat': 1, - 'kernel': [1], - 'stride': [1], - 'dilation': [1], - 'dropout': 0.0, - 'residual': False, - 'separable': True, - 'se': True, - 'se_context_size': -1, - } - ], - }, + '_target_': 'nemo.collections.asr.modules.ConvASREncoder', + 'feat_in': 64, + 'activation': 'relu', + 'conv_mask': True, + 'jasper': [ + { + 'filters': 1024, + 'repeat': 1, + 'kernel': [1], + 'stride': [1], + 'dilation': [1], + 'dropout': 0.0, + 'residual': False, + 'separable': True, + 'se': True, + 'se_context_size': -1, + } + ], } decoder = { - 'cls': 'nemo.collections.asr.modules.ConvASRDecoder', - 'params': {'feat_in': 1024, 'num_classes': -1, 'vocabulary': None}, + '_target_': 'nemo.collections.asr.modules.ConvASRDecoder', + 'feat_in': 1024, + 'num_classes': -1, + 'vocabulary': None, } tokenizer = {'dir': os.path.join(test_data_dir, "asr", "tokenizers", "an4_wpe_128"), 'type': 'wpe'} diff --git a/tests/collections/asr/test_asr_ctcencdec_model.py b/tests/collections/asr/test_asr_ctcencdec_model.py index 151b7e96fb69..7df1c8393e4d 100644 --- a/tests/collections/asr/test_asr_ctcencdec_model.py +++ b/tests/collections/asr/test_asr_ctcencdec_model.py @@ -23,64 +23,60 @@ def asr_model(): preprocessor = {'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'params': dict({})} encoder = { - 'cls': 'nemo.collections.asr.modules.ConvASREncoder', - 'params': { - 'feat_in': 64, - 'activation': 'relu', - 'conv_mask': True, - 'jasper': [ - { - 'filters': 1024, - 'repeat': 1, - 'kernel': [1], - 'stride': [1], - 'dilation': [1], - 'dropout': 0.0, - 'residual': False, - 'separable': True, - 'se': True, - 'se_context_size': -1, - } - ], - }, + '_target_': 'nemo.collections.asr.modules.ConvASREncoder', + 'feat_in': 64, + 'activation': 'relu', + 'conv_mask': True, + 'jasper': [ + { + 'filters': 1024, + 'repeat': 1, + 'kernel': [1], + 'stride': [1], + 'dilation': [1], + 'dropout': 0.0, + 'residual': False, + 'separable': True, + 'se': True, + 'se_context_size': -1, + } + ], } decoder = { - 'cls': 'nemo.collections.asr.modules.ConvASRDecoder', - 'params': { - 'feat_in': 1024, - 'num_classes': 28, - 'vocabulary': [ - ' ', - 'a', - 'b', - 'c', - 'd', - 'e', - 'f', - 'g', - 'h', - 'i', - 'j', - 'k', - 'l', - 'm', - 'n', - 'o', - 'p', - 'q', - 'r', - 's', - 't', - 'u', - 'v', - 'w', - 'x', - 'y', - 'z', - "'", - ], - }, + '_target_': 'nemo.collections.asr.modules.ConvASRDecoder', + 'feat_in': 1024, + 'num_classes': 28, + 'vocabulary': [ + ' ', + 'a', + 'b', + 'c', + 'd', + 'e', + 'f', + 'g', + 'h', + 'i', + 'j', + 'k', + 'l', + 'm', + 'n', + 'o', + 'p', + 'q', + 'r', + 's', + 't', + 'u', + 'v', + 'w', + 'x', + 'y', + 'z', + "'", + ], } modelConfig = DictConfig( {'preprocessor': DictConfig(preprocessor), 'encoder': DictConfig(encoder), 'decoder': DictConfig(decoder)}