# Quanta Maths: Integer Addition and Subtraction in Transformers - Train an SAE

This Colab trains an SAE on a Transformer models to help understand its MLP layer.

The models perform integer addition and/or subtraction e.g. 133357+182243=+0315600 and 123450-345670=-0123230. Each digit is a separate token. For 6 digit questions, the model is given 14 "question" (input) tokens, and must then predict the corresponding 8 "answer" (output) tokens.

This Colab follows on from https://github.com/PhilipQuirke/quanta_maths/blob/main/notebooks/QuantaMatthsTrain.ipynb which trained the models, and outputs model.pth and training_loss.json



## Tips for using the Colab
 * You can run and alter the code in this CoLab notebook yourself in Google CoLab ( https://colab.research.google.com/ ).
 * To run the notebook, in Google CoLab, **you will need to** go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.
 * Some graphs are interactive!
 * Use the table of contents pane in the sidebar to navigate.
 * Collapse irrelevant sections with the dropdown arrows.
 * Search the page using the search in the sidebar, not CTRL+F.

# Part 0: Import libraries
Imports standard libraries.

Imports "quanta_maths" public library as "mmi". Refer to [README.md](https://github.com/PhilipQuirke/quanta_maths/blob/main/README.md) for more detail.

In [1]:
DEVELOPMENT_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    !pip install matplotlib
    !pip install kaleido

except:
    IN_COLAB = False

    def setup_jupyter(install_libraries=False):
        if install_libraries:
            !pip install matplotlib
            !pip install kaleido

        print("Running as a Jupyter notebook - intended for development only!")
        from IPython import get_ipython

        ipython = get_ipython()
        # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
        ipython.magic("load_ext autoreload")
        ipython.magic("autoreload 2")

    # setup_jupyter(install_libraries=True)   # Uncomment if you need to install libraries in notebook.
    setup_jupyter(install_libraries=False)

Running as a Colab notebook
Collecting kaleido
  Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl.metadata (15 kB)
Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl (79.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.9/79.9 MB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: kaleido
Successfully installed kaleido-0.2.1


In [2]:
!pip install --upgrade git+https://github.com/PhilipQuirke/quanta_mech_interp.git
import QuantaMechInterp as qmi

Collecting git+https://github.com/PhilipQuirke/quanta_mech_interp.git
  Cloning https://github.com/PhilipQuirke/quanta_mech_interp.git to /tmp/pip-req-build-d1xeco85
  Running command git clone --filter=blob:none --quiet https://github.com/PhilipQuirke/quanta_mech_interp.git /tmp/pip-req-build-d1xeco85
  Resolved https://github.com/PhilipQuirke/quanta_mech_interp.git to commit e17badd114f60ed4695c750e8073337d97fd8bb4
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting scikit-optimize (from QuantaMechInterp==1.0)
  Downloading scikit_optimize-0.10.2-py2.py3-none-any.whl.metadata (9.7 kB)
Collecting torchtyping>=0.1.4 (from QuantaMechInterp==1.0)
  Downloading torchtyping-0.1.5-py3-none-any.whl.metadata (9.5 kB)
INFO: pip is looking at multiple versions of torchtyping to determine which version is compatible with other requirements. This could take a while.
  

In [3]:
!pip install --upgrade git+https://github.com/PhilipQuirke/quanta_maths.git
import MathsMechInterp as mmi

Collecting git+https://github.com/PhilipQuirke/quanta_maths.git
  Cloning https://github.com/PhilipQuirke/quanta_maths.git to /tmp/pip-req-build-k5d4nr3o
  Running command git clone --filter=blob:none --quiet https://github.com/PhilipQuirke/quanta_maths.git /tmp/pip-req-build-k5d4nr3o
  Resolved https://github.com/PhilipQuirke/quanta_maths.git to commit 0d7f3307c09a4149d4cdb6e4aca13666bb4b397c
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting transformer_lens (from MathsMechInterp==1.0.0)
  Downloading transformer_lens-2.11.0-py3-none-any.whl.metadata (12 kB)
Collecting beartype<0.15.0,>=0.14.1 (from transformer_lens->MathsMechInterp==1.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl.metadata (28 kB)
Collecting better-abc<0.0.4,>=0.0.3 (from transformer_lens->MathsMechInterp==1.0.0)
  Downloading better_abc-0.0.3-py3-none-any.whl.metadata (1.4 kB)
Coll

In [4]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import kaleido
import plotly.io as pio

if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

Using renderer: colab


In [5]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [6]:
import json
import torch
import torch.nn.functional as F
import numpy as np
import random
import itertools
import re
from enum import Enum

In [7]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

In [8]:
import transformer_lens
from transformer_lens.utils import download_file_from_hf
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [9]:
import sklearn # Aka scikit.learn
import skopt # Aka scikit.optimize

# Part 1A: Configuration

Which existing model do we want to analyze?

The existing model weightings created by the sister Colab [QuantaMathsTrain](https://github.com/PhilipQuirke/quanta_maths/blob/main/notebooks/QuantaMathsTrain.ipynb) are loaded from HuggingFace (in Part 5). Refer https://github.com/PhilipQuirke/quanta_maths/blob/main/README.md for more detail.

In [13]:
# Singleton QuantaTool "main" configuration class. MathsConfig is derived from the chain AlgoConfig > UsefulConfig > ModelConfig
cfg = mmi.MathsConfig()


# Which model do we want to analyze? Uncomment one line:

# Addition models
#cfg.set_model_names( "add_d5_l1_h3_t15K_s372001" )  # AddAccuracy=Two9s. Inaccurate as only has one layer. Reproduces previous paper model.
#cfg.set_model_names( "add_d5_l2_h3_t15K_s372001" )  # AddAccuracy=Six9s. AvgFinalLoss=1.6e-08
#cfg.set_model_names( "add_d5_l2_h3_t40K_s372001" )  # AddAccuracy=Six9s. AvgFinalLoss=2e-09. (0/M fail). Fewest nodes
#cfg.set_model_names( "add_d6_l2_h3_t15K_s372001" )  # AddAccuracy=Fives. AvgFinalLoss=1.7e-08. (2/M fail)
#cfg.set_model_names( "add_d6_l2_h3_t20K_s173289" )  # AddAccuracy=Six9s. AvgFinalLoss=1.5e-08. (0/M fail). Fewest nodes
#cfg.set_model_names( "add_d6_l2_h3_t20K_s572091" )  # AddAccuracy=Six9s. AvgFinalLoss=7e-09.  (0/M fail)
#cfg.set_model_names( "add_d6_l2_h3_t40K_s372001" )  # AddAccuracy=Six9s. AvgFinalLoss=2e-09. (0/M fail)
#cfg.set_model_names( "add_d7_l2_h3_t45K_s173289" )  # AddAccuracy=Six9s. AvgFinalLoss=3e-09. (0/M fail)
#cfg.set_model_names( "add_d8_l2_h3_t45K_s173289" )  # AddAccuracy=Six9s. AvgFinalLoss=3e-09. (0/M fail)
#cfg.set_model_names( "add_d9_l2_h3_t45K_s173289" )  # AddAccuracy=Six9s. AvgFinalLoss=3e-09. (0/M fail)
#cfg.set_model_names( "add_d10_l2_h3_t40K_s572091" ) # AddAccuracy=Six9s. AvgFinalLoss=7e-09. (1/M fail)
#cfg.set_model_names( "add_d10_l2_h3_t40K_gf_s572091" ) # AddAccuracy=Six9s. AvgFinalLoss=3.5-09. GrokFast. Minor accuracy improvement
#cfg.set_model_names( "add_d11_l2_h3_t50K_s572091" )  # AddAccuracy=Five9s. AvgFinalLoss=8e-09. (2/M fail)
#cfg.set_model_names( "add_d12_l2_h3_t50K_s572091" )  # AddAccuracy=Five9s. AvgFinalLoss=5e-09. (3/M fail)
#cfg.set_model_names( "add_d13_l2_h3_t50K_s572091" )  # AddAccuracy=Six9s. AvgFinalLoss=6.3e-08. (1/M fail)
#cfg.set_model_names( "add_d14_l2_h3_t60K_s572091" )  # AddAccuracy=Three9S. AvgFinalLoss=5.6e-06. (199/M fail)
#cfg.set_model_names( "add_d15_l2_h3_t80K_s572091" ) # AddAccuracy=Five9s. AvgFinalLoss=8.6e-08 (10/M fail)
#cfg.set_model_names( "add_d20_l2_h3_t80K_s572091" ) # AddAccuracy=Poor! AvgFinalLoss=0.20!

# Subtraction model
#cfg.set_model_names( "sub_d6_l2_h3_t30K_s372001" )  # SubAccuracy=Six9s. AvgFinalLoss=5.8e-06
#cfg.set_model_names( "sub_d10_l2_h3_t75K_s173289" )  # SubAccuracy=Two9s. AvgFinalLoss=0.002002. (6672/M fails)
#cfg.set_model_names( "sub_d10_l2_h3_t75K_gf_s173289" )  # SubAccuracy=Two9s. GrokFast. AvgFinalLoss=0.001197. (5246/M fails). Minor accuracy improvement

# Mixed (addition and subtraction) model
#cfg.set_model_names( "mix_d5_l3_h4_t40K_s372001" )  # Add/SubAccuracy=Six9s/Six9s. AvgFinalLoss=9e-09. (0/M fails, 0/M fails)
#cfg.set_model_names( "mix_d6_l3_h4_t40K_s372001" )  # Add/SubAccuracy=Six9s/Six9s. AvgFinalLoss=5e-09. (1/M fail)
#cfg.set_model_names( "mix_d7_l3_h4_t50K_s372001" )  # Add/SubAccuracy=Five9s/Five9s. AvgFinalLoss=2e-08. (2/M fails, 6/M fails)
#cfg.set_model_names( "mix_d8_l3_h4_t60K_s173289" )  # Add/SubAccuracy=Six9s/Five9s. AvgFinalLoss=4.7e-08. (0/M fails, 7/M fails)
#cfg.set_model_names( "mix_d9_l3_h4_t60K_s173289" )  # Add/SubAccuracy=Six9s/Four9s. AvgFinalLoss=3.2e-07. (1/M fails, 33/M fails)
#cfg.set_model_names( "mix_d10_l3_h4_t75K_s173289" )  # Add/SubAccuracy=Five9s/Two9s. AvgFinalLoss=1.125e-06 (2/M fail, 295/M fail)
#cfg.set_model_names( "mix_d10_l3_h4_t75K_gf_s173289" )  # Add/SubAccuracy=Six9s/Three9s. GrokFast. AvgFinalLoss=8.85e-07 (1/M fail, 294/M fail). Minor accuracy improvement
#cfg.set_model_names( "mix_d11_l3_h4_t80K_s572091" )  # Add/SubAccuracy=Six9s/Four9s AvgFinalLoss=3.9e-08 (0/M fail, 13/M fail)
#cfg.set_model_names( "mix_d12_l3_h4_t85K_s572091" )  # Add/SubAccuracy=Five9s/Five9s. AvgFinalLoss=1.7e-08. (2/M fail, 10/M fail)
#cfg.set_model_names( "mix_d13_l3_h4_t85K_s572091" )  # Add/SubAccuracy=Three9s/Two9s. AvgFinalLoss=9.5e-06. (399/M fail, 4164/M fail)

# Mixed models initialized with addition model. Params fine-tuned during training
#cfg.set_model_names( "ins1_mix_d5_l2_h3_t40K_s572091" )  # Add/SubAccuracy=TODO
#cfg.set_model_names( "ins1_mix_d6_l2_h3_t40K_s572091" )  # Add/SubAccuracy=Six9s/Five9s. AvgLoss = 2.4e-08 (5/M fails)
#cfg.set_model_names( "ins1_mix_d6_l3_h3_t40K_s572091" )  # Add/SubAccuracy=Six9s/Five9s. AvgFinalLoss=1.8e-08. (3/M fails)
#cfg.set_model_names( "ins1_mix_d6_l3_h3_t80K_s572091" )  # Add/SubAccuracy=Six9s/Five9s AvgLoss = 1.6e-08 (3/M fails)
cfg.set_model_names( "ins1_mix_d6_l3_h4_t40K_s372001" )  # Add/SubAccuracy=Six9s/Six9s. AvgFinalLoss=8e-09. MAIN FOCUS
#cfg.set_model_names( "ins1_mix_d6_l3_h4_t40K_s173289" )  # Add/SubAccuracy=Five9s/Five9s. AvgFinalLoss=1.4e-08. (3/M fails, 2/M fails)
#cfg.set_model_names( "ins1_mix_d6_l3_h4_t50K_s572091" )  # Add/SubAccuracy=Six9s/Five9s. AvgFinalLoss=2.9e-08. (4/M fails)
#cfg.set_model_names( "ins1_mix_d7_l3_h4_t50K_s572091" )  # Add/SubAccuracy=Five9s/Six9s. AvgFinalLoss=1.3e-08. (4/M fails, 1/M fails)
#cfg.set_model_names( "ins1_mix_d8_l3_h4_t70K_s572091" )  # Add/SubAccuracy=Four9s/Two9s. AvgFinalLoss=7.2e-06. (50/M fails, 1196/M fails)
#cfg.set_model_names( "ins1_mix_d9_l3_h4_t70K_s572091" )  # Add/SubAccuracy=TODO. AvgFinalLoss=TODO. (50/M fails, TODO/M fails)
#cfg.set_model_names( "ins1_mix_d10_l3_h3_t50K_s572091" )  # Add/SubAccuracy=Five9s/Five9s. AvgFinalLoss 6.3e-07 (6/M fails, 7/M fails)
#cfg.set_model_names( "ins1_mix_d10_l3_h3_t50K_gf_s572091" ) # Add/SubAccuracy=Five9s/Two9s. GrokFast. AvgFinalLoss=4.0e-06 (2/M fails, 1196/M fails). Worse accuracy than without GF!
#cfg.set_model_names( "ins1_mix_d11_l3_h4_t75K_s572091" )  # Add/SubAccuracy=TODO. AvgFinalLoss=TODO. (TODO/M fails, TODO/M fails)

# Mixed model initialized with addition model. Reset useful heads every 100 epochs during training
#cfg.set_model_names( "ins2_mix_d6_l4_h4_t40K_s372001" )  # Add/SubAccuracy=Five9s/Five9s. AvgFinalLoss=1.7e-08. (3/M fails, 8/M fails)

# Mixed model initialized with addition model. Reset useful heads & MLPs every 100 epochs during training
#cfg.set_model_names( "ins3_mix_d6_l4_h3_t40K_s372001" )  # Add/SubAccuracy=Four9s/Two9s. AvgFinalLoss=3.0e-04. (17/M fails, 3120/M fails)

# Mixed model initialized with addition model. Reset useful heads & MLPs every training epoch
#cfg.set_model_names( "ins4_mix_d6_l4_h3_t40K_s372001" )  # AvgFinalLoss=4.7e-06

# Part 1B: Configuration: Input and Output file names



In [14]:
cfg.batch_size = 512 # Default analysis batch size
if cfg.n_layers >= 3 and cfg.n_heads >= 4:
  cfg.batch_size = 256 # Reduce batch size to avoid memory constraint issues.

cfg.set_seed(cfg.analysis_seed)

In [15]:
base_repo_name = 'PhilipQuirke'
model_pth_fname = 'model.pth'
model_training_loss_fname = 'training_loss.json'
model_sae_pth_fname = 'model_sae.pth'

cfg.hf_repo = base_repo_name + "/QuantaMaths_" + cfg.model_name

print('Model files will be read from HuggingFace repo:', base_repo_name, model_pth_fname, 'and', model_training_loss_fname)
print('Saving sae model to temporary Colab file:', model_sae_pth_fname)

Model files will be read from HuggingFace repo: PhilipQuirke model.pth and training_loss.json
Saving sae model to temporary Colab file: model_sae.pth


In [16]:
# Update "cfg" with additional training config (including cfg.insert_*) information from existing file:
#      https://huggingface.co/PhilipQuirke/QuantaMaths_ins1_mix_d6_l3_h4_t40K_s372001/training_loss.json"
training_data_json = qmi.download_huggingface_json(cfg.hf_repo, model_training_loss_fname)
training_loss_list = qmi.load_training_json(cfg, training_data_json)

training_loss.json:   0%|          | 0.00/901k [00:00<?, ?B/s]

In [17]:
def print_config():
  print("%Add=", cfg.perc_add, "%Sub=", cfg.perc_sub, "%Mult=", cfg.perc_mult, "InsertMode=", cfg.insert_mode, "File=", cfg.model_name)

In [18]:
print_config()
print("weight_decay=", cfg.weight_decay, "lr=", cfg.lr, "batch_size=", cfg.batch_size)

%Add= 20 %Sub= 80 %Mult= 0 InsertMode= 1 File= ins1_mix_d6_l3_h4_t40K_s372001
weight_decay= 0.1 lr= 8e-05 batch_size= 64


# Part 3A: Set Up: Vocabulary / Embedding / Unembedding

  

In [19]:
cfg.initialize_maths_token_positions()
mmi.set_maths_vocabulary(cfg)
mmi.set_maths_question_meanings(cfg)
print(cfg.token_position_meanings)

['D5', 'D4', 'D3', 'D2', 'D1', 'D0', 'OPR', "D'5", "D'4", "D'3", "D'2", "D'1", "D'0", '=', 'A7', 'A6', 'A5', 'A4', 'A3', 'A2', 'A1', 'A0']


# Part 3B: Set Up: Create model

In [20]:
# Structure is documented at https://neelnanda-io.github.io/TransformerLens/transformer_lens.html#transformer_lens.HookedTransformerConfig.HookedTransformerConfig
ht_cfg = cfg.get_HookedTransformerConfig()

# Create the main transformer model
cfg.main_model = HookedTransformer(ht_cfg)

# Part 5: Set Up: Load Model from HuggingFace

In [22]:
print("Loading model from HuggingFace:", cfg.hf_repo, model_pth_fname)

cfg.main_model.load_state_dict(download_file_from_hf(repo_name=cfg.hf_repo, file_name=model_pth_fname, force_is_torch=True))
cfg.main_model.eval()

Loading model from HuggingFace: PhilipQuirke/QuantaMaths_ins1_mix_d6_l3_h4_t40K_s372001 model.pth


model.pth:   0%|          | 0.00/41.8M [00:00<?, ?B/s]

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-2): 3 x TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resi

# Part 6: Train a SAE

Train an SAE which preferably has:
- Low Loss and MSE (good reconstruction. Very important)
- Low sparsity (Sparsity = # neurons activating. So low sparsity means few neurons are active in any given prediction)
- Low number of Active Neurons (makes interpretation easier)


In [23]:
dataloader = mmi.get_mixed_maths_dataloader(cfg, num_batches=1000, enrich_data=True)
print("Data set size", len(dataloader.dataset))

Data set size 64000


In [None]:
sae, score, loss, sparsity, neurons_used = mmi.analyze_mlp_with_sae(cfg, dataloader, layer_num=0, encoding_dim=32, learning_rate=5e-4, sparsity_target=0.1, sparsity_weight=1e-3, num_epochs=10)
print( f"Score: {score:.4f}, Loss {loss:.4f}, Sparsity {sparsity:.4f}, Neurons Used: {neurons_used}.")

Epoch: 1, Score: 9.0763, Loss: 0.0736, MSE: 0.0686, Sparsity Penalty: 4.9361, L1 Penalty: 0.3280, Sparsity: 71.69%, Neurons used: 32/32 (100.00%)


In [None]:
sae, score, loss, sparsity, neurons_used = mmi.analyze_mlp_with_sae(cfg, dataloader, layer_num=0, encoding_dim=64, learning_rate=0.001, sparsity_target=0.05, sparsity_weight=0.1, num_epochs=10)
print( f"Score: {score:.4f}, Loss {loss:.4f}, Sparsity {sparsity:.4f}, Neurons Used: {neurons_used}.")

# Part 7: Sweep hyperparams to find the best SAE

Sweep hyperparameters to train/score multiple SAEs to find the best scoring SAE. Slow.

In [None]:
param_grid = {
    'encoding_dim': [64], #[32, 64, 128, 256, 512],
    'learning_rate': [1e-3], # [1e-4, 1e-3, 1e-2],
    'sparsity_target': [0.05, 0.1], # [0.001, 0.005, 0.01, 0.05, 0.1],
    'sparsity_weight': [0.1], #[1e-3, 1e-2, 1e-1, 1.0],
    'l1_weight': [1e-4, 1e-3, 1e-2],  # [1e-6, 1e-5, 1e-4, 1e-3, 1e-2],
    'num_epochs': [10],
    'patience': [2]
}

In [None]:
num_experiments = 1
for param_values in param_grid.values():
    num_experiments *= len(param_values)

print(f"Number of configurations to test: {num_experiments}")

In [None]:
sae, score, neurons_used, params = mmi.optimize_sae_hyperparameters(cfg, dataloader, layer_num=0, param_grid=param_grid)

# Part 7: Visualize the SAE


In [None]:
mmi.analyze_and_visualize_sae(cfg, sae, dataloader, layer_num=0, max_samples=1000, perplexity=30, n_iter=250)