In [None]:
import os
import sys

import arviz as av
import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as stats

from threeML import (
    DataList,
    JointLikelihood,
    display_spectrum_model_counts,
    update_logging_level,
)
update_logging_level("FATAL")
from astromodels import Cutoff_powerlaw, Model, PointSource
from cosmogrb.universe.survey import Survey

parent_dir = os.path.abspath("..")
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

from zusammen import AnalysisBuilder, DataSet
from zusammen.spectral_plot import display_posterior_model_counts
from zusammen.stan_models.stan_model import get_model
from cpl_prime import Cutoff_powerlaw_prime

Import the survey and process the GRBs

In [None]:
survey = Survey.from_file('data/survey.h5')
ab = AnalysisBuilder(survey, use_bb=True, intervals_min=5, sig_min=10, all_above_limit=False)

In [None]:
ab.write_yaml("test_proc.yml")

In [None]:
ds = DataSet.from_yaml("test_proc.yml")

In [None]:
ds.to_hdf5_file("sgrb.h5")

In [None]:
ds = DataSet.from_hdf5_file('sgrb.h5')

In [None]:
data = ds.to_stan_dict()
data

In [None]:
data_arviz = {i: data[i] for i in ("grb_id", "observed_counts", "response", "z", "dl")}
data_arviz

In [None]:
maxl = []
for i,j in enumerate(ds.to_stan_dict()["observed_counts"]):
    maxl.append(j.max())
    print(i, j.max())

print(max(maxl))

In [None]:
%matplotlib widget
i,j = 0,0
plt.plot(ds.to_stan_dict()["response"][i,j].T @ ds.to_stan_dict()['observed_counts'][i,j])

Make Stan model

In [None]:
m = get_model("cpl_simple_chunked_gc_relaxed")

In [None]:
m.clean_model()

In [None]:
m.build_model(opt_exp=True)

In [None]:
data = ds.to_stan_dict()

n_threads = 2
n_chains = 2
n_warmup = 1000
n_sampling = 500

fit = m.model.sample(
    data=data,
    chains=n_chains,
    parallel_chains=n_chains,
    threads_per_chain=n_threads,
    inits= {
        'alpha': -1 * np.ones(data['N_intervals']),
        'log_ec': 2 * np.ones(data['N_intervals']),

        # 'log_energy_flux': -7 * np.ones(data['N_intervals']),
        # 'log_K': -1 * np.ones(data['N_intervals']),

        # 'log_energy_flux_mu_raw': 0,
        # 'log_energy_flux_sigma': 1,
        # 'log_energy_flux_raw': np.zeros(data['N_intervals']),

        'gamma_sig_meta': 1,
        'log_Nrest_sig_meta': 1,
        'gamma_mu_meta': 1.5,
        'log_Nrest_mu_meta': 52,
        'gamma': 1.5 * np.ones(data['N_grbs']),
        'log_Nrest': 52 * np.ones(data['N_grbs']),
    },  # type: ignore
    # seed=1234,
    iter_warmup=n_warmup,
    iter_sampling=n_sampling,
    # max_treedepth=12,
    # adapt_delta=0.99,
    # step_size=0.1,
    show_progress=True,
    # show_console=True,
    refresh=1
)

In [None]:
fit

In [None]:
fit.diagnose()

Import Stan results into arviz

In [None]:
res = av.from_cmdstanpy(fit)

In [None]:
# res.add_groups(observed_data=data_arviz)

In [None]:
res.to_netcdf("inference_data/testing_gc_relaxed_data.nc")

In [None]:
res = av.from_netcdf("inference_data/testing_gc_relaxed.nc")

In [None]:
res

In [None]:
res.sample_stats.tree_depth.max()

In [None]:
%matplotlib widget
av.plot_trace(res)

In [None]:
%matplotlib widget
for i in range(data['N_intervals']):
    av.plot_trace(res, var_names=["ec"], coords={"ec_dim_0": i})

In [None]:
%matplotlib widget
av.plot_pair(res, divergences=True)

In [None]:
div = res.sample_stats.diverging.stack(sample=("chain", "draw")).values
div.sum()

In [None]:
res.posterior.gamma.shape

Load parameters  

In [None]:
N_intervals = res.posterior.alpha.shape[2]
N_grbs = res.posterior.gamma.shape[2]
length = res.posterior.gamma.shape[0] * res.posterior.gamma.shape[1]

alpha = np.zeros((N_intervals, length))
log_ec = np.zeros((N_intervals, length))
K_prime = np.zeros((N_intervals, length))
K= np.zeros((N_intervals, length))
log_energy_flux = np.zeros((N_intervals, length))
log_epeak = np.zeros((N_intervals, length))
gamma = np.zeros((N_grbs, length))
log_Nrest = np.zeros((N_grbs, length))
div = np.zeros((N_intervals, length))
samples = np.zeros((N_intervals, 3, length))
dl = []

for id in range(N_intervals):
    alpha[id] = res.posterior.alpha.stack(sample=("chain", "draw")).values[id]

    log_ec[id] = res.posterior.log_ec.stack(sample=("chain", "draw")).values[id]
    K_prime[id] = res.posterior.K.stack(sample=("chain", "draw")).values[id]
    K[id] = (10**log_ec[id])**(-alpha[id])

    log_epeak[id] = res.posterior.log_epeak.stack(sample=("chain", "draw")).values[id]
    log_energy_flux[id] = res.posterior.log_energy_flux.stack(sample=("chain", "draw")).values[id]

    div[id] = res.sample_stats.diverging.stack(sample=("chain", "draw")).values

    samples[id] = np.vstack((K_prime[id], alpha[id], 10.**log_ec[id]))

    dl.append(ds.get_data_list_of_interval(id))

# log_epeak = np.log10(2 + alpha) + log_ec

for id in range(N_grbs):
    gamma[id] = res.posterior.gamma.stack(sample=("chain", "draw")).values[id]
    log_Nrest[id] = res.posterior.log_Nrest.stack(sample=("chain", "draw")).values[id]

In [None]:
div.sum()

In [None]:
np.mean(log_Nrest,1), np.mean(gamma,1)

In [None]:
bc = Cutoff_powerlaw_prime()

bc.index.bounds = (None, None)
bc.K.bounds = (None, None)
bc.xc.bounds = (None, None)

model = Model(PointSource("ps",0,0, spectral_shape=bc))

In [None]:
%matplotlib widget
plt.scatter(log_ec, alpha, alpha=0.1)

In [None]:
%matplotlib widget
plt.scatter(K, alpha, alpha=0.1)

In [None]:
%matplotlib widget
for id in range(1):#range(data["N_intervals"]):
    display_posterior_model_counts(
        dl[id][1], model, samples[id].T[::20], min_rate=1e-99
    )

In [None]:
np.mean(log_epeak, 1)

In [None]:
%matplotlib widget
def gc_log(log_epeak, log_Nrest, gamma, z, dl):
    return log_Nrest - (1.099 + 2 * np.log10(dl)) + gamma * (np.log10(1 + z) + log_epeak - 2)

plt.scatter(np.mean(log_epeak, 1), np.mean(log_energy_flux,1))
log_epeak_sort = np.linspace(0.5,3)
z = [data["z"][0]] + [j for i,j in zip(data["z"], data["z"][1:]) if i != j]
d_l = [data["dl"][0]] + [j for i,j in zip(data["dl"], data["dl"][1:]) if i != j]
for i in range(data["N_grbs"]):
    plt.plot(log_epeak_sort, gc_log(log_epeak_sort, 52, 1.5, z[i], d_l[i]))
plt.show()

In [None]:
cpl = Cutoff_powerlaw(piv=1,K=1e-1,xc=200)


model = Model(PointSource("ps",0,0, spectral_shape=cpl))

ba = JointLikelihood(model,DataList(*dl[0]))

In [None]:
stats.bayes_mvs(gamma[0])

In [None]:
ba.fit()

In [None]:
10**log_energy_flux[0].mean()

In [None]:
ba.results.get_flux(10*u.keV, 10e4*u.keV)["flux"][0].value

In [None]:
display_spectrum_model_counts(ba)

In [None]:
F_3ml, epeak_3ml = np.zeros(data["N_intervals"]), np.zeros(data["N_intervals"])

cpl = Cutoff_powerlaw(piv=100,K=1e-1,xc=200)
model = Model(PointSource("ps",0,0, spectral_shape=cpl))

for i in range(data["N_intervals"]):
    dl = ds.get_data_list_of_interval(i)
    ba = JointLikelihood(model,DataList(*dl))
    ba.fit()
    ec_3ml = ba.results.get_data_frame()["value"]["ps.spectrum.main.Cutoff_powerlaw.xc"]
    alpha_3ml = ba.results.get_data_frame()["value"]["ps.spectrum.main.Cutoff_powerlaw.index"]
    epeak_3ml[i] = (2 + alpha_3ml) * ec_3ml
    F_3ml[i] = ba.results.get_flux(10*u.keV, 10e4*u.keV)["flux"][0].value

In [None]:
epeak_mean = np.array([10**(i.mean()) for i in log_epeak])
F_mean = np.array([10**(i.mean()) for i in log_energy_flux])
epeak_mean, F_mean

In [None]:
epeak_3ml, F_3ml

In [None]:
%matplotlib widget

def gc(epeak, Nrest, gamma, z, dl):
    return Nrest / ( 4 * np.pi * dl * dl) * (epeak * (1 + z) / 100)**gamma

plt.scatter(epeak_3ml, F_3ml)

x = np.linspace(1,1000)
z = [data["z"][0]] + [j for i,j in zip(data["z"], data["z"][1:]) if i != j]
d_l = [data["dl"][0]] + [j for i,j in zip(data["dl"], data["dl"][1:]) if i != j]
for zi, dli in zip(z,d_l):
    plt.plot(x, gc(x, 1e52, 1.5, zi, dli))

plt.scatter(10**(np.mean(log_epeak,1)), 10**(np.mean(log_energy_flux,1)))

plt.loglog()

In [None]:
%matplotlib widget
av.plot_kde(log_epeak, log_energy_flux)

In [None]:
%matplotlib widget
for i in gamma:
    av.plot_kde(i)

In [None]:
%matplotlib widget
av.plot_kde(log_Nrest[0])

In [None]:
for _, j in av.rhat(res).items():
    print((np.array(j) - np.ones(len(j))).mean())