Skip to content

Commit

Permalink
Allowing multi-file submissions
Browse files Browse the repository at this point in the history
- Individual datasets keyed under XXX_val.
- Submission must have "submission_name" in its keys
  • Loading branch information
tuetschek committed Jan 12, 2021
1 parent 1e8319e commit b0fb6aa
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 39 deletions.
27 changes: 26 additions & 1 deletion gem_metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@


# Data holder classes
from .texts import Predictions, References
from .texts import Predictions, References, Submission
from typing import Optional

# Metric implementations
from .meteor import Meteor
Expand All @@ -15,3 +16,27 @@
# TODO make this populate automatically based on imports
REFERENCED_METRICS = [BLEU, Meteor, ROUGE]
REFERENCELESS_METRICS = [MSTTR, NGramStats]


def compute(outs: Predictions, refs: Optional[References]) -> dict:
"""Main metrics computation routine. Expects a Predictions and a References object, holding
system outputs and corresponding references (References may be None -- only referenceless metrics
are computed in such a case).
Returns a dict with the results.
"""
# initialize values storage
values = {'predictions_file': outs.filename,
'N': len(outs)}

# compute referenceless metrics
for metric_class in REFERENCELESS_METRICS:
metric = metric_class()
values.update(metric.compute(outs))

# compute ref-based metrics
if refs is not None:
values['references_file'] = refs.filename
for metric_class in REFERENCED_METRICS:
metric = metric_class()
values.update(metric.compute(outs, refs))
return values
53 changes: 42 additions & 11 deletions gem_metrics/texts.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,28 @@
#!/usr/bin/env python3

import nltk
from typing import Optional
import json
import string
import nltk
from .nltk_data import nltk_ensure_download


class Texts:
"""Holder class for output texts or references."""

def __init__(self, key, data_file, tokenize_func):
self.filename = data_file
def __init__(self, key, data, tokenize_func):
self.key = key
# TODO allow other data formats?
with open(data_file, 'r', encoding='UTF-8') as fh:
self.all_data = json.load(fh)
self.data = [item[key] for item in self.all_data]
if not isinstance(data, dict):
self.filename = data
# TODO allow other data formats?
with open(data, 'r', encoding='UTF-8') as fh:
data = json.load(fh)
else:
self.filename = data.get('filename')
self.all_data = data['values']
self.language = data['language']

self.data = [item[key] for item in self.all_data]

# detect if we're using multiple texts per instance
self.multi_ref = isinstance(self.data[0], list)
Expand Down Expand Up @@ -53,9 +60,9 @@ class Predictions(Texts):

PUNCTUATION = set(string.punctuation)

def __init__(self, data_file):
def __init__(self, data):
nltk_ensure_download('tokenizers/punkt')
super().__init__(key='generated', data_file=data_file, tokenize_func=nltk.tokenize.word_tokenize)
super().__init__(key='generated', data=data, tokenize_func=nltk.tokenize.word_tokenize)
self._lc_tokenized = [[w.lower() for w in item] for item in self.list_tokenized]
self._nopunct_lc_tokenized = [[w for w in item if w not in self.PUNCTUATION] for item in self._lc_tokenized]

Expand All @@ -73,6 +80,30 @@ def list_tokenized_lower_nopunct(self):
class References(Texts):
"""Data holder class for references/targets."""

def __init__(self, data_file):
def __init__(self, data):
nltk_ensure_download('tokenizers/punkt')
super().__init__(key='target', data_file=data_file, tokenize_func=nltk.tokenize.word_tokenize)
super().__init__(key='target', data=data, tokenize_func=nltk.tokenize.word_tokenize)


class Submission:
"""Data class for multiple submissions."""

def __init__(self, data):
if isinstance(data, dict):
self.all_data = data
else:
self.filename = data
with open(data, 'r', encoding='UTF-8') as fh:
self.all_data = json.load(fh)
self.name = data['submission_name']
self.param_count = data.get('param_count')
self.entries = {}
for key in self.all_data.keys():
if not key.endswith('_val'):
continue
dataset_name = key[:-4]
self.entries[dataset_name] = Predictions(self.all_data[key])

def predictions_for(self, dataset_name: str) -> Optional[Predictions]:
"""Return per-dataset predictions"""
return self.entries.get(dataset_name)
34 changes: 7 additions & 27 deletions run_metrics.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,29 @@
#!/usr/bin/env python3

from argparse import ArgumentParser
from typing import Optional
import json
import sys

import gem_metrics


def compute_metrics(outs: gem_metrics.Predictions, refs: Optional[gem_metrics.References]) -> dict:
"""Main metrics computation routine. Expects a Predictions and a References object, holding
system outputs and corresponding references (References may be None -- only referenceless metrics
are computed in such a case).
Returns a dict with the results.
"""
# initialize values storage
values = {'predictions_file': outs.filename,
'N': len(outs)}

# compute referenceless metrics
for metric_class in gem_metrics.REFERENCELESS_METRICS:
metric = metric_class()
values.update(metric.compute(outs))

# compute ref-based metrics
if args.references_file is not None:
values['references_file'] = refs.filename
for metric_class in gem_metrics.REFERENCED_METRICS:
metric = metric_class()
values.update(metric.compute(outs, refs))
return values


def main(args):
"""Main entry point -- load inputs, call metrics measuring, print outputs"""

# load system predictions
outs = gem_metrics.Predictions(args.predictions_file)
with open(args.predictions_file, encoding='UTF-8') as fh:
data = json.load(fh)
if isinstance(data, dict) and 'submission_name' in data:
outs = gem_metrics.Submission(data)
else:
outs = gem_metrics.Predictions(data)

# load references, if available
if args.references_file is not None:
refs = gem_metrics.References(args.references_file)
assert(len(refs) == len(outs))

values = compute_metrics(outs, refs)
values = gem_metrics.compute(outs, refs)

# print output
out_fh = sys.stdout
Expand Down
11 changes: 11 additions & 0 deletions test_data/small-outs.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"language": "en",
"values": [
{
"generated": "Alimentum is located in the city centre. It is not family-friendly."
},
{
"generated": "Alimentum is a non family-friendly restaurant near Burger King in the city centre."
}
]
}
20 changes: 20 additions & 0 deletions test_data/small-refs.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"language": "en",
"values": [
{
"target": [
"There is a place in the city centre, Alimentum, that is not family-friendly.",
"In the city centre there is a venue name Alimentum, this is not a family-friendly venue.",
"Alimentum in city centre is not a family-friendly place."
]
},
{
"target": [
"Alimentum is not family-friendly, and is near the Burger King in the city centre.",
"Near Burger King in city centre is the adult establishment Alimentum.",
"Alimentum is not family-friendly. Alimentum is in the city center and it is near Burger King.",
"Alimentum is an adult establish found in the city centre area near Burger King."
]
}
]
}

0 comments on commit b0fb6aa

Please sign in to comment.