In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
import pickle
from pathlib import Path
from functools import partial

import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from lightning import fabric
from transformer_lens import (
    HookedTransformer,
    HookedTransformerConfig,
    FactoredMatrix,
    ActivationCache,
)
from transformer_lens import utils
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.multioutput import MultiOutputClassifier, MultiOutputRegressor
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics import log_loss
from einops import einsum

from tic_tac_gpt.data import TicTacToeDataset, TicTacToeState, tensor_to_state

In [None]:
torch.set_grad_enabled(False)
torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
checkpoint_dir = Path("out/model/exp6")

with open(checkpoint_dir / "config.pkl", "rb") as f:
    config: HookedTransformerConfig = pickle.load(f)
model = HookedTransformer(config)
F = fabric.Fabric(precision="16-mixed")
state_dict = F.load(checkpoint_dir / "model.pt")
model.load_and_process_state_dict(state_dict)
model = model.eval()

In [None]:
ds_train = TicTacToeDataset(Path("out/dataset/50_50/train.jsonl"))
ds_test = TicTacToeDataset(Path("out/dataset/50_50/test.jsonl"))

In [None]:
(x,) = random.choice(ds_train)

In [None]:
def visualise_logit(logit):
    probs = logit.softmax(0)
    board = [[0] * 3 for _ in range(3)]
    special = {}
    for i, p in enumerate(probs.tolist()):
        move = TicTacToeDataset.decode_one(i)
        if isinstance(move, int):
            board[move // 3][move % 3] = p
        else:
            special[move] = p

    fig, axs = plt.subplots(
        ncols=2, figsize=(10, 4), gridspec_kw={"width_ratios": [1.5, 1]}
    )
    m = axs[0].imshow(board, aspect="equal", vmin=0)
    axs[0].set(xticks=[], yticks=[], title="Logits")
    fig.colorbar(m, ax=axs[0])
    sns.barplot(x=list(special.keys()), y=list(special.values()), ax=axs[1])
    axs[1].set(title="Special moves", ylim=(0, 1))

    return fig


def visualise_board(state):
    fig, ax = plt.subplots()
    ax.imshow(state.board, cmap="gray", aspect="equal", vmin=0, vmax=2)
    ax.set(xticks=[], yticks=[], title="Board")
    return fig


print(tensor_to_state(x))
visualise_logit(model(x)[0, 3])
None

In [None]:
focus_games = ds_train[torch.randperm(len(ds_train))[:4096]][0]
focus_games_test = ds_test[torch.randperm(len(ds_test))[:4096]][0]
focus_games.shape

In [None]:
_, cache = model.run_with_cache(focus_games[:, :-1])
_, cache_test = model.run_with_cache(focus_games_test[:, :-1])
print(cache.keys())

In [None]:
def extract_states(game: torch.Tensor):
    for i in range(1, game.shape[0] + 1):
        state = tensor_to_state(game[:i])
        yield state
        if state.result != "in_progress":
            break


def board_to_tensor(board):
    return torch.tensor(board).flatten(-2, -1)


board_states = extract_states(focus_games[7])
for s in board_states:
    print(s)
    print(board_to_tensor(s.board))

In [None]:
probe_layer = 0


def invert(x):
    y = x.clone()
    y[x == 1] = 2
    y[x == 2] = 1
    return y


def remap(x):
    y = x.clone()
    y[x == 0] = 0
    y[x == 1] = 1
    y[x == 2] = 2
    return y


def extract_XY(games, cache):
    activations = cache["resid_post", probe_layer]
    # activations = cache["attn_out", probe_layer]
    # activations = cache["z", probe_layer, "attn"][:, :, 2, :]
    X, Y = [], []
    for game, game_activation in zip(games, activations):
        for i, (state, activation) in enumerate(
            zip(extract_states(game), game_activation)
        ):
            X.append(activation)
            target = board_to_tensor(state.board)
            if i % 2 == 0:
                target = invert(target)
            # target = remap(target)
            Y.append(target)
            # Y.append(torch.nn.functional.one_hot(target, 3))
    X, Y = torch.stack(X), torch.stack(Y)
    return X, Y


X, Y = extract_XY(focus_games, cache)
X_test, Y_test = extract_XY(focus_games_test, cache_test)
X.shape, Y.shape, X_test.shape, Y_test.shape

In [None]:
X_np = X.cpu().numpy()
Y_np = Y.cpu().numpy()
# Y_np = Y.flatten(-2, -1).cpu().numpy()
X_test_np = X_test.cpu().numpy()
Y_test_np = Y_test.cpu().numpy()
# Y_test_np = Y_test.flatten(-2, -1).cpu().numpy()

regressor = MultiOutputRegressor(LinearRegression())
regressor.fit(X_np, Y_np)
print(regressor.score(X_np, Y_np), regressor.score(X_test_np, Y_test_np))

In [None]:
X_np = X.cpu().numpy()
Y_np = Y.cpu().numpy()
X_test_np = X_test.cpu().numpy()
Y_test_np = Y_test.cpu().numpy()

classifier = MultiOutputClassifier(LogisticRegression(max_iter=1000))
classifier.fit(X_np, Y_np)

Y_pred = np.array(classifier.predict_proba(X_np)).transpose(1, 0, 2)
Y_test_pred = np.array(classifier.predict_proba(X_test_np)).transpose(1, 0, 2)
print(classifier.score(X_np, Y_np), classifier.score(X_test_np, Y_test_np))
print(
    log_loss(Y_np.reshape(-1), Y_pred.reshape(-1, Y_pred.shape[-1])),
    log_loss(Y_test_np.reshape(-1), Y_test_pred.reshape(-1, Y_test_pred.shape[-1])),
)

In [None]:
def plot_attn(x, pattern):
    """Input shape (nh s s)"""
    fig, axs = plt.subplots(ncols=pattern.shape[0], figsize=(pattern.shape[0] * 3, 3))
    for i, ax in enumerate(axs):
        p = pattern[i]
        p /= np.max(p, axis=-1, keepdims=True)
        ax.imshow(p, cmap="viridis", aspect="equal")
        ax.set(
            title=f"Head {i}",
            xticks=range(len(x)),
            xticklabels=x,
            yticks=range(len(x)),
            yticklabels=x,
        )
    fig.tight_layout()
    # return fig


idx = 10
print(tensor_to_state(focus_games[idx]))
plot_attn(
    TicTacToeDataset.decode(focus_games[idx])[:-1],
    cache["pattern", 0][idx].cpu().numpy(),
)
plot_attn(
    TicTacToeDataset.decode(focus_games[idx])[:-1],
    cache["pattern", 1][idx].cpu().numpy(),
)

In [None]:
def cosine_sim(x: torch.Tensor, y: torch.Tensor):
    x = x / x.norm(dim=-1, keepdim=True)
    y = y / y.norm(dim=-1, keepdim=True)
    return x @ y.T

In [None]:
print(cache.keys())

seq_activations = [[] for _ in range(ds_train.max_seq_len - 1)]
for activations, game in zip(cache["attn_out", 0], focus_games_test):
    # print(activations.shape)
    # activations = activations[:, 0, :]
    for i, (activation, move) in enumerate(zip(activations, game)):
        move = move.item()
        if move in (0, 1, 2, 3, 4, 5, 6, 7, 8, ds_train.bos_token):
            seq_activations[i].append(activation)
seq_activations = [torch.stack(s) for s in seq_activations]
print([s.shape for s in seq_activations])

data = []
for i, activations in enumerate(seq_activations):
    singulars = torch.linalg.svdvals(activations).cpu().numpy()
    for j, s in enumerate(sorted(singulars, reverse=True)):
        data.append({"idx": j, "singular": s, "position": str(i)})

ax = sns.lineplot(data=pd.DataFrame(data), x="idx", y="singular", hue="position")
for x in (1, 9, 45, 49, 53, 57, 65, 73, 81, 85):
    # for x in (1, 9, 72, 241):
    ax.axvline(x, color="black", linestyle="--", alpha=0.5)
ax.set(yscale="log", ylim=(1e-9, None))
# ax.set(yscale="log")

In [None]:
print(cache.keys())

head_idx = 3
act_Q = cache["blocks.0.attn.hook_q"][:, :, head_idx, :]
act_K = cache["blocks.0.attn.hook_k"][:, :, head_idx, :]
act_V = cache["blocks.0.attn.hook_v"][:, :, head_idx, :]

seq_idx = torch.arange(act_Q.shape[1]).tile((act_Q.shape[0], 1)) % 2 == 0
seq_idx = torch.flatten(seq_idx, 0, 1).cpu().numpy()

fig, axs = plt.subplots(ncols=3, figsize=(12, 4))
for ax, act in zip(axs, (act_Q, act_K, act_V)):
    act = torch.flatten(act, 0, 1).cpu().numpy()
    pca = PCA(n_components=2)
    X_pca = pca.fit_transform(act)
    print(pca.explained_variance_ratio_)
    ax.scatter(X_pca[:, 0], X_pca[:, 1], c=seq_idx)

In [None]:
U_Q, S_Q, Vh_Q = torch.linalg.svd(model.W_Q[0, 3], full_matrices=False)
U_K, S_K, Vh_K = torch.linalg.svd(model.W_K[0, 3], full_matrices=False)
# print(U_head.shape, S_head.shape, Vh_head.shape)

sim = cosine_sim(Vh_Q, Vh_K)
# sim = cosine_sim(model.W_pos, U_Q.T)
# sim = cosine_sim(model.W_pos, U_K.T)

fig, ax = plt.subplots(figsize=(10, 4))
sns.heatmap(sim.cpu(), cmap="RdBu", center=0, square=True, ax=ax)

In [None]:
print(cache.keys())
A = cache["v", 0, "attn"].flatten(-2, -1)
# A = einsum(A, model.W_O[0], "b s h d, h d d2 -> b s h d2")
# A, B = A[:, :, 0], A[:, :, 3]
B = cache["resid_pre", 0]
print(A.shape, B.shape, model.W_O[0].shape)

seq_idx = 6
A = A[:, seq_idx, :]
B = B[:, seq_idx, :]

U, S, Vh = A.svd()
R = torch.linalg.matrix_rank(A)
U_B, S_B, Vh_B = B.svd()
R_B = torch.linalg.matrix_rank(B)
print(U.shape, S.shape, Vh.shape, R)
print(U_B.shape, S_B.shape, Vh_B.shape, R_B)

sim = cosine_sim(Vh[:R], Vh_B[:R_B])

fig, ax = plt.subplots(figsize=(10, 4))
sns.heatmap(sim.cpu(), cmap="RdBu", center=0, square=True, ax=ax)
# sns.histplot(sim.cpu().numpy().flatten(), bins=100, ax=ax, stat="density")
# sns.lineplot(x=range(1, len(S) + 1), y=S.cpu().numpy(), ax=ax)
# sns.lineplot(x=range(1, len(S_resid) + 1), y=S_resid.cpu().numpy(), ax=ax)
# ax.set(yscale="log")

In [None]:
V_out_0 = cache["attn_out", 0].flatten(0, 1)
V_base = model.W_E_pos

U_out_0, S_out_0, Vh_out_0 = torch.linalg.svd(V_out_0, full_matrices=False)
U_base, S_base, Vh_base = torch.linalg.svd(V_base, full_matrices=False)

fig, axs = plt.subplots(ncols=2, figsize=(10, 4))
sns.lineplot(x=range(len(S_out_0)), y=S_out_0.cpu(), ax=axs[0])
sns.lineplot(x=range(len(S_base)), y=S_base.cpu(), ax=axs[1])
axs[0].axvline(85, color="black", linestyle="--", alpha=0.5)
axs[0].set(yscale="log")
axs[1].set(yscale="log")
fig.tight_layout()


# V_base = V_base / torch.linalg.norm(V_base, dim=-1, keepdim=True)
# V_out_0 = V_out_0 / torch.linalg.norm(V_out_0, dim=-1, keepdim=True)

# sim_out_0_out_0 = V_out_0 @ V_out_0.transpose(-2, -1)
sim_out_0_base = cosine_sim(V_out_0, V_base)
sim_base_base = cosine_sim(V_base, V_base)

fig, ax = plt.subplots()
# sns.histplot(sim_out_0_out_0.flatten().cpu(), ax=ax, label="out_0 vs out_0")
# sns.histplot(sim_out_0_base.flatten().cpu(), ax=ax, label="out_0 vs base")
sns.histplot(sim_base_base.flatten().cpu(), ax=ax, label="base vs base", stat="density")
# ax.set(yscale="log")
# ax.legend()

In [None]:
game_1 = torch.tensor([[ds_train.bos_token, 5, 1, 4, 2, 8, 0]])
game_2 = torch.tensor([[ds_train.bos_token, 5, 4, 1, 2, 8, 0]])
print(tensor_to_state(game_1[0]))
print(tensor_to_state(game_2[0]))

clean_logits, clean_cache = model.run_with_cache(game_1)
corrupted_logits, corrupted_cache = model.run_with_cache(game_2)

In [None]:
def patch_head_vector(corrupted_head_vector, hook, seq_idx, head_index):
    corrupted_head_vector[:, seq_idx, head_index, :] = corrupted_cache[hook.name][
        :, seq_idx, head_index, :
    ]
    return corrupted_head_vector


def logit_diff(patched_logits):
    # terminate_idx = [9, 10, 11]
    # other_idx = [0, 1, 2, 3, 4, 5, 6, 7, 8]
    terminate_idx = list(range(9))
    other_idx = [10]
    seq_idx = -1
    diff_patched = patched_logits[:, seq_idx, terminate_idx].mean(-1) - patched_logits[
        :, seq_idx, other_idx
    ].mean(-1)
    diff_clean = clean_logits[:, seq_idx, terminate_idx].mean(-1) - clean_logits[
        :, seq_idx, other_idx
    ].mean(-1)
    rel_change = diff_patched - diff_clean
    return rel_change.mean()


patched_residual_stream_diff = torch.zeros(
    model.cfg.n_layers, model.cfg.n_heads, game_1.shape[1]
)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        for seq_idx in range(game_1.shape[1]):
            hook_fn = partial(patch_head_vector, head_index=head_index, seq_idx=seq_idx)
            patched_logits = model.run_with_hooks(
                game_1,
                fwd_hooks=[(utils.get_act_name("z", layer, "attn"), hook_fn)],
                return_type="logits",
            )
            patched_logit_diff = logit_diff(patched_logits)
            patched_residual_stream_diff[
                layer, head_index, seq_idx
            ] = patched_logit_diff

In [None]:
fig, axs = plt.subplots(nrows=model.cfg.n_heads, figsize=(8, 8))
_min = patched_residual_stream_diff.min().item()
_max = patched_residual_stream_diff.max().item()
for i, ax in enumerate(axs):
    sns.heatmap(
        patched_residual_stream_diff[:, i].cpu(),
        ax=ax,
        cmap="RdBu",
        square=True,
        center=0,
        vmin=_min,
        vmax=_max,
    )
    ax.set(
        xlabel="Seq",
        ylabel="Layer",
        xticklabels=ds_train.decode(game_1[0]),
    )
fig.tight_layout()