In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
import pickle
from pathlib import Path
from functools import partial

import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from lightning import fabric
from transformer_lens import (
    HookedTransformer,
    HookedTransformerConfig,
    FactoredMatrix,
    ActivationCache,
)
from transformer_lens import utils
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.multioutput import MultiOutputClassifier, MultiOutputRegressor
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics import log_loss
from einops import einsum, rearrange, unpack

from tic_tac_gpt.data import TicTacToeDataset, TicTacToeState, tensor_to_state

In [None]:
torch.set_grad_enabled(False)
torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
checkpoint_dir = Path("out/model/exp18")

with open(checkpoint_dir / "config.pkl", "rb") as f:
    config: HookedTransformerConfig = pickle.load(f)
model = HookedTransformer(config)
F = fabric.Fabric(precision="16-mixed")
state_dict = F.load(checkpoint_dir / "model_20000.pt")
model.load_and_process_state_dict(state_dict)
model = model.eval()

In [None]:
ds_train = TicTacToeDataset(Path("out/dataset/50_50/train.jsonl"))
ds_test = TicTacToeDataset(Path("out/dataset/50_50/test.jsonl"))

In [None]:
(x,) = random.choice(ds_train)

In [None]:
def visualise_logit(logit, ax1=None, ax2=None):
    probs = logit
    board = [[0] * 3 for _ in range(3)]
    special = {}
    for i, p in enumerate(probs.tolist()):
        move = TicTacToeDataset.decode_one(i)
        if isinstance(move, int):
            board[move // 3][move % 3] = p
        else:
            special[move] = p

    if ax1 is None or ax2 is None:
        fig, axs = plt.subplots(
            ncols=2, figsize=(6, 2), gridspec_kw={"width_ratios": [1.5, 1]}
        )
    else:
        axs = (ax1, ax2)
    sns.heatmap(
        board, annot=True, fmt=".2f", ax=axs[0], cmap="RdBu", center=0, square=True
    )
    sns.barplot(x=list(special.keys()), y=list(special.values()), ax=axs[1])
    axs[1].set(title="Special moves")


def visualise_board(state):
    fig, ax = plt.subplots()
    ax.imshow(state.board, cmap="gray", aspect="equal", vmin=0, vmax=2)
    ax.set(xticks=[], yticks=[], title="Board")
    return fig


print(tensor_to_state(x))
visualise_logit(model(x)[0, 3])

In [None]:
focus_games = ds_train[torch.randperm(len(ds_train))[:2048]][0]
focus_games_test = ds_test[torch.randperm(len(ds_test))[:2048]][0]
focus_games.shape

In [None]:
logits, cache = model.run_with_cache(focus_games[:, :-1])
cache.compute_head_results()
logits_test, cache_test = model.run_with_cache(focus_games_test[:, :-1])
cache_test.compute_head_results()
print(cache.keys())

In [None]:
def extract_states(game: torch.Tensor):
    for i in range(1, game.shape[0] + 1):
        state = tensor_to_state(game[:i])
        yield state
        if state.result != "in_progress":
            break


def board_to_tensor(board):
    return torch.tensor(board).flatten(-2, -1)


board_states = extract_states(focus_games[7])
for s in board_states:
    print(s)
    print(board_to_tensor(s.board))

In [None]:
probe_layer = 0


def invert(x):
    y = x.clone()
    y[x == 1] = 2
    y[x == 2] = 1
    return y


def remap(x):
    y = x.clone()
    y[x == 0] = 0
    y[x == 1] = 1
    y[x == 2] = 2
    return y


def extract_XY(games, cache):
    activations = cache["resid_post", probe_layer]
    # activations = cache["attn_out", probe_layer]
    # activations = cache["z", probe_layer, "attn"][:, :, 2, :]
    X, Y = [], []
    for game, game_activation in zip(games, activations):
        for i, (state, activation) in enumerate(
            zip(extract_states(game), game_activation)
        ):
            X.append(activation)
            target = board_to_tensor(state.board)
            if i % 2 == 0:
                target = invert(target)
            # target = remap(target)
            Y.append(target)
            # Y.append(torch.nn.functional.one_hot(target, 3))
    X, Y = torch.stack(X), torch.stack(Y)
    return X, Y


X, Y = extract_XY(focus_games, cache)
X_test, Y_test = extract_XY(focus_games_test, cache_test)
X.shape, Y.shape, X_test.shape, Y_test.shape

In [None]:
X_np = X.cpu().numpy()
Y_np = Y.cpu().numpy()
# Y_np = Y.flatten(-2, -1).cpu().numpy()
X_test_np = X_test.cpu().numpy()
Y_test_np = Y_test.cpu().numpy()
# Y_test_np = Y_test.flatten(-2, -1).cpu().numpy()

regressor = MultiOutputRegressor(LinearRegression())
regressor.fit(X_np, Y_np)
print(regressor.score(X_np, Y_np), regressor.score(X_test_np, Y_test_np))

In [None]:
X_np = X.cpu().numpy()
Y_np = Y.cpu().numpy()
X_test_np = X_test.cpu().numpy()
Y_test_np = Y_test.cpu().numpy()

classifier = MultiOutputClassifier(LogisticRegression(max_iter=1000))
classifier.fit(X_np, Y_np)

Y_pred = np.array(classifier.predict_proba(X_np)).transpose(1, 0, 2)
Y_test_pred = np.array(classifier.predict_proba(X_test_np)).transpose(1, 0, 2)
print(classifier.score(X_np, Y_np), classifier.score(X_test_np, Y_test_np))
print(
    log_loss(Y_np.reshape(-1), Y_pred.reshape(-1, Y_pred.shape[-1])),
    log_loss(Y_test_np.reshape(-1), Y_test_pred.reshape(-1, Y_test_pred.shape[-1])),
)

In [None]:
def plot_attn(x, pattern):
    """Input shape (nh s s)"""
    fig, axs = plt.subplots(ncols=pattern.shape[0], figsize=(pattern.shape[0] * 4, 4))
    for i, ax in enumerate(axs):
        p = pattern[i]
        # p /= np.max(p, axis=-1, keepdims=True)
        sns.heatmap(
            p,
            ax=ax,
            cmap="viridis",
            square=True,
            annot=True,
            annot_kws={"fontsize": 6},
            xticklabels=x,
            yticklabels=x,
        )
        ax.set(title=f"Head {i}")
    fig.tight_layout()
    # return fig


idx = random.randint(0, len(focus_games))
game = focus_games[idx]
# game = torch.tensor([ds_train.bos_token, 1, 5, 4, 8, 7])
print(tensor_to_state(game))

_, tmp_cache = model.run_with_cache(game)
for layer in range(model.cfg.n_layers):
    plot_attn(
        TicTacToeDataset.decode(game),
        tmp_cache["pattern", layer][0].cpu().numpy(),
    )

In [None]:
def cosine_sim(x: torch.Tensor, y: torch.Tensor):
    x = torch.nn.functional.normalize(x, dim=-1)
    y = torch.nn.functional.normalize(y, dim=-1)
    return torch.tensordot(x, y, dims=([-1], [-1]))


def normalize(x):
    x = x - x.mean(-1, keepdim=True)
    scale = (x.pow(2).mean(-1, keepdim=True)).sqrt()
    return x / scale


x = torch.randn(4, 5)
y = torch.randn(3, 5)
assert (cosine_sim(x, y) == cosine_sim(y, x).T).all()

In [None]:
head_idx = 1

seq_activations = [[] for _ in range(ds_train.max_seq_len - 1)]
for activations, game in zip(cache["result", 0, "attn"], focus_games_test):
    activations = activations[:, head_idx]
    for i, (activation, move) in enumerate(zip(activations, game)):
        move = move.item()
        if move in (0, 1, 2, 3, 4, 5, 6, 7, 8, ds_train.bos_token):
            seq_activations[i].append(activation)
seq_activations = [torch.stack(s) for s in seq_activations]
print([s.shape for s in seq_activations])

data = []
for i, activations in enumerate(seq_activations):
    singulars = torch.linalg.svdvals(activations).cpu().numpy()
    for j, s in enumerate(sorted(singulars, reverse=True)):
        data.append({"idx": j, "singular": s, "position": str(i)})

ax = sns.lineplot(data=pd.DataFrame(data), x="idx", y="singular", hue="position")
# for x in (1, 9, 10, 11, 12, 13, 15, 17, 19, 20):
for x in (0, 8, 10, 16, 20, 22, 24, 26, 29, 32, 35, 38):
    ax.axvline(x, color="black", linestyle="--", alpha=0.5)
ax.set(yscale="log", ylim=(1e-5, None), xlim=(-1, 22))
# ax.set(yscale="log")

In [None]:
game_1 = torch.tensor([[ds_train.bos_token, 0, 5, 1, 8, 2]])
game_2 = torch.tensor([[ds_train.bos_token, 0, 5, 1, 8, 2]])
print(tensor_to_state(game_1[0]))
print(tensor_to_state(game_2[0]))

clean_logits, clean_cache = model.run_with_cache(game_1)
corrupted_logits, corrupted_cache = model.run_with_cache(game_2)

In [None]:
def patch_head_vector(corrupted_head_vector, hook, head_index):
    corrupted_head_vector[:, -2, head_index, :] = corrupted_cache[hook.name][
        :, -1, head_index, :
    ]
    return corrupted_head_vector


def logit_diff(patched_logits):
    # terminate_idx = [9, 10, 11]
    # other_idx = [0, 1, 2, 3, 4, 5, 6, 7, 8]
    terminate_idx = [10]
    other_idx = [9]
    seq_idx = -2
    diff_patched = patched_logits[:, seq_idx, terminate_idx].mean(-1) - patched_logits[
        :, seq_idx, other_idx
    ].mean(-1)
    diff_clean = clean_logits[:, seq_idx, terminate_idx].mean(-1) - clean_logits[
        :, seq_idx, other_idx
    ].mean(-1)
    rel_change = diff_patched  # - diff_clean
    return rel_change.mean()


patched_residual_stream_diff = torch.zeros(
    model.cfg.n_layers, model.cfg.n_heads, game_1.shape[1]
)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        for seq_idx in range(game_1.shape[1]):
            hook_fn = partial(patch_head_vector, head_index=head_index)
            patched_logits = model.run_with_hooks(
                game_1,
                fwd_hooks=[(utils.get_act_name("v", layer, "attn"), hook_fn)],
                return_type="logits",
            )
            patched_logit_diff = logit_diff(patched_logits)
            patched_residual_stream_diff[
                layer, head_index, seq_idx
            ] = patched_logit_diff

In [None]:
fig, axs = plt.subplots(nrows=model.cfg.n_heads, figsize=(8, 8))
_min = patched_residual_stream_diff.min().item()
_max = patched_residual_stream_diff.max().item()
for i, ax in enumerate(axs):
    sns.heatmap(
        patched_residual_stream_diff[:, i].cpu(),
        ax=ax,
        cmap="RdBu",
        square=True,
        center=0,
        vmin=_min,
        vmax=_max,
    )
    ax.set(
        xlabel="Seq",
        ylabel="Layer",
        xticklabels=ds_train.decode(game_1[0]),
    )
fig.tight_layout()

In [None]:
game_idx = random.randint(0, 2048)
game_tensor = focus_games[game_idx]
# game_tensor = torch.tensor([ds_train.bos_token, 8, 6, 0, 4, 1, 5, 2])
# game_tensor = torch.tensor([ds_train.bos_token, 3, 6, 1, 5, 4, 8, 7])
game = tensor_to_state(game_tensor)

print(game)

In [None]:
out_idx = 1
s = len(game) + 1


def logit_contribution(activation):
    return torch.matmul(activation, model.W_U[:, out_idx])


head_results = cache.stack_head_results()
head_results = cache.apply_ln_to_stack(head_results)
head_results = rearrange(
    head_results, "(l nh) b s d -> b s l nh d", l=model.cfg.n_layers
)

out = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, s)
for layer_idx in range(model.cfg.n_layers):
    for head_idx in range(model.cfg.n_heads):
        for seq_idx in range(s):
            out[layer_idx, head_idx, seq_idx] = logit_contribution(
                head_results[game_idx, seq_idx, layer_idx, head_idx]
            )

fig, axs = plt.subplots(
    ncols=3,
    nrows=s,
    figsize=(10, 18),
    gridspec_kw={"width_ratios": [3, 1, 3]},
)
# v_min = out.min().item()
# v_max = out.max().item()
for seq_idx in range(s):
    visualise_logit(logits[game_idx, seq_idx], axs[seq_idx, 0], axs[seq_idx, 1])
    sns.heatmap(
        out[:, :, seq_idx].cpu(),
        ax=axs[seq_idx, 2],
        cmap="RdBu",
        center=0,
        square=True,
    )
fig.tight_layout()

In [None]:
# seq_idx = len(game) - 2
seq_idx = len(game)


def logit_contribution(activation, out_idx):
    return torch.matmul(activation, model.W_U[:, out_idx])


tmp_logits, tmp_cache = model.run_with_cache(game_tensor)
all_acts = tmp_cache.get_full_resid_decomposition(pos_slice=seq_idx, apply_ln=True)
# all_acts = tmp_cache.apply_ln_to_stack(all_acts, pos_slice=seq_idx)
all_acts = all_acts[:, 0, :]

heads, neurons, emb, pos, bias = unpack(
    all_acts,
    [
        [model.cfg.n_layers, model.cfg.n_heads],
        [model.cfg.n_layers, model.cfg.d_mlp],
        [1],
        [1],
        [1],
    ],
    "* d",
)
print(game)
for name, act in {"emb": emb, "pos": pos, "bias": bias}.items():
    logits_contributions = [
        logit_contribution(act, i).item() for i in range(model.cfg.d_vocab_out)
    ]
    print(f"{name:>4}: {' '.join(f'{x:>5.2f}' for x in logits_contributions)}")

fig, axs = plt.subplots(
    ncols=2,
    nrows=model.cfg.d_vocab_out,
    figsize=(12, 1 * model.cfg.d_vocab_out),
    gridspec_kw={"width_ratios": [1, 10]},
    sharey="row",
    sharex="col",
)
for i in range(model.cfg.d_vocab_out):
    heads_logits = logit_contribution(heads, i).flatten()
    sns.scatterplot(
        x=np.arange(heads_logits.numel()), y=heads_logits.cpu().numpy(), ax=axs[i, 0]
    )
    neurons_logits = logit_contribution(neurons, i).flatten()
    sns.scatterplot(
        x=np.arange(neurons_logits.numel()),
        y=neurons_logits.cpu().numpy(),
        ax=axs[i, 1],
    )
    # label outliers
    for x, y in enumerate(neurons_logits):
        if y.abs() > 0.3:
            axs[i, 1].text(x, y, f"{x}", fontsize=8, ha="center", va="center")
    axs[i, 0].set(ylabel=f"{ds_train.decode_one(i)}")
fig.tight_layout()
visualise_logit(tmp_logits[0, seq_idx])

In [None]:
V_in = model.W_in[0, :, :]

E_base = torch.concat([model.W_E[:9], model.W_pos[:10]], dim=0)
E = E_base @ model.OV[0, :].AB
E = torch.concat([E, E_base.unsqueeze(0)], dim=0)
E = torch.concat([E, E.sum(0, keepdim=True)], dim=0)
E = normalize(E)

# sim = cosine_sim(V_in.T, E)
sim = torch.tensordot(V_in.T, E, dims=([-1], [-1]))

print(f"{sim.shape=}")
fig, axs = plt.subplots(ncols=len(E), figsize=(20, 100))
order_ = torch.argsort(torch.amax(sim, dim=(-2, -1)), descending=True)
for i, ax in enumerate(axs):
    sns.heatmap(
        sim[order_, i].cpu(),
        cmap="RdBu",
        center=0,
        square=True,
        ax=ax,
        # vmin=0,
        yticklabels=order_.cpu().numpy(),
        cbar=True,
    )
    ax.set(title=f"Head {i}")
fig.tight_layout()

In [None]:
E = torch.concat([model.W_E[:9], model.W_pos[:10]], dim=0)
print(f"{E.shape=}")
# P = model.W_pos[:10]

V1 = E @ model.OV[0, 0].AB
V2 = E @ model.OV[0, 1].AB
print(f"{V1.shape=} {V2.shape=}")

sim = cosine_sim(V1, V2)
assert (sim == cosine_sim(V2, V1).T).all()

fig, ax = plt.subplots()
labels = [ds_train.decode_one(i) for i in range(9)] + list(range(10))
sns.heatmap(
    sim.cpu(),
    cmap="RdBu",
    center=0,
    square=True,
    ax=ax,
    yticklabels=labels,
    xticklabels=labels,
)

In [None]:
P = model.W_E_pos

sim1 = P @ model.QK[0, 0].AB @ P.T
# sim1 = torch.tril(sim1)
sim2 = P @ model.QK[0, 1].AB @ P.T
# sim2 = torch.tril(sim2)

labels = [ds_train.decode_one(i) for i in range(ds_train.vocab_size)] + list(range(11))

fig, axs = plt.subplots(ncols=2, figsize=(15, 5))
sns.heatmap(
    sim1.cpu(),
    cmap="RdBu",
    center=0,
    square=True,
    ax=axs[0],
    xticklabels=labels,
    yticklabels=labels,
)
axs[0].set(title="Head 0", xlabel="Key position", ylabel="Query position")
sns.heatmap(
    sim2.cpu(),
    cmap="RdBu",
    center=0,
    square=True,
    ax=axs[1],
    xticklabels=labels,
    yticklabels=labels,
)
axs[1].set(title="Head 1", xlabel="Key position", ylabel="Query position")

In [None]:
E = torch.concat([model.W_E[:9], model.W_pos[:10]], dim=0)
print(f"{E.shape=}")
# P = model.W_pos[:10]

# sim = cosine_sim(E, E)
sim = torch.tensordot(E, E, dims=([-1], [-1]))
# sim.fill_diagonal_(0)

fig, ax = plt.subplots()
labels = [ds_train.decode_one(i) for i in range(9)] + list(range(10))
sns.heatmap(
    sim.cpu(),
    cmap="RdBu",
    center=0,
    square=True,
    ax=ax,
    yticklabels=labels,
    xticklabels=labels,
)