In [1]:
import torch
import os
import json
from torch.nn import CosineSimilarity
import matplotlib.pyplot as plt
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer
from functools import partial
from baukit import TraceDict
from einops import rearrange, einsum
from collections import defaultdict
import matplotlib.pyplot as plt
from plotly_utils import imshow, scatter
from tqdm import tqdm
from datasets import Dataset
from torch.utils.data import DataLoader
import math
import seaborn as sns
from peft import PeftModel
import pickle

import pysvelte
import analysis_utils
from counterfactual_datasets.entity_tracking import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

seed = 30
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
transformers.set_seed(seed)

%load_ext autoreload
%autoreload 2

In [2]:
print("Model Loading...")
# path = "AlekseyKorshuk/vicuna-7b"
# path = "/data/nikhil_prakash/goat-finetuning/drawn-moon-15/"
# path = "/data/nikhil_prakash/llama_weights/7B"
# tokenizer = AutoTokenizer.from_pretrained(path)
# model = AutoModelForCausalLM.from_pretrained(path).to(device)

base_model = "decapoda-research/llama-7b-hf"
lora_weights = "tiedong/goat-lora-7b"

tokenizer = LlamaTokenizer.from_pretrained(
    "hf-internal-testing/llama-tokenizer", padding_side="right"
)
model = LlamaForCausalLM.from_pretrained(
    base_model,
    load_in_8bit=False,
    torch_dtype=torch.float32,
    device_map="auto",
)
model = PeftModel.from_pretrained(
    model,
    lora_weights,
    torch_dtype=torch.float32,
    device_map={"": 0},
)

tokenizer.pad_token_id = tokenizer.eos_token_id
# llama_tokenizer.pad_token_id = llama_tokenizer.eos_token_id

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Model Loading...


Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

## Evaluating Models

In [30]:
data_file = "./box_datasets/no_instructions/alternative/Random/7/train.jsonl"
object_file = "./box_datasets/filtered_objects_with_bnc_frequency.csv"
batch_size = 50

In [31]:
raw_data = entity_tracking_example_sampler(
    tokenizer=tokenizer,
    num_samples=500,
    data_file=data_file,
    # object_file=object_file,
    few_shot=False,
    alt_examples=True,
    # num_ents_or_ops=3,
    architecture="LLaMAForCausalLM",
)

dataset = Dataset.from_dict(
    {
        "input_ids": raw_data[0],
        "last_token_indices": raw_data[1],
        "labels": raw_data[2],
    }
).with_format("torch")

print(f"Length of dataset: {len(dataset)}")

dataloader = DataLoader(dataset, batch_size=batch_size)

Length of dataset: 500


In [32]:
idx = 0
print(
    f"Prompt: {tokenizer.decode(dataset[idx]['input_ids'][:dataset[idx]['last_token_indices']+1])}"
)
print(f"Answer: {tokenizer.decode(dataset[idx]['labels'])}")

Prompt: <s>The document is in Box X, the pot is in Box T, the magnet is in Box A, the game is in Box E, the bill is in Box M, the cross is in Box K, the map is in Box D. Box X contains the
Answer: document


In [6]:
total_count = 0
correct_count = 0
model.eval()
errors = defaultdict(int)
with torch.no_grad():
    for _, output in tqdm(enumerate(tqdm(dataloader))):
        for k, v in output.items():
            if v is not None and isinstance(v, torch.Tensor):
                output[k] = v.to(model.device)

        outputs = model(input_ids=output["input_ids"])

        for bi in range(output["labels"].size(0)):
            label = output["labels"][bi]
            pred = torch.argmax(outputs.logits[bi][output["last_token_indices"][bi]])
            box_label = output["input_ids"][bi][output["last_token_indices"][bi] - 2]
            prev_box_label_pos = output["input_ids"][bi].eq(box_label).nonzero()[:, 0][0].item()
            prev_box_label_index = prev_box_label_pos // 8 + 1

            if label == pred:
                correct_count += 1
            else:
                errors[prev_box_label_index] += 1
                # print(prev_box_label_pos, prev_box_label_index)
                # print(f"Label: {tokenizer.decode(label)}, Prediction: {tokenizer.decode(pred)}")
            total_count += 1

del outputs
torch.cuda.empty_cache()

current_acc = round(correct_count / total_count, 2)
print(f"Task accuracy: {current_acc}")

  0%|          | 0/10 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
 10%|█         | 1/10 [00:02<00:20,  2.24s/it]
 20%|██        | 2/10 [00:04<00:17,  2.15s/it]
 30%|███       | 3/10 [00:06<00:14,  2.13s/it]
 40%|████      | 4/10 [00:08<00:12,  2.11s/it]
 50%|█████     | 5/10 [00:10<00:10,  2.10s/it]
 60%|██████    | 6/10 [00:12<00:08,  2.10s/it]
 70%|███████   | 7/10 [00:14<00:06,  2.10s/it]
 80%|████████  | 8/10 [00:16<00:04,  2.10s/it]
 90%|█████████ | 9/10 [00:18<00:02,  2.09s/it]
100%|██████████| 10/10 [00:21<00:00,  2.11s/it]
10it [00:21,  2.11s/it]

Task accuracy: 0.82





## Loading Counterfactual Data

In [4]:
num_boxes = 7
raw_data = box_index_aligner_examples(
    tokenizer,
    num_samples=50,
    data_file=f"./box_datasets/no_instructions/alternative/Random/{num_boxes}/train.jsonl",
    # object_file="./box_datasets/objects_with_bnc_frequency.csv",
    architecture="LLaMAForCausalLM",
    few_shot=False,
    alt_examples=True,
    num_ents_or_ops=num_boxes,
)

In [5]:
base_tokens = raw_data[0]
base_last_token_indices = raw_data[1]
source_tokens = raw_data[2]
source_last_token_indices = raw_data[3]
correct_answer_token = raw_data[4]
# incorrect_answer_token = raw_data[6]

base_tokens = torch.cat([t.unsqueeze(dim=0) for t in base_tokens], dim=0).to(device)
source_tokens = torch.cat([t.unsqueeze(dim=0) for t in source_tokens], dim=0).to(device)

In [6]:
for i in range(-6, -1):
    print(tokenizer.decode(raw_data[0][i][: raw_data[1][i] + 1]))
    print(tokenizer.decode(raw_data[2][i][: raw_data[3][i] + 1]))
    print(tokenizer.decode(raw_data[4][i]))
    print()

<s>The paper is in Box D, the shell is in Box U, the car is in Box V, the television is in Box O, the drink is in Box S, the fan is in Box C, the sheet is in Box Z. Box V contains the
<s>The clock is in Box M, the bomb is in Box J, the newspaper is in Box G, the letter is in Box L, the suit is in Box Y, the computer is in Box R, the wheel is in Box V. Box R contains the
car

<s>The paper is in Box D, the shell is in Box U, the car is in Box V, the television is in Box O, the drink is in Box S, the fan is in Box C, the sheet is in Box Z. Box O contains the
<s>The document is in Box X, the pot is in Box T, the magnet is in Box A, the game is in Box E, the bill is in Box M, the cross is in Box K, the map is in Box D. Box S contains the
television

<s>The paper is in Box D, the shell is in Box U, the car is in Box V, the television is in Box O, the drink is in Box S, the fan is in Box C, the sheet is in Box Z. Box S contains the
<s>The ticket is in Box N, the book is in Box J, the gift is 

## Implementing Path Patching

In [7]:
hook_points = [
    f"base_model.model.model.layers.{layer}.self_attn.o_proj" for layer in range(model.config.num_hidden_layers)
]

# hook_points += [f"model.layers.{layer}.mlp" for layer in range(model.config.num_hidden_layers)]

with torch.no_grad():
    # Step 1
    with TraceDict(
        model,
        hook_points,
        retain_input=True,
    ) as clean_cache:
        _ = model(base_tokens)

    with TraceDict(
        model,
        hook_points,
        retain_input=True,
    ) as corrupt_cache:
        _ = model(source_tokens)

In [8]:
def patching_heads(
    inputs,
    output,
    layer,
    sender_layer,
    sender_head,
    clean_last_token_indices,
    corrupt_last_token_indices,
    rel_pos,
):
    """
    rel_pos: Represents the token position relative to the "real" (non-padded) last token in the sequence. All the heads at this position and subsequent positions need to patched from clean run, except the sender head at this position.
    """

    input = inputs[0]
    batch_size = input.size(0)

    if "o_proj" in layer:
        input = rearrange(
            input,
            "batch seq_len (n_heads d_head) -> batch seq_len n_heads d_head",
            n_heads=model.config.num_attention_heads,
        )
        clean_head_outputs = rearrange(
            clean_cache[layer].input,
            "batch seq_len (n_heads d_head) -> batch seq_len n_heads d_head",
            n_heads=model.config.num_attention_heads,
        )
        corrupt_head_outputs = rearrange(
            corrupt_cache[layer].input,
            "batch seq_len (n_heads d_head) -> batch seq_len n_heads d_head",
            n_heads=model.config.num_attention_heads,
        )

        layer = int(layer.split(".")[4])
        if sender_layer == layer:
            for bi in range(batch_size):
                if rel_pos == -1:
                    # Computing the previous query box label token position
                    clean_prev_box_label_pos = analysis_utils.compute_prev_query_box_pos(
                        base_tokens[bi], clean_last_token_indices[bi]
                    )
#                     corrupt_prev_box_label_pos = analysis_utils.compute_prev_query_box_pos(
#                         source_tokens[bi], corrupt_last_token_indices[bi]
#                     )

                    # Since, queery box may not present in the corrupt prompt, patch in
                    # the output of heads from any random box label token
                    corrupt_prev_box_label_pos = random.choice(range(6, 49, 7))
                else:
                    clean_prev_box_label_pos = clean_last_token_indices[bi] - rel_pos
                    corrupt_prev_box_label_pos = corrupt_last_token_indices[bi] - rel_pos

                for pos in range(clean_prev_box_label_pos, clean_last_token_indices[bi] + 1):
                    for head_ind in range(model.config.num_attention_heads):
                        if head_ind == sender_head and pos == clean_prev_box_label_pos:
                            input[bi, pos, sender_head] = corrupt_head_outputs[
                                bi, corrupt_prev_box_label_pos, sender_head
                            ]
                        else:
                            input[bi, pos, head_ind] = clean_head_outputs[bi, pos, head_ind]

        else:
            for bi in range(batch_size):
                if rel_pos == -1:
                    # Computing the previous query box label token position
                    clean_prev_box_label_pos = analysis_utils.compute_prev_query_box_pos(
                        base_tokens[bi], clean_last_token_indices[bi]
                    )
                else:
                    clean_prev_box_label_pos = clean_last_token_indices[bi] - rel_pos

                for pos in range(clean_prev_box_label_pos, clean_last_token_indices[bi] + 1):
                    input[bi, pos] = clean_head_outputs[bi, pos]

        input = rearrange(
            input,
            "batch seq_len n_heads d_head -> batch seq_len (n_heads d_head)",
            n_heads=model.config.num_attention_heads,
        )

        w_o = model.base_model.model.model.layers[layer].self_attn.o_proj.weight
        output = einsum(
            input,
            w_o,
            "batch seq_len hidden_size, d_model hidden_size -> batch seq_len d_model",
        )

    return output

In [9]:
def patching_receiver_heads(
    output, layer, patched_cache, receiver_heads, clean_last_token_indices, rel_pos
):
    batch_size = output.size(0)
    receiver_heads_in_curr_layer = [h for l, h in receiver_heads if l == int(layer.split(".")[4])]

    output = rearrange(
        output,
        "batch seq_len (n_heads d_head) -> batch seq_len n_heads d_head",
        n_heads=model.config.num_attention_heads,
    )
    patched_head_outputs = rearrange(
        patched_cache[layer].output,
        "batch seq_len (n_heads d_head) -> batch seq_len n_heads d_head",
        n_heads=model.config.num_attention_heads,
    )

    # Patch in the output of the receiver heads from patched run
    for receiver_head in receiver_heads_in_curr_layer:
        for bi in range(batch_size):
            if rel_pos == -1:
                # Computing the previous query box label token position
                clean_prev_box_label_pos = analysis_utils.compute_prev_query_box_pos(
                    base_tokens[bi], clean_last_token_indices[bi]
                )
            else:
                clean_prev_box_label_pos = clean_last_token_indices[bi] - rel_pos

            output[bi, clean_prev_box_label_pos, receiver_head] = patched_head_outputs[
                bi, clean_prev_box_label_pos, receiver_head
            ]

    output = rearrange(
        output,
        "batch seq_len n_heads d_head -> batch seq_len (n_heads d_head)",
        n_heads=model.config.num_attention_heads,
    )

    return output

In [26]:
receiver_heads = heads_at_query_box_pos

receiver_layers = list(
    set([f"base_model.model.model.layers.{layer}.self_attn.v_proj" for layer, _ in receiver_heads])
)

# receiver_heads = [[layer, head] for layer, head in [[21, 3]]]

print(receiver_layers)
print(receiver_heads)

['base_model.model.model.layers.13.self_attn.v_proj', 'base_model.model.model.layers.6.self_attn.v_proj', 'base_model.model.model.layers.9.self_attn.v_proj', 'base_model.model.model.layers.7.self_attn.v_proj', 'base_model.model.model.layers.11.self_attn.v_proj', 'base_model.model.model.layers.10.self_attn.v_proj', 'base_model.model.model.layers.1.self_attn.v_proj']
[[10, 3], [13, 14], [9, 2], [9, 7], [11, 23], [6, 10], [11, 24], [9, 10], [1, 9], [7, 17]]


In [27]:
path_patching_logits = torch.zeros(
    model.config.num_hidden_layers, model.config.num_attention_heads
).to(device)
batch_size = base_tokens.size(0)
apply_softmax = torch.nn.Softmax(dim=-1)

for layer in tqdm(range(model.config.num_hidden_layers)):
    for head in range(model.config.num_attention_heads):
        with torch.no_grad():
            # Step 2
            with TraceDict(
                model,
                hook_points + receiver_layers,
                retain_input=True,
                edit_output=partial(
                    patching_heads,
                    sender_layer=layer,
                    sender_head=head,
                    clean_last_token_indices=base_last_token_indices,
                    corrupt_last_token_indices=source_last_token_indices,
                    rel_pos=-1,
                ),
            ) as patched_cache:
                _ = model(base_tokens)

            # Step 3
            with TraceDict(
                model,
                receiver_layers,
                retain_input=True,
                edit_output=partial(
                    patching_receiver_heads,
                    patched_cache=patched_cache,
                    receiver_heads=receiver_heads,
                    clean_last_token_indices=base_last_token_indices,
                    rel_pos=-1,
                ),
            ) as _:
                patched_out = model(base_tokens)

            for bi in range(batch_size):
                logits = apply_softmax(patched_out.logits[bi, base_last_token_indices[bi]])
                path_patching_logits[layer, head] += (logits[correct_answer_token[bi]]).item()

            path_patching_logits[layer, head] = path_patching_logits[layer, head] / batch_size

del patched_out
torch.cuda.empty_cache()

100%|██████████| 32/32 [1:36:50<00:00, 181.57s/it]


In [28]:
torch.save(path_patching_logits, "./new_pp_exps/post_submission/goat-7b/seed_30/heads_at_prev_query_box_pos.pt")

In [29]:
heads_at_prev_query_box_pos = analysis_utils.compute_topk_components(
    path_patching_logits, k=5,largest=False
)
print(f"Head influencing object info fetcher heads: {heads_at_query_box_pos}")

Head influencing object info fetcher heads: [[10, 3], [13, 14], [9, 2], [9, 7], [11, 23], [6, 10], [11, 24], [9, 10], [1, 9], [7, 17]]


## Evaluation Metrics

### Faithfulness

In [33]:
data_file = "./box_datasets/no_instructions/alternative/Random/7/train.jsonl"
object_file = "./box_datasets/filtered_objects_with_bnc_frequency.csv"
batch_size = 50

In [34]:
raw_data = generate_data_for_eval(
    tokenizer=tokenizer,
    num_samples=3500,
    data_file=data_file,
    num_boxes=7,
)

ablate_dataset = Dataset.from_dict(
    {
        "input_ids": raw_data[0],
        "last_token_indices": raw_data[1],
    }
).with_format("torch")

print(f"Length of dataset: {len(ablate_dataset)}")

ablate_dataloader = DataLoader(ablate_dataset, batch_size=batch_size)

Length of dataset: 500


In [35]:
idx = 0
print(
    f"Prompt: {tokenizer.decode(dataset[idx]['input_ids'][:dataset[idx]['last_token_indices']+1])}"
)

Prompt: <s>The document is in Box X, the pot is in Box T, the magnet is in Box A, the game is in Box E, the bill is in Box M, the cross is in Box K, the map is in Box D. Box X contains the


In [36]:
if model.config.architectures[0] == "LlamaForCausalLM":
    modules = [f"model.layers.{layer}.self_attn.o_proj" for layer in range(32)]
else:
    modules = [f"base_model.model.model.layers.{layer}.self_attn.o_proj" for layer in range(32)]

mean_activations = {}
with torch.no_grad():
    # Assuming a single batch
    for _, output in tqdm(enumerate(tqdm(ablate_dataloader))):
        for k, v in output.items():
            if v is not None and isinstance(v, torch.Tensor):
                output[k] = v.to(model.device)

        with TraceDict(model, modules, retain_input=True) as cache:
            _ = model(output["input_ids"])

        for layer in modules:
            if "self_attn" in layer:
                if layer in mean_activations:
                    mean_activations[layer] += torch.mean(cache[layer].input, dim=0)
                else:
                    mean_activations[layer] = torch.mean(cache[layer].input, dim=0)
            else:
                if layer in mean_activations:
                    mean_activations[layer] += torch.mean(cache[layer].output, dim=0)
                else:
                    mean_activations[layer] = torch.mean(cache[layer].output, dim=0)

        del cache
        torch.cuda.empty_cache()

    for layer in modules:
        mean_activations[layer] /= len(ablate_dataloader)

  0%|          | 0/10 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
 10%|█         | 1/10 [00:02<00:19,  2.15s/it]
 20%|██        | 2/10 [00:04<00:17,  2.13s/it]
 30%|███       | 3/10 [00:06<00:14,  2.12s/it]
 40%|████      | 4/10 [00:08<00:12,  2.11s/it]
 50%|█████     | 5/10 [00:10<00:10,  2.11s/it]
 60%|██████    | 6/10 [00:12<00:08,  2.11s/it]
 70%|███████   | 7/10 [00:14<00:06,  2.11s/it]
 80%|████████  | 8/10 [00:16<00:04,  2.11s/it]
 90%|█████████ | 9/10 [00:19<00:02,  2.11s/it]
100%|██████████| 10/10 [00:21<00:00,  2.11s/it]
10it [00:21,  2.11s/it]


In [37]:
def mean_ablate(inputs, output, layer, circuit_components, mean_activations, input_tokens):
    if isinstance(inputs, tuple):
        inputs = inputs[0]

    if isinstance(output, tuple):
        output = output[0]

    inputs = rearrange(
        inputs,
        "batch seq_len (n_heads d_head) -> batch seq_len n_heads d_head",
        n_heads=model.config.num_attention_heads,
    )

    mean_act = rearrange(
        mean_activations[layer],
        "seq_len (n_heads d_head) -> 1 seq_len n_heads d_head",
        n_heads=model.config.num_attention_heads,
    )

    last_pos = inputs.size(1) - 1
    for bi in range(inputs.size(0)):
        prev_query_box_pos = analysis_utils.compute_prev_query_box_pos(
            input_tokens[bi], input_tokens[bi].size(0) - 1
        )
        for token_pos in range(inputs.size(1)):
            if (
                token_pos != prev_query_box_pos
                and token_pos != last_pos
                and token_pos != last_pos - 2
                and token_pos != prev_query_box_pos + 1
            ):
                inputs[bi, token_pos, :] = mean_act[0, token_pos, :]
            elif token_pos == prev_query_box_pos:
                for head_idx in range(model.config.num_attention_heads):
                    if head_idx not in circuit_components[-1][layer]:
                        inputs[bi, token_pos, head_idx] = mean_act[0, token_pos, head_idx]
            elif token_pos == prev_query_box_pos + 1:
                for head_idx in range(model.config.num_attention_heads):
                    if head_idx not in circuit_components[-2][layer]:
                        inputs[bi, token_pos, head_idx] = mean_act[0, token_pos, head_idx]
            elif token_pos == last_pos:
                for head_idx in range(model.config.num_attention_heads):
                    if head_idx not in circuit_components[0][layer]:
                        inputs[bi, token_pos, head_idx] = mean_act[0, token_pos, head_idx]
            elif token_pos == last_pos - 2:
                for head_idx in range(model.config.num_attention_heads):
                    if head_idx not in circuit_components[2][layer]:
                        inputs[bi, token_pos, head_idx] = mean_act[0, token_pos, head_idx]

    inputs = rearrange(
        inputs,
        "batch seq_len n_heads d_head -> batch seq_len (n_heads d_head)",
        n_heads=model.config.num_attention_heads,
    )
    w_o = model.state_dict()[f"{layer}.weight"]
    output = einsum(
        inputs, w_o, "batch seq_len hidden_size, d_model hidden_size -> batch seq_len d_model"
    )

    return output

In [38]:
def eval(model, dataloader, modules, circuit_components, mean_activations):
    correct_count, total_count = 0, 0
    with torch.no_grad():
        for _, output in enumerate(tqdm(dataloader)):
            for k, v in output.items():
                if v is not None and isinstance(v, torch.Tensor):
                    output[k] = v.to(model.device)
    
            with TraceDict(
                model,
                modules,
                retain_input=True,
                edit_output=partial(
                    mean_ablate,
                    circuit_components=circuit_components,
                    mean_activations=mean_activations,
                    input_tokens=output["input_ids"],
                ),
            ) as _:
                outputs = model(output["input_ids"])
    
            for bi in range(output["labels"].size(0)):
                label = output["labels"][bi]
                pred = torch.argmax(outputs.logits[bi][output["last_token_indices"][bi]])
    
                if label == pred:
                    correct_count += 1
                # else:
                #     print(f"Label: {tokenizer.decode(label)}, Prediction: {tokenizer.decode(pred)}")
                total_count += 1
    
            del outputs
            torch.cuda.empty_cache()
    
    current_acc = round(correct_count / total_count, 2)
    print(f"Task accuracy: {current_acc}")
    return current_acc

In [40]:
# [47, 15, 25, 10]
n_value_fetcher = 61
n_pos_trans = 15
n_pos_detect = 30
n_struct_read = 5

circuit_components = {}
circuit_components[0] = defaultdict(list)
circuit_components[2] = defaultdict(list)
circuit_components[-1] = defaultdict(list)
circuit_components[-2] = defaultdict(list)

root_path = "./new_pp_exps/post_submission/goat-7b/seed_30"
path = root_path + "/direct_logit_heads.pt"

direct_logit_heads = analysis_utils.compute_topk_components(torch.load(path), k=n_value_fetcher, largest=False)

path = root_path + "/heads_affect_direct_logit.pt"
heads_affecting_direct_logit_heads = analysis_utils.compute_topk_components(
    torch.load(path), k=n_pos_trans, largest=False
)

path = root_path + "/heads_at_query_box_pos.pt"
head_at_query_box_token = analysis_utils.compute_topk_components(
    torch.load(path), k=n_pos_detect, largest=False
)

path = root_path + "/heads_at_prev_query_box_pos.pt"
heads_at_prev_box_pos = analysis_utils.compute_topk_components(torch.load(path), k=n_struct_read, largest=False)

intersection = []
for head in direct_logit_heads:
    if head in heads_affecting_direct_logit_heads:
        intersection.append(head)

for head in intersection:
    direct_logit_heads.remove(head)

print(
    len(direct_logit_heads),
    len(heads_affecting_direct_logit_heads),
    len(head_at_query_box_token),
    len(heads_at_prev_box_pos),
)

for layer_idx, head in direct_logit_heads:
    if model.config.architectures[0] == "LlamaForCausalLM":
        layer = f"model.layers.{layer_idx}.self_attn.o_proj"
    else:
        layer = f"base_model.model.model.layers.{layer_idx}.self_attn.o_proj"
    circuit_components[0][layer].append(head)

for layer_idx, head in heads_affecting_direct_logit_heads:
    if model.config.architectures[0] == "LlamaForCausalLM":
        layer = f"model.layers.{layer_idx}.self_attn.o_proj"
    else:
        layer = f"base_model.model.model.layers.{layer_idx}.self_attn.o_proj"
    circuit_components[0][layer].append(head)


for layer_idx, head in head_at_query_box_token:
    if model.config.architectures[0] == "LlamaForCausalLM":
        layer = f"model.layers.{layer_idx}.self_attn.o_proj"
    else:
        layer = f"base_model.model.model.layers.{layer_idx}.self_attn.o_proj"
    circuit_components[2][layer].append(head)


for layer_idx, head in heads_at_prev_box_pos:
    if model.config.architectures[0] == "LlamaForCausalLM":
        layer = f"model.layers.{layer_idx}.self_attn.o_proj"
    else:
        layer = f"base_model.model.model.layers.{layer_idx}.self_attn.o_proj"
    circuit_components[-1][layer].append(head)

for pos in circuit_components.keys():
    for layer_idx in circuit_components[pos].keys():
        circuit_components[pos][layer_idx] = list(set(circuit_components[pos][layer_idx]))

eval(model, dataloader, modules, circuit_components, mean_activations)

50 15 30 5


100%|██████████| 10/10 [01:06<00:00,  6.64s/it]

Task accuracy: 0.48





0.48