In [12]:
import jax
from jax import lax, random, numpy as jnp
from jax.tree_util import tree_flatten, tree_unflatten

import flax
from flax import linen as nn

import sympy as sy
import numpy as np

import sys
sys.path.append("..")

from eql.eqlearner import EQL
from eql.symbolic import get_symbolic_expr, get_symbolic_expr_layer

import optax
import scipy
from functools import partial

In [13]:
funs = ['mul', 'cos', 'sin']*2
e = EQL(n_layers=2, functions=funs, features=1)
key = random.PRNGKey(0)

In [14]:
N = 1000
xdim = 1
x = (random.uniform(key, (N, xdim))-.5) * 2
#x = np.array([[1., 2.]]).T
#x = np.linspace(-1, 1, N)[:,None]
#y = x[:,0] + jnp.cos(x[:,1])
y = jnp.cos(x) + 1 - x**2 

In [15]:
params = e.init({'params':key}, x);

In [18]:
def f(x, p):
    return e.apply(params, x)

In [28]:
print(jax.make_jaxpr(f)(x,params))

{ lambda a:f32[1,8] b:f32[8] c:f32[6,8] d:f32[8] e:f32[6,1] f:f32[1]; g:f32[1000,1]
    h:f32[1] i:f32[6,1] j:f32[8] k:f32[1,8] l:f32[8] m:f32[6,8]. let
    n:f32[1000,8] = dot_general[dimension_numbers=(([1], [0]), ([], []))] g a
    o:f32[1,8] = reshape[dimensions=None new_sizes=(1, 8)] b
    p:f32[1000,8] = add n o
    q:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2
    r:f32[1000] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(1,), start_index_map=(1,))
      fill_value=None
      indices_are_sorted=True
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1000, 1)
      unique_indices=True
    ] p q
    s:f32[1000] = cos r
    t:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 3
    u:f32[1000] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(1,), start_index_map=(1,))
      fill_value=None
      indices_are_sorted=True
      mode=GatherScat

In [16]:
def mse_fn(params):
    pred = e.apply(params, x)
    return jnp.mean((pred-y)**2)


def get_mask_spec(thresh, params):
    flat, spec = tree_flatten(params)
    mask = [jnp.abs(f) > thresh for f in flat]
    return mask, spec

def apply_mask(mask, spec, params):
    flat, _ = tree_flatten(params)
    masked_params = tree_unflatten(spec, [f*m for f,m in zip(flat, mask)])
    return masked_params


def get_masked_mse(thresh, params):
    mask, spec = get_mask_spec(thresh, params)
    def masked_mse(params):
        masked_params = apply_mask(mask, spec, params)
        return mse_fn(masked_params)
    return jax.jit(masked_mse)
    

def l1_fn(params):
    return sum(
        jnp.abs(w).mean() for w in jax.tree_leaves(params["params"])
    )

In [7]:
def get_loss(lamba):
    def loss_fn(params):
        return mse_fn(params)  + lamba * l1_fn(params)
    return loss_fn

def get_loss_grad(lamba):
    loss = get_loss(lamba)
    return jax.jit(jax.value_and_grad(loss))

In [8]:
tx = optax.adam(learning_rate=1e-2)
opt_state = tx.init(params)

In [9]:
loss_grad_1 = get_loss_grad(0)
loss_grad_2 = get_loss_grad(1e-2)

for i in range(1000):
    loss_val, grads = loss_grad_1(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 99 == 0:
        print(loss_val)

for i in range(3000):
    loss_val, grads = loss_grad_2(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 99 == 0:
        print(loss_val)
        print(l1_fn(params))
        
thr = 1e-3
loss_grad_masked = jax.jit(jax.value_and_grad(get_masked_mse(thr, params)))
mask, spec = get_mask_spec(thr, params)

for i in range(1000):
    loss_val, grads = loss_grad_masked(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 99 == 0:
        print(loss_val)

  jnp.abs(w).mean() for w in jax.tree_leaves(params["params"])


2.2196107
0.0027071456
0.00042087326
0.0002654106
0.00011064316
2.0543643e-05
4.4613835e-06
3.6789131e-06
3.520719e-06
3.36399e-06
3.2019113e-06
0.01937606
1.9362588
0.015194671
1.5163498
0.013030679
1.3007544
0.011318412
1.1294502
0.010037606
1.0021918
0.009304737
0.9281473
0.008803532
0.8784107
0.008343009
0.83223176
0.007927037
0.79103863
0.0074824025
0.74656904
0.007209731
0.7198598
0.0070424033
0.7031468
0.0068800915
0.68686366
0.00671784
0.6692767
0.0065376675
0.6522559
0.006375549
0.63532984
0.0062242905
0.62092984
0.006067932
0.6057775
0.0059399065
0.59048784
0.0057820575
0.5734067
0.0055793603
0.5554023
0.005396762
0.5378386
0.0052450113
0.5219404
0.0051027182
0.5084261
0.004969709
0.49459642
0.0048168446
0.48009127
0.0047163074
0.46769136
0.004551922
0.45303023
0.004463114
0.43974265
0.004315797
0.42962217
0.0042193984
0.4202358
1.2062541e-05
4.2246083e-06
1.8918056e-06
8.218631e-07
3.4434774e-07
1.3880249e-07
1.3301224e-07
1.9009605e-08
8.138984e-06
2.7658458e-09
6.4804455e-

In [10]:
symb = get_symbolic_expr(apply_mask(mask, spec, params), funs)[0]
symb

-0.532491683959961*(3.76640653219053 - 0.102846192442174*x0**2)*(0.721492115372879*x0**2 - 2.31828871619655e-5*sin(0.0143262036144733*x0) - 0.949661410410718) + 0.0959305601010714

In [11]:
sy.expand(symb)

0.0395123296994625*x0**4 - 1.26960483893245e-6*x0**2*sin(0.0143262036144733*x0) - 1.49901823220305*x0**2 + 4.64951384694626e-5*sin(0.0143262036144733*x0) + 2.00055264050324