In [1]:
import os, sys
chapter = "chapter1_transformer_interp"
repo = "ARENA_3.0"
chapter_dir = r"./" if chapter in os.listdir() else os.getcwd().split(chapter)[0]
sys.path.append(chapter_dir + f"{chapter}/exercises")

import os
os.environ["ACCELERATE_DISABLE_RICH"] = "1"
import sys
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch.utils.data import DataLoader
import numpy as np
import einops
from ipywidgets import interact
import plotly.express as px
from ipywidgets import interact
from pathlib import Path
import itertools
import random
from IPython.display import display
from jaxtyping import Float, Int, Bool, Shaped, jaxtyped
from typing import List, Union, Optional, Tuple, Callable, Dict
import typeguard
from functools import partial
# from torcheval.metrics.functional import multiclass_f1_score
from sklearn.metrics import f1_score as multiclass_f1_score
import copy
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import HookedRootModule, HookPoint
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from tqdm.notebook import tqdm
from dataclasses import dataclass
from rich import print as rprint
import pandas as pd

# Make sure exercises are in the path
# exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
# section_dir = exercises_dir / "part6_othellogpt"
# section_dir = "interpretability"
# if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow
from neel_plotly import scatter, line
# from generate_patches import generate_patch
from pprint import pprint
from utils import plot_game
from training_utils import get_state_stack_num_flipped
from utils import plot_probe_outputs
from utils import seq_to_state_stack
from utils import VisualzeBoardArguments
from utils import visualize_game

import plotly.graph_objects as go
from plotly.subplots import make_subplots
from utils import plot_boards_general
import numpy as np
import pickle

# import part6_othellogpt.tests as tests

t.manual_seed(42)

device = t.device("cuda" if t.cuda.is_available() else "cpu")

MAIN = __name__ == "__main__"

cfg = HookedTransformerConfig(
    n_layers = 8,
    d_model = 512,
    d_head = 64,
    n_heads = 8,
    d_mlp = 2048,
    d_vocab = 61,
    n_ctx = 59,
    act_fn="gelu",
    normalization_type="LNPre",
    device=device,
)
model = HookedTransformer(cfg)

sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "synthetic_model.pth")
# champion_ship_sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "championship_model.pth")
model.load_state_dict(sd)

# An example input
sample_input = t.tensor([[
    20, 19, 18, 10,  2,  1, 27,  3, 41, 42, 34, 12,  4, 40, 11, 29, 43, 13, 48, 56,
    33, 39, 22, 44, 24,  5, 46,  6, 32, 36, 51, 58, 52, 60, 21, 53, 26, 31, 37,  9,
    25, 38, 23, 50, 45, 17, 47, 28, 35, 30, 54, 16, 59, 49, 57, 14, 15, 55, 7
]]).to(device)

# The argmax of the output (ie the most likely next move from each position)
sample_output = t.tensor([[
    21, 41, 40, 34, 40, 41,  3, 11, 21, 43, 40, 21, 28, 50, 33, 50, 33,  5, 33,  5,
    52, 46, 14, 46, 14, 47, 38, 57, 36, 50, 38, 15, 28, 26, 28, 59, 50, 28, 14, 28,
    28, 28, 28, 45, 28, 35, 15, 14, 30, 59, 49, 59, 15, 15, 14, 15,  8,  7,  8
]]).to(device)

assert (model(sample_input).argmax(dim=-1) == sample_output.to(device)).all()

# os.chdir(section_dir)
section_dir = Path.cwd()
sys.path.append(str(section_dir))
print(section_dir.name)

OTHELLO_ROOT = (section_dir / "othello_world").resolve()
OTHELLO_MECHINT_ROOT = (OTHELLO_ROOT / "mechanistic_interpretability").resolve()

# if not OTHELLO_ROOT.exists():
#     !git clone https://github.com/likenneth/othello_world

sys.path.append(str(OTHELLO_MECHINT_ROOT))

from mech_interp_othello_utils import (
    plot_board,
    plot_single_board,
    plot_board_log_probs,
    to_string,
    to_int,
    int_to_label,
    string_to_label,
    OthelloBoardState
)

# Load board data as ints (i.e. 0 to 60)
board_seqs_int = t.tensor(np.load(OTHELLO_MECHINT_ROOT / "board_seqs_int_small.npy"), dtype=t.long)
# Load board data as "strings" (i.e. 0 to 63 with middle squares skipped out)
board_seqs_string = t.tensor(np.load(OTHELLO_MECHINT_ROOT / "board_seqs_string_small.npy"), dtype=t.long)

assert all([middle_sq not in board_seqs_string for middle_sq in [27, 28, 35, 36]])
assert board_seqs_int.max() == 60

num_games, length_of_game = board_seqs_int.shape

# Define possible indices (excluding the four center squares)
stoi_indices = [i for i in range(64) if i not in [27, 28, 35, 36]]

# Define our rows, and the function that converts an index into a (row, column) label, e.g. `E2`
alpha = "ABCDEFGH"

def to_board_label(i):
    return f"{alpha[i//8]}{i%8}"

# Get our list of board labels
board_labels = list(map(to_board_label, stoi_indices))
full_board_labels = list(map(to_board_label, range(64)))

def plot_square_as_board(state, diverging_scale=True, **kwargs):
    """Takes a square input (8 by 8) and plot it as a board. Can do a stack of boards via facet_col=0"""
    kwargs = {
        "y": [i for i in alpha],
        "x": [str(i) for i in range(8)],
        "color_continuous_scale": "RdBu" if diverging_scale else "Blues",
        "color_continuous_midpoint": 0. if diverging_scale else None,
        "aspect": "equal",
        **kwargs
    }
    imshow(state, **kwargs)

start = 0
num_games = 50
focus_games_int = board_seqs_int[start : start + num_games]
focus_games_string = board_seqs_string[start: start + num_games]

focus_logits, focus_cache = model.run_with_cache(focus_games_int[:, :-1].to(device))
focus_logits.shape

def one_hot(list_of_ints, num_classes=64):
    out = t.zeros((num_classes,), dtype=t.float32)
    out[list_of_ints] = 1.
    return out

focus_states = np.zeros((num_games, 60, 8, 8), dtype=np.float32)
focus_valid_moves = t.zeros((num_games, 60, 64), dtype=t.float32)

for i in (range(num_games)):
    board = OthelloBoardState()
    for j in range(60):
        board.umpire(focus_games_string[i, j].item())
        focus_states[i, j] = board.state
        focus_valid_moves[i, j] = one_hot(board.get_valid_moves())

print("focus states:", focus_states.shape)
print("focus_valid_moves", tuple(focus_valid_moves.shape))

# full_linear_probe = t.load(OTHELLO_MECHINT_ROOT / "main_linear_probe.pth", map_location=device)

rows = 8
cols = 8
options = 3

black_to_play_index = 0
white_to_play_index = 1
blank_index = 0
their_index = 1
my_index = 2

# Creating values for linear probe (converting the "black/white to play" notation into "me/them to play")

from utils import *

In [2]:
def get_estimated_attention_pattern(num_games):
    estimated_attention_pattern : Float[Tensor, "layer head pos_from pos_to"] = t.zeros((8, 8, 59, 59)).to(device)
    estimated_attention_pattern_variance = t.zeros((8, 8)).to(device)
    for layer in range(8):
        _, cache = model.run_with_cache(
            board_seqs_int[:num_games, :-1].to(device),
            return_type=None,
            names_filter = lambda name : name in [utils.get_act_name("pattern", layer)]
        )
        attention_pattern = cache["pattern", layer]
        estimated_attention_pattern[layer] = attention_pattern.mean(dim=0)
        estimated_attention_pattern_variance[layer] = attention_pattern.var(dim=0).mean()
    estimated_attention_pattern_variance = estimated_attention_pattern_variance.mean()
    print(estimated_attention_pattern_variance)
    return estimated_attention_pattern

def get_avg_resid(layer, num_games):
    _, cache = model.run_with_cache(
        board_seqs_int[:num_games, :-1].to(device),
        return_type=None,
        names_filter = lambda name : name in [f"blocks.{layer}.ln1.hook_normalized"]
    )
    return cache[f"blocks.{layer}.ln1.hook_normalized"].mean(dim=0)

In [3]:
'''# Wenn ich so an allen stellen etwas ändere muss ich wirklich schon beim coden alles testen, so viel es geht!!!
layer = 1
tile_tuple = (4, 4)
probe_flipped_b = get_probe(layer-1, "flipped", "post")[0, :, *tile_tuple, FLIPPED].detach()
yours_probe_b = get_probe(layer, "linear", "mid")[0, :, *tile_tuple, YOURS].detach()
mine_probe_b = get_probe(layer, "linear", "mid")[0, :, *tile_tuple, MINE].detach()
probe_flipped_normalized_b = probe_flipped_b / probe_flipped_b.norm()
flipped_after_OV_b = einops.einsum(probe_flipped_normalized_b, OV.AB[layer, :], "d_model_in, head_idx d_model_in d_model_out -> head_idx d_model_out")
conversion_factors_yours_b = einops.einsum(flipped_after_OV_b, yours_probe_b, "head_idx d_model, d_model -> head_idx")
conversion_factors_mine_b = einops.einsum(flipped_after_OV_b, mine_probe_b, "head_idx d_model, d_model -> head_idx")

avg_resids_b = focus_cache[f"blocks.{layer}.ln1.hook_normalized"].mean(dim=0)
projection_b = einops.repeat(avg_resids_b @ probe_flipped_b, "pos -> pos d_model", d_model=512) / (probe_flipped_b @ probe_flipped_b) * probe_flipped_b
avg_resids_without_flipped_b = avg_resids_b - projection_b
avg_resids_after_OV_b = einops.einsum(avg_resids_without_flipped_b, OV.AB[layer, :], "pos d_model_in, head_idx d_model_in d_model_out -> head_idx pos d_model_out")
avg_resids_yours_bias_b = einops.einsum(avg_resids_after_OV_b, yours_probe_b, "head_idx pos d_model, d_model -> head_idx pos")
avg_resids_mine_bias_b = einops.einsum(avg_resids_after_OV_b, mine_probe_b, "head_idx pos d_model, d_model -> head_idx pos")'''

# All the Things needed later for all layers head row col
# Bei den layout von dimensionen sollte man immer darauf achten, je wahrscheinlicher es ist, dass eine dimension für mul oder add benutz wird, desto weiter nach rechts sollte sie sein
# TODO: Does it make more sense if this is normalized? NO Because we want the Logit and thats not normalized! (DONE)
num_games = 200

OV = model.OV
flipped_probe : Float[Tensor, "d_model layer row col"] = t.Tensor(size=(512, 8, 8, 8)).to(device)
flipped_probe_normalized : Float[Tensor, "d_model layer row col"] = t.Tensor(size=(512, 8, 8, 8)).to(device)
yours_probe : Float[Tensor, "d_model layer row col"] = t.Tensor(size=(512, 8, 8, 8)).to(device)
mine_probe : Float[Tensor, "d_model layer row col"] = t.Tensor(size=(512, 8, 8, 8)).to(device)
empty_probe : Float[Tensor, "d_model layer row col"] = t.Tensor(size=(512, 8, 8, 8)).to(device)
conversion_factors_mine : Float[Tensor, "layer head_idx row col"] = t.Tensor(size=(8, 8, 8, 8)).to(device)
conversion_factors_yours : Float[Tensor, "layer head_idx row col"] = t.Tensor(size=(8, 8, 8, 8)).to(device) 
avg_resids_mine_bias : Float[Tensor, "layer head_idx pos_to row col"] = t.Tensor(size=(8, 8, 59, 8, 8)).to(device)
avg_resids_yours_bias : Float[Tensor, "layer head_idx pos_to row col"] = t.Tensor(size=(8, 8, 59, 8, 8)).to(device)
avg_resids_without_flipped : Float[Tensor, "layer d_model pos_to row col"] = t.Tensor(size=(8, 512, 59, 8, 8)).to(device)

module = "post" # Eigentlich "mid" ...

for layer in range(1, 8):
    flipped_probe_s = get_probe(layer-1, "flipped", "post")[0, :, :, :, FLIPPED].detach()
    flipped_probe[:, layer, :, :] = flipped_probe_s
    flipped_probe_normalized = flipped_probe / flipped_probe.norm(dim=0)
    yours_probe_s = get_probe(layer, "linear", module)[0, :, :, :, YOURS].detach()
    yours_probe[:, layer, :, :] = yours_probe_s
    mine_probe_s = get_probe(layer, "linear", module)[0, :, :, :, MINE].detach()
    mine_probe[:, layer, :, :] = mine_probe_s
    empty_probe_s = get_probe(layer, "linear", module)[0, :, :, :, EMPTY].detach()
    empty_probe[:, layer, :, :] = empty_probe_s

    # avg_resids = focus_cache[f"blocks.{layer}.ln1.hook_normalized"].mean(dim=0)
    avg_resids = get_avg_resid(layer, 200)
    projection = einops.einsum(avg_resids, flipped_probe_s, "p d, d r c -> p r c") / einops.einsum(flipped_probe_s, flipped_probe_s, "d r c, d r c -> r c") * einops.repeat(flipped_probe_s, "d r c -> d p r c", p=59)
    avg_resids_without_flipped_s = einops.repeat(avg_resids, "pos_to d_model -> d_model pos_to row col", row = 8, col = 8) - projection
    avg_resids_without_flipped[layer] = avg_resids_without_flipped_s
    avg_resids_after_OV = einops.einsum(avg_resids_without_flipped_s, OV.AB[layer, :], "d_model_in pos_to row col, head_idx d_model_in d_model_out -> d_model_out head_idx pos_to row col")
    avg_resids_yours_bias[layer] = einops.einsum(avg_resids_after_OV, yours_probe_s, "d_model head_idx pos_to row col, d_model row col -> head_idx pos_to row col")
    avg_resids_mine_bias[layer] = einops.einsum(avg_resids_after_OV, mine_probe_s, "d_model head_idx pos_to row col, d_model row col -> head_idx pos_to row col")

flipped_after_OV = einops.einsum(flipped_probe_normalized, OV.AB, "d_model_in layer row col, layer head_idx d_model_in d_model_out -> d_model_out layer head_idx row col")
conversion_factors_yours = einops.einsum(flipped_after_OV, yours_probe, "d_model layer head_idx row col, d_model layer row col -> layer head_idx row col")
conversion_factors_mine = einops.einsum(flipped_after_OV, mine_probe, "d_model layer head_idx row col, d_model layer row col -> layer head_idx row col")

#yours_after_OV = einops.einsum(yours_probe_normalized, OV.AB, "d_model_in layer row col, layer head_idx d_model_in d_model_out -> d_model_out layer head_idx row col")
#conversion_factors_yours = einops.einsum(flipped_after_OV, yours_probe, "d_model layer head_idx row col, d_model layer row col -> layer head_idx row col")
#conversion_factors_mine = einops.einsum(flipped_after_OV, mine_probe, "d_model layer head_idx row col, d_model layer row col -> layer head_idx row col")


# TODO: Make this work (DONE)
def get_probe_dir2(resid : Float[Tensor, "batch pos d_model"], layer : int, row, col):
    flipped_probe_normalized_small = flipped_probe_normalized[:, layer, row, col]
    avg_resids_without_flipped_small = avg_resids_without_flipped[layer, :, :, row, col]
    flipped_in_resid : Float[Tensor, "batch pos"] = resid @ flipped_probe_normalized_small
    dir = einops.repeat(avg_resids_without_flipped_small, "d_model pos -> batch pos d_model", batch=200) + einops.repeat(flipped_in_resid, "batch pos -> batch pos d_model", d_model=512) * flipped_probe_normalized_small
    return dir

estimated_attention_pattern = get_estimated_attention_pattern(200)

tensor(0.0007, device='cuda:0')


In [4]:
from jaxtyping import Float, Int, Bool, Shaped, jaxtyped
from typing import List, Union, Optional, Tuple, Callable, Dict


In [5]:
DEBUG = False

def add_bias(yours_logits_pred, mine_logits_pred, layer, tile_tuple):
    yours_probe_s = yours_probe[:, layer, *tile_tuple]
    mine_probe_s = mine_probe[:, layer, *tile_tuple]
    bias = model.b_O[layer]
    yours_logit_bias = einops.einsum(bias, yours_probe_s, "d_model, d_model -> ")
    yours_logits_pred += yours_logit_bias
    mine_logit_bias = einops.einsum(bias, mine_probe_s, "d_model, d_model -> ")
    mine_logits_pred += mine_logit_bias
    return yours_logits_pred, mine_logits_pred

def get_attn_pattern(layer, use_attn_pattern_approx, batch_size, cache = None) -> Float[Tensor, "head pos_from pos_to"]:
    assert cache is not None or use_attn_pattern_approx
    if use_attn_pattern_approx:
        attention_pattern = einops.repeat(estimated_attention_pattern[layer, :, :], "head pos_from pos_to -> batch head pos_from pos_to", batch=batch_size)
    else:
        attention_pattern : Float[Tensor, "head_idx pos_from pos_to"] = cache["pattern", layer][:, :, :]
    return attention_pattern

def get_yours_and_mine_pred_old(
        resid_real : Float[Tensor, "batch pos d_model"],
        layer : int, 
        tile_tuple : Tuple[int, int],
        cache,
        use_attn_pattern_approx : bool = True
    ):
    yours_logits_pred = t.zeros((200, 59)).to(device)
    mine_logits_pred = t.zeros((200, 59)).to(device)
    for head in range(8):
        yours_probe_s = yours_probe[:, layer, *tile_tuple]
        mine_probe_s = mine_probe[:, layer, *tile_tuple]
        resid = get_probe_dir2(resid_real, layer, *tile_tuple) # TODO: this could be wrong i guess
        # resid = resid_real
        # print(f"resid shape: {resid.shape}")
        # test_probe_dir(resid_real, resid)
        head_V = model.W_V[layer, head]
        head_v = einops.einsum(resid, head_V, "batch pos d_model, d_model d_head -> batch pos d_head")
        attention_pattern = get_attn_pattern(layer, use_attn_pattern_approx, cache)[:, head]
        # attention_pattern : Float[Tensor, "pos"] = focus_cache["pattern", layer][:, head, pos_from]
        z = einops.repeat(head_v, "batch pos_to d_head -> batch pos_from pos_to d_head", pos_from = 59) * einops.repeat(attention_pattern, "batch pos_from pos_to -> batch pos_from pos_to d_head", d_head=64)
        z = einops.reduce(z, "batch pos_from pod_to d_head -> batch pos_from d_head", "sum")
        # z = focus_cache["z", layer][:, pos_from, head] # TODO: Remove
        result = einops.einsum(z, model.W_O[layer, head], "batch pos_from d_head, d_head d_model -> batch pos_from d_model")
        yours_logit_head = einops.einsum(result, yours_probe_s, "batch pos_from d_model, d_model -> batch pos_from")
        yours_logits_pred += yours_logit_head
        mine_logit_head = einops.einsum(result, mine_probe_s, "batch pos_from d_model, d_model -> batch pos_from")
        mine_logits_pred += mine_logit_head
        # attn_out_fake += result
    yours_logits_pred, mine_logits_pred = add_bias(yours_logits_pred, mine_logits_pred, layer, tile_tuple)
    return yours_logits_pred, mine_logits_pred
        

def get_yours_and_mine_pred_math2(
        resid_real : Float[Tensor, "batch pos d_model"],
        layer : int,
        tile_tuple : tuple[int, int],
        cache,
        use_attn_pattern_approx=True,
    ):
    # flipped_probe_normalized[layer] acutally means the flipped probe of the previous layer
    flipped_probe_normalized_s : Float[Tensor, "d_model"] = flipped_probe_normalized[:, layer, *tile_tuple]
    conversion_factors_mine_s : Float[Tensor, "head_idx"] = conversion_factors_mine[layer, :, *tile_tuple]
    conversion_factors_yours_s : Float[Tensor, "head_idx"] = conversion_factors_yours[layer, :, *tile_tuple]
    avg_resids_mine_bias_s : Float[Tensor, "head_idx pos_to"] = avg_resids_mine_bias[layer, :, :, *tile_tuple]
    avg_resids_yours_bias_s : Float[Tensor, "head_idx pos_to"] = avg_resids_yours_bias[layer, :, :, *tile_tuple]
    # "layer head_idx pos_to row col"
    # TODO: Use Avg Resid, Dont Remove negative Flipped Logits
    attention_pattern : Float[Tensor, "head pos_from pos_to"] = get_attn_pattern(layer, use_attn_pattern_approx, cache)
    flipped_logit = einops.einsum(resid_real, flipped_probe_normalized_s, "batch pos_to d_model, d_model -> batch pos_to")
    # not_flipped_logit = einops.einsum(resid_real, probe_not_flipped_normalized, "batch pos d_model, d_model -> batch pos")
    # Negative Logits are doing a lot work ..
    # flipped_logit = flipped_logit * ((flipped_logit > not_flipped_logit) & (flipped_logit > 0)).to(device)
    # flipped_logit = t.max(flipped_logit, t.zeros_like(flipped_logit).to(device))
    flipped_logit = einops.repeat(flipped_logit, "batch pos_to -> batch head_idx pos_to", head_idx=8)
    yours_logits_pred = einops.repeat(flipped_logit * einops.repeat(conversion_factors_yours_s, "head_idx -> head_idx pos_to", pos_to=59) + avg_resids_yours_bias_s, "batch head_idx pos_to -> batch head_idx pos_from pos_to", pos_from=59) * attention_pattern
    yours_logits_pred = einops.reduce(yours_logits_pred, "batch head_idx pos_from pos_to -> batch pos_from", "sum")
    mine_logits_pred = einops.repeat(flipped_logit * einops.repeat(conversion_factors_mine_s, "head_idx -> head_idx pos_to", pos_to=59) + avg_resids_mine_bias_s, "batch head_idx pos_to -> batch head_idx pos_from pos_to", pos_from=59) * attention_pattern
    mine_logits_pred = einops.reduce(mine_logits_pred, "batch head_idx pos_from pos_to -> batch pos_from", "sum")
    yours_logits_pred, mine_logits_pred = add_bias(yours_logits_pred, mine_logits_pred, layer, tile_tuple)
    return yours_logits_pred, mine_logits_pred

def get_logits_real(layer, tile_tuple, cache=None, attn_out=None):
    if attn_out is None:
        attn_out = cache["attn_out", layer]
    # attn_out = cache["attn_out", layer]
    yours_probe_s = yours_probe[:, layer, *tile_tuple]
    mine_probe_s = mine_probe[:, layer, *tile_tuple]
    empty_probe_s = empty_probe[:, layer, *tile_tuple]
    yours_logits = einops.einsum(attn_out, yours_probe_s, "batch pos_from d_model, d_model -> batch pos_from")
    mine_logits = einops.einsum(attn_out, mine_probe_s, "batch pos_from d_model, d_model -> batch pos_from")
    empty_logits = einops.einsum(attn_out, empty_probe_s, "batch pos_from d_model, d_model -> batch pos_from")
    return yours_logits, mine_logits, empty_logits

def get_mind_change_mask(cache, layer, tile_tuple):
    resid_real = cache[f"blocks.{layer}.ln1.hook_normalized"]
    resid_mid = cache[f"blocks.{layer}.hook_resid_mid"]
    yours_probe_mid_layer = yours_probe[:, layer, *tile_tuple]
    mine_probe_mid_layer = mine_probe[:, layer, *tile_tuple]
    yours_probe_prev_layer = get_probe(layer-1, "linear", "post")[0, :, *tile_tuple, YOURS].detach()
    mine_probe_prev_layer = get_probe(layer-1, "linear", "post")[0, :, *tile_tuple, MINE].detach()
    yours_logits_prev_layer = einops.einsum(resid_real, yours_probe_prev_layer, "batch pos d_model, d_model -> batch pos")
    mine_logits_prev_layer = einops.einsum(resid_real, mine_probe_prev_layer, "batch pos d_model, d_model -> batch pos")
    yours_logits_mid_layer = einops.einsum(resid_mid, yours_probe_mid_layer, "batch pos d_model, d_model -> batch pos")
    mine_logits_mid_layer = einops.einsum(resid_mid, mine_probe_mid_layer, "batch pos d_model, d_model -> batch pos")
    mask = t.ones(size=(200, 59)).to(device)
    mask[(yours_logits_mid_layer < mine_logits_mid_layer) & (yours_logits_prev_layer < mine_logits_prev_layer)] = 0
    mask[(mine_logits_mid_layer < yours_logits_mid_layer) & (mine_logits_prev_layer < yours_logits_prev_layer)] = 0
    return mask.to(dtype=t.int)

# Input: resid_real, Outpu: logits pred and real
def get_yours_and_mine_pred(cache, layer, tile_label, use_attn_pattern_approx, func_to_evaluate, only_mind_changes=False):
    # TODO: Output Cool Ass Dataframe
    # resid_real is correct. I thought it should be layer -1 but NO!
    resid_real = cache[f"blocks.{layer}.ln1.hook_normalized"]
    tile_tuple = label_to_tuple(tile_label)
    yours_logits_pred, mine_logits_pred = func_to_evaluate(resid_real, layer, tile_tuple, cache, use_attn_pattern_approx)
    yours_logits, mine_logits, empty_logits = get_logits_real(cache, layer, tile_tuple)
    all_logits = t.stack([empty_logits, mine_logits, yours_logits], dim=-1)
    mask = all_logits.argmax(dim=-1) != 0
    logits_diff = yours_logits - mine_logits
    logits_pred_diff = yours_logits_pred - mine_logits_pred
    # TOOD: Evaluate only on not empty tiles
    # correct = (logits_pred_diff > 0) == (logits_diff > 0)
    if only_mind_changes:
        mind_change_mask = get_mind_change_mask(cache, layer, tile_tuple)
        mask = mask * mind_change_mask
    # correct = correct * mask
    # return correct.float().sum(dim=0) / mask.sum(dim=0)

    # remove everything where mask is 0
    # logits_diff = logits_diff[mask]
    # logits_pred_diff = logits_pred_diff[mask]
    return logits_diff, logits_pred_diff, mask

def get_scores(logit_diff, logit_diff_preds, mask):
    # calculate tp, tn, fp, fn
    if DEBUG:
        mask = t.ones_like(mask).to(device)
    tp = einops.reduce((logit_diff > 0) & (logit_diff_preds > 0) & mask, "batch pos -> pos", "sum")
    tn = einops.reduce((logit_diff < 0) & (logit_diff_preds < 0) & mask, "batch pos -> pos", "sum")
    fp = einops.reduce((logit_diff < 0) & (logit_diff_preds > 0) & mask, "batch pos -> pos", "sum")
    fn = einops.reduce((logit_diff > 0) & (logit_diff_preds < 0) & mask, "batch pos -> pos", "sum")
    return tp, tn, fp, fn

def get_yours_and_mine_pred_results(num_batches, batch_size, use_attn_pattern_approx, func_to_evaluate, only_mind_changes=False, start=200):
    # TODO: Seperate the batches used for attention approximation and the rest
    results = {
        "TP" : t.zeros((8, 59, 8, 8)).to(device),
        "TN" : t.zeros((8, 59, 8, 8)).to(device),
        "FP" : t.zeros((8, 59, 8, 8)).to(device),
        "FN" : t.zeros((8, 59, 8, 8)).to(device),
    }

    for layer in range(1, 8):
        for batch in range(num_batches):
            indeces = t.arange(start + batch * batch_size, start + (batch + 1) * batch_size).to(dtype=t.int)
            _, cache = model.run_with_cache(
                board_seqs_int[indeces, :-1].to(device),
                return_type=None,
                names_filter=lambda name: name in [f"blocks.{layer}.ln1.hook_normalized", f"blocks.{layer}.hook_resid_mid", utils.get_act_name("attn_out", layer), f"blocks.{layer}.attn.hook_pattern"]
            )
            for row in range(8):
                for col in range(8):
                    if DEBUG:
                        row, col = 3, 3
                    tile_label = tuple_to_label((row, col))
                    # get logits pred, real
                    logits_diff, logits_diff_pred, mask = get_yours_and_mine_pred(cache, layer, tile_label, use_attn_pattern_approx, func_to_evaluate, only_mind_changes)
                    tp, tn, fp, fn = get_scores(logits_diff, logits_diff_pred, mask)
                    if DEBUG:
                        print(tp[10], tn[10], fp[10], fn[10])
                    results["TP"][layer, :, row, col] += tp
                    results["TN"][layer, :, row, col] += tn
                    results["FP"][layer, :, row, col] += fp
                    results["FN"][layer, :, row, col] += fn
                    if DEBUG:
                        break
                if DEBUG:
                    break
        if DEBUG:
            break
    return results

EPSILON = 1e-6
def get_score_from_results(results : dict[str, Tensor], dimensions : list[str]):
    # TODO: Also do Weighted F1 (mhhh idk. I have to think about useful metrics here ...)
    assert all([dimension in ["layer", "pos", "row", "col"] for dimension in dimensions])
    scores = {}
    # compress results to the specified dimensions
    results_compressed = {}
    for key in results.keys():
        results_compressed[key] = einops.reduce(results[key], f"layer pos row col -> {' '.join(dimensions)}", "sum")
    # calculate scores
    tp = results_compressed["TP"]
    tn = results_compressed["TN"]
    fp = results_compressed["FP"]
    fn = results_compressed["FN"]
    scores["Accuracy"] = (tp + tn) / (tp + tn + fp + fn + EPSILON)
    scores["Precision"] = tp / (tp + fp + EPSILON)
    scores["Recall"] = tp / (tp + fn + EPSILON)
    scores["F1"] = 2 * (scores["Precision"] * scores["Recall"]) / (scores["Precision"] + scores["Recall"])
    return scores

def evaluate_yours_and_mine_pred(num_batches, batch_size, use_attn_pattern_approx, func_to_evaluate, only_mind_changes=False):
    results = get_yours_and_mine_pred_results(num_batches, batch_size, use_attn_pattern_approx, func_to_evaluate, only_mind_changes)
    scores = get_score_from_results(results, ["layer", "pos", "row", "col"])
    return scores

In [6]:
'''DEBUG = True
batches = 10
batch_size = 200
only_mind_changes = False
for use_attn_pattern_approx in [True, False]:
    if use_attn_pattern_approx:
        print("Using Attention Pattern Approximation")
        approx_str = "approx"
    else:
        print("Using Attention Pattern")
        approx_str = "real"
    print("Math")
    results_math = get_yours_and_mine_pred_results(batches, batch_size, use_attn_pattern_approx, get_yours_and_mine_pred_math2, only_mind_changes)
    scores_math = get_score_from_results(results_math, ["layer", "pos", "row", "col"])
    print(f"Math Acc: {scores_math['Accuracy'][1, 10, *label_to_tuple('D3')].item():.4f}")
    # with open(f"results_math_{approx_str}_tesing.pkl", "wb") as file:
    #     pickle.dump(results_math, file)
    print("Test")
    results_test = get_yours_and_mine_pred_results(batches, batch_size, use_attn_pattern_approx, get_yours_and_mine_pred_old, only_mind_changes)
    scores_test = get_score_from_results(results_test, ["layer", "pos", "row", "col"])
    print(f"Real Acc: {scores_test['Accuracy'][1, 10, *label_to_tuple('D3')].item():.4f}")
    # with open(f"results_test_{approx_str}_testing.pkl", "wb") as file:
    #     pickle.dump(results_test, file)'''

'DEBUG = True\nbatches = 10\nbatch_size = 200\nonly_mind_changes = False\nfor use_attn_pattern_approx in [True, False]:\n    if use_attn_pattern_approx:\n        print("Using Attention Pattern Approximation")\n        approx_str = "approx"\n    else:\n        print("Using Attention Pattern")\n        approx_str = "real"\n    print("Math")\n    results_math = get_yours_and_mine_pred_results(batches, batch_size, use_attn_pattern_approx, get_yours_and_mine_pred_math2, only_mind_changes)\n    scores_math = get_score_from_results(results_math, ["layer", "pos", "row", "col"])\n    print(f"Math Acc: {scores_math[\'Accuracy\'][1, 10, *label_to_tuple(\'D3\')].item():.4f}")\n    # with open(f"results_math_{approx_str}_tesing.pkl", "wb") as file:\n    #     pickle.dump(results_math, file)\n    print("Test")\n    results_test = get_yours_and_mine_pred_results(batches, batch_size, use_attn_pattern_approx, get_yours_and_mine_pred_old, only_mind_changes)\n    scores_test = get_score_from_results(resu

In [7]:
def orthogonalize_vectors(a, B, normalize=True):
    """Orthogonalizes vector a against a list of vectors B without in-place modification using PyTorch"""
    orthogonal_a = a.clone()  # Create a copy of a to avoid in-place modification
    B_prev = []
    for b in B:
        if not all([b @ b_prev < 1e-6 for b_prev in B_prev]):
            b = orthogonalize_vectors(b, B_prev)
        # Project orthogonal_a onto b
        projection = einops.repeat(einops.einsum(a, b, "... d_model, d_model -> ...") / t.dot(b, b), "... -> ... d_model", d_model = b.shape[0]) * b
        # Update orthogonal_a by subtracting the projection
        orthogonal_a = orthogonal_a - projection
        B_prev += [b]
    
    # Normalize the resulting vector orthogonal_a
    if normalize:
        orthogonal_a = orthogonal_a / t.norm(orthogonal_a)
    
    return orthogonal_a

In [59]:
tile_label = "D3"
layer = 1
tile_tuple = label_to_tuple(tile_label)
cache = focus_cache
pos = 10


# EXTRA
flipped_probe_normalized_s : Float[Tensor, "d_model"] = flipped_probe_normalized[:, layer, *tile_tuple]
yours_probe_s_prev = get_probe(layer-1, "linear", "post")[0, :, *tile_tuple, YOURS].detach()
mine_probe_s_prev = get_probe(layer-1, "linear", "post")[0, :, *tile_tuple, MINE].detach()
yours_probe_s_curr = get_probe(layer, "linear", "post")[0, :, *tile_tuple, YOURS].detach()
mine_probe_s_curr = get_probe(layer, "linear", "post")[0, :, *tile_tuple, MINE].detach()
# othogonalize yours to flipped and mine to flipped and yours
yours_probe_normalized_s_prev = orthogonalize_vectors(yours_probe_s_prev, [flipped_probe_normalized_s])
mine_probe_normalized_s_prev = orthogonalize_vectors(mine_probe_s_prev, [flipped_probe_normalized_s, yours_probe_normalized_s_prev])
# probes = [flipped_probe_normalized_s, yours_probe_s_prev, mine_probe_s_prev]
# probes = [flipped_probe_normalized_s]
# probes_new = []
# for probe in probes:
#     probe = orthogonalize_vectors(probe, probes_new, normalize=True)
#     probes_new += [probe]

mine_probe_normalized_s_prev = orthogonalize_vectors(mine_probe_s_prev, [flipped_probe_normalized_s])
yours_probe_normalized_s_prev = orthogonalize_vectors(yours_probe_s_prev, [flipped_probe_normalized_s, mine_probe_normalized_s_prev])
avg_resids = get_avg_resid(layer, 200)
# avg_resids = orthogonalize_vectors(avg_resids, probes_new, normalize=True)
avg_resids = orthogonalize_vectors(avg_resids, [flipped_probe_normalized_s, yours_probe_normalized_s_prev, mine_probe_normalized_s_prev], normalize=True)
avg_resids_after_OV = einops.einsum(avg_resids, OV.AB[layer, :], "pos d_model_in, head_idx d_model_in d_model_out -> d_model_out head_idx pos")
yours_logits_pred_head_pos_avg_resid = einops.einsum(avg_resids_after_OV, yours_probe_s_curr, "d_model head_idx pos, d_model -> head_idx pos")
mine_logits_pred_head_pos_avg_resid = einops.einsum(avg_resids_after_OV, mine_probe_s_curr, "d_model head_idx pos, d_model -> head_idx pos")

'''projection = einops.einsum(avg_resids, yours_probe_normalized_s, "pos d_model, d_model -> pos") / yours_probe_normalized_s
avg_resids_without_flipped_s = einops.repeat(avg_resids, "pos_to d_model -> d_model pos_to row col", row = 8, col = 8) - projection
avg_resids_without_flipped[layer] = avg_resids_without_flipped_s
avg_resids_after_OV = einops.einsum(avg_resids_without_flipped_s, OV.AB[layer, :], "d_model_in pos_to row col, head_idx d_model_in d_model_out -> d_model_out head_idx pos_to row col")
avg_resids_yours_bias[layer] = einops.einsum(avg_resids_after_OV, yours_probe_s, "d_model head_idx pos_to row col, d_model row col -> head_idx pos_to row col")
avg_resids_mine_bias[layer] = einops.einsum(avg_resids_after_OV, mine_probe_s, "d_model head_idx pos_to row col, d_model row col -> head_idx pos_to row col")

flipped_after_OV = einops.einsum(flipped_probe_normalized, OV.AB, "d_model_in layer row col, layer head_idx d_model_in d_model_out -> d_model_out layer head_idx row col")
conversion_factors_yours = einops.einsum(flipped_after_OV, yours_probe, "d_model layer head_idx row col, d_model layer row col -> layer head_idx row col")
conversion_factors_mine = einops.einsum(flipped_after_OV, mine_probe, "d_model layer head_idx row col, d_model layer row col -> layer head_idx row col")'''

resid_real = cache[f"blocks.{layer}.ln1.hook_normalized"]
batch_size = resid_real.shape[0]
tile_tuple = label_to_tuple(tile_label)

# FUNCTION CALL
conversion_factors_mine_s : Float[Tensor, "head_idx"] = conversion_factors_mine[layer, :, *tile_tuple]
conversion_factors_yours_s : Float[Tensor, "head_idx"] = conversion_factors_yours[layer, :, *tile_tuple]
avg_resids_mine_bias_s : Float[Tensor, "head_idx pos_to"] = avg_resids_mine_bias[layer, :, :, *tile_tuple]
avg_resids_yours_bias_s : Float[Tensor, "head_idx pos_to"] = avg_resids_yours_bias[layer, :, :, *tile_tuple]
# "layer head_idx pos_to row col"
# TODO: Use Avg Resid, Dont Remove negative Flipped Logits
use_attn_pattern_approx = False
attention_pattern : Float[Tensor, "head pos_from pos_to"] = get_attn_pattern(layer, use_attn_pattern_approx, 50, cache)
# flipped_logit = einops.einsum(resid_real, flipped_probe_normalized_s, "batch pos_to d_model, d_model -> batch pos_to")
# not_flipped_logit = einops.einsum(resid_real, probe_not_flipped_normalized, "batch pos d_model, d_model -> batch pos")
# Negative Logits are doing a lot work ..
# flipped_logit = flipped_logit * ((flipped_logit > not_flipped_logit) & (flipped_logit > 0)).to(device)
# flipped_logit = t.max(flipped_logit, t.zeros_like(flipped_logit).to(device))
'''flipped_logit = einops.repeat(flipped_logit, "batch pos_to -> batch head_idx pos_to", head_idx=8)
yours_logits_pred_head_pos = einops.repeat(flipped_logit * einops.repeat(conversion_factors_yours_s, "head_idx -> head_idx pos_to", pos_to=59) + avg_resids_yours_bias_s, "batch head_idx pos_to -> batch head_idx pos_from pos_to", pos_from=59) * attention_pattern
yours_logits_pred = einops.reduce(yours_logits_pred_head_pos, "batch head_idx pos_from pos_to -> batch pos_from", "sum")
mine_logits_pred_head_pos = einops.repeat(flipped_logit * einops.repeat(conversion_factors_mine_s, "head_idx -> head_idx pos_to", pos_to=59) + avg_resids_mine_bias_s, "batch head_idx pos_to -> batch head_idx pos_from pos_to", pos_from=59) * attention_pattern
mine_logits_pred = einops.reduce(mine_logits_pred_head_pos, "batch head_idx pos_from pos_to -> batch pos_from", "sum")'''
yours_logits_pred_head_pos_sum = t.zeros((batch_size, 8, 59)).to(device)
mine_logits_pred_head_pos_sum = t.zeros((batch_size, 8, 59)).to(device)
# for probe in probes_new:
for probe in [flipped_probe_normalized_s, yours_probe_normalized_s_prev, mine_probe_normalized_s_prev]:
    logit = einops.einsum(resid_real, probe, "batch pos_to d_model, d_model -> batch pos_to")
    probe_scales = logit * einops.repeat(probe, "d_model -> d_model batch pos_to", batch=logit.shape[0], pos_to=logit.shape[1])
    probe_after_OV = einops.einsum(probe_scales, OV.AB[layer, :], "d_model batch pos_to, head_idx d_model d_model_out -> d_model_out batch head_idx pos_to")
    yours_logits_pred_head_pos = einops.einsum(probe_after_OV, yours_probe_s_curr, "d_model batch head_idx pos_to, d_model -> batch head_idx pos_to")
    mine_logits_pred_head_pos = einops.einsum(probe_after_OV, mine_probe_s_curr, "d_model batch head_idx pos_to, d_model -> batch head_idx pos_to")
    yours_logits_pred_head_pos_sum += yours_logits_pred_head_pos
    mine_logits_pred_head_pos_sum += mine_logits_pred_head_pos
yours_logits_pred_head_pos_sum += yours_logits_pred_head_pos_avg_resid
mine_logits_pred_head_pos_sum += mine_logits_pred_head_pos_avg_resid
yours_logits_pred_head_pos_sum = einops.repeat(yours_logits_pred_head_pos_sum, "batch head_idx pos_to -> batch head_idx pos_from pos_to", pos_from=59) * attention_pattern
yours_logits_pred = einops.reduce(yours_logits_pred_head_pos_sum, "batch head_idx pos_from pos_to -> batch pos_from", "sum")
mine_logits_pred_head_pos_sum = einops.repeat(mine_logits_pred_head_pos_sum, "batch head_idx pos_to -> batch head_idx pos_from pos_to", pos_from=59) * attention_pattern
mine_logits_pred = einops.reduce(mine_logits_pred_head_pos_sum, "batch head_idx pos_from pos_to -> batch pos_from", "sum")
yours_logits_pred, mine_logits_pred = add_bias(yours_logits_pred, mine_logits_pred, layer, tile_tuple)
yours_logits, mine_logits, empty_logits = get_logits_real(layer, tile_tuple, cache=cache)
all_logits = t.stack([empty_logits, mine_logits, yours_logits], dim=-1)
mask = all_logits.argmax(dim=-1) != 0
logits_diff = yours_logits - mine_logits
logits_pred_diff = yours_logits_pred - mine_logits_pred
# TOOD: Evaluate only on not empty tiles
correct = (logits_pred_diff > 0) == (logits_diff > 0)
# if only_mind_changes:
#     mind_change_mask = get_mind_change_mask(cache, layer, tile_tuple)
#     mask = mask * mind_change_mask
num_games = correct.shape[0]
for i in range(num_games):
    if mask[i, pos] == 0:
        continue
    if correct[i, pos] == 1:
        continue
    print(f"Game: {i}")

Game: 0
Game: 3
Game: 5
Game: 28
Game: 47


In [11]:
def get_activation(act_names, num_games, start=0):
    # TODO: If this takes to long or something, Make a filter step!
    act_name_results = {act_name : [] for act_name in act_names}
    inference_size = 1000
    for batch in range(start, start+num_games, inference_size):
        with t.inference_mode():
            _, cache = model.run_with_cache(
                board_seqs_int[batch:batch+inference_size, :-1].to(device),
                return_type=None,
                names_filter=lambda name: name in act_names
                # names_filter=lambda name: name == f"blocks.{layer}.hook_resid_mid" or name == f"blocks.{layer}.mlp.hook_post"
                # names_filter=lambda name: name == f"blocks.{layer}.hook_resid_pre" or name == f"blocks.{layer}.mlp.hook_post"
            )
        for act_name in act_names:
            act_name_results[act_name] += [cache[act_name]]
    for act_name in act_names:
        act_name_results[act_name] = t.cat(act_name_results[act_name], dim=0)
        act_name_results[act_name] = act_name_results[act_name].detach()[:num_games]
    return act_name_results

In [12]:
# This function is weird but okay
def get_probe2(direction_str : str, layer, tile : tuple):
    if direction_str == "mine":
        return mine_probe[:, layer, *tile]
    elif direction_str == "yours":
        return yours_probe[:, layer, *tile]
    elif direction_str == "flipped":
        return flipped_probe_normalized[:, layer+1, *tile]

def get_neuron_out_direction_scaled(pos_to, neuron, layer, game=None, no_mean=True):
    if game is not None:
        neruon_activation = focus_cache["mlp_post", layer-1][game, pos_to, neuron]
    else:
        if pos_to is not None:
            neruon_activations = focus_cache["mlp_post", layer-1][:, pos_to, neuron]
        else:
            neruon_activations = focus_cache["mlp_post", layer-1][:, :, neuron]
        neruon_activations_positive = neruon_activations[neruon_activations > 0]
        neruon_activation = neruon_activations_positive.mean()
    if no_mean:
        neruon_activation = t.Tensor([1]).to(device)
    direction = model.W_out[layer-1, neuron].detach()
    direction_scaled = neruon_activation * direction
    return direction_scaled, neruon_activation.item()

def get_logit_after_ov(layer, head, direction_scaled, direction_str, tile):
    flipped_dir = get_probe2("flipped", layer-1, tile)
    mine_dir = get_probe2("mine", layer-1, tile)
    yours_dir = get_probe2("yours", layer-1, tile)
    dir_next= get_probe2(direction_str, layer, tile)
    direction_scaled = orthogonalize_vectors(direction_scaled, [flipped_dir, yours_dir, mine_dir], normalize=False)
    direction_after_OV = einops.einsum(direction_scaled, OV.AB[layer, head], "d_model_in, d_model_in d_model_out -> d_model_out")
    logit = einops.einsum(direction_after_OV, dir_next, "d_model, d_model -> ")
    return logit.item()

def get_logit_before_ov(layer, head, direction_scaled, direction_str, tile):
    # flipped_dir = flipped_probe_normalized[:, layer, *tile]
    # mine_dir = mine_probe[:, layer-1, *tile]#
    dir = get_probe2(direction_str, layer-1, tile)
    # yours_dir = yours_probe[:, layer-1, *tile]
    logit = einops.einsum(direction_scaled, dir, "d_model, d_model -> ")
    return logit.item()

def get_activation_for_all_neurons(pos_to, layer, head, tile, game=None, get_logit_function=get_logit_after_ov, direction_str="yours"):
    activations_list = []
    logits_list = []
    for neuron in range(2048):
        direction_scaled, neuron_activation = get_neuron_out_direction_scaled(pos_to, neuron, layer, game)
        logit = get_logit_function(layer, head, direction_scaled, direction_str, tile)
        activations_list.append(neuron_activation)
        logits_list.append(logit)
    return activations_list, logits_list

In [13]:
def get_neuron_weights_as_probes(layer, tile_tuple, how_many=40):
    final_neurons = []
    direction_strs = ["yours", "mine", "flipped"]
    # direction_strs = ["yours", "mine"]
    for direction_str in direction_strs:
        logits_list_result = [0] * 2048
        for head in range(8):
            activations_list, logits_list = get_activation_for_all_neurons(None, layer, head, tile_tuple, game=None, get_logit_function=get_logit_before_ov, direction_str=direction_str)
            logits_list_result = [max(logits_list_result[i], logits_list[i]) for i in range(2048)]
        # get the top 40 neuron, ideces
        neuron_with_logits_list = zip(logits_list_result, range(2048))
        neuron_with_logits_list = sorted(neuron_with_logits_list, key=lambda x: x[0], reverse=True)
        top_neurons = [neuron for _, neuron in neuron_with_logits_list[:how_many]]
        final_neurons += top_neurons
    final_neurons = list(set(final_neurons))
    final_directions = [model.W_out[layer, neuron].detach() for neuron in final_neurons]
    return final_directions

In [14]:
from tqdm import tqdm

In [15]:
def get_probes_curr():
    yours_probe_s_curr = get_probe(layer, "linear", "post")[0, :, *tile_tuple, YOURS].detach()
    mine_probe_s_curr = get_probe(layer, "linear", "post")[0, :, *tile_tuple, MINE].detach()
    return yours_probe_s_curr, mine_probe_s_curr


def get_logits_pred_diff(probes, layer, num_games, resid_real, attention_pattern):
    yours_probe_s_curr, mine_probe_s_curr = get_probes_curr()
    avg_resids = get_avg_resid(layer, 200)
    probes += [avg_resids]

    # resid_real = cache[f"blocks.{layer}.ln1.hook_normalized"]
    # act_name = f"blocks.{layer}.ln1.hook_normalized"
    # resid_real = get_activation([act_name], num_games, cache_start)[act_name]
    new_probes = []
    for probe in probes:
        probe = orthogonalize_vectors(probe, new_probes)
        new_probes += [probe]
    probes = new_probes
    # attention pattern
    # TODO: don't use the old cache!
    # attention_pattern : Float[Tensor, "head pos_from pos_to"] = get_attn_pattern(layer, use_attn_pattern_approx, num_games, cache)
    yours_logits_pred_head_pos_sum = t.zeros((num_games, 8, 59)).to(device)
    mine_logits_pred_head_pos_sum = t.zeros((num_games, 8, 59)).to(device)
    for probe in probes:
        if probe.shape[0] == 59:
            logit = einops.einsum(resid_real, probe, "batch pos_to d_model, pos_to d_model -> batch pos_to")
            probe_scales = logit * einops.repeat(probe, "pos_to d_model -> d_model batch pos_to", batch=logit.shape[0])
        else:
            logit = einops.einsum(resid_real, probe, "batch pos_to d_model, d_model -> batch pos_to")
            probe_scales = logit * einops.repeat(probe, "d_model -> d_model batch pos_to", batch=logit.shape[0], pos_to=logit.shape[1])
        probe_after_OV = einops.einsum(probe_scales, OV.AB[layer, :], "d_model batch pos_to, head_idx d_model d_model_out -> d_model_out batch head_idx pos_to")
        yours_logits_pred_head_pos = einops.einsum(probe_after_OV, yours_probe_s_curr, "d_model batch head_idx pos_to, d_model -> batch head_idx pos_to")
        mine_logits_pred_head_pos = einops.einsum(probe_after_OV, mine_probe_s_curr, "d_model batch head_idx pos_to, d_model -> batch head_idx pos_to")
        yours_logits_pred_head_pos_sum += yours_logits_pred_head_pos
        mine_logits_pred_head_pos_sum += mine_logits_pred_head_pos
    yours_logits_pred_head_pos_sum = einops.repeat(yours_logits_pred_head_pos_sum, "batch head_idx pos_to -> batch head_idx pos_from pos_to", pos_from=59) * attention_pattern
    yours_logits_pred = einops.reduce(yours_logits_pred_head_pos_sum, "batch head_idx pos_from pos_to -> batch pos_from", "sum")
    mine_logits_pred_head_pos_sum = einops.repeat(mine_logits_pred_head_pos_sum, "batch head_idx pos_to -> batch head_idx pos_from pos_to", pos_from=59) * attention_pattern
    mine_logits_pred = einops.reduce(mine_logits_pred_head_pos_sum, "batch head_idx pos_from pos_to -> batch pos_from", "sum")
    yours_logits_pred, mine_logits_pred = add_bias(yours_logits_pred, mine_logits_pred, layer, tile_tuple)
    logits_pred_diff = yours_logits_pred - mine_logits_pred
    return logits_pred_diff, yours_logits_pred_head_pos_sum, mine_logits_pred_head_pos_sum

def get_results_yours_mine_pred(layer, tile_tuple, get_probes_function, cache_start=50, num_games = 200, use_attn_pattern_approx = True, inference_size=1000):
    mask_list = []
    correct_list = []
    for batch in tqdm(range(cache_start, cache_start + num_games, inference_size)):
        fake_cache = get_activation([utils.get_act_name("attn_out", layer), f"blocks.{layer}.ln1.hook_normalized", utils.get_act_name("pattern", layer)], num_games=inference_size, start=batch)
        resid_real = fake_cache[f"blocks.{layer}.ln1.hook_normalized"]
        attn_out = fake_cache[utils.get_act_name("attn_out", layer)]
        attention_pattern = fake_cache[utils.get_act_name("pattern", layer)]
        if use_attn_pattern_approx:
            attention_pattern = get_attn_pattern(layer, use_attn_pattern_approx, batch_size=inference_size)
        probes = get_probes_function(layer, tile_tuple)
        yours_logits, mine_logits, empty_logits = get_logits_real(layer, tile_tuple, attn_out=attn_out)
        all_logits = t.stack([empty_logits, mine_logits, yours_logits], dim=-1)
        mask = all_logits.argmax(dim=-1) != 0
        logits_diff = yours_logits - mine_logits
        logits_pred_diff, yours_logits_pred_head_pos, mine_logits_pred_head_pos = get_logits_pred_diff(probes, layer, inference_size, resid_real, attention_pattern)
        correct = (logits_pred_diff > 0) == (logits_diff > 0)
        mask_list.append(mask)
        correct_list.append(correct)
    # if only_mind_changes:
    #     mind_change_mask = get_mind_change_mask(cache, layer, tile_tuple)
    #     mask = mask * mind_change_mask
    mask = t.cat(mask_list, dim=0)
    correct = t.cat(correct_list, dim=0)
    return mask, correct, yours_logits_pred_head_pos, mine_logits_pred_head_pos


def get_mean_of_yours_mine_pred(mask, correct):
    mean_result = t.zeros((59, )).to(device)
    not_blank_count = t.zeros((59, )).to(device)
    correct_count = t.zeros((59, )).to(device)
    for i in range(59):
        mean_result[i] = correct[:, i][mask[:, i] == 1].float().mean()
        not_blank_count[i] = mask[:, i].float().sum()
        correct_count[i] = correct[:, i][mask[:, i] == 1].float().sum()
    return mean_result, not_blank_count, correct_count


def get_probes_flipped_yours_mine(layer, tile_tuple):
    flipped_probe_normalized_s : Float[Tensor, "d_model"] = flipped_probe_normalized[:, layer, *tile_tuple]
    yours_probe_s_prev = get_probe(layer-1, "linear", "post")[0, :, *tile_tuple, YOURS].detach()
    mine_probe_s_prev = get_probe(layer-1, "linear", "post")[0, :, *tile_tuple, MINE].detach()
    probes = [flipped_probe_normalized_s, yours_probe_s_prev, mine_probe_s_prev]
    return probes

def get_probes_flipped(layer, tile_tuple):
    flipped_probe_normalized_s : Float[Tensor, "d_model"] = flipped_probe_normalized[:, layer, *tile_tuple]
    probes = [flipped_probe_normalized_s]
    return probes


def all_probes(layer, tile_tuple):
    probes = []
    probes += get_probes_flipped_yours_mine(layer, tile_tuple)
    probes += get_neuron_weights_as_probes(layer, tile_tuple)
    return probes

In [16]:
tile_label = "E6"
tile_tuple = label_to_tuple(tile_label)
# cache = focus_cache
pos = 10
use_attn_pattern_approx = False
layer = 1
cache_start = 0
num_games = 1000
# TODO: Make the num_games actually variable, without getting cuda memory error
# TODO: Then make new results, vary attn_pattern_approx, layer, tile, get_probe function, I could for example only use flipped ...
mask, correct, yours_logits_pred_head_pos, mine_logits_pred_head_pos = get_results_yours_mine_pred(layer, tile_tuple, get_probes_flipped, cache_start=cache_start, num_games=num_games, use_attn_pattern_approx=use_attn_pattern_approx, inference_size=1000)
mean, not_blank_count, correct_count = get_mean_of_yours_mine_pred(mask, correct)
print(mean)

100%|██████████| 1/1 [00:03<00:00,  3.16s/it]


tensor([   nan,    nan,    nan,    nan, 1.0000, 0.8333, 0.9630, 0.9545, 0.9273,
        0.9605, 0.9310, 0.9515, 0.9430, 0.9268, 0.9317, 0.9108, 0.9063, 0.9391,
        0.9038, 0.9005, 0.9070, 0.9165, 0.9038, 0.9073, 0.8870, 0.8825, 0.8908,
        0.8776, 0.8911, 0.9054, 0.8676, 0.8844, 0.8655, 0.8688, 0.8645, 0.8534,
        0.8562, 0.8626, 0.8370, 0.8338, 0.8132, 0.8032, 0.8148, 0.8160, 0.8176,
        0.8017, 0.8080, 0.7972, 0.8034, 0.7935, 0.8033, 0.7959, 0.7935, 0.8148,
        0.7759, 0.7791, 0.7650, 0.7696, 0.7379], device='cuda:0')


In [42]:
if __name__ == "__main__":
    num_games = 10000
    script_name = sys.argv[0]
    if len(sys.argv) < 2:
        print(f"Usage: {script_name} <dataset>")
        sys.exit(1)
    layer = int(sys.argv[1])
    final_scores = t.zeros((8, 8, 2, 59))
    cache_start = 200
    for get_probes_function in [get_probes_flipped_yours_mine, get_probes_flipped, all_probes]:
        for row in range(8):
            for col in range(8):
                for attn_pattern_approx in [True, False]:
                    tile_tuple = (row, col)
                    tile_label = tuple_to_label(tile_tuple)
                    mask, correct, yours_logits_pred_head_pos, mine_logits_pred_head_pos = get_results_yours_mine_pred(
                        layer,
                        tile_tuple,
                        get_probes_function,
                        cache_start=cache_start,
                        num_games=num_games,
                        use_attn_pattern_approx=use_attn_pattern_approx,
                        inference_size=1000
                    )
                    mean, not_blank_count, correct_count = get_mean_of_yours_mine_pred(mask, correct)
                    final_scores[layer, row, col, int(attn_pattern_approx)] = mean
        # save final_scores
        with open(f"attn_approx_results/attn_approx_mean_L{layer}.pkl", "wb") as file:
            pickle.dump(final_scores, file)
        # save not_blank_count
        with open(f"attn_approx_results/attn_approx_not_blank_count_L{layer}.pkl", "wb") as file:
            pickle.dump(not_blank_count, file)
        # save correct_count
        with open(f"attn_approx_results/attn_approx_correct_count_L{layer}.pkl", "wb") as file:
            pickle.dump(correct_count, file)
    print("Done!")


ValueError: invalid literal for int() with base 10: '--f=/hpi/fs00/home/jim.maar/.local/share/jupyter/runtime/kernel-v3ecfdcdf610773ef69b1781c5cc25bd53d357d5fb.json'

In [None]:
'''
tensor(
[   nan,    nan,    nan, 1.0000, 0.8388, 0.8835, 0.9248, 0.9291, 0.9348,
0.9535, 0.9370, 0.9564, 0.9309, 0.9434, 0.9268, 0.9445, 0.9253, 0.9333,
0.9213, 0.9333, 0.9219, 0.9206, 0.9153, 0.9211, 0.9059, 0.9138, 0.9086,
0.9126, 0.8983, 0.9089, 0.8988, 0.9022, 0.8979, 0.9024, 0.8990, 0.8961,
0.8891, 0.8880, 0.8885, 0.8833, 0.8729, 0.8824, 0.8781, 0.8839, 0.8770,
0.8788, 0.8652, 0.8718, 0.8639, 0.8671, 0.8648, 0.8589, 0.8536, 0.8521,
0.8428, 0.8383, 0.8363, 0.8327, 0.8300], device='cuda:0')'''

In [None]:
# TODO: Convert this to python and add __main__ arguments and 
# TOOD: Add saving functionality
# Ich finds grade schade, dass das mit den ganzen Neuron Probes immer cuda error gibt ... ich würde gerne das comparen ...


In [16]:
for i in range(100):
    if mask[i, pos] == 0:
        continue
    if correct[i, pos] == 1:
        continue
    print(f"Game: {cache_start + i}")

NameError: name 'cache_start' is not defined

- When using L1 flipped: (5 weg, aber 20 und 31 dazu)
    - Game: 0
        Game: 3
        Game: 20
        Game: 28
        Game: 31
        Game: 47
- L0 Flipped
    - Game: 0
    Game: 3
    Game: 5
    Game: 28
    Game: 47

In [18]:
game = 0

In [60]:
# Stimmt nicht mit anderen überein, kommt weil ich für mathe dings den resid_mid probe benutzt habe ... (Vielleicht gut, das zu ändern)
result_dict = {
    "yours_logits" : yours_logits[game, pos].item(),
    "yours_logits_pred" : yours_logits_pred[game, pos].item(),
    "mine_logits" : mine_logits[game, pos].item(),
    "mine_logits_pred" : mine_logits_pred[game, pos].item(),
    "logits_diff" : logits_diff[game, pos].item(),
    "logits_pred_diff" : logits_pred_diff[game, pos].item(),
}
# yours_logits[game, pos].item(), yours_logits_pred[game, pos].item(), logits_diff[game, pos].item(), logits_pred_diff[game, pos].item(),
results_df = pd.DataFrame(result_dict, index=[0])
# round
results_df = results_df.round(4)
results_df

Unnamed: 0,yours_logits,yours_logits_pred,mine_logits,mine_logits_pred,logits_diff,logits_pred_diff
0,0.9498,0.216,-0.6512,0.572,1.601,-0.3559


In [20]:
yours_logits_pred_head_pos_sum[game, :, :, pos].shape
# -0.3812	0.2438	1.3489	0.2263	-1.7301	0.0175

torch.Size([8, 59])

In [27]:
tile = tile_tuple
layer = 2
layer

2

In [41]:
layer

1

In [48]:
game = 0
'''
Game: 0
Game: 2
Game: 3
Game: 5
Game: 16
Game: 33
Game: 47
'''
layer = 1

In [61]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def direciton_attribution(game_idx, layer, pos, direction, with_heads = False):
    attributions = {}
    for module_name2 in ["resid_pre", "attn_out", "mlp_out"]:
        resid : Float[Tensor, "pos d_model"] = focus_cache[module_name2, layer][game_idx]
        attribution = einops.einsum(resid, direction, "pos d_model, d_model -> pos")
        print(f"Attributino of {module_name2} to the {direction_label} direction of {pos}: {attribution[pos].item():.2f}")
        attributions[f"{module_name2} ({layer})"] = attribution[:].unsqueeze(0)

    if not with_heads:
        return attributions
    for head in range(8):
        # print(focus_cache.keys())
        resid : Float[Tensor, "pos d_model"] = focus_cache["z", layer][game_idx, :, head, :]
        resid = einops.einsum(resid, model.W_O[layer, head], "pos d_head, d_head d_model -> pos d_model").detach()
        # print(resid.shape)
        attribution = einops.einsum(resid, direction, "pos d_model, d_model -> pos")
        attributions[f"head_{head} ({layer})"] = attribution[:].unsqueeze(0)
    return attributions

def plot_direction_attribution(game_idx, layer, pos, direction, title_text, save=False):
    attributions = direciton_attribution(game_idx, layer, pos, direction)
    fig = make_subplots(rows=len(attributions), cols=1, subplot_titles=list(attributions.keys()), shared_xaxes=True)
    abs_max = max([t.abs(attributions[module_name][:, 1:]).max().item() for module_name in attributions])
    for i, module_name in enumerate(attributions):
        # print(attributions[module_name].shape)
        fig.add_trace(go.Heatmap(z=attributions[module_name].cpu().numpy(), colorscale="RdBu", zmin=-abs_max, zmax=abs_max), row=i+1, col=1)
        fig.update_yaxes(showticklabels=False, row=i+1, col=1)
    # fig.update_layout(title_text=title_text, height = 1000, width = 1500)
    # fig.update_layout(title_text=title_text)
    if save:
        save_plotly(fig, title_text)
    fig.show()

def plot_logits_head_pos(game, pos, logits_head_pos, title_text="", save=False, abs_max = None):
    # attributions = logits_head_pos[game, :, :pos+1, pos].detach()
    attributions = logits_head_pos[game, :, pos, :pos+1].detach()
    if abs_max is None:
        abs_max = t.abs(attributions).max().item()
    fig = go.Figure(data=go.Heatmap(z=attributions.cpu().numpy(), colorscale="RdBu", zmin=-abs_max, zmax=abs_max))
    fig.update_layout(title = "Predicted Logit Attribution", xaxis_title="Position", yaxis_title="Head")
    # fig.update_layout(title_text=title_text)
    if save:
        save_plotly(fig, title_text)
    fig.show()

def plot_direction_attribution_head_pos(game_idx, layer, pos, direction, title_text, save=False):
    attributions = t.zeros((8, pos+1))
    for head in range(8):
        # print(focus_cache.keys())
        v : Float[Tensor, "pos, d_head"] = focus_cache["v", layer][game_idx, :, head, :]
        pattern : Float[Tensor, "pos"] = focus_cache["pattern", layer][game_idx, head, pos, :]
        z : Float[Tensor, "pos d_head"] = pattern.unsqueeze(dim=-1) * v
        result = einops.einsum(z, model.W_O[layer, head], "pos d_head, d_head d_model -> pos d_model").detach()
        # print(resid.shape)
        attribution = einops.einsum(result, direction, "pos d_model, d_model -> pos")
        attributions[head] = attribution[:pos+1]
    # create heatmap
    abs_max = t.abs(attributions).max().item()
    fig = go.Figure(data=go.Heatmap(z=attributions.cpu().numpy(), colorscale="RdBu", zmin=-abs_max, zmax=abs_max))
    fig.update_layout(title = "Real Logit Attribution", xaxis_title="Position", yaxis_title="Head")
    # fig.update_layout(title_text=title_text)
    if save:
        save_plotly(fig, title_text)
    fig.show()
    return abs_max

# game = 5
# tile_label = "B4"
direction_label = "yours"
tile = label_to_tuple(tile_label)
direction = get_probe(layer, "linear", "post")[0, :, *tile, get_direction_int(direction_label)]
if direction_label == "mine":
    logits_pred_head_pos = mine_logits_pred_head_pos_sum
else:
    logits_pred_head_pos = yours_logits_pred_head_pos_sum

plot_direction_attribution(game, layer, pos, direction, title_text=f"Attributions of Layer {layer} to the {direction_label} direction of {tile_label} (We are looking at Position {pos}) Game {game}", save=False)
abs_max = plot_direction_attribution_head_pos(game, layer, pos, direction, title_text=f"Direct Logit Attribution per head and position for ({tile_label} is {direction_label}) at pos {pos} in game {game} Real", save = True)
plot_logits_head_pos(game, pos, logits_pred_head_pos, abs_max=abs_max, title_text=f"Direct Logit Attribution per head and position for ({tile_label} is {direction_label}) at pos {pos} in game {game} Predicted", save = True)

Attributino of resid_pre to the yours direction of 10: 1.57
Attributino of attn_out to the yours direction of 10: 0.95
Attributino of mlp_out to the yours direction of 10: 1.01


In [29]:
# Position 5 hat andere Farbe als Position 10, aber es gibt trotzdem leichte positive Dingers
# Position 7 gibt leichte positive Logits, das macht Sinn, weil es da die gleiche Farbe hat
model.W_out.shape

torch.Size([8, 2048, 512])

In [23]:
'''# This function is weird but okay
def get_probe2(direction_str : str, layer, tile : tuple):
    if direction_str == "mine":
        return mine_probe[:, layer, *tile]
    elif direction_str == "yours":
        return yours_probe[:, layer, *tile]
    elif direction_str == "flipped":
        return flipped_probe_normalized[:, layer+1, *tile]

def get_neuron_out_direction_scaled(pos_to, neuron, layer, game=None):
    if game is not None:
        neruon_activation = focus_cache["mlp_post", layer-1][game, pos_to, neuron]
    else:
        if pos_to is not None:
            neruon_activations = focus_cache["mlp_post", layer-1][:, pos_to, neuron]
        else:
            neruon_activations = focus_cache["mlp_post", layer-1][:, :, neuron]
        neruon_activations_positive = neruon_activations[neruon_activations > 0]
        neruon_activation = neruon_activations_positive.mean()
    direction = model.W_out[layer-1, neuron]
    direction_scaled = neruon_activation * direction
    return direction_scaled, neruon_activation.item()

def get_logit_after_ov(layer, head, direction_scaled, direction_str):
    flipped_dir = get_probe2("flipped", layer-1, tile)
    mine_dir = get_probe2("mine", layer-1, tile)
    yours_dir = get_probe2("yours", layer-1, tile)
    dir_next= get_probe2(direction_str, layer, tile)
    direction_scaled = orthogonalize_vectors(direction_scaled, [flipped_dir, yours_dir, mine_dir], normalize=False)
    direction_after_OV = einops.einsum(direction_scaled, OV.AB[layer, head], "d_model_in, d_model_in d_model_out -> d_model_out")
    logit = einops.einsum(direction_after_OV, dir_next, "d_model, d_model -> ")
    return logit.item()

def get_logit_before_ov(layer, head, direction_scaled, direction_str):
    # flipped_dir = flipped_probe_normalized[:, layer, *tile]
    # mine_dir = mine_probe[:, layer-1, *tile]#
    dir = get_probe(direction_str, layer-1, tile)
    # yours_dir = yours_probe[:, layer-1, *tile]
    logit = einops.einsum(direction_scaled, dir, "d_model, d_model -> ")
    return logit.item()

def get_activation_for_all_neurons(pos_to, layer, head, game=None, get_logit_function=get_logit_after_ov, direction_str="yours"):
    activations_list = []
    logits_list = []
    for neuron in range(2048):
        direction_scaled, neuron_activation = get_neuron_out_direction_scaled(pos_to, neuron, layer, game)
        logit = get_logit_function(layer, head, direction_scaled, direction_str)
        activations_list.append(neuron_activation)
        logits_list.append(logit)
    return activations_list, logits_list'''

'# This function is weird but okay\ndef get_probe2(direction_str : str, layer, tile : tuple):\n    if direction_str == "mine":\n        return mine_probe[:, layer, *tile]\n    elif direction_str == "yours":\n        return yours_probe[:, layer, *tile]\n    elif direction_str == "flipped":\n        return flipped_probe_normalized[:, layer+1, *tile]\n\ndef get_neuron_out_direction_scaled(pos_to, neuron, layer, game=None):\n    if game is not None:\n        neruon_activation = focus_cache["mlp_post", layer-1][game, pos_to, neuron]\n    else:\n        if pos_to is not None:\n            neruon_activations = focus_cache["mlp_post", layer-1][:, pos_to, neuron]\n        else:\n            neruon_activations = focus_cache["mlp_post", layer-1][:, :, neuron]\n        neruon_activations_positive = neruon_activations[neruon_activations > 0]\n        neruon_activation = neruon_activations_positive.mean()\n    direction = model.W_out[layer-1, neuron]\n    direction_scaled = neruon_activation * direc

In [63]:
# resid_real
head = 3
pos_to = 7
logits_list = []
activations_list = []
tile_label = "D3"
tile = label_to_tuple(tile_label)

activations_list, logits_list = get_activation_for_all_neurons(None, layer, head, tile, game=None, get_logit_function=get_logit_after_ov, direction_str="yours")
attributions = np.array(logits_list)
fig = px.scatter(x=list(range(attributions.shape[0])), y=attributions.flatten(), labels={"x": "Neuron Index", "y": "Attribution"})
fig.update_layout(title_text="title_text")
# fig.update_yaxes(exponentformat = 'E')
fig.show()
# NOT Flipped und Yours tragen auch zum Dings hinzu
# Okay. Ich könnte mir auch noch MLP Dinger anschauen
# Dann weiter überlegen

neuron_with_logits_list = zip(attributions.round(decimals=2), range(2048), np.array(activations_list).round(decimals=2))
# neuron_with_logits_list = zip(neuron_with_logits_list, np.array(activations_list).round(decimals=2))
neuron_with_logits_list = sorted(neuron_with_logits_list, key=lambda x: x[0], reverse=True)
print(neuron_with_logits_list[:10])

[(0.37, 391, 1.0), (0.32, 258, 1.0), (0.32, 506, 1.0), (0.32, 521, 1.0), (0.31, 766, 1.0), (0.3, 1486, 1.0), (0.25, 1756, 1.0), (0.25, 1970, 1.0), (0.24, 1575, 1.0), (0.24, 1733, 1.0)]


In [25]:
'''how_many = 40
final_neurons = []
for direction_str in ["yours", "mine", "flipped"]:
    logits_list_result = [0] * 2048
    for head in range(8):
        activations_list, logits_list = get_activation_for_all_neurons(None, layer, head, game=None, get_logit_function=get_logit_before_ov, direction_str=direction_str)
        logits_list_result = [max(logits_list_result[i], logits_list[i]) for i in range(2048)]
    # get the top 40 neuron, ideces
    neuron_with_logits_list = zip(logits_list_result, range(2048))
    neuron_with_logits_list = sorted(neuron_with_logits_list, key=lambda x: x[0], reverse=True)
    top_neurons = [neuron for _, neuron in neuron_with_logits_list[:how_many]]
    final_neurons += top_neurons
final_neurons = list(set(final_neurons))
final_directions = [model.W_out[layer, neuron] for neuron in final_neurons]'''

'how_many = 40\nfinal_neurons = []\nfor direction_str in ["yours", "mine", "flipped"]:\n    logits_list_result = [0] * 2048\n    for head in range(8):\n        activations_list, logits_list = get_activation_for_all_neurons(None, layer, head, game=None, get_logit_function=get_logit_before_ov, direction_str=direction_str)\n        logits_list_result = [max(logits_list_result[i], logits_list[i]) for i in range(2048)]\n    # get the top 40 neuron, ideces\n    neuron_with_logits_list = zip(logits_list_result, range(2048))\n    neuron_with_logits_list = sorted(neuron_with_logits_list, key=lambda x: x[0], reverse=True)\n    top_neurons = [neuron for _, neuron in neuron_with_logits_list[:how_many]]\n    final_neurons += top_neurons\nfinal_neurons = list(set(final_neurons))\nfinal_directions = [model.W_out[layer, neuron] for neuron in final_neurons]'

In [26]:
final_directions = [model.W_out[layer, neuron] for neuron in final_neurons]

In [27]:
model.W_out.shape

torch.Size([8, 2048, 512])

In [28]:
'''for probe_name in ["linear", "flipped", "placed", "legal", "accesible"]:
    for direction_str in probe_directions[probe_name]:
        probe = get_probe(layer, probe_name, "post")
        direction = probe[0, :, *tile, get_direction_int(direction_str)]
        direction_normalized = direction / direction.norm()
        direction_logit = einops.einsum(resid_real[game, pos_to], direction_normalized, "d_model, d_model -> ")
        direction_scaled = direction_logit * direction_normalized
        direction_after_OV = einops.einsum(direction_scaled, OV.AB[layer, head], "d_model_in, d_model_in d_model_out -> d_model_out")
        yours_logit = einops.einsum(direction_after_OV, yours_probe[:, layer, *tile], "d_model, d_model -> ")
        yours_logits_list.append(yours_logit.item())'''
for neuron in range(2048):
    direction = model.W_out[layer-1, neuron]
    # direction_normalized = direction / direction.norm()
    # direction_logit = einops.einsum(resid_real[game, pos_to], direction_normalized, "d_model, d_model -> ")
    # direction_scaled = direction_logit * direction_normalized
    '''neuron_acitvation = focus_cache["mlp_post", layer-1][game, pos_to, neuron]
    direction_scaled = neuron_acitvation * direction
    flipped_dir = flipped_probe_normalized[:, layer, *tile]
    mine_dir = mine_probe[:, layer-1, *tile]
    yours_dir = yours_probe[:, layer-1, *tile]
    direction_scaled = orthogonalize_vectors(direction_scaled, [flipped_dir, yours_dir, mine_dir], normalize=False)'''
    # print(direction_scaled @ flipped_dir, direction_scaled @ yours_dir, direction_scaled @ mine_dir)
    neuron_acitvations = focus_cache["mlp_post", layer-1][:, pos_to, neuron]
    neuron_acitvations_positive = neuron_acitvations[neuron_acitvations > 0]
    neuron_acitvation = neuron_acitvations_positive.mean()
    direction_scaled = neuron_acitvation * direction
    flipped_dir = flipped_probe_normalized[:, layer, *tile]
    mine_dir = mine_probe[:, layer-1, *tile]
    yours_dir = yours_probe[:, layer-1, *tile]
    direction_scaled = orthogonalize_vectors(direction_scaled, [flipped_dir, yours_dir, mine_dir], normalize=False)
    direction_after_OV = einops.einsum(direction_scaled, OV.AB[layer, head], "d_model_in, d_model_in d_model_out -> d_model_out")
    yours_logit = einops.einsum(direction_after_OV, yours_probe[:, layer, *tile], "d_model, d_model -> ")
    yours_logits_list.append(yours_logit.item())
    activations_list += [neuron_acitvation.item()]

NameError: name 'yours_logits_list' is not defined

In [46]:
sum([t[0] for t in neuron_with_logits_list[:50]])

nan

In [None]:
# [(0.17, 1114, 0.92), (0.17, 1126, 0.94), (0.15, 1166, 1.85), (0.1, 1442, 2.54), (0.03, 202, -0.17), (0.03, 1486, -0.13), (0.03, 1575, -0.15), (0.02, 11, -0.17), (0.02, 388, 0.89), (0.02, 432, -0.13)] orthogonalize to everything
# [(0.17, 1114, 0.92), (0.17, 1126, 0.94), (0.16, 1166, 1.85), (0.1, 1442, 2.54), (0.03, 202, -0.17), (0.03, 460, 1.83), (0.03, 1486, -0.13), (0.03, 1575, -0.15), (0.02, 11, -0.17), (0.02, 388, 0.89)] orthogonalize to layer 1 flipped
# [(0.19, 1114, 0.92), (0.19, 1126, 0.94), (0.19, 1166, 1.85), (0.14, 1442, 2.54), (0.04, 1486, -0.13), (0.04, 1575, -0.15), (0.03, 202, -0.17), (0.03, 432, -0.13), (0.03, 2001, 0.13), (0.02, 11, -0.17)] don't orthogonalize to everything

In [46]:
neurons = [neuron for neuron, _ in neuron_with_logits_list[:30]]
mlp_activation = focus_cache["mlp_post", 0][game, pos_to, neurons]
(mlp_activation > 0).sum()

tensor(0, device='cuda:0')

In [23]:
layer

1

In [65]:
flipped_probe_normalized.shape

torch.Size([512, 8, 8, 8])

In [101]:
layer = 1
head = 5
pos_to = 6
yours_logits_list = []
my_tile = (3, 3)
# previous_directions = []
# for probe_name, direction_str in [("flipped", "flipped"), ("flipped", "not_flipped"), ("linear", "yours")]:
probe_information = []
for probe_name in ["linear", "flipped", "placed", "legal", "accesible"]:
    for direction_str in probe_directions[probe_name]:
        probe = get_probe(layer-1, probe_name, "post")
        for row in range(8):
            for col in range(8):
                tile = (row, col)
                direction = probe[0, :, *tile, get_direction_int(direction_str)]
                direction_normalized = direction / direction.norm()
                if not (row == 3 and col == 3 and direction_str == "flipped"):
                    flipped_dir = flipped_probe_normalized[:, layer, *my_tile]
                    yours_dir = yours_probe[:, layer, *my_tile]
                    yours_dir = yours_dir / yours_dir.norm()
                    mine_dir = mine_probe[:, layer, *my_tile]
                    mine_dir = mine_dir / mine_dir.norm()
                    direction_normalized = orthogonalize_vectors(direction_normalized, [flipped_dir, yours_dir, mine_dir])
                    # print(direction_normalized @ flipped_dir, direction_normalized @ yours_dir, direction_normalized @ mine_dir)
                # make direction orthogonal to previous directions
                # for previous_direction in previous_directions:
                #     direction = direction - einops.einsum(direction, previous_direction, "d_model, d_model -> ") * previous_direction
                #     break
                # revious_directions.append(direction_normalized)
                direction_logit = einops.einsum(resid_real[game, pos_to], direction_normalized, "d_model, d_model -> ")
                direction_scaled = direction_logit * direction_normalized
                direction_after_OV = einops.einsum(direction_scaled, OV.AB[layer, head], "d_model_in, d_model_in d_model_out -> d_model_out")
                yours_logit = einops.einsum(direction_after_OV, yours_probe[:, layer, *my_tile], "d_model, d_model -> ")
                yours_logits_list.append(yours_logit.item())
                probe_information += [f"{tuple_to_label(tile)} {direction_str}"]

attributions = np.array(yours_logits_list)
fig = go.Scatter(
    x=list(range(attributions.shape[0])),
    y=attributions.flatten(), # labels={"x": "Probe Index", "y": "Attribution"}
    mode='markers',  # Display as points
    marker=dict(size=10),  # Size of the points
    hovertext=probe_information,  # Text to display when hovering over each point
    hoverinfo='text'  # Only show the hover text when hovering
)
layout = go.Layout(
    title="title_text",
    xaxis=dict(title="Probe Index"),
    yaxis=dict(title="Attribution")
)
# fig.update_layout(hover_text=probe_information)
# fig.update_layout(title_text="title_text")
# fig.update_yaxes(exponentformat = 'E')
# fig.show()

In [102]:
fig = go.Figure(data=fig, layout=layout)
fig.show()

In [None]:
'''def plot_neuron_attribution(game_idx, layer, position, direction : Float[Tensor, "d_model"], title_text):
    W_out : Float[Tensor, "d_mlp d_model"] = model.W_out[layer].detach()
    activations = focus_cache["post", layer][game_idx]
    attributions = activations[position] * einops.einsum(W_out, direction, "d_mlp d_d_model, d_model -> d_mlp")
    print(attributions.shape)
    # Make a plotly scatter plot. y axis is the attribution, x axis is the neuron index
    fig = px.scatter(x=list(range(attributions.shape[0])), y=attributions.cpu().numpy().flatten(), labels={"x": "Neuron Index", "y": "Attribution"})
    fig.update_layout(title_text=title_text)
    fig.update_yaxes(exponentformat = 'E')
    fig.show()

position = 10

direction_label = "yours"
direction = get_probe(layer, "linear", "post")[0, :, *tile, get_direction_int(direction_label)]
plot_neuron_attribution(0, 0, position, direction, title_text=f"Attributions of Layer 0 Neurons to the {direction_label} direction of {tile_label} in game 0 (Posiiton {position})")'''

In [80]:
game

28

In [50]:
print(game)
vis_args = VisualzeBoardArguments()
vis_args.start_pos = 0
vis_args.end_pos = 11
vis_args.layers = 6
vis_args.include_layer_norm = True
vis_args.include_resid_post = False
# vis_args.include_attn_only = True
visualize_game(board_seqs_string[game], vis_args, model)

0
torch.Size([6, 2, 59, 8, 8])


In [13]:
correct.shape

torch.Size([50, 59])

In [14]:
mask.shape

torch.Size([50, 59])

In [9]:
results_math.keys()

dict_keys(['TP', 'TN', 'FP', 'FN'])

In [10]:
results_math["TP"].shape

torch.Size([8, 59, 8, 8])

In [None]:
# TODO: First get results for math version and test version, then make score with interesting two dimensions, then create heatmap ...
# TODO: Übelegen was dann kommt
# TODO: Nur Falls es langsam ist ... Alle Tiles gleichzeitig (Könnte maybe keinen Sinn machen, vielleicht will ich ja andere Tiles beim input mit rein nehmen so)

In [36]:
'''# save results_math and results_test using pickle
with open("results_math.pkl", "wb") as file:
    pickle.dump(results_math, file)
with open("results_test.pkl", "wb") as file:
    pickle.dump(results_test, file)'''

In [39]:
print(f"Math Acc: {scores_math['Accuracy'][1, 10, *label_to_tuple('C3')].item():.2f}") # This should be > 0.8
print(f"Real Acc: {scores_test['Accuracy'][1, 10, *label_to_tuple('C3')].item():.2f}") # This should be the same


Math Acc: 0.44
Real Acc: 0.44


In [45]:
# load results_math and results_test using pickle
with open("results_math_approx.pkl", "rb") as file:
    results_math_approx = pickle.load(file)

In [46]:
results_math_approx["TP"][3, 12, 0, 2]

tensor(31., device='cuda:0')

In [47]:
scores = get_score_from_results(results_math_approx, ["layer", "pos", "row", "col"])

In [48]:
scores["Accuracy"][6, 43, 2, 2]

tensor(0.6997, device='cuda:0')

In [49]:
scores_vis = get_score_from_results(results_math_approx, ["row", "col"])

In [50]:
def get_scores_tensor(results, dimensions, metric, title, layer=None, pos=None, row=None, col=None, save=False):
    # results_compressed = results[layer, pos, row, col]
    results_compressed = {}
    for key in results.keys():
        results_compressed[key] = results[key]
        if layer is not None:
            results_compressed[key] = results_compressed[key][layer].unsqueeze(0)
        if pos is not None:
            results_compressed[key] = results_compressed[key][:, pos].unsqueeze(1)
        if row is not None:
            results_compressed[key] = results_compressed[key][:, :, row].unsqueeze(2)
        if col is not None:
            results_compressed[key] = results_compressed[key][:, :, :, col].unsqueeze(3)
    scores = get_score_from_results(results_compressed, dimensions)
    scores = scores[metric]
    # print(scores.shape)
    # print(scores[:4, :4])
    if dimensions == ["layer", "pos", "row", "col"]:
        scores = einops.reduce(scores, "layer pos row col -> pos row col", "max")
        scores = einops.reduce(scores, "pos row col -> row col", "mean")
    if dimensions == ["layer", "row", "col"]:
        scores = einops.reduce(scores, "layer row col -> row col", "max")
    # print(scores.shape)
    # print(scores[:4, :4])
    scores = scores.cpu().numpy()
    return scores


In [51]:
# Create a plotly headmap of scores_vis["Accuracy"]
# Notes
# Also ich finde diesen Plot cool. (Vielleicht average ich aber doch lieber alles zusammen)
save = True
dimensions = ["layer", "row", "col"]
metric = "Accuracy"
title = "Accuracy of the Last-Flipped Heuristic of all Board Tiles"

scores = get_scores_tensor(
    results_math_approx,
    dimensions,
    metric,
    title,
    layer = None,
    pos = None,
    row = None,
    col = None,
)

fig = px.imshow(scores, labels=dict(x=dimensions[-1], y=dimensions[-2], color=metric))
fig.update_layout(width=800, height=800)
fig.show()
if save:
    save_plotly(fig, title)
# 

In [52]:
scores.shape
scores.mean()
# Real : 0.768
# Approx: 0.764

0.7642852

In [27]:
scores = get_score_from_results(results_math_approx[:, :10, :, :], ["layer", "row", "col"])["Accuracy"]
scores.shape

TypeError: unhashable type: 'slice'

In [16]:
def plot_boards_general(x_labels : List[str],
                        y_labels : List[str],
                        boards : Float[Tensor, "x y rows cols"],
                        size_of_board : Int = 200,
                        margin_t : Int = 100,
                        title_text : str = "",
                        color_range  : str = "symmetric",
                        static_image : bool = False,
                        save : bool = False):
    # TODO: add attn/mlp only
    # TODO: Change Width and Height accordingly
    boards = boards.flip(2)
    x_len, y_len, rows, cols = boards.shape
    subplot_titles = [f"{y_label}, {x_label}" for y_label in y_labels for x_label in x_labels]
    # subplot_titles = [f"P: {i}, T: {label_list[i]}, L: {j}" for i in range(vis_args.start_pos, vis_args.end_pos) for j in range(vis_args.layers)]
    width = x_len * size_of_board
    height = y_len * size_of_board + margin_t
    vertical_spacing = 70 / height
    fig = make_subplots(rows=y_len, cols=x_len, subplot_titles=subplot_titles, vertical_spacing=vertical_spacing)
    boards_min = boards.min().item()
    boards_max = boards.max().item()        
    abs_max = max(abs(boards_min), abs(boards_max))
    if color_range == "symmetric":
        begin = -abs_max
        end = abs_max
    else:
        begin = boards_min
        end = boards_max
    for x in range(x_len):
        for y in range(y_len):
            heatmap = go.Heatmap(
                z=boards[x, y].cpu(),
                x=list(range(0, rows)),
                y=reverse_alpha,
                hoverongaps = False,
                zmin=begin,
                zmax=end,
            )
            fig.add_trace(
                heatmap,
                row=y + 1,
                col=x + 1
            )
    fig.layout.update(width=width, height=height, margin_t=margin_t, title_text=title_text) 
    if static_image:
        # count the number of images in the last_plot directory
        num_images = len(list(Path("last_plot").glob("*.png")))
        fig.write_image(f'last_plot/last_plot{num_images+1}.png')
    else:
        fig.show()
    if save:
        save_plotly(fig, title_text)

In [17]:
# Ich mache noch schnell den Plot für alle Layer und dann kann ich ja morgen den Part schreiben

scores = t.zeros((8, 6, 8, 8))

for interval in range(6):
    results_new ={}
    for key in results_math_approx.keys():
        results_new[key] = results_math_approx[key][:, interval*10:(interval+1)*20]
    scores[:, interval, :, :] = get_score_from_results(results_new, ["layer", "row", "col"])["Accuracy"]

print(scores.shape)
# It would be interesting to see the 
# TODO: Next: It would be interesting to visualize Rows with the first 10 moves, then the next 10 moves, and so on

'''plot_boards_general(
    x_labels=[f"Layer {i}" for i in range(1, 8)],
    y_labels=[f"Interval {i}" for i in range(6)],
    boards=scores[1:],
    color_range="non_symmetric",
    title_text="Accuracy of the Last-Flipped Heuristic of all Board Tile over the Layers and Positions",
    save=True,
)'''

scores = get_score_from_results(results_new, ["layer", "row", "col"])["Accuracy"]
plot_boards_general(
    x_labels=[f"Layer {i}" for i in range(1, 8)],
    y_labels=[""],
    boards=scores[1:].unsqueeze(1),
    color_range="non_symmetric",
    title_text="Accuracy of the Last-Flipped Heuristic of all Board Tile over the Layers",
    save=True,
)

torch.Size([8, 6, 8, 8])


In [153]:
create_heatmap(
    results_math_approx,
    ["layer", "pos"],
    "Accuracy",
    "Accuracy of the Board over all Layers and Positions",
    layer = None,
    pos = None,
    row = 3,
    col = 3,
    save=False
)

In [64]:
scores_vis["Accuracy"][1][:]

tensor([0.5919, 0.6293, 0.6652, 0.6617, 0.7093, 0.6949, 0.7123, 0.5914],
       device='cuda:0')