In [57]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained('gpt2')

In [None]:
# prompt = 'the dog walked the man to the park'

# tokens = tokenizer.encode(prompt, return_tensors='pt')
# print(tokens)
# print(tokenizer.tokenize(prompt))
# outputs = model(tokens, output_attentions=True)
# print(len(outputs.attentions))
# outputs.attentions[0].shape

In [None]:
# from bertviz import model_view, head_view

# model_view(outputs.attentions, tokenizer.convert_ids_to_tokens(tokens[0]))
# head_view(outputs.attentions, tokenizer.convert_ids_to_tokens(tokens[0]))

In [None]:
# Goal:
# 1. Take a prompt.
# 2. Generate an answer to the question.
# 3. Quantify where attention was paid on each token of the answer.
# 4. Visualize that attention. Visualization should work for long passages.

In [107]:
def inspect_shape(name, obj):
    # Since this is the top-level inspect call, print the name of the variable.
    print(f'{name}: ', end='')
    _inspect_shape(obj, len(name) + 2)  # + 2 for ': '


def _class_name(obj):
    return type(obj).__name__


def _inspect_shape(obj, indent):
    """ 
    Print the given obj. Does not print preceding name or preceding indentation.
    Param `indent` should be the indentation level where this obj is being printed. Will be used to print nested properties if necessary.
    """

    # Base Cases (sequences with known-typed contents, tensors, other objects or primitives)
    if isinstance(obj, str):
        print(f'{_class_name(obj)}[len={len(obj)}]')
    elif isinstance(obj, torch.Tensor):
        print(f'{_class_name(obj)}[shape={list(obj.shape)}, sum={obj.sum()}]')
    else:

        # Dict where contents should be recursively inspected
        try:
            print(f'{_class_name(obj)}[len={len(obj.items())}]')
            for key, val in obj.items():
                print(' ' * (indent + 2) + f'{key}: ', end='')
                _inspect_shape(val, indent + len(key) + 2)  # + 2 for ': '
        except (TypeError, AttributeError):
            # Sequence where contents should be recursively inspected
            try:
                print(f'{_class_name(obj)}[len={len(obj)}]')
                for i in range(min(len(obj), 3)):
                    print(' ' * (indent + 2) + f'{i}: ', end='')
                    # + 2 for ': '
                    _inspect_shape(obj[i], indent + len(str(i)) + 2)
            except (TypeError, AttributeError):
                # Default case if nothing else works
                print(f'{_class_name(obj)}')

In [153]:
import altair as alt
import pandas as pd


def attn_viz(prompt_tokens, generated_sequence_tokens, generated_token_avg_attns, width, scale):
    """
    prompt_tokens should be the Tensor of token IDs from the prompt, 1 x (prompt len)
    generated_sequence should be the raw sequence from huggingface .generate(), including both prompt and generated tokens.
    generated_token_avg_attns should be a sequence with len(generated tokens), where each element is a sequence with len(tokens through generated token, including prompt)
    """
    df = _prepare_attn_df(prompt_tokens, generated_sequence_tokens,
                          generated_token_avg_attns)
    _show_attn_chart(df, width, scale)


def _printable_token_text(token):
    return tokenizer.decode(token).replace('\n', '\\n')


def _prepare_attn_df(prompt_tokens, generated_sequence_tokens, generated_token_avg_attns):
    df = pd.DataFrame({'token_id': generated_sequence_tokens, 'token_text': map(
        _printable_token_text, generated_sequence_tokens)})

    attn_from_cols = [f'token_attn_from_{i}' for i in range(
        len(generated_sequence_tokens))]
    df = df.reindex(columns=['token_id', 'token_text'] +
                    attn_from_cols, fill_value=0.0)
    df['token_idx'] = df.index

    # Generated tokens give attention to preceding tokens
    # TODO: Make this more efficient, maybe with an entire matrix instead of setting each vector?
    for token_idx, generated_token_avg_attn in enumerate(generated_token_avg_attns):
        # inspect_shape('generated_token_avg_attn', generated_token_avg_attn)
        print('prompt tokens shape', prompt_tokens.shape)
        absolute_token_idx = prompt_tokens.shape[1] + token_idx
        df.loc[:len(generated_token_avg_attn) - 1,
               f'token_attn_from_{absolute_token_idx}'] = list(generated_token_avg_attn)
    display(df)
    return df


def _show_attn_chart(df, width, scale):
    """
    df should have columns:
      token_idx: absolute index of token within sequence, including prompt tokens
      token_id: int representation of token
      token_text: string representation of token
      token_attn_from_{i}: float representing the attention placed on this token by the token with index i in the sequence
    width in pixels, of total visualization
    scale in float, 1.0 is normal scale (48px x 24px boxes)
    """

    chart_df = df.copy()

    # Add x and y information to the tokens based on the calculated number of tokens that can fit the given width
    tokens_w = width // (scale * 48)
    chart_df['x'] = chart_df['token_idx'].apply(lambda idx: idx % tokens_w)
    chart_df['y'] = chart_df['token_idx'].apply(lambda idx: idx // tokens_w)

    display(chart_df)

    base = alt.Chart(chart_df).encode(
        x=alt.X('x:N').title('').axis(labels=False, ticks=False),
        y=alt.Y('y:N').title('').axis(labels=False, ticks=False)
    ).properties(
        width=width,
        # Number of columns times height of each box
        height=(len(df.index) // tokens_w) * (scale * 24)
    )

    highlight = base.mark_rect().encode(
        alt.Color('token_attn_from_100:Q').scale(
            scheme='tealblues').legend(None)
    )

    text = base.mark_text(baseline='middle').encode(
        text='token_text:N'
    )

    display(highlight + text)


def average_token_attn(token_attn):
    # TODO: Make this more efficient using torch tensor operations instead of splitting lists
    layer_avg_attns = []
    for layer_idx, layer_attn in enumerate(token_attn):
        # print('layer_idx', layer_idx)
        # Take mean across model's attention heads
        layer_avg_attn = layer_attn.squeeze(0).mean(dim=0)
        # inspect_shape('layer_avg_attn', layer_avg_attn)
        layer_avg_attn_cleaned = torch.concat([
            torch.tensor([0.]),  # First entry is null attention, set it to 0
            # Remove all tokens except the most recent (since the first token generated has entire prompts' worth of attention generated) and remove the first entry
            layer_avg_attn[-1][1:],
            torch.tensor([0.]),  # Add a 0 for the current token itself
        ])
        # inspect_shape('layer_avg_attn_cleaned', layer_avg_attn_cleaned)
        layer_avg_attn_normalized = layer_avg_attn_cleaned / layer_avg_attn_cleaned.sum()
        # inspect_shape('layer_avg_attn_normalized', layer_avg_attn_normalized)
        layer_avg_attns.append(layer_avg_attn_normalized)
    # inspect_shape('layer_avg_attns', layer_attn)
    return torch.stack(layer_avg_attns).mean(dim=0)


def generate_and_attn(prompt):
    prompt_tokens = tokenizer.encode(prompt, return_tensors='pt')
    inspect_shape('tokens', prompt_tokens)

    outputs = model.generate(prompt_tokens, max_new_tokens=48,
                             output_attentions=True, return_dict_in_generate=True, pad_token_id=tokenizer.eos_token_id)
    inspect_shape('outputs', outputs)
    # token_avg_attn = average_token_attn(outputs.attentions[1])
    # inspect_shape('token_avg_attn', token_avg_attn)
    # print(token_avg_attn)
    # for token in outputs.sequences[0]:
    #     print(token, tokenizer.decode(token), end=', ')
    generated_token_avg_attns = list(map(
        average_token_attn, outputs.attentions))
    inspect_shape('generated_token_avg_attns', generated_token_avg_attns)

    inspect_shape('outputs.sequences[0]', outputs.sequences[0])
    attn_viz(prompt_tokens,
             outputs.sequences[0], generated_token_avg_attns, 512, 1.0)


PROMPT = """
The 2008 Summer Olympics torch relay was run from March 24 until August 8, 2008, prior to the 2008 Summer Olympics, with the theme of "one world, one dream". The torch relay took place over 45km with 18 total runners.

Q: What was the theme?
A:
""".strip()

generate_and_attn(PROMPT)

tokens: Tensor[shape=[1, 60], sum=254623]
outputs: GreedySearchDecoderOnlyOutput[len=3]
           sequences: Tensor[shape=[1, 108], sum=454166]
           attentions: tuple[len=48]
                       0: tuple[len=12]
                          0: Tensor[shape=[1, 12, 60, 60], sum=720.0]
                          1: Tensor[shape=[1, 12, 60, 60], sum=720.0]
                          2: Tensor[shape=[1, 12, 60, 60], sum=720.0]
                       1: tuple[len=12]
                          0: Tensor[shape=[1, 12, 1, 61], sum=11.999999046325684]
                          1: Tensor[shape=[1, 12, 1, 61], sum=12.000000953674316]
                          2: Tensor[shape=[1, 12, 1, 61], sum=12.0]
                       2: tuple[len=12]
                          0: Tensor[shape=[1, 12, 1, 62], sum=12.000000953674316]
                          1: Tensor[shape=[1, 12, 1, 62], sum=12.0]
                          2: Tensor[shape=[1, 12, 1, 62], sum=11.999999046325684]
           past_key_valu

Unnamed: 0,token_id,token_text,token_attn_from_0,token_attn_from_1,token_attn_from_2,token_attn_from_3,token_attn_from_4,token_attn_from_5,token_attn_from_6,token_attn_from_7,...,token_attn_from_99,token_attn_from_100,token_attn_from_101,token_attn_from_102,token_attn_from_103,token_attn_from_104,token_attn_from_105,token_attn_from_106,token_attn_from_107,token_idx
0,464,The,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0
1,3648,2008,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.003128,0.002640,0.003904,0.008104,0.016115,0.003378,0.003049,0.004015,0.003860,1
2,10216,Summer,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.001960,0.002374,0.002796,0.004580,0.005535,0.002284,0.001731,0.001930,0.002719,2
3,14935,Olympics,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.002512,0.004261,0.005024,0.005697,0.007323,0.005890,0.003913,0.003842,0.003923,3
4,28034,torch,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.001940,0.006182,0.004596,0.004449,0.009333,0.011573,0.005620,0.003233,0.004239,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
103,28034,torch,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.078872,0.090071,0.041559,0.034214,103
104,24248,relay,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.069411,0.087705,0.060633,104
105,1718,took,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.082932,0.118218,105
106,1295,place,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.067498,106


TypeError: _show_attn_chart() missing 1 required positional argument: 'scale'