In [1]:
import jax
from jax import lax, random, numpy as jnp
from jax.tree_util import tree_flatten, tree_unflatten

import flax
from flax import linen as nn

import sympy as sy
from sympy.core.rules import Transform
import numpy as np

import sys
sys.path.append("..")
sys.path.append("../../orient/")


from eql.eqlearner import EQL, EQLdiv
from eql.symbolic import get_symbolic_expr_div, get_symbolic_expr
from eql.np_utils import flatten, unflatten


import optax
import scipy
from functools import partial
import matplotlib.pyplot as plt

In [2]:
funs = ['mul', 'cos' , 'sin', 'id', 'id', 'id', 'id']
e = EQLdiv(n_layers=1, functions=funs, features=1)
key = random.PRNGKey(0)

In [3]:
N = 1024
xdim = 2
x = (random.uniform(key, (N, xdim))-.5) * 2
#x = np.array([[1., 2.]]).T
#x = np.linspace(-1, 1, N)[:,None]
#y = x[:,0] + jnp.cos(x[:,1])
#y = (jnp.cos(x) + 1 - x**2)/(x-3)**3


y = np.sin(np.pi * x[:,0])/(x[:,1]**2 + 1)
#y = 1./3. * ((1.+x[:,1])*np.sin(np.pi*x[:,0]) + x[:,1]*x[:,2]*x[:,3])
#plt.scatter(x[:,0], x[:,1], c=y)

In [4]:
params = e.init({'params':key}, x, 1.0);

In [5]:
def mse_fn(params, threshold):
    pred, _ = e.apply(params, x, threshold)
    return jnp.mean((pred-y)**2)

def mse_b_fn(params, threshold):
    pred, b = e.apply(params, x, threshold)
    return jnp.mean((pred-y)**2), b

def mse_b_y_fn(params, threshold):
    pred, b = e.apply(params, x, threshold)
    return jnp.mean((pred-y)**2), b, pred


def get_mask_spec(thresh, params):
    flat, spec = tree_flatten(params)
    mask = [jnp.abs(f) > thresh for f in flat]
    return mask, spec

def apply_mask(mask, spec, params):
    flat, _ = tree_flatten(params)
    masked_params = tree_unflatten(spec, [f*m for f,m in zip(flat, mask)])
    return masked_params


def get_masked_mse(thresh, params):
    mask, spec = get_mask_spec(thresh, params)
    def masked_mse(params, threshold):
        masked_params = apply_mask(mask, spec, params)
        return mse_fn(masked_params, threshold)
    return jax.jit(masked_mse)
    

def l1_fn(params):
    return sum(
        jnp.abs(w).mean() for w in jax.tree.leaves(params["params"])
    )

def reg_fn(threshold, b):
    return (jnp.maximum(0, threshold - b)).sum()

def penalty_fn(y, B=10, supp=3):
    penalty_fn.key, _ = random.split(key)
    xr = (random.uniform(penalty_fn.key, (N, xdim))-.5) * supp
    return jnp.sum(jnp.maximum(y-B, 0)+jnp.maximum(-y-B, 0))
penalty_fn.key = key

In [6]:
def get_loss(lamba):
    def loss_fn(params, threshold):
        mse, b = mse_b_fn(params, threshold)
        return mse  + lamba * l1_fn(params) + reg_fn(threshold, b)
    return loss_fn

def get_loss_pen():
    def loss_fn(params, threshold):
        mse, b, y = mse_b_y_fn(params, threshold)
        return penalty_fn(y) + reg_fn(threshold, b)
    return loss_fn

def get_loss_grad(lamba=1e-3, is_penalty=False):
    if is_penalty:
        loss = get_loss_pen()
    else:
        loss = get_loss(lamba)
    return jax.jit(jax.value_and_grad(loss))

In [7]:
tx = optax.adam(learning_rate=1e-4)
opt_state = tx.init(params)

In [8]:
loss_grad_pen = get_loss_grad(is_penalty=True)
loss_grad_1 = get_loss_grad(0)
loss_grad_2 = get_loss_grad(1e-1)

In [9]:
def do_step(loss_grad, params, theta, opt_state):
    loss_val, grad = loss_grad(params, theta)
    updates, opt_state = tx.update(grad, opt_state)
    return optax.apply_updates(params, updates), opt_state, loss_val

In [10]:
T1 = 10_000
Tpenalty = 500
for i in range(20_000):
    theta = 1./jnp.sqrt(i/1. + 1)
    if i < T1:
        lg = loss_grad_1
    elif i >= T1:
        lg = loss_grad_2
    params, opt_state, loss_val = do_step(lg, params, theta, opt_state)
    if i % 99 == 0:
        print(loss_val, theta)
        for j in range(100):
            params, opt_state, loss_val = do_step(loss_grad_pen, params, theta, opt_state)

0.3226874 1.0
1.3048428 0.1
1.2196879 0.07088812
1.1298747 0.057928447
1.0362647 0.050188564
0.9408289 0.044901326
0.8457495 0.040996004
0.7530881 0.037959483
0.66463464 0.035511043
0.5818316 0.033482477
0.50574565 0.031766046
0.43708462 0.030289127
0.37624228 0.029000739
0.32335082 0.02786391
0.27832562 0.026851078
0.24088985 0.02594123
0.21057938 0.02511802
0.18674725 0.02436851
0.16858725 0.023682324
0.1551885 0.023051023
0.14561483 0.022467656
0.13898471 0.02192645
0.134525 0.021422561
0.13159904 0.020951888
0.12971279 0.020510932
0.12850335 0.020096697
0.12771674 0.019706586
0.1271821 0.019338345
0.12678866 0.018990004
0.12646723 0.018659834
0.12617677 0.018346308
0.12589452 0.018048072
0.1256086 0.017763922
0.12531328 0.017492786
0.12500575 0.017233696
0.12468447 0.016985789
0.12434837 0.016748281
0.12399657 0.016520465
0.12362818 0.016301699
0.123242244 0.0160914
0.12283784 0.015889037
0.12241395 0.01569412
0.12196959 0.015506205
0.12150392 0.015324884
0.12101629 0.015149776
0.1

In [11]:
T = 0
for i in range(10_000):
    theta = 1./jnp.sqrt(T/1 + 1)
    loss_val, grads = loss_grad_1(params, theta)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    T +=1 
    if i % 99 == 0:
        print(loss_val, theta)

for i in range(10_000):
    theta = 1./jnp.sqrt(T/1 + 1)
    loss_val, grads = loss_grad_2(params, theta)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    T +=1
    if i % 99 == 0:
        print(loss_val, theta)
        print(l1_fn(params))

994.8426 1.0
0.21743366 0.1
0.17117876 0.07088812
0.15048471 0.057928447
0.1348757 0.050188564
0.12252143 0.044901326
0.112661496 0.040996004
0.10485084 0.037959483
0.09876719 0.035511043
0.094137296 0.033482477
0.09070925 0.031766046
0.08825055 0.030289127
0.08654542 0.029000739
0.08540001 0.02786391
0.08465097 0.026851078
0.08416934 0.02594123
0.083857864 0.02511802
0.08365011 0.02436851
0.08350092 0.023682324
0.083383314 0.023051023
0.08328149 0.022467656
0.08318655 0.02192645
0.083092615 0.021422561
0.082997516 0.020951888
0.08290012 0.020510932
0.082798734 0.020096697
0.08269339 0.019706586
0.08258347 0.019338345
0.08246842 0.018990004
0.08234818 0.018659834
0.08222247 0.018346308
0.08209077 0.018048072
0.08195293 0.017763922
0.08180863 0.017492786
0.0816575 0.017233696
0.081499234 0.016985789
0.08133345 0.016748281
0.08115977 0.016520465
0.08097772 0.016301699
0.080786906 0.0160914
0.08058686 0.015889037
0.080376916 0.01569412
0.080156684 0.015506205
0.07992534 0.015324884
0.0796

In [12]:
# thr = 1e-3
# loss_grad_masked = jax.jit(jax.value_and_grad(get_masked_mse(thr, params)))
# mask, spec = get_mask_spec(thr, params)

# for i in range(1000):
#     theta = 1./jnp.sqrt(T/1 + 1)
#     loss_val, grads = loss_grad_masked(params, theta)
#     updates, opt_state = tx.update(grads, opt_state)
#     params = optax.apply_updates(params, updates)
#     T +=1
#     if i % 99 == 0:
#         print(loss_val)

In [13]:
#symb = get_symbolic_expr_div(apply_mask(mask, spec, params), funs)[0]
#symb = get_symbolic_expr_div(params, funs)[0]
#symb

In [14]:
spec, fparam = flatten(params)
full_shape = fparam.shape
mask = jnp.abs(fparam) > 0.01
idxs = jnp.arange(fparam.shape[0])[mask]
count = sum(mask).item()

In [15]:
def red_loss_grad_fn(red_param):
    full_param = jnp.zeros(full_shape).at[idxs].set(red_param)
    full_param = unflatten(spec, full_param)

    #return mse_fn(full_param, 1e-4)
    loss, grad = loss_grad_1(full_param, 1e-4)
    _, grad = flatten(grad)
    return loss, np.array(grad)[idxs,]
    
#red_loss_grad = jax.jit(jax.value_and_grad(red_mse_fn))

In [16]:
x0, f, info = scipy.optimize.fmin_l_bfgs_b(
        red_loss_grad_fn,
        x0 = np.array(fparam[mask]),
        factr=1.,
        m=500,
        pgtol=1e-13,
        maxls=100)
#x0[np.abs(x0) < 1e-3] = 0.0

In [17]:
f

0.00010573297186056152

In [18]:
final_param = unflatten(spec, jnp.zeros(full_shape).at[idxs].set(x0))

In [19]:
symb = get_symbolic_expr_div(final_param, funs)[0]
#symb

In [20]:
def clean_expr(expr):
    # WARNING: might return 0/NaN/inf if expression only contains small numbers
    def prune(expr, thr=1e-5):
        return expr.replace(lambda x: x.is_Number and abs(x) < thr, lambda x: 0)
    
    def rounding(expr, dig=3):
        return expr.xreplace(Transform(lambda x: x.round(dig), lambda x: x.is_Number))
    
    # prune small numbers
    expr = prune(expr)
    # round number
    expr = rounding(expr)
    # expand
    expr = prune(sy.expand(expr), 1e-3)
    return sy.simplify(expr)

In [21]:
clean_expr(symb)

(-0.108*x0 + 0.129*sin(0.988*x0))/(0.005*x0**2 - 0.318*cos(0.212*x1) + 0.324)

In [22]:
f

0.00010573297186056152

In [23]:
symb


(-0.107759978379892*x0 + 0.128756254911423*sin(0.987523376941681*x0))/(0.00494997866433353*x0**2 - 0.318348169326782*cos(0.2120620906353*x1) + 0.324296563863754)