In [None]:
import pandas as pd
import json
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from pos_tagger import PosTagger
from data_handling import load_tinystories_data
from sklearn.model_selection import train_test_split
from attention_extraction import extract_all_attention, get_causal_selfattention_pattern
import matplotlib.pyplot as plt
from collections import Counter, defaultdict
import plotting

from statistics import mean

In [None]:
with open("../probe-results/results_queries.json", "r") as json_file:
    data = json.load(json_file)

In [None]:
tags = list(data[0][0].keys())

In [None]:
layers = len(data)
heads = len(data[0])
postags = tags[:tags.index('accuracy')]

In [None]:
data[5][13]['PRP$']

In [None]:
layer_max = {}

for layer in range(layers):
    max_head, max_accuracy = max(((head, data[layer][head]['accuracy']) for head in range(heads)),key=lambda x: x[1])
    
    layer_max[layer] = {'head': f'{layer}_{max_head}', 'accuracy': max_accuracy}


In [None]:
tag_max_head = {}  

for tag in postags:  

    max_layer, max_head, max_f1 = max(((layer, head, data[layer][head][tag]['f1-score']) for layer in range(layers) for head in range(heads)),key=lambda x: x[2])
    
    tag_max_head[tag] = {
        'head': f'{max_layer}_{max_head}',  
        'f1-score': max_f1  
    }



In [None]:
tag_max_head

#### Model

In [None]:
model_url = 'roneneldan/TinyStories-1M'
model = AutoModelForCausalLM.from_pretrained(model_url,output_attentions = True)
tokenizer = AutoTokenizer.from_pretrained(model_url)
pos_tagger = PosTagger(tokenizer)

sentences = load_tinystories_data('../data/tinystories_val.txt')

In [None]:
def extract_averaged_queries(attentions, layer, head, lookback, tag, tags_train, all_tags):
    averaged_attention = {tag: {inner_tag: [] for inner_tag in all_tags}}

    for index in range(len(tags_train)):
        if tags_train[index] == tag:
            before = max(index - lookback, 0)  # Ensure `before` is not negative
            for i in range(before, index):
                inner_tag = tags_train[i]
                value = attentions[layer][0][head][index][i].detach().numpy()
                averaged_attention[tag][inner_tag].append(value)

    for inner_tag, value_list in averaged_attention[tag].items():
        if value_list:  # Check if the list is not empty
            averaged_attention[tag][inner_tag] = np.mean(value_list)
        else:
            averaged_attention[tag][inner_tag] = 0  # or another default value

    # Visualization with consistent x-axis
    inner_dict = averaged_attention[tag]
    fig, ax = plt.subplots(figsize=(12, 5))

    # Ensure consistent ordering of bars
    ordered_values = [inner_dict.get(inner_tag, 0) for inner_tag in all_tags]
    ax.bar(x=all_tags, height=ordered_values,color = 'darkblue')

    ax.grid(axis='y', linestyle='--', color='gray', alpha=0.7)
    ax.set_xticks(range(len(all_tags)))
    ax.set_xticklabels(all_tags, rotation=90, ha='center')  # Rotate for readability

    plt.title(f"Averaged Attention Values from {tag}")
    plt.ylabel("Average Attention Value")
    plt.xlabel("POS Tags")
    plt.show()

    return averaged_attention


In [None]:
for i in range(10):
    inputs = tokenizer(sentences[i], return_tensors="pt")
    outputs = model(**inputs)
    attentions = outputs.attentions  
    tags_train = pos_tagger.tag_input(sentences[i])
    ls = {tags_train[1][j] : tokenizer.decode(tags_train[0]['input_ids'][j]) for j in range(len(tags_train[1]))}
    tag = 'PRP'

    tags_vb = ['VB' if tag.startswith('VB') else tag for tag in tags_train[1]]

    pos = [i for i in range(len(tags_train[1])) if tags_train[1][i] == tag]
    extract_averaged_queries(layer=2, head=6, lookback=40,tag = tag, tags_train=tags_vb,attentions=attentions,all_tags=postags)

In [None]:
def extract_queries(index, layer, head, lookback, tag):
    dit = dict()
    before = index - lookback
    if before >= 0:
        for i in range(before, index):
            if i != index:
                outer_tag = tags_train[1][index]  # Assuming tags_train is a list or similar structure
                inner_tag = tags_train[1][i]
                if outer_tag not in dit:
                    dit[outer_tag] = {}
                if inner_tag not in dit[outer_tag]:
                    dit[outer_tag][inner_tag] = attentions[layer][0][head][index][i].detach().cpu().numpy()  # Ensure tensor is on CPU before calling numpy()
                else:
                    # Take the mean of the previous values and the current attention score
                    dit[outer_tag][inner_tag] = np.mean([dit[outer_tag][inner_tag], attentions[layer][0][head][index][i].detach().cpu().numpy()])

    all_tags = list(set(tags_train[1]))  # Unique tags in tags_train
    current_data = dit.get(tags_train[1][index], {})
    heights = [current_data.get(tag, 0) for tag in all_tags]  # Use 0 for missing tags

    # Create the plot
    plt.figure(figsize=(12, 4))
    plt.bar(x=all_tags, height=heights, color='darkblue')
    plt.grid(axis='y', linestyle='--', color='gray', alpha=0.7)
    plt.title(f"Averaged Attention Values from {tag}")
    plt.xticks(range(len(all_tags)), all_tags, rotation=90, ha='center')
    plt.ylabel("Average Attention")

    plt.show()

In [None]:
data_filter = sentences[1530]
inputs = tokenizer(data_filter, return_tensors="pt")
tags_train = pos_tagger.tag_input(data_filter, return_words = True)
outputs = model(**inputs)
attentions = outputs.attentions
tag = tags_train[1][24]

In [None]:
extract_queries(24,5,5,24,tag)

In [None]:
plotting.plot_selfattention_from_idx(attention=attentions[5][0][5],tags=tags_train[1],idx = 24, context_size=24, words = tags_train[2])

In [None]:
def extract_averaged_across_sentences(layer, head, lookback, tag, data, tokenizer, model, pos_tagger):

    cumulative_attention = dict()
    all_tags_set = set()
    tag_counts = Counter() 

    for sentence in data:
        inputs = tokenizer(sentence, return_tensors="pt")
        outputs = model(**inputs)
        attentions = outputs.attentions  

        tags_train = pos_tagger.tag_input(sentence)
        sentence_tags = tags_train[1]
        all_tags_set.update(sentence_tags)
        tag_counts.update(sentence_tags)  

        if not cumulative_attention:
            cumulative_attention = {tag: [] for tag in all_tags_set}

        for index, current_tag in enumerate(sentence_tags):
            if current_tag == tag:
                start = max(0, index - lookback)
                for i in range(start, index):
                    if i != index:
                        context_tag = sentence_tags[i]
                        value = attentions[layer][0][head][index][i].detach().cpu().numpy()
                        cumulative_attention.setdefault(context_tag, []).append(value)

    averaged_attention = {
        key: (np.mean(values) if values else 0)
        for key, values in cumulative_attention.items()
    }

    sorted_tags = sorted(all_tags_set)
    sorted_values = [averaged_attention.get(tag, 0) for tag in sorted_tags]

    max_count = max(tag_counts.values())
    alphas = [(tag_counts[tag] / max_count) if tag in tag_counts else 0 for tag in sorted_tags]

    plt.figure(figsize=(12, 4))
    bars = plt.bar(sorted_tags, sorted_values, color='darkblue', alpha=1.0)  # Default alpha

    for bar, alpha in zip(bars, alphas):
        bar.set_alpha(alpha)

    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.xticks(rotation=90, ha='center')
    plt.title(f"Averaged Attention from {tag}")
    plt.ylabel("Average Attention")
    plt.xlabel("POS Tags")
    plt.tight_layout()
    plt.show()

    return averaged_attention


In [None]:
averaged_attention_result = extract_averaged_across_sentences(
    layer=5, 
    head=5, 
    lookback=40, 
    tag='PRP$', 
    data=sentences[:100], 
    tokenizer=tokenizer, 
    model=model, 
    pos_tagger=pos_tagger
)

In [None]:
def get_tag_relation(attention, tags, key_tag, context_size):
    key_tag_mask = np.array(tags) == key_tag
    avgs = defaultdict(list)

    for idx, tag in enumerate(tags):
        if idx < context_size:
            continue
        att_window = attention[idx, max(0, idx - context_size) : idx]
        mask_windows = key_tag_mask[max(0, idx - context_size) : idx]

        if sum(mask_windows):
            avgs[tag].append(att_window[mask_windows].sum().item())

    for tag in avgs:
        avgs[tag] = mean(avgs[tag])

    return dict(avgs)

def plot_avg_attention(data, model, tokenizer, layer=5, head=5, tag='PRP$', context_size=40):
    avg_attention_aggr = defaultdict(list)
    tag_counts = defaultdict(int)

    # Process the sentences to get attention relations
    for input in data:
        pos_tagger = PosTagger(tokenizer)
        tokens, tags, words = pos_tagger.tag_input(input, return_words=True)
        tokenized = tokenizer(input, return_tensors='pt')
        attentions = model(tokenized.input_ids).attentions
        attention = attentions[layer][0][head]

        avg_attention = get_tag_relation(attention, tags, key_tag=tag, context_size=context_size)

        # Aggregate the attention results and tag counts
        for key in avg_attention:
            avg_attention_aggr[key].append(avg_attention[key])
            tag_counts[key] += 1

    # Calculate mean and length for each tag
    for key in avg_attention_aggr:
        avg_attention_aggr[key] = (mean(avg_attention_aggr[key]), len(avg_attention_aggr[key]))

    # Sort tags alphabetically
    sorted_tags = sorted(avg_attention_aggr.keys())
    sorted_values = [avg_attention_aggr.get(tag, (0, 0))[0] for tag in sorted_tags]

    # Normalize alpha values based on tag counts
    max_count = max(tag_counts.values())
    alphas = [(tag_counts[tag] / max_count) if tag in tag_counts else 0 for tag in sorted_tags]

    # Create the figure for plotting
    plt.figure(figsize=(12, 4))

    # Create the bars with a default color and alpha
    bars = plt.bar(sorted_tags, sorted_values, color='darkblue', alpha=1.0)

    # Adjust alpha based on tag counts
    for bar, alpha in zip(bars, alphas):
        bar.set_alpha(alpha)

    # Add gridlines to the y-axis
    plt.grid(axis='y', linestyle='--', alpha=0.7)

    # Rotate x-axis labels for better readability
    plt.xticks(rotation=90, ha='center')

    # Add titles and labels
    plt.title(f"Averaged Attention to {tag}")
    plt.ylabel("Average Attention")
    plt.xlabel("POS Tags")

    # Ensure the plot is well-organized
    plt.tight_layout()

    # Show the plot
    plt.show()

    return avg_attention_aggr

In [None]:
averaged_key_attention_result = plot_avg_attention(
    layer = 5,
    head = 5,
    context_size = 40,
    tag = 'NN',
    data = sentences[:100],
    tokenizer = tokenizer, 
    model = model
    )