In [11]:
import jax
from jax import lax, random, numpy as jnp
from jax.tree_util import tree_flatten, tree_unflatten

import flax
from flax import linen as nn

import sympy as sy
import numpy as np

import sys
sys.path.append("..")

from eql.eqlearner import EQL
from eql.symbolic import get_symbolic_expr, get_symbolic_expr_layer

import optax
import scipy
from functools import partial

In [12]:
funs = ['mul', 'cos', 'sin']*2
e = EQL(n_layers=2, functions=funs, features=1)
key = random.PRNGKey(0)

In [13]:
N = 1000
xdim = 1
x = (random.uniform(key, (N, xdim))-.5) * 2
#x = np.array([[1., 2.]]).T
#x = np.linspace(-1, 1, N)[:,None]
#y = x[:,0] + jnp.cos(x[:,1])
y = jnp.cos(x) + 1 - x**2 

In [14]:
params = e.init({'params':key}, x);

In [15]:
def mse_fn(params):
    pred = e.apply(params, x)
    return jnp.mean((pred-y)**2)


def get_mask_spec(thresh, params):
    flat, spec = tree_flatten(params)
    mask = [jnp.abs(f) > thresh for f in flat]
    return mask, spec

def apply_mask(mask, spec, params):
    flat, _ = tree_flatten(params)
    masked_params = tree_unflatten(spec, [f*m for f,m in zip(flat, mask)])
    return masked_params


def get_masked_mse(thresh, params):
    mask, spec = get_mask_spec(thresh, params)
    def masked_mse(params):
        masked_params = apply_mask(mask, spec, params)
        return mse_fn(masked_params)
    return jax.jit(masked_mse)
    

def l1_fn(params):
    return sum(
        jnp.abs(w).mean() for w in jax.tree.leaves(params["params"])
    )

In [16]:
params

{'params': {'layers_0': {'linear_layer': {'kernel': Array([[ 0.6292481 , -0.54229003, -0.4677973 ,  0.70746535,  0.09610943,
             0.4903576 ,  0.06954016,  0.31511047]], dtype=float32),
    'bias': Array([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}},
  'layers_1': {'linear_layer': {'kernel': Array([[ 0.19629148,  0.06706532,  0.00969497,  0.23161124, -0.11875771,
             0.01930422,  0.05764793, -0.0596979 ],
           [ 0.4696742 ,  0.46638718,  0.21155392, -0.8553514 ,  0.6172892 ,
            -0.44233188,  0.22190586, -0.18106432],
           [-0.44605988, -0.20469512, -0.32464647,  0.5807687 ,  0.02217815,
            -0.2559955 ,  0.17005439,  0.71249163],
           [ 0.45161504, -0.10305008, -0.47209048,  0.35264972, -0.3844344 ,
             0.2527192 , -0.5963687 ,  0.310487  ],
           [-0.04032504,  0.10622181,  0.5348187 ,  0.14317314,  0.38765785,
            -0.48200536, -0.28618744,  0.3023454 ],
           [-0.90012056, -0.02397138, -0.5772944 , -0

In [17]:
def get_loss(lamba):
    def loss_fn(params):
        return mse_fn(params)  + lamba * l1_fn(params)
    return loss_fn

def get_loss_grad(lamba):
    loss = get_loss(lamba)
    return jax.jit(jax.value_and_grad(loss))

In [18]:
tx = optax.adam(learning_rate=1e-2)
opt_state = tx.init(params)

In [19]:
loss_grad_1 = get_loss_grad(0)
loss_grad_2 = get_loss_grad(1e-2)

for i in range(1000):
    loss_val, grads = loss_grad_1(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 99 == 0:
        print(loss_val)

for i in range(3000):
    loss_val, grads = loss_grad_2(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 99 == 0:
        print(loss_val)
        print(l1_fn(params))
        
thr = 1e-3
loss_grad_masked = jax.jit(jax.value_and_grad(get_masked_mse(thr, params)))
mask, spec = get_mask_spec(thr, params)

for i in range(1000):
    loss_val, grads = loss_grad_masked(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 99 == 0:
        print(loss_val)

2.2196107
0.0027071459
0.00042087253
0.00026541023
0.000110642635
2.054354e-05
4.461378e-06
3.6788986e-06
3.520743e-06
3.3639694e-06
3.2019138e-06
0.01937606
1.9362589
0.015194671
1.5163494
0.013030678
1.3007543
0.011318412
1.1294502
0.010037608
1.0021915
0.0093057
0.92825687
0.008812487
0.8789349
0.008337744
0.83210635
0.007929334
0.7913948
0.0074878247
0.74704474
0.007223872
0.7198574
0.007066483
0.7029166
0.0068825274
0.68673295
0.006710681
0.66948044
0.0065550273
0.65283126
0.006368868
0.6353921
0.0062357467
0.62077147
0.0060761394
0.60593516
0.0059197755
0.59000623
0.0057567125
0.5735217
0.0055760015
0.55551046
0.0053894017
0.537537
0.005245107
0.5218966
0.0050993543
0.50813997
0.0049803494
0.49398902
0.004815019
0.48012897
0.0047136415
0.46753278
0.004540698
0.45279384
0.0044062613
0.438745
0.004316985
0.42961138
0.0042387904
0.4201668
5.9980455e-05
4.1937633e-06
1.8825775e-06
8.189966e-07
3.4347718e-07
1.3840443e-07
5.3546447e-08
6.074402e-05
6.448685e-09
2.4348121e-09
3.1852775

In [20]:
symb = get_symbolic_expr(apply_mask(mask, spec, params), funs)[0]
symb

-0.532041907310486*(0.722420658706981*x0**2 - 0.300174593925476*cos(0.0342495925724506*x0 - 1.95161028386792e-5) - 0.650032102755954)*(-0.102325013628188*x0**2 - 1.82454196927268e-9*x0 + 1.85237622261047*cos(0.0342495925724506*x0 - 1.95161028386792e-5) + 1.9119619212137) + 0.122631274163723*sin(0.271295964717865*cos(0.0342495925724506*x0 - 1.95161028386792e-5) + 0.627886056604718)

In [21]:
sy.expand(symb)

0.0393294442534478*x0**4 + 7.01277421074476e-10*x0**3 - 0.728317604558885*x0**2*cos(0.0342495925724506*x0 - 1.95161028386792e-5) - 0.77026650939008*x0**2 - 2.91389320838233e-10*x0*cos(0.0342495925724506*x0 - 1.95161028386792e-5) - 6.31007476242746e-10*x0 + 0.122631274163723*sin(0.271295964717865*cos(0.0342495925724506*x0 - 1.95161028386792e-5) + 0.627886056604718) + 0.295834603168116*cos(0.0342495925724506*x0 - 1.95161028386792e-5)**2 + 0.945984559434695*cos(0.0342495925724506*x0 - 1.95161028386792e-5) + 0.66124117005553