In [2]:
import h5py
import jax.numpy as jnp
import jax.random as random
import scipy.optimize as opt
import numpy as np
import jax

f = h5py.File('dummy_data/inputdata.h5', "r")

def van_jax(params, order):
    """
    Construct the Vandermonde matrix.
    If params is a 2-D array, the highest dimension indicates number of parameters.
    """
    try:
        dim = len(params[0])
    except:
        dim = 1
    params = jnp.array(params)

    #We will take params to the power of grlex_pow element-wise
    if dim == 1:
        grlex_pow = jnp.array(range(order+1))
    else:
        term_list = [[0]*dim]
        for i in range(1, numCoeffsPoly(dim, order)):
            term_list.append(mono_next_grlex(dim, term_list[-1][:]))
        grlex_pow = jnp.array(term_list)
    
    if dim == 1:
        V = jnp.zeros((len(params), numCoeffsPoly(dim, order)), dtype=jnp.float32)
        for a, p in enumerate(params): 
            V = V.at[a].set(p**grlex_pow)
        return V
    else:
        V = jnp.power(params, grlex_pow[:, jnp.newaxis])
        return jnp.prod(V, axis=2).T

def numCoeffsPoly(dim, order):
    """
    Number of coefficients a dim-dimensional polynomial of order order has (C(dim+order, dim)).
    """
    ntok = 1
    r = min(order, dim)
    for i in range(r):
        ntok = ntok * (dim + order - i) / (i + 1)
    return int(ntok)

#Next term of grlex ordering
def mono_next_grlex(m, x):
    #  Author:
    #
    #    John Burkardt
    #
    #     TODO --- figure out the licensing thing https://people.sc.fsu.edu/~jburkardt/py_src/monomial/monomial.html

    #  Find I, the index of the rightmost nonzero entry of X.
    i = 0
    for j in range(m, 0, -1):
        if 0 < x[j-1]:
            i = j
            break

    #  set T = X(I)
    #  set X(I) to zero,
    #  increase X(I-1) by 1,
    #  increment X(M) by T-1.
    if i == 0:
        x[m-1] = 1
        return x
    elif i == 1:
        t = x[0] + 1
        im1 = m
    elif 1 < i:
        t = x[i-1]
        im1 = i - 1

    x[i-1] = 0
    x[im1-1] = x[im1-1] + 1
    x[m-1] = x[m-1] + t - 1

    return x


param_real = jnp.array([23, 16, 15, 4, 7, 5, 2, 1, 1, 2])
x_sample = jnp.array([[2, 3], [4, 5], [6, 7], [2, 5], [9, 8], [10, 11], [1, 9], [20, 10], [40, 2], [40, 20], [0.1, 0], [2, 0]])


param_real_2 = jnp.array([23, 16, 15, 2])
x_sample_2 = jnp.array([[1.0], [2], [3], [4], [5], [6], [7], [8]])
x_sample_2_flat = x_sample_2.reshape(8,)

def real_func(X):
    return (van_jax(X, 3))@param_real_2

def model_func(X, *params):
    return ((van_jax(X, 3))@jnp.array(params).T).flatten()
    

y_real = real_func(x_sample_2)

key = random.PRNGKey(27)
y_real = y_real + 10*random.normal(key, shape=(y_real.size,))


VM = van_jax(x_sample_2, 3)

p_coeffs, res, rank, s  = jnp.linalg.lstsq(VM, y_real, rcond=None)

def res_sq(coeff):
    return jnp.sum(jnp.square(y_real-VM@coeff))
def Hessian(func):
    return jax.jacfwd(jax.jacrev(func))

cov = jnp.linalg.inv(Hessian(res_sq)(p_coeffs))
fac = 2*res/(VM.shape[0]-VM.shape[1])
cov = cov*fac



p, V = jnp.polyfit(x_sample_2_flat, y_real, 3, cov=True)

X = jnp.array([[2, 3]])
#print(timing_van_jax(X, 3)[0])

print("y_real: ", y_real, " end")
print("y_fit: ", model_func(x_sample_2, p.tolist()), " end")

print(p_coeffs)
print(jnp.flip(p))
print(jnp.sqrt(jnp.diagonal(cov)))
print(jnp.flip(jnp.sqrt(jnp.diagonal(V))))
print(VM.T@VM)




y_real:  [  57.001892  145.35753   261.34067   457.46725   719.05005  1085.7313
 1562.6531   2128.626   ]  end
y_fit:  [   61.074776   308.8406     890.64526   1951.6664    3637.0815
  6092.068     9461.804    13891.466   ]  end
[24.1969    21.842394  12.8663025  2.1702225]
[24.19626   21.841858  12.866441   2.1702142]
[21.158699   19.1458      4.8024573   0.35233438]
[21.159895  19.147148   4.802814   0.3523605]
[[8.00000e+00 3.60000e+01 2.04000e+02 1.29600e+03]
 [3.60000e+01 2.04000e+02 1.29600e+03 8.77200e+03]
 [2.04000e+02 1.29600e+03 8.77200e+03 6.17760e+04]
 [1.29600e+03 8.77200e+03 6.17760e+04 4.46964e+05]]


In [24]:
import scipy.linalg as scila
f = h5py.File('dummy_data/inputdata.h5', "r")
X = jnp.array(f['params'][:], dtype=jnp.float32)
Y = jnp.array(f['values'][0])

VM = van_jax(X, 3)

def curve_fit_func(X, p0, p1, p2, p3, p4, p5, p6, p7, p8, p9):
    p = jnp.array([p0, p1, p2, p3, p4, p5, p6, p7, p8, p9])
    return van_jax(X, 3)@p
p_coeff_alt, p_cov_alt = opt.curve_fit(curve_fit_func, (X), Y)



p_coeffs_lstsq, res_here, rank, s  = jnp.linalg.lstsq(VM, Y, rcond=None)

##RIDGE
def ridge_obj(coeff, target, VM, alpha):
        res_sq = jnp.sum(jnp.square(target - VM@coeff))
        penalty = alpha*(coeff@coeff)
        return res_sq + penalty

obj_args = (Y, VM, 0.01)

guess = jnp.zeros((VM.shape[1],), dtype=jnp.float32)
c_opt = opt.minimize(ridge_obj, guess, args=obj_args, method='Nelder-Mead')
p_coeffs_ridge = c_opt.x
res_ridge = [jnp.sum(jnp.square(Y-VM@p_coeffs_ridge))]
def ridge(coeff):
    return ridge_obj(coeff, Y, VM, 0.01)
#################################33


cov = np.linalg.inv(VM.T@VM)
fac_here = res_here/(VM.shape[0]-VM.shape[1])
cov = cov*fac_here
fac_ridge = res_ridge[0]/(VM.shape[0]-VM.shape[1])

def res_sq(coeff):
    return jnp.sum(jnp.square(Y-VM@coeff))
def Hessian(func):
    return jax.jacfwd(jax.jacrev(func))

with jnp.printoptions(precision=8, linewidth=1000, suppress=True, floatmode="fixed"):
    #print(fac_here)
    print("\n", jnp.linalg.cond(Hessian(res_sq)(p_coeffs_lstsq)) - jnp.finfo(Hessian(res_sq)(p_coeffs_lstsq).dtype).eps)
    #print("y_real: ", Y, " end")
"""
print("y_fit: ", VM@p_coeffs, " end")
print("y_fit using curve_fit: ", VM@p_coeff_alt, " end")"""
print("scaling factor, ridge: ", fac_ridge)
with jnp.printoptions(precision=3, linewidth=1000, suppress=True, floatmode="fixed"):
    print("y: ", Y)
    print("coeff, lstsq: ", p_coeffs_lstsq)
    print("coeff, ridge: ", p_coeffs_ridge)
    print("cov given by inverse Hessian, lstsq: \n", np.linalg.inv(Hessian(res_sq)(p_coeffs_lstsq))*fac_here, " end")
    print("cov given by inverse Hessian, ridge: \n", np.linalg.inv(Hessian(ridge)(p_coeffs_ridge)), " end")
    print("VM: ", VM)

    print("Difference in inverse Hessians using lstsq regression: \n", np.linalg.inv(Hessian(res_sq)(p_coeffs_lstsq)) - np.load("polyfit_inv_hess_lst_sq.npy"))
    print("Difference in inverse Hessians using ridge regression: \n", np.linalg.inv(Hessian(ridge)(p_coeffs_ridge)) - np.load("polyfit_inv_hess_ridge.npy"))
#print("cov given by scipy curve_fit: \n", p_cov_alt, " end")
#print("cov given by old method: \n", cov, " end")
#print(p_coeffs, p_coeff_alt)



 182565870.0
scaling factor, ridge:  0.24130988
y:  [4.974 2.493 3.271 4.553 2.964 2.768 1.638 4.882 4.269 2.458 1.808 6.343 4.077 3.092 1.471 1.701 2.621 4.369 1.448 5.541 2.237 1.205 3.437 3.304 2.360 3.639 3.239 4.267 1.172 5.066 5.451 2.127 1.853 3.874 2.892 3.834 1.403 4.419 1.501 3.649 3.615 2.961]
coeff, lstsq:  [  80.631  104.382  -78.869   22.884 -118.503    7.387  -13.131  -45.061    9.274    0.986]
coeff, ridge:  [ 8.470  0.241 -1.904  0.427  1.828 -1.643 -2.090  1.411  2.279  1.357]
cov given by inverse Hessian, lstsq: 
 [[-1539.067 -3601.886   706.977 -2850.625  1038.876  -120.157  -756.955   396.314   -78.815     8.058]
 [-3601.886 -6655.653  2742.832 -4033.265  3436.893  -656.976  -771.057  1129.576  -368.962    60.409]
 [  706.977  2742.832   355.224  2915.536   141.016  -184.967   955.569   -54.653   -76.152    24.223]
 [-2850.625 -4033.265  2915.536 -1162.333  3667.144  -689.033   213.031  1155.815  -422.326    55.615]
 [ 1038.876  3436.893   141.016  3667.144   214.

In [3]:
from modules.polyfit import Polyfit

dummy_fits = Polyfit('dummy_data/dummy_p_coeffs.npz', 'dummy_data/dummy_chi2res.npz',
 input_h5='dummy_data/inputdata.h5', order=3, cov_npz = 'dummy_data/dummy_cov.npz')
test_surrogate, chi2_ndf, res, cov = dummy_fits.get_surrogate_func('/func0#0')


['/func0', '/func1']
fitting /func0#0
func0#0 
 [  80.62972621  104.38691504  -78.86316804   22.8934713  -118.4974257
    7.38508711  -13.12712689  -45.0598705     9.27279875    0.98583797] 
 [ 63.351202   156.53672883  37.81152827 142.35190859  47.69563861
  12.44700242  46.14162884  19.2047809    8.43895188   2.34190177] 
end
fitting /func0#1
func0#1 
 [  -9.06684082 -151.08618601  -55.02320836 -204.8598869   -59.50088039
   10.58171505  -83.664558    -26.22340188    1.5444825    -1.49583436] 
 [ 95.21013417 235.25809271  56.82671467 213.93981311  71.68148366
  18.70652384  69.34597188  28.8627478   12.68284919   3.51962985] 
end
fitting /func0#2
func0#2 
 [150.4308984  289.43369588 -77.35680929 224.11656476 -67.15013875
  23.15696082  43.28677066 -51.44319631 -11.12765578  -6.7135658 ] 
 [ 74.36241411 183.74472283  44.38363338 167.09440772  55.98572273
  14.61044336  54.16160709  22.54280621   9.90574472   2.74895288] 
end
fitting /func0#3
func0#3 
 [ 154.72120367  334.70876121  -72

In [7]:
print(jnp.linalg.cond(Hessian(res_sq)(p_coeffs)))
print(jnp.linalg.cond(VM.T@VM))

182565870.0
182565870.0


In [16]:
Hessian(res_sq)(p_coeffs)

DeviceArray([[  84.      ,  -83.98069 ,  127.74591 ,   84.94728 ,
              -127.58001 ,  199.8724  ,  -86.901634,  128.94978 ,
              -199.40294 ,  321.06458 ],
             [ -83.98069 ,   84.94728 , -127.58001 ,  -86.901634,
               128.9498  , -199.40294 ,   89.86794 , -131.85355 ,
               201.38052 , -320.0215  ],
             [ 127.74591 , -127.58001 ,  199.87242 ,  128.94978 ,
              -199.40292 ,  321.06458 , -131.85356 ,  201.38052 ,
              -320.0215  ,  528.26416 ],
             [  84.94728 ,  -86.901634,  128.94978 ,   89.86794 ,
              -131.85355 ,  201.38054 ,  -93.892555,  136.32425 ,
              -205.79607 ,  322.96472 ],
             [-127.58001 ,  128.9498  , -199.40292 , -131.85355 ,
               201.38052 , -320.02148 ,  136.32425 , -205.79608 ,
               322.96472 , -526.2069  ],
             [ 199.8724  , -199.40294 ,  321.06458 ,  201.38054 ,
              -320.02148 ,  528.2642  , -205.79607 ,  322.96472 ,
   