# Setup

In [2]:
import os, sys
chapter = "chapter1_transformer_interp"
repo = "ARENA_3.0"
chapter_dir = r"./" if chapter in os.listdir() else os.getcwd().split(chapter)[0]
sys.path.append(chapter_dir + f"{chapter}/exercises")

import os
os.environ["ACCELERATE_DISABLE_RICH"] = "1"
import sys
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch.utils.data import DataLoader
import numpy as np
import einops
from ipywidgets import interact
import plotly.express as px
from ipywidgets import interact
from pathlib import Path
import itertools
import random
from IPython.display import display
from jaxtyping import Float, Int, Bool, Shaped, jaxtyped
from typing import List, Union, Optional, Tuple, Callable, Dict
import typeguard
from functools import partial
# from torcheval.metrics.functional import multiclass_f1_score
from sklearn.metrics import f1_score as multiclass_f1_score
import copy
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import HookedRootModule, HookPoint
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from tqdm.notebook import tqdm
from dataclasses import dataclass
from rich import print as rprint
import pandas as pd

# Make sure exercises are in the path
# exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
# section_dir = exercises_dir / "part6_othellogpt"
# section_dir = "interpretability"
# if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow
from neel_plotly import scatter, line
from generate_patches import generate_patch
from pprint import pprint
from utils import plot_game
from training_utils import get_state_stack_num_flipped
from utils import plot_probe_outputs
from utils import seq_to_state_stack
from utils import VisualzeBoardArguments
from utils import visualize_game

import plotly.graph_objects as go
from plotly.subplots import make_subplots
from utils import plot_boards_general
import numpy as np
import pickle

# import part6_othellogpt.tests as tests

t.manual_seed(42)

device = t.device("cuda" if t.cuda.is_available() else "cpu")

MAIN = __name__ == "__main__"

cfg = HookedTransformerConfig(
    n_layers = 8,
    d_model = 512,
    d_head = 64,
    n_heads = 8,
    d_mlp = 2048,
    d_vocab = 61,
    n_ctx = 59,
    act_fn="gelu",
    normalization_type="LNPre",
    device=device,
)
model = HookedTransformer(cfg)

sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "synthetic_model.pth")
# champion_ship_sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "championship_model.pth")
model.load_state_dict(sd)

# An example input
sample_input = t.tensor([[
    20, 19, 18, 10,  2,  1, 27,  3, 41, 42, 34, 12,  4, 40, 11, 29, 43, 13, 48, 56,
    33, 39, 22, 44, 24,  5, 46,  6, 32, 36, 51, 58, 52, 60, 21, 53, 26, 31, 37,  9,
    25, 38, 23, 50, 45, 17, 47, 28, 35, 30, 54, 16, 59, 49, 57, 14, 15, 55, 7
]]).to(device)

# The argmax of the output (ie the most likely next move from each position)
sample_output = t.tensor([[
    21, 41, 40, 34, 40, 41,  3, 11, 21, 43, 40, 21, 28, 50, 33, 50, 33,  5, 33,  5,
    52, 46, 14, 46, 14, 47, 38, 57, 36, 50, 38, 15, 28, 26, 28, 59, 50, 28, 14, 28,
    28, 28, 28, 45, 28, 35, 15, 14, 30, 59, 49, 59, 15, 15, 14, 15,  8,  7,  8
]]).to(device)

assert (model(sample_input).argmax(dim=-1) == sample_output.to(device)).all()

# os.chdir(section_dir)
section_dir = Path.cwd()
sys.path.append(str(section_dir))
print(section_dir.name)

OTHELLO_ROOT = (section_dir / "othello_world").resolve()
OTHELLO_MECHINT_ROOT = (OTHELLO_ROOT / "mechanistic_interpretability").resolve()

# if not OTHELLO_ROOT.exists():
#     !git clone https://github.com/likenneth/othello_world

sys.path.append(str(OTHELLO_MECHINT_ROOT))

from mech_interp_othello_utils import (
    plot_board,
    plot_single_board,
    plot_board_log_probs,
    to_string,
    to_int,
    int_to_label,
    string_to_label,
    OthelloBoardState
)

# Load board data as ints (i.e. 0 to 60)
board_seqs_int = t.tensor(np.load(OTHELLO_MECHINT_ROOT / "board_seqs_int_small.npy"), dtype=t.long)
# Load board data as "strings" (i.e. 0 to 63 with middle squares skipped out)
board_seqs_string = t.tensor(np.load(OTHELLO_MECHINT_ROOT / "board_seqs_string_small.npy"), dtype=t.long)

assert all([middle_sq not in board_seqs_string for middle_sq in [27, 28, 35, 36]])
assert board_seqs_int.max() == 60

num_games, length_of_game = board_seqs_int.shape

# Define possible indices (excluding the four center squares)
stoi_indices = [i for i in range(64) if i not in [27, 28, 35, 36]]

# Define our rows, and the function that converts an index into a (row, column) label, e.g. `E2`
alpha = "ABCDEFGH"

def to_board_label(i):
    return f"{alpha[i//8]}{i%8}"

# Get our list of board labels
board_labels = list(map(to_board_label, stoi_indices))
full_board_labels = list(map(to_board_label, range(64)))

def plot_square_as_board(state, diverging_scale=True, **kwargs):
    """Takes a square input (8 by 8) and plot it as a board. Can do a stack of boards via facet_col=0"""
    kwargs = {
        "y": [i for i in alpha],
        "x": [str(i) for i in range(8)],
        "color_continuous_scale": "RdBu" if diverging_scale else "Blues",
        "color_continuous_midpoint": 0. if diverging_scale else None,
        "aspect": "equal",
        **kwargs
    }
    imshow(state, **kwargs)

start = 30000
num_games = 50
focus_games_int = board_seqs_int[start : start + num_games]
focus_games_string = board_seqs_string[start: start + num_games]

focus_logits, focus_cache = model.run_with_cache(focus_games_int[:, :-1].to(device))
focus_logits.shape

def one_hot(list_of_ints, num_classes=64):
    out = t.zeros((num_classes,), dtype=t.float32)
    out[list_of_ints] = 1.
    return out

focus_states = np.zeros((num_games, 60, 8, 8), dtype=np.float32)
focus_valid_moves = t.zeros((num_games, 60, 64), dtype=t.float32)

for i in (range(num_games)):
    board = OthelloBoardState()
    for j in range(60):
        board.umpire(focus_games_string[i, j].item())
        focus_states[i, j] = board.state
        focus_valid_moves[i, j] = one_hot(board.get_valid_moves())

print("focus states:", focus_states.shape)
print("focus_valid_moves", tuple(focus_valid_moves.shape))

# full_linear_probe = t.load(OTHELLO_MECHINT_ROOT / "main_linear_probe.pth", map_location=device)

rows = 8
cols = 8
options = 3

black_to_play_index = 0
white_to_play_index = 1
blank_index = 0
their_index = 1
my_index = 2

# Creating values for linear probe (converting the "black/white to play" notation into "me/them to play")

from utils import *



interpretability
interpretability
focus states: (50, 60, 8, 8)
focus_valid_moves (50, 60, 64)


# Code

- First Played Circuit
    - Look if in layer 0 Emb @ OV_Mine = Mine and Emb @ OV_Yours = Yours
        - Whats the Role of positional Embedding?
            - Important for QK Circuit?

In [3]:
def string_to_tile(pos_str):
    return (pos_str // 8, pos_str % 8)

In [4]:
W_E = model.W_E
OV = model.OV
print(W_E.shape, OV.shape)
print(type(OV))

torch.Size([61, 512]) torch.Size([8, 8, 512, 512])
<class 'transformer_lens.FactoredMatrix.FactoredMatrix'>


- This Plot shows the Layer 0 Base Cicuit. Predict the Tile that was last played. Heads 0,7 for MINE and Heads 3, 6 for YOURS
- This Shows how Corners are done!

In [5]:
def cosine_similiarity(x, y):
    return t.dot(x, y) / (t.norm(x) * t.norm(y))

In [27]:
mine_vs_yours : Float[Tensor, "layer head rows cols"] = t.zeros((8, 8, 8, 8))

for layer in range(0, 8):
    linear_probe = get_probe(layer, "linear", "post")
    for head in range(8):
        positions = list(range(1, 61))
        # linear_probe : Float[Tensor, "modes d_model rows cols options"]= t.load(f"probes/linear/resid_{layer}_linear.pth").to(device).detach()
        OV_layer_head = OV.AB[layer, head].detach()
        # mine_vs_yours = t.zeros((8, 8))
        for pos_int in positions:
            pos_str = to_string(pos_int)
            tile = string_to_tile(pos_str)
            pos_empty = linear_probe[0, :, *tile, EMPTY]
            pos_mine = linear_probe[0, :, *tile, MINE]
            pos_yours = linear_probe[0, :, *tile, YOURS]
            pos_emb = W_E[pos_int, :].detach()
            result = einops.einsum(pos_emb, OV_layer_head, "d_model_in, d_model_in d_model_out -> d_model_out")
            # mine_vs_yours[head, layer, *tile] = t.dot(result, pos_mine) - t.dot(result, pos_yours)
            mine_vs_yours[head, layer, *tile] = cosine_similiarity(result, pos_yours - pos_mine)

plot_boards_general(x_labels=[f"Head {i}" for i in range(8)], y_labels=[f"Layer {i}" for i in range(2)], boards=mine_vs_yours[:, :2], title_text = "How much does each Head add to the Yours vs Mine direction, when a Tile is played", save=True)
# fig.show()


- Next: Last Flipped Circuit
    - Heads with High Attention to the Flip direction?
        - I don't know if this makes sense, because there are flips in every move
    - I don't know how this would be implemented...
        - Maybe all the Flips write, but the last one is the strongest...
    - For the Prove I need to distinguish the Heads Attending to MINE and Heads Attending to Yours
        - Then I show FLIPPED @ OV_mine = MINE and FLIPPED @ OV_yours = yours

In [7]:
heads_dict = pickle.load(open("heads_dict.pkl", "rb"))
print(heads_dict)

{'mine_far': [(0, 3), (0, 6), (1, 0)], 'yours_far': [(0, 0), (0, 7)], 'mine_close': [(1, 0), (1, 4), (1, 5), (1, 6), (2, 0), (2, 6), (2, 7), (3, 0), (4, 4), (4, 6)], 'yours_close': [(0, 2), (1, 2), (1, 3), (2, 1), (2, 3), (2, 4), (2, 5), (3, 1), (3, 2), (3, 3), (3, 4), (3, 5), (3, 6), (3, 7), (4, 0), (4, 1), (4, 2), (4, 3), (4, 7), (5, 0), (5, 3), (5, 4), (5, 5), (5, 7), (6, 2), (6, 5)], 'mine_last': [(1, 5), (3, 0)], 'yours_last': [(0, 2), (2, 3), (4, 5), (5, 7)], 'first': [(1, 1), (1, 7), (3, 2), (4, 0), (4, 3), (5, 0), (5, 3), (5, 4), (6, 0), (6, 1), (6, 2), (6, 3), (6, 4), (6, 5), (6, 6), (6, 7), (7, 1), (7, 4), (7, 5), (7, 6)], 'stripe': [(0, 1), (0, 4), (0, 5)]}


In [21]:
def plot_OV_circuits(input_direction_1 : Tuple[str, str], input_direction_2 : Tuple[str, str], output_direction_1 : Tuple[str, str], output_direction_2 : Tuple[str, str]):
    direction1_vs_direction2 : Float[Tensor, "layer head rows cols"] = t.zeros((8, 8, 8, 8))
    for layer in range(1, 8):
        probe_prev_layer = t.zeros(size=(512, 8, 8)).to(device)
        probe_this_layer_1 = get_probe(layer, output_direction_1[0], "post")
        probe_this_layer_1 = probe_this_layer_1[0, :, :, :, get_direction_int(output_direction_1[1])]
        probe_this_layer_2 = get_probe(layer, output_direction_2[0], "post")
        probe_this_layer_2 = probe_this_layer_2[0, :, :, :, get_direction_int(output_direction_2[1])]
        probe_prev_layer_1 = get_probe(layer, input_direction_1[0], "post")
        probe_prev_layer_1 = probe_prev_layer_1[0, :, :, :, get_direction_int(input_direction_1[1])]
        probe_prev_layer_2 = get_probe(layer, input_direction_2[0], "post")
        probe_prev_layer_2 = probe_prev_layer_2[0, :, :, :, get_direction_int(input_direction_2[1])]
        '''for probe_name in input_directions:
            probe = get_probe(layer - 1, probe_name, "post")
            for direction_name in input_directions[probe_name]:
                probe_direction = probe[0, :, :, :, get_direction_int(direction_name)]
                probe_prev_layer += probe_direction'''
        probe_prev_layer = probe_prev_layer_1 - probe_prev_layer_2
        for head in range(8):
            positions = list(range(0, 64))
            OV_layer_head = OV.AB[layer, head].detach()
            # mine_vs_yours = t.zeros((8, 8))
            for pos_str in positions:
                # pos_str = to_string(pos_int)
                tile = string_to_tile(pos_str)
                result = einops.einsum(probe_prev_layer[:, *tile], OV_layer_head, "d_model_in, d_model_in d_model_out -> d_model_out")
                # direction1_vs_direction2 = t.dot(result, probe_this_layer_1[:, *tile]) - t.dot(result, probe_this_layer_2[:, *tile])
                # direction1_vs_direction2[head, layer, *tile] = cosine_similiarity(result, probe_this_layer_1[:, *tile]) - cosine_similiarity(result, probe_this_layer_2[:, *tile])
                direction1_vs_direction2[head, layer, *tile] = cosine_similiarity(result, probe_this_layer_1[:, *tile] - probe_this_layer_2[:, *tile])
    return direction1_vs_direction2


input_direction1 = ("flipped", "flipped")
input_direction2= ("linear", "yours")
direction1 = ("linear", "yours")
direction2 = ("linear", "mine")
direction1_vs_direction2_flipped = plot_OV_circuits(input_direction1, input_direction2, direction1, direction2)
'''input_directions = {
    "linear": ["yours"]
}
direction1 = ("linear", "mine")
direction2 = ("linear", "yours")
direction1_vs_direction2_yours = plot_OV_circuits(input_directions, direction1, direction2)

direction1_vs_direction2_flipped = direction1_vs_direction2_flipped[:, 1:3]
direction1_vs_direction2_yours = direction1_vs_direction2_yours[:, 1:3]
direction1_vs_direction2 = t.cat((direction1_vs_direction2_flipped, direction1_vs_direction2_yours), dim=1)'''

plot_boards_general(
        x_labels=[f"Head {i}" for i in range(8)],
        y_labels=[f"Layer {i}" for i in range(1, 4)],
        boards=direction1_vs_direction2_flipped[:, 1:4],
        title_text = "How much does each Head add to the Yours vs Mine direction, when Flipped AND Not Yours",
        save=True,
    )

# Old

In [None]:
# Ich will Heads aufteilen in MINE und YOURS attending und vielleicht wie far Back. Dann gabs auch noch First Token und Last token
# attention to itself, last, mine, yours, first, how far back (how do I measure this what is the function with which this decreases?)
# First do the mine / yours distinction then look at the function
def get_mine_their_attention_score(QK : Float[Tensor, "d_model_q d_model_k"]):
    mine_attn = 0
    count = 0
    for pos in range(60):
        pos_emb = model.W_E[pos, :]
        for pos_mine in range(pos- 2, 0, -2):
            count += 1
            pos_mine_emb = model.W_E[pos_mine, :]
            attn = einops.einsum(pos_mine_emb, QK, pos_emb, "d_model_q, d_model_q d_model_k, d_model_k -> ").item()
            mine_attn += attn
    return mine_attn / count

for layer in range(8):
    for head in range(8):
        QK = model.OV.AB[layer, head]
        # (f"L{layer}H{head}: {get_mine_their_attention_score(QK)}")

get_mine_their_attention_score(QK)