<a href="https://colab.research.google.com/github/wlg100/numseqcont_circuit_expms/blob/main/notebook_templates/headFNs_expms_template.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

# Setup

In [None]:
%%capture
%pip install git+https://github.com/redwoodresearch/Easy-Transformer.git
%pip install einops datasets transformers fancy_einsum

In [None]:
from copy import deepcopy
import torch

assert torch.cuda.device_count() == 1
from tqdm import tqdm
import pandas as pd
import torch
import torch as t
from easy_transformer.EasyTransformer import (
    EasyTransformer,
)
from time import ctime
from functools import partial

import numpy as np
from tqdm import tqdm
import pandas as pd

from easy_transformer.experiments import (
    ExperimentMetric,
    AblationConfig,
    EasyAblation,
    EasyPatching,
    PatchingConfig,
)
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
import random
import einops
from IPython import get_ipython
from copy import deepcopy
from easy_transformer.ioi_dataset import (
    IOIDataset,
)
from easy_transformer.ioi_utils import (
    path_patching,
    max_2d,
    CLASS_COLORS,
    show_pp,
    show_attention_patterns,
    scatter_attention_and_contribution,
)
from random import randint as ri
from easy_transformer.ioi_circuit_extraction import (
    do_circuit_extraction,
    get_heads_circuit,
    CIRCUIT,
)
from easy_transformer.ioi_utils import logit_diff, probs
from easy_transformer.ioi_utils import get_top_tokens_and_probs as g

ipython = get_ipython()
if ipython is not None:
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

 Initialise model (use larger N or fewer templates for no warnings about in-template ablation)

In [None]:
model = EasyTransformer.from_pretrained("gpt2").cuda()
# model = EasyTransformer.from_pretrained("gpt2")
model.set_use_attn_result(True)

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



Moving model to device:  cuda
Finished loading pretrained model gpt2 into EasyTransformer!


# Generate dataset with multiple prompts

In [None]:
#@title Names list
names = [
    "Michael",
    "Christopher",
    "Jessica",
    "Matthew",
    "Ashley",
    "Jennifer",
    "Joshua",
    "Amanda",
    "Daniel",
    "David",
    "James",
    "Robert",
    "John",
    "Joseph",
    "Andrew",
    "Ryan",
    "Brandon",
    "Jason",
    "Justin",
    "Sarah",
    "William",
    "Jonathan",
    "Stephanie",
    "Brian",
    "Nicole",
    "Nicholas",
    "Anthony",
    "Heather",
    "Eric",
    "Elizabeth",
    "Adam",
    "Megan",
    "Melissa",
    "Kevin",
    "Steven",
    "Thomas",
    "Timothy",
    "Christina",
    "Kyle",
    "Rachel",
    "Laura",
    "Lauren",
    "Amber",
    "Brittany",
    "Danielle",
    "Richard",
    "Kimberly",
    "Jeffrey",
    "Amy",
    "Crystal",
    "Michelle",
    "Tiffany",
    "Jeremy",
    "Benjamin",
    "Mark",
    "Emily",
    "Aaron",
    "Charles",
    "Rebecca",
    "Jacob",
    "Stephen",
    "Patrick",
    "Sean",
    "Erin",
    "Jamie",
    "Kelly",
    "Samantha",
    "Nathan",
    "Sara",
    "Dustin",
    "Paul",
    "Angela",
    "Tyler",
    "Scott",
    "Katherine",
    "Andrea",
    "Gregory",
    "Erica",
    "Mary",
    "Travis",
    "Lisa",
    "Kenneth",
    "Bryan",
    "Lindsey",
    "Kristen",
    "Jose",
    "Alexander",
    "Jesse",
    "Katie",
    "Lindsay",
    "Shannon",
    "Vanessa",
    "Courtney",
    "Christine",
    "Alicia",
    "Cody",
    "Allison",
    "Bradley",
    "Samuel",
]

In [None]:
def filter_names(names):
    return [name for name in names if len(model.tokenizer.tokenize(name)) == 1]
names = filter_names(names)

In [None]:
import random

def make_prompts_list(names, template, num_sentences, num_targ_tokens):
    sentences = []
    generated_set = set() # Ensure none of the generated sentences are the same
    while len(sentences) < num_sentences:
        unique_names = random.sample(names, k=num_targ_tokens)
        temp_template = template
        sentence_dict = {}
        for i, name in enumerate(unique_names, start=1):
            temp_template = temp_template.replace(f"[S{i}]", name)
            sentence_dict[f'S{i}'] = name
        sentence_dict['text'] = temp_template
        if sentence_dict['text'] not in generated_set:
            generated_set.add(sentence_dict['text'])
            sentences.append(sentence_dict)
    return sentences

In [None]:
class Dataset:
    def __init__(self, prompts, tokenizer, N, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = N

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if key != 'text']:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                    target_token = prompt[targ]
                else:
                    target_token = "Ġ" + prompt[targ]
                target_index = tokens.index(target_token)
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [None]:
# template = "[S1] and [S2] went to the store. [S2] gave a drink to"
template = "Then, [S1] and [S2] went to the store, [S2] gave a drink to"
num_inputs = 30
num_targ_tokens = 2
prompts_list = make_prompts_list(names, template, num_inputs, num_targ_tokens)
dataset = Dataset(prompts_list, model.tokenizer, num_inputs, S1_is_first=False)

# Copy score

In [None]:
def get_copy_scores(model, layer, head, dataset, verbose=False, neg=False, print_tokens=True):
    cache = {}
    model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    model(dataset.toks.long())
    if neg:
        sign = -1
    else:
        sign = 1
    z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    pred_tokens_dict = {}
    words_moved = []
    # get the keys from the first prompt in the dataset
    words = [key for key in dataset.prompts[0].keys() if key != 'text']

    for seq_idx, prompt in enumerate(dataset.prompts):
        for word in words:
            pred_tokens = [
                model.tokenizer.decode(token)
                for token in torch.topk(
                    logits[seq_idx, dataset.word_idx[word][seq_idx]], k
                ).indices
            ]
            pred_tokens_dict[prompt[word]] = pred_tokens
            if " " + prompt[word] in pred_tokens:
                n_right += 1
                words_moved.append(prompt[word])

    percent_right = (n_right / (dataset.N * len(words))) * 100
    print(f"Copy circuit for head {layer}.{head} (sign={sign}) : Top {k} accuracy: {percent_right}%")

    if print_tokens == True:
        return pred_tokens_dict
    else:
        return words_moved

## Get important heads

Find what heads are specific to certain inputs, and what's common to the template.

Get important heads from: circuit_expms_template.ipynb (Section: print top heads. Copy output of 'top_indices'; put on one line using chatgpt)

NOTE: not all attention heads just copy, so use attention patterns to determine which ones copy to refine this list of heads

(Eg. if you copy all the top heads from IOI, only 9.9 and 10.0 are name movers while other heads are "S-inhibition", "induction", "duplicate", so only the name movers + backup NM will have top accuracy)

In [None]:
top_val = [(8, 10), (9, 9), (8, 6), (5, 5), (7, 3), (7, 9), (9, 7), (10, 6), (3, 0), (10, 0)]
for index, (layer, head) in enumerate(top_val):
    print(index, get_copy_scores(model, layer, head, dataset, print_tokens=False))

Copy circuit for head 8.10 (sign=1) : Top 5 accuracy: 0.0%
0 []
Copy circuit for head 9.9 (sign=1) : Top 5 accuracy: 100.0%
1 ['Tyler', 'Brian', 'Matthew', 'Emily', 'Laura', 'Amy', 'Crystal', 'John', 'Matthew', 'Amy', 'Joshua', 'Joseph', 'Jennifer', 'Aaron', 'Jennifer', 'Sean', 'Jason', 'Kevin', 'David', 'Mark', 'Jason', 'Emily', 'Aaron', 'Jessica', 'Charles', 'Joshua', 'Steven', 'Patrick', 'Emily', 'Mark', 'Crystal', 'Brian', 'Justin', 'Ryan', 'Patrick', 'Kelly', 'Brian', 'Jason', 'Justin', 'Kevin', 'James', 'Patrick', 'Stephen', 'Jonathan', 'Rachel', 'Brian', 'Scott', 'Jennifer', 'Michael', 'Rachel', 'Charles', 'Mark', 'Daniel', 'Jennifer', 'Michelle', 'Ryan', 'Amy', 'Justin', 'James', 'Eric']
Copy circuit for head 8.6 (sign=1) : Top 5 accuracy: 8.333333333333332%
2 ['Aaron', 'Aaron', 'Kelly', 'Mark', 'Michelle']
Copy circuit for head 5.5 (sign=1) : Top 5 accuracy: 0.0%
3 []
Copy circuit for head 7.3 (sign=1) : Top 5 accuracy: 0.0%
4 []
Copy circuit for head 7.9 (sign=1) : Top 5 accu

Look at heads with strong pos on heatmap

In [None]:
get_copy_scores(model, 9, 9, dataset)

Copy circuit for head 9.9 (sign=1) : Top 5 accuracy: 100.0%


{'Tyler': [' Tyler', 'Tyler', ' TY', 'ty', ' Slater'],
 'Brian': [' Brian', 'Brian', ' BI', 'BUG', 'BI'],
 'Matthew': [' Matthew', 'Matthew', ' Matthews', 'Matt', 'MAT'],
 'Emily': [' Emily', 'Emily', ' EM', ' Ember', 'EM'],
 'Laura': [' Laura', 'Laura', ' Lar', ' Laur', ' Laf'],
 'Amy': [' Amy', 'Amy', ' amy', 'amy', ' ALS'],
 'Crystal': [' Crystal', ' crystal', 'Crystal', ' crystals', ' Cry'],
 'John': [' John', 'John', 'JOHN', 'john', ' Johns'],
 'Joshua': [' Joshua', 'Joshua', 'Josh', ' Josh', ' Jak'],
 'Joseph': [' Joseph', 'Joseph', 'Joe', ' Joe', ' Mos'],
 'Jennifer': [' Jennifer', 'Jennifer', 'Jen', ' Jen', 'jen'],
 'Aaron': [' Aaron', 'Aaron', ' Ezra', ' HA', 'achus'],
 'Sean': [' Sean', 'Sean', 'ayne', ' Stras', ' Syd'],
 'Jason': [' Jason', 'Jason', ' raspberry', 'nar', ' Gir'],
 'Kevin': ['Kevin', ' Kevin', ' KD', 'km', ' Cannon'],
 'David': [' David', 'David', ' david', 'avid', ' Davidson'],
 'Mark': [' Mark', 'Mark', 'mark', ' mark', ' Marks'],
 'Jessica': [' Jessica', 'J

# Writing direction results with scatterplot

In [None]:
def scatter_attention_and_contribution(
    model,
    layer_no,
    head_no,
    dataset,
    S1_is_first=False,
    return_vals=False,
    return_fig=False,
):
    """
    Plot a scatter plot
    for each input sequence with the attention paid to S
    and the amount that is written in the S directions
    """

    n_heads = model.cfg.n_heads
    n_layers = model.cfg.n_layers
    model_unembed = model.unembed.W_U.detach().cpu()
    df = []
    cache = {}
    model.cache_all(cache)

    logits = model(dataset.toks.long())

    for i, prompt in enumerate(dataset.prompts):
        s_toks = []
        s_positions = []
        s_dirs = []

        targ_tokens = [key for key in dataset.prompts[0].keys() if key != 'text']
        for s_id in targ_tokens:
            if S1_is_first and s_id == "S1":  # only use this if first token doesn't have space Ġ in front
                s_tok = model.tokenizer(prompt["S1"])["input_ids"][0]
            else:
                s_tok = model.tokenizer(" " + prompt[s_id])["input_ids"][0]
            s_toks.append(s_tok)

            toks = model.tokenizer(prompt["text"])["input_ids"]
            try:
                s_pos = toks.index(s_tok)
            except ValueError:
                print(f"{s_tok} is not present in {toks}. Skipping...")
                continue

            s_pos = toks.index(s_tok)
            s_positions.append(s_pos)

            s_dir = model_unembed[:, s_tok].detach()
            s_dirs.append(s_dir)

        for dire, posses, tok_type in zip(s_dirs, s_positions, targ_tokens):
            prob = sum(
                [
                    cache[f"blocks.{layer_no}.attn.hook_attn"][
                        i, head_no, dataset.word_idx["end"][i], pos
                    ]
                    .detach()
                    .cpu()
                    for pos in [posses]
                ]
            )
            resid = (
                cache[f"blocks.{layer_no}.attn.hook_result"][
                    i, dataset.word_idx["end"][i], head_no, :
                ]
                .detach()
                .cpu()
            )
            dot = torch.einsum("a,a->", resid, dire)
            df.append([prob, dot, tok_type, prompt["text"]])

    viz_df = pd.DataFrame(
        df, columns=[f"Attn Prob on Name", f"Dot w Name Embed", "Name Type", "text"]
    )
    fig = px.scatter(
        viz_df,
        x=f"Attn Prob on Name",
        y=f"Dot w Name Embed",
        color="Name Type",
        hover_data=["text"],
        title=f"How Strong {layer_no}.{head_no} Writes in the Name Embed Direction Relative to Attn Prob",
    )

    if return_vals:
        return viz_df
    if return_fig:
        return fig
    else:
        fig.show()

In [None]:
scatter_attention_and_contribution(
    model=model, layer_no=9, head_no=9, dataset=dataset, S1_is_first=False
)

## Correlation vals

In [None]:
def get_prob_dot(  # same as scatterplot, but output x and y vals instead of plotting
    model,
    layer_no,
    head_no,
    dataset,
    S1_is_first=False,
    return_vals=False,
    return_fig=False,
):
    """
    Plot a scatter plot
    for each input sequence with the attention paid to S
    and the amount that is written in the S directions
    """

    n_heads = model.cfg.n_heads
    n_layers = model.cfg.n_layers
    model_unembed = model.unembed.W_U.detach().cpu()
    # df = []
    all_prob = []
    all_dot = []
    cache = {}
    model.cache_all(cache)

    logits = model(dataset.toks.long())

    for i, prompt in enumerate(dataset.prompts):
        s_toks = []
        s_positions = []
        s_dirs = []

        targ_tokens = [key for key in dataset.prompts[0].keys() if key != 'text']
        for s_id in targ_tokens:
            if S1_is_first and s_id == "S1":  # only use this if first token doesn't have space Ġ in front
                s_tok = model.tokenizer(prompt["S1"])["input_ids"][0]
            else:
                s_tok = model.tokenizer(" " + prompt[s_id])["input_ids"][0]
            s_toks.append(s_tok)

            toks = model.tokenizer(prompt["text"])["input_ids"]
            try:
                s_pos = toks.index(s_tok)
            except ValueError:
                print(f"{s_tok} is not present in {toks}. Skipping...")
                continue

            s_pos = toks.index(s_tok)
            s_positions.append(s_pos)

            s_dir = model_unembed[:, s_tok].detach()
            s_dirs.append(s_dir)

        for dire, posses, tok_type in zip(s_dirs, s_positions, targ_tokens):
            prob = sum(
                [
                    cache[f"blocks.{layer_no}.attn.hook_attn"][
                        i, head_no, dataset.word_idx["end"][i], pos
                    ]
                    .detach()
                    .cpu()
                    for pos in [posses]
                ]
            )
            resid = (
                cache[f"blocks.{layer_no}.attn.hook_result"][
                    i, dataset.word_idx["end"][i], head_no, :
                ]
                .detach()
                .cpu()
            )
            dot = torch.einsum("a,a->", resid, dire)
            #df.append([prob, dot, tok_type, prompt["text"]])
            all_prob.append(prob)
            all_dot.append(dot)

    return all_prob, all_dot


In [None]:
all_prob, all_dot = get_prob_dot(
    model=model, layer_no=9, head_no=9, dataset=dataset, S1_is_first=False
)

In [None]:
import scipy.stats as stats

# X and Y should be arrays, lists, or pandas Series
correlation, p_value = stats.pearsonr(all_prob, all_dot)

print("Correlation:", correlation)
print("p-value:", p_value)

Correlation: 0.9469951138520497
p-value: 2.7410665403780096e-30
