# Quanta Maths: Integer Addition and Subtraction in Transformers - Analyze a Model

This Colab analyzes the behavior and algorithm feature sub-tasks performed by nodes in Transformer models.

The models perform integer addition and/or subtraction e.g. 133357+182243=+0315600 and 123450-345670=-0123230. Each digit is a separate token. For 6 digit questions, the model is given 14 "question" (input) tokens, and must then predict the corresponding 8 "answer" (output) tokens.

This Colab follows on from https://github.com/PhilipQuirke/quanta_maths/blob/main/notebooks/QuantaMathsTrain.ipynb which trained the models, and outputs model.pth and training_loss.json

## Tips for using the Colab
 * You can run and alter the code in this CoLab notebook yourself in Google CoLab ( https://colab.research.google.com/ ).
 * To run the notebook, in Google CoLab, **you will need to** go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.
 * Some graphs are interactive!
 * Use the table of contents pane in the sidebar to navigate.
 * Collapse irrelevant sections with the dropdown arrows.
 * Search the page using the search in the sidebar, not CTRL+F.

# Part 0: Import libraries
Imports standard libraries.

Imports "QuantaMaths" public library as "mmi". Refer to [README.md](https://github.com/PhilipQuirke/QuantaMaths/blob/main/README.md) for more detail.

In [1]:
DEVELOPMENT_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    !pip install matplotlib
    !pip install kaleido

except:
    IN_COLAB = False

    def setup_jupyter(install_libraries=False):
        if install_libraries:
            !pip install matplotlib
            !pip install kaleido

        print("Running as a Jupyter notebook - intended for development only!")
        from IPython import get_ipython

        ipython = get_ipython()
        # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
        ipython.magic("load_ext autoreload")
        ipython.magic("autoreload 2")

    # setup_jupyter(install_libraries=True)   # Uncomment if you need to install libraries in notebook.
    setup_jupyter(install_libraries=False)

Running as a Colab notebook
Collecting kaleido
  Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl.metadata (15 kB)
Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl (79.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.9/79.9 MB[0m [31m28.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: kaleido
Successfully installed kaleido-0.2.1


In [None]:
!pip install --upgrade git+https://github.com/PhilipQuirke/quanta_mech_interp.git
import QuantaMechInterp as qmi

Collecting git+https://github.com/PhilipQuirke/quanta_mech_interp.git
  Cloning https://github.com/PhilipQuirke/quanta_mech_interp.git to /tmp/pip-req-build-8dcr6cp0
  Running command git clone --filter=blob:none --quiet https://github.com/PhilipQuirke/quanta_mech_interp.git /tmp/pip-req-build-8dcr6cp0
  Resolved https://github.com/PhilipQuirke/quanta_mech_interp.git to commit e17badd114f60ed4695c750e8073337d97fd8bb4


In [None]:
!pip install --upgrade git+https://github.com/PhilipQuirke/quanta_maths.git
import MathsMechInterp as mmi

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import kaleido
import plotly.io as pio

if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

In [None]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [None]:
import json
import torch
import numpy as np

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

In [None]:
import re
import sklearn # Aka scikit.learn
import skopt # Aka scikit.optimize

In [None]:
import transformer_lens
from transformer_lens.utils import download_file_from_hf
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

# Part 1A: Configuration

Which existing model do we want to analyze?

The existing model weightings created by the sister Colab [QuantaMathsTrain](https://github.com/PhilipQuirke/quanta_maths/blob/main/assets/QuantaMathsTrain.ipynb) are loaded from HuggingFace (in Part 5). Refer https://github.com/PhilipQuirke/quanta_maths/blob/main/README.md for more detail.

In [None]:
# Singleton QuantaTool "main" configuration class. MathsConfig is derived from the chain AlgoConfig > UsefulConfig > ModelConfig
cfg = mmi.MathsConfig()


# Which pre-existing model do we want to analyze? Uncomment exactly one line:

# Addition models
#cfg.set_model_names( "add_d5_l1_h3_t15K_s372001" )  # AddAccuracy=Two9s. Inaccurate as only has one layer. Reproduces previous paper model.
cfg.set_model_names( "add_d5_l2_h3_t15K_s372001" )  # AddAccuracy=Six9s. AvgFinalLoss=1.6e-08
#cfg.set_model_names( "add_d5_l2_h3_t40K_s372001" )  # AddAccuracy=Six9s. AvgFinalLoss=2e-09. (0/M fail). Fewest nodes
#cfg.set_model_names( "add_d6_l2_h3_t15K_s372001" )  # AddAccuracy=Fives. AvgFinalLoss=1.7e-08. (2/M fail)
#cfg.set_model_names( "add_d6_l2_h3_t20K_s173289" )  # AddAccuracy=Six9s. AvgFinalLoss=1.5e-08. (0/M fail). Fewest nodes
#cfg.set_model_names( "add_d6_l2_h3_t20K_s572091" )  # AddAccuracy=Six9s. AvgFinalLoss=7e-09.  (0/M fail)
#cfg.set_model_names( "add_d6_l2_h3_t40K_s372001" )  # AddAccuracy=Six9s. AvgFinalLoss=2e-09. (0/M fail)
#cfg.set_model_names( "add_d7_l2_h3_t45K_s173289" )  # AddAccuracy=Six9s. AvgFinalLoss=3e-09. (0/M fail)
#cfg.set_model_names( "add_d8_l2_h3_t45K_s173289" )  # AddAccuracy=Six9s. AvgFinalLoss=3e-09. (0/M fail)
#cfg.set_model_names( "add_d9_l2_h3_t45K_s173289" )  # AddAccuracy=Six9s. AvgFinalLoss=3e-09. (0/M fail)
#cfg.set_model_names( "add_d10_l2_h3_t40K_s572091" ) # AddAccuracy=Six9s. AvgFinalLoss=7e-09. (1/M fail)
#cfg.set_model_names( "add_d10_l2_h3_t40K_gf_s572091" ) # AddAccuracy=Six9s. AvgFinalLoss=3.5-09. GrokFast. Minor accuracy improvement
#cfg.set_model_names( "add_d11_l2_h3_t50K_s572091" )  # AddAccuracy=Five9s. AvgFinalLoss=8e-09. (2/M fail)
#cfg.set_model_names( "add_d12_l2_h3_t50K_s572091" )  # AddAccuracy=Five9s. AvgFinalLoss=5e-09. (3/M fail)
#cfg.set_model_names( "add_d13_l2_h3_t50K_s572091" )  # AddAccuracy=Six9s. AvgFinalLoss=6.3e-08. (1/M fail)
#cfg.set_model_names( "add_d14_l2_h3_t60K_s572091" )  # AddAccuracy=Three9S. AvgFinalLoss=5.6e-06. (199/M fail)
#cfg.set_model_names( "add_d15_l2_h3_t80K_s572091" ) # AddAccuracy=Five9s. AvgFinalLoss=8.6e-08 (10/M fail)
#cfg.set_model_names( "add_d20_l2_h3_t80K_s572091" ) # AddAccuracy=Poor! AvgFinalLoss=0.20!

# Subtraction model
#cfg.set_model_names( "sub_d6_l2_h3_t30K_s372001" )  # SubAccuracy=Six9s. AvgFinalLoss=5.8e-06
#cfg.set_model_names( "sub_d10_l2_h3_t75K_s173289" )  # SubAccuracy=Two9s. AvgFinalLoss=0.002002. (6672/M fails)
#cfg.set_model_names( "sub_d10_l2_h3_t75K_gf_s173289" )  # SubAccuracy=Two9s. GrokFast. AvgFinalLoss=0.001197. (5246/M fails). Minor accuracy improvement

# Mixed (addition and subtraction) model
#cfg.set_model_names( "mix_d5_l3_h4_t40K_s372001" )  # Add/SubAccuracy=Six9s/Six9s. AvgFinalLoss=9e-09. (0/M fails, 0/M fails)
#cfg.set_model_names( "mix_d6_l3_h4_t40K_s372001" )  # Add/SubAccuracy=Six9s/Six9s. AvgFinalLoss=5e-09. (1/M fail)
#cfg.set_model_names( "mix_d7_l3_h4_t50K_s372001" )  # Add/SubAccuracy=Five9s/Five9s. AvgFinalLoss=2e-08. (2/M fails, 6/M fails)
#cfg.set_model_names( "mix_d8_l3_h4_t60K_s173289" )  # Add/SubAccuracy=Six9s/Five9s. AvgFinalLoss=4.7e-08. (0/M fails, 7/M fails)
#cfg.set_model_names( "mix_d9_l3_h4_t60K_s173289" )  # Add/SubAccuracy=Six9s/Four9s. AvgFinalLoss=3.2e-07. (1/M fails, 33/M fails)
#cfg.set_model_names( "mix_d10_l3_h4_t75K_s173289" )  # Add/SubAccuracy=Five9s/Two9s. AvgFinalLoss=1.125e-06 (2/M fail, 295/M fail)
#cfg.set_model_names( "mix_d10_l3_h4_t75K_gf_s173289" )  # Add/SubAccuracy=Six9s/Three9s. GrokFast. AvgFinalLoss=8.85e-07 (1/M fail, 294/M fail). Minor accuracy improvement
#cfg.set_model_names( "mix_d11_l3_h4_t80K_s572091" )  # Add/SubAccuracy=Six9s/Four9s AvgFinalLoss=3.9e-08 (0/M fail, 13/M fail)
#cfg.set_model_names( "mix_d12_l3_h4_t85K_s572091" )  # Add/SubAccuracy=Five9s/Five9s. AvgFinalLoss=1.7e-08. (2/M fail, 10/M fail)
#cfg.set_model_names( "mix_d13_l3_h4_t85K_s572091" )  # Add/SubAccuracy=Three9s/Two9s. AvgFinalLoss=9.5e-06. (399/M fail, 4164/M fail)

# Mixed models initialized with addition model. Params fine-tuned during training
#cfg.set_model_names( "ins1_mix_d5_l2_h3_t40K_s572091" )  # Add/SubAccuracy=TODO
#cfg.set_model_names( "ins1_mix_d6_l2_h3_t40K_s572091" )  # Add/SubAccuracy=Six9s/Five9s. AvgLoss = 2.4e-08 (5/M fails)
#cfg.set_model_names( "ins1_mix_d6_l3_h3_t40K_s572091" )  # Add/SubAccuracy=Six9s/Five9s. AvgFinalLoss=1.8e-08. (3/M fails)
#cfg.set_model_names( "ins1_mix_d6_l3_h3_t80K_s572091" )  # Add/SubAccuracy=Six9s/Five9s AvgLoss = 1.6e-08 (3/M fails)
#cfg.set_model_names( "ins1_mix_d6_l3_h4_t40K_s372001" )  # Add/SubAccuracy=Six9s/Six9s. AvgFinalLoss=8e-09. MAIN FOCUS
#cfg.set_model_names( "ins1_mix_d6_l3_h4_t40K_s173289" )  # Add/SubAccuracy=Five9s/Five9s. AvgFinalLoss=1.4e-08. (3/M fails, 2/M fails)
#cfg.set_model_names( "ins1_mix_d6_l3_h4_t50K_s572091" )  # Add/SubAccuracy=Six9s/Five9s. AvgFinalLoss=2.9e-08. (4/M fails)
#cfg.set_model_names( "ins1_mix_d7_l3_h4_t50K_s572091" )  # Add/SubAccuracy=Five9s/Six9s. AvgFinalLoss=1.3e-08. (4/M fails, 1/M fails)
#cfg.set_model_names( "ins1_mix_d8_l3_h4_t70K_s572091" )  # Add/SubAccuracy=Four9s/Two9s. AvgFinalLoss=7.2e-06. (50/M fails, 1196/M fails)
#cfg.set_model_names( "ins1_mix_d9_l3_h4_t70K_s572091" )  # Add/SubAccuracy=TODO. AvgFinalLoss=TODO. (50/M fails, TODO/M fails)
#cfg.set_model_names( "ins1_mix_d10_l3_h3_t50K_s572091" )  # Add/SubAccuracy=Five9s/Five9s. AvgFinalLoss 6.3e-07 (6/M fails, 7/M fails)
#cfg.set_model_names( "ins1_mix_d10_l3_h3_t50K_gf_s572091" ) # Add/SubAccuracy=Five9s/Two9s. GrokFast. AvgFinalLoss=4.0e-06 (2/M fails, 1196/M fails). Worse accuracy than without GF!
#cfg.set_model_names( "ins1_mix_d11_l3_h4_t75K_s572091" )  # Add/SubAccuracy=TODO. AvgFinalLoss=TODO. (TODO/M fails, TODO/M fails)

# Mixed model initialized with addition model. Reset useful heads every 100 epochs during training
#cfg.set_model_names( "ins2_mix_d6_l4_h4_t40K_s372001" )  # Add/SubAccuracy=Five9s/Five9s. AvgFinalLoss=1.7e-08. (3/M fails, 8/M fails)

# Mixed model initialized with addition model. Reset useful heads & MLPs every 100 epochs during training
#cfg.set_model_names( "ins3_mix_d6_l4_h3_t40K_s372001" )  # Add/SubAccuracy=Four9s/Two9s. AvgFinalLoss=3.0e-04. (17/M fails, 3120/M fails)

# Addition&Subtraction model initialized with addition model. Reset useful heads & MLPs every training epoch
#cfg.set_model_names( "ins4_mix_d6_l4_h3_t40K_s372001" )  # AvgFinalLoss=4.7e-06

# Mixed models initialized with addition model. Insert mode 5
#cfg.set_model_names( "ins5_mix_d6_l3_h4_t30K_s775824" )  # Add/SubAccuracy=???/??? TODO
#cfg.set_model_names( "ins5_mix_d6_l2_h4_t30K_s775824" )  # Add/SubAccuracy=???/??? TODO

# Part 1B: Configuration: Input and Output file names



In [None]:
base_repo_name = 'PhilipQuirke'
model_pth_fname = 'model.pth'
model_training_loss_fname = 'training_loss.json'
model_behaviors_fname = 'behaviors.json'
model_features_fname = 'features.json'

cfg.hf_repo = base_repo_name + "/QuantaMaths_" + cfg.model_name # "PhilipQuirke/QuantaMaths_add_d6_l2_h3_t15K_s372001"

In [None]:
cfg.batch_size = 512 # Default analysis batch size
if cfg.n_layers >= 3 and cfg.n_heads >= 4:
  cfg.batch_size = 256 # Reduce batch size to avoid memory constraint issues.

In [None]:
# Update "cfg" with additional training config (including cfg.insert_*) with information stored in:
#      https://huggingface.co/PhilipQuirke/QuantaMaths_ins1_mix_d6_l3_h4_t40K_s372001/training_loss.json"
training_data_json = qmi.download_huggingface_json(cfg.hf_repo, model_training_loss_fname)
training_loss_list = qmi.load_training_json(cfg, training_data_json)
print('Loaded main model training config / loss from', cfg.hf_repo, model_training_loss_fname)

In [None]:
def print_config():
  print("%Add=", cfg.perc_add, "%Sub=", cfg.perc_sub, "%Mult=", cfg.perc_mult, "InsertMode=", cfg.insert_mode, "File=", cfg.model_name)

In [None]:
print_config()
print("weight_decay=", cfg.weight_decay, "lr=", cfg.lr, "batch_size=", cfg.batch_size)
print('Model files will be read from HuggingFace repo:', base_repo_name, model_pth_fname, 'and', model_training_loss_fname)
print('Model analysis tags will be saved to Colab temporary files:', model_behaviors_fname, 'and', model_features_fname)

# Part 2: Results: Model training loss

In [None]:
print_config()
print( "Avg loss over last 5 epochs (AvgFinalLoss)", cfg.avg_final_loss)
print( "Final epoch loss", cfg.final_loss)

# Part 3A: Set Up: Vocabulary / Embedding / Unembedding

  

In [None]:
cfg.initialize_maths_token_positions()
mmi.set_maths_vocabulary(cfg)
mmi.set_maths_question_meanings(cfg)
print(cfg.token_position_meanings)

# Part 3B: Set Up: Create model

In [None]:
# Structure is documented at https://neelnanda-io.github.io/TransformerLens/transformer_lens.html#transformer_lens.HookedTransformerConfig.HookedTransformerConfig
ht_cfg = cfg.get_HookedTransformerConfig()

# Create the main transformer model
cfg.main_model = HookedTransformer(ht_cfg)

# Part 4: Set Up: Loss Function & Data Generator
This maths loss function and data generator are imported from QuantaTools as logits_to_tokens_loss, loss_fn, maths_data_generator_core and maths_data_generator.

In [None]:
# Define "iterator" maths "questions" data generator function. Invoked using next().
cfg.set_seed(cfg.analysis_seed)
ds = mmi.maths_data_generator( cfg )

In [None]:
# Generate sample data generator (unit test)
print(next(ds)[:3,:])

# Part 5: Set Up: Load Model from HuggingFace

In [None]:
print("Loading model from HuggingFace", cfg.hf_repo, model_pth_fname)

cfg.main_model.load_state_dict(download_file_from_hf(repo_name=cfg.hf_repo, file_name=model_pth_fname, force_is_torch=True))
cfg.main_model.eval()

# Part 6: Set Up: Create sample maths questions

Create a batch of manually-curated mathematics test questions, and cache some sample model prediction outputs.

In [None]:
# Singleton QuantaTool "ablation intervention" configuration class
acfg = qmi.acfg
acfg.reset_ablate()
cfg.configure_acfg_singleton()

In [None]:
varied_questions = mmi.make_maths_test_questions_and_answers(cfg)
num_varied_questions = varied_questions.shape[0]

qmi.a_set_ablate_hooks(cfg) # Updates acfg
qmi.a_calc_mean_values(cfg, varied_questions)

In [None]:
print("Num questions:", num_varied_questions, "Question length:", len(varied_questions[0]))
print("Sample Question:", varied_questions[0])

# Part 7: Results: Can the model correctly predict sample questions?

Ask the model to predict the varied_questions (without intervention) to see if the model gets them all right. Categorize answers by complexity

In [None]:
# Test maths question prediction accuracy on the sample questions provided.
# Does NOT use UsefulInfo.* information
# Used to estimate the accuracy of the model's predictions.
# Returns a reduced set of questions - removing questions that the model failed to answer.
print_config()

acfg.show_test_failures = True
varied_questions = mmi.test_maths_questions_by_complexity(cfg, acfg, varied_questions)
acfg.show_test_failures = False

num_varied_questions = varied_questions.shape[0]

# Part 9 : Results: Is the model 99.9999% accurate?

The model's accuracy is 99.9999% (aka "six 9s") if it can predict one million questions with 0 or 1 failed predictions. If it has 2 to 10 failed predictions the model's accuracy is called 99.999% (aka "five 9s").

Note: There may be very rare edge cases (say 1 in ten million) that did not appear in the test questions. So this empirical test can **not** prove 100% accuracy.

If the model fails some questions, consider:
- Adding a few of the failures into the "test questions" into mmi.make_maths_test_questions_and_answers()
- Understand the "use case(s)" driving these failures
- Alter mmi.maths_data_generator_core to enrich the training data with examples if these use case(s)
- Retrain the model using the QuantaMathsTrain Colab.  

Takes ~25 mins to run for ins_mix_d6_l3_h4_t40K_s372001

In [None]:
run_1m_test = False

Enriching data means adding more "hard" and subtractions questions. Enriched data was used during training. Using enrich_data does not much impact the  model accuracy measured here.  

In [None]:
enrich_data = True

In [None]:
if run_1m_test:
    acfg.show_test_failures = False
    mmi.test_correctness_on_num_questions(cfg, acfg, num_questions=1000000, enrich_data=enrich_data)
    acfg.show_test_failures = False

# Part 10: Set Up: Which token positions are used by the model?

Ablate all nodes in each (question and answer) token position (by overriding the model memory aka residual stream). If the model's prediction loss increases, the token position is useful to the algorithm. Unused token positions are excluded from further analysis. Used to populate the UsefulInfo.useful_positions data. This is token **position level** information.  

In [None]:
num_failures_list = []
acfg.show_test_failures = False

for position in range(cfg.n_ctx):
  # Test accuracy of model in predicting question answers. Ablates all nodes at acfg.ablate_position. Does NOT use UsefulInfo.* information.
  num_fails = mmi.test_maths_questions_by_impact(cfg, acfg, varied_questions, position, ablate=True)

  if num_fails > 0:
    # Add position to UsefulInfo.useful_positions
    cfg.add_useful_position(position)
    num_failures_list += [num_fails]
  else:
    num_failures_list += "."

# Part 11: Results: Which token positions are used by the model?

Which token positions are is used in the model's predictions? Unused token positions are excluded from further analysis.

In [None]:
print_config()
print("num_questions=", num_varied_questions)
print("useful_positions=", cfg.useful_positions )
print()

cfg.calc_position_failures_map(num_failures_list)
qmi.save_plt_to_file(cfg=cfg, full_title="Failures When Position Ablated")
plt.show()

# Part 12A: Set Up: Which nodes are used by the model?

Here we ablate each (attention head and MLP neuron) node in each (question and answer) token position see if the model's prediction loss increases. If loss increases then the "node + token position" is used by the algorithm. Used to calculate the UsefulInfo.useful_node_location. This is **position+node level** information. Unused nodes are excluded from further analysis.


In [None]:
cfg.useful_nodes = qmi.UsefulNodeList()

qmi.ablate_mlp_and_add_useful_node_tags(cfg, varied_questions, mmi.test_maths_questions_and_add_useful_node_tags)
qmi.ablate_head_and_add_useful_node_tags(cfg, varied_questions, mmi.test_maths_questions_and_add_useful_node_tags)
qmi.add_node_attention_tags(cfg, varied_questions)

cfg.useful_nodes.sort_nodes()

# Part 12B: Results: Which nodes are used by the model?

Here are the (attention head and MLP neuron) node in each (question and answer) token position used by the model during predictions.

In [None]:
cfg.useful_nodes.print_node_tags()

 # Part 13: Set up: Show and save Quanta map

 Using the UsefulNodes and filtering their tags, show a 2D map of the nodes and the tag minor versions.

In [None]:
def show_quanta_map( title, major_tag : qmi.QType, minor_tag : str, get_node_details,
        image_width_inches : int = -1, image_height_inches : int = -1,
        blue_shades : bool = True, cell_num_shades : int = 6,
        filters : qmi.FilterNode = None, cell_fontsize : int = 9,
        combine_identical_cells : bool = True, show_perc_circles : bool = False ):

  test_nodes = cfg.useful_nodes
  if filters is not None:
    test_nodes = qmi.filter_nodes(test_nodes, filters)

  ax1, quanta_results, num_results = qmi.calc_quanta_map(
      cfg, blue_shades, cell_num_shades,
      test_nodes, major_tag.value, minor_tag, get_node_details,
      cell_fontsize, combine_identical_cells, show_perc_circles,
      image_width_inches, image_height_inches )

  if num_results > 0:
    if cfg.graph_file_suffix > "":
      print("Saving quanta map:", title)
      qmi.save_plt_to_file(cfg=cfg, full_title=title)
    else:
      ax1.set_title(cfg.file_config_prefix + ' ' + title + ' ({} nodes)'.format(len(quanta_results)))

    plt.show()

# Part 16A: Results: Show failure percentage map

Show the percentage failure rate (incorrect prediction) when individual Attention Heads and MLPs are ablated. Lower percentages correspond to rarer edge cases. The grey space represents nodes that are not used by the model.

A cell containing "< 1" may add some risk to the accuracy of the overall analysis process. Check to see if this represents a new use case. Improve the test data set to contain more instances of this (new or existing) use case.

In [None]:
show_quanta_map( "Failure Frequency Behavior Per Node", qmi.QType.FAIL, "", qmi.get_quanta_fail_perc,
                cell_num_shades = qmi.FAIL_SHADES, combine_identical_cells = False, show_perc_circles = True)

# Part 16B - Show answer impact behavior map

This map shows the answer digit(s) A0 .. An+1 impacted when we ablate each useful node in the  model. Cells containing values like A5..2 are used in multiple prediction steps to calculate multiple answer digits e.g. A2 to A5. Late token
positions focus on predicting one answer digit - partially by using results calculated in early token positions.  


In [None]:
show_quanta_map( "Answer Impact Behavior Per Node", qmi.QType.IMPACT, "", qmi.get_quanta_impact,
                cell_num_shades = cfg.num_answer_positions)

# Part 16C: Result: Show attention map

This map shows the input tokens each useful attention head attends to at each token position.



In [None]:
# Only maps attention heads, not MLP layers
show_quanta_map( "Attention Behavior Per Head", qmi.QType.ATTN, "", qmi.get_quanta_attention,
                # image_height_inches = 8, # image_width_inches = 11,
                cell_num_shades = qmi.ATTN_SHADES )

# Part 16C - Show question complexity map

This map shows whether each useful node is used
to answer the quesstion classes: addition (S), positive-answer subtraction (M) and/or negative-answer subtraction (N) questions. In mixed models, nodes may be used in prediction of two or three questions classes. That is they are polysemantic.





In [None]:
if cfg.perc_sub > 0:
  num_add, num_sub, num_neg, num_triple, num_double, num_single = mmi.get_maths_nodes_operation_coverage(cfg.useful_nodes.nodes)
  print( "# useful nodes:", len(cfg.useful_nodes.nodes))
  print( "# useful nodes involved in S, M, N operations:", num_add, num_sub, num_neg )
  print( "# useful nodes involved in 3, 2, 1 operations:", num_triple, num_double, num_single)
  print()

  # For each useful cell, show if addition (S), positive-answer subtraction (M) and negative-answer subtraction (N) questions relies on the node.
  show_quanta_map( "Maths Operation Coverage", qmi.QType.MATH, "", mmi.get_maths_operation_complexity,
                  blue_shades = False, cell_num_shades = 4, combine_identical_cells = False)

This map shows the simpliest (lowest complexity) addition quanta S0, S1, etc impacted when we ablate each node in an addition or mixed model. To answer S0 questions, only the S0 nodes are used. To answer S1 questions, S0 and S1 nodes are used, etc.

In [None]:
if cfg.perc_add > 0:
  # For each useful cell, show the minimum addition question complexity that relies on the node, as measured using quanta S0, S1, S2, ...
  show_quanta_map( "Addition Min-Complexity", qmi.QType.MATH_ADD, mmi.MathsBehavior.ADD_COMPLEXITY_PREFIX.value, mmi.get_maths_min_complexity,
                  blue_shades = False, cell_num_shades = qmi.MATH_ADD_SHADES)

This map shows the simpliest (lowest complexity) subtraction quanta M0, M1, etc impacted when we ablate each node in an subtraction or mixed model. To answer M0 questions, only the M0 nodes are used. To answer M1 questions, M0 and M1 nodes are used, etc.

In [None]:
if cfg.perc_sub > 0:
  # For each useful cell, show the minimum "positive-answer subtraction" question complexity that relies on the node, as measured using quanta M0, M1, M2, ...
  show_quanta_map( "Positive-answer Subtraction Min-Complexity", qmi.QType.MATH_SUB, mmi.MathsBehavior.SUB_COMPLEXITY_PREFIX.value, mmi.get_maths_min_complexity,
                  blue_shades = False, cell_num_shades = qmi.MATH_SUB_SHADES)

In [None]:
if cfg.perc_sub > 0:
  # For each useful cell, show the minimum "negative-answer subtraction" question complexity that relies on the node, as measured using quanta N0, N1, N2, ...
  show_quanta_map( "Negative-answer Subtraction Min-Complexity", qmi.QType.MATH_NEG, mmi.MathsBehavior.NEG_COMPLEXITY_PREFIX.value, mmi.get_maths_min_complexity,
                  blue_shades = False, cell_num_shades = qmi.MATH_SUB_SHADES)

# Part 19A: Detect attention head output clustering w.r.t ST8/9/10 (Manual)

Principal Component Analysis (PCA) is a powerful technique that aids in mechanistic interpretability by simplifying complex datasets into principal components that capture the most significant variance within the data.

This library uses PCA to help understand the purpose of individual useful nodes. For more background refer https://github.com/PhilipQuirke/QuantaMaths/blob/main/pca.md

If an attention head and an answer digit An gives an interpretable response (2 or 3 distinct output clusters) on 3 groups of questions aligned to ST8, ST9 and ST10 definitions, then plot the response and add a PCA tag



In [None]:
# Create a cache of sample maths questions based on the T8, T9, T10 categorisation in cfg.tricase_questions_dict
mmi.make_maths_tricase_questions(cfg)

cfg.useful_nodes.reset_node_tags(qmi.QType.MATH_ADD.value, mmi.MathsBehavior.ADD_PCA_TAG.value)
cfg.useful_nodes.reset_node_tags(qmi.QType.MATH_SUB.value, mmi.MathsBehavior.SUB_PCA_TAG.value)
cfg.useful_nodes.reset_node_tags(qmi.QType.MATH_NEG.value, mmi.MathsBehavior.NEG_PCA_TAG.value)

In [None]:
manual_pca = False

# Plot all attention heads with the clearest An selected. Data is manually selected
if cfg.model_name == "add_d5_l1_h3_t15K_s372001":
  manual_pca = True
  mmi.manual_nodes_pca(cfg, mmi.MathsToken.PLUS,
    [[ 12, 0, 0, 4 ],
    [ 12, 0, 2, 3 ],
    [ 13, 0, 0, 3 ],
    [ 14, 0, 0, 2 ],
    [ 15, 0, 0, 1 ]])

elif cfg.model_name == "add_d5_l2_h3_t15K_s372001":
  manual_pca = True
  mmi.manual_nodes_pca(cfg, mmi.MathsToken.PLUS,
    [[10, 0, 0, 2 ],
    [ 10, 0, 2, 1 ],
    [ 12, 0, 0, 3 ],
    [ 12, 1, 0, 3 ],
    [ 12, 1, 1, 4 ],
    [ 12, 1, 2, 4 ],
    [ 13, 0, 0, 3 ],
    [ 13, 1, 2, 2 ],
    [ 14, 0, 0, 2 ],
    [ 14, 1, 2, 2 ],
    [ 15, 0, 0, 1 ],
    [ 15, 1, 1, 1 ]])

elif cfg.model_name == "add_d6_l2_h3_t15K_s372001":
  manual_pca = True
  mmi.manual_nodes_pca(cfg, mmi.MathsToken.PLUS,
    [[11, 0, 1, 1 ],
    [ 11, 0, 0, 2 ],
    [ 12, 0, 0, 3 ],
    [ 13, 0, 0, 1 ],
    [ 13, 0, 1, 0 ],
    [ 14, 0, 0, 4 ],
    [ 14, 1, 1, 4 ],
    [ 15, 0, 0, 4 ],
    [ 15, 1, 1, 4 ],
    [ 15, 1, 2, 4 ],
    [ 16, 0, 0, 3 ],
    [ 16, 1, 1, 3 ],
    [ 17, 0, 0, 2 ],
    [ 17, 1, 1, 2 ],
    [ 18, 0, 0, 1 ]])


elif cfg.model_name == "add_d6_l2_h3_t20K_s173289":
  manual_pca = True
  mmi.manual_nodes_pca(cfg, mmi.MathsToken.PLUS,
    [[ 14, 1, 0, 5 ],
    [ 14, 1, 1, 4 ],
    [ 14, 1, 2, 4 ],
    [ 15, 1, 1, 4 ],
    [ 15, 1, 2, 4 ],
    [ 16, 1, 1, 3 ],
    [ 16, 1, 2, 3 ],
    [ 17, 1, 1, 2 ],
    [ 18, 1, 1, 1 ]])

elif cfg.model_name == "add_d6_l2_h3_t20K_s572091":
  manual_pca = True
  mmi.manual_nodes_pca(cfg, mmi.MathsToken.PLUS,
    [[ 10, 0, 0, 3 ],
    [ 11, 0, 0, 2 ],
    [ 12, 0, 0, 1 ],
    [ 14, 0, 0, 4 ],
    [ 14, 1, 0, 4 ],
    [ 14, 1, 1, 4 ],
    [ 14, 1, 2, 3 ],
    [ 15, 0, 0, 4 ],
    [ 15, 1, 0, 4 ],
    [ 15, 1, 1, 4 ],
    [ 15, 1, 2, 4 ],
    [ 16, 0, 0, 3 ],
    [ 16, 1, 0, 3 ],
    [ 16, 1, 1, 3 ],
    [ 17, 0, 0, 2 ],
    [ 17, 1, 1, 2 ],
    [ 18, 0, 0, 1 ]])

elif cfg.model_name == "add_d10_l2_h3_t40K_s572091":
  manual_pca = True
  mmi.manual_nodes_pca(cfg, mmi.MathsToken.PLUS,
    [[ 20, 0, 0, 1 ],
    [ 23, 1, 0, 8 ],
    [ 24, 1, 0, 7 ],
    [ 22, 1, 0, 8 ],
    [ 25, 1, 0, 6 ],
    [ 26, 1, 0, 5 ],
    [ 27, 1, 0, 4 ],
    [ 27, 1, 2, 4 ],
    [ 28, 1, 0, 3 ],
    [ 29, 1, 0, 2 ],
    [ 30, 1, 0, 1 ]])

elif cfg.model_name.startswith("sub_d6_l2_h3_t30K"):
  manual_pca = True
  mmi.manual_nodes_pca(cfg, mmi.MathsToken.MINUS,
    [[ 9, 0, 1, 3 ],
    [ 10, 0, 1, 2 ],
    [ 11, 0, 1, 1 ],
    [ 13, 0, 1, 4 ],
    [ 13, 1, 2, 5 ],
    [ 15, 0, 0, 5 ],
    [ 15, 1, 1, 0 ],
    [ 15, 1, 1, 1 ],
    [ 15, 1, 1, 2 ],
    [ 15, 1, 1, 3 ],
    [ 15, 1, 1, 4 ],
    [ 15, 1, 2, 0 ],
    [ 15, 1, 2, 1 ],
    [ 15, 1, 2, 2 ],
    [ 15, 1, 2, 3 ],
    [ 16, 0, 0, 0 ],
    [ 16, 0, 0, 1 ],
    [ 16, 0, 0, 2 ],
    [ 16, 0, 0, 3 ],
    [ 16, 0, 0, 4 ],
    [ 16, 0, 0, 5 ],
    [ 16, 1, 0, 4 ],
    [ 16, 1, 1, 0 ],
    [ 16, 1, 1, 1 ],
    [ 16, 1, 1, 2 ],
    [ 16, 1, 1, 3 ],
    [ 16, 1, 1, 4 ],
    [ 16, 1, 2, 0 ],
    [ 16, 1, 2, 1 ],
    [ 16, 1, 2, 2 ],
    [ 16, 1, 2, 3 ],
    [ 16, 1, 2, 4 ],
    [ 16, 1, 2, 5 ],
    [ 17, 0, 0, 0 ],
    [ 17, 0, 0, 1 ],
    [ 17, 0, 0, 2 ],
    [ 17, 0, 0, 3 ],
    [ 17, 1, 0, 3 ],
    [ 17, 1, 0, 4 ],
    [ 17, 1, 2, 4 ],
    [ 18, 0, 0, 0 ],
    [ 18, 0, 0, 1 ],
    [ 18, 0, 0, 2 ],
    [ 18, 1, 0, 2 ],
    [ 18, 1, 2, 3 ],
    [ 19, 0, 0, 0 ],
    [ 19, 1, 2, 2 ],
    [ 20, 0, 0, 0 ]])

elif cfg.model_name == "mix_d6_l3_h4_t40K_s372001":
  manual_pca = True
  mmi.manual_nodes_pca(cfg, mmi.MathsToken.PLUS,
    [[ 8, 1, 0, 4 ],
    [ 11, 1, 0, 1 ],
    [ 12, 1, 0, 0 ],
    [ 13, 1, 1, 1 ],
    [ 14, 2, 1, 5 ],
    [ 15, 2, 1, 4 ],
    [ 16, 2, 1, 3 ],
    [ 17, 2, 1, 2 ],
    [ 18, 1, 0, 4 ],
    [ 18, 2, 1, 1 ]])

elif cfg.model_name == "ins1_mix_d6_l3_h4_t40K_s372001":
  manual_pca = True
  mmi.manual_nodes_pca(cfg, mmi.MathsToken.PLUS,
    [[10, 0, 0, 2 ],
    [ 10, 0, 1, 2 ],
    [ 11, 0, 0, 2 ],
    [ 11, 0, 1, 1 ],
    [ 13, 0, 1, 0 ],
    [ 13, 1, 3, 1 ],
    [ 14, 1, 2, 0 ],
    [ 14, 1, 2, 2 ],
    [ 14, 1, 3, 4 ],
    [ 15, 0, 3, 5 ],
    [ 15, 1, 2, 2 ],
    [ 15, 1, 3, 4 ],
    [ 16, 0, 3, 4 ],
    [ 16, 1, 2, 0 ],
    [ 16, 1, 2, 1 ],
    [ 16, 1, 2, 2 ],
    [ 16, 1, 3, 2 ],
    [ 17, 0, 3, 3 ],
    [ 17, 1, 2, 2 ],
    [ 17, 1, 3, 2 ],
    [ 18, 0, 3, 2 ],
    [ 18, 1, 3, 1 ],
    [ 19, 0, 3, 1 ],
    [ 19, 2, 0, 0 ],
    [ 19, 2, 1, 0 ],
    [ 20, 0, 0, 0 ],
    [ 20, 0, 3, 0 ]])

  mmi.manual_nodes_pca(cfg, mmi.MathsToken.MINUS,
    [[10, 0, 0, 2 ],
    [ 10, 0, 1, 2 ],
    [ 11, 0, 0, 2 ],
    [ 11, 0, 1, 1 ],
    [ 13, 0, 1, 0 ],
    [ 14, 0, 0, 4 ],
    [ 14, 0, 2, 5 ],
    [ 14, 1, 2, 0 ],
    [ 14, 1, 2, 2 ],
    [ 14, 1, 3, 4 ]])

# Part 19B: Setup: Detect attention head output clustering w.r.t ST8/9/10 (Auto)

Automatic detection of attention heads that have output clustered into 2 or 3 clusters aligned to ST8, ST9 and ST10 categories.

In [None]:
# PCA explained variance [0] percent threshold
evr_perc_threshold = 30
# Silhouette Score threshold. Range is 0 to 100
silhouette_threshold = 20
# Calinski-Harabasz Score (aka Variance Ratio Criterion) threshold. Ratio of the sum of between-clusters dispersion and of within-cluster dispersion. Higher values indicate better-defined clusters.
calinski_harabasz_threshold = 50
# Label Agreement Score threshold. Custom metric to measure how well the clustering aligns with the three question types. It ranges from 0 to 100.
label_agreement_threshold = 50

In [None]:
def auto_node_pca(ax, index, node_location, operation, answer_digit):

    base_title, error_message = mmi._build_title_and_error_message(
        cfg=cfg, node_location=node_location, operation=operation, answer_digit=answer_digit
    )

    if (answer_digit, operation) in cfg.tricase_questions_dict:
        test_inputs = cfg.tricase_questions_dict[(answer_digit, operation)]
    else:
        return False

    # Full_title is "P10L0H1 A3 78/43/62/42" = NodeLocation AnswerDigit EVR[0]/MaxSilhouetteScore/MaxCalinskiHarabaszScore/MaxLabelAgreementScore
    pca, pca_attn_outputs, full_title, cluster_results = qmi.calc_pca_for_an(
        cfg=cfg, node_location=node_location, title=base_title, error_message=error_message, test_inputs=test_inputs
    )

    if pca is not None:
        evr_perc = qmi.pca_evr_0_percent(pca)
        if evr_perc > evr_perc_threshold:

            silhouette_scores = cluster_results['silhouette_scores']
            calinski_harabasz_scores = cluster_results['calinski_harabasz_scores']
            label_agreement_scores = cluster_results['label_agreement_scores']

            silhouette_score = max( silhouette_scores['2_clusters'], silhouette_scores['3_clusters'] )
            calinski_harabasz_score = max( calinski_harabasz_scores['2_clusters'], calinski_harabasz_scores['3_clusters'])
            label_agreement_score = max( label_agreement_scores['2_clusters'], label_agreement_scores['3_clusters'])

            if silhouette_score >= silhouette_threshold and calinski_harabasz_score >= calinski_harabasz_threshold and label_agreement_score > label_agreement_threshold:
                mmi.plot_pca_for_an(ax, pca_attn_outputs, full_title)

                major_tag = qmi.QType.MATH_ADD if operation == mmi.MathsToken.PLUS else qmi.QType.MATH_SUB # Does not handle NEG case
                cfg.add_useful_node_tag( node_location, major_tag.value, mmi.pca_op_tag(answer_digit, operation) )
                return True

    return False

In [None]:
def auto_find_pca(operation):
    print("Automatic PCA tags for", cfg.model_name, "with operation ", qmi.token_to_char(cfg, operation))
    title = cfg.model_name + "_PCA_" + qmi.token_to_char(cfg, operation)

    n_cols, n_rows, fig, axs = mmi.plot_nodes_pca_start_core(4, 5)

    index = 0
    for node in cfg.useful_nodes.nodes:
        if node.is_head:
          for answer_digit in range(cfg.n_digits+1):
            ax = axs[index // n_cols, index % n_cols]
            if auto_node_pca(ax, index, node, operation, answer_digit):
              index += 1

              if index == n_cols * n_rows:
                  mmi.plot_nodes_pca_end(n_cols, n_rows, axs, cfg, title, index)
                  n_cols, n_rows, fig, axs = mmi.plot_nodes_pca_start_core(4, 5)
                  index = 0

    if index > 0:
        mmi.plot_nodes_pca_end(n_cols, n_rows, axs, cfg, title, index)

# Part 19B: Results: Detect attention head output clustering w.r.t ST8/9/10 (Auto)

Automatic detection of attention heads that have output clustered into 2 or 3 clusters aligned to ST8, ST9 and ST10 categories.
Output includes sklearn warnings and many include plots that are not 2 or 3 clusters.

In [None]:
if not manual_pca:
    if cfg.perc_add > 0:
        auto_find_pca(mmi.MathsToken.PLUS)
    if cfg.perc_sub > 0:
        auto_find_pca(mmi.MathsToken.MINUS)

# Part 20: Results: Show useful nodes and behaviour tags

In [None]:
cfg.useful_nodes.sort_nodes()
cfg.useful_nodes.print_node_tags()

# Part 21: Results: Save useful nodes and behaviour tags to json file

In [None]:
# Serialize and save the useful nodes list to a temporary CoLab file in JSON format
print( "Saving useful node list with behavior tags:", model_behaviors_fname)
cfg.useful_nodes.save_nodes(model_behaviors_fname)

# Part 22 : Results: Search for model algorithm tasks

Here we find which model nodes perform which specific algorithm task.
- **Automatic searches** for node purposes are preferred, as they applicable to several models, and survive (non-significant, node-reordering) differences between models caused by differences in training.
- **Manually written tests** of node purposes, specific to a single model instance are also supported.

The qmi.search_and_tag searches for a task on useful nodes by:
- **filtering** useful nodes, based on "tag" pre-requisites, to find the few nodes worth doing investigating. For more detail refer https://github.com/PhilipQuirke/QuantaMaths/blob/main/filter.md
- **intervention ablation** testing on the interesting nodes:
  - The first "store" question is run without hooks
  - The second "clean" question is run with hooks interjecting some data from the "store" run. This run gives, not a "clean" answer, but an "intervened" answer, which mixes the "store" answer and the "clean" answer. Our beliefs about the nodes algorthmic purpose are baked into the store question, clean question and intervened answer.
- An **algorithm tag** is added to all interesting nodes that pass the intervention ablation test(s)

In [None]:
cfg.useful_nodes.reset_node_tags(qmi.QType.ALGO.value)
acfg.show_test_failures = False
acfg.show_test_successes = False

## Part 22B: Automated An.SS search

 Search for addition "Use Sum 9" (SS) tasks e.g. 34633+55555=+090188 where D4 and D'4 sum to 9 (4+5), and D3 + D'3 > 10.

In [None]:
qmi.search_and_tag( cfg, acfg, mmi.add_ss_functions )

## Part 22C: Automated An.SC search

Search for addition "Make Carry 1" (SC) tasks e.g. 222222+666966=+0889188 where D2 + D'2 > 10.

(Sometimes model chooses to use ST **instead** of SC. Sometimes model chooses to use ST **and** SC. For A1, model can **accurately** use just SC. For A0,  SC and ST are not needed.)

In [None]:
#acfg.show_test_failures = True
qmi.search_and_tag( cfg, acfg, mmi.add_sc_functions )

## Part 22D: Automated An.SA search

Search for addition "Simple Add" (SA) tasks e.g. 555555+111111=+0666666 where D3 + D'3 < 10

The SA tasks is sometimes split/shared over 2 attention heads in the same position and layer.


In [None]:
#acfg.show_test_failures = True
qmi.search_and_tag( cfg, acfg, mmi.add_sa_functions,
                  do_pair_search = True, allow_impact_mismatch = True )

## Part 22E: Automated An.ST search

Search for A0.ST to A5.ST with impact "A65432" to "A65" in early tokens.

A0 and A1 are simple to calculate and so do NOT use An.ST or An.STm values. So A0 and A1 are excluded from the answer impact.

In [None]:
qmi.search_and_tag( cfg, acfg, mmi.add_st_functions,
                  do_pair_search = True, allow_impact_mismatch = True )

## Part 22F: Automated An.MD search

Search for positive-answer subtraction "Difference" (MD) tasks e.g. 666666-222222=+0444444 where D3 >= D'3

The MD task may be split/shared over 2 attention heads in the same position at the same layer.


In [None]:
if cfg.perc_sub > 0:
  qmi.search_and_tag( cfg, acfg, mmi.sub_md_functions,
                    do_pair_search = True, allow_impact_mismatch = True )

## Part 22G: Automated An.MB search

Search for positive-answer subtraction "Borrow One" (MB) tasks e.g. 222222-111311=+0110911 where D2 > D'2

(Sometimes model chooses to use MT **instead** of MB. Sometimes model chooses to use MT **and** MB. For A1, model can **accurately** use just MB. For A0,  MB and MT are not needed.)

In [None]:
if cfg.perc_sub > 0:
  qmi.search_and_tag( cfg, acfg, mmi.sub_mb_functions,
                    allow_impact_mismatch = True )

## Part 22H: Automated An.MT search

For accuracy, the addition algorithm calculates cascading "carry one" in early tokens using the An.ST sub-task. Paralleling this, the subtraction algorithm calculates cascading "borrow one" in early tokens using the An.MT sub-task.

This section locates An.MT sub-tasks.

Define An.MT = +1 if Dn > D'n else 0 if Dn == D'n else -1  
The cascading "borrow one" calculation is then:
A3.MV = fn(A3.MT, fn(A2.MT, fn(A1.MT, A0.MT)))
where f(A,B) = +1 if A=1 or (A == 0 and B <> -1) else -1
and the output "-1" means a cascading borrow one.

The above fn could be simplified, but the (below) GT sub-task often relies on the above definition. The tricase An.MT definition also mirrors the addition An.ST definition.

In [None]:
if cfg.perc_sub > 0:
    #acfg.show_test_failures = True
    #acfg.show_test_successes = True
    qmi.search_and_tag( cfg, acfg, mmi.sub_mt_functions,
                      do_pair_search = True)

## Part 22I: Automated OPR search

For mixed models that do addition and subtraction the operation token "+/-" (in the middle of the question) is key. Find nodes that attend to the question operation.

In [None]:
if cfg.perc_sub > 0 and cfg.perc_add > 0 :
  qmi.search_and_tag( cfg, acfg, mmi.opr_functions )

## Part 22J: Automated SGN search

For mixed models that do addition and subtraction, and for our subtraction models, the answer sign token "+/-" (at the start of the answer) is important. Find nodes that attend to the answer sign token.

In [None]:
if cfg.perc_sub > 0:
  qmi.search_and_tag( cfg, acfg, mmi.sgn_functions )

## Part 22K: Automated An.ND search

Search for negative-answer subtraction Difference (ND) tasks e.g. 033333-111111=-077778 where D < D'

The ND task may be split/shared over 2 attention heads in the same position at the same layer.

In [None]:
if cfg.perc_sub > 0:
  qmi.search_and_tag( cfg, acfg, mmi.neg_nd_functions,
                    do_pair_search = True, allow_impact_mismatch = True)

## Part 22L: Automated An.NB search

Search for negative-answer subtraction Borrow One (NB) tasks e.g. 033333-111411=-078078 where D < D' and D2 < D'2

(Sometimes model chooses to use NT **instead** of NB. Sometimes model chooses to use NT **and** NB.)

In [None]:
if cfg.perc_sub > 0:
  qmi.search_and_tag( cfg, acfg, mmi.neg_nb_functions,
                    allow_impact_mismatch = True)

## Part 22M: Automated An.GT search

Both SUB (e.g. 00600-00201=+000399) and NEG (00100-00201=-000101) questions rely on knowing whether D > D'. How is this calculated?

Approach 1: Model has specific GT nodes:
Define An.GT = +1 if Dn > D'n else 0 if Dn = D'n else -1  
When n_digits = 4, D > D' = f(A3.GT, fn(A2.GT, fn(A1.GT, A0.GT)))
Where f(A,B) = +1 if A=1 or (A == 0 and B <> -1) else -1

Approach 2: Model leverages the existing MT nodes:
When n_digits = 4, D > D' = f(A3.MT, fn(A2.MT, fn(A1.MT, A0.MT)))
where f(A,B) = +1 if A=1 or (A == 0 and B <> -1) else -1

Usually, one node performs both say A3.MT and A3.GT sub-tasks, but in some models the A3.MT and A3.GT functions are performed by distinct nodes. Hence we test for the MT and GT behavior separately.

Both approaches mirrors the calculation style used in ADD to calculate Amax as 1 or 0.

In [None]:
if cfg.perc_sub > 0:
    #acfg.show_test_successes = False
    qmi.search_and_tag(cfg, acfg, mmi.sub_gt_functions,
                      allow_impact_mismatch = True)

# Part 23: Show algorithm quanta map

This map shows a compacted view of all useful token positions (horizontally) and all useful attention heads and MLP layers
(vertically) used in predictions as blue cells. In each cell, the algorithm sub-task(s) Base Add SA, Make Carry SC, TriCase ST, etc found by automated search with ablation testing are shown.

Sometimes a subtask is logIcally shared across two attention heads. The SA, MD and ND subtasks sometimes do this.

This map plots the "algorithm" tags generated in previous steps as a quanta map. This is an automatically generated partial explanation of the model algorithm.

Nodes with multiple tags were tagged (found) by more than one of the above subtask searches.

In [None]:
# If a cell only has an OPR tag or only has a SGN tag then we do not understand its purpose.
# The tag is just an "attention" fact. Remove these tags from the algorithm map
# (A cell that has both OPR and SGN tags, we believe it is a "Select question case" node. We keep it)
for node in cfg.useful_nodes.nodes:
    tags = node.filter_tags(qmi.QType.ALGO.value)
    if len(tags) == 1:
        only_tag = tags[0]
        if mmi.MathsTask.OPR_TAG.value in only_tag or mmi.MathsTask.SGN_TAG.value in only_tag:
            print( "Removing", node.name(), only_tag)
            node.reset_tags(qmi.QType.ALGO.value)

In [None]:
print_config()
qmi.print_algo_purpose_results(cfg)

In [None]:
# Show useful nodes that have identified algorithm sub-task tags
show_quanta_map( "Maths Purpose Per Node", qmi.QType.ALGO, "", qmi.get_quanta_binary,
                 cell_num_shades = 2)

In [None]:
# Show ALL useful nodes with their algorithm sub-task tags (if any)
show_quanta_map( "Maths Purpose All Nodes", qmi.QType.IMPACT, "", qmi.get_quanta_algo,
                 cell_num_shades = 3)

# Part 24: Show known quanta per answer digit

Each of the late positions are soley focused on calculating one answer digit. Show the data have we collected on late answer digit.  



In [None]:
for position in range(cfg.num_question_positions + 1, cfg.n_ctx - 1):
  print("Position:", position)

  # Calculate a table of the known quanta for the specified position for each late token position
  mmi.calc_maths_quanta_for_position_nodes(cfg, position)

  qmi.save_plt_to_file(cfg=cfg, full_title="Quanta At "+ qmi.position_name(position))

  plt.show()

# Part 25: Compare ST and SC
The sub-tasks ST and SC are similar: They both take Dn,D'n inputs (10x10) and generate "carry one" outputs. They differ in that ST occurs in early tokens and has tri-state output, whereas SC occurs in late tokens and has bi-state output. For a sample mixed model, this figure shows PCA results comparing ST and SC output for A2 and A3.

In [None]:
if cfg.perc_add > 0 and cfg.n_layers >= 2:
    a2st = cfg.useful_nodes.get_node_by_tag(qmi.QType.ALGO.value, "A2.ST")
    a2sc = cfg.useful_nodes.get_node_by_tag(qmi.QType.ALGO.value, "A2.SC")
    a3st = cfg.useful_nodes.get_node_by_tag(qmi.QType.ALGO.value, "A3.ST")
    a3sc = cfg.useful_nodes.get_node_by_tag(qmi.QType.ALGO.value, "A3.SC")

    if a2st is not None and a2sc is not None and a3st is not None and a3sc is not None:
      # For all nodes, the attention head output may be transformed by the MLP layer. The images below do NOT show this.
      mmi.manual_nodes_pca(cfg, mmi.MathsToken.PLUS,
          [[a2st.position, a2st.layer, a2st.num, 2], # A2.ST Trigram needed
          [a2sc.position, a2sc.layer, a2sc.num, 2],  # A2.SC Bigram needed (Trigram superset okay)
          [a3st.position, a3st.layer, a3st.num, 3],  # A3.ST Trigram needed
          [a3sc.position, a2sc.layer, a2sc.num, 3]]) # A3.SC Bigram needed (Trigram superset okay)

# Part 26: Save useful nodes with behaviour and algorithm tags to JSON file

Show a list of the nodes that have proved useful in calculations, together with data on the nodes behavior and algorithmic purposes.
Save the data to a Colab temporary JSON file.



In [None]:
cfg.useful_nodes.print_node_tags(qmi.QType.ALGO.value, "", False)

In [None]:
# Serialize and save the useful nodes list with feature tags to a temporary CoLab file in JSON format
print( "Saving useful node list with feature tags:", model_features_fname)
cfg.useful_nodes.save_nodes(model_features_fname, qmi.QType.ALGO.value)