# Setup

## Change Inputs Here

In [1]:
model_name = "gpt2-small"
save_files = True

In [2]:
%%capture
%pip install git+https://github.com/neelnanda-io/TransformerLens.git

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
# import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

import pickle
from google.colab import files

import matplotlib.pyplot as plt
import statistics

In [4]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [5]:
import pdb

## Load Model

In [6]:
torch.set_grad_enabled(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
model = HookedTransformer.from_pretrained(
    model_name,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


## Import functions from repo

In [8]:
!git clone https://github.com/apartresearch/seqcont_circuits.git
%cd /content/seqcont_circuits/src/iter_node_pruning

Cloning into 'seqcont_circuits'...
remote: Enumerating objects: 875, done.[K
remote: Counting objects: 100% (341/341), done.[K
remote: Compressing objects: 100% (223/223), done.[K
remote: Total 875 (delta 187), reused 257 (delta 107), pack-reused 534[K
Receiving objects: 100% (875/875), 16.78 MiB | 21.61 MiB/s, done.
Resolving deltas: 100% (550/550), done.
/content/seqcont_circuits/src/iter_node_pruning


In [9]:
## comment this out when debugging functions in colab to use funcs defined in colab

# don't improt this
# # from dataset import Dataset

from metrics import *
from head_ablation_fns import *
from mlp_ablation_fns import *
from node_ablation_fns import *
from loop_node_ablation_fns import *

## fns

In [10]:
class Dataset:
    def __init__(self, prompts, pos_dict, tokenizer):  # , S1_is_first=False
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        self.corr_tokenIDs = [
            # self.tokenizer.encode(" " + prompt["corr"])[0] for prompt in self.prompts
            self.tokenizer.encode(prompt["corr"])[0] for prompt in self.prompts
        ]
        self.incorr_tokenIDs = [
            # self.tokenizer.encode(" " + prompt["incorr"])[0] for prompt in self.prompts
            self.tokenizer.encode(prompt["incorr"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        # for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
        for targ in [key for key in pos_dict]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = self.tokenizer.tokenize(input_text)
                # if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                #     target_token = prompt[targ]
                # else:
                #     target_token = "Ġ" + prompt[targ]
                # target_index = tokens.index(target_token)
                target_index = pos_dict[targ]
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [11]:
def generate_prompts_list_longer(text, tokens):
    prompts_list = []
    prompt_dict = {
        'corr': str(1),
        'incorr': str(2),
        'text': text}
    tokens_as_strs = model.tokenizer.tokenize(text)
    # for i in range(tokens.shape[1]):
    for i, tok in enumerate(tokens_as_strs):
        prompt_dict['S'+str(i)] = tok
    prompts_list.append(prompt_dict)
    return prompts_list

# Load datasets

In [12]:
def generate_prompts_list(x ,y):
    prompts_list = []
    for i in range(x, y):
        prompt_dict = {
            'S1': str(i),
            'S2': str(i+1),
            'S3': str(i+2),
            'S4': str(i+3),
            'corr': str(i+4),
            'incorr': str(i+3),
            'text': f"{i} {i+1} {i+2} {i+3}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(1, 2)
prompts_list

[{'S1': '1',
  'S2': '2',
  'S3': '3',
  'S4': '4',
  'corr': '5',
  'incorr': '4',
  'text': '1 2 3 4'}]

In [13]:
pos_dict = {}
for i in range(len(model.tokenizer.tokenize(prompts_list[0]['text']))):
    pos_dict['S'+str(i)] = i

In [14]:
dataset = Dataset(prompts_list, pos_dict, model.tokenizer)

In [15]:
import random

def generate_prompts_list_corr(prompt_list):
    outlist = []
    # for i in range(100):
    for prompt_dict in prompts_list:
        r1 = random.randint(1, 12)
        r2 = random.randint(1, 12)
        while True:
            r3 = random.randint(1, 12)
            r4 = random.randint(1, 12)
            if r4 - 1 != r3:
                break
        new_text = prompt_dict['text'].replace(prompt_dict['S1'], str(r1)).replace(prompt_dict['S2'], str(r2)).replace(prompt_dict['S3'], str(r3)).replace(prompt_dict['S4'], str(r4))
        new_prompt_dict = {
            'S1': str(r1),
            'S2': str(r2),
            'S3': str(r3),
            'S4': str(r4),
            'corr': prompt_dict['corr'],
            'incorr': prompt_dict['incorr'],
            'text': new_text
        }
        outlist.append(new_prompt_dict)
    return outlist
prompts_list_2 = generate_prompts_list_corr(prompts_list)
len(prompts_list_2)

1

In [16]:
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer)

## Get orig score

In [17]:
model.reset_hooks(including_permanent=True)
logits_original = model(dataset.toks)
orig_score = get_logit_diff(logits_original, dataset)
orig_score

tensor(6.0631, device='cuda:0')

In [18]:
import gc

del(logits_original)
torch.cuda.empty_cache()
gc.collect()

0

# logit diff for mult tok answers

In [19]:
def clean_gen(model, clean_text, corr_ans):
    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
    tokens = model.to_tokens(clean_text).to(device)
    tokens = tokens[:, 1:] # get rid of prepend bos when using model.to_tokens

    total_score = 0
    corr_ans_tokLen = 0
    ans_so_far = ''
    # while True:
    for i in range(5):
        print(f"Sequence so far: {model.to_string(tokens)[0]!r}")
        logits = model(tokens)
        next_token = logits[0, -1].argmax(dim=-1) # Get the predicted token at the end of our sequence
        next_char = model.to_string(next_token)

        dataset = Dataset(prompts_list, pos_dict, model.tokenizer)

        # new_score = get_logit_diff(logits, dataset)

        # measure how far away predicted logit is from corr token?

        # corr_logits = logits[:, dataset.word_idx["end"], dataset.corr_tokenIDs]
        # incorr_logits = logits[:, dataset.word_idx["end"], dataset.incorr_tokenIDs]
        # new_score = corr_logits - incorr_logits

        corr_logits = logits[:, -1, next_token]
        total_score += corr_logits
        print(f"logit diff of new char: {corr_logits}")

        ans_so_far += next_char
        corr_ans_tokLen += 1
        print(f"{tokens.shape[-1]+1}th char = {next_char!r}")
        if ans_so_far == corr_ans:
            print('\nTotal logit diff: ', total_score.item())
            break
        # Define new input sequence, by appending the previously generated token
        tokens = torch.cat([tokens, next_token[None, None]], dim=-1)
    return corr_ans_tokLen

In [20]:
clean_text = "1 2 3 4"
corr_ans = ' 5'
corr_ans_tokLen = clean_gen(model, clean_text, corr_ans)

Sequence so far: '1 2 3 4'
logit diff of new char: tensor([16.8118], device='cuda:0')
5th char = ' 5'

Total logit diff:  16.811784744262695


In [21]:
def ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen):
    tokens = model.to_tokens(clean_text).to(device)
    prompts_list = generate_prompts_list_longer(clean_text, tokens)

    corr_tokens = model.to_tokens(corr_text).to(device)
    prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
    pos_dict = {}
    for i in range(len(model.tokenizer.tokenize(prompts_list_2[0]['text']))):
        pos_dict['S'+str(i)] = i
    dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer)
    model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

    tokens = tokens[:, 1:] # get rid of prepend bos when using model.to_tokens
    logits = model(tokens)
    next_token = logits[0, -1].argmax(dim=-1)
    next_char = model.to_string(next_token)

    total_score = 0

    print(f"Sequence so far: {model.to_string(tokens)[0]!r}")
    for i in range(corr_ans_tokLen):
    # for i in range(5):
        print(f"{tokens.shape[-1]+1}th char = {next_char!r}")

        clean_text = clean_text + next_char
        tokens = model.to_tokens(clean_text).to(device)
        tokens = tokens[:, 1:]
        print(clean_text)

        # get new ablation dataset
        model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

        corr_text = corr_text + next_char
        corr_tokens = model.to_tokens(corr_text).to(device)
        prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)
        print(corr_text)

        pos_dict = {}
        for i in range(len(model.tokenizer.tokenize(prompts_list_2[0]['text']))):
            pos_dict['S'+str(i)] = i

        dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer)

        model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

        logits = model(tokens)
        next_token = logits[0, -1].argmax(dim=-1) # Get the predicted token at the end of our sequence
        next_char = model.to_string(next_token)

        print('\n')
        print(f"Sequence so far: {model.to_string(tokens)[0]!r}")

        new_score = get_logit_diff(logits, dataset)
        total_score += new_score
        print(f"corr logit of new char: {new_score}")
    print('\n Total corr logit: ', total_score.item())

In [22]:
clean_text = "1 2 3"
corr_text = "5 3 9"
heads_not_ablate = []  # ablate all heads but not MLPs
mlps_not_ablate = []  # ablate all MLPs
ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3'
4th char = '.'
1 2 3.
5 3 9.


Sequence so far: '1 2 3.'
corr logit of new char: 0.861663818359375

 Total corr logit:  0.861663818359375


# 1 2 3 genr ablation expms

In [23]:
clean_text = "1 2 3"
corr_text = "5 3 9"

## ablate just head 9.1 and MLP 9

In [24]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((9, 1))

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)

Sequence so far: '1 2 3'
4th char = ' 4'
1 2 3 4
5 3 9 4


Sequence so far: '1 2 3 4'
corr logit of new char: 0.42961740493774414
5th char = ' 5'
1 2 3 4 5
5 3 9 4 5


Sequence so far: '1 2 3 4 5'
corr logit of new char: 0.42962169647216797
6th char = ' 6'
1 2 3 4 5 6
5 3 9 4 5 6


Sequence so far: '1 2 3 4 5 6'
corr logit of new char: 0.4296226501464844
7th char = ' 7'
1 2 3 4 5 6 7
5 3 9 4 5 6 7


Sequence so far: '1 2 3 4 5 6 7'
corr logit of new char: 0.42961835861206055
8th char = ' 8'
1 2 3 4 5 6 7 8
5 3 9 4 5 6 7 8


Sequence so far: '1 2 3 4 5 6 7 8'
corr logit of new char: 0.42961645126342773

 Total corr logit:  2.1480965614318848


## ablate 4.4, 7.11, 9.1

In [25]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)

Sequence so far: '1 2 3'
4th char = ' 4'
1 2 3 4
5 3 9 4


Sequence so far: '1 2 3 4'
corr logit of new char: 3.36678409576416
5th char = ' 5'
1 2 3 4 5
5 3 9 4 5


Sequence so far: '1 2 3 4 5'
corr logit of new char: 3.3667917251586914
6th char = ' 6'
1 2 3 4 5 6
5 3 9 4 5 6


Sequence so far: '1 2 3 4 5 6'
corr logit of new char: 3.3667922019958496
7th char = ' 7'
1 2 3 4 5 6 7
5 3 9 4 5 6 7


Sequence so far: '1 2 3 4 5 6 7'
corr logit of new char: 3.3667869567871094
8th char = ' 8'
1 2 3 4 5 6 7 8
5 3 9 4 5 6 7 8


Sequence so far: '1 2 3 4 5 6 7 8'
corr logit of new char: 3.3667917251586914

 Total corr logit:  16.833946228027344


## ablate mlp 9

In [26]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)

Sequence so far: '1 2 3'
4th char = ' 4'
1 2 3 4
5 3 9 4


Sequence so far: '1 2 3 4'
corr logit of new char: 0.8109622001647949
5th char = ' 5'
1 2 3 4 5
5 3 9 4 5


Sequence so far: '1 2 3 4 5'
corr logit of new char: 0.8109650611877441
6th char = ' 6'
1 2 3 4 5 6
5 3 9 4 5 6


Sequence so far: '1 2 3 4 5 6'
corr logit of new char: 0.8109645843505859
7th char = ' 7'
1 2 3 4 5 6 7
5 3 9 4 5 6 7


Sequence so far: '1 2 3 4 5 6 7'
corr logit of new char: 0.8109612464904785
8th char = ' 8'
1 2 3 4 5 6 7 8
5 3 9 4 5 6 7 8


Sequence so far: '1 2 3 4 5 6 7 8'
corr logit of new char: 0.8109622001647949

 Total corr logit:  4.054815292358398


## ablate 4.4, 7.11, 9.1 and mlp 9

In [27]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)

Sequence so far: '1 2 3'
4th char = ' 3'
1 2 3 3
5 3 9 3


Sequence so far: '1 2 3 3'
corr logit of new char: -1.0452260971069336
5th char = ' 3'
1 2 3 3 3
5 3 9 3 3


Sequence so far: '1 2 3 3 3'
corr logit of new char: -1.0452265739440918
6th char = ' 3'
1 2 3 3 3 3
5 3 9 3 3 3


Sequence so far: '1 2 3 3 3 3'
corr logit of new char: -1.0452251434326172
7th char = ' 3'
1 2 3 3 3 3 3
5 3 9 3 3 3 3


Sequence so far: '1 2 3 3 3 3 3'
corr logit of new char: -1.0452260971069336
8th char = ' 3'
1 2 3 3 3 3 3 3
5 3 9 3 3 3 3 3


Sequence so far: '1 2 3 3 3 3 3 3'
corr logit of new char: -1.0452251434326172

 Total corr logit:  -5.226129055023193


## 6.2, 4.1, 7.1

In [28]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(6, 2), (4,1), (7,1)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)

Sequence so far: '1 2 3'
4th char = ' 4'
1 2 3 4
5 3 9 4


Sequence so far: '1 2 3 4'
corr logit of new char: 0.8091115951538086
5th char = ' 5'
1 2 3 4 5
5 3 9 4 5


Sequence so far: '1 2 3 4 5'
corr logit of new char: 0.8091144561767578
6th char = ' 6'
1 2 3 4 5 6
5 3 9 4 5 6


Sequence so far: '1 2 3 4 5 6'
corr logit of new char: 0.8091154098510742
7th char = ' 7'
1 2 3 4 5 6 7
5 3 9 4 5 6 7


Sequence so far: '1 2 3 4 5 6 7'
corr logit of new char: 0.809107780456543
8th char = ' 8'
1 2 3 4 5 6 7 8
5 3 9 4 5 6 7 8


Sequence so far: '1 2 3 4 5 6 7 8'
corr logit of new char: 0.8091120719909668

 Total corr logit:  4.04556131362915


In [29]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
# heads_not_ablate = [(9, 1)]
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]
len(heads_not_ablate)

141

# M T, two days after is

In [30]:
clean_text = "What comes after Monday is Tuesday, and two days after is"
corr_text = "What comes after X is Y, and two days after is"

## clean

In [31]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)

Sequence so far: 'What comes after Monday is Tuesday, and two days after is'
13th char = ' Wednesday'
What comes after Monday is Tuesday, and two days after is Wednesday
What comes after X is Y, and two days after is Wednesday


Sequence so far: 'What comes after Monday is Tuesday, and two days after is Wednesday'
corr logit of new char: -0.49404239654541016
14th char = '.'
What comes after Monday is Tuesday, and two days after is Wednesday.
What comes after X is Y, and two days after is Wednesday.


Sequence so far: 'What comes after Monday is Tuesday, and two days after is Wednesday.'
corr logit of new char: -0.49404239654541016
15th char = '\n'
What comes after Monday is Tuesday, and two days after is Wednesday.

What comes after X is Y, and two days after is Wednesday.



Sequence so far: 'What comes after Monday is Tuesday, and two days after is Wednesday.\n'
corr logit of new char: -0.49404239654541016
16th char = '\n'
What comes after Monday is Tuesday, and two days after is Wed

## corrupt the subcircuit

In [32]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)

Sequence so far: 'What comes after Monday is Tuesday, and two days after is'
13th char = ' the'
What comes after Monday is Tuesday, and two days after is the
What comes after X is Y, and two days after is the


Sequence so far: 'What comes after Monday is Tuesday, and two days after is the'
corr logit of new char: -0.890857458114624
14th char = ' day'
What comes after Monday is Tuesday, and two days after is the day
What comes after X is Y, and two days after is the day


Sequence so far: 'What comes after Monday is Tuesday, and two days after is the day'
corr logit of new char: -0.890857458114624
15th char = ' of'
What comes after Monday is Tuesday, and two days after is the day of
What comes after X is Y, and two days after is the day of


Sequence so far: 'What comes after Monday is Tuesday, and two days after is the day of'
corr logit of new char: -0.890857458114624
16th char = ' the'
What comes after Monday is Tuesday, and two days after is the day of the
What comes after X is Y, 

## ablate 4.4, 7.11, 9.1

In [33]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)

Sequence so far: 'What comes after Monday is Tuesday, and two days after is'
13th char = ' Monday'
What comes after Monday is Tuesday, and two days after is Monday
What comes after X is Y, and two days after is Monday


Sequence so far: 'What comes after Monday is Tuesday, and two days after is Monday'
corr logit of new char: -0.5167403221130371
14th char = ','
What comes after Monday is Tuesday, and two days after is Monday,
What comes after X is Y, and two days after is Monday,


Sequence so far: 'What comes after Monday is Tuesday, and two days after is Monday,'
corr logit of new char: -0.5167403221130371
15th char = ' and'
What comes after Monday is Tuesday, and two days after is Monday, and
What comes after X is Y, and two days after is Monday, and


Sequence so far: 'What comes after Monday is Tuesday, and two days after is Monday, and'
corr logit of new char: -0.5167403221130371
16th char = ' then'
What comes after Monday is Tuesday, and two days after is Monday, and then
What c

## corrupt 9.1 and mlp9

In [34]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((9, 1))

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)

Sequence so far: 'What comes after Monday is Tuesday, and two days after is'
13th char = ' the'
What comes after Monday is Tuesday, and two days after is the
What comes after X is Y, and two days after is the


Sequence so far: 'What comes after Monday is Tuesday, and two days after is the'
corr logit of new char: -0.8888349533081055
14th char = ' day'
What comes after Monday is Tuesday, and two days after is the day
What comes after X is Y, and two days after is the day


Sequence so far: 'What comes after Monday is Tuesday, and two days after is the day'
corr logit of new char: -0.8888349533081055
15th char = ' of'
What comes after Monday is Tuesday, and two days after is the day of
What comes after X is Y, and two days after is the day of


Sequence so far: 'What comes after Monday is Tuesday, and two days after is the day of'
corr logit of new char: -0.8888349533081055
16th char = ' the'
What comes after Monday is Tuesday, and two days after is the day of the
What comes after X is 

## ablate mlp 9

In [35]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)

Sequence so far: 'What comes after Monday is Tuesday, and two days after is'
13th char = ' Wednesday'
What comes after Monday is Tuesday, and two days after is Wednesday
What comes after X is Y, and two days after is Wednesday


Sequence so far: 'What comes after Monday is Tuesday, and two days after is Wednesday'
corr logit of new char: -0.8782567977905273
14th char = '.'
What comes after Monday is Tuesday, and two days after is Wednesday.
What comes after X is Y, and two days after is Wednesday.


Sequence so far: 'What comes after Monday is Tuesday, and two days after is Wednesday.'
corr logit of new char: -0.8782567977905273
15th char = '\n'
What comes after Monday is Tuesday, and two days after is Wednesday.

What comes after X is Y, and two days after is Wednesday.



Sequence so far: 'What comes after Monday is Tuesday, and two days after is Wednesday.\n'
corr logit of new char: -0.8782567977905273
16th char = '\n'
What comes after Monday is Tuesday, and two days after is Wednes

## ablate just 9.1

In [36]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((9, 1))

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)

Sequence so far: 'What comes after Monday is Tuesday, and two days after is'
13th char = ' Monday'
What comes after Monday is Tuesday, and two days after is Monday
What comes after X is Y, and two days after is Monday


Sequence so far: 'What comes after Monday is Tuesday, and two days after is Monday'
corr logit of new char: -0.49973201751708984
14th char = '.'
What comes after Monday is Tuesday, and two days after is Monday.
What comes after X is Y, and two days after is Monday.


Sequence so far: 'What comes after Monday is Tuesday, and two days after is Monday.'
corr logit of new char: -0.49973201751708984
15th char = '\n'
What comes after Monday is Tuesday, and two days after is Monday.

What comes after X is Y, and two days after is Monday.



Sequence so far: 'What comes after Monday is Tuesday, and two days after is Monday.\n'
corr logit of new char: -0.49973201751708984
16th char = '\n'
What comes after Monday is Tuesday, and two days after is Monday.


What comes after X is Y

## ablate random head

In [37]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((6, 2))

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)

Sequence so far: 'What comes after Monday is Tuesday, and two days after is'
13th char = ' Wednesday'
What comes after Monday is Tuesday, and two days after is Wednesday
What comes after X is Y, and two days after is Wednesday


Sequence so far: 'What comes after Monday is Tuesday, and two days after is Wednesday'
corr logit of new char: -0.4928019046783447
14th char = '.'
What comes after Monday is Tuesday, and two days after is Wednesday.
What comes after X is Y, and two days after is Wednesday.


Sequence so far: 'What comes after Monday is Tuesday, and two days after is Wednesday.'
corr logit of new char: -0.4928019046783447
15th char = '\n'
What comes after Monday is Tuesday, and two days after is Wednesday.

What comes after X is Y, and two days after is Wednesday.



Sequence so far: 'What comes after Monday is Tuesday, and two days after is Wednesday.\n'
corr logit of new char: -0.4928019046783447
16th char = '\n'
What comes after Monday is Tuesday, and two days after is Wednes

# test clean prompts

In [38]:
# clean_text = "1"
# tokens = model.to_tokens(clean_text).to(device)
# tokens = tokens[:, 1:] # get rid of prepend bos when using model.to_tokens
# # tokens = model.tokenizer(clean_text)['input_ids']
# logits = model(tokens)
# next_token = logits[0, -1].argmax(dim=-1) # Get the predicted token at the end of our sequence
# next_char = model.to_string(next_token)

clean_text = "1"
corr_ans_tokLen = 3

heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
mlps_not_ablate = [layer for layer in range(12)]
clean_gen(model, clean_text, corr_ans)

Sequence so far: '1'
logit diff of new char: tensor([10.5143], device='cuda:0')
2th char = '.'
Sequence so far: '1.'
logit diff of new char: tensor([11.9692], device='cuda:0')
3th char = '0'
Sequence so far: '1.0'
logit diff of new char: tensor([13.1672], device='cuda:0')
4th char = '.'
Sequence so far: '1.0.'
logit diff of new char: tensor([15.8012], device='cuda:0')
5th char = '0'
Sequence so far: '1.0.0'
logit diff of new char: tensor([14.2398], device='cuda:0')
6th char = '.'


5

In [39]:
clean_text = "two"
corr_ans_tokLen = 3

heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
mlps_not_ablate = [layer for layer in range(12)]
clean_gen(model, clean_text, corr_ans)

Sequence so far: 'two'
logit diff of new char: tensor([9.6223], device='cuda:0')
2th char = ','
Sequence so far: 'two,'
logit diff of new char: tensor([10.2186], device='cuda:0')
3th char = ' and'
Sequence so far: 'two, and'
logit diff of new char: tensor([11.2373], device='cuda:0')
4th char = ' the'
Sequence so far: 'two, and the'
logit diff of new char: tensor([10.0012], device='cuda:0')
5th char = ' other'
Sequence so far: 'two, and the other'
logit diff of new char: tensor([12.4016], device='cuda:0')
6th char = ' two'


5

In [40]:
clean_text = "March"
corr_ans_tokLen = 3

heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
mlps_not_ablate = [layer for layer in range(12)]
clean_gen(model, clean_text, corr_ans)

Sequence so far: 'March'
logit diff of new char: tensor([10.2301], device='cuda:0')
2th char = ','
Sequence so far: 'March,'
logit diff of new char: tensor([11.7677], device='cuda:0')
3th char = ' the'
Sequence so far: 'March, the'
logit diff of new char: tensor([10.4229], device='cuda:0')
4th char = ' U'
Sequence so far: 'March, the U'
logit diff of new char: tensor([19.0277], device='cuda:0')
5th char = '.'
Sequence so far: 'March, the U.'
logit diff of new char: tensor([20.5512], device='cuda:0')
6th char = 'S'


5

In [41]:
clean_text = "Bob is first. David is"
corr_ans_tokLen = 3

heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
mlps_not_ablate = [layer for layer in range(12)]
clean_gen(model, clean_text, corr_ans)

Sequence so far: 'Bob is first. David is'
logit diff of new char: tensor([15.9324], device='cuda:0')
7th char = ' second'
Sequence so far: 'Bob is first. David is second'
logit diff of new char: tensor([17.8772], device='cuda:0')
8th char = '.'
Sequence so far: 'Bob is first. David is second.'
logit diff of new char: tensor([16.3890], device='cuda:0')
9th char = '\n'
Sequence so far: 'Bob is first. David is second.\n'
logit diff of new char: tensor([22.0736], device='cuda:0')
10th char = '\n'
Sequence so far: 'Bob is first. David is second.\n\n'
logit diff of new char: tensor([17.7922], device='cuda:0')
11th char = 'David'


5

In [42]:
clean_text = "Two days after Monday is"
corr_ans_tokLen = 3

heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
mlps_not_ablate = [layer for layer in range(12)]
clean_gen(model, clean_text, corr_ans)

Sequence so far: 'Two days after Monday is'
logit diff of new char: tensor([13.7510], device='cuda:0')
6th char = ' when'
Sequence so far: 'Two days after Monday is when'
logit diff of new char: tensor([13.4375], device='cuda:0')
7th char = ' the'
Sequence so far: 'Two days after Monday is when the'
logit diff of new char: tensor([10.9209], device='cuda:0')
8th char = ' FBI'
Sequence so far: 'Two days after Monday is when the FBI'
logit diff of new char: tensor([15.5805], device='cuda:0')
9th char = ' announced'
Sequence so far: 'Two days after Monday is when the FBI announced'
logit diff of new char: tensor([16.2980], device='cuda:0')
10th char = ' it'


5

In [43]:
clean_text = "Bob is first in line. David is"
corr_ans_tokLen = 3

heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
mlps_not_ablate = [layer for layer in range(12)]
clean_gen(model, clean_text, corr_ans)

Sequence so far: 'Bob is first in line. David is'
logit diff of new char: tensor([16.5300], device='cuda:0')
9th char = ' second'
Sequence so far: 'Bob is first in line. David is second'
logit diff of new char: tensor([18.4162], device='cuda:0')
10th char = '.'
Sequence so far: 'Bob is first in line. David is second.'
logit diff of new char: tensor([16.6176], device='cuda:0')
11th char = '\n'
Sequence so far: 'Bob is first in line. David is second.\n'
logit diff of new char: tensor([22.3888], device='cuda:0')
12th char = '\n'
Sequence so far: 'Bob is first in line. David is second.\n\n'
logit diff of new char: tensor([18.9882], device='cuda:0')
13th char = 'David'


5

In [44]:
clean_text = "uno dos tres"
corr_ans_tokLen = 3

heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
mlps_not_ablate = [layer for layer in range(12)]
clean_gen(model, clean_text, corr_ans)

Sequence so far: 'uno dos tres'
logit diff of new char: tensor([11.3545], device='cuda:0')
5th char = ' un'
Sequence so far: 'uno dos tres un'
logit diff of new char: tensor([11.6727], device='cuda:0')
6th char = 'as'
Sequence so far: 'uno dos tres unas'
logit diff of new char: tensor([11.1833], device='cuda:0')
7th char = ' de'
Sequence so far: 'uno dos tres unas de'
logit diff of new char: tensor([10.9759], device='cuda:0')
8th char = ' la'
Sequence so far: 'uno dos tres unas de la'
logit diff of new char: tensor([11.1198], device='cuda:0')
9th char = ' v'


5

In [45]:
clean_text = "uno dos tres"
corr_ans_tokLen = 3

heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
mlps_not_ablate = [layer for layer in range(12)]
clean_gen(model, clean_text, corr_ans)

Sequence so far: 'uno dos tres'
logit diff of new char: tensor([11.3545], device='cuda:0')
5th char = ' un'
Sequence so far: 'uno dos tres un'
logit diff of new char: tensor([11.6727], device='cuda:0')
6th char = 'as'
Sequence so far: 'uno dos tres unas'
logit diff of new char: tensor([11.1833], device='cuda:0')
7th char = ' de'
Sequence so far: 'uno dos tres unas de'
logit diff of new char: tensor([10.9759], device='cuda:0')
8th char = ' la'
Sequence so far: 'uno dos tres unas de la'
logit diff of new char: tensor([11.1198], device='cuda:0')
9th char = ' v'


5

# Bob is first. David is

In [46]:
clean_text = "Bob is first. David is"
corr_text = "Bob is X. David is"
corr_ans_tokLen = 1

clean

In [47]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'Bob is first. David is'
7th char = ' second'
Bob is first. David is second
Bob is X. David is second


Sequence so far: 'Bob is first. David is second'
corr logit of new char: 0.4071230888366699

 Total corr logit:  0.4071230888366699


corrupt the subcircuit

In [48]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'Bob is first. David is'
7th char = ' second'
Bob is first. David is second
Bob is X. David is second


Sequence so far: 'Bob is first. David is second'
corr logit of new char: 0.6557116508483887

 Total corr logit:  0.6557116508483887


ablate 4.4, 7.11, 9.1

In [49]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'Bob is first. David is'
7th char = ' second'
Bob is first. David is second
Bob is X. David is second


Sequence so far: 'Bob is first. David is second'
corr logit of new char: 0.37902164459228516

 Total corr logit:  0.37902164459228516


corrupt 9.1 and mlp9

In [50]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((9, 1))

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'Bob is first. David is'
7th char = ' second'
Bob is first. David is second
Bob is X. David is second


Sequence so far: 'Bob is first. David is second'
corr logit of new char: 0.6732892990112305

 Total corr logit:  0.6732892990112305


ablate mlp 9

In [51]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'Bob is first. David is'
7th char = ' second'
Bob is first. David is second
Bob is X. David is second


Sequence so far: 'Bob is first. David is second'
corr logit of new char: 0.6711874008178711

 Total corr logit:  0.6711874008178711


ablate just 9.1

In [52]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((9, 1))

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'Bob is first. David is'
7th char = ' second'
Bob is first. David is second
Bob is X. David is second


Sequence so far: 'Bob is first. David is second'
corr logit of new char: 0.4114093780517578

 Total corr logit:  0.4114093780517578


ablate random head

In [53]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((6, 2))

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'Bob is first. David is'
7th char = ' second'
Bob is first. David is second
Bob is X. David is second


Sequence so far: 'Bob is first. David is second'
corr logit of new char: 0.40241003036499023

 Total corr logit:  0.40241003036499023


ablate all

In [54]:
heads_not_ablate = [ ]

mlps_not_ablate = []

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'Bob is first. David is'
7th char = ' Y'
Bob is first. David is Y
Bob is X. David is Y


Sequence so far: 'Bob is first. David is Y'
corr logit of new char: 1.2515344619750977

 Total corr logit:  1.2515344619750977


# one two three

In [55]:
clean_text = "one two three"
corr_text = "five nine two"
corr_ans_tokLen = 1

clean

In [56]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'one two three'
4th char = ' four'
one two three four
five nine two four


Sequence so far: 'one two three four'
corr logit of new char: 5.505155086517334

 Total corr logit:  5.505155086517334


corrupt the subcircuit

In [57]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'one two three'
4th char = '-'
one two three-
five nine two-


Sequence so far: 'one two three-'
corr logit of new char: -0.5595006942749023

 Total corr logit:  -0.5595006942749023


ablate 4.4, 7.11, 9.1

In [58]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'one two three'
4th char = '-'
one two three-
five nine two-


Sequence so far: 'one two three-'
corr logit of new char: -1.0319595336914062

 Total corr logit:  -1.0319595336914062


corrupt 9.1 and mlp9

In [59]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((9, 1))

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'one two three'
4th char = ' four'
one two three four
five nine two four


Sequence so far: 'one two three four'
corr logit of new char: 2.635061264038086

 Total corr logit:  2.635061264038086


ablate mlp 9

In [60]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'one two three'
4th char = ' four'
one two three four
five nine two four


Sequence so far: 'one two three four'
corr logit of new char: 3.019651412963867

 Total corr logit:  3.019651412963867


ablate just 9.1

In [61]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((9, 1))

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'one two three'
4th char = ' four'
one two three four
five nine two four


Sequence so far: 'one two three four'
corr logit of new char: 5.045960903167725

 Total corr logit:  5.045960903167725


ablate random head

In [62]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((6, 2))

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'one two three'
4th char = ' four'
one two three four
five nine two four


Sequence so far: 'one two three four'
corr logit of new char: 5.4420366287231445

 Total corr logit:  5.4420366287231445


ablate all

In [63]:
heads_not_ablate = [ ]

mlps_not_ablate = []

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'one two three'
4th char = ' five'
one two three five
five nine two five


Sequence so far: 'one two three five'
corr logit of new char: -0.12970507144927979

 Total corr logit:  -0.12970507144927979


# January February March

In [64]:
clean_text = "January February March"
corr_text = "April July July"
corr_ans_tokLen = 1

clean

In [65]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'January February March'
4th char = ' April'
January February March April
April July July April


Sequence so far: 'January February March April'
corr logit of new char: 9.440199851989746

 Total corr logit:  9.440199851989746


corrupt the subcircuit

In [66]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'January February March'
4th char = ' August'
January February March August
April July July August


Sequence so far: 'January February March August'
corr logit of new char: -1.1505086421966553

 Total corr logit:  -1.1505086421966553


ablate 4.4, 7.11, 9.1

In [67]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'January February March'
4th char = ' April'
January February March April
April July July April


Sequence so far: 'January February March April'
corr logit of new char: 0.5852069854736328

 Total corr logit:  0.5852069854736328


corrupt 9.1 and mlp9

In [68]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((9, 1))

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'January February March'
4th char = ' April'
January February March April
April July July April


Sequence so far: 'January February March April'
corr logit of new char: -3.3458361625671387

 Total corr logit:  -3.3458361625671387


ablate mlp 9

In [69]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'January February March'
4th char = ' April'
January February March April
April July July April


Sequence so far: 'January February March April'
corr logit of new char: -0.9824318885803223

 Total corr logit:  -0.9824318885803223


ablate just 9.1

In [70]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((9, 1))

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'January February March'
4th char = ' April'
January February March April
April July July April


Sequence so far: 'January February March April'
corr logit of new char: 8.566516876220703

 Total corr logit:  8.566516876220703


ablate random head

In [71]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((6, 2))

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'January February March'
4th char = ' April'
January February March April
April July July April


Sequence so far: 'January February March April'
corr logit of new char: 9.635494232177734

 Total corr logit:  9.635494232177734


ablate all

In [72]:
heads_not_ablate = [ ]

mlps_not_ablate = []

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 'January February March'
4th char = ' August'
January February March August
April July July August


Sequence so far: 'January February March August'
corr logit of new char: -1.855480670928955

 Total corr logit:  -1.855480670928955


# 1 2 3 4 5 6

In [73]:
clean_text = "1 2 3 4 5 6"
corr_text = "8 5 9 4 2 4"
corr_ans_tokLen = 1

clean

In [74]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6'
7th char = ' 7'
1 2 3 4 5 6 7
8 5 9 4 2 4 7


Sequence so far: '1 2 3 4 5 6 7'
corr logit of new char: 6.063086032867432

 Total corr logit:  6.063086032867432


corrupt the subcircuit

In [75]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6'
7th char = ' 6'
1 2 3 4 5 6 6
8 5 9 4 2 4 6


Sequence so far: '1 2 3 4 5 6 6'
corr logit of new char: -0.3448066711425781

 Total corr logit:  -0.3448066711425781


ablate 4.4, 7.11, 9.1

In [76]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6'
7th char = ' 7'
1 2 3 4 5 6 7
8 5 9 4 2 4 7


Sequence so far: '1 2 3 4 5 6 7'
corr logit of new char: 1.3072271347045898

 Total corr logit:  1.3072271347045898


corrupt 9.1 and mlp9

In [77]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((9, 1))

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6'
7th char = ' 7'
1 2 3 4 5 6 7
8 5 9 4 2 4 7


Sequence so far: '1 2 3 4 5 6 7'
corr logit of new char: -0.2949347496032715

 Total corr logit:  -0.2949347496032715


ablate mlp 9

In [78]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6'
7th char = ' 7'
1 2 3 4 5 6 7
8 5 9 4 2 4 7


Sequence so far: '1 2 3 4 5 6 7'
corr logit of new char: 0.2177433967590332

 Total corr logit:  0.2177433967590332


ablate just 9.1

In [79]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((9, 1))

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6'
7th char = ' 7'
1 2 3 4 5 6 7
8 5 9 4 2 4 7


Sequence so far: '1 2 3 4 5 6 7'
corr logit of new char: 5.689393997192383

 Total corr logit:  5.689393997192383


ablate random head

In [80]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((6, 2))

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6'
7th char = ' 7'
1 2 3 4 5 6 7
8 5 9 4 2 4 7


Sequence so far: '1 2 3 4 5 6 7'
corr logit of new char: 6.105478763580322

 Total corr logit:  6.105478763580322


ablate all

In [81]:
heads_not_ablate = [ ]

mlps_not_ablate = []

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6'
7th char = ' 3'
1 2 3 4 5 6 3
8 5 9 4 2 4 3


Sequence so far: '1 2 3 4 5 6 3'
corr logit of new char: -0.07252740859985352

 Total corr logit:  -0.07252740859985352


# 1 to 50

In [82]:
import random

# Generate a string of numbers from 1 to 50
sequence_string = ' '.join(map(str, range(1, 51)))

# Generate a string of random numbers picked from 1 to 50
random_numbers = [random.randint(1, 50) for _ in range(50)]
random_string = ' '.join(map(str, random_numbers))

print("Sequence String: ", sequence_string)
print("Random Numbers String: ", random_string)

Sequence String:  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
Random Numbers String:  34 32 3 33 41 8 27 25 35 29 27 18 27 29 11 42 5 13 20 5 2 45 8 3 30 35 28 33 32 1 47 12 31 48 20 8 27 25 24 50 25 3 19 50 48 22 50 8 7 49


In [83]:
clean_text = sequence_string
corr_text = random_string
corr_ans_tokLen = 1

clean

In [84]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50'
51th char = ' 51'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
34 32 3 33 41 8 27 25 35 29 27 18 27 29 11 42 5 13 20 5 2 45 8 3 30 35 28 33 32 1 47 12 31 48 20 8 27 25 24 50 25 3 19 50 48 22 50 8 7 49 51


Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51'
corr logit of new char: 6.063088417053223

 Total corr logit:  6.063088417053223


corrupt the subcircuit

In [85]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50'
51th char = ' 51'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
34 32 3 33 41 8 27 25 35 29 27 18 27 29 11 42 5 13 20 5 2 45 8 3 30 35 28 33 32 1 47 12 31 48 20 8 27 25 24 50 25 3 19 50 48 22 50 8 7 49 51


Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51'
corr logit of new char: -3.1905994415283203

 Total corr logit:  -3.1905994415283203


ablate 4.4, 7.11, 9.1

In [86]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50'
51th char = ' 51'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
34 32 3 33 41 8 27 25 35 29 27 18 27 29 11 42 5 13 20 5 2 45 8 3 30 35 28 33 32 1 47 12 31 48 20 8 27 25 24 50 25 3 19 50 48 22 50 8 7 49 51


Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51'
corr logit of new char: 2.5344972610473633

 Total corr logit:  2.5344972610473633


corrupt 9.1 and mlp9

In [87]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((9, 1))

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50'
51th char = ' 51'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
34 32 3 33 41 8 27 25 35 29 27 18 27 29 11 42 5 13 20 5 2 45 8 3 30 35 28 33 32 1 47 12 31 48 20 8 27 25 24 50 25 3 19 50 48 22 50 8 7 49 51


Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51'
corr logit of new char: -3.2330355644226074

 Total corr logit:  -3.2330355644226074


ablate mlp 9

In [88]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50'
51th char = ' 51'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
34 32 3 33 41 8 27 25 35 29 27 18 27 29 11 42 5 13 20 5 2 45 8 3 30 35 28 33 32 1 47 12 31 48 20 8 27 25 24 50 25 3 19 50 48 22 50 8 7 49 51


Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51'
corr logit of new char: -1.9877004623413086

 Total corr logit:  -1.9877004623413086


ablate just 9.1

In [89]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((9, 1))

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50'
51th char = ' 51'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
34 32 3 33 41 8 27 25 35 29 27 18 27 29 11 42 5 13 20 5 2 45 8 3 30 35 28 33 32 1 47 12 31 48 20 8 27 25 24 50 25 3 19 50 48 22 50 8 7 49 51


Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51'
corr logit of new char: 2.6247963905334473

 Total corr logit:  2.6247963905334473


ablate random head

In [90]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((6, 2))

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50'
51th char = ' 51'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
34 32 3 33 41 8 27 25 35 29 27 18 27 29 11 42 5 13 20 5 2 45 8 3 30 35 28 33 32 1 47 12 31 48 20 8 27 25 24 50 25 3 19 50 48 22 50 8 7 49 51


Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51'
corr logit of new char: 6.098381042480469

 Total corr logit:  6.098381042480469


ablate all

In [91]:
heads_not_ablate = [ ]

mlps_not_ablate = []

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50'
51th char = ' 25'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 25
34 32 3 33 41 8 27 25 35 29 27 18 27 29 11 42 5 13 20 5 2 45 8 3 30 35 28 33 32 1 47 12 31 48 20 8 27 25 24 50 25 3 19 50 48 22 50 8 7 49 25


Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 25'
corr logit of new char: -3.4891464710235596

 Total corr logit:  -3.4891464710235596


# 1 to 20

In [108]:
import random

# Generate a string of numbers from 1 to 50
sequence_string = ' '.join(map(str, range(1, 21)))

# Generate a string of random numbers picked from 1 to 50
random_numbers = [random.randint(1, 20) for _ in range(20)]
random_string = ' '.join(map(str, random_numbers))

print("Sequence String: ", sequence_string)
print("Random Numbers String: ", random_string)

Sequence String:  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
Random Numbers String:  19 17 16 3 19 10 19 18 18 6 16 8 12 1 18 15 5 20 18 19


In [109]:
clean_text = sequence_string
corr_text = random_string
corr_ans_tokLen = 1

clean

In [110]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20'
21th char = ' 21'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
19 17 16 3 19 10 19 18 18 6 16 8 12 1 18 15 5 20 18 19 21


Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21'
corr logit of new char: 6.063088417053223

 Total corr logit:  6.063088417053223


corrupt the subcircuit

In [111]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20'
21th char = ' 20'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 20
19 17 16 3 19 10 19 18 18 6 16 8 12 1 18 15 5 20 18 19 20


Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 20'
corr logit of new char: -0.41758298873901367

 Total corr logit:  -0.41758298873901367


ablate 4.4, 7.11, 9.1

In [112]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20'
21th char = ' 21'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
19 17 16 3 19 10 19 18 18 6 16 8 12 1 18 15 5 20 18 19 21


Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21'
corr logit of new char: 1.1238365173339844

 Total corr logit:  1.1238365173339844


corrupt 9.1 and mlp9

In [113]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((9, 1))

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20'
21th char = ' 21'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
19 17 16 3 19 10 19 18 18 6 16 8 12 1 18 15 5 20 18 19 21


Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21'
corr logit of new char: -0.5075206756591797

 Total corr logit:  -0.5075206756591797


ablate mlp 9

In [114]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20'
21th char = ' 21'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
19 17 16 3 19 10 19 18 18 6 16 8 12 1 18 15 5 20 18 19 21


Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21'
corr logit of new char: 0.18439197540283203

 Total corr logit:  0.18439197540283203


ablate just 9.1

In [115]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((9, 1))

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20'
21th char = ' 21'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
19 17 16 3 19 10 19 18 18 6 16 8 12 1 18 15 5 20 18 19 21


Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21'
corr logit of new char: 5.55417537689209

 Total corr logit:  5.55417537689209


ablate random head

In [116]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((6, 2))

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20'
21th char = ' 21'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
19 17 16 3 19 10 19 18 18 6 16 8 12 1 18 15 5 20 18 19 21


Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21'
corr logit of new char: 6.122335433959961

 Total corr logit:  6.122335433959961


ablate all

In [117]:
heads_not_ablate = [ ]

mlps_not_ablate = []

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20'
21th char = ' 20'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 20
19 17 16 3 19 10 19 18 18 6 16 8 12 1 18 15 5 20 18 19 20


Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 20'
corr logit of new char: -0.2929258346557617

 Total corr logit:  -0.2929258346557617


# 1 to 100

In [118]:
import random

# Generate a string of numbers from 1 to 50
sequence_string = ' '.join(map(str, range(1, 101)))

# Generate a string of random numbers picked from 1 to 50
random_numbers = [random.randint(1, 100) for _ in range(100)]
random_string = ' '.join(map(str, random_numbers))

print("Sequence String: ", sequence_string)
print("Random Numbers String: ", random_string)

Sequence String:  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
Random Numbers String:  17 50 98 49 40 17 48 29 11 61 99 97 66 54 12 1 90 99 24 29 24 54 60 95 85 75 24 1 11 63 78 84 46 26 15 4 83 50 62 54 98 75 15 80 39 63 73 48 68 64 95 62 26 94 39 86 15 16 85 78 72 13 84 84 42 77 97 63 26 6 10 47 91 31 1 36 15 31 49 51 13 80 69 79 55 3 20 24 44 54 17 45 82 95 96 36 81 26 62 62


In [119]:
clean_text = sequence_string
corr_text = random_string
corr_ans_tokLen = 1

clean

In [120]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100'
101th char = '\n'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100

17 50 98 49 40 17 48 29 11 61 99 97 66 54 12 1 90 99 24 29 24 54 60 95 85 75 24 1 11 63 78 84 46 26 15 4 83 50 62 54 98 75 15 80 39 63 73 48 68 64 95 62 26 94 39 86 15 16 85 78 72 13 84 84 42 77 97 63 26 6 10 47 91 31 1 36 15 31 49 51 13 80 69 79 55 3 20 24 44 54 17 45 82 95 96 36 81 26 62 62



Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 

corrupt the subcircuit

In [121]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100'
101th char = '\n'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100

17 50 98 49 40 17 48 29 11 61 99 97 66 54 12 1 90 99 24 29 24 54 60 95 85 75 24 1 11 63 78 84 46 26 15 4 83 50 62 54 98 75 15 80 39 63 73 48 68 64 95 62 26 94 39 86 15 16 85 78 72 13 84 84 42 77 97 63 26 6 10 47 91 31 1 36 15 31 49 51 13 80 69 79 55 3 20 24 44 54 17 45 82 95 96 36 81 26 62 62



Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 

ablate 4.4, 7.11, 9.1

In [122]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
head_to_remove = ([(9, 1), (4,4), (7,11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100'
101th char = '\n'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100

17 50 98 49 40 17 48 29 11 61 99 97 66 54 12 1 90 99 24 29 24 54 60 95 85 75 24 1 11 63 78 84 46 26 15 4 83 50 62 54 98 75 15 80 39 63 73 48 68 64 95 62 26 94 39 86 15 16 85 78 72 13 84 84 42 77 97 63 26 6 10 47 91 31 1 36 15 31 49 51 13 80 69 79 55 3 20 24 44 54 17 45 82 95 96 36 81 26 62 62



Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 

corrupt 9.1 and mlp9

In [123]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((9, 1))

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100'
101th char = '\n'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100

17 50 98 49 40 17 48 29 11 61 99 97 66 54 12 1 90 99 24 29 24 54 60 95 85 75 24 1 11 63 78 84 46 26 15 4 83 50 62 54 98 75 15 80 39 63 73 48 68 64 95 62 26 94 39 86 15 16 85 78 72 13 84 84 42 77 97 63 26 6 10 47 91 31 1 36 15 31 49 51 13 80 69 79 55 3 20 24 44 54 17 45 82 95 96 36 81 26 62 62



Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 

ablate mlp 9

In [124]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated

mlps_not_ablate = [layer for layer in range(12) if layer != 9]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100'
101th char = '\n'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100

17 50 98 49 40 17 48 29 11 61 99 97 66 54 12 1 90 99 24 29 24 54 60 95 85 75 24 1 11 63 78 84 46 26 15 4 83 50 62 54 98 75 15 80 39 63 73 48 68 64 95 62 26 94 39 86 15 16 85 78 72 13 84 84 42 77 97 63 26 6 10 47 91 31 1 36 15 31 49 51 13 80 69 79 55 3 20 24 44 54 17 45 82 95 96 36 81 26 62 62



Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 

ablate just 9.1

In [125]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((9, 1))

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100'
101th char = '\n'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100

17 50 98 49 40 17 48 29 11 61 99 97 66 54 12 1 90 99 24 29 24 54 60 95 85 75 24 1 11 63 78 84 46 26 15 4 83 50 62 54 98 75 15 80 39 63 73 48 68 64 95 62 26 94 39 86 15 16 85 78 72 13 84 84 42 77 97 63 26 6 10 47 91 31 1 36 15 31 49 51 13 80 69 79 55 3 20 24 44 54 17 45 82 95 96 36 81 26 62 62



Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 

ablate random head

In [126]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]  # unablated
heads_not_ablate.remove((6, 2))

mlps_not_ablate = [layer for layer in range(12)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100'
101th char = '\n'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100

17 50 98 49 40 17 48 29 11 61 99 97 66 54 12 1 90 99 24 29 24 54 60 95 85 75 24 1 11 63 78 84 46 26 15 4 83 50 62 54 98 75 15 80 39 63 73 48 68 64 95 62 26 94 39 86 15 16 85 78 72 13 84 84 42 77 97 63 26 6 10 47 91 31 1 36 15 31 49 51 13 80 69 79 55 3 20 24 44 54 17 45 82 95 96 36 81 26 62 62



Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 

ablate all

In [127]:
heads_not_ablate = [ ]

mlps_not_ablate = []

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100'
101th char = ' 65'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 65
17 50 98 49 40 17 48 29 11 61 99 97 66 54 12 1 90 99 24 29 24 54 60 95 85 75 24 1 11 63 78 84 46 26 15 4 83 50 62 54 98 75 15 80 39 63 73 48 68 64 95 62 26 94 39 86 15 16 85 78 72 13 84 84 42 77 97 63 26 6 10 47 91 31 1 36 15 31 49 51 13 80 69 79 55 3 20 24 44 54 17 45 82 95 96 36 81 26 62 62 65


Sequence so far: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 2