# Setup

In [None]:
save_files = True

In [None]:
%%capture
%pip install git+https://github.com/neelnanda-io/TransformerLens.git

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
# import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

import pickle
from google.colab import files

import matplotlib.pyplot as plt
import statistics

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer #, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [None]:
torch.set_grad_enabled(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import pdb

## Import functions from repo

In [None]:
!git clone https://github.com/apartresearch/seqcont_circuits.git
%cd /content/seqcont_circuits/src/iter_node_pruning

Cloning into 'seqcont_circuits'...
remote: Enumerating objects: 1022, done.[K
remote: Counting objects: 100% (488/488), done.[K
remote: Compressing objects: 100% (287/287), done.[K
remote: Total 1022 (delta 296), reused 378 (delta 190), pack-reused 534[K
Receiving objects: 100% (1022/1022), 18.76 MiB | 13.26 MiB/s, done.
Resolving deltas: 100% (659/659), done.
/content/seqcont_circuits/src/iter_node_pruning


In [None]:
## comment this out when debugging functions in colab to use funcs defined in colab

# don't improt this
# # from dataset import Dataset

from metrics import *
from head_ablation_fns import *
from mlp_ablation_fns import *
from node_ablation_fns import *
from loop_node_ablation_fns import *

## fns

In [None]:
import random


In [None]:
class Dataset:
    def __init__(self, prompts, pos_dict, tokenizer):  # , S1_is_first=False
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        self.corr_tokenIDs = [
            # self.tokenizer.encode(" " + prompt["corr"])[0] for prompt in self.prompts
            self.tokenizer.encode(prompt["corr"])[0] for prompt in self.prompts
        ]
        self.incorr_tokenIDs = [
            # self.tokenizer.encode(" " + prompt["incorr"])[0] for prompt in self.prompts
            self.tokenizer.encode(prompt["incorr"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a dict whose values are tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        # for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
        for targ in [key for key in pos_dict]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = self.tokenizer.tokenize(input_text)
                # if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                #     target_token = prompt[targ]
                # else:
                #     target_token = "Ġ" + prompt[targ]
                # target_index = tokens.index(target_token)
                target_index = pos_dict[targ]
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [None]:
def generate_prompts_list_longer(text, tokens):
    prompts_list = []
    prompt_dict = {
        'corr': str(1),
        'incorr': str(2),
        'text': text
        # 'text': model.to_string(tokens)[0]
        }
    tokens_as_strs = model.tokenizer.tokenize(text)
    # tokens_as_strs = model.to_string(tokens)[0].split()
    # for i in range(tokens.shape[1]):
    for i, tok in enumerate(tokens_as_strs):
        prompt_dict['S'+str(i)] = tok
    # for i, tok in enumerate(tokens):
    #     prompt_dict['S'+str(i)] = model.to_string(tok)

    # prompt_dict = {
    #     'corr': '4',
    #     'incorr': '3',
    #     'text': model.to_string(tokens)[0]
    # }
    # # list_tokens = tokenizer.tokenize('1 2 3 ')
    # tokens_as_strs = model.to_string(tokens)[0].split()
    # for i, tok_as_str in enumerate(tokens_as_strs):
    #     if tok_as_str == '▁':
    #         prompt_dict['S'+str(i)] = ' '
    #     else:
    #         prompt_dict['S'+str(i)] = tok_as_str
    prompts_list.append(prompt_dict)
    return prompts_list

# Load Model

In [None]:
from transformers import LlamaForCausalLM, LlamaTokenizer

In [None]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) n
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [None]:
LLAMA_2_7B_CHAT_PATH = "meta-llama/Llama-2-7b-chat-hf"

tokenizer = LlamaTokenizer.from_pretrained(LLAMA_2_7B_CHAT_PATH)
# tokenizer = LlamaTokenizer.from_pretrained(LLAMA_2_7B_CHAT_PATH, use_fast= False, add_prefix_space= False)
hf_model = LlamaForCausalLM.from_pretrained(LLAMA_2_7B_CHAT_PATH, low_cpu_mem_usage=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.62k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

In [None]:
import transformer_lens.utils as utils
from transformer_lens.hook_points import HookPoint
from transformer_lens import HookedTransformer

In [None]:
model = HookedTransformer.from_pretrained(
    LLAMA_2_7B_CHAT_PATH,
    hf_model = hf_model,
    tokenizer = tokenizer,
    device = "cpu",
    fold_ln = False,
    center_writing_weights = False,
    center_unembed = False,
)

del hf_model

model = model.to("cuda" if torch.cuda.is_available() else "cpu")

Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer
Moving model to device:  cuda


# Define circs

In [None]:
# from Llama2_numerals_1to10.ipynb
nums_1to9 = [(0, 2), (0, 5), (0, 6), (0, 15), (1, 15), (1, 28), (2, 13), (2, 24), (3, 24), (4, 3), (4, 16), (5, 11), (5, 13), (5, 15), (5, 16), (5, 23), (5, 25), (5, 27), (6, 11), (6, 14), (6, 20), (6, 23), (6, 24), (6, 26), (6, 28), (6, 30), (6, 31), (7, 0), (7, 13), (7, 21), (7, 30), (8, 0), (8, 2), (8, 12), (8, 15), (8, 26), (8, 27), (8, 30), (8, 31), (9, 15), (9, 16), (9, 23), (9, 26), (9, 27), (9, 29), (9, 31), (10, 1), (10, 13), (10, 18), (10, 23), (10, 29), (11, 7), (11, 8), (11, 9), (11, 17), (11, 18), (11, 25), (11, 28), (12, 18), (12, 19), (12, 23), (12, 27), (13, 6), (13, 11), (13, 20), (14, 18), (14, 19), (14, 20), (14, 21), (16, 0), (18, 19), (18, 21), (18, 25), (18, 26), (18, 31), (19, 28), (20, 17), (21, 0), (21, 2), (22, 18), (22, 20), (22, 25), (23, 27), (26, 2)]
len(nums_1to9)

84

In [None]:
# nw_circ = [(0, 1), (0, 4), (0, 6), (0, 7), (0, 8), (0, 10), (0, 11), (0, 12), (1, 16), (1, 24), (1, 27), (1, 28), (2, 2), (2, 5), (2, 8), (2, 24), (2, 30), (3, 7), (3, 14), (3, 19), (3, 23), (4, 3), (5, 16), (5, 25), (6, 11), (6, 14), (7, 0), (7, 30), (8, 0), (8, 2), (8, 3), (8, 4), (8, 6), (8, 21), (8, 31), (9, 1), (9, 3), (9, 7), (9, 11), (9, 29), (9, 31), (10, 13), (10, 18), (10, 23), (10, 24), (10, 25), (10, 27), (11, 18), (11, 28), (12, 18), (12, 26), (13, 11), (13, 17), (13, 18), (13, 19), (13, 20), (13, 21), (13, 23), (14, 7), (14, 14), (15, 25), (15, 28), (16, 0), (16, 12), (16, 14), (16, 15), (16, 16), (16, 19), (16, 24), (16, 29), (17, 17), (17, 23), (17, 31), (18, 31), (19, 12), (20, 17), (27, 20), (27, 25), (27, 27), (27, 31), (28, 5), (29, 5)]
# in order from most impt to least based on how much changes perf when ablated
nw_circ = [(20, 17), (5, 25), (16, 0), (29, 5), (3, 19), (6, 11), (15, 25), (8, 0), (16, 24), (8, 4), (7, 0), (6, 14), (16, 29), (5, 16), (12, 26), (4, 3), (3, 7), (7, 30), (11, 28), (28, 5), (17, 31), (13, 11), (13, 20), (12, 18), (1, 27), (10, 13), (18, 31), (8, 6), (9, 1), (0, 4), (2, 2), (9, 11), (19, 12), (1, 16), (13, 17), (9, 7), (11, 18), (2, 24), (10, 18), (9, 31), (9, 29), (2, 30), (2, 5), (1, 24), (2, 8), (15, 28), (27, 31), (16, 14), (3, 23), (3, 14), (10, 23), (27, 20), (8, 3), (14, 7), (14, 14), (16, 15), (8, 2), (17, 17), (0, 1), (10, 27), (16, 19), (0, 8), (0, 12), (1, 28), (0, 11), (17, 23), (0, 10), (0, 6), (13, 19), (8, 31), (10, 24), (16, 12), (13, 23), (13, 21), (27, 27), (9, 3), (27, 25), (16, 16), (8, 21), (0, 7), (13, 18), (10, 25)]
len(nw_circ)

82

In [None]:
# impt_months_heads = ([(23, 17), (17, 11), (16, 0), (26, 14), (18, 9), (5, 25), (22, 20), (6, 24), (26, 9), (12, 18), (13, 20), (19, 12), (27, 29), (13, 14), (16, 14), (12, 26), (19, 30), (16, 18), (31, 27), (26, 28), (16, 1), (18, 1), (19, 28), (18, 31), (29, 4), (17, 0), (14, 1), (17, 12), (12, 15), (28, 16), (10, 1), (16, 19), (9, 27), (30, 1), (19, 27), (0, 3), (15, 11), (21, 3), (11, 19), (12, 0), (23, 11), (8, 14), (16, 8), (22, 13), (13, 3), (4, 19), (14, 15), (12, 20), (19, 16), (18, 5)])
months_circ = [(20, 17), (6, 11), (16, 0), (5, 15), (17, 11), (23, 16), (5, 25), (7, 0), (26, 14), (6, 14), (12, 22), (8, 4), (12, 15), (16, 29), (15, 25), (5, 16), (18, 31), (14, 7), (11, 18), (4, 12), (3, 19), (12, 2), (11, 28), (4, 3), (18, 9), (8, 14), (12, 3), (11, 2), (10, 13), (4, 16), (1, 22), (11, 16), (3, 15), (13, 31), (2, 4), (2, 16), (8, 13), (0, 13), (8, 15), (12, 28), (1, 5), (0, 4), (0, 25), (3, 24), (13, 11), (1, 24), (8, 16), (13, 8), (3, 26), (0, 6), (3, 23), (1, 3), (14, 3), (8, 19), (8, 12), (14, 2), (8, 5), (1, 28), (8, 20), (2, 30), (8, 6), (10, 1), (13, 20), (19, 27)]
len(months_circ)

64

In [None]:
intersect_all = list(set(nums_1to9) & set(nw_circ) & set(months_circ))
len(intersect_all)

16

In [None]:
union_all = list(set(nums_1to9) | set(nw_circ) | set(months_circ))
len(union_all)

172

# new ablation functions

In [None]:
def get_heads_actv_mean(
    means_dataset: Dataset,
    model: HookedTransformer
) -> Float[Tensor, "layer batch seq head_idx d_head"]:
    '''
    Output: The mean activations of a head's output
    '''
    _, means_cache = model.run_with_cache(
        means_dataset.toks.long(),
        return_type=None,
        names_filter=lambda name: name.endswith("z"),
    )
    n_layers, n_heads, d_head = model.cfg.n_layers, model.cfg.n_heads, model.cfg.d_head
    batch, seq_len = len(means_dataset), means_dataset.max_len
    means = t.zeros(size=(n_layers, batch, seq_len, n_heads, d_head), device=model.cfg.device)

    # for layer in range(model.cfg.n_layers):
    #     z_for_this_layer: Float[Tensor, "batch seq head d_head"] = means_cache[utils.get_act_name("z", layer)]
    #     for template_group in means_dataset.groups:
    #         z_for_this_template = z_for_this_layer[template_group]
    #         z_means_for_this_template = einops.reduce(z_for_this_template, "batch seq head d_head -> seq head d_head", "mean")
    #         if z_means_for_this_template.shape[0] == 5:
    #             pdb.set_trace()
    #         means[layer, template_group] = z_means_for_this_template

    del(means_cache)

    return means

In [None]:
# def mask_circ_heads(
#     means_dataset: Dataset,
#     model: HookedTransformer,
#     circuit: Dict[str, List[Tuple[int, int]]],
#     seq_pos_to_keep: Dict[str, str],
# ) -> Dict[int, Bool[Tensor, "batch seq head"]]:
#     '''
#     Output: for each layer, a mask of circuit components that should not be ablated
#     '''
#     heads_and_posns_to_keep = {}
#     batch, seq, n_heads = len(means_dataset), means_dataset.max_len, model.cfg.n_heads

#     for layer in range(model.cfg.n_layers):

#         mask = t.zeros(size=(batch, seq, n_heads))

#         for (head_type, head_list) in circuit.items():
#             seq_pos = seq_pos_to_keep[head_type]
#             # if seq_pos == 'S7':
#             #     pdb.set_trace()
#             indices = means_dataset.word_idx[seq_pos] # modify this for key vs query pos. curr, this is query
#             for (layer_idx, head_idx) in head_list:
#                 if layer_idx == layer:
#                     # if indices.item() == 7:
#                     #     pdb.set_trace()
#                     mask[:, indices, head_idx] = 1
#                     # mask[:, :, head_idx] = 1  # keep L.H at all pos

#         heads_and_posns_to_keep[layer] = mask.bool()
#     # pdb.set_trace()
#     return heads_and_posns_to_keep

In [None]:
def mask_circ_heads(
    means_dataset: Dataset,
    model: HookedTransformer,
    circuit: Dict[str, List[Tuple[int, int]]],
    seq_pos_to_keep: Dict[str, str],
) -> Dict[int, Bool[Tensor, "batch seq head"]]:
    '''
    Output: for each layer, a mask of circuit components that should not be ablated
    '''
    heads_and_posns_to_keep = {}
    # batch, seq, n_heads = len(means_dataset), means_dataset.max_len, model.cfg.n_heads
    batch, seq, n_heads = len(means_dataset), len(circuit.keys()), model.cfg.n_heads
    # print(seq)

    for layer in range(model.cfg.n_layers):

        mask = t.zeros(size=(batch, seq, n_heads))

        for (head_type, head_list) in circuit.items():
            seq_pos = seq_pos_to_keep[head_type]
            indices = means_dataset.word_idx[seq_pos] # modify this for key vs query pos. curr, this is query
            for (layer_idx, head_idx) in head_list:
                if layer_idx == layer:
                    # mask[:, indices, head_idx] = 1
                    mask[:, :, head_idx] = 1

        heads_and_posns_to_keep[layer] = mask.bool()

    return heads_and_posns_to_keep

In [None]:
def hook_func_mask_head(
    z: Float[Tensor, "batch seq head d_head"],
    hook: HookPoint,
    # components_to_keep: Dict[int, Bool[Tensor, "batch seq head"]],
    # means: Float[Tensor, "layer batch seq head d_head"],
    circuit: Dict[str, List[Tuple[int, int]]],
) -> Float[Tensor, "batch seq head d_head"]:
    '''
    Use this to not mask components
    '''
    # mask_for_this_layer = components_to_keep[hook.layer()].unsqueeze(-1).to(z.device)
    # z = t.where(mask_for_this_layer, z, means[hook.layer()])

    ###
    # heads_and_posns_to_keep = {}
    # batch, seq, n_heads = z.shape[0], z.shape[1], model.cfg.n_heads  # components_to_keep[0].shape[0] is batch

    # for layer in range(model.cfg.n_layers):

    #     mask = t.zeros(size=(batch, seq, n_heads))

    #     for (head_type, head_list) in circuit.items():
    #         # seq_pos = seq_pos_to_keep[head_type]
    #         # indices = means_dataset.word_idx[seq_pos] # modify this for key vs query pos. curr, this is query
    #         for (layer_idx, head_idx) in head_list:
    #             if layer_idx == layer:
    #                 # mask[:, indices, head_idx] = 1
    #                 mask[:, :, head_idx] = 1

    #     heads_and_posns_to_keep[layer] = mask.bool()
    ###
    mask_for_this_layer = t.zeros(size=(z.shape[0], z.shape[1], z.shape[2]))
    for (head_type, head_list) in circuit.items():
        # seq_pos = seq_pos_to_keep[head_type]
        # indices = means_dataset.word_idx[seq_pos] # modify this for key vs query pos. curr, this is query
        for (layer_idx, head_idx) in head_list:
            if layer_idx == hook.layer():
                # mask[:, indices, head_idx] = 1
                mask_for_this_layer[:, :, head_idx] = 1

    mask_for_this_layer = mask_for_this_layer.bool()
    mask_for_this_layer = mask_for_this_layer.unsqueeze(-1).to(z.device)  # d_model is 1; then is broadcast in where

    z = t.where(mask_for_this_layer, z, 0)

    return z

In [None]:
def add_ablation_hook_head(
    model: HookedTransformer,
    means_dataset: Dataset,
    circuit: Dict[str, List[Tuple[int, int]]],
    seq_pos_to_keep: Dict[str, str],
    is_permanent: bool = True,
) -> HookedTransformer:
    '''
    Ablate the model, except as components and positions to keep
    '''

    model.reset_hooks(including_permanent=True)
    means = get_heads_actv_mean(means_dataset, model)
    components_to_keep = mask_circ_heads(means_dataset, model, circuit, seq_pos_to_keep)

    hook_fn = partial(
        hook_func_mask_head,
        # components_to_keep=components_to_keep,
        # means=means,
        circuit=circuit,
    )

    model.add_hook(lambda name: name.endswith("z"), hook_fn, is_permanent=is_permanent)
    return model

In [None]:
# from dataset import Dataset
from transformer_lens import HookedTransformer, utils
from transformer_lens.hook_points import HookPoint
import einops
from functools import partial
import torch as t
from torch import Tensor
from typing import Dict, Tuple, List
from jaxtyping import Float, Bool

# from head_ablation_fns import *
# from mlp_ablation_fns import *

def add_ablation_hook_MLP_head(
    model: HookedTransformer,
    means_dataset: Dataset,
    heads_lst, mlp_lst,
    is_permanent: bool = True,
) -> HookedTransformer:
    CIRCUIT = {}
    SEQ_POS_TO_KEEP = {}
    # for i in range(len(model.tokenizer.tokenize(means_dataset.prompts[0]['text']))):
    num_pos = len(model.tokenizer(means_dataset.prompts[0]['text']).input_ids)
    for i in range(num_pos ):
        CIRCUIT['S'+str(i)] = heads_lst
        # if i == len(model.tokenizer.tokenize(means_dataset.prompts[0]['text'])) - 1:
        # if i == num_pos - 1:
        #     SEQ_POS_TO_KEEP['S'+str(i)] = 'end'
        # else:
        SEQ_POS_TO_KEEP['S'+str(i)] = 'S'+str(i)

    model.reset_hooks(including_permanent=True)

    # Compute the mean of each head's output on the ABC dataset, grouped by template
    means = get_heads_actv_mean(means_dataset, model)
    # Convert this into a boolean map
    components_to_keep = mask_circ_heads(means_dataset, model, CIRCUIT, SEQ_POS_TO_KEEP)

    # Get a hook function which will patch in the mean z values for each head, at
    # all positions which aren't important for the circuit
    hook_fn = partial(
        hook_func_mask_head,
        # components_to_keep=components_to_keep,
        # means=means,
        circuit=CIRCUIT,
    )

    # Apply hook
    model.add_hook(lambda name: name.endswith("z"), hook_fn, is_permanent=is_permanent)

    # if all_entries_true(components_to_keep) == False:
    #     pdb.set_trace()
    ########################
    # CIRCUIT = {}
    # SEQ_POS_TO_KEEP = {}
    # # for i in range(len(model.tokenizer.tokenize(means_dataset.prompts[0]['text']))):
    # num_pos = len(model.tokenizer(means_dataset.prompts[0]['text']).input_ids)
    # for i in range(num_pos ):
    #     CIRCUIT['S'+str(i)] = mlp_lst
    #     # if i == len(model.tokenizer.tokenize(means_dataset.prompts[0]['text'])) - 1:
    #     # if i == num_pos - 1:
    #     #     SEQ_POS_TO_KEEP['S'+str(i)] = 'end'
    #     # else:
    #     SEQ_POS_TO_KEEP['S'+str(i)] = 'S'+str(i)

    # # Compute the mean of each head's output on the ABC dataset, grouped by template
    # means = get_MLPs_actv_mean(means_dataset, model)

    # # Convert this into a boolean map
    # components_to_keep = mask_circ_MLPs(means_dataset, model, CIRCUIT, SEQ_POS_TO_KEEP)

    # # Get a hook function which will patch in the mean z values for each head, at
    # # all positions which aren't important for the circuit
    # hook_fn = partial(
    #     hook_func_mask_mlp_out,
    #     components_to_keep=components_to_keep,
    #     means=means
    # )

    # model.add_hook(lambda name: name.endswith("mlp_out"), hook_fn, is_permanent=True)

    return model

In [None]:
def all_entries_true(tensor_dict):
    for key, tensor in tensor_dict.items():
        if not torch.all(tensor).item():
            return False
    return True

# ablation fns mult tok answers

In [None]:
def clean_gen(model, clean_text, corr_ans):
    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
    tokens = model.to_tokens(clean_text).to(device)
    # tokens = tokens[:, 1:] # get rid of prepend bos when using model.to_tokens

    total_score = 0
    corr_ans_tokLen = 0
    ans_so_far = ''
    # while True:
    for i in range(5):
        print(f"Sequence so far: {model.to_string(tokens)[0]!r}")
        logits = model(tokens)
        next_token = logits[0, -1].argmax(dim=-1) # Get the predicted token at the end of our sequence
        next_char = model.to_string(next_token)

        corr_logits = logits[:, -1, next_token]
        total_score += corr_logits
        print(f"logit diff of new char: {corr_logits}")

        ans_so_far += next_char
        corr_ans_tokLen += 1
        print(f"{tokens.shape[-1]+1}th char = {next_char!r}")
        if ans_so_far == corr_ans:
            print('\nTotal logit diff: ', total_score.item())
            break

        # Define new input sequence, by appending the previously generated token
        tokens = torch.cat([tokens, next_token[None, None]], dim=-1)
        # if next_char == '':
        #     next_char = ' '
        # clean_text = clean_text + next_char
        # tokens = model.to_tokens(clean_text).to(device)
    return corr_ans_tokLen

In [None]:
def ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen):
    tokens = model.to_tokens(clean_text).to(device)
    prompts_list = generate_prompts_list_longer(clean_text, tokens)

    corr_tokens = model.to_tokens(corr_text).to(device)
    prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
    pos_dict = {}
    num_pos = len(model.tokenizer(prompts_list_2[0]['text']).input_ids)
    for i in range(num_pos ):
        pos_dict['S'+str(i)] = i
    dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer)
    model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

    logits = model(tokens)
    next_token = logits[0, -1].argmax(dim=-1)
    next_char = model.to_string(next_token)

    total_score = 0

    for i in range(corr_ans_tokLen):
        if next_char == '':
            next_char = ' '

        clean_text = clean_text + next_char
        # if i == corr_ans_tokLen - 1:
        #     print(model.to_string(tokens))
            # print(f"Sequence so far: {clean_text}")
            # print(f"{tokens.shape[-1]+1}th char = {next_char!r}")

        tokens = torch.cat([tokens, next_token[None, None]], dim=-1)

        # get new ablation dataset
        # model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

        # corr_text = corr_text + next_char
        # corr_tokens = torch.cat([corr_tokens, next_token[None, None]], dim=-1)
        # prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)

        # pos_dict = {}
        # num_pos = len(model.tokenizer(prompts_list_2[0]['text']).input_ids)
        # for i in range(num_pos ):
        #     pos_dict['S'+str(i)] = i

        # dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, corr_tokens)

        # model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

        logits = model(tokens)
        next_token = logits[0, -1].argmax(dim=-1) # Get the predicted token at the end of our sequence
        next_char = model.to_string(next_token)

        # new_score = get_logit_diff(logits, dataset)
        # total_score += new_score
        # print(f"corr logit of new char: {new_score}")
    # print('\n Total corr logit: ', total_score.item())
    return model.to_string(tokens)

In [None]:
def ablate_auto_score(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, correct_ans):
    tokens = model.to_tokens(clean_text).to(device)
    prompts_list = generate_prompts_list_longer(clean_text, tokens)

    corr_tokens = model.to_tokens(corr_text).to(device)
    prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
    pos_dict = {}
    num_pos = len(model.tokenizer(prompts_list_2[0]['text']).input_ids)
    for i in range(num_pos ):
        pos_dict['S'+str(i)] = i
    dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer)
    model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

    # logits = model(tokens)
    # next_token = logits[0, -1].argmax(dim=-1)
    # next_char = model.to_string(next_token)

    total_score = 0
    ans_so_far = ''
    ans_str_tok = tokenizer.tokenize(correct_ans)[1:] # correct_ans is str
    corr_tokenIDs = []
    for correct_ansPos in range(len(ans_str_tok)):
        tokID = model.tokenizer.encode(ans_str_tok[correct_ansPos])[2:][0] # 2: to skip padding <s> and ''
        corr_tokenIDs.append(tokID)
    correct_ans_tokLen = len(corr_tokenIDs)
    for ansPos in range(correct_ans_tokLen):
        # if next_char == '':
        #     next_char = ' '

        # clean_text = clean_text + next_char
        # if i == correct_ans_tokLen - 1:
        #     print(model.to_string(tokens))
        #     # print(f"Sequence so far: {clean_text}")
        #     # print(f"{tokens.shape[-1]+1}th char = {next_char!r}")

        # tokens = torch.cat([tokens, next_token[None, None]], dim=-1)

        # get new ablation dataset
        # model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

        # corr_text = corr_text + next_char
        # corr_tokens = torch.cat([corr_tokens, next_token[None, None]], dim=-1)
        # prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)

        # pos_dict = {}
        # num_pos = len(model.tokenizer(prompts_list_2[0]['text']).input_ids)
        # for i in range(num_pos ):
        #     pos_dict['S'+str(i)] = i

        # dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, corr_tokens)

        # model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

        logits = model(tokens)
        next_token = logits[0, -1].argmax(dim=-1) # Get the predicted token at the end of our sequence
        next_char = model.to_string(next_token)

        if next_char == '':
            next_char = ' '

        clean_text = clean_text + next_char
        # if i == correct_ans_tokLen - 1:
            # print(model.to_string(tokens))
            # print(f"Sequence so far: {clean_text}")
            # print(f"{tokens.shape[-1]+1}th char = {next_char!r}")

        tokens = torch.cat([tokens, next_token[None, None]], dim=-1)

        ans_so_far += next_char
        correct_ans_tokLen += 1
        # print(f"{tokens.shape[-1]+1}th char = {next_char!r}")

        ansTok_IDs = torch.tensor(corr_tokenIDs[ansPos])

        # new_score = get_logit_diff(logits, dataset)
        # total_score += new_score
        # corrTok_logits = logits[:, -1, next_token]
        corrTok_logits = logits[range(logits.size(0)), -1, ansTok_IDs]  # not next_token, as that's what's pred, not the token to measure
        # pdb.set_trace()
        total_score += corrTok_logits
        # print(f"corr logit of new char: {new_score}")
    # print('\n Total corr logit: ', total_score.item())
    # return ans_so_far, total_score.item()
    return ans_so_far

# auto measure fns

In [None]:
def ablate_circ_autoScore(model, circuit, sequences_as_str, next_members):
    corr_text = "5 3 9"
    list_outputs = []
    score = 0
    for clean_text, correct_ans in zip(sequences_as_str, next_members):
        correct_ans_tokLen = clean_gen(model, clean_text, correct_ans)

        heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
        head_to_remove = circuit
        heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

        mlps_not_ablate = [layer for layer in range(32)]

        output_after_ablate = ablate_auto_score(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, correct_ans_tokLen)
        list_outputs.append(output_after_ablate)
        print(correct_ans, output_after_ablate)
        if correct_ans == output_after_ablate:
            score += 1
    perc_score = score / len(next_members)
    return perc_score, list_outputs

In [None]:
def ablate_randcirc_autoScore(model, sequences_as_str, next_members, num_rand_runs, heads_not_overlap, num_heads_rand, num_not_overlap):
    corr_text = "5 3 9"
    list_outputs = []
    all_scores = []
    for clean_text, correct_ans in zip(sequences_as_str, next_members):
        prompt_score = 0
        correct_ans_tokLen = clean_gen(model, clean_text, correct_ans)
        for j in range(num_rand_runs):
            all_possible_pairs =  [(layer, head) for layer in range(32) for head in range(32)]
            filtered_pairs = [pair for pair in all_possible_pairs if pair not in heads_not_overlap] # Filter out heads_not_overlap from all_possible_pairs

            # Randomly choose num_heads_rand pairs ensuring less than num_not_overlap overlaps with heads_not_overlap
            head_to_remove = choose_heads_to_remove(filtered_pairs, heads_not_overlap, num_heads_rand, num_not_overlap)

            heads_not_ablate = [x for x in all_possible_pairs if x not in head_to_remove]

            mlps_not_ablate = [layer for layer in range(32)]

            output_after_ablate = ablate_auto_score(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, correct_ans_tokLen)
            # list_outputs.append(output_after_ablate)
            # print(correct_ans, output_after_ablate)
            if correct_ans == output_after_ablate:
                prompt_score += 1
        print(prompt_score / num_rand_runs)
        all_scores.append(prompt_score / num_rand_runs)

    perc_score = sum(all_scores) / len(next_members)
    return perc_score, list_outputs

# chose rand circs

In [None]:
# Function to randomly choose 50 pairs ensuring less than 10 overlap with heads_of_circ
def choose_heads_to_remove(filtered_pairs, heads_of_circ, num_pairs=50, max_overlap=10):
    while True:
        head_to_remove = random.sample(filtered_pairs, num_pairs)
        overlap_count = len([head for head in head_to_remove if head in heads_of_circ])
        if overlap_count < max_overlap:
            return head_to_remove

In [None]:
import random
num_rand_runs = 50
lst_rand_head_to_remove = []

heads_not_overlap = intersect_all
num_heads_rand = 100
num_not_overlap = len(intersect_all)
for j in range(num_rand_runs):
    all_possible_pairs =  [(layer, head) for layer in range(32) for head in range(32)]
    filtered_pairs = [pair for pair in all_possible_pairs if pair not in heads_not_overlap] # Filter out heads_not_overlap from all_possible_pairs
    head_to_remove = choose_heads_to_remove(filtered_pairs, heads_not_overlap, num_heads_rand, num_not_overlap)
    # heads_not_ablate = [x for x in all_possible_pairs if x not in head_to_remove]
    lst_rand_head_to_remove.append(head_to_remove)

In [None]:
import pickle
from google.colab import files
with open('lst_rand_head_to_remove.pkl', 'wb') as file:
    pickle.dump(lst_rand_head_to_remove, file)
files.download('lst_rand_head_to_remove.pkl')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
pwd

'/content/seqcont_circuits/src/iter_node_pruning'

In [None]:
import pickle
with open('/content/lst_rand_head_to_remove.pkl', 'rb') as file:
    lst_rand_head_to_remove = pickle.load(file)

In [None]:
for lst in lst_rand_head_to_remove:
    print(lst)

[(13, 24), (28, 4), (26, 31), (12, 14), (27, 29), (27, 4), (21, 14), (0, 13), (10, 24), (5, 27), (4, 26), (16, 27), (13, 26), (5, 9), (30, 5), (11, 9), (12, 17), (22, 10), (4, 12), (4, 6), (15, 18), (2, 26), (16, 11), (23, 31), (5, 14), (27, 23), (28, 3), (0, 18), (8, 17), (4, 11), (9, 14), (23, 19), (12, 7), (19, 29), (13, 4), (25, 27), (6, 10), (26, 17), (8, 11), (2, 27), (27, 8), (22, 2), (1, 27), (1, 3), (8, 27), (11, 10), (25, 13), (14, 21), (25, 7), (1, 16), (12, 10), (8, 28), (25, 19), (1, 11), (3, 3), (5, 11), (4, 31), (28, 8), (19, 20), (20, 23), (10, 7), (22, 17), (18, 13), (6, 26), (27, 20), (2, 4), (13, 14), (22, 27), (27, 28), (6, 16), (19, 16), (7, 10), (13, 3), (12, 15), (5, 20), (28, 24), (14, 26), (13, 21), (24, 29), (29, 29), (3, 23), (22, 6), (27, 19), (7, 31), (19, 19), (20, 0), (13, 29), (7, 17), (3, 31), (22, 8), (21, 30), (3, 21), (23, 17), (25, 4), (6, 24), (30, 28), (25, 26), (18, 23), (26, 12), (10, 17)]
[(11, 14), (7, 19), (29, 25), (21, 4), (19, 4), (27, 17)

# test prompts

In [None]:
instruction = "Be concise. "
clean_text =  "If today is November 20th, then in 28 days it will be"
clean_text = instruction + clean_text
corr_text = "uno uno uno" # dos tres cinco seis
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
mlps_not_ablate = [layer for layer in range(32)]
ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)

['<s> Be concise. If today is November 20th, then in 28 days it will be December 18']


In [None]:
instruction = "Be concise. "
clean_text =  "If this month is April, and four months pass, what month is it? "
clean_text = instruction + clean_text
corr_text = "uno uno uno" # dos tres cinco seis
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
mlps_not_ablate = [layer for layer in range(32)]
ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)

['<s> Be concise. If this month is April, and four months pass, what month is it? \n\nAnswer: May']

In [None]:
instruction = "Be concise. "
clean_text =  "If this month is April, and two months pass, what month is it? "
clean_text = instruction + clean_text
corr_text = "uno uno uno" # dos tres cinco seis
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
mlps_not_ablate = [layer for layer in range(32)]
ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)

['<s> Be concise. If this month is April, and two months pass, what month is it? \n\nAnswer: May']

In [None]:
instruction = "Be concise. "
clean_text =  "If this month is April, and 4 months pass, what month is it? "
clean_text = instruction + clean_text
corr_text = "uno uno uno" # dos tres cinco seis
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
mlps_not_ablate = [layer for layer in range(32)]
ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)

['<s> Be concise. If this month is April, and 4 months pass, what month is it? \n\nAnswer: May']

In [None]:
instruction = "Be concise. "
clean_text =  "If this month is April, and 4 months pass, what month is it? Answer: "
# clean_text = instruction + clean_text
corr_text = "uno uno uno" # dos tres cinco seis
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
mlps_not_ablate = [layer for layer in range(32)]
ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 15)

['<s> If this month is April, and 4 months pass, what month is it? Answer:  If this month is April and 4 months pass, then it is August']

In [None]:
instruction = "Be concise. "
clean_text =  "If this month is April, and four months pass, what month is it? Answer: "
# clean_text = instruction + clean_text
corr_text = "uno uno uno" # dos tres cinco seis
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
mlps_not_ablate = [layer for layer in range(32)]
ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 15)

['<s> If this month is April, and four months pass, what month is it? Answer:  If four months pass, it will be August.\n\nIf this month']

In [None]:
instruction = "Be concise. "
clean_text =  "If this month is April, and four months pass, what month is it? Answer: "
# clean_text = instruction + clean_text
corr_text = "uno uno uno" # dos tres cinco seis
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
mlps_not_ablate = [layer for layer in range(32)]
ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 15)

In [None]:
instruction = "Be concise. "
clean_text =  "If this month is May, and four months pass, what month is it? Answer: "
# clean_text = instruction + clean_text
corr_text = "uno uno uno" # dos tres cinco seis
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
mlps_not_ablate = [layer for layer in range(32)]
ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 15)

['<s> If this month is May, and four months pass, what month is it? Answer:  It is now September.\n\nIf this month is June, and four']

# If this month is May, and four months pass, what month is it? Answer:

In [None]:
# Generate prompts as described by replacing the month "May" and number "four"
import random

# Define a list of months and a range of numbers to choose from
months = ["January", "February", "March", "April", "May", "June",
          "July", "August", "September", "October", "November", "December"]
numbers = list(range(1, 13))  # Using a realistic range for months in a year

# Generate 10 unique prompts by randomly selecting months and numbers
prompts = []
for _ in range(10):
    month = random.choice(months)
    number = random.choice(numbers)
    prompts.append(f"If this month is {month}, and {number} months pass, what month is it? Answer: ")

prompts

['If this month is August, and 9 months pass, what month is it? Answer: ',
 'If this month is June, and 1 months pass, what month is it? Answer: ',
 'If this month is May, and 3 months pass, what month is it? Answer: ',
 'If this month is November, and 1 months pass, what month is it? Answer: ',
 'If this month is March, and 11 months pass, what month is it? Answer: ',
 'If this month is November, and 5 months pass, what month is it? Answer: ',
 'If this month is May, and 8 months pass, what month is it? Answer: ',
 'If this month is April, and 10 months pass, what month is it? Answer: ',
 'If this month is March, and 11 months pass, what month is it? Answer: ',
 'If this month is April, and 5 months pass, what month is it? Answer: ']

In [None]:
# Correcting the function to accurately extract the month and the number of months to add

def correct_and_simplify_calculation(prompt):
    # Extract the month and number of months correctly
    words = prompt.split()
    current_month = words[4].strip(',')  # Corrected to properly extract the month name
    months_to_add = int(words[6])  # Corrected to extract the correct integer value for months to add

    # Compute the future month index considering the circular nature of months
    current_index = months.index(current_month)
    future_index = (current_index + months_to_add) % 12

    # Return the future month based on the computed index
    return months[future_index]

# Generate answers using the correctly adjusted function
final_corrected_answers = [correct_and_simplify_calculation(prompt) for prompt in prompts]
final_corrected_answers

['May',
 'July',
 'August',
 'December',
 'February',
 'April',
 'January',
 'February',
 'February',
 'September']

In [None]:
# unablated

outputs = []
# for clean_text in correct_prompts:
for clean_text in prompts:
    corr_text = "uno uno uno" # dos tres cinco seis
    heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
    mlps_not_ablate = [layer for layer in range(32)]
    prompt_out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 15)
    outputs.append(prompt_out)
    print(prompt_out)

['<s> If this month is August, and 9 months pass, what month is it? Answer:  If it is August and 9 months pass, then it is May.']
['<s> If this month is June, and 1 months pass, what month is it? Answer:  If this month is June and 1 month passes, then it is July']
['<s> If this month is May, and 3 months pass, what month is it? Answer: 3 months pass, so it is now August.</s>sports betting']
['<s> If this month is November, and 1 months pass, what month is it? Answer: 1 month has passed, so it is now December.</s>sports bet']
['<s> If this month is March, and 11 months pass, what month is it? Answer: 11 months after March is February.\n\nIf this month is March']
['<s> If this month is November, and 5 months pass, what month is it? Answer: 6 months have passed, so it is now December.</s>sports bet']
['<s> If this month is May, and 8 months pass, what month is it? Answer: 13 months have passed, so it is now August.</s>sports']
['<s> If this month is April, and 10 months pass, what month i

In [None]:
# unablated

outputs = []
instruction = "If this month is March, and 2 months pass, what month is it? Answer: May. "
# for clean_text in correct_prompts:
for clean_text in prompts:
    clean_text = instruction + clean_text
    corr_text = "uno uno uno" # dos tres cinco seis
    heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
    mlps_not_ablate = [layer for layer in range(32)]
    prompt_out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 2)
    outputs.append(prompt_out)
    print(prompt_out)

['<s> If this month is March, and 2 months pass, what month is it? Answer: May. If this month is August, and 9 months pass, what month is it? Answer: 5th']
['<s> If this month is March, and 2 months pass, what month is it? Answer: May. If this month is June, and 1 months pass, what month is it? Answer:  July.']
['<s> If this month is March, and 2 months pass, what month is it? Answer: May. If this month is May, and 3 months pass, what month is it? Answer: 8.']
['<s> If this month is March, and 2 months pass, what month is it? Answer: May. If this month is November, and 1 months pass, what month is it? Answer:  December.']
['<s> If this month is March, and 2 months pass, what month is it? Answer: May. If this month is March, and 11 months pass, what month is it? Answer: 20']
['<s> If this month is March, and 2 months pass, what month is it? Answer: May. If this month is November, and 5 months pass, what month is it? Answer: 4th']
['<s> If this month is March, and 2 months pass, what mon

Use chatgpt or manual to get correct indices

In [None]:
# Indices of correct answers
correct_indices = [0, 1, 2, 3, 4, 8, 9]

# Subset using the correct indices
correct_prompts = [prompts[i] for i in correct_indices]
correct_prompts

['If this month is August, and 9 months pass, what month is it? Answer: ',
 'If this month is June, and 1 months pass, what month is it? Answer: ',
 'If this month is May, and 3 months pass, what month is it? Answer: ',
 'If this month is November, and 1 months pass, what month is it? Answer: ',
 'If this month is March, and 11 months pass, what month is it? Answer: ',
 'If this month is March, and 11 months pass, what month is it? Answer: ',
 'If this month is April, and 5 months pass, what month is it? Answer: ']

In [None]:
answers_of_correct_prompts = [final_corrected_answers[i] for i in correct_indices]
answers_of_correct_prompts

['May', 'July', 'August', 'December', 'February', 'February', 'September']

In [None]:
# big 3 heads
head_to_remove = [(20,17), (16,0), (5,25)]
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]
mlps_not_ablate = [layer for layer in range(32)]
corr_text = "uno uno uno" # dos tres cinco seis

outputs = []
for clean_text in correct_prompts:
    prompt_out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 15)
    outputs.append(prompt_out)
    print(prompt_out)

['<s> If this month is August, and 9 months pass, what month is it? Answer:  If it is August and 9 months pass, then it is November.']
['<s> If this month is June, and 1 months pass, what month is it? Answer:  If this month is June and 1 month passes, then it is July']
['<s> If this month is May, and 3 months pass, what month is it? Answer: 3 months after May is August.</s>01.05.2']
['<s> If this month is November, and 1 months pass, what month is it? Answer: 1 month has passed, so it is now December.</s>sports bet']
['<s> If this month is March, and 11 months pass, what month is it? Answer: 11 months after March is December.\n\nIf this month is December']
['<s> If this month is March, and 11 months pass, what month is it? Answer: 11 months after March is December.\n\nIf this month is December']
['<s> If this month is April, and 5 months pass, what month is it? Answer:  If it is April and 5 months pass, then it is now July']


In [None]:
# random, len 3 (not from saved head combo presets) ; ssave all results

all_prompt_outputs = []
heads_of_circ = intersect_all
num_heads_rand = 3
num_not_overlap = len(intersect_all)
all_possible_pairs =  [(layer, head) for layer in range(32) for head in range(32)]
filtered_pairs = [pair for pair in all_possible_pairs if pair not in heads_of_circ] # Filter out heads_of_circ from all_possible_pairs
mlps_not_ablate = [layer for layer in range(32)]
corr_text = "0 0 0"
for clean_text in correct_prompts:
    output_for_a_prompt = []
    for i in range(10):
        # Randomly choose pairs ensuring no overlaps with heads_of_circ
        head_to_remove = choose_heads_to_remove(filtered_pairs, heads_of_circ, num_heads_rand, num_not_overlap)
        heads_not_ablate = [x for x in all_possible_pairs if x not in head_to_remove]
        out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 15)
        # print(out[0])
        output_for_a_prompt.append(out[0])
    print(out)
    all_prompt_outputs.append(output_for_a_prompt)

['<s> If this month is August, and 9 months pass, what month is it? Answer: 9 months after August is May.\n\nIf this month is August,']
['<s> If this month is June, and 1 months pass, what month is it? Answer:  If this month is June and 1 month passes, then it is July']
['<s> If this month is May, and 3 months pass, what month is it? Answer: 3 months pass, so it is now August.</s>sports betting']
['<s> If this month is November, and 1 months pass, what month is it? Answer: 1 month has passed, so it is now December.</s>sports bet']
['<s> If this month is March, and 11 months pass, what month is it? Answer: 11 months after March is February.\n\nIf this month is March']
['<s> If this month is March, and 11 months pass, what month is it? Answer: 11 months after March is February.\n\nIf this month is March']
['<s> If this month is April, and 5 months pass, what month is it? Answer:  If it is April and 5 months pass, then it is September.']


In [None]:
all_prompt_outputs

[['<s> If this month is August, and 9 months pass, what month is it? Answer: 9 months after August is May.\n\nIf this month is August,',
  '<s> If this month is August, and 9 months pass, what month is it? Answer:  If it is August and 9 months pass, then it is May.',
  '<s> If this month is August, and 9 months pass, what month is it? Answer:  If it is August and 9 months pass, then it is May.',
  '<s> If this month is August, and 9 months pass, what month is it? Answer: 9 months after August is May.\n\nIf this month is August,',
  '<s> If this month is August, and 9 months pass, what month is it? Answer: 9 months after August is May.\n\nIf this month is August,',
  '<s> If this month is August, and 9 months pass, what month is it? Answer:  If it is August and 9 months pass, then it is May.',
  '<s> If this month is August, and 9 months pass, what month is it? Answer:  If it is August and 9 months pass, then it is May.',
  '<s> If this month is August, and 9 months pass, what month is 

# (more data) If this month is May, and four months pass, what month is it? Answer:

In [None]:
# Generate prompts as described by replacing the month "May" and number "four"
import random

# Define a list of months and a range of numbers to choose from
months = ["January", "February", "March", "April", "May", "June",
          "July", "August", "September", "October", "November", "December"]
numbers = list(range(1, 13))  # Using a realistic range for months in a year

# Generate 10 unique prompts by randomly selecting months and numbers
prompts = []
for _ in range(100):
    month = random.choice(months)
    number = random.choice(numbers)
    prompts.append(f"If this month is {month}, and {number} months pass, what month is it? Answer: ")

prompts

['If this month is March, and 8 months pass, what month is it? Answer: ',
 'If this month is October, and 4 months pass, what month is it? Answer: ',
 'If this month is May, and 8 months pass, what month is it? Answer: ',
 'If this month is February, and 9 months pass, what month is it? Answer: ',
 'If this month is June, and 6 months pass, what month is it? Answer: ',
 'If this month is June, and 1 months pass, what month is it? Answer: ',
 'If this month is May, and 9 months pass, what month is it? Answer: ',
 'If this month is December, and 3 months pass, what month is it? Answer: ',
 'If this month is November, and 12 months pass, what month is it? Answer: ',
 'If this month is June, and 11 months pass, what month is it? Answer: ',
 'If this month is December, and 6 months pass, what month is it? Answer: ',
 'If this month is November, and 5 months pass, what month is it? Answer: ',
 'If this month is December, and 9 months pass, what month is it? Answer: ',
 'If this month is Octo

In [None]:
# Correcting the function to accurately extract the month and the number of months to add

def correct_and_simplify_calculation(prompt):
    # Extract the month and number of months correctly
    words = prompt.split()
    current_month = words[4].strip(',')  # Corrected to properly extract the month name
    months_to_add = int(words[6])  # Corrected to extract the correct integer value for months to add

    # Compute the future month index considering the circular nature of months
    current_index = months.index(current_month)
    future_index = (current_index + months_to_add) % 12

    # Return the future month based on the computed index
    return months[future_index]

# Generate answers using the correctly adjusted function
final_corrected_answers = [correct_and_simplify_calculation(prompt) for prompt in prompts]
final_corrected_answers

['November',
 'February',
 'January',
 'November',
 'December',
 'July',
 'February',
 'March',
 'November',
 'May',
 'June',
 'April',
 'September',
 'August',
 'January',
 'February',
 'September',
 'July',
 'August',
 'June',
 'August',
 'November',
 'November',
 'June',
 'April',
 'March',
 'September',
 'October',
 'March',
 'January',
 'December',
 'November',
 'October',
 'May',
 'September',
 'August',
 'August',
 'May',
 'October',
 'February',
 'January',
 'October',
 'October',
 'May',
 'June',
 'March',
 'January',
 'May',
 'April',
 'November',
 'May',
 'August',
 'September',
 'January',
 'September',
 'July',
 'June',
 'June',
 'March',
 'April',
 'November',
 'April',
 'December',
 'August',
 'December',
 'July',
 'January',
 'February',
 'April',
 'September',
 'June',
 'October',
 'October',
 'March',
 'January',
 'November',
 'April',
 'September',
 'June',
 'November',
 'September',
 'November',
 'November',
 'November',
 'October',
 'December',
 'July',
 'September

In [108]:
# unablated

unfiltered_outputs = []
for clean_text in prompts:
    corr_text = "uno uno uno" # dos tres cinco seis
    heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
    mlps_not_ablate = [layer for layer in range(32)]
    prompt_out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 15)
    unfiltered_outputs.append(prompt_out[0])
    print(prompt_out[0])

<s> If this month is March, and 8 months pass, what month is it? Answer:  If this month is March and 8 months pass, then it is November
<s> If this month is October, and 4 months pass, what month is it? Answer:  If this month is October and 4 months pass, then it is January
<s> If this month is May, and 8 months pass, what month is it? Answer: 13 months have passed, so it is now August.</s>sports
<s> If this month is February, and 9 months pass, what month is it? Answer:  If this month is February and 9 months pass, then 9 months
<s> If this month is June, and 6 months pass, what month is it? Answer: 6 months after June is December.

If this month is December,
<s> If this month is June, and 1 months pass, what month is it? Answer:  If this month is June and 1 month passes, then it is July
<s> If this month is May, and 9 months pass, what month is it? Answer: 11 months.

Explanation: If it is May and
<s> If this month is December, and 3 months pass, what month is it? Answer: 3 months af

In [109]:
len(unfiltered_outputs)

100

Use chatgpt or manual to get correct indices

In [110]:
def find_indices_with_correct_answers(test_strings, correct_answers):
    """
    For each string in 'test_strings', this function checks if the corresponding correct answer
    from 'correct_answers' appears after the word "Answer:". If the correct answer is present,
    it collects the index of that string.

    :param test_strings: List of strings where each string includes "Answer: <some text>"
    :param correct_answers: List of correct answers to check in each string
    :return: List of indices where the correct answer appears after "Answer:"
    """
    correct_indices = []
    for index, (test_string, correct_answer) in enumerate(zip(test_strings, correct_answers)):
        # Find the position where "Answer:" occurs
        answer_pos = test_string.find("Answer:")
        if answer_pos != -1:
            # Extract the part of the string after "Answer:"
            answer_text = test_string[answer_pos + len("Answer:"):].strip()
            # Check if the correct answer appears in this part of the string
            if correct_answer in answer_text:
                correct_indices.append(index)
                # print(test_string)
                # print(answer_text)

    return correct_indices

# Find indices where the answers are correct
indices_with_correct_answers = find_indices_with_correct_answers(unfiltered_outputs, final_corrected_answers)
print(indices_with_correct_answers)

[0, 4, 5, 7, 8, 10, 12, 16, 17, 18, 20, 22, 26, 27, 31, 33, 34, 36, 39, 40, 42, 43, 46, 47, 49, 51, 56, 59, 61, 62, 63, 64, 67, 69, 71, 73, 88, 89, 91, 92, 93, 94, 97]


In [111]:
len(indices_with_correct_answers)

43

In [112]:
# Indices of correct answers
# correct_indices = [0, 1, 4, 5, 7, 8, 10, 12, 15, 17, 18, 21, 23, 27, 28, 31, 34, 35, 37, 40, 41, 47, 48, 54, 57, 62, 63, 65,
#                    66, 72, 75, 76, 78, 80, 83, 85, 88, 89, 90, 91, 93, 95, 96, 98]

# Subset using the correct indices
correct_prompts = [prompts[i] for i in indices_with_correct_answers]
correct_prompts

['If this month is March, and 8 months pass, what month is it? Answer: ',
 'If this month is June, and 6 months pass, what month is it? Answer: ',
 'If this month is June, and 1 months pass, what month is it? Answer: ',
 'If this month is December, and 3 months pass, what month is it? Answer: ',
 'If this month is November, and 12 months pass, what month is it? Answer: ',
 'If this month is December, and 6 months pass, what month is it? Answer: ',
 'If this month is December, and 9 months pass, what month is it? Answer: ',
 'If this month is August, and 1 months pass, what month is it? Answer: ',
 'If this month is April, and 3 months pass, what month is it? Answer: ',
 'If this month is January, and 7 months pass, what month is it? Answer: ',
 'If this month is April, and 4 months pass, what month is it? Answer: ',
 'If this month is November, and 12 months pass, what month is it? Answer: ',
 'If this month is August, and 1 months pass, what month is it? Answer: ',
 'If this month is 

In [113]:
len(correct_prompts)

43

In [114]:
answers_of_correct_prompts = [final_corrected_answers[i] for i in indices_with_correct_answers]
answers_of_correct_prompts

['November',
 'December',
 'July',
 'March',
 'November',
 'June',
 'September',
 'September',
 'July',
 'August',
 'August',
 'November',
 'September',
 'October',
 'November',
 'May',
 'September',
 'August',
 'February',
 'January',
 'October',
 'May',
 'January',
 'May',
 'November',
 'August',
 'June',
 'April',
 'April',
 'December',
 'August',
 'December',
 'February',
 'September',
 'October',
 'March',
 'January',
 'May',
 'June',
 'November',
 'February',
 'November',
 'July']

In [115]:
len(answers_of_correct_prompts)

43

eval again to double check

In [116]:
# unablated

outputs = []
for clean_text in correct_prompts:
    corr_text = "uno uno uno" # dos tres cinco seis
    heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
    mlps_not_ablate = [layer for layer in range(32)]
    prompt_out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 15)
    outputs.append(prompt_out)
    print(prompt_out)

['<s> If this month is March, and 8 months pass, what month is it? Answer:  If this month is March and 8 months pass, then it is November']
['<s> If this month is June, and 6 months pass, what month is it? Answer: 6 months after June is December.\n\nIf this month is December,']
['<s> If this month is June, and 1 months pass, what month is it? Answer:  If this month is June and 1 month passes, then it is July']
['<s> If this month is December, and 3 months pass, what month is it? Answer: 3 months after December is March, so the month is March.</s>0']
['<s> If this month is November, and 12 months pass, what month is it? Answer: 12 months later, it is November again.</s>sports betting']
['<s> If this month is December, and 6 months pass, what month is it? Answer: 6 months after December is June, so the month is June.</s>0']
['<s> If this month is December, and 9 months pass, what month is it? Answer: 9 months after December is September, so the answer is September.</s>0']
['<s> If this m

In [96]:
def calculate_correct_answer_percentage(test_strings, correct_answers):
    """
    This function calculates the percentage of correct answers in 'test_strings' where each string
    should contain the corresponding correct answer from 'correct_answers' after "Answer:".

    :param test_strings: List of strings where each string includes "Answer: <some text>"
    :param correct_answers: List of correct answers to check in each string
    :return: Percentage of correct answers
    """
    correct_count = 0
    for test_string, correct_answer in zip(test_strings, correct_answers):
        # Find the position where "Answer:" occurs
        answer_pos = test_string.find("Answer:")
        if answer_pos != -1:
            # Extract the part of the string after "Answer:"
            answer_text = test_string[answer_pos + len("Answer:"):].strip()
            # Check if the correct answer appears in this part of the string
            if correct_answer in answer_text:
                correct_count += 1
            # else:
            #     print(test_string)

    # Calculate the percentage of correct answers
    if len(test_strings) > 0:
        percentage_correct = (correct_count / len(test_strings)) * 100
    else:
        percentage_correct = 0  # Handle the case where there are no test strings
    # print(correct_count)
    return percentage_correct

In [117]:
outputs = [out[0] for out in outputs]

In [118]:
calculate_correct_answer_percentage(outputs, answers_of_correct_prompts)

100.0

In [119]:
# big 3 heads
head_to_remove = [(20,17), (16,0), (5,25)]
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]
mlps_not_ablate = [layer for layer in range(32)]
corr_text = "uno uno uno" # dos tres cinco seis

big3_outputs = []
for clean_text in correct_prompts:
    prompt_out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 15)
    big3_outputs.append(prompt_out[0])
    print(prompt_out[0])

<s> If this month is March, and 8 months pass, what month is it? Answer:  If this month is March and 8 months pass, then 8 months
<s> If this month is June, and 6 months pass, what month is it? Answer: 6 months after June is December.

If this month is December,
<s> If this month is June, and 1 months pass, what month is it? Answer:  If this month is June and 1 month passes, then it is July
<s> If this month is December, and 3 months pass, what month is it? Answer: 3 months after December is March.</s>sports betting, sportsbook
<s> If this month is November, and 12 months pass, what month is it? Answer: 12 months later, it is November again.</s>sports betting
<s> If this month is December, and 6 months pass, what month is it? Answer: 6 months after December is June, so the month is June.</s>0
<s> If this month is December, and 9 months pass, what month is it? Answer: 9 months after December is September, so the month is September.</s>s
<s> If this month is August, and 1 months pass, wh

In [120]:
calculate_correct_answer_percentage(big3_outputs, answers_of_correct_prompts)

62.7906976744186

In [122]:
# random, len 3 (not from saved head combo presets) ; ssave all results

all_outputs_all_runs = []
heads_of_circ = intersect_all
num_heads_rand = 3
num_not_overlap = len(intersect_all)
all_possible_pairs =  [(layer, head) for layer in range(32) for head in range(32)]
filtered_pairs = [pair for pair in all_possible_pairs if pair not in heads_of_circ] # Filter out heads_of_circ from all_possible_pairs
mlps_not_ablate = [layer for layer in range(32)]
corr_text = "0 0 0"
for i in range(10):
    # Randomly choose pairs ensuring no overlaps with heads_of_circ
    head_to_remove = choose_heads_to_remove(filtered_pairs, heads_of_circ, num_heads_rand, num_not_overlap)
    heads_not_ablate = [x for x in all_possible_pairs if x not in head_to_remove]
    output_for_run = []
    for clean_text in correct_prompts:
        out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 15)
        # print(out[0])
        output_for_run.append(out[0])
    print(calculate_correct_answer_percentage(output_for_run, answers_of_correct_prompts))
    all_outputs_all_runs.append(output_for_run)

79.06976744186046
100.0
93.02325581395348
88.37209302325581
65.11627906976744
90.69767441860465
76.74418604651163
88.37209302325581
95.34883720930233
81.3953488372093


In [124]:
all_scores = []
for output_for_run in all_outputs_all_runs:
    score = calculate_correct_answer_percentage(output_for_run, answers_of_correct_prompts)
    all_scores.append(score)
print(all_scores)
print(sum(all_scores) / len(all_scores))

[79.06976744186046, 100.0, 93.02325581395348, 88.37209302325581, 65.11627906976744, 90.69767441860465, 76.74418604651163, 88.37209302325581, 95.34883720930233, 81.3953488372093]
85.8139534883721


# If today is the Xth of month M, what date will it be in Y days?”

In [None]:
from datetime import datetime, timedelta
import random

def generate_prompts_and_correct_dates(N):
    months = ["January", "February", "March", "April", "May", "June",
              "July", "August", "September", "October", "November", "December"]

    prompts = []
    correct_dates = []

    for _ in range(N):
        month_index = random.randint(0, 11)
        day = random.randint(1, 28)  # to avoid issues with different month lengths
        days_to_add = random.randint(1, 28)
        current_date = datetime(2024, month_index + 1, day)
        future_date = current_date + timedelta(days=days_to_add)
        future_month = months[future_date.month - 1]
        prompt = f"If today is {months[month_index]} {day}th, then in {days_to_add} days it will be "
        correct_date = f"{future_month} {future_date.day}th"

        prompts.append(prompt)
        correct_dates.append(correct_date)

    return prompts, correct_dates

N = 20
prompts, correct_dates = generate_prompts_and_correct_dates(N)

# Printing the results
# print("Prompts:")
# for prompt in prompts:
#     print(prompt)
# print("\nCorrect Answers:")
# for date in correct_dates:
#     print(date)

In [None]:
file_path = '/content/template_1_unablated_correct.txt'
correct_prompts = []
with open(file_path, 'r') as file:
    for line in file:
        correct_prompts.append([line.strip()])
print(correct_prompts)

[['Be concise. If today is July 23th, then in 22 days it will be'], ['Be concise. If today is April 19th, then in 25 days it will be'], ['Be concise. If today is March 21th, then in 16 days it will be'], ['Be concise. If today is June 28th, then in 18 days it will be'], ['Be concise. If today is April 14th, then in 11 days it will be'], ['Be concise. If today is April 7th, then in 20 days it will be'], ['Be concise. If today is October 28th, then in 10 days it will be'], ['Be concise. If today is May 26th, then in 5 days it will be'], ['Be concise. If today is April 17th, then in 28 days it will be'], ['Be concise. If today is September 16th, then in 12 days it will be'], ['Be concise. If today is October 21th, then in 17 days it will be'], ['Be concise. If today is July 12th, then in 23 days it will be'], ['Be concise. If today is January 27th, then in 11 days it will be'], ['Be concise. If today is May 18th, then in 18 days it will be'], ['Be concise. If today is August 18th, then in

In [None]:
# unablated

outputs = []
instruction = "Be concise. "
# for clean_text in correct_prompts:
for clean_text in prompts:
    # clean_text = instruction + clean_text
    # clean_text = clean_text[0]
    corr_text = "uno uno uno" # dos tres cinco seis
    heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
    mlps_not_ablate = [layer for layer in range(32)]
    prompt_out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)
    outputs.append(prompt_out)
    print(prompt_out)

['<s> If today is September 24th, then in 14 days it will be  October 8th.']
['<s> If today is April 13th, then in 17 days it will be  May 1st.']
['<s> If today is September 16th, then in 26 days it will be  October 14th']
['<s> If today is March 19th, then in 6 days it will be  March 25th']
['<s> If today is February 1th, then in 20 days it will be  February 20th']
['<s> If today is June 14th, then in 17 days it will be  July 1st.']
['<s> If today is August 14th, then in 6 days it will be  August 20th']
['<s> If today is July 19th, then in 21 days it will be  August 9th.']
['<s> If today is November 27th, then in 3 days it will be  November 30th']
['<s> If today is July 10th, then in 28 days it will be  August 7th.']
['<s> If today is March 25th, then in 16 days it will be  April 10th']
['<s> If today is November 13th, then in 9 days it will be  November 22nd']
['<s> If today is August 18th, then in 27 days it will be  September 15th']
['<s> If today is December 23th, then in 2 days i

In [None]:
from google.colab import files
with open('template_1_unablated_wAns.txt', 'w') as f:
    for line in outputs:
        f.write(f"{line[0]}\n")
files.download('template_1_unablated_wAns.txt')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
file_path = '/content/template_1_unablated_wAns.txt'
outputs = []
with open(file_path, 'r') as file:
    for line in file:
        if line != '\n':
            outputs.append(line.replace('\n', ''))
print(outputs)

['<s> If today is September 24th, then in 14 days it will be  October 8th.', '<s> If today is April 13th, then in 17 days it will be  May 1st.', '<s> If today is September 16th, then in 26 days it will be  October 14th', '<s> If today is March 19th, then in 6 days it will be  March 25th', '<s> If today is February 1th, then in 20 days it will be  February 20th', '<s> If today is June 14th, then in 17 days it will be  July 1st.', '<s> If today is August 14th, then in 6 days it will be  August 20th', '<s> If today is July 19th, then in 21 days it will be  August 9th.', '<s> If today is November 27th, then in 3 days it will be  November 30th', '<s> If today is July 10th, then in 28 days it will be  August 7th.', '<s> If today is March 25th, then in 16 days it will be  April 10th', '<s> If today is November 13th, then in 9 days it will be  November 22nd', '<s> If today is August 18th, then in 27 days it will be  September 15th', '<s> If today is December 23th, then in 2 days it will be 25t

In [None]:
outputs[-6]

'<s> If today is January 17th, then in 11 days it will be 2022-'

In [None]:
correct_dates[-6]

'January 28th'

In [None]:
out_ans = outputs[0].split(' ')[-2] + ' ' + outputs[0].split(' ')[-1]
out_ans = out_ans.replace('.','')
out_ans

'October 8th'

In [None]:
correct_dates[0]

'October 8th'

In [None]:
outputs[0].split(' ')[-2]

'October'

In [None]:
len(outputs)

50

In [None]:
def validate_prompt(output, correct_date):
    # try:
    #     date_str = prompt.split(" is ")[1].split(", then in ")[0]
    #     days_to_add = int(prompt.split(" in ")[1].split(" days")[0])
    #     start_date = datetime.strptime(date_str, "%B %dth")
    #     end_date = start_date + timedelta(days=days_to_add)
    #     expected_date = datetime.strptime(correct_date, "%B %dth")
    #     return end_date == expected_date
    # except ValueError:
    #     return False
    out_ans = output.split(' ')[-2] + ' ' + output.split(' ')[-1]
    out_ans = out_ans.replace('.','')
    if out_ans == correct_date:
        return True
    else:
        return False

def get_correct_prompts(outputs, correct_dates):
    corr_prompts = []
    correct_dates_of_correct_prompts = []
    for output, correct_date in zip(outputs, correct_dates):
        out_ans = output.split(' ')[-2] + ' ' + output.split(' ')[-1]
        out_ans = out_ans.replace('.','')
        if out_ans == correct_date:
            corr_prompts.append(output)
            correct_dates_of_correct_prompts.append(correct_date)
    return corr_prompts, correct_dates_of_correct_prompts

# Validate all prompts
# results = [validate_prompt(prompt, correct_date) for prompt, correct_date in zip(outputs, correct_dates)]
correctPrompts, correct_dates_of_correct_prompts = get_correct_prompts(outputs, correct_dates)

# Print the results
# for prompt, is_correct in zip(outputs, results):
#     print(f"Prompt: {prompt} - {'Correct' if is_correct else 'Incorrect'}")

# Calculate the percentage of correct prompts
# percentage_correct = sum(results) / len(results) * 100
percentage_correct = len(correctPrompts) / len(outputs) * 100
print(f"Percentage of correct prompts: {percentage_correct}%")

Percentage of correct prompts: 36.0%


In [None]:
correctPrompts = [out.replace('<s> ', '') for out in correctPrompts]

In [None]:
correctPrompts = [' '.join(out.split(' ')[:-2])[:-1] for out in correctPrompts]

In [None]:
with open('template_1_unablated_correct.txt', 'w') as f:
    for line in correctPrompts:
        f.write(f"{line}\n")
files.download('template_1_unablated_correct.txt')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
file_path = '/content/template_1_unablated_correct.txt'
correct_prompts = []
with open(file_path, 'r') as file:
    for line in file:
        correct_prompts.append(line.replace('\n', ''))
print(correct_prompts)

['If today is September 24th, then in 14 days it will be', 'If today is March 19th, then in 6 days it will be', 'If today is August 14th, then in 6 days it will be', 'If today is July 19th, then in 21 days it will be', 'If today is November 27th, then in 3 days it will be', 'If today is July 10th, then in 28 days it will be', 'If today is March 25th, then in 16 days it will be', 'If today is November 18th, then in 28 days it will be', 'If today is July 25th, then in 16 days it will be', 'If today is May 1th, then in 11 days it will be', 'If today is August 12th, then in 15 days it will be', 'If today is May 11th, then in 1 days it will be', 'If today is July 15th, then in 14 days it will be', 'If today is January 18th, then in 2 days it will be', 'If today is July 19th, then in 20 days it will be', 'If today is April 22th, then in 3 days it will be', 'If today is August 23th, then in 23 days it will be', 'If today is June 10th, then in 4 days it will be']


In [None]:
# intersect_all
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
head_to_remove = intersect_all
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]
mlps_not_ablate = [layer for layer in range(32)]
corr_text = "uno uno uno" # dos tres cinco seis

outputs = []
instruction = "Be concise. "
for clean_text in correct_prompts:
    clean_text = instruction + clean_text
    prompt_out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)
    outputs.append(prompt_out)
    print(prompt_out)

outputs = [' '.join(out[0].replace('<s> ', '').split(' ')[:-2])[:-1] for out in outputs]
with open('template_1_intersectAll.txt', 'w') as f:
    for line in outputs:
        f.write(f"{line}\n")
files.download('template_1_intersectAll.txt')

['<s> Be concise. If today is September 24th, then in 14 days it will be October 6th.']
['<s> Be concise. If today is March 19th, then in 6 days it will be March 24th']
['<s> Be concise. If today is August 14th, then in 6 days it will be August 20th']
['<s> Be concise. If today is July 19th, then in 21 days it will be July 29th']
['<s> Be concise. If today is November 27th, then in 3 days it will be December 1st.']
['<s> Be concise. If today is July 10th, then in 28 days it will be August 10th']
['<s> Be concise. If today is March 25th, then in 16 days it will be April 10th']
['<s> Be concise. If today is November 18th, then in 28 days it will be December 18th']
['<s> Be concise. If today is July 25th, then in 16 days it will be August 15th']
['<s> Be concise. If today is May 1th, then in 11 days it will be May 12th']
['<s> Be concise. If today is August 12th, then in 15 days it will be August 26th']
['<s> Be concise. If today is May 11th, then in 1 days it will be May 12th']
['<s> Be 

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
# big 3 heads
head_to_remove = [(20,17), (16,0), (5,25)]
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]
mlps_not_ablate = [layer for layer in range(32)]
corr_text = "uno uno uno" # dos tres cinco seis

outputs = []
instruction = "Be concise. "
for clean_text in correct_prompts:
    clean_text = instruction + clean_text
    prompt_out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)
    outputs.append(prompt_out)
    print(prompt_out)

['<s> Be concise. If today is September 24th, then in 14 days it will be October 8th.']
['<s> Be concise. If today is March 19th, then in 6 days it will be March 25th']
['<s> Be concise. If today is August 14th, then in 6 days it will be August 20th']
['<s> Be concise. If today is July 19th, then in 21 days it will be July 30th']
['<s> Be concise. If today is November 27th, then in 3 days it will be November 30th']
['<s> Be concise. If today is July 10th, then in 28 days it will be July 31st']
['<s> Be concise. If today is March 25th, then in 16 days it will be April 11th']
['<s> Be concise. If today is November 18th, then in 28 days it will be December 18th']
['<s> Be concise. If today is July 25th, then in 16 days it will be August 11th']
['<s> Be concise. If today is May 1th, then in 11 days it will be May 12th']
['<s> Be concise. If today is August 12th, then in 15 days it will be August 27th']
['<s> Be concise. If today is May 11th, then in 1 days it will be May 12th']
['<s> Be co

In [None]:
outputs = [out[0] for out in outputs]
with open('template_1_big3.txt', 'w') as f:
    for line in outputs:
        f.write(f"{line}\n")
files.download('template_1_big3.txt')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
file_path = '/content/template_1_big3.txt'
big_3_outputs = []
with open(file_path, 'r') as file:
    for line in file:
        if line != '\n':
            big_3_outputs.append(line.replace('\n', ''))
print(big_3_outputs)

['<s> Be concise. If today is September 24th, then in 14 days it will be October 8th.', '<s> Be concise. If today is March 19th, then in 6 days it will be March 25th', '<s> Be concise. If today is August 14th, then in 6 days it will be August 20th', '<s> Be concise. If today is July 19th, then in 21 days it will be July 30th', '<s> Be concise. If today is November 27th, then in 3 days it will be November 30th', '<s> Be concise. If today is July 10th, then in 28 days it will be July 31st', '<s> Be concise. If today is March 25th, then in 16 days it will be April 11th', '<s> Be concise. If today is November 18th, then in 28 days it will be December 18th', '<s> Be concise. If today is July 25th, then in 16 days it will be August 11th', '<s> Be concise. If today is May 1th, then in 11 days it will be May 12th', '<s> Be concise. If today is August 12th, then in 15 days it will be August 27th', '<s> Be concise. If today is May 11th, then in 1 days it will be May 12th', '<s> Be concise. If to

In [None]:
len(big_3_outputs)

18

In [None]:
len(correct_dates_of_correct_prompts)

18

In [None]:
def get_correct_prompts_only(outputs, correct_dates):
    corr_prompts = []
    # correct_dates_of_correct_prompts = []
    for output, correct_date in zip(outputs, correct_dates):
        out_ans = output.split(' ')[-2] + ' ' + output.split(' ')[-1]
        out_ans = out_ans.replace('.','')
        if out_ans == correct_date:
            corr_prompts.append(output)
            # correct_dates_of_correct_prompts.append(correct_date)
    return corr_prompts

In [None]:
correctPrompts = get_correct_prompts_only(big_3_outputs, correct_dates_of_correct_prompts)
percentage_correct = len(correctPrompts) / len(big_3_outputs) * 100
print(f"Percentage of correct prompts: {percentage_correct}%")

Percentage of correct prompts: 61.111111111111114%


In [None]:
# big 3 heads
head_to_remove = [(20,17), (16,0), (5,25)]
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]
mlps_not_ablate = [layer for layer in range(32)]
corr_text = "uno uno uno" # dos tres cinco seis

outputs = []
instruction = "Be concise. "
for clean_text in correct_prompts:
    clean_text = instruction + clean_text
    prompt_out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)
    outputs.append(prompt_out)
    print(prompt_out)

In [None]:
# random, len 4 (not from saved head combo presets) ; ssave all results

all_prompt_outputs = []
heads_of_circ = intersect_all
num_heads_rand = 4
num_not_overlap = len(intersect_all)
all_possible_pairs =  [(layer, head) for layer in range(32) for head in range(32)]
filtered_pairs = [pair for pair in all_possible_pairs if pair not in heads_of_circ] # Filter out heads_of_circ from all_possible_pairs
mlps_not_ablate = [layer for layer in range(32)]
corr_text = "0 0 0"
for clean_text in correct_prompts:
    output_for_a_prompt = []
    for i in range(10):
        # Randomly choose pairs ensuring no overlaps with heads_of_circ
        head_to_remove = choose_heads_to_remove(filtered_pairs, heads_of_circ, num_heads_rand, num_not_overlap)
        heads_not_ablate = [x for x in all_possible_pairs if x not in head_to_remove]
        out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)
        # print(out[0])
        output_for_a_prompt.append(out[0])
    print(out)
    all_prompt_outputs.append(output_for_a_prompt)

['<s> If today is September 24th, then in 14 days it will be October 8th.']
['<s> If today is March 19th, then in 6 days it will be March 25th']
['<s> If today is August 14th, then in 6 days it will be August 20th']
['<s> If today is July 19th, then in 21 days it will be August 9th.']
['<s> If today is November 27th, then in 3 days it will be November 30th']
['<s> If today is July 10th, then in 28 days it will be August 7th.']
['<s> If today is March 25th, then in 16 days it will be April 10th']
['<s> If today is November 18th, then in 28 days it will be December 16th']
['<s> If today is July 25th, then in 16 days it will be August 10th']
['<s> If today is May 1th, then in 11 days it will be May 12th']
['<s> If today is August 12th, then in 15 days it will be August 27th']
['<s> If today is May 11th, then in 1 days it will be May 12th']
['<s> If today is July 15th, then in 14 days it will be July 29th']
['<s> If today is January 18th, then in 2 days it will be January 20th']
['<s> If t

In [None]:
import pdb
def num_correct_prompts_rand(list_of_list_outputs, correct_dates):
    all_scores_for_prompts = []
    for i, prompt_outputs in enumerate(list_of_list_outputs):
        num_corr_for_prompt = 0
        correct_date = correct_dates[i]
        for run_output in prompt_outputs:
            # pdb.set_trace()
            run_output = run_output[0]
            out_ans = run_output.split(' ')[-2] + ' ' + run_output.split(' ')[-1]
            # print(out_ans)
            out_ans = out_ans.replace('.','')
            if out_ans == correct_date:
                num_corr_for_prompt += 1
        perc_corr_for_prompt = num_corr_for_prompt / len(prompt_outputs)
        print(run_output, ' : ', perc_corr_for_prompt)
        all_scores_for_prompts.append(perc_corr_for_prompt)
    return sum(all_scores_for_prompts) / len(all_scores_for_prompts)  * 100

In [None]:
percentage_correct = num_correct_prompts_rand(all_prompt_outputs, correct_dates_of_correct_prompts) # randAbl_prompt_outputs
print(f"Percentage of correct prompts: {percentage_correct}%")

<s> If today is September 24th, then in 14 days it will be October 8th.  :  0.9
<s> If today is March 19th, then in 6 days it will be March 25th  :  1.0
<s> If today is August 14th, then in 6 days it will be August 20th  :  1.0
<s> If today is July 19th, then in 21 days it will be August 9th.  :  1.0
<s> If today is November 27th, then in 3 days it will be November 30th  :  0.9
<s> If today is July 10th, then in 28 days it will be August 7th.  :  1.0
<s> If today is March 25th, then in 16 days it will be April 10th  :  1.0
<s> If today is November 18th, then in 28 days it will be December 16th  :  1.0
<s> If today is July 25th, then in 16 days it will be August 10th  :  1.0
<s> If today is May 1th, then in 11 days it will be May 12th  :  1.0
<s> If today is August 12th, then in 15 days it will be August 27th  :  1.0
<s> If today is May 11th, then in 1 days it will be May 12th  :  1.0
<s> If today is July 15th, then in 14 days it will be July 29th  :  1.0
<s> If today is January 18th, t

In [None]:
with open('template_1_rand.txt', 'w') as f:
    for line in all_prompt_outputs:
        f.write(f"{line}\n")
files.download('template_1_rand.txt')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

# (more data, runs) If today is the Xth of month M, what date will it be in Y days?”

In [None]:
from datetime import datetime, timedelta
import random

def generate_prompts_and_correct_dates(N):
    months = ["January", "February", "March", "April", "May", "June",
              "July", "August", "September", "October", "November", "December"]

    prompts = []
    correct_dates = []

    for _ in range(N):
        month_index = random.randint(0, 11)
        day = random.randint(1, 28)  # to avoid issues with different month lengths
        days_to_add = random.randint(1, 28)
        current_date = datetime(2024, month_index + 1, day)
        future_date = current_date + timedelta(days=days_to_add)
        future_month = months[future_date.month - 1]
        prompt = f"If today is {months[month_index]} {day}th, then in {days_to_add} days it will be "
        correct_date = f"{future_month} {future_date.day}th"

        prompts.append(prompt)
        correct_dates.append(correct_date)

    return prompts, correct_dates

N = 100
prompts, correct_dates = generate_prompts_and_correct_dates(N)

# Printing the results
# print("Prompts:")
# for prompt in prompts:
#     print(prompt)
# print("\nCorrect Answers:")
# for date in correct_dates:
#     print(date)

In [None]:
# unablated

outputs = []
instruction = "Be concise. "
# for clean_text in correct_prompts:
for clean_text in prompts:
    # clean_text = instruction + clean_text
    # clean_text = clean_text[0]
    corr_text = "uno uno uno" # dos tres cinco seis
    heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
    mlps_not_ablate = [layer for layer in range(32)]
    prompt_out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)
    outputs.append(prompt_out[0])
    print(prompt_out)

['<s> If today is April 26th, then in 21 days it will be  May 16th']
['<s> If today is February 23th, then in 26 days it will be  March 19th']
['<s> If today is April 12th, then in 15 days it will be  April 27th']
['<s> If today is August 15th, then in 15 days it will be  August 30th']
['<s> If today is October 20th, then in 26 days it will be  November 15th']
['<s> If today is July 20th, then in 9 days it will be 29th (2']
['<s> If today is April 15th, then in 11 days it will be  April 26th']
['<s> If today is December 7th, then in 14 days it will be  December 21st']
['<s> If today is February 7th, then in 25 days it will be  March 7th.']
['<s> If today is August 19th, then in 18 days it will be  September 8th.']
['<s> If today is June 6th, then in 1 days it will be  June 7th.']
['<s> If today is May 6th, then in 1 days it will be  May 7th.']
['<s> If today is May 26th, then in 16 days it will be  June 12th']
['<s> If today is August 2th, then in 13 days it will be  August 15th']
['<s

In [None]:
from google.colab import files

In [None]:
# with open('template_1_unablated_wAns.txt', 'w') as f:
#     for line in outputs:
#         f.write(f"{line[0]}\n")
# files.download('template_1_unablated_wAns.txt')

# with open('template_1_unablated_wAns.pkl', 'wb') as file:
#     pickle.dump(outputs, file)
#     files.download('template_1_unablated_wAns.pkl')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
file_path = '/content/template_1_unablated_wAns.pkl'
with open(file_path, 'rb') as file:
    outputs = pickle.load(file)

In [None]:
outputs = [out[0] for out in outputs]

In [None]:
len(outputs)

100

In [None]:
out_ans = outputs[0].split(' ')[-2] + ' ' + outputs[0].split(' ')[-1]
out_ans = out_ans.replace('.','')
out_ans

'May 16th'

In [None]:
correct_dates[0]

'May 17th'

In [None]:
outputs[0].split(' ')[-2]

'May'

In [None]:
def get_correct_prompts(outputs, correct_dates):
    corr_prompts = []
    correct_dates_of_correct_prompts = []
    for output, correct_date in zip(outputs, correct_dates):
        out_ans = output.split(' ')[-2] + ' ' + output.split(' ')[-1]
        out_ans = out_ans.replace('.','')
        if out_ans == correct_date:
            corr_prompts.append(output)
            correct_dates_of_correct_prompts.append(correct_date)
    return corr_prompts, correct_dates_of_correct_prompts

Percentage of correct prompts: 41.0%


In [None]:
correctPrompts, correct_dates_of_correct_prompts = get_correct_prompts(outputs, correct_dates)
percentage_correct = len(correctPrompts) / len(outputs) * 100
print(f"Percentage of correct prompts: {percentage_correct}%")

Percentage of correct prompts: 41.0%


In [None]:
correct_prompts = [out.replace('<s> ', '') for out in correctPrompts]
# correct_prompts = [' '.join(out.split(' ')[:-2])[:-1] for out in correctPrompts] # keep space at end
correct_prompts = [' '.join(out.split(' ')[:-2]) for out in correct_prompts]

In [None]:
with open('template_1_unablated_correct.txt', 'w') as f:
    for line in correct_prompts:
        f.write(f"{line}\n")
files.download('template_1_unablated_correct.txt')

with open('template_1_unablated_correct.pkl', 'wb') as file:
    pickle.dump(correct_prompts, file)
    files.download('template_1_unablated_correct.pkl')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
with open('correct_dates_of_correct_prompts.pkl', 'wb') as file:
    pickle.dump(correct_dates_of_correct_prompts, file)
    files.download('correct_dates_of_correct_prompts.pkl')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Double check

In [None]:
# unablated

outputs = []
# instruction = "Be concise. "
for clean_text in correct_prompts:
    # clean_text = instruction + clean_text
    # clean_text = clean_text[0]
    corr_text = "uno uno uno" # dos tres cinco seis
    heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
    mlps_not_ablate = [layer for layer in range(32)]
    prompt_out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)
    outputs.append(prompt_out[0])
    print(prompt_out)

['<s> If today is April 12th, then in 15 days it will be  April 27th']
['<s> If today is August 15th, then in 15 days it will be  August 30th']
['<s> If today is October 20th, then in 26 days it will be  November 15th']
['<s> If today is April 15th, then in 11 days it will be  April 26th']
['<s> If today is June 6th, then in 1 days it will be  June 7th.']
['<s> If today is May 6th, then in 1 days it will be  May 7th.']
['<s> If today is August 2th, then in 13 days it will be  August 15th']
['<s> If today is September 20th, then in 8 days it will be  September 28th']
['<s> If today is November 16th, then in 28 days it will be  December 14th']
['<s> If today is December 16th, then in 20 days it will be  January 5th.']
['<s> If today is March 22th, then in 21 days it will be  April 12th']
['<s> If today is June 4th, then in 5 days it will be  June 9th.']
['<s> If today is July 4th, then in 13 days it will be  July 17th']
['<s> If today is November 17th, then in 18 days it will be  Decembe

In [None]:
percentage_correct = get_correct_prompts_only(outputs, correct_dates_of_correct_prompts)

Percentage of correct prompts: 100.0%


In [None]:
# big 3 heads
head_to_remove = [(20,17), (16,0), (5,25)]
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]
mlps_not_ablate = [layer for layer in range(32)]
corr_text = "uno uno uno" # dos tres cinco seis

big_3_outputs = []
# instruction = "Be concise. "
for clean_text in correct_prompts:
    # clean_text = instruction + clean_text
    prompt_out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)
    big_3_outputs.append(prompt_out[0])
    print(prompt_out)

['<s> If today is April 12th, then in 15 days it will be  April 27th']
['<s> If today is August 15th, then in 15 days it will be  August 30th']
['<s> If today is October 20th, then in 26 days it will be  November 10th']
['<s> If today is April 15th, then in 11 days it will be  April 26th']
['<s> If today is June 6th, then in 1 days it will be  June 7th.']
['<s> If today is May 6th, then in 1 days it will be  May 7th.']
['<s> If today is August 2th, then in 13 days it will be  August 14th']
['<s> If today is September 20th, then in 8 days it will be  September 28th']
['<s> If today is November 16th, then in 28 days it will be  December 14th']
['<s> If today is December 16th, then in 20 days it will be 2022-']
['<s> If today is March 22th, then in 21 days it will be  April 12th']
['<s> If today is June 4th, then in 5 days it will be  June 9th.']
['<s> If today is July 4th, then in 13 days it will be  July 17th']
['<s> If today is November 17th, then in 18 days it will be  December 1st.']

In [None]:
def get_correct_prompts_only(outputs, correct_dates):
    num_corr_prompts = 0
    for output, correct_date in zip(outputs, correct_dates):
        out_ans = output.split(' ')[-2] + ' ' + output.split(' ')[-1]
        out_ans = out_ans.replace('.','')
        if out_ans == correct_date:
            num_corr_prompts += 1
    percentage_correct = num_corr_prompts / len(outputs) * 100
    print(f"Percentage of correct prompts: {percentage_correct}%")
    return percentage_correct

In [None]:
percentage_correct = get_correct_prompts_only(big_3_outputs, correct_dates_of_correct_prompts)

Percentage of correct prompts: 75.60975609756098%


In [None]:
with open('template_1_big3.txt', 'w') as f:
    for line in big_3_outputs:
        f.write(f"{line}\n")
files.download('template_1_big3.txt')

with open('template_1_big3.pkl', 'wb') as file:
    pickle.dump(big_3_outputs, file)
    files.download('template_1_big3.pkl')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
# big 5 heads
head_to_remove = [(20,17), (16,0), (5,25), (6,11), (11,18)]
heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]
mlps_not_ablate = [layer for layer in range(32)]
corr_text = "uno uno uno" # dos tres cinco seis

big5_outputs = []
# instruction = "Be concise. "
for i, clean_text in enumerate(correct_prompts):
    # clean_text = instruction + clean_text
    prompt_out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)
    big5_outputs.append(prompt_out[0])
    print(i, prompt_out)

0 ['<s> If today is April 12th, then in 15 days it will be  April 27th']
1 ['<s> If today is August 15th, then in 15 days it will be  August 30th']
2 ['<s> If today is October 20th, then in 26 days it will be  November 10th']
3 ['<s> If today is April 15th, then in 11 days it will be 2023-']
4 ['<s> If today is June 6th, then in 1 days it will be  June 7th.']
5 ['<s> If today is May 6th, then in 1 days it will be  May 7th.']
6 ['<s> If today is August 2th, then in 13 days it will be  August 13th']
7 ['<s> If today is September 20th, then in 8 days it will be  September 28th']
8 ['<s> If today is November 16th, then in 28 days it will be  December 14th']
9 ['<s> If today is December 16th, then in 20 days it will be 2022-']
10 ['<s> If today is March 22th, then in 21 days it will be  April 12th']
11 ['<s> If today is June 4th, then in 5 days it will be  June 9th.']
12 ['<s> If today is July 4th, then in 13 days it will be  July 17th']
13 ['<s> If today is November 17th, then in 18 days i

In [None]:
percentage_correct = get_correct_prompts_only(big5_outputs, correct_dates_of_correct_prompts)

Percentage of correct prompts: 68.29268292682927%


In [None]:
with open('template_1_big5.txt', 'w') as f:
    for line in big5_outputs:
        f.write(f"{line}\n")
files.download('template_1_big5.txt')

with open('template_1_big5.pkl', 'wb') as file:
    pickle.dump(big5_outputs, file)
    files.download('template_1_big5.pkl')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
# random, len 3 (not from saved head combo presets) ; ssave all results

all_prompt_outputs = []
heads_of_circ = intersect_all
num_heads_rand = 3
num_not_overlap = len(intersect_all)
all_possible_pairs =  [(layer, head) for layer in range(32) for head in range(32)]
filtered_pairs = [pair for pair in all_possible_pairs if pair not in heads_of_circ] # Filter out heads_of_circ from all_possible_pairs
mlps_not_ablate = [layer for layer in range(32)]
corr_text = "0 0 0"
for p, clean_text in enumerate(correct_prompts):
    output_for_a_prompt = []
    for i in range(10):
        # Randomly choose pairs ensuring no overlaps with heads_of_circ
        head_to_remove = choose_heads_to_remove(filtered_pairs, heads_of_circ, num_heads_rand, num_not_overlap)
        heads_not_ablate = [x for x in all_possible_pairs if x not in head_to_remove]
        out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)
        # print(out[0])
        output_for_a_prompt.append(out[0])
    print(p, out)
    all_prompt_outputs.append(output_for_a_prompt)

0 ['<s> If today is April 12th, then in 15 days it will be  April 27th']
1 ['<s> If today is August 15th, then in 15 days it will be  August 30th']
2 ['<s> If today is October 20th, then in 26 days it will be  November 15th']
3 ['<s> If today is April 15th, then in 11 days it will be  April 26th']
4 ['<s> If today is June 6th, then in 1 days it will be  June 7th.']
5 ['<s> If today is May 6th, then in 1 days it will be  May 7th.']
6 ['<s> If today is August 2th, then in 13 days it will be 15th August.']
7 ['<s> If today is September 20th, then in 8 days it will be  September 28th']
8 ['<s> If today is November 16th, then in 28 days it will be  December 14th']
9 ['<s> If today is December 16th, then in 20 days it will be  January 5th.']
10 ['<s> If today is March 22th, then in 21 days it will be  April 12th']
11 ['<s> If today is June 4th, then in 5 days it will be  June 9th.']
12 ['<s> If today is July 4th, then in 13 days it will be  July 17th']
13 ['<s> If today is November 17th, the

In [None]:
import pdb
def num_correct_prompts_rand(list_of_list_outputs, correct_dates):
    all_scores_for_prompts = []
    for i, prompt_outputs in enumerate(list_of_list_outputs):
        num_corr_for_prompt = 0
        correct_date = correct_dates[i]
        for run_output in prompt_outputs:
            # pdb.set_trace()
            # run_output = run_output[0]
            out_ans = run_output.split(' ')[-2] + ' ' + run_output.split(' ')[-1]
            # print(out_ans)
            out_ans = out_ans.replace('.','')
            if out_ans == correct_date:
                num_corr_for_prompt += 1
        perc_corr_for_prompt = num_corr_for_prompt / len(prompt_outputs)
        print(run_output, ' : ', perc_corr_for_prompt)
        all_scores_for_prompts.append(perc_corr_for_prompt)
    return sum(all_scores_for_prompts) / len(all_scores_for_prompts)  * 100

In [None]:
percentage_correct = num_correct_prompts_rand(all_prompt_outputs, correct_dates_of_correct_prompts) # randAbl_prompt_outputs
print(f"Percentage of correct prompts: {percentage_correct}%")

<s> If today is April 12th, then in 15 days it will be  April 27th  :  1.0
<s> If today is August 15th, then in 15 days it will be  August 30th  :  0.6
<s> If today is October 20th, then in 26 days it will be  November 15th  :  0.8
<s> If today is April 15th, then in 11 days it will be  April 26th  :  1.0
<s> If today is June 6th, then in 1 days it will be  June 7th.  :  0.9
<s> If today is May 6th, then in 1 days it will be  May 7th.  :  1.0
<s> If today is August 2th, then in 13 days it will be 15th August.  :  0.8
<s> If today is September 20th, then in 8 days it will be  September 28th  :  1.0
<s> If today is November 16th, then in 28 days it will be  December 14th  :  1.0
<s> If today is December 16th, then in 20 days it will be  January 5th.  :  1.0
<s> If today is March 22th, then in 21 days it will be  April 12th  :  1.0
<s> If today is June 4th, then in 5 days it will be  June 9th.  :  1.0
<s> If today is July 4th, then in 13 days it will be  July 17th  :  1.0
<s> If today is 

In [None]:
with open('template_1_rand.txt', 'w') as f:
    for line in all_prompt_outputs:
        f.write(f"{line}\n")
files.download('template_1_rand.txt')

with open('template_1_rand.pkl', 'wb') as file:
    pickle.dump(all_prompt_outputs, file)
    files.download('template_1_rand.pkl')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
# random, len 3 (not from saved head combo presets) ; ssave all results

all_prompt_outputs = []
heads_of_circ = intersect_all
num_heads_rand = 3
num_not_overlap = len(intersect_all)
all_possible_pairs =  [(layer, head) for layer in range(32) for head in range(32)]
filtered_pairs = [pair for pair in all_possible_pairs if pair not in heads_of_circ] # Filter out heads_of_circ from all_possible_pairs
mlps_not_ablate = [layer for layer in range(32)]
corr_text = "0 0 0"
for p, clean_text in enumerate(correct_prompts):
    output_for_a_prompt = []
    for i in range(50):
        # Randomly choose pairs ensuring no overlaps with heads_of_circ
        head_to_remove = choose_heads_to_remove(filtered_pairs, heads_of_circ, num_heads_rand, num_not_overlap)
        heads_not_ablate = [x for x in all_possible_pairs if x not in head_to_remove]
        out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, 5)
        # print(out[0])
        output_for_a_prompt.append(out[0])
    print(p, out)
    all_prompt_outputs.append(output_for_a_prompt)

0 ['<s> If today is April 12th, then in 15 days it will be  April 27th']
1 ['<s> If today is August 15th, then in 15 days it will be  September 1st.']
2 ['<s> If today is October 20th, then in 26 days it will be  November 15th']
3 ['<s> If today is April 15th, then in 11 days it will be  April 26th']
4 ['<s> If today is June 6th, then in 1 days it will be 7th (today).']
5 ['<s> If today is May 6th, then in 1 days it will be  May 7th.']
6 ['<s> If today is August 2th, then in 13 days it will be  August 15th']
7 ['<s> If today is September 20th, then in 8 days it will be  September 28th']
8 ['<s> If today is November 16th, then in 28 days it will be  December 14th']
9 ['<s> If today is December 16th, then in 20 days it will be  January 5th.']
10 ['<s> If today is March 22th, then in 21 days it will be  April 12th']
11 ['<s> If today is June 4th, then in 5 days it will be  June 9th.']
12 ['<s> If today is July 4th, then in 13 days it will be  July 17th']
13 ['<s> If today is November 17th

KeyboardInterrupt: 

In [None]:
percentage_correct = num_correct_prompts_rand(all_prompt_outputs[:-1], correct_dates_of_correct_prompts) # randAbl_prompt_outputs
print(f"Percentage of correct prompts: {percentage_correct}%")

<s> If today is April 12th, then in 15 days it will be  April 27th  :  1.0
<s> If today is August 15th, then in 15 days it will be  September 1st.  :  0.66
<s> If today is October 20th, then in 26 days it will be  November 15th  :  0.84
<s> If today is April 15th, then in 11 days it will be  April 26th  :  0.74
<s> If today is June 6th, then in 1 days it will be 7th (today).  :  0.82
<s> If today is May 6th, then in 1 days it will be  May 7th.  :  1.0
<s> If today is August 2th, then in 13 days it will be  August 15th  :  0.88
<s> If today is September 20th, then in 8 days it will be  September 28th  :  0.98
<s> If today is November 16th, then in 28 days it will be  December 14th  :  0.96
<s> If today is December 16th, then in 20 days it will be  January 5th.  :  0.94
<s> If today is March 22th, then in 21 days it will be  April 12th  :  1.0
<s> If today is June 4th, then in 5 days it will be  June 9th.  :  1.0
<s> If today is July 4th, then in 13 days it will be  July 17th  :  0.98
Pe

# What are the months in a year?

In [None]:
def ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen):
    tokens = model.to_tokens(clean_text).to(device)
    prompts_list = generate_prompts_list_longer(clean_text, tokens)

    corr_tokens = model.to_tokens(corr_text).to(device)
    prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
    pos_dict = {}
    num_pos = len(model.tokenizer(prompts_list_2[0]['text']).input_ids)
    for i in range(num_pos ):
        pos_dict['S'+str(i)] = i
    dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer)
    model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

    logits = model(tokens)
    next_token = logits[0, -1].argmax(dim=-1)
    next_char = model.to_string(next_token)

    total_score = 0

    for i in range(corr_ans_tokLen):
        if next_char == '':
            next_char = ' '

        clean_text = clean_text + next_char
        # if i == corr_ans_tokLen - 1:
        #     print(model.to_string(tokens))
            # print(f"Sequence so far: {clean_text}")
            # print(f"{tokens.shape[-1]+1}th char = {next_char!r}")

        tokens = torch.cat([tokens, next_token[None, None]], dim=-1)

        # get new ablation dataset
        # model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

        # corr_text = corr_text + next_char
        # corr_tokens = torch.cat([corr_tokens, next_token[None, None]], dim=-1)
        # prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)

        # pos_dict = {}
        # num_pos = len(model.tokenizer(prompts_list_2[0]['text']).input_ids)
        # for i in range(num_pos ):
        #     pos_dict['S'+str(i)] = i

        # dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, corr_tokens)

        # model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

        logits = model(tokens)
        next_token = logits[0, -1].argmax(dim=-1) # Get the predicted token at the end of our sequence
        next_char = model.to_string(next_token)

        # new_score = get_logit_diff(logits, dataset)
        # total_score += new_score
        # print(f"corr logit of new char: {new_score}")
    # print('\n Total corr logit: ', total_score.item())
    return model.to_string(tokens)

In [None]:
clean_text = "What are the months in a year? Give all of them as a list. Be concise."
corr_text = "5 3 9"
num_toks_gen = 50

In [None]:
output = []
for i in range(50):
    heads_of_circ = nums_1to9
    num_heads_rand = 86
    num_not_overlap = len(nums_1to9)

    all_possible_pairs =  [(layer, head) for layer in range(32) for head in range(32)]
    # Filter out heads_of_circ from all_possible_pairs
    filtered_pairs = [pair for pair in all_possible_pairs if pair not in heads_of_circ]

    # Randomly choose 100 pairs ensuring less than 50 overlaps with heads_of_circ
    head_to_remove = choose_heads_to_remove(filtered_pairs, heads_of_circ, 86, num_not_overlap)

    heads_not_ablate = [x for x in all_possible_pairs if x not in head_to_remove]

    mlps_not_ablate = [layer for layer in range(32)]

    out = ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, num_toks_gen)
    print(out)
    output.append(out)

['<s> What are the months in a year? Give all of them as a list. Be concise.\nThe months in a year are:\n\n1. January\n2. February\n3. March\n4. April\n5. May\n6. June\n7. July\n8. August\n9. September\n10. October']
['<s> What are the months in a year? Give all of them as a list. Be concise.\n\nAnswer:\n\nThere are 12 months in a year. Here they are in a list:\n\n1. January\n2. February\n3. March\n4. April\n5. May\n6. June\n7']
['<s> What are the months in a year? Give all of them as a list. Be concise.\n\nAnswer:\nThe months of a year are:\n\n1. January\n2. February\n3. March\n4. April\n5. May\n6. June\n7. July\n8. August\n9. September\n']
['<s> What are the months in a year? Give all of them as a list. Be concise.\n\nThe months in a year are:\n\n1. January\n2. February\n3. March\n4. April\n5. May\n6. June\n7. July\n8. August\n9. September\n10.']
['<s> What are the months in a year? Give all of them as a list. Be concise.\n\nThe months in a year are:\n\n1. January\n2. February\n3. Ma