Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EMA Docs, fix common collection documentation #5757

Merged
merged 5 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions docs/source/common/callbacks.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
*********
Callbacks
*********

Exponential Moving Average (EMA)
================================

During training, EMA maintains a moving average of the trained parameters.
EMA parameters can produce significantly better results and faster convergence for a variety of different domains and models.

EMA is a simple calculation. EMA Weights are pre-initialized with the model weights at the start of training.

Every training update, the EMA weights are updated based on the new model weights.

.. math::
ema_w = ema_w * decay + model_w * (1-decay)

Enabling EMA is straightforward. We can pass the additional argument to the experiment manager at runtime.

.. code-block:: bash

python examples/asr/asr_ctc/speech_to_text_ctc.py \
model.train_ds.manifest_filepath=/path/to/my/train/manifest.json \
model.validation_ds.manifest_filepath=/path/to/my/validation/manifest.json \
trainer.devices=2 \
trainer.accelerator='gpu' \
trainer.max_epochs=50 \
exp_manager.ema.enable=True # pass this additional argument to enable EMA

To change the decay rate, pass the additional argument.

.. code-block:: bash

python examples/asr/asr_ctc/speech_to_text_ctc.py \
model.train_ds.manifest_filepath=/path/to/my/train/manifest.json \
model.validation_ds.manifest_filepath=/path/to/my/validation/manifest.json \
trainer.devices=2 \
trainer.accelerator='gpu' \
trainer.max_epochs=50 \
exp_manager.ema.enable=True \
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
exp_manager.ema.decay=0.999

We also offer other helpful arguments.

.. code-block:: bash

python examples/asr/asr_ctc/speech_to_text_ctc.py \
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
model.train_ds.manifest_filepath=/path/to/my/train/manifest.json \
model.validation_ds.manifest_filepath=/path/to/my/validation/manifest.json \
trainer.devices=2 \
trainer.accelerator='gpu' \
trainer.max_epochs=50 \
exp_manager.ema.enable=True \
exp_manager.ema.validate_original_weights=True \ # validate the original weights instead of EMA weights.
exp_manager.ema.every_n_steps=2 \ # apply EMA every N steps instead of every step.
exp_manager.ema.cpu_offload=True # offload EMA weights to CPU. May introduce significant slow-downs.
39 changes: 6 additions & 33 deletions docs/source/common/intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,10 @@ Common Collection

The common collection contains things that could be used across all collections.

Tokenizers
----------
.. automodule:: nemo.collections.common.tokenizers.AutoTokenizer
:special-members: __init__
.. automodule:: nemo.collections.common.tokenizers.SentencePieceTokenizer
:special-members: __init__
.. automodule:: nemo.collections.common.tokenizers.TokenizerSpec
:special-members: __init__
.. toctree::
:maxdepth: 8


Losses
------
.. automodule:: nemo.collections.common.losses.AggregatorLoss
:special-members: __init__

.. automodule:: nemo.collections.common.losses.CrossEntropyLoss
:special-members: __init__

.. automodule:: nemo.collections.common.losses.MSELoss
:special-members: __init__

.. automodule:: nemo.collections.common.losses.SmoothedCrossEntropyLoss
:special-members: __init__
.. automodule:: nemo.collections.common.losses.SpanningLoss
:special-members: __init__


Metrics
-------

.. autoclass:: nemo.collections.common.metrics.Perplexity
:show-inheritance:
:members:
:undoc-members:
callbacks
losses
metrics
tokenizers
16 changes: 16 additions & 0 deletions docs/source/common/losses.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
Losses
------
.. autoclass:: nemo.collections.common.losses.AggregatorLoss
:special-members: __init__

.. autoclass:: nemo.collections.common.losses.CrossEntropyLoss
:special-members: __init__

.. autoclass:: nemo.collections.common.losses.MSELoss
:special-members: __init__

.. autoclass:: nemo.collections.common.losses.SmoothedCrossEntropyLoss
:special-members: __init__

.. autoclass:: nemo.collections.common.losses.SpanningLoss
:special-members: __init__
7 changes: 7 additions & 0 deletions docs/source/common/metrics.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Metrics
-------

.. autoclass:: nemo.collections.common.metrics.Perplexity
:show-inheritance:
:members:
:undoc-members:
8 changes: 8 additions & 0 deletions docs/source/common/tokenizers.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Tokenizers
----------
.. autoclass:: nemo.collections.common.tokenizers.AutoTokenizer
:special-members: __init__
.. autoclass:: nemo.collections.common.tokenizers.SentencePieceTokenizer
:special-members: __init__
.. autoclass:: nemo.collections.common.tokenizers.TokenizerSpec
:special-members: __init__
2 changes: 1 addition & 1 deletion nemo/collections/common/callbacks/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

import pytorch_lightning as pl
import torch
from lightning_utilities.core.rank_zero import rank_zero_info
from pytorch_lightning import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_info


class EMA(Callback):
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements_docs.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
Jinja2<3.1
Copy link
Collaborator Author

@SeanNaren SeanNaren Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When running bash update_docs_docker.sh in the docs folder, I was getting this error: readthedocs/readthedocs.org#9038

To fix this, I set the requirement for docs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow it's still an issue after all this time ?? I could run it last year without this ..

Copy link
Collaborator Author

@SeanNaren SeanNaren Jan 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried again but still required this requirement, as it's building in a docker image I'll assume that it would be the same for everyone else

latexcodec
numpy
sphinx>=3.0
Expand Down