Skip to content

Commit

Permalink
Added PR/ROC curves to UI
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 441737849
  • Loading branch information
jameswex authored and LIT team committed Apr 14, 2022
1 parent 78e2e9c commit 0f9fd4d
Show file tree
Hide file tree
Showing 8 changed files with 402 additions and 7 deletions.
2 changes: 2 additions & 0 deletions lit_nlp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from lit_nlp.api import model as lit_model
from lit_nlp.api import types
from lit_nlp.components import ablation_flip
from lit_nlp.components import curves
from lit_nlp.components import gradient_maps
from lit_nlp.components import hotflip
from lit_nlp.components import lemon_explainer
Expand Down Expand Up @@ -488,6 +489,7 @@ def __init__(
'Model-provided salience': model_salience.ModelSalience(self._models),
'counterfactual explainer': lemon_explainer.LEMON(),
'tcav': tcav.TCAV(),
'curves': curves.CurvesInterpreter(),
'thresholder': thresholder.Thresholder(),
'nearest neighbors': nearest_neighbors.NearestNeighbors(),
'metrics': metrics_group,
Expand Down
2 changes: 2 additions & 0 deletions lit_nlp/client/default/layout.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import {ClassificationModule} from '../modules/classification_module';
import {ColorModule} from '../modules/color_module';
import {ConfusionMatrixModule} from '../modules/confusion_matrix_module';
import {CounterfactualExplainerModule} from '../modules/counterfactual_explainer_module';
import {CurvesModule} from '../modules/curves_module';
import {DataTableModule, SimpleDataTableModule} from '../modules/data_table_module';
import {DatapointEditorModule, SimpleDatapointEditorModule} from '../modules/datapoint_editor_module';
import {DocumentationModule} from '../modules/documentation_module';
Expand Down Expand Up @@ -118,6 +119,7 @@ export const LAYOUTS: LitComponentLayouts = {
'Metrics': [
MetricsModule,
ConfusionMatrixModule,
CurvesModule,
ThresholderModule,
],
'Influence': [TrainingDataAttributionModule],
Expand Down
12 changes: 8 additions & 4 deletions lit_nlp/client/elements/line_chart.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ export class LineChart extends ReactiveElement {
@property({type: Number}) margin = 30; // Default margin size.
@property({type: Number}) width = 0;
@property({type: Number}) height = 0;
@property({type: Array}) xScale: number[] = [];
@property({type: Array}) yScale: number[] = [];

static override get styles() {
Expand All @@ -54,9 +55,12 @@ export class LineChart extends ReactiveElement {
}

private getXScale() {
const labels = Array.from(this.scores.keys());
return d3.scaleLinear().domain([d3.min(labels)!, d3.max(labels)!]).range(
[0, this.width]);
let scale: number[] = this.xScale;
if (scale == null || scale.length < 2) {
const labels = Array.from(this.scores.keys());
scale = [d3.min(labels)!, d3.max(labels)!];
}
return d3.scaleLinear().domain(scale).range([0, this.width]);
}

private getYScale() {
Expand Down Expand Up @@ -98,7 +102,7 @@ export class LineChart extends ReactiveElement {
chart.selectAll('*').remove();

// Make axes.
const xAxis = d3.axisBottom(x);
const xAxis = d3.axisBottom(x).ticks(5);
const yAxis = d3.axisLeft(y).ticks(5);
chart.append('g')
.attr('transform', `translate(0, ${this.height})`)
Expand Down
5 changes: 5 additions & 0 deletions lit_nlp/client/lib/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,11 @@ export function isBinaryClassification(litType: LitType) {
return predictionLabels.length === 2 && nullIdx != null;
}

/** Returns if a LitType has a parent field. */
export function hasParent(litType: LitType) {
return litType.parent != null;
}

/**
* Helper function to make an object into a human readable key.
* Sorts object keys, so order of object does not matter.
Expand Down
13 changes: 13 additions & 0 deletions lit_nlp/client/modules/curves_module.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
.charts-holder {
display: flex;
flex-flow: wrap;
padding-top: 4px;
}

.module-toolbar {
border-bottom: 1px solid var(--lit-neutral-300);
}

.chart-holder {
padding-bottom: 12px;
}

0 comments on commit 0f9fd4d

Please sign in to comment.