In [1]:
print("Running as a Jupyter notebook - intended for development only!")
from IPython import get_ipython

ipython = get_ipython()
# Code to automatically update the HookedTransformer code as its edited without restarting the kernel
ipython.magic("load_ext autoreload")
ipython.magic("autoreload 2")

import plotly.io as pio
pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

Running as a Jupyter notebook - intended for development only!
Using renderer: notebook_connected


In [2]:
import circuitsvis as cv
# Testing that the library works
cv.examples.hello("Neel")

In [3]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fed782a5a30>

In [4]:
from neel_plotly import line, scatter, imshow, histogram

### Load 8L 8H Model and Convert to Transformer Lens

In [6]:
def convert_to_transformer_lens_format(in_sd, n_layers=8, n_heads=8):
    out_sd = {}
    out_sd["pos_embed.W_pos"] = in_sd["pos_emb"].squeeze(0)
    out_sd["embed.W_E"] = in_sd["tok_emb.weight"]

    out_sd["ln_final.w"] = in_sd["ln_f.weight"]
    out_sd["ln_final.b"] = in_sd["ln_f.bias"]
    out_sd["unembed.W_U"] = in_sd["head.weight"].T

    for layer in range(n_layers):
        out_sd[f"blocks.{layer}.ln1.w"] = in_sd[f"blocks.{layer}.ln1.weight"]
        out_sd[f"blocks.{layer}.ln1.b"] = in_sd[f"blocks.{layer}.ln1.bias"]
        out_sd[f"blocks.{layer}.ln2.w"] = in_sd[f"blocks.{layer}.ln2.weight"]
        out_sd[f"blocks.{layer}.ln2.b"] = in_sd[f"blocks.{layer}.ln2.bias"]

        out_sd[f"blocks.{layer}.attn.W_Q"] = einops.rearrange(
            in_sd[f"blocks.{layer}.attn.query.weight"], "(head d_head) d_model -> head d_model d_head", head=n_heads
        )
        out_sd[f"blocks.{layer}.attn.b_Q"] = einops.rearrange(
            in_sd[f"blocks.{layer}.attn.query.bias"], "(head d_head) -> head d_head", head=n_heads
        )
        out_sd[f"blocks.{layer}.attn.W_K"] = einops.rearrange(
            in_sd[f"blocks.{layer}.attn.key.weight"], "(head d_head) d_model -> head d_model d_head", head=n_heads
        )
        out_sd[f"blocks.{layer}.attn.b_K"] = einops.rearrange(
            in_sd[f"blocks.{layer}.attn.key.bias"], "(head d_head) -> head d_head", head=n_heads
        )
        out_sd[f"blocks.{layer}.attn.W_V"] = einops.rearrange(
            in_sd[f"blocks.{layer}.attn.value.weight"], "(head d_head) d_model -> head d_model d_head", head=n_heads
        )
        out_sd[f"blocks.{layer}.attn.b_V"] = einops.rearrange(
            in_sd[f"blocks.{layer}.attn.value.bias"], "(head d_head) -> head d_head", head=n_heads
        )
        out_sd[f"blocks.{layer}.attn.W_O"] = einops.rearrange(
            in_sd[f"blocks.{layer}.attn.proj.weight"], "d_model (head d_head) -> head d_head d_model", head=n_heads
        )
        out_sd[f"blocks.{layer}.attn.b_O"] = in_sd[f"blocks.{layer}.attn.proj.bias"]

        out_sd[f"blocks.{layer}.mlp.b_in"] = in_sd[f"blocks.{layer}.mlp.0.bias"]
        out_sd[f"blocks.{layer}.mlp.W_in"] = in_sd[f"blocks.{layer}.mlp.0.weight"].T
        out_sd[f"blocks.{layer}.mlp.b_out"] = in_sd[f"blocks.{layer}.mlp.2.bias"]
        out_sd[f"blocks.{layer}.mlp.W_out"] = in_sd[f"blocks.{layer}.mlp.2.weight"].T
    
    return out_sd

## Make sure to change path

In [22]:
heads, layers = 8,8
# Change this path
path = "../EWOthello/ckpts/DeanKLi_GPT_Synthetic_8L8H/GPT_Synthetic_8Layers_8Heads.ckpt"
synthetic_checkpoint = torch.load(path, map_location='cpu')
for name, param in synthetic_checkpoint.items():
    if name.startswith("blocks.0") or not name.startswith("blocks"):
        print(name, param.shape)

cfg = HookedTransformerConfig(
    n_layers = layers,
    d_model = 512,
    d_head = 64,
    n_heads = heads,
    d_mlp = 2048,
    d_vocab = 61,
    n_ctx = 59,
    act_fn="gelu",
    normalization_type="LNPre"
)
model = HookedTransformer(cfg)


model.load_and_process_state_dict(convert_to_transformer_lens_format(synthetic_checkpoint, n_layers=layers, n_heads=heads))


pos_emb torch.Size([1, 59, 512])
tok_emb.weight torch.Size([61, 512])
blocks.0.ln1.weight torch.Size([512])
blocks.0.ln1.bias torch.Size([512])
blocks.0.ln2.weight torch.Size([512])
blocks.0.ln2.bias torch.Size([512])
blocks.0.attn.mask torch.Size([1, 1, 59, 59])
blocks.0.attn.key.weight torch.Size([512, 512])
blocks.0.attn.key.bias torch.Size([512])
blocks.0.attn.query.weight torch.Size([512, 512])
blocks.0.attn.query.bias torch.Size([512])
blocks.0.attn.value.weight torch.Size([512, 512])
blocks.0.attn.value.bias torch.Size([512])
blocks.0.attn.proj.weight torch.Size([512, 512])
blocks.0.attn.proj.bias torch.Size([512])
blocks.0.mlp.0.weight torch.Size([2048, 512])
blocks.0.mlp.0.bias torch.Size([2048])
blocks.0.mlp.2.weight torch.Size([512, 2048])
blocks.0.mlp.2.bias torch.Size([512])
ln_f.weight torch.Size([512])
ln_f.bias torch.Size([512])
head.weight torch.Size([61, 512])


### Load Othello Content
Boring setup code to load in 100K sample Othello games, the linear probe, and some utility functions

In [23]:
from EWOthello.othello_world.mechanistic_interpretability.mech_interp_othello_utils import plot_single_board, to_string, to_int, int_to_label, string_to_label, OthelloBoardState
OTHELLO_ROOT = Path("../EWOthello/othello_world/")

board_seqs_int = torch.tensor(np.load(OTHELLO_ROOT/"board_seqs_int_small.npy"), dtype=torch.long)
board_seqs_string = torch.tensor(np.load(OTHELLO_ROOT/"board_seqs_string_small.npy"), dtype=torch.long)

num_games, length_of_game = board_seqs_int.shape
print("Number of games:", num_games,)
print("Length of game:", length_of_game)

Number of games: 100000
Length of game: 60


In [24]:
stoi_indices = [
    0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 29, 30, 31, 32, 33, 34, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
]
alpha = "ABCDEFGH"


def to_board_label(i):
    return f"{alpha[i//8]}{i%8}"


board_labels = list(map(to_board_label, stoi_indices))

### Running the Model
Keep in mind that state 61 corresponds to passing (not making a move)

In [25]:
moves_int = board_seqs_int[0, :30]

# This is implicitly converted to a batch of size 1
logits = model(moves_int)
print("logits:", logits.shape)

logits: torch.Size([1, 30, 61])


In [26]:
logit_vec = logits[0, -1]
log_probs = logit_vec.log_softmax(-1)
# Remove passing
log_probs = log_probs[1:]
assert len(log_probs)==60

temp_board_state = torch.zeros(64, device=logit_vec.device)
# Set all cells to -15 by default, for a very negative log prob - this means the middle cells don't show up as mattering
temp_board_state -= 13.
temp_board_state[stoi_indices] = log_probs

In [27]:
def plot_square_as_board(state, diverging_scale=True, **kwargs):
    """Takes a square input (8 by 8) and plot it as a board. Can do a stack of boards via facet_col=0"""
    if diverging_scale:
        imshow(state, y=[i for i in alpha], x=[str(i) for i in range(8)], color_continuous_scale="RdBu", color_continuous_midpoint=0., aspect="equal", **kwargs)
    else:
        imshow(state, y=[i for i in alpha], x=[str(i) for i in range(8)], color_continuous_scale="Blues", color_continuous_midpoint=None, aspect="equal", **kwargs)
plot_square_as_board(temp_board_state.reshape(8, 8), zmax=0, diverging_scale=False, title="Example Log Probs")

### Exploring Gameplay

In [28]:
board = OthelloBoardState()
board.update(to_string(moves_int))
plot_square_as_board(board.state, title="Example Board State (+1 is Black, -1 is White)")

In [29]:
plot_single_board(int_to_label(moves_int))

In [30]:
print("Valid moves:", string_to_label(board.get_valid_moves()))

Valid moves: ['A3', 'A5', 'A6', 'B2', 'C7', 'D2', 'E6', 'F7', 'G6', 'H2', 'H3', 'H4', 'H6']


In [31]:
num_games = 50
focus_games_int = board_seqs_int[:num_games]
focus_games_string = board_seqs_string[:num_games]

In [32]:
def one_hot(list_of_ints, num_classes=64):
    out = torch.zeros((num_classes,), dtype=torch.float32)
    out[list_of_ints] = 1.
    return out
focus_states = np.zeros((num_games, 60, 8, 8), dtype=np.float32)
focus_valid_moves = torch.zeros((num_games, 60, 64), dtype=torch.float32)
for i in (range(num_games)):
    board = OthelloBoardState()
    for j in range(60):
        board.umpire(focus_games_string[i, j].item())
        focus_states[i, j] = board.state
        focus_valid_moves[i, j] = one_hot(board.get_valid_moves())
print("focus states:", focus_states.shape)
print("focus_valid_moves", focus_valid_moves.shape)


focus states: (50, 60, 8, 8)
focus_valid_moves torch.Size([50, 60, 64])


In [33]:
focus_games_int[0, :].shape

torch.Size([60])

In [34]:
moves = 25
sample_game = focus_games_int[0, :moves+1]
focus_logits, focus_cache = model.run_with_cache(sample_game)

In [37]:
print(type(focus_cache))
attention_pattern = focus_cache["pattern", 0, "attn"]
print(attention_pattern.shape)

<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([1, 8, 26, 26])


In [38]:
cv.attention.attention_patterns(tokens=string_to_label(focus_games_int[0, :moves+1]), attention=attention_pattern[0])