<a href="https://colab.research.google.com/github/wlg100/numseqcont_circuit_expms/blob/main/notebook_templates/headFNs_expms_template.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

# Setup

In [1]:
%%capture
%pip install git+https://github.com/redwoodresearch/Easy-Transformer.git
%pip install einops datasets transformers fancy_einsum

In [2]:
from copy import deepcopy
import torch

assert torch.cuda.device_count() == 1
from tqdm import tqdm
import pandas as pd
import torch
import torch as t
from easy_transformer.EasyTransformer import (
    EasyTransformer,
)
from time import ctime
from functools import partial

import numpy as np
from tqdm import tqdm
import pandas as pd

from easy_transformer.experiments import (
    ExperimentMetric,
    AblationConfig,
    EasyAblation,
    EasyPatching,
    PatchingConfig,
)
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
import random
import einops
from IPython import get_ipython
from copy import deepcopy
from easy_transformer.ioi_dataset import (
    IOIDataset,
)
from easy_transformer.ioi_utils import (
    path_patching,
    max_2d,
    CLASS_COLORS,
    show_pp,
    show_attention_patterns,
    scatter_attention_and_contribution,
)
from random import randint as ri
from easy_transformer.ioi_circuit_extraction import (
    do_circuit_extraction,
    get_heads_circuit,
    CIRCUIT,
)
from easy_transformer.ioi_utils import logit_diff, probs
from easy_transformer.ioi_utils import get_top_tokens_and_probs as g

ipython = get_ipython()
if ipython is not None:
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

 Initialise model (use larger N or fewer templates for no warnings about in-template ablation)

In [3]:
model = EasyTransformer.from_pretrained("gpt2").cuda()
# model = EasyTransformer.from_pretrained("gpt2")
model.set_use_attn_result(True)

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



Moving model to device:  cuda
Finished loading pretrained model gpt2 into EasyTransformer!


In [4]:
import pdb

# Generate dataset with multiple prompts

In [32]:
def generate_prompts_list():
    prompts_list = []
    for i in range(1, 98):
        prompt_dict = {
            'S1': str(i),
            'S2': str(i+1),
            'S3': str(i+2),
            'S4': str(i+3),
            'text': f"{i} {i+1} {i+2} {i+3}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list()

In [33]:
class Dataset:
    def __init__(self, prompts, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if key != 'text']:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                    target_token = prompt[targ]
                else:
                    target_token = "Ġ" + prompt[targ]
                target_index = tokens.index(target_token)
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [34]:
dataset = Dataset(prompts_list, model.tokenizer, S1_is_first=True)

# Next score

In [14]:
def get_next_scores(model, layer, head, dataset, verbose=False, neg=False, print_tokens=True):
    cache = {}
    model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    model(dataset.toks.long())
    if neg:
        sign = -1
    else:
        sign = 1
    z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    pred_tokens_dict = {}
    words_moved = []
    # get the keys from the first prompt in the dataset
    words = [key for key in dataset.prompts[0].keys() if key != 'text']

    for seq_idx, prompt in enumerate(dataset.prompts):
        # for word in words:
        word = words[-1]
        pred_tokens = [
            model.tokenizer.decode(token)
            for token in torch.topk(
                logits[seq_idx, dataset.word_idx[word][seq_idx]], k
            ).indices
        ]

        # get next member after digit prompt[word]
        next_word = str(int(prompt[word]) + 1)

        nextToken_in_topK = 'no'
        if " " + next_word in pred_tokens or next_word in pred_tokens:
            n_right += 1
            words_moved.append(prompt[word])
            nextToken_in_topK = 'yes'
        pred_tokens_dict[prompt[word]] = (pred_tokens, next_word, nextToken_in_topK)

    # percent_right = (n_right / (dataset.N * len(words))) * 100
    percent_right = (n_right / (dataset.N )) * 100
    if percent_right > 0:
        print(f"Next circuit for head {layer}.{head} (sign={sign}) : Top {k} accuracy: {percent_right}%")

    if print_tokens == True:
        return pred_tokens_dict
    else:
        # return words_moved
        return percent_right

In [None]:
get_next_scores(model, 9, 1, dataset)

Next circuit for head 9.1 (sign=1) : Top 5 accuracy: 87.37113402061856%


{'1': ([' one', ' needed', ' preferred', ' single', '2'], '2', 'yes'),
 '2': ([' third', ' fourth', 'third', '3', ' fifth'], '3', 'yes'),
 '3': ([' fourth', ' third', ' fifth', 'Fourth', 'fourth'], '4', 'no'),
 '4': ([' fifth', ' sixth', ' seventh', 'fifth', 'five'], '5', 'no'),
 '5': ([' sixth', ' seventh', ' fifth', '6', ' eighth'], '6', 'yes'),
 '6': ([' seventh', ' Seventh', ' sixth', '7', ' eighth'], '7', 'yes'),
 '7': ([' seventh', ' eighth', ' ninth', ' VIII', ' Seventh'], '8', 'no'),
 '8': ([' ninth', ' eighth', 'ighth', ' seventh', '9'], '9', 'yes'),
 '9': ([' ninth', ' seventh', ' tenth', ' eighth', ' sixth'], '10', 'no'),
 '10': ([' tenth', ' seventh', ' eighth', ' ninth', ' sixth'], '11', 'no'),
 '11': ([' eighth', ' seventh', ' specific', ' 12', '12'], '12', 'yes'),
 '12': ([' seventh', ' eighth', '13', ' specific', '14'], '13', 'yes'),
 '13': ([' seventh', '14', ' 14', ' eighth', 'Only'], '14', 'yes'),
 '14': ([' eighth', ' seventh', ' fifth', ' maximum', ' ninth'], '15',

# Compare next scores to copy scores

In [None]:
def get_copy_scores(model, layer, head, dataset, verbose=False, neg=False, print_tokens=True):
    cache = {}
    model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    model(dataset.toks.long())
    if neg:
        sign = -1
    else:
        sign = 1
    z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    pred_tokens_dict = {}
    words_moved = []
    # get the keys from the first prompt in the dataset
    words = [key for key in dataset.prompts[0].keys() if key != 'text']

    for seq_idx, prompt in enumerate(dataset.prompts):
        # for word in words:
        word = words[-1]
        pred_tokens = [
            model.tokenizer.decode(token)
            for token in torch.topk(
                logits[seq_idx, dataset.word_idx[word][seq_idx]], k
            ).indices
        ]

        token_in_topK = 'no'
        if " " + prompt[word] in pred_tokens or prompt[word] in pred_tokens:
            n_right += 1
            words_moved.append(prompt[word])
            token_in_topK = 'yes'
        pred_tokens_dict[prompt[word]] = (pred_tokens, token_in_topK)

    # percent_right = (n_right / (dataset.N * len(words))) * 100
    percent_right = (n_right / (dataset.N )) * 100
    print(f"Copy circuit for head {layer}.{head} (sign={sign}) : Top {k} accuracy: {percent_right}%")

    if print_tokens == True:
        return pred_tokens_dict
    else:
        return words_moved

In [None]:
get_copy_scores(model, 9, 1, dataset)

Copy circuit for head 9.1 (sign=1) : Top 5 accuracy: 59.27835051546392%


{'1': ([' one', ' needed', ' preferred', ' single', '2'], 'no'),
 '2': ([' third', ' fourth', 'third', '3', ' fifth'], 'no'),
 '3': ([' fourth', ' third', ' fifth', 'Fourth', 'fourth'], 'no'),
 '4': ([' fifth', ' sixth', ' seventh', 'fifth', 'five'], 'no'),
 '5': ([' sixth', ' seventh', ' fifth', '6', ' eighth'], 'no'),
 '6': ([' seventh', ' Seventh', ' sixth', '7', ' eighth'], 'no'),
 '7': ([' seventh', ' eighth', ' ninth', ' VIII', ' Seventh'], 'no'),
 '8': ([' ninth', ' eighth', 'ighth', ' seventh', '9'], 'no'),
 '9': ([' ninth', ' seventh', ' tenth', ' eighth', ' sixth'], 'no'),
 '10': ([' tenth', ' seventh', ' eighth', ' ninth', ' sixth'], 'no'),
 '11': ([' eighth', ' seventh', ' specific', ' 12', '12'], 'no'),
 '12': ([' seventh', ' eighth', '13', ' specific', '14'], 'no'),
 '13': ([' seventh', '14', ' 14', ' eighth', 'Only'], 'no'),
 '14': ([' eighth', ' seventh', ' fifth', ' maximum', ' ninth'], 'no'),
 '15': ([' seventh', ' eighth', ' sixth', ' fifth', ' final'], 'no'),
 '16':

# Find more next heads

## Get important heads

Find what heads are specific to certain inputs, and what's common to the template.

Get important heads from: circuit_expms_template.ipynb (Section: print top heads. Copy output of 'top_indices'; put on one line using chatgpt)

NOTE: not all attention heads just copy, so use attention patterns to determine which ones copy to refine this list of heads

(Eg. if you copy all the top heads from IOI, only 9.9 and 10.0 are name movers while other heads are "S-inhibition", "induction", "duplicate", so only the name movers + backup NM will have top accuracy)

## last token only

In [None]:
top_val = [(1,5), (4,4), (6,1), (7, 10), (7, 11), (8, 11), (9, 1)]
for index, (layer, head) in enumerate(top_val):
    print(index, get_copy_scores(model, layer, head, dataset, print_tokens=False))

Copy circuit for head 1.5 (sign=1) : Top 5 accuracy: 0.0%
0 []
Copy circuit for head 4.4 (sign=1) : Top 5 accuracy: 0.0%
1 []
Copy circuit for head 6.1 (sign=1) : Top 5 accuracy: 1.0309278350515463%
2 ['18']
Copy circuit for head 7.10 (sign=1) : Top 5 accuracy: 100.0%
3 ['4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99', '100']
Copy circuit for head 7.11 (sign=1) : Top 5 accuracy: 23.711340206185564%
4 ['4', '5', '6', '10', '11', '13', '14', '15', '20', '29', '39', '40', '41',

In [None]:
top_val = [(1,5), (4,4), (6,1), (7, 10), (7, 11), (8, 11), (9, 1)]
for index, (layer, head) in enumerate(top_val):
    print(index, get_next_scores(model, layer, head, dataset, print_tokens=False))

Next circuit for head 1.5 (sign=1) : Top 5 accuracy: 0.0%
0 []
Next circuit for head 4.4 (sign=1) : Top 5 accuracy: 0.0%
1 []
Next circuit for head 6.1 (sign=1) : Top 5 accuracy: 18.556701030927837%
2 ['4', '5', '6', '7', '8', '9', '11', '12', '13', '15', '17', '18', '20', '21', '45', '46', '65', '90']
Next circuit for head 7.10 (sign=1) : Top 5 accuracy: 63.91752577319587%
3 ['4', '11', '16', '17', '21', '22', '26', '27', '28', '31', '32', '33', '35', '37', '38', '41', '42', '43', '44', '45', '46', '47', '48', '49', '51', '52', '53', '55', '56', '57', '58', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '73', '76', '77', '78', '80', '81', '82', '83', '84', '85', '86', '87', '88', '90', '91', '92', '93', '96', '97']
Next circuit for head 7.11 (sign=1) : Top 5 accuracy: 8.24742268041237%
4 ['4', '5', '10', '22', '48', '52', '53', '58']
Next circuit for head 8.11 (sign=1) : Top 5 accuracy: 32.98969072164948%
5 ['27', '33', '35', '45', '47', '48', '55', '57', '58'

### loop over all heads

only print out if there's a match

In [15]:
all_next_scores = []
all_heads = [(layer, head) for layer in range(12) for head in range(12)]
for index, (layer, head) in enumerate(all_heads):
    all_next_scores.append(get_next_scores(model, layer, head, dataset, print_tokens=False))

Next circuit for head 3.5 (sign=1) : Top 5 accuracy: 4.123711340206185%
Next circuit for head 5.0 (sign=1) : Top 5 accuracy: 8.24742268041237%
Next circuit for head 5.1 (sign=1) : Top 5 accuracy: 3.0927835051546393%
Next circuit for head 6.1 (sign=1) : Top 5 accuracy: 18.556701030927837%
Next circuit for head 6.9 (sign=1) : Top 5 accuracy: 76.28865979381443%
Next circuit for head 7.2 (sign=1) : Top 5 accuracy: 43.29896907216495%
Next circuit for head 7.7 (sign=1) : Top 5 accuracy: 3.0927835051546393%
Next circuit for head 7.8 (sign=1) : Top 5 accuracy: 2.0618556701030926%
Next circuit for head 7.10 (sign=1) : Top 5 accuracy: 63.91752577319587%
Next circuit for head 7.11 (sign=1) : Top 5 accuracy: 8.24742268041237%
Next circuit for head 8.0 (sign=1) : Top 5 accuracy: 1.0309278350515463%
Next circuit for head 8.1 (sign=1) : Top 5 accuracy: 43.29896907216495%
Next circuit for head 8.8 (sign=1) : Top 5 accuracy: 75.25773195876289%
Next circuit for head 8.11 (sign=1) : Top 5 accuracy: 32.98

In [16]:
sum(all_next_scores)/len(all_next_scores)

4.08791523482245

In [13]:
all_next_scores

[[],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 ['7', '11', '15', '17'],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 ['4', '5', '7', '13', '17', '27', '41', '51'],
 ['51', '61', '67'],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 ['4',
  '5',
  '6',
  '7',
  '8',
  '9',
  '11',
  '12',
  '13',
  '15',
  '17',
  '18',
  '20',
  '21',
  '45',
  '46',
  '65',
  '90'],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 ['4',
  '5',
  '6',
  '7',
  '8',
  '9',
  '10',
  '11',
  '12',
  '13',
  '15',
  '16',
  '18',
  '19',
  '20',
  '21',
  '22',
  '23',
  '24',
  '25',
  '26',
  '27',
  '28',
  '29',
  '32',
  '33',
  '36',
  '38',
  '41',
  '42',
  '43',
  '44',
  '45',
  '46',
  '47',
  '48',
  '49',
  '51',
  '52',
  '53',
  '55',
  '56',
  '57',
  '58',
  '59',
  '60',
  '61',
  '62',
  '

## all tokens

In [20]:
def get_copy_scores(model, layer, head, dataset, verbose=False, neg=False, print_tokens=True):
    cache = {}
    model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    model(dataset.toks.long())
    if neg:
        sign = -1
    else:
        sign = 1
    z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    pred_tokens_dict = {}
    words_moved = []
    # get the keys from the first prompt in the dataset
    words = [key for key in dataset.prompts[0].keys() if key != 'text']

    for seq_idx, prompt in enumerate(dataset.prompts):
        for word in words:
        # for word in words[1:]:
            # word = words[-1]
            pred_tokens = [
                model.tokenizer.decode(token)
                for token in torch.topk(
                    logits[seq_idx, dataset.word_idx[word][seq_idx]], k
                ).indices
            ]

            token_in_topK = 'no'
            if " " + prompt[word] in pred_tokens or prompt[word] in pred_tokens:
                n_right += 1
                words_moved.append(prompt[word])
                token_in_topK = 'yes'
            pred_tokens_dict[prompt[word]] = (pred_tokens, token_in_topK)

    percent_right = (n_right / (dataset.N * len(words))) * 100
    # percent_right = (n_right / (dataset.N )) * 100
    if percent_right > 0:
        print(f"Copy circuit for head {layer}.{head} (sign={sign}) : Top {k} accuracy: {percent_right}%")

    if print_tokens == True:
        return pred_tokens_dict
    else:
        # return words_moved
        return percent_right

In [21]:
def get_next_scores(model, layer, head, dataset, verbose=False, neg=False, print_tokens=True):
    cache = {}
    model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    model(dataset.toks.long())
    if neg:
        sign = -1
    else:
        sign = 1
    z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    pred_tokens_dict = {}
    words_moved = []
    # get the keys from the first prompt in the dataset
    words = [key for key in dataset.prompts[0].keys() if key != 'text']

    for seq_idx, prompt in enumerate(dataset.prompts):
        for word in words:
        # for word in words[1:]:
            # word = words[-1]
            pred_tokens = [
                model.tokenizer.decode(token)
                for token in torch.topk(
                    logits[seq_idx, dataset.word_idx[word][seq_idx]], k
                ).indices
            ]

            # get next member after digit prompt[word]
            next_word = str(int(prompt[word]) + 1)

            nextToken_in_topK = 'no'
            if " " + next_word in pred_tokens or next_word in pred_tokens:
                n_right += 1
                words_moved.append(prompt[word])
                nextToken_in_topK = 'yes'
            pred_tokens_dict[prompt[word]] = (pred_tokens, next_word, nextToken_in_topK)

    percent_right = (n_right / (dataset.N * len(words))) * 100
    # percent_right = (n_right / (dataset.N )) * 100
    if percent_right > 0:
        print(f"Next circuit for head {layer}.{head} (sign={sign}) : Top {k} accuracy: {percent_right}%")

    if print_tokens == True:
        return pred_tokens_dict
    else:
        # return words_moved
        return percent_right

In [22]:
top_val = [(1,5), (4,4), (6,1), (7, 10), (7, 11), (8, 11), (9, 1)]
for index, (layer, head) in enumerate(top_val):
    print(index, get_copy_scores(model, layer, head, dataset, print_tokens=False))

0 0.0
1 0.0
Copy circuit for head 6.1 (sign=1) : Top 5 accuracy: 1.7500000000000002%
2 1.7500000000000002
Copy circuit for head 7.10 (sign=1) : Top 5 accuracy: 87.0%
3 87.0
Copy circuit for head 7.11 (sign=1) : Top 5 accuracy: 17.25%
4 17.25
Copy circuit for head 8.11 (sign=1) : Top 5 accuracy: 76.0%
5 76.0
Copy circuit for head 9.1 (sign=1) : Top 5 accuracy: 60.25%
6 60.25


In [23]:
top_val = [(1,5), (4,4), (6,1), (7, 10), (7, 11), (8, 11), (9, 1)]
for index, (layer, head) in enumerate(top_val):
    print(index, get_next_scores(model, layer, head, dataset, print_tokens=False))

0 0.0
1 0.0
Next circuit for head 6.1 (sign=1) : Top 5 accuracy: 14.499999999999998%
2 14.499999999999998
Next circuit for head 7.10 (sign=1) : Top 5 accuracy: 48.75%
3 48.75
Next circuit for head 7.11 (sign=1) : Top 5 accuracy: 9.0%
4 9.0
Next circuit for head 8.11 (sign=1) : Top 5 accuracy: 26.0%
5 26.0
Next circuit for head 9.1 (sign=1) : Top 5 accuracy: 87.25%
6 87.25


In [24]:
top_val = [(0, 10), (0, 1), (5,5), (6,1), (7, 10), (8,8), (7,11), (8,11), (9,1), (9,5), (10,7)]
for index, (layer, head) in enumerate(top_val):
    print(index, get_next_scores(model, layer, head, dataset, print_tokens=False))

0 0.0
1 0.0
Next circuit for head 5.5 (sign=1) : Top 5 accuracy: 0.25%
2 0.25
Next circuit for head 6.1 (sign=1) : Top 5 accuracy: 14.499999999999998%
3 14.499999999999998
Next circuit for head 7.10 (sign=1) : Top 5 accuracy: 48.75%
4 48.75
Next circuit for head 8.8 (sign=1) : Top 5 accuracy: 55.25%
5 55.25
Next circuit for head 7.11 (sign=1) : Top 5 accuracy: 9.0%
6 9.0
Next circuit for head 8.11 (sign=1) : Top 5 accuracy: 26.0%
7 26.0
Next circuit for head 9.1 (sign=1) : Top 5 accuracy: 87.25%
8 87.25
9 0.0
10 0.0


Compare these scores to copy scores

In [25]:
top_val = [(0, 10), (0, 1), (5,5), (6,1), (7, 10), (8,8), (7,11), (8,11), (9,1), (9,5), (10,7)]
for index, (layer, head) in enumerate(top_val):
    print(index, get_copy_scores(model, layer, head, dataset, print_tokens=False))

0 0.0
1 0.0
Copy circuit for head 5.5 (sign=1) : Top 5 accuracy: 0.75%
2 0.75
Copy circuit for head 6.1 (sign=1) : Top 5 accuracy: 1.7500000000000002%
3 1.7500000000000002
Copy circuit for head 7.10 (sign=1) : Top 5 accuracy: 87.0%
4 87.0
Copy circuit for head 8.8 (sign=1) : Top 5 accuracy: 37.0%
5 37.0
Copy circuit for head 7.11 (sign=1) : Top 5 accuracy: 17.25%
6 17.25
Copy circuit for head 8.11 (sign=1) : Top 5 accuracy: 76.0%
7 76.0
Copy circuit for head 9.1 (sign=1) : Top 5 accuracy: 60.25%
8 60.25
9 0.0
10 0.0


In [26]:
get_next_scores(model, 8, 8, dataset)

Next circuit for head 8.8 (sign=1) : Top 5 accuracy: 55.25%


{'1': ([' multiplayer', ' Multiplayer', 'Thread', ' Instant', ' Realm'],
  '2',
  'no'),
 '2': ([' multiplayer', 'nerg', 'Thread', ' Instant', ' Multiplayer'],
  '3',
  'no'),
 '3': ([' multiplayer', 'Thread', 'nerg', 'Assembly', ' Multiplayer'],
  '4',
  'no'),
 '4': (['Thread', 'Assembly', ' multiplayer', 'nerg', 'ortal'], '5', 'no'),
 '5': (['Thread', 'Cooldown', '�', 'nerg', ' Realm'], '6', 'no'),
 '6': (['Thread', 'Cooldown', 'nerg', 'pload', '�'], '7', 'no'),
 '7': (['Thread', 'thread', ' multiplayer', '�', ' multit'], '8', 'no'),
 '8': (['Thread', 'Cooldown', 'ROM', 'RAM', '�'], '9', 'no'),
 '9': (['Thread', ' Realm', 'abyte', 'Cooldown', ' Realms'], '10', 'no'),
 '10': (['Cooldown', 'Thread', 'Timeout', '�', ' cooldown'], '11', 'no'),
 '11': (['Thread', ' MSI', '━', 'Mom', ' multiplayer'], '12', 'no'),
 '12': (['Cooldown', 'Thread', 'Timeout', 'ynchron', 'Mom'], '13', 'no'),
 '13': (['Cooldown', 'Thread', ' cooldown', ' binds', 'ynchron'], '14', 'no'),
 '14': (['Cooldown', 'Thr

# loop over all heads

only print out if there's a match

In [35]:
all_copy_scores = []
all_heads = [(layer, head) for layer in range(12) for head in range(12)]
for index, (layer, head) in enumerate(all_heads):
    all_copy_scores.append(get_copy_scores(model, layer, head, dataset, print_tokens=False))

Copy circuit for head 1.7 (sign=1) : Top 5 accuracy: 2.0618556701030926%
Copy circuit for head 2.1 (sign=1) : Top 5 accuracy: 1.5463917525773196%
Copy circuit for head 2.4 (sign=1) : Top 5 accuracy: 0.25773195876288657%
Copy circuit for head 3.5 (sign=1) : Top 5 accuracy: 1.2886597938144329%
Copy circuit for head 4.8 (sign=1) : Top 5 accuracy: 0.25773195876288657%
Copy circuit for head 5.0 (sign=1) : Top 5 accuracy: 11.34020618556701%
Copy circuit for head 5.1 (sign=1) : Top 5 accuracy: 48.71134020618557%
Copy circuit for head 5.5 (sign=1) : Top 5 accuracy: 0.25773195876288657%
Copy circuit for head 6.1 (sign=1) : Top 5 accuracy: 1.804123711340206%
Copy circuit for head 6.4 (sign=1) : Top 5 accuracy: 0.5154639175257731%
Copy circuit for head 6.9 (sign=1) : Top 5 accuracy: 83.76288659793815%
Copy circuit for head 6.10 (sign=1) : Top 5 accuracy: 1.0309278350515463%
Copy circuit for head 7.1 (sign=1) : Top 5 accuracy: 1.0309278350515463%
Copy circuit for head 7.2 (sign=1) : Top 5 accuracy

In [36]:
sum(all_copy_scores)/len(all_copy_scores)

5.725587056128294

In [37]:
all_next_scores = []
all_heads = [(layer, head) for layer in range(12) for head in range(12)]
for index, (layer, head) in enumerate(all_heads):
    all_next_scores.append(get_next_scores(model, layer, head, dataset, print_tokens=False))

Next circuit for head 1.7 (sign=1) : Top 5 accuracy: 0.7731958762886598%
Next circuit for head 2.1 (sign=1) : Top 5 accuracy: 0.25773195876288657%
Next circuit for head 3.5 (sign=1) : Top 5 accuracy: 2.3195876288659796%
Next circuit for head 4.9 (sign=1) : Top 5 accuracy: 0.25773195876288657%
Next circuit for head 5.0 (sign=1) : Top 5 accuracy: 7.731958762886598%
Next circuit for head 5.1 (sign=1) : Top 5 accuracy: 2.0618556701030926%
Next circuit for head 5.5 (sign=1) : Top 5 accuracy: 0.25773195876288657%
Next circuit for head 6.1 (sign=1) : Top 5 accuracy: 14.948453608247423%
Next circuit for head 6.9 (sign=1) : Top 5 accuracy: 57.73195876288659%
Next circuit for head 7.2 (sign=1) : Top 5 accuracy: 31.185567010309278%
Next circuit for head 7.7 (sign=1) : Top 5 accuracy: 1.5463917525773196%
Next circuit for head 7.8 (sign=1) : Top 5 accuracy: 1.2886597938144329%
Next circuit for head 7.10 (sign=1) : Top 5 accuracy: 48.71134020618557%
Next circuit for head 7.11 (sign=1) : Top 5 accura

In [38]:
sum(all_next_scores)/len(all_next_scores)

3.2932416953035517

# months

In [60]:
def get_next_scores(model, layer, head, dataset, verbose=False, neg=False, print_tokens=True):
    cache = {}
    model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    model(dataset.toks.long())
    if neg:
        sign = -1
    else:
        sign = 1
    z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    pred_tokens_dict = {}
    words_moved = []
    # get the keys from the first prompt in the dataset
    words = [key for key in dataset.prompts[0].keys() if key != 'text']

    months = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December']

    for seq_idx, prompt in enumerate(dataset.prompts):
        for word in words:
        # for word in words[1:]:
            # word = words[-1]
            pred_tokens = [
                model.tokenizer.decode(token)
                for token in torch.topk(
                    logits[seq_idx, dataset.word_idx[word][seq_idx]], k
                ).indices
            ]

            # get next member after digit prompt[word]
            next_word = str(months[months.index(prompt[word]) + 1])

            nextToken_in_topK = 'no'
            if " " + next_word in pred_tokens or next_word in pred_tokens:
                print(prompt[word], next_word)
                n_right += 1
                words_moved.append(prompt[word])
                nextToken_in_topK = 'yes'
            pred_tokens_dict[prompt[word]] = (pred_tokens, next_word, nextToken_in_topK)

    percent_right = (n_right / (dataset.N * len(words))) * 100
    # percent_right = (n_right / (dataset.N )) * 100
    if percent_right > 0:
        print(f"Next circuit for head {layer}.{head} (sign={sign}) : Top {k} accuracy: {percent_right}%")

    if print_tokens == True:
        return pred_tokens_dict
    else:
        # return words_moved
        return percent_right

In [51]:
def generate_prompts_list():
    months = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December']
    prompts_list = []
    for i in range(0, 8):
        prompt_dict = {
            'S1': months[i],
            'S2': months[i+1],
            'S3': months[i+2],
            'S4': months[i+3],
            'text': f"{months[i]} {months[i+1]} {months[i+2]} {months[i+3]}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list()
print(prompts_list)

[{'S1': 'January', 'S2': 'February', 'S3': 'March', 'S4': 'April', 'text': 'January February March April'}, {'S1': 'February', 'S2': 'March', 'S3': 'April', 'S4': 'May', 'text': 'February March April May'}, {'S1': 'March', 'S2': 'April', 'S3': 'May', 'S4': 'June', 'text': 'March April May June'}, {'S1': 'April', 'S2': 'May', 'S3': 'June', 'S4': 'July', 'text': 'April May June July'}, {'S1': 'May', 'S2': 'June', 'S3': 'July', 'S4': 'August', 'text': 'May June July August'}, {'S1': 'June', 'S2': 'July', 'S3': 'August', 'S4': 'September', 'text': 'June July August September'}, {'S1': 'July', 'S2': 'August', 'S3': 'September', 'S4': 'October', 'text': 'July August September October'}, {'S1': 'August', 'S2': 'September', 'S3': 'October', 'S4': 'November', 'text': 'August September October November'}]


In [52]:
class Dataset:
    def __init__(self, prompts, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if key != 'text']:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                    target_token = prompt[targ]
                else:
                    target_token = "Ġ" + prompt[targ]
                target_index = tokens.index(target_token)
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [53]:
dataset = Dataset(prompts_list, model.tokenizer, S1_is_first=True)

## loop over all heads

only print out if there's a match

In [54]:
all_copy_scores = []
all_heads = [(layer, head) for layer in range(12) for head in range(12)]
for index, (layer, head) in enumerate(all_heads):
    all_copy_scores.append(get_copy_scores(model, layer, head, dataset, print_tokens=False))

Copy circuit for head 5.0 (sign=1) : Top 5 accuracy: 34.375%
Copy circuit for head 5.1 (sign=1) : Top 5 accuracy: 31.25%
Copy circuit for head 5.4 (sign=1) : Top 5 accuracy: 6.25%
Copy circuit for head 5.7 (sign=1) : Top 5 accuracy: 3.125%
Copy circuit for head 6.0 (sign=1) : Top 5 accuracy: 3.125%
Copy circuit for head 6.1 (sign=1) : Top 5 accuracy: 3.125%
Copy circuit for head 6.9 (sign=1) : Top 5 accuracy: 81.25%
Copy circuit for head 6.10 (sign=1) : Top 5 accuracy: 6.25%
Copy circuit for head 7.2 (sign=1) : Top 5 accuracy: 75.0%
Copy circuit for head 7.5 (sign=1) : Top 5 accuracy: 3.125%
Copy circuit for head 7.7 (sign=1) : Top 5 accuracy: 9.375%
Copy circuit for head 7.8 (sign=1) : Top 5 accuracy: 59.375%
Copy circuit for head 7.10 (sign=1) : Top 5 accuracy: 71.875%
Copy circuit for head 7.11 (sign=1) : Top 5 accuracy: 71.875%
Copy circuit for head 8.1 (sign=1) : Top 5 accuracy: 75.0%
Copy circuit for head 8.6 (sign=1) : Top 5 accuracy: 9.375%
Copy circuit for head 8.7 (sign=1) : 

In [55]:
sum(all_copy_scores)/len(all_copy_scores)

7.400173611111111

In [61]:
all_next_scores = []
all_heads = [(layer, head) for layer in range(12) for head in range(12)]
for index, (layer, head) in enumerate(all_heads):
    all_next_scores.append(get_next_scores(model, layer, head, dataset, print_tokens=False))

September October
Next circuit for head 3.5 (sign=1) : Top 5 accuracy: 3.125%
May June
May June
June July
June July
July August
July August
November December
Next circuit for head 5.0 (sign=1) : Top 5 accuracy: 21.875%
May June
June July
July August
August September
Next circuit for head 5.1 (sign=1) : Top 5 accuracy: 12.5%
July August
July August
Next circuit for head 5.4 (sign=1) : Top 5 accuracy: 6.25%
February March
May June
May June
May June
July August
July August
July August
August September
November December
Next circuit for head 6.1 (sign=1) : Top 5 accuracy: 28.125%
May June
May June
June July
May June
June July
July August
July August
November December
Next circuit for head 6.9 (sign=1) : Top 5 accuracy: 25.0%
June July
June July
Next circuit for head 6.10 (sign=1) : Top 5 accuracy: 6.25%
June July
June July
September October
September October
Next circuit for head 7.2 (sign=1) : Top 5 accuracy: 12.5%
September October
September October
October November
October November
Nove

In [57]:
sum(all_next_scores)/len(all_next_scores)

3.1901041666666665

# months try months but check if NUMERAL IS IN IT

In [70]:
def get_next_scores(model, layer, head, dataset, verbose=False, neg=False, print_tokens=True):
    cache = {}
    model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    model(dataset.toks.long())
    if neg:
        sign = -1
    else:
        sign = 1
    z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    pred_tokens_dict = {}
    words_moved = []
    # get the keys from the first prompt in the dataset
    words = [key for key in dataset.prompts[0].keys() if key != 'text']

    months = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December']
    ranks = ['first', 'second', 'third', 'fourth', 'fifth', 'sixth', 'seventh', 'eighth', 'ninth', 'tenth', 'eleventh', 'twelfth']

    for seq_idx, prompt in enumerate(dataset.prompts):
        for word in words:
        # for word in words[1:]:
            # word = words[-1]
            pred_tokens = [
                model.tokenizer.decode(token)
                for token in torch.topk(
                    logits[seq_idx, dataset.word_idx[word][seq_idx]], k
                ).indices
            ]

            # get next member after digit prompt[word]
            next_word = str(ranks[months.index(prompt[word]) + 1])

            nextToken_in_topK = 'no'
            if " " + next_word in pred_tokens or next_word in pred_tokens:
                # print(prompt[word], next_word)
                n_right += 1
                words_moved.append(prompt[word])
                nextToken_in_topK = 'yes'
            pred_tokens_dict[prompt[word]] = (pred_tokens, next_word, nextToken_in_topK)
        # if print_tokens == True:
        #     if " " + next_word in pred_tokens or next_word in pred_tokens:
        #         print(pred_tokens_dict)

    percent_right = (n_right / (dataset.N * len(words))) * 100
    # percent_right = (n_right / (dataset.N )) * 100
    if percent_right > 0:
        print(f"Next circuit for head {layer}.{head} (sign={sign}) : Top {k} accuracy: {percent_right}%")

    if print_tokens == True:
        # print(pred_tokens_dict)
        return pred_tokens_dict
    else:
        # return words_moved
        return percent_right

In [None]:
def generate_prompts_list():
    months = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December']
    prompts_list = []
    for i in range(0, 8):
        prompt_dict = {
            'S1': months[i],
            'S2': months[i+1],
            'S3': months[i+2],
            'S4': months[i+3],
            'text': f"{months[i]} {months[i+1]} {months[i+2]} {months[i+3]}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list()
print(prompts_list)

[{'S1': 'January', 'S2': 'February', 'S3': 'March', 'S4': 'April', 'text': 'January February March April'}, {'S1': 'February', 'S2': 'March', 'S3': 'April', 'S4': 'May', 'text': 'February March April May'}, {'S1': 'March', 'S2': 'April', 'S3': 'May', 'S4': 'June', 'text': 'March April May June'}, {'S1': 'April', 'S2': 'May', 'S3': 'June', 'S4': 'July', 'text': 'April May June July'}, {'S1': 'May', 'S2': 'June', 'S3': 'July', 'S4': 'August', 'text': 'May June July August'}, {'S1': 'June', 'S2': 'July', 'S3': 'August', 'S4': 'September', 'text': 'June July August September'}, {'S1': 'July', 'S2': 'August', 'S3': 'September', 'S4': 'October', 'text': 'July August September October'}, {'S1': 'August', 'S2': 'September', 'S3': 'October', 'S4': 'November', 'text': 'August September October November'}]


In [None]:
class Dataset:
    def __init__(self, prompts, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if key != 'text']:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                    target_token = prompt[targ]
                else:
                    target_token = "Ġ" + prompt[targ]
                target_index = tokens.index(target_token)
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [None]:
dataset = Dataset(prompts_list, model.tokenizer, S1_is_first=True)

## loop over all heads

only print out if there's a match

In [None]:
all_copy_scores = []
all_heads = [(layer, head) for layer in range(12) for head in range(12)]
for index, (layer, head) in enumerate(all_heads):
    all_copy_scores.append(get_copy_scores(model, layer, head, dataset, print_tokens=False))

Copy circuit for head 5.0 (sign=1) : Top 5 accuracy: 34.375%
Copy circuit for head 5.1 (sign=1) : Top 5 accuracy: 31.25%
Copy circuit for head 5.4 (sign=1) : Top 5 accuracy: 6.25%
Copy circuit for head 5.7 (sign=1) : Top 5 accuracy: 3.125%
Copy circuit for head 6.0 (sign=1) : Top 5 accuracy: 3.125%
Copy circuit for head 6.1 (sign=1) : Top 5 accuracy: 3.125%
Copy circuit for head 6.9 (sign=1) : Top 5 accuracy: 81.25%
Copy circuit for head 6.10 (sign=1) : Top 5 accuracy: 6.25%
Copy circuit for head 7.2 (sign=1) : Top 5 accuracy: 75.0%
Copy circuit for head 7.5 (sign=1) : Top 5 accuracy: 3.125%
Copy circuit for head 7.7 (sign=1) : Top 5 accuracy: 9.375%
Copy circuit for head 7.8 (sign=1) : Top 5 accuracy: 59.375%
Copy circuit for head 7.10 (sign=1) : Top 5 accuracy: 71.875%
Copy circuit for head 7.11 (sign=1) : Top 5 accuracy: 71.875%
Copy circuit for head 8.1 (sign=1) : Top 5 accuracy: 75.0%
Copy circuit for head 8.6 (sign=1) : Top 5 accuracy: 9.375%
Copy circuit for head 8.7 (sign=1) : 

In [None]:
sum(all_copy_scores)/len(all_copy_scores)

7.400173611111111

In [73]:
all_next_scores = []
all_heads = [(layer, head) for layer in range(12) for head in range(12)]
for index, (layer, head) in enumerate(all_heads):
    all_next_scores.append(get_next_scores(model, layer, head, dataset, print_tokens=False))

Next circuit for head 9.1 (sign=1) : Top 5 accuracy: 31.25%


In [74]:
sum(all_next_scores)/len(all_next_scores)

0.2170138888888889

In [67]:
get_next_scores(model, 9, 1, dataset)

{'January': ([' once', 'once', ' occasional', ' seventh', '296'],
  'February',
  'no'),
 'February': ([' third', ' fourth', ' once', ' seventh', ' fifth'],
  'March',
  'no'),
 'March': ([' fifth', ' seventh', ' fourth', ' sixth', ' third'],
  'April',
  'no'),
 'April': ([' seventh', ' fifth', ' sixth', 'ISC', ' five'], 'May', 'no'),
 'May': ([' seventh', ' sixth', 'ISC', ' once', ' 157'], 'June', 'no'),
 'June': ([' seventh', ' third', ' seven', ' sixth', 'seven'], 'July', 'no'),
 'July': ([' seventh', ' eighth', 'seven', ' once', ' seven'], 'August', 'no'),
 'August': (['ighth', ' eighth', ' ninth', ' final', ' occasional'],
  'September',
  'no'),
 'September': ([' 120', 'ure', 'Loading', '�', ' Nguyen'], 'October', 'no'),
 'October': ([' Nur', 'undrum', '�', '�', ' 121'], 'November', 'no'),
 'November': (['oor', 'ة', ' 122', 'Enlarge', ' Nur'], 'December', 'no')}

# nw

In [81]:
def get_next_scores(model, layer, head, dataset, verbose=False, neg=False, print_tokens=True):
    cache = {}
    model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    model(dataset.toks.long())
    if neg:
        sign = -1
    else:
        sign = 1
    z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    pred_tokens_dict = {}
    words_moved = []
    # get the keys from the first prompt in the dataset
    words = [key for key in dataset.prompts[0].keys() if key != 'text']

    # months = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December']
    months = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve']

    for seq_idx, prompt in enumerate(dataset.prompts):
        for word in words:
        # for word in words[1:]:
            # word = words[-1]
            pred_tokens = [
                model.tokenizer.decode(token)
                for token in torch.topk(
                    logits[seq_idx, dataset.word_idx[word][seq_idx]], k
                ).indices
            ]

            # get next member after digit prompt[word]
            next_word = str(months[months.index(prompt[word]) + 1])

            nextToken_in_topK = 'no'
            if " " + next_word in pred_tokens or next_word in pred_tokens:
                # print(prompt[word], next_word)
                n_right += 1
                words_moved.append(prompt[word])
                nextToken_in_topK = 'yes'
            pred_tokens_dict[prompt[word]] = (pred_tokens, next_word, nextToken_in_topK)

    percent_right = (n_right / (dataset.N * len(words))) * 100
    # percent_right = (n_right / (dataset.N )) * 100
    if percent_right > 0:
        print(f"Next circuit for head {layer}.{head} (sign={sign}) : Top {k} accuracy: {percent_right}%")

    if print_tokens == True:
        return pred_tokens_dict
    else:
        # return words_moved
        return percent_right

In [77]:
def generate_prompts_list(x ,y):
    words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve']
    prompts_list = []
    for i in range(x, y):
        prompt_dict = {
            'S1': words[i],
            'S2': words[i+1],
            'S3': words[i+2],
            'S4': words[i+3],
            'text': f"{words[i]} {words[i+1]} {words[i+2]} {words[i+3]}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(0, 8)
prompts_list

[{'S1': 'one',
  'S2': 'two',
  'S3': 'three',
  'S4': 'four',
  'text': 'one two three four'},
 {'S1': 'two',
  'S2': 'three',
  'S3': 'four',
  'S4': 'five',
  'text': 'two three four five'},
 {'S1': 'three',
  'S2': 'four',
  'S3': 'five',
  'S4': 'six',
  'text': 'three four five six'},
 {'S1': 'four',
  'S2': 'five',
  'S3': 'six',
  'S4': 'seven',
  'text': 'four five six seven'},
 {'S1': 'five',
  'S2': 'six',
  'S3': 'seven',
  'S4': 'eight',
  'text': 'five six seven eight'},
 {'S1': 'six',
  'S2': 'seven',
  'S3': 'eight',
  'S4': 'nine',
  'text': 'six seven eight nine'},
 {'S1': 'seven',
  'S2': 'eight',
  'S3': 'nine',
  'S4': 'ten',
  'text': 'seven eight nine ten'},
 {'S1': 'eight',
  'S2': 'nine',
  'S3': 'ten',
  'S4': 'eleven',
  'text': 'eight nine ten eleven'}]

In [78]:
class Dataset:
    def __init__(self, prompts, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if key != 'text']:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                    target_token = prompt[targ]
                else:
                    target_token = "Ġ" + prompt[targ]
                target_index = tokens.index(target_token)
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [79]:
dataset = Dataset(prompts_list, model.tokenizer, S1_is_first=True)

## loop over all heads

only print out if there's a match

In [82]:
all_next_scores = []
all_heads = [(layer, head) for layer in range(12) for head in range(12)]
for index, (layer, head) in enumerate(all_heads):
    all_next_scores.append(get_next_scores(model, layer, head, dataset, print_tokens=False))

Next circuit for head 1.7 (sign=1) : Top 5 accuracy: 3.125%
Next circuit for head 2.1 (sign=1) : Top 5 accuracy: 3.125%
Next circuit for head 4.8 (sign=1) : Top 5 accuracy: 3.125%
Next circuit for head 5.0 (sign=1) : Top 5 accuracy: 9.375%
Next circuit for head 6.1 (sign=1) : Top 5 accuracy: 75.0%
Next circuit for head 7.2 (sign=1) : Top 5 accuracy: 21.875%
Next circuit for head 7.7 (sign=1) : Top 5 accuracy: 18.75%
Next circuit for head 7.11 (sign=1) : Top 5 accuracy: 6.25%
Next circuit for head 8.1 (sign=1) : Top 5 accuracy: 43.75%
Next circuit for head 8.8 (sign=1) : Top 5 accuracy: 56.25%
Next circuit for head 9.1 (sign=1) : Top 5 accuracy: 90.625%
Next circuit for head 9.7 (sign=1) : Top 5 accuracy: 9.375%
Next circuit for head 10.2 (sign=1) : Top 5 accuracy: 43.75%
Next circuit for head 11.4 (sign=1) : Top 5 accuracy: 43.75%


In [83]:
sum(all_next_scores)/len(all_next_scores)

2.9730902777777777

# TBC

The cells below have not been updated yet for next scores so disregard them:

---



# Writing direction results with scatterplot

In [None]:
def scatter_attention_and_contribution(
    model,
    layer_no,
    head_no,
    dataset,
    S1_is_first=False,
    return_vals=False,
    return_fig=False,
):
    """
    Plot a scatter plot
    for each input sequence with the attention paid to S
    and the amount that is written in the S directions
    """

    n_heads = model.cfg.n_heads
    n_layers = model.cfg.n_layers
    model_unembed = model.unembed.W_U.detach().cpu()
    df = []
    cache = {}
    model.cache_all(cache)

    logits = model(dataset.toks.long())

    for i, prompt in enumerate(dataset.prompts):
        s_toks = []
        s_positions = []
        s_dirs = []

        targ_tokens = [key for key in dataset.prompts[0].keys() if key != 'text']
        for s_id in targ_tokens:
            if S1_is_first and s_id == "S1":  # only use this if first token doesn't have space Ġ in front
                s_tok = model.tokenizer(prompt["S1"])["input_ids"][0]
            else:
                s_tok = model.tokenizer(" " + prompt[s_id])["input_ids"][0]
            s_toks.append(s_tok)

            toks = model.tokenizer(prompt["text"])["input_ids"]
            try:
                s_pos = toks.index(s_tok)
            except ValueError:
                print(f"{s_tok} is not present in {toks}. Skipping...")
                continue

            s_pos = toks.index(s_tok)
            s_positions.append(s_pos)

            s_dir = model_unembed[:, s_tok].detach()
            s_dirs.append(s_dir)

        for dire, posses, tok_type in zip(s_dirs, s_positions, targ_tokens):
            prob = sum(
                [
                    cache[f"blocks.{layer_no}.attn.hook_attn"][
                        i, head_no, dataset.word_idx["end"][i], pos
                    ]
                    .detach()
                    .cpu()
                    for pos in [posses]
                ]
            )
            resid = (
                cache[f"blocks.{layer_no}.attn.hook_result"][
                    i, dataset.word_idx["end"][i], head_no, :
                ]
                .detach()
                .cpu()
            )
            dot = torch.einsum("a,a->", resid, dire)
            df.append([prob, dot, tok_type, prompt["text"]])

    viz_df = pd.DataFrame(
        df, columns=[f"Attn Prob on Number", f"Dot w Number Embed", "Seq Position", "text"]
    )
    fig = px.scatter(
        viz_df,
        x=f"Attn Prob on Number",
        y=f"Dot w Number Embed",
        color="Seq Position",
        hover_data=["text"],
        title=f"How Strong {layer_no}.{head_no} Writes in the Number Embed Direction Relative to Attn Prob",
    )

    if return_vals:
        return viz_df
    if return_fig:
        return fig
    else:
        fig.show()

In [None]:
scatter_attention_and_contribution(
    model=model, layer_no=9, head_no=1, dataset=dataset, S1_is_first=False
)

352 is not present in [16, 362, 513, 604]. Skipping...
362 is not present in [17, 513, 604, 642]. Skipping...
513 is not present in [18, 604, 642, 718]. Skipping...
604 is not present in [19, 642, 718, 767]. Skipping...
642 is not present in [20, 718, 767, 807]. Skipping...
718 is not present in [21, 767, 807, 860]. Skipping...
767 is not present in [22, 807, 860, 838]. Skipping...
807 is not present in [23, 860, 838, 1367]. Skipping...
860 is not present in [24, 838, 1367, 1105]. Skipping...
838 is not present in [940, 1367, 1105, 1511]. Skipping...
1367 is not present in [1157, 1105, 1511, 1478]. Skipping...
1105 is not present in [1065, 1511, 1478, 1315]. Skipping...
1511 is not present in [1485, 1478, 1315, 1467]. Skipping...
1478 is not present in [1415, 1315, 1467, 1596]. Skipping...
1315 is not present in [1314, 1467, 1596, 1248]. Skipping...
1467 is not present in [1433, 1596, 1248, 678]. Skipping...
1596 is not present in [1558, 1248, 678, 1160]. Skipping...
1248 is not presen

## Correlation vals

In [None]:
def get_prob_dot(  # same as scatterplot, but output x and y vals instead of plotting
    model,
    layer_no,
    head_no,
    dataset,
    S1_is_first=False,
    return_vals=False,
    return_fig=False,
):
    """
    Plot a scatter plot
    for each input sequence with the attention paid to S
    and the amount that is written in the S directions
    """

    n_heads = model.cfg.n_heads
    n_layers = model.cfg.n_layers
    model_unembed = model.unembed.W_U.detach().cpu()
    # df = []
    all_prob = []
    all_dot = []
    cache = {}
    model.cache_all(cache)

    logits = model(dataset.toks.long())

    for i, prompt in enumerate(dataset.prompts):
        s_toks = []
        s_positions = []
        s_dirs = []

        targ_tokens = [key for key in dataset.prompts[0].keys() if key != 'text']
        for s_id in targ_tokens:
            if S1_is_first and s_id == "S1":  # only use this if first token doesn't have space Ġ in front
                s_tok = model.tokenizer(prompt["S1"])["input_ids"][0]
            else:
                s_tok = model.tokenizer(" " + prompt[s_id])["input_ids"][0]
            s_toks.append(s_tok)

            toks = model.tokenizer(prompt["text"])["input_ids"]
            try:
                s_pos = toks.index(s_tok)
            except ValueError:
                print(f"{s_tok} is not present in {toks}. Skipping...")
                continue

            s_pos = toks.index(s_tok)
            s_positions.append(s_pos)

            s_dir = model_unembed[:, s_tok].detach()
            s_dirs.append(s_dir)

        for dire, posses, tok_type in zip(s_dirs, s_positions, targ_tokens):
            prob = sum(
                [
                    cache[f"blocks.{layer_no}.attn.hook_attn"][
                        i, head_no, dataset.word_idx["end"][i], pos
                    ]
                    .detach()
                    .cpu()
                    for pos in [posses]
                ]
            )
            resid = (
                cache[f"blocks.{layer_no}.attn.hook_result"][
                    i, dataset.word_idx["end"][i], head_no, :
                ]
                .detach()
                .cpu()
            )
            dot = torch.einsum("a,a->", resid, dire)
            #df.append([prob, dot, tok_type, prompt["text"]])
            all_prob.append(prob)
            all_dot.append(dot)

    return all_prob, all_dot


In [None]:
all_prob, all_dot = get_prob_dot(
    model=model, layer_no=9, head_no=9, dataset=dataset, S1_is_first=False
)

352 is not present in [16, 362, 513, 604]. Skipping...
362 is not present in [17, 513, 604, 642]. Skipping...
513 is not present in [18, 604, 642, 718]. Skipping...
604 is not present in [19, 642, 718, 767]. Skipping...
642 is not present in [20, 718, 767, 807]. Skipping...
718 is not present in [21, 767, 807, 860]. Skipping...
767 is not present in [22, 807, 860, 838]. Skipping...
807 is not present in [23, 860, 838, 1367]. Skipping...
860 is not present in [24, 838, 1367, 1105]. Skipping...
838 is not present in [940, 1367, 1105, 1511]. Skipping...
1367 is not present in [1157, 1105, 1511, 1478]. Skipping...
1105 is not present in [1065, 1511, 1478, 1315]. Skipping...
1511 is not present in [1485, 1478, 1315, 1467]. Skipping...
1478 is not present in [1415, 1315, 1467, 1596]. Skipping...
1315 is not present in [1314, 1467, 1596, 1248]. Skipping...
1467 is not present in [1433, 1596, 1248, 678]. Skipping...
1596 is not present in [1558, 1248, 678, 1160]. Skipping...
1248 is not presen

In [None]:
import scipy.stats as stats

# X and Y should be arrays, lists, or pandas Series
correlation, p_value = stats.pearsonr(all_prob, all_dot)

print("Correlation:", correlation)
print("p-value:", p_value)

Correlation: 0.8127540109290008
p-value: 8.99435128806603e-70


In [None]:
all_prob, all_dot = get_prob_dot(
    model=model, layer_no=9, head_no=1, dataset=dataset, S1_is_first=False
)

correlation, p_value = stats.pearsonr(all_prob, all_dot)
print("Correlation:", correlation)
print("p-value:", p_value)

352 is not present in [16, 362, 513, 604]. Skipping...
362 is not present in [17, 513, 604, 642]. Skipping...
513 is not present in [18, 604, 642, 718]. Skipping...
604 is not present in [19, 642, 718, 767]. Skipping...
642 is not present in [20, 718, 767, 807]. Skipping...
718 is not present in [21, 767, 807, 860]. Skipping...
767 is not present in [22, 807, 860, 838]. Skipping...
807 is not present in [23, 860, 838, 1367]. Skipping...
860 is not present in [24, 838, 1367, 1105]. Skipping...
838 is not present in [940, 1367, 1105, 1511]. Skipping...
1367 is not present in [1157, 1105, 1511, 1478]. Skipping...
1105 is not present in [1065, 1511, 1478, 1315]. Skipping...
1511 is not present in [1485, 1478, 1315, 1467]. Skipping...
1478 is not present in [1415, 1315, 1467, 1596]. Skipping...
1315 is not present in [1314, 1467, 1596, 1248]. Skipping...
1467 is not present in [1433, 1596, 1248, 678]. Skipping...
1596 is not present in [1558, 1248, 678, 1160]. Skipping...
1248 is not presen

In [None]:
all_prob, all_dot = get_prob_dot(
    model=model, layer_no=7, head_no=10, dataset=dataset, S1_is_first=False
)

correlation, p_value = stats.pearsonr(all_prob, all_dot)
print("Correlation:", correlation)
print("p-value:", p_value)

352 is not present in [16, 362, 513, 604]. Skipping...
362 is not present in [17, 513, 604, 642]. Skipping...
513 is not present in [18, 604, 642, 718]. Skipping...
604 is not present in [19, 642, 718, 767]. Skipping...
642 is not present in [20, 718, 767, 807]. Skipping...
718 is not present in [21, 767, 807, 860]. Skipping...
767 is not present in [22, 807, 860, 838]. Skipping...
807 is not present in [23, 860, 838, 1367]. Skipping...
860 is not present in [24, 838, 1367, 1105]. Skipping...
838 is not present in [940, 1367, 1105, 1511]. Skipping...
1367 is not present in [1157, 1105, 1511, 1478]. Skipping...
1105 is not present in [1065, 1511, 1478, 1315]. Skipping...
1511 is not present in [1485, 1478, 1315, 1467]. Skipping...
1478 is not present in [1415, 1315, 1467, 1596]. Skipping...
1315 is not present in [1314, 1467, 1596, 1248]. Skipping...
1467 is not present in [1433, 1596, 1248, 678]. Skipping...
1596 is not present in [1558, 1248, 678, 1160]. Skipping...
1248 is not presen

In [None]:
all_prob, all_dot = get_prob_dot(
    model=model, layer_no=5, head_no=1, dataset=dataset, S1_is_first=False
)

correlation, p_value = stats.pearsonr(all_prob, all_dot)
print("Correlation:", correlation)
print("p-value:", p_value)

352 is not present in [16, 362, 513, 604]. Skipping...
362 is not present in [17, 513, 604, 642]. Skipping...
513 is not present in [18, 604, 642, 718]. Skipping...
604 is not present in [19, 642, 718, 767]. Skipping...
642 is not present in [20, 718, 767, 807]. Skipping...
718 is not present in [21, 767, 807, 860]. Skipping...
767 is not present in [22, 807, 860, 838]. Skipping...
807 is not present in [23, 860, 838, 1367]. Skipping...
860 is not present in [24, 838, 1367, 1105]. Skipping...
838 is not present in [940, 1367, 1105, 1511]. Skipping...
1367 is not present in [1157, 1105, 1511, 1478]. Skipping...
1105 is not present in [1065, 1511, 1478, 1315]. Skipping...
1511 is not present in [1485, 1478, 1315, 1467]. Skipping...
1478 is not present in [1415, 1315, 1467, 1596]. Skipping...
1315 is not present in [1314, 1467, 1596, 1248]. Skipping...
1467 is not present in [1433, 1596, 1248, 678]. Skipping...
1596 is not present in [1558, 1248, 678, 1160]. Skipping...
1248 is not presen

In [None]:
all_prob, all_dot = get_prob_dot(
    model=model, layer_no=0, head_no=3, dataset=dataset, S1_is_first=False
)

correlation, p_value = stats.pearsonr(all_prob, all_dot)
print("Correlation:", correlation)
print("p-value:", p_value)

352 is not present in [16, 362, 513, 604]. Skipping...
362 is not present in [17, 513, 604, 642]. Skipping...
513 is not present in [18, 604, 642, 718]. Skipping...
604 is not present in [19, 642, 718, 767]. Skipping...
642 is not present in [20, 718, 767, 807]. Skipping...
718 is not present in [21, 767, 807, 860]. Skipping...
767 is not present in [22, 807, 860, 838]. Skipping...
807 is not present in [23, 860, 838, 1367]. Skipping...
860 is not present in [24, 838, 1367, 1105]. Skipping...
838 is not present in [940, 1367, 1105, 1511]. Skipping...
1367 is not present in [1157, 1105, 1511, 1478]. Skipping...
1105 is not present in [1065, 1511, 1478, 1315]. Skipping...
1511 is not present in [1485, 1478, 1315, 1467]. Skipping...
1478 is not present in [1415, 1315, 1467, 1596]. Skipping...
1315 is not present in [1314, 1467, 1596, 1248]. Skipping...
1467 is not present in [1433, 1596, 1248, 678]. Skipping...
1596 is not present in [1558, 1248, 678, 1160]. Skipping...
1248 is not presen