# Setup
(No need to read)

In [1]:
import os
DEVELOPMENT_MODE = False
# Detect if we're running in Google Colab
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
except:
    IN_COLAB = False

# Install if in Colab
if IN_COLAB:
    %pip install transformer_lens
    %pip install circuitsvis
    # Install a faster Node version
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs  # noqa

# Hot reload in development mode & not running on the CD
if not IN_COLAB:
    from IPython import get_ipython
    ip = get_ipython()
    if not ip.extension_manager.loaded:
        ip.extension_manager.load('autoreload')
        %autoreload 2
        
IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"
IN_GITHUB = True


In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

Using renderer: colab


In [3]:
import circuitsvis as cv
# Testing that the library works
cv.examples.hello("Neel")

In [4]:
# Import stuff
import torch
import torch.nn as nn
import einops
import pickle as pkl
import numpy as np
import datasets
from torch.utils.data import DataLoader, Dataset


from fancy_einsum import einsum
from tqdm.notebook import tqdm
import plotly.express as px

from jaxtyping import Float
from functools import partial

In [5]:
# import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix, HookedTransformerConfig
from transformer_lens.train import HookedTransformerTrainConfig
from transformer_lens.utils import tokenize_and_concatenate, lm_cross_entropy_loss

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [6]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7facfc454520>

Plotting helper functions:

In [7]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

# Local Pretrained Hooked Transformers
Transformers trained using the Transformer_Lens Hook architecture are also supported. This gives
support to continued pretraining pipelines or other custom training pipelines.

Users can provide a path to a local pretrained model, and define a dictionary for the model config
to load the model. 

## Download a pretrained model

In [8]:
# load gpt2-small 
reference_gpt2_small = HookedTransformer.from_pretrained("gpt2-small")

Loaded pretrained model gpt2-small into HookedTransformer


In [9]:
# view the model
print(reference_gpt2_small)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (h

In [10]:
# view config
print(reference_gpt2_small.cfg)

HookedTransformerConfig:
{'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'default_prepend_bos': True,
 'device': device(type='cpu'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.02886751345948129,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_layers': 12,
 'n_params': 84934656,
 'normalization_type': 'LNPre',
 'original_architecture': 'GPT2LMHeadModel',
 'parallel_attn_mlp': False,
 'positional_embedding_type': 'standard',
 'post_embedding_ln': False,
 'rotary_dim': None,
 'scale_attn_by_inverse_layer_idx': False,
 'seed': None,
 'tokenizer_name': 'gpt2',
 'tokenizer_prepends_bos': False,
 'use_attn_in': Fa

In [11]:
# print type of config
print(type(reference_gpt2_small.cfg))

<class 'transformer_lens.HookedTransformerConfig.HookedTransformerConfig'>


In [12]:
# show generate output
reference_gpt2_small.generate("(CNN) President Barack Obama caught in embarrassing new scandal\n", max_new_tokens=50, temperature=0.7, prepend_bos=True)

  0%|          | 0/50 [00:00<?, ?it/s]

"(CNN) President Barack Obama caught in embarrassing new scandal\n\nA former CIA contractor who was fired for leaking classified information about the CIA's surveillance program, Edward Snowden, has been named by the Senate Judiciary Committee as a suspect in the leak of top-secret documents, and Senate Republicans are expecting a hearing on"

## Save the model to a local directory
The Transformer_Lens train file now allows the
model to be saved as a pickle file to the save_dir where the current state_dict is saved.
<br>
```python
pickle_dump: bool = False # set to True to save the model as a pickle file
```

In [13]:
# save the same way as saved as checkpoint during transformer lens training
pickle_path = "gpt2-small-checkpoint.pkl"
with open(pickle_path, "wb") as f:
    pkl.dump(reference_gpt2_small, f)

## Reload as local model

In [14]:
# reload the model
with open(pickle_path, "rb") as f:
    loaded_model = pkl.load(f)

In [15]:
print(loaded_model)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (h

In [16]:
loaded_model.hook_points

<bound method HookedRootModule.hook_points of HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoin

In [17]:
loaded_model.generate("(CNN) President Barack Obama caught in embarrassing new scandal\n", max_new_tokens=50, temperature=0.7, prepend_bos=True)

  0%|          | 0/50 [00:00<?, ?it/s]

'(CNN) President Barack Obama caught in embarrassing new scandal\n\nIn a meeting with top officials on Friday, the former first lady apologized to the nation and her family for her actions.\n\n"I am very sorry," she wrote. "I have been there for many, many years. I know it'

## Continued Pretraining

In [18]:
torch.set_grad_enabled(True)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fadd5acf3d0>

In [19]:
batch_size = 2
num_epochs = 2
max_steps = 2
log_every = 1
lr = 1e-3
weight_decay = 1e-2

In [20]:
dataset = datasets.load_dataset("NeelNanda/pile-10k", split="train")
tokens_dataset = tokenize_and_concatenate(dataset, loaded_model.tokenizer, streaming=False, max_length=loaded_model.cfg.n_ctx, column_name="text", add_bos_token=True, num_proc=4)
data_loader = torch.utils.data.DataLoader(tokens_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

In [21]:
loaded_model

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (h

In [22]:
optimizer = torch.optim.AdamW(loaded_model.parameters(), lr=lr, weight_decay=weight_decay)

In [23]:
losses = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Number of batches:", len(data_loader))
for epoch in range(num_epochs):
    for c, batch in tqdm(enumerate(data_loader)):
        tokens = batch['tokens'].cuda()
        logits = loaded_model(tokens)
        loss = lm_cross_entropy_loss(logits, tokens)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        losses.append(loss.item())
        if c % log_every == 0:
            print(f"Step: {c}, Loss: {loss.item():.4f}")
        if c > max_steps:
            break

Number of batches: 8478


0it [00:00, ?it/s]

torch.Size([2, 1024, 50257])
torch.Size([2, 1024])
tensor([[[  7.5261,  11.1214,   7.8919,  ...,  -3.1299,  -3.3873,   8.5934],
         [  5.3532,   5.2245,   0.5649,  ...,   0.4176,  -1.1736,   3.6051],
         [  5.5376,   4.8996,   1.8676,  ...,  -2.0804,  -1.6262,   3.2844],
         ...,
         [ 17.3733,  18.9965,  15.3096,  ...,  -7.3962, -13.1367,  13.4542],
         [  8.7477,  12.2562,  11.4773,  ...,  -7.1824,  -4.2026,   7.9818],
         [  6.2303,   8.0677,   6.1254,  ...,  -4.9028,  -2.8257,   5.8537]],

        [[  7.5261,  11.1214,   7.8919,  ...,  -3.1299,  -3.3873,   8.5934],
         [  5.6205,   4.5042,   2.5711,  ...,  -3.8182,  -1.4842,   3.9433],
         [  1.2403,   2.6366,   1.4573,  ...,  -2.0364,  -2.6426,   3.4255],
         ...,
         [  7.9584,  12.1716,  12.2138,  ...,  -2.3043,  -8.1238,   9.3807],
         [ 11.9549,  14.7538,  13.3306,  ...,  -3.5024, -10.7934,  11.0000],
         [  9.2820,  10.4914,  12.5014,  ...,  -3.9947,  -8.6783,   9.60

: 

## Now try loading from state_dict