Skip to content
Permalink
Browse files

Add perplexity metric to LanguageModel (#2548)

- Add Perplexity metric
- Call Perplexity metric with `average_loss` from `LanguageModel`
- Add Perplexity docs
  • Loading branch information...
thesamuel authored and nelson-liu committed Apr 23, 2019
1 parent 4422d53 commit e2d153f89c40e331a2bb072866f12f6d482ecf72
@@ -11,6 +11,7 @@
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder
from allennlp.nn.util import get_text_field_mask
from allennlp.nn import InitializerApplicator
from allennlp.training.metrics import Perplexity


class _SoftmaxLoss(torch.nn.Module):
@@ -127,9 +128,11 @@ def __init__(self,
self._softmax_loss = _SoftmaxLoss(num_words=vocab.get_vocab_size(),
embedding_dim=self._forward_dim)

# TODO(brendanr): Output perplexity here. e^loss
# This buffer is now unused and exists only for backwards compatibility reasons.
self.register_buffer('_last_average_loss', torch.zeros(1))

self._perplexity = Perplexity()

if dropout:
self._dropout = torch.nn.Dropout(dropout)
else:
@@ -300,8 +303,8 @@ def forward(self, # type: ignore
average_loss = forward_loss / num_targets.float()
else:
average_loss = torch.tensor(0.0).to(forward_targets.device) # pylint: disable=not-callable
# this is stored to compute perplexity if needed
self._last_average_loss[0] = average_loss.detach().item()

self._perplexity(average_loss)

if num_targets > 0:
return_dict.update({
@@ -327,3 +330,6 @@ def forward(self, # type: ignore
})

return return_dict

def get_metrics(self, reset: bool = False):
return {"perplexity": self._perplexity.get_metric(reset=reset)}
@@ -18,6 +18,7 @@
from allennlp.training.metrics.mention_recall import MentionRecall
from allennlp.training.metrics.metric import Metric
from allennlp.training.metrics.pearson_correlation import PearsonCorrelation
from allennlp.training.metrics.perplexity import Perplexity
from allennlp.training.metrics.sequence_accuracy import SequenceAccuracy
from allennlp.training.metrics.span_based_f1_measure import SpanBasedF1Measure
from allennlp.training.metrics.squad_em_and_f1 import SquadEmAndF1
@@ -0,0 +1,32 @@
from overrides import overrides
import torch

from allennlp.training.metrics.average import Average
from allennlp.training.metrics.metric import Metric


@Metric.register("perplexity")
class Perplexity(Average):
"""
Perplexity is a common metric used for evaluating how well a language model
predicts a sample.
Notes
-----
Assumes negative log likelihood loss of each batch (base e). Provides the
average perplexity of the batches.
"""

@overrides
def get_metric(self, reset: bool = False) -> float:
"""
Returns
-------
The accumulated perplexity.
"""
average_loss = super().get_metric(reset)
if average_loss == 0:
return 0.

# Exponentiate the loss to compute perplexity
return float(torch.exp(average_loss))
@@ -116,6 +116,12 @@ allennlp.training.metrics
:undoc-members:
:show-inheritance:

.. _perplexity:
.. automodule:: allennlp.training.metrics.perplexity
:members:
:undoc-members:
:show-inheritance:

.. _sequence-accuracy:
.. automodule:: allennlp.training.metrics.sequence_accuracy
:members:

0 comments on commit e2d153f

Please sign in to comment.
You can’t perform that action at this time.