In [19]:
import sacc
import emcee
import matplotlib.pylab as plt 
import numpy as np
import pyccl as ccl
import jax.numpy as jnp

In [3]:
import jax_cosmo as jc 

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [2]:
def ccl_get_nz(sfile, tracertype="wl"):
    tracers_names = list(sfile.tracers.keys())

    if tracertype == "wl":
        tname = "DESwl__"
    else:
        tname = "DESgc__"
    nbin = sum([tname in tracers_names[i] for i in range(len(tracers_names))])

    nz_distributions = list()
    for i in range(nbin):
        name = tname + str(i)
        distribution = sfile.tracers[name]
        
        z_dist = (distribution.z, distribution.nz)
        nz_distributions.append(jaxred)

    return nz_distributions


def calculate_lmax_gc(sfile, kmax):
    tracers_names = list(sfile.tracers.keys())
    nbin_gc = sum(["DESgc__" in tracers_names[i] for i in range(len(tracers_names))])
    vanillacosmo = jc.Planck15()
    lmaxs = list()
    for i in range(nbin_gc):
        tracer = sfile.tracers[f"DESgc__{i}"]
        zmid = jnp.average(jnp.asarray(tracer.z), weights=jnp.asarray(tracer.nz))
        chi = jc.background.radial_comoving_distance(vanillacosmo, 1.0 / (1.0 + zmid))
        minmax = jnp.concatenate([10.0 * jnp.ones(1), kmax * chi - 0.5], dtype=int)
        lmax = jnp.max(minmax)
        lmaxs.append(lmax)
    return lmaxs

def scale_cuts(sfile, kmax=0.15, lmin_wl=30, lmax_wl=2000):
    # First we remove all B-modes
    sfile.remove_selection(data_type="cl_bb")
    sfile.remove_selection(data_type="cl_be")
    sfile.remove_selection(data_type="cl_eb")
    sfile.remove_selection(data_type="cl_0b")

    tracers_names = list(sfile.tracers.keys())
    nbin_gc = sum(["DESgc__" in tracers_names[i] for i in range(len(tracers_names))])
    nbin_wl = sum(["DESwl__" in tracers_names[i] for i in range(len(tracers_names))])
    lmaxs_gc = calculate_lmax_gc(sfile, kmax)

    for i, lmax in enumerate(lmaxs_gc):
        print(f"Maximum ell is {lmax}")
        tname_1 = f"DESgc__{i}"

        # Remove from galaxy clustering
        sfile.remove_selection(
            data_type="cl_00", tracers=(tname_1, tname_1), ell__gt=lmax
        )

        # Remove from galaxy-galaxy lensing
        for j in range(nbin_wl):
            tname_2 = f"DESwl__{j}"
            sfile.remove_selection(
                data_type="cl_0e", tracers=(tname_1, tname_2), ell__gt=lmax
            )

    # apply scale cut for weak lensing
    for i in range(nbin_wl):
        for j in range(i, nbin_wl):
            tname_1 = f"DESwl__{i}"
            tname_2 = f"DESwl__{j}"
            sfile.remove_selection(
                data_type="cl_ee", tracers=(tname_1, tname_2), ell__gt=lmax_wl
            )
            sfile.remove_selection(
                data_type="cl_ee", tracers=(tname_1, tname_2), ell__lt=lmin_wl
            )

    return sfile

In [None]:
def ccl_load_data(fname="cls_DESY1", kmax=0.15, lmin_wl=30, lmax_wl=2000):
    saccfile = sacc.Sacc.load_fits(f"data/{fname}.fits")
    jax_nz_wl = ccl_get_nz(saccfile, tracertype="wl")
    jax_nz_gc = ccl_get_nz(saccfile, tracertype="gc")
    saccfile_cut = ccl_scale_cuts(saccfile, kmax=kmax, lmin_wl=lmin_wl, lmax_wl=lmax_wl)
    bw_gc, bw_gc_wl, bw_wl = ccl_extract_bandwindow(saccfile_cut)
    data, covariance = ccl_extract_data_covariance(saccfile_cut)
    newcov = covariance + jnp.eye(data.shape[0]) * 1e-18
    precision = np.linalg.inv(newcov)
    return data, precision, jax_nz_gc, jax_nz_wl, bw_gc, bw_gc_wl, bw_wl