In [1]:
import h5py
import jax.numpy as jnp
import jax.random as random
import scipy.optimize as opt
import numpy as np
import jax

f = h5py.File('dummy_data/inputdata.h5', "r")

def van_jax(params, order):
    """
    Construct the Vandermonde matrix.
    If params is a 2-D array, the highest dimension indicates number of parameters.
    """
    try:
        dim = len(params[0])
    except:
        dim = 1
    params = jnp.array(params)

    #We will take params to the power of grlex_pow element-wise
    if dim == 1:
        grlex_pow = jnp.array(range(order+1))
    else:
        term_list = [[0]*dim]
        for i in range(1, numCoeffsPoly(dim, order)):
            term_list.append(mono_next_grlex(dim, term_list[-1][:]))
        grlex_pow = jnp.array(term_list)
    
    if dim == 1:
        V = jnp.zeros((len(params), numCoeffsPoly(dim, order)), dtype=jnp.float32)
        for a, p in enumerate(params): 
            V = V.at[a].set(p**grlex_pow)
        return V
    else:
        V = jnp.power(params, grlex_pow[:, jnp.newaxis])
        return jnp.prod(V, axis=2).T

def numCoeffsPoly(dim, order):
    """
    Number of coefficients a dim-dimensional polynomial of order order has (C(dim+order, dim)).
    """
    ntok = 1
    r = min(order, dim)
    for i in range(r):
        ntok = ntok * (dim + order - i) / (i + 1)
    return int(ntok)

#Next term of grlex ordering
def mono_next_grlex(m, x):
    #  Author:
    #
    #    John Burkardt
    #
    #     TODO --- figure out the licensing thing https://people.sc.fsu.edu/~jburkardt/py_src/monomial/monomial.html

    #  Find I, the index of the rightmost nonzero entry of X.
    i = 0
    for j in range(m, 0, -1):
        if 0 < x[j-1]:
            i = j
            break

    #  set T = X(I)
    #  set X(I) to zero,
    #  increase X(I-1) by 1,
    #  increment X(M) by T-1.
    if i == 0:
        x[m-1] = 1
        return x
    elif i == 1:
        t = x[0] + 1
        im1 = m
    elif 1 < i:
        t = x[i-1]
        im1 = i - 1

    x[i-1] = 0
    x[im1-1] = x[im1-1] + 1
    x[m-1] = x[m-1] + t - 1

    return x


param_real = jnp.array([23, 16, 15, 4, 7, 5, 2, 1, 1, 2])
x_sample = jnp.array([[2, 3], [4, 5], [6, 7], [2, 5], [9, 8], [10, 11], [1, 9], [20, 10], [40, 2], [40, 20], [0.1, 0], [2, 0]])


param_real_2 = jnp.array([23, 16, 15, 2])
x_sample_2 = jnp.array([[1.0], [2], [3], [4], [5], [6], [7], [8]])
x_sample_2_flat = x_sample_2.reshape(8,)

def real_func(X):
    return (van_jax(X, 3))@param_real_2

def model_func(X, *params):
    return ((van_jax(X, 3))@jnp.array(params).T).flatten()
    

y_real = real_func(x_sample_2)

key = random.PRNGKey(27)
y_real = y_real + 10*random.normal(key, shape=(y_real.size,))


VM = van_jax(x_sample_2, 3)

p_coeffs, res, rank, s  = jnp.linalg.lstsq(VM, y_real, rcond=None)

cov = jnp.linalg.inv(VM.T@VM)
fac = res/(VM.shape[0]-VM.shape[1])
cov = cov*fac



p, V = jnp.polyfit(x_sample_2_flat, y_real, 3, cov=True)

X = jnp.array([[2, 3]])
#print(timing_van_jax(X, 3)[0])

print("y_real: ", y_real, " end")
print("y_fit: ", model_func(x_sample_2, p.tolist()), " end")

print(p_coeffs)
print(jnp.flip(p))
print(jnp.sqrt(jnp.diagonal(cov)))
print(jnp.flip(jnp.sqrt(jnp.diagonal(V))))
print(VM.T@VM)




y_real:  [  57.001892  145.35753   261.34067   457.46725   719.05005  1085.7313
 1562.6531   2128.626   ]  end
y_fit:  [   61.074867   308.8413     890.6476    1951.6718    3637.0918
  6092.086     9461.832    13891.508   ]  end
[24.196215 21.842585 12.866302  2.170224]
[24.196342  21.841862  12.866453   2.1702123]
[21.158777   19.145866    4.8024755   0.35233578]
[21.159946   19.147196    4.802826    0.35236132]
[[8.00000e+00 3.60000e+01 2.04000e+02 1.29600e+03]
 [3.60000e+01 2.04000e+02 1.29600e+03 8.77200e+03]
 [2.04000e+02 1.29600e+03 8.77200e+03 6.17760e+04]
 [1.29600e+03 8.77200e+03 6.17760e+04 4.46964e+05]]


In [2]:
f = h5py.File('dummy_data/inputdata.h5', "r")
X = jnp.array(f['params'][:], dtype=jnp.float32)
Y = jnp.array(f['values'][0])

VM = van_jax(X, 3)

def curve_fit_func(X, p0, p1, p2, p3, p4, p5, p6, p7, p8, p9):
    p = jnp.array([p0, p1, p2, p3, p4, p5, p6, p7, p8, p9])
    return van_jax(X, 3)@p
p_coeff_alt, p_cov_alt = opt.curve_fit(curve_fit_func, (X), Y)




p_coeffs, res_here, rank, s  = jnp.linalg.lstsq(VM, Y, rcond=None)

cov = jnp.linalg.inv(VM.T@VM)
fac_here = res_here/(VM.shape[0]-VM.shape[1])
cov = cov*fac_here

def res_sq(coeff):
    return jnp.sum(jnp.square(Y-VM@coeff))
def Hessian(func):
    return jax.jacfwd(jax.jacrev(res_sq))

jnp.printoptions(linewidth=1000)

print("y_real: ", Y, " end")
print("y_fit: ", VM@p_coeffs, " end")
print("y_fit using curve_fit: ", VM@p_coeff_alt, " end")

print("cov given by inverse Hessian: \n", jnp.linalg.inv(Hessian(res_sq)(p_coeffs))*fac_here)
print("cov given by scipy curve_fit: \n", p_cov_alt)
print(cov)


y_real:  [4.974412  2.492576  3.27064   4.552948  2.9638927 2.7679713 1.6377794
 4.8824325 4.2686925 2.45834   1.807726  6.3428946 4.07663   3.09181
 1.4714913 1.701142  2.6211574 4.3685193 1.4482113 5.5406413 2.2366326
 1.204874  3.4367447 3.3041914 2.359856  3.6390333 3.2385194 4.266862
 1.1717666 5.065675  5.4511447 2.126532  1.8526747 3.8741333 2.8922234
 3.833762  1.402982  4.4189687 1.5010086 3.649254  3.6148686 2.960642 ]  end
y_fit:  [4.608388  2.7490559 3.2304    4.2167416 2.763174  2.4650955 1.4287033
 5.213358  4.5857306 2.4035664 1.6886349 5.906246  3.597928  3.8841324
 1.5839882 2.0424957 2.8386269 4.287113  1.654314  5.2860575 1.9636059
 1.0777225 4.293498  3.2334328 2.5347557 3.8044186 3.7408056 3.8315372
 1.1599102 4.499341  5.851699  2.0925064 1.7655678 3.9671764 2.7274647
 3.838787  1.7543144 3.5052052 1.5754032 3.780696  3.703453  3.1079063]  end
y_fit using curve_fit:  [4.610564  2.74924   3.2318702 4.2449646 2.74537   2.4525375 1.4727325
 5.206974  4.583289  2.4055

In [3]:
from modules.polyfit import Polyfit

dummy_fits = Polyfit('dummy_data/dummy_p_coeffs.npz', 'dummy_data/dummy_chi2res.npz',
 input_h5='dummy_data/inputdata.h5', order=3, cov_npz = 'dummy_data/dummy_cov.npz')
test_surrogate, chi2_ndf, res, cov = dummy_fits.get_surrogate_func('/func0#0')


['/func0', '/func1']
fitting /func0#0
func0#0 
 [  80.62972621  104.38691504  -78.86316804   22.8934713  -118.4974257
    7.38508711  -13.12712689  -45.0598705     9.27279875    0.98583797] 
 [ 63.351202   156.53672883  37.81152827 142.35190859  47.69563861
  12.44700242  46.14162884  19.2047809    8.43895188   2.34190177] 
end
fitting /func0#1
func0#1 
 [  -9.06684082 -151.08618601  -55.02320836 -204.8598869   -59.50088039
   10.58171505  -83.664558    -26.22340188    1.5444825    -1.49583436] 
 [ 95.21013417 235.25809271  56.82671467 213.93981311  71.68148366
  18.70652384  69.34597188  28.8627478   12.68284919   3.51962985] 
end
fitting /func0#2
func0#2 
 [150.4308984  289.43369588 -77.35680929 224.11656476 -67.15013875
  23.15696082  43.28677066 -51.44319631 -11.12765578  -6.7135658 ] 
 [ 74.36241411 183.74472283  44.38363338 167.09440772  55.98572273
  14.61044336  54.16160709  22.54280621   9.90574472   2.74895288] 
end
fitting /func0#3
func0#3 
 [ 154.72120367  334.70876121  -72