In [1]:
import pathlib

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrandom
import jax.tree_util as jtu
import matplotlib.pyplot as plt
import optax
from jaxtyping import Array, Float
from jax import vmap

import diff_ml as dml
import diff_ml.nn as dnn
from diff_ml.model import Bachelier
from diff_ml.nn.utils import init_model_weights, predict
from diff_ml.typing import Data

from diff_ml.plotting import plot_eval



In [2]:
def loss_fn(model, batch: Data) -> Float[Array, ""]:
    xs, ys = batch["x"], batch["y"]
    pred_ys = eqx.filter_vmap(model)(xs)
    return dml.losses.mse(ys, pred_ys)


def eval_fn(model, batch: Data) -> Float[Array, ""]:
    return jnp.sqrt(loss_fn(model, batch))


def train_generator(xs, n_samples: int, n_batch_size: int, *, key):
    while True:
        key, subkey = jrandom.split(key)

        def subset_fn(key):
            choice = jrandom.choice(key=key, a=n_samples, shape=(n_batch_size,))

            def subset(x):
                return x[choice]

            return subset

        yield jtu.tree_map(subset_fn(subkey), xs)



In [None]:
# Specify model
key = jrandom.key(0)
n_dims: int = 7
n_samples: int = 8 * 1024
key, subkey = jrandom.split(key)
weights = jrandom.uniform(subkey, shape=(n_dims,), minval=1.0, maxval=10.0)
ref_model = Bachelier(key, n_dims, weights)



# Generate data
train_ds = ref_model.sample(n_samples)
test_ds = ref_model.analytic(n_samples)




n_epochs = 100
n_batch_size = 256

key, subkey = jrandom.split(key)
train_gen = train_generator(train_ds, n_samples, n_batch_size, key=subkey)

# TODO move to diff_ml/plotting.py
# Plot data
x_train_mean = jnp.mean(train_ds["x"])
x_train_std = jnp.std(train_ds["x"])
y_train_mean = jnp.mean(train_ds["y"])
y_train_std = jnp.std(train_ds["y"])

xs_train = jnp.asarray(train_ds["x"])
ys_train = jnp.asarray(train_ds["y"])
zs_train = jnp.asarray(train_ds["dydx"])

xs_test = jnp.asarray(test_ds["x"])
ys_test = jnp.asarray(test_ds["y"])
zs_test = jnp.asarray(test_ds["dydx"])

baskets = ref_model.baskets(xs_test)





# Specify the surrogate model architecture
key, subkey = jrandom.split(key)
mlp = eqx.nn.MLP(key=subkey, in_size=n_dims, out_size="scalar", width_size=20, depth=3, activation=jax.nn.silu) # jax.nn.silu

key, subkey = jrandom.split(key)
mlp = init_model_weights(mlp, jax.nn.initializers.glorot_normal(), key=subkey)

surrogate = dnn.Normalized(
    dnn.Normalization(x_train_mean, x_train_std), mlp, dnn.Denormalization(y_train_mean, y_train_std)
)


## Train the surrogate using sobolev loss
#optim = optax.adam(learning_rate=1e-3)
#sobolev_loss_fn = dml.losses.sobolev(dml.losses.mse, method=dml.losses.SobolevLossType.SECOND_ORDER_PCA, ref_model=ref_model)
#surrogate, metrics = dml.train(
#    surrogate, sobolev_loss_fn, train_gen, eval_fn, test_ds, optim, n_epochs=n_epochs)
#
#
#
#
## visualize predictions of model trained on second order data
#pred_y, pred_dydx, pred_ddyddx = predict(surrogate, test_ds["x"])
#plot_eval(pred_y, pred_dydx, pred_ddyddx, test_ds)



  0%|          | 0/100 [00:00<?, ?it/s]


TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[256,49]
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

In [None]:

def train(
    model,
    loss_fn,
    train_data: DataGenerator,
    eval_fn,
    test_data: Optional[Data],
    optim: optax.GradientTransformation,
    n_epochs: int,
    n_batches_per_epoch: int = 64,
) -> PyTree:
    """Canonical training loop."""
    opt_state = optim.init(eqx.filter(model, eqx.is_array))
    train_loss = jnp.zeros(1)
    batch_size = len(next(train_data)["x"])
    metrics = {"train_loss": jnp.zeros(n_epochs), "test_loss": jnp.zeros(n_epochs)}
    loss_state = LossState(jnp.array([0.0, 0.0, 1.0]), jnp.array([1/3, 1/3, 1/3]), jnp.array([0.0, 0.0, 1.0]), jnp.array([0.0, 0.0, 0.0]), jnp.array([0.0, 0.0, 0.0])) 

    pbar = tqdm(range(n_epochs))
    for epoch in pbar:
        for batch in islice(train_data, n_batches_per_epoch):
            model, opt_state, train_loss, loss_state = train_step(model, loss_fn, optim, opt_state, batch, loss_state)


        # update loss_state
        loss_state = loss_state.update_prev_mean_losses(loss_state.accum_losses / loss_state.current_iter[0])
        loss_state = loss_state.update_accum_losses(jnp.zeros(len(loss_state.losses)))
        loss_state = loss_state.update_current_iter(jnp.zeros(len(loss_state.losses)))


        metrics_update_element(metrics, "train_loss", epoch, train_loss)
        epoch_stats = f"Epoch: {epoch:3d} | Train: {train_loss:.5f}"

        if test_data:
            test_loss = evaluate(model, test_data, batch_size, eval_fn)
            metrics_update_element(metrics, "test_loss", epoch, test_loss)
            epoch_stats += f" | Test: {test_loss:.5f}"

        pbar.set_description(epoch_stats)

    return model, metrics
