In [12]:
import sympy as sp
import jax.numpy as np
import matplotlib.pyplot as plt

from util.interfaces import Config, EqInfo, Hyperparameters, VarInfo
from main import run

$ \displaystyle
\frac{d^2y}{dx^2} - 3\frac{dy}{dx} + 2 y = 0
$

$ \displaystyle
y(0) = c (e-1)
\\ \displaystyle
y(1) = 0
$

Analytical solution:
$ \displaystyle
y = c e^x (e - e^{x})
\\ \displaystyle
$

In [17]:
config = Config(
  eq = EqInfo(
    name = 'y',
    function = lambda s: s.d2ydx2 - 3 * s.dydx + 2 * s.y,
  ),
  vars = {
    'x': VarInfo(bounds=(0, 3), integrable=False),
    'c': VarInfo(bounds=(-20, 20), integrable=True)
  },
  conditions = [
    (4., lambda s: s.y.subs(s.x, 0) - s.c * (sp.exp(1) - 1)),
    (4., lambda s: s.y.subs(s.x, 1)),
  ],
  preoperations = [
    lambda x, c: 0,
    lambda x, c: 1,
    lambda x, c: c,
    lambda x, c: x,
    lambda x, c: -x,
    lambda x, c: sp.exp(x),
    lambda x, c: x * c,
    lambda x, c: sp.exp(x) * c,
  ],
  operations = [
    lambda z: 0,
    lambda z: 1,
    lambda z: z,
    lambda z: z + 1,
    lambda z: -z,
    lambda z: sp.exp(z) + 0,
  ],
  hyperparameters = Hyperparameters(
    lr = 0.0001,
    penalty = 1,
    cellcount = 5,
  ),
  epochs = 128,
  batchsize = 64,
  verbosity = 1,
)

In [18]:
network, best, loss_histories = run(config)

14:41:38.300 [INFO] Constructed symbolic model
14:41:39.927 [INFO] Constructed loss equation
14:41:42.249 [INFO] Constructed JAXified model
14:41:46.672 [INFO] Epoch: 1, Loss: 1591.5024414062
14:41:46.986 [INFO] Epoch: 10, Loss: 1562.7587890625
14:41:47.327 [INFO] Epoch: 20, Loss: 1529.9873046875
14:41:47.668 [INFO] Epoch: 30, Loss: 1495.3302001953
14:41:48.180 [INFO] Epoch: 40, Loss: 1459.3657226562
14:41:48.360 [INFO] Epoch: 50, Loss: 1422.1982421875
14:41:48.715 [INFO] Epoch: 60, Loss: 1388.0090332031
14:41:49.570 [INFO] Epoch: 70, Loss: 1358.6391601562
14:41:49.389 [INFO] Epoch: 80, Loss: 1348.7102050781
14:41:49.720 [INFO] Epoch: 90, Loss: 1343.7312011719
14:41:50.530 [INFO] Epoch: 100, Loss: 2514.5395507812
14:41:50.402 [INFO] Epoch: 110, Loss: 1597.4221191406
14:41:50.738 [INFO] Epoch: 120, Loss: 1282.8747558594
14:41:51.110 [INFO] Epoch: 128, Loss: 4557.3916015625
14:41:51.120 [INFO] Pruning weights...
14:41:51.160 [INFO] Shed -0.02265925146639347 weight
14:41:51.137 [INFO] Con

Nothing more to prune!


$\displaystyle 0.353709277686205 c - 0.421936243772507$

Best loss: 691.0830078125


In [19]:
y_prediction_best = best.model.subs(zip(best.alphas, best.W))
print(sp.latex(y_prediction_best))

0.353709277686205 c - 0.421936243772507


In [20]:
c1 = 1
c2 = 2
x = network.symbols.x
y_pred_fn = sp.lambdify([x], y_prediction_best)
y_real_fn = sp.lambdify([x], (c1 + c2*x) * sp.exp(2 * x))

In [None]:
x = np.linspace(0, 1, 50)
plt.plot(x, y_pred_fn(x), label='prediction')
plt.plot(x, y_real_fn(x), label='actual')
plt.legend()
plt.show()

In [None]:
y = y_prediction_best
d2ydx2 = sp.diff(y, 'x', 2)
dydx = sp.diff(y, 'x')

loss = (d2ydx2 - 4 * dydx + 5 * y)**2

sp.integrate(loss, ('x', 0, 1))