<a href="https://colab.research.google.com/github/wlg100/numseqcont_circuit_expms/blob/main/notebook_templates/headFNs_expms_template.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

# Setup

In [1]:
%%capture
%pip install git+https://github.com/redwoodresearch/Easy-Transformer.git
%pip install einops datasets transformers fancy_einsum

In [2]:
from copy import deepcopy
import torch

assert torch.cuda.device_count() == 1
from tqdm import tqdm
import pandas as pd
import torch
import torch as t
from easy_transformer.EasyTransformer import (
    EasyTransformer,
)
from time import ctime
from functools import partial

import numpy as np
from tqdm import tqdm
import pandas as pd

from easy_transformer.experiments import (
    ExperimentMetric,
    AblationConfig,
    EasyAblation,
    EasyPatching,
    PatchingConfig,
)
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
import random
import einops
from IPython import get_ipython
from copy import deepcopy
from easy_transformer.ioi_dataset import (
    IOIDataset,
)
from easy_transformer.ioi_utils import (
    path_patching,
    max_2d,
    CLASS_COLORS,
    show_pp,
    show_attention_patterns,
    scatter_attention_and_contribution,
)
from random import randint as ri
from easy_transformer.ioi_circuit_extraction import (
    do_circuit_extraction,
    get_heads_circuit,
    CIRCUIT,
)
from easy_transformer.ioi_utils import logit_diff, probs
from easy_transformer.ioi_utils import get_top_tokens_and_probs as g

ipython = get_ipython()
if ipython is not None:
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

 Initialise model (use larger N or fewer templates for no warnings about in-template ablation)

In [13]:
model = EasyTransformer.from_pretrained("gpt2-medium").cuda()
# model = EasyTransformer.from_pretrained("gpt2")
model.set_use_attn_result(True)

Downloading (…)lve/main/config.json:   0%|          | 0.00/718 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/1.52G [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



Moving model to device:  cuda
Finished loading pretrained model gpt2-medium into EasyTransformer!


# Generate dataset with multiple prompts

In [9]:
class Dataset:
    def __init__(self, prompts, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        # all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts]
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        # self.io_tokenIDs = [
        #     self.tokenizer.encode(" " + prompt["corr"])[0] for prompt in self.prompts
        # ]
        # self.s_tokenIDs = [
        #     self.tokenizer.encode(" " + prompt["incorr"])[0] for prompt in self.prompts
        # ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                    target_token = prompt[targ]
                else:
                    target_token = "Ġ" + prompt[targ]
                target_index = tokens.index(target_token)
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [29]:
def generate_prompts_list(x ,y):
    prompts_list = []
    for i in range(x, y):
        prompt_dict = {
            'S1': str(i),
            'S2': str(i+2),
            'S3': str(i+4),
            'S4': str(i+6),
            'S5': str(i+8),
            'S6': str(i+10),
            # 'corr': str(i+12),
            # 'incorr': str(i+10),
            'text': f"{i} {i+2} {i+4} {i+6} {i+8} {i+10}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(1, 100)
dataset = Dataset(prompts_list, model.tokenizer, S1_is_first=True)

# add2 score on last digit of seq (the one to pred next)

In [30]:
def get_addTwo_scores(model, layer, head, dataset, verbose=False, neg=False, print_tokens=True):
    cache = {}
    model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    model(dataset.toks.long())
    if neg:
        sign = -1
    else:
        sign = 1
    z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    pred_tokens_dict = {}
    words_added = []
    # get the keys from the first prompt in the dataset
    words = [key for key in dataset.prompts[0].keys() if key != 'text']

    for seq_idx, prompt in enumerate(dataset.prompts):
        # for word in words:
        word = words[-1]

        pred_tokens = [
            model.tokenizer.decode(token)
            for token in torch.topk(
                logits[seq_idx, dataset.word_idx[word][seq_idx]], k
            ).indices
        ]

        # get prev member after digit prompt[word]
        targword = str(int(prompt[word]) + 2)

        targToken_in_topK = 'no'
        if " " + targword in pred_tokens or targword in pred_tokens:
            n_right += 1
            words_added.append(prompt[word])
            targToken_in_topK = 'yes'
            # if prompt[word] == '99':
            #     pdb.set_trace()
        # if prompt[word] in pred_tokens_dict:
        #     pdb.set_trace()
        pred_tokens_dict[prompt[word]] = (pred_tokens, targword, targToken_in_topK)

    # percent_right = (n_right / (dataset.N * len(words))) * 100
    percent_right = (n_right / (dataset.N)) * 100
    if percent_right > 0:
        print(f"Head {layer}.{head} (sign={sign}) : Top {k} accuracy: {percent_right}%")

    if print_tokens == True:
        print(n_right)
        print((dataset.N))
        print(words_added)
        return pred_tokens_dict
    else:
        if percent_right > 0:
            return words_added

## Find add2 heads of all gpt2 med

In [33]:
all_heads = [(layer, head) for layer in range(24) for head in range(16)]
for index, (layer, head) in enumerate(all_heads):
    results = get_addTwo_scores(model, layer, head, dataset, print_tokens=False)
    if results:
        print((layer, head), results)

Head 5.8 (sign=1) : Top 5 accuracy: 1.0101010101010102%
(5, 8) ['29']
Head 6.1 (sign=1) : Top 5 accuracy: 3.0303030303030303%
(6, 1) ['26', '33', '36']
Head 7.2 (sign=1) : Top 5 accuracy: 6.0606060606060606%
(7, 2) ['26', '32', '34', '36', '46', '74']
Head 9.4 (sign=1) : Top 5 accuracy: 15.151515151515152%
(9, 4) ['12', '13', '18', '21', '23', '27', '28', '29', '32', '38', '43', '45', '48', '63', '68']
Head 9.9 (sign=1) : Top 5 accuracy: 15.151515151515152%
(9, 9) ['19', '26', '30', '34', '36', '39', '49', '50', '53', '69', '74', '93', '102', '103', '105']
Head 10.1 (sign=1) : Top 5 accuracy: 6.0606060606060606%
(10, 1) ['19', '26', '29', '32', '33', '34']
Head 10.8 (sign=1) : Top 5 accuracy: 38.38383838383838%
(10, 8) ['12', '14', '17', '19', '20', '25', '26', '30', '32', '34', '39', '41', '44', '45', '49', '50', '51', '52', '53', '62', '66', '69', '70', '71', '74', '76', '79', '82', '84', '87', '94', '99', '102', '103', '104', '106', '107', '109']
Head 11.1 (sign=1) : Top 5 accuracy:

## Find add2 heads of add2 circ

In [31]:
circ_246 =  [(0, 2), (0, 3), (0, 4), (0, 5), (0, 9), (0, 10), (0, 14), (1, 2), (1, 4), (1, 7), (1, 14), (2, 3), (2, 4), (2, 5), (2, 7), (2, 8), (2, 9), (2, 15), (3, 0), (3, 3), (3, 13), (3, 14), (3, 15), (4, 2), (4, 6), (4, 8), (4, 10), (4, 11), (5, 8), (6, 14), (6, 15), (7, 2), (7, 11), (7, 13), (8, 0), (9, 3), (9, 4), (9, 5), (9, 6), (9, 12), (9, 15), (10, 1), (10, 4), (10, 9), (10, 10), (10, 13), (10, 14), (11, 1), (11, 4), (11, 5), (11, 8), (12, 1), (12, 4), (12, 12), (12, 13), (12, 15), (13, 5), (13, 12), (13, 13), (14, 5), (14, 14), (15, 5), (15, 7), (15, 11), (15, 12), (15, 15), (16, 6), (16, 7), (16, 9), (16, 11), (16, 13), (16, 14), (17, 0), (17, 1), (17, 12), (18, 3), (18, 11), (18, 13), (19, 1), (19, 4), (20, 0), (20, 1), (20, 14), (21, 0), (21, 2), (21, 7)]
for index, (layer, head) in enumerate(circ_246):
    print((layer, head), get_addTwo_scores(model, layer, head, dataset, print_tokens=False))

(0, 2) None
(0, 3) None
(0, 4) None
(0, 5) None
(0, 9) None
(0, 10) None
(0, 14) None
(1, 2) None
(1, 4) None
(1, 7) None
(1, 14) None
(2, 3) None
(2, 4) None
(2, 5) None
(2, 7) None
(2, 8) None
(2, 9) None
(2, 15) None
(3, 0) None
(3, 3) None
(3, 13) None
(3, 14) None
(3, 15) None
(4, 2) None
(4, 6) None
(4, 8) None
(4, 10) None
(4, 11) None
Head 5.8 (sign=1) : Top 5 accuracy: 1.0101010101010102%
(5, 8) ['29']
(6, 14) None
(6, 15) None
Head 7.2 (sign=1) : Top 5 accuracy: 6.0606060606060606%
(7, 2) ['26', '32', '34', '36', '46', '74']
(7, 11) None
(7, 13) None
(8, 0) None
(9, 3) None
Head 9.4 (sign=1) : Top 5 accuracy: 15.151515151515152%
(9, 4) ['12', '13', '18', '21', '23', '27', '28', '29', '32', '38', '43', '45', '48', '63', '68']
(9, 5) None
(9, 6) None
(9, 12) None
(9, 15) None
Head 10.1 (sign=1) : Top 5 accuracy: 6.0606060606060606%
(10, 1) ['19', '26', '29', '32', '33', '34']
(10, 4) None
(10, 9) None
(10, 10) None
(10, 13) None
(10, 14) None
Head 11.1 (sign=1) : Top 5 accuracy

In [32]:
get_addTwo_scores(model, 14, 14, dataset)

Head 14.14 (sign=1) : Top 5 accuracy: 57.57575757575758%
57
99
['20', '30', '31', '32', '34', '35', '36', '37', '40', '41', '47', '49', '50', '51', '52', '53', '56', '57', '60', '61', '62', '63', '64', '66', '67', '68', '69', '70', '71', '72', '74', '75', '76', '79', '80', '81', '82', '83', '84', '85', '86', '87', '90', '91', '92', '93', '95', '96', '100', '101', '103', '104', '105', '106', '107', '108', '109']


{'11': ([' 12', ' Eleven', ' 11', ' Fields', '12'], '13', 'no'),
 '12': ([' 13', '13', ' 12', '12', ' thirteen'], '14', 'no'),
 '13': ([' 14', ' Sieg', ' 13', ' 1400', '13'], '15', 'no'),
 '14': ([' 15', '15', ' 14', '14', '1500'], '16', 'no'),
 '15': ([' 16', ' 15', '16', ' 1600', '15'], '17', 'no'),
 '16': ([' 17', '17', ' 1700', ' 16', 'eenth'], '18', 'no'),
 '17': ([' 18', ' 17', ' 1800', ' 179', '17'], '19', 'no'),
 '18': ([' 19', 'eteenth', ' Mae', ' 18', '19'], '20', 'no'),
 '19': ([' 20', ' 19', ' 1919', '20', ' 1920'], '21', 'no'),
 '20': ([' 21', '21', ' 20', '221', ' 22'], '22', 'yes'),
 '21': ([' 22', ' 1921', ' 1922', ' 222', '22'], '23', 'no'),
 '22': ([' 1923', ' 23', ' 223', '23', ' 1922'], '24', 'no'),
 '23': ([' 24', ' 23', ' 1923', '24', ' 1924'], '25', 'no'),
 '24': ([' 25', '25', ' 24', '24', ' 249'], '26', 'no'),
 '25': ([' 26', '26', ' 25', 'NR', '25'], '27', 'no'),
 '26': ([' 27', '27', '26', ' 26', ' 267'], '28', 'no'),
 '27': ([' 28', ' 27', '28', '27', ' 278'

# Compare Copy scores for last pos

In [47]:
def get_copy_scores(model, layer, head, dataset, verbose=False, neg=False, print_tokens=True):
    cache = {}
    model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    model(dataset.toks.long())
    if neg:
        sign = -1
    else:
        sign = 1
    z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    pred_tokens_dict = {}
    words_moved = []
    # get the keys from the first prompt in the dataset
    words = [key for key in dataset.prompts[0].keys() if key != 'text']

    for seq_idx, prompt in enumerate(dataset.prompts):
        # for word in words:
        word = words[-1]

        pred_tokens = [
            model.tokenizer.decode(token)
            for token in torch.topk(
                logits[seq_idx, dataset.word_idx[word][seq_idx]], k
            ).indices
        ]

        token_in_topK = 'no'
        if " " + prompt[word] in pred_tokens or prompt[word] in pred_tokens:
            n_right += 1
            words_moved.append(prompt[word])
            token_in_topK = 'yes'
        pred_tokens_dict[prompt[word]] = (pred_tokens, token_in_topK)

    # percent_right = (n_right / (dataset.N * len(words))) * 100
    percent_right = (n_right / (dataset.N)) * 100
    print(f"Copy circuit for head {layer}.{head} (sign={sign}) : Top {k} accuracy: {percent_right}%")

    if print_tokens == True:
        return pred_tokens_dict
    else:
        return words_moved

In [48]:
for index, (layer, head) in enumerate(circ_246):
    print((layer, head), get_copy_scores(model, layer, head, dataset, print_tokens=False))

Copy circuit for head 0.2 (sign=1) : Top 5 accuracy: 0.0%
(0, 2) []
Copy circuit for head 0.3 (sign=1) : Top 5 accuracy: 0.0%
(0, 3) []
Copy circuit for head 0.4 (sign=1) : Top 5 accuracy: 0.0%
(0, 4) []
Copy circuit for head 0.5 (sign=1) : Top 5 accuracy: 0.0%
(0, 5) []
Copy circuit for head 0.9 (sign=1) : Top 5 accuracy: 0.0%
(0, 9) []
Copy circuit for head 0.10 (sign=1) : Top 5 accuracy: 0.0%
(0, 10) []
Copy circuit for head 0.14 (sign=1) : Top 5 accuracy: 0.0%
(0, 14) []
Copy circuit for head 1.2 (sign=1) : Top 5 accuracy: 0.0%
(1, 2) []
Copy circuit for head 1.4 (sign=1) : Top 5 accuracy: 0.0%
(1, 4) []
Copy circuit for head 1.7 (sign=1) : Top 5 accuracy: 0.0%
(1, 7) []
Copy circuit for head 1.14 (sign=1) : Top 5 accuracy: 0.0%
(1, 14) []
Copy circuit for head 2.3 (sign=1) : Top 5 accuracy: 0.0%
(2, 3) []
Copy circuit for head 2.4 (sign=1) : Top 5 accuracy: 0.0%
(2, 4) []
Copy circuit for head 2.5 (sign=1) : Top 5 accuracy: 0.0%
(2, 5) []
Copy circuit for head 2.7 (sign=1) : Top 5

# Input the first token of seq, not last

In [39]:
def get_addTwo_scores_token0(model, layer, head, dataset, verbose=False, neg=False, print_tokens=True):
    cache = {}
    model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    model(dataset.toks.long())
    if neg:
        sign = -1
    else:
        sign = 1
    z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    pred_tokens_dict = {}
    words_added = []
    # get the keys from the first prompt in the dataset
    words = [key for key in dataset.prompts[0].keys() if key != 'text']

    for seq_idx, prompt in enumerate(dataset.prompts):
        # for word in words:
        word = words[0]  ######### CHANGE TO THIS from last to first! #########

        pred_tokens = [
            model.tokenizer.decode(token)
            for token in torch.topk(
                logits[seq_idx, dataset.word_idx[word][seq_idx]], k
            ).indices
        ]

        # get +2 after digit prompt[word]
        targword = str(int(prompt[word]) + 2)

        targToken_in_topK = 'no'
        if " " + targword in pred_tokens or targword in pred_tokens:
            n_right += 1
            words_added.append(prompt[word])
            targToken_in_topK = 'yes'
        pred_tokens_dict[prompt[word]] = (pred_tokens, targword, targToken_in_topK)

    # percent_right = (n_right / (dataset.N * len(words))) * 100
    percent_right = (n_right / (dataset.N)) * 100
    if percent_right > 0:
        print(f"Head {layer}.{head} (sign={sign}) : Top {k} accuracy: {percent_right}%")

    if print_tokens == True:
        print(n_right)
        print((dataset.N))
        print(words_added)
        return pred_tokens_dict
    else:
        if percent_right > 0:
            return words_added

In [40]:
get_addTwo_scores_token0(model, 14, 14, dataset)

Head 14.14 (sign=1) : Top 5 accuracy: 8.080808080808081%
8
99
['20', '30', '40', '62', '71', '81', '86', '91']


{'1': (['bara', ' Ai', 'isphere', 'ahime', '²'], '3', 'no'),
 '2': (['3', 'ahime', 'alion', ' 3', 'atform'], '4', 'no'),
 '3': (['4', ' 4', 'daq', 'atform', 'iola'], '5', 'no'),
 '4': (['5', 'baugh', 'BB', 'Balt', 'acea'], '6', 'no'),
 '5': (['cill', ' 6', 'agu', ' brush', '6'], '7', 'no'),
 '6': ([' 7', '7', '�', ' 07', '�'], '8', 'no'),
 '7': ([' 8', ' 808', ' 7', ' Lawson', '8'], '9', 'no'),
 '8': ([' Caldwell', 'acea', ' 8', 'bara', ' 9'], '10', 'no'),
 '9': ([' Abe', ' Ark', ' AA', ' AQ', ' 9'], '11', 'no'),
 '10': (['1111', '11', 'CT', '910', ' 11'], '12', 'no'),
 '11': (['arta', '12', 'IDA', ' 12', ' Dawson'], '13', 'no'),
 '12': (['arta', 'tailed', '13', '12', ' 13'], '14', 'no'),
 '13': ([' 14', '14', 'BW', ' 1914', ' 13'], '15', 'no'),
 '14': (['15', 'arta', '1500', ' 15', 'ayer'], '16', 'no'),
 '15': ([' 16', 'ayson', '16', '1600', 'bara'], '17', 'no'),
 '16': (['17', ' 17', 'alde', 'eenth', '�'], '18', 'no'),
 '17': ([' 18', ' DD', ' Awoken', 'ッド', ' 17'], '19', 'no'),
 '18

# Input the second token of seq, not last

In [41]:
def get_addTwo_scores_token1(model, layer, head, dataset, verbose=False, neg=False, print_tokens=True):
    cache = {}
    model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    model(dataset.toks.long())
    if neg:
        sign = -1
    else:
        sign = 1
    z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    pred_tokens_dict = {}
    words_added = []
    # get the keys from the first prompt in the dataset
    words = [key for key in dataset.prompts[0].keys() if key != 'text']

    for seq_idx, prompt in enumerate(dataset.prompts):
        # for word in words:
        word = words[1]  ######### CHANGE TO THIS from last to 2nd! #########

        pred_tokens = [
            model.tokenizer.decode(token)
            for token in torch.topk(
                logits[seq_idx, dataset.word_idx[word][seq_idx]], k
            ).indices
        ]

        # get +2 after digit prompt[word]
        targword = str(int(prompt[word]) + 2)

        targToken_in_topK = 'no'
        if " " + targword in pred_tokens or targword in pred_tokens:
            n_right += 1
            words_added.append(prompt[word])
            targToken_in_topK = 'yes'
        pred_tokens_dict[prompt[word]] = (pred_tokens, targword, targToken_in_topK)

    # percent_right = (n_right / (dataset.N * len(words))) * 100
    percent_right = (n_right / (dataset.N)) * 100
    if percent_right > 0:
        print(f"Head {layer}.{head} (sign={sign}) : Top {k} accuracy: {percent_right}%")

    if print_tokens == True:
        print(n_right)
        print((dataset.N))
        print(words_added)
        return pred_tokens_dict
    else:
        if percent_right > 0:
            return words_added

In [42]:
get_addTwo_scores_token1(model, 14, 14, dataset)

Head 14.14 (sign=1) : Top 5 accuracy: 47.474747474747474%
47
99
['25', '30', '32', '35', '36', '37', '40', '41', '44', '47', '49', '50', '51', '52', '53', '57', '60', '61', '62', '64', '66', '67', '68', '69', '70', '71', '72', '74', '75', '76', '79', '80', '81', '83', '84', '85', '86', '87', '89', '90', '91', '92', '93', '95', '96', '100', '101']


{'3': ([' fourth', ' 4', '4', ' four', ' Fourth'], '5', 'no'),
 '4': (['5', ' 5', ' fifth', ' 4', '五'], '6', 'no'),
 '5': ([' 6', ' sixth', ' brush', ' Sixth', '6'], '7', 'no'),
 '6': ([' 7', ' seventh', '7', ' Seventh', ' 6'], '8', 'no'),
 '7': ([' 7', ' 8', ' seventh', ' VIII', ' eighth'], '9', 'no'),
 '8': ([' 9', ' 8', ' ninth', '9', '889'], '10', 'no'),
 '9': ([' 980', ' 9', 'apo', ' Sapp', ' 10'], '11', 'no'),
 '10': ([' 11', '11', ' 111', '1111', ' 1911'], '12', 'no'),
 '11': ([' 12', ' Eleven', ' 1280', ' 11', ' sidx'], '13', 'no'),
 '12': ([' 13', '13', 'ASE', ' thirteen', ' 12'], '14', 'no'),
 '13': ([' 14', ' Sieg', ' 13', ' 1400', '14'], '15', 'no'),
 '14': (['15', ' 15', ' 14', '14', '1500'], '16', 'no'),
 '15': ([' 16', '16', ' 1600', ' 15', ' sixteen'], '17', 'no'),
 '16': ([' 17', '17', ' 1700', ' 16', ' seventeen'], '18', 'no'),
 '17': ([' 18', ' 17', ' 1800', ' 179', '17'], '19', 'no'),
 '18': ([' 19', ' 18', 'eteenth', ' Hammond', '19'], '20', 'no'),
 '19': ([' 19', 

Perhaps first token has a space char in front? This is more similar to 'last'.

# Input the all tokens of seq

In [45]:
def get_addTwo_scores_allPos(model, layer, head, dataset, verbose=False, neg=False, print_tokens=True):
    cache = {}
    model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    model(dataset.toks.long())
    if neg:
        sign = -1
    else:
        sign = 1
    z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    pred_tokens_dict = {}
    words_added = []
    # get the keys from the first prompt in the dataset
    words = [key for key in dataset.prompts[0].keys() if key != 'text']

    for seq_idx, prompt in enumerate(dataset.prompts):
        for word in words:  ######### CHANGE TO USE THIS #########
        # word = words[1]

            pred_tokens = [
                model.tokenizer.decode(token)
                for token in torch.topk(
                    logits[seq_idx, dataset.word_idx[word][seq_idx]], k
                ).indices
            ]

            # get +2 after digit prompt[word]
            targword = str(int(prompt[word]) + 2)

            targToken_in_topK = 'no'
            if " " + targword in pred_tokens or targword in pred_tokens:
                n_right += 1
                words_added.append(prompt[word])
                targToken_in_topK = 'yes'
            pred_tokens_dict[prompt[word]] = (pred_tokens, targword, targToken_in_topK)

    percent_right = (n_right / (dataset.N * len(words))) * 100
    # percent_right = (n_right / (dataset.N)) * 100
    if percent_right > 0:
        print(f"Head {layer}.{head} (sign={sign}) : Top {k} accuracy: {percent_right}%")

    if print_tokens == True:
        print(n_right)
        print((dataset.N))
        print(words_added)
        return pred_tokens_dict
    else:
        if percent_right > 0:
            return words_added

In [46]:
get_addTwo_scores_allPos(model, 14, 14, dataset)

Head 14.14 (sign=1) : Top 5 accuracy: 44.78114478114478%
266
99
['20', '20', '20', '20', '20', '30', '25', '31', '30', '32', '25', '30', '32', '34', '35', '30', '32', '34', '36', '35', '37', '30', '32', '36', '35', '37', '30', '32', '34', '36', '40', '35', '37', '41', '36', '40', '35', '37', '41', '36', '40', '37', '41', '40', '41', '47', '40', '41', '47', '49', '40', '44', '50', '47', '49', '51', '44', '50', '52', '47', '49', '51', '53', '50', '52', '47', '49', '51', '53', '50', '52', '56', '49', '51', '53', '57', '50', '52', '56', '51', '53', '57', '52', '56', '60', '53', '57', '61', '56', '60', '62', '57', '61', '63', '60', '62', '64', '57', '61', '60', '62', '64', '66', '61', '67', '60', '62', '64', '66', '68', '61', '67', '69', '62', '64', '66', '68', '70', '67', '69', '71', '62', '64', '66', '68', '70', '72', '67', '69', '71', '66', '68', '70', '72', '74', '67', '69', '71', '75', '68', '70', '72', '74', '76', '69', '71', '75', '70', '72', '74', '76', '71', '75', '79', '72', '74',

{'1': (['bara', ' Ai', 'isphere', 'ahime', '²'], '3', 'no'),
 '3': (['4', ' 4', 'daq', 'atform', 'iola'], '5', 'no'),
 '5': (['cill', ' 6', 'agu', ' brush', '6'], '7', 'no'),
 '7': ([' 8', ' 808', ' 7', ' Lawson', '8'], '9', 'no'),
 '9': ([' Abe', ' Ark', ' AA', ' AQ', ' 9'], '11', 'no'),
 '11': (['arta', '12', 'IDA', ' 12', ' Dawson'], '13', 'no'),
 '2': (['3', 'ahime', 'alion', ' 3', 'atform'], '4', 'no'),
 '4': (['5', 'baugh', 'BB', 'Balt', 'acea'], '6', 'no'),
 '6': ([' 7', '7', '�', ' 07', '�'], '8', 'no'),
 '8': ([' Caldwell', 'acea', ' 8', 'bara', ' 9'], '10', 'no'),
 '10': (['1111', '11', 'CT', '910', ' 11'], '12', 'no'),
 '12': (['arta', 'tailed', '13', '12', ' 13'], '14', 'no'),
 '13': ([' 14', '14', 'BW', ' 1914', ' 13'], '15', 'no'),
 '14': (['15', 'arta', '1500', ' 15', 'ayer'], '16', 'no'),
 '15': ([' 16', 'ayson', '16', '1600', 'bara'], '17', 'no'),
 '16': (['17', ' 17', 'alde', 'eenth', '�'], '18', 'no'),
 '17': ([' 18', ' DD', ' Awoken', 'ッド', ' 17'], '19', 'no'),
 '18

# TBC

The cells below have not been updated yet for next scores so disregard them:

---



## Writing direction results with scatterplot

In [None]:
def scatter_attention_and_contribution(
    model,
    layer_no,
    head_no,
    dataset,
    S1_is_first=False,
    return_vals=False,
    return_fig=False,
):
    """
    Plot a scatter plot
    for each input sequence with the attention paid to S
    and the amount that is written in the S directions
    """

    n_heads = model.cfg.n_heads
    n_layers = model.cfg.n_layers
    model_unembed = model.unembed.W_U.detach().cpu()
    df = []
    cache = {}
    model.cache_all(cache)

    logits = model(dataset.toks.long())

    for i, prompt in enumerate(dataset.prompts):
        s_toks = []
        s_positions = []
        s_dirs = []

        targ_tokens = [key for key in dataset.prompts[0].keys() if key != 'text']
        for s_id in targ_tokens:
            if S1_is_first and s_id == "S1":  # only use this if first token doesn't have space Ġ in front
                s_tok = model.tokenizer(prompt["S1"])["input_ids"][0]
            else:
                s_tok = model.tokenizer(" " + prompt[s_id])["input_ids"][0]
            s_toks.append(s_tok)

            toks = model.tokenizer(prompt["text"])["input_ids"]
            try:
                s_pos = toks.index(s_tok)
            except ValueError:
                print(f"{s_tok} is not present in {toks}. Skipping...")
                continue

            s_pos = toks.index(s_tok)
            s_positions.append(s_pos)

            s_dir = model_unembed[:, s_tok].detach()
            s_dirs.append(s_dir)

        for dire, posses, tok_type in zip(s_dirs, s_positions, targ_tokens):
            prob = sum(
                [
                    cache[f"blocks.{layer_no}.attn.hook_attn"][
                        i, head_no, dataset.word_idx["end"][i], pos
                    ]
                    .detach()
                    .cpu()
                    for pos in [posses]
                ]
            )
            resid = (
                cache[f"blocks.{layer_no}.attn.hook_result"][
                    i, dataset.word_idx["end"][i], head_no, :
                ]
                .detach()
                .cpu()
            )
            dot = torch.einsum("a,a->", resid, dire)
            df.append([prob, dot, tok_type, prompt["text"]])

    viz_df = pd.DataFrame(
        df, columns=[f"Attn Prob on Number", f"Dot w Number Embed", "Seq Position", "text"]
    )
    fig = px.scatter(
        viz_df,
        x=f"Attn Prob on Number",
        y=f"Dot w Number Embed",
        color="Seq Position",
        hover_data=["text"],
        title=f"How Strong {layer_no}.{head_no} Writes in the Number Embed Direction Relative to Attn Prob",
    )

    if return_vals:
        return viz_df
    if return_fig:
        return fig
    else:
        fig.show()

In [None]:
scatter_attention_and_contribution(
    model=model, layer_no=9, head_no=1, dataset=dataset, S1_is_first=False
)

352 is not present in [16, 362, 513, 604]. Skipping...
362 is not present in [17, 513, 604, 642]. Skipping...
513 is not present in [18, 604, 642, 718]. Skipping...
604 is not present in [19, 642, 718, 767]. Skipping...
642 is not present in [20, 718, 767, 807]. Skipping...
718 is not present in [21, 767, 807, 860]. Skipping...
767 is not present in [22, 807, 860, 838]. Skipping...
807 is not present in [23, 860, 838, 1367]. Skipping...
860 is not present in [24, 838, 1367, 1105]. Skipping...
838 is not present in [940, 1367, 1105, 1511]. Skipping...
1367 is not present in [1157, 1105, 1511, 1478]. Skipping...
1105 is not present in [1065, 1511, 1478, 1315]. Skipping...
1511 is not present in [1485, 1478, 1315, 1467]. Skipping...
1478 is not present in [1415, 1315, 1467, 1596]. Skipping...
1315 is not present in [1314, 1467, 1596, 1248]. Skipping...
1467 is not present in [1433, 1596, 1248, 678]. Skipping...
1596 is not present in [1558, 1248, 678, 1160]. Skipping...
1248 is not presen

## Correlation vals

In [None]:
def get_prob_dot(  # same as scatterplot, but output x and y vals instead of plotting
    model,
    layer_no,
    head_no,
    dataset,
    S1_is_first=False,
    return_vals=False,
    return_fig=False,
):
    """
    Plot a scatter plot
    for each input sequence with the attention paid to S
    and the amount that is written in the S directions
    """

    n_heads = model.cfg.n_heads
    n_layers = model.cfg.n_layers
    model_unembed = model.unembed.W_U.detach().cpu()
    # df = []
    all_prob = []
    all_dot = []
    cache = {}
    model.cache_all(cache)

    logits = model(dataset.toks.long())

    for i, prompt in enumerate(dataset.prompts):
        s_toks = []
        s_positions = []
        s_dirs = []

        targ_tokens = [key for key in dataset.prompts[0].keys() if key != 'text']
        for s_id in targ_tokens:
            if S1_is_first and s_id == "S1":  # only use this if first token doesn't have space Ġ in front
                s_tok = model.tokenizer(prompt["S1"])["input_ids"][0]
            else:
                s_tok = model.tokenizer(" " + prompt[s_id])["input_ids"][0]
            s_toks.append(s_tok)

            toks = model.tokenizer(prompt["text"])["input_ids"]
            try:
                s_pos = toks.index(s_tok)
            except ValueError:
                print(f"{s_tok} is not present in {toks}. Skipping...")
                continue

            s_pos = toks.index(s_tok)
            s_positions.append(s_pos)

            s_dir = model_unembed[:, s_tok].detach()
            s_dirs.append(s_dir)

        for dire, posses, tok_type in zip(s_dirs, s_positions, targ_tokens):
            prob = sum(
                [
                    cache[f"blocks.{layer_no}.attn.hook_attn"][
                        i, head_no, dataset.word_idx["end"][i], pos
                    ]
                    .detach()
                    .cpu()
                    for pos in [posses]
                ]
            )
            resid = (
                cache[f"blocks.{layer_no}.attn.hook_result"][
                    i, dataset.word_idx["end"][i], head_no, :
                ]
                .detach()
                .cpu()
            )
            dot = torch.einsum("a,a->", resid, dire)
            #df.append([prob, dot, tok_type, prompt["text"]])
            all_prob.append(prob)
            all_dot.append(dot)

    return all_prob, all_dot


In [None]:
all_prob, all_dot = get_prob_dot(
    model=model, layer_no=9, head_no=9, dataset=dataset, S1_is_first=False
)

352 is not present in [16, 362, 513, 604]. Skipping...
362 is not present in [17, 513, 604, 642]. Skipping...
513 is not present in [18, 604, 642, 718]. Skipping...
604 is not present in [19, 642, 718, 767]. Skipping...
642 is not present in [20, 718, 767, 807]. Skipping...
718 is not present in [21, 767, 807, 860]. Skipping...
767 is not present in [22, 807, 860, 838]. Skipping...
807 is not present in [23, 860, 838, 1367]. Skipping...
860 is not present in [24, 838, 1367, 1105]. Skipping...
838 is not present in [940, 1367, 1105, 1511]. Skipping...
1367 is not present in [1157, 1105, 1511, 1478]. Skipping...
1105 is not present in [1065, 1511, 1478, 1315]. Skipping...
1511 is not present in [1485, 1478, 1315, 1467]. Skipping...
1478 is not present in [1415, 1315, 1467, 1596]. Skipping...
1315 is not present in [1314, 1467, 1596, 1248]. Skipping...
1467 is not present in [1433, 1596, 1248, 678]. Skipping...
1596 is not present in [1558, 1248, 678, 1160]. Skipping...
1248 is not presen

In [None]:
import scipy.stats as stats

# X and Y should be arrays, lists, or pandas Series
correlation, p_value = stats.pearsonr(all_prob, all_dot)

print("Correlation:", correlation)
print("p-value:", p_value)

Correlation: 0.8127540109290008
p-value: 8.99435128806603e-70


In [None]:
all_prob, all_dot = get_prob_dot(
    model=model, layer_no=9, head_no=1, dataset=dataset, S1_is_first=False
)

correlation, p_value = stats.pearsonr(all_prob, all_dot)
print("Correlation:", correlation)
print("p-value:", p_value)

352 is not present in [16, 362, 513, 604]. Skipping...
362 is not present in [17, 513, 604, 642]. Skipping...
513 is not present in [18, 604, 642, 718]. Skipping...
604 is not present in [19, 642, 718, 767]. Skipping...
642 is not present in [20, 718, 767, 807]. Skipping...
718 is not present in [21, 767, 807, 860]. Skipping...
767 is not present in [22, 807, 860, 838]. Skipping...
807 is not present in [23, 860, 838, 1367]. Skipping...
860 is not present in [24, 838, 1367, 1105]. Skipping...
838 is not present in [940, 1367, 1105, 1511]. Skipping...
1367 is not present in [1157, 1105, 1511, 1478]. Skipping...
1105 is not present in [1065, 1511, 1478, 1315]. Skipping...
1511 is not present in [1485, 1478, 1315, 1467]. Skipping...
1478 is not present in [1415, 1315, 1467, 1596]. Skipping...
1315 is not present in [1314, 1467, 1596, 1248]. Skipping...
1467 is not present in [1433, 1596, 1248, 678]. Skipping...
1596 is not present in [1558, 1248, 678, 1160]. Skipping...
1248 is not presen

In [None]:
all_prob, all_dot = get_prob_dot(
    model=model, layer_no=7, head_no=10, dataset=dataset, S1_is_first=False
)

correlation, p_value = stats.pearsonr(all_prob, all_dot)
print("Correlation:", correlation)
print("p-value:", p_value)

352 is not present in [16, 362, 513, 604]. Skipping...
362 is not present in [17, 513, 604, 642]. Skipping...
513 is not present in [18, 604, 642, 718]. Skipping...
604 is not present in [19, 642, 718, 767]. Skipping...
642 is not present in [20, 718, 767, 807]. Skipping...
718 is not present in [21, 767, 807, 860]. Skipping...
767 is not present in [22, 807, 860, 838]. Skipping...
807 is not present in [23, 860, 838, 1367]. Skipping...
860 is not present in [24, 838, 1367, 1105]. Skipping...
838 is not present in [940, 1367, 1105, 1511]. Skipping...
1367 is not present in [1157, 1105, 1511, 1478]. Skipping...
1105 is not present in [1065, 1511, 1478, 1315]. Skipping...
1511 is not present in [1485, 1478, 1315, 1467]. Skipping...
1478 is not present in [1415, 1315, 1467, 1596]. Skipping...
1315 is not present in [1314, 1467, 1596, 1248]. Skipping...
1467 is not present in [1433, 1596, 1248, 678]. Skipping...
1596 is not present in [1558, 1248, 678, 1160]. Skipping...
1248 is not presen

In [None]:
all_prob, all_dot = get_prob_dot(
    model=model, layer_no=5, head_no=1, dataset=dataset, S1_is_first=False
)

correlation, p_value = stats.pearsonr(all_prob, all_dot)
print("Correlation:", correlation)
print("p-value:", p_value)

352 is not present in [16, 362, 513, 604]. Skipping...
362 is not present in [17, 513, 604, 642]. Skipping...
513 is not present in [18, 604, 642, 718]. Skipping...
604 is not present in [19, 642, 718, 767]. Skipping...
642 is not present in [20, 718, 767, 807]. Skipping...
718 is not present in [21, 767, 807, 860]. Skipping...
767 is not present in [22, 807, 860, 838]. Skipping...
807 is not present in [23, 860, 838, 1367]. Skipping...
860 is not present in [24, 838, 1367, 1105]. Skipping...
838 is not present in [940, 1367, 1105, 1511]. Skipping...
1367 is not present in [1157, 1105, 1511, 1478]. Skipping...
1105 is not present in [1065, 1511, 1478, 1315]. Skipping...
1511 is not present in [1485, 1478, 1315, 1467]. Skipping...
1478 is not present in [1415, 1315, 1467, 1596]. Skipping...
1315 is not present in [1314, 1467, 1596, 1248]. Skipping...
1467 is not present in [1433, 1596, 1248, 678]. Skipping...
1596 is not present in [1558, 1248, 678, 1160]. Skipping...
1248 is not presen

In [None]:
all_prob, all_dot = get_prob_dot(
    model=model, layer_no=0, head_no=3, dataset=dataset, S1_is_first=False
)

correlation, p_value = stats.pearsonr(all_prob, all_dot)
print("Correlation:", correlation)
print("p-value:", p_value)

352 is not present in [16, 362, 513, 604]. Skipping...
362 is not present in [17, 513, 604, 642]. Skipping...
513 is not present in [18, 604, 642, 718]. Skipping...
604 is not present in [19, 642, 718, 767]. Skipping...
642 is not present in [20, 718, 767, 807]. Skipping...
718 is not present in [21, 767, 807, 860]. Skipping...
767 is not present in [22, 807, 860, 838]. Skipping...
807 is not present in [23, 860, 838, 1367]. Skipping...
860 is not present in [24, 838, 1367, 1105]. Skipping...
838 is not present in [940, 1367, 1105, 1511]. Skipping...
1367 is not present in [1157, 1105, 1511, 1478]. Skipping...
1105 is not present in [1065, 1511, 1478, 1315]. Skipping...
1511 is not present in [1485, 1478, 1315, 1467]. Skipping...
1478 is not present in [1415, 1315, 1467, 1596]. Skipping...
1315 is not present in [1314, 1467, 1596, 1248]. Skipping...
1467 is not present in [1433, 1596, 1248, 678]. Skipping...
1596 is not present in [1558, 1248, 678, 1160]. Skipping...
1248 is not presen