In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

import neural_tangents as nt
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from einops import rearrange

# from idiots.experiments.classification.training import restore
from idiots.experiments.grokking.training import restore

In [None]:
checkpoint_dir = Path("logs/mnist/mnist_adamw/checkpoints")
checkpoint_dir = Path("logs/grokking/division_47/checkpoints")
config, state, ds_train, ds_test = restore(checkpoint_dir, step=10000)
n_classes = ds_train.features["y"].num_classes

In [None]:
config

In [None]:
kernel_fn = nt.empirical_ntk_fn(
    state.apply_fn,
    trace_axes=(),
    vmap_axes=0,
    implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES,
)


def ntk(x1, x2):
    x2 = None if x1 is x2 else x2
    k = nt.batch(kernel_fn, batch_size=64)(x1, x2, state.params)
    # k = rearrange(k, "b1 b2 d1 d2 -> (b1 d1) (b2 d2)")
    return k

In [None]:
# MNIST
samples = ds_test.shuffle().select(range(128))
x, y = samples["x"], samples["y"]
order_ = np.argsort(y)
x, y = x[order_], y[order_]

K = ntk(x, x)
K.shape

In [None]:
# Arithmetic
samples = jnp.concat([ds_train["x"], ds_test["x"]], axis=0)
samples = samples[samples[:, 0] == 4]
order_ = np.argsort(samples[:, 2])
x = samples[order_]

K = ntk(x, x)
K.shape

In [None]:
# K_reduced = jnp.mean(K, axis=(-2, -1))
K_reduced = jnp.mean(K, axis=(0, 1))
# K_reduced = K[-2, -2]
print(f"{K_reduced.shape=}")
y_pred = jnp.argmax(state.apply_fn(state.params, x), axis=-1)
print(f"{y_pred=}")

fig, axs = plt.subplots(2, 1, figsize=(10, 6), gridspec_kw={"height_ratios": [1, 10]})
# axs[0].imshow(rearrange(x, "b h w -> h (b w)"), aspect="equal", cmap="gray")
sns.heatmap(K_reduced, ax=axs[1], square=True, cmap="RdBu", center=0)

In [None]:
n = 10000
samples = ds_test.shuffle().select(range(n))
xs, ys = samples["x"], samples["y"]
ys_pred = jnp.argmax(state.apply_fn(state.params, xs), axis=-1)

# plot the confusion matrix
confusion_matrix = pd.crosstab(ys, ys_pred, rownames=["True"], colnames=["Predicted"])
# remove the diagonal
confusion_matrix.values[np.arange(10), np.arange(10)] = 0
sns.heatmap(confusion_matrix, annot=True, square=True, fmt="d", cmap="RdBu", center=0)

In [None]:
K_flat = rearrange(K, "b1 b2 d1 d2 -> (b1 d1) (b2 d2)")
lambda_, es = jnp.linalg.eigh(K_flat)
print(jnp.linalg.matrix_rank(K_flat))

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 6))
sns.lineplot(lambda_[::-1], ax=ax)
ax.set(yscale="log")

# print(lambda_[-1])
# sns.heatmap(e1, ax=ax, cmap="RdBu", center=0)
# ax.set(yticks=[])

# eig_idx = -6
# e = rearrange(es[:, eig_idx], "(b d) -> b d", d=10)
# e_max_idx = jnp.argmax(jnp.abs(e), axis=1)
# e_max = jnp.abs(e)[jnp.arange(e.shape[0]), e_max_idx]
# e_max_sign = e[jnp.arange(e.shape[0]), e_max_idx]
# top_samples = jnp.argsort(e_max)[-20:]

# img = x[top_samples] * rearrange(e_max_sign[top_samples], "b -> b 1 1")
# img = rearrange(img[::-1], "b h w -> h (b w)")
# v = jnp.max(jnp.abs(img))
# ax.imshow(img, aspect="equal", cmap="RdBu", vmin=-v, vmax=v)
# ax.set(xticks=[], yticks=[])