In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.5
# %env XLA_PYTHON_CLIENT_PREALLOCATE=false

In [None]:
from functools import partial
from pathlib import Path

import jax
import jax.numpy as jnp
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import neural_tangents as nt

from idiots.dataset.dataloader import DataLoader
from idiots.experiments.grokking.training import restore, eval_step
from idiots.utils import metrics

In [None]:
def preprocess_labels(y, num_classes: int):
    """Takes inputs of shape (n) -> (n k) where k is the number of classes

    Also centers the labels so that the mean is zero.
    """
    y = jax.nn.one_hot(y, num_classes)
    y = y - jnp.mean(y, axis=0)
    return y

In [None]:
def losses_after_ntk_descent(
    init_state,
    ds_train,
    ds_test,
    train_size: int,
    test_size: int,
):
    # Load and preprocess the data
    x_train = ds_train["x"][:train_size]
    y_train_raw = ds_train["y"][:train_size]
    y_train = preprocess_labels(
        y_train_raw, num_classes=ds_train.features["y"].num_classes
    )
    x_test = ds_test["x"][:test_size]
    y_test_raw = ds_test["y"][:test_size]
    y_test = preprocess_labels(
        y_test_raw, num_classes=ds_test.features["y"].num_classes
    )

    # Perform kernel descent
    @partial(nt.batch, batch_size=64, store_on_device=True)
    def kernel_fn(x1, x2, params):
        k = nt.empirical_ntk_fn(init_state.apply_fn, trace_axes=(), vmap_axes=0)(
            x1, x2, params
        )
        return k

    k_train_train = kernel_fn(x_train, x_train, init_state.params)
    predict_fn = nt.predict.gradient_descent_mse(
        k_train_train, y_train, trace_axes=(), diag_reg=1e-3
    )

    # Make the predictions
    y_train_0 = init_state.apply_fn(init_state.params, x_train)
    y_test_0 = init_state.apply_fn(init_state.params, x_test)
    k_test_train = kernel_fn(x_test, x_train, init_state.params)
    y_train_t, y_test_t = predict_fn(None, y_train_0, y_test_0, k_test_train)

    # Compute the accuracy
    y_pred_train = jnp.argmax(y_train_t, axis=-1)
    y_pred_test = jnp.argmax(y_test_t, axis=-1)
    acc_train = jnp.mean(y_pred_train == y_train_raw)
    acc_test = jnp.mean(y_pred_test == y_test_raw)

    # Compute the loss
    loss_train = jnp.mean(jnp.square(y_train_t - y_train))
    loss_test = jnp.mean(jnp.square(y_test_t - y_test))

    return loss_train, loss_test, acc_train, acc_test

In [None]:
# checkpoint_dir = Path("/home/dc755/idiots/logs/grokking/exp22/checkpoints")
checkpoint_dir = Path("logs/grokking/exp22/checkpoints")


def checkpoint_ntk_descent_losses(step: int):
    config, state, ds_train, ds_test = restore(checkpoint_dir, step)

    loss_train, loss_test, acc_train, acc_test = losses_after_ntk_descent(
        state, ds_train, ds_test, train_size=256, test_size=128
    )
    return loss_train.item(), loss_test.item(), acc_train.item(), acc_test.item()

In [None]:
data = []
for step in range(5000, 20000, 1000):
    loss_train, loss_test, acc_train, acc_test = checkpoint_ntk_descent_losses(step)
    print(step, loss_test, acc_test)  # print to make sure they're not NaN
    data.append({"step": step, "split": "train", "loss": loss_train, "acc": acc_train})
    data.append({"step": step, "split": "test", "loss": loss_test, "acc": acc_test})

In [None]:
df = pd.DataFrame(data)
ax = sns.lineplot(data=df, x="step", y="acc", hue="split", marker="o")

In [None]:
batch_size = 256


def eval_checkpoint(step):
    config, state, ds_train, ds_test = restore(checkpoint_dir, step)

    def eval_loss_acc(ds):
        for batch in DataLoader(ds, batch_size):
            logs = eval_step(state, batch, config.loss_variant)
            metrics.log(**logs)
        [losses, accuracies] = metrics.collect("eval_loss", "eval_accuracy")
        loss = jnp.concatenate(losses).mean().item()
        acc = jnp.concatenate(accuracies).mean().item()
        return loss, acc

    train_loss, train_acc = eval_loss_acc(ds_train)
    test_loss, test_acc = eval_loss_acc(ds_test)

    return train_loss, train_acc, test_loss, test_acc

In [None]:
raw_losses_data = []
for step in range(5000, 20000, 1000):
    train_loss, train_acc, test_loss, test_acc = eval_checkpoint(step)
    raw_losses_data.append(
        {"step": step, "split": "train", "loss": train_loss, "acc": train_acc}
    )
    raw_losses_data.append(
        {"step": step, "split": "test", "loss": test_loss, "acc": test_acc}
    )

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(10, 6), sharey="row")
axs = axs.flatten()
df_raw = pd.DataFrame(raw_losses_data)
df_ntk = pd.DataFrame(data)

sns.lineplot(data=df_raw, x="step", y="acc", hue="split", marker="o", ax=axs[0])
sns.lineplot(data=df_ntk, x="step", y="acc", hue="split", marker="o", ax=axs[1])
sns.lineplot(data=df_raw, x="step", y="loss", hue="split", marker="o", ax=axs[2])
sns.lineplot(data=df_ntk, x="step", y="loss", hue="split", marker="o", ax=axs[3])

axs[0].set(title="Training curve", ylabel="Accuracy")
axs[1].set(title="NTK descent (infinite time) curve")
axs[2].set(ylabel="MSE")
fig.tight_layout()