In [2]:
!pip install jaxtyping transformer-lens
!pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python

Collecting jaxtyping
  Downloading jaxtyping-0.2.37-py3-none-any.whl.metadata (6.6 kB)
Collecting transformer-lens
  Downloading transformer_lens-2.11.0-py3-none-any.whl.metadata (12 kB)
Collecting wadler-lindig>=0.1.3 (from jaxtyping)
  Downloading wadler_lindig-0.1.3-py3-none-any.whl.metadata (17 kB)
Collecting beartype<0.15.0,>=0.14.1 (from transformer-lens)
  Downloading beartype-0.14.1-py3-none-any.whl.metadata (28 kB)
Collecting better-abc<0.0.4,>=0.0.3 (from transformer-lens)
  Downloading better_abc-0.0.3-py3-none-any.whl.metadata (1.4 kB)
Collecting datasets>=2.7.1 (from transformer-lens)
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting fancy-einsum>=0.0.3 (from transformer-lens)
  Downloading fancy_einsum-0.0.3-py3-none-any.whl.metadata (1.2 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets>=2.7.1->transformer-lens)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets>=2.7.1->transformer-lens)
  Downloading xxhash

In [3]:
import os
import sys
import gc
import einops
import numpy as np
import circuitsvis as cv
import plotly.express as px
import torch
from rich.table import Table, Column
from jaxtyping import Float, Int, Bool
from typing import Literal, Callable
from rich import print as rprint
from IPython.display import display, HTML
from torch import Tensor

if os.getcwd() not in sys.path:
    sys.path.append(os.getcwd())
from plotly_utils_user import line, imshow
import pandas as pd
from transformer_lens import HookedTransformer, utils, ActivationCache, patching

In [4]:
dataset = pd.read_csv(os.path.join("./utils", "final_dataset.csv")).to_numpy()
dataset = list(map(lambda dataset_sample: dict(
    active=dataset_sample[0],
    passive=dataset_sample[1],
    agent=dataset_sample[2],
    distractor=dataset_sample[3],
    prompt=f"{dataset_sample[0]} {' '.join(dataset_sample[1].split()[:-1])}",
    inverse_prompt=f"{dataset_sample[1]} {' '.join(dataset_sample[0].split()[:-1])}",
    answer=(f' {dataset_sample[2]}', f' {dataset_sample[3]}'),
    inverse_answer=(f' {dataset_sample[3]}', f' {dataset_sample[2]}'),
), dataset))

In [5]:
model = HookedTransformer.from_pretrained(
    "meta-llama/Llama-3.2-3B", device="cuda"
)

config.json:   0%|          | 0.00/844 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/301 [00:00<?, ?B/s]



Loaded pretrained model meta-llama/Llama-3.2-3B into HookedTransformer


In [6]:
print(f'Prompt that we test for the single case: {dataset[0]["prompt"]}')
print(f'Answer that we test for the single case: {dataset[0]["agent"]}')
utils.test_prompt(dataset[0]["prompt"], f" {dataset[0]['agent']}", model, prepend_bos=True)

Prompt that we test for the single case: The engineer built the bridge. The bridge was built by the
Answer that we test for the single case: engineer
Tokenized prompt: ['<|begin_of_text|>', 'The', ' engineer', ' built', ' the', ' bridge', '.', ' The', ' bridge', ' was', ' built', ' by', ' the']
Tokenized answer: [' engineer']


Top 0th token. Logit: 17.15 Prob: 40.44% Token: | engineer|
Top 1th token. Logit: 14.98 Prob:  4.60% Token: | engineers|
Top 2th token. Logit: 14.80 Prob:  3.85% Token: | contractor|
Top 3th token. Logit: 14.38 Prob:  2.53% Token: | people|
Top 4th token. Logit: 14.16 Prob:  2.04% Token: | government|
Top 5th token. Logit: 13.77 Prob:  1.38% Token: | architect|
Top 6th token. Logit: 13.75 Prob:  1.35% Token: | workers|
Top 7th token. Logit: 13.46 Prob:  1.01% Token: | builder|
Top 8th token. Logit: 13.39 Prob:  0.94% Token: | company|
Top 9th token. Logit: 13.13 Prob:  0.72% Token: | construction|


In [7]:
torch.cuda.empty_cache()
gc.collect()

21

In [8]:
prompts, answers, answers_tokens_list = [], [], []
for dataset_element in dataset:
    tokens = model.to_tokens(dataset_element["answer"], prepend_bos=False).T
    inverse_tokens = model.to_tokens(dataset_element["inverse_answer"], prepend_bos=False).T
    if tokens.shape[0] != 1:
        continue

    prompts.append(dataset_element["prompt"])
    prompts.append(dataset_element["inverse_prompt"])

    answers.append(dataset_element["answer"])
    answers.append(dataset_element["inverse_answer"])

    answers_tokens_list.append(tokens)
    answers_tokens_list.append(inverse_tokens)

answer_tokens = torch.concat(answers_tokens_list, dim=0)
tokens = model.to_tokens(prompts, prepend_bos=True)
tokens = tokens.to(torch.device("cuda"))
original_logits, cache = model.run_with_cache(tokens)

In [9]:
def logits_to_ave_logit_diff(
    logits: Float[Tensor, "batch seq d_vocab"],
    answer_tokens: Float[Tensor, "batch 2"] = answer_tokens,
    per_prompt: bool = False
) -> Float[Tensor, "*batch"]:
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''
    # SOLUTION
    # Only the final logits are relevant for the answer
    final_logits: Float[Tensor, "batch d_vocab"] = logits[:, -1, :]
    # Get the logits corresponding to the indirect object / subject tokens respectively
    answer_logits: Float[Tensor, "batch 2"] = final_logits.gather(dim=-1, index=answer_tokens)
    # Find logit difference
    correct_logits, incorrect_logits = answer_logits.unbind(dim=-1)
    answer_logit_diff = correct_logits - incorrect_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [10]:
clean_tokens = tokens
indices = [i+1 if i % 2 == 0 else i-1 for i in range(len(tokens))]
corrupted_tokens = clean_tokens[indices]

print(
    "Clean string 0:    ", model.to_string(clean_tokens[0]), "\n"
    "Corrupted string 0:", model.to_string(corrupted_tokens[0])
)

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

clean_logit_diff = logits_to_ave_logit_diff(clean_logits, answer_tokens)
print(f"Clean logit diff: {clean_logit_diff:.4f}")

corrupted_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)
print(f"Corrupted logit diff: {corrupted_logit_diff:.4f}")

Clean string 0:     <|begin_of_text|>The engineer built the bridge. The bridge was built by the 
Corrupted string 0: <|begin_of_text|>The bridge was built by the engineer. The engineer built the
Clean logit diff: 6.2646
Corrupted logit diff: -6.2646


In [11]:
torch.cuda.empty_cache()
gc.collect()

0

In [12]:
def metric_fn(
    logits: Float[Tensor, "batch seq d_vocab"],
    answer_tokens: Float[Tensor, "batch 2"] = answer_tokens,
    corrupted_logit_diff: float = corrupted_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
) -> Float[Tensor, ""]:
    '''
    Linear function of logit diff, calibrated so that it equals 0 when performance is
    same as on corrupted input, and 1 when performance is same as on clean input.
    '''
    # SOLUTION
    patched_logit_diff = logits_to_ave_logit_diff(logits, answer_tokens)
    return (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff  - corrupted_logit_diff)


In [14]:
batch_size = 8
num_batches = len(corrupted_tokens) // batch_size

all_act_patch_resid_pre = []

for batch_idx in range(num_batches):
    start_idx = batch_idx * batch_size
    end_idx = min((batch_idx + 1) * batch_size, len(corrupted_tokens))

    # Batch slice for corrupted tokens and cache
    batch_corrupted_tokens = corrupted_tokens[start_idx:end_idx]
    batch_clean_cache = {
        key: value[start_idx:end_idx]
        for key, value in clean_cache.items()
    }
    # **Batch slice for answer tokens**
    batch_answer_tokens = answer_tokens[start_idx:end_idx]

    with torch.no_grad():
        # Wrap metric_fn to use the batch-specific answer tokens
        batch_act_patch_resid_pre = patching.get_act_patch_resid_pre(
            model=model,
            corrupted_tokens=batch_corrupted_tokens,
            clean_cache=batch_clean_cache,
            patching_metric=lambda logits: metric_fn(logits, answer_tokens=batch_answer_tokens)
        )

    all_act_patch_resid_pre.append(batch_act_patch_resid_pre)
    torch.cuda.empty_cache()

all_act_patch_resid_pre = torch.stack(all_act_patch_resid_pre, 0).mean(0)
labels = [f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

In [15]:
imshow(
    all_act_patch_resid_pre,
    labels={"x": "Position", "y": "Layer"},
    x=labels,
    title="resid_pre Activation Patching",
    width=600
)

In [17]:
batch_size = 8
num_batches = len(corrupted_tokens) // batch_size

all_act_patch_resid_pre = []

for batch_idx in range(num_batches):
    start_idx = batch_idx * batch_size
    end_idx = min((batch_idx + 1) * batch_size, len(corrupted_tokens))

    # Batch slice for corrupted tokens and cache
    batch_corrupted_tokens = corrupted_tokens[start_idx:end_idx]
    batch_clean_cache = {
        key: value[start_idx:end_idx]
        for key, value in clean_cache.items()
    }
    # **Batch slice for answer tokens**
    batch_answer_tokens = answer_tokens[start_idx:end_idx]

    with torch.no_grad():
        # Wrap metric_fn to use the batch-specific answer tokens
        batch_act_patch_resid_pre = patching.get_act_patch_block_every(
            model,
            batch_corrupted_tokens,
            batch_clean_cache,
            lambda logits: metric_fn(logits, answer_tokens=batch_answer_tokens)
        )

    all_act_patch_resid_pre.append(batch_act_patch_resid_pre)
    torch.cuda.empty_cache()

all_act_patch_resid_pre = torch.stack(all_act_patch_resid_pre, 0).mean(0)
labels = [f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

  0%|          | 0/364 [00:00<?, ?it/s]

In [18]:
imshow(
    all_act_patch_resid_pre,
    x=labels,
    facet_col=0, # This argument tells plotly which dimension to split into separate plots
    facet_labels=["Residual Stream", "Attn Output", "MLP Output"], # Subtitles of separate plots
    title="Logit Difference From Patched Attn Head Output",
    labels={"x": "Sequence Position", "y": "Layer"},
    width=1000,
)

In [19]:
batch_size = 8
num_batches = len(corrupted_tokens) // batch_size

all_act_patch_resid_pre = []

for batch_idx in range(num_batches):
    start_idx = batch_idx * batch_size
    end_idx = min((batch_idx + 1) * batch_size, len(corrupted_tokens))

    # Batch slice for corrupted tokens and cache
    batch_corrupted_tokens = corrupted_tokens[start_idx:end_idx]
    batch_clean_cache = {
        key: value[start_idx:end_idx]
        for key, value in clean_cache.items()
    }
    # **Batch slice for answer tokens**
    batch_answer_tokens = answer_tokens[start_idx:end_idx]

    with torch.no_grad():
        # Wrap metric_fn to use the batch-specific answer tokens
        batch_act_patch_resid_pre = patching.get_act_patch_attn_head_out_all_pos(
            model,
            batch_corrupted_tokens,
            batch_clean_cache,
            lambda logits: metric_fn(logits, answer_tokens=batch_answer_tokens)
        )

    all_act_patch_resid_pre.append(batch_act_patch_resid_pre)
    torch.cuda.empty_cache()

all_act_patch_resid_pre = torch.stack(all_act_patch_resid_pre, 0).mean(0)
labels = [f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))]

  0%|          | 0/672 [00:00<?, ?it/s]

  0%|          | 0/672 [00:00<?, ?it/s]

  0%|          | 0/672 [00:00<?, ?it/s]

  0%|          | 0/672 [00:00<?, ?it/s]

In [20]:
imshow(
    all_act_patch_resid_pre,
    labels={"y": "Layer", "x": "Head"},
    title="attn_head_out Activation Patching (All Pos)",
    width=600
)