In [None]:
%%capture
%%javascript
require.config({
  paths: {
      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min',
      jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
  }
});

In [None]:
from pathlib import Path

import pandas as pd
import torch
import bertviz

DATA_PATH = Path("../test-output/test.csv")

In [None]:
texts = pd.read_csv(DATA_PATH, index_col="id")
texts

In [None]:
def view_attention(prompt_id: int, head_view: bool):
    # Load the attentions
    data = torch.load(DATA_PATH.parent / (DATA_PATH.name.split(".")[0] + f"_attn{prompt_id}.pt"))
    tokens = data["prompt_tokens"] + data["response_tokens"]
    n_prompt = len(data["prompt_tokens"])
    n_response = len(data["response_tokens"])
    n_tokens = n_prompt + n_response
    n_layers = len(data["attention"][0])
    n_heads = data["attention"][0][0].size(1)


    # Attentions for each generated token are separate
    # 1xN tensors. Combine them into one large NxN
    # tensor, padding the empty space with zeros. 
    layers = []
    for i in range(n_layers):
        layer_attn = []
        for token_attn in data["attention"]:
            token_layer_attn = token_attn[i]
            layer_attn.append(
                torch.nn.functional.pad(
                    token_layer_attn,
                    (0, n_tokens - token_layer_attn.size(-1)),
                    mode="constant",
                    value=0,
                )
            )
        layer_attn.insert(0, torch.zeros((1,n_heads,1,n_tokens)))
        layers.append(torch.cat(layer_attn, dim=-2).transpose(-1, -2))
    if head_view:
        return bertviz.head_view(layers, tokens)
    else:
        return bertviz.model_view(layers, tokens)

# Model View
<b>The model view provides a birds-eye view of attention throughout the entire model</b>. Each cell shows the attention weights for a particular head, indexed by layer (row) and head (column).  The lines in each cell represent the attention from one token (left) to another (right), with line weight proportional to the attention value (ranges from 0 to 1).  For a more detailed explanation, please refer to the [blog](https://towardsdatascience.com/deconstructing-bert-part-2-visualizing-the-inner-workings-of-attention-60a16d86b5c1).

## Usage
👉 **Click** on any **cell** for a detailed view of attention for the associated attention head (or to unselect that cell). <br/>
👉 Then **hover** over any **token** on the left side of detail view to filter the attention from that token.

In [None]:
view_attention(prompt_id=0, head_view=False)

# Head View
<b>The head view visualizes attention in one or more heads from a single Transformer layer.</b> Each line shows the attention from one token (left) to another (right). Line weight reflects the attention value (ranges from 0 to 1), while line color identifies the attention head. When multiple heads are selected (indicated by the colored tiles at the top), the corresponding  visualizations are overlaid onto one another.  For a more detailed explanation of attention in Transformer models, please refer to the [blog](https://towardsdatascience.com/deconstructing-bert-part-2-visualizing-the-inner-workings-of-attention-60a16d86b5c1).

## Usage
👉 **Hover** over any **token** on the left/right side of the visualization to filter attention from/to that token. <br/>
👉 **Double-click** on any of the **colored tiles** at the top to filter to the corresponding attention head.<br/>
👉 **Single-click** on any of the **colored tiles** to toggle selection of the corresponding attention head. <br/>
👉 **Click** on the **Layer** drop-down to change the model layer (zero-indexed).


In [None]:
view_attention(prompt_id=0, head_view=True)

# Credit
Attention analysis interfaces and descriptions taken from the [BERTViz Interactive Tutorial](https://colab.research.google.com/drive/1hXIQ77A4TYS4y3UthWF-Ci7V7vVUoxmQ)