Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

evaluation refactoring #23

Merged
merged 20 commits into from
Jun 17, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion __init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
import SummerTime.model
import SummerTime.dataset.stdatasets as data
import SummerTime.eval
import SummerTime.evaluation
176 changes: 83 additions & 93 deletions demo.ipynb

Large diffs are not rendered by default.

11 changes: 0 additions & 11 deletions eval/Metric.py

This file was deleted.

4 changes: 0 additions & 4 deletions eval/__init__.py

This file was deleted.

12 changes: 0 additions & 12 deletions eval/bertscore.py

This file was deleted.

12 changes: 0 additions & 12 deletions eval/bleu.py

This file was deleted.

17 changes: 0 additions & 17 deletions eval/rouge.py

This file was deleted.

15 changes: 0 additions & 15 deletions eval/rougewe.py

This file was deleted.

6 changes: 6 additions & 0 deletions evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .rouge_metric import Rouge
from .bertscore_metric import BertScore
from .rougewe_metric import RougeWe
from .bleu_metric import Bleu

SUPPORTED_EVALUATION_METRICS = [BertScore, RougeWe, Bleu]
22 changes: 22 additions & 0 deletions evaluation/base_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import List, Tuple, Dict

class SummMetric():
metric_name: str = None
range: Tuple[float, float] = None
higher_is_better: bool = None
low_resource: bool = None

def evaluate(self,
zhangir-azerbayev marked this conversation as resolved.
Show resolved Hide resolved
## TODO zhangir: integrate with dataset api
inputs: List[str],
targets: List[str],
keys: List[str]) -> Dict[str, float]:
"""
All metrics should have this function.
:input: A list of summaries.
:target: A list of target summaries corresponding to each entry of input.
:keys: Which metrics to return,
e.g, ['rouge_1_f_score', 'rouge_2_f_score']
:return: A dictionary with keys metrics and values scores.
"""
raise NotImplementedError("the base class for metrics shouldn't be instantiated!")
20 changes: 20 additions & 0 deletions evaluation/bertscore_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from summ_eval.bert_score_metric import BertScoreMetric
from SummerTime.evaluation.summeval_metric import SummEvalMetric
from typing import List, Dict

class BertScore(SummEvalMetric):
metric_name = 'bert score'
range = (0, 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nice design. I would also add a comment on top of the range variable whether it's inclusive or exclusive on the boundaries

higher_is_better = True
low_resource = False

def __init__(self):
se_metric = BertScoreMetric()
super(BertScore, self).__init__(se_metric)

def evaluate(self,
inputs: List[str],
targets: List[str],
keys: List[str] = ['bert_score_f1']) -> Dict[str, float]:
#TODO zhangir: update when datasets api is merged
return super(BertScore, self).evaluate(inputs, targets, keys)
20 changes: 20 additions & 0 deletions evaluation/bleu_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from summ_eval.bleu_metric import BleuMetric
from SummerTime.evaluation.summeval_metric import SummEvalMetric
from typing import List, Dict

class Bleu(SummEvalMetric):
metric_name = 'bleu'
range = (0, 10)
zhangir-azerbayev marked this conversation as resolved.
Show resolved Hide resolved
higher_is_better = True
low_resource = True
zhangir-azerbayev marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self):
se_metric = BleuMetric()
super(Bleu, self).__init__(se_metric)

def evaluate(self,
inputs: List[str],
targets: List[str],
keys: List[str] = ['bleu']) -> Dict[str, float]:
# TODO zhangir: potentially update when dataset api is merged.
return super(Bleu, self).evaluate(inputs, targets, keys)
19 changes: 19 additions & 0 deletions evaluation/rouge_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from summ_eval.rouge_metric import RougeMetric
from SummerTime.evaluation.summeval_metric import SummEvalMetric
from typing import List, Dict

class Rouge(SummEvalMetric):
metric_name = 'rouge'
range = (0, 1)
higher_is_better = True
low_resource = True

def __init__(self):
se_metric = RougeMetric()
super(Rouge, self).__init__(se_metric)

def evaluate(self,
inputs: List[str],
targets: List[str],
keys: List[str] = ['rouge_3_f_score']) -> Dict[str, float]:
return super(Rouge, self).evaluate(inputs, targets, keys)
22 changes: 22 additions & 0 deletions evaluation/rougewe_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from summ_eval.rouge_we_metric import RougeWeMetric
from SummerTime.evaluation.summeval_metric import SummEvalMetric
from typing import List, Dict
import nltk

class RougeWe(SummEvalMetric):
metric_name = 'rougeWE'
range = (0, 1)
higher_is_better = True
low_resource = False

def __init__(self):
nltk.download('stopwords')
se_metric = RougeWeMetric()
super(RougeWe, self).__init__(se_metric)

def evaluate(self,
inputs: List[str],
targets: List[str],
keys: List[str] = ['rouge_we_3_f']) -> Dict[str, float]:
#TODO zhangir: update when dataset api is merged.
return super(RougeWe, self).evaluate(inputs, targets, keys)
19 changes: 19 additions & 0 deletions evaluation/summeval_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from .base_metric import SummMetric
from summ_eval.metric import Metric as SEMetric

class SummEvalMetric(SummMetric):
"""
Generic class for a summarization metric whose backend is SummEval.
"""

def __init__(self,
se_metric: SEMetric):
self.se_metric = se_metric

def evaluate(self,
zhangir-azerbayev marked this conversation as resolved.
Show resolved Hide resolved
inputs: List[str],
targets: List[str],
keys: List[str]) -> Dict[str, float]:
score_dict = self.se_metric.evaluate_batch(
inputs, targets)
return {key: score_dict[key] for key in keys}
1 change: 0 additions & 1 deletion summertime_pkg/README.md

This file was deleted.

41 changes: 41 additions & 0 deletions tests/evaluation_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import unittest
from typing import Tuple, List, Dict

from evaluation import SUPPORTED_EVALUATION_METRICS

class TestEvaluationMetrics(unittest.TestCase):
def get_summary_pair(self, size: int=1) -> Tuple[List[str]]:
zhangir-azerbayev marked this conversation as resolved.
Show resolved Hide resolved
test_output = [ """
Glowing letters that had been hanging above
the Yankee stadium from 1976 to 2008 were placed for auction at
Sotheby’s on Wednesday, but were not sold, The current owner
of the sign is Reggie Jackson, a Yankee hall-of-famer."""]
test_target = ["""
An auction for the lights from Yankee Stadium failed to
produce any bids on Wednesday at Sotheby’s. The lights,
currently owned by former Yankees player Reggie Jackson,
lit the stadium from 1976 until 2008."""]

return test_output, test_target


def test_evaluate(self):
print(f"{'#'*10} test_evaluate STARTS {'#'*10}")

for metric_class in SUPPORTED_EVALUATION_METRICS:
print(f"Test on {metric_class}")
metric = metric_class()

test_output, test_target = self.get_summary_pairs()
score_dict = metric.evaluate(test_output, test_target)
print(f"{metric_class} output dictionary")
print(score_dict)
self.assertIs(score_dict, Dict[str, float])
self.assertNotEqual(score_dict, {})
for key in score_dict:
self.assertTrue(self.range[0] <= score_dict[key])
self.assertTrue(score_dict[key] <= self.range[1])


if __name__ = '__main__':
unittest.main()