# Setup

In [1]:
save_files = True

In [2]:
%%capture
%pip install git+https://github.com/neelnanda-io/TransformerLens.git

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
# import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

import pickle
from google.colab import files

import matplotlib.pyplot as plt
import statistics

In [4]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer #, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [5]:
torch.set_grad_enabled(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
import pdb

## Import functions from repo

In [7]:
!git clone https://github.com/apartresearch/seqcont_circuits.git
%cd /content/seqcont_circuits/src/iter_node_pruning

Cloning into 'seqcont_circuits'...
remote: Enumerating objects: 875, done.[K
remote: Counting objects: 100% (341/341), done.[K
remote: Compressing objects: 100% (223/223), done.[K
remote: Total 875 (delta 187), reused 257 (delta 107), pack-reused 534[K
Receiving objects: 100% (875/875), 16.78 MiB | 10.26 MiB/s, done.
Resolving deltas: 100% (550/550), done.
/content/seqcont_circuits/src/iter_node_pruning


In [8]:
## comment this out when debugging functions in colab to use funcs defined in colab

# don't improt this
# # from dataset import Dataset

from metrics import *
from head_ablation_fns import *
from mlp_ablation_fns import *
from node_ablation_fns import *
from loop_node_ablation_fns import *

## fns

In [9]:
class Dataset:
    def __init__(self, prompts, pos_dict, tokenizer):  # , S1_is_first=False
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        self.corr_tokenIDs = [
            # self.tokenizer.encode(" " + prompt["corr"])[0] for prompt in self.prompts
            self.tokenizer.encode(prompt["corr"])[0] for prompt in self.prompts
        ]
        self.incorr_tokenIDs = [
            # self.tokenizer.encode(" " + prompt["incorr"])[0] for prompt in self.prompts
            self.tokenizer.encode(prompt["incorr"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a dict whose values are tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        # for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
        for targ in [key for key in pos_dict]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = self.tokenizer.tokenize(input_text)
                # if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                #     target_token = prompt[targ]
                # else:
                #     target_token = "Ġ" + prompt[targ]
                # target_index = tokens.index(target_token)
                target_index = pos_dict[targ]
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [10]:
def generate_prompts_list_longer(text, tokens):
    prompts_list = []
    prompt_dict = {
        'corr': str(1),
        'incorr': str(2),
        'text': text
        # 'text': model.to_string(tokens)[0]
        }
    tokens_as_strs = model.tokenizer.tokenize(text)
    # tokens_as_strs = model.to_string(tokens)[0].split()
    # for i in range(tokens.shape[1]):
    for i, tok in enumerate(tokens_as_strs):
        prompt_dict['S'+str(i)] = tok
    # for i, tok in enumerate(tokens):
    #     prompt_dict['S'+str(i)] = model.to_string(tok)

    # prompt_dict = {
    #     'corr': '4',
    #     'incorr': '3',
    #     'text': model.to_string(tokens)[0]
    # }
    # # list_tokens = tokenizer.tokenize('1 2 3 ')
    # tokens_as_strs = model.to_string(tokens)[0].split()
    # for i, tok_as_str in enumerate(tokens_as_strs):
    #     if tok_as_str == '▁':
    #         prompt_dict['S'+str(i)] = ' '
    #     else:
    #         prompt_dict['S'+str(i)] = tok_as_str
    prompts_list.append(prompt_dict)
    return prompts_list

# Load Model

In [11]:
from transformers import LlamaForCausalLM, LlamaTokenizer

In [12]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) n
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [13]:
LLAMA_2_7B_CHAT_PATH = "meta-llama/Llama-2-7b-chat-hf"

tokenizer = LlamaTokenizer.from_pretrained(LLAMA_2_7B_CHAT_PATH)
# tokenizer = LlamaTokenizer.from_pretrained(LLAMA_2_7B_CHAT_PATH, use_fast= False, add_prefix_space= False)
hf_model = LlamaForCausalLM.from_pretrained(LLAMA_2_7B_CHAT_PATH, low_cpu_mem_usage=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.62k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

In [14]:
import transformer_lens.utils as utils
from transformer_lens.hook_points import HookPoint
from transformer_lens import HookedTransformer

In [15]:
model = HookedTransformer.from_pretrained(
    LLAMA_2_7B_CHAT_PATH,
    hf_model = hf_model,
    tokenizer = tokenizer,
    device = "cpu",
    fold_ln = False,
    center_writing_weights = False,
    center_unembed = False,
)

del hf_model

model = model.to("cuda" if torch.cuda.is_available() else "cpu")

Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer
Moving model to device:  cuda


# Load datasets

In [21]:
def generate_prompts_list(x ,y):
    prompts_list = []
    for i in range(x, y):
        prompt_dict = {
            'S1': str(i),
            'S2': str(i+1),
            'S3': str(i+2),
            'S4': str(i+3),
            'corr': str(i+4),
            'incorr': str(i+3),
            'text': f"{i} {i+1} {i+2} {i+3}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(1, 2)
prompts_list

[{'S1': '1',
  'S2': '2',
  'S3': '3',
  'S4': '4',
  'corr': '5',
  'incorr': '4',
  'text': '1 2 3 4'}]

In [22]:
pos_dict = {}
for i in range(len(model.tokenizer.tokenize(prompts_list[0]['text']))):
    pos_dict['S'+str(i)] = i

In [23]:
dataset = Dataset(prompts_list, pos_dict, model.tokenizer)

TypeError: Dataset.__init__() missing 1 required positional argument: 'tokens'

In [None]:
import random

def generate_prompts_list_corr(prompt_list):
    outlist = []
    # for i in range(100):
    for prompt_dict in prompts_list:
        r1 = random.randint(1, 12)
        r2 = random.randint(1, 12)
        while True:
            r3 = random.randint(1, 12)
            r4 = random.randint(1, 12)
            if r4 - 1 != r3:
                break
        new_text = prompt_dict['text'].replace(prompt_dict['S1'], str(r1)).replace(prompt_dict['S2'], str(r2)).replace(prompt_dict['S3'], str(r3)).replace(prompt_dict['S4'], str(r4))
        new_prompt_dict = {
            'S1': str(r1),
            'S2': str(r2),
            'S3': str(r3),
            'S4': str(r4),
            'corr': prompt_dict['corr'],
            'incorr': prompt_dict['incorr'],
            'text': new_text
        }
        outlist.append(new_prompt_dict)
    return outlist
prompts_list_2 = generate_prompts_list_corr(prompts_list)
len(prompts_list_2)

In [None]:
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer)

## Get orig score

In [None]:
model.reset_hooks(including_permanent=True)
logits_original = model(dataset.toks)
orig_score = get_logit_diff(logits_original, dataset)
orig_score

In [None]:
import gc

del(logits_original)
torch.cuda.empty_cache()
gc.collect()

# new dataset

In [16]:
class Dataset:
    def __init__(self, prompts, pos_dict, tokenizer, tokens):  # , S1_is_first=False
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        # self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
        #     torch.int
        # )
        self.toks = tokens
        self.corr_tokenIDs = [
            # self.tokenizer.encode(" " + prompt["corr"])[0] for prompt in self.prompts
            self.tokenizer.encode(prompt["corr"])[0] for prompt in self.prompts
        ]
        self.incorr_tokenIDs = [
            # self.tokenizer.encode(" " + prompt["incorr"])[0] for prompt in self.prompts
            self.tokenizer.encode(prompt["incorr"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a dict whose values are tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        # for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
        for targ in [key for key in pos_dict]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                # tokens = self.tokenizer.tokenize(input_text)
                # if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                #     target_token = prompt[targ]
                # else:
                #     target_token = "Ġ" + prompt[targ]
                # target_index = tokens.index(target_token)
                target_index = pos_dict[targ]
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        # targ_lst = []
        # for prompt in self.prompts:
        #     input_text = prompt["text"]
        #     tokens = self.tokenizer.tokenize(input_text)
        #     end_token_index = len(tokens) - 1
        #     targ_lst.append(end_token_index)
        # self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

# logit diff for mult tok answers

In [24]:
def clean_gen(model, clean_text, corr_ans):
    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
    tokens = model.to_tokens(clean_text).to(device)
    # tokens = tokens[:, 1:] # get rid of prepend bos when using model.to_tokens

    total_score = 0
    corr_ans_tokLen = 0
    ans_so_far = ''
    # while True:
    for i in range(5):
        print(f"Sequence so far: {model.to_string(tokens)[0]!r}")
        logits = model(tokens)
        next_token = logits[0, -1].argmax(dim=-1) # Get the predicted token at the end of our sequence
        next_char = model.to_string(next_token)

        # dataset = Dataset(prompts_list, pos_dict, model.tokenizer)
        dataset = Dataset(prompts_list, pos_dict, model.tokenizer, tokens)

        # new_score = get_logit_diff(logits, dataset)

        # measure how far away predicted logit is from corr token?

        # corr_logits = logits[:, dataset.word_idx["end"], dataset.corr_tokenIDs]
        # incorr_logits = logits[:, dataset.word_idx["end"], dataset.incorr_tokenIDs]
        # new_score = corr_logits - incorr_logits

        corr_logits = logits[:, -1, next_token]
        total_score += corr_logits
        print(f"logit diff of new char: {corr_logits}")

        ans_so_far += next_char
        corr_ans_tokLen += 1
        print(f"{tokens.shape[-1]+1}th char = {next_char!r}")
        if ans_so_far == corr_ans:
            print('\nTotal logit diff: ', total_score.item())
            break
        # Define new input sequence, by appending the previously generated token
        tokens = torch.cat([tokens, next_token[None, None]], dim=-1)
        # if next_char == '':
        #     next_char = ' '
        # clean_text = clean_text + next_char
        # tokens = model.to_tokens(clean_text).to(device)
    return corr_ans_tokLen

In [25]:
clean_text = "1 2 3"
corr_ans = ' 5'
corr_ans_tokLen = clean_gen(model, clean_text, corr_ans)

Sequence so far: '<s> 1 2 3'
logit diff of new char: tensor([16.3678], device='cuda:0')
8th char = ''
Sequence so far: '<s> 1 2 3 '
logit diff of new char: tensor([19.4441], device='cuda:0')
9th char = '4'
Sequence so far: '<s> 1 2 3 4'
logit diff of new char: tensor([20.5775], device='cuda:0')
10th char = ''
Sequence so far: '<s> 1 2 3 4 '
logit diff of new char: tensor([19.2083], device='cuda:0')
11th char = '5'
Sequence so far: '<s> 1 2 3 4 5'
logit diff of new char: tensor([18.8848], device='cuda:0')
12th char = ''


In [26]:
def ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen):
    tokens = model.to_tokens(clean_text).to(device)
    prompts_list = generate_prompts_list_longer(clean_text, tokens)

    corr_tokens = model.to_tokens(corr_text).to(device)
    prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
    pos_dict = {}
    # for i in range(len(model.tokenizer.tokenize(prompts_list_2[0]['text']))):
    num_pos = len(model.tokenizer(prompts_list_2[0]['text']).input_ids)
    # for i in range(num_pos + 1):
    for i in range(num_pos ):
        pos_dict['S'+str(i)] = i
    # dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer)
    dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, tokens)
    # pdb.set_trace()
    model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

    # tokens = tokens[:, 1:] # get rid of prepend bos when using model.to_tokens
    logits = model(tokens)
    next_token = logits[0, -1].argmax(dim=-1)
    next_char = model.to_string(next_token)

    total_score = 0

    for i in range(corr_ans_tokLen):
    # for i in range(5):
        if next_char == '':
            next_char = ' '
        # print(f"Sequence so far: {model.to_string(tokens)[0]!r}")
        # print(f"{tokens.shape[-1]+1}th char = {next_char!r}")

        clean_text = clean_text + next_char
        # tokens = model.to_tokens(clean_text).to(device)
        # tokens = tokens[:, 1:]
        # print(clean_text)
        print(f"Sequence so far: {clean_text}")
        print(f"{tokens.shape[-1]+1}th char = {next_char!r}")

        # clean_text = model.to_string(tokens)[0]
        tokens = torch.cat([tokens, next_token[None, None]], dim=-1)
        print(clean_text)
        # print(tokens.shape)

        # get new ablation dataset
        model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

        corr_text = corr_text + next_char
        # corr_tokens = model.to_tokens(corr_text).to(device)

        # corr_text = model.to_string(corr_tokens)[0]
        corr_tokens = torch.cat([corr_tokens, next_token[None, None]], dim=-1)
        prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)
        print(corr_text)
        # print(corr_tokens.shape)

        pos_dict = {}
        # for i in range(len(model.tokenizer.tokenize(prompts_list_2[0]['text']))):
        # for i in range(corr_tokens.shape[1]):
        num_pos = len(model.tokenizer(prompts_list_2[0]['text']).input_ids)
        # for i in range(num_pos + 1):
        for i in range(num_pos ):
            pos_dict['S'+str(i)] = i

        # dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer)
        dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, corr_tokens)

        model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

        logits = model(tokens)
        next_token = logits[0, -1].argmax(dim=-1) # Get the predicted token at the end of our sequence
        next_char = model.to_string(next_token)

        # print('\n')
        # print(f"Sequence so far: {model.to_string(tokens)[0]!r}")

        # new_score = get_logit_diff(logits, dataset)
        # total_score += new_score
        # print(f"corr logit of new char: {new_score}")
    # print('\n Total corr logit: ', total_score.item())

# test

In [None]:
# heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
# mlps_not_ablate = [layer for layer in range(32)]
# clean_text = "1 2 3"
# tokens = model.to_tokens(clean_text).to(device)
# prompts_list = generate_prompts_list_longer(clean_text, tokens)
# corr_text = "5 3 9"
# corr_tokens = model.to_tokens(corr_text).to(device)
# prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)

In [None]:
# tokens = model.to_tokens(clean_text).to(device)
# prompts_list = generate_prompts_list_longer(clean_text, tokens)

# corr_tokens = model.to_tokens(corr_text).to(device)
# prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)

# model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
# pos_dict = {}
# for i in range(len(model.tokenizer.tokenize(prompts_list_2[0]['text']))):
#     pos_dict['S'+str(i)] = i
# dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer)
# model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

# # tokens = tokens[:, 1:] # get rid of prepend bos when using model.to_tokens
# # pdb.set_trace()
# logits = model(tokens)
# next_token = logits[0, -1].argmax(dim=-1)
# next_char = model.to_string(next_token)

# total_score = 0

# for i in range(1):
# # for i in range(5):
#     if next_char == '':
#         next_char = ' '
#     print(f"Sequence so far: {model.to_string(tokens)[0]!r}")
#     print(f"{tokens.shape[-1]+1}th char = {next_char!r}")

#     # if next_char == '':
#     #     clean_text = clean_text + ' '
#     # else:
#     clean_text = clean_text + next_char
#     # tokens = model.to_tokens(clean_text).to(device)
#     # tokens = tokens[:, 1:]
#     # print(clean_text)

#     # clean_text = model.to_string(tokens)[0]
#     tokens = torch.cat([tokens, next_token[None, None]], dim=-1)
#     print(clean_text)
#     # print(tokens.shape)

#     # get new ablation dataset
#     model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

#     corr_text = corr_text + next_char
#     # corr_tokens = model.to_tokens(corr_text).to(device)

#     # corr_text = model.to_string(corr_tokens)[0]
#     corr_tokens = torch.cat([corr_tokens, next_token[None, None]], dim=-1)
#     prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)
#     print(corr_text)
#     # print(corr_tokens.shape)

#     pos_dict = {}
#     # for i in range(len(model.tokenizer.tokenize(prompts_list_2[0]['text']))):
#     for i in range(corr_tokens.shape[1]):
#         pos_dict['S'+str(i)] = i

#     dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer)

#     # model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

#     logits = model(tokens)
#     next_token = logits[0, -1].argmax(dim=-1) # Get the predicted token at the end of our sequence
#     next_char = model.to_string(next_token)

#     # print('\n')
#     # print(f"Sequence so far: {model.to_string(tokens)[0]!r}")

#     # new_score = get_logit_diff(logits, dataset)
#     # total_score += new_score
#     # print(f"corr logit of new char: {new_score}")
# # print('\n Total corr logit: ', total_score.item())

Sequence so far: '<s> 1 2 3'
8th char = ' '
1 2 3 
5 3 9 


# test clean prompts

In [None]:
# clean_text = "1"
# tokens = model.to_tokens(clean_text).to(device)
# tokens = tokens[:, 1:] # get rid of prepend bos when using model.to_tokens
# # tokens = model.tokenizer(clean_text)['input_ids']
# logits = model(tokens)
# next_token = logits[0, -1].argmax(dim=-1) # Get the predicted token at the end of our sequence
# next_char = model.to_string(next_token)

clean_text = "1"
corr_ans_tokLen = 3
clean_gen(model, clean_text, corr_ans)

Sequence so far: '<s> 1'
logit diff of new char: tensor([12.9576], device='cuda:0')
4th char = '.'
Sequence so far: '<s> 1.'
logit diff of new char: tensor([5.0549], device='cuda:0')
5th char = ''
Sequence so far: '<s> 1. '
logit diff of new char: tensor([8.5630], device='cuda:0')
6th char = 'Introduction'
Sequence so far: '<s> 1. Introduction'
logit diff of new char: tensor([21.3909], device='cuda:0')
6th char = '<0x0A>'
Sequence so far: '<s> 1. Introduction<0x0A>'
logit diff of new char: tensor([18.8126], device='cuda:0')
12th char = '<0x0A>'


5

In [None]:
clean_text = "two"
corr_ans_tokLen = 3
clean_gen(model, clean_text, corr_ans)

Sequence so far: '<s> two'
logit diff of new char: tensor([5.4994], device='cuda:0')
3th char = '-'
Sequence so far: '<s> two-'
logit diff of new char: tensor([10.8123], device='cuda:0')
4th char = 'time'
Sequence so far: '<s> two-time'
logit diff of new char: tensor([10.4234], device='cuda:0')
5th char = 'Pul'
Sequence so far: '<s> two-timePul'
logit diff of new char: tensor([21.1257], device='cuda:0')
7th char = 'itzer'
Sequence so far: '<s> two-timePulitzer'
logit diff of new char: tensor([19.2238], device='cuda:0')
8th char = 'Prize'


5

In [None]:
clean_text = "March"
corr_ans_tokLen = 3
clean_gen(model, clean_text, corr_ans)

Sequence so far: '<s> March'
logit diff of new char: tensor([17.5723], device='cuda:0')
3th char = ''
Sequence so far: '<s> March '
logit diff of new char: tensor([19.5974], device='cuda:0')
4th char = '2'
Sequence so far: '<s> March 2'
logit diff of new char: tensor([18.3477], device='cuda:0')
5th char = '0'
Sequence so far: '<s> March 20'
logit diff of new char: tensor([20.1788], device='cuda:0')
6th char = '1'
Sequence so far: '<s> March 201'
logit diff of new char: tensor([19.7587], device='cuda:0')
7th char = '7'


5

In [None]:
clean_text = "Bob is first. David is"
corr_ans_tokLen = 3
clean_gen(model, clean_text, corr_ans)

Sequence so far: '<s> Bob is first. David is'
logit diff of new char: tensor([13.4227], device='cuda:0')
8th char = 'second'
Sequence so far: '<s> Bob is first. David issecond'
logit diff of new char: tensor([18.8607], device='cuda:0')
9th char = '.'
Sequence so far: '<s> Bob is first. David issecond.'
logit diff of new char: tensor([10.3436], device='cuda:0')
10th char = '<0x0A>'
Sequence so far: '<s> Bob is first. David issecond.<0x0A>'
logit diff of new char: tensor([11.8312], device='cuda:0')
15th char = '<0x0A>'
Sequence so far: '<s> Bob is first. David issecond.<0x0A><0x0A>'
logit diff of new char: tensor([10.4858], device='cuda:0')
20th char = '<0x0A>'


5

In [None]:
clean_text = "Two days after Monday is"
corr_ans_tokLen = 3
clean_gen(model, clean_text, corr_ans)

Sequence so far: '<s> Two days after Monday is'
logit diff of new char: tensor([17.0134], device='cuda:0')
7th char = 'T'
Sequence so far: '<s> Two days after Monday isT'
logit diff of new char: tensor([18.5075], device='cuda:0')
8th char = 'ues'
Sequence so far: '<s> Two days after Monday isTues'
logit diff of new char: tensor([24.0123], device='cuda:0')
9th char = 'day'
Sequence so far: '<s> Two days after Monday isTuesday'
logit diff of new char: tensor([17.9468], device='cuda:0')
10th char = '.'
Sequence so far: '<s> Two days after Monday isTuesday.'
logit diff of new char: tensor([9.2363], device='cuda:0')
11th char = '<0x0A>'


5

In [None]:
clean_text = "uno dos tres"
corr_ans_tokLen = 3
clean_gen(model, clean_text, corr_ans)

Sequence so far: '<s> uno dos tres'
logit diff of new char: tensor([7.8090], device='cuda:0')
5th char = 'cuatro'
Sequence so far: '<s> uno dos trescuatro'
logit diff of new char: tensor([8.2790], device='cuda:0')
7th char = 'five'
Sequence so far: '<s> uno dos trescuatrofive'
logit diff of new char: tensor([11.5774], device='cuda:0')
8th char = 'six'
Sequence so far: '<s> uno dos trescuatrofivesix'
logit diff of new char: tensor([18.0978], device='cuda:0')
10th char = 'se'
Sequence so far: '<s> uno dos trescuatrofivesixse'
logit diff of new char: tensor([20.8952], device='cuda:0')
11th char = 'ven'


5

In [None]:
clean_text = "one two three"
corr_ans_tokLen = 3
clean_gen(model, clean_text, corr_ans)

Sequence so far: '<s> one two three'
logit diff of new char: tensor([10.4469], device='cuda:0')
5th char = 'four'
Sequence so far: '<s> one two threefour'
logit diff of new char: tensor([12.6489], device='cuda:0')
6th char = 'five'
Sequence so far: '<s> one two threefourfive'
logit diff of new char: tensor([16.8563], device='cuda:0')
7th char = 'six'
Sequence so far: '<s> one two threefourfivesix'
logit diff of new char: tensor([16.7848], device='cuda:0')
9th char = 'se'
Sequence so far: '<s> one two threefourfivesixse'
logit diff of new char: tensor([23.9431], device='cuda:0')
10th char = 'ven'


5

In [None]:
clean_text = "2 4 6 "
corr_ans_tokLen = 3
clean_gen(model, clean_text, corr_ans)

Sequence so far: '<s> 2 4 6 '
logit diff of new char: tensor([19.3199], device='cuda:0')
9th char = '8'
Sequence so far: '<s> 2 4 6 8'
logit diff of new char: tensor([9.7377], device='cuda:0')
10th char = ''
Sequence so far: '<s> 2 4 6 8 '
logit diff of new char: tensor([16.1787], device='cuda:0')
11th char = '1'
Sequence so far: '<s> 2 4 6 8 1'
logit diff of new char: tensor([22.3343], device='cuda:0')
12th char = '0'
Sequence so far: '<s> 2 4 6 8 10'
logit diff of new char: tensor([7.2428], device='cuda:0')
13th char = ''


5

# test if diff answers in two genr ways

In [None]:
# Get list of arguments to pass to `generate` (specifically these are the ones relating to sampling)
generate_kwargs = dict(
    do_sample = False, # deterministic output so we can compare it to the HF model
    top_p = 1.0, # suppresses annoying output errors
    temperature = 1.0, # suppresses annoying output errors
)

In [None]:
prompt = "uno dos tres"
output = model.generate(prompt, max_new_tokens=1, **generate_kwargs)
print(output)

  0%|          | 0/1 [00:00<?, ?it/s]

uno dos tres cuatro


In [None]:
prompt = "one two three"
output = model.generate(prompt, max_new_tokens=1, **generate_kwargs)
print(output)

  0%|          | 0/1 [00:00<?, ?it/s]

one two three four


# 1 2 3 genr ablation expms

In [29]:
def mask_circ_heads(
    means_dataset: Dataset,
    model: HookedTransformer,
    circuit: Dict[str, List[Tuple[int, int]]],
    seq_pos_to_keep: Dict[str, str],
) -> Dict[int, Bool[Tensor, "batch seq head"]]:
    '''
    Output: for each layer, a mask of circuit components that should not be ablated
    '''
    heads_and_posns_to_keep = {}
    batch, seq, n_heads = len(means_dataset), means_dataset.max_len, model.cfg.n_heads

    for layer in range(model.cfg.n_layers):

        mask = t.zeros(size=(batch, seq, n_heads))

        for (head_type, head_list) in circuit.items():
            seq_pos = seq_pos_to_keep[head_type]
            # if seq_pos == 'S7':
            #     pdb.set_trace()
            indices = means_dataset.word_idx[seq_pos] # modify this for key vs query pos. curr, this is query
            for (layer_idx, head_idx) in head_list:
                if layer_idx == layer:
                    # if indices.item() == 7:
                    #     pdb.set_trace()
                    mask[:, indices, head_idx] = 1
                    # mask[:, :, head_idx] = 1  # keep L.H at all pos

        heads_and_posns_to_keep[layer] = mask.bool()
    # pdb.set_trace()
    return heads_and_posns_to_keep

In [27]:
# from dataset import Dataset
from transformer_lens import HookedTransformer, utils
from transformer_lens.hook_points import HookPoint
import einops
from functools import partial
import torch as t
from torch import Tensor
from typing import Dict, Tuple, List
from jaxtyping import Float, Bool

# from head_ablation_fns import *
# from mlp_ablation_fns import *

def add_ablation_hook_MLP_head(
    model: HookedTransformer,
    means_dataset: Dataset,
    heads_lst, mlp_lst,
    is_permanent: bool = True,
) -> HookedTransformer:
    CIRCUIT = {}
    SEQ_POS_TO_KEEP = {}
    # for i in range(len(model.tokenizer.tokenize(means_dataset.prompts[0]['text']))):
    num_pos = len(model.tokenizer(means_dataset.prompts[0]['text']).input_ids)
    for i in range(num_pos ):
        CIRCUIT['S'+str(i)] = heads_lst
        # if i == len(model.tokenizer.tokenize(means_dataset.prompts[0]['text'])) - 1:
        # if i == num_pos - 1:
        #     SEQ_POS_TO_KEEP['S'+str(i)] = 'end'
        # else:
        SEQ_POS_TO_KEEP['S'+str(i)] = 'S'+str(i)

    model.reset_hooks(including_permanent=True)

    # Compute the mean of each head's output on the ABC dataset, grouped by template
    means = get_heads_actv_mean(means_dataset, model)
    # Convert this into a boolean map
    components_to_keep = mask_circ_heads(means_dataset, model, CIRCUIT, SEQ_POS_TO_KEEP)

    # Get a hook function which will patch in the mean z values for each head, at
    # all positions which aren't important for the circuit
    hook_fn = partial(
        hook_func_mask_head,
        # components_to_keep=components_to_keep,
        # means=means,
        circuit=CIRCUIT,
    )

    # Apply hook
    model.add_hook(lambda name: name.endswith("z"), hook_fn, is_permanent=is_permanent)

    # if all_entries_true(components_to_keep) == False:
    #     pdb.set_trace()
    ########################
    # CIRCUIT = {}
    # SEQ_POS_TO_KEEP = {}
    # # for i in range(len(model.tokenizer.tokenize(means_dataset.prompts[0]['text']))):
    # num_pos = len(model.tokenizer(means_dataset.prompts[0]['text']).input_ids)
    # for i in range(num_pos ):
    #     CIRCUIT['S'+str(i)] = mlp_lst
    #     # if i == len(model.tokenizer.tokenize(means_dataset.prompts[0]['text'])) - 1:
    #     # if i == num_pos - 1:
    #     #     SEQ_POS_TO_KEEP['S'+str(i)] = 'end'
    #     # else:
    #     SEQ_POS_TO_KEEP['S'+str(i)] = 'S'+str(i)

    # # Compute the mean of each head's output on the ABC dataset, grouped by template
    # means = get_MLPs_actv_mean(means_dataset, model)

    # # Convert this into a boolean map
    # components_to_keep = mask_circ_MLPs(means_dataset, model, CIRCUIT, SEQ_POS_TO_KEEP)

    # # Get a hook function which will patch in the mean z values for each head, at
    # # all positions which aren't important for the circuit
    # hook_fn = partial(
    #     hook_func_mask_mlp_out,
    #     components_to_keep=components_to_keep,
    #     means=means
    # )

    # Apply hook
    # model.add_hook(lambda name: name.endswith("mlp_out"), hook_fn, is_permanent=True)

    return model

In [None]:
clean_text = "1 2 3"
corr_text = "5 3 9"
# corr_text = "1 2 3"

clean

In [None]:
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
mlps_not_ablate = [layer for layer in range(32)]
ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 3)

Sequence so far: 1 2 3 
8th char = ' '
1 2 3 
5 3 9 
Sequence so far: 1 2 3 4
9th char = '4'
1 2 3 4
5 3 9 4
Sequence so far: 1 2 3 4 
10th char = ' '
1 2 3 4 
5 3 9 4 


In [None]:
def all_entries_true(tensor_dict):
    for key, tensor in tensor_dict.items():
        if not torch.all(tensor).item():
            return False
    return True

ablate just head 20.7

In [None]:
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
heads_not_ablate.remove((20, 7))

mlps_not_ablate = [layer for layer in range(32)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 3)

Sequence so far: 1 2 3 
8th char = ' '
1 2 3 
5 3 9 
Sequence so far: 1 2 3 4
9th char = '4'
1 2 3 4
5 3 9 4
Sequence so far: 1 2 3 4 
10th char = ' '
1 2 3 4 
5 3 9 4 


In [None]:
clean_text = "1 2 3"
corr_text = "5 3 9"
heads_not_ablate = []  # ablate all heads but not MLPs
mlps_not_ablate = []  # ablate all MLPs
corr_ans_tokLen = 2
ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen)

Sequence so far: 1 2 3 
8th char = ' '
1 2 3 
5 3 9 
Sequence so far: 1 2 3 1
9th char = '1'
1 2 3 1
5 3 9 1


In [None]:
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
head_to_remove = ([(20, 7), (1, 11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(32)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 3)

Sequence so far: 1 2 3 
8th char = ' '
1 2 3 
5 3 9 
Sequence so far: 1 2 3 4
9th char = '4'
1 2 3 4
5 3 9 4
Sequence so far: 1 2 3 4 
10th char = ' '
1 2 3 4 
5 3 9 4 


In [None]:
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
head_to_remove = ([(20, 7), (1, 11)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(32) if layer != 1]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 3)

Sequence so far: 1 2 3 
8th char = ' '
1 2 3 
5 3 9 
Sequence so far: 1 2 3 4
9th char = '4'
1 2 3 4
5 3 9 4
Sequence so far: 1 2 3 4 
10th char = ' '
1 2 3 4 
5 3 9 4 


In [None]:
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
head_to_remove = ([(20, 7), (1, 11), (16,0)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(32) if layer != 1]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 3)

Sequence so far: 1 2 3 
8th char = ' '
1 2 3 
5 3 9 
Sequence so far: 1 2 3 4
9th char = '4'
1 2 3 4
5 3 9 4
Sequence so far: 1 2 3 4 
10th char = ' '
1 2 3 4 
5 3 9 4 


In [None]:
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
head_to_remove = ([(20, 7), (1, 11), (16,0)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(32) if layer <10]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 3)

Sequence so far: 1 2 3 
8th char = ' '
1 2 3 
5 3 9 
Sequence so far: 1 2 3 4
9th char = '4'
1 2 3 4
5 3 9 4
Sequence so far: 1 2 3 4 
10th char = ' '
1 2 3 4 
5 3 9 4 


In [None]:
heads_not_ablate = []

mlps_not_ablate = [layer for layer in range(32)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 3)

Sequence so far: 1 2 3 
8th char = ' '
1 2 3 
5 3 9 
Sequence so far: 1 2 3 1
9th char = '1'
1 2 3 1
5 3 9 1
Sequence so far: 1 2 3 10
10th char = '0'
1 2 3 10
5 3 9 10


In [None]:
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
head_to_remove = ([(20, 7), (1, 11), (16,0), (0, 30), (0, 9), (15,25)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(32)]

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 3)

Sequence so far: 1 2 3 
8th char = ' '
1 2 3 
5 3 9 
Sequence so far: 1 2 3 4
9th char = '4'
1 2 3 4
5 3 9 4
Sequence so far: 1 2 3 4 
10th char = ' '
1 2 3 4 
5 3 9 4 


# 2 4 6

In [31]:
clean_text = "2 4 6"
corr_text = "5 3 9"
# corr_text = "1 2 3"

clean

In [32]:
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
mlps_not_ablate = [layer for layer in range(32)]
ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 3)

Sequence so far: 2 4 6 
8th char = ' '
2 4 6 
5 3 9 
Sequence so far: 2 4 6 8
9th char = '8'
2 4 6 8
5 3 9 8
Sequence so far: 2 4 6 8 
10th char = ' '
2 4 6 8 
5 3 9 8 


In [33]:
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
head_to_remove = ([(0, 13)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(32)] #  if layer <10

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 3)

Sequence so far: 2 4 6 
8th char = ' '
2 4 6 
5 3 9 
Sequence so far: 2 4 6 1
9th char = '1'
2 4 6 1
5 3 9 1
Sequence so far: 2 4 6 10
10th char = '0'
2 4 6 10
5 3 9 10


In [34]:
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
head_to_remove = ([(0, 1)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(32)] #  if layer <10

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 3)

Sequence so far: 2 4 6 
8th char = ' '
2 4 6 
5 3 9 
Sequence so far: 2 4 6 8
9th char = '8'
2 4 6 8
5 3 9 8
Sequence so far: 2 4 6 8 
10th char = ' '
2 4 6 8 
5 3 9 8 


In [35]:
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
head_to_remove = ([(1, 14)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(32)] #  if layer <10

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 3)

Sequence so far: 2 4 6 
8th char = ' '
2 4 6 
5 3 9 
Sequence so far: 2 4 6 8
9th char = '8'
2 4 6 8
5 3 9 8
Sequence so far: 2 4 6 8 
10th char = ' '
2 4 6 8 
5 3 9 8 


# uno dos tres

In [28]:
model.to_tokens("uno dos tres")

tensor([[   1, 6888, 3248, 9941]], device='cuda:0')

In [29]:
model.to_tokens("uno uno uno")

tensor([[   1, 6888, 6888, 6888]], device='cuda:0')

In [30]:
clean_text = "uno dos tres"
corr_text = "uno uno uno" # dos tres cinco seis

In [31]:
def get_heads_actv_mean(
    means_dataset: Dataset,
    model: HookedTransformer
) -> Float[Tensor, "layer batch seq head_idx d_head"]:
    '''
    Output: The mean activations of a head's output
    '''
    _, means_cache = model.run_with_cache(
        means_dataset.toks.long(),
        return_type=None,
        names_filter=lambda name: name.endswith("z"),
    )
    n_layers, n_heads, d_head = model.cfg.n_layers, model.cfg.n_heads, model.cfg.d_head
    batch, seq_len = len(means_dataset), means_dataset.max_len
    means = t.zeros(size=(n_layers, batch, seq_len, n_heads, d_head), device=model.cfg.device)

    # for layer in range(model.cfg.n_layers):
    #     z_for_this_layer: Float[Tensor, "batch seq head d_head"] = means_cache[utils.get_act_name("z", layer)]
    #     for template_group in means_dataset.groups:
    #         z_for_this_template = z_for_this_layer[template_group]
    #         z_means_for_this_template = einops.reduce(z_for_this_template, "batch seq head d_head -> seq head d_head", "mean")
    #         if z_means_for_this_template.shape[0] == 5:
    #             pdb.set_trace()
    #         means[layer, template_group] = z_means_for_this_template

    del(means_cache)

    return means

In [32]:
def mask_circ_heads(
    means_dataset: Dataset,
    model: HookedTransformer,
    circuit: Dict[str, List[Tuple[int, int]]],
    seq_pos_to_keep: Dict[str, str],
) -> Dict[int, Bool[Tensor, "batch seq head"]]:
    '''
    Output: for each layer, a mask of circuit components that should not be ablated
    '''
    heads_and_posns_to_keep = {}
    # batch, seq, n_heads = len(means_dataset), means_dataset.max_len, model.cfg.n_heads
    batch, seq, n_heads = len(means_dataset), len(circuit.keys()), model.cfg.n_heads
    print(seq)

    for layer in range(model.cfg.n_layers):

        mask = t.zeros(size=(batch, seq, n_heads))

        for (head_type, head_list) in circuit.items():
            seq_pos = seq_pos_to_keep[head_type]
            indices = means_dataset.word_idx[seq_pos] # modify this for key vs query pos. curr, this is query
            for (layer_idx, head_idx) in head_list:
                if layer_idx == layer:
                    # mask[:, indices, head_idx] = 1
                    mask[:, :, head_idx] = 1

        heads_and_posns_to_keep[layer] = mask.bool()

    return heads_and_posns_to_keep

In [33]:
def hook_func_mask_head(
    z: Float[Tensor, "batch seq head d_head"],
    hook: HookPoint,
    # components_to_keep: Dict[int, Bool[Tensor, "batch seq head"]],
    # means: Float[Tensor, "layer batch seq head d_head"],
    circuit: Dict[str, List[Tuple[int, int]]],
) -> Float[Tensor, "batch seq head d_head"]:
    '''
    Use this to not mask components
    '''
    # mask_for_this_layer = components_to_keep[hook.layer()].unsqueeze(-1).to(z.device)
    # z = t.where(mask_for_this_layer, z, means[hook.layer()])

    ###
    # heads_and_posns_to_keep = {}
    # batch, seq, n_heads = z.shape[0], z.shape[1], model.cfg.n_heads  # components_to_keep[0].shape[0] is batch

    # for layer in range(model.cfg.n_layers):

    #     mask = t.zeros(size=(batch, seq, n_heads))

    #     for (head_type, head_list) in circuit.items():
    #         # seq_pos = seq_pos_to_keep[head_type]
    #         # indices = means_dataset.word_idx[seq_pos] # modify this for key vs query pos. curr, this is query
    #         for (layer_idx, head_idx) in head_list:
    #             if layer_idx == layer:
    #                 # mask[:, indices, head_idx] = 1
    #                 mask[:, :, head_idx] = 1

    #     heads_and_posns_to_keep[layer] = mask.bool()
    ###
    mask_for_this_layer = t.zeros(size=(z.shape[0], z.shape[1], z.shape[2]))
    for (head_type, head_list) in circuit.items():
        # seq_pos = seq_pos_to_keep[head_type]
        # indices = means_dataset.word_idx[seq_pos] # modify this for key vs query pos. curr, this is query
        for (layer_idx, head_idx) in head_list:
            if layer_idx == hook.layer():
                # mask[:, indices, head_idx] = 1
                mask_for_this_layer[:, :, head_idx] = 1

    mask_for_this_layer = mask_for_this_layer.bool()
    mask_for_this_layer = mask_for_this_layer.unsqueeze(-1).to(z.device)  # d_model is 1; then is broadcast in where

    z = t.where(mask_for_this_layer, z, 0)

    return z

In [34]:
def add_ablation_hook_head(
    model: HookedTransformer,
    means_dataset: Dataset,
    circuit: Dict[str, List[Tuple[int, int]]],
    seq_pos_to_keep: Dict[str, str],
    is_permanent: bool = True,
) -> HookedTransformer:
    '''
    Ablate the model, except as components and positions to keep
    '''

    model.reset_hooks(including_permanent=True)
    means = get_heads_actv_mean(means_dataset, model)
    components_to_keep = mask_circ_heads(means_dataset, model, circuit, seq_pos_to_keep)

    hook_fn = partial(
        hook_func_mask_head,
        # components_to_keep=components_to_keep,
        # means=means,
        circuit=circuit,
    )

    model.add_hook(lambda name: name.endswith("z"), hook_fn, is_permanent=is_permanent)
    return model

clean

In [35]:
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
mlps_not_ablate = [layer for layer in range(32)]
ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 3)

4
Sequence so far: uno dos trescuatro
5th char = 'cuatro'
uno dos trescuatro
uno uno unocuatro
7
Sequence so far: uno dos trescuatrocinco
6th char = 'cinco'
uno dos trescuatrocinco
uno uno unocuatrocinco
9
Sequence so far: uno dos trescuatrocincoseis
7th char = 'seis'
uno dos trescuatrocincoseis
uno uno unocuatrocincoseis
11


In [36]:
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
head_to_remove = ([(0, 13)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(32)] #  if layer <10

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 3)

4
Sequence so far: uno dos trescuatro
5th char = 'cuatro'
uno dos trescuatro
uno uno unocuatro
7
Sequence so far: uno dos trescuatrocinco
6th char = 'cinco'
uno dos trescuatrocinco
uno uno unocuatrocinco
9
Sequence so far: uno dos trescuatrocincoseis
7th char = 'seis'
uno dos trescuatrocincoseis
uno uno unocuatrocincoseis
11


In [37]:
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
head_to_remove = ([(20, 17)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(32)] #  if layer <10

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 3)

4
Sequence so far: uno dos tres<0x0A>
5th char = '<0x0A>'
uno dos tres<0x0A>
uno uno uno<0x0A>
10
Sequence so far: uno dos tres<0x0A><0x0A>
6th char = '<0x0A>'
uno dos tres<0x0A><0x0A>
uno uno uno<0x0A><0x0A>
15
Sequence so far: uno dos tres<0x0A><0x0A><0x0A>
7th char = '<0x0A>'
uno dos tres<0x0A><0x0A><0x0A>
uno uno uno<0x0A><0x0A><0x0A>
20


# seqcont word problems

In [38]:
clean_text = "What comes after Monday is Tuesday, and two days after is"
corr_text = "What comes after X is Y, and two days after is"

clean

In [39]:
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
mlps_not_ablate = [layer for layer in range(32)]
ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 3)

13
Sequence so far: What comes after Monday is Tuesday, and two days after isTh
16th char = 'Th'
What comes after Monday is Tuesday, and two days after isTh
What comes after X is Y, and two days after isTh
14
Sequence so far: What comes after Monday is Tuesday, and two days after isThurs
17th char = 'urs'
What comes after Monday is Tuesday, and two days after isThurs
What comes after X is Y, and two days after isThurs
15
Sequence so far: What comes after Monday is Tuesday, and two days after isThursday
18th char = 'day'
What comes after Monday is Tuesday, and two days after isThursday
What comes after X is Y, and two days after isThursday
16


In [40]:
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
head_to_remove = ([(0, 13)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(32)] #  if layer <10

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 3)

13
Sequence so far: What comes after Monday is Tuesday, and two days after isTh
16th char = 'Th'
What comes after Monday is Tuesday, and two days after isTh
What comes after X is Y, and two days after isTh
14
Sequence so far: What comes after Monday is Tuesday, and two days after isThurs
17th char = 'urs'
What comes after Monday is Tuesday, and two days after isThurs
What comes after X is Y, and two days after isThurs
15
Sequence so far: What comes after Monday is Tuesday, and two days after isThursday
18th char = 'day'
What comes after Monday is Tuesday, and two days after isThursday
What comes after X is Y, and two days after isThursday
16


In [41]:
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
head_to_remove = ([(20, 17)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(32)] #  if layer <10

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 3)

13
Sequence so far: What comes after Monday is Tuesday, and two days after isTh
16th char = 'Th'
What comes after Monday is Tuesday, and two days after isTh
What comes after X is Y, and two days after isTh
14
Sequence so far: What comes after Monday is Tuesday, and two days after isThurs
17th char = 'urs'
What comes after Monday is Tuesday, and two days after isThurs
What comes after X is Y, and two days after isThurs
15
Sequence so far: What comes after Monday is Tuesday, and two days after isThursday
18th char = 'day'
What comes after Monday is Tuesday, and two days after isThursday
What comes after X is Y, and two days after isThursday
16


Obtain heads from top 20 of https://colab.research.google.com/drive/1p_x98vp4OMx46rphUdIk64E7P94cQmg9#scrollTo=susSZdqpqVzd&line=1&uniqifier=1

In [42]:
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
head_to_remove = ([(20, 17), (1, 11), (0, 30), (0, 9), (5, 26), (16, 0), (13, 6), (15, 25),
    (5, 15), (6, 11), (5, 25), (5, 17), (1, 28), (29, 5), (4, 3), (15, 15),
    (26, 2), (10, 25), (2, 2), (23, 2)])
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

mlps_not_ablate = [layer for layer in range(32)] #  if layer <10

ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 3)

13
Sequence so far: What comes after Monday is Tuesday, and two days after isTh
16th char = 'Th'
What comes after Monday is Tuesday, and two days after isTh
What comes after X is Y, and two days after isTh
14
Sequence so far: What comes after Monday is Tuesday, and two days after isThurs
17th char = 'urs'
What comes after Monday is Tuesday, and two days after isThurs
What comes after X is Y, and two days after isThurs
15
Sequence so far: What comes after Monday is Tuesday, and two days after isThursday
18th char = 'day'
What comes after Monday is Tuesday, and two days after isThursday
What comes after X is Y, and two days after isThursday
16
