In [None]:
import sys

from functools import partial


import os.path
import jax
from jax.config import config
config.update("jax_enable_x64", True)

import jax.numpy as np
import jax_cosmo as jc
from optax import exponential_decay

import numpy as onp


In [None]:
import numpyro
import numpyro.distributions as dist
numpyro.util.enable_x64()
numpyro.set_platform('gpu')

import jax
print("DEBUG Jax device:",jax.devices())


In [None]:
# Let's grab the data file
from astropy.io import fits
    
if not os.path.isfile('2pt_NG_mcal_1110.fits'):
    get_ipython().system('wget http://desdr-server.ncsa.illinois.edu/despublic/y1a1_files/chains/2pt_NG_mcal_1110.fits')


nz_source=fits.getdata('2pt_NG_mcal_1110.fits', 6)
nz_lens=fits.getdata('2pt_NG_mcal_1110.fits', 7)

# This is the effective number of sources from the cosmic shear paper
neff_s = [1.47, 1.46, 1.50, 0.73]
nzs_s = [jc.redshift.kde_nz(nz_source['Z_MID'].astype('float32'),
                            nz_source['BIN%d'%i].astype('float32'), 
                            bw=0.01,
                            gals_per_arcmin2=neff_s[i-1])
           for i in range(1,5)]
nzs_l = [jc.redshift.kde_nz(nz_lens['Z_MID'].astype('float32'),
                            nz_lens['BIN%d'%i].astype('float32'), bw=0.01)
        for i in range(1,6)]

# Define some ell range
ell = np.logspace(1, 3)

In [None]:
# Let's define our model using numpyro
def model():
    #  Cosmological params
    Omega_c = numpyro.sample('Omega_c', dist.Uniform(0.1, 0.9))
    sigma8 = numpyro.sample('sigma8', dist.Uniform(0.4, 1.0))
    Omega_b = numpyro.sample('Omega_b', dist.Uniform(0.03, 0.07))
    h = numpyro.sample('h', dist.Uniform(0.55, 0.91))
    n_s = numpyro.sample('n_s', dist.Uniform(0.87, 1.07)) 
    w0 = numpyro.sample('w0', dist.Uniform(-2.0, -0.33))

    # Intrinsic Alignment
    A = numpyro.sample('A', dist.Uniform(-5., 5.))
    eta = numpyro.sample('eta', dist.Uniform(-5., 5.))

    # linear galaxy bias
    bias = [numpyro.sample('b%d'%i, dist.Uniform(0.8, 3.0)) 
         for i in range(1,6)]
        
    # parameters for systematics
    m = [numpyro.sample('m%d'%i, dist.Normal(0.012, 0.023)) 
         for i in range(1,5)]
    dz1 = numpyro.sample('dz1', dist.Normal(0.001, 0.016)) 
    dz2 = numpyro.sample('dz2', dist.Normal(-0.019, 0.013)) 
    dz3 = numpyro.sample('dz3', dist.Normal(0.009, 0.011)) 
    dz4 = numpyro.sample('dz4', dist.Normal(-0.018, 0.022)) 
    dz = [dz1, dz2, dz3, dz4]
    
    # Now that params are defined, here is the forward model
    cosmo = jc.Cosmology(Omega_c=Omega_c, sigma8=sigma8, Omega_b=Omega_b,
                          h=h, n_s=n_s, w0=w0, Omega_k=0., wa=0.)
    
    # Build source nz with redshift systematic bias
    nzs_s_sys = [jc.redshift.systematic_shift(nzi, dzi, zmax=2.0) 
                for nzi, dzi in zip(nzs_s, dz)]
    
    # Define IA model, z0 is fixed
    b_ia = jc.bias.des_y1_ia_bias(A, eta, 0.62)

    # Bias for the lenses
    b = [jc.bias.constant_linear_bias(bi) for bi in bias] 
    
    # Define the lensing and number counts probe
    probes = [jc.probes.WeakLensing(nzs_s_sys, 
                                    ia_bias=b_ia,
                                    multiplicative_bias=m),
             jc.probes.NumberCounts(nzs_l, b)]

    cl, C = jc.angular_cl.gaussian_cl_covariance_and_mean(cosmo, ell, probes, 
                                                          f_sky=0.25, sparse=True)
    
    P = jc.sparse.to_dense(jc.sparse.inv(C))
    C = jc.sparse.to_dense(C)
    return cl, P, C


In [None]:
from numpyro.handlers import seed, trace, condition
# So, let's generate the data at the fiducial parameters
fiducial_model = condition(model,
    {'Omega_c':0.2545, 'sigma8':0.801, 'h':0.682, 'Omega_b':0.0485, 'w0':-1.,'n_s':0.971,
     'A':0.5,'eta':0.,
     'm1':0.0,'m2':0.0,'m3':0.0,'m4':0.0,
     'dz1':0.0,'dz2':0.0,'dz3':0.0,'dz4':0.0,
     'b1':1.2,'b2':1.4,'b3':1.6,'b4':1.8,'b5':2.0
      })

with seed(rng_seed=42):
    cl_obs, P, C = fiducial_model()


In [None]:
def model_spl(cl_obs=None):
    #  Cosmological params
    Omega_c = numpyro.sample('Omega_c', dist.Uniform(0.1, 0.9))
    sigma8 = numpyro.sample('sigma8', dist.Uniform(0.4, 1.0))
    Omega_b = numpyro.sample('Omega_b', dist.Uniform(0.03, 0.07))
    h = numpyro.sample('h', dist.Uniform(0.55, 0.91))
    n_s = numpyro.sample('n_s', dist.Uniform(0.87, 1.07))
    w0 = numpyro.sample('w0', dist.Uniform(-2.0, -0.33))

    # Intrinsic Alignment
    A = numpyro.sample('A', dist.Uniform(-5., 5.))
    eta = numpyro.sample('eta', dist.Uniform(-5., 5.))

    
    # linear galaxy bias
    bias = [numpyro.sample('b%d'%i, dist.Uniform(0.8, 3.0)) 
         for i in range(1,6)]

        
    # parameters for systematics
    m = [numpyro.sample('m%d'%i, dist.Normal(0.012, 0.023)) 
         for i in range(1,5)]
    dz1 = numpyro.sample('dz1', dist.Normal(0.001, 0.016)) 
    dz2 = numpyro.sample('dz2', dist.Normal(-0.019, 0.013)) 
    dz3 = numpyro.sample('dz3', dist.Normal(0.009, 0.011)) 
    dz4 = numpyro.sample('dz4', dist.Normal(-0.018, 0.022)) 
    dz = [dz1, dz2, dz3, dz4]
    
    # Now that params are defined, here is the forward model
    cosmo = jc.Cosmology(Omega_c=Omega_c, sigma8=sigma8, Omega_b=Omega_b,
                          h=h, n_s=n_s, w0=w0, Omega_k=0., wa=0.)
    
    # Build source nz with redshift systematic bias
    nzs_s_sys = [jc.redshift.systematic_shift(nzi, dzi, zmax=2.0) 
                for nzi, dzi in zip(nzs_s, dz)]
    
    # Define IA model, z0 is fixed
    b_ia = jc.bias.des_y1_ia_bias(A, eta, 0.62)

    # Bias for the lenses
    b = [jc.bias.constant_linear_bias(bi) for bi in bias] 
    
    # Define the lensing and number counts probe
    probes = [jc.probes.WeakLensing(nzs_s_sys, 
                                    ia_bias=b_ia,
                                    multiplicative_bias=m),
             jc.probes.NumberCounts(nzs_l, b)]

    cl = jc.angular_cl.angular_cl(cosmo, ell, probes).flatten()
    

    return numpyro.sample('cl', dist.MultivariateNormal(cl, 
                                                        precision_matrix=P,
                                                        covariance_matrix=C),
                          obs=cl_obs)


In [None]:

from numpyro.infer.reparam import LocScaleReparam, TransformReparam

def config(x):
    if type(x['fn']) is dist.TransformedDistribution:
        return TransformReparam()
    elif type(x['fn']) is dist.Normal and ('decentered' not in x['name']):
        return LocScaleReparam(centered=0)
    else:
        return None



model_reparam = numpyro.handlers.reparam(model_spl, config=config)

#####
##  SVI
####

import numpyro.infer.autoguide as autoguide
from numpyro.infer import Predictive, SVI, Trace_ELBO,  TraceMeanField_ELBO
from numpyro.optim import Adam


print('Do SVI...')


guide = autoguide.AutoMultivariateNormal(model_reparam,
                                         init_loc_fn=numpyro.infer.init_to_median())
optimizer = numpyro.optim.Adam(1e-3)

svi = SVI(model_reparam, guide,optimizer,loss=Trace_ELBO(num_particles=10))

n_steps = 20_000
svi_result = svi.run(jax.random.PRNGKey(0),n_steps, cl_obs)


samples = guide.sample_posterior(jax.random.PRNGKey(1), svi_result.params, sample_shape=(100_000,))

######
# Neural Transport reparametrizatiion for NUTS
######

from numpyro.infer.reparam import NeuTraReparam
from numpyro.infer import MCMC, NUTS, init_to_sample
neutra = NeuTraReparam(guide, svi_result.params)
neutra_model = neutra.reparam(model_reparam)

nuts_kernel = NUTS(neutra_model,
                  init_strategy=numpyro.infer.init_to_median(),
                   dense_mass=True,
                   max_tree_depth=5)

mcmc_neutra = MCMC(nuts_kernel, num_warmup=200,
                   num_samples=1000,
                   num_chains=1, # was 1
                   chain_method="vectorized", #was not there
                   progress_bar=True) # was False



print('SVI+NUTS')
mcmc_neutra.run(jax.random.PRNGKey(42), cl_obs)

mcmc_neutra.print_summary()
zs = mcmc_neutra.get_samples()["auto_shared_latent"]
samples_nuts_neutra = neutra.transform_sample(zs)
