# Setup
(No need to read)

In [1]:
import os
DEVELOPMENT_MODE = False
# Detect if we're running in Google Colab
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
except:
    IN_COLAB = False

# Install if in Colab
if IN_COLAB:
    %pip install transformer_lens
    %pip install circuitsvis
    # Install a faster Node version
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs  # noqa

# Hot reload in development mode & not running on the CD
if not IN_COLAB:
    from IPython import get_ipython
    ip = get_ipython()
    if not ip.extension_manager.loaded:
        ip.extension_manager.load('autoreload')
        %autoreload 2
        
IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"
IN_GITHUB = True


In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

Using renderer: colab


In [3]:
import circuitsvis as cv
# Testing that the library works
cv.examples.hello("Jack")

In [4]:
# Import stuff
import torch
import torch.nn as nn
import einops
import pickle as pkl
import numpy as np
import datasets
from torch.utils.data import DataLoader, Dataset


from fancy_einsum import einsum
from tqdm.notebook import tqdm
import plotly.express as px

from jaxtyping import Float
from functools import partial

In [5]:
# import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix, HookedTransformerConfig
from transformer_lens.train import HookedTransformerTrainConfig, train
from transformer_lens.utils import tokenize_and_concatenate, lm_cross_entropy_loss

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [6]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f11044b9e50>

Plotting helper functions:

In [7]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

# Local Pretrained Hooked Transformers
HookedTransformers can now be loaded from your local checkpoints and supports continued pretraining pipelines or other custom training pipelines.

## Demo 1: Dump a pretrained model during training and reload

**For demonstration purposes download a pretrained model**

In [8]:
# load gpt2-small 
reference_gpt2_small = HookedTransformer.from_pretrained("gpt2-small")

Loaded pretrained model gpt2-small into HookedTransformer


In [9]:
# view the model
print(reference_gpt2_small)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (h

In [10]:
# view config
print(reference_gpt2_small.cfg)

HookedTransformerConfig:
{'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'default_prepend_bos': True,
 'device': device(type='cpu'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.02886751345948129,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_layers': 12,
 'n_params': 84934656,
 'normalization_type': 'LNPre',
 'original_architecture': 'GPT2LMHeadModel',
 'parallel_attn_mlp': False,
 'positional_embedding_type': 'standard',
 'post_embedding_ln': False,
 'rotary_dim': None,
 'scale_attn_by_inverse_layer_idx': False,
 'seed': None,
 'tokenizer_name': 'gpt2',
 'tokenizer_prepends_bos': False,
 'use_attn_in': Fa

In [11]:
# print type of config
print(type(reference_gpt2_small.cfg))

<class 'transformer_lens.HookedTransformerConfig.HookedTransformerConfig'>


In [12]:
# show generate output
reference_gpt2_small.generate("(CNN) President Barack Obama caught in embarrassing new scandal\n", max_new_tokens=50, temperature=0.7, prepend_bos=True)

  0%|          | 0/50 [00:00<?, ?it/s]

"(CNN) President Barack Obama caught in embarrassing new scandal\n\nThere was a 10-minute moment of silence for the victims of the Benghazi attack on American soil, and the president called for a swift investigation.\n\nAnd they didn't complain.\n\nJUST WATCHED Obama: We know we killed an American"

### Save the model to a local directory
The Transformer_Lens train file now allows the
model to be saved as a pickle file to the save_dir where the current state_dict is saved.
<br>
```python
pickle_dump: bool = False # set to True to save the model as a pickle file to the same location as the state_dict
```
```

In [13]:
# save the same way as saved as checkpoint during transformer lens training
pickle_path = "gpt2-small-checkpoint.pkl"
with open(pickle_path, "wb") as f:
    pkl.dump(reference_gpt2_small, f)

### Reload as local model

In [14]:
# reload the model
with open(pickle_path, "rb") as f:
    loaded_model = pkl.load(f)

In [15]:
print(loaded_model)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (h

In [16]:
loaded_model.hook_points

<bound method HookedRootModule.hook_points of HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoin

In [17]:
loaded_model.generate("(CNN) President Barack Obama caught in embarrassing new scandal\n", max_new_tokens=50, temperature=0.7, prepend_bos=True)

  0%|          | 0/50 [00:00<?, ?it/s]

'(CNN) President Barack Obama caught in embarrassing new scandal\n\nOn Wednesday, the White House released the original email chain sent by WikiLeaks founder Julian Assange, which revealed the agency was cooperating with the Russian government in its efforts to disrupt the election.\n\nThe email chain, obtained by CNN under the Freedom of'

### Continued Pretraining

In [18]:
torch.set_grad_enabled(True)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f10cd9711c0>

In [19]:
batch_size = 2
num_epochs = 2
max_steps = 2
log_every = 1
lr = 1e-3
weight_decay = 1e-2

In [20]:
train_config = HookedTransformerTrainConfig(
    batch_size=2,
    num_epochs=2,
    max_steps=2, 
    optimizer_name="AdamW",
    lr=1e-3,
    weight_decay=1e-2,
    print_every=2,
    save_dir="models",
    seed=123,
)

dataset = datasets.load_dataset("NeelNanda/pile-10k", split="train")
tokens_dataset = tokenize_and_concatenate(dataset, loaded_model.tokenizer, streaming=False, max_length=loaded_model.cfg.n_ctx, column_name="text", add_bos_token=True, num_proc=4)

In [21]:
cont_pretrain_model = train(loaded_model, train_config, tokens_dataset)

Moving model to device:  cpu


  0%|          | 0/2 [00:00<?, ?it/s]

0it [00:00, ?it/s]

## Demo 2- Load from state_dict
This option requires a path to a state_dict and a loaded config file of type HookedTransformerConfig.


### Load new base model

In [None]:
reference_gpt2_small = HookedTransformer.from_pretrained("gpt2-small")

# save as checkpoint as in training
base_path = "gpt2-small-checkpoint"
torch.save(reference_gpt2_small.state_dict(), f"{base_path}.pt")

# save the config to pickle
with open(f"{base_path}.pkl", "wb") as f:
    pkl.dump(reference_gpt2_small.cfg, f)

Loaded pretrained model gpt2-small into HookedTransformer


In [None]:
# load the config
with open(f"{base_path}.pkl", "rb") as f:
    local_cfg = pkl.load(f)
    
type(local_cfg)

transformer_lens.HookedTransformerConfig.HookedTransformerConfig

In [None]:
# Set local_model_path
local_model_path = "gpt2-small-checkpoint.pt"

### Using from_local
The from_local method allows you to load a model from a local directory. 
You pass in the path to the directory containing the state_dict, a config file that is preloaded and
formatted as a HookedTransformerConfig.
<br>
You also have the option of passing in a weight_conversion_function, that enables you to format the
state_dict if needed. This takes in the state_dict and local_cfg as arguments and returns a state_dict.

In [None]:
pretrained_model = HookedTransformer.from_local(local_model_path=local_model_path, local_cfg=local_cfg)
# weight_conversion_function = None

Loading model from local path: gpt2-small-checkpoint.pt
Loaded model state dict: OrderedDict([('embed.W_E', tensor([[-0.1106, -0.0398,  0.0326,  ..., -0.1369,  0.0146,  0.0448],
        [ 0.0359, -0.0531,  0.0418,  ...,  0.0816, -0.0019,  0.0387],
        [-0.1301,  0.0453,  0.1815,  ...,  0.0873, -0.1324, -0.0905],
        ...,
        [-0.0447, -0.0550,  0.0121,  ...,  0.1042,  0.0976, -0.0697],
        [ 0.1870,  0.0176,  0.0471,  ..., -0.0953,  0.0795, -0.0215],
        [ 0.0517, -0.0274,  0.0502,  ...,  0.0074,  0.1555,  0.1210]])), ('pos_embed.W_pos', tensor([[-1.3368e-02, -1.9197e-01,  9.4797e-03,  ..., -3.7591e-02,
          3.3720e-02,  5.9943e-02],
        [ 2.4966e-02, -5.2785e-02, -9.3872e-02,  ...,  3.5177e-02,
          1.1178e-02,  8.5082e-04],
        [ 6.4541e-03, -8.2526e-02,  5.6753e-02,  ...,  2.1983e-02,
          2.1563e-02, -1.9186e-02],
        ...,
        [-5.1990e-03, -1.7951e-03, -5.8503e-02,  ...,  1.0216e-02,
         -1.0581e-02,  3.5489e-04],
        [ 1

In [None]:
pretrained_model.generate("(CNN) President Barack Obama caught in embarrassing new scandal\n", max_new_tokens=50, temperature=0.7, prepend_bos=True)

  0%|          | 0/50 [00:00<?, ?it/s]

'(CNN) President Barack Obama caught in embarrassing new scandal\n\nThe President of the United States, Barack Obama, has been caught in the cross-hairs of the FBI, "detailing" his contacts with an on-the-ground Russian spyoughest of government officials, according to a report on'

In [None]:
batch_size = 2
num_epochs = 2
max_steps = 2
log_every = 1
lr = 1e-3
weight_decay = 1e-2

In [None]:
dataset = datasets.load_dataset("NeelNanda/pile-10k", split="train")
tokens_dataset = tokenize_and_concatenate(dataset, pretrained_model.tokenizer, streaming=False, max_length=pretrained_model.cfg.n_ctx, column_name="text", add_bos_token=True, num_proc=4)

In [None]:
train_config = HookedTransformerTrainConfig(
    batch_size=2,
    num_epochs=2,
    max_steps=5, 
    optimizer_name="AdamW",
    lr=1e-3,
    weight_decay=1e-2,
    print_every=100,
    save_dir="models",
    seed=123,
)

In [None]:
cont_pretrained_model = train(pretrained_model, train_config, tokens_dataset)

Moving model to device:  cpu


  0%|          | 0/2 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Epoch 1 Samples 2 Step 0 Loss 3.2271852493286133


0it [00:00, ?it/s]

Epoch 2 Samples 2 Step 0 Loss 7.868700981140137


: 