Skip to content

Commit

Permalink
Fixed an error related to the task indicator of the tagger during inf…
Browse files Browse the repository at this point in the history
…erence. (#2627)

* A minor fix to _infer() of DuplexTaggerModel
Signed-off-by: Tuan Lai <tuanl@nvidia.com>

* Remove tagger data augmentation
Signed-off-by: Tuan Lai <tuanl@nvidia.com>

* Minor fixes to cache path
Signed-off-by: Tuan Lai <tuanl@nvidia.com>

* Style fix
Signed-off-by: Tuan Lai <tuanl@nvidia.com>
  • Loading branch information
Tuan Manh Lai committed Aug 9, 2021
1 parent 241855c commit 7e6197d
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 35 deletions.
15 changes: 1 addition & 14 deletions docs/source/nlp/text_normalization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,20 +130,7 @@ then we will simply append the prefix ``tn`` to it and so the final input to our
be ``tn I live in tn 123 King Ave``. Similarly, for the ITN problem, we just append the prefix ``itn``
to the input.

To improve the effectiveness and robustness of our models, we also apply some simple data
augmentation techniques during training.

Data Augmentation for Training DuplexTaggerModel
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
In the Google English TN training data, about 93% of the tokens are not in any semiotic span. In other words, the ground-truth tags of most tokens are of trivial types (i.e., ``SAME`` and ``PUNCT``). To alleviate this class imbalance problem,
for each original instance with several semiotic spans, we create a new instance by simply concatenating all the semiotic spans together. For example, considering the following ITN instance:

Original instance: ``[The|SAME] [revenues|SAME] [grew|SAME] [a|SAME] [lot|SAME] [between|SAME] [two|B-TRANSFORM] [thousand|I-TRANSFORM] [two|I-TRANSFORM] [and|SAME] [two|B-TRANSFORM] [thousand|I-TRANSFORM] [five|I-TRANSFORM] [.|PUNCT]``

Augmented instance: ``[two|B-TRANSFORM] [thousand|I-TRANSFORM] [two|I-TRANSFORM] [two|B-TRANSFORM] [thousand|I-TRANSFORM] [five|I-TRANSFORM]``

The argument ``data.train_ds.tagger_data_augmentation`` in the config file controls whether this data augmentation will be enabled or not.

To improve the robustness of the decoder, we also apply a simple data augmentation technique during training the decoder.

Data Augmentation for Training DuplexDecoderModel
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,7 @@ data:
max_decoder_len: 80
mode: ${mode}
max_insts: -1 # Maximum number of instances (-1 means no limit)
# Refer to the text_normalization doc for more information about data augmentation
tagger_data_augmentation: true
decoder_data_augmentation: true
decoder_data_augmentation: true # Refer to the text_normalization doc for more information about data augmentation
use_cache: ${data.use_cache}

validation_ds:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ def __init__(

# Get cache path
data_dir, filename = os.path.split(input_file)
cached_data_file = os.path.join(data_dir, f'cached_decoder_{filename}_{tokenizer_name}_{lang}_{max_insts}.pkl')
tokenizer_name_normalized = tokenizer_name.replace('/', '_')
cached_data_file = os.path.join(
data_dir, f'cached_decoder_{filename}_{tokenizer_name_normalized}_{lang}_{max_insts}.pkl'
)

if use_cache and os.path.exists(cached_data_file):
logging.warning(
Expand Down
17 changes: 4 additions & 13 deletions nemo/collections/nlp/data/text_normalization/tagger_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ class TextNormalizationTaggerDataset(Dataset):
tokenizer_name: name of the tokenizer,
mode: should be one of the values ['tn', 'itn', 'joint']. `tn` mode is for TN only. `itn` mode is for ITN only. `joint` is for training a system that can do both TN and ITN at the same time.
do_basic_tokenize: a flag indicates whether to do some basic tokenization before using the tokenizer of the model
tagger_data_augmentation (bool): a flag indicates whether to augment the dataset with additional data instances
lang: language of the dataset
use_cache: Enables caching to use pickle format to store and read data from,
max_insts: Maximum number of instances (-1 means no limit)
Expand All @@ -55,7 +54,6 @@ def __init__(
tokenizer_name: str,
mode: str,
do_basic_tokenize: bool,
tagger_data_augmentation: bool,
lang: str,
use_cache: bool = False,
max_insts: int = -1,
Expand All @@ -69,7 +67,10 @@ def __init__(

# Get cache path
data_dir, filename = os.path.split(input_file)
cached_data_file = os.path.join(data_dir, f'cached_tagger_{filename}_{tokenizer_name}_{lang}_{max_insts}.pkl')
tokenizer_name_normalized = tokenizer_name.replace('/', '_')
cached_data_file = os.path.join(
data_dir, f'cached_tagger_{filename}_{tokenizer_name_normalized}_{lang}_{max_insts}.pkl'
)

if use_cache and os.path.exists(cached_data_file):
logging.warning(
Expand All @@ -96,16 +97,6 @@ def __init__(
# Create a new TaggerDataInstance
inst = TaggerDataInstance(w_words, s_words, inst_dir, do_basic_tokenize)
insts.append(inst)
# Data Augmentation (if enabled)
if tagger_data_augmentation:
filtered_w_words, filtered_s_words = [], []
for ix, (w, s) in enumerate(zip(w_words, s_words)):
if not s in constants.SPECIAL_WORDS:
filtered_w_words.append(w)
filtered_s_words.append(s)
if len(filtered_s_words) > 1:
inst = TaggerDataInstance(filtered_w_words, filtered_s_words, inst_dir)
insts.append(inst)

self.insts = insts
texts = [inst.input_words for inst in insts]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,6 @@ def _infer(self, sents: List[List[str]], inst_directions: List[str]):
texts.append([prefix] + sent)

# Apply the model
prefix = constants.TN_PREFIX
texts = [[prefix] + sent for sent in sents]
encodings = self._tokenizer(
texts, is_split_into_words=True, padding=True, truncation=True, return_tensors='pt'
)
Expand Down Expand Up @@ -283,14 +281,12 @@ def _setup_dataloader_from_config(self, cfg: DictConfig, mode: str):
start_time = perf_counter()
logging.info(f'Creating {mode} dataset')
input_file = cfg.data_path
tagger_data_augmentation = cfg.get('tagger_data_augmentation', False)
dataset = TextNormalizationTaggerDataset(
input_file,
self._tokenizer,
self.transformer_name,
cfg.mode,
cfg.do_basic_tokenize,
tagger_data_augmentation,
cfg.lang,
cfg.get('use_cache', False),
cfg.get('max_insts', -1),
Expand Down

0 comments on commit 7e6197d

Please sign in to comment.