diff --git a/README.rst b/README.rst index 85db0f92699a..6b302f5619b4 100644 --- a/README.rst +++ b/README.rst @@ -32,12 +32,6 @@ Introduction ------------ -NeMo is a toolkit for creating `Conversational AI `_ applications. - -`NeMo product page. `_ - -`Introductory video. `_ - The toolkit comes with extendable collections of pre-built modules and ready-to-use models for: * `Automatic Speech Recognition (ASR) `_ diff --git a/docs/source/asr/api.rst b/docs/source/asr/api.rst index 7d203a99aec9..174e0530d978 100644 --- a/docs/source/asr/api.rst +++ b/docs/source/asr/api.rst @@ -62,7 +62,7 @@ Modules Parts ----- -.. autoclass:: nemo.collections.asr.parts.jasper.JasperBlock +.. autoclass:: nemo.collections.asr.parts.submodules.jasper.JasperBlock :show-inheritance: :members: @@ -70,11 +70,11 @@ Parts Mixins ------ -.. autoclass:: nemo.collections.asr.parts.mixins.ASRBPEMixin +.. autoclass:: nemo.collections.asr.parts.mixins.mixins.ASRBPEMixin :show-inheritance: :members: -.. autoclass:: nemo.collections.asr.parts.mixins.ASRModuleMixin +.. autoclass:: nemo.collections.asr.parts.mixins.mixins.ASRModuleMixin :show-inheritance: :members: @@ -129,39 +129,39 @@ Audio Augmentors :show-inheritance: :members: -.. autoclass:: nemo.collections.asr.parts.perturb.SpeedPerturbation +.. autoclass:: nemo.collections.asr.parts.preprocessing.perturb.SpeedPerturbation :show-inheritance: :members: -.. autoclass:: nemo.collections.asr.parts.perturb.TimeStretchPerturbation +.. autoclass:: nemo.collections.asr.parts.preprocessing.perturb.TimeStretchPerturbation :show-inheritance: :members: -.. autoclass:: nemo.collections.asr.parts.perturb.GainPerturbation +.. autoclass:: nemo.collections.asr.parts.preprocessing.perturb.GainPerturbation :show-inheritance: :members: -.. autoclass:: nemo.collections.asr.parts.perturb.ImpulsePerturbation +.. autoclass:: nemo.collections.asr.parts.preprocessing.perturb.ImpulsePerturbation :show-inheritance: :members: -.. autoclass:: nemo.collections.asr.parts.perturb.ShiftPerturbation +.. autoclass:: nemo.collections.asr.parts.preprocessing.perturb.ShiftPerturbation :show-inheritance: :members: -.. autoclass:: nemo.collections.asr.parts.perturb.NoisePerturbation +.. autoclass:: nemo.collections.asr.parts.preprocessing.perturb.NoisePerturbation :show-inheritance: :members: -.. autoclass:: nemo.collections.asr.parts.perturb.WhiteNoisePerturbation +.. autoclass:: nemo.collections.asr.parts.preprocessing.perturb.WhiteNoisePerturbation :show-inheritance: :members: -.. autoclass:: nemo.collections.asr.parts.perturb.RirAndNoisePerturbation +.. autoclass:: nemo.collections.asr.parts.preprocessing.perturb.RirAndNoisePerturbation :show-inheritance: :members: -.. autoclass:: nemo.collections.asr.parts.perturb.TranscodePerturbation +.. autoclass:: nemo.collections.asr.parts.preprocessing.perturb.TranscodePerturbation :show-inheritance: :members: @@ -179,25 +179,25 @@ RNNT Decoding :show-inheritance: :members: -.. autoclass:: nemo.collections.asr.parts.rnnt_greedy_decoding.GreedyRNNTInfer +.. autoclass:: nemo.collections.asr.parts.submodules.rnnt_greedy_decoding.GreedyRNNTInfer :show-inheritance: :members: -.. autoclass:: nemo.collections.asr.parts.rnnt_greedy_decoding.GreedyBatchedRNNTInfer +.. autoclass:: nemo.collections.asr.parts.submodules.rnnt_greedy_decoding.GreedyBatchedRNNTInfer :show-inheritance: :members: -.. autoclass:: nemo.collections.asr.parts.rnnt_beam_decoding.BeamRNNTInfer +.. autoclass:: nemo.collections.asr.parts.submodules.rnnt_beam_decoding.BeamRNNTInfer :show-inheritance: :members: Hypotheses ~~~~~~~~~~ -.. autoclass:: nemo.collections.asr.parts.rnnt_utils.Hypothesis +.. autoclass:: nemo.collections.asr.parts.utils.rnnt_utils.Hypothesis :show-inheritance: :no-members: -.. autoclass:: nemo.collections.asr.parts.rnnt_utils.NBestHypotheses +.. autoclass:: nemo.collections.asr.parts.utils.rnnt_utils.NBestHypotheses :show-inheritance: :no-members: diff --git a/docs/source/asr/configs.rst b/docs/source/asr/configs.rst index 5e293db7234f..b606b0275043 100644 --- a/docs/source/asr/configs.rst +++ b/docs/source/asr/configs.rst @@ -342,7 +342,7 @@ configuration is a shortform notation for Citrinet-21x5xC, such that ``B = 21`` not be changed. To use Citrinet instead of QuartzNet, refer to the ``citrinet_512.yaml`` configuration found inside the ``examples/asr/conf/citrinet`` -directory. Citrinet is primarily comprised of the same :class:`~nemo.collections.asr.parts.jasper.JasperBlock` as ``Jasper`` or +directory. Citrinet is primarily comprised of the same :class:`~nemo.collections.asr.parts.submodules.jasper.JasperBlock` as ``Jasper`` or ``QuartzNet`. While the configs for Citrinet and QuartzNet are similar, we note the additional flags used for Citrinet below. Refer to the @@ -442,7 +442,7 @@ changed slightly as Citrinet utilizes sub-word tokenization. .. note:: The following information is relevant to any of the above models that implements its encoder as an :class:`~nemo.collections.asr.modules.conv_asr.ConvASREncoder`, and utilizes the ``SqueezeExcite`` mechanism. -The ``SqueezeExcite`` block within a :class:`~nemo.collections.asr.modules.conv_asr.ConvASREncoder` network can be modified to utilize a different context window after the model has been instantiated (even after the model has been trained) so as to evaluate the model with limited context. This can be achieved using the :meth:`~nemo.collections.asr.parts.mixins.ASRModuleMixin.change_conv_asr_se_context_window` +The ``SqueezeExcite`` block within a :class:`~nemo.collections.asr.modules.conv_asr.ConvASREncoder` network can be modified to utilize a different context window after the model has been instantiated (even after the model has been trained) so as to evaluate the model with limited context. This can be achieved using the :meth:`~nemo.collections.asr.parts.mixins.mixins.ASRModuleMixin.change_conv_asr_se_context_window` .. code-block:: python @@ -473,3 +473,56 @@ specify the tokenizer if you want to use sub-word encoding instead of character- The encoder section includes the details about the Conformer-CTC encoder architecture. You may find more information in the config files and also :doc:`nemo.collections.asr.modules.ConformerEncoder<./api.html#nemo.collections.asr.modules.ConformerEncoder>`. + + +Fine-tuning Configurations +------------------------- + +All ASR scripts support easy fine-tuning by partially/fully loading the pretrained weights from a checkpoint into the currently instantiated model. Pre-trained weights can be provided in multiple ways - + +1) Providing a path to a NeMo model (via ``init_from_nemo_model``) +2) Providing a name of a pretrained NeMo model (which will be downloaded via the cloud) (via ``init_from_pretrained_model``) +3) Providing a path to a Pytorch Lightning checkpoint file (via ``init_from_ptl_ckpt``) + +Fine-tuning via a NeMo model +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: sh + + python examples/asr/script_to_.py \ + --config-path= \ + --config-name=) \ + model.train_ds.manifest_filepath="" \ + model.validation_ds.manifest_filepath="" \ + trainer.gpus=-1 \ + trainer.max_epochs=50 \ + +init_from_nemo_model="" + + +Fine-tuning via a NeMo pretrained model name +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: sh + + python examples/asr/script_to_.py \ + --config-path= \ + --config-name=) \ + model.train_ds.manifest_filepath="" \ + model.validation_ds.manifest_filepath="" \ + trainer.gpus=-1 \ + trainer.max_epochs=50 \ + +init_from_pretrained_model="" + +Fine-tuning via a Pytorch Lightning checkpoint +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: sh + + python examples/asr/script_to_.py \ + --config-path= \ + --config-name=) \ + model.train_ds.manifest_filepath="" \ + model.validation_ds.manifest_filepath="" \ + trainer.gpus=-1 \ + trainer.max_epochs=50 \ + +init_from_ptl_ckpt="" \ No newline at end of file diff --git a/docs/source/asr/speaker_diarization/api.rst b/docs/source/asr/speaker_diarization/api.rst index 64ae1180c012..a820e1dc7bc0 100644 --- a/docs/source/asr/speaker_diarization/api.rst +++ b/docs/source/asr/speaker_diarization/api.rst @@ -12,6 +12,6 @@ Model Classes Mixins ------ -.. autoclass:: nemo.collections.asr.parts.mixins.DiarizationMixin +.. autoclass:: nemo.collections.asr.parts.mixins.mixins.DiarizationMixin :show-inheritance: :members: diff --git a/docs/source/asr/speaker_recognition/configs.rst b/docs/source/asr/speaker_recognition/configs.rst index b9716922f6b7..7b84f474e2db 100644 --- a/docs/source/asr/speaker_recognition/configs.rst +++ b/docs/source/asr/speaker_recognition/configs.rst @@ -80,7 +80,7 @@ minimum and maximum SNR specified with min_snr and max_snr respectively. This se max_snr_db: 15 -See the :class:`nemo.collections.asr.parts.perturb.AudioAugmentor` API section for more details. +See the :class:`nemo.collections.asr.parts.preprocessing.perturb.AudioAugmentor` API section for more details. Model Architecture Configurations diff --git a/docs/source/conf.py b/docs/source/conf.py index c14ccb6408eb..7dfba6e3e4dc 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -52,6 +52,7 @@ 'nemo_text_processing.inverse_text_normalization', # Not installed automatically 'nemo_text_processing.text_normalization', # Not installed automatically 'attr', # attrdict in requirements, attr in import + 'torchmetrics', # inherited from PTL ] _skipped_autodoc_mock_imports = ['wrapt', 'numpy'] diff --git a/examples/asr/speech_to_label.py b/examples/asr/speech_to_label.py index 36128029f90c..8c71a66e137c 100644 --- a/examples/asr/speech_to_label.py +++ b/examples/asr/speech_to_label.py @@ -102,20 +102,37 @@ +trainer.precision=16 \ +trainer.amp_level=O1 # needed if using PyTorch < 1.6 +# Fine-tune a model + +For documentation on fine-tuning this model, please visit - +https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations + +# Pretrained Models + +For documentation on existing pretrained models, please visit - +https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/speech_classification/results.html# + """ import pytorch_lightning as pl +from omegaconf import OmegaConf from nemo.collections.asr.models import EncDecClassificationModel from nemo.core.config import hydra_runner +from nemo.utils import logging from nemo.utils.exp_manager import exp_manager @hydra_runner(config_path="conf", config_name="matchboxnet_3x1x64_v1") def main(cfg): + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) asr_model = EncDecClassificationModel(cfg=cfg.model, trainer=trainer) + # Initialize the weights of the model from another model, if provided via config + asr_model.maybe_init_from_pretrained_checkpoint(cfg) + trainer.fit(asr_model) if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: diff --git a/examples/asr/speech_to_text.py b/examples/asr/speech_to_text.py index 2faa510a60cb..ced4cb8e3af7 100644 --- a/examples/asr/speech_to_text.py +++ b/examples/asr/speech_to_text.py @@ -12,21 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl -from omegaconf import OmegaConf - -from nemo.collections.asr.models import EncDecCTCModel -from nemo.core.config import hydra_runner -from nemo.utils import logging -from nemo.utils.exp_manager import exp_manager - - """ +# Training the model + Basic run (on CPU for 50 epochs): python examples/asr/speech_to_text.py \ - model.train_ds.manifest_filepath="/Users/okuchaiev/Data/an4_dataset/an4_train.json" \ - model.validation_ds.manifest_filepath="/Users/okuchaiev/Data/an4_dataset/an4_val.json" \ - hydra.run.dir="." \ + # (Optional: --config-path= --config-name=) \ + model.train_ds.manifest_filepath="" \ + model.validation_ds.manifest_filepath="" \ trainer.gpus=0 \ trainer.max_epochs=50 @@ -41,19 +34,19 @@ Override some args of optimizer: python speech_to_text.py \ + # (Optional: --config-path= --config-name=) \ model.train_ds.manifest_filepath="./an4/train_manifest.json" \ model.validation_ds.manifest_filepath="./an4/test_manifest.json" \ - hydra.run.dir="." \ trainer.gpus=2 \ trainer.max_epochs=2 \ model.optim.args.betas=[0.8,0.5] \ model.optim.args.weight_decay=0.0001 -Overide optimizer entirely +Override optimizer entirely python speech_to_text.py \ + # (Optional: --config-path= --config-name=) \ model.train_ds.manifest_filepath="./an4/train_manifest.json" \ model.validation_ds.manifest_filepath="./an4/test_manifest.json" \ - hydra.run.dir="." \ trainer.gpus=2 \ trainer.max_epochs=2 \ model.optim.name=adamw \ @@ -62,16 +55,38 @@ +model.optim.args.betas=[0.8,0.5]\ +model.optim.args.weight_decay=0.0005 +# Fine-tune a model + +For documentation on fine-tuning this model, please visit - +https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations + +# Pretrained Models + +For documentation on existing pretrained models, please visit - +https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/results.html + """ +import pytorch_lightning as pl +from omegaconf import OmegaConf + +from nemo.collections.asr.models import EncDecCTCModel +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + @hydra_runner(config_path="conf", config_name="config") def main(cfg): logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) asr_model = EncDecCTCModel(cfg=cfg.model, trainer=trainer) + # Initialize the weights of the model from another model, if provided via config + asr_model.maybe_init_from_pretrained_checkpoint(cfg) + trainer.fit(asr_model) if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: diff --git a/examples/asr/speech_to_text_bpe.py b/examples/asr/speech_to_text_bpe.py index 244aac6377df..ac1251fa115d 100644 --- a/examples/asr/speech_to_text_bpe.py +++ b/examples/asr/speech_to_text_bpe.py @@ -50,7 +50,19 @@ exp_manager.wandb_logger_kwargs.name="" \ exp_manager.wandb_logger_kwargs.project="" ``` + +# Fine-tune a model + +For documentation on fine-tuning this model, please visit - +https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations + +# Pretrained Models + +For documentation on existing pretrained models, please visit - +https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/results.html + """ + import pytorch_lightning as pl from omegaconf import OmegaConf @@ -63,12 +75,14 @@ @hydra_runner(config_path="experimental/configs/", config_name="config_bpe") def main(cfg): logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') - print(OmegaConf.to_yaml(cfg)) + trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) - asr_model = EncDecCTCModelBPE(cfg=cfg.model, trainer=trainer) + # Initialize the weights of the model from another model, if provided via config + asr_model.maybe_init_from_pretrained_checkpoint(cfg) + trainer.fit(asr_model) if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: diff --git a/examples/asr/speech_to_text_rnnt.py b/examples/asr/speech_to_text_rnnt.py index 5ed0ff7b3361..39d3be9ca74e 100644 --- a/examples/asr/speech_to_text_rnnt.py +++ b/examples/asr/speech_to_text_rnnt.py @@ -12,32 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl - -from nemo.collections.asr.models import EncDecRNNTModel -from nemo.core.config import hydra_runner -from nemo.utils.exp_manager import exp_manager - - """ -# Preparing the Tokenizer for the dataset -Use the `process_asr_text_tokenizer.py` script under /scripts/tokenizers/ in order to prepare the tokenizer. - -```sh -python /scripts/tokenizers/process_asr_text_tokenizer.py \ - --manifest= \ - --data_root="" \ - --vocab_size= \ - --tokenizer=<"bpe" or "wpe"> \ - --log -``` - # Training the model Basic run (on CPU for 50 epochs): python examples/asr/speech_to_text_rnnt.py \ - model.train_ds.manifest_filepath="" \ - model.validation_ds.manifest_filepath="" \ + # (Optional: --config-path= --config-name=) \ + model.train_ds.manifest_filepath="" \ + model.validation_ds.manifest_filepath="" \ trainer.gpus=0 \ trainer.max_epochs=50 @@ -56,12 +38,11 @@ --config-name="config_rnnt" \ model.train_ds.manifest_filepath="./an4/train_manifest.json" \ model.validation_ds.manifest_filepath="./an4/test_manifest.json" \ - hydra.run.dir="." \ trainer.gpus=2 \ trainer.precision=16 \ trainer.max_epochs=2 \ - model.optim.args.params.betas=[0.8,0.5] \ - model.optim.args.params.weight_decay=0.0001 + model.optim.betas=[0.8,0.5] \ + model.optim.weight_decay=0.0001 Override optimizer entirely python speech_to_text_rnnt.py \ @@ -69,7 +50,6 @@ --config-name="config_rnnt" \ model.train_ds.manifest_filepath="./an4/train_manifest.json" \ model.validation_ds.manifest_filepath="./an4/test_manifest.json" \ - hydra.run.dir="." \ trainer.gpus=2 \ trainer.precision=16 \ trainer.max_epochs=2 \ @@ -79,15 +59,33 @@ +model.optim.args.betas=[0.8,0.5]\ +model.optim.args.weight_decay=0.0005 +# Fine-tune a model + +For documentation on fine-tuning this model, please visit - +https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations + """ +import pytorch_lightning as pl +from omegaconf import OmegaConf + +from nemo.collections.asr.models import EncDecRNNTModel +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + @hydra_runner(config_path="experimental/contextnet_rnnt", config_name="config_rnnt") def main(cfg): + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) asr_model = EncDecRNNTModel(cfg=cfg.model, trainer=trainer) + # Initialize the weights of the model from another model, if provided via config + asr_model.maybe_init_from_pretrained_checkpoint(cfg) + trainer.fit(asr_model) if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: diff --git a/examples/asr/speech_to_text_rnnt_bpe.py b/examples/asr/speech_to_text_rnnt_bpe.py index 30332fafeb03..91e41deaf05a 100644 --- a/examples/asr/speech_to_text_rnnt_bpe.py +++ b/examples/asr/speech_to_text_rnnt_bpe.py @@ -12,69 +12,72 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +""" +# Preparing the Tokenizer for the dataset +Use the `process_asr_text_tokenizer.py` script under /scripts/tokenizers/ in order to prepare the tokenizer. + +```sh +python /scripts/tokenizers/process_asr_text_tokenizer.py \ + --manifest= + OR + --data_file= \ + --data_root="" \ + --vocab_size= \ + --tokenizer=<"spe" or "wpe"> \ + --no_lower_case \ + --spe_type=<"unigram", "bpe", "char" or "word"> \ + --spe_character_coverage=1.0 \ + --log +``` + +# Training the model +```sh +python speech_to_text_rnnt_bpe.py \ + # (Optional: --config-path= --config-name=) \ + model.train_ds.manifest_filepath= \ + model.validation_ds.manifest_filepath= \ + model.tokenizer.dir= \ + model.tokenizer.type= \ + trainer.gpus=-1 \ + trainer.accelerator="ddp" \ + trainer.max_epochs=100 \ + model.optim.name="adamw" \ + model.optim.lr=0.001 \ + model.optim.betas=[0.9,0.999] \ + model.optim.weight_decay=0.0001 \ + model.optim.sched.warmup_steps=2000 + exp_manager.create_wandb_logger=True \ + exp_manager.wandb_logger_kwargs.name="" \ + exp_manager.wandb_logger_kwargs.project="" +``` -from nemo.collections.asr.models import EncDecRNNTBPEModel -from nemo.core.config import hydra_runner -from nemo.utils.exp_manager import exp_manager +# Fine-tune a model +For documentation on fine-tuning this model, please visit - +https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations """ -Basic run (on CPU for 50 epochs): - python examples/asr/speech_to_text_rnnt_bpe.py \ - model.train_ds.manifest_filepath="/Users/okuchaiev/Data/an4_dataset/an4_train.json" \ - model.validation_ds.manifest_filepath="/Users/okuchaiev/Data/an4_dataset/an4_val.json" \ - hydra.run.dir="." \ - trainer.gpus=0 \ - trainer.max_epochs=50 - - -Add PyTorch Lightning Trainer arguments from CLI: - python speech_to_text_rnnt_bpe.py \ - ... \ - +trainer.fast_dev_run=true - -Hydra logs will be found in "$(./outputs/$(date +"%y-%m-%d")/$(date +"%H-%M-%S")/.hydra)" -PTL logs will be found in "$(./outputs/$(date +"%y-%m-%d")/$(date +"%H-%M-%S")/lightning_logs)" - -Override some args of optimizer: - python speech_to_text_rnnt_bpe.py \ - --config-path="experimental/contextnet_rnnt" \ - --config-name="config_rnnt_bpe" \ - model.train_ds.manifest_filepath="./an4/train_manifest.json" \ - model.validation_ds.manifest_filepath="./an4/test_manifest.json" \ - hydra.run.dir="." \ - trainer.gpus=2 \ - trainer.precision=16 \ - trainer.max_epochs=2 \ - model.optim.args.params.betas=[0.8,0.5] \ - model.optim.args.params.weight_decay=0.0001 - -Overide optimizer entirely - python speech_to_text_rnnt_bpe.py \ - --config-path="experimental/contextnet_rnnt" \ - --config-name="config_rnnt_bpe" \ - model.train_ds.manifest_filepath="./an4/train_manifest.json" \ - model.validation_ds.manifest_filepath="./an4/test_manifest.json" \ - hydra.run.dir="." \ - trainer.gpus=2 \ - trainer.precision=16 \ - trainer.max_epochs=2 \ - model.optim.name=adamw \ - model.optim.lr=0.001 \ - ~model.optim.args \ - +model.optim.args.betas=[0.8,0.5]\ - +model.optim.args.weight_decay=0.0005 -""" +import pytorch_lightning as pl +from omegaconf import OmegaConf + +from nemo.collections.asr.models import EncDecRNNTBPEModel +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager @hydra_runner(config_path="experimental/contextnet_rnnt", config_name="config_rnnt_bpe") def main(cfg): + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) asr_model = EncDecRNNTBPEModel(cfg=cfg.model, trainer=trainer) + # Initialize the weights of the model from another model, if provided via config + asr_model.maybe_init_from_pretrained_checkpoint(cfg) + trainer.fit(asr_model) if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: diff --git a/examples/asr/vad_infer.py b/examples/asr/vad_infer.py index 00fca470c26b..ba036029f2fe 100644 --- a/examples/asr/vad_infer.py +++ b/examples/asr/vad_infer.py @@ -38,7 +38,7 @@ import torch from nemo.collections.asr.models import EncDecClassificationModel -from nemo.collections.asr.parts.vad_utils import get_vad_stream_status, prepare_manifest +from nemo.collections.asr.parts.utils.vad_utils import get_vad_stream_status, prepare_manifest from nemo.utils import logging try: diff --git a/examples/tts/conf/fastpitch_hifigan_e2e.yaml b/examples/tts/conf/fastpitch_hifigan_e2e.yaml index b4e257d353d3..c0570ee03121 100644 --- a/examples/tts/conf/fastpitch_hifigan_e2e.yaml +++ b/examples/tts/conf/fastpitch_hifigan_e2e.yaml @@ -42,7 +42,7 @@ model: num_workers: 8 preprocessor: - _target_: nemo.collections.asr.parts.features.FilterbankFeatures + _target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures dither: 0.0 nfilt: ${model.n_mel_channels} frame_splicing: 1 diff --git a/examples/tts/conf/fastspeech2.yaml b/examples/tts/conf/fastspeech2.yaml index d3b2284a2220..ed1dea2373f8 100644 --- a/examples/tts/conf/fastspeech2.yaml +++ b/examples/tts/conf/fastspeech2.yaml @@ -47,7 +47,7 @@ model: num_workers: 4 preprocessor: - _target_: nemo.collections.asr.parts.features.FilterbankFeatures + _target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures dither: 0.0 nfilt: ${n_mels} frame_splicing: 1 diff --git a/examples/tts/conf/fastspeech2_hifigan_e2e.yaml b/examples/tts/conf/fastspeech2_hifigan_e2e.yaml index 3cf56a201b0c..95acf3d7693a 100644 --- a/examples/tts/conf/fastspeech2_hifigan_e2e.yaml +++ b/examples/tts/conf/fastspeech2_hifigan_e2e.yaml @@ -52,7 +52,7 @@ model: num_workers: 4 preprocessor: - _target_: nemo.collections.asr.parts.features.FilterbankFeatures + _target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures dither: 0.0 nfilt: ${n_mels} frame_splicing: 1 diff --git a/examples/tts/conf/hifigan/hifigan.yaml b/examples/tts/conf/hifigan/hifigan.yaml index d15177ae50a1..9412ced66932 100644 --- a/examples/tts/conf/hifigan/hifigan.yaml +++ b/examples/tts/conf/hifigan/hifigan.yaml @@ -9,7 +9,7 @@ defaults: model: preprocessor: - _target_: nemo.collections.asr.parts.features.FilterbankFeatures + _target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures dither: 0.0 frame_splicing: 1 nfilt: 80 diff --git a/examples/tts/conf/melgan.yaml b/examples/tts/conf/melgan.yaml index ea31c7d43278..92599889d231 100644 --- a/examples/tts/conf/melgan.yaml +++ b/examples/tts/conf/melgan.yaml @@ -38,7 +38,7 @@ model: num_workers: 4 preprocessor: - _target_: nemo.collections.asr.parts.features.FilterbankFeatures + _target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures dither: 0.0 frame_splicing: 1 nfilt: ${n_mels} diff --git a/examples/tts/conf/squeezewave.yaml b/examples/tts/conf/squeezewave.yaml index cbd5f972a4aa..9fc620f4015e 100644 --- a/examples/tts/conf/squeezewave.yaml +++ b/examples/tts/conf/squeezewave.yaml @@ -42,7 +42,7 @@ model: num_workers: 4 preprocessor: - _target_: nemo.collections.asr.parts.features.FilterbankFeatures + _target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures dither: 0.0 nfilt: ${n_mels} frame_splicing: 1 diff --git a/examples/tts/conf/tacotron2.yaml b/examples/tts/conf/tacotron2.yaml index 0bf9406ed7b5..11b37cf4d114 100644 --- a/examples/tts/conf/tacotron2.yaml +++ b/examples/tts/conf/tacotron2.yaml @@ -55,7 +55,7 @@ model: num_workers: 8 preprocessor: - _target_: nemo.collections.asr.parts.features.FilterbankFeatures + _target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures dither: 0.0 nfilt: ${n_mels} frame_splicing: 1 diff --git a/examples/tts/conf/uniglow.yaml b/examples/tts/conf/uniglow.yaml index ae9af2ad6951..b74584d799f3 100644 --- a/examples/tts/conf/uniglow.yaml +++ b/examples/tts/conf/uniglow.yaml @@ -39,7 +39,7 @@ model: num_workers: 4 preprocessor: - _target_: nemo.collections.asr.parts.features.FilterbankFeatures + _target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures dither: 0.0 nfilt: ${n_mels} frame_splicing: 1 diff --git a/examples/tts/conf/waveglow.yaml b/examples/tts/conf/waveglow.yaml index 407bb071b1cd..83f97a047426 100644 --- a/examples/tts/conf/waveglow.yaml +++ b/examples/tts/conf/waveglow.yaml @@ -38,7 +38,7 @@ model: num_workers: 4 preprocessor: - _target_: nemo.collections.asr.parts.features.FilterbankFeatures + _target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures dither: 0.0 nfilt: ${n_mels} frame_splicing: 1 diff --git a/examples/tts/test_tts_infer.py b/examples/tts/test_tts_infer.py index a0a39797e230..ce98fef1f3a3 100644 --- a/examples/tts/test_tts_infer.py +++ b/examples/tts/test_tts_infer.py @@ -26,7 +26,7 @@ from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.models import EncDecCTCModel -from nemo.collections.asr.parts import parsers +from nemo.collections.common.parts.preprocessing import parsers from nemo.collections.tts.models.base import SpectrogramGenerator, Vocoder from nemo.utils import logging diff --git a/nemo/collections/asr/data/audio_to_label.py b/nemo/collections/asr/data/audio_to_label.py index a24893b999e6..6c81895c869a 100644 --- a/nemo/collections/asr/data/audio_to_label.py +++ b/nemo/collections/asr/data/audio_to_label.py @@ -20,7 +20,7 @@ import torch import webdataset as wd -from nemo.collections.asr.parts import collections +from nemo.collections.common.parts.preprocessing import collections from nemo.core.classes import Dataset, IterableDataset from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType, RegressionValuesType from nemo.utils import logging diff --git a/nemo/collections/asr/data/audio_to_text.py b/nemo/collections/asr/data/audio_to_text.py index f6b25db930f0..935fd456d0b2 100644 --- a/nemo/collections/asr/data/audio_to_text.py +++ b/nemo/collections/asr/data/audio_to_text.py @@ -25,8 +25,8 @@ from torch.nn import functional as F from nemo.collections.asr.data import vocabs -from nemo.collections.asr.parts import collections, parsers -from nemo.collections.asr.parts.features import WaveformFeaturizer +from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer +from nemo.collections.common.parts.preprocessing import collections, parsers from nemo.core.classes import Dataset, IterableDataset from nemo.core.neural_types import * from nemo.core.neural_types.elements import ProbsType diff --git a/nemo/collections/asr/data/audio_to_text_dali.py b/nemo/collections/asr/data/audio_to_text_dali.py index e95c88156d1c..5a6e31157681 100644 --- a/nemo/collections/asr/data/audio_to_text_dali.py +++ b/nemo/collections/asr/data/audio_to_text_dali.py @@ -20,7 +20,7 @@ import torch from omegaconf import DictConfig -from nemo.collections.asr.parts import parsers +from nemo.collections.common.parts.preprocessing import parsers from nemo.utils.decorators import experimental try: diff --git a/nemo/collections/asr/data/audio_to_text_dataset.py b/nemo/collections/asr/data/audio_to_text_dataset.py index 80c218932039..dc1b3fdae883 100644 --- a/nemo/collections/asr/data/audio_to_text_dataset.py +++ b/nemo/collections/asr/data/audio_to_text_dataset.py @@ -15,9 +15,50 @@ from typing import Optional import torch -from omegaconf import DictConfig +from omegaconf import DictConfig, open_dict from nemo.collections.asr.data import audio_to_text, audio_to_text_dali +from nemo.utils import logging + + +def inject_dataloader_value_from_model_config(model_cfg: dict, dataloader_cfg: dict, key: str): + """ + Extracts the label set provided at the top level of the model, and propagates it to the dataloader + config. + + Args: + model_cfg: A DictConfig representing the model's config. + dataloader_cfg: A DictConfig representing the individual data loader + key: A str value representing a key in the model_cfg whose value will be propagated to the + dataloader config. + """ + if key not in model_cfg: + logging.info( + f"Model level config does not container `{key}`, please explicitly provide `{key}` to the dataloaders." + ) + return + + # If key exists in the data loader config (either set explicitly or as a placeholder (via None)) + if key in dataloader_cfg: + # Dataloader `labels` is provided and is non-null + if dataloader_cfg[key] is not None and model_cfg[key] != dataloader_cfg[key]: + # Model level `labels` dont match Dataloader level `labels` + logging.warning( + f'`{key}` is explicitly provided to the data loader, and is different from ' + f'the `{key}` provided at the model level config.\n' + f'If this is incorrect, please set the dataloader\'s `{key}` to None.' + ) + + else: + # Dataloader `key` is None or values match + # Propagate from model level `key` (even if they match) + with open_dict(dataloader_cfg): + dataloader_cfg[key] = model_cfg[key] + + else: + # If key key doesnt even exist in dataloader_cfg, inject it explicitly + with open_dict(dataloader_cfg): + dataloader_cfg[key] = model_cfg[key] def get_char_dataset(config: dict, augmentor: Optional['AudioAugmentor'] = None) -> audio_to_text.AudioToCharDataset: diff --git a/nemo/collections/asr/data/feature_to_label.py b/nemo/collections/asr/data/feature_to_label.py index 6fe5ed669827..dcba068d7a71 100644 --- a/nemo/collections/asr/data/feature_to_label.py +++ b/nemo/collections/asr/data/feature_to_label.py @@ -15,7 +15,7 @@ import torch -from nemo.collections.asr.parts import collections +from nemo.collections.common.parts.preprocessing import collections from nemo.core.classes import Dataset from nemo.core.neural_types import AcousticEncodedRepresentation, LabelsType, LengthsType, NeuralType from nemo.utils import logging diff --git a/nemo/collections/asr/data/vocabs.py b/nemo/collections/asr/data/vocabs.py index 252c727ed1e7..38beb20b3b49 100644 --- a/nemo/collections/asr/data/vocabs.py +++ b/nemo/collections/asr/data/vocabs.py @@ -21,18 +21,8 @@ from typing import List import nltk -from nltk.corpus import cmudict -from nemo.collections.asr.parts import parsers - -try: - nltk.data.find('taggers/averaged_perceptron_tagger.zip') -except LookupError: - nltk.download('averaged_perceptron_tagger', quiet=True) -try: - nltk.data.find('corpora/cmudict.zip') -except LookupError: - nltk.download('cmudict', quiet=True) +from nemo.collections.common.parts.preprocessing import parsers try: import g2p_en # noqa @@ -70,6 +60,16 @@ def __init__( text_preprocessing_func=_text_preprocessing, word_tokenize_func=_word_tokenize, ): + # Download NLTK datasets if this class is to be instantiated + try: + nltk.data.find('taggers/averaged_perceptron_tagger.zip') + except LookupError: + nltk.download('averaged_perceptron_tagger', quiet=True) + try: + nltk.data.find('corpora/cmudict.zip') + except LookupError: + nltk.download('cmudict', quiet=True) + self.homograph2features = _g2p.homograph2features self.g2p_dict = self._construct_grapheme2phoneme_dict(phoneme_dict_path) self.use_seq2seq_for_oov = use_seq2seq_for_oov @@ -81,6 +81,8 @@ def __init__( @staticmethod def _construct_grapheme2phoneme_dict(phoneme_dict_path=None, encoding='latin-1'): if phoneme_dict_path is None: + from nltk.corpus import cmudict + return cmudict.dict() _alt_re = re.compile(r'\([0-9]+\)') diff --git a/nemo/collections/asr/metrics/rnnt_wer.py b/nemo/collections/asr/metrics/rnnt_wer.py index 9c86ddf201c5..17861a3f573d 100644 --- a/nemo/collections/asr/metrics/rnnt_wer.py +++ b/nemo/collections/asr/metrics/rnnt_wer.py @@ -20,9 +20,9 @@ import torch from torchmetrics import Metric -from nemo.collections.asr.parts import rnnt_beam_decoding as beam_decode -from nemo.collections.asr.parts import rnnt_greedy_decoding as greedy_decode -from nemo.collections.asr.parts.rnnt_utils import Hypothesis, NBestHypotheses +from nemo.collections.asr.parts.submodules import rnnt_beam_decoding as beam_decode +from nemo.collections.asr.parts.submodules import rnnt_greedy_decoding as greedy_decode +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses from nemo.utils import logging __all__ = ['RNNTDecoding', 'RNNTWER'] diff --git a/nemo/collections/asr/metrics/rnnt_wer_bpe.py b/nemo/collections/asr/metrics/rnnt_wer_bpe.py index 4064f681e63c..0f679592a109 100644 --- a/nemo/collections/asr/metrics/rnnt_wer_bpe.py +++ b/nemo/collections/asr/metrics/rnnt_wer_bpe.py @@ -13,16 +13,15 @@ # limitations under the License. from dataclasses import dataclass -from typing import List, Optional +from typing import List import editdistance import torch from torchmetrics import Metric from nemo.collections.asr.metrics.rnnt_wer import AbstractRNNTDecoding -from nemo.collections.asr.parts import rnnt_beam_decoding as beam_decode -from nemo.collections.asr.parts import rnnt_greedy_decoding as greedy_decode -from nemo.collections.asr.parts.rnnt_utils import Hypothesis, NBestHypotheses +from nemo.collections.asr.parts.submodules import rnnt_beam_decoding as beam_decode +from nemo.collections.asr.parts.submodules import rnnt_greedy_decoding as greedy_decode from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.utils import logging diff --git a/nemo/collections/asr/metrics/wer.py b/nemo/collections/asr/metrics/wer.py index 83858d704728..e3d62f4281ff 100644 --- a/nemo/collections/asr/metrics/wer.py +++ b/nemo/collections/asr/metrics/wer.py @@ -18,7 +18,7 @@ import torch from torchmetrics import Metric -from nemo.collections.asr.parts.rnnt_utils import Hypothesis +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.utils import logging __all__ = ['word_error_rate', 'WER'] diff --git a/nemo/collections/asr/metrics/wer_bpe.py b/nemo/collections/asr/metrics/wer_bpe.py index 3406e3c75ddb..a0e7d3855621 100644 --- a/nemo/collections/asr/metrics/wer_bpe.py +++ b/nemo/collections/asr/metrics/wer_bpe.py @@ -18,7 +18,7 @@ import torch from torchmetrics import Metric -from nemo.collections.asr.parts.rnnt_utils import Hypothesis +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.utils import logging diff --git a/nemo/collections/asr/models/classification_models.py b/nemo/collections/asr/models/classification_models.py index a451c9f7bd2b..79e73c48dc71 100644 --- a/nemo/collections/asr/models/classification_models.py +++ b/nemo/collections/asr/models/classification_models.py @@ -20,7 +20,6 @@ from math import ceil from typing import Dict, List, Optional, Union -import onnx import torch from omegaconf import DictConfig, ListConfig, OmegaConf from pytorch_lightning import Trainer @@ -28,8 +27,8 @@ from nemo.collections.asr.data import audio_to_label_dataset from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel -from nemo.collections.asr.parts.features import WaveformFeaturizer -from nemo.collections.asr.parts.perturb import process_augmentations +from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer +from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations from nemo.collections.common.losses import CrossEntropyLoss, MSELoss from nemo.collections.common.metrics import TopKClassificationAccuracy from nemo.core.classes.common import PretrainedModelInfo, typecheck diff --git a/nemo/collections/asr/models/clustering_diarizer.py b/nemo/collections/asr/models/clustering_diarizer.py index d8226f037da0..3eb59debfce9 100644 --- a/nemo/collections/asr/models/clustering_diarizer.py +++ b/nemo/collections/asr/models/clustering_diarizer.py @@ -29,9 +29,9 @@ from nemo.collections.asr.models.classification_models import EncDecClassificationModel from nemo.collections.asr.models.label_models import ExtractSpeakerEmbeddingsModel -from nemo.collections.asr.parts.mixins import DiarizationMixin -from nemo.collections.asr.parts.speaker_utils import audio_rttm_map, perform_diarization, write_rttm2manifest -from nemo.collections.asr.parts.vad_utils import ( +from nemo.collections.asr.parts.mixins.mixins import DiarizationMixin +from nemo.collections.asr.parts.utils.speaker_utils import audio_rttm_map, perform_diarization, write_rttm2manifest +from nemo.collections.asr.parts.utils.vad_utils import ( generate_overlap_vad_seq, generate_vad_segment_table, get_vad_stream_status, diff --git a/nemo/collections/asr/models/ctc_bpe_models.py b/nemo/collections/asr/models/ctc_bpe_models.py index f6b3ede113ec..c5c279e172a9 100644 --- a/nemo/collections/asr/models/ctc_bpe_models.py +++ b/nemo/collections/asr/models/ctc_bpe_models.py @@ -24,7 +24,7 @@ from nemo.collections.asr.metrics.wer_bpe import WERBPE from nemo.collections.asr.models.ctc_models import EncDecCTCModel from nemo.collections.asr.parts.mixins import ASRBPEMixin -from nemo.collections.asr.parts.perturb import process_augmentations +from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations from nemo.core.classes.common import PretrainedModelInfo from nemo.utils import logging, model_utils @@ -66,6 +66,14 @@ def list_available_models(cls) -> Optional[PretrainedModelInfo]: results.append(model) + model = PretrainedModelInfo( + pretrained_model_name="stt_es_citrinet_512", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_es_citrinet_512", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_es_citrinet_512/versions/1.0.0/files/stt_es_citrinet_512.nemo", + ) + + results.append(model) + model = PretrainedModelInfo( pretrained_model_name="stt_en_conformer_ctc_small", description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_ctc_small", diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index 70977c9c44ce..0557c37434cb 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -29,7 +29,7 @@ from nemo.collections.asr.metrics.wer import WER from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel from nemo.collections.asr.parts.mixins import ASRModuleMixin -from nemo.collections.asr.parts.perturb import process_augmentations +from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations from nemo.core.classes.common import PretrainedModelInfo, typecheck from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, LogprobsType, NeuralType, SpectrogramType from nemo.utils import logging @@ -330,6 +330,12 @@ def change_vocabulary(self, new_vocabulary: List[str]): self._cfg.decoder = new_decoder_config OmegaConf.set_struct(self._cfg.decoder, True) + ds_keys = ['train_ds', 'validation_ds', 'test_ds'] + for key in ds_keys: + if key in self.cfg: + with open_dict(self.cfg[key]): + self.cfg[key]['labels'] = OmegaConf.create(new_vocabulary) + logging.info(f"Changed decoder to output to {self.decoder.vocabulary} vocabulary.") def _setup_dataloader_from_config(self, config: Optional[Dict]): @@ -338,6 +344,10 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): else: augmentor = None + # Automatically inject args from model config to dataloader config + audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='sample_rate') + audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='labels') + shuffle = config['shuffle'] device = 'gpu' if torch.cuda.is_available() else 'cpu' if config.get('use_dali', False): diff --git a/nemo/collections/asr/models/label_models.py b/nemo/collections/asr/models/label_models.py index e54e08b50bc6..14febd97e58e 100644 --- a/nemo/collections/asr/models/label_models.py +++ b/nemo/collections/asr/models/label_models.py @@ -26,8 +26,8 @@ from nemo.collections.asr.data.audio_to_label import AudioToSpeechLabelDataset from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss from nemo.collections.asr.models.asr_model import ExportableEncDecModel -from nemo.collections.asr.parts.features import WaveformFeaturizer -from nemo.collections.asr.parts.perturb import process_augmentations +from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer +from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations from nemo.collections.common.losses import CrossEntropyLoss as CELoss from nemo.collections.common.metrics import TopKClassificationAccuracy from nemo.core.classes import ModelPT diff --git a/nemo/collections/asr/models/rnnt_bpe_models.py b/nemo/collections/asr/models/rnnt_bpe_models.py index 27a2e62757e2..da69c9d02d27 100644 --- a/nemo/collections/asr/models/rnnt_bpe_models.py +++ b/nemo/collections/asr/models/rnnt_bpe_models.py @@ -25,7 +25,7 @@ from nemo.collections.asr.metrics.rnnt_wer_bpe import RNNTBPEWER, RNNTBPEDecoding from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel from nemo.collections.asr.parts.mixins import ASRBPEMixin -from nemo.collections.asr.parts.perturb import process_augmentations +from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations from nemo.core.classes.common import PretrainedModelInfo from nemo.utils import logging, model_utils diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index e535faa3fa1f..446f88d8784c 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -30,7 +30,7 @@ from nemo.collections.asr.metrics.rnnt_wer import RNNTWER, RNNTDecoding from nemo.collections.asr.models.asr_model import ASRModel from nemo.collections.asr.parts.mixins import ASRModuleMixin -from nemo.collections.asr.parts.perturb import process_augmentations +from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations from nemo.core.classes.common import PretrainedModelInfo, typecheck from nemo.core.neural_types import AcousticEncodedRepresentation, AudioSignal, LengthsType, NeuralType, SpectrogramType from nemo.utils import logging @@ -340,6 +340,12 @@ def change_vocabulary(self, new_vocabulary: List[str], decoding_cfg: Optional[Di with open_dict(self.cfg.decoding): self.cfg.decoding = decoding_cfg + ds_keys = ['train_ds', 'validation_ds', 'test_ds'] + for key in ds_keys: + if key in self.cfg: + with open_dict(self.cfg[key]): + self.cfg[key]['labels'] = OmegaConf.create(new_vocabulary) + logging.info(f"Changed decoder to output to {self.joint.vocabulary} vocabulary.") def change_decoding_strategy(self, decoding_cfg: DictConfig): @@ -384,6 +390,10 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): else: augmentor = None + # Automatically inject args from model config to dataloader config + audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='sample_rate') + audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='labels') + shuffle = config['shuffle'] device = 'gpu' if torch.cuda.is_available() else 'cpu' if config.get('use_dali', False): diff --git a/nemo/collections/asr/models/wav2vec/wav2vec_model.py b/nemo/collections/asr/models/wav2vec/wav2vec_model.py index cd1a0721e45d..dd33c2e01757 100644 --- a/nemo/collections/asr/models/wav2vec/wav2vec_model.py +++ b/nemo/collections/asr/models/wav2vec/wav2vec_model.py @@ -31,8 +31,8 @@ from nemo.collections.asr.losses.wav2vecloss import Wav2VecLoss from nemo.collections.asr.models.wav2vec.wav2vec_config import Wav2VecEncoderModelConfig from nemo.collections.asr.modules.wav2vec_modules import GumbelVectorQuantizer, compute_mask_indices -from nemo.collections.asr.parts.perturb import process_augmentations -from nemo.collections.asr.parts.wav2vec import ConvFeatureEncoder, GradMultiply, Wav2VecTransformerEncoder +from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations +from nemo.collections.asr.parts.submodules.wav2vec import ConvFeatureEncoder, GradMultiply, Wav2VecTransformerEncoder from nemo.core import ModelPT from nemo.core.classes.common import PretrainedModelInfo, typecheck from nemo.core.neural_types import AudioSignal, EncodedRepresentation, LossType, MaskType, NeuralType diff --git a/nemo/collections/asr/modules/audio_preprocessing.py b/nemo/collections/asr/modules/audio_preprocessing.py index 32f343453d78..3d92c5473fa1 100644 --- a/nemo/collections/asr/modules/audio_preprocessing.py +++ b/nemo/collections/asr/modules/audio_preprocessing.py @@ -20,8 +20,8 @@ import torch from packaging import version -from nemo.collections.asr.parts.features import FilterbankFeatures -from nemo.collections.asr.parts.spectr_augment import SpecAugment, SpecCutout +from nemo.collections.asr.parts.preprocessing.features import FilterbankFeatures +from nemo.collections.asr.parts.submodules.spectr_augment import SpecAugment, SpecCutout from nemo.core.classes import NeuralModule, typecheck from nemo.core.neural_types import ( AudioSignal, @@ -425,13 +425,6 @@ class SpectrogramAugmentation(NeuralModule): Defaults to 25. """ - def save_to(self, save_path: str): - pass - - @classmethod - def restore_from(cls, restore_path: str): - pass - @property def input_types(self): """Returns definitions of module input types @@ -462,7 +455,7 @@ def __init__( self.spec_cutout = SpecCutout(rect_masks=rect_masks, rect_time=rect_time, rect_freq=rect_freq, rng=rng,) # self.spec_cutout.to(self._device) else: - self.spec_cutout = lambda x: x + self.spec_cutout = lambda input_spec: input_spec if freq_masks + time_masks > 0: self.spec_augment = SpecAugment( @@ -474,12 +467,12 @@ def __init__( mask_value=mask_value, ) else: - self.spec_augment = lambda x: x + self.spec_augment = lambda input_spec: input_spec @typecheck() def forward(self, input_spec): - augmented_spec = self.spec_cutout(input_spec) - augmented_spec = self.spec_augment(augmented_spec) + augmented_spec = self.spec_cutout(input_spec=input_spec) + augmented_spec = self.spec_augment(input_spec=augmented_spec) return augmented_spec diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index c6bb1234475c..0f82e172bb37 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -18,9 +18,9 @@ import torch import torch.nn as nn -from nemo.collections.asr.parts.conformer_modules import ConformerLayer -from nemo.collections.asr.parts.multi_head_attention import PositionalEncoding, RelPositionalEncoding -from nemo.collections.asr.parts.subsampling import ConvSubsampling +from nemo.collections.asr.parts.submodules.conformer_modules import ConformerLayer +from nemo.collections.asr.parts.submodules.multi_head_attention import PositionalEncoding, RelPositionalEncoding +from nemo.collections.asr.parts.submodules.subsampling import ConvSubsampling from nemo.core.classes.common import typecheck from nemo.core.classes.exportable import Exportable from nemo.core.classes.module import NeuralModule diff --git a/nemo/collections/asr/modules/conv_asr.py b/nemo/collections/asr/modules/conv_asr.py index 3022325101bd..7f2767019d12 100644 --- a/nemo/collections/asr/modules/conv_asr.py +++ b/nemo/collections/asr/modules/conv_asr.py @@ -20,7 +20,7 @@ import torch.nn.functional as F from omegaconf import MISSING, ListConfig, OmegaConf -from nemo.collections.asr.parts.jasper import ( +from nemo.collections.asr.parts.submodules.jasper import ( JasperBlock, MaskedConv1d, StatsPoolLayer, @@ -127,6 +127,11 @@ def __init__( jasper = OmegaConf.to_container(jasper) activation = jasper_activations[activation]() + + # If the activation can be executed in place, do so. + if hasattr(activation, 'inplace'): + activation.inplace = True + feat_in = feat_in * frame_splicing self._feat_in = feat_in diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index 55e224fcfaf4..abe7d6ca6c2a 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -31,7 +31,7 @@ import torch from nemo.collections.asr.modules import rnnt_abstract -from nemo.collections.asr.parts import rnnt_utils +from nemo.collections.asr.parts.utils import rnnt_utils from nemo.collections.common.parts import rnn from nemo.core.classes import typecheck from nemo.core.neural_types import ( diff --git a/nemo/collections/asr/modules/rnnt_abstract.py b/nemo/collections/asr/modules/rnnt_abstract.py index b46add09ba74..3f7f9805afb3 100644 --- a/nemo/collections/asr/modules/rnnt_abstract.py +++ b/nemo/collections/asr/modules/rnnt_abstract.py @@ -16,7 +16,7 @@ import torch -from nemo.collections.asr.parts.rnnt_utils import Hypothesis +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.core import NeuralModule diff --git a/nemo/collections/asr/parts/features.py b/nemo/collections/asr/parts/features.py index abeb5621ec08..06665c413a45 100644 --- a/nemo/collections/asr/parts/features.py +++ b/nemo/collections/asr/parts/features.py @@ -32,401 +32,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # This file contains code artifacts adapted from https://github.com/ryanleary/patter -import math -import librosa -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from librosa.util import tiny -from torch.autograd import Variable -from torch_stft import STFT - -from nemo.collections.asr.parts.perturb import AudioAugmentor -from nemo.collections.asr.parts.segment import AudioSegment -from nemo.collections.common.parts.patch_utils import stft_patch -from nemo.utils import logging - -CONSTANT = 1e-5 - - -def normalize_batch(x, seq_len, normalize_type): - if normalize_type == "per_feature": - x_mean = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device) - x_std = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device) - for i in range(x.shape[0]): - if x[i, :, : seq_len[i]].shape[1] == 1: - raise ValueError( - "normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result " - "in torch.std() returning nan" - ) - x_mean[i, :] = x[i, :, : seq_len[i]].mean(dim=1) - x_std[i, :] = x[i, :, : seq_len[i]].std(dim=1) - # make sure x_std is not zero - x_std += CONSTANT - return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2) - elif normalize_type == "all_features": - x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device) - x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device) - for i in range(x.shape[0]): - x_mean[i] = x[i, :, : seq_len[i].item()].mean() - x_std[i] = x[i, :, : seq_len[i].item()].std() - # make sure x_std is not zero - x_std += CONSTANT - return (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1) - elif "fixed_mean" in normalize_type and "fixed_std" in normalize_type: - x_mean = torch.tensor(normalize_type["fixed_mean"], device=x.device) - x_std = torch.tensor(normalize_type["fixed_std"], device=x.device) - return (x - x_mean.view(x.shape[0], x.shape[1]).unsqueeze(2)) / x_std.view(x.shape[0], x.shape[1]).unsqueeze(2) - else: - return x - - -def splice_frames(x, frame_splicing): - """ Stacks frames together across feature dim - - input is batch_size, feature_dim, num_frames - output is batch_size, feature_dim*frame_splicing, num_frames - - """ - seq = [x] - for n in range(1, frame_splicing): - seq.append(torch.cat([x[:, :, :n], x[:, :, n:]], dim=2)) - return torch.cat(seq, dim=1) - - -class WaveformFeaturizer(object): - def __init__(self, sample_rate=16000, int_values=False, augmentor=None): - self.augmentor = augmentor if augmentor is not None else AudioAugmentor() - self.sample_rate = sample_rate - self.int_values = int_values - - def max_augmentation_length(self, length): - return self.augmentor.max_augmentation_length(length) - - def process(self, file_path, offset=0, duration=0, trim=False, orig_sr=None): - audio = AudioSegment.from_file( - file_path, - target_sr=self.sample_rate, - int_values=self.int_values, - offset=offset, - duration=duration, - trim=trim, - orig_sr=orig_sr, - ) - return self.process_segment(audio) - - def process_segment(self, audio_segment): - self.augmentor.perturb(audio_segment) - return torch.tensor(audio_segment.samples, dtype=torch.float) - - @classmethod - def from_config(cls, input_config, perturbation_configs=None): - if perturbation_configs is not None: - aa = AudioAugmentor.from_config(perturbation_configs) - else: - aa = None - - sample_rate = input_config.get("sample_rate", 16000) - int_values = input_config.get("int_values", False) - - return cls(sample_rate=sample_rate, int_values=int_values, augmentor=aa) - - -class FeaturizerFactory(object): - def __init__(self): - pass - - @classmethod - def from_config(cls, input_cfg, perturbation_configs=None): - return WaveformFeaturizer.from_config(input_cfg, perturbation_configs=perturbation_configs) - - -# Create helper class to patch forward func for use with AMP -class STFTPatch(STFT): - def forward(self, input_data): - return super().transform(input_data)[0] - - -# Create helper class for STFT that yields num_frames = num_samples // hop_length -class STFTExactPad(STFTPatch): - """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" - - def __init__(self, *params, **kw_params): - super().__init__(*params, **kw_params) - self.pad_amount = (self.filter_length - self.hop_length) // 2 - - def inverse(self, magnitude, phase): - recombine_magnitude_phase = torch.cat([magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1) - - inverse_transform = F.conv_transpose1d( - recombine_magnitude_phase, - Variable(self.inverse_basis, requires_grad=False), - stride=self.hop_length, - padding=0, - ) - - if self.window is not None: - window_sum = librosa.filters.window_sumsquare( - self.window, - magnitude.size(-1), - hop_length=self.hop_length, - win_length=self.win_length, - n_fft=self.filter_length, - dtype=np.float32, - ) - # remove modulation effects - approx_nonzero_indices = torch.from_numpy(np.where(window_sum > tiny(window_sum))[0]) - window_sum = torch.autograd.Variable(torch.from_numpy(window_sum), requires_grad=False).to( - magnitude.device - ) - inverse_transform[..., approx_nonzero_indices] /= window_sum[approx_nonzero_indices] - - # scale by hop ratio - inverse_transform *= self.filter_length / self.hop_length - - inverse_transform = inverse_transform[..., self.pad_amount :] - inverse_transform = inverse_transform[..., : -self.pad_amount :] - inverse_transform = inverse_transform.squeeze(1) - - return inverse_transform - - -class FilterbankFeatures(nn.Module): - """Featurizer that converts wavs to Mel Spectrograms. - See AudioToMelSpectrogramPreprocessor for args. - """ - - def __init__( - self, - sample_rate=16000, - n_window_size=320, - n_window_stride=160, - window="hann", - normalize="per_feature", - n_fft=None, - preemph=0.97, - nfilt=64, - lowfreq=0, - highfreq=None, - log=True, - log_zero_guard_type="add", - log_zero_guard_value=2 ** -24, - dither=CONSTANT, - pad_to=16, - max_duration=16.7, - frame_splicing=1, - exact_pad=False, - stft_exact_pad=False, # TODO: Remove this in 1.1.0 - stft_conv=False, # TODO: Remove this in 1.1.0 - pad_value=0, - mag_power=2.0, - use_grads=False, - ): - super().__init__() - if stft_conv or stft_exact_pad: - logging.warning( - "Using torch_stft is deprecated and will be removed in 1.1.0. Please set stft_conv and stft_exact_pad " - "to False for FilterbankFeatures and AudioToMelSpectrogramPreprocessor. Please set exact_pad to True " - "as needed." - ) - if (exact_pad or stft_exact_pad) and n_window_stride % 2 == 1: - raise NotImplementedError( - f"{self} received exact_pad == True, but hop_size was odd. If audio_length % hop_size == 0. Then the " - "returned spectrogram would not be of length audio_length // hop_size. Please use an even hop_size." - ) - self.log_zero_guard_value = log_zero_guard_value - if ( - n_window_size is None - or n_window_stride is None - or not isinstance(n_window_size, int) - or not isinstance(n_window_stride, int) - or n_window_size <= 0 - or n_window_stride <= 0 - ): - raise ValueError( - f"{self} got an invalid value for either n_window_size or " - f"n_window_stride. Both must be positive ints." - ) - logging.info(f"PADDING: {pad_to}") - - self.win_length = n_window_size - self.hop_length = n_window_stride - self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length)) - self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None - self.stft_exact_pad = stft_exact_pad - self.stft_conv = stft_conv - - if stft_conv: - logging.info("STFT using conv") - if stft_exact_pad: - logging.info("STFT using exact pad") - self.stft = STFTExactPad(self.n_fft, self.hop_length, self.win_length, window) - else: - self.stft = STFTPatch(self.n_fft, self.hop_length, self.win_length, window) - else: - logging.info("STFT using torch") - if exact_pad: - logging.info("STFT using exact pad") - torch_windows = { - 'hann': torch.hann_window, - 'hamming': torch.hamming_window, - 'blackman': torch.blackman_window, - 'bartlett': torch.bartlett_window, - 'none': None, - } - window_fn = torch_windows.get(window, None) - window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None - self.register_buffer("window", window_tensor) - self.stft = lambda x: stft_patch( - x, - n_fft=self.n_fft, - hop_length=self.hop_length, - win_length=self.win_length, - center=False if exact_pad else True, - window=self.window.to(dtype=torch.float), - return_complex=False, - ) - - self.normalize = normalize - self.log = log - self.dither = dither - self.frame_splicing = frame_splicing - self.nfilt = nfilt - self.preemph = preemph - self.pad_to = pad_to - highfreq = highfreq or sample_rate / 2 - - filterbanks = torch.tensor( - librosa.filters.mel(sample_rate, self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq), dtype=torch.float - ).unsqueeze(0) - self.register_buffer("fb", filterbanks) - - # Calculate maximum sequence length - max_length = self.get_seq_len(torch.tensor(max_duration * sample_rate, dtype=torch.float)) - max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0 - self.max_length = max_length + max_pad - self.pad_value = pad_value - self.mag_power = mag_power - - # We want to avoid taking the log of zero - # There are two options: either adding or clamping to a small value - if log_zero_guard_type not in ["add", "clamp"]: - raise ValueError( - f"{self} received {log_zero_guard_type} for the " - f"log_zero_guard_type parameter. It must be either 'add' or " - f"'clamp'." - ) - - self.use_grads = use_grads - if not use_grads: - self.forward = torch.no_grad()(self.forward) - - # log_zero_guard_value is the the small we want to use, we support - # an actual number, or "tiny", or "eps" - self.log_zero_guard_type = log_zero_guard_type - logging.debug(f"sr: {sample_rate}") - logging.debug(f"n_fft: {self.n_fft}") - logging.debug(f"win_length: {self.win_length}") - logging.debug(f"hop_length: {self.hop_length}") - logging.debug(f"n_mels: {nfilt}") - logging.debug(f"fmin: {lowfreq}") - logging.debug(f"fmax: {highfreq}") - logging.debug(f"using grads: {use_grads}") - - def log_zero_guard_value_fn(self, x): - if isinstance(self.log_zero_guard_value, str): - if self.log_zero_guard_value == "tiny": - return torch.finfo(x.dtype).tiny - elif self.log_zero_guard_value == "eps": - return torch.finfo(x.dtype).eps - else: - raise ValueError( - f"{self} received {self.log_zero_guard_value} for the " - f"log_zero_guard_type parameter. It must be either a " - f"number, 'tiny', or 'eps'" - ) - else: - return self.log_zero_guard_value - - def get_seq_len(self, seq_len): - if isinstance(self.stft, STFT): - pad_amount = self.stft.pad_amount * 2 - else: - # Assuming that center is True is stft_pad_amount = 0 - pad_amount = self.stft_pad_amount * 2 if self.stft_pad_amount is not None else self.n_fft // 2 * 2 - seq_len = torch.floor((seq_len + pad_amount - self.n_fft) / self.hop_length) + 1 - return seq_len.to(dtype=torch.long) - - @property - def filter_banks(self): - return self.fb - - def forward(self, x, seq_len): - seq_len = self.get_seq_len(seq_len.float()) - - if self.stft_pad_amount is not None: - x = torch.nn.functional.pad( - x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "reflect" - ).squeeze(1) - - # dither (only in training mode for eval determinism) - if self.training and self.dither > 0: - x += self.dither * torch.randn_like(x) - - # do preemphasis - if self.preemph is not None: - x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1) - - # disable autocast to get full range of stft values - with torch.cuda.amp.autocast(enabled=False): - x = self.stft(x) - - # torch returns real, imag; so convert to magnitude - if not self.stft_conv: - # guard is needed for sqrt if grads are passed through - guard = 0 if not self.use_grads else CONSTANT - if x.dtype in [torch.cfloat, torch.cdouble]: - x = torch.view_as_real(x) - x = torch.sqrt(x.pow(2).sum(-1) + guard) - - # get power spectrum - if self.mag_power != 1.0: - x = x.pow(self.mag_power) - - # dot with filterbank energies - x = torch.matmul(self.fb.to(x.dtype), x) - - # log features if required - if self.log: - if self.log_zero_guard_type == "add": - x = torch.log(x + self.log_zero_guard_value_fn(x)) - elif self.log_zero_guard_type == "clamp": - x = torch.log(torch.clamp(x, min=self.log_zero_guard_value_fn(x))) - else: - raise ValueError("log_zero_guard_type was not understood") - - # frame splicing if required - if self.frame_splicing > 1: - x = splice_frames(x, self.frame_splicing) - - # normalize if required - if self.normalize: - x = normalize_batch(x, seq_len, normalize_type=self.normalize) - - # mask to zero any values beyond seq_len in batch, pad to multiple of `pad_to` (for efficiency) - max_len = x.size(-1) - mask = torch.arange(max_len).to(x.device) - mask = mask.expand(x.size(0), max_len) >= seq_len.unsqueeze(1) - x = x.masked_fill(mask.unsqueeze(1).type(torch.bool).to(device=x.device), self.pad_value) - del mask - pad_to = self.pad_to - if pad_to == "max": - x = nn.functional.pad(x, (0, self.max_length - x.size(-1)), value=self.pad_value) - elif pad_to > 0: - pad_amt = x.size(-1) % pad_to - if pad_amt != 0: - x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value) - - return x, seq_len +""" +ALIAS FILE for backward compatibility +""" +from nemo.collections.asr.parts.preprocessing.features import * diff --git a/nemo/collections/asr/parts/mixins/__init__.py b/nemo/collections/asr/parts/mixins/__init__.py new file mode 100644 index 000000000000..42caf13f664d --- /dev/null +++ b/nemo/collections/asr/parts/mixins/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.asr.parts.mixins.mixins import ASRBPEMixin, ASRModuleMixin, DiarizationMixin diff --git a/nemo/collections/asr/parts/mixins.py b/nemo/collections/asr/parts/mixins/mixins.py similarity index 99% rename from nemo/collections/asr/parts/mixins.py rename to nemo/collections/asr/parts/mixins/mixins.py index c91a80a61e4b..35e689b04eda 100644 --- a/nemo/collections/asr/parts/mixins.py +++ b/nemo/collections/asr/parts/mixins/mixins.py @@ -18,7 +18,7 @@ from omegaconf import DictConfig, OmegaConf, open_dict -from nemo.collections.asr.parts import asr_module_utils +from nemo.collections.asr.parts.utils import asr_module_utils from nemo.collections.common import tokenizers from nemo.utils import logging diff --git a/nemo/collections/asr/parts/preprocessing/__init__.py b/nemo/collections/asr/parts/preprocessing/__init__.py new file mode 100644 index 000000000000..c467ec332730 --- /dev/null +++ b/nemo/collections/asr/parts/preprocessing/__init__.py @@ -0,0 +1,41 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.asr.parts.preprocessing.feature_loader import ExternalFeatureLoader +from nemo.collections.asr.parts.preprocessing.features import ( + STFT, + FeaturizerFactory, + FilterbankFeatures, + STFTExactPad, + STFTPatch, + WaveformFeaturizer, +) +from nemo.collections.asr.parts.preprocessing.perturb import ( + AudioAugmentor, + AugmentationDataset, + GainPerturbation, + ImpulsePerturbation, + NoisePerturbation, + Perturbation, + RirAndNoisePerturbation, + ShiftPerturbation, + SpeedPerturbation, + TimeStretchPerturbation, + TranscodePerturbation, + WhiteNoisePerturbation, + perturbation_types, + process_augmentations, + register_perturbation, +) +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment diff --git a/nemo/collections/asr/parts/feature_loader.py b/nemo/collections/asr/parts/preprocessing/feature_loader.py similarity index 100% rename from nemo/collections/asr/parts/feature_loader.py rename to nemo/collections/asr/parts/preprocessing/feature_loader.py diff --git a/nemo/collections/asr/parts/preprocessing/features.py b/nemo/collections/asr/parts/preprocessing/features.py new file mode 100644 index 000000000000..6cf0c07aa759 --- /dev/null +++ b/nemo/collections/asr/parts/preprocessing/features.py @@ -0,0 +1,432 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Copyright (c) 2018 Ryan Leary +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# This file contains code artifacts adapted from https://github.com/ryanleary/patter +import math + +import librosa +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from librosa.util import tiny +from torch.autograd import Variable +from torch_stft import STFT + +from nemo.collections.asr.parts.preprocessing.perturb import AudioAugmentor +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment +from nemo.collections.common.parts.patch_utils import stft_patch +from nemo.utils import logging + +CONSTANT = 1e-5 + + +def normalize_batch(x, seq_len, normalize_type): + if normalize_type == "per_feature": + x_mean = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device) + x_std = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device) + for i in range(x.shape[0]): + if x[i, :, : seq_len[i]].shape[1] == 1: + raise ValueError( + "normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result " + "in torch.std() returning nan" + ) + x_mean[i, :] = x[i, :, : seq_len[i]].mean(dim=1) + x_std[i, :] = x[i, :, : seq_len[i]].std(dim=1) + # make sure x_std is not zero + x_std += CONSTANT + return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2) + elif normalize_type == "all_features": + x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device) + x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device) + for i in range(x.shape[0]): + x_mean[i] = x[i, :, : seq_len[i].item()].mean() + x_std[i] = x[i, :, : seq_len[i].item()].std() + # make sure x_std is not zero + x_std += CONSTANT + return (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1) + elif "fixed_mean" in normalize_type and "fixed_std" in normalize_type: + x_mean = torch.tensor(normalize_type["fixed_mean"], device=x.device) + x_std = torch.tensor(normalize_type["fixed_std"], device=x.device) + return (x - x_mean.view(x.shape[0], x.shape[1]).unsqueeze(2)) / x_std.view(x.shape[0], x.shape[1]).unsqueeze(2) + else: + return x + + +def splice_frames(x, frame_splicing): + """ Stacks frames together across feature dim + + input is batch_size, feature_dim, num_frames + output is batch_size, feature_dim*frame_splicing, num_frames + + """ + seq = [x] + for n in range(1, frame_splicing): + seq.append(torch.cat([x[:, :, :n], x[:, :, n:]], dim=2)) + return torch.cat(seq, dim=1) + + +class WaveformFeaturizer(object): + def __init__(self, sample_rate=16000, int_values=False, augmentor=None): + self.augmentor = augmentor if augmentor is not None else AudioAugmentor() + self.sample_rate = sample_rate + self.int_values = int_values + + def max_augmentation_length(self, length): + return self.augmentor.max_augmentation_length(length) + + def process(self, file_path, offset=0, duration=0, trim=False, orig_sr=None): + audio = AudioSegment.from_file( + file_path, + target_sr=self.sample_rate, + int_values=self.int_values, + offset=offset, + duration=duration, + trim=trim, + orig_sr=orig_sr, + ) + return self.process_segment(audio) + + def process_segment(self, audio_segment): + self.augmentor.perturb(audio_segment) + return torch.tensor(audio_segment.samples, dtype=torch.float) + + @classmethod + def from_config(cls, input_config, perturbation_configs=None): + if perturbation_configs is not None: + aa = AudioAugmentor.from_config(perturbation_configs) + else: + aa = None + + sample_rate = input_config.get("sample_rate", 16000) + int_values = input_config.get("int_values", False) + + return cls(sample_rate=sample_rate, int_values=int_values, augmentor=aa) + + +class FeaturizerFactory(object): + def __init__(self): + pass + + @classmethod + def from_config(cls, input_cfg, perturbation_configs=None): + return WaveformFeaturizer.from_config(input_cfg, perturbation_configs=perturbation_configs) + + +# Create helper class to patch forward func for use with AMP +class STFTPatch(STFT): + def forward(self, input_data): + return super().transform(input_data)[0] + + +# Create helper class for STFT that yields num_frames = num_samples // hop_length +class STFTExactPad(STFTPatch): + """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" + + def __init__(self, *params, **kw_params): + super().__init__(*params, **kw_params) + self.pad_amount = (self.filter_length - self.hop_length) // 2 + + def inverse(self, magnitude, phase): + recombine_magnitude_phase = torch.cat([magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1) + + inverse_transform = F.conv_transpose1d( + recombine_magnitude_phase, + Variable(self.inverse_basis, requires_grad=False), + stride=self.hop_length, + padding=0, + ) + + if self.window is not None: + window_sum = librosa.filters.window_sumsquare( + self.window, + magnitude.size(-1), + hop_length=self.hop_length, + win_length=self.win_length, + n_fft=self.filter_length, + dtype=np.float32, + ) + # remove modulation effects + approx_nonzero_indices = torch.from_numpy(np.where(window_sum > tiny(window_sum))[0]) + window_sum = torch.autograd.Variable(torch.from_numpy(window_sum), requires_grad=False).to( + magnitude.device + ) + inverse_transform[..., approx_nonzero_indices] /= window_sum[approx_nonzero_indices] + + # scale by hop ratio + inverse_transform *= self.filter_length / self.hop_length + + inverse_transform = inverse_transform[..., self.pad_amount :] + inverse_transform = inverse_transform[..., : -self.pad_amount :] + inverse_transform = inverse_transform.squeeze(1) + + return inverse_transform + + +class FilterbankFeatures(nn.Module): + """Featurizer that converts wavs to Mel Spectrograms. + See AudioToMelSpectrogramPreprocessor for args. + """ + + def __init__( + self, + sample_rate=16000, + n_window_size=320, + n_window_stride=160, + window="hann", + normalize="per_feature", + n_fft=None, + preemph=0.97, + nfilt=64, + lowfreq=0, + highfreq=None, + log=True, + log_zero_guard_type="add", + log_zero_guard_value=2 ** -24, + dither=CONSTANT, + pad_to=16, + max_duration=16.7, + frame_splicing=1, + exact_pad=False, + stft_exact_pad=False, # TODO: Remove this in 1.1.0 + stft_conv=False, # TODO: Remove this in 1.1.0 + pad_value=0, + mag_power=2.0, + use_grads=False, + ): + super().__init__() + if stft_conv or stft_exact_pad: + logging.warning( + "Using torch_stft is deprecated and will be removed in 1.1.0. Please set stft_conv and stft_exact_pad " + "to False for FilterbankFeatures and AudioToMelSpectrogramPreprocessor. Please set exact_pad to True " + "as needed." + ) + if (exact_pad or stft_exact_pad) and n_window_stride % 2 == 1: + raise NotImplementedError( + f"{self} received exact_pad == True, but hop_size was odd. If audio_length % hop_size == 0. Then the " + "returned spectrogram would not be of length audio_length // hop_size. Please use an even hop_size." + ) + self.log_zero_guard_value = log_zero_guard_value + if ( + n_window_size is None + or n_window_stride is None + or not isinstance(n_window_size, int) + or not isinstance(n_window_stride, int) + or n_window_size <= 0 + or n_window_stride <= 0 + ): + raise ValueError( + f"{self} got an invalid value for either n_window_size or " + f"n_window_stride. Both must be positive ints." + ) + logging.info(f"PADDING: {pad_to}") + + self.win_length = n_window_size + self.hop_length = n_window_stride + self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length)) + self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None + self.stft_exact_pad = stft_exact_pad + self.stft_conv = stft_conv + + if stft_conv: + logging.info("STFT using conv") + if stft_exact_pad: + logging.info("STFT using exact pad") + self.stft = STFTExactPad(self.n_fft, self.hop_length, self.win_length, window) + else: + self.stft = STFTPatch(self.n_fft, self.hop_length, self.win_length, window) + else: + logging.info("STFT using torch") + if exact_pad: + logging.info("STFT using exact pad") + torch_windows = { + 'hann': torch.hann_window, + 'hamming': torch.hamming_window, + 'blackman': torch.blackman_window, + 'bartlett': torch.bartlett_window, + 'none': None, + } + window_fn = torch_windows.get(window, None) + window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None + self.register_buffer("window", window_tensor) + self.stft = lambda x: stft_patch( + x, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + center=False if exact_pad else True, + window=self.window.to(dtype=torch.float), + return_complex=False, + ) + + self.normalize = normalize + self.log = log + self.dither = dither + self.frame_splicing = frame_splicing + self.nfilt = nfilt + self.preemph = preemph + self.pad_to = pad_to + highfreq = highfreq or sample_rate / 2 + + filterbanks = torch.tensor( + librosa.filters.mel(sample_rate, self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq), dtype=torch.float + ).unsqueeze(0) + self.register_buffer("fb", filterbanks) + + # Calculate maximum sequence length + max_length = self.get_seq_len(torch.tensor(max_duration * sample_rate, dtype=torch.float)) + max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0 + self.max_length = max_length + max_pad + self.pad_value = pad_value + self.mag_power = mag_power + + # We want to avoid taking the log of zero + # There are two options: either adding or clamping to a small value + if log_zero_guard_type not in ["add", "clamp"]: + raise ValueError( + f"{self} received {log_zero_guard_type} for the " + f"log_zero_guard_type parameter. It must be either 'add' or " + f"'clamp'." + ) + + self.use_grads = use_grads + if not use_grads: + self.forward = torch.no_grad()(self.forward) + + # log_zero_guard_value is the the small we want to use, we support + # an actual number, or "tiny", or "eps" + self.log_zero_guard_type = log_zero_guard_type + logging.debug(f"sr: {sample_rate}") + logging.debug(f"n_fft: {self.n_fft}") + logging.debug(f"win_length: {self.win_length}") + logging.debug(f"hop_length: {self.hop_length}") + logging.debug(f"n_mels: {nfilt}") + logging.debug(f"fmin: {lowfreq}") + logging.debug(f"fmax: {highfreq}") + logging.debug(f"using grads: {use_grads}") + + def log_zero_guard_value_fn(self, x): + if isinstance(self.log_zero_guard_value, str): + if self.log_zero_guard_value == "tiny": + return torch.finfo(x.dtype).tiny + elif self.log_zero_guard_value == "eps": + return torch.finfo(x.dtype).eps + else: + raise ValueError( + f"{self} received {self.log_zero_guard_value} for the " + f"log_zero_guard_type parameter. It must be either a " + f"number, 'tiny', or 'eps'" + ) + else: + return self.log_zero_guard_value + + def get_seq_len(self, seq_len): + if isinstance(self.stft, STFT): + pad_amount = self.stft.pad_amount * 2 + else: + # Assuming that center is True is stft_pad_amount = 0 + pad_amount = self.stft_pad_amount * 2 if self.stft_pad_amount is not None else self.n_fft // 2 * 2 + seq_len = torch.floor((seq_len + pad_amount - self.n_fft) / self.hop_length) + 1 + return seq_len.to(dtype=torch.long) + + @property + def filter_banks(self): + return self.fb + + def forward(self, x, seq_len): + seq_len = self.get_seq_len(seq_len.float()) + + if self.stft_pad_amount is not None: + x = torch.nn.functional.pad( + x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "reflect" + ).squeeze(1) + + # dither (only in training mode for eval determinism) + if self.training and self.dither > 0: + x += self.dither * torch.randn_like(x) + + # do preemphasis + if self.preemph is not None: + x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1) + + # disable autocast to get full range of stft values + with torch.cuda.amp.autocast(enabled=False): + x = self.stft(x) + + # torch returns real, imag; so convert to magnitude + if not self.stft_conv: + # guard is needed for sqrt if grads are passed through + guard = 0 if not self.use_grads else CONSTANT + if x.dtype in [torch.cfloat, torch.cdouble]: + x = torch.view_as_real(x) + x = torch.sqrt(x.pow(2).sum(-1) + guard) + + # get power spectrum + if self.mag_power != 1.0: + x = x.pow(self.mag_power) + + # dot with filterbank energies + x = torch.matmul(self.fb.to(x.dtype), x) + + # log features if required + if self.log: + if self.log_zero_guard_type == "add": + x = torch.log(x + self.log_zero_guard_value_fn(x)) + elif self.log_zero_guard_type == "clamp": + x = torch.log(torch.clamp(x, min=self.log_zero_guard_value_fn(x))) + else: + raise ValueError("log_zero_guard_type was not understood") + + # frame splicing if required + if self.frame_splicing > 1: + x = splice_frames(x, self.frame_splicing) + + # normalize if required + if self.normalize: + x = normalize_batch(x, seq_len, normalize_type=self.normalize) + + # mask to zero any values beyond seq_len in batch, pad to multiple of `pad_to` (for efficiency) + max_len = x.size(-1) + mask = torch.arange(max_len).to(x.device) + mask = mask.expand(x.size(0), max_len) >= seq_len.unsqueeze(1) + x = x.masked_fill(mask.unsqueeze(1).type(torch.bool).to(device=x.device), self.pad_value) + del mask + pad_to = self.pad_to + if pad_to == "max": + x = nn.functional.pad(x, (0, self.max_length - x.size(-1)), value=self.pad_value) + elif pad_to > 0: + pad_amt = x.size(-1) % pad_to + if pad_amt != 0: + x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value) + + return x, seq_len diff --git a/nemo/collections/asr/parts/perturb.py b/nemo/collections/asr/parts/preprocessing/perturb.py similarity index 99% rename from nemo/collections/asr/parts/perturb.py rename to nemo/collections/asr/parts/preprocessing/perturb.py index a2a297dfd08d..812efc3eabac 100644 --- a/nemo/collections/asr/parts/perturb.py +++ b/nemo/collections/asr/parts/preprocessing/perturb.py @@ -48,12 +48,12 @@ from scipy import signal from torch.utils.data import IterableDataset -from nemo.collections.asr.parts import collections, parsers -from nemo.collections.asr.parts.segment import AudioSegment +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment +from nemo.collections.common.parts.preprocessing import collections, parsers from nemo.utils import logging try: - from nemo.collections.asr.parts import numba_utils + from nemo.collections.asr.parts.utils import numba_utils HAVE_NUMBA = True except (ImportError, ModuleNotFoundError): diff --git a/nemo/collections/asr/parts/segment.py b/nemo/collections/asr/parts/preprocessing/segment.py similarity index 86% rename from nemo/collections/asr/parts/segment.py rename to nemo/collections/asr/parts/preprocessing/segment.py index 1cbfc2382f52..b573109471a0 100644 --- a/nemo/collections/asr/parts/segment.py +++ b/nemo/collections/asr/parts/preprocessing/segment.py @@ -42,6 +42,7 @@ from kaldiio.matio import read_kaldi from kaldiio.utils import open_like_kaldi from pydub import AudioSegment as Audio +from pydub.exceptions import CouldntDecodeError from nemo.utils import logging @@ -145,7 +146,8 @@ def from_file( samples = samples.transpose() except RuntimeError as e: logging.error( - f"Loading audio via SoundFile raised RuntimeError: `{e}`. NeMo will fallback to loading via pydub." + f"Loading {audio_file} via SoundFile raised RuntimeError: `{e}`. " + f"NeMo will fallback to loading via pydub." ) elif isinstance(audio_file, str) and audio_file.strip()[-1] == "|": f = open_like_kaldi(audio_file, "rb") @@ -159,16 +161,19 @@ def from_file( samples = np.array(samples, dtype=np.float) / abs_max_value if samples is None: - samples = Audio.from_file(audio_file) - sample_rate = samples.frame_rate - if offset > 0: - # pydub does things in milliseconds - seconds = offset * 1000 - samples = samples[int(seconds * sample_rate) :] - if duration > 0: - seconds = duration * 1000 - samples = samples[: int(seconds)] - samples = np.array(samples.get_array_of_samples()) + try: + samples = Audio.from_file(audio_file) + sample_rate = samples.frame_rate + if offset > 0: + # pydub does things in milliseconds + seconds = offset * 1000 + samples = samples[int(seconds * sample_rate) :] + if duration > 0: + seconds = duration * 1000 + samples = samples[: int(seconds)] + samples = np.array(samples.get_array_of_samples()) + except CouldntDecodeError as e: + logging.error(f"Loading {audio_file} via pydub raised CouldntDecodeError: `{e}`.") return cls(samples, sample_rate, target_sr=target_sr, trim=trim, orig_sr=orig_sr) @@ -179,15 +184,19 @@ def segment_from_file(cls, audio_file, target_sr=None, n_segments=0, trim=False, Note that audio_file can be either the file path, or a file-like object. """ - with sf.SoundFile(audio_file, 'r') as f: - sample_rate = f.samplerate - if n_segments > 0 and len(f) > n_segments: - max_audio_start = len(f) - n_segments - audio_start = random.randint(0, max_audio_start) - f.seek(audio_start) - samples = f.read(n_segments, dtype='float32') - else: - samples = f.read(dtype='float32') + try: + with sf.SoundFile(audio_file, 'r') as f: + sample_rate = f.samplerate + if n_segments > 0 and len(f) > n_segments: + max_audio_start = len(f) - n_segments + audio_start = random.randint(0, max_audio_start) + f.seek(audio_start) + samples = f.read(n_segments, dtype='float32') + else: + samples = f.read(dtype='float32') + samples = samples.transpose() + except RuntimeError as e: + logging.error(f"Loading {audio_file} via SoundFile raised RuntimeError: `{e}`.") samples = samples.transpose() return cls(samples, sample_rate, target_sr=target_sr, trim=trim, orig_sr=orig_sr) diff --git a/nemo/collections/asr/parts/submodules/__init__.py b/nemo/collections/asr/parts/submodules/__init__.py new file mode 100644 index 000000000000..bc443be41c4c --- /dev/null +++ b/nemo/collections/asr/parts/submodules/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/asr/parts/conformer_modules.py b/nemo/collections/asr/parts/submodules/conformer_modules.py similarity index 97% rename from nemo/collections/asr/parts/conformer_modules.py rename to nemo/collections/asr/parts/submodules/conformer_modules.py index a14d112bd6c1..b0d3f7440c0a 100644 --- a/nemo/collections/asr/parts/conformer_modules.py +++ b/nemo/collections/asr/parts/submodules/conformer_modules.py @@ -16,8 +16,11 @@ from torch import nn as nn from torch.nn import LayerNorm -from nemo.collections.asr.parts.activations import Swish -from nemo.collections.asr.parts.multi_head_attention import MultiHeadAttention, RelPositionMultiHeadAttention +from nemo.collections.asr.parts.submodules.multi_head_attention import ( + MultiHeadAttention, + RelPositionMultiHeadAttention, +) +from nemo.collections.asr.parts.utils.activations import Swish __all__ = ['ConformerConvolution', 'ConformerFeedForward', 'ConformerLayer'] diff --git a/nemo/collections/asr/parts/jasper.py b/nemo/collections/asr/parts/submodules/jasper.py similarity index 99% rename from nemo/collections/asr/parts/jasper.py rename to nemo/collections/asr/parts/submodules/jasper.py index 73cc213dd8c3..5cd5e00b6c24 100644 --- a/nemo/collections/asr/parts/jasper.py +++ b/nemo/collections/asr/parts/submodules/jasper.py @@ -21,7 +21,7 @@ from torch.nn.init import _calculate_correct_fan from torch.nn.modules.utils import _single -from nemo.collections.asr.parts.activations import Swish +from nemo.collections.asr.parts.utils.activations import Swish from nemo.utils import logging try: @@ -34,7 +34,7 @@ except ImportError: PYTORCH_QUANTIZATION_AVAILABLE = False -jasper_activations = {"hardtanh": nn.Hardtanh, "relu": nn.ReLU, "selu": nn.SELU, "swish": Swish} +jasper_activations = {"hardtanh": nn.Hardtanh, "relu": nn.ReLU, "selu": nn.SELU, "swish": Swish, "silu": nn.SiLU} def tds_uniform_(tensor, mode='fan_in'): diff --git a/nemo/collections/asr/parts/multi_head_attention.py b/nemo/collections/asr/parts/submodules/multi_head_attention.py similarity index 100% rename from nemo/collections/asr/parts/multi_head_attention.py rename to nemo/collections/asr/parts/submodules/multi_head_attention.py diff --git a/nemo/collections/asr/parts/rnnt_beam_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py similarity index 99% rename from nemo/collections/asr/parts/rnnt_beam_decoding.py rename to nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py index 249983913726..203aa750f6e9 100644 --- a/nemo/collections/asr/parts/rnnt_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py @@ -35,8 +35,7 @@ from tqdm import tqdm from nemo.collections.asr.modules import rnnt_abstract -from nemo.collections.asr.parts import rnnt_utils -from nemo.collections.asr.parts.rnnt_utils import Hypothesis, NBestHypotheses +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses from nemo.core.classes import Typing, typecheck from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType diff --git a/nemo/collections/asr/parts/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py similarity index 99% rename from nemo/collections/asr/parts/rnnt_greedy_decoding.py rename to nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index 31242fbbe5f3..52124b90339f 100644 --- a/nemo/collections/asr/parts/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -32,7 +32,7 @@ import torch from nemo.collections.asr.modules import rnnt_abstract -from nemo.collections.asr.parts import rnnt_utils +from nemo.collections.asr.parts.utils import rnnt_utils from nemo.collections.common.parts.rnn import label_collate from nemo.core.classes import Typing, typecheck from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType diff --git a/nemo/collections/asr/parts/spectr_augment.py b/nemo/collections/asr/parts/submodules/spectr_augment.py similarity index 72% rename from nemo/collections/asr/parts/spectr_augment.py rename to nemo/collections/asr/parts/submodules/spectr_augment.py index 300c67f6a6e8..ef9821e3a338 100644 --- a/nemo/collections/asr/parts/spectr_augment.py +++ b/nemo/collections/asr/parts/submodules/spectr_augment.py @@ -17,8 +17,11 @@ import torch import torch.nn as nn +from nemo.core.classes import Typing, typecheck +from nemo.core.neural_types import NeuralType, SpectrogramType -class SpecAugment(nn.Module): + +class SpecAugment(nn.Module, Typing): """ Zeroes out(cuts) random continuous horisontal or vertical segments of the spectrogram as described in @@ -36,6 +39,18 @@ class SpecAugment(nn.Module): are cut adaptively. """ + @property + def input_types(self): + """Returns definitions of module input types + """ + return {"input_spec": NeuralType(('B', 'D', 'T'), SpectrogramType())} + + @property + def output_types(self): + """Returns definitions of module output types + """ + return {"augmented_spec": NeuralType(('B', 'D', 'T'), SpectrogramType())} + def __init__( self, freq_masks=0, time_masks=0, freq_width=10, time_width=10, rng=None, mask_value=0.0, ): @@ -59,9 +74,10 @@ def __init__( self.adaptive_temporal_width = True + @typecheck() @torch.no_grad() - def forward(self, x): - sh = x.shape + def forward(self, input_spec): + sh = input_spec.shape if self.adaptive_temporal_width: time_width = max(1, int(sh[2] * self.time_width)) @@ -74,19 +90,19 @@ def forward(self, x): w = self._rng.randint(0, self.freq_width) - x[idx, x_left : x_left + w, :] = self.mask_value + input_spec[idx, x_left : x_left + w, :] = self.mask_value for i in range(self.time_masks): y_left = self._rng.randint(0, sh[2] - time_width) w = self._rng.randint(0, time_width) - x[idx, :, y_left : y_left + w] = self.mask_value + input_spec[idx, :, y_left : y_left + w] = self.mask_value - return x + return input_spec -class SpecCutout(nn.Module): +class SpecCutout(nn.Module, Typing): """ Zeroes out(cuts) random rectangles in the spectrogram as described in (https://arxiv.org/abs/1708.04552). @@ -97,6 +113,18 @@ class SpecCutout(nn.Module): rect_time - maximum size of cut rectangles along the time dimension """ + @property + def input_types(self): + """Returns definitions of module input types + """ + return {"input_spec": NeuralType(('B', 'D', 'T'), SpectrogramType())} + + @property + def output_types(self): + """Returns definitions of module output types + """ + return {"augmented_spec": NeuralType(('B', 'D', 'T'), SpectrogramType())} + def __init__(self, rect_masks=0, rect_time=5, rect_freq=20, rng=None): super(SpecCutout, self).__init__() @@ -106,9 +134,10 @@ def __init__(self, rect_masks=0, rect_time=5, rect_freq=20, rng=None): self.rect_time = rect_time self.rect_freq = rect_freq + @typecheck() @torch.no_grad() - def forward(self, x): - sh = x.shape + def forward(self, input_spec): + sh = input_spec.shape for idx in range(sh[0]): for i in range(self.rect_masks): @@ -118,6 +147,6 @@ def forward(self, x): w_x = self._rng.randint(0, self.rect_freq) w_y = self._rng.randint(0, self.rect_time) - x[idx, rect_x : rect_x + w_x, rect_y : rect_y + w_y] = 0.0 + input_spec[idx, rect_x : rect_x + w_x, rect_y : rect_y + w_y] = 0.0 - return x + return input_spec diff --git a/nemo/collections/asr/parts/subsampling.py b/nemo/collections/asr/parts/submodules/subsampling.py similarity index 100% rename from nemo/collections/asr/parts/subsampling.py rename to nemo/collections/asr/parts/submodules/subsampling.py diff --git a/nemo/collections/asr/parts/wav2vec.py b/nemo/collections/asr/parts/submodules/wav2vec.py similarity index 100% rename from nemo/collections/asr/parts/wav2vec.py rename to nemo/collections/asr/parts/submodules/wav2vec.py diff --git a/nemo/collections/asr/parts/utils/__init__.py b/nemo/collections/asr/parts/utils/__init__.py new file mode 100644 index 000000000000..bc443be41c4c --- /dev/null +++ b/nemo/collections/asr/parts/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/asr/parts/activations.py b/nemo/collections/asr/parts/utils/activations.py similarity index 88% rename from nemo/collections/asr/parts/activations.py rename to nemo/collections/asr/parts/utils/activations.py index 627eef295717..55ebaca7e12f 100644 --- a/nemo/collections/asr/parts/activations.py +++ b/nemo/collections/asr/parts/utils/activations.py @@ -12,16 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch import torch.nn as nn __all__ = ['Swish'] -class Swish(nn.Module): +class Swish(nn.SiLU): """ Swish activation function introduced in 'https://arxiv.org/abs/1710.05941' + Mathematically identical to SiLU. See note in nn.SiLU for references. """ - - def forward(self, x): - return x * torch.sigmoid(x) diff --git a/nemo/collections/asr/parts/asr_module_utils.py b/nemo/collections/asr/parts/utils/asr_module_utils.py similarity index 98% rename from nemo/collections/asr/parts/asr_module_utils.py rename to nemo/collections/asr/parts/utils/asr_module_utils.py index 5ad6036c168b..e077d7948c0d 100644 --- a/nemo/collections/asr/parts/asr_module_utils.py +++ b/nemo/collections/asr/parts/utils/asr_module_utils.py @@ -17,7 +17,7 @@ from omegaconf import DictConfig, open_dict from nemo.collections.asr.modules import conv_asr -from nemo.collections.asr.parts import jasper +from nemo.collections.asr.parts.submodules import jasper from nemo.utils import logging diff --git a/nemo/collections/asr/parts/nmse_clustering.py b/nemo/collections/asr/parts/utils/nmse_clustering.py similarity index 100% rename from nemo/collections/asr/parts/nmse_clustering.py rename to nemo/collections/asr/parts/utils/nmse_clustering.py diff --git a/nemo/collections/asr/parts/numba_utils.py b/nemo/collections/asr/parts/utils/numba_utils.py similarity index 100% rename from nemo/collections/asr/parts/numba_utils.py rename to nemo/collections/asr/parts/utils/numba_utils.py diff --git a/nemo/collections/asr/parts/rnnt_utils.py b/nemo/collections/asr/parts/utils/rnnt_utils.py similarity index 100% rename from nemo/collections/asr/parts/rnnt_utils.py rename to nemo/collections/asr/parts/utils/rnnt_utils.py diff --git a/nemo/collections/asr/parts/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py similarity index 99% rename from nemo/collections/asr/parts/speaker_utils.py rename to nemo/collections/asr/parts/utils/speaker_utils.py index 169d20228977..09a32b9b89f0 100644 --- a/nemo/collections/asr/parts/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -23,7 +23,7 @@ from pyannote.metrics.diarization import DiarizationErrorRate from tqdm import tqdm -from nemo.collections.asr.parts.nmse_clustering import COSclustering +from nemo.collections.asr.parts.utils.nmse_clustering import COSclustering from nemo.utils import logging diff --git a/nemo/collections/asr/parts/vad_utils.py b/nemo/collections/asr/parts/utils/vad_utils.py similarity index 100% rename from nemo/collections/asr/parts/vad_utils.py rename to nemo/collections/asr/parts/utils/vad_utils.py diff --git a/nemo/collections/common/parts/preprocessing/__init__.py b/nemo/collections/common/parts/preprocessing/__init__.py new file mode 100644 index 000000000000..bc443be41c4c --- /dev/null +++ b/nemo/collections/common/parts/preprocessing/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/asr/parts/cleaners.py b/nemo/collections/common/parts/preprocessing/cleaners.py similarity index 100% rename from nemo/collections/asr/parts/cleaners.py rename to nemo/collections/common/parts/preprocessing/cleaners.py diff --git a/nemo/collections/asr/parts/collections.py b/nemo/collections/common/parts/preprocessing/collections.py similarity index 99% rename from nemo/collections/asr/parts/collections.py rename to nemo/collections/common/parts/preprocessing/collections.py index 146645b43213..b3003b586bab 100644 --- a/nemo/collections/asr/parts/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -19,7 +19,7 @@ import pandas as pd -from nemo.collections.asr.parts import manifest, parsers +from nemo.collections.common.parts.preprocessing import manifest, parsers from nemo.utils import logging diff --git a/nemo/collections/asr/parts/manifest.py b/nemo/collections/common/parts/preprocessing/manifest.py similarity index 100% rename from nemo/collections/asr/parts/manifest.py rename to nemo/collections/common/parts/preprocessing/manifest.py diff --git a/nemo/collections/asr/parts/parsers.py b/nemo/collections/common/parts/preprocessing/parsers.py similarity index 98% rename from nemo/collections/asr/parts/parsers.py rename to nemo/collections/common/parts/preprocessing/parsers.py index 57b6584366a7..749ca6c70a52 100644 --- a/nemo/collections/asr/parts/parsers.py +++ b/nemo/collections/common/parts/preprocessing/parsers.py @@ -16,7 +16,7 @@ import frozendict -from nemo.collections.asr.parts import cleaners +from nemo.collections.common.parts.preprocessing import cleaners class CharParser: diff --git a/nemo/collections/nlp/models/machine_translation/mt_enc_dec_model.py b/nemo/collections/nlp/models/machine_translation/mt_enc_dec_model.py index 647070ce8a62..08f48905666e 100644 --- a/nemo/collections/nlp/models/machine_translation/mt_enc_dec_model.py +++ b/nemo/collections/nlp/models/machine_translation/mt_enc_dec_model.py @@ -285,6 +285,9 @@ def eval_epoch_end(self, outputs, mode): # if user specifies one validation dataloader, then PTL reverts to giving a list of dictionary instead of a list of list of dictionary if isinstance(outputs[0], dict): outputs = [outputs] + + loss_list = [] + sb_score_list = [] for dataloader_idx, output in enumerate(outputs): if dataloader_idx == 0: eval_loss = getattr(self, f'{mode}_loss').compute() @@ -339,6 +342,8 @@ def eval_epoch_end(self, outputs, mode): else: sb_score = 0.0 + loss_list.append(eval_loss.cpu().numpy()) + sb_score_list.append(sb_score) if dataloader_idx == 0: self.log(f"{mode}_loss", eval_loss, sync_dist=True) self.log(f"{mode}_sacreBLEU", sb_score, sync_dist=True) @@ -348,6 +353,10 @@ def eval_epoch_end(self, outputs, mode): self.log(f"{mode}_sacreBLEU_dl_index_{dataloader_idx}", sb_score, sync_dist=True) getattr(self, f'{mode}_loss_{dataloader_idx}').reset() + if len(loss_list) > 1: + self.log(f"{mode}_loss_avg", np.mean(loss_list), sync_dist=True) + self.log(f"{mode}_sacreBLEU_avg", np.mean(sb_score_list), sync_dist=True) + def validation_epoch_end(self, outputs): """ Called at the end of validation to aggregate outputs. diff --git a/nemo/collections/nlp/models/nlp_model.py b/nemo/collections/nlp/models/nlp_model.py index 21d267bd340e..3711db2c8fac 100644 --- a/nemo/collections/nlp/models/nlp_model.py +++ b/nemo/collections/nlp/models/nlp_model.py @@ -422,7 +422,7 @@ def restore_from( restored_model = cls._default_restore_from( restore_path, override_config_path, map_location, strict, return_config ) - restored_model._trainer = trainer + restored_model.set_trainer(trainer) return restored_model else: return super().restore_from(restore_path, override_config_path, map_location, strict, return_config) diff --git a/nemo/collections/tts/data/datalayers.py b/nemo/collections/tts/data/datalayers.py index 89095001af42..cd5da7925bfd 100644 --- a/nemo/collections/tts/data/datalayers.py +++ b/nemo/collections/tts/data/datalayers.py @@ -55,9 +55,9 @@ from torch.nn.utils.rnn import pad_sequence from tqdm import tqdm -from nemo.collections.asr.parts import collections, parsers -from nemo.collections.asr.parts.features import WaveformFeaturizer -from nemo.collections.asr.parts.segment import AudioSegment +from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment +from nemo.collections.common.parts.preprocessing import collections, parsers from nemo.core.classes import Dataset from nemo.core.neural_types.elements import * from nemo.core.neural_types.neural_type import NeuralType diff --git a/nemo/collections/tts/models/fastpitch.py b/nemo/collections/tts/models/fastpitch.py index 886d988d2ef0..a5967d3f26ab 100644 --- a/nemo/collections/tts/models/fastpitch.py +++ b/nemo/collections/tts/models/fastpitch.py @@ -21,7 +21,7 @@ from pytorch_lightning import Trainer from nemo.collections.asr.data.audio_to_text import FastPitchDataset -from nemo.collections.asr.parts import parsers +from nemo.collections.common.parts.preprocessing import parsers from nemo.collections.tts.losses.fastpitchloss import FastPitchLoss from nemo.collections.tts.models.base import SpectrogramGenerator from nemo.collections.tts.modules.fastpitch import FastPitchModule diff --git a/nemo/collections/tts/models/fastpitch_hifigan_e2e.py b/nemo/collections/tts/models/fastpitch_hifigan_e2e.py index b61dc0bcd7b8..605365c6398e 100644 --- a/nemo/collections/tts/models/fastpitch_hifigan_e2e.py +++ b/nemo/collections/tts/models/fastpitch_hifigan_e2e.py @@ -24,7 +24,7 @@ from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from nemo.collections.asr.data.audio_to_text import FastPitchDataset -from nemo.collections.asr.parts import parsers +from nemo.collections.common.parts.preprocessing import parsers from nemo.collections.tts.helpers.helpers import plot_spectrogram_to_numpy from nemo.collections.tts.losses.fastpitchloss import BaseFastPitchLoss from nemo.collections.tts.losses.fastspeech2loss import L1MelLoss diff --git a/nemo/collections/tts/models/fastspeech2.py b/nemo/collections/tts/models/fastspeech2.py index 198bc4052064..d93d3b25639b 100644 --- a/nemo/collections/tts/models/fastspeech2.py +++ b/nemo/collections/tts/models/fastspeech2.py @@ -22,7 +22,7 @@ from omegaconf import MISSING, DictConfig, OmegaConf, open_dict from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger -from nemo.collections.asr.parts import parsers +from nemo.collections.common.parts.preprocessing import parsers from nemo.collections.tts.helpers.helpers import plot_spectrogram_to_numpy from nemo.collections.tts.losses.fastspeech2loss import DurationLoss, L2MelLoss from nemo.collections.tts.models.base import SpectrogramGenerator diff --git a/nemo/collections/tts/models/fastspeech2_hifigan_e2e.py b/nemo/collections/tts/models/fastspeech2_hifigan_e2e.py index e80761cc712a..1ca2f76bb971 100644 --- a/nemo/collections/tts/models/fastspeech2_hifigan_e2e.py +++ b/nemo/collections/tts/models/fastspeech2_hifigan_e2e.py @@ -22,7 +22,7 @@ from omegaconf import DictConfig, OmegaConf, open_dict from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger -from nemo.collections.asr.parts import parsers +from nemo.collections.common.parts.preprocessing import parsers from nemo.collections.tts.helpers.helpers import plot_spectrogram_to_numpy from nemo.collections.tts.losses.fastspeech2loss import DurationLoss, L1MelLoss from nemo.collections.tts.losses.hifigan_losses import DiscriminatorLoss, FeatureMatchingLoss, GeneratorLoss diff --git a/nemo/collections/tts/models/glow_tts.py b/nemo/collections/tts/models/glow_tts.py index 6dc193debe9c..a5ceb18e9dbf 100644 --- a/nemo/collections/tts/models/glow_tts.py +++ b/nemo/collections/tts/models/glow_tts.py @@ -23,7 +23,7 @@ from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from nemo.collections.asr.data.audio_to_text import _AudioTextDataset -from nemo.collections.asr.parts.perturb import process_augmentations +from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations from nemo.collections.tts.helpers.helpers import log_audio_to_tb, plot_alignment_to_numpy, plot_spectrogram_to_numpy from nemo.collections.tts.losses.glow_tts_loss import GlowTTSLoss from nemo.collections.tts.models.base import SpectrogramGenerator diff --git a/nemo/collections/tts/models/tacotron2.py b/nemo/collections/tts/models/tacotron2.py index ab5639a4a87d..f608ac206d1b 100644 --- a/nemo/collections/tts/models/tacotron2.py +++ b/nemo/collections/tts/models/tacotron2.py @@ -22,7 +22,7 @@ from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from torch import nn -from nemo.collections.asr.parts import parsers +from nemo.collections.common.parts.preprocessing import parsers from nemo.collections.tts.helpers.helpers import get_mask_from_lengths, tacotron2_log_to_tb_func from nemo.collections.tts.losses.tacotron2loss import Tacotron2Loss from nemo.collections.tts.models.base import SpectrogramGenerator diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index a1f025d0a2d9..0133524b801f 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -65,7 +65,6 @@ """ _MODEL_IS_RESTORED = False _NEMO_FILE_FOLDER = None -_MODEL_RESTORE_PATH = None class ModelPT(LightningModule, Model): @@ -117,6 +116,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.trainer = trainer # reference required for self.*_rank self._trainer = self.trainer # alias for backward compatibility + self._set_model_guid() + # Set device_id in AppState if torch.cuda.is_available() and torch.cuda.current_device() is not None: app_state = AppState() @@ -185,11 +186,13 @@ def register_artifact(self, config_path: str, src: str, verify_src_exists: bool """ app_state = AppState() + if src is None or src == "": return src if not hasattr(self, 'artifacts'): self.artifacts = {} + if self.artifacts is None: self.artifacts = {} @@ -258,10 +261,8 @@ def _handle_artifacts(self, nemo_file_folder): shutil.copy2(artiitem.path, os.path.join(nemo_file_folder, artifact_uniq_name)) # Update artifacts registry - new_artiitem = model_utils.ArtifactItem() - new_artiitem.path = "nemo:" + artifact_uniq_name - new_artiitem.path_type = model_utils.ArtifactPathType.TAR_PATH - self.artifacts[conf_path] = new_artiitem + artiitem.hashed_path = "nemo:" + artifact_uniq_name + self.artifacts[conf_path] = artiitem elif artiitem.path_type == model_utils.ArtifactPathType.TAR_PATH: # process all tarfile artifacts in one go, so preserve key-value pair @@ -272,7 +273,8 @@ def _handle_artifacts(self, nemo_file_folder): # Process current tarfile artifacts by unpacking the previous tarfile and extract the artifacts # that are currently required. - if len(tarfile_artifacts) > 0: + model_metadata = app_state.get_model_metadata_from_guid(self.model_guid) + if len(tarfile_artifacts) > 0 and model_metadata.restoration_path is not None: # Need to step into nemo archive to extract file # Get path where the command is executed - the artifacts will be "retrieved" there # (original .nemo behavior) @@ -280,7 +282,7 @@ def _handle_artifacts(self, nemo_file_folder): try: # Step into the nemo archive to try and find the file with tempfile.TemporaryDirectory() as archive_dir: - self._unpack_nemo_file(path2file=app_state.model_restore_path, out_folder=archive_dir) + self._unpack_nemo_file(path2file=model_metadata.restoration_path, out_folder=archive_dir) os.chdir(archive_dir) for conf_path, artiitem in tarfile_artifacts: # Get basename and copy it to nemo_file_folder @@ -305,7 +307,10 @@ def _update_artifact_paths(self, path2yaml_file): if self.artifacts is not None and len(self.artifacts) > 0: conf = OmegaConf.load(path2yaml_file) for conf_path, item in self.artifacts.items(): - conf.update_node(conf_path, item.path) + if item.hashed_path is None: + conf.update_node(conf_path, item.path) + else: + conf.update_node(conf_path, item.hashed_path) with open(path2yaml_file, 'w') as fout: OmegaConf.save(config=conf, f=fout, resolve=True) @@ -473,9 +478,7 @@ def restore_from( if not path.exists(restore_path): raise FileNotFoundError(f"Can't find {restore_path}") - global _MODEL_RESTORE_PATH - _MODEL_RESTORE_PATH = os.path.abspath(os.path.expanduser(restore_path)) - app_state.model_restore_path = _MODEL_RESTORE_PATH + app_state.model_restore_path = os.path.abspath(os.path.expanduser(restore_path)) return cls._default_restore_from(restore_path, override_config_path, map_location, strict, return_config) @classmethod @@ -1038,6 +1041,90 @@ def get_test_dataloader_prefix(self, dataloader_idx: int = 0) -> str: """ return self._test_names[dataloader_idx] + @rank_zero_only + def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: str = 'cpu'): + """ + Initializes a given model with the parameters obtained via specific config arguments. + The state dict of the provided model will be updated with `strict=False` setting so as to prevent + requirement of exact model parameters matching. + + Initializations: + init_from_nemo_model: Str path to a .nemo model, which will be instantiated in order + to extract the state dict. + + init_from_pretrained_model: Str name of a pretrained model checkpoint (obtained via cloud). + The model will be downloaded (or a cached copy will be used), instantiated and then + its state dict will be extracted. + + init_from_ptl_ckpt: Str name of a Pytorch Lightning checkpoint file. It will be loaded and + the state dict will extracted. + + Args: + cfg: The config used to instantiate the model. It need only contain one of the above keys. + map_location: str or torch.device() which represents where the intermediate state dict + (from the pretrained model or checkpoint) will be loaded. + + """ + args = ['init_from_nemo_model', 'init_from_pretrained_model', 'init_from_ptl_ckpt'] + arg_matches = [(1 if arg in cfg and arg is not None else 0) for arg in args] + + if sum(arg_matches) == 0: + # model weights do not need to be restored + return + + if sum(arg_matches) > 1: + raise ValueError( + f"Cannot pass more than one model initialization arguments to config!\n" + f"Found : {[args[idx] for idx, arg_present in enumerate(arg_matches) if arg_present]}" + ) + + if 'init_from_nemo_model' in cfg and cfg.init_from_nemo_model is not None: + with open_dict(cfg): + # Restore model + model_path = cfg.pop('init_from_nemo_model') + restored_model = self.restore_from(model_path, map_location=map_location, strict=True) + + # Restore checkpoint into current model + self.load_state_dict(restored_model.state_dict(), strict=False) + logging.info(f'Model checkpoint restored from nemo file with path : `{model_path}`') + + del restored_model + + if 'init_from_pretrained_model' in cfg and cfg.init_from_pretrained_model is not None: + with open_dict(cfg): + # Restore model + model_name = cfg.pop('init_from_pretrained_model') + + # Check if model is being resumed or not - only works if `Trainer` is attached to model + if hasattr(self, 'trainer') and self.trainer is not None: + trainer = self.trainer + if hasattr(trainer, 'resume_from_checkpoint') and trainer.resume_from_checkpoint is not None: + logging.info( + "Model training is being resumed via Pytorch Lightning.\n" + "Initialization from pretrained model (via cloud) will be skipped." + ) + return + + restored_model = self.from_pretrained(model_name, map_location=map_location, strict=True) + + # Restore checkpoint into current model + self.load_state_dict(restored_model.state_dict(), strict=False) + logging.info(f'Model checkpoint restored from pretrained chackpoint with name : `{model_name}`') + + del restored_model + + if 'init_from_ptl_ckpt' in cfg and cfg.init_from_ptl_ckpt is not None: + with open_dict(cfg): + # Restore checkpoint + ckpt_path = cfg.pop('init_from_ptl_ckpt') + ckpt = torch.load(ckpt_path, map_location=map_location) + + # Restore checkpoint into current model + self.load_state_dict(ckpt['state_dict'], strict=False) + logging.info(f'Model checkpoint restored from pytorch lightning chackpoint with path : `{ckpt_path}`') + + del ckpt + def teardown(self, stage: str): """ Called at the end of fit and test. @@ -1290,23 +1377,32 @@ def _set_model_restore_state(is_being_restored: bool, folder: str = None): _MODEL_IS_RESTORED = is_being_restored _NEMO_FILE_FOLDER = folder + def _set_model_guid(self): + if not hasattr(self, 'model_guid'): + appstate = AppState() + + # Generate a unique uuid for the instance + # also determine if the model is being restored or not, and preserve the path + self.model_guid = str(uuid.uuid4()) + if self._is_model_being_restored(): + restore_path = appstate.model_restore_path + else: + restore_path = None + + appstate.register_model_guid(self.model_guid, restoration_path=restore_path) + @staticmethod def _is_restore_type_tarfile() -> bool: """ Utility method that checks if the restore path of the underlying Model is a tarfile (can be any valid archive)._MODEL_EFF_SAVE """ - global _MODEL_RESTORE_PATH - app_state = AppState() - if _MODEL_RESTORE_PATH is None and app_state.model_restore_path is None: + if app_state.model_restore_path is None: return False else: - if _MODEL_RESTORE_PATH: - if tarfile.is_tarfile(_MODEL_RESTORE_PATH): - return True - elif app_state.model_restore_path: + if app_state.model_restore_path: if tarfile.is_tarfile(app_state.model_restore_path): return True else: diff --git a/nemo/utils/app_state.py b/nemo/utils/app_state.py index e3b1bbf45808..f2c710443125 100644 --- a/nemo/utils/app_state.py +++ b/nemo/utils/app_state.py @@ -12,11 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass +from threading import Lock +from typing import Dict, Optional + from nemo.utils.metaclasses import Singleton +@dataclass() +class ModelMetadataRegistry: + guid: str + gidx: int + restoration_path: Optional[str] = None + + class AppState(metaclass=Singleton): def __init__(self): + # method call lock + self.__lock = Lock() # TODO: should we store global config in hydra_runner? self._app_cfg = None @@ -50,6 +63,8 @@ def __init__(self): self._model_config_yaml = "model_config.yaml" self._model_weights_ckpt = "model_weights.ckpt" self._model_restore_path = None + self._all_model_restore_paths = [] + self._model_guid_map = {} # type: Dict[str, ModelMetadataRegistry] @property def device_id(self): @@ -342,8 +357,30 @@ def model_weights_ckpt(self): @property def model_restore_path(self): - return self._model_restore_path + restore_path = self._all_model_restore_paths[-1] if len(self._all_model_restore_paths) > 0 else None + return restore_path @model_restore_path.setter def model_restore_path(self, path): - self._model_restore_path = path + with self.__lock: + self._model_restore_path = path + self._all_model_restore_paths.append(path) + + def register_model_guid(self, guid: str, restoration_path: Optional[str] = None): + # Maps a guid to its restore path (None or last absolute path) + with self.__lock: + if guid in self._model_guid_map: + idx = self._model_guid_map[guid].gidx + else: + idx = len(self._model_guid_map) + self._model_guid_map[guid] = ModelMetadataRegistry(guid, idx, restoration_path=restoration_path) + + def reset_model_guid_registry(self): + # Reset the guid mapping + with self.__lock: + self._model_guid_map.clear() + + def get_model_metadata_from_guid(self, guid) -> ModelMetadataRegistry: + # Returns the global model idx and restoration path + metadata = self._model_guid_map[guid] + return metadata diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 5b42045aa125..8e9400075dd3 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -646,9 +646,7 @@ def on_train_end(self, trainer, pl_module): pl_module.save_to(save_path=os.path.join(self.dirpath, self.prefix + self.postfix)) -def configure_checkpointing( - trainer: 'pytorch_lightning.Trainer', log_dir: Path, name: str, params: 'DictConfig', -): +def configure_checkpointing(trainer: 'pytorch_lightning.Trainer', log_dir: Path, name: str, params: 'DictConfig'): """ Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint callback or if trainer.weights_save_path was passed to Trainer. """ @@ -706,6 +704,7 @@ def configure_checkpointing( ) checkpoint_callback = NeMoModelCheckpoint(**params) + checkpoint_callback.last_model_path = trainer.resume_from_checkpoint or "" trainer.callbacks.append(checkpoint_callback) diff --git a/nemo/utils/model_utils.py b/nemo/utils/model_utils.py index 6b8225dd5d61..ee61db54ed07 100644 --- a/nemo/utils/model_utils.py +++ b/nemo/utils/model_utils.py @@ -46,6 +46,7 @@ class ArtifactPathType(Enum): class ArtifactItem: path: str path_type: ArtifactPathType + hashed_path: Optional[str] = None def resolve_dataset_name_from_cfg(cfg: DictConfig) -> str: diff --git a/nemo_text_processing/inverse_text_normalization/verbalizers/money.py b/nemo_text_processing/inverse_text_normalization/verbalizers/money.py index b6c312116fdb..30dec90c76e0 100644 --- a/nemo_text_processing/inverse_text_normalization/verbalizers/money.py +++ b/nemo_text_processing/inverse_text_normalization/verbalizers/money.py @@ -28,7 +28,7 @@ class MoneyFst(GraphFst): """ Finite state transducer for verbalizing money, e.g. - money { integer_part: "12" fractional_part: 05 currency: "$" } -> $12.05 + money { integer_part: "12" fractional_part: "05" currency: "$" } -> $12.05 Args: decimal: DecimalFst diff --git a/scripts/dataset_processing/ljspeech/create_manifests_and_textfiles.py b/scripts/dataset_processing/ljspeech/create_manifests_and_textfiles.py index bb83cf844fa1..aede9f828d4a 100644 --- a/scripts/dataset_processing/ljspeech/create_manifests_and_textfiles.py +++ b/scripts/dataset_processing/ljspeech/create_manifests_and_textfiles.py @@ -23,7 +23,7 @@ import sox import wget -from nemo.collections.asr.parts import parsers +from nemo.collections.common.parts.preprocessing import parsers parser = argparse.ArgumentParser() parser.add_argument('--ljspeech_base', required=True, default=None, type=str) diff --git a/scripts/speaker_recognition/rttm_to_manifest.py b/scripts/speaker_recognition/rttm_to_manifest.py index 5507e33efcb8..3d7b772b05b9 100644 --- a/scripts/speaker_recognition/rttm_to_manifest.py +++ b/scripts/speaker_recognition/rttm_to_manifest.py @@ -14,7 +14,7 @@ import argparse -from nemo.collections.asr.parts.speaker_utils import write_rttm2manifest +from nemo.collections.asr.parts.utils.speaker_utils import write_rttm2manifest from nemo.utils import logging """ diff --git a/scripts/voice_activity_detection/vad_tune_threshold.py b/scripts/voice_activity_detection/vad_tune_threshold.py index 70ffbd5f9305..1486ae730c40 100644 --- a/scripts/voice_activity_detection/vad_tune_threshold.py +++ b/scripts/voice_activity_detection/vad_tune_threshold.py @@ -16,7 +16,7 @@ import numpy as np -from nemo.collections.asr.parts.vad_utils import vad_tune_threshold_on_dev +from nemo.collections.asr.parts.utils.vad_utils import vad_tune_threshold_on_dev from nemo.utils import logging if __name__ == "__main__": diff --git a/tests/collections/asr/test_asr_datasets.py b/tests/collections/asr/test_asr_datasets.py index 9d008411b80c..b2a3ccfd09fe 100644 --- a/tests/collections/asr/test_asr_datasets.py +++ b/tests/collections/asr/test_asr_datasets.py @@ -12,14 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import os import pytest -import torch +from omegaconf import OmegaConf from nemo.collections.asr.data.audio_to_text import TarredAudioToBPEDataset, TarredAudioToCharDataset -from nemo.collections.asr.parts.features import WaveformFeaturizer +from nemo.collections.asr.data.audio_to_text_dataset import inject_dataloader_value_from_model_config from nemo.collections.common import tokenizers +from nemo.utils import logging class TestASRDatasets: @@ -79,6 +81,23 @@ def test_tarred_dataset(self, test_data_dir): count += 1 assert count == 32 + @pytest.mark.unit + def test_mismatch_in_model_dataloader_config(self, caplog): + logging._logger.propagate = True + caplog.set_level(logging.WARNING) + + model_cfg = OmegaConf.create(dict(labels=OmegaConf.create(["a", "b", "c"]))) + dataloader_cfg = OmegaConf.create(dict(labels=copy.deepcopy(self.labels))) + + inject_dataloader_value_from_model_config(model_cfg, dataloader_cfg, key='labels') + + assert ( + """`labels` is explicitly provided to the data loader, and is different from the `labels` provided at the model level config.""" + in caplog.text + ) + + logging._logger.propagate = False + @pytest.mark.unit def test_tarred_bpe_dataset(self, test_data_dir): manifest_path = os.path.abspath(os.path.join(test_data_dir, 'asr/tarred_an4/tarred_audio_manifest.json')) diff --git a/tests/collections/asr/test_asr_filterbankfeatures_seq_len.py b/tests/collections/asr/test_asr_filterbankfeatures_seq_len.py index 3dfbd4c966a8..24ba5849850f 100644 --- a/tests/collections/asr/test_asr_filterbankfeatures_seq_len.py +++ b/tests/collections/asr/test_asr_filterbankfeatures_seq_len.py @@ -17,7 +17,7 @@ import pytest import torch -from nemo.collections.asr.parts.features import FilterbankFeatures +from nemo.collections.asr.parts.preprocessing.features import FilterbankFeatures class TestFilterbankFeatures: diff --git a/tests/collections/asr/test_asr_metrics.py b/tests/collections/asr/test_asr_metrics.py index 5adcde824a58..7f78c1d72079 100644 --- a/tests/collections/asr/test_asr_metrics.py +++ b/tests/collections/asr/test_asr_metrics.py @@ -14,14 +14,12 @@ import random -import string import pytest import torch from nemo.collections.asr.metrics.wer import WER, word_error_rate -from nemo.collections.asr.parts.rnnt_utils import Hypothesis -from nemo.utils import logging +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis class TestWordErrorRate: diff --git a/tests/collections/asr/test_asr_modules.py b/tests/collections/asr/test_asr_modules.py index bff683d0bbc8..26c3c88ef722 100644 --- a/tests/collections/asr/test_asr_modules.py +++ b/tests/collections/asr/test_asr_modules.py @@ -17,7 +17,7 @@ from omegaconf import OmegaConf from nemo.collections.asr import modules -from nemo.collections.asr.parts.rnnt_utils import Hypothesis +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.utils import config_utils diff --git a/tests/collections/asr/test_asr_rnnt_encdec_model.py b/tests/collections/asr/test_asr_rnnt_encdec_model.py index 4360f3d50321..1e1ec99cec97 100644 --- a/tests/collections/asr/test_asr_rnnt_encdec_model.py +++ b/tests/collections/asr/test_asr_rnnt_encdec_model.py @@ -17,11 +17,10 @@ import torch from omegaconf import DictConfig, ListConfig -from nemo.collections.asr.metrics import rnnt_wer from nemo.collections.asr.models import EncDecRNNTModel -from nemo.collections.asr.parts import rnnt_beam_decoding as beam_decode -from nemo.collections.asr.parts import rnnt_greedy_decoding as greedy_decode from nemo.collections.asr.parts.numba import __NUMBA_MINIMUM_VERSION__, numba_utils +from nemo.collections.asr.parts.submodules import rnnt_beam_decoding as beam_decode +from nemo.collections.asr.parts.submodules import rnnt_greedy_decoding as greedy_decode from nemo.utils.config_utils import assert_dataclass_signature_match NUMBA_RNNT_LOSS_AVAILABLE = numba_utils.numba_cuda_is_supported(__NUMBA_MINIMUM_VERSION__) diff --git a/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py b/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py index ba25f00d8edf..6a6836d647c7 100644 --- a/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py +++ b/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py @@ -21,9 +21,9 @@ from omegaconf import DictConfig from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel -from nemo.collections.asr.parts import rnnt_beam_decoding as beam_decode -from nemo.collections.asr.parts import rnnt_greedy_decoding as greedy_decode from nemo.collections.asr.parts.numba import __NUMBA_MINIMUM_VERSION__, numba_utils +from nemo.collections.asr.parts.submodules import rnnt_beam_decoding as beam_decode +from nemo.collections.asr.parts.submodules import rnnt_greedy_decoding as greedy_decode from nemo.collections.common import tokenizers NUMBA_RNNT_LOSS_AVAILABLE = numba_utils.numba_cuda_is_supported(__NUMBA_MINIMUM_VERSION__) diff --git a/tests/collections/asr/test_jasper_block.py b/tests/collections/asr/test_jasper_block.py index 298db767fca2..2b6cfb41eb6d 100644 --- a/tests/collections/asr/test_jasper_block.py +++ b/tests/collections/asr/test_jasper_block.py @@ -15,7 +15,7 @@ import pytest import torch -from nemo.collections.asr.parts import jasper +from nemo.collections.asr.parts.submodules import jasper class TestJasperBlock: diff --git a/tests/collections/asr/test_label_datasets.py b/tests/collections/asr/test_label_datasets.py index 57b8bb2d8f06..fbe902528445 100644 --- a/tests/collections/asr/test_label_datasets.py +++ b/tests/collections/asr/test_label_datasets.py @@ -18,9 +18,8 @@ from nemo.collections.asr.data.audio_to_label import TarredAudioToClassificationLabelDataset from nemo.collections.asr.data.feature_to_label import FeatureToSeqSpeakerLabelDataset -from nemo.collections.asr.parts.feature_loader import ExternalFeatureLoader -from nemo.collections.asr.parts.features import WaveformFeaturizer -from nemo.collections.common import tokenizers +from nemo.collections.asr.parts.preprocessing.feature_loader import ExternalFeatureLoader +from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer class TestASRDatasets: diff --git a/tests/collections/tts/test_waveglow.py b/tests/collections/tts/test_waveglow.py index ca4eb6e28f84..2e544a83174a 100644 --- a/tests/collections/tts/test_waveglow.py +++ b/tests/collections/tts/test_waveglow.py @@ -41,7 +41,7 @@ pcfg = DictConfig( { - "_target_": "nemo.collections.asr.parts.features.FilterbankFeatures", + "_target_": "nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures", "dither": 0.0, "nfilt": 80, "stft_conv": True, diff --git a/tests/core/test_save_restore.py b/tests/core/test_save_restore.py index 5f24ee34e8ea..8cf0b893c372 100644 --- a/tests/core/test_save_restore.py +++ b/tests/core/test_save_restore.py @@ -24,6 +24,7 @@ from nemo.collections.asr.models import EncDecCTCModel, EncDecCTCModelBPE from nemo.collections.nlp.models import PunctuationCapitalizationModel from nemo.core.classes import ModelPT +from nemo.utils.app_state import AppState def getattr2(object, attr): @@ -169,7 +170,6 @@ def test_mock_save_to_restore_from(self): # Save test model_copy = self.__test_restore_elsewhere(model, map_location='cpu') - # assert filecmp.cmp(model.temp_file, model_copy.temp_file) # Restore test diff = model.w.weight - model_copy.w.weight @@ -228,7 +228,6 @@ def test_mock_restore_from_config_override_with_OmegaConf(self): # Save test (with overriden config as OmegaConf object) model_copy = self.__test_restore_elsewhere(model, map_location='cpu', override_config_path=cfg) - assert filecmp.cmp(model.temp_file, model_copy.temp_file) # Restore test diff = model.w.weight - model_copy.w.weight @@ -270,7 +269,6 @@ def test_mock_restore_from_config_override_with_yaml(self): # Restore test diff = model.w.weight - model_copy.w.weight assert diff.mean() <= 1e-9 - # assert os.path.basename(model.temp_file) == model_copy.temp_file assert filecmp.cmp(model.temp_file, model_copy.temp_file) assert model_copy.temp_data == ["*****\n"] @@ -310,3 +308,136 @@ def test_mock_save_to_restore_from_with_target_class(self): assert diff.mean() <= 1e-9 assert isinstance(model_copy, MockModel) assert model_copy.temp_data == ["*****\n"] + + @pytest.mark.unit + def test_mock_save_to_restore_from_multiple_models(self): + with tempfile.NamedTemporaryFile('w') as empty_file, tempfile.NamedTemporaryFile('w') as empty_file2: + # Write some data + empty_file.writelines(["*****\n"]) + empty_file.flush() + empty_file2.writelines(["+++++\n"]) + empty_file2.flush() + + # Update config + create ,pde;s + cfg = _mock_model_config() + cfg.model.temp_file = empty_file.name + cfg2 = _mock_model_config() + cfg2.model.temp_file = empty_file2.name + + # Create models + model = MockModel(cfg=cfg.model, trainer=None) + model = model.to('cpu') + model2 = MockModel(cfg=cfg2.model, trainer=None) + model2 = model2.to('cpu') + + assert model.temp_file == empty_file.name + assert model2.temp_file == empty_file2.name + + # Save test + model_copy = self.__test_restore_elsewhere(model, map_location='cpu') + model2_copy = self.__test_restore_elsewhere(model2, map_location='cpu') + + # Restore test + assert model_copy.temp_data == ["*****\n"] + assert model2_copy.temp_data == ["+++++\n"] + + @pytest.mark.unit + def test_mock_save_to_restore_from_multiple_models_inverted_order(self): + with tempfile.NamedTemporaryFile('w') as empty_file, tempfile.NamedTemporaryFile('w') as empty_file2: + # Write some data + empty_file.writelines(["*****\n"]) + empty_file.flush() + empty_file2.writelines(["+++++\n"]) + empty_file2.flush() + + # Update config + create ,pde;s + cfg = _mock_model_config() + cfg.model.temp_file = empty_file.name + cfg2 = _mock_model_config() + cfg2.model.temp_file = empty_file2.name + + # Create models + model = MockModel(cfg=cfg.model, trainer=None) + model = model.to('cpu') + model2 = MockModel(cfg=cfg2.model, trainer=None) + model2 = model2.to('cpu') + + assert model.temp_file == empty_file.name + assert model2.temp_file == empty_file2.name + + # Save test (inverted order) + model2_copy = self.__test_restore_elsewhere(model2, map_location='cpu') + model_copy = self.__test_restore_elsewhere(model, map_location='cpu') + + # Restore test + assert model_copy.temp_data == ["*****\n"] + assert model2_copy.temp_data == ["+++++\n"] + + @pytest.mark.unit + def test_mock_save_to_restore_chained(self): + with tempfile.NamedTemporaryFile('w') as empty_file, tempfile.NamedTemporaryFile('w') as empty_file2: + # Write some data + empty_file.writelines(["*****\n"]) + empty_file.flush() + + # Update config + create ,pde;s + cfg = _mock_model_config() + cfg.model.temp_file = empty_file.name + + # Create models + model = MockModel(cfg=cfg.model, trainer=None) + model = model.to('cpu') + + assert model.temp_file == empty_file.name + + def save_copy(model, save_folder, restore_folder): + # Where model will be saved + model_save_path = os.path.join(save_folder, f"{model.__class__.__name__}.nemo") + model.save_to(save_path=model_save_path) + # Where model will be restored from + model_restore_path = os.path.join(restore_folder, f"{model.__class__.__name__}.nemo") + shutil.copy(model_save_path, model_restore_path) + return model_restore_path + + # Save test + with tempfile.TemporaryDirectory() as level4: + with tempfile.TemporaryDirectory() as level3: + with tempfile.TemporaryDirectory() as level2: + with tempfile.TemporaryDirectory() as level1: + path = save_copy(model, level1, level2) + model_copy2 = model.__class__.restore_from(path) + path = save_copy(model_copy2, level2, level3) + model_copy3 = model.__class__.restore_from(path) + path = save_copy(model_copy3, level3, level4) + model_copy = model.__class__.restore_from(path) + + # Restore test + assert model_copy.temp_data == ["*****\n"] + + # AppState test + appstate = AppState() + metadata = appstate.get_model_metadata_from_guid(model_copy.model_guid) + assert metadata.guid != model.model_guid + assert metadata.restoration_path == path + + @pytest.mark.unit + def test_mock_save_to_multiple_times(self): + with tempfile.NamedTemporaryFile('w') as empty_file, tempfile.TemporaryDirectory() as tmpdir: + # Write some data + empty_file.writelines(["*****\n"]) + empty_file.flush() + + # Update config + cfg = _mock_model_config() + cfg.model.temp_file = empty_file.name + + # Create model + model = MockModel(cfg=cfg.model, trainer=None) # type: MockModel + model = model.to('cpu') + + assert model.temp_file == empty_file.name + + # Save test + model.save_to(os.path.join(tmpdir, 'save_0.nemo')) + model.save_to(os.path.join(tmpdir, 'save_1.nemo')) + model.save_to(os.path.join(tmpdir, 'save_2.nemo')) diff --git a/tutorials/00_NeMo_Primer.ipynb b/tutorials/00_NeMo_Primer.ipynb index d7e4385e4521..50a42570fb7a 100644 --- a/tutorials/00_NeMo_Primer.ipynb +++ b/tutorials/00_NeMo_Primer.ipynb @@ -1,1281 +1,1281 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "7LfkL2r2Q1tr" - }, - "source": [ - "# Getting Started: Exploring Nemo Fundamentals\n", - "\n", - "NeMo is a toolkit for creating [Conversational AI](https://developer.nvidia.com/conversational-ai#started) applications.\n", - "\n", - "NeMo toolkit makes it possible for researchers to easily compose complex neural network architectures for conversational AI using reusable components - Neural Modules. Neural Modules are conceptual blocks of neural networks that take typed inputs and produce typed outputs. Such modules typically represent data layers, encoders, decoders, language models, loss functions, or methods of combining activations.\n", - "\n", - "The toolkit comes with extendable collections of pre-built modules and ready-to-use models for automatic speech recognition (ASR), natural language processing (NLP) and text synthesis (TTS). Built for speed, NeMo can utilize NVIDIA's Tensor Cores and scale out training to multiple GPUs and multiple nodes.\n", - "\n", - "For more information, please visit https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/#" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "zLSy94NEQi-e" - }, - "outputs": [], - "source": [ - "\"\"\"\n", - "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n", - "\n", - "Instructions for setting up Colab are as follows:\n", - "1. Open a new Python 3 notebook.\n", - "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n", - "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n", - "4. Run this cell to set up dependencies.\n", - "\"\"\"\n", - "# If you're using Google Colab and not running locally, run this cell.\n", - "\n", - "## Install dependencies\n", - "!pip install wget\n", - "!apt-get install sox libsndfile1 ffmpeg\n", - "!pip install unidecode\n", - "\n", - "# ## Install NeMo\n", - "BRANCH = 'v1.0.0'\n", - "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n", - "\n", - "## Install TorchAudio\n", - "!pip install torchaudio>=0.6.0 -f https://download.pytorch.org/whl/torch_stable.html\n", - "\n", - "## Grab the config we'll use in this example\n", - "!mkdir configs" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6G2TZkaxcM0e" - }, - "source": [ - "## Foundations of NeMo\n", - "---------\n", - "\n", - "NeMo models leverage [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) Module, and are compatible with the entire PyTorch ecosystem. This means that users have the full flexibility of using the higher level APIs provided by PyTorch Lightning (via Trainer), or write their own training and evaluation loops in PyTorch directly (by simply calling the model and the individual components of the model).\n", - "\n", - "For NeMo developers, a \"Model\" is the neural network(s) as well as all the infrastructure supporting those network(s), wrapped into a singular, cohesive unit. As such, all NeMo models are constructed to contain the following out of the box (at the bare minimum, some models support additional functionality too!) - \n", - "\n", - " - Neural Network architecture - all of the modules that are required for the model.\n", - "\n", - " - Dataset + Data Loaders - all of the components that prepare the data for consumption during training or evaluation.\n", - "\n", - " - Preprocessing + Postprocessing - all of the components that process the datasets so they can easily be consumed by the modules.\n", - "\n", - " - Optimizer + Schedulers - basic defaults that work out of the box, and allow further experimentation with ease.\n", - "\n", - " - Any other supporting infrastructure - tokenizers, language model configuration, data augmentation etc.\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "XxAwtqWBQrNk" - }, - "outputs": [], - "source": [ - "import nemo\n", - "nemo.__version__" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "H01SHfKQh-gV" - }, - "source": [ - "## NeMo Collections\n", - "\n", - "NeMo is sub-divided into a few fundamental collections based on their domains - `asr`, `nlp`, `tts`. When you performed the `import nemo` statement above, none of the above collections were imported. This is because you might not need all of the collections at once, so NeMo allows partial imports of just one or more collection, as and when you require them.\n", - "\n", - "-------\n", - "Let's import the above three collections - " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "J09NNa8fhth7" - }, - "outputs": [], - "source": [ - "import nemo.collections.asr as nemo_asr\n", - "import nemo.collections.nlp as nemo_nlp\n", - "import nemo.collections.tts as nemo_tts" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bSvYoeBrjPza" - }, - "source": [ - "## NeMo Models in Collections\n", - "\n", - "NeMo contains several models for each of its collections, pertaining to certain common tasks involved in conversational AI. At a brief glance, let's look at all the Models that NeMo offers for the above 3 collections." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "9LbbC_92i41f" - }, - "outputs": [], - "source": [ - "asr_models = [model for model in dir(nemo_asr.models) if model.endswith(\"Model\")]\n", - "asr_models" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "t5_ax9Z8j9FC" - }, - "outputs": [], - "source": [ - "nlp_models = [model for model in dir(nemo_nlp.models) if model.endswith(\"Model\")]\n", - "nlp_models" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "bQdR6RJdkezq" - }, - "outputs": [], - "source": [ - "tts_models = [model for model in dir(nemo_tts.models) if model.endswith(\"Model\")]\n", - "tts_models" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "iWKxKQnSkj9Z" - }, - "source": [ - "## The NeMo Model\n", - "\n", - "Let's dive deeper into what a NeMo model really is. There are many ways we can create these models - we can use the constructor and pass in a config, we can instantiate the model from a pre-trained checkpoint, or simply pass a pre-trained model name and instantiate a model directly from the cloud !\n", - "\n", - "---------\n", - "For now, let's try to work with an ASR model - [QuartzNet](https://arxiv.org/abs/1910.10261)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "n-XOQaW1kh3v" - }, - "outputs": [], - "source": [ - "quartznet = nemo_asr.models.EncDecCTCModel.from_pretrained('QuartzNet15x5Base-En')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "YP4X7KVPli6g" - }, - "outputs": [], - "source": [ - "quartznet.summarize();" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MB91Swu0pIKr" - }, - "source": [ - "## Model Configuration using OmegaConf\n", - "--------\n", - "\n", - "So we could download, instantiate and analyse the high level structure of the `QuartzNet` model in a few lines! Now let's delve deeper into the configuration file that makes the model work.\n", - "\n", - "First, we import [OmegaConf](https://omegaconf.readthedocs.io/en/latest/). OmegaConf is an excellent library that is used throughout NeMo in order to enable us to perform yaml configuration management more easily. Additionally, it plays well with another library, [Hydra](https://hydra.cc/docs/intro/), that is used by NeMo to perform on the fly config edits from the command line, dramatically boosting ease of use of our config files !" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "RkgrDJvumFER" - }, - "outputs": [], - "source": [ - "from omegaconf import OmegaConf" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CktakfBluA56" - }, - "source": [ - "All NeMo models come packaged with their model configuration inside the `cfg` attribute. While technically it is meant to be config declaration of the model as it has been currently constructed, `cfg` is an essential tool to modify the behaviour of the Model after it has been constructed. It can be safely used to make it easier to perform many essential tasks inside Models. \n", - "\n", - "To be doubly sure, we generally work on a copy of the config until we are ready to edit it inside the model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ISd6z7sXt9Mm" - }, - "outputs": [], - "source": [ - "import copy" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "N2_SiLHRve8A" - }, - "outputs": [], - "source": [ - "cfg = copy.deepcopy(quartznet.cfg)\n", - "print(OmegaConf.to_yaml(cfg))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "W_V3e3W7vqOb" - }, - "source": [ - "## Analysing the contents of the Model config\n", - "----------\n", - "\n", - "Above we see a configuration for the QuartzNet model. As discussed in the beginning, NeMo models contain the entire definition of the neural network(s) as well as most of the surrounding infrastructure to support that model within themselves. Here, we see a perfect example of this behaviour.\n", - "\n", - "QuartzNet contains within its config - \n", - "\n", - "- `preprocessor` - MelSpectrogram preprocessing layer\n", - "- `encoder` - The acoustic encoder model.\n", - "- `decoder` - The CTC decoder layer.\n", - "- `optim` (and potentially `sched`) - Optimizer configuration. Can optionally include Scheduler information.\n", - "- `spec_augment` - Spectrogram Augmentation support.\n", - "- `train_ds`, `validation_ds` and `test_ds` - Dataset and data loader construction information." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sIwhdXkwxn6R" - }, - "source": [ - "## Modifying the contents of the Model config\n", - "----------\n", - "\n", - "Say we want to experiment with a different preprocessor (we want MelSpectrogram, but with different configuration than was provided in the original configuration). Or say we want to add a scheduler to this model during training. \n", - "\n", - "OmegaConf makes this a very simple task for us!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "WlSZ8EA4yGKo" - }, - "outputs": [], - "source": [ - "# OmegaConf won't allow you to add new config items, so we temporarily disable this safeguard.\n", - "OmegaConf.set_struct(cfg, False)\n", - "\n", - "# Let's see the old optim config\n", - "print(\"Old Config: \")\n", - "print(OmegaConf.to_yaml(cfg.optim))\n", - "\n", - "sched = {'name': 'CosineAnnealing', 'warmup_steps': 1000, 'min_lr': 1e-6}\n", - "sched = OmegaConf.create(sched) # Convert it into a DictConfig\n", - "\n", - "# Assign it to cfg.optim.sched namespace\n", - "cfg.optim.sched = sched\n", - "\n", - "# Let's see the new optim config\n", - "print(\"New Config: \")\n", - "print(OmegaConf.to_yaml(cfg.optim))\n", - "\n", - "# Here, we restore the safeguards so no more additions can be made to the config\n", - "OmegaConf.set_struct(cfg, True)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-nMDN66502kn" - }, - "source": [ - "## Updating the model from config\n", - "----------\n", - "\n", - "NeMo Models can be updated in a few ways, but we follow similar patterns within each collection so as to maintain consistency.\n", - "\n", - "Here, we will show the two most common ways to modify core components of the model - using the `from_config_dict` method, and updating a few special parts of the model.\n", - "\n", - "Remember, all NeMo models are PyTorch Lightning modules, which themselves are PyTorch modules, so we have a lot of flexibility here!" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qrKzFYkZ20aa" - }, - "source": [ - "### Update model using `from_config_dict`\n", - "\n", - "In certain config files, you will notice the following pattern : \n", - "\n", - "```yaml\n", - "preprocessor:\n", - " _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor\n", - " normalize: per_feature\n", - " window_size: 0.02\n", - " sample_rate: 16000\n", - " window_stride: 0.01\n", - " window: hann\n", - " features: 64\n", - " n_fft: 512\n", - " frame_splicing: 1\n", - " dither: 1.0e-05\n", - " stft_conv: false\n", - "```\n", - "\n", - "You might ask, why are we using `_target_`? Well, it is generally rare for the preprocessor, encoder, decoder and perhaps a few other details to be changed often from the command line when experimenting. In order to stabilize these settings, we enforce that our preprocessor will always be of type `AudioToMelSpectrogramPreprocessor` for this model by setting its `_target_` attribute in the config. In order to provide its parameters in the class constructor, we simply add them after `_target_`.\n", - "\n", - "---------\n", - "Note, we can still change all of the parameters of this `AudioToMelSpectrogramPreprocessor` class from the command line using hydra, so we don't lose any flexibility once we decide what type of preprocessing class we want !\n", - "\n", - "This also gives us a flexible way to instantiate parts of the model from just the config object !" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "1Be08R4szkT3" - }, - "outputs": [], - "source": [ - "new_preprocessor_config = copy.deepcopy(cfg.preprocessor)\n", - "new_preprocessor = quartznet.from_config_dict(new_preprocessor_config)\n", - "print(new_preprocessor)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UzJQ7Y8H4S_U" - }, - "source": [ - "So how do we actually update our model's internal preprocessor with something new? Well, since NeMo Model's are just pytorch Modules, we can just replace their attribute !" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "WdtnPKX84OJ-" - }, - "outputs": [], - "source": [ - "quartznet.preprocessor = new_preprocessor" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "OMz2KR-24xTO" - }, - "outputs": [], - "source": [ - "quartznet.summarize();" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gPb_BdPN40Ro" - }, - "source": [ - "--------\n", - "This might look like nothing changed - because we didn't actually modify the config for the preprocessor at all ! But as we showed above, we can easily modify the config for the preprocessor, instantiate it from config, and then just set it to the model." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "IV8WKJkD5E_Q" - }, - "source": [ - "-------\n", - "**NOTE**: Preprocessors don't generally have weights, so this was easy, but say we want to replace a part of the model which actually has trained parameters?\n", - "\n", - "Well, the above approach will still work, just remember the fact that the new module you inserted into `quartznet.encoder` or `quartznet.decoder` actually won't have pretrained weights. You can easily rectify that by loading the state dict for the module *before* you set it to the Model though!" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YplQcgfG6S1U" - }, - "source": [ - "### Preserving the new config\n", - "\n", - "So we went ahead and updated the preprocessor of the model. We however also need to perform a crucial step - **preserving the updated config**!\n", - "\n", - "Why do we want to do this? NeMo has many ways of saving and restoring its models, which we will discuss a bit later. All of them depend on having an updated config that defines the model in its entirety, so if we modify anything, we should also update the corresponding part of the config to safely save and restore models." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "dsxQHBV86R4a" - }, - "outputs": [], - "source": [ - "# Update the config copy\n", - "cfg.preprocessor = new_preprocessor_config\n", - "# Update the model config\n", - "quartznet.cfg = cfg" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "eXRRBnJk5tCv" - }, - "source": [ - "## Update a few special components of the Model\n", - "---------\n", - "\n", - "While the above approach is good for most major components of the model, NeMo has special utilities for a few components.\n", - "\n", - "They are - \n", - "\n", - " - `setup_training_data`\n", - " - `setup_validation_data` and `setup_multi_validation_data`\n", - " - `setup_test_data` and `setup_multi_test_data`\n", - " - `setup_optimization`\n", - "\n", - "These special utilities are meant to help you easily setup training, validation, testing once you restore a model from a checkpoint.\n", - "\n", - "------\n", - "One of the major tasks of all conversational AI models is fine-tuning onto new datasets - new languages, new corpus of text, new voices etc. It is often insufficient to have just a pre-trained model. So these setup methods are provided to enable users to adapt models *after* they have been already trained or provided to you.\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "B7Y7wt2x9goJ" - }, - "source": [ - "You might remember having seen a few warning messages the moment you tried to instantiate the pre-trained model. Those warnings are in fact reminders to call the appropriate setup methods for the task you want to perform. \n", - "\n", - "Those warnings are simply displaying the old config that was used to train that model, and are a basic template that you can easily modify. You have the ability to modify the `train_ds`, `validation_ds` and `test_ds` sub-configs in their entirety in order to evaluate, fine-tune or train from scratch the model, or any further purpose as you require it.\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1hXXdaup-QmG" - }, - "source": [ - "Let's discuss how to add the scheduler to the model below (which initially had just an optimizer in its config)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "cveKWvMZ4zBo" - }, - "outputs": [], - "source": [ - "# Let's print out the current optimizer\n", - "print(OmegaConf.to_yaml(quartznet.cfg.optim))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "XVguw3k0-f6b" - }, - "outputs": [], - "source": [ - "# Now let's update the config\n", - "quartznet.setup_optimization(cfg.optim);" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1JZBCQeW-21X" - }, - "source": [ - "-------\n", - "We see a warning - \n", - "\n", - "```\n", - "Neither `max_steps` nor `iters_per_batch` were provided to `optim.sched`, cannot compute effective `max_steps` !\n", - " Scheduler will not be instantiated !\n", - "```\n", - "\n", - "We don't have a train dataset setup, nor do we have max_steps in the config. Most NeMo schedulers cannot be instantiated without computing how many train steps actually exist!\n", - "\n", - "Here, we can temporarily allow the scheduler construction by explicitly passing a max_steps value to be 100" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "mqC89hfE-tqf" - }, - "outputs": [], - "source": [ - "OmegaConf.set_struct(cfg.optim.sched, False)\n", - "\n", - "cfg.optim.sched.max_steps = 100\n", - "\n", - "OmegaConf.set_struct(cfg.optim.sched, True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "r22IqOBK_q6l" - }, - "outputs": [], - "source": [ - "# Now let's update the config and try again\n", - "quartznet.setup_optimization(cfg.optim);" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "U7Eezf_sAVS0" - }, - "source": [ - "You might wonder why we didnt explicitly set `quartznet._cfg.optim = cfg.optim`. \n", - "\n", - "This is because the `setup_optimization()` method does it for you! You can still update the config manually." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "THqhXy_lQ7i8" - }, - "source": [ - "### Optimizer & Scheduler Config\n", - "\n", - "Optimizers and schedulers are common components of models, and are essential to train the model from scratch.\n", - "\n", - "They are grouped together under a unified `optim` namespace, as schedulers often operate on a given optimizer.\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6HY51nuoSJs5" - }, - "source": [ - "### Let's breakdown the general `optim` structure\n", - "```yaml\n", - "optim:\n", - " name: novograd\n", - " lr: 0.01\n", - "\n", - " # optimizer arguments\n", - " betas: [0.8, 0.25]\n", - " weight_decay: 0.001\n", - "\n", - " # scheduler setup\n", - " sched:\n", - " name: CosineAnnealing\n", - "\n", - " # Optional arguments\n", - " max_steps: null # computed at runtime or explicitly set here\n", - " monitor: val_loss\n", - " reduce_on_plateau: false\n", - "\n", - " # scheduler config override\n", - " warmup_steps: 1000\n", - " warmup_ratio: null\n", - " min_lr: 1e-9\n", - "```\n", - "\n", - "Essential Optimizer components - \n", - "\n", - " - `name`: String name of the optimizer. Generally a lower case of the class name.\n", - " - `lr`: Learning rate is a required argument to all optimizers.\n", - "\n", - "Optional Optimizer components - after the above two arguments are provided, any additional arguments added under `optim` will be passed to the constructor of that optimizer as keyword arguments\n", - "\n", - " - `betas`: List of beta values to pass to the optimizer\n", - " - `weight_decay`: Optional weight decay passed to the optimizer.\n", - "\n", - "Optional Scheduler components - `sched` is an optional setup of the scheduler for the given optimizer.\n", - "\n", - "If `sched` is provided, only one essential argument needs to be provided : \n", - "\n", - " - `name`: The name of the scheduler. Generally, it is the full class name.\n", - "\n", - "Optional Scheduler components - \n", - "\n", - " - `max_steps`: Max steps as an override from the user. If one provides `trainer.max_steps` inside the trainer configuration, that value is used instead. If neither value is set, the scheduler will attempt to compute the `effective max_steps` using the size of the train data loader. If that too fails, then the scheduler will not be created at all.\n", - "\n", - " - `monitor`: Used if you are using an adaptive scheduler such as ReduceLROnPlateau. Otherwise ignored. Defaults to `loss` - indicating train loss as monitor.\n", - "\n", - " - `reduce_on_plateau`: Required to be set to true if using an adaptive scheduler.\n", - "\n", - "Any additional arguments under `sched` will be supplied as keyword arguments to the constructor of the scheduler.\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "V3pQM2aj_6WX" - }, - "source": [ - "## Difference between the data loader setup methods\n", - "----------\n", - "\n", - "You might notice, we have multiple setup methods for validation and test data sets. We also don't have an equivalent `setup_multi_train_data`. \n", - "\n", - "In general, the `multi` methods refer to multiple data sets / data loaders. \n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "g33nMx9WCJdj" - }, - "source": [ - "### Where's `setup_multi_train_data`?\n", - "With the above in mind, let's tackle why we don't have `setup_multi_train_data`. \n", - "\n", - "NeMo is concerned with multiple domains - `asr`, `nlp` and `tts`. The way datasets are setup and used in these domains is dramatically different. It is often unclear what it means to have multiple train datasets - do we concatenate them? Do we randomly sample (with same or different probability) from each of them? \n", - "\n", - "Therefore we leave such support for multiple datasets up to the model itself. For example, in ASR, you can concatenate multiple train manifest files by using commas when providing the `manifest_filepath` value!" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BjI2Q5LECJib" - }, - "source": [ - "### What are multi methods?\n", - "\n", - "In many cases, especially true for ASR and NLP, we may have multiple validation and test datasets. The most common example for this in ASR is `Librispeech`, which has `dev_clean`, `dev_other`, `test_clean`, `test_other`.\n", - "\n", - "NeMo standardizes how to handle multiple data loaders for validation and testing, so that all of our collections have a similar look and feel, as well as ease development of our models. During evaluation, these datasets are treated independently and prepended with resolved names so that logs are separate!\n", - "\n", - "The `multi` methods are therefore generalizations of the single validation and single test data setup methods, with some additional functionality. If you provide multiple datasets, you still have to write code for just one dataset and NeMo will automatically attach the appropriate names to your logs so you can differentiate between them!\n", - "\n", - "Furthermore, they also automatically preserve the config the user passes to them when updating the validation or test data loaders.\n", - "\n", - "**In general, it is preferred to call the `setup_multi_validation_data` and `setup_multi_test_data` methods, even if you are only using single datasets, simply for the automated management they provide.**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZKURHn0jH_52" - }, - "source": [ - "## Creating Model from constructor vs restoring a model\n", - "---------\n", - "\n", - "You might notice, we discuss all of the above setup methods in the context of model after it is restored. However, NeMo scripts do not call them inside any of the example train scripts themselves.\n", - "\n", - "This is because these methods are automatically called by the constructor when the Model is created for the first time, but these methods are skipped during restoration (either from a PyTorch Lightning checkpoint using `load_from_checkpoint`, or via `restore_from` method inside NeMo Models).\n", - "\n", - "This is done as most datasets are stored on a user's local directory, and the path to these datasets is set in the config (either set by default, or set by Hydra overrides). On the other hand, the models are meant to be portable. On another user's system, the data might not be placed at exactly the same location, or even on the same drive as specified in the model's config!\n", - "\n", - "Therefore we allow the constructor some brevity and automate such dataset setup, whereas restoration warns that data loaders were not set up and provides the user with ways to set up their own datasets.\n", - "\n", - "------\n", - "\n", - "Why are optimizers not restored automatically? Well, optimizers themselves don't face an issue, but as we saw before, schedulers depend on the number of train steps in order to calculate their schedule.\n", - "\n", - "However, if you don't wish to modify the optimizer and scheduler, and prefer to leave them to their default values, that's perfectly alright. The `setup_optimization()` method is automatically called by PyTorch Lightning for you when you begin training your model!" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "g91FE8mlMcnh" - }, - "source": [ - "## Saving and restoring models\n", - "----------\n", - "\n", - "NeMo provides a few ways to save and restore models. If you utilize the Experiment Manager that is part of all NeMo train scripts, PyTorch Lightning will automatically save checkpoints for you in the experiment directory.\n", - "\n", - "We can also use packaged files using the specialized `save_to` and `restore_from` methods." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NzMxga7QNYn8" - }, - "source": [ - "### Saving and Restoring from PTL Checkpoints\n", - "----------\n", - "\n", - "The PyTorch Lightning Trainer object will periodically save checkpoints when the experiment manager is being used during training.\n", - "\n", - "PyTorch Lightning checkpoints can then be loaded and evaluated / fine-tuned just as always using the class method `load_from_checkpoint`.\n", - "\n", - "For example, restore a QuartzNet model from a checkpoint - \n", - "\n", - "```python\n", - "quartznet = nemo_asr.models.EncDecCTCModel.load_from_checkpoint()\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "W4YzAG-KOBkZ" - }, - "source": [ - "### Saving and Restoring from .nemo files\n", - "----------\n", - "\n", - "There are a few models which might require external dependencies to be packaged with them in order to restore them properly.\n", - "\n", - "One such example is an ASR model with an external BPE tokenizer. It is preferred if the model includes all of the components required to restore it, but a binary file for a tokenizer cannot be serialized into a PyTorch Lightning checkpoint.\n", - "\n", - "In such cases, we can use the `save_to` and `restore_from` method to package the entire model + its components (here, the tokenizer file(s)) into a tarfile. This can then be easily imported by the user and used to restore the model." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "P6_vMSwXNJ74" - }, - "outputs": [], - "source": [ - "# Save the model\n", - "quartznet.save_to('quartznet_15x5.nemo')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "HrBhgaqyP4rU" - }, - "outputs": [], - "source": [ - "!ls -d -- *.nemo " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Tyht1E0DQGb_" - }, - "outputs": [], - "source": [ - "# Restore the model\n", - "temp_qn = nemo_asr.models.EncDecCTCModel.restore_from('quartznet_15x5.nemo')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "dqNpmYYJQS2H" - }, - "outputs": [], - "source": [ - "temp_qn.summarize();" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "A5e42EoiZYjf" - }, - "outputs": [], - "source": [ - "# Note that the preprocessor + optimizer config have been preserved after the changes we made !\n", - "print(OmegaConf.to_yaml(temp_qn.cfg))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "OI3RxwpcV-UF" - }, - "source": [ - "Note, that .nemo file is a simple .tar.gz with checkpoint, configuration and, potentially, other artifacts such as tokenizer configs being used by the model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "jFBAGcaDWLiu" - }, - "outputs": [], - "source": [ - "!cp quartznet_15x5.nemo quartznet_15x5.tar.gz\n", - "!tar -xvf quartznet_15x5.tar.gz" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mkau4Q9jZo1l" - }, - "source": [ - "### Extracting PyTorch checkpoints from NeMo tarfiles (Model level)\n", - "-----------\n", - "\n", - "While the .nemo tarfile is an excellent way to have a portable model, sometimes it is necessary for researchers to have access to the basic PyTorch save format. NeMo aims to be entirely compatible with PyTorch, and therefore offers a simple method to extract just the PyTorch checkpoint from the .nemo tarfile." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "qccPANeycCoq" - }, - "outputs": [], - "source": [ - "import torch" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "A4zswOKHar9q" - }, - "outputs": [], - "source": [ - "state_dict = temp_qn.extract_state_dict_from('quartznet_15x5.nemo', save_dir='./pt_ckpt/')\n", - "!ls ./pt_ckpt/" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ACB-0dfnbFG3" - }, - "source": [ - "As we can see below, there is now a single basic PyTorch checkpoint available inside the `pt_ckpt` directory, which we can use to load the weights of the entire model as below" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "4ZAF_A0uc5bB" - }, - "outputs": [], - "source": [ - "temp_qn.load_state_dict(torch.load('./pt_ckpt/model_weights.ckpt'))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Hkq6EM99cS6y" - }, - "source": [ - "### Extracting PyTorch checkpoints from NeMo tarfiles (Module level)\n", - "----------\n", - "\n", - "While the above method is exceptional when extracting the checkpoint of the entire model, sometimes there may be a necessity to load and save the individual modules that comprise the Model.\n", - "\n", - "The same extraction method offers a flag to extract the individual model level checkpoints into their individual files, so that users have access to per-module level checkpoints." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "LW6wve2zbT9D" - }, - "outputs": [], - "source": [ - "state_dict = temp_qn.extract_state_dict_from('quartznet_15x5.nemo', save_dir='./pt_module_ckpt/', split_by_module=True)\n", - "!ls ./pt_module_ckpt/" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DtV5vpb5d1ni" - }, - "source": [ - "Now, we can load and assign the weights of the individual modules of the above QuartzNet Model !" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "rVHylSKFdywn" - }, - "outputs": [], - "source": [ - "temp_qn.preprocessor.load_state_dict(torch.load('./pt_module_ckpt/preprocessor.ckpt'))\n", - "temp_qn.encoder.load_state_dict(torch.load('./pt_module_ckpt/encoder.ckpt'))\n", - "temp_qn.decoder.load_state_dict(torch.load('./pt_module_ckpt/decoder.ckpt'))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "88vOGV7VYcuu" - }, - "source": [ - "# NeMo with Hydra\n", - "\n", - "[Hydra](https://hydra.cc/docs/intro/) is used throughout NeMo as a way to enable rapid prototyping using predefined config files. Hydra and OmegaConf offer great compatibility with each other, and below we show a few general helpful tips to improve productivity with Hydra when using NeMo." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DfY6Ha3qYcxG" - }, - "source": [ - "## Hydra Help\n", - "--------\n", - "\n", - "Since our scripts are written with hydra in mind, you might notice that using `python --help` returns you a config rather than the usual help format from argparse. \n", - "\n", - "Using `--help` you can see the default config attached to the script - every NeMo script has at least one default config file attached to it. This gives you a guide on how you can modify values for an experiment.\n", - "\n", - "Hydra also has a special `--hydra-help` flag, which will offer you more help with respect to hydra itself as it is set up in the script.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gEsZlnfaYc3X" - }, - "source": [ - "## Changing config paths and files\n", - "---------\n", - "\n", - "While all NeMo models come with at least 1 default config file, one might want to switch configs without changing code. This is easily achieved by the following commands : \n", - "\n", - "- `--config-path`: Path to the directory which contains the config files\n", - "- `--config-name`: Name of the config file we wish to load.\n", - "\n", - "Note that these two arguments need to be at the very beginning of your execution statement, before you provide any command line overrides to your config file." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZyNHlArpYc9A" - }, - "source": [ - "## Overriding config from the command line\n", - "----------\n", - "\n", - "Hydra allows users to provide command line overrides to any part of the config. There are three cases to consider - \n", - "\n", - " - Override existing value in config\n", - " - Add new value in config\n", - " - Remove old value in config" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "96CKbvn6Yc7f" - }, - "source": [ - "### Overriding existing values in config\n", - "\n", - "Let's take the case where we want to change the optimizer from `novograd` to `adam`. Let's also change the beta values to default adam values.\n", - "\n", - "Hydra overrides are based on the `.` syntax - each `.` representing a level in the config itself.\n", - "\n", - "```sh\n", - "$ python