diff --git a/docs/source/nlp/models.rst b/docs/source/nlp/models.rst index c6b9adbad02a..7218525b29e7 100755 --- a/docs/source/nlp/models.rst +++ b/docs/source/nlp/models.rst @@ -20,3 +20,4 @@ NeMo's NLP collection supports the following models: information_retrieval nlp_model machine_translation + text_normalization diff --git a/docs/source/nlp/nlp_all.bib b/docs/source/nlp/nlp_all.bib index f291310941e4..a7084fe78159 100644 --- a/docs/source/nlp/nlp_all.bib +++ b/docs/source/nlp/nlp_all.bib @@ -100,10 +100,26 @@ @article{post2018call } @misc{zhang2021sgdqa, - title={SGD-QA: Fast Schema-Guided Dialogue State Tracking for Unseen Services}, + title={SGD-QA: Fast Schema-Guided Dialogue State Tracking for Unseen Services}, author={Yang Zhang and Vahid Noroozi and Evelina Bakhturina and Boris Ginsburg}, year={2021}, eprint={2105.08049}, archivePrefix={arXiv}, primaryClass={cs.CL} -} \ No newline at end of file +} + +@article{Sproat2016RNNAT, + title={RNN Approaches to Text Normalization: A Challenge}, + author={R. Sproat and Navdeep Jaitly}, + journal={ArXiv}, + year={2016}, + volume={abs/1611.00068} +} + +@article{Zhang2019NeuralMO, + title={Neural Models of Text Normalization for Speech Applications}, + author={Hao Zhang and R. Sproat and Axel H. Ng and Felix Stahlberg and Xiaochang Peng and Kyle Gorman and B. Roark}, + journal={Computational Linguistics}, + year={2019}, + pages={293-338} +} diff --git a/docs/source/nlp/text_normalization.rst b/docs/source/nlp/text_normalization.rst new file mode 100644 index 000000000000..f3871c3115ca --- /dev/null +++ b/docs/source/nlp/text_normalization.rst @@ -0,0 +1,159 @@ +.. _text_normalization: + +Text Normalization Models +========================== +Text normalization is the task of converting a written text into its spoken form. For example, +``$123`` should be verbalized as ``one hundred twenty three dollars``, while ``123 King Ave`` +should be verbalized as ``one twenty three King Avenue``. At the same time, the inverse problem +is about converting a spoken sequence (e.g., an ASR output) into its written form. + +NeMo has an implementation that allows you to build a neural-based system that is able to do +both text normalization (TN) and also inverse text normalization (ITN). At a high level, the +system consists of two individual components: + +- `DuplexTaggerModel `__ - a Transformer-based tagger for identifying "semiotic" spans in the input (e.g., spans that are about times, dates, or monetary amounts). +- `DuplexDecoderModel `__ - a Transformer-based seq2seq model for decoding the semiotic spans into their appropriate forms (e.g., spoken forms for TN and written forms for ITN). + +The typical workflow is to first train a DuplexTaggerModel and also a DuplexDecoderModel. An example training script +is provided: `duplex_text_normalization_train.py `__. +After that, the two trained models can be used to initialize a `DuplexTextNormalizationModel `__ that can be used for end-to-end inference. +An example script for evaluation and inference is provided: `duplex_text_normalization_test.py `__. The term +*duplex* refers to the fact that our system can be trained to do both TN and ITN. However, you can also specifically train the system for only one of the tasks. + +NeMo Data Format +----------- +Both the DuplexTaggerModel model and the DuplexDecoderModel model use the same simple text format +as the dataset. The data needs to be stored in TAB separated files (``.tsv``) with three columns. +The first of which is the "semiotic class" (e.g., numbers, times, dates) , the second is the token +in written form, and the third is the spoken form. An example sentence in the dataset is shown below. +In the example, ``sil`` denotes that a token is a punctuation while ``self`` denotes that the spoken form is the +same as the written form. It is expected that a complete dataset contains three files: ``train.tsv``, ``dev.tsv``, +and ``test.tsv``. + +.. code:: + + PLAIN The + PLAIN company 's + PLAIN revenues + PLAIN grew + PLAIN four + PLAIN fold + PLAIN between + DATE 2005 two thousand five + PLAIN and + DATE 2008 two thousand eight + PUNCT . sil + + + +An example script for generating a dataset in this format from the `Google text normalization dataset `_ +can be found at `NeMo/examples/nlp/duplex_text_normalization/google_data_preprocessing.py `__. +Note that the script also does some preprocessing on the spoken forms of the URLs. For example, +given the URL "Zimbio.com", the original expected spoken form in the Google dataset is +"z_letter i_letter m_letter b_letter i_letter o_letter dot c_letter o_letter m_letter". +However, our script will return a more concise output which is "zim bio dot com". + +More information about the Google text normalization dataset can be found in the paper `RNN Approaches to Text Normalization: A Challenge `__ :cite:`nlp-textnorm-Sproat2016RNNAT`. + + +Model Training +-------------- + +An example training script is provided: `duplex_text_normalization_train.py `__. +The config file used for the example is at `duplex_tn_config.yaml `__. +You can change any of these parameters directly from the config file or update them with the command-line arguments. + +The config file contains three main sections. The first section contains the configs for the tagger, the second section is about the decoder, +and the last section is about the dataset. Most arguments in the example config file are quite self-explanatory (e.g., +*decoder_model.optim.lr* refers to the learning rate for training the decoder). We have set most of the hyper-parameters to +be the values that we found to be effective. Some arguments that you may want to modify are: + +- *data.base_dir*: The path to the dataset directory. It is expected that the directory contains three files: train.tsv, dev.tsv, and test.tsv. + +- *tagger_model.nemo_path*: This is the path where the final trained tagger model will be saved to. + +- *decoder_model.nemo_path*: This is the path where the final trained decoder model will be saved to. + +Example of a training command: + +.. code:: + + python examples/nlp/duplex_text_normalization/duplex_text_normalization_train.py \ + data.base_dir= \ + mode={tn,itn,joint} + +There are 3 different modes. "tn" mode is for training a system for TN only. +"itn" mode is for training a system for ITN. "joint" is for training a system +that can do both TN and ITN at the same time. Note that the above command will +first train a tagger and then train a decoder sequentially. + +You can also train only a tagger (without training a decoder) by running the +following command: + +.. code:: + + python examples/nlp/duplex_text_normalization/duplex_text_normalization_train.py \ + data.base_dir=PATH_TO_DATASET_DIR \ + mode={tn,itn,joint} \ + decoder_model.do_training=false + +Or you can also train only a decoder (without training a tagger): + +.. code:: + + python examples/nlp/duplex_text_normalization/duplex_text_normalization_train.py \ + data.base_dir=PATH_TO_DATASET_DIR \ + mode={tn,itn,joint} \ + tagger_model.do_training=false + + +Model Architecture +-------------- + +The tagger model first uses a Transformer encoder (e.g., DistilRoBERTa) to build a +contextualized representation for each input token. It then uses a classification head +to predict the tag for each token (e.g., if a token should stay the same, its tag should +be ``SAME``). The decoder model then takes the semiotic spans identified by the tagger and +transform them into the appropriate forms (e.g., spoken forms for TN and written forms for ITN). +The decoder model is essentially a Transformer-based encoder-decoder seq2seq model (e.g., the example +training script uses the T5-base model by default). Overall, our design is partly inspired by the +RNN-based sliding window model proposed in the paper +`Neural Models of Text Normalization for Speech Applications `__ :cite:`nlp-textnorm-Zhang2019NeuralMO`. + +We introduce a simple but effective technique to allow our model to be duplex. Depending on the +task the model is handling, we append the appropriate prefix to the input. For example, suppose +we want to transform the text ``I live in 123 King Ave`` to its spoken form (i.e., TN problem), +then we will simply append the prefix ``tn`` to it and so the final input to our models will actually +be ``tn I live in tn 123 King Ave``. Similarly, for the ITN problem, we just append the prefix ``itn`` +to the input. + +To improve the effectiveness and robustness of our models, we also apply some simple data +augmentation techniques during training. + +Data Augmentation for Training DuplexTaggerModel +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +In the Google English TN training data, about 93% of the tokens are not in any semiotic span. In other words, the ground-truth tags of most tokens are of trivial types (i.e., ``SAME`` and ``PUNCT``). To alleviate this class imbalance problem, +for each original instance with several semiotic spans, we create a new instance by simply concatenating all the semiotic spans together. For example, considering the following ITN instance: + +Original instance: ``[The|SAME] [revenues|SAME] [grew|SAME] [a|SAME] [lot|SAME] [between|SAME] [two|B-TRANSFORM] [thousand|I-TRANSFORM] [two|I-TRANSFORM] [and|SAME] [two|B-TRANSFORM] [thousand|I-TRANSFORM] [five|I-TRANSFORM] [.|PUNCT]`` + +Augmented instance: ``[two|B-TRANSFORM] [thousand|I-TRANSFORM] [two|I-TRANSFORM] [two|B-TRANSFORM] [thousand|I-TRANSFORM] [five|I-TRANSFORM]`` + +The argument ``data.train_ds.tagger_data_augmentation`` in the config file controls whether this data augmentation will be enabled or not. + + +Data Augmentation for Training DuplexDecoderModel +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Since the tagger may not be perfect, the inputs to the decoder may not all be semiotic spans. Therefore, to make the decoder become more robust against the tagger's potential errors, +we train the decoder with not only semiotic spans but also with some other more "noisy" spans. This way even if the tagger makes some errors, there will still be some chance that the +final output is still correct. + +The argument ``data.train_ds.decoder_data_augmentation`` in the config file controls whether this data augmentation will be enabled or not. + +References +---------- + +.. bibliography:: nlp_all.bib + :style: plain + :labelprefix: NLP-TEXTNORM + :keyprefix: nlp-textnorm- diff --git a/examples/nlp/duplex_text_normalization/conf/duplex_tn_config.yaml b/examples/nlp/duplex_text_normalization/conf/duplex_tn_config.yaml new file mode 100644 index 000000000000..c9aa55f850c3 --- /dev/null +++ b/examples/nlp/duplex_text_normalization/conf/duplex_tn_config.yaml @@ -0,0 +1,136 @@ +name: &name DuplexTextNormalization +mode: joint # Three possible choices ['tn', 'itn', 'joint'] + +# Pretrained Nemo Models +tagger_pretrained_model: null +decoder_pretrained_model: null + +# Tagger +tagger_trainer: + gpus: 1 # the number of gpus, 0 for CPU + num_nodes: 1 + max_epochs: 5 # the number of training epochs + checkpoint_callback: false # provided by exp_manager + logger: false # provided by exp_manager + accumulate_grad_batches: 1 # accumulates grads every k batches + gradient_clip_val: 0.0 + amp_level: O0 # O1/O2 for mixed precision + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + accelerator: ddp + +tagger_model: + do_training: true + transformer: distilroberta-base + tokenizer: ${tagger_model.transformer} + nemo_path: ${tagger_exp_manager.exp_dir}/tagger_model.nemo # exported .nemo path + + optim: + name: adamw + lr: 5e-5 + weight_decay: 0.01 + + sched: + name: WarmupAnnealing + + # pytorch lightning args + monitor: val_token_precision + reduce_on_plateau: false + + # scheduler config override + warmup_steps: null + warmup_ratio: 0.1 + last_epoch: -1 + +tagger_exp_manager: + exp_dir: exps # where to store logs and checkpoints + name: tagger_training # name of experiment + create_tensorboard_logger: True + create_checkpoint_callback: True + checkpoint_callback_params: + save_top_k: 3 + monitor: "val_token_precision" + mode: "max" + save_best_model: true + always_save_nemo: true + +# Decoder +decoder_trainer: + gpus: 1 # the number of gpus, 0 for CPU + num_nodes: 1 + max_epochs: 3 # the number of training epochs + checkpoint_callback: false # provided by exp_manager + logger: false # provided by exp_manager + accumulate_grad_batches: 1 # accumulates grads every k batches + gradient_clip_val: 0.0 + amp_level: O0 # O1/O2 for mixed precision + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + accelerator: ddp + +decoder_model: + do_training: true + transformer: t5-base + tokenizer: ${decoder_model.transformer} + nemo_path: ${decoder_exp_manager.exp_dir}/decoder_model.nemo # exported .nemo path + + optim: + name: adamw + lr: 2e-4 + weight_decay: 0.01 + + sched: + name: WarmupAnnealing + + # pytorch lightning args + monitor: val_loss + reduce_on_plateau: false + + # scheduler config override + warmup_steps: null + warmup_ratio: 0.0 + last_epoch: -1 + +decoder_exp_manager: + exp_dir: exps # where to store logs and checkpoints + name: decoder_training # name of experiment + create_tensorboard_logger: True + create_checkpoint_callback: True + checkpoint_callback_params: + save_top_k: 3 + monitor: "val_loss" + mode: "min" + save_best_model: true + always_save_nemo: true + +# Data +data: + base_dir: ??? # /path/to/data + + train_ds: + data_path: ${data.base_dir}/train.tsv + batch_size: 64 + shuffle: true + do_basic_tokenize: false + max_decoder_len: 80 + mode: ${mode} + # Refer to the text_normalization doc for more information about data augmentation + tagger_data_augmentation: true + decoder_data_augmentation: true + + validation_ds: + data_path: ${data.base_dir}/dev.tsv + batch_size: 64 + shuffle: false + do_basic_tokenize: false + max_decoder_len: 80 + mode: ${mode} + + test_ds: + data_path: ${data.base_dir}/test.tsv + batch_size: 64 + shuffle: false + mode: ${mode} + +# Inference +inference: + interactive: false # Set to true if you want to enable the interactive mode when running duplex_text_normalization_test.py + errors_log_fp: errors.txt # Path to the file for logging the errors diff --git a/examples/nlp/duplex_text_normalization/duplex_text_normalization_test.py b/examples/nlp/duplex_text_normalization/duplex_text_normalization_test.py new file mode 100644 index 000000000000..9c461b18ddbb --- /dev/null +++ b/examples/nlp/duplex_text_normalization/duplex_text_normalization_test.py @@ -0,0 +1,89 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script contains an example on how to evaluate a DuplexTextNormalizationModel. +Note that DuplexTextNormalizationModel is essentially a wrapper class around +DuplexTaggerModel and DuplexDecoderModel. Therefore, two trained NeMo models +should be specificied before evaluation (one is a trained DuplexTaggerModel +and the other is a trained DuplexDecoderModel). + +USAGE Example: +1. Obtain a processed test data file (refer to the `text_normalization doc `) +2. +# python duplex_text_normalization_test.py + tagger_pretrained_model=PATH_TO_TRAINED_TAGGER + decoder_pretrained_model=PATH_TO_TRAINED_DECODER + data.test_ds.data_path=PATH_TO_TEST_FILE + mode={tn,itn,joint} + +The script also supports the `interactive` mode where a user can just make the model +run on any input text: +# python duplex_text_normalization_test.py + tagger_pretrained_model=PATH_TO_TRAINED_TAGGER + decoder_pretrained_model=PATH_TO_TRAINED_DECODER + mode={tn,itn,joint} + inference.interactive=true + +This script uses the `/examples/nlp/duplex_text_normalization/conf/duplex_tn_config.yaml` +config file by default. The other option is to set another config file via command +line arguments by `--config-name=CONFIG_FILE_PATH'. + +Note that when evaluating a DuplexTextNormalizationModel on a labeled dataset, +the script will automatically generate a file for logging the errors made +by the model. The location of this file is determined by the argument +`inference.errors_log_fp`. + +""" + + +from helpers import DECODER_MODEL, TAGGER_MODEL, instantiate_model_and_trainer +from nltk import word_tokenize +from omegaconf import DictConfig, OmegaConf + +import nemo.collections.nlp.data.text_normalization.constants as constants +from nemo.collections.nlp.data.text_normalization import TextNormalizationTestDataset +from nemo.collections.nlp.models import DuplexTextNormalizationModel +from nemo.core.config import hydra_runner +from nemo.utils import logging + + +@hydra_runner(config_path="conf", config_name="duplex_tn_config") +def main(cfg: DictConfig) -> None: + logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}') + tagger_trainer, tagger_model = instantiate_model_and_trainer(cfg, TAGGER_MODEL, False) + decoder_trainer, decoder_model = instantiate_model_and_trainer(cfg, DECODER_MODEL, False) + tn_model = DuplexTextNormalizationModel(tagger_model, decoder_model) + + if not cfg.inference.interactive: + # Setup test_dataset + test_dataset = TextNormalizationTestDataset(cfg.data.test_ds.data_path, cfg.data.test_ds.mode) + results = tn_model.evaluate(test_dataset, cfg.data.test_ds.batch_size, cfg.inference.errors_log_fp) + print(f'\nTest results: {results}') + else: + while True: + test_input = input('Input a test input:') + test_input = ' '.join(word_tokenize(test_input)) + outputs = tn_model._infer([test_input, test_input], [constants.INST_BACKWARD, constants.INST_FORWARD])[-1] + print(f'Prediction (ITN): {outputs[0]}') + print(f'Prediction (TN): {outputs[1]}') + + should_continue = input('\nContinue (y/n): ').strip().lower() + if should_continue.startswith('n'): + break + + +if __name__ == '__main__': + main() diff --git a/examples/nlp/duplex_text_normalization/duplex_text_normalization_train.py b/examples/nlp/duplex_text_normalization/duplex_text_normalization_train.py new file mode 100644 index 000000000000..b96451569603 --- /dev/null +++ b/examples/nlp/duplex_text_normalization/duplex_text_normalization_train.py @@ -0,0 +1,120 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script contains an example on how to train a DuplexTextNormalizationModel. +Note that DuplexTextNormalizationModel is essentially a wrapper class around +two other classes: + +(1) DuplexTaggerModel is a model for identifying spans in the input that need to +be normalized. Usually, such spans belong to semiotic classes (e.g., DATE, NUMBERS, ...). + +(2) DuplexDecoderModel is a model for normalizing the spans identified by the tagger. +For example, in the text normalization (TN) problem, each span will be converted to its +spoken form. In the inverse text normalization (ITN) problem, each span will be converted +to its written form. + +Therefore, this script consists of two parts, one is for training the tagger model +and the other is for training the decoder. + +This script uses the `/examples/nlp/duplex_text_normalization/conf/duplex_tn_config.yaml` +config file by default. The other option is to set another config file via command +line arguments by `--config-name=CONFIG_FILE_PATH'. Probably it is worth looking +at the example config file to see the list of parameters used for training. + +USAGE Example: +1. Obtain a processed dataset (refer to the `text_normalization doc `) +2. +# python duplex_text_normalization_train.py + data.base_dir=PATH_TO_DATASET_DIR + mode={tn,itn,joint} + +There are 3 different modes. `tn` mode is for training a system for TN only. +`itn` mode is for training a system for ITN. `joint` is for training a system +that can do both TN and ITN at the same time. Note that the above command will +first train a tagger and then train a decoder sequentially. + +You can also train only a tagger (without training a decoder) by running the +following command: +# python duplex_text_normalization_train.py + data.base_dir=PATH_TO_DATASET_DIR + mode={tn,itn,joint} + decoder_model.do_training=false + +Or you can also train only a decoder (without training a tagger): +# python duplex_text_normalization_train.py + data.base_dir=PATH_TO_DATASET_DIR + mode={tn,itn,joint} + tagger_model.do_training=false + +Information on the arguments: + +Most arguments in the example config file are quite self-explanatory (e.g., +`decoder_model.optim.lr` refers to the learning rate for training the decoder). +Some arguments we want to mention are: + ++ data.base_dir: The path to the dataset directory. It is expected that the +directory contains three files: train.tsv, dev.tsv, and test.tsv. + ++ tagger_model.nemo_path: This is the path where the final trained tagger model +will be saved to. + ++ decoder_model.nemo_path: This is the path where the final trained decoder model +will be saved to. +""" + + +from helpers import DECODER_MODEL, TAGGER_MODEL, instantiate_model_and_trainer +from omegaconf import DictConfig, OmegaConf + +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path="conf", config_name="duplex_tn_config") +def main(cfg: DictConfig) -> None: + logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}') + + # Train the tagger + if cfg.tagger_model.do_training: + logging.info( + "================================================================================================" + ) + logging.info('Starting training tagger...') + tagger_trainer, tagger_model = instantiate_model_and_trainer(cfg, TAGGER_MODEL, True) + exp_manager(tagger_trainer, cfg.get('tagger_exp_manager', None)) + tagger_trainer.fit(tagger_model) + if cfg.tagger_model.nemo_path: + tagger_model.to(tagger_trainer.accelerator.root_device) + tagger_model.save_to(cfg.tagger_model.nemo_path) + logging.info('Training finished!') + + # Train the decoder + if cfg.decoder_model.do_training: + logging.info( + "================================================================================================" + ) + logging.info('Starting training decoder...') + decoder_trainer, decoder_model = instantiate_model_and_trainer(cfg, DECODER_MODEL, True) + exp_manager(decoder_trainer, cfg.get('decoder_exp_manager', None)) + decoder_trainer.fit(decoder_model) + if cfg.decoder_model.nemo_path: + decoder_model.to(decoder_trainer.accelerator.root_device) + decoder_model.save_to(cfg.decoder_model.nemo_path) + logging.info('Training finished!') + + +if __name__ == '__main__': + main() diff --git a/examples/nlp/duplex_text_normalization/google_data_preprocessing.py b/examples/nlp/duplex_text_normalization/google_data_preprocessing.py new file mode 100644 index 000000000000..656d9cba0c26 --- /dev/null +++ b/examples/nlp/duplex_text_normalization/google_data_preprocessing.py @@ -0,0 +1,182 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script can be used to process the raw data files of the Google Text Normalization dataset +to obtain data files of the format mentioned in the `text_normalization doc `. +Note that the script also does some preprocessing on the spoken forms of the URLs. For example, +given the URL "Zimbio.com", the original expected spoken form in the Google dataset is +"z_letter i_letter m_letter b_letter i_letter o_letter dot c_letter o_letter m_letter". +However, our script will return a more concise output which is "zim bio dot com". + + +USAGE Example: +1. Download the Google TN dataset from https://www.kaggle.com/google-nlu/text-normalization +2. Unzip the English subset (e.g., by running `tar zxvf en_with_types.tgz`). Then there will a folder named `en_with_types`. +3. Run this script +# python google_data_preprocessing.py \ + --data_dir=en_with_types/ \ + --output_dir=preprocessed/ \ + --lang=en + +In this example, the final preprocessed files will be stored in the `preprocessed` folder. +The folder should contain three files `train.tsv`, 'dev.tsv', and `test.tsv`. +""" + +from argparse import ArgumentParser +from os import listdir, mkdir +from os.path import isdir, isfile, join + +import wordninja +from helpers import flatten +from nltk import word_tokenize +from tqdm import tqdm + +import nemo.collections.nlp.data.text_normalization.constants as constants + +# Local Constants +ENGLISH = 'en' +SUPPORTED_LANGS = [ENGLISH] +TRAIN, DEV, TEST = 'train', 'dev', 'test' +SPLIT_NAMES = [TRAIN, DEV, TEST] +MAX_DEV_SIZE = 25000 + +# Helper Functions +def read_google_data(data_dir, lang): + """ + The function can be used to read the raw data files of the Google Text Normalization + dataset (which can be downloaded from https://www.kaggle.com/google-nlu/text-normalization) + + Args: + data_dir: Path to the data directory. The directory should contain files of the form output-xxxxx-of-00100 + lang: Selected language. Currently the only supported language is English. + Return: + train: A list of examples in the training set. + dev: A list of examples in the dev set + test: A list of examples in the test set + """ + train, dev, test = [], [], [] + for fn in listdir(data_dir): + fp = join(data_dir, fn) + if not isfile(fp): + continue + if not fn.startswith('output'): + continue + with open(fp, 'r', encoding='utf-8') as f: + # Determine the current split + split_nb = int(fn.split('-')[1]) + if split_nb == 0: + cur_split = train + elif split_nb == 90: + cur_split = dev + elif split_nb == 99: + cur_split = test + else: + continue + # Loop through each line of the file + cur_classes, cur_tokens, cur_outputs = [], [], [] + for linectx, line in tqdm(enumerate(f)): + es = line.strip().split('\t') + if split_nb == 99 and linectx == 100002: + break + if len(es) == 2 and es[0] == '': + # Update cur_split + cur_outputs = process_url(cur_tokens, cur_outputs, lang) + cur_split.append((cur_classes, cur_tokens, cur_outputs)) + # Reset + cur_classes, cur_tokens, cur_outputs = [], [], [] + continue + # Update the current example + assert len(es) == 3 + cur_classes.append(es[0]) + cur_tokens.append(es[1]) + cur_outputs.append(es[2]) + dev = dev[:MAX_DEV_SIZE] + train_sz, dev_sz, test_sz = len(train), len(dev), len(test) + print(f'train_sz: {train_sz} | dev_sz: {dev_sz} | test_sz: {test_sz}') + return train, dev, test + + +def process_url(tokens, outputs, lang): + """ + The function is used to process the spoken form of every URL in an example + + Args: + tokens: The tokens of the written form + outputs: The expected outputs for the spoken form + lang: Selected language. Currently the only supported language is English. + Return: + outputs: The outputs for the spoken form with preprocessed URLs. + """ + if lang == ENGLISH: + for i in range(len(tokens)): + t, o = tokens[i], outputs[i] + if o != constants.SIL_WORD and '_letter' in o: + o_tokens = o.split(' ') + all_spans, cur_span = [], [] + for j in range(len(o_tokens)): + if len(o_tokens[j]) == 0: + continue + if o_tokens[j] == '_letter': + all_spans.append(cur_span) + all_spans.append([' ']) + cur_span = [] + else: + o_tokens[j] = o_tokens[j].replace('_letter', '') + cur_span.append(o_tokens[j]) + if len(cur_span) > 0: + all_spans.append(cur_span) + o_tokens = flatten(all_spans) + + o = '' + for o_token in o_tokens: + if len(o_token) > 1: + o += ' ' + o_token + ' ' + else: + o += o_token + o = o.strip() + o_tokens = wordninja.split(o) + o = ' '.join(o_tokens) + + outputs[i] = o + + return outputs + + +# Main code +if __name__ == '__main__': + parser = ArgumentParser(description='Preprocess Google text normalization dataset') + parser.add_argument('--data_dir', type=str, required=True, help='Path to folder with data') + parser.add_argument('--output_dir', type=str, default='preprocessed', help='Path to folder with preprocessed data') + parser.add_argument('--lang', type=str, default=ENGLISH, choices=SUPPORTED_LANGS, help='Language') + args = parser.parse_args() + + # Create the output dir (if not exist) + if not isdir(args.output_dir): + mkdir(args.output_dir) + + # Processing + train, dev, test = read_google_data(args.data_dir, args.lang) + for split, data in zip(SPLIT_NAMES, [train, dev, test]): + output_f = open(join(args.output_dir, f'{split}.tsv'), 'w+', encoding='utf-8') + for inst in data: + cur_classes, cur_tokens, cur_outputs = inst + for c, t, o in zip(cur_classes, cur_tokens, cur_outputs): + t = ' '.join(word_tokenize(t)) + if not o in constants.SPECIAL_WORDS: + o_tokens = word_tokenize(o) + o_tokens = [o_tok for o_tok in o_tokens if o_tok != constants.SIL_WORD] + o = ' '.join(o_tokens) + output_f.write(f'{c}\t{t}\t{o}\n') + output_f.write('\t\n') diff --git a/examples/nlp/duplex_text_normalization/helpers.py b/examples/nlp/duplex_text_normalization/helpers.py new file mode 100644 index 000000000000..68be578f3dae --- /dev/null +++ b/examples/nlp/duplex_text_normalization/helpers.py @@ -0,0 +1,73 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytorch_lightning as pl +from omegaconf import DictConfig, OmegaConf + +from nemo.collections.nlp.models import DuplexDecoderModel, DuplexTaggerModel +from nemo.utils import logging + +__all__ = ['TAGGER_MODEL', 'DECODER_MODEL', 'MODEL_NAMES', 'instantiate_model_and_trainer'] + +TAGGER_MODEL = 'tagger' +DECODER_MODEL = 'decoder' +MODEL_NAMES = [TAGGER_MODEL, DECODER_MODEL] + + +def instantiate_model_and_trainer(cfg: DictConfig, model_name: str, do_training: bool): + """ Function for instantiating a model and a trainer + Args: + cfg: The config used to instantiate the model and the trainer. + model_name: A str indicates whether the model to be instantiated is a tagger or a decoder (i.e., model_name should be either TAGGER_MODEL or DECODER_MODEL). + do_training: A boolean flag indicates whether the model will be trained or evaluated. + + Returns: + trainer: A PyTorch Lightning trainer + model: A NLPModel that can either be a DuplexTaggerModel or a DuplexDecoderModel + """ + assert model_name in MODEL_NAMES + logging.info(f'Model {model_name}') + + # Get configs for the corresponding models + trainer_cfg = cfg.get(f'{model_name}_trainer') + model_cfg = cfg.get(f'{model_name}_model') + pretrained_cfg = cfg.get(f'{model_name}_pretrained_model', None) + + trainer = pl.Trainer(**trainer_cfg) + + if not pretrained_cfg: + logging.info(f'Config: {OmegaConf.to_yaml(cfg)}') + if model_name == TAGGER_MODEL: + model = DuplexTaggerModel(model_cfg, trainer=trainer) + if model_name == DECODER_MODEL: + model = DuplexDecoderModel(model_cfg, trainer=trainer) + else: + logging.info(f'Loading pretrained model {pretrained_cfg}') + if model_name == TAGGER_MODEL: + model = DuplexTaggerModel.restore_from(pretrained_cfg) + if model_name == DECODER_MODEL: + model = DuplexDecoderModel.restore_from(pretrained_cfg) + + # Setup train and validation data + if do_training: + model.setup_training_data(train_data_config=cfg.data.train_ds) + model.setup_validation_data(val_data_config=cfg.data.validation_ds) + + logging.info(f'Model Device {model.device}') + return trainer, model + + +def flatten(l): + """ flatten a list of lists """ + return [item for sublist in l for item in sublist] diff --git a/nemo/collections/nlp/data/text_normalization/__init__.py b/nemo/collections/nlp/data/text_normalization/__init__.py new file mode 100644 index 000000000000..8cf835fbe0c1 --- /dev/null +++ b/nemo/collections/nlp/data/text_normalization/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nemo.collections.nlp.data.text_normalization.decoder_dataset import TextNormalizationDecoderDataset +from nemo.collections.nlp.data.text_normalization.tagger_dataset import TextNormalizationTaggerDataset +from nemo.collections.nlp.data.text_normalization.test_dataset import TextNormalizationTestDataset diff --git a/nemo/collections/nlp/data/text_normalization/constants.py b/nemo/collections/nlp/data/text_normalization/constants.py new file mode 100644 index 000000000000..5c6c572c7736 --- /dev/null +++ b/nemo/collections/nlp/data/text_normalization/constants.py @@ -0,0 +1,103 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +DECODE_CTX_SIZE = 3 # the size of the input context to be provided to the DuplexDecoderModel +LABEL_PAD_TOKEN_ID = -100 + +# Task Prefixes +ITN_PREFIX = str(0) +TN_PREFIX = str(1) + +# Tagger Labels Prefixes +B_PREFIX = 'B-' # Denote beginning +I_PREFIX = 'I-' # Denote middle +TAGGER_LABELS_PREFIXES = [B_PREFIX, I_PREFIX] + +# Modes +TN_MODE = 'tn' +ITN_MODE = 'itn' +JOINT_MODE = 'joint' +MODES = [TN_MODE, ITN_MODE, JOINT_MODE] + +# Instance Directions +INST_BACKWARD = 'BACKWARD' +INST_FORWARD = 'FORWARD' +INST_DIRECTIONS = [INST_BACKWARD, INST_FORWARD] + +# TAGS +SAME_TAG = 'SAME' # Tag indicates that a token can be kept the same without any further transformation +TASK_TAG = 'TASK' # Tag indicates that a token belongs to a task prefix (the prefix indicates whether the current task is TN or ITN) +PUNCT_TAG = 'PUNCT' # Tag indicates that a token is a punctuation +TRANSFORM_TAG = 'TRANSFORM' # Tag indicates that a token needs to be transformed by the decoder +ALL_TAGS = [TASK_TAG, SAME_TAG, PUNCT_TAG, TRANSFORM_TAG] + +# ALL_TAG_LABELS +ALL_TAG_LABELS = [] +for prefix in TAGGER_LABELS_PREFIXES: + for tag in ALL_TAGS: + ALL_TAG_LABELS.append(prefix + tag) +ALL_TAG_LABELS.sort() + +# Special Words +SIL_WORD = 'sil' +SELF_WORD = '' +SPECIAL_WORDS = [SIL_WORD, SELF_WORD] + +# Mappings for Greek Letters +GREEK_TO_SPOKEN = { + 'Τ': 'tau', + 'Ο': 'omicron', + 'Δ': 'delta', + 'Η': 'eta', + 'Κ': 'kappa', + 'Ι': 'iota', + 'Θ': 'theta', + 'Α': 'alpha', + 'Σ': 'sigma', + 'Υ': 'upsilon', + 'Μ': 'mu', + 'Ε': 'epsilon', + 'Χ': 'chi', + 'Π': 'pi', + 'Ν': 'nu', + 'Λ': 'lambda', + 'Γ': 'gamma', + 'Β': 'beta', + 'Ρ': 'rho', + 'τ': 'tau', + 'υ': 'upsilon', + 'μ': 'mu', + 'φ': 'phi', + 'α': 'alpha', + 'λ': 'lambda', + 'ι': 'iota', + 'ς': 'sigma', + 'ο': 'omicron', + 'σ': 'sigma', + 'η': 'eta', + 'π': 'pi', + 'ν': 'nu', + 'γ': 'gamma', + 'κ': 'kappa', + 'ε': 'epsilon', + 'β': 'beta', + 'ρ': 'rho', + 'ω': 'omega', + 'χ': 'chi', +} +SPOKEN_TO_GREEK = {v: k for k, v in GREEK_TO_SPOKEN.items()} + +# IDs for special tokens for encoding inputs of the decoder models +EXTRA_ID_0 = '' +EXTRA_ID_1 = '' diff --git a/nemo/collections/nlp/data/text_normalization/decoder_dataset.py b/nemo/collections/nlp/data/text_normalization/decoder_dataset.py new file mode 100644 index 000000000000..30f3b5799d6d --- /dev/null +++ b/nemo/collections/nlp/data/text_normalization/decoder_dataset.py @@ -0,0 +1,196 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +from nltk import word_tokenize +from tqdm import tqdm +from transformers import PreTrainedTokenizerBase + +import nemo.collections.nlp.data.text_normalization.constants as constants +from nemo.collections.nlp.data.text_normalization.utils import read_data_file +from nemo.core.classes import Dataset +from nemo.utils.decorators.experimental import experimental + +__all__ = ['TextNormalizationDecoderDataset'] + + +@experimental +class TextNormalizationDecoderDataset(Dataset): + """ + Creates dataset to use to train a DuplexDecoderModel. + + Converts from raw data to an instance that can be used by Dataloader. + + For dataset to use to do end-to-end inference, see TextNormalizationTestDataset. + + Args: + input_file: path to the raw data file (e.g., train.tsv). For more info about the data format, refer to the `text_normalization doc `. + tokenizer: tokenizer of the model that will be trained on the dataset + mode: should be one of the values ['tn', 'itn', 'joint']. `tn` mode is for TN only. `itn` mode is for ITN only. `joint` is for training a system that can do both TN and ITN at the same time. + max_len: maximum length of sequence in tokens. The code will discard any training instance whose input or output is longer than the specified max_len. + decoder_data_augmentation (bool): a flag indicates whether to augment the dataset with additional data instances that may help the decoder become more robust against the tagger's errors. Refer to the doc for more info. + """ + + def __init__( + self, + input_file: str, + tokenizer: PreTrainedTokenizerBase, + mode: str, + max_len: int, + decoder_data_augmentation: bool, + ): + assert mode in constants.MODES + self.mode = mode + raw_insts = read_data_file(input_file) + + # Convert raw instances to TaggerDataInstance + insts, inputs, targets = [], [], [] + for (classes, w_words, s_words) in tqdm(raw_insts): + for ix, (_class, w_word, s_word) in enumerate(zip(classes, w_words, s_words)): + if s_word in constants.SPECIAL_WORDS: + continue + for inst_dir in constants.INST_DIRECTIONS: + if inst_dir == constants.INST_BACKWARD and mode == constants.TN_MODE: + continue + if inst_dir == constants.INST_FORWARD and mode == constants.ITN_MODE: + continue + # Create a DecoderDataInstance + inst = DecoderDataInstance( + w_words, s_words, inst_dir, start_idx=ix, end_idx=ix + 1, semiotic_class=_class + ) + insts.append(inst) + if decoder_data_augmentation: + noise_left = random.randint(1, 2) + noise_right = random.randint(1, 2) + inst = DecoderDataInstance( + w_words, s_words, inst_dir, start_idx=ix - noise_left, end_idx=ix + 1 + noise_right + ) + insts.append(inst) + + self.insts = insts + inputs = [inst.input_str for inst in insts] + targets = [inst.output_str for inst in insts] + + # Tokenization + self.inputs, self.examples = [], [] + self.tn_count, self.itn_count, long_examples_filtered = 0, 0, 0 + input_max_len, target_max_len = 0, 0 + for idx in range(len(inputs)): + # Input + _input = tokenizer([inputs[idx]]) + input_len = len(_input['input_ids'][0]) + if input_len > max_len: + long_examples_filtered += 1 + continue + + # Target + _target = tokenizer([targets[idx]]) + target_len = len(_target['input_ids'][0]) + if target_len > max_len: + long_examples_filtered += 1 + continue + + # Update + self.inputs.append(inputs[idx]) + _input['labels'] = _target['input_ids'] + self.examples.append(_input) + if inputs[idx].startswith(constants.TN_PREFIX): + self.tn_count += 1 + if inputs[idx].startswith(constants.ITN_PREFIX): + self.itn_count += 1 + input_max_len = max(input_max_len, input_len) + target_max_len = max(target_max_len, target_len) + print(f'long_examples_filtered: {long_examples_filtered}') + print(f'input_max_len: {input_max_len} | target_max_len: {target_max_len}') + + def __getitem__(self, idx): + example = self.examples[idx] + item = {key: val[0] for key, val in example.items()} + return item + + def __len__(self): + return len(self.examples) + + +class DecoderDataInstance: + """ + This class represents a data instance in a TextNormalizationDecoderDataset. + + Intuitively, each data instance can be thought as having the following form: + Input: + Output: + where the context size is determined by the constant DECODE_CTX_SIZE. + + Args: + w_words: List of words in the written form + s_words: List of words in the spoken form + inst_dir: Indicates the direction of the instance (i.e., INST_BACKWARD for ITN or INST_FORWARD for TN). + start_idx: The starting index of the input span in the original input text + end_idx: The ending index of the input span (exclusively) + semiotic_class: The semiotic class of the input span (can be set to None if not available) + """ + + def __init__(self, w_words, s_words, inst_dir, start_idx, end_idx, semiotic_class=None): + start_idx = max(start_idx, 0) + end_idx = min(end_idx, len(w_words)) + ctx_size = constants.DECODE_CTX_SIZE + extra_id_0 = constants.EXTRA_ID_0 + extra_id_1 = constants.EXTRA_ID_1 + + # Extract center words + c_w_words = w_words[start_idx:end_idx] + c_s_words = s_words[start_idx:end_idx] + + # Extract context + w_left = w_words[max(0, start_idx - ctx_size) : start_idx] + w_right = w_words[end_idx : end_idx + ctx_size] + s_left = s_words[max(0, start_idx - ctx_size) : start_idx] + s_right = s_words[end_idx : end_idx + ctx_size] + + # Process sil words and self words + for jx in range(len(s_left)): + if s_left[jx] == constants.SIL_WORD: + s_left[jx] = '' + if s_left[jx] == constants.SELF_WORD: + s_left[jx] = w_left[jx] + for jx in range(len(s_right)): + if s_right[jx] == constants.SIL_WORD: + s_right[jx] = '' + if s_right[jx] == constants.SELF_WORD: + s_right[jx] = w_right[jx] + for jx in range(len(c_s_words)): + if c_s_words[jx] == constants.SIL_WORD: + c_s_words[jx] = '' + if inst_dir == constants.INST_BACKWARD: + c_w_words[jx] = '' + if c_s_words[jx] == constants.SELF_WORD: + c_s_words[jx] = c_w_words[jx] + + # Extract input_words and output_words + c_w_words = word_tokenize(' '.join(c_w_words)) + c_s_words = word_tokenize(' '.join(c_s_words)) + w_input = w_left + [extra_id_0] + c_w_words + [extra_id_1] + w_right + s_input = s_left + [extra_id_0] + c_s_words + [extra_id_1] + s_right + if inst_dir == constants.INST_BACKWARD: + input_words = [constants.ITN_PREFIX] + s_input + output_words = c_w_words + if inst_dir == constants.INST_FORWARD: + input_words = [constants.TN_PREFIX] + w_input + output_words = c_s_words + # Finalize + self.input_str = ' '.join(input_words) + self.output_str = ' '.join(output_words) + self.direction = inst_dir + self.semiotic_class = semiotic_class diff --git a/nemo/collections/nlp/data/text_normalization/tagger_dataset.py b/nemo/collections/nlp/data/text_normalization/tagger_dataset.py new file mode 100644 index 000000000000..0f5c7dc09984 --- /dev/null +++ b/nemo/collections/nlp/data/text_normalization/tagger_dataset.py @@ -0,0 +1,167 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nltk import word_tokenize +from tqdm import tqdm +from transformers import PreTrainedTokenizerBase + +import nemo.collections.nlp.data.text_normalization.constants as constants +from nemo.collections.nlp.data.text_normalization.utils import read_data_file +from nemo.core.classes import Dataset +from nemo.utils.decorators.experimental import experimental + +__all__ = ['TextNormalizationTaggerDataset'] + + +@experimental +class TextNormalizationTaggerDataset(Dataset): + """ + Creates dataset to use to train a DuplexTaggerModel. + + Converts from raw data to an instance that can be used by Dataloader. + + For dataset to use to do end-to-end inference, see TextNormalizationTestDataset. + + Args: + input_file: path to the raw data file (e.g., train.tsv). For more info about the data format, refer to the `text_normalization doc `. + tokenizer: tokenizer of the model that will be trained on the dataset + mode: should be one of the values ['tn', 'itn', 'joint']. `tn` mode is for TN only. `itn` mode is for ITN only. `joint` is for training a system that can do both TN and ITN at the same time. + do_basic_tokenize: a flag indicates whether to do some basic tokenization (i.e., using word_tokenize() of nltk) before using the tokenizer of the model + tagger_data_augmentation (bool): a flag indicates whether to augment the dataset with additional data instances + """ + + def __init__( + self, + input_file: str, + tokenizer: PreTrainedTokenizerBase, + mode: str, + do_basic_tokenize: bool, + tagger_data_augmentation: bool, + ): + assert mode in constants.MODES + self.mode = mode + raw_insts = read_data_file(input_file) + + # Convert raw instances to TaggerDataInstance + insts = [] + for (_, w_words, s_words) in tqdm(raw_insts): + for inst_dir in constants.INST_DIRECTIONS: + if inst_dir == constants.INST_BACKWARD and mode == constants.TN_MODE: + continue + if inst_dir == constants.INST_FORWARD and mode == constants.ITN_MODE: + continue + # Create a new TaggerDataInstance + inst = TaggerDataInstance(w_words, s_words, inst_dir, do_basic_tokenize) + insts.append(inst) + # Data Augmentation (if enabled) + if tagger_data_augmentation: + filtered_w_words, filtered_s_words = [], [] + for ix, (w, s) in enumerate(zip(w_words, s_words)): + if not s in constants.SPECIAL_WORDS: + filtered_w_words.append(w) + filtered_s_words.append(s) + if len(filtered_s_words) > 1: + inst = TaggerDataInstance(filtered_w_words, filtered_s_words, inst_dir) + insts.append(inst) + + self.insts = insts + texts = [inst.input_words for inst in insts] + tags = [inst.labels for inst in insts] + + # Tags Mapping + self.tag2id = {tag: id for id, tag in enumerate(constants.ALL_TAG_LABELS)} + + # Finalize + self.encodings = tokenizer(texts, is_split_into_words=True, padding=False, truncation=True) + self.labels = self.encode_tags(tags, self.encodings) + + def __getitem__(self, idx): + item = {key: val[idx] for key, val in self.encodings.items()} + item['labels'] = self.labels[idx] + return item + + def __len__(self): + return len(self.labels) + + def encode_tags(self, tags, encodings): + encoded_labels = [] + for i, label in enumerate(tags): + word_ids = encodings.word_ids(batch_index=i) + previous_word_idx = None + label_ids = [] + for word_idx in word_ids: + # Special tokens have a word id that is None. We set the label + # to -100 (LABEL_PAD_TOKEN_ID) so they are automatically + # ignored in the loss function. + if word_idx is None: + label_ids.append(constants.LABEL_PAD_TOKEN_ID) + # We set the label for the first token of each word. + elif word_idx != previous_word_idx: + label_id = self.tag2id[constants.B_PREFIX + label[word_idx]] + label_ids.append(label_id) + # We set the label for the other tokens in a word + else: + label_id = self.tag2id[constants.I_PREFIX + label[word_idx]] + label_ids.append(label_id) + previous_word_idx = word_idx + + encoded_labels.append(label_ids) + + return encoded_labels + + +class TaggerDataInstance: + """ + This class represents a data instance in a TextNormalizationTaggerDataset. + + Args: + w_words: List of words in the written form + s_words: List of words in the spoken form + direction: Indicates the direction of the instance (i.e., INST_BACKWARD for ITN or INST_FORWARD for TN). + do_basic_tokenize: a flag indicates whether to do some basic tokenization (i.e., using word_tokenize() of nltk) before using the tokenizer of the model + """ + + def __init__(self, w_words, s_words, direction, do_basic_tokenize=False): + # Build input_words and labels + input_words, labels = [], [] + # Task Prefix + if direction == constants.INST_BACKWARD: + input_words.append(constants.ITN_PREFIX) + if direction == constants.INST_FORWARD: + input_words.append(constants.TN_PREFIX) + labels.append(constants.TASK_TAG) + # Main Content + for w_word, s_word in zip(w_words, s_words): + # Basic tokenization (if enabled) + if do_basic_tokenize: + w_word = ' '.join(word_tokenize(w_word)) + if not s_word in constants.SPECIAL_WORDS: + s_word = ' '.join(word_tokenize(s_word)) + # Update input_words and labels + if s_word == constants.SIL_WORD and direction == constants.INST_BACKWARD: + continue + if s_word == constants.SELF_WORD: + input_words.append(w_word) + labels.append(constants.SAME_TAG) + elif s_word == constants.SIL_WORD: + input_words.append(w_word) + labels.append(constants.PUNCT_TAG) + else: + if direction == constants.INST_BACKWARD: + input_words.append(s_word) + if direction == constants.INST_FORWARD: + input_words.append(w_word) + labels.append(constants.TRANSFORM_TAG) + self.input_words = input_words + self.labels = labels diff --git a/nemo/collections/nlp/data/text_normalization/test_dataset.py b/nemo/collections/nlp/data/text_normalization/test_dataset.py new file mode 100644 index 000000000000..1d235503433d --- /dev/null +++ b/nemo/collections/nlp/data/text_normalization/test_dataset.py @@ -0,0 +1,124 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import List + +from nltk import word_tokenize +from tqdm import tqdm + +import nemo.collections.nlp.data.text_normalization.constants as constants +from nemo.collections.nlp.data.text_normalization.utils import normalize_str, read_data_file, remove_puncts +from nemo.utils.decorators.experimental import experimental + +__all__ = ['TextNormalizationTestDataset'] + +# Test Dataset +@experimental +class TextNormalizationTestDataset: + """ + Creates dataset to use to do end-to-end inference + + Args: + input_file: path to the raw data file (e.g., train.tsv). For more info about the data format, refer to the `text_normalization doc `. + mode: should be one of the values ['tn', 'itn', 'joint']. `tn` mode is for TN only. `itn` mode is for ITN only. `joint` is for training a system that can do both TN and ITN at the same time. + keep_puncts: whether to keep punctuations in the inputs/outputs + """ + + def __init__(self, input_file: str, mode: str, keep_puncts: bool = False): + insts = read_data_file(input_file) + + # Build inputs and targets + self.directions, self.inputs, self.targets = [], [], [] + for (_, w_words, s_words) in insts: + # Extract words that are not punctuations + processed_w_words, processed_s_words = [], [] + for w_word, s_word in zip(w_words, s_words): + if s_word == constants.SIL_WORD: + if keep_puncts: + processed_w_words.append(w_word) + processed_s_words.append(w_word) + continue + if s_word == constants.SELF_WORD: + processed_s_words.append(w_word) + if not s_word in constants.SPECIAL_WORDS: + processed_s_words.append(s_word) + processed_w_words.append(w_word) + # Create examples + for direction in constants.INST_DIRECTIONS: + if direction == constants.INST_BACKWARD: + if mode == constants.TN_MODE: + continue + input_words = processed_s_words + output_words = processed_w_words + if direction == constants.INST_FORWARD: + if mode == constants.ITN_MODE: + continue + input_words = w_words + output_words = processed_s_words + # Basic tokenization + input_words = word_tokenize(' '.join(input_words)) + output_words = word_tokenize(' '.join(output_words)) + # Update self.directions, self.inputs, self.targets + self.directions.append(direction) + self.inputs.append(' '.join(input_words)) + self.targets.append(' '.join(output_words)) + self.examples = list(zip(self.directions, self.inputs, self.targets)) + + def __getitem__(self, idx): + return self.examples[idx] + + def __len__(self): + return len(self.inputs) + + @staticmethod + def is_same(pred: str, target: str, inst_dir: str): + """ + Function for checking whether the predicted string can be considered + the same as the target string + + Args: + pred: Predicted string + target: Target string + inst_dir: Direction of the instance (i.e., INST_BACKWARD or INST_FORWARD). + Return: an int value (0/1) indicating whether pred and target are the same. + """ + if inst_dir == constants.INST_BACKWARD: + pred = remove_puncts(pred) + target = remove_puncts(target) + pred = normalize_str(pred) + target = normalize_str(target) + return int(pred == target) + + @staticmethod + def compute_sent_accuracy(preds: List[str], targets: List[str], inst_directions: List[str]): + """ + Compute the sentence accuracy metric. + + Args: + preds: List of predicted strings. + targets: List of target strings. + inst_directions: A list of str where each str indicates the direction of the corresponding instance (i.e., INST_BACKWARD or INST_FORWARD). + Return: the sentence accuracy score + """ + assert len(preds) == len(targets) + if len(targets) == 0: + return 'NA' + # Sentence Accuracy + correct_count = 0 + for inst_dir, pred, target in zip(inst_directions, preds, targets): + correct_count += TextNormalizationTestDataset.is_same(pred, target, inst_dir) + sent_accuracy = correct_count / len(targets) + + return sent_accuracy diff --git a/nemo/collections/nlp/data/text_normalization/utils.py b/nemo/collections/nlp/data/text_normalization/utils.py new file mode 100644 index 000000000000..ccaa98d29216 --- /dev/null +++ b/nemo/collections/nlp/data/text_normalization/utils.py @@ -0,0 +1,54 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import string +from copy import deepcopy + +from nltk import word_tokenize +from tqdm import tqdm + +__all__ = ['read_data_file', 'normalize_str'] + + +def read_data_file(fp): + """ Reading the raw data from a file of NeMo format + For more info about the data format, refer to the + `text_normalization doc `. + """ + insts, w_words, s_words, classes = [], [], [], [] + # Read input file + with open(fp, 'r', encoding='utf-8') as f: + for line in tqdm(f): + es = [e.strip() for e in line.strip().split('\t')] + if es[0] == '': + inst = (deepcopy(classes), deepcopy(w_words), deepcopy(s_words)) + insts.append(inst) + # Reset + w_words, s_words, classes = [], [], [] + else: + classes.append(es[0]) + w_words.append(es[1]) + s_words.append(es[2]) + return insts + + +def normalize_str(input_str): + """ Normalize an input string """ + input_str = ' '.join(word_tokenize(input_str.strip().lower())) + input_str = input_str.replace(' ', ' ') + return input_str + + +def remove_puncts(input_str): + """ Remove punctuations from an input string """ + return input_str.translate(str.maketrans('', '', string.punctuation)) diff --git a/nemo/collections/nlp/models/__init__.py b/nemo/collections/nlp/models/__init__.py index 42f20f128e8a..75edd03d865f 100644 --- a/nemo/collections/nlp/models/__init__.py +++ b/nemo/collections/nlp/models/__init__.py @@ -13,6 +13,11 @@ # limitations under the License. from nemo.collections.nlp.models.dialogue_state_tracking.sgdqa_model import SGDQAModel +from nemo.collections.nlp.models.duplex_text_normalization import ( + DuplexDecoderModel, + DuplexTaggerModel, + DuplexTextNormalizationModel, +) from nemo.collections.nlp.models.entity_linking.entity_linking_model import EntityLinkingModel from nemo.collections.nlp.models.glue_benchmark.glue_benchmark_model import GLUEModel from nemo.collections.nlp.models.information_retrieval import BertDPRModel, BertJointIRModel diff --git a/nemo/collections/nlp/models/duplex_text_normalization/__init__.py b/nemo/collections/nlp/models/duplex_text_normalization/__init__.py new file mode 100644 index 000000000000..043bed0b2013 --- /dev/null +++ b/nemo/collections/nlp/models/duplex_text_normalization/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.nlp.models.duplex_text_normalization.duplex_decoder import DuplexDecoderModel +from nemo.collections.nlp.models.duplex_text_normalization.duplex_tagger import DuplexTaggerModel +from nemo.collections.nlp.models.duplex_text_normalization.duplex_tn import DuplexTextNormalizationModel diff --git a/nemo/collections/nlp/models/duplex_text_normalization/duplex_decoder.py b/nemo/collections/nlp/models/duplex_text_normalization/duplex_decoder.py new file mode 100644 index 000000000000..975fa6f523ed --- /dev/null +++ b/nemo/collections/nlp/models/duplex_text_normalization/duplex_decoder.py @@ -0,0 +1,262 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from time import perf_counter +from typing import List, Optional + +import nltk +import torch +import wordninja +from omegaconf import DictConfig +from pytorch_lightning import Trainer +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq + +import nemo.collections.nlp.data.text_normalization.constants as constants +from nemo.collections.nlp.data.text_normalization import TextNormalizationDecoderDataset +from nemo.collections.nlp.models.duplex_text_normalization.utils import is_url +from nemo.collections.nlp.models.nlp_model import NLPModel +from nemo.core.classes.common import PretrainedModelInfo +from nemo.utils import logging +from nemo.utils.decorators.experimental import experimental + +nltk.download('punkt') + + +__all__ = ['DuplexDecoderModel'] + + +@experimental +class DuplexDecoderModel(NLPModel): + """ + Transformer-based (duplex) decoder model for TN/ITN. + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + self._tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer) + super().__init__(cfg=cfg, trainer=trainer) + self.model = AutoModelForSeq2SeqLM.from_pretrained(cfg.transformer) + + # Training + def training_step(self, batch, batch_idx): + """ + Lightning calls this inside the training loop with the data from the training dataloader + passed in as `batch`. + """ + # Apply Transformer + outputs = self.model( + input_ids=batch['input_ids'], + decoder_input_ids=batch['decoder_input_ids'], + attention_mask=batch['attention_mask'], + labels=batch['labels'], + ) + train_loss = outputs.loss + + lr = self._optimizer.param_groups[0]['lr'] + self.log('train_loss', train_loss) + self.log('lr', lr, prog_bar=True) + return {'loss': train_loss, 'lr': lr} + + # Validation and Testing + def validation_step(self, batch, batch_idx): + """ + Lightning calls this inside the validation loop with the data from the validation dataloader + passed in as `batch`. + """ + + # Apply Transformer + outputs = self.model( + input_ids=batch['input_ids'], + decoder_input_ids=batch['decoder_input_ids'], + attention_mask=batch['attention_mask'], + labels=batch['labels'], + ) + val_loss = outputs.loss + + return {'val_loss': val_loss} + + def validation_epoch_end(self, outputs): + """ + Called at the end of validation to aggregate outputs. + :param outputs: list of individual outputs of each validation step. + """ + avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() + self.log('val_loss', avg_loss) + + return { + 'val_loss': avg_loss, + } + + def test_step(self, batch, batch_idx): + """ + Lightning calls this inside the test loop with the data from the test dataloader + passed in as `batch`. + """ + return self.validation_step(batch, batch_idx) + + def test_epoch_end(self, outputs): + """ + Called at the end of test to aggregate outputs. + :param outputs: list of individual outputs of each test step. + """ + return self.validation_epoch_end(outputs) + + # Functions for inference + @torch.no_grad() + def _infer( + self, + sents: List[List[str]], + nb_spans: List[int], + span_starts: List[List[int]], + span_ends: List[List[int]], + inst_directions: List[str], + ): + """ Main function for Inference + Args: + sents: A list of inputs tokenized by a basic tokenizer (e.g., using nltk.word_tokenize()). + nb_spans: A list of ints where each int indicates the number of semiotic spans in each input. + span_starts: A list of lists where each list contains the starting locations of semiotic spans in an input. + span_ends: A list of lists where each list contains the ending locations of semiotic spans in an input. + inst_directions: A list of str where each str indicates the direction of the corresponding instance (i.e., INST_BACKWARD for ITN or INST_FORWARD for TN). + + Returns: A list of lists where each list contains the decoded spans for the corresponding input. + """ + self.eval() + + if sum(nb_spans) == 0: + return [[]] * len(sents) + model, tokenizer = self.model, self._tokenizer + model_max_len = model.config.n_positions + ctx_size = constants.DECODE_CTX_SIZE + extra_id_0 = constants.EXTRA_ID_0 + extra_id_1 = constants.EXTRA_ID_1 + + # Build all_inputs + input_centers, input_dirs, all_inputs = [], [], [] + for ix, sent in enumerate(sents): + cur_inputs = [] + for jx in range(nb_spans[ix]): + cur_start = span_starts[ix][jx] + cur_end = span_ends[ix][jx] + ctx_left = sent[max(0, cur_start - ctx_size) : cur_start] + ctx_right = sent[cur_end + 1 : cur_end + 1 + ctx_size] + span_words = sent[cur_start : cur_end + 1] + span_words_str = ' '.join(span_words) + if is_url(span_words_str): + span_words_str = span_words_str.lower() + input_centers.append(span_words_str) + input_dirs.append(inst_directions[ix]) + # Build cur_inputs + if inst_directions[ix] == constants.INST_BACKWARD: + cur_inputs = [constants.ITN_PREFIX] + if inst_directions[ix] == constants.INST_FORWARD: + cur_inputs = [constants.TN_PREFIX] + cur_inputs += ctx_left + cur_inputs += [extra_id_0] + span_words_str.split(' ') + [extra_id_1] + cur_inputs += ctx_right + all_inputs.append(' '.join(cur_inputs)) + + # Apply the decoding model + batch = tokenizer(all_inputs, padding=True, return_tensors='pt') + input_ids = batch['input_ids'].to(self.device) + generated_ids = model.generate(input_ids, max_length=model_max_len) + generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + + # Post processing + generated_texts = self.postprocess_output_spans(input_centers, generated_texts, input_dirs) + + # Prepare final_texts + final_texts, span_ctx = [], 0 + for nb_span in nb_spans: + cur_texts = [] + for i in range(nb_span): + cur_texts.append(generated_texts[span_ctx]) + span_ctx += 1 + final_texts.append(cur_texts) + + return final_texts + + def postprocess_output_spans(self, input_centers, output_spans, input_dirs): + greek_spokens = list(constants.GREEK_TO_SPOKEN.values()) + for ix, (_input, _output) in enumerate(zip(input_centers, output_spans)): + # Handle URL + if is_url(_input): + output_spans[ix] = ' '.join(wordninja.split(_output)) + continue + # Greek letters + if _input in greek_spokens: + if input_dirs[ix] == constants.INST_FORWARD: + output_spans[ix] = _input + if input_dirs[ix] == constants.INST_BACKWARD: + output_spans[ix] = constants.SPOKEN_TO_GREEK[_input] + return output_spans + + # Functions for processing data + def setup_training_data(self, train_data_config: Optional[DictConfig]): + if not train_data_config or not train_data_config.data_path: + logging.info( + f"Dataloader config or file_path for the train is missing, so no data loader for train is created!" + ) + self._train_dl = None + return + self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config, mode="train") + + def setup_validation_data(self, val_data_config: Optional[DictConfig]): + if not val_data_config or not val_data_config.data_path: + logging.info( + f"Dataloader config or file_path for the validation is missing, so no data loader for validation is created!" + ) + self._validation_dl = None + return + self._validation_dl = self._setup_dataloader_from_config(cfg=val_data_config, mode="val") + + def setup_test_data(self, test_data_config: Optional[DictConfig]): + if not test_data_config or test_data_config.data_path is None: + logging.info( + f"Dataloader config or file_path for the test is missing, so no data loader for test is created!" + ) + self._test_dl = None + return + self._test_dl = self._setup_dataloader_from_config(cfg=test_data_config, mode="test") + + def _setup_dataloader_from_config(self, cfg: DictConfig, mode: str): + tokenizer, model = self._tokenizer, self.model + start_time = perf_counter() + logging.info(f'Creating {mode} dataset') + input_file = cfg.data_path + dataset = TextNormalizationDecoderDataset( + input_file, + tokenizer, + cfg.mode, + cfg.get('max_decoder_len', tokenizer.model_max_length), + cfg.get('decoder_data_augmentation', False), + ) + data_collator = DataCollatorForSeq2Seq( + tokenizer, model=model, label_pad_token_id=constants.LABEL_PAD_TOKEN_ID, + ) + dl = torch.utils.data.DataLoader( + dataset=dataset, batch_size=cfg.batch_size, shuffle=cfg.shuffle, collate_fn=data_collator, + ) + running_time = perf_counter() - start_time + logging.info(f'Took {running_time} seconds') + return dl + + @classmethod + def list_available_models(cls) -> Optional[PretrainedModelInfo]: + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + Returns: + List of available pre-trained models. + """ + result = [] + return result diff --git a/nemo/collections/nlp/models/duplex_text_normalization/duplex_tagger.py b/nemo/collections/nlp/models/duplex_text_normalization/duplex_tagger.py new file mode 100644 index 000000000000..91104e37f613 --- /dev/null +++ b/nemo/collections/nlp/models/duplex_text_normalization/duplex_tagger.py @@ -0,0 +1,299 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from time import perf_counter +from typing import List, Optional + +import nltk +import torch +from omegaconf import DictConfig +from pytorch_lightning import Trainer +from torch import nn +from transformers import AutoModelForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification + +import nemo.collections.nlp.data.text_normalization.constants as constants +from nemo.collections.nlp.data.text_normalization import TextNormalizationTaggerDataset +from nemo.collections.nlp.metrics.classification_report import ClassificationReport +from nemo.collections.nlp.models.duplex_text_normalization.utils import has_numbers +from nemo.collections.nlp.models.nlp_model import NLPModel +from nemo.core.classes.common import PretrainedModelInfo +from nemo.utils import logging +from nemo.utils.decorators.experimental import experimental + +nltk.download('punkt') + + +__all__ = ['DuplexTaggerModel'] + + +@experimental +class DuplexTaggerModel(NLPModel): + """ + Transformer-based (duplex) tagger model for TN/ITN. + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + self._tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer, add_prefix_space=True) + super().__init__(cfg=cfg, trainer=trainer) + self.num_labels = len(constants.ALL_TAG_LABELS) + self.model = AutoModelForTokenClassification.from_pretrained(cfg.transformer, num_labels=self.num_labels) + + # Loss Functions + self.loss_fct = nn.CrossEntropyLoss(ignore_index=constants.LABEL_PAD_TOKEN_ID) + + # setup to track metrics + self.classification_report = ClassificationReport(self.num_labels, mode='micro', dist_sync_on_step=True) + + # Training + def training_step(self, batch, batch_idx): + """ + Lightning calls this inside the training loop with the data from the training dataloader + passed in as `batch`. + """ + num_labels = self.num_labels + + # Apply Transformer + tag_logits = self.model(batch['input_ids'], batch['attention_mask']).logits + + # Loss + train_loss = self.loss_fct(tag_logits.view(-1, num_labels), batch['labels'].view(-1)) + + lr = self._optimizer.param_groups[0]['lr'] + self.log('train_loss', train_loss) + self.log('lr', lr, prog_bar=True) + return {'loss': train_loss, 'lr': lr} + + # Validation and Testing + def validation_step(self, batch, batch_idx): + """ + Lightning calls this inside the validation loop with the data from the validation dataloader + passed in as `batch`. + """ + # Apply Transformer + tag_logits = self.model(batch['input_ids'], batch['attention_mask']).logits + tag_preds = torch.argmax(tag_logits, dim=2) + + # Update classification_report + predictions, labels = tag_preds.tolist(), batch['labels'].tolist() + for prediction, label in zip(predictions, labels): + cur_preds = [p for (p, l) in zip(prediction, label) if l != constants.LABEL_PAD_TOKEN_ID] + cur_labels = [l for (p, l) in zip(prediction, label) if l != constants.LABEL_PAD_TOKEN_ID] + self.classification_report( + torch.tensor(cur_preds).to(self.device), torch.tensor(cur_labels).to(self.device) + ) + + def validation_epoch_end(self, outputs): + """ + Called at the end of validation to aggregate outputs. + :param outputs: list of individual outputs of each validation step. + """ + # calculate metrics and classification report + precision, _, _, report = self.classification_report.compute() + + logging.info(report) + + self.log('val_token_precision', precision) + + self.classification_report.reset() + + def test_step(self, batch, batch_idx): + """ + Lightning calls this inside the test loop with the data from the test dataloader + passed in as `batch`. + """ + return self.validation_step(batch, batch_idx) + + def test_epoch_end(self, outputs): + """ + Called at the end of test to aggregate outputs. + :param outputs: list of individual outputs of each test step. + """ + return self.validation_epoch_end(outputs) + + # Functions for inference + @torch.no_grad() + def _infer(self, sents: List[List[str]], inst_directions: List[str]): + """ Main function for Inference + Args: + sents: A list of inputs tokenized by a basic tokenizer (e.g., using nltk.word_tokenize()). + inst_directions: A list of str where each str indicates the direction of the corresponding instance (i.e., INST_BACKWARD for ITN or INST_FORWARD for TN). + + Returns: + all_tag_preds: A list of list where each list contains the raw tag predictions for the corresponding input. + nb_spans: A list of ints where each int indicates the number of semiotic spans in each input. + span_starts: A list of lists where each list contains the starting locations of semiotic spans in an input. + span_ends: A list of lists where each list contains the ending locations of semiotic spans in an input. + """ + self.eval() + + # Append prefix + texts = [] + for ix, sent in enumerate(sents): + if inst_directions[ix] == constants.INST_BACKWARD: + prefix = constants.ITN_PREFIX + if inst_directions[ix] == constants.INST_FORWARD: + prefix = constants.TN_PREFIX + texts.append([prefix] + sent) + + # Apply the model + prefix = constants.TN_PREFIX + texts = [[prefix] + sent for sent in sents] + encodings = self._tokenizer( + texts, is_split_into_words=True, padding=True, truncation=True, return_tensors='pt' + ) + logits = self.model(**encodings.to(self.device)).logits + pred_indexes = torch.argmax(logits, dim=-1).tolist() + + # Extract all_tag_preds + all_tag_preds = [] + batch_size, max_len = encodings['input_ids'].size() + for ix in range(batch_size): + raw_tag_preds = [constants.ALL_TAG_LABELS[p] for p in pred_indexes[ix][1:]] + tag_preds, previous_word_idx = [], None + word_ids = encodings.word_ids(batch_index=ix) + for jx, word_idx in enumerate(word_ids): + if word_idx is None: + continue + elif word_idx != previous_word_idx: + tag_preds.append(raw_tag_preds[jx - 1]) + previous_word_idx = word_idx + tag_preds = tag_preds[1:] + all_tag_preds.append(tag_preds) + + # Postprocessing + all_tag_preds = [ + self.postprocess_tag_preds(words, inst_dir, ps) + for words, inst_dir, ps in zip(sents, inst_directions, all_tag_preds) + ] + + # Decoding + nb_spans, span_starts, span_ends = self.decode_tag_preds(all_tag_preds) + + return all_tag_preds, nb_spans, span_starts, span_ends + + def postprocess_tag_preds(self, words, inst_dir, preds): + """ Function for postprocessing the raw tag predictions of the model. It + corrects obvious mistakes in the tag predictions such as a TRANSFORM span + starts with I_TRANSFORM_TAG (instead of B_TRANSFORM_TAG). + + Args: + words: The words in the input text + inst_dir: The direction of the instance (i.e., INST_BACKWARD or INST_FORWARD). + preds: The raw tag predictions + + Returns: The processed raw tag predictions + """ + final_preds = [] + for ix, p in enumerate(preds): + # a TRANSFORM span starts with I_TRANSFORM_TAG + if p == constants.I_PREFIX + constants.TRANSFORM_TAG: + if ix == 0 or (not constants.TRANSFORM_TAG in final_preds[ix - 1]): + final_preds.append(constants.B_PREFIX + constants.TRANSFORM_TAG) + continue + # a span has numbers but does not have TRANSFORM tags (for TN) + if inst_dir == constants.INST_FORWARD: + if has_numbers(words[ix]) and (not constants.TRANSFORM_TAG in p): + final_preds.append(constants.B_PREFIX + constants.TRANSFORM_TAG) + continue + # Default + final_preds.append(p) + return final_preds + + def decode_tag_preds(self, tag_preds): + """ Decoding the raw tag predictions to locate the semiotic spans in the + input texts. + + Args: + tag_preds: A list of list where each list contains the raw tag predictions for the corresponding input. + + Returns: + nb_spans: A list of ints where each int indicates the number of semiotic spans in each input. + span_starts: A list of lists where each list contains the starting locations of semiotic spans in an input. + span_ends: A list of lists where each list contains the ending locations of semiotic spans in an input. + """ + nb_spans, span_starts, span_ends = [], [], [] + for i, preds in enumerate(tag_preds): + cur_nb_spans, cur_span_start = 0, None + cur_span_starts, cur_span_ends = [], [] + for ix, pred in enumerate(preds + ['EOS']): + if pred != constants.I_PREFIX + constants.TRANSFORM_TAG: + if not cur_span_start is None: + cur_nb_spans += 1 + cur_span_starts.append(cur_span_start) + cur_span_ends.append(ix - 1) + cur_span_start = None + if pred == constants.B_PREFIX + constants.TRANSFORM_TAG: + cur_span_start = ix + nb_spans.append(cur_nb_spans) + span_starts.append(cur_span_starts) + span_ends.append(cur_span_ends) + + return nb_spans, span_starts, span_ends + + # Functions for processing data + def setup_training_data(self, train_data_config: Optional[DictConfig]): + if not train_data_config or not train_data_config.data_path: + logging.info( + f"Dataloader config or file_path for the train is missing, so no data loader for train is created!" + ) + self._train_dl = None + return + self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config, mode="train") + + def setup_validation_data(self, val_data_config: Optional[DictConfig]): + if not val_data_config or not val_data_config.data_path: + logging.info( + f"Dataloader config or file_path for the validation is missing, so no data loader for validation is created!" + ) + self._validation_dl = None + return + self._validation_dl = self._setup_dataloader_from_config(cfg=val_data_config, mode="val") + + def setup_test_data(self, test_data_config: Optional[DictConfig]): + if not test_data_config or test_data_config.data_path is None: + logging.info( + f"Dataloader config or file_path for the test is missing, so no data loader for test is created!" + ) + self._test_dl = None + return + self._test_dl = self._setup_dataloader_from_config(cfg=test_data_config, mode="test") + + def _setup_dataloader_from_config(self, cfg: DictConfig, mode: str): + start_time = perf_counter() + logging.info(f'Creating {mode} dataset') + input_file = cfg.data_path + dataset = TextNormalizationTaggerDataset( + input_file, + self._tokenizer, + cfg.mode, + cfg.get('do_basic_tokenize', False), + cfg.get('tagger_data_augmentation', False), + ) + data_collator = DataCollatorForTokenClassification(self._tokenizer) + dl = torch.utils.data.DataLoader( + dataset=dataset, batch_size=cfg.batch_size, shuffle=cfg.shuffle, collate_fn=data_collator, + ) + running_time = perf_counter() - start_time + logging.info(f'Took {running_time} seconds') + return dl + + @classmethod + def list_available_models(cls) -> Optional[PretrainedModelInfo]: + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + Returns: + List of available pre-trained models. + """ + result = [] + return result diff --git a/nemo/collections/nlp/models/duplex_text_normalization/duplex_tn.py b/nemo/collections/nlp/models/duplex_text_normalization/duplex_tn.py new file mode 100644 index 000000000000..7b2b23c6cfef --- /dev/null +++ b/nemo/collections/nlp/models/duplex_text_normalization/duplex_tn.py @@ -0,0 +1,206 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from math import ceil +from time import perf_counter +from typing import List + +import numpy as np +import torch.nn as nn +from nltk import word_tokenize +from tqdm import tqdm +from transformers import * + +import nemo.collections.nlp.data.text_normalization.constants as constants +from nemo.collections.nlp.data.text_normalization import TextNormalizationTestDataset +from nemo.collections.nlp.models.duplex_text_normalization.utils import get_formatted_string +from nemo.utils import logging +from nemo.utils.decorators.experimental import experimental + +__all__ = ['DuplexTextNormalizationModel'] + + +@experimental +class DuplexTextNormalizationModel(nn.Module): + """ + DuplexTextNormalizationModel is a wrapper class that can be used to + encapsulate a trained tagger and a trained decoder. The class is intended + to be used for inference only (e.g., for evaluation). + """ + + def __init__(self, tagger, decoder): + super(DuplexTextNormalizationModel, self).__init__() + + self.tagger = tagger + self.decoder = decoder + + def evaluate( + self, dataset: TextNormalizationTestDataset, batch_size: int, errors_log_fp: str, verbose: bool = True + ): + """ Function for evaluating the performance of the model on a dataset + + Args: + dataset: The dataset to be used for evaluation. + batch_size: Batch size to use during inference. You can set it to be 1 + (no batching) if you want to measure the running time of the model + per individual example (assuming requests are coming to the model one-by-one). + errors_log_fp: Path to the file for logging the errors + verbose: if true prints and logs various evaluation results + + Returns: + results: A Dict containing the evaluation results (e.g., accuracy, running time) + """ + results = {} + error_f = open(errors_log_fp, 'w+') + + # Apply the model on the dataset + all_run_times, all_dirs, all_inputs = [], [], [] + all_tag_preds, all_final_preds, all_targets = [], [], [] + nb_iters = int(ceil(len(dataset) / batch_size)) + for i in tqdm(range(nb_iters)): + start_idx = i * batch_size + end_idx = (i + 1) * batch_size + batch_insts = dataset[start_idx:end_idx] + batch_dirs, batch_inputs, batch_targets = zip(*batch_insts) + # Inference and Running Time Measurement + batch_start_time = perf_counter() + batch_tag_preds, _, batch_final_preds = self._infer(batch_inputs, batch_dirs) + batch_run_time = (perf_counter() - batch_start_time) * 1000 # milliseconds + all_run_times.append(batch_run_time) + # Update all_dirs, all_inputs, all_tag_preds, all_final_preds and all_targets + all_dirs.extend(batch_dirs) + all_inputs.extend(batch_inputs) + all_tag_preds.extend(batch_tag_preds) + all_final_preds.extend(batch_final_preds) + all_targets.extend(batch_targets) + + # Metrics + tn_error_ctx, itn_error_ctx = 0, 0 + for direction in constants.INST_DIRECTIONS: + cur_dirs, cur_inputs, cur_tag_preds, cur_final_preds, cur_targets = [], [], [], [], [] + for dir, _input, tag_pred, final_pred, target in zip( + all_dirs, all_inputs, all_tag_preds, all_final_preds, all_targets + ): + if dir == direction: + cur_dirs.append(dir) + cur_inputs.append(_input) + cur_tag_preds.append(tag_pred) + cur_final_preds.append(final_pred) + cur_targets.append(target) + nb_instances = len(cur_final_preds) + sent_accuracy = TextNormalizationTestDataset.compute_sent_accuracy(cur_final_preds, cur_targets, cur_dirs) + if verbose: + logging.info(f'\n============ Direction {direction} ============') + logging.info(f'Sentence Accuracy: {sent_accuracy}') + logging.info(f'nb_instances: {nb_instances}') + # Update results + results[direction] = {'sent_accuracy': sent_accuracy, 'nb_instances': nb_instances} + # Write errors to log file + for _input, tag_pred, final_pred, target in zip(cur_inputs, cur_tag_preds, cur_final_preds, cur_targets): + if not TextNormalizationTestDataset.is_same(final_pred, target, direction): + if direction == constants.INST_BACKWARD: + error_f.write('Backward Problem (ITN)\n') + itn_error_ctx += 1 + elif direction == constants.INST_FORWARD: + error_f.write('Forward Problem (TN)\n') + tn_error_ctx += 1 + formatted_input_str = get_formatted_string(_input.split(' ')) + formatted_tag_pred_str = get_formatted_string(tag_pred) + error_f.write(f'Original Input : {_input}\n') + error_f.write(f'Input : {formatted_input_str}\n') + error_f.write(f'Predicted Tags : {formatted_tag_pred_str}\n') + error_f.write(f'Predicted : {final_pred}\n') + error_f.write(f'Ground-Truth : {target}\n') + error_f.write('\n') + results['itn_error_ctx'] = itn_error_ctx + results['tn_error_ctx'] = tn_error_ctx + + # Running Time + avg_running_time = np.average(all_run_times) / batch_size # in ms + if verbose: + logging.info(f'Average running time (normalized by batch size): {avg_running_time} ms') + results['running_time'] = avg_running_time + + # Close log file + error_f.close() + + return results + + # Functions for inference + def _infer(self, sents: List[str], inst_directions: List[str]): + """ Main function for Inference + Args: + sents: A list of input texts. + inst_directions: A list of str where each str indicates the direction of the corresponding instance (i.e., INST_BACKWARD for ITN or INST_FORWARD for TN). + + Returns: + tag_preds: A list of lists where each list contains the tag predictions from the tagger for an input text. + output_spans: A list of lists where each list contains the decoded semiotic spans from the decoder for an input text. + final_outputs: A list of str where each str is the final output text for an input text. + """ + # Preprocessing + sents = self.input_preprocessing(list(sents)) + + # Tagging + tag_preds, nb_spans, span_starts, span_ends = self.tagger._infer(sents, inst_directions) + output_spans = self.decoder._infer(sents, nb_spans, span_starts, span_ends, inst_directions) + + # Preprare final outputs + final_outputs = [] + for ix, (sent, tags) in enumerate(zip(sents, tag_preds)): + cur_words, jx, span_idx = [], 0, 0 + cur_spans = output_spans[ix] + while jx < len(sent): + tag, word = tags[jx], sent[jx] + if constants.SAME_TAG in tag: + cur_words.append(word) + jx += 1 + elif constants.PUNCT_TAG in tag: + jx += 1 + else: + jx += 1 + cur_words.append(cur_spans[span_idx]) + span_idx += 1 + while jx < len(sent) and tags[jx] == constants.I_PREFIX + constants.TRANSFORM_TAG: + jx += 1 + cur_output_str = ' '.join(cur_words) + cur_output_str = ' '.join(word_tokenize(cur_output_str)) + final_outputs.append(cur_output_str) + return tag_preds, output_spans, final_outputs + + def input_preprocessing(self, sents): + """ Function for preprocessing the input texts. The function first does + some basic tokenization using nltk.word_tokenize() and then it processes + Greek letters such as Δ or λ (if any). + + Args: + sents: A list of input texts. + + Returns: A list of preprocessed input texts. + """ + # Basic Preprocessing and Tokenization + for ix, sent in enumerate(sents): + sents[ix] = sents[ix].replace('+', ' plus ') + sents[ix] = sents[ix].replace('=', ' equals ') + sents[ix] = sents[ix].replace('@', ' at ') + sents[ix] = sents[ix].replace('*', ' times ') + sents = [word_tokenize(sent) for sent in sents] + + # Greek letters processing + for ix, sent in enumerate(sents): + for jx, tok in enumerate(sent): + if tok in constants.GREEK_TO_SPOKEN: + sents[ix][jx] = constants.GREEK_TO_SPOKEN[tok] + + return sents diff --git a/nemo/collections/nlp/models/duplex_text_normalization/utils.py b/nemo/collections/nlp/models/duplex_text_normalization/utils.py new file mode 100644 index 000000000000..7506443658b4 --- /dev/null +++ b/nemo/collections/nlp/models/duplex_text_normalization/utils.py @@ -0,0 +1,37 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ['is_url', 'has_numbers'] + + +def is_url(input_str: str): + """ Check if a string is a URL """ + url_segments = ['www', 'http', '.org', '.com', '.tv'] + return any(segment in input_str for segment in url_segments) + + +def has_numbers(input_str: str): + """ Check if a string has a number character """ + return any(char.isdigit() for char in input_str) + + +def get_formatted_string(strs, str_max_len=10, space_len=2): + """ Get a nicely formatted string from a list of strings""" + padded_strs = [] + for cur_str in strs: + cur_str = cur_str + ' ' * (str_max_len - len(cur_str)) + padded_strs.append(cur_str[:str_max_len]) + + spaces = ' ' * space_len + return spaces.join(padded_strs) diff --git a/requirements/requirements_nlp.txt b/requirements/requirements_nlp.txt index b2a0d4b4f723..fcd05bd0517e 100644 --- a/requirements/requirements_nlp.txt +++ b/requirements/requirements_nlp.txt @@ -11,3 +11,5 @@ megatron-lm==2.2.0 inflect sacrebleu[ja] sacremoses>=0.0.43 +nltk==3.6.2 +wordninja==2.0.0