This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
srl_eval_scorer.py
196 lines (172 loc) · 8.05 KB
/
srl_eval_scorer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
from typing import Dict, List, Optional, Set
from collections import defaultdict
import logging
import os
import tempfile
import subprocess
import shutil
from allennlp.common.util import is_distributed
from allennlp.common.checks import ConfigurationError
from allennlp.training.metrics.metric import Metric
logger = logging.getLogger(__name__)
DEFAULT_SRL_EVAL_PATH = os.path.abspath(
os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "tools", "srl-eval.pl")
)
@Metric.register("srl_eval")
class SrlEvalScorer(Metric):
"""
This class uses the external srl-eval.pl script for computing the CoNLL SRL metrics.
AllenNLP contains the srl-eval.pl script, but you will need perl 5.x.
Note that this metric reads and writes from disk quite a bit. In particular, it
writes and subsequently reads two files per __call__, which is typically invoked
once per batch. You probably don't want to include it in your training loop;
instead, you should calculate this on a validation set only.
# Parameters
srl_eval_path : `str`, optional.
The path to the srl-eval.pl script.
ignore_classes : `List[str]`, optional (default=`None`).
A list of classes to ignore.
"""
def __init__(
self, srl_eval_path: str = DEFAULT_SRL_EVAL_PATH, ignore_classes: List[str] = None
) -> None:
self._srl_eval_path = srl_eval_path
self._ignore_classes = set(ignore_classes)
# These will hold per label span counts.
self._true_positives: Dict[str, int] = defaultdict(int)
self._false_positives: Dict[str, int] = defaultdict(int)
self._false_negatives: Dict[str, int] = defaultdict(int)
def __call__(
self, # type: ignore
batch_verb_indices: List[Optional[int]],
batch_sentences: List[List[str]],
batch_conll_formatted_predicted_tags: List[List[str]],
batch_conll_formatted_gold_tags: List[List[str]],
) -> None:
"""
# Parameters
batch_verb_indices : `List[Optional[int]]`, required.
The indices of the verbal predicate in the sentences which
the gold labels are the arguments for, or None if the sentence
contains no verbal predicate.
batch_sentences : `List[List[str]]`, required.
The word tokens for each instance in the batch.
batch_conll_formatted_predicted_tags : `List[List[str]]`, required.
A list of predicted CoNLL-formatted SRL tags (itself a list) to compute score for.
Use allennlp.models.semantic_role_labeler.convert_bio_tags_to_conll_format
to convert from BIO to CoNLL format before passing the tags into the metric,
if applicable.
batch_conll_formatted_gold_tags : `List[List[str]]`, required.
A list of gold CoNLL-formatted SRL tags (itself a list) to use as a reference.
Use allennlp.models.semantic_role_labeler.convert_bio_tags_to_conll_format
to convert from BIO to CoNLL format before passing the
tags into the metric, if applicable.
"""
if not os.path.exists(self._srl_eval_path):
raise ConfigurationError(f"srl-eval.pl not found at {self._srl_eval_path}.")
tempdir = tempfile.mkdtemp()
gold_path = os.path.join(tempdir, "gold.txt")
predicted_path = os.path.join(tempdir, "predicted.txt")
with open(predicted_path, "w", encoding="utf-8") as predicted_file, open(
gold_path, "w", encoding="utf-8"
) as gold_file:
for verb_index, sentence, predicted_tag_sequence, gold_tag_sequence in zip(
batch_verb_indices,
batch_sentences,
batch_conll_formatted_predicted_tags,
batch_conll_formatted_gold_tags,
):
from allennlp_models.structured_prediction.models.srl import (
write_conll_formatted_tags_to_file,
)
write_conll_formatted_tags_to_file(
predicted_file,
gold_file,
verb_index,
sentence,
predicted_tag_sequence,
gold_tag_sequence,
)
perl_script_command = ["perl", self._srl_eval_path, gold_path, predicted_path]
try:
completed_process = subprocess.run(
perl_script_command, stdout=subprocess.PIPE, universal_newlines=True, check=True
)
except FileNotFoundError:
raise FileNotFoundError(
"'File not found' while running the evaluation. Do you have perl installed?"
)
for line in completed_process.stdout.split("\n"):
stripped = line.strip().split()
if len(stripped) == 7:
tag = stripped[0]
# Overall metrics are calculated in get_metric, skip them here.
if tag == "Overall" or tag in self._ignore_classes:
continue
# This line contains results for a span
num_correct = int(stripped[1])
num_excess = int(stripped[2])
num_missed = int(stripped[3])
self._true_positives[tag] += num_correct
self._false_positives[tag] += num_excess
self._false_negatives[tag] += num_missed
# Note: we cannot aggregate across distributed workers because each worker
# may end up with different tags, and in such a case, the reduce operation
# will stall, or return with inaccurate values.
shutil.rmtree(tempdir)
def get_metric(self, reset: bool = False):
"""
# Returns
A Dict per label containing following the span based metrics:
- precision : `float`
- recall : `float`
- f1-measure : `float`
Additionally, an `overall` key is included, which provides the precision,
recall and f1-measure for all spans.
"""
if is_distributed():
raise RuntimeError(
"Distributed aggregation for `SrlEvalScorer` is currently not supported."
)
all_tags: Set[str] = set()
all_tags.update(self._true_positives.keys())
all_tags.update(self._false_positives.keys())
all_tags.update(self._false_negatives.keys())
all_metrics = {}
for tag in all_tags:
if tag == "overall":
raise ValueError(
"'overall' is disallowed as a tag type, "
"rename the tag type to something else if necessary."
)
precision, recall, f1_measure = self._compute_metrics(
self._true_positives[tag], self._false_positives[tag], self._false_negatives[tag]
)
precision_key = "precision" + "-" + tag
recall_key = "recall" + "-" + tag
f1_key = "f1-measure" + "-" + tag
all_metrics[precision_key] = precision
all_metrics[recall_key] = recall
all_metrics[f1_key] = f1_measure
# Compute the precision, recall and f1 for all spans jointly.
precision, recall, f1_measure = self._compute_metrics(
sum(self._true_positives.values()),
sum(self._false_positives.values()),
sum(self._false_negatives.values()),
)
all_metrics["precision-overall"] = precision
all_metrics["recall-overall"] = recall
all_metrics["f1-measure-overall"] = f1_measure
if reset:
self.reset()
return all_metrics
@staticmethod
def _compute_metrics(true_positives: int, false_positives: int, false_negatives: int):
precision = true_positives / (true_positives + false_positives + 1e-13)
recall = true_positives / (true_positives + false_negatives + 1e-13)
f1_measure = 2.0 * (precision * recall) / (precision + recall + 1e-13)
return precision, recall, f1_measure
def reset(self):
self._true_positives = defaultdict(int)
self._false_positives = defaultdict(int)
self._false_negatives = defaultdict(int)