In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from matplotlib import pyplot as plt

In [None]:
model_url = 'roneneldan/TinyStories-1M'

model = AutoModelForCausalLM.from_pretrained(model_url)
tokenizer = AutoTokenizer.from_pretrained(model_url)

In [None]:
str_ = 'The boy picked up the chair'

tokenized = tokenizer.encode(str_, return_tensors='pt')

attention_mask = torch.ones(tokenized.shape, device=tokenized.device)

pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id

output = model.generate(tokenized, max_length=100, num_beams=1, attention_mask=attention_mask, pad_token_id=pad_token_id)

print(tokenizer.decode(output[0], skip_special_tokens=True))


In [None]:
def get_attention(model, layer, input_):
    keys = []
    queries = []

    num_heads = model.config.num_heads
    print(num_heads)

    def extract_key_hook(module, input, output):
        keys.append(torch.chunk(output, num_heads, dim=-1))

    def extract_query_hook(module, input, output):
        queries.append(torch.chunk(output, num_heads, dim=-1))


    key_hookpoint = f'transformer.h.{layer}.attn.attention.k_proj'
    query_hookpoint = f'transformer.h.{layer}.attn.attention.q_proj'

    key_hook = model.get_submodule(key_hookpoint).register_forward_hook(extract_key_hook)
    query_hook = model.get_submodule(query_hookpoint).register_forward_hook(extract_query_hook)

    tokenized = tokenizer.encode(input_, return_tensors='pt')

    output = model.forward(tokenized)

    key_hook.remove()
    query_hook.remove()

    queries = queries[0] # because the hooks are appending to list
    keys = keys[0] # same

    attn_patterns = []
    for head_idx in range(num_heads):
        attn_patterns.append(torch.softmax(queries[head_idx][0] @ keys[head_idx][0].T, dim=1))

    return attn_patterns, tokenized


In [None]:
attn_patterns, tokenized = get_attention(
    model=model,
    layer=2, 
    input_='The little wolf walked past the forest. His paws did not make a sound.'
)

attn_pattern = attn_patterns[7]

fig, axs = plt.subplots(len(tokenized[0]), figsize=(5, 20))
axs = axs.flatten()

for i in range(len(tokenized[0])):
    axs[i].set_title(tokenizer.decode(tokenized[0][i]))

    barplot = {
        tokenizer.decode(tokenized[0][j]): attn_pattern[i][j].item()
        for j in range(len(tokenized[0]))
    }

    axs[i].bar(barplot.keys(), barplot.values(), label=barplot.keys())
    axs[i].set_xticks(list(barplot.keys()))
    axs[i].set_xticklabels(list(barplot.keys()), rotation=45)

fig.tight_layout()