Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Rouge score #399

Merged
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
3a23b94
Adding new metric ROUGE Metric for text
karthikrangasai Jul 23, 2021
a398b6f
Added tests for the ROUGE metric
karthikrangasai Jul 23, 2021
bca2f12
Updated docs and imports, added types
karthikrangasai Jul 23, 2021
586e82a
Apply suggestions from code review
Borda Jul 23, 2021
5169561
Applied changes suggested in code review
karthikrangasai Jul 23, 2021
2dfcd9f
Merge branch 'master' into feature/51_add_rouge_score
karthikrangasai Jul 25, 2021
7e9fed1
Updated text dependencies and CHANGELOG
karthikrangasai Jul 25, 2021
8027712
Fix typing issues
karthikrangasai Jul 25, 2021
067c7f0
Updated docs dependencies
karthikrangasai Jul 25, 2021
cba098a
pkg
Borda Jul 26, 2021
506cfc7
pkg
Borda Jul 26, 2021
b0f28bd
set jiwer
Borda Jul 26, 2021
5b5872b
Merge branch 'master' into feature/51_add_rouge_score
karthikrangasai Jul 26, 2021
d5782fe
Merge branch 'master' into feature/51_add_rouge_score
karthikrangasai Jul 27, 2021
49f40dd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 27, 2021
a0ef6e7
Simplified the implementation for batches and added more tests.
karthikrangasai Jul 27, 2021
e9ecc62
Updated docs requirements and removed unused imports.
karthikrangasai Jul 27, 2021
8fc553a
Merge branch 'master' into feature/51_add_rouge_score
karthikrangasai Jul 28, 2021
8d2e102
Merge branch 'master' into feature/51_add_rouge_score
Borda Jul 28, 2021
36e3a7d
Merge branch 'master' into feature/51_add_rouge_score
karthikrangasai Jul 28, 2021
af2c103
Fix typing, rigorously check rouge_keys, add tests for rouge_keys err…
karthikrangasai Jul 28, 2021
34f8394
Remove unused imports
karthikrangasai Jul 28, 2021
5fd1160
Apply suggestions from code review
Borda Jul 28, 2021
1439497
Fixed typing and added docstrings for update and compute method
karthikrangasai Jul 29, 2021
71d17f4
Merge branch 'master' into feature/51_add_rouge_score
karthikrangasai Jul 29, 2021
218418f
Changes based on review
karthikrangasai Jul 29, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added Symmetric Mean Absolute Percentage error (SMAPE) ([#375](https://github.com/PyTorchLightning/metrics/issues/375))

- Added ROUGE Metric ([#399](https://github.com/PyTorchLightning/metrics/issues/399))
Borda marked this conversation as resolved.
Show resolved Hide resolved

- Allowed passing labels in (n_samples, n_classes) to `AveragePrecision` ([#386](https://github.com/PyTorchLightning/metrics/issues/386))

Expand Down
4 changes: 4 additions & 0 deletions docs/source/references/functional.rst
Expand Up @@ -355,6 +355,10 @@ bleu_score [func]
.. autofunction:: torchmetrics.functional.bleu_score
:noindex:

rouge_score [func]
~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.rouge_score

wer [func]
~~~~~~~~~~
Expand Down
6 changes: 6 additions & 0 deletions docs/source/references/modules.rst
Expand Up @@ -517,6 +517,12 @@ BLEUScore
.. autoclass:: torchmetrics.BLEUScore
:noindex:

ROUGEScore
~~~~~~~~~~~
Borda marked this conversation as resolved.
Show resolved Hide resolved

.. autoclass:: torchmetrics.ROUGEScore
:noindex:


WER
~~~
Expand Down
3 changes: 3 additions & 0 deletions requirements/docs.txt
Expand Up @@ -13,3 +13,6 @@ sphinx-copybutton>=0.3

# integrations
pytorch-lightning>=1.1

# add extra requirements
-r text.txt
karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved
1 change: 0 additions & 1 deletion requirements/test.txt
Expand Up @@ -15,7 +15,6 @@ phmdoctest>=1.1.1
cloudpickle>=1.3
scikit-learn>=0.24
scikit-image>0.17.1
nltk>=3.6
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

# add extra requirements
-r image.txt
Expand Down
4 changes: 3 additions & 1 deletion requirements/text.txt
@@ -1 +1,3 @@
jiwer==2.2.0
jiwer>=2.2.0
karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved
nltk>=3.6
rouge-score>=0.0.4
51 changes: 51 additions & 0 deletions tests/text/test_rouge.py
@@ -0,0 +1,51 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch
from torch import tensor

from torchmetrics.functional.text.rouge import rouge_score
from torchmetrics.text.rouge import ROUGEScore
from torchmetrics.utilities.imports import _NLTK_AVAILABLE, _ROUGE_SCORE_AVAILABLE

PREDS = "My name is John".split()
TARGET = "Is your name John".split()


@pytest.mark.skipif(not (_NLTK_AVAILABLE or _ROUGE_SCORE_AVAILABLE), reason='test requires nltk and rouge-score')
@pytest.mark.parametrize("rouge_metric, expected", [("rouge1_recall", 0.25)])
def test_rouge_metric_functional(rouge_metric, expected):
pl_output = tensor(rouge_score(PREDS, TARGET)[rouge_metric]).float()
assert torch.allclose(pl_output, tensor(expected), 1e-4)


@pytest.mark.skipif(not (_NLTK_AVAILABLE or _ROUGE_SCORE_AVAILABLE), reason='test requires nltk and rouge-score')
@pytest.mark.parametrize("rouge_metric, expected", [("rouge1_recall", 0.25)])
def test_rouge_metric_class(rouge_metric, expected):
rouge = ROUGEScore()
pl_output = tensor(rouge(PREDS, TARGET)[rouge_metric]).float()
assert torch.allclose(pl_output, tensor(expected), 1e-4)
karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved


def test_rouge_metric_raises_errors_and_warnings():
""" Test that expected warnings and errors are raised """
if not (_NLTK_AVAILABLE or _ROUGE_SCORE_AVAILABLE):
with pytest.raises(
ValueError,
match='ROUGE metric requires that both nltk and rouge-score is installed.'
'Either as `pip install torchmetrics[text]`'
' or `pip install nltk rouge-score`'
):
ROUGEScore()
2 changes: 1 addition & 1 deletion torchmetrics/__init__.py
Expand Up @@ -60,5 +60,5 @@
RetrievalPrecision,
RetrievalRecall,
)
from torchmetrics.text import WER, BLEUScore # noqa: E402, F401
from torchmetrics.text import WER, BLEUScore, ROUGEScore # noqa: E402, F401
from torchmetrics.wrappers import BootStrapper # noqa: E402, F401
1 change: 1 addition & 0 deletions torchmetrics/functional/__init__.py
Expand Up @@ -58,4 +58,5 @@
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401
from torchmetrics.functional.self_supervised import embedding_similarity # noqa: F401
from torchmetrics.functional.text.bleu import bleu_score # noqa: F401
from torchmetrics.functional.text.rouge import rouge_score # noqa: F401
from torchmetrics.functional.text.wer import wer # noqa: F401
158 changes: 158 additions & 0 deletions torchmetrics/functional/text/rouge.py
@@ -0,0 +1,158 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Dict, List, Tuple

import numpy as np
from torch import Tensor, tensor

from torchmetrics.utilities.imports import _NLTK_AVAILABLE, _ROUGE_SCORE_AVAILABLE

if _ROUGE_SCORE_AVAILABLE:
from rouge_score.rouge_scorer import RougeScorer
from rouge_score.scoring import AggregateScore, BootstrapAggregator, Score
else:
RougeScorer, AggregateScore, Score, BootstrapAggregator = object, object, object, object


def add_newline_to_end_of_each_sentence(x: str) -> str:
"""This was added to get rougeLsum scores matching published rougeL scores for BART and PEGASUS."""
if _NLTK_AVAILABLE:
import nltk
nltk.download("punkt", quiet=True, force=False)

re.sub("<n>", "", x) # remove pegasus newline char
assert nltk, "nltk must be installed to separate newlines between sentences. (pip install nltk)"
return "\n".join(nltk.sent_tokenize(x))


def format_rouge_results(result: Dict[str, AggregateScore], decimal_places: int = 4) -> Dict[str, Tensor]:
karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved
flattened_result = {}
for rouge_key, rouge_aggregate_score in result.items():
for stat in ["precision", "recall", "fmeasure"]:
mid = rouge_aggregate_score.mid
score = round(getattr(mid, stat), decimal_places)
flattened_result[f"{rouge_key}_{stat}"] = tensor(score)
return flattened_result


class RougeBatchAggregator(BootstrapAggregator):
"""
Aggregates rouge scores and provides confidence intervals.
"""

def aggregate(self) -> Dict[str, AggregateScore]:
"""
Override function to wrap the final results in `Score` objects.
This is due to the scores being replaced with a list of torch tensors.
"""
result = {}
for score_type, scores in self._scores.items():
# Stack scores into a 2-d matrix of (sample, measure).
score_matrix = np.vstack(tuple(scores))
# Percentiles are returned as (interval, measure).
percentiles = self._bootstrap_resample(score_matrix)
# Extract the three intervals (low, mid, high).
intervals = tuple(Score(*percentiles[j, :]) for j in range(3))
result[score_type] = AggregateScore(low=intervals[0], mid=intervals[1], high=intervals[2])
return result

def add_scores(self, scores: Dict[str, List[Tensor]]) -> None:
self._scores = scores


def _rouge_score_update(
preds: List[str],
targets: List[str],
scores: Dict[str, List[Tensor]],
scorer: RougeScorer,
newline_sep: bool = False,
) -> None:
karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved

for pred, target in zip(preds, targets):
# rougeLsum expects "\n" separated sentences within a summary
if newline_sep:
pred = add_newline_to_end_of_each_sentence(pred)
target = add_newline_to_end_of_each_sentence(target)
results = scorer.score(pred, target)
for key, score in results.items():
score = tensor([score.precision, score.recall, score.fmeasure])
scores[key].append(score)


def _rouge_score_compute(scores: Dict[str, List[Tensor]], aggregator: RougeBatchAggregator) -> Dict[str, Tensor]:
aggregator.add_scores(scores)
result = aggregator.aggregate()
return format_rouge_results(result)


def rouge_score(
preds: List[str],
targets: List[str],
newline_sep: bool = False,
use_stemmer: bool = False,
rouge_keys: Tuple[str] = ("rouge1", "rouge2", "rougeL", "rougeLsum") # type: ignore
) -> Dict[str, Tensor]:
"""
Calculate `ROUGE score <https://en.wikipedia.org/wiki/ROUGE_(metric)>`_, used for automatic summarization.

Args:
preds:
An iterable of predicted sentences.
targets:
An iterable of target sentences.
newline_sep:
New line separate the inputs.
use_stemmer:
karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved
Use Porter stemmer to strip word suffixes to improve matching.
rouge_keys:
A list of rouge types to calculate.

Return:
Python dictionary of rouge scores for each input rouge key.

Example:
>>> targets = "Is your name John".split()
>>> preds = "My name is John".split()
>>> from pprint import pprint
>>> pprint(rouge_score(preds, targets)) # doctest: +NORMALIZE_WHITESPACE +SKIP
{'rouge1_fmeasure': 0.25,
'rouge1_precision': 0.25,
'rouge1_recall': 0.25,
'rouge2_fmeasure': 0.0,
'rouge2_precision': 0.0,
'rouge2_recall': 0.0,
'rougeL_fmeasure': 0.25,
'rougeL_precision': 0.25,
'rougeL_recall': 0.25,
'rougeLsum_fmeasure': 0.25,
'rougeLsum_precision': 0.25,
'rougeLsum_recall': 0.25}

References:
[1] ROUGE: A Package for Automatic Evaluation of Summaries by Chin-Yew Lin https://aclanthology.org/W04-1013/
"""

karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved
if not (_NLTK_AVAILABLE or _ROUGE_SCORE_AVAILABLE):
raise ValueError(
'ROUGE metric requires that both nltk and rouge-score is installed.'
'Either as `pip install torchmetrics[text]`'
Borda marked this conversation as resolved.
Show resolved Hide resolved
)

aggregator = RougeBatchAggregator()
scorer = RougeScorer(rouge_keys, use_stemmer=use_stemmer)
scores: Dict[str, List[Tensor]] = {key: [] for key in rouge_keys}

_rouge_score_update(preds, targets, scores=scores, scorer=scorer, newline_sep=newline_sep)
return _rouge_score_compute(scores, aggregator=aggregator)
1 change: 1 addition & 0 deletions torchmetrics/text/__init__.py
Expand Up @@ -12,4 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.text.bleu import BLEUScore # noqa: F401
from torchmetrics.text.rouge import ROUGEScore # noqa: F401
from torchmetrics.text.wer import WER # noqa: F401