In [1]:
import json
import random
import torch
from functools import partial
from baukit import TraceDict
from einops import rearrange, einsum
from tqdm import tqdm

from cmap_utils import get_model_and_tokenizer, load_data, eval_model_performance, cmap_in, cmap_out

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

seed = 10
random.seed(seed)
torch.manual_seed(seed)

%load_ext autoreload
%autoreload 2

# Loading Models and Tokenizer

In [2]:
llama_model, tokenizer = get_model_and_tokenizer(model_name="llama", device=device)
# goat_model, _ = get_model_and_tokenizer(model_name="goat", device=device)
float_model, _ = get_model_and_tokenizer(model_name="float", device=device)

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

# Loading Data

In [3]:
data_file = "../data/dataset.jsonl"
dataloader = load_data(tokenizer=tokenizer, data_file=data_file, num_samples=500, batch_size=8)

Length of dataset: 500


# Loading circuits

In [4]:
with open("../experiment_1/results/circuits/llama_circuit.json", "r") as f:
    llama_circuit = json.load(f)

with open("../experiment_1/results/circuits/goat_circuit.json", "r") as f:
    goat_circuit = json.load(f)

with open("../experiment_1/results/circuits/float_circuit.json", "r") as f:
    float_circuit = json.load(f)

In [5]:
with open("../experiment_2/results/DCM/llama_circuit/value_fetcher/object_value/0.01.txt", "r") as f:
    data = f.readlines()
    llama_value_fetcher = json.loads(data[0].split(": ")[1])

with open("../experiment_2/results/DCM/llama_circuit/pos_transmitter/positional/0.01.txt", "r") as f:
    data = f.readlines()
    llama_pos_transmitter = json.loads(data[0].split(": ")[1])

with open("../experiment_2/results/DCM/llama_circuit/pos_detector/positional/0.01.txt", "r") as f:
    data = f.readlines()
    llama_pos_detector = json.loads(data[0].split(": ")[1])

llama_struct_reader = llama_circuit["struct_reader"]

print(f"Value Fetcher Heads: {len(llama_value_fetcher)}")
print(f"Heads affecting direct logit heads: {len(llama_pos_transmitter)}")
print(f"Heads at query box token: {len(llama_pos_detector)}")
print(f"Heads at prev query box token: {len(llama_struct_reader)}")

Value Fetcher Heads: 40
Heads affecting direct logit heads: 5
Heads at query box token: 14
Heads at prev query box token: 5


In [6]:
with open("../experiment_2/results/DCM/goat_circuit/value_fetcher/object_value/0.01.txt", "r") as f:
    data = f.readlines()
    goat_value_fetcher = json.loads(data[0].split(": ")[1])

with open("../experiment_2/results/DCM/goat_circuit/pos_transmitter/positional/0.01.txt", "r") as f:
    data = f.readlines()
    goat_pos_transmitter = json.loads(data[0].split(": ")[1])

with open("../experiment_2/results/DCM/goat_circuit/pos_detector/positional/0.01.txt", "r") as f:
    data = f.readlines()
    goat_pos_detector = json.loads(data[0].split(": ")[1])

goat_struct_reader = goat_circuit["struct_reader"]

print(f"Value Fetcher Heads: {len(goat_value_fetcher)}")
print(f"Heads affecting direct logit heads: {len(goat_pos_transmitter)}")
print(f"Heads at query box token: {len(goat_pos_detector)}")
print(f"Heads at prev query box token: {len(goat_struct_reader)}")

Value Fetcher Heads: 56
Heads affecting direct logit heads: 15
Heads at query box token: 18
Heads at prev query box token: 39


In [7]:
with open("../experiment_2/results/DCM/float_circuit/value_fetcher/object_value/0.01.txt", "r") as f:
    data = f.readlines()
    float_value_fetcher = json.loads(data[0].split(": ")[1])

with open("../experiment_2/results/DCM/float_circuit/pos_transmitter/positional/0.01.txt", "r") as f:
    data = f.readlines()
    float_pos_transmitter = json.loads(data[0].split(": ")[1])

with open("../experiment_2/results/DCM/float_circuit/pos_detector/positional/0.01.txt", "r") as f:
    data = f.readlines()
    float_pos_detector = json.loads(data[0].split(": ")[1])

float_struct_reader = float_circuit["struct_reader"]

print(f"Value Fetcher Heads: {len(float_value_fetcher)}")
print(f"Heads affecting direct logit heads: {len(float_pos_transmitter)}")
print(f"Heads at query box token: {len(float_pos_detector)}")
print(f"Heads at prev query box token: {len(float_struct_reader)}")

Value Fetcher Heads: 60
Heads affecting direct logit heads: 13
Heads at query box token: 22
Heads at prev query box token: 38


# CMAP (Goat -> Llama)

### Model Performance

In [4]:
llama_acc = eval_model_performance(llama_model, dataloader, device)
goat_acc = eval_model_performance(goat_model, dataloader, device)

print(f"LLAMA accuracy: {llama_acc}")
print(f"Goat accuracy: {goat_acc}")

  0%|          | 0/63 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
  2%|▏         | 1/63 [00:00<00:32,  1.90it/s]
  3%|▎         | 2/63 [00:00<00:26,  2.27it/s]
  5%|▍         | 3/63 [00:01<00:24,  2.42it/s]
  6%|▋         | 4/63 [00:01<00:23,  2.50it/s]
  8%|▊         | 5/63 [00:02<00:22,  2.54it/s]
 10%|▉         | 6/63 [00:02<00:22,  2.57it/s]
 11%|█         | 7/63 [00:02<00:21,  2.59it/s]
 13%|█▎        | 8/63 [00:03<00:21,  2.60it/s]
 14%|█▍        | 9/63 [00:03<00:20,  2.61it/s]
 16%|█▌        | 10/63 [00:03<00:20,  2.61it/s]
 17%|█▋        | 11/63 [00:04<00:19,  2.62it/s]
 19%|█▉        | 12/63 [00:04<00:19,  2.62it/s]
 21%|██        | 13/63 [00:05<00:19,  2.62it/s]
 22%|██▏       | 14/63 [00:05<00:18,  2.62it/s]
 24%|██▍       | 15/63 [00:05<00:18,  2.62it/s]
 25%|██▌       | 16/63 [00:06<00:17,  2.62it/s]
 27%|██▋       | 17/63 [00:06<00:17,  2.62it/s]
 29%|██▊       | 18/63 [00:07<00:17,  2.62it/s]
 30%|███       | 19/63 [00:07<00:16,  2.62it/s]
 32%|███▏      | 20/63 [00:07<00:16,

LLAMA accuracy: 0.66
Goat accuracy: 0.82





### Loading Model Activations

In [5]:
llama_modules = [[f"model.layers.{layer}.self_attn.k_proj", 
                  f"model.layers.{layer}.self_attn.q_proj",
                  f"model.layers.{layer}.self_attn.v_proj",
                 f"model.layers.{layer}.self_attn.o_proj"] 
                 for layer in range(llama_model.config.num_hidden_layers)]
goat_modules = [[f"base_model.model.model.layers.{layer}.self_attn.k_proj", 
                 f"base_model.model.model.layers.{layer}.self_attn.q_proj",
                 f"base_model.model.model.layers.{layer}.self_attn.v_proj",
                f"base_model.model.model.layers.{layer}.self_attn.o_proj"] 
                for layer in range(goat_model.config.num_hidden_layers)]

llama_modules = [item for sublist in llama_modules for item in sublist]
goat_modules = [item for sublist in goat_modules for item in sublist]

In [6]:
goat_cache = {}

with torch.no_grad():
    for bi, inputs in tqdm(enumerate(dataloader), desc="goat_cache"):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(goat_model.device)

        with TraceDict(goat_model, goat_modules, retain_input=True) as cache:
            _ = goat_model(inputs["input_ids"])
        
        for goat_layer, llama_layer in zip(goat_modules, llama_modules):
            if "o_proj" in llama_layer and "o_proj" in goat_layer:
                if bi in goat_cache:
                    goat_cache[bi][llama_layer] = cache[goat_layer].input.cpu()
                else:
                    goat_cache[bi] = {}
                    goat_cache[bi][llama_layer] = cache[goat_layer].input.cpu()
            else:
                if bi in goat_cache:
                    goat_cache[bi][llama_layer] = cache[goat_layer].output.cpu()
                else:
                    goat_cache[bi] = {}
                    goat_cache[bi][llama_layer] = cache[goat_layer].output.cpu()

goat_cache: 63it [00:53,  1.17it/s]


In [7]:
llama_cache = {}

with torch.no_grad():
    for bi, inputs in tqdm(enumerate(dataloader), desc="llama_cache"):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(llama_model.device)

        with TraceDict(llama_model, llama_modules, retain_input=True) as cache:
            _ = llama_model(inputs["input_ids"])
        
        for llama_layer in llama_modules:
            if "o_proj" in llama_layer:
                if bi in llama_cache:
                    llama_cache[bi][llama_layer] = cache[llama_layer].input.cpu()
                else:
                    llama_cache[bi] = {}
                    llama_cache[bi][llama_layer] = cache[llama_layer].input.cpu()
            else:
                if bi in llama_cache:
                    llama_cache[bi][llama_layer] = cache[llama_layer].output.cpu()
                else:
                    llama_cache[bi] = {}
                    llama_cache[bi][llama_layer] = cache[llama_layer].output.cpu()

llama_cache: 63it [00:46,  1.34it/s]


### CMAP (output patching)

In [21]:
# Full circuit (Select group of heads for CMAP accordingly)
pos_heads_dict = {}
pos_heads_dict[0] = goat_value_fetcher
pos_heads_dict[0] += goat_pos_transmitter
pos_heads_dict[2] = goat_pos_detector
pos_heads_dict[-1] = goat_circuit['struct_reader']

In [22]:
correct_count, total_count = 0, 0

with torch.no_grad():
    for bi, inputs in tqdm(enumerate(dataloader)):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(llama_model.device)

        with TraceDict(llama_model,
                       llama_modules,
                       retain_input=True,
                       edit_output=partial(
                            cmap_out,
                            model = llama_model,
                            goat_cache = goat_cache,
                            bi = bi,
                            pos_heads_dict = pos_heads_dict,
                            input_tokens = inputs)) as _:
                outputs = llama_model(inputs["input_ids"], output_attentions=True)

        for bi in range(inputs["labels"].size(0)):
            label = inputs["labels"][bi]
            pred = torch.argmax(outputs.logits[bi][inputs["last_token_indices"][bi]])

            if label == pred:
                correct_count += 1
            total_count += 1

        del outputs
        torch.cuda.empty_cache()

current_acc = round(correct_count / total_count, 2)
print(f"Task accuracy: {current_acc}")

63it [00:33,  1.90it/s]

Task accuracy: 0.82





Output CMAP Results (Goat -> Llama):
- Full Circuit: 0.82
- Value Fetcher: 0.82
- Position Transmitter: 0.78
- Position Detector: 0.62
- Structure Reader: 0.65


# CMAP (FLoat -> Llama)

### Model Performance

In [4]:
llama_acc = eval_model_performance(llama_model, dataloader, device)
float_acc = eval_model_performance(float_model, dataloader, device)

print(f"LLAMA accuracy: {llama_acc}")
print(f"FLoat accuracy: {float_acc}")

  0%|          | 0/63 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
  2%|▏         | 1/63 [00:00<00:32,  1.92it/s]
  3%|▎         | 2/63 [00:00<00:26,  2.28it/s]
  5%|▍         | 3/63 [00:01<00:24,  2.42it/s]
  6%|▋         | 4/63 [00:01<00:23,  2.50it/s]
  8%|▊         | 5/63 [00:02<00:22,  2.54it/s]
 10%|▉         | 6/63 [00:02<00:22,  2.57it/s]
 11%|█         | 7/63 [00:02<00:21,  2.59it/s]
 13%|█▎        | 8/63 [00:03<00:21,  2.60it/s]
 14%|█▍        | 9/63 [00:03<00:20,  2.61it/s]
 16%|█▌        | 10/63 [00:03<00:20,  2.61it/s]
 17%|█▋        | 11/63 [00:04<00:19,  2.61it/s]
 19%|█▉        | 12/63 [00:04<00:19,  2.62it/s]
 21%|██        | 13/63 [00:05<00:19,  2.62it/s]
 22%|██▏       | 14/63 [00:05<00:18,  2.62it/s]
 24%|██▍       | 15/63 [00:05<00:18,  2.62it/s]
 25%|██▌       | 16/63 [00:06<00:17,  2.62it/s]
 27%|██▋       | 17/63 [00:06<00:17,  2.62it/s]
 29%|██▊       | 18/63 [00:07<00:17,  2.62it/s]
 30%|███       | 19/63 [00:07<00:16,  2.62it/s]
 32%|███▏      | 20/63 [00:07<00:16,

LLAMA accuracy: 0.66
FLoat accuracy: 0.82





### Loading Model Activations

In [8]:
llama_modules = [[f"model.layers.{layer}.self_attn.k_proj", 
                  f"model.layers.{layer}.self_attn.q_proj",
                  f"model.layers.{layer}.self_attn.v_proj",
                 f"model.layers.{layer}.self_attn.o_proj"] 
                 for layer in range(llama_model.config.num_hidden_layers)]

llama_modules = [item for sublist in llama_modules for item in sublist]
float_modules = llama_modules.copy()

In [9]:
float_cache = {}

with torch.no_grad():
    for bi, inputs in tqdm(enumerate(dataloader), desc="float_cache"):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(float_model.device)

        with TraceDict(float_model, float_modules, retain_input=True) as cache:
            _ = float_model(inputs["input_ids"])
        
        for float_layer in float_modules:
            if "o_proj" in float_layer:
                if bi in float_cache:
                    float_cache[bi][float_layer] = cache[float_layer].input.cpu()
                else:
                    float_cache[bi] = {}
                    float_cache[bi][float_layer] = cache[float_layer].input.cpu()
            else:
                if bi in float_cache:
                    float_cache[bi][float_layer] = cache[float_layer].output.cpu()
                else:
                    float_cache[bi] = {}
                    float_cache[bi][float_layer] = cache[float_layer].output.cpu()

float_cache: 63it [00:42,  1.50it/s]


In [10]:
llama_cache = {}

with torch.no_grad():
    for bi, inputs in tqdm(enumerate(dataloader), desc="llama_cache"):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(llama_model.device)

        with TraceDict(llama_model, llama_modules, retain_input=True) as cache:
            _ = llama_model(inputs["input_ids"])
        
        for llama_layer in llama_modules:
            if "o_proj" in llama_layer:
                if bi in llama_cache:
                    llama_cache[bi][llama_layer] = cache[llama_layer].input.cpu()
                else:
                    llama_cache[bi] = {}
                    llama_cache[bi][llama_layer] = cache[llama_layer].input.cpu()
            else:
                if bi in llama_cache:
                    llama_cache[bi][llama_layer] = cache[llama_layer].output.cpu()
                else:
                    llama_cache[bi] = {}
                    llama_cache[bi][llama_layer] = cache[llama_layer].output.cpu()

llama_cache: 63it [00:44,  1.43it/s]


### CMAP (output patching)

In [11]:
# Full circuit (Select group of heads for CMAP accordingly)
pos_heads_dict = {}
pos_heads_dict[0] = float_value_fetcher
pos_heads_dict[0] += float_pos_transmitter
pos_heads_dict[2] = float_pos_detector
pos_heads_dict[-1] = float_circuit['struct_reader']

In [None]:
correct_count, total_count = 0, 0

with torch.no_grad():
    for bi, inputs in tqdm(enumerate(dataloader)):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(llama_model.device)

        with TraceDict(llama_model,
                       llama_modules,
                       retain_input=True,
                       edit_output=partial(
                            cmap_out,
                            model = llama_model,
                            finetuned_cache = float_cache,
                            bi = bi,
                            pos_heads_dict = pos_heads_dict,
                            input_tokens = inputs)) as _:
                outputs = llama_model(inputs["input_ids"], output_attentions=True)

        for bi in range(inputs["labels"].size(0)):
            label = inputs["labels"][bi]
            pred = torch.argmax(outputs.logits[bi][inputs["last_token_indices"][bi]])

            if label == pred:
                correct_count += 1
            total_count += 1

        del outputs
        torch.cuda.empty_cache()

current_acc = round(correct_count / total_count, 2)
print(f"Task accuracy: {current_acc}")

28it [00:14,  1.91it/s]

Output CMAP Results (FLoat -> Llama):
- Full Circuit: 0.82
- Value Fetcher: 0.82
- Position Transmitter: 0.74
- Position Detector: 0.56
- Structure Reader: 0.65