In [None]:
import torch
import numpy
import pandas as pd
import os
import random
import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F


In [None]:
model = HookedTransformer.from_pretrained(
    "gpt2-XL",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True
)


In [None]:
model

In [None]:
data_path = '../dataset_csvs_singular_plural/s_plurals.csv'

In [None]:
data = pd.read_csv(data_path)
data.shape, data.columns

In [None]:
average_attention_weights = np.load('../nouns_average_weights_sing_to_plu.npy')

In [None]:
average_attention_weights.shape

In [None]:
layers = 48
heads = 25
dim = 64

In [None]:
import ast

token0_file = '../dataset_csvs_singular_plural/predictions/coordinates_tokens_0.txt'
token1_file = '../dataset_csvs_singular_plural/predictions/coordinates_tokens_1.txt'
token2_file = '../dataset_csvs_singular_plural/predictions/coordinates_tokens_2.txt'
token3_file = '../dataset_csvs_singular_plural/predictions/coordinates_tokens_3.txt'
token4_file = '../dataset_csvs_singular_plural/predictions/coordinates_tokens_4.txt'

files = [token0_file, token1_file, token2_file, token3_file, token4_file]

line_to_process = 0
head_ablation_mapping = {}
for i, file_name in enumerate(files):
    with open(file_name, 'r') as file:
        token0_data = file.readlines()  # Read all lines into a list
        
        token0_interval_0 = token0_data[line_to_process].strip()
        split_string = token0_interval_0.split(':')
        
        token0_interval_0_list = ast.literal_eval(split_string[1][2:-1])  # Convert string to list using ast
        token_0 ={}
        for each in token0_interval_0_list:
            each[0] = each[0] + 6
            if each[0] in token_0.keys():
                token_0[each[0]].append(each[1])
            else:
                token_0[each[0]] = [each[1]]

        if i == 0:
            head_ablation_mapping[1] = token_0
        elif i == 1:
            head_ablation_mapping[2] = token_0
        elif i == 2:
            head_ablation_mapping[3] = token_0
        elif i == 3:
            head_ablation_mapping[4] = token_0
        elif i == 4:
            head_ablation_mapping[5] = token_0
        
        
        



In [None]:
head_ablation_mapping = {token: dict(sorted(heads.items())) for token, heads in head_ablation_mapping.items()}


In [None]:
head_ablation_mapping.keys()

In [None]:
import torch
import torch.nn.functional as F
import pandas as pd

def replace_hook_z(z, hook, heads_to_ablate, token_idx):
    #print("replace_hook_z called")
    #print(f"Hook name: {hook.name}")
    #print(f"z shape: {z.shape}")
    layer_num = int(hook.name.split('.')[1])
    try:
        heads = heads_to_ablate[layer_num]
    except:
        return z
  
    for head_index in heads:
        if head_index < z.size(2):  # Ensure head_index is within bounds
            # Use the correct slice from average_attention_weights
            avg_weight = average_attention_weights[layer_num, head_index, :]
            z[:, token_idx, head_index, :] = torch.from_numpy(avg_weight).to(z.device)

    return z

results = []
for idx, row in data.iterrows():
    singular = row['sentence']
    plural = row['answer']
    
    hooks = []
    print(f"Processing sentence {idx+1}/{len(data)}")

    # Create hooks for each token
    for token, layer_heads_mapping in head_ablation_mapping.items():
        #print(f"Token: {token}")
        for layer, heads in layer_heads_mapping.items():            
            #print(f"Layer: {layer}, Heads: {heads}")
            if isinstance(heads, int):
                heads = [heads]
            hooks.append((
                f'blocks.{layer}.attn.hook_z', lambda z, hook, heads=heads: replace_hook_z(z, hook, layer_heads_mapping, token)
            ))

    #print(f"Hooks created: {hooks}")

    # Run the forward pass with the hooks
    with model.hooks(fwd_hooks=hooks):
        logits = model(singular, prepend_bos=True, return_type="logits")
    
    # Get the prediction and probability
    probs = F.softmax(logits[0, -1], dim=-1)
    pred_id = probs.argmax().item()
    prediction = model.to_string(pred_id)
    probability = probs[pred_id].item()

    results.append({
        'singular': singular,
        'plural': plural,
        'prediction': prediction,
        'probability': probability,
    })
    
    
# Convert results to a DataFrame outside the loop
results_df = pd.DataFrame(results)

# Count correct predictions only once at the end
correct_count = sum(row['plural'] == row['prediction'] for _, row in results_df.iterrows())
print(f"Correct predictions: {correct_count}/{len(results_df)} ({correct_count/len(results_df)*100:.2f}%)")
