# Accurate Integer Mathematics in Transformers - Train the Model

This CoLab defines and trains a Transformer model that performs integer addition, subtraction and multiplication e.g. 133357+182243=+0315600, 123450-345670=-0123230 and 000345*000823=+283935. Each digit is a separate token. For 6 digit questions, the model is given 14 "question" (input) tokens, and must then predict the corresponding 8 "answer" (output) tokens.


This CoLab trains the model, storing the results to the Colab files. Useful models are manually copied to HuggingFace.

## Tips for using the Colab
 * You can run and alter the code in this CoLab notebook yourself in Google CoLab ( https://colab.research.google.com/ ).
 * To run the notebook, in Google CoLab, **you will need to** go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.
 * Some graphs are interactive!
 * Use the table of contents pane in the sidebar to navigate.
 * Collapse irrelevant sections with the dropdown arrows.
 * Search the page using the search in the sidebar, not CTRL+F.

# Part 0: Import libraries
Imports standard libraries. Do not read.

In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")

    !pip install kaleido
    !pip install transformer_lens
    !pip install circuitsvis
    !pip install torchtyping
    !pip install transformers

except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import kaleido
import plotly.io as pio

if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

import plotly.express as px
import plotly.graph_objects as go

In [None]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [None]:
import json
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import tqdm.auto as tqdm
import random
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import matplotlib.pyplot as plt
import circuitsvis as cv

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

# Part 1A: Configuration

In [None]:
# Tokens used in vocab. (Token indexes 0 to 9 represent digits 0 to 9)
PLUS_INDEX = 10
MINUS_INDEX = 11
EQUALS_INDEX = 12
MULT_INDEX = 13
DIV_INDEX = 14
MAX_INDEX = DIV_INDEX

In [None]:
# Main configuration class for main model creation and training
class Config():
  #@markdown Main Model
  n_layers: int = 3 #@param
  n_heads: int = 4 #@param

  d_vocab: int = MAX_INDEX+1
  d_model: int = 510
  d_mlp: int = 4 * d_model
  d_head: int = 170
  training_seed: int = 372001

  #@markdown Data
  n_digits: int = 6 #@param
  n_ctx: int = 3 * n_digits + 4
  act_fn: str = 'relu'
  batch_size: int = 64

  #@markdown Optimizer
  n_training_steps: int = 40000 #@param
  lr: float = 0.00008
  weight_decay: int = 0.1

  #@markdown Maths Operations
  # Percent of questions that are multiplication, subtraction and addition.
  perc_mult: int = 0
  perc_sub: int = 80 #@param
  def perc_add(self):
    return max(0, 100 - self.perc_mult - self.perc_sub)

  #@markdown Insert Model
  insert_mode: int = 0 # 0=None 1=Init, 2=FreezeHeads 3=FreezeAll
  insert_late: bool = False

  insert_n_layers: int = 2 #@param
  insert_n_heads: int = 3 #@param
  insert_training_seed: int = 372001
  insert_n_training_steps: int = 15000 #@param

  # Save graphs to CoLab temp files as PDF and HTML. Can manually export files for re-use in papers.
  save_graph_to_file: bool = True

  def to_dict(self):
    return {
      "n_layers": self.n_layers,
      "n_heads": self.n_heads,
      "d_vocab": self.d_vocab,
      "d_mlp": self.d_mlp,
      "d_head": self.d_head,
      "training_seed": self.training_seed,
      "n_digits": self.n_digits,
      "n_ctx": self.n_ctx,
      "act_fn": self.act_fn,
      "batch_size": self.batch_size,
      "n_training_steps": self.n_training_steps,
      "lr": self.lr,
      "weight_decay": self.weight_decay,
      "perc_mult": self.perc_mult,
      "perc_sub": self.perc_sub,
      "insert_late": self.insert_late,
      "insert_mode": self.insert_mode,
      "insert_n_layers": self.insert_n_layers,
      "insert_n_heads": self.insert_n_heads,
      "insert_training_seed": self.insert_training_seed,
      "insert_n_training_steps": self.insert_n_training_steps,
    }


cfg = Config()

# Part 1B: Configuration


This CoLab can be configured to:
- Train a model using traditional approach. For example, add_d6_l2_h3_train15K_seed372001.pth is trained from scratch using 100% addition questions to give an "Addition" model with a very low loss (9e-9).
- Train a "mixed" model to do two tasks by inserting a "known good" model into the untrained composite model. For example, initialising  ins1_mix_d6_l3_h4_train20K_seed372001.pth with  add_d6_l2_h3_dm510_dh170_train15K_seed372001, and then training it on 80% subtraction questions and 20% addition questions.

If we are inserting existing model weightings they are loaded from HuggingFace.

In [None]:
# Which model do we want to train?
#model_name = "add_d5_l2_h3_train15K"  # 5 digit addition model
#model_name = "add_d6_l2_h3_train15K"  # 6 digit addition model
#model_name = "sub_d6_l2_h3_train30K"  # 6 digit subtraction model
#model_name = "mix_d6_l3_h4_train40K"  # 6 digit addition and subtraction model
#model_name = "ins1_mix_d6_l3_h4_train40K"  # 6 digit addition / subtraction model. Initialise with addition model.
#model_name = "ins2_mix_d6_l4_h4_train40K"  # 6 digit addition / subtraction model. Initialised with addition model. Reset useful heads every 100 epochs
#model_name = "ins3_mix_d6_l4_h3_train40K"  # 6 digit addition / subtraction model. Initialised with addition model. Reset useful heads & MLPs every 100 epochs
model_name = "ins3_mix_d6_l4_h3_train60K"  # 6 digit addition / subtraction model. Initialised with addition model. Reset useful heads & MLPs every 100 epochs

In [None]:
if model_name == "add_d5_l2_h3_train15K" :
  cfg.n_digits = 5
  cfg.n_layers = 2
  cfg.n_heads = 3
  cfg.n_training_steps = 15000
  cfg.perc_sub = 0
  cfg.insert_mode = 0

if model_name == "add_d6_l2_h3_train15K" :
  cfg.n_digits = 6
  cfg.n_layers = 2
  cfg.n_heads = 3
  cfg.n_training_steps = 15000
  cfg.perc_sub = 0
  cfg.insert_mode = 0

if model_name == "sub_d6_l2_h3_train30K" :
  cfg.n_digits = 6
  cfg.n_layers = 2
  cfg.n_heads = 3
  cfg.n_training_steps = 30000
  cfg.perc_sub = 100
  cfg.insert_mode = 0

if model_name == "mix_d6_l3_h4_train40K" :
  cfg.n_digits = 6
  cfg.n_layers = 3
  cfg.n_heads = 4
  cfg.n_training_steps = 40000
  cfg.perc_sub = 66 # Train on 66% subtraction and 33% addition question batches
  cfg.insert_mode = 0

if model_name == "ins1_mix_d6_l3_h4_train40K" :
  cfg.n_digits = 6
  cfg.n_layers = 3
  cfg.n_heads = 4
  cfg.n_training_steps = 40000
  cfg.perc_sub = 80 # Train on 80% subtraction and 20% addition question batches
  cfg.insert_mode = 1 # Initialise with add_d6_l2_h3_train15K.pth.

if model_name == "ins2_mix_d6_l4_h4_train40K" :
  cfg.n_digits = 6
  cfg.n_layers = 4
  cfg.n_heads = 4
  cfg.n_training_steps = 40000
  cfg.perc_sub = 80 # Train on 80% subtraction and 20% addition question batches
  cfg.insert_mode = 2 # Initialise with add_d6_l2_h3_train15K.pth. Train & reset useful heads every 100 epochs

if model_name == "ins3_mix_d6_l4_h3_train40K" :
  cfg.n_digits = 6
  cfg.n_layers = 4
  cfg.n_heads = 3
  cfg.n_training_steps = 40000
  cfg.perc_sub = 80 # Train on 80% subtraction and 20% addition question batches
  cfg.insert_mode = 3 # Initialise with add_d6_l2_h3_train15K.pth. Trained & reset useful heads & MLPs every 100 epochs

if model_name == "ins3_mix_d6_l4_h3_train60K" :
  cfg.n_digits = 6
  cfg.n_layers = 4
  cfg.n_heads = 3
  cfg.n_training_steps = 60000
  cfg.perc_sub = 80 # Train on 80% subtraction and 20% addition question batches
  cfg.insert_mode = 3 # Initialise with add_d6_l2_h3_train15K.pth. Trained & reset useful heads & MLPs every 100 epochs

cfg.n_ctx = 3 * cfg.n_digits + 4

In [None]:
def file_name_suffix(digits, layers, heads, training_steps, seed):
  train_str = str(training_steps//1000) + "K"
  return '_d{}_l{}_h{}_train{}_seed{}'.format(digits, layers, heads, train_str, seed)

main_fname = '' if cfg.insert_mode == 0 else 'ins{}_'.format(cfg.insert_mode)
main_fname += 'mul' if cfg.perc_mult == 100 else 'sub' if cfg.perc_sub == 100 else 'add' if cfg.perc_add() == 100 else 'mix'
main_fname += file_name_suffix(cfg.n_digits, cfg.n_layers, cfg.n_heads, cfg.n_training_steps, cfg.training_seed)
main_fname_pth = main_fname + '.pth'
main_fname_json = main_fname + '.json'


def print_config():
  print("%Mult=", cfg.perc_mult, "%Sub=", cfg.perc_sub, "%Add=", cfg.perc_add(), "File=", main_fname)

print_config()
print('Main model will save to Colab temporary file', main_fname_pth)
print('Main model config etc will save to Colab temporary file', main_fname_json)

# Part 3: Create main_model
This section defines the token embedding / unembedding and creates the model.

In [None]:
# Embedding / Unembedding

def token_to_char(i):
  if i < 10:
   return str(i)
  if i == PLUS_INDEX:
    return "+"
  if i == MINUS_INDEX:
    return "-"
  if i == EQUALS_INDEX:
    return "="
  if i == MULT_INDEX:
    return "*"
  if i == DIV_INDEX:
    return "\\"
  return "?"

def tokens_to_string(tokens):
    tokens = utils.to_numpy(tokens)
    return "".join([token_to_char(i) for i in tokens[:cfg.n_ctx]])

def string_to_tokens(string, batch: bool=False):
    lookup = {str(i):i for i in range(10)}
    lookup['+']=PLUS_INDEX
    lookup['-']=MINUS_INDEX
    lookup['=']=EQUALS_INDEX
    lookup['*']=MULT_INDEX
    lookup['\\']=DIV_INDEX

    tokens = [lookup[i] for i in string if i not in '\n ']
    if batch:
        return torch.tensor(tokens)[None, :]
    else:
        return torch.tensor(tokens)

In [None]:
# Transformer creation

# Structure is documented at https://neelnanda-io.github.io/TransformerLens/transformer_lens.html#transformer_lens.HookedTransformerConfig.HookedTransformerConfig
ht_cfg = HookedTransformerConfig(
    n_layers = cfg.n_layers,
    n_heads = cfg.n_heads,
    d_model = cfg.d_model,
    d_head = cfg.d_head,
    d_mlp = cfg.d_mlp,
    act_fn = cfg.act_fn,
    normalization_type = 'LN',
    d_vocab = cfg.d_vocab,
    d_vocab_out = cfg.d_vocab,
    n_ctx = cfg.n_ctx,
    init_weights = True,
    device = "cuda",
    seed = cfg.training_seed,
)

main_model = HookedTransformer(ht_cfg)

optimizer = optim.AdamW(main_model.parameters(),
                        lr = cfg.lr,
                        weight_decay = cfg.weight_decay,
                        betas = (0.9, 0.98))

max_iter = cfg.n_training_steps
warmup_iter = max_iter // 5
scheduler1 = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, total_iters=int(warmup_iter))
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=int(np.ceil((max_iter-warmup_iter))))
scheduler  = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[int(warmup_iter)])

# Part 4: Loss Function & Data Generator
This section defines the loss function and the training/testing data generator.


In [None]:
# Loss functions

# Calculate the per-token probability by comparing a batch of prediction "logits" to answer "tokens"
def logits_to_tokens_loss(logits, tokens):
  # Addition answer can have one extra digit than question. Answer also has a +/- sign
  n_answer_digits = cfg.n_digits+2

  # The addition answer digit token probabilities
  ans_logits = logits[:, -(n_answer_digits+1):-1]

  # Convert raw score (logits) vector into a probability distribution.
  # Emphasize the largest scores and suppress the smaller ones, to make them more distinguishable.
  ans_probs = F.log_softmax(ans_logits.to(torch.float64), dim=-1)

  max_prob_tokens = torch.argmax(ans_probs, dim=-1)

  # The addition answer digit tokens
  ans_tokens = tokens[:, -(n_answer_digits):]

  # Extract values from the ans_probs tensor, based on indices from the ans_tokens tensor
  ans_loss = torch.gather(ans_probs, -1, ans_tokens[:, :, None])[..., 0]

  return ans_loss, max_prob_tokens


# Calculate loss as negative of average per-token mean probability
def loss_fn(ans_loss):
  return -ans_loss.mean(0)

In [None]:
# Generate an enriched data batch for one operator type
# "Addition" batch entries are formated XXXXX+YYYYY=+ZZZZZZ e.g. 550030+800020=+1350050
# "Subtraction" batch entries are formated XXXXX-YYYYY=-ZZZZZZ e.g. 550030-800020=-0249990, 800020-550030=+0249990
# "Multiplication" batch entries are formated 000XXX*000YYY=+ZZZZZZ e.g. 000345*000678=+233910
def data_generator_core( batch_op ):

  batch = torch.zeros((cfg.batch_size, cfg.n_ctx)).to(torch.int64)
  x = torch.randint(0, 10, (cfg.batch_size, cfg.n_digits))
  y = torch.randint(0, 10, (cfg.batch_size, cfg.n_digits))

  if batch_op == MULT_INDEX:
    # Convert from NNNNNN*NNNNNN= to 000NNN*000NNN= so answer (product) is NNNNNN
    num_zeros = cfg.n_digits // 2
    for z in range(num_zeros):
      x[:, z] = 0
      y[:, z] = 0

  # Enrich the question data on 60% of batches to speed up training
  random_case = random.randint(1, 5)
  if ( batch_op == PLUS_INDEX or batch_op == MINUS_INDEX ) and (random_case <= 3):
    # Flatten x and y to 1D tensors
    x_flat = x.view(-1)
    y_flat = y.view(-1)

    num_elements_to_modify = int(0.40 * x.numel()) # 40%
    indices_to_modify = torch.randperm(x_flat.numel())[:num_elements_to_modify]
    if batch_op == PLUS_INDEX :
      # The S2, S3, etc tasks are compound and increasing rare and so harder to learn.
      # Increase the MakeSum9 case frequency to increase frequency of S2, S3, etc
      if random.randint(1, 2) == 1:
        x_flat[indices_to_modify] = 9 - y_flat[indices_to_modify]
      else:
        y_flat[indices_to_modify] = 9 - x_flat[indices_to_modify]
    else:
      # Generate more questions with negative answers (NG task) by increasing the y value
      if random_case == 1:
        y_flat[y_flat < 9] += 1

      # The M2, M3, etc tasks are compound and increasing rare and so harder to learn.
      # Increase the DiffZero case frequency to increase frequency of M2, M3, etc
      if random.randint(1, 2) == 1:
        x_flat[indices_to_modify] = y_flat[indices_to_modify]
      else:
        y_flat[indices_to_modify] = x_flat[indices_to_modify]


    # Reshape x and y back to its original shape
    x = x_flat.view(x.shape)
    y = y_flat.view(x.shape)


  batch[:, :cfg.n_digits] = x
  batch[:, cfg.n_digits] = batch_op
  batch[:, 1+cfg.n_digits:1+cfg.n_digits*2] = y
  batch[:, 1+cfg.n_digits*2] = EQUALS_INDEX

  # Convert each row into a 5-digit number
  x_values = x[:, 0]
  y_values = y[:, 0]
  for dn in range(1,cfg.n_digits):
    x_values = x_values * 10 + x[:, dn]
    y_values = y_values * 10 + y[:, dn]

  # Elementwise operations to give the 1D tensor answers
  if batch_op == MULT_INDEX:
    answers = x_values * y_values
  else:
    if batch_op == MINUS_INDEX:
      answers = x_values - y_values
    else:
      answers = x_values + y_values

  # Insert the answers into the batch
  for i in range(cfg.batch_size):
    answer = answers[i]

    sign = PLUS_INDEX
    if answer < 0:
      sign = MINUS_INDEX
      answer = - answer

    batch[i, 2+cfg.n_digits*2] = sign
    for j in range(cfg.n_digits+1):
      batch[i, cfg.n_ctx-j-1] = answer % 10
      answer = answer // 10
      if answer == 0:
          break

  return batch

In [None]:
# Define "iterator" data generator function. Invoked using next().
def data_generator( ):
  torch.manual_seed(cfg.training_seed)
  while True:

    batch_rand = random.randint(1, 100)
    batch_op = MULT_INDEX if batch_rand <= cfg.perc_mult else MINUS_INDEX if batch_rand <= cfg.perc_mult + cfg.perc_sub else PLUS_INDEX

    batch = data_generator_core( batch_op )

    yield batch.cuda()

In [None]:
# Initialise the data generator
ds = data_generator()

In [None]:
# Test data generator
tokens = next(ds)
print(tokens[:3,:])

# Part 5: Read insert_model from HuggingFace (optional)

If we are initialising the untrained model with an existing model,
then we load the existing model from HuggingFace.
We load both the model weights and a json file stating which nodes in the model are actually doing useful calculations.

In [None]:
insert_base_name = ""
insert_weights_fname = ""
insert_nodes_fname = ""

In [None]:
# Read insert_model weights from HuggingFace
if cfg.insert_mode >= 1:
  insert_base_name = "add" + file_name_suffix(cfg.n_digits, cfg.insert_n_layers, cfg.insert_n_heads, cfg.insert_n_training_steps, cfg.insert_training_seed)

  insert_weights_fname = insert_base_name + ".pth"

  ht_cfg = HookedTransformerConfig(
      n_layers = cfg.insert_n_layers,
      n_heads = cfg.insert_n_heads,
      d_model = cfg.d_model, # Assume constant
      d_head = cfg.d_head, # Assume constant
      d_mlp = cfg.d_mlp, # Assume constant
      act_fn = cfg.act_fn, # Assume constant
      normalization_type = 'LN',
      d_vocab = cfg.d_vocab, # Assume constant
      d_vocab_out = cfg.d_vocab, # Assume constant
      n_ctx = cfg.n_ctx, # Assume constant
      init_weights = True, # Assume constant
      device = "cuda",
      seed = cfg.insert_training_seed,
  )

  insert_model = HookedTransformer(ht_cfg)

  print('Loading insertion model from', insert_weights_fname)
  insert_model.load_state_dict(utils.download_file_from_hf(repo_name="PhilipQuirke/Accurate6DigitSubtraction", file_name=insert_weights_fname, force_is_torch=True))
  insert_model.eval()

  print("Loaded insert model", insert_weights_fname)

In [None]:
class UsefulCell():
  # Position.Layer.Head of the cell
  position: int  # token-position. Zero to cfg.n_ctx - 1
  layer: int
  head: int

  # Tags related to the cell of form "MajorVersion.MinorVersion"
  tags: list


  # Is this cell an attention head? If not, it must be an MLP layer
  def is_head(self):
    return self.head != cfg.insert_n_heads


  def __init__(self, position, layer, head, tags):
    self.position = position
    self.layer = layer
    self.head = head
    self.tags = tags


  @classmethod
  def from_dict(cls, data):
      return cls(data['position'], data['layer'], data['head'], data['tags'])

In [None]:
if cfg.insert_mode >= 1:
  # Read insert_model useful node information from HuggingFace
  huggingface_directory_url = 'https://huggingface.co/PhilipQuirke/Accurate6DigitSubtraction/raw/main/'
  insert_nodes_fname = huggingface_directory_url + insert_base_name + '_tags.json'
  print(insert_nodes_fname)

  # Download the file
  response = requests.get(insert_nodes_fname)

  # Ensure the request was successful
  if response.status_code == 200:
      # Load the JSON data
      json_data = json.loads(response.content.decode('utf-8'))

      useful_cells = [UsefulCell.from_dict(item) for item in json_data]

      print( "Loaded:", len(useful_cells), "Sample:", useful_cells[0].tags)
  else:
      print( "Failed to download the file:", response.status_code)

# Part 6B: Transfer all of insert_model into main_model (optional)



In [None]:
# Transfer all attention heads weights from the small to the main model, updating the right-most small.n_heads of main_model
def transfer_all_heads(small_model, small_cfg, start_layer, end_layer, large_model):
  for small_layer_no, large_layer_no in enumerate(range(start_layer, end_layer+1)):
    large_model.blocks[large_layer_no].attn.W_Q.data[:small_cfg["n_heads"]] = small_model.blocks[small_layer_no].attn.W_Q.clone().data
    large_model.blocks[large_layer_no].attn.W_K.data[:small_cfg["n_heads"]] = small_model.blocks[small_layer_no].attn.W_K.clone().data
    large_model.blocks[large_layer_no].attn.W_V.data[:small_cfg["n_heads"]] = small_model.blocks[small_layer_no].attn.W_V.clone().data

    large_model.blocks[large_layer_no].attn.b_Q.data[:small_cfg["n_heads"]] = small_model.blocks[small_layer_no].attn.b_Q.clone().data
    large_model.blocks[large_layer_no].attn.b_K.data[:small_cfg["n_heads"]] = small_model.blocks[small_layer_no].attn.b_K.clone().data
    large_model.blocks[large_layer_no].attn.b_V.data[:small_cfg["n_heads"]] = small_model.blocks[small_layer_no].attn.b_V.clone().data

In [None]:
# Transfer all MLP layer weights from the small to the main model, updating the right-most small.d_mlp of main_model
def transfer_all_mlps(small_model, small_cfg, start_layer, end_layer, large_model):
  for small_layer_no, large_layer_no in enumerate(range(start_layer, end_layer+1)):
    large_model.blocks[large_layer_no].mlp.W_in.data[:, :small_cfg["d_mlp"]] = small_model.blocks[small_layer_no].mlp.W_in.clone().data
    large_model.blocks[large_layer_no].mlp.b_in.data[:small_cfg["d_mlp"]] = small_model.blocks[small_layer_no].mlp.b_in.clone().data
    large_model.blocks[large_layer_no].mlp.W_out.data[:small_cfg["d_mlp"],] = small_model.blocks[small_layer_no].mlp.W_out.clone().data
    large_model.blocks[large_layer_no].mlp.b_out.data = small_model.blocks[small_layer_no].mlp.b_out.clone().data

In [None]:
def transfer_all_ln(small_model, start_layer, end_layer, large_model):
  for small_layer_no, large_layer_no in enumerate(range(start_layer, end_layer+1)):
    large_model.blocks[large_layer_no].ln1.w.data = small_model.blocks[small_layer_no].ln1.w.clone().data
    large_model.blocks[large_layer_no].ln1.b.data = small_model.blocks[small_layer_no].ln1.b.clone().data

  large_model.ln_final.w.data = small_model.ln_final.w.clone().data

In [None]:
def transfer_all_embeds(small_model, large_model):
  large_model.embed.W_E.data = small_model.embed.W_E.clone().data
  large_model.pos_embed.W_pos.data = small_model.pos_embed.W_pos.clone().data
  large_model.unembed.W_U.data = small_model.unembed.W_U.clone().data

In [None]:
small_cfg = {}
large_cfg = {}

# Insert the small model weights into the large model
def transfer_full_model(small_model, large_model, start_layer, end_layer, transfer_ln=True, transfer_embeds=True):
  """Args:
  small_model: The model to transfer weights from
  large_model: The model to transfer weights to
  start_layer: The first layer to transfer weights to
  end_layer: The last layer to transfer weights to (Note that this is end-inclusive!)
  """
  global small_cfg
  global large_cfg

  small_cfg = {k: v for k,v in small_model.cfg.__dict__.items() if k in ["d_head", "d_mlp", "d_model", "n_heads", "n_layers"]}
  large_cfg = {k: v for k,v in large_model.cfg.__dict__.items() if k in ["d_head", "d_mlp", "d_model", "n_heads", "n_layers"]}

  # Sanity checks for large model > small model
  assert small_cfg["d_model"] == large_cfg["d_model"]
  assert small_cfg["d_head"] == large_cfg["d_head"]
  assert small_cfg["n_layers"] <= large_cfg["n_layers"]
  assert small_cfg["n_heads"] <= large_cfg["n_heads"]
  assert small_cfg["d_mlp"] <= large_cfg["d_mlp"]

  assert 0 <= start_layer < end_layer <= large_cfg["n_layers"] # Make sure start_layer and end_layer are valid
  assert end_layer - start_layer + 1 == small_cfg["n_layers"] # Make sure the number of layers to transfer is correct

  transfer_all_heads(small_model, small_cfg, start_layer, end_layer, large_model)
  transfer_all_mlps(small_model, small_cfg, start_layer, end_layer, large_model)
  if transfer_ln:
    transfer_all_ln(small_model, start_layer, end_layer, large_model)
  if transfer_embeds:
    transfer_all_embeds(small_model, large_model)

In [None]:
def insert_existing_model( first_time ):
  if cfg.insert_mode >= 1 :
    # Is the destination the first few or last few layers of the main_model?
    start_layer = cfg.n_layers - cfg.insert_n_layers if cfg.insert_late else 0
    end_layer = start_layer + cfg.insert_n_layers-1

    if first_time:
      print( "Inserting trained small_model", insert_weights_fname)
      print( "into larger, untrained main_model", main_fname)
      print( "destination layers:", start_layer, end_layer)

    transfer_full_model(insert_model, main_model, start_layer, end_layer, first_time, first_time)


insert_existing_model( True )

# Part 6C: Transfer useful heads of insert_model into main_model (optional)

Transfer just the useful attention heads from insert_model into main_model.

In [None]:
# Transfer one attention head's weights from the small to the main model.
# The right-most small.n_heads of main_model are updated
def transfer_one_head(small_model, small_layer_no, small_head_no, large_model, start_layer):
  large_layer_no = start_layer + small_layer_no
  large_head_no = large_cfg["n_heads"] - small_cfg["n_heads"] + small_head_no

  large_model.blocks[large_layer_no].attn.W_Q.data[large_head_no] = small_model.blocks[small_layer_no].attn.W_Q.clone().data[small_head_no]
  large_model.blocks[large_layer_no].attn.W_K.data[large_head_no] = small_model.blocks[small_layer_no].attn.W_K.clone().data[small_head_no]
  large_model.blocks[large_layer_no].attn.W_V.data[large_head_no] = small_model.blocks[small_layer_no].attn.W_V.clone().data[small_head_no]

  large_model.blocks[large_layer_no].attn.b_Q.data[large_head_no] = small_model.blocks[small_layer_no].attn.b_Q.clone().data[small_head_no]
  large_model.blocks[large_layer_no].attn.b_K.data[large_head_no] = small_model.blocks[small_layer_no].attn.b_K.clone().data[small_head_no]
  large_model.blocks[large_layer_no].attn.b_V.data[large_head_no] = small_model.blocks[small_layer_no].attn.b_V.clone().data[small_head_no]

In [None]:
def transfer_useful_heads(small_model, large_model):
  if cfg.insert_mode >= 2 and len(useful_cells) > 0:
    # Is the destination the first few or last few layers of the main_model?
    start_layer = cfg.n_layers - cfg.insert_n_layers if cfg.insert_late else 0

    # TODO: Somewhat inefficent loop as a given head may exist in the list multiple times (with different use_cell.position and use_cell.tags values)
    for use_cell in useful_cells:
      if use_cell.is_head():
        transfer_one_head(small_model, use_cell.layer, use_cell.head, large_model, start_layer)

# Part 7: Train add/sub/mult main_model with Infinite Data
Train main_model for n_training_steps, storing train_losses per epoch.

Each training step (of n_training_steps) new training data (a batch of batch_size tokens) is generated and the model is trained and loss calculated on it. No separate "testing" data is needed, as the training data is unique each step. Memorisation of past training data by the model (if any) is minimally beneficial. For 6 digit addition or subtraction there are 1000 billion possible questions.

In [None]:
# Train the model
train_losses_list = []
per_token_train_losses_list = []

for epoch in tqdm.tqdm(range(cfg.n_training_steps)):

  tokens = next(ds)
  logits = main_model(tokens)

  per_token_train_losses_raw, _ = logits_to_tokens_loss(logits, tokens)
  per_token_train_losses = loss_fn(per_token_train_losses_raw)
  per_token_train_losses_list.append(utils.to_numpy(per_token_train_losses))

  train_loss = per_token_train_losses.mean()
  train_loss.backward()
  train_losses_list.append(train_loss.item())

  optimizer.step()
  scheduler.step()
  optimizer.zero_grad()

  if epoch % 100 == 0:
    print(epoch, train_loss.item())
    if cfg.insert_mode == 2:
      # Freeze the useful attention heads from insert_model
      transfer_useful_heads(insert_model, main_model)
    if cfg.insert_mode == 3:
      # Freeze the useful attention heads and MLP layers from insert_model
      insert_existing_model( False )

print(epoch, train_loss.item())

In [None]:
final_training_loss = round((train_losses_list[-5]+train_losses_list[-4]+train_losses_list[-3]+train_losses_list[-2]+train_losses_list[-1])/5,9)

print( "AvgFinalLoss", final_training_loss)
print( "Final loss", train_losses_list[-1])

In [None]:
# These temporary Colab files can be manually downloaded from the Colab "Files" tab (at left).
# The download can be manually loaded into HuggingFace so the "Accurate Math - Analyse" Colab can access it.

print("Saving main model to temporary Colab file", main_fname_pth)
torch.save(main_model.state_dict(), main_fname_pth)

In [None]:
extra_data = {
    "Config": cfg.to_dict(),
    "AvgFinalLoss": final_training_loss,
    "FinalLoss": train_losses_list[-1],
    "TrainingLoss": train_losses_list
}

print( "Saving main model config etc to temporary Colab file:", main_fname_json)
save_cfg = cfg.to_dict()
with open(main_fname_json, 'w') as file:
    json.dump(extra_data, file)

# Part 8: Training Loss - Addition and Subtraction

On 26Jan24 ran several runs:

Addition:
- 26Jan24: add_d5_l2_h3_train15K_seed372001. AvgFinalLoss=1.6e-08. Handles 1m Qs
- 26Jan24: add_d6_l2_h3_train15K_seed372001. AvgFinalLoss=1.7e-08. Handles 1m Qs

Subtraction:
- 26Jan24: sub_d6_l2_h3_train20K_seed372001. AvgFinalLoss=9.8-05. Fails 1m Qs
- 26Jan24: sub_d6_l2_h3_train30K_seed372001. AvgFinalLoss=5.8e-06. Fails 1m Qs

Mixed:
- 26Jan24: sub_d6_l3_h4_train20K_seed372001. AvgFinalLoss=5e-09. Fails 1m Qs

Mixed+Insert (add_d6_l2_h3_train15K_seed372001 inserted):
- 26Jan24: ins1_mix_d6_l3_h4_train40K_seed372001. AvgFinalLoss=8e-09. Handles 1m Qs for Add and Sub
- 26Jan24: ins2_mix_d6_l4_h4_train40K_seed372001. AvgFinalLoss=7e-09. Fails 1m Qs
- 26Jan24: ins3_mix_d6_l4_h3_train40K_seed372001. AvgFinalLoss=2.6e-06. Fails 1m Qs
- ??Jan24: ins3_mix_d6_l4_h3_train60K_seed372001. AvgFinalLoss=????

# Part 9: Line Graphs

This section analyses the training loss by graphing it at a high level.

The loss curve for all digits show visible inflection points (bumps), but is too high level to help understand the algorithm.

When this graph is decomposed into 'per digit' graphs, the interesting distinct 'per digit' curves appear, showing each digit is being refined semi-independently, with the model algorithm refining each digit separately.

In [None]:
def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

# Helper function to plot multiple lines
def lines(raw_lines_list, x=None, mode='lines', labels=None, xaxis='Epoch', yaxis='Loss', title = '', log_y=False, hover=None, **kwargs):

    lines_list = raw_lines_list
    log_suffix = '' if log_y==False else ' (Log)'
    full_title = title + log_suffix

    if type(lines_list)==torch.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    if cfg.save_graph_to_file :
      fig = go.Figure(layout={})
      print(full_title)
    else:
      fig = go.Figure(layout={'title':full_title})

    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis + log_suffix)
    for c, line in enumerate(lines_list):
        if type(line)==torch.Tensor:
            line = utils.to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    if cfg.save_graph_to_file:
        fig.update_layout(margin=dict(l=10, r=10, t=10, b=10),width=1200,height=300)

    fig.show(bbox_inches="tight")

    if cfg.save_graph_to_file:
        filename = full_title.replace(" ", "").replace("(", "").replace(")", "").replace("&", "").replace(",", "").replace("%", "")   +'.pdf'
        pio.write_image(fig, filename)


title_suffix = 'Digit Loss Curves ' + main_fname
per_token_losses = np.stack(per_token_train_losses_list, axis=0)

line(train_losses_list,
    title=title_suffix)

lines([per_token_losses[:, i] for i in range(1+cfg.n_digits)]+[train_losses_list],
      labels = [f'digit {i}' for i in range(1+cfg.n_digits)]+['all_digits'],
      title='Per digit'+title_suffix, log_y=False)

lines([per_token_losses[:, i] for i in range(1+cfg.n_digits)]+[train_losses_list],
      labels = [f'digit {i}' for i in range(1+cfg.n_digits)]+['all_digits'],
      title='Per digit'+title_suffix, log_y=True)

for i in range(1+cfg.n_digits):
  print('Final Loss for digit ' + str(i) + ' is ', per_token_losses[-1, i])

# Part 10: Questions Set Up

Create sets of sample questions (by task) to ask the model to predict

In [None]:
def make_varied_questions():
  q0 = next(ds)
  q1 = next(ds)
  q2 = next(ds)
  q3 = next(ds)

  questions = torch.vstack((q0.cuda(), q1.cuda(), q2.cuda(), q3.cuda()))

  return questions

In [None]:
verbose = True

In [None]:
# Build a test batch of random questions
varied_questions = make_varied_questions();


# Run the sample batch, gather the cache
main_model.reset_hooks()
main_model.set_use_attn_result(True)
sample_logits, sample_cache = main_model.run_with_cache(varied_questions.cuda())
print(sample_cache) # Gives names of datasets in the cache
sample_losses_raw, sample_max_prob_tokens = logits_to_tokens_loss(sample_logits, varied_questions.cuda())
sample_loss_mean = utils.to_numpy(loss_fn(sample_losses_raw).mean())
print("Sample Mean Loss", sample_loss_mean)


# attn.hook_z is the "attention head output" hook point name (at a specified layer)
l_attn_hook_z_name = [utils.get_act_name('z', 0, 'a'),utils.get_act_name('z', 1, 'a')] # 'blocks.0.attn.hook_z' etc
sample_attn_z = sample_cache[l_attn_hook_z_name[0]]
print("Sample", l_attn_hook_z_name[0], sample_attn_z.shape) # gives [239, 18, 3, 170] = num_questions, cfg.n_ctx, n_heads, d_head
mean_attn_z = torch.mean(sample_attn_z, dim=0, keepdim=True)
print("Mean", l_attn_hook_z_name[0], mean_attn_z.shape) # gives [1, 18, 3, 170] = 1, cfg.n_ctx, n_heads, d_head


# hook_resid_pre is the "pre residual memory update" hook point name (at a specified layer)
l_hook_resid_pre_name = ['blocks.0.hook_resid_pre','blocks.1.hook_resid_pre']


# hook_resid_post is the "post residual memory update" hook point name (at a specified layer)
l_hook_resid_post_name = ['blocks.0.hook_resid_post','blocks.1.hook_resid_post']
sample_resid_post = sample_cache[l_hook_resid_post_name[0]]
print("Sample", l_hook_resid_post_name[0], sample_resid_post.shape) # gives [239, 18, 510] = num_questions, cfg.n_ctx, d_model
mean_resid_post = torch.mean(sample_resid_post, dim=0, keepdim=True)
print("Mean", l_hook_resid_post_name[0], mean_resid_post.shape) # gives [1, 18, 510] = 1, cfg.n_ctx, d_model


# mlp.hook_post is the "MLP layer" hook point name (at a specified layer)
l_mlp_hook_post_name = [utils.get_act_name('post', 0),utils.get_act_name('post', 1)] # 'blocks.0.mlp.hook_post' etc
sample_mlp_hook_post = sample_cache[l_mlp_hook_post_name[0]]
print("Sample", l_mlp_hook_post_name[0], sample_mlp_hook_post.shape) # gives [239, 18, 2040] = num_questions, cfg.n_ctx, d_model*4
mean_mlp_hook_post = torch.mean(sample_mlp_hook_post, dim=0, keepdim=True)
print("Mean", l_mlp_hook_post_name[0], mean_mlp_hook_post.shape) # gives [1, 18, 2040] = 1, cfg.n_ctx, d_model*4

# Part 11: Attention Patterns
Attention patterns show which token(s) the model's attention heads are paying attention to in each token position of the prediction calculation.

For the default CoLab set up, the  model has 3 attention heads, and performs 5 digit addition. The attention pattern is 18 by 18 squares (as 54321+77779=132100 is 18 tokens). Time proceeds vertically downwards, with one additional token being revealed horizontally at each token position, giving the overall triangle shape. This visualisation provided insights. After the question is fully revealed (at token position 11), each head starts attending to pairs of question digits from left to right (i.e. high-value digits before lower-value digits) giving the “double staircase" shape. The three heads attend to a given digit pair in three different token position, giving a time ordering of heads.

In [None]:
def show_token_attention_patterns(index, layer, token_at_index, use_case):

  the_tokens = [str(token) for token in token_at_index.tolist()]
  if layer == 0:
    tokens_str = tokens_to_string(token_at_index)
    print("Attention patterns for", tokens_str)

  attention_pattern=sample_cache["pattern", layer, "attn"][index]
  display(cv.attention.attention_patterns(
      tokens=the_tokens,
      attention=attention_pattern,
      #attention_head_names=[f"L{layer}H{i}" for i in range(cfg.n_heads)],
  ))


sample_size = 3

# Show attention patterns for some randomly chosen tokens
for i in range(sample_size):
  for layer in range(cfg.n_layers):
    show_token_attention_patterns(i, layer, tokens[i], "Misc")


In [None]:
if cfg.save_graph_to_file:

  tokens_str = []
  for i in range(cfg.n_heads):
    one_token_str = []
    for j in tokens[i]:
      one_token_str.append(str(utils.to_numpy(j)))
    tokens_str.append(one_token_str)

  # Refer https://github.com/callummcdougall/CircuitsVis/blob/main/python/circuitsvis/circuitsvis_demo.ipynb

  # html_object = cv.attention.from_cache(
  #    cache = sample_cache,
  #    tokens = tokens_str, # list of list of strings
  #    return_mode = "html",
  #)

  # Create a CoLab file containing the attention pattern(s) in HTML
  #filename = "AttentionPattern" + str(cfg.n_digits) + "Digits" + str(cfg.n_heads) + "Heads.html"
  #with open(filename, "w") as f:
  #    f.write(html_object.data)

  # Manually download the CoLab "html" file and open in your local browser.
  # Install and use the Edge extension "FireShot" to save a portion of the HTML page as a PDF