# Analyze JS estimates

Checking the JS estimates

In [1]:
import os
import copy
import psutil
p = psutil.Process()
p.cpu_affinity([0])
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
jax.config.update("jax_disable_jit", True)

from scipy.spatial.distance import jensenshannon

NAMING = ['M_c', 'q', 's1_z', 's2_z', 'lambda_1', 'lambda_2', 'd_L', 't_c', 'phase_c', 'cos_iota', 'psi', 'ra', 'sin_dec']

from ripple import get_chi_eff, Mc_eta_to_ms, lambdas_to_lambda_tildes

In [2]:
def convert_chains(chains: np.array,
                   iota_index: int = 9,
                   dec_index: int = 12):
    
    # Initialize the new chains
    new_chains = copy.deepcopy(chains)
    
    # First, the cos and sin:    
    new_chains[:, iota_index] = np.arccos(chains[:, iota_index])
    new_chains[:, dec_index] = np.arcsin(chains[:, dec_index])
    
    # Now, for chi eff and lambdas:
    mc, q = chains[:, 0], chains[:, 1]
    eta = q / (1 + q)**2
    m1, m2 = Mc_eta_to_ms(jnp.array([mc, eta]))
    chi1, chi2 = chains[:, 2], chains[:, 3]
    lambda1, lambda2 = chains[:, 4], chains[:, 5]
    
    # Convert the chains
    chi_eff = get_chi_eff(jnp.array([m1, m2, chi1, chi2]))
    lambda_tilde, delta_lambda_tilde = lambdas_to_lambda_tildes(jnp.array([lambda1, lambda2, m1, m2]))
    
    # Replace the chi and lambdas:
    new_chains[:, 4] = lambda_tilde 
    new_chains[:, 5] = delta_lambda_tilde
    new_chains[:, 2] = chi_eff
    
    # Now, remove chi2 column:
    new_chains = np.delete(new_chains, 3, axis=1)
    
    return new_chains

In [3]:
def get_JS_estimate(default_samples: np.array,
                    variation_outdir: str = "/home/thibeau.wouters/TurboPE-BNS/JS_estimate/GW190425_TaylorF2/varied_runs/outdir/",
                    debug: bool = False,
                    nb_bins = 20):
    
    # Make the histogram for the repro run, histogram per 1D parameter:
    histogram_dict = {}
    edges_dict = {}
    for i, key in enumerate(NAMING):
        values = default_samples[:, i]
        histogram, edges = np.histogram(values, bins=nb_bins, density=True)
        histogram_dict[key] = histogram
        edges_dict[key] = edges
    
    # Will save the JS divergence for each run ibt
    js_dict = defaultdict(list)
    
    for path in os.listdir(variation_outdir):
        # Only get subdirectory, which is a run
        if not os.path.isdir(variation_outdir + path):
            continue
        
        # Check if the run has a results_production.npz file
        file = variation_outdir + path + "/results_production.npz"
        
        if not os.path.exists(file):
            print(f"WARNING: {file} does not exist, skipping this")
            continue
            
        # Load the data
        data = np.load(file)
        chains = data["chains"]
        chains = np.reshape(chains, (-1, 13))

        # chains = convert_chains(chains)
        
        # Iterate over all the parameters
        for i, key in enumerate(NAMING):
            
            # Get the values of this parameter
            values = chains[:, i]
            
            # Get the histogram of the current run
            histogram, _ = np.histogram(values, bins=edges_dict[key], density=True)
            
            if debug:
                plt.hist(histogram, edges_dict[key], label = "Current run")
                plt.hist(histogram_dict[key], edges_dict[key], label = "Default run")
                plt.legend()
                plt.show()
            
            
            # Compute the JS divergence against the default run
            js_div = jensenshannon(histogram, histogram_dict[key], base = 2) ** 2
            js_dict[key].append(js_div)
            
    return js_dict

## Load the samples

In [4]:
repro_samples = "/home/thibeau.wouters/TurboPE-BNS/JS_estimate/GW190425_TaylorF2/repro/outdir/results_production.npz"

repro_data = np.load(repro_samples)
repro_chains = repro_data["chains"]
repro_chains = np.reshape(repro_chains, (-1, 13))

# repro_chains = convert_chains(repro_chains)

## Compute the JS divergences

In [5]:
js_dict = get_JS_estimate(repro_chains)

In [6]:
js_dict

defaultdict(list,
            {'M_c': [0.00021476234426517647,
              0.00026117638337234776,
              0.00017303249554248,
              0.00015825397140405576,
              0.0002203251414816731,
              0.0001605458849025018,
              0.00019527393133460198,
              0.00012663123192121773,
              7.643742296414048e-05,
              0.0001386299038126467],
             'q': [0.00024167965999735778,
              0.00017561668150511462,
              0.00013300875811472825,
              0.00013808419953950546,
              0.00013831985308409114,
              0.00017987363373826943,
              0.0001223679862769129,
              0.00015245634897474196,
              0.00011068449041929593,
              0.00013413671246207795],
             's1_z': [0.0002430886251553477,
              0.00011378149487373951,
              0.00027607710948396763,
              0.00029108076943635663,
              0.0003247404041144363,
              0.0003

## Postprocessing

In [7]:
mean_js_dict = {key: np.mean(value) for key, value in js_dict.items()}

for key, value in mean_js_dict.items():
    print(f"{key}: {np.round(value, 5)}")

M_c: 0.00017
q: 0.00015
s1_z: 0.00027
s2_z: 0.00019
lambda_1: 0.00022
lambda_2: 0.00027
d_L: 0.00059
t_c: 0.00062
phase_c: 0.00016
cos_iota: 0.00021
psi: 0.00013
ra: 0.00112
sin_dec: 0.00033


In [8]:
max_js_dict = {key: np.max(value) for key, value in js_dict.items()}

for key, value in max_js_dict.items():
    print(f"{key}: {np.round(value, 5)}")

M_c: 0.00026
q: 0.00024
s1_z: 0.00049
s2_z: 0.00025
lambda_1: 0.00036
lambda_2: 0.00048
d_L: 0.00097
t_c: 0.00105
phase_c: 0.00028
cos_iota: 0.00029
psi: 0.0002
ra: 0.00162
sin_dec: 0.00064


First, drop tc from the dicts:

In [9]:
# First, drop tc from the dicts:

max_js_dict = {key: value for key, value in max_js_dict.items() if key != "t_c"}
mean_js_dict = {key: value for key, value in mean_js_dict.items() if key != "t_c"}

# Save the JS div values
np.savez("js_div_noise.npz", max = list(max_js_dict.values()), mean = list(mean_js_dict.values()))