# SoLU Circuits

## Imports

In [1]:
import collections
import copy
import gc
import itertools
import json
import math
import os
import pickle
import random
import sys
import time
from functools import partial
from os import path
from pathlib import Path
from pprint import pprint

import datasets
import einops
import gdown
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
import pysvelte
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm.auto as tqdm
import tqdm.notebook as tqdm
import transformers
import wandb
from datasets import load_dataset
from easy_transformer import EasyTransformer, EasyTransformerConfig
from easy_transformer.EasyTransformer import Embed, Unembed, PosEmbed, LayerNorm, Attention
from easy_transformer.hook_points import HookedRootModule, HookPoint
from IPython.display import clear_output
from rich import print
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

# Fix for pysvelte import bug
sys.path.append('/home/user/.local/lib/python3.9/site-packages/pysvelte')

## Model Setup

### Config

In [2]:
# EasyTransformerConfig settings
cfg = {
    'd_model': 1024,
    'd_head': 64,
    'n_layers': 1,
    'n_ctx': 1024,
    'd_vocab': 50278,
    'use_attn_result': False,
    'act_fn': 'SoLU',
    'eps': 1e-5
}

# Calculated settings
cfg['n_heads'] = cfg['d_model']//cfg['d_head']
cfg['d_mlp'] = 4 * cfg['d_model']

# Custom settings not supported by EasyTransformer directly
custom_cfg = {
    'normalization': 'RMS',  # 'LN' 'RMS' or None
    'model_checkpoint_name': 'SoLU_1L_1024W_final_checkpoint.pth',
    'device': 'cuda',
}

In [3]:
class LayerNormPre(nn.Module):
    """Layer pre-normalization
    """

    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.eps = self.cfg['eps']

        # Adds a hook point for the normalization scale factor
        self.hook_scale = HookPoint()  # [batch, pos]

    def forward(self, x):
        x = x - x.mean(axis=-1, keepdim=True)  # [batch, pos, d_model]
        scale = self.hook_scale(x.pow(2).mean(-1, keepdim=True) +
                                self.eps).sqrt()  # [batch, pos, 1]
        return x / scale


class RMSNorm(nn.Module):
    """RMS Normalization"""

    def __init__(self, cfg, length):
        super().__init__()
        self.cfg = cfg
        self.eps = self.cfg['eps']
        self.length = length
        self.w = nn.Parameter(torch.ones(length))

        # Adds a hook point for the normalization scale factor
        self.hook_scale = HookPoint()  # [batch, pos]

    def forward(self, x):
        scale = self.hook_scale((x.pow(2).mean(-1, keepdim=True) +
                                 self.eps).sqrt())  # [batch, pos, 1]
        out = (x / scale) * self.w
        return out


class MLP(nn.Module):
    """MLP Layer

    Uses weights & biases (in & out), with a SoLU activation function
    inbetween.

    W_in is d_mlp x d_model (i.e. changes size from d_model to d_mlp before
    running SoLU), and then vice versa for W_out.
    """

    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(torch.empty(
            self.cfg['d_mlp'], self.cfg['d_model']))
        nn.init.kaiming_uniform_(self.W_in, a=np.sqrt(5))
        self.b_in = nn.Parameter(torch.zeros(self.cfg['d_mlp']))
        self.W_out = nn.Parameter(torch.empty(
            self.cfg['d_model'], self.cfg['d_mlp']))
        nn.init.kaiming_uniform_(self.W_out, a=np.sqrt(5))
        self.b_out = nn.Parameter(torch.zeros(self.cfg['d_model']))

        self.hook_pre = HookPoint()  # [batch, pos, d_mlp]
        self.hook_post = HookPoint()  # [batch, pos, d_mlp]

        if self.cfg['act_fn'].lower() == 'relu':
            self.act_fn = F.relu
        elif self.cfg['act_fn'].lower() == 'gelu_new':
            self.act_fn = gelu_new
        elif self.cfg['act_fn'].lower() == 'solu':
            self.act_fn = lambda x: F.softmax(x, dim=-1)*x
            self.hook_post_ln = HookPoint()  # [batch, pos, d_mlp]
            self.ln = LayerNorm(self.cfg, self.cfg['d_mlp'])
        else:
            raise ValueError(
                f"Invalid activation function name: {self.cfg['act_fn']}")

    def forward(self, x):
        # [batch, pos, d_mlp]
        x = self.hook_pre(torch.einsum(
            'md,bpd->bpm', self.W_in, x) + self.b_in)
        x = self.hook_post(self.act_fn(x))  # [batch, pos, d_mlp]
        if self.cfg['act_fn'].lower() == 'solu':
            x = self.hook_post_ln(self.ln(x))
        x = torch.einsum('dm,bpm->bpd', self.W_out, x) + \
            self.b_out  # [batch, pos, d_model]
        return x


class TransformerBlock(nn.Module):
    """Transformer block

    NOTE: There was a bug below with RMS using LayerNorm by mistake - will
    likely need to be folded into weights/biases anyway for our work.

    Currently does:

     - Attn
     - Norm
     - Add residual (norm + inputs)
     - MLP
     - Norm
     - Add residual (norm + prev residual)  
    """

    def __init__(self, cfg, block_index):
        super().__init__()
        self.cfg = cfg
        if custom_cfg['normalization'] == 'RMS':
            # This is a dumb bug, but was also there during training *shrug*
            self.norm1 = LayerNorm(self.cfg, self.cfg['d_model'])
            self.norm2 = LayerNorm(self.cfg, self.cfg['d_model'])
        elif custom_cfg['normalization'] == 'LN':
            self.norm1 = LayerNorm(self.cfg, self.cfg['d_model'])
            self.norm2 = LayerNorm(self.cfg, self.cfg['d_model'])
        self.attn = Attention(self.cfg)
        self.mlp = MLP(self.cfg)

        self.hook_attn_out = HookPoint()  # [batch, pos, d_model]
        self.hook_mlp_out = HookPoint()  # [batch, pos, d_model]
        # Note that resid_pre of layer k+1 is resid_post of layer k - given for convenience
        self.hook_resid_pre = HookPoint()  # [batch, pos, d_model]
        self.hook_resid_mid = HookPoint()  # [batch, pos, d_model]
        self.hook_resid_post = HookPoint()  # [batch, pos, d_model]

    def forward(self, x, pos_embed):
        resid_pre = self.hook_resid_pre(x)  # [batch, pos, d_model]
        if custom_cfg['normalization'] is not None:
            attn_out = self.hook_attn_out(
                self.attn(self.norm1(resid_pre)))  # [batch, pos, d_model]
        else:
            attn_out = self.hook_attn_out(
                self.attn(resid_pre))  # [batch, pos, d_model]
        resid_mid = self.hook_resid_mid(
            resid_pre + attn_out)  # [batch, pos, d_model]
        if custom_cfg['normalization'] is not None:
            mlp_out = self.hook_mlp_out(
                self.mlp(self.norm2(resid_mid)))  # [batch, pos, d_model]
        else:
            mlp_out = self.hook_mlp_out(
                self.mlp(resid_mid))  # [batch, pos, d_model]
        resid_post = self.hook_resid_post(
            resid_mid + mlp_out)  # [batch, pos, d_model]
        return resid_post

def loss_fn(logits, batch):
    log_probs = F.log_softmax(logits[:, :-1], dim=-1)
    pred_log_probs = torch.gather(log_probs, -1, batch[:, 1:, None])[..., 0]
    return -pred_log_probs.mean()


class Transformer(HookedRootModule):
    def __init__(self, cfg, tokenizer):
        super().__init__()

        self.cfg = cfg
        self.tokenizer = tokenizer

        self.embed = Embed(self.cfg)
        self.hook_embed = HookPoint()  # [batch, pos, d_model]

        self.pos_embed = PosEmbed(self.cfg)
        self.hook_pos_embed = HookPoint()  # [batch, pos, d_model]

        if custom_cfg['normalization'] == 'RMS':
            self.norm = RMSNorm(self.cfg, self.cfg['d_model'])
        elif custom_cfg['normalization'] == 'LN':
            self.norm = LayerNorm(self.cfg, self.cfg['d_model'])

        self.blocks = nn.ModuleList([TransformerBlock(
            self.cfg, block_index) for block_index in range(self.cfg['n_layers'])])

        self.unembed = Unembed(self.cfg)

        # Gives each module a parameter with its name (relative to this root module)
        # Needed for HookPoints to work
        self.setup()

    def forward(self, tokens, return_type='both', calc_logits=True):
        # Input x is either a batch of tokens ([batch, pos]) or a text string
        if type(tokens) == str:
            # If text, convert to tokens (batch_size=1)
            tokens = self.to_tokens(tokens)
        embed = self.hook_embed(self.embed(tokens))  # [batch, pos, d_model]
        pos_embed = self.hook_pos_embed(
            self.pos_embed(tokens))  # [batch, pos, d_model]
        
        # We have to use this now as it can't be disabled with the standard Attention
        # if custom_cfg['use_pos_resid']:
        residual = embed + pos_embed  # [batch, pos, d_model]
        # else:
        #     residual = embed  # [batch, pos, d_model]
        
        for block in self.blocks:
            # Note that each block includes skip connections, so we don't need
            # residual + block(residual)
            residual = block(residual, pos_embed)  # [batch, pos, d_model]
        if not calc_logits:
            # A flag to avoid calculating the logits - this significantly speeds up runtime on small models and reduces memory consumption, and can be used when we only want to get the activations, eg for finding max activating dataset examples.
            return None
        if custom_cfg['normalization'] is not None:
            residual = self.norm(residual)
        logits = self.unembed(residual)  # [batch, pos, d_vocab]
        if return_type == 'both':
            return (logits, loss_fn(logits, tokens))
        elif return_type == 'logits':
            return logits
        elif return_type == 'loss':
            return loss_fn(logits, tokens)

    def to_tokens(self, text):
        return self.tokenizer(self.tokenizer.bos_token+text, return_tensors='pt')['input_ids'].to(custom_cfg['device'])


tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
pad_token = '<PAD>'
tokenizer.add_special_tokens({'pad_token': pad_token})


model = Transformer(cfg, tokenizer)
model.to(custom_cfg['device'])

Transformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (norm): RMSNorm(
    (hook_scale): HookPoint()
  )
  (blocks): ModuleList(
    (0): TransformerBlock(
      (norm1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (norm2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_attn): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
        (hook_post_ln): HookPoint()
        (ln): LayerNorm(
          (hook_scale): HookPoint()
          (hook_normalized): HookPoint()
        )
      )
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoi

### Load from the checkpoint

In [4]:
# Checkpoint provided by Neel Nanda
checkpoint_url = "https://drive.google.com/file/d/16bqEZg9Oq0WT2xOcNS1HJkmR7qB2G14o/view"

# Create the directory if it doesn't exist
checkpoint_dir = "/tmp/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

# Download the checkpoint if it doesn't exist
checkpoint_file = path.join(checkpoint_dir, custom_cfg['model_checkpoint_name'])
if not path.exists(checkpoint_file):
    gdown.download(checkpoint_url, checkpoint_file, quiet=False, fuzzy=True)

# Load the checkpoint
state_dict = torch.load(checkpoint_file)
# EasyTransformer has an additional bias term for the unembedding, so we simply set it to zero.
state_dict["unembed.b_U"] = torch.zeros(cfg['d_vocab'])

model.load_state_dict(state_dict)

<All keys matched successfully>

## Find interesting activations

A 1-layer model without an MLP can't do much more than skip trigrams. Whilst the MLP layer added may improve this a little, the prompts will still need to have quite simple answers.

In this case we'll look for the ability of the model to close HTML tags. As an simple overview of how HTML tags work, whenever a tag is used (e.g. `<b>` for bold) it must be closed when you no longer want it to apply (e.g. `<b>bold text</b> normal text`).

Note that `</` is a single token - so we can't use `<` as the last token and expect to see `\`.

In [5]:
def get_next_token(prompt: str) -> str:
    """Run a forward pass to get the next token"""
    logits = model(prompt)[0]
    log_probabilities = F.log_softmax(logits, dim=-1)
    predictions = torch.argmax(log_probabilities, 2)
    print(log_probabilities)
    next_token = [model.tokenizer.decode(t) for t in predictions.squeeze()][-1]
    return next_token

In [6]:
# Example prompts to run through the model
prompts = [
    "<h1>Title",
    "<b>Some bold text</",
    "<p>An interesting paragraph</",
    "<table><tr><th>Model name"
]

# Run each prompt (with a few tokens appended by the model)
for prompt in prompts:
    result = prompt
    
    additional_tokens = 2
    for i in range(additional_tokens):
        next_token = get_next_token(result)
        result = result + next_token
        
    print(result)