In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
!export PYTHONPATH="${PYTHONPATH}:/home/dc755/idiots"

In [None]:
from pathlib import Path
import json

import jax
import jax.numpy as jnp
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import neural_tangents as nt
from einops import rearrange
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

from idiots.dataset.dataloader import DataLoader
from idiots.experiments.grokking.training import restore as restore_grokking
from idiots.experiments.grokking.training import eval_step
from idiots.experiments.classification.training import (
    restore as restore_classification,
    restore_partial as restore_partial_classification,
)
from idiots.utils import metrics

In [None]:
checkpoint_dir = Path("logs/checkpoints/mnist/exp26/checkpoints")
N_train = 32
N_test = 256

mngr, config, state, ds_train, ds_test = restore_classification(checkpoint_dir, 0)

kernel_fn = nt.empirical_kernel_fn(
    state.apply_fn,
    vmap_axes=0,
    implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES,
)

X_train, _, Y_train, _ = train_test_split(
    ds_train["x"], ds_train["y"], train_size=N_train
)
X_test, _, Y_test, _ = train_test_split(ds_test["x"], ds_test["y"], train_size=N_test)


def eval_init(init_scale: float):
    params = state.params
    params = jax.tree_map(lambda x: x * init_scale, params)

    def custom_kernel(X1, X2):
        kernel_fn_batched = nt.batch(kernel_fn, batch_size=64)
        return kernel_fn_batched(X1, X2, "ntk", params)

    svc = SVC(kernel=custom_kernel)
    svc.fit(X_train, Y_train)

    predictions = svc.predict(X_test)
    accuracy = accuracy_score(Y_test, predictions)

    return {"accuracy": accuracy, "init_scale": init_scale}

In [None]:
data = []
for init_scale in np.geomspace(1e-4, 0.2, 100):
    data.append(eval_init(init_scale))

In [None]:
df = pd.DataFrame(data)
df["init_scale"] *= config.model.init_scale

ax = sns.lineplot(data=df, x="init_scale", y="accuracy", marker="o")