In this notebook we create a dataset to test our implementation on and save it, so we always use the same data and don't have to retrain all the time.

In [None]:
# %% Imports
import jax
from jax import random, numpy as jnp
from flax import optim
from modax.models import Deepmod
from modax.training import create_update
from modax.losses import loss_fn_pinn
from modax.logging import Logger

from sklearn.linear_model import BayesianRidge
from jax.scipy.stats import gamma
from modax.data.burgers import burgers
from time import time

from functools import partial
from jax import lax

In [None]:
key = random.PRNGKey(42)
key_data, key_network = random.split(key)

In [None]:
# Making dataset
x = jnp.linspace(-3, 4, 50)
t = jnp.linspace(0.5, 5.0, 20)

t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X_train = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y_train = u.reshape(-1, 1)
y_train += 0.05 * jnp.std(y_train) * jax.random.normal(key_data, y_train.shape)

In [10]:
# Instantiating model and optimizers
model = Deepmod(features=[50, 50, 1])

params = model.init(key_network, X_train)
optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
optimizer = optimizer.create(params)

# Compiling train step
update = create_update(loss_fn_pinn, model=model, x=X_train, y=y_train)
_ = update(optimizer)  # triggering compilation

In [11]:
# Running to convergence
max_epochs = 10001
logger = Logger()
for epoch in jnp.arange(max_epochs):
    optimizer, metrics = update(optimizer)
    if epoch % 1000 == 0:
        print(f"Loss step {epoch}: {metrics['loss']}")
    if epoch % 100 == 0:
        logger.write(metrics, epoch)
logger.close()

Loss step 0: 0.1971043348312378
Loss step 1000: 0.00030297692865133286
Loss step 2000: 0.00017942434351425618
Loss step 3000: 0.00017744959041010588
Loss step 4000: 0.00017721494077704847
Loss step 5000: 0.00017684497288428247
Loss step 6000: 0.00017650557856541127
Loss step 7000: 0.00017631708760745823
Loss step 8000: 0.00017639024008531123
Loss step 9000: 0.0001760099403327331
Loss step 10000: 0.00017583364387974143


In [12]:
prediction, dt, theta, coeffs = model.apply(optimizer.target, X_train)

In [13]:
jnp.save('test_data.npy', {'y': dt, 'X': theta})