In [1]:
from LIMxCMBL.init import *
from LIMxCMBL.kernels import *

from scipy.interpolate import interp1d, LinearNDInterpolator
from scipy.integrate import quad, quad_vec, trapezoid, qmc_quad

import sys
from os.path import isfile
import jax
import jax.numpy as jnp

from jax import config
config.update("jax_enable_x64", True)
# config.update('jax_platform_name', 'cpu')




Lambda_idx = 24#int(sys.argv[1])
n_bins = 100#int(sys.argv[2])
ell_idx = 58#int(sys.argv[3])


Lambda = Lambdas[Lambda_idx]

zmin = 2.4#float(sys.argv[4])
zmax = 3.4#float(sys.argv[5])

kernels = {}
kernels['CII'] = np.array(KI)
kernels['CO'] = np.array(KI_CO)
kernels['Lya'] = np.array(KI_Lya)
kernels['HI'] = np.array(KI_HI)


line_str = 'CO'#sys.argv[6]
print(line_str)
_KI = kernels[line_str]

oup_fname = '/scratch/users/delon/LIMxCMBL/I_auto/comb_'
oup_fname += '%s_zmin_%.1f_zmax_%.1f_Lambda_idx_%d_n_b_%d_l_%d_jax_quad.npy'%(line_str,
                                                                                zmin, zmax, 
                                                                                Lambda_idx, 
                                                                                n_bins,
                                                                                ell_idx)


print(oup_fname)

chimin = ccl.comoving_angular_distance(cosmo, 1/(1+zmin))
chimax = ccl.comoving_angular_distance(cosmo, 1/(1+zmax))

inner_dkparp_integral = np.load('/oak/stanford/orgs/kipac/users/delon/LIMxCMBL/inner_dkparp_integral.npy')
inner_dkparp_integral = inner_dkparp_integral.astype(np.float64)
inner_dkparp_integral = np.moveaxis(inner_dkparp_integral, 0, -1)

@jax.jit
def f_KILo(chi, external_chi, Lambda):
    return (Lambda / jnp.pi 
            * jnp.interp(x = chi, xp = chis, 
                         fp = _KI, left = 0, right = 0) 
            * jnp.sinc(Lambda * (external_chi - chi) / np.pi))


CO
/scratch/users/delon/LIMxCMBL/I_auto/comb_CO_zmin_2.4_zmax_3.4_Lambda_idx_24_n_b_100_l_58_jax_quad.npy


In [2]:
chi_bin_edges = np.linspace(chimin*(1+1e-8), chimax*(1 - 1e-8), n_bins + 1)
chi_bin_centers = (chi_bin_edges[1:] + chi_bin_edges[:-1])/2
dchi_binned = np.mean(np.diff(chi_bin_edges))

In [3]:
from interpax import interp2d
from interpax import interp1d as interp1dx

In [4]:
from tqdm import trange, tqdm

In [5]:
chimin

5858.14183362481

In [6]:
inner_dkparp_integral.shape

(256, 128, 100)

In [19]:
@jax.jit
def f_integrand(x):
    chi, chip, _chib = x[0], x[1], x[2]
    chi = chi.reshape(-1, 1)
    chip = chip.reshape(-1, 1)
    _chib = _chib.reshape(-1, 1)

    _delta = jnp.abs(1 - chi/_chib) #(p)
    _delta = jnp.where(_delta < 1e-6, 1e-6, 
                     jnp.where(_delta > 0.7, 0.7, _delta))

    _idx = ((chimin <= 2*_chib - chi) 
            & (2*_chib - chi <= chimax)) #(p)
    cross_integrand = (2 * jnp.interp(x = chi, xp = chis, fp = _KI, left = 0, right = 0) 
                       * interp2d(xq = _chib.reshape(-1), yq=jnp.log(_delta).reshape(-1), 
                           x = chibs, y = jnp.log(deltas), f=inner_dkparp_integral,
                           method='linear',) 
                       / (_chib**2))
    cross_integrand *= jnp.where(_idx,
                                 f_KILo(2*_chib - chi, 
                                        external_chi = chip,
                                        Lambda=Lambda), 0) #(p)

    _delta = jnp.abs(1 - chip/_chib) #(p)
    _delta = jnp.where(_delta < 1e-6, 1e-6, 
                     jnp.where(_delta > 0.7, 0.7, _delta))

    _idx = ((chimin <= 2*_chib - chip) 
            & (2*_chib - chip <= chimax)) #(p)


    cross_integrand_2 = (2 * jnp.interp(x = chip, xp = chis, fp = _KI, left = 0, right = 0) 
                       * interp2d(xq = _chib.reshape(-1), yq=jnp.log(_delta).reshape(-1), 
                           x = chibs, y = jnp.log(deltas), f=inner_dkparp_integral,
                           method='linear',) 
                       / (_chib**2))

    cross_integrand_2 *= jnp.where(_idx,
                               f_KILo(2*_chib - chip, 
                                        external_chi = chi,
                                        Lambda=Lambda),0)

    cross_integrand += cross_integrand_2

    #LoLo
    plus = _chib.reshape(-1, 1)*(1+deltas.reshape(1, -1))
    mins = _chib.reshape(-1, 1)*(1-deltas.reshape(1, -1))
    _idxs = (chimin < plus) & (plus < chimax) & (chimin < mins) & (mins < chimax)
    LoLo_integrand  = jnp.where(_idxs,
                               f_KILo(plus, 
                                      external_chi = chi.reshape(-1, 1),
                                      Lambda=Lambda) 
                                * f_KILo(mins, 
                                         external_chi = chip.reshape(-1, 1),
                                         Lambda=Lambda),
                               0)

    LoLo_integrand += jnp.where(_idxs,
                               f_KILo(mins, 
                                      external_chi = chi.reshape(-1, 1),
                                      Lambda=Lambda) 
                                * f_KILo(plus, 
                                         external_chi = chip.reshape(-1, 1),
                                         Lambda=Lambda),0)
    LoLo_integrand *= (2 / _chib.reshape(-1, 1)) 
    LoLo_integrand *= deltas.reshape(1, -1)
    LoLo_integrand = jnp.einsum('pd, pdl->pld', LoLo_integrand,
                                interp1dx(xq = _chib.reshape(-1),x = chibs, 
                                f=inner_dkparp_integral,
                                method='linear',))

    LoLo_integrand = jnp.trapezoid(x = np.log(deltas), y = LoLo_integrand, axis=-1)
    return LoLo_integrand - cross_integrand

In [20]:
from scipy.stats import qmc


In [21]:
qrng = qmc.Halton(d = 3)

In [22]:

params_list = []
for i in range(n_bins):
    l1, r1 = chi_bin_edges[i], chi_bin_edges[i+1]
    for j in range(i, n_bins):
        l2, r2 = chi_bin_edges[j], chi_bin_edges[j+1]
        params = (i, j, l1, r1, l2, r2)
        params_list.append(params)
        
    

In [23]:
oup = np.zeros((n_bins, n_bins), dtype=np.float64)


In [33]:
a = np.array([l1, l2, chimin])
b = np.array([r1, r2, chimax])

sample = qrng.random(n = 2**25)
scaled_samples = qmc.scale(sample, a, b)


In [29]:
scaled_samples = jnp.array(scaled_samples)

In [30]:
tmp = f_integrand(scaled_samples)

In [31]:
scaled_samples.shape

(1048576, 3)

In [32]:
tmp.shape

(3, 100)

In [None]:
jax.default_backend()

In [None]:
_oup

In [None]:
comb_n_external = 301
comb_unbinned = np.zeros((len(ells), comb_n_external, comb_n_external))
comb_fname = '/scratch/users/delon/LIMxCMBL/I_auto/comb_'
comb_fname += '%s_zmin_%.1f_zmax_%.1f_Lambda_idx_%d_n_ext_%d_l_%d_jax_quad.npy'%('CO',
                                                                              zmin, zmax, 
                                                                              Lambda_idx, 
                                                                              comb_n_external,
                                                                              ell_idx)


comb_unbinned[ell_idx] = np.load(comb_fname)

In [None]:
def get_binned(base, n_external = 300):
    external_chis = np.linspace(chimin*(1+1e-8), chimax*(1 - 1e-8), n_external)
    
    oup = np.zeros((100, n_bins, n_bins), dtype=np.float64)
    for i, (l1, r1) in enumerate(zip(chi_bin_edges, chi_bin_edges[1:])):
        for j, (l2, r2) in enumerate(zip(chi_bin_edges, chi_bin_edges[1:])):
            idx1 = np.where((external_chis > l1) & (external_chis <= r1))[0]
            idx2 = np.where((external_chis > l2) & (external_chis <= r2))[0]
            oup[:,i,j] = (np.sum(base[:,
                                      idx1[0]:idx1[-1]+1,
                                      idx2[0]:idx2[-1]+1], 
                                 axis=(1, 2)) / len(idx1) / len(idx2))
    return oup

In [None]:
comb = get_binned(comb_unbinned, n_external = comb_n_external)


In [None]:
comb[ell_idx][0][0]

In [None]:
2**21

In [None]:
301**2

In [None]:
@jax.jit
def f_integrand(_chib, chip, chi):
    _delta = jnp.abs(1 - chi/_chib) #(1)
    _delta = jnp.where(_delta < 1e-6, 1e-6, 
                     jnp.where(_delta > 0.7, 0.7, _delta))

    _idx = ((chimin <= 2*_chib - chi) 
            & (2*_chib - chi <= chimax)) #(1)
    cross_integrand = (2 * jnp.interp(x = chi, xp = chis, fp = _KI, left = 0, right = 0) 
                       * interp2d(xq = _chib, yq=jnp.log(_delta), 
                           x = chibs, y = jnp.log(deltas), f=inner_dkparp_integral,
                           method='linear',) 
                       / (_chib**2))
    cross_integrand *= jnp.where(_idx,
                                 f_KILo(2*_chib - chi, 
                                        external_chi = chip,
                                        Lambda=Lambda), 0) #(1)

    _delta = jnp.abs(1 - chip/_chib) #(1)
    _delta = jnp.where(_delta < 1e-6, 1e-6, 
                     jnp.where(_delta > 0.7, 0.7, _delta))

    _idx = ((chimin <= 2*_chib - chip) 
            & (2*_chib - chip <= chimax)) #(1)


    cross_integrand_2 = (2 * jnp.interp(x = chip, xp = chis, fp = _KI, left = 0, right = 0) 
                       * interp2d(xq = _chib, yq=jnp.log(_delta), 
                           x = chibs, y = jnp.log(deltas), f=inner_dkparp_integral,
                           method='linear',) 
                       / (_chib**2))

    cross_integrand_2 *= jnp.where(_idx,
                               f_KILo(2*_chib - chip, 
                                        external_chi = chi,
                                        Lambda=Lambda),0)

    cross_integrand += cross_integrand_2

    #LoLo
    plus = _chib*(1+deltas)
    mins = _chib*(1-deltas)
    _idxs = (chimin < plus) & (plus < chimax) & (chimin < mins) & (mins < chimax)

    LoLo_integrand  = jnp.where(_idxs,
                               f_KILo(plus, 
                                      external_chi = chi,
                                      Lambda=Lambda) 
                                * f_KILo(mins, 
                                         external_chi = chip,
                                         Lambda=Lambda),
                               0)
    LoLo_integrand += jnp.where(_idxs,
                               f_KILo(mins, 
                                      external_chi = chi,
                                      Lambda=Lambda) 
                                * f_KILo(plus, 
                                         external_chi = chip,
                                         Lambda=Lambda),0)
    LoLo_integrand *= (2 / _chib) * deltas
    LoLo_integrand *= interp1dx(xq = _chib,
                                x = chibs, f=inner_dkparp_integral,
                                method='linear',)

    LoLo_integrand = jnp.trapezoid(x = np.log(deltas), y = LoLo_integrand, axis=-1)
    return LoLo_integrand - cross_integrand

In [None]:
f_integrand(5733.5, 5734, 5733)

In [None]:
oup = np.zeros((n_bins, n_bins), dtype=np.float64)

params_list = []
for i in range(n_bins):
    l1, r1 = chi_bin_edges[i], chi_bin_edges[i+1]
    for j in range(i, n_bins):
        l2, r2 = chi_bin_edges[j], chi_bin_edges[j+1]
        params = (i, j, l1, r1, l2, r2)
        params_list.append(params)

In [None]:
from scipy.integrate import nquad

In [None]:
options={'limit':100000, 'epsabs': 0.0, 'epsrel':1e-3}

In [None]:
def elem(params):
    i, j, l1, r1, l2, r2 = params
    res, err, info = nquad(f_integrand,
                 [[10, chimax_sample],[l2, r2],[l1,r1]],
                opts=[options]*3, full_output = True)

    
    return (i, j, res / dchi_binned**2, info)

for params in tqdm(params_list):
    i, j, _oup, info = elem(params)
    oup[i,j] = oup[j,i] = _oup
    print(info.success, info.status, info.message)
    print('ninters', info.intervals.shape)
    break