#Setup

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
    # Needed for PySvelte to work, v3 came out and broke things...
    %pip install typeguard==2.13.3
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-tci3e3x0
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-tci3e3x0
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 760135a27c4b7873b0cb66aca541958f8939f60b
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting accelerate>=0.23.0 (from transformer-lens==0.0.0)
  Downloading accelerate-0.29.2-py3-none-any.whl (297 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m297.4/297.4 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━

In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [3]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
from dataclasses import dataclass
import datasets
from IPython.display import HTML

In [4]:
#import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [5]:
#Plotting helper functions

import plotly.graph_objects as go

update_layout_set = {"xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth"}
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    if isinstance(tensor, list):
        tensor = torch.stack(tensor)
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    if "facet_labels" in kwargs_pre:
        facet_labels = kwargs_pre.pop("facet_labels")
    else:
        facet_labels = None
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0,labels={"x":xaxis, "y":yaxis}, **kwargs_pre).update_layout(**kwargs_post)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label

    fig.show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(y=utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, **kwargs):
    # Helper function to plot multiple lines
    if type(lines_list)==torch.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    fig = go.Figure(layout={'title':title})
    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis)
    for c, line in enumerate(lines_list):
        if type(line)==torch.Tensor:
            line = utils.to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    fig.show()

def bar(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.bar(
        y=utils.to_numpy(tensor),
        labels={"x": xaxis, "y": yaxis},
        template="simple_white",
        **kwargs).show(renderer)

In [6]:
import transformer_lens.patching as patching
from transformer_lens import evals
import math

In [7]:
def visualize_attn_patterns(heads, local_tokens, local_cache, title: str = ""):
    labels = []
    patterns = []
    batch_index = 0

    for head in heads:
        if isinstance(head, tuple):
            layer, head_index = head
        else:
            layer, head_index = head // model.cfg.n_heads, head % model.cfg.n_heads
        patterns.append(local_cache["pattern", layer][batch_index, head_index])
        labels.append(f"L{layer}H{head_index}")
    patterns = torch.stack(patterns, dim=-1)
    attn_viz = pysvelte.AttentionMulti(tokens=model.to_str_tokens(local_tokens[batch_index]), attention=patterns, head_labels=labels)
    display(HTML(f"<h3>{title}</h3>"))
    attn_viz.show()

#Load Model

In [8]:
torch.set_grad_enabled(True)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fb611dff490>

In [9]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cpu


In [10]:
LIST_LEN = 3
MAX_INT = 50
cfg = HookedTransformerConfig(
    n_layers=2,
    n_heads=1,
    d_model=128,
    d_head=128,
    n_ctx=3*LIST_LEN+3, # BOS d1 d2 d3 MID p1 p2 p3 END a1 a2 a3
    d_vocab= MAX_INT+3, # 0, ..., MAX_INT-1, BOS, MID, END
    d_vocab_out=MAX_INT,
    attn_only=True,
    normalization_type=None,
    device=device,
    seed=0,
)
model = HookedTransformer(cfg)
print(model)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-1): 2 x TransformerBlock(
      (ln1): Identity()
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (unembed): Unembed()
)


In [11]:
# disable biases
for name, param in model.named_parameters():
    if 'b_' in name:
        param.requires_grad = False
    print(name, param.shape, param.requires_grad)

embed.W_E torch.Size([53, 128]) True
pos_embed.W_pos torch.Size([12, 128]) True
blocks.0.attn.W_Q torch.Size([1, 128, 128]) True
blocks.0.attn.W_O torch.Size([1, 128, 128]) True
blocks.0.attn.b_Q torch.Size([1, 128]) False
blocks.0.attn.b_O torch.Size([128]) False
blocks.0.attn.W_K torch.Size([1, 128, 128]) True
blocks.0.attn.W_V torch.Size([1, 128, 128]) True
blocks.0.attn.b_K torch.Size([1, 128]) False
blocks.0.attn.b_V torch.Size([1, 128]) False
blocks.1.attn.W_Q torch.Size([1, 128, 128]) True
blocks.1.attn.W_O torch.Size([1, 128, 128]) True
blocks.1.attn.b_Q torch.Size([1, 128]) False
blocks.1.attn.b_O torch.Size([128]) False
blocks.1.attn.W_K torch.Size([1, 128, 128]) True
blocks.1.attn.W_V torch.Size([1, 128, 128]) True
blocks.1.attn.b_K torch.Size([1, 128]) False
blocks.1.attn.b_V torch.Size([1, 128]) False
unembed.W_U torch.Size([128, 50]) True
unembed.b_U torch.Size([50]) False


#Training

#Task Dataset

In [12]:
def make_data_generator(cfg, batch_size, seed=0):
    torch.manual_seed(seed)
    BOS_TOKEN = cfg.d_vocab-1
    MID_TOKEN = cfg.d_vocab-2
    END_TOKEN = cfg.d_vocab-3
    while True:
        seq = torch.randint(0, MAX_INT, (batch_size, LIST_LEN))
        perm = torch.randperm(LIST_LEN)
        ans = seq[:, perm]

        bos_tensor = einops.repeat(torch.tensor(BOS_TOKEN), " -> i 1", i=batch_size)
        mid_tensor = einops.repeat(torch.tensor(MID_TOKEN), " -> i 1", i=batch_size)
        end_tensor = einops.repeat(torch.tensor(END_TOKEN), " -> i 1", i=batch_size)

        x = torch.cat([bos_tensor, seq, mid_tensor, einops.repeat(perm, "seq -> batch seq", batch=batch_size), end_tensor, ans], dim=-1)
        yield x

print(next(make_data_generator(cfg, 2)))


tensor([[52, 44, 39, 33, 51,  1,  2,  0, 50, 39, 33, 44],
        [52, 10, 13, 29, 51,  1,  2,  0, 50, 13, 29, 10]])


#Loss Function

In [13]:
def loss_fn(logits, tokens):
    logits=logits[:, -LIST_LEN-1:-1, :]
    logits = logits.to(torch.float64)
    labels = tokens[:, -LIST_LEN:]

    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[..., None])[..., 0]
    return -correct_log_probs.mean()

with torch.no_grad():
    tokens = next(make_data_generator(cfg, 2)).to(device)
    logits = model(tokens)
    loss = loss_fn(logits, tokens)
    print(loss)

tensor(3.9640, dtype=torch.float64)


In [14]:
print('uniform loss:', np.log(cfg.d_vocab_out))

uniform loss: 3.912023005428146


#Setup Optimizer

In [15]:
lr = 1e-3
wd = 0.01
betas = (0.9, 0.98)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=betas)

batch_size = 256
train_data_loader = make_data_generator(cfg, batch_size)

#Training Loop

In [16]:
num_epochs = 4000

train_losses = []
for epoch in tqdm.tqdm(range(num_epochs)):
    tokens = next(train_data_loader).to(device)
    logits = model(tokens)
    loss = loss_fn(logits, tokens)
    loss.backward()
    train_losses.append(loss.item())

    optimizer.step()
    optimizer.zero_grad()

    if epoch % 100 == 0:
        print(f"Epoch: {epoch}, train_loss: {loss.item()}")

  0%|          | 0/4000 [00:00<?, ?it/s]

Epoch: 0, train_loss: 3.9126681423614014
Epoch: 100, train_loss: 1.7503021757225596
Epoch: 200, train_loss: 1.4704949486103251
Epoch: 300, train_loss: 0.9829875616178234
Epoch: 400, train_loss: 0.03023505979362966
Epoch: 500, train_loss: 0.039463109395148056
Epoch: 600, train_loss: 0.0006447695685105333
Epoch: 700, train_loss: 0.022440187069656493
Epoch: 800, train_loss: 0.0015573624023373274
Epoch: 900, train_loss: 2.16064946622243e-05
Epoch: 1000, train_loss: 0.003525535229202404
Epoch: 1100, train_loss: 0.0010389910995369164
Epoch: 1200, train_loss: 0.00015402106237967064
Epoch: 1300, train_loss: 4.502118188786725e-05
Epoch: 1400, train_loss: 3.591175651756231e-06
Epoch: 1500, train_loss: 1.7859222722574152e-06
Epoch: 1600, train_loss: 0.0031789599349202283
Epoch: 1700, train_loss: 0.0004103817472111452
Epoch: 1800, train_loss: 2.1507927352726425e-05
Epoch: 1900, train_loss: 6.925996265165917e-06
Epoch: 2000, train_loss: 3.628497620560638e-06
Epoch: 2100, train_loss: 1.7201994410672

In [None]:
line(train_losses,
     title="Loss curve",
     xaxis="Epoch", yaxis="Loss")

#Sanity Check

In [None]:
# get a test sample with multiple different permutations
test_batch_size = 256
test_data = []
sub_batch_size = 4
for i in range(test_batch_size // sub_batch_size):
    test_data.append(next(make_data_generator(cfg, sub_batch_size, seed=i)))

test_data = torch.cat(test_data, dim=0).to(device)
print(test_data.shape)

In [None]:
with torch.inference_mode():
    test_logits = model(test_data)
    test_logits = test_logits[:, -LIST_LEN-1:-1, :]
    preds = test_logits.argmax(dim=-1)
    test_labels = test_data[:, -LIST_LEN:]

    acc = (preds == test_labels).float().mean()
    print("Test sample accuracy:", acc.item())

#Save Model

In [None]:
%mkdir ../models

In [None]:
filename = "../models/permute_lists_model.pt"
torch.save(model.state_dict(), filename)

In [None]:
# check we can load in model
LIST_LEN = 3
MAX_INT = 50
cfg = HookedTransformerConfig(
    n_layers=2,
    n_heads=1,
    d_model=128,
    d_head=128,
    n_ctx=3*LIST_LEN+3, # BOS d1 d2 d3 MID p1 p2 p3 END a1 a2 a3
    d_vocab= MAX_INT+3, # 0, ..., MAX_INT-1, BOS, MID, END
    d_vocab_out=MAX_INT,
    attn_only=True,
    normalization_type=None,
    device=device,
    seed=0,
)

model_loaded = HookedTransformer(cfg)
model.load_state_dict(torch.load(filename), strict=True)