In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.5
# %env XLA_PYTHON_CLIENT_PREALLOCATE=false

In [None]:
import warnings
from functools import partial
from pathlib import Path
from itertools import chain

import jax
import jax.numpy as jnp
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import neural_tangents as nt
import orbax.checkpoint as ocp
import numpy as np
from sklearn.model_selection import train_test_split
from einops import rearrange

from idiots.dataset.dataloader import DataLoader
from idiots.experiments.grokking.training import restore as restore_grokking
from idiots.experiments.grokking.training import eval_step
from idiots.experiments.classification.training import (
    restore as restore_classification,
    restore_partial as restore_partial_classification,
)
from idiots.utils import metrics

In [None]:
warnings.filterwarnings("ignore")

In [None]:
def preprocess_labels(y, num_classes: int):
    """Takes inputs of shape (n) -> (n k) where k is the number of classes

    Also centers the labels so that the mean is zero.
    """
    y = jax.nn.one_hot(y, num_classes)
    y = y - jnp.mean(y, axis=-1, keepdims=True)
    return y

In [None]:
def losses_after_ntk_descent(
    apply_fn,
    init_params,
    params,
    ds_train,
    ds_test,
    train_size: int,
    test_size: int,
    batch_size: int = 64,
):
    # Load and preprocess the data
    x_train, _, y_train_raw, _ = train_test_split(
        ds_train["x"],
        ds_train["y"],
        train_size=train_size,
        stratify=ds_train["y"],
        random_state=0,
    )
    y_train = preprocess_labels(
        y_train_raw, num_classes=ds_train.features["y"].num_classes
    )
    x_test, _, y_test_raw, _ = train_test_split(
        ds_test["x"],
        ds_test["y"],
        train_size=test_size,
        stratify=ds_test["y"],
        random_state=0,
    )
    y_test = preprocess_labels(
        y_test_raw, num_classes=ds_test.features["y"].num_classes
    )

    # Perform kernel descent
    @partial(nt.batch, batch_size=batch_size)
    def kernel_fn(x1, x2):
        k = nt.empirical_ntk_fn(apply_fn, trace_axes=(), vmap_axes=0)(x1, x2, params)
        return k

    k_train_train = kernel_fn(x_train, x_train)
    predict_fn = nt.predict.gradient_descent_mse(
        k_train_train, y_train, trace_axes=(), diag_reg=0
    )

    # Make the predictions
    y_train_0 = apply_fn(init_params, x_train)
    y_test_0 = apply_fn(init_params, x_test)
    # y_train_0 = 0
    # y_test_0 = 0
    k_test_train = kernel_fn(x_test, x_train)
    y_train_t, y_test_t = predict_fn(None, y_train_0, y_test_0, k_test_train)

    # Compute the accuracy
    y_pred_train = jnp.argmax(y_train_t, axis=-1)
    y_pred_test = jnp.argmax(y_test_t, axis=-1)
    acc_train = jnp.mean(y_pred_train == y_train_raw)
    acc_test = jnp.mean(y_pred_test == y_test_raw)

    # Compute the loss
    loss_train = jnp.mean(jnp.square(y_train_t - y_train))
    loss_test = jnp.mean(jnp.square(y_test_t - y_test))

    return {
        "loss_train": loss_train.item(),
        "loss_test": loss_test.item(),
        "acc_train": acc_train.item(),
        "acc_test": acc_test.item(),
    }

In [None]:
# checkpoint_dir = Path("/home/dc755/idiots/logs/grokking/exp22/checkpoints")
checkpoint_dir = Path("logs/checkpoints/mnist/mnist_gd_grokking/checkpoints")
# checkpoint_dir = Path("logs/checkpoints/mnist/mnist_grokking_slower/checkpoints")

mngr, config, init_state, ds_train, ds_test = restore_classification(checkpoint_dir, 0)

data = []
for step in chain(range(0, 30000, 2000), range(30000, 100_000 + 1, 10000)):
    if step == 0:
        state = init_state
    else:
        state = mngr.restore(step, args=ocp.args.StandardRestore(init_state))
    out = losses_after_ntk_descent(
        init_state.apply_fn,
        jax.tree_map(lambda x: 8 * x / 8, init_state.params),
        state.params,
        ds_train,
        ds_test,
        train_size=64,
        test_size=128,
        batch_size=256,
    )
    print(step, out)
    data.append(
        {
            "step": step,
            "split": "train",
            "loss": out["loss_train"],
            "acc": out["acc_train"],
        }
    )
    data.append(
        {
            "step": step,
            "split": "test",
            "loss": out["loss_test"],
            "acc": out["acc_test"],
        }
    )

In [None]:
mngr, config, state, ds_train, ds_test = restore_classification(checkpoint_dir, 20000)

# Load and preprocess the data
x_train, _, y_train_raw, _ = train_test_split(
    ds_train["x"],
    ds_train["y"],
    train_size=128,
    stratify=ds_train["y"],
    random_state=0,
)
y_train = preprocess_labels(y_train_raw, num_classes=ds_train.features["y"].num_classes)
x_test, _, y_test_raw, _ = train_test_split(
    ds_test["x"],
    ds_test["y"],
    train_size=128,
    stratify=ds_test["y"],
    random_state=0,
)
y_test = preprocess_labels(y_test_raw, num_classes=ds_test.features["y"].num_classes)

kernel_fn = nt.batch(
    nt.empirical_ntk_fn(state.apply_fn, trace_axes=(), vmap_axes=0), batch_size=128
)

In [None]:
y_train

In [None]:
def f_t(t: float, lr: float, x):
    k_train_train = kernel_fn(x_train, x_train, state.params)
    k_train_train = rearrange(k_train_train, "b1 b2 d1 d2 -> (b1 d1) (b2 d2)")
    k_pred_train = kernel_fn(x, x_train, state.params)
    k_pred_train = rearrange(k_pred_train, "b1 b2 d1 d2 -> (b1 d1) (b2 d2)")
    k_train_train_bar = [
        [-lr * k_train_train, jnp.eye(k_train_train.shape[0])],
        [jnp.zeros_like(k_train_train), jnp.zeros_like(k_train_train)],
    ]
    k_train_train_bar = jnp.block(k_train_train_bar)
    exp_k = jax.scipy.linalg.expm(t * k_train_train_bar, max_squarings=32)
    exp_k = exp_k[: k_train_train.shape[0], k_train_train.shape[0] :]

    f_0_pred = state.apply_fn(state.params, x)
    f_0_pred = jnp.zeros_like(f_0_pred)
    b, d = f_0_pred.shape
    f_0_train = state.apply_fn(state.params, x_train)
    f_0_train = jnp.zeros_like(f_0_train)

    out = f_0_pred - lr * rearrange(
        k_pred_train @ exp_k @ (rearrange(f_0_train - y_train, "b d -> (b d) 1")),
        "(b d) 1 -> b d",
        b=b,
        d=d,
    )
    return out


data = []
for step in np.linspace(0, 5, 21):
    y_pred_train = f_t(step, 1e-2, x_train)
    y_pred_test = f_t(step, 1e-2, x_test)
    assert jnp.all(jnp.isfinite(y_pred_train)) and jnp.all(jnp.isfinite(y_pred_test))
    data.append(
        {
            "step": step,
            "split": "train",
            "loss": jnp.mean(jnp.square(y_pred_train - y_train)).item(),
            "acc": jnp.mean(jnp.argmax(y_pred_train, axis=-1) == y_train_raw).item(),
        }
    )
    data.append(
        {
            "step": step,
            "split": "test",
            "loss": jnp.mean(jnp.square(y_pred_test - y_test)).item(),
            "acc": jnp.mean(jnp.argmax(y_pred_test, axis=-1) == y_test_raw).item(),
        }
    )

In [None]:
df = pd.DataFrame(data)

fig, ax = plt.subplots(1, 2, figsize=(12, 4))
sns.lineplot(data=df, x="step", y="loss", hue="split", ax=ax[0], marker="o")
sns.lineplot(data=df, x="step", y="acc", hue="split", ax=ax[1], marker="o")

In [None]:
df = pd.DataFrame(data)
fig, axs = plt.subplots(1, 2, figsize=(10, 3))
sns.lineplot(data=df, x="step", y="acc", hue="split", marker="o", ax=axs[0])
sns.lineplot(data=df, x="step", y="loss", hue="split", marker="o", ax=axs[1])

In [None]:
batch_size = 256

mngr, config, init_state, ds_train, ds_test = restore_classification(checkpoint_dir, 0)


def eval_checkpoint(step):
    if step == 0:
        state = init_state
    else:
        state = mngr.restore(step, args=ocp.args.StandardRestore(init_state))

    def eval_loss_acc(ds):
        for batch in DataLoader(ds, batch_size):
            logs = eval_step(state, batch, config.loss_variant)
            metrics.log(**logs)
        [losses, accuracies] = metrics.collect("eval_loss", "eval_accuracy")
        loss = jnp.concatenate(losses).mean().item()
        acc = jnp.concatenate(accuracies).mean().item()
        return loss, acc

    train_loss, train_acc = eval_loss_acc(ds_train)
    test_loss, test_acc = eval_loss_acc(ds_test)

    return {
        "step": step,
        "train_loss": train_loss,
        "train_acc": train_acc,
        "test_loss": test_loss,
        "test_acc": test_acc,
    }

In [None]:
raw_losses_data = []
for step in chain(range(0, 30000, 2000), range(30000, 100000 + 1, 10000)):
    out = eval_checkpoint(step)
    print(step, out)
    raw_losses_data.append(
        {
            "step": step,
            "split": "train",
            "loss": out["train_loss"],
            "acc": out["train_acc"],
        }
    )
    raw_losses_data.append(
        {
            "step": step,
            "split": "test",
            "loss": out["test_loss"],
            "acc": out["test_acc"],
        }
    )

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(11, 4))
axs = axs.flatten()
df_raw = pd.DataFrame(raw_losses_data)
df_ntk = pd.DataFrame(data)
# df_ntk.loc[df_ntk["split"] == "train", "loss"] += 1e-8

df_ntk["Model"] = "NTK"
df_raw["Model"] = "Regular"
df = pd.concat([df_ntk, df_raw])
df = df.rename(columns={"split": "Split"})
df["Split"] = df["Split"].map({"train": "Train", "test": "Test"})

sns.lineplot(
    data=df, x="step", y="acc", hue="Split", style="Model", marker="o", ax=axs[0]
)
sns.lineplot(
    data=df, x="step", y="loss", hue="Split", style="Model", marker="o", ax=axs[1]
)

axs[0].set(ylabel="Accuracy", xlabel="Step")
axs[1].set(ylabel="MSE", yscale="log", xlabel="Step")
fig.tight_layout()