# Setup

In [1]:
save_files = True

In [2]:
%%capture
%pip install git+https://github.com/neelnanda-io/TransformerLens.git

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
# import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

import pickle
from google.colab import files

import matplotlib.pyplot as plt
import statistics

In [4]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer #, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [5]:
torch.set_grad_enabled(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
import pdb

## Import functions from repo

In [7]:
!git clone https://github.com/apartresearch/seqcont_circuits.git
%cd /content/seqcont_circuits/src/iter_node_pruning

Cloning into 'seqcont_circuits'...
remote: Enumerating objects: 1022, done.[K
remote: Counting objects: 100% (488/488), done.[K
remote: Compressing objects: 100% (287/287), done.[K
remote: Total 1022 (delta 296), reused 378 (delta 190), pack-reused 534[K
Receiving objects: 100% (1022/1022), 18.76 MiB | 14.38 MiB/s, done.
Resolving deltas: 100% (659/659), done.
/content/seqcont_circuits/src/iter_node_pruning


In [8]:
## comment this out when debugging functions in colab to use funcs defined in colab

# don't improt this
# # from dataset import Dataset

from metrics import *
from head_ablation_fns import *
from mlp_ablation_fns import *
from node_ablation_fns import *
from loop_node_ablation_fns import *

## fns

In [9]:
import random


In [10]:
class Dataset:
    def __init__(self, prompts, pos_dict, tokenizer):  # , S1_is_first=False
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        self.corr_tokenIDs = [
            # self.tokenizer.encode(" " + prompt["corr"])[0] for prompt in self.prompts
            self.tokenizer.encode(prompt["corr"])[0] for prompt in self.prompts
        ]
        self.incorr_tokenIDs = [
            # self.tokenizer.encode(" " + prompt["incorr"])[0] for prompt in self.prompts
            self.tokenizer.encode(prompt["incorr"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a dict whose values are tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        # for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
        for targ in [key for key in pos_dict]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = self.tokenizer.tokenize(input_text)
                # if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                #     target_token = prompt[targ]
                # else:
                #     target_token = "Ġ" + prompt[targ]
                # target_index = tokens.index(target_token)
                target_index = pos_dict[targ]
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [11]:
def generate_prompts_list_longer(text, tokens):
    prompts_list = []
    prompt_dict = {
        'corr': str(1),
        'incorr': str(2),
        'text': text
        # 'text': model.to_string(tokens)[0]
        }
    tokens_as_strs = model.tokenizer.tokenize(text)
    # tokens_as_strs = model.to_string(tokens)[0].split()
    # for i in range(tokens.shape[1]):
    for i, tok in enumerate(tokens_as_strs):
        prompt_dict['S'+str(i)] = tok
    # for i, tok in enumerate(tokens):
    #     prompt_dict['S'+str(i)] = model.to_string(tok)

    # prompt_dict = {
    #     'corr': '4',
    #     'incorr': '3',
    #     'text': model.to_string(tokens)[0]
    # }
    # # list_tokens = tokenizer.tokenize('1 2 3 ')
    # tokens_as_strs = model.to_string(tokens)[0].split()
    # for i, tok_as_str in enumerate(tokens_as_strs):
    #     if tok_as_str == '▁':
    #         prompt_dict['S'+str(i)] = ' '
    #     else:
    #         prompt_dict['S'+str(i)] = tok_as_str
    prompts_list.append(prompt_dict)
    return prompts_list

# Load Model

In [12]:
from transformers import LlamaForCausalLM, LlamaTokenizer

In [13]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) n
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [14]:
LLAMA_2_7B_CHAT_PATH = "meta-llama/Llama-2-7b-chat-hf"

tokenizer = LlamaTokenizer.from_pretrained(LLAMA_2_7B_CHAT_PATH)
# tokenizer = LlamaTokenizer.from_pretrained(LLAMA_2_7B_CHAT_PATH, use_fast= False, add_prefix_space= False)
hf_model = LlamaForCausalLM.from_pretrained(LLAMA_2_7B_CHAT_PATH, low_cpu_mem_usage=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.62k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

In [15]:
import transformer_lens.utils as utils
from transformer_lens.hook_points import HookPoint
from transformer_lens import HookedTransformer

In [16]:
model = HookedTransformer.from_pretrained(
    LLAMA_2_7B_CHAT_PATH,
    hf_model = hf_model,
    tokenizer = tokenizer,
    device = "cpu",
    fold_ln = False,
    center_writing_weights = False,
    center_unembed = False,
)

del hf_model

model = model.to("cuda" if torch.cuda.is_available() else "cpu")

Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer
Moving model to device:  cuda


# new ablation functions

In [17]:
def get_heads_actv_mean(
    means_dataset: Dataset,
    model: HookedTransformer
) -> Float[Tensor, "layer batch seq head_idx d_head"]:
    '''
    Output: The mean activations of a head's output
    '''
    _, means_cache = model.run_with_cache(
        means_dataset.toks.long(),
        return_type=None,
        names_filter=lambda name: name.endswith("z"),
    )
    n_layers, n_heads, d_head = model.cfg.n_layers, model.cfg.n_heads, model.cfg.d_head
    batch, seq_len = len(means_dataset), means_dataset.max_len
    means = t.zeros(size=(n_layers, batch, seq_len, n_heads, d_head), device=model.cfg.device)

    # for layer in range(model.cfg.n_layers):
    #     z_for_this_layer: Float[Tensor, "batch seq head d_head"] = means_cache[utils.get_act_name("z", layer)]
    #     for template_group in means_dataset.groups:
    #         z_for_this_template = z_for_this_layer[template_group]
    #         z_means_for_this_template = einops.reduce(z_for_this_template, "batch seq head d_head -> seq head d_head", "mean")
    #         if z_means_for_this_template.shape[0] == 5:
    #             pdb.set_trace()
    #         means[layer, template_group] = z_means_for_this_template

    del(means_cache)

    return means

In [18]:
# def mask_circ_heads(
#     means_dataset: Dataset,
#     model: HookedTransformer,
#     circuit: Dict[str, List[Tuple[int, int]]],
#     seq_pos_to_keep: Dict[str, str],
# ) -> Dict[int, Bool[Tensor, "batch seq head"]]:
#     '''
#     Output: for each layer, a mask of circuit components that should not be ablated
#     '''
#     heads_and_posns_to_keep = {}
#     batch, seq, n_heads = len(means_dataset), means_dataset.max_len, model.cfg.n_heads

#     for layer in range(model.cfg.n_layers):

#         mask = t.zeros(size=(batch, seq, n_heads))

#         for (head_type, head_list) in circuit.items():
#             seq_pos = seq_pos_to_keep[head_type]
#             # if seq_pos == 'S7':
#             #     pdb.set_trace()
#             indices = means_dataset.word_idx[seq_pos] # modify this for key vs query pos. curr, this is query
#             for (layer_idx, head_idx) in head_list:
#                 if layer_idx == layer:
#                     # if indices.item() == 7:
#                     #     pdb.set_trace()
#                     mask[:, indices, head_idx] = 1
#                     # mask[:, :, head_idx] = 1  # keep L.H at all pos

#         heads_and_posns_to_keep[layer] = mask.bool()
#     # pdb.set_trace()
#     return heads_and_posns_to_keep

In [19]:
def mask_circ_heads(
    means_dataset: Dataset,
    model: HookedTransformer,
    circuit: Dict[str, List[Tuple[int, int]]],
    seq_pos_to_keep: Dict[str, str],
) -> Dict[int, Bool[Tensor, "batch seq head"]]:
    '''
    Output: for each layer, a mask of circuit components that should not be ablated
    '''
    heads_and_posns_to_keep = {}
    # batch, seq, n_heads = len(means_dataset), means_dataset.max_len, model.cfg.n_heads
    batch, seq, n_heads = len(means_dataset), len(circuit.keys()), model.cfg.n_heads
    # print(seq)

    for layer in range(model.cfg.n_layers):

        mask = t.zeros(size=(batch, seq, n_heads))

        for (head_type, head_list) in circuit.items():
            seq_pos = seq_pos_to_keep[head_type]
            indices = means_dataset.word_idx[seq_pos] # modify this for key vs query pos. curr, this is query
            for (layer_idx, head_idx) in head_list:
                if layer_idx == layer:
                    # mask[:, indices, head_idx] = 1
                    mask[:, :, head_idx] = 1

        heads_and_posns_to_keep[layer] = mask.bool()

    return heads_and_posns_to_keep

In [20]:
def hook_func_mask_head(
    z: Float[Tensor, "batch seq head d_head"],
    hook: HookPoint,
    # components_to_keep: Dict[int, Bool[Tensor, "batch seq head"]],
    # means: Float[Tensor, "layer batch seq head d_head"],
    circuit: Dict[str, List[Tuple[int, int]]],
) -> Float[Tensor, "batch seq head d_head"]:
    '''
    Use this to not mask components
    '''
    # mask_for_this_layer = components_to_keep[hook.layer()].unsqueeze(-1).to(z.device)
    # z = t.where(mask_for_this_layer, z, means[hook.layer()])

    ###
    # heads_and_posns_to_keep = {}
    # batch, seq, n_heads = z.shape[0], z.shape[1], model.cfg.n_heads  # components_to_keep[0].shape[0] is batch

    # for layer in range(model.cfg.n_layers):

    #     mask = t.zeros(size=(batch, seq, n_heads))

    #     for (head_type, head_list) in circuit.items():
    #         # seq_pos = seq_pos_to_keep[head_type]
    #         # indices = means_dataset.word_idx[seq_pos] # modify this for key vs query pos. curr, this is query
    #         for (layer_idx, head_idx) in head_list:
    #             if layer_idx == layer:
    #                 # mask[:, indices, head_idx] = 1
    #                 mask[:, :, head_idx] = 1

    #     heads_and_posns_to_keep[layer] = mask.bool()
    ###
    mask_for_this_layer = t.zeros(size=(z.shape[0], z.shape[1], z.shape[2]))
    for (head_type, head_list) in circuit.items():
        # seq_pos = seq_pos_to_keep[head_type]
        # indices = means_dataset.word_idx[seq_pos] # modify this for key vs query pos. curr, this is query
        for (layer_idx, head_idx) in head_list:
            if layer_idx == hook.layer():
                # mask[:, indices, head_idx] = 1
                mask_for_this_layer[:, :, head_idx] = 1

    mask_for_this_layer = mask_for_this_layer.bool()
    mask_for_this_layer = mask_for_this_layer.unsqueeze(-1).to(z.device)  # d_model is 1; then is broadcast in where

    z = t.where(mask_for_this_layer, z, 0)

    return z

In [21]:
def add_ablation_hook_head(
    model: HookedTransformer,
    means_dataset: Dataset,
    circuit: Dict[str, List[Tuple[int, int]]],
    seq_pos_to_keep: Dict[str, str],
    is_permanent: bool = True,
) -> HookedTransformer:
    '''
    Ablate the model, except as components and positions to keep
    '''

    model.reset_hooks(including_permanent=True)
    means = get_heads_actv_mean(means_dataset, model)
    components_to_keep = mask_circ_heads(means_dataset, model, circuit, seq_pos_to_keep)

    hook_fn = partial(
        hook_func_mask_head,
        # components_to_keep=components_to_keep,
        # means=means,
        circuit=circuit,
    )

    model.add_hook(lambda name: name.endswith("z"), hook_fn, is_permanent=is_permanent)
    return model

In [22]:
# from dataset import Dataset
from transformer_lens import HookedTransformer, utils
from transformer_lens.hook_points import HookPoint
import einops
from functools import partial
import torch as t
from torch import Tensor
from typing import Dict, Tuple, List
from jaxtyping import Float, Bool

# from head_ablation_fns import *
# from mlp_ablation_fns import *

def add_ablation_hook_MLP_head(
    model: HookedTransformer,
    means_dataset: Dataset,
    heads_lst, mlp_lst,
    is_permanent: bool = True,
) -> HookedTransformer:
    CIRCUIT = {}
    SEQ_POS_TO_KEEP = {}
    # for i in range(len(model.tokenizer.tokenize(means_dataset.prompts[0]['text']))):
    num_pos = len(model.tokenizer(means_dataset.prompts[0]['text']).input_ids)
    for i in range(num_pos ):
        CIRCUIT['S'+str(i)] = heads_lst
        # if i == len(model.tokenizer.tokenize(means_dataset.prompts[0]['text'])) - 1:
        # if i == num_pos - 1:
        #     SEQ_POS_TO_KEEP['S'+str(i)] = 'end'
        # else:
        SEQ_POS_TO_KEEP['S'+str(i)] = 'S'+str(i)

    model.reset_hooks(including_permanent=True)

    # Compute the mean of each head's output on the ABC dataset, grouped by template
    means = get_heads_actv_mean(means_dataset, model)
    # Convert this into a boolean map
    components_to_keep = mask_circ_heads(means_dataset, model, CIRCUIT, SEQ_POS_TO_KEEP)

    # Get a hook function which will patch in the mean z values for each head, at
    # all positions which aren't important for the circuit
    hook_fn = partial(
        hook_func_mask_head,
        # components_to_keep=components_to_keep,
        # means=means,
        circuit=CIRCUIT,
    )

    # Apply hook
    model.add_hook(lambda name: name.endswith("z"), hook_fn, is_permanent=is_permanent)

    # if all_entries_true(components_to_keep) == False:
    #     pdb.set_trace()
    ########################
    # CIRCUIT = {}
    # SEQ_POS_TO_KEEP = {}
    # # for i in range(len(model.tokenizer.tokenize(means_dataset.prompts[0]['text']))):
    # num_pos = len(model.tokenizer(means_dataset.prompts[0]['text']).input_ids)
    # for i in range(num_pos ):
    #     CIRCUIT['S'+str(i)] = mlp_lst
    #     # if i == len(model.tokenizer.tokenize(means_dataset.prompts[0]['text'])) - 1:
    #     # if i == num_pos - 1:
    #     #     SEQ_POS_TO_KEEP['S'+str(i)] = 'end'
    #     # else:
    #     SEQ_POS_TO_KEEP['S'+str(i)] = 'S'+str(i)

    # # Compute the mean of each head's output on the ABC dataset, grouped by template
    # means = get_MLPs_actv_mean(means_dataset, model)

    # # Convert this into a boolean map
    # components_to_keep = mask_circ_MLPs(means_dataset, model, CIRCUIT, SEQ_POS_TO_KEEP)

    # # Get a hook function which will patch in the mean z values for each head, at
    # # all positions which aren't important for the circuit
    # hook_fn = partial(
    #     hook_func_mask_mlp_out,
    #     components_to_keep=components_to_keep,
    #     means=means
    # )

    # model.add_hook(lambda name: name.endswith("mlp_out"), hook_fn, is_permanent=True)

    return model

In [23]:
def all_entries_true(tensor_dict):
    for key, tensor in tensor_dict.items():
        if not torch.all(tensor).item():
            return False
    return True

# ablation fns mult tok answers

In [64]:
def clean_gen(model, clean_text, corr_ans):
    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
    tokens = model.to_tokens(clean_text).to(device)
    # tokens = tokens[:, 1:] # get rid of prepend bos when using model.to_tokens

    total_score = 0
    corr_ans_tokLen = 0
    ans_so_far = ''
    # while True:
    for i in range(10):
        # print(f"Sequence so far: {model.to_string(tokens)[0]!r}")
        logits = model(tokens)
        next_token = logits[0, -1].argmax(dim=-1) # Get the predicted token at the end of our sequence
        next_char = model.to_string(next_token)

        corr_logits = logits[:, -1, next_token]
        total_score += corr_logits
        # print(f"logit diff of new char: {corr_logits}")

        ans_so_far += next_char
        corr_ans_tokLen += 1
        # print(f"{tokens.shape[-1]+1}th char = {next_char!r}")
        if ans_so_far == corr_ans:
            # print('\nTotal logit diff: ', total_score.item())
            break

        # Define new input sequence, by appending the previously generated token
        tokens = torch.cat([tokens, next_token[None, None]], dim=-1)
        # if next_char == '':
        #     next_char = ' '
        # clean_text = clean_text + next_char
        # tokens = model.to_tokens(clean_text).to(device)
    # return corr_ans_tokLen, total_score
    return total_score.item()

In [65]:
def ablate_then_gen(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, corr_ans_tokLen):
    tokens = model.to_tokens(clean_text).to(device)
    prompts_list = generate_prompts_list_longer(clean_text, tokens)

    corr_tokens = model.to_tokens(corr_text).to(device)
    prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
    pos_dict = {}
    num_pos = len(model.tokenizer(prompts_list_2[0]['text']).input_ids)
    for i in range(num_pos ):
        pos_dict['S'+str(i)] = i
    dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer)
    model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

    logits = model(tokens)
    next_token = logits[0, -1].argmax(dim=-1)
    next_char = model.to_string(next_token)

    total_score = 0

    for i in range(corr_ans_tokLen):
        if next_char == '':
            next_char = ' '

        clean_text = clean_text + next_char
        if i == corr_ans_tokLen - 1:
            print(model.to_string(tokens))
            # print(f"Sequence so far: {clean_text}")
            # print(f"{tokens.shape[-1]+1}th char = {next_char!r}")

        tokens = torch.cat([tokens, next_token[None, None]], dim=-1)

        # get new ablation dataset
        # model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

        # corr_text = corr_text + next_char
        # corr_tokens = torch.cat([corr_tokens, next_token[None, None]], dim=-1)
        # prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)

        # pos_dict = {}
        # num_pos = len(model.tokenizer(prompts_list_2[0]['text']).input_ids)
        # for i in range(num_pos ):
        #     pos_dict['S'+str(i)] = i

        # dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, corr_tokens)

        # model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

        logits = model(tokens)
        next_token = logits[0, -1].argmax(dim=-1) # Get the predicted token at the end of our sequence
        next_char = model.to_string(next_token)

        # new_score = get_logit_diff(logits, dataset)
        # total_score += new_score
        # print(f"corr logit of new char: {new_score}")
    # print('\n Total corr logit: ', total_score.item())

In [66]:
# Function to randomly choose 50 pairs ensuring less than 10 overlap with heads_of_circ
def choose_heads_to_remove(filtered_pairs, heads_of_circ, num_pairs=50, max_overlap=10):
    while True:
        head_to_remove = random.sample(filtered_pairs, num_pairs)
        overlap_count = len([head for head in head_to_remove if head in heads_of_circ])
        if overlap_count < max_overlap:
            return head_to_remove

In [67]:
answer_str = '10'
ans_str_tok = tokenizer.tokenize(answer_str)[1:]
ans_str_tok

['1', '0']

In [68]:
corr_tokenIDs = []
for ansPos in range(len(ans_str_tok)):
    # ansPos_corrTokIDS = [] # this is the inner list. each member is a promptID
    # for promptID in range(len(self.prompts)):
    #     tokID = self.tokenizer.encode(self.prompts[promptID]['corr'][ansPos])[2:][0] # 2: to skip padding <s> and ''
    #     ansPos_corrTokIDS.append(tokID)
    # self.corr_tokenIDs.append(ansPos_corrTokIDS)

    tokID = model.tokenizer.encode(ans_str_tok[ansPos])[2:][0] # 2: to skip padding <s> and ''
    corr_tokenIDs.append(tokID)
corr_tokenIDs

[29896, 29900]

In [69]:
def ablate_auto_score(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, correct_ans):
    tokens = model.to_tokens(clean_text).to(device)
    prompts_list = generate_prompts_list_longer(clean_text, tokens)

    corr_tokens = model.to_tokens(corr_text).to(device)
    prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
    pos_dict = {}
    num_pos = len(model.tokenizer(prompts_list_2[0]['text']).input_ids)
    for i in range(num_pos ):
        pos_dict['S'+str(i)] = i
    dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer)
    model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

    # logits = model(tokens)
    # next_token = logits[0, -1].argmax(dim=-1)
    # next_char = model.to_string(next_token)

    total_score = 0
    ans_so_far = ''
    ans_str_tok = tokenizer.tokenize(correct_ans)[1:] # correct_ans is str
    corr_tokenIDs = []
    for correct_ansPos in range(len(ans_str_tok)):
        tokID = model.tokenizer.encode(ans_str_tok[correct_ansPos])[2:][0] # 2: to skip padding <s> and ''
        corr_tokenIDs.append(tokID)
    correct_ans_tokLen = len(corr_tokenIDs)
    for ansPos in range(correct_ans_tokLen):
        # if next_char == '':
        #     next_char = ' '

        # clean_text = clean_text + next_char
        # if i == correct_ans_tokLen - 1:
        #     print(model.to_string(tokens))
        #     # print(f"Sequence so far: {clean_text}")
        #     # print(f"{tokens.shape[-1]+1}th char = {next_char!r}")

        # tokens = torch.cat([tokens, next_token[None, None]], dim=-1)

        # get new ablation dataset
        # model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

        # corr_text = corr_text + next_char
        # corr_tokens = torch.cat([corr_tokens, next_token[None, None]], dim=-1)
        # prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)

        # pos_dict = {}
        # num_pos = len(model.tokenizer(prompts_list_2[0]['text']).input_ids)
        # for i in range(num_pos ):
        #     pos_dict['S'+str(i)] = i

        # dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, corr_tokens)

        # model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

        logits = model(tokens)
        next_token = logits[0, -1].argmax(dim=-1) # Get the predicted token at the end of our sequence
        next_char = model.to_string(next_token)

        if next_char == '':
            next_char = ' '

        clean_text = clean_text + next_char
        # if i == correct_ans_tokLen - 1:
            # print(model.to_string(tokens))
            # print(f"Sequence so far: {clean_text}")
            # print(f"{tokens.shape[-1]+1}th char = {next_char!r}")

        tokens = torch.cat([tokens, next_token[None, None]], dim=-1)

        ans_so_far += next_char
        correct_ans_tokLen += 1
        # print(f"{tokens.shape[-1]+1}th char = {next_char!r}")

        ansTok_IDs = torch.tensor(corr_tokenIDs[ansPos])

        # new_score = get_logit_diff(logits, dataset)
        # total_score += new_score
        # corrTok_logits = logits[:, -1, next_token]
        corrTok_logits = logits[range(logits.size(0)), -1, ansTok_IDs]  # not next_token, as that's what's pred, not the token to measure
        # pdb.set_trace()
        total_score += corrTok_logits
        # print(f"corr logit of new char: {new_score}")
    # print('\n Total corr logit: ', total_score.item())
    return ans_so_far, total_score.item()

# multi-promt auto score fns

In [70]:
def clean_autoScore(model, sequences_as_str, next_members):
    list_outputs = []
    total_orig_logits = 0
    for clean_text, correct_ans in zip(sequences_as_str, next_members):
        orig_score = clean_gen(model, clean_text, correct_ans)

        total_orig_logits += orig_score
    return total_orig_logits

In [71]:
def ablate_circ_autoScore(model, circuit, sequences_as_str, next_members):
    corr_text = "5 3 9"
    list_outputs = []
    score = 0
    total_orig_logits = 0
    total_abl_logits = 0
    for clean_text, correct_ans in zip(sequences_as_str, next_members):
        # correct_ans_tokLen, orig_score = clean_gen(model, clean_text, correct_ans)
        # orig_score = clean_gen(model, clean_text, correct_ans)

        heads_not_ablate = [(layer, head) for layer in range(32) for head in range(32)]  # unablated
        head_to_remove = circuit
        heads_not_ablate = [x for x in heads_not_ablate if (x not in head_to_remove)]

        mlps_not_ablate = [layer for layer in range(32)]

        output_after_ablate, ablated_score = ablate_auto_score(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, correct_ans)
        list_outputs.append(output_after_ablate)
        # total_orig_logits += orig_score
        total_abl_logits += ablated_score
        # print(correct_ans, output_after_ablate)
        # print(orig_score, ablated_score)
        # print('logit ratio: ', ablated_score / orig_score)
        if correct_ans == output_after_ablate:
            score += 1
    perc_score = score / len(next_members)
    return perc_score, list_outputs , total_abl_logits # / total_orig_logits

In [81]:
def ablate_randcirc_autoScore(model, sequences_as_str, next_members, num_rand_runs, heads_not_overlap, num_heads_rand, num_not_overlap):
    corr_text = "5 3 9"
    list_outputs = []
    all_scores = []
    for clean_text, correct_ans in zip(sequences_as_str, next_members):
        prompt_score = 0
        correct_ans_tokLen = clean_gen(model, clean_text, correct_ans)
        for j in range(num_rand_runs):
            all_possible_pairs =  [(layer, head) for layer in range(32) for head in range(32)]
            filtered_pairs = [pair for pair in all_possible_pairs if pair not in heads_not_overlap] # Filter out heads_not_overlap from all_possible_pairs

            # Randomly choose num_heads_rand pairs ensuring less than num_not_overlap overlaps with heads_not_overlap
            head_to_remove = choose_heads_to_remove(filtered_pairs, heads_not_overlap, num_heads_rand, num_not_overlap)

            heads_not_ablate = [x for x in all_possible_pairs if x not in head_to_remove]

            mlps_not_ablate = [layer for layer in range(32)]

            output_after_ablate, logit_diff = ablate_auto_score(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, correct_ans)
            list_outputs.append(output_after_ablate)
            # print(correct_ans, output_after_ablate)
            if correct_ans == output_after_ablate:
                prompt_score += 1
        # print(prompt_score / num_rand_runs)
        all_scores.append(prompt_score / num_rand_runs)

    perc_score = sum(all_scores) / len(next_members)
    return perc_score, list_outputs

In [73]:
def gen_intervaled_seqs(interval, start, num_prompts):
    sequences = []
    next_members = []

    # Generate overlapping intervals
    for _ in range(num_prompts):
        sequence = [start, start + interval, start + interval*2]
        next_member = str(start + interval*3)
        sequences.append(sequence)
        next_members.append(next_member)
        start += interval  # Move to the next starting point

    sequences_as_str = [" ".join(map(str, seq)) for seq in sequences]
    sequences_as_str = [member + " " for member in sequences_as_str]
    print("Sequences:")
    print(sequences_as_str)
    print("\nNext Members:")
    print(next_members)
    return sequences_as_str, next_members

# Define circs

In [30]:
# from Llama2_numerals_1to10.ipynb
nums_1to9 = [(0, 2), (0, 5), (0, 6), (0, 15), (1, 15), (1, 28), (2, 13), (2, 24), (3, 24), (4, 3), (4, 16), (5, 11), (5, 13), (5, 15), (5, 16), (5, 23), (5, 25), (5, 27), (6, 11), (6, 14), (6, 20), (6, 23), (6, 24), (6, 26), (6, 28), (6, 30), (6, 31), (7, 0), (7, 13), (7, 21), (7, 30), (8, 0), (8, 2), (8, 12), (8, 15), (8, 26), (8, 27), (8, 30), (8, 31), (9, 15), (9, 16), (9, 23), (9, 26), (9, 27), (9, 29), (9, 31), (10, 1), (10, 13), (10, 18), (10, 23), (10, 29), (11, 7), (11, 8), (11, 9), (11, 17), (11, 18), (11, 25), (11, 28), (12, 18), (12, 19), (12, 23), (12, 27), (13, 6), (13, 11), (13, 20), (14, 18), (14, 19), (14, 20), (14, 21), (16, 0), (18, 19), (18, 21), (18, 25), (18, 26), (18, 31), (19, 28), (20, 17), (21, 0), (21, 2), (22, 18), (22, 20), (22, 25), (23, 27), (26, 2)]
len(nums_1to9)

84

In [31]:
# nw_circ = [(0, 1), (0, 4), (0, 6), (0, 7), (0, 8), (0, 10), (0, 11), (0, 12), (1, 16), (1, 24), (1, 27), (1, 28), (2, 2), (2, 5), (2, 8), (2, 24), (2, 30), (3, 7), (3, 14), (3, 19), (3, 23), (4, 3), (5, 16), (5, 25), (6, 11), (6, 14), (7, 0), (7, 30), (8, 0), (8, 2), (8, 3), (8, 4), (8, 6), (8, 21), (8, 31), (9, 1), (9, 3), (9, 7), (9, 11), (9, 29), (9, 31), (10, 13), (10, 18), (10, 23), (10, 24), (10, 25), (10, 27), (11, 18), (11, 28), (12, 18), (12, 26), (13, 11), (13, 17), (13, 18), (13, 19), (13, 20), (13, 21), (13, 23), (14, 7), (14, 14), (15, 25), (15, 28), (16, 0), (16, 12), (16, 14), (16, 15), (16, 16), (16, 19), (16, 24), (16, 29), (17, 17), (17, 23), (17, 31), (18, 31), (19, 12), (20, 17), (27, 20), (27, 25), (27, 27), (27, 31), (28, 5), (29, 5)]
# in order from most impt to least based on how much changes perf when ablated
nw_circ = [(20, 17), (5, 25), (16, 0), (29, 5), (3, 19), (6, 11), (15, 25), (8, 0), (16, 24), (8, 4), (7, 0), (6, 14), (16, 29), (5, 16), (12, 26), (4, 3), (3, 7), (7, 30), (11, 28), (28, 5), (17, 31), (13, 11), (13, 20), (12, 18), (1, 27), (10, 13), (18, 31), (8, 6), (9, 1), (0, 4), (2, 2), (9, 11), (19, 12), (1, 16), (13, 17), (9, 7), (11, 18), (2, 24), (10, 18), (9, 31), (9, 29), (2, 30), (2, 5), (1, 24), (2, 8), (15, 28), (27, 31), (16, 14), (3, 23), (3, 14), (10, 23), (27, 20), (8, 3), (14, 7), (14, 14), (16, 15), (8, 2), (17, 17), (0, 1), (10, 27), (16, 19), (0, 8), (0, 12), (1, 28), (0, 11), (17, 23), (0, 10), (0, 6), (13, 19), (8, 31), (10, 24), (16, 12), (13, 23), (13, 21), (27, 27), (9, 3), (27, 25), (16, 16), (8, 21), (0, 7), (13, 18), (10, 25)]
len(nw_circ)

82

In [32]:
# impt_months_heads = ([(23, 17), (17, 11), (16, 0), (26, 14), (18, 9), (5, 25), (22, 20), (6, 24), (26, 9), (12, 18), (13, 20), (19, 12), (27, 29), (13, 14), (16, 14), (12, 26), (19, 30), (16, 18), (31, 27), (26, 28), (16, 1), (18, 1), (19, 28), (18, 31), (29, 4), (17, 0), (14, 1), (17, 12), (12, 15), (28, 16), (10, 1), (16, 19), (9, 27), (30, 1), (19, 27), (0, 3), (15, 11), (21, 3), (11, 19), (12, 0), (23, 11), (8, 14), (16, 8), (22, 13), (13, 3), (4, 19), (14, 15), (12, 20), (19, 16), (18, 5)])
months_circ = [(20, 17), (6, 11), (16, 0), (5, 15), (17, 11), (23, 16), (5, 25), (7, 0), (26, 14), (6, 14), (12, 22), (8, 4), (12, 15), (16, 29), (15, 25), (5, 16), (18, 31), (14, 7), (11, 18), (4, 12), (3, 19), (12, 2), (11, 28), (4, 3), (18, 9), (8, 14), (12, 3), (11, 2), (10, 13), (4, 16), (1, 22), (11, 16), (3, 15), (13, 31), (2, 4), (2, 16), (8, 13), (0, 13), (8, 15), (12, 28), (1, 5), (0, 4), (0, 25), (3, 24), (13, 11), (1, 24), (8, 16), (13, 8), (3, 26), (0, 6), (3, 23), (1, 3), (14, 3), (8, 19), (8, 12), (14, 2), (8, 5), (1, 28), (8, 20), (2, 30), (8, 6), (10, 1), (13, 20), (19, 27)]
len(months_circ)

64

In [33]:
intersect_all = list(set(nums_1to9) & set(nw_circ) & set(months_circ))
len(intersect_all)

16

In [34]:
union_all = list(set(nums_1to9) | set(nw_circ) | set(months_circ))
len(union_all)

172

# (+100) seq - summed correct tok logits

In [None]:
interval = 100
start = 0
num_prompts = 50
sequences_as_str, next_members = gen_intervaled_seqs(interval, start, num_prompts)

Sequences:
['0 100 200 ', '100 200 300 ', '200 300 400 ', '300 400 500 ', '400 500 600 ', '500 600 700 ', '600 700 800 ', '700 800 900 ', '800 900 1000 ', '900 1000 1100 ', '1000 1100 1200 ', '1100 1200 1300 ', '1200 1300 1400 ', '1300 1400 1500 ', '1400 1500 1600 ', '1500 1600 1700 ', '1600 1700 1800 ', '1700 1800 1900 ', '1800 1900 2000 ', '1900 2000 2100 ', '2000 2100 2200 ', '2100 2200 2300 ', '2200 2300 2400 ', '2300 2400 2500 ', '2400 2500 2600 ', '2500 2600 2700 ', '2600 2700 2800 ', '2700 2800 2900 ', '2800 2900 3000 ', '2900 3000 3100 ', '3000 3100 3200 ', '3100 3200 3300 ', '3200 3300 3400 ', '3300 3400 3500 ', '3400 3500 3600 ', '3500 3600 3700 ', '3600 3700 3800 ', '3700 3800 3900 ', '3800 3900 4000 ', '3900 4000 4100 ', '4000 4100 4200 ', '4100 4200 4300 ', '4200 4300 4400 ', '4300 4400 4500 ', '4400 4500 4600 ', '4500 4600 4700 ', '4600 4700 4800 ', '4700 4800 4900 ', '4800 4900 5000 ', '4900 5000 5100 ']

Next Members:
['300', '400', '500', '600', '700', '800', '900', '1

In [None]:
orig_logit_sum = clean_autoScore(model, sequences_as_str, next_members)

In [None]:
perc_score, list_outputs, total_abl_logits = ablate_circ_autoScore(model, intersect_all, sequences_as_str, next_members)
print(perc_score)
print(total_abl_logits/ orig_logit_sum)

0.8
0.9105765476871324


In [None]:
perc_score, list_outputs, total_abl_logits = ablate_circ_autoScore(model, nums_1to9, sequences_as_str, next_members)
print(perc_score)
print(total_abl_logits/ orig_logit_sum)

0.0
0.7278133123943198


In [54]:
perc_score, list_outputs, total_abl_logits = ablate_circ_autoScore(model, nw_circ, sequences_as_str, next_members)
print(perc_score)
print(total_abl_logits/ orig_logit_sum)

0.0
0.7464113388731783


In [55]:
perc_score, list_outputs, total_abl_logits = ablate_circ_autoScore(model, months_circ, sequences_as_str, next_members)
print(perc_score)
print(total_abl_logits/ orig_logit_sum)

0.0
0.5278497158347519


In [82]:
num_rand_runs = 10
heads_not_overlap = intersect_all
num_heads_rand = len(intersect_all)
num_not_overlap = len(intersect_all)
perc_score, list_outputs = ablate_randcirc_autoScore(model, sequences_as_str, next_members,
                                                    num_rand_runs, heads_not_overlap, num_heads_rand, num_not_overlap)
# print(perc_score / len(next_members))

1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
0.9
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
0.998


# fns using logit diff

In [88]:
interval = 100
start = 0
num_prompts = 50
sequences_as_str, next_members = gen_intervaled_seqs(interval, start, num_prompts)

incorr_answers = [prompt.split(' ')[-2] for prompt in sequences_as_str]

Sequences:
['0 100 200 ', '100 200 300 ', '200 300 400 ', '300 400 500 ', '400 500 600 ', '500 600 700 ', '600 700 800 ', '700 800 900 ', '800 900 1000 ', '900 1000 1100 ', '1000 1100 1200 ', '1100 1200 1300 ', '1200 1300 1400 ', '1300 1400 1500 ', '1400 1500 1600 ', '1500 1600 1700 ', '1600 1700 1800 ', '1700 1800 1900 ', '1800 1900 2000 ', '1900 2000 2100 ', '2000 2100 2200 ', '2100 2200 2300 ', '2200 2300 2400 ', '2300 2400 2500 ', '2400 2500 2600 ', '2500 2600 2700 ', '2600 2700 2800 ', '2700 2800 2900 ', '2800 2900 3000 ', '2900 3000 3100 ', '3000 3100 3200 ', '3100 3200 3300 ', '3200 3300 3400 ', '3300 3400 3500 ', '3400 3500 3600 ', '3500 3600 3700 ', '3600 3700 3800 ', '3700 3800 3900 ', '3800 3900 4000 ', '3900 4000 4100 ', '4000 4100 4200 ', '4100 4200 4300 ', '4200 4300 4400 ', '4300 4400 4500 ', '4400 4500 4600 ', '4500 4600 4700 ', '4600 4700 4800 ', '4700 4800 4900 ', '4800 4900 5000 ', '4900 5000 5100 ']

Next Members:
['300', '400', '500', '600', '700', '800', '900', '1

In [89]:
def clean_autoScore(model, sequences_as_str, next_members, incorr_answers):
    list_outputs = []
    total_orig_logits = 0
    for clean_text, correct_ans, incorr_ans in zip(sequences_as_str, next_members, incorr_answers):
        orig_score = clean_gen(model, clean_text, correct_ans, incorr_ans)

        total_orig_logits += orig_score
    return total_orig_logits

In [87]:
def clean_gen(model, clean_text, corr_ans, incorr_ans):
    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
    tokens = model.to_tokens(clean_text).to(device)
    # tokens = tokens[:, 1:] # get rid of prepend bos when using model.to_tokens

    wrongAns_str_tok = tokenizer.tokenize(incorr_ans)[1:]
    incorr_tokenIDs = []
    for incorrect_ansPos in range(len(wrongAns_str_tok)):
        tokID = model.tokenizer.encode(ans_str_tok[incorrect_ansPos])[2:][0] # 2: to skip padding <s> and ''
        incorr_tokenIDs.append(tokID)

    total_score = 0
    corr_ans_tokLen = 0
    ans_so_far = ''
    # while True:
    for i in range(10):
        # print(f"Sequence so far: {model.to_string(tokens)[0]!r}")
        logits = model(tokens)
        next_token = logits[0, -1].argmax(dim=-1) # Get the predicted token at the end of our sequence
        next_char = model.to_string(next_token)

        corr_logits = logits[:, -1, next_token]

        wrongTok_IDs = torch.tensor(incorr_tokenIDs[i])
        incorr_logits = logits[:, -1, wrongTok_IDs]
        # total_score += corr_logits
        total_score += (corr_logits - incorr_logits)
        # print(f"logit diff of new char: {corr_logits}")

        ans_so_far += next_char
        corr_ans_tokLen += 1
        # print(f"{tokens.shape[-1]+1}th char = {next_char!r}")
        if ans_so_far == corr_ans:
            # print('\nTotal logit diff: ', total_score.item())
            break

        # Define new input sequence, by appending the previously generated token
        tokens = torch.cat([tokens, next_token[None, None]], dim=-1)
        # if next_char == '':
        #     next_char = ' '
        # clean_text = clean_text + next_char
        # tokens = model.to_tokens(clean_text).to(device)
    # return corr_ans_tokLen, total_score
    return total_score.item()

In [91]:
orig_logit_sum = clean_autoScore(model, sequences_as_str, next_members, incorr_answers)
orig_logit_sum

IndexError: list index out of range

In [None]:
def ablate_auto_score(model, clean_text, corr_text, heads_not_ablate, mlps_not_ablate, correct_ans):
    tokens = model.to_tokens(clean_text).to(device)
    prompts_list = generate_prompts_list_longer(clean_text, tokens)

    corr_tokens = model.to_tokens(corr_text).to(device)
    prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook
    pos_dict = {}
    num_pos = len(model.tokenizer(prompts_list_2[0]['text']).input_ids)
    for i in range(num_pos ):
        pos_dict['S'+str(i)] = i
    dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer)
    model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

    # logits = model(tokens)
    # next_token = logits[0, -1].argmax(dim=-1)
    # next_char = model.to_string(next_token)

    total_score = 0
    ans_so_far = ''
    ans_str_tok = tokenizer.tokenize(correct_ans)[1:] # correct_ans is str
    corr_tokenIDs = []
    for correct_ansPos in range(len(ans_str_tok)):
        tokID = model.tokenizer.encode(ans_str_tok[correct_ansPos])[2:][0] # 2: to skip padding <s> and ''
        corr_tokenIDs.append(tokID)
    correct_ans_tokLen = len(corr_tokenIDs)
    for ansPos in range(correct_ans_tokLen):
        # if next_char == '':
        #     next_char = ' '

        # clean_text = clean_text + next_char
        # if i == correct_ans_tokLen - 1:
        #     print(model.to_string(tokens))
        #     # print(f"Sequence so far: {clean_text}")
        #     # print(f"{tokens.shape[-1]+1}th char = {next_char!r}")

        # tokens = torch.cat([tokens, next_token[None, None]], dim=-1)

        # get new ablation dataset
        # model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

        # corr_text = corr_text + next_char
        # corr_tokens = torch.cat([corr_tokens, next_token[None, None]], dim=-1)
        # prompts_list_2 = generate_prompts_list_longer(corr_text, corr_tokens)

        # pos_dict = {}
        # num_pos = len(model.tokenizer(prompts_list_2[0]['text']).input_ids)
        # for i in range(num_pos ):
        #     pos_dict['S'+str(i)] = i

        # dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, corr_tokens)

        # model = add_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)

        logits = model(tokens)
        next_token = logits[0, -1].argmax(dim=-1) # Get the predicted token at the end of our sequence
        next_char = model.to_string(next_token)

        if next_char == '':
            next_char = ' '

        clean_text = clean_text + next_char
        # if i == correct_ans_tokLen - 1:
            # print(model.to_string(tokens))
            # print(f"Sequence so far: {clean_text}")
            # print(f"{tokens.shape[-1]+1}th char = {next_char!r}")

        tokens = torch.cat([tokens, next_token[None, None]], dim=-1)

        ans_so_far += next_char
        correct_ans_tokLen += 1
        # print(f"{tokens.shape[-1]+1}th char = {next_char!r}")

        ansTok_IDs = torch.tensor(corr_tokenIDs[ansPos])

        # new_score = get_logit_diff(logits, dataset)
        # total_score += new_score
        # corrTok_logits = logits[:, -1, next_token]
        corrTok_logits = logits[range(logits.size(0)), -1, ansTok_IDs]  # not next_token, as that's what's pred, not the token to measure
        # pdb.set_trace()
        total_score += corrTok_logits
        # print(f"corr logit of new char: {new_score}")
    # print('\n Total corr logit: ', total_score.item())
    return ans_so_far, total_score.item()

# (+2) seq

In [None]:
interval = 2
start = 2
num_prompts = 20
sequences_as_str, next_members = gen_intervaled_seqs(interval, start, num_prompts)

Sequences:
['2 4 6 ', '4 6 8 ', '6 8 10 ', '8 10 12 ', '10 12 14 ', '12 14 16 ', '14 16 18 ', '16 18 20 ', '18 20 22 ', '20 22 24 ', '22 24 26 ', '24 26 28 ', '26 28 30 ', '28 30 32 ', '30 32 34 ', '32 34 36 ', '34 36 38 ', '36 38 40 ', '38 40 42 ', '40 42 44 ']

Next Members:
['8', '10', '12', '14', '16', '18', '20', '22', '24', '26', '28', '30', '32', '34', '36', '38', '40', '42', '44', '46']


In [None]:
perc_score, list_outputs, logit_ratio = ablate_circ_autoScore(model, [], sequences_as_str, next_members)
print(perc_score)
logit_ratio

8 8
logit ratio:  1.0
10 10
logit ratio:  1.0
12 12
logit ratio:  1.0
14 14
logit ratio:  1.0
16 16
logit ratio:  1.0
18 18
logit ratio:  1.0
20 20
logit ratio:  1.0
22 22
logit ratio:  1.0
24 24
logit ratio:  1.0
26 26
logit ratio:  1.0
28 28
logit ratio:  1.0
30 30
logit ratio:  1.0
32 32
logit ratio:  1.0
34 34
logit ratio:  1.0
36 36
logit ratio:  1.0
38 38
logit ratio:  1.0
40 40
logit ratio:  1.0
42 42
logit ratio:  1.0
44 44
logit ratio:  1.0
46 46
logit ratio:  1.0
1.0


1.0

In [None]:
perc_score, list_outputs, logit_ratio = ablate_circ_autoScore(model, intersect_all, sequences_as_str, next_members)
print(perc_score)
logit_ratio

8 8
logit ratio:  0.7353686519021883
10 10
logit ratio:  0.7973682116716272
12 10
logit ratio:  0.8519332673926002
14 10
logit ratio:  0.8844083631865375
16 10
logit ratio:  0.7750002515709837
18 10
logit ratio:  0.8442251521091306
20 10
logit ratio:  0.7898947100286495
22 10
logit ratio:  0.8053892642476077
24 10
logit ratio:  0.8266712200783308
26 25
logit ratio:  0.8546860552797781
28 28
logit ratio:  0.8427185880013404
30 29
logit ratio:  0.8676719176313099
32 32
logit ratio:  0.9395627550305542
34 33
logit ratio:  0.8147240138984405
36 35
logit ratio:  0.8258292482283477
38 38
logit ratio:  0.8460781404319373
40 39
logit ratio:  0.7352914844424208
42 42
logit ratio:  0.840594478881024
44 45
logit ratio:  0.8143526195962727
46 45
logit ratio:  0.8231797074664411
0.3


0.8271088476138267

In [None]:
perc_score, list_outputs, logit_ratio = ablate_circ_autoScore(model, nums_1to9, sequences_as_str, next_members)
print(perc_score)
logit_ratio

8 6
logit ratio:  0.6140289200227146
10 10
logit ratio:  0.7823467730392415
12 10
logit ratio:  0.8351860044636241
14 10
logit ratio:  0.818153290680594
16 10
logit ratio:  0.7188690148952612
18 14
logit ratio:  0.804041993464688
20 16
logit ratio:  0.763741683131062
22 18
logit ratio:  0.7755448905542037
24 10
logit ratio:  0.7563106433563089
26 10
logit ratio:  0.8485805287523001
28 10
logit ratio:  0.8158035133519199
30 20
logit ratio:  0.8968695568321476
32 10
logit ratio:  0.8226770998357318
34 20
logit ratio:  0.7617077728664517
36 33
logit ratio:  0.7343630779094977
38 34
logit ratio:  0.8071547862242475
40 36
logit ratio:  0.7433231446713092
42 40
logit ratio:  0.6774103206784406
44 40
logit ratio:  0.7035567837358885
46 44
logit ratio:  0.7990218727397458
0.05


0.7764357879128586

In [None]:
perc_score, list_outputs = ablate_circ_autoScore(model, nw_circ, sequences_as_str, next_members)
perc_score

8 2
10 1 
12 0 
14 1 
16 1 
18 10
20 10
22 10
24 1 
26 0 
28 2 
30 10
32 10
34 30
36 3 
38 32
40 36
42 3 
44 3 
46 0 


0.0

In [None]:
perc_score, list_outputs = ablate_circ_autoScore(model, months_circ, sequences_as_str, next_members)
perc_score

8 6
10 6 
12 10
14 10
16 14
18 16
20 16
22 18
24 20
26 24
28 26
30 28
32 30
34 32
36 34
38 36
40 36
42 38
44 43
46 44


0.0

In [None]:
num_rand_runs = 5
heads_not_overlap = intersect_all
num_heads_rand = len(intersect_all)
num_not_overlap = len(intersect_all)
perc_score, list_outputs = ablate_randcirc_autoScore(model, sequences_as_str[:10], next_members[:10],
                                                    num_rand_runs, heads_not_overlap, num_heads_rand, num_not_overlap)
perc_score

0.8
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0


0.9800000000000001

# (+3) seq

In [None]:
interval = 3
start = 3
num_prompts = 20
sequences_as_str, next_members = gen_intervaled_seqs(interval, start, num_prompts)

Sequences:
['3 6 9', '6 9 12', '9 12 15', '12 15 18', '15 18 21', '18 21 24', '21 24 27', '24 27 30', '27 30 33', '30 33 36', '33 36 39', '36 39 42', '39 42 45', '42 45 48', '45 48 51', '48 51 54', '51 54 57', '54 57 60', '57 60 63', '60 63 66']

Next Members:
['12', '15', '18', '21', '24', '27', '30', '33', '36', '39', '42', '45', '48', '51', '54', '57', '60', '63', '66', '69']


In [None]:
perc_score, list_outputs = ablate_circ_autoScore(model, intersect_all, sequences_as_str, next_members)
perc_score

12 10
15 15
18 18
21 <0x0A>Answer:
24 10
27 25
30 30
33 30
36 30
39 30
42 36
45 45
48 48
51 47
54 55
57 55
60 58
63 63
66 65
69 67


0.3

In [None]:
perc_score, list_outputs = ablate_circ_autoScore(model, nums_1to9, sequences_as_str, next_members)
perc_score

12 10
15 12
18 12
21 15
24 18
27 11
30 10
33 10
36 27
39 33
42 33
45 39
48 45
51 45
54 48
57 48
60 54
63 60
66 53
69 63


0.0

In [None]:
perc_score, list_outputs = ablate_circ_autoScore(model, nw_circ, sequences_as_str, next_members)
perc_score

12 11
15 1.<0x0A>
18 11
21 10
24 16.
27 12
30 12
33 7.<0x0A>
36 10
39 30
42 33
45 33
48 34
51 44
54 10
57 31
60 15
63 10
66 10
69 63


0.0

In [None]:
perc_score, list_outputs = ablate_circ_autoScore(model, months_circ, sequences_as_str, next_members)
perc_score

12 33
15 12
18 15
21 18
24 18
27 24
30 27
33 20
36 30
39 36
42 36
45 39
48 45
51 48
54 51
57 54
60 57
63 57
66 60
69 63


0.0

In [None]:
num_rand_runs = 5
heads_not_overlap = intersect_all
num_heads_rand = len(intersect_all)
num_not_overlap = len(intersect_all)
perc_score, list_outputs = ablate_randcirc_autoScore(model, sequences_as_str[:10], next_members[:10],
                                                    num_rand_runs, heads_not_overlap, num_heads_rand, num_not_overlap)
perc_score

1.0
1.0
0.4
1.0
0.6
0.4
0.8
0.8
1.0
0.8


0.78

# (+10) seq

In [None]:
interval = 10
start = 0
num_prompts = 10
sequences_as_str, next_members = gen_intervaled_seqs(interval, start, num_prompts)

Sequences:
['0 10 20', '10 20 30', '20 30 40', '30 40 50', '40 50 60', '50 60 70', '60 70 80', '70 80 90', '80 90 100', '90 100 110']

Next Members:
['30', '40', '50', '60', '70', '80', '90', '100', '110', '120']


In [None]:
perc_score, list_outputs = ablate_circ_autoScore(model, intersect_all, sequences_as_str, next_members)
perc_score

30 10
40 10
50 50
60 50
70 70
80 80
90 90
100 708
110 100
120 100


0.4

In [None]:
perc_score, list_outputs = ablate_circ_autoScore(model, nums_1to9, sequences_as_str, next_members)
perc_score

30 81
40 10
50 10
60 10
70 70
80 10
90 10
100 010
110 100
120 100


0.1

In [None]:
perc_score, list_outputs = ablate_circ_autoScore(model, nw_circ, sequences_as_str, next_members)
perc_score

30 00
40 00
50 30
60 40
70 50
80 60
90 70
100 809
110 0000
120 100


0.0

In [None]:
perc_score, list_outputs = ablate_circ_autoScore(model, months_circ, sequences_as_str, next_members)
perc_score

30 9-0
40 30
50 40
60 50
70 60
80 70
90 70
100 909
110 101
120 101


0.0

In [None]:
num_rand_runs = 5
heads_not_overlap = intersect_all
num_heads_rand = len(intersect_all)
num_not_overlap = len(intersect_all)
perc_score, list_outputs = ablate_randcirc_autoScore(model, sequences_as_str[:10], next_members[:10],
                                                    num_rand_runs, heads_not_overlap, num_heads_rand, num_not_overlap)
perc_score

1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0


1.0