# Interpreting a Sorting Model

In this notebook we train a 1L Attn-Only model to sort sequences of fixed length, and show how to use EasyTransformer to both initialise and train the model, and to then interpret the trained model.

## Setup

In [None]:
try:
  import google.colab
  IN_COLAB = True
  print("Running as a Colab notebook")
  import plotly.io as pio
  pio.renderers.default = "colab"
except:
  IN_COLAB = False
  print("Running as a Jupyter notebook - intended for development only!")
  from IPython import get_ipython
  ipython = get_ipython()
  # Code to automatically update the EasyTransformer code as its edited without restarting the kernel
  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")
  import plotly.io as pio
  pio.renderers.default = "vscode"
  

Running as a Jupyter notebook - intended for development only!


In [None]:
import os
if IN_COLAB:
    os.system('pip install git+https://github.com/neelnanda-io/Easy-Transformer.git')

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
import tqdm.notebook as tqdm

import random
import time

# from google.colab import drive
from pathlib import Path
import pickle
import os


import matplotlib.pyplot as plt
%matplotlib inline
import plotly.express as px
import plotly.graph_objects as go

from torch.utils.data import DataLoader

from functools import *
import pandas as pd
import gc
import collections
import copy

# import comet_ml
import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets

import ipywidgets as wg

In [None]:
from easy_transformer.utils import gelu_new, to_numpy, get_corner # Helper functions
from easy_transformer.hook_points import HookedRootModule, HookPoint
from easy_transformer.EasyTransformer import EasyTransformer,TransformerBlock, MLP, Attention, LayerNormPre, PosEmbed, Unembed, Embed
from easy_transformer.experiments import ExperimentMetric, AblationConfig, EasyAblation, EasyPatching, PatchingConfig
from easy_transformer.EasyTransformerConfig import EasyTransformerConfig
import easy_transformer


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Training an Algorithmic Model


In [None]:

sorting_cfg = EasyTransformerConfig(
    d_model = 32,
    d_head = 32,
    n_heads = 1,
    n_layers=1,
    n_ctx = 42,
    d_vocab=32,
    normalization_type=None,
    attn_only=True,
    use_attn_result=True,
    )
batch_size = 16
num_epochs = 50001
lr = 1e-3
seed = 123
torch.manual_seed(seed)
np.random.seed(seed)
sorting_model = EasyTransformer.from_config(sorting_cfg).to(device)
sorting_optimizer = torch.optim.Adam(sorting_model.parameters(), lr=lr)

In [None]:
def data_generator(batch_size, n_ctx, d_vocab):
    # We learn to sort sequences, of the form START 3 6 2 15 23 0 3 2 STOP 0 2 2 3 3 6 15 23
    # We evaluate the model's performance by taking the log loss of predicting the next token 
    # in the second half of the sequence
    # We use tokens d_vocab - 1 to be STOP and d_vocab - 2 to be START
    seq_len = (n_ctx-2)//2
    START = torch.zeros(batch_size, 1, dtype=torch.long).to(device) + d_vocab - 2
    STOP = torch.zeros(batch_size, 1, dtype=torch.long).to(device) + d_vocab - 1
    max_token = d_vocab - 2 # The maximum token value we can use
    while True:
        data = torch.randint(0, max_token, (batch_size, seq_len)).to(device)
        sorted_data = torch.sort(data, dim=1).values
        yield torch.concat([START, data, STOP, sorted_data], dim=1)
data_gen = data_generator(batch_size, sorting_model.cfg.n_ctx, sorting_model.cfg.d_vocab)
example_tokens = (next(data_gen))
print("Sanity check our training data:")
print(example_tokens[:2])

Sanity check our training data:
tensor([[30, 22, 19, 12,  0, 10,  2,  6, 17, 29,  4, 11,  1,  6,  1, 22,  9, 24,
         21, 13, 10, 31,  0,  1,  1,  2,  4,  6,  6,  9, 10, 10, 11, 12, 13, 17,
         19, 21, 22, 22, 24, 29],
        [30, 10, 16, 25, 27, 19, 12, 21,  6, 20,  6,  6, 25,  3, 19, 24, 16, 11,
         18, 10,  5, 31,  3,  5,  6,  6,  6, 10, 10, 11, 12, 16, 16, 18, 19, 19,
         20, 21, 24, 25, 25, 27]], device='cuda:0')


In [None]:
def lm_cross_entropy(logits, tokens, return_per_token=False):
    # We offset by 1 because we're predicting the NEXT token
    log_probs = F.log_softmax(logits[:, :-1], dim=-1)
    correct_log_probs = log_probs.gather(-1, tokens[:, 1:, None]).squeeze(-1)
    if return_per_token:
        return correct_log_probs
    else:
        return -correct_log_probs.mean()

def lm_accuracy(logits, tokens, return_per_token=False):
    top_prediction = logits.argmax(dim=-1)
    correct_matches = top_prediction[:, :-1] == tokens[:, 1:]
    if return_per_token:
        return correct_matches
    else:
        return correct_matches.sum()/correct_matches.numel()

In [None]:
def sorting_loss(logits, tokens, n_ctx=sorting_model.cfg.n_ctx):
    # We only care about the loss of predicting the second half of the sequence
    offset = (n_ctx-2)//2 + 1
    return lm_cross_entropy(logits[:, offset:, :], tokens[:, offset:])

def sorting_acc(logits, tokens, n_ctx=sorting_model.cfg.n_ctx):
    offset = (n_ctx-2)//2 + 1
    return lm_accuracy(logits[:, offset:, :], tokens[:, offset:])


In [None]:
losses = []
accuracies = []
epochs = []
for epoch in tqdm.tqdm(range(num_epochs)):
    tokens = next(data_gen).to(device)
    logits = sorting_model(tokens, return_type='logits')
    loss = sorting_loss(logits, tokens)
    loss.backward()
    sorting_optimizer.step()
    sorting_optimizer.zero_grad()
    if epoch%100 == 0:
        accuracy = sorting_acc(logits, tokens)
        print(f"Epoch: {epoch}. Loss: {loss:.4f}. Accuracy: {accuracy:.2%}")
        losses.append(loss.detach().cpu().numpy())
        accuracies.append(accuracy.detach().cpu().numpy())
        epochs.append(epoch)
px.line(x=epochs, y=losses, title="Loss", labels={'x':'Epoch', 'y':'Loss'}).show()
px.line(x=epochs, y=accuracies, title="Accuracy", labels={'x':'Epoch', 'y':'Accuracy'}).show()

  0%|          | 0/50001 [00:00<?, ?it/s]

Epoch: 0. Loss: 3.4684. Accuracy: 4.69%
Epoch: 100. Loss: 2.8659. Accuracy: 24.38%
Epoch: 200. Loss: 1.9344. Accuracy: 31.56%
Epoch: 300. Loss: 1.7086. Accuracy: 38.12%
Epoch: 400. Loss: 1.5999. Accuracy: 40.94%
Epoch: 500. Loss: 1.6157. Accuracy: 37.19%
Epoch: 600. Loss: 1.5636. Accuracy: 34.38%
Epoch: 700. Loss: 1.4924. Accuracy: 38.75%
Epoch: 800. Loss: 1.4309. Accuracy: 38.75%
Epoch: 900. Loss: 1.3987. Accuracy: 45.31%
Epoch: 1000. Loss: 1.3399. Accuracy: 41.88%
Epoch: 1100. Loss: 1.3521. Accuracy: 47.81%
Epoch: 1200. Loss: 1.2868. Accuracy: 53.75%
Epoch: 1300. Loss: 1.2665. Accuracy: 45.94%
Epoch: 1400. Loss: 1.2239. Accuracy: 51.56%
Epoch: 1500. Loss: 1.1769. Accuracy: 52.50%
Epoch: 1600. Loss: 1.1026. Accuracy: 58.44%
Epoch: 1700. Loss: 1.1249. Accuracy: 54.37%
Epoch: 1800. Loss: 1.0538. Accuracy: 57.50%
Epoch: 1900. Loss: 0.9979. Accuracy: 55.31%
Epoch: 2000. Loss: 0.9683. Accuracy: 57.81%
Epoch: 2100. Loss: 1.0160. Accuracy: 59.06%
Epoch: 2200. Loss: 0.9005. Accuracy: 64.38%
E

Sample output:

In [None]:
dropdown = wg.Dropdown(
    options=list(range(batch_size)),
    value=0,
    description='Number:',
    disabled=False,
)

seq_len = (sorting_model.cfg.n_ctx - 2)//2
example_logits = sorting_model((example_tokens))

index = 0
predicted_token = example_logits.argmax(-1)[index, seq_len+1:-1].detach().cpu().numpy()
actual_token = example_tokens[index, seq_len+2:].detach().cpu().numpy()

fig = px.line(x=np.arange(seq_len), y=[predicted_token, actual_token], title="Predicted vs Actual Next Token", labels={'x':'Token Position', 'wide_variable_0':'Predicted Token Index', 'wide_variable_1':"True Next Token Index", "value":"Token Value"})
fig = go.FigureWidget(fig)

# Choose one entry in the batch of 16 example sequences, and take the predicted and actual next tokens in the sorted sequence
def update_fig(change):
    if change['type'] == 'change' and change['name'] == 'value':
        index = change['new']
        predicted_token = example_logits.argmax(-1)[index, seq_len+1:-1].detach().cpu().numpy()
        actual_token = example_tokens[index, seq_len+2:].detach().cpu().numpy()
        # Plot them
        fig.data[0].y = predicted_token
        fig.data[1].y = actual_token
dropdown.observe(update_fig)
display(wg.VBox([dropdown,
                 fig]))

VBox(children=(Dropdown(description='Number:', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15),…

In [None]:
W_E = sorting_model.embed.W_E
W_pos = sorting_model.pos_embed.W_pos
W_E_pos = torch.concat([W_E, W_pos], dim=1)
W_U = sorting_model.unembed.W_U
W_K = sorting_model.blocks[0].attn.W_K[0]
W_Q = sorting_model.blocks[0].attn.W_Q[0]
W_V = sorting_model.blocks[0].attn.W_V[0]
W_O = sorting_model.blocks[0].attn.W_O[0]

QK_circuit = W_E_pos.T @ W_Q.T @ W_K @ W_E_pos
OV_circuit = W_U @ W_O @ W_V @ W_E_pos

vocab_labels = [f"tok {i}" for i in range(sorting_cfg.d_vocab-2)]+["START"]+["STOP"]
pos_labels = ["pos START"]+[f"pos in {j}" for j in range(seq_len)] + ["pos STOP"] + [f"pos out {j}" for j in range(seq_len)]


In [None]:
px.imshow(to_numpy(QK_circuit), labels={'x':'Key', 'y':'Query'}, x=vocab_labels+pos_labels, y=vocab_labels+pos_labels, color_continuous_scale='RdBu', color_continuous_midpoint=0.0).show()
px.imshow(to_numpy(OV_circuit), labels={'x':'Input', 'y':'Output'}, x=vocab_labels+pos_labels, y=vocab_labels, color_continuous_scale='RdBu', color_continuous_midpoint=0.0).show()

In [None]:
sorting_model.set_use_attn_result(True)

Setting use_attn_result to True


In [None]:
# sorting_model.set_
example_logits, example_cache = sorting_model.run_with_hooks(example_tokens, return_type='logits', return_cache=True)


In [None]:

example_cache.keys()

dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_post'])

In [None]:
attn_bias = einops.repeat(sorting_model.blocks[0].attn.b_O, "d_model -> batch pos d_model", batch=example_tokens.size(0), pos=example_tokens.size(1))
pos_embed = einops.repeat(example_cache['hook_pos_embed'], "pos d_model -> batch pos d_model", batch=example_tokens.size(0))
components = torch.stack([example_cache['hook_embed'], pos_embed, example_cache['blocks.0.attn.hook_result'][:, :, 0], example_cache['blocks.0.attn.hook_result'][:, :, 1], attn_bias], dim=0)
print(components.shape)
component_labels = ['embed', 'pos_embed', 'result_0', 'result_1', 'attn_bias']

torch.Size([5, 16, 42, 32])


In [None]:
example_token_answers = example_tokens[:, seq_len+2:]
W_U_per_token = sorting_model.unembed.W_U[example_token_answers]
print(W_U_per_token.shape)
direct_logit_attr = torch.einsum('cbpm,bpm->bcp', components[:, :, 1+seq_len:-1], W_U_per_token)
print(direct_logit_attr.shape)
px.imshow(to_numpy(direct_logit_attr), y=component_labels, animation_frame=0, color_continuous_midpoint=0.0, color_continuous_scale='RdBu', title="Direct Logit Attribution").show()

torch.Size([16, 20, 32])
torch.Size([16, 5, 20])


In [None]:
px.imshow(to_numpy(example_logits[:, :-1].gather(-1, example_tokens[:, 1:, None])[..., 0]), color_continuous_midpoint=0.0, color_continuous_scale='RdBu', ).show()
px.imshow(to_numpy((example_logits - example_logits.mean(-1, keepdim=True))[:, :-1].gather(-1, example_tokens[:, 1:, None])[..., 0]), color_continuous_midpoint=0.0, color_continuous_scale='RdBu', ).show()
px.imshow(to_numpy(F.log_softmax(example_logits, dim=-1)[:, :-1].gather(-1, example_tokens[:, 1:, None])[..., 0]), color_continuous_midpoint=0.0, color_continuous_scale='RdBu', )

In [None]:
from easy_transformer.utils import lm_cross_entropy_loss


lm_cross_entropy_loss(example_logits[:, seq_len+2:-1], example_tokens[:, seq_len+3:])

tensor(3.4714, device='cuda:0', grad_fn=<NegBackward0>)