In [19]:
import jax.numpy as jnp
from timing import poly_gen_D, timing_van_jax
from out_loop import objective_func


dummy_cov_arr = jnp.load("dummy_cov.npz")
dummy_chi2res = jnp.load("dummy_chi2res.npz")
dummy_pcoeffs = jnp.load("dummy_pceoffs.npz")
jnp.shape(dummy_chi2res)
coeff = dict(zip((range(jnp.size(dummy_pcoeffs))), (dummy_pcoeffs[k] for k in dummy_pcoeffs)))
cov = dict(zip((range(jnp.size(dummy_cov_arr))), (dummy_cov_arr[k] for k in dummy_cov_arr)))

In [20]:
p = jnp.array([2, 3])

poly = timing_van_jax([p], 3)[0]

jnp.matmul(coeff[1], poly.T)
coeff[1]

#Setting target to be something completely random, we expect the objective function to be extremely high
target_d = range(40)
targ_sig = jnp.ones(40)

print("We expect this to be high ", objective_func(p, target_d, targ_sig))

#Setting target to be exactly what it would be with p = 2, 3, we expect the objective function to be arbitrarily low
target_d = jnp.matmul(jnp.array(list(coeff.values())), poly.T)
print("We expect this to be low ", objective_func(p, target_d, targ_sig))

jnp.size(jnp.array(list(coeff.values())), axis=0)

We expect this to be high  3608.5068
We expect this to be low  4.3565507e-13


40

In [21]:
#Testing

import scipy.optimize as opt
import jax.random as random
from jax import jacfwd, jacrev

#Toy data
p_real = jnp.array([300, 480])
poly_real = timing_van_jax([p_real], 3)[0]
key = random.PRNGKey(20)
target_d_real = jnp.matmul(jnp.array(list(coeff.values())), poly_real.T) + 50*random.normal(key, shape=(40,))
targ_sig_real = 200*jnp.ones(40)


p_guess = jnp.array([0, 0])
p_opt = opt.minimize(objective_func, p_guess, args=(target_d_real, targ_sig_real), method='Nelder-Mead')



#Math for covariance
def res_sq_sum(P):
    p_poly = timing_van_jax([P], 3)[0]
    return jnp.sum(jnp.square(jnp.matmul(jnp.array(list(coeff.values())), p_poly.T) - target_d_real))
def Hessian(func):
    return jacfwd(jacrev(res_sq_sum))
p_opt_cov = jnp.linalg.inv(Hessian(res_sq_sum)(p_opt.x))*res_sq_sum(p_opt.x)/(38)
p_opt_unc = jnp.sqrt(jnp.diagonal(p_opt_cov))


print("Returned parameters, we expect this to be the same as p_real ", p_opt.x)
print("Uncertainty in returned parameters, currently broken, expect this to be higher in correct versions ", p_opt_unc)
print("Target distribution ", target_d_real)

poly_opt = timing_van_jax([p_opt.x], 3)[0]
print("Distribution created from returned parameters ", jnp.matmul(jnp.array(list(coeff.values())), poly_opt.T).reshape(10, 4)) #Reshaped for comparison with target distribution



Returned parameters, we expect this to be the same as p_real  [299.99749678 480.00053822]
Uncertainty in returned parameters, currently broken, expect this to be higher in correct versions  [6.910124e-04 7.512431e-05]
Target distribution  [-4.1501885e+09 -1.1093762e+10  6.1360230e+08  3.7448952e+09
 -1.3400934e+10  2.7636050e+09 -2.0341055e+10 -1.0259394e+09
 -1.5658098e+10  5.1413217e+09 -7.0312059e+09  8.0410264e+07
 -4.5577994e+09  1.9528223e+09  3.3425312e+08  2.0825925e+09
  6.5380832e+08  7.3925594e+08  3.7174941e+08  5.9552230e+08
 -7.0241500e+09  3.8048297e+09 -2.2802297e+10 -2.9366188e+10
  4.7016894e+10  7.6632735e+09 -3.3695437e+09  1.0345909e+10
 -5.6908104e+09  2.1669012e+09  3.7736755e+09  9.1455283e+08
  8.5443904e+08  1.6842616e+08  8.3342400e+07  2.7056820e+07
  8.2515930e+06  2.5090695e+06 -4.1598564e+01 -8.1214439e+01]
Distribution created from returned parameters  [[-4.15018163e+09 -1.10937836e+10  6.13652480e+08  3.74492928e+09]
 [-1.34009651e+10  2.76354022e+09 -2