In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
import pickle
from pathlib import Path

import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from lightning import fabric
from transformer_lens import (
    HookedTransformer,
    HookedTransformerConfig,
    FactoredMatrix,
    ActivationCache,
)
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.multioutput import MultiOutputClassifier, MultiOutputRegressor
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics import log_loss

from tic_tac_gpt.data import TicTacToeDataset, TicTacToeState, tensor_to_state

In [None]:
torch.set_grad_enabled(False)
torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
checkpoint_dir = Path("out/model/exp6")

with open(checkpoint_dir / "config.pkl", "rb") as f:
    config: HookedTransformerConfig = pickle.load(f)
model = HookedTransformer(config)
F = fabric.Fabric(precision="16-mixed")
state_dict = F.load(checkpoint_dir / "model.pt")
model.load_state_dict(state_dict)
model = model.eval()

In [None]:
ds_train = TicTacToeDataset(Path("out/dataset/50_50/train.jsonl"))
ds_test = TicTacToeDataset(Path("out/dataset/50_50/test.jsonl"))

In [None]:
(x,) = random.choice(ds_train)

In [None]:
def visualise_logit(logit):
    probs = logit.softmax(0)
    board = [[0] * 3 for _ in range(3)]
    special = {}
    for i, p in enumerate(probs.tolist()):
        move = TicTacToeDataset.decode_one(i)
        if isinstance(move, int):
            board[move // 3][move % 3] = p
        else:
            special[move] = p

    fig, axs = plt.subplots(
        ncols=2, figsize=(10, 4), gridspec_kw={"width_ratios": [1.5, 1]}
    )
    m = axs[0].imshow(board, aspect="equal", vmin=0)
    axs[0].set(xticks=[], yticks=[], title="Logits")
    fig.colorbar(m, ax=axs[0])
    sns.barplot(x=list(special.keys()), y=list(special.values()), ax=axs[1])
    axs[1].set(title="Special moves", ylim=(0, 1))

    return fig


def visualise_board(state):
    fig, ax = plt.subplots()
    ax.imshow(state.board, cmap="gray", aspect="equal", vmin=0, vmax=2)
    ax.set(xticks=[], yticks=[], title="Board")
    return fig


visualise_board(tensor_to_state(x[:4]))
visualise_logit(model(x)[0, 3])
None

In [None]:
focus_games = ds_train[torch.randperm(len(ds_train))[:1000]][0]
focus_games_test = ds_test[torch.randperm(len(ds_test))[:2048]][0]

focus_games.shape

In [None]:
def extract_states(game: torch.Tensor):
    for i in range(1, game.shape[0] + 1):
        state = tensor_to_state(game[:i])
        yield state
        if state.result != "in_progress":
            break


def board_to_tensor(board):
    return torch.tensor(board).flatten(-2, -1)


board_states = extract_states(focus_games[7])
for s in board_states:
    print(s)
    print(board_to_tensor(s.board))

In [None]:
probe_layer = 1


def invert(x):
    y = x.clone()
    y[x == 1] = 2
    y[x == 2] = 1
    return y


def remap(x):
    y = x.clone()
    y[x == 0] = -1
    y[x == 1] = 2
    y[x == 2] = 1
    return y


def extract_XY(games):
    _, cache = model.run_with_cache(games[:, :-1])
    activations = cache["resid_post", probe_layer]
    X, Y = [], []
    for game, game_activation in zip(games, activations):
        for i, (state, activation) in enumerate(
            zip(extract_states(game), game_activation)
        ):
            X.append(activation)
            target = board_to_tensor(state.board)
            if i % 2 == 0:
                target = invert(target)
            target = remap(target)
            Y.append(target)
            # Y.append(torch.nn.functional.one_hot(target, 3))
    X, Y = torch.stack(X), torch.stack(Y)
    return X, Y


X, Y = extract_XY(focus_games)
X_test, Y_test = extract_XY(focus_games_test)
X.shape, Y.shape, X_test.shape, Y_test.shape

In [None]:
X_np = X.cpu().numpy()
Y_np = Y.cpu().numpy()
# Y_np = Y.flatten(-2, -1).cpu().numpy()
X_test_np = X_test.cpu().numpy()
Y_test_np = Y_test.cpu().numpy()
# Y_test_np = Y_test.flatten(-2, -1).cpu().numpy()

regressor = MultiOutputRegressor(LinearRegression())
regressor.fit(X_np, Y_np)
print(regressor.score(X_np, Y_np), regressor.score(X_test_np, Y_test_np))

In [None]:
W = np.array([r.coef_ for r in regressor.estimators_])

In [None]:
X_np = X.cpu().numpy()
Y_np = Y.cpu().numpy()
X_test_np = X_test.cpu().numpy()
Y_test_np = Y_test.cpu().numpy()

classifier = MultiOutputClassifier(LogisticRegression(max_iter=1000))
classifier.fit(X_np, Y_np)

Y_pred = np.array(classifier.predict_proba(X_np)).transpose(1, 0, 2)
Y_test_pred = np.array(classifier.predict_proba(X_test_np)).transpose(1, 0, 2)
print(classifier.score(X_np, Y_np), classifier.score(X_test_np, Y_test_np))
print(
    log_loss(Y_np.reshape(-1), Y_pred.reshape(-1, Y_pred.shape[-1])),
    log_loss(Y_test_np.reshape(-1), Y_test_pred.reshape(-1, Y_test_pred.shape[-1])),
)

In [None]:
_, cache = model.run_with_cache(focus_games_test)
print(cache.keys())

In [None]:
def plot_attn(x, pattern):
    """Input shape (nh s s)"""
    fig, axs = plt.subplots(ncols=pattern.shape[0], figsize=(pattern.shape[0] * 4, 4))
    for i, ax in enumerate(axs):
        p = pattern[i]
        p /= np.max(p, axis=-1, keepdims=True)
        ax.imshow(p, cmap="viridis", aspect="equal")
        ax.set(
            title=f"Head {i}",
            xticks=range(len(x)),
            xticklabels=x,
            yticks=range(len(x)),
            yticklabels=x,
        )
    # return fig


idx = 3
print(tensor_to_state(focus_games[idx]))
plot_attn(
    TicTacToeDataset.decode(focus_games[idx]), cache["pattern", 0][idx].cpu().numpy()
)
plot_attn(
    TicTacToeDataset.decode(focus_games[idx]), cache["pattern", 1][idx].cpu().numpy()
)

In [None]:
W = np.array([estimator.coef_ for estimator in classifier.estimators_])
W.shape

In [None]:
print(cache.keys())

data = []
for i in range(9):
    activations = cache["ln_final.hook_normalized"][:, i]
    singulars = torch.linalg.svdvals(activations).cpu().numpy()
    for j, s in enumerate(sorted(singulars, reverse=True)):
        data.append({"idx": j, "singular": s, "position": str(i)})

ax = sns.lineplot(data=pd.DataFrame(data), x="idx", y="singular", hue="position")
ax.axvline(9, color="k", linestyle="--")
ax.set(yscale="log", ylim=(1e-1, 1e3))