In [1]:
import torch
import os
from torch.nn import CosineSimilarity
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer
from functools import partial
from baukit import TraceDict
from einops import rearrange, einsum
from collections import defaultdict
import matplotlib.pyplot as plt
from plotly_utils import imshow, scatter
from tqdm import tqdm
from datasets import Dataset
from torch.utils.data import DataLoader
from peft import PeftModel

import pysvelte
import analysis_utils
from counterfactual_datasets.entity_tracking import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(10)

%load_ext autoreload
%autoreload 2

  warn(f"Failed to load image Python extension: {e}")


In [2]:
print("Model Loading...")
# path = "AlekseyKorshuk/vicuna-7b"
path = "/data/nikhil_prakash/llama_weights/7B"
llama_tokenizer = AutoTokenizer.from_pretrained(path)
llama_model = AutoModelForCausalLM.from_pretrained(path).to(device)

base_model = "decapoda-research/llama-7b-hf"
lora_weights = "tiedong/goat-lora-7b"

goat_model = LlamaForCausalLM.from_pretrained(
    base_model,
    load_in_8bit=False,
    torch_dtype=torch.float32,
    device_map="auto",
)
goat_model = PeftModel.from_pretrained(
    goat_model,
    lora_weights,
    torch_dtype=torch.float32,
    device_map={'': 0},
)

llama_tokenizer.pad_token_id = llama_tokenizer.eos_token_id

Model Loading...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

In [3]:
data_file = "./box_datasets/no_instructions/alternative/Random/7/train.jsonl"
object_file = "./box_datasets/objects_with_bnc_frequency.csv"
batch_size = 8

In [4]:
raw_data = entity_tracking_example_sampler(
    tokenizer=llama_tokenizer,
    num_samples=600,
    data_file=data_file,
    # object_file=object_file,
    few_shot=False,
    alt_examples=True,
    # num_ents_or_ops=3,
    architecture="LLaMAForCausalLM",
)

dataset = Dataset.from_dict(
    {
        "input_ids": raw_data[0],
        "last_token_indices": raw_data[1],
        "labels": raw_data[2],
    }
).with_format("torch")

print(f"Length of dataset: {len(dataset)}")

dataloader = DataLoader(dataset, batch_size=batch_size)

Length of dataset: 600


In [5]:
idx = 0
print(
    f"Prompt: {llama_tokenizer.decode(dataset[idx]['input_ids'][:dataset[idx]['last_token_indices']+1])}"
)
print(f"Answer: {llama_tokenizer.decode(dataset[idx]['labels'])}")

Prompt:  The document is in Box X, the pot is in Box T, the magnet is in Box A, the game is in Box E, the bill is in Box M, the cross is in Box K, the map is in Box D. Box X contains the
Answer:  document


In [7]:
total_count = 0
correct_count = 0
goat_model.eval()
with torch.no_grad():
    for _, inputs in tqdm(enumerate(dataloader)):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(goat_model.device)

        outputs = goat_model(input_ids=inputs["input_ids"])

        for bi in range(inputs["labels"].size(0)):
            label = inputs["labels"][bi]
            pred = torch.argmax(outputs.logits[bi][inputs["last_token_indices"][bi]])

            if label == pred:
                correct_count += 1
            # else:
            #     print(f"Label: {tokenizer.decode(label)}, Prediction: {tokenizer.decode(pred)}")
            total_count += 1

del outputs
torch.cuda.empty_cache()

current_acc = round(correct_count / total_count * 100, 2)
print(f"Task accuracy: {current_acc}")

75it [00:30,  2.47it/s]

Task accuracy: 82.33





In [144]:
root_path = "./new_pp_exps/reverse/7_boxes"
path = root_path + "/direct_logit_heads.pt"
direct_logit_heads = analysis_utils.compute_topk_components(
    torch.load(path), k=52, largest=False
)

path = root_path + "/heads_affect_direct_logit.pt"
heads_affecting_direct_logit_heads = analysis_utils.compute_topk_components(
    torch.load(path), k=15, largest=False
)

path = root_path + "/heads_at_query_box_pos.pt"
head_at_query_box_token = analysis_utils.compute_topk_components(
    torch.load(path), k=30, largest=False
)

path = root_path + "/heads_at_prev_query_box_pos.pt"
heads_at_prev_box_pos = analysis_utils.compute_topk_components(
    torch.load(path), k=5, largest=False
)

intersection = []
for head in direct_logit_heads:
    if head in heads_affecting_direct_logit_heads:
        intersection.append(head)

for head in intersection:
    direct_logit_heads.remove(head)

In [145]:
print(
        len(direct_logit_heads),
        len(heads_affecting_direct_logit_heads),
        len(head_at_query_box_token),
        len(heads_at_prev_box_pos),
    )

40 15 30 5


In [139]:
llama_modules = [[f"model.layers.{layer}.self_attn.k_proj", 
                  f"model.layers.{layer}.self_attn.q_proj",
                  f"model.layers.{layer}.self_attn.v_proj",
                 f"model.layers.{layer}.self_attn.o_proj"] 
                 for layer in range(32)]
goat_modules = [[f"base_model.model.model.layers.{layer}.self_attn.k_proj", 
                 f"base_model.model.model.layers.{layer}.self_attn.q_proj",
                 f"base_model.model.model.layers.{layer}.self_attn.v_proj",
                f"base_model.model.model.layers.{layer}.self_attn.o_proj"] 
                for layer in range(32)]

llama_modules = [item for sublist in llama_modules for item in sublist]
goat_modules = [item for sublist in goat_modules for item in sublist]

In [140]:
goat_cache = {}

with torch.no_grad():
    for bi, inputs in tqdm(enumerate(dataloader)):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(goat_model.device)

        with TraceDict(goat_model, goat_modules, retain_input=True) as cache:
            _ = goat_model(inputs["input_ids"])
        
        for llama_layer, goat_layer in zip(llama_modules, goat_modules):
            if "o_proj" in llama_layer:
                if bi in goat_cache:
                    goat_cache[bi][llama_layer] = cache[goat_layer].input.cpu()
                else:
                    goat_cache[bi] = {}
                    goat_cache[bi][llama_layer] = cache[goat_layer].input.cpu()
            else:
                if bi in goat_cache:
                    goat_cache[bi][llama_layer] = cache[goat_layer].output.cpu()
                else:
                    goat_cache[bi] = {}
                    goat_cache[bi][llama_layer] = cache[goat_layer].output.cpu()

75it [00:44,  1.70it/s]


In [168]:
def cross_model_patching(inputs, outputs, layer, bi, relative_pos, input_tokens):
    if isinstance(inputs, tuple):
        inputs = inputs[0]

    if isinstance(outputs, tuple):
        outputs = outputs[0]

    cache = rearrange(
                goat_cache[bi][layer],
                "batch seq_len (n_heads d_head) -> batch seq_len n_heads d_head",
                n_heads=llama_model.config.num_attention_heads,
            )
    
    if "o_proj" in layer:
        pass
#         inputs = rearrange(
#                 inputs,
#                 "batch seq_len (n_heads d_head) -> batch seq_len n_heads d_head",
#                 n_heads=llama_model.config.num_attention_heads,
#             )

#         for rel_pos, heads in relative_pos.items():
#             curr_layer_heads = [h for l, h in heads if l == int(layer.split(".")[2])]            
            
#             if rel_pos == -1:
#                 for batch in range(inputs.size(0)):
#                     prev_query_box_pos = analysis_utils.compute_prev_query_box_pos(
#                         input_tokens["input_ids"][batch],
#                         input_tokens["last_token_indices"][batch]
#                     )
#                     for head in curr_layer_heads:
#                         inputs[batch, prev_query_box_pos, head] = cache[batch, prev_query_box_pos, head]

#             else:
#                 pos = inputs.size(1) - rel_pos - 1
#                 for head in curr_layer_heads:
#                     inputs[:, pos, head] = cache[:, pos, head]

#         inputs = rearrange(
#                 inputs,
#                 "batch seq_len n_heads d_head -> batch seq_len (n_heads d_head)",
#                 n_heads=llama_model.config.num_attention_heads,
#             )
#         w_o = llama_model.state_dict()[f"{layer}.weight"]
#         outputs = einsum(
#             inputs, w_o, "batch seq_len hidden_size, d_model hidden_size -> batch seq_len d_model"
#         )

    else:
        outputs = rearrange(
                outputs,
                "batch seq_len (n_heads d_head) -> batch seq_len n_heads d_head",
                n_heads=llama_model.config.num_attention_heads,
            )

        for rel_pos, heads in relative_pos.items():
            curr_layer_heads = [h for l, h in heads if l == int(layer.split(".")[2])]   
            
            if rel_pos == -1:
                for batch in range(inputs.size(0)):
                    prev_query_box_pos = analysis_utils.compute_prev_query_box_pos(
                        input_tokens["input_ids"][batch],
                        input_tokens["last_token_indices"][batch]
                    )
#                 if "v_proj" in layer:
#                     for head in curr_layer_heads:
#                         outputs[:, prev_query_box_pos, head] = cache[:, prev_query_box_pos, head]

#                 if "k_proj" in layer:
#                     for head in curr_layer_heads:
#                         outputs[:, prev_query_box_pos, head] = cache[:, prev_query_box_pos, head]

#                 if "q_proj" in layer:
#                     for head in curr_layer_heads:
#                         outputs[:, prev_query_box_pos, head] = cache[:, prev_query_box_pos, head]
            else:
#                 pos = outputs.size(1) - rel_pos - 1
#                 for batch in range(inputs.size(0)):
#                     prev_query_box_pos = analysis_utils.compute_prev_query_box_pos(
#                         input_tokens["input_ids"][batch],
#                         input_tokens["last_token_indices"][batch]
#                     )
#                     correct_obj_pos = prev_query_box_pos - 4
                if "v_proj" in layer:
                    for head in curr_layer_heads:
                        outputs[:, :, head] = cache[:, :, head]

                if "k_proj" in layer:
                    for head in curr_layer_heads:
                        outputs[:, :, head] = cache[:, :, head]

#                 if "q_proj" in layer:
#                     for head in curr_layer_heads:
#                         outputs[:, -1, head] = cache[:, -1, head]

        outputs = rearrange(
                    outputs,
                    "batch seq_len n_heads d_head -> batch seq_len (n_heads d_head)",
                    n_heads=llama_model.config.num_attention_heads,
                )

    return outputs

In [169]:
relative_pos = {}
# relative_pos[0] = direct_logit_heads
relative_pos[0] = heads_affecting_direct_logit_heads
# relative_pos[2] = head_at_query_box_token
# relative_pos[-1] = heads_at_prev_box_pos

In [170]:
correct_count, total_count = 0, 0
correct_obj_attn_score_sum, incorrect_obj_attn_score_sum, first_token_attn_score_sum = 0, 0, 0
box_attn_score_sum = 0
with torch.no_grad():
    for bi, inputs in tqdm(enumerate(dataloader)):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(llama_model.device)

        with TraceDict(llama_model, llama_modules, retain_input=True, edit_output=partial(
                                                                        cross_model_patching,
                                                                        bi = bi,
                                                                        relative_pos = relative_pos,
                                                                        input_tokens = inputs)) as trace:
                outputs = llama_model(inputs["input_ids"], output_attentions=True)

#         scaled_attn = {}
#         for layer_idx in range(32):
#             attn_score = outputs.attentions[layer_idx]
#             value_vector = trace[f"model.layers.{layer_idx}.self_attn.v_proj"].output
#             value_vector = rearrange(value_vector,
#                                     "batch seq_len (n_heads d_head) -> batch seq_len n_heads d_head",
#                                     n_heads=llama_model.config.num_attention_heads,
#                                 )
#             value_vector_norm = torch.norm(value_vector, dim=-1)
#             scaled_attn[layer_idx] = einsum(
#                 value_vector_norm,
#                 attn_score,
#                 "batch k_seq_len n_heads, batch n_heads q_seq_len k_seq_len -> batch n_heads q_seq_len k_seq_len",
#             )

#         box_attn_tmp = 0
        for batch in range(inputs["labels"].size(0)):
#             for l, h in heads_affecting_direct_logit_heads:
#                 box_attn_tmp += scaled_attn[l][batch, h, -1, -3]
#             box_attn_tmp /= len(heads_affecting_direct_logit_heads)
            
            label = inputs["labels"][batch]
            pred = torch.argmax(outputs.logits[batch][inputs["last_token_indices"][batch]])

            if label == pred:
                correct_count += 1
            # else:
            #     print(f"Label: {tokenizer.decode(label)}, Prediction: {tokenizer.decode(pred)}")
            total_count += 1

#         box_attn_score_sum += box_attn_tmp / inputs["labels"].size(0)

        del outputs
        torch.cuda.empty_cache()

current_acc = round(correct_count / total_count * 100, 2)
box_attn_score_sum = box_attn_score_sum / len(dataloader)
print(f"Task accuracy: {current_acc}")
# print(f"Task accuracy: {current_acc}\nCorrect object attn score: {correct_obj_attn_score_sum}\nIncorrect object attn score: {incorrect_obj_attn_score_sum}\nStart token attn score: {first_token_attn_score_sum}")

75it [00:50,  1.48it/s]

Task accuracy: 73.83





In [None]:
Task accuracy: 66.5, Attn Score: 0.01679055020213127

In [None]:

Task accuracy: 78.33, Attn Score: 0.017840217798948288

In [None]:
Task accuracy: 66.5, Attn Score: 0.12626506388187408

In [None]:
Task accuracy: 78.33, Attn Score: 0.1326800286769867

In [None]:
# Direct Logit Heads
# Original - 66.5
# Query - 75.17
# All keys (no Query) - 66.67
# All values - 69.0
# All Keys with final query -75.33
# All Keys with final query and correct object value vector - 76.0
# All Keys with final query and all value vectors - 76.5

In [None]:
# Heads affecting Direct Logit Heads
# Original - 66.5
# Query (final pos) - 69.67
# All Keys (no Query) - 61.67
# All values - 73.67
# All Keys with final query - 71.33
# All Keys with final query and box label pos value vector - 74.67
# All Keys with final query and box all value vectors - 78.33

In [None]:
# Task accuracy: 66.5
# Correct object attn score: 0.14423631131649017
# Incorrect object attn score: 0.0002922056009992957
# Start token attn score: 0.41476279497146606

In [110]:
# Task accuracy: 75.17
# Correct object attn score: 0.18178720772266388
# Incorrect object attn score: 0.00031543115619570017
# Start token attn score: 0.28832370042800903

In [97]:
input_patching_scores

{'query': 75.17, 'key': 66.33, 'value': 67.0}

In [96]:
input_patching_scores['value'] = 67.0

In [55]:
patching_scores

{'direct_logit_heads': 75.83,
 'heads_affect_direct_logit': 78.33,
 'head_at_query_box_token': 66.17,
 'heads_at_prev_box_pos': 67.33,
 'direct_logit_heads + heads_affect_direct_logit': 79.67,
 'all_head_groups': 79.5}

In [None]:
#             prev_query_box_pos = analysis_utils.compute_prev_query_box_pos(
#                 inputs["input_ids"][batch],
#                 inputs["last_token_indices"][batch]
#             )
#             correct_obj_pos = prev_query_box_pos - 4
#             object_pos = [i for i in range(2, 50, 7)]
#             object_pos.remove(correct_obj_pos)

#             correct_obj_attn_temp, incorrect_obj_attn_temp, first_token_attn_temp = 0, 0, 0
#             for l, h in direct_logit_heads:
#                 correct_obj_attn_temp += outputs.attentions[l][batch, h, -1, correct_obj_pos]
#                 first_token_attn_temp += outputs.attentions[l][batch, h, -1, 0]
                
#                 for pos in object_pos:
#                     incorrect_obj_attn_temp += outputs.attentions[l][batch, h, -1, pos]
#                 incorrect_obj_attn_temp = incorrect_obj_attn_temp/len(object_pos)
                
#             correct_obj_attn_score_sum += correct_obj_attn_temp/len(direct_logit_heads)
#             first_token_attn_score_sum += first_token_attn_temp/len(direct_logit_heads)
#             incorrect_obj_attn_score_sum += incorrect_obj_attn_temp/len(direct_logit_heads)


#         correct_obj_attn_score_sum = correct_obj_attn_score_sum / inputs["labels"].size(0)
#         first_token_attn_score_sum = first_token_attn_score_sum / inputs["labels"].size(0)
#         incorrect_obj_attn_score_sum = incorrect_obj_attn_score_sum / inputs["labels"].size(0)

In [73]:
object_pos = [i for i in range(2, 50, 7)]

In [77]:
llama_tokenizer.decode(inputs["input_ids"][batch])

' The cross is in Box Z, the machine is in Box K, the drink is in Box M, the magnet is in Box E, the paper is in Box A, the ball is in Box Y, the milk is in Box L. Box A contains the'

In [75]:
[llama_tokenizer.decode(token_idx) for token_idx in inputs["input_ids"][batch][object_pos]]

[' cross', ' machine', ' drink', ' magnet', ' paper', ' ball', ' milk']