In [1]:
!pip install corner pymc pytensor arviz astropy matplotlib numpy 

Collecting corner
  Downloading corner-2.2.3-py3-none-any.whl.metadata (2.2 kB)
Downloading corner-2.2.3-py3-none-any.whl (15 kB)
Installing collected packages: corner
Successfully installed corner-2.2.3


In [2]:
import pymc as pm
import pytensor.tensor as pt
import pytensor.tensor.extra_ops as pte
import numpy as np
import matplotlib.pyplot as plt
import arviz as az
from astropy.cosmology import Planck18 as cosmo
from astropy import units as u
import corner

# --- Functions ---
def Ez(z, Om, w, wDM):
    opz = 1 + z
    return pt.sqrt(Om * opz ** (3 * (1 + wDM)) + (1 - Om) * opz ** (3 * (1 + w)))

def dCs(zs, Om, w, wDM):
    dz = zs[1:] - zs[:-1]
    fz = 1 / Ez(zs, Om, w, wDM)
    I = 0.5 * dz * (fz[:-1] + fz[1:])
    return pt.concatenate([pt.as_tensor([0.0]), pt.cumsum(I)])

def dLs(zs, dCs):
    return dCs * (1 + zs)

def pt_interp(x, xs, ys):
    x = pt.as_tensor(x)
    xs = pt.as_tensor(xs)
    ys = pt.as_tensor(ys)
    ind = pte.searchsorted(xs, x)
    ind = pt.clip(ind, 1, xs.shape[0] - 1)
    r = (x - xs[ind - 1]) / (xs[ind] - xs[ind - 1])
    return r * ys[ind] + (1 - r) * ys[ind - 1]

# --- Pivot optimization helper ---
def fracHz(z, Om=0.3, w=-1.0):
    """Proxy for fractional uncertainty ÏƒH/H from log-derivatives wrt parameters."""
    opz = 1 + z
    denom = Om + (1 - Om) * opz ** (3*w)
    num_Om = 1 - opz ** (3*w)
    num_w  = (1 - Om) * 3 * opz ** (3*w) * np.log(opz)
    return np.sqrt(num_Om**2 + num_w**2) / (2 * denom)

#np.random.seed(97685295)

#Finding the current seed
current_state = np.random.get_state()
current_seed = current_state[1][0]
print(f"Current random seed: {current_seed}")
with open("seed.txt", "w") as f:
    f.write(str(current_seed))

# --- Mock data ---
Nobs = 1000
z_true = np.random.beta(3, 9, Nobs) * 10
DL_true = cosmo.luminosity_distance(z_true).to(u.Gpc).value
sigma_DL = DL_true * 0.07 #5.0 #0.001 #
DL_obs = np.random.normal(DL_true, sigma_DL)

# True chirp mass distribution
mu_p_true = 1.17
sigma_p_true = 0.1
M_source_true = np.random.normal(mu_p_true, sigma_p_true, size=Nobs)
Mz_true = (1 + z_true) * M_source_true

# --- Mass ratio (q) ---
mu_q_true = 0.5
sigma_q_true = 0.1
# Generate in [0,1]
q_true = np.clip(np.random.normal(mu_q_true, sigma_q_true, size=Nobs), 1e-4, 1-1e-4)

# Measurement noise
sigma_q_obs = q_true * 0.07 #5.0 #0.001 #
q_obs = np.clip(np.random.normal(q_true, sigma_q_obs), 1e-4, 1-1e-4)

m1_true = M_source_true * (1 + q_true) ** (1/5) / q_true ** (3/5)
m2_true = M_source_true * (1 + q_true) ** (1/5) * q_true ** (2/5)

m1t_mean = mu_p_true * (1 + mu_q_true) ** (1/5) / mu_q_true ** (3/5)
m2t_mean = mu_p_true * (1 + mu_q_true) ** (1/5) * mu_q_true ** (2/5)
print("m1t_mean =", m1t_mean)
print("m2t_mean =", m2t_mean)
# --- Tidal deformabilities ---
c0_true = 4.8
c1_true = -5.0
Lambda1_true = c0_true + c1_true * (m1_true - ((m1t_mean + m2t_mean)/2.0))
Lambda2_true = c0_true + c1_true * (m2_true - ((m1t_mean + m2t_mean)/2.0))

Lambda_til_true = (16.0 / 13.0) * (
    (m1_true + 12 * m2_true) * m1_true**4 * Lambda1_true
    + (m2_true + 12 * m1_true) * m2_true**4 * Lambda2_true
) / (m1_true + m2_true) ** 5
sigma_lambda = np.abs(Lambda_til_true) * 0.07 #5.0 #0.001 #
Lambda_obs = np.random.normal(Lambda_til_true, sigma_lambda)

# Observed detector-frame chirp masses
Mz_obs = Mz_true

zinterp = np.linspace(0, 10, 1000)

# Choosing pivot redshift
z_grid = z_true
frac_vals = fracHz(z_grid)
z_pivot = float(z_grid[np.argmin(frac_vals)])
print("Using optimal z_pivot =", z_pivot)

with open("z_pivot.txt", "w") as f:
    f.write(str(z_pivot))

# --- Model ---
with pm.Model() as model:

    # Cosmology priors
    H_pivot = pm.Uniform("H_pivot", 0.01, 1.0)
    Om = pm.Uniform("Om", 0.1, 0.5)
    w = pm.Uniform("w", -2.5, -0.3)
    wDM = 0.0

    # H0 defined from pivot: H(z_pivot) = H0 * Ez(z_pivot)
    z_piv_t = pt.as_tensor(z_pivot)
    H0 = pm.Deterministic("H0", H_pivot / Ez(z_piv_t, Om, w, wDM))
    dH = pm.Deterministic("dH", 2.99792 / H0)  # Gpc

    # Chirp Mass 
    mu_p = pm.Uniform("mu_p", 0.7, 1.7)#0.5, 2.0
    sigma_p = pm.Uniform("sigma_p", 0.0, 0.5)

    # Mass ratio hyper-parameters 
    mu_q = pm.Uniform("mu_q", 0.01, 0.99)
    sigma_q = pm.Uniform("sigma_q", 0.0, 0.5)

    # EOS
    c0 = pm.Uniform("c0", 3.0, 6.0)
    c1 = pm.Uniform("c1", -7.0, -3.0)

    # Latent redshifts 
    z_unit = pm.Beta("z_unit", 3, 9, shape=Nobs)
    z = pm.Deterministic("z", z_unit * 10)

    # Distance DL
    dCinterp = dH * dCs(zinterp, Om, w, wDM)
    dLinterp = dLs(zinterp, dCinterp)
    dL = pm.Deterministic("dL", pt_interp(z, zinterp, dLinterp))

    # Chirp Mass in source frame
    Mc = pm.Deterministic("Mc", Mz_obs / (1 + z))

    # Latent q with hierarchical prior
    q = pm.TruncatedNormal("q", mu=mu_q, sigma=sigma_q,
                                  lower=0.0, upper=1.0, shape=Nobs)

    m1 = pm.Deterministic("m1", Mc * (1 + q) ** (1/5) / q ** (3/5))
    m2 = pm.Deterministic("m2", Mc * (1 + q) ** (1/5) * q ** (2/5))

    m1mean = ( mu_p * (1 + mu_q) ** (1/5) / mu_q ** (3/5))
    m2mean = ( mu_p * (1 + mu_q) ** (1/5) * mu_q ** (2/5))

    # Lambda for each component 
    Lambda1 = pm.Deterministic("Lambda1", c0 + c1 * (m1 - ((m1mean + m2mean)/2.0)))
    Lambda2 = pm.Deterministic("Lambda2", c0 + c1 * (m2 - ((m1mean + m2mean)/2.0)))
    # Tilde Lambda (mass-weighted combination) 
    Lambda_til = pm.Deterministic("Lambda_til",(16.0 / 13.0)*((m1 + 12 * m2) * m1**4 * Lambda1 + (m2 + 12 * m1) * m2**4 * Lambda2)/(m1 + m2) ** 5)

    # Likelihoods & priors
    pm.Potential("mcprior", pt.sum(pm.logp(pm.Normal.dist(mu_p, sigma_p), Mc)))
    pm.Potential("mcjac", pt.sum(-pt.log1p(z)))
    pm.Normal("q_likelihood", mu=q, sigma=sigma_q_obs, observed=q_obs)
    pm.Normal("dL_likelihood", mu=dL, sigma=sigma_DL, observed=DL_obs)
    pm.Normal("Lambda_til_likelihood", mu=Lambda_til, sigma=sigma_lambda, observed=Lambda_obs)



    # Initial values
    initvals = {
        "H_pivot": cosmo.H(z_pivot).value / 100,
        "Om": cosmo.Om0,
        "w": -1.0,
        "mu_p": 1.17,
        "sigma_p": 0.1,
        "c0": 4.8,
        "c1": -5.0,
        "mu_q": 0.5,
        "sigma_q": 0.1,
    }

    # Sampling
    trace = pm.sample(1000, tune=1000, chains=2, target_accept=0.95, initvals=initvals, max_treedepth=15)

summary = az.summary(trace, var_names=["H_pivot", "H0", "Om", "w", "mu_p", "sigma_p", "c0", "c1", "mu_q", "sigma_q"])# "c0", "c1",
print(summary)

# Save trace
az.to_netcdf(trace, "trace_pymc_frac.nc")
#np.savez("trace_pymc_061225_woeos.npz", **{k: trace.posterior[k].values for k in trace.posterior.data_vars})
print("Trace saved as trace_pymc_frac.nc")

# Trace plot
az.plot_trace(trace, var_names=["H_pivot", "H0", "Om", "w", "mu_p", "sigma_p", "c0", "c1", "mu_q", "sigma_q"],# "c0", "c1",
              lines=[('H_pivot', {}, cosmo.H(z_pivot).value/100),
                     ('H0', {}, cosmo.H0.value/100),
                     ('Om', {}, cosmo.Om0),
                     ('w', {}, -1),
                     ('mu_p', {}, 1.17),
                     ('sigma_p', {}, 0.1),
                     ('c0', {}, 4.8),
                     ('c1', {}, -5.0),
                     ('mu_q', {}, 0.5),
                     ('sigma_q', {}, 0.1)])
plt.tight_layout()
plt.savefig("trace_plot_frac.png")
print("Saved trace_plot_frac.png")

# Corner plot
trace_data = trace.posterior.stack(samples=("chain", "draw"))[["H_pivot","H0", "Om", "w", "mu_p", "sigma_p", "c0", "c1", "mu_q", "sigma_q"]].to_array().values.T
labels = ["H_pivot","H0", "Om", "w", "mu_p", "sigma_p", "c0", "c1","mu_q","sigma_q"]
true_values = [cosmo.H(z_pivot).value/100,cosmo.H0.value/100, cosmo.Om0, -1.0, 1.17, 0.1, 4.8, -5.0, 0.5, 0.1]

corner.corner(trace_data, labels=labels, show_titles=True, color="blue", hist_kwargs={"density": True}, truths=true_values)
plt.savefig("corner_plot_frac.png")
print("Saved corner_plot_frac.png")

summary.to_csv("summary.csv")
with open("summary.txt", "w") as f:
    f.write(str(summary))

Current random seed: 2147483648
m1t_mean = 1.923189640535154
m2t_mean = 0.9615948202675769
Using optimal z_pivot = 0.15294016678301522


Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [H_pivot, Om, w, mu_p, sigma_p, mu_q, sigma_q, c0, c1, z_unit, q]


Output()