# TransformerLens & induction circuits

![kcomp_diagram.png](https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/kcomp_diagram.png)

In [1]:
try:
  import google.colab
  IN_COLAB = True
  print("Running as a Colab notebook")
  %pip install transformer_lens
  %pip install circuitsvis
except:
  IN_COLAB = False
  print("Running as a Jupyter notebook - intended for development only!")

Running as a Jupyter notebook - intended for development only!


In [3]:
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
pio.renderers.default = "notebook_connected" # or use "browser" if you want plots to open with browser
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import einops
from fancy_einsum import einsum
from torchtyping import TensorType as TT
from typing import List, Optional, Callable, Tuple, Union
import functools
from tqdm import tqdm
from IPython.display import display

from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import circuitsvis as cv

# Saves computation time, since we don't need it for the contents of this notebook
t.set_grad_enabled(False)

def imshow(tensor, renderer=None, xaxis="", yaxis="", caxis="", **kwargs):
    return px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    return px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    return px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs)

def plot_comp_scores(model: HookedTransformer, comp_scores: TT["heads", "heads"], title: str = "", baseline: Optional[t.Tensor] = None) -> go.Figure:
    return px.imshow(
        utils.to_numpy(comp_scores),
        y=[f"L0H{h}" for h in range(model.cfg.n_heads)],
        x=[f"L1H{h}" for h in range(model.cfg.n_heads)],
        labels={"x": "Layer 1", "y": "Layer 0"},
        title=title,
        color_continuous_scale="RdBu" if baseline is not None else "Blues",
        color_continuous_midpoint=baseline if baseline is not None else None,
        zmin=None if baseline is not None else 0.0,
    )

def enable_plotly_in_cell():
  import IPython
  from plotly.offline import init_notebook_mode
  display(IPython.core.display.HTML('''<script src="/static/components/requirejs/require.js"></script>'''))
  init_notebook_mode(connected=False)

t.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x1e284c22d40>

In [4]:
def solutions_get_ablation_scores(model: HookedTransformer, tokens: TT["batch", "seq"]) -> TT["n_layers", "n_heads"]:
    ablation_scores = t.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)
    logits = model(tokens, return_type="logits")
    loss_no_ablation = cross_entropy_loss(logits, tokens)
    for layer in tqdm(range(model.cfg.n_layers)):
        for head in range(model.cfg.n_heads):
            temp_hook_fn = functools.partial(head_ablation_hook, head_index_to_ablate=head)
            patched_logits = model.run_with_hooks(tokens, fwd_hooks=[
                (utils.get_act_name("result", layer), temp_hook_fn)
            ])
            loss = cross_entropy_loss(patched_logits, tokens)
            ablation_scores[layer, head] = loss - loss_no_ablation
    return ablation_scores

def solutions_mask_scores(attn_scores: TT["query_d_model", "key_d_model"]):
    mask = t.tril(t.ones_like(attn_scores)).bool()
    neg_inf = t.tensor(-1.0e6).to(attn_scores.device)
    masked_attn_scores = t.where(mask, attn_scores, neg_inf)
    return masked_attn_scores

def solutions_decompose_attn_scores(decomposed_q: t.Tensor, decomposed_k: t.Tensor) -> t.Tensor:
    return einsum("q_comp q_pos d_model, k_comp k_pos d_model -> q_comp k_comp q_pos k_pos", decomposed_q, decomposed_k)

def solutions_find_K_comp_full_circuit(model: HookedTransformer, prev_token_head_index: int, ind_head_index: int) -> FactoredMatrix:
    W_E = model.W_E
    W_Q = model.W_Q[1, ind_head_index]
    W_K = model.W_K[1, ind_head_index]
    W_O = model.W_O[0, prev_token_head_index]
    W_V = model.W_V[0, prev_token_head_index]
    Q = W_E @ W_Q
    K = W_E @ W_V @ W_O @ W_K
    return FactoredMatrix(Q, K.T)

def solutions_get_comp_score(
    W_A: TT["in_A", "out_A"], 
    W_B: TT["out_A", "out_B"]
) -> float:
    W_A_norm = W_A.pow(2).sum().sqrt()
    W_B_norm = W_B.pow(2).sum().sqrt()
    W_AB_norm = (W_A @ W_B).pow(2).sum().sqrt()
    return (W_AB_norm / (W_A_norm * W_B_norm)).item()

In [5]:
def test_get_ablation_scores(ablation_scores: TT["layer", "head"], model: HookedTransformer, rep_tokens: TT["batch", "seq"]):
    ablation_scores_expected = solutions_get_ablation_scores(model, rep_tokens)
    t.testing.assert_close(ablation_scores, ablation_scores_expected)
    print("All tests in `test_get_ablation_scores` passed!")

def test_full_OV_circuit(OV_circuit: FactoredMatrix, model: HookedTransformer, layer: int, head: int):
        W_E = model.W_E
        W_OV = FactoredMatrix(model.W_V[layer, head], model.W_O[layer, head])
        W_U = model.W_U
        OV_circuit_expected = W_E @ W_OV @ W_U
        t.testing.assert_close(OV_circuit.get_corner(20), OV_circuit_expected.get_corner(20))
        print("All tests in `test_full_OV_circuit` passed!")

def test_pos_by_pos_pattern(pattern: TT["n_ctx", "n_ctx"], model: HookedTransformer, layer: int, head: int):
    W_pos = model.W_pos
    W_QK = model.W_Q[layer, head] @ model.W_K[layer, head].T
    score_expected = W_pos @ W_QK @ W_pos.T
    masked_scaled = solutions_mask_scores(score_expected / model.cfg.d_head ** 0.5)
    pattern_expected = t.softmax(masked_scaled, dim=-1)
    t.testing.assert_close(pattern[:50, :50], pattern_expected[:50, :50])
    print("All tests in `test_full_OV_circuit` passed!")

def test_decompose_attn_scores(decompose_attn_scores: Callable, q: t.Tensor, k: t.Tensor):
    decomposed_scores = decompose_attn_scores(q, k)
    decomposed_scores_expected = solutions_decompose_attn_scores(q, k)
    t.testing.assert_close(decomposed_scores, decomposed_scores_expected)
    print("All tests in `test_decompose_attn_scores` passed!")

def test_find_K_comp_full_circuit(find_K_comp_full_circuit: Callable, model: HookedTransformer):
    K_comp_full_circuit: FactoredMatrix = find_K_comp_full_circuit(model, 7, 4)
    K_comp_full_circuit_expected: FactoredMatrix = solutions_find_K_comp_full_circuit(model, 7, 4)
    assert isinstance(K_comp_full_circuit, FactoredMatrix), "Should return a FactoredMatrix object!"
    t.testing.assert_close(K_comp_full_circuit.get_corner(20), K_comp_full_circuit_expected.get_corner(20))
    print("All tests in `test_find_K_comp_full_circuit` passed!")

def test_get_comp_score(get_comp_score: Callable):
    W_A = t.rand(3, 4)
    W_B = t.rand(4, 5)
    comp_score = get_comp_score(W_A, W_B)
    comp_score_expected = solutions_get_comp_score(W_A, W_B)
    assert isinstance(comp_score, float)
    assert abs(comp_score - comp_score_expected) < 1e-5
    print("All tests in `test_get_comp_score` passed!")

In [6]:
gpt2_small = HookedTransformer.from_pretrained("gpt2-small").cuda()

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda


In [7]:
model_description_text = '''## Loading Models

HookedTransformer comes loaded with >40 open source GPT-style models. You can load any of them in with `HookedTransformer.from_pretrained(MODEL_NAME)`. Each model is loaded into the consistent HookedTransformer architecture, designed to be clean, consistent and interpretability-friendly. 

For this demo notebook we'll look at GPT-2 Small, an 80M parameter model. To try the model the model out, let's find the loss on this paragraph!'''

loss = gpt2_small(model_description_text, return_type="loss")
print("Model loss:", loss)

Model loss: tensor(4.3943, device='cuda:0')


In [8]:
logits = gpt2_small(model_description_text, return_type="logits")
prediction = logits.argmax(dim=-1).squeeze()[:-1]
true_tokens = gpt2_small.to_tokens(model_description_text).squeeze()[1:]
num_correct = (prediction == true_tokens).sum()

print(f"Model accuracy: {num_correct}/{len(true_tokens)}")
print(f"Correct words: {gpt2_small.to_str_tokens(prediction[prediction == true_tokens])}")

Model accuracy: 32/112
Correct words: ['\n', '\n', 'former', ' with', ' models', '.', ' can', ' of', 'ooked', 'Trans', 'former', '_', 'NAME', '`.', ' model', ' the', 'Trans', 'former', ' to', ' be', ' and', '-', '.', '\n', ' at', 'PT', '-', ',', ' model', ',', "'s", ' the']


## Caching all Activations

with `logits, cache = model.run_with_cache(tokens)`, we can look at all of the internal activations of a model.

In [9]:
gpt2_text = "Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets."
gpt2_tokens = gpt2_small.to_tokens(gpt2_text)
gpt2_logits, gpt2_cache = gpt2_small.run_with_cache(gpt2_tokens, remove_batch_dim=True)

In [11]:
attn_patterns_layer_0 = gpt2_cache["blocks.0.attn.hook_pattern"]

## Visualising Attention Heads

A key insight from the Mathematical Frameworks paper is that we should focus on interpreting the parts of the model that are intrinsically interpretable - the input tokens, the output logits and the attention patterns. Everything else (the residual stream, keys, queries, values, etc) are compressed intermediate states when calculating meaningful things. 

In [12]:
print(type(gpt2_cache))
attention_pattern = gpt2_cache["pattern", 0, "attn"]
print(attention_pattern.shape)
gpt2_str_tokens = gpt2_small.to_str_tokens(gpt2_text)

print("Layer 0 Head Attention Patterns:")
display(cv.attention.attention_patterns(tokens=gpt2_str_tokens, attention=attention_pattern))

<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([12, 33, 33])
Layer 0 Head Attention Patterns:


# Finding induction heads

In [13]:
cfg = HookedTransformerConfig(
    d_model=768,
    d_head=64,
    n_heads=12,
    n_layers=2,
    n_ctx=2048,
    d_vocab=50278,
    attention_dir="causal", # defaults to "bidirectional"
    attn_only=True, # defaults to False
    tokenizer_name="EleutherAI/gpt-neox-20b", 
    seed=398,
    use_attn_result=True,
    normalization_type=None, # defaults to "LN", i.e. layernorm with weights & biases
    positional_embedding_type="shortformer"
)

In [None]:
WEIGHT_PATH = "attn_only_2L_half.pth"

model = HookedTransformer(cfg)
pretrained_weights = t.load(WEIGHT_PATH, map_location="cuda")
model.load_state_dict(pretrained_weights)

In [15]:
text = "We think that powerful, significantly superhuman machine intelligence is more likely than not to be created this century. If current machine learning techniques were scaled up to this level, we think they would by default produce systems that are deceptive or manipulative, and that no solid plans are known for how to avoid this."
logits, cache = model.run_with_cache(text, remove_batch_dim=True)

In [16]:
str_tokens = model.to_str_tokens(text)
for layer in range(model.cfg.n_layers):
    attention_pattern = cache["pattern", layer]
    display(cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern))

We notice that there are three basic patterns which repeat quite frequently:

* `prev_token_heads`, which attend mainly to the previous token (e.g. head `0.7`)
* `current_token_heads`, which attend mainly to the current token (e.g. head `1.6`)
* `first_token_heads`, which attend mainly to the first token (e.g. head `0.9`, although this is a bit less clear-cut than the other two)

The `prev_token_heads` and `current_token_heads` are perhaps unsurprising, because words that are close together in a sequence probably have a lot more mutual information (i.e. we could get quite far using bigram or trigram prediction). 

The `first_token_heads` are a bit more surprising. The basic intuition here is that the first token in a sequence is often used as a resting or null position for heads that only sometimes activate (since our attention probabilities always have to add up to 1).


In [17]:
def current_attn_detector(cache: ActivationCache) -> List[str]:
    '''
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be current-token heads
    '''
    current_attn_heads = []
    for layer in range(model.cfg.n_layers):
        for head in range(model.cfg.n_heads):
            attention_pattern = cache["pattern", layer][head]
            # take avg of diagonal elements
            current_attn_score = attention_pattern.diagonal().mean()
            if current_attn_score > 0.4:
                current_attn_heads.append(f"{layer}.{head}")
    return current_attn_heads

    

In [18]:
print("Heads attending to current token  = ", ", ".join(current_attn_detector(cache)))

Heads attending to current token  =  0.9


In [21]:
def generate_repeated_tokens(model: HookedTransformer, seq_len: int, batch: int = 1) -> t.Tensor:
    '''
    Generates a sequence of repeated random tokens

    Outputs are:
        rep_tokens: [batch, 1+2*seq_len]
    '''
    prefix = (t.ones(batch, 1) * model.tokenizer.bos_token_id).long()
    rep_tokens_half = t.randint(0, model.cfg.d_vocab, (batch, seq_len), dtype=t.int64)
    rep_tokens = t.cat([prefix, rep_tokens_half, rep_tokens_half], dim=-1).cuda()
    return rep_tokens

def run_and_cache_model_repeated_tokens(model: HookedTransformer, seq_len: int, batch: int = 1) -> Tuple[t.Tensor, t.Tensor, ActivationCache]:
    '''
    Generates a sequence of repeated random tokens, and runs the model on it, returning logits, tokens and cache

    Should use the `generate_repeated_tokens` function above

    Outputs are:
        rep_tokens: [batch, 1+2*seq_len]
        rep_logits: [batch, 1+2*seq_len, d_vocab]
        rep_cache: The cache of the model run on rep_tokens
    '''
    rep_tokens = generate_repeated_tokens(model, seq_len, batch)
    rep_logits, rep_cache = model.run_with_cache(rep_tokens)
    return rep_tokens, rep_logits, rep_cache

In [22]:
def per_token_losses(logits, tokens):
    log_probs = F.log_softmax(logits, dim=-1)
    pred_log_probs = t.gather(log_probs[:, :-1], -1, tokens[:, 1:, None])[..., 0]
    return -pred_log_probs[0]

In [23]:
seq_len = 50
batch = 1
(rep_tokens, rep_logits, rep_cache) = run_and_cache_model_repeated_tokens(model, seq_len, batch)
rep_cache.remove_batch_dim()
rep_str = model.to_str_tokens(rep_tokens)
model.reset_hooks()
ptl = per_token_losses(rep_logits, rep_tokens)
print(f"Performance on the first half: {ptl[:seq_len].mean():.3f}")
print(f"Performance on the second half: {ptl[seq_len:].mean():.3f}")
fig = px.line(
    utils.to_numpy(ptl), hover_name=rep_str[1:],
    title=f"Per token loss on sequence of length {seq_len} repeated twice",
    labels={"index": "Sequence position", "value": "Loss"}
).update_layout(showlegend=False, hovermode="x unified")
fig.add_vrect(x0=0, x1=seq_len-.5, fillcolor="red", opacity=0.2, line_width=0)
fig.add_vrect(x0=seq_len-.5, x1=2*seq_len-1, fillcolor="green", opacity=0.2, line_width=0)

enable_plotly_in_cell()
fig.show()

Performance on the first half: 14.240
Performance on the second half: 3.733


In [30]:
len(rep_str)

101

In [31]:
rep_str

['<|endoftext|>',
 'rias',
 ' [(',
 ' Cyt',
 ' decreasing',
 'INC',
 ' neat',
 ' post',
 ' 99',
 'cancer',
 'church',
 ' hat',
 ' patient',
 'rel',
 ' Downtown',
 ' associ',
 'Marker',
 'incoln',
 ' cyl',
 ' Adding',
 'ometric',
 ' confusing',
 ' Moment',
 ' val',
 ' dirty',
 'igrant',
 ' Was',
 'ITC',
 'Hope',
 'erver',
 ' queries',
 'account',
 ' sinister',
 ' files',
 ' courtesy',
 'ulas',
 'prot',
 ' rocky',
 'bounds',
 ' aspect',
 ' indoors',
 ' turkey',
 'igion',
 ' :',
 'filed',
 ' prepared',
 'umbent',
 'ails',
 'oste',
 ' competence',
 'itating',
 'rias',
 ' [(',
 ' Cyt',
 ' decreasing',
 'INC',
 ' neat',
 ' post',
 ' 99',
 'cancer',
 'church',
 ' hat',
 ' patient',
 'rel',
 ' Downtown',
 ' associ',
 'Marker',
 'incoln',
 ' cyl',
 ' Adding',
 'ometric',
 ' confusing',
 ' Moment',
 ' val',
 ' dirty',
 'igrant',
 ' Was',
 'ITC',
 'Hope',
 'erver',
 ' queries',
 'account',
 ' sinister',
 ' files',
 ' courtesy',
 'ulas',
 'prot',
 ' rocky',
 'bounds',
 ' aspect',
 ' indoors',
 ' t

In [25]:
for layer in range(model.cfg.n_layers):
    attention_pattern = rep_cache["pattern", layer]
    display(cv.attention.attention_patterns(tokens=rep_str, attention=attention_pattern))

In [24]:
def induction_attn_detector(cache: ActivationCache) -> List[str]:
    '''
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be induction heads

    Remember - the tokens used to generate rep_cache are (bos_token, *rand_tokens, *rand_tokens)
    '''
    attn_heads = []
    for layer in range(model.cfg.n_layers):
        for head in range(model.cfg.n_heads):
            attention_pattern = cache["pattern", layer][head]
            # take avg of (-seq_len+1)-offset elements
            seq_len = (attention_pattern.shape[-1] - 1) // 2
            score = attention_pattern.diagonal(-seq_len+1).mean()
            if score > 0.4:
                attn_heads.append(f"{layer}.{head}")
    return attn_heads

In [26]:
print("Induction heads = ", ", ".join(induction_attn_detector(rep_cache)))

Induction heads =  1.4, 1.10
