# Introduction

<b style="color: red">To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.</b>

This is code to accompany my blog post building on Kenneth Li et al's paper Emergent World Representations. I found that the network actually learns a **linear** world model, but in terms of whether a cell contains a piece of **my colour** vs the **opponent's colour**. I demonstrate how to use and intervene with the linear probe I found, use the probe to start interpreting the model and studying circuits, and some starter code for neuron interpretability and activation patching

If you're new to mechanistic interpretability, check out [my blog post on getting started](https://neelnanda.io/getting-started). This notebook heavily uses my TransformerLens library, check out [the main tutorial for a better introduction](https://neelnanda.io/transformer-lens-demo).

Read the blog post here: https://neelnanda.io/othello

Look up unfamiliar terms here: https://neelnanda.io/glossary

The paper: https://arxiv.org/pdf/2210.13382.pdf



## Setup (Don't Read This)

In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install transformer_lens==1.2.1
    %pip install git+https://github.com/neelnanda-io/neel-plotly

    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting transformer_lens==1.2.1
  Downloading transformer_lens-1.2.1-py3-none-any.whl (80 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m80.5/80.5 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets<3.0.0,>=2.7.1 (from transformer_lens==1.2.1)
  Downloading datasets-2.15.0-py3-none-any.whl (521 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting einops<0.7.0,>=0.6.0 (from transformer_lens==1.2.1)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting fancy-einsum<0.0.4,>=0.0.3 (from transformer_lens==1.2.1)
  Downloading fancy_einsum-0.0.3-py3-none-any.whl (6.2 kB)
Collecting jaxtyping<0.3.0,>=0.2.11 (from transformer_lens==1.2.1)
  Downloading jaxtyping-0.2.24-py3-none-any.whl (38 kB)
Collecti

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

Using renderer: colab


In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

KeyboardInterrupt: ignored

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [None]:
torch.set_grad_enabled(False)

Plotting helper functions:

In [None]:
from neel_plotly import line, scatter, imshow, histogram

# Othello GPT

<details><summary>I was in a massive rush when I made this codebase, so it's a bit of a mess, sorry! This Colab is an attempt to put a smiley face on the underlying shoggoth, but if you want more of the guts, here's an info dump</summary>

This codebase is a bit of a mess! This colab is an attempt to be a pretty mask on top of the shoggoth, but if it helps, here's an info dump I wrote for someone about interpreting this codebase:

Technical details:

-   Games are 60 moves, but the model can only take in 59. It's trained to predict the next move, so they give it the first 59 moves (0<=...<59) and evaluate the predictions for each next move (1<=...<60). There is no Beginning of Sequence token, and the model never tries to predict the first move of the game

-   This means that, while in Othello black plays first, here white plays "first" because first is actually second

-   You can get code to load the synthetic model (ie trained to play uniformly random legal moves) into TransformerLens here: [https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Othello\_GPT.ipynb](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Othello_GPT.ipynb)
-   You can load in their synthetically generated games [from their Github](https://github.com/likenneth/othello_world) (there's a google drive link)
-   Their model has 8 layers, residual stream width 512, 8 heads per layer and 2048 neurons per layer.

-   The vocab size is 61. 0 is -100, which I *think *means pass, I just filtered out the rare games that include that move and ignore it. 1 to 60 (inclusive) means the board moves in lexicographic order (A0, A1, ..., A7, B0, ...) but *skipping *D3, D4, E3 and E4. These are at the center of the board and so can never be played, because Othello starts with them filled)

-   There's 3 ways to denote a board cell. I call them "int", "string" and "label" (which is terrible notation, sorry).

-   "label" means the label for a board cell, \["A0", ..., "A7", ''', "H7"\] (I index at 0 not 1, sorry!).
-   "int" means "part of the model vocabulary", so 1 means A0, we then count up but miss the center squares, so 27 is D2, 28 is D5, 33 is E2 and 34 is E5.
-   "string" means "the input format of the OthelloBoardState class". These are integers (sorry!) from 0 to 63, and exactly correspond to labels A0, ..., H7, without skipping any center cells. OthelloBoardState is a class in data/othello.py that can play the Othello game and tell you the board state and valid moves (created by the authors, not me)
-   I have utility functions to\_int, to\_string, str\_to\_label and int\_to\_label in tl\_othello\_utils.py to do this

-   The embedding and unembedding are untied (ie, in contrast to most language models, the map W\_U from final residual to the logits is *not *the transpose of W\_E, the map from tokens to the initial residual. They're unrelated matrices)
-   tl\_othello\_utils.py is my utils file, with various functions to load games, etc. \`board\_seqs\_string\` and \`board\_seqs\_int\` are massive saved tensors with every move across all 4.5M synthetic games in both string and int format, these are 2.3GB so I haven't attached them lol. You can recreate them from the synthetic games they provide. It also provides a bunch of plotting functions to make nice othello board states, and some random other utilities
-   \`tl\_probing.py\` is my probe training file. But it was used to train a *second* probe, linear\_probe\_L4\_blank\_vs\_color\_v1.pth . This probe actually didn't work very well for analysing the model (despite getting great accuracy) and I don't know why - it was trained on layer 4, to do a binary classification on blank vs not blank, and on my color vs their color *conditional *on not being blank (ie not evaluated if blank). For some reason, the "this cell is my color" direction has a significant dot product with the "is blank" direction, and this makes it much worse for eg interpreting neurons. I don't know why!
-   \`tl\_scratch.py\` is where I did some initial exploration, including activation patching between different final moves
-   \`tl\_exploration.py\` is where I did my most recent exploration, verifying that the probe works, doing probe interventions (CTRL F for \`newly\_legal\`) and using the probe to interpret neurons

</details>


## Loading the model

This loads a conversion of the author's synthetic model checkpoint to TransformerLens format. See [this notebook](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Othello_GPT.ipynb) for how.

In [None]:
import transformer_lens.utils as utils
cfg = HookedTransformerConfig(
    n_layers = 8,
    d_model = 512,
    d_head = 64,
    n_heads = 8,
    d_mlp = 2048,
    d_vocab = 61,
    n_ctx = 59,
    act_fn="gelu",
    normalization_type="LNPre"
)
model = HookedTransformer(cfg)

In [None]:

sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "synthetic_model.pth")
# champion_ship_sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "championship_model.pth")
model.load_state_dict(sd)

Code to load and convert one of the author's checkpoints to TransformerLens:

Testing code for the synthetic checkpoint giving the correct outputs

In [None]:
# An example input
sample_input = torch.tensor([[20, 19, 18, 10, 2, 1, 27, 3, 41, 42, 34, 12, 4, 40, 11, 29, 43, 13, 48, 56, 33, 39, 22, 44, 24, 5, 46, 6, 32, 36, 51, 58, 52, 60, 21, 53, 26, 31, 37, 9, 25, 38, 23, 50, 45, 17, 47, 28, 35, 30, 54, 16, 59, 49, 57, 14, 15, 55, 7]])
# The argmax of the output (ie the most likely next move from each position)
sample_output = torch.tensor([[21, 41, 40, 34, 40, 41,  3, 11, 21, 43, 40, 21, 28, 50, 33, 50, 33,  5,
         33,  5, 52, 46, 14, 46, 14, 47, 38, 57, 36, 50, 38, 15, 28, 26, 28, 59,
         50, 28, 14, 28, 28, 28, 28, 45, 28, 35, 15, 14, 30, 59, 49, 59, 15, 15,
         14, 15,  8,  7,  8]])
model(sample_input).argmax(dim=-1)

tensor([[21, 41, 40, 34, 40, 41,  3, 11, 21, 43, 40, 21, 28, 50, 33, 50, 33,  5,
         33,  5, 52, 46, 14, 46, 14, 47, 38, 57, 36, 50, 38, 15, 28, 26, 28, 59,
         50, 28, 14, 28, 28, 28, 28, 45, 28, 35, 15, 14, 30, 59, 49, 59, 15, 15,
         14, 15,  8,  7,  8]], device='cuda:0')

## Loading Othello Content
Boring setup code to load in 100K sample Othello games, the linear probe, and some utility functions

In [None]:

if IN_COLAB:
    !git clone https://github.com/likenneth/othello_world
    OTHELLO_ROOT = Path("/content/othello_world/")
    import sys
    sys.path.append(str(OTHELLO_ROOT/"mechanistic_interpretability"))
    from mech_interp_othello_utils import plot_single_board, to_string, to_int, int_to_label, string_to_label, OthelloBoardState
else:
    OTHELLO_ROOT = Path("/workspace/othello_world/")
    from tl_othello_utils import plot_single_board, to_string, to_int, int_to_label, string_to_label, OthelloBoardState


We load in a big tensor of 100,000 games, each with 60 moves. This is in the format the model wants, with 1-59 representing the 60 moves, and 0 representing pass.

We also load in the same set of games, in the same order, but in "string" format - still a tensor of ints but referring to moves with numbers from 0 to 63 rather than in the model's compressed format of 1 to 59

Number of games: 100000
Length of game: 60


In [None]:
board_seqs_int.unique()

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
        54, 55, 56, 57, 58, 59, 60, 61, 62, 63])

In [None]:
import pickle
from google.colab import drive
drive.mount('/content/drive')
HOME_DIR = Path("drive/MyDrive/6S898/")

board_seqs_int = torch.tensor(np.load(OTHELLO_ROOT/"mechanistic_interpretability/board_seqs_int_small.npy"), dtype=torch.long)
board_seqs_string = torch.tensor(np.load(OTHELLO_ROOT/"mechanistic_interpretability/board_seqs_string_small.npy"), dtype=torch.long)

num_games, length_of_game = board_seqs_int.shape
print("Number of games:", num_games,)
print("Length of game:", length_of_game)
pkl = list(HOME_DIR.glob("*.pickle"))
"""
def int_to_string(y):
  x = y.clone()
  x[x > 26] = x[x > 26] - 2
  x[x > 34] = x[x > 34] - 2
  x += 1
  return x

for f in pkl:
  t = pickle.load(f.open("rb"))
  t = torch.tensor(list(filter(lambda x: len(x) == 60, t)))
  #print(t.shape)
  board_seqs_int = torch.cat((board_seqs_int, int_to_string(t)))
print(board_seqs_int.shape)
"""

In [None]:
board_seqs_int[0]
stoi_indices = [
    0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 29, 30, 31, 32, 33, 34, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
]
itos_mapping = {value: index for index, value in enumerate(stoi_indices)}
print(itos_mapping)


In [None]:


alpha = "ABCDEFGH"

def to_board_label(i):
    return f"{alpha[i//8]}{i%8}"


board_labels = list(map(to_board_label, stoi_indices))

## Making some utilities

At this point, I'll stop and get some aggregate data that will be useful later - a tensor of valid moves, of board states, and a cache of all model activations across 50 games (in practice, you want as much as will comfortably fit into GPU memory). It's really convenient to have the ability to quickly run an experiment across a bunch of games! And one of the great things about small models on algorithmic tasks is that you just can do stuff like this.

For want of a lack of creativity, let's call these the **focus games**

In [None]:
num_games = 1000
focus_games_int = board_seqs_int[:num_games]
focus_games_string = board_seqs_string[:num_games]

A big stack of each move's board state and a big stack of the valid moves in each game (one hot encoded to be a nice tensor)

In [None]:
def one_hot(list_of_ints, num_classes=64):
    out = torch.zeros((num_classes,), dtype=torch.float32)
    out[list_of_ints] = 1.
    return out

In [None]:
cfg = {
    "seed": 49,
    "batch_size": 50,
    "buffer_mult": 50*59,
    "lr": 1e-4,
    "num_tokens": int(2e9),
    "l1_coeff": 3e-4,
    "beta1": 0.9,
    "beta2": 0.99,
    "dict_mult": 8,
    "seq_len": 60,
    "enc_dtype":"fp32",
    "remove_rare_dir": False,
    "model_name": "gelu-2l",
    "site": "mlp_out",
    "layer": 5,
    "device": "cuda:0",
    "act_name": f"blocks.4.hook_mlp_out",
    "model_batch_size": 50
}
site_to_size = {
    "mlp_out": 512,
    "post": 2048,
    "resid_pre": 512,
    "resid_mid": 512,
    "resid_post": 512,
}
def post_init_cfg(cfg):
    #cfg["model_batch_size"] = cfg["batch_size"] // cfg["seq_len"] * 16
    cfg["buffer_size"] = cfg["batch_size"] * cfg["buffer_mult"]
    cfg["buffer_batches"] = cfg["buffer_size"] // cfg["seq_len"]
    cfg["act_size"] = 512
    cfg["dict_size"] = cfg["act_size"] * cfg["dict_mult"]
    cfg["name"] = f"{cfg['model_name']}_{cfg['layer']}_{cfg['dict_size']}_{cfg['site']}"
post_init_cfg(cfg)
from pathlib import Path
SAVE_DIR = Path(".")

@torch.no_grad()
def get_acts(boards, batch_size=50):
    _, cache = model.run_with_cache(boards, stop_at_layer=cfg["layer"]+1, names_filter=cfg["act_name"])
    acts = cache
    acts = acts.reshape(-1, acts.shape[-1])
    subsample = torch.randperm(acts.shape[0])[:batch_size]
    subsampled_acts = acts[subsample, :]
    return subsampled_acts, acts

In [None]:
def replacement_hook(mlp_post, hook, encoder):
    mlp_post_reconstr = encoder(mlp_post)[1]
    return mlp_post_reconstr

def mean_ablate_hook(mlp_post, hook):
    mlp_post[:] = mlp_post.mean([0, 1])
    return mlp_post

def zero_ablate_hook(mlp_post, hook):
    mlp_post[:] = 0.
    return mlp_post

In [None]:
# Frequency
#SPLIT = 0.8
@torch.no_grad()
def get_freqs(num_batches=25, local_encoder=None):
    if local_encoder is None:
        local_encoder = encoder
    act_freq_scores = torch.zeros(local_encoder.d_hidden, dtype=torch.float32).cuda()
    total = 0
    for i in tqdm.trange(num_batches):
        seqs = board_seqs_int[torch.randperm(len(board_seqs_int))[:cfg["model_batch_size"]]]

        _, cache = model.run_with_cache(seqs[:, :-1], stop_at_layer=1) #, names_filter=utils.get_act_name("post", 0))
        c_name = utils.get_act_name("post", 0)
        cache = {c_name: cache[c_name]}
        mlp_acts = cache[c_name]
        mlp_acts = mlp_acts.reshape(-1, 512)

        hidden = local_encoder(mlp_acts)[2]

        act_freq_scores += (hidden > 0).sum(0)
        total+=hidden.shape[0]
    act_freq_scores /= total
    num_dead = (act_freq_scores==0).float().mean()
    print("Num dead", num_dead)
    return act_freq_scores

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        d_hidden = cfg["dict_size"]
        l1_coeff = cfg["l1_coeff"]
        dtype = torch.float32
        torch.manual_seed(cfg["seed"])
        self.W_enc = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(cfg["act_size"], d_hidden, dtype=dtype)))
        self.W_dec = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, cfg["act_size"], dtype=dtype)))
        self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))
        self.b_dec = nn.Parameter(torch.zeros(cfg["act_size"], dtype=dtype))

        self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)

        self.d_hidden = d_hidden
        self.l1_coeff = l1_coeff

        self.to(cfg["device"])

    def forward(self, x):
        x_cent = x - self.b_dec
        acts = F.relu(x_cent @ self.W_enc + self.b_enc)
        x_reconstruct = acts @ self.W_dec + self.b_dec
        l2_loss = (x_reconstruct.float() - x.float()).pow(2).sum(-1).mean(0)
        l1_loss = self.l1_coeff * (acts.float().abs().sum())
        loss = l2_loss + l1_loss
        return loss, x_reconstruct, acts, l2_loss, l1_loss

    @torch.no_grad()
    def make_decoder_weights_and_grad_unit_norm(self):
        W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
        W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed
        self.W_dec.grad -= W_dec_grad_proj
        # Bugfix(?) for ensuring W_dec retains unit norm, this was not there when I trained my original autoencoders.
        self.W_dec.data = W_dec_normed

    def get_version(self):
        version_list = [int(file.name.split(".")[0]) for file in list(SAVE_DIR.iterdir()) if "pt" in str(file)]
        if len(version_list):
            return 1+max(version_list)
        else:
            return 0

    def save(self):
        version = self.get_version()
        torch.save(self.state_dict(), SAVE_DIR/(str(version)+".pt"))
        with open(SAVE_DIR/(str(version)+"_cfg.json"), "w") as f:
            json.dump(cfg, f)
        print("Saved as version", version)

    @classmethod
    def load(cls, version):
        cfg = (json.load(open(SAVE_DIR/(str(version)+"_cfg.json"), "r")))
        pprint.pprint(cfg)
        self = cls(cfg=cfg)
        self.load_state_dict(torch.load(SAVE_DIR/(str(version)+".pt")))
        return self

    @classmethod
    def load_from_hf(cls, version):
        """
        Loads the saved autoencoder from HuggingFace.

        Version is expected to be an int, or "run1" or "run2"

        version 25 is the final checkpoint of the first autoencoder run,
        version 47 is the final checkpoint of the second autoencoder run.
        """
        if version=="run1":
            version = 25
        elif version=="run2":
            version = 47

        cfg = utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}_cfg.json")

        self = cls(cfg=cfg)
        self.load_state_dict(utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}.pt", force_is_torch=True))
        return self

# %%


# %%
class Buffer():
    """
    This defines a data buffer, to store a bunch of MLP acts that can be used to train the autoencoder. It'll automatically run the model to generate more when it gets halfway empty.
    """
    def __init__(self, cfg):
        self.buffer = torch.zeros((cfg["buffer_size"], cfg["act_size"]), dtype=torch.float16, requires_grad=False).to(cfg["device"])
        self.cfg = cfg
        self.token_pointer = 0
        self.first = True
        self.refresh()

    @torch.no_grad()
    def refresh(self):
        self.pointer = 0
        with torch.autocast("cuda", torch.float16):
            if self.first:
                num_batches = self.cfg["buffer_batches"]
            else:
                num_batches = self.cfg["buffer_batches"]//2
            self.first = False
            for _ in range(0, num_batches, self.cfg["model_batch_size"]):
                seqs = board_seqs_int[self.token_pointer:self.token_pointer+self.cfg["model_batch_size"]]

                _, cache = model.run_with_cache(seqs[:,:-1], stop_at_layer=cfg["layer"]+1) #, names_filter=cfg["act_name"])
                cache = {cfg["act_name"]: cache[cfg["act_name"]]}
                acts = cache[(cfg["act_name"])].reshape(-1, self.cfg["act_size"])



                # print(tokens.shape, acts.shape, self.pointer, self.token_pointer)
                #print(acts.shape, self.buffer.shape)
                self.buffer[self.pointer: self.pointer+acts.shape[0]] = acts
                self.pointer += acts.shape[0]
                self.token_pointer += self.cfg["model_batch_size"]
                # if self.token_pointer > all_tokens.shape[0] - self.cfg["model_batch_size"]:
                #     self.token_pointer = 0

        self.pointer = 0
        self.buffer = self.buffer[torch.randperm(self.buffer.shape[0]).to(cfg["device"])]

    @torch.no_grad()
    def next(self):
        out = self.buffer[self.pointer:self.pointer+self.cfg["batch_size"]]
        self.pointer += self.cfg["batch_size"]
        if self.pointer > self.buffer.shape[0]//2 - self.cfg["batch_size"]:
            # print("Refreshing the buffer!")
            self.refresh()
        return out

In [None]:
encoder = AutoEncoder(cfg)
buffer = Buffer(cfg)
encoder.load_state_dict(torch.load("drive/MyDrive/6S898/81.pt"))

<All keys matched successfully>

In [None]:
# %%

print(encoder)
# %%
try:
    num_batches = cfg["num_tokens"] // cfg["batch_size"]
    # model_num_batches = cfg["model_batch_size"] * num_batches
    encoder_optim = torch.optim.Adam(encoder.parameters(), lr=cfg["lr"], betas=(cfg["beta1"], cfg["beta2"]))
    recons_scores = []
    act_freq_scores_list = []
    for i in tqdm.trange(num_batches):
        acts = buffer.next()
        loss, x_reconstruct, mid_acts, l2_loss, l1_loss = encoder(acts)
        loss.backward()
        encoder.make_decoder_weights_and_grad_unit_norm()
        encoder_optim.step()
        encoder_optim.zero_grad()
        loss_dict = {"loss": loss.item(), "l2_loss": l2_loss.item(), "l1_loss": l1_loss.item()}
        del loss, x_reconstruct, mid_acts, l2_loss, l1_loss, acts
        if (i) % 100 == 0:
            print(loss_dict)
        if (i) % 1000 == 0:

            x = (get_recons_loss(local_encoder=encoder))
            print("Reconstruction:", x)
            recons_scores.append(x[0])
            freqs = get_freqs(5, local_encoder=encoder)
            act_freq_scores_list.append(freqs)
            # histogram(freqs.log10(), marginal="box", histnorm="percent", title="Frequencies")
            print({
                "recons_score": x[0],
                "dead": (freqs==0).float().mean().item(),
                "below_1e-6": (freqs<1e-6).float().mean().item(),
                "below_1e-5": (freqs<1e-5).float().mean().item(),
            })
        if (i+1) % 30000 == 0:
            encoder.save()
            freqs = get_freqs(50, local_encoder=encoder)
            to_be_reset = (freqs<10**(-5.5))
            print("Resetting neurons!", to_be_reset.sum())
finally:
    encoder.save()

RuntimeError: ignored

In [None]:
buffer = Buffer(cfg)


RuntimeError: ignored

In [None]:
board_seqs_int.device

device(type='cpu')

In [None]:
focus_tasks = board_seqs_int[-20000:].clone()
ids = list(range(len(board_seqs_int)))[-20000:]

In [None]:
focus_tasks.shape

torch.Size([20000, 60])

In [None]:
focus_tasks[99*100: 99*100 + 100].unique()

tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
        19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
        37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
        55, 56, 57, 58, 59, 60])

## Interpreting using the SAE

In [None]:
acts = torch.zeros((100, 512)).cuda()
encoded_acts = torch.zeros(1, 4096).cuda()
model_c = model
for lm in range(0, len(focus_tasks), 100):
  if lm > 20 * 100:
    break
  seqs = focus_tasks[lm: lm+100]
  _, cache = model_c.run_with_cache(seqs[:,:-1], stop_at_layer=cfg["layer"]+1) #, names_filter=cfg["act_name"])
  cache = {cfg["act_name"]: cache[cfg["act_name"]]}
  acts = cache[(cfg["act_name"])].reshape(-1, cfg["act_size"])
  loss, x_reconstruct, mid_acts, l2_loss, l1_loss = encoder(acts)
  encoded_acts = torch.cat((encoded_acts, mid_acts))
encoded_acts = encoded_acts[1:]
encoded_acts.shape

torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])
torch.Size([5900, 512])


torch.Size([182900, 4096])

In [None]:
#encoded_acts.shape

torch.Size([106201, 4096])

In [None]:
z = len(board_seqs_int) - 20000 + i//59
plot_single_board(int_to_label(board_seqs_int[z]))

In [None]:
captures_a_lot = 2315
local_regions = [774, 1892, 457, 831]

In [None]:
c_id = local_regions[0]
indices = torch.randperm(encoded_acts.shape[0])[:20]
# top_acts = encoded_acts[top_acts_idx]
for i in indices:
  print(encoded_acts[i][c_id])
  if i % 59 < 3: continue
  board = board_seqs_int[len(board_seqs_int) - 20000 + i//59][:i%59]
  print(board)
  plot_single_board(int_to_label(board))

In [None]:
location = "G2"

In [None]:
c_id = local_regions[0]
top_acts_idx = torch.topk(encoded_acts[:, c_id], k=30).indices
# top_acts = encoded_acts[top_acts_idx]
for i in top_acts_idx:
  print(encoded_acts[i][c_id], board_seqs_int[len(board_seqs_int) - 20000 + i//59][i%59])
  board = board_seqs_int[len(board_seqs_int) - 20000 + i//59][:i%59]
  plot_single_board(int_to_label(board))

tensor(0.1425, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1292, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1292, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1269, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1263, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1250, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1220, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1217, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1215, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1146, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1118, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1103, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1102, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1097, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1094, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1077, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1073, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1070, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1065, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1060, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1058, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1053, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1050, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1048, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1047, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1047, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1034, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1026, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1016, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1005, device='cuda:0', grad_fn=<SelectBackward0>)


In [None]:
for id in local_regions:
  top_acts_idx = torch.topk(encoded_acts[:, id], k=10).indices
 # top_acts = encoded_acts[top_acts_idx]
  print(id)
  for i in top_acts_idx:
    print(encoded_acts[i][id])
    board = board_seqs_int[len(board_seqs_int) - 20000 + i//59][:i%59]
    plot_single_board(int_to_label(board))

774
tensor(0.1425, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1292, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1292, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1269, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1263, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1250, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1220, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1217, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1215, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1146, device='cuda:0', grad_fn=<SelectBackward0>)


1892
tensor(0.2806, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.2746, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.2586, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.2404, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.2396, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.2374, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.2343, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.2280, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.2234, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.2170, device='cuda:0', grad_fn=<SelectBackward0>)


457
tensor(1.2946, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(1.1436, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.9577, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.9504, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.9293, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.9027, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.8817, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.8746, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.8633, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.8501, device='cuda:0', grad_fn=<SelectBackward0>)


831
tensor(0.4673, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.2785, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.2692, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.2657, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.2194, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1789, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1784, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1686, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1448, device='cuda:0', grad_fn=<SelectBackward0>)


tensor(0.1367, device='cuda:0', grad_fn=<SelectBackward0>)


In [None]:
# prompt: pick a random autoencoder dimension and find / display top activating boards
import random
for _ in range(5):
  id = random.randint(1, 4095)
  print(id)
  top_acts_idx = torch.topk(encoded_acts[:, id], k=5).indices
  print(top_acts_idx)
  top_acts = encoded_acts[top_acts_idx]

  for i in top_acts_idx:
    board = board_seqs_int[len(board_seqs_int) - 20000 + i//59][:i%59]
    plot_single_board(int_to_label(board))

2791
tensor([58663, 66938, 84471, 15547, 60988], device='cuda:0')


3764
tensor([51860, 40452,  1290, 52799, 72563], device='cuda:0')


831
tensor([ 86716, 105891,  18995,  23804,   8783], device='cuda:0')


457
tensor([44294, 11439, 30612, 96146, 30782], device='cuda:0')


2344
tensor([ 7122, 87350, 28016, 10905, 32296], device='cuda:0')


A cache of every model activation and the logits