In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false

In [None]:
from pathlib import Path

import jax
import jax.numpy as jnp
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from diffrax import (
    diffeqsolve,
    Tsit5,
    ODETerm,
    SaveAt,
    PIDController,
    TqdmProgressMeter,
)
import neural_tangents as nt
from einops import rearrange
import scienceplots

from idiots.experiments.compute_results.compute_results import restore_checkpoint
from idiots.experiments.grokking.training import loss_fn

plt.style.use(["science", "grid"])

In [None]:
config, apply_fn, get_params, ds_train, ds_test, all_steps = restore_checkpoint(
    Path("logs/checkpoints/gradient_flow/exp39/checkpoints"),
    "gradient_flow_algorithmic",
)

xs_train, ys_train = ds_train["x"], ds_train["y"]
xs_train, ys_train = jax.device_put(xs_train), jax.device_put(ys_train)
xs_test, ys_test = ds_test["x"], ds_test["y"]
xs_test, ys_test = jax.device_put(xs_test), jax.device_put(ys_test)

config

In [None]:
def linearise_train_from_step(step):
    params = get_params(step)
    apply_fn_lin = nt.linearize(apply_fn, params)

    def update_fn(params):
        def forward(params):
            ys_pred = apply_fn_lin(params, xs_train)
            return loss_fn(ys_pred, ys_train, variant=config.loss_variant).mean()

        grad = jax.grad(forward)(params)
        update = jax.tree_map(
            lambda g, p: -(g + config.weight_decay * p),
            grad,
            params,
        )
        return update

    t0 = 0
    t1 = config.T - step

    term = ODETerm(lambda t, ps, args: update_fn(ps))
    solver = Tsit5()
    save_at = SaveAt(ts=jnp.arange(t0, t1, config.save_every))
    step_size_controller = PIDController(
        rtol=config.ode.rtol,
        atol=config.ode.atol,
        pcoeff=config.ode.pcoeff,
        icoeff=config.ode.icoeff,
    )

    sol = diffeqsolve(
        term,
        solver,
        t0=t0,
        t1=t1,
        dt0=None,
        y0=get_params(step),
        saveat=save_at,
        stepsize_controller=step_size_controller,
        max_steps=None,
        progress_meter=TqdmProgressMeter(),
    )
    print(sol.stats)

    params = [jax.tree_map(lambda x: x[i], sol.ys) for i in range(len(sol.ts))]

    return apply_fn_lin, params, sol.ts

In [None]:
apply_fn_lins = []
params_lins = []
ts_lins = []

for step in all_steps[0:-1:10]:
    apply_fn_lin, params_lin, ts_lin = linearise_train_from_step(step)
    apply_fn_lins.append(apply_fn_lin)
    params_lins.append(params_lin)
    ts_lins.append(ts_lin + step)

In [None]:
@jax.jit
def accuracy(y_pred, xs, ys):
    return jnp.mean(jnp.argmax(y_pred, axis=-1) == ys)


@jax.jit
def loss(y_pred, xs, ys):
    return loss_fn(y_pred, ys, variant=config.loss_variant).mean()


def global_norm(params):
    return jnp.sqrt(sum(jnp.sum(p**2) for p in jax.tree_util.tree_leaves(params)))


data = []
for step, apply_fn_lin, params_lin, ts_lin in zip(
    all_steps[0:-1:10], apply_fn_lins, params_lins, ts_lins
):
    for param, t in zip(params_lin, ts_lin):
        y_pred_train = apply_fn_lin(param, xs_train)
        y_pred_test = apply_fn_lin(param, xs_test)
        data.append(
            {
                "from": step,
                "step": t.item(),
                "weight_norm": global_norm(param).item(),
                "train_loss": loss(y_pred_train, xs_train, ys_train).item(),
                "train_accuracy": accuracy(y_pred_train, xs_train, ys_train).item(),
                "test_loss": loss(y_pred_test, xs_test, ys_test).item(),
                "test_accuracy": accuracy(y_pred_test, xs_test, ys_test).item(),
            }
        )

In [None]:
df = pd.DataFrame(data)
df.to_json("logs/results/linearisation.json", orient="records")

In [None]:
import orbax.checkpoint as ocp

checkpointer = ocp.StandardCheckpointer()

for step, params_lin, ts_lin in zip(all_steps[0:-1:10], params_lins, ts_lins):
    checkpointer.save(
        Path(f"logs/linearisation/{step}").absolute().resolve(),
        {
            "step": step,
            "params_lin": params_lins,
            "ts": ts_lin,
        },
    )

In [None]:
df = pd.read_json("logs/results/linearisation.json")

with open("logs/results/division-gf-mlp.json", "r") as f:
    df_main = pd.read_json(f)

fig, axs = plt.subplots(1, 2, figsize=(8, 8 / 3))

labels = []

[line] = axs[0].plot(df_main["step"], df_main["train_loss"], color="C0")
labels.append((line, "Train (base model)"))
[line] = axs[0].plot(df_main["step"], df_main["test_loss"], color="C1")
labels.append((line, "Test (base model)"))

df_loss = df.melt(
    id_vars=["from", "step"],
    value_vars=["train_loss", "test_loss"],
    var_name="split",
    value_name="loss",
)
for from_, df_ in df_loss.groupby("from"):
    for i, split in enumerate(df_["split"].unique()):
        data = df_[df_["split"] == split]
        [line] = axs[0].plot(
            data["step"], data["loss"], linestyle="--", alpha=0.9, color=f"C{i}"
        )
        x, y = line.get_data()
        axs[0].plot(
            x[0], y[0], marker="o", color=line.get_color(), markersize=3, alpha=0.9
        )

labels.append((axs[0].lines[2], "Train (linearised)"))
labels.append((axs[0].lines[4], "Test (linearised)"))
# axs[0].legend(*zip(*labels), loc="upper right")
axs[0].set(xlabel="$t$", ylabel="Loss")

labels = []
[line] = axs[1].plot(df_main["step"], df_main["training_acc"], color="C0")
labels.append((line, "Train (base model)"))
[line] = axs[1].plot(df_main["step"], df_main["test_acc"], color="C1")
labels.append((line, "Test (base model)"))

df_accuracy = df.melt(
    id_vars=["from", "step"],
    value_vars=["train_accuracy", "test_accuracy"],
    var_name="split",
    value_name="accuracy",
)
for from_, df_ in df_accuracy.groupby("from"):
    for i, split in enumerate(df_["split"].unique()):
        data = df_[df_["split"] == split]
        [line] = axs[1].plot(
            data["step"],
            data["accuracy"],
            linestyle="--",
            alpha=0.9,
            color=f"C{i}",
            label=f"{from_} {split}",
        )
        x, y = line.get_data()
        axs[1].plot(
            x[0], y[0], marker="o", color=line.get_color(), markersize=3, alpha=0.9
        )

labels.append((axs[1].lines[2], "Train (linearised)"))
labels.append((axs[1].lines[4], "Test (linearised)"))

axs[1].legend(*zip(*labels), loc="lower right", fontsize="small")
axs[1].set(xlabel="$t$", ylabel="Accuracy")

# fig.savefig("logs/plots/linearisation.pdf", bbox_inches="tight")

In [None]:
with open("logs/results/division-gf-mlp.json", "r") as f:
    df = pd.read_json(f)

print(df.keys())
df_loss = df.melt(
    id_vars=["step"],
    value_vars=["train_loss", "test_loss"],
    var_name="split",
    value_name="loss",
)
df_accuracy = df.melt(
    id_vars=["step"],
    value_vars=["training_acc", "test_acc"],
    var_name="split",
    value_name="accuracy",
)

fig, axs = plt.subplots(1, 3, figsize=(12, 4))
sns.lineplot(data=df_loss, x="step", y="loss", hue="split", ax=axs[0])
sns.lineplot(data=df_accuracy, x="step", y="accuracy", hue="split", ax=axs[1])
sns.lineplot(data=df, x="step", y="weight_norm", ax=axs[2])

axs[0].set(yscale="log")

fig.tight_layout()