Skip to content

Commit

Permalink
Adds sparse mode to Classification Results module.
Browse files Browse the repository at this point in the history
If enabled using the toggle in the footer, the Classification Results module will only show classes that have scored above 0.01.

This toggle is off by default and only available for fields that have more than 10 labels in their vocab.

PiperOrigin-RevId: 482293876
  • Loading branch information
RyanMullins authored and LIT team committed Oct 19, 2022
1 parent 68cb051 commit 20a8f31
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 20 deletions.
66 changes: 46 additions & 20 deletions lit_nlp/client/modules/classification_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import {DataService, SelectionService} from '../services/services';

import {styles} from './classification_module.css';

const SPARSE_MODE_THRESHOLD = 0.01;

interface DisplayInfo {
value: number;
isTruth: boolean;
Expand Down Expand Up @@ -76,14 +78,15 @@ export class ClassificationModule extends LitModule {
private readonly pinnedSelectionService =
app.getService(SelectionService, 'pinned');

@observable private sparseMode = false;
@observable private labeledPredictions: LabeledPredictions = {};

override firstUpdated() {
const getSelectionChanges = () =>
[this.appState.compareExamplesEnabled, this.appState.currentModels,
this.pinnedSelectionService.primarySelectedInputData,
this.selectionService.primarySelectedInputData,
this.dataService.dataVals];
const getSelectionChanges = () => [
this.appState.compareExamplesEnabled, this.appState.currentModels,
this.pinnedSelectionService.primarySelectedInputData, this.sparseMode,
this.selectionService.primarySelectedInputData, this.dataService.dataVals
];
this.reactImmediately(getSelectionChanges, () => {this.updateSelection();});
}

Expand Down Expand Up @@ -150,33 +153,51 @@ export class ClassificationModule extends LitModule {
const label = labels[i];

// Map the predctions for each example into DisplayInfo objects
const rowPreds = scores.map((score, j): DisplayInfo => {
const rowPreds = [];

for (let j = 0; j < scores.length; j++) {
const score = scores[j];

// Only push null scores if not in sparseMode
if (score == null) {
return {value: 0, isPredicted: false, isTruth: false};
if (!this.sparseMode) {
rowPreds.push({value: 0, isPredicted: false, isTruth: false});
}
continue;
}

const value = score[i];
const isPredicted = label === predictedClasses[j];
const {data} = inputs[j];
const isTruth = (parent != null && data[parent] === labels[i]);
return {value, isPredicted, isTruth};
});
// Push values if not in sparseMode or if above threshold
if (!this.sparseMode || value >= SPARSE_MODE_THRESHOLD) {
rowPreds.push({value, isPredicted, isTruth});
}
}

labeledPredictions[topLevelKey][label] = rowPreds;
if (rowPreds.length) labeledPredictions[topLevelKey][label] = rowPreds;
}
}

return labeledPredictions;
}

override renderImpl() {
const hasGroundTruth = this.appState.currentModels.some(
model =>
const clsFieldSpecs =
this.appState.currentModels.flatMap((model) =>
Object.values(this.appState.currentModelSpecs[model].spec.output)
.some(
feature => feature instanceof MulticlassPreds &&
feature.parent != null &&
feature.parent in
this.appState.currentModelSpecs[model].spec.input));
.filter(
(fieldSpec) => fieldSpec instanceof MulticlassPreds
) as MulticlassPreds[]);

const hasGroundTruth = clsFieldSpecs.some((fs) =>
fs.parent != null && fs.parent in this.appState.currentDatasetSpec);

const allowSparseMode = clsFieldSpecs.some((fs) => fs.vocab.length > 10);

const onClickSwitch = () => {this.sparseMode = !this.sparseMode;};

return html`<div class='module-container'>
<div class="module-results-area">
${
Expand All @@ -185,14 +206,19 @@ export class ClassificationModule extends LitModule {
const featureTable =
this.renderFeatureTable(labelRow, hasGroundTruth);
return arr.length === 1 ? featureTable : html`
<expansion-panel .label=${fieldName} expanded>
${featureTable}
</expansion-panel>`;
<expansion-panel .label=${fieldName} expanded>
${featureTable}
</expansion-panel>`;
})}
</div>
<div class="module-footer">
<annotated-score-bar-legend ?hasTruth=${hasGroundTruth}>
</annotated-score-bar-legend>
${allowSparseMode ? html`
<div class='switch-container' @click=${onClickSwitch}>
<div>Only show classes above ${SPARSE_MODE_THRESHOLD}</div>
<mwc-switch .checked=${this.sparseMode}></mwc-switch>
</div>` : null}
</div>
</div>`;
}
Expand Down
2 changes: 2 additions & 0 deletions lit_nlp/examples/image_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lit_nlp import dev_server
from lit_nlp import server_flags
from lit_nlp.api import layout
from lit_nlp.components import classification_results
from lit_nlp.components import image_gradient_maps

from lit_nlp.examples.datasets import imagenette
Expand Down Expand Up @@ -55,6 +56,7 @@ def main(_):
datasets = {'imagenette': imagenette.ImagenetteDataset()}
models = {'mobilenet': mobilenet.MobileNet()}
interpreters = {
'classification': classification_results.ClassificationInterpreter(),
'Grad': image_gradient_maps.VanillaGradients(),
'Integrated Gradients': image_gradient_maps.IntegratedGradients(),
'Blur IG': image_gradient_maps.BlurIG(),
Expand Down

0 comments on commit 20a8f31

Please sign in to comment.