Skip to content

Commit

Permalink
Pass requested LitTypes directly into API getPreds query.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 465625844
  • Loading branch information
cjqian authored and LIT team committed Aug 5, 2022
1 parent a36a936 commit 74b5dbb
Show file tree
Hide file tree
Showing 14 changed files with 26 additions and 29 deletions.
3 changes: 0 additions & 3 deletions lit_nlp/client/lib/lit_types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ export class LitType {
annotated: boolean = false;
// TODO(b/162269499): Update to camel case once we've replaced old LitType.
show_in_data_table: boolean = false;

// TODO(b/162269499): Add isCompatible functionality.
}

/** A type alias for LitType with an align property. */
Expand Down Expand Up @@ -413,7 +411,6 @@ export class SubwordOffsets extends ListLitType {
export class SparseMultilabel extends StringList {
/** Label names. */
vocab?: string[] = undefined;
// TODO(b/162269499) Migrate non-comma separators to custom type.
/** Separator used for display purposes. */
separator: string = ',';
}
Expand Down
1 change: 0 additions & 1 deletion lit_nlp/client/lib/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ export function mapsContainSame<T>(mapA: Map<string, T>, mapB: Map<string, T>) {

/** Returns a list of names corresponding to LitTypes. */
export function getTypeNames(litTypes: LitTypeTypesList) : LitName[] {
// TODO(b/162269499): Update apiService to ingest types directly.
// TypeScript treats `typeof LitType` as a constructor function.
// Cast to any to access the name property.
// tslint:disable-next-line:no-any
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/client/modules/annotated_text_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ export class AnnotatedTextModule extends LitModule {

const promise = this.apiService.getPreds(
[input], this.model, this.appState.currentDataset,
['MultiSegmentAnnotations'], 'Retrieving annotations');
[MultiSegmentAnnotations], 'Retrieving annotations');
const results = await this.loadLatest('answers', promise);
if (results === null) return;

Expand Down
6 changes: 3 additions & 3 deletions lit_nlp/client/modules/attention_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import {observable} from 'mobx';

import {app} from '../core/app';
import {LitModule} from '../core/lit_module';
import {AttentionHeads as AttentionHeadsLitType} from '../lib/lit_types';
import {AttentionHeads as AttentionHeadsLitType, Tokens as TokensLitType} from '../lib/lit_types';
import {IndexedInput, ModelInfoMap, SCROLL_SYNC_CSS_CLASS, Spec} from '../lib/types';
import {doesOutputSpecContain, findSpecKeys, getTextWidth, getTokOffsets, sumArray} from '../lib/utils';
import {FocusService} from '../services/services';
Expand Down Expand Up @@ -101,8 +101,8 @@ export class AttentionModule extends LitModule {
if (selectedInput === null) return;
const dataset = this.appState.currentDataset;
const promise = this.apiService.getPreds(
[selectedInput], this.model, dataset, ['Tokens', 'AttentionHeads'],
'Fetching attention');
[selectedInput], this.model, dataset,
[TokensLitType, AttentionHeadsLitType], 'Fetching attention');
const res = await this.loadLatest('attentionAndTokens', promise);
if (res === null) return;
this.preds = res[0];
Expand Down
3 changes: 2 additions & 1 deletion lit_nlp/client/modules/feature_attribution_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ export class FeatureAttributionModule extends LitModule {
*/
private async predict(facet: string, data: IndexedInput[]) {
const promise = this.apiService.getPreds(
data, this.model, this.appState.currentDataset, ['FeatureSalience']);
data, this.model, this.appState.currentDataset,
[FeatureSalienceLitType]);
const results = await this.loadLatest('predictionScores', promise);

if (results == null) return;
Expand Down
4 changes: 2 additions & 2 deletions lit_nlp/client/modules/generated_image_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import {LitModule} from '../core/lit_module';
import {GeneratedURL, ImageBytes} from '../lib/lit_types';
import {styles as sharedStyles} from '../lib/shared_styles.css';
import {IndexedInput, ModelInfoMap, Spec} from '../lib/types';
import {getTypeNames, doesOutputSpecContain, isLitSubtype} from '../lib/utils';
import {doesOutputSpecContain, isLitSubtype} from '../lib/utils';

/**
* A LIT module that renders generated text.
Expand Down Expand Up @@ -76,7 +76,7 @@ export class GeneratedImageModule extends LitModule {
const dataset = this.appState.currentDataset;
const promise = this.apiService.getPreds(
[input], this.model, dataset,
getTypeNames([...GeneratedImageModule.supportedTypes, GeneratedURL]),
[...GeneratedImageModule.supportedTypes, GeneratedURL],
'Generating images');
const results = await this.loadLatest('generatedImages', promise);
if (results === null) return;
Expand Down
4 changes: 2 additions & 2 deletions lit_nlp/client/modules/generated_text_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import {DiffMode, GeneratedTextResult, GENERATION_TYPES} from '../lib/generated_
import {GeneratedText, GeneratedTextCandidates, LitTypeWithParent, ReferenceScores, ReferenceTexts} from '../lib/lit_types';
import {styles as sharedStyles} from '../lib/shared_styles.css';
import {IndexedInput, Input, ModelInfoMap, Spec} from '../lib/types';
import {getTypeNames, doesOutputSpecContain, findSpecKeys} from '../lib/utils';
import {doesOutputSpecContain, findSpecKeys} from '../lib/utils';

import {styles} from './generated_text_module.css';

Expand Down Expand Up @@ -114,7 +114,7 @@ export class GeneratedTextModule extends LitModule {

const dataset = this.appState.currentDataset;
const promise = this.apiService.getPreds(
[input], this.model, dataset, getTypeNames([...GENERATION_TYPES, ReferenceScores]),
[input], this.model, dataset, [...GENERATION_TYPES, ReferenceScores],
'Generating text');
const results = await this.loadLatest('generatedText', promise);
if (results === null) return;
Expand Down
4 changes: 2 additions & 2 deletions lit_nlp/client/modules/lm_prediction_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ export class LanguageModelPredictionModule extends LitModule {

const dataset = this.appState.currentDataset;
const promise = this.apiService.getPreds(
[input], this.model, dataset, ['Tokens', 'TokenTopKPreds'],
[input], this.model, dataset, [Tokens, TokenTopKPreds],
'Loading tokens');
const results = await this.loadLatest('modelPreds', promise);
if (results === null) return;
Expand Down Expand Up @@ -201,7 +201,7 @@ export class LanguageModelPredictionModule extends LitModule {

const dataset = this.appState.currentDataset;
const promise = this.apiService.getPreds(
[this.maskedInput], this.model, dataset, ['TokenTopKPreds']);
[this.maskedInput], this.model, dataset, [TokenTopKPreds]);
const results = await this.loadLatest('mlmResults', promise);
if (results === null) return;

Expand Down
4 changes: 2 additions & 2 deletions lit_nlp/client/modules/multilabel_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import {LitModule} from '../core/lit_module';
import {SortableTemplateResult, TableData, TableEntry} from '../elements/table';
import {SparseMultilabelPreds} from '../lib/lit_types';
import {formatBoolean, IndexedInput, ModelInfoMap, NumericResults, Spec} from '../lib/types';
import {getTypeNames, doesOutputSpecContain, findSpecKeys} from '../lib/utils';
import {doesOutputSpecContain, findSpecKeys} from '../lib/utils';
import {SelectionService} from '../services/services';

import {styles} from './multilabel_module.css';
Expand Down Expand Up @@ -104,7 +104,7 @@ export class MultilabelModule extends LitModule {
const results = await Promise.all(models.map(
async model => this.apiService.getPreds(
datapoints, model, this.appState.currentDataset,
getTypeNames([SparseMultilabelPreds]))));
[SparseMultilabelPreds])));
if (results === null) {
this.resultsInfo = {};
return;
Expand Down
5 changes: 2 additions & 3 deletions lit_nlp/client/modules/sequence_salience_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import {canonicalizeGenerationResults, GeneratedTextResult, GENERATION_TYPES, ge
import {Salience} from '../lib/lit_types';
import {styles as sharedStyles} from '../lib/shared_styles.css';
import {IndexedInput, ModelInfoMap, Spec} from '../lib/types';
import {getTypeNames, sumArray} from '../lib/utils';
import {sumArray} from '../lib/utils';
import {SignedSalienceCmap, UnsignedSalienceCmap} from '../services/color_service';

import {styles} from './sequence_salience_module.css';
Expand Down Expand Up @@ -130,8 +130,7 @@ export class SequenceSalienceModule extends LitModule {
this.currentPreds = undefined;

const promise = this.apiService.getPreds(
[input], this.model, this.appState.currentDataset,
getTypeNames(GENERATION_TYPES),
[input], this.model, this.appState.currentDataset, GENERATION_TYPES,
'Getting targets from model prediction');
const results = await this.loadLatest('generationResults', promise);
if (results === null) return;
Expand Down
4 changes: 2 additions & 2 deletions lit_nlp/client/modules/span_graph_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import {LitModule} from '../core/lit_module';
import {AnnotationLayer, SpanGraph} from '../elements/span_graph_vis_vertical';
import {EdgeLabels, SequenceTags, SpanLabels, LitTypeTypesList, LitTypeWithAlign, TextSegment, Tokens} from '../lib/lit_types';
import {EdgeLabel, IndexedInput, Input, ModelInfoMap, Preds, SpanLabel, Spec} from '../lib/types';
import {getTypeNames, findSpecKeys} from '../lib/utils';
import {findSpecKeys} from '../lib/utils';

import {styles as sharedStyles} from '../lib/shared_styles.css';

Expand Down Expand Up @@ -256,7 +256,7 @@ export class SpanGraphModule extends LitModule {
} else {
const promise = this.apiService.getPreds(
[input], this.model, this.appState.currentDataset,
getTypeNames([Tokens, ...supportedPredTypes]));
[Tokens, ...supportedPredTypes]);

const results = await this.loadLatest('getPreds', promise);
if (!results) return;
Expand Down
5 changes: 2 additions & 3 deletions lit_nlp/client/modules/tda_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import {canonicalizeGenerationResults, GeneratedTextResult, GENERATION_TYPES, ge
import {styles as sharedStyles} from '../lib/shared_styles.css';
import {FieldMatcher, LitTypeWithParent, InfluentialExamples} from '../lib/lit_types';
import {CallConfig, ComponentInfoMap, IndexedInput, Input, ModelInfoMap, Spec} from '../lib/types';
import {cloneSpec, getTypeNames, filterToKeys, findSpecKeys} from '../lib/utils';
import {cloneSpec, filterToKeys, findSpecKeys} from '../lib/utils';
import {AppState, SelectionService} from '../services/services';

import {styles} from './tda_module.css';
Expand Down Expand Up @@ -247,8 +247,7 @@ export class TrainingDataAttributionModule extends LitModule {
this.currentPreds = undefined;

const promise = this.apiService.getPreds(
[input], this.model, this.appState.currentDataset,
getTypeNames(GENERATION_TYPES),
[input], this.model, this.appState.currentDataset, GENERATION_TYPES,
'Getting targets from model prediction');
const results = await this.loadLatest('generationResults', promise);
if (results === null) return;
Expand Down
8 changes: 5 additions & 3 deletions lit_nlp/client/services/api_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
* limitations under the License.
*/

import {LitTypeTypesList} from '../lib/lit_types';
import {CallConfig, IndexedInput, LitMetadata, Preds} from '../lib/types';
import {deserializeLitTypesInLitMetadata} from '../lib/utils';
import {deserializeLitTypesInLitMetadata, getTypeNames} from '../lib/utils';

import {LitService} from './lit_service';
import {StatusService} from './status_service';
Expand Down Expand Up @@ -99,13 +100,14 @@ export class ApiService extends LitService {
*/
getPreds(
inputs: IndexedInput[], model: string, datasetName: string,
requestedTypes: string[], loadMessage?: string): Promise<Preds[]> {
requestedTypes: LitTypeTypesList,
loadMessage?: string): Promise<Preds[]> {
loadMessage = loadMessage || 'Fetching predictions';
return this.queryServer(
'/get_preds', {
'model': model,
'dataset_name': datasetName,
'requested_types': requestedTypes.join(','),
'requested_types': getTypeNames(requestedTypes).join(','),
},
inputs, loadMessage);
}
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/client/services/data_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ export class DataService extends LitService {
}

const predsPromise = this.apiService.getPreds(
data, model, this.appState.currentDataset, ['Scalar']);
data, model, this.appState.currentDataset, [Scalar]);
const preds = await predsPromise;

// Add scalar results as new column to the data service.
Expand Down

0 comments on commit 74b5dbb

Please sign in to comment.