# Q&A example with TransformerLens

In [1]:
import circuitsvis as cv

import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix
import torch
import torch.nn as nn
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px

from jaxtyping import Float
from functools import partial

import matplotlib.pyplot as plt
import numpy as np

import pandas as pd

In [2]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [3]:
def vis_attn_patterns(model, text, layers, compact=True):
    str_tokens = model.to_str_tokens(text)
    logits, cache = model.run_with_cache(text, remove_batch_dim=True)

    if compact:
        for layer in layers:
            attention_pattern = cache["pattern", layer]
            display(cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern))
    
    else:
        for layer in layers:
            attention_pattern = cache["pattern", layer]
            display(cv.attention.attention_heads(tokens=str_tokens, attention=attention_pattern))

We set our model:

In [4]:
model = HookedTransformer.from_pretrained("gpt2-medium")

Loaded pretrained model gpt2-medium into HookedTransformer


We import and embed the example text:

In [5]:
import json 

'''
with open("Test.json","r") as infile: 
    infile.readline()
    lines = infile.readlines()
    '''

with open("Test.json", "r") as infile:
    data = json.load(infile)

#data = [json.loads(line.strip("\n").rstrip(',').rstrip("]")) for line in lines]

questions = [line["question"] for line in data]
answers = [line["rationale"] for line in data]

#Embedding the questions and answers
question_embeds = model.to_tokens(questions)
answer_embeds = model.to_tokens(answers)

str_questions = model.to_str_tokens(questions)
str_answers = model.to_str_tokens(answers)

print(len(question_embeds), len(answer_embeds))
print(questions)

10 10
['0+0=', '1+2=', '2+4=', '3+6=', '4+8=', '5+10=', '6+12=', '7+14=', '8+16=', '9+18=']


In [6]:
vis_attn_patterns(model, questions[0], layers=[0])
vis_attn_patterns(model, questions[1], layers=[0])

Trying activation patching

In [7]:
clean_prompt = "What is the capital of France?"
corrupted_prompt = "What is the capital of England?"

clean_answer = "Paris"
corrupted_answer = "London"

clean_logits, clean_cache = model.run_with_cache(clean_prompt)
corrupted_logits = model(corrupted_prompt)

clean_index = model.to_single_token(clean_answer)
corrupted_index = model.to_single_token(corrupted_answer)

clean_diff = clean_logits[0, -1, clean_index] - clean_logits[0, -1, corrupted_index]
print(f"Clean answer logit difference: {clean_diff:.4f}")

corrupted_diff = corrupted_logits[0, -1, clean_index] - corrupted_logits[0, -1, corrupted_index]
print(f"Corrupted answer logit difference: {corrupted_diff:.4f}")

Clean answer logit difference: 3.1147
Corrupted answer logit difference: -4.0909


Then we want to patch the clean prompt onto the corrupted prompt.

In [16]:
def activation_patching_hook(resid_pre, hook, position, clean_cache):
    clean_activation = clean_cache[hook.name]
    resid_pre[:, position, :] = clean_activation[:, position, :]
    return resid_pre

def model_data(model, prompt):
    tokens = model.to_tokens(prompt)
    logits, cache = model.run_with_cache(tokens)
    return tokens, logits, cache

#num_positions = len(model.to_tokens(clean_prompt)[0])
#patching_result = torch.zeros((model.cfg.n_layers, num_positions), device=model.cfg.device)

def activation_patching(model, clean_prompt, corrupted_prompt, clean_answer, corrupted_answer):
    clean_logits, clean_cache = model.run_with_cache(clean_prompt)
    corrupted_logits = model(corrupted_prompt)

    clean_index = model.to_single_token(clean_answer)
    corrupted_index = model.to_single_token(corrupted_answer)

    clean_diff = clean_logits[0, -1, clean_index] - clean_logits[0, -1, corrupted_index]
    corrupted_diff = corrupted_logits[0, -1, clean_index] - corrupted_logits[0, -1, corrupted_index]

    corrupted_tokens = model.to_tokens(corrupted_prompt)
    num_positions = len(model.to_tokens(clean_prompt)[0])
   
    patching_result = torch.zeros((model.cfg.n_layers, num_positions), device=model.cfg.device)
    for layer in tqdm.tqdm(range(model.cfg.n_layers)):
        for position in range(num_positions):
            # We use a temporary hook with functool.partial to patch at each position
            temp_hook = partial(activation_patching_hook, position=position, clean_cache=clean_cache)
            # We then run the model with hooks as usual
            patched_logits = model.run_with_hooks(corrupted_tokens, 
                                                  fwd_hooks=[(utils.get_act_name("resid_pre", layer), temp_hook)])
            # We then calculate the logit difference
            patched_diff = (patched_logits[0, -1, clean_index] - patched_logits[0, -1, corrupted_index]).detach()
            # We then store the result in the patching_result tensor, normalizing it
            patching_result[layer, position] = (patched_diff - corrupted_diff) / (clean_diff - corrupted_diff)

    return patching_result, patched_logits


In [83]:

print(list(clean_cache.keys()))
print("Looking for; ", utils.get_act_name("activation", 0))
patching_result = activation_patching(model, clean_prompt, model.to_tokens(corrupted_prompt))



['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'blocks.2.hook_resid_pre', 'blocks.2.ln1.ho

  0%|          | 0/12 [00:00<?, ?it/s]

Layer 0, Position 0
Layer 0, Position 1
Layer 0, Position 2
Layer 0, Position 3
Layer 0, Position 4
Layer 0, Position 5
Layer 0, Position 6
Layer 0, Position 7
Layer 1, Position 0
Layer 1, Position 1
Layer 1, Position 2
Layer 1, Position 3
Layer 1, Position 4
Layer 1, Position 5
Layer 1, Position 6
Layer 1, Position 7
Layer 2, Position 0
Layer 2, Position 1
Layer 2, Position 2
Layer 2, Position 3
Layer 2, Position 4
Layer 2, Position 5
Layer 2, Position 6
Layer 2, Position 7
Layer 3, Position 0
Layer 3, Position 1
Layer 3, Position 2
Layer 3, Position 3
Layer 3, Position 4
Layer 3, Position 5
Layer 3, Position 6
Layer 3, Position 7
Layer 4, Position 0
Layer 4, Position 1
Layer 4, Position 2
Layer 4, Position 3
Layer 4, Position 4
Layer 4, Position 5
Layer 4, Position 6
Layer 4, Position 7
Layer 5, Position 0
Layer 5, Position 1
Layer 5, Position 2
Layer 5, Position 3
Layer 5, Position 4
Layer 5, Position 5
Layer 5, Position 6
Layer 5, Position 7
Layer 6, Position 0
Layer 6, Position 1


We then visualize our results

In [9]:
token_labels = [f'{token}_{index}' for index, token in enumerate(model.to_str_tokens(clean_prompt))]
imshow(patching_result, x=token_labels, xaxis="Position", yaxis="Layer", title="Patching Result",)

NameError: name 'patching_result' is not defined

When patching the clean results onto the corrupted results, the change is first very localized, but is then brought to the end in the last few layers. 
We now want to try with a more complicated example.

In [10]:
clean_physics_prompt = "When you drop a ball, it falls to the ground because of"
corrupted_physics_prompt = "When you throw a ball, it falls to the ground because of"
clean_answer = " gravity"
corrupted_answer = " air"

In [15]:
model.generate(clean_physics_prompt, max_new_tokens=10, temperature=0.0, top_p=1.0, do_sample=False)
#utils.test_prompt(clean_physics_prompt, clean_answer, model)

  0%|          | 0/10 [00:00<?, ?it/s]

'When you drop a ball, it falls to the ground because of gravity. When you drop a ball, it falls'

In [17]:
patching_result_physics = activation_patching(model, clean_physics_prompt, corrupted_physics_prompt, clean_answer, corrupted_answer)
token_labels = [f'{token}_{index}' for index, token in enumerate(model.to_str_tokens(clean_physics_prompt))]
imshow(patching_result_physics[0], x=token_labels, xaxis="Position", yaxis="Layer", title="Patching Result",)

  0%|          | 0/24 [00:00<?, ?it/s]

In [19]:
# What about the probability distribution now?
patching_logits = patching_result_physics[1][0, -1, :]
# Get the logits for the last position
logits = patching_result_physics[1][0, -1, :]  # shape: (vocab_size,)

# Convert logits to probabilities
probs = logits.softmax(dim=-1)

# Get the top 10 token indices and their probabilities
top_probs, top_indices = probs.topk(10)

# Decode the tokens to strings
top_tokens = [model.to_string([idx.item()]) for idx in top_indices]

# Print the results
for token, prob in zip(top_tokens, top_probs):
    print(f"{token!r}: {prob.item():.4f}")

' gravity': 0.6126
' the': 0.1128
' its': 0.0351
' friction': 0.0283
' a': 0.0258
' inertia': 0.0244
' two': 0.0168
' momentum': 0.0128
' gravitational': 0.0125
' your': 0.0073


This seems to work!!!

But can we make a function that does not require the questions to be the same length? And where the answers can be multiple words?

In [38]:
x = []
y = []
induction = []
dict_questions = {}
for question, question_token in zip(questions, question_embeds):
    length = len(question_token)
    
    induction_score_store = np.zeros((model.cfg.n_layers, model.cfg.n_heads))

    # A function for the average induction score
    def induction_score_hook(activation_pattern, hook):
        """
        Computes the average induction score for a given activation pattern.
        
        Args:
            activation_pattern (torch.Tensor): The activation pattern to compute the induction score for.
            hook (HookPoint): The hook point that triggered this function.
            
        Returns:
            torch.Tensor: The average induction score.
        """
        # We take the diagonal of attention paid from each destination position to source positions seq_len-1 tokens back
        # (This only has entries for tokens with index>=seq_len)
        induction_stripe = activation_pattern.diagonal(dim1=-2, dim2=-1, offset=1-length//2)
        # Get an average score per head
        induction_score = einops.reduce(induction_stripe, "batch head_index position -> head_index", "mean")
        # Store the result.
        induction_score_store[hook.layer(), :] = induction_score

    # We make a boolean filter on activation names, that's true only on attention pattern names.
    pattern_hook_names_filter = lambda name: name.endswith("pattern")
    print(question)
    # Run with hooks (this is where we write to the `induction_score_store` tensor`)
    model.run_with_hooks(question_token, 
        return_type=None, # For efficiency, we don't need to calculate the logits
        fwd_hooks=[(
            pattern_hook_names_filter,
            induction_score_hook
        )]
    )
    #print(induction_score_store)
    # Get the global max value and its flat index
    max_value = np.max(induction_score_store)
    max_index = np.argmax(induction_score_store)

    # Convert flat index to (layer, head) coordinates
    #max_index = np.unravel_index(max_index_flat.cpu().numpy(), induction_score_store.shape)
    #print("Question:", question)
    #print("Max value:", max_value.item())
    #print("Max index (layer, head):", max_index)
    '''
    if (max_index[0], max_index[1]) not in dict_questions:
        dict_questions[(max_index[0], max_index[1])] = [question]
    else:
        dict_questions[(max_index[0], max_index[1])].append(question)
    '''
    print(max_index[0])
    x.append(max_index[0])
    y.append(max_index[1])
    print(x)
    induction.append(induction_score_store)
 
print(x, y)
df = pd.DataFrame({
    "Layer": x,
    "Head": y,
    "Question": questions[:len(x)]
})
#print(dict_questions[(int(x[0]), int(y[0]))])
px.scatter(
    data_frame=df,
    x="Layer",
    y="Head",
    hover_data=["Question"],
    title="Induction Scores for Questions"
).show()
#px.scatter(x=np.array(x), y=np.array(y), labels={"x": "Layer", "y": "Head"},  title="Induction Scores for Questions", hover_data=["question"]).show()

0+0=


RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

The problem here is that many of the questions have the same max index. Therefore, it might not be the best method to visualize the inuction heads.

In [None]:
from sklearn.manifold import TSNE

tsne = TSNE(n_components=2, perplexity=30, random_state=42, learning_rate="auto", early_exaggeration=5.0, n_iter=5000, init="pca")
induction_embeds = tsne.fit_transform(np.array(induction))



NameError: name 'np' is not defined