In [42]:
import jax
from jax import lax, random, numpy as jnp
from jax.tree_util import tree_flatten, tree_unflatten
import wandb
import flax
from flax import linen as nn

import sympy as sy
import numpy as np

import sys
sys.path.append("..")

from eql.eqlearner import EQL
from eql.symbolic import get_symbolic_expr, get_symbolic_expr_layer

import optax
import scipy
from functools import partial

In [None]:
wandb.init(
    # set the wandb project where this run will be logged
    project="Example_2",

    # track hyperparameters and run metadata
    config={
    "learning_rate": 0.01,
    "architecture": "EQL_2_Layers",
    "epochs": 5000,
    "optimizer": "Adam",
    "regularization": "No_reg(1000)+L1(3000)+Pruning(1000)",
    "Batchsize": 1000,
    "Reg_Factor": 0.01,
    "Threshold": 0.001,
    "input_dim": 1,
    "output_dim": 1
    }
)

In [44]:
funs = ['mul', 'cos', 'sin']*2
e = EQL(n_layers=2, functions=funs, features=1)
key = random.PRNGKey(0)

In [45]:
N = 1000
xdim = 1
x = (random.uniform(key, (N, xdim))-.5) * 2
#x = np.array([[1., 2.]]).T
#x = np.linspace(-1, 1, N)[:,None]
#y = x[:,0] + jnp.cos(x[:,1])
y = jnp.cos(x) + 1 - x**2 

In [46]:
params = e.init({'params':key}, x)

In [47]:
def f(x, p):
    return e.apply(params, x)

In [48]:
print(jax.make_jaxpr(f)(x,params))

{ [34m[22m[1mlambda [39m[22m[22ma[35m:f32[1,8][39m b[35m:f32[8][39m c[35m:f32[6,8][39m d[35m:f32[8][39m e[35m:f32[6,1][39m f[35m:f32[1][39m; g[35m:f32[1000,1][39m
    h[35m:f32[1][39m i[35m:f32[6,1][39m j[35m:f32[8][39m k[35m:f32[1,8][39m l[35m:f32[8][39m m[35m:f32[6,8][39m. [34m[22m[1mlet
    [39m[22m[22mn[35m:f32[1000,8][39m = dot_general[dimension_numbers=(([1], [0]), ([], []))] g a
    o[35m:f32[1,8][39m = reshape[dimensions=None new_sizes=(1, 8) sharding=None] b
    p[35m:f32[1000,8][39m = add n o
    q[35m:f32[1000,1][39m = slice[
      limit_indices=(1000, 3)
      start_indices=(0, 2)
      strides=None
    ] p
    r[35m:f32[1000][39m = squeeze[dimensions=(1,)] q
    s[35m:f32[1000][39m = cos r
    t[35m:f32[1000,1][39m = slice[
      limit_indices=(1000, 4)
      start_indices=(0, 3)
      strides=None
    ] p
    u[35m:f32[1000][39m = squeeze[dimensions=(1,)] t
    v[35m:f32[1000][39m = sin u
    w[35m:f32[1000,1][39m 

In [49]:
def mse_fn(params):
    pred = e.apply(params, x)
    return jnp.mean((pred-y)**2)


def get_mask_spec(thresh, params):
    flat, spec = tree_flatten(params)
    mask = [jnp.abs(f) > thresh for f in flat]
    return mask, spec

def apply_mask(mask, spec, params):
    flat, _ = tree_flatten(params)
    masked_params = tree_unflatten(spec, [f*m for f,m in zip(flat, mask)])
    return masked_params


def get_masked_mse(thresh, params):
    mask, spec = get_mask_spec(thresh, params)
    def masked_mse(params):
        masked_params = apply_mask(mask, spec, params)
        return mse_fn(masked_params)
    return jax.jit(masked_mse)
    

def l1_fn(params):
    return sum(
        jnp.abs(w).mean() for w in jax.tree.leaves(params["params"])
    )

In [50]:
def get_loss(lamba):
    def loss_fn(params):
        return mse_fn(params)  + lamba * l1_fn(params)
    return loss_fn

def get_loss_grad(lamba):
    loss = get_loss(lamba)
    return jax.jit(jax.value_and_grad(loss))

In [51]:
tx = optax.adam(learning_rate=1e-2)
opt_state = tx.init(params)

In [52]:
loss_grad_1 = get_loss_grad(0)
loss_grad_2 = get_loss_grad(1e-2)

for i in range(1000):
    loss_val, grads = loss_grad_1(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 100 == 0 and i > 0:
        print(loss_val)
        wandb.log({"loss": loss_val})

for i in range(3000):
    loss_val, grads = loss_grad_2(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 100 == 0 and i > 0:
        print(loss_val)
        print(l1_fn(params))
        wandb.log({"loss": loss_val})
        
thr = 1e-3
loss_grad_masked = jax.jit(jax.value_and_grad(get_masked_mse(thr, params)))

for i in range(1000):
    loss_val, grads = loss_grad_masked(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    mask, spec = get_mask_spec(thr, params)
    params = apply_mask(mask, spec, params)
    if i % 100 == 0 and i > 0:
        print(loss_val)
        wandb.log({"loss": loss_val})

0.0027960988
0.00041793875
0.00026049122
0.00010524316
1.8509174e-05
4.316204e-06
3.6661515e-06
3.5083053e-06
3.3494446e-06
0.015170485
1.514023
0.012992911
1.2966099
0.011273896
1.1249194
0.010014574
0.9998518
0.009265739
0.92502797
0.008786042
0.8761539
0.008316746
0.8292042
0.007901113
0.78824556
0.007455474
0.7426778
0.0072117704
0.71789104
0.0070529766
0.701208
0.0068570278
0.68472946
0.0067139305
0.6674716
0.006532983
0.6504694
0.006347271
0.63327837
0.0062148757
0.6185476
0.006051357
0.6032225
0.005900569
0.58672076
0.005724036
0.5700405
0.005568296
0.5526626
0.005346085
0.532651
0.005204119
0.5183987
0.005270538
0.5052099
0.00494418
0.49054587
0.004792625
0.4772443
0.0046483744
0.46311858
0.0045114206
0.4486933
0.0043865885
0.43633038
0.004286748
0.42675477
4.161805e-06
1.8518053e-06
7.9769654e-07
3.3133182e-07
1.3207847e-07
5.0497704e-08
1.0379539e-05
5.798387e-09
6.9755015e-09


In [53]:
params

{'params': {'last': {'bias': Array([0.], dtype=float32),
   'kernel': Array([[ 0.        ],
          [ 0.12263943],
          [ 0.        ],
          [ 0.        ],
          [-0.00386547],
          [-0.53217864]], dtype=float32)},
  'layers_0': {'linear_layer': {'bias': Array([ 0.        ,  0.        ,  0.        ,  0.        , -0.01496816,
            0.        ,  0.        ,  0.        ], dtype=float32),
    'kernel': Array([[ 0.49164796, -0.49165705,  0.03423718,  0.        ,  0.00303308,
             0.        ,  0.        ,  0.        ]], dtype=float32)}},
  'layers_1': {'linear_layer': {'bias': Array([ 0.        , -0.00319671,  0.        ,  0.        ,  0.        ,
            0.01109702,  0.        ,  0.01310011], dtype=float32),
    'kernel': Array([[ 0.        ,  0.        ,  0.        ,  0.27129987,  1.8525475 ,
            -0.30027294,  0.        ,  0.        ],
           [ 0.        ,  0.        ,  0.        ,  0.        , -0.00344228,
            -0.00859725,  0.     

In [54]:
symb = get_symbolic_expr(apply_mask(mask, spec, params), funs)[0]
symb

-0.532178640365601*(-0.102202316171661*x0**2 + 1.85254752635956*cos(0.0342371799051762*x0) + 1.91213011741638)*(0.722062601609799*x0**2 - 0.300272941589355*cos(0.0342371799051762*x0) - 0.639033070765436) + 0.122639432549477*sin(0.271299868822098*cos(0.0342371799051762*x0) + 0.627889752388)

In [55]:
sy.expand(symb)

0.0392729052309385*x0**4 - 0.728203383803228*x0**2*cos(0.0342371799051762*x0) - 0.769524091278662*x0**2 + 0.122639432549477*sin(0.271299868822098*cos(0.0342371799051762*x0) + 0.627889752388) + 0.296034956490051*cos(0.0342371799051762*x0)**2 + 0.935570086784625*cos(0.0342371799051762*x0) + 0.650276733729862

In [56]:
wandb.finish()

0,1
loss,▂▁▁▁▁▁▁▁█▇▆▆▅▅▅▄▄▄▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▁▁▁▁▁▁▁

0,1
loss,0.0
