Skip to content

Commit

Permalink
Updates and UI fixes for LM salience
Browse files Browse the repository at this point in the history
- Remove redundant "Target:" label on dropdown
- Help icon next to target selector dropdown
- Fix tooltip text on colormap slider
- Remove "Show self scores" toggle
- Remove "token_loss" for now
- Add a progress indicator for salience requests

PiperOrigin-RevId: 607565021
  • Loading branch information
iftenney authored and LIT team committed Feb 16, 2024
1 parent 406fbc7 commit 77583e7
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 33 deletions.
35 changes: 26 additions & 9 deletions lit_nlp/client/modules/lm_salience_module.css
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
padding: 8px;
}

.chip-container-dense {
padding: 8px;
}

.pre-wrap {
white-space: pre-wrap;
}
Expand All @@ -23,6 +19,7 @@
white-space: nowrap;
text-overflow: ellipsis;
overflow-x: hidden;
line-height: 22px;
}

lit-switch .icon-button {
Expand Down Expand Up @@ -63,12 +60,10 @@ lit-switch .icon-button {
margin-right: 8px;
}

.controls-group-variable > label {
min-width: 45px;
}

.controls-group-variable .dropdown {
max-width: calc(100% - 45px);
max-width: calc(100% - 22px);
margin-right: 4px;
text-overflow: ellipsis;
}

.vertical-separator {
Expand All @@ -95,4 +90,26 @@ color-legend {
/* extra space to keep other controls from jumping when legend changes */
/* width: 400px; */
margin-right: 16px;
}


/* Pending request indicator */
.loading-indicator-container {
position: relative;
width: 100%;
top: -2px;
}

@keyframes running-progress {
0% { margin-left: 0; margin-right: 100%; }
50% { margin-left: 35%; margin-right: 0%; }
100% { margin-left: 100%; margin-right: 0%; }
}

.loading-indicator {
position: absolute;
background-color: var(--lit-neutral-500);
width: 100%;
height: 2px;
animation: running-progress 2s cubic-bezier(0.4, 0, 0.2, 1) infinite;
}
61 changes: 44 additions & 17 deletions lit_nlp/client/modules/lm_salience_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import '../elements/fused_button_bar';
import {css, html} from 'lit';
// tslint:disable:no-new-decorators
import {customElement} from 'lit/decorators.js';
import {classMap} from 'lit/directives/class-map.js';
import {computed, observable} from 'mobx';

import {LitModule} from '../core/lit_module';
Expand Down Expand Up @@ -556,17 +557,21 @@ export class LMSalienceModule extends SingleExampleSingleModelModule {
`;
}

/* Disabled for space reasons. */
// renderSelfScoreSelector() {
// const onClickToggleSelfSalience = () => {
// this.showSelfSalience = !this.showSelfSalience;
// };
// // prettier-ignore
// return html`
// <lit-switch labelLeft="Show self scores"
// ?selected=${this.showSelfSalience}
// @change=${onClickToggleSelfSalience}>
// </lit-switch>
// `;
// }
renderSelfScoreSelector() {
const onClickToggleSelfSalience = () => {
this.showSelfSalience = !this.showSelfSalience;
};
// prettier-ignore
return html`
<lit-switch labelLeft="Show self scores"
?selected=${this.showSelfSalience}
@change=${onClickToggleSelfSalience}>
</lit-switch>
`;
return null;
}

renderMethodSelector() {
Expand Down Expand Up @@ -632,14 +637,29 @@ export class LMSalienceModule extends SingleExampleSingleModelModule {
</option>`;
});

const targetSelectorHelp =
'Select a (response) from the model or a pre-defined (target) sequence from the dataset.';

// prettier-ignore
return html`
<div class="controls-group controls-group-variable"
title="Target string for salience.">
<label class="dropdown-label">Target:</label>
<select class="dropdown" @change=${onChangeTarget}>
${options}
</select>
<lit-tooltip content=${targetSelectorHelp} tooltipPosition="left">
<span class="help-icon material-icon-outlined icon-button">
help_outline
</span>
</lit-tooltip>
</div>`;
}

renderLoadingIndicator() {
// prettier-ignore
return html`
<div class='loading-indicator-container'>
<div class='loading-indicator'></div>
</div>`;
}

Expand All @@ -658,12 +678,22 @@ export class LMSalienceModule extends SingleExampleSingleModelModule {
return `Explaining ${this.printTargetForHuman(start, end)}`;
};

const requestPending = this.targetTokenSpan !== undefined &&
this.salienceResultCache[this.spanToKey(this.targetTokenSpan)] ===
REQUEST_PENDING;
// const requestPending = true;
const infoLineClasses = classMap({
'target-info-line': true,
'gray-text': requestPending,
});

// prettier-ignore
return html`
<div class="controls-group controls-group-variable"
title="Selected target span.">
<div class="target-info-line">
<div class=${infoLineClasses}>
${printSelectedTargets()}
${requestPending ? this.renderLoadingIndicator() : null}
</div>
</div>
`;
Expand Down Expand Up @@ -741,12 +771,9 @@ export class LMSalienceModule extends SingleExampleSingleModelModule {
});
}

// TODO: revert to 4px for non-dense view if we can figure out the
// display mode for token chips? Needs more padding for block mode,
// but also indentation and newlines are wonky.
// prettier-ignore
return html`
<div class=${this.denseView ? 'chip-container-dense' : 'chip-container'}>
<div class='chip-container'>
<lm-salience-chips .tokensWithWeights=${segmentsWithWeights}
?dense=${this.denseView} ?preSpace=${this.denseView}
.cmap=${this.cmap} breakNewlines displayBlock>
Expand Down Expand Up @@ -793,7 +820,7 @@ export class LMSalienceModule extends SingleExampleSingleModelModule {
<lit-numeric-input min="0" max="6" step="0.25" id='gamma-slider'
value="${this.cmapGamma}" @change=${onChangeGamma}>
</lit-numeric-input>
<mwc-icon class='icon-button value-reset-icon' title='Reset gamma'
<mwc-icon class='icon-button value-reset-icon' title='Reset colormap'
@click=${resetGamma}>
restart_alt
</mwc-icon>
Expand Down
8 changes: 4 additions & 4 deletions lit_nlp/examples/models/instrumented_keras_lms.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def layer_intercept_fn(x, i):
FieldNames.GRAD_NORM: grad_l2,
FieldNames.GRAD_DOT_INPUT: grad_dot_input,
# Shift token loss to align with (input) tokens.
FieldNames.TOKEN_LOSS: tf.roll(per_token_loss, shift=1, axis=1),
# FieldNames.TOKEN_LOSS: tf.roll(per_token_loss, shift=1, axis=1),
}

return batched_outputs
Expand All @@ -322,7 +322,7 @@ def _postprocess(self, preds):
):
preds[key] = preds[key][mask]
# First token (<bos>) is not actually predicted, so return 0 for loss.
preds[FieldNames.TOKEN_LOSS][0] = 0
# preds[FieldNames.TOKEN_LOSS][0] = 0

return preds

Expand Down Expand Up @@ -353,11 +353,11 @@ def input_spec(self):
def output_spec(self) -> lit_types.Spec:
return {
FieldNames.TOKENS: lit_types.Tokens(parent=""), # All tokens.
FieldNames.GRAD_NORM: lit_types.TokenScores(align=FieldNames.TOKENS),
FieldNames.GRAD_DOT_INPUT: lit_types.TokenScores(
align=FieldNames.TOKENS
),
FieldNames.GRAD_NORM: lit_types.TokenScores(align=FieldNames.TOKENS),
FieldNames.TOKEN_LOSS: lit_types.TokenScores(align=FieldNames.TOKENS),
# FieldNames.TOKEN_LOSS: lit_types.TokenScores(align=FieldNames.TOKENS),
}


Expand Down
6 changes: 3 additions & 3 deletions lit_nlp/examples/models/pretrained_lms.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def _pred(self, encoded_inputs, target_masks):
"grad_l2": grad_l2,
"grad_dot_input": grad_dot_input,
# Shift token loss to align with (input) tokens.
"token_loss": tf.roll(per_token_loss, shift=1, axis=1),
# "token_loss": tf.roll(per_token_loss, shift=1, axis=1),
}

return batched_outputs
Expand All @@ -609,7 +609,7 @@ def _postprocess(self, preds):
for key in utils.find_spec_keys(self.output_spec(), lit_types.TokenScores):
preds[key] = preds[key][mask]
# First token (usually <s>) is not actually predicted, so return 0 for loss.
preds["token_loss"][0] = 0
# preds["token_loss"][0] = 0

return preds

Expand Down Expand Up @@ -645,7 +645,7 @@ def output_spec(self) -> lit_types.Spec:
"tokens": lit_types.Tokens(parent=""), # all tokens
"grad_l2": lit_types.TokenScores(align="tokens"),
"grad_dot_input": lit_types.TokenScores(align="tokens"),
"token_loss": lit_types.TokenScores(align="tokens"),
# "token_loss": lit_types.TokenScores(align="tokens"),
}


Expand Down

0 comments on commit 77583e7

Please sign in to comment.