# Q&A example with TransformerLens

In [6]:
import circuitsvis as cv

import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix
import torch
import torch.nn as nn
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px

from jaxtyping import Float
from functools import partial

In [7]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [8]:
def vis_attn_patterns(model, text, layers, compact=True):
    str_tokens = model.to_str_tokens(text)
    logits, cache = model.run_with_cache(text, remove_batch_dim=True)

    if compact:
        for layer in layers:
            attention_pattern = cache["pattern", layer]
            display(cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern))
    
    else:
        for layer in layers:
            attention_pattern = cache["pattern", layer]
            display(cv.attention.attention_heads(tokens=str_tokens, attention=attention_pattern))

We set our model:

In [9]:
model = HookedTransformer.from_pretrained("gpt2-small")

Loaded pretrained model gpt2-small into HookedTransformer


We import and embed the example text:

In [None]:
import json 

with open("test.tok.json","r") as infile: 
    infile.readline()
    lines = infile.readlines()

data = [json.loads(line.strip("\n").rstrip(',').rstrip("]")) for line in lines]

questions = [line["question"] for line in data]
answers = [line["rationale"] for line in data]

# Embedding the questions and answers
question_embeds = model.to_tokens(questions)
answer_embeds = model.to_tokens(answers)

str_questions = model.to_str_tokens(questions)
str_answers = model.to_str_tokens(answers)

: 

In [None]:
seq_len = len(str_answers)//2
induction_score_store = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)

# A function for the average induction score
def induction_score_hook(activation_pattern, hook):
    """
    Computes the average induction score for a given activation pattern.
    
    Args:
        activation_pattern (torch.Tensor): The activation pattern to compute the induction score for.
        hook (HookPoint): The hook point that triggered this function.
        
    Returns:
        torch.Tensor: The average induction score.
    """
    # We take the diagonal of attention paid from each destination position to source positions seq_len-1 tokens back
    # (This only has entries for tokens with index>=seq_len)
    induction_stripe = activation_pattern.diagonal(dim1=-2, dim2=-1, offset=1-seq_len)
    # Get an average score per head
    induction_score = einops.reduce(induction_stripe, "batch head_index position -> head_index", "mean")
    # Store the result.
    induction_score_store[hook.layer(), :] = induction_score

# We make a boolean filter on activation names, that's true only on attention pattern names.
pattern_hook_names_filter = lambda name: name.endswith("pattern")

# Run with hooks (this is where we write to the `induction_score_store` tensor`)
model.run_with_hooks(
    answer_embeds, 
    return_type=None, # For efficiency, we don't need to calculate the logits
    fwd_hooks=[(
        pattern_hook_names_filter,
        induction_score_hook
    )]
)

# Plot the induction scores for each head in each layer
fig = imshow(induction_score_store, xaxis="Head", yaxis="Layer", title="Induction Score by Head", text_auto=".2f")

vis_attn_patterns(model, answer_embeds, [5, 6, 7, 8])
