# Introduction

The accompanying notebook to my [real-time research](https://www.youtube.com/watch?v=yo4QvDn-vsU) video. Trains a model with no positional embeddings to predict the previous token, and makes a start at analysing what's going on there!

EDIT: The loss spikes were due to the learning rate being max(step/100, 1.0) not min! Thanks to MadHatter for catching that.

# Setup

In [1]:
try:
    import google.colab

    IN_COLAB = True
    !pip install einops
    !pip install https://github.com/neelnanda-io/TransformerLens@no-position-experiment
except:
    IN_COLAB = False

from transformer_lens import HookedTransformer, HookedTransformerConfig
import torch
import numpy as np
import plotly.express as px
import plotly.io as pio

pio.renderers.default = "colab"
import tqdm.auto as tqdm
import einops
from transformer_lens.utils import to_numpy

Some plotting code. Wrappers around Plotly, not important to understand.

In [2]:
def line(tensor, line_labels=None, yaxis="", xaxis="", **kwargs):
    tensor = to_numpy(tensor)
    labels = {"y": yaxis, "x": xaxis}
    fig = px.line(tensor, labels=labels, **kwargs)
    if line_labels:
        for c, label in enumerate(line_labels):
            fig.data[c].name = label
    fig.show()


def imshow(tensor, yaxis="", xaxis="", **kwargs):
    tensor = to_numpy(tensor)
    plot_kwargs = {
        "color_continuous_scale": "RdBu",
        "color_continuous_midpoint": 0.0,
        "labels": {"x": xaxis, "y": yaxis},
    }
    plot_kwargs.update(kwargs)
    px.imshow(tensor, **plot_kwargs).show()

# Model Training

## Setup

### Defining the Model

In [3]:
cfg = HookedTransformerConfig(
    n_layers=2,
    d_model=64,
    d_head=64,
    n_heads=1,
    d_mlp=256,
    d_vocab=300,
    n_ctx=50,
    act_fn="relu",
    normalization_type="LN",
)
model = HookedTransformer(cfg)

Moving model to device:  cuda


In [4]:
def deactivate_position(model):
    model.pos_embed.W_pos.data[:] = 0.0
    model.pos_embed.W_pos.requires_grad = False


deactivate_position(model)

In [5]:
print(model)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_attn): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_mid): HookPoint()
      (hook_resid_post): HookPoint()
    )
    (1): TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPo

### Define data + Loss function

In [6]:
def make_data_generator(cfg, batch_size, seed=123, incl_bos_token=True):
    torch.manual_seed(seed)
    while True:
        x = torch.randint(1, cfg.d_vocab, (batch_size, cfg.n_ctx))
        if incl_bos_token:
            x[:, 0] = 0
        yield x


data_generator = make_data_generator(cfg, 2)
print(next(data_generator))

tensor([[  0,  93,  34, 155, 274, 116, 114, 248,  68,   3, 298,  83, 194,  20,
           8, 133,  32,  66,  62,  73, 210, 273,  46, 243, 104, 232, 161, 125,
         123, 251,   7,   4, 115, 127,  21,   1,  89, 142,   6,  15, 298, 251,
          88, 229, 108, 114,  23,  88,   3, 265],
        [  0, 118,  46, 274, 105, 268, 131,  35,  19,  58, 226, 278,  27,  25,
         276, 180, 164,   4,  95,  27,  74, 201, 105,  65,  80, 185,  44, 258,
         105,  60,  58,  47, 126,  60, 294, 253, 258, 136,  29, 101, 258,  77,
          80, 180, 159, 169, 122, 117,  27, 194]])


In [7]:
def loss_fn(logits, tokens, per_token=False):
    # logit shape: [batch, pos, vocab]
    # token shape: [batch, pos]
    logits = logits[:, 1:]
    tokens = tokens[:, :-1]
    log_probs = logits.log_softmax(-1)
    correct_log_probs = log_probs.gather(-1, tokens[..., None])[..., 0]
    if per_token:
        return -correct_log_probs
    else:
        return -correct_log_probs.mean()

In [8]:
# Test the loss function works
test_tokens = torch.arange(5)[None, :]
test_logits = torch.randn(1, 5, 10)
test_logits[:, 1, 0] = 10.0
test_logits[:, 2, 1] = 10.0
test_logits[:, 3, 2] = 10.0
test_logits[:, 4, 3] = 10.0
print(loss_fn(test_logits, test_tokens, per_token=True))
print(loss_fn(test_logits, test_tokens, per_token=False))

tensor([[0.0004, 0.0003, 0.0031, 0.0005]])
tensor(0.0011)


### Setup Optimizer


In [9]:
batch_size = 256
num_epochs = 4000
lr = 1e-4
betas = (0.9, 0.95)
max_grad_norm = 1.0
wd = 0.1
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=betas, weight_decay=wd)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda i: min(i / 100, 1.0))

data_loader = make_data_generator(cfg, batch_size)

## Model Training

In [10]:
losses = []
for epoch in tqdm.tqdm(range(num_epochs)):
    tokens = next(data_loader)
    tokens = tokens.cuda()
    logits = model(tokens)
    loss = loss_fn(logits, tokens)
    loss.backward()
    if max_grad_norm is not None:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
    optimizer.step()
    optimizer.zero_grad()
    scheduler.step()
    losses.append(loss.item())
    if epoch % 100 == 0:
        print(f"Epoch {epoch}: {loss.item()}")
px.line(losses, labels={"x": "Epoch", "y": "Loss"})

  0%|          | 0/4000 [00:00<?, ?it/s]

Epoch 0: 6.0046892166137695
Epoch 100: 5.636989593505859
Epoch 200: 5.43870210647583
Epoch 300: 5.084463119506836
Epoch 400: 4.105402946472168
Epoch 500: 2.667759418487549
Epoch 600: 1.4100089073181152
Epoch 700: 0.6534866094589233
Epoch 800: 0.26310086250305176
Epoch 900: 0.094235360622406
Epoch 1000: 0.07662342488765717
Epoch 1100: 0.05123501643538475
Epoch 1200: 0.0633467361330986
Epoch 1300: 0.0698024183511734
Epoch 1400: 0.03592035919427872
Epoch 1500: 0.06732264906167984
Epoch 1600: 0.028138982132077217
Epoch 1700: 0.02272624894976616
Epoch 1800: 0.02585722878575325
Epoch 1900: 0.04599686339497566
Epoch 2000: 0.21788650751113892
Epoch 2100: 0.052709151059389114
Epoch 2200: 0.025653734803199768
Epoch 2300: 0.03516862168908119
Epoch 2400: 0.017889760434627533
Epoch 2500: 0.013999780640006065
Epoch 2600: 0.036015357822179794
Epoch 2700: 0.021333860233426094
Epoch 2800: 0.07593370974063873
Epoch 2900: 0.01114147063344717
Epoch 3000: 0.007803339511156082
Epoch 3100: 0.0085709709674119

In [11]:
# torch.save(model.state_dict(), "no_pos_experiment_state_dict_v0.pth")

# Model Interpretability

In [12]:
model.pos_embed.W_pos.norm()

tensor(0., device='cuda:0')

## Look at attention patterns

In [13]:
big_data_loader = make_data_generator(cfg, 4000)
big_tokens = next(big_data_loader)
big_tokens = big_tokens.cuda()
logits, cache = model.run_with_cache(big_tokens)
print("Loss:", loss_fn(logits, big_tokens).item())

Loss: 0.005800994113087654


In [14]:
print(cache)

ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_attn', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_attn', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'ln_final.hook_scale', 

In [15]:
cache["blocks.0.attn.hook_attn"].shape

torch.Size([4000, 1, 50, 50])

In [16]:
batch_index = 0
tokens = big_tokens[batch_index]
imshow(
    to_numpy(cache["attn", 0].mean([0, 1])),
    title="Layer 0 Attention Pattern",
    height=500,
    width=500,
)
imshow(
    to_numpy(cache["attn", 1].mean([0, 1])),
    title="Layer 1 Attention Pattern",
    height=500,
    width=500,
)

## Look at how different bits of the model directly contribute to the logits

In [17]:
resid_components = [
    cache["embed"],
    cache["attn_out", 0],
    cache["mlp_out", 0],
    cache["attn_out", 1],
    cache["mlp_out", 1],
]
labels = ["embed", "A0", "M0", "A1", "M2"]
resid_stack = torch.stack(resid_components, 0)
resid_stack = resid_stack - resid_stack.mean(-1, keepdim=True)
print(resid_stack.shape)

torch.Size([5, 4000, 50, 64])


In [18]:
fold_W_U = model.ln_final.w[:, None] * model.unembed.W_U
logit_components = (
    resid_stack[:, batch_index]
    @ fold_W_U
    / cache["scale", None, "ln_final"][batch_index]
)
print(logit_components.shape)

torch.Size([5, 50, 300])


In [19]:
logit_components = logit_components - logit_components.mean(-1, keepdim=True)
line(
    logit_components[:, torch.arange(1, model.cfg.n_ctx).cuda(), tokens[:-1]].T,
    line_labels=labels,
)

## Folding In LayerNorm

In [20]:
analysis_cfg = HookedTransformerConfig(
    n_layers=2,
    d_model=64,
    d_head=64,
    n_heads=1,
    d_mlp=256,
    d_vocab=300,
    n_ctx=50,
    act_fn="relu",
    normalization_type="LNPre",
    init_weights=False,
)
analysis_model = HookedTransformer(analysis_cfg)
state_dict = model.state_dict()
analysis_model.load_and_process_state_dict(
    state_dict, fold_ln=True, center_writing_weights=True, center_unembed=True
)
deactivate_position(analysis_model)

Moving model to device:  cuda


In [21]:
# analysis_model()

## Understand Attn 0


In [22]:
QK = model.W_E @ model.W_Q[0, 0] @ model.W_K[0, 0].T @ model.W_E.T
imshow(QK, yaxis="Query", xaxis="Key")

In [23]:
OV = model.W_E @ model.W_V[0, 0] @ model.W_O[0, 0] @ model.W_in[0]
imshow(OV, yaxis="Input Vocab", xaxis="Neuron")

In [24]:
line(OV[:, torch.randint(0, 256, (5,))])

## Understand MLP 0

In [25]:
imshow(cache["post", 0][batch_index], yaxis="Pos", xaxis="Neuron")
imshow(cache["post", 0].mean(0), yaxis="Pos", xaxis="Neuron")
imshow((cache["post", 0] > 0).float()[batch_index], yaxis="Pos", xaxis="Neuron")
imshow((cache["post", 0] > 0).float().mean(0), yaxis="Pos", xaxis="Neuron")

## Understand Attn 1

## Understand MLP 1

# Experiment

In [26]:
new_token_batch = next(big_data_loader).cuda()
baseline_loss = loss_fn(model(new_token_batch), new_token_batch).item()
print("Baseline loss:", baseline_loss)

Baseline loss: 0.005588936153799295


In [27]:
hook_list = list(model.hook_dict.keys())
losses = []
loss_labels = []
for hook_name in hook_list:
    if hook_name != "hook_pos_embed" and "result" not in hook_name:
        average_act = cache[hook_name].mean(0)

        def replacing_with_average_act(activation, hook):
            activation[:] = einops.repeat(
                average_act, "... -> batch ...", batch=new_token_batch.size(0)
            )
            return activation

        logits = model.run_with_hooks(
            new_token_batch, fwd_hooks=[(hook_name, replacing_with_average_act)]
        )
        loss = loss_fn(logits, new_token_batch)
        print(hook_name, loss.item())
        losses.append(loss.item())
        loss_labels.append(hook_name)

hook_embed 15.897111892700195
blocks.0.ln1.hook_scale 0.005609430838376284
blocks.0.ln1.hook_normalized 0.511729896068573
blocks.0.ln2.hook_scale 0.006470857188105583
blocks.0.ln2.hook_normalized 4.198207855224609
blocks.0.attn.hook_k 0.006857300642877817
blocks.0.attn.hook_q 0.5248472690582275
blocks.0.attn.hook_v 0.007329730782657862
blocks.0.attn.hook_z 0.5096153616905212
blocks.0.attn.hook_attn_scores 0.5726690888404846
blocks.0.attn.hook_attn 0.5747630596160889
blocks.0.mlp.hook_pre 4.198207855224609
blocks.0.mlp.hook_post 4.294939041137695
blocks.0.hook_attn_out 0.5096153616905212
blocks.0.hook_mlp_out 4.294939041137695
blocks.0.hook_resid_pre 15.897111892700195
blocks.0.hook_resid_mid 18.945459365844727
blocks.0.hook_resid_post 18.333091735839844
blocks.1.ln1.hook_scale 0.008781297132372856
blocks.1.ln1.hook_normalized 16.925254821777344
blocks.1.ln2.hook_scale 0.005587007850408554
blocks.1.ln2.hook_normalized 9.671163558959961
blocks.1.attn.hook_k 0.059003107249736786
blocks.1.

In [28]:
line(losses, hover_name=loss_labels)

In [29]:
cache.cache_dict.keys()

dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_attn', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_attn', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'ln_final.hook_scale', 'ln_final.hook_n