In [None]:
from pathlib import Path

import numba as nb
import numpy as np
import scipy.stats as stats

from natsort import natsorted

import matplotlib.pyplot as plt


#plt.style.use("mike")
import warnings
warnings.simplefilter("ignore")
warnings.filterwarnings('ignore')


import astropy.units as u

import cmasher as cmr

green = "#33FF86"
purple = "#CE33FF"

%matplotlib widget
from cosmogrb.universe.survey import Survey

import os, sys
parent_dir = os.path.abspath('..')
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

from zusammen.stan_models.stan_model import get_model
from zusammen import AnalysisBuilder, DataSet
from zusammen.spectral_plot import display_posterior_model_counts

from threeML import update_logging_level

import arviz as av


update_logging_level("FATAL")


from astromodels import Band_Calderone, PointSource, Model


from threeML import JointLikelihood, DataList, display_spectrum_model_counts

from astromodels import Cutoff_powerlaw

import popsynth as ps

Import the survey and process the GRBs

In [None]:
survey = Survey.from_file('data/survey.h5')
ab = AnalysisBuilder(survey, use_bb=True)

In [None]:
ab.write_yaml("test_proc.yml")

In [None]:
ds = DataSet.from_yaml("test_proc.yml")

In [None]:
ds.to_hdf5_file("sgrb.h5")

In [None]:
ds.to_hdf5_file("sgrb.h5", sig_threshold=40)

In [None]:
ds = DataSet.from_hdf5_file('sgrb.h5')

In [None]:
ds.to_stan_dict()

Make Stan model

In [None]:
m = get_model("cpl_simple_chunked")

In [None]:
m.clean_model()

In [None]:
m.build_model(opt_exp=True)

In [None]:
data = ds.to_stan_dict()

n_threads = 16
n_chains = 2

fit = m.model.sample(
    data=data,
    parallel_chains=n_chains,
    chains=n_chains,
    # inits= {'alpha':-1.},
    threads_per_chain=n_threads,
    seed=1234,
    iter_warmup=2000,
    iter_sampling=1000,
    max_treedepth=15,
    adapt_delta=0.9,
    show_progress=True,
    #show_console=True
)

Import Stan results into arviz

In [None]:
res = av.from_cmdstanpy(fit)

In [None]:
res.to_netcdf("inference_data/testing_wo_gc_3.nc")

In [None]:
res = av.from_netcdf("inference_data/testing_wo_gc.nc")

In [None]:
res.sample_stats

In [None]:
%matplotlib widget
av.plot_trace(res)

In [None]:
%matplotlib widget
av.plot_pair(res, divergences=True)

Load parameters  
    $\alpha$: parameter of CPL  
    $E_p = \log E_c$  
    $K$: normalization

In [None]:
id = 0


alpha = res.posterior.alpha.stack(sample=("chain", "draw")).values[id]

log_ec = res.posterior.log_ec.stack(sample=("chain", "draw")).values[id]

K  = res.posterior.K.stack(sample=("chain", "draw")).values[id]

log_epeak = res.posterior.log_epeak.stack(sample=("chain", "draw")).values[id]

energy_flux = res.posterior.energy_flux.stack(sample=("chain", "draw")).values[id]

gamma = res.posterior.gamma.stack(sample=("chain", "draw")).values[id]

log_Nrest = res.posterior.log_Nrest.stack(sample=("chain", "draw")).values[id]


div = res.sample_stats.diverging.stack(sample=("chain", "draw")).values

samples = np.vstack((K, alpha, 10.**log_ec))


dl = ds.get_data_list_of_interval(id)

In [None]:
log_Nrest

In [None]:
div.sum()

In [None]:
bc = Cutoff_powerlaw(piv=100)

bc.index.bounds = (None, None)
bc.K.bounds = (None, None)
bc.xc.bounds = (None, None)

model = Model(PointSource("ps",0,0, spectral_shape=bc))

In [None]:
%matplotlib widget
#fig, ax = plt.subplots()

plt.scatter(log_epeak, alpha, alpha=0.1)

In [None]:
fig, ax = plt.subplots()

ax.scatter(K, alpha, alpha=0.1)

In [None]:
display_posterior_model_counts(
    dl[1], model, samples.T[::20], min_rate=1e-99, shade=False
)

In [None]:
cpl = Cutoff_powerlaw(piv=100,K=1e-1,xc=200)


dl = ds.get_data_list_of_interval(2)

model = Model(PointSource("ps",0,0, spectral_shape=cpl))

ba = JointLikelihood(model,DataList(*dl))

In [None]:
ba.fit()

In [None]:
display_spectrum_model_counts(ba)

In [None]:
dl[0].view_count_spectrum()

In [None]:
cpl