In [None]:
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np

from data_handling import load_tinystories_data
from attention_extraction import plot_selfattention_pattern, extract_all_attention
from pos_tagger import PosTagger

from collections import defaultdict
from statistics import mean

from matplotlib import pyplot as plt
import pandas as pd

In [None]:

model_url = 'roneneldan/TinyStories-1M'

model = AutoModelForCausalLM.from_pretrained(model_url, output_attentions=True)
tokenizer = AutoTokenizer.from_pretrained(model_url)


data = load_tinystories_data('data/tinystories_val.txt')

In [None]:
def get_tag_relation(attention, tags, key_tag, context_size):


    key_tag_mask = np.array(tags) == key_tag

    avgs = defaultdict(list)

    for idx, tag in enumerate(tags):
        if idx < context_size:
            continue
        att_window = attention[idx, max(0, idx - context_size) : idx]

        mask_windows = key_tag_mask[max(0, idx - context_size) : idx]

        if sum(mask_windows):
            avgs[tag].append(att_window[mask_windows].sum().item())


    for tag in avgs:
        avgs[tag] = mean(avgs[tag])

    return dict(avgs)

In [None]:
layer = 0
head = 11
tag = 'DT'

avg_attention_aggr = defaultdict(list)

for input in data[:100]:
    pos_tagger = PosTagger(tokenizer)
    tokens, tags, words = pos_tagger.tag_input(input, return_words=True)
    tokenized = tokenizer(input, return_tensors='pt')
    attentions = model(tokenized.input_ids).attentions
    attention = attentions[layer][0][head]

    avg_attention = get_tag_relation(attention, tags, key_tag=tag, context_size=40)

    for key in avg_attention:
        avg_attention_aggr[key].append(avg_attention[key])

for key in avg_attention_aggr:
    avg_attention_aggr[key] = (mean(avg_attention_aggr[key]), len(avg_attention_aggr[key]))

In [None]:
fig, ax = plt.subplots(1, figsize=(15,3))

max_val = max([len_ for _, len_ in avg_attention_aggr.values()])
means, lens = zip(*avg_attention_aggr.values())
bars = ax.bar(avg_attention_aggr.keys(), means)

for bar, alpha in zip(bars, lens):
    bar.set_alpha(alpha / max_val)

In [None]:
input = data[100]
pos_tagger = PosTagger(tokenizer)
tokens, tags, words = pos_tagger.tag_input(input, return_words=True)
tokenized = tokenizer(input, return_tensors='pt')
attentions = model(tokenized.input_ids).attentions
attention = attentions[layer][0][head]

start, end = (0,7)

plot_selfattention_pattern(attention[start:end, start:end], words[start:end], tags[start:end])

In [None]:
attentions[2][0][6][64][63]

In [None]:
for idx, tag_ in enumerate(tags):
    if tag_ == tag: 
        print(tag_, idx)    


In [None]:
for idx, tag_ in enumerate(tags):
    if tag_ == 'VBD': 
        print(tag_, idx)    

In [None]:
len(tags)

In [None]:
attention[24, 10]

In [None]:
string_ = 'Once upon a time, a little boy named Florian saw'
tokenized = tokenizer(string_, return_tensors='pt')

output = model.generate(tokenized.input_ids, max_length=300)

print(tokenizer.decode(output[0][0], skip_special_tokens=True))

In [None]:
output[0]

In [None]:
attention.shape

In [None]:
key_tag ='EX'

key_tag_mask = np.array(tags) == key_tag

avgs = defaultdict(list)

for idx, tag in enumerate(tags):
    att_window = attention[idx, max(0, idx - 30) : idx]
    mask_windows = key_tag_mask[max(0, idx - 30) : idx]
    if sum(mask_windows):
        avgs[tag].append(att_window[mask_windows].mean().item())


for tag in avgs:
    avgs[tag] = mean(avgs[tag])


In [None]:
layer = 2
head = 7

get_tag_relation(attentions[layer][0][head], tags, 'PRP$')

In [None]:
plot_selfattention_pattern(attentions[layer][0][head], words)

In [None]:
output = model(tokenized.input_ids)

In [None]:

att_for_head = output.attentions[layer][0][head]

In [None]:
get_tag_relation(att_for_head, tags, 'DT')