In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from typing import Any

import jax
import jax.numpy as jnp
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from diffrax import (
    diffeqsolve,
    Tsit5,
    Dopri8,
    ODETerm,
    SaveAt,
    PIDController,
    TqdmProgressMeter,
)
import diffrax
import flax.linen as nn
import optax
import neural_tangents as nt
from einops import rearrange

from idiots.dataset.image_classification import mnist_splits
from idiots.dataset.algorithmic import binary_op_splits
from idiots.experiments.grokking.model import EmbedMLP
from idiots.experiments.classification.model import ImageMLP
from idiots.experiments.classification.config import get_config

In [None]:
config: Any = get_config()
config.train_size = 512
config.model.init_scale = 0.3
config.model.normalize_inputs = True
config.model.d_model = 128
config.model.n_layers = 2
config.opt.lr = 1
config.opt.weight_decay = 0
config

Grokking MNIST with Gradient Flow

Config:
```
steps: 700_000
dots_batch_size: 64
dots_sample_size: 128
eval_every: 1000
log_dir: logs/checkpoints/mnist
log_every: 100
loss_variant: cross_entropy
model:
  d_model: 256
  init_scale: 8.0
  n_layers: 2
  normalize_inputs: true
opt:
  lr: 0.001
  name: adamw
  warmup_steps: 10
  weight_decay: 0.004
save_every: -1
seed: 0
steps: 100000
test_batch_size: 128
test_size: 5000
train_batch_size: 128
train_size: 256
```

Grokking x + y (mod 47) GF

```
steps: 16_000_000
dots_batch_size: 64
dots_sample_size: 128
eval_every: 1000
log_dir: logs/checkpoints/mnist
log_every: 100
loss_variant: cross_entropy
model:
  d_model: 256
  init_scale: 1.0
  n_layers: 2
  normalize_inputs: false
opt:
  learning_rate: 1
  lr: 0.001
  name: adamw
  warmup_steps: 10
  weight_decay: 0.002
save_every: -1
seed: 0
steps: 100000
test_batch_size: 128
test_size: 5000
train_batch_size: 128
train_size: 128
```

Grokking x + y (mod 47) GF MSE

```
steps 50_000_000
dots_batch_size: 64
dots_sample_size: 128
eval_every: 1000
log_dir: logs/checkpoints/mnist
log_every: 100
loss_variant: cross_entropy
model:
  d_model: 256
  init_scale: 1.0
  n_layers: 2
  normalize_inputs: false
opt:
  learning_rate: 1
  lr: 0.001
  name: adamw
  warmup_steps: 10
  weight_decay: 3.0e-05
save_every: -1
seed: 0
steps: 100000
test_batch_size: 128
test_size: 5000
train_batch_size: 128
train_size: 128
```

In [None]:
ds_train, ds_test = mnist_splits(config.train_size, config.test_size, config.seed)
# ds_train, ds_test = binary_op_splits("x / y (mod 47)")
xs_train, ys_train = ds_train["x"], ds_train["y"]
ys_train = 1 * jax.nn.one_hot(ys_train, ds_train.features["y"].num_classes)
xs_test, ys_test = ds_test["x"], ds_test["y"]
ys_test = 1 * jax.nn.one_hot(ys_test, ds_test.features["y"].num_classes)

xs_train, ys_train = jax.device_put(xs_train), jax.device_put(ys_train)
xs_test, ys_test = jax.device_put(xs_test), jax.device_put(ys_test)

# xs_train = xs_train[:, [0, 2]]
# xs_test = xs_test[:, [0, 2]]

# model = TransformerSingleOutput(
#     d_model=64,
#     n_layers=2,
#     n_heads=2,
#     old_parameterisation=False,
#     vocab_size=ds_train.features["y"].num_classes,
#     max_len=ds_train.features["x"].length,
# )

model = ImageMLP(
    hidden=config.model.d_model,
    n_layers=config.model.n_layers,
    normalize_inputs=config.model.normalize_inputs,
    out=ds_train.features["y"].num_classes,
)

# model = EmbedMLP(
#     hidden=config.model.d_model,
#     n_layers=config.model.n_layers,
#     n_classes=ds_train.features["y"].num_classes,
# )
params = model.init(jax.random.PRNGKey(config.seed), xs_train)
params = jax.tree_map(lambda x: x * config.model.init_scale, params)

In [None]:
@jax.jit
def loss_fn(params, xs, ys):
    y_pred = model.apply(params, xs)
    return jnp.mean((y_pred - ys) ** 2)
    # return optax.softmax_cross_entropy(y_pred, ys).mean()


def update_fn(params, xs, ys):
    grad = jax.grad(loss_fn)(params, xs, ys)
    update = jax.tree_map(
        lambda g, p: -config.opt.lr * (g + config.opt.weight_decay * p),
        grad,
        params,
    )
    return update


def fixed_norm_update_fn(params, xs, ys):
    grad = jax.grad(loss_fn)(params, xs, ys)
    update = jax.tree_map(
        lambda g, p: -config.opt.lr * (g + config.opt.weight_decay * p),
        grad,
        params,
    )

    # project to remove the component of the update that would change the norm
    # of the parameters
    # u_fixed = u - (u . p^hat) p^hat
    p_norm = optax.global_norm(params)
    p_hat = jax.tree_map(lambda p: p / p_norm, params)
    u_dot_p_hat = sum(jax.tree_util.tree_leaves(jax.tree_map(jnp.vdot, update, p_hat)))
    update = jax.tree_map(lambda u, p_hat: u - u_dot_p_hat * p_hat, update, p_hat)
    return update


t1 = 10000

term = ODETerm(lambda t, ps, args: fixed_norm_update_fn(ps, *args))
solver = Tsit5()
# solver = Dopri8()
save_at = SaveAt(ts=jnp.linspace(0, t1, 101))
step_size_controller = PIDController(rtol=1e-5, atol=1e-8, pcoeff=0.3, icoeff=0.3)

sol = diffeqsolve(
    term,
    solver,
    t0=0,
    t1=t1,
    dt0=1e-5,
    y0=params,
    saveat=save_at,
    stepsize_controller=step_size_controller,
    args=(xs_train, ys_train),
    max_steps=None,
    progress_meter=TqdmProgressMeter(),
)
sol.stats

In [None]:
@jax.jit
def accuracy(params, xs, ys):
    y_pred = model.apply(params, xs)
    return jnp.mean(jnp.argmax(y_pred, axis=-1) == jnp.argmax(ys, axis=-1))


def global_norm(params):
    return jnp.sqrt(sum(jnp.sum(p**2) for p in jax.tree_util.tree_leaves(params)))


@jax.jit
def dots(params, x):
    kernel_fn = nt.batch(
        nt.empirical_ntk_fn(
            model.apply,
            trace_axes=(),
            vmap_axes=0,
            implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES,
        ),
        batch_size=512,
    )
    k = kernel_fn(x, None, params)
    k = rearrange(k, "b1 b2 d1 d2 -> (b1 d1) (b2 d2)")
    return jnp.linalg.matrix_rank(k)  # type: ignore


data = []
for i, t in enumerate(sol.ts):
    if i % 1 != 0:
        continue
    trained_param = jax.tree_map(lambda x: x[i], sol.ys)
    data.append(
        {
            "step": t.item(),
            "weight_norm": global_norm(trained_param).item(),
            "train_loss": loss_fn(trained_param, xs_train, ys_train).item(),
            "train_accuracy": accuracy(trained_param, xs_train, ys_train).item(),
            "test_loss": loss_fn(trained_param, xs_test, ys_test).item(),
            "test_accuracy": accuracy(trained_param, xs_test, ys_test).item(),
            # "dots": dots(trained_param, xs_train[:64]).item(),
        }
    )

In [None]:
df = pd.DataFrame(data)

df_loss = df.melt(
    id_vars=["step"],
    value_vars=["train_loss", "test_loss"],
    var_name="split",
    value_name="loss",
)
df_accuracy = df.melt(
    id_vars=["step"],
    value_vars=["train_accuracy", "test_accuracy"],
    var_name="split",
    value_name="accuracy",
)

fig, axs = plt.subplots(1, 4, figsize=(12, 4))
sns.lineplot(data=df_loss, x="step", y="loss", hue="split", ax=axs[0])
sns.lineplot(data=df_accuracy, x="step", y="accuracy", hue="split", ax=axs[1])
sns.lineplot(data=df, x="step", y="weight_norm", ax=axs[2])
# sns.lineplot(data=df, x="step", y="dots", ax=axs[3])

axs[0].set(yscale="log")

fig.tight_layout()

transformer: 1_000_000 steps

In [None]:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false

In [None]:
from pathlib import Path

import jax
import jax.numpy as jnp
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import neural_tangents as nt
from einops import rearrange

from idiots.experiments.gradient_flow.init import restore

In [None]:
checkpoint_dir = Path("logs/checkpoints/gradient_flow/exp29/checkpoints/")
apply_fn, init_params, ds_train, ds_test, mngr, config = restore(checkpoint_dir)

xs_train, ys_train = ds_train["x"], ds_train["y"]
ys_train = jax.nn.one_hot(ys_train, ds_train.features["y"].num_classes)
xs_train, ys_train = jax.device_put(xs_train), jax.device_put(ys_train)

xs_test, ys_test = ds_test["x"], ds_test["y"]
ys_test = jax.nn.one_hot(ys_test, ds_test.features["y"].num_classes)
xs_test, ys_test = jax.device_put(xs_test), jax.device_put(ys_test)

In [None]:
def dots(params, x):
    kernel_fn = nt.batch(
        nt.empirical_ntk_fn(
            apply_fn,
            trace_axes=(),
            vmap_axes=0,
            implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES,
        ),
        batch_size=512,
    )
    k = kernel_fn(x, None, params)
    k = rearrange(k, "b1 b2 d1 d2 -> (b1 d1) (b2 d2)")
    return jnp.linalg.matrix_rank(k)  # type: ignore


@jax.jit
def metrics(params):
    y_pred = apply_fn(params, xs_train)
    train_loss = jnp.mean((y_pred - ys_train) ** 2)
    train_acc = jnp.mean(jnp.argmax(y_pred, axis=-1) == jnp.argmax(ys_train, axis=-1))

    y_pred = apply_fn(params, xs_test)
    test_loss = jnp.mean((y_pred - ys_test) ** 2)
    test_acc = jnp.mean(jnp.argmax(y_pred, axis=-1) == jnp.argmax(ys_test, axis=-1))

    weight_norm = jnp.sqrt(
        sum(jnp.sum(p**2) for p in jax.tree_util.tree_leaves(params))
    )
    return {
        "train_loss": train_loss,
        "train_accuracy": train_acc,
        "test_loss": test_loss,
        "test_accuracy": test_acc,
        "weight_norm": weight_norm,
        # "train_dots": dots(
        #     params, jax.random.permutation(jax.random.PRNGKey(0), xs_train)[:128]
        # ),
        "test_dots": dots(
            params, jax.random.permutation(jax.random.PRNGKey(0), xs_test)[:128]
        ),
    }


data = []
for step in mngr.all_steps()[::5]:
    params = mngr.restore(step)
    train_metrics = metrics(params)
    print(step, {k: v.item() for k, v in train_metrics.items()})
    data.append({"step": step, **{k: v.item() for k, v in train_metrics.items()}})

In [None]:
df = pd.DataFrame(data)
df["step"] += 1
df_loss = df.melt(
    id_vars=["step"],
    value_vars=df.columns[df.columns.str.contains("loss")],
    var_name="split",
    value_name="loss",
)
df_accuracy = df.melt(
    id_vars=["step"],
    value_vars=df.columns[df.columns.str.contains("accuracy")],
    var_name="split",
    value_name="accuracy",
)
df_dots = df.melt(
    id_vars=["step"],
    value_vars=df.columns[df.columns.str.contains("dots")],
    var_name="split",
    value_name="dots",
)

fig, axs = plt.subplots(1, 4, figsize=(16, 4))
sns.lineplot(data=df_loss, x="step", y="loss", hue="split", ax=axs[0])
sns.lineplot(data=df_accuracy, x="step", y="accuracy", hue="split", ax=axs[1])
sns.lineplot(data=df, x="step", y="weight_norm", ax=axs[2])
sns.lineplot(data=df_dots, x="step", y="dots", hue="split", ax=axs[3])

axs[0].set(yscale="log")
# axs[1].set(xscale="log")