In [1]:
!pip install -q transformers datasets accelerate git+https://github.com/TransformerLensOrg/TransformerLens git+https://github.com/neelnanda-io/neel-plotly.git  kaleido

In [2]:
import torch
from transformer_lens import HookedTransformer, ActivationCache, patching, HookedTransformerConfig
import transformer_lens
from transformer_lens.patching import (
    generic_activation_patch,
    get_act_patch_resid_pre,
    get_act_patch_attn_out,
    get_act_patch_mlp_out,
)
from transformer_lens.ActivationCache import ActivationCache
from transformer_lens.utils import *
# from .autonotebook import tqdm as notebook_tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import torch.nn.functional as F
from accelerate import init_empty_weights
from typing import List, Tuple
from copy import deepcopy
import matplotlib.pyplot as plt
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader
import os
from neel_plotly import line, imshow, scatter
import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/TransformerLensOrg/TransformerLens.git
    # Install my janky personal plotting utils
    %pip install git+https://github.com/neelnanda-io/neel-plotly.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Jupyter notebook - intended for development only!


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [3]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [3]:
import plotly.io as pio
pio.renderers.default = "colab"

In [None]:
import huggingface_hub
huggingface_hub.login()

In [5]:
from accelerate import Accelerator
accelerator = Accelerator()
print(accelerator.device)

cuda


In [6]:
model_name = "meta-llama/Llama-3.2-3B-Instruct"
# model_name = 'gpt2-small'
# model_name = "google/gemma-2-2b-it"
# mdoel_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"

In [7]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f9e0c144cb0>

In [8]:
model = HookedTransformer.from_pretrained(model_name 
                                        #   move_to_device=True, 
                                        #   n_devices=torch.cuda.device_count()
                                         )

Loading checkpoint shards: 100%|██████████| 2/2 [04:19<00:00, 129.50s/it]


Loaded pretrained model meta-llama/Llama-3.2-3B-Instruct into HookedTransformer


In [9]:
model.cfg.device

device(type='cuda')

# Activation Patching

In [10]:
prompts = ["What is 1234567 + 7654321?", "What is 8326542 + 1673345?", "What is 3785816 + 6214087?"]
answers = [("8888888", "9999999"), ("9999887", "1235789"), ("9999903", "2170830")]
prompts_c = ["What is 7654321 + 1234567?", "What is 1673345 + 8326542?", "What is 6214087 + 3785816?"]
PAD_TOKEN = model.tokenizer.pad_token_id

In [24]:
# Step 2: Tokenize Prompts and Answers
clean_tokens = model.to_tokens(prompts)

# Determine the device of the model's tokens
device = clean_tokens.device

# Tokenize correct and incorrect answers, and ensure token count consistency
answer_token_indices = []
for correct_answer, incorrect_answer in answers:
    # Tokenize both correct and incorrect answers
    correct_tokens = model.to_tokens(correct_answer)[0].to(device)
    incorrect_tokens = model.to_tokens(incorrect_answer)[0].to(device)
    
    # Ensure they have the same number of tokens
    if len(correct_tokens) < len(incorrect_tokens):
        # Pad the correct tokens
        correct_tokens = torch.cat(
            [correct_tokens, torch.full((len(incorrect_tokens) - len(correct_tokens),), PAD_TOKEN, dtype=torch.long, device=device)]
        )
    elif len(correct_tokens) > len(incorrect_tokens):
        # Pad the incorrect tokens
        incorrect_tokens = torch.cat(
            [incorrect_tokens, torch.full((len(correct_tokens) - len(incorrect_tokens),), PAD_TOKEN, dtype=torch.long, device=device)]
        )

    # Append tokenized answers to the list
    answer_token_indices.append((correct_tokens, incorrect_tokens))

# Step 3: Create Corrupted Tokens
# Replace the correct answer tokens with the incorrect answer tokens
corrupted_tokens = model.to_tokens(prompts_o).to(device)

for i, (correct_tokens, incorrect_tokens) in enumerate(answer_token_indices):
    # Locate where the model outputs the answer in the sequence
    answer_start_idx = -len(correct_tokens)  # Assuming the answer is at the end of the sequence
    answer_end_idx = corrupted_tokens.shape[1]

    # Replace the correct answer tokens with the incorrect answer tokens
    corrupted_tokens[i, answer_start_idx:answer_end_idx] = incorrect_tokens

# Step 4: Run the Model
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)


In [25]:
def arithmetic_metric(logits, answer_token_indices, corrupted_baseline, clean_baseline):
    """
    Compute the arithmetic metric to evaluate the model's performance on arithmetic tasks.

    Args:
        logits: The model's output logits (shape: [batch_size, seq_len, vocab_size]).
        answer_token_indices: A list of tuples [(correct_tokens, incorrect_tokens)] where:
                              - correct_tokens: Token IDs for the correct answer (tensor of shape [seq_len_correct]).
                              - incorrect_tokens: Token IDs for the incorrect answer (tensor of shape [seq_len_incorrect]).
        corrupted_baseline: Metric value for the corrupted baseline (0).
        clean_baseline: Metric value for the clean baseline (1).

    Returns:
        Metric value as a PyTorch tensor.
    """
    batch_size = logits.shape[0]
    total_logit_diffs = []

    for i in range(batch_size):
        # Extract logits for the final positions (where the answers are expected)
        correct_tokens = answer_token_indices[i][0]
        incorrect_tokens = answer_token_indices[i][1]

        # Get the logits for the correct and incorrect answers
        correct_logits = logits[i, -len(correct_tokens):, :].gather(1, correct_tokens.unsqueeze(1)).sum()
        incorrect_logits = logits[i, -len(incorrect_tokens):, :].gather(1, incorrect_tokens.unsqueeze(1)).sum()

        # Calculate the logit difference for this example
        logit_diff = correct_logits - incorrect_logits
        total_logit_diffs.append(logit_diff)

    # Average the logit differences across the batch
    avg_logit_diff = torch.stack(total_logit_diffs).mean()

    # Normalize the metric between corrupted and clean baselines
    return (avg_logit_diff - corrupted_baseline) / (clean_baseline - corrupted_baseline)


In [26]:
# Define the wrapper function for arithmetic_metric
def arithmetic_metric_wrapper(logits):
    return arithmetic_metric(
        logits,
        answer_token_indices=answer_token_indices,
        corrupted_baseline=CORRUPTED_BASELINE,
        clean_baseline=CLEAN_BASELINE
    )


In [27]:
# Compute baselines
CORRUPTED_BASELINE = arithmetic_metric(corrupted_logits, answer_token_indices, 0, 1)
CLEAN_BASELINE = arithmetic_metric(clean_logits, answer_token_indices, 0, 1)

print(f"Corrupted Baseline: {CORRUPTED_BASELINE:.4f}")
print(f"Clean Baseline: {CLEAN_BASELINE:.4f}")


Corrupted Baseline: -5.9844
Clean Baseline: -1.6410


In [28]:
every_block_result = patching.get_act_patch_block_every(
    model, corrupted_tokens, clean_cache, arithmetic_metric_wrapper
)


100%|██████████| 364/364 [00:51<00:00,  7.04it/s]
100%|██████████| 364/364 [00:51<00:00,  7.03it/s]
100%|██████████| 364/364 [00:51<00:00,  7.05it/s]


In [29]:
img = imshow(every_block_result, 
        facet_col=0, 
        facet_labels=["Residual Stream", "Attn Output", "MLP Output"], 
        title="Activation Patching Per Block", 
        xaxis="Position",
        yaxis="Layer", 
        zmax=1, 
        zmin=-1, 
        x=[f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
        return_fig=True
    )

In [30]:
img.write_image('./fig_o1.png')

In [31]:
attn_head_out_all_pos_act_patch_results = patching.get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, arithmetic_metric_wrapper)
img2 = imshow(attn_head_out_all_pos_act_patch_results, 
       yaxis="Layer", 
       xaxis="Head", 
       title="attn_head_out Activation Patching (All Pos)",
       return_fig=True
    )
img2.write_image('./fig_o2.png')

 12%|█▎        | 84/672 [00:13<01:30,  6.53it/s]

100%|██████████| 672/672 [01:43<00:00,  6.47it/s]


In [32]:
ALL_HEAD_LABELS = [f"L{i}H{j}" for i in range(model.cfg.n_layers) for j in range(model.cfg.n_heads)]

attn_head_out_act_patch_results = patching.get_act_patch_attn_head_out_by_pos(model, corrupted_tokens, clean_cache, arithmetic_metric_wrapper)
attn_head_out_act_patch_results = einops.rearrange(attn_head_out_act_patch_results, "layer pos head -> (layer head) pos")
img3 = imshow(attn_head_out_act_patch_results, 
        yaxis="Head Label", 
        xaxis="Pos", 
        x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
        y=ALL_HEAD_LABELS,
        title="attn_head_out Activation Patching By Pos", 
        return_fig=True
    )
img3.write_image('./fig_o3.png')

  2%|▏         | 186/8736 [00:28<21:49,  6.53it/s]

100%|██████████| 8736/8736 [22:31<00:00,  6.47it/s]
