In [1]:
import jax
import os
import jax.numpy as np
import jax_cosmo as jc
import numpy as onp
from desy1 import theory_cov, get_params_vec, get_data


In [2]:

fid_cosmo = jc.Cosmology(sigma8=0.801,
                          Omega_c=0.2545,
                          Omega_b=0.0485,
                          h=0.682,
                          n_s=0.971,
                          w0=-1., Omega_k=0., wa=0.)

fid_params  = get_params_vec(fid_cosmo, 
                                          [0., 0., 0., 0.],
                                          [0., 0., 0., 0.],
                                          [0.5, 0.],
                                          [1.2, 1.4, 1.6, 1.8, 2.0])
nz_source, nz_lens = get_data()



In [3]:
neff_s = [1.47, 1.46, 1.50, 0.73]

nzs_s = [jc.redshift.kde_nz(nz_source['Z_MID'].astype('float32'),
                            nz_source['BIN%d'%i].astype('float32'), 
                            bw=0.01,
                            gals_per_arcmin2=neff_s[i-1])
           for i in range(1,5)]
nzs_l = [jc.redshift.kde_nz(nz_lens['Z_MID'].astype('float32'),
                              nz_lens['BIN%d'%i].astype('float32'), bw=0.01)
           for i in range(1,6)]    


In [7]:
# Define some ell range
ell = np.logspace(1, 3)
args = [nzs_s, nzs_l, ell]
covmat = theory_cov(fid_params, *args)
covmat = onp.array(covmat)
#print(covmat)
onp.save("covmat.npy", covmat)

In [29]:
covmat = onp.load("covmat.npy")
#print(covmat)
print(covmat.shape)

(2025, 50, 50)


In [66]:
#cov_mat # shape = (n_cls*n_cls, n_ell)
n_cls = 45
n_ell = 50
c = onp.zeros((n_cls * n_ell, n_cls * n_ell))
#onp.set_printoptions(threshold=np.inf)
#print(onp.nonzero(covmat)[2]-onp.nonzero(covmat)[1])
# for i in range(n_cls):
#     for j in range(n_cls):
#         s1 = i * n_ell
#         s2 = j * n_ell
#         for k in range(n_ell):
#             c[s1 + k, s2 + k] = covmat[i + n_cls*j, k]
# covmat = c

C = []
for i in covmat:
    for j in range(50):
        C.append(i[j][j])
        
Cov = onp.array(C)
print(Cov.reshape((50,45,45)))



[[[7.86629973e-17 6.39554686e-17 4.94299036e-17 3.80524680e-17
   2.92062765e-17 2.23665371e-17 1.71064224e-17 1.30756713e-17
   9.99654206e-18 7.64893746e-18 5.86132316e-18 4.50005869e-18
   3.46277575e-18 2.67148677e-18 2.06700861e-18 1.60437289e-18
   1.24956287e-18 9.76890067e-19 7.66811400e-19 6.04417216e-19
   4.78433768e-19 3.80349383e-19 3.03654421e-19 2.43392503e-19
   1.95807801e-19 1.58043185e-19 1.27936859e-19 1.03836541e-19
   8.44710092e-20 6.88619690e-20 5.62424638e-20 4.60090974e-20
   3.76897073e-20 3.09123351e-20 2.53803403e-20 2.08573048e-20
   1.71539640e-20 1.41179977e-20 1.16264560e-20 9.57983114e-21
   7.89731704e-21 6.51316736e-21 5.37377102e-21 4.43533396e-21
   3.66203037e-21]
  [3.02451666e-21 2.49873152e-21 2.06492742e-21 1.70688461e-21
   1.48080591e-21 6.00641393e-17 4.80316626e-17 3.63927641e-17
   2.73684444e-17 2.04440779e-17 1.51776366e-17 1.12076258e-17
   8.23656986e-18 6.02863616e-18 4.39752876e-18 3.19893864e-18
   2.32156872e-18 1.68135809e-18 1.2

In [62]:
def check_symmetric(a, tol=1e-30):
    return np.all(np.abs(a-a.T) < tol)

In [63]:
print(check_symmetric(Cov))

True
