In [None]:
import os, time
from collections.abc import Callable
from typing import TypeVar

import numpy as np
from pyDOE import lhs
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
import optax
import optimistix as optx


In [1]:
np.random.seed(1234)
jax.config.update("jax_enable_x64", True)

dir_path = "./checkpoints"
os.makedirs(dir_path, exist_ok=True)


nu = 0.01 / np.pi
x_min, x_max = -1.0, 1.0
t_min, t_max = 0.0, 1.0

model_name = "burgers_1D.eqx"
MODEL_FILE_NAME = os.path.join(dir_path, model_name)



class SSBFGSTrustRegion(optx.AbstractSSBFGS):
    rtol: float
    atol: float
    norm: Callable = optx.max_norm
    use_inverse: bool = True
    search: optx.AbstractSearch = optx.LinearTrustRegion()
    descent: optx.AbstractDescent = optx.NewtonDescent()
    verbose: frozenset[str] = frozenset()

class BroydenTrustRegion(optx.AbstractSSBroyden):                                                                                 
    rtol: float 
    atol: float 
    norm: Callable = optx.max_norm 
    use_inverse: bool = True
    search: optx.AbstractSearch = optx.LinearTrustRegion()
    descent: optx.AbstractDescent = optx.NewtonDescent()
    verbose: frozenset[str] = frozenset()  



initializer = jax.nn.initializers.glorot_normal()

def trunc_init(weight: jax.Array, key: jax.random.PRNGKey) -> jax.Array:
    out, in_ = weight.shape
    return initializer(key, shape=(out, in_))

def init_linear_weight(model, init_fn, key):
    is_linear = lambda x_loc: isinstance(x_loc, eqx.nn.Linear)

    get_weights = lambda m: [x.weight
                             for x in jax.tree_util.tree_leaves(m, is_leaf=is_linear)
                             if is_linear(x)]
    get_bias = lambda m: [x.bias
                          for x in jax.tree_util.tree_leaves(m, is_leaf=is_linear)
                          if is_linear(x)]

    weights = get_weights(model)
    biases  = get_bias(model)

    new_biases = jax.tree_util.tree_map(lambda p: 0.0 * jnp.abs(p), biases)
    new_weights = [init_fn(w, sk)
                   for w, sk in zip(weights, jax.random.split(key, len(weights)))]

    model = eqx.tree_at(get_weights, model, new_weights)
    model = eqx.tree_at(get_bias, model, new_biases)
    return model


class Burgers1D(eqx.Module):
    layers: list

    def __init__(self, key, units=30, depth=4):
        keys = jax.random.split(key, depth + 2)
        self.layers = [
            eqx.nn.Linear(2, units, key=keys[0]),
            *[eqx.nn.Linear(units, units, key=keys[i+1]) for i in range(depth - 1)],
            eqx.nn.Linear(units, 1, key=keys[-1]),]

    def __call__(self, t, x):
        z = jnp.hstack((t, x)) 
        for layer in self.layers[:-1]:
            z = jax.nn.tanh(layer(z))
        return self.layers[-1](z)  



U0_fn = lambda x: -jnp.sin(jnp.pi * x) 
Ub_fn = lambda t, x: 0.0              


@eqx.filter_jit
def burgers_residual(net, tt, xx):
    u_fn = lambda _t, _x: net(_t, _x)[0]  # scalar

    u_t = jax.grad(u_fn, argnums=0)(tt, xx)
    u_x = jax.grad(u_fn, argnums=1)(tt, xx)
    u_xx = jax.grad(jax.grad(u_fn, argnums=1), argnums=1)(tt, xx)

    u = u_fn(tt, xx)
    return u_t + u * u_x - nu * u_xx


@eqx.filter_jit
def loss_fn(net, w_data, w_phys, Xf, X0, u0, Xb, ub):
    u0_pred = jax.vmap(net, in_axes=(0, 0))(X0[:, 0], X0[:, 1])[:, 0:1]
    loss_ic = jnp.mean((u0_pred - u0) ** 2)

    ub_pred = jax.vmap(net, in_axes=(0, 0))(Xb[:, 0], Xb[:, 1])[:, 0:1]
    loss_bc = jnp.mean((ub_pred - ub) ** 2)

    R = jax.vmap(burgers_residual, in_axes=(None, 0, 0))(net, Xf[:, 0], Xf[:, 1]).reshape(-1, 1)
    loss_pde = jnp.mean(R ** 2)

    loss_data = loss_ic + loss_bc
    return w_data * loss_data + w_phys * loss_pde


def sample_interior(nf, seed=0):
    pts = lhs(2, nf)  # (nf,2)
    t = t_min + (t_max - t_min) * pts[:, 0:1]
    x = x_min + (x_max - x_min) * pts[:, 1:2]
    return jnp.array(np.hstack([t, x]))

def sample_initial(n0, seed=0):
    np.random.seed(seed)
    x = np.random.uniform(x_min, x_max, size=(n0, 1))
    t = np.zeros_like(x)
    X0 = np.hstack([t, x])
    u0 = jax.vmap(U0_fn)(jnp.array(x).reshape(-1)).reshape(-1, 1)
    return jnp.array(X0), u0

def sample_boundary(nb, seed=0):
    np.random.seed(seed)
    t = np.random.uniform(t_min, t_max, size=(nb, 1))
    # half left, half right
    nb_l = nb // 2
    nb_r = nb - nb_l
    xl = x_min * np.ones((nb_l, 1))
    xr = x_max * np.ones((nb_r, 1))
    tl = t[:nb_l]
    tr = t[nb_l:]

    Xbl = np.hstack([tl, xl])
    Xbr = np.hstack([tr, xr])
    Xb = np.vstack([Xbl, Xbr])

    ub = jnp.zeros((nb, 1), dtype=jnp.float64)
    return jnp.array(Xb), ub


def load(filename, model):
    with open(filename, "rb") as f:
        return eqx.tree_deserialise_leaves(f, model)


if __name__ == "__main__":
    Nf = 50_000
    N0 = 2_000
    Nb = 2_000

    lambda_data = 5.0   
    lambda_phys = 1.0

    adam_iters = 10_000
    print_every = 200

    Xf = sample_interior(Nf, seed=123)
    X0, u0 = sample_initial(N0, seed=456)
    Xb, ub = sample_boundary(Nb, seed=789)

    key = jr.PRNGKey(1234)
    key, init_key = jr.split(key)
    pinn = Burgers1D(init_key, units=30, depth=4)
    pinn = init_linear_weight(pinn, trunc_init, init_key)

    lr = 1e-3
    optimizer = optax.radam(lr)
    opt_state = optimizer.init(eqx.filter(pinn, eqx.is_inexact_array))

    @eqx.filter_jit
    def train_step(net, state):
        l, g = eqx.filter_value_and_grad(loss_fn)(net, lambda_data, lambda_phys, Xf, X0, u0, Xb, ub)
        updates, new_state = optimizer.update(g, state, net)
        new_net = eqx.apply_updates(net, updates)
        return new_net, new_state, l

    def ic_rel_l2(net):
        u0_pred = jax.vmap(net, in_axes=(0, 0))(X0[:, 0], X0[:, 1])[:, 0:1]
        return float(jnp.linalg.norm(u0_pred - u0) / (jnp.linalg.norm(u0) + 1e-12))

    print("Training (Adam) started!")
    t0 = time.time()

    last_adam_loss = None
    last_adam_ic_err = None

    for epoch in range(1, adam_iters + 1):
        pinn, opt_state, l = train_step(pinn, opt_state)

        if epoch % print_every == 0:
            last_adam_loss = float(l)
            last_adam_ic_err = ic_rel_l2(pinn)
            print(f"[Adam] step {epoch:6d} | loss {last_adam_loss:.3e} | IC relL2 {last_adam_ic_err:.3e}")

    eqx.tree_serialise_leaves(MODEL_FILE_NAME, pinn)
    t1 = time.time()
    print(f"Elapsed Time for Adam: {t1 - t0:.2f} sec")
    
    params, static = eqx.partition(pinn, eqx.is_inexact_array)

    def loss_newton(dynamic_model, static_model):
        model = eqx.combine(dynamic_model, static_model)
        return loss_fn(model, lambda_data, lambda_phys, Xf, X0, u0, Xb, ub)

    solver = SSBFGSTrustRegion(rtol=1e-16, atol=1e-16, verbose=["loss"])
    solver = optx.BestSoFarMinimiser(solver)

    print("Training (SSBFGS) started!")
    t2 = time.time()
    sol = optx.minimise(loss_newton, solver, params, args=static, max_steps=100_000, throw=False)
    t3 = time.time()

    pinn = eqx.combine(sol.value, static)
    eqx.tree_serialise_leaves(MODEL_FILE_NAME, pinn)

    final_loss = float(loss_fn(pinn, lambda_data, lambda_phys, Xf, X0, u0, Xb, ub))
    final_ic_err = ic_rel_l2(pinn)
    total_time = (t1 - t0) + (t3 - t2)


PINN Simulation for 1D Burgers: nu=0.003183098861837907, x∈[-1.0,1.0], t∈[0.0,1.0]


AttributeError: module 'optimistix' has no attribute 'AbstractSSBFGS'