## Imports and Setup

In [None]:
import os
import sys

import arviz as av
import matplotlib.pyplot as plt
import numpy as np

from threeML import update_logging_level
update_logging_level("FATAL")

parent_dir = os.path.abspath("..")
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

from zusammen import DataSet

## Arviz

### Load the data

In [None]:
data_folder = "real_data/"
data_name = "data"

inference_folder = "inference/"
inference_name = "real_relaxed_sig_5_500"

In [None]:
ds = DataSet.from_hdf5_file(data_folder + data_name + ".h5")
data = ds.to_stan_dict()
res = av.from_netcdf(inference_folder + inference_name + ".nc")

In [None]:
N_intervals = data["N_intervals"]
N_grbs = data["N_grbs"]
length = res.posterior.gamma.shape[0] * res.posterior.gamma.shape[1]
chains = res.posterior.gamma.shape[0]
draws = res.posterior.gamma.shape[1]

alpha = np.zeros((N_intervals, length))
log_ec = np.zeros((N_intervals, length))
K_prime = np.zeros((N_intervals, length))
K= np.zeros((N_intervals, length))
log_energy_flux = np.zeros((N_intervals, length))
log_epeak = np.zeros((N_intervals, length))
gamma = np.zeros((N_grbs, length))
log_Nrest = np.zeros((N_grbs, length))
div = np.zeros((N_intervals, length))
samples = np.zeros((N_intervals, 3, length))
z = data["z"]
dl = data["dl"]

for id in range(N_intervals):
    alpha[id] = res.posterior.alpha.stack(sample=("chain", "draw")).values[id]
    log_ec[id] = res.posterior.log_ec.stack(sample=("chain", "draw")).values[id]
    K[id] = res.posterior.K.stack(sample=("chain", "draw")).values[id]
    log_epeak[id] = res.posterior.log_epeak.stack(sample=("chain", "draw")).values[id]
    log_energy_flux[id] = res.posterior.log_energy_flux.stack(sample=("chain", "draw")).values[id]

for id in range(N_grbs):
    gamma[id] = res.posterior.gamma.stack(sample=("chain", "draw")).values[id]
    log_Nrest[id] = res.posterior.log_Nrest.stack(sample=("chain", "draw")).values[id]

# gamma_mu_meta = res.posterior.gamma_mu_meta.stack(sample=("chain", "draw")).values
# log_Nrest_mu_meta = res.posterior.log_Nrest_mu_meta.stack(sample=("chain", "draw")).values

### HDI
$1 \sigma$

In [None]:
print(np.mean(gamma, 1))
print(av.hdi(gamma.T, 0.683))
print(sum([i < 1.5 and j > 1.5 for i,j in av.hdi(gamma.T, 0.683)])/gamma.shape[0])

$ 2 \sigma $

In [None]:
print(av.hdi(gamma.T, 0.954))
print(sum([i < 1.5 and j > 1.5 for i,j in av.hdi(gamma.T, 0.954)])/gamma.shape[0])

$\mu_\gamma$

In [None]:
print(np.mean(gamma_mu_meta))
print(av.hdi(gamma_mu_meta), 0.683)

### Plots

In [None]:
%matplotlib widget

fig, ax = plt.subplots(1, 1)
for i in gamma:
    av.plot_kde(i, ax=ax)
# ax.set_xlim(1.377,1.545)
# ax.set_ylim(-.5,35)
ax.set_xlabel(r"$\gamma$")

In [None]:
fig, ax = plt.subplots(1, 1)
av.plot_kde(gamma_mu_meta, ax=ax)
ax.set_xlim(1.377,1.545)
ax.set_ylim(0,35)
ax.set_xlabel(r"$\mu_{\log N_\mathrm{rest}}$")

In [None]:
%matplotlib widget

fig, ax = plt.subplots(1, 1)
av.plot_kde(log_epeak[0:6], log_energy_flux[0:6], ax=ax)
ax.plot([1.35, 2.4], 52 - (1.099 + 2 * np.log10(dl[0])) + 1.5 * (np.log10(1 + z[0]) + np.array([1.35, 2.4]) - 2), "b--")
# ax.set_xlim(1.375,2.375)
# ax.set_ylim(-5.7,-4.2)
# ax.set_xlabel(r"$\log E_\mathrm{peak}$")
# ax.set_ylabel(r"log F")

### Posterior Predictive Checks

In [None]:
from posterior_predictive_check import PPC

In [None]:
p = PPC(data, res)
cenergies, ppc_sampled_counts = p.ppc(chain=1, draws=500, grb=0, interval=1, detector=1, interval_in_grb=True)
ppc_mu, ppc_1s, ppc_2s = PPC.ppc_summary(ppc_sampled_counts)

In [None]:
%matplotlib widget
plt.stairs(data["observed_counts"][1,1], cenergies, color="#dd8452")
# plt.stairs(ppc_mu, cenergies)
plt.stairs(ppc_1s[0], cenergies, baseline=ppc_1s[1], fill=True, color="#55a868", alpha=.5)
plt.stairs(ppc_2s[0], cenergies, baseline=ppc_2s[1], fill=True, color="#4c72b0", alpha=.2)

# plt.ylim(.8,1e2)
plt.xlabel("Energy")
plt.ylabel("Count Rate")
plt.loglog()