In [1]:
import os
import sys

# to import functions from `patching_utils.py` and `plotly_utils.py`,
# we need to add the repository directory to the system path.
current_dir = os.path.dirname(os.getcwd())
print(current_dir)
if current_dir not in sys.path:
    sys.path.append(current_dir)

/home/qinyuan/function-induction


### Helpers

In [2]:
import torch as t
from torch import Tensor
import torch.nn.functional as F

from transformer_lens import HookedTransformer

t.set_grad_enabled(False)

from jaxtyping import Float, Int, Bool, Union

import circuitsvis as cv
from IPython.display import display
from plotly_utils import imshow

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from patching_utils import batched_get_path_patch_to_repr, batched_get_path_patch_to_head, prepare_data_for_fwd

In [4]:
def analyze_path_to_repr(
    receiver_layers: list[int],
    receiver_input: str,
    model: HookedTransformer,
    data: list[dict],
    begin_layer: int,
    batch_size: int
):
    t.cuda.empty_cache()
    patched_logit_diff, normal_logit_diff, contrast_logit_diff = batched_get_path_patch_to_repr(receiver_layers, receiver_input, model, data, begin_layer, batch_size)
    results = (patched_logit_diff - contrast_logit_diff) / (contrast_logit_diff  - normal_logit_diff)
    results[:begin_layer, :] = 0.0
    end_layer=max(receiver_layers)
    results[end_layer+1:, :] = 0.0
    visualize_heatmap(results)
    return results

def analyze_path_to_head(
    receiver_heads: list[int],
    receiver_input: str,
    model: HookedTransformer,
    data: list[dict],
    begin_layer: int,
    batch_size: int
):
    t.cuda.empty_cache()
    patched_logit_diff, normal_logit_diff, contrast_logit_diff = batched_get_path_patch_to_head(receiver_heads, receiver_input, model, data, begin_layer, batch_size)
    results = (patched_logit_diff - contrast_logit_diff) / (contrast_logit_diff  - normal_logit_diff)
    results[:begin_layer, :] = 0.0
    end_layer=max([head[0] for head in receiver_heads])
    results[end_layer:, :] = 0.0
    visualize_heatmap(results)
    return results

def visualize_heatmap(results, info=""):
    imshow(
        100 * results.t(),
        title="Path patching results {}".format(info),
        labels={"x":"Layer", "y":"Head", "color": "Logit diff. variation"},
        coloraxis=dict(colorbar_ticksuffix = "%"),
        width=500,
        height=400,
    )

def visualize_heads(heads, model, data, mode="contrast"):
    data = data[:4]
    t.cuda.empty_cache()
    normal_input, contrast_input, _, _, normal_cache, contrast_cache, _, _ = prepare_data_for_fwd(model, data)
    
    if mode == "contrast":
        cache = contrast_cache
        tokens = model.to_tokens(contrast_input)[0]
    elif mode == "normal":
        cache = normal_cache
        tokens = model.to_tokens(normal_input)[0]

    attn_patterns_for_important_heads: Float[Tensor, "head q k"] = t.stack([
        cache["pattern", layer][:, head].mean(0)
            for layer, head in heads
    ])
    
    display(cv.attention.attention_patterns(
        attention = attn_patterns_for_important_heads,
        tokens = model.to_str_tokens(tokens),
        attention_head_names = [f"{layer}.{head}" for layer, head in heads],
    ))

### Loading model and data

In [5]:
device = "cuda:2"
model_name = "meta-llama/Meta-Llama-3-8B"
model = HookedTransformer.from_pretrained(model_name, device=device)
model.set_ungroup_grouped_query_attention(True)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00,  4.05it/s]


Loaded pretrained model meta-llama/Meta-Llama-3-8B into HookedTransformer


In [6]:
# Load data
from data_utils import process_dataset, read_jsonl

setting, nmax, offset, n_icl_examples = "setting1", 999, 1, 4
filename = f"../data/addition/{setting}/addition_nmax{nmax}_offset{offset}.jsonl"
data = read_jsonl(filename)
processed_data = process_dataset(data, n_icl_examples=n_icl_examples, offset=offset)

# load 4 examples for demo
processed_data = processed_data[:4]

### Path Patching Demo

In [7]:
results_resid_post_31 = analyze_path_to_repr([31], "resid_post", model, processed_data, begin_layer=0, batch_size=4)

batch: 100%|██████████| 1/1 [05:42<00:00, 342.70s/it]


In [8]:
results_h29_11_v = analyze_path_to_head([(29, 11)], "v", model, processed_data, begin_layer=0, batch_size=4)

batch:   0%|          | 0/1 [00:00<?, ?it/s]

layer: 100%|██████████| 30/30 [05:21<00:00, 10.72s/it]
batch: 100%|██████████| 1/1 [05:21<00:00, 321.98s/it]


In [9]:
results_h26_2_v = analyze_path_to_head([(26, 2)], "v", model, processed_data, begin_layer=0, batch_size=4)

batch:   0%|          | 0/1 [00:00<?, ?it/s]

layer: 100%|██████████| 27/27 [04:49<00:00, 10.71s/it]
batch: 100%|██████████| 1/1 [04:49<00:00, 289.55s/it]
