# Addition Interpretability Analysis
This workbook includes content from:

* [A Mechanistic Interpretability Analysis of Grokking (Stable)](https://colab.research.google.com/drive/1F6_1_cWXE5M7WocUcpQWp3v8z4b1jL20#scrollTo=XdjjunLxwi_b) - the "Experiment: Task 2: 5 Digit Addition" overview section.

* [Non_Modular_Addition_Grokking_Tasks.ipynb](https://github.com/mechanistic-interpretability-grokking/progress-measures-paper/blob/main/Non_Modular_Addition_Grokking_Tasks%20(1).ipynb) - the sections to run the model.

This CoLab's purpose is explained in a paper called [Addition Interpretability Cohort3 - OverLeaf](https://www.overleaf.com/project/64c75f5e7211fe5cb86623d2) but basically it builds on Neel Nanda's work, to elucidate the internal algorithm of the 5 Digit Addition neutral net model.   

# A Mechanistic Interpretability Analysis of Grokking: Setup
A collection of helper functions and setup code, no need to read.

In [None]:
!nvidia-smi
!pip install einops

Sun Aug 27 21:25:25 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   60C    P8    10W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
import tqdm.notebook as tqdm

import random
import time

from google.colab import drive
from pathlib import Path
import pickle
import os

import matplotlib.pyplot as plt
%matplotlib inline
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "colab"
import plotly.graph_objects as go

from torch.utils.data import DataLoader

from functools import *
import pandas as pd
import gc

import itertools

##Defining Transformer

In [None]:
# A helper class to get access to intermediate activations (inspired by Garcon)
# It's a dummy module that is the identity function by default
# I can wrap any intermediate activation in a HookPoint and get a convenient
# way to add PyTorch hooks
class HookPoint(nn.Module):
    def __init__(self):
        super().__init__()
        self.fwd_hooks = []
        self.bwd_hooks = []

    def give_name(self, name):
        # Called by the model at initialisation
        self.name = name

    def add_hook(self, hook, dir='fwd'):
        # Hook format is fn(activation, hook_name)
        # Change it into PyTorch hook format (this includes input and output,
        # which are the same for a HookPoint)
        def full_hook(module, module_input, module_output):
            return hook(module_output, name=self.name)
        if dir=='fwd':
            handle = self.register_forward_hook(full_hook)
            self.fwd_hooks.append(handle)
        elif dir=='bwd':
            handle = self.register_backward_hook(full_hook)
            self.bwd_hooks.append(handle)
        else:
            raise ValueError(f"Invalid direction {dir}")

    def remove_hooks(self, dir='fwd'):
        if (dir=='fwd') or (dir=='both'):
            for hook in self.fwd_hooks:
                hook.remove()
            self.fwd_hooks = []
        if (dir=='bwd') or (dir=='both'):
            for hook in self.bwd_hooks:
                hook.remove()
            self.bwd_hooks = []
        if dir not in ['fwd', 'bwd', 'both']:
            raise ValueError(f"Invalid direction {dir}")

    def forward(self, x):
        return x

In [None]:
# Define network architecture
# I defined my own transformer from scratch so I'd fully understand each component
# - I expect this wasn't necessary or particularly important, and a bunch of this
# replicates existing PyTorch functionality

# Embed & Unembed
class Embed(nn.Module):
    def __init__(self, d_vocab, d_model):
        super().__init__()
        self.W_E = nn.Parameter(torch.randn(d_model, d_vocab)/np.sqrt(d_model))

    def forward(self, x):
        return torch.einsum('dbp -> bpd', self.W_E[:, x])

class Unembed(nn.Module):
    def __init__(self, d_vocab, d_model):
        super().__init__()
        self.W_U = nn.Parameter(torch.randn(d_model, d_vocab)/np.sqrt(d_vocab))

    def forward(self, x):
        return (x @ self.W_U)

# Positional Embeddings
class PosEmbed(nn.Module):
    def __init__(self, max_ctx, d_model):
        super().__init__()
        self.W_pos = nn.Parameter(torch.randn(max_ctx, d_model)/np.sqrt(d_model))

    def forward(self, x):
        return x+self.W_pos[:x.shape[-2]]

# LayerNorm
class LayerNorm(nn.Module):
    def __init__(self, d_model, epsilon = 1e-4, model=[None]):
        super().__init__()
        self.model = model
        self.w_ln = nn.Parameter(torch.ones(d_model))
        self.b_ln = nn.Parameter(torch.zeros(d_model))
        self.epsilon = epsilon

    def forward(self, x):
        if self.model[0].use_ln:
            x = x - x.mean(axis=-1)[..., None]
            x = x / (x.std(axis=-1)[..., None] + self.epsilon)
            x = x * self.w_ln
            x = x + self.b_ln
            return x
        else:
            return x

# Attention
class Attention(nn.Module):
    def __init__(self, d_model, num_heads, d_head, n_ctx, model):
        super().__init__()
        self.model = model
        self.W_K = nn.Parameter(torch.randn(num_heads, d_head, d_model)/np.sqrt(d_model))
        self.W_Q = nn.Parameter(torch.randn(num_heads, d_head, d_model)/np.sqrt(d_model))
        self.W_V = nn.Parameter(torch.randn(num_heads, d_head, d_model)/np.sqrt(d_model))
        self.W_O = nn.Parameter(torch.randn(d_model, d_head * num_heads)/np.sqrt(d_model))
        self.register_buffer('mask', torch.tril(torch.ones((n_ctx, n_ctx))))
        self.d_head = d_head
        self.hook_k = HookPoint()
        self.hook_q = HookPoint()
        self.hook_v = HookPoint()
        self.hook_z = HookPoint()
        self.hook_attn = HookPoint()
        self.hook_attn_pre = HookPoint()

    def forward(self, x):
        k = self.hook_k(torch.einsum('ihd,bpd->biph', self.W_K, x))
        q = self.hook_q(torch.einsum('ihd,bpd->biph', self.W_Q, x))
        v = self.hook_v(torch.einsum('ihd,bpd->biph', self.W_V, x))
        attn_scores_pre = torch.einsum('biph,biqh->biqp', k, q)
        attn_scores_masked = torch.tril(attn_scores_pre) - 1e10 * (1 - self.mask[:x.shape[-2], :x.shape[-2]])
        attn_matrix = self.hook_attn(F.softmax(self.hook_attn_pre(attn_scores_masked/np.sqrt(self.d_head)), dim=-1))
        z = self.hook_z(torch.einsum('biph,biqp->biqh', v, attn_matrix))
        z_flat = einops.rearrange(z, 'b i q h -> b q (i h)')
        out = torch.einsum('df,bqf->bqd', self.W_O, z_flat)
        return out

# MLP Layers
class MLP(nn.Module):
    def __init__(self, d_model, d_mlp, act_type, model):
        super().__init__()
        self.model = model
        self.W_in = nn.Parameter(torch.randn(d_mlp, d_model)/np.sqrt(d_model))
        self.b_in = nn.Parameter(torch.zeros(d_mlp))
        self.W_out = nn.Parameter(torch.randn(d_model, d_mlp)/np.sqrt(d_model))
        self.b_out = nn.Parameter(torch.zeros(d_model))
        self.act_type = act_type
        # self.ln = LayerNorm(d_mlp, model=self.model)
        self.hook_pre = HookPoint()
        self.hook_post = HookPoint()
        assert act_type in ['ReLU', 'GeLU']

    def forward(self, x):
        x = self.hook_pre(torch.einsum('md,bpd->bpm', self.W_in, x) + self.b_in)
        if self.act_type=='ReLU':
            x = F.relu(x)
        elif self.act_type=='GeLU':
            x = F.gelu(x)
        x = self.hook_post(x)
        x = torch.einsum('dm,bpm->bpd', self.W_out, x) + self.b_out
        return x

# Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, d_model, d_mlp, d_head, num_heads, n_ctx, act_type, model):
        super().__init__()
        self.model = model
        # self.ln1 = LayerNorm(d_model, model=self.model)
        self.attn = Attention(d_model, num_heads, d_head, n_ctx, model=self.model)
        # self.ln2 = LayerNorm(d_model, model=self.model)
        self.mlp = MLP(d_model, d_mlp, act_type, model=self.model)
        self.hook_attn_out = HookPoint()
        self.hook_mlp_out = HookPoint()
        self.hook_resid_pre = HookPoint()
        self.hook_resid_mid = HookPoint()
        self.hook_resid_post = HookPoint()

    def forward(self, x):
        x = self.hook_resid_mid(x + self.hook_attn_out(self.attn((self.hook_resid_pre(x)))))
        x = self.hook_resid_post(x + self.hook_mlp_out(self.mlp((x))))
        return x

# Full transformer
class Transformer(nn.Module):
    def __init__(self, num_layers, d_vocab, d_model, d_mlp, d_head, num_heads, n_ctx, act_type, use_cache=False, use_ln=True):
        super().__init__()
        self.cache = {}
        self.use_cache = use_cache

        self.embed = Embed(d_vocab, d_model)
        self.pos_embed = PosEmbed(n_ctx, d_model)
        self.blocks = nn.ModuleList([TransformerBlock(d_model, d_mlp, d_head, num_heads, n_ctx, act_type, model=[self]) for i in range(num_layers)])
        # self.ln = LayerNorm(d_model, model=[self])
        self.unembed = Unembed(d_vocab, d_model)
        self.use_ln = use_ln

        for name, module in self.named_modules():
            if type(module)==HookPoint:
                module.give_name(name)

    def forward(self, x):
        x = self.embed(x)
        x = self.pos_embed(x)
        for block in self.blocks:
            x = block(x)
        # x = self.ln(x)
        x = self.unembed(x)
        return x

    def set_use_cache(self, use_cache):
        self.use_cache = use_cache

    def hook_points(self):
        return [module for name, module in self.named_modules() if 'hook' in name]

    def remove_all_hooks(self):
        for hp in self.hook_points():
            hp.remove_hooks('fwd')
            hp.remove_hooks('bwd')

    def cache_all(self, cache, incl_bwd=False):
        # Caches all activations wrapped in a HookPoint
        def save_hook(tensor, name):
            cache[name] = tensor.detach()
        def save_hook_back(tensor, name):
            cache[name+'_grad'] = tensor[0].detach()
        for hp in self.hook_points():
            hp.add_hook(save_hook, 'fwd')
            if incl_bwd:
                hp.add_hook(save_hook_back, 'bwd')

##Helper Functions

In [None]:
def cross_entropy_high_precision(logits, labels):
    # Shapes: batch x vocab, batch
    # Cast logits to float64 because log_softmax has a float32 underflow on overly
    # confident data and can only return multiples of 1.2e-7 (the smallest float x
    # such that 1+x is different from 1 in float32). This leads to loss spikes
    # and dodgy gradients
    logprobs = F.log_softmax(logits.to(torch.float64), dim=-1)
    prediction_logprobs = torch.gather(logprobs, index=labels[:, None], dim=-1)
    loss = -torch.mean(prediction_logprobs)
    return loss

def test_logits(logits, bias_correction=False, original_logits=None, mode='all'):
    # Calculates cross entropy loss of logits representing a batch of all p^2
    # possible inputs
    # Batch dimension is assumed to be first
    if logits.shape[1]==p*p:
        logits = logits.T
    if logits.shape==torch.Size([p*p, p+1]):
        logits = logits[:, :-1]
    logits = logits.reshape(p*p, p)
    if bias_correction:
        # Applies bias correction - we correct for any missing bias terms,
        # independent of the input, by centering the new logits along the batch
        # dimension, and then adding the average original logits across all inputs
        logits = einops.reduce(original_logits - logits, 'batch ... -> ...', 'mean') + logits
    if mode=='train':
        return cross_entropy_high_precision(logits[is_train], labels[is_train])
    elif mode=='test':
        return cross_entropy_high_precision(logits[is_test], labels[is_test])
    elif mode=='all':
        return cross_entropy_high_precision(logits, labels)

In [None]:
#Plotting functions
# This is mostly a bunch of over-engineered mess to hack Plotly into producing
# the pretty pictures I want, I recommend not reading too closely unless you
# want Plotly hacking practice
def to_numpy(tensor, flat=False):
    if type(tensor)!=torch.Tensor:
        return tensor
    if flat:
        return tensor.flatten().detach().cpu().numpy()
    else:
        return tensor.detach().cpu().numpy()
def imshow(tensor, xaxis=None, yaxis=None, animation_name='Snapshot', **kwargs):
    if tensor.shape[0]==p*p:
        tensor = unflatten_first(tensor)
    tensor = torch.squeeze(tensor)
    px.imshow(to_numpy(tensor, flat=False),
              labels={'x':xaxis, 'y':yaxis, 'animation_name':animation_name},
              **kwargs).show()
# Set default colour scheme
imshow = partial(imshow, color_continuous_scale='Blues')
# Creates good defaults for showing divergent colour scales (ie with both
# positive and negative values, where 0 is white)
imshow_div = partial(imshow, color_continuous_scale='RdBu', color_continuous_midpoint=0.0)
# Presets a bunch of defaults to imshow to make it suitable for showing heatmaps
# of activations with x axis being input 1 and y axis being input 2.
inputs_heatmap = partial(imshow, xaxis='Input 1', yaxis='Input 2', color_continuous_scale='RdBu', color_continuous_midpoint=0.0)
def line(x, y=None, hover=None, xaxis='', yaxis='', **kwargs):
    if type(y)==torch.Tensor:
        y = to_numpy(y, flat=True)
    if type(x)==torch.Tensor:
        x=to_numpy(x, flat=True)
    fig = px.line(x, y=y, hover_name=hover, **kwargs)
    fig.update_layout(xaxis_title=xaxis, yaxis_title=yaxis)
    fig.show()
def scatter(x, y, **kwargs):
    px.scatter(x=to_numpy(x, flat=True), y=to_numpy(y, flat=True), **kwargs).show()
def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, **kwargs):
    # Helper function to plot multiple lines
    if type(lines_list)==torch.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    fig = go.Figure(layout={'title':title})
    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis)
    for c, line in enumerate(lines_list):
        if type(line)==torch.Tensor:
            line = to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    fig.show()


In [None]:
def unflatten_first(tensor):
    if tensor.shape[0]==p*p:
        return einops.rearrange(tensor, '(x y) ... -> x y ...', x=p, y=p)
    else:
        return tensor
def cos(x, y):
    return (x.dot(y))/x.norm()/y.norm()
def mod_div(a, b):
    return (a*pow(b, p-2, p))%p
def normalize(tensor, axis=0):
    return tensor/(tensor).pow(2).sum(keepdim=True, axis=axis).sqrt()


# Non_Modular_Addition_Grokking_Tasks - Configuration
Parameter configuration that controls training/test data and model training.

In [None]:
is_finite = False #@param

train_model = True #@param

# Optional. Performs sanity checks on 1% of the data
unit_tests = True #@param


In [None]:
#@markdown Model
num_layers = 1
d_vocab = 12
d_vocab_out = 10
d_model = 512 #@param
num_heads = 4
d_head = d_model//num_heads
d_mlp = 4 * d_model
seed = 129000 #@param
#@markdown Data
num_digits = 5 #@param
n_ctx = 3*num_digits + 3
act_type = 'ReLU'
batch_size = 64 #@param
num_data = 750 #@param
#@markdown Optimizer
lr = 1e-4 #@param
weight_decay = 0.1 #@param
num_epochs = 3000 #@param

#@markdown Misc
checkpoint_models = False #@param
checkpoint_every = 50 #@param

PLUS_INDEX = 10
EQUALS_INDEX = 11


# Non_Modular_Addition_Grokking_Tasks - Set Up
Sections on embedding / unembedding, transformer definition, loss functions

In [None]:
# Embedding / Unembedding

def tokens_to_string(tokens):
    tokens = to_numpy(tokens)
    x = "".join([str(i) for i in tokens[:5]])
    y = "".join([str(i) for i in tokens[6:11]])
    z = "".join([str(i) for i in tokens[12:]])
    return f"  {x}\n +{y}\n={z}"

def string_to_tokens(string, batch=False):
    lookup = {str(i):i for i in range(10)}
    lookup['+']=10
    lookup['=']=11
    tokens = [lookup[i] for i in string if i not in '\n ']
    if batch:
        return torch.tensor(tokens)[None, :]
    else:
        return torch.tensor(tokens)

In [None]:
# Transformer

torch.manual_seed(seed)
model = Transformer(num_layers=num_layers,
                    d_vocab=d_vocab,
                    d_model=d_model,
                    d_mlp=d_mlp,
                    d_head=d_head,
                    num_heads=num_heads,
                    n_ctx=n_ctx,
                    act_type=act_type)
                    # PQ: Commented out: d_vocab_out=d_vocab_out)
model.to('cuda')
optimizer = optim.AdamW(model.parameters(),
                        lr=lr,
                        weight_decay=weight_decay,
                        betas=(0.9, 0.98))
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: min(step/10, 1))

In [None]:
# Loss functions

# Calculate the per-token probability by comparing a batch of predictions "logits" to answers "tokens"
def get_pred_log_probs(logits, tokens):

    # last "digit" columns of the input tokens are the addition prediction.
    # A 5 digit addition problem can give a 6 digit answer
    trunc_logits = logits[:, -(num_digits+2):-1]

    # Convert raw score (logits) vector into a probability distribution.
    # Emphasize the largest scores and suppress the smaller ones, to make them more distinguishable.
    log_probs = F.log_softmax(trunc_logits.to(torch.float64), dim=-1)

    # last "digit" columns of the input tokens are the addition "correct" answer.
    ans_tokens = tokens[:, -(num_digits+1):]

    # Extract values from the log_probs tensor, based on indices from the ans_tokens tensor
    pred_log_probs = torch.gather(log_probs, -1, ans_tokens[:, :, None])[..., 0]

    return pred_log_probs

# Calculate loss as negative of average per-token mean probability
def loss_fn(pred_log_probs):
   return -pred_log_probs.mean(0)

# Paper - Set Up
Data generator (expanded and with unit test), paper sub-task calculations.

In [None]:
# Define "iterator" data generator function. Invoke using next().
# Batch entries are in format XXXXX+YYYYY=ZZZZZZ e.g. 55003+80002=135005
# Modified to provide more characteristics of each batch entry.
def data_generator(batch_size, num_digits, seed):
    torch.manual_seed(seed)
    while True:
        #generate a batch of addition questions (answers calculated below)
        batch = torch.zeros((batch_size, 3*num_digits+3)).to(torch.int64)
        x = torch.randint(0, 10, (batch_size, num_digits))
        y = torch.randint(0, 10, (batch_size, num_digits))
        batch[:, :num_digits] = x
        batch[:, num_digits] = PLUS_INDEX
        batch[:, 1+num_digits:1+num_digits*2] = y
        batch[:, 1+num_digits*2] = EQUALS_INDEX
        #print("Batch", batch)

        # These attributes are used for testing the model training progress
        base_adds = torch.zeros((batch_size,num_digits)).to(torch.int64)
        make_carry1s = torch.zeros((batch_size,num_digits)).to(torch.int64)
        sum9s = torch.zeros((batch_size,num_digits)).to(torch.int64)
        use_carry1s = torch.zeros((batch_size,num_digits)).to(torch.int64)
        use_sum9s = torch.zeros((batch_size,num_digits)).to(torch.int64)

        # generate the addition question answers & other info for testing
        for i in range(num_digits):
            # the column in the test attribtues being updated
            test_col = num_digits-1-i

            base_add = batch[:, num_digits-1-i]+batch[:, 2*num_digits-i]
            base_adds[:, test_col] = base_add

            sum9 = (base_add == 9)
            sum9s[:, test_col] = sum9

            if i>0:
              use_carry1s[:, test_col] = make_carry1s[:, test_col+1]
            use_carry = use_carry1s[:, test_col]

            use_sum9s[:, test_col] = sum9 & use_carry;

            digit_sum = base_add + use_carry1s[:, test_col]

            make_carry = (digit_sum >= 10)
            make_carry1s[:, test_col] = make_carry

            batch[:, -1-i] = (digit_sum % 10)

        # Final (possible) carry to highest digit of the sum
        batch[:, -1-num_digits] = make_carry1s[:, 0]

        yield batch.cuda(), base_adds.cuda(), make_carry1s.cuda(), sum9s.cuda(), use_carry1s.cuda(), use_sum9s.cuda()


In [None]:
# Data generator unit test (optional)
# This unit test checks that the above data_generator function is sensible
def unit_test_data_generator(train_tokens, train_use_carry1s, train_make_carry1s):
  test_token = train_tokens[0]
  test_use_carry = train_use_carry1s[0]
  test_make_carry = train_make_carry1s[0]

  if num_digits == 5:
    digits = test_token.cpu().numpy()
    use = test_use_carry.cpu().numpy()
    force = test_make_carry.cpu().numpy()
    #print(digits)
    #print(use)
    #print(force)
    num1 = digits[0]*10000 + digits[1]*1000 + digits[2]*100 + digits[3]*10 + digits[4];
    num2 = digits[6]*10000 + digits[7]*1000 + digits[8]*100 + digits[9]*10 + digits[10];
    sum = digits[12]*100000 + digits[13]*10000 + digits[14]*1000 + digits[15]*100 + digits[16]*10 + digits[17];
    assert num1 + num2 == sum, "Unit test failed: Data generator: Bad sum"
    assert (digits[4]+digits[10]+use[4]>=10) == force[4], "Unit test failed: Data generator: Bad carry 0"
    assert (digits[3]+digits[9]+use[3]>=10) == force[3], "Unit test failed: Data generator: Bad carry 1"
    assert (digits[2]+digits[8]+use[2]>=10) == force[2], "Unit test failed: Data generator: Bad carry 2"
    assert (digits[1]+digits[7]+use[1]>=10) == force[1], "Unit test failed: Data generator: Bad carry 3"
    assert (digits[0]+digits[6]+use[0]>=10) == force[0], "Unit test failed: Data generator: Bad carry 4"



In [None]:
# BaseAdd-only loss data and volumes. Array index 0 is the 'Units' digit. Array index 3 is the 'Thousands' digit.
ba_alldigits_loss = []
ba_perdigit_loss = []
ba_perdigit_cases = 0
ba_total_cases = 0


def calculate_baseadd_loss(tokens, prediction_logits, per_token_losses, base_adds, use_carry1s):
  global ba_perdigit_cases
  global ba_total_cases

  # BaseAdd AllDigits
  # Consider each test token in batch where UseCarry1 is false for all columns simultaneously, so BaseAdd can be used on all digits
  answer = 0
  any_use_carry1s = torch.any(use_carry1s.bool(), dim=1)
  no_use_carry1s = ~ any_use_carry1s
  num_cases = to_numpy(torch.sum(no_use_carry1s))
  if num_cases > 0 :
    filtered_loss = per_token_losses[:, -num_digits:] * no_use_carry1s[:, None]
    sum_loss = torch.sum(filtered_loss)
    answer = - to_numpy(sum_loss) / num_cases
    answer = answer / num_digits  # Approximately align the scale of ba_alldigits_loss to ba_perdigit_loss
  ba_alldigits_loss.append(answer)


  # BaseAdd PerDigit
  # Consider each test token in batch and each digit column (e.g. 5) where use_carry is false, so BaseAdd can be used on that digit
  ba_perdigit_cases = 0;
  for digit_num in range(num_digits):
    answer = 0
    no_use_carry = 1 - use_carry1s[:, -1-digit_num]
    num_cases = to_numpy(torch.sum(no_use_carry))
    ba_perdigit_cases += num_cases
    ba_total_cases += num_cases
    if num_cases > 0 :
      filtered_loss = per_token_losses[:, -1-digit_num] * no_use_carry
      sum_loss = torch.sum(filtered_loss)
      answer = - to_numpy(sum_loss) / num_cases
    if len(ba_perdigit_loss)<=digit_num:
      ba_perdigit_loss.append([])
    if (num_cases == 0) & (len(ba_perdigit_loss[digit_num]) > 0) :
      answer = ba_perdigit_loss[digit_num][-1] # Use the previous step's loss. Improves graph
    ba_perdigit_loss[digit_num].append(answer)


In [None]:
# UseCarry1-specific loss data and volumes
uc1_anydigits_loss = []
uc1_perdigit_loss = []
uc1_perdigit_cases = 0
uc1_total_cases = 0


def calculate_usecarry1_loss(tokens, prediction_logits, per_token_losses, use_carry1s, sum9s):
  global uc1_perdigit_cases
  global uc1_total_cases

  # UseCarry1 AnyDigits (exclude Sum9)
  # Consider each test token in batch where UseCarry1 is used at least once over the columns & Sum9 is never used
  num_use_carry1s = torch.sum(use_carry1s, dim=1)
  any_use_carry1s = torch.where( num_use_carry1s != 0, 1, 0 ) # At least one digit uses UseCarry1
  num_sum9s = torch.sum(use_sum9s, dim=1)
  no_sum9s = torch.where( num_sum9s == 0, 1, 0 ) # No digits have Sum9 true
  filtered_cases = any_use_carry1s & no_sum9s
  num_cases = to_numpy(torch.sum(filtered_cases))
  filtered_indices = torch.nonzero(filtered_cases).squeeze()
  filtered_token_losses = per_token_losses[filtered_indices]
  answer = - filtered_token_losses.mean()
  uc1_anydigits_loss.append(to_numpy(answer))

  # UseCarry1 PerDigit (exclude Sum9)
  # Consider each test token in batch and each digit column (e.g. 5) where UseCarry1 is used on the columns & Sum9 is not true
  uc1_perdigit_cases = 0
  for digit_num in range(num_digits):
    answer = 0
    use_carry = use_carry1s[:, -1-digit_num]
    no_sum9 = 1 - sum9s[:, -1-digit_num]
    filtered_cases = use_carry & no_sum9
    num_cases = to_numpy(torch.sum(filtered_cases))
    uc1_perdigit_cases += num_cases
    uc1_total_cases += num_cases
    if num_cases > 0 :
      filtered_loss = per_token_losses[:, -1-digit_num] * filtered_cases
      sum_loss = torch.sum(filtered_loss)
      answer = - to_numpy(sum_loss) / num_cases
    if len(uc1_perdigit_loss)<=digit_num:
      uc1_perdigit_loss.append([])
    if (num_cases==0) & (len(uc1_perdigit_loss[digit_num]) > 0) :
      answer = uc1_perdigit_loss[digit_num][-1] # Use the previous step's loss. Improves graph
    uc1_perdigit_loss[digit_num].append(answer)


In [None]:
# UseSum9-specific loss data and volumes
us9_onedigit_loss = []
us9_twodigits_loss = []
us9_perdigit_loss = []
us9_perdigit_cases = 0
us9_total_cases = 0


def calculate_usesum9_loss(tokens, prediction_logits, per_token_losses, use_sum9s):
  global us9_perdigit_cases
  global us9_total_cases

  # UseSum9 OneDigit
  # Consider each test token in batch where UseSum9 is used once over the columns
  num_use_sum9s = torch.sum(use_sum9s, dim=1)
  filtered_num_use_sum9s = torch.where( num_use_sum9s == 1, 1, 0 ) # Exactly OneDigit uses UseSum9
  num_cases = to_numpy(torch.sum(filtered_num_use_sum9s))
  filtered_indices = torch.nonzero(filtered_num_use_sum9s).squeeze()
  filtered_token_losses = per_token_losses[filtered_indices]
  answer = - filtered_token_losses.mean()
  us9_onedigit_loss.append(to_numpy(answer))

  # UseSum9 TwoDigits
  # Consider each test token in batch where UseSum9 is used twice over the columns
  num_use_sum9s = torch.sum(use_sum9s, dim=1)
  filtered_num_use_sum9s = torch.where( num_use_sum9s == 2, 1, 0 ) # Exactly TwoDigits uses UseSum9
  num_cases = to_numpy(torch.sum(filtered_num_use_sum9s))
  filtered_indices = torch.nonzero(filtered_num_use_sum9s).squeeze()
  filtered_token_losses = per_token_losses[filtered_indices]
  answer = - filtered_token_losses.mean()
  us9_twodigits_loss.append(to_numpy(answer))

  # UseSum9 PerDigit
  us9_perdigit_cases = 0
  for digit_num in range(num_digits):
    answer = 0
    use_carry = use_carry1s[:, -1-digit_num]
    use_sum9 = sum9s[:, -1-digit_num]
    filtered_cases = use_carry & use_sum9
    num_cases = to_numpy(torch.sum(filtered_cases))
    us9_perdigit_cases += num_cases
    us9_total_cases += num_cases
    if num_cases > 0 :
      filtered_loss = per_token_losses[:, -1-digit_num] * filtered_cases
      sum_loss = torch.sum(filtered_loss)
      answer = - to_numpy(sum_loss) / num_cases
    if len(us9_perdigit_loss)<=digit_num:
      us9_perdigit_loss.append([])
    if (num_cases==0) & (len(us9_perdigit_loss[digit_num]) > 0) :
      answer = us9_perdigit_loss[digit_num][-1] # Use the previous step's loss. Improves graph
    us9_perdigit_loss[digit_num].append(answer)


In [None]:
# Check that us9_perdigit_loss, uc1_perdigit_loss and ba_perdigit_loss are nonoverlapping
# This is needed so that the graphs of each are independent
def unit_test_nonoverlapping():
  global ba_perdigit_cases
  global ba_total_cases
  global uc1_perdigit_cases
  global uc1_total_cases
  global us9_perdigit_cases
  global us9_total_cases

  perdigit_numcases = us9_perdigit_cases + uc1_perdigit_cases + ba_perdigit_cases
  assert (perdigit_numcases == batch_size * num_digits), "Cases overlap: " + str(perdigit_numcases) + " != " + str(batch_size*num_digits)


# Non_Modular_Addition_Grokking_Tasks - Train model - Finite Data (Alternative 1)
Train model for num_epochs, storing train_losses & test_losses per epoch. Param is_finite must be True.

Each epoch (of 3000) the model is trained on the **same** training data. New "testing" data is generated each epoch and used to calculate the loss. Memorisation of the (constant) training data by the model is beneficial and so likely.

In [None]:
if is_finite:

    # Initialise the data generator. Paper delta
    test_ds = data_generator(batch_size, num_digits, seed)
    train_ds = data_generator(num_data, num_digits, seed)
    train_tokens, train_base_adds, train_make_carry1s, train_sum9s, train_use_carry1s, train_use_sum9s = next(train_ds)

    train_losses_list = []
    per_token_train_losses_list = []
    test_losses_list = []
    per_token_test_losses_list = []
    # sds=[]
    # epochs = [0]
    # sds.append(model.state_dict())

    for epoch in tqdm.tqdm(range(num_epochs)):
        train_logits = model(train_tokens)
        per_token_train_losses_raw = get_pred_log_probs(train_logits, train_tokens)
        per_token_train_losses = loss_fn(per_token_train_losses_raw)
        per_token_train_losses_list.append(to_numpy(per_token_train_losses))

        train_loss = per_token_train_losses.mean()
        train_loss.backward()
        train_losses_list.append(train_loss.item())

        test_tokens, test_base_adds, test_make_carry1s, test_sum9s, test_use_carry1s, test_use_sum9s = next(test_ds) # Paper delta

        test_logits = model(test_tokens)
        per_token_test_losses_raw = get_pred_log_probs(test_logits, test_tokens)
        per_token_test_losses = loss_fn(per_token_test_losses_raw)
        per_token_test_losses_list.append(to_numpy(per_token_test_losses))

        test_loss = per_token_test_losses.mean()
        # test_loss.backward() ????
        test_losses_list.append(test_loss.item())

        calculate_baseadd_loss(test_tokens, test_logits, per_token_test_losses_raw, test_base_adds, test_use_carry1s) # Paper delta
        calculate_usecarry1_loss(test_tokens, test_logits, per_token_test_losses_raw, test_use_carry1s, test_sum9s) # Paper delta
        calculate_usesum9_loss(test_tokens, test_logits, per_token_test_losses_raw, test_use_sum9s) # Paper delta

        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        if epoch % 100 == 0:
          print(epoch, train_loss.item(), test_loss.item())
          if unit_tests:
            unit_test_data_generator(test_tokens, test_use_carry1s, test_make_carry1s) # Paper delta
            unit_test_nonoverlapping() # Paper delta

# Non_Modular_Addition_Grokking_Tasks - Train model - Infinite Data (Alternative 2)
Train model for num_epochs, storing train_losses & test_losses per epoch. Param is_finite must be False.

Each epoch (of 3000) new training data (batch of 64 tokens) is generated and the model is trained and loss calculated on it. No separate "testing" data is needed, as the training data is unique each step. Memorisation of past training data by the model (if any) is minimally beneficial.

In [None]:
if not is_finite and train_model:

    # Initialise the data generator
    ds = data_generator(batch_size, num_digits, seed)

    train_losses_list = []
    per_token_train_losses_list = []
    sds=[]
    epochs = [0]
    sds.append(model.state_dict())

    for epoch in tqdm.tqdm(range(num_epochs)):

        tokens, base_adds, make_carry1s, sum9s, use_carry1s, use_sum9s = next(ds) # Paper delta
        logits = model(tokens)

        per_token_train_losses_raw = get_pred_log_probs(logits, tokens)
        per_token_train_losses = loss_fn(per_token_train_losses_raw)
        per_token_train_losses_list.append(to_numpy(per_token_train_losses))

        train_loss = per_token_train_losses.mean()
        train_loss.backward()
        train_losses_list.append(train_loss.item())

        calculate_baseadd_loss(tokens, logits, per_token_train_losses_raw, base_adds, use_carry1s) # Paper delta
        calculate_usecarry1_loss(tokens, logits, per_token_train_losses_raw, use_carry1s, sum9s) # Paper delta
        calculate_usesum9_loss(tokens, logits, per_token_train_losses_raw, use_sum9s) # Paper delta

        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        if epoch % 100 == 0:
            print(epoch, train_loss.item())
            if unit_tests:
              unit_test_data_generator(tokens, use_carry1s, make_carry1s) # Paper delta
              unit_test_nonoverlapping() # Paper delta
        if checkpoint_models:
            if (epoch+1) % (checkpoint_every) == 0:
                sds.append(model.state_dict())
                epochs.append(epoch+1)


  0%|          | 0/3000 [00:00<?, ?it/s]

0 2.7430343616761856
100 1.9971816125458297
200 1.9878598365624653
300 1.9211451122371277
400 1.799695116234186
500 1.7733734137662993
600 1.6818455862327346
700 1.640927731484273
800 1.6633163840308485
900 1.5949398357414302
1000 1.4884997076008966
1100 1.3666489853980661
1200 1.3691433715944594
1300 1.3107232249306124
1400 1.299196280155246
1500 1.2031917735316908
1600 1.0476638786300827
1700 0.7068740237688674
1800 0.48319902596597375
1900 0.32962253964640326
2000 0.24967609327256418
2100 0.18736288746112573
2200 0.13135599463987707
2300 0.09809812016126725
2400 0.12397030583380486
2500 0.09909783280253348
2600 0.05823351531285724
2700 0.0658819872624349
2800 0.05454502611187108
2900 0.0663233497679983


# Non_Modular_Addition_Grokking_Tasks - Loss Graphs
Recreates the previously published graphs. Shows code changes made do not alter previous outcomes.

In [None]:
def plot_losses( all_losses_list, per_token_losses_list, title_suffix, legend_title):
  line(all_losses_list,
      title=title_suffix)

  lines([per_token_losses_list[:, i] for i in range(1+num_digits)]+[all_losses_list],
        labels = [f'tok {i}' for i in range(1+num_digits)]+[legend_title],
        title='Per-digit '+title_suffix,
        xaxis='Step',
        yaxis='Loss')

  lines([per_token_losses_list[:, i] for i in range(1+num_digits)]+[all_losses_list],
        labels = [f'tok {i}' for i in range(1+num_digits)]+[legend_title],
        title='Per-digit (Log scale) '+title_suffix,
        xaxis='Step',
        yaxis='Log Loss',
        log_y=True)


data_size = '(Finite)' if is_finite else '(Infinite)'
title_suffix1 = ' Loss Curves for 5 digit addition ' + data_size

data_type = '(Training)'
title_suffix = data_type + title_suffix1
per_token_losses = np.stack(per_token_train_losses_list, axis=0)
plot_losses(train_losses_list, per_token_losses, title_suffix, 'train_loss')

if is_finite:

  data_type = '(Testing)'
  title_suffix = data_type + title_suffix1
  per_token_losses = np.stack(per_token_test_losses_list, axis=0)
  plot_losses(test_losses_list, per_token_losses, title_suffix, 'test_loss')


# Paper Results - Visualisation set up

In [None]:
%pip install jaxtyping
%pip install transformer_lens
%pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
import circuitsvis as cv
from IPython.display import display

Collecting git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
  Cloning https://github.com/callummcdougall/CircuitsVis.git to /tmp/pip-req-build-ow9s6brc
  Running command git clone --filter=blob:none --quiet https://github.com/callummcdougall/CircuitsVis.git /tmp/pip-req-build-ow9s6brc
  Resolved https://github.com/callummcdougall/CircuitsVis.git to commit 5afe6fed827592dd525490b81e213bc3e2241a4a
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [None]:
# Graph per-digit series using "normal" and "log" scale
def graph_perdigit(losslist, num_series, title_suffix, showlog):
  lines([losslist[i] for i in range(num_series)],
        labels = [f'tok {i}' for i in range(num_series)],
        title='Per-digit '+title_suffix,
        xaxis='Step',
        yaxis='Loss')
  if showlog:
    lines([losslist[i] for i in range(num_series)],
          labels = [f'tok {i}' for i in range(num_series)],
          title='Per-digit (log) '+title_suffix,
          xaxis='Step',
          yaxis='Log Loss',
          log_y=True)


# Paper Results - Single task (multiple digits) loss graphs
Graph the losses over training epochs of the addition sub-tasks. Each graph contains all digits, but only shows one type of task.

## BaseAdd task graphs
Graphs token loss vs step in use case where only BaseAdd (not UseCarry1 or UseSum9) is needed to get the correct answer.

In [None]:
data_size = '(Finite) ' if is_finite else '(Infinite) '
data_type = '(Testing) ' if is_finite else '(Training) '

perc = (int)(100 * ba_total_cases / (ba_total_cases + uc1_total_cases + us9_total_cases))
title_suffix = 'BaseAdd Loss ' + data_size + data_type + '(' + str(ba_total_cases) + ' cases, ' + str(perc) + '%)'

# For use cases where use_carry1s is false for all columns simultaneously, so BaseAdd can be used on all digits
line(ba_alldigits_loss, title='AllDigits ' +title_suffix)
#line(ba_alldigits_loss, title='Log AllDigits ' +title_suffix, log_y=True)

# For each digit independently
graph_perdigit(ba_perdigit_loss, num_digits, title_suffix, False)


## UseCarry1 (excluding UseSum9) task graphs
Graphs token loss vs step where use_carry1s is used at least once over the digits columns (and Sum9 is not used at all)

In [None]:
perc = (int)(100 * uc1_total_cases / (ba_total_cases + uc1_total_cases + us9_total_cases))
title_suffix = 'UseCarry1 Loss ' + data_size + data_type + '(' + str(uc1_total_cases) + ' cases, ' + str(perc) + '%)'

lines([uc1_anydigits_loss],
      labels = ['at least 1 digit'],
      title=title_suffix,
      xaxis='Step',
      yaxis='Loss')
#lines([uc1_anydigits_loss],
#      labels = ['at least 1 digit'],
#      title='Log '+title_suffix,
#      xaxis='Step',
#      yaxis='Log Loss',
#      log_y=True)

# For each digit independently
graph_perdigit(uc1_perdigit_loss, num_digits, title_suffix, False)

## UseSum9 task graphs
Graphs token loss vs step where UseSum9 is used once or twice over the digits columns

In [None]:
perc = (int)(100 * us9_total_cases / (ba_total_cases + uc1_total_cases + us9_total_cases))
title_suffix = 'UseSum9 Loss ' + data_size + data_type + '(' + str(us9_total_cases) + ' cases, ' + str(perc) + '%)'

lines([us9_onedigit_loss]+[us9_twodigits_loss],
      labels = ['1 digit']+['2 digits'],
      title=title_suffix,
      xaxis='Step',
      yaxis='Loss')
#lines([us9_onedigit_loss]+[us9_twodigits_loss],
#      labels = ['1 digit']+['2 digits'],
#      title='Log '+title_suffix,
#      xaxis='Step',
#      yaxis='Log Loss',
#      log_y=True)

# For each digit independently
graph_perdigit(us9_perdigit_loss, num_digits, title_suffix, False)

# Paper Results - Single digit (multiple tasks) loss graphs
Graph the losses over training epochs of the addition sub-tasks. Each graph contains one digit, but multiple tasks.

## Per digit BaseAdd & UseCarry1 task graphs
For each digit, graph the BasedAdd and UseCarry1 tasks for curve comparison

In [None]:
for whichdigit in range(num_digits):

  title_suffix = 'Loss for BaseAdd & UseCarry1 Tasks for digit ' + str(whichdigit) + ' ' + data_size + data_type

  lines([ba_perdigit_loss[whichdigit]]+[uc1_perdigit_loss[whichdigit]],
        labels = ['BaseAdd']+['UseCarry1'],
        title=title_suffix,
        xaxis='Step',
        yaxis='Loss')
  lines([ba_perdigit_loss[whichdigit]]+[uc1_perdigit_loss[whichdigit]],
        labels = ['BaseAdd']+['UseCarry1'],
        title='Log '+title_suffix,
        xaxis='Step',
        yaxis='Log Loss',
        log_y=True)

## Per digit BaseAdd, UseCarry1 & UseSum9 task graphs
For each digit, graph the BasedAdd, UseCarry1 & UseSum9 tasks for curve comparison

In [None]:
for whichdigit in range(num_digits):

  title_suffix = 'Loss for BaseAdd, UseCarry1 & UseSum9 Tasks for digit ' + str(whichdigit) + ' ' + data_size + data_type

  lines([ba_perdigit_loss[whichdigit]]+[uc1_perdigit_loss[whichdigit]]+[us9_perdigit_loss[whichdigit]],
        labels = ['BaseAdd']+['UseCarry1']+['UseSum9'],
        title=title_suffix,
        xaxis='Step',
        yaxis='Loss')
  lines([ba_perdigit_loss[whichdigit]]+[uc1_perdigit_loss[whichdigit]]+[us9_perdigit_loss[whichdigit]],
        labels = ['BaseAdd']+['UseCarry1']+['UseSum9'],
        title='Log '+title_suffix,
        xaxis='Step',
        yaxis='Log Loss',
        log_y=True)

# Paper Results - Attention Patterns
Say token is displayed as 7 7 4 2 6 10 1 7 6 1 1 11 0 9 5 0 3 7.
This should be read 77426 + 17611 = 095037.



In [33]:
tokens, base_adds, make_carry1s, sum9s, use_carry1s, use_sum9s = next(ds) # Paper delta

cache = {}
model.cache_all(cache)
model(tokens)
print("cache.keys", cache.keys())
print("tokens.shape", tokens.shape)

def show_token_attention_patterns(index, token_at_index):
  # Available hooks are 'blocks.0.hook_resid_pre', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_v',
  # 'blocks.0.attn.hook_attn_pre', 'blocks.0.attn.hook_attn', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid',
  # 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post'
  attention_pattern=cache['blocks.0.attn.hook_attn'][index]

  # First token attention patterns
  token_strs = [str(token) for token in token_at_index.tolist()]
  display(cv.attention.attention_patterns(
      tokens=token_strs,
      attention=attention_pattern,
      attention_head_names=[f"L0H{i}" for i in range(4)],
  ))

sample_size = 3

# Show attention patterns for some randomly chosen tokens
print("Attention patterns for first few tokens")
for i in range(sample_size):
  show_token_attention_patterns(i, tokens[i])

# Show attention patterns for some tokens which only use BaseAdd across all digits
any_use_carry1s = torch.any(use_carry1s.bool(), dim=1)
no_use_carry1s = ~ any_use_carry1s
num_cases = to_numpy(torch.sum(no_use_carry1s))
if num_cases >= sample_size :
  print(f"Attention patterns for first few BaseAdd-only tokens ({num_cases} of {tokens.shape[0]})")
  baseadd_tokens = tokens[no_use_carry1s==1]
  for i in range(sample_size):
    show_token_attention_patterns(i, baseadd_tokens[i])

# Show attention patterns for some tokens which UseCarry1 (and not UseSum9)
num_use_carry1s = torch.sum(use_carry1s, dim=1)
any_use_carry1s = torch.where( num_use_carry1s != 0, 1, 0 ) # At least one digit uses UseCarry1
num_sum9s = torch.sum(use_sum9s, dim=1)
no_sum9s = torch.where( num_sum9s == 0, 1, 0 ) # No digits have Sum9 true
filtered_cases = any_use_carry1s & no_sum9s
num_cases = to_numpy(torch.sum(filtered_cases))
if num_cases >= sample_size :
  print(f"Attention patterns for first few UseCarry1-only (and not UseSum9) tokens ({num_cases} of {tokens.shape[0]})")
  usecarry1_tokens = tokens[filtered_cases==1]
  for i in range(sample_size):
    show_token_attention_patterns(i, usecarry1_tokens[i])

# Show attention patterns for some tokens which UseSum9
num_sum9s = torch.sum(use_sum9s, dim=1)
any_sum9s = torch.where( num_sum9s != 0, 1, 0 ) # At least one digit uses Sum9
num_cases = to_numpy(torch.sum(any_sum9s))
if num_cases >= sample_size :
  print(f"Attention patterns for first few UseSum9 tokens ({num_cases} of {tokens.shape[0]})")
  usesum9_tokens = tokens[any_sum9s==1]
  for i in range(sample_size):
    show_token_attention_patterns(i, usesum9_tokens[i])



cache.keys dict_keys(['blocks.0.hook_resid_pre', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_pre', 'blocks.0.attn.hook_attn', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post'])
tokens.shape torch.Size([64, 18])
Attention patterns for first few tokens


Attention patterns for first few BaseAdd-only tokens (9 of 64)


Attention patterns for first few UseCarry1-only (and not UseSum9) tokens (45 of 64)


Attention patterns for first few UseSum9 tokens (10 of 64)


# Paper Results - Hook Experiments

In [35]:
# To add a hook that modifies rather than caches you do something like this
def test_hook(value, name):
  # Zero ablate the layer 0 attention
  #value = torch.zeros_like(value)
  return value

hook_attn_0 = model.blocks[0].attn.hook_attn

cache = {}
model.cache_all(cache)

hook_attn_0.remove_hooks()
logits_without_hook = model(tokens)
hook_attn_0.add_hook(test_hook)
logits_with_hook = model(tokens)
hook_attn_0.remove_hooks()

#assert not torch.equal(logits_without_hook, logits_with_hook)

print("logits_without_hook.shape", logits_without_hook.shape)
predicted_tokens = torch.argmax(logits_without_hook, dim=-1)
print(predicted_tokens.shape)
print(predicted_tokens[0])
print(tokens_to_string(predicted_tokens[0]))
#print("tokens_to_string", tokens_to_string(logits_without_hook[:,1]))
#print("logits_without_hook", logits_without_hook)
#print("logits_with_hook.shape", logits_with_hook.shape)
#print("logits_with_hook", logits_with_hook)





logits_without_hook.shape torch.Size([64, 18, 12])
torch.Size([64, 18])
tensor([0, 2, 9, 8, 0, 8, 8, 8, 4, 7, 6, 0, 9, 6, 4, 2, 4, 1], device='cuda:0')
  02980
 +88476
=964241


## Logit Lens
From https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb
"Fascinatingly, we see that the model is utterly unable to do the task until layer 7, almost all performance comes from attention layer 9, and performance actually decreases from there."

Layer k of a transformer means the kth transformer block, but each block consists of an attention layer (to move information around) and an MLP layer (to process information).

In [36]:
accumulated_residual, labels = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, cache)
line(logit_lens_logit_diffs, x=np.arange(model.cfg.n_layers*2+1)/2, hover_name=labels, title="Logit Difference From Accumulate Residual Stream")


AttributeError: ignored

## Head Attribution
From https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb
We can further break down the output of each attention layer into the sum of the outputs of each attention head.

"Each attention layer consists of 12 heads, which each act independently and additively. We see that only a few heads really matter - heads L9H6 and L9H9 contribute a lot positively (explaining why attention layer 9 is so important), while heads L10H7 and L11H10 contribute a lot negatively (explaining why attention layer 10 and layer 11 are actively harmful)."

In [None]:
per_head_residual, labels = cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)
per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, "(layer head_index) -> layer head_index", layer=model.cfg.n_layers, head_index=model.cfg.n_heads)
imshow(per_head_logit_diffs, labels={"x":"Head", "y":"Layer"}, title="Logit Difference From Each Head")

Lucia says:
Other things you can try to get started with hooks:
- Disable a component and see how much the loss increases. The naive way to do this is to hook into its activations and set them to 0. A better way to do it that doesn't throw the downstream components out of distribution is to set the activations to their mean value over many tokens.
- Save a component's activations (the ones of size d_model that write into the residual stream). Project them onto the unembed to see what tokens the component is directly boosting if any.
- Note that if the component is only indirectly useful because it's read by a downstream component, it may write to a direction unaligned with any unembed direction

I would see if I can find components responsible for each part, then if I found them all I would look at model checkpoints before and after each phase change and see if I can show that those components aren't present in the pre-phase shift checkpoint and are afterwards

This is more shaky speculation: we know that MLPs can implement maps/memorise lots of n-gram completions to things, so it could be doing something like mapping from every possible tuple of digits to the summation of the two, or it could map a tuple of digits (a, b) -> a % b) which could plausibly be helpful.

I think the model will consistently make use of the positional embeddings to figure out which digits to attend to. Most large models have early components which choose to write different positional information into the residual stream, but 1L means we don't have to worry about that

I'm wondering whether at each prediction digit it could copy all the input digits into the position/digit vector, then in the MLP layer it could do five different modulo operations + a final add operation, using a bunch of superposition to compress all the possible values, but we wouldn't see the characterstic phase change. Maybe that's how it was done at the memorization stage
