Skip to content

Commit

Permalink
Adds multi-label metrics (accuracy, precision, recall, F1) to LIT.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 515069808
  • Loading branch information
RyanMullins authored and LIT team committed Mar 8, 2023
1 parent dd23083 commit c0e3663
Show file tree
Hide file tree
Showing 3 changed files with 525 additions and 0 deletions.
1 change: 1 addition & 0 deletions lit_nlp/components/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def default_metrics() -> ComponentGroup:
return ComponentGroup({
'regression': metrics.RegressionMetrics(),
'multiclass': metrics.MulticlassMetrics(),
'multilabel': metrics.MultilabelMetrics(),
'paired': metrics.MulticlassPairedMetrics(),
'bleu': metrics.CorpusBLEU(),
'rouge': metrics.RougeL(),
Expand Down
153 changes: 153 additions & 0 deletions lit_nlp/components/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from scipy import stats as scipy_stats
from scipy.spatial import distance as scipy_distance
from sklearn import metrics as sklearn_metrics
from sklearn import preprocessing as sklearn_preprocessing

from rouge_score import rouge_scorer

Expand All @@ -36,6 +37,8 @@
LitType = types.LitType
Spec = types.Spec

_MultiLabelBinarizer = sklearn_preprocessing.MultiLabelBinarizer


def map_pred_keys(
data_spec: Spec, model_output_spec: Spec,
Expand Down Expand Up @@ -559,6 +562,156 @@ def __init__(self):
ClassificationMetricsWrapper.__init__(self, MulticlassPairedMetricsImpl())


class MultilabelMetrics(SimpleMetrics):
"""Metrics for assessing the performance of multi-label learning models."""

def is_field_compatible(
self, pred_spec: types.LitType, parent_spec: Optional[types.LitType]
) -> bool:
"""Determines the compatibility of a field with these metrics.
Args:
pred_spec: The field in the model's output spec containing the predicted
labels, must be a `SparseMultilabelPreds` type.
parent_spec: The field in the dataset containing the ground truth, must be
a `SparseMultilabel` field.
Returns:
True if the pred_spec and parent_spec pair are compatible.
"""
pred_suppported = isinstance(pred_spec, types.SparseMultilabelPreds)
parent_supported = isinstance(parent_spec, types.SparseMultilabel)
return pred_suppported and parent_supported

def meta_spec(self) -> types.Spec:
"""Returns the Spec describing the computed metrics."""
return {
'exactmatch': types.MetricResult(
best_value=types.MetricBestValue.HIGHEST,
description=(
'Multi-label accuracy is the exact match ratio; the proportion '
'of exact matches between the predicted labels and the true '
'labels across all examples. Closer to 1 is better.'
),
),
'precision': types.MetricResult(
best_value=types.MetricBestValue.HIGHEST,
description=(
'The mean proportion of correctly predicted labels out of all '
'predicted labels across examples. Closer to 1 is better.'
),
),
'recall': types.MetricResult(
best_value=types.MetricBestValue.HIGHEST,
description=(
'The mean proportion of correctly predicted labels relative to '
'the true labels across examples. Closer to 1 is better.'
),
),
'f1': types.MetricResult(
best_value=types.MetricBestValue.HIGHEST,
description=(
'The mean performance of the model (i.e., the harmonic mean of '
'precision and recall) across examples. Closer to 1 is better.'
),
),
}

def compute(
self,
labels: Sequence[Sequence[str]],
preds: Sequence[types.ScoredTextCandidates],
label_spec: types.LitType,
pred_spec: types.LitType,
config: Optional[types.JsonDict] = None,
) -> lit_components.MetricsDict:
"""Computes standard metrics for multi-label classification models.
Args:
labels: Ground truth against which predictions are compared.
preds: The predictions made by the model.
label_spec: A `SparseMultilabel` instance describing the types of elements
in `labels`.
pred_spec: A `SparseMultilabelPreds` instance describing the types of
elements in `preds`.
config: unused parameter from base class.
Returns:
A dict containing the accuracy (exact match ratio), precision, recall, and
F1 score for the provided predictions given true labels.
Raises:
TypeError: If `label_spec` is not a `SparseMultilabel` instance or
`pred_spec` is not a `SparseMultilabelPreds` instance.
ValueError: If `labels` is not the same length as `preds`.
"""
# TODO(b/271864674): Use this config dict to get user-defined thresholds
del config # unused in multi-label metrics, for now.

if not labels or not preds:
return {}

num_labels = len(labels)
num_preds = len(preds)
if num_labels != num_preds:
raise ValueError(
'Must have exactly as many labels as predictions. Received '
f'{num_labels} labels and {num_preds} preds.'
)

if not isinstance(label_spec, types.SparseMultilabel):
raise TypeError(
'label_spec must be a SparseMultilabel, received '
f'{type(label_spec).__name__}'
)

if not isinstance(pred_spec, types.SparseMultilabelPreds):
raise TypeError(
'pred_spec must be a SparseMultilabelPreds, received '
f'{type(pred_spec).__name__}'
)

# Learn the complete vocabulary of the possible labels
if pred_spec.vocab: # Try to get the labels from the model's output spec
all_labels: list[Sequence[str]] = [pred_spec.vocab]
elif label_spec.vocab: # Or, try to get them from the dataset spec
all_labels: list[Sequence[str]] = [label_spec.vocab]
else: # Otherwise, derive them from the observed labels
# WARNING: this is only correct for metrics like precision, recall, and
# exact-match accuracy which do not depend on knowing the full label set.
# For per-label accuracy this will give incorrect results if not all
# labels are observed in a given sample.
all_labels: list[Sequence[str]] = []
all_labels.extend(labels)
all_labels.extend([{l for l, _ in p} for p in preds])

binarizer = _MultiLabelBinarizer()
binarizer.fit(all_labels)

# Next, extract the labels from the ScoredTextCandidates for binarization.
pred_labels = [
# TODO(b/271864674): Update this set comprehension to respect
# user-defined margins from the config dict or pred_spec.threshold.
{l for l, s in p if s is not None and s > 0.5} for p in preds
]

# Transform the true and predicted labels into the binarized vector space.
v_true = binarizer.transform(labels)
v_pred = binarizer.transform(pred_labels)

# Compute and return the metrics
return {
'exactmatch': sklearn_metrics.accuracy_score(v_true, v_pred),
'precision': sklearn_metrics.precision_score(
v_true, v_pred, average='samples'
),
'recall': sklearn_metrics.recall_score(
v_true, v_pred, average='samples'
),
'f1': sklearn_metrics.f1_score(v_true, v_pred, average='samples'),
}


class CorpusBLEU(SimpleMetrics):
"""Corpus BLEU score using SacreBLEU."""

Expand Down

0 comments on commit c0e3663

Please sign in to comment.