# Verified Integer Mathematics in Transformers - Train an SAE

This Colab trains an SAE on a Transformer models to help understand its MLP layer.

The models perform integer addition and/or subtraction e.g. 133357+182243=+0315600 and 123450-345670=-0123230. Each digit is a separate token. For 6 digit questions, the model is given 14 "question" (input) tokens, and must then predict the corresponding 8 "answer" (output) tokens.

This Colab follows on from https://github.com/PhilipQuirke/quanta_addition_subtraction/blob/main/notebooks/VerifiedArithmeticTrain.ipynb which trained the models, and outputs model_name.pth and model_name_train.json



## Tips for using the Colab
 * You can run and alter the code in this CoLab notebook yourself in Google CoLab ( https://colab.research.google.com/ ).
 * To run the notebook, in Google CoLab, **you will need to** go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.
 * Some graphs are interactive!
 * Use the table of contents pane in the sidebar to navigate.
 * Collapse irrelevant sections with the dropdown arrows.
 * Search the page using the search in the sidebar, not CTRL+F.

# Part 0: Import libraries
Imports standard libraries.

Imports "verified_transformer" public library as "qt". This library is specific to this CoLab's "QuantaTool" approach to transformer analysis. Refer to [README.md](https://github.com/PhilipQuirke/quanta_addition_subtraction/blob/main/README.md) for more detail.

In [1]:
DEVELOPMENT_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    !pip install matplotlib

    !pip install kaleido
    !pip install transformer_lens
    !pip install torchtyping
    !pip install transformers

    !pip install numpy
    !pip install scikit-learn

except:
    IN_COLAB = False

    def setup_jupyter(install_libraries=False):
        if install_libraries:
            !pip install matplotlib==3.8.4
            !pip install kaleido==0.2.1
            !pip install transformer_lens==1.15.0
            !pip install torchtyping==0.1.4
            !pip install transformers==4.39.3

            !pip install numpy==1.26.4
            !pip install plotly==5.20.0
            !pip install pytest==8.1.1
            !pip install scikit-learn==1.4.1.post1

        print("Running as a Jupyter notebook - intended for development only!")
        from IPython import get_ipython

        ipython = get_ipython()
        # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
        ipython.magic("load_ext autoreload")
        ipython.magic("autoreload 2")

    # setup_jupyter(install_libraries=True)   # Uncomment if you need to install libraries in notebook.
    setup_jupyter(install_libraries=False)

Running as a Colab notebook
Collecting kaleido
  Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl.metadata (15 kB)
Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl (79.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.9/79.9 MB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: kaleido
Successfully installed kaleido-0.2.1
Collecting transformer_lens
  Downloading transformer_lens-2.8.0-py3-none-any.whl.metadata (12 kB)
Collecting beartype<0.15.0,>=0.14.1 (from transformer_lens)
  Downloading beartype-0.14.1-py3-none-any.whl.metadata (28 kB)
Collecting better-abc<0.0.4,>=0.0.3 (from transformer_lens)
  Downloading better_abc-0.0.3-py3-none-any.whl.metadata (1.4 kB)
Collecting datasets>=2.7.1 (from transformer_lens)
  Downloading datasets-3.0.2-py3-none-any.whl.metadata (20 kB)
Collecting fancy-einsum>=0.0.3 (from transformer_lens)
  Downloading fancy_einsum-0.0.3-py3-none-any.whl.metadata (1.2 kB)
Collecting jaxty

In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import kaleido
import plotly.io as pio

if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

Using renderer: colab


In [3]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [4]:
import json
import torch
import torch.nn.functional as F
import numpy as np
import random
import itertools
import re
from enum import Enum

In [5]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import textwrap

In [6]:
import transformer_lens
from transformer_lens.utils import download_file_from_hf
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [7]:
!pip install scikit-optimize

import re
import sklearn # Aka scikit.learn
import skopt # Aka scikit.optimize

Collecting scikit-optimize
  Downloading scikit_optimize-0.10.2-py2.py3-none-any.whl.metadata (9.7 kB)
Collecting pyaml>=16.9 (from scikit-optimize)
  Downloading pyaml-24.9.0-py3-none-any.whl.metadata (11 kB)
Downloading scikit_optimize-0.10.2-py2.py3-none-any.whl (107 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m107.8/107.8 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyaml-24.9.0-py3-none-any.whl (24 kB)
Installing collected packages: pyaml, scikit-optimize
Successfully installed pyaml-24.9.0 scikit-optimize-0.10.2


In [8]:
!pip install --upgrade git+https://github.com/PhilipQuirke/quanta_mech_interp.git
import QuantaMechInterp as qt

Collecting git+https://github.com/PhilipQuirke/quanta_mech_interp.git
  Cloning https://github.com/PhilipQuirke/quanta_mech_interp.git to /tmp/pip-req-build-y6fe2mgo
  Running command git clone --filter=blob:none --quiet https://github.com/PhilipQuirke/quanta_mech_interp.git /tmp/pip-req-build-y6fe2mgo
  Resolved https://github.com/PhilipQuirke/quanta_mech_interp.git to commit 02dbee25e238658e3e765ac42118255e52b77893
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: QuantaMechInterp
  Building wheel for QuantaMechInterp (pyproject.toml) ... [?25l[?25hdone
  Created wheel for QuantaMechInterp: filename=QuantaMechInterp-1.1-py3-none-any.whl size=45337 sha256=676692db360fc88cc9579184e6537c616d74d01df1b7d26f9abc2aeaecad85e7
  Stored in directory: /tmp/pip-ephem-wheel-cache-8dv8d1wv/wheels/0f/d1/c3/1763b9cf263825f6202103985dba1535

In [9]:
!pip install --upgrade git+https://github.com/PhilipQuirke/quanta_addition_subtraction.git
import maths_tools as mt

Collecting git+https://github.com/PhilipQuirke/quanta_addition_subtraction.git
  Cloning https://github.com/PhilipQuirke/quanta_addition_subtraction.git to /tmp/pip-req-build-71vkv14j
  Running command git clone --filter=blob:none --quiet https://github.com/PhilipQuirke/quanta_addition_subtraction.git /tmp/pip-req-build-71vkv14j
  Resolved https://github.com/PhilipQuirke/quanta_addition_subtraction.git to commit 24e7f77f433f91fa8aa398092dc58c5a693af420
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: QuantaAdditionSubtraction
  Building wheel for QuantaAdditionSubtraction (pyproject.toml) ... [?25l[?25hdone
  Created wheel for QuantaAdditionSubtraction: filename=QuantaAdditionSubtraction-1.1-py3-none-any.whl size=35268 sha256=8bd9eae2e43a7a193e7aa66b02e532d951da6107ae9b7cdfb1d8365d7b330453
  Stored in directory: /tmp/pip-eph

# Part 1A: Configuration

Which existing model do we want to analyze?

The existing model weightings created by the sister Colab [VerifiedArithmeticTrain](https://github.com/PhilipQuirke/quanta_addition_subtraction/blob/main/notebooks/VerifiedArithmeticTrain.ipynb) are loaded from HuggingFace (in Part 5). Refer https://github.com/PhilipQuirke/quanta_addition_subtraction/blob/main/README.md for more detail.

In [10]:
# Singleton QuantaTool "main" configuration class. MathsConfig is derived from the chain AlgoConfig > UsefulConfig > ModelConfig
cfg = mt.MathsConfig()


# Which model do we want to analyze? Uncomment one line:

# Addition models
#cfg.set_model_names( "add_d5_l1_h3_t15K_s372001" )  # AddAccuracy=Two9s. Inaccurate as only has one layer. Can predict S0, S1 and S2 complexity questions.
#cfg.set_model_names( "add_d5_l2_h3_t15K_s372001" )  # AddAccuracy=Six9s. AvgFinalLoss=1.6e-08
#cfg.set_model_names( "add_d5_l2_h3_t40K_s372001" )  # AddAccuracy=Six9s. AvgFinalLoss=2e-09. Fewest nodes
#cfg.set_model_names( "add_d6_l2_h3_t15K_s372001" )  # AddAccuracy=Fives. AvgFinalLoss=1.7e-08. (2/M fail: 018539+789353=+0807892 ModelAnswer: +0707892, 747332+057349=+0804681 ModelAnswer: +0704681)
#cfg.set_model_names( "add_d6_l2_h3_t20K_s173289" )  # AddAccuracy=Six9s. AvgFinalLoss=1.5e-08. Fewest nodes
#cfg.set_model_names( "add_d6_l2_h3_t20K_s572091" )  # AddAccuracy=Six9s. AvgFinalLoss=7e-09
#cfg.set_model_names( "add_d6_l2_h3_t40K_s372001" )  # AddAccuracy=Six9s. AvgFinalLoss 2e-09
#cfg.set_model_names( "add_d10_l2_h3_t40K_s572091" ) # AddAccuracy=Six9s. AvgFinalLoss=7e-09. (1/M fail: 0000000555+0000000445=+00000001000 ModelAnswer: +00000000900)
#cfg.set_model_names( "add_d10_l2_h3_t40K_gf_s572091" ) # AddAccuracy=Six9s. AvgFinalLoss=3.5-09. GrokFast.

# Subtraction model
#cfg.set_model_names( "sub_d6_l2_h3_t30K_s372001" )  # SubAccuracy=Six9s. AvgFinalLoss=5.8e-06
#cfg.set_model_names( "sub_d10_l2_h3_t75K_s173289" )  # SubAccuracy=Two9s. (6672/M fails) AvgFinalLoss=0.002002.
#cfg.set_model_names( "sub_d10_l2_h3_t75K_gf_s173289" )  # SubAccuracy=Two9s. GrokFast. (5246/M fails) AvgFinalLoss=0.001197

# Mixed (addition and subtraction) model
#cfg.set_model_names( "mix_d6_l3_h4_t40K_s372001" )  # Add/SubAccuracy=Six9s/Six9s. AvgFinalLoss=5e-09. (1/M fail: 463687+166096=+0629783 ModelAnswer: +0639783)
#cfg.set_model_names( "mix_d10_l3_h4_t75K_s173289" )  # Add/SubAccuracy=Five9s/Two9s. AvgFinalLoss=1.125e-06 (2/M fail: 3301956441+6198944455=+09500900896 ModelAnswer: +09500800896) (295/M fail: 8531063649-0531031548=+08000032101 ModelAnswer: +07900032101)
#cfg.set_model_names( "mix_d10_l3_h4_t75K_gf_s173289" )  # Add/SubAccuracy=Six9s/Three9s. GrokFast. AvgFinalLoss 8.85e-07 (1/M fail) (294/M fail)

# Mixed models initialized with addition model
#cfg.set_model_names( "ins1_mix_d6_l2_h3_t40K_s572091" )  # Add/SubAccuracy=Six9s/Five9s. AvgLoss = 2.4e-08 (5/M fails e.g. 565000-364538=+0200462 ModelAnswer: +0100462)
#cfg.set_model_names( "ins1_mix_d6_l3_h3_t40K_s572091" )  # Add/SubAccuracy=Six9s/Five9s. AvgFinalLoss=1.8e-08. (3/M fails e.g. 072074-272074=-0200000 ModelAnswer: +0200000)
#cfg.set_model_names( "ins1_mix_d6_l3_h3_t80K_s572091" )  # Add/SubAccuracy=Six9s/Five9s AvgLoss = 1.6e-08 (3/M fails e.g. 229672-229678=-0000006 ModelAnswer: +0000006) (EnrichFalse => 0/M, 4/M)
cfg.set_model_names( "ins1_mix_d6_l3_h4_t40K_s372001" )  # Add/SubAccuracy=Six9s/Six9s. AvgFinalLoss=8e-09. MAIN FOCUS
#cfg.set_model_names( "ins1_mix_d6_l3_h4_t40K_s173289" )  # Add/SubAccuracy=Five9s/Five9s. AvgFinalLoss=1.4e-08. (3/M fails e.g. 850038+159060=+1009098 ModelAnswer: +0009098) (2/M fails e.g. 77285-477285=+0100000 Q: ModelAnswer: +0000000) (EnrichFalse => 0/M, 3/M)
#cfg.set_model_names( "ins1_mix_d6_l3_h4_t50K_s572091" )  # Add/SubAccuracy=Six9s/Five9s. AvgFinalLoss=2.9e-08. (4/M fails e.g. 986887-286887=+0700000 ModelAnswer: +0600000) (EnrichFalse => 0/M, 3/M)
#cfg.set_model_names( "ins1_mix_d10_l3_h3_t50K_s572091" )  # Add/SubAccuracy=Five9s/Five9s. AvgFinalLoss 6.3e-07  (6/M fails e.g. 5068283822+4931712829=+09999996651 ModelAnswer: +19099996651) (7/M fails e.g. 3761900218-0761808615=+03000091603 ModelAnswer: +02000091603)
#cfg.set_model_names( "ins1_mix_d10_l3_h3_t50K_gf_s572091" ) # Add/SubAccuracy=Five9s/Two9s. GrokFast. AvgFinalLoss=4.0e-06 (2/M fails) (1196/M fails)

# Mixed model initialized with addition model. Reset useful heads every 100 epochs.
#cfg.set_model_names( "ins2_mix_d6_l4_h4_t40K_s372001" )  # Add/SubAccuracy=Five9s/Five9s. AvgFinalLoss=1.7e-08. (3/M fails e.g. 530757+460849=+0991606 ModelAnswer: +0091606) (8 fails e.g. 261926-161857=+0100069 ModelAnswer: +0000069)

# Mixed model initialized with addition model. Reset useful heads & MLPs every 100 epochs.
#cfg.set_model_names( "ins3_mix_d6_l4_h3_t40K_s372001" )  # Add/SubAccuracy=Four9s/Two9s. AvgFinalLoss=3.0e-04. (17/M fails e.g. 273257+056745=+0330002 ModelAnswer: +0320002) (3120 fails e,g. 09075-212133=-0003058 ModelAnswer: +0003058)

# Mixed models initialized with addition model.
#cfg.set_model_names( "ins4_mix_d6_l3_h4_t30K_s775824" )  # Add/SubAccuracy=???/??? TODO
#cfg.set_model_names( "ins4_mix_d6_l2_h4_t30K_s775824" )  # Add/SubAccuracy=???/??? TODO

# Part 1B: Configuration: Input and Output file names



In [11]:
cfg.batch_size = 512 # Default analysis batch size
if cfg.n_layers >= 3 and cfg.n_heads >= 4:
  cfg.batch_size = 256 # Reduce batch size to avoid memory constraint issues.

cfg.set_seed(cfg.analysis_seed)

In [12]:
main_fname_pth = cfg.model_name + '.pth'
main_fname_train_json = cfg.model_name + '_train.json'
main_fname_sae_pth = cfg.model_name + '_sae.pth'
main_repo_name="PhilipQuirke/VerifiedArithmetic"

In [13]:
# Update "cfg" with additional training config (including cfg.insert_*) information from existing file:
#      https://huggingface.co/PhilipQuirke/VerifiedArithmetic/raw/main/ins1_mix_d6_l3_h4_t40K_s372001_train.json"
training_data_json = qt.download_huggingface_json(main_repo_name, main_fname_train_json)
training_loss_list = qt.load_training_json(cfg, training_data_json)

(…)ns1_mix_d6_l3_h4_t40K_s372001_train.json:   0%|          | 0.00/901k [00:00<?, ?B/s]

In [14]:
def print_config():
  print("%Add=", cfg.perc_add, "%Sub=", cfg.perc_sub, "%Mult=", cfg.perc_mult, "InsertMode=", cfg.insert_mode, "File=", cfg.model_name)

In [15]:
print_config()
print("weight_decay=", cfg.weight_decay, "lr=", cfg.lr, "batch_size=", cfg.batch_size)
print('Main model will be read from HuggingFace file', main_repo_name, main_fname_pth)
print('Main model training config / loss was read from HuggingFace file', main_fname_train_json)
print('Saving sae model to', main_repo_name, main_fname_sae_pth)

%Add= 20 %Sub= 80 %Mult= 0 InsertMode= 1 File= ins1_mix_d6_l3_h4_t40K_s372001
weight_decay= 0.1 lr= 8e-05 batch_size= 64
Main model will be read from HuggingFace file PhilipQuirke/VerifiedArithmetic ins1_mix_d6_l3_h4_t40K_s372001.pth
Main model training config / loss was read from HuggingFace file ins1_mix_d6_l3_h4_t40K_s372001_train.json
Saving sae model to PhilipQuirke/VerifiedArithmetic ins1_mix_d6_l3_h4_t40K_s372001_sae.pth


# Part 3A: Set Up: Vocabulary / Embedding / Unembedding

  

In [16]:
cfg.initialize_maths_token_positions()
mt.set_maths_vocabulary(cfg)
mt.set_maths_question_meanings(cfg)
print(cfg.token_position_meanings)

['D5', 'D4', 'D3', 'D2', 'D1', 'D0', 'OPR', "D'5", "D'4", "D'3", "D'2", "D'1", "D'0", '=', 'A7', 'A6', 'A5', 'A4', 'A3', 'A2', 'A1', 'A0']


# Part 3B: Set Up: Create model

In [17]:
# Structure is documented at https://neelnanda-io.github.io/TransformerLens/transformer_lens.html#transformer_lens.HookedTransformerConfig.HookedTransformerConfig
ht_cfg = cfg.get_HookedTransformerConfig()

# Create the main transformer model
cfg.main_model = HookedTransformer(ht_cfg)

# Part 5: Set Up: Load Model from HuggingFace

In [18]:
print("Loading model from HuggingFace", main_repo_name, main_fname_pth)

cfg.main_model.load_state_dict(download_file_from_hf(repo_name=main_repo_name, file_name=main_fname_pth, force_is_torch=True))
cfg.main_model.eval()

Loading model from HuggingFace PhilipQuirke/VerifiedArithmetic ins1_mix_d6_l3_h4_t40K_s372001.pth


ins1_mix_d6_l3_h4_t40K_s372001.pth:   0%|          | 0.00/41.8M [00:00<?, ?B/s]

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-2): 3 x TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resi

# Part 6: Train a SAE

Train an SAE which preferably has:
- Low Loss and MSE (good reconstruction. Very important)
- Low sparsity (Sparsity = # neurons activating. So low sparsity means few neurons are active in any given prediction)
- Low number of Active Neurons (makes interpretation easier)


In [20]:
dataloader = mt.get_mixed_maths_dataloader(cfg, num_batches=1000, enrich_data=True)
print("Data set size", len(dataloader.dataset))

Data set size 64000


In [21]:
sae, score, loss, sparsity, neurons_used = qt.analyze_mlp_with_sae(cfg, dataloader, layer_num=0, encoding_dim=32, learning_rate=5e-4, sparsity_target=0.1, sparsity_weight=1e-3, num_epochs=10)
print( f"Score: {score:.4f}, Loss {loss:.4f}, Sparsity {sparsity:.4f}, Neurons Used: {neurons_used}.")

Epoch: 1, Score: 9.0763, Loss: 0.0736, MSE: 0.0686, Sparsity Penalty: 4.9361, L1 Penalty: 0.3280, Sparsity: 71.69%, Neurons used: 32/32 (100.00%)
Epoch: 6, Score: 4.7991, Loss: 0.0297, MSE: 0.0288, Sparsity Penalty: 0.9105, L1 Penalty: 0.1846, Sparsity: 82.80%, Neurons used: 32/32 (100.00%)
Epoch: 10, Score: 4.7784, Loss: 0.0292, MSE: 0.0287, Sparsity Penalty: 0.4986, L1 Penalty: 0.1591, Sparsity: 85.44%, Neurons used: 32/32 (100.00%)
Score: 4.7784, Loss 0.0292, Sparsity 0.8544, Neurons Used: 32.


In [None]:
sae, score, loss, sparsity, neurons_used = qt.analyze_mlp_with_sae(cfg, dataloader, layer_num=0, encoding_dim=64, learning_rate=0.001, sparsity_target=0.05, sparsity_weight=0.1, num_epochs=10)
print( f"Score: {score:.4f}, Loss {loss:.4f}, Sparsity {sparsity:.4f}, Neurons Used: {neurons_used}.")

# Part 7: Sweep hyperparams to find the best SAE

Sweep hyperparameters to train/score multiple SAEs to find the best scoring SAE. Slow.

In [None]:
param_grid = {
    'encoding_dim': [64], #[32, 64, 128, 256, 512],
    'learning_rate': [1e-3], # [1e-4, 1e-3, 1e-2],
    'sparsity_target': [0.05, 0.1], # [0.001, 0.005, 0.01, 0.05, 0.1],
    'sparsity_weight': [0.1], #[1e-3, 1e-2, 1e-1, 1.0],
    'l1_weight': [1e-4, 1e-3, 1e-2],  # [1e-6, 1e-5, 1e-4, 1e-3, 1e-2],
    'num_epochs': [10],
    'patience': [2]
}

In [None]:
num_experiments = 1
for param_values in param_grid.values():
    num_experiments *= len(param_values)

print(f"Number of configurations to test: {num_experiments}")

In [None]:
sae, score, neurons_used, params = qt.optimize_sae_hyperparameters(cfg, dataloader, layer_num=0, param_grid=param_grid)

# Part 7: Visualize the SAE


In [None]:
qt.analyze_and_visualize_sae(cfg, sae, dataloader, layer_num=0, max_samples=1000, perplexity=30, n_iter=250)