# Accurate Integer Addition in Transformers

This CoLab defines, trains and analyses a Transformer model that performs integer addition e.g. 33357+82243=115600. Each digit is a separate token. For 5 digit addition, the model is given 12 "question" (input) tokens, and must then predict the corresponding 6 "answer" (output) tokens.


ProblemShape.svg

The CoLab follows on from the [Understanding Addition in Transformers](https://colab.research.google.com/drive/1p71dC3LCPJIfFKqr2rhmBZBOyyMpovyn) which explains integer addition and documents a rare high-loss use case called "Use Sum 9 Cascade". This CoLab seeks to eliminate this high-loss use case.

## Tips for using the Colab
 * You can run and alter the code in this CoLab notebook yourself in Google CoLab ( https://colab.research.google.com/ ).
 * To run the notebook, in Google CoLab, **you will need to** go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.
 * Some graphs are interactive!
 * Use the table of contents pane in the sidebar to navigate.
 * Collapse irrelevant sections with the dropdown arrows.
 * Search the page using the search in the sidebar, not CTRL+F.

# Part 1: Setup
Imports standard libraries. Don't bother reading. Skip


In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install kaleido
    #%pip install git+https://github.com/neelnanda-io/TransformerLens.git@new-demo
    %pip install transformer_lens
    %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
    #%pip install circuitsvis
    %pip install jaxtyping
    %pip install einops
    %pip install fancy_einsum
    %pip install torchtyping
    %pip install transformers
    %pip install datasets
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook


In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import kaleido
import plotly.io as pio

if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

import plotly.express as px
import plotly.graph_objects as go

In [None]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from torchtyping import TensorType as TT
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
from IPython.display import display

import circuitsvis as cv


In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [None]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

# Part 2: Configuration
This section defines the token embedding / unembedding and creates the model.

The model has been successfully trained to do 2, 5, 10, 15 digit integer addition. The default is n_digits = 5.

The model has been successfully trained with 2, 3 or 4 attention heads. The  default is n_heads = 3. More heads do not increase accuracy

The model has been successfully trained with 1 or 2 layers. The default is n_layers = 2. Two layers increases accuracy in Cascading UseSum9 cases over 1 layer.

long_equals changes the question format "12345+11111=023465" to "12345+11111equals023465". This does not increase accuracy.

Setting more_ms9_cases to true (to increase the percentage of UseSum9 cases in the training data) speeds up training, but doesnt increase trained accuracy.  



In [None]:
more_ms9_cases = True # When doing addition, increase frequency of Make Sum 9 cases in training data
long_equals = False # If true, in question formats use "equals" instead of "="


In [None]:
# Tokens used in vocab. (Token indexes 0 to 9 represent digits 0 to 9)
PLUS_INDEX = 10
MINUS_INDEX = 11
EQUALS_INDEX = 12
E_INDEX = 13
Q_INDEX = 14
U_INDEX = 14
A_INDEX = 16
L_INDEX = 17
S_INDEX = 18

#@markdown Model
n_layers = 2 #@param
d_vocab = (S_INDEX+1 if long_equals == True else EQUALS_INDEX+1)
n_heads = 3 #@param
d_model = ( 512 // n_heads ) * n_heads    # About 512, and divisible by n_heads
d_head = d_model // n_heads               # About 170 when n_heads == 3
d_mlp = 4 * d_model
seed = 129000 #@param

#@markdown Data
n_digits = 5 #@param
n_ctx = 3 * n_digits + (8 if long_equals == True else 3)
act_fn = 'relu'
batch_size = 64 #@param

#@markdown Optimizer
lr = 0.00008 #@param
weight_decay = 0.1 #@param
n_training_steps = 20000


In [None]:
# Train or load model? Training saves the model weights in a temporary CoLab file
train_model = True #@param

# The name of the temporary CoLab file to save model to
pth_location = "model.pth"

# Save graphs to files as PDF and HTML
save_graph_to_file = False

In [None]:
# Embedding / Unembedding

def tokens_to_string(tokens):
    tokens = utils.to_numpy(tokens)
    x = "".join([str(i) for i in tokens[:n_digits]])
    y = "".join([str(i) for i in tokens[n_digits+1:n_digits*2+1]])
    z = "".join([str(i) for i in tokens[n_ctx-n_digits-1:]])
    equals = "equals" if long_equals == True else "="
    operator = "+"
    return f"{x}{operator}{y}{equals}{z}"

def string_to_tokens(string, batch=False):
    lookup = {str(i):i for i in range(10)}
    lookup['+']=PLUS_INDEX
    lookup['-']=MINUS_INDEX
    lookup['=']=EQUALS_INDEX

    if long_equals == True:
      lookup['e']=E_INDEX
      lookup['q']=Q_INDEX
      lookup['u']=U_INDEX
      lookup['a']=A_INDEX
      lookup['l']=L_INDEX
      lookup['s']=S_INDEX

    tokens = [lookup[i] for i in string if i not in '\n ']
    if batch:
        return torch.tensor(tokens)[None, :]
    else:
        return torch.tensor(tokens)

In [None]:
# Transformer creation

# Structure is documented at https://neelnanda-io.github.io/TransformerLens/transformer_lens.html#transformer_lens.HookedTransformerConfig.HookedTransformerConfig
cfg = HookedTransformerConfig(
    n_layers = n_layers,
    n_heads = n_heads,
    d_model = d_model,
    d_head = d_head,
    d_mlp = d_mlp,
    act_fn = act_fn,
    normalization_type = 'LN',
    d_vocab=d_vocab,
    d_vocab_out=d_vocab,
    n_ctx=n_ctx,
    init_weights = True,
    device="cuda",
    seed = seed,
)

model = HookedTransformer(cfg)

optimizer = optim.AdamW(model.parameters(),
                        lr=lr,
                        weight_decay=weight_decay,
                        betas=(0.9, 0.98))

scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: min(step/10, 1))

# Part 3: Data Generator. Addition sub-task categorisation
This section defines the loss function and the training/tesing data generator.

It also defines functions to categorise the training data by the addition sub-task defined in the paper. The addition sub tasks are abbreviated as:
- BA is Base Add. Calculates the sum of two digits Dn and Dn' modulo 10, ignoring any carry over from previous columns.
- MC1 is Make Carry 1. Evaluates to true if adding digits Dn and Dn' results in a carry over of 1 to the next column.
- MS9 is Make Sum 9. Evaluates to true if adding digits Dn and Dn' gives exactly 9.
- UC1 is Use Carry 1. Takes the previous column's carry output and adds it to the sum of the current digit pair.
- US9 is Use Sum 9. Propagates (aka cascades) a carry over of 1 to the next column if the current column sums to 9 and the previous column generated a carry over. US9 is the most complex task as it spans three digits. For some rare questions (e.g. 00555 + 00445 = 01000) US9 applies to up to four sequential digits, causing a chain effect, with the MC1 cascading through multiple digits.

In [None]:
# Loss functions

# Calculate the per-token probability by comparing a batch of prediction "logits" to answer "tokens"
def logits_to_tokens_loss(logits, tokens):

  # The last "n_digit+1" tokens are the addition answer probabilities
  ans_logits = logits[:, -(n_digits+2):-1]

  # Convert raw score (logits) vector into a probability distribution.
  # Emphasize the largest scores and suppress the smaller ones, to make them more distinguishable.
  ans_probs = F.log_softmax(ans_logits.to(torch.float64), dim=-1)

  max_indices = torch.argmax(ans_probs, dim=-1)

  # The last "n_digit+1" tokens are the model’s answer.
  ans_tokens = tokens[:, -(n_digits+1):]

  # Extract values from the ans_probs tensor, based on indices from the ans_tokens tensor
  ans_loss = torch.gather(ans_probs, -1, ans_tokens[:, :, None])[..., 0]

  return ans_loss, max_indices

# Calculate loss as negative of average per-token mean probability
def loss_fn(ans_loss):
  return -ans_loss.mean(0)

In [None]:
# Define "iterator" data generator function. Invoked using next().
# "Addition" batch entries are formated XXXXX+YYYYY=ZZZZZZ e.g. 55003+80002=135005
# "Subtraction" batch entries are formated XXXXX-YYYYY=ZZZZZZ e.g. 55003-80002=-24999, 80002-55003=024999
# Note that answer has one more digit than the question
# Returns characteristics of each batch entry to aid later analysis
def data_generator(batch_size, n_digits, seed):
    torch.manual_seed(seed)
    while True:
        #generate a batch of questions (answers calculated below)
        batch = torch.zeros((batch_size, n_ctx)).to(torch.int64)
        x = torch.randint(0, 10, (batch_size, n_digits))
        y = torch.randint(0, 10, (batch_size, n_digits))

        if more_ms9_cases == True:
          # The UseSum9 task is compound and rare and so hard to learn.
          # For some of batches, we increase the MakeSum9 case frequency
          # UseSum9 also relies on MakeCarry1 (50%) from previous column.
          # So UseSum9 frequency is increased by 60% * 40% * 50% = 12%
          if random.randint(1, 5) < 3: # 60%
            # Flatten x and y to 1D tensors
            x_flat = x.view(-1)
            y_flat = y.view(-1)

            num_elements_to_modify = int(0.40 * x.numel()) # 40%
            indices_to_modify = torch.randperm(x_flat.numel())[:num_elements_to_modify]
            if random.randint(1, 2) == 1:
              x_flat[indices_to_modify] = 9 - y_flat[indices_to_modify]
            else:
              y_flat[indices_to_modify] = 9 - x_flat[indices_to_modify]

            # Reshape x and y back to its original shape
            x = x_flat.view(x.shape)
            y = y_flat.view(x.shape)

        batch[:, :n_digits] = x
        batch[:, n_digits] = PLUS_INDEX
        batch[:, 1+n_digits:1+n_digits*2] = y
        if long_equals == True:
          batch[:, 1+n_digits*2] = E_INDEX
          batch[:, 2+n_digits*2] = Q_INDEX
          batch[:, 3+n_digits*2] = U_INDEX
          batch[:, 4+n_digits*2] = A_INDEX
          batch[:, 5+n_digits*2] = L_INDEX
          batch[:, 6+n_digits*2] = S_INDEX
        else:
          batch[:, 1+n_digits*2] = EQUALS_INDEX

        # These attributes are used for testing addition
        base_adds = torch.zeros((batch_size,n_digits)).to(torch.int64)
        make_carry1s = torch.zeros((batch_size,n_digits)).to(torch.int64)
        sum9s = torch.zeros((batch_size,n_digits)).to(torch.int64)
        use_carry1s = torch.zeros((batch_size,n_digits)).to(torch.int64)
        use_sum9s = torch.zeros((batch_size,n_digits)).to(torch.int64)

        # generate the addition question answers & other info for testing
        for i in range(n_digits):
            # the column in the test attributes being updated
            test_col = n_digits-1-i

            base_add = batch[:, n_digits-1-i] + batch[:, 2*n_digits-i]
            base_adds[:, test_col] = base_add % 10

            sum9 = (base_add == 9)
            sum9s[:, test_col] = sum9

            if i>0:
              use_carry1s[:, test_col] = make_carry1s[:, test_col+1]
            use_carry = use_carry1s[:, test_col]

            use_sum9s[:, test_col] = sum9 & use_carry;

            digit_sum = base_add + use_carry1s[:, test_col]

            make_carry = (digit_sum >= 10)
            make_carry1s[:, test_col] = make_carry

            batch[:, -1-i] = (digit_sum % 10)

        # Final (possible) carry to highest digit of the sum
        batch[:, -1-n_digits] = make_carry1s[:, 0]

        yield batch.cuda(), base_adds.cuda(), make_carry1s.cuda(), sum9s.cuda(), use_carry1s.cuda(), use_sum9s.cuda()

In [None]:
ds = data_generator(batch_size, n_digits, seed)

tokens, base_adds, make_carry1s, sum9s, use_carry1s, use_sum9s = next(ds)

print(tokens[0])

In [None]:
# Data generator unit test (optional)
# This unit test checks that the above data_generator function is sensible
def unit_test_data_generator(train_tokens, train_use_carry1s, train_make_carry1s):
  test_token = train_tokens[0]
  test_use_carry = train_use_carry1s[0]
  test_make_carry = train_make_carry1s[0]

  if n_digits == 5:
    digits = test_token.cpu().numpy()
    use = test_use_carry.cpu().numpy()
    force = test_make_carry.cpu().numpy()

    num1 = digits[0]*10000 + digits[1]*1000 + digits[2]*100 + digits[3]*10 + digits[4];
    num2 = digits[6]*10000 + digits[7]*1000 + digits[8]*100 + digits[9]*10 + digits[10];
    if long_equals == True:
      sum = digits[17]*100000 + digits[18]*10000 + digits[19]*1000 + digits[20]*100 + digits[21]*10 + digits[22];
    else:
      sum = digits[12]*100000 + digits[13]*10000 + digits[14]*1000 + digits[15]*100 + digits[16]*10 + digits[17];

    assert num1 + num2 == sum, "Unit test failed: Data generator: Bad sum"
    assert (digits[4]+digits[10]+use[4]>=10) == force[4], "Unit test failed: Data generator: Bad carry 0"
    assert (digits[3]+digits[9]+use[3]>=10) == force[3], "Unit test failed: Data generator: Bad carry 1"
    assert (digits[2]+digits[8]+use[2]>=10) == force[2], "Unit test failed: Data generator: Bad carry 2"
    assert (digits[1]+digits[7]+use[1]>=10) == force[1], "Unit test failed: Data generator: Bad carry 3"
    assert (digits[0]+digits[6]+use[0]>=10) == force[0], "Unit test failed: Data generator: Bad carry 4"

In [None]:
# Base-Add-only loss
# Identify the subset of (simple) tokens that only require BA (not UC1 or US9) to get the correct answer
# Array index 0 is the 'Units' digit. Array index 3 is the 'Thousands' digit.
ba_alldigits_loss = []
ba_alldigits_oneloss = 0

ba_perdigit_loss = []
ba_perdigit_cases = 0
ba_total_cases = 0


# Base Add AllDigits
# Identity the tokens in the batch where UC1 is false for all columns simultaneously, so only BA is required on all digits
def calculate_ba_oneloss(tokens, per_token_losses, base_adds, use_carry1s):
  global ba_alldigits_oneloss

  answer = 0
  any_use_carry1s = torch.any(use_carry1s.bool(), dim=1)
  no_use_carry1s = ~ any_use_carry1s
  num_cases = utils.to_numpy(torch.sum(no_use_carry1s))
  if num_cases > 0 :
    filtered_loss = per_token_losses[:, -n_digits:] * no_use_carry1s[:, None]
    sum_loss = torch.sum(filtered_loss)
    answer = - utils.to_numpy(sum_loss) / num_cases
    answer = answer / n_digits  # Approximately align the scale of ba_alldigits_loss to ba_perdigit_loss
  ba_alldigits_oneloss = answer


def calculate_ba_loss(tokens, per_token_losses, base_adds, use_carry1s):
  global ba_perdigit_cases
  global ba_total_cases

  # Base Add All Digits
  # Identity the tokens in the batch where UC1 is false for all columns simultaneously, so only BA is required on all digits
  calculate_ba_oneloss(tokens, per_token_losses, base_adds, use_carry1s)
  ba_alldigits_loss.append(ba_alldigits_oneloss)


  # Base Add Per Digit
  # For each token in the batch, identity the digit columns (e.g. 3) where use_carry is false, so only BA is required on that digit
  ba_perdigit_cases = 0;
  for digit_num in range(n_digits):
    answer = 0
    no_use_carry = 1 - use_carry1s[:, -1-digit_num]
    num_cases = utils.to_numpy(torch.sum(no_use_carry))
    ba_perdigit_cases += num_cases
    ba_total_cases += num_cases
    if num_cases > 0 :
      filtered_loss = per_token_losses[:, -1-digit_num] * no_use_carry
      sum_loss = torch.sum(filtered_loss)
      answer = - utils.to_numpy(sum_loss) / num_cases
    if len(ba_perdigit_loss)<=digit_num:
      ba_perdigit_loss.append([])
    if (num_cases == 0) & (len(ba_perdigit_loss[digit_num]) > 0) :
      answer = ba_perdigit_loss[digit_num][-1] # Use the previous step's loss. Improves graph
    ba_perdigit_loss[digit_num].append(answer)

In [None]:
# Use Carry 1 loss
# Identify the subset of tokens that require UC1 (but not US9) to get the correct answer
# Array index 0 is the 'Units' digit. Array index 3 is the 'Thousands' digit.
uc1_anydigits_loss = []
uc1_anydigits_oneloss = 0

uc1_perdigit_loss = []
uc1_perdigit_cases = 0
uc1_total_cases = 0


# UC1 AnyDigits (exclude Sum9)
# Identity the tokens in the batch where UC1 is used at least once over the columns & Sum9 is never used
def calculate_uc1_loss_any(tokens, per_token_losses, use_carry1s, sum9s):
  global uc1_anydigits_oneloss

  num_use_carry1s = torch.sum(use_carry1s, dim=1)
  any_use_carry1s = torch.where( num_use_carry1s != 0, 1, 0 ) # At least one digit uses UC1
  num_sum9s = torch.sum(use_sum9s, dim=1)
  no_sum9s = torch.where( num_sum9s == 0, 1, 0 ) # No digits have Sum9 true
  filtered_cases = any_use_carry1s & no_sum9s
  num_cases = utils.to_numpy(torch.sum(filtered_cases))
  filtered_indices = torch.nonzero(filtered_cases).squeeze()
  filtered_token_losses = per_token_losses[filtered_indices]
  answer = - filtered_token_losses.mean()
  uc1_anydigits_oneloss = utils.to_numpy(answer)


def calculate_uc1_loss(tokens, per_token_losses, use_carry1s, sum9s):
  global uc1_perdigit_cases
  global uc1_total_cases

  # UC1 AnyDigits (exclude Sum9)
  # Identity the tokens in the batch where UC1 is used at least once over the columns & Sum9 is never used
  calculate_uc1_loss_any(tokens, per_token_losses, use_carry1s, sum9s)
  uc1_anydigits_loss.append(uc1_anydigits_oneloss)

  # UC1 PerDigit (exclude Sum9)
  # For each token in the batch, identity the digit columns (e.g. 3) where UC1 is used on the columns & Sum9 is not true
  uc1_perdigit_cases = 0
  for digit_num in range(n_digits):
    answer = 0
    use_carry = use_carry1s[:, -1-digit_num]
    no_sum9 = 1 - sum9s[:, -1-digit_num]
    filtered_cases = use_carry & no_sum9
    num_cases = utils.to_numpy(torch.sum(filtered_cases))
    uc1_perdigit_cases += num_cases
    uc1_total_cases += num_cases
    if num_cases > 0 :
      filtered_loss = per_token_losses[:, -1-digit_num] * filtered_cases
      sum_loss = torch.sum(filtered_loss)
      answer = - utils.to_numpy(sum_loss) / num_cases
    if len(uc1_perdigit_loss)<=digit_num:
      uc1_perdigit_loss.append([])
    if (num_cases==0) & (len(uc1_perdigit_loss[digit_num]) > 0) :
      answer = uc1_perdigit_loss[digit_num][-1] # Use the previous step's loss. Improves graph
    uc1_perdigit_loss[digit_num].append(answer)

In [None]:
# Use Sum 9 loss
# Identify the subset of tokens that require US9 (being Sum9 and Carry1 from prev column) to get the correct answer
# Array index 0 is the 'Units' digit. Array index 3 is the 'Thousands' digit.
us9_anydigits_loss = []
us9_anydigits_oneloss = 0

us9_perdigit_loss = []
us9_perdigit_cases = 0
us9_total_cases = 0


# US9 OneDigit
# Identity the tokens in the batch where US9 is used at least once over the columns
def calculate_us9_oneloss(tokens, per_token_losses, use_sum9s):
  global us9_anydigits_oneloss

  num_use_sum9s = torch.sum(use_sum9s, dim=1)
  filtered_num_use_sum9s = torch.where( num_use_sum9s != 0, 1, 0 ) # At least OneDigit uses US9
  num_cases = utils.to_numpy(torch.sum(filtered_num_use_sum9s))
  filtered_indices = torch.nonzero(filtered_num_use_sum9s).squeeze()
  filtered_token_losses = per_token_losses[filtered_indices]
  answer = - filtered_token_losses.mean()
  us9_anydigits_oneloss = utils.to_numpy(answer);


def calculate_us9_loss(tokens, per_token_losses, use_sum9s):
  global us9_perdigit_cases
  global us9_total_cases

  # US9 OneDigit
  # Identity the tokens in the batch where US9 is used at least once over the columns
  calculate_us9_oneloss(tokens, per_token_losses, use_sum9s)
  us9_anydigits_loss.append(us9_anydigits_oneloss)

  # For each token in the batch, identity the digit columns (e.g. 3) where US9 is used
  us9_perdigit_cases = 0
  for digit_num in range(n_digits):
    answer = 0
    use_carry = use_carry1s[:, -1-digit_num]
    use_sum9 = sum9s[:, -1-digit_num]
    filtered_cases = use_carry & use_sum9
    num_cases = utils.to_numpy(torch.sum(filtered_cases))
    us9_perdigit_cases += num_cases
    us9_total_cases += num_cases
    if num_cases > 0 :
      filtered_loss = per_token_losses[:, -1-digit_num] * filtered_cases
      sum_loss = torch.sum(filtered_loss)
      answer = - utils.to_numpy(sum_loss) / num_cases
    if len(us9_perdigit_loss)<=digit_num:
      us9_perdigit_loss.append([])
    if (num_cases==0) & (len(us9_perdigit_loss[digit_num]) > 0) :
      answer = us9_perdigit_loss[digit_num][-1] # Use the previous step's loss. Improves graph
    us9_perdigit_loss[digit_num].append(answer)


In [None]:
# Check that us9_perdigit_loss, uc1_perdigit_loss and ba_perdigit_loss do NOT overlap
# This ensures the graphs of each are non-overlapping
def unit_test_nonoverlapping():
  global ba_perdigit_cases
  global ba_total_cases
  global uc1_perdigit_cases
  global uc1_total_cases
  global us9_perdigit_cases
  global us9_total_cases

  perdigit_numcases = us9_perdigit_cases + uc1_perdigit_cases + ba_perdigit_cases
  assert (perdigit_numcases == batch_size * n_digits), "Cases overlap: " + str(perdigit_numcases) + " != " + str(batch_size*n_digits)

# Part 4: Train model with Infinite Data
Train model for n_training_steps, storing train_losses per epoch.

Each training step (of n_training_steps) new training data (a batch of batch_size tokens) is generated and the model is trained and loss calculated on it. No separate "testing" data is needed, as the training data is unique each step. Memorisation of past training data by the model (if any) is minimally beneficial. For 5 digit addition there are 10 billion possible questions, and model training is on ~2 million questions.

In [None]:
def print_config():
  print("n_digits=", n_digits, "n_heads=", n_heads, "n_layers=", n_layers, "n_ctx=", n_ctx)
  print("seed=", seed, "long_equals=", long_equals, "n_training_steps=", n_training_steps)

In [None]:
# Initialise the data generator
ds = data_generator(batch_size, n_digits, seed)

if train_model:
  # Train the model
  train_losses_list = []
  per_token_train_losses_list = []

  for epoch in tqdm.tqdm(range(n_training_steps)):

      tokens, base_adds, make_carry1s, sum9s, use_carry1s, use_sum9s = next(ds)
      logits = model(tokens)

      per_token_train_losses_raw, _ = logits_to_tokens_loss(logits, tokens)
      per_token_train_losses = loss_fn(per_token_train_losses_raw)
      per_token_train_losses_list.append(utils.to_numpy(per_token_train_losses))

      train_loss = per_token_train_losses.mean()
      train_loss.backward()
      train_losses_list.append(train_loss.item())

      calculate_ba_loss(tokens, per_token_train_losses_raw, base_adds, use_carry1s)
      calculate_uc1_loss(tokens, per_token_train_losses_raw, use_carry1s, sum9s)
      calculate_us9_loss(tokens, per_token_train_losses_raw, use_sum9s)

      optimizer.step()
      scheduler.step()
      optimizer.zero_grad()

      if epoch % 100 == 0:
          print(epoch, train_loss.item())
          unit_test_data_generator(tokens, use_carry1s, make_carry1s)
          unit_test_nonoverlapping()


In [None]:
# Even at the end of training, the loss can wobble between epochs, perhaps based on #of rare edge cases in the training data.
# Use the average of last 5 training losses as the "final accuracy"
print_config()
print( "Final training loss", round((train_losses_list[-5]+train_losses_list[-4]+train_losses_list[-3]+train_losses_list[-2]+train_losses_list[-1])/5,6))

In [None]:
if train_model:
  # Save the model to file
  torch.save(model.state_dict(), pth_location)
else:
  # Load the model from file
  model.load_state_dict(torch.load(pth_location))
  model.eval()

# Part 5: Training Loss Analysis - High Level Graphs

This section analyses the training loss by graphing it at a high level.

The loss curve for all digits show visible inflection points (bumps), but is too high level to help understand the algorithm.

When this graph is decomposed into 'per digit' graphs, the interesting distinct 'per digit' curves appear, showing each digit is being refined semi-independently, with the model algorithm refining each digit separately.



In [None]:
epochs_to_graph=1200

# Helper function to plot multiple lines
def lines(raw_lines_list, x=None, mode='lines', labels=None, xaxis='Epoch', yaxis='Loss', title = '', log_y=False, hover=None, all_epochs=True, **kwargs):

    lines_list = raw_lines_list if all_epochs==False else [row[:epochs_to_graph] for row in raw_lines_list]
    log_suffix = '' if log_y==False else ' (Log)'
    epoch_suffix = '' if all_epochs==False else ' (' + str(epochs_to_graph) + ' steps)'
    full_title = title + log_suffix + epoch_suffix

    if type(lines_list)==torch.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    if save_graph_to_file :
      fig = go.Figure(layout={})
      print(full_title)
    else:
      fig = go.Figure(layout={'title':full_title})

    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis + log_suffix)
    for c, line in enumerate(lines_list):
        if type(line)==torch.Tensor:
            line = utils.to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    if save_graph_to_file:
        fig.update_layout(margin=dict(l=10, r=10, t=10, b=10),width=1200,height=300)

    fig.show(bbox_inches="tight")

    if save_graph_to_file:
        filename = full_title.replace(" ", "").replace("(", "").replace(")", "").replace("&", "").replace(",", "").replace("%", "")   +'.pdf'
        pio.write_image(fig, filename)


if train_model:
  title_suffix = ' Loss Curves for ' + str(n_digits) + ' digit addition'
  per_token_losses = np.stack(per_token_train_losses_list, axis=0)

  line(train_losses_list,
      title=title_suffix)

  all_epochs = True;
  for i in range(2):
    lines([per_token_losses[:, i] for i in range(1+n_digits)]+[train_losses_list],
          labels = [f'digit {i}' for i in range(1+n_digits)]+['all_digits'],
          title='Per digit'+title_suffix,
          all_epochs=all_epochs)

    #lines([per_token_losses[:, i] for i in range(1+n_digits)]+[train_losses_list],
    #      labels = [f'digit {i}' for i in range(1+n_digits)]+['all_digits'],
    #      title='Per digit'+title_suffix,
    #      all_epochs=all_epochs,
    #      log_y=True)

    all_epochs = False;

  for i in range(1+n_digits):
    print('Final Loss for digit ' + str(i) + ' is ', per_token_losses[-1, i])

# Part 6: Training Loss Analysis - Single task (multiple digits)
The previous section graphed across all the (BA, UC1, US9) tasks. This section graphs one of the addition sub-tasks at a time:
- The all-digits graphs for one task (say BA) again show several inflection points but do not provide significant insights.
- The per digit graphs for one task (say BA) again show distinct per digit curves. More useful but no significant insights.


In [None]:
# Graph per digit series using "normal" and "log" scale
def graph_perdigit(losslist, num_series, title_suffix, showlog, all_epochs=True):
    lines([losslist[i] for i in range(num_series)],
          labels = [f'digit {i}' for i in range(num_series)],
          title='PerDigit '+title_suffix,
          all_epochs=all_epochs)

    if showlog:
      lines([losslist[i] for i in range(num_series)],
            labels = [f'digit {i}' for i in range(num_series)],
            title='PerDigit '+title_suffix,
            all_epochs=all_epochs,
            log_y=True)

    if all_epochs==True :
      total_loss = 0
      for i in range(num_series):
        print('Final Loss for digit ' + str(i) + ' is', losslist[i][-1])
        total_loss += losslist[i][-1]
      print('Mean Loss is', total_loss/num_series)
      print()

## Base Add task graphs
Graphs token loss vs step in use case where only BA (not UC1 or US9) is needed to get the correct answer.

In [None]:
if train_model:
  perc = (int)(100 * ba_total_cases / (ba_total_cases + uc1_total_cases + us9_total_cases))
  print('BA Loss' + ' (' + str(ba_total_cases) + ' cases, ' + str(perc) + '%)')

  the_title = 'BA Loss'

  # For use cases where use_carry1s is false for all columns simultaneously, so BA can be used on all digits
  line(ba_alldigits_loss, title='AllDigits '+the_title)

  # For each digit independently
  graph_perdigit(ba_perdigit_loss, n_digits, the_title, False, True)
  graph_perdigit(ba_perdigit_loss, n_digits, the_title, False, False)

## Use Carry 1 (excluding Use Sum 9) task graphs
Graphs token loss vs step where use_carry1s is used at least once over the digits columns (and Sum9 is not used at all)

In [None]:
if train_model:
  perc = (int)(100 * uc1_total_cases / (ba_total_cases + uc1_total_cases + us9_total_cases))
  print( 'UC1 Loss (' + str(uc1_total_cases) + ' cases, ' + str(perc) + '%)' )
  the_title = 'UC1 Loss'

  lines([uc1_anydigits_loss],
        labels = ['at least 1 digit'],
        title='AllDigits '+the_title)

  # For each digit independently
  graph_perdigit(uc1_perdigit_loss, n_digits, the_title, False, True)
  graph_perdigit(uc1_perdigit_loss, n_digits, the_title, False, False)


## Use Sum 9 task graphs
Graphs token loss vs step where US9 is used at least once over the digits columns

In [None]:
if train_model:
  perc = (int)(100 * us9_total_cases / (ba_total_cases + uc1_total_cases + us9_total_cases))
  print( 'US9 Loss (' + str(us9_total_cases) + ' cases, ' + str(perc) + '%)')
  the_title = 'US9 Loss'

  lines([us9_anydigits_loss],
        labels = ['any digits'],
        title='AllDigits '+the_title)

  # For each digit independently
  graph_perdigit(us9_perdigit_loss, n_digits, the_title, False, True)
  graph_perdigit(us9_perdigit_loss, n_digits, the_title, False, False)

# Part 7: Training Loss Analysis - Single digit (multiple tasks)
This section graphs show multiple tasks but only one digit. These graphs provide some insights, including:
- The lowest value digit (D0) has a simple, steep loss curve
- The middle value digits (D1, D2 and D3 in 5 digit addition) have similar loss curves
- The highest value digit (D4 in 5 digit addition) has a different loss curve from the middle digits, suggesting the alogrithm for this digit differs from the middle digits.  


## Per digit BA & UC1 task graphs
For each digit, graph the BasedAdd and UC1 tasks for curve comparison

In [None]:
if train_model:
  for whichdigit in range(n_digits):

    the_title = 'Loss for BA & UC1 tasks for Digit ' + str(whichdigit)

    lines([ba_perdigit_loss[whichdigit]]+[uc1_perdigit_loss[whichdigit]],
          labels = ['BA']+['UC1'],
          title=the_title)
    lines([ba_perdigit_loss[whichdigit]]+[uc1_perdigit_loss[whichdigit]],
          labels = ['BA']+['UC1'],
          title=the_title,
          log_y=True)
    lines([ba_perdigit_loss[whichdigit]]+[uc1_perdigit_loss[whichdigit]],
          labels = ['BA']+['UC1'],
          title=the_title,
          all_epochs=False)

## Per digit BA, UC1 & US9 task graphs
For each digit, graph the BasedAdd, UC1 & US9 tasks for curve comparison.

The high variability (noise) in the US9 curve comes from the rareness of this use case. There are ~4 examples in each training batch of 64. A single prediction error adds significant loss. The average US9 loss matches the BA and UC1 curves.

In [None]:
if train_model:
  for whichdigit in range(n_digits):

    the_title = 'Loss for BA, UC1 & US9 Tasks for Digit ' + str(whichdigit)

    lines([ba_perdigit_loss[whichdigit]]+[uc1_perdigit_loss[whichdigit]]+[us9_perdigit_loss[whichdigit]],
          labels = ['BA']+['UC1']+['US9'],
          title=the_title)
    lines([ba_perdigit_loss[whichdigit]]+[uc1_perdigit_loss[whichdigit]]+[us9_perdigit_loss[whichdigit]],
          labels = ['BA']+['UC1']+['US9'],
          title=the_title,
          log_y=True)
    lines([ba_perdigit_loss[whichdigit]]+[uc1_perdigit_loss[whichdigit]]+[us9_perdigit_loss[whichdigit]],
          labels = ['BA']+['UC1']+['US9'],
          title=the_title,
          all_epochs=False)

# Part 8: Questions Set Up

Create sets of sample questions (by task) to ask the model to predict




In [None]:
# Insert a number into the question
def insert_question_number(the_question, index, first_digit_index, the_digits, n):

  last_digit_index = first_digit_index + the_digits - 1

  for j in range(the_digits):
    the_question[index, last_digit_index-j] = n % 10
    n = n // 10


# Create a single question
def make_a_question(the_question, index, q1, q2):
  a = q1 + q2

  insert_question_number(the_question, index, 0, n_digits, q1)

  the_question[index, n_digits] = PLUS_INDEX

  insert_question_number( the_question, index, n_digits+1, n_digits, q2)

  if long_equals == True:
    the_question[index, 2*n_digits+1] = E_INDEX
    the_question[index, 2*n_digits+2] = Q_INDEX
    the_question[index, 2*n_digits+3] = U_INDEX
    the_question[index, 2*n_digits+4] = A_INDEX
    the_question[index, 2*n_digits+5] = L_INDEX
    the_question[index, 2*n_digits+6] = S_INDEX
    offset = 7
  else:
    the_question[index, 2*n_digits+1] = EQUALS_INDEX
    offset = 2

  insert_question_number(the_question, index, 2*n_digits + offset, n_digits+1, q1+q2)


def prediction_to_string(max_indices):
  answer = "".join([str(i) for i in utils.to_numpy(max_indices)[0]])
  return answer;

In [None]:
def make_ba_questions():
  questions = torch.zeros((17, n_ctx)).to(torch.int64)
  if n_digits >= 5 :
    make_a_question( questions, 0, 12345, 33333)
    make_a_question( questions, 1, 33333, 12345)
    make_a_question( questions, 2, 45762, 33113)
    make_a_question( questions, 3, 888, 11111)
    make_a_question( questions, 4, 2362, 23123)
    make_a_question( questions, 5, 15, 81)
    make_a_question( questions, 6, 1000, 4440)
    make_a_question( questions, 7, 4440, 1000)
    make_a_question( questions, 8, 24033, 25133)
    make_a_question( questions, 9, 23533, 21133)
    make_a_question( questions, 10, 32500, 1)
    make_a_question( questions, 11, 31500, 1111)
    make_a_question( questions, 12, 5500, 12323)
    make_a_question( questions, 13, 4500, 2209)
    make_a_question( questions, 14, 10990, 44000)
    make_a_question( questions, 15, 60000, 30000)
    make_a_question( questions, 16, 10000, 20000)
  return questions


def make_uc1_questions():
  n = 24 if n_digits >= 10 else 19
  questions = torch.zeros((n, n_ctx)).to(torch.int64)
  if n_digits >= 5 :
    make_a_question( questions, 0, 15, 45)
    make_a_question( questions, 1, 25, 55)
    make_a_question( questions, 2, 35, 59)
    make_a_question( questions, 3, 40035, 40049)
    make_a_question( questions, 4, 5025, 5059)
    make_a_question( questions, 5, 15, 65)
    make_a_question( questions, 6, 44000, 46000)
    make_a_question( questions, 7, 70000, 40000)
    make_a_question( questions, 8, 15000, 25000)
    make_a_question( questions, 9, 35000, 35000)
    make_a_question( questions, 10, 45000, 85000)
    make_a_question( questions, 11, 67000, 85000)
    make_a_question( questions, 12, 99000, 76000)
    make_a_question( questions, 13, 1500, 4500)
    make_a_question( questions, 14, 2500, 5500)
    make_a_question( questions, 15, 3500, 5900)
    make_a_question( questions, 16, 15020, 45091)
    make_a_question( questions, 17, 25002, 55019)
    make_a_question( questions, 18, 35002, 59019)
  if n_digits >= 10 :
    make_a_question( questions, 19, 25000000, 55000000)
    make_a_question( questions, 20, 35000000, 59000000)
    make_a_question( questions, 21, 150200000, 450910000)
    make_a_question( questions, 22, 250020000, 550190000)
    make_a_question( questions, 23, 350020000, 590190000)
  return questions


def make_simple_us9_questions():
  questions = torch.zeros((18, n_ctx)).to(torch.int64)
  if n_digits >= 5 :
    make_a_question( questions, 0, 55, 45)
    make_a_question( questions, 1, 45, 55)
    make_a_question( questions, 2, 45, 59)
    make_a_question( questions, 3, 35, 69)
    make_a_question( questions, 4, 25, 79)
    make_a_question( questions, 5, 15, 85)
    make_a_question( questions, 6, 15, 88)
    make_a_question( questions, 7, 15508, 14500)
    make_a_question( questions, 8, 14508, 15500)
    make_a_question( questions, 9, 24533, 25933)
    make_a_question( questions, 10, 23533, 26933)
    make_a_question( questions, 11, 32500, 7900)
    make_a_question( questions, 12, 31500, 8500)
    make_a_question( questions, 13, 550, 450)
    make_a_question( questions, 14, 450, 550)
    make_a_question( questions, 15, 10880, 41127)
    make_a_question( questions, 16, 41127, 10880)
    make_a_question( questions, 17, 12386, 82623)
  return questions


def make_cascade_us9_questions(clean = True):
  questions = torch.zeros((29, n_ctx)).to(torch.int64)
  if n_digits >= 5 :
    # These are two level UseSum9 cascades
    make_a_question( questions, 0, 555, 445)
    make_a_question( questions, 1, 3340, 6660)
    make_a_question( questions, 2, 8880, 1120)
    make_a_question( questions, 3, 1120, 8880)
    make_a_question( questions, 4, 123, 877)
    make_a_question( questions, 5, 877, 123)
    make_a_question( questions, 6, 321, 679)
    make_a_question( questions, 7, 679, 321)
    make_a_question( questions, 8, 1283, 88786)
    # These are three level UseSum9 cascades
    make_a_question( questions, 9, 5555, 4445)
    make_a_question( questions, 10, 55550, 44450)
    make_a_question( questions, 11, 334, 666)
    make_a_question( questions, 12, 3340, 6660)
    make_a_question( questions, 13, 33400, 66600)
    make_a_question( questions, 14, 888, 112)
    make_a_question( questions, 15, 8880, 1120)
    make_a_question( questions, 16, 88800, 11200)
    make_a_question( questions, 17, 1234, 8766)
    make_a_question( questions, 18, 4321, 5679)
    # These are four level UseSum9 cascades
    make_a_question( questions, 19, 44445, 55555)
    make_a_question( questions, 20, 33334, 66666)
    make_a_question( questions, 21, 88888, 11112)
    make_a_question( questions, 22, 12345, 87655)
    make_a_question( questions, 23, 54321, 45679)
    make_a_question( questions, 24, 45545, 54455)
    make_a_question( questions, 25, 36634, 63366)
    make_a_question( questions, 26, 81818, 18182)
    make_a_question( questions, 27, 87345, 12655)
    make_a_question( questions, 28, 55379, 44621)
  return questions


# These questions focus mainly on 1 digit at a time
# (We're assuming that the 0 + 0 digit additions are trivial bigrams)
def make_step_questions():
  questions = torch.zeros((26, n_ctx)).to(torch.int64)
  if n_digits >= 5 :
      make_a_question( questions, 0, 1, 0)
      make_a_question( questions, 1, 4, 3)
      make_a_question( questions, 2, 5, 5)
      make_a_question( questions, 3, 8, 1)
      make_a_question( questions, 4, 40, 30)
      make_a_question( questions, 5, 44, 46)
      make_a_question( questions, 6, 400, 300)
      make_a_question( questions, 7, 440, 460)
      make_a_question( questions, 8, 800, 100)
      make_a_question( questions, 9, 270, 470)
      make_a_question( questions, 10, 600, 300)
      make_a_question( questions, 11, 4000, 3000)
      make_a_question( questions, 12, 4400, 4600)
      make_a_question( questions, 13, 6000, 3000)
      make_a_question( questions, 14, 7000, 4000)
      make_a_question( questions, 15, 40000, 30000)
      make_a_question( questions, 16, 44000, 46000)
      make_a_question( questions, 17, 60000, 30000)
      make_a_question( questions, 17, 70000, 40000)
      make_a_question( questions, 19, 10000, 20000)
      make_a_question( questions, 20, 15000, 25000)
      make_a_question( questions, 21, 35000, 35000)
      make_a_question( questions, 22, 45000, 85000)
      make_a_question( questions, 23, 67000, 85000)
      make_a_question( questions, 24, 99000, 76000)
      make_a_question( questions, 25, 76000, 99000)
  return questions


# Returns ~100 manually-chosen questions
def all_manual_questions():
  q1 = make_ba_questions()
  q2 = make_uc1_questions()
  q3 = make_simple_us9_questions()
  q4 = make_cascade_us9_questions()
  q5 = make_step_questions()

  all_manual = torch.vstack((q1, q2, q3, q4, q5))

  return all_manual


# Returns 64 random and ~100 manually-chosen questions
def make_varied_questions():
  random_questions, _, _, _, _, _ = next(ds)

  all_manual = all_manual_questions()

  all_tokens = torch.vstack((random_questions.cuda(), all_manual.cuda()))

  return all_tokens

In [None]:
num_questions = 0;
correct_answers = 0;
verbose = True
total_mean_loss = 0.0


# Clear the question summary results
def clear_questions_results(title):
  global num_questions
  global correct_answers
  global verbose
  global total_mean_loss

  num_questions = 0
  correct_answers = 0
  total_mean_loss = 0

  if verbose:
    print(title)


# Ask model to predict answer for each question & collect results
def do_questions(questions):
    global num_questions
    global correct_answers
    global verbose
    global total_mean_loss

    num_questions = questions.shape[0]
    for question_num in range(num_questions):
      q = questions[question_num]

      # Run with no hook
      the_logits = model(q.cuda())

      q_2d = q.unsqueeze(0)
      losses_raw, max_indices = logits_to_tokens_loss(the_logits, q_2d.cuda())
      losses = loss_fn(losses_raw)
      mean_loss = utils.to_numpy(losses.mean())
      total_mean_loss = total_mean_loss + mean_loss

      model_answer_str = prediction_to_string(max_indices)
      model_answer_num = int(model_answer_str)

      i = n_digits*2 + 7 if long_equals == True else n_digits*2 + 2

      a = 0
      # 5 digit addition yields a 6 digit answer. Hence n_digits+1
      for j in range(n_digits+1):
        a = a * 10 + q[i+j]

      correct = (model_answer_num == a)
      if correct :
        correct_answers += 1

      if verbose:
        print(tokens_to_string(q), "ModelAnswer:", model_answer_str, "Correct:", correct, "Loss:", mean_loss )


# Print the question summary results
def print_questions_results(prefix):
  global num_questions
  global correct_answers
  global verbose
  global total_mean_loss

  print(prefix, num_questions, "questions.", correct_answers, "correct. % Correct:", 100*correct_answers/num_questions, "Mean loss:", total_mean_loss/num_questions)
  if verbose:
    print("")

In [None]:
# Build a test batch of 64 random and ~100 manually-chosen questions
varied_questions = make_varied_questions();

# Run the sample batch, gather the cache
model.reset_hooks()
model.set_use_attn_result(True)
sample_logits, sample_cache = model.run_with_cache(varied_questions.cuda())
print(sample_cache) # Gives names of datasets in the cache
sample_losses_raw, sample_max_indices = logits_to_tokens_loss(sample_logits, varied_questions.cuda())
sample_loss_mean = utils.to_numpy(loss_fn(sample_losses_raw).mean())
print("Sample Mean Loss", sample_loss_mean) # Loss < 0.04 is good

# attn.hook_z is a key output of the attention heads, combining pattern and value
hook_l0_z_name = 'blocks.0.attn.hook_z'
sample_attn_z = sample_cache[hook_l0_z_name]
print("Sample", hook_l0_z_name, sample_attn_z.shape) # gives [64, 18, 3, 170] = batch_size, num_tokens, n_heads, d_model
mean_attn_z = torch.mean(sample_attn_z, dim=0, keepdim=True)
print("Mean", hook_l0_z_name, mean_attn_z.shape) # gives [1, 18, 3, 170] = 1, num_tokens, n_heads, d_model
hook_l1_z_name = 'blocks.1.attn.hook_z'

hook_l0_resid_post_name = 'blocks.0.hook_resid_post'
sample_resid_post = sample_cache[hook_l0_resid_post_name]
print("Sample", hook_l0_resid_post_name, sample_resid_post.shape) # gives [64, 18, 510] = batch_size, num_tokens, d_model
mean_resid_post = torch.mean(sample_resid_post, dim=0, keepdim=True)
print("Mean", hook_l0_resid_post_name, mean_resid_post.shape) # gives [1, 18, 510] = 1, num_tokens, d_model
hook_l1_resid_post_name = 'blocks.1.hook_resid_post'


# Part 9: Prediction Analysis - By Use Case
This section sets up BA, UC1 and US9 test cases that will be re-used in later experiments to show the impact of ablating heads or steps.



In [None]:
print_config()

In [None]:
def do_ba_questions():
  clear_questions_results("Simple BaseAdd cases")
  questions = make_ba_questions()
  do_questions(questions)
  print_questions_results("BaseAdd:")

def do_uc1_questions():
  clear_questions_results("These are Use Carry 1 (UC1) examples (not UseSum9 examples)")
  questions = make_uc1_questions()
  do_questions(questions)
  print_questions_results("UseCarry1:")

def do_simple_us9_questions():
  clear_questions_results("These are simple (one level) UseSum9 exampless")
  questions = make_simple_us9_questions()
  do_questions(questions)
  print_questions_results("SimpleUS9")

def do_cascade_us9_questions():
  clear_questions_results("These are UseSum9 two, three and four level cascades")
  questions = make_cascade_us9_questions()
  do_questions(questions)
  print_questions_results("CascadeUS9")

def do_step_questions():
  clear_questions_results("These questions focus on different steps")
  questions = make_step_questions()
  do_questions(questions)
  print_questions_results("Steps")

verbose = False

do_ba_questions()
do_uc1_questions()
do_simple_us9_questions()
do_cascade_us9_questions()
do_step_questions()

# Part 10: What prediction steps does the model does actually use?

Here we ablate all heads in each step and see if loss increases to show which steps (if any) are **not** used by the algorithm. These steps can be excluded from further analysis.

This section overrides (ablates) the model memory (residual stream) at each step. It confirms that for:
- n_digits = 5, n_layers = 1 : the addition algorithm does **not** use any data generated in steps 0 to 10 inclusive. In these steps the model has **not** yet seen the full question and every digit in the question is independent of every other digit, making accurate answer prediction infeasible. The model also does not use the last (17th) step. Therefore, the addition is started and completed in 6 steps (11 to 16)
- n_digits = 5, n_layers = 2 : the addition algorithm does **not** use any data generated in steps 0 to 7 inclusive. The model also does not use the last (17th) step. Therefore, the addition is started and completed in 9 steps (8 to 16).
- n_digits = 10, n_layers = 2 :  TBA

In [None]:
# Experiment 3: Apply a hook to override residual stream (blocks.0.hook_resid_post) in each step in turn and see impact on loss

exp3_step = 0  # zero-based step to ablate
exp3_threshold = 0.02

def exp3_hook(value, hook):
  #print( "In hook", hook_l0_resid_post_name, exp3_ablate, exp3_step, value.shape) # Get [64, 18, 510] = batch_size, num_tokens, d_model

  # Copy the mean resid post values in step N to all the batch questions
  value[:,exp3_step,:] = mean_resid_post[0,exp3_step,:].clone()


if n_digits >= 5 :
  exp3_fwd_hooks = [(hook_l0_resid_post_name, exp3_hook)]

  for exp3_step in range(n_ctx):
    model.reset_hooks()
    exp3_logits = model.run_with_hooks(varied_questions.cuda(), return_type="logits", fwd_hooks=exp3_fwd_hooks)
    exp3_losses_raw, resid_post_max_indices = logits_to_tokens_loss(exp3_logits, varied_questions.cuda())
    exp3_loss_mean = utils.to_numpy(loss_fn(exp3_losses_raw).mean())

    loss_description = "Good" if exp3_loss_mean < exp3_threshold else "BAD"
    print("Sample Loss", hook_l0_resid_post_name, "Loss", exp3_loss_mean, "Step", exp3_step, "Ablate", loss_description)

  print()

# Part 11: Set Up Run Framework

Create way to get model to predict sample question answers and analysis/show results

In [None]:
# Compare each digit in the answer. Returns a yyNNy pattern where y means the digits match and N means a failure
def get_digit_accuracy_pattern(a_int, answer_str):
  a_str = str(a_int.cpu().numpy()).zfill(n_digits+1)
  match_str = ""
  for i in range(n_digits+1):
      if answer_str[i] == a_str[i]:
          match_str += "y"  # Matching digit
      else:
          match_str += "N"  # Non-matching digit
  return match_str

In [None]:
# Build up a list of success/failure patterns found, and the frequency of each pattern
exp2_step_counts = {}

def clear_step_counts():
  global exp2_step_counts

  exp2_step_counts = {}

def add_step_count(step_key):
  global exp2_step_counts

  if step_key in exp2_step_counts:
    # If the key is already in the dictionary, increment its count
    exp2_step_counts[step_key] += 1
  else:
    # If the key is not in the dictionary, add it with a count of 1
    exp2_step_counts[step_key] = 1

def print_step_counts():
  global exp2_step_counts

  if len(exp2_step_counts) > 0 :
    print("  Step failures:", exp2_step_counts)

def get_step_counts_total():
  global exp2_step_counts

  if len(exp2_step_counts) == 0:
    return 0

  total_sum = 0
  for key, value in exp2_step_counts.items():
      if isinstance(value, int):
          total_sum += value
  return total_sum

In [None]:
# Analyse the question and return the use case as BA, MC, SimpleUS9 or CascadeUS9
def get_question_case(q):
  qa = utils.to_numpy(q)
  qn = qa[:2*n_digits+2]

  # Locate the MC and MS digits (if any)
  mc = torch.zeros( n_digits).to(torch.int64)
  ms = torch.zeros( n_digits).to(torch.int64)
  for dn in range(n_digits):
    if qn[dn] + qn[dn + n_digits + 1] == 9:
      ms[n_digits-1-dn] = 1
    if qn[dn] + qn[dn + n_digits +1] > 9:
      mc[n_digits-1-dn] = 1

  # Calculate the use case of a question
  if torch.sum(mc) == 0:
    return "BA"

  if torch.sum(ms) == 0:
    return "MC1"

  for dn in range(n_digits):
    if dn < n_digits-2 and mc[dn] == 1 and ms[dn+1] == 1 and ms[dn+2] == 1:
      return "CascadeUS9"

  for dn in range(n_digits):
    if dn < n_digits-1 and mc[dn] == 1 and ms[dn+1] == 1:
      return "SimpleUS9"

  return "MC1"


# Test that the above code works as expected
def unit_test_get_question_case_core(correct_case, questions):
  for i in range(questions.shape[0]):
    question_case = get_question_case(questions[i])
    if question_case != correct_case:
      print( "Case mismatch:", correct_case, question_case, questions[i])

def unit_test_get_question_case():
  unit_test_get_question_case_core( "BA", make_ba_questions())
  unit_test_get_question_case_core( "MC1", make_uc1_questions())
  unit_test_get_question_case_core( "SimpleUS9", make_simple_us9_questions())
  unit_test_get_question_case_core( "CascadeUS9", make_cascade_us9_questions())

unit_test_get_question_case()

In [None]:
# Build up a count of cases found
exp2_case_counts = {}

def clear_case_counts():
  global exp2_case_counts

  exp2_case_counts = {}

def add_case_count(case_key):
  global exp2_case_counts

  if case_key in exp2_case_counts:
    # If the key is already in the dictionary, increment its count
    exp2_case_counts[case_key] += 1
  else:
    # If the key is not in the dictionary, add it with a count of 1
    exp2_case_counts[case_key] = 1

def print_case_counts():
  global exp2_case_counts

  if len(exp2_case_counts) > 0:
    print( "  Case failures:", exp2_case_counts)

In [None]:
def do_experiment_question(questions, the_hook, the_threshold):

  num_questions = questions.shape[0]
  for question_num in range(num_questions):
    q = questions[question_num]

    model.reset_hooks()
    exp_logits = model.run_with_hooks(q.cuda(), return_type="logits", fwd_hooks=the_hook)

    q_2d = q.unsqueeze(0)
    exp_losses_raw, exp_max_indices = logits_to_tokens_loss(exp_logits, q_2d.cuda())
    exp_loss_mean = utils.to_numpy(loss_fn(exp_losses_raw).mean())
    exp_answer_str = prediction_to_string(exp_max_indices)

    i = 17 if long_equals == True else 12
    a = q[i+0] * 100000 + q[i+1] * 10000 + q[i+2] * 1000 + q[i+3] * 100 + q[i+4] * 10 + q[i+5] * 1;

    # Only show the question if the loss exceeds the threshold (because of the ablated step)
    if exp_loss_mean > the_threshold:
      match_str = get_digit_accuracy_pattern( a, exp_answer_str )
      # Only count the question if the model got the question wrong
      if 'N' in match_str:
        the_case = get_question_case(q)
        if verbose:
          print(tokens_to_string(q), "ModelAnswer:", exp_answer_str, "Matches:", match_str, "Loss:", exp_loss_mean, "Case:", the_case )
        else:
          add_case_count(the_case)
          add_step_count(match_str)

# Part 11A: Impact on digit accuracy and task accuracy of ablating all heads in a step.
Here we ablate all heads in each step and see if loss increases for specific **digits** and **tasks** shows which steps are associated with calculating which digits and tasks.

Notes:
* The paper claims that an answer digit (such as A3) is calculated one token before it is revealed. Visual inspection of the failures helps determine which answer digit is dependent on each step.
* With n_layers = 1: There are examples, when ablating step 16 that model gets two digits wrong (e.g. Question: 35000 + 35000 = 70000 ModelAnswer: 060009) but the higher digit error is irrelevant as the model has predicted and revealed the higher digit a few tokens ago. Only errors in digits that have not yet been revealed at the step being tested are significant.


In [None]:
# Experiment 2 - Ablate all heads in each step to see what question digits and tasks then fail.

exp2_step = 0 # zero-based step to ablate
exp2_threshold = 0.1
exp2_questions = make_varied_questions()

verbose = False;

def exp2_hook(value, hook):
  # Copy the mean resid post values in step N to all the batch questions
  #value[:,exp2_step,:] = mean_resid_post[0,exp2_step,:].clone()
  value[:,exp2_step,:] = 0


if n_digits >= 5 :
  exp2_fwd_hooks = [(hook_l0_resid_post_name, exp2_hook)]

  print_config()
  print("num_questions=", exp2_questions.shape[0])
  print()

  for exp2_step in range(n_ctx):
    clear_case_counts()
    clear_step_counts()

    do_experiment_question(exp2_questions, exp2_fwd_hooks, exp2_threshold)

    # Skip steps with no errors
    if (not verbose) and ((len(exp2_case_counts) > 0) or (len(exp2_step_counts) > 0)):
      print( "Ablating all heads in step", exp2_step)
      print_case_counts()
      print_step_counts()


# Part 11B: Impact on digit accuracy and task accuracy of ablating some heads in a step.
Ablating each head in each layer in each step and seeing if the loss increases shows which head+layer+step are / aren't used by the algorithm.

In 2 layer model, most second layer head+step are not used.
The results are diagrammed in StaircaseA3L2_Part1.drawio

In [None]:
exp7_step = 0 # zero-based step to ablate. 0 to say 17
exp7_layer = 0 # zero-based layer to ablate. 0 to 1
exp7_head = 0 # zero-based head to ablate. 0 to 2
exp7_threshold = 0.1
exp7_questions = varied_questions
exp7_l0_failures = [[0 for _ in range(n_heads)] for _ in range(n_ctx)]
exp7_l1_failures = [[0 for _ in range(n_heads)] for _ in range(n_ctx)]
verbose = False;


def exp7_reset_failures():
  global exp7_l0_failures
  global exp7_l1_failures

  exp7_l0_failures = [[0 for _ in range(n_heads)] for _ in range(n_ctx)]
  exp7_l1_failures = [[0 for _ in range(n_heads)] for _ in range(n_ctx)]


def exp7_hook(value, hook):
  global exp7_step
  global exp7_head

  # print( "In hook", hook_l0_z_name, value.shape) # Get [1, 18, 3, 170] = ???, n_ctx, n_heads, d_head

  # Copy the mean resid post values in step N to all the batch questions
  value[:,exp7_step,exp7_head,:] = mean_attn_z[:,exp7_step,exp7_head,:].clone() # Mean ablate


def exp7_perform_core():
  global exp7_step
  global exp7_head
  global exp7_threshold
  global exp7_questions
  global exp7_l0_failures
  global exp7_l1_failures

  clear_case_counts()
  clear_step_counts()

  the_hook = [(hook_l0_z_name, exp7_hook)] if exp7_layer == 0 else [(hook_l1_z_name, exp7_hook)]
  do_experiment_question(exp7_questions, the_hook, exp7_threshold)

  if exp7_layer==0:
    exp7_l0_failures[exp7_step][exp7_head] = exp7_l0_failures[exp7_step][exp7_head] + get_step_counts_total()
  else:
    exp7_l1_failures[exp7_step][exp7_head] = exp7_l1_failures[exp7_step][exp7_head] + get_step_counts_total()

  if verbose:
    failures = exp7_l0_failures[exp7_step][exp7_head] if exp7_layer==0 else exp7_l1_failures[exp7_step][exp7_head]
    print( "  Step", exp7_step, "  layer", exp7_layer, "head", exp7_head, "#failures", failures )
    print_case_counts()
    print_step_counts()


def exp7_perform(title):
  global exp7_step
  global exp7_layer
  global exp7_head
  global exp7_questions
  global exp7_l0_failures
  global exp7_l1_failures

  if n_digits >= 5 :
    if n_layers == 2 :

      exp7_reset_failures()

      for exp7_step in range(n_ctx):
        for exp7_layer in range(n_layers):
          for exp7_head in range(n_heads):

            exp7_perform_core()

      print(title, exp7_questions.shape[0])
      for exp7_step in range(n_ctx):
        if sum(exp7_l0_failures[exp7_step]) > 0 or sum(exp7_l1_failures[exp7_step]) > 0 :
          print("  Step", exp7_step, "  L0Hn #failures", exp7_l0_failures[exp7_step], "L1Hn #failures", exp7_l1_failures[exp7_step])
      print()


exp7_questions = varied_questions
exp7_perform("# varied questions:")

## Part 11C - BA Analysis
For n_digits = 5, n_layers = 2:
- S0 to S11 and S17 are not relevant
- L1 is not relevant
- L0H1 is not relevant

In [None]:
exp7_questions = make_ba_questions()
exp7_perform("# BA questions:")

## Part 11D - UC1 Analysis
For n_digits = 5, n_layers = 2:
- S0 to S11 and S17 are not relevant
- L1 is not relevant


In [None]:
exp7_questions = make_uc1_questions()
exp7_perform("# UC1 questions:")

## Part 11E - Simple US9 Analysis
For n_digits = 5, n_layers = 2 for SimpleUS9:
- S0 to S7 and S17 are not relevant
- L1 is not relevant


In [None]:
exp7_questions = make_simple_us9_questions()
exp7_perform("#SimpleUS9 questions")

## Part 11F - Cascade US9 Analysis
For n_digits = 5, n_layers = 2 for Cascade US9:
- S0 to S7 and S17 are not relevant


In [None]:
exp7_questions = make_cascade_us9_questions()
exp7_perform("#CascadeUS9 questions")

# Part 12: Prediction Analysis - Shape

The "Prediction" sections analyses the model after it has been trained, by looking at how it predicts answers to questions. This section shows the shape of the data available. No insights derived.


Get some new tokens from the data generator

In [None]:
tokens, base_adds, make_carry1s, sum9s, use_carry1s, use_sum9s = next(ds)

print("tokens.shape", tokens.shape)
print("sample tokens", tokens[:4])
print(tokens_to_string(tokens[0]))
print(tokens_to_string(tokens[1]))
print(tokens_to_string(tokens[2]))
print(tokens_to_string(tokens[3]))

Run the model on the tokens

In [None]:
original_logits, cache = model.run_with_cache(tokens)
print("original_logits.numel", original_logits.numel())

Get key weight matrices:

In [None]:
W_E = model.embed.W_E[:-1]
print("W_E shape:", W_E.shape)

W_neur = W_E @ model.blocks[0].attn.W_V @ model.blocks[0].attn.W_O @ model.blocks[0].mlp.W_in
print("W_neur shape:", W_neur.shape)

W_logit = model.blocks[0].mlp.W_out @ model.unembed.W_U
print("W_logit shape:", W_logit.shape)

In [None]:
per_token_train_losses_original, _ = logits_to_tokens_loss(logits, tokens)
original_loss = loss_fn(per_token_train_losses_original).mean()
print("Original Loss:", utils.to_numpy(original_loss))

In [None]:
pattern_a = cache["pattern", 0, "attn"][:, :, -1, 0]
pattern_b = cache["pattern", 0, "attn"][:, :, -1, 1]
neuron_acts = cache["post", 0, "mlp"][:, -1, :]
neuron_pre_acts = cache["pre", 0, "mlp"][:, -1, :]

for param_name, param in cache.items():
    print(param_name + ' shape:', param.shape)

# Part 13: Prediction Analysis - Attention Patterns
Attention patterns show which token(s) the model's attention heads are paying attention to in each step of the prediction calculation.

For the default CoLab set up, the  model has 3 attention heads, and performs 5 digit addition. The attention pattern is 18 by 18 squares (as 54321+77779=132100 is 18 tokens). Time proceeds vertically downwards, with one additional token being revealed horizontally at each step, giving the overall triangle shape. This visualisation provided insights. After the question is fully revealed (at step 11), each head starts attending to pairs of question digits from left to right (i.e. high-value digits before lower-value digits) giving the “double staircase" shape. The three heads attend to a given digit pair in three different steps, giving a time ordering of heads.

### Show attention patterns for some randomly chosen tokens

In [None]:
def show_token_attention_patterns(index, layer, token_at_index, use_case):

  the_tokens = [str(token) for token in token_at_index.tolist()]
  if layer == 0:
    tokens_str = tokens_to_string(the_tokens)
    print("Attention patterns for", tokens_str)

  attention_pattern=cache["pattern", layer, "attn"][index]
  display(cv.attention.attention_patterns(
      tokens=the_tokens,
      attention=attention_pattern,
      attention_head_names=[f"L{layer}H{i}" for i in range(n_heads)],
  ))


sample_size = 3

# Show attention patterns for some randomly chosen tokens
for i in range(sample_size):
  for layer in range(n_layers):
    show_token_attention_patterns(i, layer, tokens[i], "Misc")


In [None]:
if save_graph_to_file:

  tokens_str = []
  for i in range(n_heads):
    one_token_str = []
    for j in tokens[i]:
      one_token_str.append(str(utils.to_numpy(j)))
    tokens_str.append(one_token_str)

  # Refer https://github.com/callummcdougall/CircuitsVis/blob/main/python/circuitsvis/circuitsvis_demo.ipynb

  html_object = cv.attention.from_cache(
      cache = cache,
      tokens = tokens_str, # list of list of strings
      return_mode = "html",
  )

  # Create a CoLab file containing the attention pattern(s) in HTML
  filename = "AttentionPattern" + str(n_digits) + "Digits" + str(n_heads) + "Heads.html"
  with open(filename, "w") as f:
      f.write(html_object.data)

  # Manually download the CoLab "html" file and open in your local browser.
  # Install and use the Edge extension "FireShot" to save a portion of the HTML page as a PDF

### Show attention patterns for some tokens which BA only



In [None]:
any_use_carry1s = torch.any(use_carry1s.bool(), dim=1)
no_use_carry1s = ~ any_use_carry1s
ba_num_cases = utils.to_numpy(torch.sum(no_use_carry1s))
if ba_num_cases >= sample_size :
  print(f"Attention patterns for first few BA-only tokens ({ba_num_cases} of {tokens.shape[0]})")
  ba_tokens = tokens[no_use_carry1s==1]
  for i in range(sample_size):
    for layer in range(n_layers):
      show_token_attention_patterns(i, layer, ba_tokens[i], "BAOnly")

Show attention patterns for some tokens which UC1 (and not US9)

In [None]:
num_use_carry1s = torch.sum(use_carry1s, dim=1)
any_use_carry1s = torch.where( num_use_carry1s != 0, 1, 0 ) # At least one digit uses UC1
num_sum9s = torch.sum(use_sum9s, dim=1)
no_sum9s = torch.where( num_sum9s == 0, 1, 0 ) # No digits have Sum9 true
filtered_cases = any_use_carry1s & no_sum9s
uc1_num_cases = utils.to_numpy(torch.sum(filtered_cases))
if uc1_num_cases >= sample_size :
  print(f"Attention patterns for first few UC1-only (and not US9) tokens ({uc1_num_cases} of {tokens.shape[0]})")
  uc1_tokens = tokens[filtered_cases==1]
  for i in range(sample_size):
    for layer in range(n_layers):
      show_token_attention_patterns(i, layer, uc1_tokens[i], "UC1")

Show attention patterns for some tokens which US9

In [None]:
num_sum9s = torch.sum(use_sum9s, dim=1)
any_sum9s = torch.where( num_sum9s != 0, 1, 0 ) # At least one digit uses Sum9
us9_cases = utils.to_numpy(torch.sum(any_sum9s))
if us9_cases >= sample_size :
  print(f"Attention patterns for first few US9 tokens ({us9_cases} of {tokens.shape[0]})")
  us9_tokens = tokens[any_sum9s==1]
  for i in range(sample_size):
    for layer in range(n_layers):
      show_token_attention_patterns(i, layer, us9_tokens[i], "US9")

# Part 14A: Test Hypothesis D4.T1 is calculated at S11L0H2

When n_digits = 5 and n_layers = 2, test whether D4.T1 is calculated at S11L0H2.
If it is, when we ablate S11L0H2 we expect the A4 and maybe A5 question answer digits to be inaccurate.

In [None]:
if n_digits >= 5 and n_layers == 2:
  exp7_step = 11 # zero-based step to ablate. 0 to say 17
  exp7_layer = 0 # zero-based layer to ablate. 0 to 1
  exp7_head = 2 # zero-based head to ablate. 0 to 2

  exp7_questions = torch.zeros((7, n_ctx)).to(torch.int64)
  make_a_question( exp7_questions, 0, 15508, 14500)
  make_a_question( exp7_questions, 1, 14508, 15500)
  make_a_question( exp7_questions, 2, 24533, 25933)
  make_a_question( exp7_questions, 3, 23533, 26933)
  make_a_question( exp7_questions, 4, 10880, 41127)
  make_a_question( exp7_questions, 5, 41127, 10880)
  make_a_question( exp7_questions, 6, 12386, 82623)

  verbose = True
  exp7_reset_failures()
  exp7_perform_core()