# Accurate Integer Mathematics in Transformers - Analyse the Model

This CoLab analyses a Transformer model that performs integer addition, subtraction and multiplication e.g. 133357+182243=+0315600, 123450-345670=-0123230 and 000345*000823=+283935. Each digit is a separate token. For 6 digit questions, the model is given 14 "question" (input) tokens, and must then predict the corresponding 8 "answer" (output) tokens.

## Tips for using the Colab
 * You can run and alter the code in this CoLab notebook yourself in Google CoLab ( https://colab.research.google.com/ ).
 * To run the notebook, in Google CoLab, **you will need to** go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.
 * Some graphs are interactive!
 * Use the table of contents pane in the sidebar to navigate.
 * Collapse irrelevant sections with the dropdown arrows.
 * Search the page using the search in the sidebar, not CTRL+F.

# Part 0A: Import libraries
Imports standard libraries. Don't bother reading.

In [None]:
DEVELOPMENT_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")

    !pip install matplotlib
    !pip install prettytable

    !pip install kaleido
    !pip install transformer_lens
    !pip install torchtyping
    !pip install transformers

    !pip install numpy
    !pip install scikit-learn

except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook


In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import kaleido
import plotly.io as pio

if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

In [None]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [None]:
import json
import torch
import torch.nn.functional as F
import numpy as np
import random
from prettytable import PrettyTable
import itertools
import re
from enum import Enum

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import textwrap

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache


# Part 0B: Import PCA library

In [None]:
# Import Principal Component Analysis (PCA) library
use_pca = True
try:
  from sklearn.decomposition import PCA
except Exception as e:
  print("pca import failed with exception:", e)
  use_pca = False

  # Sometimes version conflicts means the PCA library does not import. This workaround partially fixes the issue
  !pip install --upgrade numpy
  !pip install --upgrade scikit-learn

  # To complete workaround, now select menu option "Runtime > Restart session and run all".
  stop

# Part 0C: Import verified_transformers library

In [None]:
!pip install --upgrade git+https://github.com/PhilipQuirke/verified_transformers.git
from QuantaTools import QuantaFilter, QuantaType, position_name, position_name_to_int, row_location_name, location_name, NodeLocation, UsefulNode, str_to_node_location, UsefulInfo, useful_info, answer_name
from QuantaTools import token_to_char, tokens_to_string
from QuantaTools import FilterAnd, FilterOr, FilterHead, FilterNeuron, FilterContains, FilterPosition, FilterAttention, FilterImpact, FilterPCA, FilterAlgo, filter_nodes
from QuantaTools import create_custom_colormap, calc_quanta_map, get_quanta_fail_perc, get_quanta_attention, get_quanta_binary
from QuantaTools import get_answer_impact, get_question_answer_impact, compact_answer_if_sequential, get_quanta_impact

# Part 1A: Configuration: Detailed

In [None]:
# Tokens used in vocab. (Token indexes 0 to 9 represent digits 0 to 9)
PLUS_INDEX = 10
MINUS_INDEX = 11
EQUALS_INDEX = 12
MULT_INDEX = 13
DIV_INDEX = 14
MAX_INDEX = DIV_INDEX

In [None]:
# Main configuration class for main model creation and training
class Config():
  #@markdown Main Model
  n_layers: int = 3 #@param
  n_heads: int = 4 #@param

  d_vocab: int = MAX_INDEX+1
  d_model: int = 510
  d_mlp_multiplier: int = 4
  d_mlp: int = d_mlp_multiplier * d_model
  d_head: int = 170
  training_seed: int = 372001
  analysis_seed: int = 673023

  #@markdown Data
  n_digits: int = 6 #@param
  act_fn: str = 'relu'
  batch_size: int = 512 # Training used 64. Larger for speed during analysis

  #@markdown Optimizer
  n_training_steps: int = 40000 #@param
  weight_decay: float = 0.00008
  lr: int = 0.1

  #@markdown Actions

  # Percent of questions that are multiplication, subtraction (rest are addition questions).
  perc_mult: int = 0 # e.g. 20
  perc_sub: int = 0 #@param e.g. 80
  def perc_add(self):
    return max(0, 100 - self.perc_mult - self.perc_sub)

  #@markdown Insert Model
  insert_mode: int = 0 #@param 0=None 1=Init, 2=FreezeHeads 3=FreezeAll


  # Save graphs to CoLab temp files as PDF or SVG. You can manually export temp files for re-use in papers.
  graph_file_suffix = "svg"

  # The format to output prettytable in. Options are text|html|json|csv|latex
  # Use Text for this CoLab, latex for Overleaf output, and html for GitHub blog output
  table_out_format: str = "text"


  # The number of question tokens
  # This is also the token position of the first answer digit (which is a "+" or a  "-")
  def question_tokens(self):
    return self.n_digits*2 + 2
  def answer_tokens(self):
    return self.n_digits + 2
  def n_ctx(self):
    return self.question_tokens() + self.answer_tokens()

  # How many slices do we break the MLP layer up into?
  def mlp_slices(self):
    return 1 # Paper 2 used this granualarity
    # return self.n_heads * self.d_mlp_multiplier # Alternative for Paper 3?

  # Model name prefix for models stored on HuggingFace
  model_name = ""

  main_model : HookedTransformer = None


cfg = Config()

# Part 1B: Configuration: Summary

Which existing model do we want to analyse?

The existing model weightings created by the sister Colab [VerifiedArithmeticTrain](https://github.com/PhilipQuirke/transformer-maths/blob/main/assets/VerifiedArithmeticTrain.ipynb) are loaded from HuggingFace.

In [None]:
# Which existing model do we want to analyse?
# cfg.model_name = "" # Use configuration specified in Part 1A
# cfg.model_name = "add_d5_l1_h3_t30K"  # 5 digit addition model. Inaccurate as only has one layer. Can predict S0, S1 and S2 complexity questions
# cfg.model_name = "add_d5_l2_h3_t15K"  # 5 digit addition model
# cfg.model_name = "add_d6_l2_h3_t15K"  # 6 digit addition model
cfg.model_name = "sub_d6_l2_h3_t30K"  # 6 digit subtraction model
# cfg.model_name = "mix_d6_l3_h4_t40K"  # 6 digit addition and subtraction model. AvgFinalLoss=8e-09
# cfg.model_name = "ins1_mix_d6_l3_h4_t40K"  # 6 digit addition / subtraction model. Initialise with addition model. Handles 1m Qs for Add and Sub
# cfg.model_name = "ins2_mix_d6_l4_h4_t40K"  # 6 digit addition / subtraction model. Initialised with addition model. Reset useful heads every 100 epochs. AvgFinalLoss=7e-09. Fails 1m Qs
# cfg.model_name = "ins3_mix_d6_l4_h3_t40K"  # 6 digit addition / subtraction model. Initialised with addition model. Reset useful heads & MLPs every 100 epochs. AvgFinalLoss=2.6e-06. Fails 1m Qs

# Part 1C: Configuration: Input and Output file names



In [None]:
if cfg.model_name != "":

  match = re.search("d(\d)_", cfg.model_name)
  if match:
    cfg.n_digits = int(match.group(1))

  match = re.search("l(\d)_", cfg.model_name)
  if match:
    cfg.n_layers = int(match.group(1))

  match = re.search("h(\d)_", cfg.model_name)
  if match:
    cfg.n_heads = int(match.group(1))

  match = re.search("t(\d\d)K", cfg.model_name)
  if match:
    cfg.n_training_steps = int(match.group(1)) * 1000

  cfg.perc_sub = 0
  cfg.insert_mode = 0

  if cfg.model_name.startswith("sub_") :
    cfg.perc_sub = 100

  if cfg.model_name == "mix_d6_l3_h4_t40K" :
    cfg.batch_size = 256
    cfg.perc_sub = 66 # Train on 66% subtraction and 33% addition question batches

  if cfg.model_name == "ins1_mix_d6_l3_h4_t40K" :
    cfg.batch_size = 256
    cfg.perc_sub = 80 # Train on 80% subtraction and 20% addition question batches
    cfg.insert_mode = 1 # Initialise with add_d6_l2_h3_t15K.pth.

  if cfg.model_name == "ins2_mix_d6_l4_h4_t40K" :
    cfg.batch_size = 256
    cfg.perc_sub = 80 # Train on 80% subtraction and 20% addition question batches
    cfg.insert_mode = 2 # Initialise with add_d6_l2_h3_t15K.pth. Train & reset useful heads every 100 epochs

  if cfg.model_name == "ins3_mix_d6_l4_h3_t40K" :
    cfg.batch_size = 256
    cfg.perc_sub = 80 # Train on 80% subtraction and 20% addition question batches
    cfg.insert_mode = 3 # Initialise with add_d6_l2_h3_t15K.pth. Trained & reset useful heads & MLPs every 100 epochs

In [None]:
def file_prefix():
  op_prefix = 'mul' if cfg.perc_mult == 100 else 'sub' if cfg.perc_sub == 100 else 'add' if cfg.perc_add() == 100 else 'mix'

  return op_prefix + '_d{}_l{}_h{}_'.format(cfg.n_digits, cfg.n_layers, cfg.n_heads)



train_str = str(cfg.n_training_steps//1000) + "K"
main_fname = '' if cfg.insert_mode == 0 else 'ins{}_'.format(cfg.insert_mode)
main_fname += file_prefix() + 't{}_s{}'.format(train_str, cfg.training_seed)
main_fname_pth = main_fname + '.pth'
main_fname_behavior_json = main_fname + '_behavior.json'
main_fname_algorithm_json = main_fname + '_algorithm.json'

def print_config():
  print("%Mult=", cfg.perc_mult, "%Sub=", cfg.perc_sub, "%Add=", cfg.perc_add(), "File", main_fname)

print_config()
print('Main model will be read from HuggingLab file', main_fname_pth)
print('Main model behavior analysis tags will save to Colab temporary file', main_fname_behavior_json)
print('Main model algorithm analysis tags will save to Colab temporary file', main_fname_algorithm_json)

# Part 3A: Set Up: Vocabulary / Embedding / Unembedding

Convert from
- Human-readable character to numeric token index.
- Convert numeric token positions to position "meanings"
- Convert from number to human-readable stringand vice versa

In [None]:
# Vocabulary dictionary: Mapping from character (key) to token (value)
useful_info.char_to_token = {str(i) : i for i in range(10)}
useful_info.char_to_token['+'] = PLUS_INDEX
useful_info.char_to_token['-'] = MINUS_INDEX
useful_info.char_to_token['='] = EQUALS_INDEX
useful_info.char_to_token['*'] = MULT_INDEX
useful_info.char_to_token['\\'] = DIV_INDEX

In [None]:
# Unit tests
assert token_to_char(4) == '4'
assert token_to_char(MULT_INDEX) == '*'
assert tokens_to_string([EQUALS_INDEX,4,0,7]) == '=407'

In [None]:
# Convert D0 to P5, D1 to P4, D2 to P3 in 6 digit addition
def dn_to_position_name(n):
  return position_name(cfg.n_digits - 1 - n)
# Convert D'0 to P10, D'1 to P9, D'2 to P8, etc in 6 digit addition
def ddn_to_position_name(n):
  return position_name(2 * cfg.n_digits - n)
# Convert A0 to P20, A1 to P19, A2 to P18, etc in 6 digit addition
def an_to_position_name(n):
  return position_name(cfg.n_ctx() - 1 - n)
# Position of the operator (+, -, * or /)
def op_position_name():
  return position_name(cfg.n_digits)


def set_question_meanings():
  # Question and answer token position meanings D5, .., D0, *, D5', .., D0', =, A7, A6, .., A0
  q_meanings = []
  for i in range(cfg.n_digits):
    q_meanings += ["D" + str(cfg.n_digits-i-1)]
  q_meanings += "+" # Stands in for operation +, - or *
  for i in range(cfg.n_digits):
    q_meanings += ["D'" + str(cfg.n_digits-i-1)]
  q_meanings += ["="]

  useful_info.initialize_token_positions(cfg.question_tokens(), cfg.answer_tokens(), False )
  useful_info.token_position_meanings = q_meanings + useful_info.token_position_meanings[-useful_info.num_answer_positions:]
  print(useful_info.token_position_meanings)


set_question_meanings()

In [None]:
def int_to_answer_str( n ):
  s = str(abs(n))
  while len(s) < cfg.n_digits + 1 :
    s = "0" + s
  s = ("+" if n >= 0 else "-") + s
  return s


# Unit test
if cfg.n_digits == 6 :
  assert int_to_answer_str(1234) == "+0001234"

In [None]:
# Convert "0012345" to 12345
def tokens_to_unsigned_int( q, offset, digits ):
  a = 0
  for j in range(digits):
    a = a * 10 + q[offset+j]
  return a


# Convert "-12345" to -12345, and "+12345" to 12345
def tokens_to_answer(q):
  # offset of sign character
  sign_offset = cfg.question_tokens()

  # 5 digit addition yields a 6 digit answer. So cfg.n_digits+1 DIGITS
  answer_digits = cfg.n_digits+1

  a = tokens_to_unsigned_int( q, sign_offset+1, answer_digits )
  if q[sign_offset] == MINUS_INDEX:
    a = - a

  return a

# Part 3B: Set Up: Create model

In [None]:
# Transformer creation

# Structure is documented at https://neelnanda-io.github.io/TransformerLens/transformer_lens.html#transformer_lens.HookedTransformerConfig.HookedTransformerConfig
ht_cfg = HookedTransformerConfig(
    n_layers = cfg.n_layers,
    n_heads = cfg.n_heads,
    d_model = cfg.d_model,
    d_head = cfg.d_head,
    d_mlp = cfg.d_mlp,
    act_fn = cfg.act_fn,
    normalization_type = 'LN',
    d_vocab = cfg.d_vocab,
    d_vocab_out = cfg.d_vocab,
    n_ctx = cfg.n_ctx(),
    init_weights = True,
    device = "cuda",
    seed = cfg.training_seed,
)

cfg.main_model = HookedTransformer(ht_cfg)

# Part 4: Set Up: Loss Function & Data Generator
This section defines the loss function and the training/testing data generator.

In [None]:
# Loss functions

# Calculate the per-token probability by comparing a batch of prediction "logits" to answer "tokens"
def logits_to_tokens_loss(logits, tokens):
  # Addition answer can have one extra digit than question. Answer also has a +/- sign
  n_answer_digits = cfg.n_digits+2

  # The addition answer digit token probabilities
  ans_logits = logits[:, -(n_answer_digits+1):-1]

  # Convert raw score (logits) vector into a probability distribution.
  # Emphasize the largest scores and suppress the smaller ones, to make them more distinguishable.
  ans_probs = F.log_softmax(ans_logits.to(torch.float64), dim=-1)

  max_prob_tokens = torch.argmax(ans_probs, dim=-1)

  # The addition answer digit tokens
  ans_tokens = tokens[:, -(n_answer_digits):]

  # Extract values from the ans_probs tensor, based on indices from the ans_tokens tensor
  ans_loss = torch.gather(ans_probs, -1, ans_tokens[:, :, None])[..., 0]

  return ans_loss, max_prob_tokens


# Calculate loss as negative of average per-token mean probability
def loss_fn(ans_loss):
  return -ans_loss.mean(0)

In [None]:
# Generate an enriched data batch for one operator type
# "Addition" batch entries are formated XXXXX+YYYYY=+ZZZZZZ e.g. 550030+800020=+1350050
# "Subtraction" batch entries are formated XXXXX-YYYYY=-ZZZZZZ e.g. 550030-800020=-0249990, 800020-550030=+0249990
# "Multiplication" batch entries are formated 000XXX*000YYY=+ZZZZZZ e.g. 000345*000678=+233910
def data_generator_core( batch_op ):

  batch = torch.zeros((cfg.batch_size, cfg.n_ctx())).to(torch.int64)
  x = torch.randint(0, 10, (cfg.batch_size, cfg.n_digits))
  y = torch.randint(0, 10, (cfg.batch_size, cfg.n_digits))

  if batch_op == MULT_INDEX:
    # Convert from NNNNNN*NNNNNN= to 000NNN*000NNN= so answer (product) is NNNNNN
    num_zeros = cfg.n_digits // 2
    for z in range(num_zeros):
      x[:, z] = 0
      y[:, z] = 0

  # Enrich the question data on 60% of batches to speed up training
  if ( batch_op == PLUS_INDEX or batch_op == MINUS_INDEX ) and (random.randint(1, 5) < 3):
    # Flatten x and y to 1D tensors
    x_flat = x.view(-1)
    y_flat = y.view(-1)

    if batch_op == PLUS_INDEX :
      # The UseSum9 task is compound and rare and so hard to learn.
      # Increase the MakeSum9 case frequency
      # UseSum9 also relies on MakeCarry1 (50%) from previous column.
      num_elements_to_modify = int(0.40 * x.numel()) # 40%
      indices_to_modify = torch.randperm(x_flat.numel())[:num_elements_to_modify]
      if random.randint(1, 2) == 1:
        x_flat[indices_to_modify] = 9 - y_flat[indices_to_modify]
      else:
        y_flat[indices_to_modify] = 9 - x_flat[indices_to_modify]
    else:
      # Empirically, the model seems to struggle with the sign calculation.
      # Minus signs are rarer than positive signs.
      # Generate more negative answers by increasing the y value
      y_flat[y_flat < 9] += 1

    # Reshape x and y back to its original shape
    x = x_flat.view(x.shape)
    y = y_flat.view(x.shape)


  first_answer_index = cfg.question_tokens()

  batch[:, :cfg.n_digits] = x
  batch[:, cfg.n_digits] = batch_op
  batch[:, 1+cfg.n_digits:1+cfg.n_digits*2] = y
  batch[:, first_answer_index-1] = EQUALS_INDEX

  # Convert each row into a 5-digit number
  x_values = x[:, 0]
  y_values = y[:, 0]
  for dn in range(1,cfg.n_digits):
    x_values = x_values * 10 + x[:, dn]
    y_values = y_values * 10 + y[:, dn]

  # Elementwise operations to give the 1D tensor answers
  if batch_op == MULT_INDEX:
    answers = x_values * y_values
  else:
    if batch_op == MINUS_INDEX:
      answers = x_values - y_values
    else:
      answers = x_values + y_values

  # Insert the answers into the batch
  for i in range(cfg.batch_size):
    answer = answers[i]

    sign = PLUS_INDEX
    if answer < 0:
      sign = MINUS_INDEX
      answer = - answer

    batch[i, first_answer_index] = sign
    for j in range(cfg.n_digits+1):
      batch[i, cfg.n_ctx()-j-1] = answer % 10
      answer = answer // 10
      if answer == 0:
          break

  return batch

In [None]:
# Define "iterator" data generator function. Invoked using next().
def data_generator( ):
  torch.manual_seed(cfg.analysis_seed)
  while True:

    batch_rand = random.randint(1, 100)
    batch_op = MULT_INDEX if batch_rand <= cfg.perc_mult else MINUS_INDEX if batch_rand <= cfg.perc_mult + cfg.perc_sub else PLUS_INDEX

    batch = data_generator_core( batch_op )

    yield batch.cuda()

In [None]:
# Initialise the data generator
ds = data_generator()

In [None]:
# Run data generator
print(next(ds)[:3,:])

# Part 5: Set Up: Load Model from HuggingFace

In [None]:
main_repo_name="PhilipQuirke/VerifiedArithmetic"
print("Loading model from HuggingFace", main_repo_name, main_fname_pth)

cfg.main_model.load_state_dict(utils.download_file_from_hf(repo_name=main_repo_name, file_name=main_fname_pth, force_is_torch=True))
cfg.main_model.eval()

# Part 6A: Set Up: Major Quanta types

Extending the imported QuantaType values (POSITION, FAIL, IMPACT, ATTENTION and ALGO), we define model-specific QuantaType values  

In [None]:
# What type of mathematical operation was the question
QuantaType_MATH_ADD = "Math.Add"
QuantaType_MATH_SUB = "Math.Sub"
QuantaType_MATH_VARIED = "Math.Varied" # Mixture of question types. Aka Unknown

# Part 6B: Set Up: Minor Quanta types

In [None]:
# Related to QuantaType.IMPACT:
# No answer digits were impacted by the intervention
NO_IMPACT_TAG = "(none)"


# Related to QuantaType.PCA:
# PCA says the node outputs is interpretable aligned to the T8,T9,T10 questions, giving 2 or 3 distinct output clusters
PCA_ADD_TAG = "TR"



# Related to QuantaType_MATH_ADD:
# Addition operation "complexity" minor tags
MATH_ADD_S0_TAG = "S0"
MATH_ADD_S1_TAG = "S1"
MATH_ADD_S2_TAG = "S2"
MATH_ADD_S3_TAG = "S3"
MATH_ADD_S4_TAG = "S4"
MATH_ADD_S5_TAG = "S5"


# Related to QuantaType_MATH_SUB:
# Subtraction operation "complexity" minor tags
MATH_SUB_S0_TAG = "M0"
MATH_SUB_S1_TAG = "M1"
MATH_SUB_S2_TAG = "M2"
MATH_SUB_S3_TAG = "M3"
MATH_SUB_NG_TAG = "NG"


# Related to QuantaType.ALGO:
ALGO_ADD_BA_TAG = "BA" # Addition - Base Add (Dn, D'n)
ALGO_ADD_MC_TAG = "MC" # Addition - Make Carry (Dn, D'n)
ALGO_ADD_US_TAG = "US" # Addition - Use Sum 9 (Dn, D'n)
ALGO_ADD_TC_TAG = "TC" # Addition - TriCase (Dn, D'n)
ALGO_SUB_BS_TAG = "BS" # Subtraction - Base Sub (Dn, D'n)
ALGO_SUB_BO_TAG = "BO" # Subtraction - Borrow One (Dn, D'n)
ALGO_SUB_SZ_TAG = "SZ" # Subtraction - Sum Zero (Dn, D'n)
ALGO_SUB_NG_TAG = "NG" # Subtraction - Answer is negative (that is A - B where A < B)
ALGO_MIX_OP_TAG = "OP" # Add/Sub - Attends to operation token

# Part 7B: Set Up: Create sample questions by Complexity Quanta

Sets of sample questions by complexity quanta

In [None]:
# Insert a number into the question
def insert_question_number(the_question, index, first_digit_index, the_digits, n):

  last_digit_index = first_digit_index + the_digits - 1

  for j in range(the_digits):
    the_question[index, last_digit_index-j] = n % 10
    n = n // 10


# Create a single question
def make_a_question(the_question, index, q1, q2, operator ):

  insert_question_number(the_question, index, 0, cfg.n_digits, q1)

  the_question[index, cfg.n_digits] = operator

  insert_question_number( the_question, index, cfg.n_digits+1, cfg.n_digits, q2)

  the_question[index, 2*cfg.n_digits+1] = EQUALS_INDEX

  answer = q1+q2
  if operator == MINUS_INDEX:
    answer = q1-q2
  else:
    if operator == MULT_INDEX:
      answer = q1*q2

  the_question[index, 2*cfg.n_digits+2] = PLUS_INDEX if answer >= 0 else MINUS_INDEX
  if answer < 0:
    answer = -answer

  insert_question_number(the_question, index, 2*cfg.n_digits + 3, cfg.n_digits+1, answer)


# Create a batch of questions from a 2D matrix of ints
def make_questions(operator, q_matrix):
  max_len = len(q_matrix)
  real_len = 0
  questions = torch.zeros((max_len, cfg.n_ctx())).to(torch.int64)
  limit = 10 ** cfg.n_digits

  for i in range(max_len):
    a = q_matrix[i][0]
    b = q_matrix[i][1]

    if a < limit and b < limit:
      make_a_question(questions, real_len, a, b, operator)
      real_len += 1

  return questions[:real_len]

In [None]:
# Manually create some questions that strongly test one quanta


# Make BaseAdd questions
def make_s0_questions():
    return QuantaType_MATH_ADD, MATH_ADD_S0_TAG, make_questions( PLUS_INDEX,
      [[0, 0],
      [1, 3],
      [12345, 33333],
      [33333, 12345],
      [45762, 33113],
      [888, 11111],
      [2362, 23123],
      [15, 81],
      [1000, 4441],
      [4440, 11111],
      [24033, 25133],
      [23533, 21133],
      [32500, 1],
      [31500, 1111],
      [5500, 12323],
      [4500, 2209],
      [33345, 66643], # =099988
      [66643, 33345], # =099988
      [10770, 44111],
      [60000, 31111],
      [10000, 21111],
      [107700, 441111],
      [600000, 311111],
      [100000, 211111],
      [1077000, 4411111],
      [6000000, 3111111],
      [1000000, 2111111],
      [10770000, 44111111],
      [60000000, 3111111],
      [10000000, 2111111]])

# Make UseCarry1 (addition) questions
def make_s1_questions():
    return QuantaType_MATH_ADD, MATH_ADD_S1_TAG, make_questions( PLUS_INDEX,
      [[ 15, 45],
      [ 27, 55],
      [ 35, 59],
      [ 150, 451],
      [ 270, 551],
      [ 350, 591],
      [ 1500, 4511],
      [ 2700, 5511],
      [ 3500, 5911],
      [ 40035, 41149],
      # [ 44000, 46000], D6 L1 H3 model cant handle this.
      [ 70000, 41111],
      [ 15000, 25111],
      [ 35000, 35111],
      [ 45000, 35111],
      [ 67000, 25111],
      [ 19000, 76111],
      [ 15020, 45091],
      [ 25002, 55019],
      [ 35002, 59019],
      [ 150211, 450911],
      [ 250021, 550191],
      [ 350021, 590191],
      [ 1502111, 4509111],
      [ 2500211, 5501911],
      [ 3500211, 5901911],
      [ 15021111, 45091111],
      [ 25002111, 55019111],
      [ 35002111, 59019111]])


# Make SimpleUseSum9 (addition) questions
def make_s2_questions():
    return QuantaType_MATH_ADD,MATH_ADD_S2_TAG, make_questions( PLUS_INDEX,
      [[ 55, 45],
      [ 45, 55],
      [ 45, 59],
      [ 35, 69],
      [ 25, 79],
      [ 15, 85],
      [ 15, 88],
      [ 15518, 14511],
      [ 14518, 15511],
      [ 24533, 25933],
      [ 23533, 26933],
      [ 32511, 7911],
      [ 31511, 8511],
      [ 551, 451],
      [ 451, 551],
      [ 10881, 41127],
      [ 41127, 10881],
      [ 12386, 82623],
      [ 108811, 411271],
      [ 411271, 108811],
      [ 123861, 826231],
      [ 994890, 80105],
      [ 970590, 96026],
      [ 994890, 80105],
      [ 970590, 96026],
      [ 1088111, 4112711],
      [ 4112711, 1088111],
      [ 1238611, 8262311],
      [ 10881111, 41127111],
      [ 41127111, 10881111],
      [ 12386111, 82623111]])

# These are two level UseSum9 cascades
def make_s3_questions():
    return QuantaType_MATH_ADD, MATH_ADD_S3_TAG, make_questions( PLUS_INDEX,
      [[ 555, 445],
      [ 3340, 6661],
      [ 8880, 1121],
      [ 1120, 8881],
      [ 123, 877],
      [ 877, 123],
      [ 321, 679],
      [ 679, 321],
      [ 1283, 78785]])


# These are three level UseSum9 cascades
def make_s4_questions():
    return QuantaType_MATH_ADD, MATH_ADD_S4_TAG, make_questions( PLUS_INDEX,
      [[ 5555, 4445],
      [ 55550, 44451],
      [ 3334, 6666],
      [ 33340, 66661],
      [ 8888, 1112],
      [ 88880, 11121],
      [ 1234, 8766],
      [ 4321, 5679]])


# These are four level UseSum9 cascades
def make_s5_questions():
    return QuantaType_MATH_ADD, MATH_ADD_S5_TAG, make_questions( PLUS_INDEX,
      [[ 44445, 55555],
      [ 33334, 66666],
      [ 88888, 11112],
      [ 12345, 87655],
      [ 54321, 45679],
      [ 45545, 54455],
      [ 36634, 63366],
      [ 81818, 18182],
      [ 87345, 12655],
      [ 55379, 44621]])


# Make questions focus mainly on 1 digit at a time
# (assuming that the 0 + 0 digit additions/subtractions are trivial bigrams)
def make_sn_questions():
    return QuantaType_MATH_ADD, "S*", make_questions( PLUS_INDEX,
      [[ 1, 0],
      [ 4, 3],
      [ 5, 5],
      [ 8, 1],
      [ 40, 31],
      [ 44, 46],
      [ 400, 311],
      [ 440, 461],
      [ 800, 111],
      [ 270, 471],
      [ 600, 311],
      [ 4000, 3111],
      [ 4400, 4611],
      [ 6000, 3111],
      [ 7000, 4111],
      [ 40000, 31111],
      [ 44000, 45111],
      [ 60000, 31111],
      [ 70000, 41111],
      [ 10000, 21111],
      [ 15000, 25111],
      [ 35000, 35111],
      [ 45000, 85111],
      [ 67000, 85111],
      [ 99000, 76111],
      [ 76000, 99111],
      [ 670000, 851111],
      [ 990000, 761111],
      [ 760000, 991111],
      [ 6700000, 8511111],
      [ 9900000, 7611111],
      [ 7600000, 9911111],
      [ 67000000, 85111111],
      [ 99000000, 76111111],
      [ 76000000, 99111111]])


# Make M0 questions - when no column generates a Borrow One. Answer is always positive (or zero).
def make_m0_questions():
    return QuantaType_MATH_SUB, MATH_SUB_S0_TAG, make_questions( MINUS_INDEX,
      [[0, 0],
      [6, 6],
      [61, 60],
      [611, 600],
      [6111, 6000],
      [61111, 60000],
      [611111, 600000],
      [6111111, 6000000],
      [61111111, 60000000],
      [66666, 12345],
      [33333, 12321],
      [45762, 34551],
      [78901, 78901], # = +000000
      [23123, 23123], # = +000000
      [86, 15],
      [4440, 1230],
      [88746, 86544],
      [27833, 25133],
      [23533, 21133],
      [32501, 1],
      [31511, 1111],
      [55555, 12323],
      [45454, 22022],
      [66643, 3341],
      [66643, 30042],
      [99999, 44012],
      [61111, 30000],
      [99111, 99111], # = +000000
      [999991, 440120],
      [611111, 300000],
      [991111, 991111], # = +0000000
      [9999911, 4401200],
      [6111111, 3000000],
      [9911111, 9911111], # = +00000000
      [99999111, 44012000],
      [61111111, 30000000],
      [99111111, 99111111]]) # = +000000000

# Make subtraction M1 questions with exactly one "borrow 1" instance. Answer is always positive.
def make_m1_questions():
    return QuantaType_MATH_SUB, MATH_SUB_S1_TAG, make_questions( MINUS_INDEX,
      [[22222, 11113],
      [ 22222, 11131],
      [ 22222, 11311],
      [ 22222, 13111],
      [    14,     8],
      [   141,    80],
      [  1411,   800],
      [ 14111,  8000],
      [ 55514, 11108],
      [ 55141, 11080],
      [ 51411, 10800],
      [ 140111,  8000],
      [ 88888, 22229],
      [ 77777, 22292],
      [ 66666, 22922],
      [ 888888, 222292],
      [ 777777, 222922],
      [ 666666, 229222],
      [ 8888888, 2222922],
      [ 7777777, 2229222],
      [ 6666666, 2292222],
      [ 88888888, 22229222],
      [ 77777777, 22292222],
      [ 66666666, 22922222]])

# Make subtraction M2 questions containing BO and DZ. Answer is always positive (or zero).
def make_m2_questions():
    return QuantaType_MATH_SUB, MATH_SUB_S2_TAG, make_questions( MINUS_INDEX,
      [[22212, 11113],
      [ 22122, 11131],
      [ 21222, 11311],
      [   904,     8],
      [  9041,    80],
      [ 90411,   800],
      [ 55514, 11118],
      [ 55141, 11180],
      [ 51411, 11800],
      [ 88888, 22289],
      [ 77777, 22792],
      [ 66666, 26922],
      [ 888888, 222892],
      [ 777777, 227922],
      [ 666666, 269222],
      [ 8888888, 2228922],
      [ 7777777, 2279222],
      [ 6666666, 2692222],
      [ 88888888, 22289222],
      [ 77777777, 22792222],
      [ 66666666, 26922222]])


# Make subtraction M3,M4,... questions containing BO and multiple DZs. Answer is always positive (or zero).
def make_m3_questions():
    return QuantaType_MATH_SUB, MATH_SUB_S3_TAG, make_questions( MINUS_INDEX,
      [[22112, 11113],
      [ 21122, 11131],
      [ 99004,     8],
      [ 90041,    80],
      [ 55114, 11118],
      [ 51140, 11180],
      [ 88888, 22889],
      [ 87777, 27792],
      [ 888888, 228892],
      [ 877777, 277922],
      [ 8888888, 2288922],
      [ 7777777, 2779222],
      [ 88888888, 22889222],
      [ 77777777, 28892222]])


# Make subtraction questions with negative answers
def make_ng_questions():
    return QuantaType_MATH_SUB, MATH_SUB_NG_TAG, make_questions( MINUS_INDEX,
      [[0, 1],
      [7, 9],
      [12345, 33333],
      [888, 11111],
      [2362, 23123],
      [15, 81],
      [1111, 4440],
      [24033, 25133],
      [23533, 88133],
      [5511, 12323],
      [4511, 22209],
      [ 88888, 88889],
      [ 55555, 55556],
      [ 88881, 88891],
      [ 55551, 55561],
      [ 88811, 88911],
      [ 55511, 55611],
      [ 88746, 89544],
      [ 27833, 29133],
      [ 23533, 23833],
      [ 31511, 41111],
      [ 55555, 62323],
      [ 45454, 72022],
      [ 66643, 73341],
      [ 66643, 90042],
      [ 99998, 99999],
      [ 8, 12],
      [ 41, 232],
      [ 44, 523],
      [ 234, 334],
      [ 7777, 8434],
      [ 88888, 92222],
      [ 77777, 84340],
      [ 888888, 922220],
      [ 777777, 843400],
      [ 8888888, 9222200],
      [ 7777777, 8434000],
      [ 88888888, 92222000],
      [ 77777777, 84340000]])


v0 = next(ds) # Could be Add, Sub or Mult
v1 = next(ds) # Could be Add, Sub or Mult
if cfg.perc_add() > 0 and cfg.perc_sub > 0 :
  v0 = data_generator_core( PLUS_INDEX )
  v1 = data_generator_core( MINUS_INDEX )


# Returns ~1000 random and up to ~150 manually-chosen questions
def make_varied_questions():
  if cfg.perc_mult == 100 :
    return torch.vstack((v0.cuda(), v1.cuda()))

  _, _, s0 = make_s0_questions()
  _, _, s1 = make_s1_questions()
  _, _, s2 = make_s2_questions()
  _, _, s3 = make_s3_questions()
  _, _, s4 = make_s4_questions()
  _, _, s5 = make_s5_questions()
  _, _, s6 = make_sn_questions()

  _, _, m0 = make_m0_questions()
  _, _, m1 = make_m1_questions()
  _, _, m2 = make_m2_questions()
  _, _, m3 = make_m3_questions()
  _, _, m4 = make_ng_questions()

  if cfg.perc_add() == 100 :
    return torch.vstack((v0.cuda(), s0.cuda(), s1.cuda(), s2.cuda(), s3.cuda(), s4.cuda(), v1.cuda()))

  if cfg.perc_sub == 100 :
    return torch.vstack((v0.cuda(), m0.cuda(), m1.cuda(), m2.cuda(), m3.cuda(), m4.cuda(), v1.cuda()))

  return torch.vstack((v0.cuda(), s0.cuda(), m0.cuda(), s1.cuda(), m1.cuda(), s2.cuda(), m2.cuda(), s3.cuda(), m3.cuda(), s4.cuda(), m4.cuda(), s5.cuda(), s6.cuda(), v1.cuda()))

# Part 7C: Set Up: Evaluate mathematical Complexity quanta e.g. Add.S2, Sub.M1

Functions to evaluate the question "mathematical complexity" of questions

In [None]:
# Analyse and return the complexity quanta for the Addition (S0 to S4+) or Subtraction (M0 to NG) questions
def get_question_complexity(question):
  qlist = utils.to_numpy(question)
  inputs = qlist[:2*cfg.n_digits+2]
  operator = qlist[cfg.n_digits]

  if operator == PLUS_INDEX:

    # Locate the MC and MS digits (if any)
    mc = torch.zeros(cfg.n_digits).to(torch.int64)
    ms = torch.zeros(cfg.n_digits).to(torch.int64)
    for dn in range(cfg.n_digits):
      if inputs[dn] + inputs[dn + cfg.n_digits + 1] == 9:
        ms[cfg.n_digits-1-dn] = 1
      if inputs[dn] + inputs[dn + cfg.n_digits +1] > 9:
        mc[cfg.n_digits-1-dn] = 1

    if torch.sum(mc) == 0:
      return QuantaType_MATH_ADD, MATH_ADD_S0_TAG

    if torch.sum(ms) == 0:
      return QuantaType_MATH_ADD, MATH_ADD_S1_TAG

    for dn in range(cfg.n_digits-4):
      if mc[dn] == 1 and ms[dn+1] == 1 and ms[dn+2] == 1 and ms[dn+3] == 1 and ms[dn+4] == 1:
        return QuantaType_MATH_ADD, MATH_ADD_S5_TAG # MC cascades 4 or more digits

    for dn in range(cfg.n_digits-3):
      if mc[dn] == 1 and ms[dn+1] == 1 and ms[dn+2] == 1 and ms[dn+3] == 1:
        return QuantaType_MATH_ADD, MATH_ADD_S4_TAG # MC cascades 3 or more digits

    for dn in range(cfg.n_digits-2):
      if mc[dn] == 1 and ms[dn+1] == 1 and ms[dn+2] == 1:
        return QuantaType_MATH_ADD, MATH_ADD_S3_TAG # MC cascades 2 or more digits

    for dn in range(cfg.n_digits-1):
      if mc[dn] == 1 and ms[dn+1] == 1:
        return QuantaType_MATH_ADD, MATH_ADD_S2_TAG # Simple US 9

    return QuantaType_MATH_ADD, MATH_ADD_S1_TAG


  if operator == MINUS_INDEX:
    a = tokens_to_unsigned_int( question, 0, cfg.n_digits )
    b = tokens_to_unsigned_int( question, cfg.n_digits + 1, cfg.n_digits )
    if a - b < 0:
      return QuantaType_MATH_SUB, MATH_SUB_NG_TAG

    # Locate the BO and MZ digits (if any)
    bo = torch.zeros(cfg.n_digits).to(torch.int64)
    mz = torch.zeros(cfg.n_digits).to(torch.int64)
    for dn in range(cfg.n_digits):
      if inputs[dn] - inputs[dn + cfg.n_digits + 1] < 0:
        bo[cfg.n_digits-1-dn] = 1
      if inputs[dn] - inputs[dn + cfg.n_digits +1] == 0:
        mz[cfg.n_digits-1-dn] = 1

    # Evaluate BaseSub questions - when no column generates a Borrow One
    if torch.sum(bo) == 0:
      return QuantaType_MATH_SUB, MATH_SUB_S0_TAG

    # Evaluate subtraction "cascade multiple steps" questions
    for dn in range(cfg.n_digits-3):
      if bo[dn] == 1 and mz[dn+1] == 1 and mz[dn+2] == 1 and mz[dn+3] == 1:
        return QuantaType_MATH_SUB, "M4+" # BO cascades 3 or more digits

    # Evaluate subtraction "cascade multiple steps" questions
    for dn in range(cfg.n_digits-2):
      if bo[dn] == 1 and mz[dn+1] == 1 and mz[dn+2] == 1:
        return QuantaType_MATH_SUB, MATH_SUB_S3_TAG # BO cascades 2 or more digits

    # Evaluate subtraction "cascade 1" questions
    for dn in range(cfg.n_digits-1):
      if bo[dn] == 1 and mz[dn+1] == 1:
        return QuantaType_MATH_SUB, MATH_SUB_S2_TAG # BO cascades 1 digit

    return QuantaType_MATH_SUB, MATH_SUB_S1_TAG


  # Should never get here
  print("get_question_complexity OP? exception", question)
  return QuantaType_MATH_VARIED, "OP?"

In [None]:
def unit_test_quanta_core(make_questions):
  correct_major_tag, correct_complexity, questions = make_questions()
  num_questions = questions.shape[0]
  print( correct_major_tag + ":" + correct_complexity, "#Questions=", num_questions)

  for i in range(num_questions):
    major_tag, complexity = get_question_complexity(questions[i])
    if major_tag != correct_major_tag or complexity != correct_complexity:
      print( "Complexity mismatch:", correct_major_tag, major_tag, correct_complexity, complexity, questions[i])


# Test that our "sample questions by quanta" and "question quanta evaluation" are aligned.
# If this fails, either the sample questions or the evaluation is buggy.
def unit_test_quanta():
  unit_test_quanta_core(make_s0_questions)
  unit_test_quanta_core(make_s1_questions)
  unit_test_quanta_core(make_s2_questions)
  unit_test_quanta_core(make_s3_questions)
  unit_test_quanta_core(make_s4_questions)
  unit_test_quanta_core(make_s5_questions)

  unit_test_quanta_core(make_m0_questions)
  unit_test_quanta_core(make_m1_questions)
  unit_test_quanta_core(make_m2_questions)
  unit_test_quanta_core(make_m3_questions)
  unit_test_quanta_core(make_ng_questions)


unit_test_quanta()

# Part 8A: Set Up: Question prediction function

Create sets of sample questions exercising different quanta

In [None]:
# Build a test batch of random and manually-chosen questions
varied_questions = make_varied_questions();


# Run the sample batch, gather the cache
cfg.main_model.reset_hooks()
cfg.main_model.set_use_attn_result(True)
sample_logits, sample_cache = cfg.main_model.run_with_cache(varied_questions.cuda())
print(sample_cache) # Gives names of datasets in the cache
sample_losses_raw, sample_max_prob_tokens = logits_to_tokens_loss(sample_logits, varied_questions.cuda())
sample_loss_mean = utils.to_numpy(loss_fn(sample_losses_raw).mean())
print("Sample Mean Loss", sample_loss_mean) # Loss < 0.04 is good


# attn.hook_z is the "attention head output" hook point name (at a specified layer)
l_attn_hook_z_name = [utils.get_act_name('z', 0, 'a'),utils.get_act_name('z', 1, 'a'),utils.get_act_name('z', 2, 'a'),utils.get_act_name('z', 3, 'a')] # 'blocks.0.attn.hook_z' etc
sample_attn_z_0 = sample_cache[l_attn_hook_z_name[0]]
print("Sample", l_attn_hook_z_name[0], sample_attn_z_0.shape) # gives [350, 22, 3, 170] = num_questions, cfg.n_ctx, n_heads, d_head
mean_attn_z = torch.mean(sample_attn_z_0, dim=0, keepdim=True)
print("Mean", l_attn_hook_z_name[0], mean_attn_z.shape) # gives [1, 22, 3, 170] = 1, cfg.n_ctx, n_heads, d_head


# hook_resid_pre is the "pre residual memory update" hook point name (at a specified layer)
l_hook_resid_pre_name = ['blocks.0.hook_resid_pre','blocks.1.hook_resid_pre','blocks.2.hook_resid_pre','blocks.3.hook_resid_pre']


# hook_resid_post is the "post residual memory update" hook point name (at a specified layer)
l_hook_resid_post_name = ['blocks.0.hook_resid_post','blocks.1.hook_resid_post','blocks.2.hook_resid_post','blocks.3.hook_resid_post']
sample_resid_post_0 = sample_cache[l_hook_resid_post_name[0]]
print("Sample", l_hook_resid_post_name[0], sample_resid_post_0.shape) # gives [350, 22, 510] = num_questions, cfg.n_ctx, d_model
mean_resid_post = torch.mean(sample_resid_post_0, dim=0, keepdim=True)
print("Mean", l_hook_resid_post_name[0], mean_resid_post.shape) # gives [1, 22, 510] = 1, cfg.n_ctx, d_model


# mlp.hook_post is the "MLP layer" hook point name (at a specified layer)
l_mlp_hook_post_name = [utils.get_act_name('post', 0),utils.get_act_name('post', 1),utils.get_act_name('post', 2),utils.get_act_name('post', 3)] # 'blocks.0.mlp.hook_post' etc
sample_mlp_hook_post_0 = sample_cache[l_mlp_hook_post_name[0]]
print("Sample", l_mlp_hook_post_name[0], sample_mlp_hook_post_0.shape) # gives [350, 22, 2040] = num_questions, cfg.n_ctx, cfg.d_mlp
mean_mlp_hook_post = torch.mean(sample_mlp_hook_post_0, dim=0, keepdim=True)
print("Mean", l_mlp_hook_post_name[0], mean_mlp_hook_post.shape) # gives [1, 22, 2040] = 1, cfg.n_ctx, cfg.d_mlp

In [None]:
verbose = True


class T_Config(NodeLocation):
  num_questions : int
  correct_answers : int
  total_mean_loss : float
  correct_list = [] # List of size num_questions showing which answers were correct.

  sum_num_questions : int
  sum_correct_answers : int

  output = PrettyTable()

  threshold : int


  def __init__(self):
    super().__init__(0, 0, True, 0)
    self.reset()


  def reset(self):
    self.num_questions = 0
    self.correct_answers = 0
    self.total_mean_loss = 0.0
    self.sum_num_questions = 0
    self.correct_list = []
    self.sum_correct_answers = 0
    self.output = PrettyTable()
    self.output.field_names = ["Complexity", "#Questions", "#Correct", "%Correct", "Mean loss"]
    self.threshold = 0.01


  # Clear the question summary results
  def clear_questions_results(self, title):

    self.num_questions = 0
    self.correct_answers = 0
    self.total_mean_loss = 0
    self.correct_list = []

    if verbose:
      print(title)


  # Print the question summary results
  def print_questions_results(self, prefix):
    self.output.add_row([prefix, self.num_questions, str(self.correct_answers), 100*self.correct_answers/self.num_questions, self.total_mean_loss/self.num_questions])
    self.sum_num_questions += self.num_questions
    self.sum_correct_answers += self.correct_answers


  # Print the overall summary results
  def print_overall_results(self):
    self.output.add_row(["OVERALL", self.sum_num_questions, self.sum_correct_answers, "", ""])
    print(self.output.get_formatted_string(out_format=cfg.table_out_format))


  # Evidence (not proof) the model is accurate
  def might_be_fully_accurate(self):
    return self.sum_num_questions == self.sum_correct_answers


tcfg = T_Config()

In [None]:
# Ask model to predict answer for each question & collect results
def do_questions(questions, show_failures = False):

  tcfg.num_questions = questions.shape[0]
  tcfg.correct_list = [True] * tcfg.num_questions

  # Run with no hook
  all_logits = cfg.main_model(questions.cuda())
  all_losses_raw, all_max_prob_tokens = logits_to_tokens_loss(all_logits, questions.cuda())

  for question_num in range(tcfg.num_questions):
    q = questions[question_num]

    losses = loss_fn(all_losses_raw[question_num])
    mean_loss = utils.to_numpy(losses.mean())
    tcfg.total_mean_loss += mean_loss

    model_answer_str = tokens_to_string(all_max_prob_tokens[question_num])
    model_answer_num = int(model_answer_str)

    a = tokens_to_answer(q)

    correct = (model_answer_num == a)
    tcfg.correct_list[question_num] = correct

    if correct :
      tcfg.correct_answers += 1

    if verbose or (show_failures and not correct):
      print(tokens_to_string(q), "ModelAnswer:", model_answer_str, "Loss:", mean_loss )


In [None]:
def print_question_results_core( title, questions, show_failures = False):
  tcfg.clear_questions_results(title)
  do_questions(questions, show_failures)
  tcfg.print_questions_results(title)


def print_question_results( make_questions, show_failures = False):
  major_tag, quanta_case, questions = make_questions()
  title = major_tag + "." + quanta_case
  print_question_results_core( title, questions, show_failures)

# Part 8B: Results: Prediction success by Complexity quanta

This section runs hand-curated test cases to indicate which complexity quanta the model can (probably) handle.

Not proof - our test cases might be inadequate.

In [None]:
verbose = False

if cfg.perc_add() > 0:
  tcfg.reset()
  print_question_results(make_s0_questions)
  print_question_results(make_s1_questions)
  print_question_results(make_s2_questions)
  print_question_results(make_s3_questions)
  print_question_results(make_s4_questions)
  print_question_results(make_s5_questions)
  print_question_results(make_sn_questions)
  tcfg.print_overall_results()

In [None]:
verbose = False

if cfg.perc_sub > 0:
  tcfg.reset()
  print_question_results(make_m0_questions)
  print_question_results(make_m1_questions)
  print_question_results(make_m2_questions)
  print_question_results(make_m3_questions)
  print_question_results(make_ng_questions)
  tcfg.print_overall_results()

In [None]:
# Varied questions includes 2 random batches of questions. Show any questions that we can't calculate correctly.
tcfg.reset()
print_question_results_core( QuantaType_MATH_VARIED, varied_questions, True)
tcfg.print_overall_results()

model_might_be_fully_accurate = tcfg.might_be_fully_accurate()
if model_might_be_fully_accurate:
  # This is evidence not proof because there may be very rare edge cases (say 1 in ten million) that did not appear in the test questions.
  # Even if you believe you know all the edge cases, and have enriched the training data to contain them, you may not have thought of all edge cases, so this is not proof.
  print("Model got all test questions correct. This is a pre-requisite for the model to be fully accurate, but this is NOT proof it is fully accurate.")
else:
  # Remove the questions that the model failed to answer as they turn up in every cell quanta maps
  org_size = varied_questions.shape[0]
  varied_questions = varied_questions[torch.tensor(tcfg.correct_list)]
  new_size = varied_questions.shape[0]

  print()
  print("WARNING: Model is not fully accurate as it got", org_size - new_size, "questions wrong.")
  print("RESOLUTION: Understand these failures. Enrich the training data to provide more examples. Retrain the model.")
  print("INTERIM: Have reduced 'varied_questions' size from", org_size, "to", new_size, "so can continue.")

# Part 12A: Set Up: Predict Questions and Evaluate Quanta

Get model to predict given question answers, with ablation hook(s), and categorise how many questions fail.

In [None]:
# Ask the model to predict the question answers (with the hooks either reading data, doing intervetion ablations, or doing nothing )
def predict_questions_core(questions, the_hooks):

  cfg.main_model.reset_hooks()
  cfg.main_model.set_use_attn_result(True)

  all_logits = cfg.main_model.run_with_hooks(questions.cuda(), return_type="logits", fwd_hooks=the_hooks)
  all_losses_raw, all_max_prob_tokens = logits_to_tokens_loss(all_logits, questions.cuda())

  return all_losses_raw, all_max_prob_tokens

In [None]:
def c_predict_questions(questions, the_hooks):

  all_losses_raw, all_max_prob_tokens = predict_questions_core(questions, the_hooks)

  num_fails = 0
  for question_num in range(questions.shape[0]):
    q = questions[question_num]

    the_loss_mean = utils.to_numpy(loss_fn(all_losses_raw[question_num]).mean())

    # Only show the question if the loss exceeds the threshold (because of the ablated token position)
    if the_loss_mean > tcfg.threshold:
      answer_str = tokens_to_string(all_max_prob_tokens[question_num])

      # Only count the question if the model got the question wrong
      impact_str = get_question_answer_impact( q, answer_str )
      if 'A' in impact_str:
        num_fails += 1

        if verbose :
          print(tokens_to_string(q), "Q: ModelAnswer:", answer_str, "Impact:", impact_str, "Loss:", the_loss_mean )

  return num_fails

In [None]:
verbose = False


def c_set_resid_post_hook(value, hook):
  #print( "In hook", l_hook_resid_post_name[tcfg.layer], tcfg.ablate, tcfg.position, value.shape) # Get [64, 22, 510] = cfg.batch_size, num_tokens, d_model

  # Copy the mean resid post values in position N to all the batch questions
  value[:,tcfg.position,:] = mean_resid_post[0,tcfg.position,:].clone()


num_failures_list = []
num_questions = 0
if cfg.n_digits >= 5 :
  c_fwd_hooks = [(l_hook_resid_post_name[0], c_set_resid_post_hook),(l_hook_resid_post_name[1], c_set_resid_post_hook),(l_hook_resid_post_name[2], c_set_resid_post_hook),(l_hook_resid_post_name[3], c_set_resid_post_hook)][:cfg.n_layers]

  num_questions = varied_questions.shape[0]

  for tcfg.position in range(cfg.n_ctx()):
    num_fails = c_predict_questions(varied_questions, c_fwd_hooks)

    num_failures_list += [num_fails] if num_fails > 0 else "."

    if num_fails > 0:
      assert tcfg.position < cfg.n_ctx()
      useful_info.add_useful_position(tcfg.position)


# Part 9 : Results: Can the model do 1 million questions without error?

If the model passes this test, this is evidence (not proof) that the model is fully accurate. There may be very rare edge cases (say 1 in ten million) that did not appear in the test questions. Even if you believe you know all the edge cases, and have enriched the training data to contain them, you may not have thought of all edge cases, so this is not proof.

If the model fails this test:
- Add a few of the failures into the "test questions" in part 6C
- Understand the "use case(s)" driving these failures
- Alter the Training CoLab data_generator_core to enrich the training data with examples if these use case(s) and retrain the model.  

Takes ~25 mins to run (successfully) for ins_mix_d6_l3_h4_t40K_seed372001

In [None]:
def null_hook(value, hook):
  pass

In [None]:
def one_million_questions_core():
  global verbose
  global ds

  verbose = True

  cfg.analysis_seed = 345621 # Randomly chosen
  ds = data_generator() # Re-initialise the data generator

  the_successes = 0
  the_fails = 0

  num_batches = 1000000//cfg.batch_size
  for epoch in range(num_batches):
      tokens = next(ds)

      the_hook = [(l_attn_hook_z_name[0], null_hook)]
      the_fails = c_predict_questions(tokens, the_hook)

      if the_fails> 0:
        break

      the_successes = the_successes + cfg.batch_size

      if epoch % 100 == 0:
          print("Batch", epoch, "of", num_batches, "#Successes=", the_successes)

  print("successes", the_successes, "num_fails", the_fails)
  if the_fails > 0:
    "WARNING: Model is not fully accurate. It failed the 1M Q test"

In [None]:
def one_million_questions():
  print_config()
  print()

  if model_might_be_fully_accurate:

    # Commented out as it takes > 9 minutes to run
    if cfg.perc_add() > 0 and cfg.perc_sub > 0:
      print("Subtraction:")
      cfg.perc_sub = 100
      one_million_questions_core()
      print()
      print("Addition:")
      cfg.perc_sub = 0
      one_million_questions_core()

    else:
      # Predict 1M (sub, add or mult) questions
      one_million_questions_core()

  else:
    print("WARNING: Model is not fully accurate. It failed some test questions")


# Takes ~25 minutes to run
# one_million_questions()

# Part 12B: Results: Ablate ALL Heads in EACH token position. What is the impact?

Here we ablate all heads in each token position (overriding the model memory aka residual stream) and see if loss increases. If loss increases the token position is used by the algorithm. Unused token positions can be excluded from further analysis. Use "C_" prefix

In [None]:
def save_plt_to_file( full_title ):
  if cfg.graph_file_suffix > "":
    filename = file_prefix() + full_title.replace(" ", "").replace(",", "").replace(":", "").replace("-", "") + '.' + cfg.graph_file_suffix
    plt.savefig(filename, bbox_inches='tight', pad_inches=0)

In [None]:
print_config()
print()
print("Number of failures when ALL Heads in EACH token position are ablated")
print("num_questions=", num_questions, "min_useful_position=", useful_info.min_useful_position(), "max_useful_position=", useful_info.max_useful_position() )
print()

# Token positions names P1 .... P20
columns = ["Posn"]
for i in range(cfg.n_ctx()):
  columns += [position_name(i)]

rows = ["Posn", "# fails"]
data = [
    ["Posn"] + useful_info.token_position_meanings,
    ["# fails"] + num_failures_list
]

fig, ax = plt.subplots(figsize=(16,1))
ax.axis('tight')
ax.axis('off')

table = ax.table(cellText=data, colLabels=columns, loc='center', cellLoc='center')
table.auto_set_font_size(False)
table.set_fontsize(10)  # Set the font size here
table.scale(1, 1.5)  # The first parameter scales column widths, the second scales row heights

save_plt_to_file("Failures When Position Ablated")

plt.show()


# Part 14: Setup: Analysis of quanta per node

Evaluate quanta at node (not position) resolution. Uses "u_" prefix.

In [None]:
def u_add_node_tag( the_location, major_tag, minor_tag ):
  assert the_location.position >= 0
  assert the_location.layer >= 0
  assert the_location.num >= 0
  assert the_location.position < cfg.n_ctx()
  assert the_location.layer < cfg.n_layers
  if the_location.is_head:
    assert the_location.num < cfg.n_heads
  else:
    assert the_location.num < cfg.mlp_slices()

  useful_info.add_node_tag( the_location, major_tag, minor_tag )

In [None]:
# Convert "A1231231278321" to "12378" or "87321"
def sort_unique_digits(raw_input_string, do_reverse):
  digit_string = ''.join(filter(str.isdigit, raw_input_string))

  seen = set()
  unique_digits = ""
  for char in digit_string:
      if char not in seen:
          seen.add(char)
          unique_digits += char

  return ''.join(sorted(unique_digits, reverse=do_reverse))


# Unit test
assert sort_unique_digits("A1231231278321", False) == "12378"
assert sort_unique_digits("A1231231278321", True) == "87321"

In [None]:
def u_predict_questions(questions, the_hooks):

  all_losses_raw, all_max_prob_tokens = predict_questions_core(questions, the_hooks)

  num_fails = 0
  impact_fails = ""
  add_complexity_fails = ""
  sub_complexity_fails = ""

  for question_num in range(questions.shape[0]):
    q = questions[question_num]

    the_loss_mean = utils.to_numpy(loss_fn(all_losses_raw[question_num]).mean())

    # Only show the question if the loss exceeds the threshold (because of the ablated token position)
    if the_loss_mean > tcfg.threshold:
      answer_str = tokens_to_string(all_max_prob_tokens[question_num])

      impact_str = get_question_answer_impact( q, answer_str )
      # Only count the question if the model got the question wrong
      if 'A' in impact_str:
        num_fails += 1

        impact_fails += impact_str

        major_tag, minor_tag = get_question_complexity(q)
        if major_tag == QuantaType_MATH_ADD:
          add_complexity_fails += minor_tag
        elif major_tag == QuantaType_MATH_SUB:
          sub_complexity_fails += minor_tag

        if verbose :
          print(tokens_to_string(q), "U: ModelAnswer:", answer_str, "Complexity:", major_tag, "Impact:", impact_str, "Loss:", the_loss_mean )


  if num_fails > 0:

    # Add percentage failure quanta
    perc = int( 100.0 * num_fails / len(questions))
    u_add_node_tag( tcfg, QuantaType.FAIL, str(perc) )

    # Add summary of all answer digit impact quanta failures
    u_add_node_tag( tcfg, QuantaType.IMPACT, "A" + sort_unique_digits(impact_fails, True) )

    # Add summary of all addition question complexity quanta failures
    if add_complexity_fails != "":
      u_add_node_tag( tcfg, QuantaType_MATH_ADD, "S" + sort_unique_digits(add_complexity_fails, False) )

    # Add summary of all subtraction question complexity quanta failures
    if sub_complexity_fails != "":
      sub_complexity_fails = sort_unique_digits(sub_complexity_fails, False)
      if sub_complexity_fails == "":
        sub_complexity_fails = MATH_SUB_NG_TAG
      else:
        sub_complexity_fails = "M" + sub_complexity_fails
      u_add_node_tag( tcfg, QuantaType_MATH_SUB, sub_complexity_fails )

In [None]:
def u_mlp_hook_post(value, hook):
  # print( "In u_mlp_hook_post", value.shape) # Get [1099, 22, 2040] = num_questions, cfg.n_ctx, cfg.d_mlp (# neurons)

  # Mean ablate. Copy the mean resid post values in the MLP layer
  value[:,tcfg.position,:] =  mean_mlp_hook_post[:,tcfg.position,:].clone()
  #PQR When we do slices the MLP never fails?? Why??
  #slice_size = cfg.d_mlp // cfg.mlp_slices()
  #start_index = tcfg.num * slice_size
  #end_index = start_index + slice_size
  #print ("PQR", cfg.mlp_slices(), slice_size, start_index, end_index)
  #value[:,tcfg.position,start_index:end_index] =  mean_mlp_hook_post[:,tcfg.position,start_index:end_index].clone()


# Ablating the MLP in each layer in each position and seeing if the loss increases shows which layer+MLP are used by the algorithm.
def u_mlp_perform_all(questions):
  tcfg.is_head = False
  for tcfg.position in useful_info.positions:
    for tcfg.layer in range(cfg.n_layers):
      for tcfg.num in range(cfg.mlp_slices()):
        the_hook = [(l_mlp_hook_post_name[tcfg.layer], u_mlp_hook_post)]
        u_predict_questions(questions, the_hook)

In [None]:
def u_mlp_hook_post(value, hook):
  # print( "In u_mlp_hook_post", value.shape) # Get [1099, 22, 2040] = num_questions, cfg.n_ctx, cfg.d_mlp (# neurons)

  # Mean ablate. Copy the mean resid post values in the MLP layer
  value[:,tcfg.position,:] =  mean_mlp_hook_post[:,tcfg.position,:].clone()
  #PQR When we do slices the MLP never fails?? Why??
  #slice_size = cfg.d_mlp // cfg.mlp_slices()
  #start_index = tcfg.num * slice_size
  #end_index = start_index + slice_size
  #print ("PQR", cfg.mlp_slices(), slice_size, start_index, end_index)
  #value[:,tcfg.position,start_index:end_index] =  mean_mlp_hook_post[:,tcfg.position,start_index:end_index].clone()


# Ablating the MLP in each layer in each position and seeing if the loss increases shows which layer+MLP are used by the algorithm.
def u_mlp_perform_all(questions):
  tcfg.is_head = False
  for tcfg.position in useful_info.positions:
    for tcfg.layer in range(cfg.n_layers):
      for tcfg.num in range(cfg.mlp_slices()):
        the_hook = [(l_mlp_hook_post_name[tcfg.layer], u_mlp_hook_post)]
        u_predict_questions(questions, the_hook)

In [None]:
def u_mlp_hook_post(value, hook):
  # print( "In u_mlp_hook_post", value.shape) # Get [1099, 22, 2040] = num_questions, cfg.n_ctx, cfg.d_mlp (# neurons)

  # Mean ablate. Copy the mean resid post values in the MLP layer
  value[:,tcfg.position,:] =  mean_mlp_hook_post[:,tcfg.position,:].clone()
  #PQR When we do slices the MLP never fails?? Why??
  #slice_size = cfg.d_mlp // cfg.mlp_slices()
  #start_index = tcfg.num * slice_size
  #end_index = start_index + slice_size
  #print ("PQR", cfg.mlp_slices(), slice_size, start_index, end_index)
  #value[:,tcfg.position,start_index:end_index] =  mean_mlp_hook_post[:,tcfg.position,start_index:end_index].clone()


# Ablating the MLP in each layer in each position and seeing if the loss increases shows which layer+MLP are used by the algorithm.
def u_mlp_perform_all(questions):
  tcfg.is_head = False
  for tcfg.position in useful_info.positions:
    for tcfg.layer in range(cfg.n_layers):
      for tcfg.num in range(cfg.mlp_slices()):
        the_hook = [(l_mlp_hook_post_name[tcfg.layer], u_mlp_hook_post)]
        u_predict_questions(questions, the_hook)

In [None]:
def u_head_attn_hook_z(value, hook):
  # print( "In u_head_attn_hook_z", value.shape) # Get [1, 22, 3, 170] = ???, cfg.n_ctx, cfg.n_heads, cfg.d_head

  # Mean ablate. Copy the mean resid post values in position N to all the batch questions
  value[:,tcfg.position,tcfg.num,:] = mean_attn_z[:,tcfg.position,tcfg.num,:].clone()


# Ablating each head in each layer in each position and seeing if the loss increases shows which position+layer+head are used by the algorithm.
def u_head_perform_all(questions):
  tcfg.is_head = True
  for tcfg.position in useful_info.positions:
    for tcfg.layer in range(cfg.n_layers):
      for tcfg.num in range(cfg.n_heads):
        the_hook = [(l_attn_hook_z_name[tcfg.layer], u_head_attn_hook_z)]
        u_predict_questions(questions, the_hook)

In [None]:
def h_null_attn_z_hook(value, hook):
  global verbose

  #print("In h_null_attn_z_hook", value.shape)  # Get [1, 22, 3, 170] = ???, cfg.n_ctx, cfg.n_heads, cfg.d_head


def u_calculate_attention_tags(questions):
  useful_info.reset_node_tags(QuantaType.ATTENTION)

  logits, cache = cfg.main_model.run_with_cache(questions)

  all_attention_weights = []
  for layer in range(cfg.n_layers):
    attention_weights = cache["pattern", layer, "attn"]
    #print(attention_weights.shape) # 512, 4, 22, 22 = cfg.batch_size, cfg.n_heads, cfg.n_ctx, cfg.n_ctx

    average_attention_weights = attention_weights.mean(dim=0)
    #print(average_attention_weights.shape) # 4, 22, 22 = cfg.n_heads, cfg.n_ctx, cfg.n_ctx

    all_attention_weights += [average_attention_weights]


  for node in useful_info.nodes:
    if node.is_head:

      # Get attention weights for this token in this head
      layer_weights = all_attention_weights[node.layer]
      weights = layer_weights[node.num, node.position, :]

      top_tokens = torch.topk(weights, 4)
      total_attention = weights.sum()
      attention_percentage = top_tokens.values / total_attention * 100

      # Add up to 4 tags with percs per head
      for idx, token_idx in enumerate(top_tokens.indices):
        perc = attention_percentage[idx]
        if perc >= 1.0:
          u_add_node_tag( node, QuantaType.ATTENTION, f"P{token_idx}={perc:.0f}" )

In [None]:
verbose = False
useful_info.nodes = []
u_mlp_perform_all(varied_questions)
u_head_perform_all(varied_questions)
u_calculate_attention_tags(varied_questions)
useful_info.sort_nodes()

 # Part 15A: Set up: Show and save Quanta map

 Using the UsefulNodes and filtering their tags, show a 2D map of the nodes and the tag minor versions.

In [None]:
def show_quanta_map( title, custom_cmap, shades, major_tag, minor_tag, get_node_details, base_fontsize = 10, max_width = 10):

  print_config()
  print()

  ax1, quanta_results = calc_quanta_map(custom_cmap, shades, major_tag, minor_tag, get_node_details, base_fontsize, max_width)

  if cfg.graph_file_suffix > "":
    print("Saving quanta map:", title)
    save_plt_to_file(title)
  else:
    ax1.set_title(file_prefix() + ' ' + title + ' ({} nodes)'.format(len(quanta_results)))

  # Show plot
  plt.show()

# Part 16A: Results: Show failure percentage quanta map

Show the percentage failure rate (incorrect prediction) when individual Attention Heads and MLPs are ablated.

A cell containing "< 1" may add some risk to the accuracy of the overall analysis process. Check to see if this represents a new use case. Improve the test data set to contain more instances of this (new or existing) use case.

In [None]:
show_quanta_map( "Failure Frequency Per Node", plt.cm.winter, 10, QuantaType.FAIL, "", get_quanta_fail_perc, 9)

# Part 16B - Show answer impact quanta map

Show the purpose of each useful cell by impact on the answer digits A0 to A5.


In [None]:
show_quanta_map( "Answer Impact Per Node", plt.cm.winter, cfg.n_digits+2, QuantaType.IMPACT, "", get_quanta_impact, 9, 6)

# Part 16C: Result: Show attention quanta map

Show attention quanta of useful heads

In [None]:
# Only maps attention heads, not MLP layers
show_quanta_map( "Attention Per Head", plt.cm.winter, 10, QuantaType.ATTENTION, "", get_quanta_attention, 10, 6)

# Part 16C - Show question complexity (S*) quanta map

Show the "minimum" addition purpose of each useful cell by S0 to S5 quanta.
Show the "minimum" subtraction purpose of each useful cell by M0 to M5 quanta

In [None]:
def get_quanta_min_complexity(node, major_tag, minor_tag, shades):
  color_index = 0
  cell_text = node.min_tag_suffix( major_tag, minor_tag )
  if cell_text != "" :
    cell_text = cell_text[0:2]
    color_index = int(cell_text[1]) if len(cell_text) > 1 and cell_text[1].isdigit() else shades-1

  return cell_text, color_index


def show_quanta_min_tags( title, major_tag, minor_tag, shades):
  show_quanta_map( title, create_custom_colormap(), shades, major_tag, minor_tag, get_quanta_min_complexity)

In [None]:
if cfg.perc_add() > 0:
  show_quanta_min_tags( "Addition Min-Complexity Per Node", QuantaType_MATH_ADD, "", 6)

In [None]:
if cfg.perc_sub > 0:
  show_quanta_min_tags( "Subtraction Min-Complexity Per Node", QuantaType_MATH_SUB, "", 4)

In [None]:
if cfg.perc_sub > 0:
  show_quanta_min_tags( "Neg-Answer Sub Min-Complexity Per Node", QuantaType_MATH_SUB, MATH_SUB_NG_TAG, 4)

#Part 19A: Set Up: Calc and graph PCA decomposition.


In [None]:
tn_questions = 100


def make_t_questions(test_digit, test_case, operation):
    limit = 10 ** test_digit
    questions = []
    for i in range(tn_questions):
      x_noise = 0
      y_noise = 0

      if operation == PLUS_INDEX:
        if test_case == 8:
          # These are n_digit addition questions where x and y sum is between 0 to 8
          x = random.randint(0, 8)
          y = random.randint(0, 8-x)
        if test_case == 9:
          # These are n_digit addition questions where x and y sum is 9
          x = random.randint(0, 9)
          y = 9 - x
        if test_case == 10:
          # These are n_digit addition questions where x and y sum is between 10 to 18
          x = random.randint(1, 9)
          y = random.randint(10-x, 9)

        # Randomise the lower digits - ensuring that x_noise + y_noise dont cause a MakeCarry
        x_noise = random.randint(0, limit-1)
        y_noise = random.randint(0, limit-1 - x_noise)


      if operation == MINUS_INDEX:
        if test_case == 8:
          # These are n_digit subtraction questions where x - y < 0
          x = random.randint(0, 8)
          y = random.randint(x+1, 9)
        if test_case == 9:
          # These are n_digit subtraction questions where x - y is 0
          x = random.randint(0, 9)
          y = x
        if test_case == 10:
          # These are n_digit subtraction questions where x - y > 0
          x = random.randint(1, 9)
          y = random.randint(0, x-1)

        # Randomise the lower digits - ensuring that x_noise + y_noise dont cause a BorrowOne
        x_noise = random.randint(0, limit-1)
        y_noise = random.randint(0, x_noise)


      x = x * limit + x_noise
      y = y * limit + y_noise
      questions.append([x, y])

    return make_questions(operation, questions)



def make_tricase_questions(test_digit, operation):
  q1 = make_t_questions(test_digit, 8, operation)
  q2 = make_t_questions(test_digit, 9, operation)
  q3 = make_t_questions(test_digit, 10, operation)

  questions = torch.vstack((q1, q2, q3))

  return questions


# Cache the sample questions (by answer_digit and operation) for later reuse
t_questions_dict = {}
for answer_digit in range(cfg.n_digits):
    for operation in [PLUS_INDEX, MINUS_INDEX]:
        t_questions = make_tricase_questions(answer_digit, operation)
        # Use a tuple of (answer_digit, operation) as the key for indexing
        t_questions_dict[(answer_digit, operation)] = t_questions

In [None]:
def pca_evr_0_percent(pca):
  return int(round(pca.explained_variance_ratio_[0]*100,0))

In [None]:
# Calculate one Principal Component Analysis
def calc_pca_for_an(node_location, operation, answer_digit):
  assert node_location.is_head == True

  try:
    t_questions = t_questions_dict[(answer_digit, operation)]

    t_logits, t_cache = cfg.main_model.run_with_cache(t_questions)

    # Gather attention patterns for all the (randomly chosen) questions
    attention_outputs = []
    for i in range(len(t_questions)):

      # Output of individual heads, without final bias
      attention_cache=t_cache["result", node_location.layer, "attn"] # Output of individual heads, without final bias
      attention_output=attention_cache[i]  # Shape [n_ctx, n_head, d_model]
      attention_outputs.append(attention_output[node_location.position, node_location.num, :])

    attn_outputs = torch.stack(attention_outputs, dim=0).cpu()

    pca = PCA(n_components=6)
    pca.fit(attn_outputs)
    pca_attn_outputs = pca.transform(attn_outputs)

    title = node_location.name() + ', A'+str(answer_digit) + ', EVR[0]=' + str(pca_evr_0_percent(pca)) + '%'

    return pca, pca_attn_outputs, title
  except Exception as e:
    print( "calc_pca_for_an Failed:" + node_location.name() + " " + token_to_char(operation) + " " + answer_name(answer_digit), e)
    return None, None, None

In [None]:
# Plot the PCA of PnLnHn's attention pattern, using T8, T9, T10 questions that differ in the An digit
def plot_pca_for_an(ax, pca_attn_outputs, title):
  ax.scatter(pca_attn_outputs[:tn_questions, 0], pca_attn_outputs[:tn_questions, 1], color='red', label='T8 (0-8)') # t8 questions
  ax.scatter(pca_attn_outputs[tn_questions:2*tn_questions, 0], pca_attn_outputs[tn_questions:2*tn_questions, 1], color='green', label='T9') # t9 questions
  ax.scatter(pca_attn_outputs[2*tn_questions:, 0], pca_attn_outputs[2*tn_questions:, 1], color='blue', label='T10 (10-18)') # t10 questions
  if title != "" :
    ax.set_title(title)

In [None]:
def pca_tag(the_digit, strong):
  return answer_name(the_digit)  + "." + PCA_ADD_TAG + ( "" if strong else ".Weak")

In [None]:
def manual_node_pca(ax, position, layer, num, operation, answer_digit):

  node_location = NodeLocation(position, layer, True, num)
  pca, pca_attn_outputs, title = calc_pca_for_an(node_location, operation, answer_digit)
  plot_pca_for_an(ax, pca_attn_outputs, title)

  # Add the strong PCA tag to node PCA:A5.TR
  u_add_node_tag( node_location, QuantaType.PCA, pca_tag(answer_digit, True) )


def auto_node_pca(ax, index, node_location, operation, answer_digit, perc_threshold):

  pca, pca_attn_outputs, title = calc_pca_for_an(node_location, operation, answer_digit)
  if pca is not None:
    perc = pca_evr_0_percent(pca)
    if perc > perc_threshold:
      plot_pca_for_an(ax, pca_attn_outputs, title)

      # Add the weak PCA tag to node PCA:A5.TR.Weak
      u_add_node_tag( node_location, QuantaType.PCA, pca_tag(answer_digit, False) )
      return True

  return False

In [None]:
def manual_nodes_pca(op, nodes):
  cols = 4
  rows = 1 + (len(nodes)+1) // cols

  fig, axs = plt.subplots(rows, cols)
  fig.set_figheight(rows*2 + 1)
  fig.set_figwidth(10)

  index = 0
  for node in nodes:
    manual_node_pca(axs[index // cols, index % cols], node[0], node[1], node[2], op, node[3])
    index += 1

  # Remove any graphs we dont need (except last one)
  while index < rows * cols - 1:
    ax = axs[index // cols, index % cols]
    ax.remove()
    index += 1

  # Replace last graph with the legend
  lines_labels = [axs[0,0].get_legend_handles_labels()]
  lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
  axs[rows-1, cols-1].legend(lines, labels)
  axs[rows-1, cols-1].axis('off') # Now, to hide the last subplot

  plt.tight_layout()
  save_plt_to_file('Pca Tr')
  plt.show()

#Part 19B: Results: Manual interpetation of PCA results

If an attention head and an answer digit An gives an interpretable response (2 or 3 distinct output clusters) on 3 groups of questions aligned to T8, T9 and T10 definitions, then plot the response and add a QuantaType.PCA tag



In [None]:
useful_info.reset_node_tags(QuantaType.PCA)

In [None]:
# Plot all attention heads with the clearest An selected
if use_pca:

  if cfg.model_name == "add_d5_l1_h3_t30K" :
    manual_nodes_pca(PLUS_INDEX,
      [[ 12, 0, 0, 4 ],  # P12L0H0 with A4 EVR[0]=98%
      [ 12, 0, 2, 3 ],  # P12L0H2 with A3 EVR[0]=99%
      [ 13, 0, 0, 3 ],  # P13L0H0 with A3 EVR[0]=99%
      [ 13, 0, 2, 2 ],  # P13L0H2 with A2 EVR[0]=99%
      [ 14, 0, 0, 2 ],  # P14L0H0 with A2 EVR[0]=100%
      [ 14, 0, 2, 1 ],  # P14L0H2 with A1 EVR[0]=99%
      [ 15, 0, 0, 1 ],  # P15L0H0 with A1 EVR[0]=99%
      [ 15, 0, 2, 0 ],  # P15L0H2 with A0 EVR[0]=99%
      [ 16, 0, 0, 0 ]]) # P16L0H0 with A0 EVR[0]=99%

  if cfg.model_name == "add_d5_l2_h3_t15K" :
    manual_nodes_pca(PLUS_INDEX,
      [[10, 0, 0, 2 ],  # P10L0H0 with A2 EVR[0]=91%
      [ 12, 0, 0, 3 ],  # P12L0H0 with A3 EVR[0]=87%
      [ 12, 1, 0, 3 ],  # P12L1H0 with A3 EVR[0]=79%
      [ 12, 1, 1, 4 ],  # P12L1H1 with A4 EVR[0]=81%
      [ 12, 1, 2, 4 ],  # P12L1H2 with A4 EVR[0]=70%
      [ 13, 0, 0, 0 ],  # P13L0H0 with A0 EVR[0]=88%
      [ 13, 0, 0, 3 ],  # P13L0H0 with A3 EVR[0]=90%
      [ 13, 1, 2, 2 ],  # P13L1H2 with A3 EVR[0]=88%
      [ 14, 0, 0, 0 ],  # P14L0H0 with A0 EVR[0]=90%
      [ 14, 0, 0, 2 ],  # P14L0H0 with A2 EVR[0]=91%
      [ 14, 1, 2, 2 ],  # P14L1H2 with A2 EVR[0]=74%
      [ 15, 0, 0, 1 ],  # P15L0H0 with A1 EVR[0]=89%
      [ 15, 0, 0, 0 ],  # P15L0H0 with A0 EVR[0]=96%
      [ 15, 1, 1, 1 ],  # P15L1H1 with A1 EVR[0]=89%
      [ 16, 0, 0, 0 ]]) # P16L0H0 with A0 EVR[0]=90%

  if cfg.model_name == "add_d6_l2_h3_t15K" :
    manual_nodes_pca(PLUS_INDEX,
      [[11, 0, 0, 2 ],  # P11L0H0 with A2 EVR[0]=85%
      [ 12, 0, 0, 3 ],  # P12L0H0 with A3 EVR[0]=87%
      [ 13, 0, 0, 1 ],  # P13L0H0 with A1 EVR[0]=84%
      [ 14, 0, 0, 4 ],  # P14L0H0 with A4 EVR[0]=86%
      [ 14, 1, 1, 4 ],  # P14L1H1 with A4 EVR[0]=82%
      [ 15, 0, 0, 4 ],  # P15L0H0 with A4 EVR[0]=86%
      [ 15, 1, 1, 4 ],  # P15L1H1 with A4 EVR[0]=83%
      [ 16, 0, 0, 3 ],  # P16L0H0 with A3 EVR[0]=87%
      [ 16, 1, 1, 3 ],  # P16L1H1 with A3 EVR[0]=92%
      [ 17, 0, 0, 2 ],  # P17L0H0 with A2 EVR[0]=86%
      [ 17, 1, 1, 2 ],  # P17L1H1 with A2 EVR[0]=83%
      [ 18, 0, 0, 0 ],  # P18L0H0 with A0 EVR[0]=80%
      [ 18, 0, 0, 1 ],  # P18L0H1 with A1 EVR[0]=85%
      [ 19, 0, 0, 0 ]]) # P19L0H0 with A0 EVR[0]=85%

  if cfg.model_name == "sub_d6_l2_h3_t30K" :
    manual_nodes_pca(PLUS_INDEX,
      [[ 8, 0, 0, 4 ],  # P8L0H0 with A4 EVR[0]=37%
      [  9, 0, 1, 3 ],  # P9L0H1 with A3 EVR[0]=84%
      [ 10, 0, 1, 2 ],  # P10L0H1 with A2 EVR[0]=77%
      [ 10, 0, 1, 3 ],  # P10L0H1 with A3 EVR[0]=61%
      [ 11, 0, 1, 1 ],  # P11L0H1 with A1 EVR[0]=82%
      [ 11, 0, 1, 2 ],  # P11L0H1 with A2 EVR[0]=69%
      [ 12, 0, 0, 0 ],  # P12L0H0 with A0 EVR[0]=57%
      [ 13, 0, 1, 4 ],  # P13L0H1 with A4 EVR[0]=94%
      [ 13, 1, 1, 0 ],  # P13L1H1 with A0 EVR[0]=89%
      [ 13, 1, 1, 1 ],  # P13L1H1 with A1 EVR[0]=96%
      [ 13, 1, 1, 2 ],  # P13L1H1 with A2 EVR[0]=94%
      [ 13, 1, 1, 3 ],  # P13L1H1 with A3 EVR[0]=93%
      [ 13, 1, 2, 3 ],  # P13L1H2 with A3 EVR[0]=73%
      [ 13, 1, 2, 5 ],  # P13L1H2 with A5 EVR[0]=79%
      [ 14, 0, 1, 0 ],  # P14L0H1 with A0 EVR[0]=80%
      [ 15, 0, 0, 0 ],  # P15L0H0 with A0 EVR[0]=100%!
      [ 15, 0, 0, 1 ],  # P15L0H0 with A1 EVR[0]=100%!
      [ 15, 0, 0, 2 ],  # P15L0H0 with A2 EVR[0]=100%!
      [ 15, 0, 0, 3 ],  # P15L0H0 with A3 EVR[0]=100%!
      [ 15, 0, 0, 4 ],  # P15L0H0 with A4 EVR[0]=100%!
      [ 15, 0, 0, 5 ],  # P15L0H0 with A5 EVR[0]=100%!
      [ 15, 1, 0, 0 ],  # P15L1H0 with A0 EVR[0]=99%
      [ 15, 1, 0, 1 ],  # P15L1H0 with A1 EVR[0]=99%
      [ 15, 1, 0, 2 ],  # P15L1H0 with A2 EVR[0]=97%
      [ 15, 1, 0, 3 ],  # P15L1H0 with A3 EVR[0]=97%
      [ 15, 1, 0, 4 ]]) # P15L1H0 with A4 EVR[0]=84%

  if cfg.model_name == "ins1_mix_d6_l3_h4_t40K" :
    manual_nodes_pca(PLUS_INDEX,
      [[13, 1, 3, 1 ],  # P13L1H3 with A1 EVR[0]=85%
      [ 14, 1, 2, 0 ],  # P14L1H2 with A0 EVR[0]=77%
      [ 14, 1, 2, 2 ],  # P14L1H2 with A2 EVR[0]=82%
      [ 14, 1, 3, 4 ],  # P14L1H3 with A4 EVR[0]=87%
      [ 15, 0, 3, 5 ],  # P15L0H3 with A5 EVR[0]=100%
      [ 15, 1, 2, 2 ],  # P15L1H2 with A2 EVR[0]=81%
      [ 15, 1, 3, 4 ],  # P15L1H3 with A4 EVR[0]=87%
      [ 16, 0, 3, 4 ],  # P16L0H3 with A4 EVR[0]=99%
      [ 16, 1, 2, 0 ],  # P16L1H0 with A0 EVR[0]=78%
      [ 16, 1, 2, 1 ],  # P16L1H1 with A1 EVR[0]=77%
      [ 16, 1, 2, 2 ],  # P16L1H2 with A2 EVR[0]=80%
      [ 16, 1, 3, 2 ],  # P16L1H3 with A2 EVR[0]=78%
      [ 17, 0, 3, 3 ],  # P17L0H3 with A3 EVR[0]=99%
      [ 17, 1, 2, 2 ],  # P17L1H2 with A2 EVR[0]=84%
      [ 17, 1, 3, 2 ],  # P17L1H3 with A2 EVR[0]=90%
      [ 18, 0, 3, 2 ],  # P18L0H3 with A2 EVR[0]=98%
      [ 18, 1, 3, 1 ],  # P18L1H3 with A1 EVR[0]=88%
      [ 19, 0, 3, 1 ],  # P19L0H3 with A1 EVR[0]=98%
      [ 19, 2, 0, 0 ],  # P19L2H0 with A0 EVR[0]=77%
      [ 19, 2, 1, 0 ],  # P19L2H1 with A0 EVR[0]=80%
      [ 20, 0, 0, 0 ],  # P20L0H0 with A0 EVR[0]=80%
      [ 20, 0, 3, 0 ]]) # P20L0H3 with A0 EVR[0]=96%

else:
  print( "PCA library failed to import. So PCA not done")

#Part 19C: Results: Automatic interpetation of PCA results

Part 19B is manual and selective. This part is automatic. It tests nodes not included in Part 19B, where this first (single) principal component explains 66% or more of the node. It adds a QuantaType.PCA "weak" tag

In [None]:
def auto_find_pca_node(node, op, perc_threshold):
  fig, axs = plt.subplots(2, 4) # Allow up to 8 graphs
  fig.set_figheight(4)
  fig.set_figwidth(10)

  index = 0
  for answer_digit in range(cfg.n_digits+1):
    ax = axs[index // 4, index % 4]
    if auto_node_pca(ax, index, node, op, answer_digit, perc_threshold):
      index += 1

  # Remove any graphs we dont need after all
  while index < 2 * 4:
    ax = axs[index // 4, index % 4]
    ax.remove()
    index += 1

  plt.tight_layout()
  plt.show()

In [None]:
def auto_find_pca(op):
  perc_threshold = 75

  for node in useful_info.nodes:

    # Exclude nodes with a (manual) PCA tag - for any answer digit(s)). Exlcude MLP neurons.
    if node.is_head and not node.contains_tag(QuantaType.PCA, ""):
      print( "Doing PCA on node", node.name(), "operation", token_to_char(op))

      auto_find_pca_node(node, op, perc_threshold)


if use_pca:
  if cfg.perc_add() > 0:
    auto_find_pca(PLUS_INDEX)
  if cfg.perc_sub > 0:
    auto_find_pca(MINUS_INDEX)

# Part 20A: Results: Show useful nodes and behaviour tags

In [None]:
useful_info.print_node_tags()

# Part 20B: Results: Save useful nodes and behaviour tags to json file

In [None]:
# Serialize and save the useful nodes list to a temporary CoLab file in JSON format
print( "Saving useful node list with behavior tags:", main_fname_behavior_json)
useful_info.save_nodes(main_fname_behavior_json)

# Part 21A : Set up: Interchange Interventions

Here we prove that model nodes perform specified calculations. If all the calculations in an algorithm hypothesis are found to exist in a model instance, this provides evidence for the hypothesis.   

**Automatic searches** for node purposes are preferred, as they applicable to several models, and survive (non-sigificant, node-reordering) changes to the model after training. When a node purpose is detected, this is documented as a tag on the node.

**Manually written tests** of node purposes, specific to a single model instance are also supported.

In [None]:
class A_Config():
  # A list of NodeLocations
  node_locations = []

  # A list of stored weightings collected from the model.
  # Same length as nodes
  layer_store = []

  questions = []
  null_hooks = []
  get_hooks = []
  put_hooks = []

  # Expected output of an intervention ablation experiment
  expected_answer = ""
  expected_impact = "" # e.g A32
  # Actual outputs of an intervention ablation experiment
  intervened_answer = ""
  intervened_impact = "" # e.g A32


  def reset_hooks(self):
    self.node_locations = []
    self.layer_store = [[],[],[]]   # Supports 3 layers
    self.questions = []
    self.null_hooks = []
    self.get_hooks = []
    self.put_hooks = []

    tcfg.reset_node_location()
    tcfg.threshold = 0.00001


  def reset_intervention(self, expected_answer_int = 0, expected_impact = NO_IMPACT_TAG):
    self.expected_answer = int_to_answer_str(expected_answer_int)
    self.expected_impact = expected_impact if expected_impact != "" else NO_IMPACT_TAG
    self.intervened_answer = ""
    self.intervened_impact = NO_IMPACT_TAG


  def __init__(self):
    self.reset_hooks()
    self.reset_intervention()


  def node_names(self):
    answer = ""

    for node in self.node_locations:
      if answer != "":
        answer += ", "
      answer += node.name()

    return answer


acfg = A_Config()

In [None]:
# Get and put attention head value hooks

def a_null_attn_z_hook(value, hook):
  #print("In a_null_attn_z_hook", value.shape)  # Get [1, 22, 3, 170] = ???, cfg.n_ctx, cfg.n_heads, cfg.d_head
  pass


def a_get_l0_attn_z_hook(value, hook):
  # print( "In a_get_l0_attn_z_hook", value.shape) # Get [1, 22, 3, 170] = ???, cfg.n_ctx, cfg.n_heads, cfg.d_head
  acfg.layer_store[0] = value.clone()

def a_get_l1_attn_z_hook(value, hook):
  acfg.layer_store[1] = value.clone()

def a_get_l2_attn_z_hook(value, hook):
  acfg.layer_store[2] = value.clone()

def a_get_l3_attn_z_hook(value, hook):
  acfg.layer_store[3] = value.clone()


def a_put_l0_attn_z_hook(value, hook):
  # print( "In a_l0_attn_z_hook", value.shape) # Get [1, 22, 3, 170] = ???, cfg.n_ctx, cfg.n_heads, d_head
  for location in acfg.node_locations:
    if location.layer == 0:
      value[:,location.position,location.num,:] = acfg.layer_store[0][:,location.position,location.num,:].clone()

def a_put_l1_attn_z_hook(value, hook):
  for location in acfg.node_locations:
    if location.layer == 1:
      value[:,location.position,location.num,:] = acfg.layer_store[0][:,location.position,location.num,:].clone()

def a_put_l2_attn_z_hook(value, hook):
  for location in acfg.node_locations:
    if location.layer == 2:
      value[:,location.position,location.num,:] = acfg.layer_store[0][:,location.position,location.num,:].clone()

def a_put_l3_attn_z_hook(value, hook):
  for location in acfg.node_locations:
    if location.layer == 3:
      value[:,location.position,location.num,:] = acfg.layer_store[0][:,location.position,location.num,:].clone()


def a_reset(node_locations):
  acfg.reset_hooks()
  acfg.node_locations = node_locations
  acfg.null_hooks = [(l_attn_hook_z_name[0], a_null_attn_z_hook), (l_attn_hook_z_name[1], a_null_attn_z_hook), (l_attn_hook_z_name[2], a_null_attn_z_hook), (l_attn_hook_z_name[3], a_null_attn_z_hook)][:cfg.n_layers]
  acfg.get_hooks = [(l_attn_hook_z_name[0], a_get_l0_attn_z_hook), (l_attn_hook_z_name[1], a_get_l1_attn_z_hook), (l_attn_hook_z_name[2], a_get_l2_attn_z_hook), (l_attn_hook_z_name[3], a_get_l3_attn_z_hook)][:cfg.n_layers]
  acfg.put_hooks = [(l_attn_hook_z_name[0], a_put_l0_attn_z_hook), (l_attn_hook_z_name[1], a_put_l1_attn_z_hook), (l_attn_hook_z_name[2], a_put_l2_attn_z_hook), (l_attn_hook_z_name[3], a_put_l3_attn_z_hook)][:cfg.n_layers]

In [None]:
def run_intervention_core(node_locations, store_question, test_question, operation):
  assert(test_question[0] < + 10 ** cfg.n_digits)
  assert(test_question[1] > - 10 ** cfg.n_digits)
  assert(test_question[0] < + 10 ** cfg.n_digits)
  assert(test_question[1] > - 10 ** cfg.n_digits)


  a_reset(node_locations)

  # Calculate the clean (no intervention) test question answer  e.g. "+006671"
  clean_answer_int = test_question[0]+test_question[1] if operation == PLUS_INDEX else test_question[0]-test_question[1]
  clean_answer = int_to_answer_str(clean_answer_int)
  description = "Intervening on " + acfg.node_names() + ", CleanAnswer: " + clean_answer

  # Predict "store" question and store activation values
  acfg.questions = make_questions(operation, [store_question])
  predict_questions_core(acfg.questions, acfg.get_hooks)

  # Predict "test" question overriding PnLmHp to give a bad answer
  acfg.questions = make_questions(operation, [test_question])
  all_losses_raw, all_max_prob_tokens = predict_questions_core(acfg.questions, acfg.put_hooks)
  loss_max = utils.to_numpy(loss_fn(all_losses_raw[0]).max())
  acfg.intervened_answer = tokens_to_string(all_max_prob_tokens[0])


  # Compare the clean test question answer to what the model generated (impacted by the ablation intervention)
  acfg.intervened_impact = get_answer_impact( clean_answer, acfg.intervened_answer )
  if acfg.intervened_impact == "":
    acfg.intervened_impact = NO_IMPACT_TAG

  if loss_max > tcfg.threshold:
    loss_str = NO_IMPACT_TAG if loss_max < 1e-7 else str(loss_max)

    description += ", IntervenedAnswer/Impact: " + acfg.intervened_answer + "/" + acfg.intervened_impact + ", Loss: " + loss_str

  return description


# Run an intervention where we have a precise expectation of the intervention impact
def run_strong_intervention(node_locations, store_question, test_question, operation, expected_impact, expected_answer_int, show_failures = True):
  acfg.reset_intervention(expected_answer_int, expected_impact)

  # These are the actual model prediction outputs (while applying our node-level intervention).
  description = run_intervention_core(node_locations, store_question, test_question, operation)

  answer_success = (acfg.intervened_answer == acfg.expected_answer)
  impact_success = (acfg.intervened_impact == acfg.expected_impact)
  success = answer_success and impact_success

  if show_failures and not success:
    print( description )
    print("Failed: Expected:", acfg.expected_answer, acfg.expected_impact, "Model predicted:", acfg.intervened_answer, acfg.intervened_impact)

  return success, answer_success, impact_success


# Run an intervention where we expect the intervention to have a non-zero impact and we cant precisely predict the answer
def run_weak_intervention(node_locations, store_question, test_question, operation, show_failures = True):

  # Calculate the test (clean) question answer e.g. "+006671"
  expected_answer_int = test_question[0]+test_question[1] if operation == PLUS_INDEX else test_question[0]-test_question[1]
  acfg.reset_intervention(expected_answer_int, NO_IMPACT_TAG)

  description = run_intervention_core(node_locations, store_question, test_question, operation)

  success = not ((acfg.intervened_answer == acfg.expected_answer) or (acfg.intervened_impact == NO_IMPACT_TAG))

  if show_failures and not success:
    print("Failed: Intervention had no impact on the answer", description)

  return success

In [None]:
def repeat_digit(digit):
    return int(str(digit) * cfg.n_digits)


# unit test
if cfg.n_digits == 6:
  assert repeat_digit(4) == 444444

In [None]:
def ignore_test(node_locations, alter_digit, strong, show_failures = False):
  return True

In [None]:
# Search the specified useful node(s), using the test_function, for the expected impact on the_impact_digit
def search_and_tag_digit_position( the_impact_digit, the_test_nodes, test_function, strong, the_tag, do_pair_search ):

  # Try single nodes first
  for node in the_test_nodes:
    if test_function( [node], the_impact_digit, strong):
      full_tag = the_tag + ("" if strong else "." + acfg.intervened_impact)
      node.add_tag(QuantaType.ALGO, full_tag)
      return True

  # Try pairs of nodes. Sometimes a task is split across two attention heads (i.e. a virtual attention head)
  if do_pair_search:
    node_pairs = list(itertools.combinations(the_test_nodes, 2))
    for pair in node_pairs:
        if test_function( [pair[0], pair[1]], the_impact_digit, strong):
          full_tag = the_tag + ("" if strong else "." + acfg.intervened_impact)
          pair[0].add_tag(QuantaType.ALGO, full_tag)
          pair[1].add_tag(QuantaType.ALGO, full_tag)
          return True

  return False


# For each useful position, search the related useful node(s), using the test_function, for the expected impact on the_impact_digit.
def search_and_tag_digit( prerequisites_function, the_impact_digit, test_function, tag_function, do_pair_search, do_weak_search, from_position, to_position ):

  the_tag = tag_function(the_impact_digit)

  if from_position == -1:
    from_position = useful_info.min_useful_position()
  if to_position == -1:
    to_position = useful_info.max_useful_position()

  # In some models, we don't predict the intervened_answer correctly in test_function.
  # So we may do a weak second pass and may add say "A5.BS.A632" tag to a node.
  for strong in [True, False]:
    if strong or do_weak_search:

      for position in range(from_position, to_position+1):
        test_nodes = filter_nodes( useful_info.nodes, prerequisites_function(position, the_impact_digit))
        if search_and_tag_digit_position( the_impact_digit, test_nodes, test_function, strong, the_tag, do_pair_search ):
          return True

  return False


# For each answer digit, for each useful position, search the related useful node(s), using the test_function, for the expected impact on the_impact_digit. We may do 2 passes.
def search_and_tag( prerequisites_function, test_function, tag_function, do_pair_search = True, do_weak_search = True, from_position = -1, to_position = -1):
  for the_impact_digit in range(useful_info.num_answer_positions):
    search_and_tag_digit(
      prerequisites_function, the_impact_digit, test_function, tag_function,
      do_pair_search, do_weak_search, from_position, to_position )

In [None]:
useful_info.reset_node_tags(QuantaType.ALGO)

# Part 21B: Automated Dn.US search

The addition Use Sum 9 (US) operation is a simple task. Search for US tasks.

In [None]:
def add_us_tag(impact_digit):
  return answer_name(impact_digit-1)  + "." + ALGO_ADD_US_TAG

In [None]:
# These rules are prerequisites for (not proof of) an Addition UseSum9 node
def add_us_prereqs(position, impact_digit):
  return FilterAnd(
    FilterHead(),
    FilterPosition(position_name(position)),
    FilterAttention(dn_to_position_name(impact_digit-2)), # Attends to Dn-2
    FilterAttention(ddn_to_position_name(impact_digit-2)), # Attends to D'n-2
    FilterImpact(answer_name(impact_digit)))

In [None]:
def add_us_test(node_locations, alter_digit, strong):
  if alter_digit < 2 or alter_digit > cfg.n_digits:
    acfg.reset_intervention()
    return False

  intervention_impact = answer_name(alter_digit)

  # 25222 + 44444 = 69666. Has no Dn-2.MC but has Dn-1.US so not a US case
  store_question = [repeat_digit(2), repeat_digit(4)]
  store_question[0] += (5-2) * 10 ** (alter_digit - 1)

  # 34633 + 55555 = 90188. Has Dn-2.MC and Dn-1.US so is a US case
  test_question = [repeat_digit(3), repeat_digit(5)]
  test_question[0] += (4-3) * 10 ** (alter_digit - 1)
  test_question[0] += (6-3) * 10 ** (alter_digit - 2)

  # When we intervene we expect answer 80188
  intervened_answer = test_question[0] + test_question[1] - 10 ** (alter_digit)


  # Unit test
  if cfg.n_digits == 5 and alter_digit == 4:
    assert store_question[0] == 25222
    assert test_question[0] == 34633
    assert test_question[0] + test_question[1] == 90188
    assert intervened_answer == 80188


  success, _, _ = run_strong_intervention(node_locations, store_question, test_question, PLUS_INDEX, intervention_impact, intervened_answer, False)

  if success:
    print( "Test confirmed", acfg.node_names(), "perform D"+str(alter_digit)+".US impacting "+intervention_impact+" accuracy.", "Strong:", strong)

  return success

In [None]:
# if cfg.perc_add() > 0: Should not succeed in subtraction cases
search_and_tag( add_us_prereqs, add_us_test, add_us_tag, False, False)

# Part 21C: Automated Dn.MC search

The addition Make Carry (MC) operation is a simple task. Search for MC tasks.

In [None]:
def add_mc_tag(impact_digit):
  return answer_name(impact_digit-1)  + "." + ALGO_ADD_MC_TAG

In [None]:
# These rules are prerequisites for (not proof of) an Addition MakeCarry node
def add_mc_prereqs(position, impact_digit):
  return FilterAnd(
    FilterHead(),
    FilterPosition(position_name(position)),
    FilterAttention(dn_to_position_name(impact_digit-1)), # MC is calculated on the next lower-value digit.
    FilterAttention(ddn_to_position_name(impact_digit-1)), # MC is calculated on the next lower-value digit.
    FilterImpact(answer_name(impact_digit)))

In [None]:
def add_mc_test(node_locations, impact_digit, strong):
  alter_digit = impact_digit - 1

  if alter_digit < 0 or alter_digit >= cfg.n_digits:
    acfg.reset_intervention()
    return False

  intervention_impact = answer_name(impact_digit)

  # 222222 + 666966 = 889188. Has Dn.MC
  store_question = [repeat_digit(2), repeat_digit(6)]
  store_question[1] += (9 - 6) * (10 ** alter_digit)

  # 333333 + 555555 = 888888. No Dn.MC
  test_question = [repeat_digit(3), repeat_digit(5)]

  # When we intervene we expect answer 889888
  intervened_answer = test_question[0] + test_question[1] + 10 ** (alter_digit+1)

  success, _, _ = run_strong_intervention(node_locations, store_question, test_question, PLUS_INDEX, intervention_impact, intervened_answer, False)

  if success:
    print( "Test confirmed", acfg.node_names(), "perform D"+str(alter_digit)+".MC impacting "+intervention_impact+" accuracy.", "Strong:", strong)

  return success

In [None]:
# if cfg.perc_add() > 0: Should not succeed in subtraction cases
search_and_tag( add_mc_prereqs, add_mc_test, add_mc_tag, False, False)

# Part 21D: Automated Dn.BA search

The addition Base Add (BA) operation is a simple task. The task may be split/shared over 2 attention heads in the same position. Search for BA calculations.

In [None]:
def add_ba_tag(impact_digit):
  return answer_name(impact_digit) + "." + ALGO_ADD_BA_TAG

In [None]:
# These rules are prerequisites for (not proof of) an Addition BaseAdd node
def add_ba_prereqs(position, impact_digit):
  return FilterAnd(
    FilterHead(),
    FilterPosition(position_name(position)),
    FilterAttention(dn_to_position_name(impact_digit)), # Attends to Dn
    FilterAttention(ddn_to_position_name(impact_digit)), # Attends to D'n
    FilterImpact(answer_name(impact_digit)), # Impacts An
    FilterAlgo(add_ba_tag(impact_digit), QuantaFilter.NOT)) # Has not already been flagged as a BA task

In [None]:
def add_ba_test1(alter_digit):
  # 222222 + 111111 = 333333. No Dn.MC
  store_question = [repeat_digit(2), repeat_digit(1)]

  # 555555 + 444444 = 999999. No Dn.MC
  test_question = [repeat_digit(5), repeat_digit(4)]

  # When we intervene we expect answer 999399
  intervened_answer = test_question[0] + test_question[1] + (3-9) * 10 ** alter_digit

  return store_question, test_question, intervened_answer


def add_ba_test2(alter_digit):
  # 222222 + 666666 = 888888. No Dn.MC
  store_question = [repeat_digit(2), repeat_digit(6)]

  # 555555 + 111111 = 666666. No Dn.MC
  test_question = [repeat_digit(5), repeat_digit(1)]

  # When we intervene we expect answer 666866
  intervened_answer = test_question[0] + test_question[1] + (8-6) * 10 ** alter_digit

  return store_question, test_question, intervened_answer


def add_ba_test(node_locations, alter_digit, strong, show_failures = False):
  intervention_impact = answer_name(alter_digit)

  store_question, test_question, intervened_answer = add_ba_test1(alter_digit)
  success1, answer_success1, impact_success1 = run_strong_intervention(node_locations, store_question, test_question, PLUS_INDEX, intervention_impact, intervened_answer, show_failures)

  store_question, test_question, intervened_answer = add_ba_test2(alter_digit)
  success2, answer_success2, impact_success2 = run_strong_intervention(node_locations, store_question, test_question, PLUS_INDEX, intervention_impact, intervened_answer, show_failures)

  success = (success1 and success2) if strong else (impact_success1 and impact_success2)

  if success:
    print( "Test confirmed:", acfg.node_names(), "perform D"+str(alter_digit)+".BA = (D"+str(alter_digit)+" + D"+str(alter_digit)+"') % 10 impacting "+intervention_impact+" accuracy.", "Strong:", strong)

  return success

In [None]:
#if cfg.perc_add() > 0: Should not succeed in subtraction cases
search_and_tag( add_ba_prereqs, add_ba_test, add_ba_tag, True, True )

# Part 21E: Automated Dn.C search

Search for D0.C to D5.C with impact "A65432" to "A65" in early tokens.

A0 and A1 are too simple to need Dn.C values so they are excluded from the answer impact.

In [None]:
def add_tc_tag(focus_digit):
  return "D" + str(focus_digit) + "." + ALGO_ADD_TC_TAG

In [None]:
# These rules are prerequisites for (not proof of) an Addition Dn.C node
def add_tc_prereqs(position, focus_digit):
  return FilterAnd(
    FilterHead(),
    FilterPosition(position_name(position)),
    FilterAttention(dn_to_position_name(focus_digit)), # Attends to Dn
    FilterAttention(ddn_to_position_name(focus_digit)), # Attends to D'n
    FilterPCA(PCA_ADD_TAG, QuantaFilter.CONTAINS)) # Node PCA is interpretable with respect to T8,T9,T10

In [None]:
def add_tc_test(node_locations, focus_digit, strong):
  # 222222 + 777977 = 1000188. Has Dn.MC
  store_question = [repeat_digit(2), repeat_digit(7)]
  store_question[1] += (9 - 7) * (10 ** focus_digit)

  # 333333 + 666666 = 999999. No Dn.MC
  test_question = [repeat_digit(3), repeat_digit(6)]
  alter_sum = test_question[0] + test_question[1]

  success = run_weak_intervention(node_locations, store_question, test_question, PLUS_INDEX, False)

  if success:
    description = acfg.node_names() + " perform D"+str(focus_digit)+".C = TriCase(D"+str(focus_digit)+" + D"+str(focus_digit)+"')"
    print("Test confirmed", description, "Impact:", acfg.intervened_impact, "Strong:", strong)

  return success

In [None]:
# if cfg.perc_add() > 0: Should not succeed in subtraction cases
search_and_tag( add_tc_prereqs, add_tc_test, add_tc_tag,
  False, # Have not seen this task split between nodes.
  False,
  cfg.n_digits, 2*cfg.n_digits+2) # These occur from the first D'n digit to the first answer digit.

# Part 21F: Automated Dn.BS search

The subtraction Base Subtraction (BS) operation is a simple task. The task may be split/shared over 2 attention heads in the same position. Search for BS calculations.

In [None]:
def sub_bs_tag(impact_digit):
  return answer_name(impact_digit) + "." + ALGO_SUB_BS_TAG

In [None]:
# These rules are prerequisites for (not proof of) a BaseSubtraction node
def sub_bs_prereqs(position, impact_digit):
  return FilterAnd(
    FilterHead(),
    FilterPosition(position_name(position)),
    FilterAttention(dn_to_position_name(impact_digit)), # Attends to Dn
    FilterAttention(ddn_to_position_name(impact_digit)), # Attends to D'n
    FilterImpact(answer_name(impact_digit)), # Impacts An
    FilterAlgo(sub_bs_tag(impact_digit), QuantaFilter.NOT)) # Has not already been flagged as a BS task

In [None]:
def sub_bs_test1(alter_digit):
  # 333333 - 111111 = 222222. No Dn.BO
  store_question = [repeat_digit(3), repeat_digit(1)]

  # 999999 - 444444 = 555555. No Dn.BO
  test_question = [repeat_digit(9), repeat_digit(4)]

  # When we intervene we expect answer 555255
  intervened_answer = test_question[0] - test_question[1] + (2-5) * 10 ** alter_digit

  return store_question, test_question, intervened_answer


def sub_bs_test2(alter_digit):
  # 666666 - 222222 = 444444. No Dn.BO
  store_question = [repeat_digit(6), repeat_digit(2)]

  # 555555 - 333333 = 222222. No Dn.BO
  test_question = [repeat_digit(5), repeat_digit(3)]

  # When we intervene we expect answer 222422
  intervened_answer = test_question[0] - test_question[1] + (4-2) * 10 ** alter_digit

  return store_question, test_question, intervened_answer


def sub_bs_test(node_locations, alter_digit, strong, show_failures = False):
  intervention_impact = answer_name(alter_digit)

  store_question, test_question, intervened_answer = sub_bs_test1(alter_digit)
  success1, answer_success1, impact_success1 = run_strong_intervention(node_locations, store_question, test_question, MINUS_INDEX, intervention_impact, intervened_answer, show_failures)

  store_question, test_question, intervened_answer = sub_bs_test2(alter_digit)
  success2, answer_success2, impact_success2 = run_strong_intervention(node_locations, store_question, test_question, MINUS_INDEX, intervention_impact, intervened_answer, show_failures)

  success = (success1 and success2) if strong else (impact_success1 and impact_success2)

  if success:
    print( "Test confirmed:", acfg.node_names(), "perform D"+str(alter_digit)+".BS = (D"+str(alter_digit)+" + D"+str(alter_digit)+"') % 10 impacting "+intervention_impact+" accuracy.", "Strong:", strong)

  return success

In [None]:
search_and_tag( sub_bs_prereqs, sub_bs_test, sub_bs_tag, True, True )

# Part 21G: Automated Dn.BO search

The subtraction Borrow One operation is a simple task. Search for BO tasks.

In [None]:
def sub_bo_tag(impact_digit):
  return answer_name(impact_digit-1)  + "." + ALGO_SUB_BO_TAG

In [None]:
# These rules are prerequisites for (not proof of) an subtraction Borrow One node
def sub_bo_prereqs(position, impact_digit):
  return FilterAnd(
    FilterHead(),
    FilterPosition(position_name(position)),
    FilterAttention(dn_to_position_name(impact_digit-1)), # BO is calculated on the next lower-value digit.
    FilterAttention(ddn_to_position_name(impact_digit-1)), # BO is calculated on the next lower-value digit.
    FilterImpact(answer_name(impact_digit)))

In [None]:
def sub_bo_test(node_locations, impact_digit, strong):
  alter_digit = impact_digit - 1
  intervention_impact = answer_name(impact_digit)

  # 222222 - 111311 = 110911. Has Dn.BO
  store_question = [repeat_digit(2), repeat_digit(1)]
  store_question[1] += (3 - 1) * (10 ** alter_digit)

  # 777777 - 444444 = 333333. No Dn.BO
  test_question = [repeat_digit(7), repeat_digit(4)]

  # When we intervene we expect answer 332333
  intervened_answer = test_question[0] - test_question[1] - 10 ** alter_digit

  #print("PQR", store_question,test_question,intervened_answer)
  success, _, _ = run_strong_intervention(node_locations, store_question, test_question, MINUS_INDEX, intervention_impact, intervened_answer, False)

  if success:
    print( "Test confirmed", acfg.node_names(), "perform D"+str(alter_digit)+".BO impacting "+intervention_impact+" accuracy.", "Strong:", strong)

  return success

In [None]:
#if cfg.perc_sub > 0: Should not succeed in addition cases
search_and_tag( sub_bo_prereqs, sub_bo_test, sub_bo_tag, False, False) #PQR debug

In [None]:
#print(filter_nodes( useful_info.nodes, sub_bo_prereqs(17, 3)))
#print(filter_nodes( useful_info.nodes, sub_bo_prereqs(17, 4)))
#print(filter_nodes( useful_info.nodes, sub_bo_prereqs(17, 5)))

#nodes = filter_nodes( useful_info.nodes, sub_bo_prereqs(17, 3) )
#print(nodes[0].name())
#sub_bo_test( [NodeLocation(17, 0, True, 0)], 3, True)


# Part 21H: Automated Dn.NG search

Somes useful nodes are only used in subtraction when A < B in the A - B question e.g. negative-answer questions. We claim these nodes are somehow associated with converting A - B to - ( B - A )

In [None]:
def sub_ng_tag(impact_digit):
  return answer_name(impact_digit)  + "." + ALGO_SUB_NG_TAG

In [None]:
def sub_ng_prereqs(position, impact_digit):
  return FilterAnd(
    FilterHead(),
    FilterPosition(position_name(position)),
    FilterImpact(answer_name(impact_digit)), # Impacts An
    FilterContains(QuantaType_MATH_SUB, MATH_SUB_NG_TAG), # Impacts negative-answer questions
    # Does not impact positive-answer subtraction questions (of any complexity)
    FilterContains(QuantaType_MATH_SUB, MATH_SUB_S0_TAG, QuantaFilter.NOT),
    FilterContains(QuantaType_MATH_SUB, MATH_SUB_S1_TAG, QuantaFilter.NOT),
    FilterContains(QuantaType_MATH_SUB, MATH_SUB_S2_TAG, QuantaFilter.NOT),
    FilterContains(QuantaType_MATH_SUB, MATH_SUB_S3_TAG, QuantaFilter.NOT))

In [None]:
search_and_tag( sub_ng_prereqs, ignore_test, sub_ng_tag )

# Part 21I: Automated Dn.OP search

For mixed models that do addition and subtraction the operation token "+/-" is key. Find nodes that attend to the operation.

In [None]:
def mix_op_tag(impact_digit):
  return ALGO_MIX_OP_TAG

In [None]:
def sub_op_prereqs(position, impact_digit):
  return FilterAnd(
    FilterHead(),
    FilterPosition(position_name(position)),
    FilterAttention(op_position_name()))

In [None]:
search_and_tag( sub_op_prereqs, ignore_test, mix_op_tag )

# Part 22: Show algorithm quanta map

Plot the "algorithm" tags generated in previous steps as a quanta map. This is an automatically generated partail explanation of the model algorithm.

In [None]:
show_quanta_map( "Algorithm Purpose Per Node", create_custom_colormap(), 2, QuantaType.ALGO, "", get_quanta_binary, 9)

# Part 23: Save useful nodes with behaviour and algorithm tags to JSON file

Show a list of the nodes that have proved useful in calculations, together with data on the nodes behavior and algorithmic purposes.
Save the data to a Colab temporary JSON file.



In [None]:
useful_info.print_node_tags(QuantaType.ALGO, False)

In [None]:
# Serialize and save the useful nodes list with algorithm tags to a temporary CoLab file in JSON format
print( "Saving useful node list with algorithm tags:", main_fname_algorithm_json)
useful_info.save_nodes(main_fname_algorithm_json)

# Part 24: Test Addition Hypothesis

In [None]:
num_hypothesis_clause_failures = 0
num_hypothesis_clause_successes = 0

In [None]:
def node_exists(the_nodes, the_filters):
  global num_hypothesis_clause_failures
  global num_hypothesis_clause_successes

  matching_nodes = filter_nodes( the_nodes, the_filters)
  num_nodes = len(matching_nodes)

  if num_nodes > 0:
    print( "Clause succeeded on ", num_nodes, "node(s) including", matching_nodes[0].name(), ":", the_filters.describe())
    num_hypothesis_clause_successes += 1
  else:
    print( "Clause failed:", the_filters.describe())
    num_hypothesis_clause_failures += 1

In [None]:
# Get the model nodes with a known algorithmic purpose
model_nodes = filter_nodes( useful_info.nodes, FilterAlgo("", QuantaFilter.MAY))
assert len(model_nodes) > 0

if cfg.perc_add() > 0:
  for impact_digit in range(cfg.n_digits):
    # For every answer digit (except the first 1 or 0 answer digit), Dn.BA and Dn.MC values are calculated before the answer digit is revealed
    node_exists(model_nodes, FilterAnd(FilterAlgo(add_ba_tag(impact_digit)), FilterPosition(an_to_position_name(impact_digit+1), QuantaFilter.MUST_BY)))
    node_exists(model_nodes, FilterAnd(FilterAlgo(add_mc_tag(impact_digit)), FilterPosition(an_to_position_name(impact_digit+1), QuantaFilter.MUST_BY)))

    early_dnc = FilterAnd(FilterAlgo(add_tc_tag(impact_digit)), FilterPosition(an_to_position_name(cfg.n_digits+1), QuantaFilter.MUST_BY))
    late_dnc = FilterAnd(FilterAlgo(add_tc_tag(impact_digit)), FilterPosition(an_to_position_name(impact_digit+1), QuantaFilter.MUST_BY))
    any_dnus = FilterAnd(FilterAlgo(add_us_tag(impact_digit)), FilterPosition(an_to_position_name(impact_digit+1), QuantaFilter.MUST_BY))

    if cfg.n_layers == 1:
      # There must be a Dn.US node for every answer digit except A0
      if impact_digit > 0:
        node_exists(model_nodes, any_dnus)
    else:
      # There must be a Dn.US node or a Dn.C node for every digit except A0
      if impact_digit > 0:
        node_exists(model_nodes, FilterOr(any_dnus, late_dnc))

      # There must a Dn.C node for every digit before the first 1 or 0 digit is calculated
      node_exists(model_nodes, early_dnc)

if cfg.perc_sub > 0:
  for impact_digit in range(cfg.n_digits):
    # For every answer digit (except the first 1 or 0 answer digit), Dn.BS values are calculated before the answer digit is revealed
    node_exists(model_nodes, FilterAnd(FilterAlgo(sub_bs_tag(impact_digit)), FilterPosition(an_to_position_name(impact_digit+1), QuantaFilter.MUST_BY)))


print( "Overall", num_hypothesis_clause_successes, "out of", num_hypothesis_clause_successes + num_hypothesis_clause_failures, "clauses succeeded")

# Part 30: Unit Test automated searches

In [None]:
def unit_test_node_tag(node_location_as_str, the_tags ):
  node_location = str_to_node_location(node_location_as_str)
  node = useful_info.get_node(node_location)
  assert node is not None

  for the_tag in the_tags:
    assert node.contains_tag( QuantaType.ALGO, the_tag)

In [None]:
print(cfg.model_name)

if cfg.model_name == "add_d6_l2_h3_t15K":
  unit_test_node_tag("P11L0H0", ["D2.TC"] )
  unit_test_node_tag("P12L0H0", ["D3.TC"] )
  unit_test_node_tag("P14L0H0", ["A5.US", "D4.TC"] )
  unit_test_node_tag("P14L0H2", ["A5.MC", "D5.TC"] )
  unit_test_node_tag("P14L1H1", ["OP"] )
  unit_test_node_tag("P15L0H0", ["A4.MC"] )
  unit_test_node_tag("P15L0H1", ["A5.BA"] )
  unit_test_node_tag("P15L0H2", ["A5.BA"] )
  unit_test_node_tag("P16L0H0", ["A3.MC"] )
  unit_test_node_tag("P16L0H1", ["A4.BA"] )
  unit_test_node_tag("P16L0H2", ["A4.BA"] )
  unit_test_node_tag("P17L0H0", ["A2.MC"] )
  unit_test_node_tag("P17L0H1", ["A3.BA"] )
  unit_test_node_tag("P17L0H2", ["A3.BA"] )
  unit_test_node_tag("P18L0H0", ["A1.MC"] )
  unit_test_node_tag("P18L0H1", ["A2.BA"] )
  unit_test_node_tag("P18L0H2", ["A2.BA"] )
  unit_test_node_tag("P19L0H0", ["A0.MC"] )
  unit_test_node_tag("P19L0H1", ["A1.BA"] )
  unit_test_node_tag("P19L0H2", ["A1.BA"] )
  unit_test_node_tag("P20L0H1", ["A0.BA"] )
  unit_test_node_tag("P20L0H2", ["A0.BA"] )


