In [57]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained('gpt2')

In [None]:
# prompt = 'the dog walked the man to the park'

# tokens = tokenizer.encode(prompt, return_tensors='pt')
# print(tokens)
# print(tokenizer.tokenize(prompt))
# outputs = model(tokens, output_attentions=True)
# print(len(outputs.attentions))
# outputs.attentions[0].shape

In [None]:
# from bertviz import model_view, head_view

# model_view(outputs.attentions, tokenizer.convert_ids_to_tokens(tokens[0]))
# head_view(outputs.attentions, tokenizer.convert_ids_to_tokens(tokens[0]))

In [None]:
# Goal:
# 1. Take a prompt.
# 2. Generate an answer to the question.
# 3. Quantify where attention was paid on each token of the answer.

In [79]:
def inspect_shape(name, obj):
    # Since this is the top-level inspect call, print the name of the variable.
    print(f'{name}: ', end='')
    _inspect_shape(obj, 0)


def _inspect_shape(obj, indent):
    # Print given the current indent

    # Sequence where contents should be recursively inspected
    if isinstance(obj, (list, tuple)):
        print(' ' * indent + f'{type(obj)}[len={len(obj)}]')
        for i in range(min(len(obj), 3)):
            _inspect_shape(obj[i], indent + 2)

    # Base Cases (sequences with known-typed contents, tensors, other objects or primitives)
    elif isinstance(obj, str):
        print(' ' * indent + f'str[len={len(obj)}]')
    elif isinstance(obj, torch.Tensor):
        print(' ' * indent + f'tensor[shape={obj.shape}, sum={obj.sum()}]')
    else:
        print(' ' * indent + f'{type(obj)}')

In [85]:
def average_token_attn(token_attn):
    layer_avg_attns = []
    for layer_idx, layer_attn in enumerate(token_attn):
        print('layer_idx', layer_idx)
        # Take mean across model's attention heads
        layer_avg_attn = layer_attn.squeeze(0).mean(dim=0)
        inspect_shape('layer_avg_attn', layer_avg_attn)
        layer_avg_attn_cleaned = torch.concat([
            torch.tensor([0.]),  # First entry is null attention, set it to 0
            # Remove all tokens except the most recent (since the first token generated has entire prompts' worth of attention generated) and remove the first entry
            layer_avg_attn[-1][1:],
            torch.tensor([0.]),  # Add a 0 for the current token itself
        ])
        inspect_shape('layer_avg_attn_cleaned', layer_avg_attn_cleaned)
        layer_avg_attn_normalized = layer_avg_attn_cleaned / layer_avg_attn_cleaned.sum()
        inspect_shape('layer_avg_attn_normalized', layer_avg_attn_normalized)
        layer_avg_attns.append(layer_avg_attn_normalized)
    inspect_shape('layer_avg_attns', layer_attn)
    return torch.stack(layer_avg_attns).mean(dim=0)


def generate_and_attn(prompt):
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    inspect_shape('tokens', tokens)

    outputs = model.generate(tokens, max_new_tokens=48,
                             output_attentions=True, return_dict_in_generate=True, pad_token_id=tokenizer.eos_token_id)
    inspect_shape('outputs.attentions', outputs.attentions)
    token_attn = average_token_attn(outputs.attentions[1])
    inspect_shape('token_attn', token_attn)
    print(token_attn)


PROMPT = """
The 2008 Summer Olympics torch relay was run from March 24 until August 8, 2008, prior to the 2008 Summer Olympics, with the theme of "one world, one dream". The torch relay took place over 45km with 18 total runners.

Q: What was the theme?
A:
""".strip()

generate_and_attn(PROMPT)

tokens: tensor[shape=torch.Size([1, 60]), sum=254623]
outputs.attentions: <class 'tuple'>[len=48]
  <class 'tuple'>[len=12]
    tensor[shape=torch.Size([1, 12, 60, 60]), sum=720.0]
    tensor[shape=torch.Size([1, 12, 60, 60]), sum=720.0]
    tensor[shape=torch.Size([1, 12, 60, 60]), sum=720.0]
  <class 'tuple'>[len=12]
    tensor[shape=torch.Size([1, 12, 1, 61]), sum=11.999999046325684]
    tensor[shape=torch.Size([1, 12, 1, 61]), sum=12.000000953674316]
    tensor[shape=torch.Size([1, 12, 1, 61]), sum=12.0]
  <class 'tuple'>[len=12]
    tensor[shape=torch.Size([1, 12, 1, 62]), sum=12.000000953674316]
    tensor[shape=torch.Size([1, 12, 1, 62]), sum=12.0]
    tensor[shape=torch.Size([1, 12, 1, 62]), sum=11.999999046325684]
layer_idx 0
layer_avg_attn: tensor[shape=torch.Size([1, 61]), sum=0.9999999403953552]
layer_avg_attn_cleaned: tensor[shape=torch.Size([62]), sum=0.9806371331214905]
layer_avg_attn_normalized: tensor[shape=torch.Size([62]), sum=1.0]
layer_idx 1
layer_avg_attn: tensor[

In [84]:
a = torch.Tensor([1, 2, 3, 4])
torch.nn.functional.normalize(a, dim=0)

tensor([0.1826, 0.3651, 0.5477, 0.7303])