In [12]:
import jax
from typing import Any, Callable, Sequence, Optional
from jax import lax, random, numpy as jnp
import flax
from flax import linen as nn
import sympy as sy
import numpy as np
import sys
sys.path.append("..")
from eql.eqlearner import EQL
from eql.symbolic import get_symbolic_expr, get_symbolic_expr_layer
from typing import List, Tuple, Callable
from functools import partial
import matplotlib.pyplot as plt
import scipy
sys.path.append("../../orient")
import optax

from np_utils import flatten, unflatten

In [13]:
funs = ['mul', 'cos', 'sin']*2
e = EQL(n_layers=2, functions=funs, features=1)
key = random.PRNGKey(0)

In [14]:
N = 1000
xdim = 1
x = (random.uniform(key, (N, xdim))-.5) * 2

#y = x[:,0] + jnp.cos(x[:,1])
y = jnp.cos(x) + 1 - x**2 

In [15]:
params = e.init({'params':key}, x);

In [16]:
params

FrozenDict({
    params: {
        layers_0: {
            linear_layer: {
                kernel: DeviceArray([[-7.2592092e-01,  1.4928404e-03, -1.1177226e+00,
                              -1.5228953e+00, -8.1515509e-01, -2.2600982e+00,
                               2.4191648e-01, -1.8337251e+00]], dtype=float32),
                bias: DeviceArray([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
            },
        },
        layers_1: {
            linear_layer: {
                kernel: DeviceArray([[ 0.02277731,  0.17892736, -0.51479936, -0.6999076 ,
                              -0.43175164, -0.38907874,  0.21593012, -0.31170368],
                             [ 0.34880134, -0.19892956, -0.57989043, -0.01286852,
                              -0.1202417 , -0.08883591,  0.30248502, -0.74179024],
                             [ 0.4252634 ,  0.31946254, -0.28948838,  0.7172954 ,
                               0.3699872 ,  0.30644968, -0.339205  ,  0.31891167],
                    

In [17]:
def mse_fn(params):
    pred = e.apply(params, x)
    return jnp.mean((pred-y)**2)

In [18]:
def loss(params):
    return mse_fn(params)
loss_grad_fn = jax.jit(jax.value_and_grad(loss))

In [19]:
tx = optax.adam(learning_rate=1e-2)
opt_state = tx.init(params)
loss_grad_fn = jax.jit(jax.value_and_grad(loss))

In [21]:
for i in range(10000):
    loss_val, grads = loss_grad_fn(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 99 == 0:
        print(loss_val)

1.0019127e-05
8.54314e-06
7.2747416e-06
6.210226e-06
5.3362523e-06
4.631923e-06
4.070804e-06
3.6245508e-06
3.2655016e-06
2.9697164e-06
2.7179026e-06
2.496307e-06
2.2955846e-06
2.1100368e-06
1.936609e-06
1.7737342e-06
1.620818e-06
1.4775273e-06
1.3439228e-06
1.2199718e-06
1.1056237e-06
1.0008273e-06
9.053834e-07
8.189853e-07
7.41341e-07
6.720133e-07
6.105197e-07
5.5630824e-07
5.088538e-07
4.6747743e-07
4.3158423e-07


KeyboardInterrupt: 

In [22]:
params

FrozenDict({
    params: {
        last: {
            bias: DeviceArray([0.18907651], dtype=float32),
            kernel: DeviceArray([[ 0.3093719 ],
                         [-0.20909262],
                         [ 0.2651188 ],
                         [ 0.49299037],
                         [ 0.588305  ],
                         [ 0.3081454 ]], dtype=float32),
        },
        layers_0: {
            linear_layer: {
                bias: DeviceArray([ 0.12463151,  0.05000146, -0.15119174, -0.30613348,
                              0.09569556,  0.11766236, -0.03175413, -0.29099533],            dtype=float32),
                kernel: DeviceArray([[-0.6132989,  0.0657657, -1.1314934, -1.2790749, -0.8241023,
                              -2.2584727,  0.3420058, -1.4921745]], dtype=float32),
            },
        },
        layers_1: {
            linear_layer: {
                bias: DeviceArray([ 0.13890126,  0.2596124 ,  0.21440063, -0.21265374,
                              0.01

In [23]:
symb = get_symbolic_expr(params, funs)[0]
print(symb)

0.588304996490479*(-0.44148987531662*(0.0956955552101135 - 0.824102282524109*x0)*(0.117662355303764 - 2.25847268104553*x0) + 0.641749799251556*(0.124631509184837 - 0.613298892974854*x0)*(0.0657657012343407*x0 + 0.0500014573335648) - 0.146095961332321*sin(1.27907490730286*x0 + 0.306133478879929) - 0.2213164716959*sin(1.49217450618744*x0 + 0.290995329618454) + 0.564487814903259*cos(0.342005789279938*x0 - 0.03175413236022) + 0.233935862779617*cos(1.13149344921112*x0 + 0.151191741228104) + 0.138901263475418)*(0.377752512693405*(0.0956955552101135 - 0.824102282524109*x0)*(0.117662355303764 - 2.25847268104553*x0) + 0.729577660560608*(0.124631509184837 - 0.613298892974854*x0)*(0.0657657012343407*x0 + 0.0500014573335648) + 0.028958685696125*sin(1.27907490730286*x0 + 0.306133478879929) + 0.0681050047278404*sin(1.49217450618744*x0 + 0.290995329618454) + 0.579850435256958*cos(0.342005789279938*x0 - 0.03175413236022) + 0.480754405260086*cos(1.13149344921112*x0 + 0.151191741228104) + 0.259612411260