# Setup

In [1]:
save_files = True

In [2]:
%%capture
%pip install git+https://github.com/neelnanda-io/TransformerLens.git

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
# import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

import pickle
from google.colab import files

import matplotlib.pyplot as plt
import statistics

In [4]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer #, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [5]:
torch.set_grad_enabled(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
import pdb

## Import functions from repo

In [7]:
!git clone https://github.com/apartresearch/seqcont_circuits.git
%cd /content/seqcont_circuits/src/iter_node_pruning

Cloning into 'seqcont_circuits'...
remote: Enumerating objects: 1022, done.[K
remote: Counting objects: 100% (488/488), done.[K
remote: Compressing objects: 100% (287/287), done.[K
remote: Total 1022 (delta 296), reused 378 (delta 190), pack-reused 534[K
Receiving objects: 100% (1022/1022), 18.76 MiB | 14.06 MiB/s, done.
Resolving deltas: 100% (659/659), done.
/content/seqcont_circuits/src/iter_node_pruning


In [8]:
## comment this out when debugging functions in colab to use funcs defined in colab

# don't improt this
# # from dataset import Dataset

from metrics import *
from head_ablation_fns import *
from mlp_ablation_fns import *
from node_ablation_fns import *
from loop_node_ablation_fns import *

## fns

In [9]:
import random


In [10]:
class Dataset:
    def __init__(self, prompts, pos_dict, tokenizer):  # , S1_is_first=False
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        self.corr_tokenIDs = [
            # self.tokenizer.encode(" " + prompt["corr"])[0] for prompt in self.prompts
            self.tokenizer.encode(prompt["corr"])[0] for prompt in self.prompts
        ]
        self.incorr_tokenIDs = [
            # self.tokenizer.encode(" " + prompt["incorr"])[0] for prompt in self.prompts
            self.tokenizer.encode(prompt["incorr"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a dict whose values are tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        # for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
        for targ in [key for key in pos_dict]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = self.tokenizer.tokenize(input_text)
                # if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                #     target_token = prompt[targ]
                # else:
                #     target_token = "Ġ" + prompt[targ]
                # target_index = tokens.index(target_token)
                target_index = pos_dict[targ]
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [11]:
def generate_prompts_list_longer(text, tokens):
    prompts_list = []
    prompt_dict = {
        'corr': str(1),
        'incorr': str(2),
        'text': text
        # 'text': model.to_string(tokens)[0]
        }
    tokens_as_strs = model.tokenizer.tokenize(text)
    # tokens_as_strs = model.to_string(tokens)[0].split()
    # for i in range(tokens.shape[1]):
    for i, tok in enumerate(tokens_as_strs):
        prompt_dict['S'+str(i)] = tok
    # for i, tok in enumerate(tokens):
    #     prompt_dict['S'+str(i)] = model.to_string(tok)

    # prompt_dict = {
    #     'corr': '4',
    #     'incorr': '3',
    #     'text': model.to_string(tokens)[0]
    # }
    # # list_tokens = tokenizer.tokenize('1 2 3 ')
    # tokens_as_strs = model.to_string(tokens)[0].split()
    # for i, tok_as_str in enumerate(tokens_as_strs):
    #     if tok_as_str == '▁':
    #         prompt_dict['S'+str(i)] = ' '
    #     else:
    #         prompt_dict['S'+str(i)] = tok_as_str
    prompts_list.append(prompt_dict)
    return prompts_list

# Load Model

In [12]:
from transformers import LlamaForCausalLM, LlamaTokenizer

In [13]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) n
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [14]:
LLAMA_2_7B_CHAT_PATH = "meta-llama/Llama-2-7b-chat-hf"

tokenizer = LlamaTokenizer.from_pretrained(LLAMA_2_7B_CHAT_PATH)
# tokenizer = LlamaTokenizer.from_pretrained(LLAMA_2_7B_CHAT_PATH, use_fast= False, add_prefix_space= False)
hf_model = LlamaForCausalLM.from_pretrained(LLAMA_2_7B_CHAT_PATH, low_cpu_mem_usage=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.62k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

In [15]:
import transformer_lens.utils as utils
from transformer_lens.hook_points import HookPoint
from transformer_lens import HookedTransformer

In [16]:
model = HookedTransformer.from_pretrained(
    LLAMA_2_7B_CHAT_PATH,
    hf_model = hf_model,
    tokenizer = tokenizer,
    device = "cpu",
    fold_ln = False,
    center_writing_weights = False,
    center_unembed = False,
)

del hf_model

model = model.to("cuda" if torch.cuda.is_available() else "cpu")

Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer
Moving model to device:  cuda


# new ablation functions

In [17]:
def get_heads_actv_mean(
    means_dataset: Dataset,
    model: HookedTransformer
) -> Float[Tensor, "layer batch seq head_idx d_head"]:
    '''
    Output: The mean activations of a head's output
    '''
    _, means_cache = model.run_with_cache(
        means_dataset.toks.long(),
        return_type=None,
        names_filter=lambda name: name.endswith("z"),
    )
    n_layers, n_heads, d_head = model.cfg.n_layers, model.cfg.n_heads, model.cfg.d_head
    batch, seq_len = len(means_dataset), means_dataset.max_len
    means = t.zeros(size=(n_layers, batch, seq_len, n_heads, d_head), device=model.cfg.device)

    # for layer in range(model.cfg.n_layers):
    #     z_for_this_layer: Float[Tensor, "batch seq head d_head"] = means_cache[utils.get_act_name("z", layer)]
    #     for template_group in means_dataset.groups:
    #         z_for_this_template = z_for_this_layer[template_group]
    #         z_means_for_this_template = einops.reduce(z_for_this_template, "batch seq head d_head -> seq head d_head", "mean")
    #         if z_means_for_this_template.shape[0] == 5:
    #             pdb.set_trace()
    #         means[layer, template_group] = z_means_for_this_template

    del(means_cache)

    return means

In [18]:
# def mask_circ_heads(
#     means_dataset: Dataset,
#     model: HookedTransformer,
#     circuit: Dict[str, List[Tuple[int, int]]],
#     seq_pos_to_keep: Dict[str, str],
# ) -> Dict[int, Bool[Tensor, "batch seq head"]]:
#     '''
#     Output: for each layer, a mask of circuit components that should not be ablated
#     '''
#     heads_and_posns_to_keep = {}
#     batch, seq, n_heads = len(means_dataset), means_dataset.max_len, model.cfg.n_heads

#     for layer in range(model.cfg.n_layers):

#         mask = t.zeros(size=(batch, seq, n_heads))

#         for (head_type, head_list) in circuit.items():
#             seq_pos = seq_pos_to_keep[head_type]
#             # if seq_pos == 'S7':
#             #     pdb.set_trace()
#             indices = means_dataset.word_idx[seq_pos] # modify this for key vs query pos. curr, this is query
#             for (layer_idx, head_idx) in head_list:
#                 if layer_idx == layer:
#                     # if indices.item() == 7:
#                     #     pdb.set_trace()
#                     mask[:, indices, head_idx] = 1
#                     # mask[:, :, head_idx] = 1  # keep L.H at all pos

#         heads_and_posns_to_keep[layer] = mask.bool()
#     # pdb.set_trace()
#     return heads_and_posns_to_keep

In [19]:
def mask_circ_heads(
    means_dataset: Dataset,
    model: HookedTransformer,
    circuit: Dict[str, List[Tuple[int, int]]],
    seq_pos_to_keep: Dict[str, str],
) -> Dict[int, Bool[Tensor, "batch seq head"]]:
    '''
    Output: for each layer, a mask of circuit components that should not be ablated
    '''
    heads_and_posns_to_keep = {}
    # batch, seq, n_heads = len(means_dataset), means_dataset.max_len, model.cfg.n_heads
    batch, seq, n_heads = len(means_dataset), len(circuit.keys()), model.cfg.n_heads
    # print(seq)

    for layer in range(model.cfg.n_layers):

        mask = t.zeros(size=(batch, seq, n_heads))

        for (head_type, head_list) in circuit.items():
            seq_pos = seq_pos_to_keep[head_type]
            indices = means_dataset.word_idx[seq_pos] # modify this for key vs query pos. curr, this is query
            for (layer_idx, head_idx) in head_list:
                if layer_idx == layer:
                    # mask[:, indices, head_idx] = 1
                    mask[:, :, head_idx] = 1

        heads_and_posns_to_keep[layer] = mask.bool()

    return heads_and_posns_to_keep

In [20]:
def hook_func_mask_head(
    z: Float[Tensor, "batch seq head d_head"],
    hook: HookPoint,
    # components_to_keep: Dict[int, Bool[Tensor, "batch seq head"]],
    # means: Float[Tensor, "layer batch seq head d_head"],
    circuit: Dict[str, List[Tuple[int, int]]],
) -> Float[Tensor, "batch seq head d_head"]:
    '''
    Use this to not mask components
    '''
    # mask_for_this_layer = components_to_keep[hook.layer()].unsqueeze(-1).to(z.device)
    # z = t.where(mask_for_this_layer, z, means[hook.layer()])

    ###
    # heads_and_posns_to_keep = {}
    # batch, seq, n_heads = z.shape[0], z.shape[1], model.cfg.n_heads  # components_to_keep[0].shape[0] is batch

    # for layer in range(model.cfg.n_layers):

    #     mask = t.zeros(size=(batch, seq, n_heads))

    #     for (head_type, head_list) in circuit.items():
    #         # seq_pos = seq_pos_to_keep[head_type]
    #         # indices = means_dataset.word_idx[seq_pos] # modify this for key vs query pos. curr, this is query
    #         for (layer_idx, head_idx) in head_list:
    #             if layer_idx == layer:
    #                 # mask[:, indices, head_idx] = 1
    #                 mask[:, :, head_idx] = 1

    #     heads_and_posns_to_keep[layer] = mask.bool()
    ###
    mask_for_this_layer = t.zeros(size=(z.shape[0], z.shape[1], z.shape[2]))
    for (head_type, head_list) in circuit.items():
        # seq_pos = seq_pos_to_keep[head_type]
        # indices = means_dataset.word_idx[seq_pos] # modify this for key vs query pos. curr, this is query
        for (layer_idx, head_idx) in head_list:
            if layer_idx == hook.layer():
                # mask[:, indices, head_idx] = 1
                mask_for_this_layer[:, :, head_idx] = 1

    mask_for_this_layer = mask_for_this_layer.bool()
    mask_for_this_layer = mask_for_this_layer.unsqueeze(-1).to(z.device)  # d_model is 1; then is broadcast in where

    z = t.where(mask_for_this_layer, z, 0)

    return z

In [21]:
def add_ablation_hook_head(
    model: HookedTransformer,
    means_dataset: Dataset,
    circuit: Dict[str, List[Tuple[int, int]]],
    seq_pos_to_keep: Dict[str, str],
    is_permanent: bool = True,
) -> HookedTransformer:
    '''
    Ablate the model, except as components and positions to keep
    '''

    model.reset_hooks(including_permanent=True)
    means = get_heads_actv_mean(means_dataset, model)
    components_to_keep = mask_circ_heads(means_dataset, model, circuit, seq_pos_to_keep)

    hook_fn = partial(
        hook_func_mask_head,
        # components_to_keep=components_to_keep,
        # means=means,
        circuit=circuit,
    )

    model.add_hook(lambda name: name.endswith("z"), hook_fn, is_permanent=is_permanent)
    return model

In [22]:
# from dataset import Dataset
from transformer_lens import HookedTransformer, utils
from transformer_lens.hook_points import HookPoint
import einops
from functools import partial
import torch as t
from torch import Tensor
from typing import Dict, Tuple, List
from jaxtyping import Float, Bool

# from head_ablation_fns import *
# from mlp_ablation_fns import *

def add_ablation_hook_MLP_head(
    model: HookedTransformer,
    means_dataset: Dataset,
    heads_lst, mlp_lst,
    is_permanent: bool = True,
) -> HookedTransformer:
    CIRCUIT = {}
    SEQ_POS_TO_KEEP = {}
    # for i in range(len(model.tokenizer.tokenize(means_dataset.prompts[0]['text']))):
    num_pos = len(model.tokenizer(means_dataset.prompts[0]['text']).input_ids)
    for i in range(num_pos ):
        CIRCUIT['S'+str(i)] = heads_lst
        # if i == len(model.tokenizer.tokenize(means_dataset.prompts[0]['text'])) - 1:
        # if i == num_pos - 1:
        #     SEQ_POS_TO_KEEP['S'+str(i)] = 'end'
        # else:
        SEQ_POS_TO_KEEP['S'+str(i)] = 'S'+str(i)

    model.reset_hooks(including_permanent=True)

    # Compute the mean of each head's output on the ABC dataset, grouped by template
    means = get_heads_actv_mean(means_dataset, model)
    # Convert this into a boolean map
    components_to_keep = mask_circ_heads(means_dataset, model, CIRCUIT, SEQ_POS_TO_KEEP)

    # Get a hook function which will patch in the mean z values for each head, at
    # all positions which aren't important for the circuit
    hook_fn = partial(
        hook_func_mask_head,
        # components_to_keep=components_to_keep,
        # means=means,
        circuit=CIRCUIT,
    )

    # Apply hook
    model.add_hook(lambda name: name.endswith("z"), hook_fn, is_permanent=is_permanent)

    # if all_entries_true(components_to_keep) == False:
    #     pdb.set_trace()
    ########################
    # CIRCUIT = {}
    # SEQ_POS_TO_KEEP = {}
    # # for i in range(len(model.tokenizer.tokenize(means_dataset.prompts[0]['text']))):
    # num_pos = len(model.tokenizer(means_dataset.prompts[0]['text']).input_ids)
    # for i in range(num_pos ):
    #     CIRCUIT['S'+str(i)] = mlp_lst
    #     # if i == len(model.tokenizer.tokenize(means_dataset.prompts[0]['text'])) - 1:
    #     # if i == num_pos - 1:
    #     #     SEQ_POS_TO_KEEP['S'+str(i)] = 'end'
    #     # else:
    #     SEQ_POS_TO_KEEP['S'+str(i)] = 'S'+str(i)

    # # Compute the mean of each head's output on the ABC dataset, grouped by template
    # means = get_MLPs_actv_mean(means_dataset, model)

    # # Convert this into a boolean map
    # components_to_keep = mask_circ_MLPs(means_dataset, model, CIRCUIT, SEQ_POS_TO_KEEP)

    # # Get a hook function which will patch in the mean z values for each head, at
    # # all positions which aren't important for the circuit
    # hook_fn = partial(
    #     hook_func_mask_mlp_out,
    #     components_to_keep=components_to_keep,
    #     means=means
    # )

    # model.add_hook(lambda name: name.endswith("mlp_out"), hook_fn, is_permanent=True)

    return model

In [23]:
def all_entries_true(tensor_dict):
    for key, tensor in tensor_dict.items():
        if not torch.all(tensor).item():
            return False
    return True

# ablation fns mult tok answers

In [24]:
def clean_gen(model, clean_text, corr_ans):
    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
    tokens = model.to_tokens(clean_text).to(device)
    # tokens = tokens[:, 1:] # get rid of prepend bos when using model.to_tokens

    total_score = 0
    corr_ans_tokLen = 0
    ans_so_far = ''
    # while True:
    for i in range(5):
        # print(f"Sequence so far: {model.to_string(tokens)[0]!r}")
        logits = model(tokens)
        next_token = logits[0, -1].argmax(dim=-1) # Get the predicted token at the end of our sequence
        next_char = model.to_string(next_token)

        corr_logits = logits[:, -1, next_token]
        total_score += corr_logits
        # print(f"logit diff of new char: {corr_logits}")

        ans_so_far += next_char
        corr_ans_tokLen += 1
        # print(f"{tokens.shape[-1]+1}th char = {next_char!r}")
        if ans_so_far == corr_ans:
            # print('\nTotal logit diff: ', total_score.item())
            break

        # Define new input sequence, by appending the previously generated token
        tokens = torch.cat([tokens, next_token[None, None]], dim=-1)
        # if next_char == '':
        #     next_char = ' '
        # clean_text = clean_text + next_char
        # tokens = model.to_tokens(clean_text).to(device)
    return corr_ans_tokLen

In [25]:
def ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen):
    tokens = model.to_tokens(clean_text).to(device)
    prompts_list = generate_prompts_list_longer(clean_text, tokens)

    corr_tokens = model.to_tokens(corr_text).to(device)
    prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
    pos_dict = {}
    num_pos = len(model.tokenizer(prompts_list_2[0]['text']).input_ids)
    for i in range(num_pos ):
        pos_dict['S'+str(i)] = i
    dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer)
    model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

    logits = model(tokens)
    next_token = logits[0, -1].argmax(dim=-1)
    next_char = model.to_string(next_token)

    total_score = 0

    for i in range(corr_ans_tokLen):
        if next_char == '':
            next_char = ' '

        clean_text = clean_text + next_char
        if i == corr_ans_tokLen - 1:
            print(model.to_string(tokens))
            # print(f"Sequence so far: {clean_text}")
            # print(f"{tokens.shape[-1]+1}th char = {next_char!r}")

        tokens = torch.cat([tokens, next_token[None, None]], dim=-1)

        # get new ablation dataset
        # model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

        # corr_text = corr_text + next_char
        # corr_tokens = torch.cat([corr_tokens, next_token[None, None]], dim=-1)
        # prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)

        # pos_dict = {}
        # num_pos = len(model.tokenizer(prompts_list_2[0]['text']).input_ids)
        # for i in range(num_pos ):
        #     pos_dict['S'+str(i)] = i

        # dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, corr_tokens)

        # model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

        logits = model(tokens)
        next_token = logits[0, -1].argmax(dim=-1) # Get the predicted token at the end of our sequence
        next_char = model.to_string(next_token)

        # new_score = get_logit_diff(logits, dataset)
        # total_score += new_score
        # print(f"corr logit of new char: {new_score}")
    # print('\n Total corr logit: ', total_score.item())

In [26]:
# Function to randomly choose 50 pairs ensuring less than 10 overlap with heads_of_circ
def choose_heads_to_remove(filtered_pairs, heads_of_circ, num_pairs=50, max_overlap=10):
    while True:
        head_to_remove = random.sample(filtered_pairs, num_pairs)
        overlap_count = len([head for head in head_to_remove if head in heads_of_circ])
        if overlap_count < max_overlap:
            return head_to_remove

In [27]:
def ablate_auto_score(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, correct_ans_tokLen):  # correct_ans
    tokens = model.to_tokens(clean_text).to(device)
    prompts_list = generate_prompts_list_longer(clean_text, tokens)

    corr_tokens = model.to_tokens(corr_text).to(device)
    prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
    pos_dict = {}
    num_pos = len(model.tokenizer(prompts_list_2[0]['text']).input_ids)
    for i in range(num_pos ):
        pos_dict['S'+str(i)] = i
    dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer)
    model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

    # logits = model(tokens)
    # next_token = logits[0, -1].argmax(dim=-1)
    # next_char = model.to_string(next_token)

    total_score = 0
    ans_so_far = ''
    for i in range(correct_ans_tokLen):
        # if next_char == '':
        #     next_char = ' '

        # clean_text = clean_text + next_char
        # if i == correct_ans_tokLen - 1:
        #     print(model.to_string(tokens))
        #     # print(f"Sequence so far: {clean_text}")
        #     # print(f"{tokens.shape[-1]+1}th char = {next_char!r}")

        # tokens = torch.cat([tokens, next_token[None, None]], dim=-1)

        # get new ablation dataset
        # model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

        # corr_text = corr_text + next_char
        # corr_tokens = torch.cat([corr_tokens, next_token[None, None]], dim=-1)
        # prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)

        # pos_dict = {}
        # num_pos = len(model.tokenizer(prompts_list_2[0]['text']).input_ids)
        # for i in range(num_pos ):
        #     pos_dict['S'+str(i)] = i

        # dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, corr_tokens)

        # model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

        logits = model(tokens)
        next_token = logits[0, -1].argmax(dim=-1) # Get the predicted token at the end of our sequence
        next_char = model.to_string(next_token)

        if next_char == '':
            next_char = ' '

        clean_text = clean_text + next_char
        # if i == correct_ans_tokLen - 1:
            # print(model.to_string(tokens))
            # print(f"Sequence so far: {clean_text}")
            # print(f"{tokens.shape[-1]+1}th char = {next_char!r}")

        tokens = torch.cat([tokens, next_token[None, None]], dim=-1)

        ans_so_far += next_char
        correct_ans_tokLen += 1
        # print(f"{tokens.shape[-1]+1}th char = {next_char!r}")
    # if ans_so_far == corr_ans:
        # print('\nTotal logit diff: ', total_score.item())
    return ans_so_far

        # new_score = get_logit_diff(logits, dataset)
        # total_score += new_score
        # print(f"corr logit of new char: {new_score}")
    # print('\n Total corr logit: ', total_score.item())

# Define circs

In [28]:
# from Llama2_numerals_1to10.ipynb
nums_1to9 = [(0, 2), (0, 5), (0, 6), (0, 15), (1, 15), (1, 28), (2, 13), (2, 24), (3, 24), (4, 3), (4, 16), (5, 11), (5, 13), (5, 15), (5, 16), (5, 23), (5, 25), (5, 27), (6, 11), (6, 14), (6, 20), (6, 23), (6, 24), (6, 26), (6, 28), (6, 30), (6, 31), (7, 0), (7, 13), (7, 21), (7, 30), (8, 0), (8, 2), (8, 12), (8, 15), (8, 26), (8, 27), (8, 30), (8, 31), (9, 15), (9, 16), (9, 23), (9, 26), (9, 27), (9, 29), (9, 31), (10, 1), (10, 13), (10, 18), (10, 23), (10, 29), (11, 7), (11, 8), (11, 9), (11, 17), (11, 18), (11, 25), (11, 28), (12, 18), (12, 19), (12, 23), (12, 27), (13, 6), (13, 11), (13, 20), (14, 18), (14, 19), (14, 20), (14, 21), (16, 0), (18, 19), (18, 21), (18, 25), (18, 26), (18, 31), (19, 28), (20, 17), (21, 0), (21, 2), (22, 18), (22, 20), (22, 25), (23, 27), (26, 2)]
len(nums_1to9)

84

In [29]:
# nw_circ = [(0, 1), (0, 4), (0, 6), (0, 7), (0, 8), (0, 10), (0, 11), (0, 12), (1, 16), (1, 24), (1, 27), (1, 28), (2, 2), (2, 5), (2, 8), (2, 24), (2, 30), (3, 7), (3, 14), (3, 19), (3, 23), (4, 3), (5, 16), (5, 25), (6, 11), (6, 14), (7, 0), (7, 30), (8, 0), (8, 2), (8, 3), (8, 4), (8, 6), (8, 21), (8, 31), (9, 1), (9, 3), (9, 7), (9, 11), (9, 29), (9, 31), (10, 13), (10, 18), (10, 23), (10, 24), (10, 25), (10, 27), (11, 18), (11, 28), (12, 18), (12, 26), (13, 11), (13, 17), (13, 18), (13, 19), (13, 20), (13, 21), (13, 23), (14, 7), (14, 14), (15, 25), (15, 28), (16, 0), (16, 12), (16, 14), (16, 15), (16, 16), (16, 19), (16, 24), (16, 29), (17, 17), (17, 23), (17, 31), (18, 31), (19, 12), (20, 17), (27, 20), (27, 25), (27, 27), (27, 31), (28, 5), (29, 5)]
# in order from most impt to least based on how much changes perf when ablated
nw_circ = [(20, 17), (5, 25), (16, 0), (29, 5), (3, 19), (6, 11), (15, 25), (8, 0), (16, 24), (8, 4), (7, 0), (6, 14), (16, 29), (5, 16), (12, 26), (4, 3), (3, 7), (7, 30), (11, 28), (28, 5), (17, 31), (13, 11), (13, 20), (12, 18), (1, 27), (10, 13), (18, 31), (8, 6), (9, 1), (0, 4), (2, 2), (9, 11), (19, 12), (1, 16), (13, 17), (9, 7), (11, 18), (2, 24), (10, 18), (9, 31), (9, 29), (2, 30), (2, 5), (1, 24), (2, 8), (15, 28), (27, 31), (16, 14), (3, 23), (3, 14), (10, 23), (27, 20), (8, 3), (14, 7), (14, 14), (16, 15), (8, 2), (17, 17), (0, 1), (10, 27), (16, 19), (0, 8), (0, 12), (1, 28), (0, 11), (17, 23), (0, 10), (0, 6), (13, 19), (8, 31), (10, 24), (16, 12), (13, 23), (13, 21), (27, 27), (9, 3), (27, 25), (16, 16), (8, 21), (0, 7), (13, 18), (10, 25)]
len(nw_circ)

82

In [30]:
# impt_months_heads = ([(23, 17), (17, 11), (16, 0), (26, 14), (18, 9), (5, 25), (22, 20), (6, 24), (26, 9), (12, 18), (13, 20), (19, 12), (27, 29), (13, 14), (16, 14), (12, 26), (19, 30), (16, 18), (31, 27), (26, 28), (16, 1), (18, 1), (19, 28), (18, 31), (29, 4), (17, 0), (14, 1), (17, 12), (12, 15), (28, 16), (10, 1), (16, 19), (9, 27), (30, 1), (19, 27), (0, 3), (15, 11), (21, 3), (11, 19), (12, 0), (23, 11), (8, 14), (16, 8), (22, 13), (13, 3), (4, 19), (14, 15), (12, 20), (19, 16), (18, 5)])
months_circ = [(20, 17), (6, 11), (16, 0), (5, 15), (17, 11), (23, 16), (5, 25), (7, 0), (26, 14), (6, 14), (12, 22), (8, 4), (12, 15), (16, 29), (15, 25), (5, 16), (18, 31), (14, 7), (11, 18), (4, 12), (3, 19), (12, 2), (11, 28), (4, 3), (18, 9), (8, 14), (12, 3), (11, 2), (10, 13), (4, 16), (1, 22), (11, 16), (3, 15), (13, 31), (2, 4), (2, 16), (8, 13), (0, 13), (8, 15), (12, 28), (1, 5), (0, 4), (0, 25), (3, 24), (13, 11), (1, 24), (8, 16), (13, 8), (3, 26), (0, 6), (3, 23), (1, 3), (14, 3), (8, 19), (8, 12), (14, 2), (8, 5), (1, 28), (8, 20), (2, 30), (8, 6), (10, 1), (13, 20), (19, 27)]
len(months_circ)

64

In [31]:
intersect_all = list(set(nums_1to9) & set(nw_circ) & set(months_circ))
len(intersect_all)

16

In [32]:
union_all = list(set(nums_1to9) | set(nw_circ) | set(months_circ))
len(union_all)

172

# auto measure fns

In [90]:
def ablate_circ_autoScore(model, circuit, sequences_as_str, next_members):
    corr_text = "5 3 9"
    list_outputs = []
    score = 0
    for clean_text, correct_ans in zip(sequences_as_str, next_members):
        correct_ans_tokLen = clean_gen(model, clean_text, correct_ans)

        heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
        head_to_remove = circuit
        heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

        mlps_not_ablate = [layer for layer in range(32)]

        output_after_ablate = ablate_auto_score(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, correct_ans_tokLen)
        list_outputs.append(output_after_ablate)
        print(correct_ans, output_after_ablate)
        if correct_ans == output_after_ablate:
            score += 1
    perc_score = score / len(next_members)
    return perc_score, list_outputs

In [91]:
def ablate_randcirc_autoScore(model, sequences_as_str, next_members, num_rand_runs, heads_not_overlap, num_heads_rand, num_not_overlap):
    corr_text = "5 3 9"
    list_outputs = []
    all_scores = []
    for clean_text, correct_ans in zip(sequences_as_str, next_members):
        prompt_score = 0
        correct_ans_tokLen = clean_gen(model, clean_text, correct_ans)
        for j in range(num_rand_runs):
            all_possible_pairs =  [(layer, head) for layer in range(32) for head in range(32)]
            filtered_pairs = [pair for pair in all_possible_pairs if pair not in heads_not_overlap] # Filter out heads_not_overlap from all_possible_pairs

            # Randomly choose num_heads_rand pairs ensuring less than num_not_overlap overlaps with heads_not_overlap
            head_to_remove = choose_heads_to_remove(filtered_pairs, heads_not_overlap, num_heads_rand, num_not_overlap)

            heads_not_ablate = [x for x in all_possible_pairs if x not in head_to_remove]

            mlps_not_ablate = [layer for layer in range(32)]

            output_after_ablate = ablate_auto_score(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, correct_ans_tokLen)
            # list_outputs.append(output_after_ablate)
            # print(correct_ans, output_after_ablate)
            if correct_ans == output_after_ablate:
                prompt_score += 1
        print(prompt_score / num_rand_runs)
        all_scores.append(prompt_score / num_rand_runs)

    perc_score = sum(all_scores) / len(next_members)
    return perc_score, list_outputs

In [92]:
# import random

# def gen_single_addition_prompts(num_prompts):
#     sequences = []
#     next_members = []

#     single_digit_additions = [(random.randint(0, 9), random.randint(0, 9)) for _ in range(num_prompts)]
#     # double_digit_additions = [(random.randint(10, 99), random.randint(10, 99)) for _ in range(num_prompts)]

#     for a, b in single_digit_additions:
#         prompt = f"{a} + {b} = "
#         answer = str(a + b)
#         sequences.append(prompt)
#         next_members.append(answer)

#     # for a, b in double_digit_additions:
#     #     prompt = f"{a} + {b} = "
#     #     answer = str(a + b)
#     #     sequences.append(prompt)
#     #     next_members.append(answer)

#     print("Sequences:")
#     print(sequences)
#     print("\nNext Members:")
#     print(next_members)
#     return sequences, next_members

import random

def gen_single_addition_prompts(num_prompts):
    sequences = []
    next_members = []
    seen_pairs = set()

    while len(sequences) < num_prompts:
        a = random.randint(0, 9)
        b = random.randint(0, 9)
        if (a, b) not in seen_pairs:
            seen_pairs.add((a, b))
            prompt = f"{a} + {b} = "
            answer = str(a + b)
            sequences.append(prompt)
            next_members.append(answer)

    print("Sequences:")
    print(sequences)
    print("\nNext Members:")
    print(next_members)
    return sequences, next_members

In [93]:
# import random

# def gen_double_addition_prompts(num_prompts):
#     sequences = []
#     next_members = []

#     # single_digit_additions = [(random.randint(0, 9), random.randint(0, 9)) for _ in range(num_prompts)]
#     double_digit_additions = [(random.randint(10, 99), random.randint(10, 99)) for _ in range(num_prompts)]

#     # for a, b in single_digit_additions:
#     #     prompt = f"{a} + {b} = "
#     #     answer = str(a + b)
#     #     sequences.append(prompt)
#     #     next_members.append(answer)

#     for a, b in double_digit_additions:
#         prompt = f"{a} + {b} = "
#         answer = str(a + b)
#         sequences.append(prompt)
#         next_members.append(answer)

#     print("Sequences:")
#     print(sequences)
#     print("\nNext Members:")
#     print(next_members)
#     return sequences, next_members

import random

def gen_double_addition_prompts(num_prompts):
    sequences = []
    next_members = []
    seen_pairs = set()

    while len(sequences) < num_prompts:
        a = random.randint(10, 99)
        b = random.randint(10, 99)
        if (a, b) not in seen_pairs:
            seen_pairs.add((a, b))
            prompt = f"{a} + {b} = "
            answer = str(a + b)
            sequences.append(prompt)
            next_members.append(answer)

    print("Sequences:")
    print(sequences)
    print("\nNext Members:")
    print(next_members)
    return sequences, next_members

In [94]:
# import random

# def gen_s_subtraction_prompts(num_prompts):
#     sequences = []
#     next_members = []

#     single_digit_subtractions = [(random.randint(0, 9), random.randint(0, 9)) for _ in range(num_prompts)]
#     # double_digit_subtractions = [(random.randint(10, 99), random.randint(10, 99)) for _ in range(num_prompts)]

#     for a, b in single_digit_subtractions:
#         # Ensure a is greater than or equal to b to avoid negative results
#         a, b = max(a, b), min(a, b)
#         prompt = f"{a} - {b} = "
#         answer = str(a - b)
#         sequences.append(prompt)
#         next_members.append(answer)

#     # for a, b in double_digit_subtractions:
#     #     # Ensure a is greater than or equal to b to avoid negative results
#     #     a, b = max(a, b), min(a, b)
#     #     prompt = f"{a} - {b} = "
#     #     answer = str(a - b)
#     #     sequences.append(prompt)
#     #     next_members.append(answer)

#     print("Sequences:")
#     print(sequences)
#     print("\nNext Members:")
#     print(next_members)
#     return sequences, next_members

import random

def gen_s_subtraction_prompts(num_prompts):
    sequences = []
    next_members = []
    seen_pairs = set()

    while len(sequences) < num_prompts:
        a = random.randint(0, 9)
        b = random.randint(0, 9)
        a, b = max(a, b), min(a, b)  # Ensure a is greater than or equal to b to avoid negative results
        if (a, b) not in seen_pairs:
            seen_pairs.add((a, b))
            prompt = f"{a} - {b} = "
            answer = str(a - b)
            sequences.append(prompt)
            next_members.append(answer)

    print("Sequences:")
    print(sequences)
    print("\nNext Members:")
    print(next_members)
    return sequences, next_members

In [95]:
# import random

# def gen_d_subtraction_prompts(num_prompts):
#     sequences = []
#     next_members = []

#     # single_digit_subtractions = [(random.randint(0, 9), random.randint(0, 9)) for _ in range(num_prompts)]
#     double_digit_subtractions = [(random.randint(10, 99), random.randint(10, 99)) for _ in range(num_prompts)]

#     # for a, b in single_digit_subtractions:
#     #     # Ensure a is greater than or equal to b to avoid negative results
#     #     a, b = max(a, b), min(a, b)
#     #     prompt = f"{a} - {b} = "
#     #     answer = str(a - b)
#     #     sequences.append(prompt)
#     #     next_members.append(answer)

#     for a, b in double_digit_subtractions:
#         # Ensure a is greater than or equal to b to avoid negative results
#         a, b = max(a, b), min(a, b)
#         prompt = f"{a} - {b} = "
#         answer = str(a - b)
#         sequences.append(prompt)
#         next_members.append(answer)

#     print("Sequences:")
#     print(sequences)
#     print("\nNext Members:")
#     print(next_members)
#     return sequences, next_members

import random

def gen_d_subtraction_prompts(num_prompts):
    sequences = []
    next_members = []
    seen_pairs = set()

    while len(sequences) < num_prompts:
        a = random.randint(10, 99)
        b = random.randint(10, 99)
        a, b = max(a, b), min(a, b)  # Ensure a is greater than or equal to b to avoid negative results
        if (a, b) not in seen_pairs:
            seen_pairs.add((a, b))
            prompt = f"{a} - {b} = "
            answer = str(a - b)
            sequences.append(prompt)
            next_members.append(answer)

    print("Sequences:")
    print(sequences)
    print("\nNext Members:")
    print(next_members)
    return sequences, next_members


# addition

In [38]:
num_prompts = 50
sequences_as_str, next_members = gen_addition_prompts(num_prompts)

Sequences:
['1 + 1 = ', '9 + 9 = ', '7 + 7 = ', '1 + 9 = ', '4 + 1 = ', '6 + 0 = ', '4 + 0 = ', '0 + 4 = ', '3 + 3 = ', '7 + 7 = ', '1 + 9 = ', '8 + 3 = ', '3 + 2 = ', '5 + 1 = ', '2 + 4 = ', '2 + 5 = ', '1 + 6 = ', '8 + 2 = ', '8 + 0 = ', '9 + 9 = ', '9 + 8 = ', '8 + 9 = ', '2 + 9 = ', '7 + 1 = ', '2 + 9 = ', '5 + 1 = ', '5 + 5 = ', '6 + 0 = ', '5 + 5 = ', '5 + 0 = ', '5 + 4 = ', '9 + 4 = ', '4 + 8 = ', '8 + 2 = ', '0 + 4 = ', '8 + 7 = ', '3 + 7 = ', '7 + 5 = ', '8 + 0 = ', '3 + 8 = ', '1 + 9 = ', '7 + 2 = ', '4 + 9 = ', '6 + 4 = ', '5 + 6 = ', '0 + 9 = ', '5 + 6 = ', '7 + 2 = ', '5 + 8 = ', '2 + 4 = ', '78 + 98 = ', '71 + 87 = ', '71 + 59 = ', '15 + 82 = ', '74 + 16 = ', '28 + 78 = ', '21 + 13 = ', '20 + 66 = ', '23 + 36 = ', '49 + 50 = ', '48 + 71 = ', '33 + 84 = ', '76 + 15 = ', '46 + 59 = ', '25 + 41 = ', '71 + 29 = ', '50 + 92 = ', '26 + 18 = ', '94 + 99 = ', '96 + 25 = ', '13 + 48 = ', '52 + 10 = ', '85 + 61 = ', '75 + 50 = ', '53 + 42 = ', '78 + 53 = ', '82 + 51 = ', '89 + 99 =

In [39]:
# all_heads = [(layer, head) for layer in range(32) for head in range(32)]
# input the circuit to ablate, not what to keep
# perc_score, list_outputs = ablate_circ_autoScore(model, [], sequences_as_str, next_members)
# perc_score

In [40]:
perc_score, list_outputs = ablate_circ_autoScore(model, intersect_all, sequences_as_str, next_members)
perc_score

2 2
18 18
14 14
10 10
5 5
6 6
4 4
4 4
6 6
14 14
10 10
11 11
5 5
6 6
6 6
7 7
7 7
10 10
8 8<0x0A><0x0A>8+
18 18
17 17
17 17
11 11
8 9<0x0A><0x0A>Yourturn
11 11
6 6
10 10
6 6
10 10
5 5<0x0A><0x0A>5+
9 9
13 13
12 12
10 10
4 4
15 15
10 10
12 12
8 8<0x0A><0x0A>8+
11 11
10 10
9 9
13 13
10 10
11 11
9 9
11 11
9 9
13 13
6 6


KeyboardInterrupt: 

In [41]:
perc_score, list_outputs = ablate_circ_autoScore(model, nums_1to9, sequences_as_str, next_members)
perc_score

2 1
18 18
14 14
10 12
5 5
6 6
4 4
4 4
6 3
14 14
10 12
11 12
5 3
6 6
6 1
7 1
7 1
10 10
8 8<0x0A><0x0A><0x0A>8
18 18
17 18
17 16
11 12
8 14<0x0A><0x0A><0x0A>
11 12
6 6
10 10
6 6
10 10
5 5<0x0A><0x0A><0x0A>5
9 1
13 10
12 12
10 10
4 4
15 15
10 33
12 12
8 8<0x0A><0x0A><0x0A>8
11 33
10 12
9 1
13 18
10 10
11 11
9 0
11 11
9 1
13 10
6 1
176 88<0x0A>
158 711
130 71<0x0A>
97 11
90 90
106 30<0x0A>
34 14
86 20
59 36
99 49
119 491
117 333
91 76
105 55<0x0A>
66 25
100 711
142 500
44 42
193 94<0x0A>
121 97<0x0A><0x0A> 
61 11
62 62
146 85<0x0A>
125 755
95 53
131 788
133 82<0x0A>
188 99<0x0A>
145 75<0x0A>
96 40
108 60<0x0A>
125 588
151 56<0x0A>
140 140
105 335
139 70<0x0A>
124 1000<0x0A>
77 11
91 71
152 600
83 42
98 71
54 54
129 333
135 733
56 26
138 449
77 61
57 33
162 100<0x0A><0x0A>


0.26

In [42]:
perc_score, list_outputs = ablate_circ_autoScore(model, nw_circ, sequences_as_str, next_members)
perc_score

2 2
18 18
14 14
10 1+
5 5
6 6
4 4
4 4
6 6
14 14
10 1+
11 10
5 5
6 6
6 1
7 1
7 7
10 10
8 8<0x0A><0x0A>Hinweis:
18 18
17 18
17 10
11 10
8 8<0x0A><0x0A><0x0A>So
11 10
6 6
10 10
6 6
10 10
5 5<0x0A><0x0A><0x0A>Answer
9 1
13 14
12 18
10 10
4 4
15 15
10 10
12 12
8 8<0x0A><0x0A>Hinweis:
11 10
10 1+
9 1
13 10
10 10
11 11
9 9
11 11
9 1
13 10
6 1
176 80<0x0A>
158 50<0x0A>
130 71+
97 12
90 <0x0A><0x0A>
106 156
34 5<0x0A>
86 12
59 10
99 9<0x0A>
119 <0x0A><0x0A> 
117 110
91 12
105 <0x0A><0x0A> 
66 10
100 79<0x0A>
142 50+
44 10
193 94+
121 10<0x0A><0x0A> 
61 10
62 52
146 85+
125 120
95 10
131 115
133 85<0x0A>
188 170
145 70+
96 40
108 8<0x0A><0x0A>
125 125
151 10<0x0A>
140 84+
105 10<0x0A>
139 10<0x0A>
124 10<0x0A><0x0A> 
77 7<0x0A>
91 71
152 60+
83 14
98 19
54 17
129 <0x0A><0x0A><0x0A>
135 10<0x0A>
56 6<0x0A>
138 164
77 63
57 7<0x0A>
162 10<0x0A><0x0A> 


0.28

In [43]:
perc_score, list_outputs = ablate_circ_autoScore(model, months_circ, sequences_as_str, next_members)
perc_score

2 1
18 18
14 49
10 10
5 1
6 6
4 0
4 0
6 3
14 49
10 10
11 10
5 1
6 1
6 2
7 1
7 1
10 10
8 8<0x0A><0x0A><0x0A><0x0A>
18 18
17 10
17 10
11 10
8 10<0x0A><0x0A><0x0A>
11 10
6 1
10 10
6 6
10 10
5 5<0x0A><0x0A><0x0A><0x0A>
9 1
13 10
12 16
10 10
4 0
15 10
10 3+
12 15
8 8<0x0A><0x0A><0x0A><0x0A>
11 3+
10 10
9 1
13 10
10 16
11 10
9 0
11 10
9 1
13 50
6 2
176 70<0x0A>
158 71+
130 71+


KeyboardInterrupt: 

In [None]:
num_rand_runs = 10
heads_not_overlap = intersect_all
num_heads_rand = 100
num_not_overlap = len(intersect_all)
perc_score, list_outputs = ablate_randcirc_autoScore(model, sequences_as_str, next_members,
                                                    num_rand_runs, heads_not_overlap, num_heads_rand, num_not_overlap)
perc_score

# single digit addition

In [96]:
num_prompts = 50
sequences_as_str, next_members = gen_single_addition_prompts(num_prompts)

Sequences:
['2 + 7 = ', '0 + 5 = ', '2 + 4 = ', '6 + 8 = ', '0 + 2 = ', '3 + 7 = ', '6 + 1 = ', '7 + 9 = ', '8 + 6 = ', '5 + 6 = ', '9 + 0 = ', '2 + 5 = ', '4 + 0 = ', '0 + 9 = ', '6 + 4 = ', '0 + 1 = ', '0 + 3 = ', '1 + 8 = ', '1 + 5 = ', '1 + 6 = ', '7 + 5 = ', '3 + 6 = ', '3 + 0 = ', '9 + 5 = ', '5 + 3 = ', '7 + 0 = ', '9 + 8 = ', '5 + 4 = ', '1 + 9 = ', '1 + 3 = ', '0 + 6 = ', '4 + 7 = ', '1 + 1 = ', '3 + 3 = ', '1 + 4 = ', '2 + 0 = ', '0 + 0 = ', '4 + 1 = ', '4 + 2 = ', '7 + 7 = ', '7 + 8 = ', '4 + 8 = ', '5 + 9 = ', '2 + 2 = ', '8 + 7 = ', '8 + 5 = ', '0 + 4 = ', '1 + 7 = ', '8 + 0 = ', '9 + 4 = ']

Next Members:
['9', '5', '6', '14', '2', '10', '7', '16', '14', '11', '9', '7', '4', '9', '10', '1', '3', '9', '6', '7', '12', '9', '3', '14', '8', '7', '17', '9', '10', '4', '6', '11', '2', '6', '5', '2', '0', '5', '6', '14', '15', '12', '14', '4', '15', '13', '4', '8', '8', '13']


In [97]:
# all_heads = [(layer, head) for layer in range(32) for head in range(32)]
# input the circuit to ablate, not what to keep
# perc_score, list_outputs = ablate_circ_autoScore(model, [], sequences_as_str, next_members)
# perc_score

In [None]:
perc_score, list_outputs = ablate_circ_autoScore(model, intersect_all, sequences_as_str, next_members)
perc_score

9 9
5 5
6 6
14 14
2 2
10 10
7 7


In [None]:
perc_score, list_outputs = ablate_circ_autoScore(model, nums_1to9, sequences_as_str, next_members)
perc_score

In [None]:
perc_score, list_outputs = ablate_circ_autoScore(model, nw_circ, sequences_as_str, next_members)
perc_score

In [None]:
perc_score, list_outputs = ablate_circ_autoScore(model, months_circ, sequences_as_str, next_members)
perc_score

In [None]:
num_rand_runs = 10
heads_not_overlap = intersect_all
num_heads_rand = 100
num_not_overlap = len(intersect_all)
perc_score, list_outputs = ablate_randcirc_autoScore(model, sequences_as_str, next_members,
                                                    num_rand_runs, heads_not_overlap, num_heads_rand, num_not_overlap)
perc_score

# double digit addition

In [53]:
num_prompts = 50
sequences_as_str, next_members = gen_double_addition_prompts(num_prompts)

Sequences:
['95 + 45 = ', '12 + 23 = ', '68 + 82 = ', '36 + 79 = ', '87 + 75 = ', '64 + 53 = ', '26 + 38 = ', '35 + 56 = ', '13 + 55 = ', '35 + 32 = ', '91 + 60 = ', '62 + 17 = ', '38 + 16 = ', '25 + 30 = ', '15 + 13 = ', '47 + 98 = ', '72 + 58 = ', '40 + 72 = ', '51 + 91 = ', '47 + 49 = ', '57 + 71 = ', '81 + 67 = ', '30 + 37 = ', '55 + 51 = ', '51 + 31 = ', '77 + 86 = ', '28 + 91 = ', '75 + 33 = ', '94 + 55 = ', '36 + 44 = ', '15 + 13 = ', '15 + 18 = ', '90 + 14 = ', '66 + 85 = ', '86 + 30 = ', '33 + 70 = ', '68 + 56 = ', '62 + 95 = ', '72 + 30 = ', '32 + 65 = ', '68 + 29 = ', '56 + 94 = ', '11 + 49 = ', '56 + 15 = ', '56 + 21 = ', '55 + 94 = ', '93 + 27 = ', '49 + 59 = ', '25 + 93 = ', '11 + 80 = ']

Next Members:
['140', '35', '150', '115', '162', '117', '64', '91', '68', '67', '151', '79', '54', '55', '28', '145', '130', '112', '142', '96', '128', '148', '67', '106', '82', '163', '119', '108', '149', '80', '28', '33', '104', '151', '116', '103', '124', '157', '102', '97', '97', '1

In [54]:
# all_heads = [(layer, head) for layer in range(32) for head in range(32)]
# input the circuit to ablate, not what to keep
# perc_score, list_outputs = ablate_circ_autoScore(model, [], sequences_as_str, next_members)
# perc_score

In [55]:
perc_score, list_outputs = ablate_circ_autoScore(model, intersect_all, sequences_as_str, next_members)
perc_score

140 <0x0A><0x0A>Pleaseenteryour
35 <0x0A><0x0A>
150 <0x0A><0x0A>What
115 <0x0A><0x0A><0x0A>
162 <0x0A><0x0A>What
117 <0x0A><0x0A><0x0A>
64 64
91 91
68 68
67 <0x0A><0x0A>
151 <0x0A><0x0A>Please
79 79
54 <0x0A><0x0A>Pleaseentera
55 55
28 <0x0A><0x0A>
145 <0x0A><0x0A>Please
130 120
112 <0x0A><0x0A>Please
142 91<0x0A>
96 <0x0A><0x0A>
128 <0x0A><0x0A>What
148 <0x0A><0x0A><0x0A>
67 37
106 <0x0A><0x0A>Please
82 82
163 <0x0A><0x0A>Please
119 119
108 <0x0A><0x0A>Please
149 <0x0A><0x0A>Please
80 <0x0A><0x0A>
28 <0x0A><0x0A>
33 <0x0A><0x0A>
104 104
151 <0x0A><0x0A><0x0A>
116 <0x0A><0x0A>Please
103 103
124 <0x0A><0x0A>Whatisthe
157 <0x0A><0x0A><0x0A>
102 <0x0A><0x0A>Please
97 97
97 97
150 <0x0A><0x0A><0x0A>
60 60
71 71
77 77
149 <0x0A><0x0A><0x0A>
120 <0x0A><0x0A>Pleaseprovidethe
108 <0x0A><0x0A>Please
118 118
91 91


0.32

In [56]:
perc_score, list_outputs = ablate_circ_autoScore(model, nums_1to9, sequences_as_str, next_members)
perc_score

140 95<0x0A><0x0A><0x0A>
35 13
150 698
115 336
162 88<0x0A>
117 64<0x0A>
64 26
91 33
68 11
67 33
151 91<0x0A>
79 62
54 338<0x0A><0x0A>
55 30
28 13
145 498
130 72<0x0A>
112 407
142 511
96 47
128 577
148 88<0x0A>
67 33
106 555
82 51
163 777
119 28<0x0A>
108 78<0x0A>
149 94<0x0A>
80 33
28 13
33 12
104 900
151 70<0x0A>
116 86<0x0A>
103 333
124 700<0x0A><0x0A>
157 62<0x0A>
102 72<0x0A>
97 33
97 69
150 566
60 11
71 56
77 56
149 555
120 93<0x0A><0x0A><0x0A>
108 59<0x0A>
118 27<0x0A>
91 11


0.0

In [57]:
perc_score, list_outputs = ablate_circ_autoScore(model, nw_circ, sequences_as_str, next_members)
perc_score

140 10<0x0A><0x0A> 
35 10
150 10<0x0A>
115 36+
162 150
117 10<0x0A>
64 10
91 10
68 18
67 10
151 91<0x0A>
79 65
54 <0x0A><0x0A><0x0A>Answer:
55 5<0x0A>
28 8<0x0A>
145 <0x0A><0x0A> 
130 72+
112 40+
142 51+
96 <0x0A><0x0A>
128 120
148 81+
67 30
106 10<0x0A>
82 5<0x0A>
163 142
119 10<0x0A>
108 10<0x0A>
149 99<0x0A>
80 10
28 8<0x0A>
33 12
104 95<0x0A>
151 6+ 
116 86<0x0A>
103 30<0x0A>
124 <0x0A><0x0A> 10
157 62+
102 72<0x0A>
97 32
97 68
150 10<0x0A>
60 1+
71 11
77 10
149 10<0x0A>
120 <0x0A><0x0A> 93
108 249
118 10<0x0A>
91 8<0x0A>


0.0

In [58]:
perc_score, list_outputs = ablate_circ_autoScore(model, months_circ, sequences_as_str, next_members)
perc_score

140 95+ 4
35 23
150 16+
115 36<0x0A>
162 150
117 10<0x0A>
64 26
91 35
68 15
67 15
151 91+
79 17
54 38<0x0A><0x0A><0x0A>
55 25
28 15
145 47+
130 72+
112 40+
142 51+
96 49
128 57<0x0A>
148 81+
67 30
106 10<0x0A>
82 51
163 420
119 28+
108 15+
149 10<0x0A>
80 36
28 15
33 15
104 90+
151 6+ 
116 86+
103 30<0x0A>
124 120<0x0A><0x0A>
157 62+
102 72+
97 32
97 29
150 10<0x0A>
60 10
71 15
77 10
149 10<0x0A>
120 93+ 2
108 49+
118 25+
91 10


0.0

In [59]:
num_rand_runs = 10
heads_not_overlap = intersect_all
num_heads_rand = 100
num_not_overlap = len(intersect_all)
perc_score, list_outputs = ablate_randcirc_autoScore(model, sequences_as_str, next_members,
                                                    num_rand_runs, heads_not_overlap, num_heads_rand, num_not_overlap)
perc_score

0.0
1.0
0.5
0.6
0.4
0.7
0.7
0.7
0.5
0.7
0.7
0.9
0.0
0.9
0.8
1.0
0.7
0.7
0.8
0.8
0.6
0.7
0.4
0.4
0.9
0.5
0.9
0.7
0.7
0.5
0.8
0.9
0.6
0.5
0.6
0.7
0.0
0.8
0.7
1.0
1.0
0.5
0.9
0.8
0.8
0.8
0.0
0.8
0.8
0.8


0.6639999999999999

# s subtraction

In [76]:
num_prompts = 50
sequences_as_str, next_members = gen_s_subtraction_prompts(num_prompts)

Sequences:
['9 - 0 = ', '3 - 0 = ', '6 - 0 = ', '8 - 0 = ', '9 - 2 = ', '9 - 4 = ', '7 - 7 = ', '6 - 3 = ', '2 - 0 = ', '6 - 1 = ', '4 - 1 = ', '5 - 2 = ', '4 - 4 = ', '9 - 2 = ', '3 - 0 = ', '8 - 0 = ', '7 - 6 = ', '9 - 3 = ', '3 - 0 = ', '7 - 3 = ']

Next Members:
['9', '3', '6', '8', '7', '5', '0', '3', '2', '5', '3', '3', '0', '7', '3', '8', '1', '6', '3', '4']


In [77]:
# all_heads = [(layer, head) for layer in range(32) for head in range(32)]
# input the circuit to ablate, not what to keep
perc_score, list_outputs = ablate_circ_autoScore(model, [], sequences_as_str, next_members)
perc_score

9 9
3 3
6 6
8 8
7 7
5 5
0 0
3 3
2 2
5 5
3 3
3 3
0 0
7 7
3 3
8 8
1 1
6 6
3 3
4 4


1.0

In [78]:
perc_score, list_outputs = ablate_circ_autoScore(model, intersect_all, sequences_as_str, next_members)
perc_score

9 9
3 3
6 6
8 8
7 9
5 9
0 0
3 6
2 2
5 5
3 4
3 3
0 0
7 9
3 3
8 8
1 </s>
6 6
3 3
4 4


0.7

In [79]:
perc_score, list_outputs = ablate_circ_autoScore(model, nums_1to9, sequences_as_str, next_members)
perc_score

9 9
3 3
6 6
8 8
7 4
5 1
0 4
3 2
2 2
5 6
3 4
3 2
0 4
7 4
3 3
8 8
1 4
6 3
3 3
4 2


0.4

In [80]:
perc_score, list_outputs = ablate_circ_autoScore(model, nw_circ, sequences_as_str, next_members)
perc_score

9 9
3 3
6 6
8 8
7 6
5 1
0 0
3 3
2 2
5 6
3 4
3 3
0 2
7 6
3 3
8 8
1 1
6 6
3 3
4 4


0.7

In [81]:
perc_score, list_outputs = ablate_circ_autoScore(model, months_circ, sequences_as_str, next_members)
perc_score

9 9
3 0
6 0
8 8
7 1
5 1
0 7
3 1
2 0
5 1
3 1
3 1
0 1
7 1
3 0
8 8
1 1
6 1
3 0
4 1


0.2

In [82]:
num_rand_runs = 10
heads_not_overlap = intersect_all
num_heads_rand = 100
num_not_overlap = len(intersect_all)
perc_score, list_outputs = ablate_randcirc_autoScore(model, sequences_as_str, next_members,
                                                    num_rand_runs, heads_not_overlap, num_heads_rand, num_not_overlap)
perc_score

1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0


1.0

# d subtraction

In [83]:
num_prompts = 50
sequences_as_str, next_members = gen_d_subtraction_prompts(num_prompts)

Sequences:
['79 - 13 = ', '84 - 11 = ', '45 - 21 = ', '23 - 12 = ', '55 - 52 = ', '82 - 48 = ', '91 - 69 = ', '76 - 59 = ', '75 - 51 = ', '52 - 45 = ', '62 - 18 = ', '70 - 42 = ', '94 - 18 = ', '53 - 13 = ', '68 - 41 = ', '44 - 18 = ', '53 - 14 = ', '36 - 35 = ', '93 - 18 = ', '90 - 69 = ', '34 - 21 = ', '81 - 16 = ', '70 - 66 = ', '52 - 47 = ', '49 - 12 = ', '89 - 51 = ', '55 - 35 = ', '45 - 20 = ', '53 - 37 = ', '93 - 47 = ', '62 - 62 = ', '76 - 30 = ', '72 - 43 = ', '68 - 68 = ', '49 - 31 = ', '36 - 12 = ', '93 - 20 = ', '64 - 31 = ', '78 - 46 = ', '83 - 29 = ', '97 - 83 = ', '82 - 16 = ', '94 - 16 = ', '92 - 25 = ', '13 - 12 = ', '78 - 22 = ', '61 - 18 = ', '79 - 19 = ', '81 - 51 = ', '71 - 37 = ']

Next Members:
['66', '73', '24', '11', '3', '34', '22', '17', '24', '7', '44', '28', '76', '40', '27', '26', '39', '1', '75', '21', '13', '65', '4', '5', '37', '38', '20', '25', '16', '46', '0', '46', '29', '0', '18', '24', '73', '33', '32', '54', '14', '66', '78', '67', '1', '56', '43'

In [84]:
# all_heads = [(layer, head) for layer in range(32) for head in range(32)]
# input the circuit to ablate, not what to keep
perc_score, list_outputs = ablate_circ_autoScore(model, [], sequences_as_str, next_members)
perc_score

66 66
73 73
24 24
11 11
3 3
34 34
22 22
17 17
24 24
7 7
44 44
28 28
76 76
40 40
27 27
26 26
39 39
1 1
75 75
21 21
13 13
65 65
4 4
5 5
37 37
38 38
20 20
25 25
16 16
46 46
0 0
46 46
29 29
0 0
18 18
24 24
73 73
33 33
32 32
54 54
14 14
66 66
78 78
67 67
1 1
56 56
43 43
60 60
30 30
34 34


1.0

In [85]:
perc_score, list_outputs = ablate_circ_autoScore(model, intersect_all, sequences_as_str, next_members)
perc_score

66 66
73 73
24 24
11 11
3 6
34 34
22 22
17 17
24 24
7 1
44 44
28 28
76 76
40 40
27 27
26 26
39 39
1 3
75 75
21 21
13 13
65 65
4 7
5 6
37 36
38 38
20 20
25 <0x0A><0x0A>
16 16
46 46
0 0
46 46
29 29
0 0
18 18
24 24
73 73
33 33
32 32
54 54
14 14
66 66
78 78
67 67
1 <0x0A>
56 56
43 42
60 60
30 30
34 34


0.82

In [86]:
perc_score, list_outputs = ablate_circ_autoScore(model, nums_1to9, sequences_as_str, next_members)
perc_score

66 6<0x0A>
73 73
24 22
11 17
3 5
34 17
22 91
17 76
24 14
7 1
44 62
28 17
76 73
40 42
27 16
26 44
39 53
1 3
75 93
21 13
13 33
65 81
4 1
5 1
37 49
38 17
20 17
25 22
16 15
46 21
0 5
46 76
29 17
0 6
18 14
24 33
73 93
33 64
32 78
54 26
14 10
66 82
78 78
67 36
1 1
56 38
43 61
60 79
30 81
34 71


0.06

In [87]:
perc_score, list_outputs = ablate_circ_autoScore(model, nw_circ, sequences_as_str, next_members)
perc_score

66 6<0x0A>
73 84
24 3<0x0A>
11 3<0x0A>
3 2
34 72
22 53
17 7<0x0A>
24 2<0x0A>
7 1
44 56
28 65
76 32
40 2<0x0A>
27 12
26 4<0x0A>
39 20
1 2
75 25
21 8<0x0A>
13 3<0x0A>
65 73
4 7
5 3
37 5<0x0A>
38 14
20 20
25 5<0x0A>
16 20
46 21
0 3
46 7<0x0A>
29 18
0 8
18 14
24 3<0x0A>
73 93
33 1<0x0A>
32 12
54 53
14 10
66 76
78 94
67 92
1 2
56 5<0x0A>
43 53
60 6<0x0A>
30 79
34 21


0.02

In [88]:
perc_score, list_outputs = ablate_circ_autoScore(model, months_circ, sequences_as_str, next_members)
perc_score

66 7<0x0A>
73 10
24 45
11 23
3 2
34 24
22 69
17 39
24 35
7 1
44 18
28 70
76 18
40 13
27 41
26 44
39 14
1 3
75 18
21 90
13 14
65 16
4 7
5 2
37 49
38 31
20 15
25 45
16 37
46 33
0 6
46 30
29 23
0 6
18 49
24 12
73 10
33 10
32 46
54 83
14 73
66 16
78 16
67 25
1 1
56 14
43 18
60 7<0x0A>
30 81
34 37


0.02

In [89]:
num_rand_runs = 10
heads_not_overlap = intersect_all
num_heads_rand = 100
num_not_overlap = len(intersect_all)
perc_score, list_outputs = ablate_randcirc_autoScore(model, sequences_as_str, next_members,
                                                    num_rand_runs, heads_not_overlap, num_heads_rand, num_not_overlap)
perc_score

1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0


1.0