Skip to content

Commit

Permalink
Improves robustness of Metrics compatibility checks.
Browse files Browse the repository at this point in the history
Previously, Metrics used the `.is_compatible()` method inherited from Interpreter, and overloaded it to check compatibility against the LitType spec for the predicted field.

This adds a new `.is_field_compatible()` method that takes a required `pred_spec` and an optional `parent_spec` param to check compatibility with the predicted field and its relevant parent types.

Calls to `.is_compatible()` for Metrics classes always return `False` to prevent their misuse as Interpreters such as LIME, Integrated Gradients, etc.

PiperOrigin-RevId: 462600295
  • Loading branch information
RyanMullins authored and LIT team committed Jul 22, 2022
1 parent 65c5b8a commit 0d8341d
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 119 deletions.
169 changes: 97 additions & 72 deletions lit_nlp/components/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@

import abc
import collections
from typing import cast, Dict, List, Sequence, Tuple, Text, Optional, Callable, Any, Union
from typing import Any, Callable, cast, Optional, Sequence, Union

from absl import logging
from lit_nlp.api import components as lit_components
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import model as lit_model
from lit_nlp.api import types
from lit_nlp.components import classification_results
from lit_nlp.lib import utils
import numpy as np
import sacrebleu
from scipy import stats as scipy_stats
Expand All @@ -35,31 +34,31 @@

JsonDict = types.JsonDict
IndexedInput = types.IndexedInput
LitType = types.LitType
Spec = types.Spec


def map_pred_keys(
data_spec: lit_model.Spec, model_output_spec: lit_model.Spec,
predicate: Callable[[types.LitType], bool]) -> Dict[Text, Text]:
"""Find output fields matching predicate, and return a mapping to input fields."""
data_spec: Spec, model_output_spec: Spec,
predicate: Callable[[LitType, Optional[LitType]], bool]) -> dict[str, str]:
"""Returns a map of compatible output fields and their parent input fields."""
ret = {}
for pred_key in utils.find_keys(model_output_spec, predicate):
pred_field_spec = model_output_spec[pred_key]
label_key = getattr(pred_field_spec, 'parent', None)
if label_key is None:
logging.warning("Pred key '%s' has no parent field. Skipping.", pred_key)
continue # skip fields with no pointer
if label_key not in data_spec:
# This may be intentional, if running on unlabeled data.
logging.warning(
"Pred key '%s' points to missing label field '%s'. Skipping.",
pred_key, label_key)
for pred_key, pred_spec in model_output_spec.items():
parent_key: Optional[str] = getattr(pred_spec, 'parent', None)
if parent_key is None:
logging.warning("Skipping '%s': No parent provided.", pred_key)
continue
ret[pred_key] = label_key

parent_spec: Optional[LitType] = data_spec.get(parent_key)
if predicate(pred_spec, parent_spec):
ret[pred_key] = parent_key
else:
logging.warning(
"Skipping '%s': incompatible parent '%s'.", pred_key, parent_key)
return ret


def nan_to_none(metrics: Dict[str, float]) -> Dict[str, Optional[float]]:
def nan_to_none(metrics: dict[str, float]) -> dict[str, Optional[float]]:
# NaN is not a valid JSON value, so replace with None which will be
# serialized as null.
# TODO(lit-dev): consider moving this logic to serialize.py?
Expand All @@ -69,17 +68,22 @@ def nan_to_none(metrics: Dict[str, float]) -> Dict[str, Optional[float]]:
class SimpleMetrics(lit_components.Interpreter):
"""Base class for simple metrics, which should render in the main metrics table."""

def is_compatible(self, model: lit_model.Model) -> bool:
"""Metrics should always return false for Model-level compatibility."""
return False

@abc.abstractmethod
def is_compatible(self, field_spec: types.LitType) -> bool:
"""Return true if compatible with this field."""
def is_field_compatible(self, pred_spec: LitType,
parent_spec: Optional[LitType]) -> bool:
"""Returns true if compatible with the predicted field and its parent."""
pass

def compute(self,
labels: Sequence[Any],
preds: Sequence[Any],
label_spec: types.LitType,
pred_spec: types.LitType,
config: Optional[JsonDict] = None) -> Dict[Text, float]:
label_spec: LitType,
pred_spec: LitType,
config: Optional[JsonDict] = None) -> dict[str, float]:
"""Compute metric(s) between labels and predictions."""
raise NotImplementedError(
'Subclass should implement this, or override compute_with_metadata() directly.'
Expand All @@ -89,25 +93,26 @@ def compute_with_metadata(
self,
labels: Sequence[Any],
preds: Sequence[Any],
label_spec: types.LitType,
pred_spec: types.LitType,
label_spec: LitType,
pred_spec: LitType,
indices: Sequence[types.ExampleId],
metas: Sequence[JsonDict],
config: Optional[JsonDict] = None) -> Dict[Text, float]:
config: Optional[JsonDict] = None) -> dict[str, float]:
"""As compute(), but has access to indices and metadata."""
return self.compute(labels, preds, label_spec, pred_spec, config)

def run(self,
inputs: List[JsonDict],
inputs: list[JsonDict],
model: lit_model.Model,
dataset: lit_dataset.Dataset,
model_outputs: Optional[List[JsonDict]] = None,
model_outputs: Optional[list[JsonDict]] = None,
config: Optional[JsonDict] = None):
if model_outputs is None:
model_outputs = list(model.predict(inputs))

spec = model.spec()
field_map = map_pred_keys(dataset.spec(), spec.output, self.is_compatible)
field_map = map_pred_keys(dataset.spec(), spec.output,
self.is_field_compatible)
ret = []
for pred_key, label_key in field_map.items():
# Extract fields
Expand All @@ -132,16 +137,17 @@ def run_with_metadata(self,
indexed_inputs: Sequence[IndexedInput],
model: lit_model.Model,
dataset: lit_dataset.IndexedDataset,
model_outputs: Optional[List[JsonDict]] = None,
config: Optional[JsonDict] = None) -> List[JsonDict]:
model_outputs: Optional[list[JsonDict]] = None,
config: Optional[JsonDict] = None) -> list[JsonDict]:
if model_outputs is None:
model_outputs = list(model.predict_with_metadata(indexed_inputs))

# TODO(lit-team): pre-compute this mapping in constructor?
# This would require passing a model name to this function so we can
# reference a pre-computed list.
spec = model.spec()
field_map = map_pred_keys(dataset.spec(), spec.output, self.is_compatible)
field_map = map_pred_keys(dataset.spec(), spec.output,
self.is_field_compatible)
ret = []
for pred_key, label_key in field_map.items():
# Extract fields
Expand Down Expand Up @@ -179,20 +185,25 @@ class ClassificationMetricsWrapper(lit_components.Interpreter):
def __init__(self, metrics: SimpleMetrics):
self._metrics = metrics

def is_compatible(self, field_spec: types.LitType) -> bool:
def is_compatible(self, model: lit_model.Model) -> bool:
"""Metrics should always return false for Model-level compatibility."""
return False

def is_field_compatible(self, pred_spec: LitType,
parent_spec: Optional[LitType]) -> bool:
"""Return true if compatible with this field."""
return self._metrics.is_compatible(field_spec)
return self._metrics.is_field_compatible(pred_spec, parent_spec)

def run(self,
inputs: List[JsonDict],
inputs: list[JsonDict],
model: lit_model.Model,
dataset: lit_dataset.Dataset,
model_outputs: Optional[List[JsonDict]] = None,
model_outputs: Optional[list[JsonDict]] = None,
config: Optional[JsonDict] = None):
# Get margin for each input for each pred key and add them to a config dict
# to pass to the wrapped metrics.
field_map = map_pred_keys(dataset.spec(),
model.spec().output, self.is_compatible)
model.spec().output, self.is_field_compatible)
margin_config = {}
for pred_key in field_map:
field_config = config.get(pred_key) if config else None
Expand All @@ -208,12 +219,12 @@ def run_with_metadata(self,
indexed_inputs: Sequence[IndexedInput],
model: lit_model.Model,
dataset: lit_dataset.IndexedDataset,
model_outputs: Optional[List[JsonDict]] = None,
config: Optional[JsonDict] = None) -> List[JsonDict]:
model_outputs: Optional[list[JsonDict]] = None,
config: Optional[JsonDict] = None) -> list[JsonDict]:
# Get margin for each input for each pred key and add them to a config dict
# to pass to the wrapped metrics.
field_map = map_pred_keys(dataset.spec(),
model.spec().output, self.is_compatible)
model.spec().output, self.is_field_compatible)
margin_config = {}
for pred_key in field_map:
inputs = [ex['data'] for ex in indexed_inputs]
Expand All @@ -230,16 +241,18 @@ def run_with_metadata(self,
class RegressionMetrics(SimpleMetrics):
"""Standard regression metrics."""

def is_compatible(self, field_spec: types.LitType) -> bool:
def is_field_compatible(self, pred_spec: LitType,
parent_spec: Optional[LitType]) -> bool:
"""Return true if compatible with this field."""
return isinstance(field_spec, types.RegressionScore)
del parent_spec
return isinstance(pred_spec, types.RegressionScore)

def compute(self,
labels: Sequence[float],
preds: Sequence[float],
label_spec: types.Scalar,
pred_spec: types.RegressionScore,
config: Optional[JsonDict] = None) -> Dict[Text, float]:
config: Optional[JsonDict] = None) -> dict[str, float]:
"""Compute metric(s) between labels and predictions."""
del config

Expand Down Expand Up @@ -284,7 +297,7 @@ def get_all_metrics(self,
# null_idx as the negative / "other" class.
if null_idx is not None:
# Note: labels here are indices.
labels: List[int] = [
labels: list[int] = [
i for i in range(len(pred_spec.vocab)) if i != null_idx
]
ret['precision'] = sklearn_metrics.precision_score(
Expand Down Expand Up @@ -313,16 +326,18 @@ def get_all_metrics(self,

return ret

def is_compatible(self, field_spec: types.LitType) -> bool:
def is_field_compatible(self, pred_spec: LitType,
parent_spec: Optional[LitType]) -> bool:
"""Return true if compatible with this field."""
return isinstance(field_spec, types.MulticlassPreds)
del parent_spec
return isinstance(pred_spec, types.MulticlassPreds)

def compute(self,
labels: Sequence[Text],
labels: Sequence[str],
preds: Sequence[np.ndarray],
label_spec: types.CategoryLabel,
pred_spec: types.MulticlassPreds,
config: Optional[JsonDict] = None) -> Dict[Text, float]:
config: Optional[JsonDict] = None) -> dict[str, float]:
"""Compute metric(s) between labels and predictions."""
# TODO(lit-dev): compare on strings instead of converting to indices?
# This should be more robust to skew in label sets.
Expand Down Expand Up @@ -358,13 +373,15 @@ class MulticlassPairedMetricsImpl(SimpleMetrics):
the model to the perturbations.
"""

def is_compatible(self, field_spec: types.LitType) -> bool:
def is_field_compatible(self, pred_spec: LitType,
parent_spec: Optional[LitType]) -> bool:
"""Return true if compatible with this field."""
return isinstance(field_spec, types.MulticlassPreds)
del parent_spec
return isinstance(pred_spec, types.MulticlassPreds)

@staticmethod
def find_pairs(indices: Sequence[types.ExampleId],
metas: Sequence[JsonDict]) -> List[Tuple[int, int]]:
metas: Sequence[JsonDict]) -> list[tuple[int, int]]:
"""Find valid pairs in the current selection, and return list indices."""
id_to_position = {example_id: i for i, example_id in enumerate(indices)}
pairs = [] # (i,j) relative to labels and preds lists
Expand All @@ -381,11 +398,11 @@ def compute_with_metadata(
self,
labels: Sequence[Any],
preds: Sequence[Any],
label_spec: types.LitType,
label_spec: LitType,
pred_spec: types.MulticlassPreds,
indices: Sequence[types.ExampleId],
metas: Sequence[JsonDict],
config: Optional[JsonDict] = None) -> Dict[Text, float]:
config: Optional[JsonDict] = None) -> dict[str, float]:
del labels # Unused; we only care about preds.
del label_spec # Unused; we only care about preds.

Expand Down Expand Up @@ -423,18 +440,21 @@ class CorpusBLEU(SimpleMetrics):

BLEU_SMOOTHING_VAL = 0.1

def is_compatible(self, field_spec: types.LitType) -> bool:
def is_field_compatible(self, pred_spec: LitType,
parent_spec: LitType) -> bool:
"""Return true if compatible with this field."""
return isinstance(field_spec,
(types.GeneratedText, types.GeneratedTextCandidates))
is_pred_comaptible = isinstance(
pred_spec, (types.GeneratedText, types.GeneratedTextCandidates))
is_parent_compatible = isinstance(parent_spec, types.StringLitType)
return is_pred_comaptible and is_parent_compatible

def compute(self,
labels: Sequence[Text],
preds: Sequence[Union[Text, types.ScoredTextCandidates]],
labels: Sequence[str],
preds: Sequence[Union[str, types.ScoredTextCandidates]],
label_spec: types.TextSegment,
pred_spec: Union[types.GeneratedText,
types.GeneratedTextCandidates],
config: Optional[JsonDict] = None) -> Dict[Text, float]:
config: Optional[JsonDict] = None) -> dict[str, float]:
"""Compute metric(s) between labels and predictions."""
del label_spec
del config
Expand Down Expand Up @@ -462,18 +482,21 @@ def _score(self, reference, prediction):
return self._scorer.score(
target=reference, prediction=prediction)['rougeL'].fmeasure

def is_compatible(self, field_spec: types.LitType) -> bool:
def is_field_compatible(self, pred_spec: LitType,
parent_spec: Optional[LitType]) -> bool:
"""Return true if compatible with this field."""
return isinstance(field_spec,
(types.GeneratedText, types.GeneratedTextCandidates))
is_pred_comaptible = isinstance(
pred_spec, (types.GeneratedText, types.GeneratedTextCandidates))
is_parent_compatible = isinstance(parent_spec, types.StringLitType)
return is_pred_comaptible and is_parent_compatible

def compute(self,
labels: Sequence[Text],
preds: Sequence[Union[Text, types.ScoredTextCandidates]],
labels: Sequence[str],
preds: Sequence[Union[str, types.ScoredTextCandidates]],
label_spec: types.TextSegment,
pred_spec: Union[types.GeneratedText,
types.GeneratedTextCandidates],
config: Optional[JsonDict] = None) -> Dict[Text, float]:
config: Optional[JsonDict] = None) -> dict[str, float]:
"""Compute metric(s) between labels and predictions."""
del label_spec
del config
Expand All @@ -496,7 +519,7 @@ class BinaryConfusionMetricsImpl(SimpleMetrics):
def get_all_metrics(self,
y_true: Sequence[int],
y_pred: Sequence[int],
vocab: Sequence[Text],
vocab: Sequence[str],
null_idx: Optional[int] = None):
# Filter out unlabeled examples before calculating metrics.
labeled_example_indices = [
Expand All @@ -514,19 +537,21 @@ def get_all_metrics(self,
ret['TP'] = matrix[1][1]
return ret

def is_compatible(self, field_spec: types.LitType) -> bool:
def is_field_compatible(self, pred_spec: LitType,
parent_spec: LitType) -> bool:
"""Return true if binary classification with ground truth."""
if not isinstance(field_spec, types.MulticlassPreds):
if not (isinstance(pred_spec, types.MulticlassPreds) and
isinstance(parent_spec, types.CategoryLabel)):
return False
class_spec = cast(types.MulticlassPreds, field_spec)
return len(class_spec.vocab) == 2 and class_spec.parent
class_spec = cast(types.MulticlassPreds, pred_spec)
return len(class_spec.vocab) == 2

def compute(self,
labels: Sequence[Text],
labels: Sequence[str],
preds: Sequence[np.ndarray],
label_spec: types.CategoryLabel,
pred_spec: types.MulticlassPreds,
config: Optional[JsonDict] = None) -> Dict[Text, float]:
config: Optional[JsonDict] = None) -> dict[str, float]:
"""Compute metric(s) between labels and predictions."""
del label_spec # Unused; get vocab from pred_spec.

Expand Down

0 comments on commit 0d8341d

Please sign in to comment.