# Setup

In [1]:
import os, sys
chapter = "chapter1_transformer_interp"
repo = "ARENA_3.0"
chapter_dir = r"./" if chapter in os.listdir() else os.getcwd().split(chapter)[0]
sys.path.append(chapter_dir + f"{chapter}/exercises")

import os
os.environ["ACCELERATE_DISABLE_RICH"] = "1"
import sys
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch.utils.data import DataLoader
import numpy as np
import einops
from ipywidgets import interact
import plotly.express as px
from ipywidgets import interact
from pathlib import Path
import itertools
import random
from IPython.display import display
from jaxtyping import Float, Int, Bool, Shaped, jaxtyped
from typing import List, Union, Optional, Tuple, Callable, Dict
import typeguard
from functools import partial
# from torcheval.metrics.functional import multiclass_f1_score
from sklearn.metrics import f1_score as multiclass_f1_score
import copy
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import HookedRootModule, HookPoint
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from tqdm.notebook import tqdm
from dataclasses import dataclass
from rich import print as rprint
import pandas as pd

# Make sure exercises are in the path
# exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
# section_dir = exercises_dir / "part6_othellogpt"
# section_dir = "interpretability"
# if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow
from neel_plotly import scatter, line
# from generate_patches import generate_patch
from pprint import pprint
from utils import plot_game
from training_utils import get_state_stack_num_flipped
from utils import plot_probe_outputs
from utils import seq_to_state_stack
from utils import VisualzeBoardArguments
from utils import visualize_game

import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import plotly.express as px
import plotly.graph_objects as go
import networkx as nx
from utils import visualize_game
from utils import VisualzeBoardArguments
from utils import label_to_tuple
from utils import label_to_int
from utils import label_to_string

# import part6_othellogpt.tests as tests

t.manual_seed(42)

device = t.device("cuda" if t.cuda.is_available() else "cpu")

MAIN = __name__ == "__main__"

cfg = HookedTransformerConfig(
    n_layers = 8,
    d_model = 512,
    d_head = 64,
    n_heads = 8,
    d_mlp = 2048,
    d_vocab = 61,
    n_ctx = 59,
    act_fn="gelu",
    normalization_type="LNPre",
    device=device,
)
model = HookedTransformer(cfg)

sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "synthetic_model.pth")
# champion_ship_sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "championship_model.pth")
model.load_state_dict(sd)

# An example input
sample_input = t.tensor([[
    20, 19, 18, 10,  2,  1, 27,  3, 41, 42, 34, 12,  4, 40, 11, 29, 43, 13, 48, 56,
    33, 39, 22, 44, 24,  5, 46,  6, 32, 36, 51, 58, 52, 60, 21, 53, 26, 31, 37,  9,
    25, 38, 23, 50, 45, 17, 47, 28, 35, 30, 54, 16, 59, 49, 57, 14, 15, 55, 7
]]).to(device)

# The argmax of the output (ie the most likely next move from each position)
sample_output = t.tensor([[
    21, 41, 40, 34, 40, 41,  3, 11, 21, 43, 40, 21, 28, 50, 33, 50, 33,  5, 33,  5,
    52, 46, 14, 46, 14, 47, 38, 57, 36, 50, 38, 15, 28, 26, 28, 59, 50, 28, 14, 28,
    28, 28, 28, 45, 28, 35, 15, 14, 30, 59, 49, 59, 15, 15, 14, 15,  8,  7,  8
]]).to(device)

assert (model(sample_input).argmax(dim=-1) == sample_output.to(device)).all()

# os.chdir(section_dir)
section_dir = Path.cwd()
sys.path.append(str(section_dir))
print(section_dir.name)

OTHELLO_ROOT = (section_dir / "othello_world").resolve()
OTHELLO_MECHINT_ROOT = (OTHELLO_ROOT / "mechanistic_interpretability").resolve()

# if not OTHELLO_ROOT.exists():
#     !git clone https://github.com/likenneth/othello_world

sys.path.append(str(OTHELLO_MECHINT_ROOT))

from mech_interp_othello_utils import (
    plot_board,
    plot_single_board,
    plot_board_log_probs,
    to_string,
    to_int,
    int_to_label,
    string_to_label,
    OthelloBoardState,
)

# Load board data as ints (i.e. 0 to 60)
board_seqs_int = t.tensor(np.load(OTHELLO_MECHINT_ROOT / "board_seqs_int_small.npy"), dtype=t.long)
# Load board data as "strings" (i.e. 0 to 63 with middle squares skipped out)
board_seqs_string = t.tensor(np.load(OTHELLO_MECHINT_ROOT / "board_seqs_string_small.npy"), dtype=t.long)

assert all([middle_sq not in board_seqs_string for middle_sq in [27, 28, 35, 36]])
assert board_seqs_int.max() == 60

num_games, length_of_game = board_seqs_int.shape

# Define possible indices (excluding the four center squares)
stoi_indices = [i for i in range(64) if i not in [27, 28, 35, 36]]

# Define our rows, and the function that converts an index into a (row, column) label, e.g. `E2`
alpha = "ABCDEFGH"

def to_board_label(i):
    return f"{alpha[i//8]}{i%8}"

# Get our list of board labels
board_labels = list(map(to_board_label, stoi_indices))
full_board_labels = list(map(to_board_label, range(64)))

def plot_square_as_board(state, diverging_scale=True, **kwargs):
    """Takes a square input (8 by 8) and plot it as a board. Can do a stack of boards via facet_col=0"""
    kwargs = {
        "y": [i for i in alpha],
        "x": [str(i) for i in range(8)],
        "color_continuous_scale": "RdBu" if diverging_scale else "Blues",
        "color_continuous_midpoint": 0. if diverging_scale else None,
        "aspect": "equal",
        **kwargs
    }
    imshow(state, **kwargs)

# start = 30000
start = 0
num_games = 200
focus_games_int = board_seqs_int[start : start + num_games]
focus_games_string = board_seqs_string[start: start + num_games]

focus_logits, focus_cache = model.run_with_cache(focus_games_int[:, :-1].to(device))
focus_logits.shape

def one_hot(list_of_ints, num_classes=64):
    out = t.zeros((num_classes,), dtype=t.float32)
    out[list_of_ints] = 1.
    return out

focus_states = np.zeros((num_games, 60, 8, 8), dtype=np.float32)
focus_valid_moves = t.zeros((num_games, 60, 64), dtype=t.float32)

for i in (range(num_games)):
    board = OthelloBoardState()
    for j in range(60):
        board.umpire(focus_games_string[i, j].item())
        focus_states[i, j] = board.state
        focus_valid_moves[i, j] = one_hot(board.get_valid_moves())

print("focus states:", focus_states.shape)
print("focus_valid_moves", tuple(focus_valid_moves.shape))

# full_linear_probe = t.load(OTHELLO_MECHINT_ROOT / "main_linear_probe.pth", map_location=device)

# linear_probe2 = t.load("probes/linear/resid_6_linear.pth")

rows = 8
cols = 8
options = 3
# assert linear_probe2.shape == (1, cfg.d_model, rows, cols, options)

black_to_play_index = 0
white_to_play_index = 1
blank_index = 0
their_index = 1
my_index = 2

# Creating values for linear probe (converting the "black/white to play" notation into "me/them to play")

'''LAYER = 6
game_index = 0
move = 29'''

# BLANK1 = 0
# BLACK = 1
# WHITE = -1

# # MINE = 0
# # YOURS = 1
# # BLANK2 = 2

# EMPTY = 0
# MINE = 1
# YOURS = 2

# FLIPPED = 0
# NOT_FLIPPED = 1
from utils import *



interpretability
focus states: (200, 60, 8, 8)
focus_valid_moves (200, 60, 64)


In [2]:
device

device(type='cuda')

# Code
- Two possibilities
    - Input: Neuron, then visualize what it does
        - Look at output weights ... what direction feature is it writing
        - How much Variance is explained by some subspace e.g. 
        - Do fancy Neuron Graph shit
    - Find important Neurons
        - Throw some metric at all the neurons (like variance or something better)

In [3]:
from utils import visualize_game
from utils import VisualzeBoardArguments
vis_args = VisualzeBoardArguments()
vis_args.start_pos = 0
vis_args.end_pos = 10

# visualize_game(focus_games_string[0], vis_args, model)

In [4]:
from utils import get_probe
from utils import plot_boards_general
from utils import get_direction_int
from utils import get_direction_str

In [5]:
def get_w_in(
    model: HookedTransformer,
    layer: int,
    neuron: int,
    normalize: bool = False,
) -> Float[Tensor, "d_model"]:
    '''
    Returns the input weights for the given neuron.

    If normalize is True, the weight is normalized to unit norm.
    '''
    # SOLUTION
    w_in = model.W_in[layer, :, neuron].detach().clone()
    if normalize: w_in /= w_in.norm(dim=0, keepdim=True)
    return w_in


def get_w_out(
    model: HookedTransformer,
    layer: int,
    neuron: int,
    normalize: bool = False,
) -> Float[Tensor, "d_model"]:
    '''
    Returns the output weights for the given neuron.

    If normalize is True, the weight is normalized to unit norm.
    '''
    # SOLUTION
    w_out = model.W_out[layer, neuron, :].detach().clone()
    if normalize: w_out /= w_out.norm(dim=0, keepdim=True)
    return  w_out


def calculate_neuron_input_weights(
    model: HookedTransformer,
    probe: Float[Tensor, "modes d_model row col options"],
    layer: int,
    neuron: int,
    probe_option: int,
) -> Float[Tensor, "rows cols"]:
    '''
    Returns tensor of the input weights for the given neuron, at each square on the board,
    projected along the corresponding probe directions.

    Assume probe directions are normalized. You should also normalize the model weights.
    '''
    # SOLUTION
    w_in = get_w_in(model, layer, neuron, normalize=True)

    return einops.einsum(
        w_in, probe,
        "d_model, modes d_model row col options -> modes row col options",
    )[0, :, :, probe_option]


def calculate_neuron_output_weights(
    model: HookedTransformer,
    probe: Float[Tensor, "modes d_model row col options"],
    layer: int,
    neuron: int,
    probe_option: int,
) -> Float[Tensor, "rows cols"]:
    '''
    Returns tensor of the output weights for the given neuron, at each square on the board,
    projected along the corresponding probe directions.

    Assume probe directions are normalized. You should also normalize the model weights.
    '''
    # SOLUTION
    w_out = get_w_out(model, layer, neuron, normalize=True)

    return einops.einsum(
        w_out, probe,
        "d_model, modes d_model row col options -> modes row col options",
    )[0, :, :, probe_option]


layer = 5
neuron = 1393

linear_probe = get_probe(layer, "linear", "post")

w_in_L5N1393_blank = calculate_neuron_input_weights(model, linear_probe, layer, neuron, EMPTY)
w_in_L5N1393_my = calculate_neuron_input_weights(model, linear_probe, layer, neuron, MINE)
w_in_L5N1393_yours = calculate_neuron_input_weights(model, linear_probe, layer, neuron, YOURS)

'''imshow(
    t.stack([w_in_L5N1393_blank, w_in_L5N1393_my, w_in_L5N1393_yours]),
    facet_col=0,
    y=[i for i in "ABCDEFGH"],
    title=f"Input weights in terms of the probe for neuron L{layer}N{neuron}",
    facet_labels=["Blank In", "My In"],
    width=750,
)'''

'''plot_boards_general(
    x_labels = [f"Probe {i}" for i in range(3)],
    y_labels = [f"Input Weights for L{layer}N{neuron}"],
    boards = t.stack([w_in_L5N1393_blank, w_in_L5N1393_my, w_in_L5N1393_yours]).unsqueeze(dim=1),
    size_of_board= 400,
)'''

'plot_boards_general(\n    x_labels = [f"Probe {i}" for i in range(3)],\n    y_labels = [f"Input Weights for L{layer}N{neuron}"],\n    boards = t.stack([w_in_L5N1393_blank, w_in_L5N1393_my, w_in_L5N1393_yours]).unsqueeze(dim=1),\n    size_of_board= 400,\n)'

In [6]:
from utils import probe_directions_list

In [7]:
def get_fraction_of_variance_from_neuron_explained_by_probe(
        neuron : Int,
        layer : Int,
        in_out : str,
        tiles : List[Tuple[str, str, str]], # = probe_directions_list,
    ) -> Float:
    # assert type(tiles) == dict
    # TODO: Update asserts and add option for all tiles
    # if not all_tiles:
    # assert all([type(v) == dict and all ([type(v2) == list for _, v2 in v.items()]) for _, v in probe_names_and_directions.items()])
    # else:
    #     assert all([type(v) == list for _, v in probe_names_and_directions.items()])
    if in_out == "in":
        neuron_w = get_w_in(model, layer, neuron, normalize=True)
    else:
        neuron_w = get_w_out(model, layer, neuron, normalize=True)
    probes = []
    # probes_out = []
    for tile_label, probe_name, direction_name in tiles:
        tile_tuple = label_to_tuple(tile_label)
        y, x = tile_tuple
        if in_out == "in":
            probe = get_probe(layer, probe_name, "mid")
        else:
            probe = get_probe(layer, probe_name, "post")
        direction_int = get_direction_int(direction_name)
        probe_direction = probe[0, :, y, x, direction_int]
        probes.append(probe_direction)

    probes = t.stack(probes, dim=-1)
    U, S, Vh = t.svd(
        probes
    )
    return ((neuron_w @ U)[neuron_w @ U > 0]).norm().item()**2

# get_fraction_of_variance_from_neuron_explained_by_probe(neuron, layer, {"linear" : ["empty"], "flipped" : ["flipped"]})
tiles = [
    ("E4", "linear", "empty"),
    ("E4", "flipped", "flipped"),
]
get_fraction_of_variance_from_neuron_explained_by_probe(neuron, layer, "out", tiles)

# 0.2573759276456258, 0.11356346447632859

0.0

In [8]:
from utils import plot_boards_general

In [9]:
# TODO: I trained the linear probe wrong, I need to do it again ...

In [10]:
from utils import probes
from utils import probe_directions
from collections import defaultdict
from utils import get_short_cut

In [52]:
def kurtosis(tensor: Tensor, reduced_axes, fisher=True):
    """
    Computes the kurtosis of a tensor over specified dimensions.
    """
    return (((tensor - tensor.mean(dim=reduced_axes, keepdim=True)) / tensor.std(dim=reduced_axes, keepdim=True))**4).mean(dim=reduced_axes, keepdim=False) - fisher*3

def plot_neuron_weights(neurons : Float[Tensor, "d_mlp"], layer : Int, title : str, probe_names_and_directions : Dict[str, List[str]] = probe_directions_list, save=False):
    direction_dict = defaultdict(list)
    for neuron in neurons:
        neuron = neuron.item()
        for in_out in ["in", "out"]:
            for probe_name in probe_names_and_directions:
                for direction_str in probe_names_and_directions[probe_name]:
                    if in_out == "in":
                        probe_module = "mid"
                        probe = get_probe(layer = layer, probe_type = probe_name, probe_module = probe_module).clone()
                    else:
                        probe_module = "post"
                        probe = get_probe(layer = layer, probe_type = probe_name, probe_module = probe_module).clone()
                    if probe.isnan().sum().item() > 0:
                        print(f"Probe {probe_name} in is nan")
                        continue
                    probe_normalized = probe / probe.norm(dim=1, keepdim=True)
                    direction_int = probe_directions[probe_name][direction_str]
                    if in_out == "in":
                        neuron_weights = calculate_neuron_input_weights(model, probe_normalized, layer, neuron, direction_int)
                    else:
                        neuron_weights = calculate_neuron_output_weights(model, probe_normalized, layer, neuron, direction_int)
                    direction_dict[f"{get_short_cut(probe_name)}_{get_short_cut(direction_str)}_{in_out}"].append(neuron_weights)

    for direction, weights in direction_dict.items():
        direction_dict[direction] = t.stack(weights)

    boards = t.stack(list(direction_dict.values()))
    plot_boards_general(
        x_labels = list(direction_dict.keys()),
        y_labels = [f"N{i.item()}" for i in top_layer_3_neurons],
        boards = boards,
        title_text = title,
        save=save,
    )

layer = 2
top_layer_3_neurons = einops.reduce(focus_cache["post", layer][:, :], "game move neuron -> neuron", reduction=kurtosis).argsort(descending=True)[:1]
top_layer_3_neurons = focus_cache["post", layer][:, 3:-3].std(dim=[0, 1]).argsort(descending=True)[:1]

# top_layer_3_neurons[0] = 368
# top_layer_3_neurons[1] = 486
# top_layer_3_neurons[2] = 824
# top_layer_3_neurons[0] = 421#
neuron = 877
# top_layer_3_neurons[0] = neuron 
# top_layer_3_neurons[0] = 882
# top_layer_3_neurons[1] = 460
# top_layer_3_neurons[0] = 877
# top_layer_3_neurons[1] = 460
# top_layer_3_neurons = top_layer_3_neurons[:5]

# top_layer_3_neurons = t.arange(0, 6)

"""
Layer: 1, Tile: D3, Neurons: [894, 23, 1411, 441, 1412, 1417, 969, 592, 953, 304], Similiarities: [0.11, 0.08, 0.05, 0.05, 0.05, 0.04, 0.04, 0.04, 0.04, 0.04]
Layer: 1, Tile: B3, Neurons: [1157, 496, 377, 1653, 778, 23, 218, 1365, 513, 900], Similiarities: [0.11, 0.1, 0.1, 0.09, 0.08, 0.07, 0.07, 0.06, 0.06, 0.06]
Layer: 1, Tile: A3, Neurons: [12, 1401, 23, 1893, 819, 531, 43, 528, 268, 142], Similiarities: [0.1, 0.08, 0.08, 0.07, 0.07, 0.06, 0.06, 0.05, 0.04, 0.03]
Layer: 2, Tile: D3, Neurons: [171, 740, 486, 161, 269, 819, 863, 1138, 46, 1198], Similiarities: [0.2, 0.08, 0.07, 0.05, 0.05, 0.05, 0.05, 0.05, 0.04, 0.04]
Layer: 2, Tile: B3, Neurons: [1922, 710, 1022, 473, 613, 1884, 1106, 745, 234, 2010], Similiarities: [0.19, 0.12, 0.11, 0.09, 0.08, 0.08, 0.07, 0.07, 0.06, 0.06]
Layer: 2, Tile: A3, Neurons: [637, 225, 914, 11, 920, 514, 376, 111, 1768, 193], Similiarities: [0.1, 0.09, 0.08, 0.06, 0.03, 0.02, 0.02, 0.02, 0.02, 0.01]
Layer: 3, Tile: D3, Neurons: [850, 1800, 1341, 455, 253, 1245, 1023, 732, 390, 163], Similiarities: [0.2, 0.05, 0.05, 0.05, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04]
Layer: 3, Tile: B3, Neurons: [1931, 853, 1966, 1012, 71, 248, 20, 1551, 228, 594], Similiarities: [0.19, 0.13, 0.1, 0.09, 0.08, 0.08, 0.08, 0.06, 0.06, 0.05]
Layer: 3, Tile: A3, Neurons: [711, 1181, 554, 895, 1409, 1057, 1845, 80, 1590, 1607], Similiarities: [0.04, 0.02, 0.02, 0.02, 0.02, 0.02, 0.01, 0.01, 0.01, 0.01]
Layer: 4, Tile: D3, Neurons: [1742, 1749, 1316, 538, 902, 1645, 588, 1794, 178, 625], Similiarities: [0.2, 0.07, 0.06, 0.05, 0.05, 0.05, 0.04, 0.04, 0.04, 0.04]
Layer: 4, Tile: B3, Neurons: [125, 798, 1329, 1539, 1461, 1609, 144, 988, 418, 600], Similiarities: [0.2, 0.13, 0.13, 0.1, 0.1, 0.09, 0.07, 0.06, 0.04, 0.03]
Layer: 4, Tile: A3, Neurons: [1155, 1670, 1813, 861, 1613, 1444, 738, 373, 260, 470], Similiarities: [0.04, 0.03, 0.02, 0.02, 0.02, 0.01, 0.01, 0.01, 0.01, 0.01]
Layer: 5, Tile: D3, Neurons: [1902, 402, 1093, 1896, 301, 1904, 709, 1946, 1910, 713], Similiarities: [0.25, 0.09, 0.08, 0.07, 0.06, 0.06, 0.05, 0.05, 0.05, 0.05]
Layer: 5, Tile: B3, Neurons: [193, 866, 1303, 332, 581, 1156, 74, 855, 1275, 526], Similiarities: [0.17, 0.12, 0.09, 0.08, 0.07, 0.06, 0.04, 0.04, 0.03, 0.02]
Layer: 5, Tile: A3, Neurons: [301, 1855, 185, 713, 1172, 850, 1247, 1826, 292, 1244], Similiarities: [0.13, 0.09, 0.07, 0.06, 0.05, 0.03, 0.01, 0.01, 0.01, 0.01]
"""
layer = 2
top_layer_3_neurons = t.Tensor(np.array([171, 740, 486, 161, 269, 819, 863, 1138, 46, 1198])).to(dtype=t.long)


probe_directions_new = {
    "linear" : ["empty", "mine", "yours"],
    "flipped" : ["flipped"],
    "placed" : ["placed"],
}
# title = f"Top 10 neurons in layer {layer} by kurtosis"
plot_neuron_weights(top_layer_3_neurons, layer, title=f"Neuon Weights of L{layer}_N{neuron} Projected to different Linear Probes", probe_names_and_directions=probe_directions_new, save=True)

In [40]:
from interpreting_neurons_utils import plot_max_activations_of_neuron

In [50]:
neuron = 23
plot_max_activations_of_neuron(focus_cache, focus_games_string, layer, neuron, random=False)

Top Game: 
tensor([ 75, 125, 125, 171, 173,  57, 171,  12], device='cuda:0')
Top Move: 
tensor([ 9, 22, 21, 11, 21, 13, 13,  6], device='cuda:0')
Activation values: 
tensor([ 1.8707e+00, -5.4178e-05, -1.4269e-04, -1.7239e-04, -1.9858e-04,
        -2.1395e-04, -2.1923e-04, -2.2386e-04], device='cuda:0')
Game: 75, Pos: 9, Activation: 1.8707242012023926
torch.Size([6, 2, 59, 8, 8])


Game: 125, Pos: 22, Activation: -5.417780630523339e-05
torch.Size([6, 2, 59, 8, 8])


Game: 125, Pos: 21, Activation: -0.00014268509403336793
torch.Size([6, 2, 59, 8, 8])


Game: 171, Pos: 11, Activation: -0.00017238935106433928
torch.Size([6, 2, 59, 8, 8])


Game: 173, Pos: 21, Activation: -0.00019857838924508542
torch.Size([6, 2, 59, 8, 8])


Game: 57, Pos: 13, Activation: -0.00021394983923528343
torch.Size([6, 2, 59, 8, 8])


Game: 171, Pos: 13, Activation: -0.00021923323220107704
torch.Size([6, 2, 59, 8, 8])


Game: 12, Pos: 6, Activation: -0.00022386174532584846
torch.Size([6, 2, 59, 8, 8])


In [12]:
# Spectrum plot
layer = 1
neuron = 2
probe_name = "flipped"
probe_module = "mid"
probe = get_probe(layer = layer, probe_type = probe_name, probe_module = probe_module)[0]

post_activations = focus_cache["post", layer][:, :, neuron]
batch_size, seq_len = post_activations.shape
post_activations = einops.rearrange(post_activations, "batch seq_len -> (batch seq_len)")
print(post_activations.shape)
top_activations_indeces = post_activations.argsort(descending=True)
activation_values = post_activations.sort(descending=True)
positive_count = (activation_values.values > 0.5).sum()
positive_activations_indeces = top_activations_indeces[:positive_count]

print(probe.shape)
print(focus_cache.keys())
resid : Float[Tensor, "batch pos d_model"] = focus_cache[f"blocks.{layer}.ln2.hook_normalized"]
probe_results = einops.einsum(resid, probe, "batch pos d_model, d_model row col options -> batch pos row col options")
print(probe_results.shape)
probe_results = einops.rearrange(probe_results, "batch pos row col options -> (batch pos) row col options")
probe_results = probe_results.softmax(dim=-1)
print(probe_results.shape)
print(positive_activations_indeces.shape)
probe_results = probe_results[positive_activations_indeces]
print(probe_results.shape)
probe_results = einops.reduce(probe_results, "batch row col options -> row col options", "mean")
for option in range(probe_results.shape[-1]):
    plot_boards_general(
        x_labels = [""],
        y_labels = [""],
        boards = probe_results[:, :, option].unsqueeze(dim=0).unsqueeze(dim=0),
        title_text = f"Probe Spectrum for {get_direction_str(probe_name, option)}",
        save=False,
        margin_t = 100,
        size_of_board=400, 
    )

torch.Size([11800])


torch.Size([512, 8, 8, 2])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'block

In [13]:
'''
Neuron pair: 1084 with similiarity: 0.23889845144822175
Neuron pair: 1920 with similiarity: 0.23444519623226157
Neuron pair: 705 with similiarity: 0.2317404346409555
Neuron pair: 1551 with similiarity: 0.22644795818111163
Neuron pair: 1860 with similiarity: 0.2042490093961682
Neuron pair: 1915 with similiarity: 0.19610049085385128
Neuron pair: 843 with similiarity: 0.19589928164777692
Neuron pair: 1198 with similiarity: 0.19580397753182766
Neuron pair: 938 with similiarity: 0.19567771486123675
Neuron pair: 1407 with similiarity: 0.18711266284867634
'''

'\nNeuron pair: 1084 with similiarity: 0.23889845144822175\nNeuron pair: 1920 with similiarity: 0.23444519623226157\nNeuron pair: 705 with similiarity: 0.2317404346409555\nNeuron pair: 1551 with similiarity: 0.22644795818111163\nNeuron pair: 1860 with similiarity: 0.2042490093961682\nNeuron pair: 1915 with similiarity: 0.19610049085385128\nNeuron pair: 843 with similiarity: 0.19589928164777692\nNeuron pair: 1198 with similiarity: 0.19580397753182766\nNeuron pair: 938 with similiarity: 0.19567771486123675\nNeuron pair: 1407 with similiarity: 0.18711266284867634\n'

In [14]:
# from utils import get_focus_logits_and_cache

In [15]:
# focus_logits, focus_cache = get_focus_logits_and_cache()

In [16]:
layer = 2 # 0
neuron = 877 # 368
neuron = neuron 
bias = model.b_in[layer, neuron].item()
print(bias)

# get_fraction_of_variance_from_neuron_explained_by_probe(neuron = neuron, layer = 1)
def get_max_acitvations_of_neuron(
    cache: ActivationCache,
    layer: int,
    neuron: int,
    num_activations: int = 10,
    random = False,
) -> Tuple[Float[Tensor, "game move"], Float[Tensor, "game move"]]:
    '''
    Returns the top activations for a given neuron in a given layer.
    '''
    # SOLUTION
    post_activations = cache["post", layer][:, :, neuron]
    batch_size, seq_len = post_activations.shape
    post_activations = post_activations.reshape(-1)
    top_activations = post_activations.argsort(descending=True)
    activation_values = post_activations.sort(descending=True)
    # set positive_activations to num_activation random examples
    if random:
        positive_count = (activation_values.values > 0).sum()
        positive_activations = top_activations[:positive_count]
        random_indeces = (t.randperm(positive_count)[:num_activations]).to(t.long)
        top_activations = positive_activations[random_indeces]
    else:
        top_activations = top_activations[:num_activations]
    activation_values = post_activations[top_activations]
    top_games = top_activations // seq_len
    top_moves = top_activations % seq_len
    print(f"Top Game: \n{top_games}")
    print(f"Top Move: \n{top_moves}")
    print(f"Activation values: \n{activation_values}")
    return top_games, top_moves, activation_values

top_games, top_moves, activation_values = get_max_acitvations_of_neuron(focus_cache, layer, neuron, 2, random=False)
for i in range(len(top_games)):
    game = top_games[i].item()
    pos = top_moves[i].item()
    activation = activation_values[i].item()
    vis_args = VisualzeBoardArguments()
    vis_args.start_pos = pos
    vis_args.end_pos = pos+1
    # vis_args.include_layer_norm = True
    # vis_args.include_pre_resid = False
    vis_args.layers = 6
    vis_args.static_image = True
    # vis_args.include_attn_only = True
    # vis_args.include_mlp_only = True
    # plot_game(focus_games_string, game)
    print(f"Game: {game}, Pos: {pos}, Activation: {activation}")
    visualize_game(
        focus_games_string[game, :59],
        vis_args,
        model,
    )
    # break

-3.611151695251465
Top Game: 
tensor([ 94, 183], device='cuda:0')
Top Move: 
tensor([ 8, 15], device='cuda:0')
Activation values: 
tensor([3.7089, 3.6964], device='cuda:0')
Game: 94, Pos: 8, Activation: 3.7089197635650635
torch.Size([6, 1, 59, 8, 8])
Game: 183, Pos: 15, Activation: 3.6964008808135986
torch.Size([6, 1, 59, 8, 8])


In [17]:
for layer in range(8):
    mlp_activations = focus_cache["post", layer]
    mlp_activations_high = mlp_activations > 0.0
    fraction = mlp_activations_high.sum() / mlp_activations_high.numel()
    print(f"Fraction of activations above 0.0 in layer {layer}: {fraction.item()}")

Fraction of activations above 0.0 in layer 0: 0.04423344135284424
Fraction of activations above 0.0 in layer 1: 0.09934781491756439
Fraction of activations above 0.0 in layer 2: 0.1093745008111
Fraction of activations above 0.0 in layer 3: 0.12304844707250595
Fraction of activations above 0.0 in layer 4: 0.1329999566078186
Fraction of activations above 0.0 in layer 5: 0.12842658162117004
Fraction of activations above 0.0 in layer 6: 0.12131679058074951
Fraction of activations above 0.0 in layer 7: 0.10062719136476517


In [18]:
layer = 1
linear_probe = get_probe(layer, "linear", "post")
# B2
direction = linear_probe[0, :, 1, 2, MINE]

W_out = model.W_out[layer, :, :].detach().clone() / model.W_out[layer, :, :].norm(dim=-1, keepdim=True)
similiarity_matrix = W_out @ W_out.T

In [19]:
'''similiarities = dict()
for a in range(similiarity_matrix.shape[0]):
    if a % 100 == 0:
        print(f"Neron1: {a}")
    for b in range(similiarity_matrix.shape[1]):
        similiarities[(a, b)] = similiarity_matrix[a, b].item()
        if a == b:
            similiarities[(a, b)] = 0'''

'similiarities = dict()\nfor a in range(similiarity_matrix.shape[0]):\n    if a % 100 == 0:\n        print(f"Neron1: {a}")\n    for b in range(similiarity_matrix.shape[1]):\n        similiarities[(a, b)] = similiarity_matrix[a, b].item()\n        if a == b:\n            similiarities[(a, b)] = 0'

In [20]:
'''sorted_similiarities = sorted(similiarities.items(), key=lambda x: x[1], reverse=True)
for neuron_pair, similiarity in sorted_similiarities[:10]:
    print(f"Neuron pair: {neuron_pair} with similiarity: {similiarity}")'''

'sorted_similiarities = sorted(similiarities.items(), key=lambda x: x[1], reverse=True)\nfor neuron_pair, similiarity in sorted_similiarities[:10]:\n    print(f"Neuron pair: {neuron_pair} with similiarity: {similiarity}")'

## Automatically Finding Flipping Neurons

In [21]:
def get_similiarity(neuron : Int, layer : Int, tiles : List[Tuple[str, str, str, str]], metric = "avg"):
    avg_similiarity = 0
    direction_all = t.zeros([512]).to(device)
    for label, probe_type, feature_str, in_or_out in tiles:
        tile_tuple = label_to_tuple(label)
        y, x = tile_tuple
        feature = get_direction_int(feature_str)
        if in_or_out == "in":
            probe_module = "mid"
            w = get_w_in(model, layer, neuron, normalize=True)
        else:
            probe_module = "post"
            w = get_w_out(model, layer, neuron, normalize=True)
        probe = get_probe(layer, probe_type=probe_type, probe_module=probe_module)
        direction = probe[0, :, y, x, feature]
        direction = direction / direction.norm()
        direction_all += direction
        similiarity = einops.einsum(direction, w, "d_model, d_model ->").item()
        if feature_str == "empty":
            similiarity = similiarity / 3
        avg_similiarity += similiarity
    direction_all = direction_all / direction_all.norm()
    similiarity_all = einops.einsum(direction_all, w, "d_model, d_model ->").item()
    if metric == "avg":
        return avg_similiarity / len(tiles)
    else:
        return similiarity_all


similiarities = dict()

tiles = [
    ("E2", "linear", "empty", "in"),
    ("B5", "placed", "placed", "in"),
    ("B5", "linear", "yours", "in"),
    ("B4", "linear", "mine", "in"),
    ("B3", "linear", "empty", "in"),
    ("C5", "linear", "empty", "in"),
    ("G4", "linear", "yours", "out"),
    ("G4", "flipped", "flipped", "out"),
    ("D3", "linear", "yours", "out"),
]

for neuron1 in range(4 * 512):
    similiarity = get_similiarity(neuron1, layer=0, tiles=tiles, metric="avg")
    similiarities[neuron1] = similiarity

# sort similiarites by decending value
sorted_similiarities = sorted(similiarities.items(), key=lambda x: x[1], reverse=True)
for neuron_pair, similiarity in sorted_similiarities[:10]:
    print(f"Neuron pair: {neuron_pair} with similiarity: {similiarity}")

Neuron pair: 409 with similiarity: 0.10315750263355397
Neuron pair: 1915 with similiarity: 0.0931953980276982
Neuron pair: 882 with similiarity: 0.09068340708122209
Neuron pair: 1551 with similiarity: 0.08845974794692463
Neuron pair: 56 with similiarity: 0.08499977294424617
Neuron pair: 1920 with similiarity: 0.08209401969280507
Neuron pair: 1293 with similiarity: 0.08157213418572037
Neuron pair: 705 with similiarity: 0.08045514445337985
Neuron pair: 925 with similiarity: 0.07430763677176502
Neuron pair: 1780 with similiarity: 0.07376143836450796


In [29]:
import math

In [33]:
'''tiles_in = [
    ("B5", "placed", "placed"),
    ("B5", "linear", "yours"),
    ("B4", "linear", "mine"),
    # ("B3", "linear", "empty"),
    # ("C5", "linear", "empty"),
]

tiles_out = [
    ("G4", "linear", "yours"),
    ("G4", "flipped", "flipped"),
    ("D3", "linear", "yours"),
]'''

tiles_in = [
]
for layer in range(1, 6):
    for label in ["D3", "B3", "A3"]:
        tiles_out = [
            (label, "flipped", "flipped")
        ]

        for neuron1 in range(4 * 512):
            # similiarity_in = get_fraction_of_variance_from_neuron_explained_by_probe(neuron1, 0, "in", tiles_in)
            similiarity_out = get_fraction_of_variance_from_neuron_explained_by_probe(neuron1, layer, "out", tiles_out)
            # similiarities[neuron1] = (similiarity_in + similiarity_out) / 2
            similiarities[neuron1] = similiarity_out

        # sort similiarites by decending value
        sorted_similiarities = sorted(similiarities.items(), key=lambda x: x[1], reverse=True)
        # for neuron_pair, similiarity in sorted_similiarities[:10]:
        #     print(f"Neuron pair: {neuron_pair} with similiarity: {similiarity}")
        print(f"Layer: {layer}, Tile: {label}, Neurons: {[neuron for neuron, similiarity in sorted_similiarities[:10]]}, Similiarities: {[round(similiarity, 2) for neuron, similiarity in sorted_similiarities[:10]]}")
        # get_fraction_of_variance_from_neuron_explained_by_probe

Layer: 1, Tile: D3, Neurons: [894, 23, 1411, 441, 1412, 1417, 969, 592, 953, 304], Similiarities: [0.11, 0.08, 0.05, 0.05, 0.05, 0.04, 0.04, 0.04, 0.04, 0.04]
Layer: 1, Tile: B3, Neurons: [1157, 496, 377, 1653, 778, 23, 218, 1365, 513, 900], Similiarities: [0.11, 0.1, 0.1, 0.09, 0.08, 0.07, 0.07, 0.06, 0.06, 0.06]
Layer: 1, Tile: A3, Neurons: [12, 1401, 23, 1893, 819, 531, 43, 528, 268, 142], Similiarities: [0.1, 0.08, 0.08, 0.07, 0.07, 0.06, 0.06, 0.05, 0.04, 0.03]
Layer: 2, Tile: D3, Neurons: [171, 740, 486, 161, 269, 819, 863, 1138, 46, 1198], Similiarities: [0.2, 0.08, 0.07, 0.05, 0.05, 0.05, 0.05, 0.05, 0.04, 0.04]
Layer: 2, Tile: B3, Neurons: [1922, 710, 1022, 473, 613, 1884, 1106, 745, 234, 2010], Similiarities: [0.19, 0.12, 0.11, 0.09, 0.08, 0.08, 0.07, 0.07, 0.06, 0.06]
Layer: 2, Tile: A3, Neurons: [637, 225, 914, 11, 920, 514, 376, 111, 1768, 193], Similiarities: [0.1, 0.09, 0.08, 0.06, 0.03, 0.02, 0.02, 0.02, 0.02, 0.01]
Layer: 3, Tile: D3, Neurons: [850, 1800, 1341, 455, 25

In [22]:
'''
Neuron pair: 1897 with similiarity: 3.733549118041992
Neuron pair: 1447 with similiarity: 2.6375696659088135
Neuron pair: 686 with similiarity: 1.5242090225219727
Neuron pair: 1324 with similiarity: 1.2681114673614502
Neuron pair: 1544 with similiarity: 1.0980175733566284
Neuron pair: 626 with similiarity: 1.0913395881652832
Neuron pair: 550 with similiarity: 1.0839078426361084
Neuron pair: 595 with similiarity: 1.0835928916931152
Neuron pair: 844 with similiarity: 0.9409590363502502
Neuron pair: 187 with similiarity: 0.9197345972061157
layer = 1
linear_probe = get_probe(layer, "linear", "post")
# B2
direction = linear_probe[0, :, 1, 2, MINE]
similiarities = dict()

for neuron1 in range(4 * 512):
    if neuron1 % 100 == 0:
        print(f"Neron1: {neuron1}")
    for neuron2 in range(neuron1):
        w_out1 = get_w_out(model, layer, neuron1, normalize=True)
        w_out2 = get_w_out(model, layer, neuron2, normalize=True)
        normalized_neuron_addition_out = w_out1 + w_out2 / (w_out1 + w_out2).norm()
        # First use addition, later use cosine similiarity to plane
        similiarity = einops.einsum(direction, normalized_neuron_addition_out, "d_model, d_model ->").item()
        similiarities[(neuron1, neuron2)] = similiarity

# sort similiarites by decending value
sorted_similiarities = sorted(similiarities.items(), key=lambda x: x[1], reverse=True)
for neuron_pair, similiarity in sorted_similiarities[:10]:
    print(f"Neuron pair: {neuron_pair} with similiarity: {similiarity}")'''

'\nNeuron pair: 1897 with similiarity: 3.733549118041992\nNeuron pair: 1447 with similiarity: 2.6375696659088135\nNeuron pair: 686 with similiarity: 1.5242090225219727\nNeuron pair: 1324 with similiarity: 1.2681114673614502\nNeuron pair: 1544 with similiarity: 1.0980175733566284\nNeuron pair: 626 with similiarity: 1.0913395881652832\nNeuron pair: 550 with similiarity: 1.0839078426361084\nNeuron pair: 595 with similiarity: 1.0835928916931152\nNeuron pair: 844 with similiarity: 0.9409590363502502\nNeuron pair: 187 with similiarity: 0.9197345972061157\nlayer = 1\nlinear_probe = get_probe(layer, "linear", "post")\n# B2\ndirection = linear_probe[0, :, 1, 2, MINE]\nsimiliarities = dict()\n\nfor neuron1 in range(4 * 512):\n    if neuron1 % 100 == 0:\n        print(f"Neron1: {neuron1}")\n    for neuron2 in range(neuron1):\n        w_out1 = get_w_out(model, layer, neuron1, normalize=True)\n        w_out2 = get_w_out(model, layer, neuron2, normalize=True)\n        normalized_neuron_addition_out 

In [23]:
pca = PCA(n_components=2)

mine_probe = get_probe(layer = 5, probe_type = "linear", probe_module = "post").clone()[:, :, :, :, MINE]

mine_probe_pca = pca.fit_transform(mine_probe.reshape(cfg.d_model, 64).T)
# print(mine_probe_pca.shape)

# x and y should be log scale
'''trace = go.Scatter(
    x = mine_probe_pca[:, 0],
    y = mine_probe_pca[:, 1],
    mode = 'markers',
    marker = dict(
        size = 10,
        color = "blue",
        opacity = 0.8,
    ),
    text = board_labels,
    log_x = True,
    log_y = True,
)

fig = go.Figure(data = [trace])
fig.show()'''
# px.scatter(x = mine_probe_pca[:, 0], y = mine_probe_pca[:, 1], labels = {"x": "PCA 1", "y": "PCA 2"}, x_labels = list(range(64)))

# perform log on mine_probe_pca

# mine_probe_pca[mine_probe_pca >= 0] = np.log10(mine_probe_pca[mine_probe_pca >= 0] + 1)  # / np.log10(0.1)
mine_probe_pca[mine_probe_pca < 0] = -np.log10(-mine_probe_pca[mine_probe_pca < 0] + 1)
board_labels = []
for i in range(64):
    label = string_to_label(i)
    board_labels.append(label)

# Dont use marker size, use text instead
fig = px.scatter(
    x = mine_probe_pca[:, 1],
    y = mine_probe_pca[:, 0] * -1,
    text = board_labels,
    log_x = False,
    log_y = False,
    labels = {"x": "PCA 1", "y": "PCA 2"},
    title = "PCA of the mine probe in layer 5",
    opacity = 0,
)

# fig = go.Figure(data = [trace])
# set width and height to 1000 and plot the figure
fig.update_layout(width = 1000, height = 1000)
fig.show()

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

# Abitious MLP Interpretability
- Search for Interpretabil Directions out of Neuron Input and Outputs (I think input might be more interpretabil ...)
- I'm assuming that many Neurons read multiple features, like Tile A placed here. Tile B Mine etc.
- I should test with a Toy model where I construct the features as an overloaded bases. Then create random combinations of these features and then try to reconstruc the original features from these
- Regular Language Models have much more feature then neurons (LLama 3 has 5325 Hidden Layer size). So this should work good
- Maybe start without superposition and noise, then increase superposition and noise
- To confirm a Solutions
    - All atomic elements should have < \epsilon similiarity
    -For all the inputs the specific atomic features should appr. add up to the the input
- Does this even have a solution. Could there be multiple possible slolutions?
- Is my Idea of finding the subet even somewhat principled
    - If I have Atomic Features A B C and I have A+B and B+C. Can I find B from the two Inputs: YES
    - Middle doesen't work ...
- I think this might be possible but very hard. Especially with superposition.
    - The Naive Idea didn't work
    - I want to assk John and Neel Nanda at the summer school what they think of the idea