In [20]:
import jax
from typing import Any, Callable, Sequence, Optional
from jax import lax, random, numpy as jnp
import flax
from flax import linen as nn
import sympy as sy
import numpy as np
import sys
sys.path.append("..")
from eql.eqlearner import EQL
#import custom_functions
from eql.symbolic import get_symbolic_expr, get_symbolic_expr_layer
from typing import List, Tuple, Callable
from functools import partial
import matplotlib.pyplot as plt

In [4]:
import scipy
sys.path.append("../../orient")
from np_utils import flatten, unflatten

In [23]:
#funs = ['sin', 'cos', 'id', 'mul']
funs = ['mul', 'cos', 'sin']*2
e = EQL(n_layers=2, functions=funs, features=1, use_l0=False)

In [24]:
key = random.PRNGKey(0)
key, k1, k2 = random.split(key, 3)

In [62]:
N = 1000
xdim = 1
x = (random.uniform(key, (N, xdim))-.5) * 2
#x = np.array([[1., 2.]]).T
#x = np.linspace(-1, 1, N)[:,None]
#y = x[:,0] + jnp.cos(x[:,1])
y = jnp.cos(x) + 1 - x**2 

def make_l0_func():
    def l0(params, key):
        return e.apply(params, rngs={'l0':key}, method=e.l0_reg)
    return jax.jit(l0)


def my_mse(params):
    pred = e.apply(params, x)
    #pred = e.apply(params, x, rngs={'l0': key})
    return jnp.mean((pred-y)**2)

def my_mse_det(params):
    #pred = e.apply(params, x)
    pred = e.apply(params, x, deterministic=True)
    return jnp.mean((pred-y)**2)

mse_fn = jax.jit(my_mse)#make_mse_func(x, y)
l0_fn = make_l0_func()


#params = e.init({'params':k1, 'l0': k2}, x);
params = e.init({'params':k1}, x);

In [50]:
#mse_det_fn = jax.jit(my_mse_det)
det_loss_grad_fn = jax.jit(jax.value_and_grad(my_mse_det))

In [51]:
def loss(params):
    return mse_fn(params)# + 1e-3*l0_fn(params, key)
loss_grad_fn = jax.jit(jax.value_and_grad(loss))

In [52]:
spec, flat = flatten(params)

In [53]:
def np_fn(params, key):
    params = unflatten(spec, params)
    
    key = jax.random.fold_in(key, np_fn.counter)
    loss, grad = loss_grad_fn(params, key)
    _, grad = flatten(grad)

    return loss, np.array(grad)

In [54]:
def np_fn_det(params):
    params = unflatten(spec, params)
    
    loss, grad = loss_grad_fn(params)
    _, grad = flatten(grad)
    
    print(loss)

    return loss, np.array(grad)

In [55]:
# key, _ = random.split(key)
# x0, _, info = scipy.optimize.fmin_l_bfgs_b(
#         np_fn,
#         args=[key],
#         x0 = np.array(flat),
#         maxfun=2000,
#         factr=1.,
#         m=500,
#         pgtol=1e-20,
#         maxls = 1000)
# print(mse_fn(unflatten(spec, x0), key))
# flat = x0

In [56]:
#symb = get_symbolic_expr(unflatten(spec, x0), funs)[0]
#symb = get_symbolic_expr(unflatten(spec,flat), funs)[0]
#outs = sy.lambdify("x0", symb)(x[:,0])#, x[:,1], x[:,2])
#outs = sy.lambdify("x0, x1, x2", symb)(x[:,0])#, x[:,1], x[:,2])

In [57]:
#plt.scatter(x,y)
#plt.scatter(x,e.apply(unflatten(spec, x0), x),alpha=.3,marker='.')
#plt.scatter(x,outs)

In [30]:
import optax
tx = optax.adam(learning_rate=1e-2)
opt_state = tx.init(params)
loss_grad_fn = jax.jit(jax.value_and_grad(loss))

In [32]:
for i in range(1000):
    key, k1 = random.split(key)
    loss_val, grads = loss_grad_fn(params, key)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 99 == 0:
        print(loss_val)
        print(mse_fn(params, key))
        #print(l0_fn(params, key))

4.4585567
3.9813523
0.0044944966
0.0043987934
0.0016995181
0.0016875749
0.00087663805
0.0008708908
0.0004579496
0.00045497963
0.00024350798
0.00024199794
0.00013436153
0.00013359007
7.8335455e-05
7.793713e-05
4.9279955e-05
4.9072387e-05
3.406455e-05
3.3955046e-05
2.5964495e-05
2.5905209e-05


In [34]:
symb = get_symbolic_expr(params, funs, use_l0=False)[0]

In [35]:
symb

-0.761843681335449*(-0.0769314616918564*(-0.792992651462555*x0 - 0.0266538988798857)*(1.20189297199249*x0 + 0.0326758101582527) - 0.483024954795837*(0.409473866224289*x0 - 0.0535124987363815)*(0.458635061979294*x0 - 0.0637139305472374) - 0.539164900779724*sin(0.588991343975067*x0 + 0.218530684709549) - 0.217753425240517*sin(0.757192313671112*x0 - 0.401704877614975) + 0.519339263439178*cos(0.944219052791595*x0 + 0.119548119604588) + 0.486043214797974*cos(0.961868822574615*x0 + 0.0442788749933243) - 0.00856994744390249)*(0.0462779328227043*(-0.792992651462555*x0 - 0.0266538988798857)*(1.20189297199249*x0 + 0.0326758101582527) - 0.384604781866074*(0.409473866224289*x0 - 0.0535124987363815)*(0.458635061979294*x0 - 0.0637139305472374) - 0.438745021820068*sin(0.588991343975067*x0 + 0.218530684709549) + 0.177688211202621*sin(0.757192313671112*x0 - 0.401704877614975) - 0.224421381950378*cos(0.944219052791595*x0 + 0.119548119604588) - 0.0823211073875427*cos(0.961868822574615*x0 + 0.044278874993

In [58]:
_, flat_params = flatten(params)
x0, _, info = scipy.optimize.fmin_l_bfgs_b(
        np_fn_det,
        x0 = np.array(flat_params),
        maxfun=500,
        factr=1.,
        m=500,
        pgtol=1e-13,
        maxls=20)

4.571755
0.28928602
0.17935148
0.16201106
0.106317736
0.05408343
4.893029
0.009474921
0.002595358
0.0021586067
0.0017636507
0.0014862885
0.00074124924
0.0005076983
0.0004125147
0.0003547914
0.00024908505
0.00010206136
4.0977226e-05
3.685751e-05
3.6697493e-05
3.6540823e-05
3.5887697e-05
3.4599278e-05
3.1277596e-05
2.4964751e-05
1.5799695e-05
8.033768e-06
6.959201e-06
6.807201e-06
6.778847e-06
6.7617343e-06
6.7180677e-06
6.653607e-06
6.441908e-06
5.9650433e-06
5.0079457e-06
4.1077374e-06
3.0776025e-06
2.1784824e-06
1.5216764e-06
4.1482367e-06
9.520898e-07
4.371858e-06
7.6112735e-07
6.760302e-07
5.803523e-07
3.942491e-07
2.738113e-07
2.3723774e-07
2.0688191e-07
1.8814626e-07
1.745227e-07
1.655817e-07
1.4804932e-07
1.296257e-07
9.8683515e-08
8.000273e-08
7.6959e-08
5.7598207e-08
6.0734514e-08
3.4588034e-08
2.9376466e-08
2.808023e-08
2.7523331e-08
2.6939361e-08
2.4157789e-08
2.0083135e-08
1.4308092e-08
3.5398024e-08
1.2945061e-08
1.0030074e-08
9.288619e-09
8.580911e-09
7.851426e-09
6.846807

In [59]:
get_symbolic_expr(unflatten(spec, x0), funs, use_l0=False)[0]

-0.53101785473171*(-0.0131915958657225*(-1.12383126429409*x0 - 0.0098654003411941)*(-1.01451619670063*x0 - 0.00874413515476435) + 0.480637305208576*(0.446504662344*x0 - 0.0482387813739118)*(1.50413786850424*x0 + 0.0523266252083551) - 0.290916413625701*sin(1.58944357756049*x0 - 0.36410563416881) + 0.335708930048034*sin(1.68646077541525*x0 + 0.145793885492396) + 0.664926168829447*cos(0.344557237147682*x0 + 0.0122060841696888) + 0.606550041418038*cos(0.959284263573444*x0 - 0.024713889252506) + 0.251296668103114)*(0.32160364113499*(-1.12383126429409*x0 - 0.0098654003411941)*(-1.01451619670063*x0 - 0.00874413515476435) + 0.6501229539*(0.446504662344*x0 - 0.0482387813739118)*(1.50413786850424*x0 + 0.0523266252083551) + 0.508392665637239*sin(1.58944357756049*x0 - 0.36410563416881) - 0.209937611520086*sin(1.68646077541525*x0 + 0.145793885492396) + 0.21143026446638*cos(0.344557237147682*x0 + 0.0122060841696888) - 0.655871131019627*cos(0.959284263573444*x0 - 0.024713889252506) - 0.19407826534349

In [139]:
jit_apply = jax.jit(e.apply, static_argnames='deterministic')

In [13]:
jit_part = partial(jit_apply, params)
#jax.grad(jit_part)(x)

In [9]:
e.apply(params, rngs={'l0': k1}, method=e.l0_reg)

DeviceArray(508.99625, dtype=float32)

In [12]:
k1, _ = random.split(k1)
e.apply(params, x, deterministic=False, rngs={'l0': k1})

DeviceArray([[ 8.58321381e+00,  1.51409264e+01, -2.02603173e+00,
               2.00298309e+01,  1.47476797e+01, -4.98244190e+00,
               2.11784592e+01, -1.21116362e+01, -1.42601099e+01,
              -4.60530281e+00],
             [ 5.63098877e+02,  2.89830261e+02,  6.18977844e+02,
               4.19883545e+02,  2.96393799e+02, -3.36902313e+02,
              -1.66018234e+02, -5.47439697e+02, -7.59330521e+01,
              -1.87522095e+02],
             [ 1.41679535e+01,  2.39808059e+00,  7.43775249e-01,
               1.33280349e+00,  1.10605164e+01,  1.10182362e+01,
               1.54590921e+01, -3.53029418e+00,  8.96689796e+00,
               1.93303025e+00],
             [ 7.56713000e+05,  6.11489188e+05, -5.11581875e+05,
              -7.13832312e+05, -4.38223906e+05, -5.07743945e+04,
               2.48641734e+05,  7.31543625e+05,  2.68542062e+05,
              -3.95877188e+05],
             [ 4.73204404e-01, -9.08424735e-01, -7.54241526e-01,
              -2.32303977e+

In [16]:
k1, _ = random.split(k1)
jit_apply(params, x, deterministic=False, rngs={'l0': k1})

DeviceArray([[-3.78209448e+00,  4.46613979e+00, -5.88704634e+00,
               1.94581795e+00,  1.25876462e+00,  5.66233540e+00,
              -5.79753447e+00, -3.79149890e+00,  7.15724325e+00,
               7.17711067e+00],
             [ 3.95358396e+00,  6.34684658e+00, -2.64497590e+00,
               3.13939035e-01, -1.19828969e-01,  1.12981653e+00,
              -2.60230207e+00,  1.02555215e+00, -2.70521045e+00,
              -4.04525566e+00],
             [-4.18014479e+00, -3.11423206e+00, -3.79945159e+00,
              -4.71738863e+00,  2.46585870e+00, -5.51736546e+00,
              -4.07708788e+00,  1.80600092e-01, -1.94446886e+00,
               9.08502400e-01],
             [ 3.12757816e+01, -8.57553899e-01,  1.08915466e+02,
               1.92732449e+01, -2.73622179e+00, -1.32710707e+00,
               3.48873405e+01, -1.96826286e+01, -4.13103142e+01,
               5.81769943e-01],
             [-1.20899057e+00,  5.42552710e-01, -7.33462906e+00,
              -3.19516850e+

In [31]:
#symb = get_symbolic_expr(unflatten(spec, x0), funs)
sy.lambdify("x0, x1, x2", symb[0])(x[:,0], x[:,1], x[:,2])

DeviceArray([ 0.57243985,  0.4264363 ,  0.20558922, ...,  0.4450177 ,
             -0.22650672,  0.41479236], dtype=float32)

In [8]:
sy.lambdify(sy.symbols("x0, x1"), symb[1], "numpy")(x[:,0],x[:,1])

DeviceArray([ 0.29320884, 20.998196  ,  3.2638202 ], dtype=float32)