Skip to content

Commit

Permalink
Adds dataset arg to Component is_compatible() functions.
Browse files Browse the repository at this point in the history
Updates many interpreters and tests as a result.

Converts any file affected by the above to PEP 585 typings.

PiperOrigin-RevId: 485866683
  • Loading branch information
RyanMullins authored and LIT team committed Nov 3, 2022
1 parent ba98322 commit ecd3a66
Show file tree
Hide file tree
Showing 38 changed files with 685 additions and 426 deletions.
11 changes: 6 additions & 5 deletions lit_nlp/api/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ def run_with_metadata(self,
inputs = [ex['data'] for ex in indexed_inputs]
return self.run(inputs, model, dataset, model_outputs, config)

def is_compatible(self, model: lit_model.Model):
"""Return if interpreter is compatible with the given model."""
del model
def is_compatible(self, model: lit_model.Model,
dataset: lit_dataset.Dataset) -> bool:
"""Return if interpreter is compatible with the dataset and model."""
del dataset, model # Unused in base class
return True

def config_spec(self) -> types.Spec:
Expand Down Expand Up @@ -97,8 +98,8 @@ class ComponentGroup(Interpreter):
def __init__(self, subcomponents: dict[str, Interpreter]):
self._subcomponents = subcomponents

def meta_spec(self) -> dict[str, types.LitType]:
spec: dict[str, types.LitType] = {}
def meta_spec(self) -> types.Spec:
spec: types.Spec = {}
for component_name, component in self._subcomponents.items():
for field_name, field_spec in component.meta_spec().items():
spec[f'{component_name}: {field_name}'] = field_spec
Expand Down
8 changes: 5 additions & 3 deletions lit_nlp/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,9 @@ def supports_concurrent_predictions(self):
def predict_minibatch(self, inputs: List[JsonDict], **kw) -> List[JsonDict]:
return self.wrapped.predict_minibatch(inputs, **kw)

def predict(self, inputs: Iterable[JsonDict], **kw) -> Iterator[JsonDict]:
return self.wrapped.predict(inputs, **kw)
def predict(self, inputs: Iterable[JsonDict], *args,
**kw) -> Iterator[JsonDict]:
return self.wrapped.predict(inputs, *args, **kw)

# NOTE: if a subclass modifies predict(), it should also override this to
# call the custom predict() method - otherwise this will delegate to the
Expand Down Expand Up @@ -312,7 +313,8 @@ def __init__(self,
self._max_qps = max_qps
self._pool = multiprocessing.pool.ThreadPool(max_concurrent_requests)

def predict(self, inputs: Iterable[JsonDict], **kw) -> Iterator[JsonDict]:
def predict(self, inputs: Iterable[JsonDict], *unused_args,
**unused_kwargs) -> Iterator[JsonDict]:
batches = utils.batch_iterator(
inputs, max_batch_size=self.max_minibatch_size())
batches = utils.rate_limit(batches, self._max_qps)
Expand Down
47 changes: 32 additions & 15 deletions lit_nlp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,40 @@ def _build_metadata(self):
model_info = {}
for name, m in self._models.items():
mspec: lit_model.ModelSpec = m.spec()
info = {}
info['spec'] = {'input': mspec.input, 'output': mspec.output}
info = {
'description': m.description(),
'spec': {
'input': mspec.input,
'output': mspec.output
}
}

# List compatible datasets.
info['datasets'] = [
dname for dname, ds in self._datasets.items()
if mspec.is_compatible_with_dataset(ds.spec())
name for name, dataset in self._datasets.items()
if mspec.is_compatible_with_dataset(dataset.spec())
]
if len(info['datasets']) == 0: # pylint: disable=g-explicit-length-test
logging.error("Error: model '%s' has no compatible datasets!", name)
info['generators'] = [
name for name, gen in self._generators.items() if gen.is_compatible(m)
]
info['interpreters'] = [
name for name, interp in self._interpreters.items()
if interp.is_compatible(m)
]
info['description'] = m.description()

compat_gens: set[str] = set()
compat_interps: set[str] = set()

for d in info['datasets']:
dataset: lit_dataset.Dataset = self._datasets[d]
compat_gens.update([
name for name, gen in self._generators.items()
if gen.is_compatible(model=m, dataset=dataset)
])
compat_interps.update([
name for name, interp in self._interpreters.items()
if interp.is_compatible(model=m, dataset=dataset)
])

info['generators'] = [name for name in self._generators.keys()
if name in compat_gens]
info['interpreters'] = [name for name in self._interpreters.keys()
if name in compat_interps]
model_info[name] = info

dataset_info = {}
Expand Down Expand Up @@ -139,15 +156,15 @@ def _reconstitute_inputs(self, inputs: Sequence[Union[IndexedInput, str]],
def _save_datapoints(self, data, dataset_name: str, path: str, **unused_kw):
"""Save datapoints to disk."""
if self._demo_mode:
logging.warn('Attempted to save datapoints in demo mode.')
logging.warning('Attempted to save datapoints in demo mode.')
return None
return self._datasets[dataset_name].save(data['inputs'], path)

def _load_datapoints(self, unused_data, dataset_name: str, path: str,
**unused_kw):
"""Load datapoints from disk."""
if self._demo_mode:
logging.warn('Attempted to load datapoints in demo mode.')
logging.warning('Attempted to load datapoints in demo mode.')
return None
dataset = self._datasets[dataset_name].load(path)
return dataset.indexed_examples
Expand Down Expand Up @@ -538,7 +555,7 @@ def __init__(
for name, model in models.items()
}

self._datasets = dict(datasets)
self._datasets: dict[str, lit_dataset.Dataset] = dict(datasets)
# TODO(b/202210900): get rid of this, just dynamically create the empty
# dataset on the frontend.
self._datasets['_union_empty'] = lit_dataset.NoneDataset(self._models)
Expand Down
10 changes: 6 additions & 4 deletions lit_nlp/components/ablation_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,14 @@ def _generate_leave_one_out_ablation_score(
ret.append((field, idxs[i], loo_score))
return ret

def is_compatible(self, model: lit_model.Model) -> bool:
def is_compatible(self, model: lit_model.Model,
dataset: lit_dataset.Dataset) -> bool:
del dataset # Unused by AblationFlip
supported_inputs = (types.SparseMultilabel, types.TextSegment, types.URL)
supported_preds = (types.MulticlassPreds, types.RegressionScore)
input_fields = utils.find_spec_keys(model.input_spec(), supported_inputs)
output_fields = utils.find_spec_keys(model.output_spec(), supported_preds)
return (bool(input_fields) and bool(output_fields))
valid_inputs = utils.spec_contains(model.input_spec(), supported_inputs)
valid_outputs = utils.spec_contains(model.output_spec(), supported_preds)
return valid_inputs and valid_outputs

def config_spec(self) -> types.Spec:
return {
Expand Down
4 changes: 3 additions & 1 deletion lit_nlp/components/ablation_flip_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from absl.testing import absltest
from absl.testing import parameterized
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import model as lit_model
from lit_nlp.api import types
from lit_nlp.components import ablation_flip
Expand Down Expand Up @@ -71,7 +72,8 @@ def test_ablation_flip_is_compatible(self,
input_spec: types.Spec,
exp: bool):
model = model_ctr(input_spec)
compatible = self.ablation_flip.is_compatible(model)
compatible = self.ablation_flip.is_compatible(
model, lit_dataset.NoneDataset({'test': model}))
self.assertEqual(compatible, exp)

if __name__ == '__main__':
Expand Down
7 changes: 4 additions & 3 deletions lit_nlp/components/classification_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,10 @@ def run( # pytype: disable=signature-mismatch # overriding-parameter-type-chec
results.append(input_result)
return results

def is_compatible(self, model: lit_model.Model) -> bool:
output_spec = model.output_spec()
return True if self._find_supported_pred_keys(output_spec) else False
def is_compatible(self, model: lit_model.Model,
dataset: lit_dataset.Dataset) -> bool:
del dataset # Unused during model classification
return lit_utils.spec_contains(model.output_spec(), types.MulticlassPreds)

def _find_supported_pred_keys(self, output_spec: types.Spec) -> list[str]:
return lit_utils.find_spec_keys(output_spec, types.MulticlassPreds)
17 changes: 11 additions & 6 deletions lit_nlp/components/classification_results_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,29 @@
"""Tests for lit_nlp.components.classification_results."""

from absl.testing import absltest
from absl.testing import parameterized
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import dtypes
from lit_nlp.api import model as lit_model
from lit_nlp.components import classification_results
from lit_nlp.lib import testing_utils
import numpy as np


class ClassificationResultsTest(absltest.TestCase):
class ClassificationResultsTest(parameterized.TestCase):

def setUp(self):
super(ClassificationResultsTest, self).setUp()
self.interpreter = classification_results.ClassificationInterpreter()

def test_is_compatible(self):
self.assertTrue(self.interpreter.is_compatible(
testing_utils.TestModelClassification()))
self.assertFalse(self.interpreter.is_compatible(
testing_utils.TestRegressionModel({})))
@parameterized.named_parameters(
('classification', testing_utils.TestModelClassification(), True),
('regression', testing_utils.TestRegressionModel({}), False),
)
def test_is_compatible(self, model: lit_model.Model, epxected: bool):
compat = self.interpreter.is_compatible(
model, lit_dataset.NoneDataset({'test': model}))
self.assertEqual(compat, epxected)

def test_no_label(self):
dataset = lit_dataset.Dataset(None, None)
Expand Down
19 changes: 11 additions & 8 deletions lit_nlp/components/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,6 @@ def default_interpreters(models: dict[str, Model]) -> dict[str, Interpreter]:
'Integrated Gradients': gradient_maps.IntegratedGradients(),
'LIME': lime_explainer.LIME(),
}
metrics_group: ComponentGroup = ComponentGroup({
'regression': metrics.RegressionMetrics(),
'multiclass': metrics.MulticlassMetrics(),
'paired': metrics.MulticlassPairedMetrics(),
'bleu': metrics.CorpusBLEU(),
'rouge': metrics.RougeL(),
})
# Ensure the prediction analysis interpreters are included.
prediction_analysis_interpreters: dict[str, Interpreter] = {
'classification': classification_results.ClassificationInterpreter(),
Expand All @@ -94,7 +87,7 @@ def default_interpreters(models: dict[str, Model]) -> dict[str, Interpreter]:
'tcav': tcav.TCAV(),
'curves': curves.CurvesInterpreter(),
'thresholder': thresholder.Thresholder(),
'metrics': metrics_group,
'metrics': default_metrics(),
'pdp': pdp.PdpInterpreter(),
'Salience Clustering': salience_clustering.SalienceClustering(
gradient_map_interpreters),
Expand All @@ -105,3 +98,13 @@ def default_interpreters(models: dict[str, Model]) -> dict[str, Interpreter]:
**prediction_analysis_interpreters,
**embedding_based_interpreters)
return interpreters


def default_metrics() -> ComponentGroup:
return ComponentGroup({
'regression': metrics.RegressionMetrics(),
'multiclass': metrics.MulticlassMetrics(),
'paired': metrics.MulticlassPairedMetrics(),
'bleu': metrics.CorpusBLEU(),
'rouge': metrics.RougeL(),
})
16 changes: 10 additions & 6 deletions lit_nlp/components/curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================
"""An interpreters for generating data for ROC and PR curves."""

from typing import cast, List, Optional, Sequence, Text
from typing import cast, Optional, Sequence

from lit_nlp.api import components as lit_components
from lit_nlp.api import dataset as lit_dataset
Expand Down Expand Up @@ -98,11 +98,15 @@ def run_with_metadata(self,
# Create and return the result.
return {ROC_DATA: roc_data, PR_DATA: pr_data}

def is_compatible(self, model: lit_model.Model) -> bool:
# A model is compatible if it is a classification model and has
# reference to the ground truth in the dataset.
def is_compatible(self, model: lit_model.Model,
dataset: lit_dataset.Dataset) -> bool:
"""True if using a classification model and dataset has ground truth."""
output_spec = model.output_spec()
return True if self._find_supported_pred_keys(output_spec) else False
supported_keys = self._find_supported_pred_keys(output_spec)
has_parents = all(
cast(types.MulticlassPreds, output_spec[key]).parent in dataset.spec()
for key in supported_keys)
return bool(supported_keys) and has_parents

def config_spec(self) -> types.Spec:
# If a model is a multiclass classifier, a user can specify which
Expand All @@ -113,7 +117,7 @@ def config_spec(self) -> types.Spec:
def meta_spec(self) -> types.Spec:
return {ROC_DATA: types.CurveDataPoints(), PR_DATA: types.CurveDataPoints()}

def _find_supported_pred_keys(self, output_spec: types.Spec) -> List[Text]:
def _find_supported_pred_keys(self, output_spec: types.Spec) -> list[str]:
"""Returns the list of supported prediction keys in the model output.
Args:
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/components/curves_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def test_meta_spec(self):
('no_parent', NoParentTestModel(), False))
def test_model_compatibility(self, model: Model, exp_is_compat: bool):
"""A model is incompatible if prediction is not MulticlassPreds."""
self.assertEqual(self.ci.is_compatible(model), exp_is_compat)
self.assertEqual(self.ci.is_compatible(model, TestDataset()), exp_is_compat)


if __name__ == '__main__':
Expand Down

0 comments on commit ecd3a66

Please sign in to comment.