# Accurate Integer Addition in Transformers - Analyse the Model

This CoLab analyses a Transformer model that performs integer addition e.g. 33357+82243=115600. Each digit is a separate token. For 5 digit addition, the model is given 12 "question" (input) tokens, and must then predict the corresponding 6 "answer" (output) tokens.

For speed, this CoLab relies on the model weightings created by the sister CoLab [Accurate_Addition_Train](https://github.com/PhilipQuirke/transformer-maths/blob/main/assets/Accurate_Addition_Train.ipynb). These model weightings are loaded from HuggingFace.

## Tips for using the Colab
 * You can run and alter the code in this CoLab notebook yourself in Google CoLab ( https://colab.research.google.com/ ).
 * To run the notebook, in Google CoLab, **you will need to** go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.
 * Some graphs are interactive!
 * Use the table of contents pane in the sidebar to navigate.
 * Collapse irrelevant sections with the dropdown arrows.
 * Search the page using the search in the sidebar, not CTRL+F.

# Part 1: Configuration


In [None]:
# Tokens used in vocab. (Token indexes 0 to 9 represent digits 0 to 9)
PLUS_INDEX = 10
MINUS_INDEX = 11
EQUALS_INDEX = 12

class Config():
  #@markdown Model
  n_layers: int = 2 #@param
  n_heads: int = 3 #@param

  d_vocab: int = EQUALS_INDEX+1
  d_model: int = ( 512 // n_heads ) * n_heads # About 512, and divisible by n_heads
  d_mlp: int = 4 * d_model
  d_head: int = d_model // n_heads  # About 170 when n_heads == 3
  seed: int = 129000 #@param

  #@markdown Data
  n_digits: int = 5 #@param
  n_ctx: int = 3 * n_digits + 3
  act_fn: str = 'relu'
  batch_size: int = 64 #@param

  #@markdown Optimizer
  n_training_steps: int = 30000 #@param
  lr: float = 0.00008 #@param
  weight_decay: int = 0.1 #@param

  # Save graphs to CoLab temp files as PDF and HTML. Can manually export files for re-use in papers.
  save_graph_to_file: bool = True

  # The format to output prettytable in. Options are text|html|json|csv|latex
  # Use Text for this CoLab, latex for Overleaf output, and html for GitHub blog output
  table_out_format: str = "text"


cfg = Config()

In [None]:
def file_name_suffix(digits, layers, heads, d_model, d_head, ctx, seed, training_steps):
  epoch_str = str(training_steps//1000) + "K"
  return '_d{}_l{}_h{}_dm{}_dh{}_ctx{}_seed{}_train{}'.format(digits, layers, heads, d_model, d_head, ctx, seed, epoch_str)

op_prefix = 'add'
fname_prefix = op_prefix
main_fname_suffix = fname_prefix + file_name_suffix(cfg.n_digits, cfg.n_layers, cfg.n_heads, cfg.d_model, cfg.d_head, cfg.n_ctx, cfg.seed, cfg.n_training_steps)
main_fname_full = main_fname_suffix + '.pth'


def print_config():
  print("Config:", main_fname_suffix)

print_config()

# Part 2: Import libraries
Imports standard libraries. Don't bother reading.


In [None]:
DEVELOPMENT_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")

    !pip install matplotlib
    !pip install prettytable

    !pip install kaleido
    !pip install transformer_lens
    !pip install torchtyping
    !pip install transformers

    !pip install numpy -- upgrade
    !pip install scikit-learn -- upgrade

except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import kaleido
import plotly.io as pio

if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

import plotly.express as px
import plotly.graph_objects as go

In [None]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import random
from prettytable import PrettyTable

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import textwrap

# Use Principal Component Analysis (PCA) library
use_pca = True
try:
  from sklearn.decomposition import PCA
except Exception as e:
  print("pca import exception:", e)
  use_pca = False

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache


# Part 3: Create model
This section defines the token embedding / unembedding and creates the model.

In [None]:
# Embedding / Unembedding

def token_to_char(i):
  if i < 10:
   return str(i)
  if i == PLUS_INDEX:
    return "+"
  if i == MINUS_INDEX:
    return "-"
  if i == EQUALS_INDEX:
    return "="
  return "?"


def tokens_to_string(tokens):
    tokens = utils.to_numpy(tokens)
    return "".join([token_to_char(i) for i in tokens[:cfg.n_ctx]])


def string_to_tokens(string, batch: bool=False):
    lookup = {str(i):i for i in range(10)}
    lookup['+']=PLUS_INDEX
    lookup['-']=MINUS_INDEX
    lookup['=']=EQUALS_INDEX

    tokens = [lookup[i] for i in string if i not in '\n ']
    if batch:
        return torch.tensor(tokens)[None, :]
    else:
        return torch.tensor(tokens)

In [None]:
# Transformer creation

# Structure is documented at https://neelnanda-io.github.io/TransformerLens/transformer_lens.html#transformer_lens.HookedTransformerConfig.HookedTransformerConfig
ht_cfg = HookedTransformerConfig(
    n_layers = cfg.n_layers,
    n_heads = cfg.n_heads,
    d_model = cfg.d_model,
    d_head = cfg.d_head,
    d_mlp = cfg.d_mlp,
    act_fn = cfg.act_fn,
    normalization_type = 'LN',
    d_vocab = cfg.d_vocab,
    d_vocab_out = cfg.d_vocab,
    n_ctx = cfg.n_ctx,
    init_weights = True,
    device = "cuda",
    seed = cfg.seed,
)

model = HookedTransformer(ht_cfg)

optimizer = torch.optim.AdamW(model.parameters(),
                        lr = cfg.lr,
                        weight_decay = cfg.weight_decay,
                        betas = (0.9, 0.98))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: min(step/10, 1))

# Part 4: Data Generator. Addition sub-task categorisation
This section defines the loss function and the training/tesing data generator.

It also defines functions to categorise the training data by the addition sub-task defined in the paper. The addition sub tasks are abbreviated as:
- BA is Base Add. Calculates the sum of two digits Dn and Dn' modulo 10, ignoring any carry over from previous columns.
- MC1 is Make Carry 1. Evaluates to true if adding digits Dn and Dn' results in a carry over of 1 to the next column.
- MS9 is Make Sum 9. Evaluates to true if adding digits Dn and Dn' gives exactly 9.
- UC1 is Use Carry 1. Takes the previous column's carry output and adds it to the sum of the current digit pair.
- US9 is Use Sum 9. Propagates (aka cascades) a carry over of 1 to the next column if the current column sums to 9 and the previous column generated a carry over. US9 is the most complex task as it spans three digits. For some rare questions (e.g. 00555 + 00445 = 01000) US9 applies to up to four sequential digits, causing a chain effect, with the MC1 cascading through multiple digits.

In [None]:
# Loss functions

# Calculate the per-token probability by comparing a batch of prediction "logits" to answer "tokens"
def logits_to_tokens_loss(logits, tokens):

  # The last "n_digit+1" tokens are the addition answer probabilities
  ans_logits = logits[:, -(cfg.n_digits+2) :-1]

  # Convert raw score (logits) vector into a probability distribution.
  # Emphasize the largest scores and suppress the smaller ones, to make them more distinguishable.
  ans_probs = F.log_softmax(ans_logits.to(torch.float64), dim=-1)

  max_indices = torch.argmax(ans_probs, dim=-1)

  # The last "n_digit+1" tokens are the model’s answer.
  ans_tokens = tokens[:, -(cfg.n_digits+1):]

  # Extract values from the ans_probs tensor, based on indices from the ans_tokens tensor
  ans_loss = torch.gather(ans_probs, -1, ans_tokens[:, :, None])[..., 0]

  return ans_loss, max_indices


# Calculate loss as negative of average per-token mean probability
def loss_fn(ans_loss):
  return -ans_loss.mean(0)

In [None]:
# Define "iterator" data generator function. Invoked using next().
# "Addition" batch entries are formated XXXXX+YYYYY=ZZZZZZ e.g. 55003+80002=135005
# "Subtraction" batch entries are formated XXXXX-YYYYY=ZZZZZZ e.g. 55003-80002=-24999, 80002-55003=024999
# Note that answer has one more digit than the question
# Returns characteristics of each batch entry to aid later analysis
def data_generator():
    torch.manual_seed(cfg.seed)
    while True:
        #generate a batch of questions (answers calculated below)
        batch = torch.zeros((cfg.batch_size, cfg.n_ctx)).to(torch.int64)
        x = torch.randint(0, 10, (cfg.batch_size, cfg.n_digits))
        y = torch.randint(0, 10, (cfg.batch_size, cfg.n_digits))


        # The UseSum9 task is compound and rare and so hard to learn.
        # For some batches, we increase the MakeSum9 case frequency
        # UseSum9 also relies on MakeCarry1 (50%) from previous column.
        # So UseSum9 frequency is increased by 60% * 40% * 50% = 12%
        if random.randint(1, 5) < 3: # 60%
          # Flatten x and y to 1D tensors
          x_flat = x.view(-1)
          y_flat = y.view(-1)

          num_elements_to_modify = int(0.40 * x.numel()) # 40%
          indices_to_modify = torch.randperm(x_flat.numel())[:num_elements_to_modify]
          if random.randint(1, 2) == 1:
            x_flat[indices_to_modify] = 9 - y_flat[indices_to_modify]
          else:
            y_flat[indices_to_modify] = 9 - x_flat[indices_to_modify]

          # Reshape x and y back to its original shape
          x = x_flat.view(x.shape)
          y = y_flat.view(x.shape)


        batch[:, :cfg.n_digits] = x
        batch[:, cfg.n_digits] = PLUS_INDEX
        batch[:, 1+cfg.n_digits:1+cfg.n_digits*2] = y
        batch[:, 1+cfg.n_digits*2] = EQUALS_INDEX

        # These attributes are used for testing addition
        base_adds = torch.zeros((cfg.batch_size,cfg.n_digits)).to(torch.int64)
        make_carry1s = torch.zeros((cfg.batch_size,cfg.n_digits)).to(torch.int64)
        sum9s = torch.zeros((cfg.batch_size,cfg.n_digits)).to(torch.int64)
        use_carry1s = torch.zeros((cfg.batch_size,cfg.n_digits)).to(torch.int64)
        use_sum9s = torch.zeros((cfg.batch_size,cfg.n_digits)).to(torch.int64)

        # generate the addition question answers & other info for testing
        for i in range(cfg.n_digits):
            # the column in the test attributes being updated
            test_col = cfg.n_digits-1-i

            base_add = batch[:, cfg.n_digits-1-i] + batch[:, 2*cfg.n_digits-i]
            base_adds[:, test_col] = base_add % 10

            sum9 = (base_add == 9)
            sum9s[:, test_col] = sum9

            if i>0:
              use_carry1s[:, test_col] = make_carry1s[:, test_col+1]
            use_carry = use_carry1s[:, test_col]

            use_sum9s[:, test_col] = sum9 & use_carry;

            digit_sum = base_add + use_carry1s[:, test_col]

            make_carry = (digit_sum >= 10)
            make_carry1s[:, test_col] = make_carry

            batch[:, -1-i] = (digit_sum % 10)

        # Final (possible) carry to highest digit of the sum
        batch[:, -1-cfg.n_digits] = make_carry1s[:, 0]

        yield batch.cuda(), base_adds.cuda(), make_carry1s.cuda(), sum9s.cuda(), use_carry1s.cuda(), use_sum9s.cuda()

In [None]:
ds = data_generator()

tokens, base_adds, make_carry1s, sum9s, use_carry1s, use_sum9s = next(ds)

print(tokens[0])

# Part 5: Load Model from HuggingFace

In [None]:
print("Loading model", main_fname_full)

model.load_state_dict(utils.download_file_from_hf(repo_name="PhilipQuirke/Accurate5DigitAddition", file_name=main_fname_full, force_is_torch=True))

model.eval()

# Part 8: Sample Questions Set Up

Create sets of sample questions (by task) to ask the model to predict

In [None]:
# Mathematical operations
addition_major_tag = "Add"
subtraction_major_tag = "Sub"
multiplication_major_tag = "Mult"
varied_major_tag = "Varied"

# Answer digit impact
impact_major_tag = "Impact"

# Ablation failure percentage
perc_major_tag = "FailPerc"

# Attention pattern
attention_major_tag = "Attn"

In [None]:
# Insert a number into the question
def insert_question_number(the_question, index, first_digit_index, the_digits, n):

  last_digit_index = first_digit_index + the_digits - 1

  for j in range(the_digits):
    the_question[index, last_digit_index-j] = n % 10
    n = n // 10


# Create a single question
def make_a_question(the_question, index, q1, q2):
  a = q1 + q2

  insert_question_number(the_question, index, 0, cfg.n_digits, q1)

  the_question[index, cfg.n_digits] = PLUS_INDEX

  insert_question_number( the_question, index, cfg.n_digits+1, cfg.n_digits, q2)

  the_question[index, 2*cfg.n_digits+1] = EQUALS_INDEX
  offset = 2

  insert_question_number(the_question, index, 2*cfg.n_digits + offset, cfg.n_digits+1, q1+q2)


# Create a batch of questions from a 2D matrix of ints
def make_questions(q_matrix):
  max_len = len(q_matrix)
  real_len = 0
  questions = torch.zeros((max_len, cfg.n_ctx)).to(torch.int64)
  limit = 10 ** cfg.n_digits

  for i in range(max_len):
    a = q_matrix[i][0]
    b = q_matrix[i][1]

    if a < limit and b < limit:
      make_a_question(questions, real_len, a, b)
      real_len += 1

  return questions[:real_len]

In [None]:
# Manually create some questions that strongly test one use case


def make_s0_questions():
    return make_questions(
      [[12345, 33333],
      [33333, 12345],
      [45762, 33113],
      [888, 11111],
      [2362, 23123],
      [15, 81],
      [1000, 4440],
      [4440, 1000],
      [24033, 25133],
      [23533, 21133],
      [32500, 1],
      [31500, 1111],
      [5500, 12323],
      [4500, 2209],
      [ 33345, 66643], # =099988
      [ 66643, 33345], # =099988
      [10990, 44000],
      [60000, 30000],
      [10000, 20000]])


def make_s1_questions():
    return make_questions(
      [[ 15, 45],
      [ 25, 55],
      [ 35, 59],
      [ 40035, 40049],
      [ 5025, 5059],
      [ 15, 65],
      [ 44000, 46000],
      [ 70000, 40000],
      [ 15000, 25000],
      [ 35000, 35000],
      [ 45000, 85000],
      [ 67000, 85000],
      [ 99000, 76000],
      [ 1500, 4500],
      [ 2500, 5500],
      [ 3500, 5900],
      [ 15020, 45091],
      [ 25002, 55019],
      [ 35002, 59019]])


# These are one level UseSum9 cascades
def make_s2_questions():
    return make_questions(
      [[ 55, 45],
      [ 45, 55],
      [ 45, 59],
      [ 35, 69],
      [ 25, 79],
      [ 15, 85],
      [ 15, 88],
      [ 15508, 14500],
      [ 14508, 15500],
      [ 24533, 25933],
      [ 23533, 26933],
      [ 32500, 7900],
      [ 31500, 8500],
      [ 550, 450],
      [ 450, 550],
      [ 10880, 41127],
      [ 41127, 10880],
      [ 12386, 82623]])


# These are two level UseSum9 cascades
def make_s3_questions():
    return make_questions(
      [[ 555, 445],
      [ 3340, 6660],
      [ 8880, 1120],
      [ 1120, 8880],
      [ 123, 877],
      [ 877, 123],
      [ 321, 679],
      [ 679, 321],
      [ 1283, 78785]])


# These are three level UseSum9 cascades
def make_s4_questions():
    return make_questions(
      [[ 5555, 4445],
      [ 55550, 44450],
      [ 3334, 6666],
      [ 33340, 66660],
      [ 8888, 1112],
      [ 88880, 11120],
      [ 1234, 8766],
      [ 4321, 5679]])


# These are four level UseSum9 cascades
def make_s5_questions():
    return make_questions(
      [[ 44445, 55555],
      [ 33334, 66666],
      [ 88888, 11112],
      [ 12345, 87655],
      [ 54321, 45679],
      [ 45545, 54455],
      [ 36634, 63366],
      [ 81818, 18182],
      [ 87345, 12655],
      [ 55379, 44621]])


# These questions focus mainly on 1 digit at a time
# (We're assuming that the 0 + 0 digit additions are trivial bigrams)
def make_sn_questions():
    return make_questions(
      [[ 1, 0],
      [ 4, 3],
      [ 5, 5],
      [ 8, 1],
      [ 40, 30],
      [ 44, 46],
      [ 400, 300],
      [ 440, 460],
      [ 800, 100],
      [ 270, 470],
      [ 600, 300],
      [ 4000, 3000],
      [ 4400, 4600],
      [ 6000, 3000],
      [ 7000, 4000],
      [ 40000, 30000],
      [ 44000, 46000],
      [ 60000, 30000],
      [ 70000, 40000],
      [ 10000, 20000],
      [ 15000, 25000],
      [ 35000, 35000],
      [ 45000, 85000],
      [ 67000, 85000],
      [ 99000, 76000],
      [ 76000, 99000]])


# Returns 128 random and ~100 manually-chosen questions
def make_varied_questions():
  q0, _, _, _, _, _ = next(ds)
  q1 = make_s0_questions()
  q2 = make_s1_questions()
  q3 = make_s2_questions()
  q4 = make_s3_questions()
  q5 = make_s4_questions()
  q6 = make_s5_questions()
  q7 = make_sn_questions()
  q8, _, _, _, _, _ = next(ds)

  questions = torch.vstack((q0.cuda(), q1.cuda(), q2.cuda(), q3.cuda(), q4.cuda(), q5.cuda(), q6.cuda(), q7.cuda(), q8.cuda()))

  return questions

In [None]:
# 5 digit addition yields a 6 digit answer. Hence cfg.n_digits+1
def get_answer(q):
  a = 0
  for j in range(cfg.n_digits+1):
    a = a * 10 + q[cfg.n_digits*2 + 2 + j]
  return a

In [None]:
num_questions = 0;
correct_answers = 0;
verbose = True
total_mean_loss = 0.0


# Clear the question summary results
def clear_questions_results(title):
  global num_questions
  global correct_answers
  global verbose
  global total_mean_loss

  num_questions = 0
  correct_answers = 0
  total_mean_loss = 0

  if verbose:
    print(title)


# Ask model to predict answer for each question & collect results
def do_questions(questions):
    global num_questions
    global correct_answers
    global verbose
    global total_mean_loss

    # Get model predictions with no hook
    all_logits = model(questions.cuda())
    all_losses_raw, all_max_indices = logits_to_tokens_loss(all_logits, questions.cuda())

    num_questions = questions.shape[0]
    for question_num in range(num_questions):
      q = questions[question_num]

      mean_loss = utils.to_numpy(loss_fn(all_losses_raw[question_num]).mean())
      total_mean_loss = total_mean_loss + mean_loss

      model_answer_str = tokens_to_string(all_max_indices[question_num])
      model_answer_num = int(model_answer_str)

      a = get_answer(q)

      correct = (model_answer_num == a)
      if correct :
        correct_answers += 1

      if verbose:
        print(tokens_to_string(q), "ModelAnswer:", model_answer_str, "Correct:", correct, "Loss:", mean_loss )


# Print the question summary results
def print_questions_results(prefix, output_table):
  global num_questions
  global correct_answers
  global total_mean_loss

  output_table.add_row([prefix, num_questions, str(correct_answers), 100*correct_answers/num_questions, total_mean_loss/num_questions])

In [None]:
# Build a test batch of 64 random and ~100 manually-chosen questions
varied_questions = make_varied_questions();


# Run the sample batch, gather the cache
model.reset_hooks()
model.set_use_attn_result(True)
sample_logits, sample_cache = model.run_with_cache(varied_questions.cuda())
print(sample_cache) # Gives names of datasets in the cache
sample_losses_raw, sample_max_indices = logits_to_tokens_loss(sample_logits, varied_questions.cuda())
sample_loss_mean = utils.to_numpy(loss_fn(sample_losses_raw).mean())
print("Sample Mean Loss", sample_loss_mean) # Loss < 0.04 is good


# attn.hook_z is the "attention head output" hook point name (at a specified layer)
l_attn_hook_z_name = [utils.get_act_name('z', 0, 'a'),utils.get_act_name('z', 1, 'a')] # 'blocks.0.attn.hook_z' etc
sample_attn_z = sample_cache[l_attn_hook_z_name[0]]
print("Sample", l_attn_hook_z_name[0], sample_attn_z.shape) # gives [239, 18, 3, 170] = #questions, cfg.n_ctx, n_heads, d_head
mean_attn_z = torch.mean(sample_attn_z, dim=0, keepdim=True)
print("Mean", l_attn_hook_z_name[0], mean_attn_z.shape) # gives [1, 18, 3, 170] = 1, cfg.n_ctx, n_heads, d_head


# hook_resid_pre is the "pre residual memory update" hook point name (at a specified layer)
l_hook_resid_pre_name = ['blocks.0.hook_resid_pre','blocks.1.hook_resid_pre']


# hook_resid_post is the "post residual memory update" hook point name (at a specified layer)
l_hook_resid_post_name = ['blocks.0.hook_resid_post','blocks.1.hook_resid_post']
sample_resid_post = sample_cache[l_hook_resid_post_name[0]]
print("Sample", l_hook_resid_post_name[0], sample_resid_post.shape) # gives [239, 18, 510] = #questions, cfg.n_ctx, cfg.d_model
mean_resid_post = torch.mean(sample_resid_post, dim=0, keepdim=True)
print("Mean", l_hook_resid_post_name[0], mean_resid_post.shape) # gives [1, 18, 510] = 1, cfg.n_ctx, d_model


l_mlp_hook_pre_name = [utils.get_act_name('pre', 0),utils.get_act_name('pre', 1)] # 'blocks.0.mlp.hook_pre' etc


# mlp.hook_post is the "MLP layer" hook point name (at a specified layer)
l_mlp_hook_post_name = [utils.get_act_name('post', 0),utils.get_act_name('post', 1)] # 'blocks.0.mlp.hook_post' etc
sample_mlp_hook_post = sample_cache[l_mlp_hook_post_name[0]]
print("Sample", l_mlp_hook_post_name[0], sample_mlp_hook_post.shape) # gives [239, 18, 2040] = #questions, cfg.n_ctx, cfg.d_mlp
mean_mlp_hook_post = torch.mean(sample_mlp_hook_post, dim=0, keepdim=True)
print("Mean", l_mlp_hook_post_name[0], mean_mlp_hook_post.shape) # gives [1, 18, 2040] = 1, cfg.n_ctx, cfg.d_mlp

# Part 9: Prediction Analysis By Use Case
This section sets up BA, UC1 and US9 test cases that will be re-used in later experiments to show the impact of ablating heads or token positions.

In [None]:
exp0_output = PrettyTable()
exp0_output.field_names = ["Case", "#Questions", "#Correct", "%Correct", "Mean loss"]
verbose = False

In [None]:
sum_total_mean_loss = 0
sum_num_questions = 0


def print_question_results( title1, title2, questions):
  global sum_total_mean_loss
  global sum_num_questions

  clear_questions_results(title1)
  do_questions(questions)
  print_questions_results(title2, exp0_output)
  sum_total_mean_loss += total_mean_loss
  sum_num_questions += num_questions

In [None]:
print_question_results("Simple BaseAdd cases", "Add.S0", make_s0_questions())
print_question_results("These are Use Carry 1 (UC1) examples (not UseSum9 examples)", "Add.S1", make_s1_questions())
print_question_results("These are simple (one level) UseSum9 exampless", "Add.S2", make_s2_questions())
print_question_results("These are UseSum9 two level cascades", "Add.S3", make_s3_questions())
print_question_results("These are UseSum9 three level cascades", "Add.S4", make_s4_questions())
print_question_results("These are UseSum9 four level cascades", "Add.S5", make_s5_questions())
print_question_results("These questions focus on different answer digits", "Add.SN", make_sn_questions())

exp0_output.add_row(["OVERALL", sum_num_questions, "", "", sum_total_mean_loss])

print_config()
print()
print(exp0_output.get_formatted_string(out_format=cfg.table_out_format))

# Part 10A: Set Up "Quanta" evaluations

Define tools to evaluate quanta

In [None]:
# Compare each digit in the answer. Returns a A45 pattern where each digit means an incorrect answer digit
def get_answer_impact(q, answer_str):

  a_str = tokens_to_string(q[-(cfg.n_digits+1):])

  impact_str = ""
  for i in range(cfg.n_digits+1):
    impact_str += "" if answer_str[i] == a_str[i] else str(cfg.n_digits-i)

  if impact_str == "":
    return ""

  return "A" + impact_str

In [None]:
# Analyse the question and return the use case as BA, MC, SimpleUS9 or CascadeUS9
def get_question_complexity(q):
  qa = utils.to_numpy(q)
  qn = qa[:2*cfg.n_digits+2]

  # Locate the MC and MS digits (if any)
  mc = torch.zeros( cfg.n_digits).to(torch.int64)
  ms = torch.zeros( cfg.n_digits).to(torch.int64)
  for dn in range(cfg.n_digits):
    if qn[dn] + qn[dn + cfg.n_digits + 1] == 9:
      ms[cfg.n_digits-1-dn] = 1
    if qn[dn] + qn[dn + cfg.n_digits +1] > 9:
      mc[cfg.n_digits-1-dn] = 1

  # Calculate the use case of a question
  if torch.sum(mc) == 0:
    return "S0"

  if torch.sum(ms) == 0:
    return "S1"

  for dn in range(cfg.n_digits-3):
    if mc[dn] == 1 and ms[dn+1] == 1 and ms[dn+2] == 1 and ms[dn+3] == 1:
      return "S4" # MC cascades 3 or more digits

  for dn in range(cfg.n_digits-2):
    if mc[dn] == 1 and ms[dn+1] == 1 and ms[dn+2] == 1:
      return "S3" # MC cascades 2 or more digits

  for dn in range(cfg.n_digits-1):
    if mc[dn] == 1 and ms[dn+1] == 1:
      return "S2" # Simple US 9

  return "S1"

In [None]:
# Test that the get_question_complexity code works as expected
def unit_test_get_question_complexity_core(correct_case, questions):
  num_questions = questions.shape[0]
  print( correct_case, "#Questions=", num_questions)
  for i in range(num_questions):
    question_case = get_question_complexity(questions[i])
    if question_case != correct_case:
      print( "Case mismatch:", correct_case, question_case, questions[i])

def unit_test_get_question_complexity():
  unit_test_get_question_complexity_core( "S0", make_s0_questions())
  unit_test_get_question_complexity_core( "S1", make_s1_questions())
  unit_test_get_question_complexity_core( "S2", make_s2_questions())
  unit_test_get_question_complexity_core( "S3", make_s3_questions())
  unit_test_get_question_complexity_core( "S4", make_s4_questions())
  unit_test_get_question_complexity_core( "S4", make_s5_questions())

unit_test_get_question_complexity()

#Part 10B: Set Up "Quanta" result lists
Define tools to count failures by quanta / metrics. Use prefix "q_"

In [None]:
def increment_dictionary_case_count(dictionary, the_case):
  if the_case in dictionary:
    # If the key is already in the dictionary, increment its count
    dictionary[the_case] += 1
  else:
    # If the key is not in the dictionary, add it with a count of 1
    dictionary[the_case] = 1

In [None]:
class Q_Config():
  # Build up a list of questions by case (e.g. Add.S0, Add.S1, Sub.M0, Sub.NG, etc)
  question_complexity_counts = {}
  # Build up a count of question failure cases (e.g. Add.S0, Add.S1, Sub.M0, Sub.NG, etc)
  question_complexity_fails = {}


  def reset(self):
    self.question_complexity_counts = {}
    self.question_complexity_fails = {}


  def add_questions_complexity_count(self, questions):
    for i in range(questions.shape[0]):
      q_case = get_question_complexity(questions[i])
      increment_dictionary_case_count(self.question_complexity_counts, q_case)


  def add_question_complexity_fail(self, the_case):
    increment_dictionary_case_count(self.question_complexity_fails, the_case)



qcfg = Q_Config()
qcfg.reset()

In [None]:
def q_total_complexity_fails():
  answer = 0
  for _, value in qcfg.question_complexity_fails.items():
    answer += value
  return answer


def q_get_complexity_fails():
  results = ""

  if len(qcfg.question_complexity_fails) > 0:
    sorted_fails = dict(sorted(qcfg.question_complexity_fails.items(), key=lambda item: item[1], reverse=True))

    for key, value in sorted_fails.items():
      percent = round(100 * value / qcfg.question_complexity_counts[key])
      results += "%" + key + "=" + str(percent)+ " "

  return results

# Part 12: Ablate ALL Heads in EACH token position. What is the impact on Loss?

Here we ablate all heads in each token position (overriding the model memory aka residual stream) and see if loss increases. If loss increases the token position is used by the algorithm. Unused token positions can be excluded from further analysis. Use "C_" prefix

In [None]:
def quanta_row(the_layer, the_head):
  return the_layer * (cfg.n_heads+1) + the_head


def quanta_rows():
  return (cfg.n_heads + 1) * cfg.n_layers


def get_quanta_row_heading(i):
  head = i % (cfg.n_heads + 1)
  layer = i // (cfg.n_heads + 1)
  return "L" + str(layer) + ("H" + str(head) if head < cfg.n_heads else "MLP")

In [None]:
class C_Config():
  output = PrettyTable()
  threshold : int = 0.01

  useful_positions = []  # sparce ordered list of useful token positions e.g. 0,1,8,9,10,11
  useful_rows = []  # sparce ordered list of useful quanta_rows e.g. 0,1,2,3,4,7

  curr_position : int = 0   # zero-based token position to ablate


  # Add a token position that we know is used in calculations
  def add_useful_position(self, position):
    if not (position in self.useful_positions):
      self.useful_positions += [position]


  # Add a quanta row that we know is used in calculations
  def add_useful_row(self, row):
    if not (row in self.useful_rows):
      self.useful_rows += [row]


ccfg = C_Config()
ccfg.output.field_names = ["Position", "Fails", "% Fails by Complexity"]

In [None]:
def q_predict_questions(questions, the_hook):

  model.reset_hooks()
  model.set_use_attn_result(True)

  all_logits = model.run_with_hooks(questions.cuda(), return_type="logits", fwd_hooks=the_hook)
  all_losses_raw, all_max_prob_tokens = logits_to_tokens_loss(all_logits, questions.cuda())

  qcfg.reset()
  qcfg.add_questions_complexity_count(questions)

  for question_num in range(questions.shape[0]):
    q = questions[question_num]

    the_loss_mean = utils.to_numpy(loss_fn(all_losses_raw[question_num]).mean())

    # Only show the question if the loss exceeds the threshold (because of the ablated token position)
    if the_loss_mean > ccfg.threshold:
      answer_str = tokens_to_string(all_max_prob_tokens[question_num])

      # Only count the question if the model got the question wrong
      impact_str = get_answer_impact( q, answer_str )
      if 'A' in impact_str:
        the_complexity = get_question_complexity(q)

        qcfg.add_question_complexity_fail(the_complexity)

        if verbose:
          print(tokens_to_string(q), "ModelAnswer:", answer_str, "Impact:", impact_str, "Complexity:", the_complexity )

In [None]:
verbose = False


def c_set_resid_post_hook(value, hook):
  global ccfg

  #print( "In hook", l_hook_resid_post_name[ccfg.layer], ccfg.ablate, ccfg.curr_position, value.shape) # Get [64, 18, 510] = cfg.batch_size, num_tokens, d_model

  # Copy the mean resid post values in position N to all the batch questions
  value[:,ccfg.curr_position,:] = mean_resid_post[0,ccfg.curr_position,:].clone()


num_questions = 0
if cfg.n_digits >= 5 :
  c_fwd_hooks = [(l_hook_resid_post_name[0], c_set_resid_post_hook)] if cfg.n_layers == 1 else [(l_hook_resid_post_name[0], c_set_resid_post_hook),(l_hook_resid_post_name[1], c_set_resid_post_hook)]

  num_questions = varied_questions.shape[0]

  for ccfg.curr_position in range(cfg.n_ctx):
    q_predict_questions(varied_questions, c_fwd_hooks)

    num_fails = q_total_complexity_fails()
    perc_fails = 0
    if num_fails > 0:
      perc_fails = round(100 * num_fails / num_questions)

      ccfg.add_useful_position(ccfg.curr_position)

    ccfg.output.add_row([str(ccfg.curr_position), str(perc_fails)+"%", q_get_complexity_fails()])

In [None]:
print_config()
print("num_questions=", num_questions, "min_useful_position=", min(ccfg.useful_positions), "max_useful_position=", max(ccfg.useful_positions) )
print()

print(ccfg.output.get_formatted_string(out_format=cfg.table_out_format))

# Part 13: Setup: Useful node (cell) matrix

Uses "u_" prefix.

In [None]:
class UsefulCell():
  # Position.Layer.Head of the cell
  position: int = -1  # token-position. Zero to cfg.n_ctx - 1
  layer: int = -1
  head: int = -1

  # Tags related to the cell of form "MajorVersion.MinorVersion"
  tags = []


  # Is this cell an attention head? If not, it must be an MLP layer
  def is_head(self):
    return self.head != cfg.n_heads


  def reset(self):
    self.position = -1
    self.layer = -1
    self.head = -1
    self.tags = []


  # Remove some/all tafs from this cell
  def reset_tags(self, major_version):
    if filter == "":
      self.tags = []
    else:
      self.tags = [s for s in self.tags if not s.startswith(major_version)]


  # Row in a table that this cell is drawn
  def cell_row(self):
    return quanta_row(self.layer, self.head)


  # Add a tag to this cell (if not already present)
  def add_tag(self, tag):
    if tag != "" and (not (tag in self.tags)):
      self.tags += [tag]


  # Return tags with the matching major
  def filter_tags(self, major_version):
    assert major_version != ""

    # Filter strings that start with the given major version
    filtered_strings = [s for s in self.tags if s.startswith(major_version)]

    # Extract minor versions
    minor_versions = [s.split('.')[1] for s in filtered_strings]

    return minor_versions


  # Return minimum tag with the matching major and minor versions
  def min_tag_suffix(self, major_version, minor_version, show_plus = False):
    suffixes = self.filter_tags(major_version)

    if minor_version != "":
      suffixes = [s for s in suffixes if s.startswith(minor_version)]

    answer = min(suffixes) if suffixes else ""

    if show_plus and len(suffixes) > 1:
      answer += "+"

    return answer


  # Return the only tag with the matching major_version
  def only_tag(self, major_version):
    assert major_version != ""

    filtered_strings = [s for s in self.tags if s.startswith(major_version)]

    num_strings = len(filtered_strings)
    if num_strings > 1:
      print("only_tag logic failure", major_version, num_strings, filtered_strings)
      assert False

    return filtered_strings[0].split('.')[1] if num_strings == 1 else ""


  def to_dict(self):
    return {
      "position": self.position,
      "layer": self.layer,
      "head": self.head,
      "tags": self.tags
    }


  def __init__(self, position, layer, head, tags):
    self.position = position
    self.layer = layer
    self.head = head
    self.tags = tags

In [None]:
class U_Config():
  num_heads : int
  num_mlps : int

  useful_cells = []

  curr_position : int
  curr_layer : int
  curr_head : int


  def reset(self):
    self.num_heads = 0
    self.num_mlps = 0
    self.useful_cells = []

    self.curr_position = 0
    self.curr_layer = 0
    self.curr_head = 0


  def reset_tags(self, major_version = ""):
    for cell in self.useful_cells:
      cell.reset_tags(major_version)


  def get_cell( self, the_row, the_position ):
    for cell in self.useful_cells:
      if cell.position == the_position and cell.cell_row() == the_row:
        return cell

    return None


  def add_cell_tag( self, tag ):
    assert self.curr_position  >= 0
    assert self.curr_layer >= 0
    assert self.curr_head >= 0
    assert self.curr_position < cfg.n_ctx
    assert self.curr_layer < cfg.n_layers
    assert self.curr_head <= cfg.n_heads

    the_row = quanta_row(self.curr_layer, self.curr_head)
    assert the_row >= 0

    the_cell = self.get_cell(the_row, self.curr_position )
    if the_cell == None:

      the_cell = UsefulCell(self.curr_position , self.curr_layer, self.curr_head, [])

      self.useful_cells += [the_cell]
      if the_cell.is_head():
        self.num_heads += 1
      else:
        self.num_mlps += 1

    the_cell.add_tag(tag)


ucfg = U_Config()
ucfg.reset()

In [None]:
def u_predict_questions(questions, the_hook):

  model.reset_hooks()
  model.set_use_attn_result(True)

  all_logits = model.run_with_hooks(questions.cuda(), return_type="logits", fwd_hooks=the_hook)
  all_losses_raw, all_max_prob_tokens = logits_to_tokens_loss(all_logits, questions.cuda())

  num_fails = 0
  for question_num in range(questions.shape[0]):
    q = questions[question_num]

    the_loss_mean = utils.to_numpy(loss_fn(all_losses_raw[question_num]).mean())

    # Only show the question if the loss exceeds the threshold (because of the ablated token position)
    if the_loss_mean > ccfg.threshold:
      answer_str = tokens_to_string(all_max_prob_tokens[question_num])

      impact_str = get_answer_impact( q, answer_str )
      # Only count the question if the model got the question wrong
      if 'A' in impact_str:
        num_fails += 1
        the_complexity = get_question_complexity(q)

        # Add question complexity quanta
        ucfg.add_cell_tag( addition_major_tag + "."+ the_complexity )

        # Add answer digit impact quanta
        ucfg.add_cell_tag( impact_major_tag + "."+ impact_str )

        if verbose :
          print(tokens_to_string(q), "U: ModelAnswer:", answer_str, "Complexity:", the_complexity, "Impact:", impact_str, "Loss:", the_loss_mean )

  if num_fails > 0:
    # Add percentage failure quanta
    perc = int( 100.0 * num_fails / len(questions))
    perc_tag = perc_major_tag + '.' + str(perc)
    ucfg.add_cell_tag( perc_tag)

    ccfg.add_useful_row(quanta_row(ucfg.curr_layer, ucfg.curr_head))

In [None]:
def m_mlp_hook_post(value, hook):
  #print( "In m_mlp_hook_post", value.shape) # Get [1, 18, 2040] = ???, cfg.n_ctx, cfg.d_mlp

  # Mean ablate. Copy the mean resid post values in position N to the MLP
  value[:,ucfg.curr_position,:] =  mean_mlp_hook_post[:,ucfg.curr_position,:].clone()


# Ablating the MLP in each layer in each position and seeing if the loss increases shows which head+layer+MLP are used by the algorithm.
def u_mlp_perform_all(questions):
  ucfg.curr_head = cfg.n_heads
  for ucfg.curr_position in ccfg.useful_positions:
    for ucfg.curr_layer in range(cfg.n_layers):
      the_hook = [(l_mlp_hook_post_name[ucfg.curr_layer], m_mlp_hook_post)]
      u_predict_questions(questions, the_hook)

In [None]:
def h_set_attn_hook_z(value, hook):
  # print( "In h_set_attn_hook_z", value.shape) # Get [1, 22, 3, 170] = ???, cfg.n_ctx, cfg.n_heads, cfg.d_head

  # Mean ablate. Copy the mean resid post values in position N to all the batch questions
  value[:,ucfg.curr_position,ucfg.curr_head,:] = mean_attn_z[:,ucfg.curr_position,ucfg.curr_head,:].clone()


def u_head_perform_all(questions):
  for ucfg.curr_position in ccfg.useful_positions:
    for ucfg.curr_layer in range(cfg.n_layers):
      for ucfg.curr_head in range(cfg.n_heads):
        the_hook = [(l_attn_hook_z_name[ucfg.curr_layer], h_set_attn_hook_z)]
        u_predict_questions(questions, the_hook)

In [None]:
def h_null_attn_z_hook(value, hook):
  global ucfg

  #print("In h_null_attn_z_hook", value.shape)  # Get [1, 22, 3, 170] = ???, cfg.n_ctx, cfg.n_heads, cfg.d_head


def u_calculate_attention_tags(questions):
  ucfg.reset_tags(attention_major_tag)

  logits, cache = model.run_with_cache(questions)

  all_attention_weights = []
  for layer in range(cfg.n_layers):
    attention_weights = cache["pattern", layer, "attn"]
    #print(attention_weights.shape) # 512, 4, 22, 22 = cfg.batch_size, cfg.n_heads, cfg.n_ctx, cfg.n_ctx

    average_attention_weights = attention_weights.mean(dim=0)
    #print(average_attention_weights.shape) # 4, 22, 22 = cfg.n_heads, cfg.n_ctx, cfg.n_ctx

    all_attention_weights += [average_attention_weights]


  for cell in ucfg.useful_cells:
    if cell.is_head():
      # Get attention weights for this token in this head
      layer_weights = all_attention_weights[cell.layer]
      weights = layer_weights[cell.head, cell.position, :]

      top_tokens = torch.topk(weights, 4)
      total_attention = weights.sum()
      attention_percentage = top_tokens.values / total_attention * 100

      # Add up to 4 tags with percs per head
      for idx, token_idx in enumerate(top_tokens.indices):
        perc = attention_percentage[idx]
        if perc >= 1.0:
          ucfg.curr_position = cell.position
          ucfg.curr_layer = cell.layer
          ucfg.curr_head = cell.head
          ucfg.add_cell_tag( f"{attention_major_tag}.{token_idx}={perc:.0f}" )


In [None]:
verbose = False
ucfg.reset()
u_mlp_perform_all(varied_questions)
u_head_perform_all(varied_questions)
u_calculate_attention_tags(varied_questions)

 # Part 15: Set up: Draw quanta map

In [None]:
# Define a colormap for use with graphing
def create_custom_colormap():
    colors = ["green", "yellow"]
    return mcolors.LinearSegmentedColormap.from_list("custom_colormap", colors)


# Blend the color with white to make it paler
def pale_color(color, factor=0.5):
    color_array = np.array(color)
    white = np.array([1, 1, 1, 1])
    return white * factor + color_array * (1 - factor)

In [None]:
class quanta_result:
  model_row : int = 0
  model_col : int = 0
  cell_text : str = ""
  color_index :int = -1

  def __init__(self, model_row, model_col, cell_text, color_index):
    self.model_row = model_row
    self.model_col = model_col
    self.cell_text = cell_text
    self.color_index = color_index


def calc_quanta_results( major_version, minor_version, get_cell_details, shades ):

  quanta_results = []

  for raw_row in ccfg.useful_rows:
    for raw_col in ccfg.useful_positions:
      cell_text, color_index = get_cell_details(raw_row, raw_col, major_version, minor_version, shades)
      if cell_text != "" :
        quanta_results +=[quanta_result(model_row=raw_row, model_col=raw_col, cell_text=cell_text, color_index=color_index )]

  return quanta_results


def find_quanta_result_by_row_col(row, col, quanta_results):
    for result in quanta_results:
        if result.model_row == row and result.model_col == col:
            return result
    return None

In [None]:
# Convert token positions to D5, .., D0, -, D'5, .., D'0, =, -, A6, .., A0
def token_position_to_name( position ):
  if position < cfg.n_digits:
    return "D" + str(cfg.n_digits-position-1)

  if position == cfg.n_digits:
    return "+"

  if position <= 2 * cfg.n_digits:
    return "D'" + str(2*cfg.n_digits-position)

  if position == 2 * cfg.n_digits + 1:
    return "="

  return "A" + str(3*cfg.n_digits-position+2)


def unit_test_token_position_to_name():
  for i in range (cfg.n_ctx):
    print(token_position_to_name(i))


#unit_test_token_position_to_name()

In [None]:
def show_quanta_add_patch(ax, j, row, cell_color):
  ax.add_patch(plt.Rectangle((j, row), 1, 1, fill=True, color=cell_color))


def show_quanta_map( title, custom_cmap, shades, major_version, minor_version, get_cell_details, base_fontsize = 10, max_width = 10):

  quanta_results = calc_quanta_results(major_version, minor_version, get_cell_details, shades)

  distinct_rows = set()
  distinct_cols = set()

  for result in quanta_results:
      distinct_rows.add(result.model_row)
      distinct_cols.add(result.model_col)

  distinct_rows = sorted(distinct_rows)
  distinct_cols = sorted(distinct_cols)

  print_config()
  print()

  # Create figure and axes
  fig1, ax1 = plt.subplots(figsize=(2*len(distinct_cols)/3, 2*len(distinct_rows)/3))  # Adjust the figure size as needed

  # Ensure cells are square
  ax1.set_aspect('equal', adjustable='box')
  ax1.yaxis.set_tick_params(labelleft=True, labelright=False)

  colors = [pale_color(custom_cmap(i/shades)) for i in range(shades)]
  vertical_labels = []
  horizontal_labels = []
  wrapper = textwrap.TextWrapper(width=max_width)


  show_row = len(distinct_rows)-1
  for raw_row in distinct_rows:
    vertical_labels += [get_quanta_row_heading(raw_row)]

    show_col = 0
    for raw_col in distinct_cols:
      cell_color = 'lightgrey'  # Color for empty cells

      if show_row == 0:
        horizontal_labels += [token_position_to_name(raw_col)]

      result = find_quanta_result_by_row_col(raw_row, raw_col, quanta_results)
      if result != None:
        cell_color = colors[result.color_index] if result.color_index >= 0 else 'lightgrey'
        the_fontsize = base_fontsize if len(result.cell_text) < 4 else base_fontsize-1 if len(result.cell_text) < 5 else base_fontsize-2
        wrapped_text = wrapper.fill(text=result.cell_text)
        ax1.text(show_col + 0.5, show_row + 0.5, wrapped_text, ha='center', va='center', color='black', fontsize=the_fontsize)

      show_quanta_add_patch(ax1, show_col, show_row, cell_color)
      show_col += 1

    show_row -= 1


  # Configure x axis
  ax1.set_xlim(0, len(horizontal_labels))
  ax1.set_xticks(np.arange(0.5, len(horizontal_labels), 1))
  ax1.set_xticklabels(horizontal_labels)
  ax1.xaxis.tick_top()
  ax1.xaxis.set_label_position('top')
  ax1.tick_params(axis='x', length=0)
  for label in ax1.get_xticklabels():
    label.set_fontsize(9)

  # Configure y axis
  vertical_labels = vertical_labels[::-1] # Reverse the order
  ax1.set_ylim(0, len(vertical_labels))
  ax1.set_yticks(np.arange(0.5, len(vertical_labels), 1))
  ax1.set_yticklabels(vertical_labels)
  ax1.tick_params(axis='y', length=0)
  for label in ax1.get_yticklabels():
    label.set_horizontalalignment('left')
    label.set_position((-0.1, 0))  # Adjust the horizontal position
    #label.set_fontsize(9)

  fulltitle = op_prefix + ': ' + title + ' (d{}_l{}_h{})'.format(cfg.n_digits, cfg.n_layers, cfg.n_heads)

  if cfg.save_graph_to_file:
    print("Saving quanta map:", fulltitle)
    filename = fulltitle.replace( ' ', '_').replace( '-', '_').replace( ':', '_')
    plt.savefig(filename+".pdf", bbox_inches='tight', pad_inches=0)
    #plt.savefig(filename+".svg")
  else:
    ax1.set_title(fulltitle + ' ({} nodes)'.format(len(quanta_results)))

  # Show plot
  plt.show()

# Part 16A: Results: Show failure percentage quanta map

Show the percentage failure rate (incorrect prediction) when individual Attention Heads and MLPs are ablated.

In [None]:
for cell in ucfg.useful_cells:
  print( cell.position, cell.layer, cell.head, cell.tags)

In [None]:
def get_quanta_fail_percs( row, col, major_version, minor_version, shades):
  cell_text = ""
  color_index = 0

  cell = ucfg.get_cell( row, col )
  if cell != None:
    cell_text = cell.only_tag( major_version )
    value = int(cell_text) if cell_text != "" else 0

    if value == 100 and ccfg.num_col_headings() > 5:
      value = 99 # Avoid overlapping figures in the matrix.
    color_index = value // shades
    cell_text = (str(value) if value > 0 else "<1") + "%"

  return cell_text, color_index


show_quanta_map( varied_major_tag, plt.cm.winter, 10, perc_major_tag, "", get_quanta_fail_percs, 10)

# Part 16B: Result: Show attention quanta map

Show attention quanta of useful cells

In [None]:
min_attention_perc = 1 # Only show input tokens with >= 1% of attention


# Only maps attention heads, not MLP layers
def get_quanta_attention_tag(row, col, major_version, minor_version, shades):
  cell_text = ""
  color_index = 0

  if not "MLP" in get_quanta_row_heading(row):
    cell = ucfg.get_cell( row, col )
    if cell != None:
      sum_perc = 0
      for minor_version in cell.filter_tags( major_version ):
        cell_parts = minor_version.split("=")
        token_pos = int(cell_parts[0])
        the_perc = int(cell_parts[1])
        if the_perc >= min_attention_perc:
          cell_text += token_position_to_name(token_pos) + " "
          sum_perc += the_perc

      cell_text = cell_text.rstrip(" ")
      color_index = 10 - sum_perc // 10    # Want >90% => Dark-Green, and <10% => Yellow

  return cell_text, color_index


# Only maps attention heads, not MLP layers
show_quanta_map( "Attention per node", create_custom_colormap(), 10, attention_major_tag, "", get_quanta_attention_tag, 10, 6)

# Part 16C - Show question complexity (S*) quanta map
Show the "minimum" addition purpose of each useful cell by S0 to S4 quanta.

In [None]:
def get_quanta_min_tag(row, col, major_version, minor_version, shades):
  cell_text = ""
  color_index = 0

  cell = ucfg.get_cell( row, col )
  if cell != None:
    cell_text = cell.min_tag_suffix( major_version, minor_version )

    if cell_text != "" :
      color_index = int(cell_text[1]) if len(cell_text) > 1 and cell_text[1].isdigit() else shades-1

  return cell_text, color_index


show_quanta_map( "Addition minimum-quanta per node", create_custom_colormap(), 6, addition_major_tag, "", get_quanta_min_tag, 11)

# Part 16D - Show answer impact quanta map

Show the purpose of each useful cell by impact on the answer digits A0 to A5.

In [None]:
def is_sequential(digits):
    return all(int(digits[i]) + 1 == int(digits[i+1]) for i in range(len(digits) - 1))

In [None]:
def remove_duplicate_digits(input_string):
    seen = set()
    result = ""
    for char in input_string:
        if char not in seen:
            seen.add(char)
            result += char
    return result

# Unit test
# print( remove_duplicate_digits("1231231278321"))

In [None]:
def get_impact_quanta_range( row, col, major_version, minor_version, shades):

  cell_text = ""
  color_index = 0

  cell = ucfg.get_cell( row, col )
  if cell != None and len(cell.tags) > 0:

    cell_texts = cell.filter_tags( major_version )
    if len(cell_texts) > 0:

      # Check for '-' sign
      has_dash = any('-' in s for s in cell_texts)

      digits = ""
      for s in cell_texts:
        digits += ''.join(filter(str.isdigit, s))
      digits = sorted(remove_duplicate_digits(digits))

      if len(digits) >= 3 and is_sequential(digits):
        digits = f"{digits[0]}..{digits[-1]}"

      # Joining numbers with the appropriate prefix
      cell_text = ("A-" if has_dash else "A") + ''.join(digits)

      color_index = int(cell_text[1]) if len(cell_text) > 1 and cell_text[1].isdigit() else shades-1

  return cell_text, color_index


show_quanta_map( "Answer-digit-impact per node", create_custom_colormap(), cfg.n_digits+2, impact_major_tag, "", get_impact_quanta_range, 11)

#Part 18: SetUp: Calc and graph PCA decomposition

In [None]:
tn_questions = 100

# These are n_digit addition questions where the first test_digits add up from 0 to 8
# Randomise the last test_digits-1 digits of both numbers
def make_t8_questions(test_digit):
    limit = 10 ** test_digit
    questions = []
    for i in range(tn_questions):
        x = random.randint(0, 8)
        y = random.randint(0, 8-x)
        x = x * limit + random.randint(0, limit-1)
        y = y * limit + random.randint(0, limit-1)
        questions.append([x, y])
    return make_questions(questions)


# These are n_digit addition questions where the first test_digits add up to 9
# Randomise the last test_digits-1 digits of both numbers
def make_t9_questions(test_digit):
    limit = 10 ** test_digit
    questions = []
    for i in range(tn_questions):
        x = random.randint(0, 9)
        y = 9 - x
        x = x * limit + random.randint(0, limit-1)
        y = y * limit + random.randint(0, limit-1)
        questions.append([x, y])
    return make_questions(questions)


# These are n_digit addition questions where the first test_digits add up to 10 to 18
# Randomise the last test_digits-1 digits of both numbers
def make_t10_questions(test_digit):
    limit = 10 ** test_digit
    questions = []
    for i in range(tn_questions):
        x = random.randint(1, 9)
        y = random.randint(10-x, 9)
        x = x * limit + random.randint(0, limit-1)
        y = y * limit + random.randint(0, limit-1)
        questions.append([x, y])
    return make_questions(questions)


def make_tricase_questions(test_digit):
  q1 = make_t8_questions(test_digit)
  q2 = make_t9_questions(test_digit)
  q3 = make_t10_questions(test_digit)

  questions = torch.vstack((q1, q2, q3))

  return questions

In [None]:
# Do one Principal Component Analysis
def calc_tricase_pca(t_position, t_layer, t_head, t_digit):
  global tn_questions

  t_questions = make_tricase_questions(t_digit)
  #print('Sample t8 question:', t_questions[0].tolist())
  #print('Sample t9 question:', t_questions[tn_questions].tolist())
  #print('Sample t10 question:', t_questions[2*tn_questions].tolist())

  t_logits, t_cache = model.run_with_cache(t_questions)

  # Gather attention patterns for all the (randomly chosen) questions
  attention_outputs = []
  for i in range(len(t_questions)):

    # Output of individual heads, without final bias
    attention_cache=t_cache["result", t_layer, "attn"] # Output of individual heads, without final bias
    attention_output=attention_cache[i]  # Shape [n_ctx, n_head, d_model]
    attention_outputs.append(attention_output[t_position, t_head, :])

  attn_outputs = torch.stack(attention_outputs, dim=0).cpu()

  pca = PCA(n_components=6)
  pca.fit(attn_outputs)
  pca_attn_outputs = pca.transform(attn_outputs)

  title = 'P' + str(t_position) + '.L' + str(t_layer) + '.H'+str(t_head) + ', A'+str(t_digit)

  return (pca, pca_attn_outputs, title)


# Plot one PCA scatter graph
def graph_pca(pca, pca_attn_outputs, ax, title):
  global tn_questions

  ax.scatter(pca_attn_outputs[:tn_questions, 0], pca_attn_outputs[:tn_questions, 1], color='red', label='T8: 0-8') # t8 questions
  ax.scatter(pca_attn_outputs[tn_questions:2*tn_questions, 0], pca_attn_outputs[tn_questions:2*tn_questions, 1], color='green', label='T9') # t9 questions
  ax.scatter(pca_attn_outputs[2*tn_questions:, 0], pca_attn_outputs[2*tn_questions:, 1], color='blue', label='T10: 10-18') # t10 questionsset

  if title != "" :
    ax.set_title(title)

In [None]:
# Graph the PCA of Sn.Ln.Hn's attention pattern, using T8, T9, T10 questions that differ in the An digit
def add_one_pca_subplot(ax, t_position, t_layer, t_head, t_digit):
  pca, pca_attn_outputs, title = calc_tricase_pca(t_position, t_layer, t_head, t_digit)
  graph_pca( pca, pca_attn_outputs, ax, title)

In [None]:
def save_plt_to_file( full_title ):
  if cfg.save_graph_to_file:
    filename = full_title.replace(" ", "_").replace(",", "").replace(":", "_")  + '.pdf'
    plt.savefig(filename)

#Part 19: PCA decomposition tri-state results

Plot attention heads in the positions 8 to 16 with a clear "tri-state" response to (exactly) one An.

In [None]:
if cfg.n_digits == 5 and cfg.n_layers == 2 and use_pca :

  # graph all useful early cells
  fig, axs = plt.subplots(4, 2)
  fig.set_figheight(8)
  fig.set_figwidth(5)

  # Plot all useful attention heads in the positions 8 to 12 with the clearest An selected
  add_one_pca_subplot(axs[0, 0], 8, 0, 1, 2)    # P8.L0.H1 is clear only for A2
  add_one_pca_subplot(axs[0, 1], 9, 0, 1, 1)    # P9.L0.H1 is clear only for A1
  add_one_pca_subplot(axs[1, 0], 11, 0, 1, 3)   # P11.L0.H1 is clear only for A3
  add_one_pca_subplot(axs[1, 1], 11, 0, 2, 4)   # P11.L0.H2 is clear only for A4
  add_one_pca_subplot(axs[2, 0], 12, 0, 1, 3)   # P12.L0.H1 is clear only for A3
  add_one_pca_subplot(axs[2, 1], 13, 0, 1, 2)   # P13.L0.H1 is clear only for A2
  add_one_pca_subplot(axs[3, 0], 14, 0, 1, 1)   # P14.L0.H1 is clear only for A1

  lines_labels = [axs[0,0].get_legend_handles_labels()]
  lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
  # fig.legend(lines, labels, loc='lower center', ncol=4)
  # fig.subplots_adjust(bottom=0.2)  # Adjust the bottom spacing

  axs[3, 1].legend(lines, labels)
  axs[3, 1].axis('off') # Now, to hide the last subplot

  plt.tight_layout()
  save_plt_to_file('PCA_Trigrams')
  plt.show()

#Part 19B: PCA decomposition bi-state results

Plot attention heads in the positions 8 to 16 with a clear "bi-state" response to (exactly) one An.

In [None]:
if cfg.n_digits == 5 and cfg.n_layers == 2 and use_pca :

  # graph all useful early cells
  fig, axs = plt.subplots(1, 2)
  fig.set_figheight(2)
  fig.set_figwidth(5)

  # Plot all useful attention heads in the positions 8 to 12 with the clearest An selected
  add_one_pca_subplot(axs[0], 10, 0, 1, 0)   # P10.L0.H1 is clear only for A0
  add_one_pca_subplot(axs[1], 15, 0, 1, 0)   # P15.L0.H1 is clear only for A0

  lines_labels = [axs[0].get_legend_handles_labels()]
  lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
  #fig.legend(lines, labels, loc='lower center', ncol=4)

  plt.tight_layout()
  save_plt_to_file('PCA_Bigrams')
  plt.show()

In [None]:
# Do one Principal Component Analysis and graph it
def run_one_tricase_pca(t_position, t_layer, t_head, t_digit):

  pca, pca_attn_outputs, title = calc_tricase_pca(t_position, t_layer, t_head, t_digit)

  # Plot the PCA results
  fig, ax = plt.subplots()
  graph_pca(pca, pca_attn_outputs, ax, "")

  full_title = 'PCA of attention: n_digits=' + str(cfg.n_digits) + ', ' + title
  plt.title(full_title + ', EVR[0]=' + str(round(pca.explained_variance_ratio_[0],3)) )

  plt.tight_layout()
  save_plt_to_file(full_title)
  plt.show()

  print( "First few principal components explain variance of:", pca.explained_variance_ratio_)

#Part 19C: PCA decomposition of useful cells with digits 0 to 4

Parts 19A and 19B are selective. This part is not. Use it to find (verify) the interesting parts.

In [None]:
def graph_all_pca_results():

  for useful_cell in ucfg.useful_cells:
      if useful_cell.is_head():
        position = useful_cell.position
        layer = useful_cell.layer
        head = useful_cell.head
        print( "PCA: position=", position, "layer=", layer, "head=", head)

        fig, axs = plt.subplots(3, 2)

        add_one_pca_subplot(axs[0, 0], position, layer, head, 0)
        add_one_pca_subplot(axs[0, 1], position, layer, head, 1)
        add_one_pca_subplot(axs[1, 0], position, layer, head, 2)
        add_one_pca_subplot(axs[1, 1], position, layer, head, 3)
        add_one_pca_subplot(axs[2, 0], position, layer, head, 4)

        plt.tight_layout()
        plt.show()


if use_pca :
  graph_all_pca_results()

# Part 20: Implement Mathematical framework

Demonstrates that the mathematical framework (not the model) can do 1,000,000 additions without error.

In [None]:
def tri_case(dn,dnd):
  s = dn + dnd
  if s >= 10:
    return 10
  if s == 9:
    return 9
  return 8

def tri_add(dn_cx, dm_cy):
  if dn_cx == 10 or (dn_cx == 9 and dm_cy == 10):
    return 10
  if dn_cx == 8 and dm_cy == 10:
    return 9
  return 8

def addition_psuedo_code( d4, d3, d2, d1, d0, d4d, d3d, d2d, d1d, d0d ):
  # V2.C | TriCase(D2, D2’) | P8.L0.H1 and P8.L0.MLP
  v2_c = tri_case(d2, d2d)

  # V1.C | TriCase(D1, D1') | P9.L0.H1 and P9.L0.MLP
  v1_c = tri_case(d1, d1d)

  # V1.C2 | TriAdd(V1.C, TriCase(D0, D0’)) | P10.L0.H1 and P10.L0.MLP
  v1_c2 =  tri_add( v1_c, tri_case(d0, d0d) )

  # V3.C4 | TriAdd(TriCase(D3, D3’), TriAdd(V2.C,V1.C2)) | P11.L0.H1
  v3_c4 = tri_add( tri_case(d3, d3d), tri_add(v2_c,v1_c2) )

  # V4.C | TriCase(D4, D4’) | P11.L0.H2
  v4_c = tri_case(d4, d4d)

  # V4.C5 | TriAdd(V4.C, V3.C4) | P11.L0.MLP
  v4_c5 = tri_add(v4_c, v3_c4)

  # A5 | (V4.C5 == 10) | P11.L1.MLP
  a5 = 1 if v4_c5 == 10 else 0

  # V4.BA | (D4 + D4') % 10 | P12.L0.H0 + H2
  v4_ba = (d4 + d4d) % 10

  # V3.C4 | TriAdd(TriCase(D3, D3’), TriAdd(V2.C,V1.C2)) | P12.L0.H1
  v3_c4 = tri_add( tri_case(d3, d3d), tri_add(v2_c, v1_c2) )

  # A4 | (V4.BA + V3.C4 / 10) % 10 | P12.L0.MLP and P12.L1.MLP
  a4 = (v4_ba + (v3_c4 // 10)) % 10

  #V3.BA | (D3 + D3') % 10 | P13.L0.H0 + H2
  v3_ba = (d3 + d3d) % 10

  # V2.C3 | TriAdd(V2.C,V1.C2) | P13.L0.H1
  v2_c3 = tri_add(v2_c, v1_c2)

  # A3 | (V3.BA + V2.C3 / 10) % 10 | P13.L0.MLP and P13.L1.MLP
  a3 = (v3_ba + (v2_c3 // 10)) % 10

  # V2.BA | (D2 + D2') % 10 | P14.L0.H0 + H2
  v2_ba = (d2 + d2d) % 10

  # V1.C2 | Copy from P10 | P14.L0.H1
  # skip

  # A2 | (V2.BA + V1.C2 / 10) % 10 | P14.L0.MLP and P14.L1.MLP
  a2 = (v2_ba + (v1_c2 // 10)) % 10

  # V1.BA | (D1 + D1') % 10 | P15.L0.H0 + H2
  v1_ba = (d1 + d1d) % 10

  # D0.MC | (D0 + D0') // 10 | P15.L0.H1
  v0_mc = (d0 + d0d) // 10

  # A1 | (V1.BA + D0.MC) % 10 | P15.L0.MLP and P15.L1.MLP
  a1 = (v1_ba + v0_mc) % 10

  # A0 | (D0 + D0') % 10 | P16.L0.H0 + H2 P16.L0.MLP and P16.L1.MLP
  a0 = (d0 + d0d) % 10

  return a5, a4, a3, a2, a1, a0


def do_addition_question(question):
  if cfg.n_digits == 5:
    d4 = int(question[0])
    d3 = int(question[1])
    d2 = int(question[2])
    d1 = int(question[3])
    d0 = int(question[4])
    d4d = int(question[6])
    d3d = int(question[7])
    d2d = int(question[8])
    d1d = int(question[9])
    d0d = int(question[10])

    a5, a4, a3, a2, a1, a0 = addition_psuedo_code( d4, d3, d2, d1, d0, d4d, d3d, d2d, d1d, d0d)

    d = d4 * 10000 + d3 * 1000 + d2 * 100 + d1 * 10 + d0
    dd = d4d * 10000 + d3d * 1000 + d2d * 100 + d1d * 10 + d0d
    a = a5 * 100000 + a4 * 10000 + a3 * 1000 + a2 * 100 + a1 * 10 + a0

    if d + dd != a :
      print(d4, d3, d2, d1, d0, "+" ,d4d, d3d, d2d, d1d, d0d, "=", a5, a4, a3, a2, a1, a0 )
      print("Bad addition:", d, "+", dd, "=", a, "Should be", d+dd, "Delta", d+dd-a)
      return False

    return True

In [None]:
def verify_mathematical_framework():
  if cfg.n_digits == 5:
    num_successes = 0;
    num_fails = 0

    num_batches = 1000000//cfg.batch_size
    for epoch in range(num_batches):
      tokens, _, _, _, _, _ = next(ds)

      for i in range(cfg.batch_size):
        if not do_addition_question(tokens[i]):
          num_fails += 1

      if num_fails > 0:
        break

      num_successes += cfg.batch_size
      if epoch % 250 == 0:
          print("Batch", epoch, "of", num_batches, "#Successes=", num_successes)

    print("successes", num_successes, "num_fails", num_fails)


#verify_mathematical_framework()

# Part 21A : Set Up Interchange Interventions

Here we test our mapping of our mathematical framework (causual abstraction) to the model attention heads.







In [None]:
class A_Config():
  token_position : int = 15 # The token position we want to get/set. P8 to P11 contribute to A5 calculations
  layer : int = 0 # The layer we want to get/set
  heads = [] # The heads we want to get/set
  threshold : int = 0.00001

  hook_calls: int = 0
  answer_failures : int = 0    # Failures of any digit

  questions = []
  store = []

  null_hooks = []
  get_hooks = []
  put_hooks = []


acfg = A_Config()

In [None]:
# Get and put attention head value hooks

def a_null_attn_z_hook(value, hook):
  global acfg

  acfg.hook_calls += 1
  #print("In a_null_attn_z_hook", value.shape)  # Get [1, 18, 3, 170] = ???, cfg.n_ctx, cfg.n_heads, cfg.d_head


def a_get_l0_attn_z_hook(value, hook):
  global acfg

  if acfg.layer == 0:
    acfg.hook_calls += 1
    # print( "In a_get_l0_attn_z_hook", value.shape) # Get [1, 18, 3, 170] = ???, cfg.n_ctx, cfg.n_heads, cfg.d_head
    acfg.store = value.clone()


def a_get_l1_attn_z_hook(value, hook):
  global acfg

  if acfg.layer == 1:
    acfg.hook_calls += 1
    # print( "In acfg.get_l1_attn_z_hook", value.shape) # Get [1, 18, 3, 170] = ???, cfg.n_ctx, cfg.n_heads, cfg.d_head
    acfg.store = value.clone()


def a_put_l0_attn_z_hook(value, hook):
  global acfg

  if acfg.layer == 0:
    acfg.hook_calls += 1
    # print( "In a_l0_attn_z_hook", value.shape) # Get [1, 18, 3, 170] = ???, cfg.n_ctx, cfg.n_heads, d_head
    for head_index in acfg.heads:
      value[:,acfg.token_position,head_index,:] = acfg.store[:,acfg.token_position,head_index,:].clone()


def a_put_l1_attn_z_hook(value, hook):
  global acfg

  if acfg.layer == 1:
    acfg.hook_calls += 1
    # print( "In a_put_l1_attn_z_hook", value.shape) # Get [1, 18, 3, 170] = ???, cfg.n_ctx, cfg.n_heads, d_head
    for head_index in acfg.heads:
      value[:,acfg.token_position,head_index,:] = acfg.store[:,acfg.token_position,head_index,:].clone()


def a_reset(token_position, layer, heads):
  global acfg

  acfg.token_position = token_position
  acfg.layer = layer
  acfg.heads = heads

  acfg.hook_calls = 0
  acfg.answer_failures = 0

  acfg.null_hooks = [(l_attn_hook_z_name[0], a_null_attn_z_hook)]
  acfg.get_hooks = [(l_attn_hook_z_name[0], a_get_l0_attn_z_hook),(l_attn_hook_z_name[1], a_get_l1_attn_z_hook)]
  acfg.put_hooks = [(l_attn_hook_z_name[0], a_put_l0_attn_z_hook),(l_attn_hook_z_name[1], a_put_l1_attn_z_hook)]

In [None]:
def a_predict_question(description, the_hooks, always):
  global acfg
  global model

  acfg.hook_calls = 0
  acfg.answer_failures = 0

  model.reset_hooks()
  model.set_use_attn_result(True)

  all_logits = model.run_with_hooks(acfg.questions.cuda(), return_type="logits", fwd_hooks=the_hooks)
  all_losses_raw, all_max_indices = logits_to_tokens_loss(all_logits, acfg.questions.cuda())

  for question_num in range(acfg.questions.shape[0]):
    loss_max = utils.to_numpy(loss_fn(all_losses_raw[question_num]).mean())

    answer_str = tokens_to_string(all_max_indices[question_num])

    match_str = ""
    if loss_max > acfg.threshold:
      acfg.answer_failures += 1
      q = acfg.questions[question_num]
      match_str = get_answer_impact( q, answer_str )
    if match_str == "":
      match_str = "(none)"

    if always or (loss_max > acfg.threshold):
      loss_str = "(none)" if loss_max < 1e-7 else str(loss_max)

      print(description, "  ModelPredicts:", answer_str, "  DigitsImpacted:", match_str, "  Loss:", loss_str)


In [None]:
def a_run_intervention_core(token_position, layer, heads, store_question, alter_question):
  a_reset(token_position, layer, heads)

  # Predict first question and store activation values (including the Vn.BA)
  acfg.questions = make_questions([store_question])
  a_predict_question("Unit test (null hook)", acfg.null_hooks, False)
  a_predict_question("Store activation", acfg.get_hooks, False)

  # Predict second question. Then rerun overriding Pn_Lm_Hp to give bad answer
  acfg.questions = make_questions([alter_question])
  a_predict_question("Unit test (null hook)", acfg.null_hooks, False)
  prompt = "Intervening on P" + str(token_position) + ".L" + str(layer) + ".H"
  for head_index in acfg.heads:
    prompt += str(head_index) + ","
  a_predict_question(prompt, acfg.put_hooks, True)


def a_run_intervention(description, token_position, layer, heads, store_question, alter_question):
  if cfg.n_digits == 5 and cfg.n_layers == 2:
    print(description)
    a_run_intervention_core(token_position, layer, heads, store_question, alter_question)
    print()

# Part 21B : Run Interchange Interventions

Here we test our mapping of our mathematical framework (casual abstraction) to the model attention heads.


In [None]:
print( "Test claim that P8.L0.H1 performs V2.C = TriCase(D2, D2’) impacting A4 and A5 accuracy")
print()

store_question = [44444, 55555] # Sum is 099999. V2 has no MC.
alter_question = [11111, 11111] # Sum is 022222. V2 has no MC.
a_run_intervention("No V2.MC: No impact expected", 8, 0, [1], store_question, alter_question)

store_question = [77711, 22711] # Sum is 100422. V2 has MC
alter_question = [44444, 55555] # Sum is 099999. V2 has no MC
a_run_intervention("Insert V2.MC: Expect A54 digit impacts. Expect 109999.", 8, 0, [1], store_question, alter_question)

store_question = [17711, 22711] # Sum is 035422. V2 has MC
alter_question = [ 4444,  5555] # Sum is 009999. V2 has no MC
a_run_intervention("Insert V2.MC: Expect A4 digit impact. Expect 019999.", 8, 0, [1], store_question, alter_question)

# Confirmed that P8.L0.H1 is: Based on D2 and D2'. Triggers on a V2 carry value. Provides "carry 1" used in A5 and A4 calculation.

In [None]:
print( "Test claim that P9.L0.H1 performs V1.C = TriCase(D1, D1’) impacting A5, A4 & A3 accuracy")
print()

store_question = [ 44444, 55555] # Sum is 099999. V1 has no MC.
alter_question = [ 11111, 11111] # Sum is 022222. V1 has no MC
a_run_intervention("No V1.MC: No impact expected", 9, 0, [1], store_question, alter_question)

store_question = [ 11171, 11171] # Sum is 022342. V1 has MC
alter_question = [ 44444, 55555] # Sum is 099999. V1 has no MC.
a_run_intervention("Insert V1.MC: Expect A543 digit impacts. Expect 100999.", 9, 0, [1], store_question, alter_question)

store_question = [ 11171, 11171] # Sum is 022342. V1 has MC
alter_question = [  4444,  5555] # Sum is 009999. V1 has no MC
a_run_intervention("Insert V1.MC: Expect A43 digit impacts. Expect 010999.", 9, 0, [1], store_question, alter_question)

store_question = [ 11171, 11171] # Sum is 022342. V1 has MC
alter_question = [   444,   555] # Sum is 000999. V1 has no MC
a_run_intervention("Insert V1.MC: Expect A3 digit impact. Expect 001999.", 9, 0, [1], store_question, alter_question)

# Confirmed that P9.L0.H1 is: Based on D1 and D1'. Triggers on a V1 carry value. Provides "carry 1" used in A5, A4 & A3 calculation.

In [None]:
print( "Test claim that P10.L0.H1 performs V1.C2 = TriAdd(V1.C, TriCase(D0, D0’)) impacting A5, A4, A3 & A2 accuracy")
print()

store_question = [ 11111, 33333] # Sum is 044444. V0 has no MC.
alter_question = [ 44444, 55555] # Sum is 099999. V0 has no MC
a_run_intervention("No impact expected", 10, 0, [1], store_question, alter_question)
# Results: No impact as expected

store_question = [ 11117, 11117] # Sum is 022234. V0 has MC
alter_question = [ 44444, 55555] # Sum is 099999. V0 has no MC
a_run_intervention("Insert D0.MC: Expect A5432 digit impacts. Expect 100099.", 10, 0, [1], store_question, alter_question)
# Results: Impact on A5432 as expected. Got expected value 100099

store_question = [ 11117, 11117] # Sum is 022234. V0 has MC
alter_question = [  4444,  5555] # Sum is 009999. V0 has no MC
a_run_intervention("Insert D0.MC: Expect A432 digit impacts. Expect 010099.", 10, 0, [1], store_question, alter_question)

store_question = [ 11117, 11117] # Sum is 022234. V0 has MC
alter_question = [   444,   555] # Sum is 000999. V0 has no MC
a_run_intervention("Insert D0.MC: Expect A32 digit impacts. Expect 001099", 10, 0, [1], store_question, alter_question)

store_question = [ 11117, 11117] # Sum is 022234. V0 has MC
alter_question = [    44,    55] # Sum is 000099. V0 has no MC
a_run_intervention("Insert D0.MC Expect A2 digit impacts. Expect 000199", 10, 0, [1], store_question, alter_question)

# Confirmed that P10.L0.H1 is: Based on D0 and D0'. Triggers on a V0 carry value. Provides "carry 1" used in A5, A4, A3 & A2 calculation.

In [None]:
print( "Test claim that P11.L0.H1 performs V3.C4 = TriAdd(TriCase(D3, D3’),TriAdd(V2.C,V1.C2)) impacting A5 accuracy")
print()

store_question = [44444, 44444] # Sum is 088888. V3 sums to 8 (has no MC).
alter_question = [11111, 11111] # Sum is 022222. V3 has no MC.
a_run_intervention("No V3.MC: No impact expected", 11, 0, [1], store_question, alter_question)

store_question = [16111, 13111] # Sum is 032111. V3 sums to 9 (has no MC).
alter_question = [44444, 55555] # Sum is 099999. V3 has no MC
a_run_intervention("No V3.MC: No impact expected", 11, 0, [1], store_question, alter_question)

store_question = [16111, 16111] # Sum is 032111. V3 has MC
alter_question = [44444, 55555] # Sum is 099999. V3 has no MC
a_run_intervention("Insert V3.MC: Expect A5 digit impact. Expect 199999.", 11, 0, [1], store_question, alter_question)

# Confirmed that P11.L0.H1 is: Based on D3 and D3'. Triggers on a V3 carry value. Provides "carry 1" used in A5 calculations.

In [None]:
print( "Test claim that P11.L0.H2 performs V4.C = TriCase(D4, D4’) impacting A5 accuracy")
print()

store_question = [44444, 55555] # Sum is 099999. V4 has no MC.
alter_question = [11111, 11111] # Sum is 022222. V4 has no MC.
a_run_intervention("No V4.MC: No impact expected", 11, 0, [2], store_question, alter_question)

store_question = [71111, 71111] # Sum is 100422. V4 has MC
alter_question = [44444, 55555] # Sum is 099999. V4 has no MC
a_run_intervention("Insert V4.MC: Expect A5 digit impact. Expect 199999.", 11, 0, [2], store_question, alter_question)

# Confirmed that P9.L0.H2 is: Based on D4 and D4'. Triggers on a V4 carry value. Provides "carry 1" used in A5 calculation.

In [None]:
print( "Test claim that P12.L0.H0 and H2 performs V4.BA = (D4 + D4’) % 10 impacting A4 accuracy")
print()

store_question = [72222, 71111] # Sum is 143333
alter_question = [12342, 56573] # Sum is 068915
a_run_intervention("Override D4/D4'. Expect A4 digit impact. Expect 048915", 12, 0, [0,2], store_question, alter_question)

# Confirmed that P12.L0.H0+H2 is: Adds D4 and D4'. Impacts A4

In [None]:
print( "Test claim that P13.L0.H0 and H2 performs V3.BA = (D3 + D3’) % 10 impacting A3 accuracy")
print()

store_question = [23222, 13111] # Sum is 36333
alter_question = [12342, 56573] # Sum is 68915
a_run_intervention("Override D3/D3'. Expect A3 digit impact. Expect 66915", 13, 0, [0,2], store_question, alter_question)

# Confirmed that P13.L0.H0+H2 is: Adds D3 and D3'. Impacts A3

In [None]:
print( "Test claim that P14.L0.H0 and H2 performs V2.BA = (D2 + D2’) % 10 impacting A2 accuracy")
print()

store_question = [22322, 11311] # Sum is 33633. No V1.MC
alter_question = [12342, 56573] # Sum is 68915. Has V1.MC
a_run_intervention("Override D2/D2'. Expect A2 digit impact. Expect 68715", 14, 0, [0,2], store_question, alter_question)

store_question = [22322, 11311] # Sum is 33633. No V1.MC
alter_question = [12133, 56133] # Sum is 68266. No V1.MC
a_run_intervention("Override D2/D2'. Expect A2 digit impact. Expect 68666", 14, 0, [0,2], store_question, alter_question)

# Confirmed that P12.L0.H0 and H2 both impact A2, and together sum D2 and D2'

In [None]:
print( "Test claim that P14.L0.H1 calculates V1.C1 but also relies on P10.V1.C2, impacting A2 accuracy")
print()


store_question = [55555, 44454] # Sum is 100009. Has V1.MC
alter_question = [22222, 33333] # Sum is 055555. No V1.MC
a_run_intervention("Override V1.MC impacting V1.C1. Expect A2 digit impact. Expect 055655", 14, 0, [1], store_question, alter_question) # Get 055655. Correct
a_run_intervention("Override V1.MC impacting V1.C1. Expect A2 digit impact. Expect 055655", 10, 0, [1], store_question, alter_question) # Get 055555. No impact.

store_question = [55590, 44490] # Sum is 100080. Has V1.MC
alter_question = [12345, 54321] # Sum is 066666. No V1.MC
a_run_intervention("Override V1.MC impacting V1.C1. Expect A2 digit impact. Expect 066766", 14, 0, [1], store_question, alter_question) # Get 066766. Correct
a_run_intervention("Override V1.MC impacting V1.C1. Expect A2 digit impact. Expect 066766", 10, 0, [1], store_question, alter_question) # Get 066666. No impact.

store_question = [12345, 54321] # Sum is 066666. No V1.MC
alter_question = [55590, 44490] # Sum is 100080. Has V1.MC
a_run_intervention("Override V1.MC impacting V1.C1. Expect A2 digit impact. Expect 100980", 14, 0, [1], store_question, alter_question) # Get 100980. Correct
a_run_intervention("Override V1.MC impacting V1.C1. Expect A2 digit impact. Expect 100980", 10, 0, [1], store_question, alter_question) # Get 100080. No impact.

# Above shows:
# - P14.L0.H1 behaviour is different from P10.L0.H1 behaviour
# - P14.L0.H1 does not simply copy P10.L0.H1 (although this would be a valid way to get perfect accuracy in A2)
print()

store_question = [12345, 54321] # Sum is 066666. No V1.MC
alter_question = [55555, 44445] # Sum is 100000. Has V0.MC, V1.MC, V1.C2
a_run_intervention("Override V1.MC impacting V1.C2. Expect A2 digit impact. Expect 100900", 14, 0, [1], store_question, alter_question) # Get 100900. Correct
a_run_intervention("Override V1.MC impacting V1.C1. Expect A2 digit impact. Expect 099900", 10, 0, [1], store_question, alter_question) # Get 099900. Correct

store_question = [22222, 33333] # Sum is 055555. No V1.MC
alter_question = [66663, 33337] # Sum is 100000. Has V0.MC, V1.MC, V1.C2
a_run_intervention("Override V1.MC impacting V1.C2. Expect A2 digit impact. Expect 100900", 14, 0, [1], store_question, alter_question) # Get 100900. Correct
a_run_intervention("Override V1.MC impacting V1.C1. Expect A2 digit impact. Expect 099900", 10, 0, [1], store_question, alter_question) # Get 099900. Correct

# Above shows:
# - P14.L0.H1 does rely on P10.L0.H1 for V1.C2 information when V1.C != V1.C2
# - P14.L0.H1 calculates V1.C information itself from D1+D1'.

# Overall confirmed: P14.L0.H1 calculates V1.C1 but also relies on P10.V1.C2 when V1.C != V1.C2. Impacts A2

In [None]:
print( "Test claim that P15.L0.H0 and H2 performs V1.BA = (D1 + D1’) % 10 impacting A1 accuracy")
print()

store_question = [22242, 11141] # Sum is 33383. No V0.MC
alter_question = [12322, 56523] # Sum is 68845. No V0.MC
a_run_intervention("Override D1/D1'. Expect A1 digit impact. Expect 68885", 15, 0, [0,2], store_question, alter_question)

# Confirmed that P15.L0.H0 and H2 both impact A1, and together sum D1 and D1'

In [None]:
print( "Test claim that P15.L0.H1 performs V0.MC = (D0 + D0’) / 10 impacting A1 (A one) accuracy")
print()

store_question = [22244, 11149] # Sum is 33393. Has V0.MC
alter_question = [12342, 56513] # Sum is 68855. No V0.MC
a_run_intervention("Override D0.MC. Expect A1 digit impact. Expect 68865", 15, 0, [1], store_question, alter_question)

# Now test counter-claim that an intervention where both questions do NOT generate a D0.MC has NO impact on A1
store_question = [22242, 11141] # Sum is 33383. No V0.MC
alter_question = [12342, 56523] # Sum is 68865. No V0.MC
a_run_intervention("No impact expected", 15, 0, [1], store_question, alter_question)

# Confirmed that P15.L0.H1: Triggers when D0 + D0' > 10. Impacts A1 digit by 1

In [None]:
print( "Test claim that P16.L0.H0 and H2 performs D0.BA = (D0 + D0’) % 10 impacting A0 accuracy")
print()

store_question = [22225, 11114] # Sum is 33339
alter_question = [12342, 56563] # Sum is 68905
a_run_intervention("Override D0/D0'. Expect A0 digit impact. Expect 68909", 16, 0, [0,2], store_question, alter_question)

store_question = [22228, 11119] # Sum is 33347
alter_question = [12342, 56563] # Sum is 68905
a_run_intervention("Override D0/D0'. Expect A0 digit impact. Expect 68907", 16, 0, [0,2], store_question, alter_question)

# Confirmed that P16.L0.H0 and H2 both impact A0, and together sum D0 and D0'

# Part 22: Ablate each NEURON in each useful MLP layer. What is impact on loss?

Uses n_ prefix. Determines which neurons are useful, so we can focus on them in Part 25 & manually view them.






In [None]:
class N_Config():
  position : int = 0  # zero-based token-position to ablate
  layer : int = 0 # zero-based layer to ablate. 0 to cfg.n_layers
  neuron : int = 0 # zero-based neuron to ablate. 0 to cfg.d_mlp
  questions = varied_questions
  output = PrettyTable()
  hook_calls : int = 0


ncfg = N_Config()


def n_reset():
  global ncfg

  ncfg.output = PrettyTable()
  ncfg.output.field_names = ["Position", "MLP Layer", "Neuron", "% Fails", "% Fails by Case", "# Fails by Patterns"]
  ncfg.hook_calls = 0


def n_mlp_hook_post(value, hook):
  global ncfg

  ncfg.hook_calls += 1
  #print( "In n_mlp_hook_post", value.shape) # Get [1, 18, 2040] = ???, cfg.n_ctx, cfg.d_mlp

  # Mean ablate. Copy the mean resid post values in position N to the MLP
  value[:,ncfg.position,ncfg.neuron] =  mean_mlp_hook_post[:,ncfg.position,ncfg.neuron].clone()

In [None]:
class UsefulNeuron():
  position : int = 0  # zero-based token-position to ablate
  layer : int = 0 # zero-based layer to ablate. 0 to cfg.n_layers
  neuron : int = 0 # zero-based neuron to ablate. 0 to cfg.d_mlp


useful_neurons = []


def n_perform_core(show_all = False):
  global ncfg

  the_hook = [(l_mlp_hook_post_name[ncfg.layer], n_mlp_hook_post)]
  q_predict_questions(ncfg.questions, the_hook)

  num_fails = q_total_complexity_fails()
  if show_all or (num_fails > 0):
    perc_fails = round(100 * num_fails / ncfg.questions.shape[0])
    (pattern_results, top_pattern) = get_sorted_impact_fails()

    ncfg.output.add_row([str(ncfg.position), str(ncfg.layer), str(ncfg.neuron), perc_fails, q_get_complexity_fails(), pattern_results])

    useful_neuron = UsefulNeuron()
    useful_neuron.position = ncfg.position
    useful_neuron.layer = ncfg.layer
    useful_neuron.neuron = ncfg.neuron

    useful_neurons += [useful_neuron]

In [None]:
def run_neurons():
  q0 = make_s0_questions()
  q1 = make_s1_questions()
  q2 = make_s2_questions()
  q3 = make_s3_questions()
  q4 = make_s4_questions()
  ncfg.questions = torch.vstack((q0.cuda(), q1.cuda(), q2.cuda(), q3.cuda(), q4.cuda()))
  print( "# questions", ncfg.questions.shape[0], "# usefulcells", len(ucfg.useful_cells))

  for useful_cell in ucfg.useful_cells:
    if not useful_cell.is_head():
      n_reset()
      ncfg.position = useful_cell.position
      ncfg.layer = useful_cell.layer

      # For each useful MLP layer, check the neurons.
      for the_neuron in range(cfg.d_mlp):
        ncfg.neuron = the_neuron
        n_perform_core()

      print(ncfg.output.get_formatted_string(out_format=cfg.table_out_format))

#run_neurons()

#Part 25: MLP Visualisation

In [None]:
import einops
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from IPython.display import clear_output


# number of questions in batch that generated sample_cache
num_questions = ncfg.questions.shape[0]


def get_mlp_data(data_set_name):

  data_set = sample_cache[data_set_name]
  # print( data_set_name + " shape", data_set.shape) # 239, 18, 2040 = num_questions, n_ctx, d_mlp

  raw_data = data_set[:,-3]
  # print( "raw_data shape", raw_data.shape) # 239, 2040 = num_questions, d_mlp

  answer = einops.rearrange(raw_data, "(x y) d_mlp -> x y d_mlp", x=num_questions).cpu().numpy()
  # print( "answer shape", answer.shape) # 239, 1, 2040 = num_questions, ??, d_mlp

  return answer


l0_mlp_hook_pre_sq = get_mlp_data(utils.get_act_name('pre', 0))
l0_mlp_hook_post_sq = get_mlp_data(utils.get_act_name('post', 0))
l1_mlp_hook_pre_sq = get_mlp_data(utils.get_act_name('pre', 1))
l1_mlp_hook_post_sq = get_mlp_data(utils.get_act_name('post', 1))


def plot_mlp_neuron_activation(pos: int):
    clear_output()

    l0_mlp_pre_data = l0_mlp_hook_pre_sq[:,:,pos]
    l0_mlp_post_data = l0_mlp_hook_post_sq[:,:,pos]
    l1_mlp_pre_data = l1_mlp_hook_pre_sq[:,:,pos]
    l1_mlp_post_data = l1_mlp_hook_post_sq[:,:,pos]

    fig, axs = plt.subplots(1, 2, figsize=(8,4))

    plot = axs[0].imshow(l1_mlp_pre_data, cmap='magma', vmin=0, vmax=1)
    cbar = plt.colorbar(plot, fraction=0.1)
    cbar.set_label(r'l0_mlp_pre_data {}'.format(pos))
    #axs[0].set_ylim(-0.5, 99.5)
    #axs[0].set_yticks(range(100), labels=range(100), size=5.5);
    #axs[0].set_xticks(range(100), labels=range(100), size=5.5, rotation='vertical');

    plot = axs[1].imshow(l1_mlp_post_data, cmap='magma', vmin=0, vmax=1)
    cbar = plt.colorbar(plot, fraction=0.1)
    cbar.set_label(r'l0_mlp_post_data {}'.format(pos))
    #axs[0].set_ylim(-0.5, 99.5)
    #axs[0].set_yticks(range(100), labels=range(100), size=5.5);
    #axs[0].set_xticks(range(100), labels=range(100), size=5.5, rotation='vertical');


interact(plot_mlp_neuron_activation, pos=widgets.IntText(value=0, description='Index:'))