In [1]:
import torch
import os
from torch.nn import CosineSimilarity
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer
from functools import partial
from baukit import TraceDict
from einops import rearrange, einsum
from collections import defaultdict
import matplotlib.pyplot as plt
from plotly_utils import imshow, scatter
from tqdm import tqdm
from datasets import Dataset
from torch.utils.data import DataLoader
from peft import PeftModel

import pysvelte
import analysis_utils
from counterfactual_datasets.entity_tracking import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(10)

%load_ext autoreload
%autoreload 2

  warn(f"Failed to load image Python extension: {e}")


In [2]:
print("Model Loading...")
# path = "AlekseyKorshuk/vicuna-7b"
path = "/data/nikhil_prakash/llama_weights/7B"
llama_tokenizer = AutoTokenizer.from_pretrained(path)
llama_model = AutoModelForCausalLM.from_pretrained(path).to(device)

base_model = "decapoda-research/llama-7b-hf"
lora_weights = "tiedong/goat-lora-7b"

goat_model = LlamaForCausalLM.from_pretrained(
    base_model,
    load_in_8bit=False,
    torch_dtype=torch.float32,
    device_map="auto",
)
goat_model = PeftModel.from_pretrained(
    goat_model,
    lora_weights,
    torch_dtype=torch.float32,
    device_map={'': 0},
)

llama_tokenizer.pad_token_id = llama_tokenizer.eos_token_id

Model Loading...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

In [3]:
data_file = "./box_datasets/no_instructions/alternative/Random/7/train.jsonl"
object_file = "./box_datasets/objects_with_bnc_frequency.csv"
batch_size = 8

In [4]:
raw_data = entity_tracking_example_sampler(
    tokenizer=llama_tokenizer,
    num_samples=500,
    data_file=data_file,
    # object_file=object_file,
    few_shot=False,
    alt_examples=True,
    # num_ents_or_ops=3,
    architecture="LLaMAForCausalLM",
)

dataset = Dataset.from_dict(
    {
        "input_ids": raw_data[0],
        "last_token_indices": raw_data[1],
        "labels": raw_data[2],
    }
).with_format("torch")

print(f"Length of dataset: {len(dataset)}")

dataloader = DataLoader(dataset, batch_size=batch_size)

Length of dataset: 500


In [5]:
idx = 0
print(
    f"Prompt: {llama_tokenizer.decode(dataset[idx]['input_ids'][:dataset[idx]['last_token_indices']+1])}"
)
print(f"Answer: {llama_tokenizer.decode(dataset[idx]['labels'])}")

Prompt:  The document is in Box X, the pot is in Box T, the magnet is in Box A, the game is in Box E, the bill is in Box M, the cross is in Box K, the map is in Box D. Box X contains the
Answer:  document


In [6]:
total_count = 0
correct_count = 0
goat_model.eval()
with torch.no_grad():
    for _, inputs in tqdm(enumerate(dataloader)):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(llama_model.device)

        outputs = llama_model(input_ids=inputs["input_ids"])

        for bi in range(inputs["labels"].size(0)):
            label = inputs["labels"][bi]
            pred = torch.argmax(outputs.logits[bi][inputs["last_token_indices"][bi]])

            if label == pred:
                correct_count += 1
            # else:
            #     print(f"Label: {tokenizer.decode(label)}, Prediction: {tokenizer.decode(pred)}")
            total_count += 1

del outputs
torch.cuda.empty_cache()

current_acc = round(correct_count / total_count, 2)
print(f"Task accuracy: {current_acc}")

63it [00:24,  2.58it/s]

Task accuracy: 0.66





In [7]:
with open("circuit_heads.json", "r") as f:
    circuit_heads = json.load(f)

In [8]:
llama_modules = [[f"model.layers.{layer}.self_attn.k_proj", 
                  f"model.layers.{layer}.self_attn.q_proj",
                  f"model.layers.{layer}.self_attn.v_proj",
                 f"model.layers.{layer}.self_attn.o_proj"] 
                 for layer in range(32)]
goat_modules = [[f"base_model.model.model.layers.{layer}.self_attn.k_proj", 
                 f"base_model.model.model.layers.{layer}.self_attn.q_proj",
                 f"base_model.model.model.layers.{layer}.self_attn.v_proj",
                f"base_model.model.model.layers.{layer}.self_attn.o_proj"] 
                for layer in range(32)]

llama_modules = [item for sublist in llama_modules for item in sublist]
goat_modules = [item for sublist in goat_modules for item in sublist]

In [9]:
goat_cache = {}

with torch.no_grad():
    for bi, inputs in tqdm(enumerate(dataloader)):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(goat_model.device)

        with TraceDict(goat_model, goat_modules, retain_input=True) as cache:
            _ = goat_model(inputs["input_ids"])
        
        for llama_layer, goat_layer in zip(llama_modules, goat_modules):
            if "o_proj" in llama_layer:
                if bi in goat_cache:
                    goat_cache[bi][llama_layer] = cache[goat_layer].input.cpu()
                else:
                    goat_cache[bi] = {}
                    goat_cache[bi][llama_layer] = cache[goat_layer].input.cpu()
            else:
                if bi in goat_cache:
                    goat_cache[bi][llama_layer] = cache[goat_layer].output.cpu()
                else:
                    goat_cache[bi] = {}
                    goat_cache[bi][llama_layer] = cache[goat_layer].output.cpu()

63it [00:49,  1.27it/s]


In [37]:
def cross_model_patching(inputs, outputs, layer, bi, relative_pos, input_tokens):
    if isinstance(inputs, tuple):
        inputs = inputs[0]

    if isinstance(outputs, tuple):
        outputs = outputs[0]

    cache = rearrange(
                goat_cache[bi][layer],
                "batch seq_len (n_heads d_head) -> batch seq_len n_heads d_head",
                n_heads=llama_model.config.num_attention_heads,
            )
    
    if "o_proj" in layer:
        pass
#         inputs = rearrange(
#                 inputs,
#                 "batch seq_len (n_heads d_head) -> batch seq_len n_heads d_head",
#                 n_heads=llama_model.config.num_attention_heads,
#             )

#         for rel_pos, heads in relative_pos.items():
#             curr_layer_heads = [h for l, h in heads if l == int(layer.split(".")[2])]            
            
#             if rel_pos == -1:
#                 for batch in range(inputs.size(0)):
#                     prev_query_box_pos = analysis_utils.compute_prev_query_box_pos(
#                         input_tokens["input_ids"][batch],
#                         input_tokens["last_token_indices"][batch]
#                     )
#                     for head in curr_layer_heads:
#                         inputs[batch, prev_query_box_pos, head] = cache[batch, prev_query_box_pos, head]

#             else:
#                 pos = inputs.size(1) - rel_pos - 1
#                 for head in curr_layer_heads:
#                     inputs[:, pos, head] = cache[:, pos, head]

#         inputs = rearrange(
#                 inputs,
#                 "batch seq_len n_heads d_head -> batch seq_len (n_heads d_head)",
#                 n_heads=llama_model.config.num_attention_heads,
#             )
#         w_o = llama_model.state_dict()[f"{layer}.weight"]
#         outputs = einsum(
#             inputs, w_o, "batch seq_len hidden_size, d_model hidden_size -> batch seq_len d_model"
#         )

    else:
        outputs = rearrange(
                outputs,
                "batch seq_len (n_heads d_head) -> batch seq_len n_heads d_head",
                n_heads=llama_model.config.num_attention_heads,
            )

        for rel_pos, heads in relative_pos.items():
            curr_layer_heads = [h for l, h in heads if l == int(layer.split(".")[2])]   
            
            if rel_pos == -1:
                for batch in range(inputs.size(0)):
                    prev_query_box_pos = analysis_utils.compute_prev_query_box_pos(
                        input_tokens["input_ids"][batch],
                        input_tokens["last_token_indices"][batch]
                    )
#                 if "v_proj" in layer:
#                     for head in curr_layer_heads:
#                         outputs[:, prev_query_box_pos, head] = cache[:, prev_query_box_pos, head]

#                 if "k_proj" in layer:
#                     for head in curr_layer_heads:
#                         outputs[:, prev_query_box_pos, head] = cache[:, prev_query_box_pos, head]

#                 if "q_proj" in layer:
#                     for head in curr_layer_heads:
#                         outputs[:, prev_query_box_pos, head] = cache[:, prev_query_box_pos, head]
            else:
#                 pos = outputs.size(1) - rel_pos - 1
#                 for batch in range(inputs.size(0)):
#                     prev_query_box_pos = analysis_utils.compute_prev_query_box_pos(
#                         input_tokens["input_ids"][batch],
#                         input_tokens["last_token_indices"][batch]
#                     )
#                     correct_obj_pos = prev_query_box_pos - 4
                if "v_proj" in layer:
                    for head in curr_layer_heads:
                        outputs[:, :, head] = cache[:, :, head]

                if "k_proj" in layer:
                    for head in curr_layer_heads:
                        outputs[:, :, head] = cache[:, :, head]

                if "q_proj" in layer:
                    for head in curr_layer_heads:
                        outputs[:, -1, head] = cache[:, -1, head]

        outputs = rearrange(
                    outputs,
                    "batch seq_len n_heads d_head -> batch seq_len (n_heads d_head)",
                    n_heads=llama_model.config.num_attention_heads,
                )

    return outputs

In [38]:
with open("./new_masks/llama-7b/direct_logit_heads/object_value/0.01.txt", "r") as f:
    data = f.readlines()
    value_fetcher_heads = json.loads(data[0].split(": ")[1])

with open("./new_masks/llama-7b/heads_affect_direct_logit/positional/0.01.txt", "r") as f:
    data = f.readlines()
    positional_info_fetcher_heads = json.loads(data[0].split(": ")[1])

with open("./new_masks/llama-7b/heads_at_query_box_pos/positional/0.01.txt", "r") as f:
    data = f.readlines()
    duplicate_token_heads = json.loads(data[0].split(": ")[1])

print(f"Value Fetcher Heads: {len(value_fetcher_heads)}")
print(f"Heads affecting direct logit heads: {len(positional_info_fetcher_heads)}")
print(f"Heads at query box token: {len(duplicate_token_heads)}")
print(f"Heads at prev query box token: {len(circuit_heads['heads_at_prev_box_pos'])}")

Value Fetcher Heads: 39
Heads affecting direct logit heads: 7
Heads at query box token: 17
Heads at prev query box token: 4


In [39]:
relative_pos = {}
relative_pos[0] = value_fetcher_heads
# relative_pos[0] = positional_info_fetcher_heads
# relative_pos[2] = duplicate_token_heads
# relative_pos[-1] = circuit_heads['heads_at_prev_box_pos']

In [40]:
correct_count, total_count = 0, 0
correct_obj_attn_score_mean, incorrect_obj_attn_score_sum, first_token_attn_score_sum = 0, 0, 0

with torch.no_grad():
    for bi, inputs in tqdm(enumerate(dataloader)):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(llama_model.device)

        with TraceDict(llama_model, llama_modules, retain_input=True, edit_output=partial(
                                                                        cross_model_patching,
                                                                        bi = bi,
                                                                        relative_pos = relative_pos,
                                                                        input_tokens = inputs)) as trace:
                outputs = llama_model(inputs["input_ids"], output_attentions=True)

        for batch in range(inputs["labels"].size(0)):
            label = inputs["labels"][batch]
            pred = torch.argmax(outputs.logits[batch][inputs["last_token_indices"][batch]])

            if label == pred:
                correct_count += 1
            total_count += 1

        del outputs
        torch.cuda.empty_cache()

current_acc = round(correct_count / total_count, 2)
correct_obj_attn_score_mean = correct_obj_attn_score_mean / len(dataloader)
print(f"Task accuracy: {current_acc}")

63it [01:19,  1.26s/it]

Task accuracy: 0.75



