In [1]:
# %% Imports
import jax
from jax import random, numpy as jnp
from flax import optim
from modax.models import Deepmod
from modax.training import create_update
from modax.losses import loss_fn_mse, loss_fn_pinn
from modax.logging import Logger
from modax.data.burgers import burgers
from time import time

%load_ext autoreload
%autoreload 2

In [2]:
# Making dataset
x = jnp.linspace(-3, 4, 100)
t = jnp.linspace(0.5, 5.0, 20)

t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X_train = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y_train = u.reshape(-1, 1)

In [3]:
# Instantiating model and optimizers
model = Deepmod(features=[50, 50, 1])
key = random.PRNGKey(42)
params = model.init(key, X_train)
optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
optimizer = optimizer.create(params)

In [4]:
# Compiling train step
update = create_update(loss_fn_pinn, model=model, x=X_train, y=y_train)
_ = update(optimizer)  # triggering compilation

In [5]:
# Running to convergence
max_epochs = 10001

t_start = time()
logger = Logger()
for epoch in jnp.arange(max_epochs):
    optimizer, metrics = update(optimizer)
    if epoch % 1000 == 0:
        print(f"Loss step {epoch}: {metrics['mse']}")
    if epoch % 100 == 0:
        logger.write(metrics, epoch)
t_end = time()
logger.close()
print(t_end - t_start)
theta, coeffs = model.apply(optimizer.target, X_train)[2:]
print(coeffs * jnp.linalg.norm(theta, axis=0, keepdims=True).T)

Loss step 0: 0.053290676325559616
Loss step 1000: 2.422910984023474e-05
Loss step 2000: 7.069579623930622e-07
Loss step 3000: 2.1927279192368587e-07
Loss step 4000: 2.7093531684840855e-07
Loss step 5000: 1.5286441623629798e-07
Loss step 6000: 1.0731930188967453e-07
Loss step 7000: 1.97639351995349e-07
Loss step 8000: 9.446227977605304e-08
Loss step 9000: 5.8184117079917996e-08
Loss step 10000: 1.0624194857200564e-07
15.614403247833252
[[ 6.4229132e-03]
 [-2.1348760e-02]
 [ 5.0729265e+00]
 [ 4.2522405e-03]
 [ 6.5230890e-03]
 [-7.2715979e+00]
 [ 1.2733145e-02]
 [-3.8976349e-02]
 [-2.1299455e-02]
 [-3.7962530e-02]
 [-7.3218577e-02]
 [ 9.1348916e-02]]


In [11]:
metrics['mse']

0

mse 8.819382e-08
