<a href="https://colab.research.google.com/github/wlg100/numseqcont_circuit_expms/blob/main/notebook_templates/headFNs_expms_template.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

# Setup

In [1]:
%%capture
!pip install accelerate

In [2]:
%%capture
%pip install git+https://github.com/neelnanda-io/TransformerLens.git

In [3]:
%pip install einops



In [4]:
import torch
import numpy as np
import einops
from copy import deepcopy

# Load Model

In [8]:
from transformers import LlamaForCausalLM, LlamaTokenizer

In [9]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) n
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [10]:
LLAMA_2_7B_CHAT_PATH = "meta-llama/Llama-2-7b-chat-hf"

# tokenizer = LlamaTokenizer.from_pretrained(LLAMA_2_7B_CHAT_PATH)
tokenizer = LlamaTokenizer.from_pretrained(LLAMA_2_7B_CHAT_PATH, use_fast= False, add_prefix_space= False)
hf_model = LlamaForCausalLM.from_pretrained(LLAMA_2_7B_CHAT_PATH, low_cpu_mem_usage=True)

tokenizer_config.json:   0%|          | 0.00/1.62k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

In [12]:
import transformer_lens.utils as utils
from transformer_lens.hook_points import HookPoint
from transformer_lens import HookedTransformer

In [13]:
model = HookedTransformer.from_pretrained(
    LLAMA_2_7B_CHAT_PATH,
    hf_model = hf_model,
    tokenizer = tokenizer,
    device = "cpu",
    fold_ln = False,
    center_writing_weights = False,
    center_unembed = False,
)

del hf_model

model = model.to("cuda" if torch.cuda.is_available() else "cpu")

Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer
Moving model to device:  cuda


# Dataset and Score Functions

In [16]:
class Dataset:
    def __init__(self, prompts, tokenizer):  # , S1_is_first=False
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        pos_dict = {}
        list_tokens = tokenizer.tokenize('2 4 6 ')
        for i, tok_as_str in enumerate(list_tokens):
            pos_dict['S'+str(i)] = i

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        # for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
        for targ in [key for key in pos_dict]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = self.tokenizer.tokenize(input_text)
                target_index = pos_dict[targ]
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [17]:
model.blocks[1].attn

Attention(
  (hook_k): HookPoint()
  (hook_q): HookPoint()
  (hook_v): HookPoint()
  (hook_z): HookPoint()
  (hook_attn_scores): HookPoint()
  (hook_pattern): HookPoint()
  (hook_result): HookPoint()
  (hook_rot_k): HookPoint()
  (hook_rot_q): HookPoint()
)

In [19]:
def get_copy_scores(model, layer, head, dataset, neg=False, print_all_results=True):
    cache = {}
    model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    model(dataset.toks.long())
    if neg:
        sign = -1
    else:
        sign = 1
    # z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])
    z_0 = model.blocks[1].attn.hook_z(cache["blocks.0.hook_resid_post"])

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    pred_tokens_dict = {}
    words_moved = []
    # get the keys from the first prompt in the dataset
    words = [key for key in dataset.prompts[0].keys() if key != 'text']

    for seq_idx, prompt in enumerate(dataset.prompts):
        for word in words:
            pred_tokens = [
                model.tokenizer.decode(token)
                for token in torch.topk(
                    logits[seq_idx, dataset.word_idx[word][seq_idx]], k
                ).indices
            ]

            token_in_topK = 'no'
            if " " + prompt[word] in pred_tokens or prompt[word] in pred_tokens:
                n_right += 1
                words_moved.append(prompt[word])
                token_in_topK = 'yes'
            pred_tokens_dict[prompt[word]] = (pred_tokens, token_in_topK)

    percent_right = (n_right / (dataset.N * len(words))) * 100
    if percent_right > 0:
        print(f"Copy circuit for head {layer}.{head}: Top {k} accuracy: {percent_right}%")

    if print_all_results == True:
        print(pred_tokens_dict)

    return percent_right

In [18]:
def get_next_scores(model, layer, head, dataset, task="numerals", neg=False, print_all_results=True):
    cache = {}
    model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    model(dataset.toks.long())
    if neg:
        sign = -1
    else:
        sign = 1
    # z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])
    z_0 = model.blocks[1].attn.hook_z(cache["blocks.0.hook_resid_post"])

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    pred_tokens_dict = {}
    words_moved = []
    # get the keys from the first prompt in the dataset
    words = [key for key in dataset.prompts[0].keys() if key != 'text']

    numwords = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve']
    months = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December']
    ranks = ['first', 'second', 'third', 'fourth', 'fifth', 'sixth', 'seventh', 'eighth', 'ninth', 'tenth', 'eleventh', 'twelfth']

    for seq_idx, prompt in enumerate(dataset.prompts):
        for word in words:
            pred_tokens = [
                model.tokenizer.decode(token)
                for token in torch.topk(
                    logits[seq_idx, dataset.word_idx[word][seq_idx]], k
                ).indices
            ]

            # get next member after seq member prompt[word]
            if task == "numerals":
                next_word = str(int(prompt[word]) + 1)
            elif task == "numwords":
                next_word = str(numwords[numwords.index(prompt[word]) + 1])
            elif task == "months":
                next_word = str(ranks[months.index(prompt[word]) + 1])

            nextToken_in_topK = 'no'
            if " " + next_word in pred_tokens or next_word in pred_tokens:
                n_right += 1
                words_moved.append(prompt[word])
                nextToken_in_topK = 'yes'
            pred_tokens_dict[prompt[word]] = (pred_tokens, next_word, nextToken_in_topK)

    percent_right = (n_right / (dataset.N * len(words))) * 100
    if percent_right > 0:
        print(f"Next circuit for head {layer}.{head}: Top {k} accuracy: {percent_right}%")

    if print_all_results == True:
        print(pred_tokens_dict)

    return percent_right

In [20]:
def get_next_next_scores(model, layer, head, dataset, task="numerals", neg=False, print_all_results=True):
    cache = {}
    model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    model(dataset.toks.long())
    if neg:
        sign = -1
    else:
        sign = 1
    # z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])
    z_0 = model.blocks[1].attn.hook_z(cache["blocks.0.hook_resid_post"])

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    pred_tokens_dict = {}
    words_moved = []
    # get the keys from the first prompt in the dataset
    words = [key for key in dataset.prompts[0].keys() if key != 'text']

    # numwords = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve']
    # months = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December']
    # ranks = ['first', 'second', 'third', 'fourth', 'fifth', 'sixth', 'seventh', 'eighth', 'ninth', 'tenth', 'eleventh', 'twelfth']

    for seq_idx, prompt in enumerate(dataset.prompts):
        for word in words:
            pred_tokens = [
                model.tokenizer.decode(token)
                for token in torch.topk(
                    logits[seq_idx, dataset.word_idx[word][seq_idx]], k
                ).indices
            ]

            # get next NEXT member after seq member prompt[word]
            if task == "numerals":
                next_word = str(int(prompt[word]) + 2)
            # elif task == "numwords":
            #     next_word = str(numwords[numwords.index(prompt[word]) + 1])
            # elif task == "months":
            #     next_word = str(ranks[months.index(prompt[word]) + 1])

            nextToken_in_topK = 'no'
            if " " + next_word in pred_tokens or next_word in pred_tokens:
                n_right += 1
                words_moved.append(prompt[word])
                nextToken_in_topK = 'yes'
            pred_tokens_dict[prompt[word]] = (pred_tokens, next_word, nextToken_in_topK)

    percent_right = (n_right / (dataset.N * len(words))) * 100
    if percent_right > 0:
        print(f"Next next score for head {layer}.{head}: Top {k} accuracy: {percent_right}%")

    if print_all_results == True:
        print(pred_tokens_dict)

    return percent_right

# Numerals

## make dataset

In [21]:
def generate_prompts():
    prompts_list = []
    # prompt_dict = {
    #     # 'corr': '8',
    #     # 'incorr': '6',
    #     'text': f"2 4 6 "
    # }
    # list_tokens = tokenizer.tokenize('2 4 6 ')
    # for i, tok_as_str in enumerate(list_tokens):
    #     if tok_as_str == '▁':
    #         prompt_dict['S'+str(i)] = ' '
    #     else:
    #         prompt_dict['S'+str(i)] = tok_as_str
    prompt_dict = {
        'S1': '2',
        'S2': '4',
        'S3': '6',
        'text': f"2 4 6 "
    }
    prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts()
prompts_list

[{'S1': '2', 'S2': '4', 'S3': '6', 'text': '2 4 6 '}]

In [22]:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

1

In [23]:
dataset = Dataset(prompts_list, tokenizer)

## get copy scores

In [None]:
get_copy_scores(model, 16, 0, dataset)



{'2': (['YES', ' yes', ' YES', 'yes', ' Yes'], 'no'), '4': ([' Yes', ' yes', ' Хронологија', ' YES', ' Inner'], 'no'), '6': ([' consec', '<0x94>', ' yes', 'bay', 'yes'], 'no')}


0.0

In [None]:
# %%capture
# all_copy_scores = []
all_copy_scores = {}
all_heads = [(layer, head) for layer in range(32) for head in range(32)]
for index, (layer, head) in enumerate(all_heads):
    # all_copy_scores.append(get_copy_scores(model, layer, head, dataset, print_all_results=False))
    all_copy_scores[(layer, head)] = get_copy_scores(model, layer, head, dataset, print_all_results=False)

In [None]:
k=5
for (layer, head), percent_right in all_copy_scores.items():
    if percent_right > 0:
        print(f"Copy circuit for head {layer}.{head}: Top {k} accuracy: {percent_right}%")

Copy circuit for head 6.9: Top 5 accuracy: 33.33333333333333%
Copy circuit for head 6.16: Top 5 accuracy: 66.66666666666666%
Copy circuit for head 7.10: Top 5 accuracy: 33.33333333333333%
Copy circuit for head 11.15: Top 5 accuracy: 33.33333333333333%
Copy circuit for head 12.2: Top 5 accuracy: 33.33333333333333%
Copy circuit for head 12.26: Top 5 accuracy: 66.66666666666666%
Copy circuit for head 16.19: Top 5 accuracy: 33.33333333333333%
Copy circuit for head 16.24: Top 5 accuracy: 33.33333333333333%
Copy circuit for head 17.12: Top 5 accuracy: 66.66666666666666%
Copy circuit for head 17.22: Top 5 accuracy: 33.33333333333333%
Copy circuit for head 17.31: Top 5 accuracy: 66.66666666666666%
Copy circuit for head 22.31: Top 5 accuracy: 33.33333333333333%
Copy circuit for head 24.29: Top 5 accuracy: 33.33333333333333%
Copy circuit for head 25.3: Top 5 accuracy: 33.33333333333333%
Copy circuit for head 27.9: Top 5 accuracy: 33.33333333333333%
Copy circuit for head 29.14: Top 5 accuracy: 33

In [None]:
sum(all_copy_scores.values())/len(all_copy_scores)

0.7486979166666667

## get next scores

In [None]:
# %%capture
# all_next_scores = []
all_next_scores = {}
all_heads = [(layer, head) for layer in range(32) for head in range(32)]
for index, (layer, head) in enumerate(all_heads):
    # all_next_scores.append(get_next_scores(model, layer, head, dataset, task="numerals", print_all_results=False))
    all_next_scores[(layer, head)] = get_next_scores(model, layer, head, dataset, print_all_results=False)

In [28]:
k=5
for (layer, head), percent_right in all_next_scores.items():
    if percent_right > 0:
        print(f"Next score for head {layer}.{head}: Top {k} accuracy: {percent_right}%")

Next score for head 6.16: Top 5 accuracy: 33.33333333333333%
Next score for head 12.26: Top 5 accuracy: 33.33333333333333%
Next score for head 17.12: Top 5 accuracy: 33.33333333333333%
Next score for head 20.17: Top 5 accuracy: 33.33333333333333%
Next score for head 22.31: Top 5 accuracy: 66.66666666666666%
Next score for head 24.29: Top 5 accuracy: 33.33333333333333%
Next score for head 28.5: Top 5 accuracy: 33.33333333333333%
Next score for head 29.7: Top 5 accuracy: 33.33333333333333%
Next score for head 29.14: Top 5 accuracy: 33.33333333333333%
Next score for head 30.13: Top 5 accuracy: 66.66666666666666%
Next score for head 31.10: Top 5 accuracy: 33.33333333333333%
Next score for head 31.30: Top 5 accuracy: 33.33333333333333%


In [27]:
sum(all_next_scores.values())/len(all_next_scores)

0.4557291666666665

In [None]:
get_next_scores(model, 22, 31, dataset, task="numerals")



{'2': (['YES', ' yes', ' YES', 'yes', ' Yes'], '3', 'no'), '4': ([' Yes', ' yes', ' Хронологија', ' YES', ' Inner'], '5', 'no'), '6': ([' consec', '<0x94>', ' yes', 'bay', 'yes'], '7', 'no')}


0.0

In [30]:
get_next_scores(model, 30, 13, dataset, task="numerals")



Next circuit for head 30.13: Top 5 accuracy: 66.66666666666666%
{'2': (['0', '7', '9', '8', '6'], '3', 'no'), '4': (['0', '5', '7', '9', '6'], '5', 'yes'), '6': (['0', '5', '7', '8', '6'], '7', 'yes')}


66.66666666666666

## next next scores

In [24]:
get_next_next_scores(model, 16, 0, dataset, task="numerals")



{'2': (['YES', ' yes', ' YES', 'yes', ' Yes'], '4', 'no'), '4': ([' Yes', ' yes', ' Хронологија', ' YES', ' Inner'], '6', 'no'), '6': ([' consec', '<0x94>', ' yes', 'bay', 'yes'], '8', 'no')}


0.0

In [None]:
all_next_next_scores = {}
all_heads = [(layer, head) for layer in range(32) for head in range(32)]
for index, (layer, head) in enumerate(all_heads):
    percent_right = get_next_next_scores(model, layer, head, dataset, print_all_results=False)
    if percent_right > 0:
        all_next_next_scores[(layer, head)] = percent_right

In [31]:
k=5
for (layer, head), percent_right in all_next_next_scores.items():
    if percent_right > 0:
        print(f"Next +2 score for head {layer}.{head}: Top {k} accuracy: {percent_right}%")

Next +2 score for head 12.26: Top 5 accuracy: 33.33333333333333%
Next +2 score for head 17.12: Top 5 accuracy: 33.33333333333333%
Next +2 score for head 22.31: Top 5 accuracy: 33.33333333333333%
Next +2 score for head 28.24: Top 5 accuracy: 33.33333333333333%
Next +2 score for head 29.14: Top 5 accuracy: 33.33333333333333%
Next +2 score for head 30.13: Top 5 accuracy: 66.66666666666666%
Next +2 score for head 31.30: Top 5 accuracy: 66.66666666666666%


In [34]:
for (layer, head), percent_right in all_next_next_scores.items():
    if percent_right > 0:
        get_next_next_scores(model, layer, head, dataset, print_all_results=True)



Next next score for head 12.26: Top 5 accuracy: 33.33333333333333%
{'2': (['2', '3', '1', '4', '6'], '4', 'yes'), '4': ([' and', 'koz', ' <', 'ank', ' Pear'], '6', 'no'), '6': (['3', '4', '2', '6', '1'], '8', 'no')}
Next next score for head 17.12: Top 5 accuracy: 33.33333333333333%
{'2': (['2', '1', '3', '4', '5'], '4', 'yes'), '4': (['_', 'US', ' ::', ' US', '<0xAB>'], '6', 'no'), '6': (['-', '_', '/', '#', '6'], '8', 'no')}




Next next score for head 22.31: Top 5 accuracy: 33.33333333333333%
{'2': (['3', '4', '5', '6', '7'], '4', 'yes'), '4': (['widet', '**', 'ii', ' (**', ' $('], '6', 'no'), '6': (['6', '4', '3', '7', '5'], '8', 'no')}
Next next score for head 28.24: Top 5 accuracy: 33.33333333333333%
{'2': (['nd', 'ND', 'nde', 'nder', 'ns'], '4', 'no'), '4': (['nd', 'ND', '0', 'nde', '6'], '6', 'yes'), '6': (['nd', 'ND', 'nde', 'itu', 'нд'], '8', 'no')}




Next next score for head 29.14: Top 5 accuracy: 33.33333333333333%
{'2': (['2', '1', '4', '3', '0'], '4', 'yes'), '4': (['ery', '.', 'now', '-', 'ware'], '6', 'no'), '6': (['-', 'ble', '.', 'ery', '1'], '8', 'no')}
Next next score for head 30.13: Top 5 accuracy: 66.66666666666666%
{'2': (['0', '7', '9', '8', '6'], '4', 'no'), '4': (['0', '5', '7', '9', '6'], '6', 'yes'), '6': (['0', '5', '7', '8', '6'], '8', 'yes')}
Next next score for head 31.30: Top 5 accuracy: 66.66666666666666%
{'2': (['2', '3', '1', '4', '6'], '4', 'yes'), '4': ([' twenty', ' .', ' necess', ' twelve', ' thirty'], '6', 'no'), '6': (['4', '5', '6', '3', '8'], '8', 'yes')}


In [33]:
sum(all_next_next_scores.values())/len(all_next_next_scores)

0.29296874999999994

In [32]:
get_next_next_scores(model, 31, 30, dataset, task="numerals")



Next next score for head 31.30: Top 5 accuracy: 66.66666666666666%
{'2': (['2', '3', '1', '4', '6'], '4', 'yes'), '4': ([' twenty', ' .', ' necess', ' twelve', ' thirty'], '6', 'no'), '6': (['4', '5', '6', '3', '8'], '8', 'yes')}


66.66666666666666