# Verified Integer Mathematics in Transformers - Analyse the Model

This CoLab analyses a Transformer model that performs integer addition, subtraction and multiplication e.g. 133357+182243=+0315600, 123450-345670=-0123230 and 000345*000823=+283935. Each digit is a separate token. For 6 digit questions, the model is given 14 "question" (input) tokens, and must then predict the corresponding 8 "answer" (output) tokens.

## Tips for using the Colab
 * You can run and alter the code in this CoLab notebook yourself in Google CoLab ( https://colab.research.google.com/ ).
 * To run the notebook, in Google CoLab, **you will need to** go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.
 * Some graphs are interactive!
 * Use the table of contents pane in the sidebar to navigate.
 * Collapse irrelevant sections with the dropdown arrows.
 * Search the page using the search in the sidebar, not CTRL+F.

# Part 0: Import libraries
Imports standard libraries.

Imports "verified_transformer" public library as "qt". This library is specific to this CoLab's "QuantaTool" approach to transformer analysis. Refer https://github.com/PhilipQuirke/verified_transformers/blob/main/README.md for more detail.

In [None]:
DEVELOPMENT_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")

    !pip install matplotlib

    !pip install kaleido
    !pip install transformer_lens
    !pip install torchtyping
    !pip install transformers

    !pip install numpy
    !pip install scikit-learn

except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook


In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import kaleido
import plotly.io as pio

if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

In [None]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [None]:
import json
import torch
import torch.nn.functional as F
import numpy as np
import random
import itertools
import re
from enum import Enum

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import textwrap

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache


In [None]:
# Import Principal Component Analysis (PCA) library
use_pca = True
try:
  from sklearn.decomposition import PCA
except Exception as e:
  print("pca import failed with exception:", e)
  use_pca = False

  # Sometimes version conflicts means the PCA library does not import. This workaround partially fixes the issue
  !pip install --upgrade numpy
  !pip install --upgrade scikit-learn

  # To complete workaround, now select menu option "Runtime > Restart session and Run all".
  stop

In [None]:
# Refer https://github.com/PhilipQuirke/verified_transformers/blob/main/README.md
!pip install --upgrade git+https://github.com/PhilipQuirke/verified_transformers.git
import QuantaTools as qt

# Part 1A: Configuration

Which existing model do we want to analyse?

The existing model weightings created by the sister Colab [VerifiedArithmeticTrain](https://github.com/PhilipQuirke/transformer-maths/blob/main/assets/VerifiedArithmeticTrain.ipynb) are loaded from HuggingFace. Refer https://github.com/PhilipQuirke/verified_transformers/blob/main/README.md for more detail.

In [None]:
# Singleton QuantaTool configuration class. MathsConfig is derived from AlgoConfig > UsefulConfig > ModelConfig
cfg = qt.MathsConfig()

# Singleton QuantaTool ablation intervention configuration class
acfg = qt.acfg

In [None]:
# Which existing model do we want to analyse?
# cfg.model_name = "" # Use configuration specified in Part 1A
# cfg.model_name = "add_d5_l1_h3_t30K"  # 5 digit addition model. Inaccurate as only has one layer. Can predict S0, S1 and S2 complexity questions
# cfg.model_name = "add_d5_l2_h3_t15K"  # 5 digit addition model
# cfg.model_name = "add_d6_l2_h3_t15K"  # 6 digit addition model
# cfg.model_name = "sub_d6_l2_h3_t30K"  # 6 digit subtraction model
# cfg.model_name = "mix_d6_l3_h4_t40K"  # 6 digit addition and subtraction model. AvgFinalLoss=8e-09
cfg.model_name = "ins1_mix_d6_l3_h4_t40K"  # 6 digit addition / subtraction model. Initialise with addition model. Handles 1m Qs for Add and Sub
# cfg.model_name = "ins2_mix_d6_l4_h4_t40K"  # 6 digit addition / subtraction model. Initialised with addition model. Reset useful heads every 100 epochs. AvgFinalLoss=7e-09. Fails 1m Qs
# cfg.model_name = "ins3_mix_d6_l4_h3_t40K"  # 6 digit addition / subtraction model. Initialised with addition model. Reset useful heads & MLPs every 100 epochs. AvgFinalLoss=2.6e-06. Fails 1m Qs

# Part 1B: Configuration: Input and Output file names



In [None]:
# Needed when user changes model_name and reruns this Colab a second time
cfg.reset_useful()
cfg.reset_algo()
cfg.initialize_maths_token_positions()
acfg.reset_ablate()

if cfg.model_name != "":
  # Update cfg member data n_digits, n_layers, n_heads, n_training_steps from model_name
  cfg.parse_model_name()

  cfg.perc_sub = 0
  if cfg.model_name.startswith("sub_") :
    cfg.perc_sub = 100

  elif cfg.model_name == "mix_d6_l3_h4_t40K" :
    cfg.batch_size = 256
    cfg.perc_sub = 66 # Train on 66% subtraction and 33% addition question batches

  elif cfg.model_name == "ins1_mix_d6_l3_h4_t40K" :
    cfg.batch_size = 256
    cfg.perc_sub = 80 # Train on 80% subtraction and 20% addition question batches
    # Initialise with add_d6_l2_h3_t15K.pth.

  elif cfg.model_name == "ins2_mix_d6_l4_h4_t40K" :
    cfg.batch_size = 256
    cfg.perc_sub = 80 # Train on 80% subtraction and 20% addition question batches
    # Initialise with add_d6_l2_h3_t15K.pth. Train & reset useful heads every 100 epochs

  elif cfg.model_name == "ins3_mix_d6_l4_h3_t40K" :
    cfg.batch_size = 256
    cfg.perc_sub = 80 # Train on 80% subtraction and 20% addition question batches
    # Initialise with add_d6_l2_h3_t15K.pth. Trained & reset useful heads & MLPs every 100 epochs

In [None]:
main_fname = cfg.file_config_prefix()
main_fname_pth = main_fname + '.pth'
main_fname_behavior_json = main_fname + '_behavior.json'
main_fname_algorithm_json = main_fname + '_algorithm.json'

def print_config():
  print("%Mult=", cfg.perc_mult, "%Sub=", cfg.perc_sub, "%Add=", cfg.perc_add(), "File", main_fname)

print_config()
print('Main model will be read from HuggingLab file', main_fname_pth)
print('Main model behavior analysis tags will save to Colab temporary file', main_fname_behavior_json)
print('Main model algorithm analysis tags will save to Colab temporary file', main_fname_algorithm_json)

# Part 3A: Set Up: Vocabulary / Embedding / Unembedding

  

In [None]:
qt.set_maths_vocabulary(cfg)
qt.set_maths_question_meanings(cfg)
print(cfg.token_position_meanings)

# Part 3B: Set Up: Create model

In [None]:
# Transformer creation

# Structure is documented at https://neelnanda-io.github.io/TransformerLens/transformer_lens.html#transformer_lens.HookedTransformerConfig.HookedTransformerConfig
ht_cfg = HookedTransformerConfig(
    n_layers = cfg.n_layers,
    n_heads = cfg.n_heads,
    d_model = cfg.d_model,
    d_head = cfg.d_head,
    d_mlp = cfg.d_mlp(),
    act_fn = cfg.act_fn,
    normalization_type = 'LN',
    d_vocab = cfg.d_vocab,
    d_vocab_out = cfg.d_vocab,
    n_ctx = cfg.n_ctx(),
    init_weights = True,
    device = "cuda",
    seed = cfg.training_seed,
)

cfg.main_model = HookedTransformer(ht_cfg)

# Part 4: Set Up: Loss Function & Data Generator
This maths loss function and data generator are imported from QuantaTools as logits_to_tokens_loss, loss_fn, maths_data_generator_core and maths_data_generator.

In [None]:
# Define "iterator" maths "questions" data generator function. Invoked using next().
ds = qt.maths_data_generator( cfg )

In [None]:
# Generate sample data generator (unit test)
print(next(ds)[:3,:])

# Part 5: Set Up: Load Model from HuggingFace

In [None]:
main_repo_name="PhilipQuirke/VerifiedArithmetic"
print("Loading model from HuggingFace", main_repo_name, main_fname_pth)

cfg.main_model.load_state_dict(utils.download_file_from_hf(repo_name=main_repo_name, file_name=main_fname_pth, force_is_torch=True))
cfg.main_model.eval()

# Part 7A: Set Up: Create sample maths questions

Create a batch of manually-curated mathematics test questions, and cache some sample model prediction outputs.

In [None]:
varied_questions = qt.make_maths_test_questions_and_answers(cfg)
num_varied_questions = varied_questions.shape[0]

qt.a_set_ablate_hooks(cfg)
qt.a_calc_mean_values(cfg, varied_questions)

In [None]:
print("Num questions:", num_varied_questions, "Question length:", len(varied_questions[0]))

# Part 7B: Results: Can the model correctly predict sample questions?

Ask the model to predict the varied_questions (without intervention) to see if the model gets them all right. Categorise answers by complexity

In [None]:
# Test maths question prediction accuracy on the sample questions provided.
# Does NOT use acfg.* or UsefulInfo.* information
# Used to estimate the accuracy of the model's predictions.
# Returns a reduced set of questions - removing questions that the model failed to answer.
print_config()
#acfg.show_test_failures = True
varied_questions = qt.test_maths_questions_by_complexity(cfg, acfg, varied_questions)
num_varied_questions = varied_questions.shape[0]

# Part 9 : Results: Can the model do 1 million questions without error?

If the model passes this test, this is evidence (not proof) that the model is fully accurate. There may be very rare edge cases (say 1 in ten million) that did not appear in the test questions. Even if you believe you know all the edge cases, and have enriched the training data to contain them, you may not have thought of all edge cases, so this is not proof.

If the model fails this test:
- Add a few of the failures into the "test questions" in part 6C
- Understand the "use case(s)" driving these failures
- Alter the Training CoLab data_generator_core to enrich the training data with examples if these use case(s) and retrain the model.  

Takes ~25 mins to run (successfully) for ins_mix_d6_l3_h4_t40K_s372001

In [None]:
def one_million_questions_core():
  global ds

  acfg.verbose = False

  cfg.analysis_seed = 345621 # Randomly chosen
  ds = qt.maths_data_generator() # Re-initialise the data generator

  the_successes = 0
  the_fails = 0

  num_batches = 1000000//cfg.batch_size
  for epoch in range(num_batches):
      tokens = next(ds)

      the_fails = qt.test_maths_questions_by_impact(cfg, acfg, tokens, 0, False)

      if the_fails> 0:
        break

      the_successes = the_successes + cfg.batch_size

      if epoch % 100 == 0:
          print("Batch", epoch, "of", num_batches, "#Successes=", the_successes)

  print("successes", the_successes, "num_fails", the_fails)
  if the_fails > 0:
    "WARNING: Model is not fully accurate. It failed the 1M Q test"

In [None]:
def one_million_questions():
  store_perc_sub = cfg.perc_sub
  store_perc_mult = cfg.perc_mult

  print_config()
  print()

  if cfg.perc_add() > 0:
    print("Addition:")
    cfg.perc_sub = 0
    cfg.perc_mult = 0
    one_million_questions_core()

  if store_perc_sub > 0:
    print("Subtraction:")
    cfg.perc_sub = 100
    cfg.perc_mult = 0
    one_million_questions_core()
    print()

  cfg.perc_sub = store_perc_sub
  cfg.perc_mult = store_perc_mult


# Takes ~25 minutes to run
# one_million_questions()

# Part 10: Set Up: Which token positions are used by the model?

Ablate all nodes in each (question and answer) token position (by overriding the model memory aka residual stream). If the model's prediction loss increases, the token position is useful to the algorithm. Unused token positions are excluded from further analysis. Used to populate the UsefulInfo.useful_positions data. This is token **position level** information.  

In [None]:
num_failures_list = []

for position in range(cfg.n_ctx()):
  # Test accuracy of model in predicting question answers. Ablates all nodes at acfg.ablate_position. Does NOT use UsefulInfo.* information.
  num_fails = qt.test_maths_questions_by_impact(cfg, acfg, varied_questions, position, True)

  if num_fails > 0:
    # Add position to UsefulInfo.useful_positions
    cfg.add_useful_position(position)
    num_failures_list += [num_fails]
  else:
    num_failures_list += "."

In [None]:
def save_plt_to_file( full_title ):
  if cfg.graph_file_suffix > "":
    filename = cfg.file_config_prefix() + full_title + '.' + cfg.graph_file_suffix
    filename = filename.replace(" ", "").replace(",", "").replace(":", "").replace("-", "")
    plt.savefig(filename, bbox_inches='tight', pad_inches=0)

# Part 11: Results: Which token positions are used by the model?

Which token positions are is used in the model's predictions? Unused token positions are excluded from further analysis.




In [None]:
print_config()
print("num_questions=", num_varied_questions)
print("useful_positions=", cfg.useful_positions )
print()

cfg.calc_position_failures_map(num_failures_list)
save_plt_to_file("Failures When Position Ablated")
plt.show()

# Part 12A: Set Up: Which nodes are used by the model?

Here we ablate each (attention head and MLP neuron) node in each (question and answer) token position see if the model's prediction loss increases. If loss increases then the "node + token position" is used by the algorithm. Used to calculate the UsefulInfo.useful_node_location. This is **position+node level** information.


In [None]:
cfg.useful_nodes = qt.UsefulNodeList()

acfg.verbose = False
qt.ablate_mlp_and_add_useful_node_tags(cfg, varied_questions, qt.test_maths_questions_and_add_useful_node_tags)
qt.ablate_head_and_add_useful_node_tags(cfg, varied_questions, qt.test_maths_questions_and_add_useful_node_tags)
qt.add_node_attention_tags(cfg, varied_questions)

cfg.useful_nodes.sort_nodes()

# Part 12B: Results: Which nodes are used by the model?

Here are the (attention head and MLP neuron) node in each (question and answer) token position used by the model during predictions.

In [None]:
cfg.useful_nodes.print_node_tags()

 # Part 13: Set up: Show and save Quanta map

 Using the UsefulNodes and filtering their tags, show a 2D map of the nodes and the tag minor versions.

In [None]:
def show_quanta_map( title, standard_quanta, shades, filters : qt.FilterNode, major_tag : qt.QType, minor_tag : str, get_node_details, base_fontsize = 10, max_width = 10 ):

  test_nodes = cfg.useful_nodes if filters == None else qt.filter_nodes(cfg.useful_nodes, filters)

  ax1, quanta_results, num_results = qt.calc_quanta_map(cfg, standard_quanta, shades, test_nodes, major_tag.value, minor_tag, get_node_details, base_fontsize, max_width )

  if cfg.graph_file_suffix > "":
    print("Saving quanta map:", title)
    save_plt_to_file(title)
  else:
    ax1.set_title(cfg.file_config_prefix() + ' ' + title + ' ({} nodes)'.format(len(quanta_results)))

  # Show plot
  plt.show()

  return num_results

# Part 16A: Results: Show failure percentage map

Show the percentage failure rate (incorrect prediction) when individual Attention Heads and MLPs are ablated.

A cell containing "< 1" may add some risk to the accuracy of the overall analysis process. Check to see if this represents a new use case. Improve the test data set to contain more instances of this (new or existing) use case.

In [None]:
show_quanta_map( "Failure Frequency Behavior Per Node", True, qt.FAIL_SHADES, None, qt.QType.FAIL, "", qt.get_quanta_fail_perc, 9)

# Part 16B - Show answer impact behavior map

Show the purpose of each useful cell by impact on the answer digits A0 to Amax.


In [None]:
show_quanta_map( "Answer Impact Behavior Per Node", True, cfg.num_answer_positions, None, qt.QType.IMPACT, "", qt.get_quanta_impact, 9, 6)

# Part 16C: Result: Show attention map

Show attention quanta of useful heads

In [None]:
# Only maps attention heads, not MLP layers
show_quanta_map( "Attention Behavior Per Head", True, qt.ATTN_SHADES, None, qt.QType.ATTN, "", qt.get_quanta_attention, 10, 6)

# Part 16C - Show question complexity map

Show the "minimum" addition purpose of each useful cell by S0 to S5 quanta.
Show the "minimum" subtraction purpose of each useful cell by M0 to M5 quanta

In [None]:
if cfg.perc_add() > 0:
  show_quanta_map( "Addition Min-Complexity Behavior Per Node", False, qt.MATH_ADD_SHADES, None, qt.QType.MATH_ADD, "", qt.get_maths_min_complexity)

In [None]:
if cfg.perc_sub > 0:
  show_quanta_map( "Subtraction Min-Complexity Behavior Per Node", False, qt.MATH_SUB_SHADES, None, qt.QType.MATH_SUB, "", qt.get_maths_min_complexity)

#Part 19A: Set Up: Calc and graph PCA decomposition.


In [None]:
# Create a cache of sample maths questions based on the T8, T9, T10 categorisation in cfg.tricase_questions_dict
qt.make_maths_tricase_questions(cfg)

def pca_evr_0_percent(pca):
  return int(round(pca.explained_variance_ratio_[0]*100,0))

In [None]:
# Calculate one Principal Component Analysis
def calc_pca_for_an(node_location, operation, answer_digit):
  assert node_location.is_head == True

  try:
    t_questions = cfg.tricase_questions_dict[(answer_digit, operation)]

    _, the_cache = cfg.main_model.run_with_cache(t_questions)

    # Gather attention patterns for all the (randomly chosen) questions
    attention_outputs = []
    for i in range(len(t_questions)):

      # Output of individual heads, without final bias
      attention_cache=the_cache["result", node_location.layer, "attn"] # Output of individual heads, without final bias
      attention_output=attention_cache[i]  # Shape [n_ctx, n_head, d_model]
      attention_outputs.append(attention_output[node_location.position, node_location.num, :])

    attn_outputs = torch.stack(attention_outputs, dim=0).cpu()

    pca = PCA(n_components=6)
    pca.fit(attn_outputs)
    pca_attn_outputs = pca.transform(attn_outputs)

    title = node_location.name() + ', A'+str(answer_digit) + ', EVR[0]=' + str(pca_evr_0_percent(pca)) + '%'

    return pca, pca_attn_outputs, title
  except Exception as e:
    print( "calc_pca_for_an Failed:" + node_location.name() + " " + qt.token_to_char(cfg, operation) + " " + qt.answer_name(answer_digit), e)
    return None, None, None

In [None]:
# Plot the PCA of PnLnHn's attention pattern, using T8, T9, T10 questions that differ in the An digit
def plot_pca_for_an(ax, pca_attn_outputs, title):
  ax.scatter(pca_attn_outputs[:qt.TRICASE_QUESTIONS, 0], pca_attn_outputs[:qt.TRICASE_QUESTIONS, 1], color='red', label='T8 (0-8)') # t8 questions
  ax.scatter(pca_attn_outputs[qt.TRICASE_QUESTIONS:2*qt.TRICASE_QUESTIONS, 0], pca_attn_outputs[qt.TRICASE_QUESTIONS:2*qt.TRICASE_QUESTIONS, 1], color='green', label='T9') # t9 questions
  ax.scatter(pca_attn_outputs[2*qt.TRICASE_QUESTIONS:, 0], pca_attn_outputs[2*qt.TRICASE_QUESTIONS:, 1], color='blue', label='T10 (10-18)') # t10 questions
  if title != "" :
    ax.set_title(title)

In [None]:
def pca_op_tag(the_digit, operation, strong):
  return qt.answer_name(the_digit)  + "." + (qt.MathsBehavior.PCA_ADD_TAG.value if operation == qt.MathsToken.PLUS else qt.MathsBehavior.PCA_SUB_TAG.value) + ( "" if strong else ".Weak")

In [None]:
def manual_node_pca(ax, position, layer, num, operation, answer_digit):

  node_location = qt.NodeLocation(position, layer, True, num)
  pca, pca_attn_outputs, title = calc_pca_for_an(node_location, operation, answer_digit)
  plot_pca_for_an(ax, pca_attn_outputs, title)

  # Add the strong PCA tag to node PCA:A5.TR
  cfg.add_useful_node_tag( node_location, qt.QType.PCA, pca_op_tag(answer_digit, operation, True) )


def auto_node_pca(ax, index, node_location, operation, answer_digit, perc_threshold):

  pca, pca_attn_outputs, title = calc_pca_for_an(node_location, operation, answer_digit)
  if pca is not None:
    perc = pca_evr_0_percent(pca)
    if perc > perc_threshold:
      plot_pca_for_an(ax, pca_attn_outputs, title)

      # Add the weak PCA tag to node PCA:A5.TR.Weak
      cfg.add_useful_node_tag( node_location, qt.QType.PCA, pca_op_tag(answer_digit, operation, False) )
      return True

  return False

In [None]:
def manual_nodes_pca(op, nodes):
  print("Manual PCA tags for", cfg.model_name, "with operation", qt.token_to_char(cfg, op))

  cols = 4
  rows = 1 + (len(nodes)+1) // cols

  fig, axs = plt.subplots(rows, cols)
  fig.set_figheight(rows*2 + 1)
  fig.set_figwidth(10)

  index = 0
  for node in nodes:
    manual_node_pca(axs[index // cols, index % cols], node[0], node[1], node[2], op, node[3])
    index += 1

  # Remove any graphs we dont need (except last one)
  while index < rows * cols - 1:
    ax = axs[index // cols, index % cols]
    ax.remove()
    index += 1

  # Replace last graph with the legend
  lines_labels = [axs[0,0].get_legend_handles_labels()]
  lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
  axs[rows-1, cols-1].legend(lines, labels)
  axs[rows-1, cols-1].axis('off') # Now, to hide the last subplot

  plt.tight_layout()
  save_plt_to_file('Pca Tr')
  plt.show()

#Part 19B: Results: Manual interpretation of PCA results

If an attention head and an answer digit An gives an interpretable response (2 or 3 distinct output clusters) on 3 groups of questions aligned to T8, T9 and T10 definitions, then plot the response and add a QType.PCA tag



In [None]:
cfg.useful_nodes.reset_node_tags(qt.QType.PCA)

In [None]:
# Plot all attention heads with the clearest An selected
if use_pca:

  if cfg.model_name == "add_d5_l1_h3_t30K" :
    manual_nodes_pca(qt.MathsToken.PLUS,
      [[ 12, 0, 0, 4 ],
      [ 12, 0, 2, 3 ],
      [ 13, 0, 0, 3 ],
      [ 13, 0, 2, 2 ],
      [ 14, 0, 0, 2 ],
      [ 14, 0, 2, 1 ],
      [ 15, 0, 0, 1 ],
      [ 15, 0, 2, 0 ],
      [ 16, 0, 0, 0 ]])

  if cfg.model_name == "add_d5_l2_h3_t15K" :
    manual_nodes_pca(qt.MathsToken.PLUS,
      [[10, 0, 0, 2 ],
      [ 12, 0, 0, 3 ],
      [ 12, 1, 0, 3 ],
      [ 12, 1, 1, 4 ],
      [ 12, 1, 2, 4 ],
      [ 13, 0, 0, 0 ],
      [ 13, 0, 0, 3 ],
      [ 13, 1, 2, 2 ],
      [ 14, 0, 0, 0 ],
      [ 14, 0, 0, 2 ],
      [ 14, 1, 2, 2 ],
      [ 15, 0, 0, 1 ],
      [ 15, 0, 0, 0 ],
      [ 15, 1, 1, 1 ],
      [ 16, 0, 0, 0 ]])

  if cfg.model_name == "add_d6_l2_h3_t15K" :
    manual_nodes_pca(qt.MathsToken.PLUS,
      [[11, 0, 0, 2 ],
      [ 12, 0, 0, 3 ],
      [ 13, 0, 0, 1 ],
      [ 14, 0, 0, 4 ],
      [ 14, 1, 1, 4 ],
      [ 15, 0, 0, 4 ],
      [ 15, 1, 1, 4 ],
      [ 16, 0, 0, 3 ],
      [ 16, 1, 1, 3 ],
      [ 17, 0, 0, 2 ],
      [ 17, 1, 1, 2 ],
      [ 18, 0, 0, 0 ],
      [ 18, 0, 0, 1 ],
      [ 19, 0, 0, 0 ]])

  if cfg.model_name == "sub_d6_l2_h3_t30K" :
    manual_nodes_pca(qt.MathsToken.MINUS,
      [[14, 0, 1, 0 ],
      [ 15, 0, 0, 5 ],
      [ 15, 1, 1, 0 ],
      [ 15, 1, 1, 1 ],
      [ 15, 1, 1, 2 ],
      [ 15, 1, 1, 3 ],
      [ 15, 1, 1, 4 ],
      [ 15, 1, 2, 0 ],
      [ 15, 1, 2, 1 ],
      [ 15, 1, 2, 2 ],
      [ 15, 1, 2, 3 ],
      [ 16, 0, 0, 0 ],
      [ 16, 0, 0, 1 ],
      [ 16, 0, 0, 2 ],
      [ 16, 0, 0, 3 ],
      [ 16, 0, 0, 4 ],
      [ 16, 0, 0, 5 ],
      [ 16, 1, 0, 4 ],
      [ 16, 1, 1, 0 ],
      [ 16, 1, 1, 1 ],
      [ 16, 1, 1, 2 ],
      [ 16, 1, 1, 3 ],
      [ 16, 1, 1, 4 ],
      [ 16, 1, 2, 0 ],
      [ 16, 1, 2, 1 ],
      [ 16, 1, 2, 2 ],
      [ 16, 1, 2, 3 ],
      [ 16, 1, 2, 4 ],
      [ 16, 1, 2, 5 ],
      [ 17, 0, 0, 0 ],
      [ 17, 0, 0, 1 ],
      [ 17, 0, 0, 2 ],
      [ 17, 0, 0, 3 ],
      [ 17, 1, 0, 3 ],
      [ 17, 1, 0, 4 ],
      [ 17, 1, 2, 0 ],
      [ 17, 1, 2, 4 ],
      [ 18, 0, 0, 0 ],
      [ 18, 0, 0, 1 ],
      [ 18, 0, 0, 2 ],
      [ 18, 0, 2, 0 ],
      [ 18, 1, 2, 3 ],
      [ 19, 0, 0, 0 ],
      [ 19, 1, 2, 2 ],
      [ 20, 0, 0, 0 ],
      [ 18, 0, 2, 0 ],
	    [ 18, 0, 2, 0 ]])

  if cfg.model_name == "mix_d6_l3_h4_t40K" : # TBC
    manual_nodes_pca(qt.MathsToken.PLUS,
      [[ 8, 0, 0, 4 ],
      [  9, 0, 1, 3 ],
      [ 10, 0, 1, 2 ],
      [ 10, 0, 1, 3 ],
      [ 11, 0, 1, 1 ],
      [ 11, 0, 1, 2 ],
      [ 12, 0, 0, 0 ],
      [ 13, 0, 1, 4 ],
      [ 13, 1, 1, 0 ],
      [ 13, 1, 1, 1 ],
      [ 13, 1, 1, 2 ],
      [ 13, 1, 1, 3 ],
      [ 13, 1, 2, 3 ],
      [ 13, 1, 2, 5 ],
      [ 14, 0, 1, 0 ],
      [ 15, 0, 0, 0 ],
      [ 15, 0, 0, 1 ],
      [ 15, 0, 0, 2 ],
      [ 15, 0, 0, 3 ],
      [ 15, 0, 0, 4 ],
      [ 15, 0, 0, 5 ],
      [ 15, 1, 0, 0 ],
      [ 15, 1, 0, 1 ],
      [ 15, 1, 0, 2 ],
      [ 15, 1, 0, 3 ],
      [ 15, 1, 0, 4 ]])

  if cfg.model_name == "ins1_mix_d6_l3_h4_t40K" :
    manual_nodes_pca(qt.MathsToken.PLUS,
      [[13, 1, 3, 1 ],
      [ 14, 1, 2, 0 ],
      [ 14, 1, 2, 2 ],
      [ 14, 1, 3, 4 ],
      [ 15, 0, 3, 5 ],
      [ 15, 1, 2, 2 ],
      [ 15, 1, 3, 4 ],
      [ 16, 0, 3, 4 ],
      [ 16, 1, 2, 0 ],
      [ 16, 1, 2, 1 ],
      [ 16, 1, 2, 2 ],
      [ 16, 1, 3, 2 ],
      [ 17, 0, 3, 3 ],
      [ 17, 1, 2, 2 ],
      [ 17, 1, 3, 2 ],
      [ 18, 0, 3, 2 ],
      [ 18, 1, 3, 1 ],
      [ 19, 0, 3, 1 ],
      [ 19, 2, 0, 0 ],
      [ 19, 2, 1, 0 ],
      [ 20, 0, 0, 0 ],
      [ 20, 0, 3, 0 ]])

else:
  print( "PCA library failed to import. So PCA not done")

#Part 19C: Results: Automatic interpetation of PCA results

Part 19B is manual and selective. This part is automatic. It tests nodes not included in Part 19B, where this first (single) principal component explains 66% or more of the node. It adds a QType.PCA "weak" tag

In [None]:
def auto_find_pca_node(node, op, perc_threshold):
  fig, axs = plt.subplots(2, 4) # Allow up to 8 graphs
  fig.set_figheight(4)
  fig.set_figwidth(10)

  index = 0
  for answer_digit in range(cfg.n_digits+1):
    ax = axs[index // 4, index % 4]
    if auto_node_pca(ax, index, node, op, answer_digit, perc_threshold):
      index += 1

  # Remove any graphs we dont need after all
  while index < 2 * 4:
    ax = axs[index // 4, index % 4]
    ax.remove()
    index += 1

  plt.tight_layout()
  plt.show()

In [None]:
def auto_find_pca(op):
  print("Automatic (weak) PCA tags for", cfg.model_name, "with operation", qt.token_to_char(cfg, op))
  perc_threshold = 75

  for node in cfg.useful_nodes.nodes:

    # Exclude nodes with a (manual) PCA tag - for any answer digit(s)). Exclude MLP neurons.
    minor_tag_prefix = qt.MathsBehavior.PCA_ADD_TAG if op == qt.MathsToken.PLUS else qt.MathsBehavior.PCA_SUB_TAG
    if node.is_head and not node.contains_tag(qt.QType.PCA.value, minor_tag_prefix):
      print( "Doing PCA on node", node.name())

      auto_find_pca_node(node, op, perc_threshold)


if use_pca:
  if cfg.perc_add() > 0:
    auto_find_pca(qt.MathsToken.PLUS)
  if cfg.perc_sub > 0:
    auto_find_pca(qt.MathsToken.MINUS)

# Part 20A: Results: Show useful nodes and behaviour tags

In [None]:
cfg.useful_nodes.sort_nodes()
cfg.useful_nodes.print_node_tags()

# Part 20B: Results: Save useful nodes and behaviour tags to json file

In [None]:
# Serialize and save the useful nodes list to a temporary CoLab file in JSON format
print( "Saving useful node list with behavior tags:", main_fname_behavior_json)
cfg.useful_nodes.save_nodes(main_fname_behavior_json)

# Part 21 : Set up: Interchange Interventions

Set up a framework to run the model with two questions:
- The first "store" question is run without hooks
- The second "clean" question is run with hooks interjecting some data from the "store" run. This run gives, not a "clean" answer, but an "intervened" answer, which mixes the "store" answer and the "clean" answer.


In [None]:
def run_intervention_core(node_locations, store_question, clean_question, operation, expected_answer_impact, expected_answer_int, strong):
  assert(len(node_locations) > 0)
  assert(store_question[0] < + 10 ** cfg.n_digits)
  assert(store_question[1] > - 10 ** cfg.n_digits)
  assert(store_question[0] < + 10 ** cfg.n_digits)
  assert(store_question[1] > - 10 ** cfg.n_digits)
  assert(clean_question[0] < + 10 ** cfg.n_digits)
  assert(clean_question[1] > - 10 ** cfg.n_digits)
  assert(clean_question[0] < + 10 ** cfg.n_digits)
  assert(clean_question[1] > - 10 ** cfg.n_digits)

  # Calculate the test (clean) question answer e.g. "+006671"
  clean_answer_int = clean_question[0]+clean_question[1] if operation == qt.MathsToken.PLUS else clean_question[0]-clean_question[1]
  clean_answer_str = qt.int_to_answer_str(cfg, clean_answer_int)
  expected_answer_str = qt.int_to_answer_str(cfg, expected_answer_int)

  # Matrices of tokens
  store_question_and_answer = qt.make_maths_questions_and_answers(cfg, acfg.operation, qt.QType.UNKNOWN, qt.MathsBehavior.UNKNOWN, [store_question])
  clean_question_and_answer = qt.make_maths_questions_and_answers(cfg, acfg.operation, qt.QType.UNKNOWN, qt.MathsBehavior.UNKNOWN, [clean_question])

  acfg.reset_intervention(expected_answer_str, expected_answer_impact, operation)
  acfg.ablate_node_locations = node_locations

  run_description = qt.a_run_attention_intervention(cfg, store_question_and_answer, clean_question_and_answer, clean_answer_str)

  return "Intervening on " + acfg.node_names() + ", " + ("Strong" if strong else "Weak") + ", Node[0]=" + acfg.ablate_node_locations[0].name() + ", " + run_description


# Run an intervention where we have a precise expectation of the intervention impact
def run_strong_intervention(node_locations, store_question, clean_question, operation, expected_answer_impact, expected_answer_int):

  # These are the actual model prediction outputs (while applying our node-level intervention).
  description = run_intervention_core(node_locations, store_question, clean_question, operation, expected_answer_impact, expected_answer_int, True)

  answer_success = (acfg.intervened_answer == acfg.expected_answer)
  impact_success = (acfg.intervened_impact == acfg.expected_impact)
  success = answer_success and impact_success

  if acfg.show_test_failures and not success:
    print("Failed: " + description)

  return success, answer_success, impact_success


# Run an intervention where we expect the intervention to have a non-zero impact and we cant precisely predict the answer
def run_weak_intervention(node_locations, store_question, clean_question, operation):

  # Calculate the test (clean) question answer e.g. "+006671"
  expected_answer_int = clean_question[0]+clean_question[1] if operation == qt.MathsToken.PLUS else clean_question[0]-clean_question[1]

  description = run_intervention_core(node_locations, store_question, clean_question, operation, qt.NO_IMPACT_TAG, expected_answer_int, False)

  success = not ((acfg.intervened_answer == acfg.expected_answer) or (acfg.intervened_impact == qt.NO_IMPACT_TAG))

  if acfg.show_test_failures and not success:
    print("Failed: Intervention had no impact on the answer", description)

  return success

In [None]:
def repeat_digit(digit):
    return int(str(digit) * cfg.n_digits)


# unit test
if cfg.n_digits == 6:
  assert repeat_digit(4) == 444444

In [None]:
def succeed_test(node_locations, alter_digit, strong):
  print( "Test confirmed", node_locations[0].name(), node_locations[1].name() if len(node_locations)>1 else "", "" if strong else "Weak")
  return True

In [None]:
cfg.useful_nodes.reset_node_tags(qt.QType.ALGO.value)

# Part 22A : Results: Search for model operations

Here we find which model nodes perform which specific operations. (The resulting "nodes doing operations" data will be used later to test an algorithm hypothesis.)  

**Automatic searches** for node purposes are preferred, as they applicable to several models, and survive (non-sigificant, node-reordering) changes to the model after training. When a node purpose is detected, this is documented as a tag on the node.

**Manually written tests** of node purposes, specific to a single model instance are also supported.

# Part 22B: Automated An.US search

The addition Use Sum 9 (US) operation is a simple task. Search for An.US tasks.

In [None]:
def add_us_tag(impact_digit):
  return qt.answer_name(impact_digit-1)  + "." + qt.MathsAlgorithm.ADD_US_TAG.value

In [None]:
# These rules are prerequisites for (not proof of) an Addition UseSum9 node
def add_us_prereqs(position, impact_digit):
  return qt.FilterAnd(
    qt.FilterHead(),
    qt.FilterPosition(qt.position_name(position)),
    qt.FilterAttention(cfg.dn_to_position_name(impact_digit-2)), # Attends to Dn-2
    qt.FilterAttention(cfg.ddn_to_position_name(impact_digit-2)), # Attends to D'n-2
    qt.FilterImpact(qt.answer_name(impact_digit)))

In [None]:
def add_us_test(node_locations, alter_digit, strong):
  if alter_digit < 2 or alter_digit > cfg.n_digits:
    acfg.reset_intervention()
    return False

  intervention_impact = qt.answer_name(alter_digit)

  # 25222 + 44444 = 69666. Has no Dn-2.MC but has Dn-1.US so not a US case
  store_question = [repeat_digit(2), repeat_digit(4)]
  store_question[0] += (5-2) * 10 ** (alter_digit - 1)

  # 34633 + 55555 = 90188. Has Dn-2.MC and Dn-1.US so is a US case
  clean_question = [repeat_digit(3), repeat_digit(5)]
  clean_question[0] += (4-3) * 10 ** (alter_digit - 1)
  clean_question[0] += (6-3) * 10 ** (alter_digit - 2)

  # When we intervene we expect answer 80188
  intervened_answer = clean_question[0] + clean_question[1] - 10 ** (alter_digit)


  # Unit test
  if cfg.n_digits == 5 and alter_digit == 4:
    assert store_question[0] == 25222
    assert clean_question[0] == 34633
    assert clean_question[0] + clean_question[1] == 90188
    assert intervened_answer == 80188


  success, _, _ = run_strong_intervention(node_locations, store_question, clean_question, qt.MathsToken.PLUS, intervention_impact, intervened_answer)

  if success:
    print( "Test confirmed", acfg.node_names(), "perform D"+str(alter_digit)+".US impacting "+intervention_impact+" accuracy.", "" if strong else "Weak")

  return success

In [None]:
qt.search_and_tag( cfg, acfg, add_us_prereqs, add_us_test, add_us_tag, False, False)

# Part 22C: Automated An.MC search

The addition Make Carry (MC) operation is a simple task. Search for An.MC tasks.

In [None]:
def add_mc_tag(impact_digit):
  return qt.answer_name(impact_digit-1)  + "." + qt.MathsAlgorithm.ADD_MC_TAG.value

In [None]:
# These rules are prerequisites for (not proof of) an Addition MakeCarry node
def add_mc_prereqs(position, impact_digit):
  return qt.FilterAnd(
    qt.FilterHead(),
    qt.FilterPosition(qt.position_name(position)),
    qt.FilterAttention(cfg.dn_to_position_name(impact_digit-1)), # MC is calculated on the next lower-value digit.
    qt.FilterAttention(cfg.ddn_to_position_name(impact_digit-1)), # MC is calculated on the next lower-value digit.
    qt.FilterImpact(qt.answer_name(impact_digit)))

In [None]:
def add_mc_test(node_locations, impact_digit, strong):
  alter_digit = impact_digit - 1

  if alter_digit < 0 or alter_digit >= cfg.n_digits:
    acfg.reset_intervention()
    return False

  intervention_impact = qt.answer_name(impact_digit)

  # 222222 + 666966 = 889188. Has Dn.MC
  store_question = [repeat_digit(2), repeat_digit(6)]
  store_question[1] += (9 - 6) * (10 ** alter_digit)

  # 333333 + 555555 = 888888. No Dn.MC
  clean_question = [repeat_digit(3), repeat_digit(5)]

  # When we intervene we expect answer 889888
  intervened_answer = clean_question[0] + clean_question[1] + 10 ** (alter_digit+1)

  success, _, _ = run_strong_intervention(node_locations, store_question, clean_question, qt.MathsToken.PLUS, intervention_impact, intervened_answer)

  if success:
    print( "Test confirmed", acfg.node_names(), "perform D"+str(alter_digit)+".MC impacting "+intervention_impact+" accuracy.", "" if strong else "Weak")

  return success

In [None]:
qt.search_and_tag( cfg, acfg, add_mc_prereqs, add_mc_test, add_mc_tag, False, False)

# Part 22D: Automated An.BA search

The addition Base Add (BA) operation is a simple task. The task may be split/shared over 2 attention heads in the same position. Search for An.BA calculations.

(Note: BS and BA give same result in edge case when D'=0 or D=D'=5. Avoid tests that use these cases)

In [None]:
def add_ba_tag(impact_digit):
  return qt.answer_name(impact_digit) + "." + qt.MathsAlgorithm.ADD_BA_TAG.value

In [None]:
# These rules are prerequisites for (not proof of) an Addition BaseAdd node
def add_ba_prereqs(position, impact_digit):
  return qt.FilterAnd(
    qt.FilterHead(),
    qt.FilterPosition(qt.position_name(position)),
    qt.FilterAttention(cfg.dn_to_position_name(impact_digit)), # Attends to Dn
    qt.FilterAttention(cfg.ddn_to_position_name(impact_digit)), # Attends to D'n
    qt.FilterImpact(qt.answer_name(impact_digit)), # Impacts An
    qt.FilterAlgo(add_ba_tag(impact_digit), qt.QCondition.NOT)) # Has not already been flagged as a BA task

In [None]:
def add_ba_test1(alter_digit):
  # 222222 + 111111 = 333333. No Dn.MC
  store_question = [repeat_digit(2), repeat_digit(1)]

  # 555555 + 444444 = 999999. No Dn.MC
  clean_question = [repeat_digit(5), repeat_digit(4)]

  # When we intervene we expect answer 999399
  intervened_answer = clean_question[0] + clean_question[1] + (3-9) * 10 ** alter_digit

  return store_question, clean_question, intervened_answer


def add_ba_test2(alter_digit):
  # 222222 + 666666 = 888888. No Dn.MC
  store_question = [repeat_digit(2), repeat_digit(6)]

  # 555555 + 111111 = 666666. No Dn.MC
  clean_question = [repeat_digit(5), repeat_digit(1)]

  # When we intervene we expect answer 666866
  intervened_answer = clean_question[0] + clean_question[1] + (8-6) * 10 ** alter_digit

  return store_question, clean_question, intervened_answer


def add_ba_test(node_locations, alter_digit, strong):
  intervention_impact = qt.answer_name(alter_digit)

  store_question, clean_question, intervened_answer = add_ba_test1(alter_digit)
  success1, answer_success1, impact_success1 = run_strong_intervention(node_locations, store_question, clean_question, qt.MathsToken.PLUS, intervention_impact, intervened_answer)

  store_question, clean_question, intervened_answer = add_ba_test2(alter_digit)
  success2, answer_success2, impact_success2 = run_strong_intervention(node_locations, store_question, clean_question, qt.MathsToken.PLUS, intervention_impact, intervened_answer)

  success = (success1 and success2) if strong else (impact_success1 and impact_success2)

  if success:
    print( "Test confirmed:", acfg.node_names(), "perform D"+str(alter_digit)+".BA = (D"+str(alter_digit)+" + D'"+str(alter_digit)+") % 10 impacting "+intervention_impact+" accuracy.", "" if strong else "Weak", acfg.intervened_answer)

  return success

In [None]:
#cfg.useful_nodes.reset_node_tags(QType.ALGO)
#acfg.show_test_failures = True
qt.search_and_tag( cfg, acfg, add_ba_prereqs, add_ba_test, add_ba_tag, True, True )

# Part 22E: Automated Dn.C search

Search for D0.C to D5.C with impact "A65432" to "A65" in early tokens.

A0 and A1 are too simple to need Dn.C values so they are excluded from the answer impact.

In [None]:
def add_tc_tag(focus_digit):
  return "D" + str(focus_digit) + "." + qt.MathsAlgorithm.ADD_TC_TAG.value

In [None]:
# These rules are prerequisites for (not proof of) an Addition Dn.C node
def add_tc_prereqs(position, focus_digit):
  return qt.FilterAnd(
    qt.FilterHead(),
    qt.FilterPosition(qt.position_name(position)),
    qt.FilterAttention(cfg.dn_to_position_name(focus_digit)), # Attends to Dn
    qt.FilterAttention(cfg.ddn_to_position_name(focus_digit)), # Attends to D'n
    qt.FilterPCA(qt.MathsBehavior.PCA_ADD_TAG.value, qt.QCondition.CONTAINS)) # Node PCA is interpretable (bigram or trigram output) with respect to T8,T9,T10

In [None]:
def add_tc_test(node_locations, focus_digit, strong):
  # 222222 + 777977 = 1000188. Has Dn.MC
  store_question = [repeat_digit(2), repeat_digit(7)]
  store_question[1] += (9 - 7) * (10 ** focus_digit)

  # 333333 + 666666 = 999999. No Dn.MC
  clean_question = [repeat_digit(3), repeat_digit(6)]

  success = run_weak_intervention(node_locations, store_question, clean_question, qt.MathsToken.PLUS)

  if success:
    description = acfg.node_names() + " perform D"+str(focus_digit)+".C = TriCase(D"+str(focus_digit)+" + D'"+str(focus_digit)+")"
    print("Test confirmed", description, "Impact:", acfg.intervened_impact, "" if strong else "Weak")

  return success

In [None]:
if cfg.n_layers > 1: # Have not seen this task in 1-layer models.
  qt.search_and_tag( cfg, acfg, add_tc_prereqs, add_tc_test, add_tc_tag,
    False, # Have not seen this task split between nodes.
    False,
    cfg.n_digits, cfg.num_question_positions) # These occur from the first D'n digit to the first answer digit.

# Part 22F: Automated An.BS search

The subtraction Base Subtraction (BS) operation is a simple task. The task may be split/shared over 2 attention heads in the same position. Search for An.BS calculations.

(Note: BS and BA give same result in edge case when D'=0 or D=D'=5. Avoid tests that use these cases)

In [None]:
def sub_bs_tag(impact_digit):
  return qt.answer_name(impact_digit) + "." + qt.MathsAlgorithm.SUB_BS_TAG.value

In [None]:
# These rules are prerequisites for (not proof of) a BaseSubtraction node
def sub_bs_prereqs(position, impact_digit):
  return qt.FilterAnd(
    qt.FilterHead(),
    qt.FilterPosition(qt.position_name(position)),
    qt.FilterAttention(cfg.dn_to_position_name(impact_digit)), # Attends to Dn
    qt.FilterAttention(cfg.ddn_to_position_name(impact_digit)), # Attends to D'n
    qt.FilterImpact(qt.answer_name(impact_digit)), # Impacts An
    qt.FilterAlgo(sub_bs_tag(impact_digit), qt.QCondition.NOT)) # Has not already been flagged as a BS task

In [None]:
def sub_bs_test1(alter_digit):
  # 333333 - 111111 = 222222. No Dn.BO
  store_question = [repeat_digit(3), repeat_digit(1)]

  # 999999 - 444444 = 555555. No Dn.BO
  clean_question = [repeat_digit(9), repeat_digit(4)]

  # When we intervene we expect answer 555255
  intervened_answer = clean_question[0] - clean_question[1] + (2-5) * 10 ** alter_digit

  return store_question, clean_question, intervened_answer


def sub_bs_test2(alter_digit):
  # 666666 - 222222 = 444444. No Dn.BO
  store_question = [repeat_digit(6), repeat_digit(2)]

  # 999999 - 333333 = 666666. No Dn.BO
  clean_question = [repeat_digit(9), repeat_digit(3)]

  # When we intervene we expect answer 666466
  intervened_answer = clean_question[0] - clean_question[1] + (4-6) * 10 ** alter_digit

  return store_question, clean_question, intervened_answer


def sub_bs_test(node_locations, alter_digit, strong):
  intervention_impact = qt.answer_name(alter_digit)

  store_question, clean_question, intervened_answer = sub_bs_test1(alter_digit)
  success1, answer_success1, impact_success1 = run_strong_intervention(node_locations, store_question, clean_question, qt.MathsToken.MINUS, intervention_impact, intervened_answer)

  store_question, clean_question, intervened_answer = sub_bs_test2(alter_digit)
  success2, answer_success2, impact_success2 = run_strong_intervention(node_locations, store_question, clean_question, qt.MathsToken.MINUS, intervention_impact, intervened_answer)

  success = (success1 and success2) if strong else (impact_success1 and impact_success2)

  if success:
    print( "Test confirmed:", acfg.node_names(), "perform D"+str(alter_digit)+".BS = (D"+str(alter_digit)+" + D'"+str(alter_digit)+") % 10 impacting "+intervention_impact+" accuracy.", "" if strong else "Weak")

  return success

In [None]:
#cfg.useful_nodes.reset_node_tags(QType.ALGO)
#acfg.show_test_failures = True
qt.search_and_tag( cfg, acfg, sub_bs_prereqs, sub_bs_test, sub_bs_tag, True, True, cfg.num_question_positions )

# Part 22G: Automated An.BO search

The subtraction Borrow One operation is a simple task. Search for An.BO tasks.

In [None]:
def sub_bo_tag(impact_digit):
  return qt.answer_name(impact_digit-1)  + "." + qt.MathsAlgorithm.SUB_BO_TAG.value

In [None]:
# These rules are prerequisites for (not proof of) an subtraction Borrow One node
def sub_bo_prereqs(position, impact_digit):
  return qt.FilterAnd(
    qt.FilterHead(),
    qt.FilterPosition(qt.position_name(position)),
    qt.FilterAttention(cfg.dn_to_position_name(impact_digit-1)), # BO is calculated on the next lower-value digit.
    qt.FilterAttention(cfg.ddn_to_position_name(impact_digit-1)), # BO is calculated on the next lower-value digit.
    qt.FilterImpact(qt.answer_name(impact_digit)))

In [None]:
def sub_bo_test(node_locations, impact_digit, strong):
  alter_digit = impact_digit - 1

  if alter_digit < 0 or alter_digit >= cfg.n_digits:
    acfg.reset_intervention()
    return False

  intervention_impact = qt.answer_name(impact_digit)

  # 222222 - 111311 = +0110911. Has Dn.BO
  store_question = [repeat_digit(2), repeat_digit(1)]
  store_question[1] += (3 - 1) * (10 ** alter_digit)

  # 777777 - 444444 = +0333333. No Dn.BO
  clean_question = [repeat_digit(7), repeat_digit(4)]


  # When we intervene we expect answer +0332333
  intervened_answer = clean_question[0] - clean_question[1] - 10 ** (alter_digit+1)

  success, _, _ = run_strong_intervention(node_locations, store_question, clean_question, qt.MathsToken.MINUS, intervention_impact, intervened_answer)

  if success:
    print( "Test confirmed", acfg.node_names(), "perform D"+str(alter_digit)+".BO impacting "+intervention_impact+" accuracy.", "" if strong else "Weak")

  return success

In [None]:
#cfg.useful_nodes.reset_node_tags(qt.QType.ALGO)
#acfg.show_test_failures = True
qt.search_and_tag( cfg, acfg, sub_bo_prereqs, sub_bo_test, sub_bo_tag, True, True)

# Part 22H: Automated An.NG search

Somes useful nodes are only used in subtraction when D < D' in the D - D' question e.g. negative-answer questions. We claim these nodes are associated somehow with converting D - D' to the mathematically equivalent - ( D' - D )

To calculate D'>D, model needs to calculate D'6>D6, ..., D'0>D0 and then combine the results in a cascade D'6>D6 else ... else D'0>D0. We expect to see nodes only used in NG questions, with PCA bigram (or trigram) outputs, attending to these input pairs. The "else cascade" is tested later.

In [None]:
def sub_ng_tag(impact_digit):
  return qt.answer_name(impact_digit)  + "." + qt.MathsAlgorithm.SUB_NG_TAG.value

In [None]:
def sub_ng_prereqs(position, impact_digit):
  return qt.FilterAnd(
    qt.FilterHead(),
    qt.FilterPosition(qt.position_name(position)),
    qt.FilterImpact(qt.answer_name(impact_digit)), # Impacts An
    qt.FilterContains(qt.QType.MATH_SUB, qt.MathsBehavior.SUB_NG_TAG.value)) # Impacts negative-answer questions

In [None]:
# Test that if we ablate this node then a negative-answer-subtraction question answer swaps to its positive complement
def sub_ng_test(node_locations, focus_digit, strong):
  global store0
  global store1
  global clean0

  if focus_digit >= cfg.n_digits:
    acfg.reset_intervention()
    return False

  # 555555 - 000000 = +0555555. Is a positive-answer-subtraction
  store_question = [repeat_digit(5), repeat_digit(0)]

  # 222222 - 222422 = -0000200. Is a negative-answer-subtraction question because of focus_digit
  clean_question = [repeat_digit(2), repeat_digit(2)]
  clean_question[1] += 2 * (10 ** focus_digit)

  success = run_weak_intervention(node_locations, store_question, clean_question, qt.MathsToken.MINUS)

  if success:
    description = acfg.node_names() + " perform A"+str(focus_digit)+".NG"
    print("Test confirmed", description, "Impact:", acfg.intervened_impact, "" if strong else "Weak")

  return success

In [None]:
#acfg.show_test_failures = True
qt.search_and_tag( cfg, acfg, sub_ng_prereqs, sub_ng_test, sub_ng_tag, False, False )
#acfg.show_test_failures = False

# Part 22I: Automated OP search

For mixed models that do addition and subtraction the operation token "+/-" (in the middle of the question) is key. Find nodes that attend to the question operation.

In [None]:
def mix_op_tag(impact_digit):
  return "An." + qt.MathsAlgorithm.MIX_OP_TAG.value # Doesnt depend on impact_digit

In [None]:
def mix_op_prereqs(position, impact_digit):
  return qt.FilterAnd(
    qt.FilterHead(),
    qt.FilterPosition(qt.position_name(position)),
    qt.FilterAttention(cfg.op_position_name()))

In [None]:
#cfg.useful_nodes.reset_node_tags(QType.ALGO)
#acfg.show_test_failures = True
qt.search_and_tag( cfg, acfg, mix_op_prereqs, succeed_test, mix_op_tag, False, False )

# Part 22J: Automated SG search

For mixed models that do addition and subtraction, and for our subtraction models, the answer sign token "+/-" (at the start of the answer) is important. Find nodes that attend to the answer sign token.

In [None]:
def mix_sg_tag(impact_digit):
  return "An." + qt.MathsAlgorithm.MIX_SG_TAG.value # Doesnt depend on impact_digit

In [None]:
def mix_sg_prereqs(position, impact_digit):
  return qt.FilterAnd(
    qt.FilterHead(),
    qt.FilterPosition(qt.position_name(position)),
    qt.FilterAttention(cfg.an_to_position_name(cfg.n_digits+1)))

In [None]:
#cfg.useful_nodes.reset_node_tags(QType.ALGO)
#acfg.show_test_failures = True
qt.search_and_tag( cfg, acfg, mix_sg_prereqs, succeed_test, mix_sg_tag, False, False )

# Part 23A: Show algorithm quanta map

Plot the "algorithm" tags generated in previous steps as a quanta map. This is an automatically generated partial explanation of the model algorithm.

Nodes with multiple tags were tagged (found) by more than one of the above task searches

In [None]:
num_results = show_quanta_map( "Algorithm Purpose Per Node", False, 2, None, qt.QType.ALGO, "", qt.get_quanta_binary, 9)

print( num_results, "of", len(cfg.useful_nodes.nodes), "useful nodes have an algorithm purpose assigned.")

# Part 23B: Show known quanta per answer digit

Each of the late positions are soley focused on calculating one answer digit. Show the data have we collected on late answer digit.  



In [None]:
for position in range(cfg.num_question_positions + 1, cfg.n_ctx() - 1):
  print("Position:", position)

  # Calculate a table of the known quanta for the specified position for each late token position
  qt.calc_maths_quanta_for_position_nodes(cfg, position)

  save_plt_to_file("Quanta At "+ qt.position_name(position))

  plt.show()

# Part 24: Save useful nodes with behaviour and algorithm tags to JSON file

Show a list of the nodes that have proved useful in calculations, together with data on the nodes behavior and algorithmic purposes.
Save the data to a Colab temporary JSON file.



In [None]:
cfg.useful_nodes.print_node_tags(qt.QType.ALGO.value, "", False)

In [None]:
# Serialize and save the useful nodes list with algorithm tags to a temporary CoLab file in JSON format
print( "Saving useful node list with algorithm tags:", main_fname_algorithm_json)
cfg.useful_nodes.save_nodes(main_fname_algorithm_json)

In [None]:
mixed_model = cfg.model_name == "ins1_mix_d6_l3_h4_t40K" or cfg.model_name == "ins2_mix_d6_l4_h4_t40K" or cfg.model_name == "ins3_mix_d6_l4_h3_t40K"

# Part 25 : Results: Test Algorithm - Addition

## Part25A : Model add_d5_l1_h3_t30K. Tasks An.BA, An.MC, An.US

This 1-layer model cant do all addition questions. This hypothesis mirrors Paper 1. 14/15 heads have purpose assigned. 0/6 neurons have purpose assigned.

In [None]:
# For answer digits (excluding Amax), a An.BA node is needed before the answer digit is revealed
def test_algo_ba(algo_nodes):
  for impact_digit in range(cfg.n_digits):
    cfg.test_algo_clause(algo_nodes, qt.FilterAnd(qt.FilterAlgo(add_ba_tag(impact_digit)), qt.FilterPosition(cfg.an_to_position_name(impact_digit+1), qt.QCondition.MUST_BY)))

# For answer digits (excluding Amax and A0), a An.MC node is needed before the answer digit is revealed
def test_algo_mc(algo_nodes):
  for impact_digit in range(cfg.n_digits):
    if impact_digit > 0:
      cfg.test_algo_clause(algo_nodes, qt.FilterAnd(qt.FilterAlgo(add_mc_tag(impact_digit)), qt.FilterPosition(cfg.an_to_position_name(impact_digit+1), qt.QCondition.MUST_BY)))

# For answer digits (excluding Amax, A1 and A0), a An.US node is needed before the answer digit is revealed
def test_algo_us(algo_nodes):
  for impact_digit in range(cfg.n_digits):
    if impact_digit > 1:
      cfg.test_algo_clause(algo_nodes, qt.FilterAnd(qt.FilterAlgo(add_us_tag(impact_digit)), qt.FilterPosition(cfg.an_to_position_name(impact_digit+1), qt.QCondition.MUST_BY)))

In [None]:
if cfg.model_name == "add_d5_l1_h3_t30K":

  algo_nodes = cfg.start_algorithm_test(acfg)

  test_algo_ba(algo_nodes)
  test_algo_mc(algo_nodes)
  test_algo_us(algo_nodes)
  cfg.print_algo_clause_results()

  cfg.print_algo_purpose_results(algo_nodes)

## Part25B : Models add_d5/d6_l2_h3_t15K. Tasks An.BA, An.MC, An.US, Dn.C

These 2-layer models can do addition accurately. This hypothesis mirrors Paper 2.

In [None]:
# Before Amax is revealed (as a 0 or 1), there must be a Dn.C node for every digit pair
# For each digit (except A0) there must be either an An.US or an Dn.C
def test_algo_tc_or_us(model_nodes):
  for impact_digit in range(cfg.n_digits):
    early_dnc = qt.FilterAnd(qt.FilterAlgo(add_tc_tag(impact_digit)), qt.FilterPosition(cfg.an_to_position_name(cfg.n_digits+1), qt.QCondition.MUST_BY))
    late_dnc = qt.FilterAnd(qt.FilterAlgo(add_tc_tag(impact_digit)), qt.FilterPosition(cfg.an_to_position_name(impact_digit+1), qt.QCondition.MUST_BY))
    any_dnus = qt.FilterAnd(qt.FilterAlgo(add_us_tag(impact_digit)), qt.FilterPosition(cfg.an_to_position_name(impact_digit+1), qt.QCondition.MUST_BY))

    if cfg.n_layers == 1:
      # There must be a Dn.US node for every answer digit except A0
      if impact_digit > 0:
        cfg.test_algo_clause(model_nodes, any_dnus)
    else:
      # There must be a Dn.US node or a Dn.C node for every digit except A0
      if impact_digit > 0:
        cfg.test_algo_clause(model_nodes, qt.FilterOr(any_dnus, late_dnc))

      # There must a Dn.C node for every digit before the first 1 or 0 digit is calculated
      cfg.test_algo_clause(model_nodes, early_dnc)

In [None]:
if cfg.model_name == "add_d5_l2_h3_t15K" or cfg.model_name == "add_d6_l2_h3_t15K":

  algo_nodes = cfg.start_algorithm_test(acfg)

  test_algo_ba(algo_nodes)
  test_algo_mc(algo_nodes)
  test_algo_tc_or_us(algo_nodes)
  cfg.print_algo_clause_results()

  cfg.print_algo_purpose_results(algo_nodes)

# Part 26: Results: Test Algorithm - Subtraction

## Part 26A : Model sub_d6_l2_h3_t30K. Tasks BS, BO, SZ

This 2-layer model can do subtraction accurately. TBC

In [None]:
# For answer digits (excluding Amax), An.BS and An.BO nodes are needed before the answer digit is revealed
def test_algo_bs_bo(algo_nodes):
  for impact_digit in range(cfg.n_digits):
    cfg.test_algo_clause(algo_nodes, qt.FilterAnd(qt.FilterAlgo(sub_bs_tag(impact_digit)), qt.FilterPosition(cfg.an_to_position_name(impact_digit+1), qt.QCondition.MUST_BY)))

    cfg.test_algo_clause(algo_nodes, qt.FilterAnd(qt.FilterAlgo(sub_bo_tag(impact_digit)), qt.FilterPosition(cfg.an_to_position_name(impact_digit+1), qt.QCondition.MUST_BY)))


# For answer digits (excluding Amax), An.SZ nodes are needed before the answer digit is revealed
def test_algo_sz(algo_nodes):
  for impact_digit in range(cfg.n_digits):
      pass
      #cfg.test_algo_clause(algo_nodes, qt.FilterAnd(qt.FilterAlgo(sub_sz_tag(impact_digit)), qt.FilterPosition(cfg.an_to_position_name(impact_digit+1), qt.QCondition.MUST_BY)))

In [None]:
if cfg.model_name == "sub_d6_l2_h3_t30K":

  algo_nodes = cfg.start_algorithm_test(acfg)

  test_algo_bs_bo(algo_nodes)
  test_algo_sz(algo_nodes)
  cfg.print_algo_clause_results()

  cfg.print_algo_purpose_results(algo_nodes)

## Part 26B: Test Algorithm - Subtraction - Negative Answer

We claim the model converts subtraction questions D - D' to the mathematically equivalent - ( D' - D ) when D' > D. To do this the model needs to know when D' > D (or equivalently when D' > D).

To calculate D'>D, model needs to calculate D'6>D6 else D'5>D5 else D'4>D4 else D'3>D3 else D'2>D2 else D'1>D1 else D'1>D1. We expect to see nodes only used in NG questions, with PCA bigram (or trigram) outputs, attending to these input pairs, evaluated in this order.

In [None]:
# For answer digits (excluding Amax), An.NG is needed before the answer digit is revealed
def test_algo_ng(algo_nodes):
  ng_locations = {}

  for impact_digit in range(cfg.n_digits):
    # For answer digits (excluding the +/- sign and 0 or 1 Amax), An.NG is calculated before the answer digit is revealed
    position = cfg.test_algo_clause(algo_nodes, qt.FilterAnd(qt.FilterAlgo(sub_ng_tag(impact_digit)), qt.FilterPosition(cfg.an_to_position_name(impact_digit), qt.QCondition.MUST_BY)))
    ng_locations[impact_digit] = position

  # Check that ng_locations[6] < ng_locations[5] < ng_locations[4] < etc
  for impact_digit in range(cfg.n_digits):
    if impact_digit > 0:
      cfg.test_algo_logic("NG Ordering for A" + str(impact_digit), ng_locations[impact_digit] < ng_locations[impact_digit-1])

In [None]:
if cfg.model_name == "sub_d6_l2_h3_t30K":

  algo_nodes = cfg.start_algorithm_test(acfg)

  test_algo_ng(algo_nodes)
  cfg.print_algo_clause_results()

  cfg.print_algo_purpose_results(algo_nodes)

# Part 27: Test Algorithm - Mixed Addition and Subtraction model

What algorithm do mixed models use to perform both addition and subtraction? Our working hypothesis is in https://github.com/PhilipQuirke/verified_transformers/blob/main/mixed_readme.md





## Consider position A3 calculating answer digit A2

The below graph uses the same (behavior and algorithm) data as the quanta maps. Notes:
- Some attention heads are used in both Add and Sub (e.g. to do BA or BS, MC or BO).
- Two attention heads are specific to Add (e.g. P18L1H2, P18L1H3).
- Four attention heads are specific to Sub (e.g. P18L0H3, P18L1H0, P18L1H1, P18L2H0).
  - One attention heads P18L2H0 is only used in NG questions.  
- One attention heads (P18L0H3) attend to the Op (+/-) token.
- Four attention heads attend to the = token, which is when the sign (A7 being + or -) is calculated.
- Three nodes attend to the A7 token, which is when the A6 (0 or 1) answer digit is calculated.
    


In [None]:
qt.calc_maths_quanta_for_position_nodes(cfg, 18)
plt.show()

## Part 27.H3 Calculate whether D > D' (using NG tasks)

In [None]:
# Only display nodes with the SUB_NG_TAG tag.
filters = qt.FilterContains(qt.QType.MATH_SUB, qt.MathsBehavior.SUB_NG_TAG.value)

print("NG tagged nodes:", qt.filter_nodes( cfg.useful_nodes, filters ).get_node_names())

show_quanta_map( "Subtraction Behavior NG Nodes", False, 2, filters, qt.QType.MATH_SUB, "", qt.get_maths_min_complexity)
show_quanta_map( "Attention Behavior Per NG Head", True, 10, filters, qt.QType.ATTN, "", qt.get_quanta_attention, 10, 6)
show_quanta_map( "Algorithm Purpose Per NG Node", False, 2, filters, qt.QType.ALGO, "", qt.get_quanta_binary, 9, 10)

In [None]:
if mixed_model:
  algo_nodes = cfg.start_algorithm_test(acfg)

  test_algo_ng(algo_nodes)
  cfg.print_algo_clause_results()

  cfg.print_algo_purpose_results(algo_nodes)

# Part 30: Unit Test automated searches

In [None]:
def unit_test_node_tag(node_location_as_str, the_tags ):
  node_location = qt.str_to_node_location(node_location_as_str)
  node = cfg.useful_nodes.get_node(node_location)
  assert node is not None

  for the_tag in the_tags:
    if not node.contains_tag( qt.QType.ALGO, the_tag):
      print( "Unit test failure: Node", node.name(), "does not have expected tag", the_tag)

In [None]:
print(cfg.model_name)

if cfg.model_name == 'add_d6_l2_h3_t15K':
  unit_test_node_tag('P11L0H0', ['D2.TC'] )
  unit_test_node_tag('P12L0H0', ['D3.TC'] )
  unit_test_node_tag('P14L0H0', ['A5.US', 'D4.TC'] )
  unit_test_node_tag('P14L0H2', ['A5.MC', 'D5.TC'] )
  unit_test_node_tag('P14L1H1', ['OP'] )
  unit_test_node_tag('P15L0H0', ['A4.MC'] )
  unit_test_node_tag('P15L0H1', ['A5.BA'] )
  unit_test_node_tag('P15L0H2', ['A5.BA'] )
  unit_test_node_tag('P16L0H0', ['A3.MC'] )
  unit_test_node_tag('P16L0H1', ['A4.BA'] )
  unit_test_node_tag('P16L0H2', ['A4.BA'] )
  unit_test_node_tag('P17L0H0', ['A2.MC'] )
  unit_test_node_tag('P17L0H1', ['A3.BA'] )
  unit_test_node_tag('P17L0H2', ['A3.BA'] )
  unit_test_node_tag('P18L0H0', ['A1.MC'] )
  unit_test_node_tag('P18L0H1', ['A2.BA'] )
  unit_test_node_tag('P18L0H2', ['A2.BA'] )
  unit_test_node_tag('P19L0H0', ['A0.MC'] )
  unit_test_node_tag('P19L0H1', ['A1.BA'] )
  unit_test_node_tag('P19L0H2', ['A1.BA'] )
  unit_test_node_tag('P20L0H1', ['A0.BA'] )
  unit_test_node_tag('P20L0H2', ['A0.BA'] )

if cfg.model_name == 'mix_d6_l3_h4_t40K':
  unit_test_node_tag('P8L0H1', ['OP'] )
  unit_test_node_tag('P13L2H0', ['A7.NG'] )
  unit_test_node_tag('P15L0H0', ['A5.BA', 'A5.BS'] )
  unit_test_node_tag('P15L0H3', ['A5.BA', 'A5.BS'] )
  unit_test_node_tag('P16L0H3', ['A4.BA.A4', 'A4.BS.A4'] )
  unit_test_node_tag('P17L0H1', ['A3.NG'] )
  unit_test_node_tag('P17L0H3', ['A3.BA.A3', 'A3.BS.A3'] )
  unit_test_node_tag('P18L0H1', ['A2.NG'] )
  unit_test_node_tag('P18L0H3', ['A2.BA.A2', 'A2.BS.A2'] )
  unit_test_node_tag('P19L0H1', ['A1.NG'] )
  unit_test_node_tag('P19L0H3', ['A1.BA.A1', 'A1.BS.A1'] )
  unit_test_node_tag('P20L0H0', ['A0.BA', 'A0.BS'] )
  unit_test_node_tag('P20L0H3', ['A0.BA', 'A0.BS'] )
  unit_test_node_tag('P20L2H1', ['A0.NG'] )

if cfg.model_name == 'ins1_mix_d6_l3_h4_t40K':
  unit_test_node_tag('P9L0H1', ['D4.TC'] )
  unit_test_node_tag('P9L0H3', ['A5.NG', 'OP'] )
  unit_test_node_tag('P10L0H1', ['D2.TC'] )
  unit_test_node_tag('P12L0H1', ['D3.TC'] )
  unit_test_node_tag('P13L2H0', ['A7.NG'] )
  unit_test_node_tag('P14L0H0', ['A5.US', 'A6.NG'] )
  unit_test_node_tag('P14L0H2', ['A5.MC', 'D5.TC'] )
  unit_test_node_tag('P15L0H0', ['A4.MC'] )
  unit_test_node_tag('P15L0H1', ['A5.BA', 'A5.BS'] )
  unit_test_node_tag('P15L0H2', ['A5.BA', 'A5.BS'] )
  unit_test_node_tag('P16L0H0', ['A3.MC'] )
  unit_test_node_tag('P16L0H1', ['A4.BA', 'A4.BS'] )
  unit_test_node_tag('P16L0H2', ['A4.BA', 'A4.BS'] )
  unit_test_node_tag('P16L2H0', ['A4.NG'] )
  unit_test_node_tag('P17L0H0', ['A2.MC'] )
  unit_test_node_tag('P17L0H1', ['A3.BA', 'A3.BS'] )
  unit_test_node_tag('P17L0H2', ['A3.BA', 'A3.BS'] )
  unit_test_node_tag('P17L2H0', ['A3.NG'] )
  unit_test_node_tag('P18L0H0', ['A1.MC'] )
  unit_test_node_tag('P18L0H1', ['A2.BA', 'A2.BS'] )
  unit_test_node_tag('P18L0H2', ['A2.BA', 'A2.BS'] )
  unit_test_node_tag('P18L2H0', ['A2.NG'] )
  unit_test_node_tag('P19L0H0', ['A0.MC'] )
  unit_test_node_tag('P19L0H1', ['A1.BA', 'A1.BS'] )
  unit_test_node_tag('P19L0H2', ['A1.BA', 'A1.BS'] )
  unit_test_node_tag('P19L2H0', ['A1.NG'] )
  unit_test_node_tag('P20L0H0', ['A0.NG'] )
  unit_test_node_tag('P20L0H1', ['A0.BA', 'A0.BS'] )
  unit_test_node_tag('P20L0H2', ['A0.BA', 'A0.BS'] )