# Counting parentheses in token embedding space

Can we count how many open and closed parens there are in a token from the embedding?

In [1]:
FORCE_CPU = True                        #@param {type:"boolean"}
MODEL_NAME = "gelu-1l"                  #@param {type:"string"}
TOKEN_BEGIN_SPACE = "Ġ"                 #@param {type:"string"}
LENGTH_OUTLIER_THRESHOLD = 15           #@param {type:"integer"}
NUMERIC_OUTLIER_THRESHOLD = 1000        #@param {type:"integer"}
DATASET_SIZE = 256                      #@param {type:"integer"}
BATCH_SIZE = 2                          #@param {type:"integer"}

## Setup

In [2]:
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/SamAdamDay/mechanistic-interpretability-projects.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

import plotly.io as pio
pio.renderers.default = "colab+vscode"

Running as a Jupyter notebook
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [32]:
import re
from typing import Callable
from dataclasses import dataclass
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, IterableDataset

import numpy as np
from numpy.typing import NDArray

from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.metrics import confusion_matrix

from rich.console import Console
from rich.table import Table

from fancy_einsum import einsum

from tqdm import tqdm

import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go

import matplotlib.pyplot as plt

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import (
    HookedTransformer,
    HookedTransformerConfig,
    FactoredMatrix,
    ActivationCache,
)

In [4]:
if torch.cuda.is_available() and not FORCE_CPU:
    device = "cuda"
else:
    device = "cpu"
print(device)

cpu



CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:109.)



In [5]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f537d533750>

## Load model and tokens

In [6]:
model = HookedTransformer.from_pretrained(MODEL_NAME, device="cpu")

Loaded pretrained model gelu-1l into HookedTransformer


In [7]:
W_E = model.W_E
W_E_numpy = utils.to_numpy(W_E)
print(W_E_numpy.shape)

(48262, 512)


In [8]:
d_vocab = model.tokenizer.vocab_size
str_tokens = model.tokenizer.convert_ids_to_tokens(list(range(d_vocab)))

## Number of parentheses

In [9]:
num_open = np.array([x.count("(") for x in str_tokens])
num_close = np.array([x.count(")") for x in str_tokens])

In [10]:
px.histogram(x=num_open, title="Num open brackets")

In [11]:
px.histogram(x=num_open, title="Num closed brackets")

## Linear regression

In [12]:
open_lin_regression = LinearRegression().fit(W_E_numpy, num_open)
close_lin_regression = LinearRegression().fit(W_E_numpy, num_close)

open_lin_pred = open_lin_regression.predict(W_E_numpy)
close_lin_pred = close_lin_regression.predict(W_E_numpy)

In [18]:
fig = px.box(
    x=num_open,
    y=open_lin_pred,
    title="Linear regression of open brackets",
    labels=dict(x="True number of opens", y="Predicted number of opens"),
)
fig.add_shape(type="line", x0=0, y0=0, x1=3,y1=3, line=dict(dash="dot"), label=dict(text="x=y"))
fig.show()

In [20]:
fig = px.box(
    x=num_close,
    y=close_lin_pred,
    title="Linear regression of closed brackets",
    labels=dict(x="True number of closeds", y="Predicted number of closeds"),
)
fig.add_shape(type="line", x0=0, y0=0, x1=4,y1=4, line=dict(dash="dot"), label=dict(text="x=y"))
fig.show()

## Logistic regression

In [22]:
open_log_regression = LogisticRegression().fit(W_E_numpy, num_open)
close_log_regression = LogisticRegression().fit(W_E_numpy, num_close)

open_log_pred = open_log_regression.predict(W_E_numpy)
close_log_pred = close_log_regression.predict(W_E_numpy)

In [None]:
def display_confusion_matrix(confusion_matrix, labels):
    console = Console()

    # Create table
    table = Table(show_header=True, header_style="bold")
    table.add_column("Actual / Predicted", justify="center")
    for label in labels:
        table.add_column(str(label), justify="center")

    # Add rows to the table
    for i, actual_label in enumerate(labels):
        row = [str(actual_label)]
        for j, predicted_label in enumerate(labels):
            count = confusion_matrix[i, j]
            row.append(str(count))
        table.add_row(*row)

    # Display the table
    console.print(table)

In [31]:
display_confusion_matrix(confusion_matrix(num_open, open_log_pred), labels=["0", "1", "2", "3"])

In [33]:
display_confusion_matrix(confusion_matrix(num_close, close_log_pred), labels=["0", "1", "2", "3", "4"])

## Comparing logistic regression on a random embedding matrix

In [36]:
W_E_random = np.random.normal(size=W_E_numpy.shape, loc=np.mean(W_E_numpy), scale=np.std(W_E_numpy))

In [39]:
open_rand_log_regression = LogisticRegression().fit(W_E_random, num_open)
close_rand_log_regression = LogisticRegression().fit(W_E_random, num_close)

open_rand_log_pred = open_rand_log_regression.predict(W_E_random)
close_rand_log_pred = close_rand_log_regression.predict(W_E_random)

In [40]:
display_confusion_matrix(confusion_matrix(num_open, open_rand_log_pred), labels=["0", "1", "2", "3"])

In [41]:
display_confusion_matrix(confusion_matrix(num_close, close_rand_log_pred), labels=["0", "1", "2", "3", "4"])