In [18]:
import h5py
import jax.numpy as jnp
import jax.random as random
import scipy.optimize as opt
import numpy as np
import jax

f = h5py.File('inputdata.h5', "r")

def van_jax(params, order):
    """
    Construct the Vandermonde matrix.
    If params is a 2-D array, the highest dimension indicates number of parameters.
    """
    try:
        dim = len(params[0])
    except:
        dim = 1
    params = jnp.array(params)

    #We will take params to the power of grlex_pow element-wise
    if dim == 1:
        grlex_pow = jnp.array(range(order+1))
    else:
        term_list = [[0]*dim]
        for i in range(1, numCoeffsPoly(dim, order)):
            term_list.append(mono_next_grlex(dim, term_list[-1][:]))
        grlex_pow = jnp.array(term_list)
    
    if dim == 1:
        V = jnp.zeros((len(params), numCoeffsPoly(dim, order)), dtype=jnp.float32)
        for a, p in enumerate(params): 
            V = V.at[a].set(p**grlex_pow)
        return V
    else:
        V = jnp.power(params, grlex_pow[:, jnp.newaxis])
        return jnp.prod(V, axis=2).T

def numCoeffsPoly(dim, order):
    """
    Number of coefficients a dim-dimensional polynomial of order order has (C(dim+order, dim)).
    """
    ntok = 1
    r = min(order, dim)
    for i in range(r):
        ntok = ntok * (dim + order - i) / (i + 1)
    return int(ntok)

#Next term of grlex ordering
def mono_next_grlex(m, x):
    #  Author:
    #
    #    John Burkardt
    #
    #     TODO --- figure out the licensing thing https://people.sc.fsu.edu/~jburkardt/py_src/monomial/monomial.html

    #  Find I, the index of the rightmost nonzero entry of X.
    i = 0
    for j in range(m, 0, -1):
        if 0 < x[j-1]:
            i = j
            break

    #  set T = X(I)
    #  set X(I) to zero,
    #  increase X(I-1) by 1,
    #  increment X(M) by T-1.
    if i == 0:
        x[m-1] = 1
        return x
    elif i == 1:
        t = x[0] + 1
        im1 = m
    elif 1 < i:
        t = x[i-1]
        im1 = i - 1

    x[i-1] = 0
    x[im1-1] = x[im1-1] + 1
    x[m-1] = x[m-1] + t - 1

    return x


param_real = jnp.array([23, 16, 15, 4, 7, 5, 2, 1, 1, 2])
x_sample = jnp.array([[2, 3], [4, 5], [6, 7], [2, 5], [9, 8], [10, 11], [1, 9], [20, 10], [40, 2], [40, 20], [0.1, 0], [2, 0]])


param_real_2 = jnp.array([23, 16, 15, 2])
x_sample_2 = jnp.array([[1.0], [2], [3], [4], [5], [6], [7], [8]])
x_sample_2_flat = x_sample_2.reshape(8,)

def real_func(X):
    return (van_jax(X, 3))@param_real_2

def model_func(X, *params):
    return ((van_jax(X, 3))@jnp.array(params).T).flatten()
    

y_real = real_func(x_sample_2)

key = random.PRNGKey(27)
y_real = y_real + 10*random.normal(key, shape=(y_real.size,))


VM = van_jax(x_sample_2, 3)

p_coeffs, res, rank, s  = jnp.linalg.lstsq(VM, y_real, rcond=None)

cov = jnp.linalg.inv(VM.T@VM)
fac = res/(VM.shape[0]-VM.shape[1])
cov = cov*fac



p, V = jnp.polyfit(x_sample_2_flat, y_real, 3, cov=True)

X = jnp.array([[2, 3]])
#print(timing_van_jax(X, 3)[0])

print("y_real: ", y_real, " end")
print("y_fit: ", model_func(x_sample_2, p.tolist()), " end")

print(p_coeffs)
print(jnp.flip(p))
print(jnp.sqrt(jnp.diagonal(cov)))
print(jnp.flip(jnp.sqrt(jnp.diagonal(V))))
print(VM.T@VM)


y_real:  [  55.10936242  125.49150181  260.25133322  455.05066567  730.29218516
 1088.5304862  1546.03744832 2115.64479861]  end
y_fit:  [   53.86055512   272.96166786   805.05160004  1796.11765469
  3392.14713483  5739.12734347  8983.04558363 13269.88915834]  end
[24.33121717 10.5071067  17.26127243  1.76095882]
[24.33121717 10.5071067  17.26127243  1.76095882]
[5.49501039 4.97225145 1.24722001 0.09150285]
[5.49445884 4.97166918 1.24707455 0.09149266]
[[8.00000e+00 3.60000e+01 2.04000e+02 1.29600e+03]
 [3.60000e+01 2.04000e+02 1.29600e+03 8.77200e+03]
 [2.04000e+02 1.29600e+03 8.77200e+03 6.17760e+04]
 [1.29600e+03 8.77200e+03 6.17760e+04 4.46964e+05]]




In [22]:
f = h5py.File('inputdata.h5', "r")
X = jnp.array(f['params'][:], dtype=jnp.float32)
Y = jnp.array(f['values'][0])

VM = van_jax(X, 3)

def curve_fit_func(X, p0, p1, p2, p3, p4, p5, p6, p7, p8, p9):
    p = jnp.array([p0, p1, p2, p3, p4, p5, p6, p7, p8, p9])
    return van_jax(X, 3)@p
p_coeff_alt, p_cov_alt = opt.curve_fit(curve_fit_func, (X), Y)




p_coeffs, res_here, rank, s  = jnp.linalg.lstsq(VM, Y, rcond=None)

cov = jnp.linalg.inv(VM.T@VM)
fac_here = res_here/(VM.shape[0]-VM.shape[1])
cov = cov*fac_here

def res_sq(coeff):
    return jnp.sum(jnp.square(Y-VM@coeff))
def Hessian(func):
    return jax.jacfwd(jax.jacrev(res_sq))

jnp.printoptions(linewidth=1000)

print("y_real: ", Y, " end")
print("y_fit: ", VM@p_coeffs, " end")
print("y_fit using curve_fit: ", VM@p_coeff_alt, " end")

print("cov given by inverse Hessian: \n", jnp.linalg.inv(Hessian(res_sq)(p_coeffs))*fac_here)
print("cov given by scipy curve_fit: \n", p_cov_alt)
print(cov)


y_real:  [4.974412   2.492576   3.27064    4.552948   2.96389267 2.76797133
 1.63777933 4.88243267 4.26869267 2.45834    1.807726   6.34289467
 4.07663    3.09181    1.47149133 1.701142   2.62115733 4.36851933
 1.44821133 5.54064133 2.23663267 1.204874   3.43674467 3.30419133
 2.359856   3.63903333 3.23851933 4.266862   1.17176667 5.06567467
 5.45114467 2.126532   1.85267467 3.87413333 2.89222333 3.833762
 1.402982   4.41896867 1.50100867 3.649254   3.61486867 2.960642  ]  end
y_fit:  [4.60839772 2.74903858 3.23036213 4.21670845 2.76309843 2.46501569
 1.42870649 5.21336776 4.58572443 2.40359798 1.68862244 5.90623642
 3.597959   3.88407672 1.58392246 2.04241646 2.83854608 4.28711445
 1.6542727  5.28611255 1.96357019 1.0776957  4.29352468 3.23342089
 2.53469601 3.80443238 3.74082706 3.83150487 1.15989808 4.49936924
 5.85168336 2.09246    1.76559913 3.96713278 2.7274942  3.83881562
 1.75429853 3.50518078 1.575348   3.78073395 3.7034497  3.10782459]  end
y_fit using curve_fit:  [4.60839746

In [3]:
from polyfit import Polyfit

dummy_fits = Polyfit('dummy_p_coeffs.npz', 'dummy_chi2res.npz', input_h5='inputdata.h5', order=3, cov_npz = 'dummy_cov.npz')
test_surrogate, chi2_ndf, res, cov = dummy_fits.get_surrogate_func('/func0#0')


func0#0 
 [  80.62972621  104.38691504  -78.86316804   22.8934713  -118.4974257
    7.38508711  -13.12712689  -45.0598705     9.27279875    0.98583797] 
 [ 63.35120204 156.53672888  37.8115283  142.3519086   47.69563864
  12.44700243  46.14162883  19.2047809    8.43895188   2.34190177] 
end
func0#1 
 [  -9.06684082 -151.08618601  -55.02320836 -204.8598869   -59.50088039
   10.58171505  -83.664558    -26.22340188    1.5444825    -1.49583436] 
 [ 95.21013423 235.2580928   56.8267147  213.93981312  71.68148371
  18.70652385  69.34597187  28.86274781  12.68284919   3.51962985] 
end
func0#2 
 [150.4308984  289.43369588 -77.35680929 224.11656476 -67.15013875
  23.15696082  43.28677066 -51.44319631 -11.12765578  -6.7135658 ] 
 [ 74.36241416 183.7447229   44.38363341 167.09440773  55.98572276
  14.61044336  54.16160709  22.54280622   9.90574472   2.74895288] 
end
func0#3 
 [ 154.72120367  334.70876121  -72.98427325  243.493349   -138.8571259
   -3.38708843   60.00090054  -50.83377668   11.3573

  chi2res = {b: np.array([self.chi2[b], self.res[b]]) for b in self.chi2.keys()}
