<a href="https://colab.research.google.com/github/wlg100/numseqcont_circuit_expms/blob/main/notebook_templates/headFNs_expms_template.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

# Setup

In [1]:
%%capture
%pip install git+https://github.com/redwoodresearch/Easy-Transformer.git
%pip install einops datasets transformers fancy_einsum

In [2]:
from copy import deepcopy
import torch

assert torch.cuda.device_count() == 1
from tqdm import tqdm
import pandas as pd
import torch
import torch as t
from easy_transformer.EasyTransformer import (
    EasyTransformer,
)
from time import ctime
from functools import partial

import numpy as np
from tqdm import tqdm
import pandas as pd

from easy_transformer.experiments import (
    ExperimentMetric,
    AblationConfig,
    EasyAblation,
    EasyPatching,
    PatchingConfig,
)
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
import random
import einops
from IPython import get_ipython
from copy import deepcopy
from easy_transformer.ioi_dataset import (
    IOIDataset,
)
from easy_transformer.ioi_utils import (
    path_patching,
    max_2d,
    CLASS_COLORS,
    show_pp,
    show_attention_patterns,
    scatter_attention_and_contribution,
)
from random import randint as ri
from easy_transformer.ioi_circuit_extraction import (
    do_circuit_extraction,
    get_heads_circuit,
    CIRCUIT,
)
from easy_transformer.ioi_utils import logit_diff, probs
from easy_transformer.ioi_utils import get_top_tokens_and_probs as g

ipython = get_ipython()
if ipython is not None:
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

 Initialise model (use larger N or fewer templates for no warnings about in-template ablation)

In [3]:
model = EasyTransformer.from_pretrained("gpt2-medium").cuda()
# model = EasyTransformer.from_pretrained("gpt2")
model.set_use_attn_result(True)

Downloading (…)lve/main/config.json:   0%|          | 0.00/718 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/1.52G [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



Moving model to device:  cuda
Finished loading pretrained model gpt2-medium into EasyTransformer!


In [4]:
import pdb

# Generate dataset with multiple prompts

In [5]:
def generate_prompts_list():
    prompts_list = []
    for i in range(1, 98):
        prompt_dict = {
            'S1': str(i),
            'S2': str(i+1),
            'S3': str(i+2),
            'S4': str(i+3),
            'text': f"{i} {i+1} {i+2} {i+3}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list()

In [6]:
class Dataset:
    def __init__(self, prompts, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if key != 'text']:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                    target_token = prompt[targ]
                else:
                    target_token = "Ġ" + prompt[targ]
                target_index = tokens.index(target_token)
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [7]:
dataset = Dataset(prompts_list, model.tokenizer, S1_is_first=True)

# Next score

In [15]:
def get_next_scores(model, layer, head, dataset, verbose=False, neg=False, print_tokens=True):
    cache = {}
    model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    model(dataset.toks.long())
    if neg:
        sign = -1
    else:
        sign = 1
    z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    pred_tokens_dict = {}
    words_moved = []
    # get the keys from the first prompt in the dataset
    words = [key for key in dataset.prompts[0].keys() if key != 'text']

    for seq_idx, prompt in enumerate(dataset.prompts):
        # for word in words:
        word = words[-1]
        pred_tokens = [
            model.tokenizer.decode(token)
            for token in torch.topk(
                logits[seq_idx, dataset.word_idx[word][seq_idx]], k
            ).indices
        ]

        # get next member after digit prompt[word]
        next_word = str(int(prompt[word]) + 1)

        nextToken_in_topK = 'no'
        if " " + next_word in pred_tokens or next_word in pred_tokens:
            n_right += 1
            words_moved.append(prompt[word])
            nextToken_in_topK = 'yes'
        pred_tokens_dict[prompt[word]] = (pred_tokens, next_word, nextToken_in_topK)

    # percent_right = (n_right / (dataset.N * len(words))) * 100
    percent_right = (n_right / (dataset.N)) * 100
    if percent_right > 0:
        print(f"Next circuit for head {layer}.{head} (sign={sign}) : Top {k} accuracy: {percent_right}%")

    if print_tokens == True:
        return pred_tokens_dict
    else:
        return words_moved

In [16]:
all_heads = [(layer, head) for layer in range(24) for head in range(16)]
for index, (layer, head) in enumerate(all_heads):
    get_next_scores(model, layer, head, dataset, print_tokens=False)

Next circuit for head 5.8 (sign=1) : Top 5 accuracy: 2.0618556701030926%
Next circuit for head 6.1 (sign=1) : Top 5 accuracy: 14.432989690721648%
Next circuit for head 6.11 (sign=1) : Top 5 accuracy: 3.0927835051546393%
Next circuit for head 7.2 (sign=1) : Top 5 accuracy: 9.278350515463918%
Next circuit for head 9.4 (sign=1) : Top 5 accuracy: 13.402061855670103%
Next circuit for head 9.9 (sign=1) : Top 5 accuracy: 53.608247422680414%
Next circuit for head 10.1 (sign=1) : Top 5 accuracy: 14.432989690721648%
Next circuit for head 10.8 (sign=1) : Top 5 accuracy: 45.36082474226804%
Next circuit for head 10.9 (sign=1) : Top 5 accuracy: 2.0618556701030926%
Next circuit for head 11.1 (sign=1) : Top 5 accuracy: 93.81443298969072%
Next circuit for head 11.5 (sign=1) : Top 5 accuracy: 1.0309278350515463%
Next circuit for head 12.0 (sign=1) : Top 5 accuracy: 7.216494845360824%
Next circuit for head 12.1 (sign=1) : Top 5 accuracy: 5.154639175257731%
Next circuit for head 12.13 (sign=1) : Top 5 acc

# Compare next scores to copy scores

In [10]:
def get_copy_scores(model, layer, head, dataset, verbose=False, neg=False, print_tokens=True):
    cache = {}
    model.cache_some(cache, lambda x: x == "blocks.0.hook_resid_post")
    model(dataset.toks.long())
    if neg:
        sign = -1
    else:
        sign = 1
    z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])

    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])
    logits = model.unembed(model.ln_final(o))

    k = 5
    n_right = 0

    pred_tokens_dict = {}
    words_moved = []
    # get the keys from the first prompt in the dataset
    words = [key for key in dataset.prompts[0].keys() if key != 'text']

    for seq_idx, prompt in enumerate(dataset.prompts):
        for word in words:
            pred_tokens = [
                model.tokenizer.decode(token)
                for token in torch.topk(
                    logits[seq_idx, dataset.word_idx[word][seq_idx]], k
                ).indices
            ]

            token_in_topK = 'no'
            if " " + prompt[word] in pred_tokens or prompt[word] in pred_tokens:
                n_right += 1
                words_moved.append(prompt[word])
                token_in_topK = 'yes'
            pred_tokens_dict[prompt[word]] = (pred_tokens, token_in_topK)

    percent_right = (n_right / (dataset.N * len(words))) * 100
    if percent_right > 0:
        print(f"Copy circuit for head {layer}.{head} (sign={sign}) : Top {k} accuracy: {percent_right}%")

    if print_tokens == True:
        return pred_tokens_dict
    else:
        return words_moved

In [12]:
all_copy_scores = {}
all_heads = [(layer, head) for layer in range(24) for head in range(16)]
for index, (layer, head) in enumerate(all_heads):
    cop_sco = get_copy_scores(model, layer, head, dataset, print_tokens=False)
    all_copy_scores[(layer, head)] =cop_sco

Copy circuit for head 5.8 (sign=1) : Top 5 accuracy: 1.0309278350515463%
Copy circuit for head 6.1 (sign=1) : Top 5 accuracy: 20.36082474226804%
Copy circuit for head 6.11 (sign=1) : Top 5 accuracy: 1.804123711340206%
Copy circuit for head 7.2 (sign=1) : Top 5 accuracy: 17.010309278350515%
Copy circuit for head 9.4 (sign=1) : Top 5 accuracy: 19.329896907216497%
Copy circuit for head 9.7 (sign=1) : Top 5 accuracy: 0.25773195876288657%
Copy circuit for head 9.9 (sign=1) : Top 5 accuracy: 92.26804123711341%
Copy circuit for head 10.1 (sign=1) : Top 5 accuracy: 26.54639175257732%
Copy circuit for head 10.8 (sign=1) : Top 5 accuracy: 85.30927835051546%
Copy circuit for head 10.9 (sign=1) : Top 5 accuracy: 1.2886597938144329%
Copy circuit for head 11.1 (sign=1) : Top 5 accuracy: 100.0%
Copy circuit for head 12.0 (sign=1) : Top 5 accuracy: 11.34020618556701%
Copy circuit for head 12.1 (sign=1) : Top 5 accuracy: 100.0%
Copy circuit for head 12.13 (sign=1) : Top 5 accuracy: 43.55670103092783%
C