<a href="https://colab.research.google.com/github/wlg100/numseqcont_circuit_expms/blob/main/notebook_templates/headFNs_expms_template.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

# Setup

In [None]:
%%capture
%pip install git+https://github.com/redwoodresearch/Easy-Transformer.git
%pip install einops datasets transformers fancy_einsum

In [None]:
from copy import deepcopy
import torch

assert torch.cuda.device_count() == 1
from tqdm import tqdm
import pandas as pd
import torch
import torch as t
from easy_transformer.EasyTransformer import (
    EasyTransformer,
)
from time import ctime
from functools import partial

import numpy as np
from tqdm import tqdm
import pandas as pd

from easy_transformer.experiments import (
    ExperimentMetric,
    AblationConfig,
    EasyAblation,
    EasyPatching,
    PatchingConfig,
)
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
import random
import einops
from IPython import get_ipython
from copy import deepcopy
from easy_transformer.ioi_dataset import (
    IOIDataset,
)
from easy_transformer.ioi_utils import (
    path_patching,
    max_2d,
    CLASS_COLORS,
    show_pp,
    show_attention_patterns,
    scatter_attention_and_contribution,
)
from random import randint as ri
from easy_transformer.ioi_circuit_extraction import (
    do_circuit_extraction,
    get_heads_circuit,
    CIRCUIT,
)
from easy_transformer.ioi_utils import logit_diff, probs
from easy_transformer.ioi_utils import get_top_tokens_and_probs as g

ipython = get_ipython()
if ipython is not None:
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

 Initialise model (use larger N or fewer templates for no warnings about in-template ablation)

In [None]:
model = EasyTransformer.from_pretrained("gpt2").cuda()
# model = EasyTransformer.from_pretrained("gpt2")
model.set_use_attn_result(True)

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



Moving model to device:  cuda
Finished loading pretrained model gpt2 into EasyTransformer!


In [None]:
import pdb

# Generate dataset with multiple prompts

In [None]:
def generate_prompts_list(x ,y):
    prompts_list = []
    for i in range(x, y, -1):
        prompt_dict = {
            'S1': str(i),
            'S2': str(i-1),
            'S3': str(i-2),
            'S4': str(i-3),
            'S5': str(i-4),
            'S6': str(i-5),
            'S7': str(i-6),
            'S8': str(i-7),
            'S9': str(i-8),
            'S10': str(i-9),
            # 'S11': str(i-10),
            'text': f"{i} {i-1} {i-2} {i-3} {i-4} {i-5} {i-6} {i-7} {i-8} {i-9}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(101, 11)

In [None]:
class Dataset:
    def __init__(self, prompts, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        # all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts]
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        # self.io_tokenIDs = [
        #     self.tokenizer.encode(" " + prompt["S11"])[0] for prompt in self.prompts
        # ]
        # self.s_tokenIDs = [
        #     self.tokenizer.encode(" " + prompt["S10"])[0] for prompt in self.prompts
        # ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'S11')]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                    target_token = prompt[targ]
                else:
                    target_token = "Ġ" + prompt[targ]
                target_index = tokens.index(target_token)
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [None]:
dataset = Dataset(prompts_list, model.tokenizer, S1_is_first=True)

# Prev  score

In [None]:
import pdb

In [None]:
def get_prev_scores(model, layer, head, dataset, verbose=False, neg=False, print_tokens=True):
    cache = {}
    model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    model(dataset.toks.long())
    if neg:
        sign = -1
    else:
        sign = 1
    z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    pred_tokens_dict = {}
    words_moved = []
    # get the keys from the first prompt in the dataset
    words = [key for key in dataset.prompts[0].keys() if key != 'text']

    for seq_idx, prompt in enumerate(dataset.prompts):
        # for word in words:
        word = words[-1]

        pred_tokens = [
            model.tokenizer.decode(token)
            for token in torch.topk(
                logits[seq_idx, dataset.word_idx[word][seq_idx]], k
            ).indices
        ]

        # get prev member after digit prompt[word]
        prev_word = str(int(prompt[word]) - 1)

        prevToken_in_topK = 'no'
        if " " + prev_word in pred_tokens or prev_word in pred_tokens:
            n_right += 1
            words_moved.append(prompt[word])
            prevToken_in_topK = 'yes'
            # if prompt[word] == '99':
            #     pdb.set_trace()
        if prompt[word] in pred_tokens_dict:
            pdb.set_trace()
        pred_tokens_dict[prompt[word]] = (pred_tokens, prev_word, prevToken_in_topK)

    # percent_right = (n_right / (dataset.N * len(words))) * 100
    percent_right = (n_right / (dataset.N)) * 100
    if percent_right > 0:
        print(f"Head {layer}.{head} (sign={sign}) : Top {k} accuracy: {percent_right}%")

    if print_tokens == True:
        print(n_right)
        print((dataset.N))
        print(words_moved)
        return pred_tokens_dict
    else:
        if percent_right > 0:
            return words_moved

In [None]:
get_prev_scores(model, 6, 9, dataset)

Head 6.9 (sign=1) : Top 5 accuracy: 70.0%
63
90
['91', '90', '89', '87', '81', '80', '79', '78', '77', '76', '75', '74', '73', '71', '70', '69', '68', '67', '65', '64', '63', '61', '60', '59', '58', '57', '54', '53', '51', '50', '49', '48', '47', '45', '44', '43', '42', '41', '34', '31', '30', '29', '28', '27', '26', '24', '23', '22', '20', '19', '17', '16', '14', '13', '12', '11', '10', '9', '8', '7', '6', '5', '3']


{'92': ([' 92', ' 96', ' 93', ' 94', ' 88'], '91', 'no'),
 '91': ([' 91', ' 92', ' 93', ' 94', ' 90'], '90', 'yes'),
 '90': ([' 90', '90', ' 89', ' 88', ' 92'], '89', 'yes'),
 '89': ([' 89', ' 88', ' 1889', ' 29', ' 90'], '88', 'yes'),
 '88': ([' 88', ' 89', ' 8', ' 86', ' 48'], '87', 'no'),
 '87': ([' 87', ' 89', ' 88', ' 47', ' 86'], '86', 'yes'),
 '86': ([' 86', ' 46', ' 96', ' 87', ' 88'], '85', 'no'),
 '85': ([' 85', ' 86', ' 87', ' 45', ' 89'], '84', 'no'),
 '84': ([' 84', ' 8', ' 89', ' 44', ' 86'], '83', 'no'),
 '83': ([' 83', ' 84', ' 79', ' 93', ' 87'], '82', 'no'),
 '82': ([' 82', ' 84', ' 78', ' 83', ' Mont'], '81', 'no'),
 '81': ([' 81', ' 79', ' 80', ' 77', ' 84'], '80', 'yes'),
 '80': ([' 80', ' 79', ' 90', ' 78', ' 40'], '79', 'yes'),
 '79': ([' 79', ' 78', ' 59', ' 77', ' 29'], '78', 'yes'),
 '78': ([' 78', ' 79', ' 77', '78', ' Mont'], '77', 'yes'),
 '77': ([' 77', ' 78', ' 7', ' 79', ' 76'], '76', 'yes'),
 '76': ([' 76', ' 77', ' 74', ' 75', ' 78'], '75', 'yes'),
 '7

In [None]:
get_prev_scores(model, 9, 1, dataset)

Head 9.1 (sign=1) : Top 5 accuracy: 0.0%
0
90
[]


{'92': ([' 93', ' 94', ' 95', ' 92', ' 97'], '91', 'no'),
 '91': ([' 92', ' 93', ' 95', ' 94', ' 97'], '90', 'no'),
 '90': ([' 91', ' 95', ' 90', ' 92', ' 100'], '89', 'no'),
 '89': ([' 91', ' 90', ' 94', ' 95', ' 92'], '88', 'no'),
 '88': ([' 90', ' 89', ' 91', ' 99', ' 94'], '87', 'no'),
 '87': ([' 88', ' 89', ' 90', ' 98', ' 92'], '86', 'no'),
 '86': ([' 87', ' 88', ' 89', ' 92', ' 86'], '85', 'no'),
 '85': ([' 86', ' 85', ' 87', ' 90', ' 91'], '84', 'no'),
 '84': ([' 85', ' 86', 'rity', '85', ' 84'], '83', 'no'),
 '83': ([' 84', ' 85', ' 86', '84', ' 83'], '82', 'no'),
 '82': ([' 83', ' 84', ' 85', ' 86', ' 82'], '81', 'no'),
 '81': ([' 82', ' 83', ' 81', ' 84', ' 85'], '80', 'no'),
 '80': ([' 81', ' 80', ' 85', ' 90', ' 82'], '79', 'no'),
 '79': ([' 80', ' 81', ' 79', ' eighty', ' 85'], '78', 'no'),
 '78': ([' 79', ' 80', ' 81', ' 78', ' 84'], '77', 'no'),
 '77': ([' 78', ' 79', ' 80', ' 77', ' 81'], '76', 'no'),
 '76': ([' 77', ' 78', ' 79', ' 81', ' 76'], '75', 'no'),
 '75': (['

# Compare Copy scores

In [None]:
def get_copy_scores(model, layer, head, dataset, verbose=False, neg=False, print_tokens=True):
    cache = {}
    model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    model(dataset.toks.long())
    if neg:
        sign = -1
    else:
        sign = 1
    z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    pred_tokens_dict = {}
    words_moved = []
    # get the keys from the first prompt in the dataset
    words = [key for key in dataset.prompts[0].keys() if key != 'text']

    for seq_idx, prompt in enumerate(dataset.prompts):
        for word in words:
            pred_tokens = [
                model.tokenizer.decode(token)
                for token in torch.topk(
                    logits[seq_idx, dataset.word_idx[word][seq_idx]], k
                ).indices
            ]

            token_in_topK = 'no'
            if " " + prompt[word] in pred_tokens or prompt[word] in pred_tokens:
                n_right += 1
                words_moved.append(prompt[word])
                token_in_topK = 'yes'
            pred_tokens_dict[prompt[word]] = (pred_tokens, token_in_topK)

    percent_right = (n_right / (dataset.N * len(words))) * 100
    print(f"Copy circuit for head {layer}.{head} (sign={sign}) : Top {k} accuracy: {percent_right}%")

    if print_tokens == True:
        return pred_tokens_dict
    else:
        return words_moved

In [None]:
get_copy_scores(model, 9, 1, dataset)

Copy circuit for head 9.1 (sign=1) : Top 5 accuracy: 59.27835051546392%


{'1': ([' one', ' needed', ' preferred', ' single', '2'], 'no'),
 '2': ([' third', ' fourth', 'third', '3', ' fifth'], 'no'),
 '3': ([' fourth', ' third', ' fifth', 'Fourth', 'fourth'], 'no'),
 '4': ([' fifth', ' sixth', ' seventh', 'fifth', 'five'], 'no'),
 '5': ([' sixth', ' seventh', ' fifth', '6', ' eighth'], 'no'),
 '6': ([' seventh', ' Seventh', ' sixth', '7', ' eighth'], 'no'),
 '7': ([' seventh', ' eighth', ' ninth', ' VIII', ' Seventh'], 'no'),
 '8': ([' ninth', ' eighth', 'ighth', ' seventh', '9'], 'no'),
 '9': ([' ninth', ' seventh', ' tenth', ' eighth', ' sixth'], 'no'),
 '10': ([' tenth', ' seventh', ' eighth', ' ninth', ' sixth'], 'no'),
 '11': ([' eighth', ' seventh', ' specific', ' 12', '12'], 'no'),
 '12': ([' seventh', ' eighth', '13', ' specific', '14'], 'no'),
 '13': ([' seventh', '14', ' 14', ' eighth', 'Only'], 'no'),
 '14': ([' eighth', ' seventh', ' fifth', ' maximum', ' ninth'], 'no'),
 '15': ([' seventh', ' eighth', ' sixth', ' fifth', ' final'], 'no'),
 '16':

# Compare to Next score

In [None]:
def get_next_scores(model, layer, head, dataset, verbose=False, neg=False, print_tokens=True):
    cache = {}
    model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    model(dataset.toks.long())
    if neg:
        sign = -1
    else:
        sign = 1
    z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    pred_tokens_dict = {}
    words_moved = []
    # get the keys from the first prompt in the dataset
    words = [key for key in dataset.prompts[0].keys() if key != 'text']

    for seq_idx, prompt in enumerate(dataset.prompts):
        # for word in words:
        word = words[-1]

        pred_tokens = [
            model.tokenizer.decode(token)
            for token in torch.topk(
                logits[seq_idx, dataset.word_idx[word][seq_idx]], k
            ).indices
        ]

        # get next member after digit prompt[word]
        next_word = str(int(prompt[word]) + 1)

        nextToken_in_topK = 'no'
        if " " + next_word in pred_tokens or next_word in pred_tokens:
            n_right += 1
            words_moved.append(prompt[word])
            nextToken_in_topK = 'yes'
        pred_tokens_dict[prompt[word]] = (pred_tokens, next_word, nextToken_in_topK)

    # percent_right = (n_right / (dataset.N * len(words))) * 100
    percent_right = (n_right / (dataset.N)) * 100
    if percent_right > 0:
        print(f"Next circuit for head {layer}.{head} (sign={sign}) : Top {k} accuracy: {percent_right}%")

    if print_tokens == True:
        return pred_tokens_dict
    else:
        if percent_right > 0:
            return words_moved

In [None]:
get_next_scores(model, 9, 1, dataset)

Next circuit for head 9.1 (sign=1) : Top 5 accuracy: 100.0%


{'92': ([' 93', ' 94', ' 95', ' 92', ' 97'], '93', 'yes'),
 '91': ([' 92', ' 93', ' 95', ' 94', ' 97'], '92', 'yes'),
 '90': ([' 91', ' 95', ' 90', ' 92', ' 100'], '91', 'yes'),
 '89': ([' 91', ' 90', ' 94', ' 95', ' 92'], '90', 'yes'),
 '88': ([' 90', ' 89', ' 91', ' 99', ' 94'], '89', 'yes'),
 '87': ([' 88', ' 89', ' 90', ' 98', ' 92'], '88', 'yes'),
 '86': ([' 87', ' 88', ' 89', ' 92', ' 86'], '87', 'yes'),
 '85': ([' 86', ' 85', ' 87', ' 90', ' 91'], '86', 'yes'),
 '84': ([' 85', ' 86', 'rity', '85', ' 84'], '85', 'yes'),
 '83': ([' 84', ' 85', ' 86', '84', ' 83'], '84', 'yes'),
 '82': ([' 83', ' 84', ' 85', ' 86', ' 82'], '83', 'yes'),
 '81': ([' 82', ' 83', ' 81', ' 84', ' 85'], '82', 'yes'),
 '80': ([' 81', ' 80', ' 85', ' 90', ' 82'], '81', 'yes'),
 '79': ([' 80', ' 81', ' 79', ' eighty', ' 85'], '80', 'yes'),
 '78': ([' 79', ' 80', ' 81', ' 78', ' 84'], '79', 'yes'),
 '77': ([' 78', ' 79', ' 80', ' 77', ' 81'], '78', 'yes'),
 '76': ([' 77', ' 78', ' 79', ' 81', ' 76'], '77', '

In [None]:
get_next_scores(model, 6, 9, dataset)

Next circuit for head 6.9 (sign=1) : Top 5 accuracy: 74.44444444444444%


{'92': ([' 92', ' 96', ' 93', ' 94', ' 88'], '93', 'yes'),
 '91': ([' 91', ' 92', ' 93', ' 94', ' 90'], '92', 'yes'),
 '90': ([' 90', '90', ' 89', ' 88', ' 92'], '91', 'no'),
 '89': ([' 89', ' 88', ' 1889', ' 29', ' 90'], '90', 'yes'),
 '88': ([' 88', ' 89', ' 8', ' 86', ' 48'], '89', 'yes'),
 '87': ([' 87', ' 89', ' 88', ' 47', ' 86'], '88', 'yes'),
 '86': ([' 86', ' 46', ' 96', ' 87', ' 88'], '87', 'yes'),
 '85': ([' 85', ' 86', ' 87', ' 45', ' 89'], '86', 'yes'),
 '84': ([' 84', ' 8', ' 89', ' 44', ' 86'], '85', 'no'),
 '83': ([' 83', ' 84', ' 79', ' 93', ' 87'], '84', 'yes'),
 '82': ([' 82', ' 84', ' 78', ' 83', ' Mont'], '83', 'yes'),
 '81': ([' 81', ' 79', ' 80', ' 77', ' 84'], '82', 'no'),
 '80': ([' 80', ' 79', ' 90', ' 78', ' 40'], '81', 'no'),
 '79': ([' 79', ' 78', ' 59', ' 77', ' 29'], '80', 'no'),
 '78': ([' 78', ' 79', ' 77', '78', ' Mont'], '79', 'yes'),
 '77': ([' 77', ' 78', ' 7', ' 79', ' 76'], '78', 'yes'),
 '76': ([' 76', ' 77', ' 74', ' 75', ' 78'], '77', 'yes'),
 

In [None]:
all_heads = [(layer, head) for layer in range(12) for head in range(12)]
for index, (layer, head) in enumerate(all_heads):
    get_next_scores(model, layer, head, dataset, print_tokens=False)

Next circuit for head 3.5 (sign=1) : Top 5 accuracy: 3.3333333333333335%
Next circuit for head 5.0 (sign=1) : Top 5 accuracy: 5.555555555555555%
Next circuit for head 6.1 (sign=1) : Top 5 accuracy: 15.555555555555555%
Next circuit for head 6.9 (sign=1) : Top 5 accuracy: 74.44444444444444%
Next circuit for head 7.2 (sign=1) : Top 5 accuracy: 33.33333333333333%
Next circuit for head 7.7 (sign=1) : Top 5 accuracy: 4.444444444444445%
Next circuit for head 7.8 (sign=1) : Top 5 accuracy: 1.1111111111111112%
Next circuit for head 7.10 (sign=1) : Top 5 accuracy: 55.55555555555556%
Next circuit for head 7.11 (sign=1) : Top 5 accuracy: 8.88888888888889%
Next circuit for head 8.0 (sign=1) : Top 5 accuracy: 1.1111111111111112%
Next circuit for head 8.1 (sign=1) : Top 5 accuracy: 38.88888888888889%
Next circuit for head 8.8 (sign=1) : Top 5 accuracy: 80.0%
Next circuit for head 8.11 (sign=1) : Top 5 accuracy: 33.33333333333333%
Next circuit for head 9.1 (sign=1) : Top 5 accuracy: 100.0%
Next circui

In [None]:
all_heads = [(layer, head) for layer in range(12) for head in range(12)]
for index, (layer, head) in enumerate(all_heads):
    get_prev_scores(model, layer, head, dataset, print_tokens=False)

Head 2.1 (sign=1) : Top 5 accuracy: 1.1111111111111112%
Head 2.4 (sign=1) : Top 5 accuracy: 1.1111111111111112%
Head 3.5 (sign=1) : Top 5 accuracy: 1.1111111111111112%
Head 5.0 (sign=1) : Top 5 accuracy: 7.777777777777778%
Head 5.1 (sign=1) : Top 5 accuracy: 1.1111111111111112%
Head 6.1 (sign=1) : Top 5 accuracy: 2.2222222222222223%
Head 6.9 (sign=1) : Top 5 accuracy: 70.0%
Head 6.10 (sign=1) : Top 5 accuracy: 2.2222222222222223%
Head 7.2 (sign=1) : Top 5 accuracy: 56.666666666666664%
Head 7.7 (sign=1) : Top 5 accuracy: 6.666666666666667%
Head 7.8 (sign=1) : Top 5 accuracy: 1.1111111111111112%
Head 7.10 (sign=1) : Top 5 accuracy: 8.88888888888889%
Head 7.11 (sign=1) : Top 5 accuracy: 12.222222222222221%
Head 8.0 (sign=1) : Top 5 accuracy: 1.1111111111111112%
Head 8.1 (sign=1) : Top 5 accuracy: 57.77777777777777%
Head 8.6 (sign=1) : Top 5 accuracy: 1.1111111111111112%
Head 8.8 (sign=1) : Top 5 accuracy: 35.55555555555556%
Head 8.11 (sign=1) : Top 5 accuracy: 1.1111111111111112%
Head 9.3

Unlike 9.1 in incr, descr doesn't seem to have specialized heads for "next". There is 10.2, but it wasn't chosen for the 'prune fwd' circuit (though it was for 'prune backw').

# Find prev heads of desc circ

from: https://colab.research.google.com/drive/1odPpf7w_gBG8ZfAB2L6SXZszsDUk1CGA#scrollTo=ET--8aulD8pE&line=1&uniqifier=1

In [None]:
decr_circ = [(0, 1), (0, 3), (0, 5), (0, 7), (0, 9), (1, 0), (1, 5), (2, 2), (2, 4), (2, 9), (3, 0), (3, 3), (3, 7), (3, 10), (4, 6), (4, 7), (4, 10), (4, 11), (5, 1), (5, 5), (5, 6), (6, 1), (6, 7), (6, 9), (7, 2), (7, 10), (7, 11), (8, 1), (8, 6), (8, 8), (8, 10), (9, 5), (10, 7), (11, 0), (11, 8), (11, 11)]
for index, (layer, head) in enumerate(decr_circ):
    print(index, get_prev_scores(model, layer, head, dataset, print_tokens=False))

Head 0.1 (sign=1) : Top 5 accuracy: 0.0%
0 []
Head 0.3 (sign=1) : Top 5 accuracy: 0.0%
1 []
Head 0.5 (sign=1) : Top 5 accuracy: 0.0%
2 []
Head 0.7 (sign=1) : Top 5 accuracy: 0.0%
3 []
Head 0.9 (sign=1) : Top 5 accuracy: 0.0%
4 []
Head 1.0 (sign=1) : Top 5 accuracy: 0.0%
5 []
Head 1.5 (sign=1) : Top 5 accuracy: 0.0%
6 []
Head 2.2 (sign=1) : Top 5 accuracy: 0.0%
7 []
Head 2.4 (sign=1) : Top 5 accuracy: 1.1111111111111112%
8 ['92']
Head 2.9 (sign=1) : Top 5 accuracy: 0.0%
9 []
Head 3.0 (sign=1) : Top 5 accuracy: 0.0%
10 []
Head 3.3 (sign=1) : Top 5 accuracy: 0.0%
11 []
Head 3.7 (sign=1) : Top 5 accuracy: 0.0%
12 []
Head 3.10 (sign=1) : Top 5 accuracy: 0.0%
13 []
Head 4.6 (sign=1) : Top 5 accuracy: 0.0%
14 []
Head 4.7 (sign=1) : Top 5 accuracy: 0.0%
15 []
Head 4.10 (sign=1) : Top 5 accuracy: 0.0%
16 []
Head 4.11 (sign=1) : Top 5 accuracy: 0.0%
17 []
Head 5.1 (sign=1) : Top 5 accuracy: 1.1111111111111112%
18 ['91']
Head 5.5 (sign=1) : Top 5 accuracy: 0.0%
19 []
Head 5.6 (sign=1) : Top 5 acc

Compare these scores to next scores

In [None]:
for index, (layer, head) in enumerate(decr_circ):
    print(index, get_next_scores(model, layer, head, dataset, print_tokens=False))

Next circuit for head 0.1 (sign=1) : Top 5 accuracy: 0.0%
0 []
Next circuit for head 0.3 (sign=1) : Top 5 accuracy: 0.0%
1 []
Next circuit for head 0.5 (sign=1) : Top 5 accuracy: 0.0%
2 []
Next circuit for head 0.7 (sign=1) : Top 5 accuracy: 0.0%
3 []
Next circuit for head 0.9 (sign=1) : Top 5 accuracy: 0.0%
4 []
Next circuit for head 1.0 (sign=1) : Top 5 accuracy: 0.0%
5 []
Next circuit for head 1.5 (sign=1) : Top 5 accuracy: 0.0%
6 []
Next circuit for head 2.2 (sign=1) : Top 5 accuracy: 0.0%
7 []
Next circuit for head 2.4 (sign=1) : Top 5 accuracy: 0.0%
8 []
Next circuit for head 2.9 (sign=1) : Top 5 accuracy: 0.0%
9 []
Next circuit for head 3.0 (sign=1) : Top 5 accuracy: 0.0%
10 []
Next circuit for head 3.3 (sign=1) : Top 5 accuracy: 0.0%
11 []
Next circuit for head 3.7 (sign=1) : Top 5 accuracy: 0.0%
12 []
Next circuit for head 3.10 (sign=1) : Top 5 accuracy: 0.0%
13 []
Next circuit for head 4.6 (sign=1) : Top 5 accuracy: 0.0%
14 []
Next circuit for head 4.7 (sign=1) : Top 5 accura

# Input the first token of seq, not last

In [None]:
def get_prev_scores_token0(model, layer, head, dataset, verbose=False, neg=False, print_tokens=True):
    cache = {}
    model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    model(dataset.toks.long())
    if neg:
        sign = -1
    else:
        sign = 1
    z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    pred_tokens_dict = {}
    words_moved = []
    # get the keys from the first prompt in the dataset
    words = [key for key in dataset.prompts[0].keys() if key != 'text']

    for seq_idx, prompt in enumerate(dataset.prompts):
        # for word in words:
        word = words[0]

        pred_tokens = [
            model.tokenizer.decode(token)
            for token in torch.topk(
                logits[seq_idx, dataset.word_idx[word][seq_idx]], k
            ).indices
        ]

        # get prev member after digit prompt[word]
        prev_word = str(int(prompt[word]) - 1)

        prevToken_in_topK = 'no'
        if " " + prev_word in pred_tokens or prev_word in pred_tokens:
            n_right += 1
            words_moved.append(prompt[word])
            prevToken_in_topK = 'yes'
            # if prompt[word] == '99':
            #     pdb.set_trace()
        if prompt[word] in pred_tokens_dict:
            pdb.set_trace()
        pred_tokens_dict[prompt[word]] = (pred_tokens, prev_word, prevToken_in_topK)

    # percent_right = (n_right / (dataset.N * len(words))) * 100
    percent_right = (n_right / (dataset.N)) * 100
    if percent_right > 0:
        print(f"Head {layer}.{head} (sign={sign}) : Top {k} accuracy: {percent_right}%")

    if print_tokens == True:
        print(n_right)
        print((dataset.N))
        print(words_moved)
        return pred_tokens_dict
    else:
        if percent_right > 0:
            return words_moved

In [None]:
get_prev_scores_token0(model, 6, 9, dataset)

Head 6.9 (sign=1) : Top 5 accuracy: 3.3333333333333335%
3
90
['70', '66', '26']


{'101': (['102', '101', '1001', 'Type', ' Series'], '100', 'no'),
 '100': ([' Series', 'Domain', '-', '100', 'opia'], '99', 'no'),
 '99': ([' Series', 'Type', 'Made', '99', 'Us'], '98', 'no'),
 '98': ([' Series', '98', 'Made', 'Type', 'less'], '97', 'no'),
 '97': ([' Series', 'Force', 'Made', 'Type', 'Maker'], '96', 'no'),
 '96': ([' Series', 'Type', 'Made', 'Count', 'Series'], '95', 'no'),
 '95': ([' Series', 'thood', 'Type', '65', 'Maker'], '94', 'no'),
 '94': ([' Series', 'Type', 'Mode', 'visors', 'Series'], '93', 'no'),
 '93': ([' Series', 'Type', 'Series', 'Bi', 'Maker'], '92', 'no'),
 '92': ([' Series', 'Maker', 'meter', 'Mode', 'Made'], '91', 'no'),
 '91': (['Type', ' Series', 'Mode', 'Maker', 'Cause'], '90', 'no'),
 '90': (['Maker', 'thood', 'Type', 'Made', ' Series'], '89', 'no'),
 '89': ([' Series', 'Count', 'Made', 'Maker', ' Labs'], '88', 'no'),
 '88': ([' Series', '88', 'Made', 'Desk', 'Count'], '87', 'no'),
 '87': ([' Series', ' Labs', ' Facility', ' Experiment', 'Count']

# Input the second token of seq, not last

In [None]:
def get_prev_scores_token1(model, layer, head, dataset, verbose=False, neg=False, print_tokens=True):
    cache = {}
    model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    model(dataset.toks.long())
    if neg:
        sign = -1
    else:
        sign = 1
    z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    pred_tokens_dict = {}
    words_moved = []
    # get the keys from the first prompt in the dataset
    words = [key for key in dataset.prompts[0].keys() if key != 'text']

    for seq_idx, prompt in enumerate(dataset.prompts):
        # for word in words:
        word = words[1]

        pred_tokens = [
            model.tokenizer.decode(token)
            for token in torch.topk(
                logits[seq_idx, dataset.word_idx[word][seq_idx]], k
            ).indices
        ]

        # get prev member after digit prompt[word]
        prev_word = str(int(prompt[word]) - 1)

        prevToken_in_topK = 'no'
        if " " + prev_word in pred_tokens or prev_word in pred_tokens:
            n_right += 1
            words_moved.append(prompt[word])
            prevToken_in_topK = 'yes'
            # if prompt[word] == '99':
            #     pdb.set_trace()
        if prompt[word] in pred_tokens_dict:
            pdb.set_trace()
        pred_tokens_dict[prompt[word]] = (pred_tokens, prev_word, prevToken_in_topK)

    # percent_right = (n_right / (dataset.N * len(words))) * 100
    percent_right = (n_right / (dataset.N)) * 100
    if percent_right > 0:
        print(f"Head {layer}.{head} (sign={sign}) : Top {k} accuracy: {percent_right}%")

    if print_tokens == True:
        print(n_right)
        print((dataset.N))
        print(words_moved)
        return pred_tokens_dict
    else:
        if percent_right > 0:
            return words_moved

In [None]:
get_prev_scores_token1(model, 6, 9, dataset)

Head 6.9 (sign=1) : Top 5 accuracy: 67.77777777777779%
61
90
['99', '97', '95', '94', '93', '92', '90', '89', '81', '80', '79', '78', '77', '76', '74', '73', '71', '70', '69', '68', '65', '64', '63', '62', '61', '60', '59', '58', '57', '54', '53', '52', '51', '48', '47', '45', '44', '43', '42', '41', '40', '38', '37', '34', '33', '31', '30', '29', '28', '27', '26', '24', '23', '22', '21', '17', '16', '14', '13', '12', '11']


{'100': ([' 100', ' 50', ' 101', ' 96', ' 10'], '99', 'no'),
 '99': ([' 99', ' 59', ' 39', ' 98', ' 29'], '98', 'yes'),
 '98': ([' 98', ' 96', ' 58', ' 94', ' 88'], '97', 'no'),
 '97': ([' 96', ' 97', ' 98', ' 95', ' 93'], '96', 'yes'),
 '96': ([' 96', ' 56', ' 94', ' 97', ' 92'], '95', 'no'),
 '95': ([' 95', ' 96', ' 93', ' 94', 'の魔'], '94', 'yes'),
 '94': ([' 94', ' 96', ' 93', ' 92', ' 54'], '93', 'yes'),
 '93': ([' 93', ' 92', ' 94', ' 91', ' 293'], '92', 'yes'),
 '92': ([' 92', ' 96', ' 93', ' 94', ' 91'], '91', 'yes'),
 '91': ([' 91', ' 93', ' 92', ' 96', ' 61'], '90', 'no'),
 '90': ([' 90', ' 89', ' 92', ' 88', ' 60'], '89', 'yes'),
 '89': ([' 89', ' 59', ' 39', ' 29', ' 88'], '88', 'yes'),
 '88': ([' 88', ' 89', ' 8', ' 58', ' 86'], '87', 'no'),
 '87': ([' 87', ' 89', ' 47', ' 77', ' 96'], '86', 'no'),
 '86': ([' 86', ' 96', ' 87', ' 46', ' 56'], '85', 'no'),
 '85': ([' 85', ' 86', ' 89', ' 87', ' 75'], '84', 'no'),
 '84': ([' 84', ' 89', ' 79', ' 8', ' 86'], '83', 'no'),
 '83'

Perhaps first token has a space char in front? This is more similar to 'last'.

# Input the all tokens of seq

In [None]:
def get_prev_scores_allPos(model, layer, head, dataset, verbose=False, neg=False, print_tokens=True):
    cache = {}
    model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    model(dataset.toks.long())
    if neg:
        sign = -1
    else:
        sign = 1
    z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    pred_tokens_dict = {}
    words_moved = []
    # get the keys from the first prompt in the dataset
    words = [key for key in dataset.prompts[0].keys() if key != 'text']

    for seq_idx, prompt in enumerate(dataset.prompts):
        for word in words:

            pred_tokens = [
                model.tokenizer.decode(token)
                for token in torch.topk(
                    logits[seq_idx, dataset.word_idx[word][seq_idx]], k
                ).indices
            ]

            # get prev member after digit prompt[word]
            prev_word = str(int(prompt[word]) - 1)

            prevToken_in_topK = 'no'
            if " " + prev_word in pred_tokens or prev_word in pred_tokens:
                n_right += 1
                words_moved.append(prompt[word])
                prevToken_in_topK = 'yes'
            if prompt[word] in pred_tokens_dict:
                pdb.set_trace()
            pred_tokens_dict[str(seq_idx) + "_" + prompt[word]] = (pred_tokens, prev_word, prevToken_in_topK)

    percent_right = (n_right / (dataset.N * len(words))) * 100
    # percent_right = (n_right / (dataset.N)) * 100
    if percent_right > 0:
        print(f"Head {layer}.{head} (sign={sign}) : Top {k} accuracy: {percent_right}%")

    if print_tokens == True:
        print(n_right)
        print((dataset.N * len(words)))
        print(words_moved)
        return pred_tokens_dict
    else:
        if percent_right > 0:
            return words_moved

In [None]:
get_prev_scores_allPos(model, 6, 9, dataset)

Head 6.9 (sign=1) : Top 5 accuracy: 65.88888888888889%
593
900
['99', '97', '96', '95', '94', '93', '99', '97', '96', '95', '94', '93', '91', '97', '96', '95', '94', '93', '91', '90', '97', '96', '95', '94', '93', '91', '90', '89', '95', '94', '93', '91', '90', '89', '95', '94', '93', '90', '89', '87', '94', '93', '90', '89', '87', '93', '90', '89', '87', '92', '90', '89', '87', '90', '89', '87', '90', '89', '84', '89', '88', '84', '81', '84', '81', '80', '81', '80', '79', '84', '81', '80', '79', '78', '84', '81', '80', '79', '78', '77', '81', '80', '79', '78', '77', '76', '81', '80', '79', '78', '77', '76', '75', '81', '80', '79', '78', '77', '76', '75', '74', '81', '80', '79', '78', '77', '76', '75', '74', '73', '80', '79', '78', '77', '76', '75', '74', '73', '79', '78', '77', '76', '75', '74', '73', '72', '71', '78', '77', '76', '75', '74', '73', '72', '71', '70', '77', '76', '75', '74', '73', '72', '71', '70', '69', '76', '75', '74', '73', '72', '71', '70', '69', '68', '74', '73', 

{'0_101': (['102', '101', '1001', 'Type', ' Series'], '100', 'no'),
 '0_100': ([' 100', ' 50', ' 101', ' 96', ' 10'], '99', 'no'),
 '0_99': ([' 99', ' 59', ' 98', ' 95', ' 96'], '98', 'yes'),
 '0_98': ([' 98', ' 96', ' 94', ' 58', ' 88'], '97', 'no'),
 '0_97': ([' 97', ' 96', ' 98', ' 95', ' 77'], '96', 'yes'),
 '0_96': ([' 96', ' 94', ' 56', ' 97', ' 95'], '95', 'yes'),
 '0_95': ([' 95', 'の魔', ' 96', ' 94', ' 93'], '94', 'yes'),
 '0_94': ([' 94', ' 96', ' 93', ' 92', ' 104'], '93', 'yes'),
 '0_93': ([' 93', ' 94', ' 92', ' 91', ' 23'], '92', 'yes'),
 '0_92': ([' 92', ' 96', ' 93', ' 94', ' 88'], '91', 'no'),
 '1_100': ([' Series', 'Domain', '-', '100', 'opia'], '99', 'no'),
 '1_99': ([' 99', ' 59', ' 39', ' 98', ' 29'], '98', 'yes'),
 '1_98': ([' 98', ' 96', ' 94', ' 58', ' 88'], '97', 'no'),
 '1_97': ([' 97', ' 96', ' 98', ' 95', ' 77'], '96', 'yes'),
 '1_96': ([' 96', ' 94', ' 56', ' 97', ' 95'], '95', 'yes'),
 '1_95': ([' 95', 'の魔', ' 96', ' 94', ' 93'], '94', 'yes'),
 '1_94': ([' 

# TBC

The cells below have not been updated yet for next scores so disregard them:

---



## Writing direction results with scatterplot

In [None]:
def scatter_attention_and_contribution(
    model,
    layer_no,
    head_no,
    dataset,
    S1_is_first=False,
    return_vals=False,
    return_fig=False,
):
    """
    Plot a scatter plot
    for each input sequence with the attention paid to S
    and the amount that is written in the S directions
    """

    n_heads = model.cfg.n_heads
    n_layers = model.cfg.n_layers
    model_unembed = model.unembed.W_U.detach().cpu()
    df = []
    cache = {}
    model.cache_all(cache)

    logits = model(dataset.toks.long())

    for i, prompt in enumerate(dataset.prompts):
        s_toks = []
        s_positions = []
        s_dirs = []

        targ_tokens = [key for key in dataset.prompts[0].keys() if key != 'text']
        for s_id in targ_tokens:
            if S1_is_first and s_id == "S1":  # only use this if first token doesn't have space Ġ in front
                s_tok = model.tokenizer(prompt["S1"])["input_ids"][0]
            else:
                s_tok = model.tokenizer(" " + prompt[s_id])["input_ids"][0]
            s_toks.append(s_tok)

            toks = model.tokenizer(prompt["text"])["input_ids"]
            try:
                s_pos = toks.index(s_tok)
            except ValueError:
                print(f"{s_tok} is not present in {toks}. Skipping...")
                continue

            s_pos = toks.index(s_tok)
            s_positions.append(s_pos)

            s_dir = model_unembed[:, s_tok].detach()
            s_dirs.append(s_dir)

        for dire, posses, tok_type in zip(s_dirs, s_positions, targ_tokens):
            prob = sum(
                [
                    cache[f"blocks.{layer_no}.attn.hook_attn"][
                        i, head_no, dataset.word_idx["end"][i], pos
                    ]
                    .detach()
                    .cpu()
                    for pos in [posses]
                ]
            )
            resid = (
                cache[f"blocks.{layer_no}.attn.hook_result"][
                    i, dataset.word_idx["end"][i], head_no, :
                ]
                .detach()
                .cpu()
            )
            dot = torch.einsum("a,a->", resid, dire)
            df.append([prob, dot, tok_type, prompt["text"]])

    viz_df = pd.DataFrame(
        df, columns=[f"Attn Prob on Number", f"Dot w Number Embed", "Seq Position", "text"]
    )
    fig = px.scatter(
        viz_df,
        x=f"Attn Prob on Number",
        y=f"Dot w Number Embed",
        color="Seq Position",
        hover_data=["text"],
        title=f"How Strong {layer_no}.{head_no} Writes in the Number Embed Direction Relative to Attn Prob",
    )

    if return_vals:
        return viz_df
    if return_fig:
        return fig
    else:
        fig.show()

In [None]:
scatter_attention_and_contribution(
    model=model, layer_no=9, head_no=1, dataset=dataset, S1_is_first=False
)

352 is not present in [16, 362, 513, 604]. Skipping...
362 is not present in [17, 513, 604, 642]. Skipping...
513 is not present in [18, 604, 642, 718]. Skipping...
604 is not present in [19, 642, 718, 767]. Skipping...
642 is not present in [20, 718, 767, 807]. Skipping...
718 is not present in [21, 767, 807, 860]. Skipping...
767 is not present in [22, 807, 860, 838]. Skipping...
807 is not present in [23, 860, 838, 1367]. Skipping...
860 is not present in [24, 838, 1367, 1105]. Skipping...
838 is not present in [940, 1367, 1105, 1511]. Skipping...
1367 is not present in [1157, 1105, 1511, 1478]. Skipping...
1105 is not present in [1065, 1511, 1478, 1315]. Skipping...
1511 is not present in [1485, 1478, 1315, 1467]. Skipping...
1478 is not present in [1415, 1315, 1467, 1596]. Skipping...
1315 is not present in [1314, 1467, 1596, 1248]. Skipping...
1467 is not present in [1433, 1596, 1248, 678]. Skipping...
1596 is not present in [1558, 1248, 678, 1160]. Skipping...
1248 is not presen

## Correlation vals

In [None]:
def get_prob_dot(  # same as scatterplot, but output x and y vals instead of plotting
    model,
    layer_no,
    head_no,
    dataset,
    S1_is_first=False,
    return_vals=False,
    return_fig=False,
):
    """
    Plot a scatter plot
    for each input sequence with the attention paid to S
    and the amount that is written in the S directions
    """

    n_heads = model.cfg.n_heads
    n_layers = model.cfg.n_layers
    model_unembed = model.unembed.W_U.detach().cpu()
    # df = []
    all_prob = []
    all_dot = []
    cache = {}
    model.cache_all(cache)

    logits = model(dataset.toks.long())

    for i, prompt in enumerate(dataset.prompts):
        s_toks = []
        s_positions = []
        s_dirs = []

        targ_tokens = [key for key in dataset.prompts[0].keys() if key != 'text']
        for s_id in targ_tokens:
            if S1_is_first and s_id == "S1":  # only use this if first token doesn't have space Ġ in front
                s_tok = model.tokenizer(prompt["S1"])["input_ids"][0]
            else:
                s_tok = model.tokenizer(" " + prompt[s_id])["input_ids"][0]
            s_toks.append(s_tok)

            toks = model.tokenizer(prompt["text"])["input_ids"]
            try:
                s_pos = toks.index(s_tok)
            except ValueError:
                print(f"{s_tok} is not present in {toks}. Skipping...")
                continue

            s_pos = toks.index(s_tok)
            s_positions.append(s_pos)

            s_dir = model_unembed[:, s_tok].detach()
            s_dirs.append(s_dir)

        for dire, posses, tok_type in zip(s_dirs, s_positions, targ_tokens):
            prob = sum(
                [
                    cache[f"blocks.{layer_no}.attn.hook_attn"][
                        i, head_no, dataset.word_idx["end"][i], pos
                    ]
                    .detach()
                    .cpu()
                    for pos in [posses]
                ]
            )
            resid = (
                cache[f"blocks.{layer_no}.attn.hook_result"][
                    i, dataset.word_idx["end"][i], head_no, :
                ]
                .detach()
                .cpu()
            )
            dot = torch.einsum("a,a->", resid, dire)
            #df.append([prob, dot, tok_type, prompt["text"]])
            all_prob.append(prob)
            all_dot.append(dot)

    return all_prob, all_dot


In [None]:
all_prob, all_dot = get_prob_dot(
    model=model, layer_no=9, head_no=9, dataset=dataset, S1_is_first=False
)

352 is not present in [16, 362, 513, 604]. Skipping...
362 is not present in [17, 513, 604, 642]. Skipping...
513 is not present in [18, 604, 642, 718]. Skipping...
604 is not present in [19, 642, 718, 767]. Skipping...
642 is not present in [20, 718, 767, 807]. Skipping...
718 is not present in [21, 767, 807, 860]. Skipping...
767 is not present in [22, 807, 860, 838]. Skipping...
807 is not present in [23, 860, 838, 1367]. Skipping...
860 is not present in [24, 838, 1367, 1105]. Skipping...
838 is not present in [940, 1367, 1105, 1511]. Skipping...
1367 is not present in [1157, 1105, 1511, 1478]. Skipping...
1105 is not present in [1065, 1511, 1478, 1315]. Skipping...
1511 is not present in [1485, 1478, 1315, 1467]. Skipping...
1478 is not present in [1415, 1315, 1467, 1596]. Skipping...
1315 is not present in [1314, 1467, 1596, 1248]. Skipping...
1467 is not present in [1433, 1596, 1248, 678]. Skipping...
1596 is not present in [1558, 1248, 678, 1160]. Skipping...
1248 is not presen

In [None]:
import scipy.stats as stats

# X and Y should be arrays, lists, or pandas Series
correlation, p_value = stats.pearsonr(all_prob, all_dot)

print("Correlation:", correlation)
print("p-value:", p_value)

Correlation: 0.8127540109290008
p-value: 8.99435128806603e-70


In [None]:
all_prob, all_dot = get_prob_dot(
    model=model, layer_no=9, head_no=1, dataset=dataset, S1_is_first=False
)

correlation, p_value = stats.pearsonr(all_prob, all_dot)
print("Correlation:", correlation)
print("p-value:", p_value)

352 is not present in [16, 362, 513, 604]. Skipping...
362 is not present in [17, 513, 604, 642]. Skipping...
513 is not present in [18, 604, 642, 718]. Skipping...
604 is not present in [19, 642, 718, 767]. Skipping...
642 is not present in [20, 718, 767, 807]. Skipping...
718 is not present in [21, 767, 807, 860]. Skipping...
767 is not present in [22, 807, 860, 838]. Skipping...
807 is not present in [23, 860, 838, 1367]. Skipping...
860 is not present in [24, 838, 1367, 1105]. Skipping...
838 is not present in [940, 1367, 1105, 1511]. Skipping...
1367 is not present in [1157, 1105, 1511, 1478]. Skipping...
1105 is not present in [1065, 1511, 1478, 1315]. Skipping...
1511 is not present in [1485, 1478, 1315, 1467]. Skipping...
1478 is not present in [1415, 1315, 1467, 1596]. Skipping...
1315 is not present in [1314, 1467, 1596, 1248]. Skipping...
1467 is not present in [1433, 1596, 1248, 678]. Skipping...
1596 is not present in [1558, 1248, 678, 1160]. Skipping...
1248 is not presen

In [None]:
all_prob, all_dot = get_prob_dot(
    model=model, layer_no=7, head_no=10, dataset=dataset, S1_is_first=False
)

correlation, p_value = stats.pearsonr(all_prob, all_dot)
print("Correlation:", correlation)
print("p-value:", p_value)

352 is not present in [16, 362, 513, 604]. Skipping...
362 is not present in [17, 513, 604, 642]. Skipping...
513 is not present in [18, 604, 642, 718]. Skipping...
604 is not present in [19, 642, 718, 767]. Skipping...
642 is not present in [20, 718, 767, 807]. Skipping...
718 is not present in [21, 767, 807, 860]. Skipping...
767 is not present in [22, 807, 860, 838]. Skipping...
807 is not present in [23, 860, 838, 1367]. Skipping...
860 is not present in [24, 838, 1367, 1105]. Skipping...
838 is not present in [940, 1367, 1105, 1511]. Skipping...
1367 is not present in [1157, 1105, 1511, 1478]. Skipping...
1105 is not present in [1065, 1511, 1478, 1315]. Skipping...
1511 is not present in [1485, 1478, 1315, 1467]. Skipping...
1478 is not present in [1415, 1315, 1467, 1596]. Skipping...
1315 is not present in [1314, 1467, 1596, 1248]. Skipping...
1467 is not present in [1433, 1596, 1248, 678]. Skipping...
1596 is not present in [1558, 1248, 678, 1160]. Skipping...
1248 is not presen

In [None]:
all_prob, all_dot = get_prob_dot(
    model=model, layer_no=5, head_no=1, dataset=dataset, S1_is_first=False
)

correlation, p_value = stats.pearsonr(all_prob, all_dot)
print("Correlation:", correlation)
print("p-value:", p_value)

352 is not present in [16, 362, 513, 604]. Skipping...
362 is not present in [17, 513, 604, 642]. Skipping...
513 is not present in [18, 604, 642, 718]. Skipping...
604 is not present in [19, 642, 718, 767]. Skipping...
642 is not present in [20, 718, 767, 807]. Skipping...
718 is not present in [21, 767, 807, 860]. Skipping...
767 is not present in [22, 807, 860, 838]. Skipping...
807 is not present in [23, 860, 838, 1367]. Skipping...
860 is not present in [24, 838, 1367, 1105]. Skipping...
838 is not present in [940, 1367, 1105, 1511]. Skipping...
1367 is not present in [1157, 1105, 1511, 1478]. Skipping...
1105 is not present in [1065, 1511, 1478, 1315]. Skipping...
1511 is not present in [1485, 1478, 1315, 1467]. Skipping...
1478 is not present in [1415, 1315, 1467, 1596]. Skipping...
1315 is not present in [1314, 1467, 1596, 1248]. Skipping...
1467 is not present in [1433, 1596, 1248, 678]. Skipping...
1596 is not present in [1558, 1248, 678, 1160]. Skipping...
1248 is not presen

In [None]:
all_prob, all_dot = get_prob_dot(
    model=model, layer_no=0, head_no=3, dataset=dataset, S1_is_first=False
)

correlation, p_value = stats.pearsonr(all_prob, all_dot)
print("Correlation:", correlation)
print("p-value:", p_value)

352 is not present in [16, 362, 513, 604]. Skipping...
362 is not present in [17, 513, 604, 642]. Skipping...
513 is not present in [18, 604, 642, 718]. Skipping...
604 is not present in [19, 642, 718, 767]. Skipping...
642 is not present in [20, 718, 767, 807]. Skipping...
718 is not present in [21, 767, 807, 860]. Skipping...
767 is not present in [22, 807, 860, 838]. Skipping...
807 is not present in [23, 860, 838, 1367]. Skipping...
860 is not present in [24, 838, 1367, 1105]. Skipping...
838 is not present in [940, 1367, 1105, 1511]. Skipping...
1367 is not present in [1157, 1105, 1511, 1478]. Skipping...
1105 is not present in [1065, 1511, 1478, 1315]. Skipping...
1511 is not present in [1485, 1478, 1315, 1467]. Skipping...
1478 is not present in [1415, 1315, 1467, 1596]. Skipping...
1315 is not present in [1314, 1467, 1596, 1248]. Skipping...
1467 is not present in [1433, 1596, 1248, 678]. Skipping...
1596 is not present in [1558, 1248, 678, 1160]. Skipping...
1248 is not presen