## Test of TransformerLens for decoder models

We import all necessary packages

In [None]:
import circuitsvis as cv

import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix
import torch
import torch.nn as nn
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px

from jaxtyping import Float
from functools import partial


Plotter helper functions

In [None]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

We then import the GPT2-small model an try it out

In [9]:
device = utils.get_device()
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

test_text = "The capital of Norway is Oslo."
loss = model(test_text, return_type="loss")
print("model loss:", loss)

Loaded pretrained model gpt2-small into HookedTransformer
model loss: tensor(4.4879, grad_fn=<DivBackward0>)


We now try storing and visualizing the activation patterns

In [10]:
model_tokens = model.to_tokens(test_text)
model_logits, model_cache = model.run_with_cache(model_tokens, remove_batch_dim=True)

print(type(model_cache))
attention_pattern = model_cache["pattern", 0, "attn"]
print(attention_pattern.shape)
gpt2_str_tokens = model.to_str_tokens(test_text)

print("Layer 0 Head Attention Patterns")
cv.attention.attention_patterns(tokens=gpt2_str_tokens, attention=attention_pattern)

<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([12, 8, 8])
Layer 0 Head Attention Patterns


We try to ablate the 7th layer by using the hook features of the transformer_lens package

In [None]:
layer_to_ablate = 0
head_index_to_ablate = 8

# We define a head ablation hook
# The type annotations are NOT necessary, they're just a useful guide to the reader
# 
def head_ablation_hook(value, hook):
    print(f"Shape of the value tensor: {value.shape}")
    value[:, :, head_index_to_ablate, :] = 0.
    return value

original_loss = model(model_tokens, return_type="loss")
ablated_loss = model.run_with_hooks(
    model_tokens, 
    return_type="loss", 
    fwd_hooks=[(
        utils.get_act_name("v", layer_to_ablate), 
        head_ablation_hook
        )]
    )
print(f"Original Loss: {original_loss.item():.3f}")
print(f"Ablated Loss: {ablated_loss.item():.3f}")


Shape of the value tensor: torch.Size([1, 8, 12, 64])
Original Loss: 4.488
Ablated Loss: 5.268


Then we visualize the new patterns to check that ablating the hea worked

In [16]:
#Trying to visualize the new attention patterns
model_logits, model_cache = model.run_with_cache(model_tokens, remove_batch_dim=True)
print(type(model_cache))
attention_pattern = model_cache["pattern", 0, "attn"]
print(attention_pattern.shape)
gpt2_str_tokens = model.to_str_tokens(test_text)

cv.attention.attention_patterns(tokens=gpt2_str_tokens, attention=attention_pattern)

<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([12, 8, 8])
