# SoLU Circuits

## Imports

In [1]:
import collections
import copy
import gc
import itertools
import json
import math
import os
import pickle
import random
import sys
import time
from functools import partial
from os import path
from pathlib import Path
from pprint import pprint

import datasets
import einops
import gdown
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
import pysvelte
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm.auto as tqdm
import tqdm.notebook as tqdm
import transformers
import wandb
from datasets import load_dataset
from easy_transformer import EasyTransformer, EasyTransformerConfig
from easy_transformer.EasyTransformer import Embed, Unembed, PosEmbed, LayerNorm, Attention, MLP, TransformerBlock
from easy_transformer.hook_points import HookedRootModule, HookPoint
from IPython.display import clear_output
from rich import print
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

# Fix for pysvelte import bug
sys.path.append('/home/user/.local/lib/python3.9/site-packages/pysvelte')

## Model

### Config

In [2]:
# EasyTransformerConfig settings
cfg = {
    'd_model': 1024,
    'd_head': 64,
    'n_layers': 1,
    'n_ctx': 1024,
    'd_vocab': 50278,
    'use_attn_result': False,
    'act_fn': 'solu_ln',
    'eps': 1e-5,
    # The trained model used LN everywhere, except for RMS just before 
    # the final unembedding. We switch to LNPre (folding in the weights/biases
    # to the next weights), and then manually override the final RMS normalization
    # to be RMSPre in the code below. See the 'Fold in weights and biases' section
    # for more details.
    'normalization_type': 'LN'
}

# Calculated settings
cfg['n_heads'] = cfg['d_model']//cfg['d_head']
cfg['d_mlp'] = 4 * cfg['d_model']

# Custom settings not supported by EasyTransformer directly
custom_cfg = {
    'model_checkpoint_name': 'SoLU_1L_1024W_final_checkpoint.pth',
    'device': 'cuda',
}

### Model Setup

This uses the `EasyTransformer` components, where possible (as they can be configured identically to the code that was used for training).

In [3]:
class RMSNormPre(nn.Module):
    """RMS Pre Normalization
    
    This is RMS Normalization without the multiplation by a weights term, 
    as that has been folded into the next weights instead."""

    def __init__(self, cfg, length):
        super().__init__()
        self.cfg = cfg
        self.eps = self.cfg['eps']
        self.length = length
        # self.w = nn.Parameter(torch.ones(length)) # Folded

        # Adds a hook point for the normalization scale factor
        self.hook_scale = HookPoint()  # [batch, pos]

    def forward(self, x):
        scale = self.hook_scale((x.pow(2).mean(-1, keepdim=True) +
                                 self.eps).sqrt())  # [batch, pos, 1]
        out = (x / scale) # * self.w # (don't multiply by weights as folded)
        return out


def loss_fn(logits, batch):
    log_probs = F.log_softmax(logits[:, :-1], dim=-1)
    pred_log_probs = torch.gather(log_probs, -1, batch[:, 1:, None])[..., 0]
    return -pred_log_probs.mean()


class Transformer(HookedRootModule):
    def __init__(self, cfg, tokenizer):
        super().__init__()

        self.cfg = cfg
        self.tokenizer = tokenizer

        self.embed = Embed(self.cfg)
        self.hook_embed = HookPoint()  # [batch, pos, d_model]

        self.pos_embed = PosEmbed(self.cfg)
        self.hook_pos_embed = HookPoint()  # [batch, pos, d_model]

        # Switched to RMSNormPre as we've folded in the weights (was RMS in training)
        self.norm = RMSNormPre(self.cfg, self.cfg['d_model'])

        self.blocks = nn.ModuleList([TransformerBlock(
            self.cfg, block_index) for block_index in range(self.cfg['n_layers'])])

        self.unembed = Unembed(self.cfg)

        # Gives each module a parameter with its name (relative to this root module)
        # Needed for HookPoints to work
        self.setup()

    def forward(self, tokens, return_type='both', calc_logits=True):
        # Input x is either a batch of tokens ([batch, pos]) or a text string
        if type(tokens) == str:
            # If text, convert to tokens (batch_size=1)
            tokens = self.to_tokens(tokens)
        embed = self.hook_embed(self.embed(tokens))  # [batch, pos, d_model]
        pos_embed = self.hook_pos_embed(
            self.pos_embed(tokens))  # [batch, pos, d_model]
        
        # We have to use this now as it can't be disabled with the standard Attention
        # if custom_cfg['use_pos_resid']:
        residual = embed + pos_embed  # [batch, pos, d_model]
        # else:
        #     residual = embed  # [batch, pos, d_model]
        
        for block in self.blocks:
            # Note that each block includes skip connections, so we don't need
            # residual + block(residual)
            residual = block(residual)  # [batch, pos, d_model]
        if not calc_logits:
            # A flag to avoid calculating the logits - this significantly speeds up runtime on small models and reduces memory consumption, and can be used when we only want to get the activations, eg for finding max activating dataset examples.
            return None
        
        # Norm the residual
        residual = self.norm(residual)
        
        logits = self.unembed(residual)  # [batch, pos, d_vocab]
        if return_type == 'both':
            return (logits, loss_fn(logits, tokens))
        elif return_type == 'logits':
            return logits
        elif return_type == 'loss':
            return loss_fn(logits, tokens)

    def to_tokens(self, text):
        return self.tokenizer(self.tokenizer.bos_token+text, return_tensors='pt')['input_ids'].to(custom_cfg['device'])

# Custom Tokenizer
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
pad_token = '<PAD>'
tokenizer.add_special_tokens({'pad_token': pad_token})

# Create the model
model = Transformer(cfg, tokenizer)
model.to(custom_cfg['device'])

Transformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (norm): RMSNormPre(
    (hook_scale): HookPoint()
  )
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_attn): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
        (hook_post_ln): HookPoint()
        (ln): LayerNorm(
          (hook_scale): HookPoint()
          (hook_normalized): HookPoint()
        )
      )
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoin

### Load from the checkpoint

### Download Checkpoint

In [4]:
# Checkpoint provided by Neel Nanda
checkpoint_url = "https://drive.google.com/file/d/16bqEZg9Oq0WT2xOcNS1HJkmR7qB2G14o/view"

# Create the directory if it doesn't exist
checkpoint_dir = "/tmp/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

# Download the checkpoint if it doesn't exist
checkpoint_file = path.join(checkpoint_dir, custom_cfg['model_checkpoint_name'])
if not path.exists(checkpoint_file):
    gdown.download(checkpoint_url, checkpoint_file, quiet=False, fuzzy=True)

### Fold in weights and biases

We fold the `LayerNorm` weights and biases in to the weights after them, for simplicty, as per [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html#model-simplifications).

In [80]:
a = torch.tensor([[1,1], [1,1]])
a.size()           

torch.Size([2, 2])

In [81]:
b = torch.tensor([2,2])
b.size()

torch.Size([2])

In [82]:
a * b

tensor([[2, 2],
        [2, 2]])

In [107]:
# Get the state dictionary from the checkpoint
sd = torch.load(checkpoint_file)
print(sd.keys())

# Fold in layer normalization weights & biases, for each layer (just one in the toy example)
for layer in range(cfg['n_layers']):
    # Pre-attention layer norm weights -> Query/Key/Value weights
    pre_ln_w = sd[f"blocks.{layer}.norm1.w"]
    W_Q_old = sd[f"blocks.{layer}.attn.W_Q"]
    W_K_old = sd[f"blocks.{layer}.attn.W_K"]
    W_V_old = sd[f"blocks.{layer}.attn.W_V"]
    sd[f"blocks.{layer}.attn.W_Q"] = W_Q_old * pre_ln_w
    sd[f"blocks.{layer}.attn.W_K"] = W_K_old * pre_ln_w
    sd[f"blocks.{layer}.attn.W_V"] = W_V_old * pre_ln_w
    
    # Pre-attention layer norm biases -> Query/Key/Value biases
    pre_ln_b = sd[f"blocks.{layer}.norm1.b"]
    sd[f"blocks.{layer}.attn.b_Q"] = W_Q_old @ pre_ln_b + sd[f"blocks.{layer}.attn.b_Q"]
    sd[f"blocks.{layer}.attn.b_K"] = W_K_old @ pre_ln_b + sd[f"blocks.{layer}.attn.b_K"]
    sd[f"blocks.{layer}.attn.b_V"] = W_V_old @ pre_ln_b + sd[f"blocks.{layer}.attn.b_V"]
    
    # Post-attention layer weights/biases -> MLP weights/biases
    W_in_old = sd[f"blocks.{layer}.mlp.W_in"]
    sd[f"blocks.{layer}.mlp.W_in"] = W_in_old * sd[f"blocks.{layer}.norm2.w"]
    sd[f"blocks.{layer}.mlp.b_in"] = W_in_old @ sd[f"blocks.{layer}.norm2.b"] \
                                        + sd[f"blocks.{layer}.mlp.b_in"]
    
#     # Delete the weights/biases that are no longer used (as they're folded in)
#     del sd[f"blocks.{layer}.norm1.w"]
#     del sd[f"blocks.{layer}.norm1.b"]
#     del sd[f"blocks.{layer}.norm2.w"]
#     del sd[f"blocks.{layer}.norm2.b"]

    # TODO delete
    sd['blocks.0.ln1.w'] = sd['blocks.0.norm1.w']
    sd['blocks.0.ln1.b'] = sd['blocks.0.norm1.b']
    sd['blocks.0.ln2.w'] = sd['blocks.0.norm2.w']
    sd['blocks.0.ln2.b'] = sd['blocks.0.norm2.b']
    del sd['blocks.0.norm1.w']
    del sd['blocks.0.norm1.b']
    del sd['blocks.0.norm2.w']
    del sd['blocks.0.norm2.b']
    
    # TODO delete
    sd['blocks.0.ln1.w'] = torch.ones_like(sd['blocks.0.ln1.w'])
    sd['blocks.0.ln1.b'] = torch.zeros_like(sd['blocks.0.ln1.b'])
    sd['blocks.0.ln2.b'] = torch.zeros_like(sd['blocks.0.ln1.b'])
    sd['blocks.0.ln2.w'] = torch.ones_like(sd['blocks.0.ln2.w'])

# Fold the post-blocks (pre-unembed) RMS norm weights -> unembed weights
sd["unembed.W_U"] *= sd["norm.w"]
del sd["norm.w"] # Delete as no longer used (folded in)
    
# EasyTransformer has an additional bias term for the unembedding, so we simply set it to zero.
sd["unembed.b_U"] = torch.zeros(cfg['d_vocab'])

# Load the state dict into the model
model.load_state_dict(sd)

<All keys matched successfully>

### Load the checkpoint

## Find interesting activations

A 1-layer model without an MLP can't do much more than skip trigrams. Whilst the MLP layer added may improve this a little, the prompts will still need to have quite simple answers.

In this case we'll look for the ability of the model to close HTML tags. As an simple overview of how HTML tags work, whenever a tag is used (e.g. `<b>` for bold) it must be closed when you no longer want it to apply (e.g. `<b>bold text</b> normal text`).

Note that `</` is a single token - so we can't use `<` as the last token and expect to see `\`.

In [108]:
def get_next_token(prompt: str) -> str:
    """Run a forward pass to get the next token"""
    logits = model(prompt)[0]
    log_probabilities = F.log_softmax(logits, dim=-1)
    predictions = torch.argmax(log_probabilities, 2)
    print(log_probabilities)
    next_token = [model.tokenizer.decode(t) for t in predictions.squeeze()][-1]
    return next_token

In [109]:
# Example prompts to run through the model
prompts = [
    "<h1>Title",
    "<b>Some bold text</",
    "<p>An interesting paragraph</",
    "<table><tr><th>Model name"
]

# Run each prompt (with a few tokens appended by the model)
for prompt in prompts:
    result = prompt
    
    additional_tokens = 2
    for i in range(additional_tokens):
        next_token = get_next_token(result)
        result = result + next_token
        
    print(result)