<a target="_blank" href="https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Othello_GPT.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

This is a demo notebook porting the weights of the Othello-GPT Model from the excellent [Emergent World Representations](https://arxiv.org/pdf/2210.13382.pdf) paper to my TransformerLens library. Check out the paper's [blog post](https://thegradient.pub/othello/), [paper](https://arxiv.org/pdf/2210.13382.pdf), and [github](https://github.com/likenneth/othello_world/)

I think this is a super interesting paper, and I want to better enable work trying to reverse-engineer this model! I'm particularly curious about: 
* Why non-linear probes work much better than linear probes? 
    * Is the model internally representing the board in a usable yet non-linear way? 
    * Is there a representation of simpler concepts (eg diagonal lines in the board, number of black pieces, whether a cell is blank)) that the non-linear probe uses to compute board positions, but where the model internally reasons in this simpler representation?
* What's going up with the model editing? 
    * The paper edits across many layers at once. What's the minimal edit that works? 
        * Can we edit just before the final layer? 
        * Can we do a single edit rather than across many layers?
    * If we contrast model activations pre and post edit, what changes? 
        * Which components shift their output and how does this affect the logits? 
        * Is there significant depth of composition, or does it just affect the output logits?
* Can we find any non-trivial circuits in the model?
    * Start with [exploratory techniques](https://neelnanda.io/exploratory-analysis-demo), like direct logit attribution, or just looking at head attention patterns, and try to get traction
    * Pick a simple sub-task, eg figuring out whether a cell is blank, and try to interpret that.


I uploaded pre-converted checkpoints to HuggingFace, which can be automatically downloaded, and there's a code snippet to do this after the setup.

If you want to use the author's code, I wrote a script to load and convert checkpoints from the author's code, given below this.

To get started, check out the transformer lens [main tutorial](https://neelnanda.io/transformer-lens-demo) and [tutorial on exploratory techniques](https://neelnanda.io/exploratory-analysis-demo), and the author's [excellent Github](https://github.com/likenneth/othello_world/) (Ot**hello world**) for various notebooks demonstrating their code, showing how to load inputs, etc. And check out my [concrete open problems in mechanistic interpretability](https://www.lesswrong.com/s/yivyHaCAmMJ3CqSyj) sequence, especially the algorithmic problems post, for tips on this style of research.


# Setup (Skip)

In [5]:
# NBVAL_IGNORE_OUTPUT
import os

# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = False
IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"
try:
    import google.colab

    IN_COLAB = True
    print("Running as a Colab notebook")

    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

if IN_COLAB or IN_GITHUB:
    %pip install transformer_lens
    %pip install circuitsvis
    %pip install torchtyping

Running as a Jupyter notebook - intended for development only!
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload









[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.
Collecting torchtyping
  Using cached torchtyping-0.1.4-py3-none-any.whl.metadata (9.2 kB)
Using cached torchtyping-0.1.4-py3-none-any.whl (17 kB)
Installing collected packages: torchtyping
Successfully installed torchtyping-0.1.4

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24

In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

Using renderer: colab


In [3]:
import circuitsvis as cv

# Testing that the library works
cv.examples.hello("Neel")

In [6]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from torchtyping import TensorType as TT
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [7]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import (
    HookedTransformer,
    HookedTransformerConfig,
    FactoredMatrix,
    ActivationCache,
)

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [8]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x2b63dad90>

Plotting helper functions:

In [9]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        labels={"x": xaxis, "y": yaxis},
        **kwargs,
    ).show(renderer)


def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x": xaxis, "y": yaxis}, **kwargs).show(
        renderer
    )


def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y, x=x, labels={"x": xaxis, "y": yaxis, "color": caxis}, **kwargs
    ).show(renderer)

# Othello GPT

In [10]:
LOAD_AND_CONVERT_CHECKPOINT = False

In [11]:
import transformer_lens.utils as utils

cfg = HookedTransformerConfig(
    n_layers=8,
    d_model=512,
    d_head=64,
    n_heads=8,
    d_mlp=2048,
    d_vocab=61,
    n_ctx=59,
    act_fn="gelu",
    normalization_type="LNPre",
)
model = HookedTransformer(cfg)

In [12]:
# NBVAL_IGNORE_OUTPUT
sd = utils.download_file_from_hf(
    "NeelNanda/Othello-GPT-Transformer-Lens", "synthetic_model.pth"
)
# champion_ship_sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "championship_model.pth")
model.load_state_dict(sd)

./synthetic_model.pth:   0%|          | 0.00/101M [00:00<?, ?B/s]

<All keys matched successfully>

Code to load and convert one of the author's checkpoints to TransformerLens:

In [13]:
def convert_to_transformer_lens_format(in_sd, n_layers=8, n_heads=8):
    out_sd = {}
    out_sd["pos_embed.W_pos"] = in_sd["pos_emb"].squeeze(0)
    out_sd["embed.W_E"] = in_sd["tok_emb.weight"]

    out_sd["ln_final.w"] = in_sd["ln_f.weight"]
    out_sd["ln_final.b"] = in_sd["ln_f.bias"]
    out_sd["unembed.W_U"] = in_sd["head.weight"].T

    for layer in range(n_layers):
        out_sd[f"blocks.{layer}.ln1.w"] = in_sd[f"blocks.{layer}.ln1.weight"]
        out_sd[f"blocks.{layer}.ln1.b"] = in_sd[f"blocks.{layer}.ln1.bias"]
        out_sd[f"blocks.{layer}.ln2.w"] = in_sd[f"blocks.{layer}.ln2.weight"]
        out_sd[f"blocks.{layer}.ln2.b"] = in_sd[f"blocks.{layer}.ln2.bias"]

        out_sd[f"blocks.{layer}.attn.W_Q"] = einops.rearrange(
            in_sd[f"blocks.{layer}.attn.query.weight"],
            "(head d_head) d_model -> head d_model d_head",
            head=n_heads,
        )
        out_sd[f"blocks.{layer}.attn.b_Q"] = einops.rearrange(
            in_sd[f"blocks.{layer}.attn.query.bias"],
            "(head d_head) -> head d_head",
            head=n_heads,
        )
        out_sd[f"blocks.{layer}.attn.W_K"] = einops.rearrange(
            in_sd[f"blocks.{layer}.attn.key.weight"],
            "(head d_head) d_model -> head d_model d_head",
            head=n_heads,
        )
        out_sd[f"blocks.{layer}.attn.b_K"] = einops.rearrange(
            in_sd[f"blocks.{layer}.attn.key.bias"],
            "(head d_head) -> head d_head",
            head=n_heads,
        )
        out_sd[f"blocks.{layer}.attn.W_V"] = einops.rearrange(
            in_sd[f"blocks.{layer}.attn.value.weight"],
            "(head d_head) d_model -> head d_model d_head",
            head=n_heads,
        )
        out_sd[f"blocks.{layer}.attn.b_V"] = einops.rearrange(
            in_sd[f"blocks.{layer}.attn.value.bias"],
            "(head d_head) -> head d_head",
            head=n_heads,
        )
        out_sd[f"blocks.{layer}.attn.W_O"] = einops.rearrange(
            in_sd[f"blocks.{layer}.attn.proj.weight"],
            "d_model (head d_head) -> head d_head d_model",
            head=n_heads,
        )
        out_sd[f"blocks.{layer}.attn.b_O"] = in_sd[f"blocks.{layer}.attn.proj.bias"]

        out_sd[f"blocks.{layer}.mlp.b_in"] = in_sd[f"blocks.{layer}.mlp.0.bias"]
        out_sd[f"blocks.{layer}.mlp.W_in"] = in_sd[f"blocks.{layer}.mlp.0.weight"].T
        out_sd[f"blocks.{layer}.mlp.b_out"] = in_sd[f"blocks.{layer}.mlp.2.bias"]
        out_sd[f"blocks.{layer}.mlp.W_out"] = in_sd[f"blocks.{layer}.mlp.2.weight"].T

    return out_sd


if LOAD_AND_CONVERT_CHECKPOINT:
    synthetic_checkpoint = torch.load("/workspace/othello_world/gpt_synthetic.ckpt")
    for name, param in synthetic_checkpoint.items():
        if name.startswith("blocks.0") or not name.startswith("blocks"):
            print(name, param.shape)

    cfg = HookedTransformerConfig(
        n_layers=8,
        d_model=512,
        d_head=64,
        n_heads=8,
        d_mlp=2048,
        d_vocab=61,
        n_ctx=59,
        act_fn="gelu",
        normalization_type="LNPre",
    )
    model = HookedTransformer(cfg)

    model.load_and_process_state_dict(
        convert_to_transformer_lens_format(synthetic_checkpoint)
    )

Testing code for the synthetic checkpoint giving the correct outputs

In [16]:
# An example input
sample_input = torch.tensor(
    [
        [
            20,
            19,
            18,
            10,
            2,
            1,
            27,
            3,
            41,
            42,
            34,
            12,
            4,
            40,
            11,
            29,
            43,
            13,
            48,
            56,
            33,
            39,
            22,
            44,
            24,
            5,
            46,
            6,
            32,
            36,
            51,
            58,
            52,
            60,
            21,
            53,
            26,
            31,
            37,
            9,
            25,
            38,
            23,
            50,
            45,
            17,
            47,
            28,
            35,
            30,
            54,
            16,
            59,
            49,
            57,
            14,
            15,
            55,
            7,
        ]
    ]
)
# The argmax of the output (ie the most likely next move from each position)
sample_output = torch.tensor(
    [
        [
            21,
            41,
            40,
            34,
            40,
            41,
            3,
            11,
            21,
            43,
            40,
            21,
            28,
            50,
            33,
            50,
            33,
            5,
            33,
            5,
            52,
            46,
            14,
            46,
            14,
            47,
            38,
            57,
            36,
            50,
            38,
            15,
            28,
            26,
            28,
            59,
            50,
            28,
            14,
            28,
            28,
            28,
            28,
            45,
            28,
            35,
            15,
            14,
            30,
            59,
            49,
            59,
            15,
            15,
            14,
            15,
            8,
            7,
            8,
        ]
    ]
)
model(sample_input).argmax(dim=-1)

tensor([[21, 41, 40, 34, 40, 41,  3, 11, 21, 43, 40, 21, 28, 50, 33, 50, 33,  5,
         33,  5, 52, 46, 14, 46, 14, 47, 38, 57, 36, 50, 38, 15, 28, 26, 28, 59,
         50, 28, 14, 28, 28, 28, 28, 45, 28, 35, 15, 14, 30, 59, 49, 59, 15, 15,
         14, 15,  8,  7,  8]])