Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformer-based Text Normalization Models #2415

Merged
merged 56 commits into from
Jul 8, 2021
Merged
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
e60f018
Add notebook with recommendations for 8 kHz speech (#2326)
jbalam-nv Jun 18, 2021
b3c1b01
Add FastEmit support for RNNT Losses (#2374)
titu1994 Jun 22, 2021
7461bc4
Implement inference functions of TN models
Jun 23, 2021
d11bfc8
Minor Fix
Jun 24, 2021
21169a3
fix bugs in hifigan code (#2392)
Oktai15 Jun 23, 2021
f81dfc2
Update setup.py (#2394)
blisc Jun 23, 2021
2db9c4b
update checkpointing (#2396)
blisc Jun 23, 2021
3e76624
byt5 unicode implementation (#2365)
mchrzanowski Jun 24, 2021
af49749
Minor Fix
Jun 24, 2021
6fdc83f
Minor Fixes
Jun 24, 2021
02a1943
Add TextNormalizationTestDataset and testing/evaluation code
Jun 24, 2021
690dca3
Add TextNormalizationTaggerDataset and training code for tagger
Jun 25, 2021
4eeb6be
Restore from local nemo ckpts
Jun 25, 2021
d642381
Add TextNormalizationDecoderDataset
Jun 25, 2021
e7f1a3f
Add interactive mode for neural_text_normalization_test.py
Jun 25, 2021
2775d62
Add options to do training or not for tagger/decoder
Jun 25, 2021
cf13111
Renamed
Jun 25, 2021
85c1417
Implemented setup dataloader for decoder
Jun 25, 2021
7aeba27
Implemented training and validation for decoder
Jun 25, 2021
7bfa8de
Data augmentation for decoder training
Jun 25, 2021
8e370f0
Config change
Jun 25, 2021
85f1e3c
add blossom-ci.yml (#2401)
ericharper Jun 25, 2021
5d93237
Merge r1.1 bugfixes into main (#2407)
ericharper Jun 25, 2021
42ad961
Remove unused imports
Jun 28, 2021
b4ffb28
Add initial doc for text_normalization
Jun 28, 2021
d83ccf7
Fixed imports warnings
Jun 28, 2021
cff1dee
Minor Fix
Jun 28, 2021
1ab0583
Renamed
Jun 28, 2021
617fa9f
Allowed duplex modes
Jun 28, 2021
96e58b9
Minor Fix
Jun 29, 2021
ec26b7e
Add docs for duplex_text_normalization_train and duplex_text_normaliz…
Jun 29, 2021
9e3e825
docstrings for model codes + minor fix
Jun 29, 2021
f5b566c
Add more comments and doc strings
Jun 29, 2021
8859f1d
Merge branch 'NVIDIA:main' into neural_tn
laituan245 Jun 29, 2021
e0b036b
Add doc for datasets + Use time.perf_counter()
Jun 30, 2021
780ed53
Add code for preprocessing Google TN data
Jun 30, 2021
8580b9c
Add more docs and comments + Minor Fixes
Jun 30, 2021
8f38766
Add more licenses + Fixed comments + Minors
Jun 30, 2021
c68e1a3
Moved evaluation logic to DuplexTextNormalizationModel
Jul 1, 2021
b106c02
Add logging errors
Jul 1, 2021
2cd3b43
Updated validation code of tagger + Minors
Jul 1, 2021
ab5b1b9
Also write tag preds to log file
Jul 1, 2021
fed32a8
Add data augmentation for tagger dataset
Jul 1, 2021
0eaee3a
Added experimental decorators
Jul 1, 2021
c2def6c
Updated docs
Jul 1, 2021
6f47f2b
Updated duplex_tn_config.yaml
Jul 2, 2021
5360a48
Compute token precision of tagger using NeMo metrics
Jul 2, 2021
aafef49
Fixed saving issue when using ddp accelerator
Jul 2, 2021
fb2d7e7
Refactoring
Jul 3, 2021
fcf921e
Merge branch 'NVIDIA:main' into neural_tn
laituan245 Jul 3, 2021
55f9db0
Add option to keep punctuations in TextNormalizationTestDataset
Jul 4, 2021
381318c
Merge branch 'neural_tn' of github.com:laituan245/NeMo into neural_tn
Jul 4, 2021
722bb9e
Changes to input preprocessing + decoder's postprocessing
Jul 4, 2021
f15dad2
Fixed styles + Add references
Jul 7, 2021
a6e8cdc
Renamed examples/nlp/duplex_text_normalization/utils.py to helpers.py
Jul 7, 2021
f7d390b
Merge branch 'main' into neural_tn
okuchaiev Jul 8, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions docs/source/nlp/text_normalization.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
.. _text_normalization:

Text Normalization Models
==========================
laituan245 marked this conversation as resolved.
Show resolved Hide resolved
Text normalization is the task of converting a written text into its spoken form. For example,
``$123`` should be verbalized as ``one hundred twenty three dollars``, while ``123 King Ave``
should be verbalized as ``one twenty three King Avenue``. At the same time, the inverse problem
is about converting a spoken sequence (e.g., an ASR output) into its written form.

NeMo has an implementation that allows you to build a neural-based system that is able to do
both text normalization (TN) and also inverse text normalization (ITN). At a high level, the
system consists of two individual components:

- `DuplexTaggerModel <https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/duplex_text_normalization/duplex_tagger.py/>`__ - a Transformer-based tagger for identifying "semiotic" spans in the input (e.g., spans that are about times, dates, or monetary amounts).
- `DuplexDecoderModel <https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/duplex_text_normalization/duplex_decoder.py/>`__ - a Transformer-based seq2seq model for decoding the semiotic spans into their appropriate forms (e.g., spoken forms for TN and written forms for ITN).

The typical workflow is to first train a DuplexTaggerModel and also a DuplexDecoderModel. An example training script
is provided: `duplex_text_normalization_train.py <https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/duplex_text_normalization/duplex_text_normalization_train.py>`__.
After that, the two trained models can be used to initialize a `DuplexTextNormalizationModel <https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/duplex_text_normalization/duplex_tn.py/>`__ that can be used for end-to-end inference.
An example script for evaluation and inference is provided: `duplex_text_normalization_test.py <https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/duplex_text_normalization/duplex_text_normalization_test.py>`__. The term
*duplex* refers to the fact that our system can be trained to do both TN and ITN. However, you can also specifically train the system for only one of the tasks.

NeMo Data Format
-----------
Both the DuplexTaggerModel model and the DuplexDecoderModel model use the same simple text format
as the dataset. The data needs to be stored in TAB separated files (``.tsv``) with three columns.
laituan245 marked this conversation as resolved.
Show resolved Hide resolved
The first of which is the "semiotic class" (e.g., numbers, times, dates) , the second is the token
in written form, and the third is the spoken form. An example sentence in the dataset is shown below.
In the example, ``sil`` denotes that a token is a punctuation while ``self`` denotes that the spoken form is the
same as the written form. It is expected that a complete dataset contains three files: ``train.tsv``, ``dev.tsv``,
and ``test.tsv``.

.. code::

PLAIN The <self>
PLAIN company 's <self>
PLAIN revenues <self>
PLAIN grew <self>
PLAIN four <self>
PLAIN fold <self>
PLAIN between <self>
DATE 2005 two thousand five
PLAIN and <self>
DATE 2008 two thousand eight
PUNCT . sil
<eos> <eos>


An example script for generating a dataset in this format from the `Google text normalization dataset <https://www.kaggle.com/google-nlu/text-normalization>`_
can be found at `NeMo/examples/nlp/duplex_text_normalization/google_data_preprocessing.py <https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/duplex_text_normalization/google_data_preprocessing.py>`__.
Note that the script also does some preprocessing on the spoken forms of the URLs. For example,
given the URL "Zimbio.com", the original expected spoken form in the Google dataset is
"z_letter i_letter m_letter b_letter i_letter o_letter dot c_letter o_letter m_letter".
However, our script will return a more concise output which is "zim bio dot com".

Model Training
--------------

An example training script is provided: `duplex_text_normalization_train.py <https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/duplex_text_normalization/duplex_text_normalization_train.py>`__.
The config file used for the example is at `duplex_tn_config.yaml <https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/duplex_text_normalization/conf/duplex_tn_config.yaml>`__.
You can change any of these parameters directly from the config file or update them with the command-line arguments.

The config file contains three main sections. The first section contains the configs for the tagger, the second section is about the decoder,
and the last section is about the dataset. Most arguments in the example config file are quite self-explanatory (e.g.,
*decoder_model.optim.lr* refers to the learning rate for training the decoder). We have set most of the hyper-parameters to
be the values that we found to be effective. Some arguments that you may want to modify are:

- *data.base_dir*: The path to the dataset directory. It is expected that the directory contains three files: train.tsv, dev.tsv, and test.tsv.

- *tagger_model.nemo_path*: This is the path where the final trained tagger model will be saved to.

- *decoder_model.nemo_path*: This is the path where the final trained decoder model will be saved to.

Example of a training command:

.. code::

python examples/nlp/duplex_text_normalization/duplex_text_normalization_train.py \
data.base_dir=<PATH_TO_DATASET_DIR> \
mode={tn,itn,joint}

There are 3 different modes. "tn" mode is for training a system for TN only.
"itn" mode is for training a system for ITN. "joint" is for training a system
that can do both TN and ITN at the same time. Note that the above command will
first train a tagger and then train a decoder sequentially.

You can also train only a tagger (without training a decoder) by running the
following command:

.. code::

python examples/nlp/duplex_text_normalization/duplex_text_normalization_train.py \
data.base_dir=PATH_TO_DATASET_DIR \
mode={tn,itn,joint} \
decoder_model.do_training=false

Or you can also train only a decoder (without training a tagger):

.. code::

python examples/nlp/duplex_text_normalization/duplex_text_normalization_train.py \
data.base_dir=PATH_TO_DATASET_DIR \
mode={tn,itn,joint} \
tagger_model.do_training=false


Model Architecture
--------------

The tagger model first uses a Transformer encoder (e.g., DistilRoBERTa) to build a
contextualized representation for each input token. It then uses a classification head
to predict the tag for each token (e.g., if a token should stay the same, its tag should
be ``SAME``). The decoder model then takes the semiotic spans identified by the tagger and
transform them into the appropriate forms (e.g., spoken forms for TN and written forms for ITN).
The decoder model is essentially a Transformer-based encoder-decoder seq2seq model (e.g., the example
training script uses the T5-base model by default).

We introduce a simple but effective technique to allow our model to be duplex. Depending on the
task the model is handling, we append the appropriate prefix to the input. For example, suppose
we want to transform the text ``I live in 123 King Ave`` to its spoken form (i.e., TN problem),
then we will simply append the prefix ``tn`` to it and so the final input to our models will actually
be ``tn I live in tn 123 King Ave``. Similarly, for the ITN problem, we just append the prefix ``itn``
to the input.

To improve the effectiveness and robustness of our models, we also apply some simple data
augmentation techniques during training.

Data Augmentation for Training DuplexTaggerModel
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
In the Google English TN training data, about 93% of the tokens are not in any semiotic span. In other words, the ground-truth tags of most tokens are of trivial types (i.e., ``SAME`` and ``PUNCT``). To alleviate this class imbalance problem,
for each original instance with several semiotic spans, we create a new instance by simply concatenating all the semiotic spans together. For example, considering the following ITN instance:

Original instance: ``[The|SAME] [revenues|SAME] [grew|SAME] [a|SAME] [lot|SAME] [between|SAME] [two|B-TRANSFORM] [thousand|I-TRANSFORM] [two|I-TRANSFORM] [and|SAME] [two|B-TRANSFORM] [thousand|I-TRANSFORM] [five|I-TRANSFORM] [.|PUNCT]``

Augmented instance: ``[two|B-TRANSFORM] [thousand|I-TRANSFORM] [two|I-TRANSFORM] [two|B-TRANSFORM] [thousand|I-TRANSFORM] [five|I-TRANSFORM]``

The argument ``data.train_ds.tagger_data_augmentation`` in the config file controls whether this data augmentation will be enabled or not.


Data Augmentation for Training DuplexDecoderModel
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Since the tagger may not be perfect, the inputs to the decoder may not all be semiotic spans. Therefore, to make the decoder become more robust against the tagger's potential errors,
we train the decoder with not only semiotic spans but also with some other more "noisy" spans. This way even if the tagger makes some errors, there will still be some chance that the
final output is still correct.

The argument ``data.train_ds.decoder_data_augmentation`` in the config file controls whether this data augmentation will be enabled or not.
136 changes: 136 additions & 0 deletions examples/nlp/duplex_text_normalization/conf/duplex_tn_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
name: &name DuplexTextNormalization
mode: joint # Three possible choices ['tn', 'itn', 'joint']

# Pretrained Nemo Models
tagger_pretrained_model: null
decoder_pretrained_model: null

# Tagger
tagger_trainer:
gpus: 1 # the number of gpus, 0 for CPU
num_nodes: 1
max_epochs: 5 # the number of training epochs
checkpoint_callback: false # provided by exp_manager
logger: false # provided by exp_manager
accumulate_grad_batches: 1 # accumulates grads every k batches
gradient_clip_val: 0.0
amp_level: O0 # O1/O2 for mixed precision
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
accelerator: ddp

tagger_model:
do_training: true
transformer: distilroberta-base
tokenizer: ${tagger_model.transformer}
nemo_path: ${tagger_exp_manager.exp_dir}/tagger_model.nemo # exported .nemo path

optim:
name: adamw
lr: 5e-5
weight_decay: 0.01

sched:
name: WarmupAnnealing

# pytorch lightning args
monitor: val_token_precision
reduce_on_plateau: false

# scheduler config override
warmup_steps: null
warmup_ratio: 0.1
last_epoch: -1

tagger_exp_manager:
exp_dir: exps # where to store logs and checkpoints
name: tagger_training # name of experiment
create_tensorboard_logger: True
create_checkpoint_callback: True
checkpoint_callback_params:
save_top_k: 3
monitor: "val_token_precision"
mode: "max"
save_best_model: true
always_save_nemo: true

# Decoder
decoder_trainer:
gpus: 1 # the number of gpus, 0 for CPU
num_nodes: 1
max_epochs: 3 # the number of training epochs
checkpoint_callback: false # provided by exp_manager
logger: false # provided by exp_manager
accumulate_grad_batches: 1 # accumulates grads every k batches
gradient_clip_val: 0.0
amp_level: O0 # O1/O2 for mixed precision
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
accelerator: ddp

decoder_model:
do_training: true
transformer: t5-base
tokenizer: ${decoder_model.transformer}
nemo_path: ${decoder_exp_manager.exp_dir}/decoder_model.nemo # exported .nemo path

optim:
name: adamw
lr: 2e-4
weight_decay: 0.01

sched:
name: WarmupAnnealing

# pytorch lightning args
monitor: val_loss
reduce_on_plateau: false

# scheduler config override
warmup_steps: null
warmup_ratio: 0.0
last_epoch: -1

decoder_exp_manager:
exp_dir: exps # where to store logs and checkpoints
name: decoder_training # name of experiment
create_tensorboard_logger: True
create_checkpoint_callback: True
checkpoint_callback_params:
save_top_k: 3
monitor: "val_loss"
mode: "min"
save_best_model: true
always_save_nemo: true

# Data
data:
base_dir: ??? # /path/to/data

train_ds:
data_path: ${data.base_dir}/train.tsv
batch_size: 64
shuffle: true
do_basic_tokenize: false
max_decoder_len: 80
mode: ${mode}
# Refer to the text_normalization doc for more information about data augmentation
tagger_data_augmentation: true
decoder_data_augmentation: true

validation_ds:
data_path: ${data.base_dir}/dev.tsv
batch_size: 64
shuffle: false
do_basic_tokenize: false
max_decoder_len: 80
mode: ${mode}

test_ds:
data_path: ${data.base_dir}/test.tsv
batch_size: 64
shuffle: false
mode: ${mode}

# Inference
inference:
interactive: false # Set to true if you want to enable the interactive mode when running duplex_text_normalization_test.py
errors_log_fp: errors.txt # Path to the file for logging the errors
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
This script contains an example on how to evaluate a DuplexTextNormalizationModel.
Note that DuplexTextNormalizationModel is essentially a wrapper class around
DuplexTaggerModel and DuplexDecoderModel. Therefore, two trained NeMo models
should be specificied before evaluation (one is a trained DuplexTaggerModel
and the other is a trained DuplexDecoderModel).

USAGE Example:
1. Obtain a processed test data file (refer to the `text_normalization doc <https://github.com/NVIDIA/NeMo/blob/main/docs/source/nlp/text_normalization.rst>`)
2.
# python duplex_text_normalization_test.py
tagger_pretrained_model=PATH_TO_TRAINED_TAGGER
decoder_pretrained_model=PATH_TO_TRAINED_DECODER
data.test_ds.data_path=PATH_TO_TEST_FILE
mode={tn,itn,joint}

The script also supports the `interactive` mode where a user can just make the model
run on any input text:
# python duplex_text_normalization_test.py
tagger_pretrained_model=PATH_TO_TRAINED_TAGGER
decoder_pretrained_model=PATH_TO_TRAINED_DECODER
mode={tn,itn,joint}
inference.interactive=true

This script uses the `/examples/nlp/duplex_text_normalization/conf/duplex_tn_config.yaml`
config file by default. The other option is to set another config file via command
line arguments by `--config-name=CONFIG_FILE_PATH'.

Note that when evaluating a DuplexTextNormalizationModel on a labeled dataset,
the script will automatically generate a file for logging the errors made
by the model. The location of this file is determined by the argument
`inference.errors_log_fp`.

"""


import nemo.collections.nlp.data.text_normalization.constants as constants

from nltk import word_tokenize
from omegaconf import DictConfig, OmegaConf
from utils import TAGGER_MODEL, DECODER_MODEL, instantiate_model_and_trainer

from nemo.utils import logging
from nemo.core.config import hydra_runner
from nemo.collections.nlp.models import DuplexTextNormalizationModel
from nemo.collections.nlp.data.text_normalization import TextNormalizationTestDataset

@hydra_runner(config_path="conf", config_name="duplex_tn_config")
def main(cfg: DictConfig) -> None:
logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}')
tagger_trainer, tagger_model = instantiate_model_and_trainer(cfg, TAGGER_MODEL, False)
decoder_trainer, decoder_model = instantiate_model_and_trainer(cfg, DECODER_MODEL, False)
laituan245 marked this conversation as resolved.
Show resolved Hide resolved
tn_model = DuplexTextNormalizationModel(tagger_model, decoder_model)

if not cfg.inference.interactive:
# Setup test_dataset
test_dataset = TextNormalizationTestDataset(cfg.data.test_ds.data_path,
cfg.data.test_ds.mode)
results = tn_model.evaluate(test_dataset,
cfg.data.test_ds.batch_size,
cfg.inference.errors_log_fp)
print(f'\nTest results: {results}')
else:
while True:
test_input = input('Input a test input:')
test_input = ' '.join(word_tokenize(test_input))
outputs = tn_model._infer([test_input, test_input],
[constants.INST_BACKWARD, constants.INST_FORWARD])[-1]
print(f'Prediction (ITN): {outputs[0]}')
print(f'Prediction (TN): {outputs[1]}')

should_continue = input('\nContinue (y/n): ').strip().lower()
if should_continue.startswith('n'): break

if __name__ == '__main__':
main()
Loading