In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# %env XLA_PYTHON_CLIENT_MEM_FRACTION=0.95
%env XLA_PYTHON_CLIENT_PREALLOCATE=false

In [None]:
from functools import partial
from pathlib import Path
from typing import cast
from copy import deepcopy

import jax
import jax.numpy as jnp
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import neural_tangents as nt
import optax
from sklearn.model_selection import train_test_split

from idiots.dataset.dataloader import DataLoader
from idiots.experiments.grokking.training import (
    restore as restore_grokking,
    train_step,
    eval_step,
    TrainState,
    loss_fn,
)
from idiots.experiments.classification.training import (
    restore as restore_classification,
    restore_partial as restore_partial_classification,
)
from idiots.utils import metrics

In [None]:
# checkpoint_dir = Path("/home/dc755/idiots/logs/grokking/exp22/checkpoints")
checkpoint_dir = Path("logs/checkpoints/mnist/exp55/checkpoints")

# def linear_model_from(step):
mngr, config, state, ds_train, ds_test = restore_classification(checkpoint_dir, 20000)

In [None]:
linear_state = TrainState.create(
    apply_fn=nt.linearize(state.apply_fn, state.params),
    params=state.params,
    tx=state.tx,
)

train_loader = DataLoader(
    ds_train, config.train_batch_size, shuffle=True, infinite=True, drop_last=True
)
train_iter = iter(train_loader)

while linear_state.step < 2000:
    linear_state, logs = train_step(linear_state, next(train_iter), config.loss_variant)
    linear_state = cast(TrainState, linear_state)  # For better typing
    metrics.log(**logs)

    if linear_state.step % 100 == 0:
        [losses, accuracies] = metrics.collect("loss", "accuracy")
        loss = jnp.concatenate(losses).mean().item()
        acc = jnp.concatenate(accuracies).mean().item()
        print(f"Train {loss=} {acc=}")

    if linear_state.step % 200 == 0:
        for batch in DataLoader(ds_test, config.test_batch_size):
            logs = eval_step(linear_state, batch, config.loss_variant)
            metrics.log(**logs)
        [losses, accuracies] = metrics.collect("eval_loss", "eval_accuracy")
        loss = jnp.concatenate(losses).mean().item()
        acc = jnp.concatenate(accuracies).mean().item()
        print(f"Eval {loss=} {acc=}")

## Using an ODE solver (not working, takes too much memory)

In [None]:
@partial(nt.batch, batch_size=64, store_on_device=True)  # type: ignore
def kernel_fn(x1, x2, params):
    k = nt.empirical_ntk_fn(state.apply_fn, trace_axes=(), vmap_axes=0)(x1, x2, params)
    return k

In [None]:
train_size = 128
# test_size = 128

x_train, _, y_train, _ = train_test_split(
    ds_train["x"], ds_train["y"], train_size=train_size, stratify=ds_train["y"]
)
# x_test = ds_test["x"][:test_size]

k_train_train = kernel_fn(x_train, x_train, state.params)

In [None]:
y_train.shape

In [None]:
def preprocess_y(y):
    return jax.nn.one_hot(y, num_classes=ds_train.features["y"].num_classes)


# y_train = jax.nn.one_hot(y_train, num_classes=ds_train.features["y"].num_classes)
# y_train = y_train.astype(jnp.float32)
# y_train.shape

In [None]:
# nt.predict.gradient_descent
# def cross_entropy(fx, y_hat):
#     return -jnp.mean(jax.nn.log_softmax(fx) * y_hat)


def mse(fx, y):
    y = y - jnp.mean(y, axis=-1, keepdims=True)
    return jnp.mean(jnp.square(fx - y))


# loss = partial(loss_fn, variant="mse")
predict_fn = nt.predict.gradient_descent(
    mse, k_train_train, preprocess_y(y_train), trace_axes=()
)

# y_train_mse = y_train - jnp.mean(y_train)
# predict_fn = nt.predict.gradient_descent_mse(k_train_train, y_train_mse, trace_axes=())

In [None]:
fx_train_0 = state.apply_fn(state.params, x_train)
predict_fn(None, fx_train_0)