In [2]:
import sys
sys.path.append('../reference/jax-cosmo/jax_cosmo-master/')

In [3]:
import jax
import jax.numpy as jnp 
from jax.config import config
import jax_cosmo as jc
from jax_cosmo.power import linear_matter_power, nonlinear_matter_power
import jax_cosmo.power as jcp 

# the emulator part 
from jax_cosmo.power import linear_matter_power_emu
from jax_cosmo.power import KGRID, ZGRID
config.update("jax_enable_x64", True)

import matplotlib.pylab as plt 
plt.rc('text', usetex=True)
plt.rc('font',**{'family':'sans-serif','serif':['Palatino']})
figSize  = (12, 8)
fontSize = 20

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [4]:
jax.default_backend()

'cpu'

In [5]:
jax.devices()

[CpuDevice(id=0)]

In [6]:
scalefactor = 0.65
redshift = 1.0 / scalefactor - 1.0

In [7]:
print(f'The redshift is {redshift:.3f}.')

The redshift is 0.538.


In [8]:
cosmo = jc.Cosmology(sigma8=0.933523, Omega_c=0.096535, Omega_b=0.035931, h=0.816626, n_s=0.902618, 
                      w0=-1., Omega_k=0., wa=0.)

In [9]:
new_k = jnp.geomspace(1E-4, 50, 500)

# Emulator 

In [10]:
%%time
prediction = linear_matter_power_emu(cosmo, new_k, scalefactor)

CPU times: user 959 ms, sys: 465 ms, total: 1.42 s
Wall time: 811 ms


### Using JIT

In [11]:
emu_jit = jax.jit(linear_matter_power_emu)

In [12]:
%%time
prediction_jit = emu_jit(cosmo, new_k, scalefactor)

CPU times: user 1.78 s, sys: 257 ms, total: 2.04 s
Wall time: 1.47 s


In [13]:
%%time
prediction_jit = emu_jit(cosmo, new_k, scalefactor)

CPU times: user 1.28 ms, sys: 407 µs, total: 1.68 ms
Wall time: 1.02 ms


# Jax Cosmo

In [14]:
%%time
pklin_jax = linear_matter_power(cosmo, new_k, scalefactor)

CPU times: user 2.3 s, sys: 15.1 ms, total: 2.32 s
Wall time: 2.3 s


### Using JIT

In [15]:
jc_jit = jax.jit(linear_matter_power)

In [16]:
%%time 
pklin_jit = jc_jit(cosmo, new_k, scalefactor)

CPU times: user 2.21 s, sys: 13.3 ms, total: 2.22 s
Wall time: 2.18 s


In [17]:
%%time 
pklin_jit = jc_jit(cosmo, new_k, scalefactor)

CPU times: user 698 µs, sys: 108 µs, total: 806 µs
Wall time: 694 µs


# Calculates the Non-Linear Matter Power Spectrum

### With the Emulator

In [18]:
jcp.USE_EMU = True

In [19]:
%%time
pk_non_linear_fine_emu = nonlinear_matter_power(cosmo, new_k, scalefactor)

Using the emulator
CPU times: user 1.39 s, sys: 17.3 ms, total: 1.4 s
Wall time: 1.4 s


In [20]:
emu_nl_jit = jax.jit(nonlinear_matter_power)

In [21]:
%%time
pk_non_linear_fine_emu = emu_nl_jit(cosmo, new_k, scalefactor) 

Using the emulator


KeyboardInterrupt: 

In [22]:
%%time
pk_non_linear_fine_emu = emu_nl_jit(cosmo, new_k, scalefactor) 

CPU times: user 4.13 s, sys: 828 ms, total: 4.96 s
Wall time: 3.23 s


### Without the Emulator

In [23]:
jcp.USE_EMU = False

In [None]:
%%time
pk_non_linear_fine_jax = nonlinear_matter_power(cosmo, new_k, scalefactor)

Not using the emulator


In [None]:
jc_nl_jit = jax.jit(nonlinear_matter_power)

In [None]:
%%time
pk_non_linear_fine_jax = emu_nl_jit(cosmo, new_k, scalefactor) 

In [None]:
%%time
pk_non_linear_fine_jax = emu_nl_jit(cosmo, new_k, scalefactor) 

In [None]:
plt.figure(figsize = (8,6))
plt.title(f'Redshift, $z={redshift:.3f}$', fontsize=fontSize)
plt.loglog(new_k, prediction, label = 'Emulator, $P_{l}$', lw = 3)
plt.loglog(new_k, pklin_jax, lw = 3, linestyle = '--', label = 'JAX Cosmo, $P_{l}$')
plt.loglog(new_k, pk_non_linear_fine_emu, label = 'Emulator, $P_{nl}$', lw = 3)
plt.loglog(new_k, pk_non_linear_fine_jax, lw = 3, linestyle = '--', label = 'JAX Cosmo, $P_{nl}$')
plt.xlim(min(KGRID), max(KGRID))
plt.legend(loc = 'best',prop={'family':'sans-serif', 'size':15})
plt.ylabel(r'$P(k)$', fontsize = fontSize)
plt.xlabel(r'$k\;[\textrm{Mpc}^{-1}]$', fontsize = fontSize)
plt.tick_params(axis='x', labelsize=fontSize)
plt.tick_params(axis='y', labelsize=fontSize)
plt.show()

# Blackjax

In [None]:
import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import numpy as np
import blackjax
import pandas as pd 

In [None]:
observed = np.random.normal(10, 20, size=1_000)

In [None]:
def logdensity_fn(x):
    logpdf = stats.norm.logpdf(observed, x["loc"], x["scale"])
    return jnp.sum(logpdf)

In [None]:
# Build the kernel
step_size = 1e-3
nsamples = 20_000
burnin = int(0.1 * nsamples)
inverse_mass_matrix = jnp.array([1., 1.])
nuts = blackjax.nuts(logdensity_fn, step_size, inverse_mass_matrix)

In [None]:
# Initialize the state
initial_position = {"loc": 1., "scale": 2.}
state = nuts.init(initial_position)

# Iterate
rng_key = jax.random.PRNGKey(0)
step = jax.jit(nuts.step)
record = []
for i in range(nsamples):
    rng_key, nuts_key = jax.random.split(rng_key)
    state, _ = step(nuts_key, state)
    position = state.position
    record.append({k: position[k].item() for k in position.keys()})

In [None]:
df = pd.DataFrame(record)

In [None]:
plt.figure(figsize=(16, 6))
plt.subplot(121)
plt.hist(df['loc'].values[burnin:], density = True, bins=20, ec='blue')
plt.ylabel(r'$p(\mu)$', fontsize = fontSize)
plt.xlabel(r'$\mu$', fontsize = fontSize)
plt.tick_params(axis='x', labelsize=fontSize)
plt.tick_params(axis='y', labelsize=fontSize)
plt.subplot(122)
plt.hist(df['scale'].values[burnin:], density = True, bins=20, ec='blue')
plt.ylabel(r'$p(\sigma)$', fontsize = fontSize)
plt.xlabel(r'$\sigma$', fontsize = fontSize)
plt.tick_params(axis='x', labelsize=fontSize)
plt.tick_params(axis='y', labelsize=fontSize)
plt.show()

### Pathfinder

In [None]:
def logdensity_fn_arr(x):
    logpdf = stats.norm.logpdf(observed, x[0], x[1])
    return jnp.sum(logpdf)

In [None]:
rng_key = jax.random.PRNGKey(314)
w0 = jnp.array([5.0, 15.0])
_, info = blackjax.vi.pathfinder.approximate(rng_key, logdensity_fn_arr, w0, ftol=1e-4)
path = info.path

In [None]:
steps = (jnp.isfinite(path.elbo)).sum()

In [None]:
for i in range(steps):
    state = jax.tree_map(lambda x: x[i], path)
    sample_state, _ = blackjax.vi.pathfinder.sample(rng_key, state, 10_000)
    mu_i, cov_i = sample_state.mean(0), jnp.cov(sample_state, rowvar=False)
    print(mu_i)
    print(cov_i)
    print('-'*50)

### Periodic Orbital MCMC

In [None]:
import blackjax.mcmc.integrators as integrators
from blackjax import orbital_hmc as orbital

In [None]:
def inference_loop(rng_key, kernel, initial_state, num_samples):
    """Sequantially draws samples given the kernel of choice."""

    def one_step(state, rng_key):
        state, _ = kernel(rng_key, state)
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states

In [None]:
def plot_contour(logdensity, orbits=None, weights=None):
    """Contour plots for density w/ or w/o samples."""
    a, b, c, d = 8.0, 13.0, 18.0, 23.0
    x1 = jnp.linspace(a, b, 1000)
    x2 = jnp.linspace(c, d, 1000)
    ref = logdensity({"x1": 10, "x2": 20})
    y = jax.vmap(
        jax.vmap(lambda x1, x2: jnp.exp(logdensity({"x1": x1, "x2": x2})-ref), (0, None)),
        (None, 0),
    )(x1, x2)
    fig, ax = plt.subplots(1, 2, figsize=(17, 6))
    CS0 = ax[0].contour(x1, x2, y, levels=10, colors="k")
    plt.clabel(CS0, inline=1, fontsize=10)
    CS1 = ax[1].contour(x1, x2, y, levels=10, colors="k")
    plt.clabel(CS1, inline=1, fontsize=10)
    if orbits is not None:
        ax[0].set_title("Unweighted samples")
        ax[0].scatter(orbits["x1"], orbits["x2"], marker=".")
        ax[1].set_title("Weighted samples")
        ax[1].scatter(orbits["x1"], orbits["x2"], marker=".", alpha=weights)

In [None]:
def logdensity_fn_x1x2(x1, x2):
    logpdf = stats.norm.logpdf(observed, x1, x2)
    return jnp.sum(logpdf)

In [None]:
logdensity = lambda x: logdensity_fn_x1x2(**x)

In [None]:
inv_mass_matrix = jnp.ones(2)
period = 2
step_size = 0.5
initial_position = {"x1": 10.0, "x2": 20.0}

In [None]:
init_fn, vv_kernel = orbital(
    logdensity, step_size, inv_mass_matrix, period, bijection=integrators.mclachlan
)
initial_state = init_fn(initial_position)
vv_kernel = jax.jit(vv_kernel)

In [None]:
rng_key = jax.random.PRNGKey(0)
states = inference_loop(rng_key, vv_kernel, initial_state, 10_000)

samples = states.positions
weights = states.weights

In [None]:
weights

In [None]:
plot_contour(logdensity)

In [None]:
plot_contour(logdensity, orbits=samples, weights=None)

# Numpyro

In [None]:
import jax
# from jax.experimental import mesh_utils
# from jax.sharding import PositionalSharding
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, init_to_value

In [None]:
ref_params = {'m1': 0.0, 'm2': 0.0, 'm3': 0.0}

In [None]:
X = np.random.randn(128, 3)
y = np.random.randn(128)

def model(X, y):
    # beta = numpyro.sample("beta", dist.Normal(0, 1).expand([3]))
    # beta = numpyro.sample("beta", dist.Normal(0.012, 0.023).expand([3]))
    m1 = numpyro.sample("m1", dist.Normal(0.012, 0.023))
    m2 = numpyro.sample("m2", dist.Normal(0.012, 0.023))
    m3 = numpyro.sample("m3", dist.Normal(0.012, 0.023))
    beta = jnp.array([m1, m2, m3])
    print(beta)
    numpyro.sample("obs", dist.Normal(X @ beta, 1), obs=y)

mcmc = MCMC(NUTS(model, init_strategy=init_to_value(values=ref_params)), num_warmup=10, num_samples=1000)

# # See https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
# sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
# X_shard = jax.device_put(X, sharding.reshape(8, 1))
# y_shard = jax.device_put(y, sharding.reshape(8))
# mcmc.run(jax.random.PRNGKey(0), X_shard, y_shard)

In [None]:
# [ 1.39441888 -0.38114053  0.38801493]

In [None]:
%%time
mcmc.run(jax.random.PRNGKey(124545), X, y)

In [None]:
beta_samples = mcmc.get_samples()

In [None]:
beta_samples

In [None]:
def normal_distribution(mean, std):
    import scipy.stats as ss 
    xrange = np.linspace(mean - 5* std, mean + 5 * std, 1000)
    distribution = ss.norm(mean, std)
    pdf = distribution.pdf(xrange)
    return xrange, pdf

In [None]:
plt.figure()
plt.hist(beta_samples['m1'], density=True)
x, p = normal_distribution(0.012, 0.023)
plt.plot(x, p)
plt.show()

In [None]:
import scipy.stats as ss 

In [None]:
distribution = ss.multivariate_normal(np.ones(2), np.eye(2))

In [None]:
vec = np.random.randn(2)

In [None]:
distribution.logpdf(vec)

In [None]:
factor = 5

In [None]:
distribution2 = ss.multivariate_normal(factor*np.ones(2), np.eye(2)*factor**2)

In [None]:
distribution2.logpdf(vec*factor)

In [None]:
def pdf(x, mean, cov):
    return np.exp(logpdf(x, mean, cov))


def logpdf(x, mean, cov):
    # `eigh` assumes the matrix is Hermitian.
    vals, vecs = np.linalg.eigh(cov)
    logdet     = np.sum(np.log(vals))
    valsinv    = np.array([1./v for v in vals])
    # `vecs` is R times D while `vals` is a R-vector where R is the matrix 
    # rank. The asterisk performs element-wise multiplication.
    U          = vecs * np.sqrt(valsinv)
    rank       = len(vals)
    dev        = x - mean
    # "maha" for "Mahalanobis distance".
    maha       = np.square(np.dot(dev, U)).sum()
    log2pi     = np.log(2 * np.pi)
    return -0.5 * (rank * log2pi + maha + logdet)

In [None]:
cov = np.random.randn(2,2)
cov = cov @ cov.T

In [None]:
vals, vecs = np.linalg.eigh(cov)
logdet     = np.sum(np.log(vals))
valsinv    = np.array([1./v for v in vals])
# `vecs` is R times D while `vals` is a R-vector where R is the matrix 
# rank. The asterisk performs element-wise multiplication.
U          = vecs * np.sqrt(valsinv)
rank       = len(vals)

In [None]:
x = np.random.randn(2)
mean = np.zeros(2)
dev = x - mean

In [None]:
dev

In [None]:
np.dot(dev, U)

In [None]:
jax.scipy.stats.uniform.logpdf(jnp.ones(1), 0, 2)

In [None]:
jax.scipy.stats.norm.logpdf(jnp.zeros(1), 0, 1)

In [None]:
import scipy.stats as ss 

In [None]:
ss.norm(0, 1).logpdf(0)