Skip to content

Commit

Permalink
Add DataService and use to include TCAV cosine similarity scores in s…
Browse files Browse the repository at this point in the history
…calars and data table module.

DataService stores new (non-input, non-model-output) columns of data for each datapoint in the dataset, and fetches the values for these columns when new datapoints are created.

TCAV module adds a new column when a statistically-significant CAV is created, containing the cosine similarity of each datapoint to the CAV. The TCAV interpreter is updated to be able to generate these scores from a list of datapoints and a CAV, if a pre-computed CAV is supplied to the interpreter.

The scalars module plots this new value (and any other scalars in the data service).
The data table module will this new value (and any other cols added by data service that are tagged as to be visualized in the data table).

The group service and color service also make use of the data service to take into account these new sources of information per datapoint, so datapoints can be colored and faceted by these new columns now.

PiperOrigin-RevId: 425972500
  • Loading branch information
jameswex authored and LIT team committed Feb 2, 2022
1 parent a7c1928 commit 9bdc23e
Show file tree
Hide file tree
Showing 21 changed files with 838 additions and 93 deletions.
28 changes: 28 additions & 0 deletions lit_nlp/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class LitType(metaclass=abc.ABCMeta):
"""Base class for LIT Types."""
required: bool = True # for input fields, mark if required by the model.
annotated: bool = False # If this type is created from an Annotator.
show_in_data_table = True # If true, show this info the data table.
# TODO(lit-dev): Add defaults for all LitTypes
default = None # an optional default value for a given type.

Expand Down Expand Up @@ -75,6 +76,18 @@ def from_json(d: JsonDict):
del d["__mro__"]
return cls(**d)

# TODO(b/162269499): remove this once we have a proper implementation of
# these types on the frontend.
@classmethod
def cls_to_json(cls) -> JsonDict:
"""Serialize class info to JSON."""
d = {}
d["__class__"] = "type"
d["__name__"] = cls.__name__
# All parent classes, from method resolution order (mro).
# Use this to check inheritance on the frontend.
d["__mro__"] = [a.__name__ for a in cls.mro()]
return d

Spec = Dict[Text, LitType]

Expand Down Expand Up @@ -103,6 +116,21 @@ def remap_spec(spec: Spec, keymap: Dict[str, str]) -> Spec:
return ret


# TODO(b/162269499): remove this once we have a proper implementation of
# these types on the frontend.
def all_littypes():
"""Return json of class info for all LitType classes."""
def all_subclasses(cls):
# pylint: disable=g-complex-comprehension
types_set = set(cls.__subclasses__()).union(
[s for c in cls.__subclasses__() for s in all_subclasses(c)])
# pylint: enable=g-complex-comprehension
return list(types_set)

classes = all_subclasses(LitType)
return {cls.__name__: cls.cls_to_json() for cls in classes}


##
# Concrete type clases
# LINT.IfChange
Expand Down
1 change: 1 addition & 0 deletions lit_nlp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def _build_metadata(self):
'generators': generator_info,
'interpreters': interpreter_info,
'layouts': self._layouts,
'littypes': types.all_littypes(),
# Global configuration
'demoMode': self._demo_mode,
'defaultLayout': self._default_layout,
Expand Down
8 changes: 6 additions & 2 deletions lit_nlp/client/core/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import {Constructor, LitComponentLayouts} from '../lib/types';
import {ApiService} from '../services/api_service';
import {ClassificationService} from '../services/classification_service';
import {ColorService} from '../services/color_service';
import {DataService} from '../services/data_service';
import {FocusService} from '../services/focus_service';
import {GroupService} from '../services/group_service';
import {LitService} from '../services/lit_service';
Expand Down Expand Up @@ -130,11 +131,13 @@ export class LitApp {
const regressionService = new RegressionService(apiService, appState);
const settingsService =
new SettingsService(appState, modulesService, selectionService0);
const groupService = new GroupService(appState);
const dataService = new DataService(appState);
const groupService = new GroupService(appState, dataService);
const classificationService =
new ClassificationService(apiService, appState, groupService);
const colorService = new ColorService(
appState, groupService, classificationService, regressionService);
appState, groupService, classificationService, regressionService,
dataService);
const focusService = new FocusService(selectionService0);

// Initialize url syncing of state
Expand All @@ -145,6 +148,7 @@ export class LitApp {
this.services.set(AppState, appState);
this.services.set(ClassificationService, classificationService);
this.services.set(ColorService, colorService);
this.services.set(DataService, dataService);
this.services.set(FocusService, focusService);
this.services.set(GroupService, groupService);
this.services.set(ModulesService, modulesService);
Expand Down
5 changes: 3 additions & 2 deletions lit_nlp/client/core/faceting_control.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import {observable} from 'mobx';

import {app} from '../core/app';
import {ReactiveElement} from '../lib/elements';
import {getStepSizeGivenRange} from '../lib/utils';
import {FacetingConfig, FacetingMethod, GroupService, NumericFeatureBins} from '../services/group_service';

import {styles as sharedStyles} from '../lib/shared_styles.css';
Expand Down Expand Up @@ -282,7 +283,7 @@ export class FacetingControl extends ReactiveElement {
/** This method creates two bins given a threshold set by the user. */
const [min, max] = this.groupService.numericalFeatureRanges[feature];
const delta = max - min;
const step = delta > 100 ? 10: delta > 10 ? 1 : delta > 1 ? 0.1 : 0.01;
const step = getStepSizeGivenRange(delta);
const value = (config.threshold || (delta/2 + min)).toString();
// clang-format off
inputField = html`
Expand All @@ -295,7 +296,7 @@ export class FacetingControl extends ReactiveElement {
* This method infers the number of bins to create from the spec, and
* therefore does not take user input, so we use this div for alignment.
*/
inputField = html`<div class="no-input">}</div>`;
inputField = html`<div class="no-input"></div>`;
}

const rowClass = classMap({
Expand Down
5 changes: 3 additions & 2 deletions lit_nlp/client/core/faceting_control_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {Checkbox} from '@material/mwc-checkbox';
import {LitApp} from '../core/app';
import {LitCheckbox} from '../elements/checkbox';
import {mockMetadata} from '../lib/testing_utils';
import {AppState, GroupService} from '../services/services';
import {AppState, DataService, GroupService} from '../services/services';


describe('faceting control test', () => {
Expand All @@ -26,13 +26,14 @@ describe('faceting control test', () => {
// Set up.
const app = new LitApp();
const appState = app.getService(AppState);
const dataService = app.getService(DataService);
// Stop appState from trying to make the call to the back end
// to load the data (causes test flakiness).
spyOn(appState, 'loadData').and.returnValue(Promise.resolve());
appState.metadata = mockMetadata;
appState.setCurrentDataset('sst_dev');

const groupService = new GroupService(appState);
const groupService = new GroupService(appState, dataService);
facetCtrl = new FacetingControl(groupService);
document.body.appendChild(facetCtrl);
document.body.addEventListener('facets-change', facetChangeHandler);
Expand Down
37 changes: 37 additions & 0 deletions lit_nlp/client/lib/testing_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,36 @@ export const mockMetadata: LitMetadata = {
}
}
},
'color_test': {
'spec': {
'testNumFeat0': {
'__class__': 'LitType',
'__name__': 'Scalar',
'__mro__': ['Scalar', 'LitType', 'object'],
'required': true
},
'testNumFeat1': {
'__class__': 'LitType',
'__name__': 'Scalar',
'__mro__': ['Scalar', 'LitType', 'object'],
'required': true
},
'testFeat0': {
'__class__': 'LitType',
'__name__': 'CategoryLabel',
'__mro__': ['CategoryLabel', 'LitType', 'object'],
'required': true,
'vocab': ['0', '1']
},
'testFeat1': {
'__class__': 'LitType',
'__name__': 'CategoryLabel',
'__mro__': ['CategoryLabel', 'LitType', 'object'],
'required': true,
'vocab': ['a', 'b', 'c']
}
}
},
'penguin_dev': {
'spec': {
'body_mass_g': {
Expand Down Expand Up @@ -324,6 +354,13 @@ export const mockMetadata: LitMetadata = {
'umap': emptySpec(),
},
'layouts': {},
'littypes': {
'Scalar': {
'__class__': 'type',
'__mro__': ['Scalar', 'LitType', 'object'],
'__name__': 'Scalar'
}
},
'demoMode': false,
'defaultLayout': 'default',
'canonicalURL': undefined
Expand Down
6 changes: 4 additions & 2 deletions lit_nlp/client/lib/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import {chunkWords, isLitSubtype} from './utils';
export type D3Selection = d3.Selection<any, any, any, any>;

export type LitClass = 'LitType';
export type LitName = 'LitType'|'String'|'TextSegment'|'GeneratedText'|
export type LitName = 'type'|'LitType'|'String'|'TextSegment'|'GeneratedText'|
'GeneratedTextCandidates'|'ReferenceTexts'|'URL'|'SearchQuery'|'Tokens'|
'TokenTopKPreds'|'Scalar'|'RegressionScore'|'CategoryLabel'|
'MulticlassPreds'|'SequenceTags'|'SpanLabels'|'EdgeLabels'|
Expand All @@ -39,7 +39,7 @@ export const listFieldTypes: LitName[] =
['Tokens', 'SequenceTags', 'SpanLabels', 'EdgeLabels', 'SparseMultilabel'];

export interface LitType {
__class__: LitClass;
__class__: LitClass|'type';
__name__: LitName;
__mro__: string[];
parent?: string;
Expand All @@ -65,6 +65,7 @@ export interface LitType {
token_prefix?: string;
select_all?: boolean;
autosort?: boolean;
show_in_data_table?: boolean;
}

export interface Spec {
Expand Down Expand Up @@ -118,6 +119,7 @@ export interface LitMetadata {
generators: ComponentInfoMap;
interpreters: ComponentInfoMap;
layouts: LitComponentLayouts;
littypes: Spec;
demoMode: boolean;
defaultLayout: string;
canonicalURL?: string;
Expand Down
5 changes: 5 additions & 0 deletions lit_nlp/client/lib/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -409,3 +409,8 @@ export function replaceNth(str: string, orig: string, replacement: string,
RegExp("^(?:.*?" + escapedOrig + "){" + n.toString() + "}"),
x => x.replace(RegExp(escapedOrig + "$"), replacement));
}

/** Return a good step size given a range of values. */
export function getStepSizeGivenRange(range: number) {
return range > 100 ? 10: range > 10 ? 1 : range > 1 ? 0.1 : 0.01;
}
12 changes: 8 additions & 4 deletions lit_nlp/client/modules/data_table_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import {formatForDisplay, IndexedInput, ModelInfoMap, Spec} from '../lib/types';
import {compareArrays, findSpecKeys, shortenId} from '../lib/utils';
import {ClassificationInfo} from '../services/classification_service';
import {RegressionInfo} from '../services/regression_service';
import {ClassificationService, FocusService, RegressionService, SelectionService} from '../services/services';
import {ClassificationService, DataService, FocusService, RegressionService, SelectionService} from '../services/services';

import {styles} from './data_table_module.css';

Expand All @@ -58,6 +58,7 @@ export class DataTableModule extends LitModule {
app.getService(ClassificationService);
private readonly regressionService = app.getService(RegressionService);
private readonly focusService = app.getService(FocusService);
private readonly dataService = app.getService(DataService);

@observable columnVisibility = new Map<string, boolean>();
@observable
Expand All @@ -81,8 +82,10 @@ export class DataTableModule extends LitModule {
get keys(): string[] {
// Use currentInputData to get keys / column names because filteredData
// might have 0 length;
const keys = this.appState.currentInputDataKeys.filter(d => d !== 'meta');
return keys;
const keys = this.appState.currentInputDataKeys;
const dataKeys = this.dataService.cols.filter(
col => col.dataType.show_in_data_table).map(col => col.name);
return keys.concat(dataKeys);
}

@computed
Expand Down Expand Up @@ -192,7 +195,8 @@ export class DataTableModule extends LitModule {

const dataEntries =
this.keys.filter(k => this.columnVisibility.get(k))
.map(k => formatForDisplay(d.data[k], this.dataSpec[k]));
.map(k => formatForDisplay(this.dataService.getVal(d.id, k),
this.dataSpec[k]));

const ret: TableData = [index];
if (this.columnVisibility.get('id')) {
Expand Down
31 changes: 22 additions & 9 deletions lit_nlp/client/modules/scalar_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import {LitModule} from '../core/lit_module';
import {D3Selection, formatForDisplay, IndexedInput, ModelInfoMap, ModelSpec, Preds, Spec} from '../lib/types';
import {doesOutputSpecContain, findSpecKeys, getThresholdFromMargin, isLitSubtype} from '../lib/utils';
import {FocusData} from '../services/focus_service';
import {ClassificationService, ColorService, GroupService, FocusService, RegressionService} from '../services/services';
import {ClassificationService, ColorService, DataService, GroupService, FocusService, RegressionService} from '../services/services';

import {styles} from './scalar_module.css';
import {styles as sharedStyles} from '../lib/shared_styles.css';
Expand Down Expand Up @@ -84,6 +84,7 @@ export class ScalarModule extends LitModule {
private readonly groupService = app.getService(GroupService);
private readonly regressionService = app.getService(RegressionService);
private readonly focusService = app.getService(FocusService);
private readonly dataService = app.getService(DataService);

private readonly inputIDToIndex = new Map();
private resizeObserver!: ResizeObserver;
Expand All @@ -101,12 +102,12 @@ export class ScalarModule extends LitModule {
`translate(${ScalarModule.plotLeftMargin},${ScalarModule.plotTopMargin})`;

@computed
private get inputKeys() {
private get scalarKeys() {
return this.groupService.numericalFeatureNames;
}

@computed
private get scalarKeys() {
private get scalarModelOutputKeys() {
const outputSpec = this.appState.currentModelSpecs[this.model].spec.output;
return findSpecKeys(outputSpec, 'Scalar');
}
Expand Down Expand Up @@ -137,6 +138,18 @@ export class ScalarModule extends LitModule {
}
});

// Update predictions when new scalar columns exist to plot.
const getScalarKeys = () => this.scalarKeys;
this.reactImmediately(getScalarKeys, scalarKeys => {
this.updatePredictions(this.appState.currentInputData);
});

// Update predictions when new data values are set.
const getDataVals = () => this.dataService.dataVals;
this.reactImmediately(getDataVals, dataVals => {
this.updatePredictions(this.appState.currentInputData);
});

const getSelectedInputData = () => this.selectionService.selectedInputData;
this.reactImmediately(getSelectedInputData, selectedInputData => {
if (selectedInputData != null) {
Expand Down Expand Up @@ -336,8 +349,8 @@ export class ScalarModule extends LitModule {
const pred = Object.assign(
{}, classificationPreds[i], scalarPreds[i], regressionPreds[i],
{id: currId});
for (const inputKey of this.inputKeys) {
pred[inputKey] = currentInputData[i].data[inputKey];
for (const scalarKey of this.scalarKeys) {
pred[scalarKey] = this.dataService.getVal(currId, scalarKey);
}
preds.push(pred);
}
Expand Down Expand Up @@ -378,7 +391,7 @@ export class ScalarModule extends LitModule {
scoreRange[0] = scoreRange[0] - .1;
scoreRange[1] = scoreRange[1] + .1;
}
} else if (this.inputKeys.indexOf(key) !== -1) {
} else if (this.scalarKeys.indexOf(key) !== -1) {
scoreRange = this.groupService.numericalFeatureRanges[key];
}

Expand Down Expand Up @@ -433,7 +446,7 @@ export class ScalarModule extends LitModule {
const scatterplot = item as SVGGElement;
const key = (item as HTMLElement).dataset['key'];

if (key == null || this.inputKeys.indexOf(key) !== -1) {
if (key == null || this.scalarKeys.indexOf(key) !== -1) {
return;
}

Expand Down Expand Up @@ -762,10 +775,10 @@ export class ScalarModule extends LitModule {
// clang-format off
return html`
<div id='container'>
${this.scalarKeys.map(key => this.renderPlot(key, ''))}
${this.scalarModelOutputKeys.map(key => this.renderPlot(key, ''))}
${this.classificationKeys.map(key =>
this.renderClassificationGroup(key))}
${this.inputKeys.map(key => this.renderPlot(key, ''))}
${this.scalarKeys.map(key => this.renderPlot(key, ''))}
</div>
`;
// clang-format on
Expand Down

0 comments on commit 9bdc23e

Please sign in to comment.