In [1]:
import jax
from typing import Any, Callable, Sequence, Optional
from jax import lax, random, numpy as jnp
import flax
from flax import linen as nn
import sympy as sy
import numpy as np
import sys
sys.path.append("..")
from eql.eqlearner import EQL
from eql.symbolic import get_symbolic_expr, get_symbolic_expr_layer
from typing import List, Tuple, Callable
from functools import partial
import matplotlib.pyplot as plt
import scipy
sys.path.append("../../orient")
import optax

from eql.np_utils import flatten, unflatten

In [2]:
funs = ['mul', 'cos', 'sin']*2
e = EQL(n_layers=2, functions=funs, features=1)
key = random.PRNGKey(0)

In [3]:
N = 1000
xdim = 1
x = (random.uniform(key, (N, xdim))-.5) * 2

#y = x[:,0] + jnp.cos(x[:,1])
y = jnp.cos(x) + 1 - x**2 

In [4]:
params = e.init({'params':key}, x);

In [5]:
params

{'params': {'layers_0': {'linear_layer': {'kernel': Array([[ 0.6292481 , -0.54229003, -0.4677973 ,  0.70746535,  0.09610943,
             0.4903576 ,  0.06954016,  0.31511047]], dtype=float32),
    'bias': Array([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}},
  'layers_1': {'linear_layer': {'kernel': Array([[ 0.19629148,  0.06706532,  0.00969497,  0.23161124, -0.11875771,
             0.01930422,  0.05764793, -0.0596979 ],
           [ 0.4696742 ,  0.46638718,  0.21155392, -0.8553514 ,  0.6172892 ,
            -0.44233188,  0.22190586, -0.18106432],
           [-0.44605988, -0.20469512, -0.32464647,  0.5807687 ,  0.02217815,
            -0.2559955 ,  0.17005439,  0.71249163],
           [ 0.45161504, -0.10305008, -0.47209048,  0.35264972, -0.3844344 ,
             0.2527192 , -0.5963687 ,  0.310487  ],
           [-0.04032504,  0.10622181,  0.5348187 ,  0.14317314,  0.38765785,
            -0.48200536, -0.28618744,  0.3023454 ],
           [-0.90012056, -0.02397138, -0.5772944 , -0

In [6]:
def mse_fn(params):
    pred = e.apply(params, x)
    return jnp.mean((pred-y)**2)

In [7]:
def loss(params):
    return mse_fn(params)
loss_grad_fn = jax.jit(jax.value_and_grad(loss))

In [8]:
tx = optax.adam(learning_rate=1e-2)
opt_state = tx.init(params)
loss_grad_fn = jax.jit(jax.value_and_grad(loss))

In [9]:
for i in range(10000):
    loss_val, grads = loss_grad_fn(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 99 == 0:
        print(loss_val)

2.2196107
0.0027071459
0.00042087253
0.00026541023
0.000110642635
2.054354e-05
4.461378e-06
3.6788986e-06
3.520743e-06
3.3639694e-06
3.2019138e-06
3.0358747e-06
2.8671943e-06
2.6971602e-06
2.5270863e-06
2.3581583e-06
2.1915125e-06
2.0283485e-06
1.8695698e-06
1.7162208e-06
1.569199e-06
1.4292385e-06
1.2969997e-06
1.1730492e-06
1.0577837e-06
9.51501e-07
8.5435147e-07
7.662375e-07
6.87114e-07
6.166305e-07
6.315026e-07
4.9386915e-07
4.5009946e-07
3.1735233e-06
3.7769922e-07
3.5093436e-07
5.5306725e-07
3.1114087e-07
2.9419556e-07
3.2072515e-07
2.7255803e-07
4.5385406e-07
2.6688596e-07
2.4793906e-07
1.4758874e-06
2.591921e-07
2.3213772e-07
2.2033004e-07
2.690237e-07
2.2012121e-07
2.0796222e-07
0.00033590486
2.2703011e-07
2.0785988e-07
1.9522857e-07
1.8318725e-07
3.5375515e-07
1.8236975e-07
1.7156994e-07
1.9076366e-07
1.708597e-07
1.5901729e-07
3.6061323e-07
1.6430673e-07
1.5226041e-07
1.2444579e-05
1.5615336e-07
1.4388971e-07
1.5345247e-05
1.4493673e-07
1.3257339e-07
3.2619166e-06
1.3538528e

In [10]:
params

{'params': {'last': {'bias': Array([0.25747997], dtype=float32),
   'kernel': Array([[ 0.6018571 ],
          [ 0.9976524 ],
          [-0.4062651 ],
          [ 0.14818056],
          [ 0.01753187],
          [-0.39656726]], dtype=float32)},
  'layers_0': {'linear_layer': {'bias': Array([-0.04495414,  0.04862367, -0.28518608, -0.11209451, -0.06013893,
           -0.14098652, -0.3750776 , -0.20590888], dtype=float32),
    'kernel': Array([[ 0.9447709 , -0.86982423, -0.67828965,  0.77243936, -0.10001896,
             0.3091637 , -0.7011913 , -0.07298253]], dtype=float32)}},
  'layers_1': {'linear_layer': {'bias': Array([-0.00438732,  0.01088454,  0.15295629,  0.21148321,  0.40738818,
           -0.05994988,  0.26964036, -0.0109344 ], dtype=float32),
    'kernel': Array([[ 0.1705253 ,  0.05426419,  0.17521882,  0.4319874 ,  0.2371408 ,
            -0.1555451 ,  0.34666124, -0.04287156],
           [ 0.20493104,  0.20742822,  1.0536268 , -1.0619944 ,  0.46606606,
            -0.11224312, 

In [11]:
symb = get_symbolic_expr(params, funs)[0]
print(symb)

-0.39656725525856*(-1.34242975711823*(0.0486236698925495 - 0.869824230670929*x0)*(0.944770872592926*x0 - 0.0449541434645653) + 0.0693667232990265*(-0.100018955767155*x0 - 0.0601389296352863)*(0.309163689613342*x0 - 0.140986517071724) - 0.278187543153763*sin(0.0729825273156166*x0 + 0.205908879637718) - 0.112243123352528*sin(0.772439360618591*x0 - 0.112094506621361) - 0.155545100569725*cos(0.678289651870728*x0 + 0.28518608212471) - 0.419852674007416*cos(0.701191306114197*x0 + 0.375077605247498) - 0.0599498823285103)*(-0.373554795980453*(0.0486236698925495 - 0.869824230670929*x0)*(0.944770872592926*x0 - 0.0449541434645653) - 0.347931414842606*(-0.100018955767155*x0 - 0.0601389296352863)*(0.309163689613342*x0 - 0.140986517071724) + 0.611740410327911*sin(0.0729825273156166*x0 + 0.205908879637718) + 0.466066062450409*sin(0.772439360618591*x0 - 0.112094506621361) + 0.23714080452919*cos(0.678289651870728*x0 + 0.28518608212471) + 0.34836158156395*cos(0.701191306114197*x0 + 0.375077605247498) + 