# Inferring the Redshift Distribution of Gravitational Waves

# Hierarchical Likelihood Inference

In [1]:
import numpy as np
import pandas as pd
import json
import jax.numpy as jnp

with open(f"./event_posterior_samples.json", "r") as f:
    event_samples = {k:pd.DataFrame(v) for k,v in json.load(f).items()}

eventnames = []
posteriors = []
for event, posterior in event_samples.items():
    eventnames.append(event)
    posteriors.append(posterior)

In [2]:
with open(f"./selection_function_samples.json", "r") as f:
    selection_samples = pd.DataFrame(json.load(f))

selections_jaxed = {col: jnp.array(selection_samples[col][1:100]) for col in selection_samples.columns if col not in ['waveform_name', 'name']}

In [3]:
posteriors_jaxed = [{col: jnp.array(post[col][1:100]) for col in post.columns if col not in ['waveform_name']} for post in posteriors]

## Population Model Recap

We use individual event likelihoods to infer the hierarchical model. 
$$
\begin{equation}
p(\Lambda \mid d) \propto p(\Lambda)\prod_i^N \mathcal{L}\left(d_i \mid \Lambda\right)
\end{equation}
$$
where
$$
\begin{equation}
\mathcal{L}\left(d_i \mid \Lambda\right) \propto \int \frac{p\left(\theta \mid d_i\right) p(\theta \mid \Lambda)}{p(\theta)} d \theta \approx 
\left\langle  \frac{p(\theta \mid \Lambda)}{p(\theta)}  \right\rangle_{\sim p\left(\theta \mid d_i\right)}
\end{equation}
$$

We will "recycle" these posteriors and get the result for the following population model:


In [4]:
import jax.numpy as jnp
import jax

In [5]:
## Mass model
def power_law(x, L, a, b):
    normalization = (L + 1) / (jnp.power(b, L+1) - jnp.power(a, L+1))
    return jnp.power(x,L) * normalization

def trunc_normal_pdf(x, mu, sig, a, b):
    a_std, b_std = (a - mu) / sig, (b - mu) / sig
    return jax.scipy.stats.truncnorm.pdf(x, a_std, b_std, loc=mu, scale=sig)

In [6]:
class MassModel:
    def __init__(self, m_min, m_max):
        self.m_min = m_min
        self.m_max = m_max

    def pdf(self, data, params):
        p_m_power_law = power_law(data['mass_1_source'], params['lambda'], self.m_min, self.m_max)
        p_m_peak = trunc_normal_pdf(data['mass_1_source'], params['mu_m'], params['sigma_m'], self.m_min, self.m_max)
        p_q = power_law(data['mass_ratio'], params['gamma'], self.m_min/data['mass_1_source'], 1.0)
        p_m = params['fp'] * p_m_power_law + (1 - params['fp']) * p_m_peak
        return p_m*p_q

    def __call__(self, data, params):
        return self.pdf(data, params)

In [7]:
class Redshift:
    def __init__(self, z_max):
        from astropy.cosmology import Planck15
        self.z = jnp.linspace(0,z_max,300)
        self.y = Planck15.differential_comoving_volume(self.z).value * 4 * np.pi
        self.kappas = jnp.linspace(-10,10,3000)
        self._kappa_norms = None

    def dVdz(self, z):
        return jnp.interp(z, self.z, self.y)

    def normalization_func(self, z, kappa):
        return (1 + z)**(kappa - 1) * self.dVdz(z)

    @property
    def kappa_norms(self):
        if self._kappa_norms is None:
            self._normalize()
        return self._kappa_norms

    def _normalize(self):
        self._kappa_norms = jnp.array([jax.scipy.integrate.trapezoid(self.normalization_func(self.z, float(self.kappas[i])), self.z) for i in range(len(self.kappas))])

    def normalization(self, kappa):
        return jnp.interp(kappa, self.kappas, self.kappa_norms)
    
    def __call__(self, data, params):
        un_normalized = self.dVdz(data['redshift']) * ((1 + data['redshift']) ** (params["kappa"]-1))
        return un_normalized#/self.normalization(params['kappa'])

In [8]:
R = Redshift(1.9)
M = MassModel(5.0, 100.0)

In [18]:
Lambda_0 = {'lambda' : 0.35, "gamma": 1.1, 'fp':0.98, 'mu_m':33.0, 'sigma_m':4.0, 'kappa':2.9}

class CustomLikelihood:
    def __init__(self, all_posteriors, selections, domain_changer=None):
        self.all_posteriors = all_posteriors
        self.selections = selections
        self.domain_changer = domain_changer

    def logpdf(self, x):
        #self.domain_changer.inverse_transform_in_place(x)

        #event_likelihoods = jnp.sum(jnp.array([jnp.log(jnp.mean( ( R(post, x) / post["prior"]) )) for post in self.all_posteriors]))
        #selection_effects = -len(self.all_posteriors) * jnp.log( jnp.mean(  R(self.selections, x) / self.selections["prior"]) )

        event_likelihoods = jnp.sum(jnp.array([jnp.log(jnp.mean( (M(post, x) * R(post, x) / post["prior"]) )) for post in self.all_posteriors]))
        selection_effects = -len(self.all_posteriors) * jnp.log( jnp.mean( M(self.selections, x) * R(self.selections, x) / self.selections["prior"]) )
        return event_likelihoods + selection_effects

In [None]:
class DomainChanger:
    def __init__(self, ranges):
        self.ranges = ranges
    
    def transform_to_infinite(self, x, a, b):
        # Normalize x to [0, 1]
        normalized_x = (x - a) / (b - a)
    
        # Avoid division by zero or log of zero
        epsilon = 1e-5
        normalized_x = jnp.clip(normalized_x, epsilon, 1 - epsilon)
    
        # Transform to (-∞, +∞)
        infinite_x = jnp.log(normalized_x / (1 - normalized_x))
        return infinite_x

    def inverse_transform_from_infinite(self, y, a, b):
        # Transform from (-∞, +∞) to [0, 1]
        normalized_y = 1 / (1 + jnp.exp(-y))
    
        # Inverse normalize to [a, b]
        x = normalized_y * (b - a) + a
        return x

    def transform(self, x):
        keys = list(self.ranges.keys())
        new_x = {}
        for key in keys:
            if self.ranges[key] == 'infinite':
                new_x[key + '_transformed'] = x[key]
            else:
                new_x[key + '_transformed'] = self.transform_to_infinite(x[key], self.ranges[key][0], self.ranges[key][1])

        return new_x

    def inverse_transform(self, x):
        keys = list(self.ranges.keys())
        new_x = {}
        for key in keys:
            if self.ranges[key] == 'infinite':
                new_x[key] = x[key + '_transformed']
            else:
                new_x[key] = self.inverse_transform_from_infinite(x[key + '_transformed'], self.ranges[key][0], self.ranges[key][1])

        return new_x

    def transform_in_place(self, x):
        keys = list(self.ranges.keys())
        for key in keys:
            if self.ranges[key] == 'infinite':
                x[key] = x[key]
            else:
                x[key] = self.transform_to_infinite(x[key], self.ranges[key][0], self.ranges[key][1])

        return x

    def inverse_transform_in_place(self, x):
        keys = list(self.ranges.keys())
        for key in keys:
            if self.ranges[key] == 'infinite':
                x[key] = x[key]
            else:
                x[key] = self.inverse_transform_from_infinite(x[key], self.ranges[key][0], self.ranges[key][1])

        return x
        

In [19]:
from bayesian_inference import NUTS, DomainChanger
import jax

#DC = DomainChanger()
limits = {'lambda' : [-5,2], "gamma": [-3,3], 'fp':[0,1], 'mu_m':[10,50], 'sigma_m':[3,10], 'kappa':[0,10]}
DC = DomainChanger(limits)

Lambda_0 = {'lambda' : -2.35, "gamma": 1.1, 'fp':0.98, 'mu_m':33.0, 'sigma_m':4.0, 'kappa':2.9}
#DC.transform_in_place(Lambda_0)
Lambda_0

{'lambda': -2.35,
 'gamma': 1.1,
 'fp': 0.98,
 'mu_m': 33.0,
 'sigma_m': 4.0,
 'kappa': 2.9}

In [20]:
CL = CustomLikelihood(posteriors_jaxed[1:2], selections_jaxed, domain_changer=DC)

In [21]:
new = jax.grad(CL.logpdf)
new(Lambda_0)
#new(kappa_0)

{'fp': Array(3.093314, dtype=float32, weak_type=True),
 'gamma': Array(-0.13727583, dtype=float32, weak_type=True),
 'kappa': Array(-0.14747958, dtype=float32, weak_type=True),
 'lambda': Array(-0.46813202, dtype=float32, weak_type=True),
 'mu_m': Array(-0.02301218, dtype=float32, weak_type=True),
 'sigma_m': Array(0.01370422, dtype=float32, weak_type=True)}

In [22]:
N = NUTS(CL, Lambda_0, limits=limits)

{'lambda': -2.35, 'gamma': 1.1, 'fp': 0.98, 'mu_m': 33.0, 'sigma_m': 4.0, 'kappa': 2.9} {'lambda': -2.35, 'gamma': 1.1, 'fp': 0.98, 'mu_m': 33.0, 'sigma_m': 4.0, 'kappa': 2.9}


In [23]:
result = N.run(1000)

Running the inference for 1000 samples


In [15]:
result_transformed = {k:jnp.array(result[k]) for k in result.copy().columns}
CL.domain_changer.inverse_transform_in_place(result_transformed)
result_transformed = pd.DataFrame(result_transformed)

In [24]:
result

Unnamed: 0,fp,gamma,kappa,lambda,mu_m,sigma_m
0,-37.710251,41.599049,-30.462132,-7.388054,69.337708,92.552216
1,-36.679070,42.258064,-30.829433,-7.430704,69.930542,94.866295
2,-36.089043,42.489300,-31.050999,-7.281521,69.629967,95.284805
3,-36.334530,42.300198,-30.824821,-7.522876,68.887932,95.476013
4,-33.955784,42.340042,-31.121170,-7.533043,71.190529,91.666054
...,...,...,...,...,...,...
995,-950.742615,77.138329,-31.781605,-7.685198,-260.766907,1819.616577
996,-953.370911,75.955338,-32.552479,-7.422150,-259.784729,1819.188843
997,-953.342712,75.725456,-32.306728,-7.205547,-260.462158,1818.735229
998,-953.154663,75.460754,-31.068443,-7.089154,-259.677643,1818.796387


In [385]:
result_transformed

Unnamed: 0,fp,gamma,kappa,lambda,mu_m,sigma_m
0,0.979445,0.188128,0.400622,-3.514095,30.141695,3.505382
1,0.979445,0.188132,0.400623,-3.514109,30.141726,3.505388
2,0.979445,0.188131,0.400622,-3.514128,30.141689,3.505388
3,0.979445,0.188131,0.400622,-3.514128,30.141689,3.505388
4,0.979445,0.188131,0.400622,-3.514128,30.141689,3.505388
...,...,...,...,...,...,...
995,0.979445,0.188152,0.400619,-3.514153,30.142056,3.505388
996,0.979445,0.188152,0.400619,-3.514153,30.142056,3.505388
997,0.979445,0.188152,0.400619,-3.514153,30.142056,3.505388
998,0.979445,0.188152,0.400619,-3.514153,30.142056,3.505388
