In [1]:
import os
DEV_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
except:
    IN_COLAB = False
# Install if in Colab
if IN_COLAB:
    %pip install transformer_lens
    %pip install circuitsvis
    # Install a faster Node version
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs  # noqa

# Hot reload in development mode & not running on the CD
if not IN_COLAB:
    from IPython import get_ipython
    ip = get_ipython()
    if not ip.extension_manager.loaded:
        ip.extension_manager.load('autoreload')
        %autoreload 2
        
IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"

# change renderer to colab if needed
import plotly.io as pio
if IN_COLAB or not DEV_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
    
print(f"Using renderer: {pio.renderers.default}")

# import circuit vis
import circuitsvis as cv
# Testing that the library works
cv.examples.hello("Neel")

import warnings
warnings.filterwarnings("ignore")

# Main imports
import torch
import torch.nn as nn
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px

from jaxtyping import Float
from functools import partial

# transformer lens stuff
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix

# set grad to false cuz we dont need to train
torch.set_grad_enabled(False)

from plot_utils import *


Using renderer: notebook_connected


# Factored matrices
- Some matrix AB can be very large. We can factor it into A = [Large, small] and B = [small, Large] to save space

- this allows us to do all basic matrix calcs very efficiently

In [2]:
torch.manual_seed(50)

A = torch.randn(5, 2)
B = torch.randn(2, 5)

AB = A @ B
AB_factored = FactoredMatrix(A, B)

print("Norms:")
print("AB:", AB.norm())
print("AB_factored:", AB_factored.norm())

print(f"Right dim: {AB_factored.rdim}, left dim: {AB_factored.ldim}, hidden dim: {AB_factored.mdim}")



Norms:
AB: tensor(9.9105)
AB_factored: tensor(9.9105)
Right dim: 5, left dim: 5, hidden dim: 2


# eigenvalues and singular values
We can also look at the eigenvalues and singular values of the matrix. Note that, because the matrix is rank 2 but 5 by 5, the final 3 eigenvalues and singular values are zero - the factored class omits the zeros.


In [4]:
print("Eigenvalues:")
print(torch.linalg.eig(AB).eigenvalues)
print(AB_factored.eigenvalues)
print()
print("Singular Values:")
print(torch.linalg.svd(AB).S)
print(AB_factored.S)


Eigenvalues:
tensor([-6.2877e+00+0.j,  1.8626e-09+0.j,  2.3121e+00+0.j, -8.9038e-08+0.j,
        -4.1527e-07+0.j])
tensor([-6.2877+0.j,  2.3121+0.j])

Singular Values:
tensor([8.3126e+00, 5.3963e+00, 2.2029e-07, 5.7690e-08, 1.2164e-08])
tensor([8.3126, 5.3963])


# Multiplication

In [8]:
C = torch.randn(5, 300)

ABC = AB @ C
ABC_factor = AB_factored @ C
print("Unfactored:", ABC.shape, ABC.norm().round(decimals=3))
print("Factored:", ABC_factor.shape, ABC_factor.norm().round(decimals=3))
print(f"Right dimension: {ABC_factor.rdim}, Left dimension: {ABC_factor.ldim}, Hidden dimension: {ABC_factor.mdim}")


Unfactored: torch.Size([5, 300]) tensor(171.4090)
Factored: torch.Size([5, 300]) tensor(171.4090)
Right dimension: 300, Left dimension: 5, Hidden dimension: 2


In [9]:
# get original AB
AB_unfactored = AB_factored.AB
print(torch.isclose(AB_unfactored, AB).all())



tensor(True)


# More advanced example:
- toymodels paper

In [14]:
# set device
device = utils.get_device()
print(f"Using device: {device}")


device="cpu"


Using device: mps


In [15]:
model = HookedTransformer.from_pretrained("gpt2-small", device=device)


Loaded pretrained model gpt2-small into HookedTransformer


In [16]:
OV_circuit_all_heads = model.OV
print(OV_circuit_all_heads)

# eigenvalues 
OV_circuit_all_heads_eigenvalues = OV_circuit_all_heads.eigenvalues 
print(OV_circuit_all_heads_eigenvalues.shape)
print(OV_circuit_all_heads_eigenvalues.dtype)


FactoredMatrix: Shape(torch.Size([12, 12, 768, 768])), Hidden Dim(64)
torch.Size([12, 12, 64])
torch.complex64


In [17]:
OV_copying_score = OV_circuit_all_heads_eigenvalues.sum(dim=-1).real / OV_circuit_all_heads_eigenvalues.abs().sum(dim=-1)
imshow(utils.to_numpy(OV_copying_score), xaxis="Head", yaxis="Layer", title="OV Copying Score for each head in GPT-2 Small", zmax=1.0, zmin=-1.0)


In [18]:
scatter(x=OV_circuit_all_heads_eigenvalues[-1, -1, :].real, y=OV_circuit_all_heads_eigenvalues[-1, -1, :].imag, title="Eigenvalues of Head L11H11 of GPT-2 Small", xaxis="Real", yaxis="Imaginary")


# full ov circuit

In [19]:
full_ov_circuit = model.embed.W_E @ OV_circuit_all_heads @ model.unembed.W_U
print(full_ov_circuit)
full_ov_circuit_eigenvalues = full_ov_circuit.eigenvalues
print(full_ov_circuit_eigenvalues.shape)
print(full_ov_circuit_eigenvalues.dtype)

# display copying score
full_ov_circuit_copying_score = full_ov_circuit_eigenvalues.sum(dim=-1).real / full_ov_circuit_eigenvalues.abs().sum(dim=-1)
imshow(utils.to_numpy(full_ov_circuit_copying_score), xaxis="Head", yaxis="Layer", title="Full OV Copying Score for each head in GPT-2 Small", zmax=1.0, zmin=-1.0)


FactoredMatrix: Shape(torch.Size([12, 12, 50257, 50257])), Hidden Dim(64)
torch.Size([12, 12, 64])
torch.complex64


In [21]:
scatter(x=full_ov_circuit_copying_score.flatten(), y=OV_copying_score.flatten(), hover_name=[f"L{layer}H{head}" for layer in range(12) for head in range(12)], title="OV Copying Score for each head in GPT-2 Small", xaxis="Full OV Copying Score", yaxis="OV Copying Score")
