# Accurate Integer Mathematics in Transformers - Train the Model

This CoLab defines and trains a Transformer model that performs integer addition, subtraction and multiplication e.g. 133357+182243=+0315600, 123450-345670=-0123230 and 000345*000823=+283935. Each digit is a separate token. For 6 digit questions, the model is given 14 "question" (input) tokens, and must then predict the corresponding 8 "answer" (output) tokens.

This CoLab can be configured to:
- Train a model using traditional approach. For example, add_d6_l2_h3_dm510_dh170_ctx22_seed129000_train15K.pth is trained from scratch using 100% addition questions to give an "Addition " model with a very low loss (9e-9).
- Train a model by inserting a "known good" model into the untrained composite model. For example, initialising ins_sub_d6_l3_h4_dm510_dh170_ctx22_seed129000_train20K.pth with add_d6_l2_h3_dm510_dh170_ctx22_seed129000_train15K, and then training it on 100% subtraction questions.
- Train a "mixed" model to do two tasks by inserting a "known good" model into the untrained composite model. For example, initialising  ins_mix_d6_l3_h4_dm510_dh170_ctx22_seed129000_train20K.pth with  add_d6_l2_h3_dm510_dh170_ctx22_seed129000_train15K, and then training it on 80% subtraction questions and 20% addition questions.


This CoLab trains the model, storing the results to Google Drive.

## Tips for using the Colab
 * You can run and alter the code in this CoLab notebook yourself in Google CoLab ( https://colab.research.google.com/ ).
 * To run the notebook, in Google CoLab, **you will need to** go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.
 * Some graphs are interactive!
 * Use the table of contents pane in the sidebar to navigate.
 * Collapse irrelevant sections with the dropdown arrows.
 * Search the page using the search in the sidebar, not CTRL+F.

# Part 1: Configuration

In [89]:
# Tokens used in vocab. (Token indexes 0 to 9 represent digits 0 to 9)
PLUS_INDEX = 10
MINUS_INDEX = 11
EQUALS_INDEX = 12
MULT_INDEX = 13
DIV_INDEX = 14
MAX_INDEX = DIV_INDEX

In [90]:
# Main configuration class for main model creation and training
class Config():
  #@markdown Main Model
  n_layers: int = 3 #@param
  n_heads: int = 4 #@param

  d_vocab: int = MAX_INDEX+1
  d_model: int = 510
  d_mlp: int = 4 * d_model
  d_head: int = 170
  seed: int = 372001

  #@markdown Data
  n_digits: int = 6 #@param
  n_ctx: int = 3 * n_digits + 4
  act_fn: str = 'relu'
  batch_size: int = 256

  #@markdown Optimizer
  n_training_steps: int = 20000 #@param
  lr: float = 0.00008
  weight_decay: int = 0.1

  #@markdown Maths Operations
  # Percent of questions that are multiplication, subtraction and addition.
  perc_mult: int = 0 #@param
  perc_sub: int = 80 #@param
  def perc_add(self):
    return max(0, 100 - self.perc_mult - self.perc_sub)


  # Save graphs to CoLab temp files as PDF and HTML. Can manually export files for re-use in papers.
  save_graph_to_file: bool = True


cfg = Config()

In [91]:
# Optional configuration class for inserting a sub-model's weights into the main model before training
class IConfig():
  #@markdown Insert Model
  insert: bool = True #@param
  insert_late: bool = False

  n_layers: int = 2 #@param
  n_heads: int = 3 #@param
  seed: int = 372001

  #@markdown Data
  n_digits: int = 6 #@param
  n_ctx: int = 3 * n_digits + 4

  #@markdown Optimizer
  n_training_steps: int = 15000 #@param


icfg = IConfig()

In [92]:
def file_name_suffix(digits, layers, heads, d_model, d_head, ctx, seed, training_steps):
  epoch_str = str(training_steps//1000) + "K"
  return '_d{}_l{}_h{}_dm{}_dh{}_ctx{}_seed{}_train{}'.format(digits, layers, heads, d_model, d_head, ctx, seed, epoch_str)

fname_prefix = 'ins_' if icfg.insert else ''
fname_prefix += 'mul' if cfg.perc_mult == 100 else 'sub' if cfg.perc_sub == 100 else 'add' if cfg.perc_add() == 100 else 'mix'
main_fname_suffix = fname_prefix + file_name_suffix(cfg.n_digits, cfg.n_layers, cfg.n_heads, cfg.d_model, cfg.d_head, cfg.n_ctx, cfg.seed, cfg.n_training_steps)


def print_config():
  print("%Mult=", cfg.perc_mult, "%Sub=", cfg.perc_sub, "%Add=", cfg.perc_add(), "File=", main_fname_suffix)

print_config()

%Mult= 0 %Sub= 80 %Add= 20 File= ins_mix_d6_l3_h4_dm510_dh170_ctx22_seed372001_train20K


# Part 2: Import libraries
Imports standard libraries. Will ask for access to your Google to write model weightings

In [93]:
from google.colab import drive
from pathlib import Path

In [94]:
# Training saves the trained model weights to a file in your Google Drive. You will need to give permission for this CoLab to access your Google Drive.
# Loading loads the model from your Google Drive. Avoids the say 10mins spent on training the model.

GLOBAL=True
if GLOBAL:
    drive.mount('/content/drive', force_remount=False)
    rootdir=Path('/content/drive/MyDrive/AI/CoLabOutput/')
else:
    rootdir=Path('./')

main_fname_full = main_fname_suffix + '.pth'

main_persist_location = rootdir/f'{main_fname_full}'

print('main model will save to {}'.format(str(main_persist_location)))

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
main model will save to /content/drive/MyDrive/AI/CoLabOutput/ins_mix_d6_l3_h4_dm510_dh170_ctx22_seed372001_train20K.pth


In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")

    !pip install kaleido
    !pip install transformer_lens
    !pip install circuitsvis
    !pip install torchtyping
    !pip install transformers

except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
[31mERROR: Operation cancelled by user[0m[31m


In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import kaleido
import plotly.io as pio

if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

import plotly.express as px
import plotly.graph_objects as go

In [None]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import tqdm.auto as tqdm
import random
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import matplotlib.pyplot as plt
import circuitsvis as cv

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

# Part 3: Create main_model
This section defines the token embedding / unembedding and creates the model.

In [None]:
# Embedding / Unembedding

def token_to_char(i):
  if i < 10:
   return str(i)
  if i == PLUS_INDEX:
    return "+"
  if i == MINUS_INDEX:
    return "-"
  if i == EQUALS_INDEX:
    return "="
  if i == MULT_INDEX:
    return "*"
  if i == DIV_INDEX:
    return "\\"
  return "?"

def tokens_to_string(tokens):
    tokens = utils.to_numpy(tokens)
    return "".join([token_to_char(i) for i in tokens[:cfg.n_ctx]])

def string_to_tokens(string, batch: bool=False):
    lookup = {str(i):i for i in range(10)}
    lookup['+']=PLUS_INDEX
    lookup['-']=MINUS_INDEX
    lookup['=']=EQUALS_INDEX
    lookup['*']=MULT_INDEX
    lookup['\\']=DIV_INDEX

    tokens = [lookup[i] for i in string if i not in '\n ']
    if batch:
        return torch.tensor(tokens)[None, :]
    else:
        return torch.tensor(tokens)

In [None]:
# Transformer creation

# Structure is documented at https://neelnanda-io.github.io/TransformerLens/transformer_lens.html#transformer_lens.HookedTransformerConfig.HookedTransformerConfig
ht_cfg = HookedTransformerConfig(
    n_layers = cfg.n_layers,
    n_heads = cfg.n_heads,
    d_model = cfg.d_model,
    d_head = cfg.d_head,
    d_mlp = cfg.d_mlp,
    act_fn = cfg.act_fn,
    normalization_type = 'LN',
    d_vocab = cfg.d_vocab,
    d_vocab_out = cfg.d_vocab,
    n_ctx = cfg.n_ctx,
    init_weights = True,
    device = "cuda",
    seed = cfg.seed,
)

main_model = HookedTransformer(ht_cfg)

optimizer = optim.AdamW(main_model.parameters(),
                        lr = cfg.lr,
                        weight_decay = cfg.weight_decay,
                        betas = (0.9, 0.98))

max_iter = cfg.n_training_steps
warmup_iter = max_iter // 5
scheduler1 = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, total_iters=int(warmup_iter))
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=int(np.ceil((max_iter-warmup_iter))))
scheduler  = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[int(warmup_iter)])

# Part 4: Loss Function & Data Generator
This section defines the loss function and the training/testing data generator.


In [None]:
# Loss functions

# Calculate the per-token probability by comparing a batch of prediction "logits" to answer "tokens"
def logits_to_tokens_loss(logits, tokens):
  # Addition answer can have one extra digit than question. Answer also has a +/- sign
  n_answer_digits = cfg.n_digits+2

  # The addition answer digit token probabilities
  ans_logits = logits[:, -(n_answer_digits+1):-1]

  # Convert raw score (logits) vector into a probability distribution.
  # Emphasize the largest scores and suppress the smaller ones, to make them more distinguishable.
  ans_probs = F.log_softmax(ans_logits.to(torch.float64), dim=-1)

  max_prob_tokens = torch.argmax(ans_probs, dim=-1)

  # The addition answer digit tokens
  ans_tokens = tokens[:, -(n_answer_digits):]

  # Extract values from the ans_probs tensor, based on indices from the ans_tokens tensor
  ans_loss = torch.gather(ans_probs, -1, ans_tokens[:, :, None])[..., 0]

  return ans_loss, max_prob_tokens

# Calculate loss as negative of average per-token mean probability
def loss_fn(ans_loss):
  return -ans_loss.mean(0)

In [None]:
# Define "iterator" data generator function. Invoked using next().
# "Addition" batch entries are formated XXXXX+YYYYY=+ZZZZZZ e.g. 550030+800020=+1350050
# "Subtraction" batch entries are formated XXXXX-YYYYY=-ZZZZZZ e.g. 550030-800020=-0249990, 800020-550030=+0249990
# "Multiplication" batch entries are formated 000XXX*000YYY=+ZZZZZZ e.g. 000345*000678=+233910
def data_generator( ):
    torch.manual_seed(cfg.seed)
    while True:
        #generate a batch of questions (answers calculated below)
        batch = torch.zeros((cfg.batch_size, cfg.n_ctx)).to(torch.int64)
        x = torch.randint(0, 10, (cfg.batch_size, cfg.n_digits))
        y = torch.randint(0, 10, (cfg.batch_size, cfg.n_digits))

        batch_rand = random.randint(1, 100)
        batch_op = MULT_INDEX if batch_rand <= cfg.perc_mult else MINUS_INDEX if batch_rand <= cfg.perc_mult + cfg.perc_sub else PLUS_INDEX

        if batch_op == MULT_INDEX:
          # Convert from NNNNNN*NNNNNN= to 000NNN*000NNN= so answer (product) is NNNNNN
          num_zeros = cfg.n_digits // 2
          for z in range(num_zeros):
            x[:, z] = 0
            y[:, z] = 0


        # Enrich the question data on 60% of batches to speed up training
        if ( batch_op == PLUS_INDEX or batch_op == MINUS_INDEX ) and (random.randint(1, 5) < 3):
            # Flatten x and y to 1D tensors
            x_flat = x.view(-1)
            y_flat = y.view(-1)

            if batch_op == PLUS_INDEX :
              # The UseSum9 task is compound and rare and so hard to learn.
              # Increase the MakeSum9 case frequency
              # UseSum9 also relies on MakeCarry1 (50%) from previous column.
              num_elements_to_modify = int(0.40 * x.numel()) # 40%
              indices_to_modify = torch.randperm(x_flat.numel())[:num_elements_to_modify]
              if random.randint(1, 2) == 1:
                x_flat[indices_to_modify] = 9 - y_flat[indices_to_modify]
              else:
                y_flat[indices_to_modify] = 9 - x_flat[indices_to_modify]
            else:
              # Empirically, the model seems to struggle with the sign calculation.
              # Minus signs are rarer than positive signs.
              # Generate more negative answers by increasing the y value
              y_flat[y_flat < 9] += 1

            # Reshape x and y back to its original shape
            x = x_flat.view(x.shape)
            y = y_flat.view(x.shape)


        batch[:, :cfg.n_digits] = x
        batch[:, cfg.n_digits] = batch_op
        batch[:, 1+cfg.n_digits:1+cfg.n_digits*2] = y
        batch[:, 1+cfg.n_digits*2] = EQUALS_INDEX

        # Convert each row into a 5-digit number
        x_values = x[:, 0]
        y_values = y[:, 0]
        for dn in range(1,cfg.n_digits):
          x_values = x_values * 10 + x[:, dn]
          y_values = y_values * 10 + y[:, dn]

        # Elementwise operations to give the 1D tensor answers
        if batch_op == MULT_INDEX:
          answers = x_values * y_values
        else:
          if batch_op == MINUS_INDEX:
            answers = x_values - y_values
          else:
            answers = x_values + y_values

        # Insert the answers into the batch
        for i in range(cfg.batch_size):
          answer = answers[i]

          sign = PLUS_INDEX
          if answer < 0:
            sign = MINUS_INDEX
            answer = - answer

          batch[i, 2+cfg.n_digits*2] = sign
          for j in range(cfg.n_digits+1):
            batch[i, cfg.n_ctx-j-1] = answer % 10
            answer = answer // 10
            if answer == 0:
                break

        yield batch.cuda()

In [None]:
# Initialise the data generator
ds = data_generator()

In [None]:
# Test data generator
tokens = next(ds)
print(tokens[:3,:])

# Part 5: Read insert_model weights from Google drive (optional)

In [None]:
insert_fname_full= ""

if icfg.insert == True:

  insert_fname_full = "add" + file_name_suffix(icfg.n_digits, icfg.n_layers, icfg.n_heads, cfg.d_model, cfg.d_head, icfg.n_ctx, icfg.seed, icfg.n_training_steps) + ".pth"
  insert_persist_location = rootdir/f'{insert_fname_full}'

  ht_cfg = HookedTransformerConfig(
      n_layers = icfg.n_layers,
      n_heads = icfg.n_heads,
      d_model = cfg.d_model, # Assume constant
      d_head = cfg.d_head, # Assume constant
      d_mlp = cfg.d_mlp, # Assume constant
      act_fn = cfg.act_fn, # Assume constant
      normalization_type = 'LN',
      d_vocab = cfg.d_vocab, # Assume constant
      d_vocab_out = cfg.d_vocab, # Assume constant
      n_ctx = icfg.n_ctx,
      init_weights = True, # Assume constant
      device = "cuda",
      seed = icfg.seed,
  )

  insert_model = HookedTransformer(ht_cfg)

  print('insert_model will load from {}'.format(str(insert_persist_location)))
  insert_model.load_state_dict(torch.load(insert_persist_location))
  insert_model.eval()

# Part 6A: Set N token positions to "no grad" (optional, deprecated)

In [None]:
# Freeze the weights of the first few attention heads
def set_nograd_positions():
  n_positions = cfg.n_nograd_positions
  if n_positions > 0:
    print("Freezing weights of first", n_positions, "attention heads")
    for layer in range(cfg.n_layers):
        attention_heads = main_model.blocks[layer].attn
        attention_heads.W_Q.data[n_positions:].requires_grad = False
        attention_heads.W_K.data[n_positions:].requires_grad = False
        attention_heads.W_V.data[n_positions:].requires_grad = False

#set_nograd_positions()

# Part 6B: Insert insert_model weights into untrained main_model (optional)



In [None]:
# Insert the small model weights into the large model
def do_model_insert(small_model, large_model, start_layer, end_layer, transfer_ln=True, transfer_embeds=True):
    """Args:
    small_model: The model to transfer weights from
    large_model: The model to transfer weights to
    start_layer: The first layer to transfer weights to
    end_layer: The last layer to transfer weights to (Note that this is end-inclusive!)
    """
    small_cfg = {k: v for k,v in small_model.cfg.__dict__.items() if k in ["d_head", "d_mlp", "d_model", "n_heads", "n_layers"]}
    large_cfg = {k: v for k,v in large_model.cfg.__dict__.items() if k in ["d_head", "d_mlp", "d_model", "n_heads", "n_layers"]}

    # Sanity checks for large model > small model
    assert small_cfg["d_model"] == large_cfg["d_model"]
    assert small_cfg["d_head"] == large_cfg["d_head"]
    assert small_cfg["n_layers"] <= large_cfg["n_layers"]
    assert small_cfg["n_heads"] <= large_cfg["n_heads"]
    assert small_cfg["d_mlp"] <= large_cfg["d_mlp"]

    assert 0 <= start_layer < end_layer <= large_cfg["n_layers"] # Make sure start_layer and end_layer are valid
    assert end_layer - start_layer + 1 == small_cfg["n_layers"] # Make sure the number of layers to transfer is correct

    # Transfer heads and MLPs
    for small_layer_no, large_layer_no in enumerate(range(start_layer, end_layer+1)):
        # Transfer Heads
        large_model.blocks[large_layer_no].attn.W_Q.data[:small_cfg["n_heads"]] = small_model.blocks[small_layer_no].attn.W_Q.clone().data
        large_model.blocks[large_layer_no].attn.W_K.data[:small_cfg["n_heads"]] = small_model.blocks[small_layer_no].attn.W_K.clone().data
        large_model.blocks[large_layer_no].attn.W_V.data[:small_cfg["n_heads"]] = small_model.blocks[small_layer_no].attn.W_V.clone().data

        large_model.blocks[large_layer_no].attn.b_Q.data[:small_cfg["n_heads"]] = small_model.blocks[small_layer_no].attn.b_Q.clone().data
        large_model.blocks[large_layer_no].attn.b_K.data[:small_cfg["n_heads"]] = small_model.blocks[small_layer_no].attn.b_K.clone().data
        large_model.blocks[large_layer_no].attn.b_V.data[:small_cfg["n_heads"]] = small_model.blocks[small_layer_no].attn.b_V.clone().data

        # Transfer MLPs
        large_model.blocks[large_layer_no].mlp.W_in.data[:, :small_cfg["d_mlp"]] = small_model.blocks[small_layer_no].mlp.W_in.clone().data
        large_model.blocks[large_layer_no].mlp.b_in.data[:small_cfg["d_mlp"]] = small_model.blocks[small_layer_no].mlp.b_in.clone().data
        large_model.blocks[large_layer_no].mlp.W_out.data[:small_cfg["d_mlp"],] = small_model.blocks[small_layer_no].mlp.W_out.clone().data
        large_model.blocks[large_layer_no].mlp.b_out.data = small_model.blocks[small_layer_no].mlp.b_out.clone().data

    if transfer_ln:
        for small_layer_no, large_layer_no in enumerate(range(start_layer, end_layer+1)):
            large_model.blocks[large_layer_no].ln1.w.data = small_model.blocks[small_layer_no].ln1.w.clone().data
            large_model.blocks[large_layer_no].ln1.b.data = small_model.blocks[small_layer_no].ln1.b.clone().data

        large_model.ln_final.w.data = small_model.ln_final.w.clone().data

    if transfer_embeds:
        large_model.embed.W_E.data = small_model.embed.W_E.clone().data
        large_model.pos_embed.W_pos.data = small_model.pos_embed.W_pos.clone().data
        large_model.unembed.W_U.data = small_model.unembed.W_U.clone().data

In [None]:
if icfg.insert == True:
  print( "Inserting small_model", insert_fname_full, "into main_model", main_fname_suffix)

  if icfg.insert_late == True:
    # override last few layers
    do_model_insert(small_model=insert_model, large_model=main_model, start_layer=cfg.n_layers - icfg.n_layers, end_layer=cfg.n_layers-1)
  else:
    # override first few layers
    do_model_insert(small_model=insert_model, large_model=main_model, start_layer=0, end_layer=icfg.n_layers-1)

# Part 7: Train add/sub/mult main_model with Infinite Data
Train main_model for n_training_steps, storing train_losses per epoch.

Each training step (of n_training_steps) new training data (a batch of batch_size tokens) is generated and the model is trained and loss calculated on it. No separate "testing" data is needed, as the training data is unique each step. Memorisation of past training data by the model (if any) is minimally beneficial. For 6 digit addition or subtraction there are 1000 billion possible questions.

In [None]:
# Train the model
train_losses_list = []
per_token_train_losses_list = []

for epoch in tqdm.tqdm(range(cfg.n_training_steps)):

  tokens = next(ds)
  logits = main_model(tokens)

  per_token_train_losses_raw, _ = logits_to_tokens_loss(logits, tokens)
  per_token_train_losses = loss_fn(per_token_train_losses_raw)
  per_token_train_losses_list.append(utils.to_numpy(per_token_train_losses))

  train_loss = per_token_train_losses.mean()
  train_loss.backward()
  train_losses_list.append(train_loss.item())

  optimizer.step()
  scheduler.step()
  optimizer.zero_grad()

  if epoch % 100 == 0:
    print(epoch, train_loss.item())


print(epoch, train_loss.item())

In [None]:
print("Saving main model to Google drive", main_persist_location)
torch.save(main_model.state_dict(), main_persist_location)

# Part 8: Training Loss - Addition and Subtraction

From 1Jan24, new seed, bigger batches, enriched data, optimiser:

Addition:
- 11Jan24: add_d6_l1_h3_dm510_dh170_ctx22_seed372001_train15K. AvgLoss=0.01952 (50mins. FinalLoss=0.023058)
- 6Jan24: add_d6_l2_h3_dm510_dh170_ctx22_seed372001_train15K. AvgLoss=1e-8 (15mins. FinalLoss=9.6877e-9)

Subtraction:
- 10Jan24: sub_d6_l2_h3_dm510_dh170_ctx22_seed372001_train20K. AvgLoss=1.2e-5 (15K=3.062789e-05, FinalLoss=1.45665e-05)
- 6Jan24: sub_d6_l3_h4_dm510_dh170_ctx22_seed372001_train20K. AvgLoss=1.8e-8 (15K=6.8653e-07 FinalLoss=2.06524e-08)

Subtraction+Insert (add_d6_l2_h3_..._train15K inserted, insert_late=False):
- 10Jan24: ins_sub_d6_l2_h3_dm510_dh170_ctx22_seed372001_train20K. AvgLoss=1.2e-6 (15K=0.001378 FinalLoss=8.60275-07)
- 10Jan24: ins_sub_d6_l3_h4_dm510_dh170_ctx22_seed372001_train20K. AvgLoss=3.4e-8 (15K=4.53658e-07 FinalLoss=3.11361-08)

Mixed:
- 11Jan24: mix_d6_l2_h3_dm510_dh170_ctx22_seed372001_train20K, AvgLoss=2.8357e-5 (15K=0.003333 FinalLoss=1.23922e-05, 60%Sub 40%Add)
- 11Jan24: mix_d6_l3_h4_dm510_dh170_ctx22_seed372001_train20K, AvgLoss=1.521e-6 (15K=0.000134 FinalLoss=1.4144e-06, 60%Sub 40%Add)

Mixed+Insert (add_d6_l2_h3_..._train15K inserted, insert_late=False):
- 11Jan24: ins_mix_d6_l2_h3_dm510_dh170_ctx22_seed372001_train20K. AvgLoss=8.5828e-5 (15K=5.713924e-05 FinalLoss=2.045787e-06, 80%Sub 20%Add)
- 6Jan24: ins_mix_d6_l3_h4_dm510_dh170_ctx22_seed372001_train20K. AvgLoss=2.101e-06. (FinalLoss=9.73815e-07, 80%Sub 20%Add)
- 11Jan24: ins_mix_d6_l3_h4_dm510_dh170_ctx22_seed372001_train40K. AvgLoss=XXX. (15K=XXX 20K=XXX 25K=XXX 30K=XXX 35K=XXX 40K=XXX. FinalLoss=9.73815e-07, 80%Sub 20%Add)

In [None]:
print_config()

final_training_loss = round((train_losses_list[-5]+train_losses_list[-4]+train_losses_list[-3]+train_losses_list[-2]+train_losses_list[-1])/5,9)
print( "Final training loss", train_losses_list[-1])
print( "Last 5 average", final_training_loss)

# Part 9: Line Graphs

This section analyses the training loss by graphing it at a high level.

The loss curve for all digits show visible inflection points (bumps), but is too high level to help understand the algorithm.

When this graph is decomposed into 'per digit' graphs, the interesting distinct 'per digit' curves appear, showing each digit is being refined semi-independently, with the model algorithm refining each digit separately.

In [None]:
def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

# Helper function to plot multiple lines
def lines(raw_lines_list, x=None, mode='lines', labels=None, xaxis='Epoch', yaxis='Loss', title = '', log_y=False, hover=None, **kwargs):

    lines_list = raw_lines_list
    log_suffix = '' if log_y==False else ' (Log)'
    full_title = title + log_suffix

    if type(lines_list)==torch.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    if cfg.save_graph_to_file :
      fig = go.Figure(layout={})
      print(full_title)
    else:
      fig = go.Figure(layout={'title':full_title})

    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis + log_suffix)
    for c, line in enumerate(lines_list):
        if type(line)==torch.Tensor:
            line = utils.to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    if cfg.save_graph_to_file:
        fig.update_layout(margin=dict(l=10, r=10, t=10, b=10),width=1200,height=300)

    fig.show(bbox_inches="tight")

    if cfg.save_graph_to_file:
        filename = full_title.replace(" ", "").replace("(", "").replace(")", "").replace("&", "").replace(",", "").replace("%", "")   +'.pdf'
        pio.write_image(fig, filename)


title_suffix = 'Digit Loss Curves ' + main_fname_suffix
per_token_losses = np.stack(per_token_train_losses_list, axis=0)

line(train_losses_list,
    title=title_suffix)

lines([per_token_losses[:, i] for i in range(1+cfg.n_digits)]+[train_losses_list],
      labels = [f'digit {i}' for i in range(1+cfg.n_digits)]+['all_digits'],
      title='Per digit'+title_suffix, log_y=False)

lines([per_token_losses[:, i] for i in range(1+cfg.n_digits)]+[train_losses_list],
      labels = [f'digit {i}' for i in range(1+cfg.n_digits)]+['all_digits'],
      title='Per digit'+title_suffix, log_y=True)

for i in range(1+cfg.n_digits):
  print('Final Loss for digit ' + str(i) + ' is ', per_token_losses[-1, i])

# Part 10: Questions Set Up

Create sets of sample questions (by task) to ask the model to predict

In [None]:
def make_varied_questions():
  q0 = next(ds)
  q1 = next(ds)
  q2 = next(ds)
  q3 = next(ds)

  questions = torch.vstack((q0.cuda(), q1.cuda(), q2.cuda(), q3.cuda()))

  return questions

In [None]:
verbose = True

In [None]:
# Build a test batch of random questions
varied_questions = make_varied_questions();


# Run the sample batch, gather the cache
main_model.reset_hooks()
main_model.set_use_attn_result(True)
sample_logits, sample_cache = main_model.run_with_cache(varied_questions.cuda())
print(sample_cache) # Gives names of datasets in the cache
sample_losses_raw, sample_max_prob_tokens = logits_to_tokens_loss(sample_logits, varied_questions.cuda())
sample_loss_mean = utils.to_numpy(loss_fn(sample_losses_raw).mean())
print("Sample Mean Loss", sample_loss_mean)


# attn.hook_z is the "attention head output" hook point name (at a specified layer)
l_attn_hook_z_name = [utils.get_act_name('z', 0, 'a'),utils.get_act_name('z', 1, 'a')] # 'blocks.0.attn.hook_z' etc
sample_attn_z = sample_cache[l_attn_hook_z_name[0]]
print("Sample", l_attn_hook_z_name[0], sample_attn_z.shape) # gives [239, 18, 3, 170] = num_questions, cfg.n_ctx, n_heads, d_head
mean_attn_z = torch.mean(sample_attn_z, dim=0, keepdim=True)
print("Mean", l_attn_hook_z_name[0], mean_attn_z.shape) # gives [1, 18, 3, 170] = 1, cfg.n_ctx, n_heads, d_head


# hook_resid_pre is the "pre residual memory update" hook point name (at a specified layer)
l_hook_resid_pre_name = ['blocks.0.hook_resid_pre','blocks.1.hook_resid_pre']


# hook_resid_post is the "post residual memory update" hook point name (at a specified layer)
l_hook_resid_post_name = ['blocks.0.hook_resid_post','blocks.1.hook_resid_post']
sample_resid_post = sample_cache[l_hook_resid_post_name[0]]
print("Sample", l_hook_resid_post_name[0], sample_resid_post.shape) # gives [239, 18, 510] = num_questions, cfg.n_ctx, d_model
mean_resid_post = torch.mean(sample_resid_post, dim=0, keepdim=True)
print("Mean", l_hook_resid_post_name[0], mean_resid_post.shape) # gives [1, 18, 510] = 1, cfg.n_ctx, d_model


# mlp.hook_post is the "MLP layer" hook point name (at a specified layer)
l_mlp_hook_post_name = [utils.get_act_name('post', 0),utils.get_act_name('post', 1)] # 'blocks.0.mlp.hook_post' etc
sample_mlp_hook_post = sample_cache[l_mlp_hook_post_name[0]]
print("Sample", l_mlp_hook_post_name[0], sample_mlp_hook_post.shape) # gives [239, 18, 2040] = num_questions, cfg.n_ctx, d_model*4
mean_mlp_hook_post = torch.mean(sample_mlp_hook_post, dim=0, keepdim=True)
print("Mean", l_mlp_hook_post_name[0], mean_mlp_hook_post.shape) # gives [1, 18, 2040] = 1, cfg.n_ctx, d_model*4

# Part 11: Attention Patterns
Attention patterns show which token(s) the model's attention heads are paying attention to in each token position of the prediction calculation.

For the default CoLab set up, the  model has 3 attention heads, and performs 5 digit addition. The attention pattern is 18 by 18 squares (as 54321+77779=132100 is 18 tokens). Time proceeds vertically downwards, with one additional token being revealed horizontally at each token position, giving the overall triangle shape. This visualisation provided insights. After the question is fully revealed (at token position 11), each head starts attending to pairs of question digits from left to right (i.e. high-value digits before lower-value digits) giving the “double staircase" shape. The three heads attend to a given digit pair in three different token position, giving a time ordering of heads.

In [None]:
def show_token_attention_patterns(index, layer, token_at_index, use_case):

  the_tokens = [str(token) for token in token_at_index.tolist()]
  if layer == 0:
    tokens_str = tokens_to_string(token_at_index)
    print("Attention patterns for", tokens_str)

  attention_pattern=sample_cache["pattern", layer, "attn"][index]
  display(cv.attention.attention_patterns(
      tokens=the_tokens,
      attention=attention_pattern,
      #attention_head_names=[f"L{layer}H{i}" for i in range(cfg.n_heads)],
  ))


sample_size = 3

# Show attention patterns for some randomly chosen tokens
for i in range(sample_size):
  for layer in range(cfg.n_layers):
    show_token_attention_patterns(i, layer, tokens[i], "Misc")


In [None]:
if cfg.save_graph_to_file:

  tokens_str = []
  for i in range(cfg.n_heads):
    one_token_str = []
    for j in tokens[i]:
      one_token_str.append(str(utils.to_numpy(j)))
    tokens_str.append(one_token_str)

  # Refer https://github.com/callummcdougall/CircuitsVis/blob/main/python/circuitsvis/circuitsvis_demo.ipynb

  # html_object = cv.attention.from_cache(
  #    cache = sample_cache,
  #    tokens = tokens_str, # list of list of strings
  #    return_mode = "html",
  #)

  # Create a CoLab file containing the attention pattern(s) in HTML
  #filename = "AttentionPattern" + str(cfg.n_digits) + "Digits" + str(cfg.n_heads) + "Heads.html"
  #with open(filename, "w") as f:
  #    f.write(html_object.data)

  # Manually download the CoLab "html" file and open in your local browser.
  # Install and use the Edge extension "FireShot" to save a portion of the HTML page as a PDF