# Accurate Integer Mathematics in Transformers - Analyse the Model

This CoLab analyses a Transformer model that performs integer addition, subtraction and multiplication e.g. 133357+182243=+0315600, 123450-345670=-0123230 and 000345*000823=+283935. Each digit is a separate token. For 6 digit questions, the model is given 14 "question" (input) tokens, and must then predict the corresponding 8 "answer" (output) tokens.

## Tips for using the Colab
 * You can run and alter the code in this CoLab notebook yourself in Google CoLab ( https://colab.research.google.com/ ).
 * To run the notebook, in Google CoLab, **you will need to** go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.
 * Some graphs are interactive!
 * Use the table of contents pane in the sidebar to navigate.
 * Collapse irrelevant sections with the dropdown arrows.
 * Search the page using the search in the sidebar, not CTRL+F.

# Part 0: Import libraries
Imports standard libraries. Don't bother reading.

In [None]:
DEVELOPMENT_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")

    !pip install matplotlib
    !pip install prettytable

    !pip install kaleido
    !pip install transformer_lens
    !pip install torchtyping
    !pip install transformers

    !pip install numpy --upgrade
    !pip install scikit-learn --upgrade

except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook


In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import kaleido
import plotly.io as pio

if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

import plotly.express as px
import plotly.graph_objects as go

In [None]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [None]:
import json
import torch
import torch.nn.functional as F
import numpy as np
import random
from prettytable import PrettyTable
import itertools
import re

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import textwrap

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache


In [None]:
# Import Principal Component Analysis (PCA) library
use_pca = True
try:
  from sklearn.decomposition import PCA
except Exception as e:
  print("pca import failed with exception:", e)
  use_pca = False

  # Sometimes version conflicts means the PCA library does not import. This workaround partially fixes the issue
  !pip install --upgrade numpy
  !pip install --upgrade scikit-learn

  # To complete workaround, now select menu option "Runtime > Restart session and Run all".
  stop

In [None]:
# Refer https://github.com/PhilipQuirke/verified_transformers/blob/main/README.md
!pip install --upgrade git+https://github.com/PhilipQuirke/verified_transformers.git
import QuantaTools as qt

# Part 1A: Configuration

Which existing model do we want to analyse?

The existing model weightings created by the sister Colab [VerifiedArithmeticTrain](https://github.com/PhilipQuirke/transformer-maths/blob/main/assets/VerifiedArithmeticTrain.ipynb) are loaded from HuggingFace. Refer https://github.com/PhilipQuirke/verified_transformers/blob/main/README.md for more detail.

In [None]:
# Singleton QuantaTool configuration class. MathsConfig is derived from AlgoConfig > UsefulConfig > ModelConfig
cfg = qt.MathsConfig()

# Singleton QuantaTool ablation intervention configuration class
acfg = qt.acfg

In [None]:
# Which model do we want to analyse? Uncomment one line:

#cfg.model_name = "" # Use configuration specified in cfg defaults

#cfg.model_name = 5-digit and 6-digit digit Addition models
#cfg.model_name = "add_d5_l1_h3_t15K_s372001"  # Inaccurate as only has one layer. Can predict S0, S1 and S2 complexity questions.
#cfg.model_name = "add_d5_l2_h3_t15K_s372001"  # AvgFinalLoss=1.6e-08. Accurate on 1M Qs
#cfg.model_name = "add_d6_l2_h3_t15K_s372001"  # AvgFinalLoss=1.7e-08. Accurate on 1M Qs
#cfg.model_name = "add_d6_l2_h3_t20K_s173289"  # AvgFinalLoss=1.5e-08. Accurate on 1M Qs
#cfg.model_name = "add_d6_l2_h3_t20K_s572091"  # AvgFinalLoss=7e-09. Accurate on 1M Qs

# 6-digit Subtraction model
#cfg.model_name = "sub_d6_l2_h3_t30K_s372001"  # AvgFinalLoss=5.8e-06. Fails 1M Qs

# 6-digit Mixed (addition and subtraction) models
#cfg.model_name = "mix_d6_l3_h4_t40K_s372001"  # AvgFinalLoss=5e-09. Fails 1M Qs

#  "ins1" 6-digit Mixed models initialised with 6-digit addition model
cfg.model_name = "ins1_mix_d6_l3_h4_t40K_s372001"  # AvgFinalLoss=8e-09. Accurate on 1M Qs for Add and Sub
#cfg.model_name = "ins1_mix_d6_l3_h4_t40K_s173289"  # AvgFinalLoss=1.6e-08. 936K for Add, 1M Qs for Sub
#cfg.model_name = "ins1_mix_d6_l3_h4_t50K_s572091"  # AvgFinalLoss=2.9e-08. 1M for Add. 300K for Sub. For 000041-000047=-0000006 gives +0000006. Improve training data.
#cfg.model_name = "ins1_mix_d6_l3_h3_t40K_s572091"  #  AvgFinalLoss=1.8e-08. Fails on 1M Qs. For 099111-099111=+0000000 gives -0000000. Improve training data.

# "ins2" 6-digit Mixed model initialised with 6-digit addition model. Reset useful heads every 100 epochs.
#cfg.model_name = "ins2_mix_d6_l4_h4_t40K_s372001"  # AvgFinalLoss=7e-09. Fails 1M Qs

# "ins3" 6-digit Mixed model initialised with 6-digit addition model. Reset useful heads & MLPs every 100 epochs.
#cfg.model_name = "ins3_mix_d6_l4_h3_t40K_s372001"  # AvgFinalLoss=2.6e-06. Fails 1M Qs

# Part 1B: Configuration: Input and Output file names

In [None]:
# Needed when user changes model_name and reruns this Colab a second time
cfg.reset_useful()
cfg.reset_algo()
cfg.initialize_maths_token_positions()
acfg.reset_ablate()

if cfg.model_name != "":
  # Update cfg member data n_digits, n_layers, n_heads, n_training_steps from model_name
  cfg.parse_model_name()

  # Addition model
  cfg.perc_sub = 0
  if cfg.model_name.startswith("sub_") :
    # Subtraction model
    cfg.perc_sub = 100
  elif cfg.model_name.startswith("mix") :
    # Mixed (addition and subtraction) model
    cfg.perc_sub = 66 # Train on 66% subtraction and 33% addition question batches
  elif cfg.model_name.startswith("ins") :
    # Mixed model initialised with an addition model (using insert mode 1, 2 or 3)
    cfg.perc_sub = 80 # Train on 80% subtraction and 20% addition question batches

  # We train multiple versions of some models, inserting different addition models.
  cfg.insert_training_seed = cfg.training_seed

  if cfg.model_name.startswith("ins1_mix_d6_l3") :
    if cfg.training_seed == 372001:
      # Mixed model initialised with add_d6_l2_h3_t15K.pth.
      cfg.insert_n_training_steps = 15000
    else:
      # Mixed model initialised with add_d6_l2_h3_t20K.pth.
      cfg.insert_n_training_steps = 20000

  cfg.batch_size = 512 # Default analysis batch size
  if cfg.n_layers >= 3 and cfg.n_heads >= 4:
    cfg.batch_size = 256 # Reduce batch size to avoid memory constraint issues.

In [None]:
main_fname = cfg.file_config_prefix()
main_fname_pth = main_fname + '.pth'
main_fname_behavior_json = main_fname + '_behavior.json'
main_fname_algorithm_json = main_fname + '_algorithm.json'

def print_config():
  print("%Add=", cfg.perc_add(), "%Sub=", cfg.perc_sub, "%Mult=", cfg.perc_mult, "InsertMode=", cfg.insert_mode, "File=", main_fname)

print_config()
print("weight_decay=", cfg.weight_decay, "lr=", cfg.lr, "batch_size=", cfg.batch_size)
print('Main model will be read from HuggingLab file', main_fname_pth)
print('Main model behavior analysis tags will save to Colab temporary file', main_fname_behavior_json)
print('Main model algorithm analysis tags will save to Colab temporary file', main_fname_algorithm_json)

# Part 3A: Set Up: Vocabulary / Embedding / Unembedding

In [None]:
qt.set_maths_vocabulary(cfg)
qt.set_maths_question_meanings(cfg)
print(cfg.token_position_meanings)

# Part 3B: Set Up: Create model

In [None]:
# Transformer creation

# Structure is documented at https://neelnanda-io.github.io/TransformerLens/transformer_lens.html#transformer_lens.HookedTransformerConfig.HookedTransformerConfig
ht_cfg = HookedTransformerConfig(
    n_layers = cfg.n_layers,
    n_heads = cfg.n_heads,
    d_model = cfg.d_model,
    d_head = cfg.d_head,
    d_mlp = cfg.d_mlp(),
    act_fn = cfg.act_fn,
    normalization_type = 'LN',
    d_vocab = cfg.d_vocab,
    d_vocab_out = cfg.d_vocab,
    n_ctx = cfg.n_ctx(),
    init_weights = True,
    device = "cuda",
    seed = cfg.training_seed,
)

cfg.main_model = HookedTransformer(ht_cfg)

optimizer = torch.optim.AdamW(cfg.main_model.parameters(),
                        lr = cfg.lr,
                        weight_decay = cfg.weight_decay,
                        betas = (0.9, 0.98))

max_iter = cfg.n_training_steps
warmup_iter = max_iter // 5
scheduler1 = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, total_iters=int(warmup_iter))
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=int(np.ceil((max_iter-warmup_iter))))
scheduler  = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[int(warmup_iter)])

# Part 4: Set Up: Loss Function & Data Generator
This maths loss function and data generator are imported from QuantaTools as logits_to_tokens_loss, loss_fn, maths_data_generator_core and maths_data_generator.

In [None]:
# Define "iterator" maths "questions" data generator function. Invoked using next().
ds = qt.maths_data_generator( cfg )

In [None]:
# Test data generator
tokens = next(ds)
print(tokens[:3,:])

# Part 5: Set Up: Load Model from HuggingFace

In [None]:
main_repo_name="PhilipQuirke/Accurate6DigitSubtraction"
print("Loading model from HuggingFace", main_repo_name, main_fname_pth)

cfg.main_model.load_state_dict(utils.download_file_from_hf(repo_name=main_repo_name, file_name=main_fname_pth, force_is_torch=True))
cfg.main_model.eval()

# Part 7A: Set Up: Create sample maths questions


In [None]:
varied_questions = qt.make_maths_test_questions_and_answers(cfg)
num_varied_questions = varied_questions.shape[0]

qt.a_set_ablate_hooks(cfg)
qt.a_calc_mean_values(cfg, varied_questions)

In [None]:
print("Num questions:", num_varied_questions, "Question length:", len(varied_questions[0]))

# Part 7B: Results: Can the model correctly predict sample questions?

Ask the model to predict the varied_questions (without intervention) to see if the model gets them all right. Categorise answers by complexity

In [None]:
# Test maths question prediction accuracy on the sample questions provided.
# Does NOT use acfg.* or UsefulInfo.* information
# Used to estimate the accuracy of the model's predictions.
# Returns a reduced set of questions - removing questions that the model failed to answer.
print_config()

acfg.show_test_failures = True
varied_questions = qt.test_maths_questions_by_complexity(cfg, acfg, varied_questions)
acfg.show_test_failures = False

num_varied_questions = varied_questions.shape[0]

# Part 10: Set Up: Which token positions are used by the model?

Here we ablate all heads in each (question and answer) token position (overriding the model memory aka residual stream) and see if the model's prediction loss increases. If loss increases the token position is used by the algorithm. Unused token positions can be excluded from further analysis.

Used to calculate the UsefulInfo.useful_positions. This is (question and answer) **token-position level** information.  

In [None]:
num_failures_list = []

for position in range(cfg.n_ctx()):
  # Test accuracy of model in predicting question answers. Ablates all nodes at acfg.ablate_position
  # Does NOT use UsefulInfo.* information. Used to populate UsefulInfo.useful_positions
  num_fails = qt.test_maths_questions_by_impact(cfg, acfg, varied_questions, position, True)

  num_failures_list += [num_fails] if num_fails > 0 else "."

  if num_fails > 0:
    cfg.add_useful_position(position)

# Part 11: Results: Which token positions are used by the model?

Which token positions are is used in the model's predictions? Unused token positions can be excluded from further analysis.

In [None]:
print_config()
print("num_questions=", num_varied_questions)
print("useful_positions=", cfg.useful_positions )
print()

cfg.calc_position_failures_map(num_failures_list, 16)
qt.save_plt_to_file(cfg=cfg, full_title="Failures When Position Ablated")
plt.show()

# Part 12A: Set Up: Which nodes are used by the model?

Here we ablate each (attention head and MLP neuron) node in each (question and answer) token position see if the model's prediction loss increases. If loss increases then the "node + token position" is used by the algorithm. Unused node+positions can be excluded from further analysis.

Used to calculate the UsefulInfo.useful_node_location. This is (question and answer) **node+position level** information.

In [None]:
cfg.useful_nodes = qt.UsefulNodeList()

acfg.verbose = False
qt.ablate_mlp_and_add_useful_node_tags(cfg, varied_questions, qt.test_maths_questions_and_add_useful_node_tags)
qt.ablate_head_and_add_useful_node_tags(cfg, varied_questions, qt.test_maths_questions_and_add_useful_node_tags)
qt.add_node_attention_tags(cfg, varied_questions)

cfg.useful_nodes.sort_nodes()

# Part 12B: Results: Which nodes are used by the model?

Here are the (attention head and MLP neuron) node in each (question and answer) token position used by the model during predictions.

In [None]:
cfg.useful_nodes.print_node_tags()

 # Part 13: Set up: Show and save Quanta map

 Using the UsefulNodes and filtering their tags, show a 2D map of the nodes and the tag minor versions.

In [None]:
def show_quanta_map( title, major_tag : qt.QType, minor_tag : str, get_node_details,
        image_width_inches : int = 9, image_height_inches : int = 6,
        blue_shades : bool = True, cell_num_shades : int = 6,
        filters : qt.FilterNode = None, cell_fontsize : int = 9,
        combine_identical_cells : bool = True, show_perc_circles : bool = False ):

  test_nodes = cfg.useful_nodes
  if filters is not None:
    test_nodes = qt.filter_nodes(test_nodes, filters)

  ax1, quanta_results, num_results = qt.calc_quanta_map(
      cfg, blue_shades, cell_num_shades,
      test_nodes, major_tag.value, minor_tag, get_node_details,
      cell_fontsize, combine_identical_cells, show_perc_circles,
      image_width_inches, image_height_inches )

  if num_results > 0:
    if cfg.graph_file_suffix > "":
      print("Saving quanta map:", title)
      qt.save_plt_to_file(cfg=cfg, full_title=title)
    else:
      ax1.set_title(cfg.file_config_prefix() + ' ' + title + ' ({} nodes)'.format(len(quanta_results)))

    plt.show()

# Part 16A: Results: Show failure percentage map

Show the percentage failure rate (incorrect prediction) when individual Attention Heads and MLPs are ablated.

A cell containing "< 1" may add some risk to the accuracy of the overall analysis process. Check to see if this represents a new use case. Improve the test data set to contain more instances of this (new or existing) use case.

In [None]:
show_quanta_map( "Failure Frequency Behavior Per Node",
                qt.QType.FAIL, "", qt.get_quanta_fail_perc,
                image_height_inches = 2 * cfg.n_layers,
                cell_num_shades = qt.FAIL_SHADES, combine_identical_cells = False, show_perc_circles = True)

# Part 16B - Show answer impact behavior map

Show the purpose of each useful cell by impact on the answer digits A0 to Amax

In [None]:
show_quanta_map( "Answer Impact Behavior Per Node",
                qt.QType.IMPACT, "", qt.get_quanta_impact,
                image_height_inches = 2 * cfg.n_layers,
                cell_num_shades = cfg.num_answer_positions)

# Part 16C: Result: Show attention map

Show attention quanta of useful heads

In [None]:
# Only maps attention heads, not MLP layers
show_quanta_map( "Attention Behavior Per Head",
                qt.QType.ATTN, "", qt.get_quanta_attention,
                image_height_inches = 3 * cfg.n_layers,
                cell_num_shades = qt.ATTN_SHADES )

# Part 16C - Show question complexity map

Show the "minimum" addition purpose of each useful cell by S0 to S5 quanta.
Show the "minimum" subtraction purpose of each useful cell by M0 to M5 quanta

In [None]:
if cfg.perc_sub > 0:
  # For each useful cell, show if addition (S), positive-answer subtraction (M) and negative-answer subtraction (N) questions relies on the node.
  show_quanta_map( "Maths Operation Coverage",
                  qt.QType.MATH, "", qt.get_maths_operation_complexity,
                  image_height_inches = 1.75 * cfg.n_layers,
                  blue_shades = False, cell_num_shades = 4, combine_identical_cells = False)

In [None]:
if cfg.perc_add() > 0:
  # For each useful cell, show the minimum addition question complexity that relies on the node, as measured using quanta S0, S1, S2, ...
  show_quanta_map( "Addition Min-Complexity",
                  qt.QType.MATH_ADD, qt.MathsBehavior.ADD_COMPLEXITY_PREFIX.value, qt.get_maths_min_complexity,
                  image_height_inches = 1.25 * cfg.n_layers,
                  blue_shades = False, cell_num_shades = qt.MATH_ADD_SHADES)

In [None]:
if cfg.perc_sub > 0:
  # For each useful cell, show the minimum "positive-answer subtraction" question complexity that relies on the node, as measured using quanta M0, M1, M2, ...
  show_quanta_map( "Positive-answer Subtraction Min-Complexity",
                  qt.QType.MATH_SUB, qt.MathsBehavior.SUB_COMPLEXITY_PREFIX.value, qt.get_maths_min_complexity,
                  image_height_inches = 1.5 * cfg.n_layers,
                  blue_shades = False, cell_num_shades = qt.MATH_SUB_SHADES)

In [None]:
if cfg.perc_sub > 0:
  # For each useful cell, show the minimum "negative-answer subtraction" question complexity that relies on the node, as measured using quanta N0, N1, N2, ...
  show_quanta_map( "Negative-answer Subtraction Min-Complexity",
                  qt.QType.MATH_NEG, qt.MathsBehavior.NEG_COMPLEXITY_PREFIX.value, qt.get_maths_min_complexity,
                  image_height_inches = 1.5 * cfg.n_layers,
                  blue_shades = False, cell_num_shades = qt.MATH_SUB_SHADES)

#Part 18: Set Up: Calc and graph PCA decomposition

In [None]:
# Create a cache of sample maths questions based on the T8, T9, T10 categorisation in cfg.tricase_questions_dict
qt.make_maths_tricase_questions(cfg)

def pca_evr_0_percent(pca):
  return int(round(pca.explained_variance_ratio_[0]*100,0))

In [None]:
# Calculate one Principal Component Analysis
def calc_pca_for_an(node_location, operation, answer_digit):
  assert node_location.is_head == True

  try:
    t_questions = cfg.tricase_questions_dict[(answer_digit, operation)]

    _, the_cache = cfg.main_model.run_with_cache(t_questions)

    # Gather attention patterns for all the (randomly chosen) questions
    attention_outputs = []
    for i in range(len(t_questions)):

      # Output of individual heads, without final bias
      attention_cache=the_cache["result", node_location.layer, "attn"] # Output of individual heads, without final bias
      attention_output=attention_cache[i]  # Shape [n_ctx, n_head, d_model]
      attention_outputs.append(attention_output[node_location.position, node_location.num, :])

    attn_outputs = torch.stack(attention_outputs, dim=0).cpu()

    pca = PCA(n_components=6)
    pca.fit(attn_outputs)
    pca_attn_outputs = pca.transform(attn_outputs)

    title = node_location.name() + ', A'+str(answer_digit) + ', EVR[0]=' + str(pca_evr_0_percent(pca)) + '%'

    return pca, pca_attn_outputs, title
  except Exception as e:
    print( "calc_pca_for_an Failed:" + node_location.name() + " " + qt.token_to_char(cfg, operation) + " " + qt.answer_name(answer_digit), e)
    return None, None, None

In [None]:
# Plot the PCA of PnLnHn's attention pattern, using T8, T9, T10 questions that differ in the An digit
def plot_pca_for_an(ax, pca_attn_outputs, title):
  ax.scatter(pca_attn_outputs[:qt.TRICASE_QUESTIONS, 0], pca_attn_outputs[:qt.TRICASE_QUESTIONS, 1], color='red', label='T8 (0-8)') # t8 questions
  ax.scatter(pca_attn_outputs[qt.TRICASE_QUESTIONS:2*qt.TRICASE_QUESTIONS, 0], pca_attn_outputs[qt.TRICASE_QUESTIONS:2*qt.TRICASE_QUESTIONS, 1], color='green', label='T9') # t9 questions
  ax.scatter(pca_attn_outputs[2*qt.TRICASE_QUESTIONS:, 0], pca_attn_outputs[2*qt.TRICASE_QUESTIONS:, 1], color='blue', label='T10 (10-18)') # t10 questions
  if title != "" :
    ax.set_title(title)

In [None]:
def pca_op_tag(the_digit, operation, strong):
  minor_tag_prefix = qt.MathsBehavior.ADD_PCA_TAG if operation == qt.MathsToken.PLUS else qt.MathsBehavior.SUB_PCA_TAG
  return qt.answer_name(the_digit)  + "." + (minor_tag_prefix.value) + ( "" if strong else ".Weak")

In [None]:
def manual_node_pca(ax, position, layer, num, operation, answer_digit):

  node_location = qt.NodeLocation(position, layer, True, num)
  pca, pca_attn_outputs, title = calc_pca_for_an(node_location, operation, answer_digit)
  plot_pca_for_an(ax, pca_attn_outputs, title)

  major_tag = qt.QType.MATH_ADD if operation == qt.MathsToken.PLUS else qt.QType.MATH_SUB
  cfg.add_useful_node_tag( node_location, major_tag.value, pca_op_tag(answer_digit, operation, True) )


def auto_node_pca(ax, index, node_location, operation, answer_digit, perc_threshold):

  pca, pca_attn_outputs, title = calc_pca_for_an(node_location, operation, answer_digit)
  if pca is not None:
    perc = pca_evr_0_percent(pca)
    if perc > perc_threshold:
      plot_pca_for_an(ax, pca_attn_outputs, title)

      major_tag = qt.QType.MATH_ADD if operation == qt.MathsToken.PLUS else qt.QType.MATH_SUB
      cfg.add_useful_node_tag( node_location, major_tag.value, pca_op_tag(answer_digit, operation, False) )
      return True

  return False

In [None]:
def manual_nodes_pca(op, nodes):
  print("Manual PCA tags for", cfg.model_name, "with operation", qt.token_to_char(cfg, op))

  cols = 4
  rows = 1 + (len(nodes)+1) // cols

  fig, axs = plt.subplots(rows, cols)
  fig.set_figheight(rows*2 + 1)
  fig.set_figwidth(10)

  index = 0
  for node in nodes:
    manual_node_pca(axs[index // cols, index % cols], node[0], node[1], node[2], op, node[3])
    index += 1

  # Remove any graphs we dont need (except last one)
  while index < rows * cols - 1:
    ax = axs[index // cols, index % cols]
    ax.remove()
    index += 1

  # Replace last graph with the legend
  lines_labels = [axs[0,0].get_legend_handles_labels()]
  lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
  axs[rows-1, cols-1].legend(lines, labels)
  axs[rows-1, cols-1].axis('off') # Now, to hide the last subplot

  plt.tight_layout()
  qt.save_plt_to_file(cfg=cfg, full_title='Pca Tr')
  plt.show()

#Part 19B: Results: Manual interpretation of PCA results

If an attention head and an answer digit An gives an interpretable response (2 or 3 distinct output clusters) on 3 groups of questions aligned to T8, T9 and T10 definitions, then plot the response and add a PCA tag

In [None]:
# Plot all attention heads with the clearest An selected
if use_pca:

  if cfg.model_name == "add_d5_l2_h3_train15K" :
    manual_nodes_pca(qt.MathsToken.PLUS,
      [[ 8, 0, 1, 2],
      [  9, 0, 1, 1],
      [ 10, 0, 1, 0],
      [ 11, 0, 1, 3],
      [ 11, 0, 2, 4],
      [ 12, 0, 1, 3],
      [ 13, 0, 1, 2],
      [ 14, 0, 1, 1],
      [ 15, 0, 1, 0]])

  if cfg.model_name == "add_d6_l2_h3_train15K" :
    manual_nodes_pca(qt.MathsToken.PLUS,
      [[ 11, 0, 0, 2],
      [ 12, 0, 0, 3],
      [ 13, 0, 0, 1],
      [ 14, 0, 0, 4],
      [ 14, 1, 1, 4],
      [ 15, 0, 0, 4],
      [ 15, 1, 1, 4],
      [ 15, 1, 2, 4],
      [ 17, 0, 0, 2]])

else:
  print( "PCA library failed to import. So PCA not done")

#Part 19C: Results: Automatic interpetation of PCA results

Part 19B is manual and selective. This part is automatic. It tests nodes not included in Part 19B, where this first (single) principal component explains 66% or more of the node. It adds a PCA "weak" tag

In [None]:
def auto_find_pca_node(node, op, perc_threshold):
  fig, axs = plt.subplots(2, 4) # Allow up to 8 graphs
  fig.set_figheight(4)
  fig.set_figwidth(10)

  index = 0
  for answer_digit in range(cfg.n_digits+1):
    ax = axs[index // 4, index % 4]
    if auto_node_pca(ax, index, node, op, answer_digit, perc_threshold):
      index += 1

  # Remove any graphs we dont need after all
  while index < 2 * 4:
    ax = axs[index // 4, index % 4]
    ax.remove()
    index += 1

  plt.tight_layout()
  plt.show()

In [None]:
def auto_find_pca(operation):
  print("Automatic (weak) PCA tags for", cfg.model_name, "with operation", qt.token_to_char(cfg, operation))
  perc_threshold = 75

  for node in cfg.useful_nodes.nodes:

    # Exclude nodes with a (manual) PCA tag - for any answer digit(s)). Exclude MLP neurons.
    major_tag = qt.QType.MATH_ADD if operation == qt.MathsToken.PLUS else qt.QType.MATH_SUB
    minor_tag_prefix = qt.MathsBehavior.ADD_PCA_TAG if operation == qt.MathsToken.PLUS else qt.MathsBehavior.SUB_PCA_TAG
    if node.is_head and not node.contains_tag(major_tag.value, minor_tag_prefix.value):
      print( "Doing PCA on node", node.name())

      auto_find_pca_node(node, operation, perc_threshold)

if False: # Suppressed for speed
  if use_pca:
    if cfg.perc_add() > 0:
      auto_find_pca(qt.MathsToken.PLUS)
    if cfg.perc_sub > 0:
      auto_find_pca(qt.MathsToken.MINUS)

# Part 20A: Results: Show useful nodes and behaviour tags

In [None]:
cfg.useful_nodes.sort_nodes()
cfg.useful_nodes.print_node_tags()

# Part 20B: Results: Save useful nodes and behaviour tags to json file

In [None]:
# Serialize and save the useful nodes list to a temporary CoLab file in JSON format
print( "Saving useful node list with behavior tags:", main_fname_behavior_json)
cfg.useful_nodes.save_nodes(main_fname_behavior_json)

# Part 21 : Set up: Interchange Interventions

Set up a framework to run the model with two questions:
- The first "store" question is run without hooks
- The second "clean" question is run with hooks interjecting some data from the "store" run. This run gives, not a "clean" answer, but an "intervened" answer, which mixes the "store" answer and the "clean" answer.

In [None]:
def run_intervention_core(node_locations, store_question, clean_question, operation, expected_answer_impact, expected_answer_int, strong):
  assert(len(node_locations) > 0)
  assert(store_question[0] < + 10 ** cfg.n_digits)
  assert(store_question[1] > - 10 ** cfg.n_digits)
  assert(store_question[0] < + 10 ** cfg.n_digits)
  assert(store_question[1] > - 10 ** cfg.n_digits)
  assert(clean_question[0] < + 10 ** cfg.n_digits)
  assert(clean_question[1] > - 10 ** cfg.n_digits)
  assert(clean_question[0] < + 10 ** cfg.n_digits)
  assert(clean_question[1] > - 10 ** cfg.n_digits)

  # Calculate the test (clean) question answer e.g. "+006671"
  clean_answer_int = clean_question[0]+clean_question[1] if operation == qt.MathsToken.PLUS else clean_question[0]-clean_question[1]
  clean_answer_str = qt.int_to_answer_str(cfg, clean_answer_int)
  expected_answer_str = qt.int_to_answer_str(cfg, expected_answer_int)

  # Matrices of tokens
  store_question_and_answer = qt.make_maths_questions_and_answers(cfg, acfg.operation, qt.QType.UNKNOWN, qt.MathsBehavior.UNKNOWN, [store_question])
  clean_question_and_answer = qt.make_maths_questions_and_answers(cfg, acfg.operation, qt.QType.UNKNOWN, qt.MathsBehavior.UNKNOWN, [clean_question])

  acfg.reset_intervention(expected_answer_str, expected_answer_impact, operation)
  acfg.ablate_node_locations = node_locations

  run_description = qt.a_run_attention_intervention(cfg, store_question_and_answer, clean_question_and_answer, clean_answer_str)

  return "Intervening on " + acfg.node_names() + ", " + ("Strong" if strong else "Weak") + ", Node[0]=" + acfg.ablate_node_locations[0].name() + ", " + run_description


# Run an intervention where we have a precise expectation of the intervention impact
def run_strong_intervention(node_locations, store_question, clean_question, operation, expected_answer_impact, expected_answer_int):

  # These are the actual model prediction outputs (while applying our node-level intervention).
  description = run_intervention_core(node_locations, store_question, clean_question, operation, expected_answer_impact, expected_answer_int, True)

  answer_success = (acfg.intervened_answer == acfg.expected_answer)
  impact_success = (acfg.intervened_impact == acfg.expected_impact)
  success = answer_success and impact_success

  if acfg.show_test_failures and not success:
    print("Failed: " + description)

  return success, answer_success, impact_success


# Run an intervention where we expect the intervention to have a non-zero impact and we cant precisely predict the answer
def run_weak_intervention(node_locations, store_question, clean_question, operation):

  # Calculate the test (clean) question answer e.g. "+006671"
  expected_answer_int = clean_question[0]+clean_question[1] if operation == qt.MathsToken.PLUS else clean_question[0]-clean_question[1]

  description = run_intervention_core(node_locations, store_question, clean_question, operation, qt.NO_IMPACT_TAG, expected_answer_int, False)

  success = not ((acfg.intervened_answer == acfg.expected_answer) or (acfg.intervened_impact == qt.NO_IMPACT_TAG))

  if acfg.show_test_failures and not success:
    print("Failed: Intervention had no impact on the answer", description)

  return success

In [None]:
def repeat_digit(digit):
    return int(str(digit) * cfg.n_digits)


# unit test
if cfg.n_digits == 6:
  assert repeat_digit(4) == 444444

In [None]:
def succeed_test(node_locations, alter_digit, strong):
  print( "Test confirmed", node_locations[0].name(), node_locations[1].name() if len(node_locations)>1 else "", "" if strong else "Weak")
  return True

In [None]:
cfg.useful_nodes.reset_node_tags(qt.QType.ALGO.value)

## Part 21C: Automated An.SA search

Search for addition "Simple Add" (SA) tasks e.g. 555555+111111=+0666666 where D3 + D'3 < 10

The SA tasks is sometimes split/shared over 2 attention heads in the same position and layer.

In [None]:
def add_sa_tag(the_digit):
  return qt.answer_name(the_digit) + "." + qt.MathsTask.SA_TAG.value

In [None]:
# These rules are prerequisites for (not proof of) an Addition BaseAdd node
def add_sa_prereqs(cfg, position, impact_digit):
  return qt.FilterAnd(
    qt.FilterHead(),
    qt.FilterPosition(qt.position_name(position)),
    qt.FilterAttention(cfg.dn_to_position_name(impact_digit)), # Attends to Dn
    qt.FilterAttention(cfg.ddn_to_position_name(impact_digit)), # Attends to D'n
    qt.FilterImpact(qt.answer_name(impact_digit)), # Impacts An
    qt.FilterAlgo(add_sa_tag(impact_digit), qt.QCondition.NOT)) # Has not already been flagged as a BA task

In [None]:
def add_sa_test(cfg, acfg, node_locations, alter_digit, strong):
  intervention_impact = qt.answer_name(alter_digit)

  store_question = [repeat_digit(2), repeat_digit(3)] # 222222 + 333333 = 555555
  clean_question = [repeat_digit(5), repeat_digit(4)] # 555555 + 444444 = 999999
  intervened_answer = repeat_digit(9) + (5 - 9) * (10 ** alter_digit)
  success1, answer_success1, impact_success1 = run_strong_intervention(node_locations, store_question, clean_question, qt.MathsToken.PLUS, intervention_impact, intervened_answer)

  store_question = [repeat_digit(2), repeat_digit(1)] # 222222 + 111111 = 333333
  clean_question = [repeat_digit(5), repeat_digit(4)] # 555555 + 444444 = 999999
  intervened_answer = int(repeat_digit(9)) + (3 - 9) * (10 ** alter_digit)
  success2, answer_success2, impact_success2 = run_strong_intervention(node_locations, store_question, clean_question, qt.MathsToken.PLUS, intervention_impact, intervened_answer)

  success = (success1 and success2) if strong else (impact_success1 and impact_success2)

  if success:
    print( "Test confirmed:", acfg.node_names(), "perform A"+str(alter_digit)+".SA = (D"+str(alter_digit)+" + D'"+str(alter_digit)+") % 10 impacting "+intervention_impact+" accuracy.", "" if strong else "Weak", acfg.intervened_answer)

  return success

In [None]:
#cfg.useful_nodes.reset_node_tags(QType.ALGO.value)
#acfg.show_test_failures = True
qt.search_and_tag( cfg, acfg, add_sa_prereqs, add_sa_test, add_sa_tag, True, True )

## Part 22C: Automated An.SC search

Search for addition "Make Carry 1" (SC) tasks e.g. 222222+666966=+0889188 where D2 + D'2 > 10

In [None]:
def add_sc_tag(impact_digit):
  return qt.answer_name(impact_digit-1)  + "." + qt.MathsTask.SC_TAG.value

In [None]:
# These rules are prerequisites for (not proof of) an Addition MakeCarry node
def add_sc_prereqs(cfg, position, impact_digit):
  return qt.FilterAnd(
    qt.FilterHead(),
    qt.FilterPosition(qt.position_name(position)),
    qt.FilterAttention(cfg.dn_to_position_name(impact_digit-1)), # MC is calculated on the next lower-value digit.
    qt.FilterAttention(cfg.ddn_to_position_name(impact_digit-1)), # MC is calculated on the next lower-value digit.
    qt.FilterImpact(qt.answer_name(impact_digit)))

In [None]:
def add_sc_test(cfg, acfg, node_locations, impact_digit, strong):
  alter_digit = impact_digit - 1

  if alter_digit < 0 or alter_digit >= cfg.n_digits:
    acfg.reset_intervention()
    return False

  intervention_impact = qt.answer_name(impact_digit)

  # 222222 + 666966 = 889188. Has Dn.MC
  store_question = [repeat_digit(2), repeat_digit(6)]
  store_question[1] += (9 - 6) * (10 ** alter_digit)

  # 333333 + 555555 = 888888. No Dn.MC
  clean_question = [repeat_digit(3), repeat_digit(5)]

  # When we intervene we expect answer 889888
  intervened_answer = clean_question[0] + clean_question[1] + 10 ** (alter_digit+1)

  success, _, _ = run_strong_intervention(node_locations, store_question, clean_question, qt.MathsToken.PLUS, intervention_impact, intervened_answer)

  if success:
    print( "Test confirmed", acfg.node_names(), "perform A"+str(alter_digit)+".SC impacting "+intervention_impact+" accuracy.", "" if strong else "Weak")

  return success

In [None]:
qt.search_and_tag( cfg, acfg, add_sc_prereqs, add_sc_test, add_sc_tag)

## Part 22E: Automated Dn.ST search

Search for D0.ST to D5.ST with impact "A65432" to "A65" in early tokens.

A0 and A1 are too simple to need Dn.ST values so they are excluded from the answer impact.

In [None]:
def add_st_tag(focus_digit):
  return "D" + str(focus_digit) + "." + qt.MathsTask.ST_TAG.value

In [None]:
# These rules are prerequisites for (not proof of) an Addition Dn.ST node
def add_st_prereqs(cfg, position, focus_digit):
  return qt.FilterAnd(
    qt.FilterHead(),
    qt.FilterPosition(qt.position_name(position)),
    qt.FilterAttention(cfg.dn_to_position_name(focus_digit)), # Attends to Dn
    qt.FilterAttention(cfg.ddn_to_position_name(focus_digit)), # Attends to D'n
    qt.FilterContains(qt.QType.MATH_ADD, qt.MathsBehavior.ADD_PCA_TAG.value)) # Node PCA is interpretable (bigram or trigram output) with respect to T8,T9,T10

In [None]:
def add_st_test(cfg, acfg, node_locations, focus_digit, strong):
  # 222222 + 777977 = 1000188. Has Dn.MC
  store_question = [repeat_digit(2), repeat_digit(7)]
  store_question[1] += (9 - 7) * (10 ** focus_digit)

  # 333333 + 666666 = 999999. No Dn.MC
  clean_question = [repeat_digit(3), repeat_digit(6)]

  success = run_weak_intervention(node_locations, store_question, clean_question, qt.MathsToken.PLUS)

  if success:
    description = acfg.node_names() + " perform D"+str(focus_digit)+".C = TriCase(D"+str(focus_digit)+" + D'"+str(focus_digit)+")"
    print("Test confirmed", description, "Impact:", acfg.intervened_impact, "" if strong else "Weak")

  return success

In [None]:
qt.search_and_tag( cfg, acfg, add_st_prereqs, add_st_test, add_st_tag)

# Part 21D: Show algorithm quanta map

Plot the "algorithm" tags generated in previous steps as a quanta map. This is an automatically generated partail explanation of the model algorithm.

In [None]:
qt.print_algo_purpose_results(cfg)
print()

show_quanta_map( "Maths Purpose Per Node", qt.QType.ALGO, "", qt.get_quanta_binary, cell_num_shades = 2)

# Part21E: Save useful nodes and tags to JSON file

Show a list of the nodes that have proved useful in calculations.
For each useful node, show its useful facts, stored as tags.
Save the data to a Colab temporary JSON file.



In [None]:
cfg.useful_nodes.print_node_tags(qt.QType.ALGO.value, "", False)

In [None]:
# Serialize and save the useful nodes list with algorithm tags to a temporary CoLab file in JSON format
print( "Saving useful node list with algorithm tags:", main_fname_algorithm_json)
cfg.useful_nodes.save_nodes(main_fname_algorithm_json)

# Part 21F: TBA

In [None]:
print( "Claim that P10.L0.H1 performs D1.C2 = TriAdd(V1.C, TriCase(D0, D0’)) impacting A5, A4, A3 & A2 accuracy")
print()
nodes = [qt.NodeLocation(10, 0, True, 1)]

store_question = [ 11111, 33333] # Sum is 044444. V0 has no MC.
clean_question = [ 44444, 55555] # Sum is 099999. V0 has no MC
run_strong_intervention( nodes, store_question, clean_question, qt.MathsToken.PLUS, qt.NO_IMPACT_TAG, 99999 )

store_question = [ 11117, 11117] # Sum is 022234. V0 has MC
clean_question = [ 44444, 55555] # Sum is 099999. V0 has no MC
run_strong_intervention( nodes, store_question, clean_question, qt.MathsToken.PLUS, "A5432", 100099 )

store_question = [ 11117, 11117] # Sum is 022234. V0 has MC
clean_question = [  4444,  5555] # Sum is 009999. V0 has no MC
run_strong_intervention( nodes, store_question, clean_question,qt.MathsToken.PLUS, "A432", 10099 )

store_question = [ 11117, 11117] # Sum is 022234. V0 has MC
clean_question = [   444,   555] # Sum is 000999. V0 has no MC
run_strong_intervention( nodes, store_question, clean_question, qt.MathsToken.PLUS, "A32", 1099 )

store_question = [ 11117, 11117] # Sum is 022234. V0 has MC
clean_question = [    44,    55] # Sum is 000099. V0 has no MC
run_strong_intervention( nodes, store_question, clean_question, qt.MathsToken.PLUS, "A2", 199 )

# Deprecated: Confirmed that P10.L0.H1 is: Based on D0 and D0'. Triggers on a V0 carry value. Provides "carry 1" used in A5, A4, A3 & A2 calculation.

In [None]:
print( "Claim that P11.L0.H1 performs D3.C4 = TriAdd(TriCase(D3, D3’),TriAdd(V2.C,V1.C2)) impacting A5 accuracy")
print()
nodes = [qt.NodeLocation(11, 0, True, 1)]

store_question = [44444, 44444] # Sum is 088888. V3 sums to 8 (has no MC).
clean_question = [11111, 11111] # Sum is 022222. V3 has no MC.
run_strong_intervention( nodes, store_question, clean_question, qt.MathsToken.PLUS, qt.NO_IMPACT_TAG, 22222 )

store_question = [16111, 13111] # Sum is 032111. V3 sums to 9 (has no MC).
clean_question = [44444, 55555] # Sum is 099999. V3 has no MC
run_strong_intervention( nodes, store_question, clean_question, qt.MathsToken.PLUS, qt.NO_IMPACT_TAG, 99999 )

store_question = [16111, 16111] # Sum is 032111. V3 has MC
clean_question = [44444, 55555] # Sum is 099999. V3 has no MC
run_strong_intervention( nodes, store_question, clean_question,qt.MathsToken.PLUS, "A5", 199999 )

# Deprecated: Confirmed that P11.L0.H1 is: Based on D3 and D3'. Triggers on a V3 carry value. Provides "carry 1" used in A5 calculations.

In [None]:
print( "Claim that P11.L0.H2 performs D4.C = TriCase(D4, D4’) impacting A5 accuracy")
print()
nodes = [qt.NodeLocation(11, 0, True, 2)]

store_question = [44444, 55555] # Sum is 099999. V4 has no MC.
clean_question = [11111, 11111] # Sum is 022222. V4 has no MC.
run_strong_intervention( nodes, store_question, clean_question, qt.MathsToken.PLUS, qt.NO_IMPACT_TAG, 22222)

store_question = [71111, 71111] # Sum is 100422. V4 has MC
clean_question = [44444, 55555] # Sum is 099999. V4 has no MC
run_strong_intervention( nodes, store_question, clean_question, qt.MathsToken.PLUS, "A5", 199999 )

# Deprecated: Confirmed that P9.L0.H2 is: Based on D4 and D4'. Triggers on a V4 carry value. Provides "carry 1" used in A5 calculation.

#Part 22: MLP Visualisation (incomplete, on-hold)

In [None]:
import einops
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from IPython.display import clear_output


def get_mlp_data(data_set_name):
  cfg.main_model.reset_hooks()
  cfg.main_model.set_use_attn_result(True)
  sample_logits, sample_cache = cfg.main_model.run_with_cache(varied_questions.cuda())
  data_set = sample_cache[data_set_name]
  # print( data_set_name + " shape", data_set.shape) # 239, 22, 2040 = num_varied_questions, cfg.n_ctx, cfg.d_mlp

  raw_data = data_set[:,-3]
  # print( "raw_data shape", raw_data.shape) # 239, 2040 = num_varied_questions, cfg.d_mlp

  answer = einops.rearrange(raw_data, "(x y) d_mlp -> x y d_mlp", x=num_varied_questions).cpu().numpy()
  # print( "answer shape", answer.shape) # 239, 1, 2040 = num_varied_questions, ??, cfg.d_mlp

  return answer


l0_mlp_hook_pre_sq = get_mlp_data('blocks.0.mlp.hook_pre')
l0_mlp_hook_post_sq = get_mlp_data('blocks.0.mlp.hook_post')
l1_mlp_hook_pre_sq = get_mlp_data('blocks.1.mlp.hook_pre') if cfg.n_layers > 1 else l0_mlp_hook_pre_sq
l1_mlp_hook_post_sq = get_mlp_data('blocks.1.mlp.hook_post') if cfg.n_layers > 1 else l0_mlp_hook_post_sq


def plot_mlp_neuron_activation(pos: int):
    clear_output()

    l0_mlp_pre_data = l0_mlp_hook_pre_sq[:,:,pos]
    l0_mlp_post_data = l0_mlp_hook_post_sq[:,:,pos]
    l1_mlp_pre_data = l1_mlp_hook_pre_sq[:,:,pos]
    l1_mlp_post_data = l1_mlp_hook_post_sq[:,:,pos]

    fig, axs = plt.subplots(1, 2, figsize=(8,4))

    plot = axs[0].imshow(l1_mlp_pre_data, cmap='magma', vmin=0, vmax=1)
    cbar = plt.colorbar(plot, fraction=0.1)
    cbar.set_label(r'l0_mlp_pre_data {}'.format(pos))
    #axs[0].set_ylim(-0.5, 99.5)
    #axs[0].set_yticks(range(100), labels=range(100), size=5.5);
    #axs[0].set_xticks(range(100), labels=range(100), size=5.5, rotation='vertical');

    plot = axs[1].imshow(l1_mlp_post_data, cmap='magma', vmin=0, vmax=1)
    cbar = plt.colorbar(plot, fraction=0.1)
    cbar.set_label(r'l0_mlp_post_data {}'.format(pos))
    #axs[0].set_ylim(-0.5, 99.5)
    #axs[0].set_yticks(range(100), labels=range(100), size=5.5);
    #axs[0].set_xticks(range(100), labels=range(100), size=5.5, rotation='vertical');


interact(plot_mlp_neuron_activation, pos=widgets.IntText(value=0, description='Index:'))

# Part 25 : Results: Can the model do 1 million questions without error?

If the model passes this test, this is evidence (not proof) that the model is fully accurate. There may be very rare edge cases (say 1 in ten million) that did not appear in the test questions. Even if you believe you know all the edge cases, and have enriched the training data to contain them, you may not have thought of all edge cases, so this is not proof.

If the model fails this test:
- Add a few of the failures into the "test questions" in part 6C
- Understand the "use case(s)" driving these failures
- Alter the Training CoLab data_generator_core to enrich the training data with examples if these use case(s) and retrain the model.  

Takes ~25 mins to run (successfully) for ins_mix_d6_l3_h4_train40K_seed372001

In [None]:
def one_million_questions_core():
  global ds

  acfg.verbose = False

  cfg.analysis_seed = 345621 # Randomly chosen
  ds = qt.maths_data_generator() # Re-initialise the data generator

  the_successes = 0
  the_fails = 0

  num_batches = 1000000//cfg.batch_size
  for epoch in range(num_batches):
      tokens = next(ds)

      the_fails = qt.test_maths_questions_by_impact(cfg, acfg, tokens, 0, False)

      if the_fails> 0:
        break

      the_successes = the_successes + cfg.batch_size

      if epoch % 100 == 0:
          print("Batch", epoch, "of", num_batches, "#Successes=", the_successes)

  print("successes", the_successes, "num_fails", the_fails)
  if the_fails > 0:
    "WARNING: Model is not fully accurate. It failed the 1M Q test"

In [None]:
def one_million_questions():
  store_perc_sub = cfg.perc_sub
  store_perc_mult = cfg.perc_mult

  print_config()
  print()

  if cfg.perc_add() > 0:
    print("Addition:")
    cfg.perc_sub = 0
    cfg.perc_mult = 0
    one_million_questions_core()

  if store_perc_sub > 0:
    print("Subtraction:")
    cfg.perc_sub = 100
    cfg.perc_mult = 0
    one_million_questions_core()
    print()

  cfg.perc_sub = store_perc_sub
  cfg.perc_mult = store_perc_mult


# Takes ~25 minutes to run
# one_million_questions()