How can we visualize saliency across an entire paragraph?


This notebooks has rifts on [ecco](https://github.com/jalammar/ecco) visualizations, using gpt2 and gradient x input saliency. 

(I'd be interested in trying a T5 Q&A or summerization model with the [LIT API](https://github.com/PAIR-code/lit/blob/main/documentation/python_api.md) along with other saliency methods but haven't figured out how to hook everything up yet).

# Setup

In [1]:
%%capture

!pip install ecco
!pip install sentencepiece

import ecco
lm = ecco.from_pretrained('gpt2', verbose=False)

In [2]:
def outputToSaliency(output, attr_method='grad_x_input'):
  position = output.n_input_tokens

  importance_id = position - output.n_input_tokens
  tokens = []
  attribution = output.attribution[attr_method]
  for idx, token in enumerate(output.tokens[0]):
    type = "input" if idx < output.n_input_tokens else 'output'
    if idx < len(attribution[importance_id]):
      imp = attribution[importance_id][idx]
    else:
      imp = 0

    tokens.append({
      'token': token,
      'token_id': int(output.token_ids[0][idx]),
      'type': type,
      'value': str(imp),  # because json complains of floats
      'position': idx
    })

  return {
    'tokens': tokens,
    'attributions': [att.tolist() for att in attribution]
  }

In [3]:
import IPython
import google.colab

def jsViz(data, settings={}):
  url = 'https://roadtolarissa.com/colab/fpld4dp8lo/2021-07-saliency-viz/paragraph-minimap/watch-files.js?6'
  
  if ('type' in settings):
    url = url.replace('paragraph-minimap', settings['type'])

  HTML_TEMPLATE = '''
    <link rel='stylesheet' href='__hs_placeholder'>
    <script src='https://pair.withgoogle.com/explorables/third_party/d3_.js'></script>
    <script src='https://pair.withgoogle.com/explorables/third_party/d3-scale-chromatic.v1.min.js'></script>
    <a style='display:none' class='no-js' href='{url}'>Click to authenticate</a>
    <div class='container'></div>

    <script>window.python_data = {data}</script>
    <script>window.python_settings = {settings}</script>
    <script>window.timeoutMS = 250</script>
    <script src='{url}'></script>
  '''

  IPython.display.display(IPython.display.HTML(HTML_TEMPLATE.format(
      data=data, 
      settings=settings, 
      url=url)
  ))

In [4]:
from IPython.display import Javascript
def resize_colab_cell():
  display(Javascript('google.colab.output.setIframeHeight(0, true, {maxHeight: 5000})'))
get_ipython().events.register('pre_run_cell', resize_colab_cell)


# Charts

With longer input passages, text saliency visualizations typically only let us look a single output token's attributions at once: 

In [5]:
text="""The race to Mars involves competition between manufacturers and nations. NASA has demurred in a potential rivalry with SpaceX or other manufacturers in any possible race to be first to Mars. It instead sees synergies in possible cooperation with such entities.  Boeing has stated that one of its rockets will lead to the first crewed expedition to Mars, before SpaceX or others will land a crewed mission. Boeing is the primary contractor on the U.S. Space Launch System (SLS) NASA rocket program that has the ultimate goal of a crewed Mars mission. SpaceX has declined to state that it is a race, or that it needs to race Boeing. Blue Origin has stated that with its New Armstrong and New Glenn rockets, it may be attempting missions to Mars, head-to-head with SpaceX's Interplanetary Transport System. This may result in commercial competition going to Mars. The first organization to land on Mars will be"""
output = lm.generate(text, generate=30, do_sample=True)
output.saliency()

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

By shrinking the saliency highlights down to small thumbnail, we can see the attributions for all the output tokens at the same time:

In [6]:
jsViz(outputToSaliency(output), {'type': 'paragraph-minimap', 'isDev': 0})

<IPython.core.display.Javascript object>

Comparing thumbnails across output tokens is still tricky. Instead of preserving input token position, we can squeeze each token down to a 2px wide bar to move input tokens closer to each other across tokens.

This visualization also allows for direction comparisons of attributions between two output tokens with a diverging color scale — try clicking on an output token and hovering over another.

In [7]:
jsViz(outputToSaliency(output), {'type': 'paragraph-tall-heatmap', 'isDev': 0})

<IPython.core.display.Javascript object>

Rearranging the charts and making the bars a bit thinner, we can also compare attributions across continuations.

If the model answers a question a different way is it looking at different input text? This vis could also be used to compare saliency methods. 

In [8]:
outputs = [
  output, 
  lm.generate(text, generate=30, do_sample=True),
  lm.generate(text, generate=30, do_sample=True),
]

<IPython.core.display.Javascript object>

In [9]:
jsViz(list(map(outputToSaliency, outputs)), {'type': 'paragraph-tall-multiple', 'isDev': 0})

<IPython.core.display.Javascript object>

More than ~30 output tokens take up too much vertical space. 

Shrinking the heights of the bars this time to squeze in more output tokens, we're left with a heatmap:

In [10]:
longOutputs = [
  lm.generate(text, generate=100, do_sample=True),
  lm.generate(text, generate=100, do_sample=True),
  lm.generate(text, generate=100, do_sample=True),
]

<IPython.core.display.Javascript object>

In [11]:
jsViz(list(map(outputToSaliency, longOutputs)), {'isDev': 0, 'type': 'paragraph-long-output'})

<IPython.core.display.Javascript object>

Other ideas:

- Input instead of output focused
  - With a fixed token width and a short continuation, you could show output attributions above every input token
  - Clicking could show how attributions are different between two input tokens 
- Non-linear positioning of tokens
  - UMAP/PCA 
  - Color could encode sentence position; x/y could show attribution. 
- Only show "interesting" interactions
  - Remove stop words / low values
  - Shift-click to expand a row of the heat map
- Can we pull anything else out of attribution methods beside scalars?
- Would normalizing relative position be useful? 
- Interactions between tokens / perturbation