In [3]:
import os
os.chdir('/home/scur2012/Thesis/master-thesis/experiments/peregrine')
import zarr
import pickle
import torch

import swyft.lightning as sl

data_directory = '/scratch-shared/scur2012/peregrine_data/bhardwaj2023'

run_name = 'lowSNR'
rnd_id = 1

injection = dict(
    mass_1 = 39.536,
    mass_2 = 34.872,
    mass_ratio = 0.8858,
    chirp_mass = 32.14,
    luminosity_distance = 200,
    dec = 0.071,
    ra = 5.556,
    theta_jn = 0.4432,
    psi = 1.100,
    phase = 5.089,
    tilt_1 = 1.497,
    tilt_2 = 1.102,
    a_1 = 0.9702,
    a_2 = 0.8118,
    phi_12 = 6.220,
    phi_jl = 1.885,
    geocent_time = 0.0,
)

for run_name in ['lowSNR', 'highSNR']:
    
    injection['luminosity_distance'] = 200 if run_name=='highSNR' else 900
    
    for rid in range(8):
        rnd_id = rid+1

        # Prior
        saved_path = f'/scratch-shared/scur2012/peregrine_data/bhardwaj2023/simulations_{run_name}_R{rnd_id}'
        simulation_results = zarr.convenience.open(saved_path)

        # Posterior
        import pickle
        logratio_path = f'/scratch-shared/scur2012/peregrine_data/bhardwaj2023/logratios_{run_name}/logratios_R{rnd_id}'
        with open(logratio_path, 'rb') as f:
            logratio_data = pickle.load(f)

        # Plot
        import matplotlib.pyplot as plt
        import numpy as np
        import pandas as pd
        from scipy.interpolate import interp1d

        parameter_names = pd.read_table(f'/scratch-shared/scur2012/peregrine_data/bhardwaj2023/param_idxs_{run_name}.txt', 
                                        sep = '\\s+', usecols=[0,1], header=None).set_index(1).values[:,0]
        # ground_truth = dict(zip(parameter_names, list(obs['z_int']) + list(obs['z_ext'])))

        plt.figure(1, figsize=(30, 20))
        for p_id in range(15):
            
            # Posteriors plot
            
            ax = plt.subplot(3, 5, p_id + 1)
            plt.xlabel(parameter_names[p_id])

            logratios = logratio_data.logratios[:, p_id]
            sample = logratio_data.params[:, p_id, 0]
            h, bins = np.histogram(
                sample, weights=np.exp(logratios), density=True, bins=30
            )
            bin_centres = 0.5 * (bins[1:] + bins[0:-1])
            plt.bar(
                bin_centres,
                h,
                width=bin_centres[1] - bin_centres[0],
                alpha=0.3,
                label=f'Posterior (R{rnd_id})',
            )
            min1, max1 = bin_centres[0], bin_centres[-1]
            fit1 = interp1d(bin_centres, h, kind=3, fill_value="extrapolate")

            x_grid = np.linspace(min1, max1, 1000)
            plt.plot(x_grid[fit1(x_grid)>0], fit1(x_grid)[fit1(x_grid)>0])
            
            # Priors plot
            
            #prob_masses = coverage_data.prob_masses[:, p_id]
            #sample = coverage_data.params[:, p_id, 0]
            
            sample = simulation_results['data']['z_total'][:, p_id]
                
            h, bins = np.histogram(
                sample, density=True, bins=30
            )
            bin_centres = 0.5 * (bins[1:] + bins[0:-1])
            plt.bar(
                bin_centres,
                h,
                width=bin_centres[1] - bin_centres[0],
                alpha=0.3,
                label=f'Prior (R{rnd_id})',
            )
            min1, max1 = bin_centres[0], bin_centres[-1]
            fit1 = interp1d(bin_centres, h, kind=3, fill_value="extrapolate")
            
            if p_id == 0:
                plt.legend()
            x_grid = np.linspace(min1, max1, 1000)
            plt.plot(x_grid[fit1(x_grid)>0], fit1(x_grid)[fit1(x_grid)>0])
            
            # Add ground-truth
            plt.axvline(injection[parameter_names[p_id]], 0, 1, color='k', linestyle='--', linewidth=1)

        plt.savefig(f'Posteriors_{run_name}_round_{rnd_id}.png', dpi=600, bbox_inches='tight')
        plt.close()