Skip to content

Commit

Permalink
Adds Metrics class to lit_nlp.api.components
Browse files Browse the repository at this point in the history
Refactors existing metrics implementations to inherit from this new base class.

PiperOrigin-RevId: 485867499
  • Loading branch information
RyanMullins authored and LIT team committed Nov 3, 2022
1 parent ecd3a66 commit de7d8ba
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 70 deletions.
77 changes: 76 additions & 1 deletion lit_nlp/api/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
"""Base classes for LIT backend components."""
import abc
import inspect
from typing import Optional, Sequence
from typing import Any, Optional, Sequence

from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import model as lit_model
from lit_nlp.api import types

JsonDict = types.JsonDict
IndexedInput = types.IndexedInput
MetricsDict = dict[str, float]


class Interpreter(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -173,6 +174,80 @@ def generate(self,
pass


class Metrics(Interpreter):
"""Base class for LIT metrics components."""

# Required methods implementations from Interpreter base class

def is_compatible(self, model: lit_model.Model,
dataset: lit_dataset.Dataset) -> bool:
"""True if the model and dataset support metric computation."""
for pred_spec in model.output_spec().values():
parent_key: Optional[str] = getattr(pred_spec, 'parent', None)
parent_spec: Optional[types.LitType] = dataset.spec().get(parent_key)
if self.is_field_compatible(pred_spec, parent_spec):
return True
return False

def meta_spec(self):
"""A dict of MetricResults defining the metrics computed by this class."""
raise NotImplementedError('Subclass should define its own meta spec.')

def run(
self,
inputs: Sequence[JsonDict],
model: lit_model.Model,
dataset: lit_dataset.Dataset,
model_outputs: Optional[list[JsonDict]] = None,
config: Optional[JsonDict] = None) -> list[JsonDict]:
raise NotImplementedError(
'Subclass should implement its own run using compute.')

def run_with_metadata(
self,
indexed_inputs: Sequence[IndexedInput],
model: lit_model.Model,
dataset: lit_dataset.IndexedDataset,
model_outputs: Optional[list[JsonDict]] = None,
config: Optional[JsonDict] = None) -> list[JsonDict]:
inputs = [inp['data'] for inp in indexed_inputs]
return self.run(inputs, model, dataset, model_outputs, config)

# New methods introduced by this subclass

def is_field_compatible(
self,
pred_spec: types.LitType,
parent_spec: Optional[types.LitType]) -> bool:
"""True if compatible with the prediction field and its parent."""
del pred_spec, parent_spec # Unused in base class
raise NotImplementedError('Subclass should implement field compatibility.')

def compute(
self,
labels: Sequence[Any],
preds: Sequence[Any],
label_spec: types.LitType,
pred_spec: types.LitType,
config: Optional[JsonDict] = None) -> MetricsDict:
"""Compute metric(s) given labels and predictions."""
raise NotImplementedError('Subclass should implement this, or override '
'compute_with_metadata() directly.')

def compute_with_metadata(
self,
labels: Sequence[Any],
preds: Sequence[Any],
label_spec: types.LitType,
pred_spec: types.LitType,
indices: Sequence[types.ExampleId],
metas: Sequence[JsonDict],
config: Optional[JsonDict] = None) -> MetricsDict:
"""As compute(), but with access to indices and metadata."""
del indices, metas # unused by Metrics base class
return self.compute(labels, preds, label_spec, pred_spec, config)


class Annotator(metaclass=abc.ABCMeta):
"""Base class for LIT annotator components.
Expand Down
52 changes: 11 additions & 41 deletions lit_nlp/components/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# ==============================================================================
"""Metric component and implementations."""

import abc
import collections
import enum
from typing import Any, Callable, Optional, Sequence, Union, cast
Expand Down Expand Up @@ -81,46 +80,11 @@ def nan_to_none(metrics: dict[str, float]) -> dict[str, Optional[float]]:
return {k: (v if not np.isnan(v) else None) for k, v in metrics.items()}


class SimpleMetrics(lit_components.Interpreter):
"""Base class for simple metrics, which should render in the main metrics table."""

def is_compatible(self, model: lit_model.Model,
dataset: lit_dataset.Dataset) -> bool:
"""Metrics should always return false for Model-level compatibility."""
del model, dataset # TODO(b/254832560): Use these once metrics get promoted
return False

@abc.abstractmethod
def is_field_compatible(self, pred_spec: LitType,
parent_spec: Optional[LitType]) -> bool:
"""Returns true if compatible with the predicted field and its parent."""
pass

def compute(self,
labels: Sequence[Any],
preds: Sequence[Any],
label_spec: LitType,
pred_spec: LitType,
config: Optional[JsonDict] = None) -> dict[str, float]:
"""Compute metric(s) between labels and predictions."""
raise NotImplementedError(
'Subclass should implement this, or override compute_with_metadata() directly.'
)

def compute_with_metadata(
self,
labels: Sequence[Any],
preds: Sequence[Any],
label_spec: LitType,
pred_spec: LitType,
indices: Sequence[types.ExampleId],
metas: Sequence[JsonDict],
config: Optional[JsonDict] = None) -> dict[str, float]:
"""As compute(), but has access to indices and metadata."""
return self.compute(labels, preds, label_spec, pred_spec, config)
class SimpleMetrics(lit_components.Metrics):
"""Base class for built-in metrics rendered in the main metrics table."""

def run(self,
inputs: list[JsonDict],
inputs: Sequence[JsonDict],
model: lit_model.Model,
dataset: lit_dataset.Dataset,
model_outputs: Optional[list[JsonDict]] = None,
Expand Down Expand Up @@ -206,8 +170,7 @@ def __init__(self, metrics: SimpleMetrics):
def is_compatible(self, model: lit_model.Model,
dataset: lit_dataset.Dataset) -> bool:
"""Metrics should always return false for Model-level compatibility."""
del model, dataset # TODO(b/254832560): Use these once metrics get promoted
return False
return self._metrics.is_compatible(model, dataset)

def is_field_compatible(self, pred_spec: LitType,
parent_spec: Optional[LitType]) -> bool:
Expand Down Expand Up @@ -464,6 +427,13 @@ class MulticlassPairedMetricsImpl(SimpleMetrics):
the model to the perturbations.
"""

def meta_spec(self) -> types.Spec:
return {
'num_pairs': _get_best_value_spec(),
'swap_rate': _get_best_value_spec(BestValue.ZERO),
'mean_jsd': _get_best_value_spec(BestValue.HIGHEST),
}

def is_field_compatible(self, pred_spec: LitType,
parent_spec: Optional[LitType]) -> bool:
"""Return true if compatible with this field."""
Expand Down
91 changes: 63 additions & 28 deletions lit_nlp/components/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# ==============================================================================
"""Tests for lit_nlp.components.metrics."""

from typing import Optional

from absl.testing import absltest
from absl.testing import parameterized
from lit_nlp.api import dataset as lit_dataset
Expand All @@ -25,18 +23,39 @@
from lit_nlp.lib import testing_utils

LitType = types.LitType
DUMMY_MODEL = testing_utils.TestModelBatched()


class TestGenTextModel(lit_model.Model):

def input_spec(self) -> types.Spec:
return {'input': types.TextSegment()}

def output_spec(self) -> types.Spec:
return {'output': types.GeneratedText(parent='input')}

def predict_minibatch(self,
inputs: list[types.JsonDict]) -> list[types.JsonDict]:
return [{'output': 'test_output'}] * len(inputs)


_CLASSIFICATION_MODEL = testing_utils.TestModelClassification()
_GENERATED_TEXT_MODEL = TestGenTextModel()
_REGRESSION_MODEL = testing_utils.TestIdentityRegressionModel()


class RegressionMetricsTest(parameterized.TestCase):

@parameterized.named_parameters(('with model', DUMMY_MODEL),
('without model', None))
def test_is_compatible(self, model: Optional[lit_model.Model]):
@parameterized.named_parameters(
('cls_model', _CLASSIFICATION_MODEL, False),
('gen_text_model', _GENERATED_TEXT_MODEL, False),
('reg_model', _REGRESSION_MODEL, True),
)
def test_is_compatible(self, model: lit_model.Model, expected: bool):
"""Always false to prevent use as explainer."""
regression_metrics = metrics.RegressionMetrics()
self.assertFalse(regression_metrics.is_compatible(
model, lit_dataset.NoneDataset({'test': model})))
compat = regression_metrics.is_compatible(
model, lit_dataset.NoneDataset({'test': model}))
self.assertEqual(compat, expected)

@parameterized.named_parameters(
('regression', types.RegressionScore(), None, True),
Expand Down Expand Up @@ -95,13 +114,17 @@ def test_compute_empty_labels(self):

class MulticlassMetricsTest(parameterized.TestCase):

@parameterized.named_parameters(('with model', DUMMY_MODEL),
('without model', None))
def test_is_compatible(self, model: Optional[lit_model.Model]):
@parameterized.named_parameters(
('cls_model', _CLASSIFICATION_MODEL, True),
('reg_model', _REGRESSION_MODEL, False),
('gen_text_model', _GENERATED_TEXT_MODEL, False),
)
def test_is_compatible(self, model: lit_model.Model, expected: bool):
"""Always false to prevent use as explainer."""
multiclass_metrics = metrics.MulticlassMetrics()
self.assertFalse(multiclass_metrics.is_compatible(
model, lit_dataset.NoneDataset({'test': model})))
compat = multiclass_metrics.is_compatible(
model, lit_dataset.NoneDataset({'test': model}))
self.assertEqual(compat, expected)

@parameterized.named_parameters(
('multiclass', types.MulticlassPreds(vocab=['']), None, True),
Expand Down Expand Up @@ -230,13 +253,17 @@ def test_compute_empty_labels(self):

class MulticlassPairedMetricsTest(parameterized.TestCase):

@parameterized.named_parameters(('with model', DUMMY_MODEL),
('without model', None))
def test_is_compatible(self, model: Optional[lit_model.Model]):
@parameterized.named_parameters(
('cls_model', _CLASSIFICATION_MODEL, True),
('reg_model', _REGRESSION_MODEL, False),
('gen_text_model', _GENERATED_TEXT_MODEL, False),
)
def test_is_compatible(self, model: lit_model.Model, expected: bool):
"""Always false to prevent use as explainer."""
multiclass_paired_metrics = metrics.MulticlassPairedMetrics()
self.assertFalse(multiclass_paired_metrics.is_compatible(
model, lit_dataset.NoneDataset({'test': model})))
compat = multiclass_paired_metrics.is_compatible(
model, lit_dataset.NoneDataset({'test': model}))
self.assertEqual(compat, expected)

@parameterized.named_parameters(
('multiclass', types.MulticlassPreds(vocab=['']), None, True),
Expand Down Expand Up @@ -307,13 +334,17 @@ def test_compute(self):

class CorpusBLEUTest(parameterized.TestCase):

@parameterized.named_parameters(('with model', DUMMY_MODEL),
('without model', None))
def test_is_compatible(self, model: Optional[lit_model.Model]):
@parameterized.named_parameters(
('cls_model', _CLASSIFICATION_MODEL, False),
('reg_model', _REGRESSION_MODEL, False),
('gen_text_model', _GENERATED_TEXT_MODEL, True),
)
def test_is_compatible(self, model: lit_model.Model, expected: bool):
"""Always false to prevent use as explainer."""
bleu_metrics = metrics.CorpusBLEU()
self.assertFalse(bleu_metrics.is_compatible(
model, lit_dataset.NoneDataset({'test': model})))
compat = bleu_metrics.is_compatible(
model, lit_dataset.NoneDataset({'test': model}))
self.assertEqual(compat, expected)

@parameterized.named_parameters(
('generated text + str', types.GeneratedText(), types.StringLitType(),
Expand Down Expand Up @@ -380,13 +411,17 @@ def test_compute_with_candidates(self):

class RougeLTest(parameterized.TestCase):

@parameterized.named_parameters(('with model', DUMMY_MODEL),
('without model', None))
def test_is_compatible(self, model: Optional[lit_model.Model]):
@parameterized.named_parameters(
('cls_model', _CLASSIFICATION_MODEL, False),
('reg_model', _REGRESSION_MODEL, False),
('gen_text_model', _GENERATED_TEXT_MODEL, True),
)
def test_is_compatible(self, model: lit_model.Model, expected: bool):
"""Always false to prevent use as explainer."""
rouge_metrics = metrics.RougeL()
self.assertFalse(rouge_metrics.is_compatible(
model, lit_dataset.NoneDataset({'test': model})))
compat = rouge_metrics.is_compatible(
model, lit_dataset.NoneDataset({'test': model}))
self.assertEqual(compat, expected)

@parameterized.named_parameters(
('generated text + str', types.GeneratedText(), types.StringLitType(),
Expand Down

0 comments on commit de7d8ba

Please sign in to comment.