# Accurate Integer Multiplication in Transformers - Train the Model

This CoLab defines, trains and analyses a Transformer model that performs integer addition, subtraction and multiplication e.g. 33357+82243=115600, 012345-034567=-012323 and 00345*00123=007935. Each digit is a separate token. For 5 digit questions, the model is given 12 "question" (input) tokens, and must then predict the corresponding 6 "answer" (output) tokens.

This CoLab trains the model, storing the results to Google Drive.

## Tips for using the Colab
 * You can run and alter the code in this CoLab notebook yourself in Google CoLab ( https://colab.research.google.com/ ).
 * To run the notebook, in Google CoLab, **you will need to** go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.
 * Some graphs are interactive!
 * Use the table of contents pane in the sidebar to navigate.
 * Collapse irrelevant sections with the dropdown arrows.
 * Search the page using the search in the sidebar, not CTRL+F.

# Part 1: Configuration

In [None]:
# Tokens used in vocab. (Token indexes 0 to 9 represent digits 0 to 9)
PLUS_INDEX = 10
MINUS_INDEX = 11
EQUALS_INDEX = 12
MULT_INDEX = 13
DIV_INDEX = 14
MAX_INDEX = DIV_INDEX

class Config():
  #@markdown Model
  n_layers: int = 3 #@param
  n_heads: int = 4 #@param

  d_vocab: int = MAX_INDEX+1
  d_model: int = ( 512 // n_heads ) * n_heads # About 512, and divisible by n_heads
  d_mlp: int = 4 * d_model
  d_head: int = d_model // n_heads  # About 170 when n_heads == 3
  seed: int = 129000 #@param

  #@markdown Data
  n_digits: int = 5 #@param
  n_ctx: int = 3 * n_digits + 3
  act_fn: str = 'relu'
  batch_size: int = 64 #@param

  #@markdown Optimizer
  n_training_steps: int = 40000 #@param
  lr: float = 0.00008 #@param
  weight_decay: int = 0.1 #@param

  # Save graphs to CoLab temp files as PDF and HTML. Can manually export files for re-use in papers.
  save_graph_to_file: bool = True

  #@markdown Actions

  # Percent of questions that are multiplication and subtraction. Rest are addition.
  perc_mult: int = 100 #@param
  perc_sub: int = 0 #@param


cfg = Config()

# Part 2: Import libraries
Imports standard libraries. Will ask for access to your Google to write model weightings

In [None]:
from google.colab import drive
from pathlib import Path

In [None]:
# Training saves the trained model weights to a file in your Google Drive. You will need to give permission for this CoLab to access your Google Drive.
# Loading loads the model from your Google Drive. Avoids the say 10mins spent on training the model.

GLOBAL=True
if GLOBAL:
    drive.mount('/content/drive', force_remount=False)
    rootdir=Path('/content/drive/MyDrive/AI/CoLabOutput/')
else:
    rootdir=Path('./')

if cfg.perc_mult == 100:
  full_fname = 'mult'
else:
  if cfg.perc_sub == 100:
    full_fname = 'sub'
  else:
    if cfg.perc_mult == 0 and cfg.perc_sub == 0:
      full_fname = 'add'
    else:
      full_fname = 'mix'

epochs = str(cfg.n_training_steps//1000) + "K"
full_fname += '_D{}_L{}_H{}_dmodel{}_dhead{}_ctx{}_seed{}_train{}.pt'.format(cfg.n_digits, cfg.n_layers, cfg.n_heads, cfg.d_model, cfg.d_head, cfg.n_ctx, cfg.seed, epochs)
model_save_location = rootdir/f'{full_fname}'

print('model will save to {}'.format(str(model_save_location)))

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
model will save to /content/drive/MyDrive/AI/CoLabOutput/mult_D5_L3_H4_dmodel512_dhead128_ctx18_seed129000_train40K.pt


In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")

    %pip install kaleido
    %pip install transformer_lens
    %pip install circuitsvis
    %pip install torchtyping
    %pip install transformers

except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook


In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import kaleido
import plotly.io as pio

if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

import plotly.express as px
import plotly.graph_objects as go

Using renderer: colab


In [None]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import tqdm.auto as tqdm
import random
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import matplotlib.pyplot as plt
import circuitsvis as cv

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

# Part 3: Create model
This section defines the token embedding / unembedding and creates the model.

In [None]:
# Embedding / Unembedding

def tokens_to_string(tokens):
    tokens = utils.to_numpy(tokens)
    x = "".join([str(i) for i in tokens[:cfg.n_digits]])
    y = "".join([str(i) for i in tokens[cfg.n_digits+1:cfg.n_digits*2+1]])
    z = "".join([str(i) for i in tokens[cfg.n_ctx-cfg.n_digits-1:]])
    equals = "="
    operator = "+"
    return f"{x}{operator}{y}{equals}{z}"

def string_to_tokens(string, batch: bool=False):
    lookup = {str(i):i for i in range(10)}
    lookup['+']=PLUS_INDEX
    lookup['-']=MINUS_INDEX
    lookup['=']=EQUALS_INDEX
    lookup['*']=MULT_INDEX
    lookup['\\']=DIV_INDEX

    tokens = [lookup[i] for i in string if i not in '\n ']
    if batch:
        return torch.tensor(tokens)[None, :]
    else:
        return torch.tensor(tokens)

In [None]:
# Transformer creation

# Structure is documented at https://neelnanda-io.github.io/TransformerLens/transformer_lens.html#transformer_lens.HookedTransformerConfig.HookedTransformerConfig
ht_cfg = HookedTransformerConfig(
    n_layers = cfg.n_layers,
    n_heads = cfg.n_heads,
    d_model = cfg.d_model,
    d_head = cfg.d_head,
    d_mlp = cfg.d_mlp,
    act_fn = cfg.act_fn,
    normalization_type = 'LN',
    d_vocab = cfg.d_vocab,
    d_vocab_out = cfg.d_vocab,
    n_ctx = cfg.n_ctx,
    init_weights = True,
    device = "cuda",
    seed = cfg.seed,
)

model = HookedTransformer(ht_cfg)

optimizer = optim.AdamW(model.parameters(),
                        lr = cfg.lr,
                        weight_decay = cfg.weight_decay,
                        betas = (0.9, 0.98))

max_iter = cfg.n_training_steps
warmup_iter = max_iter // 5
scheduler1 = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, total_iters=int(warmup_iter))
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=int(np.ceil((max_iter-warmup_iter))))
scheduler  = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[int(warmup_iter)])

# Part 4: Data Generator
This section defines the loss function and the training/testing data generator.


In [None]:
# Loss functions

# Calculate the per-token probability by comparing a batch of prediction "logits" to answer "tokens"
def logits_to_tokens_loss(logits, tokens):

  # The last "n_digit+1" tokens are the addition answer probabilities
  ans_logits = logits[:, -(cfg.n_digits+2):-1]

  # Convert raw score (logits) vector into a probability distribution.
  # Emphasize the largest scores and suppress the smaller ones, to make them more distinguishable.
  ans_probs = F.log_softmax(ans_logits.to(torch.float64), dim=-1)

  max_indices = torch.argmax(ans_probs, dim=-1)

  # The last "n_digit+1" tokens are the model’s answer.
  ans_tokens = tokens[:, -(cfg.n_digits+1):]

  # Extract values from the ans_probs tensor, based on indices from the ans_tokens tensor
  ans_loss = torch.gather(ans_probs, -1, ans_tokens[:, :, None])[..., 0]

  return ans_loss, max_indices

# Calculate loss as negative of average per-token mean probability
def loss_fn(ans_loss):
  return -ans_loss.mean(0)

In [None]:
# Define "iterator" data generator function. Invoked using next().
# "Addition" batch entries are formated XXXXX+YYYYY=ZZZZZZ e.g. 55003+80002=135005
# "Subtraction" batch entries are formated XXXXX-YYYYY=ZZZZZZ e.g. 55003-80002=-24999, 80002-55003=024999
# "Multiplication" batch entries are formated 00XXX*00YYY=ZZZZZZ e.g. 00123*00784=096432
def data_generator( ):
    torch.manual_seed(cfg.seed)
    while True:
        #generate a batch of questions (answers calculated below)
        batch = torch.zeros((cfg.batch_size, cfg.n_ctx)).to(torch.int64)
        x = torch.randint(0, 10, (cfg.batch_size, cfg.n_digits))
        y = torch.randint(0, 10, (cfg.batch_size, cfg.n_digits))

        batch_rand = random.randint(1, 100)
        batch_op = MULT_INDEX if batch_rand <= cfg.perc_mult else MINUS_INDEX if batch_rand <= cfg.perc_mult + cfg.perc_sub else PLUS_INDEX

        if batch_op == MULT_INDEX:
          # Convert from NNNNN*NNNNN= to 00NNN*00NNN=
          x[:, 0] = 0
          x[:, 1] = 0
          y[:, 0] = 0
          y[:, 1] = 0

        batch[:, :cfg.n_digits] = x
        batch[:, cfg.n_digits] = batch_op
        batch[:, 1+cfg.n_digits:1+cfg.n_digits*2] = y
        batch[:, 1+cfg.n_digits*2] = EQUALS_INDEX

        # Convert each row into a 5-digit number
        x_values = x[:, 0] * 10000 + x[:, 1] * 1000 + x[:, 2] * 100 + x[:, 3] * 10 + x[:, 4]
        y_values = y[:, 0] * 10000 + y[:, 1] * 1000 + y[:, 2] * 100 + y[:, 3] * 10 + y[:, 4]

        # Elementwise operations giving 1D tensor
        if batch_op == MULT_INDEX:
          answers = x_values * y_values
        else:
          if batch_op == MINUS_INDEX:
            answers = x_values - y_values
          else:
            answers = x_values + y_values

        for i in range(cfg.batch_size):
          answer = answers[i]

          if answer < 0:
            batch[:, 2+cfg.n_digits*2] = MINUS_INDEX
            answer = - answer

          for j in range(cfg.n_digits+1):
            batch[i, cfg.n_ctx-j-1] = answer % 10
            answer = answer // 10
            if answer == 0:
                break

        yield batch.cuda()

In [None]:
ds = data_generator()

tokens = next(ds)

print(tokens[:5,:])

tensor([[ 0,  0,  4,  4,  4, 13,  0,  0,  5,  7,  1, 12,  2,  5,  3,  5,  2,  4],
        [ 0,  0,  7,  8,  3, 13,  0,  0,  8,  8,  7, 12,  6,  9,  4,  5,  2,  1],
        [ 0,  0,  9,  9,  9, 13,  0,  0,  2,  2,  8, 12,  2,  2,  7,  7,  7,  2],
        [ 0,  0,  4,  9,  8, 13,  0,  0,  2,  6,  3, 12,  1,  3,  0,  9,  7,  4],
        [ 0,  0,  8,  3,  4, 13,  0,  0,  2,  8,  0, 12,  2,  3,  3,  5,  2,  0]],
       device='cuda:0')


# Part 5: Train multiplication model with Infinite Data
Train model for n_training_steps, storing train_losses per epoch.

Each training step (of n_training_steps) new training data (a batch of batch_size tokens) is generated and the model is trained and loss calculated on it. No separate "testing" data is needed, as the training data is unique each step. Memorisation of past training data by the model (if any) is minimally beneficial. For 5 digit addition there are 10 billion possible questions, and model training is on ~2 million questions.

In [None]:
def print_config():
  print("%Mult=", cfg.perc_mult, "%Sub=", cfg.perc_sub, "digits=", cfg.n_digits, "heads=", cfg.n_heads, "layers=", cfg.n_layers, "ctx=", cfg.n_ctx, "seed=", cfg.seed, "training_steps=", cfg.n_training_steps)

In [None]:
# Initialise the data generator
ds = data_generator()

In [None]:
# Train the model
train_losses_list = []
per_token_train_losses_list = []

for epoch in tqdm.tqdm(range(cfg.n_training_steps)):

  tokens = next(ds)
  logits = model(tokens)

  per_token_train_losses_raw, _ = logits_to_tokens_loss(logits, tokens)
  per_token_train_losses = loss_fn(per_token_train_losses_raw)
  per_token_train_losses_list.append(utils.to_numpy(per_token_train_losses))

  train_loss = per_token_train_losses.mean()
  train_loss.backward()
  train_losses_list.append(train_loss.item())

  optimizer.step()
  scheduler.step()
  optimizer.zero_grad()

  if epoch % 100 == 0:
    the_loss = train_loss.item()
    print(epoch, the_loss)

  0%|          | 0/40000 [00:00<?, ?it/s]

0 2.7543175178955934
100 2.3049452856905663
200 2.1979681145449157
300 2.1967886171802795
400 2.138759447254304
500 2.0609740325318753
600 1.9511289317281186
700 1.8498127050717497
800 1.735269290898528
900 1.6297449144281768
1000 1.5749387613188681
1100 1.552785613323952
1200 1.540126557594336
1300 1.4857083509574693
1400 1.4404442771071877
1500 1.4835853444005953
1600 1.4319663233426811
1700 1.4252510732850872
1800 1.407832119758263
1900 1.4085540245295096
2000 1.3605676959473947
2100 1.3913708242387024
2200 1.3451829431017348
2300 1.3294406100821443
2400 1.3716071842401356
2500 1.3597743599780094
2600 1.2566654792113074
2700 1.3522836419213573
2800 1.3529823052027494
2900 1.2778078765775644
3000 1.2881200801630868
3100 1.2944332553691984
3200 1.2646078416896942
3300 1.2717859478430804
3400 1.2408022278616695
3500 1.223415434440483
3600 1.2674888689560606
3700 1.2738858564401467
3800 1.207969567860267
3900 1.30014140115737
4000 1.2543828101359016
4100 1.1741731942198244
4200 1.174094

In [None]:
print("Saving model to Google drive", model_save_location)
torch.save(model.state_dict(), model_save_location)

Saving model to Google drive /content/drive/MyDrive/AI/CoLabOutput/mult_D5_L3_H4_dmodel512_dhead128_ctx18_seed129000_train40K.pt


# Part 6: Final Training Loss

- add_D5_L2_H3_dmodel510_dhead170_ctx18_seed129000_train30K has loss of 1.38e-07
- sub_D5_L2_H3_dmodel510_dhead170_ctx18_seed129000_train30K has loss of 2.501e-06
- mult_D5_L2_H3_dmodel510_dhead170_ctx18_seed129000_train30K has loss of 0.594356712
- mult_D5_L3_H3_dmodel510_dhead170_ctx18_seed129000_train30K has loss of 0.408103767
- mult_D5_L3_H4_dmodel512_dhead128_ctx18_seed129000_train30K has loss of 0.377608407
-mult_D5_L3_H4_dmodel512_dhead128_ctx18_seed129000_train40K has loss of 0.34133034


In [None]:
print_config()

final_training_loss = round((train_losses_list[-5]+train_losses_list[-4]+train_losses_list[-3]+train_losses_list[-2]+train_losses_list[-1])/5,9)
print( "Final training loss", final_training_loss)

%Mult= 100 %Sub= 0 digits= 5 heads= 4 layers= 3 ctx= 18 seed= 129000 training_steps= 40000
Final training loss 0.34133034


# Part 7: Line Graphs

This section analyses the training loss by graphing it at a high level.

The loss curve for all digits show visible inflection points (bumps), but is too high level to help understand the algorithm.

When this graph is decomposed into 'per digit' graphs, the interesting distinct 'per digit' curves appear, showing each digit is being refined semi-independently, with the model algorithm refining each digit separately.

In [None]:
epochs_to_graph=1200

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

# Helper function to plot multiple lines
def lines(raw_lines_list, x=None, mode='lines', labels=None, xaxis='Epoch', yaxis='Loss', title = '', log_y=False, hover=None, all_epochs=True, **kwargs):

    lines_list = raw_lines_list if all_epochs==False else [row[:epochs_to_graph] for row in raw_lines_list]
    log_suffix = '' if log_y==False else ' (Log)'
    epoch_suffix = '' if all_epochs==False else ' (' + str(epochs_to_graph) + ' steps)'
    full_title = title + log_suffix + epoch_suffix

    if type(lines_list)==torch.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    if cfg.save_graph_to_file :
      fig = go.Figure(layout={})
      print(full_title)
    else:
      fig = go.Figure(layout={'title':full_title})

    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis + log_suffix)
    for c, line in enumerate(lines_list):
        if type(line)==torch.Tensor:
            line = utils.to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    if cfg.save_graph_to_file:
        fig.update_layout(margin=dict(l=10, r=10, t=10, b=10),width=1200,height=300)

    fig.show(bbox_inches="tight")

    if cfg.save_graph_to_file:
        filename = full_title.replace(" ", "").replace("(", "").replace(")", "").replace("&", "").replace(",", "").replace("%", "")   +'.pdf'
        pio.write_image(fig, filename)



title_suffix = ' Loss Curves for ' + str(cfg.n_digits) + ' digit addition'
per_token_losses = np.stack(per_token_train_losses_list, axis=0)

line(train_losses_list,
    title=title_suffix)

all_epochs = True;
for i in range(2):
  lines([per_token_losses[:, i] for i in range(1+cfg.n_digits)]+[train_losses_list],
        labels = [f'digit {i}' for i in range(1+cfg.n_digits)]+['all_digits'],
        title='Per digit'+title_suffix,
        all_epochs=all_epochs)

  all_epochs = False;

for i in range(1+cfg.n_digits):
  print('Final Loss for digit ' + str(i) + ' is ', per_token_losses[-1, i])

# Part 8: Questions Set Up

Create sets of sample questions (by task) to ask the model to predict

In [None]:
# Insert a number into the question
def insert_question_number(the_question, index, first_digit_index, the_digits, n):

  last_digit_index = first_digit_index + the_digits - 1

  for j in range(the_digits):
    the_question[index, last_digit_index-j] = n % 10
    n = n // 10


# Create a single question
def make_a_question(the_question, index, q1, q2):
  a = q1 + q2

  insert_question_number(the_question, index, 0, cfg.n_digits, q1)

  the_question[index, cfg.n_digits] = PLUS_INDEX

  insert_question_number( the_question, index, cfg.n_digits+1, cfg.n_digits, q2)

  the_question[index, 2*cfg.n_digits+1] = EQUALS_INDEX
  offset = 2

  insert_question_number(the_question, index, 2*cfg.n_digits + offset, cfg.n_digits+1, q1+q2)


# Create a batch of questions from a 2D matrix of ints
def make_questions(q_matrix):
  length = len(q_matrix)

  questions = torch.zeros((length, cfg.n_ctx)).to(torch.int64)

  limit = 10 ** cfg.n_digits
  for i in range(length):
    if (q_matrix[i][0] < limit) and (q_matrix[i][1] < limit) :
      make_a_question(questions, i, q_matrix[i][0], q_matrix[i][1])

  return questions


def prediction_to_string(max_indices):
  answer = "".join([str(i) for i in utils.to_numpy(max_indices)[0]])
  return answer;

In [None]:
# Analyse the question and return the use case as BA, MC, SimpleUS9 or CascadeUS9
def get_question_case(q):
  qa = utils.to_numpy(q)
  qn = qa[:2*cfg.n_digits+2]

  # Locate the MC and MS digits (if any)
  mc = torch.zeros( cfg.n_digits).to(torch.int64)
  ms = torch.zeros( cfg.n_digits).to(torch.int64)
  for dn in range(cfg.n_digits):
    if qn[dn] + qn[dn + cfg.n_digits + 1] == 9:
      ms[cfg.n_digits-1-dn] = 1
    if qn[dn] + qn[dn + cfg.n_digits +1] > 9:
      mc[cfg.n_digits-1-dn] = 1

  # Calculate the use case of a question
  if torch.sum(mc) == 0:
    return "BA"

  if torch.sum(ms) == 0:
    return "MC1"

  for dn in range(cfg.n_digits):
    if dn < cfg.n_digits-2 and mc[dn] == 1 and ms[dn+1] == 1 and ms[dn+2] == 1:
      return "CascadeUS9"

  for dn in range(cfg.n_digits):
    if dn < cfg.n_digits-1 and mc[dn] == 1 and ms[dn+1] == 1:
      return "SimpleUS9"

  return "MC1"

In [None]:
# Manually create some questions that strongly test one use case


def make_ba_questions():
    return make_questions(
      [[12345, 33333],
      [33333, 12345],
      [45762, 33113],
      [888, 11111],
      [2362, 23123],
      [15, 81],
      [1000, 4440],
      [4440, 1000],
      [24033, 25133],
      [23533, 21133],
      [32500, 1],
      [31500, 1111],
      [5500, 12323],
      [4500, 2209],
      [ 33345, 66643], # =099988
      [ 66643, 33345], # =099988
      [10990, 44000],
      [60000, 30000],
      [10000, 20000]])


def make_uc1_questions():
    return make_questions(
      [[ 15, 45],
      [ 25, 55],
      [ 35, 59],
      [ 40035, 40049],
      [ 5025, 5059],
      [ 15, 65],
      [ 44000, 46000],
      [ 70000, 40000],
      [ 15000, 25000],
      [ 35000, 35000],
      [ 45000, 85000],
      [ 67000, 85000],
      [ 99000, 76000],
      [ 1500, 4500],
      [ 2500, 5500],
      [ 3500, 5900],
      [ 15020, 45091],
      [ 25002, 55019],
      [ 35002, 59019]])


def make_simple_us9_questions():
    return make_questions(
      [[ 55, 45],
      [ 45, 55],
      [ 45, 59],
      [ 35, 69],
      [ 25, 79],
      [ 15, 85],
      [ 15, 88],
      [ 15508, 14500],
      [ 14508, 15500],
      [ 24533, 25933],
      [ 23533, 26933],
      [ 32500, 7900],
      [ 31500, 8500],
      [ 550, 450],
      [ 450, 550],
      [ 10880, 41127],
      [ 41127, 10880],
      [ 12386, 82623]])


def make_cascade_us9_questions(clean = True):
    return make_questions(
      # These are two level UseSum9 cascades
      [[ 555, 445],
      [ 3340, 6660],
      [ 8880, 1120],
      [ 1120, 8880],
      [ 123, 877],
      [ 877, 123],
      [ 321, 679],
      [ 679, 321],
      [ 1283, 88786],
      # These are three level UseSum9 cascades
      [ 5555, 4445],
      [ 55550, 44450],
      [ 334, 666],
      [ 3340, 6660],
      [ 33400, 66600],
      [ 888, 112],
      [ 8880, 1120],
      [ 88800, 11200],
      [ 1234, 8766],
      [ 4321, 5679],
      # These are four level UseSum9 cascades
      [ 44445, 55555],
      [ 33334, 66666],
      [ 88888, 11112],
      [ 12345, 87655],
      [ 54321, 45679],
      [ 45545, 54455],
      [ 36634, 63366],
      [ 81818, 18182],
      [ 87345, 12655],
      [ 55379, 44621]])


# These questions focus mainly on 1 digit at a time
# (We're assuming that the 0 + 0 digit additions are trivial bigrams)
def make_answerdigit_questions():
    return make_questions(
      [[ 1, 0],
      [ 4, 3],
      [ 5, 5],
      [ 8, 1],
      [ 40, 30],
      [ 44, 46],
      [ 400, 300],
      [ 440, 460],
      [ 800, 100],
      [ 270, 470],
      [ 600, 300],
      [ 4000, 3000],
      [ 4400, 4600],
      [ 6000, 3000],
      [ 7000, 4000],
      [ 40000, 30000],
      [ 44000, 46000],
      [ 60000, 30000],
      [ 70000, 40000],
      [ 10000, 20000],
      [ 15000, 25000],
      [ 35000, 35000],
      [ 45000, 85000],
      [ 67000, 85000],
      [ 99000, 76000],
      [ 76000, 99000]])


# Returns 128 random and ~100 manually-chosen questions
def make_varied_questions():
  q0, _, _, _, _, _ = next(ds)
  q1 = make_ba_questions()
  q2 = make_uc1_questions()
  q3 = make_simple_us9_questions()
  q4 = make_cascade_us9_questions()
  q5 = make_answerdigit_questions()
  q6, _, _, _, _, _ = next(ds)

  questions = torch.vstack((q0.cuda(), q1.cuda(), q2.cuda(), q3.cuda(), q4.cuda(), q5.cuda(), q6.cuda()))

  return questions

In [None]:
# Test that the get_question_case code works as expected
def unit_test_get_question_case_core(correct_case, questions):
  num_questions = questions.shape[0]
  print( correct_case, "#Questions=", num_questions)
  for i in range(num_questions):
    question_case = get_question_case(questions[i])
    if question_case != correct_case:
      print( "Case mismatch:", correct_case, question_case, questions[i])

def unit_test_get_question_case():
  unit_test_get_question_case_core( "BA", make_ba_questions())
  unit_test_get_question_case_core( "MC1", make_uc1_questions())
  unit_test_get_question_case_core( "SimpleUS9", make_simple_us9_questions())
  unit_test_get_question_case_core( "CascadeUS9", make_cascade_us9_questions())

unit_test_get_question_case()

In [None]:
verbose = True

In [None]:
# Build a test batch of 64 random and ~100 manually-chosen questions
varied_questions = make_varied_questions();


# Run the sample batch, gather the cache
model.reset_hooks()
model.set_use_attn_result(True)
sample_logits, sample_cache = model.run_with_cache(varied_questions.cuda())
print(sample_cache) # Gives names of datasets in the cache
sample_losses_raw, sample_max_indices = logits_to_tokens_loss(sample_logits, varied_questions.cuda())
sample_loss_mean = utils.to_numpy(loss_fn(sample_losses_raw).mean())
print("Sample Mean Loss", sample_loss_mean) # Loss < 0.04 is good


# attn.hook_z is the "attention head output" hook point name (at a specified layer)
# Used in h_set_attn_hook_z and t_*_hook functions
l_attn_hook_z_name = ['blocks.0.attn.hook_z','blocks.1.attn.hook_z']
sample_attn_z = sample_cache[l_attn_hook_z_name[0]]
print("Sample", l_attn_hook_z_name[0], sample_attn_z.shape) # gives [239, 18, 3, 170] = ???, cfg.n_ctx, n_heads, d_head
mean_attn_z = torch.mean(sample_attn_z, dim=0, keepdim=True)
print("Mean", l_attn_hook_z_name[0], mean_attn_z.shape) # gives [1, 18, 3, 170] = 1, cfg.n_ctx, n_heads, d_head


# hook_resid_pre is the "pre residual memory update" hook point name (at a specified layer)
# Used in o_*_hook functions
l_hook_resid_pre_name = ['blocks.0.hook_resid_pre','blocks.1.hook_resid_pre']


# hook_resid_post is the "post residual memory update" hook point name (at a specified layer)
# Used in c_*_hook and t_*_hook functions
l_hook_resid_post_name = ['blocks.0.hook_resid_post','blocks.1.hook_resid_post']
sample_resid_post = sample_cache[l_hook_resid_post_name[0]]
print("Sample", l_hook_resid_post_name[0], sample_resid_post.shape) # gives [239, 18, 510] = ???, cfg.n_ctx, d_model
mean_resid_post = torch.mean(sample_resid_post, dim=0, keepdim=True)
print("Mean", l_hook_resid_post_name[0], mean_resid_post.shape) # gives [1, 18, 510] = 1, cfg.n_ctx, d_model


# mlp.hook_post is the "MLP layer" hook point name (at a specified layer)
# Used in m_*_hook functions
l_mlp_hook_post_name = ['blocks.0.mlp.hook_post','blocks.1.mlp.hook_post']
sample_mlp_hook_post = sample_cache[l_mlp_hook_post_name[0]]
print("Sample", l_mlp_hook_post_name[0], sample_mlp_hook_post.shape) # gives [239, 18, 2040] = ???, cfg.n_ctx, d_model*4
mean_mlp_hook_post = torch.mean(sample_mlp_hook_post, dim=0, keepdim=True)
print("Mean", l_mlp_hook_post_name[0], mean_mlp_hook_post.shape) # gives [1, 18, 2040] = 1, cfg.n_ctx, d_model*4

# Part 9: Attention Patterns
Attention patterns show which token(s) the model's attention heads are paying attention to in each token position of the prediction calculation.

For the default CoLab set up, the  model has 3 attention heads, and performs 5 digit addition. The attention pattern is 18 by 18 squares (as 54321+77779=132100 is 18 tokens). Time proceeds vertically downwards, with one additional token being revealed horizontally at each token position, giving the overall triangle shape. This visualisation provided insights. After the question is fully revealed (at token position 11), each head starts attending to pairs of question digits from left to right (i.e. high-value digits before lower-value digits) giving the “double staircase" shape. The three heads attend to a given digit pair in three different token position, giving a time ordering of heads.

In [None]:
def show_token_attention_patterns(index, layer, token_at_index, use_case):

  the_tokens = [str(token) for token in token_at_index.tolist()]
  if layer == 0:
    tokens_str = tokens_to_string(the_tokens)
    print("Attention patterns for", tokens_str)

  attention_pattern=sample_cache["pattern", layer, "attn"][index]
  display(cv.attention.attention_patterns(
      tokens=the_tokens,
      attention=attention_pattern,
      #attention_head_names=[f"L{layer}H{i}" for i in range(cfg.n_heads)],
  ))


sample_size = 3

# Show attention patterns for some randomly chosen tokens
for i in range(sample_size):
  for layer in range(cfg.n_layers):
    show_token_attention_patterns(i, layer, tokens[i], "Misc")


In [None]:
if cfg.save_graph_to_file:

  tokens_str = []
  for i in range(cfg.n_heads):
    one_token_str = []
    for j in tokens[i]:
      one_token_str.append(str(utils.to_numpy(j)))
    tokens_str.append(one_token_str)

  # Refer https://github.com/callummcdougall/CircuitsVis/blob/main/python/circuitsvis/circuitsvis_demo.ipynb

  # html_object = cv.attention.from_cache(
  #    cache = sample_cache,
  #    tokens = tokens_str, # list of list of strings
  #    return_mode = "html",
  #)

  # Create a CoLab file containing the attention pattern(s) in HTML
  #filename = "AttentionPattern" + str(cfg.n_digits) + "Digits" + str(cfg.n_heads) + "Heads.html"
  #with open(filename, "w") as f:
  #    f.write(html_object.data)

  # Manually download the CoLab "html" file and open in your local browser.
  # Install and use the Edge extension "FireShot" to save a portion of the HTML page as a PDF

# Part 10: Run a prediction question

Create way to get model to predict sample question answers and analysis/show results.

In [None]:
def predict_experiment_question(questions, the_hook, the_threshold):

  c_loss_mean = 0

  for question_num in range(questions.shape[0]):
    q = questions[question_num]

    model.reset_hooks()
    model.set_use_attn_result(True)
    exp_logits = model.run_with_hooks(q.cuda(), return_type="logits", fwd_hooks=the_hook)

    q_2d = q.unsqueeze(0)
    exp_losses_raw, exp_max_indices = logits_to_tokens_loss(exp_logits, q_2d.cuda())
    c_loss_mean = utils.to_numpy(loss_fn(exp_losses_raw).mean())

    # Only show the question if the loss exceeds the threshold (because of the ablated token position)
    if c_loss_mean > the_threshold:
      print(tokens_to_string(q), "ModelAnswer:", exp_answer_str, "Matches:", match_str, "Loss:", c_loss_mean, "Case:", the_case )

  return c_loss_mean

# Part 11 : Is the model 100% accurate?

This is hard to prove. If it does 1M predictions without error, then we assume it is 100%.

This part takes ~90 mins to run for 5-digit 2-layer model.

In [None]:
def null_hook(value, hook):
  global verbose

  verbose = False


def one_million_questions():
  global verbose
  global exp9_threshold
  global exp9_successes
  global exp9_fails

  verbose = False
  exp9_threshold = 0.12
  exp9_successes = 0
  exp9_fails = 0

  num_batches = 1000000//cfg.batch_size
  for epoch in range(num_batches):
      tokens, _, _, _, _, _ = next(ds)

      the_hook = [(l_attn_hook_z_name[0], null_hook)]
      predict_experiment_question(tokens, the_hook, exp9_threshold)

      exp9_fails = total_case_fails()
      if exp9_fails > 0:
        break

      exp9_successes = exp9_successes + len(hcfg.questions)

      if epoch % 100 == 0:
          print("Batch", epoch, "of", num_batches, "#Successes=", exp9_successes)

  print("successes", exp9_successes, "num_fails", exp9_fails)


# Commented out as it takes ~90 minutes to run with cfg.n_layers=2, n_digits=5, train=30K
# one_million_questions()

#Part 12: Model Transfer

In [None]:
# Transfer the model small weights to the model large
def transfer_model(small_model, large_model, start_layer, end_layer, transfer_ln=True, transfer_embeds=True):
    """Args:
    small_model: The model to transfer weights from
    large_model: The model to transfer weights to
    start_layer: The first layer to transfer weights to
    end_layer: The last layer to transfer weights to (Note that this is end-inclusive!)
    """
    small_cfg = {k: v for k,v in small_model.cfg.__dict__.items() if k in ["d_head", "d_mlp", "d_model", "n_heads", "n_layers"]}
    large_cfg = {k: v for k,v in large_model.cfg.__dict__.items() if k in ["d_head", "d_mlp", "d_model", "n_heads", "n_layers"]}

    # Sanity checks for large model > small model
    assert small_cfg["d_model"] == large_cfg["d_model"]
    assert small_cfg["d_head"] == large_cfg["d_head"]
    assert small_cfg["n_layers"] <= large_cfg["n_layers"]
    assert small_cfg["n_heads"] <= large_cfg["n_heads"]
    assert small_cfg["d_mlp"] <= large_cfg["d_mlp"]

    assert 0 <= start_layer < end_layer <= large_cfg["n_layers"] # Make sure start_layer and end_layer are valid
    assert end_layer - start_layer + 1 == small_cfg["n_layers"] # Make sure the number of layers to transfer is correct

    # Transfer heads and MLPs
    for small_layer_no, large_layer_no in enumerate(range(start_layer, end_layer+1)):
        # Transfer Heads
        large_model.blocks[large_layer_no].attn.W_Q.data[:small_cfg["n_heads"]] = small_model.blocks[small_layer_no].attn.W_Q.clone().data
        large_model.blocks[large_layer_no].attn.W_K.data[:small_cfg["n_heads"]] = small_model.blocks[small_layer_no].attn.W_K.clone().data
        large_model.blocks[large_layer_no].attn.W_V.data[:small_cfg["n_heads"]] = small_model.blocks[small_layer_no].attn.W_V.clone().data

        large_model.blocks[large_layer_no].attn.b_Q.data[:small_cfg["n_heads"]] = small_model.blocks[small_layer_no].attn.b_Q.clone().data
        large_model.blocks[large_layer_no].attn.b_K.data[:small_cfg["n_heads"]] = small_model.blocks[small_layer_no].attn.b_K.clone().data
        large_model.blocks[large_layer_no].attn.b_V.data[:small_cfg["n_heads"]] = small_model.blocks[small_layer_no].attn.b_V.clone().data

        # Transfer MLPs
        large_model.blocks[large_layer_no].mlp.W_in.data[:, :small_cfg["d_mlp"]] = small_model.blocks[small_layer_no].mlp.W_in.clone().data
        large_model.blocks[large_layer_no].mlp.b_in.data[:small_cfg["d_mlp"]] = small_model.blocks[small_layer_no].mlp.b_in.clone().data
        large_model.blocks[large_layer_no].mlp.W_out.data[:small_cfg["d_mlp"],] = small_model.blocks[small_layer_no].mlp.W_out.clone().data
        large_model.blocks[large_layer_no].mlp.b_out.data = small_model.blocks[small_layer_no].mlp.b_out.clone().data

    if transfer_ln:
        for small_layer_no, large_layer_no in enumerate(range(start_layer, end_layer+1)):
            large_model.blocks[large_layer_no].ln1.w.data = small_model.blocks[small_layer_no].ln1.w.clone().data
            large_model.blocks[large_layer_no].ln1.b.data = small_model.blocks[small_layer_no].ln1.b.clone().data

        large_model.ln_final.w.data = small_model.ln_final.w.clone().data

    if transfer_embeds:
        large_model.embed.W_E.data = small_model.embed.W_E.clone().data
        large_model.pos_embed.W_pos.data = small_model.pos_embed.W_pos.clone().data
        large_model.unembed.W_U.data = small_model.unembed.W_U.clone().data

#transfer_model(small_model=model_small, large_model=model_large, start_layer=1, end_layer=2)