<a href="https://colab.research.google.com/github/wlg100/numseqcont_circuit_expms/blob/main/notebook_templates/headFNs_expms_template.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

# Setup

In [1]:
%%capture
%pip install git+https://github.com/redwoodresearch/Easy-Transformer.git
%pip install einops datasets transformers fancy_einsum

In [2]:
from copy import deepcopy
import torch

assert torch.cuda.device_count() == 1
from tqdm import tqdm
import pandas as pd
import torch
import torch as t
from easy_transformer.EasyTransformer import (
    EasyTransformer,
)
from time import ctime
from functools import partial

import numpy as np
from tqdm import tqdm
import pandas as pd

from easy_transformer.experiments import (
    ExperimentMetric,
    AblationConfig,
    EasyAblation,
    EasyPatching,
    PatchingConfig,
)
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
import random
import einops
from IPython import get_ipython
from copy import deepcopy
from easy_transformer.ioi_dataset import (
    IOIDataset,
)
from easy_transformer.ioi_utils import (
    path_patching,
    max_2d,
    CLASS_COLORS,
    show_pp,
    show_attention_patterns,
    scatter_attention_and_contribution,
)
from random import randint as ri
from easy_transformer.ioi_circuit_extraction import (
    do_circuit_extraction,
    get_heads_circuit,
    CIRCUIT,
)
from easy_transformer.ioi_utils import logit_diff, probs
from easy_transformer.ioi_utils import get_top_tokens_and_probs as g

ipython = get_ipython()
if ipython is not None:
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

 Initialise model (use larger N or fewer templates for no warnings about in-template ablation)

In [3]:
model = EasyTransformer.from_pretrained("gpt2").cuda()
# model = EasyTransformer.from_pretrained("gpt2")
model.set_use_attn_result(True)

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



Moving model to device:  cuda
Finished loading pretrained model gpt2 into EasyTransformer!


# Generate dataset with multiple prompts

In [4]:
def generate_prompts_list():
    months = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December']
    prompts_list = []
    for i in range(0, 8):
        prompt_dict = {
            'S1': months[i],
            'S2': months[i+1],
            'S3': months[i+2],
            'S4': months[i+3],
            'text': f"{months[i]} {months[i+1]} {months[i+2]} {months[i+3]}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list()
print(prompts_list)

[{'S1': 'January', 'S2': 'February', 'S3': 'March', 'S4': 'April', 'text': 'January February March April'}, {'S1': 'February', 'S2': 'March', 'S3': 'April', 'S4': 'May', 'text': 'February March April May'}, {'S1': 'March', 'S2': 'April', 'S3': 'May', 'S4': 'June', 'text': 'March April May June'}, {'S1': 'April', 'S2': 'May', 'S3': 'June', 'S4': 'July', 'text': 'April May June July'}, {'S1': 'May', 'S2': 'June', 'S3': 'July', 'S4': 'August', 'text': 'May June July August'}, {'S1': 'June', 'S2': 'July', 'S3': 'August', 'S4': 'September', 'text': 'June July August September'}, {'S1': 'July', 'S2': 'August', 'S3': 'September', 'S4': 'October', 'text': 'July August September October'}, {'S1': 'August', 'S2': 'September', 'S3': 'October', 'S4': 'November', 'text': 'August September October November'}]


In [5]:
class Dataset:
    def __init__(self, prompts, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if key != 'text']:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                    target_token = prompt[targ]
                else:
                    target_token = "Ġ" + prompt[targ]
                target_index = tokens.index(target_token)
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [6]:
dataset = Dataset(prompts_list, model.tokenizer, S1_is_first=True)

# Copy score

In [7]:
def get_copy_scores(model, layer, head, dataset, verbose=False, neg=False, print_tokens=True):
    cache = {}
    model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    model(dataset.toks.long())
    if neg:
        sign = -1
    else:
        sign = 1
    z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    pred_tokens_dict = {}
    words_moved = []
    # get the keys from the first prompt in the dataset
    words = [key for key in dataset.prompts[0].keys() if key != 'text']

    for seq_idx, prompt in enumerate(dataset.prompts):
        for word in words:
            pred_tokens = [
                model.tokenizer.decode(token)
                for token in torch.topk(
                    logits[seq_idx, dataset.word_idx[word][seq_idx]], k
                ).indices
            ]
            pred_tokens_dict[prompt[word]] = pred_tokens
            if " " + prompt[word] in pred_tokens:
                n_right += 1
                words_moved.append(prompt[word])

    percent_right = (n_right / (dataset.N * len(words))) * 100
    print(f"Copy circuit for head {layer}.{head} (sign={sign}) : Top {k} accuracy: {percent_right}%")

    if print_tokens == True:
        return pred_tokens_dict
    else:
        return words_moved

## Get important heads

Find what heads are specific to certain inputs, and what's common to the template.

Get important heads from: circuit_expms_template.ipynb (Section: print top heads. Copy output of 'top_indices'; put on one line using chatgpt)

NOTE: not all attention heads just copy, so use attention patterns to determine which ones copy to refine this list of heads

(Eg. if you copy all the top heads from IOI, only 9.9 and 10.0 are name movers while other heads are "S-inhibition", "induction", "duplicate", so only the name movers + backup NM will have top accuracy)

In [8]:
top_val = [(0, 10), (0, 1), (5,5), (6,1), (7, 10), (8,8), (7,11), (8,11), (9,1), (9,5), (10,7)]
for index, (layer, head) in enumerate(top_val):
    print(index, get_copy_scores(model, layer, head, dataset, print_tokens=False))

Copy circuit for head 0.10 (sign=1) : Top 5 accuracy: 0.0%
0 []
Copy circuit for head 0.1 (sign=1) : Top 5 accuracy: 0.0%
1 []
Copy circuit for head 5.5 (sign=1) : Top 5 accuracy: 0.0%
2 []
Copy circuit for head 6.1 (sign=1) : Top 5 accuracy: 3.125%
3 ['August']
Copy circuit for head 7.10 (sign=1) : Top 5 accuracy: 71.875%
4 ['February', 'March', 'April', 'March', 'April', 'May', 'April', 'May', 'June', 'May', 'June', 'July', 'July', 'August', 'July', 'August', 'September', 'August', 'September', 'October', 'September', 'October', 'November']
Copy circuit for head 8.8 (sign=1) : Top 5 accuracy: 6.25%
5 ['March', 'November']
Copy circuit for head 7.11 (sign=1) : Top 5 accuracy: 65.625%
6 ['February', 'March', 'April', 'March', 'April', 'May', 'May', 'June', 'May', 'June', 'July', 'June', 'July', 'August', 'July', 'August', 'September', 'August', 'September', 'September', 'November']
Copy circuit for head 8.11 (sign=1) : Top 5 accuracy: 65.625%
7 ['February', 'March', 'April', 'March', '

Look at heads with strong pos on heatmap

In [10]:
get_copy_scores(model, 9, 1, dataset)

Copy circuit for head 9.1 (sign=1) : Top 5 accuracy: 0.0%


{'January': [' once', 'once', ' occasional', ' seventh', '296'],
 'February': [' third', ' fourth', ' once', ' seventh', ' fifth'],
 'March': [' fifth', ' seventh', ' fourth', ' sixth', ' third'],
 'April': [' seventh', ' fifth', ' sixth', 'ISC', ' five'],
 'May': [' seventh', ' sixth', 'ISC', ' once', ' 157'],
 'June': [' seventh', ' third', ' seven', ' sixth', 'seven'],
 'July': [' seventh', ' eighth', 'seven', ' once', ' seven'],
 'August': ['ighth', ' eighth', ' ninth', ' final', ' occasional'],
 'September': [' 120', 'ure', 'Loading', '�', ' Nguyen'],
 'October': [' Nur', 'undrum', '�', '�', ' 121'],
 'November': ['oor', 'ة', ' 122', 'Enlarge', ' Nur']}

In [11]:
get_copy_scores(model, 4, 4, dataset)

Copy circuit for head 4.4 (sign=1) : Top 5 accuracy: 0.0%


{'January': ['onom', 'osc', 'natureconservancy', 'ovie', ' somew'],
 'February': ['onom', 'ovie', 'natureconservancy', 'osc', 'icult'],
 'March': ['onom', 'osc', ' somew', ' Canaver', ' physic'],
 'April': ['onom', 'natureconservancy', 'ovie', ' Borders', 'osc'],
 'May': ['Redd', 'natureconservancy', 'models', 'Maps', ' Craigslist'],
 'June': ['onom', 'natureconservancy', 'ovie', 'ベ', 'osc'],
 'July': ['onom', 'ovie', 'natureconservancy', 'osc', ' Wiki'],
 'August': ['onom', 'natureconservancy', ' Canaver', 'osc', 'onomic'],
 'September': ['rel', 'aimon', ' Ludwig', ' inher', ' assistants'],
 'October': ['rel', ' Zin', ' snakes', ' under', ' roommate'],
 'November': ['rel', ' handles', 'oof', ' Zin', 'ologist']}

# Writing direction results with scatterplot

In [None]:
def scatter_attention_and_contribution(
    model,
    layer_no,
    head_no,
    dataset,
    S1_is_first=False,
    return_vals=False,
    return_fig=False,
):
    """
    Plot a scatter plot
    for each input sequence with the attention paid to S
    and the amount that is written in the S directions
    """

    n_heads = model.cfg.n_heads
    n_layers = model.cfg.n_layers
    model_unembed = model.unembed.W_U.detach().cpu()
    df = []
    cache = {}
    model.cache_all(cache)

    logits = model(dataset.toks.long())

    for i, prompt in enumerate(dataset.prompts):
        s_toks = []
        s_positions = []
        s_dirs = []

        targ_tokens = [key for key in dataset.prompts[0].keys() if key != 'text']
        for s_id in targ_tokens:
            if S1_is_first and s_id == "S1":  # only use this if first token doesn't have space Ġ in front
                s_tok = model.tokenizer(prompt["S1"])["input_ids"][0]
            else:
                s_tok = model.tokenizer(" " + prompt[s_id])["input_ids"][0]
            s_toks.append(s_tok)

            toks = model.tokenizer(prompt["text"])["input_ids"]
            try:
                s_pos = toks.index(s_tok)
            except ValueError:
                print(f"{s_tok} is not present in {toks}. Skipping...")
                continue

            s_pos = toks.index(s_tok)
            s_positions.append(s_pos)

            s_dir = model_unembed[:, s_tok].detach()
            s_dirs.append(s_dir)

        for dire, posses, tok_type in zip(s_dirs, s_positions, targ_tokens):
            prob = sum(
                [
                    cache[f"blocks.{layer_no}.attn.hook_attn"][
                        i, head_no, dataset.word_idx["end"][i], pos
                    ]
                    .detach()
                    .cpu()
                    for pos in [posses]
                ]
            )
            resid = (
                cache[f"blocks.{layer_no}.attn.hook_result"][
                    i, dataset.word_idx["end"][i], head_no, :
                ]
                .detach()
                .cpu()
            )
            dot = torch.einsum("a,a->", resid, dire)
            df.append([prob, dot, tok_type, prompt["text"]])

    viz_df = pd.DataFrame(
        df, columns=[f"Attn Prob on Month", f"Dot w Month Embed", "Seq Position", "text"]
    )
    fig = px.scatter(
        viz_df,
        x=f"Attn Prob on Month",
        y=f"Dot w Month Embed",
        color="Seq Position",
        hover_data=["text"],
        title=f"How Strong {layer_no}.{head_no} Writes in the Month Embed Direction Relative to Attn Prob",
    )

    if return_vals:
        return viz_df
    if return_fig:
        return fig
    else:
        fig.show()

In [None]:
scatter_attention_and_contribution(
    model=model, layer_no=9, head_no=1, dataset=dataset, S1_is_first=False
)

3269 is not present in [21339, 3945, 2805, 3035]. Skipping...
3945 is not present in [21816, 2805, 3035, 1737]. Skipping...
2805 is not present in [16192, 3035, 1737, 2795]. Skipping...
3035 is not present in [16784, 1737, 2795, 2901]. Skipping...
1737 is not present in [6747, 2795, 2901, 2932]. Skipping...
2795 is not present in [15749, 2901, 2932, 2693]. Skipping...
2901 is not present in [16157, 2932, 2693, 3267]. Skipping...
2932 is not present in [17908, 2693, 3267, 3389]. Skipping...


## Correlation vals

In [None]:
def get_prob_dot(  # same as scatterplot, but output x and y vals instead of plotting
    model,
    layer_no,
    head_no,
    dataset,
    S1_is_first=False,
    return_vals=False,
    return_fig=False,
):
    """
    Plot a scatter plot
    for each input sequence with the attention paid to S
    and the amount that is written in the S directions
    """

    n_heads = model.cfg.n_heads
    n_layers = model.cfg.n_layers
    model_unembed = model.unembed.W_U.detach().cpu()
    # df = []
    all_prob = []
    all_dot = []
    cache = {}
    model.cache_all(cache)

    logits = model(dataset.toks.long())

    for i, prompt in enumerate(dataset.prompts):
        s_toks = []
        s_positions = []
        s_dirs = []

        targ_tokens = [key for key in dataset.prompts[0].keys() if key != 'text']
        for s_id in targ_tokens:
            if S1_is_first and s_id == "S1":  # only use this if first token doesn't have space Ġ in front
                s_tok = model.tokenizer(prompt["S1"])["input_ids"][0]
            else:
                s_tok = model.tokenizer(" " + prompt[s_id])["input_ids"][0]
            s_toks.append(s_tok)

            toks = model.tokenizer(prompt["text"])["input_ids"]
            try:
                s_pos = toks.index(s_tok)
            except ValueError:
                print(f"{s_tok} is not present in {toks}. Skipping...")
                continue

            s_pos = toks.index(s_tok)
            s_positions.append(s_pos)

            s_dir = model_unembed[:, s_tok].detach()
            s_dirs.append(s_dir)

        for dire, posses, tok_type in zip(s_dirs, s_positions, targ_tokens):
            prob = sum(
                [
                    cache[f"blocks.{layer_no}.attn.hook_attn"][
                        i, head_no, dataset.word_idx["end"][i], pos
                    ]
                    .detach()
                    .cpu()
                    for pos in [posses]
                ]
            )
            resid = (
                cache[f"blocks.{layer_no}.attn.hook_result"][
                    i, dataset.word_idx["end"][i], head_no, :
                ]
                .detach()
                .cpu()
            )
            dot = torch.einsum("a,a->", resid, dire)
            #df.append([prob, dot, tok_type, prompt["text"]])
            all_prob.append(prob)
            all_dot.append(dot)

    return all_prob, all_dot


In [None]:
all_prob, all_dot = get_prob_dot(
    model=model, layer_no=9, head_no=9, dataset=dataset, S1_is_first=False
)

3269 is not present in [21339, 3945, 2805, 3035]. Skipping...
3945 is not present in [21816, 2805, 3035, 1737]. Skipping...
2805 is not present in [16192, 3035, 1737, 2795]. Skipping...
3035 is not present in [16784, 1737, 2795, 2901]. Skipping...
1737 is not present in [6747, 2795, 2901, 2932]. Skipping...
2795 is not present in [15749, 2901, 2932, 2693]. Skipping...
2901 is not present in [16157, 2932, 2693, 3267]. Skipping...
2932 is not present in [17908, 2693, 3267, 3389]. Skipping...


In [None]:
import scipy.stats as stats

# X and Y should be arrays, lists, or pandas Series
correlation, p_value = stats.pearsonr(all_prob, all_dot)

print("Correlation:", correlation)
print("p-value:", p_value)

Correlation: 0.9307392782710868
p-value: 4.390668557761558e-11


In [None]:
all_prob, all_dot = get_prob_dot(
    model=model, layer_no=9, head_no=1, dataset=dataset, S1_is_first=False
)

correlation, p_value = stats.pearsonr(all_prob, all_dot)
print("Correlation:", correlation)
print("p-value:", p_value)

3269 is not present in [21339, 3945, 2805, 3035]. Skipping...
3945 is not present in [21816, 2805, 3035, 1737]. Skipping...
2805 is not present in [16192, 3035, 1737, 2795]. Skipping...
3035 is not present in [16784, 1737, 2795, 2901]. Skipping...
1737 is not present in [6747, 2795, 2901, 2932]. Skipping...
2795 is not present in [15749, 2901, 2932, 2693]. Skipping...
2901 is not present in [16157, 2932, 2693, 3267]. Skipping...
2932 is not present in [17908, 2693, 3267, 3389]. Skipping...
Correlation: 0.7485558896583466
p-value: 2.5860397031824707e-05


In [None]:
all_prob, all_dot = get_prob_dot(
    model=model, layer_no=7, head_no=10, dataset=dataset, S1_is_first=False
)

correlation, p_value = stats.pearsonr(all_prob, all_dot)
print("Correlation:", correlation)
print("p-value:", p_value)

3269 is not present in [21339, 3945, 2805, 3035]. Skipping...
3945 is not present in [21816, 2805, 3035, 1737]. Skipping...
2805 is not present in [16192, 3035, 1737, 2795]. Skipping...
3035 is not present in [16784, 1737, 2795, 2901]. Skipping...
1737 is not present in [6747, 2795, 2901, 2932]. Skipping...
2795 is not present in [15749, 2901, 2932, 2693]. Skipping...
2901 is not present in [16157, 2932, 2693, 3267]. Skipping...
2932 is not present in [17908, 2693, 3267, 3389]. Skipping...
Correlation: 0.9170188900139262
p-value: 3.0041050057793885e-10


In [None]:
all_prob, all_dot = get_prob_dot(
    model=model, layer_no=5, head_no=1, dataset=dataset, S1_is_first=False
)

correlation, p_value = stats.pearsonr(all_prob, all_dot)
print("Correlation:", correlation)
print("p-value:", p_value)

3269 is not present in [21339, 3945, 2805, 3035]. Skipping...
3945 is not present in [21816, 2805, 3035, 1737]. Skipping...
2805 is not present in [16192, 3035, 1737, 2795]. Skipping...
3035 is not present in [16784, 1737, 2795, 2901]. Skipping...
1737 is not present in [6747, 2795, 2901, 2932]. Skipping...
2795 is not present in [15749, 2901, 2932, 2693]. Skipping...
2901 is not present in [16157, 2932, 2693, 3267]. Skipping...
2932 is not present in [17908, 2693, 3267, 3389]. Skipping...
Correlation: 0.714348201909229
p-value: 8.814064668526422e-05


In [None]:
all_prob, all_dot = get_prob_dot(
    model=model, layer_no=0, head_no=3, dataset=dataset, S1_is_first=False
)

correlation, p_value = stats.pearsonr(all_prob, all_dot)
print("Correlation:", correlation)
print("p-value:", p_value)

3269 is not present in [21339, 3945, 2805, 3035]. Skipping...
3945 is not present in [21816, 2805, 3035, 1737]. Skipping...
2805 is not present in [16192, 3035, 1737, 2795]. Skipping...
3035 is not present in [16784, 1737, 2795, 2901]. Skipping...
1737 is not present in [6747, 2795, 2901, 2932]. Skipping...
2795 is not present in [15749, 2901, 2932, 2693]. Skipping...
2901 is not present in [16157, 2932, 2693, 3267]. Skipping...
2932 is not present in [17908, 2693, 3267, 3389]. Skipping...
Correlation: 0.12830774329431274
p-value: 0.5501739536565748
