<a href="https://colab.research.google.com/github/neelnanda-io/Easy-Transformer/blob/demo_notebook/EasyTransformer_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [1]:
try:
  import google.colab
  IN_COLAB = True
  print("Running as a Colab notebook")

except:
  IN_COLAB = False
  print("Running as a Jupyter notebook - intended for development only!")
  from IPython import get_ipython
  ipython = get_ipython()
  # Code to automatically update the EasyTransformer code as its edited without restarting the kernel
  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")
  

Running as a Jupyter notebook - intended for development only!


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [2]:
DEBUG_MODE = False
import plotly.io as pio
if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
  pio.renderers.default = "colab"
else:
  pio.renderers.default = "vscode"


In [3]:
import os
if IN_COLAB:
    os.system('pip install git+https://github.com/neelnanda-io/Easy-Transformer.git')

In [4]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
import tqdm.notebook as tqdm

import random
import time

# from google.colab import drive
from pathlib import Path
import pickle
import os


import matplotlib.pyplot as plt
%matplotlib inline
import plotly.express as px
import plotly.graph_objects as go

from torch.utils.data import DataLoader

from functools import *
import pandas as pd
import gc
import collections
import copy

# import comet_ml
import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets


In [5]:
from easy_transformer.utils import gelu_new, to_numpy, get_corner, lm_cross_entropy_loss # Helper functions
from easy_transformer.hook_points import HookedRootModule, HookPoint # Hooking utilities
from easy_transformer import EasyTransformer, EasyTransformerConfig
import easy_transformer
from easy_transformer.experiments import ExperimentMetric, AblationConfig, EasyAblation, EasyPatching, PatchingConfig


In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Hook Points

A Garcon-style interface - the key thing is a HookPoint class. This is a layer to wrap any activation within the model in. The HookPoint acts as an identity function, but allows us to put PyTorch hooks in to edit and access the relevant activation. This allows us to take any model and insert in access points to all interesting activations by wrapping them in HookPoints

There is also a `HookedRootModule` class - this is a utility class that the root module should inherit from (root module = the model we run) - it has several utility functions for using hooks well. 

The default interface is the `run_with_hooks` function on the root module, which lets us run a forwards pass on the model, and pass on a list of hooks paired with layer names to run on that pass. 

The syntax for a hook is `function(activation, hook)` where `activation` is the activation the hook is wrapped around, and `hook` is the `HookPoint` class the function is attached to. If the function returns a new activation or edits the activation in-place, that replaces the old one, if it returns None then the activation remains as is.



## Hook Points Example

Here's a simple example of how to use the classes:

We define a basic network with two layers that each take a scalar input $x$, square it, and add a constant:
$x_0=x$, $x_1=x_0^2+3$, $x_2=x_1^2-4$.

We wrap the input, each layer's output, and the intermediate value of each layer (the square) in a hook point.

In [7]:
from easy_transformer.hook_points import HookedRootModule, HookPoint

class SquareThenAdd(nn.Module):
    def __init__(self, offset):
        super().__init__()
        self.offset = nn.Parameter(torch.tensor(offset))
        self.hook_square = HookPoint()
    
    def forward(self, x):
        # The hook_square doesn't change the value, but lets us access it
        square = self.hook_square(x * x)
        return self.offset + square
    
class TwoLayerModel(HookedRootModule):
    def __init__(self):
        super().__init__()
        self.layer1 = SquareThenAdd(3.)
        self.layer2 = SquareThenAdd(-4.)
        self.hook_in = HookPoint()
        self.hook_mid = HookPoint()
        self.hook_out = HookPoint()

        # We need to call the setup function of HookedRootModule to build an 
        # internal dictionary of modules and hooks, and to give each hook a name
        super().setup()
    
    def forward(self, x):
        # We wrap the input and each layer's output in a hook - they leave the 
        # value unchanged (unless there's a hook added to explicitly change it), 
        # but allow us to access it.
        x_in = self.hook_in(x)
        x_mid = self.hook_mid(self.layer1(x_in))
        x_out = self.hook_out(self.layer2(x_mid))
        return x_out
model = TwoLayerModel()



We can add a cache, to save the activation at each hook point

(There's a custom `run_with_cache` function on the root module as a convenience, which is a wrapper around model.forward that return model_out, cache_object - we could also manually add hooks with `run_with_hooks` that store activations in a global caching dictionary. This is often useful if we only want to store, eg, subsets or functions of some activations.)

In [8]:

out, cache = model.run_with_cache(torch.tensor(5.))
print('Model output:', out.item())
for key in cache:
    print(f"Value cached at hook {key}", cache[key].item())

Model output: 780.0
Value cached at hook hook_in 5.0
Value cached at hook layer1.hook_square 25.0
Value cached at hook hook_mid 28.0
Value cached at hook layer2.hook_square 784.0
Value cached at hook hook_out 780.0


We can also use hooks to intervene on activations - eg, we can set the intermediate value in layer 2 to zero to change the output to -5

In [9]:
def set_to_zero_hook(tensor, hook):
    print(hook.name)
    return torch.tensor(0.)
print('Output after intervening on layer2.hook_scaled', 
      model.run_with_hooks(torch.tensor(5.),
                           fwd_hooks = [('layer2.hook_square', set_to_zero_hook)]).item())

layer2.hook_square
Output after intervening on layer2.hook_scaled -4.0


# Transformer models

We now define a stripped down transformer. There are helper functions to load in the weights of several families of open source LLMs - OpenAI's GPT-2, Facebook's OPT and Eleuther's GPT-Neo.

Note: OPT-350M is not supported - it applies the LayerNorms to the *outputs* of each layer, which means we cannot fold the weights and biases into other layers, and would require notably different architecture.

**TODO:** Add in GPT-J and GPT-NeoX functionality

The list of supported model names:
```
 ['gpt2', 
                     'gpt2-medium', 
                     'gpt2-large', 
                     'gpt2-xl', 
                     'facebook/opt-125m', 
                     'facebook/opt-1.3b', 
                     'facebook/opt-2.7b', 
                     'facebook/opt-6.7b', 
                     'facebook/opt-13b', 
                     'facebook/opt-30b', 
                     'facebook/opt-66b', 
                     'EleutherAI/gpt-neo-125M', 
                     'EleutherAI/gpt-neo-1.3B', 
                     'EleutherAI/gpt-neo-2.7B',]
                     ```

# Examples

## Setup

Load in GPT-2 small

In [10]:
model_name = 'gpt2' #@param ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl', 'facebook/opt-125m', 'facebook/opt-1.3b', 'facebook/opt-2.7b', 'facebook/opt-6.7b', 'facebook/opt-13b', 'facebook/opt-30b', 'facebook/opt-66b', 'EleutherAI/gpt-neo-125M', 'EleutherAI/gpt-neo-1.3B', 'EleutherAI/gpt-neo-2.7B']
model = EasyTransformer.from_pretrained(model_name).to(device)

Create some reference text to run the models on. Models come with a `to_tokens` and `to_str_tokens` method, which can convert text to tokens and to a list of individual tokens as strings. Though GPT-2 was not trained with a beginning of string token, we prepend one by default, as the first token is often used as a "resting position" by inactive attention heads, and as a result has weird behaviour.

In [11]:
prompt = 'Interpretability is great'
# The model has a method to_str_tokens
print(model.to_str_tokens(prompt, prepend_bos=True))

prompt_2 = 'AI Alignment is great'
# We can go via the to_tokens method
tokens_2 = model.to_tokens(prompt_2, prepend_bos=True)
# to_str_tokens also takes a tensor of tokens as input, though only for a *single* example
print(model.to_str_tokens(tokens_2))

['<|endoftext|>', 'Inter', 'pret', 'ability', ' is', ' great']
['<|endoftext|>', 'AI', ' Al', 'ignment', ' is', ' great']


In [12]:
print('Reference: Hyperparameters for the model')
dataclasses.asdict(model.cfg)

Reference: Hyperparameters for the model


{'n_layers': 12,
 'd_model': 768,
 'n_ctx': 1024,
 'd_head': 64,
 'n_heads': 12,
 'model_name': 'gpt2',
 'd_mlp': 3072,
 'act_fn': 'gelu_new',
 'd_vocab': 50257,
 'eps': 1e-05,
 'use_attn_result': False,
 'use_attn_scale': True,
 'use_local_attn': False,
 'model_family': 'gpt2',
 'checkpoint': None,
 'tokenizer_name': 'gpt2',
 'window_size': None,
 'attn_types': None,
 'init_mode': 'gpt2',
 'normalization_type': 'LNPre',
 'device': 'cuda',
 'attention_dir': 'causal',
 'attn_only': False,
 'seed': 42,
 'initializer_range': 0.02,
 'init_weights': False,
 'scale_attn_by_inverse_layer_idx': False,
 'positional_embedding_type': 'standard'}

## Using the model

The model can be given either text or tokens as an input (text is automatically converted to a `batch_size=1` batch of tokens). 

This time, we'll disable the automatic prepending of a BoS token. Here it doesn't really matter either way.

In [14]:
prompt = 'Hello World!'
print(model.to_str_tokens(prompt, prepend_bos=False))
tokens = model.to_tokens(prompt, prepend_bos=False)
logits_tokens = model(tokens)
logits_text = model(prompt, prepend_bos=False)

['Hello', ' World', '!']


EasyTransformer applies multiple optimizations internally to make the model computationally equivalent to the original model, but to be more interpretable, eg [folding in LayerNorm weights](https://transformer-circuits.pub/2021/framework/index.html#:~:text=Handling%20Layer%20Normalization) and making the unembedding be mean zero (as logits are translation invariant) 

Importantly, as we remove a constant offset from $W_U$, the log probs (and thus the loss) are unchanged, but the logits are translated.

In [15]:
original_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
easy_logits = model(tokens)
original_model_logits = original_model(tokens).logits

easy_log_probs = F.log_softmax(easy_logits, dim=-1)
original_model_log_probs = F.log_softmax(original_model_logits, dim=-1)
print("Logit shape:", easy_logits.shape)
print('Fraction of log probs the same between easy model and original model:')
print(torch.isclose(original_model_log_probs, easy_log_probs).sum()/easy_log_probs.numel())
print('Fraction of logits the same between easy model and original model:')
print(torch.isclose(original_model_logits, easy_logits).sum()/easy_logits.numel())

Logit shape: torch.Size([1, 3, 50257])
Fraction of log probs the same between easy model and original model:
tensor(1., device='cuda:0')
Fraction of logits the same between easy model and original model:
tensor(0., device='cuda:0')


## Basic Examples

Print the shapes of all activations in the embedding or first layer (other layers are identical to the first layer, just change the index in `blocks.{layer}.`)

**Note:** This cell is a good reference for creating hooks - it's extremely useful to know the shapes of different activations as accessible by each hook!

By convention, each activation is batch x position x ... (where the final dimension(s) is d_model, (head_index x d_head) or d_mlp). The one exception is hook_attn (attention patterns) which has shape batch x head_index x query_pos x key_pos

**Reference:**
`batch_size=4
n_ctx=50
d_head=64
d_model=768
d_mlp=3072
n_heads=12
n_layers=12`

In [16]:
embed_or_first_layer = lambda name: (name[:6]!='blocks' or name[:8]=='blocks.0')
def print_shape(tensor, hook):
    print(f'Activation at hook {hook.name} has shape:')
    print(tensor.shape)
random_tokens = torch.randint(1000, 10000, (4, 50))
logits = model.run_with_hooks(random_tokens, fwd_hooks=[(embed_or_first_layer, print_shape)])

Activation at hook hook_embed has shape:
torch.Size([4, 50, 768])
Activation at hook hook_pos_embed has shape:
torch.Size([50, 768])
Activation at hook blocks.0.hook_resid_pre has shape:
torch.Size([4, 50, 768])
Activation at hook blocks.0.ln1.hook_scale has shape:
torch.Size([4, 50, 1])
Activation at hook blocks.0.ln1.hook_normalized has shape:
torch.Size([4, 50, 768])
Activation at hook blocks.0.attn.hook_q has shape:
torch.Size([4, 50, 12, 64])
Activation at hook blocks.0.attn.hook_k has shape:
torch.Size([4, 50, 12, 64])
Activation at hook blocks.0.attn.hook_v has shape:
torch.Size([4, 50, 12, 64])
Activation at hook blocks.0.attn.hook_attn has shape:
torch.Size([4, 12, 50, 50])
Activation at hook blocks.0.attn.hook_z has shape:
torch.Size([4, 50, 12, 64])
Activation at hook blocks.0.hook_attn_out has shape:
torch.Size([4, 50, 768])
Activation at hook blocks.0.hook_resid_mid has shape:
torch.Size([4, 50, 768])
Activation at hook blocks.0.ln2.hook_scale has shape:
torch.Size([4, 50,

Print the top corner of all activations

**Note:** This is useful to do as a sanity check when debugging a model, to quickly and roughly compare the new activations to the original activations (without looking at the full enormous tensors)

In [17]:
def print_corner(tensor, hook):
    print(hook.name)
    print(get_corner(tensor))
logits = model.run_with_hooks(tokens, fwd_hooks=[(embed_or_first_layer, print_corner)])

hook_embed
tensor([[[-0.0692, -0.1332,  0.0107],
         [-0.0903,  0.0212,  0.3009],
         [-0.1106, -0.0398,  0.0326]]], device='cuda:0',
       grad_fn=<SliceBackward0>)
hook_pos_embed
tensor([[-0.0134, -0.1920,  0.0095],
        [ 0.0250, -0.0528, -0.0939],
        [ 0.0065, -0.0825,  0.0568]], device='cuda:0',
       grad_fn=<SliceBackward0>)
blocks.0.hook_resid_pre
tensor([[[-0.0826, -0.3252,  0.0201],
         [-0.0653, -0.0316,  0.2071],
         [-0.1041, -0.1223,  0.0894]]], device='cuda:0',
       grad_fn=<SliceBackward0>)
blocks.0.ln1.hook_scale
tensor([[[0.3795],
         [0.2153],
         [0.1962]]], device='cuda:0', grad_fn=<SliceBackward0>)
blocks.0.ln1.hook_normalized
tensor([[[-0.2176, -0.8570,  0.0531],
         [-0.3033, -0.1468,  0.9616],
         [-0.5308, -0.6233,  0.4556]]], device='cuda:0',
       grad_fn=<SliceBackward0>)
blocks.0.attn.hook_q
tensor([[[[ 0.4701, -0.0775,  0.5531],
          [-1.6738,  0.0741, -2.0410],
          [ 0.4965,  0.9504,  0.2968

Cache all activations


In [18]:
logits, cache = model.run_with_cache(tokens)
for name in cache:
    if embed_or_first_layer(name):
        print(name, cache[name].shape)

hook_embed torch.Size([1, 3, 768])
hook_pos_embed torch.Size([3, 768])
blocks.0.hook_resid_pre torch.Size([1, 3, 768])
blocks.0.ln1.hook_scale torch.Size([1, 3, 1])
blocks.0.ln1.hook_normalized torch.Size([1, 3, 768])
blocks.0.attn.hook_q torch.Size([1, 3, 12, 64])
blocks.0.attn.hook_k torch.Size([1, 3, 12, 64])
blocks.0.attn.hook_v torch.Size([1, 3, 12, 64])
blocks.0.attn.hook_attn torch.Size([1, 12, 3, 3])
blocks.0.attn.hook_z torch.Size([1, 3, 12, 64])
blocks.0.hook_attn_out torch.Size([1, 3, 768])
blocks.0.hook_resid_mid torch.Size([1, 3, 768])
blocks.0.ln2.hook_scale torch.Size([1, 3, 1])
blocks.0.ln2.hook_normalized torch.Size([1, 3, 768])
blocks.0.mlp.hook_pre torch.Size([1, 3, 3072])
blocks.0.mlp.hook_post torch.Size([1, 3, 3072])
blocks.0.hook_mlp_out torch.Size([1, 3, 768])
blocks.0.hook_resid_post torch.Size([1, 3, 768])
ln_final.hook_scale torch.Size([1, 3, 1])
ln_final.hook_normalized torch.Size([1, 3, 768])


To save GPU memory, we can cache activations to the CPU - note that this is much slower though, since it requires copying.

In [19]:
random_tokens = torch.randint(1000, 10000, (1, 300))

print('Run time when copying to the CPU')
%timeit -n 3 logits, cache = model.run_with_cache(random_tokens, device='cpu')
model.reset_hooks()
if torch.cuda.is_available():
    print('Run time when just caching on GPU')
    %timeit -n 3 logits, cache = model.run_with_cache(random_tokens, device='cuda')

Run time when copying to the CPU
63.5 ms ± 4.11 ms per loop (mean ± std. dev. of 7 runs, 3 loops each)
Run time when just caching on GPU
20 ms ± 1.07 ms per loop (mean ± std. dev. of 7 runs, 3 loops each)


## Editing Activations
**To change an activation, add a hook to that HookPoint which returns the new activation**

Pruning attention heads

In [20]:
# Example - prune heads 0, 3 and 7 from layer 3 and heads 8 and 9 from layer 7
layer = 3
head_indices = torch.tensor([0, 3, 7])
layer_2 = 7
head_indices_2 = torch.tensor([8, 9])
def prune_fn_1(z, hook):
    # The shape of the z tensor is batch x pos x head_index x d_head
    z[:, :, head_indices, :] = 0.
    return z
def prune_fn_2(z, hook):
    # The shape of the z tensor is batch x pos x head_index x d_head
    z[:, :, head_indices_2, :] = 0.
    return z
logits = model.run_with_hooks(tokens, fwd_hooks=[(f'blocks.{layer}.attn.hook_z', prune_fn_1),
                                                       (f'blocks.{layer_2}.attn.hook_z', prune_fn_2)])

Restrict all attention heads to only attend to the current and previous token.

**Validation:** The logits for the first 2 positions are the same, the logits for pos 3 are different

In [21]:
model.reset_hooks()
def filter_hook_attn(name):
    split_name = name.split('.')
    return (split_name[-1]=='hook_attn')
def restrict_attn(attn, hook):
    # Attn has shape batch x head_index x query_pos x key_pos
    n_ctx = attn.size(-2)
    key_pos = torch.arange(n_ctx)[None, :]
    query_pos = torch.arange(n_ctx)[:, None]
    mask = (key_pos>(query_pos-2)).to(device)
    ZERO = torch.tensor(0.).to(device)
    attn = torch.where(mask, attn, ZERO)
    return attn
text = "GPU go brrrr"
original_logits = model(text)
logits = model.run_with_hooks(text, fwd_hooks=[(filter_hook_attn, restrict_attn)])
print('New logits')
print(get_corner(logits, 3))
print('Original logits')
print(get_corner(original_logits, 3))

New logits
tensor([[[ 7.5261, 11.1214,  7.8919],
         [ 5.5660,  6.0116,  6.5193],
         [10.0751,  6.8902,  3.8754]]], device='cuda:0',
       grad_fn=<SliceBackward0>)
Original logits
tensor([[[ 7.5261, 11.1214,  7.8919],
         [ 5.5660,  6.0116,  6.5193],
         [10.1794,  6.3186,  3.4537]]], device='cuda:0',
       grad_fn=<SliceBackward0>)


Freezing attention patterns - here we do two runs of the model. First on the original text, caching attn patterns, and secondly on the new text, loading the cached patterns


In [22]:
attn_cache = {}
def cache_attn(attn, hook):
    attn_cache[hook.name]=attn

def freeze_attn(attn, hook):
    return attn_cache[hook.name]
text = "Freezing attention is good"
text_2 = "Freezing attention is bad"
logits = model.run_with_hooks(text, fwd_hooks=[(filter_hook_attn, cache_attn)])

logits_2 = model.run_with_hooks(text_2, fwd_hooks=[(filter_hook_attn, freeze_attn)])


## Using Hook Contexts

**Each hook point has a dictionary `hook.ctx` that can be used to store information between runs** - this is useful for keeping running totals, etc 

A running total of times a neuron activation was positive


In [23]:
# We focus on neuron 20 in layer 7
model.reset_hooks()
animal_texts = ['The dog was green', 'The cat was blue', 'The squid was magenta', 'The blobfish was grey']
layer = 7
neuron_index = 20
def running_total_hook(neuron_acts, hook):
    if 'total' not in hook.ctx:
        hook.ctx['total']=0
    print('Neuron acts:', neuron_acts[0, :, neuron_index])
    hook.ctx['total']+=(neuron_acts[0, :, neuron_index]>0).sum().item()
    print('Running total:', hook.ctx['total'])

for animal_text in animal_texts:
    print(model.to_str_tokens(animal_text))
    model.run_with_hooks(animal_text, fwd_hooks=[(f'blocks.{layer}.mlp.hook_post', running_total_hook)])

['<|endoftext|>', 'The', ' dog', ' was', ' green']
Neuron acts: tensor([-0.0097, -0.1189, -0.1593,  0.7287, -0.0855], device='cuda:0',
       grad_fn=<SelectBackward0>)
Running total: 1
['<|endoftext|>', 'The', ' cat', ' was', ' blue']
Neuron acts: tensor([-0.0097, -0.1189, -0.1408,  0.5849, -0.0564], device='cuda:0',
       grad_fn=<SelectBackward0>)
Running total: 2
['<|endoftext|>', 'The', ' squid', ' was', ' mag', 'enta']
Neuron acts: tensor([-0.0097, -0.1189,  0.6551,  0.7568, -0.1282, -0.0406], device='cuda:0',
       grad_fn=<SelectBackward0>)
Running total: 4
['<|endoftext|>', 'The', ' blob', 'fish', ' was', ' grey']
Neuron acts: tensor([-0.0097, -0.1189,  0.0608,  0.4433,  0.7619, -0.0141], device='cuda:0',
       grad_fn=<SelectBackward0>)
Running total: 7


Finding the dataset example that most activates a given neuron


In [24]:
# We focus on neuron 13 in layer 5
model.reset_hooks(clear_contexts=True)
animal_texts = ['The dog was green', 'The cat was blue', 'The squid was magenta', 'The blobfish was grey']
layer = 5
neuron_index = 13
def best_act_hook(neuron_acts, hook, text):
    if 'best' not in hook.ctx:
        hook.ctx['best']=-1e3
    print('Neuron acts:', neuron_acts[0, :, neuron_index])
    if hook.ctx['best']<neuron_acts[0, :, neuron_index].max():
        print(f'Updating best act from {hook.ctx["best"]} to {neuron_acts[0, :, neuron_index].max().item()}')
        hook.ctx['best'] = neuron_acts[0, :, neuron_index].max().item()
        hook.ctx['text'] = text

for animal_text in animal_texts:
    (print(model.to_str_tokens(animal_text)))
    # Use partial to give the hook access to the relevant text
    model.run_with_hooks(animal_text, fwd_hooks=[(f'blocks.{layer}.mlp.hook_post', partial(best_act_hook, text=animal_text))])
print()
print('Maximally activating dataset example:', model.hook_dict[f'blocks.{layer}.mlp.hook_post'].ctx['text'])
model.reset_hooks(clear_contexts=True)

['<|endoftext|>', 'The', ' dog', ' was', ' green']
Neuron acts: tensor([-0.0064, -0.1598, -0.1656,  0.0594,  0.0503], device='cuda:0',
       grad_fn=<SelectBackward0>)
Updating best act from -1000.0 to 0.059385620057582855
['<|endoftext|>', 'The', ' cat', ' was', ' blue']
Neuron acts: tensor([-0.0064, -0.1598, -0.1520,  0.2058,  0.0560], device='cuda:0',
       grad_fn=<SelectBackward0>)
Updating best act from 0.059385620057582855 to 0.2058115005493164
['<|endoftext|>', 'The', ' squid', ' was', ' mag', 'enta']
Neuron acts: tensor([-0.0064, -0.1598, -0.1667,  0.0659, -0.1564, -0.1610], device='cuda:0',
       grad_fn=<SelectBackward0>)
['<|endoftext|>', 'The', ' blob', 'fish', ' was', ' grey']
Neuron acts: tensor([-0.0064, -0.1598, -0.1666,  0.0968,  0.1661, -0.0204], device='cuda:0',
       grad_fn=<SelectBackward0>)

Maximally activating dataset example: The cat was blue


## Fancier Examples

Looking for heads that mostly attend to the previous token


In [25]:
long_text = 'Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.'
print('Long text:', long_text)
# We first cache attention patterns
attn_cache = {}
def cache_attn(attn, hook):
    attn_cache[hook.name]=attn
logits = model.run_with_hooks(long_text, fwd_hooks=[(filter_hook_attn, cache_attn)])

# We then go through the cache and find the average attention paid to previous tokens
prev_token_scores = np.zeros((model.cfg.n_layers, model.cfg.n_heads))
for layer in range(model.cfg.n_layers):
    for head in range(model.cfg.n_heads):
        attn = attn_cache[f"blocks.{layer}.attn.hook_attn"][0, head]
        prev_token_scores[layer, head]=attn.diag(-1).mean().item()

px.imshow(prev_token_scores, 
          x=[f'Head {hi}' for hi in range(model.cfg.n_heads)], 
          y=[f'Layer {i}' for i in range(model.cfg.n_layers)], 
          title='Prev Token Scores', 
          color_continuous_scale='Blues')

Long text: Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.


[ROME style](https://rome.baulab.info/) patching for causal tracing - we have two runs with two different prompts and different answers, eg "Steve Jobs founded" -> " Apple" and "Bill Gates founded" -> " Microsoft". We patch parts of the layer outputs or residual stream from specific tokens and positions and see which patches significantly shift the answer from " Apple" to " Microsoft"

In [26]:
prompt_1 = 'Bill Gates founded'
response_1 = ' Microsoft'
logit_index_1 = model.to_tokens(response_1)[0][-1]
(print(model.to_str_tokens(prompt_1)))
prompt_2 = 'Steve Jobs founded'
response_2 = ' Apple'
logit_index_2 = model.to_tokens(response_2)[0][-1]
print(model.to_str_tokens(prompt_2))

logits_1, uncorrupted_cache = model.run_with_cache(prompt_1)

uncorrupted_logits = model(prompt_2)
uncorrupted_log_probs = F.log_softmax(uncorrupted_logits, dim=-1)
print('Uncorrupted log prob for', response_1, uncorrupted_log_probs[0, -1, logit_index_1].item())
print('Uncorrupted log prob for', response_2, uncorrupted_log_probs[0, -1, logit_index_2].item())

# Patch the residual stream from the Bill Gates run to the Steve Jobs run
# at the Jobs/Gates token, at the start of layer 7
layer = 7
position = 1

def patch_resid_pre(resid_pre, hook):
    uncorrupted_resid_pre = uncorrupted_cache[hook.name]
    # Move things on the Jobs/Gates token
    resid_pre[:, position] = uncorrupted_resid_pre[:, position]
    return resid_pre

corrupted_logits = model.run_with_hooks(prompt_2, 
                    fwd_hooks=[(f'blocks.{layer}.hook_resid_pre', patch_resid_pre)])
corrupted_log_probs = F.log_softmax(corrupted_logits, dim=-1)
print('Corrupted (Residual) log prob for', response_1, corrupted_log_probs[0, -1, logit_index_1].item())
print('Corrupted (Residual) log prob for', response_2, corrupted_log_probs[0, -1, logit_index_2].item())

['<|endoftext|>', 'Bill', ' Gates', ' founded']
['<|endoftext|>', 'Steve', ' Jobs', ' founded']
Uncorrupted log prob for  Microsoft -2.841726303100586
Uncorrupted log prob for  Apple -0.4552706778049469
Corrupted (Residual) log prob for  Microsoft -2.785998821258545
Corrupted (Residual) log prob for  Apple -0.4480985999107361


We can also patch the outputs of MLP layers 0 to 7 on the Gates/Jobs token - this time, rather than giving a hook name, we give a Boolean function that filters for the names of those hooks.

In [27]:
layer_start = 0
layer_end = 7

def patch_mlp_post(mlp_post, hook):
    return uncorrupted_cache[hook.name]

def filter_middle_mlps(name):
    split_name = name.split('.')
    if split_name[-1]=='hook_post':
        layer = int(split_name[1])
        return (layer_start<=layer<layer_end)
    return False

corrupted_logits = model.run_with_hooks(prompt_2, 
                    fwd_hooks=[(filter_middle_mlps, patch_mlp_post)])
corrupted_log_probs = F.log_softmax(corrupted_logits, dim=-1)
print('Corrupted (MLP) log prob for', response_1, corrupted_log_probs[0, -1, logit_index_1].item())
print('Corrupted (MLP) log prob for', response_2, corrupted_log_probs[0, -1, logit_index_2].item())

Corrupted (MLP) log prob for  Microsoft -1.0558686256408691
Corrupted (MLP) log prob for  Apple -3.8480467796325684


Looking for [induction heads](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html), by feeding in a random sequence of tokens repeated twice and looking for heads that attend from a second copy of a token to the token just after the first copy.

In [28]:
seq_len = 100
rand_tokens = torch.randint(1000, 10000, (4, seq_len))
rand_tokens_repeat = einops.repeat(rand_tokens, 'batch pos -> batch (2 pos)').to(device)

induction_scores_array = np.zeros((model.cfg.n_layers, model.cfg.n_heads))
def calc_induction_score(attn_pattern, hook):
    # Pattern has shape [batch, index, query_pos, key_pos]
    induction_stripe = attn_pattern.diagonal(1-seq_len, dim1=-2, dim2=-1)
    induction_scores = einops.reduce(induction_stripe, 'batch index pos -> index', 'mean')
    # Store the scores in a common array
    induction_scores_array[hook.layer()] = induction_scores.detach().cpu().numpy()
    
def filter_attn_hooks(hook_name):
    split_name = hook_name.split('.')
    return split_name[-1]=='hook_attn'

induction_logits = model.run_with_hooks(rand_tokens_repeat, fwd_hooks=[(filter_attn_hooks, calc_induction_score)])
px.imshow(induction_scores_array, labels={'y':'Layer', 'x':'Head'}, color_continuous_scale='Blues')

**Validation:** We can ablate the top few heads by this metric, and show that performance goes down substantially

In [29]:
induction_logits = model(rand_tokens_repeat)
induction_log_probs = F.log_softmax(induction_logits, dim=-1)
induction_pred_log_probs = torch.gather(induction_log_probs[:, :-1], -1, rand_tokens_repeat[:, 1:, None])[..., 0]
print('Original loss on repeated sequence:', induction_pred_log_probs[:, seq_len:].mean())

# Mask out the heads with a high induction score
attn_head_mask = induction_scores_array>0.8

def prune_attn_heads(value, hook):
    # Value has shape [batch, pos, index, d_head]
    mask = attn_head_mask[hook.layer()]
    value[:, :, mask] = 0.
    return value

def filter_value_hooks(name):
    return name.split('.')[-1]=='hook_v'

ablated_logits = model.run_with_hooks(rand_tokens_repeat, fwd_hooks=[(filter_value_hooks, prune_attn_heads)])
ablated_log_probs = F.log_softmax(ablated_logits, dim=-1)
ablated_pred_log_probs = torch.gather(ablated_log_probs[:, :-1], -1, rand_tokens_repeat[:, 1:, None])[..., 0]
print('Loss on repeated sequence without induction heads:', ablated_pred_log_probs[:, seq_len:].mean())

px.imshow(attn_head_mask, labels={'y':'Layer', 'x':'Head'}, color_continuous_scale='Blues', title='Mask').show()

Original loss on repeated sequence: tensor(-0.1913, device='cuda:0', grad_fn=<MeanBackward0>)
Loss on repeated sequence without induction heads: tensor(-6.2510, device='cuda:0', grad_fn=<MeanBackward0>)


# Further Examples + Features

## Ablation experiments

We provide a wrapper to facilitate ablations experiment. 

An `EasyAblation` object is the combinaison of:
* The `EasyTransformer` model to be ablated
* An `AblationConfig` object that store all the parameters of the ablation 
* An `ExperimentMetric` object that define how we will measure the effect of the ablation. This can be the loss on a given dataset or the attention score of a precise head.

Here, we defined a metric function that takes in inputs the model and the dataset and output a tensor. You can use `model.run_with_hook()` in the metric function, by you *have* to use the option `reset_hooks_start=False`, else the ablation hooks will be ignored. 

You can specify the `target_module` from "mlp", "attn_layer", "attn_head". 
If you chose "attn_head" you can define which part of the head computation to ablate ("z", "q", "v", "k", "attn", "attn_scores") 

The supported ablation types are `mean`, `zero`, `neg` and `custom`. For mean ablations, you can specify a `mean_dataset` in the config that can be different to the one used in the metric. 

The `verbose` option prints all the experiment parameters before running the ablations.

In [30]:
def induction_loss(model, dataset):
    induction_logits = model(dataset)
    induction_log_probs = F.log_softmax(induction_logits, dim=-1)
    induction_pred_log_probs = torch.gather(induction_log_probs[:, :-1], -1, rand_tokens_repeat[:, 1:, None])[..., 0]
    return induction_pred_log_probs[:, seq_len:].mean()

metric = ExperimentMetric(metric=induction_loss, dataset=rand_tokens_repeat, relative_metric=True)
config = AblationConfig(abl_type="zero", target_module="attn_head",head_circuit="z",  cache_means=True, verbose=True)
abl = EasyAblation(model, config, metric)
result = abl.run_ablation()

px.imshow(result, labels={'y':'Layer', 'x':'Head'}, color_continuous_scale='Blues', title='Induction Score Variation after Ablation').show()

--- AblationConfig: ---
* target_module: attn_head
* head_circuit: z
* layers: all
* heads: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
* dataset: tensor([1145, 7879, 3717, 5613, 7860, 7867, 9639, 6182, 2987, 2705],
       device='cuda:0') ... 
tensor([7796, 5309, 6540, 2826, 3098, 9550, 7882, 4919, 7166, 5802],
       device='cuda:0') ... 
tensor([5909, 5771, 7850, 8157, 7896, 8485, 3135, 1822, 2248, 8013],
       device='cuda:0') ... 
* verbose: True
* beg_layer: 0
* end_layer: 12
* abl_type: zero
* mean_dataset: None
* cache_means: True
* compute_means: False
* abl_fn: <function zero_fn at 0x7f97a46225f0>



100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:06<00:00,  1.79it/s]


You can also use the `EasyAblation` object to generate ablations hook, without using it to run ablations. This is useful when you want to ablate several heads at once. The ablations hooks generated will respect the configuration you used to generate the `EasyAblation` object.

Here we can reproduce the revious results where we ablate the induction heads by replacing their activation by zero. You can notice that the combined effect of the ablations of the 5 heads is much greater than the sum of their individual effect.

In [31]:
model.reset_hooks()
for (l,h) in [(5,1), (5,5), (6,9), (7,2), (7,10)]:
    hook_name, hook = abl.get_hook(l,h)
    model.add_hook(hook_name, hook)
print(f"Loss on the repeated random token after zero-ablations of the induction heads {induction_loss(model, rand_tokens_repeat)}")

Loss on the repeated random token after zero-ablations of the induction heads -6.250967979431152


However, replacing the activation by zero is quite a weird thing to do as this could be really far from the baseline activation depending on the head. To fix this, we can instead replace the activation by its mean on the dataset. All the effect related to a particular sample will be washed out but the global contribution will be kept. The drop is still significant but not as much as zero ablation.

In [32]:
mean_abl_config = AblationConfig(abl_type="mean", target_module="attn_head",head_circuit="z", cache_means=True)
mean_abl = EasyAblation(model, mean_abl_config, metric)
model.reset_hooks()
for (l,h) in [(5,1), (5,5), (6,9), (7,2), (7,10)]:
    hook_name, hook = mean_abl.get_hook(l,h)
    model.add_hook(hook_name, hook)
    
print(f"Loss on the repeated random token after mean-ablations of the induction heads {induction_loss(model, rand_tokens_repeat)}")

Loss on the repeated random token after mean-ablations of the induction heads -1.8151707649230957


### Custom ablation function

When using `abl_type="custom"` you can specify an arbitrary function `custom_abl_fn`. It has to take as input the normal output of the module, its mean activation, the hook object and output a tensor of the same shape as the normal output.

The mean as the same shape as the normal output, its constant along the batch_size dimension. 

If `target_module="attn_head"` the output is for a given head, its shape would be `(batch, seq_len, head_dim)` except for attention score and attention pattern : `(batch,seq_len, seq_len)`.

In the example below, we take the symetric of the activation with respect to its mean. It has for effect to reverse the contribution of the head without going too far in out of distribution activation space (as naively flipping the sign would do).

In [33]:

def sym_mean(z, mean, hook): 
    return mean-(z-mean)

metric = ExperimentMetric(induction_loss, rand_tokens_repeat, relative_metric=True)
config = AblationConfig(abl_type="custom",abl_fn=sym_mean, target_module="attn_head",head_circuit="v",  cache_means=True, verbose=False)
abl = EasyAblation(model, config, metric)
result = abl.run_ablation()
fig = px.imshow(result, labels={'y':'Layer', 'x':'Head'}, color_continuous_scale='Blues', title='Induction Score Variation after Custom Ablation')
# fig.update_xtick()
fig.show()


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:03<00:00,  3.52it/s]


## Patching experiments

We can also run patching experiment were we take activation from a source dataset and copy them in the model while processing the target dataset. We can then mesure wich module causes the model to change its output.
Similarly to the `EasyAblation`, we use an `EasyPatching` object that depends on a `PatchingConfig` and an `ExperimentMetric`.

In the example bellow we can locate which head influences the predicted next token at "founded" to got from "Microsoft" to "Apple" if the sentence is "Bill Gates founded". It seems that head 10.0 plays the major role here, by only changing its activation we can make the next token prediction switch from "Microsoft" to "Apple".

In [34]:
from easy_transformer.experiments import EasyPatching, PatchingConfig

source_facts = ["Steve Jobs founded Apple", "Bill Gates founded Microsoft"]
target_facts = ["Bill Gates founded Microsoft", "Steve Jobs founded Apple"]

source_labels = [ " Apple", " Microsoft"]
target_labels = [ " Microsoft"," Apple"]

source_logits = model.to_tokens(source_labels).squeeze()
target_logits = model.to_tokens(target_labels).squeeze()

tokens_pos = [2,2] # The position of "founded" in the target sentences, where to get the next token prediction

def fact_transfer_score(model, target_dataset):
    logits = model(target_dataset)
    log_probs = F.log_softmax(logits, dim=-1)
    logit_diff = (log_probs[torch.arange(len(target_logits)),tokens_pos,target_logits] - # logit target - logit source (positive by default)
                  log_probs[torch.arange(len(source_logits)),tokens_pos,source_logits])

    return logit_diff.mean() 

metric = ExperimentMetric(fact_transfer_score, target_facts, relative_metric=False)
config = PatchingConfig(source_dataset = source_facts,target_dataset = target_facts, target_module="attn_head",head_circuit="v",  cache_act=True, verbose=False)
patching = EasyPatching(model, config, metric)
result = patching.run_patching()
px.imshow(result, labels={'y':'Layer', 'x':'Head'}, color_continuous_scale='Blues', title='Absolute Log Logit Prob Difference After Patching').show()

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:03<00:00,  3.90it/s]


We can be more precise and patch at only certain token position. For that, we can use custom patching functions. Below we show that patching at the last name position is enough to recover the previous plot.

In [35]:
def patch_last_name(z, source_act, hook):
    z[:, 1, :] = source_act[:, 1, :] # We patch at the token of the last name 
    return z

config = PatchingConfig(
    source_dataset=source_facts,
    target_dataset=target_facts,
    patch_fn=patch_last_name,
    target_module="attn_head",
    head_circuit="v",
    cache_act=True,
    verbose=False,
)
patching = EasyPatching(model, config, metric)
result = patching.run_patching()
px.imshow(result, labels={'y':'Layer', 'x':'Head'}, color_continuous_scale='Blues', title='Log Logit Prob difference after Patching').show()

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:03<00:00,  3.82it/s]


## Loading Checkpointed Models
Researchers at the Stanford Center for Research on Foundation Models kindly [created and open sourced 5 training runs of GPT-2 Small and GPT-2 Medium](https://huggingface.co/stanford-crfm), with 600 checkpoints taken during training. These can be loaded in via the same interface as above

These are called `stanford-gpt2-small-A`, (with `small` or `medium` and `A`, `B`, `C`, `D`, `E` as the possible options)

You can see the available checkpoints [here](https://huggingface.co/stanford-crfm/alias-gpt2-small-x21/tree/main) (each checkpoint has a separate Git branch)

In [36]:

print_checkpoints = False #@param [False, True]
if print_checkpoints:
    print(EasyTransformer.STANFORD_CRFM_CHECKPOINTS)
else:
    print("Available checkpoints - it's long, so toggle the flag to print")

Available checkpoints - it's long, so toggle the flag to print


### Looking for induction heads
We can use this to analyse whether models contain induction heads during training (by checking whether they can predict repeated sequences of random tokens), and can see something of a phase change early in training

In [37]:

plps = {}
tokens = torch.randint(1000, 20000, (1, 100))
tokens = torch.concat([tokens, tokens], axis=1).to(device)
for check in [1000, 2500, 5000]:
    checkpointed_model = EasyTransformer.from_pretrained('stanford-gpt2-small-E', checkpoint=check)
    logits = checkpointed_model(tokens)
    log_probs = F.log_softmax(logits, dim=-1)
    plp = torch.gather(log_probs[:, :-1], -1, tokens[:, 1:, None])[0, :, 0]
    plps[check] = plp.detach().cpu().numpy()
px.line(plps).show()

Downloading config.json:   0%|          | 0.00/946 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/249M [00:00<?, ?B/s]

Downloading tokenizer_config.json:   0%|          | 0.00/200 [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/946 [00:00<?, ?B/s]

Downloading vocab.json:   0%|          | 0.00/779k [00:00<?, ?B/s]

Downloading merges.txt:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading special_tokens_map.json:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/946 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/249M [00:00<?, ?B/s]

## Visualisations

Visualisations are an extremely important tool for doing good interpretability work - fundamentally, neural networks are complex, high-dimensional objects and we want to understand how to decompose them and understand them. 

It's important to get as close as you can to the ground truth of what's really going on, and easy to trick yourself and shoot yourself in the foot, so it's hard to get by with just things like summary statistics - which is where data visualisation techniques come in!

Visualisations are not currently planned to be part of EasyTransformer, but Anthropic released a very rough library for doing nice and transformer-relevant visualisations within Python called PySvelte, a slightly less rough fork can be found here: https://github.com/neelnanda-io/pysvelte (Note - this library is under active development by a friend of mine and will hopefully be much more stable and usable by October-ish - I don't recommend putting significant effort into understanding how to edit and write your own components unless you already have a bunch of webdev experience). Credit to Oliver Balfour for helping me figure out how to get it to work!

### Installation

In [None]:
%%capture
# Install the right version of node - this is needed to install Svelte which is used to build components
# v16, an older version, seems to work more reliably than v18 on the systems I've tried it on.
!curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
# Get an up-to-date PySvelte, deleting old versions if present
!touch PySvelte
!rm -r PySvelte
!git clone https://github.com/neelnanda-io/PySvelte.git
import sys
sys.path.append('/content/PySvelte')

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [None]:
import pysvelte
import numpy as np

In [None]:
# Running pysvelte the first time will re-compile the Svelte UI code, which might take a while
pysvelte.Hello(name='World')

pysvelte components appear to be unbuilt or stale
Building pysvelte components with webpack...
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
entry: {"loader":"./src/loader.js","Hello":"./src/Hello.svelte"}
asset loader.js 44.1 KiB [emitted] [minimized] (name: loader) 1 related asset
asset Hello.js 6.82 KiB [emitted] [minimized] (name: Hello)
runtime modules 1.03 KiB 5 modules
orphan modules 53.6 KiB [orphan] 1 module
modules by path ./node_modules/ 116 KiB
  modules by path ./node_modules/pako/lib/ 102 KiB
    modules by path ./node_modules/pako/lib/zlib/*.js 84.7 KiB 9 modules
    modules by path ./node_modules/pako/lib/utils/*.js 5.62 KiB 2 modules
    ./node_modules/pako/lib/inflate.js 11.6 KiB [built] [code generated]
  ./node_modules/numpy-parser/d

### Visualising Attention Patterns

A component to visualise attention patterns over some text. This is a particularly hard problem, as attention patterns are rank 3 tensors - with a destination_pos, source_pos and num_heads dimension. 

This plots the attention pattern as a dest_pos x source_pos grid in the top left, showing the average across the heads, and along the bottom shows the tokenized text, each token highlighted with the average attention paid to it. 

Each head gets a colour, and by default the colours are averaged, but we can hover over or click on a head to focus on just that head's colour.

If we hover over or click on a token, it instead shows that the shading over other tokens according to attention paid from that token to them.

There's a toggle to flip the token view to show tokens attending TO the current token (ie, from later in the sequence attending back)

In [None]:
model = EasyTransformer.from_pretrained('gpt2')
vis_text = "Help, I live in three dimensions but need to interact with models with too many dimensions!! Help, I live in three dimensions but need to interact with models with too many dimensions!"
logits, vis_cache = model.run_with_cache(vis_text)

In [None]:
layer = 5
attn_pattern = einops.rearrange(vis_cache[f'blocks.{layer}.attn.hook_attn'][0], 
                                "num_heads dest_pos src_pos -> dest_pos src_pos num_heads") # Indexing into shape [batch, n_heads, dest_pos, src_pos]

tokenized_text = model.to_str_tokens(vis_text)
html_object = pysvelte.AttentionMulti(tokens=tokenized_text, attention=attn_pattern, head_labels=None)
html_object.show()

pysvelte components appear to be unbuilt or stale
Building pysvelte components with webpack...
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
entry: {"loader":"./src/loader.js","AttentionMulti":"./src/AttentionMulti.svelte"}
asset loader.js 44.1 KiB [emitted] [minimized] (name: loader) 1 related asset
asset AttentionMulti.js 41.1 KiB [emitted] [minimized] (name: AttentionMulti)
orphan modules 68.2 KiB [orphan] 6 modules
runtime modules 1.03 KiB 5 modules
modules by path ./node_modules/ 116 KiB
  modules by path ./node_modules/pako/lib/ 102 KiB
    modules by path ./node_modules/pako/lib/zlib/*.js 84.7 KiB 9 modules
    modules by path ./node_modules/pako/lib/utils/*.js 5.62 KiB 2 modules
    ./node_modules/pako/lib/inflate.js 11.6 KiB [built] [code gener

Earlier we saw that layer 5 contained some induction heads - can you figure out what they are from the above diagram? 

(Hint: What should the attention pattern visualised as a grid for each head look like?)


<details> <summary> Answer: </summary> Heads 1 and 5 strongly, head 0 weakly. Weirdly, head 8 seems somewhat induction-y here, but wasn't at all earlier - I'm not sure what's happening here, if you figure it out then please let me know! </details>

### Visualising Neuron Activations

We can also plot neuron activations over text - we input a list of the token in the text, and 1D array of activations, one per token. 

Positive activations are coloured green and negative are red, and it's automatically normalised to be in [-1, 1] (the max and min are printed at the top)

(This currently just supports activations for a single neuron, though wouldn't be too hard to extend)

In [None]:
layer = 7
neuron = 124
neuron_activations = vis_cache[f"blocks.{layer}.mlp.hook_post"][0, :, neuron] # Indexing into shape [batch, pos, d_mlp]

html_object = pysvelte.TextSingle(tokens=tokenized_text, activations=neuron_activations, neuron_name='Test neuron')
html_object.show()

pysvelte components appear to be unbuilt or stale
Building pysvelte components with webpack...


## Training an Algorithmic Model

EasyTransformer also supports passing in custom config and initialising weights to create your own model. This isn't optimised for performance, so is likely best for training small LMs or small transformers for algorithmic tasks.

We demonstrate training a (very!) small model to predict a string of consecutive numbers (with a random initial offset)

In [None]:
tiny_cfg = EasyTransformerConfig(
    d_model = 32,
    d_head = 16,
    n_heads = 2,
    d_mlp = 128,
    n_layers=1,
    n_ctx = 50,
    act_fn='solu_ln',
    d_vocab=150,
    normalization_type='LN',
    seed=23, # Now we're training a custom model, it's good to set the seed to get reproducible results. It defaults to 42.
    )

tiny_model = EasyTransformer(tiny_cfg).to(device)
tiny_optimizer = torch.optim.Adam(tiny_model.parameters(), lr=1e-3)
batch_size = 20
num_epochs=301

In [None]:
for epoch in tqdm.tqdm(range(301)):
    batch_offset = torch.randint(0, 100, (20,))
    range_over_ctx = torch.arange(tiny_model.cfg.n_ctx)
    # Fancy indexing to get a batch of consecutive tokens, with each row starting with batch_offset
    batch = batch_offset[:, None] + range_over_ctx[None, :]
    loss = tiny_model(batch, return_type='loss')
    loss.backward()
    tiny_optimizer.step()
    tiny_optimizer.zero_grad()
    if epoch%100 == 0:
        print(f"Epoch: {epoch}. Loss: {loss}")

  0%|          | 0/301 [00:00<?, ?it/s]

Epoch: 0. Loss: 5.026259899139404
Epoch: 100. Loss: 1.738516092300415
Epoch: 200. Loss: 0.33298417925834656
Epoch: 300. Loss: 0.11569619923830032


In [None]:
loss, tiny_cache = tiny_model.run_with_cache(batch, return_type='loss')

tensor(0.1146, device='cuda:0', grad_fn=<NegBackward0>)


## Training a Language Model

Though EasyTransformer is not designed for high-performance model training, we provide some utilities for training small language models.

See train.py for an example training script, the following is how to use it for a simple training task:

In [None]:
micro_gpt_cfg = EasyTransformerConfig(
    d_model = 64,
    d_head = 32,
    n_heads = 2,
    d_mlp = 256,
    n_layers=3,
    n_ctx = 512,
    act_fn='gelu_new',
    normalization_type='LN',
    tokenizer_name='EleutherAI/gpt-neox-20b',
    )
micro_gpt = EasyTransformer(micro_gpt_cfg)

We download 10K samples of the Pile (via a small utility dataset on HuggingFace), and use a utility to tokenize them, concatenate them (separated by EOS tokens), and reshape them into batches of size n_ctx

In [None]:

dataset = datasets.load_dataset("NeelNanda/pile-10k", split="train")
dataset = easy_transformer.utils.tokenize_and_concatenate(dataset, micro_gpt.tokenizer, max_length=micro_gpt.cfg.n_ctx, add_bos_token=False)

Using custom data configuration NeelNanda--pile-10k-698b4c44102ba425
Reusing dataset parquet (/workspace/cache/NeelNanda___parquet/NeelNanda--pile-10k-698b4c44102ba425/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


 

Loading cached processed dataset at /workspace/cache/NeelNanda___parquet/NeelNanda--pile-10k-698b4c44102ba425/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-16d4ee1731535bf6.arrow


 

Loading cached processed dataset at /workspace/cache/NeelNanda___parquet/NeelNanda--pile-10k-698b4c44102ba425/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-9eff21dca5ec913c.arrow


 

Loading cached processed dataset at /workspace/cache/NeelNanda___parquet/NeelNanda--pile-10k-698b4c44102ba425/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-47632d7a5e7766f1.arrow


 

Loading cached processed dataset at /workspace/cache/NeelNanda___parquet/NeelNanda--pile-10k-698b4c44102ba425/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-d1466a3d20801ca1.arrow


As an example, we train our tiny model for 500 steps, of batch size 2, with AdamW

In [None]:

training_cfg = easy_transformer.train.EasyTransformerTrainConfig(
    num_epochs = 1,
    batch_size = 2,
    weight_decay = 0.01,
    optimizer_name = 'AdamW',
    max_steps = 500,
)
micro_gpt = easy_transformer.train.train(micro_gpt, training_cfg, dataset)

  0%|          | 0/1 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Epoch 1 Samples 2 Step 0 Loss 10.832801818847656
Epoch 1 Samples 102 Step 50 Loss 8.344871520996094
Epoch 1 Samples 202 Step 100 Loss 7.829599380493164
Epoch 1 Samples 302 Step 150 Loss 8.067302703857422
Epoch 1 Samples 402 Step 200 Loss 7.894046306610107
Epoch 1 Samples 502 Step 250 Loss 7.74192476272583
Epoch 1 Samples 602 Step 300 Loss 8.823657035827637
Epoch 1 Samples 702 Step 350 Loss 7.71836519241333
Epoch 1 Samples 802 Step 400 Loss 7.205349445343018
Epoch 1 Samples 902 Step 450 Loss 7.272680759429932
Epoch 1 Samples 1002 Step 500 Loss 6.653236389160156


In [None]:
from neel.imports import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
In IPython
Set autoreload


## Generating Text

This isn't a core feature of the library, but is pretty useful to have! A key move in both ML and interpretability is to really get your hands dirty, play around with your models and your data, and try to understand what's going on. Generating a bunch of text and playing around is a good way to engage with that.


Thanks to Ansh Rahhakrishnan for adding this feature!

In [None]:
model = EasyTransformer.from_pretrained("gpt2-medium")

In [None]:
prompt = "The following work gives an insightful and original solution to the alignment problem:"
model.generate(prompt, max_new_tokens=50, temperature=0.9, top_k=5)

  0%|          | 0/50 [00:00<?, ?it/s]

'<|endoftext|>The following work gives an insightful and original solution to the alignment problem:\n\nIn an effort to find a solution to the alignment problem, the authors propose a solution that is both elegant and simple to implement. The solution is based on the concept of "boundedness" and the concept of "bounds". The'