<a href="https://colab.research.google.com/github/DKdekes/rotary-interp/blob/main/rotary_interp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Interpretting Rotary Embeddings

<b style="color: red">To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.</b>

This notebook covers work done to understand how a single attention layer can
use information from rotary embeddings to find the previous token in input.

Work is captured in [this repo](https://github.com/DKdekes/rotary-interp).

For more context on rotary embeddings: https://blog.eleuther.ai/rotary-embeddings/


# Task
This task comes from Problem 3.21 from [200 Concrete Open Problems](https://docs.google.com/spreadsheets/d/1oOdrQ80jDK-aGn-EVdDt3dg65GhmzrvBWzJ6MUZB8n4/edit#gid=0).

Train a 1L attention-only transformer with rotary embeddings to predict the previous token and reverse engineer how it does this.

# Motivation
Rotary embeddings are injecting a relative positional encoding into the embedded input, but exactly how transformers make use of these encodings are not known.

# Setup
(No need to read)

In [225]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    %pip install circuitsvis
    
    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-6_3ohket
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-6_3ohket
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 3f0e1c3a32ef5a69b11284ffd0ddfefe11197bc5
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [226]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

Using renderer: colab


In [227]:
import circuitsvis as cv

In [228]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [229]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

Plotting helper functions:

In [230]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

# Model Definition
The architecture trained to perform the previous token identification task is a single attention layer with rotary positional encoding. Given the simple task, their is just one head to keep the architecture as simple as possible.

In [231]:
from transformer_lens import HookedTransformerConfig, HookedTransformer
from transformer_lens.train import HookedTransformerTrainConfig, train
import wandb

device = "cuda" if torch.cuda.is_available() else "cpu"

n_heads = 1
d_head = n_heads * 8

cfg = {
 'n_layers': 1,
 'd_model': 512,
 'n_ctx': 10,
 'd_head': d_head,
 'model_name': 'Attn_Only_1L512W_C4_Rotary',
 'n_heads': n_heads,
 'd_mlp': 2048,
 'act_fn': 'solu_ln',
 'd_vocab': 48262,
 'eps': 1e-05,
 'use_attn_result': False,
 'use_attn_scale': True,
 'use_split_qkv_input': False,
 'use_local_attn': False,
 'original_architecture': 'neel',
 'from_checkpoint': False,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'tokenizer_name': 'NeelNanda/gpt-neox-tokenizer-digits',
 'window_size': None,
 'attn_types': None,
 'init_mode': 'gpt2',
 'normalization_type': 'LNPre',
 'device': device,
 'n_devices': 1,
 'attention_dir': 'causal',
 'attn_only': True,
 'seed': None,
 'initializer_range': -1,
 'init_weights': True,
 'scale_attn_by_inverse_layer_idx': False,
 'positional_embedding_type': 'rotary',
 'final_rms': False,
 'd_vocab_out': 48262,
 'parallel_attn_mlp': False,
 'rotary_dim': None,
 'use_hook_tokens': False,
 'gated_mlp': False
 }

cfg = HookedTransformerConfig.from_dict(cfg)
model = HookedTransformer(cfg)

# Data
Random strings of numbers where generated as training data. `NeelNanda/gpt-neox-tokenizer-digits` tokenizer tokenizes digits at the character level. I know it's overkill to use a gpt-neox tokenizer for this task. This is what I've hacked together :)

In [232]:
from torch.utils.data import Dataset
from random import randint

# generate texts
n_texts = 10000
texts = [''.join([str(randint(0, 9)) for _ in range(10)]) for _ in range(n_texts)]

class TokenDataset(Dataset):
  def __init__(self, tokenizer, texts):
    self.tokenizer = tokenizer
    self.texts = texts

  def __len__(self):
    return len(self.texts)

  def __getitem__(self, idx):
    item = {'tokens': torch.Tensor(self.tokenizer(self.texts[idx])['input_ids']).long()}
    return item

dataset = TokenDataset(model.tokenizer, texts)
dataset[0]

{'tokens': tensor([18, 26, 21, 25, 20, 21, 18, 21, 22, 18])}

# Loss Function
I adapted the lm_cross_entropy_loss function that is default for HookedTransformer so that the loss is now the cross entropy between the output logits at `t` and the token at `t-1`.

In [233]:
def prev_token_cross_entropy_loss(
    logits: torch.Tensor,
    tokens: torch.Tensor,
    per_token: bool = False,
) -> Union[torch.Tensor, torch.Tensor]:
    """Cross entropy loss gives the loss for predicting the PREVIOUS token.

    Args:
        logits (torch.Tensor): Logits. Shape [batch, pos, d_vocab]
        tokens (torch.Tensor[int64]): Input tokens. Shape [batch, pos]
        per_token (bool, optional): Whether to return the log probs predicted for the correct token, or the loss (ie mean of the predicted log probs). Note that the returned array has shape [batch, seq-1] as we cannot predict the first token (alternately, we ignore the final logit). Defaults to False.
    """
    if tokens.device != logits.device:
        tokens = tokens.to(logits.device)
    
    log_probs = F.log_softmax(logits, dim=-1)
    
    # Offset needed because we're predicting the PREVIOUS token (this means the first logit is meaningless)
    predicted_log_probs = log_probs[..., 1:, :].gather(
        dim=-1, index=tokens[..., :-1, None]
    )[..., 0]
    
    if per_token:
        return -predicted_log_probs
    else:
        return -predicted_log_probs.mean()

# update model's loss function
model.loss_fn = prev_token_cross_entropy_loss

In [234]:
train_config = HookedTransformerTrainConfig(
  num_epochs=5,
  batch_size=32 
)  

train(  
    model=model,
    config=train_config,
    dataset=dataset
)

Moving model to device:  cuda


  0%|          | 0/5 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Epoch 1 Samples 32 Step 0 Loss 11.036900520324707
Epoch 1 Samples 1632 Step 50 Loss 0.008286873809993267
Epoch 1 Samples 3232 Step 100 Loss 0.0018355099018663168
Epoch 1 Samples 4832 Step 150 Loss 0.0012214584276080132
Epoch 1 Samples 6432 Step 200 Loss 0.0009280472295358777
Epoch 1 Samples 8032 Step 250 Loss 0.0007234661024995148
Epoch 1 Samples 9632 Step 300 Loss 0.0005883136764168739


0it [00:00, ?it/s]

Epoch 2 Samples 32 Step 0 Loss 0.0005574462120421231
Epoch 2 Samples 1632 Step 50 Loss 0.0004614414647221565
Epoch 2 Samples 3232 Step 100 Loss 0.00039343730895780027
Epoch 2 Samples 4832 Step 150 Loss 0.000334952404955402
Epoch 2 Samples 6432 Step 200 Loss 0.00029160501435399055
Epoch 2 Samples 8032 Step 250 Loss 0.0002537959080655128
Epoch 2 Samples 9632 Step 300 Loss 0.00022657022054772824


0it [00:00, ?it/s]

Epoch 3 Samples 32 Step 0 Loss 0.00021895331155974418
Epoch 3 Samples 1632 Step 50 Loss 0.00019464342040009797
Epoch 3 Samples 3232 Step 100 Loss 0.00017357338219881058
Epoch 3 Samples 4832 Step 150 Loss 0.00015821577107999474
Epoch 3 Samples 6432 Step 200 Loss 0.00014436972560361028
Epoch 3 Samples 8032 Step 250 Loss 0.00013205688446760178
Epoch 3 Samples 9632 Step 300 Loss 0.0001223777508130297


0it [00:00, ?it/s]

Epoch 4 Samples 32 Step 0 Loss 0.00011904319399036467
Epoch 4 Samples 1632 Step 50 Loss 0.00010831476538442075
Epoch 4 Samples 3232 Step 100 Loss 0.00010063857916975394
Epoch 4 Samples 4832 Step 150 Loss 9.418081026524305e-05
Epoch 4 Samples 6432 Step 200 Loss 8.750403503654525e-05
Epoch 4 Samples 8032 Step 250 Loss 8.163017628248781e-05
Epoch 4 Samples 9632 Step 300 Loss 7.597107469337061e-05


0it [00:00, ?it/s]

Epoch 5 Samples 32 Step 0 Loss 7.468429976142943e-05
Epoch 5 Samples 1632 Step 50 Loss 6.948664668016136e-05
Epoch 5 Samples 3232 Step 100 Loss 6.514987762784585e-05
Epoch 5 Samples 4832 Step 150 Loss 6.127375672804192e-05
Epoch 5 Samples 6432 Step 200 Loss 5.786862675449811e-05
Epoch 5 Samples 8032 Step 250 Loss 5.460215106722899e-05
Epoch 5 Samples 9632 Step 300 Loss 5.109476114739664e-05


HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (ln_final): LayerNormPre(
    (hook_scale): HookPoint()
    (hook_normalized): HookPoint()
  )
  (unembed): Unembed()
)

# Activation Pattern
Activation patterns from the trained model suggest a very strong attention towards the previous token, which is to be expected when training on previous token identification.

In [242]:
act_txt = "123456789"
act_tokens = model.to_tokens(act_txt)
act_logits, act_cache = model.run_with_cache(act_tokens, remove_batch_dim=True)

attention_pattern = act_cache["pattern", 0, "attn"]
act_str_tokens = model.to_str_tokens(act_txt)

print("Layer 0 Head Attention Patterns:")
cv.attention.attention_patterns(tokens=act_str_tokens, attention=attention_pattern)

Layer 0 Head Attention Patterns:


# Hypothesis
In order for this model to perform previous token prediction reliably, `Q_t` and `K_t-1` need to dot very highly. The attention patterns gathered above reinforce this claim. This seems obvious because in order for this model to perform previous token prediction, the [Output-Value Circuit](https://transformer-circuits.pub/2021/framework/index.html) just needs to forward the information from the previous token to the output of the current token.

I hypothesize that `Q_t` and `K_t-1` dot highly together because dot product attention is directly using information from rotary positional encoding to create an affinity between `t` and `t-1`. `W_Q` and `W_K` shouldn't be learning to add any information since the task that has been learned is purely positional. Honestly this seems somewhat trivial but this is what I've had time to prove :)

# Experiment
To prove that rotary embeddings are contributing the information that attention needs to perform previous token identification, I trained the above model, collected Q and K activations before and after rotary was applied, and compared the outputs at each position using pairwise cosine similarity. I mainly wanted to visualize and analyze the positional information encoded in Q and K post-rotary.

I'm leaving the first token position out of this analysis since this is a useless output (model has no previous token to predict at this position).

In [243]:
#@title Cosine Sim Utilities
def pairwise_cosine_similarity(x):
    """This function takes a 2D PyTorch tensor x with shape (n, m) as input, 
    where n is the number of samples and m is the number of features. 
    
    The function returns a symmetric cosine similarity matrix with shape (n, n), 
    where the element at position (i, j) represents the cosine similarity between 
    the i-th and j-th samples.
    """
    # Normalize the input tensor along each row (L2 norm)
    x_normalized = F.normalize(x, p=2, dim=1)
    
    # Compute the cosine similarity matrix by multiplying the normalized tensor with its transpose
    similarity_matrix = torch.matmul(x_normalized, x_normalized.t())
    
    return similarity_matrix
  
def pairwise_sim_cache(cache, component, idx, hook, title=None):
    cache_vals = cache[component, idx, hook]
    pcs = pairwise_cosine_similarity(cache_vals.squeeze()[1:, :])
    imshow(pcs, title=title)

Looking at pairwise similarity between the Q outputs for each token position shows that `W_Q` has not learned to extract any useful information from the input. Instead it has learned to maintain a very consistent output. This way the application of rotary will yield useful information, free of noise, during attention calculation.

In [244]:
pairwise_sim_cache(act_cache, 'q', 0, 'attn', title='Pairwise Similarity of Q Pre-Rotary')

Because the pre-rotation Q output is so consistent, the post-rotation output is very clean! You can see that tokens near each other will be encoded closely due to the rotation being applied by rotary.

(I think a shirt with this pattern would be pretty cool)

In [238]:
pairwise_sim_cache(act_cache, 'rot_q', 0, 'attn', title='Pairwise Similarity of Q Post-Rotary')

The above findings for Q also apply to K:

In [239]:
pairwise_sim_cache(act_cache, 'k', 0, 'attn', title='Pairwise Similarity of K Pre-Rotary')

In [240]:
pairwise_sim_cache(act_cache, 'rot_k', 0, 'attn', title='Pairwise Similarity of K Post-Rotary')

The following output is something that I did not anticipate in hypothesis, but makes a lot of sense:

Visualizing cosine similarity between pre-rotary Q and K outputs reveals that `W_Q` and `W_K` have actually learned to introduce an **offset** from each others' outputs. This offset very likely introduces a shift in the post-rotary QK values so the previous token is predicted. If this offset did not exist and pre-rotary Q and K outputs where nearly identical, the post-rotary Q and K would dot highly at the current token, and model would always predict the current token. I'll refer to this learned offset as the **QK offset**.

In [241]:
q_cache = act_cache["q", 0, "attn"].squeeze()[1:, :]
k_cache = act_cache["k", 0, "attn"].squeeze()[1:, :]
imshow(pairwise_cosine_similarity(torch.cat((q_cache, k_cache))), title='Similarity Between Q and K Pre-Rotary')

Draw circuit

# Conclusion

The QK offset seems to point towards the core of how this model is predicting the previous token. More rigor is required to prove that QK offset is doing what has been hypothesized but given evidence seems fairly convincing.

The task of previous token prediction, this model implementation, and the semi-identified circuit seem pretty trivial. Next steps would be to test if QK offsets exist in slightly more complex models, and explore what QK offsets might look like when the task involves processing information that is not purely positional.

I wanted to say that I had a ton of fun working on this problem! Thank you Neel for building the tooling to support open MI and providing beginners with an accessible jumping off point.