# Quanta Maths: Integer Addition and Subtraction in Transformers - Training Graphs

This Colab graphs training data across multiple models.

This Colab follows on from https://github.com/PhilipQuirke/quanta_maths/blob/main/notebooks/QMTrain.ipynb which trained the models, and outputs model.pth and training_loss.json



## Tips for using the Colab
 * You can run and alter the code in this CoLab notebook yourself in Google CoLab ( https://colab.research.google.com/ ).
 * To run the notebook, in Google CoLab, **you will need to** go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.
 * Some graphs are interactive!
 * Use the table of contents pane in the sidebar to navigate.
 * Collapse irrelevant sections with the dropdown arrows.
 * Search the page using the search in the sidebar, not CTRL+F.

# Part 0: Import libraries
Imports standard libraries.

Imports "MathsQuanta" public library as "mmi". Refer to [README.md](https://github.com/PhilipQuirke/Quanta_Maths/blob/main/README.md) for more detail.

In [None]:
DEVELOPMENT_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
 
    !pip install matplotlib
    !pip install kaleido

except ImportError:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    #!pip install matplotlib==3.8.4
    #!pip install kaleido==0.2.1
    # Install required libraries if not already installed
    import sys
    import subprocess
    def install(package):
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])

    # List of required packages
    required_packages = ["matplotlib", "kaleido", "plotly", "nbformat", "scikit-learn", "scikit-optimize", "huggingface_hub", "torch", "numpy"]
    for pkg in required_packages:
        try:
            __import__(pkg.replace('-', '_'))
        except ImportError:
            print(f"Installing {pkg}...")
            install(pkg)

    from IPython import get_ipython
    ipython = get_ipython()
    %load_ext autoreload
    %autoreload 2

    # Uncomment below to force reinstall if needed
    # for pkg in required_packages:
    #     install(pkg)

In [None]:
try:
    import QuantaMechInterp as qmi
except ImportError:
    import sys
    !{sys.executable} -m pip install --upgrade git+https://github.com/PhilipQuirke/quanta_mech_interp.git
    import QuantaMechInterp as qmi

In [None]:
try:
    import MathsMechInterp as mmi
except ImportError:
    import sys
    !{sys.executable} -m pip install --upgrade git+https://github.com/PhilipQuirke/quanta_maths.git
    import MathsMechInterp as mmi

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import kaleido
import plotly.io as pio

if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

In [None]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [None]:
import json
import torch
import numpy as np
import pandas as pd

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [None]:
import re
import sklearn # Aka scikit.learn
import skopt # Aka scikit.optimize

In [None]:
import transformer_lens
from transformer_lens.utils import download_file_from_hf
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

# Part 1A: Configuration

Which existing model do we want to graph?

The existing model weightings created by the sister Colab [QuantaMathsTrain](https://github.com/PhilipQuirke/Quanta_Maths/blob/main/assets/QuantaMathsTrain.ipynb) are loaded from HuggingFace (in Part 5). Refer https://github.com/PhilipQuirke/quanta_maths/blob/main/README.md for more detail.

In [None]:
# Singleton QuantaTool "main" configuration class. MathsConfig is derived from the chain AlgoConfig > UsefulConfig > ModelConfig
cfg = mmi.MathsConfig()


# Which model do we want to graph? Uncomment one line:

# Addition models
#cfg.set_model_names( "add_d5_l1_h3_t15K_s372001" )  # AddAccuracy=Two9s. Inaccurate as only has one layer. Can predict S0, S1 and S2 complexity questions.
#cfg.set_model_names( "add_d5_l2_h3_t15K_s372001" )  # AddAccuracy=Six9s. AvgFinalLoss=1.6e-08
#cfg.set_model_names( "add_d5_l2_h3_t40K_s372001" )  # AddAccuracy=Six9s. AvgFinalLoss=2e-09. Fewest nodes
#cfg.set_model_names( "add_d6_l2_h3_t15K_s372001" )  # AddAccuracy=Fives. AvgFinalLoss=1.7e-08. (2/M fail: 018539+789353=+0807892 ModelAnswer: +0707892, 747332+057349=+0804681 ModelAnswer: +0704681)
#cfg.set_model_names( "add_d6_l2_h3_t20K_s173289" )  # AddAccuracy=Six9s. AvgFinalLoss=1.5e-08. Fewest nodes
#cfg.set_model_names( "add_d6_l2_h3_t20K_s572091" )  # AddAccuracy=Six9s. AvgFinalLoss=7e-09
#cfg.set_model_names( "add_d6_l2_h3_t40K_s372001" )  # AddAccuracy=Six9s. AvgFinalLoss=2e-09
#cfg.set_model_names( "add_d7_l2_h3_t45K_s173289" )  # AddAccuracy=Six9s. AvgFinalLoss=3e-09. (0/M fail)
#cfg.set_model_names( "add_d8_l2_h3_t45K_s173289" )  # AddAccuracy=Six9s. AvgFinalLoss=3e-09. (0/M fail)
#cfg.set_model_names( "add_d9_l2_h3_t45K_s173289" )  # AddAccuracy=Six9s. AvgFinalLoss=3e-09. (0/M fail)
#cfg.set_model_names( "add_d10_l2_h3_t40K_s572091" ) # AddAccuracy=Six9s. AvgFinalLoss=7e-09. (1/M fail: 0000000555+0000000445=+00000001000 ModelAnswer: +00000000900)
#cfg.set_model_names( "add_d10_l2_h3_t40K_gf_s572091" ) # AddAccuracy=Six9s. AvgFinalLoss=3.5-09. GrokFast.
#cfg.set_model_names( "add_d11_l2_h3_t50K_s572091" )  # AddAccuracy=Five9s. AvgFinalLoss=8e-09. (2/M fail)
#cfg.set_model_names( "add_d12_l2_h3_t50K_s572091" )  # AddAccuracy=Five9s. AvgFinalLoss=5e-09. (3/M fail)
#cfg.set_model_names( "add_d13_l2_h3_t50K_s572091" )  # AddAccuracy=Six9s. AvgFinalLoss=6.3e-08. (1/M fail)
cfg.set_model_names( "add_d14_l2_h3_t60K_s572091" )  # AddAccuracy=Three9S. AvgFinalLoss=5.6e-06. (199/M fail)
#cfg.set_model_names( "add_d15_l2_h3_t80K_s572091" ) # AddAccuracy=Five9s. AvgFinalLoss=8.6e-08 (10/M fail)
#cfg.set_model_names( "add_d20_l2_h3_t80K_s572091" ) # AddAccuracy=Poor! AvgFinalLoss=0.20!

# Subtraction model
#cfg.set_model_names( "sub_d6_l2_h3_t30K_s372001" )  # SubAccuracy=Six9s. AvgFinalLoss=5.8e-06
#cfg.set_model_names( "sub_d10_l2_h3_t75K_s173289" )  # SubAccuracy=Two9s. (6672/M fails) AvgFinalLoss=0.002002.
#cfg.set_model_names( "sub_d10_l2_h3_t75K_gf_s173289" )  # SubAccuracy=Two9s. GrokFast. (5246/M fails) AvgFinalLoss=0.001197

# Mixed (addition and subtraction) model
#cfg.set_model_names( "mix_d5_l3_h4_t40K_s372001" )  # Add/SubAccuracy=Six9s/Six9s. AvgFinalLoss=9e-09. (0/M fails, 0/M fails)
#cfg.set_model_names( "mix_d6_l3_h4_t40K_s372001" )  # Add/SubAccuracy=Six9s/Six9s. AvgFinalLoss=5e-09. (1/M fail)
#cfg.set_model_names( "mix_d7_l3_h4_t50K_s372001" )  # Add/SubAccuracy=Five9s/Five9s. AvgFinalLoss=2e-08. (2/M fails, 6/M fails)
#cfg.set_model_names( "mix_d8_l3_h4_t60K_s173289" )  # Add/SubAccuracy=Six9s/Five9s. AvgFinalLoss=4.7e-08. (0/M fails, 7/M fails)
#cfg.set_model_names( "mix_d9_l3_h4_t60K_s173289" )  # Add/SubAccuracy=Six9s/Four9s. AvgFinalLoss=3.2e-07. (1/M fails, 33/M fails)
#cfg.set_model_names( "mix_d10_l3_h4_t75K_s173289" )  # Add/SubAccuracy=Five9s/Two9s. AvgFinalLoss=1.125e-06 (2/M fail, 295/M fail)
#cfg.set_model_names( "mix_d10_l3_h4_t75K_gf_s173289" )  # Add/SubAccuracy=Six9s/Three9s. GrokFast. AvgFinalLoss=8.85e-07 (1/M fail, 294/M fail)
#cfg.set_model_names( "mix_d11_l3_h4_t80K_s572091" )  # Add/SubAccuracy=Six9s/Four9s AvgFinalLoss=3.9e-08 (0/M fail, 13/M fail)
#cfg.set_model_names( "mix_d12_l3_h4_t85K_s572091" )  # Add/SubAccuracy=TODO
#cfg.set_model_names( "mix_d13_l3_h4_t85K_s572091" )  # Add/SubAccuracy=TODO

# Mixed models initialized with addition model
#cfg.set_model_names( "ins1_mix_d5_l2_h3_t40K_s572091" )  # Add/SubAccuracy=TODO
#cfg.set_model_names( "ins1_mix_d6_l2_h3_t40K_s572091" )  # Add/SubAccuracy=Six9s/Five9s. AvgLoss = 2.4e-08 (5/M fails)
#cfg.set_model_names( "ins1_mix_d6_l3_h3_t40K_s572091" )  # Add/SubAccuracy=Six9s/Five9s. AvgFinalLoss=1.8e-08. (3/M fails)
#cfg.set_model_names( "ins1_mix_d6_l3_h3_t80K_s572091" )  # Add/SubAccuracy=Six9s/Five9s AvgLoss = 1.6e-08 (3/M fails)
#cfg.set_model_names( "ins1_mix_d6_l3_h4_t40K_s372001" )  # Add/SubAccuracy=Six9s/Six9s. AvgFinalLoss=8e-09. MAIN FOCUS
#cfg.set_model_names( "ins1_mix_d6_l3_h4_t40K_s173289" )  # Add/SubAccuracy=Five9s/Five9s. AvgFinalLoss=1.4e-08. (3/M fails, 2/M fails)
#cfg.set_model_names( "ins1_mix_d6_l3_h4_t50K_s572091" )  # Add/SubAccuracy=Six9s/Five9s. AvgFinalLoss=2.9e-08. (4/M fails)
#cfg.set_model_names( "ins1_mix_d7_l3_h4_t50K_s572091" )  # Add/SubAccuracy=TODO
#cfg.set_model_names( "ins1_mix_d8_l3_h4_t70K_s572091" )  # Add/SubAccuracy=TODO
#cfg.set_model_names( "ins1_mix_d9_l3_h4_t70K_s572091" )  # Add/SubAccuracy=TODO
#cfg.set_model_names( "ins1_mix_d10_l3_h3_t50K_s572091" )  # Add/SubAccuracy=Five9s/Five9s. AvgFinalLoss 6.3e-07 (6/M fails, 7/M fails)
#cfg.set_model_names( "ins1_mix_d10_l3_h3_t50K_gf_s572091" ) # Add/SubAccuracy=Five9s/Two9s. GrokFast. AvgFinalLoss=4.0e-06 (2/M fails, 1196/M fails)
#cfg.set_model_names( "ins1_mix_d11_l3_h4_t75K_s572091" )  # Add/SubAccuracy=TODO

# Mixed model initialized with addition model. Reset useful heads every 100 epochs.
#cfg.set_model_names( "ins2_mix_d6_l4_h4_t40K_s372001" )  # Add/SubAccuracy=Five9s/Five9s. AvgFinalLoss=1.7e-08. (3/M fails e.g. 530757+460849=+0991606 ModelAnswer: +0091606) (8 fails e.g. 261926-161857=+0100069 ModelAnswer: +0000069)

# Mixed model initialized with addition model. Reset useful heads & MLPs every 100 epochs.
#cfg.set_model_names( "ins3_mix_d6_l4_h3_t40K_s372001" )  # Add/SubAccuracy=Four9s/Two9s. AvgFinalLoss=3.0e-04. (17/M fails e.g. 273257+056745=+0330002 ModelAnswer: +0320002) (3120 fails e,g. 09075-212133=-0003058 ModelAnswer: +0003058)

# Addition&Subtraction model initialized with addition model. Reset useful heads & MLPs every epoch
#cfg.set_model_names( "ins4_mix_d6_l4_h3_t40K_s372001" )  # AvgFinalLoss 4.7e-06

# Mixed models initialized with addition model. Insert mode 5
#cfg.set_model_names( "ins5_mix_d6_l3_h4_t30K_s775824" )  # Add/SubAccuracy=???/??? TODO
#cfg.set_model_names( "ins5_mix_d6_l2_h4_t30K_s775824" )  # Add/SubAccuracy=???/??? TODO

# Part 1B: Configuration: Input and Output file names



In [None]:
base_repo_name = 'PhilipQuirke'
model_pth_fname = 'model.pth'
model_training_loss_fname = 'training_loss.json'
model_behaviors_fname = 'behaviors.json'
model_features_fname = 'features.json'

cfg.hf_repo = base_repo_name + "/QuantaMaths_" + cfg.model_name

print('Model files will be read from HuggingFace repo:', base_repo_name, model_pth_fname, 'and', model_training_loss_fname)
print('Model analysis tags will be read from HuggingFace repo:', model_behaviors_fname, 'and', model_features_fname)

In [None]:
# Update "cfg" with additional training config (including cfg.insert_*) with information stored in:
#      https://huggingface.co/.../QuantaMaths_ins1_mix_d6_l3_h4_t40K_s372001/training_loss.json"
training_data_json = qmi.download_huggingface_json(cfg.hf_repo, model_training_loss_fname)
training_loss_list = qmi.load_training_json(cfg, training_data_json)

# Part 2: Single model training loss graphs

In [None]:
print("%Add=", cfg.perc_add, "%Sub=", cfg.perc_sub, "%Mult=", cfg.perc_mult, "InsertMode=", cfg.insert_mode, "File=", cfg.model_name)
print("weight_decay=", cfg.weight_decay, "lr=", cfg.lr, "batch_size=", cfg.batch_size)
print( "Avg loss over last 5 epochs (AvgFinalLoss)", cfg.avg_final_loss)
print( "Final epoch loss", cfg.final_loss)

In [None]:
# Show the model final training loss and graph loss over epochs
if training_loss_list:
  answer_digits = cfg.n_digits + 1
  title_font_size=32
  tick_font_size=24

  qmi.plot_loss_lines(cfg, 1500, [training_loss_list[:1500]], labels = ['All'], log_y=False,
                       title='Training Loss', title_font_size=title_font_size, tick_font_size=tick_font_size)

  full_title, fig = qmi.plot_loss_lines(cfg, cfg.n_training_steps, [training_loss_list], labels = ['All'], log_y=True,
                       title='Training Loss', title_font_size=title_font_size, tick_font_size=tick_font_size)
  pio.write_image(fig, cfg.model_name + '_LogTrainingLoss.' + cfg.graph_file_suffix)

# Part 3A: Multiple model training loss graphs

In [None]:
all_training_loss_lists = []
model_labels = []

In [None]:
# Load and process the training loss data from each model
def load_training_data(model_names):
  global all_training_loss_lists
  global model_labels

  all_training_loss_lists = []
  model_labels = []

  for model_name in model_names:
      the_repo = base_repo_name + "/QuantaMaths_" + model_name
      training_data_json = qmi.download_huggingface_json(the_repo, model_training_loss_fname)
      training_loss_list = qmi.load_training_json(cfg, training_data_json)
      all_training_loss_lists.append(training_loss_list)

      cfg2 = mmi.MathsConfig()
      cfg2.set_model_names( model_name )
      model_labels.append( "d" + str(cfg2.n_digits))

In [None]:
def smooth_data(data, window_size=500):
    """Apply rolling statistics to the data"""
    series = pd.Series(data)
    smoothed_mean = series.rolling(window=window_size, center=True).mean()
    smoothed_min = series.rolling(window=window_size, center=True).min()
    smoothed_max = series.rolling(window=window_size, center=True).max()
    return smoothed_mean, smoothed_min, smoothed_max

def hex_to_rgba(hex_color, alpha=0.2):
    """Convert hex color to rgba string with alpha"""
    hex_color = hex_color.lstrip('#')
    r = int(hex_color[0:2], 16)
    g = int(hex_color[2:4], 16)
    b = int(hex_color[4:6], 16)
    return f'rgba({r},{g},{b},{alpha})'

In [None]:
def plot_training_data_regular(prefix):
    global all_training_loss_lists
    global model_labels

    fig = go.Figure()

    # Color palette for different traces
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
              '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf', '#aec7e8']

    for i, loss_list in enumerate(all_training_loss_lists):
        # Limit data to first 15K steps to reduce memory usage
        max_steps = 15000
        loss_list = loss_list[:max_steps]
        x_vals = list(range(len(loss_list)))

        smooth_mean, smooth_min, smooth_max = smooth_data(loss_list)

        # Plot smooth_min first
        fig.add_trace(
            go.Scatter(
                x=x_vals,
                y=smooth_min,
                mode='lines',
                line=dict(width=0),
                hoverinfo='skip',
                showlegend=False
            )
        )

        # Plot smooth_max with fill='tonexty'
        fig.add_trace(
            go.Scatter(
                x=x_vals,
                y=smooth_max,
                mode='lines',
                line=dict(width=0),
                fillcolor=hex_to_rgba(colors[i]),
                fill='tonexty',
                hoverinfo='skip',
                showlegend=False
            )
        )

        # Add mean line
        fig.add_trace(
            go.Scatter(
                x=x_vals,
                y=smooth_mean,
                mode='lines',
                line=dict(color=colors[i], width=2),
                name=model_labels[i]
            )
        )

    # Update layout
    fig.update_layout(
        title=prefix + " Training Loss Comparison (First 15K Steps)",
        height=600,
        template="plotly_white",
        showlegend=True,
        legend=dict(
            yanchor="top",
            y=0.95,
            xanchor="right",
            x=0.95
        )
    )

    # Update axes labels and ranges
    fig.update_xaxes(title_text="Training Step", range=[0, max_steps])
    fig.update_yaxes(title_text="Training Loss")

    fig.show()

In [None]:
def plot_training_data_log(prefix):
    global all_training_loss_lists
    global model_labels

    fig = go.Figure()

    # Color palette for different traces
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
              '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf', '#aec7e8']

    for i, loss_list in enumerate(all_training_loss_lists):
        # Limit data to 80K steps or the length of the loss_list
        max_steps = min(80000, len(loss_list))
        loss_list = loss_list[:max_steps]
        x_vals = list(range(len(loss_list)))

        smooth_mean, smooth_min, smooth_max = smooth_data(loss_list)

        # Plot smooth_min first
        fig.add_trace(
            go.Scatter(
                x=x_vals,
                y=smooth_min,
                mode='lines',
                line=dict(width=0),
                hoverinfo='skip',
                showlegend=False
            )
        )

        # Plot smooth_max with fill='tonexty'
        fig.add_trace(
            go.Scatter(
                x=x_vals,
                y=smooth_max,
                mode='lines',
                line=dict(width=0),
                fillcolor=hex_to_rgba(colors[i]),
                fill='tonexty',
                hoverinfo='skip',
                showlegend=False
            )
        )

        # Add mean line
        fig.add_trace(
            go.Scatter(
                x=x_vals,
                y=smooth_mean,
                mode='lines',
                line=dict(color=colors[i], width=2),
                name=model_labels[i]
            )
        )

    # Update layout
    fig.update_layout(
        title=prefix + " Training Loss Comparison (Log Scale)",
        height=600,
        template="plotly_white",
        showlegend=True,
        legend=dict(
            yanchor="top",
            y=0.95,
            xanchor="right",
            x=0.95
        )
    )

    # Update axes labels and ranges
    fig.update_xaxes(title_text="Training Step", range=[0, max_steps])
    fig.update_yaxes(title_text="Training Loss (Log Scale)", type="log")

    fig.show()

    qmi.save_plt_to_file(cfg=cfg, full_title=prefix + "_training" )

# Part 3B: Multiple Addition model training loss graphs

In [None]:
model_names1 = [
    "add_d5_l2_h3_t40K_s372001",
    "add_d6_l2_h3_t40K_s372001",
    "add_d7_l2_h3_t45K_s173289",
    "add_d8_l2_h3_t45K_s173289",
    "add_d9_l2_h3_t45K_s173289",
    "add_d10_l2_h3_t40K_s572091",
    "add_d11_l2_h3_t50K_s572091",
    "add_d12_l2_h3_t50K_s572091",
    "add_d13_l2_h3_t50K_s572091",
    "add_d14_l2_h3_t60K_s572091",
    "add_d15_l2_h3_t80K_s572091",
]

load_training_data(model_names1)

In [None]:
#plot_training_data_regular("Addition")
plot_training_data_log("Addition")

#Part 3C: Multiple Mixed (Addition and Subtraction) model training loss graphs

In [None]:
model_names2 = [
    "mix_d5_l3_h4_t40K_s372001",
    "mix_d6_l3_h4_t40K_s372001",
    "mix_d7_l3_h4_t50K_s372001",
    "mix_d8_l3_h4_t60K_s173289",
    "mix_d9_l3_h4_t60K_s173289",
    "mix_d10_l3_h4_t75K_s173289",
    "mix_d11_l3_h4_t80K_s572091",
    "mix_d12_l3_h4_t85K_s572091",
    "mix_d13_l3_h4_t85K_s572091",
]

load_training_data(model_names2)

In [None]:
#plot_training_data_regular("Mixed")
plot_training_data_log("Mixed")


#Part 3D: Multiple Mixed models initialized with addition models training loss graphs

In [None]:
model_names3 = [
    "ins1_mix_d5_l2_h3_t40K_s572091",
    "ins1_mix_d6_l3_h4_t40K_s372001",
    "ins1_mix_d7_l3_h4_t50K_s572091",
    "ins1_mix_d8_l3_h4_t70K_s572091",
    "ins1_mix_d9_l3_h4_t70K_s572091",
    "ins1_mix_d10_l3_h3_t50K_s572091",
    "ins1_mix_d11_l3_h4_t75K_s572091",
]

load_training_data(model_names3)

In [None]:
#plot_training_data_regular("Inited Mixed")
plot_training_data_log("Inited Mixed")