## Setup

### Setup 1

In [1]:
import os, sys
chapter = "chapter1_transformer_interp"
repo = "ARENA_3.0"
chapter_dir = r"./" if chapter in os.listdir() else os.getcwd().split(chapter)[0]
sys.path.append(chapter_dir + f"{chapter}/exercises")

import os
os.environ["ACCELERATE_DISABLE_RICH"] = "1"
import sys
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch.utils.data import DataLoader
import numpy as np
import einops
from ipywidgets import interact
import plotly.express as px
from ipywidgets import interact
from pathlib import Path
import itertools
import random
from IPython.display import display
from jaxtyping import Float, Int, Bool, Shaped, jaxtyped
from typing import List, Union, Optional, Tuple, Callable, Dict
import typeguard
from functools import partial
# from torcheval.metrics.functional import multiclass_f1_score
from sklearn.metrics import f1_score as multiclass_f1_score
import copy
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import HookedRootModule, HookPoint
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from tqdm.notebook import tqdm
from dataclasses import dataclass
from rich import print as rprint
import pandas as pd

import pickle

from training_utils import (
    get_state_stack_one_hot_empty_yours_mine,
    get_state_stack_one_hot_flipped,
    get_state_stack_one_hot_placed,
    get_state_stack_one_hot_placed_and_flipped,
    get_state_stack_one_hot_placed_and_flipped_stripe
)
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from jinja2 import Template
from utils import save_plotly_to_html
from utils import save_plotly_to_png
from utils import save_plotly
from training_utils import get_state_stack_num_flipped
from utils import get_probe

# Make sure exercises are in the path
# exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
# section_dir = exercises_dir / "part6_othellogpt"
# section_dir = "interpretability"
# if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow
from neel_plotly import scatter, line
# import part6_othellogpt.tests as tests

t.manual_seed(42)

from training_utils import seq_to_state_stack
from training_utils import seq_to_state_stack_flipped

device = t.device("cuda" if t.cuda.is_available() else "cpu")



### Setup 2

In [2]:
MAIN = __name__ == "__main__"

cfg = HookedTransformerConfig(
    n_layers = 8,
    d_model = 512,
    d_head = 64,
    n_heads = 8,
    d_mlp = 2048,
    d_vocab = 61,
    n_ctx = 59,
    act_fn="gelu",
    normalization_type="LNPre",
    device=device,
)
model = HookedTransformer(cfg)

sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "synthetic_model.pth")
# champion_ship_sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "championship_model.pth")
model.load_state_dict(sd)

# An example input
sample_input = t.tensor([[
    20, 19, 18, 10,  2,  1, 27,  3, 41, 42, 34, 12,  4, 40, 11, 29, 43, 13, 48, 56,
    33, 39, 22, 44, 24,  5, 46,  6, 32, 36, 51, 58, 52, 60, 21, 53, 26, 31, 37,  9,
    25, 38, 23, 50, 45, 17, 47, 28, 35, 30, 54, 16, 59, 49, 57, 14, 15, 55, 7
]]).to(device)

# The argmax of the output (ie the most likely next move from each position)
sample_output = t.tensor([[
    21, 41, 40, 34, 40, 41,  3, 11, 21, 43, 40, 21, 28, 50, 33, 50, 33,  5, 33,  5,
    52, 46, 14, 46, 14, 47, 38, 57, 36, 50, 38, 15, 28, 26, 28, 59, 50, 28, 14, 28,
    28, 28, 28, 45, 28, 35, 15, 14, 30, 59, 49, 59, 15, 15, 14, 15,  8,  7,  8
]]).to(device)

assert (model(sample_input).argmax(dim=-1) == sample_output.to(device)).all()

# os.chdir(section_dir)
section_dir = Path.cwd()
sys.path.append(str(section_dir))
print(section_dir.name)

OTHELLO_ROOT = (section_dir / "othello_world").resolve()
OTHELLO_MECHINT_ROOT = (OTHELLO_ROOT / "mechanistic_interpretability").resolve()

# if not OTHELLO_ROOT.exists():
#     !git clone https://github.com/likenneth/othello_world

sys.path.append(str(OTHELLO_MECHINT_ROOT))

from mech_interp_othello_utils import (
    plot_board,
    plot_single_board,
    plot_board_log_probs,
    to_string,
    to_int,
    int_to_label,
    string_to_label,
    OthelloBoardState
)

# Load board data as ints (i.e. 0 to 60)
board_seqs_int = t.tensor(np.load(OTHELLO_MECHINT_ROOT / "board_seqs_int_small.npy"), dtype=t.long)
# Load board data as "strings" (i.e. 0 to 63 with middle squares skipped out)
board_seqs_string = t.tensor(np.load(OTHELLO_MECHINT_ROOT / "board_seqs_string_small.npy"), dtype=t.long)

assert all([middle_sq not in board_seqs_string for middle_sq in [27, 28, 35, 36]])
assert board_seqs_int.max() == 60

num_games, length_of_game = board_seqs_int.shape

# Define possible indices (excluding the four center squares)
stoi_indices = [i for i in range(64) if i not in [27, 28, 35, 36]]

# Define our rows, and the function that converts an index into a (row, column) label, e.g. `E2`
alpha = "ABCDEFGH"

def to_board_label(i):
    return f"{alpha[i//8]}{i%8}"

# Get our list of board labels
board_labels = list(map(to_board_label, stoi_indices))
full_board_labels = list(map(to_board_label, range(64)))

def plot_square_as_board(state, diverging_scale=True, **kwargs):
    """Takes a square input (8 by 8) and plot it as a board. Can do a stack of boards via facet_col=0"""
    kwargs = {
        "y": [i for i in alpha],
        "x": [str(i) for i in range(8)],
        "color_continuous_scale": "RdBu" if diverging_scale else "Blues",
        "color_continuous_midpoint": 0. if diverging_scale else None,
        "aspect": "equal",
        **kwargs
    }
    imshow(state, **kwargs)

start = 30000
num_games = 50
focus_games_int = board_seqs_int[start : start + num_games]
focus_games_string = board_seqs_string[start: start + num_games]

focus_logits, focus_cache = model.run_with_cache(focus_games_int[:, :-1].to(device))
focus_logits.shape

def one_hot(list_of_ints, num_classes=64):
    out = t.zeros((num_classes,), dtype=t.float32)
    out[list_of_ints] = 1.
    return out

focus_states = np.zeros((num_games, 60, 8, 8), dtype=np.float32)
focus_valid_moves = t.zeros((num_games, 60, 64), dtype=t.float32)

for i in (range(num_games)):
    board = OthelloBoardState()
    for j in range(60):
        board.umpire(focus_games_string[i, j].item())
        focus_states[i, j] = board.state
        focus_valid_moves[i, j] = one_hot(board.get_valid_moves())

print("focus states:", focus_states.shape)
print("focus_valid_moves", tuple(focus_valid_moves.shape))

# full_linear_probe = t.load(OTHELLO_MECHINT_ROOT / "main_linear_probe.pth", map_location=device)

# linear_probe2 = t.load("probes/linear/resid_6_linear.pth")

rows = 8
cols = 8
options = 3
# assert linear_probe2.shape == (1, cfg.d_model, rows, cols, options)

black_to_play_index = 0
white_to_play_index = 1
blank_index = 0
their_index = 1
my_index = 2

# Creating values for linear probe (converting the "black/white to play" notation into "me/them to play")

from utils import *

interpretability
focus states: (50, 60, 8, 8)
focus_valid_moves (50, 60, 64)


In [3]:
device

device(type='cuda')

## Code

In [4]:
board_seqs_int = t.load(
    os.path.join(
        "data/board_seqs_int_valid.pth",
    )
)
board_seqs_int.shape

torch.Size([500000, 60])

In [5]:
def state_stack_to_mine_yours(state_stack):
    # 0 is blank, -1 is white, 1 is black
    # change dtype to int
    state_stack[0::2] *= -1 # 0 blank, 1 mine, -1 theirs
    state_stack[state_stack == 1] = MINE # 2
    state_stack[state_stack == -1] = YOURS # 1
    # assert set(state_stack.tolist()).issubset(set([0, 1, 2]))
    return state_stack # 0 empty, 1 yours, 2 mine

In [6]:
# TODO: Add Flipped Probe ZU Faul weil f1 score. NO if I only look at the flipped Tile I'm Fine
# TODO: Do Plotly of board_probe_results_layer_pos
# TODO: Do Plotly of board_probe_results_layer_pos_flipped
# TODO: board_probe_result
# TODO: board_results_flipped
# TODO: Batches (Flipped per Tile, move and position requires a lot of Data (Very Sparse))

def load_probes(probe_name, layers):
    probes = []
    for layer in range(layers):
        probe = t.load(f"probes/{probe_name}/resid_{layer}_{probe_name}.pth").to(device)
        probes.append(probe)
    return probes

def get_state_stack_board(input_seq_str):
    state_stack = seq_to_state_stack(input_seq_str)
    state_stack = state_stack_to_mine_yours(state_stack)
    state_stack = t.Tensor(state_stack[:-1])
    return state_stack

def get_state_stack_board_flipped(input_seq_str):
    state_stack = seq_to_state_stack_flipped(input_seq_str)
    state_stack[state_stack == 0] = 1
    state_stack[state_stack == 1] = 1
    state_stack[state_stack == 2] = 0
    state_stack[state_stack == 3] = 1
    state_stack[state_stack == 4] = 1
    state_stack[state_stack == 5] = 1
    state_stack = t.Tensor(state_stack[:-1])
    return state_stack

# TODO: I have to make this clean ...
# TODO: Use one hot encoding for the state stack (Itshould work for any state_stack and probe)
# TODO: You can switch between Recall and Accuracy.
# TODO: You can send a List of Dimensions that should not be viewed
# TODO: Remove use_num_flipped to make my life easier
# TODO: Implement F1Score
def get_accuracy_detailed(
        probe_type : str,
        probe_module : str,
        len_data : Int,
        get_state_stack_one_hot_function : Callable,
        batch_size : Int = 128,
        options : Int = 3,
        ignore_dimensions : List[Int] = [0],
        multi_label : bool = False,
        multi_label_threshold : Float = 0.5, # metric : str = "recall"
    ) -> Tuple[Tensor, Tensor]:
    # TODO: Accuracy doesen't work! Add accuracy
    # assert top_one or (not multi_label)
    _, len_seq = board_seqs_int.shape
    layers = 8
    rows = 8
    cols = 8
    not_ignore_dimensions = []
    for i in range(options):
        if i not in ignore_dimensions:
            not_ignore_dimensions.append(i)

    outputs = {
        "full" : {
            "recall" : None,
            "accuracy" : None
        },
        "layer_pos" : {
            "recall" : None,
            "accuracy" : None
        }
    }

    board_probe_correct_recall = t.zeros(size=(layers, len_seq-1, rows, cols, options)).to(device)
    board_probe_correct_accuracy = t.zeros(size=(layers, len_seq-1, rows, cols, options)).to(device)
    # Counts how many times a tile is not blank over all the games, to average later
    total_count_recall = t.zeros(size=(len_seq-1, rows, cols, options)).to(device)
    total_count_accuracy = t.zeros(size=(len_seq-1, rows, cols, options)).to(device)
    total_count_top_one_recall = t.zeros(size=(len_seq-1, rows, cols)).to(device)
    # probes = load_probes(probe_name, layers)
    num_batches = len_data // batch_size

    for i in tqdm(range(0, len_data-batch_size, batch_size)):
        indeces = t.arange(i, min(i + batch_size, len_data))
        input_seqs = board_seqs_int[indeces]
        input_seqs = input_seqs[:, :-1]
        state_stacks = get_state_stack_one_hot_function(t.Tensor(to_string(input_seqs)).to(dtype=t.long))
        # print(state_stacks.shape)
        assert state_stacks.shape == (batch_size, len_seq-1, rows, cols, options)
        # state_stacks = state_stacks[:, :-1]
        # if metric == "recall":
        #     not_blank = state_stacks > 0
        # else:
        #     not_blank = t.ones_like(state_stacks)
        not_blank_recall = state_stacks > 0
        not_blank_accuracy = t.ones_like(state_stacks)
        # not_blank = state_stacks > 0
        # state_stacks[state_stacks == 0] = 10 # TODO: remember this
        not_blank_recall = einops.reduce(not_blank_recall, "batch pos rows cols options -> pos rows cols options", reduction="sum")
        total_count_recall += not_blank_recall
        not_blank_accuracy = einops.reduce(not_blank_accuracy, "batch pos rows cols options -> pos rows cols options", reduction="sum")
        total_count_accuracy += not_blank_accuracy

        _, cache= model.run_with_cache(
            input_seqs.to(device)
        )
        for layer in range(layers):
            probe = get_probe(layer, probe_type, probe_module)
            resid = cache[f"resid_{probe_module}", layer].to(device)
            result = einops.einsum(probe, resid, "modes d_model rows cols options, batch pos d_model -> modes batch pos rows cols options")[0]
            if not multi_label:
                pred = result.argmax(dim=-1)
                pred = t.nn.functional.one_hot(pred, num_classes=options)
            else:
                pred = result
                pred = t.nn.functional.sigmoid(pred)
                pred = (pred > multi_label_threshold).int()
            correct_accuracy = (pred == state_stacks)
            # if metric == "recall":
            correct_recall = correct_accuracy & (state_stacks == 1)
            # Now for I need a Tensor[batch, pos, rows, cols] with the number of flips Done
            correct_recall = einops.reduce(correct_recall, "batch pos rows cols options -> pos rows cols options", reduction="sum")
            correct_accuracy = einops.reduce(correct_accuracy, "batch pos rows cols options -> pos rows cols options", reduction="sum")
            board_probe_correct_recall[layer] += correct_recall
            board_probe_correct_accuracy[layer] += correct_accuracy
    # TODO: Make this work for multi_label
    # TODO: (Keep the option dimension, later make it go away) For all dimensions save how many correct there are
    # Tensor of Size[pos]
    count_of_tiles_not_blank_recall : Float[Tensor, "seq_len options"] = (total_count_recall > 0).sum(dim=-3).sum(dim=-2)
    count_of_tiles_not_blank_accuracy : Float[Tensor, "seq_len options"] = (total_count_accuracy > 0).sum(dim=-3).sum(dim=-2)
    # As to not divide by zero
    total_count_recall[total_count_recall == 0] = 1
    total_count_accuracy[total_count_accuracy == 0] = 1
    # First Average over the number of correct predictions per tile (Only where Tile is not Blank)
    board_probe_result_recall : Float[Tensor, "layer seq_len rows cols options"] = (board_probe_correct_recall / total_count_recall)
    board_probe_result_accuracy : Float[Tensor, "layer seq_len rows cols options"] = (board_probe_correct_accuracy / total_count_accuracy)
    # Now Average over the number of tiles that are not blank
    board_probe_results_layer_pos_recall = einops.reduce(board_probe_result_recall, "layers pos rows cols options -> layers pos options", reduction="sum") / count_of_tiles_not_blank_recall
    board_probe_results_layer_pos_recall = board_probe_results_layer_pos_recall[:, :, not_ignore_dimensions].mean(dim=-1)
    board_probe_results_layer_pos_accuracy = einops.reduce(board_probe_result_accuracy, "layers pos rows cols options -> layers pos options", reduction="sum") / count_of_tiles_not_blank_accuracy
    board_probe_results_layer_pos_accuracy = board_probe_results_layer_pos_accuracy[:, :, not_ignore_dimensions].mean(dim=-1)
    outputs["full"]["recall"] = board_probe_result_recall
    outputs["full"]["accuracy"] = board_probe_result_accuracy
    outputs["layer_pos"]["recall"] = board_probe_results_layer_pos_recall
    outputs["layer_pos"]["accuracy"] = board_probe_results_layer_pos_accuracy
    return outputs

In [7]:
from sklearn.metrics import f1_score

In [18]:
def get_f1score_detailed(
        probe_type : str,
        probe_module : str,
        len_data : Int,
        get_state_stack_one_hot_function : Callable,
        batch_size : Int = 128,
        multi_label : bool = True,
        options : Int = 8,
        multi_label_threshold : Float = 0.5
    ) -> Tuple[Tensor, Tensor]:

    f1_scores_int = t.zeros((8)).to(device)
    for i in tqdm(range(0, len_data-batch_size, batch_size)):
        indeces = t.arange(i, min(i + batch_size, len_data))
        input_seqs = board_seqs_int[indeces]
        input_seqs = input_seqs[:, :-1]
        state_stacks = get_state_stack_one_hot_function(t.Tensor(to_string(input_seqs)).to(dtype=t.long))
        _, cache= model.run_with_cache(
            input_seqs.to(device)
        )
        for layer in range(8):
            probe = get_probe(layer, probe_type, probe_module)
            resid = cache[f"resid_{probe_module}", layer].to(device)
            result = einops.einsum(probe, resid, "modes d_model rows cols options, batch pos d_model -> modes batch pos rows cols options")[0]
            if not multi_label:
                pred = result.argmax(dim=-1)
                pred = t.nn.functional.one_hot(pred, num_classes=options)
            else:
                pred = result
                pred = t.nn.functional.sigmoid(pred)
                pred = (pred > multi_label_threshold).int()
            # Now calculate the f1 score over all dimensions
            state_stacks_flat = einops.rearrange(state_stacks, "batch pos rows cols options -> (batch pos rows cols options)")
            pred_flat = einops.rearrange(pred, "batch pos rows cols options -> (batch pos rows cols options)")
            f1_score_int = f1_score(state_stacks_flat.to(dtype=t.long).cpu(), pred_flat.to(dtype=t.long).cpu(), average="macro")
            f1_scores_int[layer] = f1_score_int
    return f1_scores_int
            
'''get_f1score_detailed(
    "placed",
    "mid",
    100,
    get_state_stack_one_hot_placed,
    batch_size=90,
    multi_label=False,
    options=2,
    multi_label_threshold=0.999999
)'''

get_f1score_detailed(
    "placed_and_flipped",
    "mid",
    100,
    get_state_stack_one_hot_placed_and_flipped,
    batch_size=90,
    multi_label=True,
    options=8,
    multi_label_threshold=0.99999997
)

  0%|          | 0/1 [00:00<?, ?it/s]

tensor([0.9049, 0.9125, 0.9099, 0.9094, 0.9116, 0.9164, 0.9072, 0.4216],
       device='cuda:0')

In [25]:
# 0.4782, 0.4347, 0.3652, 0.2977, 0.2459, 0.1913, 0.1435, 0.1275
# 0.6448, 0.5850, 0.5200, 0.4541, 0.3836, 0.3019, 0.1671, 0.1301
# 0.8224, 0.7447, 0.6659, 0.5908, 0.5072, 0.4224, 0.2144, 0.1395
# 0.8936, 0.8948, 0.8925, 0.8892, 0.8786, 0.8326, 0.4863, 0.3099 (multi_label_threshold=0.99999)
# 0.9063, 0.8966, 0.8948, 0.8923, 0.8872, 0.8770, 0.7814, 0.3058 (for mid)
# 0.9049, 0.9125, 0.9099, 0.9094, 0.9116, 0.9164, 0.9072, 0.4216

In [75]:
linear_probe_result, linear_probe_results_layer_pos = get_accuracy_detailed(
    probe_type="linear",
    probe_module="mid",
    len_data=100,
    get_state_stack_one_hot_function=get_state_stack_one_hot_empty_yours_mine,
    batch_size=32,
    options=3,
    ignore_dimensions=[0],
    multi_label=False,
    multi_label_threshold=None
)

  0%|          | 0/3 [00:00<?, ?it/s]

In [27]:
state_stack_one_hot_functions = {
    "linear" : get_state_stack_one_hot_empty_yours_mine,
    "flipped" : get_state_stack_one_hot_flipped,
    "placed" : get_state_stack_one_hot_placed,
    "placed_and_flipped" : get_state_stack_one_hot_placed_and_flipped,
    "placed_and_flipped_stripe" : get_state_stack_one_hot_placed_and_flipped_stripe
}

probe_option_count = {
    "linear" : 3,
    "flipped" : 2,
    "placed" : 2,
    "placed_and_flipped" : 8,
    "placed_and_flipped_stripe" : 8,
}

probe_ignore_dimensions = {
    "linear" : [0],
    "flipped" : [1],
    "placed" : [1],
    "placed_and_flipped" : [],
    "placed_and_flipped_stripe" : [],
}

probe_multi_label = {
    "linear" : False,
    "flipped" : False,
    "placed" : False,
    "placed_and_flipped" : True,
    "placed_and_flipped_stripe" : True,
}

In [28]:
'''probe_name = "placed_and_flipped"
probe_module = "post"
len_data = 100

probe_result, probe_results_layer_pos = get_accuracy_detailed(
    probe_type=probe_name,
    probe_module=probe_module,
    len_data=len_data,
    get_state_stack_one_hot_function=state_stack_one_hot_functions[probe_name],
    batch_size=32,
    options=probe_option_count[probe_name],
    ignore_dimensions=probe_ignore_dimensions[probe_name],
    multi_label=probe_multi_label[probe_name],
    multi_label_threshold=0.8
)'''

'probe_name = "placed_and_flipped"\nprobe_module = "post"\nlen_data = 100\n\nprobe_result, probe_results_layer_pos = get_accuracy_detailed(\n    probe_type=probe_name,\n    probe_module=probe_module,\n    len_data=len_data,\n    get_state_stack_one_hot_function=state_stack_one_hot_functions[probe_name],\n    batch_size=32,\n    options=probe_option_count[probe_name],\n    ignore_dimensions=probe_ignore_dimensions[probe_name],\n    multi_label=probe_multi_label[probe_name],\n    multi_label_threshold=0.8\n)'

In [71]:
# TODO: Reload doesen't work anymore
RELOAD = True
len_data = 1000
# probe_modules = ["mid", "post"]
# probe_names = ["linear", "placed", "placed_and_flipped", "placed_and_flipped_stripe"]
probe_modules = ["post", "mid"]
probe_names = ["placed_and_flipped"]
probe_results = {}
for probe_module in probe_modules:
    if not os.path.isdir(f"probe_results/{probe_module}"):
        os.mkdir(f"probe_results/{probe_module}")
    for probe_name in probe_names:
        if not os.path.isdir(f"probe_results/{probe_module}/{probe_name}"):
            os.mkdir(f"probe_results/{probe_module}/{probe_name}")
        c = False
        if os.path.exists(f"probe_results/{probe_module}/{probe_name}/probe_result.pickle") and not RELOAD:
            c = True
            probe_result = pickle.load(open(f"probe_results/{probe_module}/{probe_name}/probe_result.pickle", "rb"))
            probe_results[(probe_name, probe_module, "full")] = probe_result
        if os.path.exists(f"probe_results/{probe_module}/{probe_name}/probe_result_layer_pos.pickle") and not RELOAD:
            c = True
            probe_results_layer_pos = pickle.load(open(f"probe_results/{probe_module}/{probe_name}/probe_result_layer_pos.pickle", "rb"))
            probe_results[(probe_name, probe_module, "layer_pos")] = probe_results_layer_pos
        if c:
            continue
        outputs = get_accuracy_detailed(
            probe_type=probe_name,
            probe_module=probe_module,
            len_data=len_data,
            get_state_stack_one_hot_function=state_stack_one_hot_functions[probe_name],
            batch_size=32,
            options=probe_option_count[probe_name],
            ignore_dimensions=probe_ignore_dimensions[probe_name],
            multi_label=probe_multi_label[probe_name],
            multi_label_threshold=0.999999,
        )
        for layer_pos_or_full, d in outputs.items():
            for metric, probe_result in d.items():
                probe_results[(probe_name, probe_module, layer_pos_or_full, metric)] = probe_result
        # probe_results[(probe_name, probe_module, "full")] = probe_result
        # probe_results[(probe_name, probe_module, "layer_pos")] = probe_results_layer_pos
        with open(f"probe_results/{probe_module}/{probe_name}/probe_result.pickle", "wb") as f:
            pickle.dump(outputs, f)
        # with open(f"probe_results/{probe_module}/{probe_name}/probe_result_layer_pos.pickle", "wb") as f:
        #     pickle.dump(probe_results_layer_pos, f)

  0%|          | 0/31 [00:00<?, ?it/s]

  0%|          | 0/31 [00:00<?, ?it/s]

In [35]:
from utils import save_plotly

In [74]:
def plot_results_layer_pos(probe_name : str, probe_module : str, metric : str = "recall"):
    linear_probe_results_layer_pos = probe_results[(probe_name, probe_module, "layer_pos", metric)]
    layers = 8
    len_seq = 60
    fig = go.Figure(
        data=go.Heatmap(
            z=linear_probe_results_layer_pos.cpu(),
            x=list(range(0, len_seq-1)),
            y=list(range(0, layers-1)),
            hoverongaps = False),
        layout=go.Layout(width=750,height=400,margin_t=50)
    )
    fig.update_layout(
        title_text="Accuracy of the MINE/YOURS direction of the Linear Probe per Layer and Position",
        xaxis_title="Layer",
        yaxis_title="Position",
        )
    fig.show()

plot_results_layer_pos("placed_and_flipped", "mid", "recall")

In [None]:
# board_probe_results_layer_pos
layers = 8
len_seq = 60
fig = go.Figure(
    data=go.Heatmap(
        z=linear_probe_results_layer_pos.cpu(),
        x=list(range(0, len_seq-1)),
        y=list(range(0, layers-1)),
        hoverongaps = False),
    layout=go.Layout(width=750,height=400,margin_t=50)
)
fig.update_layout(
    title_text="Accuracy of the MINE/YOURS direction of the Linear Probe per Layer and Position",
    xaxis_title="Layer",
    yaxis_title="Position",
    )
fig.show()
# save_plotly(fig, "linear_probe_results_layer_pos")

## Plot Board
- Plan:
    - Subplots Pro Layer oder Pro Positions
    - Dann Feld als Heatmap

In [51]:
from utils import plot_boards_general
from utils import 

In [62]:
def plot_results_layer(probe_name : str, probe_module : str, metric : str = "recall"):
    linear_probe_result = probe_results[(probe_name, probe_module, "full", metric)]
    linear_probe_result_layer = einops.reduce(linear_probe_result, "layers pos rows cols options -> layers rows cols options", reduction="mean")
    linear_probe_result_layer = einops.rearrange(linear_probe_result_layer, "layers rows cols options -> options layers rows cols")
    layers = 8
    options = 8
    plot_boards_general(
        x_labels = probe_directions_list[probe_name],
        y_labels = [f"layer {i}" for i in range(layers)],
        boards = linear_probe_result_layer,
    )

    '''for layer in range(layers):
        for option in range(options):
            fig.add_trace(
                go.Heatmap(
                    z=linear_probe_result_layer[layer, :, :, option].cpu(),
                    x=list(range(0, 8)),
                    y=list(range(0, 8)),
                    hoverongaps = False, # colorbar=go.heatmap.ColorBar(title="Accuracy", nticks=5),
                    zmin=0.0,
                    zmax=1.0,
                ),
                row=layer//5 + 1,
                col=layer%5 + 1,
            )
    fig.layout.update(width=1800,height=750,margin_t=100, title_text="Accuracy of the MINE/YOURS direction per Layer per Tile")
    # save_plotly(fig, "linear_probe_result_layer")
    fig.show()'''

plot_results_layer("placed_and_flipped", "post", "recall")

In [None]:
flipped_probe_result_layer = flipped_probe_result[:, :]
flipped_probe_result_layer = einops.reduce(flipped_probe_result_layer, "layers pos rows cols -> layers rows cols", reduction="mean")

fig = make_subplots(rows=2, cols=5, subplot_titles=[f"Layer {i}" for i in range(8)], vertical_spacing=0.1)
for layer in range(8):
    # print(layer//3 + 1, layer%3 + 1)
    fig.add_trace(
        go.Heatmap(
            z=flipped_probe_result_layer[layer].cpu(),
            x=list(range(0, 8)),
            y=list(range(0, 8)),
            hoverongaps = False, # colorbar=go.heatmap.ColorBar(title="Accuracy", nticks=5),
            zmin=0.5,
            zmax=1.0,
        ),
        row=layer//5 + 1,
        col=layer%5 + 1,
    )
fig.layout.update(width=1800,height=750,margin_t=100, title_text="Recall of the Flipped Probe per Layer per Tile")
save_plotly(fig, "flipped_probe_result_layer")
# fig['layout']['scene']['zaxis'].update(range=[0.8,1.0])

fig.show()

In [None]:
linear_probe_result_pos = linear_probe_result[:, :]
linear_probe_result_pos = einops.reduce(linear_probe_result_pos, "layers pos rows cols -> pos rows cols", reduction="mean")

fig = make_subplots(rows=2, cols=6, subplot_titles=[f"Position {i}" for i in range(6, 59, 5)], vertical_spacing=0.15)
for i, pos in enumerate(range(6, 59, 5)):
    # print(layer//3 + 1, layer%3 + 1)
    fig.add_trace(
        go.Heatmap(
            z=linear_probe_result_pos[pos].cpu(),
            x=list(range(0, 8)),
            y=list(range(0, 8)),
            hoverongaps = False,
            zmin=0.75,
            zmax=1.0,
        ),
        row=i//6 + 1,
        col=i%6 + 1
    )
fig.layout.update(width=1800,height=700,margin_t=100, title_text="Accuracy of the MINE/YOURS direction per Position per Tile")
save_plotly(fig, "linear_probe_result_pos")
fig.show()

In [None]:
flipped_probe_result_pos = flipped_probe_result[:, :]
flipped_probe_result_pos = einops.reduce(flipped_probe_result_pos, "layers pos rows cols -> pos rows cols", reduction="mean")

fig = make_subplots(rows=2, cols=6, subplot_titles=[f"Position {i}" for i in range(6, 59, 5)], vertical_spacing=0.15)
for i, pos in enumerate(range(6, 59, 5)):
    # print(layer//3 + 1, layer%3 + 1)
    fig.add_trace(
        go.Heatmap(
            z=flipped_probe_result_pos[pos].cpu(),
            x=list(range(0, 8)),
            y=list(range(0, 8)),
            hoverongaps = False,
            zmin=0.5,
            zmax=1.0,
        ),
        row=i//6 + 1,
        col=i%6 + 1
    )
fig.layout.update(width=1800,height=700,margin_t=100, title_text="Recall of the Flipped direction per Position per Tile")
save_plotly(fig, "flipped_probe_result_pos")
fig.show()


In [None]:
linear_probe_result_layer = num_flipped_linear_probe_result[:, :]
linear_probe_result_layer = einops.reduce(linear_probe_result_layer, "layers pos rows cols -> layers rows cols", reduction="mean")

fig = make_subplots(rows=2, cols=5, subplot_titles=[f"Layer {i}" for i in range(8)], vertical_spacing=0.1)
for layer in range(8):
    # print(layer//3 + 1, layer%3 + 1)
    fig.add_trace(
        go.Heatmap(
            z=linear_probe_result_layer[layer].cpu(),
            x=list(range(0, 8)),
            y=list(range(0, 8)),
            hoverongaps = False, # colorbar=go.heatmap.ColorBar(title="Accuracy", nticks=5),
            zmin=0.0,
            zmax=1.0,
        ),
        row=layer//5 + 1,
        col=layer%5 + 1,
    )
fig.layout.update(width=1800,height=750,margin_t=100, title_text="Accuracy of the MINE/YOURS direction per Layer per Tile")
save_plotly(fig, "num_flipped_linear_probe_result_layer")
# fig['layout']['scene']['zaxis'].update(range=[0.8,1.0])

fig.show()

flipped_probe_result_layer = num_flipped_flipped_probe_result[:, :]
flipped_probe_result_layer = einops.reduce(flipped_probe_result_layer, "layers pos rows cols -> layers rows cols", reduction="mean")

fig = make_subplots(rows=2, cols=5, subplot_titles=[f"Layer {i}" for i in range(8)], vertical_spacing=0.1)
for layer in range(8):
    # print(layer//3 + 1, layer%3 + 1)
    fig.add_trace(
        go.Heatmap(
            z=flipped_probe_result_layer[layer].cpu(),
            x=list(range(0, 8)),
            y=list(range(0, 8)),
            hoverongaps = False, # colorbar=go.heatmap.ColorBar(title="Accuracy", nticks=5),
            zmin=0.0,
            zmax=1.0,
        ),
        row=layer//5 + 1,
        col=layer%5 + 1,
    )
fig.layout.update(width=1800,height=750,margin_t=100, title_text="Recall of the Flipped Probe per Layer per Tile")
save_plotly(fig, "num_flipped_flipped_probe_result_layer")
# fig['layout']['scene']['zaxis'].update(range=[0.8,1.0])

fig.show()

linear_probe_result_pos = num_flipped_linear_probe_result[:, :]
linear_probe_result_pos = einops.reduce(linear_probe_result_pos, "layers pos rows cols -> pos rows cols", reduction="mean")

fig = make_subplots(rows=2, cols=7, subplot_titles=[f"Number of Flips: {i}" for i in range(1, 14)], vertical_spacing=0.15)
for i, pos in enumerate(range(1, 14)):
    # print(layer//3 + 1, layer%3 + 1)
    fig.add_trace(
        go.Heatmap(
            z=linear_probe_result_pos[pos].cpu(),
            x=list(range(0, 8)),
            y=list(range(0, 8)),
            hoverongaps = False,
            zmin=0.0,
            zmax=1.0,
        ),
        row=i//7 + 1,
        col=i%7 + 1
    )
fig.layout.update(width=1800,height=700,margin_t=100, title_text="Accuracy of the MINE/YOURS direction per Number of Flips per Tile")
save_plotly(fig, "num_flipped_linear_probe_result_pos")
fig.show()

flipped_probe_result_pos = num_flipped_flipped_probe_result[:, :]
flipped_probe_result_pos = einops.reduce(flipped_probe_result_pos, "layers pos rows cols -> pos rows cols", reduction="mean")

fig = make_subplots(rows=2, cols=7, subplot_titles=[f"Number of Flips: {i}" for i in range(1, 14)], vertical_spacing=0.15)
for i, pos in enumerate(range(1, 14)):
    # print(layer//3 + 1, layer%3 + 1)
    fig.add_trace(
        go.Heatmap(
            z=flipped_probe_result_pos[pos].cpu(),
            x=list(range(0, 8)),
            y=list(range(0, 8)),
            hoverongaps = False,
            zmin=0.0,
            zmax=1.0,
        ),
        row=i//7 + 1,
        col=i%7 + 1
    )
fig.layout.update(width=1800,height=700,margin_t=100, title_text="Recall of the Flipped direction per Number of Flips per Tile")
save_plotly(fig, "num_flipped_flipped_probe_result_pos")
fig.show()


In [None]:
from itertools import product
# [f"Layer: {j}, Number Flips: {i}" for i, j in list(product(range(8), range(1, 14)))]

In [None]:
# linear_probe_result_pos = linear_probe_result[:, :]
# linear_probe_result_pos = einops.reduce(linear_probe_result_pos, "layers pos rows cols -> pos rows cols", reduction="mean")
print()
fig = make_subplots(rows=13, cols=8, column_titles=[f"Layer: {j}, Position: {i}" for i, j in list(product(range(8), range(6, 59, 5)))], vertical_spacing=0.02)
for i, pos in enumerate(range(0, 59, 5)):
    for layer in range(8):
        # print(layer//3 + 1, layer%3 + 1)
        fig.add_trace(
            go.Heatmap(
                z=linear_probe_result[layer, pos].cpu(),
                x=list(range(0, 8)),
                y=list(range(0, 8)),
                hoverongaps = False,
                zmin=0.6,
                zmax=1.0,
            ),
            row=i + 1,
            col=layer + 1
        )
fig.layout.update(width=2000,height=4000,margin_t=50)
save_plotly(fig, "linear_probe_result_pos_layer_expanded")
fig.show()




In [None]:
# linear_probe_result_pos = linear_probe_result[:, :]
# linear_probe_result_pos = einops.reduce(linear_probe_result_pos, "layers pos rows cols -> pos rows cols", reduction="mean")
print()
fig = make_subplots(rows=13, cols=8, column_titles=[f"Layer: {j}, Number Flips: {i}" for i, j in list(product(range(8), range(1, 14)))], vertical_spacing=0.02)
for pos in range(1, 14):
    for layer in range(8):
        # print(layer//3 + 1, layer%3 + 1)
        fig.add_trace(
            go.Heatmap(
                z=num_flipped_linear_probe_result[layer, pos].cpu(),
                x=list(range(0, 8)),
                y=list(range(0, 8)),
                hoverongaps = False,
                zmin=0.5,
                zmax=1.0,
            ),
            row=pos,
            col=layer + 1
        )
fig.layout.update(width=2000,height=4000,margin_t=50)
save_plotly(fig, "num_flipped_linear_probe_result_num_layer_expanded")
fig.show()


