In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, f1_score
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
from tqdm import tqdm
import numpy as np
from collections import Counter


from plotting import plot_f1_matrix
from pos_tagger import PosTagger
from data_handling import load_tinystories_data
from attention_extraction import extract_all_attention
import json

In [None]:
def get_pos_tags(pos_tagger, data):
    all_tags = []
    for sent in data:
        tokens, tags, words = pos_tagger.tag_input(sent, return_words=True)
        all_tags.extend(tags)

    return all_tags

In [None]:
def init(model_url:str,data_path:str,data_size:int,vb:bool = False):
    model = AutoModelForCausalLM.from_pretrained(model_url)
    tokenizer = AutoTokenizer.from_pretrained(model_url)

    num_heads = model.config.num_heads
    num_layers = model.config.num_layers

    data = load_tinystories_data(data_path)
    data_train, data_test = train_test_split(data, test_size=0.2)
    data_train = data_train[:data_size]
    data_test = data_test[:data_size]

    pos_tagger = PosTagger(tokenizer)
    tags_train = get_pos_tags(pos_tagger, data_train)
    tags_test = get_pos_tags(pos_tagger, data_test)
    
    if vb:
        tags_train = ['VB' if tag.startswith('VB') else tag for tag in tags_train]
        tags_test = ['VB' if tag.startswith('VB') else tag for tag in tags_test]
    else:
        pass

    keys_train, queries_train, values_train = extract_all_attention(model, tokenizer, data_train)
    keys_test, queries_test, values_test = extract_all_attention(model, tokenizer, data_test)

    return num_heads, num_layers, keys_train, keys_test, queries_train, queries_test, tags_train, tags_test

In [None]:
model_url = 'roneneldan/TinyStories-1M'
data_path = '../data/tinystories_val.txt'

In [None]:
num_heads, num_layers, keys_train, keys_test, queries_train, queries_test, tags_train, tags_test = init(model_url=model_url,data_path=data_path,data_size=200,vb = False)

In [None]:
def train_predict(num_heads:int,num_layers:int,train_data:np.ndarray,test_data:np.ndarray,tags_train:list,tags_test:list,filename:str):
        probes = [
            [
                LogisticRegression(solver='newton-cg',max_iter=100)
                for head in range(num_heads)
            ]
            for layer in range(num_layers)
        ]

        # train
        for layer in tqdm(range(num_layers)):
            for head in tqdm(range(num_heads)):
                #print(layer, head)
                probes[layer][head].fit(
                    train_data[layer][head],
                    tags_train
                )



        # predict
        results = []
        for layer in tqdm(range(num_layers)):
            results.append([])
            for head in tqdm(range(num_heads)):
                preds = probes[layer][head].predict(test_data[layer][head])
                results[-1].append(
                    classification_report(
                        tags_test, 
                        preds,
                        output_dict=True,
                        zero_division=1
                    )
                ) 

        with open(f"results/results_{filename}.json", "w") as json_file:
            json.dump(results, json_file, indent=4)
        
        return probes

In [None]:
probes = train_predict(num_heads=num_heads,num_layers=num_layers,train_data=queries_train,test_data=queries_test,tags_train=tags_train, tags_test=tags_test,filename='queries')

In [None]:
with open('../probe-results/results_keys.json','r') as f:
    results = json.load(f)

In [None]:
plot_f1_matrix(results=results,filename='keys',color='Blues')

In [None]:
with open('../probe-results/results_queries.json','r') as f:
    results = json.load(f)

In [None]:
plot_f1_matrix(results=results,filename='queries',color='Reds')