In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false

In [None]:
from pathlib import Path
import random

import neural_tangents as nt
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from einops import rearrange

from idiots.experiments.compute_results.compute_results import restore_checkpoint

In [None]:
# checkpoint_dir = Path("logs/checkpoints/mnist/mnist_grokking_slower/checkpoints")
# checkpoint_dir = Path("logs/checkpoints/mnist/mnist_gd_grokking/checkpoints")
# checkpoint_dir = Path("logs/checkpoints/grokking/division_47/checkpoints")
# checkpoint_dir = Path("logs/checkpoints/gradient_flow/exp37/checkpoints")
# checkpoint_dir = Path("logs/checkpoints/grokking/exp112/checkpoints")
checkpoint_dir = Path("logs/checkpoints/grokking/exp123/checkpoints")

config, apply_fn, get_params, ds_train, ds_test, all_steps = restore_checkpoint(
    checkpoint_dir,
    # experiment_type="gradient_flow_",
    # experiment_type="mnist",
    experiment_type="algorithmic",
)
n_classes = ds_train.features["y"].num_classes


def kernel_fn(x, params):
    k = nt.batch(
        nt.empirical_ntk_fn(
            apply_fn,
            vmap_axes=0,
            trace_axes=(-1,),
            implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES,
        ),
        batch_size=64,
    )(x, None, params)
    return k

In [None]:
config

In [None]:
# MNIST
samples = ds_test.shuffle().select(range(128))
x, y = samples["x"], samples["y"]
order_ = np.argsort(y)
x, y = x[order_], y[order_]

K = kernel_fn(x, get_params(all_steps[-1]))
K.shape

In [None]:
# Arithmetic
x = jnp.concat([ds_train["x"], ds_test["x"]], axis=0)
y = jnp.concat([ds_train["y"], ds_test["y"]], axis=0)
order_ = np.lexsort((x[:, 2], x[:, 0], y))
x = x[order_]
y = y[order_]

indices = np.arange(2048)
# indices = np.random.choice(len(x), 256, replace=False)
# [indices] = np.where(np.isin(x[:0], np.arange(1)))
# indices = indices[:256]
# indices = np.arange(2048)
x, y = x[indices], y[indices]

K = kernel_fn(x, get_params(16000))
K.shape

In [None]:
# new_order = np.ar
# K_reduced = K[new_order][:, new_order]
K_reduced = K
# K_reduced = jnp.mean(K, axis=(-2, -1))
# K_reduced = jnp.mean(K, axis=(0, 1))
# K_reduced = K[0, 0]
print(f"{K_reduced.shape=}")
# y_pred = jnp.argmax(apply_fn(params, x), axis=-1)
# print(f"{y_pred=}")

# fig, axs = plt.subplots(2, 1, figsize=(10, 6), gridspec_kw={"height_ratios": [1, 10]})
# axs[0].imshow(rearrange(x, "b h w -> h (b w)"), aspect="equal", cmap="gray")
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
# labels = x[:, [0, 2]].tolist()
sns.heatmap(
    K_reduced,
    ax=ax,
    square=True,
    cmap="RdBu",
    # center=0,
    # xticklabels=[],
    # yticklabels=[],
    cbar=False,
)

In [None]:
n = 10000
samples = ds_test.shuffle().select(range(n))
xs, ys = samples["x"], samples["y"]
ys_pred = jnp.argmax(state.apply_fn(state.params, xs), axis=-1)

# plot the confusion matrix
confusion_matrix = pd.crosstab(ys, ys_pred, rownames=["True"], colnames=["Predicted"])
# remove the diagonal
confusion_matrix.values[np.arange(10), np.arange(10)] = 0
sns.heatmap(confusion_matrix, annot=True, square=True, fmt="d", cmap="RdBu", center=0)

In [None]:
K_flat = rearrange(K, "b1 b2 d1 d2 -> (b1 d1) (b2 d2)")
lambda_, es = jnp.linalg.eigh(K_flat)
print(jnp.linalg.matrix_rank(K_flat))

In [None]:
fig, axs = plt.subplots(3, 3, figsize=(10, 8))
# sns.lineplot(lambda_[::-1], ax=ax)
# ax.set(yscale="log")

for eig_idx in range(1, 10):
    ax = axs.flatten()[eig_idx - 1]
    e = rearrange(es[:, -eig_idx], "(b d) -> b d", d=10)
    sns.heatmap(e, ax=ax, cmap="RdBu", center=0)
    ax.set(
        yticks=[],
        ylabel="Sample (sorted)",
        xlabel="Class",
        title=f"Eigenvector {eig_idx} (λ={lambda_[-eig_idx]:.1e})",
    )

fig.tight_layout()

# eig_idx = -1
# e = rearrange(es[:, eig_idx], "(b d) -> b d", d=10)

# print(lambda_[-1])
# sns.heatmap(e, ax=ax, cmap="RdBu", center=0)
# ax.set(yticks=[], ylabel="Sample (sorted)", xlabel="Class")

# e_max_idx = jnp.argmax(jnp.abs(e), axis=1)
# e_max = jnp.abs(e)[jnp.arange(e.shape[0]), e_max_idx]
# e_max_sign = e[jnp.arange(e.shape[0]), e_max_idx]
# top_samples = jnp.argsort(e_max)[-20:]

# img = x[top_samples] * rearrange(e_max_sign[top_samples], "b -> b 1 1")
# img = rearrange(img[::-1], "b h w -> h (b w)")
# v = jnp.max(jnp.abs(img))
# ax.imshow(img, aspect="equal", cmap="RdBu", vmin=-v, vmax=v)
# ax.set(xticks=[], yticks=[])