In [1]:
import torch
import os
from torch.nn import CosineSimilarity
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer
from functools import partial
from baukit import TraceDict
from einops import rearrange, einsum
from collections import defaultdict
import matplotlib.pyplot as plt
from plotly_utils import imshow, scatter
from tqdm import tqdm
from datasets import Dataset
from torch.utils.data import DataLoader
from peft import PeftModel

import pysvelte
import analysis_utils
from counterfactual_datasets.entity_tracking import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(10)

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print("Model Loading...")
# path = "AlekseyKorshuk/vicuna-7b"
path = "/data/nikhil_prakash/llama_weights/7B"
llama_tokenizer = AutoTokenizer.from_pretrained(path)
llama_model = AutoModelForCausalLM.from_pretrained(path).to(device)

# base_model = "decapoda-research/llama-7b-hf"
# lora_weights = "tiedong/goat-lora-7b"

# goat_model = LlamaForCausalLM.from_pretrained(
#     base_model,
#     load_in_8bit=False,
#     torch_dtype=torch.float32,
#     device_map="auto",
# )
# goat_model = PeftModel.from_pretrained(
#     goat_model,
#     lora_weights,
#     torch_dtype=torch.float32,
#     device_map={'': 0},
# )

llama_tokenizer.pad_token_id = llama_tokenizer.eos_token_id

Model Loading...


In [3]:
data_file = "./box_datasets/no_instructions/alternative/Random/5/train.jsonl"
object_file = "./box_datasets/objects_with_bnc_frequency.csv"
batch_size = 16

In [4]:
raw_data = entity_tracking_example_sampler(
    tokenizer=llama_tokenizer,
    num_samples=600,
    data_file=data_file,
    # object_file=object_file,
    few_shot=False,
    alt_examples=True,
    # num_ents_or_ops=3,
    architecture="LLaMAForCausalLM",
)

dataset = Dataset.from_dict(
    {
        "input_ids": raw_data[0],
        "last_token_indices": raw_data[1],
        "labels": raw_data[2],
    }
).with_format("torch")

print(f"Length of dataset: {len(dataset)}")

dataloader = DataLoader(dataset, batch_size=batch_size)

Length of dataset: 600


In [5]:
idx = 0
print(
    f"Prompt: {llama_tokenizer.decode(dataset[idx]['input_ids'][:dataset[idx]['last_token_indices']+1])}"
)
print(f"Answer: {llama_tokenizer.decode(dataset[idx]['labels'])}")

Prompt:  The machine is in Box X, the magazine is in Box T, the key is in Box A, the fig is in Box E, the bomb is in Box M. Box X contains the
Answer:  machine


In [6]:
total_count = 0
correct_count = 0
goat_model.eval()
with torch.no_grad():
    for _, inputs in tqdm(enumerate(dataloader)):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(goat_model.device)

        outputs = goat_model(input_ids=inputs["input_ids"])

        for bi in range(inputs["labels"].size(0)):
            label = inputs["labels"][bi]
            pred = torch.argmax(outputs.logits[bi][inputs["last_token_indices"][bi]])

            if label == pred:
                correct_count += 1
            # else:
            #     print(f"Label: {tokenizer.decode(label)}, Prediction: {tokenizer.decode(pred)}")
            total_count += 1

del outputs
torch.cuda.empty_cache()

current_acc = round(correct_count / total_count, 2)
print(f"Task accuracy: {current_acc}")

38it [00:19,  1.94it/s]

Task accuracy: 0.9





In [7]:
root_path = "./new_pp_exps/5_boxes/reverse"
path = root_path + "/direct_logit_heads.pt"
logit_values = torch.load(path)
direct_logit_heads = analysis_utils.compute_topk_components(torch.load(path), k=35, largest=False)

path = root_path + "/heads_affect_direct_logit.pt"
logit_values = torch.load(path)
heads_affecting_direct_logit_heads = analysis_utils.compute_topk_components(
    torch.load(path), k=10, largest=False
)

path = root_path + "/heads_at_query_box_pos.pt"
logit_values = torch.load(path)
head_at_query_box_token = analysis_utils.compute_topk_components(
    torch.load(path), k=15, largest=False
)

path = root_path + "/heads_at_prev_query_box_pos.pt"
logit_values = torch.load(path)
heads_at_prev_box_pos = analysis_utils.compute_topk_components(torch.load(path), k=5, largest=False)

path = root_path + "/heads_at_next_token_to_prev_box_token.pt"
logit_values = torch.load(path)
heads_at_next_token_to_box_pos = analysis_utils.compute_topk_components(
    torch.load(path), k=5, largest=False
)

In [8]:
llama_modules = [[f"model.layers.{layer}.self_attn.k_proj", 
                  f"model.layers.{layer}.self_attn.q_proj",
                  f"model.layers.{layer}.self_attn.v_proj",
                 f"model.layers.{layer}.self_attn.o_proj"] 
                 for layer in range(32)]
goat_modules = [[f"base_model.model.model.layers.{layer}.self_attn.k_proj", 
                 f"base_model.model.model.layers.{layer}.self_attn.q_proj",
                 f"base_model.model.model.layers.{layer}.self_attn.v_proj",
                f"base_model.model.model.layers.{layer}.self_attn.o_proj"] 
                for layer in range(32)]

llama_modules = [item for sublist in llama_modules for item in sublist]
goat_modules = [item for sublist in goat_modules for item in sublist]

In [9]:
goat_cache = {}

with torch.no_grad():
    for bi, inputs in tqdm(enumerate(dataloader)):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(goat_model.device)

        with TraceDict(goat_model, goat_modules, retain_input=True) as cache:
            _ = goat_model(inputs["input_ids"])
        
        for llama_layer, goat_layer in zip(llama_modules, goat_modules):
            if "o_proj" in llama_layer:
                if bi in goat_cache:
                    goat_cache[bi][llama_layer] = cache[goat_layer].input.cpu()
                else:
                    goat_cache[bi] = {}
                    goat_cache[bi][llama_layer] = cache[goat_layer].input.cpu()
            else:
                if bi in goat_cache:
                    goat_cache[bi][llama_layer] = cache[goat_layer].output.cpu()
                else:
                    goat_cache[bi] = {}
                    goat_cache[bi][llama_layer] = cache[goat_layer].output.cpu()

38it [00:49,  1.30s/it]


In [27]:
def cross_model_patching(inputs, outputs, layer, heads, bi, rel_pos, input_tokens):
    if isinstance(inputs, tuple):
        inputs = inputs[0]

    if isinstance(outputs, tuple):
        outputs = outputs[0]

    cache = rearrange(
                goat_cache[bi][layer],
                "batch seq_len (n_heads d_head) -> batch seq_len n_heads d_head",
                n_heads=llama_model.config.num_attention_heads,
            )

    curr_layer_heads = [h for l, h in heads if l == int(layer.split(".")[2])]
    
    if "o_proj" in layer:
        pass
#         inputs = rearrange(
#                 inputs,
#                 "batch seq_len (n_heads d_head) -> batch seq_len n_heads d_head",
#                 n_heads=llama_model.config.num_attention_heads,
#             )
#         if rel_pos == -1:
#             for batch in range(inputs.size(0)):
#                 prev_query_box_pos = analysis_utils.compute_prev_query_box_pos(
#                     input_tokens["input_ids"][batch],
#                     input_tokens["last_token_indices"][batch]
#                 )
#                 correct_object_pos = prev_query_box_pos - 3
#                 for head in curr_layer_heads:
#                     inputs[:, correct_object_pos, head] = cache[:, correct_object_pos, head]
            
#         else:
#             pos = outputs.size(1) - rel_pos - 1
#             for head in curr_layer_heads:
#                 inputs[:, pos, head] = cache[:, pos, head]

#         inputs = rearrange(
#                 inputs,
#                 "batch seq_len n_heads d_head -> batch seq_len (n_heads d_head)",
#                 n_heads=llama_model.config.num_attention_heads,
#             )
#         w_o = llama_model.state_dict()[f"{layer}.weight"]
#         outputs = einsum(
#             inputs, w_o, "batch seq_len hidden_size, d_model hidden_size -> batch seq_len d_model"
#         )

    else:
        outputs = rearrange(
                outputs,
                "batch seq_len (n_heads d_head) -> batch seq_len n_heads d_head",
                n_heads=llama_model.config.num_attention_heads,
            )
        
#         if "v_proj" in layer:
#             for batch in range(outputs.size(0)):
#                 prev_query_box_pos = analysis_utils.compute_prev_query_box_pos(
#                     input_tokens["input_ids"][batch],
#                     input_tokens["last_token_indices"][batch]
#                 )
#                 pos = prev_query_box_pos - 4
# #                 pos = outputs.size(1)
#                 for head in curr_layer_heads:
#                     outputs[batch, pos, head] = cache[batch, pos, head]

#         if "k_proj" in layer:
#             for batch in range(outputs.size(0)):
#                 prev_query_box_pos = analysis_utils.compute_prev_query_box_pos(
#                     input_tokens["input_ids"][batch],
#                     input_tokens["last_token_indices"][batch]
#                 )
#                 pos = prev_query_box_pos - 4
# #                 pos = outputs.size(1) - 2 - 1
#                 for head in curr_layer_heads:
#                     outputs[batch, pos, head] = cache[batch, pos, head]

        if "v_proj" in layer:
#             object_pos = [i for i in range(2, 35, 7)]
#             for batch in range(outputs.size(0)):
#                 prev_query_box_pos = analysis_utils.compute_prev_query_box_pos(
#                     input_tokens["input_ids"][batch],
#                     input_tokens["last_token_indices"][batch]
#                 )
#                 pos = prev_query_box_pos - 4
            for head in curr_layer_heads:
                outputs[:, :, head] = cache[:, :, head]

        if "k_proj" in layer:
#             object_pos = [i for i in range(2, 35, 7)]
            for head in curr_layer_heads:
                outputs[:, :, head] = cache[:, :, head]

        if "q_proj" in layer:
            pos = outputs.size(1) - rel_pos - 1
            for head in curr_layer_heads:
                outputs[:, pos, head] = cache[:, pos, head]

        outputs = rearrange(
                    outputs,
                    "batch seq_len n_heads d_head -> batch seq_len (n_heads d_head)",
                    n_heads=llama_model.config.num_attention_heads,
                )

    return outputs

In [28]:
correct_count, total_count = 0, 0

with torch.no_grad():
    for bi, inputs in tqdm(enumerate(dataloader)):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(goat_model.device)

        with TraceDict(llama_model, llama_modules, retain_input=True, edit_output=partial(
                                                                        cross_model_patching,
                                                                        heads = direct_logit_heads,
                                                                        bi = bi,
                                                                        rel_pos = 0,
                                                                        input_tokens = inputs)) as _:
                outputs = llama_model(inputs["input_ids"])
        
        for batch in range(inputs["labels"].size(0)):
            label = inputs["labels"][batch]
            pred = torch.argmax(outputs.logits[batch][inputs["last_token_indices"][batch]])

            if label == pred:
                correct_count += 1
            # else:
            #     print(f"Label: {tokenizer.decode(label)}, Prediction: {tokenizer.decode(pred)}")
            total_count += 1
        
        del outputs
        torch.cuda.empty_cache()

current_acc = round(correct_count / total_count, 2)
print(f"Task accuracy: {current_acc}")

38it [00:19,  1.92it/s]

Task accuracy: 0.89





In [227]:
object_pos = [i for i in range(2, 35, 7)]

In [102]:
llama_tokenizer.decode(inputs["input_ids"][1])

' The block is in Box N, the bus is in Box J, the dress is in Box L, the wire is in Box K, the boot is in Box V. Box K contains the'