Skip to content

Commit

Permalink
Adds components.core library w/ functions for default interpreters + …
Browse files Browse the repository at this point in the history
…generators

PiperOrigin-RevId: 469241495
  • Loading branch information
RyanMullins authored and LIT team committed Aug 22, 2022
1 parent ab057b5 commit 9ea4ab2
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 73 deletions.
78 changes: 5 additions & 73 deletions lit_nlp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,7 @@
from lit_nlp.api import layout
from lit_nlp.api import model as lit_model
from lit_nlp.api import types
from lit_nlp.components import ablation_flip
from lit_nlp.components import classification_results
from lit_nlp.components import curves
from lit_nlp.components import gradient_maps
from lit_nlp.components import hotflip
from lit_nlp.components import lemon_explainer
from lit_nlp.components import lime_explainer
from lit_nlp.components import metrics
from lit_nlp.components import model_salience
from lit_nlp.components import nearest_neighbors
from lit_nlp.components import pca
from lit_nlp.components import pdp
from lit_nlp.components import projection
from lit_nlp.components import regression_results
from lit_nlp.components import salience_clustering
from lit_nlp.components import scrambler
from lit_nlp.components import shap_explainer
from lit_nlp.components import tcav
from lit_nlp.components import thresholder
from lit_nlp.components import umap
from lit_nlp.components import word_replacer
from lit_nlp.components import core
from lit_nlp.lib import caching
from lit_nlp.lib import serialize
from lit_nlp.lib import ui_state
Expand Down Expand Up @@ -492,65 +472,17 @@ def __init__(
self._datasets = lit_dataset.IndexedDataset.index_all(
self._datasets, caching.input_hash)

# Generator initialization
if generators is not None:
self._generators = generators
else:
self._generators = {
'Ablation Flip': ablation_flip.AblationFlip(),
'Hotflip': hotflip.HotFlip(),
'Scrambler': scrambler.Scrambler(),
'Word Replacer': word_replacer.WordReplacer(),
}
self._generators = core.default_generators()

# Interpreter initialization
if interpreters is not None:
self._interpreters = interpreters

else:
metrics_group = lit_components.ComponentGroup({
'regression': metrics.RegressionMetrics(),
'multiclass': metrics.MulticlassMetrics(),
'paired': metrics.MulticlassPairedMetrics(),
'bleu': metrics.CorpusBLEU(),
'rouge': metrics.RougeL(),
})
gradient_map_interpreters = {
'Grad L2 Norm': gradient_maps.GradientNorm(),
'Grad ⋅ Input': gradient_maps.GradientDotInput(),
'Integrated Gradients': gradient_maps.IntegratedGradients(),
'LIME': lime_explainer.LIME(),
}
# pyformat: disable
self._interpreters: dict[str, lit_components.Interpreter] = {
'Model-provided salience': model_salience.ModelSalience(self._models),
'counterfactual explainer': lemon_explainer.LEMON(),
'tcav': tcav.TCAV(),
'curves': curves.CurvesInterpreter(),
'thresholder': thresholder.Thresholder(),
'metrics': metrics_group,
'pdp': pdp.PdpInterpreter(),
'Salience Clustering': salience_clustering.SalienceClustering(
gradient_map_interpreters),
'Tabular SHAP': shap_explainer.TabularShapExplainer(),
}
# pyformat: enable
self._interpreters.update(gradient_map_interpreters)

# Ensure the prediction analysis interpreters are included.
prediction_analysis_interpreters = {
'classification': classification_results.ClassificationInterpreter(),
'regression': regression_results.RegressionInterpreter(),
}
# Ensure the embedding-based interpreters are included.
embedding_based_interpreters = {
'nearest neighbors': nearest_neighbors.NearestNeighbors(),
# Embedding projectors expose a standard interface, but get special
# handling so we can precompute the projections if requested.
'pca': projection.ProjectionManager(pca.PCAModel),
'umap': projection.ProjectionManager(umap.UmapModel),
}
self._interpreters = dict(**self._interpreters,
**prediction_analysis_interpreters,
**embedding_based_interpreters)
self._interpreters = core.default_interpreters(self._models)

# Component to sync state from TS -> Python. Used in notebooks.
if sync_state:
Expand Down
107 changes: 107 additions & 0 deletions lit_nlp/components/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helpers for getting default values for LitApp configurations."""
from typing import Union
from lit_nlp.api import components as lit_components
from lit_nlp.api import model as lit_model
from lit_nlp.components import ablation_flip
from lit_nlp.components import classification_results
from lit_nlp.components import curves
from lit_nlp.components import gradient_maps
from lit_nlp.components import hotflip
from lit_nlp.components import lemon_explainer
from lit_nlp.components import lime_explainer
from lit_nlp.components import metrics
from lit_nlp.components import model_salience
from lit_nlp.components import nearest_neighbors
from lit_nlp.components import pca
from lit_nlp.components import pdp
from lit_nlp.components import projection
from lit_nlp.components import regression_results
from lit_nlp.components import salience_clustering
from lit_nlp.components import scrambler
from lit_nlp.components import shap_explainer
from lit_nlp.components import tcav
from lit_nlp.components import thresholder
from lit_nlp.components import umap
from lit_nlp.components import word_replacer

ComponentGroup = lit_components.ComponentGroup
Generator = lit_components.Generator
Interpreter = lit_components.Interpreter
Model = lit_model.Model


def default_generators() -> dict[str, Generator]:
"""Returns a dict of the default generators used in a LitApp."""
return {
'Ablation Flip': ablation_flip.AblationFlip(),
'Hotflip': hotflip.HotFlip(),
'Scrambler': scrambler.Scrambler(),
'Word Replacer': word_replacer.WordReplacer(),
}


def default_interpreters(models: dict[str, Model]) -> dict[str, Interpreter]:
"""Returns a dict of the default interpreters (and metrics) used in a LitApp.
Args:
models: A dictionary of models that included in the LitApp that may provide
thier own salience information.
"""
# Ensure the embedding-based interpreters are included.
embedding_based_interpreters: dict[str, Interpreter] = {
'nearest neighbors': nearest_neighbors.NearestNeighbors(),
# Embedding projectors expose a standard interface, but get special
# handling so we can precompute the projections if requested.
'pca': projection.ProjectionManager(pca.PCAModel),
'umap': projection.ProjectionManager(umap.UmapModel),
}
gradient_map_interpreters: dict[str, Interpreter] = {
'Grad L2 Norm': gradient_maps.GradientNorm(),
'Grad ⋅ Input': gradient_maps.GradientDotInput(),
'Integrated Gradients': gradient_maps.IntegratedGradients(),
'LIME': lime_explainer.LIME(),
}
metrics_group: ComponentGroup = ComponentGroup({
'regression': metrics.RegressionMetrics(),
'multiclass': metrics.MulticlassMetrics(),
'paired': metrics.MulticlassPairedMetrics(),
'bleu': metrics.CorpusBLEU(),
'rouge': metrics.RougeL(),
})
# Ensure the prediction analysis interpreters are included.
prediction_analysis_interpreters: dict[str, Interpreter] = {
'classification': classification_results.ClassificationInterpreter(),
'regression': regression_results.RegressionInterpreter(),
}
# pyformat: disable
interpreters: dict[str, Union[ComponentGroup, Interpreter]] = {
'Model-provided salience': model_salience.ModelSalience(models),
'counterfactual explainer': lemon_explainer.LEMON(),
'tcav': tcav.TCAV(),
'curves': curves.CurvesInterpreter(),
'thresholder': thresholder.Thresholder(),
'metrics': metrics_group,
'pdp': pdp.PdpInterpreter(),
'Salience Clustering': salience_clustering.SalienceClustering(
gradient_map_interpreters),
'Tabular SHAP': shap_explainer.TabularShapExplainer(),
}
# pyformat: enable
interpreters.update(**gradient_map_interpreters,
**prediction_analysis_interpreters,
**embedding_based_interpreters)
return interpreters

0 comments on commit 9ea4ab2

Please sign in to comment.