In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

import jax
import jax.numpy as jnp
import neural_tangents as nt
from sklearn.model_selection import train_test_split
import orbax.checkpoint as ocp
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from einops import rearrange

from idiots.experiments.compute_results.compute_results import restore_checkpoint

In [None]:
# checkpoint_dir = Path("logs/checkpoints/gradient_flow/exp34/checkpoints")
# checkpoint_dir = Path("logs/checkpoints/mnist/exp66/checkpoints")
# checkpoint_dir = Path("logs/checkpoints/mnist/exp75/checkpoints")
# checkpoint_dir = Path("logs/checkpoints/mnist/mnist_gd_grokking/checkpoints")
checkpoint_dir = Path("logs/checkpoints/gradient_flow/exp39/checkpoints")
config, apply_fn, get_params, ds_train, ds_test, all_steps = restore_checkpoint(
    checkpoint_dir,
    experiment_type="gradient_flow_",
    # experiment_type="mnist",
    # experiment_type="grokking",
)


def kernel_fn(x, params):
    k = nt.batch(
        nt.empirical_ntk_fn(
            apply_fn,
            vmap_axes=0,
            trace_axes=(),
            implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES,
        ),
        batch_size=64,
    )(x, None, params)
    k = rearrange(k, "b1 b2 d1 d2 -> (b1 d1) (b2 d2)")
    return k


@jax.jit
def kernel_distance(k1, k2):
    return 1 - (
        jnp.trace(k1 @ k2.T)
        / (jnp.sqrt(jnp.trace(k1 @ k1.T)) * jnp.sqrt(jnp.trace(k2 @ k2.T)))
    )

In [None]:
xs = jnp.concatenate([ds_train["x"], ds_test["x"]], axis=0)
ys = jnp.concatenate([ds_train["y"], ds_test["y"]], axis=0)
x, _, y, _ = train_test_split(xs, ys, train_size=128, stratify=ys)

data = []

prev_kernel = kernel_fn(x, get_params(0))
for step in all_steps[1:]:
    kernel = kernel_fn(x, get_params(step))
    data.append({"step": step, "distance": kernel_distance(prev_kernel, kernel).item()})
    prev_kernel = kernel

In [None]:
import scienceplots

df = pd.DataFrame(data)
df["velocity"] = df["distance"] / (df["step"] - df["step"].shift(1, fill_value=0))

with plt.style.context(["science", "grid"]):
    fig, ax = plt.subplots(figsize=(3, 2.6))
    sns.lineplot(data=df, x="step", y="velocity", ax=ax)
    ax.set(ylim=(0, 1e-5), ylabel="Kernel velocity", xlabel="$t$")

    # fig.savefig("logs/plots/gradient_flow_velocity.pdf", bbox_inches="tight")