Skip to content

Commit

Permalink
Initial support for references in multi-file mode
Browse files Browse the repository at this point in the history
  • Loading branch information
tuetschek committed Jan 12, 2021
1 parent b0fb6aa commit 440dc8c
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 6 deletions.
5 changes: 5 additions & 0 deletions gem_metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,8 @@ def compute(outs: Predictions, refs: Optional[References]) -> dict:
metric = metric_class()
values.update(metric.compute(outs, refs))
return values


def load_references(dataset_name: str) -> Optional[References]:
"""Load a file with references for a standard GEM dataset."""
return None
5 changes: 5 additions & 0 deletions gem_metrics/texts.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,8 @@ def __init__(self, data):
def predictions_for(self, dataset_name: str) -> Optional[Predictions]:
"""Return per-dataset predictions"""
return self.entries.get(dataset_name)

@property
def datasets(self):
"""List of datasets for which there are predictions available."""
return list(self.entries.keys())
32 changes: 26 additions & 6 deletions run_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,37 @@ def main(args):
# load system predictions
with open(args.predictions_file, encoding='UTF-8') as fh:
data = json.load(fh)

# multi-file submissions
if isinstance(data, dict) and 'submission_name' in data:
outs = gem_metrics.Submission(data)
data = gem_metrics.Submission(data)

ref_data = None
if args.references_file:
with open(args.references_file, encoding='UTF-8') as fh:
raw_ref_data = json.load(fh)
assert(sorted(list(raw_ref_data.keys())) == sorted(data.datasets))
for dataset in data.datasets:
ref_data[dataset] = gem_metrics.References(ref_data[dataset])
values = {}
for dataset in data.datasets:
outs = data.predictions_for(dataset)
# use default reference files if no custom ones are provided
refs = ref_data[dataset] if ref_data else gem_metrics.load_references(dataset)
if refs:
assert(len(refs) == len(outs))
values[dataset] = gem_metrics.compute(outs, refs)

# single-file mode
else:
outs = gem_metrics.Predictions(data)

# load references, if available
if args.references_file is not None:
refs = gem_metrics.References(args.references_file)
assert(len(refs) == len(outs))
# load references, if available
if args.references_file is not None:
refs = gem_metrics.References(args.references_file)
assert(len(refs) == len(outs))

values = gem_metrics.compute(outs, refs)
values = gem_metrics.compute(outs, refs)

# print output
out_fh = sys.stdout
Expand Down

0 comments on commit 440dc8c

Please sign in to comment.