In [1]:
import os
import sys

# to import functions from `patching_utils.py` and `plotly_utils.py`,
# we need to add the repository directory to the system path.
current_dir = os.path.dirname(os.getcwd())
print(current_dir)
if current_dir not in sys.path:
    sys.path.append(current_dir)

/home/qinyuan/function-induction


### Helpers

In [2]:
import torch as t
from torch import Tensor
import torch.nn.functional as F

from transformer_lens import HookedTransformer

t.set_grad_enabled(False)

from jaxtyping import Float, Int, Bool, Union

import circuitsvis as cv
from IPython.display import display
from plotly_utils import imshow

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from patching_utils import batched_get_path_patch_to_repr, batched_get_path_patch_to_head, prepare_data_for_fwd

In [4]:
def analyze_path_to_repr(
    receiver_layers: list[int],
    receiver_input: str,
    model: HookedTransformer,
    data: list[dict],
    begin_layer: int,
    batch_size: int
):
    t.cuda.empty_cache()
    patched_logit_diff, normal_logit_diff, contrast_logit_diff = batched_get_path_patch_to_repr(receiver_layers, receiver_input, model, data, begin_layer, batch_size)
    results = (patched_logit_diff - contrast_logit_diff) / (contrast_logit_diff  - normal_logit_diff)
    results[:begin_layer, :] = 0.0
    end_layer=max(receiver_layers)
    results[end_layer:, :] = 0.0
    visualize_heatmap(results)
    return results

def analyze_path_to_head(
    receiver_heads: list[int],
    receiver_input: str,
    model: HookedTransformer,
    data: list[dict],
    begin_layer: int,
    batch_size: int
):
    t.cuda.empty_cache()
    patched_logit_diff, normal_logit_diff, contrast_logit_diff = batched_get_path_patch_to_head(receiver_heads, receiver_input, model, data, begin_layer, batch_size)
    results = (patched_logit_diff - contrast_logit_diff) / (contrast_logit_diff  - normal_logit_diff)
    results[:begin_layer, :] = 0.0
    end_layer=max([head[0] for head in receiver_heads])
    results[end_layer:, :] = 0.0
    visualize_heatmap(results)
    return results

def visualize_heatmap(results, info=""):
    imshow(
        100 * results.t(),
        title="Path patching results {}".format(info),
        labels={"x":"Layer", "y":"Head", "color": "Logit diff. variation"},
        coloraxis=dict(colorbar_ticksuffix = "%"),
        width=800,
        height=400,
    )

def visualize_heads(heads, model, data, mode="contrast"):
    data = data[:4]
    t.cuda.empty_cache()
    normal_input, contrast_input, _, _, normal_cache, contrast_cache, _, _ = prepare_data_for_fwd(model, data)
    
    if mode == "contrast":
        cache = contrast_cache
        tokens = model.to_tokens(contrast_input)[0]
    elif mode == "normal":
        cache = normal_cache
        tokens = model.to_tokens(normal_input)[0]

    attn_patterns_for_important_heads: Float[Tensor, "head q k"] = t.stack([
        cache["pattern", layer][:, head].mean(0)
            for layer, head in heads
    ])
    
    display(cv.attention.attention_patterns(
        attention = attn_patterns_for_important_heads,
        tokens = model.to_str_tokens(tokens),
        attention_head_names = [f"{layer}.{head}" for layer, head in heads],
    ))

### Loading model and data

In [5]:
device = "cuda:0"
model_name = "google/gemma-2-9b"
model = HookedTransformer.from_pretrained(model_name, device=device)
model.set_ungroup_grouped_query_attention(True)

Loading checkpoint shards: 100%|██████████| 8/8 [00:00<00:00, 127.67it/s]


Loaded pretrained model google/gemma-2-9b into HookedTransformer


In [6]:
# Load data
from data_utils import process_dataset, read_jsonl

setting, nmax, offset, n_icl_examples = "setting1", 9, 1, 4
filename = f"../data/addition/{setting}/addition_nmax{nmax}_offset{offset}.jsonl"
data = read_jsonl(filename)
processed_data = process_dataset(data, n_icl_examples=n_icl_examples, offset=offset)

# load 12 examples for demo
processed_data = processed_data[:12]

### Path Patching Demo

#### Patching to output logits

In [9]:
# Estimated run time: 15 min
results_resid_post_41 = analyze_path_to_repr([41], "resid_post", model, processed_data, begin_layer=0, batch_size=4)

batch:   0%|          | 0/3 [00:00<?, ?it/s]

batch: 100%|██████████| 3/3 [15:52<00:00, 317.40s/it]


In [10]:
# Tracking which heads have |r| > 2%

head_list_1 = [tuple(idx) for idx in t.nonzero(t.abs(results_resid_post_41) > 0.05, as_tuple=False).tolist()]
print(f"rel_diff > 0.05:\t {sorted(head_list_1, reverse=True)}")

head_list_2 = [tuple(idx) for idx in t.nonzero(t.abs(results_resid_post_41) > 0.02, as_tuple=False).tolist() if tuple(idx) not in head_list_1]
print(f"0.05 > rel_diff > 0.02:\t {sorted(head_list_2, reverse=True)}")

rel_diff > 0.05:	 [(39, 7), (36, 7)]
0.05 > rel_diff > 0.02:	 [(40, 12), (40, 11), (39, 12), (32, 1), (25, 13)]


The demo above runs on 12 test cases, which is good for preliminary exploration but may be noisy. We repeat the experiments with 100 test cases by running `gemma_2_9b_offline.py`. Below we load and visualize the results, which also corresponds to Figure 2(a) in the figure.

In [7]:
filename = "./results/gemma_2_9b/resid_post_41.pt"
results_resid_post_41 = t.load(filename)
visualize_heatmap(results_resid_post_41)

In [8]:
head_list_1 = [tuple(idx) for idx in t.nonzero(t.abs(results_resid_post_41) > 0.05, as_tuple=False).tolist()]
print(f"rel_diff > 0.05:\t {sorted(head_list_1, reverse=True)}")

head_list_2 = [tuple(idx) for idx in t.nonzero(t.abs(results_resid_post_41) > 0.02, as_tuple=False).tolist() if tuple(idx) not in head_list_1]
print(f"0.05 > rel_diff > 0.02:\t {sorted(head_list_2, reverse=True)}")

rel_diff > 0.05:	 [(41, 4), (39, 7), (36, 7)]
0.05 > rel_diff > 0.02:	 [(41, 5), (40, 12), (40, 11), (39, 12), (32, 6), (32, 1), (25, 13)]


We can visualize the attention pattern of these heads to help our analysis.

In [9]:
# Consolidation Heads
visualize_heads([(41, 4), (41, 5), (40, 11), (40, 12)], model, processed_data)

In [10]:
# Function Induction Heads
visualize_heads([(39, 7), (39, 12), (36, 7), (32, 1), (32, 6), (25, 13)], model, processed_data)

#### Patching to H39.7 v

In [17]:
results_h39_7_v = analyze_path_to_head([(39, 7)], "v", model, processed_data, begin_layer=0, batch_size=4)

layer: 100%|██████████| 40/40 [05:01<00:00,  7.53s/it]
layer: 100%|██████████| 40/40 [05:02<00:00,  7.55s/it]
layer: 100%|██████████| 40/40 [05:02<00:00,  7.56s/it]
batch: 100%|██████████| 3/3 [15:07<00:00, 302.37s/it]


Similarly we can run patching with 100 examples using `gemma_2_9b_offline.py`. Below we load and visualize the results.

In [11]:
filename = "./results/gemma_2_9b/ind_h39_7_v.pt"
results_h39_7_v = t.load(filename)
visualize_heatmap(results_h39_7_v)

In [11]:
# Previous Token Heads (3 of 8 in total)
visualize_heads([(38, 6), (38, 7), (35, 9)], model, processed_data)