Skip to content

Commit

Permalink
Adds heatmap mode to feature attribution module and SHAP explainer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 436614966
  • Loading branch information
RyanMullins authored and LIT team committed Mar 23, 2022
1 parent 0c65066 commit 76379ad
Show file tree
Hide file tree
Showing 9 changed files with 470 additions and 35 deletions.
2 changes: 2 additions & 0 deletions lit_nlp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from lit_nlp.components import projection
from lit_nlp.components import salience_clustering
from lit_nlp.components import scrambler
from lit_nlp.components import shap_explainer
from lit_nlp.components import tcav
from lit_nlp.components import thresholder
from lit_nlp.components import umap
Expand Down Expand Up @@ -487,6 +488,7 @@ def __init__(
'pdp': pdp.PdpInterpreter(),
'Salience Clustering': salience_clustering.SalienceClustering(
gradient_map_interpreters),
'Tabular SHAP': shap_explainer.TabularShapExplainer(),
# Embedding projectors expose a standard interface, but get special
# handling so we can precompute the projections if requested.
'pca': projection.ProjectionManager(pca.PCAModel),
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/client/core/lit_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ export abstract class LitModule extends ReactiveElement {
@observable @property({type: Number}) selectionServiceIndex = 0;

// tslint:disable-next-line:no-any
private readonly latestLoadPromises = new Map<string, Promise<any>>();
protected readonly latestLoadPromises = new Map<string, Promise<any>>();

protected readonly apiService = app.getService(ApiService);
protected readonly appState = app.getService(AppState);
Expand Down
1 change: 1 addition & 0 deletions lit_nlp/client/elements/expansion_panel.css
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
height: 30px;
padding: 2px 8px;
border-bottom: 1px solid var(--lit-neutral-300);
border-top: 1px solid var(--lit-neutral-100);

display: flex;
flex-direction: row;
Expand Down
7 changes: 4 additions & 3 deletions lit_nlp/client/elements/interpreter_controls.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ import {isLitSubtype} from '../lib/utils';

import {styles} from './interpreter_controls.css';

interface Settings {
/** Settings for an interpreter */
export interface InterpreterSettings {
[name: string]: boolean | number | string | string[];
}

/** Custom click event for interpreter controls */
export interface InterpreterClick {
name: string;
settings: Settings;
settings: InterpreterSettings;
}

/**
Expand All @@ -51,7 +52,7 @@ export class InterpreterControls extends ReactiveElement {
@observable @property({type: String}) description = '';
@observable @property({type: String}) applyButtonText = 'Apply';
@property({type: Boolean, reflect: true}) applyButtonDisabled = false;
@observable settings: Settings = {};
@observable settings: InterpreterSettings = {};
@property({type: Boolean, reflect: true}) opened = false;

static override get styles() {
Expand Down
30 changes: 30 additions & 0 deletions lit_nlp/client/modules/feature_attribution_module.css
Original file line number Diff line number Diff line change
@@ -1,3 +1,33 @@
.module-container {
display: flex;
flex-direction: column;
}

.module-toolbar {
border-bottom: 1px solid var(--lit-neutral-300);
}

.module-results-area {
display: flex;
flex: 1 0 auto;
flex-direction: row;
}

.side-navigation {
border-right: 1px solid var(--lit-neutral-300);
max-width: 300px;
min-width: 300px;
overflow-y: auto;
}

.main-content {
max-width: calc(100% - 300px);
display: flex;
flex: 1 0 auto;
flex-direction: column;
}

.module-results {
flex: 1 0 auto;
overflow-y: scroll;
}
173 changes: 144 additions & 29 deletions lit_nlp/client/modules/feature_attribution_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,26 @@
// tslint:disable:no-new-decorators
import {html} from 'lit';
import {customElement} from 'lit/decorators';
import {styleMap} from 'lit/directives/style-map';
import {computed, observable} from 'mobx';

import {app} from '../core/app';
import {FacetsChange} from '../core/faceting_control';
import {LitModule} from '../core/lit_module';
import {TableData} from '../elements/table';
import {InterpreterClick, InterpreterSettings} from '../elements/interpreter_controls';
import {SortableTemplateResult, TableData} from '../elements/table';
import {IndexedInput, ModelInfoMap} from '../lib/types';
import * as utils from '../lib/utils';
import {findSpecKeys} from '../lib/utils';
import {AppState, GroupService} from '../services/services';
import {findSpecKeys, isLitSubtype} from '../lib/utils';
import {SignedSalienceCmap} from '../services/color_service';
import {NumericFeatureBins} from '../services/group_service';
import {AppState, GroupService} from '../services/services';

import {styles as sharedStyles} from '../lib/shared_styles.css';
import {styles} from './feature_attribution_module.css';

const ALL_DATA = 'Entire Dataset';
const SELECTION = 'Selection';

interface AttributionStats {
min: number;
Expand Down Expand Up @@ -83,7 +87,7 @@ export class FeatureAttributionModule extends LitModule {
return Object.values(modelSpecs).some(modelInfo => {
// The model directly outputs FeatureSalience
const hasIntrinsicSalience =
findSpecKeys(modelInfo.spec.output, 'FeatureSalience').length > 0;
findSpecKeys(modelInfo.spec.output, 'FeatureSalience').length > 0;

// At least one compatible interpreter outputs FeatureSalience
const canDeriveSalience = modelInfo.interpreters.some(name => {
Expand All @@ -105,13 +109,18 @@ export class FeatureAttributionModule extends LitModule {
// ---- Instance Properties ----

private readonly groupService = app.getService(GroupService);
private readonly colorMap = new SignedSalienceCmap();

@observable private startsOpen?: string;
@observable private isColored = false;
@observable private features: string[] = [];
@observable private bins: NumericFeatureBins = {};
@observable private readonly settings =
new Map<string, InterpreterSettings>();
@observable private summaries: SummariesMap = {};
@observable private readonly enabled: VisToggles = {
'model': this.hasIntrinsicSalience
'model': this.hasIntrinsicSalience,
[SELECTION]: false
};

@computed
Expand Down Expand Up @@ -179,6 +188,10 @@ export class FeatureAttributionModule extends LitModule {

await this.predict(ALL_DATA, this.appState.currentInputData);

if (this.enabled[SELECTION]) {
await this.predict(SELECTION, this.selectionService.selectedInputData);
}

if (this.features.length) {
for (const [facet, group] of Object.entries(this.facets)) {
await this.predict(facet, group.data);
Expand All @@ -196,8 +209,20 @@ export class FeatureAttributionModule extends LitModule {
*/
private async interpret(name: string, facet: string, data: IndexedInput[]) {
const runKey = `interpretations-${name}`;
const {configSpec} = this.appState.metadata.interpreters[name];
const defaultCallConfig: {[key: string]: unknown} = {};

for (const [configKey, configInfo] of Object.entries(configSpec)) {
if (configInfo.default) {
defaultCallConfig[configKey] = configInfo.default;
} else if (configInfo.vocab && configInfo.vocab.length) {
defaultCallConfig[configKey] = configInfo.vocab[0];
}
}

const callConfig = this.settings.get(name) || defaultCallConfig;
const promise = this.apiService.getInterpretations(
data, this.model, this.appState.currentDataset, name, {},
data, this.model, this.appState.currentDataset, name, callConfig,
`Running ${name}`);
const results =
(await this.loadLatest(runKey, promise)) as FeatureSalienceResult[];
Expand Down Expand Up @@ -231,6 +256,11 @@ export class FeatureAttributionModule extends LitModule {
await this.interpret(interpreter, ALL_DATA,
this.appState.currentInputData);

if (this.enabled[SELECTION]) {
await this.interpret(interpreter, SELECTION,
this.selectionService.selectedInputData);
}

if (this.features.length) {
for (const [facet, group] of Object.entries(this.facets)) {
await this.interpret(interpreter, facet, group.data);
Expand Down Expand Up @@ -264,44 +294,109 @@ export class FeatureAttributionModule extends LitModule {
return statsMap;
}

private renderFacetControls() {
private renderSecondaryControls() {
const change = () => {this.isColored = !this.isColored;};
const updateFacets = (event: CustomEvent<FacetsChange>) => {
this.features = event.detail.features;
this.bins = event.detail.bins;
};

// clang-format off
return html`<faceting-control @facets-change=${updateFacets}
contextName=${FeatureAttributionModule.title}>
</faceting-control>`;
return html`
<faceting-control @facets-change=${updateFacets}
contextName=${FeatureAttributionModule.title}>
</faceting-control>
<span style="felx: 1 1 auto;"></span>
<lit-checkbox label="Heatmap" ?checked=${this.isColored}
@change=${() => {change();}}>
</lit-checkbox>`;
// clang-format on
}

private renderSalienceControls() {
private renderPrimaryControls() {
const change = (name: string) => {
this.enabled[name] = !this.enabled[name];
};
const isNoSelectionEmtp = this.selectionService.selectedIds.length === 0;
const selectionTitle =
`Attribution is calculated for the entire dataset by default.${
isNoSelectionEmtp ? ' Make a selection to enable this checkbox.' :
''}`;
// clang-format off
return html`
<lit-checkbox label="Show attributions for selection"
title=${selectionTitle} ?disabled=${isNoSelectionEmtp}
?checked=${this.enabled[SELECTION]}
@change=${() => {change(SELECTION);}}>
</lit-checkbox>
<span style="width: 16px;"></span>
<span>Show attributions from:</span>
${this.hasIntrinsicSalience ?
html` <lit-checkbox label=${this.model}
?checked=${this.enabled['model']}
@change=${() => {change('model');}}>
</lit-checkbox>` : null}
${this.salienceInterpreters.map(interp =>
html`<lit-checkbox label=${interp} ?checked=${this.enabled[interp]}
@change=${() => {change(interp);}}>
</lit-checkbox>`)}`;
// clang-format on
}

private renderInterpreterControls(interpreter: string) {
const {configSpec, description} =
this.appState.metadata.interpreters[interpreter];
const clonedSpec = Object.assign({}, configSpec);
for (const fieldName of Object.keys(clonedSpec)) {
// If the interpreter uses a field matcher, then get the matching field
// names from the specified spec and use them as the vocab.
if (isLitSubtype(clonedSpec[fieldName], ['FieldMatcher'])) {
clonedSpec[fieldName].vocab =
this.appState.getSpecKeysFromFieldMatcher(
clonedSpec[fieldName], this.model);
}
}
const interpreterControlClick = (event: CustomEvent<InterpreterClick>) => {
this.settings.set(interpreter, event.detail.settings);
if (!this.enabled[interpreter]) this.enabled[interpreter] = true;
};
return html`
<lit-interpreter-controls @interpreter-click=${interpreterControlClick}
.spec=${configSpec} .name=${interpreter}
.description=${description || ''}
.opened=${this.enabled[interpreter]}>
</lit-interpreter-controls>`;
}

private renderColoredCell(value: number): SortableTemplateResult {
const txtColor = this.colorMap.textCmap(value);
const bgColor = this.colorMap.bgCmap(value);
const styles = styleMap({
'width': '100%',
'height': '100%',
'position': 'relative',
'text-align': 'right',
'color': txtColor,
'background-color': bgColor
});
const template = html`<div style=${styles}>${value.toFixed(4)}</div>`;
return {value, template};
}

private renderTable(summary: AttributionStatsMap) {
const columnNames = ['field', 'min', 'median', 'max', 'mean'];
const columnNames = [
{name: 'field', rightAlign: false},
{name: 'min', rightAlign: true},
{name: 'median', rightAlign: true},
{name: 'max', rightAlign: true},
{name: 'mean', rightAlign: true}
];
const tableData: TableData[] =
Object.entries(summary).map(([feature, stats]) => {
const {min, median, max, mean} = stats;
return [feature, min, median, max, mean];
let fieldsArray: number[] | SortableTemplateResult[] =
[min, median, max, mean];

if (this.isColored) {
fieldsArray = fieldsArray.map(v => this.renderColoredCell(v));
}

return [feature, ...fieldsArray];
});

// clang-format off
Expand All @@ -325,18 +420,38 @@ export class FeatureAttributionModule extends LitModule {
// clang-format off
return html`
<div class='module-container'>
<div class='module-toolbar'>${this.renderSalienceControls()}</div>
<div class='module-toolbar'>${this.renderFacetControls()}</div>
<div class='module-toolbar'>${this.renderPrimaryControls()}</div>
<div class='module-results-area'>
${Object.entries(this.summaries)
.sort()
.map(([facet, summary]) => html`
<div class="attribution-container">
<expansion-panel .label=${facet}
?expanded=${facet === this.startsOpen}>
${this.renderTable(summary)}
</expansion-panel>
</div>`)}
<div class='side-navigation'>
${this.salienceInterpreters.map(interpreter =>
this.renderInterpreterControls(interpreter))}
</div>
<div class='main-content'>
<div class='module-toolbar'>${this.renderSecondaryControls()}</div>
<div class='module-results'>
${!(Object.keys(this.summaries).length ||
this.latestLoadPromises.size)?
html`<div style="padding: 8px;">
Select a model or interpreter to show attributions.
Attributions are calculated from the entire dataset by
default, but can also be calculated for the selection
or any facets of the entire dataset. Faceting of
selections is not supported.
</div>`: null}
${Object.entries(this.summaries)
.sort()
.map(([facet, summary]) => html`
<div class="attribution-container">
<expansion-panel .label=${facet}
?expanded=${facet === this.startsOpen}>
${this.renderTable(summary)}
</expansion-panel>
</div>`)}
${this.latestLoadPromises.size ?
html`<lit-spinner size=${24} color="var(--lit-cyea-400)">
</lit-spinner>`: null}
</div>
</div>
</div>
</div>`;
// clang-format on
Expand Down

0 comments on commit 76379ad

Please sign in to comment.