Skip to content

Commit

Permalink
Add ROUGE scores to default metrics.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 445206099
  • Loading branch information
iftenney authored and LIT team committed Apr 28, 2022
1 parent bb06368 commit 6773927
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 13 deletions.
1 change: 1 addition & 0 deletions lit_nlp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ def __init__(
'multiclass': metrics.MulticlassMetrics(),
'paired': metrics.MulticlassPairedMetrics(),
'bleu': metrics.CorpusBLEU(),
'rouge': metrics.RougeL(),
})
gradient_map_interpreters = {
'Grad L2 Norm': gradient_maps.GradientNorm(),
Expand Down
40 changes: 40 additions & 0 deletions lit_nlp/components/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from scipy.spatial import distance as scipy_distance
from sklearn import metrics as sklearn_metrics

from rouge_score import rouge_scorer
JsonDict = types.JsonDict
IndexedInput = types.IndexedInput
Spec = types.Spec
Expand Down Expand Up @@ -468,6 +469,45 @@ def compute(self,
return {'corpus_bleu' + name_suffix: bleu.score}


class RougeL(SimpleMetrics):
"""RougeL score for generation tasks."""

def __init__(self, *args, **kw):
super().__init__(*args, **kw)
self._scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

def _score(self, reference, prediction):
return self._scorer.score(
target=reference, prediction=prediction)['rougeL'].fmeasure

def is_compatible(self, field_spec: types.LitType) -> bool:
"""Return true if compatible with this field."""
return isinstance(field_spec,
(types.GeneratedText, types.GeneratedTextCandidates))

def compute(self,
labels: Sequence[Text],
preds: Sequence[Union[Text, types.ScoredTextCandidates]],
label_spec: types.TextSegment,
pred_spec: Union[types.GeneratedText,
types.GeneratedTextCandidates],
config: Optional[JsonDict] = None) -> Dict[Text, float]:
"""Compute metric(s) between labels and predictions."""
del label_spec
del config

if not labels or not preds:
return {}

name_suffix = ''
if isinstance(pred_spec, types.GeneratedTextCandidates):
preds = [types.GeneratedTextCandidates.top_text(v) for v in preds]
name_suffix = '@1'
scores = list(map(self._score, labels, preds))

return {'rougeL' + name_suffix: np.mean(scores)}


class BinaryConfusionMetricsImpl(SimpleMetrics):
"""Confusion matrix values for binary classification."""

Expand Down
100 changes: 87 additions & 13 deletions lit_nlp/components/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,45 +211,119 @@ def test_compute(self):
class CorpusBLEUTest(absltest.TestCase):

def test_is_compatible(self):
corpusblue_metrics = metrics.CorpusBLEU()
bleu_metrics = metrics.CorpusBLEU()

# Only compatible with generation types.
self.assertTrue(bleu_metrics.is_compatible(types.GeneratedText()))
self.assertTrue(bleu_metrics.is_compatible(types.GeneratedTextCandidates()))

# Only compatible with GeneratedText spec.
self.assertTrue(corpusblue_metrics.is_compatible(types.GeneratedText()))
self.assertFalse(
corpusblue_metrics.is_compatible(types.MulticlassPreds(vocab=[''])))
self.assertFalse(corpusblue_metrics.is_compatible(types.RegressionScore()))
bleu_metrics.is_compatible(types.MulticlassPreds(vocab=[''])))
self.assertFalse(bleu_metrics.is_compatible(types.RegressionScore()))

def test_compute(self):
corpusblue_metrics = metrics.CorpusBLEU()
bleu_metrics = metrics.CorpusBLEU()

# All correct predictions.
result = corpusblue_metrics.compute(
result = bleu_metrics.compute(
['This is a test.', 'Test two', 'A third test example'],
['This is a test.', 'Test two', 'A third test example'],
types.GeneratedText(), types.GeneratedText())
testing_utils.assert_deep_almost_equal(self, result,
{'corpus_bleu': 100.00000})
{'corpus_bleu': 100.0000})

# Some incorrect predictions.
result = corpusblue_metrics.compute(
result = bleu_metrics.compute(
['This is a test.', 'Test one', 'A third test'],
['This is a test.', 'Test two', 'A third test example'],
types.GeneratedText(), types.GeneratedText())
testing_utils.assert_deep_almost_equal(self, result,
{'corpus_bleu': 68.037493})

result = corpusblue_metrics.compute(
result = bleu_metrics.compute(
['This is a test.', 'Test one', 'A third test'],
['these test.', 'Test two', 'A third test example'],
types.GeneratedText(), types.GeneratedText())
testing_utils.assert_deep_almost_equal(self, result,
{'corpus_bleu': 29.508062388758525})
{'corpus_bleu': 29.508062})

# Empty labels and predictions
result = bleu_metrics.compute([], [], types.GeneratedText(),
types.GeneratedText())
testing_utils.assert_deep_almost_equal(self, result, {})

def test_compute_with_candidates(self):
bleu_metrics = metrics.CorpusBLEU()

# Should only score the first one (@1).
labels = ['This is a test.', 'Test two']
preds = [
[('This is a test.', -1.0), ('foobar', -20.0)],
[('Test two', -1.0), ('spam', -20.0)],
]

result = bleu_metrics.compute(labels, preds, types.TextSegment(),
types.GeneratedTextCandidates())
testing_utils.assert_deep_almost_equal(self, result,
{'corpus_bleu@1': 100.0000})


class RougeLTest(absltest.TestCase):

def test_is_compatible(self):
rouge_metrics = metrics.RougeL()

# Only compatible with generation types.
self.assertTrue(rouge_metrics.is_compatible(types.GeneratedText()))
self.assertTrue(
rouge_metrics.is_compatible(types.GeneratedTextCandidates()))

self.assertFalse(
rouge_metrics.is_compatible(types.MulticlassPreds(vocab=[''])))
self.assertFalse(rouge_metrics.is_compatible(types.RegressionScore()))

def test_compute(self):
rouge_metrics = metrics.RougeL()

# All correct predictions.
result = rouge_metrics.compute(
['This is a test.', 'Test two', 'A third test example'],
['This is a test.', 'Test two', 'A third test example'],
types.TextSegment(), types.GeneratedText())
testing_utils.assert_deep_almost_equal(self, result, {'rougeL': 1.0})

# Some incorrect predictions.
result = rouge_metrics.compute(
['This is a test.', 'Test one', 'A third test'],
['This is a test.', 'Test two', 'A third test example'],
types.TextSegment(), types.GeneratedText())
testing_utils.assert_deep_almost_equal(self, result, {'rougeL': 0.785714})

result = rouge_metrics.compute(
['This is a test.', 'Test one', 'A third test'],
['these test.', 'Test two', 'A third test example'],
types.TextSegment(), types.GeneratedText())
testing_utils.assert_deep_almost_equal(self, result, {'rougeL': 0.563492})

# Empty labels and predictions
result = corpusblue_metrics.compute([], [], types.GeneratedText(),
types.GeneratedText())
result = rouge_metrics.compute([], [], types.GeneratedText(),
types.GeneratedText())
testing_utils.assert_deep_almost_equal(self, result, {})

def test_compute_with_candidates(self):
rouge_metrics = metrics.RougeL()

# Should only score the first one (@1).
labels = ['This is a test.', 'Test two']
preds = [
[('This is a test.', -1.0), ('foobar', -20.0)],
[('Test two', -1.0), ('spam', -20.0)],
]

result = rouge_metrics.compute(labels, preds, types.TextSegment(),
types.GeneratedTextCandidates())
testing_utils.assert_deep_almost_equal(self, result, {'rougeL@1': 1.0})


class ClassifcationMarginTest(absltest.TestCase):

Expand Down

0 comments on commit 6773927

Please sign in to comment.