In [2]:
import jax
from typing import Any, Callable, Sequence, Optional
from jax import lax, random, numpy as jnp
import flax
from flax import linen as nn
import sympy as sy
import numpy as np
import sys
sys.path.append("..")
from eql.eqlearner import EQL
from eql.symbolic import get_symbolic_expr, get_symbolic_expr_layer
from typing import List, Tuple, Callable
from functools import partial
import matplotlib.pyplot as plt
import scipy
sys.path.append("../../orient")
import optax
from eql.np_utils import flatten, unflatten

In [3]:
res = optimize_fire2(flat, energy, denergy, 0, logoutput=True)

NameError: name 'flat' is not defined

In [None]:
" Global variables for the FIRE algorithm"
alpha0 = 0.1
Ndelay = 5
Nmax = 400
finc = 1.1
fdec = 0.5
fa = 0.99
Nnegmax = 2000

#@partial(jax.jit, static_argnums=[1,2,3,4,5,6])
def optimize_fire(x0,f,df,params,atol=1e-4,dt = 0.002,logoutput=False):
    error = 10*atol 
    dtmax = 100*dt
    dtmin = 1e-3*dt
    alpha = alpha0
    Npos = 0

    x = x0.copy()
    V = jnp.zeros(x.shape)
    F = -df(x,params)

    for i in range(Nmax):

        P = (F*V).sum() # dissipated power
        
        if (P>0.0):
            Npos = Npos + 1
            if Npos>Ndelay:
                dt = min(dt*finc,dtmax)
                alpha = alpha*fa
        else:
            Npos = 0
            dt = max(dt*fdec,dtmin)
            alpha = alpha0
            V = jnp.zeros(x.shape)

        V = V + 0.5*dt*F
        V = (1-alpha)*V + alpha*F*jnp.linalg.norm(V)/jnp.linalg.norm(F)
        x = x + dt*V
        F = -df(x,params)
        V = V + 0.5*dt*F

        error = max(abs(F))
        if error < atol: break

        if logoutput: print(f(x,params),error)

    del V, F  
    return [x,f(x,params),i]

def optimize_fire2(x0,f,df,params,atol=1e-4,dt = 0.002,logoutput=False):
    error = 10*atol 
    dtmax = 100*dt
    dtmin = 1e-3*dt
    alpha = alpha0
    Npos = 0
    Nneg = 0
    key = jax.random.PRNGKey(1)

    x = x0.copy()
    V = jnp.zeros(x.shape)
    F = -df(x,params, key)

    for i in range(Nmax):
        key, _ = jax.random.split(key)
        P = (F*V).sum() # dissipated power
        
        if (P>0):
            Npos = Npos + 1
            Nneg = 0
            if Npos>Ndelay:
                dt = min(dt*finc,dtmax)
                alpha = alpha*fa
        else:
            Npos = 0
            Nneg = Nneg + 1
            if Nneg > Nnegmax: break
            if i> Ndelay:
                dt = max(dt*fdec,dtmin)
                alpha = alpha0
            x = x - 0.5*dt*V
            V = jnp.zeros(x.shape)
            
        V = V + 0.5*dt*F
        V = (1-alpha)*V + alpha*F*jnp.linalg.norm(V)/jnp.linalg.norm(F)
        x = x + dt*V
        F = -df(x,params, key)*0.9 + 0.1*F
        V = V + 0.5*dt*F

        error = max(abs(F))
        if error < atol: break

        if logoutput: print(f(x,params, key),error)

    del V, F  
    return [x,f(x,params, key),i]

In [None]:
spec, flat = flatten(params)

In [None]:
@jax.jit
def energy(x, params, key):
    return mse_fn(unflatten(spec, x), key)

denergy = jax.jit(jax.grad(energy))

In [None]:
funs = ['mul', 'cos', 'sin', 'exp', 'square']*2
e = EQL(n_layers=2, functions=funs, features=1, use_l0=True, drop_rate=0.03)
key = random.PRNGKey(0)

In [None]:
N = 10000
xdim = 3
x = (random.uniform(key, (N, xdim))-.5) * 3

y = x[:,0] + jnp.cos(x[:,1]) - 4.2*jnp.exp(-x[:,2]**2)
#y = jnp.cos(x) + 1 - x**2 

In [None]:
#params = e.init({'params':key}, x)
params = e.init({'params':key, 'l0': key}, x);

In [None]:
def mse_fn(params, key):
    pred = e.apply(params, x, rngs={'l0': key})
    return jnp.mean((pred-y)**2) + 1e-2*e.apply(params, rngs={'l0': key}, method=e.l0_reg)

In [None]:
def l2_fn(params):
    return sum(
        jnp.square(w).mean() for w in jax.tree_leaves(params["params"])
    )

In [None]:
def loss(params, key):
    return mse_fn(params, key)
loss_grad_fn = jax.jit(jax.value_and_grad(loss))

In [None]:
tx = optax.adam(learning_rate=1e-2)
opt_state = tx.init(params)
loss_grad_fn = jax.jit(jax.value_and_grad(loss))

In [None]:

for i in range(1000):
    key, _ = jax.random.split(key)
    loss_val, grads = loss_grad_fn(params, key)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 9 == 0:
        print(loss_val)

7.4362807
3.8227262
2.8572721
2.7632446
2.6407542
2.6195922
2.6266215
2.6010928
2.6209812


KeyboardInterrupt: 

In [None]:
params = unflatten(spec, res[0])

In [None]:
symb = get_symbolic_expr(params, funs, use_l0=True)[0]
print(symb)

-0.0833462253212929*(-0.1974916055734*(-0.783446629679925*x0 + 0.339917521660799*x1 + x2 - 0.0213253617043436)**2 + 0.171413645148277*(-0.163496926426888*x0 - 0.0909073203802109*x1 - 0.493996649980545*x2 + 0.00570506509393454)*(0.172961160540581*x0 + 0.0903534665703773*x1 - 0.310840874910355*x2 + 0.00215031718835235) + 0.380757182836533*(0.178383767604828*x0 - 0.254964083433151*x1 + 0.759177207946777*x2 - 0.0386561527848244)*(0.411183476448059*x0 + 0.230472907423973*x1 + 0.50562059879303*x2 - 0.0271826516836882) - 0.284336374057733*(0.525652755316313*x0 + x1 + 0.382663090690845*x2 + 0.0362508920736499)**2 - 0.0652149003232912*exp(-0.644405901432037*x0 - 0.621454954147339*x1 + 0.191928163170815*x2) + 0.304508230937823*exp(-0.202667981386185*x0 + 0.54872190952301*x1 + 0.823585569858551*x2) - 0.589221239089966*sin(-0.798131346702576*x0 + 0.534926295280457*x1 + 0.243915945291519*x2 + 0.0107335988432169) + 0.268852531909943*sin(0.139502450823784*x0 + 0.197748526930809*x1 + 0.033848278224468