In [1]:
import jax
from typing import Any, Callable, Sequence, Optional
from jax import lax, random, numpy as jnp
import flax
from flax import linen as nn
import sympy as sy
import numpy as np
import sys
sys.path.append("..")
from eql.eqlearner import EQL
from eql.symbolic import get_symbolic_expr, get_symbolic_expr_layer
from typing import List, Tuple, Callable
from functools import partial
import matplotlib.pyplot as plt
import scipy
sys.path.append("../../orient")
import optax

from eql.np_utils import flatten, unflatten

In [2]:
funs = ['mul', 'cos', 'sin']*2
e = EQL(n_layers=2, functions=funs, features=1)
key = random.PRNGKey(0)

In [3]:
N = 1000
xdim = 1
x = (random.uniform(key, (N, xdim))-.5) * 2

#y = x[:,0] + jnp.cos(x[:,1])
y = jnp.cos(x) + 1 - x**2 

In [4]:
params = e.init({'params':key}, x);

In [5]:
params

{'params': {'layers_0': {'linear_layer': {'kernel': Array([[ 0.6292481 , -0.54229003, -0.4677973 ,  0.70746535,  0.09610943,
             0.4903576 ,  0.06954016,  0.31511047]], dtype=float32),
    'bias': Array([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}},
  'layers_1': {'linear_layer': {'kernel': Array([[ 0.19629148,  0.06706532,  0.00969497,  0.23161124, -0.11875771,
             0.01930422,  0.05764793, -0.0596979 ],
           [ 0.4696742 ,  0.46638718,  0.21155392, -0.8553514 ,  0.6172892 ,
            -0.44233188,  0.22190586, -0.18106432],
           [-0.44605988, -0.20469512, -0.32464647,  0.5807687 ,  0.02217815,
            -0.2559955 ,  0.17005439,  0.71249163],
           [ 0.45161504, -0.10305008, -0.47209048,  0.35264972, -0.3844344 ,
             0.2527192 , -0.5963687 ,  0.310487  ],
           [-0.04032504,  0.10622181,  0.5348187 ,  0.14317314,  0.38765785,
            -0.48200536, -0.28618744,  0.3023454 ],
           [-0.90012056, -0.02397138, -0.5772944 , -0

In [6]:
def mse_fn(params):
    pred = e.apply(params, x)
    return jnp.mean((pred-y)**2)

In [7]:
def loss(params):
    return mse_fn(params)
loss_grad_fn = jax.jit(jax.value_and_grad(loss))

In [8]:
tx = optax.adam(learning_rate=1e-2)
opt_state = tx.init(params)
loss_grad_fn = jax.jit(jax.value_and_grad(loss))

In [9]:
for i in range(10000):
    loss_val, grads = loss_grad_fn(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 99 == 0:
        print(loss_val)

2.2196107
0.002707143
0.0004208714
0.00026540947
0.00011064221
2.0543432e-05
4.4613707e-06
3.6789017e-06
3.5207515e-06
3.363988e-06
3.2019302e-06
3.0358692e-06
2.86721e-06
2.6971832e-06
2.5270917e-06
2.3581754e-06
2.1915253e-06
2.0283308e-06
1.8695702e-06
1.7162504e-06
1.5691925e-06
1.4292357e-06
1.2970082e-06
1.1730391e-06
1.0577944e-06
9.515139e-07
8.543414e-07
7.6624104e-07
6.8713007e-07
6.166335e-07
8.093855e-07
4.9545355e-07
4.5026394e-07
9.476678e-07
3.7751127e-07
3.7427924e-07
3.4744824e-07
3.1015563e-07
1.2761588e-05
2.8459053e-07
2.7043447e-07
4.3855079e-07
2.5587454e-07
2.4743045e-07
2.937638e-07
2.3595061e-07
2.2421938e-07
3.1366207e-07
2.317404e-07
2.1934987e-07
2.0726701e-07
3.055188e-07
2.060244e-07
1.9359871e-07
0.00011240191
2.123392e-07
1.9560291e-07
1.8297541e-07
1.715815e-07
2.2103261e-07
1.6975203e-07
1.5881751e-07
2.3513586e-07
1.6708915e-07
1.5486822e-07
8.557334e-05
1.5615831e-07
1.4272484e-07
1.9601503e-05
1.6594392e-07
1.4705307e-07
1.3505674e-07
1.2428454e-07


In [10]:
params

{'params': {'last': {'bias': Array([0.25736916], dtype=float32),
   'kernel': Array([[ 0.6017721 ],
          [ 0.99713594],
          [-0.40698075],
          [ 0.14944981],
          [ 0.01900841],
          [-0.39706507]], dtype=float32)},
  'layers_0': {'linear_layer': {'bias': Array([-0.04229239,  0.0458124 , -0.28680083, -0.11495356, -0.0570096 ,
           -0.14381441, -0.37563142, -0.20593354], dtype=float32),
    'kernel': Array([[ 0.94337577, -0.86841416, -0.6794713 ,  0.77071905, -0.09809233,
             0.30688858, -0.7023072 , -0.06171607]], dtype=float32)}},
  'layers_1': {'linear_layer': {'bias': Array([-0.0066725 ,  0.00834501,  0.151604  ,  0.21296903,  0.40795207,
           -0.06007641,  0.26944926, -0.01032709], dtype=float32),
    'kernel': Array([[ 0.16826028,  0.05169917,  0.17407492,  0.4332904 ,  0.23796213,
            -0.1557705 ,  0.34648117, -0.04225495],
           [ 0.2069332 ,  0.20985135,  1.0525651 , -1.0614916 ,  0.45836076,
            -0.10832226, 

In [11]:
symb = get_symbolic_expr(params, funs)[0]
print(symb)

-0.397065073251724*(-1.34127986431122*(0.0458123981952667 - 0.868414163589478*x0)*(0.943375766277313*x0 - 0.0422923900187016) + 0.0730874538421631*(-0.0980923250317574*x0 - 0.0570096001029015)*(0.306888580322266*x0 - 0.143814414739609) - 0.277734339237213*sin(0.0617160685360432*x0 + 0.205933541059494) - 0.108322262763977*sin(0.770719051361084*x0 - 0.114953555166721) - 0.155770495533943*cos(0.6794713139534*x0 + 0.28680083155632) - 0.420195698738098*cos(0.702307224273682*x0 + 0.375631421804428) - 0.0600764080882072)*(-0.374362200498581*(0.0458123981952667 - 0.868414163589478*x0)*(0.943375766277313*x0 - 0.0422923900187016) - 0.34632220864296*(-0.0980923250317574*x0 - 0.0570096001029015)*(0.306888580322266*x0 - 0.143814414739609) + 0.612203001976013*sin(0.0617160685360432*x0 + 0.205933541059494) + 0.458360761404037*sin(0.770719051361084*x0 - 0.114953555166721) + 0.237962126731873*cos(0.6794713139534*x0 + 0.28680083155632) + 0.349282592535019*cos(0.702307224273682*x0 + 0.375631421804428) + 