Skip to content

Commit

Permalink
Compute token precision of tagger using NeMo metrics
Browse files Browse the repository at this point in the history
Signed-off-by: Tuan Lai <tuanl@nvidia.com>
  • Loading branch information
Tuan Lai committed Jul 2, 2021
1 parent 6f47f2b commit 76359fa
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ tagger_model:
name: WarmupAnnealing

# pytorch lightning args
monitor: val_sentence_accuracy
monitor: val_token_precision
reduce_on_plateau: false

# scheduler config override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from nemo.utils.decorators.experimental import experimental
from nemo.collections.nlp.models.duplex_text_normalization.utils import has_numbers
from nemo.collections.nlp.metrics.classification_report import ClassificationReport
from nemo.collections.nlp.data.text_normalization import TextNormalizationTaggerDataset

__all__ = ['DuplexTaggerModel']
Expand All @@ -51,6 +52,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# Loss Functions
self.loss_fct = nn.CrossEntropyLoss(ignore_index=constants.LABEL_PAD_TOKEN_ID)

# setup to track metrics
self.classification_report = ClassificationReport(
self.num_labels, mode='micro', dist_sync_on_step=True
)

# Training
def training_step(self, batch, batch_idx):
"""
Expand Down Expand Up @@ -92,54 +98,29 @@ def validation_step(self, batch, batch_idx):

# Loss
tag_labels = batch['labels'].to(self.device)
val_loss = self.loss_fct(tag_logits.view(-1, num_labels),
tag_labels.view(-1))

# Extract batch_predictions and batch_labels
# Update classification_report
predictions, labels = tag_preds.tolist(), tag_labels.tolist()
final_predictions = [
[constants.ALL_TAG_LABELS[p] for (p, l) in zip(prediction, label) \
if l != constants.LABEL_PAD_TOKEN_ID]
for prediction, label in zip(predictions, labels)
]
final_labels = [
[constants.ALL_TAG_LABELS[l] for (p, l) in zip(prediction, label) \
if l != constants.LABEL_PAD_TOKEN_ID]
for prediction, label in zip(predictions, labels)
]

# Compute sent_count and sent_correct
sent_count, sent_correct = 0, 0
for p, l in zip(final_predictions, final_labels):
sent_correct += int(p==l)
sent_count += 1
for prediction, label in zip(predictions, labels):
cur_preds = [p for (p, l) in zip(prediction, label) if l != constants.LABEL_PAD_TOKEN_ID]
cur_labels = [l for (p, l) in zip(prediction, label) if l != constants.LABEL_PAD_TOKEN_ID]
self.classification_report(torch.tensor(cur_preds).to(self.device),
torch.tensor(cur_labels).to(self.device))

return {
'val_loss': val_loss,
'sent_count': torch.tensor(sent_count),
'sent_correct': torch.tensor(sent_correct)
}

def validation_epoch_end(self, outputs):
"""
Called at the end of validation to aggregate outputs.
:param outputs: list of individual outputs of each validation step.
"""
# calculate metrics and classification report
precision, _, _, report = self.classification_report.compute()

# Average loss
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
self.log('val_loss', avg_loss)
logging.info(report)

# Sentence Accuracy
sent_correct = int(torch.stack([x['sent_correct'] for x in outputs]).sum())
sent_count = int(torch.stack([x['sent_count'] for x in outputs]).sum())
sent_accuracy = sent_correct / sent_count
self.log('val_sentence_accuracy', sent_accuracy)
self.log('val_token_precision', precision)

return {
'val_loss': avg_loss,
'val_sentence_accuracy': sent_accuracy
}
self.classification_report.reset()

def test_step(self, batch, batch_idx):
"""
Expand Down

0 comments on commit 76359fa

Please sign in to comment.