In [None]:
from pathlib import Path

# import numba as nb
import numpy as np
import scipy.stats as stats

# from natsort import natsorted

import matplotlib.pyplot as plt


# plt.style.use("mike")
# import warnings
# warnings.simplefilter("ignore")
# warnings.filterwarnings("ignore")


# import astropy.units as u

# import cmasher as cmr

green = "#33FF86"
purple = "#CE33FF"

%matplotlib widget
from cosmogrb.universe.survey import Survey

import os, sys
parent_dir = os.path.abspath('..')
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

from zusammen.stan_models.stan_model import get_model
from zusammen import AnalysisBuilder, DataSet
from zusammen.spectral_plot import display_posterior_model_counts

from threeML import update_logging_level

import arviz as av


update_logging_level("FATAL")


from astromodels import Band_Calderone, PointSource, Model


from threeML import JointLikelihood, DataList, display_spectrum_model_counts

from astromodels import Cutoff_powerlaw

import popsynth as ps

Import the survey and process the GRBs

In [None]:
survey = Survey.from_file('data/survey.h5')
ab = AnalysisBuilder(survey, use_bb=True, intervals_min=2, sig_min=5, all_above_limit=False)

In [None]:
ab.write_yaml("test_proc.yml")

In [None]:
ds = DataSet.from_yaml("test_proc.yml")

In [None]:
ds.to_hdf5_file("sgrb.h5")

In [None]:
ds = DataSet.from_hdf5_file('sgrb.h5')

In [None]:
data = ds.to_stan_dict()
data

In [None]:
maxl = []
for i,j in enumerate(ds.to_stan_dict()["observed_counts"]):
    maxl.append(j.max())
    print(i, j.max())

print(max(maxl))

In [None]:
%matplotlib widget
i,j = 0,0
plt.plot(ds.to_stan_dict()["response"][i,j].T @ ds.to_stan_dict()['observed_counts'][i,j])

Make Stan model

In [None]:
mt = get_model("test")
mt.build_model()
mt.model.sample(chains=1, show_console=True, iter_warmup=1, iter_sampling=1)

In [None]:
m = get_model("cpl_simple_chunked_gc")

In [None]:
m.clean_model()

In [None]:
m.build_model(opt_exp=True)

In [None]:
data = ds.to_stan_dict()

n_threads = 2
n_chains = 2

fit = m.model.sample(
    data=data,
    chains=n_chains,
    parallel_chains=n_chains,
    threads_per_chain=n_threads,
    inits= {
        'alpha': -1 * np.ones(data['N_intervals']),
        'log_ec': 2 * np.ones(data['N_intervals']),

        # 'log_energy_flux': -7 * np.ones(data['N_intervals']),
        # 'log_K': -1 * np.ones(data['N_intervals']),

        # 'log_energy_flux_mu_raw': 0,
        # 'log_energy_flux_sigma': 1,
        # 'log_energy_flux_raw': np.zeros(data['N_intervals']),

        'gamma_sig_meta': 1,
        'log_Nrest_sig_meta': 1,
        'gamma_mu_meta': 1.5,
        'log_Nrest_mu_meta': 52,
        'gamma': 1.5 * np.ones(data['N_grbs']),
        'log_Nrest': 52 * np.ones(data['N_grbs']),

        # 'z': 0.5 * np.ones(data['N_unknown'])
    },  # type: ignore
    seed=1234,
    iter_warmup=1000,
    iter_sampling=500,
    max_treedepth=12,
    adapt_delta=0.9,
    # step_size=0.1,
    show_progress=True,
    # show_console=True,
    refresh=1
)

In [None]:
fit

In [None]:
fit.diagnose()

Import Stan results into arviz

In [None]:
res = av.from_cmdstanpy(fit)

In [None]:
res.to_netcdf("inference_data/testing_analytic_gc.nc")

In [None]:
res = av.from_netcdf("inference_data/testing_analytic_gc.nc")

In [None]:
res.sample_stats.tree_depth.max()

In [None]:
%matplotlib widget
av.plot_trace(res)

In [None]:
%matplotlib widget
for i in range(data['N_intervals']):
    av.plot_trace(res, var_names=["ec"], coords={"ec_dim_0": i})

In [None]:
%matplotlib widget
av.plot_pair(res, divergences=True)

In [None]:
div = res.sample_stats.diverging.stack(sample=("chain", "draw")).values
div.sum()

In [None]:
res

Load parameters  

In [None]:
alpha = []
log_ec = []
K = []
log_energy_flux = []
log_epeak = []
gamma = []
log_Nrest = []
div = []
samples = []
dl = []

for id in range(data["N_intervals"]):
    alpha.append(res.posterior.alpha.stack(sample=("chain", "draw")).values[id])

    log_ec.append(res.posterior.log_ec.stack(sample=("chain", "draw")).values[id])
    K.append(res.posterior.K.stack(sample=("chain", "draw")).values[id])
    log_energy_flux.append(res.posterior.log_energy_flux.stack(sample=("chain", "draw")).values[id])

    div.append(res.sample_stats.diverging.stack(sample=("chain", "draw")).values)

    samples.append(np.vstack((K[id], alpha[id], 10.**log_ec[id])))

    dl.append(ds.get_data_list_of_interval(id))


for id in range(data["N_grbs"]):
    log_epeak.append(res.posterior.log_epeak.stack(sample=("chain", "draw")).values[id])
    gamma.append(res.posterior.gamma.stack(sample=("chain", "draw")).values[id])
    log_Nrest.append(res.posterior.log_Nrest.stack(sample=("chain", "draw")).values[id])

In [None]:
gamma[1].mean()

In [None]:
div[5].sum()

In [None]:
bc = Cutoff_powerlaw(piv=100)

bc.index.bounds = (None, None)
bc.K.bounds = (None, None)
bc.xc.bounds = (None, None)

model = Model(PointSource("ps",0,0, spectral_shape=bc))

In [None]:
%matplotlib widget
plt.scatter(log_ec, alpha, alpha=0.1)

In [None]:
%matplotlib widget
plt.scatter(K, alpha, alpha=0.1)

In [None]:
%matplotlib widget
for id in range(data["N_intervals"]):
    display_posterior_model_counts(
        dl[id][1], model, samples[id].T[::20], min_rate=1e-99, shade=False
    )

In [None]:
%matplotlib widget

def gc(log_epeak, log_Nrest, gamma, z, dl):
    return log_Nrest - (1.099 + 2 * np.log10(dl)) + gamma * (np.log10(1 + z) + log_epeak - 2)

plt.plot(log_epeak,log_energy_flux)
plt.plot(log_epeak, gc(log_epeak, log_Nrest.mean(), gamma.mean(), data["z"][0], data["dl"][0]))

In [None]:
cpl = Cutoff_powerlaw(piv=100,K=1e-1,xc=200)


dl = ds.get_data_list_of_interval(2)

model = Model(PointSource("ps",0,0, spectral_shape=cpl))

ba = JointLikelihood(model,DataList(*dl))

In [None]:
ba.fit()

In [None]:
display_spectrum_model_counts(ba)

In [None]:
dl[0].view_count_spectrum()

In [None]:
cpl.parameters

In [None]:
import astropy.units as u
ba.results.get_flux(10*u.keV, 10e4*u.keV)["flux"][0].value

In [None]:
ba.results.get_data_frame()["value"]

In [None]:
F, Epeak = [], []

cpl = Cutoff_powerlaw(piv=100,K=1e-1,xc=200)

for i in range(data["N_intervals"]):
    dl = ds.get_data_list_of_interval(1)
    model = Model(PointSource("ps",0,0, spectral_shape=cpl))
    ba = JointLikelihood(model,DataList(*dl))
    ba.fit()
    Ec = ba.results.get_data_frame()["value"]["ps.spectrum.main.Cutoff_powerlaw.xc"]
    alpha = ba.results.get_data_frame()["value"]["ps.spectrum.main.Cutoff_powerlaw.index"]
    Epeak.append((2 + alpha) * Ec)
    F.append(ba.results.get_flux(10*u.keV, 10e4*u.keV)["flux"][0].value)

In [None]:
for i,E in enumerate(Epeak):
    if E < 0:
        Epeak.remove(Epeak[i])
        F.remove(F[i])
len(F)

In [None]:
from scipy.optimize import curve_fit

(a, gamma), _ = curve_fit(lambda x,a,b: a * (x/100)**b, Epeak, F)
gamma

In [None]:
F = [x for _, x in sorted(zip(Epeak, F))]
Epeak.sort()

In [None]:
%matplotlib widget
plt.plot(Epeak, F)
plt.plot(Epeak, [a * (E/100)**gamma for E in Epeak])
plt.loglog()