# SoLU Circuits

## Imports

In [197]:
import collections
import copy
import gc
import html
import itertools
import json
import math
import os
import pickle
import random
import sys
import time
from functools import partial
from os import path
from pathlib import Path
from pprint import pprint

import datasets
import einops
import gdown
import ipywidgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm.auto as tqdm
import tqdm.notebook as tqdm
import transformers
import wandb
from datasets import load_dataset
from easy_transformer.EasyTransformer import (MLP, Attention, EasyTransformer,
                                              Embed, LayerNorm, PosEmbed,
                                              TransformerBlock, Unembed)
from easy_transformer.EasyTransformerConfig import EasyTransformerConfig
from easy_transformer.hook_points import HookedRootModule, HookPoint
from IPython.core.display import HTML
from IPython.display import clear_output, display
from rich import print
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

# Fix for pysvelte import bug
sys.path.append('/workspaces/solu-circuits/PySvelte')

In [2]:
# Run this after the above fix
import pysvelte

## Model

### Config

Given that we're using a checkpoint of a model that has already been run, we add in the config settings from that model here.

In [3]:
# EasyTransformerConfig settings
cfg = {
    'd_model': 1024,
    'd_head': 64,
    'n_layers': 1,
    'n_ctx': 1024,
    'd_vocab': 50278,
    'use_attn_result': False,
    'act_fn': 'solu_ln',
    'eps': 1e-5,
    # The trained model used LN everywhere, except for RMS just before 
    # the final unembedding. We switch to LNPre (folding in the weights/biases
    # to the next weights), and then manually override the final RMS normalization
    # to be RMSPre in the code below. See the 'Fold in weights and biases' section
    # for more details.
    'normalization_type': 'LNPre',
    "model_name": "SoLU"
}

# Calculated settings
cfg['n_heads'] = cfg['d_model']//cfg['d_head']
cfg['d_mlp'] = 4 * cfg['d_model']

# Custom settings not supported by EasyTransformer directly
custom_cfg = {
    'model_checkpoint_name': 'SoLU_1L_1024W_final_checkpoint.pth',
    'device': 'cuda',
}

### Model Setup

This uses the `EasyTransformer` components, where possible (as they can be configured identically to the code that was used for training).

In [4]:
class RMSNormPre(nn.Module):
    """RMS Pre Normalization
    
    This is RMS Normalization without the multiplation by a weights term, 
    as that has been folded into the next layer's weights instead."""

    def __init__(self, cfg, length):
        super().__init__()
        self.eps = cfg.eps
        self.length = length
        # self.w = nn.Parameter(torch.ones(length)) # Folded

        # Adds a hook point for the normalization scale factor
        self.hook_scale = HookPoint()  # [batch, pos]

    def forward(self, x):
        scale = self.hook_scale((x.pow(2).mean(-1, keepdim=True) +
                                 self.eps).sqrt())  # [batch, pos, 1]
        out = (x / scale) # * self.w # (folded)
        return out

    
class Transformer(EasyTransformer):
    """Transformer
    
    The checkpointed model had a few modifications from the
    standard `EasyTransformer`, so we extend it and add these in here.
    """
    def __init__(self, cfg: EasyTransformerConfig):
        super().__init__("custom", cfg=cfg)
        
        # Custom tokenizer setup (different pad token) from trained model
        self.tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
        pad_token = '<PAD>'
        self.tokenizer.add_special_tokens({'pad_token': pad_token})
        
        # Custom final layer norm (trained model used RMS Norm, and we've 
        # folded the weights out of this)
        self.ln_final = RMSNormPre(self.cfg, self.cfg.d_model)
   
    def to_tokens(self, text):
        return self.tokenizer(self.tokenizer.bos_token+text, return_tensors='pt')['input_ids'].to(custom_cfg['device'])

    
# Create the model
model = Transformer(EasyTransformerConfig.from_dict(cfg))
model.to(custom_cfg['device'])

Transformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_attn): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
        (hook_post_ln): HookPoint()
        (ln): LayerNorm(
          (hook_scale): HookPoint()
          (hook_normalized): HookPoint()
        )
      )
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hoo

## Load from the checkpoint

### Download Checkpoint

In [5]:
# Checkpoint provided by Neel Nanda
checkpoint_url = "https://drive.google.com/file/d/16bqEZg9Oq0WT2xOcNS1HJkmR7qB2G14o/view"

# Create the directory if it doesn't exist
checkpoint_dir = "/tmp/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

# Download the checkpoint if it doesn't exist
checkpoint_file = path.join(checkpoint_dir, custom_cfg['model_checkpoint_name'])
if not path.exists(checkpoint_file):
    gdown.download(checkpoint_url, checkpoint_file, quiet=False, fuzzy=True)

Downloading...
From: https://drive.google.com/uc?id=16bqEZg9Oq0WT2xOcNS1HJkmR7qB2G14o
To: /tmp/checkpoints/SoLU_1L_1024W_final_checkpoint.pth
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 468M/468M [00:05<00:00, 80.7MB/s]


### Fold in weights and biases

We fold the `LayerNorm` weights and biases in to the weights after them, for simplicty, as per [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html#model-simplifications).

In [6]:
# Get the state dictionary from the checkpoint
sd = torch.load(checkpoint_file)
print(sd.keys())

# Fold in layer normalization weights & biases, for each layer (just one in the toy example)
for layer in range(cfg['n_layers']):
    # Pre-attention layer norm weights -> Query/Key/Value weights
    pre_ln_w = sd[f"blocks.{layer}.norm1.w"]
    W_Q_old = sd[f"blocks.{layer}.attn.W_Q"]
    W_K_old = sd[f"blocks.{layer}.attn.W_K"]
    W_V_old = sd[f"blocks.{layer}.attn.W_V"]
    sd[f"blocks.{layer}.attn.W_Q"] = W_Q_old * pre_ln_w
    sd[f"blocks.{layer}.attn.W_K"] = W_K_old * pre_ln_w
    sd[f"blocks.{layer}.attn.W_V"] = W_V_old * pre_ln_w
    
    # Pre-attention layer norm biases -> Query/Key/Value biases
    pre_ln_b = sd[f"blocks.{layer}.norm1.b"]
    sd[f"blocks.{layer}.attn.b_Q"] = W_Q_old @ pre_ln_b + sd[f"blocks.{layer}.attn.b_Q"]
    sd[f"blocks.{layer}.attn.b_K"] = W_K_old @ pre_ln_b + sd[f"blocks.{layer}.attn.b_K"]
    sd[f"blocks.{layer}.attn.b_V"] = W_V_old @ pre_ln_b + sd[f"blocks.{layer}.attn.b_V"]
    
    # Post-attention layer weights/biases -> MLP weights/biases
    W_in_old = sd[f"blocks.{layer}.mlp.W_in"]
    sd[f"blocks.{layer}.mlp.W_in"] = W_in_old * sd[f"blocks.{layer}.norm2.w"]
    sd[f"blocks.{layer}.mlp.b_in"] = W_in_old @ sd[f"blocks.{layer}.norm2.b"] \
                                        + sd[f"blocks.{layer}.mlp.b_in"]
    
    # Delete the weights/biases that are no longer used (as they're folded in)
    del sd[f"blocks.{layer}.norm1.w"]
    del sd[f"blocks.{layer}.norm1.b"]
    del sd[f"blocks.{layer}.norm2.w"]
    del sd[f"blocks.{layer}.norm2.b"]

# Fold the post-blocks (pre-unembed) RMS norm weights -> unembed weights
sd["unembed.W_U"] *= sd["norm.w"]
del sd["norm.w"] # Delete as no longer used (folded in)
    
# EasyTransformer has an additional bias term for the unembedding, so we simply set it to zero.
sd["unembed.b_U"] = torch.zeros(cfg['d_vocab'])

# Load the state dict into the model
model.load_state_dict(sd)

<All keys matched successfully>

## Find interesting activations

A 1-layer model without an MLP can't do much more than skip trigrams. Whilst the MLP layer added may improve this a little, the prompts will still need to have quite simple answers.

In this case we'll look for the ability of the model to close HTML tags. As an simple overview of how HTML tags work, whenever a tag is used (e.g. `<b>` for bold) it must be closed when you no longer want it to apply (e.g. `<b>bold text</b> normal text`).

Note that `</` is a single token - so we can't use `<` as the last token and expect to see `/`.

In [7]:
def get_next_token(prompt: str) -> str:
    """Run a forward pass to get the next token"""
    logits = model(prompt, "logits")
    log_probabilities = F.log_softmax(logits, dim=-1)
    predictions = torch.argmax(log_probabilities, 2)
    next_token = [model.tokenizer.decode(t) for t in predictions.squeeze()][-1]
    probability = log_probabilities[-1]
    return next_token

In [8]:
# Example prompts to run through the model
prompts = [
    "<h1>Title</",
    "<b>Some bold text</",
    "<p>An interesting paragraph</",
    "<table><tr><th>Model name</",
    "<li>List item</"
]

# Run each prompt (with a few tokens appended by the model)
for prompt in prompts:
    result = prompt
    
    additional_tokens = 2
    for i in range(additional_tokens):
        next_token = get_next_token(result)
        result = result + next_token
        
    print(result)

## Find Important Neurons

### MLP neurons

The MLP activations are multiplied by $W_\text{out}$ (and added to $b_\text{out}$), added to the residual stream & sent through the RMS activation with weights/biases folded out (i.e. removed). It's then then multiplied by the unembedding weights ($W_U$) to get the logits. This means we can calculate the importance of the neurons by multiplying these weights by the neuron values.

In [9]:
prompts_neuron_importance: dict[str, np.ndarray] = {} # {prompt: [importance of each neuron]}

# Get the importance of each neuron, for the list of prompts
for prompt in prompts:
    # Setup the cache
    cache = {}
    model.cache_all(cache)

    # Get the logits
    logits = model(prompt, "logits")[0] # First batch item -> [ tokens x d_vocab ]
    predictions = torch.argmax(logits, 1) # [ tokens ]

    # Get the neurons
    mlp_neurons = cache[f'blocks.{layer}.mlp.hook_post_ln'][0] # [ d_vocab x d_model ]

    # Get the combined weights these are multiplied by
    mlp_neuron_weights = sd['unembed.W_U'] @ sd['blocks.0.mlp.W_out'] # [ d_vocab x d_model ]

    # Get the values for just the last token (the one we've predicted)
    last_token_idx = predictions[-1].item()
    last_token_weights = mlp_neuron_weights[last_token_idx]
    last_token_mlp_neurons = mlp_neurons[-1]
    last_token_neuron_importance = last_token_weights * last_token_mlp_neurons

    # Calculate the importance as a percentage
    importance_percentage = last_token_neuron_importance / last_token_neuron_importance.sum()
    
    # Add to the results
    prompts_neuron_importance[prompt] = importance_percentage.cpu().numpy()

In [10]:
importance = pd.DataFrame(prompts_neuron_importance)
importance["average"] = importance.sum(axis = 1)/len(importance.columns)

# Format and sort
importance.index.name = "neuron"
importance = importance.sort_values(by="average", ascending=False)

importance.head(10)

Unnamed: 0_level_0,<h1>Title</,<b>Some bold text</,<p>An interesting paragraph</,<table><tr><th>Model name</,<li>List item</,average
neuron,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
3444,0.096613,0.075791,0.110834,0.09613,0.183113,0.112496
1785,0.101765,0.23889,0.040109,0.017412,0.027491,0.085133
2538,0.024398,0.029209,0.033149,0.034083,0.064888,0.037146
3810,0.051312,0.014371,0.056061,0.015946,0.0452,0.036578
733,0.059635,0.013014,0.009346,0.095818,0.003096,0.036182
1292,0.016477,0.031212,0.018242,0.060624,0.017517,0.028814
1931,0.064833,0.029711,0.038862,0.001635,0.006584,0.028325
3369,0.005495,0.0097,0.039516,-0.001043,0.080424,0.026818
2844,0.059883,0.016307,0.021545,0.017598,0.007661,0.024599
3313,0.023691,-0.001589,0.058233,-0.005241,0.025868,0.020193


From these results it's clear that a few neurons are important for working out how to close html tags (e.g. `3444`).

##### Visualising key neurons

We'll now run the model with a larger chunk of html, and see which parts trigger these neurons most.

In [11]:
# Create the prompt
prompt = """<html>
  <head>
    <title>Dubious HTML Skills</title>
  </head>
  <div id="heading">
    <h1>
      Dubious HTML Skills
    </h1>
    <p class="subtitle" style="color: red;">
      By Alan Cooney
    </p>
  </div>
  <hr/><!-- Self closing tag to add confusion -->
  <div id="main">
    <h2>
      An unordered list
    </h2>
    <div>
      <ul>
        <li>List item</li>
        <li>Another interesting thing</li>
      </ul>
      
      <h2>
        An ordered list
      </h2>
      <ol>
         <li>List item</li>
        <li>Another interesting thing</li>
      </ol>
    </div>
  </div>
</html>"""

# Convert to tokens
tokens = model.to_tokens(prompt)

# Setup the cache
model.reset_hooks()
cache = {}
model.cache_all(cache)

# Get the logits
logits = model(prompt, "logits")[0] # [225, 50278]

# Get the neurons
mlp_neurons = cache[f'blocks.{layer}.mlp.hook_post_ln'][0] # [225, 4096]
mlp_neurons.size()

torch.Size([225, 4096])

### Visualising Neuron Activations

We can now run a larger amount of html through the model, and see which parts of the text activate the most important neurons.

#### CSS

In [208]:
%%html
<style>
    /* Container */
    .outerBlock {
        position: relative;
        display: inline-block;
        width: auto;
        padding: 0 5px;
        margin: 1px 1px 1px -2px;
        color: #fff;
    }

    /* Tooltip text */
    .outerBlock .tooltiptext {
      visibility: hidden;
      width: auto;
      background-color: black;
      color: #fff;
      text-align: center;
      padding: 5px 0;
      border-radius: 6px;

      /* Position the tooltip text - see examples below! */
      position: absolute;
      z-index: 1;
    }

    /* Show the tooltip text when you mouse over the tooltip container */
    .outerBlock:hover .tooltiptext {
      visibility: visible;
    }
    </style>

#### Neuron visualisation tool

In [12]:
def text_to_token_strings(text: str) -> list[str]:
    # Extremely hacky function to convert text into a list of each token (as text, not as a token index)
    return model.tokenizer.batch_decode(model.tokenizer.encode(text), clean_up_tokenization_spaces=False)

# Show as a dataframe 
mlp_neurons_df = pd.DataFrame(mlp_neurons.cpu())
mlp_neurons_df["token"] = pd.Series(text_to_token_strings(prompt))
mlp_neurons_df.set_index(["token"], inplace=True)

# Set sum to 1
for col in mlp_neurons_df.columns:
    mlp_neurons_df[col] = mlp_neurons_df[col]/sum(mlp_neurons_df[col])
                                                  
mlp_neurons_df.head()

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,4086,4087,4088,4089,4090,4091,4092,4093,4094,4095
token,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
<,0.003334,-0.016737,-0.294515,0.007116,0.098163,0.039022,-0.040334,0.023278,-0.003482,0.019633,...,-0.035695,0.020276,-0.005266,0.025283,0.007698,0.040952,0.008989,-0.138265,-0.001714,-0.254747
html,-0.004978,-0.013288,0.036868,0.003957,0.015209,-0.036187,0.053214,0.002984,-0.0016,0.00434,...,-0.001678,0.006071,-0.002647,0.004438,0.003236,0.002682,0.003131,0.001807,0.000993,0.115502
>,-0.007064,0.003895,-0.024963,0.006397,0.022449,3.5e-05,0.006373,0.004162,-0.002383,0.006273,...,0.002674,0.000766,-0.002899,0.00997,0.004745,0.005126,0.006045,0.008157,0.005207,-0.071425
\n,-0.006206,-0.008013,-0.005806,0.003963,0.025213,-0.012366,0.024937,0.004367,-0.001676,0.001181,...,0.032919,0.007068,-0.001562,0.010807,0.003854,8.9e-05,0.0056,-0.004245,-0.002289,-0.056427
,-0.01043,0.009132,-0.102923,0.00213,0.03219,0.009086,-0.00332,0.001905,0.001117,0.005754,...,-0.00143,0.009375,-0.002398,0.015501,0.004504,0.010736,0.005808,0.011842,-0.004847,-0.065786


In [225]:
def visualise_neuron_activation(prompt: list[str], hook_cache: np.ndarray, neuron_index: int) -> None:
    """Visualise the activation of a specific neuron, for an input of text
    
    Args:
        prompt (list[str]): The prompt as a list of string tokens (as text not as indicies)
        hook_cache (np.ndarray): The cache from a specific hook 
            [ number_of_tokens x number_of_neurons ]
        neuron_index (int): The index of the neuron that we're considering
    """
    
    # Get the activations as a percentage
    # Note this will sum to 1, but some neurons will be negative
    activations = hook_cache[:, neuron_index]
    max = activations.max()
    min = activations.min()
    
    # Print them
    rendered_tokens = []
    for idx, token in enumerate(prompt):
              
        # Handle new lines
        if token == '\n':
            rendered_tokens.append('<br/>')
        
        # Handle spaces
        elif re.match(r"^ *$", token):
            rendered_tokens.append(token.replace(" ", "&nbsp;"))
        
        # Otherwise render the token
        else:
            activation = activations[idx].item()
        
            green = int((activation /(max - min)) * 100 + 125) if activation > 0 else 125
            red = int((activation / (max - min)) * 100 + 125) if activation < 0 else 125
            color = (red, green, 125)

            block = f"""<div style='background: rgb{color}' class="outerBlock">
                {html.escape(token)}
                <span class="tooltiptext">{html.escape(token)} : {round(float(activation) / (max - min), 4)}</span>
            </div>"""
            
            rendered_tokens.append(block)
    
    return HTML("<div>" + "\n".join(rendered_tokens) + "</div>")

In [226]:
def visualise_neurons_activations(prompt: list[str], hook_cache: np.ndarray, default_neuron: int) -> None:
    
    number_of_neurons = hook_cache.size()[1]
    
    neuron_dropdown = ipywidgets.Dropdown(
        options=[i for i in range(number_of_neurons)],
        value=default_neuron,
        description='Neuron:',
    )

    display(neuron_dropdown)

    out = ipywidgets.Output()
    display(out)

    with out:
        html_output = visualise_neuron_activation(prompt, hook_cache, default_neuron)
        display(html_output)

    def neuron_dropdown_onchange(change):
        if change['type'] == 'change' and change['name'] == 'value':
            with out:
                clear_output()
                html_output = visualise_neuron_activation(prompt, hook_cache, change['new'])
                display(html_output)

    neuron_dropdown.observe(neuron_dropdown_onchange)

In [227]:
visualise_neurons_activations(text_to_token_strings(prompt), mlp_neurons.cpu(), 3444)

Dropdown(description='Neuron:', index=3444, options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,…

Output()