In [None]:
import sys
sys.path.append('/global/homes/k/kunhaoz/des/projects/jax_cosmo_late_mod/')

%pylab inline
import jax
import jax_cosmo as jc
import jax.numpy as np
import numpy as onp
import os

import numpyro
import numpyro.distributions as dist

print("JAX version:", jax.__version__)
#print("jax-cosmo version:", jc.__version__)
print("Using Modified Jax-cosmo from: ", jc.__file__)

In [None]:
# Create a cosmology with default parameters
cosmo_P15 = jc.Planck15()

In [None]:
# Now let's try to build the equivalent with jax-cosmo

# Let's grab the data file
from astropy.io import fits
if not os.path.isfile('2pt_NG_mcal_1110.fits'):
    !wget http://desdr-server.ncsa.illinois.edu/despublic/y1a1_files/chains/2pt_NG_mcal_1110.fits

nz_source=fits.getdata('2pt_NG_mcal_1110.fits', 6)
nz_lens=fits.getdata('2pt_NG_mcal_1110.fits', 7)

# This is the effective number of sources from the cosmic shear paper
neff_s = [1.47, 1.46, 1.50, 0.73]

nzs_s = [jc.redshift.kde_nz(nz_source['Z_MID'].astype('float32'),
                            nz_source['BIN%d'%i].astype('float32'), 
                            bw=0.01,
                            gals_per_arcmin2=neff_s[i-1])
           for i in range(1,5)]

nzs_l = [jc.redshift.kde_nz(nz_lens['Z_MID'].astype('float32'),
                              nz_lens['BIN%d'%i].astype('float32'), bw=0.01)
           for i in range(1,6)]

In [None]:
# Create one with late-time modification (bin); keep others the same as planck15
from jax_cosmo.core import Cosmology

_z_bin = onp.array([0.0,   1.0, 5])
_k_bin = onp.array([0.025, 0.8, 5])
# _k_bin = onp.array([-2, 0, 5]) 

tmpa = np.array([0. for i in range(int(_z_bin[2]))])
print(tmpa)
tmpb = np.array([0. for i in range(int(_k_bin[2]))])
print(tmpb)

cosmo_modified2 = Cosmology(sigma8=0.801,
                          Omega_c=0.2545,
                          Omega_b=0.0485,
                          h=0.682,
                          n_s=0.971,
                          w0=-1.0, # Fix w0=-1
                          Omega_k=0., wa=0.,
                         a_late=[0.0, 0., 0., 0., 0.1], z_mod_form="bin_custom", z_bin=_z_bin,
                         b_late=tmpb, k_mod_form="bin_custom", k_bin=_k_bin)

# You can inspect the documentation to see the 
# meaning of these positional arguments


dz =  [0.0,0.0,0.0,0.0]
A = 0.5
eta = 0.0
bias = [1.2, 1.4, 1.6, 1.8, 2.0]
m = [0.0,0.0,0.0,0.0]


nzs_s_sys = [jc.redshift.systematic_shift(nzi, dzi) 
            for nzi, dzi in zip(nzs_s, dz)]

# Define IA model, z0 is fixed
b_ia = jc.bias.des_y1_ia_bias(A, eta, 0.62)
# Bias for the lenses
b = [jc.bias.constant_linear_bias(bi) for bi in bias] 

# Define the lensing and number counts probe
probes = [jc.probes.WeakLensing(nzs_s_sys, 
                                ia_bias=b_ia,
                                multiplicative_bias=m),
         jc.probes.NumberCounts(nzs_l, b)]

ell = np.logspace(1,3) # Defines a range of \ell

# And compute the data vector
cls_P15 = jc.angular_cl.angular_cl(cosmo_P15, ell, probes)


# Cls for modified cosmology
cls_modified2 = jc.angular_cl.angular_cl(cosmo_modified2, ell, probes)


residue2 = (cls_modified2 - cls_P15) / cls_P15

fig, axes = plt.subplots(2, 5, figsize=(15, 5))

for i in range(2):
    for j in range(5):
        ax = axes[i][j]
        ax.plot(ell, residue2[i*5 + j])

In [None]:
fig, axes = plt.subplots(2, 5, figsize=(15, 5))
for i in range(2):
    for j in range(5):
        ax = axes[i][j]
        ax.plot(ell, cls_modified2[i*5 + j])

In [None]:
cls_modified2[0]