In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import random

import jax.numpy as jnp
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import neural_tangents as nt

from idiots.dataset.dataloader import DataLoader
from idiots.experiments.grokking.training import eval_step, dots
from idiots.experiments.classification.training import restore
from idiots.utils import metrics

In [None]:
# checkpoint_dir = Path("/home/dc755/idiots/logs/grokking/exp22/checkpoints")
checkpoint_dir = Path("logs/mnist/mnist_adamw/checkpoints")
batch_size = 512

_, state, _, _ = restore(checkpoint_dir, 0)

kernel_fn = nt.empirical_ntk_fn(
    state.apply_fn,
    trace_axes=(),
    vmap_axes=0,
    implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES,
)


def compute_dots(kernel_fn, params, ds, sample_size: int, batch_size: int) -> int:
    random_indices = random.sample(range(len(ds)), sample_size)
    return dots(kernel_fn, params, ds.select(random_indices)["x"], batch_size).item()


def eval_checkpoint(step):
    config, state, ds_train, ds_test = restore(checkpoint_dir, step)

    def eval_loss_acc(ds):
        for batch in DataLoader(ds, batch_size):
            logs = eval_step(state, batch, config.loss_variant)
            metrics.log(**logs)
        [losses, accuracies] = metrics.collect("eval_loss", "eval_accuracy")
        loss = jnp.concatenate(losses).mean().item()
        acc = jnp.concatenate(accuracies).mean().item()
        return loss, acc

    train_loss, train_acc = eval_loss_acc(ds_train)
    test_loss, test_acc = eval_loss_acc(ds_test)
    train_dots = compute_dots(kernel_fn, state.params, ds_train, 128, 64)
    test_dots = compute_dots(kernel_fn, state.params, ds_test, 128, 64)

    return {
        "step": step,
        "train_loss": train_loss,
        "train_acc": train_acc,
        "test_loss": test_loss,
        "test_acc": test_acc,
        "train_dots": train_dots,
        "test_dots": test_dots,
    }

In [None]:
eval_checkpoint(0)

In [None]:
data = []
for step in range(0, 20000 + 1, 1000):
    data.append(eval_checkpoint(step))

In [None]:
df = pd.DataFrame(data)

df_loss = df[["step", "train_loss", "test_loss"]]
df_loss = df_loss.melt("step", var_name="split", value_name="loss")
df_loss["split"] = df_loss["split"].str.replace("_loss", "")


df_acc = df[["step", "train_acc", "test_acc"]]
df_acc = df_acc.melt("step", var_name="split", value_name="accuracy")
df_acc["split"] = df_acc["split"].str.replace("_acc", "")

df_dots = df[["step", "train_dots", "test_dots"]]
df_dots = df_dots.melt("step", var_name="split", value_name="dots")
df_dots["split"] = df_dots["split"].str.replace("_dots", "")

fig, axs = plt.subplots(1, 3, figsize=(12, 3))

sns.lineplot(data=df_loss, x="step", y="loss", hue="split", marker="o", ax=axs[0])
sns.lineplot(data=df_acc, x="step", y="accuracy", hue="split", marker="o", ax=axs[1])
sns.lineplot(data=df_dots, x="step", y="dots", hue="split", marker="o", ax=axs[2])

axs[0].set(ylim=(0, 0.6))
axs[1].set(ylim=(0.9, 1))
fig.tight_layout()