In [2]:
import torch
import os
import json
from torch.nn import CosineSimilarity
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer
from functools import partial
from baukit import TraceDict
from einops import rearrange, einsum
from collections import defaultdict
import matplotlib.pyplot as plt
from plotly_utils import imshow, scatter
from tqdm import tqdm
from datasets import Dataset
from torch.utils.data import DataLoader
import math
import seaborn as sns
from peft import PeftModel
import pickle

import pysvelte
import analysis_utils
from counterfactual_datasets.entity_tracking import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
path = "./llama_7b/"
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(path).to(device)

# base_model = "decapoda-research/llama-7b-hf"
# lora_weights = "tiedong/goat-lora-7b"

# tokenizer = LlamaTokenizer.from_pretrained(
#     "hf-internal-testing/llama-tokenizer", padding_side="right"
# )
# model = LlamaForCausalLM.from_pretrained(
#     base_model,
#     load_in_8bit=False,
#     torch_dtype=torch.float32,
#     device_map="auto",
# )
# model = PeftModel.from_pretrained(
#     model,
#     lora_weights,
#     torch_dtype=torch.float32,
#     device_map={"": 0},
# )

tokenizer.pad_token_id = tokenizer.eos_token_id

Loading checkpoint shards: 100%|██████████| 3/3 [00:10<00:00,  3.65s/it]


## Load Data

In [45]:
num_boxes = 7
batch_size = 8
data_file_path = f"./box_datasets/no_instructions/alternative/Random/{num_boxes}/train.jsonl"
object_file_path = "./box_datasets/filtered_objects_with_bnc_frequency.csv"

In [46]:
def load_data(raw_data, batch_size):

    dataset = Dataset.from_dict(
        {
            "base_input_ids": raw_data[0],
            "base_input_last_pos": raw_data[1],
            "source_input_ids": raw_data[2],
            "source_input_last_pos": raw_data[3],
            "labels": raw_data[4],
        }
    ).with_format("torch")
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
    )

    return dataloader

In [184]:
raw_data = add_raw_text_at_start(
        tokenizer=tokenizer,
        num_samples=500,
        data_file=data_file_path,
    )

In [185]:
dataloader = load_data(
    raw_data=raw_data, batch_size=batch_size
)

In [186]:
data = next(enumerate(dataloader))[1]
bi = 0
print(tokenizer.decode(data["base_input_ids"][bi][: data["base_input_last_pos"][bi] + 1]))
print(tokenizer.decode(data["source_input_ids"][bi][: data["source_input_last_pos"][bi] + 1]))
print(tokenizer.decode(data["labels"][bi]))

 The document is in Box X, the pot is in Box T, the magnet is in Box A, the game is in Box E, the bill is in Box M, the cross is in Box K, the map is in Box D. Box X contains the
 There are three boxes, Box PP, Box BB and Box AA, the plant is in Box D, the fig is in Box E, the brick is in Box K, the radio is in Box M, the book is in Box S, the magnet is in Box C, the rock is in Box Z. Box M contains the
 game


## Activation Patching

In [187]:
with open("./new_masks/llama-7b/heads_affect_direct_logit/positional/0.01.txt", "r") as f:
    data = f.readlines()
    positional_info_fetcher_heads = json.loads(data[0].split(": ")[1])

In [188]:
modules = [f"model.layers.{i}.self_attn.o_proj" for i in range(32)]

In [189]:
source_cache = {}
for bi, inputs in enumerate(dataloader):
    for k, v in inputs.items():
        if v is not None and isinstance(v, torch.Tensor):
            inputs[k] = v.to(model.device)
    
    with TraceDict(model, modules, retain_input=True) as cache:
        _ = model(inputs["source_input_ids"])

    for module in modules:
        if bi in source_cache:
            source_cache[bi][module] = cache[module].input.detach().cpu()
        else:
            source_cache[bi] = {module: cache[module].input.detach().cpu()}

In [190]:
def patching(inputs, output, layer, patching_heads, bi):
    if isinstance(inputs, tuple):
        inputs = inputs[0]
    
    inputs = rearrange(inputs, 
                       "batch seq_len (n_heads d_head) -> batch seq_len n_heads d_head",
                       n_heads=model.config.num_attention_heads)
    
    cache = rearrange(source_cache[bi][layer],
                        "batch seq_len (n_heads d_head) -> batch seq_len n_heads d_head",
                        n_heads=model.config.num_attention_heads)

    for rel_pos in patching_heads.keys():
        layer_index = int(layer.split(".")[2])
        curr_layer_heads = [h for l, h in patching_heads[rel_pos] if l == layer_index]

        # pos = inputs.size(1) - rel_pos - 1
        for head in curr_layer_heads:
            inputs[:, -1, head] = cache[:, -1, head]

    inputs = rearrange(inputs,
                        "batch seq_len n_heads d_head -> batch seq_len (n_heads d_head)",
                        n_heads=model.config.num_attention_heads)
    
    w_o = model.state_dict()[f"{layer}.weight"]
    output = einsum(inputs, 
                    w_o, 
                    "batch seq_len hidden_size, d_model hidden_size -> batch seq_len d_model")
    
    del w_o
    torch.cuda.empty_cache()
    return output

In [191]:
patching_heads = {
    0: positional_info_fetcher_heads
}

In [192]:
correct_count, total_count = 0, 0
for bi, inputs in tqdm(enumerate(dataloader)):
    for k, v in inputs.items():
        if v is not None and isinstance(v, torch.Tensor):
            inputs[k] = v.to(model.device)

    with TraceDict(model, modules, retain_input=True, edit_output=partial(patching, 
                                                                          patching_heads = patching_heads,
                                                                          bi = bi)) as cache:
        outputs = model(inputs["base_input_ids"])

    for idx in range(inputs["base_input_ids"].size(0)):
        label = inputs["labels"][idx].item()
        pred = torch.argmax(outputs.logits[idx, -1], dim=-1).item()

        if label == pred:
            correct_count += 1
        total_count += 1

    del outputs
    torch.cuda.empty_cache()

print(f"Accuracy: {correct_count / total_count}")

0it [00:00, ?it/s]

63it [00:36,  1.74it/s]

Accuracy: 0.252





In [14]:
Positional Info relative to what?
# Raw text at the start - 0.312
# Raw text at the end - 0.36
# Additional tokens between box and object - 0.398

Is the model keeping tracking of boxes?
# One additional segment at the start - 0.224
# One additional segment at the end - 0.316
# Additional boxes before correct segment - 0.2

Semantic association and token order seem to be important
# Incorrect box segment index - 0.162
# Box object order altered - 0.156
# object is not in the box - 0.154

Commas are irrelevant
# No comma - 0.407
# Commas after objects - 0.386

## Additional desiderata

In [3]:
with open("additional_desiderata_results.json", "r") as f:
    data = json.load(f)

In [7]:
for desideratum in data.keys():
    values = data[desideratum]
    std = np.std(values)
    print(f"{desideratum}: {round(np.mean(values), 2), round(std, 2)}")

raw_text_start: (34.56, 1.75)
raw_text_end: (33.32, 1.61)
additional_tokens_btw_obj_and_box: (38.24, 1.63)
add_segment_start: (24.6, 1.33)
add_segment_end: (30.98, 2.02)
add_boxes_before_correct_segment: (19.66, 1.2)
incorrect_box_segment_index: (17.1, 1.77)
Box_object_altered_order: (14.18, 1.66)
object_not_in_box: (17.18, 1.4)
no_comma: (37.5, 1.24)
comma_after_object: (36.58, 1.89)
