# Setup

In [1]:
import os, sys
chapter = "chapter1_transformer_interp"
repo = "ARENA_3.0"
chapter_dir = r"./" if chapter in os.listdir() else os.getcwd().split(chapter)[0]
sys.path.append(chapter_dir + f"{chapter}/exercises")

import os
os.environ["ACCELERATE_DISABLE_RICH"] = "1"
import sys
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch.utils.data import DataLoader
import numpy as np
import einops
from ipywidgets import interact
import plotly.express as px
from ipywidgets import interact
from pathlib import Path
import itertools
import random
from IPython.display import display
from jaxtyping import Float, Int, Bool, Shaped, jaxtyped
from typing import List, Union, Optional, Tuple, Callable, Dict
import typeguard
from functools import partial
# from torcheval.metrics.functional import multiclass_f1_score
from sklearn.metrics import f1_score as multiclass_f1_score
import copy
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import HookedRootModule, HookPoint
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from tqdm.notebook import tqdm
from dataclasses import dataclass
from rich import print as rprint
import pandas as pd

# Make sure exercises are in the path
# exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
# section_dir = exercises_dir / "part6_othellogpt"
# section_dir = "interpretability"
# if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow
from neel_plotly import scatter, line
from generate_patches import generate_patch
from pprint import pprint
from utils import plot_game
from training_utils import get_state_stack_num_flipped
from utils import plot_probe_outputs
from utils import seq_to_state_stack
# import part6_othellogpt.tests as tests

t.manual_seed(42)

device = t.device("cuda" if t.cuda.is_available() else "cpu")

MAIN = __name__ == "__main__"

cfg = HookedTransformerConfig(
    n_layers = 8,
    d_model = 512,
    d_head = 64,
    n_heads = 8,
    d_mlp = 2048,
    d_vocab = 61,
    n_ctx = 59,
    act_fn="gelu",
    normalization_type="LNPre",
    device=device,
)
model = HookedTransformer(cfg)

sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "synthetic_model.pth")
# champion_ship_sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "championship_model.pth")
model.load_state_dict(sd)

# An example input
sample_input = t.tensor([[
    20, 19, 18, 10,  2,  1, 27,  3, 41, 42, 34, 12,  4, 40, 11, 29, 43, 13, 48, 56,
    33, 39, 22, 44, 24,  5, 46,  6, 32, 36, 51, 58, 52, 60, 21, 53, 26, 31, 37,  9,
    25, 38, 23, 50, 45, 17, 47, 28, 35, 30, 54, 16, 59, 49, 57, 14, 15, 55, 7
]]).to(device)

# The argmax of the output (ie the most likely next move from each position)
sample_output = t.tensor([[
    21, 41, 40, 34, 40, 41,  3, 11, 21, 43, 40, 21, 28, 50, 33, 50, 33,  5, 33,  5,
    52, 46, 14, 46, 14, 47, 38, 57, 36, 50, 38, 15, 28, 26, 28, 59, 50, 28, 14, 28,
    28, 28, 28, 45, 28, 35, 15, 14, 30, 59, 49, 59, 15, 15, 14, 15,  8,  7,  8
]]).to(device)

assert (model(sample_input).argmax(dim=-1) == sample_output.to(device)).all()

# os.chdir(section_dir)
section_dir = Path.cwd()
sys.path.append(str(section_dir))
print(section_dir.name)

OTHELLO_ROOT = (section_dir / "othello_world").resolve()
OTHELLO_MECHINT_ROOT = (OTHELLO_ROOT / "mechanistic_interpretability").resolve()

# if not OTHELLO_ROOT.exists():
#     !git clone https://github.com/likenneth/othello_world

sys.path.append(str(OTHELLO_MECHINT_ROOT))

from mech_interp_othello_utils import (
    plot_board,
    plot_single_board,
    plot_board_log_probs,
    to_string,
    to_int,
    int_to_label,
    string_to_label,
    OthelloBoardState
)

# Load board data as ints (i.e. 0 to 60)
board_seqs_int = t.tensor(np.load(OTHELLO_MECHINT_ROOT / "board_seqs_int_small.npy"), dtype=t.long)
# Load board data as "strings" (i.e. 0 to 63 with middle squares skipped out)
board_seqs_string = t.tensor(np.load(OTHELLO_MECHINT_ROOT / "board_seqs_string_small.npy"), dtype=t.long)

assert all([middle_sq not in board_seqs_string for middle_sq in [27, 28, 35, 36]])
assert board_seqs_int.max() == 60

num_games, length_of_game = board_seqs_int.shape

# Define possible indices (excluding the four center squares)
stoi_indices = [i for i in range(64) if i not in [27, 28, 35, 36]]

# Define our rows, and the function that converts an index into a (row, column) label, e.g. `E2`
alpha = "ABCDEFGH"

def to_board_label(i):
    return f"{alpha[i//8]}{i%8}"

# Get our list of board labels
board_labels = list(map(to_board_label, stoi_indices))
full_board_labels = list(map(to_board_label, range(64)))

def plot_square_as_board(state, diverging_scale=True, **kwargs):
    """Takes a square input (8 by 8) and plot it as a board. Can do a stack of boards via facet_col=0"""
    kwargs = {
        "y": [i for i in alpha],
        "x": [str(i) for i in range(8)],
        "color_continuous_scale": "RdBu" if diverging_scale else "Blues",
        "color_continuous_midpoint": 0. if diverging_scale else None,
        "aspect": "equal",
        **kwargs
    }
    imshow(state, **kwargs)

start = 30000
num_games = 50
focus_games_int = board_seqs_int[start : start + num_games]
focus_games_string = board_seqs_string[start: start + num_games]

focus_logits, focus_cache = model.run_with_cache(focus_games_int[:, :-1].to(device))
focus_logits.shape

def one_hot(list_of_ints, num_classes=64):
    out = t.zeros((num_classes,), dtype=t.float32)
    out[list_of_ints] = 1.
    return out

focus_states = np.zeros((num_games, 60, 8, 8), dtype=np.float32)
focus_valid_moves = t.zeros((num_games, 60, 64), dtype=t.float32)

for i in (range(num_games)):
    board = OthelloBoardState()
    for j in range(60):
        board.umpire(focus_games_string[i, j].item())
        focus_states[i, j] = board.state
        focus_valid_moves[i, j] = one_hot(board.get_valid_moves())

print("focus states:", focus_states.shape)
print("focus_valid_moves", tuple(focus_valid_moves.shape))

# full_linear_probe = t.load(OTHELLO_MECHINT_ROOT / "main_linear_probe.pth", map_location=device)

linear_probe2 = t.load("probes/linear/resid_6_linear.pth")

rows = 8
cols = 8
options = 3
assert linear_probe2.shape == (1, cfg.d_model, rows, cols, options)

black_to_play_index = 0
white_to_play_index = 1
blank_index = 0
their_index = 1
my_index = 2

# Creating values for linear probe (converting the "black/white to play" notation into "me/them to play")

'''LAYER = 6
game_index = 0
move = 29'''

BLANK1 = 0
BLACK = 1
WHITE = -1

# MINE = 0
# YOURS = 1
# BLANK2 = 2

EMPTY = 0
YOURS = 1
MINE = 2



interpretability
interpretability
focus states: (50, 60, 8, 8)
focus_valid_moves (50, 60, 64)


## More Imports

In [2]:
from utils import visualize_game

# Code
- Plan:
    - We want something that takes a handcoded Algorithm and evaluates how good it is
    - Input depends on Layer
        - Layer 0: All the Previous Moves.
        - Layer n > 1: All the Previous Moves + All the Board Logits of the last Layers of the previous and the current moves
        - Also the Flipped Probs
        - Should I just use the Cache ...

In [3]:
# Load Probes
linear_probes = []
flipped_probes = []
for layer in range(8):
    linear_probe = t.load(f"probes/linear/resid_{layer}_linear.pth").to(device)
    flipped_probe = t.load(f"probes/flipped/resid_{layer}_flipped.pth").to(device)
    linear_probes.append(linear_probe)
    flipped_probes.append(flipped_probe)

In [4]:
def get_feature_logits(resid : Float[Tensor, "batch pos d_model"], layer : Int):
    linear_probes = linear_probes[layer]
    flipped_probe = flipped_probes[layer]
    board_logits = einops.einsum(resid, linear_probe, 'batch seq d_model, modes d_model rows cols options -> model batch seq rows cols options')[0]
    flipped_logits = einops.einsum(resid, flipped_probe, 'batch seq d_model, modes d_model rows cols options -> model batch seq rows cols options')[0]
    return board_logits, flipped_logits

In [6]:
class Algo():
    def __init__(self, params):
        self.params = params
    
    def forward(self, resid : Float[Tensor, "batch pos d_model"], layer) -> Tuple[Float[Tensor, "batch pos rows cols options"], Float[Tensor, "batch pos rows cols options"]]:
        pass

    def get_random_variation(self, change_factor : float):
        params_new = dict()
        for key, value in self.params.items():
            params_new[key] = value + t.randn_like(value) * value * change_factor
        return Algo(params_new)

In [4]:
class history_circuit(Algo):
    def __init__(self, params):
        # Hyperparams: How far to go back, Threshhold where you count a tile as Flipped, Threshhold for Mine/Yours
        super().__init__(params)
        self.mine_threshold = self.paramas["mine_threshold"]
        self.flipped_threshold = self.paramas["flipped_threshold"]

    def forward(self, prev_resid : Float[Tensor, "batch pos d_model"], prev_layer) -> Tuple[Float[Tensor, "batch pos rows cols options"], Float[Tensor, "batch pos rows cols options"]]:

        # Go over each previous position B of layer -1
        # Check what Color the Tile should be...
        #    From the Beginning the first time the MINE/YOURS is above a certain Threshhold count the tile as Mine / Yours, then count the Flipped ... (Not what the model is acutally doing though)
        #    I could try other Things like: Look at the Last time the Tile has been Flipped, this is probably the Color
        #        If the Tile has not been Flipped => First Placed ...
        #        I could first visulize and look what makes sense. Or I just code it out and try different things (better evidence I think, but I need to have good code for fast feedback loops)
        # If Tile has Flipped at Pos A in the Last Layer, then it should be the opposite color
        # If Tile has not Flipped, then it should be the same color
        batch, seq_len, _ = prev_resid.shape
        board_logits_prev, flipped_logits_prev = get_feature_logits(prev_resid, prev_layer)
        board_logits_new = t.zeros((batch, seq_len, 8, 8, 3))
        for pos in range(seq_len):
            logits_at_pos = board_logits_prev[:, pos]
            flipped_at_pos = flipped_logits_prev[:, pos]
        

def flipping_circuit():
    # Goes in all directionns and if YOURS, turn into MINE with Flipped, until you hit a MINE without Flipped
    # Threshold for when a Tile is counted as Flipped
    # Threshold for when a Tile is counted as MINE vs YOURS
    pass

def layer_n_algo(layer):
    # Do all positions at the same time or go over each position?
    # Go over each Tile
    # If Blank, leave Blank
    # Else do History Circuit
    # Using the update from the History Circuit do Flipping Circuit 
    pass

In [None]:
def evaluator(layer: Int, model: HookedTransformer):
    assert 0 <= layer < 8
    # Use focus_cache ...
    # Do I want to give 1 game or many to the function? I think many is better. I might want to scale this
    # alog function should return the board_logits and the flipped logits
    # get the real board_logits
    # batch, seq, d_model
    resid = focus_cache["resid_post", layer]
    # modes, d_model, rows, cols, options
    board_logits, flipped_logits = get_feature_logits(resid, layer)
    board_logits_pred, flipped_logits_pred = layer_n_algo(layer)
    board_correct = board_logits.argmax(dim=-1)
    flipped_correct = flipped_logits.argmax(dim=-1)
    board_pred = board_logits_pred.argmax(dim=-1)
    flipped_pred = flipped_logits_pred.argmax(dim=-1)
    # Calculate the Accuracy
    board_accuracy = (board_correct == board_pred).float().mean()
    flipped_accuracy = (flipped_correct == flipped_pred).float().mean()
    return board_accuracy, flipped_accuracy
