# Analyze runtimes

The referee asked about the total number of samples produced by the different pipelines.

In [311]:
import os
import psutil
p = psutil.Process()
p.cpu_affinity([0])
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
jax.config.update("jax_disable_jit", True)
import arviz

import h5py
import json

In [312]:
gwosc_path = "/home/thibeau.wouters/gw-datasets/GW190425/posterior_samples.h5"
jim_root_path = "/home/thibeau.wouters/TurboPE-BNS/real_events/"
bilby_root_path = "/home/thibeau.wouters/jim_pbilby_samples/older_bilby_version/"
paths_dict = {"GW170817_TaylorF2": {"jim": jim_root_path + "GW170817_TaylorF2/outdir/results_production.npz",
                    "bilby": bilby_root_path + "GW170817_TF2_with_tukey_fix_result.json"},
              
              "GW170817_NRTidalv2": {"jim": jim_root_path + "GW170817_NRTidalv2/outdir/results_production.npz",
                                     "bilby": bilby_root_path + "GW170817_IMRPhenomD_NRTidalv2_result.json",
                    },
              
              "GW190425_TaylorF2": {"jim": jim_root_path + "GW190425_TaylorF2/outdir_gwosc_data/results_production.npz",
                                    "bilby": bilby_root_path + "GW190425_GWOSC_data_result.json",
                    },
              
              "GW190425_NRTidalv2": {"jim": jim_root_path + "GW190425_NRTidalv2/outdir/results_production.npz",
                                     "bilby": bilby_root_path + "GW190425_NRTv2_GWOSC_data_result.json",
                    },
}

RUN_NAMES = list(paths_dict.keys())
JIM_VAR_NAMES = ['M_c', 'q', 's1_z', 's2_z', 'lambda_1', 'lambda_2', 'd_L', 't_c', 'phase_c', 'cos_iota', 'psi', 'ra', 'sin_dec']
BILBY_VAR_NAMES = ['chirp_mass', 'mass_ratio', 'spin_1z', 'spin_2z', 'lambda_1', 'lambda_2', 'luminosity_distance', 'phase', 'iota', 'psi', 'ra', 'dec']

## Jim

In [321]:
def compute_ess(chains, 
                log_prob, 
                method = "arviz", 
                take_exp: bool = True,
                relative: bool = False):
    
    # Get the weights used for inference
    if take_exp:
        weights = np.exp(log_prob)
    else:
        weights = log_prob
    weights /= np.sum(weights) # to avoid overflow
    weights = np.array(weights)
    
    if method == "arviz":
        ess_list = []
        chains = np.array(chains).T
        for param_chains in chains:
            ess_list.append(arviz.ess(chains))
            
        ess = np.mean(ess_list)

    elif method == "rejection_sampling":
        weights = np.exp(log_prob)
        keep = weights > np.random.uniform(0, max(weights), weights.shape)
        ess = np.sum(keep)
    
    elif method == "kish":
        num = (np.sum(weights))**2
        denom = np.sum(weights**2)
        ess = num/denom
    
    else:
        raise ValueError("Unknown method")
    
    # If relative, divide by total sample size:
    if relative:
        ess /= len(log_prob)
    return ess

In [322]:
def get_nb_jim_samples(path: str, 
                       report_runtime_production: bool = False,
                       get_ess: bool = True):
    
    identifier = path.split("/")[-3]
    print(f"identifier is: {identifier}")
    
    data = np.load(path)
    chains = data["chains"]
    a, b = np.shape(chains)[0], np.shape(chains)[1]
    
    # Also show the runtimes of the production loop
    if report_runtime_production:
        new_path = path.replace("results_production.npz", "runtime_production.txt")
        
        time = np.loadtxt(new_path)
        print("Runtime production:", time)
        
    # Get the chains and the log prob
    data = np.load(path)
    chains = data["chains"]
    chains = np.reshape(chains, (int(a * b), 13))
    log_prob = data["log_prob"]
    log_prob = np.reshape(log_prob, (int(a * b),))
    
    # DEBUG: make a histogram of log prob:
    plt.hist(log_prob, bins=100)
    plt.savefig(f"./figures/hist_log_prob_{identifier}.png")
    plt.close()
    
    print(np.shape(chains))
    print(np.shape(log_prob))
    
    ess = compute_ess(chains, log_prob, method = "kish", relative = True)
    
    return a * b, ess

In [323]:
def my_avg(values):
    """Take the average but floor and int it"""
    
    return int(np.floor(np.mean(list(values))))

In [324]:
total_nb_samples_jim = {}
ess_jim = {}

for run_name in RUN_NAMES:
    total, ess = get_nb_jim_samples(paths_dict[run_name]["jim"])
    total_nb_samples_jim[run_name] = total
    ess_jim[run_name] = ess
    
print("Total samples for Jim")
print(total_nb_samples_jim)

print("ESS for Jim")
print(ess_jim)

identifier is: GW170817_TaylorF2
(220000, 13)
(220000,)
identifier is: GW170817_NRTidalv2
(220000, 13)
(220000,)
identifier is: GW190425_TaylorF2
(600000, 13)
(600000,)
identifier is: GW190425_NRTidalv2
(600000, 13)
(600000,)
Total samples for Jim
{'GW170817_TaylorF2': 220000, 'GW170817_NRTidalv2': 220000, 'GW190425_TaylorF2': 600000, 'GW190425_NRTidalv2': 600000}
ESS for Jim
{'GW170817_TaylorF2': 6.346562498299967e-05, 'GW170817_NRTidalv2': 6.333542479671354e-05, 'GW190425_TaylorF2': 2.669603090068237e-05, 'GW190425_NRTidalv2': 2.6565362696847772e-05}


## Bilby

In [325]:
def get_nb_bilby_samples(path: str):
    
    with open(path, "r") as f:
        data = json.load(f)
    posterior = data["posterior"]["content"]
    mc = posterior["chirp_mass"] # just taking the chirp mass samples as an example
    n = len(mc)
    
    log_likelihood = np.array(posterior["log_likelihood"])
    print(f"log_likelihood: {log_likelihood}")
    
    nested_samples = data["nested_samples"]['content'] # dict with keys: 'dec', 'ra', 'chirp_mass', 'mass_ratio', 'chi_1', 'chi_2', 'luminosity_distance', 'cos_theta_jn', 'psi', 'phase', 'lambda_1', 'lambda_2', 'geocent_time', 'weights', 'log_likelihood'
    nested_samples_weights = nested_samples["weights"]
    samples = data["samples"]['content']
    
    print(f"Nested samples and samples shape: {np.shape(nested_samples_weights)}, {np.shape(samples)[0]}")
    print(f"nested_samples_weights min and max: {np.min(nested_samples_weights)}, {np.max(nested_samples_weights)}")
    
    avg_ess = compute_ess(samples, nested_samples_weights, take_exp = False, method = "kish", relative = True)
    # avg_ess = compute_ess(samples, log_likelihood, take_exp = True, method = "arviz", relative = True)
        
    return n, avg_ess

### pBilby

In [326]:
total_nb_samples_bilby = {}
ess_bilby = {}

for run_name in RUN_NAMES:
    path = paths_dict[run_name]["bilby"]
    total, ess = get_nb_bilby_samples(path)
    
    total_nb_samples_bilby[run_name] = total
    ess_bilby[run_name] = ess
    
print("Total samples for Bilby")
print(total_nb_samples_bilby)

print("Average total samples for Bilby")
avg_total_nb_samples_bilby = my_avg(list(total_nb_samples_bilby.values()))
print(avg_total_nb_samples_bilby)

print("ESS for Bilby")
print(ess_bilby)

avg_ess_bilby = my_avg(list(ess_bilby.values()))
print("Average ESS for Bilby")
print(avg_ess_bilby)

log_likelihood: [547.00313918 548.69856653 549.43301019 ... 569.99966648 569.99966648
 569.99966648]
Nested samples and samples shape: (44782,), 44782
nested_samples_weights min and max: 0.0, 0.00016014352358832842
log_likelihood: [539.12220195 540.08677565 540.63059469 ... 561.13490206 561.13490206
 561.13490206]
Nested samples and samples shape: (45365,), 45365
nested_samples_weights min and max: 0.0, 0.00015848338278036988
log_likelihood: [65.05886983 66.21304446 66.81574675 ... 86.80823552 86.80823552
 86.80823552]
Nested samples and samples shape: (30782,), 30782
nested_samples_weights min and max: 0.0, 0.00014868191631388282
log_likelihood: [64.81474159 65.86292189 66.34509702 ... 85.93556712 85.93556712
 85.93556712]
Nested samples and samples shape: (30277,), 30277
nested_samples_weights min and max: 0.0, 0.00014336516234732138
Total samples for Bilby
{'GW170817_TaylorF2': 44782, 'GW170817_NRTidalv2': 45365, 'GW190425_TaylorF2': 30782, 'GW190425_NRTidalv2': 30277}
Average total

### Relative binning-Bilby

In [None]:
def get_nb_rb_bilby_samples(path: str):
    
    with h5py.File(path, "r") as f:
        posterior = f["posterior"]
        mc = posterior["chirp_mass"][()] # just taking the chirp mass samples as an example
        n = len(mc)
        
        my_ess_list = []
        for name in BILBY_VAR_NAMES:
            values = posterior[name][()]
            
            # NOTE: sometimes the json is broken and there are dicts instead of floats
            if isinstance(values[0], dict):
                new_values = [item['content'] for item in values]
                values = new_values

            values = np.array(values)

            ess = arviz.ess(values)
            # print(f"ESS for {name} production: {ess}")
            my_ess_list.append(ess)
                
        avg_ess = int(np.floor(np.mean(my_ess_list)))
    
    return n, avg_ess

In [None]:
rb_paths = ["../RB/gw170817_relbin_TaylorF2_result.hdf5", 
            "../RB/gw170817_relbin_result.hdf5",
            "../RB/gw190425_relbin_TaylorF2_result.hdf5",
            "../RB/gw190425_relbin_result.hdf5"]

total_nb_samples_rb_bilby = {}
ess_rb_bilby = {}

for path, run_name in zip(rb_paths, RUN_NAMES):
    print(run_name)
    total, ess = get_nb_rb_bilby_samples(path)

    total_nb_samples_rb_bilby[run_name] = total
    ess_rb_bilby[run_name] = ess
    
print("Total samples for RB Bilby")
print(total_nb_samples_rb_bilby)

print("Average total samples for RB Bilby")
avg_total_nb_samples_rb_bilby = my_avg(list(total_nb_samples_rb_bilby.values()))
print(avg_total_nb_samples_rb_bilby)

print("ESS for RB Bilby")
print(ess_rb_bilby)

avg_ess_rb_bilby = my_avg(list(ess_rb_bilby.values()))
print("Average ESS for RB Bilby")
print(avg_ess_rb_bilby)

GW170817_TaylorF2
GW170817_NRTidalv2
GW190425_TaylorF2


GW190425_NRTidalv2
Total samples for RB Bilby
{'GW170817_TaylorF2': 5258, 'GW170817_NRTidalv2': 5172, 'GW190425_TaylorF2': 6743, 'GW190425_NRTidalv2': 5172}
Average total samples for RB Bilby
5586
ESS for RB Bilby
{'GW170817_TaylorF2': 3274, 'GW170817_NRTidalv2': 3837, 'GW190425_TaylorF2': 3661, 'GW190425_NRTidalv2': 3837}
Average ESS for RB Bilby
3652


### ROQ-Bilby

In [None]:
roq_paths = ["/home/thibeau.wouters/TurboPE-BNS/ROQ/gw170817_ROQ_result.hdf5",
             "/home/thibeau.wouters/TurboPE-BNS/ROQ/gw190425_ROQ_result.hdf5"]

total_nb_samples_roq_bilby = {}
ess_roq_bilby = {}

for path, run_name in zip(rb_paths, roq_paths):
    print(run_name)
    total, ess = get_nb_rb_bilby_samples(path)

    total_nb_samples_roq_bilby[run_name] = total
    ess_roq_bilby[run_name] = ess
    
print("Total samples for ROQ Bilby")
print(total_nb_samples_roq_bilby)

print("Average total samples for ROQ Bilby")
avg_total_nb_samples_roq_bilby = my_avg(list(total_nb_samples_roq_bilby.values()))
print(avg_total_nb_samples_roq_bilby)

print("ESS for ROQ Bilby")
print(ess_roq_bilby)

avg_ess_roq_bilby = my_avg(list(ess_roq_bilby.values()))
print("Average ESS for ROQ Bilby")
print(avg_ess_roq_bilby)

/home/thibeau.wouters/TurboPE-BNS/ROQ/gw170817_ROQ_result.hdf5
/home/thibeau.wouters/TurboPE-BNS/ROQ/gw190425_ROQ_result.hdf5
Total samples for ROQ Bilby
{'/home/thibeau.wouters/TurboPE-BNS/ROQ/gw170817_ROQ_result.hdf5': 5258, '/home/thibeau.wouters/TurboPE-BNS/ROQ/gw190425_ROQ_result.hdf5': 5172}
Average total samples for ROQ Bilby
5215
ESS for ROQ Bilby
{'/home/thibeau.wouters/TurboPE-BNS/ROQ/gw170817_ROQ_result.hdf5': 3274, '/home/thibeau.wouters/TurboPE-BNS/ROQ/gw190425_ROQ_result.hdf5': 3837}
Average ESS for ROQ Bilby
3555


# Combine into a table

In [None]:
values = list(total_nb_samples_jim.values())
print((np.min(values), np.max(values)))

values = list(ess_jim.values())
print((np.min(values), np.max(values)))

(220000, 600000)
(0.0709587168064469, 0.12169029575884786)


In [None]:
def my_get_number(value):
    value = int(np.round(value))
    my_string = str(value)
    first, second = my_string[0], my_string[1]
    
    power = int(np.floor(np.log10(value)))
    
    return_string = r"{}.{} \times 10^{}".format(first, second, power)
    
    return return_string

In [None]:
latex_code = '& Number of samples & Number of effective samples \\\\\n \hline\\hline\n \\textsc{{Jim}} & ${}$ & ${}$ \\\\ \\hline \n \\textsc{{pBilby}} & ${}$ & ${}$ \\\\ \\hline \n \\textsc{{RB-Bilby}} & ${}$ & ${}$ \\\\ \\hline \n \\textsc{{ROQ-Bilby}} & ${}$ & ${}$ \\\\ \\hline\\hline'\
.format(my_get_number(np.mean(list(total_nb_samples_jim.values()))),
        my_get_number(np.mean(list(ess_jim.values()))),
        my_get_number(avg_total_nb_samples_bilby),
        my_get_number(avg_ess_bilby),
        my_get_number(avg_total_nb_samples_rb_bilby),
        my_get_number(avg_ess_rb_bilby),
        my_get_number(avg_total_nb_samples_roq_bilby),
        my_get_number(avg_ess_roq_bilby)
)

print(latex_code)

# my_get_number(np.min(list(total_nb_samples_jim.values()))) + " -- " + my_get_number(np.max(list(total_nb_samples_jim.values()))),
# my_get_number(np.min(list(ess_jim.values()))) + " -- " + my_get_number(np.max(list(ess_jim.values())))

IndexError: string index out of range

## Final table

Only quote ESS to make it more informed.

In [None]:
latex_code = '& Number of effective samples \\\\\n \hline\\hline\n \\textsc{{Jim}} & ${}$ \\\\ \\hline \n \\textsc{{pBilby}} & ${}$ \\\\ \\hline \n \\textsc{{RB-Bilby}} & ${}$ \\\\ \\hline \n \\textsc{{ROQ-Bilby}} & ${}$ \\\\ \\hline\\hline'\
.format(my_get_number(np.mean(list(ess_jim.values()))),
        my_get_number(avg_ess_bilby),
        my_get_number(avg_ess_rb_bilby),
        my_get_number(avg_ess_roq_bilby)
)

print(latex_code)

# my_get_number(np.min(list(total_nb_samples_jim.values()))) + " -- " + my_get_number(np.max(list(total_nb_samples_jim.values()))),
# my_get_number(np.min(list(ess_jim.values()))) + " -- " + my_get_number(np.max(list(ess_jim.values())))

& Number of effective samples \\
 \hline\hline
 \textsc{Jim} & $5.2 \times 10^3$ \\ \hline 
 \textsc{pBilby} & $4.5 \times 10^3$ \\ \hline 
 \textsc{RB-Bilby} & $3.6 \times 10^3$ \\ \hline 
 \textsc{ROQ-Bilby} & $3.5 \times 10^3$ \\ \hline\hline
