In [46]:
from alphatoe import plot, game, interpretability
import torch
from torch import Tensor
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from importlib import reload
import tqdm
import pandas as pd

In [47]:
model = interpretability.load_model(
    "../scripts/models/prob all 8 layer control-20230718-185339"
)

In [48]:
games = pd.read_csv("../data/prob all 8 layer control-20230718-185339_stats.csv")
games.head()

  games = pd.read_csv("../data/prob all 8 layer control-20230718-185339_stats.csv")


Unnamed: 0,moves played,steps till end state,winner,rotation 1,rotation 2,rotation 3,horizontal flip,vertical flip,training index,train or test,first win condition,second win condition,end move loss
0,"[0, 1, 3, 2, 6]",5,X,399,1439,1040,1114,325,241912,test,left column,,5e-06
1,"[0, 1, 3, 4, 6]",5,X,396,1438,1043,1112,327,190522,train,left column,,0.000114
2,"[0, 1, 3, 5, 6]",5,X,398,1437,1041,1113,326,90275,train,left column,,7e-06
3,"[0, 1, 3, 7, 6]",5,X,395,1436,1044,1110,329,21994,train,left column,,8e-06
4,"[0, 1, 3, 8, 6]",5,X,397,1435,1042,1111,328,48696,train,left column,,1.3e-05


### Get games by win conditions

In [49]:
non_9_move_games = games[games["steps till end state"] != 9]
end_game_types = list(games["first win condition"].unique())

game_kinds = [
    [
        [10] + eval(move)
        for move in non_9_move_games[
            (non_9_move_games["first win condition"] == game_type)
            | (non_9_move_games["second win condition"] == game_type)
        ]["moves played"]
    ]
    for game_type in end_game_types[:-1]
]

In [50]:
for games in game_kinds:
    print(len(games))

14436
20340
14436
14436
20340
14436
14436
14436


In [51]:
all_non_9_move_games = [[10] + eval(move) for move in non_9_move_games["moves played"]]

### Get pre-mlp residual stream contents

In [75]:
def pre_mlp_residual(seq):
    def hook(module, input, output):
        result = output.clone()
        module.captured = result

    try:
        handle = model.blocks[0].hook_resid_mid.register_forward_hook(hook)
        with torch.no_grad():
            model(seq)
        activations = model.blocks[0].hook_resid_mid.captured
        handle.remove()
    except Exception as e:
        handle.remove()
        raise e

    return activations

In [53]:
residual_stream = pre_mlp_residual(torch.tensor([10, 1, 2, 3]))

In [54]:
residual_stream.shape

torch.Size([1, 4, 128])

### Get residual stream across 1000 games of a type

In [55]:
games = game_kinds[0][:1000]

In [56]:
games[0]

[10, 0, 1, 3, 2, 6]

In [57]:
len(games)

1000

In [58]:
# left column
games[0]

[10, 0, 1, 3, 2, 6]

In [59]:
# TODO: games are different sizes, batching requires some work, who cares for now
resid_vectors = []
for game in games:
    resid_vectors.append(
        interpretability.capture_forward_pass(
            model, model.blocks[0].hook_resid_mid, torch.tensor(game)
        )[0, -1]
    )
resid_tensor = torch.stack(resid_vectors).detach().cpu()

In [60]:
resid_tensor.shape

torch.Size([1000, 128])

### Taking the PCA of the residual stream across games

In [61]:
mean_centered = resid_tensor - resid_tensor.mean(dim=0)

In [62]:
resid_tensor

tensor([[ 0.0431,  0.5041, -0.3265,  ..., -0.6198, -0.3301, -0.5209],
        [-0.0807,  0.5625, -0.3957,  ..., -0.4737, -0.2211, -0.6115],
        [-0.0240,  0.5285, -0.2413,  ..., -0.5456, -0.3942, -0.6143],
        ...,
        [-0.3022,  0.0194,  0.1407,  ..., -0.1070, -0.3671, -0.2893],
        [-0.4039,  0.0594,  0.0815,  ...,  0.0336, -0.2753, -0.3656],
        [-0.3621,  0.0250,  0.2009,  ..., -0.0330, -0.4281, -0.3731]])

In [63]:
PCA_info = interpretability.pca(resid_tensor)

In [64]:
R2 = PCA_info["R2"]
variances = PCA_info["variances"]
principal_components = PCA_info["principal components"]
projected_data = PCA_info["projected data"]

In [65]:
print(R2)
print(variances)
print(principal_components)
print(projected_data)

tensor([0.2270, 0.1849, 0.1348, 0.1120, 0.0887, 0.0730, 0.0708, 0.0604])
tensor([0.6929, 0.5643, 0.4114, 0.3417, 0.2709, 0.2228, 0.2162, 0.1842])
tensor([[ 0.0262, -0.0082, -0.1195,  ...,  0.0749,  0.0488, -0.0993],
        [-0.0438,  0.0193, -0.0306,  ..., -0.0423,  0.0093, -0.1454],
        [-0.0173, -0.0830,  0.0602,  ...,  0.0776, -0.2071,  0.0710],
        ...,
        [ 0.0472, -0.2161, -0.0483,  ..., -0.0894, -0.0569,  0.0392],
        [-0.0538,  0.0824,  0.0281,  ..., -0.1813,  0.1037,  0.0822],
        [-0.0968, -0.0381,  0.0638,  ...,  0.0479,  0.0616, -0.0289]])
tensor([[-0.5320,  1.0366, -0.1912,  ..., -0.5966,  0.2388, -0.4008],
        [-0.5350,  0.9054, -0.2894,  ..., -0.5321,  0.2063, -0.3617],
        [-0.5000,  1.0873, -0.2263,  ..., -0.3949, -0.5261, -0.5436],
        ...,
        [-1.4458, -0.7565,  0.1498,  ...,  0.0038, -0.2217,  1.7975],
        [-1.4579, -0.8730,  0.0678,  ...,  0.0669, -0.2429,  1.8418],
        [-1.4305, -0.7070,  0.1236,  ...,  0.1837, -0.920

### This isn't enough, we need a way to associate each game with it's residual stream projected onto each principal component

In [66]:
# If we sort and take the indices across the 1000 dimension, that'll tell us what games occured in which placement of the variance
projected_data.shape

torch.Size([1000, 8])

In [67]:
# We can "recover" our data by doing projected_data @ principal_components.T
principal_components.shape

torch.Size([128, 8])

In [68]:
def sort_PCA_projections(projected_data, index):
    data_vec = projected_data[:, index]
    # principal_component = principal_components[:, index]
    game_indices = torch.argsort(data_vec)
    return game_indices

In [69]:
first_component_games = [
    game_kinds[0][i] for i in sort_PCA_projections(resid_tensor, 0)
]
for moves in first_component_games:
    print(moves)
    print("-------------------")

[10, 0, 1, 4, 5, 3, 7, 6]
-------------------
[10, 0, 1, 4, 7, 3, 5, 6]
-------------------
[10, 0, 1, 3, 5, 4, 7, 6]
-------------------
[10, 0, 1, 3, 7, 4, 5, 6]
-------------------
[10, 0, 1, 3, 5, 7, 4, 6]
-------------------
[10, 0, 1, 3, 4, 7, 5, 6]
-------------------
[10, 0, 1, 4, 5, 6, 7, 3]
-------------------
[10, 0, 1, 4, 7, 6, 5, 3]
-------------------
[10, 0, 1, 6, 5, 4, 7, 3]
-------------------
[10, 0, 1, 6, 7, 4, 5, 3]
-------------------
[10, 0, 1, 4, 2, 3, 7, 6]
-------------------
[10, 0, 1, 6, 5, 7, 4, 3]
-------------------
[10, 0, 1, 6, 4, 7, 5, 3]
-------------------
[10, 0, 1, 4, 7, 3, 2, 6]
-------------------
[10, 0, 1, 3, 2, 4, 7, 6]
-------------------
[10, 0, 1, 3, 7, 4, 2, 6]
-------------------
[10, 0, 1, 3, 2, 7, 4, 6]
-------------------
[10, 0, 1, 3, 4, 7, 2, 6]
-------------------
[10, 0, 1, 4, 2, 6, 7, 3]
-------------------
[10, 0, 1, 4, 7, 6, 2, 3]
-------------------
[10, 0, 1, 4, 8, 3, 7, 6]
-------------------
[10, 0, 1, 6, 2, 4, 7, 3]
--------

In [70]:
from alphatoe import game

for moves in first_component_games:
    game.play_game(moves)
    print("-------------------")

| X | O |   |
| X | X | O |
| X | O |   |
-------------------
| X | O |   |
| X | X | O |
| X | O |   |
-------------------
| X | O |   |
| X | X | O |
| X | O |   |
-------------------
| X | O |   |
| X | X | O |
| X | O |   |
-------------------
| X | O |   |
| X | O | O |
| X | X |   |
-------------------
| X | O |   |
| X | O | O |
| X | X |   |
-------------------
| X | O |   |
| X | X | O |
| X | O |   |
-------------------
| X | O |   |
| X | X | O |
| X | O |   |
-------------------
| X | O |   |
| X | X | O |
| X | O |   |
-------------------
| X | O |   |
| X | X | O |
| X | O |   |
-------------------
| X | O | O |
| X | X |   |
| X | O |   |
-------------------
| X | O |   |
| X | O | O |
| X | X |   |
-------------------
| X | O |   |
| X | O | O |
| X | X |   |
-------------------
| X | O | O |
| X | X |   |
| X | O |   |
-------------------
| X | O | O |
| X | X |   |
| X | O |   |
-------------------
| X | O | O |
| X | X |   |
| X | O |   |
-------------------
| X | O 

### That didn't really work. The principal vectors aren't really interpretable..

### Maybe we can just get some information about the variance across all games

In [71]:
games = pd.read_csv("../data/prob all 8 layer control-20230718-185339_stats.csv")

  games = pd.read_csv("../data/prob all 8 layer control-20230718-185339_stats.csv")


In [72]:
all_games = [eval(moves) for moves in games["moves played"]]

In [77]:
all_logits = interpretability.inference_on_games(
    interpretability.capture_forward_pass(
        model, model.blocks[0].hook_resid_mid, torch.tensor(game)
    ),
    all_games,
)

RuntimeError: Could not infer dtype of module

In [None]:
for logits in all_logits:
    print(len(logits))

In [None]:
with torch.no_grad():
    all_residuals = [
        pre_mlp_residual(torch.tensor(game)) for game in tqdm.tqdm(all_games)
    ]

### Working across all games

In [78]:
for game in tqdm.tqdm(all_non_9_move_games):
    resid_vectors.append(
        interpretability.capture_forward_pass(
            model, model.blocks[0].hook_resid_mid, torch.tensor(game)
        )[0, -1]
    )
resid_tensor = torch.stack(resid_vectors).detach().cpu()

100%|██████████| 127296/127296 [00:56<00:00, 2259.08it/s]


In [79]:
resid_tensor.shape

torch.Size([128296, 128])

In [80]:
all_PCA = interpretability.pca(resid_tensor)

In [81]:
R2 = all_PCA["R2"]
variances = all_PCA["variances"]
principal_components = PCA_info["principal components"]
projected_data = PCA_info["projected data"]

In [82]:
print(R2)
print(variances)
print(principal_components)
print(projected_data)

tensor([0.2270, 0.1849, 0.1348, 0.1120, 0.0887, 0.0730, 0.0708, 0.0604])
tensor([0.6929, 0.5643, 0.4114, 0.3417, 0.2709, 0.2228, 0.2162, 0.1842])
tensor([[ 0.0262, -0.0082, -0.1195,  ...,  0.0749,  0.0488, -0.0993],
        [-0.0438,  0.0193, -0.0306,  ..., -0.0423,  0.0093, -0.1454],
        [-0.0173, -0.0830,  0.0602,  ...,  0.0776, -0.2071,  0.0710],
        ...,
        [ 0.0472, -0.2161, -0.0483,  ..., -0.0894, -0.0569,  0.0392],
        [-0.0538,  0.0824,  0.0281,  ..., -0.1813,  0.1037,  0.0822],
        [-0.0968, -0.0381,  0.0638,  ...,  0.0479,  0.0616, -0.0289]])
tensor([[-0.5320,  1.0366, -0.1912,  ..., -0.5966,  0.2388, -0.4008],
        [-0.5350,  0.9054, -0.2894,  ..., -0.5321,  0.2063, -0.3617],
        [-0.5000,  1.0873, -0.2263,  ..., -0.3949, -0.5261, -0.5436],
        ...,
        [-1.4458, -0.7565,  0.1498,  ...,  0.0038, -0.2217,  1.7975],
        [-1.4579, -0.8730,  0.0678,  ...,  0.0669, -0.2429,  1.8418],
        [-1.4305, -0.7070,  0.1236,  ...,  0.1837, -0.920