# Setup

In [3]:
import os, sys
chapter = "chapter1_transformer_interp"
repo = "ARENA_3.0"
chapter_dir = r"./" if chapter in os.listdir() else os.getcwd().split(chapter)[0]
sys.path.append(chapter_dir + f"{chapter}/exercises")

import os
os.environ["ACCELERATE_DISABLE_RICH"] = "1"
import sys
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch.utils.data import DataLoader
import numpy as np
import einops
from ipywidgets import interact
import plotly.express as px
from ipywidgets import interact
from pathlib import Path
import itertools
import random
from IPython.display import display
from jaxtyping import Float, Int, Bool, Shaped, jaxtyped
from typing import List, Union, Optional, Tuple, Callable, Dict
import typeguard
from functools import partial
# from torcheval.metrics.functional import multiclass_f1_score
from sklearn.metrics import f1_score as multiclass_f1_score
import copy
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import HookedRootModule, HookPoint
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from tqdm.notebook import tqdm
from dataclasses import dataclass
from rich import print as rprint
import pandas as pd

# Make sure exercises are in the path
# exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
# section_dir = exercises_dir / "part6_othellogpt"
# section_dir = "interpretability"
# if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow
from neel_plotly import scatter, line
from generate_patches import generate_patch
from pprint import pprint
from utils import plot_game
from training_utils import get_state_stack_num_flipped
from utils import plot_probe_outputs
from utils import seq_to_state_stack
from utils import VisualzeBoardArguments
from utils import visualize_game

# import part6_othellogpt.tests as tests

t.manual_seed(42)

device = t.device("cuda" if t.cuda.is_available() else "cpu")

MAIN = __name__ == "__main__"

cfg = HookedTransformerConfig(
    n_layers = 8,
    d_model = 512,
    d_head = 64,
    n_heads = 8,
    d_mlp = 2048,
    d_vocab = 61,
    n_ctx = 59,
    act_fn="gelu",
    normalization_type="LNPre",
    device=device,
)
model = HookedTransformer(cfg)

sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "synthetic_model.pth")
# champion_ship_sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "championship_model.pth")
model.load_state_dict(sd)

# An example input
sample_input = t.tensor([[
    20, 19, 18, 10,  2,  1, 27,  3, 41, 42, 34, 12,  4, 40, 11, 29, 43, 13, 48, 56,
    33, 39, 22, 44, 24,  5, 46,  6, 32, 36, 51, 58, 52, 60, 21, 53, 26, 31, 37,  9,
    25, 38, 23, 50, 45, 17, 47, 28, 35, 30, 54, 16, 59, 49, 57, 14, 15, 55, 7
]]).to(device)

# The argmax of the output (ie the most likely next move from each position)
sample_output = t.tensor([[
    21, 41, 40, 34, 40, 41,  3, 11, 21, 43, 40, 21, 28, 50, 33, 50, 33,  5, 33,  5,
    52, 46, 14, 46, 14, 47, 38, 57, 36, 50, 38, 15, 28, 26, 28, 59, 50, 28, 14, 28,
    28, 28, 28, 45, 28, 35, 15, 14, 30, 59, 49, 59, 15, 15, 14, 15,  8,  7,  8
]]).to(device)

assert (model(sample_input).argmax(dim=-1) == sample_output.to(device)).all()

# os.chdir(section_dir)
section_dir = Path.cwd()
sys.path.append(str(section_dir))
print(section_dir.name)

OTHELLO_ROOT = (section_dir / "othello_world").resolve()
OTHELLO_MECHINT_ROOT = (OTHELLO_ROOT / "mechanistic_interpretability").resolve()

# if not OTHELLO_ROOT.exists():
#     !git clone https://github.com/likenneth/othello_world

sys.path.append(str(OTHELLO_MECHINT_ROOT))

from mech_interp_othello_utils import (
    plot_board,
    plot_single_board,
    plot_board_log_probs,
    to_string,
    to_int,
    int_to_label,
    string_to_label,
    OthelloBoardState
)

# Load board data as ints (i.e. 0 to 60)
board_seqs_int = t.tensor(np.load(OTHELLO_MECHINT_ROOT / "board_seqs_int_small.npy"), dtype=t.long)
# Load board data as "strings" (i.e. 0 to 63 with middle squares skipped out)
board_seqs_string = t.tensor(np.load(OTHELLO_MECHINT_ROOT / "board_seqs_string_small.npy"), dtype=t.long)

assert all([middle_sq not in board_seqs_string for middle_sq in [27, 28, 35, 36]])
assert board_seqs_int.max() == 60

num_games, length_of_game = board_seqs_int.shape

# Define possible indices (excluding the four center squares)
stoi_indices = [i for i in range(64) if i not in [27, 28, 35, 36]]

# Define our rows, and the function that converts an index into a (row, column) label, e.g. `E2`
alpha = "ABCDEFGH"

def to_board_label(i):
    return f"{alpha[i//8]}{i%8}"

# Get our list of board labels
board_labels = list(map(to_board_label, stoi_indices))
full_board_labels = list(map(to_board_label, range(64)))

def plot_square_as_board(state, diverging_scale=True, **kwargs):
    """Takes a square input (8 by 8) and plot it as a board. Can do a stack of boards via facet_col=0"""
    kwargs = {
        "y": [i for i in alpha],
        "x": [str(i) for i in range(8)],
        "color_continuous_scale": "RdBu" if diverging_scale else "Blues",
        "color_continuous_midpoint": 0. if diverging_scale else None,
        "aspect": "equal",
        **kwargs
    }
    imshow(state, **kwargs)

start = 30000
num_games = 50
focus_games_int = board_seqs_int[start : start + num_games]
focus_games_string = board_seqs_string[start: start + num_games]

focus_logits, focus_cache = model.run_with_cache(focus_games_int[:, :-1].to(device))
focus_logits.shape

def one_hot(list_of_ints, num_classes=64):
    out = t.zeros((num_classes,), dtype=t.float32)
    out[list_of_ints] = 1.
    return out

focus_states = np.zeros((num_games, 60, 8, 8), dtype=np.float32)
focus_valid_moves = t.zeros((num_games, 60, 64), dtype=t.float32)

for i in (range(num_games)):
    board = OthelloBoardState()
    for j in range(60):
        board.umpire(focus_games_string[i, j].item())
        focus_states[i, j] = board.state
        focus_valid_moves[i, j] = one_hot(board.get_valid_moves())

print("focus states:", focus_states.shape)
print("focus_valid_moves", tuple(focus_valid_moves.shape))

# full_linear_probe = t.load(OTHELLO_MECHINT_ROOT / "main_linear_probe.pth", map_location=device)

linear_probe2 = t.load("probes/linear/resid_6_linear.pth")

rows = 8
cols = 8
options = 3
assert linear_probe2.shape == (1, cfg.d_model, rows, cols, options)

black_to_play_index = 0
white_to_play_index = 1
blank_index = 0
their_index = 1
my_index = 2

# Creating values for linear probe (converting the "black/white to play" notation into "me/them to play")

'''LAYER = 6
game_index = 0
move = 29'''

BLANK1 = 0
BLACK = 1
WHITE = -1

# MINE = 0
# YOURS = 1
# BLANK2 = 2

EMPTY = 0
MINE = 1
YOURS = 2



interpretability
interpretability
focus states: (50, 60, 8, 8)
focus_valid_moves (50, 60, 64)


## More Imports

In [4]:
from utils import visualize_game

# Code

In [5]:
# Load Probes
linear_probes = []
flipped_probes = []
for layer in range(8):
    linear_probe = t.load(f"probes/linear/resid_{layer}_linear.pth").to(device)
    flipped_probe = t.load(f"probes/flipped/resid_{layer}_flipped.pth").to(device)
    linear_probes.append(linear_probe)
    flipped_probes.append(flipped_probe)
print(len(linear_probes), len(flipped_probes))

8 8


In [6]:
def get_feature_logits(resid : Float[Tensor, "batch pos d_model"], layer : Int, softmax : bool = False) -> Tuple[Float[Tensor, "model batch seq rows cols options"], Float[Tensor, "model batch seq rows cols options"]]:
    """
    resid: [batch pos d_model]
    layer: Int
    softmax: bool
    output:
        Tuple[Float[Tensor, "model batch seq rows cols options"], Float[Tensor, "model batch seq rows cols options"]]
        The first tensor is the logits for the board, the second tensor is the logits for the flipped board
    """
    linear_probe = linear_probes[layer]
    flipped_probe = flipped_probes[layer]
    board_logits = einops.einsum(resid, linear_probe, 'batch seq d_model, modes d_model rows cols options -> modes batch seq rows cols options')[0]
    flipped_logits = einops.einsum(resid, flipped_probe, 'batch seq d_model, modes d_model rows cols options -> modes batch seq rows cols options')[0]
    if softmax:
        board_logits = F.softmax(board_logits, dim=-1)
        flipped_logits = F.softmax(flipped_logits, dim=-1)
    return board_logits, flipped_logits

In [23]:
# TODO: Print out the Mistakes, so I can visualize them
# TODO: Print out the Tiles where the mistake is
# TODO: Import shit to visualize
# TODO: Visualize
# TODO: Different Layers: Different Threshholds, mabe different positions ???
# TODO: Add the Flipping Circuit
# TODO: For the Future I can add evaluation that distinguishes Tiles that got changed from the previous layer and Tiles that did not
# TODO: Add hard eval (everything that's not Empty)

@dataclass
class Parameters():
    flipped_thresh : Float = 0.2
    how_far_back : Int = 15
    start_pos : Int = 10
    end_pos : Int = 20
    layers : Int = 7
    evaluation_module : str = "attn_out"
    include_current_pos : bool = False
    easy_eval: bool = True
    easy_eval_thresh : float = 0.8
    yours_thresh : float = 0.8
    number_mistakes_to_print : Int = 15
    circuits = ["last_flipped", "flipped"]
    evaluate_on_everything : bool = True

In [8]:
def last_flipped_circuit(
        params : Parameters,
        mine_yours_logits: Float[Tensor, "batch pos rows cols options"],
        flipped_logits : Float[Tensor, "batch pos rows cols options"],
        predicted : Bool[Tensor, "batch pos rows cols"],
    ) -> Tuple[Float[Tensor, "batch pos rows cols options"], Bool[Tensor, "batch pos rows cols"]]:
    last_flipped_values = mine_yours_logits.clone()
    last_flipped_values[:, :, :, :, 0] = 1
    for pos1 in range(params.start_pos, params.end_pos):
        if params.include_current_pos:
            local_end_pos = pos1 + 1
        else:
            local_end_pos = pos1
        for pos in range(max([local_end_pos - params.how_far_back, 0]), local_end_pos):
            flipped_diff = flipped_logits.softmax(dim=-1)[:, pos, :, :, 0]
            was_flipped = flipped_diff > params.flipped_thresh
            if pos1 % 2 == pos % 2:
                last_flipped_values[:, pos1][was_flipped] = mine_yours_logits[:, pos][was_flipped]
            else:
                mine_yours_logits_copy = mine_yours_logits.clone()
                mine_yours_logits_copy[:, pos, :, :, 1], mine_yours_logits_copy[:, pos, :, :, 2] = mine_yours_logits[:, pos, :, :, 2], mine_yours_logits[:, pos, :, :, 1]
                last_flipped_values[:, pos1][was_flipped] = mine_yours_logits_copy[:, pos][was_flipped]
            predicted[:, pos1] |= was_flipped
    return last_flipped_values, predicted

In [9]:
def tile_empty(mine_yours: Float[Tensor, "batch pos rows cols"], batch : Int, pos1 : Int, tile : Tuple[Int, Int]):
    result = tile[0] < 0 or tile[0] >= 8 or tile[1] < 0 or tile[1] >= 8
    if result:
        return result
    result = (mine_yours[batch, pos1, tile[0], tile[1]] == EMPTY).item()
    return result

def flipping_circuit(
        params : Parameters,
        mine_yours_logits: Float[Tensor, "batch pos rows cols options"],
        flipped_logits : Float[Tensor, "batch pos rows cols options"],
        predicted: Bool[Tensor, "batch pos rows cols"],
    ) -> Tuple[Float[Tensor, "batch pos rows cols options"], Bool[Tensor, "batch pos rows cols"]]:
    # TODO: Add the history ...
    # TODO: Add Canges to the Flipping Logits
    # Get the last played Tile(s) (as input or calculate themselves. I think as input is nice)
    # focus_games_str
    # Go over each position
    mine_yours_logits_clone = mine_yours_logits.clone()
    # mine_yours = mine_yours_logits.softmax(dim=-1)
    # TODO: This Behaviour does what you want only when yours_thresh > 0.5
    # mine_yours[:, :, :, :, YOURS][mine_yours[:, :, :, :, YOURS] < params.yours_thresh] = 0
    mine_yours = mine_yours_logits.argmax(dim=-1)
    # TODO: Use flipped_thresh
    flipped = flipped_logits.argmax(dim=-1)
    batch_size, seq_len, rows, cols, options = mine_yours_logits.shape
    for batch in range(batch_size):
        for pos1 in range(params.start_pos, params.end_pos):
            tile_str = focus_games_string[batch, pos1]
            tile = (tile_str // 8, tile_str % 8)
            flipped_list = []
            for x_delta in [-1, 0, 1]:
                for y_delta in [-1, 0, 1]:
                    potentially_flipped = []
                    for dist in range(1, 8):
                        this_tile = (tile[0] + x_delta * dist, tile[1] + y_delta * dist)
                        if tile_empty(mine_yours, batch, pos1, this_tile):
                            break
                        next_tile = (this_tile[0] + x_delta, this_tile[1] + y_delta)
                        # TODO: turn this into better functions
                        if mine_yours[batch, pos1, this_tile[0], this_tile[1]] == MINE and flipped[batch, pos1, this_tile[0], this_tile[1]] == 1:
                            flipped_list += potentially_flipped
                            break
                        if tile_empty(mine_yours, batch, pos1, next_tile):
                            break
                        if mine_yours[batch, pos1, this_tile[0], this_tile[1]] == YOURS:
                            potentially_flipped += [this_tile]
            for tile in flipped_list:
                mine_yours_logits[batch, pos1, tile[0], tile[1], YOURS] = mine_yours_logits_clone[batch, pos1, tile[0], tile[1], MINE]
                mine_yours_logits[batch, pos1, tile[0], tile[1], MINE] = mine_yours_logits_clone[batch, pos1, tile[0], tile[1], YOURS]
                predicted[batch, pos1, tile[0], tile[1]] = True
    #   Use a YOURS Threshold (in params)
    #   create list of tiles (row, col) that should be flipped
    #   flip the tiles
    #       Option1: Add fixxed amount to the MINE logit
    #       Option2: Swap the values for MINE and YOURS? (I think this makes sense but isn't very general)
    #           Swap using a weight term (write 0.5 of MINE to YOURS ...) then add a fixxed bias to both logits? (Use different weight and bias for mine and yours)
    # pass
    return mine_yours_logits, predicted

In [24]:
def get_predictions(params : Parameters, resid_pre: Float[Tensor, "batch pos d_model"], layer : Int) -> Float[Tensor, "batch pos rows cols options"]:
    # batch pos rows cols option
    # TODO: Maybe remove the inputs and outputs and just use function side effects ...
    # TODO: use prev layer as base prediction
    batch_size, seq_len, _= resid_pre.shape
    rows, cols = 8, 8
    if params.evaluate_on_everything:
        predicted : Bool[Tensor, "batch pos rows cols"] = t.ones((batch_size, seq_len, rows, cols), dtype=t.bool).to(device)
    else:
        predicted : Bool[Tensor, "batch pos rows cols"] = t.zeros((batch_size, seq_len, rows, cols), dtype=t.bool).to(device)
    mine_yours_logits_pred, flipped_logits = get_feature_logits(resid_pre, layer-1)
    if "last_flipped" in params.circuits:
        mine_yours_logits_pred, predicted = last_flipped_circuit(params, mine_yours_logits_pred, flipped_logits, predicted)
    if "flipped" in params.circuits:
        mine_yours_logits_pred, predicted = flipping_circuit(params, mine_yours_logits_pred, flipped_logits, predicted)
    # mine_yours_logits_pred = mine_yours_logits.clone()
    return mine_yours_logits_pred, predicted

In [11]:
def print_first_mistakes(pred : Float[Tensor, "batch pos rows cols"], real : Float[Tensor, "batch pos rows cols"], mask : Float[Tensor, "batch pos rows cols"], num_mistakes : Int):
    batch_size, seq_len, rows, cols = pred.shape
    mistakes = ((pred != real) & mask)
    for i in range(batch_size):
        for j in range(seq_len):
            for row in range(rows):
                for col in range(cols):
                    if not mistakes[i, j, row, col]:
                        continue
                    tile_label = string_to_label(row*8 + col)
                    print(f"Game: {i}, Pos: {j}, Tile: {tile_label}")
                    num_mistakes -= 1
                    if num_mistakes == 0:
                        return

def get_evaluation_mask(params : Parameters, predicted : Float[Tensor, "batch pos rows cols"], real_values : Float[Tensor, "batch pos rows cols options"]):
    batch_size, seq_len, rows, cols = predicted.shape
    # create bool tensor of shape [batch_size, seq_len, rows, cols] True
    mask = t.ones((batch_size, seq_len, rows, cols), dtype=t.bool).to(device)
    # Mask pred values that are 0
    # mask = mask & (pred_values.argmax(dim=-1) != 0)
    mask = mask & predicted
    mask = mask & (real_values.argmax(dim=-1) != 0)
    # If easy eval, mask real values that are below 0.7
    if params.easy_eval:
        real_values = real_values.softmax(dim=-1)
        mask = mask & ((real_values[:, :, :, :, 2] > params.easy_eval_thresh) | (real_values[:, :, :, :, 1] > params.easy_eval_thresh))
    return mask


def get_accuracy(params : Parameters,
        pred_values : Float[Tensor, "batch pos rows cols options"],
        resid_post : Float[Tensor, "batch pos d_model"],
        predicted: Bool[Tensor, "batch pos rows cols"],
        layer : Int,
    ) -> Float[Tensor, ""]:
    #   - Create a Mask for what should be evaluated and what not, with a function, then use that mask to calculate the accuracy
    batch_size, seq_len, _, _, _ = pred_values.shape
    real_values , _ = get_feature_logits(resid_post, layer)
    evaluation_mask = get_evaluation_mask(params, predicted, real_values)
    pred_values = pred_values.argmax(dim=-1)
    real_values = real_values.argmax(dim=-1)
    print_first_mistakes(pred_values, real_values, evaluation_mask, params.number_mistakes_to_print)
    num_labels = evaluation_mask.sum()
    num_correct = (pred_values[evaluation_mask] == real_values[evaluation_mask]).float().sum()
    acc = num_correct / num_labels
    return acc

In [12]:
def test_last_flipped_accuracy(params : Parameters):
    for layer in range(1, params.layers):
        resid_pre = focus_cache["resid_pre", layer]
        resid_post = focus_cache[params.evaluation_module, layer]
        # shape: [batch pos rows cols options]
        mine_yours_logits_pred, predicted = get_predictions(params, resid_pre, layer)
        acc = get_accuracy(params, mine_yours_logits_pred, resid_post, predicted, layer)
        print(f"Layer: {layer}, Accuracy: {acc.item()}")

In [35]:
params = Parameters()
params.layers = 6
params.start_pos = 0
params.end_pos = 30
params.easy_eval = True
params.easy_eval_thresh = 0.9
params.evaluation_module = "resid_post"
params.include_current_pos = True
params.flipped_thresh = 0.35
params.circuits = ["last_flipped", "flipped"]
params.evaluate_on_everything = True

test_last_flipped_accuracy(params)

Game: 0, Pos: 13, Tile: C3
Game: 0, Pos: 14, Tile: F2
Game: 0, Pos: 15, Tile: C4
Game: 0, Pos: 16, Tile: C3
Game: 0, Pos: 17, Tile: C1
Game: 0, Pos: 17, Tile: C3
Game: 0, Pos: 18, Tile: F2
Game: 0, Pos: 21, Tile: E5
Game: 0, Pos: 22, Tile: E4
Game: 0, Pos: 23, Tile: E5
Game: 0, Pos: 23, Tile: G3
Game: 0, Pos: 24, Tile: F7
Game: 0, Pos: 25, Tile: D4
Game: 0, Pos: 26, Tile: D4
Game: 0, Pos: 30, Tile: C4
Layer: 1, Accuracy: 0.9465149641036987
Game: 0, Pos: 18, Tile: C3
Game: 0, Pos: 20, Tile: D3
Game: 0, Pos: 24, Tile: E5
Game: 0, Pos: 28, Tile: B4
Game: 0, Pos: 30, Tile: C3
Game: 0, Pos: 31, Tile: B4
Game: 0, Pos: 31, Tile: C6
Game: 0, Pos: 31, Tile: E2
Game: 0, Pos: 33, Tile: C4
Game: 0, Pos: 33, Tile: C6
Game: 0, Pos: 35, Tile: C2
Game: 0, Pos: 37, Tile: C6
Game: 0, Pos: 38, Tile: D4
Game: 0, Pos: 38, Tile: G3
Game: 0, Pos: 40, Tile: B2
Layer: 2, Accuracy: 0.9740877747535706
Game: 0, Pos: 0, Tile: C3
Game: 0, Pos: 27, Tile: E3
Game: 0, Pos: 27, Tile: E4
Game: 0, Pos: 28, Tile: E3
Game:

In [20]:
vis_args = VisualzeBoardArguments()
vis_args.start_pos = 0
vis_args.end_pos = 20
vis_args.layers = 6
vis_args.include_attn_only = False
vis_args.include_mlp_only = False
vis_args.mode = "linear"
# vis_args.static_image = True

clean_input_str = focus_games_string[1, :59]

# visualize_game(clean_input_str, vis_args, model)

In [2]:
vis_args = VisualzeBoardArguments()
vis_args.start_pos = 5
vis_args.end_pos = 20
vis_args.layers = 6
vis_args.include_attn_only = False
vis_args.include_mlp_only = False
vis_args.mode = "flipped"
# vis_args.static_image = True

clean_input_str = focus_games_string[1, :59]

# visualize_game(clean_input_str, vis_args, model)

NameError: name 'VisualzeBoardArguments' is not defined