In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
import copy
from typing import Any

import jax
import jax.numpy as jnp
import neural_tangents as nt
import orbax.checkpoint as ocp
from absl import app, flags, logging
from datasets import Dataset
from ml_collections import config_flags
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import optax

from idiots.dataset.dataloader import DataLoader
from idiots.dataset.algorithmic import binary_op_splits
from idiots.experiments.grokking.training import TrainState, dots, eval_step, train_step
from idiots.experiments.grokking.model import TransformerSingleOutput
from idiots.experiments.grokking.config import get_config
from idiots.utils import metrics, num_params, get_optimizer

In [None]:
config: Any = get_config()
config.steps = 20000
config.log_every = 500
config.opt.weight_decay = 0.1
config.model.d_model = 128
config

In [None]:
ds_train, ds_test = binary_op_splits(config.task, config.train_percentage, config.seed)
model = TransformerSingleOutput(
    d_model=config.model.d_model,
    n_layers=config.model.n_layers,
    n_heads=config.model.n_heads,
    vocab_size=ds_train.features["y"].num_classes,
    max_len=ds_train.features["x"].length,
)
init_params = model.init(jax.random.PRNGKey(config.seed), ds_train["x"][:1])
tx = get_optimizer("adamw", **config.opt)
state = TrainState.create(apply_fn=model.apply, params=init_params, tx=tx)
print(f"Model has {num_params(init_params):,} parameters")

In [None]:
train_iter = iter(
    DataLoader(
        ds_train, config.train_batch_size, shuffle=True, infinite=True, drop_last=True
    )
)

In [None]:
def evaluate(ds, state: TrainState):
    for batch in DataLoader(ds, config.test_batch_size):
        logs = eval_step(state, batch, config.loss_variant)
        metrics.log(**logs)
    [losses, accuracies] = metrics.collect("eval_loss", "eval_accuracy")
    val_loss = jnp.concatenate(losses).mean().item()
    val_acc = jnp.concatenate(accuracies).mean().item()
    return val_loss, val_acc

In [None]:
while state.step < config.steps:
    state, logs = train_step(state, next(train_iter), config.loss_variant)
    metrics.log(**logs)

    if state.step % config.log_every == 0 and config.log_every > 0:
        [losses, accuracies] = metrics.collect("loss", "accuracy")
        train_loss = jnp.concatenate(losses).mean().item()
        train_acc = jnp.concatenate(accuracies).mean().item()
        val_loss, val_acc = evaluate(ds_test, state)
        print(
            f"Step {state.step}: train/loss={train_loss:.4f} train/acc={train_acc:.4f} val/loss={val_loss:.4f} val/acc={val_acc:.4f}"
        )

In [None]:
mask = jax.tree_map(lambda p: jnp.abs(p) > 0.02, state.params)

In [None]:
def mask_update(updates, params):
    del params
    return jax.tree_map(lambda u, m: u * m, updates, mask)


lottery_params = jax.tree_map(lambda p, m: p * m, state.params, mask)
new_tx = optax.chain(
    # optax.sgd(1e-3, momentum=0.9),
    get_optimizer("adamw", **config.opt),
    optax.stateless(mask_update),
)
linear_apply_fn = nt.linearize(model.apply, lottery_params)

state_pruned = TrainState.create(
    apply_fn=linear_apply_fn,
    # apply_fn=model.apply,
    params=lottery_params,
    tx=new_tx,
)

evaluate(ds_test, state_pruned), evaluate(ds_train, state_pruned)

In [None]:
while state_pruned.step < 5000:
    state_pruned, logs = train_step(state_pruned, next(train_iter), config.loss_variant)
    metrics.log(**logs)

    if state_pruned.step % config.log_every == 0:
        [losses, accuracies] = metrics.collect("loss", "accuracy")
        train_loss = jnp.concatenate(losses).mean().item()
        train_acc = jnp.concatenate(accuracies).mean().item()
        val_loss, val_acc = evaluate(ds_test, state_pruned)
        print(
            f"Step {state_pruned.step}: train/loss={train_loss:.4f} train/acc={train_acc:.4f} val/loss={val_loss:.4f} val/acc={val_acc:.4f}"
        )

In [None]:
def magnitude_vectorize(params):
    p = jax.tree_map(lambda x: jnp.abs(x).flatten(), params)
    p = jnp.concat(jax.tree_util.tree_flatten(p)[0], axis=0)
    return p


df = pd.DataFrame(
    {
        "init": magnitude_vectorize(init_params),
        "trained": magnitude_vectorize(state.params),
        "pruned_trained": magnitude_vectorize(state_pruned.params),
    }
)
df = df.melt(var_name="type", value_name="magnitude")

fig, ax = plt.subplots()
sns.ecdfplot(data=df, x="magnitude", hue="type", ax=ax, log_scale=True)