In [13]:
import jax
from jax import lax, random, numpy as jnp
from jax.tree_util import tree_flatten, tree_unflatten

import flax
from flax import linen as nn

import sympy as sy
import numpy as np

import sys
sys.path.append("..")

from eql.eqlearner import EQL
from eql.symbolic import get_symbolic_expr, get_symbolic_expr_layer

import optax
import scipy
from functools import partial

In [14]:
funs = ['mul', 'cos', 'sin']*2
e = EQL(n_layers=2, functions=funs, features=1)
key = random.PRNGKey(0)

In [15]:
N = 1000
xdim = 1
x = (random.uniform(key, (N, xdim))-.5) * 2
#x = np.array([[1., 2.]]).T
#x = np.linspace(-1, 1, N)[:,None]
#y = x[:,0] + jnp.cos(x[:,1])
y = jnp.cos(x) + 1 - x**2 

In [16]:
params = e.init({'params':key}, x);

In [17]:
def f(x, p):
    return e.apply(params, x)

In [18]:
print(jax.make_jaxpr(f)(x,params))

{ [34m[22m[1mlambda [39m[22m[22ma[35m:f32[1,8][39m b[35m:f32[8][39m c[35m:f32[6,8][39m d[35m:f32[8][39m e[35m:f32[6,1][39m f[35m:f32[1][39m; g[35m:f32[1000,1][39m
    h[35m:f32[1][39m i[35m:f32[6,1][39m j[35m:f32[8][39m k[35m:f32[1,8][39m l[35m:f32[8][39m m[35m:f32[6,8][39m. [34m[22m[1mlet
    [39m[22m[22mn[35m:f32[1000,8][39m = dot_general[dimension_numbers=(([1], [0]), ([], []))] g a
    o[35m:f32[1,8][39m = reshape[dimensions=None new_sizes=(1, 8)] b
    p[35m:f32[1000,8][39m = add n o
    q[35m:f32[1000,1][39m = slice[
      limit_indices=(1000, 3)
      start_indices=(0, 2)
      strides=None
    ] p
    r[35m:f32[1000][39m = squeeze[dimensions=(1,)] q
    s[35m:f32[1000][39m = cos r
    t[35m:f32[1000,1][39m = slice[
      limit_indices=(1000, 4)
      start_indices=(0, 3)
      strides=None
    ] p
    u[35m:f32[1000][39m = squeeze[dimensions=(1,)] t
    v[35m:f32[1000][39m = sin u
    w[35m:f32[1000,1][39m = slice[
     

In [19]:
def mse_fn(params):
    pred = e.apply(params, x)
    return jnp.mean((pred-y)**2)


def get_mask_spec(thresh, params):
    flat, spec = tree_flatten(params)
    mask = [jnp.abs(f) > thresh for f in flat]
    return mask, spec

def apply_mask(mask, spec, params):
    flat, _ = tree_flatten(params)
    masked_params = tree_unflatten(spec, [f*m for f,m in zip(flat, mask)])
    return masked_params


def get_masked_mse(thresh, params):
    mask, spec = get_mask_spec(thresh, params)
    def masked_mse(params):
        masked_params = apply_mask(mask, spec, params)
        return mse_fn(masked_params)
    return jax.jit(masked_mse)
    

def l1_fn(params):
    return sum(
        jnp.abs(w).mean() for w in jax.tree.leaves(params["params"])
    )

In [20]:
def get_loss(lamba):
    def loss_fn(params):
        return mse_fn(params)  + lamba * l1_fn(params)
    return loss_fn

def get_loss_grad(lamba):
    loss = get_loss(lamba)
    return jax.jit(jax.value_and_grad(loss))

def get_proj_loss_grad(lamba):
    def loss_fn(params):
        
        mse = mse_fn(params)
        l1 = l1_fn(params)

        mse_grad = jax.grad(mse_fn)(params)
        l1_grad = jax.grad(l1_fn)(params)

        mse_grad_flat, tree_def = tree_flatten(mse_grad)
        l1_grad_flat, _ = tree_flatten(l1_grad)

        dot_product = sum(jnp.vdot(m, l) for m, l in zip(mse_grad_flat, l1_grad_flat))
        mse_grad_norm_squared = sum(jnp.vdot(m, m) for m in mse_grad_flat) 
        
        if(mse_grad_norm_squared == 0):
            return lamba * l1_grad

        proj_l1_flat = (dot_product/mse_grad_norm_squared) * mse_grad_flat

        l1_grad_orthogonal = tree_unflatten(tree_def, l1_grad_orthogonal)

        return mse
    return loss_fn

In [21]:
tx = optax.adam(learning_rate=1e-2)
opt_state = tx.init(params)

In [22]:
loss_grad_1 = get_loss_grad(0)
loss_grad_2 = get_loss_grad(1e-2)

for i in range(1000):
    loss_val, grads = loss_grad_1(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 99 == 0:
        print(loss_val)

for i in range(3000):
    loss_val, grads = loss_grad_2(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 99 == 0:
        print(loss_val)
        print(l1_fn(params))
        
thr = 1e-3
loss_grad_masked = jax.jit(jax.value_and_grad(get_masked_mse(thr, params)))
mask, spec = get_mask_spec(thr, params)

for i in range(1000):
    loss_val, grads = loss_grad_masked(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 99 == 0:
        print(loss_val)

2.2196107
0.002707143
0.0004208714
0.00026540947
0.00011064221
2.0543432e-05
4.4613707e-06
3.6789017e-06
3.5207515e-06
3.363988e-06
3.2019302e-06
0.01937606
1.9362588
0.015194669
1.5163493
0.013030677
1.3007544
0.011318412
1.1294503
0.010037643
1.0021915
0.00930468
0.9281473
0.008803648
0.8784219
0.008348987
0.8320803
0.007947914
0.7910663
0.007518957
0.7469034
0.0072142705
0.72014046
0.007057825
0.70315003
0.006882266
0.6867191
0.00677687
0.66987675
0.0065443413
0.6523567
0.0063646003
0.63525206
0.0062312284
0.62128615
0.006071429
0.6060356
0.005926168
0.58973813
0.0057794475
0.57365894
0.00560639
0.5558394
0.005391321
0.5370958
0.0052899597
0.52133024
0.0050951154
0.508272
0.004964719
0.4939353
0.0048739086
0.48035818
0.0046893465
0.4669354
0.004565598
0.45338818
0.004405324
0.43899575
0.004319337
0.4299099
0.0042511504
0.42022023
1.312282e-05
9.035533e-06
9.031992e-06
9.02654e-06
9.017615e-06
9.002922e-06
1.0821866e-05
8.933939e-06
8.857901e-06
8.870014e-06
8.462183e-06


In [23]:
symb = get_symbolic_expr(apply_mask(mask, spec, params), funs)[0]
symb

-0.53297632932663*(1.91601252555847*cos(0.0608053095638752*x0) + 1.84970014420013)*(0.728689376026723*x0**2 - 0.00173055101186037*sin(0.0144784711301327*x0) - 0.648524641990662*cos(0.0608053095638752*x0) - 0.298678778403794) + 0.122283354401588*sin(0.627760708332062*cos(0.0608053095638752*x0) + 0.271160759674615)

In [24]:
sy.expand(symb)

-0.744129810447942*x0**2*cos(0.0608053095638752*x0) - 0.718375793126916*x0**2 + 0.00176722021590023*sin(0.0144784711301327*x0)*cos(0.0608053095638752*x0) + 0.00170605747330971*sin(0.0144784711301327*x0) + 0.122283354401588*sin(0.627760708332062*cos(0.0608053095638752*x0) + 0.271160759674615) + 0.662266439983932*cos(0.0608053095638752*x0)**2 + 0.944353258672449*cos(0.0608053095638752*x0) + 0.294451396417965