Skip to content

Commit

Permalink
Add scalar model outputs to data service.
Browse files Browse the repository at this point in the history
Removes the need for special handling of scalar preds types in scalars module.

PiperOrigin-RevId: 452813933
  • Loading branch information
jameswex authored and LIT team committed Jun 3, 2022
1 parent bcdbb80 commit 00749fc
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 15 deletions.
20 changes: 5 additions & 15 deletions lit_nlp/client/modules/scalar_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import {ThresholdChange} from '../elements/threshold_slider';
import {styles as sharedStyles} from '../lib/shared_styles.css';
import {D3Selection, formatForDisplay, IndexedInput, ModelInfoMap, ModelSpec, Preds, Spec} from '../lib/types';
import {doesOutputSpecContain, findSpecKeys, getThresholdFromMargin, isLitSubtype} from '../lib/utils';
import {CalculatedColumnType, CLASSIFICATION_SOURCE_PREFIX, REGRESSION_SOURCE_PREFIX} from '../services/data_service';
import {CalculatedColumnType, CLASSIFICATION_SOURCE_PREFIX, REGRESSION_SOURCE_PREFIX, SCALAR_SOURCE_PREFIX} from '../services/data_service';
import {FocusData} from '../services/focus_service';
import {ClassificationService, ColorService, DataService, FocusService, GroupService, SelectionService} from '../services/services';

Expand Down Expand Up @@ -118,20 +118,15 @@ export class ScalarModule extends LitModule {
return true;
} else if (col.source.includes(REGRESSION_SOURCE_PREFIX)) {
return false;
} else if (col.source.includes(CLASSIFICATION_SOURCE_PREFIX)) {
} else if (col.source.includes(CLASSIFICATION_SOURCE_PREFIX) ||
col.source.includes(SCALAR_SOURCE_PREFIX)) {
return col.source.includes(this.model);
} else {
return true;
}
});
}

@computed
private get scalarModelOutputKeys() {
const outputSpec = this.appState.currentModelSpecs[this.model].spec.output;
return findSpecKeys(outputSpec, 'Scalar');
}

@computed
private get classificationKeys() {
const outputSpec = this.appState.currentModelSpecs[this.model].spec.output;
Expand Down Expand Up @@ -365,18 +360,14 @@ export class ScalarModule extends LitModule {
currentInputData, this.model, dataset, ['MulticlassPreds']),
this.apiService.getPreds(
currentInputData, this.model, dataset, ['RegressionScore']),
this.apiService.getPreds(
currentInputData, this.model, dataset, ['Scalar']),
]);
const results = await this.loadLatest('predictionScores', promise);
if (results === null) {
return;
}
const classificationPreds = results[0];
const regressionPreds = results[1];
const scalarPreds = results[2];
if (classificationPreds == null && regressionPreds == null &&
scalarPreds == null) {
if (classificationPreds == null && regressionPreds == null) {
return;
}

Expand All @@ -386,7 +377,7 @@ export class ScalarModule extends LitModule {
// TODO(lit-dev): structure this as a proper IndexedInput,
// rather than having 'id' as a regular field.
const pred = Object.assign(
{}, classificationPreds[i], scalarPreds[i], regressionPreds[i],
{}, classificationPreds[i], regressionPreds[i],
{id: currId});
for (const scalarKey of this.scalarColumnsToPlot) {
pred[scalarKey] = this.dataService.getVal(currId, scalarKey);
Expand Down Expand Up @@ -819,7 +810,6 @@ export class ScalarModule extends LitModule {
// clang-format off
return html`
<div id='container'>
${this.scalarModelOutputKeys.map(key => this.renderPlot(key, ''))}
${this.classificationKeys.map(key =>
this.renderClassificationGroup(key))}
${this.scalarColumnsToPlot.map(key => this.renderPlot(key, ''))}
Expand Down
30 changes: 30 additions & 0 deletions lit_nlp/client/services/data_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ export type ColumnData = Map<string, ValueType>;
export const CLASSIFICATION_SOURCE_PREFIX = 'Classification';
/** Column source prefix for columns from the regression interpreter. */
export const REGRESSION_SOURCE_PREFIX = 'Regression';
/** Column source prefix for columns from scalar model outputs. */
export const SCALAR_SOURCE_PREFIX = 'Scalar';

/**
* Data service singleton, responsible for maintaining columns of computed data
Expand Down Expand Up @@ -110,6 +112,7 @@ export class DataService extends LitService {
}
for (const model of this.appState.currentModels) {
this.runRegression(model, this.appState.currentInputData);
this.runScalarPreds(model, this.appState.currentInputData);
}
}, {fireImmediately: true});

Expand Down Expand Up @@ -219,6 +222,33 @@ export class DataService extends LitService {
}
}

/**
* Run scalar predictions and store results in data service.
*/
private async runScalarPreds(model: string, data: IndexedInput[]) {
const {output} = this.appState.currentModelSpecs[model].spec;
if (findSpecKeys(output, 'Scalar').length === 0) {
return;
}

const predsPromise = this.apiService.getPreds(
data, model, this.appState.currentDataset, ['Scalar']);
const preds = await predsPromise;

// Add scalar results as new column to the data service.
if (preds == null || preds.length === 0) {
return;
}
const scalarKeys = Object.keys(preds[0]);
for (const key of scalarKeys) {
const scoreFeatName = this.getColumnName(model, key);
const scores = preds.map(pred => pred[key]);
const dataType = this.appState.createLitType('Scalar', false);
const source = `${SCALAR_SOURCE_PREFIX}:${model}`;
this.addColumnFromList(scores, data, scoreFeatName, dataType, source);
}
}

@action
async setValuesForNewDatapoints(datapoints: IndexedInput[]) {
// When new datapoints are created, set their data values for each
Expand Down

0 comments on commit 00749fc

Please sign in to comment.