# Toy Models of Superposition

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/SamAdamDay/mechanistic-interpretability-projects/blob/main/playground/toy-models.ipynb)


Playing around with the paper [Toy Models of Superposition](https://transformer-circuits.pub/2022/toy_model/index.html) (Elhage et al., 2022)

In [105]:
FORCE_CPU = True
TEST_MODEL = "gpt2-small"
DO_WORD_EMBEDDINGS = False
LOSS_PLOT_SMOOTHING = 500

## Setup

In [2]:
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/SamAdamDay/mechanistic-interpretability-projects.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

import plotly.io as pio
pio.renderers.default = "colab+vscode"

Running as a Jupyter notebook
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [97]:
import re
from typing import Callable
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, IterableDataset

import numpy as np

from scipy.ndimage import gaussian_filter1d

from fancy_einsum import einsum

from tqdm import tqdm

import plotly.express as px

import matplotlib.pyplot as plt

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import (
    HookedTransformer,
    HookedTransformerConfig,
    FactoredMatrix,
    ActivationCache,
)

In [4]:
torch.set_grad_enabled(False)

if torch.cuda.is_available() and not FORCE_CPU:
    device = "cuda"
else:
    device = "cpu"
print(device)

cpu



CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:109.)



## Load things

In [5]:
if DO_WORD_EMBEDDINGS:
    test_model = HookedTransformer.from_pretrained(TEST_MODEL, center_writing_weights=False).to(device)

## Word embedding arithmetic

In [6]:
if DO_WORD_EMBEDDINGS:

    def get_token_embedding(token: str, model = test_model):
        token = model.to_single_token(token)
        return model.W_E[token, :].squeeze()

    def compute_embedding_arithmetic(expression: str, model = test_model) -> torch.Tensor:
        expression_eval = re.sub(r"[a-zA-Z0-9]+", lambda x: f"get_token_embedding('{x.group(0)}', model)", expression)
        vector = eval(expression_eval)
        return vector

    def print_embedding_arithmetic_norm(expression: str, model = test_model):
        vector = compute_embedding_arithmetic(expression, model)
        print(f"{expression}: {vector.norm().item()}")

    def print_embedding_arithmetic_cossim(expression1: str, expression2: str, model = test_model):
        vector1 = compute_embedding_arithmetic(expression1, model)
        vector2 = compute_embedding_arithmetic(expression2, model)
        cosine_similarity = F.cosine_similarity(vector1, vector2, dim=0).item()
        print(f"'{expression1}' vs '{expression2}': {cosine_similarity}")

In [7]:
if DO_WORD_EMBEDDINGS:
    average_W_E_norm = torch.norm(test_model.W_E, dim = 1).mean().item()
    print(average_W_E_norm)

In [8]:
if DO_WORD_EMBEDDINGS:
    print_embedding_arithmetic_norm("(wine - type) - (four - cat)")
    print_embedding_arithmetic_norm("(boy - girl) - (man - woman)")
    print_embedding_arithmetic_norm("(adult - child) - (man - boy)")
    print_embedding_arithmetic_norm("(good - bad) - (best - worst)")
    print_embedding_arithmetic_norm("(good - bad) - (best - worst)")

In [9]:
if DO_WORD_EMBEDDINGS:
    print_embedding_arithmetic_cossim("wine - type", "four - cat")
    print_embedding_arithmetic_cossim("boy - girl", "man - woman")
    print_embedding_arithmetic_cossim("adult - child", "man - boy")
    print_embedding_arithmetic_cossim("good - bad", "best - worst")
    print_embedding_arithmetic_cossim("good - bad", "best - worst")

## Demonstrating superposition

I replicate the experiments in [Section 2](https://transformer-circuits.pub/2022/toy_model/index.html#demonstrating) of the paper.

In [50]:
@dataclass
class Config:
    num_features: int = 20
    hidden_size: int = 5
    use_relu: bool = True
    importances: Callable = lambda i: 0.7 ** i
    sparsities: float | list = 0.5
    init_range: float = 0.02

In [51]:
class ToyModel(nn.Module):
    """A simple linear model with a single hidden layer."""

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W = nn.Parameter(torch.empty((cfg.num_features, cfg.hidden_size)))
        nn.init.normal_(self.W, std=self.cfg.init_range)
        self.b = nn.Parameter(torch.zeros((cfg.num_features,)))

    def forward(self, x):
        linear_out = x @ self.W @ self.W.T + self.b
        if self.cfg.use_relu:
            return F.relu(linear_out)
        else:
            return linear_out

In [90]:
def toy_dataloader(
    cfg: Config,
    num_samples: int = 1024,
    batch_size: int = 1,
):
    sparsities = cfg.sparsities
    if isinstance(sparsities, float):
        sparsities = [sparsities] * cfg.num_features
    sparsities = torch.tensor(sparsities)

    for _ in range(num_samples):
        is_zero = torch.rand((batch_size, cfg.num_features)) > sparsities
        x = torch.rand((batch_size, cfg.num_features)) * is_zero
        yield x


print(next(toy_dataloader(Config(5), 10, 10)))

tensor([[0.7588, 0.0000, 0.4051, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.7076, 0.0000],
        [0.0000, 0.4444, 0.5666, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.8452, 0.0000],
        [0.1525, 0.0000, 0.0108, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.2600, 0.0000, 0.0000],
        [0.6516, 0.7368, 0.0000, 0.7745, 0.2918],
        [0.0000, 0.0000, 0.2006, 0.2340, 0.0000],
        [0.6502, 0.4878, 0.8981, 0.7094, 0.9493],
        [0.2213, 0.8215, 0.0000, 0.0000, 0.0552]])


In [87]:
class ToyLoss(nn.Module):
    """Weighted MSE loss"""

    def __init__(self, cfg: Config):
        super().__init__()
        self.importances = torch.tensor(
            [cfg.importances(i) for i in range(cfg.num_features)]
        )
        self.importances.requires_grad = False

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor):
        return ((y_pred - y_true) ** 2 * self.importances).sum()


x_test = torch.rand(Config.num_features)
print(ToyLoss(Config(importances=lambda i: 1))(x_test, x_test + torch.arange(20)))
print(
    ToyLoss(Config(importances=lambda i: 0.7**i))(x_test, x_test + torch.arange(20))
)

tensor(2470.)
tensor(42.7268)


In [110]:
cfg = Config(
    num_features=20,
    hidden_size=5,
    sparsities=0,
    importances=lambda i: 0.7**i,
    use_relu=False,
)
num_samples = 100000
batch_size = 1024

linear_model = ToyModel(cfg).to(device)
loss_fn = ToyLoss(cfg)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(linear_model.parameters(), lr=0.01)
data_loader = toy_dataloader(cfg, num_samples=num_samples, batch_size=batch_size)

losses = np.empty((num_samples,))
for i, batch in tqdm(enumerate(data_loader), total=num_samples):

    batch = batch.to(device)

    loss = loss_fn(linear_model(batch), batch)
    loss.requires_grad = True

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    losses[i] = loss.item()

100%|██████████| 100000/100000 [00:23<00:00, 4201.53it/s]


In [112]:
losses_smoothed = gaussian_filter1d(losses, sigma=LOSS_PLOT_SMOOTHING)
px.line(y=losses_smoothed, title="Training loss smoothed").show()