# Setup

## Change Inputs Here

In [1]:
model_name = "gpt2-small"
save_files = True

In [2]:
%%capture
%pip install git+https://github.com/neelnanda-io/TransformerLens.git

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
# import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

import pickle
from google.colab import files

import matplotlib.pyplot as plt
import statistics

In [4]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [5]:
import pdb

## Load Model

In [6]:
torch.set_grad_enabled(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
model = HookedTransformer.from_pretrained(
    model_name,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



Loaded pretrained model gpt2-small into HookedTransformer


## Import functions from repo

In [8]:
!git clone https://github.com/apartresearch/seqcont_circuits.git
%cd /content/seqcont_circuits/src/iter_node_pruning

Cloning into 'seqcont_circuits'...
remote: Enumerating objects: 1069, done.[K
remote: Counting objects: 100% (535/535), done.[K
remote: Compressing objects: 100% (312/312), done.[K
remote: Total 1069 (delta 336), reused 403 (delta 212), pack-reused 534 (from 1)[K
Receiving objects: 100% (1069/1069), 19.61 MiB | 19.52 MiB/s, done.
Resolving deltas: 100% (699/699), done.
/content/seqcont_circuits/src/iter_node_pruning


In [9]:
## comment this out when debugging functions in colab to use funcs defined in colab

# don't improt this
# # from dataset import Dataset

from metrics import *
from head_ablation_fns import *
from mlp_ablation_fns import *
from node_ablation_fns import *
from loop_node_ablation_fns import *

## fns

In [10]:
class Dataset:
    def __init__(self, prompts, pos_dict, tokenizer):  # , S1_is_first=False
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        self.corr_tokenIDs = [
            # self.tokenizer.encode(" " + prompt["corr"])[0] for prompt in self.prompts
            self.tokenizer.encode(prompt["corr"])[0] for prompt in self.prompts
        ]
        self.incorr_tokenIDs = [
            # self.tokenizer.encode(" " + prompt["incorr"])[0] for prompt in self.prompts
            self.tokenizer.encode(prompt["incorr"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        # for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
        for targ in [key for key in pos_dict]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = self.tokenizer.tokenize(input_text)
                # if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                #     target_token = prompt[targ]
                # else:
                #     target_token = "Ġ" + prompt[targ]
                # target_index = tokens.index(target_token)
                target_index = pos_dict[targ]
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [11]:
def generate_prompts_list_longer(text, tokens):
    prompts_list = []
    prompt_dict = {
        'corr': str(1),
        'incorr': str(2),
        'text': text}
    tokens_as_strs = model.tokenizer.tokenize(text)
    # for i in range(tokens.shape[1]):
    for i, tok in enumerate(tokens_as_strs):
        prompt_dict['S'+str(i)] = tok
    prompts_list.append(prompt_dict)
    return prompts_list

# Load datasets

In [None]:
def generate_prompts_list(x ,y):
    prompts_list = []
    for i in range(x, y):
        prompt_dict = {
            'S1': str(i),
            'S2': str(i+1),
            'S3': str(i+2),
            'S4': str(i+3),
            'corr': str(i+4),
            'incorr': str(i+3),
            'text': f"{i} {i+1} {i+2} {i+3}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(1, 2)
prompts_list

[{'S1': '1',
  'S2': '2',
  'S3': '3',
  'S4': '4',
  'corr': '5',
  'incorr': '4',
  'text': '1 2 3 4'}]

In [None]:
pos_dict = {}
for i in range(len(model.tokenizer.tokenize(prompts_list[0]['text']))):
    pos_dict['S'+str(i)] = i

In [None]:
dataset = Dataset(prompts_list, pos_dict, model.tokenizer)

In [None]:
import random

def generate_prompts_list_corr(prompt_list):
    outlist = []
    # for i in range(100):
    for prompt_dict in prompts_list:
        r1 = random.randint(1, 12)
        r2 = random.randint(1, 12)
        while True:
            r3 = random.randint(1, 12)
            r4 = random.randint(1, 12)
            if r4 - 1 != r3:
                break
        new_text = prompt_dict['text'].replace(prompt_dict['S1'], str(r1)).replace(prompt_dict['S2'], str(r2)).replace(prompt_dict['S3'], str(r3)).replace(prompt_dict['S4'], str(r4))
        new_prompt_dict = {
            'S1': str(r1),
            'S2': str(r2),
            'S3': str(r3),
            'S4': str(r4),
            'corr': prompt_dict['corr'],
            'incorr': prompt_dict['incorr'],
            'text': new_text
        }
        outlist.append(new_prompt_dict)
    return outlist
prompts_list_2 = generate_prompts_list_corr(prompts_list)
len(prompts_list_2)

1

In [None]:
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer)

## Get orig score

In [None]:
model.reset_hooks(including_permanent=True)
logits_original = model(dataset.toks)
orig_score = get_logit_diff(logits_original, dataset)
orig_score

tensor(6.0631, device='cuda:0')

In [None]:
import gc

del(logits_original)
torch.cuda.empty_cache()
gc.collect()

35

# logit diff for mult tok answers

In [19]:
def clean_gen(model, clean_text, corr_ans):
    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
    tokens = model.to_tokens(clean_text).to(device)
    tokens = tokens[:, 1:] # get rid of prepend bos when using model.to_tokens

    total_score = 0
    corr_ans_tokLen = 0
    ans_so_far = ''
    # while True:
    for i in range(5):
        print(f"Sequence so far: {model.to_string(tokens)[0]!r}")
        logits = model(tokens)
        next_token = logits[0, -1].argmax(dim=-1) # Get the predicted token at the end of our sequence
        next_char = model.to_string(next_token)

        dataset = Dataset(prompts_list, pos_dict, model.tokenizer)

        # new_score = get_logit_diff(logits, dataset)

        # measure how far away predicted logit is from corr token?

        # corr_logits = logits[:, dataset.word_idx["end"], dataset.corr_tokenIDs]
        # incorr_logits = logits[:, dataset.word_idx["end"], dataset.incorr_tokenIDs]
        # new_score = corr_logits - incorr_logits

        corr_logits = logits[:, -1, next_token]
        total_score += corr_logits
        print(f"logit diff of new char: {corr_logits}")

        ans_so_far += next_char
        corr_ans_tokLen += 1
        print(f"{tokens.shape[-1]+1}th char = {next_char!r}")
        if ans_so_far == corr_ans:
            print('\nTotal logit diff: ', total_score.item())
            break
        # Define new input sequence, by appending the previously generated token
        tokens = torch.cat([tokens, next_token[None, None]], dim=-1)
    return corr_ans_tokLen

In [20]:
clean_text = "1 2 3 4"
corr_ans = ' 5'
corr_ans_tokLen = clean_gen(model, clean_text, corr_ans)

Sequence so far: '1 2 3 4'
logit diff of new char: tensor([16.8118], device='cuda:0')
5th char = ' 5'

Total logit diff:  16.811798095703125


In [21]:
def ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen):
    tokens = model.to_tokens(clean_text).to(device)
    prompts_list = generate_prompts_list_longer(clean_text, tokens)

    corr_tokens = model.to_tokens(corr_text).to(device)
    prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
    pos_dict = {}
    for i in range(len(model.tokenizer.tokenize(prompts_list_2[0]['text']))):
        pos_dict['S'+str(i)] = i
    dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer)
    model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

    tokens = tokens[:, 1:] # get rid of prepend bos when using model.to_tokens
    logits = model(tokens)
    next_token = logits[0, -1].argmax(dim=-1)
    next_char = model.to_string(next_token)

    total_score = 0

    print(f"Sequence so far: {model.to_string(tokens)[0]!r}")
    for i in range(corr_ans_tokLen):
    # for i in range(5):
        print(f"{tokens.shape[-1]+1}th char = {next_char!r}")

        clean_text = clean_text + next_char
        tokens = model.to_tokens(clean_text).to(device)
        tokens = tokens[:, 1:]
        print(clean_text)

        # get new ablation dataset
        model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

        corr_text = corr_text + next_char
        corr_tokens = model.to_tokens(corr_text).to(device)
        prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)
        print(corr_text)

        pos_dict = {}
        for i in range(len(model.tokenizer.tokenize(prompts_list_2[0]['text']))):
            pos_dict['S'+str(i)] = i

        dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer)

        model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

        logits = model(tokens)
        next_token = logits[0, -1].argmax(dim=-1) # Get the predicted token at the end of our sequence
        next_char = model.to_string(next_token)

        print('\n')
        print(f"Sequence so far: {model.to_string(tokens)[0]!r}")

        new_score = get_logit_diff(logits, dataset)
        total_score += new_score
        print(f"corr logit of new char: {new_score}")
    print('\n Total corr logit: ', total_score.item())

In [22]:
clean_text = "1 2 3"
corr_text = "5 3 9"
heads_not_ablate = []  # ablate all heads but not MLPs
mlps_not_ablate = []  # ablate all MLPs
ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3'
4th char = '.'
1 2 3.
5 3 9.


Sequence so far: '1 2 3.'
corr logit of new char: 0.8616600036621094

 Total corr logit:  0.8616600036621094


# new ablation functions

In [27]:
def get_heads_actv_mean(
    means_dataset: Dataset,
    model: HookedTransformer
) -> Float[Tensor, "layer batch seq head_idx d_head"]:
    '''
    Output: The mean activations of a head's output
    '''
    _, means_cache = model.run_with_cache(
        means_dataset.toks.long(),
        return_type=None,
        names_filter=lambda name: name.endswith("z"),
    )
    n_layers, n_heads, d_head = model.cfg.n_layers, model.cfg.n_heads, model.cfg.d_head
    batch, seq_len = len(means_dataset), means_dataset.max_len
    means = t.zeros(size=(n_layers, batch, seq_len, n_heads, d_head), device=model.cfg.device)

    # for layer in range(model.cfg.n_layers):
    #     z_for_this_layer: Float[Tensor, "batch seq head d_head"] = means_cache[utils.get_act_name("z", layer)]
    #     for template_group in means_dataset.groups:
    #         z_for_this_template = z_for_this_layer[template_group]
    #         z_means_for_this_template = einops.reduce(z_for_this_template, "batch seq head d_head -> seq head d_head", "mean")
    #         if z_means_for_this_template.shape[0] == 5:
    #             pdb.set_trace()
    #         means[layer, template_group] = z_means_for_this_template

    del(means_cache)

    return means

In [28]:
# def mask_circ_heads(
#     means_dataset: Dataset,
#     model: HookedTransformer,
#     circuit: Dict[str, List[Tuple[int, int]]],
#     seq_pos_to_keep: Dict[str, str],
# ) -> Dict[int, Bool[Tensor, "batch seq head"]]:
#     '''
#     Output: for each layer, a mask of circuit components that should not be ablated
#     '''
#     heads_and_posns_to_keep = {}
#     batch, seq, n_heads = len(means_dataset), means_dataset.max_len, model.cfg.n_heads

#     for layer in range(model.cfg.n_layers):

#         mask = t.zeros(size=(batch, seq, n_heads))

#         for (head_type, head_list) in circuit.items():
#             seq_pos = seq_pos_to_keep[head_type]
#             # if seq_pos == 'S7':
#             #     pdb.set_trace()
#             indices = means_dataset.word_idx[seq_pos] # modify this for key vs query pos. curr, this is query
#             for (layer_idx, head_idx) in head_list:
#                 if layer_idx == layer:
#                     # if indices.item() == 7:
#                     #     pdb.set_trace()
#                     mask[:, indices, head_idx] = 1
#                     # mask[:, :, head_idx] = 1  # keep L.H at all pos

#         heads_and_posns_to_keep[layer] = mask.bool()
#     # pdb.set_trace()
#     return heads_and_posns_to_keep

In [29]:
def mask_circ_heads(
    means_dataset: Dataset,
    model: HookedTransformer,
    circuit: Dict[str, List[Tuple[int, int]]],
    seq_pos_to_keep: Dict[str, str],
) -> Dict[int, Bool[Tensor, "batch seq head"]]:
    '''
    Output: for each layer, a mask of circuit components that should not be ablated
    '''
    heads_and_posns_to_keep = {}
    # batch, seq, n_heads = len(means_dataset), means_dataset.max_len, model.cfg.n_heads
    batch, seq, n_heads = len(means_dataset), len(circuit.keys()), model.cfg.n_heads
    # print(seq)

    for layer in range(model.cfg.n_layers):

        mask = t.zeros(size=(batch, seq, n_heads))

        for (head_type, head_list) in circuit.items():
            seq_pos = seq_pos_to_keep[head_type]
            indices = means_dataset.word_idx[seq_pos] # modify this for key vs query pos. curr, this is query
            for (layer_idx, head_idx) in head_list:
                if layer_idx == layer:
                    # mask[:, indices, head_idx] = 1
                    mask[:, :, head_idx] = 1

        heads_and_posns_to_keep[layer] = mask.bool()

    return heads_and_posns_to_keep

In [30]:
def hook_func_mask_head(
    z: Float[Tensor, "batch seq head d_head"],
    hook: HookPoint,
    # components_to_keep: Dict[int, Bool[Tensor, "batch seq head"]],
    # means: Float[Tensor, "layer batch seq head d_head"],
    circuit: Dict[str, List[Tuple[int, int]]],
) -> Float[Tensor, "batch seq head d_head"]:
    '''
    Use this to not mask components
    '''
    # mask_for_this_layer = components_to_keep[hook.layer()].unsqueeze(-1).to(z.device)
    # z = t.where(mask_for_this_layer, z, means[hook.layer()])

    ###
    # heads_and_posns_to_keep = {}
    # batch, seq, n_heads = z.shape[0], z.shape[1], model.cfg.n_heads  # components_to_keep[0].shape[0] is batch

    # for layer in range(model.cfg.n_layers):

    #     mask = t.zeros(size=(batch, seq, n_heads))

    #     for (head_type, head_list) in circuit.items():
    #         # seq_pos = seq_pos_to_keep[head_type]
    #         # indices = means_dataset.word_idx[seq_pos] # modify this for key vs query pos. curr, this is query
    #         for (layer_idx, head_idx) in head_list:
    #             if layer_idx == layer:
    #                 # mask[:, indices, head_idx] = 1
    #                 mask[:, :, head_idx] = 1

    #     heads_and_posns_to_keep[layer] = mask.bool()
    ###
    mask_for_this_layer = t.zeros(size=(z.shape[0], z.shape[1], z.shape[2]))
    for (head_type, head_list) in circuit.items():
        # seq_pos = seq_pos_to_keep[head_type]
        # indices = means_dataset.word_idx[seq_pos] # modify this for key vs query pos. curr, this is query
        for (layer_idx, head_idx) in head_list:
            if layer_idx == hook.layer():
                # mask[:, indices, head_idx] = 1
                mask_for_this_layer[:, :, head_idx] = 1

    mask_for_this_layer = mask_for_this_layer.bool()
    mask_for_this_layer = mask_for_this_layer.unsqueeze(-1).to(z.device)  # d_model is 1; then is broadcast in where

    z = t.where(mask_for_this_layer, z, 0)

    return z

In [31]:
def add_ablation_hook_head(
    model: HookedTransformer,
    means_dataset: Dataset,
    circuit: Dict[str, List[Tuple[int, int]]],
    seq_pos_to_keep: Dict[str, str],
    is_permanent: bool = True,
) -> HookedTransformer:
    '''
    Ablate the model, except as components and positions to keep
    '''

    model.reset_hooks(including_permanent=True)
    means = get_heads_actv_mean(means_dataset, model)
    components_to_keep = mask_circ_heads(means_dataset, model, circuit, seq_pos_to_keep)

    hook_fn = partial(
        hook_func_mask_head,
        # components_to_keep=components_to_keep,
        # means=means,
        circuit=circuit,
    )

    model.add_hook(lambda name: name.endswith("z"), hook_fn, is_permanent=is_permanent)
    return model

In [32]:
# from dataset import Dataset
from transformer_lens import HookedTransformer, utils
from transformer_lens.hook_points import HookPoint
import einops
from functools import partial
import torch as t
from torch import Tensor
from typing import Dict, Tuple, List
from jaxtyping import Float, Bool

# from head_ablation_fns import *
# from mlp_ablation_fns import *

def add_ablation_hook_MLP_head(
    model: HookedTransformer,
    means_dataset: Dataset,
    heads_lst, mlp_lst,
    is_permanent: bool = True,
) -> HookedTransformer:
    CIRCUIT = {}
    SEQ_POS_TO_KEEP = {}
    # for i in range(len(model.tokenizer.tokenize(means_dataset.prompts[0]['text']))):
    num_pos = len(model.tokenizer(means_dataset.prompts[0]['text']).input_ids)
    for i in range(num_pos ):
        CIRCUIT['S'+str(i)] = heads_lst
        # if i == len(model.tokenizer.tokenize(means_dataset.prompts[0]['text'])) - 1:
        # if i == num_pos - 1:
        #     SEQ_POS_TO_KEEP['S'+str(i)] = 'end'
        # else:
        SEQ_POS_TO_KEEP['S'+str(i)] = 'S'+str(i)

    model.reset_hooks(including_permanent=True)

    # Compute the mean of each head's output on the ABC dataset, grouped by template
    means = get_heads_actv_mean(means_dataset, model)
    # Convert this into a boolean map
    components_to_keep = mask_circ_heads(means_dataset, model, CIRCUIT, SEQ_POS_TO_KEEP)

    # Get a hook function which will patch in the mean z values for each head, at
    # all positions which aren't important for the circuit
    hook_fn = partial(
        hook_func_mask_head,
        # components_to_keep=components_to_keep,
        # means=means,
        circuit=CIRCUIT,
    )

    # Apply hook
    model.add_hook(lambda name: name.endswith("z"), hook_fn, is_permanent=is_permanent)

    # if all_entries_true(components_to_keep) == False:
    #     pdb.set_trace()
    ########################
    # CIRCUIT = {}
    # SEQ_POS_TO_KEEP = {}
    # # for i in range(len(model.tokenizer.tokenize(means_dataset.prompts[0]['text']))):
    # num_pos = len(model.tokenizer(means_dataset.prompts[0]['text']).input_ids)
    # for i in range(num_pos ):
    #     CIRCUIT['S'+str(i)] = mlp_lst
    #     # if i == len(model.tokenizer.tokenize(means_dataset.prompts[0]['text'])) - 1:
    #     # if i == num_pos - 1:
    #     #     SEQ_POS_TO_KEEP['S'+str(i)] = 'end'
    #     # else:
    #     SEQ_POS_TO_KEEP['S'+str(i)] = 'S'+str(i)

    # # Compute the mean of each head's output on the ABC dataset, grouped by template
    # means = get_MLPs_actv_mean(means_dataset, model)

    # # Convert this into a boolean map
    # components_to_keep = mask_circ_MLPs(means_dataset, model, CIRCUIT, SEQ_POS_TO_KEEP)

    # # Get a hook function which will patch in the mean z values for each head, at
    # # all positions which aren't important for the circuit
    # hook_fn = partial(
    #     hook_func_mask_mlp_out,
    #     components_to_keep=components_to_keep,
    #     means=means
    # )

    # model.add_hook(lambda name: name.endswith("mlp_out"), hook_fn, is_permanent=True)

    return model

In [33]:
def all_entries_true(tensor_dict):
    for key, tensor in tensor_dict.items():
        if not torch.all(tensor).item():
            return False
    return True

# ablation fns mult tok answers

In [34]:
def clean_gen(model, clean_text, corr_ans):
    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
    tokens = model.to_tokens(clean_text).to(device)
    # tokens = tokens[:, 1:] # get rid of prepend bos when using model.to_tokens

    total_score = 0
    corr_ans_tokLen = 0
    ans_so_far = ''
    # while True:
    for i in range(5):
        print(f"Sequence so far: {model.to_string(tokens)[0]!r}")
        logits = model(tokens)
        next_token = logits[0, -1].argmax(dim=-1) # Get the predicted token at the end of our sequence
        next_char = model.to_string(next_token)

        corr_logits = logits[:, -1, next_token]
        total_score += corr_logits
        print(f"logit diff of new char: {corr_logits}")

        ans_so_far += next_char
        corr_ans_tokLen += 1
        print(f"{tokens.shape[-1]+1}th char = {next_char!r}")
        if ans_so_far == corr_ans:
            print('\nTotal logit diff: ', total_score.item())
            break

        # Define new input sequence, by appending the previously generated token
        tokens = torch.cat([tokens, next_token[None, None]], dim=-1)
        # if next_char == '':
        #     next_char = ' '
        # clean_text = clean_text + next_char
        # tokens = model.to_tokens(clean_text).to(device)
    return corr_ans_tokLen

In [35]:
def ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen):
    tokens = model.to_tokens(clean_text).to(device)
    prompts_list = generate_prompts_list_longer(clean_text, tokens)

    corr_tokens = model.to_tokens(corr_text).to(device)
    prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
    pos_dict = {}
    num_pos = len(model.tokenizer(prompts_list_2[0]['text']).input_ids)
    for i in range(num_pos ):
        pos_dict['S'+str(i)] = i
    dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer)
    model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

    logits = model(tokens)
    next_token = logits[0, -1].argmax(dim=-1)
    next_char = model.to_string(next_token)

    total_score = 0

    for i in range(corr_ans_tokLen):
        if next_char == '':
            next_char = ' '

        clean_text = clean_text + next_char
        # if i == corr_ans_tokLen - 1:
        #     print(model.to_string(tokens))
            # print(f"Sequence so far: {clean_text}")
            # print(f"{tokens.shape[-1]+1}th char = {next_char!r}")

        tokens = torch.cat([tokens, next_token[None, None]], dim=-1)

        # get new ablation dataset
        # model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

        # corr_text = corr_text + next_char
        # corr_tokens = torch.cat([corr_tokens, next_token[None, None]], dim=-1)
        # prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)

        # pos_dict = {}
        # num_pos = len(model.tokenizer(prompts_list_2[0]['text']).input_ids)
        # for i in range(num_pos ):
        #     pos_dict['S'+str(i)] = i

        # dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, corr_tokens)

        # model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

        logits = model(tokens)
        next_token = logits[0, -1].argmax(dim=-1) # Get the predicted token at the end of our sequence
        next_char = model.to_string(next_token)

        # new_score = get_logit_diff(logits, dataset)
        # total_score += new_score
        # print(f"corr logit of new char: {new_score}")
    # print('\n Total corr logit: ', total_score.item())
    return model.to_string(tokens)

In [36]:
def ablate_auto_score(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, correct_ans):
    tokens = model.to_tokens(clean_text).to(device)
    prompts_list = generate_prompts_list_longer(clean_text, tokens)

    corr_tokens = model.to_tokens(corr_text).to(device)
    prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
    pos_dict = {}
    num_pos = len(model.tokenizer(prompts_list_2[0]['text']).input_ids)
    for i in range(num_pos ):
        pos_dict['S'+str(i)] = i
    dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer)
    model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

    # logits = model(tokens)
    # next_token = logits[0, -1].argmax(dim=-1)
    # next_char = model.to_string(next_token)

    total_score = 0
    ans_so_far = ''
    ans_str_tok = tokenizer.tokenize(correct_ans)[1:] # correct_ans is str
    corr_tokenIDs = []
    for correct_ansPos in range(len(ans_str_tok)):
        tokID = model.tokenizer.encode(ans_str_tok[correct_ansPos])[2:][0] # 2: to skip padding <s> and ''
        corr_tokenIDs.append(tokID)
    correct_ans_tokLen = len(corr_tokenIDs)
    for ansPos in range(correct_ans_tokLen):
        # if next_char == '':
        #     next_char = ' '

        # clean_text = clean_text + next_char
        # if i == correct_ans_tokLen - 1:
        #     print(model.to_string(tokens))
        #     # print(f"Sequence so far: {clean_text}")
        #     # print(f"{tokens.shape[-1]+1}th char = {next_char!r}")

        # tokens = torch.cat([tokens, next_token[None, None]], dim=-1)

        # get new ablation dataset
        # model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

        # corr_text = corr_text + next_char
        # corr_tokens = torch.cat([corr_tokens, next_token[None, None]], dim=-1)
        # prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)

        # pos_dict = {}
        # num_pos = len(model.tokenizer(prompts_list_2[0]['text']).input_ids)
        # for i in range(num_pos ):
        #     pos_dict['S'+str(i)] = i

        # dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, corr_tokens)

        # model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

        logits = model(tokens)
        next_token = logits[0, -1].argmax(dim=-1) # Get the predicted token at the end of our sequence
        next_char = model.to_string(next_token)

        if next_char == '':
            next_char = ' '

        clean_text = clean_text + next_char
        # if i == correct_ans_tokLen - 1:
            # print(model.to_string(tokens))
            # print(f"Sequence so far: {clean_text}")
            # print(f"{tokens.shape[-1]+1}th char = {next_char!r}")

        tokens = torch.cat([tokens, next_token[None, None]], dim=-1)

        ans_so_far += next_char
        correct_ans_tokLen += 1
        # print(f"{tokens.shape[-1]+1}th char = {next_char!r}")

        ansTok_IDs = torch.tensor(corr_tokenIDs[ansPos])

        # new_score = get_logit_diff(logits, dataset)
        # total_score += new_score
        # corrTok_logits = logits[:, -1, next_token]
        corrTok_logits = logits[range(logits.size(0)), -1, ansTok_IDs]  # not next_token, as that's what's pred, not the token to measure
        # pdb.set_trace()
        total_score += corrTok_logits
        # print(f"corr logit of new char: {new_score}")
    # print('\n Total corr logit: ', total_score.item())
    # return ans_so_far, total_score.item()
    return ans_so_far

# auto measure fns

In [37]:
def ablate_circ_autoScore(model, circuit, sequences_as_str, next_members):
    corr_text = "5 3 9"
    list_outputs = []
    score = 0
    for clean_text, correct_ans in zip(sequences_as_str, next_members):
        correct_ans_tokLen = clean_gen(model, clean_text, correct_ans)

        heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
        head_to_remove = circuit
        heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

        mlps_not_ablate = [layer for layer in range(32)]

        output_after_ablate = ablate_auto_score(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, correct_ans_tokLen)
        list_outputs.append(output_after_ablate)
        print(correct_ans, output_after_ablate)
        if correct_ans == output_after_ablate:
            score += 1
    perc_score = score / len(next_members)
    return perc_score, list_outputs

In [38]:
def ablate_randcirc_autoScore(model, sequences_as_str, next_members, num_rand_runs, heads_not_overlap, num_heads_rand, num_not_overlap):
    corr_text = "5 3 9"
    list_outputs = []
    all_scores = []
    for clean_text, correct_ans in zip(sequences_as_str, next_members):
        prompt_score = 0
        correct_ans_tokLen = clean_gen(model, clean_text, correct_ans)
        for j in range(num_rand_runs):
            all_possible_pairs =  [(layer, head) for layer in range(32) for head in range(32)]
            filtered_pairs = [pair for pair in all_possible_pairs if pair not in heads_not_overlap] # Filter out heads_not_overlap from all_possible_pairs

            # Randomly choose num_heads_rand pairs ensuring less than num_not_overlap overlaps with heads_not_overlap
            head_to_remove = choose_heads_to_remove(filtered_pairs, heads_not_overlap, num_heads_rand, num_not_overlap)

            heads_not_ablate = [x for x in all_possible_pairs if x not in head_to_remove]

            mlps_not_ablate = [layer for layer in range(32)]

            output_after_ablate = ablate_auto_score(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, correct_ans_tokLen)
            # list_outputs.append(output_after_ablate)
            # print(correct_ans, output_after_ablate)
            if correct_ans == output_after_ablate:
                prompt_score += 1
        print(prompt_score / num_rand_runs)
        all_scores.append(prompt_score / num_rand_runs)

    perc_score = sum(all_scores) / len(next_members)
    return perc_score, list_outputs

# 1 2 3 genr ablation expms

In [None]:
clean_text = "1 2 3"
corr_text = "5 3 9"

## ablate just head 9.1 and MLP 9

In [None]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((9, 1))

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)

Sequence so far: '1 2 3'
4th char = ' 4'
1 2 3 4
5 3 9 4


Sequence so far: '1 2 3 4'
corr logit of new char: 0.42961740493774414
5th char = ' 5'
1 2 3 4 5
5 3 9 4 5


Sequence so far: '1 2 3 4 5'
corr logit of new char: 0.42962169647216797
6th char = ' 6'
1 2 3 4 5 6
5 3 9 4 5 6


Sequence so far: '1 2 3 4 5 6'
corr logit of new char: 0.4296226501464844
7th char = ' 7'
1 2 3 4 5 6 7
5 3 9 4 5 6 7


Sequence so far: '1 2 3 4 5 6 7'
corr logit of new char: 0.42961835861206055
8th char = ' 8'
1 2 3 4 5 6 7 8
5 3 9 4 5 6 7 8


Sequence so far: '1 2 3 4 5 6 7 8'
corr logit of new char: 0.42961645126342773

 Total corr logit:  2.1480965614318848


## ablate 4.4, 7.11, 9.1

In [None]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)

Sequence so far: '1 2 3'
4th char = ' 4'
1 2 3 4
5 3 9 4


Sequence so far: '1 2 3 4'
corr logit of new char: 3.36678409576416
5th char = ' 5'
1 2 3 4 5
5 3 9 4 5


Sequence so far: '1 2 3 4 5'
corr logit of new char: 3.3667917251586914
6th char = ' 6'
1 2 3 4 5 6
5 3 9 4 5 6


Sequence so far: '1 2 3 4 5 6'
corr logit of new char: 3.3667922019958496
7th char = ' 7'
1 2 3 4 5 6 7
5 3 9 4 5 6 7


Sequence so far: '1 2 3 4 5 6 7'
corr logit of new char: 3.3667869567871094
8th char = ' 8'
1 2 3 4 5 6 7 8
5 3 9 4 5 6 7 8


Sequence so far: '1 2 3 4 5 6 7 8'
corr logit of new char: 3.3667917251586914

 Total corr logit:  16.833946228027344


## ablate mlp 9

In [None]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)

Sequence so far: '1 2 3'
4th char = ' 4'
1 2 3 4
5 3 9 4


Sequence so far: '1 2 3 4'
corr logit of new char: 0.8109622001647949
5th char = ' 5'
1 2 3 4 5
5 3 9 4 5


Sequence so far: '1 2 3 4 5'
corr logit of new char: 0.8109650611877441
6th char = ' 6'
1 2 3 4 5 6
5 3 9 4 5 6


Sequence so far: '1 2 3 4 5 6'
corr logit of new char: 0.8109645843505859
7th char = ' 7'
1 2 3 4 5 6 7
5 3 9 4 5 6 7


Sequence so far: '1 2 3 4 5 6 7'
corr logit of new char: 0.8109612464904785
8th char = ' 8'
1 2 3 4 5 6 7 8
5 3 9 4 5 6 7 8


Sequence so far: '1 2 3 4 5 6 7 8'
corr logit of new char: 0.8109622001647949

 Total corr logit:  4.054815292358398


## ablate 4.4, 7.11, 9.1 and mlp 9

In [None]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)

Sequence so far: '1 2 3'
4th char = ' 3'
1 2 3 3
5 3 9 3


Sequence so far: '1 2 3 3'
corr logit of new char: -1.0452260971069336
5th char = ' 3'
1 2 3 3 3
5 3 9 3 3


Sequence so far: '1 2 3 3 3'
corr logit of new char: -1.0452265739440918
6th char = ' 3'
1 2 3 3 3 3
5 3 9 3 3 3


Sequence so far: '1 2 3 3 3 3'
corr logit of new char: -1.0452251434326172
7th char = ' 3'
1 2 3 3 3 3 3
5 3 9 3 3 3 3


Sequence so far: '1 2 3 3 3 3 3'
corr logit of new char: -1.0452260971069336
8th char = ' 3'
1 2 3 3 3 3 3 3
5 3 9 3 3 3 3 3


Sequence so far: '1 2 3 3 3 3 3 3'
corr logit of new char: -1.0452251434326172

 Total corr logit:  -5.226129055023193


## 6.2, 4.1, 7.1

In [None]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(6, 2), (4,1), (7,1)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)

Sequence so far: '1 2 3'
4th char = ' 4'
1 2 3 4
5 3 9 4


Sequence so far: '1 2 3 4'
corr logit of new char: 0.8091115951538086
5th char = ' 5'
1 2 3 4 5
5 3 9 4 5


Sequence so far: '1 2 3 4 5'
corr logit of new char: 0.8091144561767578
6th char = ' 6'
1 2 3 4 5 6
5 3 9 4 5 6


Sequence so far: '1 2 3 4 5 6'
corr logit of new char: 0.8091154098510742
7th char = ' 7'
1 2 3 4 5 6 7
5 3 9 4 5 6 7


Sequence so far: '1 2 3 4 5 6 7'
corr logit of new char: 0.809107780456543
8th char = ' 8'
1 2 3 4 5 6 7 8
5 3 9 4 5 6 7 8


Sequence so far: '1 2 3 4 5 6 7 8'
corr logit of new char: 0.8091120719909668

 Total corr logit:  4.04556131362915


In [None]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
# heads_not_ablate = [(9, 1)]
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]
len(heads_not_ablate)

141

# one two three

In [None]:
clean_text = "one two three"
corr_text = "five nine two"
corr_ans_tokLen = 1

clean

In [None]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'one two three'
4th char = ' four'
one two three four
five nine two four


Sequence so far: 'one two three four'
corr logit of new char: 5.505155086517334

 Total corr logit:  5.505155086517334


corrupt the subcircuit

In [None]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'one two three'
4th char = '-'
one two three-
five nine two-


Sequence so far: 'one two three-'
corr logit of new char: -0.5595006942749023

 Total corr logit:  -0.5595006942749023


ablate 4.4, 7.11, 9.1

In [None]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'one two three'
4th char = '-'
one two three-
five nine two-


Sequence so far: 'one two three-'
corr logit of new char: -1.0319595336914062

 Total corr logit:  -1.0319595336914062


corrupt 9.1 and mlp9

In [None]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((9, 1))

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'one two three'
4th char = ' four'
one two three four
five nine two four


Sequence so far: 'one two three four'
corr logit of new char: 2.635061264038086

 Total corr logit:  2.635061264038086


ablate mlp 9

In [None]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'one two three'
4th char = ' four'
one two three four
five nine two four


Sequence so far: 'one two three four'
corr logit of new char: 3.019651412963867

 Total corr logit:  3.019651412963867


ablate just 9.1

In [None]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((9, 1))

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'one two three'
4th char = ' four'
one two three four
five nine two four


Sequence so far: 'one two three four'
corr logit of new char: 5.045960903167725

 Total corr logit:  5.045960903167725


ablate random head

In [None]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((6, 2))

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'one two three'
4th char = ' four'
one two three four
five nine two four


Sequence so far: 'one two three four'
corr logit of new char: 5.4420366287231445

 Total corr logit:  5.4420366287231445


ablate all

In [None]:
heads_not_ablate = [ ]

mlps_not_ablate = []

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'one two three'
4th char = ' five'
one two three five
five nine two five


Sequence so far: 'one two three five'
corr logit of new char: -0.12970507144927979

 Total corr logit:  -0.12970507144927979


# January February March

In [None]:
clean_text = "January February March"
corr_text = "April July July"
corr_ans_tokLen = 1

clean

In [None]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'January February March'
4th char = ' April'
January February March April
April July July April


Sequence so far: 'January February March April'
corr logit of new char: 9.440199851989746

 Total corr logit:  9.440199851989746


corrupt the subcircuit

In [None]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'January February March'
4th char = ' August'
January February March August
April July July August


Sequence so far: 'January February March August'
corr logit of new char: -1.1505086421966553

 Total corr logit:  -1.1505086421966553


ablate 4.4, 7.11, 9.1

In [None]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'January February March'
4th char = ' April'
January February March April
April July July April


Sequence so far: 'January February March April'
corr logit of new char: 0.5852069854736328

 Total corr logit:  0.5852069854736328


corrupt 9.1 and mlp9

In [None]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((9, 1))

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'January February March'
4th char = ' April'
January February March April
April July July April


Sequence so far: 'January February March April'
corr logit of new char: -3.3458361625671387

 Total corr logit:  -3.3458361625671387


ablate mlp 9

In [None]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'January February March'
4th char = ' April'
January February March April
April July July April


Sequence so far: 'January February March April'
corr logit of new char: -0.9824318885803223

 Total corr logit:  -0.9824318885803223


ablate just 9.1

In [None]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((9, 1))

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'January February March'
4th char = ' April'
January February March April
April July July April


Sequence so far: 'January February March April'
corr logit of new char: 8.566516876220703

 Total corr logit:  8.566516876220703


ablate random head

In [None]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((6, 2))

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'January February March'
4th char = ' April'
January February March April
April July July April


Sequence so far: 'January February March April'
corr logit of new char: 9.635494232177734

 Total corr logit:  9.635494232177734


ablate all

In [None]:
heads_not_ablate = [ ]

mlps_not_ablate = []

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'January February March'
4th char = ' August'
January February March August
April July July August


Sequence so far: 'January February March August'
corr logit of new char: -1.855480670928955

 Total corr logit:  -1.855480670928955


# fns

# numerals

In [41]:
correct_prompts = []
for i in range(1, 9):
    correct_prompts.append(f"{i} {i+1} {i+2} {i+3}")

In [54]:
corr_ans = []
for i in range(1, 9):
    corr_ans.append(str(i+4))

In [59]:
# clean
num_corr = 0
corr_text = "0 0 0 0"
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
mlps_not_ablate = [layer for layer in range(12)]

big3_outputs = []
lst_out = []
for clean_text, ans in zip(correct_prompts, corr_ans):
    prompt_out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 1)
    big3_outputs.append(prompt_out[0])
    answer = prompt_out[0].split(' ')[-1]
    lst_out.append(answer)
    if answer == ans:
        num_corr += 1
    # print(prompt_out[0])
print(lst_out)
print(num_corr / len(corr_ans))

['5', '6', '7', '8', '9', '10', '11', '12']
1.0


In [64]:
# big 3 heads + MLP 9
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]
mlps_not_ablate = [layer for layer in range(12) if layer != 9]

corr_text = "0 0 0 0"
num_corr = 0
big3_outputs = []
lst_out = []
for clean_text, ans in zip(correct_prompts, corr_ans):
    prompt_out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 1)
    big3_outputs.append(prompt_out[0])
    answer = prompt_out[0].split(' ')[-1]
    lst_out.append(answer)
    if answer == ans:
        num_corr += 1
    print(prompt_out[0])
print(lst_out)
print(num_corr / len(corr_ans))

<|endoftext|>1 2 3 4 5
<|endoftext|>2 3 4 5 6
<|endoftext|>3 4 5 6 7
<|endoftext|>4 5 6 7 8
<|endoftext|>5 6 7 8 9
<|endoftext|>6 7 8 9 10
<|endoftext|>7 8 9 10 9
<|endoftext|>8 9 10 11 12
['5', '6', '7', '8', '9', '10', '9', '12']
0.875


In [63]:
# big 3 heads + MLP 9
corr_text = "0 0 0 0"
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]
mlps_not_ablate = [layer for layer in range(12) if layer != 9]

big3_outputs = []
for clean_text in correct_prompts:
    prompt_out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 1)
    big3_outputs.append(prompt_out[0])
    print(prompt_out[0])

<|endoftext|>1 2 3 4 5
<|endoftext|>2 3 4 5 6
<|endoftext|>3 4 5 6 7
<|endoftext|>4 5 6 7 8
<|endoftext|>5 6 7 8 9
<|endoftext|>6 7 8 9 10
<|endoftext|>7 8 9 10 9
<|endoftext|>8 9 10 11 12


In [45]:
# ablate 4.4, 7.11, 9.1
corr_text = "0 0 0 0"
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]
mlps_not_ablate = [layer for layer in range(12)]

big3_outputs = []
for clean_text in correct_prompts:
    prompt_out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 1)
    big3_outputs.append(prompt_out[0])
    print(prompt_out[0])

<|endoftext|>1 2 3 4 5
<|endoftext|>2 3 4 5 6
<|endoftext|>3 4 5 6 7
<|endoftext|>4 5 6 7 8
<|endoftext|>5 6 7 8 9
<|endoftext|>6 7 8 9 10
<|endoftext|>7 8 9 10 9
<|endoftext|>8 9 10 11 12


In [46]:
# corrupt 9.1 and mlp9
corr_text = "0 0 0 0"
head_to_remove = ([(9, 1)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]
mlps_not_ablate = [layer for layer in range(12) if layer != 9]

big3_outputs = []
for clean_text in correct_prompts:
    prompt_out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 1)
    big3_outputs.append(prompt_out[0])
    print(prompt_out[0])

<|endoftext|>1 2 3 4 5
<|endoftext|>2 3 4 5 6
<|endoftext|>3 4 5 6 7
<|endoftext|>4 5 6 7 8
<|endoftext|>5 6 7 8 9
<|endoftext|>6 7 8 9 10
<|endoftext|>7 8 9 10 9
<|endoftext|>8 9 10 11 12


In [47]:
# ablate mlp 9
corr_text = "0 0 0 0"
head_to_remove = ([])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]
mlps_not_ablate = [layer for layer in range(12) if layer != 9]

big3_outputs = []
for clean_text in correct_prompts:
    prompt_out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 1)
    big3_outputs.append(prompt_out[0])
    print(prompt_out[0])

<|endoftext|>1 2 3 4 5
<|endoftext|>2 3 4 5 6
<|endoftext|>3 4 5 6 7
<|endoftext|>4 5 6 7 8
<|endoftext|>5 6 7 8 9
<|endoftext|>6 7 8 9 10
<|endoftext|>7 8 9 10 9
<|endoftext|>8 9 10 11 12


ablate just 9.1

In [48]:
# ablate just 9.1
corr_text = "0 0 0 0"
head_to_remove = ([(9, 1)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]
mlps_not_ablate = [layer for layer in range(12) ]

big3_outputs = []
for clean_text in correct_prompts:
    prompt_out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 1)
    big3_outputs.append(prompt_out[0])
    print(prompt_out[0])

<|endoftext|>1 2 3 4 5
<|endoftext|>2 3 4 5 6
<|endoftext|>3 4 5 6 7
<|endoftext|>4 5 6 7 8
<|endoftext|>5 6 7 8 9
<|endoftext|>6 7 8 9 10
<|endoftext|>7 8 9 10 9
<|endoftext|>8 9 10 11 12


In [49]:
import random

corr_text = "0 0 0 0"
mlps_not_ablate = [layer for layer in range(12) if layer != 9]

# Define the number of runs
num_runs = 10
# Store outputs for each run
big3_outputs = []

# Perform 10 runs, each time randomly selecting 3 heads to ablate
for _ in range(num_runs):
    # Randomly select 3 heads to ablate from heads_not_ablate
    heads_to_remove = random.sample(heads_not_ablate, 3)
    heads_not_ablate_run = [x for x in heads_not_ablate if (x not in heads_to_remove)]

    for clean_text in correct_prompts:
        # Generate outputs with the randomly selected heads removed
        prompt_out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate_run, mlps_not_ablate, 1)
        big3_outputs.append(prompt_out[0])
        print(prompt_out[0])


<|endoftext|>1 2 3 4 5
<|endoftext|>2 3 4 5 6
<|endoftext|>3 4 5 6 7
<|endoftext|>4 5 6 7 8
<|endoftext|>5 6 7 8 9
<|endoftext|>6 7 8 9 9
<|endoftext|>7 8 9 10

<|endoftext|>8 9 10 11

<|endoftext|>1 2 3 4 5
<|endoftext|>2 3 4 5 6
<|endoftext|>3 4 5 6 7
<|endoftext|>4 5 6 7 8
<|endoftext|>5 6 7 8 9
<|endoftext|>6 7 8 9 10
<|endoftext|>7 8 9 10 10
<|endoftext|>8 9 10 11 10
<|endoftext|>1 2 3 4 5
<|endoftext|>2 3 4 5 6
<|endoftext|>3 4 5 6 7
<|endoftext|>4 5 6 7 8
<|endoftext|>5 6 7 8 9
<|endoftext|>6 7 8 9 10
<|endoftext|>7 8 9 10

<|endoftext|>8 9 10 11 12
<|endoftext|>1 2 3 4 5
<|endoftext|>2 3 4 5 6
<|endoftext|>3 4 5 6 7
<|endoftext|>4 5 6 7 8
<|endoftext|>5 6 7 8 9
<|endoftext|>6 7 8 9 10
<|endoftext|>7 8 9 10 9
<|endoftext|>8 9 10 11 12
<|endoftext|>1 2 3 4 5
<|endoftext|>2 3 4 5 6
<|endoftext|>3 4 5 6 7
<|endoftext|>4 5 6 7 8
<|endoftext|>5 6 7 8 7
<|endoftext|>6 7 8 9 8
<|endoftext|>7 8 9 10

<|endoftext|>8 9 10 11

<|endoftext|>1 2 3 4 5
<|endoftext|>2 3 4 5 6
<|endoftext|>3 4 