In [None]:
import sympy as sp
import jax.numpy as np
import matplotlib.pyplot as plt

from util.interfaces import Config, EqInfo, Hyperparameters, VarInfo
from main import run

Burgers' equation

$
\displaystyle \frac{\partial u}{\partial t} + u \frac{\partial u}{\partial x} - \frac{0.01}{\pi} \frac{\partial^2 u}{\partial x^2} = 0
$

$
\displaystyle u(0, x) = -\sin(\pi x) \\
\displaystyle u(t, -1) = u(t, 1) = 0
$

In [4]:
config = Config(
  eq = EqInfo(
    name = 'u',
    function = lambda s: s.dudt + s.u * s.dudx - (0.01 / sp.pi) * s.d2udx2,
  ),
  vars = {
    't': VarInfo(bounds=(0, 1), integrable=False),
    'x': VarInfo(bounds=(-1, 1), integrable=True)
  },
  conditions = [
      (2., lambda s: s.u.subs(s.t, 0) + sp.sin(sp.pi * s.x)),
      (2., lambda s: s.u.subs(s.x, -1)),
      (2., lambda s: s.u.subs(s.x, 1)),
  ],
  preoperations = [
    lambda t, x: 0,
    lambda t, x: 1,
    lambda t, x: t,
    lambda t, x: x,
    lambda t, x: -x,
    lambda t, x: sp.exp(x),
    lambda t, x: x * t,
    lambda t, x: sp.exp(x) * t,
  ],
  operations = [
    lambda z: 0,
    lambda z: 1,
    lambda z: z,
    lambda z: z + 1,
    lambda z: -z,
    lambda z: sp.exp(z) + 0,
  ],
  hyperparameters = Hyperparameters(
    lr = 0.0001,
    penalty = 1,
    cellcount = 5,
  ),
  epochs = 128,
  batchsize = 64,
  verbosity = 1,
)

In [5]:
network, best, loss_histories = run(config)

17:45:29.580 [INFO] Constructed symbolic model
17:45:30.941 [INFO] Constructed loss equation
17:45:34.360 [INFO] Constructed JAXified model


TypeError: Gradient only defined for scalar-output functions. Output was [Traced<ShapedArray(float32[])>with<BatchTrace(level=2/0)> with
  val = Traced<ShapedArray(float32[64])>with<DynamicJaxprTrace(level=1/0)>
  batch_dim = 0].