In [None]:
import os
import numpy as np
import pandas as pd

from glob import glob
from scipy.stats import lognorm
from scipy.interpolate import interp1d
from astropy.cosmology import Planck18 as cosmo
from frb.dm.igm import average_DM
from frb.dm import igm
from frb.dm import cosmic
from frb.dm import mcmc

from mockFRBhosts import draw_galaxies, observed_bands, draw_DM
from mockFRBhosts.mcmc_simulations import do_mcmc

In [None]:
# Define where to save the posteriors.
outdir = '../Posteriors/'

if not os.path.isdir(outdir):
    os.makedirs(outdir)

In [None]:
# Load FRBs from the pickle files.
pickles = sorted(glob('../Simulated_FRBs/*.pickle'))

survey_models, z_models = [], []
for file in pickles:
    # Extract models from file names.
    params = os.path.basename(file)
    params = os.path.splitext(params)[0]
    params = params.split('_', 1)
    survey_models.append(params[0])
    z_models.append(params[1])

# Check files
pickles

In [None]:
# Number of FRBs that should be used througout
n_frbs = 1000

# Pick an FRB survey and redshift distribution
chosen = pickles[0]
radio_survey = survey_models[0]
print(chosen)

# Weight galaxy choice depending on file name
if os.path.splitext(chosen)[0][-3:] == 'sfr':  # last thre letters before extension
    weights = 'mstardot'
else:
    weights = 'mstars_total'

frbs = np.load(chosen, allow_pickle=True)
print(frbs.shape[0], "FRBs in file, using only first", n_frbs)
frbs = frbs.iloc[:n_frbs].copy()

galaxies, snapnum = draw_galaxies(frbs['z'], weights=weights, seed=42)

# Order FRBs such that they correspond to galaxies at the same positions.
frbs.loc[:, 'snapnum'] = snapnum
frbs.sort_values('snapnum', ascending=True, inplace=True)

n_bands_obs_SDSS, n_bands_obs_LSST, n_bands_obs_Euclid, n_bands_obs_DES = observed_bands(frbs, galaxies)

frbs['n_bands_SDSS'] = n_bands_obs_SDSS.to_numpy()
frbs['n_bands_LSST'] = n_bands_obs_LSST.to_numpy()
frbs['n_bands_Euclid'] = n_bands_obs_Euclid.to_numpy()
frbs['n_bands_DES'] = n_bands_obs_DES.to_numpy()

In [None]:
# Draw a DM for each FRB given its reshift.
rng = np.random.default_rng(seed=42)
frbs['DM'] = draw_DM(frbs['z'], F=0.2, mu=100, lognorm_s=1, rng=rng)

In [None]:
survey = 'SDSS'
n_bands_obs = frbs['n_bands_' + survey]
n_bands = n_bands_obs.max()

# Limit to FRBs with host in all bands and shuffle them.
frbs_w_host = frbs[n_bands_obs.to_numpy() == n_bands]
rng = np.random.default_rng(seed=42)
frbs_w_host = frbs_w_host.sample(frac=1, ignore_index=True, random_state=rng)

In [None]:
draws = 1500  # Draws are per chain.
cores = 20    # Each core has its own chain.

frb_set = frbs_w_host

n_frbs = len(frb_set)
print(n_frbs)

In [None]:
# Make an MCMC simulation for all FRBs with a host galaxy.
post_path = os.path.join(outdir, f"{radio_survey}_{survey}_{n_frbs}_zs_{cores}x{draws}_draws.nc")
if not os.path.isfile(post_path):
    frb_set = frbs_w_host.iloc[:n_frbs]

    idata = do_mcmc(frb_set['z'], frb_set['DM'], draws=draws, cores=cores)
    idata.to_netcdf(post_path)

else:
    print("Already existing, skip.")

In [None]:
# Creat random samples from the FRB population to compare with.
rndm_sample1 = frbs.sample(n=n_frbs, ignore_index=True, random_state=rng)
rndm_sample2 = frbs.sample(n=n_frbs, ignore_index=True, random_state=rng)
rndm_sample3 = frbs.sample(n=n_frbs, ignore_index=True, random_state=rng)

for frb_set in [rndm_sample1, rndm_sample2, rndm_sample3]:
    i = 0
    post_path = os.path.join(outdir, f"{radio_survey}_{survey}_random_sample_of_{len(frb_set)}_zs_run_{i}_{cores}x{draws}_draws.nc")
    # Don't overwrite existing files.
    while os.path.isfile(post_path):
        i += 1
        post_path = os.path.join(outdir, f"{radio_survey}_{survey}_random_sample_of_{len(frb_set)}_zs_run_{i}_{cores}x{draws}_draws.nc")
        
    print(f"Will save to {post_path}")

    frb_set = frbs_w_host.iloc[:n_frbs]

    idata = do_mcmc(frb_set['z'], frb_set['DM'], draws=draws, cores=cores)
    idata.to_netcdf(post_path)


In [None]:
print(sorted(list(set(np.logspace(1, np.log10(len(frbs_w_host)), 30, dtype=int)))))

In [None]:
draws = 150
cores = 20

for n_frbs in sorted(list(set(np.logspace(1, np.log10(len(frbs_w_host)), 30, dtype=int)))): # [5,7]: #
    print(n_frbs)

    post_path = os.path.join(outdir, f"{radio_survey}_{survey}_{n_frbs}_zs_{cores}x{draws}_draws.nc")
    if os.path.isfile(post_path):
        continue
    
    frb_set = frbs_w_host.iloc[:n_frbs]

    idata = do_mcmc(frb_set['z'], frb_set['DM'], draws=draws, cores=cores)
    idata.to_netcdf(post_path)