# Q&A example with TransformerLens

In [1]:
import circuitsvis as cv

import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix
import torch
import torch.nn as nn
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px

from jaxtyping import Float
from functools import partial

import matplotlib.pyplot as plt
import numpy as np

import pandas as pd

In [2]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [3]:
def vis_attn_patterns(model, text, layers, compact=True):
    str_tokens = model.to_str_tokens(text)
    logits, cache = model.run_with_cache(text, remove_batch_dim=True)

    if compact:
        for layer in layers:
            attention_pattern = cache["pattern", layer]
            display(cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern))
    
    else:
        for layer in layers:
            attention_pattern = cache["pattern", layer]
            display(cv.attention.attention_heads(tokens=str_tokens, attention=attention_pattern))

We set our model:

In [4]:
model = HookedTransformer.from_pretrained("gpt2-small")

Loaded pretrained model gpt2-small into HookedTransformer


We import and embed the example text:

In [5]:
import json 

'''
with open("Test.json","r") as infile: 
    infile.readline()
    lines = infile.readlines()
    '''

with open("Test.json", "r") as infile:
    data = json.load(infile)

#data = [json.loads(line.strip("\n").rstrip(',').rstrip("]")) for line in lines]

questions = [line["question"] for line in data]
answers = [line["rationale"] for line in data]

#Embedding the questions and answers
question_embeds = model.to_tokens(questions)
answer_embeds = model.to_tokens(answers)

str_questions = model.to_str_tokens(questions)
str_answers = model.to_str_tokens(answers)

print(len(question_embeds), len(answer_embeds))
print(questions)

FileNotFoundError: [Errno 2] No such file or directory: 'Test.json'

In [6]:
vis_attn_patterns(model, questions[0], layers=[0])
vis_attn_patterns(model, questions[1], layers=[0])

NameError: name 'questions' is not defined

In [5]:
# Test
clean = "The capital of Norway is"
answer = " Oslo"
corrupted = "The capital of Sweden is"
answer_2 = " Stockholm"

print(model.to_str_tokens([clean, answer]))
print(model.to_str_tokens([corrupted, answer_2]))

[['<|endoftext|>', 'The', ' capital', ' of', ' Norway', ' is'], ['<|endoftext|>', ' Oslo']]
[['<|endoftext|>', 'The', ' capital', ' of', ' Sweden', ' is'], ['<|endoftext|>', ' Stockholm']]


# Activation patching

In [7]:
clean_prompt = "What is the capital of France?"
corrupted_prompt = "What is the capital of England?"

clean_answer = "Paris"
corrupted_answer = "London"

clean_logits, clean_cache = model.run_with_cache(clean_prompt)
corrupted_logits = model(corrupted_prompt)

clean_index = model.to_single_token(clean_answer)
corrupted_index = model.to_single_token(corrupted_answer)

clean_diff = clean_logits[0, -1, clean_index] - clean_logits[0, -1, corrupted_index]
print(f"Clean answer logit difference: {clean_diff:.4f}")

corrupted_diff = corrupted_logits[0, -1, clean_index] - corrupted_logits[0, -1, corrupted_index]
print(f"Corrupted answer logit difference: {corrupted_diff:.4f}")

Clean answer logit difference: 2.1247
Corrupted answer logit difference: -5.3342


Then we want to patch the clean prompt onto the corrupted prompt.

In [6]:
def activation_patching_hook(resid_pre, hook, position, clean_cache):
    clean_activation = clean_cache[hook.name]
    resid_pre[:, position, :] = clean_activation[:, position, :]
    return resid_pre

def model_data(model, prompt):
    tokens = model.to_tokens(prompt)
    logits, cache = model.run_with_cache(tokens)
    return tokens, logits, cache

def activation_patching(model, clean_prompt, corrupted_prompt, clean_answer, corrupted_answer):
    '''
    Performs activation patching of the clean prompt onto the corrupted prompt. The prompts must have the same number of tokens.

    Returns: 
    patching_results (list[tensor[layers, positions]]): The logit difference after patching
    patched_logits (list[tensor[num_tokens, logits]]): The logits of the tokens after patching
    '''
    
    clean_logits, clean_cache = model.run_with_cache(clean_prompt)
    corrupted_logits = model(corrupted_prompt)
    print("Clean answer:",clean_answer)
    print("Corrupted answer:", corrupted_answer)
    clean_index = model.to_single_token(clean_answer)
    corrupted_index = model.to_single_token(corrupted_answer)

    clean_diff = clean_logits[0, -1, clean_index] - clean_logits[0, -1, corrupted_index]
    corrupted_diff = corrupted_logits[0, -1, clean_index] - corrupted_logits[0, -1, corrupted_index]

    clean_tokens = model.to_tokens(clean_prompt)
    corrupted_tokens= model.to_tokens(corrupted_prompt)
    num_positions = len(model.to_tokens(clean_prompt)[0])

    assert len(clean_tokens[0]) == len(corrupted_tokens[0]), "The prompts must have the same number of tokens."

   
    patching_result = torch.zeros((model.cfg.n_layers, num_positions), device=model.cfg.device)
    for layer in tqdm.tqdm(range(model.cfg.n_layers)):
        for position in range(num_positions):
            # We use a temporary hook with functool.partial to patch at each position
            temp_hook = partial(activation_patching_hook, position=position, clean_cache=clean_cache)
            # We then run the model with hooks as usual
            patched_logits = model.run_with_hooks(corrupted_tokens, 
                                                  fwd_hooks=[(utils.get_act_name("resid_pre", layer), temp_hook)])
            
            # We then calculate the logit difference
            patched_diff = (patched_logits[0, -1, clean_index] - patched_logits[0, -1, corrupted_index]).detach()
            # We then store the result in the patching_result tensor, normalizing it
            if abs(clean_diff-corrupted_diff) < 1e-16:
                patching_result[layer, position] = 0
            else:
                patching_result[layer, position] = abs((patched_diff - corrupted_diff) / (clean_diff - corrupted_diff))
    print(patched_logits.shape)
    return patching_result, patched_logits

def activation_patching_mult(model, clean_prompt, corrupted_prompt, clean_answer, corrupted_answer):
    ''' 
    Performs activation patching on prompts with multi-word answers by using separate run-throughs.
    The answers must have the same number of tokens
    '''
    patching_result = []
    patched_logits = []
    clean_answers_tokens = model.to_str_tokens(clean_answer)[1:]
    corrupted_answers_tokens = model.to_str_tokens(corrupted_answer)[1:]

    if len(clean_answers_tokens) == 1:
        return activation_patching(model, clean_prompt, corrupted_prompt, clean_answers_tokens, corrupted_answers_tokens)
    print(model.to_str_tokens(clean_prompt))
    print(model.to_str_tokens(corrupted_prompt))
    print("Number of run throughs:", len(clean_answers_tokens))
    for i in range(len(clean_answers_tokens)):
        p_result, p_logits = activation_patching(model, clean_prompt, corrupted_prompt, clean_answers_tokens[0], corrupted_answers_tokens[0])
        patching_result.append(p_result)
        patched_logits.append(p_logits)
        clean_prompt += clean_answers_tokens[0]
        clean_answers_tokens = clean_answers_tokens[1:]
        corrupted_prompt += corrupted_answers_tokens[0]
        corrupted_answers_tokens = corrupted_answers_tokens[1:]

    return patching_result, patched_logits

In [7]:

print(list(clean_cache.keys()))
print("Looking for; ", utils.get_act_name("activation", 0))
patching_result = activation_patching(model, clean_prompt, model.to_tokens(corrupted_prompt))



NameError: name 'clean_cache' is not defined

When patching the clean results onto the corrupted results, the change is first very localized, but is then brought to the end in the last few layers. 
We now want to try with a more complicated example.

In [8]:
clean_physics_prompt = "The action of adding numbers is called"
corrupted_physics_prompt = "The action of multiplying numbers is called"
clean_answer = " addition, and the result is their sum"
corrupted_answer = " multiplication, and the result is their product"

In [11]:
model.generate(clean_physics_prompt, max_new_tokens=10, temperature=0.0, top_p=1.0, do_sample=False)
utils.test_prompt(clean_physics_prompt, clean_answer, model)

  0%|          | 0/10 [00:00<?, ?it/s]

Tokenized prompt: ['<|endoftext|>', 'The', ' action', ' of', ' adding', ' numbers', ' is', ' called']
Tokenized answer: [' addition', ',', ' and', ' the', ' result', ' is', ' their', ' sum']


Top 0th token. Logit: 12.96 Prob: 14.15% Token: | "|
Top 1th token. Logit: 12.67 Prob: 10.58% Token: | the|
Top 2th token. Logit: 12.00 Prob:  5.38% Token: | a|
Top 3th token. Logit: 10.96 Prob:  1.91% Token: | '|
Top 4th token. Logit: 10.87 Prob:  1.74% Token: | an|
Top 5th token. Logit: 10.21 Prob:  0.90% Token: | for|
Top 6th token. Logit:  9.94 Prob:  0.69% Token: | adding|
Top 7th token. Logit:  9.89 Prob:  0.65% Token: | as|
Top 8th token. Logit:  9.73 Prob:  0.56% Token: | in|
Top 9th token. Logit:  9.69 Prob:  0.54% Token: | by|


Top 0th token. Logit: 16.07 Prob: 30.04% Token: |.|
Top 1th token. Logit: 15.66 Prob: 19.83% Token: |,|
Top 2th token. Logit: 15.09 Prob: 11.24% Token: | and|
Top 3th token. Logit: 14.39 Prob:  5.60% Token: | or|
Top 4th token. Logit: 14.15 Prob:  4.39% Token: | to|
Top 5th token. Logit: 13.51 Prob:  2.32% Token: | (|
Top 6th token. Logit: 13.45 Prob:  2.19% Token: | .|
Top 7th token. Logit: 13.28 Prob:  1.83% Token: | of|
Top 8th token. Logit: 12.91 Prob:  1.26% Token: | ,|
Top 9th token. Logit: 12.78 Prob:  1.11% Token: |/|


Top 0th token. Logit: 15.73 Prob: 50.26% Token: | and|
Top 1th token. Logit: 13.25 Prob:  4.21% Token: | which|
Top 2th token. Logit: 12.94 Prob:  3.08% Token: | but|
Top 3th token. Logit: 12.70 Prob:  2.44% Token: | as|
Top 4th token. Logit: 12.65 Prob:  2.32% Token: | so|
Top 5th token. Logit: 12.64 Prob:  2.30% Token: | or|
Top 6th token. Logit: 12.64 Prob:  2.29% Token: | because|
Top 7th token. Logit: 12.33 Prob:  1.68% Token: | not|
Top 8th token. Logit: 12.24 Prob:  1.54% Token: | the|
Top 9th token. Logit: 11.77 Prob:  0.96% Token: | a|


Top 0th token. Logit: 15.61 Prob: 26.87% Token: | it|
Top 1th token. Logit: 14.54 Prob:  9.26% Token: | is|
Top 2th token. Logit: 14.28 Prob:  7.13% Token: | the|
Top 3th token. Logit: 13.37 Prob:  2.86% Token: | in|
Top 4th token. Logit: 13.15 Prob:  2.30% Token: | when|
Top 5th token. Logit: 13.06 Prob:  2.11% Token: | adding|
Top 6th token. Logit: 13.02 Prob:  2.01% Token: | this|
Top 7th token. Logit: 12.67 Prob:  1.43% Token: | there|
Top 8th token. Logit: 12.65 Prob:  1.39% Token: | we|
Top 9th token. Logit: 12.61 Prob:  1.34% Token: | as|


Top 0th token. Logit: 12.76 Prob:  5.24% Token: | number|
Top 1th token. Logit: 12.15 Prob:  2.84% Token: | process|
Top 2th token. Logit: 11.85 Prob:  2.12% Token: | action|
Top 3th token. Logit: 11.70 Prob:  1.81% Token: | more|
Top 4th token. Logit: 11.69 Prob:  1.79% Token: | result|
Top 5th token. Logit: 11.62 Prob:  1.68% Token: | numbers|
Top 6th token. Logit: 11.61 Prob:  1.65% Token: | first|
Top 7th token. Logit: 11.32 Prob:  1.24% Token: | idea|
Top 8th token. Logit: 11.28 Prob:  1.19% Token: | effect|
Top 9th token. Logit: 11.27 Prob:  1.18% Token: | amount|


Top 0th token. Logit: 17.75 Prob: 67.82% Token: | is|
Top 1th token. Logit: 16.08 Prob: 12.78% Token: | of|
Top 2th token. Logit: 14.52 Prob:  2.69% Token: | can|
Top 3th token. Logit: 13.58 Prob:  1.05% Token: | will|
Top 4th token. Logit: 13.40 Prob:  0.88% Token: | depends|
Top 5th token. Logit: 13.23 Prob:  0.74% Token: | should|
Top 6th token. Logit: 13.15 Prob:  0.69% Token: | has|
Top 7th token. Logit: 12.89 Prob:  0.53% Token: |,|
Top 8th token. Logit: 12.82 Prob:  0.49% Token: | in|
Top 9th token. Logit: 12.73 Prob:  0.45% Token: | may|


Top 0th token. Logit: 15.24 Prob: 22.41% Token: | that|
Top 1th token. Logit: 14.77 Prob: 14.08% Token: | a|
Top 2th token. Logit: 14.66 Prob: 12.51% Token: | the|
Top 3th token. Logit: 13.40 Prob:  3.56% Token: | an|
Top 4th token. Logit: 12.75 Prob:  1.87% Token: | what|
Top 5th token. Logit: 12.39 Prob:  1.31% Token: | to|
Top 6th token. Logit: 12.37 Prob:  1.27% Token: | usually|
Top 7th token. Logit: 12.23 Prob:  1.11% Token: | often|
Top 8th token. Logit: 12.19 Prob:  1.07% Token: | something|
Top 9th token. Logit: 12.08 Prob:  0.95% Token: | not|


Top 0th token. Logit: 11.77 Prob:  2.14% Token: | own|
Top 1th token. Logit: 11.67 Prob:  1.94% Token: | number|
Top 2th token. Logit: 11.67 Prob:  1.94% Token: | effect|
Top 3th token. Logit: 11.54 Prob:  1.70% Token: | value|
Top 4th token. Logit: 11.01 Prob:  1.00% Token: | numbers|
Top 5th token. Logit: 10.98 Prob:  0.98% Token: | addition|
Top 6th token. Logit: 10.98 Prob:  0.98% Token: | total|
Top 7th token. Logit: 10.87 Prob:  0.87% Token: | ratio|
Top 8th token. Logit: 10.85 Prob:  0.86% Token: | combined|
Top 9th token. Logit: 10.80 Prob:  0.82% Token: | sum|


In [7]:
def imshow_patching_result(model, patching_results, corrupted_prompt, corrupted_answer):
    '''
    Visualizes the logit differences caused by activation patching in a heat map. If the answer has more than one token, "patching_results" must be a list of results.
    '''
    if isinstance(patching_results, list):
        print(f"The list has {len(patching_results)} elements.")
        len_ans = len(patching_results)
        for i in range(len_ans):
            tokens = model.to_str_tokens(corrupted_prompt+corrupted_answer)
            labels = [f'{token}_{index}' for index, token in enumerate(tokens)][:-len(patching_results)+i+1]
            print(labels)
            if not torch.all(patching_results[i] == 0):
                print(patching_results[i].shape)
                px.imshow(patching_results[i].detach(), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", x=labels, labels={"x": "Position", "y": "Layer"}, title="Patching Results").show()
    else:
        print("There is only one patching result.")
        tokens = model.to_str_tokens(corrupted_prompt)
        labels = [f'{token}_{index}' for index, token in enumerate(tokens)]
        print(labels)
        px.imshow(patching_results.detach(), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", x=labels, labels={"x": "Position", "y": "Layer"}, title="Patching Results").show()
    

In [39]:
patching_result_physics = activation_patching_mult(model, clean,corrupted, answer, answer_2)

Clean answer: [' Oslo']
Corrupted answer: [' Stockholm']


  0%|          | 0/24 [00:00<?, ?it/s]

torch.Size([1, 6, 50257])


In [40]:
imshow_patching_result(model, patching_result_physics[0], corrupted, answer_2)

There is only one patching result.
['<|endoftext|>_0', 'The_1', ' capital_2', ' of_3', ' Sweden_4', ' is_5']


It seems that even after multiple separate runs, the old patchings still influence the results slightly.

Question: If we always took the absolute value of the difference (ie no red), woul this lead to the same interpretation or is there a difference? 

In [11]:
# What about the probability distribution now?
def get_prediction(model, logits, num_top=5):
    if isinstance(logits, list):
        for i in range(len(logits)):
            logits_i = logits[i][0, -1, :]
            probs = logits_i.softmax(dim=-1)
            top_probs, top_indices = probs.topk(num_top)
            top_tokens = [model.to_string(index.item()) for index in top_indices]

            for token, prob in zip(top_tokens, top_probs):
                print(f'{token:>15}: {prob.item():.4f}, ', end='')
            print()
    
    else:
        logits_i = logits[0, -1, :]
        probs = logits_i.softmax(dim=-1)
        top_probs, top_indices = probs.topk(num_top)
        top_tokens = [model.to_string(index.item()) for index in top_indices]

        for token, prob in zip(top_tokens, top_probs):
            print(f'{token}: {prob.item():.4f}, ', end='')

get_prediction(model, patching_result_physics[1], num_top=2)


              ": 0.1507,             the: 0.1075, 
              .: 0.3018,               ,: 0.2002, 
            and: 0.5176,           which: 0.0428, 
             it: 0.2846,              is: 0.1009, 
         number: 0.0612,         process: 0.0323, 
             is: 0.6871,              of: 0.1305, 
           that: 0.2379,               a: 0.1375, 
            own: 0.0264,  multiplication: 0.0235, 


Analysis of results from activation patching:


# Attribution patching

A technique for patching multip

# Induction heads

In [38]:
x = []
y = []
induction = []
dict_questions = {}
for question, question_token in zip(questions, question_embeds):
    length = len(question_token)
    
    induction_score_store = np.zeros((model.cfg.n_layers, model.cfg.n_heads))

    # A function for the average induction score
    def induction_score_hook(activation_pattern, hook):
        """
        Computes the average induction score for a given activation pattern.
        
        Args:
            activation_pattern (torch.Tensor): The activation pattern to compute the induction score for.
            hook (HookPoint): The hook point that triggered this function.
            
        Returns:
            torch.Tensor: The average induction score.
        """
        # We take the diagonal of attention paid from each destination position to source positions seq_len-1 tokens back
        # (This only has entries for tokens with index>=seq_len)
        induction_stripe = activation_pattern.diagonal(dim1=-2, dim2=-1, offset=1-length//2)
        # Get an average score per head
        induction_score = einops.reduce(induction_stripe, "batch head_index position -> head_index", "mean")
        # Store the result.
        induction_score_store[hook.layer(), :] = induction_score

    # We make a boolean filter on activation names, that's true only on attention pattern names.
    pattern_hook_names_filter = lambda name: name.endswith("pattern")
    print(question)
    # Run with hooks (this is where we write to the `induction_score_store` tensor`)
    model.run_with_hooks(question_token, 
        return_type=None, # For efficiency, we don't need to calculate the logits
        fwd_hooks=[(
            pattern_hook_names_filter,
            induction_score_hook
        )]
    )
    #print(induction_score_store)
    # Get the global max value and its flat index
    max_value = np.max(induction_score_store)
    max_index = np.argmax(induction_score_store)

    # Convert flat index to (layer, head) coordinates
    #max_index = np.unravel_index(max_index_flat.cpu().numpy(), induction_score_store.shape)
    #print("Question:", question)
    #print("Max value:", max_value.item())
    #print("Max index (layer, head):", max_index)
    '''
    if (max_index[0], max_index[1]) not in dict_questions:
        dict_questions[(max_index[0], max_index[1])] = [question]
    else:
        dict_questions[(max_index[0], max_index[1])].append(question)
    '''
    print(max_index[0])
    x.append(max_index[0])
    y.append(max_index[1])
    print(x)
    induction.append(induction_score_store)
 
print(x, y)
df = pd.DataFrame({
    "Layer": x,
    "Head": y,
    "Question": questions[:len(x)]
})
#print(dict_questions[(int(x[0]), int(y[0]))])
px.scatter(
    data_frame=df,
    x="Layer",
    y="Head",
    hover_data=["Question"],
    title="Induction Scores for Questions"
).show()
#px.scatter(x=np.array(x), y=np.array(y), labels={"x": "Layer", "y": "Head"},  title="Induction Scores for Questions", hover_data=["question"]).show()

0+0=


RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

The problem here is that many of the questions have the same max index. Therefore, it might not be the best method to visualize the induction heads.

In [None]:
from sklearn.manifold import TSNE

tsne = TSNE(n_components=2, perplexity=30, random_state=42, learning_rate="auto", early_exaggeration=5.0, n_iter=5000, init="pca")
induction_embeds = tsne.fit_transform(np.array(induction))



NameError: name 'np' is not defined

In [57]:
print(model.to_tokens(""))

tensor([[50256]])


# Direct Logit Attribution

A technique to see what components contrinute the most to an output.

# Logit Lens

A method for seeing how the predictions change throughout a network. We "pretend" that a certain layer is the last layer and then see what the prediction would have been.