# Verified Integer Mathematics in Transformers - Analyse the Model

This CoLab analyses a Transformer model that performs integer addition, subtraction and multiplication e.g. 133357+182243=+0315600, 123450-345670=-0123230 and 000345*000823=+283935. Each digit is a separate token. For 6 digit questions, the model is given 14 "question" (input) tokens, and must then predict the corresponding 8 "answer" (output) tokens.

## Tips for using the Colab
 * You can run and alter the code in this CoLab notebook yourself in Google CoLab ( https://colab.research.google.com/ ).
 * To run the notebook, in Google CoLab, **you will need to** go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.
 * Some graphs are interactive!
 * Use the table of contents pane in the sidebar to navigate.
 * Collapse irrelevant sections with the dropdown arrows.
 * Search the page using the search in the sidebar, not CTRL+F.

# Part 0: Import libraries
Imports standard libraries.

Imports "verified_transformer" public library as "qt". This library is specific to this CoLab's "QuantaTool" approach to transformer analysis. Refer to [README.md](https://github.com/PhilipQuirke/verified_transformers/blob/main/README.md) for more detail.

In [None]:
DEVELOPMENT_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    !pip install matplotlib

    !pip install kaleido
    !pip install transformer_lens
    !pip install torchtyping
    !pip install transformers

    !pip install numpy
    !pip install scikit-learn

except:
    IN_COLAB = False

    def setup_jupyter(install_libraries=False):
        if install_libraries:
            !pip install matplotlib==3.8.4
            !pip install kaleido==0.2.1
            !pip install transformer_lens==1.15.0
            !pip install torchtyping==0.1.4
            !pip install transformers==4.39.3

            !pip install numpy==1.26.4
            !pip install plotly==5.20.0
            !pip install pytest==8.1.1
            !pip install scikit-learn==1.4.1.post1

        print("Running as a Jupyter notebook - intended for development only!")
        from IPython import get_ipython

        ipython = get_ipython()
        # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
        ipython.magic("load_ext autoreload")
        ipython.magic("autoreload 2")

    # setup_jupyter(install_libraries=True)   # Uncomment if you need to install libraries in notebook.
    setup_jupyter(install_libraries=False)

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import kaleido
import plotly.io as pio

if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

In [None]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [None]:
import json
import torch
import torch.nn.functional as F
import numpy as np
import random
import itertools
import re
from enum import Enum

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import textwrap

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [None]:
# Import Principal Component Analysis (PCA) library
use_pca = True
try:
  from sklearn.decomposition import PCA
except Exception as e:
  print("pca import failed with exception:", e)
  use_pca = False

  # Sometimes version conflicts means the PCA library does not import. This workaround partially fixes the issue
  !pip install --upgrade numpy
  !pip install --upgrade scikit-learn

  # To complete workaround, now select menu option "Runtime > Restart session and Run all".
  stop

In [None]:
! pip uninstall QuantaTools -y || true   # Ensure a clean install.

In [None]:
# Refer https://github.com/PhilipQuirke/verified_transformers/blob/main/README.md
!pip install --upgrade git+https://github.com/PhilipQuirke/verified_transformers.git  # Specify @branch if testing a specific branch
import QuantaTools as qt

# Part 1A: Configuration

Which existing model do we want to analyse?

The existing model weightings created by the sister Colab [VerifiedArithmeticTrain](https://github.com/PhilipQuirke/transformer-maths/blob/main/assets/VerifiedArithmeticTrain.ipynb) are loaded from HuggingFace (in Part 5). Refer https://github.com/PhilipQuirke/verified_transformers/blob/main/README.md for more detail.

In [None]:
# Singleton QuantaTool "main" configuration class. MathsConfig is derived from the chain AlgoConfig > UsefulConfig > ModelConfig
cfg = qt.MathsConfig()

# Singleton QuantaTool "ablation intervention" configuration class
acfg = qt.acfg

In [None]:
# Which model do we want to analyse? Uncomment one line:

#cfg.model_name = "" # Use configuration specified in cfg defaults

#cfg.model_name = 5-digit and 6-digit digit Addition models
#cfg.model_name = "add_d5_l1_h3_t15K_s372001"  # Inaccurate as only has one layer. Can predict S0, S1 and S2 complexity questions.
#cfg.model_name = "add_d5_l2_h3_t15K_s372001"  # AvgFinalLoss=1.6e-08. Accurate on 1M Qs
cfg.model_name = "add_d6_l2_h3_t15K_s372001"  # AvgFinalLoss=1.7e-08. Accurate on 1M Qs
#cfg.model_name = "add_d6_l2_h3_t20K_s173289"  # AvgFinalLoss=1.5e-08. Accurate on 1M Qs
#cfg.model_name = "add_d6_l2_h3_t20K_s572091"  # AvgFinalLoss=7e-09. Accurate on 1M Qs

# 6-digit Subtraction model
#cfg.model_name = "sub_d6_l2_h3_t30K_s372001"  # AvgFinalLoss=5.8e-06. Fails 1M Qs

# 6-digit Mixed (addition and subtraction) models
#cfg.model_name = "mix_d6_l3_h4_t40K_s372001"  # AvgFinalLoss=5e-09. Fails 1M Qs

#  "ins1" 6-digit Mixed models initialised with 6-digit addition model
#cfg.model_name = "ins1_mix_d6_l3_h4_t40K_s372001"  # AvgFinalLoss=8e-09. Accurate on 1M Qs for Add and Sub
#cfg.model_name = "ins1_mix_d6_l3_h4_t40K_s173289"  # AvgFinalLoss=1.6e-08. 936K for Add, 1M Qs for Sub
#cfg.model_name = "ins1_mix_d6_l3_h4_t50K_s572091"  # AvgFinalLoss=2.9e-08. 1M for Add. 300K for Sub. For 000041-000047=-0000006 gives +0000006. Improve training data.
#cfg.model_name = "ins1_mix_d6_l3_h3_t40K_s572091"  #  AvgFinalLoss=1.8e-08. Fails on 1M Qs. For 099111-099111=+0000000 gives -0000000. Improve training data.

# "ins2" 6-digit Mixed model initialised with 6-digit addition model. Reset useful heads every 100 epochs.
#cfg.model_name = "ins2_mix_d6_l4_h4_t40K_s372001"  # AvgFinalLoss=7e-09. Fails 1M Qs

# "ins3" 6-digit Mixed model initialised with 6-digit addition model. Reset useful heads & MLPs every 100 epochs.
#cfg.model_name = "ins3_mix_d6_l4_h3_t40K_s372001"  # AvgFinalLoss=2.6e-06. Fails 1M Qs

# Part 1B: Configuration: Input and Output file names



In [None]:
# Needed when user changes model_name and reruns this Colab a second time
cfg.reset_useful()
cfg.reset_algo()
cfg.initialize_maths_token_positions()
acfg.reset_ablate()

if cfg.model_name != "":
  # Update cfg member data n_digits, n_layers, n_heads, n_training_steps from model_name
  cfg.parse_model_name()

  # Addition model
  cfg.perc_sub = 0
  if cfg.model_name.startswith("sub_") :
    # Subtraction model
    cfg.perc_sub = 100
  elif cfg.model_name.startswith("mix") :
    # Mixed (addition and subtraction) model
    cfg.perc_sub = 66 # Train on 66% subtraction and 33% addition question batches
  elif cfg.model_name.startswith("ins") :
    # Mixed model initialised with an addition model (using insert mode 1, 2 or 3)
    cfg.perc_sub = 80 # Train on 80% subtraction and 20% addition question batches

  # We train multiple versions of some models, inserting different addition models.
  cfg.insert_training_seed = cfg.training_seed

  if cfg.model_name.startswith("ins1_mix_d6_l3") :
    if cfg.training_seed == 372001:
      # Mixed model initialised with add_d6_l2_h3_t15K.pth.
      cfg.insert_n_training_steps = 15000
    else:
      # Mixed model initialised with add_d6_l2_h3_t20K.pth.
      cfg.insert_n_training_steps = 20000

  cfg.batch_size = 512 # Default analysis batch size
  if cfg.n_layers >= 3 and cfg.n_heads >= 4:
    cfg.batch_size = 256 # Reduce batch size to avoid memory constraint issues.

In [None]:
main_fname = cfg.file_config_prefix()
main_fname_pth = main_fname + '.pth'
main_fname_behavior_json = main_fname + '_behavior.json'
main_fname_algorithm_json = main_fname + '_algorithm.json'

def print_config():
  print("%Add=", cfg.perc_add(), "%Sub=", cfg.perc_sub, "%Mult=", cfg.perc_mult, "InsertMode=", cfg.insert_mode, "File=", main_fname)

print_config()
print("weight_decay=", cfg.weight_decay, "lr=", cfg.lr, "batch_size=", cfg.batch_size)
print('Main model will be read from HuggingLab file', main_fname_pth)
print('Main model behavior analysis tags will save to Colab temporary file', main_fname_behavior_json)
print('Main model algorithm analysis tags will save to Colab temporary file', main_fname_algorithm_json)

# Part 3A: Set Up: Vocabulary / Embedding / Unembedding

  

In [None]:
main_fname_pth

In [None]:
qt.set_maths_vocabulary(cfg)
qt.set_maths_question_meanings(cfg)
print(cfg.token_position_meanings)

# Part 3B: Set Up: Create model

In [None]:
# Transformer creation

# Structure is documented at https://neelnanda-io.github.io/TransformerLens/transformer_lens.html#transformer_lens.HookedTransformerConfig.HookedTransformerConfig
ht_cfg = HookedTransformerConfig(
    n_layers = cfg.n_layers,
    n_heads = cfg.n_heads,
    d_model = cfg.d_model,
    d_head = cfg.d_head,
    d_mlp = cfg.d_mlp(),
    act_fn = cfg.act_fn,
    normalization_type = 'LN',
    d_vocab = cfg.d_vocab,
    d_vocab_out = cfg.d_vocab,
    n_ctx = cfg.n_ctx(),
    init_weights = True,
    device = "cuda",
    seed = cfg.training_seed,
)

cfg.main_model = HookedTransformer(ht_cfg)

# Part 4: Set Up: Loss Function & Data Generator
This maths loss function and data generator are imported from QuantaTools as logits_to_tokens_loss, loss_fn, maths_data_generator_core and maths_data_generator.

In [None]:
# Define "iterator" maths "questions" data generator function. Invoked using next().
ds = qt.maths_data_generator( cfg )

In [None]:
# Generate sample data generator (unit test)
print(next(ds)[:3,:])

# Part 5: Set Up: Load Model from HuggingFace

In [None]:
main_fname_pth

In [None]:
main_repo_name="PhilipQuirke/VerifiedArithmetic"
print("Loading model from HuggingFace", main_repo_name, main_fname_pth)

cfg.main_model.load_state_dict(utils.download_file_from_hf(repo_name=main_repo_name, file_name=main_fname_pth, force_is_torch=True))
cfg.main_model.eval()

# Part 7A: Set Up: Create sample maths questions

Create a batch of manually-curated mathematics test questions, and cache some sample model prediction outputs.

In [None]:
varied_questions = qt.make_maths_test_questions_and_answers(cfg)
num_varied_questions = varied_questions.shape[0]

qt.a_set_ablate_hooks(cfg)
qt.a_calc_mean_values(cfg, varied_questions)

In [None]:
print("Num questions:", num_varied_questions, "Question length:", len(varied_questions[0]))
print("Sample Question:", varied_questions[0])

# Part 7B: Results: Can the model correctly predict sample questions?

Ask the model to predict the varied_questions (without intervention) to see if the model gets them all right. Categorise answers by complexity

In [None]:
# Test maths question prediction accuracy on the sample questions provided.
# Does NOT use acfg.* or UsefulInfo.* information
# Used to estimate the accuracy of the model's predictions.
# Returns a reduced set of questions - removing questions that the model failed to answer.
print_config()

acfg.show_test_failures = True
varied_questions = qt.test_maths_questions_by_complexity(cfg, acfg, varied_questions)
acfg.show_test_failures = False

num_varied_questions = varied_questions.shape[0]

# Part 9 : Results: Can the model do 1 million questions without error?

If the model passes this test, this is evidence (not proof) that the model is fully accurate. There may be very rare edge cases (say 1 in ten million) that did not appear in the test questions. Even if you believe you know all the edge cases, and have enriched the training data to contain them, you may not have thought of all edge cases, so this is not proof.

If the model fails this test:
- Add a few of the failures into the "test questions" into qt.make_maths_test_questions_and_answers()
- Understand the "use case(s)" driving these failures
- Alter the Training CoLab data_generator_core to enrich the training data with examples if these use case(s) and retrain the model.  

Takes ~25 mins to run (successfully) for ins_mix_d6_l3_h4_t40K_s372001

In [None]:
# Takes ~25 minutes to run with num_questions=1000000.
# When not interested in this test, use num_questions=1000 for speed (while still checking the code runs).
acfg.show_test_failures = True
qt.test_correctness_on_num_questions(cfg, acfg, num_questions=1000)
acfg.show_test_failures = False

# Part 10: Set Up: Which token positions are used by the model?

Ablate all nodes in each (question and answer) token position (by overriding the model memory aka residual stream). If the model's prediction loss increases, the token position is useful to the algorithm. Unused token positions are excluded from further analysis. Used to populate the UsefulInfo.useful_positions data. This is token **position level** information.  

In [None]:
num_failures_list = []

for position in range(cfg.n_ctx()):
  # Test accuracy of model in predicting question answers. Ablates all nodes at acfg.ablate_position. Does NOT use UsefulInfo.* information.
  num_fails = qt.test_maths_questions_by_impact(cfg, acfg, varied_questions, position, True)

  if num_fails > 0:
    # Add position to UsefulInfo.useful_positions
    cfg.add_useful_position(position)
    num_failures_list += [num_fails]
  else:
    num_failures_list += "."

# Part 11: Results: Which token positions are used by the model?

Which token positions are is used in the model's predictions? Unused token positions are excluded from further analysis.

In [None]:
print_config()
print("num_questions=", num_varied_questions)
print("useful_positions=", cfg.useful_positions )
print()

cfg.graph_file_suffix = "svg" # Can be pdf, svg or png
cfg.calc_position_failures_map(num_failures_list)
qt.save_plt_to_file(cfg=cfg, full_title="Failures When Position Ablated")
plt.show()

# Part 12A: Set Up: Which nodes are used by the model?

Here we ablate each (attention head and MLP neuron) node in each (question and answer) token position see if the model's prediction loss increases. If loss increases then the "node + token position" is used by the algorithm. Used to calculate the UsefulInfo.useful_node_location. This is **position+node level** information.


In [None]:
cfg.useful_nodes = qt.UsefulNodeList()

qt.ablate_mlp_and_add_useful_node_tags(cfg, varied_questions, qt.test_maths_questions_and_add_useful_node_tags)
qt.ablate_head_and_add_useful_node_tags(cfg, varied_questions, qt.test_maths_questions_and_add_useful_node_tags)
qt.add_node_attention_tags(cfg, varied_questions)

cfg.useful_nodes.sort_nodes()

# Part 12B: Results: Which nodes are used by the model?

Here are the (attention head and MLP neuron) node in each (question and answer) token position used by the model during predictions.

In [None]:
cfg.useful_nodes.print_node_tags()

 # Part 13: Set up: Show and save Quanta map

 Using the UsefulNodes and filtering their tags, show a 2D map of the nodes and the tag minor versions.

In [None]:
def show_quanta_map( title, standard_quanta, cell_num_shades, \
        filters : qt.FilterNode, major_tag : qt.QType, minor_tag : str, get_node_details,  \
        image_width_inches : int = -1, image_height_inches : int = -1,
        combine_identical_cells : bool = True, cell_fontsize : int = 9 ): \

  test_nodes = cfg.useful_nodes
  if filters is not None:
    test_nodes = qt.filter_nodes(test_nodes, filters)

  ax1, quanta_results, num_results = qt.calc_quanta_map(
      cfg, standard_quanta, cell_num_shades,
      test_nodes, major_tag.value, minor_tag, get_node_details,
      cell_fontsize, combine_identical_cells, image_width_inches, image_height_inches )

  if num_results > 0:
    if cfg.graph_file_suffix > "":
      print("Saving quanta map:", title)
      qt.save_plt_to_file(cfg=cfg, full_title=title)
    else:
      ax1.set_title(cfg.file_config_prefix() + ' ' + title + ' ({} nodes)'.format(len(quanta_results)))

    # Show plot
    plt.show()

# Part 16A: Results: Show failure percentage map

Show the percentage failure rate (incorrect prediction) when individual Attention Heads and MLPs are ablated.

A cell containing "< 1" may add some risk to the accuracy of the overall analysis process. Check to see if this represents a new use case. Improve the test data set to contain more instances of this (new or existing) use case.

In [None]:
show_quanta_map( "Failure Frequency Behavior Per Node", True, qt.FAIL_SHADES, None, qt.QType.FAIL, "", qt.get_quanta_fail_perc, 9, 2 * cfg.n_layers, False)

# Part 16B - Show answer impact behavior map

Show the purpose of each useful cell by impact on the answer digits A0 to Amax.


In [None]:
show_quanta_map( "Answer Impact Behavior Per Node", True, cfg.num_answer_positions, None, qt.QType.IMPACT, "", qt.get_quanta_impact, 9, 2 * cfg.n_layers)

# Part 16C: Result: Show attention map

Show attention quanta of useful heads

In [None]:
# Only maps attention heads, not MLP layers
show_quanta_map( "Attention Behavior Per Head", True, qt.ATTN_SHADES, None, qt.QType.ATTN, "", qt.get_quanta_attention, 9, 3 * cfg.n_layers)

# Part 16C - Show question complexity map

Show the "minimum" addition purpose of each useful cell by S0 to S5 quanta.
Show the "minimum" subtraction purpose of each useful cell by M0 to M5 quanta

In [None]:
if cfg.perc_add() > 0:
  show_quanta_map( "Addition Min-Complexity Behavior Per Node", False,
      qt.MATH_ADD_SHADES, None, qt.QType.MATH_ADD, qt.MathsBehavior.ADD_COMPLEXITY_PREFIX.value, qt.get_maths_min_complexity, 9, 4)

In [None]:
if cfg.perc_sub > 0:
  show_quanta_map( "Postive-answer Subtraction Min-Complexity Per Node", False,
      qt.MATH_SUB_SHADES, None, qt.QType.MATH_SUB, qt.MathsBehavior.SUB_COMPLEXITY_PREFIX.value, qt.get_maths_min_complexity, 9, 6)

In [None]:
if cfg.perc_sub > 0:
  show_quanta_map( "Negative-answer Subtraction Min-Complexity Per Node", False,
      qt.MATH_SUB_SHADES, None, qt.QType.MATH_NEG, qt.MathsBehavior.NEG_COMPLEXITY_PREFIX.value, qt.get_maths_min_complexity, 9, 6)

# Part 19A: Results: Manual interpretation of PCA results

Principal Component Analysis (PCA) is a powerful technique that aids in mechanistic interpretability by simplifying complex datasets into principal components that capture the most significant variance within the data.

This library uses PCA to help understand the purpose of individual useful nodes. For more background refer https://github.com/PhilipQuirke/verified_transformers/blob/main/pca.md

If an attention head and an answer digit An gives an interpretable response (2 or 3 distinct output clusters) on 3 groups of questions aligned to T8, T9 and T10 definitions, then plot the response and add a PCA tag



In [None]:
# Create a cache of sample maths questions based on the T8, T9, T10 categorisation in cfg.tricase_questions_dict
qt.make_maths_tricase_questions(cfg)

In [None]:
# Plot all attention heads with the clearest An selected
if use_pca:

  if cfg.model_name == "add_d5_l1_h3_t15K_s372001":
    qt.manual_nodes_pca(cfg, qt.MathsToken.PLUS,
      [[ 12, 0, 0, 4 ],
      [ 12, 0, 2, 3 ],
      [ 13, 0, 0, 3 ],
      [ 14, 0, 0, 2 ],
      [ 15, 0, 0, 1 ],
      [ 16, 0, 0, 0 ]])

  if cfg.model_name == "add_d5_l2_h3_t15K_s372001":
    qt.manual_nodes_pca(cfg, qt.MathsToken.PLUS,
      [[10, 0, 0, 2 ],
      [ 12, 0, 0, 3 ],
      [ 12, 1, 0, 3 ],
      [ 12, 1, 1, 4 ],
      [ 12, 1, 2, 4 ],
      [ 13, 0, 0, 0 ],
      [ 13, 0, 0, 3 ],
      [ 13, 1, 2, 2 ],
      [ 14, 0, 0, 0 ],
      [ 14, 0, 0, 2 ],
      [ 14, 1, 2, 2 ],
      [ 15, 0, 0, 1 ],
      [ 15, 0, 0, 0 ],
      [ 15, 1, 1, 1 ],
      [ 16, 0, 0, 0 ]])

  if cfg.model_name == "add_d6_l2_h3_t15K_s372001":
    qt.manual_nodes_pca(cfg, qt.MathsToken.PLUS,
      [[11, 0, 1, 1 ], # Reasonable?
      [ 11, 0, 0, 2 ],
      [ 12, 0, 0, 3 ],
      [ 13, 0, 0, 1 ],
      [ 14, 0, 0, 4 ],
      [ 14, 1, 1, 4 ],
      [ 15, 0, 0, 4 ],
      [ 15, 1, 1, 4 ],
      [ 15, 1, 2, 4 ],
      [ 16, 0, 0, 3 ],
      [ 16, 1, 1, 3 ],
      [ 17, 0, 0, 2 ],
      [ 17, 1, 1, 2 ],
      [ 18, 0, 0, 0 ],
      [ 18, 0, 0, 1 ],
      [ 19, 0, 0, 0 ]])

  if cfg.model_name.startswith("sub_d6_l2_h3_t30K"):
    qt.manual_nodes_pca(cfg, qt.MathsToken.MINUS,
      [[14, 0, 1, 0 ],
      [ 15, 0, 0, 5 ],
      [ 15, 1, 1, 0 ],
      [ 15, 1, 1, 1 ],
      [ 15, 1, 1, 2 ],
      [ 15, 1, 1, 3 ],
      [ 15, 1, 1, 4 ],
      [ 15, 1, 2, 0 ],
      [ 15, 1, 2, 1 ],
      [ 15, 1, 2, 2 ],
      [ 15, 1, 2, 3 ],
      [ 16, 0, 0, 0 ],
      [ 16, 0, 0, 1 ],
      [ 16, 0, 0, 2 ],
      [ 16, 0, 0, 3 ],
      [ 16, 0, 0, 4 ],
      [ 16, 0, 0, 5 ],
      [ 16, 1, 0, 4 ],
      [ 16, 1, 1, 0 ],
      [ 16, 1, 1, 1 ],
      [ 16, 1, 1, 2 ],
      [ 16, 1, 1, 3 ],
      [ 16, 1, 1, 4 ],
      [ 16, 1, 2, 0 ],
      [ 16, 1, 2, 1 ],
      [ 16, 1, 2, 2 ],
      [ 16, 1, 2, 3 ],
      [ 16, 1, 2, 4 ],
      [ 16, 1, 2, 5 ],
      [ 17, 0, 0, 0 ],
      [ 17, 0, 0, 1 ],
      [ 17, 0, 0, 2 ],
      [ 17, 0, 0, 3 ],
      [ 17, 1, 0, 3 ],
      [ 17, 1, 0, 4 ],
      [ 17, 1, 2, 0 ],
      [ 17, 1, 2, 4 ],
      [ 18, 0, 0, 0 ],
      [ 18, 0, 0, 1 ],
      [ 18, 0, 0, 2 ],
      [ 18, 0, 2, 0 ],
      [ 18, 1, 2, 3 ],
      [ 19, 0, 0, 0 ],
      [ 19, 1, 2, 2 ],
      [ 20, 0, 0, 0 ],
      [ 18, 0, 2, 0 ],
	    [ 18, 0, 2, 0 ]])

  if cfg.model_name == "mix_d6_l3_h4_t40K_s372001":
    qt.manual_nodes_pca(cfg, qt.MathsToken.PLUS,
      [[ 8, 0, 0, 4 ],
      [  9, 0, 1, 3 ],
      [ 10, 0, 1, 2 ],
      [ 10, 0, 1, 3 ],
      [ 11, 0, 1, 1 ],
      [ 11, 0, 1, 2 ],
      [ 12, 0, 0, 0 ],
      [ 13, 0, 1, 4 ],
      [ 13, 1, 1, 0 ],
      [ 13, 1, 1, 1 ],
      [ 13, 1, 1, 2 ],
      [ 13, 1, 1, 3 ],
      [ 13, 1, 2, 3 ],
      [ 13, 1, 2, 5 ],
      [ 14, 0, 1, 0 ],
      [ 15, 0, 0, 0 ],
      [ 15, 0, 0, 1 ],
      [ 15, 0, 0, 2 ],
      [ 15, 0, 0, 3 ],
      [ 15, 0, 0, 4 ],
      [ 15, 0, 0, 5 ],
      [ 15, 1, 0, 0 ],
      [ 15, 1, 0, 1 ],
      [ 15, 1, 0, 2 ],
      [ 15, 1, 0, 3 ],
      [ 15, 1, 0, 4 ]])

  if cfg.model_name == "ins1_mix_d6_l3_h4_t40K_s372001":
    qt.manual_nodes_pca(cfg, qt.MathsToken.PLUS,
      [[10, 0, 0, 2 ],
      [ 10, 0, 1, 2 ],
      [ 11, 0, 0, 2 ],
      [ 11, 0, 1, 1 ],
      [ 13, 0, 1, 0 ],
      [ 13, 1, 3, 1 ],
      [ 14, 1, 2, 0 ],
      [ 14, 1, 2, 2 ],
      [ 14, 1, 3, 4 ],
      [ 15, 0, 3, 5 ],
      [ 15, 1, 2, 2 ],
      [ 15, 1, 3, 4 ],
      [ 16, 0, 3, 4 ],
      [ 16, 1, 2, 0 ],
      [ 16, 1, 2, 1 ],
      [ 16, 1, 2, 2 ],
      [ 16, 1, 3, 2 ],
      [ 17, 0, 3, 3 ],
      [ 17, 1, 2, 2 ],
      [ 17, 1, 3, 2 ],
      [ 18, 0, 3, 2 ],
      [ 18, 1, 3, 1 ],
      [ 19, 0, 3, 1 ],
      [ 19, 2, 0, 0 ],
      [ 19, 2, 1, 0 ],
      [ 20, 0, 0, 0 ],
      [ 20, 0, 3, 0 ]])

    qt.manual_nodes_pca(cfg, qt.MathsToken.MINUS,
      [[10, 0, 0, 2 ],
      [ 10, 0, 1, 2 ],
      [ 11, 0, 0, 2 ],
      [ 11, 0, 1, 1 ],
      [ 13, 0, 1, 0 ],
      [ 14, 0, 0, 4 ],
      [ 14, 0, 2, 5 ],
      [ 14, 1, 2, 0 ],
      [ 14, 1, 2, 2 ],
      [ 14, 1, 3, 4 ]])

  print('Finished generating plots')

else:
  print( "PCA library failed to import. So PCA not done")

# Part 19B: Results: Automatic interpretation of PCA results (Optional)

Part 19B is manual and selective. This part is automatic. It tests nodes not included in Part 19A, where this first (single) principal component explains 75% or more of the node. It adds a QType.PCA "weak" tag

In [None]:
def auto_node_pca(ax, index, node_location, operation, answer_digit, perc_threshold=0.75):

    title, error_message = qt.maths_tools._build_title_and_error_message(
        cfg=cfg, node_location=node_location, operation=operation, answer_digit=answer_digit
    )

    if (answer_digit, operation) in cfg.tricase_questions_dict:
        test_inputs = cfg.tricase_questions_dict[(answer_digit, operation)]
    else:
        return False

    pca, pca_attn_outputs, title = qt.calc_pca_for_an(
        cfg=cfg, node_location=node_location, title=title, error_message=error_message, test_inputs=test_inputs
    )

    if pca is not None:
      perc = qt.pca_evr_0_percent(pca)
      if perc > perc_threshold:
        qt.maths_tools.plot_pca_for_an(ax, pca_attn_outputs, title)

        major_tag = qt.QType.MATH_ADD if operation == qt.MathsToken.PLUS else qt.QType.MATH_SUB # Does not handle NEG case
        cfg.add_useful_node_tag( node_location, major_tag.value, qt.maths_tools.pca_op_tag(answer_digit, operation, False) )
        return True

    return False

In [None]:
def auto_find_pca_node(node, op, perc_threshold):
  fig, axs = plt.subplots(2, 4) # Allow up to 8 graphs
  fig.set_figheight(4)
  fig.set_figwidth(10)

  index = 0
  for answer_digit in range(cfg.n_digits+1):
    ax = axs[index // 4, index % 4]
    if auto_node_pca(ax, index, node, op, answer_digit, perc_threshold):
      index += 1

  # Remove any graphs we dont need after all
  while index < 2 * 4:
    ax = axs[index // 4, index % 4]
    ax.remove()
    index += 1

  plt.tight_layout()
  plt.show()

In [None]:
def auto_find_pca(operation):
  print("Automatic (weak) PCA tags for", cfg.model_name, "with operation", qt.token_to_char(cfg, operation))
  perc_threshold = 75

  for node in cfg.useful_nodes.nodes:

    # Exclude nodes with a (manual) PCA tag - for any answer digit(s)). Exclude MLP neurons.
    major_tag = qt.QType.MATH_ADD if operation == qt.MathsToken.PLUS else qt.QType.MATH_SUB # Does not handle NEG case
    minor_tag_prefix = qt.MathsBehavior.ADD_PCA_TAG if operation == qt.MathsToken.PLUS else qt.MathsBehavior.SUB_PCA_TAG
    if node.is_head and not node.contains_tag(major_tag.value, minor_tag_prefix.value):
      print( "Doing PCA on node", node.name())

      auto_find_pca_node(node, operation, perc_threshold)

#if False: # Suppress for speed
if use_pca:
  if cfg.perc_add() > 0:
    auto_find_pca(qt.MathsToken.PLUS)
  if cfg.perc_sub > 0:
    auto_find_pca(qt.MathsToken.MINUS)

# Part 20A: Results: Show useful nodes and behaviour tags

In [None]:
cfg.useful_nodes.sort_nodes()
cfg.useful_nodes.print_node_tags()

# Part 20B: Results: Save useful nodes and behaviour tags to json file

In [None]:
# Serialize and save the useful nodes list to a temporary CoLab file in JSON format
print( "Saving useful node list with behavior tags:", main_fname_behavior_json)
cfg.useful_nodes.save_nodes(main_fname_behavior_json)

# Part 22 : Results: Search for model algorithm tasks

Here we find which model nodes perform which specific algorithm task.
- **Automatic searches** for node purposes are preferred, as they applicable to several models, and survive (non-significant, node-reordering) differences between models caused by differences in training.
- **Manually written tests** of node purposes, specific to a single model instance are also supported.

The qt.search_and_tag searches for a task on useful nodes by:
- **filtering** useful nodes, based on "tag" pre-requisites, to find the few nodes worth doing investigating. For more detail refer https://github.com/PhilipQuirke/verified_transformers/blob/main/filter.md
- **intervention ablation** testing on the interesting nodes:
  - The first "store" question is run without hooks
  - The second "clean" question is run with hooks interjecting some data from the "store" run. This run gives, not a "clean" answer, but an "intervened" answer, which mixes the "store" answer and the "clean" answer. Our beliefs about the nodes algorthmic purpose are baked into the store question, clean question and intervened answer.
- An **algorithm tag** is added to all interesting nodes that pass the intervention ablation test(s)

In [None]:
cfg.useful_nodes.reset_node_tags(qt.QType.ALGO.value)
acfg.show_test_failures = False
acfg.show_test_successes = False

## Part 22B: Automated An.SS search

 Search for addition "Use Sum 9" (SS) tasks e.g. 34633+55555=+090188 where D4 and D'4 sum to 9 (4+5), and D3 + D'3 > 10.

In [None]:
qt.search_and_tag( cfg, acfg, qt.add_ss_prereqs, qt.add_ss_test, qt.add_ss_tag)

## Part 22C: Automated An.SC search

Search for addition "Make Carry 1" (SC) tasks e.g. 222222+666966=+0889188 where D2 + D'2 > 10

In [None]:
qt.search_and_tag( cfg, acfg, qt.add_sc_prereqs, qt.add_sc_test, qt.add_sc_tag)

## Part 22D: Automated An.SA search

Search for addition "Simple Add" (SA) tasks e.g. 555555+111111=+0666666 where D3 + D'3 < 10

The SA tasks is sometimes split/shared over 2 attention heads in the same position and layer.


In [None]:
qt.search_and_tag( cfg, acfg, qt.add_sa_prereqs, qt.add_sa_test, qt.add_sa_tag, do_pair_search = True, allow_impact_mismatch = True )

## Part 22E: Automated Dn.ST search

Search for D0.ST to D5.ST with impact "A65432" to "A65" in early tokens.

A0 and A1 are simple to calculate and so do NOT use Dn.ST or DN.STm values. So A0 and A1 are excluded from the answer impact.

In [None]:
cfg.useful_nodes.reset_node_tags(qt.QType.ALGO.value, qt.MathsTask.ST_TAG.value)
qt.search_and_tag( cfg, acfg, qt.add_st_prereqs, qt.add_st_test, qt.add_st_tag)

## Part 22F: Automated An.MD search

Search for positive-answer subtraction "Difference" (MD) tasks e.g. 666666-222222=+0444444 where D3 >= D'3

The MD task may be split/shared over 2 attention heads in the same position at the same layer.


In [None]:
if cfg.perc_sub > 0:
  qt.search_and_tag( cfg, acfg, qt.sub_md_prereqs, qt.sub_md_test, qt.sub_md_tag, do_pair_search = True, allow_impact_mismatch = True )

## Part 22G: Automated An.MB search

Search for positive-answer subtraction "Borrow One" (MB) tasks e.g. 222222-111311=+0110911 where D2 > D'2

In [None]:
if cfg.perc_sub > 0:
  qt.search_and_tag( cfg, acfg, qt.sub_mb_prereqs, qt.sub_mb_test, qt.sub_mb_tag, do_pair_search = True, allow_impact_mismatch = True )

## Part 22H: Automated Dn.MT search

To accurately predict if the answer sign is + or - the model must calculate if
D < D'. To calculate this, the model must calculate Dn < D'n or (Dn = D'n and (Dn-1 < D'n-1 or (Dn-2 = D'n-1 and ( etc. It must predict this before the answer sign is revealed.

Aligned to the Addition calculations, we assume this calculation is based on "TriCase" data. We search for Dn.MT task nodes which are useful in negative-answer questions, with PCA bigram or trigram outputs, in early token positions, attending to Dn, D'n input pairs.

(The combined calculation that gives D < D' is NOT tested here.)

In [None]:
if cfg.perc_sub > 0:
    cfg.useful_nodes.reset_node_tags(qt.QType.ALGO.value, qt.MathsTask.MT_TAG.value)
    acfg.show_test_failures = True
    acfg.show_test_successes = True
    qt.search_and_tag( cfg, acfg, qt.sub_mt_prereqs, qt.sub_mt_test, qt.sub_mt_tag)

## Part 22I: Automated OPR search

For mixed models that do addition and subtraction the operation token "+/-" (in the middle of the question) is key. Find nodes that attend to the question operation.

In [None]:
if cfg.perc_sub > 0:
  qt.search_and_tag( cfg, acfg, qt.opr_prereqs, qt.succeed_test, qt.opr_tag)

## Part 22J: Automated SGN search

For mixed models that do addition and subtraction, and for our subtraction models, the answer sign token "+/-" (at the start of the answer) is important. Find nodes that attend to the answer sign token.

In [None]:
if cfg.perc_sub > 0:
  qt.search_and_tag( cfg, acfg, qt.sgn_prereqs, qt.succeed_test, qt.sgn_tag)

## Part 22K: Automated Dn.ND search

Search for negative-answer subtraction Difference (ND) tasks e.g. 033333-111111=-077778 where D < D'

The ND task may be split/shared over 2 attention heads in the same position at the same layer.

In [None]:
def neg_nd_tag(impact_digit):
  return qt.answer_name(impact_digit) + "." + qt.MathsTask.ND_TAG.value

In [None]:
# These rules are prerequisites for (not proof of) a Neg Difference node
def neg_nd_prereqs(cfg, position, impact_digit):
  # Impacts An and pays attention to Dn and D'n
  return qt.math_common_prereqs(cfg, position, impact_digit, impact_digit)

In [None]:
def neg_nd_test1(cfg, acfg, alter_digit):
  # 033333 - 111111 = -077778. No Dn.NB
  store_question = [cfg.repeat_digit(3), cfg.repeat_digit(1)]
  store_question[0] = store_question[0] // 10 # Convert 333333 to 033333

  # 099999 - 444444 = -344445. No Dn.NB
  clean_question = [cfg.repeat_digit(9), cfg.repeat_digit(4)]
  clean_question[0] = clean_question[0] // 10 # Convert 999999 to 099999

  # When we intervene we expect answer -347445
  intervened_answer = clean_question[0] - clean_question[1] - (7-4) * 10 ** alter_digit

  # Unit test
  if cfg.n_digits == 6 and alter_digit == 3:
    assert store_question[0] == 33333
    assert clean_question[0] == 99999
    assert clean_question[0] - clean_question[1] == -344445
    assert intervened_answer == -347445

  return store_question, clean_question, intervened_answer


def neg_nd_test2(cfg, acfg, alter_digit):
  # 066666 - 222222 = -155556. No Dn.NB
  store_question = [cfg.repeat_digit(6), cfg.repeat_digit(2)]
  store_question[0] = store_question[0] // 10 # Remove top digit

  # 099999 - 333333 = -233334. No Dn.NB
  clean_question = [cfg.repeat_digit(9), cfg.repeat_digit(3)]
  clean_question[0] = clean_question[0] // 10 # Remove top digit

  # When we intervene we expect answer -231334
  intervened_answer = clean_question[0] - clean_question[1] - (5-3) * 10 ** alter_digit

  return store_question, clean_question, intervened_answer


def neg_nd_test(cfg, acfg, node_locations, alter_digit, strong):
  intervention_impact = qt.answer_name(alter_digit)

  store_question, clean_question, intervened_answer = neg_nd_test1(cfg, acfg, alter_digit)
  success1, _, impact_success1 = qt.run_strong_intervention(cfg, acfg, node_locations, store_question, clean_question, qt.MathsToken.MINUS, intervention_impact, intervened_answer)

  store_question, clean_question, intervened_answer = neg_nd_test2(cfg, acfg, alter_digit)
  success2, _, impact_success2 = qt.run_strong_intervention(cfg, acfg, node_locations, store_question, clean_question, qt.MathsToken.MINUS, intervention_impact, intervened_answer)

  success = (success1 and success2) if strong else (impact_success1 and impact_success2)

  if success:
    print( "Test confirmed", acfg.node_names(), "perform A"+str(alter_digit)+".ND = (D"+str(alter_digit)+" + D'"+str(alter_digit)+") % 10 impacting "+intervention_impact+" accuracy.", "" if strong else "Weak")

  return success

In [None]:
if cfg.perc_sub > 0:
  qt.search_and_tag( cfg, acfg, neg_nd_prereqs, neg_nd_test, neg_nd_tag, do_pair_search = True, allow_impact_mismatch = True)

## Part 22L: Automated Dn.NB search

Search for negative-answer subtraction Difference (ND) tasks e.g. 033333-111411=-078078 where D < D' and D2 < D'2

The ND task may be split/shared over 2 attention heads in the same position at the same layer.

In [None]:
def neg_nb_tag(impact_digit):
  return qt.answer_name(impact_digit) + "." + qt.MathsTask.NB_TAG.value

In [None]:
# Prerequisites for negative-answer subtraction "Borrow One" (NB) task
def neg_nb_prereqs(cfg, position, impact_digit):
    # Pays attention to Dn-1 and D'n-1. Impacts An
    return qt.math_common_prereqs(cfg,  position, impact_digit-1, impact_digit)

In [None]:
# Intervention ablation test for negative-answer subtraction "Borrow One" (NB) task
def neg_nb_test(cfg, acfg, node_locations, impact_digit, strong):
    alter_digit = impact_digit - 1

    if alter_digit < 0 or alter_digit >= cfg.n_digits:
        acfg.reset_intervention()
        return False

    intervention_impact = qt.answer_name(impact_digit)

    # 022222 - 111311 = -0089089. Has Dn.MB
    store_question = [cfg.repeat_digit(2), cfg.repeat_digit(1)]
    store_question[0] = store_question[0] // 10 # Convert 222222 to 022222
    store_question[1] += (3 - 1) * (10 ** alter_digit)

    # 077777 - 444444 = -0366667. No Dn.MB
    clean_question = [cfg.repeat_digit(7), cfg.repeat_digit(4)]
    clean_question[0] = clean_question[0] // 10 # Convert 777777 to 077777

    # When we intervene we expect answer -0366677
    intervened_answer = clean_question[0] - clean_question[1] - 10 ** (alter_digit+1)

    success, _, _ = qt.run_strong_intervention(cfg, acfg, node_locations, store_question, clean_question,qt. MathsToken.MINUS, intervention_impact, intervened_answer)

    if success:
        print( "Test confirmed", acfg.node_names(), "perform A"+str(alter_digit)+".NB impacting "+intervention_impact+" accuracy.", "" if strong else "Weak")

    return success

In [None]:
if cfg.perc_sub > 0:
  cfg.useful_nodes.reset_node_tags(qt.QType.ALGO.value, qt.MathsTask.NB_TAG.value)
  qt.search_and_tag( cfg, acfg, neg_nb_prereqs, neg_nb_test, neg_nb_tag, do_pair_search = True, allow_impact_mismatch = True)

# Part 23A: Show algorithm quanta map

Plot the "algorithm" tags generated in previous steps as a quanta map. This is an automatically generated partial explanation of the model algorithm.

Nodes with multiple tags were tagged (found) by more than one of the above task searches

In [None]:
cfg.graph_file_suffix = "pdf" # Else pdf

In [None]:
qt.print_algo_purpose_results(cfg)
print()

show_quanta_map( "Algorithm Purpose Per Node", True, 2, None, qt.QType.ALGO, "", qt.get_quanta_binary, 9)

In [None]:
show_quanta_map( "Attention Behavior Per Head", True, qt.ATTN_SHADES, None, qt.QType.ATTN, "", qt.get_quanta_attention, 9, 3 * cfg.n_layers)

# Part 23B: Show known quanta per answer digit

Each of the late positions are soley focused on calculating one answer digit. Show the data have we collected on late answer digit.  



In [None]:
for position in range(cfg.num_question_positions + 1, cfg.n_ctx() - 1):
  print("Position:", position)

  # Calculate a table of the known quanta for the specified position for each late token position
  qt.calc_maths_quanta_for_position_nodes(cfg, position)

  qt.save_plt_to_file(cfg=cfg, full_title="Quanta At "+ qt.position_name(position))

  plt.show()

# Part 24: Save useful nodes with behaviour and algorithm tags to JSON file

Show a list of the nodes that have proved useful in calculations, together with data on the nodes behavior and algorithmic purposes.
Save the data to a Colab temporary JSON file.



In [None]:
cfg.useful_nodes.print_node_tags(qt.QType.ALGO.value, "", False)

In [None]:
# Serialize and save the useful nodes list with algorithm tags to a temporary CoLab file in JSON format
print( "Saving useful node list with algorithm tags:", main_fname_algorithm_json)
cfg.useful_nodes.save_nodes(main_fname_algorithm_json)

In [None]:
mixed_model = cfg.model_name.startswith("ins1_mix_d6_l3_h4_t40K") or cfg.model_name.startswith("ins2_mix_d6_l4_h4_t40K") or cfg.model_name.startswith("ins3_mix_d6_l4_h3_t40K")

# Part 25 : Results: Test Algorithm - Addition

## Part25A : Model add_d5_l1_h3_t30K. Tasks An.SA, An.SC, An.SS

This 1-layer model cant do all addition questions. This hypothesis mirrors Paper 1. 14/15 heads have purpose assigned. 0/6 neurons have purpose assigned.

In [None]:
# For answer digits (excluding Amax), a An.SA node is needed before the answer digit is revealed
def test_algo_sa(algo_nodes):
  for impact_digit in range(cfg.n_digits):
    cfg.test_algo_clause(algo_nodes, qt.FilterAnd(qt.FilterAlgo(qt.add_sa_tag(impact_digit)), qt.FilterPosition(cfg.an_to_position_name(impact_digit+1), qt.QCondition.MAX)))

# For answer digits (excluding Amax and A0), a An.SC node is needed before the answer digit is revealed
def test_algo_sc(algo_nodes):
  for impact_digit in range(cfg.n_digits):
    if impact_digit > 0:
      cfg.test_algo_clause(algo_nodes, qt.FilterAnd(qt.FilterAlgo(qt.add_sc_tag(impact_digit)), qt.FilterPosition(cfg.an_to_position_name(impact_digit+1), qt.QCondition.MAX)))

# For answer digits (excluding Amax, A1 and A0), a An.SS node is needed before the answer digit is revealed
def test_algo_ss(algo_nodes):
  for impact_digit in range(cfg.n_digits):
    if impact_digit > 1:
      cfg.test_algo_clause(algo_nodes, qt.FilterAnd(qt.FilterAlgo(qt.add_ss_tag(impact_digit)), qt.FilterPosition(cfg.an_to_position_name(impact_digit+1), qt.QCondition.MAX)))

In [None]:
if cfg.model_name == "add_d5_l1_h3_t30K":

  algo_nodes = cfg.start_algorithm_test(acfg)

  test_algo_sa(algo_nodes)
  test_algo_sc(algo_nodes)
  test_algo_ss(algo_nodes)
  cfg.print_algo_clause_results()

  cfg.print_algo_purpose_results(algo_nodes)

## Part25B : Models add_d5/d6_l2_h3_t15K. Tasks An.SA, An.SC, An.SS, An.AC

These 2 and 3-layer models can do addition accurately.

- add_d6_l2_h3_t15K: 21/31 heads have purpose assigned. 0/17 neurons have purpose assigned.

In [None]:
# Before Amax is revealed (as a 0 or 1), there must be a Dn.C node for every digit pair
# For each digit (except A0) there must be either an An.SS or an Dn.C
def test_algo_ac_or_ss(model_nodes):
  for impact_digit in range(cfg.n_digits):
    early_ac = qt.FilterAnd(qt.FilterAlgo(qt.add_st_tag(impact_digit)), qt.FilterPosition(cfg.an_to_position_name(cfg.n_digits+1), qt.QCondition.MAX))
    late_ac = qt.FilterAnd(qt.FilterAlgo(qt.add_st_tag(impact_digit)), qt.FilterPosition(cfg.an_to_position_name(impact_digit+1), qt.QCondition.MAX))
    any_ss = qt.FilterAnd(qt.FilterAlgo(qt.add_ss_tag(impact_digit)), qt.FilterPosition(cfg.an_to_position_name(impact_digit+1), qt.QCondition.MAX))

    if cfg.n_layers == 1:
      # There must be a Dn.SS node for every answer digit except A0
      if impact_digit > 0:
        cfg.test_algo_clause(model_nodes, any_ss)
    else:
      # There must be a Dn.SS node or a Dn.C node for every digit except A0
      if impact_digit > 0:
        cfg.test_algo_clause(model_nodes, qt.FilterOr(any_ss, late_ac))

      # There must a Dn.C node for every digit before the first 1 or 0 digit is calculated
      cfg.test_algo_clause(model_nodes, early_ac)

In [None]:
if cfg.model_name == "add_d5_l2_h3_t15K" or cfg.model_name == "add_d6_l2_h3_t15K":

  algo_nodes = cfg.start_algorithm_test(acfg)

  test_algo_sa(algo_nodes)
  test_algo_sc(algo_nodes)
  test_algo_ac_or_ss(algo_nodes)
  cfg.print_algo_clause_results()

  cfg.print_algo_purpose_results(algo_nodes)

# Part 26: Results: Test Algorithm - Subtraction

## Part 26A : Model sub_d6_l2_h3_t30K. Tasks MD, MB, MZ

This 2-layer model can do subtraction accurately. TBC

In [None]:
# For answer digits (excluding Amax), An.MD and An.MB nodes are needed before the answer digit is revealed
def test_algo_md_mb(algo_nodes):
  for impact_digit in range(cfg.n_digits):
    cfg.test_algo_clause(algo_nodes, qt.FilterAnd(qt.FilterAlgo(qt.sub_md_tag(impact_digit)), qt.FilterPosition(cfg.an_to_position_name(impact_digit+1), qt.QCondition.MAX)))

    cfg.test_algo_clause(algo_nodes, qt.FilterAnd(qt.FilterAlgo(qt.sub_mb_tag(impact_digit)), qt.FilterPosition(cfg.an_to_position_name(impact_digit+1), qt.QCondition.MAX)))


# For answer digits (excluding Amax), An.MZ nodes are needed before the answer digit is revealed
def test_algo_mz(algo_nodes):
  for impact_digit in range(cfg.n_digits):
      pass
      #cfg.test_algo_clause(algo_nodes, qt.FilterAnd(qt.FilterAlgo(sub_mz_tag(impact_digit)), qt.FilterPosition(cfg.an_to_position_name(impact_digit+1), qt.QCondition.MAX)))

In [None]:
if cfg.model_name == "sub_d6_l2_h3_t30K":

  algo_nodes = cfg.start_algorithm_test(acfg)

  test_algo_md_mb(algo_nodes)
  test_algo_mz(algo_nodes)
  cfg.print_algo_clause_results()

  cfg.print_algo_purpose_results(algo_nodes)

## Part 26B: Test Algorithm - Subtraction - Negative Answer

To accurately predict if the answer sign is + or - the model must calculate if
D < D'. To calculate this, the model must calculate Dn < D'n or (Dn = D'n and (Dn-1 < D'n-1 or (Dn-2 = D'n-1 and ( etc. It must predict this before the answer sign is revealed.

We expect to see nodes useful in negative answer questions, with PCA bigram (or trigram) outputs, attending to these input pairs, evaluated in this order, before the answer sign is revealed.

In [None]:
# For answer digits (excluding Amax), An.SC is needed before the answer digit is revealed
def test_algo_sc(algo_nodes):
  sc_locations = {}

  sign_position = cfg.an_to_position_name(cfg.n_digits+1)

  for impact_digit in range(cfg.n_digits):
    # For answer digits (excluding the +/- answer sign and 0 or 1 first answer digit), An.SC is calculated before the answer sign is revealed
    position = cfg.test_algo_clause(algo_nodes,  qt.FilterAnd(
      qt.FilterAlgo(qt.sub_mt_tag(impact_digit)),
      qt.FilterPosition(sign_position, qt.QCondition.MAX)))
    sc_locations[impact_digit] = position

  # Check that sc_locations[6] < sc_locations[5] < sc_locations[4] < etc
  print("SC Locations:", sc_locations)
  for impact_digit in range(cfg.n_digits):
    if impact_digit > 0:
      sc1 = sc_locations[impact_digit]
      sc2 = sc_locations[impact_digit-1]
      description = f"SC Ordering: A{impact_digit}={sc1}, A{impact_digit-1}={sc2}"
      cfg.test_algo_logic(description, sc1 >= 0 and sc2 >= 0 and sc1 < sc2)

In [None]:
if cfg.model_name.startswith("sub_d6_l2_h3_t30K") or cfg.model_name.startswith('ins1_mix_d6_l3_h4_t40K') :

  algo_nodes = cfg.start_algorithm_test(acfg)

  test_algo_sc(algo_nodes)
  cfg.print_algo_clause_results()

  cfg.print_algo_purpose_results(algo_nodes)

# Part 27: Test Algorithm - Mixed Addition and Subtraction model

What algorithm do mixed models use to perform both addition and subtraction? Our working hypothesis is Hypothesis 2 as described in https://github.com/PhilipQuirke/verified_transformers/blob/main/mixed_model.md





## Part 27A: Calculating answer digit A2 in token position A3

The below graph displays the same (behavior and algorithm) data as the quanta maps. Refer https://github.com/PhilipQuirke/verified_transformers/blob/main/mixed_model.md section 27A for more explanation.

    


In [None]:
qt.calc_maths_quanta_for_position_nodes(cfg, 18)
qt.save_plt_to_file(cfg=cfg, full_title="Quanta At "+ qt.position_name(18))
plt.show()

In [None]:
if cfg.model_name.startswith("ins1_mix_d6_l3_h4_t40K"):
  qt.manual_nodes_pca(cfg, qt.MathsToken.PLUS,
    [[ 18, 0, 3, 2 ],
    [ 18, 0, 0, 1 ],
    [ 18, 1, 0, 2 ],
    [ 18, 1, 0, 1 ],
    [ 18, 1, 1, 1 ],
    [ 18, 1, 2, 1 ],
    [ 18, 1, 3, 1 ]])

In [None]:
if cfg.model_name == "ins1_mix_d6_l3_h4_t40K_s372001":
  qt.manual_nodes_pca(cfg, qt.MathsToken.MINUS,
    [[ 18, 0, 3, 3 ],
    [ 18, 0, 3, 2 ],
    [ 18, 0, 0, 1 ],
    [ 18, 0, 3, 1 ],
    [ 18, 1, 0, 1 ],
    [ 18, 1, 1, 1 ],
    [ 18, 1, 2, 1 ],
    [ 18, 1, 3, 1 ]])

## Part 27.H3 Calculate whether D > D' (using NG tasks)

In [None]:
filters = qt.FilterContains(qt.QType.MATH_NEG, "")

#print("NG tagged nodes:", qt.filter_nodes( cfg.useful_nodes, filters ).get_node_names())

show_quanta_map( "Subtraction Behavior NG Nodes", False, 2, filters, qt.QType.MATH_SUB, "", qt.get_maths_min_complexity, 9, 6)
show_quanta_map( "Attention Behavior Per NG Head", True, 10, filters, qt.QType.ATTN, "", qt.get_quanta_attention, 9, 8)
show_quanta_map( "Algorithm Purpose Per NG Node", True, 2, filters, qt.QType.ALGO, "", qt.get_quanta_binary, 9, 6)

In [None]:
if mixed_model:
  algo_nodes = cfg.start_algorithm_test(acfg)

  test_algo_sc(algo_nodes)
  cfg.print_algo_clause_results()

  cfg.print_algo_purpose_results(algo_nodes)

# Part 30: Unit Test automated searches

In [None]:
def check_algo_tag_exists(node_location_as_str, the_tags ):
  node_location = qt.str_to_node_location(node_location_as_str)
  node = cfg.useful_nodes.get_node(node_location)
  assert node is not None

  for the_tag in the_tags:
    if not node.contains_tag(qt.QType.ALGO.value, the_tag):
      print( "Node", node.name(), "is missing tag", the_tag, "It has tags:", node.tags )

In [None]:
print(cfg.model_name)

if cfg.model_name.startswith('add_d6_l2_h3_t15K'):
  check_algo_tag_exists('P11L0H0', ['D2.TC'] )
  check_algo_tag_exists('P12L0H0', ['D3.TC'] )
  check_algo_tag_exists('P14L0H0', ['A5.SS', 'D4.TC'] )
  check_algo_tag_exists('P14L0H2', ['A5.SC', 'D5.TC'] )
  check_algo_tag_exists('P14L1H1', ['OPR'] )
  check_algo_tag_exists('P15L0H0', ['A4.SC'] )
  check_algo_tag_exists('P15L0H1', ['A5.SA'] )
  check_algo_tag_exists('P15L0H2', ['A5.SA'] )
  check_algo_tag_exists('P16L0H0', ['A3.SC'] )
  check_algo_tag_exists('P16L0H1', ['A4.SA'] )
  check_algo_tag_exists('P16L0H2', ['A4.SA'] )
  check_algo_tag_exists('P17L0H0', ['A2.SC'] )
  check_algo_tag_exists('P17L0H1', ['A3.SA'] )
  check_algo_tag_exists('P17L0H2', ['A3.SA'] )
  check_algo_tag_exists('P18L0H0', ['A1.SC'] )
  check_algo_tag_exists('P18L0H1', ['A2.SA'] )
  check_algo_tag_exists('P18L0H2', ['A2.SA'] )
  check_algo_tag_exists('P19L0H0', ['A0.SC'] )
  check_algo_tag_exists('P19L0H1', ['A1.SA'] )
  check_algo_tag_exists('P19L0H2', ['A1.SA'] )
  check_algo_tag_exists('P20L0H1', ['A0.SA'] )
  check_algo_tag_exists('P20L0H2', ['A0.SA'] )

if cfg.model_name.startswith('mix_d6_l3_h4_t40K'):
  check_algo_tag_exists('P8L0H1', ['OPR'] )
  check_algo_tag_exists('P13L2H0', ['A7.NG'] )
  check_algo_tag_exists('P15L0H0', ['A5.SA', 'A5.MD'] )
  check_algo_tag_exists('P15L0H3', ['A5.SA', 'A5.MD'] )
  check_algo_tag_exists('P16L0H3', ['A4.SA.A4', 'A4.MD.A4'] )
  check_algo_tag_exists('P17L0H1', ['A3.NG'] )
  check_algo_tag_exists('P17L0H3', ['A3.SA.A3', 'A3.MD.A3'] )
  check_algo_tag_exists('P18L0H1', ['A2.NG'] )
  check_algo_tag_exists('P18L0H3', ['A2.SA.A2', 'A2.MD.A2'] )
  check_algo_tag_exists('P19L0H1', ['A1.NG'] )
  check_algo_tag_exists('P19L0H3', ['A1.SA.A1', 'A1.MD.A1'] )
  check_algo_tag_exists('P20L0H0', ['A0.SA', 'A0.MD'] )
  check_algo_tag_exists('P20L0H3', ['A0.SA', 'A0.MD'] )
  check_algo_tag_exists('P20L2H1', ['A0.NG'] )

if cfg.model_name.startswith('ins1_mix_d6_l3_h4_t40K'):
  check_algo_tag_exists('P6L0H0', ['OPR'] )
  check_algo_tag_exists('P9L0H0', ['A4.MT'] )
  check_algo_tag_exists('P9L0H1', ['A4.ST'] )
  check_algo_tag_exists('P10L0H1', ['A2.MT'] )
  check_algo_tag_exists('P10L0H3', ['OPR'] )
  check_algo_tag_exists('P12L0H0', ['A3.MT'] )
  check_algo_tag_exists('P12L0H1', ['A3.ST'] )
  check_algo_tag_exists('P12L0H3', ['OPR'] )
  check_algo_tag_exists('P12L1H2', ['A1.MT'] )
  check_algo_tag_exists('P13L0H3', ['OPR'] )
  check_algo_tag_exists('P13L1H0', ['OPR'] )
  check_algo_tag_exists('P13L2H0', ['OPR'] )
  check_algo_tag_exists('P14L0H0', ['A5.SS', 'OPR'] )
  check_algo_tag_exists('P14L0H1', ['OPR'] )
  check_algo_tag_exists('P14L0H2', ['A5.SC', 'A5.ST', 'SGN'] )
  check_algo_tag_exists('P14L1H2', ['SGN'] )
  check_algo_tag_exists('P14L1H3', ['SGN'] )
  check_algo_tag_exists('P15L0H0', ['A4.SC'] )
  check_algo_tag_exists('P15L0H1', ['A5.SA', 'A5.MD'] )
  check_algo_tag_exists('P15L0H2', ['A5.SA', 'A5.MD'] )
  check_algo_tag_exists('P15L0H3', ['SGN'] )
  check_algo_tag_exists('P16L0H0', ['A3.SC'] )
  check_algo_tag_exists('P16L0H1', ['A4.SA', 'A4.MD', 'A4.ND'] )
  check_algo_tag_exists('P16L0H2', ['A4.SA', 'A4.MD', 'A4.ND'] )
  check_algo_tag_exists('P16L0H3', ['SGN'] )
  check_algo_tag_exists('P16L1H0', ['SGN'] )
  check_algo_tag_exists('P16L2H0', ['SGN'] )
  check_algo_tag_exists('P17L0H0', ['A2.SC'] )
  check_algo_tag_exists('P17L0H1', ['A3.SA', 'A3.MD', 'A3.ND'] )
  check_algo_tag_exists('P17L0H2', ['A3.SA', 'A3.MD', 'A3.ND'] )
  check_algo_tag_exists('P17L0H3', ['SGN'] )
  check_algo_tag_exists('P18L0H0', ['A1.SC'] )
  check_algo_tag_exists('P18L0H1', ['A2.SA', 'A2.MD', 'A2.ND'] )
  check_algo_tag_exists('P18L0H2', ['A2.SA', 'A2.MD', 'A2.ND'] )
  check_algo_tag_exists('P19L0H0', ['A0.MB.A1'] )
  check_algo_tag_exists('P19L0H1', ['A1.SA', 'A1.MD', 'A1.ND'] )
  check_algo_tag_exists('P19L0H2', ['A1.SA', 'A1.MD', 'A1.ND'] )
  check_algo_tag_exists('P20L0H1', ['A0.SA', 'A0.MD', 'A0.ND'] )
  check_algo_tag_exists('P20L0H2', ['A0.SA', 'A0.MD', 'A0.ND'] )