In [1]:
import numpyro
numpyro.set_host_device_count(5)

import numpy as np
from scipy.linalg import cholesky

import jax.numpy as jnp
from jax import random
from jax.scipy.linalg import solve_triangular
import numpyro.distributions as dist
from numpyro import factor, plate, sample
from numpyro.infer import MCMC, NUTS, init_to_median
import numpy as np
from jax import numpy as jnp

import joblib
import candel

import matplotlib.pyplot as plt
%matplotlib inline

data = joblib.load("/Users/rstiskalek/Downloads/data.joblib")
SPEED_OF_LIGHT = 299_792.458  # km / s


def mvn_logpdf_cholesky(y, mu, L):
    """
    Log-pdf of a multivariate normal using Cholesky factor L (lower
    triangular).
    """
    z = solve_triangular(L, y - mu, lower=True)
    log_det = jnp.sum(jnp.log(jnp.diag(L)))
    return -0.5 * (len(y) * jnp.log(2 * jnp.pi) + 2 * log_det + jnp.dot(z, z))


L_cepheid = cholesky(data["C_Cepheid"], lower=True)
L_SN = cholesky(data["C_SN_unique_Cepheid_host"], lower=True)


In [6]:
def r2mu(r):
    return 5 * jnp.log10(r) + 25

def r2cz(r, H0):
    return r * H0

def mu2r(mu):
    return 10**((mu - 25) / 5)

def drdmu(mu):
    return 1 / 5 * np.log(10) * 10**((mu - 25) / 5)

r2mu_candel = candel.Distance2Distmod()
r2z_candel = candel.Distance2Redshift()
mu2r_candel = candel.Distmod2Distance()
log_drdmu_candel = candel.LogGrad_Distmod2ComovingDistance()



def model(num_hosts, num_cepheids, mag_cepheid, e_mag_cepheid, L_cepheid,
          logP, OH, L_Cepheid_host_dist, cz_host, e_cz_host, mag_SN, e_mag_SN,
          L_SN, mu_N4258_anchor, e_mu_N4258_anchor, mu_LMC_anchor,
          e_mu_LMC_anchor, version, use_cepheid_covariance, use_SN_covariance,
          sample_dZP):
    # Sample model parameters.
    H0 = sample("H0", dist.Uniform(10, 100))
    h = H0 / 100

    M_W = sample("M_W", dist.Uniform(-7, -5))
    b_W = sample("b_W", dist.Uniform(-4, -2))
    Z_W = sample("Z_W", dist.Uniform(-1, 1))

    M_B = sample("M_B", dist.Uniform(-20, -18))
    
    if sample_dZP:
        dZP = sample("dZP", dist.Normal(0, 0.1))
    else:
        dZP = 0

    sigma_v = sample("sigma_v", dist.Uniform(5, 1000))

    # MW calibration
    sample("M_W_Combined", dist.Normal(M_W, 0.0239), obs=-5.8946)

    if version == "eleni":
        with plate("plate_host", num_hosts + 3):
            logr = sample("logr_host", dist.Uniform(jnp.log(1e-8), jnp.log(200)))
            factor("lp_r2", 3 * logr)

        r_host_all = jnp.exp(logr)
        z_cos_all = H0 * r_host_all / SPEED_OF_LIGHT
        DL = r_host_all * (1 + z_cos_all)
        mu_host_all = 5 * jnp.log10(DL) + 25

        cz_pred = H0 * r_host_all[:-3]
    elif version == "distmod_approx":
        with plate("plate_host", num_hosts + 3):
            mu_host_all = sample("mu_host", dist.Uniform(10, 43))
        r_host_all = 10**((mu_host_all - 25) / 5)
        factor("lp_r2", 3 * jnp.log(r_host_all))

        cz_pred = H0 * r_host_all[:-3]
    elif version == "distmod_exact":
        with plate("plate_host", num_hosts + 3):
            mu_host_all = sample("mu_host", dist.Uniform(10, 43))

        r_host_all = mu2r_candel(mu_host_all, h)
        factor(
            "lp_r2",
            2 * jnp.log(r_host_all) + log_drdmu_candel(mu_host_all, h=h))
        cz_pred = r2z_candel(r_host_all[:-3], h=h) * SPEED_OF_LIGHT
    else:
        raise ValueError("Unknown version")

    mu_N4258 = mu_host_all[-3]
    mu_LMC = mu_host_all[-2]

    # Distance calibration
    sample("ll_N4258", dist.Normal(mu_N4258, e_mu_N4258_anchor), obs=mu_N4258_anchor)
    sample("ll_LMC", dist.Normal(mu_LMC, e_mu_LMC_anchor), obs=mu_LMC_anchor)

    mu_host_all_for_cepheid = jnp.copy(mu_host_all)
    mu_host_all_for_cepheid = mu_host_all_for_cepheid.at[-2].add(dZP)

    mu_cepheid = L_Cepheid_host_dist @ mu_host_all_for_cepheid
    pred_mag_cepheid = mu_cepheid + M_W + b_W * logP + Z_W * OH

    if use_cepheid_covariance:
        factor("ll_cepheid", mvn_logpdf_cholesky(mag_cepheid, pred_mag_cepheid, L_cepheid))
    else:
        with plate("plate_ll_Cepheid", num_cepheids):
            sample("ll_cepheid",
                   dist.Normal(pred_mag_cepheid, e_mag_cepheid),
                   obs=mag_cepheid)

    mu_host = mu_host_all[:-3]

    with plate("plate_ll_host", num_hosts):
        sample(
            "ll_host",
            dist.Normal(cz_pred, jnp.sqrt(e_cz_host**2 + sigma_v**2)),
            obs=cz_host)

    mag_SN_pred = M_B + mu_host
    if use_SN_covariance:
        factor("ll_SN", mvn_logpdf_cholesky(mag_SN, mag_SN_pred, L_SN))
    else:
        with plate("plate_ll_SN", num_hosts):
            sample(
                "ll_SN",
                dist.Normal(mag_SN_pred, e_mag_SN),
                obs=mag_SN)

    logS = -1.38 * M_B
    factor("logS", - num_hosts * logS)

In [8]:
kernel = NUTS(model, init_strategy=init_to_median(num_samples=200))
mcmc = MCMC(kernel, num_warmup=500, num_samples=1000, num_chains=5, progress_bar=True)
mcmc.run(
    random.PRNGKey(41),
    num_hosts=data["num_hosts"],
    num_cepheids=data["num_cepheids"],
    mag_cepheid=jnp.asarray(data["mag_cepheid"]),
    e_mag_cepheid=jnp.sqrt(jnp.diag(data["C_Cepheid"])),
    L_cepheid=jnp.asarray(L_cepheid),
    logP=jnp.asarray(data["logP"]),
    OH=jnp.asarray(data["OH"]),
    cz_host=jnp.asarray(data["czcmb_cepheid_host"]),
    e_cz_host=jnp.asarray(data["e_czcmb_cepheid_host"]),
    L_Cepheid_host_dist=jnp.asarray(data["L_Cepheid_host_dist"]),
    mag_SN=jnp.asarray(data["mag_SN_unique_Cepheid_host"]),
    e_mag_SN=jnp.sqrt(jnp.diag(data["C_SN_unique_Cepheid_host"])),
    L_SN=jnp.asarray(L_SN),
    mu_N4258_anchor=data["mu_N4258_anchor"],
    e_mu_N4258_anchor=data["e_mu_N4258_anchor"],
    mu_LMC_anchor=data["mu_LMC_anchor"],
    e_mu_LMC_anchor=data["e_mu_LMC_anchor"],
    version="distmod_exact",
    use_cepheid_covariance=True,
    use_SN_covariance=True,
    # version="distmod_approx"
    # version="distmod_exact"
    sample_dZP=True,
)

mcmc.print_summary()
samples = mcmc.get_samples()

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]


                 mean       std    median      5.0%     95.0%     n_eff     r_hat
         H0     68.75      1.97     68.73     65.50     71.99   2985.12      1.00
        M_B    -19.28      0.03    -19.28    -19.34    -19.23   1703.53      1.00
        M_W     -5.90      0.02     -5.90     -5.93     -5.87   1709.62      1.00
        Z_W     -0.19      0.05     -0.19     -0.28     -0.11   3225.60      1.00
        b_W     -3.28      0.02     -3.28     -3.31     -3.26   4111.42      1.00
        dZP     -0.03      0.04     -0.03     -0.09      0.03   2521.35      1.00
 mu_host[0]     29.17      0.04     29.17     29.11     29.25   3284.84      1.00
 mu_host[1]     32.94      0.08     32.94     32.80     33.08   4877.46      1.00
 mu_host[2]     32.85      0.08     32.85     32.72     33.00   4811.01      1.00
 mu_host[3]     32.56      0.07     32.56     32.45     32.68   4710.58      1.00
 mu_host[4]     32.53      0.05     32.53     32.44     32.62   3763.50      1.00
 mu_host[5]    