In [2]:
import torch
import numpy
import pandas as pd
import os
import random
import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F


In [2]:
model = HookedTransformer.from_pretrained(
    "gpt2-XL",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True
)


Loaded pretrained model gpt2-XL into HookedTransformer


In [3]:
model

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-47): 48 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (h

In [4]:
data_path = '../dataset_csvs_singular_plural/s_plurals.csv'

In [5]:
data = pd.read_csv(data_path)
data.shape, data.columns

((127, 2), Index(['sentence', 'answer'], dtype='object'))

In [3]:
average_attention_weights = np.load('../nouns_average_weights_sing_to_plu.npy')

In [4]:
average_attention_weights.shape

(48, 25, 64)

In [8]:
layers = 48
heads = 25
dim = 64

In [11]:
# Convert average_attention_weights to a PyTorch tensor if it's not already
average_attention_weights = torch.tensor(average_attention_weights)

def replace_hook_z(z, hook):
    layer_num = int(hook.name.split('.')[1])
    # The 4th column is index 4 because of the prepended BOS token
    z[:, 4, head, :] = average_attention_weights[layer_num, head, :].to(z.device)
    return z

for target_layer in range(0, layers):
    for head in range(0, heads):
        print(f"Processing layer {target_layer}, head {head}")
        results = []
        for idx, row in data.iterrows():
            singular = row['sentence']
            plural = row['answer']
            
            # Run the forward pass with the hook
            with model.hooks(fwd_hooks=[
                (f'blocks.{target_layer}.attn.hook_z', lambda z, hook, h=head: replace_hook_z(z, hook))
            ]):
                logits = model(singular, prepend_bos=True, return_type="logits")
            
            # Get the prediction and probability
            probs = F.softmax(logits[0, -1], dim=-1)
            pred_id = probs.argmax().item()
            prediction = model.to_string(pred_id)
            probability = probs[pred_id].item()
            
            # Store the results
            results.append({
                'singular': singular,
                'plural': plural,
                'prediction': prediction,
                'probability': probability
            })

        # Convert results to a DataFrame
        results_df = pd.DataFrame(results)
        results_df.to_csv('../mean_ablated_predictions/predictions_s_plural_XL_layer_' + str(target_layer) + '_' + str(head) +  '.csv')


  average_attention_weights = torch.tensor(average_attention_weights)


Processing layer 0, head 0
Processing layer 0, head 1
Processing layer 0, head 2
Processing layer 0, head 3
Processing layer 0, head 4
Processing layer 0, head 5
Processing layer 0, head 6
Processing layer 0, head 7
Processing layer 0, head 8
Processing layer 0, head 9
Processing layer 0, head 10
Processing layer 0, head 11
Processing layer 0, head 12
Processing layer 0, head 13
Processing layer 0, head 14
Processing layer 0, head 15
Processing layer 0, head 16
Processing layer 0, head 17
Processing layer 0, head 18
Processing layer 0, head 19
Processing layer 0, head 20
Processing layer 0, head 21
Processing layer 0, head 22
Processing layer 0, head 23
Processing layer 0, head 24
Processing layer 1, head 0
Processing layer 1, head 1
Processing layer 1, head 2
Processing layer 1, head 3
Processing layer 1, head 4
Processing layer 1, head 5
Processing layer 1, head 6
Processing layer 1, head 7
Processing layer 1, head 8
Processing layer 1, head 9
Processing layer 1, head 10
Processing l

In [12]:
# Convert average_attention_weights to a PyTorch tensor if it's not already
average_attention_weights = torch.tensor(average_attention_weights)

def replace_hook_z(z, hook):
    layer_num = int(hook.name.split('.')[1])
    # The 4th column is index 4 because of the prepended BOS token
    z[:, 4, :, :] = average_attention_weights[layer_num, :, :].to(z.device)
    return z

for target_layer in range(0, layers):
    
    print(f"Processing layer {target_layer}")
    results = []
    for idx, row in data.iterrows():
        singular = row['sentence']
        plural = row['answer']
        
        # Run the forward pass with the hook
        with model.hooks(fwd_hooks=[
            (f'blocks.{target_layer}.attn.hook_z', lambda z, hook: replace_hook_z(z, hook))
        ]):
            logits = model(singular, prepend_bos=True, return_type="logits")
        
        # Get the prediction and probability
        probs = F.softmax(logits[0, -1], dim=-1)
        pred_id = probs.argmax().item()
        prediction = model.to_string(pred_id)
        probability = probs[pred_id].item()
        
        # Store the results
        results.append({
            'singular': singular,
            'plural': plural,
            'prediction': prediction,
            'probability': probability
        })

    # Convert results to a DataFrame
    results_df = pd.DataFrame(results)
    results_df.to_csv('../mean_ablated_predictions/predictions_s_plural_XL_layer_' + str(target_layer) +  '.csv')


  average_attention_weights = torch.tensor(average_attention_weights)


Processing layer 0
Processing layer 1
Processing layer 2
Processing layer 3
Processing layer 4
Processing layer 5
Processing layer 6
Processing layer 7
Processing layer 8
Processing layer 9
Processing layer 10
Processing layer 11
Processing layer 12
Processing layer 13
Processing layer 14
Processing layer 15
Processing layer 16
Processing layer 17
Processing layer 18
Processing layer 19
Processing layer 20
Processing layer 21
Processing layer 22
Processing layer 23
Processing layer 24
Processing layer 25
Processing layer 26
Processing layer 27
Processing layer 28
Processing layer 29
Processing layer 30
Processing layer 31
Processing layer 32
Processing layer 33
Processing layer 34
Processing layer 35
Processing layer 36
Processing layer 37
Processing layer 38
Processing layer 39
Processing layer 40
Processing layer 41
Processing layer 42
Processing layer 43
Processing layer 44
Processing layer 45
Processing layer 46
Processing layer 47


In [None]:
count = 0
for i, row in results_df.iterrows():
    #print(row['plural'], row['prediction'][1:])
    if row['plural'] == row['prediction'][1:]:
        count += 1
count

-0.36473373