# Accurate Integer Mathematics in Transformers - Analyse the Model

This CoLab analyses a Transformer model that performs integer addition, subtraction and multiplication e.g. 133357+182243=+0315600, 123450-345670=-0123230 and 000345*000823=+283935. Each digit is a separate token. For 6 digit questions, the model is given 14 "question" (input) tokens, and must then predict the corresponding 8 "answer" (output) tokens.

## Tips for using the Colab
 * You can run and alter the code in this CoLab notebook yourself in Google CoLab ( https://colab.research.google.com/ ).
 * To run the notebook, in Google CoLab, **you will need to** go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.
 * Some graphs are interactive!
 * Use the table of contents pane in the sidebar to navigate.
 * Collapse irrelevant sections with the dropdown arrows.
 * Search the page using the search in the sidebar, not CTRL+F.

# Part 0A: Import libraries
Imports standard libraries. Don't bother reading.

In [None]:
DEVELOPMENT_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")

    !pip install matplotlib

    !pip install kaleido
    !pip install transformer_lens
    !pip install torchtyping
    !pip install transformers

    !pip install numpy
    !pip install scikit-learn

except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting kaleido
  Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl (79.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.9/79.9 MB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import kaleido
import plotly.io as pio

if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

In [None]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [None]:
import json
import torch
import torch.nn.functional as F
import numpy as np
import random
import itertools
import re
from enum import Enum

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import textwrap

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache


# Part 0B: Import PCA library

In [None]:
# Import Principal Component Analysis (PCA) library
use_pca = True
try:
  from sklearn.decomposition import PCA
except Exception as e:
  print("pca import failed with exception:", e)
  use_pca = False

  # Sometimes version conflicts means the PCA library does not import. This workaround partially fixes the issue
  !pip install --upgrade numpy
  !pip install --upgrade scikit-learn

  # To complete workaround, now select menu option "Runtime > Restart session and run all".
  stop

# Part 0C: Import verified_transformers library

In [None]:
!pip install --upgrade git+https://github.com/PhilipQuirke/verified_transformers.git
from QuantaTools import ModelConfig, token_to_char, tokens_to_string

from QuantaTools import UsefulConfig, position_name, position_name_to_int, row_location_name, location_name, NodeLocation, UsefulNode, UsefulNodeList, str_to_node_location, answer_name

from QuantaTools import QuantaFilter, QuantaType, MAX_ATTENTION_TAGS, MIN_ATTENTION_PERC, NO_IMPACT_TAG
from QuantaTools import FilterNode, FilterAnd, FilterOr, FilterHead, FilterNeuron, FilterContains, FilterPosition, FilterAttention, FilterImpact, FilterPCA, FilterAlgo, filter_nodes
from QuantaTools import get_answer_impact, get_question_answer_impact, compact_answer_if_sequential, get_quanta_impact
from QuantaTools import calc_quanta_map, get_quanta_fail_perc, get_quanta_attention, get_quanta_binary

from QuantaTools import MathsConfig, MathsTokens, MathsTag, AlgoTag, set_maths_vocabulary
from QuantaTools import int_to_answer_str, tokens_to_unsigned_int, tokens_to_answer, insert_question_number, make_a_maths_question, sort_unique_digits
from QuantaTools import get_maths_question_complexity, maths_data_generator, maths_data_generator_core, make_maths_questions, make_maths_test_questions

# Part 1A: Configuration: Detailed

In [None]:
# Main configuration class for main model creation and training.
# Derived from MathsConfig > UsefulConfig > ModelConfig
class ColabConfig(MathsConfig):

  def __init__(self):
    super().__init__()

    self.main_model = None


# Singleton class instance
cfg = ColabConfig()

# Part 1B: Configuration: Summary

Which existing model do we want to analyse?

The existing model weightings created by the sister Colab [VerifiedArithmeticTrain](https://github.com/PhilipQuirke/transformer-maths/blob/main/assets/VerifiedArithmeticTrain.ipynb) are loaded from HuggingFace.

In [None]:
# Which existing model do we want to analyse?
# cfg.model_name = "" # Use configuration specified in Part 1A
# cfg.model_name = "add_d5_l1_h3_t30K"  # 5 digit addition model. Inaccurate as only has one layer. Can predict S0, S1 and S2 complexity questions
# cfg.model_name = "add_d5_l2_h3_t15K"  # 5 digit addition model
# cfg.model_name = "add_d6_l2_h3_t15K"  # 6 digit addition model
# cfg.model_name = "sub_d6_l2_h3_t30K"  # 6 digit subtraction model
# cfg.model_name = "mix_d6_l3_h4_t40K"  # 6 digit addition and subtraction model. AvgFinalLoss=8e-09
cfg.model_name = "ins1_mix_d6_l3_h4_t40K"  # 6 digit addition / subtraction model. Initialise with addition model. Handles 1m Qs for Add and Sub
# cfg.model_name = "ins2_mix_d6_l4_h4_t40K"  # 6 digit addition / subtraction model. Initialised with addition model. Reset useful heads every 100 epochs. AvgFinalLoss=7e-09. Fails 1m Qs
# cfg.model_name = "ins3_mix_d6_l4_h3_t40K"  # 6 digit addition / subtraction model. Initialised with addition model. Reset useful heads & MLPs every 100 epochs. AvgFinalLoss=2.6e-06. Fails 1m Qs

# Part 1C: Configuration: Input and Output file names



In [None]:
if cfg.model_name != "":
  # Update cfg member data n_digits, n_layers, n_heads, n_training_steps from model_name
  cfg.parse_model_name()

  if cfg.model_name.startswith("sub_") :
    cfg.perc_sub = 100

  elif cfg.model_name == "mix_d6_l3_h4_t40K" :
    cfg.batch_size = 256
    cfg.perc_sub = 66 # Train on 66% subtraction and 33% addition question batches

  elif cfg.model_name == "ins1_mix_d6_l3_h4_t40K" :
    cfg.batch_size = 256
    cfg.perc_sub = 80 # Train on 80% subtraction and 20% addition question batches
    # Initialise with add_d6_l2_h3_t15K.pth.

  elif cfg.model_name == "ins2_mix_d6_l4_h4_t40K" :
    cfg.batch_size = 256
    cfg.perc_sub = 80 # Train on 80% subtraction and 20% addition question batches
    # Initialise with add_d6_l2_h3_t15K.pth. Train & reset useful heads every 100 epochs

  elif cfg.model_name == "ins3_mix_d6_l4_h3_t40K" :
    cfg.batch_size = 256
    cfg.perc_sub = 80 # Train on 80% subtraction and 20% addition question batches
    # Initialise with add_d6_l2_h3_t15K.pth. Trained & reset useful heads & MLPs every 100 epochs

In [None]:
main_fname = cfg.file_config_prefix()
main_fname_pth = main_fname + '.pth'
main_fname_behavior_json = main_fname + '_behavior.json'
main_fname_algorithm_json = main_fname + '_algorithm.json'

def print_config():
  print("%Mult=", cfg.perc_mult, "%Sub=", cfg.perc_sub, "%Add=", cfg.perc_add(), "File", main_fname)

print_config()
print('Main model will be read from HuggingLab file', main_fname_pth)
print('Main model behavior analysis tags will save to Colab temporary file', main_fname_behavior_json)
print('Main model algorithm analysis tags will save to Colab temporary file', main_fname_algorithm_json)

# Part 3A: Set Up: Vocabulary / Embedding / Unembedding

Convert from
- Human-readable character to numeric token index.
- Convert numeric token positions to position "meanings"
- Convert from number to human-readable stringand vice versa

In [None]:
# Vocabulary dictionary: Mapping from character (key) to token (value)
set_maths_vocabulary(cfg)

In [None]:
# Convert D0 to P5, D1 to P4, D2 to P3 in 6 digit addition
def dn_to_position_name(n):
  return position_name(cfg.n_digits - 1 - n)
# Convert D'0 to P10, D'1 to P9, D'2 to P8, etc in 6 digit addition
def ddn_to_position_name(n):
  return position_name(2 * cfg.n_digits - n)
# Convert A0 to P20, A1 to P19, A2 to P18, etc in 6 digit addition
def an_to_position_name(n):
  #PQR add dependency on ascending and move to library
  return position_name(cfg.n_ctx() - 1 - n)
# Position of the operator (+, -, * or /)
def op_position_name():
  return position_name(cfg.n_digits)


def set_question_meanings():
  # Question and answer token position meanings D5, .., D0, *, D5', .., D0', =, A7, A6, .., A0
  q_meanings = []
  for i in range(cfg.n_digits):
    q_meanings += ["D" + str(cfg.n_digits-i-1)]
  q_meanings += ["Op"] # Stands in for operation +, - or *
  for i in range(cfg.n_digits):
    q_meanings += ["D'" + str(cfg.n_digits-i-1)]
  q_meanings += ["="]

  cfg.token_position_meanings = q_meanings + cfg.token_position_meanings[-cfg.num_answer_positions:]
  print(cfg.token_position_meanings)


set_question_meanings()

# Part 3B: Set Up: Create model

In [None]:
# Transformer creation

# Structure is documented at https://neelnanda-io.github.io/TransformerLens/transformer_lens.html#transformer_lens.HookedTransformerConfig.HookedTransformerConfig
ht_cfg = HookedTransformerConfig(
    n_layers = cfg.n_layers,
    n_heads = cfg.n_heads,
    d_model = cfg.d_model,
    d_head = cfg.d_head,
    d_mlp = cfg.d_mlp(),
    act_fn = cfg.act_fn,
    normalization_type = 'LN',
    d_vocab = cfg.d_vocab,
    d_vocab_out = cfg.d_vocab,
    n_ctx = cfg.n_ctx(),
    init_weights = True,
    device = "cuda",
    seed = cfg.training_seed,
)

cfg.main_model = HookedTransformer(ht_cfg)

# Part 4: Set Up: Loss Function & Data Generator
This section defines the loss function and the training/testing data generator.

In [None]:
# Loss functions

# Calculate the per-token probability by comparing a batch of prediction "logits" to answer "tokens"
def logits_to_tokens_loss(logits, tokens):
  # Addition answer can have one extra digit than question. Answer also has a +/- sign
  n_answer_digits = cfg.answer_tokens()

  # The addition answer digit token probabilities
  ans_logits = logits[:, -(n_answer_digits+1):-1]

  # Convert raw score (logits) vector into a probability distribution.
  # Emphasize the largest scores and suppress the smaller ones, to make them more distinguishable.
  ans_probs = F.log_softmax(ans_logits.to(torch.float64), dim=-1)

  max_prob_tokens = torch.argmax(ans_probs, dim=-1)

  # The addition answer digit tokens
  ans_tokens = tokens[:, -(n_answer_digits):]

  # Extract values from the ans_probs tensor, based on indices from the ans_tokens tensor
  ans_loss = torch.gather(ans_probs, -1, ans_tokens[:, :, None])[..., 0]

  return ans_loss, max_prob_tokens


# Calculate loss as negative of average per-token mean probability
def loss_fn(ans_loss):
  return -ans_loss.mean(0)

In [None]:
# Define "iterator" maths "questions" data generator function. Invoked using next().
ds = maths_data_generator( cfg )

In [None]:
# Generate sample data generator (unit test)
print(next(ds)[:3,:])

# Part 5: Set Up: Load Model from HuggingFace

In [None]:
main_repo_name="PhilipQuirke/VerifiedArithmetic"
print("Loading model from HuggingFace", main_repo_name, main_fname_pth)

cfg.main_model.load_state_dict(utils.download_file_from_hf(repo_name=main_repo_name, file_name=main_fname_pth, force_is_torch=True))
cfg.main_model.eval()

# Part 7A : Set up: Prediction Framework


In [None]:
class T_Config(NodeLocation):


  def __init__(self):
    super().__init__(0, 0, True, 0)
    self.reset()


  def reset(self):
    self.threshold = 0.01
    self.verbose = False

    # How many test questions are in the manually-curated varied_questions test set?
    self.num_varied_questions = 0
    # How many of the manually-curated varied_questions can the model answer?
    self.num_varied_successes = 0

    # attn.hook_z is the "attention head output" hook point name (at a specified layer)
    self.l_attn_hook_z_name = [utils.get_act_name('z', 0, 'a'),utils.get_act_name('z', 1, 'a'),utils.get_act_name('z', 2, 'a'),utils.get_act_name('z', 3, 'a')] # 'blocks.0.attn.hook_z' etc
    # hook_resid_pre is the "pre residual memory update" hook point name (at a specified layer)
    self.l_hook_resid_pre_name = ['blocks.0.hook_resid_pre','blocks.1.hook_resid_pre','blocks.2.hook_resid_pre','blocks.3.hook_resid_pre']
    # hook_resid_post is the "post residual memory update" hook point name (at a specified layer)
    self.l_hook_resid_post_name = ['blocks.0.hook_resid_post','blocks.1.hook_resid_post','blocks.2.hook_resid_post','blocks.3.hook_resid_post']
    # mlp.hook_post is the "MLP layer" hook point name (at a specified layer)
    self.l_mlp_hook_post_name = [utils.get_act_name('post', 0),utils.get_act_name('post', 1),utils.get_act_name('post', 2),utils.get_act_name('post', 3)] # 'blocks.0.mlp.hook_post' etc

    # Sample model outputs used in ablation interventions
    self.mean_attn_z = []
    self.mean_resid_post = []
    self.mean_mlp_hook_post = []


  def print_prediction_success_rate(self):
    bad_predictions = self.num_varied_questions - self.num_varied_successes

    if self.num_varied_questions > 0:
      print(f"Varied_questions prediction success rate = {self.num_varied_successes / self.num_varied_questions * 100:.2f}% ({self.num_varied_successes} good, {bad_predictions} bad)")

    if bad_predictions == 0:
      # This is evidence not proof because there may be very rare edge cases (say 1 in ten million) that do not exist in the test questions.
      # Even if you believe you know all the edge cases, and have enriched the training data to contain them, you may not have thought of all edge cases, so this is not proof.
      print("Model got all test questions correct. This is a pre-requisite for the model to be fully accurate, but this is NOT proof.")
    else:
      print("WARNING: Model is not fully accurate as it got", bad_predictions, "questions wrong.")

In [None]:
class A_Config(T_Config):


  def __init__(self):
    super().__init__()
    self.reset_hooks()
    self.reset_intervention()
    self.reset_intervention_totals()
    self.show_test_failures = False


  def reset_hooks(self):
    # A list of NodeLocations
    self.node_locations = []

    # A list of stored weightings collected from the model.
    # Same length as nodes
    self.layer_store = [[],[],[]]   # Supports 3 layers

    self.questions = []
    self.attn_get_hooks = []
    self.attn_put_hooks = []

    self.reset_node_location()
    self.threshold = 0.00001


  def reset_intervention(self, expected_answer_int = 0, expected_impact = NO_IMPACT_TAG, operation = MathsTokens.PLUS):
    self.operation = operation

    # Expected output of an intervention ablation experiment
    self.expected_answer = int_to_answer_str(cfg,expected_answer_int)
    self.expected_impact = expected_impact if expected_impact != "" else NO_IMPACT_TAG

    # Actual outputs of an intervention ablation experiment
    self.intervened_answer = ""
    self.intervened_impact = NO_IMPACT_TAG

    self.abort = False


  def reset_intervention_totals(self):
    self.num_tests_run = 0
    self.num_tags_added = 0


  def node_names(self):
    answer = ""

    for node in self.node_locations:
      if answer != "":
        answer += ", "
      answer += node.name()

    return answer


acfg = A_Config()

In [None]:
# Get and put attention head value hooks

def a_null_attn_z_hook(value, hook):
  pass


def validate_value(name, value):
  if value.shape[0] == 0:
    print( "Aborted", name, acfg.node_names(), acfg.questions, acfg.operation, acfg.expected_answer, acfg.expected_impact)
    acfg.abort = True # TransformerLens returned a [0, 22, 3, 170] tensor. This is bad data. Bug in code? Abort
    return False

  return True


def a_get_l0_attn_z_hook(value, hook):
  # print( "In a_get_l0_attn_z_hook", value.shape) # Get [1, 22, 3, 170] = ???, cfg.n_ctx, cfg.n_heads, cfg.d_head
  if validate_value("a_get_l0_attn_z_hook", value):
    acfg.layer_store[0] = value.clone()

def a_get_l1_attn_z_hook(value, hook):
  if validate_value("a_get_l1_attn_z_hook", value):
    acfg.layer_store[1] = value.clone()

def a_get_l2_attn_z_hook(value, hook):
  if validate_value("a_get_l2_attn_z_hook", value):
    acfg.layer_store[2] = value.clone()

def a_get_l3_attn_z_hook(value, hook):
  if validate_value("a_get_l3_attn_z_hook", value):
    acfg.layer_store[3] = value.clone()


def a_put_l0_attn_z_hook(value, hook):
  # print( "In a_put_l0_attn_z_hook", value.shape) # Get [1, 22, 3, 170] = ???, cfg.n_ctx, cfg.n_heads, d_head
  for location in acfg.node_locations:
    if location.layer == 0:
      value[:,location.position,location.num,:] = acfg.layer_store[0][:,location.position,location.num,:].clone()

def a_put_l1_attn_z_hook(value, hook):
  for location in acfg.node_locations:
    if location.layer == 1:
      value[:,location.position,location.num,:] = acfg.layer_store[0][:,location.position,location.num,:].clone()

def a_put_l2_attn_z_hook(value, hook):
  for location in acfg.node_locations:
    if location.layer == 2:
      value[:,location.position,location.num,:] = acfg.layer_store[0][:,location.position,location.num,:].clone()

def a_put_l3_attn_z_hook(value, hook):
  for location in acfg.node_locations:
    if location.layer == 3:
      value[:,location.position,location.num,:] = acfg.layer_store[0][:,location.position,location.num,:].clone()


def a_reset(node_locations):
  acfg.reset_hooks()
  acfg.node_locations = node_locations
  acfg.attn_get_hooks = [(acfg.l_attn_hook_z_name[0], a_get_l0_attn_z_hook), (acfg.l_attn_hook_z_name[1], a_get_l1_attn_z_hook), (acfg.l_attn_hook_z_name[2], a_get_l2_attn_z_hook), (acfg.l_attn_hook_z_name[3], a_get_l3_attn_z_hook)][:cfg.n_layers]
  acfg.attn_put_hooks = [(acfg.l_attn_hook_z_name[0], a_put_l0_attn_z_hook), (acfg.l_attn_hook_z_name[1], a_put_l1_attn_z_hook), (acfg.l_attn_hook_z_name[2], a_put_l2_attn_z_hook), (acfg.l_attn_hook_z_name[3], a_put_l3_attn_z_hook)][:cfg.n_layers]

# Part 7B: Set Up: Create sample maths questions

Create a batch of manually-curated mathematics test questions

In [None]:
varied_questions = make_maths_test_questions(cfg)

In [None]:
def calc_mean_values(the_questions):

  # Run the sample batch, gather the cache
  cfg.main_model.reset_hooks()
  cfg.main_model.set_use_attn_result(True)
  sample_logits, sample_cache = cfg.main_model.run_with_cache(the_questions.cuda())
  print(sample_cache) # Gives names of datasets in the cache
  sample_losses_raw, sample_max_prob_tokens = logits_to_tokens_loss(sample_logits, the_questions.cuda())
  sample_loss_mean = utils.to_numpy(loss_fn(sample_losses_raw).mean())
  print("Sample Mean Loss", sample_loss_mean) # Loss < 0.04 is good


  # attn.hook_z is the "attention head output" hook point name (at a specified layer)
  sample_attn_z_0 = sample_cache[acfg.l_attn_hook_z_name[0]]
  print("Sample", acfg.l_attn_hook_z_name[0], sample_attn_z_0.shape) # gives [350, 22, 3, 170] = num_questions, cfg.n_ctx, n_heads, d_head
  acfg.mean_attn_z = torch.mean(sample_attn_z_0, dim=0, keepdim=True)
  print("Mean", acfg.l_attn_hook_z_name[0], acfg.mean_attn_z.shape) # gives [1, 22, 3, 170] = 1, cfg.n_ctx, n_heads, d_head


  # hook_resid_post is the "post residual memory update" hook point name (at a specified layer)
  sample_resid_post_0 = sample_cache[acfg.l_hook_resid_post_name[0]]
  print("Sample", acfg.l_hook_resid_post_name[0], sample_resid_post_0.shape) # gives [350, 22, 510] = num_questions, cfg.n_ctx, d_model
  acfg.mean_resid_post = torch.mean(sample_resid_post_0, dim=0, keepdim=True)
  print("Mean", acfg.l_hook_resid_post_name[0], acfg.mean_resid_post.shape) # gives [1, 22, 510] = 1, cfg.n_ctx, d_model


  # mlp.hook_post is the "MLP layer" hook point name (at a specified layer)
  sample_mlp_hook_post_0 = sample_cache[acfg.l_mlp_hook_post_name[0]]
  print("Sample", acfg.l_mlp_hook_post_name[0], sample_mlp_hook_post_0.shape) # gives [350, 22, 2040] = num_questions, cfg.n_ctx, cfg.d_mlp
  acfg.mean_mlp_hook_post = torch.mean(sample_mlp_hook_post_0, dim=0, keepdim=True)
  print("Mean", acfg.l_mlp_hook_post_name[0], acfg.mean_mlp_hook_post.shape) # gives [1, 22, 2040] = 1, cfg.n_ctx, cfg.d_mlp

In [None]:
calc_mean_values(varied_questions)

# Part 7C: Results: Predict varied questions

Ask the model to predict the varied_questions (without intervention) to see if the model gets them all right. Categorise answers by group

In [None]:
def predict_varied_questions():
  global varied_questions

  num_questions = varied_questions.shape[0]
  correct_list = [True] * num_questions

  all_logits = cfg.main_model(varied_questions.cuda())
  all_losses_raw, all_max_prob_tokens = logits_to_tokens_loss(all_logits, varied_questions.cuda())

  # Evaluate and categorize each object
  categorization_results = {}
  for question_num in range(num_questions):
    q = varied_questions[question_num]

    losses = loss_fn(all_losses_raw[question_num])
    mean_loss = utils.to_numpy(losses.mean())

    model_answer_str = tokens_to_string(cfg, all_max_prob_tokens[question_num])
    model_answer_num = int(model_answer_str)

    major_tag, minor_tag = get_maths_question_complexity(cfg, q)
    group_name = major_tag + "." + minor_tag

    correct = (model_answer_num == tokens_to_answer(cfg, q))
    correct_list[question_num] = correct

    if group_name not in categorization_results:
      categorization_results[group_name] = [0, 0]  # Initialize counts for new group

    if correct:
      categorization_results[group_name][0] += 1  # Increment good count for this group
    else:
      categorization_results[group_name][1] += 1  # Increment bad count for this group


  # Calculate and print summary success rates per group
  acfg.num_varied_questions = 0
  acfg.num_varied_successes = 0
  for group_name, counts in categorization_results.items():
      total = sum(counts)
      success_rate = counts[0] / total * 100 if total != 0 else 0
      print(f"Group {group_name}: Success Rate = {success_rate:.2f}% ({counts[0]} good, {counts[1]} bad)")
      acfg.num_varied_questions += total
      acfg.num_varied_successes += counts[0]


  acfg.print_prediction_success_rate()
  if acfg.num_varied_successes < acfg.num_varied_questions:
    # Remove the questions that the model failed to answer as they turn up in every cell of the quanta maps
    org_size = varied_questions.shape[0]
    varied_questions = varied_questions[torch.tensor(correct_list)]
    new_size = varied_questions.shape[0]
    print("RESOLUTION: Understand these failures. Enrich the training data to provide more examples. Retrain the model.")
    print("INTERIM: Have reduced 'varied_questions' size from", org_size, "to", new_size, "so can continue.")


predict_varied_questions()

# Part 12A: Set Up: Predict Questions and Evaluate Quanta

Get model to predict given question answers, with ablation hook(s), and categorise how many questions fail.

In [None]:
# Ask the model to predict the question answers (with the hooks either reading data, doing intervetion ablations, or doing nothing )
def predict_questions_core(questions, the_hooks):

  cfg.main_model.reset_hooks()
  cfg.main_model.set_use_attn_result(True)

  all_logits = cfg.main_model.run_with_hooks(questions.cuda(), return_type="logits", fwd_hooks=the_hooks)
  all_losses_raw, all_max_prob_tokens = logits_to_tokens_loss(all_logits, questions.cuda())

  return all_losses_raw, all_max_prob_tokens

In [None]:
def c_predict_questions(questions, the_hooks):

  all_losses_raw, all_max_prob_tokens = predict_questions_core(questions, the_hooks)

  num_fails = 0
  for question_num in range(questions.shape[0]):
    q = questions[question_num]

    the_loss_mean = utils.to_numpy(loss_fn(all_losses_raw[question_num]).mean())

    # Only show the question if the loss exceeds the threshold (because of the ablated token position)
    if the_loss_mean > acfg.threshold:
      answer_str = tokens_to_string(cfg, all_max_prob_tokens[question_num])

      # Only count the question if the model got the question wrong
      impact_str = get_question_answer_impact(cfg, q, answer_str )
      if 'A' in impact_str:
        num_fails += 1

        if acfg.verbose :
          print(tokens_to_string(q), "Q: ModelAnswer:", answer_str, "Impact:", impact_str, "Loss:", the_loss_mean )

  return num_fails

In [None]:
def c_set_resid_post_hook(value, hook):
  #print( "In hook", l_hook_resid_post_name[acfg.layer], acfg.ablate, acfg.position, value.shape) # Get [64, 22, 510] = cfg.batch_size, cfg.n_ctx, d_model

  # Copy the mean resid post values in position N to all the batch questions
  value[:,acfg.position,:] = acfg.mean_resid_post[0,acfg.position,:].clone()


num_failures_list = []
if cfg.n_digits >= 5 :
  c_fwd_hooks = [(acfg.l_hook_resid_post_name[0], c_set_resid_post_hook),(acfg.l_hook_resid_post_name[1], c_set_resid_post_hook),(acfg.l_hook_resid_post_name[2], c_set_resid_post_hook),(acfg.l_hook_resid_post_name[3], c_set_resid_post_hook)][:cfg.n_layers]

  num_questions = varied_questions.shape[0]

  for acfg.position in range(cfg.n_ctx()):
    num_fails = c_predict_questions(varied_questions, c_fwd_hooks)

    num_failures_list += [num_fails] if num_fails > 0 else "."

    if num_fails > 0:
      assert acfg.position < cfg.n_ctx()
      cfg.add_useful_position(acfg.position)

# Part 9 : Results: Can the model do 1 million questions without error?

If the model passes this test, this is evidence (not proof) that the model is fully accurate. There may be very rare edge cases (say 1 in ten million) that did not appear in the test questions. Even if you believe you know all the edge cases, and have enriched the training data to contain them, you may not have thought of all edge cases, so this is not proof.

If the model fails this test:
- Add a few of the failures into the "test questions" in part 6C
- Understand the "use case(s)" driving these failures
- Alter the Training CoLab data_generator_core to enrich the training data with examples if these use case(s) and retrain the model.  

Takes ~25 mins to run (successfully) for ins_mix_d6_l3_h4_t40K_seed372001

In [None]:
def null_hook(value, hook):
  pass

In [None]:
def one_million_questions_core():
  global ds

  acfg.verbose = False

  cfg.analysis_seed = 345621 # Randomly chosen
  ds = data_generator() # Re-initialise the data generator

  the_successes = 0
  the_fails = 0

  num_batches = 1000000//cfg.batch_size
  for epoch in range(num_batches):
      tokens = next(ds)

      the_hook = [(acfg.l_attn_hook_z_name[0], null_hook)]
      the_fails = c_predict_questions(tokens, the_hook)

      if the_fails> 0:
        break

      the_successes = the_successes + cfg.batch_size

      if epoch % 100 == 0:
          print("Batch", epoch, "of", num_batches, "#Successes=", the_successes)

  print("successes", the_successes, "num_fails", the_fails)
  if the_fails > 0:
    "WARNING: Model is not fully accurate. It failed the 1M Q test"

In [None]:
def one_million_questions():
  store_perc_sub = cfg.perc_sub
  store_perc_mult = cfg.perc_mult

  print_config()
  print()

  if cfg.perc_add() > 0:
    print("Addition:")
    cfg.perc_sub = 0
    cfg.perc_mult = 0
    one_million_questions_core()

  if store_perc_sub > 0:
    print("Subtraction:")
    cfg.perc_sub = 100
    cfg.perc_mult = 0
    one_million_questions_core()
    print()

  cfg.perc_sub = store_perc_sub
  cfg.perc_mult = store_perc_mult


# Takes ~25 minutes to run
# one_million_questions()

# Part 12B: Results: Ablate ALL Heads in EACH token position. What is the impact?

Here we ablate all heads in each token position (overriding the model memory aka residual stream) and see if loss increases. If loss increases the token position is used by the algorithm. Unused token positions can be excluded from further analysis. Use "C_" prefix

In [None]:
def save_plt_to_file( full_title ):
  if cfg.graph_file_suffix > "":
    filename = cfg.file_config_prefix() + full_title + '.' + cfg.graph_file_suffix
    filename = filename.replace(" ", "").replace(",", "").replace(":", "").replace("-", "")
    plt.savefig(filename, bbox_inches='tight', pad_inches=0)

In [None]:
print_config()
print()
print("Number of failures when ALL Heads in EACH token position are ablated")
print("num_questions=", num_questions, "min_useful_position=", cfg.min_useful_position(), "max_useful_position=", cfg.max_useful_position() )
print()

cfg.calc_position_failures_map(num_failures_list, 16)
save_plt_to_file("Failures When Position Ablated")
plt.show()

# Part 14: Setup: Analysis of quanta per node

Evaluate quanta at node (not position) resolution. Uses "u_" prefix.

In [None]:
def u_add_node_tag( the_location, major_tag, minor_tag ):
  assert the_location.position >= 0
  assert the_location.layer >= 0
  assert the_location.num >= 0
  assert the_location.position < cfg.n_ctx()
  assert the_location.layer < cfg.n_layers
  if the_location.is_head:
    assert the_location.num < cfg.n_heads
  else:
    assert the_location.num < cfg.mlp_slices()

  cfg.useful_nodes.add_node_tag( the_location, major_tag, minor_tag )

In [None]:
def u_predict_questions(questions, the_hooks):

  all_losses_raw, all_max_prob_tokens = predict_questions_core(questions, the_hooks)

  num_fails = 0
  impact_fails = ""
  add_complexity_fails = ""
  sub_complexity_fails = ""

  for question_num in range(questions.shape[0]):
    q = questions[question_num]

    the_loss_mean = utils.to_numpy(loss_fn(all_losses_raw[question_num]).mean())

    # Only show the question if the loss exceeds the threshold (because of the ablated token position)
    if the_loss_mean > acfg.threshold:
      answer_str = tokens_to_string(cfg, all_max_prob_tokens[question_num])

      impact_str = get_question_answer_impact(cfg, q, answer_str )
      # Only count the question if the model got the question wrong
      if 'A' in impact_str:
        num_fails += 1

        impact_fails += impact_str

        major_tag, minor_tag = get_maths_question_complexity(cfg, q)
        if major_tag == QuantaType.MATH_ADD:
          add_complexity_fails += minor_tag
        elif major_tag == QuantaType.MATH_SUB:
          sub_complexity_fails += minor_tag

        if acfg.verbose :
          print(tokens_to_string(q), "U: ModelAnswer:", answer_str, "Complexity:", major_tag, "Impact:", impact_str, "Loss:", the_loss_mean )


  if num_fails > 0:

    # Add percentage failure quanta
    perc = int( 100.0 * num_fails / len(questions))
    u_add_node_tag( acfg, QuantaType.FAIL, str(perc) )

    # Add summary of all answer digit impact quanta failures
    u_add_node_tag( acfg, QuantaType.IMPACT, "A" + sort_unique_digits(impact_fails, True) )

    # Add summary of all addition question complexity quanta failures
    if add_complexity_fails != "":
      u_add_node_tag( acfg, QuantaType.MATH_ADD, "S" + sort_unique_digits(add_complexity_fails, False) )

    # Add summary of all subtraction question complexity quanta failures
    if sub_complexity_fails != "":
      sub_complexity_fails = sort_unique_digits(sub_complexity_fails, False)
      if sub_complexity_fails == "":
        sub_complexity_fails = MathsTag.SUB_NG_TAG
      else:
        sub_complexity_fails = "M" + sub_complexity_fails
      u_add_node_tag( acfg, QuantaType.MATH_SUB, sub_complexity_fails )

In [None]:
def u_mlp_hook_post(value, hook):
  # print( "In u_mlp_hook_post", value.shape) # Get [1099, 22, 2040] = num_questions, cfg.n_ctx, cfg.d_mlp (# neurons)

  # Mean ablate. Copy the mean resid post values in the MLP layer
  value[:,acfg.position,:] =  acfg.mean_mlp_hook_post[:,acfg.position,:].clone()
  #PQR When we do slices the MLP never fails?? Why??
  #slice_size = cfg.d_mlp // cfg.mlp_slices()
  #start_index = acfg.num * slice_size
  #end_index = start_index + slice_size
  #print ("PQR", cfg.mlp_slices(), slice_size, start_index, end_index)
  #value[:,acfg.position,start_index:end_index] = acfg.mean_mlp_hook_post[:,acfg.position,start_index:end_index].clone()


# Ablating the MLP in each layer in each position and seeing if the loss increases shows which layer+MLP are used by the algorithm.
def u_mlp_perform_all(questions):
  acfg.is_head = False
  for acfg.position in cfg.useful_positions:
    for acfg.layer in range(cfg.n_layers):
      for acfg.num in range(cfg.mlp_slices()):
        the_hook = [(acfg.l_mlp_hook_post_name[acfg.layer], u_mlp_hook_post)]
        u_predict_questions(questions, the_hook)

In [None]:
def u_mlp_hook_post(value, hook):
  # print( "In u_mlp_hook_post", value.shape) # Get [1099, 22, 2040] = num_questions, cfg.n_ctx, cfg.d_mlp (# neurons)

  # Mean ablate. Copy the mean resid post values in the MLP layer
  value[:,acfg.position,:] =  acfg.mean_mlp_hook_post[:,acfg.position,:].clone()
  #PQR When we do slices the MLP never fails?? Why??
  #slice_size = cfg.d_mlp // cfg.mlp_slices()
  #start_index = acfg.num * slice_size
  #end_index = start_index + slice_size
  #print ("PQR", cfg.mlp_slices(), slice_size, start_index, end_index)
  #value[:,acfg.position,start_index:end_index] = acfg.mean_mlp_hook_post[:,acfg.position,start_index:end_index].clone()


# Ablating the MLP in each layer in each position and seeing if the loss increases shows which layer+MLP are used by the algorithm.
def u_mlp_perform_all(questions):
  acfg.is_head = False
  for acfg.position in cfg.useful_positions:
    for acfg.layer in range(cfg.n_layers):
      for acfg.num in range(cfg.mlp_slices()):
        the_hook = [(acfg.l_mlp_hook_post_name[acfg.layer], u_mlp_hook_post)]
        u_predict_questions(questions, the_hook)

In [None]:
def u_mlp_hook_post(value, hook):
  # print( "In u_mlp_hook_post", value.shape) # Get [1099, 22, 2040] = num_questions, cfg.n_ctx, cfg.d_mlp (# neurons)

  # Mean ablate. Copy the mean resid post values in the MLP layer
  value[:,acfg.position,:] =  acfg.mean_mlp_hook_post[:,acfg.position,:].clone()
  #PQR When we do slices the MLP never fails?? Why??
  #slice_size = cfg.d_mlp // cfg.mlp_slices()
  #start_index = acfg.num * slice_size
  #end_index = start_index + slice_size
  #print ("PQR", cfg.mlp_slices(), slice_size, start_index, end_index)
  #value[:,acfg.position,start_index:end_index] =  acfg.mean_mlp_hook_post[:,acfg.position,start_index:end_index].clone()


# Ablating the MLP in each layer in each position and seeing if the loss increases shows which layer+MLP are used by the algorithm.
def u_mlp_perform_all(questions):
  acfg.is_head = False
  for acfg.position in cfg.useful_positions:
    for acfg.layer in range(cfg.n_layers):
      for acfg.num in range(cfg.mlp_slices()):
        the_hook = [(acfg.l_mlp_hook_post_name[acfg.layer], u_mlp_hook_post)]
        u_predict_questions(questions, the_hook)

In [None]:
def u_head_attn_hook_z(value, hook):
  # print( "In u_head_attn_hook_z", value.shape) # Get [1, 22, 3, 170] = ???, cfg.n_ctx, cfg.n_heads, cfg.d_head

  # Mean ablate. Copy the mean resid post values in position N to all the batch questions
  value[:,acfg.position,acfg.num,:] = acfg.mean_attn_z[:,acfg.position,acfg.num,:].clone()


# Ablating each head in each layer in each position and seeing if the loss increases shows which position+layer+head are used by the algorithm.
def u_head_perform_all(questions):
  acfg.is_head = True
  for acfg.position in cfg.useful_positions:
    for acfg.layer in range(cfg.n_layers):
      for acfg.num in range(cfg.n_heads):
        the_hook = [(acfg.l_attn_hook_z_name[acfg.layer], u_head_attn_hook_z)]
        u_predict_questions(questions, the_hook)

In [None]:
def u_calculate_attention_tags(questions):
  cfg.useful_nodes.reset_node_tags(QuantaType.ATTENTION)

  logits, cache = cfg.main_model.run_with_cache(questions)

  all_attention_weights = []
  for layer in range(cfg.n_layers):
    attention_weights = cache["pattern", layer, "attn"]
    #print(attention_weights.shape) # 512, 4, 22, 22 = cfg.batch_size, cfg.n_heads, cfg.n_ctx, cfg.n_ctx

    average_attention_weights = attention_weights.mean(dim=0)
    #print(average_attention_weights.shape) # 4, 22, 22 = cfg.n_heads, cfg.n_ctx, cfg.n_ctx

    all_attention_weights += [average_attention_weights]


  for node in cfg.useful_nodes.nodes:
    if node.is_head:

      # Get attention weights for this token in this head
      layer_weights = all_attention_weights[node.layer]
      weights = layer_weights[node.num, node.position, :]

      top_tokens = torch.topk(weights, 4)
      total_attention = weights.sum()
      attention_percentage = top_tokens.values / total_attention * 100

      # Add up to 4 tags with percs per head
      for idx, token_idx in enumerate(top_tokens.indices):
        perc = attention_percentage[idx]
        if perc >= 1.0:
          u_add_node_tag( node, QuantaType.ATTENTION, f"P{token_idx}={perc:.0f}" )

In [None]:
acfg.verbose = False
cfg.useful_nodes = UsefulNodeList()
u_mlp_perform_all(varied_questions)
u_head_perform_all(varied_questions)
u_calculate_attention_tags(varied_questions)
cfg.useful_nodes.sort_nodes()

 # Part 15A: Set up: Show and save Quanta map

 Using the UsefulNodes and filtering their tags, show a 2D map of the nodes and the tag minor versions.

In [None]:
def show_quanta_map( title, standard_quanta, shades, filters : FilterNode, major_tag : str, minor_tag : str, get_node_details, base_fontsize = 10, max_width = 10 ):

  test_nodes = cfg.useful_nodes if filters == None else filter_nodes(cfg.useful_nodes, filters)

  ax1, quanta_results = calc_quanta_map(cfg, standard_quanta, shades, test_nodes, major_tag, minor_tag, get_node_details, base_fontsize, max_width )

  if cfg.graph_file_suffix > "":
    print("Saving quanta map:", title)
    save_plt_to_file(title)
  else:
    ax1.set_title(cfg.file_config_prefix() + ' ' + title + ' ({} nodes)'.format(len(quanta_results)))

  # Show plot
  plt.show()

# Part 16A: Results: Show failure percentage map

Show the percentage failure rate (incorrect prediction) when individual Attention Heads and MLPs are ablated.

A cell containing "< 1" may add some risk to the accuracy of the overall analysis process. Check to see if this represents a new use case. Improve the test data set to contain more instances of this (new or existing) use case.

In [None]:
show_quanta_map( "Failure Frequency Behavior Per Node", True, 10, None, QuantaType.FAIL, "", get_quanta_fail_perc, 9)

# Part 16B - Show answer impact behavior map

Show the purpose of each useful cell by impact on the answer digits A0 to A5.


In [None]:
show_quanta_map( "Answer Impact Behavior Per Node", True, cfg.answer_tokens(), None, QuantaType.IMPACT, "", get_quanta_impact, 9, 6)

# Part 16C: Result: Show attention map

Show attention quanta of useful heads

In [None]:
# Only maps attention heads, not MLP layers
show_quanta_map( "Attention Behavior Per Head", True, 10, None, QuantaType.ATTENTION, "", get_quanta_attention, 10, 6)

# Part 16C - Show question complexity map

Show the "minimum" addition purpose of each useful cell by S0 to S5 quanta.
Show the "minimum" subtraction purpose of each useful cell by M0 to M5 quanta

In [None]:
def get_quanta_min_complexity(cfg, node, major_tag, minor_tag, shades):
  color_index = 0
  cell_text = node.min_tag_suffix( major_tag, minor_tag )
  if cell_text != "" :
    cell_text = cell_text[0:2]
    color_index = int(cell_text[1]) if len(cell_text) > 1 and cell_text[1].isdigit() else shades-1

  return cell_text, color_index

In [None]:
if cfg.perc_add() > 0:
  show_quanta_map( "Addition Min-Complexity Behavior Per Node", False, 6, None, QuantaType.MATH_ADD, "", get_quanta_min_complexity)

In [None]:
if cfg.perc_sub > 0:
  show_quanta_map( "Subtraction Min-Complexity Behavior Per Node", False, 4, None, QuantaType.MATH_SUB, "", get_quanta_min_complexity)

#Part 19A: Set Up: Calc and graph PCA decomposition.


In [None]:
tn_questions = 100


def make_t_questions(test_digit, test_case, operation):
    limit = 10 ** test_digit
    questions = []
    for i in range(tn_questions):
      x_noise = 0
      y_noise = 0

      if operation == MathsTokens.PLUS:
        if test_case == 8:
          # These are n_digit addition questions where x and y sum is between 0 to 8
          x = random.randint(0, 8)
          y = random.randint(0, 8-x)
        if test_case == 9:
          # These are n_digit addition questions where x and y sum is 9
          x = random.randint(0, 9)
          y = 9 - x
        if test_case == 10:
          # These are n_digit addition questions where x and y sum is between 10 to 18
          x = random.randint(1, 9)
          y = random.randint(10-x, 9)

        # Randomise the lower digits - ensuring that x_noise + y_noise dont cause a MakeCarry
        x_noise = random.randint(0, limit-1)
        y_noise = random.randint(0, limit-1 - x_noise)


      if operation == MathsTokens.MINUS:
        if test_case == 8:
          # These are n_digit subtraction questions where x - y < 0
          x = random.randint(0, 8)
          y = random.randint(x+1, 9)
        if test_case == 9:
          # These are n_digit subtraction questions where x - y is 0
          x = random.randint(0, 9)
          y = x
        if test_case == 10:
          # These are n_digit subtraction questions where x - y > 0
          x = random.randint(1, 9)
          y = random.randint(0, x-1)

        # Randomise the lower digits - ensuring that x_noise + y_noise dont cause a BorrowOne
        x_noise = random.randint(0, limit-1)
        y_noise = random.randint(0, x_noise)


      x = x * limit + x_noise
      y = y * limit + y_noise
      questions.append([x, y])

    return make_maths_questions(cfg, operation, "", "", questions)



def make_tricase_questions(test_digit, operation):
  q1 = make_t_questions(test_digit, 8, operation)
  q2 = make_t_questions(test_digit, 9, operation)
  q3 = make_t_questions(test_digit, 10, operation)

  questions = torch.vstack((q1, q2, q3))

  return questions


# Cache the sample questions (by answer_digit and operation) for later reuse
t_questions_dict = {}
for answer_digit in range(cfg.n_digits):
    for operation in [MathsTokens.PLUS, MathsTokens.MINUS]:
        t_questions = make_tricase_questions(answer_digit, operation)
        # Use a tuple of (answer_digit, operation) as the key for indexing
        t_questions_dict[(answer_digit, operation)] = t_questions

In [None]:
def pca_evr_0_percent(pca):
  return int(round(pca.explained_variance_ratio_[0]*100,0))

In [None]:
# Calculate one Principal Component Analysis
def calc_pca_for_an(node_location, operation, answer_digit):
  assert node_location.is_head == True

  try:
    t_questions = t_questions_dict[(answer_digit, operation)]

    _, the_cache = cfg.main_model.run_with_cache(t_questions)

    # Gather attention patterns for all the (randomly chosen) questions
    attention_outputs = []
    for i in range(len(t_questions)):

      # Output of individual heads, without final bias
      attention_cache=the_cache["result", node_location.layer, "attn"] # Output of individual heads, without final bias
      attention_output=attention_cache[i]  # Shape [n_ctx, n_head, d_model]
      attention_outputs.append(attention_output[node_location.position, node_location.num, :])

    attn_outputs = torch.stack(attention_outputs, dim=0).cpu()

    pca = PCA(n_components=6)
    pca.fit(attn_outputs)
    pca_attn_outputs = pca.transform(attn_outputs)

    title = node_location.name() + ', A'+str(answer_digit) + ', EVR[0]=' + str(pca_evr_0_percent(pca)) + '%'

    return pca, pca_attn_outputs, title
  except Exception as e:
    print( "calc_pca_for_an Failed:" + node_location.name() + " " + token_to_char(cfg, operation) + " " + answer_name(answer_digit), e)
    return None, None, None

In [None]:
# Plot the PCA of PnLnHn's attention pattern, using T8, T9, T10 questions that differ in the An digit
def plot_pca_for_an(ax, pca_attn_outputs, title):
  ax.scatter(pca_attn_outputs[:tn_questions, 0], pca_attn_outputs[:tn_questions, 1], color='red', label='T8 (0-8)') # t8 questions
  ax.scatter(pca_attn_outputs[tn_questions:2*tn_questions, 0], pca_attn_outputs[tn_questions:2*tn_questions, 1], color='green', label='T9') # t9 questions
  ax.scatter(pca_attn_outputs[2*tn_questions:, 0], pca_attn_outputs[2*tn_questions:, 1], color='blue', label='T10 (10-18)') # t10 questions
  if title != "" :
    ax.set_title(title)

In [None]:
def pca_op_tag(the_digit, operation, strong):
  return answer_name(the_digit)  + "." + (MathsTag.PCA_ADD_TAG if operation == MathsTokens.PLUS else MathsTag.PCA_SUB_TAG) + ( "" if strong else ".Weak")

In [None]:
def manual_node_pca(ax, position, layer, num, operation, answer_digit):

  node_location = NodeLocation(position, layer, True, num)
  pca, pca_attn_outputs, title = calc_pca_for_an(node_location, operation, answer_digit)
  plot_pca_for_an(ax, pca_attn_outputs, title)

  # Add the strong PCA tag to node PCA:A5.TR
  u_add_node_tag( node_location, QuantaType.PCA, pca_op_tag(answer_digit, operation, True) )


def auto_node_pca(ax, index, node_location, operation, answer_digit, perc_threshold):

  pca, pca_attn_outputs, title = calc_pca_for_an(node_location, operation, answer_digit)
  if pca is not None:
    perc = pca_evr_0_percent(pca)
    if perc > perc_threshold:
      plot_pca_for_an(ax, pca_attn_outputs, title)

      # Add the weak PCA tag to node PCA:A5.TR.Weak
      u_add_node_tag( node_location, QuantaType.PCA, pca_op_tag(answer_digit, operation, False) )
      return True

  return False

In [None]:
def manual_nodes_pca(op, nodes):
  print("Manual PCA tags for", cfg.model_name, "with operation", token_to_char(cfg, op))

  cols = 4
  rows = 1 + (len(nodes)+1) // cols

  fig, axs = plt.subplots(rows, cols)
  fig.set_figheight(rows*2 + 1)
  fig.set_figwidth(10)

  index = 0
  for node in nodes:
    manual_node_pca(axs[index // cols, index % cols], node[0], node[1], node[2], op, node[3])
    index += 1

  # Remove any graphs we dont need (except last one)
  while index < rows * cols - 1:
    ax = axs[index // cols, index % cols]
    ax.remove()
    index += 1

  # Replace last graph with the legend
  lines_labels = [axs[0,0].get_legend_handles_labels()]
  lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
  axs[rows-1, cols-1].legend(lines, labels)
  axs[rows-1, cols-1].axis('off') # Now, to hide the last subplot

  plt.tight_layout()
  save_plt_to_file('Pca Tr')
  plt.show()

#Part 19B: Results: Manual interpretation of PCA results

If an attention head and an answer digit An gives an interpretable response (2 or 3 distinct output clusters) on 3 groups of questions aligned to T8, T9 and T10 definitions, then plot the response and add a QuantaType.PCA tag



In [None]:
cfg.useful_nodes.reset_node_tags(QuantaType.PCA)

In [None]:
# Plot all attention heads with the clearest An selected
if use_pca:

  if cfg.model_name == "add_d5_l1_h3_t30K" :
    manual_nodes_pca(MathsTokens.PLUS,
      [[ 12, 0, 0, 4 ],
      [ 12, 0, 2, 3 ],
      [ 13, 0, 0, 3 ],
      [ 13, 0, 2, 2 ],
      [ 14, 0, 0, 2 ],
      [ 14, 0, 2, 1 ],
      [ 15, 0, 0, 1 ],
      [ 15, 0, 2, 0 ],
      [ 16, 0, 0, 0 ]])

  if cfg.model_name == "add_d5_l2_h3_t15K" :
    manual_nodes_pca(MathsTokens.PLUS,
      [[10, 0, 0, 2 ],
      [ 12, 0, 0, 3 ],
      [ 12, 1, 0, 3 ],
      [ 12, 1, 1, 4 ],
      [ 12, 1, 2, 4 ],
      [ 13, 0, 0, 0 ],
      [ 13, 0, 0, 3 ],
      [ 13, 1, 2, 2 ],
      [ 14, 0, 0, 0 ],
      [ 14, 0, 0, 2 ],
      [ 14, 1, 2, 2 ],
      [ 15, 0, 0, 1 ],
      [ 15, 0, 0, 0 ],
      [ 15, 1, 1, 1 ],
      [ 16, 0, 0, 0 ]])

  if cfg.model_name == "add_d6_l2_h3_t15K" :
    manual_nodes_pca(MathsTokens.PLUS,
      [[11, 0, 0, 2 ],
      [ 12, 0, 0, 3 ],
      [ 13, 0, 0, 1 ],
      [ 14, 0, 0, 4 ],
      [ 14, 1, 1, 4 ],
      [ 15, 0, 0, 4 ],
      [ 15, 1, 1, 4 ],
      [ 16, 0, 0, 3 ],
      [ 16, 1, 1, 3 ],
      [ 17, 0, 0, 2 ],
      [ 17, 1, 1, 2 ],
      [ 18, 0, 0, 0 ],
      [ 18, 0, 0, 1 ],
      [ 19, 0, 0, 0 ]])

  if cfg.model_name == "sub_d6_l2_h3_t30K" :
    manual_nodes_pca(MathsTokens.MINUS,
      [[14, 0, 1, 0 ],
      [ 15, 0, 0, 5 ],
      [ 15, 1, 1, 0 ],
      [ 15, 1, 1, 1 ],
      [ 15, 1, 1, 2 ],
      [ 15, 1, 1, 3 ],
      [ 15, 1, 1, 4 ],
      [ 15, 1, 2, 0 ],
      [ 15, 1, 2, 1 ],
      [ 15, 1, 2, 2 ],
      [ 15, 1, 2, 3 ],
      [ 16, 0, 0, 0 ],
      [ 16, 0, 0, 1 ],
      [ 16, 0, 0, 2 ],
      [ 16, 0, 0, 3 ],
      [ 16, 0, 0, 4 ],
      [ 16, 0, 0, 5 ],
      [ 16, 1, 0, 4 ],
      [ 16, 1, 1, 0 ],
      [ 16, 1, 1, 1 ],
      [ 16, 1, 1, 2 ],
      [ 16, 1, 1, 3 ],
      [ 16, 1, 1, 4 ],
      [ 16, 1, 2, 0 ],
      [ 16, 1, 2, 1 ],
      [ 16, 1, 2, 2 ],
      [ 16, 1, 2, 3 ],
      [ 16, 1, 2, 4 ],
      [ 16, 1, 2, 5 ],
      [ 17, 0, 0, 0 ],
      [ 17, 0, 0, 1 ],
      [ 17, 0, 0, 2 ],
      [ 17, 0, 0, 3 ],
      [ 17, 1, 0, 3 ],
      [ 17, 1, 0, 4 ],
      [ 17, 1, 2, 0 ],
      [ 17, 1, 2, 4 ],
      [ 18, 0, 0, 0 ],
      [ 18, 0, 0, 1 ],
      [ 18, 0, 0, 2 ],
      [ 18, 0, 2, 0 ],
      [ 18, 1, 2, 3 ],
      [ 19, 0, 0, 0 ],
      [ 19, 1, 2, 2 ],
      [ 20, 0, 0, 0 ],
      [ 18, 0, 2, 0 ],
	    [ 18, 0, 2, 0 ]])

  if cfg.model_name == "mix_d6_l3_h4_t40K" : # TBC
    manual_nodes_pca(MathsTokens.PLUS,
      [[ 8, 0, 0, 4 ],
      [  9, 0, 1, 3 ],
      [ 10, 0, 1, 2 ],
      [ 10, 0, 1, 3 ],
      [ 11, 0, 1, 1 ],
      [ 11, 0, 1, 2 ],
      [ 12, 0, 0, 0 ],
      [ 13, 0, 1, 4 ],
      [ 13, 1, 1, 0 ],
      [ 13, 1, 1, 1 ],
      [ 13, 1, 1, 2 ],
      [ 13, 1, 1, 3 ],
      [ 13, 1, 2, 3 ],
      [ 13, 1, 2, 5 ],
      [ 14, 0, 1, 0 ],
      [ 15, 0, 0, 0 ],
      [ 15, 0, 0, 1 ],
      [ 15, 0, 0, 2 ],
      [ 15, 0, 0, 3 ],
      [ 15, 0, 0, 4 ],
      [ 15, 0, 0, 5 ],
      [ 15, 1, 0, 0 ],
      [ 15, 1, 0, 1 ],
      [ 15, 1, 0, 2 ],
      [ 15, 1, 0, 3 ],
      [ 15, 1, 0, 4 ]])

  if cfg.model_name == "ins1_mix_d6_l3_h4_t40K" :
    manual_nodes_pca(MathsTokens.PLUS,
      [[13, 1, 3, 1 ],
      [ 14, 1, 2, 0 ],
      [ 14, 1, 2, 2 ],
      [ 14, 1, 3, 4 ],
      [ 15, 0, 3, 5 ],
      [ 15, 1, 2, 2 ],
      [ 15, 1, 3, 4 ],
      [ 16, 0, 3, 4 ],
      [ 16, 1, 2, 0 ],
      [ 16, 1, 2, 1 ],
      [ 16, 1, 2, 2 ],
      [ 16, 1, 3, 2 ],
      [ 17, 0, 3, 3 ],
      [ 17, 1, 2, 2 ],
      [ 17, 1, 3, 2 ],
      [ 18, 0, 3, 2 ],
      [ 18, 1, 3, 1 ],
      [ 19, 0, 3, 1 ],
      [ 19, 2, 0, 0 ],
      [ 19, 2, 1, 0 ],
      [ 20, 0, 0, 0 ],
      [ 20, 0, 3, 0 ]])

else:
  print( "PCA library failed to import. So PCA not done")

#Part 19C: Results: Automatic interpetation of PCA results

Part 19B is manual and selective. This part is automatic. It tests nodes not included in Part 19B, where this first (single) principal component explains 66% or more of the node. It adds a QuantaType.PCA "weak" tag

In [None]:
def auto_find_pca_node(node, op, perc_threshold):
  fig, axs = plt.subplots(2, 4) # Allow up to 8 graphs
  fig.set_figheight(4)
  fig.set_figwidth(10)

  index = 0
  for answer_digit in range(cfg.n_digits+1):
    ax = axs[index // 4, index % 4]
    if auto_node_pca(ax, index, node, op, answer_digit, perc_threshold):
      index += 1

  # Remove any graphs we dont need after all
  while index < 2 * 4:
    ax = axs[index // 4, index % 4]
    ax.remove()
    index += 1

  plt.tight_layout()
  plt.show()

In [None]:
def auto_find_pca(op):
  print("Automatic (weak) PCA tags for", cfg.model_name, "with operation", token_to_char(cfg, op))
  perc_threshold = 75

  for node in cfg.useful_nodes.nodes:

    # Exclude nodes with a (manual) PCA tag - for any answer digit(s)). Exclude MLP neurons.
    minor_tag_prefix = MathsTag.PCA_ADD_TAG if operation == MathsTokens.PLUS else MathsTag.PCA_SUB_TAG
    if node.is_head and not node.contains_tag(QuantaType.PCA, minor_tag_prefix):
      print( "Doing PCA on node", node.name())

      auto_find_pca_node(node, op, perc_threshold)


if use_pca:
  if cfg.perc_add() > 0:
    auto_find_pca(MathsTokens.PLUS)
  if cfg.perc_sub > 0:
    auto_find_pca(MathsTokens.MINUS)

# Part 20A: Results: Show useful nodes and behaviour tags

In [None]:
cfg.useful_nodes.print_node_tags()

# Part 20B: Results: Save useful nodes and behaviour tags to json file

In [None]:
# Serialize and save the useful nodes list to a temporary CoLab file in JSON format
print( "Saving useful node list with behavior tags:", main_fname_behavior_json)
cfg.useful_nodes.save_nodes(main_fname_behavior_json)

# Part 21A : Set up: Interchange Interventions

Here we prove that model nodes perform specified calculations. If all the calculations in an algorithm hypothesis are found to exist in a model instance, this provides evidence for the hypothesis.   

**Automatic searches** for node purposes are preferred, as they applicable to several models, and survive (non-sigificant, node-reordering) changes to the model after training. When a node purpose is detected, this is documented as a tag on the node.

**Manually written tests** of node purposes, specific to a single model instance are also supported.

In [None]:
def run_intervention_core(node_locations, store_question, clean_question, strong):
  assert(store_question[0] < + 10 ** cfg.n_digits)
  assert(store_question[1] > - 10 ** cfg.n_digits)
  assert(store_question[0] < + 10 ** cfg.n_digits)
  assert(store_question[1] > - 10 ** cfg.n_digits)
  assert(clean_question[0] < + 10 ** cfg.n_digits)
  assert(clean_question[1] > - 10 ** cfg.n_digits)
  assert(clean_question[0] < + 10 ** cfg.n_digits)
  assert(clean_question[1] > - 10 ** cfg.n_digits)

  a_reset(node_locations)
  acfg.num_tests_run+= 1

  # Calculate the clean (no intervention) test question answer  e.g. "+006671"
  clean_answer_int = clean_question[0]+clean_question[1] if acfg.operation == MathsTokens.PLUS else clean_question[0]-clean_question[1]
  clean_answer = int_to_answer_str(cfg,clean_answer_int)
  description = "Intervening on " + acfg.node_names() + ", " + ("Strong" if strong else "Weak") + ", CleanAnswer: " + clean_answer + ", ExpectedAnswer/Impact: " + acfg.expected_answer + "/" + acfg.expected_impact
  # Predict "store" question and store activation values
  acfg.questions = make_maths_questions(cfg, acfg.operation, "", "", [store_question])
  predict_questions_core(acfg.questions, acfg.attn_get_hooks)
  if acfg.abort:
    return description + " (Aborted on store)"

  # Predict "test" question overriding PnLmHp to give a bad answer
  acfg.questions = make_maths_questions(cfg, acfg.operation, "", "", [clean_question])
  all_losses_raw, all_max_prob_tokens = predict_questions_core(acfg.questions, acfg.attn_put_hooks)
  if acfg.abort:
    return description + " (Aborted on test)"
  if all_losses_raw.shape[0] == 0:
    acfg.abort = True
    print( "Bad all_losses_raw", all_losses_raw.shape, store_question, clean_question, "Check if a question component exceeds 10 ** n_digits" )
    return description + " (Aborted on Bad all_losses_raw)"
  loss_max = utils.to_numpy(loss_fn(all_losses_raw[0]).max())
  acfg.intervened_answer = tokens_to_string(cfg, all_max_prob_tokens[0])

  # Compare the clean test question answer to what the model generated (impacted by the ablation intervention)
  acfg.intervened_impact = get_answer_impact( cfg, clean_answer, acfg.intervened_answer )
  if acfg.intervened_impact == "":
    acfg.intervened_impact = NO_IMPACT_TAG

  description += ", IntervenedAnswer/Impact: " + acfg.intervened_answer + "/" + acfg.intervened_impact

  if loss_max > acfg.threshold:
    loss_str = NO_IMPACT_TAG if loss_max < 1e-7 else str(loss_max)
    description += ", Loss: " + loss_str

  return description


# Run an intervention where we have a precise expectation of the intervention impact
def run_strong_intervention(node_locations, store_question, clean_question, operation, expected_impact, expected_answer_int):
  acfg.reset_intervention(expected_answer_int, expected_impact, operation)

  # These are the actual model prediction outputs (while applying our node-level intervention).
  description = run_intervention_core(node_locations, store_question, clean_question, True)

  answer_success = (acfg.intervened_answer == acfg.expected_answer)
  impact_success = (acfg.intervened_impact == acfg.expected_impact)
  success = answer_success and impact_success

  if acfg.show_test_failures and not success:
    print("Failed: " + description)

  return success, answer_success, impact_success


# Run an intervention where we expect the intervention to have a non-zero impact and we cant precisely predict the answer
def run_weak_intervention(node_locations, store_question, clean_question, operation):

  # Calculate the test (clean) question answer e.g. "+006671"
  expected_answer_int = clean_question[0]+clean_question[1] if operation == MathsTokens.PLUS else clean_question[0]-clean_question[1]
  acfg.reset_intervention(expected_answer_int, NO_IMPACT_TAG, operation)

  description = run_intervention_core(node_locations, store_question, clean_question, False)

  success = not ((acfg.intervened_answer == acfg.expected_answer) or (acfg.intervened_impact == NO_IMPACT_TAG))

  if acfg.show_test_failures and not success:
    print("Failed: Intervention had no impact on the answer", description)

  return success

In [None]:
def repeat_digit(digit):
    return int(str(digit) * cfg.n_digits)


# unit test
if cfg.n_digits == 6:
  assert repeat_digit(4) == 444444

In [None]:
def succeed_test(node_locations, alter_digit, strong):
  print( "Test confirmed", node_locations[0].name(), node_locations[1].name() if len(node_locations)>1 else "", "" if strong else "Weak")
  return True

In [None]:
# Search the specified useful node(s), using the test_function, for the expected impact on the_impact_digit
def search_and_tag_digit_position( the_impact_digit, the_test_nodes, test_function, strong, the_tag, do_pair_search ):

  # Try single nodes first
  for node in the_test_nodes.nodes:
    if test_function( [node], the_impact_digit, strong):
      full_tag = the_tag + ("" if strong else "." + acfg.intervened_impact)
      node.add_tag(QuantaType.ALGO, full_tag)
      acfg.num_tags_added += 1
      return True

  # Try pairs of nodes. Sometimes a task is split across two attention heads (i.e. a virtual attention head)
  if do_pair_search:
    node_pairs = list(itertools.combinations(the_test_nodes.nodes, 2))
    for pair in node_pairs:
      # Only if the 2 nodes are in the same layer can they can act in parallel and so "sum" to give a virtual attention head
      if pair[0].layer == pair[1].layer and pair[0].is_head == pair[1].is_head:
        if test_function( [pair[0], pair[1]], the_impact_digit, strong):
          full_tag = the_tag + ("" if strong else "." + acfg.intervened_impact)
          pair[0].add_tag(QuantaType.ALGO, full_tag)
          pair[1].add_tag(QuantaType.ALGO, full_tag)
          acfg.num_tags_added += 2
          return True

  return False


# For each useful position, search the related useful node(s), using the test_function, for the expected impact on the_impact_digit.
def search_and_tag_digit( prerequisites_function, the_impact_digit, test_function, tag_function, do_pair_search, do_weak_search, from_position, to_position ):

  the_tag = tag_function(the_impact_digit)

  if from_position == -1:
    from_position = cfg.min_useful_position()
  if to_position == -1:
    to_position = cfg.max_useful_position()

  # In some models, we don't predict the intervened_answer correctly in test_function.
  # So we may do a weak second pass and may add say "A5.BS.A632" tag to a node.
  for strong in [True, False]:
    if strong or do_weak_search:

      for position in range(from_position, to_position+1):
        test_nodes = filter_nodes( cfg.useful_nodes, prerequisites_function(position, the_impact_digit))
        if search_and_tag_digit_position( the_impact_digit, test_nodes, test_function, strong, the_tag, do_pair_search ):
          return True

  return False


# For each answer digit, for each useful position, search the related useful node(s), using the test_function, for the expected impact on the_impact_digit. We may do 2 passes.
def search_and_tag( prerequisites_function, test_function, tag_function, do_pair_search, do_weak_search, from_position = -1, to_position = -1):
  acfg.reset_intervention_totals()

  for the_impact_digit in range(cfg.num_answer_positions):
    search_and_tag_digit(
      prerequisites_function, the_impact_digit, test_function, tag_function,
      do_pair_search, do_weak_search, from_position, to_position )

  print( f"Ran {acfg.num_tests_run} intervention test(s). Added {acfg.num_tags_added} tag(s)")

In [None]:
cfg.useful_nodes.reset_node_tags(QuantaType.ALGO)

# Part 21B: Automated An.US search

The addition Use Sum 9 (US) operation is a simple task. Search for An.US tasks.

In [None]:
def add_us_tag(impact_digit):
  return answer_name(impact_digit-1)  + "." + AlgoTag.ADD_US_TAG

In [None]:
# These rules are prerequisites for (not proof of) an Addition UseSum9 node
def add_us_prereqs(position, impact_digit):
  return FilterAnd(
    FilterHead(),
    FilterPosition(position_name(position)),
    FilterAttention(dn_to_position_name(impact_digit-2)), # Attends to Dn-2
    FilterAttention(ddn_to_position_name(impact_digit-2)), # Attends to D'n-2
    FilterImpact(answer_name(impact_digit)))

In [None]:
def add_us_test(node_locations, alter_digit, strong):
  if alter_digit < 2 or alter_digit > cfg.n_digits:
    acfg.reset_intervention()
    return False

  intervention_impact = answer_name(alter_digit)

  # 25222 + 44444 = 69666. Has no Dn-2.MC but has Dn-1.US so not a US case
  store_question = [repeat_digit(2), repeat_digit(4)]
  store_question[0] += (5-2) * 10 ** (alter_digit - 1)

  # 34633 + 55555 = 90188. Has Dn-2.MC and Dn-1.US so is a US case
  clean_question = [repeat_digit(3), repeat_digit(5)]
  clean_question[0] += (4-3) * 10 ** (alter_digit - 1)
  clean_question[0] += (6-3) * 10 ** (alter_digit - 2)

  # When we intervene we expect answer 80188
  intervened_answer = clean_question[0] + clean_question[1] - 10 ** (alter_digit)


  # Unit test
  if cfg.n_digits == 5 and alter_digit == 4:
    assert store_question[0] == 25222
    assert clean_question[0] == 34633
    assert clean_question[0] + clean_question[1] == 90188
    assert intervened_answer == 80188


  success, _, _ = run_strong_intervention(node_locations, store_question, clean_question, MathsTokens.PLUS, intervention_impact, intervened_answer)

  if success:
    print( "Test confirmed", acfg.node_names(), "perform D"+str(alter_digit)+".US impacting "+intervention_impact+" accuracy.", "" if strong else "Weak")

  return success

In [None]:
search_and_tag( add_us_prereqs, add_us_test, add_us_tag, False, False)

# Part 21C: Automated An.MC search

The addition Make Carry (MC) operation is a simple task. Search for An.MC tasks.

In [None]:
def add_mc_tag(impact_digit):
  return answer_name(impact_digit-1)  + "." + AlgoTag.ADD_MC_TAG

In [None]:
# These rules are prerequisites for (not proof of) an Addition MakeCarry node
def add_mc_prereqs(position, impact_digit):
  return FilterAnd(
    FilterHead(),
    FilterPosition(position_name(position)),
    FilterAttention(dn_to_position_name(impact_digit-1)), # MC is calculated on the next lower-value digit.
    FilterAttention(ddn_to_position_name(impact_digit-1)), # MC is calculated on the next lower-value digit.
    FilterImpact(answer_name(impact_digit)))

In [None]:
def add_mc_test(node_locations, impact_digit, strong):
  alter_digit = impact_digit - 1

  if alter_digit < 0 or alter_digit >= cfg.n_digits:
    acfg.reset_intervention()
    return False

  intervention_impact = answer_name(impact_digit)

  # 222222 + 666966 = 889188. Has Dn.MC
  store_question = [repeat_digit(2), repeat_digit(6)]
  store_question[1] += (9 - 6) * (10 ** alter_digit)

  # 333333 + 555555 = 888888. No Dn.MC
  clean_question = [repeat_digit(3), repeat_digit(5)]

  # When we intervene we expect answer 889888
  intervened_answer = clean_question[0] + clean_question[1] + 10 ** (alter_digit+1)

  success, _, _ = run_strong_intervention(node_locations, store_question, clean_question, MathsTokens.PLUS, intervention_impact, intervened_answer)

  if success:
    print( "Test confirmed", acfg.node_names(), "perform D"+str(alter_digit)+".MC impacting "+intervention_impact+" accuracy.", "" if strong else "Weak")

  return success

In [None]:
search_and_tag( add_mc_prereqs, add_mc_test, add_mc_tag, False, False)

# Part 21D: Automated An.BA search

The addition Base Add (BA) operation is a simple task. The task may be split/shared over 2 attention heads in the same position. Search for An.BA calculations.

In [None]:
def add_ba_tag(impact_digit):
  return answer_name(impact_digit) + "." + AlgoTag.ADD_BA_TAG

In [None]:
# These rules are prerequisites for (not proof of) an Addition BaseAdd node
def add_ba_prereqs(position, impact_digit):
  return FilterAnd(
    FilterHead(),
    FilterPosition(position_name(position)),
    FilterAttention(dn_to_position_name(impact_digit)), # Attends to Dn
    FilterAttention(ddn_to_position_name(impact_digit)), # Attends to D'n
    FilterImpact(answer_name(impact_digit)), # Impacts An
    FilterAlgo(add_ba_tag(impact_digit), QuantaFilter.NOT)) # Has not already been flagged as a BA task

In [None]:
def add_ba_test1(alter_digit):
  # 222222 + 111111 = 333333. No Dn.MC
  store_question = [repeat_digit(2), repeat_digit(1)]

  # 555555 + 444444 = 999999. No Dn.MC
  clean_question = [repeat_digit(5), repeat_digit(4)]

  # When we intervene we expect answer 999399
  intervened_answer = clean_question[0] + clean_question[1] + (3-9) * 10 ** alter_digit

  return store_question, clean_question, intervened_answer


def add_ba_test2(alter_digit):
  # 222222 + 666666 = 888888. No Dn.MC
  store_question = [repeat_digit(2), repeat_digit(6)]

  # 555555 + 111111 = 666666. No Dn.MC
  clean_question = [repeat_digit(5), repeat_digit(1)]

  # When we intervene we expect answer 666866
  intervened_answer = clean_question[0] + clean_question[1] + (8-6) * 10 ** alter_digit

  return store_question, clean_question, intervened_answer


def add_ba_test(node_locations, alter_digit, strong):
  intervention_impact = answer_name(alter_digit)

  store_question, clean_question, intervened_answer = add_ba_test1(alter_digit)
  success1, answer_success1, impact_success1 = run_strong_intervention(node_locations, store_question, clean_question, MathsTokens.PLUS, intervention_impact, intervened_answer)

  store_question, clean_question, intervened_answer = add_ba_test2(alter_digit)
  success2, answer_success2, impact_success2 = run_strong_intervention(node_locations, store_question, clean_question, MathsTokens.PLUS, intervention_impact, intervened_answer)

  success = (success1 and success2) if strong else (impact_success1 and impact_success2)

  if success:
    print( "Test confirmed:", acfg.node_names(), "perform D"+str(alter_digit)+".BA = (D"+str(alter_digit)+" + D'"+str(alter_digit)+") % 10 impacting "+intervention_impact+" accuracy.", "" if strong else "Weak")

  return success

In [None]:
search_and_tag( add_ba_prereqs, add_ba_test, add_ba_tag, True, True, cfg.question_tokens() )

# Part 21E: Automated Dn.C search

Search for D0.C to D5.C with impact "A65432" to "A65" in early tokens.

A0 and A1 are too simple to need Dn.C values so they are excluded from the answer impact.

In [None]:
def add_tc_tag(focus_digit):
  return "D" + str(focus_digit) + "." + AlgoTag.ADD_TC_TAG

In [None]:
# These rules are prerequisites for (not proof of) an Addition Dn.C node
def add_tc_prereqs(position, focus_digit):
  return FilterAnd(
    FilterHead(),
    FilterPosition(position_name(position)),
    FilterAttention(dn_to_position_name(focus_digit)), # Attends to Dn
    FilterAttention(ddn_to_position_name(focus_digit)), # Attends to D'n
    FilterPCA(MathsTag.PCA_ADD_TAG, QuantaFilter.CONTAINS)) # Node PCA is interpretable (bigram or trigram output) with respect to T8,T9,T10

In [None]:
def add_tc_test(node_locations, focus_digit, strong):
  # 222222 + 777977 = 1000188. Has Dn.MC
  store_question = [repeat_digit(2), repeat_digit(7)]
  store_question[1] += (9 - 7) * (10 ** focus_digit)

  # 333333 + 666666 = 999999. No Dn.MC
  clean_question = [repeat_digit(3), repeat_digit(6)]

  success = run_weak_intervention(node_locations, store_question, clean_question, MathsTokens.PLUS)

  if success:
    description = acfg.node_names() + " perform D"+str(focus_digit)+".C = TriCase(D"+str(focus_digit)+" + D'"+str(focus_digit)+")"
    print("Test confirmed", description, "Impact:", acfg.intervened_impact, "" if strong else "Weak")

  return success

In [None]:
if cfg.n_layers > 1: # Have not seen this task in 1-layer models.
  search_and_tag( add_tc_prereqs, add_tc_test, add_tc_tag,
    False, # Have not seen this task split between nodes.
    False,
    cfg.n_digits, cfg.question_tokens()) # These occur from the first D'n digit to the first answer digit.

# Part 21F: Automated An.BS search

The subtraction Base Subtraction (BS) operation is a simple task. The task may be split/shared over 2 attention heads in the same position. Search for An.BS calculations.

In [None]:
def sub_bs_tag(impact_digit):
  return answer_name(impact_digit) + "." + AlgoTag.SUB_BS_TAG

In [None]:
# These rules are prerequisites for (not proof of) a BaseSubtraction node
def sub_bs_prereqs(position, impact_digit):
  return FilterAnd(
    FilterHead(),
    FilterPosition(position_name(position)),
    FilterAttention(dn_to_position_name(impact_digit)), # Attends to Dn
    FilterAttention(ddn_to_position_name(impact_digit)), # Attends to D'n
    FilterImpact(answer_name(impact_digit)), # Impacts An
    FilterAlgo(sub_bs_tag(impact_digit), QuantaFilter.NOT)) # Has not already been flagged as a BS task

In [None]:
def sub_bs_test1(alter_digit):
  # 333333 - 111111 = 222222. No Dn.BO
  store_question = [repeat_digit(3), repeat_digit(1)]

  # 999999 - 444444 = 555555. No Dn.BO
  clean_question = [repeat_digit(9), repeat_digit(4)]

  # When we intervene we expect answer 555255
  intervened_answer = clean_question[0] - clean_question[1] + (2-5) * 10 ** alter_digit

  return store_question, clean_question, intervened_answer


def sub_bs_test2(alter_digit):
  # 666666 - 222222 = 444444. No Dn.BO
  store_question = [repeat_digit(6), repeat_digit(2)]

  # 999999 - 333333 = 666666. No Dn.BO
  clean_question = [repeat_digit(9), repeat_digit(3)]

  # When we intervene we expect answer 666466
  intervened_answer = clean_question[0] - clean_question[1] + (4-6) * 10 ** alter_digit

  return store_question, clean_question, intervened_answer


def sub_bs_test(node_locations, alter_digit, strong):
  intervention_impact = answer_name(alter_digit)

  store_question, clean_question, intervened_answer = sub_bs_test1(alter_digit)
  success1, answer_success1, impact_success1 = run_strong_intervention(node_locations, store_question, clean_question, MathsTokens.MINUS, intervention_impact, intervened_answer)

  store_question, clean_question, intervened_answer = sub_bs_test2(alter_digit)
  success2, answer_success2, impact_success2 = run_strong_intervention(node_locations, store_question, clean_question, MathsTokens.MINUS, intervention_impact, intervened_answer)

  success = (success1 and success2) if strong else (impact_success1 and impact_success2)

  if success:
    print( "Test confirmed:", acfg.node_names(), "perform D"+str(alter_digit)+".BS = (D"+str(alter_digit)+" + D'"+str(alter_digit)+") % 10 impacting "+intervention_impact+" accuracy.", "" if strong else "Weak")

  return success

In [None]:
#cfg.useful_nodes.reset_node_tags(QuantaType.ALGO)
#acfg.show_test_failures = True
search_and_tag( sub_bs_prereqs, sub_bs_test, sub_bs_tag, True, True, cfg.question_tokens() )

# Part 21G: Automated An.BO search

The subtraction Borrow One operation is a simple task. Search for An.BO tasks.

In [None]:
def sub_bo_tag(impact_digit):
  return answer_name(impact_digit-1)  + "." + AlgoTag.SUB_BO_TAG

In [None]:
# These rules are prerequisites for (not proof of) an subtraction Borrow One node
def sub_bo_prereqs(position, impact_digit):
  return FilterAnd(
    FilterHead(),
    FilterPosition(position_name(position)),
    FilterAttention(dn_to_position_name(impact_digit-1)), # BO is calculated on the next lower-value digit.
    FilterAttention(ddn_to_position_name(impact_digit-1)), # BO is calculated on the next lower-value digit.
    FilterImpact(answer_name(impact_digit)))

In [None]:
def sub_bo_test(node_locations, impact_digit, strong):
  alter_digit = impact_digit - 1

  if alter_digit < 0 or alter_digit >= cfg.n_digits:
    acfg.reset_intervention()
    return False

  intervention_impact = answer_name(impact_digit)

  # 222222 - 111311 = 110911. Has Dn.BO
  store_question = [repeat_digit(2), repeat_digit(1)]
  store_question[1] += (3 - 1) * (10 ** alter_digit)

  # 777777 - 444444 = 333333. No Dn.BO
  clean_question = [repeat_digit(7), repeat_digit(4)]

  # When we intervene we expect answer 332333
  intervened_answer = clean_question[0] - clean_question[1] - 10 ** (alter_digit+1)

  if cfg.n_digits == 6 and impact_digit == 3:
    assert store_question[0] == 222222
    assert store_question[1] == 111311
    assert store_question[0] - store_question[1] == 110911
    assert clean_question[0] == 777777
    assert clean_question[1] == 444444
    assert clean_question[0] - clean_question[1] == 333333
    assert intervened_answer == 332333

  success, _, _ = run_strong_intervention(node_locations, store_question, clean_question, MathsTokens.MINUS, intervention_impact, intervened_answer)

  if success:
    print( "Test confirmed", acfg.node_names(), "perform D"+str(alter_digit)+".BO impacting "+intervention_impact+" accuracy.", "" if strong else "Weak")

  return success

In [None]:
#cfg.useful_nodes.reset_node_tags(QuantaType.ALGO)
#acfg.show_test_failures = True
search_and_tag( sub_bo_prereqs, sub_bo_test, sub_bo_tag, False, False)

# Part 21H: Automated An.NG search

Somes useful nodes are only used in subtraction when D < D' in the D - D' question e.g. negative-answer questions. We claim these nodes are associated somehow with converting D - D' to the mathematically equivalent - ( D' - D )

To calculate D'>D, model needs to calculate D'6>D6, ..., D'0>D0 and then combine the results in a cascade D'6>D6 else ... else D'0>D0. We expect to see nodes only used in NG questions, with PCA bigram (or trigram) outputs, attending to these input pairs. The "else cascade" is tested later.

In [None]:
def sub_ng_tag(impact_digit):
  return answer_name(impact_digit)  + "." + AlgoTag.SUB_NG_TAG

In [None]:
def sub_ng_prereqs(position, impact_digit):
  return FilterAnd(
    FilterHead(),
    FilterPosition(position_name(position)),
    FilterImpact(answer_name(impact_digit)), # Impacts An
    FilterContains(QuantaType.MATH_SUB, MathsTag.SUB_NG_TAG), # Impacts negative-answer questions
    # Does not impact positive-answer subtraction questions (of any complexity)
    FilterContains(QuantaType.MATH_SUB, MathsTag.SUB_S0_TAG, QuantaFilter.NOT),
    FilterContains(QuantaType.MATH_SUB, MathsTag.SUB_S1_TAG, QuantaFilter.NOT),
    FilterContains(QuantaType.MATH_SUB, MathsTag.SUB_S2_TAG, QuantaFilter.NOT),
    FilterContains(QuantaType.MATH_SUB, MathsTag.SUB_S3_TAG, QuantaFilter.NOT),
    FilterPCA(MathsTag.PCA_SUB_TAG, QuantaFilter.CONTAINS)) # Node PCA is interpretable (bigram or trigram output) with respect to T8,T9,T10

In [None]:
# Test that if we ablate this node then a negative-answer-subtraction question answer swaps to its positive complement
def sub_ng_test(node_locations, focus_digit, strong):
  if focus_digit >= cfg.n_digits:
    acfg.reset_intervention()
    return False

  # 555555 - 333333 = 222222. Is a positive-answer-subtraction
  store_question = [repeat_digit(5), repeat_digit(3)]

  # 444444 - 444644 = -200. Is a negative-answer-subtraction question because of focus_digit
  clean_question = [repeat_digit(4), repeat_digit(4)]
  clean_question[1] += 2 * (10 ** focus_digit)

  success = run_weak_intervention(node_locations, store_question, clean_question, MathsTokens.MINUS)

  if success:
    description = acfg.node_names() + " perform A"+str(focus_digit)+".NG"
    print("Test confirmed", description, "Impact:", acfg.intervened_impact, "" if strong else "Weak")

  return success

In [None]:
search_and_tag( sub_ng_prereqs, sub_ng_test, sub_ng_tag, False, False )

# Part 21I: Automated OP search

For mixed models that do addition and subtraction the operation token "+/-" (in the middle of the question) is key. Find nodes that attend to the question operation.

In [None]:
def mix_op_tag(impact_digit):
  return "An." + AlgoTag.MIX_OP_TAG # Doesnt depend on impact_digit

In [None]:
def mix_op_prereqs(position, impact_digit):
  return FilterAnd(
    FilterHead(),
    FilterPosition(position_name(position)),
    FilterAttention(op_position_name()))

In [None]:
#cfg.useful_nodes.reset_node_tags(QuantaType.ALGO)
#acfg.show_test_failures = True
search_and_tag( mix_op_prereqs, succeed_test, mix_op_tag, False, False )

# Part 21J: Automated SG search

For mixed models that do addition and subtraction, and for our subtraction models, the answer sign token "+/-" (at the start of the answer) is important. Find nodes that attend to the answer sign token.

In [None]:
def mix_sg_tag(impact_digit):
  return "An." + AlgoTag.MIX_SG_TAG # Doesnt depend on impact_digit

In [None]:
def mix_sg_prereqs(position, impact_digit):
  return FilterAnd(
    FilterHead(),
    FilterPosition(position_name(position)),
    FilterAttention(an_to_position_name(cfg.n_digits+1)))

In [None]:
#cfg.useful_nodes.reset_node_tags(QuantaType.ALGO)
#acfg.show_test_failures = True
search_and_tag( mix_sg_prereqs, succeed_test, mix_sg_tag, False, False )

# Part 22: Show algorithm quanta map

Plot the "algorithm" tags generated in previous steps as a quanta map. This is an automatically generated partail explanation of the model algorithm.

In [None]:
show_quanta_map( "Algorithm Purpose Per Node", False, 2, None, QuantaType.ALGO, "", get_quanta_binary, 9)

# Part 23: Save useful nodes with behaviour and algorithm tags to JSON file

Show a list of the nodes that have proved useful in calculations, together with data on the nodes behavior and algorithmic purposes.
Save the data to a Colab temporary JSON file.



In [None]:
cfg.useful_nodes.print_node_tags(QuantaType.ALGO, False)

In [None]:
# Serialize and save the useful nodes list with algorithm tags to a temporary CoLab file in JSON format
print( "Saving useful node list with algorithm tags:", main_fname_algorithm_json)
cfg.useful_nodes.save_nodes(main_fname_algorithm_json)

# Part 25A: Setup: Test Algorithm utilities

In [None]:
num_algo_valid_clauses = 0
num_algo_invalid_clauses = 0

In [None]:
def start_algorithm_test():
  global num_algo_invalid_clauses
  global num_algo_valid_clauses

  num_algo_valid_clauses = 0
  num_algo_invalid_clauses = 0

  acfg.print_prediction_success_rate()
  print()

  # Get the model nodes with a known algorithmic purpose
  return filter_nodes( cfg.useful_nodes, FilterAlgo("", QuantaFilter.MUST))

In [None]:
# Does a useful node exist matching the filters? If so, return the position
def test_algo_clause(node_list, the_filters):
  global num_algo_invalid_clauses
  global num_algo_valid_clauses

  answer_position = -1
  matching_nodes = filter_nodes(node_list, the_filters)
  num_nodes = len(matching_nodes.nodes)

  if num_nodes > 0:
    print( "Clause valid:", matching_nodes.get_node_names(), " match", the_filters.describe())
    num_algo_valid_clauses += 1
    answer_position = matching_nodes.nodes[0].position
  else:
    print( "Clause invalid: No nodes match", the_filters.describe())
    num_algo_invalid_clauses += 1

  return answer_position

In [None]:
def test_algo_logic(clause_name, clause_valid):
  global num_algo_invalid_clauses
  global num_algo_valid_clauses

  if clause_valid:
    print( "Clause valid:", clause_name)
    num_algo_valid_clauses += 1
  else:
    print( "Clause invalid:", clause_name)
    num_algo_invalid_clauses += 1

In [None]:
# Show the fraction of hypothesis clauses that were valid
def print_algo_clause_results():
  global num_algo_invalid_clauses
  global num_algo_valid_clauses

  print( "Overall", num_algo_valid_clauses, "out of", num_algo_valid_clauses + num_algo_invalid_clauses, "algorithm clauses succeeded")

In [None]:
# Show the fraction of useful nodes that have an assigned algorithmic purpose
def print_algo_purpose_results(algo_nodes):
  num_heads = cfg.useful_nodes.num_heads()
  num_neurons = cfg.useful_nodes.num_neurons()

  num_heads_with_purpose = algo_nodes.num_heads()
  num_neurons_with_purpose = algo_nodes.num_neurons()

  print()
  print( f"{num_heads_with_purpose} of {num_heads} useful attention heads ({num_heads_with_purpose / num_heads * 100:.2f}%) have an algorithmic purpose assigned." )
  print( f"{num_neurons_with_purpose} of {num_neurons} useful MLP neurons ({num_neurons_with_purpose / num_neurons * 100:.2f}%) have an algorithmic purpose assigned." )

# Part 25B: Results: Test Algorithm - Addition

## Part25B.1 Model add_d5_l1_h3_t30K. Tasks An.BA, An.MC, An.US

This 1-layer model cant do all addition questions. This hypothesis mirrors Paper 1. 14/15 heads have purpose assigned. 0/6 neurons have purpose assigned.

In [None]:
# For answer digits (excluding Amax), An.BA and An.MC nodes are needed before the answer digit is revealed
def test_algo_ba_mc(algo_nodes):
  for impact_digit in range(cfg.n_digits):
    test_algo_clause(algo_nodes, FilterAnd(FilterAlgo(add_ba_tag(impact_digit)), FilterPosition(an_to_position_name(impact_digit+1), QuantaFilter.MUST_BY)))

    test_algo_clause(algo_nodes, FilterAnd(FilterAlgo(add_mc_tag(impact_digit)), FilterPosition(an_to_position_name(impact_digit+1), QuantaFilter.MUST_BY)))


# For answer digits (excluding Amax), An.US nodes are needed before the answer digit is revealed
def test_algo_us(algo_nodes):
  for impact_digit in range(cfg.n_digits):
      test_algo_clause(algo_nodes, FilterAnd(FilterAlgo(add_us_tag(impact_digit)), FilterPosition(an_to_position_name(impact_digit+1), QuantaFilter.MUST_BY)))

In [None]:
if cfg.model_name == "add_d5_l1_h3_t30K":

  algo_nodes = start_algorithm_test()

  test_algo_ba_mc(algo_nodes)
  test_algo_us(algo_nodes)
  print_algo_clause_results()

  print_algo_purpose_results(algo_nodes)

## Part25B.2 Models add_d5/d6_l2_h3_t15K. Tasks An.BA, An.MC, An.US, Dn.C

These 2-layer models can do addition accurately. This hypothesis mirrors Paper 2.

In [None]:
# Before Amax is revealed (as a 0 or 1), there must be a Dn.C node for every digit pair
# For each digit (except A0) there must be either an An.US or an Dn.C
def test_algo_tc_us(model_nodes):
  for impact_digit in range(cfg.n_digits):
    early_dnc = FilterAnd(FilterAlgo(add_tc_tag(impact_digit)), FilterPosition(an_to_position_name(cfg.n_digits+1), QuantaFilter.MUST_BY))
    late_dnc = FilterAnd(FilterAlgo(add_tc_tag(impact_digit)), FilterPosition(an_to_position_name(impact_digit+1), QuantaFilter.MUST_BY))
    any_dnus = FilterAnd(FilterAlgo(add_us_tag(impact_digit)), FilterPosition(an_to_position_name(impact_digit+1), QuantaFilter.MUST_BY))

    if cfg.n_layers == 1:
      # There must be a Dn.US node for every answer digit except A0
      if impact_digit > 0:
        test_algo_clause(model_nodes, any_dnus)
    else:
      # There must be a Dn.US node or a Dn.C node for every digit except A0
      if impact_digit > 0:
        test_algo_clause(model_nodes, FilterOr(any_dnus, late_dnc))

      # There must a Dn.C node for every digit before the first 1 or 0 digit is calculated
      test_algo_clause(model_nodes, early_dnc)

In [None]:
if cfg.model_name == "add_d5_l2_h3_t15K" or cfg.model_name == "add_d6_l2_h3_t15K":

  algo_nodes = start_algorithm_test()

  test_algo_ba_mc(algo_nodes)
  test_algo_tc_us(algo_nodes)
  print_algo_clause_results()

  print_algo_purpose_results()

# Part 25C: Results: Test Algorithm - Subtraction

## Part 2C.1: Model XXX. Tasks BS, BO, SZ

This 2-layer model can do subtraction accurately. TBC

In [None]:
# For answer digits (excluding Amax), An.BS and An.BO nodes are needed before the answer digit is revealed
def test_algo_bs_bo(algo_nodes):
  for impact_digit in range(cfg.n_digits):
    test_algo_clause(algo_nodes, FilterAnd(FilterAlgo(sub_bs_tag(impact_digit)), FilterPosition(an_to_position_name(impact_digit+1), QuantaFilter.MUST_BY)))

    test_algo_clause(algo_nodes, FilterAnd(FilterAlgo(sub_bo_tag(impact_digit)), FilterPosition(an_to_position_name(impact_digit+1), QuantaFilter.MUST_BY)))


# For answer digits (excluding Amax), An.SZ nodes are needed before the answer digit is revealed
def test_algo_sz(algo_nodes):
  for impact_digit in range(cfg.n_digits):
      pass
      #test_algo_clause(algo_nodes, FilterAnd(FilterAlgo(sub_sz_tag(impact_digit)), FilterPosition(an_to_position_name(impact_digit+1), QuantaFilter.MUST_BY)))

In [None]:
if cfg.model_name == "sub_d6_l2_h3_t30K":

  algo_nodes = start_algorithm_test()

  test_algo_bs_bo(algo_nodes)
  test_algo_sz(algo_nodes)
  print_algo_clause_results()

  print_algo_purpose_results(algo_nodes)

## Part 2C.2: Test Algorithm - Subtraction - Negative Answer

We claim the model converts subtraction questions D - D' to the mathematically equivalent - ( D' - D ) when D' > D. To do this the model needs to know when D' > D (or equivalently when D' > D).

To calculate D'>D, model needs to calculate D'6>D6 else D'5>D5 else D'4>D4 else D'3>D3 else D'2>D2 else D'1>D1 else D'1>D1. We expect to see nodes only used in NG questions, with PCA bigram (or trigram) outputs, attending to these input pairs, evaluated in this order.


In [None]:
# For answer digits (excluding Amax), An.NG is needed before the answer digit is revealed
def test_algo_ng(algo_nodes):
  ng_locations = {}

  for impact_digit in range(cfg.n_digits):
    # For answer digits (excluding the +/- sign and 0 or 1 Amax), An.NG is calculated before the answer digit is revealed
    position = test_algo_clause(algo_nodes, FilterAnd(FilterAlgo(sub_ng_tag(impact_digit)), FilterPosition(an_to_position_name(impact_digit), QuantaFilter.MUST_BY)))
    ng_locations[impact_digit] = position

  # Check that ng_locations[6] < ng_locations[5] < ng_locations[4] < etc
  for impact_digit in range(cfg.n_digits):
    if impact_digit > 0:
      test_algo_logic("NG Ordering for A" + str(impact_digit), ng_locations[impact_digit] < ng_locations[impact_digit-1])

In [None]:
if cfg.model_name == "sub_d6_l2_h3_t30K":

  algo_nodes = start_algorithm_test()

  test_algo_ng(algo_nodes)
  print_algo_clause_results()

  print_algo_purpose_results(algo_nodes)

# Part 26: Test Algorithm - Mixed Addition and Subtraction model

Our working assumption for models ins1_mix_d6_l3_h4_t40K, ins2_mix_d6_l4_h4_t40K and ins3_mix_d6_l4_h3_t40K is that the model's algorithm is:

* H1: Pays attention to the +\- question operator (using OP task)
* If operator is "+" then  
  * H2: Does addition using BA, MC, US & TC tasks
* Else
  * H3: Calculates whether D > D' (using NG tasks)
  * If D > D' then
    * H4: Amax is "+"
    * H4: Does subtraction using BS, BO, SZ & T?? tasks
  * Else
    * H5: Applys D - D' = - (D' - D) transform
    * H6: Amax is =-"
    * H4: Does subtraction using BS, BO, SZ & T?? tasks
    * H5: Applys D - D' = - (D' - D) transform a second time

Each of the algorithm sub-tasks H1 to H5 will be validated separately.












In [None]:
mixed_model = cfg.model_name == "ins1_mix_d6_l3_h4_t40K" or cfg.model_name == "ins2_mix_d6_l4_h4_t40K" or cfg.model_name == "ins3_mix_d6_l4_h3_t40K"

## Part 26.H3 Calculate whether D > D' (using NG tasks)

In [None]:
# Only display nodes with the SUB_NG_TAG tag.
filters = FilterContains(QuantaType.MATH_SUB, MathsTag.SUB_NG_TAG)
show_quanta_map( "Subtraction Behavior NG Nodes", False, 2, filters, QuantaType.MATH_SUB, "", get_quanta_min_complexity)
show_quanta_map( "Attention Behavior Per (NG) Head", True, 10, filters, QuantaType.ATTENTION, "", get_quanta_attention, 10, 6)
show_quanta_map( "Algorithm Purpose Per (NG) Node", False, 2, filters, QuantaType.ALGO, "", get_quanta_binary, 9, 10)

In [None]:
if mixed_model:
  algo_nodes = start_algorithm_test()

  test_algo_ng(algo_nodes)
  print_algo_clause_results()

  print_algo_purpose_results(algo_nodes)

# Part 30: Unit Test automated searches

In [None]:
def unit_test_node_tag(node_location_as_str, the_tags ):
  node_location = str_to_node_location(node_location_as_str)
  node = cfg.useful_nodes.get_node(node_location)
  assert node is not None

  for the_tag in the_tags:
    if not node.contains_tag( QuantaType.ALGO, the_tag):
      print( "Unit test failure: Node", node.name(), "does not have expected tag", the_tag)

In [None]:
print(cfg.model_name)

if cfg.model_name == 'add_d6_l2_h3_t15K':
  unit_test_node_tag('P11L0H0', ['D2.TC'] )
  unit_test_node_tag('P12L0H0', ['D3.TC'] )
  unit_test_node_tag('P14L0H0', ['A5.US', 'D4.TC'] )
  unit_test_node_tag('P14L0H2', ['A5.MC', 'D5.TC'] )
  unit_test_node_tag('P14L1H1', ['OP'] )
  unit_test_node_tag('P15L0H0', ['A4.MC'] )
  unit_test_node_tag('P15L0H1', ['A5.BA'] )
  unit_test_node_tag('P15L0H2', ['A5.BA'] )
  unit_test_node_tag('P16L0H0', ['A3.MC'] )
  unit_test_node_tag('P16L0H1', ['A4.BA'] )
  unit_test_node_tag('P16L0H2', ['A4.BA'] )
  unit_test_node_tag('P17L0H0', ['A2.MC'] )
  unit_test_node_tag('P17L0H1', ['A3.BA'] )
  unit_test_node_tag('P17L0H2', ['A3.BA'] )
  unit_test_node_tag('P18L0H0', ['A1.MC'] )
  unit_test_node_tag('P18L0H1', ['A2.BA'] )
  unit_test_node_tag('P18L0H2', ['A2.BA'] )
  unit_test_node_tag('P19L0H0', ['A0.MC'] )
  unit_test_node_tag('P19L0H1', ['A1.BA'] )
  unit_test_node_tag('P19L0H2', ['A1.BA'] )
  unit_test_node_tag('P20L0H1', ['A0.BA'] )
  unit_test_node_tag('P20L0H2', ['A0.BA'] )

if cfg.model_name == 'mix_d6_l3_h4_t40K':
  unit_test_node_tag('P8L0H1', ['OP'] )
  unit_test_node_tag('P13L2H0', ['A7.NG'] )
  unit_test_node_tag('P15L0H0', ['A5.BA', 'A5.BS'] )
  unit_test_node_tag('P15L0H3', ['A5.BA', 'A5.BS'] )
  unit_test_node_tag('P16L0H3', ['A4.BA.A4', 'A4.BS.A4'] )
  unit_test_node_tag('P17L0H1', ['A3.NG'] )
  unit_test_node_tag('P17L0H3', ['A3.BA.A3', 'A3.BS.A3'] )
  unit_test_node_tag('P18L0H1', ['A2.NG'] )
  unit_test_node_tag('P18L0H3', ['A2.BA.A2', 'A2.BS.A2'] )
  unit_test_node_tag('P19L0H1', ['A1.NG'] )
  unit_test_node_tag('P19L0H3', ['A1.BA.A1', 'A1.BS.A1'] )
  unit_test_node_tag('P20L0H0', ['A0.BA', 'A0.BS'] )
  unit_test_node_tag('P20L0H3', ['A0.BA', 'A0.BS'] )
  unit_test_node_tag('P20L2H1', ['A0.NG'] )

if cfg.model_name == 'ins1_mix_d6_l3_h4_t40K':
  unit_test_node_tag('P9L0H1', ['D4.TC'] )
  unit_test_node_tag('P9L0H3', ['A5.NG', 'OP'] )
  unit_test_node_tag('P10L0H1', ['D2.TC'] )
  unit_test_node_tag('P12L0H1', ['D3.TC'] )
  unit_test_node_tag('P13L2H0', ['A7.NG'] )
  unit_test_node_tag('P14L0H0', ['A5.US', 'A6.NG'] )
  unit_test_node_tag('P14L0H2', ['A5.MC', 'D5.TC'] )
  unit_test_node_tag('P15L0H0', ['A4.MC'] )
  unit_test_node_tag('P15L0H1', ['A5.BA', 'A5.BS'] )
  unit_test_node_tag('P15L0H2', ['A5.BA', 'A5.BS'] )
  unit_test_node_tag('P16L0H0', ['A3.MC'] )
  unit_test_node_tag('P16L0H1', ['A4.BA', 'A4.BS'] )
  unit_test_node_tag('P16L0H2', ['A4.BA', 'A4.BS'] )
  unit_test_node_tag('P16L2H0', ['A4.NG'] )
  unit_test_node_tag('P17L0H0', ['A2.MC'] )
  unit_test_node_tag('P17L0H1', ['A3.BA', 'A3.BS'] )
  unit_test_node_tag('P17L0H2', ['A3.BA', 'A3.BS'] )
  unit_test_node_tag('P17L2H0', ['A3.NG'] )
  unit_test_node_tag('P18L0H0', ['A1.MC'] )
  unit_test_node_tag('P18L0H1', ['A2.BA', 'A2.BS'] )
  unit_test_node_tag('P18L0H2', ['A2.BA', 'A2.BS'] )
  unit_test_node_tag('P18L2H0', ['A2.NG'] )
  unit_test_node_tag('P19L0H0', ['A0.MC'] )
  unit_test_node_tag('P19L0H1', ['A1.BA', 'A1.BS'] )
  unit_test_node_tag('P19L0H2', ['A1.BA', 'A1.BS'] )
  unit_test_node_tag('P19L2H0', ['A1.NG'] )
  unit_test_node_tag('P20L0H0', ['A0.NG'] )
  unit_test_node_tag('P20L0H1', ['A0.BA', 'A0.BS'] )
  unit_test_node_tag('P20L0H2', ['A0.BA', 'A0.BS'] )