<a target="_blank" href="https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/No_Position_Experiment.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Introduction

The accompanying notebook to my [real-time research](https://www.youtube.com/watch?v=yo4QvDn-vsU) video. Trains a model with no positional embeddings to predict the previous token, and makes a start at analysing what's going on there!

EDIT: The loss spikes were due to the learning rate being max(step/100, 1.0) not min! Thanks to MadHatter for catching that.

# Setup

Here, we check if the code is running in Google Colab, GitHub, or a local Jupyter notebook. If in Colab or GitHub, we install the `einops` and `transformer_lens` libraries. We then import essential libraries for transformer modeling, tensor operations, visualization, and progress tracking. Finally, we set the device to GPU if available, or CPU otherwise, ensuring compatibility for running transformer experiments with TransformerLens.

In [None]:
# NBVAL_IGNORE_OUTPUT
import os

# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = False
IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"
try:
    import google.colab

    IN_COLAB = True
    print("Running as a Colab notebook")
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")

if IN_COLAB or IN_GITHUB:
    %pip install einops
    %pip install transformer_lens@v1.15.0

    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git

from transformer_lens import HookedTransformer, HookedTransformerConfig
import torch
import numpy as np
import plotly.express as px
import plotly.io as pio

pio.renderers.default = "colab"
import tqdm.auto as tqdm
import einops
from transformer_lens.utils import to_numpy

device = "cuda" if torch.cuda.is_available() else "cpu"

Running as a Jupyter notebook - intended for development only!


Here, we define two helper functions, `line` and `imshow`, to simplify plotting tensor data with Plotly.

1. **`line` Function**: Converts a tensor to a NumPy array and generates a line plot with `px.line`. Optional `line_labels` can name each line, and `x` and `y` axis labels can be set.

2. **`imshow` Function**: Converts a tensor to a NumPy array and displays it as a heatmap with `px.imshow`, using a "RdBu" color scale centered at zero. The function also allows customization of axis labels and other display settings.

These functions provide a consistent and efficient way to visualize tensor data for model analysis.


In [None]:
def line(tensor, line_labels=None, yaxis="", xaxis="", **kwargs):
    tensor = to_numpy(tensor)
    labels = {"y": yaxis, "x": xaxis}
    fig = px.line(tensor, labels=labels, **kwargs)
    if line_labels:
        for c, label in enumerate(line_labels):
            fig.data[c].name = label
    fig.show()


def imshow(tensor, yaxis="", xaxis="", **kwargs):
    tensor = to_numpy(tensor)
    plot_kwargs = {
        "color_continuous_scale": "RdBu",
        "color_continuous_midpoint": 0.0,
        "labels": {"x": xaxis, "y": yaxis},
    }
    plot_kwargs.update(kwargs)
    px.imshow(tensor, **plot_kwargs).show()

# Model Training

## Setup

### Defining the Model

Here, we create a configuration (`cfg`) for a simple transformer model using `HookedTransformerConfig` and then initialize the model with this configuration.

1. **Configuration Details**:
   - **`n_layers=2`**: Specifies 2 transformer layers.
   - **`d_model=64`**: Sets the model’s hidden dimension to 64.
   - **`d_head=64`** and **`n_heads=1`**: Sets the attention head size to 64 with only 1 head per layer, meaning each layer's attention is not split into multiple heads.
   - **`d_mlp=256`**: Defines the hidden layer dimension in the feedforward (MLP) block as 256.
   - **`d_vocab=300`**: Sets the vocabulary size to 300, meaning the model can represent up to 300 unique tokens.
   - **`n_ctx=50`**: Limits the context length (sequence length) to 50 tokens.
   - **`act_fn="relu"`**: Specifies ReLU as the activation function in the MLP layers.
   - **`normalization_type="LN"`**: Sets layer normalization (LN) as the normalization type.
   - **`device=device`**: Moves the model to the GPU if available, or to the CPU otherwise.

2. **Model Initialization**:
   - We initialize a transformer model instance (`model`) with this configuration. This model will have minimal complexity, making it easier to analyze while still retaining essential transformer characteristics.
  

In [None]:
cfg = HookedTransformerConfig(
    n_layers=2,
    d_model=64,
    d_head=64,
    n_heads=1,
    d_mlp=256,
    d_vocab=300,
    n_ctx=50,
    act_fn="relu",
    normalization_type="LN",
    device=device,
)
model = HookedTransformer(cfg)

The function `deactivate_position` disables positional embeddings in the model by setting the positional embedding weights (`W_pos`) to zero and freezing them (`requires_grad = False`). This prevents the model from learning positional information, forcing it to rely solely on content for token predictions.

We then call `deactivate_position(model)` to apply this change to the model.

In [None]:
def deactivate_position(model):
    model.pos_embed.W_pos.data[:] = 0.0
    model.pos_embed.W_pos.requires_grad = False


deactivate_position(model)

In [None]:
print(model)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-1): 2 x TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resi

### Define data + Loss function

The `make_data_generator` function creates an infinite generator that yields batches of random token sequences.

It generates sequences of integers from 1 to `cfg.d_vocab` (token IDs) with a length of `cfg.n_ctx`. If `incl_bos_token` is `True`, the first token in each sequence is set to 0 (a beginning-of-sequence token). Setting the random seed ensures reproducibility.

Here, `data_generator = make_data_generator(cfg, 2)` initializes the generator with a batch size of 2, and `print(next(data_generator))` displays the first batch.

In [None]:
def make_data_generator(cfg, batch_size, seed=123, incl_bos_token=True):
    torch.manual_seed(seed)
    while True:
        x = torch.randint(1, cfg.d_vocab, (batch_size, cfg.n_ctx))
        if incl_bos_token:
            x[:, 0] = 0
        yield x


data_generator = make_data_generator(cfg, 2)
print(next(data_generator))

tensor([[  0,  93,  34, 155, 274, 116, 114, 248,  68,   3, 298,  83, 194,  20,
           8, 133,  32,  66,  62,  73, 210, 273,  46, 243, 104, 232, 161, 125,
         123, 251,   7,   4, 115, 127,  21,   1,  89, 142,   6,  15, 298, 251,
          88, 229, 108, 114,  23,  88,   3, 265],
        [  0, 118,  46, 274, 105, 268, 131,  35,  19,  58, 226, 278,  27,  25,
         276, 180, 164,   4,  95,  27,  74, 201, 105,  65,  80, 185,  44, 258,
         105,  60,  58,  47, 126,  60, 294, 253, 258, 136,  29, 101, 258,  77,
          80, 180, 159, 169, 122, 117,  27, 194]])


The `loss_fn` function calculates the negative log likelihood loss between model predictions (`logits`) and actual tokens.

1. **Alignment**: Shifts `logits` and `tokens` by one position to predict the previous token.
2. **Log Probabilities**: Converts `logits` to log probabilities and extracts the log probability of the correct tokens.
3. **Output**: Returns either per-token loss (if `per_token=True`) or the average loss across tokens.

This function evaluates the model’s accuracy in predicting the correct tokens.

In [None]:
def loss_fn(logits, tokens, per_token=False):
    # logit shape: [batch, pos, vocab]
    # token shape: [batch, pos]
    logits = logits[:, 1:]
    tokens = tokens[:, :-1]
    log_probs = logits.log_softmax(-1)
    correct_log_probs = log_probs.gather(-1, tokens[..., None])[..., 0]
    if per_token:
        return -correct_log_probs
    else:
        return -correct_log_probs.mean()

In [None]:
# Test the loss function works
test_tokens = torch.arange(5)[None, :]
test_logits = torch.randn(1, 5, 10)
test_logits[:, 1, 0] = 10.0
test_logits[:, 2, 1] = 10.0
test_logits[:, 3, 2] = 10.0
test_logits[:, 4, 3] = 10.0
print(loss_fn(test_logits, test_tokens, per_token=True))
print(loss_fn(test_logits, test_tokens, per_token=False))

tensor([[0.0004, 0.0003, 0.0031, 0.0005]])
tensor(0.0011)


### Setup Optimizer


- **Training Parameters**: Defines `batch_size` (256), `num_epochs` (4000), `lr` (1e-4), `betas` (0.9, 0.95), `max_grad_norm` (1.0), and `wd` (0.1) for regularization.
- **Optimizer and Scheduler**: Uses AdamW optimizer with weight decay and a learning rate scheduler that ramps up over the first 100 steps.
- **Data Generator**: Initializes `data_loader` to produce batches of token sequences for training.

These settings configure the model for effective training.


In [None]:
batch_size = 256
num_epochs = 4000
lr = 1e-4
betas = (0.9, 0.95)
max_grad_norm = 1.0
wd = 0.1
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=betas, weight_decay=wd)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda i: min(i / 100, 1.0))

data_loader = make_data_generator(cfg, batch_size)

## Model Training

In [None]:
losses = []
for epoch in tqdm.tqdm(range(num_epochs)):
    tokens = next(data_loader)
    tokens = tokens.to(device)
    logits = model(tokens)
    loss = loss_fn(logits, tokens)
    loss.backward()
    if max_grad_norm is not None:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
    optimizer.step()
    optimizer.zero_grad()
    scheduler.step()
    losses.append(loss.item())
    if epoch % 100 == 0:
        print(f"Epoch {epoch}: {loss.item()}")
px.line(losses, labels={"x": "Epoch", "y": "Loss"})

  0%|          | 0/4000 [00:00<?, ?it/s]

Epoch 0: 6.039131164550781
Epoch 100: 5.773892879486084
Epoch 200: 5.573237895965576
Epoch 300: 5.444890022277832
Epoch 400: 5.3126444816589355
Epoch 500: 5.152464389801025
Epoch 600: 4.953516483306885
Epoch 700: 4.677230358123779
Epoch 800: 4.353099822998047
Epoch 900: 3.9406914710998535
Epoch 1000: 3.4933784008026123
Epoch 1100: 3.07138991355896
Epoch 1200: 2.6529295444488525
Epoch 1300: 2.2651336193084717
Epoch 1400: 1.9132359027862549
Epoch 1500: 1.576438307762146
Epoch 1600: 1.2859177589416504
Epoch 1700: 1.0253156423568726
Epoch 1800: 0.8068246841430664
Epoch 1900: 0.6299871802330017
Epoch 2000: 0.47548314929008484
Epoch 2100: 0.3611340820789337
Epoch 2200: 0.2577555775642395
Epoch 2300: 0.19410978257656097
Epoch 2400: 0.14035893976688385
Epoch 2500: 0.10599333792924881
Epoch 2600: 0.07851045578718185
Epoch 2700: 0.055136531591415405
Epoch 2800: 0.041809480637311935
Epoch 2900: 0.0317872129380703
Epoch 3000: 0.025179561227560043
Epoch 3100: 0.017474526539444923
Epoch 3200: 0.0167

In [None]:
# torch.save(model.state_dict(), "no_pos_experiment_state_dict_v0.pth")

# Model Interpretability

In [None]:
model.pos_embed.W_pos.norm()

tensor(0.)

## Look at attention patterns

In [None]:
big_data_loader = make_data_generator(cfg, 4000)
big_tokens = next(big_data_loader)
big_tokens = big_tokens.to(device)
logits, cache = model.run_with_cache(big_tokens)
print("Loss:", loss_fn(logits, big_tokens).item())

Loss: 0.003689224598929286


In [None]:
print(cache)

ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'ln_final.hook_sc

In [None]:
cache["blocks.0.attn.hook_pattern"].shape

torch.Size([4000, 1, 50, 50])

In [None]:
batch_index = 0
tokens = big_tokens[batch_index]
imshow(
    to_numpy(cache["attn", 0].mean([0, 1])),
    title="Layer 0 Attention Pattern",
    height=500,
    width=500,
)
imshow(
    to_numpy(cache["attn", 1].mean([0, 1])),
    title="Layer 1 Attention Pattern",
    height=500,
    width=500,
)

***Above visualization illustrates that layer 0 has minimal attention activity across tokens, indicating little reliance on sequence order, while layer 1 displays a clear diagonal pattern, suggesting it has learned to attend to preceding tokens in a sequential manner. This contrast implies that, despite lacking positional embeddings, layer 1 adapts to sequence structure, whereas layer 0 does not contribute significantly to ordering information.***

## Look at how different bits of the model directly contribute to the logits

In [None]:
resid_components = [
    cache["embed"],
    cache["attn_out", 0],
    cache["mlp_out", 0],
    cache["attn_out", 1],
    cache["mlp_out", 1],
]
labels = ["embed", "A0", "M0", "A1", "M2"]
resid_stack = torch.stack(resid_components, 0)
resid_stack = resid_stack - resid_stack.mean(-1, keepdim=True)
print(resid_stack.shape)

torch.Size([5, 4000, 50, 64])


In [None]:
fold_W_U = model.ln_final.w[:, None] * model.unembed.W_U
logit_components = resid_stack[:, batch_index] @ fold_W_U / cache["scale"][batch_index]
print(logit_components.shape)

torch.Size([5, 50, 300])


In [None]:
logit_components = logit_components - logit_components.mean(-1, keepdim=True)
line(
    logit_components[:, torch.arange(1, model.cfg.n_ctx).to(device), tokens[:-1]].T,
    line_labels=labels,
)

***Above visualization illustrates the contributions of different model components to the final logits across token positions. The embedding layer provides a stable baseline, while the attention layers (A0 and A1) show significant fluctuations, with A1 in particular having pronounced peaks and troughs, indicating that the model dynamically adjusts its focus on certain tokens. The MLP layers (M0 and M2) exhibit more stability, with moderate variation, suggesting they play a refining role by consolidating information from the attention layers. Overall, this pattern demonstrates how attention layers adapt to context, while the embedding and MLP layers provide a more consistent, stabilizing influence on the model’s predictions.***

## Folding In LayerNorm

In [None]:
analysis_cfg = HookedTransformerConfig(
    n_layers=2,
    d_model=64,
    d_head=64,
    n_heads=1,
    d_mlp=256,
    d_vocab=300,
    n_ctx=50,
    act_fn="relu",
    normalization_type="LNPre",
    init_weights=False,
)
analysis_model = HookedTransformer(analysis_cfg)
state_dict = model.state_dict()
analysis_model.load_and_process_state_dict(
    state_dict, fold_ln=True, center_writing_weights=True, center_unembed=True
)
deactivate_position(analysis_model)

In [None]:
# analysis_model()

## Understand Attn 0


In [None]:
QK = model.W_E @ model.W_Q[0, 0] @ model.W_K[0, 0].T @ model.W_E.T
imshow(QK, yaxis="Query", xaxis="Key")

In [None]:
OV = model.W_E @ model.W_V[0, 0] @ model.W_O[0, 0] @ model.W_in[0]
imshow(OV, yaxis="Input Vocab", xaxis="Neuron")

In [None]:
line(OV[:, torch.randint(0, 256, (5,))])

## Understand MLP 0

In [None]:
imshow(cache["post", 0][batch_index], yaxis="Pos", xaxis="Neuron")
imshow(cache["post", 0].mean(0), yaxis="Pos", xaxis="Neuron")
imshow((cache["post", 0] > 0).float()[batch_index], yaxis="Pos", xaxis="Neuron")
imshow((cache["post", 0] > 0).float().mean(0), yaxis="Pos", xaxis="Neuron")

## Understand Attn 1

## Understand MLP 1

# Experiment

In [None]:
new_token_batch = next(big_data_loader).to(device)
baseline_loss = loss_fn(model(new_token_batch), new_token_batch).item()
print("Baseline loss:", baseline_loss)

Baseline loss: 0.0036276562605053186


In [None]:
hook_list = list(model.hook_dict.keys())
losses = []
loss_labels = []
for hook_name in hook_list:
    if (
        hook_name in cache
        and hook_name != "hook_pos_embed"
        and "result" not in hook_name
    ):
        average_act = cache[hook_name].mean(0)

        def replacing_with_average_act(activation, hook):
            activation[:] = einops.repeat(
                average_act, "... -> batch ...", batch=new_token_batch.size(0)
            )
            return activation

        logits = model.run_with_hooks(
            new_token_batch, fwd_hooks=[(hook_name, replacing_with_average_act)]
        )
        loss = loss_fn(logits, new_token_batch)
        print(hook_name, loss.item())
        losses.append(loss.item())
        loss_labels.append(hook_name)

hook_embed 10.975448608398438
blocks.0.ln1.hook_scale 0.4518754482269287
blocks.0.ln1.hook_normalized 2.4589312076568604
blocks.0.ln2.hook_scale 0.02511436492204666
blocks.0.ln2.hook_normalized 9.506545066833496
blocks.0.attn.hook_k 0.026024164631962776
blocks.0.attn.hook_q 1.9935917854309082
blocks.0.attn.hook_v 0.19012247025966644
blocks.0.attn.hook_z 2.4572794437408447
blocks.0.attn.hook_attn_scores 1.9351273775100708
blocks.0.attn.hook_pattern 1.9483344554901123
blocks.0.mlp.hook_pre 9.506546020507812
blocks.0.mlp.hook_post 9.526301383972168
blocks.0.hook_attn_out 2.4572784900665283
blocks.0.hook_mlp_out 9.526301383972168
blocks.0.hook_resid_pre 10.975448608398438
blocks.0.hook_resid_mid 11.21129035949707
blocks.0.hook_resid_post 10.834088325500488
blocks.1.ln1.hook_scale 0.021276870742440224
blocks.1.ln1.hook_normalized 9.080503463745117
blocks.1.ln2.hook_scale 0.003745849709957838
blocks.1.ln2.hook_normalized 4.472580432891846
blocks.1.attn.hook_k 0.0021377610974013805
blocks.1.a

In [None]:
line(losses, hover_name=loss_labels)

In [None]:
cache.cache_dict.keys()

dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'ln_final.hook_scale', 'ln_final.

# Notes

*ChatGPT, developed by OpenAI, contributed to the Markdown content.*