Skip to content

Commit

Permalink
Adds exact match metrics to LIT.
Browse files Browse the repository at this point in the history
These metrics support comparing GeneratedText and GeneratedTextCandidates outputs against TextSegment or MultiSegmentAnnotations labels.

We expect most labels will be TextSegment, but some datasets, e.g., TyDi QA, are better represented as MultiSegmentAnnotations.

PiperOrigin-RevId: 487499002
  • Loading branch information
RyanMullins authored and LIT team committed Nov 10, 2022
1 parent fd2b976 commit eac9382
Show file tree
Hide file tree
Showing 3 changed files with 389 additions and 1 deletion.
1 change: 1 addition & 0 deletions lit_nlp/components/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,5 @@ def default_metrics() -> ComponentGroup:
'paired': metrics.MulticlassPairedMetrics(),
'bleu': metrics.CorpusBLEU(),
'rouge': metrics.RougeL(),
'exactmatch': metrics.ExactMatchMetrics(),
})
104 changes: 104 additions & 0 deletions lit_nlp/components/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,3 +809,107 @@ class BinaryConfusionMetrics(ClassificationMetricsWrapper):

def __init__(self):
ClassificationMetricsWrapper.__init__(self, BinaryConfusionMetricsImpl())


class ExactMatchMetrics(SimpleMetrics):
"""Exact match metrics for text generations."""

def meta_spec(self) -> types.Spec:
"""Returns the spec for the Exact Match metrics.
Returns
A dict of MetricResult specs for the metrics computed by this class.
"""
return {
'exactmatch': types.MetricResult(
best_value=types.MetricBestValue.HIGHEST,
description='The proportion of exact matches. Closer to 1 is '
'better.',
),
'exactmatch@1': types.MetricResult(
best_value=types.MetricBestValue.HIGHEST,
description='The proportion of exact matches for the top predicted '
'candidate. Closer to 1 is better.',
)
}

def is_field_compatible(self, pred_spec: LitType,
parent_spec: Optional[LitType]) -> bool:
"""Return true if compatible with this field.
Args:
pred_spec: The field in the model's output spec containing the generated
text, must be of type GeneratedText or GeneratedTextCandidates.
parent_spec: The field in the dataset containing the ground truth, must be
of type MultiSegmentAnnotations or TextSegment.
Returns:
True if the pred_spec and parent_spec pair are compatible.
"""
pred_supported = isinstance(pred_spec, (types.GeneratedText,
types.GeneratedTextCandidates))
parent_supported = isinstance(parent_spec, (types.TextSegment,
types.MultiSegmentAnnotations))
return pred_supported and parent_supported

def compute(
self,
labels: Sequence[Any],
preds: Sequence[Any],
label_spec: types.LitType,
pred_spec: types.LitType,
config: Optional[JsonDict] = None) -> lit_components.MetricsDict:
"""Compute exact matches between labels and predictions.
Args:
labels: Ground truth against which predictions are compared.
preds: The predictions made by the model.
label_spec: A `MultiSegmentAnnotations` or `TextSegment` spec describing
the types of elements in `labels`.
pred_spec: A `GeneratedText` or `GeneratedTextCandidates` spec describing
the types of elements in `preds`.
config: unused parameter from base class.
Returns:
A dict containing the proportion of exact matches in the predictions,
stored in the `exactmatch` key if `pred_spec` is `GeneratedText` or the
`exactmatch@1` key if `pred_spec` is `GeneratedTextCandidates`.
"""
del config

if not labels or not preds:
return {}

if not isinstance(label_spec,
(types.TextSegment, types.MultiSegmentAnnotations)):
raise TypeError('label_spec must be a TextSegment or '
'MultiSegmentAnnotations, received '
f'{type(pred_spec).__name__}')

if not isinstance(pred_spec,
(types.GeneratedText, types.GeneratedTextCandidates)):
raise TypeError('pred_spec must be a GeneratedText or '
'GeneratedTextCandidates, received '
f'{type(pred_spec).__name__}')

if isinstance(pred_spec, types.GeneratedTextCandidates):
texts = [types.GeneratedTextCandidates.top_text(v) for v in preds]
name_suffix = '@1'
else:
texts = preds
name_suffix = ''

matches = 0
for label, pred in zip(labels, texts):
if isinstance(label_spec, types.MultiSegmentAnnotations):
# MultiSegmentAnnotations means that labels is a
# Sequence[api.dtypes.AnnotationCluster].
answers = [annotation.label for annotation in label]
if any(pred == answer for answer in answers):
matches += 1
else:
# Otherwise, labels is a Sequence[str].
if pred == label:
matches += 1

return {f'exactmatch{name_suffix}': matches/len(preds)}

0 comments on commit eac9382

Please sign in to comment.