In [1]:
import json
import random
import torch
from functools import partial
from baukit import TraceDict
from einops import rearrange, einsum
from tqdm import tqdm

from cmap_utils import get_model_and_tokenizer, load_data, eval_model_performance, cmap_in, cmap_out

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

seed = 10
random.seed(seed)
torch.manual_seed(seed)

%load_ext autoreload
%autoreload 2

### Loading Models and Tokenizer

In [2]:
llama_model, tokenizer = get_model_and_tokenizer(model_name="llama", device=device)
goat_model, _ = get_model_and_tokenizer(model_name="naive", device=device)

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

### Loading Data

In [3]:
data_file = "../data/dataset.jsonl"
dataloader = load_data(tokenizer=tokenizer, data_file=data_file, num_samples=500, batch_size=8)

Length of dataset: 500


### Models Performance

In [4]:
llama_acc = eval_model_performance(llama_model, dataloader, device)
goat_acc = eval_model_performance(goat_model, dataloader, device)

print(f"LLAMA accuracy: {llama_acc}")
print(f"Goat accuracy: {goat_acc}")

  0%|          | 0/63 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
  2%|▏         | 1/63 [00:00<00:32,  1.91it/s]
  3%|▎         | 2/63 [00:00<00:26,  2.28it/s]
  5%|▍         | 3/63 [00:01<00:24,  2.43it/s]
  6%|▋         | 4/63 [00:01<00:23,  2.50it/s]
  8%|▊         | 5/63 [00:02<00:22,  2.55it/s]
 10%|▉         | 6/63 [00:02<00:22,  2.58it/s]
 11%|█         | 7/63 [00:02<00:21,  2.60it/s]
 13%|█▎        | 8/63 [00:03<00:21,  2.61it/s]
 14%|█▍        | 9/63 [00:03<00:20,  2.62it/s]
 16%|█▌        | 10/63 [00:03<00:20,  2.62it/s]
 17%|█▋        | 11/63 [00:04<00:19,  2.63it/s]
 19%|█▉        | 12/63 [00:04<00:19,  2.63it/s]
 21%|██        | 13/63 [00:05<00:19,  2.63it/s]
 22%|██▏       | 14/63 [00:05<00:18,  2.62it/s]
 24%|██▍       | 15/63 [00:05<00:18,  2.62it/s]
 25%|██▌       | 16/63 [00:06<00:17,  2.62it/s]
 27%|██▋       | 17/63 [00:06<00:17,  2.62it/s]
 29%|██▊       | 18/63 [00:06<00:17,  2.62it/s]
 30%|███       | 19/63 [00:07<00:16,  2.62it/s]
 32%|███▏      | 20/63 [00:07<00:16,

LLAMA accuracy: 0.66
Goat accuracy: 0.82





### Loading Model Activations

In [6]:
llama_modules = [[f"model.layers.{layer}.self_attn.k_proj", 
                  f"model.layers.{layer}.self_attn.q_proj",
                  f"model.layers.{layer}.self_attn.v_proj",
                 f"model.layers.{layer}.self_attn.o_proj"] 
                 for layer in range(llama_model.config.num_hidden_layers)]
goat_modules = [[f"model.layers.{layer}.self_attn.k_proj", 
                 f"model.layers.{layer}.self_attn.q_proj",
                 f"model.layers.{layer}.self_attn.v_proj",
                f"model.layers.{layer}.self_attn.o_proj"] 
                for layer in range(goat_model.config.num_hidden_layers)]

llama_modules = [item for sublist in llama_modules for item in sublist]
goat_modules = [item for sublist in goat_modules for item in sublist]

In [7]:
goat_cache = {}

with torch.no_grad():
    for bi, inputs in tqdm(enumerate(dataloader), desc="goat_cache"):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(goat_model.device)

        with TraceDict(goat_model, goat_modules, retain_input=True) as cache:
            _ = goat_model(inputs["input_ids"])
        
        for goat_layer, llama_layer in zip(goat_modules, llama_modules):
            if "o_proj" in llama_layer and "o_proj" in goat_layer:
                if bi in goat_cache:
                    goat_cache[bi][llama_layer] = cache[goat_layer].input.cpu()
                else:
                    goat_cache[bi] = {}
                    goat_cache[bi][llama_layer] = cache[goat_layer].input.cpu()
            else:
                if bi in goat_cache:
                    goat_cache[bi][llama_layer] = cache[goat_layer].output.cpu()
                else:
                    goat_cache[bi] = {}
                    goat_cache[bi][llama_layer] = cache[goat_layer].output.cpu()

goat_cache: 63it [00:44,  1.40it/s]


In [8]:
llama_cache = {}

with torch.no_grad():
    for bi, inputs in tqdm(enumerate(dataloader), desc="llama_cache"):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(llama_model.device)

        with TraceDict(llama_model, llama_modules, retain_input=True) as cache:
            _ = llama_model(inputs["input_ids"])
        
        for llama_layer in llama_modules:
            if "o_proj" in llama_layer:
                if bi in llama_cache:
                    llama_cache[bi][llama_layer] = cache[llama_layer].input.cpu()
                else:
                    llama_cache[bi] = {}
                    llama_cache[bi][llama_layer] = cache[llama_layer].input.cpu()
            else:
                if bi in llama_cache:
                    llama_cache[bi][llama_layer] = cache[llama_layer].output.cpu()
                else:
                    llama_cache[bi] = {}
                    llama_cache[bi][llama_layer] = cache[llama_layer].output.cpu()

llama_cache: 63it [00:44,  1.41it/s]


### Loading circuit

In [9]:
with open("../experiment_1/results/circuits/llama_circuit.json", "r") as f:
    llama_circuit = json.load(f)

with open("../experiment_1/results/circuits/goat_circuit.json", "r") as f:
    goat_circuit = json.load(f)

In [10]:
with open("../experiment_2/results/DCM/llama_circuit/value_fetcher/object_value/0.01.txt", "r") as f:
    data = f.readlines()
    llama_value_fetcher = json.loads(data[0].split(": ")[1])

with open("../experiment_2/results/DCM/llama_circuit/pos_transmitter/positional/0.01.txt", "r") as f:
    data = f.readlines()
    llama_pos_transmitter = json.loads(data[0].split(": ")[1])

with open("../experiment_2/results/DCM/llama_circuit/pos_detector/positional/0.01.txt", "r") as f:
    data = f.readlines()
    llama_pos_detector = json.loads(data[0].split(": ")[1])

llama_struct_reader = llama_circuit["struct_reader"]

print(f"Value Fetcher Heads: {len(llama_value_fetcher)}")
print(f"Heads affecting direct logit heads: {len(llama_pos_transmitter)}")
print(f"Heads at query box token: {len(llama_pos_detector)}")
print(f"Heads at prev query box token: {len(llama_struct_reader)}")

Value Fetcher Heads: 40
Heads affecting direct logit heads: 5
Heads at query box token: 14
Heads at prev query box token: 5


In [11]:
with open("../experiment_2/results/DCM/goat_circuit/value_fetcher/object_value/0.01.txt", "r") as f:
    data = f.readlines()
    goat_value_fetcher = json.loads(data[0].split(": ")[1])

with open("../experiment_2/results/DCM/goat_circuit/pos_transmitter/positional/0.01.txt", "r") as f:
    data = f.readlines()
    goat_pos_transmitter = json.loads(data[0].split(": ")[1])

with open("../experiment_2/results/DCM/goat_circuit/pos_detector/positional/0.01.txt", "r") as f:
    data = f.readlines()
    goat_pos_detector = json.loads(data[0].split(": ")[1])

goat_struct_reader = goat_circuit["struct_reader"]

print(f"Value Fetcher Heads: {len(goat_value_fetcher)}")
print(f"Heads affecting direct logit heads: {len(goat_pos_transmitter)}")
print(f"Heads at query box token: {len(goat_pos_detector)}")
print(f"Heads at prev query box token: {len(goat_struct_reader)}")

Value Fetcher Heads: 56
Heads affecting direct logit heads: 15
Heads at query box token: 18
Heads at prev query box token: 39


### CMAP (output patching)

In [14]:
# Full circuit (Select group of heads for CMAP accordingly)
pos_heads_dict = {}
# pos_heads_dict[0] = goat_value_fetcher
pos_heads_dict[0] = goat_circuit['pos_transmitter']
# pos_heads_dict[2] = goat_pos_detector
# pos_heads_dict[-1] = goat_circuit['struct_reader']

In [15]:
correct_count, total_count = 0, 0

with torch.no_grad():
    for bi, inputs in tqdm(enumerate(dataloader)):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(llama_model.device)

        with TraceDict(llama_model,
                       llama_modules,
                       retain_input=True,
                       edit_output=partial(
                            cmap_out,
                            model = llama_model,
                            goat_cache = goat_cache,
                            bi = bi,
                            pos_heads_dict = pos_heads_dict,
                            input_tokens = inputs)) as _:
                outputs = llama_model(inputs["input_ids"], output_attentions=True)

        for bi in range(inputs["labels"].size(0)):
            label = inputs["labels"][bi]
            pred = torch.argmax(outputs.logits[bi][inputs["last_token_indices"][bi]])

            if label == pred:
                correct_count += 1
            total_count += 1

        del outputs
        torch.cuda.empty_cache()

current_acc = round(correct_count / total_count, 2)
print(f"Task accuracy: {current_acc}")

63it [00:28,  2.23it/s]

Task accuracy: 0.8





Output CMAP Results (Goat -> Llama):
- Full Circuit: 0.82
- Value Fetcher: 0.82
- Position Transmitter: 0.78
- Position Detector: 0.62
- Structure Reader: 0.65
- Heads in Group B: 0.81

Output CMAP Results (Naive -> Llama):
- Full Circuit: 0.82
- Value Fetcher: 0.81
- Position Transmitter: 0.74
- Position Detector: 0.52
- Structure Reader: 0.64
- Heads in Group B: 0.8


### CMAP (input patching)

In [421]:
# Select group of heads for CMAP accordingly
pos_heads_dict = {}
# pos_heads_dict[0] = llama_value_fetcher
pos_heads_dict[0] = goat_pos_transmitter

In [424]:
correct_count, total_count = 0, 0

with torch.no_grad():
    for bi, inputs in tqdm(enumerate(dataloader)):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(llama_model.device)

        with TraceDict(llama_model,
                       llama_modules,
                       retain_input=True,
                       edit_output=partial(
                            cmap_in,
                            model = llama_model,
                            goat_cache = goat_cache,
                            llama_cache = llama_cache,
                            patching_component = ["q_proj", "k_proj", "v_proj"], #Options: "q_proj" (query), "k_proj" (key), "v_proj" (value)
                            bi = bi,
                            pos_heads_dict = pos_heads_dict,
                            input_tokens = inputs)) as _:
                outputs = llama_model(inputs["input_ids"], output_attentions=True)

        for bi in range(inputs["labels"].size(0)):
            label = inputs["labels"][bi]
            pred = torch.argmax(outputs.logits[bi][inputs["last_token_indices"][bi]])

            if label == pred:
                correct_count += 1
            total_count += 1

        del outputs
        torch.cuda.empty_cache()

current_acc = round(correct_count / total_count, 2)
print(f"Task accuracy: {current_acc}")

63it [01:00,  1.05it/s]

Task accuracy: 0.78





Input CMAP Results:
- Value Fetcher:
    - Query: 0.77
    - Key: 0.63
    - Value: 0.69
    - QK: 0.76
    - QKV: 0.77
- Position Transmitter:
    - Query: 0.66
    - Key: 0.65
    - Value: 0.69
    - QK: 0.65
    - QKV: 0.67

### Patching Weights

In [17]:
def patch_weights(inputs=None, 
                  output=None, 
                  layer=None, 
                  bi=None, 
                  patching_weight_matrix=None,
                  weight_patching_heads=None,
                  output_patching_heads=None,
                  weight_rel_pos=None,
                  output_rel_pos=None):
    if isinstance(inputs, tuple):
        inputs = inputs[0]

    layer_idx = int(layer.split(".")[2])
    weight_patching_heads_curr_layer = [h for l, h in weight_patching_heads if l == layer_idx]
    output_patching_heads_curr_layer = [h for l, h in output_patching_heads if l == layer_idx]
    
    num_heads = llama_model.config.num_attention_heads
    head_size = llama_model.config.hidden_size // num_heads

    if (patching_weight_matrix in layer) and len(weight_patching_heads_curr_layer) > 0:
        llama_w = llama_model.state_dict()[f"{layer}.weight"].clone()
        
        w_o = goat_model.state_dict()[f"base_model.model.{layer}.weight"].clone()
        lora_A = goat_model.state_dict()[f"base_model.model.{layer}.lora_A.default.weight"].clone()
        lora_B = goat_model.state_dict()[f"base_model.model.{layer}.lora_B.default.weight"].clone()
        goat_w = w_o + lora_B @ lora_A

        for head_idx in weight_patching_heads_curr_layer:
            head_start = head_idx * head_size
            head_end = (head_idx + 1) * head_size
            llama_w[head_start:head_end, :] = goat_w[head_start:head_end, :]

        output = inputs @ llama_w.T
        pos = inputs.size(1) - weight_rel_pos - 1
#         output = torch.cat((output[:, :pos], new_output[:, pos].unsqueeze(dim=1)), dim=1)
    
    if ("o_proj" in layer) and (len(output_patching_heads_curr_layer) > 0):
        inputs = rearrange(inputs, "batch seq_len (n_heads d_head) -> batch seq_len n_heads d_head", n_heads=32)
        gcache = rearrange(goat_cache[bi][layer], "batch seq_len (n_heads d_head) -> batch seq_len n_heads d_head", n_heads=32)

        pos = inputs.size(1) - output_rel_pos - 1
        for head_idx in output_patching_heads_curr_layer:
            inputs[:, pos, head_idx] = gcache[:, pos, head_idx]
        
        inputs = rearrange(inputs, "batch seq_len n_heads d_head -> batch seq_len (n_heads d_head)", n_heads=32)
        w_o = llama_model.state_dict()[f"{layer}.weight"]
        output = einsum(
            inputs, w_o, "batch seq_len hidden_size, d_model hidden_size -> batch seq_len d_model"
        )

    return output

In [18]:
correct_count, total_count = 0, 0
llama_model.eval()
with torch.no_grad():
    for bi, inputs in tqdm(enumerate(tqdm(dataloader))):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(llama_model.device)

        with TraceDict(
            llama_model,
            llama_modules,
            retain_input=True,
            edit_output=partial(
                patch_weights,
                bi=bi,
                patching_weight_matrix="v_proj",
                weight_patching_heads=goat_pos_transmitter,
                output_patching_heads=[],
                weight_rel_pos=0,
                output_rel_pos=2,
            ),
        ) as _:
            outputs = llama_model(inputs["input_ids"])

        for bi in range(inputs["labels"].size(0)):
            label = inputs["labels"][bi]
            pred = torch.argmax(outputs.logits[bi][inputs["last_token_indices"][bi]])

            if label == pred:
                correct_count += 1
            total_count += 1

        del outputs
        torch.cuda.empty_cache()

  0%|          | 0/63 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
  2%|▏         | 1/63 [00:00<00:27,  2.23it/s]
  3%|▎         | 2/63 [00:00<00:26,  2.33it/s]
  5%|▍         | 3/63 [00:01<00:25,  2.37it/s]
  6%|▋         | 4/63 [00:01<00:24,  2.39it/s]
  8%|▊         | 5/63 [00:02<00:24,  2.40it/s]
 10%|▉         | 6/63 [00:02<00:23,  2.40it/s]
 11%|█         | 7/63 [00:02<00:23,  2.41it/s]
 13%|█▎        | 8/63 [00:03<00:22,  2.41it/s]
 14%|█▍        | 9/63 [00:03<00:22,  2.41it/s]
 16%|█▌        | 10/63 [00:04<00:21,  2.41it/s]
 17%|█▋        | 11/63 [00:04<00:21,  2.41it/s]
 19%|█▉        | 12/63 [00:05<00:21,  2.42it/s]
 21%|██        | 13/63 [00:05<00:20,  2.42it/s]
 22%|██▏       | 14/63 [00:05<00:20,  2.42it/s]
 24%|██▍       | 15/63 [00:06<00:20,  2.39it/s]
 25%|██▌       | 16/63 [00:06<00:19,  2.40it/s]
 27%|██▋       | 17/63 [00:07<00:19,  2.35it/s]
 29%|██▊       | 18/63 [00:07<00:19,  2.37it/s]
 30%|███       | 19/63 [00:07<00:18,  2.38it/s]
 32%|███▏      | 20/63 [00:08<00:17,

In [19]:
round(correct_count/total_count, 2)

0.65

### Patching inputs to K/Q/V matrices

In [374]:
goat_input_cache = {}

with torch.no_grad():
    for bi, inputs in tqdm(enumerate(dataloader)):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(goat_model.device)

        with TraceDict(goat_model, goat_modules, retain_input=True) as cache:
            _ = goat_model(inputs["input_ids"])

        for goat_layer, llama_layer in zip(goat_modules, llama_modules):
            if bi in goat_input_cache:
                goat_input_cache[bi][llama_layer] = cache[goat_layer].input.cpu()
            else:
                goat_input_cache[bi] = {}
                goat_input_cache[bi][llama_layer] = cache[goat_layer].input.cpu()

63it [00:34,  1.83it/s]


In [476]:
def patch_inputs(inputs=None, output=None, layer=None, bi=None):
    if isinstance(inputs, tuple):
        inputs = inputs[0]

    position_trans_curr_layer = [h for l, h in goat_pos_transmitter if l == int(layer.split(".")[2])]

    num_heads = llama_model.config.num_attention_heads
    head_size = llama_model.config.hidden_size // num_heads

    if ("v_proj" in layer) and len(position_trans_curr_layer) > 0:
        goat_inp = goat_input_cache[bi][layer]

        llama_w = llama_model.state_dict()[f"{layer}.weight"].clone()

        output = rearrange(output, 
                   "batch seq_len (n_heads d_head) -> batch seq_len n_heads d_head", 
                   n_heads=goat_model.config.num_attention_heads)

        for head_idx in position_trans_curr_layer:
            head_start = head_idx * head_size
            head_end = (head_idx + 1) * head_size

            # Computing the output of q_proj in llama using its corresponding input in goat
            res = goat_inp.cuda() @ llama_w[head_start:head_end, :].T

            # Since we only patch o/p of q_proj at the last token in the previous
            # experiment, here also we patch in the input of q_proj only at the 
            # last token.
            output[:, :, head_idx] = res[:, :]

        output = rearrange(output, 
                       "batch seq_len n_heads d_head -> batch seq_len (n_heads d_head)", 
                       n_heads=llama_model.config.num_attention_heads)

    return output

In [477]:
correct_count, total_count = 0, 0
with torch.no_grad():
    for bi, inputs in tqdm(enumerate(tqdm(dataloader))):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(llama_model.device)

        with TraceDict(
            llama_model,
            llama_modules,
            retain_input=True,
            edit_output=partial(
                patch_inputs,
                bi=bi,
            ),
        ) as _:
            outputs = llama_model(inputs["input_ids"])

        for bi in range(inputs["labels"].size(0)):
            label = inputs["labels"][bi]
            pred = torch.argmax(outputs.logits[bi][inputs["last_token_indices"][bi]])

            if label == pred:
                correct_count += 1
            total_count += 1

        del outputs
        torch.cuda.empty_cache()

  0%|          | 0/63 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
  2%|▏         | 1/63 [00:00<00:28,  2.20it/s]
  3%|▎         | 2/63 [00:00<00:26,  2.32it/s]
  5%|▍         | 3/63 [00:01<00:25,  2.37it/s]
  6%|▋         | 4/63 [00:01<00:24,  2.38it/s]
  8%|▊         | 5/63 [00:02<00:24,  2.40it/s]
 10%|▉         | 6/63 [00:02<00:23,  2.40it/s]
 11%|█         | 7/63 [00:02<00:23,  2.41it/s]
 13%|█▎        | 8/63 [00:03<00:22,  2.41it/s]
 14%|█▍        | 9/63 [00:03<00:22,  2.41it/s]
 16%|█▌        | 10/63 [00:04<00:22,  2.41it/s]
 17%|█▋        | 11/63 [00:04<00:21,  2.40it/s]
 19%|█▉        | 12/63 [00:05<00:21,  2.41it/s]
 21%|██        | 13/63 [00:05<00:20,  2.41it/s]
 22%|██▏       | 14/63 [00:05<00:20,  2.41it/s]
 24%|██▍       | 15/63 [00:06<00:19,  2.41it/s]
 25%|██▌       | 16/63 [00:06<00:19,  2.41it/s]
 27%|██▋       | 17/63 [00:07<00:19,  2.40it/s]
 29%|██▊       | 18/63 [00:07<00:18,  2.40it/s]
 30%|███       | 19/63 [00:07<00:18,  2.40it/s]
 32%|███▏      | 20/63 [00:08<00:17,

In [478]:
round(correct_count/total_count, 2)

0.76