In [1]:
import os
import sys

# to import functions from `patching_utils.py` and `plotly_utils.py`,
# we need to add the repository directory to the system path.
current_dir = os.path.dirname(os.getcwd())
print(current_dir)
if current_dir not in sys.path:
    sys.path.append(current_dir)

/home/qinyuan/function-induction


In [2]:
import torch as t
from torch import Tensor
import torch.nn.functional as F

from transformer_lens import HookedTransformer

t.set_grad_enabled(False)

from jaxtyping import Float, Int, Bool, Union

import circuitsvis as cv
from IPython.display import display
from plotly_utils import imshow

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from patching_utils import batched_get_path_patch_to_repr, batched_get_path_patch_to_head, prepare_data_for_fwd

In [4]:
def analyze_path_to_repr(
    receiver_layers: list[int],
    receiver_input: str,
    model: HookedTransformer,
    data: list[dict],
    begin_layer: int,
    batch_size: int
):
    t.cuda.empty_cache()
    patched_logit_diff, normal_logit_diff, contrast_logit_diff = batched_get_path_patch_to_repr(receiver_layers, receiver_input, model, data, begin_layer, batch_size)
    results = (patched_logit_diff - contrast_logit_diff) / (contrast_logit_diff  - normal_logit_diff)
    results[:begin_layer, :] = 0.0
    end_layer=max(receiver_layers)
    results[end_layer+1:, :] = 0.0
    visualize_heatmap(results)
    return results

def analyze_path_to_head(
    receiver_heads: list[int],
    receiver_input: str,
    model: HookedTransformer,
    data: list[dict],
    begin_layer: int,
    batch_size: int
):
    t.cuda.empty_cache()
    patched_logit_diff, normal_logit_diff, contrast_logit_diff = batched_get_path_patch_to_head(receiver_heads, receiver_input, model, data, begin_layer, batch_size)
    results = (patched_logit_diff - contrast_logit_diff) / (contrast_logit_diff  - normal_logit_diff)
    results[:begin_layer, :] = 0.0
    end_layer=max([head[0] for head in receiver_heads])
    results[end_layer:, :] = 0.0
    visualize_heatmap(results)
    return results

def visualize_heatmap(results, info=""):
    imshow(
        100 * results.t(),
        title="Path patching results {}".format(info),
        labels={"x":"Layer", "y":"Head", "color": "Logit diff. variation"},
        coloraxis=dict(colorbar_ticksuffix = "%"),
        width=800,
        height=400,
    )

def visualize_heads(heads, model, data, mode="contrast"):
    data = data[:4]
    t.cuda.empty_cache()
    normal_input, contrast_input, _, _, normal_cache, contrast_cache, _, _ = prepare_data_for_fwd(model, data)
    
    if mode == "contrast":
        cache = contrast_cache
        tokens = model.to_tokens(contrast_input)[0]
    elif mode == "normal":
        cache = normal_cache
        tokens = model.to_tokens(normal_input)[0]

    attn_patterns_for_important_heads: Float[Tensor, "head q k"] = t.stack([
        cache["pattern", layer][:, head].mean(0)
            for layer, head in heads
    ])
    
    display(cv.attention.attention_patterns(
        attention = attn_patterns_for_important_heads,
        tokens = model.to_str_tokens(tokens),
        attention_head_names = [f"{layer}.{head}" for layer, head in heads],
    ))

In [5]:
device = "cuda:0"
model_name = "meta-llama/Llama-2-7b-hf"
model = HookedTransformer.from_pretrained(model_name, device=device)
model.set_ungroup_grouped_query_attention(True)

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.33s/it]


Loaded pretrained model meta-llama/Llama-2-7b-hf into HookedTransformer


In [6]:
# Load data
from data_utils import process_dataset, read_jsonl

setting, nmax, offset, n_icl_examples = "setting1", 9, 1, 4
filename = f"../data/addition/{setting}/addition_nmax{nmax}_offset{offset}.jsonl"
data = read_jsonl(filename)
processed_data = process_dataset(data, n_icl_examples=n_icl_examples, offset=offset)

# load 4 examples for demo
processed_data = processed_data[:4]

In [9]:
# Estimated run time: 5 min
results_resid_post_31 = analyze_path_to_repr([31], "resid_post", model, processed_data, begin_layer=0, batch_size=4)

batch: 100%|██████████| 1/1 [04:44<00:00, 284.34s/it]


In [10]:
# Tracking which heads have |r| > 2%

head_list_1 = [tuple(idx) for idx in t.nonzero(t.abs(results_resid_post_31) > 0.05, as_tuple=False).tolist()]
print(f"rel_diff > 0.05:\t {sorted(head_list_1, reverse=True)}")

head_list_2 = [tuple(idx) for idx in t.nonzero(t.abs(results_resid_post_31) > 0.02, as_tuple=False).tolist() if tuple(idx) not in head_list_1]
print(f"0.05 > rel_diff > 0.02:\t {sorted(head_list_2, reverse=True)}")

rel_diff > 0.05:	 [(31, 30), (31, 10), (31, 4), (29, 26), (29, 16)]
0.05 > rel_diff > 0.02:	 [(31, 28), (30, 26), (29, 2), (18, 30), (14, 1), (11, 2)]


In [7]:
# results from llama_2_7b_offline.py which uses 100 examples instead of 4 examples to reduce noise.
head_list_1 = [(31, 30), (31, 10), (31, 4), (29, 26), (29, 16)]
head_list_2 = [(31, 28), (30, 26), (30, 3), (16, 24)]
print(f"rel_diff > 0.05:\t {sorted(head_list_1, reverse=True)}")
print(f"0.05 > rel_diff > 0.02:\t {sorted(head_list_2, reverse=True)}")

rel_diff > 0.05:	 [(31, 30), (31, 10), (31, 4), (29, 26), (29, 16)]
0.05 > rel_diff > 0.02:	 [(31, 28), (30, 26), (30, 3), (16, 24)]


In [13]:
results_h31_10_v = analyze_path_to_head([(31, 10)], "v", model, processed_data, begin_layer=0, batch_size=4)

layer: 100%|██████████| 32/32 [04:41<00:00,  8.80s/it]
batch: 100%|██████████| 1/1 [04:41<00:00, 281.79s/it]


In [14]:
results_h29_16_v = analyze_path_to_head([(29, 16)], "v", model, processed_data, begin_layer=0, batch_size=4)

layer: 100%|██████████| 30/30 [04:23<00:00,  8.77s/it]
batch: 100%|██████████| 1/1 [04:23<00:00, 263.36s/it]
