# Accurate Integer Mathematics in Transformers - Analyse the Model

This CoLab analyses a Transformer model that performs integer addition, subtraction and multiplication e.g. 133357+182243=+0315600, 123450-345670=-0123230 and 000345*000823=+283935. Each digit is a separate token. For 6 digit questions, the model is given 14 "question" (input) tokens, and must then predict the corresponding 8 "answer" (output) tokens.

The model weightings created by the sister CoLab [Accurate_Math_Train](https://github.com/PhilipQuirke/transformer-maths/blob/main/assets/Accurate_Math_Train.ipynb) are loaded from HuggingFace.

Focus is on ins_sub_d6_l3_h4_dm510_dh170_ctx22_seed372001_train20K.pth which was initialised with add_d6_l2_h3_dm510_dh170_ctx22_seed372001_train15K.pth before being trained

## Tips for using the Colab
 * You can run and alter the code in this CoLab notebook yourself in Google CoLab ( https://colab.research.google.com/ ).
 * To run the notebook, in Google CoLab, **you will need to** go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.
 * Some graphs are interactive!
 * Use the table of contents pane in the sidebar to navigate.
 * Collapse irrelevant sections with the dropdown arrows.
 * Search the page using the search in the sidebar, not CTRL+F.

# Part 1: Configuration


In [853]:
# Tokens used in vocab. (Token indexes 0 to 9 represent digits 0 to 9)
PLUS_INDEX = 10
MINUS_INDEX = 11
EQUALS_INDEX = 12
MULT_INDEX = 13
DIV_INDEX = 14
MAX_INDEX = DIV_INDEX

In [854]:
# Main configuration class for main model creation and training
class Config():
  #@markdown Main Model
  n_layers: int = 3 #@param
  n_heads: int = 4 #@param

  d_vocab: int = MAX_INDEX+1
  d_model: int = 510
  d_mlp: int = 4 * d_model
  d_head: int = 170
  seed: int = 372001 # 129000

  #@markdown Data
  n_digits: int = 6 #@param
  n_ctx: int = 3 * n_digits + 4
  act_fn: str = 'relu'
  batch_size: int = 64

  #@markdown Optimizer
  n_training_steps: int = 20000 #@param
  weight_decay: float = 0.00008
  lr: int = 0.1

  #@markdown Actions

  # Percent of questions that are multiplication, subtraction (rest are addition questions).
  perc_mult: int = 0 #@param e.g. 20
  perc_sub: int = 100 #@param e.g. 80
  def perc_add(self):
    return max(0, 100 - self.perc_mult - self.perc_sub)

  #@markdown Insert Model
  insert: bool = False #@param   Was the model trained using an inserted "known good" model


  # Save graphs to CoLab temp files as PDF and HTML. Can manually export files for re-use in papers.
  save_graph_to_file: bool = True

  # The format to output prettytable in. Options are text|html|json|csv|latex
  # Use Text for this CoLab, latex for Overleaf output, and html for GitHub blog output
  table_out_format: str = "text"


cfg = Config()

In [855]:
def file_name_suffix(digits, layers, heads, d_model, d_head, ctx, seed, training_steps):
  epoch_str = str(training_steps//1000) + "K"
  return '_d{}_l{}_h{}_dm{}_dh{}_ctx{}_seed{}_train{}'.format(digits, layers, heads, d_model, d_head, ctx, seed, epoch_str)

fname_prefix = 'ins_' if cfg.insert else ''
fname_prefix += 'mul' if cfg.perc_mult == 100 else 'sub' if cfg.perc_sub == 100 else 'add' if cfg.perc_add() == 100 else 'mix'
main_fname_suffix = fname_prefix + file_name_suffix(cfg.n_digits, cfg.n_layers, cfg.n_heads, cfg.d_model, cfg.d_head, cfg.n_ctx, cfg.seed, cfg.n_training_steps)
main_fname_full = main_fname_suffix + '.pth'


def print_config():
  print("Config: %Mult=", cfg.perc_mult, "%Sub=", cfg.perc_sub, "%Add=", cfg.perc_add(), main_fname_suffix)

print_config()

Config: %Mult= 0 %Sub= 100 %Add= 0 sub_d6_l3_h4_dm510_dh170_ctx22_seed372001_train20K


# Part 2A: Import libraries
Imports standard libraries. Don't bother reading.

In [None]:
DEVELOPMENT_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")

    !pip install numpy --upgrade
    !pip install scikit-learn --upgrade
    !pip install matplotlib
    !pip install prettytable

    !pip install kaleido
    !pip install transformer_lens
    !pip install torchtyping
    !pip install transformers

except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook


In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import kaleido
import plotly.io as pio

if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

import plotly.express as px
import plotly.graph_objects as go

In [None]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import random
from prettytable import PrettyTable

In [None]:
import matplotlib.pyplot as plt

# Use Principal Component Analysis (PCA) library
use_pca = True
try:
  from sklearn.decomposition import PCA
except Exception as e:
  print("pca import exception:", e)
  use_pca = False

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache


 # Part 2B: Utilities

In [None]:
def get_row_heading(i):
  head = i % (cfg.n_heads + 1)
  layer = i // (cfg.n_heads + 1)
  return ( "L" + str(layer) + "H" + str(head) ) if head < cfg.n_heads else "MLP "


def show_2d_map_start(width_inches=12,height_inches=6):
  # Create figure and axes
  fig, ax = plt.subplots(figsize=(width_inches, height_inches))  # Adjust the figure size as needed

  # Ensure cells are square
  ax.set_aspect('equal', adjustable='box')

  return fig, ax


def show_2d_map_end(title, fig, ax, min_col, max_col):

  # Vertical and horizontal axis labels
  num_rows = table_rows()
  num_cols = max_col - min_col + 1
  vertical_labels = [get_row_heading(num_rows-i-1) for i in range(num_rows)]
  horizontal_labels = [f"P{min_col+i}" for i in range(max_col-min_col+1)]

  # Set axis limits
  ax.set_xlim(0, len(horizontal_labels))
  ax.set_ylim(0, len(vertical_labels))

  # Set axis labels
  ax.set_xticks(np.arange(0.5, num_cols, 1))
  ax.set_yticks(np.arange(0.5, num_rows, 1))
  ax.set_xticklabels(horizontal_labels)
  ax.set_yticklabels(vertical_labels)

  # Move the x-axis to the top
  ax.xaxis.tick_top()
  ax.xaxis.set_label_position('top')

  filename = fname_prefix + '_' + title + '_d{}_l{}_h{}'.format(cfg.n_digits, cfg.n_layers, cfg.n_heads)
  ax.set_title(filename)

  if cfg.save_graph_to_file:
    plt.savefig(filename+".pdf")
    #plt.savefig(filename+".svg")

  # Show plot
  plt.show()

# Part 3: Create model
This section defines the token embedding / unembedding and creates the model.

In [None]:
# Embedding / Unembedding

def token_to_char(i):
  if i < 10:
   return str(i)
  if i == PLUS_INDEX:
    return "+"
  if i == MINUS_INDEX:
    return "-"
  if i == EQUALS_INDEX:
    return "="
  if i == MULT_INDEX:
    return "*"
  if i == DIV_INDEX:
    return "\\"
  return "?"

def tokens_to_string(tokens):
    tokens = utils.to_numpy(tokens)
    return "".join([token_to_char(i) for i in tokens[:cfg.n_ctx]])

def string_to_tokens(string, batch: bool=False):
    lookup = {str(i):i for i in range(10)}
    lookup['+']=PLUS_INDEX
    lookup['-']=MINUS_INDEX
    lookup['=']=EQUALS_INDEX
    lookup['*']=MULT_INDEX
    lookup['\\']=DIV_INDEX

    tokens = [lookup[i] for i in string if i not in '\n ']
    if batch:
        return torch.tensor(tokens)[None, :]
    else:
        return torch.tensor(tokens)

In [None]:
# Transformer creation

# Structure is documented at https://neelnanda-io.github.io/TransformerLens/transformer_lens.html#transformer_lens.HookedTransformerConfig.HookedTransformerConfig
ht_cfg = HookedTransformerConfig(
    n_layers = cfg.n_layers,
    n_heads = cfg.n_heads,
    d_model = cfg.d_model,
    d_head = cfg.d_head,
    d_mlp = cfg.d_mlp,
    act_fn = cfg.act_fn,
    normalization_type = 'LN',
    d_vocab = cfg.d_vocab,
    d_vocab_out = cfg.d_vocab,
    n_ctx = cfg.n_ctx,
    init_weights = True,
    device = "cuda",
    seed = cfg.seed,
)

main_model = HookedTransformer(ht_cfg)

optimizer = torch.optim.AdamW(main_model.parameters(),
                        lr = cfg.lr,
                        weight_decay = cfg.weight_decay,
                        betas = (0.9, 0.98))

max_iter = cfg.n_training_steps
warmup_iter = max_iter // 5
scheduler1 = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, total_iters=int(warmup_iter))
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=int(np.ceil((max_iter-warmup_iter))))
scheduler  = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[int(warmup_iter)])

# Part 4: Loss Function & Data Generator
This section defines the loss function and the training/tesing data generator.

In [None]:
# Loss functions

# Calculate the per-token probability by comparing a batch of prediction "logits" to answer "tokens"
def logits_to_tokens_loss(logits, tokens):
  # Addition answer can have one extra digit than question. Answer also has a +/- sign
  n_answer_digits = cfg.n_digits+2

  # The addition answer digit token probabilities
  ans_logits = logits[:, -(n_answer_digits+1):-1]

  # Convert raw score (logits) vector into a probability distribution.
  # Emphasize the largest scores and suppress the smaller ones, to make them more distinguishable.
  ans_probs = F.log_softmax(ans_logits.to(torch.float64), dim=-1)

  max_prob_tokens = torch.argmax(ans_probs, dim=-1)

  # The addition answer digit tokens
  ans_tokens = tokens[:, -(n_answer_digits):]

  # Extract values from the ans_probs tensor, based on indices from the ans_tokens tensor
  ans_loss = torch.gather(ans_probs, -1, ans_tokens[:, :, None])[..., 0]

  return ans_loss, max_prob_tokens


# Calculate loss as negative of average per-token mean probability
def loss_fn(ans_loss):
  return -ans_loss.mean(0)

In [None]:
# Define "iterator" data generator function. Invoked using next().
# "Addition" batch entries are formated XXXXX+YYYYY=+ZZZZZZ e.g. 550030+800020=+1350050
# "Subtraction" batch entries are formated XXXXX-YYYYY=-ZZZZZZ e.g. 550030-800020=-0249990, 800020-550030=+0249990
# "Multiplication" batch entries are formated 000XXX*000YYY=+ZZZZZZ e.g. 000345*000678=+233910
def data_generator( ):
    torch.manual_seed(cfg.seed)
    while True:
        #generate a batch of questions (answers calculated below)
        batch = torch.zeros((cfg.batch_size, cfg.n_ctx)).to(torch.int64)
        x = torch.randint(0, 10, (cfg.batch_size, cfg.n_digits))
        y = torch.randint(0, 10, (cfg.batch_size, cfg.n_digits))

        batch_rand = random.randint(1, 100)
        batch_op = MULT_INDEX if batch_rand <= cfg.perc_mult else MINUS_INDEX if batch_rand <= cfg.perc_mult + cfg.perc_sub else PLUS_INDEX

        if batch_op == MULT_INDEX:
          # Convert from NNNNNN*NNNNNN= to 000NNN*000NNN= so answer (product) is NNNNNN
          num_zeros = cfg.n_digits // 2
          for z in range(num_zeros):
            x[:, z] = 0
            y[:, z] = 0


        # Enrich the question data on 60% of batches to speed up training
        if ( batch_op == PLUS_INDEX or batch_op == MINUS_INDEX ) and (random.randint(1, 5) < 3):
            # Flatten x and y to 1D tensors
            x_flat = x.view(-1)
            y_flat = y.view(-1)

            if batch_op == PLUS_INDEX :
              # The UseSum9 task is compound and rare and so hard to learn.
              # Increase the MakeSum9 case frequency
              # UseSum9 also relies on MakeCarry1 (50%) from previous column.
              num_elements_to_modify = int(0.40 * x.numel()) # 40%
              indices_to_modify = torch.randperm(x_flat.numel())[:num_elements_to_modify]
              if random.randint(1, 2) == 1:
                x_flat[indices_to_modify] = 9 - y_flat[indices_to_modify]
              else:
                y_flat[indices_to_modify] = 9 - x_flat[indices_to_modify]
            else:
              # Empirically, the model seems to struggle with the sign calculation.
              # Minus signs are rarer than positive signs.
              # Generate more negative answers by increasing the y value
              y_flat[y_flat < 9] += 1

            # Reshape x and y back to its original shape
            x = x_flat.view(x.shape)
            y = y_flat.view(x.shape)


        batch[:, :cfg.n_digits] = x
        batch[:, cfg.n_digits] = batch_op
        batch[:, 1+cfg.n_digits:1+cfg.n_digits*2] = y
        batch[:, 1+cfg.n_digits*2] = EQUALS_INDEX

        # Convert each row into a 5-digit number
        x_values = x[:, 0]
        y_values = y[:, 0]
        for dn in range(1,cfg.n_digits):
          x_values = x_values * 10 + x[:, dn]
          y_values = y_values * 10 + y[:, dn]

        # Elementwise operations to give the 1D tensor answers
        if batch_op == MULT_INDEX:
          answers = x_values * y_values
        else:
          if batch_op == MINUS_INDEX:
            answers = x_values - y_values
          else:
            answers = x_values + y_values

        # Insert the answers into the batch
        for i in range(cfg.batch_size):
          answer = answers[i]

          sign = PLUS_INDEX
          if answer < 0:
            sign = MINUS_INDEX
            answer = - answer

          batch[i, 2+cfg.n_digits*2] = sign
          for j in range(cfg.n_digits+1):
            batch[i, cfg.n_ctx-j-1] = answer % 10
            answer = answer // 10
            if answer == 0:
                break

        yield batch.cuda()

In [None]:
# Initialise the data generator
ds = data_generator()

In [None]:
# Test data generator
tokens = next(ds)
print(tokens[:3,:])

# Part 5: Load Model from HuggingFace

In [None]:
print("Loading model", main_fname_full)

main_model.load_state_dict(utils.download_file_from_hf(repo_name="PhilipQuirke/Accurate6DigitSubtraction", file_name=main_fname_full, force_is_torch=True))

main_model.eval()

# Part 8: Sample Questions Set Up

Create sets of sample questions exercising different use cases.

In [None]:
# Insert a number into the question
def insert_question_number(the_question, index, first_digit_index, the_digits, n):

  last_digit_index = first_digit_index + the_digits - 1

  for j in range(the_digits):
    the_question[index, last_digit_index-j] = n % 10
    n = n // 10


# Create a single question
def make_a_question(the_question, index, q1, q2, operator ):

  insert_question_number(the_question, index, 0, cfg.n_digits, q1)

  the_question[index, cfg.n_digits] = operator

  insert_question_number( the_question, index, cfg.n_digits+1, cfg.n_digits, q2)

  the_question[index, 2*cfg.n_digits+1] = EQUALS_INDEX

  answer = q1+q2
  if operator == MINUS_INDEX:
    answer = q1-q2
  else:
    if operator == MULT_INDEX:
      answer = q1*q2

  the_question[index, 2*cfg.n_digits+2] = PLUS_INDEX if answer >= 0 else MINUS_INDEX
  if answer < 0:
    answer = -answer

  insert_question_number(the_question, index, 2*cfg.n_digits + 3, cfg.n_digits+1, answer)


# Create a batch of questions from a 2D matrix of ints
def make_questions(q_matrix, operator):
  max_len = len(q_matrix)
  real_len = 0

  questions = torch.zeros((max_len, cfg.n_ctx)).to(torch.int64)

  limit = 10 ** cfg.n_digits
  for i in range(max_len):
    a = q_matrix[i][0]
    b = q_matrix[i][1]

    if a < limit and b < limit:
      real_len += 1
      make_a_question(questions, i, a, b, operator)

  return questions[:real_len]

In [None]:
# Analyse the question and return the use case as Addition (BA, MC, SimpleUS9 or CascadeUS9) or Subtraction (BS, B1, C1, CN)
def get_question_case(question):
  qlist = utils.to_numpy(question)
  inputs = qlist[:2*cfg.n_digits+2]
  operator = qlist[cfg.n_digits]

  if operator == PLUS_INDEX:

    # Locate the MC and MS digits (if any)
    mc = torch.zeros(cfg.n_digits).to(torch.int64)
    ms = torch.zeros(cfg.n_digits).to(torch.int64)
    for dn in range(cfg.n_digits):
      if inputs[dn] + inputs[dn + cfg.n_digits + 1] == 9:
        ms[cfg.n_digits-1-dn] = 1
      if inputs[dn] + inputs[dn + cfg.n_digits +1] > 9:
        mc[cfg.n_digits-1-dn] = 1

    # Calculate the use case of a question
    if torch.sum(mc) == 0:
      return "S0"

    if torch.sum(ms) == 0:
      return "S1"

    for dn in range(cfg.n_digits):
      if dn < cfg.n_digits-2 and mc[dn] == 1 and ms[dn+1] == 1 and ms[dn+2] == 1:
        return "S3+" # MC cascades 2 or more digits

    for dn in range(cfg.n_digits):
      if dn < cfg.n_digits-1 and mc[dn] == 1 and ms[dn+1] == 1:
        return "S2" # Simple US 9

    return "S1"


  if operator == MINUS_INDEX:

    # Locate the B1 and MZ digits (if any)
    b1 = torch.zeros(cfg.n_digits).to(torch.int64)
    mz = torch.zeros(cfg.n_digits).to(torch.int64)
    for dn in range(cfg.n_digits):
      if inputs[dn] - inputs[dn + cfg.n_digits + 1] < 0:
        b1[cfg.n_digits-1-dn] = 1
      if inputs[dn] + inputs[dn + cfg.n_digits +1] == 0:
        mz[cfg.n_digits-1-dn] = 1

    # Evaluate BaseSub questions - when no column generates a Borrow One
    if torch.sum(b1) == 0:
      return "M0"

    # Evaluate subtraction "cascade multiple steps" questions
    for dn in range(cfg.n_digits):
      if dn < cfg.n_digits-2 and b1[dn] == 1 and mz[dn+1] == 1 and mz[dn+2] == 1:
        return "M3+" # B1 cascades 2 or more digits

    # Evaluate subtraction "cascade 1" questions
    for dn in range(cfg.n_digits):
      if dn < cfg.n_digits-1 and b1[dn] == 1 and mz[dn+1] == 1:
        return "M2" # B1 cascades 1 digit

    return "M1"


  if operator == MULT_INDEX:
    return "MUL"


  return "OP?"

In [None]:
# Manually create some questions that strongly test one use case


# Make BaseAdd questions
def make_s0_questions():
    return make_questions(
      [[0, 0],
      [12345, 33333],
      [33333, 12345],
      [45762, 33113],
      [888, 11111],
      [2362, 23123],
      [15, 81],
      [1000, 4440],
      [4440, 1000],
      [24033, 25133],
      [23533, 21133],
      [32500, 1],
      [31500, 1111],
      [5500, 12323],
      [4500, 2209],
      [33345, 66643], # =099988
      [66643, 33345], # =099988
      [10990, 44000],
      [60000, 30000],
      [10000, 20000],
      [109900, 440000],
      [600000, 300000],
      [100000, 200000],
      [1099000, 4400000],
      [6000000, 3000000],
      [1000000, 2000000],
      [10990000, 44000000],
      [60000000, 3000000],
      [10000000, 20000000]],
      PLUS_INDEX)


# Make UseCarry1 (addition) questions
def make_s1_questions():
    return make_questions(
      [[ 15, 45],
      [ 27, 55],
      [ 35, 59],
      [ 150, 450],
      [ 270, 550],
      [ 350, 590],
      [ 1500, 4500],
      [ 2700, 5500],
      [ 3500, 5900],
      [ 40035, 40049],
      # [ 44000, 46000], D6 L1 H3 model cant handle this.
      [ 70000, 40000],
      [ 15000, 25000],
      [ 35000, 35000],
      [ 45000, 35000],
      [ 67000, 25000],
      [ 19000, 76000],
      [ 15020, 45091],
      [ 25002, 55019],
      [ 35002, 59019],
      [ 150200, 450910],
      [ 250020, 550190],
      [ 350020, 590190],
      [ 1502000, 4509100],
      [ 2500200, 5501900],
      [ 3500200, 5901900],
      [ 15020000, 45091000],
      [ 25002000, 55019000],
      [ 35002000, 59019000]],
      PLUS_INDEX)


# Make SimpleUseSum9 (addition) questions
def make_s2_questions():
    return make_questions(
      [[ 55, 45],
      [ 45, 55],
      [ 45, 59],
      [ 35, 69],
      [ 25, 79],
      [ 15, 85],
      [ 15, 88],
      [ 15508, 14500],
      [ 14508, 15500],
      [ 24533, 25933],
      [ 23533, 26933],
      [ 32500, 7900],
      [ 31500, 8500],
      [ 550, 450],
      [ 450, 550],
      [ 10880, 41127],
      [ 41127, 10880],
      [ 12386, 82623],
      [ 108800, 411270],
      [ 411270, 108800],
      [ 123860, 826230],
      [ 1088000, 4112700],
      [ 4112700, 1088000],
      [ 1238600, 8262300],
      [ 10880000, 41127000],
      [ 41127000, 10880000],
      [ 12386000, 82623000]],
      PLUS_INDEX)


# Make CascadeUseSum9 (addition) questions
def make_s3plus_questions():
    return make_questions(
      # These are two level UseSum9 cascades
      [[ 555, 445],
      [ 3340, 6660],
      [ 8880, 1120],
      [ 1120, 8880],
      [ 123, 877],
      [ 877, 123],
      [ 321, 679],
      [ 679, 321],
      [ 1283, 88786],
      # These are three level UseSum9 cascades
      [ 5555, 4445],
      [ 55550, 44450],
      [ 334, 666],
      [ 3340, 6660],
      [ 33400, 66600],
      [ 888, 112],
      [ 8880, 1120],
      [ 88800, 11200],
      [ 1234, 8766],
      [ 4321, 5679],
      # These are four level UseSum9 cascades
      [ 44445, 55555],
      [ 33334, 66666],
      [ 88888, 11112],
      [ 12345, 87655],
      [ 54321, 45679],
      [ 45545, 54455],
      [ 36634, 63366],
      [ 81818, 18182],
      [ 87345, 12655],
      [ 55379, 44621],
      # These are five level UseSum9 cascades
      [ 818818, 181182],
      [ 873345, 126655],
      [ 553379, 446621],
      # These are six level UseSum9 cascades
      [ 8188818, 1811182],
      [ 8733345, 1266655],
      [ 5533379, 4466621],
      # These are seven level UseSum9 cascades
      [ 81888818, 18111182],
      [ 87333345, 12666655],
      [ 55333379, 44626661]],
      PLUS_INDEX)


# Make questions focus mainly on 1 digit at a time
# (assuming that the 0 + 0 digit additions/subtractions are trivial bigrams)
def make_sn_questions():
    return make_questions(
      [[ 1, 0],
      [ 4, 3],
      [ 5, 5],
      [ 8, 1],
      [ 40, 30],
      [ 44, 46],
      [ 400, 300],
      [ 440, 460],
      [ 800, 100],
      [ 270, 470],
      [ 600, 300],
      [ 4000, 3000],
      [ 4400, 4600],
      [ 6000, 3000],
      [ 7000, 4000],
      [ 40000, 30000],
      [ 44000, 45000],
      [ 60000, 30000],
      [ 70000, 40000],
      [ 10000, 20000],
      [ 15000, 25000],
      [ 35000, 35000],
      [ 45000, 85000],
      [ 67000, 85000],
      [ 99000, 76000],
      [ 76000, 99000],
      [ 670000, 850000],
      [ 990000, 760000],
      [ 760000, 990000],
      [ 6700000, 8500000],
      [ 9900000, 7600000],
      [ 7600000, 9900000],
      [ 67000000, 85000000],
      [ 99000000, 76000000],
      [ 76000000, 99000000]],
      PLUS_INDEX)


# Make M0 questions - when no column generates a Borrow One. Answer is always positive (or zero).
def make_m0_questions():
    return make_questions(
      [[0, 0],
      [6, 6],
      [60, 60],
      [600, 600],
      [6000, 6000],
      [60000, 60000],
      [600000, 600000],
      [6000000, 6000000],
      [60000000, 60000000],
      [66666, 12345],
      [33333, 12321],
      [45762, 34551],
      [78901, 78901], # = +000000
      [23123, 23123], # = +000000
      [86, 15],
      [4440, 1230],
      [88746, 86544],
      [27833, 25133],
      [23533, 21133],
      [32501, 1],
      [31511, 1111],
      [55555, 12323],
      [45454, 22022],
      [66643, 3341],
      [66643, 30042],
      [99999, 44012],
      [60000, 30000],
      [99000, 99000], # = +000000
      [999990, 440120],
      [600000, 300000],
      [990000, 990000], # = +0000000
      [9999900, 4401200],
      [6000000, 3000000],
      [9900000, 9900000], # = +00000000
      [99999000, 44012000],
      [60000000, 30000000],
      [99000000, 99000000]], # = +000000000
      MINUS_INDEX)

# Make subtraction M1 questions with exactly one "borrow 1" instance. Answer is always positive.
def make_m1_questions():
    return make_questions(
      [[22222, 11113],
      [ 22222, 11131],
      [ 22222, 11311],
      [ 22222, 13111],
      [    14,     8],
      [   140,    80],
      [  1400,   800],
      [ 14000,  8000],
      [ 55514, 11108],
      [ 55140, 11080],
      [ 51400, 10800],
      [ 14000,  8000],
      [ 88888, 22229],
      [ 77777, 22292],
      [ 66666, 22922],
      [ 888888, 222292],
      [ 777777, 222922],
      [ 666666, 229222],
      [ 8888888, 2222922],
      [ 7777777, 2229222],
      [ 6666666, 2292222],
      [ 88888888, 22229222],
      [ 77777777, 22292222],
      [ 66666666, 22922222]],
      MINUS_INDEX)

# Make subtraction M2 questions containing B1 and DZ. Answer is always positive (or zero).
def make_m2_questions():
    return make_questions(
      [[22212, 11113],
      [ 22122, 11131],
      [ 21222, 11311],
      [   904,     8],
      [  9040,    80],
      [ 90400,   800],
      [ 55514, 11118],
      [ 55140, 11180],
      [ 51400, 11800],
      [ 88888, 22289],
      [ 77777, 22892],
      [ 66666, 28922],
      [ 888888, 222892],
      [ 777777, 228922],
      [ 666666, 289222],
      [ 8888888, 2228922],
      [ 7777777, 2289222],
      [ 6666666, 2892222],
      [ 88888888, 22289222],
      [ 77777777, 22892222],
      [ 66666666, 28922222]],
      MINUS_INDEX)

# Make subtraction M3,M4,... questions containing B1 and multiple DZs. Answer is always positive (or zero).
def make_m3plus_questions():
    return make_questions(
      [[21112, 11113],
      [ 21122, 11131],
      [ 99004,     8],
      [ 90040,    80],
      [ 55114, 11118],
      [ 51140, 11180],
      [ 88888, 22889],
      [ 77777, 28892],
      [ 888888, 228892],
      [ 777777, 288922],
      [ 8888888, 2288922],
      [ 7777777, 2889222],
      [ 88888888, 22889222],
      [ 77777777, 28892222]],
      MINUS_INDEX)

# Make subtraction questions with negative answers
def make_ng_questions():
    return make_questions(
      [[ 0, 1],
      [ 88888, 88889],
      [ 55555, 55556],
      [ 88881, 88891],
      [ 55551, 55561],
      [ 88811, 88911],
      [ 55511, 55611],
      [ 8, 12],
      [ 40, 232],
      [ 44, 523],
      [ 234, 334],
      [ 7777, 8434],
      [ 88888, 92222],
      [ 77777, 84340],
      [ 888888, 922220],
      [ 777777, 843400],
      [ 8888888, 9222200],
      [ 7777777, 8434000],
      [ 88888888, 92222000],
      [ 77777777, 84340000]],
      MINUS_INDEX)


def make_addition_questions():
  s0 = make_s0_questions()
  s1 = make_s1_questions()
  s2 = make_s2_questions()
  s3 = make_s3plus_questions()
  s4 = make_sn_questions()

  return torch.vstack((s0.cuda(), s1.cuda(), s2.cuda(), s3.cuda(), s4.cuda()))


def make_subtraction_questions():
  m0 = make_m0_questions()
  m1 = make_m1_questions()
  m2 = make_m2_questions()
  m3 = make_m3plus_questions()
  m4 = make_ng_questions()

  return torch.vstack((m0.cuda(), m1.cuda(), m2.cuda(), m3.cuda(), m4.cuda()))


v0 = next(ds) # Could be Add, Sub or Mult
v1 = next(ds) # Could be Add, Sub or Mult


# Returns random and manually-chosen questions
def make_varied_questions():
  if cfg.perc_mult == 100 :
    return torch.vstack((v0.cuda(), v1.cuda()))

  s0 = make_s0_questions()
  s1 = make_s1_questions()
  s2 = make_s2_questions()
  s3 = make_s3plus_questions()
  s4 = make_sn_questions()

  m0 = make_m0_questions()
  m1 = make_m1_questions()
  m2 = make_m2_questions()
  m3 = make_m3plus_questions()
  m4 = make_ng_questions()

  if cfg.perc_sub == 0 :
    return torch.vstack((v0.cuda(), s0.cuda(), s1.cuda(), s2.cuda(), s3.cuda(), s4.cuda(), v1.cuda()))

  if cfg.perc_sub == 100 :
    return torch.vstack((v0.cuda(), m0.cuda(), m1.cuda(), m2.cuda(), m3.cuda(), m4.cuda(), v1.cuda()))

  return torch.vstack((v0.cuda(), s0.cuda(), m0.cuda(), s1.cuda(), m1.cuda(), s2.cuda(), m2.cuda(), s3.cuda(), m3.cuda(), s4.cuda(), m4.cuda(), v1.cuda()))

In [None]:
# Build a test batch of random and manually-chosen questions
varied_questions = make_varied_questions();


# Run the sample batch, gather the cache
main_model.reset_hooks()
main_model.set_use_attn_result(True)
sample_logits, sample_cache = main_model.run_with_cache(varied_questions.cuda())
print(sample_cache) # Gives names of datasets in the cache
sample_losses_raw, sample_max_prob_tokens = logits_to_tokens_loss(sample_logits, varied_questions.cuda())
sample_loss_mean = utils.to_numpy(loss_fn(sample_losses_raw).mean())
print("Sample Mean Loss", sample_loss_mean) # Loss < 0.04 is good


# attn.hook_z is the "attention head output" hook point name (at a specified layer)
l_attn_hook_z_name = [utils.get_act_name('z', 0, 'a'),utils.get_act_name('z', 1, 'a'),utils.get_act_name('z', 2, 'a')] # 'blocks.0.attn.hook_z' etc
sample_attn_z = sample_cache[l_attn_hook_z_name[0]]
print("Sample", l_attn_hook_z_name[0], sample_attn_z.shape) # gives [350, 22, 3, 170] = num_questions, cfg.n_ctx, n_heads, d_head
mean_attn_z = torch.mean(sample_attn_z, dim=0, keepdim=True)
print("Mean", l_attn_hook_z_name[0], mean_attn_z.shape) # gives [1, 22, 3, 170] = 1, cfg.n_ctx, n_heads, d_head


# hook_resid_pre is the "pre residual memory update" hook point name (at a specified layer)
l_hook_resid_pre_name = ['blocks.0.hook_resid_pre','blocks.1.hook_resid_pre','blocks.2.hook_resid_pre']


# hook_resid_post is the "post residual memory update" hook point name (at a specified layer)
l_hook_resid_post_name = ['blocks.0.hook_resid_post','blocks.1.hook_resid_post','blocks.2.hook_resid_post']
sample_resid_post = sample_cache[l_hook_resid_post_name[0]]
print("Sample", l_hook_resid_post_name[0], sample_resid_post.shape) # gives [350, 22, 510] = num_questions, cfg.n_ctx, d_model
mean_resid_post = torch.mean(sample_resid_post, dim=0, keepdim=True)
print("Mean", l_hook_resid_post_name[0], mean_resid_post.shape) # gives [1, 22, 510] = 1, cfg.n_ctx, d_model


# mlp.hook_post is the "MLP layer" hook point name (at a specified layer)
l_mlp_hook_post_name = [utils.get_act_name('post', 0),utils.get_act_name('post', 1),utils.get_act_name('post', 2)] # 'blocks.0.mlp.hook_post' etc
sample_mlp_hook_post = sample_cache[l_mlp_hook_post_name[0]]
print("Sample", l_mlp_hook_post_name[0], sample_mlp_hook_post.shape) # gives [350, 22, 2040] = num_questions, cfg.n_ctx, d_model*4
mean_mlp_hook_post = torch.mean(sample_mlp_hook_post, dim=0, keepdim=True)
print("Mean", l_mlp_hook_post_name[0], mean_mlp_hook_post.shape) # gives [1, 22, 2040] = 1, cfg.n_ctx, d_model*4

In [None]:
def tokens_to_unsigned_int( q, offset, digits ):
  a = 0
  for j in range(digits):
    a = a * 10 + q[offset+j]
  return a


def tokens_to_signed_int( q, offset, digits ):
  a = tokens_to_unsigned_int( q, offset+1, digits )
  if q[offset] == MINUS_INDEX:
    a = - a
  return a

In [None]:
verbose = True

class T_Config():
  num_questions : int
  correct_answers : int
  total_mean_loss : float

  sum_num_questions : int
  sum_correct_answers : int

  output = PrettyTable()


  def reset(self):
    self.num_questions = 0
    self.correct_answers = 0
    self.total_mean_loss = 0.0
    self.sum_num_questions = 0
    self.sum_correct_answers = 0

    self.output = PrettyTable()
    self.output.field_names = ["Case", "#Questions", "#Correct", "%Correct", "Mean loss"]


  # Clear the question summary results
  def clear_questions_results(self, title):
    global verbose

    self.num_questions = 0
    self.correct_answers = 0
    self.total_mean_loss = 0

    if verbose:
      print(title)


  # Print the question summary results
  def print_questions_results(self, prefix):
    self.output.add_row([prefix, self.num_questions, str(self.correct_answers), 100*self.correct_answers/self.num_questions, self.total_mean_loss/self.num_questions])
    self.sum_num_questions += self.num_questions
    self.sum_correct_answers += self.correct_answers


  # Print the overall summary results
  def print_overall_results(self):
    self.output.add_row(["OVERALL", self.sum_num_questions, self.sum_correct_answers, "", ""])
    print(self.output.get_formatted_string(out_format=cfg.table_out_format))


tcfg = T_Config()
tcfg.reset()

In [None]:
# Ask model to predict answer for each question & collect results
def do_questions(questions, show_failures = False):
  global verbose
  global tcfg

  tcfg.num_questions = questions.shape[0]

  # Run with no hook
  all_logits = main_model(questions.cuda())
  all_losses_raw, all_max_prob_tokens = logits_to_tokens_loss(all_logits, questions.cuda())

  for question_num in range(tcfg.num_questions):
    q = questions[question_num]

    losses = loss_fn(all_losses_raw[question_num])
    mean_loss = utils.to_numpy(losses.mean())
    tcfg.total_mean_loss += mean_loss

    model_answer_str = tokens_to_string(all_max_prob_tokens[question_num])
    model_answer_num = int(model_answer_str)

    # 6 digit addition yields a 7 digit answer. Hence cfg.n_digits+1
    a = tokens_to_signed_int(q, cfg.n_digits*2 + 2, cfg.n_digits+1)

    correct = (model_answer_num == a)
    if correct :
      tcfg.correct_answers += 1

    if verbose or (show_failures and not correct):
      print(tokens_to_string(q), "ModelAnswer:", model_answer_str, "Correct:", correct, "Loss:", mean_loss )


# Part 9: Prediction Analysis By Use Case
This section runs addition (BA, UC1 and US9) and subtraction (BS, B1, C1, CN) test cases to show which uses cases the model can handle.

In [None]:
def print_question_results( title, questions, show_failures = False):
  tcfg.clear_questions_results(title)
  do_questions(questions, show_failures)
  tcfg.print_questions_results(title)

In [None]:
verbose = False

if cfg.perc_add() > 0:
  print( "ADDITION:")

  tcfg.reset()
  print_question_results("Add.S*: All addition cases", make_addition_questions(), True)
  tcfg.print_overall_results()

  tcfg.reset()
  print_question_results("Add.S0: Only use BA", make_s0_questions())
  print_question_results("Add.S1: BA and MC. No US", make_s1_questions())
  print_question_results("Add.S2: BA, MC and 1 x US", make_s2_questions())
  print_question_results("Add.S3,S4,S5: BA, MC and 2,3,4 x US", make_s3plus_questions())
  print_question_results("Add.SN: Cover different answer digits", make_sn_questions())
  tcfg.print_overall_results()

In [None]:
verbose = False

if cfg.perc_sub > 0:
  print( "SUBTRACTION:")

  tcfg.reset()
  print_question_results("Sub.M*,NG: All subtraction cases", make_subtraction_questions(), True)
  tcfg.print_overall_results()

  tcfg.reset()
  print_question_results("Sub.M0: Only use BS", make_m0_questions())
  print_question_results("Sub.M1: BS & B1 only", make_m1_questions())
  print_question_results("Sub.M2: BS, B1 and 1 x DZ", make_m2_questions())
  print_question_results("Sub.M3,M4,M5: BS, B1 and 2,3,4 x DZ", make_m3plus_questions())
  print_question_results("Sub.Neg: Negative answer", make_ng_questions())
  tcfg.print_overall_results()

# Part 11: Set Up "Count" Framework

Create way to get model to predict sample question answers and analysis/show results. Use prefix "c_"

In [None]:
# Build up a list of success/failure by case (BA, MC1, US9) found, and the frequency of each case
c_case_counts = {}


def count_question_cases(questions):
  global c_case_counts

  c_case_counts = {}

  for i in range(questions.shape[0]):
    q_case = get_question_case(questions[i])

    if q_case in c_case_counts:
      # If the key is already in the dictionary, increment its count
      c_case_counts[q_case] += 1
    else:
      # If the key is not in the dictionary, add it with a count of 1
      c_case_counts[q_case] = 1

In [None]:
# Compare each digit in the answer. Returns a A+45 pattern where '+' means a failed sign and '4' means a failed 4th digit
def get_digit_accuracy_impact(a_int, answer_str):
  a_str = str(a_int.cpu().numpy()).zfill(cfg.n_digits+1)
  match_str = "A"
  for i in range(cfg.n_digits+1):
    match_str += "" if answer_str[i] == a_str[i] else str(cfg.n_digits-i)

    #pqr handle+/-mismatch

  return "" if match_str == "A" else match_str

In [None]:
# Build up a list of success/failure digit-patterns found, and the frequency of each pattern
c_pattern_fails = {}


def clear_pattern_fails():
  global c_pattern_fails

  c_pattern_fails = {}


def add_pattern_fail(match_str):
  global c_pattern_fails

  if match_str in c_pattern_fails:
    # If the key is already in the dictionary, increment its count
    c_pattern_fails[match_str] += 1
  else:
    # If the key is not in the dictionary, add it with a count of 1
    c_pattern_fails[match_str] = 1


def get_pattern_fails():
  global c_pattern_fails

  results = ""
  top_result = ""
  if len(c_pattern_fails) > 0 :
    sorted_fails = dict(sorted(c_pattern_fails.items(), key=lambda item: item[1], reverse=True))
    for key, value in sorted_fails.items():
      this_cell = key + "=" + str(value)

      results = results + this_cell + " "

      if top_result == "":
        top_result = this_cell
      else:
        top_result = top_result + ", " + this_cell

  return results, top_result


def get_pattern_fails_total():
  global c_pattern_fails

  if len(c_pattern_fails) == 0:
    return 0

  total_sum = 0
  for key, value in c_pattern_fails.items():
      if isinstance(value, int):
          total_sum += value
  return total_sum

In [None]:
# Build up a count of failure cases
c_case_fails = {}


def clear_case_fails():
  global c_case_fails

  c_case_fails = {}


def add_case_fail(case_key):
  global c_case_fails

  if case_key in c_case_fails:
    # If the key is already in the dictionary, increment its count
    c_case_fails[case_key] += 1
  else:
    # If the key is not in the dictionary, add it with a count of 1
    c_case_fails[case_key] = 1


def total_case_fails():
  global c_case_fails

  answer = 0
  for _, value in c_case_fails.items():
    answer = answer + value
  return answer


def get_case_fails():
  global c_case_fails
  global c_case_counts

  results = ""
  num_results = len(c_case_fails)
  if num_results > 0:
    sorted_fails = dict(sorted(c_case_fails.items(), key=lambda item: item[1], reverse=True))

    for key, value in sorted_fails.items():
      percent = round(100 * value / c_case_counts[key])
      results = results + "%" + key + "=" + str(percent)+ " "

  return results

In [None]:
def predict_experiment_question(questions, the_hook, the_threshold, tag = ""):

  c_loss_mean = 0

  clear_case_fails()
  clear_pattern_fails()

  count_question_cases(questions)

  main_model.reset_hooks()
  main_model.set_use_attn_result(True)

  all_logits = main_model.run_with_hooks(questions.cuda(), return_type="logits", fwd_hooks=the_hook)
  all_losses_raw, all_max_prob_tokens = logits_to_tokens_loss(all_logits, questions.cuda())

  answer_str = ""
  for question_num in range(questions.shape[0]):
    q = questions[question_num]

    c_loss_mean = utils.to_numpy(loss_fn(all_losses_raw[question_num]).mean())

    # Only show the question if the loss exceeds the threshold (because of the ablated token position)
    if c_loss_mean > the_threshold:
      answer_str = tokens_to_string(all_max_prob_tokens[question_num])

      # 5 digit addition yields a 6 digit answer. Hence cfg.n_digits+1
      a = tokens_to_signed_int(q, cfg.n_digits*2 + 2, cfg.n_digits+1)

      match_str = get_digit_accuracy_impact( a, answer_str )
      # Only count the question if the model got the question wrong
      if 'A' in match_str:
        the_case = get_question_case(q)
        add_case_fail(the_case)
        add_pattern_fail(match_str)

        if verbose :
          print(tokens_to_string(q), "ModelAnswer:", answer_str, "Matches:", match_str, "Loss:", c_loss_mean, "Case:", the_case )

  return c_loss_mean

# Part 12: Ablate ALL Heads in EACH token position. What is the impact on Loss?

Here we ablate all heads in each token position (overriding the model memory aka residual stream) and see if loss increases. If loss increases the token position is used by the algorithm. Unused token positions can be excluded from further analysis. Use "C_" prefix

In [None]:
class C_Config():
  position : int = 0  # zero-based token position to ablate
  threshold : float = 0.01
  questions = varied_questions
  output = PrettyTable()
  perc_list = []
  hook_calls : int = 0

  min_useful_position : int = -1 # Minimum useful position where loss increases on ablation
  max_useful_position : int = -1 # Maximum useful position where loss increases on ablation


  def get_column_headings(self):
    datums = ["Position"]
    for i in range(self.min_useful_position, self.max_useful_position+1):
      datums = datums + ["P"+str(i)]
    return datums


ccfg = C_Config()
ccfg.output.field_names = ["Position", "Fails", "% Fails by Case", "# Fails by Patterns"]

In [None]:
verbose = False


def c_set_resid_post_hook(value, hook):
  global ccfg

  #print( "In hook", l_hook_resid_post_name[ccfg.layer], ccfg.ablate, ccfg.position, value.shape) # Get [64, 22, 510] = cfg.batch_size, num_tokens, d_model

  # Copy the mean resid post values in position N to all the batch questions
  value[:,ccfg.position,:] = mean_resid_post[0,ccfg.position,:].clone()


num_questions = 0
if cfg.n_digits >= 5 :
  c_fwd_hooks = [(l_hook_resid_post_name[0], c_set_resid_post_hook)] if cfg.n_layers == 1 else [(l_hook_resid_post_name[0], c_set_resid_post_hook),(l_hook_resid_post_name[1], c_set_resid_post_hook)]

  num_questions = ccfg.questions.shape[0]

  for ccfg.position in range(cfg.n_ctx):
    clear_case_fails()
    clear_pattern_fails()

    loss_mean = predict_experiment_question(ccfg.questions, c_fwd_hooks, ccfg.threshold)

    num_fails = total_case_fails()
    perc_fails = 0
    if num_fails > 0:
      perc_fails = round(100 * num_fails / num_questions)

      if ccfg.min_useful_position == -1:
        ccfg.min_useful_position = ccfg.position
      ccfg.max_useful_position = ccfg.position

    ccfg.perc_list = ccfg.perc_list + [perc_fails]

    (pattern_results, top_pattern) = get_pattern_fails()
    ccfg.output.add_row([str(ccfg.position), str(perc_fails)+"%", get_case_fails(), pattern_results])

In [None]:
print_config()
print("num_questions=", num_questions, "min_useful_position=", ccfg.min_useful_position, "max_useful_position=", ccfg.max_useful_position )
print()

plt.hist(ccfg.perc_list, cfg.n_ctx, facecolor='blue', alpha=0.5)
plt.xlabel('Position')
plt.ylabel('Probability')
plt.title(r'Histogram of IQ: $\mu=100$, $\sigma=15$')
# Tweak spacing to prevent clipping of ylabel
plt.subplots_adjust(left=0.15)
plt.show()

print(ccfg.output.get_formatted_string(out_format=cfg.table_out_format))

# Part 13: Setup: Cell matrix

Uses "u_" prefix.

In [None]:
def table_row(the_layer, the_head):
  return the_layer * (cfg.n_heads+1) + the_head

def table_rows():
  return (cfg.n_heads + 1) * cfg.n_layers

def table_cols():
  return ccfg.max_useful_position - ccfg.min_useful_position + 1

In [None]:
class UsefulCell():
  # Is this cell an attention head? If not, it must be an MLP layer
  is_head: bool = True

  # Position.Layer.Head of the cell
  position: int = -1  # token-position. Zero to cfg.n_ctx - 1
  layer: int = -1
  head: int = -1

  # Tags related to the cell
  tags = []


  # Row in a table that this cell is drawn
  def cell_row(self):
    return table_row(self.layer, self.head)


  # Add a tag to this cell (if not already present)
  def add_tag(self, tag):
    if tag != "" and (not (tag in self.tags)):
      self.tags += [tag]


  def min_tag_suffix(self, prefix):
    # Filter strings that start with the given prefix
    filtered_strings = [s for s in self.tags if s.startswith(prefix)]

    # Extract suffixes
    suffixes = [s.split('.')[1] for s in filtered_strings]

    # Return the minimum suffix if there are any, else return None
    return min(suffixes) if suffixes else ""

In [None]:
class U_Config():
  # This is a head+MLP (row) by token (column) matrix of percent of failure percentages with associated notes
  fail_percs = [[]]
  fail_notes = [[]]
  num_heads : int
  num_mlps : int

  # We (once) calculate the list of cells (attention head and MLP layers per position) that are useful to the model.
  calc_useful_cells = True
  # Once this list of useful cells is calculated (available) it is used to speed up functions.
  useful_cells = []


  def reset(self):
    self.fail_percs = [[0 for _ in range(cfg.n_ctx)] for _ in range((cfg.n_heads + 1) * cfg.n_layers)]
    self.fail_notes = [["" for _ in range(cfg.n_ctx)] for _ in range((cfg.n_heads + 1) * cfg.n_layers)]
    self.num_heads = 0
    self.num_mlps = 0

    if self.calc_useful_cells:
      self.useful_cells = []


  def reset_tags(self):
    for cell in self.useful_cells:
      cell.tags = []


  def get_cell( self, the_position, the_row ):
    for cell in self.useful_cells:
      if cell.position == the_position and cell.cell_row() == the_row:
        return cell

    return UsefulCell()


  def add_fail_perc( self, the_position, the_layer, the_head, perc_fails, notes, tag ):
    if perc_fails >= 1:
      the_row = table_row(the_layer, the_head)

      self.fail_percs[the_row][the_position] = perc_fails
      self.fail_notes[the_row][the_position] = notes

      if self.calc_useful_cells:
        # Add this  usefil cell. Check that we do not already have a cell at that row/col
        assert self.get_cell( the_position, the_row).position < 0

        new_cell = UsefulCell()
        new_cell.is_head = the_head != cfg.n_heads
        new_cell.position = the_position
        new_cell.layer = the_layer
        new_cell.head = the_head
        new_cell.add_tag(tag)

        self.useful_cells += [new_cell]
      else:
        self.get_cell(the_position,the_row).add_tag(tag)


  def add_head_fail_perc( self, the_position, the_layer, the_head, perc_fails, notes, tag ):
    self.add_fail_perc( the_position, the_layer, the_head, perc_fails, notes, tag )
    self.num_heads += 1


  def add_mlp_fail_perc( self, the_position, the_layer, perc_fails, notes, tag ):
    self.add_fail_perc( the_position, the_layer, cfg.n_heads, perc_fails, notes, tag )
    self.num_mlps += 1


ucfg = U_Config()
ucfg.reset()

In [None]:
# Print a 2 by 2 matrix of the percentage failures.
def print_u_fail_percs(title):
  global ucfg
  global ccfg

  print(title, "% failures when each head & MLP in each position is ablated: #FailedHeads=", ucfg.num_heads, "#FailedMlps=", ucfg.num_mlps )

  fig1, ax1 = show_2d_map_start(12, 2*cfg.n_layers) # Width, Height in inches

  # Generate a sequence of colors from green to red
  shades = 10
  colors = [plt.cm.RdYlGn(i/10) for i in range(shades)]

  # Add cells
  for i in range(table_rows()):
      for j in range(table_cols()):
        value = ucfg.fail_percs[i][ccfg.min_useful_position+j]
        if value == 100 and table_cols() > 15:
          value = 99 # Avoid overlapping figures in the matrix.

        cell_color = 'lightgrey'  # Color for empty cells
        if value > 0:
            color_index = value // shades
            cell_color = colors[color_index]
            cell_text = str(value)+"%"
            ax1.text(j + 0.5, i + 0.5, cell_text, ha='center', va='center', color='black')

        ax1.add_patch(plt.Rectangle((j, i), 1, 1, fill=True, color=cell_color))

  show_2d_map_end(title, fig1, ax1, ccfg.min_useful_position, ccfg.max_useful_position )



# Print a 2 by 2 matrix of notes.
def print_u_fail_notes():
  global ucfg
  global ccfg

  print("The most common failure pattern (with associated failure #) when each head or MLP in each position is ablated")

  cell_output = PrettyTable()
  cell_output.field_names = ccfg.get_column_headings()

  for i in range((cfg.n_heads + 1) * cfg.n_layers):
    datums = [get_row_heading(i)]
    for j in range(ccfg.min_useful_position, ccfg.max_useful_position+1):
      datums = datums + [ucfg.fail_notes[i][j]]
    cell_output.add_row(datums)

  print(cell_output.get_formatted_string(out_format=cfg.table_out_format))

# Part 14: Setup: Ablate each MLP in EACH position. Impact on Loss?
Ablating the MLP in each layer in each position and seeing if the loss increases shows which head+layer+MLP are used by the algorithm. Use "m_" prefix.

In [None]:
class M_Config():
  position : int  # zero-based token-position to ablate
  layer : int # zero-based layer to ablate. 0 to cfg.n_layers
  threshold : float
  output = PrettyTable()
  hook_calls : int
  questions = varied_questions


  def reset(self):
    self.position = 0
    self.layer = 0
    self.threshold = 0.12
    self.output = PrettyTable()
    self.output.field_names = ["Position", "MLP Layer", "% Fails", "% Fails by Case", "# Fails by Patterns"]
    self.hook_calls = 0


mcfg = M_Config()
mcfg.reset()

In [None]:
def m_mlp_hook_post(value, hook):
  global mcfg

  mcfg.hook_calls += 1
  #print( "In m_mlp_hook_post", value.shape) # Get [1, 22, 2040] = ???, cfg.n_ctx, ???

  # Mean ablate. Copy the mean resid post values in position N to the MLP
  value[:,mcfg.position,:] =  mean_mlp_hook_post[:,mcfg.position,:].clone()


def m_perform_core(tag):
  clear_case_fails()
  clear_pattern_fails()

  the_hook = [(l_mlp_hook_post_name[mcfg.layer], m_mlp_hook_post)]
  loss_mean = predict_experiment_question(mcfg.questions, the_hook, mcfg.threshold, tag)

  num_fails = total_case_fails()
  if num_fails > 0:
    perc_fails = round(100 * num_fails / mcfg.questions.shape[0])
    (pattern_results, top_pattern) = get_pattern_fails()

    mcfg.output.add_row([str(mcfg.position), str(mcfg.layer), perc_fails, get_case_fails(), pattern_results])

    ucfg.add_mlp_fail_perc( mcfg.position, mcfg.layer, perc_fails, top_pattern, tag )


def m_perform_all():
  ucfg.reset()
  mcfg.reset()
  for mcfg.position in range(cfg.n_ctx):
    for mcfg.layer in range(cfg.n_layers):
      m_perform_core("")


def m_perform_useful(tag):
  ucfg.reset()
  mcfg.reset()
  for useful_cell in ucfg.useful_cells:
    if not useful_cell.is_head:
      mcfg.position = useful_cell.position
      mcfg.layer = useful_cell.layer
      m_perform_core(tag)


def m_print_results(title):
  if verbose:
    print_config()
    print()
    print(title, mcfg.questions.shape[0])
    print(mcfg.output.get_formatted_string(out_format=cfg.table_out_format))

# Part 15: Setup: Ablate EACH head in EACH position. Impact on Digit & Task Loss?
Ablating each head in each layer in each position and seeing if the loss increases shows which position+layer+head are used by the algorithm. Use "h_" prefix.

In [None]:
class H_Config():
  position : int # zero-based token position to ablate. 0 to cfg.n_ctx - 1
  layer : int # zero-based layer to ablate. 0 to cfg.n_layers - 1
  head : int # zero-based head to ablate. 0 to cfg.n_heads - 1
  threshold : float
  output = PrettyTable()
  hook_calls: int
  questions = varied_questions


  def reset(self):
    self.position = 0 # zero-based token position to ablate. 0 to say 17
    self.layer = 0 # zero-based layer to ablate. 0 to 1
    self.head = 0 # zero-based head to ablate. 0 to 2
    self.threshold = 0.12
    self.output = PrettyTable()
    self.output.field_names = ["Position", "Layer", "Head", "% Fails", "% Fails by Case", "# Fails by Impact"]
    self.hook_calls = 0


  def print_results(self, title):
    if verbose:
      print_config()
      print()
      print(title, self.questions.shape[0], "#hook_calls=", self.hook_calls)
      print(self.output.get_formatted_string(out_format=cfg.table_out_format))


hcfg = H_Config()
hcfg.reset()

In [None]:
def h_set_attn_hook_z(value, hook):

  hcfg.hook_calls += 1
  # print( "In h_set_attn_hook_z", value.shape) # Get [1, 22, 3, 170] = ???, cfg.n_ctx, cfg.n_heads, d_head

  # Mean ablate. Copy the mean resid post values in position N to all the batch questions
  value[:,hcfg.position,hcfg.head,:] = mean_attn_z[:,hcfg.position,hcfg.head,:].clone()


def h_perform_core(tag):
  clear_case_fails()
  clear_pattern_fails()

  the_hook = [(l_attn_hook_z_name[hcfg.layer], h_set_attn_hook_z)]
  loss_mean = predict_experiment_question(hcfg.questions, the_hook, hcfg.threshold, tag)

  num_fails = total_case_fails()
  if num_fails > 0:
    perc_fails = round(100 * num_fails / hcfg.questions.shape[0])
    (pattern_results, top_pattern) = get_pattern_fails()

    hcfg.output.add_row([str(hcfg.position), str(hcfg.layer), str(hcfg.head), perc_fails, get_case_fails(), pattern_results])

    ucfg.add_head_fail_perc( hcfg.position, hcfg.layer, hcfg.head, perc_fails, top_pattern, tag)


def h_perform_all():
  hcfg.reset()
  for hcfg.position in range(ccfg.min_useful_position, ccfg.max_useful_position+1):
    for hcfg.layer in range(cfg.n_layers):
      for hcfg.head in range(cfg.n_heads):
        h_perform_core("")


def h_perform_useful(tag):
  hcfg.reset()
  for useful_cell in ucfg.useful_cells:
    if useful_cell.is_head:
      hcfg.position = useful_cell.position
      hcfg.layer = useful_cell.layer
      hcfg.head = useful_cell.head
      h_perform_core(tag)


# Part 16: Calculate show cell matrixes

Show the percentage failure rate (incorrect prediction) when individual Attention Heads and MLPs are ablated.

In [None]:
def calc_cell_matrices(tag, questions, all_cells):

  mcfg.questions = questions
  if all_cells:
    m_perform_all()
  else:
    m_perform_useful(tag)
  m_print_results(tag)

  hcfg.questions = questions
  if all_cells:
    h_perform_all()
  else:
    h_perform_useful(tag)
  hcfg.print_results(tag)

In [None]:
def print_cell_matrices(title):
  print_config()
  print()
  print_u_fail_percs(title)
  print()
  print_u_fail_notes()

In [None]:
def run_cell_matrices():
  global verbose

  verbose = False

  ucfg.calc_useful_cells = True
  calc_cell_matrices("Varied", varied_questions, True)
  ucfg.calc_useful_cells = False

  print_cell_matrices("Varied")


run_cell_matrices()

In [None]:
def calc_and_print_cell_matrices(tag, questions):
  calc_cell_matrices(tag, questions, False)
  print_cell_matrices(tag)

# Part 17A - Case Analysis
Processing Addition and Subtraction questions

In [None]:
ucfg.reset_tags()

In [None]:
if cfg.perc_add() > 0:
  calc_and_print_cell_matrices("Add.S0", make_s0_questions())

In [None]:
if cfg.perc_add() > 0:
  calc_and_print_cell_matrices("Add.S1", make_s1_questions())

In [None]:
if cfg.perc_add() > 0:
  calc_and_print_cell_matrices("Add.S2", make_s2_questions())

In [None]:
if cfg.perc_add() > 0:
  calc_and_print_cell_matrices("Add.S3+", make_s3plus_questions())

In [None]:
if cfg.perc_sub > 0:
  calc_and_print_cell_matrices("Sub.M0", make_m0_questions())

In [None]:
if cfg.perc_sub > 0:
  calc_and_print_cell_matrices("Sub.M1", make_m1_questions())

In [None]:
if cfg.perc_sub > 0:
  calc_and_print_cell_matrices("Sub.M2", make_m2_questions())

In [None]:
if cfg.perc_sub > 0:
  calc_and_print_cell_matrices("Sub.M3+", make_m3plus_questions())

In [None]:
if cfg.perc_sub > 0:
  calc_and_print_cell_matrices("Sub.NG", make_ng_questions())

In [None]:
qs = make_m0_questions()

#verbose = True
print_question_results("", qs)
calc_and_print_cell_matrices("UnitTest", qs)

In [None]:
# Process varied_questions random-questions here in case they include new edge cases not covered in the above hand-crafted questions
verbose = True

hcfg.reset()
hcfg.position = 0
hcfg.layer = 0

print("Head 0")
hcfg.head = 0
hcfg.questions = v0
h_perform_core("??.??")
hcfg.questions = v1
h_perform_core("??.??")

print("Head 1")
hcfg.head = 2
hcfg.questions = v0
h_perform_core("??.??")
hcfg.questions = v1
h_perform_core("??.??")

# Part 17B - Quanta Analysis
Show the "minimum" addition purpose of each useful cell by S0 to S5 quanta.
Show the "minimum" subtraction purpose of each useful cell by M0 to M5 quanta

In [None]:
def draw_quanta_tags( tag_prefix, file_prefix):
  fig1, ax1 = show_2d_map_start(12, 2*cfg.n_layers) # Width, Height in inches

  # Generate a sequence of colors from green to red
  shades = 6
  colors = [plt.cm.RdYlGn(i/10) for i in range(shades)]

  # Add cells
  for i in range(table_rows()):
    for j in range(table_cols()):
      cell_color = 'lightgrey'  # Color for empty cells

      cell = ucfg.get_cell( ccfg.min_useful_position+j, i )
      if cell.position >= 0:
        suffix = cell.min_tag_suffix( tag_prefix )
        if suffix != "":
          color_index = int(suffix[1]) if len(suffix) > 1 and suffix[1].isdigit() else 0
          cell_color = colors[color_index]
          ax1.text(j + 0.5, i + 0.5, suffix, ha='center', va='center', color='black')

      ax1.add_patch(plt.Rectangle((j, i), 1, 1, fill=True, color=cell_color))

  show_2d_map_end(file_prefix, fig1, ax1, ccfg.min_useful_position, ccfg.max_useful_position )


In [None]:
if cfg.perc_add() > 0:
  draw_quanta_tags( "Add", "Add.All")

In [None]:
if cfg.perc_sub > 0:
  draw_quanta_tags( "Sub", "Sub.All")

#Part 18: SetUp: Calc and graph PCA decomposition

In [None]:
tn_questions = 100


def make_t_questions(test_digit, test_case, operation):
    limit = 10 ** test_digit
    questions = []
    for i in range(tn_questions):


      if operation == PLUS_INDEX:
        if test_case == 8:
          # These are n_digit addition questions where the first test_digits add up from 0 to 8
          x = random.randint(0, 8)
          y = random.randint(0, 8-x)
        if test_case == 9:
          # These are n_digit addition questions where the first test_digits add up to 9
          x = random.randint(0, 9)
          y = 9 - x
        if test_case == 10:
          # These are n_digit addition questions where the first test_digits add up to 10 to 18
          x = random.randint(1, 9)
          y = random.randint(10-x, 9)


      if operation == MINUS_INDEX:
        if test_case == 8:
          # These are n_digit subtraction questions where the first test_digits difference is negative
          x = random.randint(0, 8)
          y = random.randint(x+1, 9)
        if test_case == 9:
          # These are n_digit subtraction questions where the first test_digits difference is zero
          x = random.randint(0, 9)
          y = x
        if test_case == 10:
          # These are n_digit subtraction questions where the first test_digits difference is positive
          x = random.randint(0, 9)
          y = random.randint(0, x-1)


      # Randomise the last test_digits-1 digits of both numbers
      x = x * limit + random.randint(0, limit-1)
      y = y * limit + random.randint(0, limit-1)
      questions.append([x, y])
    return make_questions(questions, operation)



def make_tricase_questions(test_digit, operation):
  q1 = make_t_questions(test_digit, 8, operation)
  q2 = make_t_questions(test_digit, 9, operation)
  q3 = make_t_questions(test_digit, 10, operation)

  questions = torch.vstack((q1, q2, q3))

  return questions

In [None]:
# Do one Principal Component Analysis
def calc_tricase_pca(t_position, t_layer, t_head, t_digit, operation):
  global tn_questions

  t_questions = make_tricase_questions(t_digit, operation)

  t_logits, t_cache = main_model.run_with_cache(t_questions)

  # Gather attention patterns for all the (randomly chosen) questions
  attention_outputs = []
  for i in range(len(t_questions)):

    # Output of individual heads, without final bias
    attention_cache=t_cache["result", t_layer, "attn"] # Output of individual heads, without final bias
    attention_output=attention_cache[i]  # Shape [n_ctx, n_head, d_model]
    attention_outputs.append(attention_output[t_position, t_head, :])

  attn_outputs = torch.stack(attention_outputs, dim=0).cpu()

  pca = PCA(n_components=6)
  pca.fit(attn_outputs)
  pca_attn_outputs = pca.transform(attn_outputs)

  title = tokens_to_string([operation]) + 'P' + str(t_position) + '.L' + str(t_layer) + '.H'+str(t_head) + ', A'+str(t_digit) + ', EVR[0]=' + str(int(round(pca.explained_variance_ratio_[0]*100,0))) + '%'

  return (pca, pca_attn_outputs, title)


# Plot one PCA scatter graph
def graph_pca(pca, pca_attn_outputs, ax, title):
  global tn_questions

  ax.scatter(pca_attn_outputs[:tn_questions, 0], pca_attn_outputs[:tn_questions, 1], color='red', label='T8 (0-8)') # t8 questions
  ax.scatter(pca_attn_outputs[tn_questions:2*tn_questions, 0], pca_attn_outputs[tn_questions:2*tn_questions, 1], color='green', label='T9') # t9 questions
  ax.scatter(pca_attn_outputs[2*tn_questions:, 0], pca_attn_outputs[2*tn_questions:, 1], color='blue', label='T10 (10-18)') # t10 questions

  if title != "" :
    ax.set_title(title)

In [None]:
# Graph the PCA of Pasn.Ln.Hn's attention pattern, using T8, T9, T10 questions that differ in the An digit
def add_one_pca_subplot(ax, t_position, t_layer, t_head, t_digit, operation):
  try:
    pca, pca_attn_outputs, title = calc_tricase_pca(t_position, t_layer, t_head, t_digit, operation)
    graph_pca( pca, pca_attn_outputs, ax, title)
  except Exception as e:
    desc = "add_one_pca_subplot(" + str(t_position) + ","+ str(t_layer) + ","+ str(t_head) + ","+ str(t_digit) + ","+ str(operation) + ")"
    print( desc + " Failed:", e)

In [None]:
def save_plt_to_file( full_title ):
  if cfg.save_graph_to_file:
    filename = full_title.replace(" ", "_").replace(",", "").replace(":", "_")  + '.png'
    plt.savefig(filename)

#Part 19: Addition PCA decomposition tri-state results

Plot attention heads in the positions 8 to 16 with a clear "tri-state" response to (exactly) one An.

In [None]:
if not use_pca:
  print( "PCA library failed to import. So PCA not done")

if use_pca and cfg.perc_add() > 0 :
  op = PLUS_INDEX

  fig, axs = plt.subplots(2, 2)

  if cfg.n_digits == 5 and cfg.n_layers == 2 and cfg.n_heads == 3:
    fig, axs = plt.subplots(4, 2)
    fig.set_figheight(8)
    fig.set_figwidth(5)

    # Plot all useful attention heads in the positions 8 to 12 with the clearest An selected
    add_one_pca_subplot(axs[0, 0], 8, 0, 1, 2, op)    # P8.L0.H1 is interpretable only for A2
    add_one_pca_subplot(axs[0, 1], 9, 0, 1, 1, op)    # P9.L0.H1 is interpretable only for A1
    add_one_pca_subplot(axs[1, 0], 11, 0, 1, 3, op)   # P11.L0.H1 is interpretable only for A3
    add_one_pca_subplot(axs[1, 1], 11, 0, 2, 4, op)   # P11.L0.H2 is interpretable only for A4
    add_one_pca_subplot(axs[2, 0], 12, 0, 1, 3, op)   # P12.L0.H1 is interpretable only for A3
    add_one_pca_subplot(axs[2, 1], 13, 0, 1, 2, op)   # P13.L0.H1 is interpretable only for A2
    add_one_pca_subplot(axs[3, 0], 14, 0, 1, 1, op)   # P14.L0.H1 is interpretable only for A1

  if cfg.n_digits == 6 and cfg.n_layers == 2 and cfg.n_heads == 3:
    fig, axs = plt.subplots(5, 2)
    fig.set_figheight(8)
    fig.set_figwidth(5)

    # Plot all useful attention heads in the positions 10 to 17 with the clearest An selected
    add_one_pca_subplot(axs[0, 0], 10, 0, 0, 3, op)   # P10.L0.H0 is interpretable only for A3
    add_one_pca_subplot(axs[0, 1], 11, 0, 0, 2, op)   # P11.L0.H0 is interpretable only for A2
    add_one_pca_subplot(axs[1, 0], 12, 0, 0, 1, op)   # P12.L0.H0 is interpretable only for A1
    add_one_pca_subplot(axs[1, 1], 14, 0, 0, 4, op)   # P14.L0.H0 is interpretable only for A4
    add_one_pca_subplot(axs[2, 0], 14, 1, 1, 4, op)   # P14.L1.H1 is interpretable only for A4
    add_one_pca_subplot(axs[2, 1], 15, 0, 0, 4, op)   # P15.L0.H0 is interpretable only for A4
    add_one_pca_subplot(axs[3, 0], 16, 0, 0, 3, op)   # P16.L0.H0 is interpretable only for A3
    add_one_pca_subplot(axs[3, 1], 17, 0, 0, 2, op)   # P17.L0.H0 is interpretable only for A2
    add_one_pca_subplot(axs[4, 0], 19, 0, 0, 0, op)   # P19.L0.H0 is interpretable only for A0


  lines_labels = [axs[0,0].get_legend_handles_labels()]
  lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
  fig.legend(lines, labels, loc='lower center', ncol=4)

  plt.tight_layout()
  save_plt_to_file('PCA_Trigrams')
  plt.show()

#Part 19B: Addition PCA decomposition bi-state results

Plot attention heads in the positions 8 to 16 with a clear "bi-state" response to (exactly) one An.

In [None]:
if not use_pca:
  print( "PCA library failed to import. So PCA not done")

if use_pca and cfg.perc_add() > 0 :
  op = PLUS_INDEX

  fig, axs = plt.subplots(1, 2)
  fig.set_figheight(2)
  fig.set_figwidth(5)

  if cfg.n_digits == 5 and cfg.n_layers == 2 and cfg.n_heads == 3:
    # Plot all useful attention heads in the positions 8 to 12 with the clearest An selected
    add_one_pca_subplot(axs[0], 10, 0, 1, 0, op)   # P10.L0.H1 is clear only for A0
    add_one_pca_subplot(axs[1], 15, 0, 1, 0, op)   # P15.L0.H1 is clear only for A0

  if cfg.n_digits == 6 and cfg.n_layers == 2 and cfg.n_heads == 3:
    # Plot all useful attention heads in the positions 8 to 12 with the clearest An selected
    add_one_pca_subplot(axs[0], 12, 0, 2, 0, op)   # P12.L0.H2 is clear only for A0


  lines_labels = [axs[0].get_legend_handles_labels()]
  lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
  fig.legend(lines, labels, loc='lower center', ncol=4)

  plt.tight_layout()
  save_plt_to_file('PCA_Bigrams')
  plt.show()

#Part 19C: PCA decomposition of useful cells with digits 0 to 4

Parts 19A and 19B are selective. This part is not. Use it to find (verify) the interesting parts.

In [None]:
def graph_all_pca_results(op):
  for useful_cell in ucfg.useful_cells:
    if useful_cell.is_head:
      position = useful_cell.position
      layer = useful_cell.layer
      head = useful_cell.head
      print( "PCA: position=", position, "layer=", layer, "head=", head)

      fig, axs = plt.subplots(4, 2)

      add_one_pca_subplot(axs[0, 0], position, layer, head, 0, op)
      add_one_pca_subplot(axs[0, 1], position, layer, head, 1, op)
      add_one_pca_subplot(axs[1, 0], position, layer, head, 2, op)
      add_one_pca_subplot(axs[1, 1], position, layer, head, 3, op)
      add_one_pca_subplot(axs[2, 0], position, layer, head, 4, op)
      add_one_pca_subplot(axs[2, 1], position, layer, head, 5, op)
      add_one_pca_subplot(axs[3, 0], position, layer, head, 6, op)
      add_one_pca_subplot(axs[3, 1], position, layer, head, 7, op)

      plt.tight_layout()
      plt.show()


if use_pca :

  if cfg.perc_add() > 0:
    graph_all_pca_results(PLUS_INDEX)
  if cfg.perc_sub > 0:
    graph_all_pca_results(MINUS_INDEX)

else:
  print( "PCA library failed to import. So PCA not done")

# Part 21A : Set Up Interchange Interventions

Here we test our mapping of our mathematical framework (causual abstraction) to the model attention heads.







In [None]:
class A_Config():
  token_position : int  # The token position we want to get/set. P8 to P11 contribute to A5 calculations
  layer : int # The layer we want to get/set
  heads = [] # The heads we want to get/set
  threshold : int
  hook_calls : int
  answer_failures : int   # Failures of any digit

  questions = []
  store = []
  null_hooks = []
  get_hooks = []
  put_hooks = []


  def reset(self):
    self.token_position = 10
    self.layer = 0
    self.heads = []
    self.threshold = 0.00001
    self.hook_calls = 0
    self.answer_failures = 0
    self.questions = []
    self.store = []
    self.null_hooks = []
    self.get_hooks = []
    self.put_hooks = []


acfg = A_Config()
acfg.reset()

In [None]:
# Get and put attention head value hooks

def a_null_attn_z_hook(value, hook):
  global acfg

  acfg.hook_calls += 1
  #print("In a_null_attn_z_hook", value.shape)  # Get [1, 22, 3, 170] = ???, cfg.n_ctx, cfg.n_heads, cfg.d_head


def a_get_l0_attn_z_hook(value, hook):
  global acfg

  if acfg.layer == 0:
    acfg.hook_calls += 1
    # print( "In a_get_l0_attn_z_hook", value.shape) # Get [1, 22, 3, 170] = ???, cfg.n_ctx, cfg.n_heads, cfg.d_head
    acfg.store = value.clone()


def a_get_l1_attn_z_hook(value, hook):
  global acfg

  if acfg.layer == 1:
    acfg.hook_calls += 1
    # print( "In acfg.get_l1_attn_z_hook", value.shape) # Get [1, 22, 3, 170] = ???, cfg.n_ctx, cfg.n_heads, cfg.d_head
    acfg.store = value.clone()


def a_put_l0_attn_z_hook(value, hook):
  global acfg

  if acfg.layer == 0:
    acfg.hook_calls += 1
    # print( "In a_l0_attn_z_hook", value.shape) # Get [1, 22, 3, 170] = ???, cfg.n_ctx, cfg.n_heads, d_head
    for head_index in acfg.heads:
      value[:,acfg.token_position,head_index,:] = acfg.store[:,acfg.token_position,head_index,:].clone()


def a_put_l1_attn_z_hook(value, hook):
  global acfg

  if acfg.layer == 1:
    acfg.hook_calls += 1
    # print( "In a_put_l1_attn_z_hook", value.shape) # Get [1, 22, 3, 170] = ???, cfg.n_ctx, cfg.n_heads, d_head
    for head_index in acfg.heads:
      value[:,acfg.token_position,head_index,:] = acfg.store[:,acfg.token_position,head_index,:].clone()


def a_reset(token_position, layer, heads):
  global acfg

  acfg.reset()

  acfg.token_position = token_position
  acfg.layer = layer
  acfg.heads = heads

  acfg.null_hooks = [(l_attn_hook_z_name[0], a_null_attn_z_hook)]
  acfg.get_hooks = [(l_attn_hook_z_name[0], a_get_l0_attn_z_hook),(l_attn_hook_z_name[1], a_get_l1_attn_z_hook)]
  acfg.put_hooks = [(l_attn_hook_z_name[0], a_put_l0_attn_z_hook),(l_attn_hook_z_name[1], a_put_l1_attn_z_hook)]

In [None]:
def a_predict_question(description, the_hooks, always):
  global acfg
  global model

  acfg.hook_calls = 0
  acfg.answer_failures = 0

  main_model.reset_hooks()
  main_model.set_use_attn_result(True)

  all_logits = main_model.run_with_hooks(acfg.questions.cuda(), return_type="logits", fwd_hooks=the_hooks)
  all_losses_raw, all_max_prob_tokens = logits_to_tokens_loss(all_logits, acfg.questions.cuda())

  for question_num in range(acfg.questions.shape[0]):
    loss_max = utils.to_numpy(loss_fn(all_losses_raw[question_num]).max())
    answer_str = tokens_to_string(all_max_prob_tokens[question_num])

    match_str = ""
    if loss_max > acfg.threshold:
      acfg.answer_failures += 1
      q = acfg.questions[question_num]
      a = tokens_to_signed_int(q, cfg.n_digits*2 + 2, cfg.n_digits+1)
      match_str = get_digit_accuracy_impact( a, answer_str )
    if match_str == "":
      match_str = "(none)"

    if always or (loss_mean > acfg.threshold):
      loss_str = "(none)" if loss_mean < 1e-7 else str(loss_max)

      print(description, "  ModelPredicts:", answer_str, "  DigitsImpacted:", match_str, "  Loss:", loss_str)


In [None]:
def a_run_intervention_core(token_position, layer, heads, store_question, alter_question):
  a_reset(token_position, layer, heads)

  # Predict first question and store activation values (including the Vn.BA)
  acfg.questions = make_questions([store_question], PLUS_INDEX)
  a_predict_question("Unit test (null hook)", acfg.null_hooks, False)
  a_predict_question("Store activation", acfg.get_hooks, False)

  # Predict second question. Then rerun overriding Pn_Lm_Hp to give bad answer
  acfg.questions = make_questions([alter_question], PLUS_INDEX)
  a_predict_question("Unit test (null hook)", acfg.null_hooks, False)
  prompt = "Intervening on P" + str(token_position) + ".L" + str(layer) + ".H"
  for head_index in acfg.heads:
    prompt += str(head_index) + ","
  a_predict_question(prompt, acfg.put_hooks, True)


def a_run_intervention(description, token_position, layer, heads, store_question, alter_question):
  if cfg.n_digits == 5 and cfg.n_layers == 2:
    print(description)
    a_run_intervention_core(token_position, layer, heads, store_question, alter_question)
    print()

# Part 21C: Subtraction edge case tests

In [None]:
global verbose
verbose = True

if cfg.perc_sub > 0:
  # Look for questions that fail even with no intervention
  acfg.questions = make_ng_questions()
  a_predict_question("Subtraction Neg questions", acfg.null_hooks, True)




# Part 21B : Run Addition Interchange Interventions

Here we test our mapping of our mathematical framework (casual abstraction) to the model attention heads.


In [None]:
print( "Test claim that P8.L0.H1 performs V2.C = TriCase(D2, D2’) impacting A4 and A5 accuracy")
print()

store_question = [44444, 55555] # Sum is 099999. V2 has no MC.
alter_question = [11111, 11111] # Sum is 022222. V2 has no MC.
a_run_intervention("No V2.MC: No impact expected", 8, 0, [1], store_question, alter_question)

store_question = [77711, 22711] # Sum is 100422. V2 has MC
alter_question = [44444, 55555] # Sum is 099999. V2 has no MC
a_run_intervention("Insert V2.MC: Expect A54 digit impacts. Expect 109999.", 8, 0, [1], store_question, alter_question)

store_question = [17711, 22711] # Sum is 035422. V2 has MC
alter_question = [ 4444,  5555] # Sum is 009999. V2 has no MC
a_run_intervention("Insert V2.MC: Expect A4 digit impact. Expect 019999.", 8, 0, [1], store_question, alter_question)

# Confirmed that P8.L0.H1 is: Based on D2 and D2'. Triggers on a V2 carry value. Provides "carry 1" used in A5 and A4 calculation.

In [None]:
print( "Test claim that P9.L0.H1 performs V1.C = TriCase(D1, D1’) impacting A5, A4 & A3 accuracy")
print()

store_question = [ 44444, 55555] # Sum is 099999. V1 has no MC.
alter_question = [ 11111, 11111] # Sum is 022222. V1 has no MC
a_run_intervention("No V1.MC: No impact expected", 9, 0, [1], store_question, alter_question)

store_question = [ 11171, 11171] # Sum is 022342. V1 has MC
alter_question = [ 44444, 55555] # Sum is 099999. V1 has no MC.
a_run_intervention("Insert V1.MC: Expect A543 digit impacts. Expect 100999.", 9, 0, [1], store_question, alter_question)

store_question = [ 11171, 11171] # Sum is 022342. V1 has MC
alter_question = [  4444,  5555] # Sum is 009999. V1 has no MC
a_run_intervention("Insert V1.MC: Expect A43 digit impacts. Expect 010999.", 9, 0, [1], store_question, alter_question)

store_question = [ 11171, 11171] # Sum is 022342. V1 has MC
alter_question = [   444,   555] # Sum is 000999. V1 has no MC
a_run_intervention("Insert V1.MC: Expect A3 digit impact. Expect 001999.", 9, 0, [1], store_question, alter_question)

# Confirmed that P9.L0.H1 is: Based on D1 and D1'. Triggers on a V1 carry value. Provides "carry 1" used in A5, A4 & A3 calculation.

In [None]:
print( "Test claim that P10.L0.H1 performs V1.C2 = TriAdd(V1.C, TriCase(D0, D0’)) impacting A5, A4, A3 & A2 accuracy")
print()

store_question = [ 11111, 33333] # Sum is 044444. V0 has no MC.
alter_question = [ 44444, 55555] # Sum is 099999. V0 has no MC
a_run_intervention("No impact expected", 10, 0, [1], store_question, alter_question)
# Results: No impact as expected

store_question = [ 11117, 11117] # Sum is 022234. V0 has MC
alter_question = [ 44444, 55555] # Sum is 099999. V0 has no MC
a_run_intervention("Insert D0.MC: Expect A5432 digit impacts. Expect 100099.", 10, 0, [1], store_question, alter_question)
# Results: Impact on A5432 as expected. Got expected value 100099

store_question = [ 11117, 11117] # Sum is 022234. V0 has MC
alter_question = [  4444,  5555] # Sum is 009999. V0 has no MC
a_run_intervention("Insert D0.MC: Expect A432 digit impacts. Expect 010099.", 10, 0, [1], store_question, alter_question)

store_question = [ 11117, 11117] # Sum is 022234. V0 has MC
alter_question = [   444,   555] # Sum is 000999. V0 has no MC
a_run_intervention("Insert D0.MC: Expect A32 digit impacts. Expect 001099", 10, 0, [1], store_question, alter_question)

store_question = [ 11117, 11117] # Sum is 022234. V0 has MC
alter_question = [    44,    55] # Sum is 000099. V0 has no MC
a_run_intervention("Insert D0.MC Expect A2 digit impacts. Expect 000199", 10, 0, [1], store_question, alter_question)

# Confirmed that P10.L0.H1 is: Based on D0 and D0'. Triggers on a V0 carry value. Provides "carry 1" used in A5, A4, A3 & A2 calculation.

In [None]:
print( "Test claim that P11.L0.H1 performs V3.C4 = TriAdd(TriCase(D3, D3’),TriAdd(V2.C,V1.C2)) impacting A5 accuracy")
print()

store_question = [44444, 44444] # Sum is 088888. V3 sums to 8 (has no MC).
alter_question = [11111, 11111] # Sum is 022222. V3 has no MC.
a_run_intervention("No V3.MC: No impact expected", 11, 0, [1], store_question, alter_question)

store_question = [16111, 13111] # Sum is 032111. V3 sums to 9 (has no MC).
alter_question = [44444, 55555] # Sum is 099999. V3 has no MC
a_run_intervention("No V3.MC: No impact expected", 11, 0, [1], store_question, alter_question)

store_question = [16111, 16111] # Sum is 032111. V3 has MC
alter_question = [44444, 55555] # Sum is 099999. V3 has no MC
a_run_intervention("Insert V3.MC: Expect A5 digit impact. Expect 199999.", 11, 0, [1], store_question, alter_question)

# Confirmed that P11.L0.H1 is: Based on D3 and D3'. Triggers on a V3 carry value. Provides "carry 1" used in A5 calculations.

In [None]:
print( "Test claim that P11.L0.H2 performs V4.C = TriCase(D4, D4’) impacting A5 accuracy")
print()

store_question = [44444, 55555] # Sum is 099999. V4 has no MC.
alter_question = [11111, 11111] # Sum is 022222. V4 has no MC.
a_run_intervention("No V4.MC: No impact expected", 11, 0, [2], store_question, alter_question)

store_question = [71111, 71111] # Sum is 100422. V4 has MC
alter_question = [44444, 55555] # Sum is 099999. V4 has no MC
a_run_intervention("Insert V4.MC: Expect A5 digit impact. Expect 199999.", 11, 0, [2], store_question, alter_question)

# Confirmed that P9.L0.H2 is: Based on D4 and D4'. Triggers on a V4 carry value. Provides "carry 1" used in A5 calculation.

In [None]:
print( "Test claim that P12.L0.H0 and H2 performs V4.BA = (D4 + D4’) % 10 impacting A4 accuracy")
print()

store_question = [72222, 71111] # Sum is 143333
alter_question = [12342, 56573] # Sum is 068915
a_run_intervention("Override D4/D4'. Expect A4 digit impact. Expect 048915", 12, 0, [0,2], store_question, alter_question)

# Confirmed that P12.L0.H0+H2 is: Adds D4 and D4'. Impacts A4

In [None]:
print( "Test claim that P13.L0.H0 and H2 performs V3.BA = (D3 + D3’) % 10 impacting A3 accuracy")
print()

store_question = [23222, 13111] # Sum is 36333
alter_question = [12342, 56573] # Sum is 68915
a_run_intervention("Override D3/D3'. Expect A3 digit impact. Expect 66915", 13, 0, [0,2], store_question, alter_question)

# Confirmed that P13.L0.H0+H2 is: Adds D3 and D3'. Impacts A3

In [None]:
print( "Test claim that P14.L0.H0 and H2 performs V2.BA = (D2 + D2’) % 10 impacting A2 accuracy")
print()

store_question = [22322, 11311] # Sum is 33633. No V1.MC
alter_question = [12342, 56573] # Sum is 68915. Has V1.MC
a_run_intervention("Override D2/D2'. Expect A2 digit impact. Expect 68715", 14, 0, [0,2], store_question, alter_question)

store_question = [22322, 11311] # Sum is 33633. No V1.MC
alter_question = [12133, 56133] # Sum is 68266. No V1.MC
a_run_intervention("Override D2/D2'. Expect A2 digit impact. Expect 68666", 14, 0, [0,2], store_question, alter_question)

# Confirmed that P12.L0.H0 and H2 both impact A2, and together sum D2 and D2'

In [None]:
print( "Test claim that P14.L0.H1 calculates V1.C1 but also relies on P10.V1.C2, impacting A2 accuracy")
print()


store_question = [55555, 44454] # Sum is 100009. Has V1.MC
alter_question = [22222, 33333] # Sum is 055555. No V1.MC
a_run_intervention("Override V1.MC impacting V1.C1. Expect A2 digit impact. Expect 055655", 14, 0, [1], store_question, alter_question) # Get 055655. Correct
a_run_intervention("Override V1.MC impacting V1.C1. Expect A2 digit impact. Expect 055655", 10, 0, [1], store_question, alter_question) # Get 055555. No impact.

store_question = [55590, 44490] # Sum is 100080. Has V1.MC
alter_question = [12345, 54321] # Sum is 066666. No V1.MC
a_run_intervention("Override V1.MC impacting V1.C1. Expect A2 digit impact. Expect 066766", 14, 0, [1], store_question, alter_question) # Get 066766. Correct
a_run_intervention("Override V1.MC impacting V1.C1. Expect A2 digit impact. Expect 066766", 10, 0, [1], store_question, alter_question) # Get 066666. No impact.

store_question = [12345, 54321] # Sum is 066666. No V1.MC
alter_question = [55590, 44490] # Sum is 100080. Has V1.MC
a_run_intervention("Override V1.MC impacting V1.C1. Expect A2 digit impact. Expect 100980", 14, 0, [1], store_question, alter_question) # Get 100980. Correct
a_run_intervention("Override V1.MC impacting V1.C1. Expect A2 digit impact. Expect 100980", 10, 0, [1], store_question, alter_question) # Get 100080. No impact.

# Above shows:
# - P14.L0.H1 behaviour is different from P10.L0.H1 behaviour
# - P14.L0.H1 does not simply copy P10.L0.H1 (although this would be a valid way to get perfect accuracy in A2)
print()

store_question = [12345, 54321] # Sum is 066666. No V1.MC
alter_question = [55555, 44445] # Sum is 100000. Has V0.MC, V1.MC, V1.C2
a_run_intervention("Override V1.MC impacting V1.C2. Expect A2 digit impact. Expect 100900", 14, 0, [1], store_question, alter_question) # Get 100900. Correct
a_run_intervention("Override V1.MC impacting V1.C1. Expect A2 digit impact. Expect 099900", 10, 0, [1], store_question, alter_question) # Get 099900. Correct

store_question = [22222, 33333] # Sum is 055555. No V1.MC
alter_question = [66663, 33337] # Sum is 100000. Has V0.MC, V1.MC, V1.C2
a_run_intervention("Override V1.MC impacting V1.C2. Expect A2 digit impact. Expect 100900", 14, 0, [1], store_question, alter_question) # Get 100900. Correct
a_run_intervention("Override V1.MC impacting V1.C1. Expect A2 digit impact. Expect 099900", 10, 0, [1], store_question, alter_question) # Get 099900. Correct

# Above shows:
# - P14.L0.H1 does rely on P10.L0.H1 for V1.C2 information when V1.C != V1.C2
# - P14.L0.H1 calculates V1.C information itself from D1+D1'.

# Overall confirmed: P14.L0.H1 calculates V1.C1 but also relies on P10.V1.C2 when V1.C != V1.C2. Impacts A2

In [None]:
print( "Test claim that P15.L0.H0 and H2 performs V1.BA = (D1 + D1’) % 10 impacting A1 accuracy")
print()

store_question = [22242, 11141] # Sum is 33383. No V0.MC
alter_question = [12322, 56523] # Sum is 68845. No V0.MC
a_run_intervention("Override D1/D1'. Expect A1 digit impact. Expect 68885", 15, 0, [0,2], store_question, alter_question)

# Confirmed that P15.L0.H0 and H2 both impact A1, and together sum D1 and D1'

In [None]:
print( "Test claim that P15.L0.H1 performs V0.MC = (D0 + D0’) / 10 impacting A1 (A one) accuracy")
print()

store_question = [22244, 11149] # Sum is 33393. Has V0.MC
alter_question = [12342, 56513] # Sum is 68855. No V0.MC
a_run_intervention("Override D0.MC. Expect A1 digit impact. Expect 68865", 15, 0, [1], store_question, alter_question)

# Now test counter-claim that an intervention where both questions do NOT generate a D0.MC has NO impact on A1
store_question = [22242, 11141] # Sum is 33383. No V0.MC
alter_question = [12342, 56523] # Sum is 68865. No V0.MC
a_run_intervention("No impact expected", 15, 0, [1], store_question, alter_question)

# Confirmed that P15.L0.H1: Triggers when D0 + D0' > 10. Impacts A1 digit by 1

In [None]:
print( "Test claim that P16.L0.H0 and H2 performs D0.BA = (D0 + D0’) % 10 impacting A0 accuracy")
print()

store_question = [22225, 11114] # Sum is 33339
alter_question = [12342, 56563] # Sum is 68905
a_run_intervention("Override D0/D0'. Expect A0 digit impact. Expect 68909", 16, 0, [0,2], store_question, alter_question)

store_question = [22228, 11119] # Sum is 33347
alter_question = [12342, 56563] # Sum is 68905
a_run_intervention("Override D0/D0'. Expect A0 digit impact. Expect 68907", 16, 0, [0,2], store_question, alter_question)

# Confirmed that P16.L0.H0 and H2 both impact A0, and together sum D0 and D0'

#Part 22: MLP Visualisation (incomplete, on-hold)

In [None]:
import einops
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from IPython.display import clear_output


# number of questions in batch that generated sample_cache
num_questions = varied_questions.shape[0]


def get_mlp_data(data_set_name):

  data_set = sample_cache[data_set_name]
  # print( data_set_name + " shape", data_set.shape) # 239, 22, 2040 = num_questions, n_ctx, d_mlp

  raw_data = data_set[:,-3]
  # print( "raw_data shape", raw_data.shape) # 239, 2040 = num_questions, d_mlp

  answer = einops.rearrange(raw_data, "(x y) d_mlp -> x y d_mlp", x=num_questions).cpu().numpy()
  # print( "answer shape", answer.shape) # 239, 1, 2040 = num_questions, ??, d_mlp

  return answer


l0_mlp_hook_pre_sq = get_mlp_data('blocks.0.mlp.hook_pre')
l0_mlp_hook_post_sq = get_mlp_data('blocks.0.mlp.hook_post')
l1_mlp_hook_pre_sq = get_mlp_data('blocks.1.mlp.hook_pre') if cfg.n_layers > 1 else l0_mlp_hook_pre_sq
l1_mlp_hook_post_sq = get_mlp_data('blocks.1.mlp.hook_post') if cfg.n_layers > 1 else l0_mlp_hook_post_sq


def plot_mlp_neuron_activation(pos: int):
    clear_output()

    l0_mlp_pre_data = l0_mlp_hook_pre_sq[:,:,pos]
    l0_mlp_post_data = l0_mlp_hook_post_sq[:,:,pos]
    l1_mlp_pre_data = l1_mlp_hook_pre_sq[:,:,pos]
    l1_mlp_post_data = l1_mlp_hook_post_sq[:,:,pos]

    fig, axs = plt.subplots(1, 2, figsize=(8,4))

    plot = axs[0].imshow(l1_mlp_pre_data, cmap='magma', vmin=0, vmax=1)
    cbar = plt.colorbar(plot, fraction=0.1)
    cbar.set_label(r'l0_mlp_pre_data {}'.format(pos))
    #axs[0].set_ylim(-0.5, 99.5)
    #axs[0].set_yticks(range(100), labels=range(100), size=5.5);
    #axs[0].set_xticks(range(100), labels=range(100), size=5.5, rotation='vertical');

    plot = axs[1].imshow(l1_mlp_post_data, cmap='magma', vmin=0, vmax=1)
    cbar = plt.colorbar(plot, fraction=0.1)
    cbar.set_label(r'l0_mlp_post_data {}'.format(pos))
    #axs[0].set_ylim(-0.5, 99.5)
    #axs[0].set_yticks(range(100), labels=range(100), size=5.5);
    #axs[0].set_xticks(range(100), labels=range(100), size=5.5, rotation='vertical');


interact(plot_mlp_neuron_activation, pos=widgets.IntText(value=0, description='Index:'))

# Part 25 : Is the model 100% accurate?

This is hard to prove. If it does 1M predictions without error, then we assume it is 100%.

This part takes ~9 mins to run for add_6d_2l_3h model.

In [None]:
def null_hook(value, hook):
  global verbose

  verbose = False

In [None]:
def one_million_questions():
  global verbose
  global ds

  verbose = False

  cfg.batch_size = 512 # For speed
  cfg.seed = 345621 # Randomly chosen
  ds = data_generator() # Re-initialise the data generator

  the_threshold = 0.12
  the_successes = 0
  the_fails = 0

  num_batches = 1000000//cfg.batch_size
  for epoch in range(num_batches):
      tokens = next(ds)

      the_hook = [(l_attn_hook_z_name[0], null_hook)]
      predict_experiment_question(tokens, the_hook, the_threshold)

      if get_pattern_fails_total() > 0:
        break

      the_successes = the_successes + cfg.batch_size

      if epoch % 100 == 0:
          print("Batch", epoch, "of", num_batches, "#Successes=", the_successes)

  print("successes", the_successes, "num_fails", the_fails)


# Commented out as it takes ~9 minutes to run with cfg.n_layers=2, n_digits=6, train=30K
# one_million_questions()