In [3]:
import sys
import json
import torch
from collections import defaultdict
from functools import partial
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import LlamaTokenizer, LlamaForCausalLM
from tqdm import tqdm
from peft import PeftModel

sys.path.append("../")
from data.data_utils import *

from pp_utils import eval_circuit_performance, get_mean_activations

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

seed = 5
random.seed(seed)
torch.manual_seed(seed)

  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7f747c52c3f0>

## Loading Evaluation Data

In [4]:
tokenizer = LlamaTokenizer.from_pretrained(
    "hf-internal-testing/llama-tokenizer", padding_side="right"
)
tokenizer.pad_token_id = tokenizer.eos_token_id

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [5]:
data_file = "../data/dataset.jsonl"
batch_size = 50

raw_data = entity_tracking_example_sampler(
    tokenizer=tokenizer,
    num_samples=500,
    data_file=data_file,
    architecture="LlamaForCausalLM",
)

dataset = Dataset.from_dict(
    {
        "input_ids": raw_data[0],
        "last_token_indices": raw_data[1],
        "labels": raw_data[2],
    }
).with_format("torch")

print(f"Length of dataset: {len(dataset)}")

dataloader = DataLoader(dataset, batch_size=batch_size)

Length of dataset: 500


## Loading Circuits

In [6]:
with open("./results/circuits/llama_circuit.json", "r") as f:
    llama_circuit = json.load(f)

with open("./results/circuits/goat_circuit.json", "r") as f:
    goat_circuit = json.load(f)

In [7]:
print("Llama Circuit")
print(f"Value Fetcher Heads: {len(llama_circuit['value_fetcher'])}")
print(f"Position Transmitter Heads: {len(llama_circuit['pos_transmitter'])}")
print(f"Position Detector Heads: {len(llama_circuit['pos_detector'])}")
print(f"Structure Reader Heads: {len(llama_circuit['struct_reader'])}")
print(f"Total Heads: {len(llama_circuit['value_fetcher']) + len(llama_circuit['pos_transmitter']) + len(llama_circuit['pos_detector']) + len(llama_circuit['struct_reader'])}")

Llama Circuit
Value Fetcher Heads: 40
Position Transmitter Heads: 7
Position Detector Heads: 20
Structure Reader Heads: 5
Total Heads: 72


In [8]:
print("GOAT CIRCUIT")
print(f"Value Fetcher Heads: {len(goat_circuit['value_fetcher'])}")
print(f"Position Transmitter Heads: {len(goat_circuit['pos_transmitter'])}")
print(f"Position Detector Heads: {len(goat_circuit['pos_detector'])}")
print(f"Structure Reader Heads: {len(goat_circuit['struct_reader'])}")
print(f"Total Heads: {len(goat_circuit['value_fetcher']) + len(goat_circuit['pos_transmitter']) + len(goat_circuit['pos_detector']) + len(goat_circuit['struct_reader'])}")

GOAT CIRCUIT
Value Fetcher Heads: 68
Position Transmitter Heads: 28
Position Detector Heads: 40
Structure Reader Heads: 39
Total Heads: 175


## Helper Methods

In [9]:
def get_circuit(model, circuit_heads):
    circuit_components = {}
    circuit_components[0] = defaultdict(list)
    circuit_components[2] = defaultdict(list)
    circuit_components[-1] = defaultdict(list)

    for layer_idx, head in circuit_heads['value_fetcher']:
        if model.config.architectures[0] == "LlamaForCausalLM":
            layer = f"model.layers.{layer_idx}.self_attn.o_proj"
        else:
            layer = f"base_model.model.model.layers.{layer_idx}.self_attn.o_proj"
        circuit_components[0][layer].append(head)

    for layer_idx, head in circuit_heads['pos_transmitter']:
        if model.config.architectures[0] == "LlamaForCausalLM":
            layer = f"model.layers.{layer_idx}.self_attn.o_proj"
        else:
            layer = f"base_model.model.model.layers.{layer_idx}.self_attn.o_proj"
        circuit_components[0][layer].append(head)


    for layer_idx, head in circuit_heads['pos_detector']:
        if model.config.architectures[0] == "LlamaForCausalLM":
            layer = f"model.layers.{layer_idx}.self_attn.o_proj"
        else:
            layer = f"base_model.model.model.layers.{layer_idx}.self_attn.o_proj"
        circuit_components[2][layer].append(head)


    for layer_idx, head in circuit_heads['struct_reader']:
        if model.config.architectures[0] == "LlamaForCausalLM":
            layer = f"model.layers.{layer_idx}.self_attn.o_proj"
        else:
            layer = f"base_model.model.model.layers.{layer_idx}.self_attn.o_proj"
        circuit_components[-1][layer].append(head)
    
    return circuit_components

In [10]:
def eval_model_performance(model, dataloader):
    total_count = 0
    correct_count = 0
    model.eval()
    with torch.no_grad():
        for _, output in tqdm(enumerate(tqdm(dataloader))):
            for k, v in output.items():
                if v is not None and isinstance(v, torch.Tensor):
                    output[k] = v.to(model.device)

            outputs = model(input_ids=output["input_ids"])

            for bi in range(output["labels"].size(0)):
                label = output["labels"][bi]
                pred = torch.argmax(outputs.logits[bi][output["last_token_indices"][bi]])

                if label == pred:
                    correct_count += 1
                total_count += 1

    del outputs
    torch.cuda.empty_cache()

    current_acc = round(correct_count / total_count, 2)
    return current_acc

## Model and Circuit Performance: Llama-7B

In [36]:
path = "/home/local_nikhil/Projects/llama_weights/7B"

# Delete model if present in memory
if "model" in locals():
    del model
    torch.cuda.empty_cache()

model = LlamaForCausalLM.from_pretrained(path).to(device)

Loading checkpoint shards: 100%|██████████| 2/2 [00:06<00:00,  3.21s/it]


In [37]:
circuit_components = get_circuit(model, goat_circuit)

mean_activations, modules = get_mean_activations(
    model, tokenizer, data_file, num_samples=500, batch_size=50
)

Computing mean activations...


  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:21<00:00,  2.11s/it]


In [38]:
model_acc = eval_model_performance(model, dataloader)
print(f"Model Performance {model_acc}")

circuit_acc = eval_circuit_performance(model, dataloader, modules, circuit_components, mean_activations)
print(f"Circuit Performance {circuit_acc}")

print(f"Faithfulness: {round(circuit_acc/model_acc, 2)}")

100%|██████████| 10/10 [00:21<00:00,  2.12s/it]
10it [00:21,  2.12s/it]


Model Performance 0.66


100%|██████████| 10/10 [00:42<00:00,  4.25s/it]

Circuit Performance 0.77
Faithfulness: 1.17





## Model and Circuit Performance: Vicuna-7B

In [13]:
path = "AlekseyKorshuk/vicuna-7b"

# Delete model if present in memory
if "model" in locals():
    del model
    torch.cuda.empty_cache()

model = LlamaForCausalLM.from_pretrained(path).to(device)

Loading checkpoint shards: 100%|██████████| 14/14 [00:06<00:00,  2.12it/s]


In [14]:
circuit_components = get_circuit(model, llama_circuit)

mean_activations, modules = get_mean_activations(
    model, tokenizer, data_file, num_samples=500, batch_size=50
)

Computing mean activations...


  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:21<00:00,  2.13s/it]


In [15]:
model_acc = eval_model_performance(model, dataloader)
print(f"Model Performance {model_acc}")

circuit_acc = eval_circuit_performance(model, dataloader, modules, circuit_components, mean_activations)
print(f"Circuit Performance {circuit_acc}")

print(f"Faithfulness: {round(circuit_acc/model_acc, 2)}")

100%|██████████| 10/10 [00:21<00:00,  2.15s/it]
10it [00:21,  2.15s/it]


Model Performance 0.67


100%|██████████| 10/10 [00:43<00:00,  4.38s/it]

Circuit Performance 0.65
Faithfulness: 0.97





## Model and Circuit Performance: Goat-7B

In [27]:
base_model = "decapoda-research/llama-7b-hf"
lora_weights = "tiedong/goat-lora-7b"

# Delete model if present in memory
if "model" in locals():
    del model
    torch.cuda.empty_cache()

model = LlamaForCausalLM.from_pretrained(
    base_model,
    load_in_8bit=False,
    torch_dtype=torch.float32,
    device_map="auto",
)
model = PeftModel.from_pretrained(
    model,
    lora_weights,
    torch_dtype=torch.float32,
    device_map={'': 0},
)

Loading checkpoint shards: 100%|██████████| 33/33 [00:12<00:00,  2.58it/s]


In [67]:
circuit_components = get_circuit(model, goat_circuit)

mean_activations, modules = get_mean_activations(
    model, tokenizer, data_file, num_samples=500, batch_size=50
)

Computing mean activations...


  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:22<00:00,  2.24s/it]


In [68]:
model_acc = eval_model_performance(model, dataloader)
print(f"Model Performance: {model_acc}")

circuit_acc = eval_circuit_performance(model, dataloader, modules, circuit_components, mean_activations)
print(f"Circuit Performance: {circuit_acc}")

print(f"Faithfulness: {round(circuit_acc/model_acc, 2)}")

100%|██████████| 10/10 [00:22<00:00,  2.24s/it]
10it [00:22,  2.24s/it]


Model Performance: 0.82


100%|██████████| 10/10 [00:45<00:00,  4.51s/it]

Circuit Performance: 0.1
Faithfulness: 0.12





## Model and Circuit Performance: Naive finetuned model

In [17]:
path = "/media/local_nikhil/disk/weights_naive/possessed-candle-14"

# Delete model if present in memory
if "model" in locals():
    del model
    torch.cuda.empty_cache()

model = LlamaForCausalLM.from_pretrained(path).to(device)

Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.89s/it]


In [34]:
circuit_components = get_circuit(model, llama_circuit)

mean_activations, modules = get_mean_activations(
    model, tokenizer, data_file, num_samples=500, batch_size=50
)

Computing mean activations...


  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:21<00:00,  2.13s/it]


In [35]:
model_acc = eval_model_performance(model, dataloader)
print(f"Model Performance {model_acc}")

circuit_acc = eval_circuit_performance(model, dataloader, modules, circuit_components, mean_activations)
print(f"Circuit Performance {circuit_acc}")

print(f"Faithfulness: {round(circuit_acc/model_acc, 2)}")

100%|██████████| 10/10 [00:21<00:00,  2.13s/it]
10it [00:21,  2.13s/it]


Model Performance 0.82


100%|██████████| 10/10 [00:42<00:00,  4.28s/it]

Circuit Performance 0.72
Faithfulness: 0.88





## Circuit Comparison

In [16]:
head_groups = ['value_fetcher', 'pos_transmitter', 'pos_detector', 'struct_reader']

for head_group in head_groups:
    intersection = 0
    for l, h in llama_circuit[head_group]:
        if [l, h] in goat_circuit[head_group]:
            intersection += 1

    precision = round(intersection/len(llama_circuit[head_group]), 2)
    recall = round(intersection/len(goat_circuit[head_group]), 2)

    print(f"{head_group}: {len(llama_circuit[head_group])} | {len(goat_circuit[head_group])} | {intersection} | {precision} | {recall}")

value_fetcher: 40 | 68 | 27 | 0.68 | 0.4
pos_transmitter: 7 | 28 | 6 | 0.86 | 0.21
pos_detector: 20 | 40 | 16 | 0.8 | 0.4
struct_reader: 5 | 39 | 5 | 1.0 | 0.13
