In [None]:
from flax import linen as nn, struct
from flax.training.train_state import TrainState
from jax.flatten_util import ravel_pytree
from optax import (
    softmax_cross_entropy_with_integer_labels as xent
)
import jax
import jax.numpy as jnp
import optax


class MLP(nn.Module):
    hidden_sizes: tuple[int, ...]
    out_features: int

    @nn.compact
    def __call__(self, x):
        for feat in self.hidden_sizes:
            scale = feat ** -0.5
            bias_init = nn.initializers.normal(stddev=scale)

            x = nn.Dense(feat, bias_init=bias_init)(x)
            x = nn.gelu(x)

        x = nn.Dense(self.out_features, bias_init=bias_init)(x)
        return x


@struct.dataclass
class TrainConfig:
    batch_size: int = 64
    num_epochs: int = 25

    opt: str = "sgd"


def make_apply_full(model, unraveler):
    """Make an apply function that takes the full parameter vector."""
    def apply_full(raveled, x):
        params = unraveler(raveled)
        return model.apply(params, x)
    
    return apply_full


def make_apply_subspace(model, unraveler, params0, basis):
    """Make an apply function that takes a small parameter vector."""
    def apply_subspace(small_params, x):
        raveled = params0 + basis.T @ small_params
        return model.apply(unraveler(raveled), x)
    
    return apply_subspace


# Loss function
def compute_loss(params, apply_fn, X, Y):
    logits = apply_fn(params['p'], X)
    preds = jnp.argmax(logits, axis=-1)

    loss = xent(logits, Y).mean()
    acc = jnp.mean(preds == Y)
    return loss, acc


def train(params, x_train, y_train, x_test, y_test, apply_fn, cfg: TrainConfig):
    # Create the batches
    X_batched = jnp.reshape(x_train, (-1, cfg.batch_size, 64))
    Y_batched = jnp.reshape(y_train, (-1, cfg.batch_size))

    # LR schedule
    num_steps = cfg.num_epochs * len(x_train) // cfg.batch_size

    # Define the optimizer and training state
    if cfg.opt == "adam":
        sched = optax.cosine_decay_schedule(3e-3, num_steps)
        tx = optax.adam(learning_rate=sched, eps_root=1e-8)
    else:
        sched = optax.cosine_decay_schedule(0.1, num_steps)
        tx = optax.chain(
            optax.sgd(learning_rate=sched, momentum=0.9)
        )

    state = TrainState.create(apply_fn=apply_fn, params=dict(p=params), tx=tx)

    # Forward and backward pass
    loss_and_grad = jax.value_and_grad(compute_loss, has_aux=True)

    def train_step(state: TrainState, batch):
        loss, grads = loss_and_grad(state.params, state.apply_fn, *batch)
        return state.apply_gradients(grads=grads), loss

    def epoch_step(state: TrainState, epoch) -> tuple[TrainState, tuple[jnp.ndarray, jnp.ndarray]]:
        state, (losses, accs) = jax.lax.scan(train_step, state, (X_batched, Y_batched))
        return state, (losses.mean(), accs.mean())

    state, (train_loss, _) = jax.lax.scan(epoch_step, state, jnp.arange(cfg.num_epochs))
    raveled, _ = ravel_pytree(state.params)

    # Test loss
    logits = state.apply_fn(state.params['p'], x_test)
    test_loss = xent(logits, y_test).mean()
    return raveled, test_loss#, train_loss # test_loss, train_loss[-1]


# grad_fn = jax.value_and_grad(train, has_aux=True)
jac_fn = jax.jacfwd(train, has_aux=True)

In [None]:
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split


# Load data
X, Y = load_digits(return_X_y=True)
X = X / 16.0  # Normalize

# Split data
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=197, random_state=0)
spectra = []

d_inner = X.shape[1] // 2

model = MLP(hidden_sizes=(d_inner,), out_features=10)

# Do a single run with seed 0 to get the Jac
key = jax.random.key(0)

params_0 = model.init(key, X_train)
raveled_0, unraveler = ravel_pytree(params_0)
apply_fn = make_apply_full(model, unraveler)

raveled_f, loss = train(raveled_0, X_train, Y_train, X_test, Y_test, apply_fn, TrainConfig(opt="sgd"))
params_f = unraveler(raveled_f)

params_0, unravel = ravel_pytree(params_0)
jac, losses = jac_fn(raveled_0, X_train, Y_train, X_test, Y_test, apply_fn, TrainConfig(opt="sgd"))
jac_u, jac_s, jac_vh = jnp.linalg.svd(jac)

## Perturbations

In [None]:
import matplotlib.pyplot as plt

complements = []
bulk_y = []
responses = []

params_f, loss = train(
    params_0, X_train, Y_train, X_test, Y_test, apply_fn, TrainConfig(opt="sgd")
)

num = 2000
grid = jnp.logspace(-3, 3, 30)
for stim in grid:
    perturbed_params, loss = train(
        params_0 + stim * jac_vh[num], X_train, Y_train, X_test, Y_test, apply_fn, TrainConfig(opt="sgd")
    )
    delta = perturbed_params - params_f
    res = delta @ jac_u[:, num]
    complement = jnp.linalg.norm(delta - res * jac_u[:, num])

    complements.append(complement)
    bulk_y.append(loss)
    responses.append(res)


plt.figure(figsize=(5, 4))  # Adjust figure size to be more paper-friendly
plt.plot(grid, grid * jac_s[num], label="Pred. of Jacobian", c="black", linestyle="--")
plt.plot(grid, jnp.abs(jnp.array(responses)), marker="o", label="Proj. on left singular vector", alpha=0.75)
plt.plot(grid, complements, marker="o", label="Proj. on complement", alpha=0.75)
plt.plot(grid, bulk_y, marker="x", label="Test Loss", alpha=0.75)
plt.xscale("log")
plt.yscale("log")
plt.xlabel("Stimulus Size")
plt.ylabel("Response in Final Model")
plt.title(f"Response to Perturbing Init. Along Bulk Direction")
# plt.title(f"Response to Perturbing Init. Along Top S.V. ($\\sigma_1 = {s[0]:.3f}$)")
plt.legend()

# Tight layout to ensure no clipping
plt.tight_layout()

# Show or save the plot
plt.savefig('bulk perturb.pdf', format='pdf')
plt.show()