In [64]:
import jax
from jax import lax, random, numpy as jnp
from jax.tree_util import tree_flatten, tree_unflatten
import wandb
import flax
from flax import linen as nn

import sympy as sy
import numpy as np

import sys
sys.path.append("..")

from eql.eqlearner import EQL
from eql.symbolic import get_symbolic_expr, get_symbolic_expr_layer

import optax
import scipy
from functools import partial

In [65]:
wandb.init(
    # set the wandb project where this run will be logged
    project="Example_2",

    # track hyperparameters and run metadata
    config={
    "learning_rate": 0.01,
    "architecture": "EQL_2_Layers",
    "epochs": 5000,
    "optimizer": "Adam",
    "regularization": "No_reg(1000)+OPR(3000)+Pruning(1000)",
    "Batchsize": 1000,
    "Reg_Factor": 0.01,
    "Threshold": 0.001,
    "input_dim": 1,
    "output_dim": 1
    }
)

In [66]:
funs = ['mul', 'cos', 'sin']*2
e = EQL(n_layers=2, functions=funs, features=1)
key = random.PRNGKey(0)

In [67]:
N = 1000
xdim = 1
x = (random.uniform(key, (N, xdim))-.5) * 2
#x = np.array([[1., 2.]]).T
#x = np.linspace(-1, 1, N)[:,None]
#y = x[:,0] + jnp.cos(x[:,1])
y = jnp.cos(x) + 1 - x**2 

In [68]:
params = e.init({'params':key}, x)

In [69]:
def f(x, p):
    return e.apply(params, x)

In [70]:
print(jax.make_jaxpr(f)(x,params))

{ [34m[22m[1mlambda [39m[22m[22ma[35m:f32[1,8][39m b[35m:f32[8][39m c[35m:f32[6,8][39m d[35m:f32[8][39m e[35m:f32[6,1][39m f[35m:f32[1][39m; g[35m:f32[1000,1][39m
    h[35m:f32[1][39m i[35m:f32[6,1][39m j[35m:f32[8][39m k[35m:f32[1,8][39m l[35m:f32[8][39m m[35m:f32[6,8][39m. [34m[22m[1mlet
    [39m[22m[22mn[35m:f32[1000,8][39m = dot_general[dimension_numbers=(([1], [0]), ([], []))] g a
    o[35m:f32[1,8][39m = reshape[dimensions=None new_sizes=(1, 8) sharding=None] b
    p[35m:f32[1000,8][39m = add n o
    q[35m:f32[1000,1][39m = slice[
      limit_indices=(1000, 3)
      start_indices=(0, 2)
      strides=None
    ] p
    r[35m:f32[1000][39m = squeeze[dimensions=(1,)] q
    s[35m:f32[1000][39m = cos r
    t[35m:f32[1000,1][39m = slice[
      limit_indices=(1000, 4)
      start_indices=(0, 3)
      strides=None
    ] p
    u[35m:f32[1000][39m = squeeze[dimensions=(1,)] t
    v[35m:f32[1000][39m = sin u
    w[35m:f32[1000,1][39m 

In [71]:
def mse_fn(params):
    pred = e.apply(params, x)
    return jnp.mean((pred-y)**2)


def get_mask_spec(thresh, params):
    flat, spec = tree_flatten(params)
    mask = [jnp.abs(f) > thresh for f in flat]
    return mask, spec

def apply_mask(mask, spec, params):
    flat, _ = tree_flatten(params)
    masked_params = tree_unflatten(spec, [f*m for f,m in zip(flat, mask)])
    return masked_params


def get_masked_mse(thresh, params):
    mask, spec = get_mask_spec(thresh, params)
    def masked_mse(params):
        masked_params = apply_mask(mask, spec, params)
        return mse_fn(masked_params)
    return jax.jit(masked_mse)
    

def l1_fn(params):
    return sum(
        jnp.abs(w).mean() for w in jax.tree.leaves(params["params"])
    )

In [72]:
def get_loss_grad(lamba):
    
    def loss_grad_fn(params):
        mse_val, mse_grad = jax.value_and_grad(mse_fn)(params)
        
        l1_val, l1_grad = jax.value_and_grad(l1_fn)(params)
        
        mse_flat, spec = tree_flatten(mse_grad)
        l1_flat, _ = tree_flatten(l1_grad)
        
        dot_product = sum(jnp.dot(m1.ravel(), l1.ravel()) for m1, l1 in zip(mse_flat, l1_flat))
        norm_squared = sum(jnp.dot(m1.ravel(), m1.ravel()) for m1 in mse_flat)
        
        proj_scalar = dot_product / (norm_squared + 1e-8)
        
        proj_l1_flat = [l1 - proj_scalar * m1 for l1, m1 in zip(l1_flat, mse_flat)]
        combined_grad_flat = [m1 + lamba * p1 for m1, p1 in zip(mse_flat, proj_l1_flat)]
        combined_grad = tree_unflatten(spec, combined_grad_flat)
        combined_loss = mse_val + lamba * l1_val
        
        return combined_loss, combined_grad
    
    return jax.jit(loss_grad_fn)

In [73]:
tx = optax.adam(learning_rate=1e-2)
opt_state = tx.init(params)

In [74]:
loss_grad_1 = get_loss_grad(0)
loss_grad_2 = get_loss_grad(1e-2)

for i in range(1000):
    loss_val, grads = loss_grad_1(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 100 == 0 and i > 0:
        print(loss_val)
        wandb.log({"loss": loss_val})

for i in range(3000):
    loss_val, grads = loss_grad_2(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 100 == 0 and i > 0:
        print(loss_val)
        print(l1_fn(params))
        wandb.log({"loss": loss_val})
        
thr = 1e-3
loss_grad_masked = jax.jit(jax.value_and_grad(get_masked_mse(thr, params)))

for i in range(1000):
    loss_val, grads = loss_grad_masked(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    mask, spec = get_mask_spec(thr, params)
    params = apply_mask(mask, spec, params)
    if i % 100 == 0 and i > 0:
        print(loss_val)
        wandb.log({"loss": loss_val})

0.0027960988
0.00041793875
0.00026049122
0.00010524316
1.8509174e-05
4.316204e-06
3.6661515e-06
3.5083053e-06
3.3494446e-06
0.015175413
1.5149567
0.013005482
1.2970622
0.01128583
1.1256647
0.010028285
1.0015178
0.009288606
0.9259709
0.0088067455
0.87797856
0.008343381
0.83063567
0.007914561
0.7898774
0.0074660387
0.74392545
0.0072041308
0.7188498
0.0070492313
0.70199364
0.0068699853
0.68555295
0.006716198
0.6679546
0.00654328
0.65061593
0.0063721742
0.6338
0.006221539
0.61881304
0.0060529835
0.6034317
0.0059031663
0.58733714
0.005744346
0.5704863
0.0055367034
0.5521358
0.005349946
0.5338613
0.005235622
0.51891446
0.005066209
0.5051458
0.0049210005
0.4911219
0.0047788573
0.47679925
0.004684058
0.46320108
0.0045221522
0.44916743
0.004381592
0.43605798
0.004286056
0.42724532
4.445052e-06
1.9368488e-06
8.1848566e-07
3.333035e-07
1.3026813e-07
4.934162e-05
1.7054077e-08
5.9102594e-09
8.545005e-07


In [75]:
params

{'params': {'last': {'bias': Array([0.], dtype=float32),
   'kernel': Array([[ 0.        ],
          [ 0.12263211],
          [ 0.        ],
          [ 0.        ],
          [ 0.        ],
          [-0.53281856]], dtype=float32)},
  'layers_0': {'linear_layer': {'bias': Array([ 0.0109022 , -0.01115601,  0.00468342,  0.        ,  0.        ,
            0.        , -0.01791567,  0.00413576], dtype=float32),
    'kernel': Array([[ 0.49267825, -0.4926865 ,  0.        ,  0.        ,  0.00355938,
            -0.01648284,  0.01258602,  0.        ]], dtype=float32)}},
  'layers_1': {'linear_layer': {'bias': Array([ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
            0.        ,  0.        , -0.01516098], dtype=float32),
    'kernel': Array([[-4.0627629e-03,  0.0000000e+00,  0.0000000e+00,  2.7152407e-01,
             1.8476720e+00, -3.0167583e-01,  0.0000000e+00,  0.0000000e+00],
           [ 0.0000000e+00, -8.1290407e-03,  0.0000000e+00,  0.0000000e+00,
          

In [76]:
symb = get_symbolic_expr(apply_mask(mask, spec, params), funs)[0]
symb

-0.532818555831909*(6.94929926730461e-7*x0**2 + 0.4236179292202*(-0.49268651008606*x0 - 0.0111560123041272)*(0.492678254842758*x0 + 0.0109022026881576) + 1.90931916236877*cos(0.0125860152766109*x0 - 0.0179156735539436) + 1.84765172184128)*(7.7752730662506e-7*x0**2 - 2.9784038066864*(-0.49268651008606*x0 - 0.0111560123041272)*(0.492678254842758*x0 + 0.0109022026881576) - 0.649626970291138*cos(0.0125860152766109*x0 - 0.0179156735539436) - 0.301672517772522) + 0.12263210862875*sin(5.18916491933865e-7*x0**2 + 0.628131747245789*cos(0.0125860152766109*x0 - 0.0179156735539436) + 0.271521093835101)

In [77]:
sy.expand(symb)

0.0396098242762188*x0**4 + 0.00354680601222744*x0**3 - 0.771080373835565*x0**2*cos(0.0125860152766109*x0 - 0.0179156735539436) - 0.72814267193512*x0**2 - 0.0345225311220222*x0*cos(0.0125860152766109*x0 - 0.0179156735539436) - 0.0326036886884059*x0 + 0.12263210862875*sin(5.18916491933865e-7*x0**2 + 0.628131747245789*cos(0.0125860152766109*x0 - 0.0179156735539436) + 0.271521093835101) + 0.660878950328488*cos(0.0125860152766109*x0 - 0.0179156735539436)**2 + 0.946045129155795*cos(0.0125860152766109*x0 - 0.0179156735539436) + 0.296620576607424

In [78]:
wandb.finish()

0,1
loss,▂▁▁▁▁▁▁▁█▇▆▆▅▅▅▄▄▄▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▁▁▁▁▁▁▁

0,1
loss,0.0
