In [1]:
%matplotlib inline
import matplotlib.pyplot as plt

import sys; sys.path.append('../')
from misc import h5file
import numpy as np

import scipy.io as sio
import sympy
from sympy import Symbol, parse_expr, sympify, lambdify, Lambda

import pandas as pd
import subprocess

from scipy.optimize import curve_fit

MAIN_SEED = 1234

In [2]:
fp1 = "./IPI_output_files/PMS_data.h5"

In [3]:
_, _, un, _ = h5file(file_path=fp1, mode='r', return_dict=False)

['X_pre', 'best_subsets', 'un', 'y_pre']


In [4]:
data = sio.loadmat('../Datasets/burgers.mat')
x = data['x'].real.reshape(-1, 1)
del data

In [5]:
y = un[:, 0]

In [6]:
import feyn
ql = feyn.QLattice()

This version of feyn and the QLattice is available for academic, personal, and non-commercial use. By using the community version of this software you agree to the terms and conditions which can be found at `https://abzu.ai/eula`.


In [7]:
train = pd.DataFrame({'x':x.flatten(), 'y':y})
# starting from max_complexity = 10
models = ql.auto_run(train, output_name = 'y', max_complexity=2)

In [8]:
best = models[0]

In [9]:
sympy_model = best.sympify()
sympy_model.as_expr()

0.00221307 + 1.00782*exp(-4.1475*(-0.497272*x - 1)**2)

In [10]:
import pickle
# Save equation to file
save_to = f"./hall/hof.pkl"
# save_to = f"hof_{smoother_name}.pkl"
with open(save_to, 'wb') as f:
    pickle.dump(sympy_model.as_expr(), f)

In [11]:
# Load equation from file
with open(save_to, 'rb') as f:
    eq = pickle.load(f)
eq

0.00221307 + 1.00782*exp(-4.1475*(-0.497272*x - 1)**2)

In [12]:
from jaxfit import CurveFit
from jax import numpy as jnp

In [13]:
equation = eq

In [14]:
feyn_params = np.array(sorted([float(atom) for atom in sympify(equation).atoms() if atom.is_number]))
feyn_params = feyn_params[[0, 2, 3, 4]]
feyn_params

array([-4.14750004e+00, -4.97272015e-01,  2.21307017e-03,  1.00781989e+00])

In [15]:
def initial_function(x, a, b, c, d):
    return c+d*np.exp(a*np.square(b*x-1))

def jax_initial_function(x, a, b, c, d):
    return c+d*jnp.exp(a*jnp.square(b*x-1))

recovered_params = np.array(CurveFit().curve_fit(jax_initial_function, x.flatten(), un[:, 0], 
                                                  p0=feyn_params)[0])
recovered_params

array([-4.15259791e+00, -4.97142427e-01,  2.95012304e-03,  1.00797403e+00])