In [2]:
import einops
import plotly.express as px
import numpy as np
from scipy import fft
import torch
import transformer_lens

from src import const
from src import loss
from src import plot
from src import task
from src import utils

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
hooked_model: transformer_lens.HookedTransformer = utils.load_model(
    path=const.MODEL_SAVE_PATH,
    mod=const.MOD,
    device="cpu",
)

In [4]:
# Make a dataset and labels for the model
dataset, labels = task.make_dataset_and_labels(
    p=const.MOD,
    device="cpu",
)
print(f"Dataset shape: {dataset.shape[0]:,},\t\t{dataset.shape[1]}")
print(f"Dataset shape: {const.MOD} x {const.MOD},\t[a, b, =]")

Dataset shape: 12,769,		3
Dataset shape: 113 x 113,	[a, b, =]


In [5]:
# Get a prediction for every possible input and cache the activations
original_logits, cache = hooked_model.run_with_cache(dataset)


In [6]:
for param_name, param in cache.items():
    print(param_name, param.shape)

hook_embed torch.Size([12769, 3, 128])
hook_pos_embed torch.Size([12769, 3, 128])
blocks.0.hook_resid_pre torch.Size([12769, 3, 128])
blocks.0.attn.hook_q torch.Size([12769, 3, 4, 32])
blocks.0.attn.hook_k torch.Size([12769, 3, 4, 32])
blocks.0.attn.hook_v torch.Size([12769, 3, 4, 32])
blocks.0.attn.hook_attn_scores torch.Size([12769, 4, 3, 3])
blocks.0.attn.hook_pattern torch.Size([12769, 4, 3, 3])
blocks.0.attn.hook_z torch.Size([12769, 3, 4, 32])
blocks.0.hook_attn_out torch.Size([12769, 3, 128])
blocks.0.hook_resid_mid torch.Size([12769, 3, 128])
blocks.0.mlp.hook_pre torch.Size([12769, 3, 512])
blocks.0.mlp.hook_post torch.Size([12769, 3, 512])
blocks.0.hook_mlp_out torch.Size([12769, 3, 128])
blocks.0.hook_resid_post torch.Size([12769, 3, 128])


In [7]:
# W_neur = W_E @ W_V @ W_O @ W_in
# The linear transformation from the input to the ReLU
W_E = hooked_model.embed.W_E[:-1]
print("W_E", W_E.shape)

# W_logit = W_out @ W_U
# The linear transformation from the ReLU to the logits
W_neur = (
    W_E
    @ hooked_model.blocks[0].attn.W_V
    @ hooked_model.blocks[0].attn.W_O
    @ hooked_model.blocks[0].mlp.W_in
)
print("W_neur", W_neur.shape)

# W_logit = W_out @ W_U
# The linear transformation from the ReLU to the logits
W_logit = hooked_model.blocks[0].mlp.W_out @ hooked_model.unembed.W_U
print("W_logit", W_logit.shape)

W_E torch.Size([113, 128])
W_neur torch.Size([4, 113, 512])
W_logit torch.Size([512, 113])


## Basic Analysis

### Plotting Attention Patterns

In [8]:
neuron_activations = cache["post", 0, "mlp"][:, -1, :]

In [9]:
plot.imshow(
    cache["pattern", 0].mean(dim=0)[:, -1, :],
    title="Average Attention Pattern per Head",
    xaxis="Source",
    yaxis="Head",
    x=["a", "b", "="],
    y=[f"Head {i}" for i in range(4)],
    width=800,
    height=800,
)

In [10]:
plot.imshow(
    cache["pattern", 0][5][:, -1, :],
    title="Attention Pattern for Head",  # TODO: what is this?
    xaxis="Source",
    yaxis="Head",
    x=["a", "b", "="],
    y=[f"Head {i}" for i in range(4)],
    width=800,
    height=800,
)

In [11]:
plot.imshow(
    einops.rearrange(
        cache["pattern", 0][:, :, -1, 0],
        "(a b) head -> head a b",
        a=const.MOD,
        b=const.MOD,
    ),
    title="Attention for each head from a -> =",
    xaxis="b",
    yaxis="a",
    facet_col=0,
)

In [19]:
plot.imshow(
    einops.rearrange(
        neuron_activations[:, :5],
        "(a b) neuron -> neuron a b",
        a=const.MOD,
        b=const.MOD,
    ),
    title="First 5 neuron activations",
    xaxis="b",
    yaxis="a",
    facet_col=0,
).for_each_annotation(lambda a: a.update(text=f'Neuron {int(a.text.split("=")[-1]) + 1}'))

### Singular Value Decomposition

In [20]:
    U, S, _ = torch.linalg.svd(W_E)

In [21]:
px.line(
    S.detach().numpy(),
    title="Singular Values",
    labels={"x":"Input Vocabulary"},
    template=plot.TEMPLATE,
).update_layout(showlegend=False)

In [34]:
# The singular values plot shows the first 8 elements of the embeddings are the most important
# The next two are slightly important, the rest are basically unused.
plot.imshow(
    U[:, :10],
    title="First 8 columns of the Singular Value Matrix",
    aspect="auto",
    xaxis="Input Vocabulary",
)

## Frequency Analysis

In [43]:
principal_components = U[:, :8].sum(dim=1).detach().numpy()
yf = fft.fft(principal_components)
px.line(
    x=np.arange(const.MOD // 2),
    y=2.0 / const.MOD * np.abs(yf[0 : const.MOD // 2]),
    labels={"x": "Frequency", "y": "Amplitude"},
    template=plot.TEMPLATE,
)

In [46]:
import fourier_analysis

fourier_basis, fourier_basis_names = fourier_analysis.make_fourier_basis(
    base=const.MOD, device="cpu"
)
plot.imshow(
    fourier_basis, xaxis="Input", yaxis="Component", y=fourier_basis_names
)

Multiplying the Fourier Basis by the singular values we can see how the singular values are composed sin and cos functions.
The rows correspond to the key frequencies from the previous plots.
The additional thing we learn from this is that it's using both the sin and cos of the 4 key frequencies.

In [47]:
plot.imshow(
    fourier_basis @ W_E,
    yaxis="Fourier Component",
    xaxis="Residual Stream",
    y=fourier_basis_names,
    title="Embedding in Fourier Basis",
)