# Verified Integer Arithmetic in Transformers - Training Graphs

This Colab graphs training data across multiple models.

This Colab follows on from https://github.com/PhilipQuirke/quanta_addition_and_subtraction/blob/main/notebooks/VerifiedArithmeticTrain.ipynb which trained the models, and outputs model_name.pth and model_name_train.json



## Tips for using the Colab
 * You can run and alter the code in this CoLab notebook yourself in Google CoLab ( https://colab.research.google.com/ ).
 * To run the notebook, in Google CoLab, **you will need to** go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.
 * Some graphs are interactive!
 * Use the table of contents pane in the sidebar to navigate.
 * Collapse irrelevant sections with the dropdown arrows.
 * Search the page using the search in the sidebar, not CTRL+F.

# Part 0: Import libraries
Imports standard libraries.

Imports "verified_transformer" public library as "qt". This library is specific to this CoLab's "QuantaTool" approach to transformer analysis. Refer to [README.md](https://github.com/PhilipQuirke/verified_transformers/blob/main/README.md) for more detail.

In [None]:
DEVELOPMENT_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    !pip install matplotlib

    !pip install kaleido
    !pip install transformer_lens
    !pip install torchtyping
    !pip install transformers

    !pip install numpy
    !pip install scikit-learn

except:
    IN_COLAB = False

    def setup_jupyter(install_libraries=False):
        if install_libraries:
            !pip install matplotlib==3.8.4
            !pip install kaleido==0.2.1
            !pip install transformer_lens==1.15.0
            !pip install torchtyping==0.1.4
            !pip install transformers==4.39.3

            !pip install numpy==1.26.4
            !pip install plotly==5.20.0
            !pip install pytest==8.1.1
            !pip install scikit-learn==1.4.1.post1

        print("Running as a Jupyter notebook - intended for development only!")
        from IPython import get_ipython

        ipython = get_ipython()
        # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
        ipython.magic("load_ext autoreload")
        ipython.magic("autoreload 2")

    # setup_jupyter(install_libraries=True)   # Uncomment if you need to install libraries in notebook.
    setup_jupyter(install_libraries=False)

Running as a Colab notebook


In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import kaleido
import plotly.io as pio

if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

In [None]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [None]:
import json
import torch
import numpy as np
import pandas as pd

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [None]:
!pip install scikit-optimize

import re
import sklearn # Aka scikit.learn
import skopt # Aka scikit.optimize

In [None]:
import transformer_lens
from transformer_lens.utils import download_file_from_hf
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [None]:
!pip install --upgrade git+https://github.com/PhilipQuirke/quanta_mech_interp.git
import quanta_tools as qt

In [None]:
!pip install --upgrade git+https://github.com/PhilipQuirke/quanta_addition_subtraction.git
import maths_tools as mt

# Part 1A: Configuration

Which existing model do we want to graph?

The existing model weightings created by the sister Colab [VerifiedArithmeticTrain](https://github.com/PhilipQuirke/transformer-maths/blob/main/assets/VerifiedArithmeticTrain.ipynb) are loaded from HuggingFace (in Part 5). Refer https://github.com/PhilipQuirke/quanta_addition_and_subtraction/blob/main/README.md for more detail.

In [None]:
# Singleton QuantaTool "main" configuration class. MathsConfig is derived from the chain AlgoConfig > UsefulConfig > ModelConfig
cfg = mt.MathsConfig()


# Which model do we want to graph? Uncomment one line:

# Addition models
#cfg.set_model_names( "add_d5_l1_h3_t15K_s372001" )  # AddAccuracy=Two9s. Inaccurate as only has one layer. Can predict S0, S1 and S2 complexity questions.
#cfg.set_model_names( "add_d5_l2_h3_t15K_s372001" )  # AddAccuracy=Six9s. AvgFinalLoss=1.6e-08
#cfg.set_model_names( "add_d5_l2_h3_t40K_s372001" )  # AddAccuracy=Six9s. AvgFinalLoss=2e-09. Fewest nodes
#cfg.set_model_names( "add_d6_l2_h3_t15K_s372001" )  # AddAccuracy=Fives. AvgFinalLoss=1.7e-08. (2/M fail: 018539+789353=+0807892 ModelAnswer: +0707892, 747332+057349=+0804681 ModelAnswer: +0704681)
#cfg.set_model_names( "add_d6_l2_h3_t20K_s173289" )  # AddAccuracy=Six9s. AvgFinalLoss=1.5e-08. Fewest nodes
#cfg.set_model_names( "add_d6_l2_h3_t20K_s572091" )  # AddAccuracy=Six9s. AvgFinalLoss=7e-09
#cfg.set_model_names( "add_d6_l2_h3_t40K_s372001" )  # AddAccuracy=Six9s. AvgFinalLoss=2e-09
#cfg.set_model_names( "add_d7_l2_h3_t45K_s173289" )  # AddAccuracy=Six9s. AvgFinalLoss=3e-09. (0/M fail)
#cfg.set_model_names( "add_d8_l2_h3_t45K_s173289" )  # AddAccuracy=Six9s. AvgFinalLoss=3e-09. (0/M fail)
#cfg.set_model_names( "add_d9_l2_h3_t45K_s173289" )  # AddAccuracy=Six9s. AvgFinalLoss=3e-09. (0/M fail)
#cfg.set_model_names( "add_d10_l2_h3_t40K_s572091" ) # AddAccuracy=Six9s. AvgFinalLoss=7e-09. (1/M fail: 0000000555+0000000445=+00000001000 ModelAnswer: +00000000900)
#cfg.set_model_names( "add_d10_l2_h3_t40K_gf_s572091" ) # AddAccuracy=Six9s. AvgFinalLoss=3.5-09. GrokFast.
#cfg.set_model_names( "add_d11_l2_h3_t50K_s572091" )  # AddAccuracy=Five9s. AvgFinalLoss=8e-09. (2/M fail)
#cfg.set_model_names( "add_d12_l2_h3_t50K_s572091" )  # AddAccuracy=Five9s. AvgFinalLoss=5e-09. (3/M fail)
#cfg.set_model_names( "add_d13_l2_h3_t50K_s572091" )  # AddAccuracy=Six9s. AvgFinalLoss=6.3e-08. (1/M fail)
cfg.set_model_names( "add_d14_l2_h3_t60K_s572091" )  # AddAccuracy=Three9S. AvgFinalLoss=5.6e-06. (199/M fail)
#cfg.set_model_names( "add_d15_l2_h3_t80K_s572091" ) # AddAccuracy=Five9s. AvgFinalLoss=8.6e-08 (10/M fail)
#cfg.set_model_names( "add_d20_l2_h3_t80K_s572091" ) # AddAccuracy=Poor! AvgFinalLoss=0.20!

# Subtraction model
#cfg.set_model_names( "sub_d6_l2_h3_t30K_s372001" )  # SubAccuracy=Six9s. AvgFinalLoss=5.8e-06
#cfg.set_model_names( "sub_d10_l2_h3_t75K_s173289" )  # SubAccuracy=Two9s. (6672/M fails) AvgFinalLoss=0.002002.
#cfg.set_model_names( "sub_d10_l2_h3_t75K_gf_s173289" )  # SubAccuracy=Two9s. GrokFast. (5246/M fails) AvgFinalLoss=0.001197

# Mixed (addition and subtraction) model
#cfg.set_model_names( "mix_d6_l3_h4_t40K_s372001" )  # Add/SubAccuracy=Six9s/Six9s. AvgFinalLoss=5e-09. (1/M fail: 463687+166096=+0629783 ModelAnswer: +0639783)
#cfg.set_model_names( "mix_d10_l3_h4_t75K_s173289" )  # Add/SubAccuracy=Five9s/Two9s. AvgFinalLoss=1.125e-06 (2/M fail: 3301956441+6198944455=+09500900896 ModelAnswer: +09500800896) (295/M fail: 8531063649-0531031548=+08000032101 ModelAnswer: +07900032101)
#cfg.set_model_names( "mix_d10_l3_h4_t75K_gf_s173289" )  # Add/SubAccuracy=Six9s/Three9s. GrokFast. AvgFinalLoss 8.85e-07 (1/M fail) (294/M fail)

# Mixed models initialized with addition model
#cfg.set_model_names( "ins1_mix_d6_l2_h3_t40K_s572091" )  # Add/SubAccuracy=Six9s/Five9s. AvgLoss = 2.4e-08 (5/M fails e.g. 565000-364538=+0200462 ModelAnswer: +0100462)
#cfg.set_model_names( "ins1_mix_d6_l3_h3_t40K_s572091" )  # Add/SubAccuracy=Six9s/Five9s. AvgFinalLoss=1.8e-08. (3/M fails e.g. 072074-272074=-0200000 ModelAnswer: +0200000)
#cfg.set_model_names( "ins1_mix_d6_l3_h3_t80K_s572091" )  # Add/SubAccuracy=Six9s/Five9s AvgLoss = 1.6e-08 (3/M fails e.g. 229672-229678=-0000006 ModelAnswer: +0000006) (EnrichFalse => 0/M, 4/M)
#cfg.set_model_names( "ins1_mix_d6_l3_h4_t40K_s372001" )  # Add/SubAccuracy=Six9s/Six9s. AvgFinalLoss=8e-09. MAIN FOCUS
#cfg.set_model_names( "ins1_mix_d6_l3_h4_t40K_s173289" )  # Add/SubAccuracy=Five9s/Five9s. AvgFinalLoss=1.4e-08. (3/M fails e.g. 850038+159060=+1009098 ModelAnswer: +0009098) (2/M fails e.g. 77285-477285=+0100000 Q: ModelAnswer: +0000000) (EnrichFalse => 0/M, 3/M)
#cfg.set_model_names( "ins1_mix_d6_l3_h4_t50K_s572091" )  # Add/SubAccuracy=Six9s/Five9s. AvgFinalLoss=2.9e-08. (4/M fails e.g. 986887-286887=+0700000 ModelAnswer: +0600000) (EnrichFalse => 0/M, 3/M)
#cfg.set_model_names( "ins1_mix_d10_l3_h3_t50K_s572091" )  # Add/SubAccuracy=Five9s/Five9s. AvgFinalLoss 6.3e-07  (6/M fails e.g. 5068283822+4931712829=+09999996651 ModelAnswer: +19099996651) (7/M fails e.g. 3761900218-0761808615=+03000091603 ModelAnswer: +02000091603)
#cfg.set_model_names( "ins1_mix_d10_l3_h3_t50K_gf_s572091" ) # Add/SubAccuracy=Five9s/Two9s. GrokFast. AvgFinalLoss=4.0e-06 (2/M fails) (1196/M fails)

# Mixed model initialized with addition model. Reset useful heads every 100 epochs.
#cfg.set_model_names( "ins2_mix_d6_l4_h4_t40K_s372001" )  # Add/SubAccuracy=Five9s/Five9s. AvgFinalLoss=1.7e-08. (3/M fails e.g. 530757+460849=+0991606 ModelAnswer: +0091606) (8 fails e.g. 261926-161857=+0100069 ModelAnswer: +0000069)

# Mixed model initialized with addition model. Reset useful heads & MLPs every 100 epochs.
#cfg.set_model_names( "ins3_mix_d6_l4_h3_t40K_s372001" )  # Add/SubAccuracy=Four9s/Two9s. AvgFinalLoss=3.0e-04. (17/M fails e.g. 273257+056745=+0330002 ModelAnswer: +0320002) (3120 fails e,g. 09075-212133=-0003058 ModelAnswer: +0003058)

# Addition&Subtraction model initialized with addition model. Reset useful heads & MLPs every epoch
#cfg.set_model_names( "ins4_mix_d6_l4_h3_t40K_s372001" )  # AvgFinalLoss 4.7e-06

# Mixed models initialized with addition model. Insert mode 5
#cfg.set_model_names( "ins5_mix_d6_l3_h4_t30K_s775824" )  # Add/SubAccuracy=???/??? TODO
#cfg.set_model_names( "ins5_mix_d6_l2_h4_t30K_s775824" )  # Add/SubAccuracy=???/??? TODO

# Part 1B: Configuration: Input and Output file names



In [None]:
main_fname_pth = cfg.model_name + '.pth'
main_fname_train_json = cfg.model_name + '_train.json'
main_fname_behavior_json = cfg.model_name + '_behavior.json'
main_fname_maths_json = cfg.model_name + '_maths.json'
main_repo_name="PhilipQuirke/VerifiedArithmetic"

In [None]:
# Update "cfg" with additional training config (including cfg.insert_*) with information stored in:
#      https://huggingface.co/PhilipQuirke/VerifiedArithmetic/raw/main/ins1_mix_d6_l3_h4_t40K_s372001_train.json"
training_data_json = qt.download_huggingface_json(main_repo_name, main_fname_train_json)
training_loss_list = qt.load_training_json(cfg, training_data_json)

In [None]:
print('Main model will be read from HuggingFace file', main_repo_name, main_fname_pth)
print('Main model training config / loss will be read from HuggingFace file', main_fname_train_json)

# Part 2: Single model training loss graphs

In [None]:
print("%Add=", cfg.perc_add, "%Sub=", cfg.perc_sub, "%Mult=", cfg.perc_mult, "InsertMode=", cfg.insert_mode, "File=", cfg.model_name)
print("weight_decay=", cfg.weight_decay, "lr=", cfg.lr, "batch_size=", cfg.batch_size)
print( "Avg loss over last 5 epochs (AvgFinalLoss)", cfg.avg_final_loss)
print( "Final epoch loss", cfg.final_loss)

In [None]:
# Show the model final training loss and graph loss over epochs
if training_loss_list:
  answer_digits = cfg.n_digits + 1
  title_font_size=32
  tick_font_size=24

  qt.plot_loss_lines(cfg, 1500, [training_loss_list[:1500]], labels = ['All'], log_y=False,
                       title='Training Loss', title_font_size=title_font_size, tick_font_size=tick_font_size)

  full_title, fig = qt.plot_loss_lines(cfg, cfg.n_training_steps, [training_loss_list], labels = ['All'], log_y=True,
                       title='Training Loss', title_font_size=title_font_size, tick_font_size=tick_font_size)
  pio.write_image(fig, cfg.model_name + '_LogTrainingLoss.' + cfg.graph_file_suffix)

# Part 3: Multiple Addition model training loss graphs

In [None]:
json_file_paths = [
    "add_d5_l2_h3_t40K_s372001",
    "add_d6_l2_h3_t40K_s372001",
    "add_d7_l2_h3_t45K_s173289",
    "add_d8_l2_h3_t45K_s173289",
    "add_d9_l2_h3_t45K_s173289",
    "add_d10_l2_h3_t40K_s572091",
    "add_d11_l2_h3_t50K_s572091",
    "add_d12_l2_h3_t50K_s572091",
    "add_d13_l2_h3_t50K_s572091",
    "add_d14_l2_h3_t60K_s572091",
    "add_d15_l2_h3_t80K_s572091",
]

all_training_loss_lists = []
model_labels = []

# Load and process the training loss data from each JSON file
for file_path in json_file_paths:
    training_data_json = qt.download_huggingface_json(main_repo_name, file_path+"_train.json")
    training_loss_list = qt.load_training_json(cfg, training_data_json)
    all_training_loss_lists.append(training_loss_list)
    model_labels.append(file_path.split("_")[1])

In [None]:
def smooth_data(data, window_size=500):
    """Apply rolling statistics to the data"""
    series = pd.Series(data)
    smoothed_mean = series.rolling(window=window_size, center=True).mean()
    smoothed_min = series.rolling(window=window_size, center=True).min()
    smoothed_max = series.rolling(window=window_size, center=True).max()
    return smoothed_mean, smoothed_min, smoothed_max

def hex_to_rgba(hex_color, alpha=0.2):
    """Convert hex color to rgba string with alpha"""
    hex_color = hex_color.lstrip('#')
    r = int(hex_color[0:2], 16)
    g = int(hex_color[2:4], 16)
    b = int(hex_color[4:6], 16)
    return f'rgba({r},{g},{b},{alpha})'

In [None]:
fig = make_subplots(rows=2, cols=1, subplot_titles=("Training Loss Comparison (First 15K Steps)", "Training Loss Comparison (Log Scale)"))

# Color palette for different traces
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf', '#aec7e8']

for i, loss_list in enumerate(all_training_loss_lists):
    smooth_mean, smooth_min, smooth_max = smooth_data(loss_list)
    x_vals = list(range(len(loss_list)))

    # Regular scale plot
    fig.add_trace(
        go.Scatter(
            x=x_vals,
            y=smooth_max,
            mode='lines',
            line=dict(width=0),
            showlegend=False,
            name=f'{model_labels[i]} Upper Bound'
        ),
        row=1, col=1
    )

    fig.add_trace(
        go.Scatter(
            x=x_vals,
            y=smooth_min,
            mode='lines',
            line=dict(width=0),
            fillcolor=hex_to_rgba(colors[i]),
            fill='tonexty',
            showlegend=False,
            name=f'{model_labels[i]} Lower Bound'
        ),
        row=1, col=1
    )

    # Add mean line
    fig.add_trace(
        go.Scatter(
            x=x_vals,
            y=smooth_mean,
            mode='lines',
            line=dict(color=colors[i], width=2),
            name=model_labels[i]
        ),
        row=1, col=1
    )

    # Log scale plot
    fig.add_trace(
        go.Scatter(
            x=x_vals,
            y=smooth_max,
            mode='lines',
            line=dict(width=0),
            showlegend=False,
            name=f'{model_labels[i]} Upper Bound'
        ),
        row=2, col=1
    )

    fig.add_trace(
        go.Scatter(
            x=x_vals,
            y=smooth_min,
            mode='lines',
            line=dict(width=0),
            fillcolor=hex_to_rgba(colors[i]),
            fill='tonexty',
            showlegend=False,
            name=f'{model_labels[i]} Lower Bound'
        ),
        row=2, col=1
    )

    # Add mean line
    fig.add_trace(
        go.Scatter(
            x=x_vals,
            y=smooth_mean,
            mode='lines',
            line=dict(color=colors[i], width=2),
            showlegend=False,
            name=model_labels[i]
        ),
        row=2, col=1
    )

# Update layout
fig.update_layout(
    height=1000,
    showlegend=True,
    legend=dict(
        yanchor="top",
        y=0.95,
        xanchor="right",
        x=0.95
    ),
    template="plotly_white"
)

# Update axes labels and ranges
fig.update_xaxes(title_text="Training Step", row=1, col=1, range=[0, 15000])
fig.update_xaxes(title_text="Training Step", row=2, col=1, range=[0, 80000])
fig.update_yaxes(title_text="Training Loss", row=1, col=1)
fig.update_yaxes(title_text="Training Loss (Log Scale)", type="log", row=2, col=1)

fig.show()