In [None]:
import json
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
from matplotlib import pyplot as plt 

from pos_tagger import PosTagger
from attention_extraction import extract_all_attention, get_causal_selfattention_pattern
from data_handling import load_tinystories_data
from plotting import plot_probe_results_from_tag, plot_idx_of_highest_output, plot_selfattention_from_idx

data = load_tinystories_data('../data/tinystories_val.txt')


model_url = 'roneneldan/TinyStories-1M'

model = AutoModelForCausalLM.from_pretrained(model_url, output_attentions=True)
tokenizer = AutoTokenizer.from_pretrained(model_url)

pos_tagger = PosTagger(tokenizer)

with open('../probe-results/results_keys.json', 'r') as file:
    results_key = json.load(file) 

with open('../probe-results/results_queries.json', 'r') as file:
    results_query = json.load(file) 

In [None]:
folder = '../figures/case_studies/but_heads/'

## head active after but

layer 6, head 13 is active after a "but" and keeps attending back to it. Notably the probes did now exhibit an ability to identify conjunctions. Maybe this is because conjunctions include a lot of other words than "but" and this head is more specific

In [None]:
layer = 6
head = 13
input = data[50]

plot_probe_results_from_tag(results_key, 'CC', cmap='Blues', outfile=folder + 'but_CC_probe.pdf')

top_idx, _, _ = plot_idx_of_highest_output(model, tokenizer, input, layer, head, pos_tagger, color=(0.5, 0.1, 0.1), outfile=folder + 'but_activity.pdf')

keys, queries, values = extract_all_attention(model, tokenizer, input)
attention = get_causal_selfattention_pattern(keys[layer][head], queries[layer][head])
tokens, tags, words = pos_tagger.tag_input(input, return_words=True)



fig, axs = plt.subplots(10, 1, figsize=(5,8), sharex=True)
axs = axs.flatten()
top_idx = top_idx - 5
start_buffer = 5
start = top_idx - start_buffer


for idx, ax in enumerate(axs):
    i = idx
    idx += top_idx

    # ax.set_title(words[idx])

    ax.text(1.05, 0.5, words[idx], transform=ax.transAxes, rotation=0, 
            fontsize=12, va="center", ha="left")

    attn_context = attention[idx, start : top_idx + 10]
    words_context = words[start : top_idx + 10]
    tags_context = tags[start : top_idx + 10]

    barplot = [attn_context[j].item() for j in range(10 + start_buffer)]

    bars = ax.bar(range(start, top_idx + 10), barplot, color=(0.1, 0.2, 0.5))
    bars[i + start_buffer].set_color((0.4, 0.6, 0.8))
    ax.set_ylim((0, 0.8))
    ax.set_xticks(range(start, top_idx + 10))
    ax.set_xticklabels([f'{word} ({tag})' for word, tag in zip(words_context, tags_context)], rotation=45, ha='right')

fig.savefig(folder + 'but_attention.pdf', bbox_inches='tight')