Given
$ \displaystyle
    \frac{d^2 y}{d x^2} - \frac{c_1 y}{c_2 + y} = 0
$
find $y$

<!-- $ \displaystyle
\mathcal{L} = \int_{C} \left| \frac{d^2 y}{d x^2} - \frac{c_1 y}{c_2 + y} \right| \; dC
$ -->

$ \displaystyle
\mathcal{L} = \int_{C} \left( \frac{d^2 y}{d x^2} - \frac{c_1 y}{c_2 + y} \right)^2 \; dC
$


$C$ – $c_1$, $c_2$ space


\begin{cases}
    c_1 = \sigma^2 \quad \text{(normalizuotas)}\\
    c_2 = 1
\end{cases}



# DARTS

$ \displaystyle
\overline{o}^{(i, j)}(x) = \sum_{o \in \mathcal{O}} \left( \alpha_o^{(i,j)} o(x) \right)
\\ \displaystyle
\forall (i, j) \; R = (\sum_{o \in \mathcal{O}} \alpha_o^{(i, j)} - 1)^2
$

# Code

In [1]:
import sympy as sp
from jax import random
import jax.numpy as np

# !rm /etc/localtime
# !ln -s /usr/share/zoneinfo/Europe/Vilnius /etc/localtime

from util.plot import Plotting
from util.print import a, d, pad, info
from util.dotdict import DotDict
from network import Network
from train import train

In [2]:
# def actual_func(z, c_1, c_2=1, c_3=0):
#     return c_1 * z - c_2 * c_1 * np.log(c_2 + z) + c_3

In [3]:
c1 = sp.symbols('c_1')
x_bounds = (0, 4)
c1_bounds = (1.4, 1.6)

In [4]:
loss_model_func = lambda model_y, x, model_d2y: sp.Pow(model_d2y - c1 * (model_y) / (1 + model_y), 2, evaluate=False)
loss_integration_func = lambda loss_model: (loss_model, (c1, *c1_bounds))

In [11]:
cell_count = 5
network = Network(
  loss_model_func,
  loss_integration_func,
  [
    lambda z: 0,
    lambda z: 1,
    lambda z: z,
    lambda z: -z,
    lambda z: z*z,
    lambda z: z*z*z*z,
    lambda x: sp.sin(x),
    lambda z: sp.exp(z),
    # lambda x: sp.Max(x, np.array(0)),
    lambda x: 1 / (1 + sp.exp(-x))
  ],
  cell_count,
  x_bounds)

_, model_y, loss_and_grad, _ = network.get_model()

14:48:15.394 [INFO] Constructed symbolic model
14:48:15.443 [INFO] Integrated
14:48:34.280 [INFO] Substituted y's with replacements
14:48:47.139 [INFO] Lambdified
14:48:47.140 [INFO] Constructed JAXified model


In [12]:
# model_y_numeric = model_y.subs(list(zip(network.alphas, [0.01]*19)))
# loss_model = loss_model_func(model_y_numeric, network.x)
# loss_integrated = sp.integrate(*loss_integration_func(loss_model))
# loss_integrated += network.penalties
# loss_func = network.lambdify_no_alphas(model_y_numeric)

In [None]:
# plotting = Plotting(actual_func, network, x_bounds, c1_bounds)
plotting = Plotting(None, network, x_bounds, c1_bounds)

key = random.PRNGKey(7)

key, subkey = random.split(key)
W = random.uniform(subkey, shape=(len(network.alphas),), minval=0, maxval=0.001)

is_final = False

best = DotDict({"loss": np.inf})

while not is_final:
    plotting.funcs = []
    train_results = train(
        network,
        dataset = [(0, -0.6081976)],
        plotting = plotting,
        key = key,
        lr = 0.00001,
        lr_2 = 0,
        epochs = 20,
        verbose = 0,
        batch_size = 16,
        W_init = W,
        best = best
    )

    W = train_results.W
    loss_history = train_results.loss_history

    info('Pruning weights...')
    network.assign_weights(W)
    W, model_y, loss_and_grad, is_final = network.prune_auto()

In [14]:
best

{'loss': DeviceArray(0.00680969, dtype=float32),
 'model_y': -a_o2__01__04*a_o3__00__01*x + a_o3__00__02*a_o3__02__04*x - a_o3__03__04*(a_o3__00__01*a_o3__01__03*x + a_o3__00__02*a_o3__02__03*x - a_o3__00__03*x) + a_o4__00__04*x**2 + b,
 'alphas': [a_o3__00__01,
  a_o3__00__02,
  a_o3__00__03,
  a_o4__00__04,
  a_o0__01__02,
  a_o3__01__03,
  a_o2__01__04,
  a_o3__02__03,
  a_o3__02__04,
  a_o3__03__04,
  b],
 'W': DeviceArray([ 0.02642444,  0.04561374,  0.06300088,  0.07093308,
               0.07661124,  0.08353736,  0.09056003,  0.09431309,
               0.09922865,  0.10418268, -0.00886083], dtype=float32)}

In [15]:
best.model_y

-a_o2__01__04*a_o3__00__01*x + a_o3__00__02*a_o3__02__04*x - a_o3__03__04*(a_o3__00__01*a_o3__01__03*x + a_o3__00__02*a_o3__02__03*x - a_o3__00__03*x) + a_o4__00__04*x**2 + b

In [17]:
best.model_y.subs(zip(best.alphas, best.W))

0.0709330812096596*x**2 + 0.00801862575846662*x - 0.00886083394289017