In [67]:
import os, sys
chapter = "chapter1_transformer_interp"
repo = "ARENA_3.0"
chapter_dir = r"./" if chapter in os.listdir() else os.getcwd().split(chapter)[0]
sys.path.append(chapter_dir + f"{chapter}/exercises")

os.environ["ACCELERATE_DISABLE_RICH"] = "1"
import torch as t
from torch import Tensor
import numpy as np
import einops
from ipywidgets import interact
import plotly.express as px
from pathlib import Path
import itertools
import random
from IPython.display import display
from typing import List, Union, Optional, Tuple, Callable, Dict
import typeguard
from functools import partial
# from torcheval.metrics.functional import multiclass_f1_score
from sklearn.metrics import f1_score as multiclass_f1_score
import dataclasses
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import HookedRootModule, HookPoint
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from tqdm.notebook import tqdm
from dataclasses import dataclass
from rich import print as rprint
import pandas as pd

from plotly_utils import imshow
from pathlib import Path
from typing import List, Union, Optional, Tuple, Callable, Dict
from jaxtyping import Float, Int, Bool, Shaped, jaxtyped
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from jinja2 import Template

# os.chdir(section_dir)
section_dir = Path.cwd()
assert section_dir.name == "interpretability"

OTHELLO_ROOT = (section_dir / "othello_world").resolve()
OTHELLO_MECHINT_ROOT = (OTHELLO_ROOT / "mechanistic_interpretability").resolve()

sys.path.append(str(OTHELLO_MECHINT_ROOT))
from mech_interp_othello_utils import (
    OthelloBoardState,
    to_int,
    to_string,
    string_to_label,
    str_to_int,
)

t.manual_seed(42)

device = t.device("cuda" if t.cuda.is_available() else "cpu")

MAIN = __name__ == "__main__"

cfg = HookedTransformerConfig(
    n_layers = 8,
    d_model = 512,
    d_head = 64,
    n_heads = 8,
    d_mlp = 2048,
    d_vocab = 61,
    n_ctx = 59,
    act_fn="gelu",
    normalization_type="LNPre",
    device=device,
)
model = HookedTransformer(cfg)

sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "synthetic_model.pth")
# champion_ship_sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "championship_model.pth")
model.load_state_dict(sd)
board_seqs_int = t.tensor(np.load(OTHELLO_MECHINT_ROOT / "board_seqs_int_small.npy"), dtype=t.long)
board_seqs_string_train = t.load(
    os.path.join(
        section_dir,
        "data/board_seqs_string_train.pth",
    )
)
'''board_seqs_string_test = t.load(
    os.path.join(
        section_dir,
        "data/board_seqs_string_valid.pth",
    )
)'''
board_seqs_string = t.tensor(np.load(OTHELLO_MECHINT_ROOT / "board_seqs_string_small.npy"), dtype=t.long)
assert all([middle_sq not in board_seqs_string for middle_sq in [27, 28, 35, 36]])
assert board_seqs_int.max() == 60
num_games, length_of_game = board_seqs_int.shape
# Define possible indices (excluding the four center squares)
stoi_indices = [i for i in range(64) if i not in [27, 28, 35, 36]]
# Define our rows, and the function that converts an index into a (row, column) label, e.g. `E2`
alpha = "ABCDEFGH"
def to_board_label(i):
    return f"{alpha[i//8]}{i%8}"
# Get our list of board labels
board_labels = list(map(to_board_label, stoi_indices))
full_board_labels = list(map(to_board_label, range(64)))
# start = 30000
def get_focus_logits_and_cache():
    start = 0
    num_games = 50
    focus_games_int = board_seqs_int[start : start + num_games]
    focus_games_string = board_seqs_string[start: start + num_games]
    focus_logits, focus_cache = model.run_with_cache(focus_games_int[:, :-1].to(device))
    return focus_logits, focus_cache
# focus_logits.shape
def one_hot(list_of_ints, num_classes=64):
    out = t.zeros((num_classes,), dtype=t.float32)
    out[list_of_ints] = 1.
    return out
focus_states = np.zeros((num_games, 60, 8, 8), dtype=np.float32)
focus_valid_moves = t.zeros((num_games, 60, 64), dtype=t.float32)

BLANK1 = 0
BLACK = 1
WHITE = -1

EMPTY = 0
YOURS = 1
MINE = 2

FLIPPED = 0
NOT_FLIPPED = 1

PLACED = 0
NOT_PLACED = 1

FLIPPED_TOP = 0
FLIPPED_TOP_RIGHT = 1
FLIPPED_RIGHT = 2
FLIPPED_BOTTOM_RIGHT = 3
FLIPPED_BOTTOM = 4
FLIPPED_BOTTOM_LEFT = 5
FLIPPED_LEFT = 6
FLIPPED_TOP_LEFT = 7

ACCESIBLE = 0
NOT_ACCESIBLE = 1

LEGAL = 0
NOT_LEGAL = 1

# Load Model
def load_model(device):
    cfg = HookedTransformerConfig(
        n_layers = 8,
        d_model = 512,
        d_head = 64,
        n_heads = 8,
        d_mlp = 2048,
        d_vocab = 61,
        n_ctx = 59,
        act_fn="gelu",
        normalization_type="LNPre",
        device=device,
    )
    model = HookedTransformer(cfg)

    sd = transformer_lens.utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "synthetic_model.pth")
    # champion_ship_sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "championship_model.pth")
    model.load_state_dict(sd)
    return model

alpha = "ABCDEFGH"

# Load Probes

probes = dict()
probe_modules = os.listdir("probes")
for probe_module in probe_modules:
    probe_types = os.listdir(f"probes/{probe_module}")
    for probe_type in probe_types:
        for layer in range(8):
            path = f"probes/{probe_module}/{probe_type}/resid_{layer}_{probe_type}.pth"
            if not os.path.exists(path):
                continue
            if device.type == "cpu":
                probe = t.load(path, map_location=device).detach()
            else:
                probe = t.load(path).to(device).detach()
            probes[(probe_module, probe_type, layer)] = probe

def get_probe(layer : Int = 5, probe_type : str = "linear", probe_module : str = "post"):
    # assert probe_module in ["post", "mid"]
    # assert probe_type in ["linear", "flipped"]
    return probes[(probe_module, probe_type, layer)]

probe_directions = {
    "linear": {
        "empty" : EMPTY,
        "yours" : YOURS,
        "mine" : MINE, 
    },
    "flipped": {
        "flipped" : FLIPPED,
        "not_flipped" : NOT_FLIPPED,
    },
    "placed" : {
        "placed" : PLACED,
        "not_placed" : NOT_PLACED,
    },
    "accesible" : {
        "accesible" : ACCESIBLE,
        "not_accesible" : NOT_ACCESIBLE,
    },
    "legal" : {
        "legal" : LEGAL,
        "not_legal" : NOT_LEGAL,
    },
    "placed_and_flipped" : {
        "top" : FLIPPED_TOP,
        "top_right" : FLIPPED_TOP_RIGHT,
        "right" : FLIPPED_RIGHT,
        "bottom_right" : FLIPPED_BOTTOM_RIGHT,
        "bottom" : FLIPPED_BOTTOM,
        "bottom_left" : FLIPPED_BOTTOM_LEFT,
        "left" : FLIPPED_LEFT,
        "top_left" : FLIPPED_TOP_LEFT,
    },
    "placed_and_flipped_stripe" : {
        "top" : FLIPPED_TOP,
        "top_right" : FLIPPED_TOP_RIGHT,
        "right" : FLIPPED_RIGHT,
        "bottom_right" : FLIPPED_BOTTOM_RIGHT,
        "bottom" : FLIPPED_BOTTOM,
        "bottom_left" : FLIPPED_BOTTOM_LEFT,
        "left" : FLIPPED_LEFT,
        "top_left" : FLIPPED_TOP_LEFT,
    }
}

probe_directions_list = {
    k : list(v.keys()) for k, v in probe_directions.items()
}

short_cuts = {
    "empty" : "E",
    "yours" : "Y",
    "mine" : "M",
    "flipped" : "F",
    "not_flipped" : "NF",
    "placed" : "P",
    "not_placed" : "NP",
    "accesible" : "A",
    "not_accesible" : "NA",
    "legal" : "L",
    "not_legal" : "NL",
    "top" : "T",
    "top_right" : "TR",
    "right" : "R",
    "bottom_right" : "BR",
    "bottom" : "B",
    "bottom_left" : "BL",
    "left" : "L",
    "top_left" : "TL",
    "linear" : "L",
    "placed_and_flipped" : "PF",
    "placed_and_flipped_stripe" : "PFS",
}

def get_short_cut(name):
    return short_cuts[name]

def get_probe_names():
    return list(probe_directions.keys())

def get_direction_str(probe_name, direction_int):
    for direction_str in probe_directions[probe_name]:
        if probe_directions[probe_name][direction_str] == direction_int:
            return direction_str
    assert(False)

def get_direction_int(directions_str):
    directions_str = directions_str.lower()
    for probe_name in probe_directions:
        if directions_str in probe_directions[probe_name]:
            return probe_directions[probe_name][directions_str]
    assert(False)

def seq_to_state_stack(str_moves):
    """
    Takes a sequence of moves and returns a stack of states for each move with dimensions (num_moves, rows, cols)
    -1 white, 0 blank, 1 black
    """
    board = OthelloBoardState()
    states = []
    for move in str_moves:
        board.umpire(move)
        states.append(np.copy(board.state))
    states = np.stack(states, axis=0)
    return states


# Ploting Functions

def plot_square_as_board(state, diverging_scale=True, **kwargs):
    """Takes a square input (8 by 8) and plot it as a board. Can do a stack of boards via facet_col=0"""
    kwargs = {
        "y": [i for i in alpha],
        "x": [str(i) for i in range(8)],
        "color_continuous_scale": "RdBu" if diverging_scale else "Blues",
        "color_continuous_midpoint": 0. if diverging_scale else None,
        "aspect": "equal",
        **kwargs
    }
    imshow(state, **kwargs)

def plot_probe_outputs(focus_cache, linear_probe, layer, game_index, move, **kwargs):
    residual_stream = focus_cache["resid_post", layer][game_index, move]
    # print("residual_stream", residual_stream.shape)
    # probe_out = einops.einsum(residual_stream, linear_probe, "d_model, d_model row col options -> row col options")
    probe_out = einops.einsum(residual_stream, linear_probe, "d_model, modes d_model row col options -> modes row col options")[0]
    '''if move % 2 == 0:
        probe_out = probe_out[0]
    else:
        probe_out = probe_out[1]'''
    probabilities = probe_out.softmax(dim=-1)
    plot_square_as_board(probabilities, facet_col=2, facet_labels=["P(EMPTY)", "P(YOURS)", "P(MINE)"], **kwargs)

def plot_game(games_str, game_index=0, end_move=16):
    '''
    This shows the game the 0'th move is the first move the display shows the board after the move was made
    '''
    focus_states = seq_to_state_stack(games_str[game_index])
    imshow(
        focus_states[:end_move],
        facet_col=0,
        facet_col_wrap=8,
        facet_labels=[f"Move {i}" for i in range(0, end_move)],
        title="First 16 moves of first game",
        color_continuous_scale="Greys",
        y = [i for i in alpha],
    )

def square_to_tuple(square, is_int=False):
    if is_int:
        square = to_string(square)
    row = square // 8
    col = square % 8
    return (row, col)

def to_board_label(i):
    return f"{alpha[i//8]}{i%8}"

def tuple_to_label(t):
    row = t[0]
    col = t[1]
    return f"{alpha[row]}{col}"

def get_focus_games(model = None, device = "cpu"):
    # Load board data as ints (i.e. 0 to 60)
    board_seqs_int = t.tensor(np.load(OTHELLO_MECHINT_ROOT / "board_seqs_int_small.npy"), dtype=t.long)
    # Load board data as "strings" (i.e. 0 to 63 with middle squares skipped out)
    board_seqs_string = t.tensor(np.load(OTHELLO_MECHINT_ROOT / "board_seqs_string_small.npy"), dtype=t.long)

    assert all([middle_sq not in board_seqs_string for middle_sq in [27, 28, 35, 36]])
    assert board_seqs_int.max() == 60

    num_games, length_of_game = board_seqs_int.shape
    start = 0
    num_games = 50
    focus_games_int = board_seqs_int[start : start + num_games]
    focus_games_string = board_seqs_string[start: start + num_games]

    if model is not None:
        focus_logits, focus_cache = model.run_with_cache(focus_games_int[:, :-1].to(device))
        return focus_games_int, focus_games_string, focus_logits, focus_cache
    return focus_games_int, focus_games_string

def square_tuple_from_square(square : str):
    return (alpha.index(square[0]), int(square[1])) 

reverse_alpha = ["H", "G", "F", "E", "D", "C", "B", "A"]

def save_plotly_to_html(fig, filename):
    TEMPLATE_PATH = "interactive_plots/template.html"
    assert os.path.exists(TEMPLATE_PATH)
    plotly_jinja_data = {"fig":fig.to_html(full_html=False)}
    with open(filename, "w", encoding="utf-8") as output_file:
        with open(TEMPLATE_PATH) as template_file:
            j2_template = Template(template_file.read())
            output_file.write(j2_template.render(plotly_jinja_data))

def save_plotly_to_png(fig, filename):
    fig.write_image(filename)

def save_plotly(fig, name):
    save_plotly_to_html(fig, f"interactive_plots/{name}.html")
    save_plotly_to_png(fig, f"plots/{name}.png")

def plot_boards_general(x_labels : List[str],
                        y_labels : List[str],
                        boards : Float[Tensor, "x y rows cols"],
                        size_of_board : Int = 200,
                        margin_t : Int = 100,
                        title_text : str = "",
                        color_scale : str = "RdBu",
                        color_range  : str = "symmetric",
                        static_image : bool = False,
                        save : bool = False):
    # TODO: add attn/mlp only
    # TODO: Change Width and Height accordingly
    boards = boards.flip(2)
    x_len, y_len, rows, cols = boards.shape
    subplot_titles = [f"{y_label}, {x_label}" for y_label in y_labels for x_label in x_labels]
    # subplot_titles = [f"P: {i}, T: {label_list[i]}, L: {j}" for i in range(vis_args.start_pos, vis_args.end_pos) for j in range(vis_args.layers)]
    width = x_len * size_of_board
    height = y_len * size_of_board + margin_t
    vertical_spacing = 70 / height
    fig = make_subplots(rows=y_len, cols=x_len, subplot_titles=subplot_titles, vertical_spacing=vertical_spacing)
    boards_min = boards.min().item()
    boards_max = boards.max().item()        
    abs_max = max(abs(boards_min), abs(boards_max))
    if color_range == "symmetric":
        begin = -abs_max
        end = abs_max
    else:
        begin = boards_min
        end = boards_max
    for x in range(x_len):
        for y in range(y_len):
            heatmap = go.Heatmap(
                z=boards[x, y].cpu(),
                x=list(range(0, rows)),
                y=reverse_alpha,
                hoverongaps = False,
                zmin=begin,
                zmax=end,
                colorscale=color_scale,
            )
            fig.add_trace(
                heatmap,
                row=y + 1,
                col=x + 1
            )
    fig.layout.update(width=width, height=height, margin_t=margin_t, title_text=title_text) 
    if static_image:
        # count the number of images in the last_plot directory
        num_images = len(list(Path("last_plot").glob("*.png")))
        fig.write_image(f'last_plot/last_plot{num_images+1}.png')
    else:
        fig.show()
    if save:
        save_plotly(fig, title_text)

def get_color(val : float):
    val = min(int(val * 5), 4)
    # Define the gradient characters from darkest to lightest
    gradient_chars = [" ", "░", "▒", "▓", "█"]
    return gradient_chars[val]

@dataclass
class VisualzeBoardArguments:
    include_attn_only = False
    include_mlp_only = False
    include_pre_resid = False
    include_layer_norm = False
    include_resid_post = True
    start_pos=0
    end_pos=59
    layer_start=0
    layers=8
    static_image=False#
    size_of_board = 225
    margin_t = 100
    mode = "linear"
    horizontal_spacing = 20
    margin_l = 100

def get_score_from_resid(resid, layer):
    # assert probe_name in ["linear", "flipped"]
    linear_probe = get_probe(layer, "linear", "post")
    flipped_probe = get_probe(layer, "flipped", "post")
    assert len(resid.shape) == 2
    seq_len, d_model = resid.shape
    logits = einops.einsum(resid, linear_probe, 'pos d_model, modes d_model rows cols options -> modes pos rows cols options')[0]
    probs = logits.softmax(dim=-1)
    flipped_logits = einops.einsum(resid, flipped_probe, 'pos d_model, modes d_model rows cols options -> modes pos rows cols options')[0]
    flipped_probs = flipped_logits.softmax(dim=-1)
    probs_copy = probs.clone()
    # Convert Back to Balck/White
    for i in range(0, seq_len, 2):
        probs[i, :, :, 1], probs[i, :, :, 2] = probs_copy[i, :, :, 2], probs_copy[i, :, :, 1]
    color_score = 0.5 + (probs[:, :, :, 2] - probs[:, :, :, 1])/2
    # Flip the color score on the rows dimension
    # TODO: Add Flips as Labels...
    color_score = color_score.flip(1)
    flip_score = flipped_probs[:, :, :, [0]].flip(1).squeeze(dim=-1)
    # flip_score = flipped_probs[:, :, :, [0]].squeeze(dim=-1)
    return color_score, flip_score

def get_boards(input_int : Float[Tensor, "pos"], vis_args : VisualzeBoardArguments, model: HookedTransformer):
    _, cache = model.run_with_cache(input_int)
    boards = []
    flip_boards = []
    for layer in range(vis_args.layers):
        color_scores = []
        flip_scores = []
        if vis_args.include_resid_post:
            resid = cache["resid_post", layer][0].detach()
            color_score, flip_score = get_score_from_resid(resid, layer)
            color_scores += [color_score]
            flip_scores += [flip_score]
        if vis_args.include_pre_resid:
            resid = cache["resid_pre", layer][0].detach()
            color_score, flip_score = get_score_from_resid(resid, layer)
            color_scores += [color_score]
            flip_scores += [flip_score]
        if vis_args.include_attn_only:
            resid = cache["resid_post", layer][0].detach() - t.stack([cache["mlp_out", l][0].detach() for l in range(layer, layer + 1)]).sum(dim=0) - cache["resid_pre", layer][0].detach()
            color_score, flip_score = get_score_from_resid(resid, layer)
            color_scores += [color_score]
            flip_scores += [flip_score]
        if vis_args.include_mlp_only:
            resid = cache["resid_post", layer][0].detach() - t.stack([cache["attn_out", l][0].detach() for l in range(layer, layer + 1)]).sum(dim=0) - cache["resid_pre", layer][0].detach()
            color_score, flip_score = get_score_from_resid(resid, layer)
            color_scores += [color_score]
            flip_scores += [flip_score]
        if vis_args.include_layer_norm:
            resid = cache[f"blocks.{layer}.ln1.hook_normalized"][0].detach()
            color_score, flip_score = get_score_from_resid(resid, layer)
            color_scores += [color_score]
            flip_scores += [flip_score]
        if vis_args.include_layer_norm:
            resid = cache[f"blocks.{layer}.ln2.hook_normalized"][0].detach()
            color_score, flip_score = get_score_from_resid(resid, layer)
            color_scores += [color_score]
            flip_scores += [flip_score]
        color_score = t.stack(color_scores, dim=0)
        # color_score = color_score.transpose(0, 1)
        # color_score = color_score.reshape(-1, 8, 8)
        flip_score = t.stack(flip_scores, dim=0)
        # flip_score = flip_score.transpose(0, 1)
        # flip_score = flip_score.reshape(-1, 8, 8)
        # color_score, flip_score = get_score_from_resid(resid, layer)
        boards += [color_score]
        flip_boards += [flip_score]
    boards = t.stack(boards)
    flip_boards = t.stack(flip_boards)
    return boards, flip_boards

def plot_boards(label_list: List[str], boards : Float[Tensor, "layers mode pos rows cols"], flip_boards : Float[Tensor, "layers mode pos rows cols"], vis_args: VisualzeBoardArguments):
    # TODO: add attn/mlp only
    # TODO: Change Width and Height accordingly
    _, _, _, rows, cols = boards.shape
    print(boards.shape)
    seq_len = vis_args.end_pos - vis_args.start_pos
    modes = []
    if vis_args.include_resid_post:
        modes += ["N"]
    if vis_args.include_pre_resid:
        modes += ["P"]
    if vis_args.include_attn_only:
        modes += ["A"]
    if vis_args.include_mlp_only:
        modes += ["M"]
    if vis_args.include_layer_norm:
        modes += ["Attn"]
        modes += ["MLP"]
    subplot_titles = [f"P: {i}, T: {label_list[i]}, L: {j}, {mode}" for i in range(vis_args.start_pos, vis_args.end_pos) for j in range(vis_args.layer_start, vis_args.layers) for mode in modes]
    # subplot_titles = [f"P: {i}, T: {label_list[i]}, L: {j}" for i in range(vis_args.start_pos, vis_args.end_pos) for j in range(vis_args.layers)]
    width = ((vis_args.layers - vis_args.layer_start) * len(modes) * vis_args.size_of_board) * (1 + 2 * vis_args.horizontal_spacing)
    height = vis_args.margin_t + seq_len * vis_args.size_of_board
    vertical_spacing = 70 / height
    horizontal_spacing = vis_args.horizontal_spacing
    fig = make_subplots(rows=seq_len, cols=(vis_args.layers - vis_args.layer_start) * len(modes), subplot_titles=subplot_titles, vertical_spacing=vertical_spacing, horizontal_spacing = horizontal_spacing)
    for pos_idx, pos in enumerate(range(vis_args.start_pos, vis_args.end_pos)):
        for layer_idx, layer in enumerate(range(vis_args.layer_start, vis_args.layers)):
            for mode_idx, mode in enumerate(modes):
                text_data = [[get_color(flip_boards[layer, mode_idx, pos, i, j]) for j in range(cols)] for i in range(rows)]
                if vis_args.mode == "linear":
                    heatmap = go.Heatmap(
                        z=boards[layer, mode_idx, pos].cpu(),
                        text=text_data,
                        x=list(range(0, rows)),
                        y=reverse_alpha,
                        hoverongaps = False,
                        zmin=0.0,
                        zmax=1.0,
                        colorscale="RdBu",
                        texttemplate="%{text}",
                        showscale=False,
                        # textfont_color="green",
                    )
                elif vis_args.mode == "flipped":
                    heatmap = go.Heatmap(
                        z=flip_boards[layer, mode_idx, pos].cpu(),
                        x=list(range(0, rows)),
                        y=reverse_alpha,
                        hoverongaps = False,
                        zmin=0.0,
                        zmax=1.0,
                        colorscale="Greens", # Green color scale
                    )
                else:
                    raise ValueError("Invalid Mode")
                fig.add_trace(
                    heatmap,
                    row=pos_idx + 1,
                    col=layer_idx * len(modes) + mode_idx + 1
                )
    fig.layout.update(width=width, height=height, margin_t=vis_args.margin_t, margin_l=vis_args.margin_l, title_text=f"Probe Results per Position per Layer") 
    if vis_args.static_image:
        # count the number of images in the last_plot directory
        num_images = len(list(Path("last_plot").glob("*.png")))
        fig.write_image(f'last_plot/last_plot{num_images+1}.png')
    else:
        fig.show()


def visualize_game(input_str, vis_args: VisualzeBoardArguments, model: HookedTransformer):
    # 1. Get the cache
    # 2. Get Board States from the cache using the Pobes
    # 3. Plot the Board States
    # assert not (vis_args.include_attn_only and vis_args.include_mlp_only)
    if len(input_str) > 59:
        input_str = input_str[:59]
    label_list = string_to_label(input_str)
    boards, flip_boards = get_boards(t.Tensor(to_int(input_str)).to(t.int32), vis_args, model)
    plot_boards(label_list, boards, flip_boards, vis_args)


def label_to_tuple(label):
    # return f"{alpha[label // 8]}{label % 8}" This but reverse
    alhpha_ind = alpha.find(label[0])
    return (alhpha_ind, int(label[1]))

def label_to_string(label):
    tup = label_to_tuple(label)
    return tup[0] * 8 + tup[1]

def label_to_int(label):
    st =  label_to_string(label)
    return str_to_int(st)



if __name__ == "__main__":
    vis_args = VisualzeBoardArguments()
    vis_args.start_pos = 0
    vis_args.end_pos = 20
    vis_args.layers = 6
    vis_args.include_attn_only = False
    vis_args.include_mlp_only = False
    vis_args.include_layer_norm = True
    vis_args.mode = "flipped"
    vis_args.static_image = True

    model = load_model("cuda")
    _, focus_games_str = get_focus_games()

    clean_input_str = focus_games_str[0][:30]
    # visualize_game(clean_input_str, vis_args, model)
    '''
    print(label_to_int("B3"))
    print(label_to_string("B3"))
    print(label_to_tuple("B3"))'''

# Create a helper function, where I can say, I want to get the thiese activations for the first ... games
def get_activation(act_names, num_games, start=0, board_seqs_int=board_seqs_int):
    # TODO: If this takes to long or something, Make a filter step!
    act_name_results = {act_name : [] for act_name in act_names}
    inference_size = 1000
    for batch in range(start, start+num_games, inference_size):
        with t.inference_mode():
            _, cache = model.run_with_cache(
                board_seqs_int[batch:batch+inference_size, :-1].to(device),
                return_type=None,
                names_filter=lambda name: name in act_names
                # names_filter=lambda name: name == f"blocks.{layer}.hook_resid_mid" or name == f"blocks.{layer}.mlp.hook_post"
                # names_filter=lambda name: name == f"blocks.{layer}.hook_resid_pre" or name == f"blocks.{layer}.mlp.hook_post"
            )
        for act_name in act_names:
            act_name_results[act_name] += [cache[act_name]]
    for act_name in act_names:
        act_name_results[act_name] = t.cat(act_name_results[act_name], dim=0)
        act_name_results[act_name] = act_name_results[act_name].detach()[:num_games]
    return act_name_results

    

In [68]:
# board_seqs_int_test = t.load("data/board_seqs_int_valid.pth")

In [70]:
game = 0
vis_args = VisualzeBoardArguments()
vis_args.start_pos = 14
vis_args.end_pos = 20
vis_args.layer_start = 1
vis_args.layers = 3
# vis_args.static_image = True
vis_args.include_resid_post = False
vis_args.include_layer_norm = True
vis_args.horizontal_spacing = 0.07
vis_args.static_image = True
# visualize_game(to_string(board_seqs_int_test[game]), vis_args, model)
visualize_game(board_seqs_string[game], vis_args, model)

torch.Size([3, 2, 59, 8, 8])


In [3]:
focus_logits, focus_cache = get_focus_logits_and_cache()

In [4]:
attn_out = focus_cache["attn_out", 0]
attn_out.shape
import pickle

# load data/neurons_in_famility/flipping_neurons.pkl using pickle to a dictionary
with open('data/neurons_in_famility/flipping_neurons.pkl', 'rb') as f:
    flipping_neurons = pickle.load(f)

In [5]:
mlp_out = focus_cache["mlp_post", layer]
W_out = model.W_out
W_out.shape, mlp_out.shape

(torch.Size([8, 2048, 512]), torch.Size([50, 59, 2048]))

In [6]:
model.b_out.shape

torch.Size([8, 512])

In [None]:
OV = model.OV.AB
for layer in range(1, 8):
    OV_layer = OV[layer]
    # resid = focus_cache[f"blocks.{layer}.ln1.hook_normalized", layer][0].detach()
    probe = get_probe(layer-1, "linear", "post")[0, :, :, :, MINE]
    output = einops.einsum(probe, OV, "d_model_in rows cols, head d_model_in d_model_out -> rows cols d_model_out")


In [17]:
def get_w_in(
    model: HookedTransformer,
    layer: int,
    neuron: int,
    normalize: bool = False,
) -> Float[Tensor, "d_model"]:
    '''
    Returns the input weights for the given neuron.

    If normalize is True, the weight is normalized to unit norm.
    '''
    # SOLUTION
    w_in = model.W_in[layer, :, neuron].detach().clone()
    if normalize: w_in /= w_in.norm(dim=0, keepdim=True)
    return w_in


def get_w_out(
    model: HookedTransformer,
    layer: int,
    neuron: int,
    normalize: bool = False,
) -> Float[Tensor, "d_model"]:
    '''
    Returns the output weights for the given neuron.

    If normalize is True, the weight is normalized to unit norm.
    '''
    # SOLUTION
    w_out = model.W_out[layer, neuron, :].detach().clone()
    if normalize: w_out /= w_out.norm(dim=0, keepdim=True)
    return  w_out

def get_similiarity(neuron : Int, layer : Int, tiles : List[Tuple[str, str, str, str]], metric = "avg"):
    avg_similiarity = 0
    direction_all = t.zeros([512]).to(device)
    for label, probe_type, feature_str, in_or_out in tiles:
        tile_tuple = label_to_tuple(label)
        y, x = tile_tuple
        feature = get_direction_int(feature_str)
        if in_or_out == "in":
            probe_module = "mid"
            w = get_w_in(model, layer, neuron, normalize=True)
        else:
            probe_module = "post"
            w = get_w_out(model, layer, neuron, normalize=True)
        probe = get_probe(layer, probe_type=probe_type, probe_module=probe_module)
        direction = probe[0, :, y, x, feature]
        direction = direction / direction.norm()
        direction_all += direction
        similiarity = einops.einsum(direction, w, "d_model, d_model ->").item()
        if feature_str == "empty":
            similiarity = similiarity / 3
        avg_similiarity += similiarity
    direction_all = direction_all / direction_all.norm()
    similiarity_all = einops.einsum(direction_all, w, "d_model, d_model ->").item()
    if metric == "avg":
        return avg_similiarity / len(tiles)
    else:
        return similiarity_all

In [29]:
labels = []
how_many = 10

for row in range(8):
    for col in range(8):
        labels.append(tuple_to_label((row, col)))

flipping_neurons2 = {}

for layer in range(1, 8):
    print(f"Layer {layer}")
    similiarities = dict()
    flipping_neurons2[layer] = []
    for label in tqdm(labels):
        tiles_out = [
            (label, "flipped", "flipped", "out")
        ]

        for neuron1 in range(4 * 512):
            similiarity_out = get_similiarity(neuron1, layer, tiles_out)
            similiarities[neuron1] = similiarity_out

        sorted_similiarities = sorted(similiarities.items(), key=lambda x: x[1], reverse=True)
        neurons_list = [neuron for neuron, similiarity in sorted_similiarities[:how_many]]
        similiarities_list = [round(similiarity, 2) for neuron, similiarity in sorted_similiarities[:how_many]]
        flipping_neurons2[layer] += neurons_list
        # print(f"Layer: {layer}, Tile: {label}, Neurons: {neurons}, Similiarities: {similiarities}")
    flipping_neurons2[layer] = list(set(flipping_neurons2[layer]))

Layer 1


  0%|          | 0/64 [00:00<?, ?it/s]

Layer 2


  0%|          | 0/64 [00:00<?, ?it/s]

Layer 3


  0%|          | 0/64 [00:00<?, ?it/s]

Layer 4


  0%|          | 0/64 [00:00<?, ?it/s]

Layer 5


  0%|          | 0/64 [00:00<?, ?it/s]

Layer 6


  0%|          | 0/64 [00:00<?, ?it/s]

Layer 7


  0%|          | 0/64 [00:00<?, ?it/s]

In [33]:
# TODO: Hier kann ich noch ein bisschen weiter rum spielen ...
# TODO: Only look at a set of Flipping Neurons

attn_out = focus_cache["attn_out", 0]
attn_out.shape
import pickle

# load data/neurons_in_famility/flipping_neurons.pkl using pickle to a dictionary
with open('data/neurons_in_famility/flipping_neurons.pkl', 'rb') as f:
    flipping_neurons = pickle.load(f)

# flipping_neurons = flipping_neurons2

# check how much percent of Flipped Logits come from MLP vs Attention
avg_flipped_logit_percent_from_mlp : Float[Tensor, "layers"] = t.zeros((8, ))
flipped_logit_from_mlp = t.zeros((8, ))
flipped_logit_from_attn = t.zeros((8, ))
for layer in range(1, 8):
    mlp_post = focus_cache["mlp_post", layer]
    mlp_post_mean = mlp_post.mean(dim=0)
    mlp_post_mean = einops.repeat(mlp_post_mean, "pos neurons -> game pos neurons", game=mlp_post.shape[0]).clone()
    # get flipping neurons of the layer
    neurons = []
    # neurons += flipping_neurons[layer]
    for rule_name, neuron_list in flipping_neurons[layer].items():
        neurons += neuron_list
    neurons = list(set(neurons))
    neurons.sort()
    # neurons = list(range(2048))
    # print(neurons[:10])
    print(len(neurons))
    probe = get_probe(layer, "flipped", "post")[0]
    attn_out = focus_cache["attn_out", layer]
    mlp_out = focus_cache["mlp_out", layer]
    mlp_post = focus_cache["mlp_post", layer]
    W_out = model.W_out[layer].detach()
    b_out = model.b_out[layer].detach()
    # mlp_post = mlp_post[:, :, neurons]
    # mlp_post = t.max(mlp_post, t.zeros_like(mlp_post))
    # W_out = W_out[neurons]
    # print(mlp_post.shape, mlp_post_mean.shape)
    mlp_post_mean[:, :, neurons] = mlp_post[:, :, neurons]
    mlp_post = mlp_post_mean
    mlp_out = einops.einsum(mlp_post, W_out, "game pos neurons, neurons d_model -> game pos d_model") + b_out
    # print(mlp_out.shape)
    resid_post = focus_cache["resid_post", layer]
    logits_from_attn = einops.einsum(attn_out, probe, "game pos d_model, d_model rows cols options -> game pos rows cols options")
    logits_from_mlp = einops.einsum(mlp_out, probe, "game pos d_model, d_model rows cols options -> game pos rows cols options")
    logits_from_resid_post = einops.einsum(resid_post, probe, "game pos d_model, d_model rows cols options -> game pos rows cols options")
    tiles_flipped = logits_from_resid_post[:, :, :, :, 0] > logits_from_resid_post[:, :, :, :, 1]
    # tiles_flipped = t.ones_like(tiles_flipped)
    logits_from_attn = logits_from_attn[tiles_flipped]
    logits_from_mlp = logits_from_mlp[tiles_flipped]
    # print(logits_from_attn.shape, logits_from_mlp.shape)
    sum_mlp_flipped = logits_from_mlp[:, 0].sum()
    sum_attn_flipped = logits_from_attn[:, 0].sum()
    # sum_mlp_not_flipped = logits_from_mlp[:, :, :, :, 1].sum()
    # sum_attn_not_flipped = logits_from_attn[:, :, :, :, 1].sum()
    # print(f"Layer: {layer}, MLP Flipped: {sum_mlp_flipped}, MLP Not Flipped: {sum_mlp_not_flipped}, Attn Flipped: {sum_attn_flipped}, Attn Not Flipped: {sum_attn_not_flipped}")
    # print(sum_mlp_flipped, sum_attn_flipped)
    avg_flipped_logit_percent_from_mlp[layer] = sum_mlp_flipped / (sum_mlp_flipped + sum_attn_flipped)
    flipped_logit_from_mlp[layer] = sum_mlp_flipped
    flipped_logit_from_attn[layer] = sum_attn_flipped

# avg_flipped_logit_percent_from_mlp
print(flipped_logit_from_mlp)
print(flipped_logit_from_attn)

564
544
540
454
204
127
258
tensor([     0.0000,   5157.6152,   4808.4189,   5780.5054,   3884.3809,
         -3282.2759, -12740.1523,   -617.6957])
tensor([     0.0000,  -1889.1267,  -1515.0182,  -1069.9028,  -1671.4114,
         -1823.5784,  -3657.9995, -12940.5439])


In [None]:
# Okay just using single Neurons doesen't seem to work ... I think all the negative neurons are relevant maybe. Also I don't have the single flipped Neurons and the very negative flipped to flipped neurons
'''
REAL: tensor([     0.0000,   6352.9609,   2439.9839,   1518.7620,   -814.8992,
         -8579.4238, -17324.6836,   -250.1274])
PRED: Top 20 Neurons per flipped : tensor([     0.0000,   4387.9111,   5138.4956,   4367.2275,   3122.8298,
         -2957.8474, -13642.4121,  -1687.0056])
PRED (Avg + 600 extra Neurons ...) tensor([     0.0000,   5157.6152,   4808.4189,   5780.5054,   3884.3809,
         -3282.2759, -12740.1523,   -617.6957])

[0.0000,   6352.9609,   2439.9839,   1518.7620,   -814.8992, -8579.4238, -17324.6836,   -250.1274]
[0.0000,   4387.9111,   5138.4956,   4367.2275,   3122.8298, -2957.8474, -13642.4121,  -1687.0056]
[0.0000,   5157.6152,   4808.4189,   5780.5054,   3884.3809, -3282.2759, -12740.1523,   -617.6957]
'''

In [8]:
'''ACCESIBLE
5, 7, 10, 11, 12, 14, 17, 19, 20, 23]
662
torch.Size([50, 59, 512])
tensor(473.8499, device='cuda:0') tensor(-1889.1267, device='cuda:0')
[10, 12, 13, 18, 23, 29, 30, 31, 33, 36]
634
torch.Size([50, 59, 512])
tensor(2472.4529, device='cuda:0') tensor(-1515.0182, device='cuda:0')
[10, 13, 17, 20, 21, 22, 24, 25, 26, 30]
630
torch.Size([50, 59, 512])
tensor(4428.3608, device='cuda:0') tensor(-1069.9028, device='cuda:0')
[2, 3, 6, 8, 9, 20, 21, 22, 23, 28]
536
torch.Size([50, 59, 512])
tensor(7608.5562, device='cuda:0') tensor(-1671.4114, device='cuda:0')
[2, 3, 8, 12, 19, 22, 34, 37, 55, 67]
340
torch.Size([50, 59, 512])
tensor(5427.8564, device='cuda:0') tensor(-1823.5784, device='cuda:0')
[1, 23, 60, 62, 64, 86, 88, 89, 92, 100]
219
torch.Size([50, 59, 512])
tensor(4537.5273, device='cuda:0') tensor(-3657.9995, device='cuda:0')
[3, 4, 13, 19, 23, 42, 55, 59, 66, 75]
284
torch.Size([50, 59, 512])
tensor(-2168.6604, device='cuda:0') tensor(-12940.5439, device='cuda:0')
'''

"ACCESIBLE\n5, 7, 10, 11, 12, 14, 17, 19, 20, 23]\n662\ntorch.Size([50, 59, 512])\ntensor(473.8499, device='cuda:0') tensor(-1889.1267, device='cuda:0')\n[10, 12, 13, 18, 23, 29, 30, 31, 33, 36]\n634\ntorch.Size([50, 59, 512])\ntensor(2472.4529, device='cuda:0') tensor(-1515.0182, device='cuda:0')\n[10, 13, 17, 20, 21, 22, 24, 25, 26, 30]\n630\ntorch.Size([50, 59, 512])\ntensor(4428.3608, device='cuda:0') tensor(-1069.9028, device='cuda:0')\n[2, 3, 6, 8, 9, 20, 21, 22, 23, 28]\n536\ntorch.Size([50, 59, 512])\ntensor(7608.5562, device='cuda:0') tensor(-1671.4114, device='cuda:0')\n[2, 3, 8, 12, 19, 22, 34, 37, 55, 67]\n340\ntorch.Size([50, 59, 512])\ntensor(5427.8564, device='cuda:0') tensor(-1823.5784, device='cuda:0')\n[1, 23, 60, 62, 64, 86, 88, 89, 92, 100]\n219\ntorch.Size([50, 59, 512])\ntensor(4537.5273, device='cuda:0') tensor(-3657.9995, device='cuda:0')\n[3, 4, 13, 19, 23, 42, 55, 59, 66, 75]\n284\ntorch.Size([50, 59, 512])\ntensor(-2168.6604, device='cuda:0') tensor(-12940.54

* Layer 0 - 4 macht MLP Flipped und Attn Not Flipped. Ab Layer 5 ist es ausgeglichen bzw. anders rum
* weirdly: Wenn ich nur die flipped neurons benutze, dann funktioniert es bei Layer 1 nicht dafür aber bei 2 bis 6
    * Jezt wo ich das threshold runtergesetzt habe geht es gut ...
    * ABER wenn ich alle Neurons hinzu nehme ist es nochmal viel höher beei 1 bis 3 und ab Layer 4 geringer ...

* Also Neuronen hinzu zu nehmen kann die Logits verringern ... Liegt glaube ich an negativen Aktivierungen ...
    * Nope! Wenn ich negative Aktivierungen hochrunde dann gehen die Logits in den Bach

* Ich glaube die MLP layer macht so eine Background Strahlung von ein bischen Positivem Flipped
    * Wie kann ich das nutzen ...

* Mission ist es eine Art Average Logit zu nehmen plus dann die Logits von einer Handvoll Neurons je nach 
    * Das geht gut: Ich karkuliere das Average MLP out (ohne bias) und kalkuliere das Average mlp_out von jedem Neuron, dann ziehe ich die Average mlp_out der Neurons ab und nehme stattdessen die richtigen Aktivierungen rein (Dann nehme ich aber die Aktivierungen aus dem nichts ...) Wild wäre vorberechnet, die Average Aktivierung wenn halt eine bestimmte Regel aktiv ist ...

* Grober Plan: Gucken welche Regel Aktiv ist, dann Handvoll Neurons raussuchen, dann neue Neuron activations rein pluggen und dann bin ich glücklich ...
    * Was ist wenn es Neurons gibt, die Flipped schreiben ohne soeine Regel zu befolgen ... Das würde alles zerstören

* Der Average Background sollte auf jeden Fall Position abhängig sein
* OKAY! Also ich habe mir eine Menge Neurons angeschaut die Tile Flipped schreiben und es gibt ein Paar Arten
    * Normale Flipped
        * Unterschied: auch Tile MINE und nächstes Tile Yours sind relevant
    * Weird (L1N23)
    * Aktiviert wenn von verschiedenen Seiten geflipped
    * Aktiviert wenn besonders start nicht geflipped ...
* TODO: Ich muss überlegen ob das meinen Plan zerstört ...
* Ich glaube aber das meiste sollte schon noch von den Dingern kommen ...

* Neurons zu random einsetzen klappt nicht!
    * Enweder: Falsche Neurons, Ich sollte die nehmen, die am meisten Tile Flipped schreiben ...
    * Oder: Viele Neurons, selbst negative sind relevant ...

* Ich habe davor voll die Falschen Schlüsse gezogen, (Average nehmen und dann bestimmte Neurons rein )