In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
import pickle
import json
import math
from pathlib import Path
from functools import partial

import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from lightning import fabric
from transformer_lens import (
    HookedTransformer,
    HookedTransformerConfig,
    FactoredMatrix,
    ActivationCache,
)
from transformer_lens import utils
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.multioutput import MultiOutputClassifier, MultiOutputRegressor
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics import log_loss
from einops import einsum, rearrange, unpack, repeat
from matplotlib import cm, colors
from tqdm import tqdm

from tic_tac_gpt.data import TicTacToeDataset, TicTacToeState, tensor_to_state

import tic_tac_gpt.utils

In [None]:
torch.set_grad_enabled(False)
torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
checkpoint_dir = Path("out/model/exp23")

with open(checkpoint_dir / "config.pkl", "rb") as f:
    config: HookedTransformerConfig = pickle.load(f)
model = HookedTransformer(config)
F = fabric.Fabric(precision="16-mixed")
state_dict = F.load(checkpoint_dir / "model_40000.pt")
model.load_and_process_state_dict(state_dict)
model = model.eval()

In [None]:
# ds_train = TicTacToeDataset.from_file(Path("out/dataset/50_50/train.jsonl"))
# ds_test = TicTacToeDataset.from_file(Path("out/dataset/50_50/test.jsonl"))
ds_train = TicTacToeDataset.from_file(Path("out/dataset/50_50_even/train.jsonl"))
ds_test = TicTacToeDataset.from_file(Path("out/dataset/50_50_even/test.jsonl"))

In [None]:
def visualise_logit(logit, ax1=None, ax2=None, vmin=None, vmax=None, center: float = 0):
    probs = logit
    board = [[0] * 3 for _ in range(3)]
    special = {}
    for i, p in enumerate(probs.tolist()):
        if i < 9:
            board[i // 3][i % 3] = p
        elif i == 9:
            special["[X]"] = p
        elif i == 10:
            special["[O]"] = p
        elif i == 11:
            special["[D]"] = p

    if ax1 is None or ax2 is None:
        fig, (ax1, ax2) = plt.subplots(
            ncols=2, figsize=(3, 2), gridspec_kw={"width_ratios": [1, 0.4]}
        )
    if vmin is None:
        vmin = probs.min().item()
    if vmax is None:
        vmax = probs.max().item()

    sns.heatmap(
        board,
        ax=ax1,
        cmap="RdBu",
        center=center,
        square=True,
        cbar=False,
        vmin=vmin,
        vmax=vmax,
    )
    ax1.set(yticks=[], xticks=[])
    sns.heatmap(
        [[v] for v in special.values()],
        ax=ax2,
        cmap="RdBu",
        center=center,
        square=True,
        cbar=False,
        vmin=vmin,
        vmax=vmax,
    )
    ax2.set(yticks=[], xticks=[])


def visualise_board(state):
    fig, ax = plt.subplots()
    ax.imshow(state.board, cmap="gray", aspect="equal", vmin=0, vmax=2)
    ax.set(xticks=[], yticks=[], title="Board")
    return fig


(x,) = random.choice(ds_train)
game = tensor_to_state(x)
print(game)
visualise_logit(model(x)[0, len(game)])

In [None]:
focus_games = ds_train[torch.randperm(len(ds_train))[:5000]][0]
focus_games_test = ds_test[torch.randperm(len(ds_test))[:5000]][0]
focus_games.shape

In [None]:
logits, cache = model.run_with_cache(focus_games[:, :-1])
cache.compute_head_results()
logits_test, cache_test = model.run_with_cache(focus_games_test[:, :-1])
cache_test.compute_head_results()
print(cache.keys())

In [None]:
from scipy.stats import kurtosis

a = cache["resid_mid", 0]

sns.histplot(a.flatten().cpu().numpy(), bins=100)
# kurtosis(a.flatten().cpu().numpy(), fisher=False)

In [None]:
def extract_states(game: torch.Tensor):
    for i in range(1, game.shape[0] + 1):
        state = tensor_to_state(game[:i])
        yield state
        if state.result != "in_progress":
            break


def board_to_tensor(board):
    return torch.tensor(board).flatten(-2, -1)


board_states = extract_states(focus_games[7])
for s in board_states:
    print(s)
    print(board_to_tensor(s.board))

In [None]:
probe_layer = 0


def invert(x):
    y = x.clone()
    y[x == 1] = 2
    y[x == 2] = 1
    return y


def remap(x):
    y = x.clone()
    y[x == 0] = 0
    y[x == 1] = 1
    y[x == 2] = 2
    return y


def extract_XY(games, cache):
    activations = cache["resid_mid", probe_layer]
    # activations = cache["attn_out", probe_layer]
    # activations = cache["z", probe_layer, "attn"][:, :, 2, :]
    X, Y = [], []
    for game, game_activation in zip(games, activations):
        for i, (state, activation) in enumerate(
            zip(extract_states(game), game_activation)
        ):
            X.append(activation)
            target = board_to_tensor(state.board)
            if i % 2 == 0:
                target = invert(target)
            # target = remap(target)
            Y.append(target)
            # Y.append(torch.nn.functional.one_hot(target, 3))
    X, Y = torch.stack(X), torch.stack(Y)
    return X, Y


X, Y = extract_XY(focus_games, cache)
X_test, Y_test = extract_XY(focus_games_test, cache_test)
X.shape, Y.shape, X_test.shape, Y_test.shape

In [None]:
X_np = X.cpu().numpy()
Y_np = Y.cpu().numpy()
# Y_np = Y.flatten(-2, -1).cpu().numpy()
X_test_np = X_test.cpu().numpy()
Y_test_np = Y_test.cpu().numpy()
# Y_test_np = Y_test.flatten(-2, -1).cpu().numpy()

regressor = MultiOutputRegressor(LinearRegression())
regressor.fit(X_np, Y_np)
print(regressor.score(X_np, Y_np), regressor.score(X_test_np, Y_test_np))

In [None]:
X_np = X.cpu().numpy()
Y_np = Y.cpu().numpy()
X_test_np = X_test.cpu().numpy()
Y_test_np = Y_test.cpu().numpy()

classifier = MultiOutputClassifier(LogisticRegression(max_iter=1000))
classifier.fit(X_np, Y_np)

Y_pred = np.array(classifier.predict_proba(X_np)).transpose(1, 0, 2)
Y_test_pred = np.array(classifier.predict_proba(X_test_np)).transpose(1, 0, 2)
print(classifier.score(X_np, Y_np), classifier.score(X_test_np, Y_test_np))
print(
    log_loss(Y_np.reshape(-1), Y_pred.reshape(-1, Y_pred.shape[-1])),
    log_loss(Y_test_np.reshape(-1), Y_test_pred.reshape(-1, Y_test_pred.shape[-1])),
)

In [None]:
def plot_attn(x, pattern):
    """Input shape (nh s s)"""
    fig, axs = plt.subplots(
        ncols=pattern.shape[0] + 1,
        figsize=(pattern.shape[0] * 4, 4),
        gridspec_kw={"width_ratios": [1] * pattern.shape[0] + [0.07]},
    )
    for i in range(pattern.shape[0]):
        ax = axs[i]
        p = pattern[i]
        sns.heatmap(
            p,
            ax=ax,
            cmap="viridis",
            square=True,
            xticklabels=[f"\\texttt{{{i}}}" for i in x],
            yticklabels=[f"\\texttt{{{i}}}" for i in x],
            cbar=False,
            linewidth=0.5,
            linecolor="black",
            clip_on=False,
        )
        ax.set(title=f"Head {i}")

    im = cm.ScalarMappable(cmap="viridis", norm=colors.Normalize(vmin=0, vmax=1))
    fig.colorbar(im, cax=axs[-1])
    return fig


idx = random.randint(0, len(focus_games))
game = focus_games[idx]
# game = torch.tensor([ds_train.bos_token, 1, 5, 4, 8, 7])
print(tensor_to_state(game))

_, tmp_cache = model.run_with_cache(game)
fig = plot_attn(
    TicTacToeDataset.decode(game),
    tmp_cache["pattern", 0][0].cpu().numpy(),
)
# fig.savefig("out/figs/attn_pattern.pdf", bbox_inches="tight")

In [None]:
def cosine_sim(x: torch.Tensor, y: torch.Tensor):
    x = torch.nn.functional.normalize(x, dim=-1)
    y = torch.nn.functional.normalize(y, dim=-1)
    return torch.tensordot(x, y, dims=([-1], [-1]))


def normalize(x, dim=-1):
    x = x - x.mean(dim, keepdim=True)
    scale = (x.pow(2).mean(dim, keepdim=True)).sqrt()
    return x / scale


x = torch.randn(4, 5)
y = torch.randn(3, 5)
assert (cosine_sim(x, y) == cosine_sim(y, x).T).all()

In [None]:
head_idx = 1

seq_activations = [[] for _ in range(ds_train.max_seq_len - 1)]
for activations, game in zip(cache["resid_mid", 0], focus_games_test):
    # activations = activations[:, head_idx]
    for i, (activation, move) in enumerate(zip(activations, game)):
        move = move.item()
        if move in (0, 1, 2, 3, 4, 5, 6, 7, 8, ds_train.bos_token):
            seq_activations[i].append(activation)
seq_activations = [torch.stack(s) for s in seq_activations]
print([s.shape for s in seq_activations])

data = []
for i, activations in enumerate(seq_activations):
    singulars = torch.linalg.svdvals(activations).cpu().numpy()
    singulars = singulars / singulars[0]
    for j, s in enumerate(sorted(singulars, reverse=True)):
        data.append({"idx": j, "singular": s, "position": str(i)})

ax = sns.lineplot(data=pd.DataFrame(data), x="idx", y="singular", hue="position")
# for x in (1, 9, 10, 11, 12, 13, 15, 17, 19, 20):
for x in (0, 8, 10, 16, 20, 22, 24, 26, 29, 32, 35, 38):
    ax.axvline(x, color="black", linestyle="--", alpha=0.5)
ax.set(yscale="log", ylim=(1e-3, 1), xlim=(-1, 50))
# ax.set(ylim=(None, None), xlim=(-1, 40))
# ax.set(yscale="log")

In [None]:
game_1 = torch.tensor([[ds_train.bos_token, 0, 5, 1, 8, 2]])
game_2 = torch.tensor([[ds_train.bos_token, 0, 5, 1, 8, 2]])
print(tensor_to_state(game_1[0]))
print(tensor_to_state(game_2[0]))

clean_logits, clean_cache = model.run_with_cache(game_1)
corrupted_logits, corrupted_cache = model.run_with_cache(game_2)

In [None]:
def patch_head_vector(corrupted_head_vector, hook, head_index):
    corrupted_head_vector[:, -2, head_index, :] = corrupted_cache[hook.name][
        :, -1, head_index, :
    ]
    return corrupted_head_vector


def logit_diff(patched_logits):
    # terminate_idx = [9, 10, 11]
    # other_idx = [0, 1, 2, 3, 4, 5, 6, 7, 8]
    terminate_idx = [10]
    other_idx = [9]
    seq_idx = -2
    diff_patched = patched_logits[:, seq_idx, terminate_idx].mean(-1) - patched_logits[
        :, seq_idx, other_idx
    ].mean(-1)
    diff_clean = clean_logits[:, seq_idx, terminate_idx].mean(-1) - clean_logits[
        :, seq_idx, other_idx
    ].mean(-1)
    rel_change = diff_patched  # - diff_clean
    return rel_change.mean()


patched_residual_stream_diff = torch.zeros(
    model.cfg.n_layers, model.cfg.n_heads, game_1.shape[1]
)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        for seq_idx in range(game_1.shape[1]):
            hook_fn = partial(patch_head_vector, head_index=head_index)
            patched_logits = model.run_with_hooks(
                game_1,
                fwd_hooks=[(utils.get_act_name("v", layer, "attn"), hook_fn)],
                return_type="logits",
            )
            patched_logit_diff = logit_diff(patched_logits)
            patched_residual_stream_diff[layer, head_index, seq_idx] = (
                patched_logit_diff
            )

In [None]:
fig, axs = plt.subplots(nrows=model.cfg.n_heads, figsize=(8, 8))
_min = patched_residual_stream_diff.min().item()
_max = patched_residual_stream_diff.max().item()
for i, ax in enumerate(axs):
    sns.heatmap(
        patched_residual_stream_diff[:, i].cpu(),
        ax=ax,
        cmap="RdBu",
        square=True,
        center=0,
        vmin=_min,
        vmax=_max,
    )
    ax.set(
        xlabel="Seq",
        ylabel="Layer",
        xticklabels=ds_train.decode(game_1[0]),
    )
fig.tight_layout()

In [None]:
game_idx = random.randint(0, 2048)
game_tensor = focus_games[game_idx]
# game_tensor = torch.tensor([ds_train.bos_token, 8, 6, 0, 4, 1, 5, 2])
# game_tensor = torch.tensor([ds_train.bos_token, 3, 6, 1, 5, 4, 8, 7])
game = tensor_to_state(game_tensor)

print(game)

In [None]:
out_idx = 1
s = len(game) + 1


def logit_contribution(activation):
    return torch.matmul(activation, model.W_U[:, out_idx])


head_results = cache.stack_head_results()
head_results = cache.apply_ln_to_stack(head_results)
head_results = rearrange(
    head_results, "(l nh) b s d -> b s l nh d", l=model.cfg.n_layers
)

out = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, s)
for layer_idx in range(model.cfg.n_layers):
    for head_idx in range(model.cfg.n_heads):
        for seq_idx in range(s):
            out[layer_idx, head_idx, seq_idx] = logit_contribution(
                head_results[game_idx, seq_idx, layer_idx, head_idx]
            )

fig, axs = plt.subplots(
    ncols=3,
    nrows=s,
    figsize=(10, 18),
    gridspec_kw={"width_ratios": [3, 1, 3]},
)
# v_min = out.min().item()
# v_max = out.max().item()
for seq_idx in range(s):
    visualise_logit(logits[game_idx, seq_idx], axs[seq_idx, 0], axs[seq_idx, 1])
    sns.heatmap(
        out[:, :, seq_idx].cpu(),
        ax=axs[seq_idx, 2],
        cmap="RdBu",
        center=0,
        square=True,
    )
fig.tight_layout()

In [None]:
# seq_idx = len(game) - 2
seq_idx = len(game)


def logit_contribution(activation, out_idx):
    return torch.matmul(activation, model.W_U[:, out_idx])


tmp_logits, tmp_cache = model.run_with_cache(game_tensor)
all_acts = tmp_cache.get_full_resid_decomposition(pos_slice=seq_idx, apply_ln=True)
# all_acts = tmp_cache.apply_ln_to_stack(all_acts, pos_slice=seq_idx)
all_acts = all_acts[:, 0, :]

heads, neurons, emb, pos, bias = unpack(
    all_acts,
    [
        [model.cfg.n_layers, model.cfg.n_heads],
        [model.cfg.n_layers, model.cfg.d_mlp],
        [1],
        [1],
        [1],
    ],
    "* d",
)
print(game)
for name, act in {"emb": emb, "pos": pos, "bias": bias}.items():
    logits_contributions = [
        logit_contribution(act, i).item() for i in range(model.cfg.d_vocab_out)
    ]
    print(f"{name:>4}: {' '.join(f'{x:>5.2f}' for x in logits_contributions)}")

with plt.style.context("default"):
    fig, axs = plt.subplots(
        ncols=2,
        nrows=model.cfg.d_vocab_out,
        figsize=(12, 1 * model.cfg.d_vocab_out),
        gridspec_kw={"width_ratios": [1, 10]},
        sharey="row",
        sharex="col",
    )
    for i in range(model.cfg.d_vocab_out):
        heads_logits = logit_contribution(heads, i).flatten()
        sns.scatterplot(
            x=np.arange(heads_logits.numel()),
            y=heads_logits.cpu().numpy(),
            ax=axs[i, 0],
        )
        neurons_logits = logit_contribution(neurons, i).flatten()
        sns.scatterplot(
            x=np.arange(neurons_logits.numel()),
            y=neurons_logits.cpu().numpy(),
            ax=axs[i, 1],
        )
        # label outliers
        for x, y in enumerate(neurons_logits):
            if y.abs() > 0.5:
                axs[i, 1].text(x, y, f"{x}", fontsize=8, ha="center", va="center")
        axs[i, 0].set(ylabel=f"{ds_train.decode_one(i)}")
    fig.tight_layout()
    visualise_logit(tmp_logits[0, seq_idx])

In [None]:
V_in = model.W_in[0, :, :]

E_base = torch.concat([model.W_E[:9], model.W_pos[:10]], dim=0)
E = E_base @ model.OV[0, :].AB
E = torch.concat([E, E_base.unsqueeze(0)], dim=0)
E = torch.concat([E, E.sum(0, keepdim=True)], dim=0)
E = normalize(E)

# sim = cosine_sim(V_in.T, E)
sim = torch.tensordot(V_in.T, E, dims=([-1], [-1]))

print(f"{sim.shape=}")
with plt.style.context("default"):
    fig, axs = plt.subplots(ncols=len(E), figsize=(20, 100))
    order_ = torch.argsort(torch.amax(sim, dim=(-2, -1)), descending=True)[:200]
    for i, ax in enumerate(axs):
        sns.heatmap(
            sim[order_, i].cpu(),
            cmap="RdBu",
            center=0,
            square=True,
            ax=ax,
            # vmin=0,
            yticklabels=order_.cpu().numpy(),
            cbar=True,
        )
        ax.set(title=f"Head {i}")
    fig.tight_layout()

In [None]:
out_circuit = model.W_out[0] @ model.W_U

effects = out_circuit.amax(dim=-1) - out_circuit.amin(dim=-1)
neurons = torch.argsort(effects, descending=True)[:30]

per_row = 10
scale = 0.25
vmin = out_circuit[neurons].min().item()
vmax = out_circuit[neurons].max().item()

with plt.style.context("no-latex"):
    fig, axs = plt.subplots(
        ncols=2 * per_row,
        nrows=math.ceil(len(neurons) / per_row),
        figsize=(4 * scale * per_row, 4 * scale * math.ceil(len(neurons) / per_row)),
        width_ratios=[3, 1] * per_row,
        squeeze=False,
    )
    for i, neuron in enumerate(neurons):
        ax1 = axs[i // per_row, 2 * (i % per_row)]
        ax2 = axs[i // per_row, 2 * (i % per_row) + 1]
        ax1.set(title=f"{neuron}")

        visualise_logit(
            out_circuit[neuron],
            ax1,
            ax2,
            # vmin,
            vmax=out_circuit[neuron].max().item(),
            center=out_circuit[neuron].max().item() - 1,
        )

    fig.tight_layout()

In [None]:
neuron_acts = cache["mlp_mid", 0].view(-1, model.cfg.d_mlp)

threshold = torch.quantile(neuron_acts.flatten()[::100], 0.99)
# threshold = 1e-2
active = neuron_acts > threshold
activity = active.sum(0)

print(threshold)
print(torch.argsort(activity, descending=True)[:100])
# print(torch.where(active))

fig, ax = plt.subplots(figsize=(6, 4))

sns.ecdfplot(activity.flatten().cpu().numpy(), ax=ax)
# ax.set(xscale="log")

In [None]:
def visualise_neuron_input_output(neuron: int):
    fig, axs = plt.subplots(
        ncols=8,
        nrows=1,
        figsize=(8, 2),
        gridspec_kw={"width_ratios": [1, 1 / 9] * 3 + [1, 1 / 3]},
    )

    E_base = torch.concat([model.W_E[:9], model.W_pos[:10]], dim=0)
    in_circuits = [
        E_base @ model.W_in[0, :, neuron] + model.b_in[0, neuron],
        E_base @ model.OV[0, 0, :].AB @ model.W_in[0, :, neuron]
        + model.b_in[0, neuron],
        E_base @ model.OV[0, 1, :].AB @ model.W_in[0, :, neuron]
        + model.b_in[0, neuron],
    ]
    vmin = min(in_circuit.min().item() for in_circuit in in_circuits)
    vmax = max(in_circuit.max().item() for in_circuit in in_circuits)

    for i, in_circuit in enumerate(in_circuits):
        board = in_circuit[:9].reshape(3, 3)
        pos = in_circuit[9:].unsqueeze(-1)
        sns.heatmap(
            board.cpu().numpy(),
            ax=axs[i * 2],
            cmap="RdBu",
            # center=0,
            square=True,
            cbar=False,
            vmin=vmin,
            vmax=vmax,
            xticklabels=[],
            yticklabels=[],
        )
        sns.heatmap(
            pos.cpu().numpy(),
            ax=axs[i * 2 + 1],
            cmap="RdBu",
            # center=0,
            square=True,
            cbar=False,
            vmin=vmin,
            vmax=vmax,
            yticklabels=[],
            xticklabels=[],
        )

    out_circuit = model.W_out[0, neuron] @ model.W_U
    visualise_logit(out_circuit, axs[-2], axs[-1])


visualise_neuron_input_output(7)

In [None]:
P = normalize(model.W_pos[:10])
P0 = P @ model.OV[0, 0, :].AB
P1 = P @ model.OV[0, 1, :].AB

sim = P @ model.W_in[0]
sim0 = P0 @ model.W_in[0]
sim1 = P1 @ model.W_in[0]
# sim = cosine_sim(P, model.W_in[0].T)
# sim0 = cosine_sim(P0, model.W_in[0].T)
# sim1 = cosine_sim(P1, model.W_in[0].T)

fig, ax = plt.subplots(figsize=(6, 4))
sns.histplot(sim.flatten().cpu().numpy(), ax=ax, label="P")
sns.histplot(sim0.flatten().cpu().numpy(), ax=ax, label="P0")
sns.histplot(sim1.flatten().cpu().numpy(), ax=ax, label="P1")
ax.legend()

In [None]:
def kl_divergence(p_logits, q_logits):
    return torch.where(
        (p_logits != float("-inf")) & (q_logits == float("-inf")),
        float("nan"),
        torch.where(
            p_logits == float("-inf"),
            0,
            p_logits.softmax(-1)
            * (p_logits.log_softmax(-1) - q_logits.log_softmax(-1)),
        ),
    ).sum(dim=-1)


def measure_neuron_effect(neuron: int):
    def neuron_hook(value, hook):
        value[:, :, neuron] = 0
        return value

    patched_logits = model.run_with_hooks(
        focus_games[:, :-1],
        fwd_hooks=[(utils.get_act_name("mlp_mid", 0), neuron_hook)],
    )
    loss = (
        kl_divergence(patched_logits, logits).mean()
        + kl_divergence(logits, patched_logits).mean()
    )
    return loss


data = []
for neuron in tqdm(range(model.cfg.d_mlp)):
    data.append({"neuron": neuron, "loss": measure_neuron_effect(neuron).item()})

In [None]:
with open("out/model/exp21/neurons_40000.json", "r") as f:
    data = json.load(f)

df = pd.DataFrame({"neuron": int(k), "importance": -v} for k, v in data.items())

fig, ax = plt.subplots(figsize=(4, 3))
sns.ecdfplot(data=df, x="importance", ax=ax)
ax.set_xscale("symlog", linthresh=1e-3)
ax.set(xlabel="Increaes in loss")

# draw 90% line
ax.axvline(df["importance"].quantile(0.9), color="k", linestyle="--")
ax.text(
    df["importance"].quantile(0.9) + 1e-4,
    0.75,
    f"{df['importance'].quantile(0.9):.2e}",
    color="k",
)

fig.savefig("out/figs/neuron_importance.pdf", bbox_inches="tight")

# df = df.sort_values("importance", ascending=False)
# df["importance"] -= df["importance"].min()
# df["cum_importance"] = df["importance"].cumsum()
# df["cum_importance"] /= df["cum_importance"].max()

# top_neurons = df[df["cum_importance"] < 0.99]
# len(top_neurons)

# print(df["neuron"].to_list().index(505))
# display(top_neurons)

In [None]:
visualise_neuron_input_output(423)

In [None]:
target_vocab = [12, 0, 1, 2, 3, 4, 5, 6, 7, 8]
E = model.W_E[target_vocab]
P = model.W_pos[:10]


def labels_for(A):
    if A is E:
        labels = [ds_train.decode_one(i) for i in target_vocab]
        return [l if isinstance(l, str) else f"$E_{l}$" for l in labels]
    elif A is P:
        return [f"$P_{i}$" for i in range(len(A))]
    assert False


def title_for(A):
    if A is E:
        return "W_E"
    elif A is P:
        return "W_P"
    raise ValueError("Unknown matrix")


fig = plt.figure(figsize=(9, 5))
subfigs = fig.subfigures(1, 3, width_ratios=[1, 1, 0.07])

for i, subfig in enumerate(subfigs[:2]):
    subfig.suptitle(f"Head {i}", fontsize="xx-large")
    subfig.supxlabel("Key", fontsize="x-large")
    subfig.supylabel("Query", fontsize="x-large")
    axs = subfig.subplots(2, 2)
    for j, (A, B) in enumerate([(E, E), (E, P), (P, E), (P, P)]):
        sim = A @ model.QK[0, i].AB @ B.T
        if A is P and B is P:
            sim = torch.tril(sim)
        if A is E:
            sim[0, 1:] = 0
        sns.heatmap(
            sim.cpu(),
            ax=axs.flat[j],
            vmax=0.06,
            center=0,
            yticklabels=labels_for(A),
            xticklabels=labels_for(B),
            cmap="RdBu",
            square=True,
            cbar=False,
            linewidth=0.5,
            linecolor="black",
            clip_on=False,
        )
        # axs.flat[j].set(title=f"${title_for(A)}^T W_Q^T W_K {title_for(B)}$")

cbar_ax = subfigs[-1].add_subplot()
im = cm.ScalarMappable(cmap="RdBu", norm=colors.Normalize(vmin=-0.06, vmax=0.06))
fig.colorbar(im, cax=cbar_ax)
# fig.tight_layout()

# fig.savefig("out/figs/attention.pdf", bbox_inches="tight")

In [None]:
E = torch.concat([model.W_E[:9], model.W_pos[:10]], dim=0)
print(f"{E.shape=}")
# P = model.W_pos[:10]

# sim = cosine_sim(E, E)
sim = torch.tensordot(E, E, dims=([-1], [-1]))
sim.fill_diagonal_(0)

fig, ax = plt.subplots()
labels = [ds_train.decode_one(i) for i in range(9)] + list(range(10))
sns.heatmap(
    sim.cpu(),
    cmap="RdBu",
    center=0,
    square=True,
    ax=ax,
    yticklabels=labels,
    xticklabels=labels,
)

In [None]:
sim = torch.tensordot(model.W_out[0], model.W_U.T, dims=([-1], [-1]))
# sim = cosine_sim(model.W_out[0], model.W_U.T)
sim = sim - torch.amin(sim, dim=-1, keepdim=True)
order_ = torch.argsort(torch.amax(sim.abs(), dim=(-1)), descending=True)
sim = sim[order_]

fig, ax = plt.subplots(figsize=(6, 100))
sns.heatmap(
    sim.cpu(),
    cmap="RdBu",
    center=0,
    square=True,
    ax=ax,
    yticklabels=order_.tolist(),
)

In [None]:
def qk_solbol_indices(head_idx):
    E = model.W_E[[0, 1, 2, 3, 4, 5, 6, 7, 8]]
    P = model.W_pos[[1, 2, 3, 4, 5, 6, 7, 8, 9]]
    EP = rearrange(E, "v d -> 1 v d") + rearrange(P, "p d -> p 1 d")
    EP = normalize(EP, dim=-1)
    EP_Q = EP @ model.W_Q[0, head_idx] + model.b_Q[0, head_idx]
    EP_K = EP @ model.W_K[0, head_idx] + model.b_K[0, head_idx]
    A = torch.tensordot(EP_Q, EP_K, dims=([-1], [-1]))
    casual_mask = torch.tril(torch.ones(len(P), len(P)))
    casual_mask = rearrange(casual_mask, "i j -> i 1 j 1")
    A.masked_fill_(casual_mask == 0, 0)

    g0 = torch.mean(A)

    def g1_fn(i):
        dims = [j for j in range(A.ndim) if j != i]
        return torch.mean(A, dim=dims) - g0

    g1 = [g1_fn(i) for i in range(A.ndim)]

    def g2_fn(i, j):
        assert i < j
        dims = [k for k in range(A.ndim) if k != i and k != j]
        return torch.mean(A, dim=dims) - g0 - g1[i].unsqueeze(-1) - g1[j]

    g2 = [[g2_fn(i, j) for i in range(j)] for j in range(A.ndim)]

    v = torch.var(A, unbiased=False).item()
    v1 = [torch.var(g, unbiased=False).item() for g in g1]
    v2 = [[torch.var(g, unbiased=False).item() for g in gs] for gs in g2]

    sobol = {
        "$p_q$": v1[0] / v,
        "$t_q$": v1[1] / v,
        "$p_k$": v1[2] / v,
        "$t_k$": v1[3] / v,
        "$t_q, p_q$": v2[1][0] / v,
        "$p_q, p_k$": v2[2][0] / v,
        "$t_q, p_k$": v2[2][1] / v,
        "$p_q, t_k$": v2[3][0] / v,
        "$t_q, t_k$": v2[3][1] / v,
        "$t_k, p_k$": v2[3][2] / v,
    }
    sobol["Higher\norder"] = 1 - sum(sobol.values())
    return sobol


data = []
for head_idx in range(model.cfg.n_heads):
    data.append({"Head": head_idx, **qk_solbol_indices(head_idx)})
df = pd.DataFrame(data).melt(id_vars="Head")

fig, ax = plt.subplots(figsize=(4, 3))
order = df[df["Head"] == 0].sort_values("value", ascending=False)["variable"]
order = order[:5]
sns.barplot(data=df, x="variable", y="value", hue="Head", ax=ax, order=order)
ax.set(
    ylabel="Variance of $a^{(i)}(t_q, p_q, t_k, p_k)$ explained",
    xlabel="",
)
# fig.savefig("out/figs/sobol.pdf", bbox_inches="tight")