# Verified Arithmetic in Transformers - Describe Model Algorithm

This Colab describes the algorithm of Transformer models in terms of model behaviours and algorithmic sub-tasks (analysed in other Colabs and stored in JSON files).

The models perform integer addition and/or subtraction e.g. 133357+182243=+0315600 and 123450-345670=-0123230. Each digit is a separate token. For 6 digit questions, the model is given 14 "question" (input) tokens, and must then predict the corresponding 8 "answer" (output) tokens.

This Colab follows on from:
- https://github.com/PhilipQuirke/quanta_maths/blob/main/notebooks/VerifiedArithmeticTrain.ipynb trained the models. It outputs model_name.pth and model_name_train.json
- https://github.com/PhilipQuirke/quanta_maths/blob/main/notebooks/VerifiedArithmeticAnalyse.ipynb analyzes the models. It documents their sub-tasks in model_name_behavior.json model_name_maths.json

This Colab loads the above json files from HuggingFace repository https://huggingface.co/PhilipQuirke/VerifiedArithmetic/raw/main


## Tips for using the Colab
 * You can run and alter the code in this CoLab notebook yourself in Google CoLab ( https://colab.research.google.com/ ).
 * To run the notebook, in Google CoLab, **you will need to** go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.
 * Some graphs are interactive!
 * Use the table of contents pane in the sidebar to navigate.
 * Collapse irrelevant sections with the dropdown arrows.
 * Search the page using the search in the sidebar, not CTRL+F.

# Part 0: Import libraries
Imports standard libraries.

Imports "verified_transformer" public library as "qt". This library is specific to this CoLab's "QuantaTool" approach to transformer analysis. Refer to [README.md](https://github.com/PhilipQuirke/verified_transformers/blob/main/README.md) for more detail.

In [None]:
DEVELOPMENT_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    !pip install matplotlib
    !pip install kaleido

except:
    IN_COLAB = False

    def setup_jupyter(install_libraries=False):
        if install_libraries:
            !pip install matplotlib==3.8.4
            !pip install kaleido==0.2.1

        print("Running as a Jupyter notebook - intended for development only!")
        from IPython import get_ipython

        ipython = get_ipython()
        # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
        ipython.magic("load_ext autoreload")
        ipython.magic("autoreload 2")

    # setup_jupyter(install_libraries=True)   # Uncomment if you need to install libraries in notebook.
    setup_jupyter(install_libraries=False)

In [None]:
!pip install --upgrade git+https://github.com/PhilipQuirke/quanta_mech_interp.git
import QuantaMechInterp as qmi

In [None]:
!pip install --upgrade git+https://github.com/PhilipQuirke/quanta_maths.git
import MathsMechInterp as mmi

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import kaleido
import plotly.io as pio

if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

In [None]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [None]:
import json
import torch
import torch.nn.functional as F
import numpy as np
import random
import itertools
import re
from enum import Enum

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from huggingface_hub import hf_hub_download

In [None]:
import re
import sklearn # Aka scikit.learn
import skopt # Aka scikit.optimize

# Part 1A: Configuration

Which existing model do we want to analyze?

The existing model weightings created by the sister Colab [VerifiedArithmeticTrain](https://github.com/PhilipQuirke/quanta_maths/blob/main/assets/VerifiedArithmeticTrain.ipynb) are loaded from HuggingFace (in Part 5). Refer https://github.com/PhilipQuirke/quanta_maths/blob/main/README.md for more detail.

In [None]:
# Singleton QuantaTool "main" configuration class. MathsConfig is derived from the chain AlgoConfig > UsefulConfig > ModelConfig
cfg = mmi.MathsConfig()

# Which model do we want to analyze? Uncomment one line:

# Addition models
#cfg.set_model_names( "add_d5_l1_h3_t15K_s372001" )  # AddAccuracy=Two9s. Inaccurate as only has one layer. Can predict S0, S1 and S2 complexity questions.
#cfg.set_model_names( "add_d5_l2_h3_t15K_s372001" )  # AddAccuracy=Six9s. AvgFinalLoss=1.6e-08
#cfg.set_model_names( "add_d5_l2_h3_t40K_s372001" )  # AddAccuracy=Six9s. AvgFinalLoss=2e-09. Fewest nodes
#cfg.set_model_names( "add_d6_l2_h3_t15K_s372001" )  # AddAccuracy=Six9s. AvgFinalLoss=1.7e-08. MAIN FOCUS
#cfg.set_model_names( "add_d6_l2_h3_t20K_s173289" )  # AddAccuracy=Six9s. AvgFinalLoss=1.5e-08
#cfg.set_model_names( "add_d6_l2_h3_t20K_s572091" )  # AddAccuracy=Six9s. AvgFinalLoss=7e-09
#cfg.set_model_names( "add_d6_l2_h3_t40K_s372001" )  # AddAccuracy=Six9s. AvgFinalLoss 2e-09. Fewest nodes
#cfg.set_model_names( "add_d10_l2_h3_t40K_s572091" ) # AddAccuracy=Six9s. AvgFinalLoss=7e-09. (1/M fail: 0000000555+0000000445=+00000001000 ModelAnswer: +00000000900)
#cfg.set_model_names( "add_d10_l2_h3_t40K_gf_s572091" ) # AddAccuracy=Six9s. AvgFinalLoss=3.5-09. GrokFast.

# Subtraction model
#cfg.set_model_names( "sub_d6_l2_h3_t30K_s372001" )  # SubAccuracy=Six9s. AvgFinalLoss=5.8e-06
#cfg.set_model_names( "sub_d10_l2_h3_t75K_s173289" )  # SubAccuracy=Two9s. (6672/M fails) AvgFinalLoss=0.002002022
#cfg.set_model_names( "sub_d10_l2_h3_t75K_gf_s173289" )  # SubAccuracy=Two9s. GrokFast. (5246/M fails) AvgFinalLoss=0.001197

# Mixed (addition and subtraction) model
#cfg.set_model_names( "mix_d5_l3_h4_t40K_s372001" )  # Add/SubAccuracy=Six9s/Six9s. AvgFinalLoss=9e-09. (0/M fails, 0/M fails)
#cfg.set_model_names( "mix_d6_l3_h4_t40K_s372001" )  # Add/SubAccuracy=Six9s/Six9s. AvgFinalLoss=5e-09. (1/M fail)
cfg.set_model_names( "mix_d7_l3_h4_t50K_s372001" )  # Add/SubAccuracy=Five9s/Five9s. AvgFinalLoss=2e-08. (2/M fails, 6/M fails)
#cfg.set_model_names( "mix_d8_l3_h4_t60K_s173289" )  # Add/SubAccuracy=Six9s/Five9s. AvgFinalLoss=4.7e-08. (0/M fails, 7/M fails)
#cfg.set_model_names( "mix_d9_l3_h4_t60K_s173289" )  # Add/SubAccuracy=Six9s/Four9s. AvgFinalLoss=3.2e-07. (1/M fails, 33/M fails)
#cfg.set_model_names( "mix_d10_l3_h4_t75K_s173289" )  # Add/SubAccuracy=Five9s/Two9s. AvgFinalLoss=1.125e-06 (2/M fail, 295/M fail)
#cfg.set_model_names( "mix_d10_l3_h4_t75K_gf_s173289" )  # Add/SubAccuracy=Six9s/Three9s. GrokFast. AvgFinalLoss=8.85e-07 (1/M fail, 294/M fail)
#cfg.set_model_names( "mix_d11_l3_h4_t80K_s572091" )  # Add/SubAccuracy=Six9s/Four9s AvgFinalLoss=3.9e-08 (0/M fail, 13/M fail)

# Mixed models initialized with addition model
#cfg.set_model_names( "ins1_mix_d5_l2_h3_t40K_s572091" )  # Add/SubAccuracy=TODO
#cfg.set_model_names( "ins1_mix_d6_l2_h3_t40K_s572091" )  # Add/SubAccuracy=Six9s/Five9s. AvgLoss = 2.4e-08 (5/M fails)
#cfg.set_model_names( "ins1_mix_d6_l3_h3_t40K_s572091" )  # Add/SubAccuracy=Six9s/Five9s. AvgFinalLoss=1.8e-08. (3/M fails)
#cfg.set_model_names( "ins1_mix_d6_l3_h3_t80K_s572091" )  # Add/SubAccuracy=Six9s/Five9s AvgLoss = 1.6e-08 (3/M fails)
#cfg.set_model_names( "ins1_mix_d6_l3_h4_t40K_s372001" )  # Add/SubAccuracy=Six9s/Six9s. AvgFinalLoss=8e-09. MAIN FOCUS
#cfg.set_model_names( "ins1_mix_d6_l3_h4_t40K_s173289" )  # Add/SubAccuracy=Five9s/Five9s. AvgFinalLoss=1.4e-08. (3/M fails, 2/M fails)
#cfg.set_model_names( "ins1_mix_d6_l3_h4_t50K_s572091" )  # Add/SubAccuracy=Six9s/Five9s. AvgFinalLoss=2.9e-08. (4/M fails)
#cfg.set_model_names( "ins1_mix_d7_l3_h4_t50K_s572091" )  # Add/SubAccuracy=TODO
#cfg.set_model_names( "ins1_mix_d8_l3_h4_t70K_s572091" )  # Add/SubAccuracy=TODO
#cfg.set_model_names( "ins1_mix_d9_l3_h4_t70K_s572091" )  # Add/SubAccuracy=TODO
#cfg.set_model_names( "ins1_mix_d10_l3_h3_t50K_s572091" )  # Add/SubAccuracy=Five9s/Five9s. AvgFinalLoss 6.3e-07 (6/M fails, 7/M fails)
#cfg.set_model_names( "ins1_mix_d10_l3_h3_t50K_gf_s572091" ) # Add/SubAccuracy=Five9s/Two9s. GrokFast. AvgFinalLoss=4.0e-06 (2/M fails, 1196/M fails)
#cfg.set_model_names( "ins1_mix_d11_l3_h4_t75K_s572091" )  # Add/SubAccuracy=TODO

# Mixed model initialized with addition model. Reset useful heads every 100 epochs.
#cfg.set_model_names( "ins2_mix_d6_l4_h4_t40K_s372001" )  # Add/SubAccuracy=Five9s/Five9s. AvgFinalLoss=1.7e-08. (3/M fails e.g. 530757+460849=+0991606 ModelAnswer: +0091606) (8 fails e.g. 261926-161857=+0100069 ModelAnswer: +0000069)

# Mixed model initialized with addition model. Reset useful heads & MLPs every 100 epochs.
#cfg.set_model_names( "ins3_mix_d6_l4_h3_t40K_s372001" )  # Add/SubAccuracy=Four9s/Two9s. AvgFinalLoss=3.0e-04. (17/M fails e.g. 273257+056745=+0330002 ModelAnswer: +0320002) (3120 fails e,g. 09075-212133=-0003058 ModelAnswer: +0003058)

# Mixed models initialized with addition model.
#cfg.set_model_names( "ins4_mix_d6_l3_h4_t30K_s775824" )  # Add/SubAccuracy=???/??? TODO
#cfg.set_model_names( "ins4_mix_d6_l2_h4_t30K_s775824" )  # Add/SubAccuracy=???/??? TODO

# Part 1B: Configuration: Input and Output file names



In [None]:
cfg.batch_size = 512 # Default analysis batch size
if cfg.n_layers >= 3 and cfg.n_heads >= 4:
  cfg.batch_size = 256 # Reduce batch size to avoid memory constraint issues.

In [None]:
main_fname_pth = cfg.model_name + '.pth'
main_fname_train_json = cfg.model_name + '_train.json'
main_fname_behavior_json = cfg.model_name + '_behavior.json'
main_fname_maths_json = cfg.model_name + '_maths.json'
main_repo_name="PhilipQuirke/VerifiedArithmetic"

In [None]:
# Update "cfg" with additional training (including cfg.insert_*) config information from stored file:
#      https://huggingface.co/PhilipQuirke/VerifiedArithmetic/raw/main/ins1_mix_d6_l3_h4_t40K_s372001_train.json"
training_data_json = qmi.download_huggingface_json(main_repo_name, main_fname_train_json)
training_loss_list = qmi.load_training_json(cfg, training_data_json)
print('Loaded main model training config / loss from', main_repo_name, main_fname_train_json)

In [None]:
def print_config():
  print("%Add=", cfg.perc_add, "%Sub=", cfg.perc_sub, "%Mult=", cfg.perc_mult, "InsertMode=", cfg.insert_mode, "File=", cfg.model_name)

In [None]:
print_config()
print("weight_decay=", cfg.weight_decay, "lr=", cfg.lr, "batch_size=", cfg.batch_size)
print('Main model will be read from HuggingLab file', main_repo_name, main_fname_pth)
print('Main model training config / loss will be read from HuggingLab file', main_fname_train_json)
print('Main model behavior analysis tags will be read from HuggingLab file', main_fname_behavior_json)
print('Main model maths analysis tags will be read from HuggingLab file', main_fname_maths_json)

In [None]:
# Singleton QuantaTool "ablation intervention" configuration class
acfg = qmi.acfg

# Part 3A: Set Up: Vocabulary / Embedding / Unembedding

  

In [None]:
mmi.set_maths_vocabulary(cfg)
mmi.set_maths_question_meanings(cfg)
print(cfg.token_position_meanings)

# Part 4: Results: Model training loss

In [None]:
print_config()
print( "Avg loss over last 5 epochs", cfg.avg_final_loss)
print( "Final epoch loss", cfg.final_loss)

# Part 6A: Set Up: Load nodes and behaviour tags
Load the useful nodes and associated behaviour tags from a JSON file stored on HuggingFace

In [None]:
print("Loading useful node list with behavior tags from HuggingFace", main_repo_name, main_fname_behavior_json)

file_path = hf_hub_download(repo_id=main_repo_name, filename=main_fname_behavior_json, revision="main")

cfg.useful_nodes.load_nodes(file_path)

# Part 6B: Results: Show nodes and behaviour tags

In [None]:
cfg.useful_nodes.sort_nodes()
cfg.useful_nodes.print_node_tags()

In [None]:
def show_quanta_map( title, major_tag : qmi.QType, minor_tag : str, get_node_details,  \
        image_width_inches : int = 9, image_height_inches : int = 6,
        blue_shades : bool = True, cell_num_shades : int = 6,
        filters : qmi.FilterNode = None, cell_fontsize : int = 9,
        combine_identical_cells : bool = True, show_perc_circles : bool = False ): \

  test_nodes = cfg.useful_nodes
  if filters is not None:
    test_nodes = qmi.filter_nodes(test_nodes, filters)

  ax1, quanta_results, num_results = qmi.calc_quanta_map(
      cfg, blue_shades, cell_num_shades,
      test_nodes, major_tag.value, minor_tag, get_node_details,
      cell_fontsize, combine_identical_cells, show_perc_circles,
      image_width_inches, image_height_inches )

  if num_results > 0:
    plt.show()

# Part 16A: Results: Show failure percentage map

Show the percentage failure rate (incorrect prediction) when individual Attention Heads and MLPs are ablated.

In [None]:
show_quanta_map( "Failure Frequency Behavior Per Node",
                qmi.QType.FAIL, "", qmi.get_quanta_fail_perc,
                image_height_inches = 2 * cfg.n_layers,
                cell_num_shades = qmi.FAIL_SHADES, combine_identical_cells = False, show_perc_circles = True)

# Part 16B - Show answer impact behavior map

Show the purpose of each useful cell by impact on the answer digits A0 to Amax.


In [None]:
show_quanta_map( "Answer Impact Behavior Per Node",
                qmi.QType.IMPACT, "", qmi.get_quanta_impact,
                image_height_inches = 2 * cfg.n_layers,
                cell_num_shades = cfg.num_answer_positions)

# Part 16C: Result: Show attention map

Show attention quanta of useful heads

In [None]:
# Only maps attention heads, not MLP layers
show_quanta_map( "Attention Behavior Per Head",
                qmi.QType.ATTN, "", qmi.get_quanta_attention,
                image_height_inches = 3 * cfg.n_layers,
                cell_num_shades = qmi.ATTN_SHADES )

# Part 16C - Show question complexity map


In [None]:
if cfg.perc_sub > 0:
  # For each useful cell, show if addition (S), positive-answer subtraction (M) and negative-answer subtraction (N) questions relies on the node.
  show_quanta_map( "Maths Operation Coverage",
                  qmi.QType.MATH, "", mmi.get_maths_operation_complexity,
                  image_height_inches = 1.75 * cfg.n_layers,
                  blue_shades = False, cell_num_shades = 4, combine_identical_cells = False)

In [None]:
if cfg.perc_add > 0:
  # For each useful cell, show the minimum addition question complexity that relies on the node, as measured using quanta S0, S1, S2, ...
  show_quanta_map( "Addition Min-Complexity",
                  qmi.QType.MATH_ADD, mmi.MathsBehavior.ADD_COMPLEXITY_PREFIX.value, mmi.get_maths_min_complexity,
                  image_height_inches = 1.25 * cfg.n_layers,
                  blue_shades = False, cell_num_shades = qmi.MATH_ADD_SHADES)

In [None]:
if cfg.perc_sub > 0:
  # For each useful cell, show the minimum "positive-answer subtraction" question complexity that relies on the node, as measured using quanta M0, M1, M2, ...
  show_quanta_map( "Positive-answer Subtraction Min-Complexity",
                  qmi.QType.MATH_SUB, mmi.MathsBehavior.SUB_COMPLEXITY_PREFIX.value, mmi.get_maths_min_complexity,
                  image_height_inches = 1.5 * cfg.n_layers,
                  blue_shades = False, cell_num_shades = qmi.MATH_SUB_SHADES)

In [None]:
if cfg.perc_sub > 0:
  # For each useful cell, show the minimum "negative-answer subtraction" question complexity that relies on the node, as measured using quanta N0, N1, N2, ...
  show_quanta_map( "Negative-answer Subtraction Min-Complexity",
                  qmi.QType.MATH_NEG, mmi.MathsBehavior.NEG_COMPLEXITY_PREFIX.value, mmi.get_maths_min_complexity,
                  image_height_inches = 1.5 * cfg.n_layers,
                  blue_shades = False, cell_num_shades = qmi.MATH_SUB_SHADES)

# Part 17: Set Up: Load maths sub-task tags from json file

Load the useful nodes maths sub-task tags from a JSON file stored on HuggingFace

In [None]:
main_repo_name="PhilipQuirke/VerifiedArithmetic"
print("Loading maths sub-tasks from HuggingFace", main_repo_name, main_fname_maths_json)

file_path = hf_hub_download(repo_id=main_repo_name, filename=main_fname_maths_json, revision="main")

cfg.useful_nodes.load_nodes(file_path)

# Part 23A: Show algorithm quanta map

Plot the "maths" tags as a quanta map. This is an automatically generated partial explanation of the model algorithm.

Nodes with multiple tags perform multiple tasks.

In [None]:
qmi.print_algo_purpose_results(cfg)

In [None]:
show_quanta_map( "Maths Purpose Per Node", qmi.QType.ALGO, "", qmi.get_quanta_binary,
                #image_width_inches = 11,
                cell_num_shades = 2)

# Part 23B: Show known quanta per answer digit

Each of the late positions are soley focused on calculating one answer digit. Show the data have we collected on late answer digit.  



In [None]:
for position in range(cfg.num_question_positions + 1, cfg.n_ctx - 1):
  print("Position:", position)

  # Calculate a table of the known quanta for the specified position for each late token position
  mmi.calc_maths_quanta_for_position_nodes(cfg, position)

  plt.show()

# Part 24: Show maths tags

Show a list of the nodes that have proved useful in calculations, together with data on the nodes behavior and algorithmic purposes.


In [None]:
cfg.useful_nodes.print_node_tags(qmi.QType.ALGO.value, "", False)

In [None]:
mixed_model = cfg.model_name.startswith("ins1_mix_d6_l3_h4_t40K") or cfg.model_name.startswith("ins2_mix_d6_l4_h4_t40K") or cfg.model_name.startswith("ins3_mix_d6_l4_h3_t40K")

# Part 25 : Results: Test Algorithm - Addition

In [None]:
print_config()

In [None]:
# For answer digits (excluding Amax), a An.SA node is needed before the answer digit is revealed
def algo_task_search(algo_nodes, from_digit, to_digit, filter_function, mandatory : bool = True):
    impact_digit = from_digit
    while impact_digit <= to_digit:
        cfg.test_algo_clause(algo_nodes, filter_function(impact_digit), mandatory)
        impact_digit += 1

In [None]:
# Read as: And( HasAlgoTag:A4.SA, Position<=A3 )
def algo_sa_filter(impact_digit):
    return qmi.FilterAnd(qmi.FilterAlgo(mmi.add_sa_functions.tag(impact_digit)), qmi.FilterPosition(cfg.an_to_position_name(impact_digit+1), qmi.QCondition.MAX))

# For each question digit, an An.SA node exists before the answer digit is revealed
def algo_sa_search(algo_nodes):
    algo_task_search(algo_nodes, 0, cfg.n_digits-1, algo_sa_filter)


# Read as: And( HasAlgoTag:A4.SC, Position<=A3 )
def algo_sc_filter(impact_digit):
    return qmi.FilterAnd(qmi.FilterAlgo(mmi.add_sc_functions.tag(impact_digit)), qmi.FilterPosition(cfg.an_to_position_name(impact_digit+1), qmi.QCondition.MAX))

# For each question digit, except A0, an An.SC node exists before the answer digit is revealed
def algo_sc_search(algo_nodes, mandatory : bool = True):
    algo_task_search(algo_nodes, 1, cfg.n_digits-1, algo_sc_filter, mandatory)


# Read as: And( HasAlgoTag:A4.SS, Position<=A3 )
def algo_ss_filter(impact_digit):
    return qmi.FilterAnd(qmi.FilterAlgo(mmi.add_ss_functions.tag(impact_digit)), qmi.FilterPosition(cfg.an_to_position_name(impact_digit+1), qmi.QCondition.MAX))

# For each question digit, except A0 and A1, an An.SS node exists before the answer digit is revealed
def algo_ss_search(algo_nodes):
    algo_task_search(algo_nodes, 2, cfg.n_digits-1, algo_ss_filter)


# Read as: And( HasAlgoTag:A4.ST, Position<=Amax )
def algo_st_filter(impact_digit):
    return qmi.FilterAnd(qmi.FilterAlgo(mmi.add_st_functions.tag(impact_digit)), qmi.FilterPosition(cfg.an_to_position_name(cfg.n_digits+1), qmi.QCondition.MAX))

# For each question digit, except A0, an An.ST node exists by Amax
def algo_st_search(algo_nodes):
    algo_task_search(algo_nodes, 2, cfg.n_digits-1, algo_st_filter)


# The nodes that implement the ST sub-task must be sequenced so that the SV values can be calculated.
def algo_sv_search(algo_nodes):
    # TODO: It is not clear what this "sequencing" condition is
    pass

## Part25A : 1-layer addition models use sub-tasks An.SA, An.SC, An.SS

A 1-layer model can do 99% of addition questions using SA, SC and SS sub-tasks (as per Paper 1).

For add_d5_l1_h3_t15K_s372001 below search says: 14/15 heads have purpose assigned. 0/6 neurons have purpose assigned.

In [None]:
if cfg.perc_add > 0 and cfg.n_layers == 1 :

    algo_nodes = cfg.start_algorithm_test(acfg)

    algo_sa_search(algo_nodes)
    algo_sc_search(algo_nodes, mandatory=True)
    algo_ss_search(algo_nodes)
    algo_sv_search(algo_nodes)

    cfg.print_algo_clause_results()
    cfg.print_algo_purpose_results(algo_nodes)

## Part25B : 2-layer addition models use sub-tasks An.SA, An.ST and maybe An.SC

A 2-layer model can do 99.9999% of addition questions using SA and ST sub-tasks (as per Paper 2). It may use (redundant) SC sub-tasks.

For add_d6_l2_h3_t15K below search says: 21/31 heads have purpose assigned. 0/17 neurons have purpose assigned.

TODO: 4 other models.

In [None]:
if cfg.perc_add > 0 and cfg.n_layers >= 2 :

  algo_nodes = cfg.start_algorithm_test(acfg)

  algo_sa_search(algo_nodes)
  algo_st_search(algo_nodes)
  # SC does the same job as but is less accurate than ST. So the SC nodes are redundant and may be optimised out.
  algo_sc_search(algo_nodes, mandatory=False)

  cfg.print_algo_clause_results()
  cfg.print_algo_purpose_results(algo_nodes)

# Part 26: Results: Test Algorithm - Subtraction

## Part 26A : 1-layer subtraction uses sub-tasks MD, MB & MZ

A 1-layer model can do 99% of subtraction questions using MD, MB and MZ sub-tasks (as per Paper 1).

In [None]:
# Read as: And( HasAlgoTag:A4.MD, Position<=A3 )
def algo_md_filter(impact_digit):
    return qmi.FilterAnd(qmi.FilterAlgo(mmi.sub_md_functions.tag(impact_digit)), qmi.FilterPosition(cfg.an_to_position_name(impact_digit+1), qmi.QCondition.MAX))

# For each question digit, an An.MD node exists before the answer digit is revealed
def algo_md_search(algo_nodes):
    algo_task_search(algo_nodes, 0, cfg.n_digits-1, algo_md_filter)


# Read as: And( HasAlgoTag:A4.MB, Position<=A3 )
def algo_mb_filter(impact_digit):
    return qmi.FilterAnd(qmi.FilterAlgo(mmi.sub_mb_functions.tag(impact_digit)), qmi.FilterPosition(cfg.an_to_position_name(impact_digit+1), qmi.QCondition.MAX))

# For each question digit, an An.MB node exists before the answer digit is revealed
def algo_mb_search(algo_nodes, mandatory : bool = True):
    algo_task_search(algo_nodes, 0, cfg.n_digits-1, algo_mb_filter, mandatory)


In [None]:
if cfg.perc_sub > 0 and cfg.n_layers == 1 :

    algo_nodes = cfg.start_algorithm_test(acfg)

    algo_md_search(algo_nodes)
    algo_mb_search(algo_nodes, mandatory = True)

    cfg.print_algo_clause_results()
    cfg.print_algo_purpose_results(algo_nodes)

## Part 26B: Test Algorithm - Subtraction

To accurately predict if the answer sign is + or - the model must calculate if D < D'. To calculate this, the model must calculate Dn < D'n or (Dn = D'n and (Dn-1 < D'n-1 or (Dn-2 = D'n-1 and ( etc. It must predict this before the answer sign is revealed.

We expect to see nodes useful in negative answer questions, with PCA bigram (or trigram) outputs, attending to these input pairs, evaluated in this order, before the answer sign is revealed.

In [None]:
def algo_mt_filter(impact_digit,sign_position):
    return qmi.FilterAnd(
      qmi.FilterAlgo(mmi.sub_mt_functions.tag(impact_digit)),
      qmi.FilterPosition(sign_position, qmi.QCondition.MAX))


# For answer digits (excluding Amax), An.MT is needed before the answer digit is revealed
def algo_mt_search(algo_nodes):
  mt_locations = {}

  sign_position = cfg.an_to_position_name(cfg.n_digits+1)

  for impact_digit in range(cfg.n_digits):
    # For answer digits (excluding the +/- answer sign and 0 or 1 first answer digit), An.SC is calculated before the answer sign is revealed
    position = cfg.test_algo_clause(algo_nodes, algo_mt_filter(impact_digit, sign_position))
    mt_locations[impact_digit] = position

  # Check that mt_locations[6] < mt_locations[5] < mt_locations[4] < etc
  print("MT Locations:", mt_locations)
  for impact_digit in range(cfg.n_digits):
    if impact_digit > 0:
      sc1 = mt_locations[impact_digit]
      sc2 = mt_locations[impact_digit-1]
      description = f"MT Ordering: A{impact_digit}={sc1}, A{impact_digit-1}={sc2}"
      cfg.test_algo_logic(description, sc1 >= 0 and sc2 >= 0 and sc1 < sc2)

In [None]:
if cfg.perc_sub > 0 and cfg.n_layers >= 2 :

    algo_nodes = cfg.start_algorithm_test(acfg)

    algo_md_search(algo_nodes)
    algo_mt_search(algo_nodes)
    algo_mb_search(algo_nodes, mandatory = False) # TODO: Is this true?

    cfg.print_algo_clause_results()
    cfg.print_algo_purpose_results(algo_nodes)

# Part 27: Test Algorithm - Mixed Addition and Subtraction model

What algorithm do mixed models use to perform both addition and subtraction? Our working hypothesis is Hypothesis 2 as described in https://github.com/PhilipQuirke/verified_transformers/blob/main/mixed_model.md





## Part 27A: Calculating answer digit A2 in token position A3

The below graph displays the same (behavior and algorithm) data as the quanta maps. Refer https://github.com/PhilipQuirke/verified_transformers/blob/main/mixed_model.md section 27A for more explanation.

    


In [None]:
mmi.calc_maths_quanta_for_position_nodes(cfg, 18)
plt.show()

## Part 27.H3 Calculate whether D > D' (using NG tasks)

In [None]:
filters = qmi.FilterContains(qmi.QType.MATH_NEG, "")

#print("NG tagged nodes:", qmi.filter_nodes( cfg.useful_nodes, filters ).get_node_names())

show_quanta_map( "Subtraction Behavior NG Nodes", qmi.QType.MATH_SUB, "", mmi.get_maths_min_complexity, 9, 6, filters=filters, blue_shades=False, cell_num_shades=2)
show_quanta_map( "Attention Behavior Per NG Head", qmi.QType.ATTN, "", qmi.get_quanta_attention, 9, 8, filters=filters, cell_num_shades=10)
show_quanta_map( "Maths Purpose Per NG Node", qmi.QType.ALGO, "", qmi.get_quanta_binary, 9, 6, filters=filters, cell_num_shades=2)

In [None]:
if mixed_model:
  algo_nodes = cfg.start_algorithm_test(acfg)

  algo_sc_search(algo_nodes)
  cfg.print_algo_clause_results()

  cfg.print_algo_purpose_results(algo_nodes)

# Part 30: Unit Test automated searches

In [None]:
def check_algo_tag_exists(node_location_as_str, the_tags ):
  node_location = qmi.str_to_node_location(node_location_as_str)
  node = cfg.useful_nodes.get_node(node_location)
  if node is None:
      print( "Node", node_location_as_str, "is missing")
  else:
      for the_tag in the_tags:
          if not node.contains_tag(qmi.QType.ALGO.value, the_tag):
              print( "Node", node.name(), "is missing tag", the_tag, "It has tags:", node.tags )

In [None]:
print(cfg.model_name)

if cfg.model_name.startswith('add_d6_l2_h3_t15K'):
  check_algo_tag_exists('P11L0H0', ['A2.ST'] )
  check_algo_tag_exists('P12L0H0', ['A3.ST'] )
  check_algo_tag_exists('P14L0H0', ['A5.SS', 'A4.ST'] )
  check_algo_tag_exists('P14L0H2', ['A5.SC', 'A5.ST'] )
  check_algo_tag_exists('P14L1H1', ['OPR'] )
  check_algo_tag_exists('P15L0H0', ['A4.SC'] )
  check_algo_tag_exists('P15L0H1', ['A5.SA'] )
  check_algo_tag_exists('P15L0H2', ['A5.SA'] )
  check_algo_tag_exists('P16L0H0', ['A3.SC'] )
  check_algo_tag_exists('P16L0H1', ['A4.SA'] )
  check_algo_tag_exists('P16L0H2', ['A4.SA'] )
  check_algo_tag_exists('P17L0H0', ['A2.SC'] )
  check_algo_tag_exists('P17L0H1', ['A3.SA'] )
  check_algo_tag_exists('P17L0H2', ['A3.SA'] )
  check_algo_tag_exists('P18L0H0', ['A1.SC'] )
  check_algo_tag_exists('P18L0H1', ['A2.SA'] )
  check_algo_tag_exists('P18L0H2', ['A2.SA'] )
  check_algo_tag_exists('P19L0H0', ['A0.SC'] )
  check_algo_tag_exists('P19L0H1', ['A1.SA'] )
  check_algo_tag_exists('P19L0H2', ['A1.SA'] )
  check_algo_tag_exists('P20L0H1', ['A0.SA'] )
  check_algo_tag_exists('P20L0H2', ['A0.SA'] )

if cfg.model_name.startswith('mix_d6_l3_h4_t40K'):
  check_algo_tag_exists('P8L0H1', ['OPR'] )
  check_algo_tag_exists('P13L2H0', ['A7.NG'] )
  check_algo_tag_exists('P15L0H0', ['A5.SA', 'A5.MD'] )
  check_algo_tag_exists('P15L0H3', ['A5.SA', 'A5.MD'] )
  check_algo_tag_exists('P16L0H3', ['A4.SA.A4', 'A4.MD.A4'] )
  check_algo_tag_exists('P17L0H1', ['A3.NG'] )
  check_algo_tag_exists('P17L0H3', ['A3.SA.A3', 'A3.MD.A3'] )
  check_algo_tag_exists('P18L0H1', ['A2.NG'] )
  check_algo_tag_exists('P18L0H3', ['A2.SA.A2', 'A2.MD.A2'] )
  check_algo_tag_exists('P19L0H1', ['A1.NG'] )
  check_algo_tag_exists('P19L0H3', ['A1.SA.A1', 'A1.MD.A1'] )
  check_algo_tag_exists('P20L0H0', ['A0.SA', 'A0.MD'] )
  check_algo_tag_exists('P20L0H3', ['A0.SA', 'A0.MD'] )
  check_algo_tag_exists('P20L2H1', ['A0.NG'] )

if cfg.model_name.startswith('ins1_mix_d6_l3_h4_t40K'):
  check_algo_tag_exists('P6L0H0', ['OPR'] )
  check_algo_tag_exists('P9L0H0', ['A4.MT'] )
  check_algo_tag_exists('P9L0H1', ['A4.ST'] )
  check_algo_tag_exists('P10L0H1', ['A2.MT'] )
  check_algo_tag_exists('P10L0H3', ['OPR'] )
  check_algo_tag_exists('P12L0H0', ['A3.MT'] )
  check_algo_tag_exists('P12L0H1', ['A3.ST'] )
  check_algo_tag_exists('P12L0H3', ['OPR'] )
  check_algo_tag_exists('P12L1H2', ['A1.MT'] )
  check_algo_tag_exists('P13L0H3', ['OPR'] )
  check_algo_tag_exists('P13L1H0', ['OPR'] )
  check_algo_tag_exists('P13L2H0', ['OPR'] )
  check_algo_tag_exists('P14L0H0', ['A5.SS', 'OPR'] )
  check_algo_tag_exists('P14L0H1', ['OPR'] )
  check_algo_tag_exists('P14L0H2', ['A5.SC', 'A5.ST', 'SGN'] )
  check_algo_tag_exists('P14L1H2', ['SGN'] )
  check_algo_tag_exists('P14L1H3', ['SGN'] )
  check_algo_tag_exists('P15L0H0', ['A4.SC'] )
  check_algo_tag_exists('P15L0H1', ['A5.SA', 'A5.MD'] )
  check_algo_tag_exists('P15L0H2', ['A5.SA', 'A5.MD'] )
  check_algo_tag_exists('P15L0H3', ['SGN'] )
  check_algo_tag_exists('P16L0H0', ['A3.SC'] )
  check_algo_tag_exists('P16L0H1', ['A4.SA', 'A4.MD', 'A4.ND'] )
  check_algo_tag_exists('P16L0H2', ['A4.SA', 'A4.MD', 'A4.ND'] )
  check_algo_tag_exists('P16L0H3', ['SGN'] )
  check_algo_tag_exists('P16L1H0', ['SGN'] )
  check_algo_tag_exists('P16L2H0', ['SGN'] )
  check_algo_tag_exists('P17L0H0', ['A2.SC'] )
  check_algo_tag_exists('P17L0H1', ['A3.SA', 'A3.MD', 'A3.ND'] )
  check_algo_tag_exists('P17L0H2', ['A3.SA', 'A3.MD', 'A3.ND'] )
  check_algo_tag_exists('P17L0H3', ['SGN'] )
  check_algo_tag_exists('P18L0H0', ['A1.SC'] )
  check_algo_tag_exists('P18L0H1', ['A2.SA', 'A2.MD', 'A2.ND'] )
  check_algo_tag_exists('P18L0H2', ['A2.SA', 'A2.MD', 'A2.ND'] )
  check_algo_tag_exists('P19L0H0', ['A0.MB.A1'] )
  check_algo_tag_exists('P19L0H1', ['A1.SA', 'A1.MD', 'A1.ND'] )
  check_algo_tag_exists('P19L0H2', ['A1.SA', 'A1.MD', 'A1.ND'] )
  check_algo_tag_exists('P20L0H1', ['A0.SA', 'A0.MD', 'A0.ND'] )
  check_algo_tag_exists('P20L0H2', ['A0.SA', 'A0.MD', 'A0.ND'] )