# Doing linear regression on the token embeddings

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/SamAdamDay/mechanistic-interpretability-projects/blob/main/playground/token-embed-regression.ipynb)

What features can be extracted from a model's token embeddings using linear regression?

In [8]:
DEVELOPMENT_MODE = False                #@param {type:"boolean"}
MODEL_NAME = "gpt2"                     #@param {type:"string"}
TOKEN_BEGIN_SPACE = "Ġ"                 #@param {type:"string"}
LENGTH_OUTLIER_THRESHOLD = 15           #@param {type:"integer"}
NUMERIC_OUTLIER_THRESHOLD = 1000        #@param {type:"integer"}
DATASET_SIZE = 256                      #@param {type:"integer"}
BATCH_SIZE = 2                          #@param {type:"integer"}

## Setup

In [9]:
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/SamAdamDay/mechanistic-interpretability-projects.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Jupyter notebook - intended for development only!
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload








In [10]:
import plotly.io as pio

if IN_COLAB or DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")


Using renderer: notebook_connected


In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import numpy as np

from sklearn.linear_model import LinearRegression

from datasets import load_dataset

from fancy_einsum import einsum

from tqdm import tqdm

import plotly.express as px

import matplotlib.pyplot as plt

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import (
    HookedTransformer,
    HookedTransformerConfig,
    FactoredMatrix,
    ActivationCache,
)


In [12]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7efe4c7aab90>

In [13]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
print(device)

cuda


## Get model and tokens

In [14]:
model = HookedTransformer.from_pretrained(MODEL_NAME, device=device)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2 into HookedTransformer


In [15]:
W_E = model.W_E
W_E_numpy = utils.to_numpy(W_E)
print(W_E_numpy.shape)

(50257, 768)


In [16]:
d_vocab = model.tokenizer.vocab_size
str_tokens = model.tokenizer.convert_ids_to_tokens(list(range(d_vocab)))

## Token length

Regressing on the number of characters in a token.

In [17]:
lengths_basic = np.array(list(map(len, str_tokens)))
def len_no_space(token_str: str):
    if token_str.startswith(TOKEN_BEGIN_SPACE):
        return len(token_str) - 1
    else:
        return len(token_str)
lengths_nospace = np.array(list(map(len_no_space, str_tokens)))

In [18]:
px.histogram(x=lengths_basic, title="Token Basic Lengths")

In [19]:
px.histogram(x=lengths_nospace, title="Token Lengths Without Space")

What are these weird outliers?

In [20]:
print([x for x in str_tokens if len(x) > LENGTH_OUTLIER_THRESHOLD])



Let's mask out the outliers

In [21]:
length_outliers_mask = lengths_basic <= LENGTH_OUTLIER_THRESHOLD

In [22]:
regression_tasks = {
    "Basic lengths": (W_E_numpy, lengths_basic),
    "Lengths without space": (W_E_numpy, lengths_nospace),
    "Basic lengths excluding outliers": (W_E_numpy[length_outliers_mask], lengths_basic[length_outliers_mask]),
    "Lengths without space excluding outliers": (W_E_numpy[length_outliers_mask], lengths_nospace[length_outliers_mask]),
}

regressions = {}
preditions = {}
for name, (X, y) in regression_tasks.items():
    regressions[name] = LinearRegression().fit(X, y)
    preditions[name] = regressions[name].predict(X)

In [23]:
for name in regression_tasks:
    fig = px.box(
        x=regression_tasks[name][1],
        y=preditions[name],
        title=name,
        labels=dict(x="True length", y="Predicted length"),
    )
    fig.add_shape(type="line", x0=0, y0=0, x1=15,y1=15, line=dict(dash="dot"), label=dict(text="x=y"))
    fig.show()

In [24]:
for name, (X, y) in regression_tasks.items():
    print(name)
    print(regressions[name].score(X, y))

Basic lengths
0.707913266402122
Lengths without space
0.6734827903596416
Basic lengths excluding outliers
0.8134518075536497
Lengths without space excluding outliers
0.7879467384965992


## Numerical value

Can we regress on the numerical value of those tokens which are numbers?

In [25]:
# Determine which tokens are numbers and their numerical values
is_number_mask = np.empty(d_vocab, dtype=bool)
numerical_values = []
for i,str_token in enumerate(str_tokens):
    is_number_mask[i] = all(x in "0123456789" for x in str_token)
    if is_number_mask[i]:
        numerical_values.append(int(str_token))

numerical_values = np.array(numerical_values)

In [26]:
px.histogram(x=numerical_values, title="Numerical tokens").show()

Looks like there's an outlier!

In [27]:
numeric_outliers_mask = numerical_values <= NUMERIC_OUTLIER_THRESHOLD

In [28]:
px.histogram(x=numerical_values[numeric_outliers_mask], title="Numerical tokens excluding outliers").show()

In [29]:
regression = LinearRegression().fit(
    W_E_numpy[is_number_mask][numeric_outliers_mask], numerical_values[numeric_outliers_mask]
)
prediction = regression.predict(W_E_numpy[is_number_mask][numeric_outliers_mask])
fig = px.box(
    x=numerical_values[numeric_outliers_mask],
    y=prediction,
    title="Numerical Values",
    labels=dict(x="True Numerical Value", y="Predicted Numerical Value"),
)
fig.add_shape(
    type="line",
    x0=0,
    y0=0,
    x1=NUMERIC_OUTLIER_THRESHOLD,
    y1=NUMERIC_OUTLIER_THRESHOLD,
    line=dict(dash="dot"),
    label=dict(text="x=y"),
)
fig.show()

## Continuing up the model

Does the token length direction identified above still make sense in the residual stream further up the model?

In [30]:
length_regression = regressions["Basic lengths excluding outliers"]

In [31]:
dataset = load_dataset("wikitext", "wikitext-2-v1", split="test")
dataset = dataset.filter(lambda data_dict: data_dict["text"] != "")
dataset = dataset.shuffle()[:DATASET_SIZE]
dataset_tokens = model.to_tokens(dataset["text"], move_to_device=False)

Found cached dataset wikitext (/home/sam/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
Loading cached processed dataset at /home/sam/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-a9a0c12ef3289d1b.arrow


In [32]:
# Get the names of the hooks
resid_hooks = list(model.hook_dict.keys())
print(resid_hooks)

['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_resid_mid', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_resid_mid', 'blocks.1.hook_resid_post', 'blocks.2.hook_resid_pre', 'blocks.2.hook_resid_mid', 'blocks.2.hook_resid_post', 'blocks.3.hook_resid_pre', 'blocks.3.hook_resid_mid', 'blocks.3.hook_resid_post', 'blocks.4.hook_resid_pre', 'blocks.4.hook_resid_mid', 'blocks.4.hook_resid_post', 'blocks.5.hook_resid_pre', 'blocks.5.hook_resid_mid', 'blocks.5.hook_resid_post', 'blocks.6.hook_resid_pre', 'blocks.6.hook_resid_mid', 'blocks.6.hook_resid_post', 'blocks.7.hook_resid_pre', 'blocks.7.hook_resid_mid', 'blocks.7.hook_resid_post', 'blocks.8.hook_resid_pre', 'blocks.8.hook_resid_mid', 'blocks.8.hook_resid_post', 'blocks.9.hook_resid_pre', 'blocks.9.hook_resid_mid', 'blocks.9.hook_resid_post', 'blocks.10.hook_resid_pre', 'blocks.10.hook_resid_mid', 'blocks.10.hook_resid_post', 'blocks.11.hook_resid_pre', 'blocks.11.hook_resid_mid', 'block

In [33]:
dataloader = DataLoader(dataset_tokens, batch_size=BATCH_SIZE)

resid_vectors = dict(zip(resid_hooks, [[] for i in range(len(resid_hooks))]))
actual_lengths = dict(zip(resid_hooks, [[] for i in range(len(resid_hooks))]))

for batch in tqdm(dataloader):

    _, cache = model.run_with_cache(batch)

    for name in resid_hooks:
        cached_vectors = cache[name]
        resid_vectors[name].append(
            utils.to_numpy(cached_vectors[batch != model.tokenizer.bos_token_id])
        )
        tokens = utils.to_numpy(batch[batch != model.tokenizer.bos_token_id])
        actual_lengths[name].append(lengths_basic[tokens])

resid_regression_scores = {}
for name in resid_hooks:
    resid_vectors[name] = np.concatenate(resid_vectors[name])
    actual_lengths[name] = np.concatenate(actual_lengths[name])
    resid_regression_scores[name] = length_regression.score(resid_vectors[name], actual_lengths[name])

100%|██████████| 128/128 [04:01<00:00,  1.89s/it]


In [35]:
del resid_vectors
del actual_lengths

In [42]:
px.line(
    y=resid_regression_scores.values(),
    x=resid_regression_scores.keys(),
    labels=dict(y="R2 score", x="Layer"),
    title=f"Linear regression score for token length prediction through {MODEL_NAME} residual stream",
)

In [41]:
resid_regression_scores_abs = np.abs(np.array(list(resid_regression_scores.values())))

px.line(
    y=resid_regression_scores_abs,
    x=resid_regression_scores.keys(),
    labels=dict(y="Absolute R2 score", x="Layer"),
    title=f"Absolute linear regression score for token length prediction through {MODEL_NAME} residual stream",
    log_y=True,
)