In [1]:
import functools
import sys
from pathlib import Path
from typing import Callable
import circuitsvis as cv
import einops
import numpy as np
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from eindex import eindex
from IPython.display import display
from jaxtyping import Float, Int
from torch import Tensor
from tqdm import tqdm
from transformer_lens import (
    ActivationCache,
    FactoredMatrix,
    HookedTransformer,
    HookedTransformerConfig,
    utils,
)
from transformer_lens.hook_points import HookPoint

device = t.device(
    "mps"
    if t.backends.mps.is_available()
    else "cuda" if t.cuda.is_available() else "cpu"
)

# Make sure exercises are in the path
chapter = "chapter1_transformer_interp"
section = "part2_intro_to_mech_interp"
root_dir = next(p for p in Path.cwd().parents if (p / chapter).exists())
exercises_dir = root_dir / chapter / "exercises"
section_dir = exercises_dir / section
if str(exercises_dir) not in sys.path:
    sys.path.append(str(exercises_dir))

import part2_intro_to_mech_interp.tests as tests

from plotly_utils import (
    hist,
    imshow,
    plot_comp_scores,
    plot_logit_attribution,
    plot_loss_difference,
)

# Saves computation time, since we don't need it for the contents of this notebook
t.set_grad_enabled(False)
MAIN = __name__ == "__main__"

In [2]:
# Set up 2-layer ATTN-ONLY shortformer
cfg = HookedTransformerConfig(
    d_model=768,
    d_head=64,
    n_heads=12,
    n_layers=2,
    n_ctx=2048,
    d_vocab=50278,
    attention_dir="causal",
    attn_only=True,  # defaults to False
    tokenizer_name="EleutherAI/gpt-neox-20b",
    seed=398,
    use_attn_result=True,
    normalization_type=None,  # defaults to "LN", i.e. layernorm with weights & biases
    positional_embedding_type="shortformer", # positional embeddings cannot be explicitly moved since in xW_VW_O there is no positional embedding
)

model = HookedTransformer(cfg)

In [3]:
# Import trained weights for our transformer
from huggingface_hub import hf_hub_download
REPO_ID = "callummcdougall/attn_only_2L_half"
FILENAME = "attn_only_2L_half.pth"
weights_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
pretrained_weights = t.load(weights_path, map_location=device, weights_only=True)
model.load_state_dict(pretrained_weights)

<All keys matched successfully>

In [4]:
# Get repeated text
def generate_repeated_tokens(model: HookedTransformer, seq_len: int, batch_size: int = 1):
    """
    Generates a sequence of repeated random tokens of the form:
    <BOS> sequence sequence

    Outputs are:
        rep_tokens: [batch_size, 1+2*seq_len]
    """
    prefix = (t.ones(batch_size, 1) * model.tokenizer.bos_token_id).long()
    seq = t.randint(0, model.cfg.d_vocab, (batch_size, seq_len))
    repeated_seq = einops.repeat(seq, "b s -> b (repeat s)", repeat=2)
    return t.cat((prefix, repeated_seq), dim=-1)

In [5]:
def run_and_cache_model_repeated_tokens(
    model: HookedTransformer, seq_len: int, batch_size: int = 1
) -> tuple[Tensor, Tensor, ActivationCache]:
    """
    Generates a sequence of repeated random tokens, and runs the model on it, returning (tokens, logits, cache). This
    function should use the `generate_repeated_tokens` function above

    Outputs are:
        rep_tokens: [batch_size, 1+2*seq_len]
        rep_logits: [batch_size, 1+2*seq_len, d_vocab]
        rep_cache: The cache of the model run on rep_tokens
    """
    rep_tokens = generate_repeated_tokens(model, seq_len, batch_size)
    rep_logits, rep_cache = model.run_with_cache(rep_tokens)
    return rep_tokens, rep_logits, rep_cache

In [6]:
# First method of analysis: visualizing the attn patterns
tokens = generate_repeated_tokens(model, 50, 1)
logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)
for layer in range(model.cfg.n_layers):
    attn = cache["pattern", layer]
    display(
        cv.attention.attention_patterns(
            tokens=model.to_str_tokens(tokens),
            attention=attn,
            attention_head_names=[f"L{layer}H{i}" for i in range(12)],
        )
    )

Upon inspection, we can see that head 0.7 is strongly prev-token, and heads 1.4 and 1.10 attend to the token following the prev occurrence. This suggests induction, although we don't technically have proof that they are performing induction and successfully copying. For this, we turn to logit attribution.

In [8]:
# Method 2: Logit attribution
def logit_attribution(
    embed: Float[Tensor, "seq d_model"],
    l1_results: Float[Tensor, "seq nheads d_model"],
    l2_results: Float[Tensor, "seq nheads d_model"],
    W_U: Float[Tensor, "d_model d_vocab"],
    tokens: Int[Tensor, "seq"],
) -> Float[Tensor, "seq-1 n_components"]:
    """
    Inputs:
        embed: the embeddings of the tokens (i.e. token + position embeddings)
        l1_results: the outputs of the attention heads at layer 1 (with head as one of the dimensions)
        l2_results: the outputs of the attention heads at layer 2 (with head as one of the dimensions)
        W_U: the unembedding matrix
        tokens: the token ids of the sequence

    Returns:
        Tensor of shape (seq_len-1, n_components)
        represents the concatenation (along dim=-1) of logit attributions from:
            the direct path (seq-1,1)
            layer 0 logits (seq-1, n_heads)
            layer 1 logits (seq-1, n_heads)
        so n_components = 1 + 2*n_heads
    """
    W_U_correct_tokens = W_U[:, tokens[1:]]

    direct_attributions = einops.einsum(W_U_correct_tokens, embed[:-1], "emb seq, seq emb -> seq")
    l1_attributions = einops.einsum(W_U_correct_tokens, l1_results[:-1], "emb seq, seq nhead emb -> seq nhead")
    l2_attributions = einops.einsum(W_U_correct_tokens, l2_results[:-1], "emb seq, seq nhead emb -> seq nhead")
    return t.concat(
        [direct_attributions.unsqueeze(-1), l1_attributions, l2_attributions], dim=-1
    )

logit_attr = logit_attribution(cache["embed"], cache["result", 0], cache["result", 1], model.W_U, tokens.squeeze())

plot_logit_attribution(
    model,
    logit_attr,
    tokens.squeeze(),
    title="Logit attribution (random induction prompt)",
)

We see that heads 1.4 and 1.10 are in fact dominating the logit attribution for the second half of the repeated and strongly increasing the logits for the correct prediction. This is more concrete evidence that they are performing copying.

In [6]:
def get_log_probs(
    logits: Float[Tensor, "batch posn d_vocab"], tokens: Int[Tensor, "batch posn"]
) -> Float[Tensor, "batch posn-1"]:
    logprobs = logits.log_softmax(dim=-1)
    # We want to get logprobs[b, s, tokens[b, s+1]], in eindex syntax this looks like:
    correct_logprobs = eindex(logprobs, tokens, "b s [b s+1]")
    return correct_logprobs

In [9]:
# Method 3 - Ablations
seq_len = 50
batch_size = 10
tokens = generate_repeated_tokens(model, seq_len, batch_size)

def head_zero_ablation_hook(z: Float[Tensor, "batch seq n_heads d_head"],hook: HookPoint, head_index_to_ablate: int,) -> None:
    """z is the weighted sum of value vectors, before applying W_O"""
    z[:, :, head_index_to_ablate, :] = 0.0

def head_mean_ablation_hook(z: Float[Tensor, "batch seq n_heads d_head"],hook: HookPoint, head_index_to_ablate: int,) -> None:
    """Each head outputs batch size many z vectors for each residual position. Before projecting via W_O, we now take 
    the mean over the batch (the average behavior of the head), and project the mean vector. Does nothing for batch_size = 1"""
    mean = z[:, :, head_index_to_ablate, :].mean(dim = 0)
    z[:, :, head_index_to_ablate, :] = mean # broadcasting

def get_ablation_scores(
    model: HookedTransformer,
    tokens: Int[Tensor, "batch seq"],
    ablation_function: Callable = head_zero_ablation_hook,
) -> Float[Tensor, "n_layers n_heads"]:
    ablation_loss = t.zeros((model.cfg.n_layers, model.cfg.n_heads))
    model.reset_hooks()
    for layer in range(model.cfg.n_layers):
        for head in range(model.cfg.n_heads):
            hook_fn = functools.partial(ablation_function, head_index_to_ablate=head)
            ablated_logits = model.run_with_hooks(tokens, fwd_hooks = [(utils.get_act_name("z", layer), hook_fn)])
            ablated_loss = -get_log_probs(ablated_logits.log_softmax(-1), tokens)[:, -(seq_len - 1) :].mean()
            ablation_loss[layer, head] = ablated_loss

    logits_no_ablation = model(tokens, return_type="logits")
    loss_no_ablation = -get_log_probs(logits_no_ablation, tokens)[:, -(seq_len - 1) :].mean()
    loss_differences = ablation_loss - loss_no_ablation
    return loss_differences

zero_ablation_loss_differences = get_ablation_scores(model, tokens, head_zero_ablation_hook)

imshow(
    zero_ablation_loss_differences,
    labels={"x": "Head", "y": "Layer", "color": "Loss diff"},
    title="Loss Difference After Zero Ablation",
    text_auto=".2f",
    width=900,
    height=350,
)

mean_ablation_loss_differences = get_ablation_scores(model, tokens, head_mean_ablation_hook)

imshow(
    mean_ablation_loss_differences,
    labels={"x": "Head", "y": "Layer", "color": "Loss diff"},
    title="Loss Difference After Mean Ablation",
    text_auto=".2f",
    width=900,
    height=350,
)

While we suspected head 0.7 was part of the induction circuit from visualizing its attention pattern, direct logit attribution
does not show anything for it. By "switching off" head 0.7 and observing a big hit on the loss, we see the transformer is
not copying as well (we are looking at second half of repeated sequence, so copying = right answer), providing evidence
that head 0.7 is an important part of the induction circuit.

From the mean-ablation experiment, two new heads emerge as important to the induction circuit: heads 0.4 and heads 0.11.
What are they doing in the induction circuit? 0.11 often strongly attends to current token, while 0.4  attends to
the past few tokens, and most strongly to two tokens before. TBD.

Note: In this notebook, we have identified attention heads that comprise an induction circuit. However, we have NOT
reverse-engineered the circuit, in that we dont know exactly how it is working: is it Q-comp? K-comp? V-comp?

----------------------------------- REVERSE ENGINEERING --------------------------------------------

The full OV circuit is W_E @ W_V @ W_O @ W_U.

This is a bilinear form, where token_i.T @ full_OV_Circuit @ token_j is what token_i adds to the logits for token_j if token_j attends to token_i.

In [7]:
head_index = 4
layer = 1


W_O = model.W_O[layer, head_index]
W_V = model.W_V[layer, head_index]
W_E = model.W_E
W_U = model.W_U
OV_circuit = FactoredMatrix(W_V, W_O)
full_OV_circuit = W_E @ OV_circuit @ W_U

tests.test_full_OV_circuit(full_OV_circuit, model, layer, head_index)

All tests in `test_full_OV_circuit` passed!


In [8]:
indices = t.randint(0, model.cfg.d_vocab, (200,))
full_OV_circuit_sample = full_OV_circuit[indices, indices].AB # Computes A[indices, :] @ B[:, indices] which is equivalent to (A @ B)[indices, indices]

imshow(
    full_OV_circuit_sample,
    labels={"x": "Logits on output token", "y": "Input token"},
    title="Full OV circuit for copying head",
    width=700,
    height=600,
)

In [9]:
def top_1_acc(full_OV_circuit: FactoredMatrix, batch_size: int = 1000) -> float:
    """
    Compute the argmax of each column (ie over dim=0) and return the fraction of the time that the maximum value is on
    the circuit diagonal.
    """
    total = 0

    for indices in t.split(t.arange(0, full_OV_circuit.shape[0]), batch_size):
        AB_slice = full_OV_circuit[indices].AB # [batch_size, d_vocab]
        argmaxs = t.argmax(AB_slice, dim = 1) # [batch_size]
        total += (argmaxs == indices).float().sum().item()
    total = total / full_OV_circuit.shape[0]
    return total
print(
    f"Fraction of the time that the best logit is on the diagonal: {top_1_acc(full_OV_circuit):.4f}"
)

Fraction of the time that the best logit is on the diagonal: 0.3079


In [10]:
W_O_both = einops.rearrange(model.W_O[1, [4, 10]], "head d_head d_model -> (head d_head) d_model")
W_V_both = einops.rearrange(model.W_V[1, [4, 10]], "head d_model d_head -> d_model (head d_head)")

W_OV_eff = W_E @ FactoredMatrix(W_V_both, W_O_both) @ W_U

In [11]:
print(
    f"Fraction of the time that the best logit is on the diagonal: {top_1_acc(W_OV_eff):.4f}"
)

Fraction of the time that the best logit is on the diagonal: 0.9556


For the effective OV circuit, we see that the highest values are on the diagonal. The diagonal value AA is what a prev occurence of token_A adds to the current occurence of token_A if the current occurence attends to the prev occurence. Since this is the highest value for each row, it means the effective OV circuit is performing copying (if it sees a prev occurence of itself, it copies it over).

The full QK circuit is (W_pos + W_E) @ W_QK @ (W_pos.T + W_E.T). We claim we can ignore the W_E part (will be justified shortly),
and just look at W_pos @ W_QK @ W_pos.T. 

This is a bilinear form where token_i.T @ W_pos @ W_QK @ W_pos.T @ token_j.T is the attention score given to token j by token i.

In [12]:
layer = 0
head_index = 7

# Compute full QK matrix (for positional embeddings)
W_pos = model.W_pos
W_QK = model.W_Q[layer, head_index] @ model.W_K[layer, head_index].T
pos_by_pos_scores = W_pos @ W_QK @ W_pos.T # [2048, 2048]

# Mask, scale and softmax the scores
mask = t.tril(t.ones_like(pos_by_pos_scores)).bool()
pos_by_pos_pattern = t.where(
    mask, pos_by_pos_scores / model.cfg.d_head**0.5, -1.0e6
).softmax(-1)

# Plot the results
print(f"Avg lower-diagonal value: {pos_by_pos_pattern.diag(-1).mean():.4f}")
imshow(
    utils.to_numpy(pos_by_pos_pattern[:200, :200]),
    labels={"x": "Key", "y": "Query"},
    title="Attention patterns for prev-token QK circuit, first 100 indices",
    width=700,
    height=600,
)

Avg lower-diagonal value: 0.9978


Since the lower diagonal completely dominates the prob dist across rows, we see that token_i gives the highest attention score to token_(i - 1), and so this is a prev token QK circuit. Now we will show why it was justified to ignore the embeddings:

In [18]:
seq_len = 50
batch_size = 1
(rep_tokens, rep_logits, rep_cache) = run_and_cache_model_repeated_tokens(
    model, seq_len, batch_size
)


In [32]:
W_Q = model.W_Q[0, 7]
W_K = model.W_K[0, 7]
emb = rep_cache["embed"]
pos_emb = rep_cache["pos_embed"]
decomposed_input = t.cat((emb, pos_emb), dim = 0)
query_decomposed = einops.einsum(decomposed_input, W_Q, "b p d_model, d_model d_head -> b p d_head")
key_decomposed = einops.einsum(decomposed_input, W_K, "b p d_model, d_model d_head -> b p d_head")
# get norms
query_norms = query_decomposed.norm(dim = -1)
key_norms = key_decomposed.norm(dim = -1)

# grid plot of query norms: position vs embed, pos emb
component_labels = ["Embed", "PosEmbed"]
imshow(
    query_norms,
    labels={"x": "Position", "y": "Component"},
    title=f"Norms of components of query",
    y=component_labels,
    width=800,
    height=400,
)

imshow(
    key_norms,
    labels={"x": "Position", "y": "Component"},
    title=f"Norms of components of key",
    y=component_labels,
    width=800,
    height=400,
)

Now we decompose the attention patterns for the L1 heads. The query and key vectors for a L1 attnetion head can be decomposed into 14 terms each -- the embeddings, positional embeddings, and outputs of the L0 heads (each @ W_Q/W_k). These 14 activations are not interpretable since they have no privileged basis -- however, it's a safe bet that larger activations are going to have a greater overall effect on the residual stream, so we can take the norm of each of the 14 activations.

In [34]:
def decompose_qk_input(cache: ActivationCache) -> Float[Tensor, "n_heads+2 posn d_model"]:
    """
    Retrieves all the input tensors to the first attention layer, and concatenates them along the 0th dim.

    The [i, 0, 0]th element is y_i (from notation above). The sum of these tensors along the 0th dim should
    be the input to the first attention layer.
    """
    outputs = einops.rearrange(cache["result", 0], 'posn n_head d_model -> n_head posn d_model') # [n_head, posn, d_model]
    emb = cache["embed"].unsqueeze(0) # [1, posn, d_model]
    pos_emb = cache["pos_embed"].unsqueeze(0) # [1, posn, d_model]
    return t.cat((emb, pos_emb, outputs, ), dim = 0) # [n_heads + 2, posn, d_model]


def decompose_q(
    decomposed_qk_input: Float[Tensor, "n_heads+2 posn d_model"],
    ind_head_index: int,
    model: HookedTransformer,
) -> Float[Tensor, "n_heads+2 posn d_head"]:
    """
    Computes the tensor of query vectors for each decomposed QK input.

    The [i, :, :]th element is y_i @ W_Q (so the sum along axis 0 is just the q-values).
    """
    W_Q = model.W_Q[1, ind_head_index]
    return einops.einsum(decomposed_qk_input, W_Q, "n_heads posn d_model, d_model d_head -> n_heads posn d_head")


def decompose_k(
    decomposed_qk_input: Float[Tensor, "n_heads+2 posn d_model"],
    ind_head_index: int,
    model: HookedTransformer,
) -> Float[Tensor, "n_heads+2 posn d_head"]:
    """
    Computes the tensor of key vectors for each decomposed QK input.

    The [i, :, :]th element is y_i @ W_K(so the sum along axis 0 is just the k-values)
    """
    W_K = model.W_K[1, ind_head_index]
    return einops.einsum(decomposed_qk_input, W_K, "n_heads posn d_model, d_model d_head -> n_heads posn d_head")


# Recompute rep tokens/logits/cache, if we haven't already
seq_len = 50
batch_size = 1
(rep_tokens, rep_logits, rep_cache) = run_and_cache_model_repeated_tokens(model, seq_len, batch_size)
rep_cache.remove_batch_dim()

ind_head_index = 4

# First we get decomposed q and k input, and check they're what we expect
decomposed_qk_input = decompose_qk_input(rep_cache)
decomposed_q = decompose_q(decomposed_qk_input, ind_head_index, model)
decomposed_k = decompose_k(decomposed_qk_input, ind_head_index, model)
t.testing.assert_close(
    decomposed_qk_input.sum(0), rep_cache["resid_pre", 1] + rep_cache["pos_embed"], rtol=0.01, atol=1e-05
)
t.testing.assert_close(decomposed_q.sum(0), rep_cache["q", 1][:, ind_head_index], rtol=0.01, atol=0.001)
t.testing.assert_close(decomposed_k.sum(0), rep_cache["k", 1][:, ind_head_index], rtol=0.01, atol=0.01)

# Second, we plot our results
component_labels = ["Embed", "PosEmbed"] + [f"0.{h}" for h in range(model.cfg.n_heads)]
for decomposed_input, name in [(decomposed_q, "query"), (decomposed_k, "key")]:
    imshow(
        utils.to_numpy(decomposed_input.pow(2).sum([-1])),
        labels={"x": "Position", "y": "Component"},
        title=f"Norms of components of {name}",
        y=component_labels,
        width=800,
        height=400,
    )

We see that the output of the L0.7 attention head dominates the key vectors. The token embeddings dominate the query vectors. This motivates us to look at the attention scores obtained when the query comes only from the token embeddings, and the keys come only from the L0.7 output, since this is a good approximation of the real attention scores by our previous analysis.

In [None]:
def decompose_attn_scores(decomposed_q: t.Tensor, decomposed_k: t.Tensor) -> Tensor:
    """
    Output is decomposed_scores with shape [query_component, key_component, query_pos, key_pos]

    The [i, j, 0, 0]th element is y_i @ W_QK @ y_j^T (so the sum along both first axes are the attention scores)
    """
    return einops.einsum(decomposed_q, decomposed_k, "n_q posn_q d_head, n_k posn_k d_head -> n_q n_k posn_q posn_k")


tests.test_decompose_attn_scores(decompose_attn_scores, decomposed_q, decomposed_k)

All tests in `test_decompose_attn_scores` passed!


In [42]:
decomposed_scores = decompose_attn_scores(decomposed_q, decomposed_k)
decomposed_stds = einops.reduce(
    decomposed_scores,
    "query_decomp key_decomp query_pos key_pos -> query_decomp key_decomp",
    t.std,
)

# First plot: attention score contribution from (query_component, key_component) = (Embed, L0H7)
imshow(
    utils.to_numpy(t.tril(decomposed_scores[0, 9]) / model.cfg.d_head ** 0.5),
    title="Attention score contributions from query = embed, key = output of L0H7<br>(by query & key sequence positions)",
    width=700,
)

# Second plot: std dev over query and key positions, shown by component
imshow(
    utils.to_numpy(decomposed_stds),
    labels={"x": "Key Component", "y": "Query Component"},
    title="Std dev of attn score contributions across sequence positions<br>(by query & key component)",
    x=component_labels,
    y=component_labels,
    width=700,
)

We see the characteristic induction pattern of attending primarily to (seq_len - 1) tokens backward, since this is the token following the previous occurence of the current token. The standard deviation tells us more -- when we decompose the attention scores into 14^2 = 196 activations, we see that only when query = embed, key = L0.7 output is there variation in attn scores across sequence positions. This means the other 198 terms are close to the uniform distribution and not doing anything, so the induction pattern is primarily due to query = embed, key = L0.7 output.

We have found a prev-token QK circuit in head 0.7 and a K-composition QK circuit in head 1.4.

We can combine these to create the K_comp_full_circuit = W_E @ W_QK_1.4 @ W_OV_0.7.T @ W_E.T.

This is a bilinear form, where token_A.T @ W_E @ W_QK @ W_OV.T @ W_E.T @ token_A is how much a residual position p that look like token_A.T @ W_E will attend to a residual position q that looks like token_A.T @ W_E @ W_OV_0.7 in attention head 1.4. We know from the query composition that p = index of current occurence of token_A, and from the L0.7 QK circuit that q = index of previous occurence of token A + 1.

Hence, if we have a sequence AB ... AB, then the second A token attends most prominently to the first B token in head 1.4 (if the diagonal values of full K_comp circuit are high, which is what we find in  next code block). 

Then from the head 1.4 OV circuit, we know that when token_B @ W_E @ W_OV_1.4 gets added to the second A token position logits, it will mainly add to the logits for B, and this completes the reverse engineering of copying. 

In [51]:
def find_K_comp_full_circuit(
    model: HookedTransformer, prev_token_head_index: int, ind_head_index: int
) -> FactoredMatrix:
    """
    Returns a (vocab, vocab)-size FactoredMatrix, with the first dimension being the query side (direct from token
    embeddings) and the second dimension being the key side (going via the previous token head).
    """
    W_E = model.W_E
    W_Q = model.W_Q[1, ind_head_index]
    W_K = model.W_K[1, ind_head_index]
    W_O = model.W_O[0, prev_token_head_index]
    W_V = model.W_V[0, prev_token_head_index]

    Q = W_E @ W_Q
    K = W_E @ W_V @ W_O @ W_K
    return FactoredMatrix(Q, K.T)


prev_token_head_index = 7
ind_head_index = 4
K_comp_circuit = find_K_comp_full_circuit(model, prev_token_head_index, ind_head_index)

tests.test_find_K_comp_full_circuit(find_K_comp_full_circuit, model)

print(f"Fraction of tokens where the highest activating key is the same token: {top_1_acc(K_comp_circuit):.4f}")

All tests in `test_find_K_comp_full_circuit` passed!
Fraction of tokens where the highest activating key is the same token: 0.6285
