## Test of TransformerLens for decoder models

We import all necessary packages

In [None]:
import circuitsvis as cv

import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix
import torch
import torch.nn as nn
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px

from jaxtyping import Float
from functools import partial


Plotter helper functions

In [None]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

Some setup:

In [45]:
torch.set_grad_enabled(False)
print("Disabled automatic differentiation")

Disabled automatic differentiation


# Some helper functions

In [91]:
def vis_attn_patterns(model, text, layers):
    str_tokens = model.to_str_tokens(text)
    logits, cache = model.run_with_cache(text, remove_batch_dim=True)

    for layer in layers:
        attention_pattern = cache["pattern", layer]
        display(cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern))
    
    #display(attn_vis)

We then import the GPT2-small model an try it out

In [100]:
device = utils.get_device()
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

test_text = "The capital of Norway is Oslo."
loss = model(test_text, return_type="loss")
print("model loss:", loss)

Loaded pretrained model gpt2-small into HookedTransformer
model loss: tensor(4.4879)


We now try storing and visualizing the activation patterns

In [101]:
model_tokens = model.to_tokens(test_text)
model_logits, model_cache = model.run_with_cache(model_tokens, remove_batch_dim=True)

print(type(model_cache))
attention_pattern = model_cache["pattern", 0, "attn"]
print(attention_pattern.shape)
gpt2_str_tokens = model.to_str_tokens(test_text)

print("Layer 0 Head Attention Patterns")
cv.attention.attention_patterns(tokens=gpt2_str_tokens, attention=attention_pattern)

<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([12, 8, 8])
Layer 0 Head Attention Patterns


We try to ablate the 7th layer by using the hook features of the transformer_lens package

In [102]:
layer_to_ablate = 0
head_index_to_ablate = 8

# We define a head ablation hook
# The type annotations are NOT necessary, they're just a useful guide to the reader
# 
def head_ablation_hook(value, hook):
    print(f"Shape of the value tensor: {value.shape}")
    value[:, :, head_index_to_ablate, :] = 0.
    return value

original_loss = model(model_tokens, return_type="loss")
ablated_loss = model.run_with_hooks(
    model_tokens, 
    return_type="loss", 
    fwd_hooks=[(
        utils.get_act_name("v", layer_to_ablate), 
        head_ablation_hook
        )]
    )
print(f"Original Loss: {original_loss.item():.3f}")
print(f"Ablated Loss: {ablated_loss.item():.3f}")


Shape of the value tensor: torch.Size([1, 8, 12, 64])
Original Loss: 4.488
Ablated Loss: 5.268


# Predicting the next word in a factual statement


In [103]:
text_norway = "The capital of Norway is called"
text_physics = "The unit of temperature is called"
correct_completion = " Oslo"
correct_completion_physics = " Kelvin"

token_norway = model.to_tokens(text_norway)
prob_tensor = model(text_norway).softmax(dim=-1)
#print(prob_tensor)

print(prob_tensor.shape)

prediction = max(prob_tensor[0, -1, :]).item()
print(prediction) # probability of most likely word

index = torch.argmax(prob_tensor[0, -1, :]).item()
print(model.to_string(index)) # index of most likely word

print(prob_tensor[0, -1, model.to_single_token(correct_completion)].item()) # probability of correct completion

utils.test_prompt(text_norway, correct_completion, model)
utils.test_prompt(text_physics, correct_completion_physics, model)


torch.Size([1, 7, 50257])
0.06769462674856186
 "
0.024034103378653526
Tokenized prompt: ['<|endoftext|>', 'The', ' capital', ' of', ' Norway', ' is', ' called']
Tokenized answer: [' Oslo']


Top 0th token. Logit: 13.70 Prob:  6.77% Token: | "|
Top 1th token. Logit: 13.48 Prob:  5.40% Token: | the|
Top 2th token. Logit: 12.81 Prob:  2.77% Token: | '|
Top 3th token. Logit: 12.67 Prob:  2.40% Token: | Oslo|
Top 4th token. Logit: 12.49 Prob:  2.02% Token: | V|
Top 5th token. Logit: 12.49 Prob:  2.02% Token: | N|
Top 6th token. Logit: 12.27 Prob:  1.61% Token: | H|
Top 7th token. Logit: 12.26 Prob:  1.60% Token: | St|
Top 8th token. Logit: 12.25 Prob:  1.59% Token: | The|
Top 9th token. Logit: 12.23 Prob:  1.55% Token: | J|


Tokenized prompt: ['<|endoftext|>', 'The', ' unit', ' of', ' temperature', ' is', ' called']
Tokenized answer: [' Kelvin']


Top 0th token. Logit: 14.86 Prob: 45.28% Token: | the|
Top 1th token. Logit: 13.63 Prob: 13.33% Token: | a|
Top 2th token. Logit: 12.44 Prob:  4.03% Token: | "|
Top 3th token. Logit: 12.20 Prob:  3.17% Token: | an|
Top 4th token. Logit: 11.00 Prob:  0.96% Token: | '|
Top 5th token. Logit: 10.67 Prob:  0.69% Token: | temperature|
Top 6th token. Logit: 10.42 Prob:  0.53% Token: | its|
Top 7th token. Logit: 10.08 Prob:  0.38% Token: | therm|
Top 8th token. Logit: 10.01 Prob:  0.36% Token: | ther|
Top 9th token. Logit:  9.88 Prob:  0.31% Token: | heat|


The model can not seem to give the correct answer to the prompt. Using factual statement involving capitals no longer seems like a good idea.

Update: After switching to GPT2-medium, the answers now seem sensible. The assumptions is then that the disttributions will be even better if switching to an even larger model.

# Test with FCI

We now try visualizing attention patterns for question 1 from FCI:

In [105]:
text_fci1 = "Two metal balls are the same size but one weighs twice as much as the other. The balls are dropped from the roof of a single story building at the same instant of time. The time it takes the balls to reach the ground below will be"
tokens = model.to_tokens(text_fci1)
logits, cache = model.run_with_cache(text_fci1, remove_batch_dim=True)

vis_attn_patterns(model, text_fci1, [0, 10])

We now try looking for induction heads with a simple example:

In [106]:
repeated_fci1_text = "Two metal balls are the same size but one weighs twice as much as the other. The balls are dropped from the roof of a single story building at the same instant of time. Two metal balls are the same size but one weighs twice as much as the other. The balls are dropped from the roof of a single story building at the same instant of time."
vis_attn_patterns(model, repeated_fci1_text, [0, 10])
# This is a simple example to visualize attention patterns in a Transformer model.