## Setup

### Setup 1

In [1]:
import os, sys
chapter = "chapter1_transformer_interp"
repo = "ARENA_3.0"
chapter_dir = r"./" if chapter in os.listdir() else os.getcwd().split(chapter)[0]
sys.path.append(chapter_dir + f"{chapter}/exercises")

import os
os.environ["ACCELERATE_DISABLE_RICH"] = "1"
import sys
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch.utils.data import DataLoader
import numpy as np
import einops
from ipywidgets import interact
import plotly.express as px
from ipywidgets import interact
from pathlib import Path
import itertools
import random
from IPython.display import display
from jaxtyping import Float, Int, Bool, Shaped, jaxtyped
from typing import List, Union, Optional, Tuple, Callable, Dict
import typeguard
from functools import partial
# from torcheval.metrics.functional import multiclass_f1_score
from sklearn.metrics import f1_score as multiclass_f1_score
import copy
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import HookedRootModule, HookPoint
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from tqdm.notebook import tqdm
from dataclasses import dataclass
from rich import print as rprint
import pandas as pd

# Make sure exercises are in the path
# exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
# section_dir = exercises_dir / "part6_othellogpt"
# section_dir = "interpretability"
# if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow
from neel_plotly import scatter, line
# import part6_othellogpt.tests as tests

device = t.device("cuda" if t.cuda.is_available() else "cpu")



### Setup 2

In [3]:
MAIN = __name__ == "__main__"

cfg = HookedTransformerConfig(
    n_layers = 8,
    d_model = 512,
    d_head = 64,
    n_heads = 8,
    d_mlp = 2048,
    d_vocab = 61,
    n_ctx = 59,
    act_fn="gelu",
    normalization_type="LNPre",
    device=device,
)
model = HookedTransformer(cfg)

sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "synthetic_model.pth")
# champion_ship_sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "championship_model.pth")
model.load_state_dict(sd)

# An example input
sample_input = t.tensor([[
    20, 19, 18, 10,  2,  1, 27,  3, 41, 42, 34, 12,  4, 40, 11, 29, 43, 13, 48, 56,
    33, 39, 22, 44, 24,  5, 46,  6, 32, 36, 51, 58, 52, 60, 21, 53, 26, 31, 37,  9,
    25, 38, 23, 50, 45, 17, 47, 28, 35, 30, 54, 16, 59, 49, 57, 14, 15, 55, 7
]]).to(device)

# The argmax of the output (ie the most likely next move from each position)
sample_output = t.tensor([[
    21, 41, 40, 34, 40, 41,  3, 11, 21, 43, 40, 21, 28, 50, 33, 50, 33,  5, 33,  5,
    52, 46, 14, 46, 14, 47, 38, 57, 36, 50, 38, 15, 28, 26, 28, 59, 50, 28, 14, 28,
    28, 28, 28, 45, 28, 35, 15, 14, 30, 59, 49, 59, 15, 15, 14, 15,  8,  7,  8
]]).to(device)

assert (model(sample_input).argmax(dim=-1) == sample_output.to(device)).all()

# os.chdir(section_dir)
section_dir = Path.cwd()
sys.path.append(str(section_dir))
print(section_dir.name)

OTHELLO_ROOT = (section_dir / "othello_world").resolve()
OTHELLO_MECHINT_ROOT = (OTHELLO_ROOT / "mechanistic_interpretability").resolve()

# if not OTHELLO_ROOT.exists():
#     !git clone https://github.com/likenneth/othello_world

sys.path.append(str(OTHELLO_MECHINT_ROOT))

from mech_interp_othello_utils import (
    plot_board,
    plot_single_board,
    plot_board_log_probs,
    to_string,
    to_int,
    int_to_label,
    string_to_label,
    OthelloBoardState
)

# Load board data as ints (i.e. 0 to 60)
board_seqs_int = t.tensor(np.load(OTHELLO_MECHINT_ROOT / "board_seqs_int_small.npy"), dtype=t.long)
# Load board data as "strings" (i.e. 0 to 63 with middle squares skipped out)
board_seqs_string = t.tensor(np.load(OTHELLO_MECHINT_ROOT / "board_seqs_string_small.npy"), dtype=t.long)

assert all([middle_sq not in board_seqs_string for middle_sq in [27, 28, 35, 36]])
assert board_seqs_int.max() == 60

num_games, length_of_game = board_seqs_int.shape

# Define possible indices (excluding the four center squares)
stoi_indices = [i for i in range(64) if i not in [27, 28, 35, 36]]

# Define our rows, and the function that converts an index into a (row, column) label, e.g. `E2`
alpha = "ABCDEFGH"

def to_board_label(i):
    return f"{alpha[i//8]}{i%8}"

# Get our list of board labels
board_labels = list(map(to_board_label, stoi_indices))
full_board_labels = list(map(to_board_label, range(64)))

def plot_square_as_board(state, diverging_scale=True, **kwargs):
    """Takes a square input (8 by 8) and plot it as a board. Can do a stack of boards via facet_col=0"""
    kwargs = {
        "y": [i for i in alpha],
        "x": [str(i) for i in range(8)],
        "color_continuous_scale": "RdBu" if diverging_scale else "Blues",
        "color_continuous_midpoint": 0. if diverging_scale else None,
        "aspect": "equal",
        **kwargs
    }
    imshow(state, **kwargs)

start = 30000
num_games = 50
focus_games_int = board_seqs_int[start : start + num_games]
focus_games_string = board_seqs_string[start: start + num_games]

focus_logits, focus_cache = model.run_with_cache(focus_games_int[:, :-1].to(device))
focus_logits.shape

def one_hot(list_of_ints, num_classes=64):
    out = t.zeros((num_classes,), dtype=t.float32)
    out[list_of_ints] = 1.
    return out

focus_states = np.zeros((num_games, 60, 8, 8), dtype=np.float32)
focus_valid_moves = t.zeros((num_games, 60, 64), dtype=t.float32)

for i in (range(num_games)):
    board = OthelloBoardState()
    for j in range(60):
        board.umpire(focus_games_string[i, j].item())
        focus_states[i, j] = board.state
        focus_valid_moves[i, j] = one_hot(board.get_valid_moves())

print("focus states:", focus_states.shape)
print("focus_valid_moves", tuple(focus_valid_moves.shape))

full_linear_probe = t.load(OTHELLO_MECHINT_ROOT / "main_linear_probe.pth", map_location=device)

rows = 8
cols = 8
options = 3
assert full_linear_probe.shape == (3, cfg.d_model, rows, cols, options)

black_to_play_index = 0
white_to_play_index = 1
blank_index = 0
their_index = 1
my_index = 2

# Creating values for linear probe (converting the "black/white to play" notation into "me/them to play")
linear_probe = t.zeros(cfg.d_model, rows, cols, options, device=device)
linear_probe[..., blank_index] = 0.5 * (full_linear_probe[black_to_play_index, ..., 0] + full_linear_probe[white_to_play_index, ..., 0])
linear_probe[..., their_index] = 0.5 * (full_linear_probe[black_to_play_index, ..., 1] + full_linear_probe[white_to_play_index, ..., 2])
linear_probe[..., my_index] = 0.5 * (full_linear_probe[black_to_play_index, ..., 2] + full_linear_probe[white_to_play_index, ..., 1])

layer = 6
game_index = 0
move = 29


interpretability
focus states: (50, 60, 8, 8)
focus_valid_moves (50, 60, 64)


## Code

In [4]:
board_seqs_int = t.load(
    os.path.join(
        "data/board_seqs_int_valid.pth",
    )
)
board_seqs_int.shape

torch.Size([500000, 60])

In [5]:
# from utils import plot_game
def plot_game(focus_states, game_index=0, end_move=16):
    imshow(
        focus_states[game_index, :end_move],
        facet_col=0,
        facet_col_wrap=8,
        facet_labels=[f"Move {i}" for i in range(0, end_move)],
        title="First 16 moves of first game",
        color_continuous_scale="Greys",
        y = [i for i in alpha],
    )

In [None]:
'''move = 7
layer = 5
probe_name = "mine_theirs_first_tile"

mine_theirs_first_tile = t.load(
    os.path.join(
        f"models2/{probe_name}_L{layer}.pth",
    )
)

blank_or_not = mine_theirs_first_tile[]'''

In [29]:
def plot_probe_outputs_intervention_template(model_name, facet_labels, even_odd_all_index, option):
    def plot_probe_outputs_intervention(layer, game_index, move, **kwargs):
        residual_stream = focus_cache["resid_post", layer][game_index, move]
        probe = t.load(
            os.path.join(
                f"models2/{model_name}_L{layer}.pth",
            )
        )
        # print("residual_stream", residual_stream.shape)
        print(residual_stream.shape)
        print(probe.shape)
        probe = probe[even_odd_all_index]
        options = probe.shape[-1]
        probe_out = einops.einsum(residual_stream, probe, "d_model, d_model row col options -> row col options")
        probabilities = probe_out.softmax(dim=-1)
        probabilities = probabilities[:, :, option].unsqueeze(dim=-1)
        probabilities = t.cat([probabilities, 1-probabilities], dim=-1)
        print(probabilities.shape)
        print(options)
        print(facet_labels)
        plot_square_as_board(probabilities, facet_col=2, facet_labels=facet_labels, **kwargs)
    return plot_probe_outputs_intervention

'''plot_probe_outputs_intervention_flipped = plot_probe_outputs_intervention_template(
    model_name="flipped",
    facet_labels=["P(Flipped)", "P(Not Flipped)"],
    even_odd_all_index=2
)'''

'''plot_probe_outputs_intervention_mine_theirs_first_tile_odd_blank_or_not = plot_probe_outputs_intervention_template(
    model_name="mine_theirs_first_tile",
    facet_labels=["P(Blank)", "P(First Tile Their)", "P(First Tile Mine)"],
    even_odd_all_index=1
)'''
plot_probe_outputs_intervention_mine_theirs_first_tile_odd_blank_or_not = plot_probe_outputs_intervention_template(
    model_name="mine_theirs_first_tile",
    facet_labels=["P(Blank)", "P(not Blank)"],
    even_odd_all_index=1,
    option=0
)

plot_probe_outputs_intervention_mine_theirs_first_tile_odd_mine_or_not = plot_probe_outputs_intervention_template(
    model_name="mine_theirs_first_tile",
    facet_labels=["P(Mine)", "P(not Mine)"],
    even_odd_all_index=1,
    option=2
)

plot_probe_outputs_intervention_mine_theirs_first_tile_odd_their_or_not = plot_probe_outputs_intervention_template(
    model_name="mine_theirs_first_tile",
    facet_labels=["P(Their)", "P(not Their)"],
    even_odd_all_index=1,
    option=1
)

plot_probe_outputs_intervention_mine_theirs_first_tile_even_blank_or_not = plot_probe_outputs_intervention_template(
    model_name="mine_theirs_first_tile",
    facet_labels=["P(Blank)", "P(not Blank)"],
    even_odd_all_index=0,
    option=0
)

plot_probe_outputs_intervention_mine_theirs_first_tile_even_mine_or_not = plot_probe_outputs_intervention_template(
    model_name="mine_theirs_first_tile",
    facet_labels=["P(Mine)", "P(not Mine)"],
    even_odd_all_index=0,
    option=2
)

plot_probe_outputs_intervention_mine_theirs_first_tile_even_their_or_not = plot_probe_outputs_intervention_template(
    model_name="mine_theirs_first_tile",
    facet_labels=["P(Their)", "P(not Their)"],
    even_odd_all_index=0,
    option=1
)

In [37]:
from training_utils import get_state_stack_first_tile_places_mine_theirs
move = 8
layer = 5
end_move = 16
plot_game(focus_states, game_index)
# plot_probe_outputs_intervention_mine_theirs_first_tile_odd_blank_or_not(layer, game_index, move)
# plot_probe_outputs_intervention_mine_theirs_first_tile_odd_mine_or_not(layer, game_index, move)
# plot_probe_outputs_intervention_mine_theirs_first_tile_odd_their_or_not(layer, game_index, move)
# plot_probe_outputs_intervention_mine_theirs_first_tile_even_blank_or_not(layer, game_index, move)
plot_probe_outputs_intervention_mine_theirs_first_tile_even_mine_or_not(layer, game_index, move)
# plot_probe_outputs_intervention_mine_theirs_first_tile_even_their_or_not(layer, game_index, move)
focus_states_num_flipped = get_state_stack_first_tile_places_mine_theirs(focus_games_string)
imshow(
        focus_states_num_flipped[game_index, :end_move],
        facet_col=0,
        facet_col_wrap=8,
        facet_labels=[f"Move {i}" for i in range(0, end_move)],
        title="First 16 moves of first game",
        color_continuous_scale="Greys",
        y = [i for i in alpha],
    )

torch.Size([512])
torch.Size([3, 512, 8, 8, 3])
torch.Size([8, 8, 2])
3
['P(Mine)', 'P(not Mine)']


In [17]:
move = 7
layer = 5
def plot_probe_outputs_intervention_flipped(layer, game_index, move, **kwargs):
    residual_stream = focus_cache["resid_post", layer][game_index, move]
    probe = t.load(
        os.path.join(
            f"models/flipped_L{layer}.pth",
        )
    )
    # print("residual_stream", residual_stream.shape)
    print(residual_stream.shape)
    print(probe.shape)
    probe = probe[2]
    probe_out = einops.einsum(residual_stream, probe, "d_model, d_model row col options -> row col options")
    probabilities = probe_out.softmax(dim=-1)
    print(probabilities.shape)
    plot_square_as_board(probabilities, facet_col=2, facet_labels=["P(Flipped)", "P(Not Flipped)"], **kwargs)

plot_game(focus_states, game_index)
plot_probe_outputs_intervention_flipped(layer, game_index, move)

torch.Size([512])
torch.Size([3, 512, 8, 8, 2])
torch.Size([8, 8, 2])


In [6]:
move = 24
def plot_probe_outputs_intervention_even_odd_flipped(layer, game_index, move, **kwargs):
    residual_stream = focus_cache["resid_post", layer][game_index, move]
    probe = t.load(
        os.path.join(
            f"models/even_odd_flipped_L{layer}.pth",
        )
    )
    # print("residual_stream", residual_stream.shape)
    print(residual_stream.shape)
    print(probe.shape)
    probe = probe[2]
    probe_out = einops.einsum(residual_stream, probe, "d_model, d_model row col options -> row col options")
    probabilities = probe_out.softmax(dim=-1)
    plot_square_as_board(probabilities, facet_col=2, facet_labels=["P(Flipped_Even)", "P(Flipped_Odd)"], **kwargs)

plot_game(focus_states, game_index)
plot_probe_outputs_intervention_even_odd_flipped(layer, game_index, move)

torch.Size([512])
torch.Size([3, 512, 8, 8, 2])


In [7]:
from training_utils import get_state_stack_num_flipped
from training_probes import ProbeTrainingArgs
args = ProbeTrainingArgs()

In [13]:
move = 12
def plot_probe_outputs_intervention_num_flipped(layer, game_index, move, **kwargs):
    residual_stream = focus_cache["resid_post", layer][game_index, move]
    probe = t.load(
        os.path.join(
            f"models/num_flipped_L{layer}.pth",
        )
    )
    # print("residual_stream", residual_stream.shape)
    print(residual_stream.shape)
    print(probe.shape)
    probe = probe[2][:,:,:,:5]
    avg_cosine_sim = 0
    for i in range(5):
        for j in range(i + 1, 5):
            cosine_sim = t.nn.functional.cosine_similarity(probe[:, 3, 3, i], probe[:, 3, 3, j], dim=0)
            # print(f"Cosine similarity between {i} and {j} flipped squares:", cosine_sim.item())
            avg_cosine_sim += cosine_sim
    avg_cosine_sim /= 10
    print("Average cosine similarity between flipped squares:", avg_cosine_sim.item())
    probe_out = einops.einsum(residual_stream, probe, "d_model, d_model row col options -> row col options")
    probabilities = probe_out.softmax(dim=-1)
    facet_labels = [f"P({i} Flipped)" for i in range(5)]
    plot_square_as_board(probabilities, facet_col=2, facet_labels=facet_labels, **kwargs)

end_move = 16
# focus_games_num_flipped = focus_games_string[start : start + num_games]
# print(focus_games_num_flipped.shape)

plot_game(focus_states, game_index)
focus_states_num_flipped = get_state_stack_num_flipped(focus_games_string, args, only_middle=False)
imshow(
        focus_states_num_flipped[game_index, :end_move],
        facet_col=0,
        facet_col_wrap=8,
        facet_labels=[f"Move {i}" for i in range(0, end_move)],
        title="First 16 moves of first game",
        color_continuous_scale="Greys",
        y = [i for i in alpha],
    )
plot_probe_outputs_intervention_num_flipped(layer, game_index, move)

torch.Size([512])
torch.Size([3, 512, 8, 8, 18])
Cosine similarity between 0 and 1 flipped squares: 0.41522470116615295
Cosine similarity between 0 and 2 flipped squares: 0.13425834476947784
Cosine similarity between 0 and 3 flipped squares: -0.14185288548469543
Cosine similarity between 0 and 4 flipped squares: -0.35662245750427246
Cosine similarity between 1 and 2 flipped squares: 0.29375189542770386
Cosine similarity between 1 and 3 flipped squares: 0.09530141949653625
Cosine similarity between 1 and 4 flipped squares: -0.1668206751346588
Cosine similarity between 2 and 3 flipped squares: 0.262576699256897
Cosine similarity between 2 and 4 flipped squares: 0.13699443638324738
Cosine similarity between 3 and 4 flipped squares: 0.34099310636520386
Average cosine similarity between flipped squares: 0.10138045996427536
