In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

import jax
import jax.numpy as jnp
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

from idiots.dataset.dataloader import DataLoader
from idiots.experiments.grokking.training import restore, eval_step
from idiots.utils import metrics
import neural_tangents as nt
from einops import rearrange
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
import json

In [None]:
checkpoint_dir = Path("../../logs/division/exp21/checkpoints")
batch_size = 512


def eval_checkpoint(step):
    config, state, ds_train, ds_test = restore(checkpoint_dir, step)

    def eval_loss_acc(ds):
        for batch in DataLoader(ds, batch_size):
            logs = eval_step(state, batch, config.loss_variant)
            metrics.log(**logs)
        [losses, accuracies] = metrics.collect("eval_loss", "eval_accuracy")
        loss = jnp.concatenate(losses).mean().item()
        acc = jnp.concatenate(accuracies).mean().item()
        return loss, acc

    train_loss, train_acc = eval_loss_acc(ds_train)
    test_loss, test_acc = eval_loss_acc(ds_test)

    return state, ds_train, ds_test, train_loss, train_acc, test_loss, test_acc

In [None]:
data = []
for step in range(0, 50000, 10000):
    (
        state,
        ds_train,
        ds_test,
        train_loss,
        train_acc,
        test_loss,
        test_acc,
    ) = eval_checkpoint(step)
    data.append(
        {
            "step": step,
            "state": state,
            "ds_train": ds_train,
            "ds_test": ds_test,
            "train_loss": train_loss,
            "train_acc": train_acc,
            "test_loss": test_loss,
            "test_acc": test_acc,
        }
    )

In [None]:
df = pd.DataFrame(data)

df_loss = df[["step", "train_loss", "test_loss"]]
df_loss = df_loss.melt("step", var_name="split", value_name="loss")
df_loss["split"] = df_loss["split"].str.replace("_loss", "")


df_acc = df[["step", "train_acc", "test_acc"]]
df_acc = df_acc.melt("step", var_name="split", value_name="accuracy")
df_acc["split"] = df_acc["split"].str.replace("_acc", "")

# fig, axs = plt.subplots(1, 2, figsize=(12, 6))

# sns.lineplot(data=df_loss, x="step", y="loss", hue="split", marker="o", ax=axs[0])
# sns.lineplot(data=df_acc, x="step", y="accuracy", hue="split", marker="o", ax=axs[1])

In [None]:
training_loss = df_loss[df_loss["split"] == "train"]["loss"].tolist()
test_loss = df_loss[df_loss["split"] == "test"]["loss"].tolist()
training_acc = df_acc[df_acc["split"] == "train"]["accuracy"].tolist()
test_acc = df_acc[df_acc["split"] == "test"]["accuracy"].tolist()

In [None]:
def get_dots(kernel_fn, X):
    kernel_fn_batched = nt.batch(kernel_fn, device_count=-1, batch_size=32)
    kernel = kernel_fn_batched(X, None, "ntk", state.params)
    return jnp.linalg.matrix_rank(kernel).item()

In [None]:
import warnings

warnings.filterwarnings("ignore")

df = pd.DataFrame(data)
state_checkpoints = df["state"].tolist()
train_data_checkpoints = df["ds_train"].tolist()
test_data_checkpoints = df["ds_test"].tolist()

svm_accuracy = []
dots_results = []

N_train = 10  # 512
N_test = 10  # 512

X_train = jnp.array(train_data_checkpoints[0]["x"][:N_train])
Y_train = jnp.array(train_data_checkpoints[0]["y"][:N_train])

X_test = jnp.array(test_data_checkpoints[0]["x"][:N_test])
Y_test = jnp.array(test_data_checkpoints[0]["y"][:N_test])

for i in range(len(state_checkpoints)):
    print(f"Iteration: {i}/{len(state_checkpoints)}")

    state = state_checkpoints[i]

    kernel_fn = nt.empirical_kernel_fn(
        state.apply_fn,
        vmap_axes=0,
        implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES,
    )

    dots = get_dots(kernel_fn, X_test)

    def custom_kernel(X1, X2):
        kernel_fn_batched = nt.batch(kernel_fn, device_count=-1, batch_size=32)
        return kernel_fn_batched(X1, X2, "ntk", state.params)

    svc = SVC(kernel=custom_kernel)

    svc.fit(X_train, Y_train)

    predictions = svc.predict(X_test)
    accuracy = accuracy_score(Y_test, predictions)

    svm_accuracy.append(accuracy)
    dots_results.append(dots)

In [None]:
print(training_loss)
print(test_loss)
print(svm_accuracy)
print(dots_results)

graph_data = {
    "training_loss": training_loss,
    "test_loss": test_loss,
    "training_acc": train_acc,
    "test_acc": test_acc,
    "svm_accuracy": svm_accuracy,
    "dots": dots_results,
}

json_data = json.dumps(graph_data, indent=2)

with open("graph_data.json", "w") as json_file:
    json_file.write(json_data)

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(12, 4), sharey=True)

for i, exp in enumerate(["div", "div_mse", "s5"]):
    with open(f"results_{exp}.json", "r") as json_file:
        df = pd.read_json(json_file)
    df["step"] = df.index * 1000
    df = df[["step", "training_acc", "test_acc", "svm_accuracy"]]
    df = df.melt("step", var_name="type", value_name="Accuracy")

    sns.lineplot(data=df, x="step", y="Accuracy", hue="type", marker="o", ax=axs[i])
    axs[i].set(title=exp, xlabel="Step", ylabel="Accuracy")
fig.tight_layout()