In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false

In [None]:
from pathlib import Path
import random

import jax
import neural_tangents as nt
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from einops import rearrange
from sklearn.model_selection import train_test_split
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

import scienceplots

from idiots.experiments.compute_results.compute_results import restore_checkpoint

plt.style.use(["science", "grid"])

In [None]:
# checkpoint_dir = Path("logs/checkpoints/mnist/mnist_grokking_slower/checkpoints")
# checkpoint_dir = Path("logs/checkpoints/mnist/mnist_gd_grokking/checkpoints")
# checkpoint_dir = Path("logs/checkpoints/grokking/division_47/checkpoints")
# checkpoint_dir = Path("logs/checkpoints/gradient_flow/exp63/checkpoints")
# checkpoint_dir = Path("logs/checkpoints/grokking/exp112/checkpoints")
# checkpoint_dir = Path("logs/checkpoints/grokking/exp123/checkpoints")
# checkpoint_dir = Path("logs/checkpoints/mnist/mnist_adamw/checkpoints/")
# checkpoint_dir = Path("logs/checkpoints/grokking/division_adamw_mlp/checkpoints/")
# checkpoint_dir = Path(
#     "logs/checkpoints/grokking/division_adamw_transformer/checkpoints/"
# )
checkpoint_dir = Path(
    "logs/checkpoints/grokking/division_adamw_mlp_1_layer/checkpoints/"
)

config, apply_fn, get_params, ds_train, ds_test, all_steps = restore_checkpoint(
    checkpoint_dir,
    # experiment_type="gradient_flow_mnist",
    # experiment_type="mnist",
    experiment_type="algorithmic",
)
n_classes = ds_train.features["y"].num_classes


def kernel_traced_fn(x, params):
    k = nt.batch(
        nt.empirical_ntk_fn(
            apply_fn,
            vmap_axes=0,
            trace_axes=(-1,),
            implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES,
        ),
        batch_size=config.dots_batch_size,
    )(x, None, params)
    return k


def kernel_fn(x, params):
    k = nt.batch(
        nt.empirical_ntk_fn(
            apply_fn,
            vmap_axes=0,
            trace_axes=(),
            implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES,
        ),
        batch_size=config.dots_batch_size,
    )(x, None, params)
    return k

In [None]:
config

In [None]:
# MNIST
x, _, y, _ = train_test_split(
    ds_test["x"], ds_test["y"], train_size=128, stratify=ds_test["y"]
)
order_ = np.argsort(y)
x, y = x[order_], y[order_]

In [None]:
ts = np.linspace(0, 2000, 11, dtype=int)

fig, axs = plt.subplots(1, len(ts), figsize=(16, 6))

# Ks = [kernel_fn(x, get_params(t)).mean((-2, -1)) for t in ts]
Ks = [kernel_traced_fn(x, get_params(t)) for t in ts]

for ax, t, K in zip(axs, ts, Ks):
    # vabs = max(abs(K.min()), abs(K.max()))
    ax.imshow(K, cmap="RdBu")
    ax.set_title(f"t={t}")
    ax.axis("off")

In [None]:
# Arithmetic
x = jnp.concat([ds_train["x"], ds_test["x"]], axis=0)
y = jnp.concat([ds_train["y"], ds_test["y"]], axis=0)
order_ = np.lexsort((x[:, 2], x[:, 0], y))
x = x[order_]
y = y[order_]

# indices = np.arange(1024, 1536)
indices = np.random.choice(len(x), 256, replace=False)
x, y = x[indices], y[indices]

K = kernel_fn(x, get_params(50000))
K.shape, K.min(), K.max()

In [None]:
# K_reduced = K
K_reduced = jnp.mean(K, axis=(-2, -1))
# K_reduced = jnp.mean(K, axis=(0, 1))
# K_reduced = K[0, 0]
print(f"{K_reduced.shape=}")
# y_pred = jnp.argmax(apply_fn(params, x), axis=-1)
# print(f"{y_pred=}")

# fig, axs = plt.subplots(2, 1, figsize=(10, 6), gridspec_kw={"height_ratios": [1, 10]})
# axs[0].imshow(rearrange(x, "b h w -> h (b w)"), aspect="equal", cmap="gray")
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
# labels = x[:, [0, 2]].tolist()
ax.imshow(K_reduced, cmap="RdBu")
ax.axis("off")

In [None]:
params = get_params(100000)
jax.tree_map(lambda x: x.shape, params)

In [None]:
params = get_params(100000)
# E = params["params"]["Embed_0"]["embedding"]
E = params["params"]["Dense_2"]["kernel"].T
# E = params["params"]["Dense_0"]["kernel"].T
print(f"{E.shape=}")

pca = PCA(n_components=2)
E = pca.fit_transform(E)
# tsne = TSNE(n_components=2, perplexity=30, n_iter=1000)
# E = tsne.fit_transform(E)
print(f"{E.shape=}")

fig, ax = plt.subplots(1, 1, figsize=(8, 8))
ax.scatter(x=E[:, 0], y=E[:, 1], c=np.arange(len(E)))

idx_1 = np.arange(len(E))
idx_2 = len(E) - idx_1
idx_2[0] = 0
print(np.stack([idx_1, idx_2]))
for x1, y1, x2, y2 in zip(E[idx_1, 0], E[idx_1, 1], E[idx_2, 0], E[idx_2, 1]):
    ax.plot([x1, x2], [y1, y2], color="black", alpha=0.1)

for i, (x, y) in enumerate(E):
    ax.text(x, y, str(i), fontsize=12)

In [None]:
from idiots.dataset.algorithmic import OPERATIONS

op = OPERATIONS["x / y (mod 47)"]
x = range(op["n_classes"])


# print table
table = [[(i * j) % len(x) for j in x] for i in x]
df = pd.DataFrame(table, index=x, columns=x)
with pd.option_context("display.max_rows", None, "display.max_columns", None):
    display(df)

# fig, axs = plt.subplots(2, 2, figsize=(8, 8))
# table = [[op["fn"](i, j) for j in x] for i in x]
# axs[0, 0].imshow(table, cmap="RdBu")
# table = [[op["fn"](i, j) for j in reversed(x)] for i in x]
# axs[0, 1].imshow(table, cmap="RdBu")
# table = [[op["fn"](i, j) for j in x] for i in reversed(x)]
# axs[1, 0].imshow(table, cmap="RdBu")
# table = [[op["fn"](i, j) for j in reversed(x)] for i in reversed(x)]
# axs[1, 1].imshow(table, cmap="RdBu")

In [None]:
# MNIST representations and kernel evolution

X, _, Y, _ = train_test_split(
    ds_test["x"], ds_test["y"], train_size=512, stratify=ds_test["y"]
)
order_ = np.argsort(Y)
X, Y = X[order_], Y[order_]

ts = range(0, 100000 + 1, 20000)

fig, axs = plt.subplots(2, len(ts), figsize=(2 * len(ts), 4), squeeze=False)

for axs_, t in zip(axs.T, ts):
    ax1, ax2 = axs_[0], axs_[1]
    params = get_params(t)
    W_1 = params["params"]["Dense_0"]["kernel"]
    b_1 = params["params"]["Dense_0"]["bias"]
    E = jax.nn.relu(X.reshape(len(X), -1) / 255 @ W_1 + b_1)
    pca = PCA(n_components=2)
    E = pca.fit_transform(E)
    v1, v2 = pca.explained_variance_ratio_

    ax1.scatter(x=E[:, 0], y=E[:, 1], c=Y, cmap="tab10", alpha=0.7, s=5)
    ax1.set(xticks=[], yticks=[], aspect=1 / ax1.get_data_ratio())
    ax1.set_xlabel(f"PCA 1: {v1*100:.2f}\%", fontsize="small")
    ax1.set_ylabel(f"PCA 2: {v2*100:.2f}\%", fontsize="small")

    K = kernel_traced_fn(X, params)
    vmax = np.percentile(K, 99.5)
    im = ax2.imshow(K, cmap="Blues", vmin=0, vmax=vmax)
    # cbar = plt.colorbar(im, ax=ax2)
    ax2.set(xticks=[], yticks=[], xlabel=f"Step {t}")
    for spine in ax2.spines.values():
        spine.set_visible(False)

fig.tight_layout()
# fig.savefig("logs/plots/mnist-representations-blue.pdf", bbox_inches="tight", dpi=300)

In [None]:
# Algorithmic representations and kernel evolution

X = np.concatenate([ds_train["x"], ds_test["x"]], axis=0)
Y = np.concatenate([ds_train["y"], ds_test["y"]], axis=0)
# print(jnp.unique(Y, return_counts=True))

# filter_ = X[:, 0] == 24
# filter_ = (1 <= Y) & (Y <= 46 - 1)
filter_ = (5 <= Y) & (Y <= 46 - 5)
# filter_ = (10 <= Y) & (Y <= 46 - 10)
# filter_ = slice(None)
# filter_ = (14 <= Y) & (Y <= 33)
X, Y = X[filter_], Y[filter_]

# X, _, Y, _ = train_test_split(xs, ys, train_size=len(xs), stratify=ys)

ts = range(0, 50000 + 1, 10000)

fig, axs = plt.subplots(1, len(ts), figsize=(2 * len(ts), 2), squeeze=False)

for ax, t in zip(axs[0], ts):
    params = get_params(t)
    E = params["params"]["Embed_0"]["embedding"]
    E = E[X].reshape(len(X), -1)
    pca = PCA(n_components=2)
    E = pca.fit_transform(E)
    v1, v2 = pca.explained_variance_ratio_

    ax.scatter(x=E[:, 0], y=E[:, 1], c=Y, cmap="tab10", alpha=0.6, s=1)
    ax.set(xticks=[], yticks=[], aspect=1 / ax.get_data_ratio(), title=f"Step {t}")
    ax.set_xlabel(f"PCA 1: {v1*100:.2f}\%", fontsize="small")
    ax.set_ylabel(f"PCA 2: {v2*100:.2f}\%", fontsize="small")

# fig.savefig(
#     "logs/plots/algorithmic-representations-mlp-1.jpg",
#     bbox_inches="tight",
#     dpi=300,
# )

In [None]:
n = 10000
samples = ds_test.shuffle().select(range(n))
xs, ys = samples["x"], samples["y"]
ys_pred = jnp.argmax(state.apply_fn(state.params, xs), axis=-1)

# plot the confusion matrix
confusion_matrix = pd.crosstab(ys, ys_pred, rownames=["True"], colnames=["Predicted"])
# remove the diagonal
confusion_matrix.values[np.arange(10), np.arange(10)] = 0
sns.heatmap(confusion_matrix, annot=True, square=True, fmt="d", cmap="RdBu", center=0)

In [None]:
K_flat = rearrange(K, "b1 b2 d1 d2 -> (b1 d1) (b2 d2)")
lambda_, es = jnp.linalg.eigh(K_flat)
print(jnp.linalg.matrix_rank(K_flat))

In [None]:
fig, axs = plt.subplots(3, 3, figsize=(10, 8))
# sns.lineplot(lambda_[::-1], ax=ax)
# ax.set(yscale="log")

for eig_idx in range(1, 10):
    ax = axs.flatten()[eig_idx - 1]
    e = rearrange(es[:, -eig_idx], "(b d) -> b d", d=10)
    sns.heatmap(e, ax=ax, cmap="RdBu", center=0)
    ax.set(
        yticks=[],
        ylabel="Sample (sorted)",
        xlabel="Class",
        title=f"Eigenvector {eig_idx} (λ={lambda_[-eig_idx]:.1e})",
    )

fig.tight_layout()

# eig_idx = -1
# e = rearrange(es[:, eig_idx], "(b d) -> b d", d=10)

# print(lambda_[-1])
# sns.heatmap(e, ax=ax, cmap="RdBu", center=0)
# ax.set(yticks=[], ylabel="Sample (sorted)", xlabel="Class")

# e_max_idx = jnp.argmax(jnp.abs(e), axis=1)
# e_max = jnp.abs(e)[jnp.arange(e.shape[0]), e_max_idx]
# e_max_sign = e[jnp.arange(e.shape[0]), e_max_idx]
# top_samples = jnp.argsort(e_max)[-20:]

# img = x[top_samples] * rearrange(e_max_sign[top_samples], "b -> b 1 1")
# img = rearrange(img[::-1], "b h w -> h (b w)")
# v = jnp.max(jnp.abs(img))
# ax.imshow(img, aspect="equal", cmap="RdBu", vmin=-v, vmax=v)
# ax.set(xticks=[], yticks=[])