## Imports and Setup

In [None]:
import os
import sys

import arviz as av
import matplotlib.pyplot as plt
import numpy as np

from threeML import update_logging_level
update_logging_level("FATAL")

parent_dir = os.path.abspath("..")
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

from zusammen import DataSet
from posterior_predictive_check import PPC

## Arviz

### Load the data

In [None]:
data_folder = "simulation/"
data_name = "data_2_sig_5"
survey_name = "survey_2"

inference_folder = "inference/"
inference_name = "simulated_2_sig_5_1000"
# inference_name = "simulated_int_sig_5_1000"
# inference_name = "simulated_relaxed_2_sig_5_1000"
# inference_name = "simulated_global_2_sig_5_1000"


# data_folder = "real_data/"
# data_name = "data"

# inference_folder = "inference/"
# inference_name = "real_sig_5_1000"
# inference_name = "real_int_sig_5_1000"
# inference_name = "real_relaxed_3_sig_5_1000"
# inference_name = "real_global_sig_5_1000"

In [None]:
if "real" in data_folder:
    model = "real"
else:
    model = "simulated"
    # survey = Survey.from_file(data_folder + survey_name + ".h5")
if "relaxed" in inference_name:
    model += "_relaxed"
elif "global" in inference_name:
    model += "_global"
elif "int" in inference_name:
    model += "_int"

ds = DataSet.from_hdf5_file(data_folder + data_name + ".h5")
data = ds.to_stan_dict()
res = av.from_netcdf(inference_folder + inference_name + ".nc")

assert data["N_intervals"] == res.posterior.alpha.shape[2]
model

In [None]:
N_intervals = res.posterior.alpha.shape[2]
N_grbs = res.posterior.gamma.shape[2]
length = res.posterior.gamma.shape[0] * res.posterior.gamma.shape[1]
chains = res.posterior.gamma.shape[0]
draws = res.posterior.gamma.shape[1]

alpha = np.zeros((N_intervals, length))
log_ec = np.zeros((N_intervals, length))
K_prime = np.zeros((N_intervals, length))
K= np.zeros((N_intervals, length))
log_energy_flux = np.zeros((N_intervals, length))
log_epeak = np.zeros((N_intervals, length))
gamma = np.zeros((N_grbs, length))
log_Nrest = np.zeros((N_grbs, length))
div = np.zeros((N_intervals, length))
samples = np.zeros((N_intervals, 3, length))
z = data["z"]
dl = data["dl"]

for id in range(N_intervals):
    alpha[id] = res.posterior.alpha.stack(sample=("chain", "draw")).values[id]
    log_ec[id] = res.posterior.log_ec.stack(sample=("chain", "draw")).values[id]
    K[id] = res.posterior.K.stack(sample=("chain", "draw")).values[id]
    log_epeak[id] = res.posterior.log_epeak.stack(sample=("chain", "draw")).values[id]
    log_energy_flux[id] = res.posterior.log_energy_flux.stack(sample=("chain", "draw")).values[id]

for id in range(N_grbs):
    gamma[id] = res.posterior.gamma.stack(sample=("chain", "draw")).values[id]
    log_Nrest[id] = res.posterior.log_Nrest.stack(sample=("chain", "draw")).values[id]

if "relaxed" not in inference_name and "global" not in inference_name:
    gamma_mu_meta = res.posterior.gamma_mu_meta.stack(sample=("chain", "draw")).values
    log_Nrest_mu_meta = res.posterior.log_Nrest_mu_meta.stack(sample=("chain", "draw")).values

if "int" in inference_name:
    int_scatter = res.posterior.int_scatter.stack(sample=("chain", "draw")).values

In [None]:
np.mean(int_scatter,1)

### HDI
GC parameters

In [None]:
print(np.mean(gamma, 1))
print(av.hdi(gamma.T, 0.954))
print(sum([i < 1.5 and j > 1.5 for i,j in av.hdi(gamma.T, 0.954)])/gamma.shape[0])

hyperpriors

In [None]:
if "global" in model:
    print("gamma")
    print(np.mean(res.posterior.gamma.stack(sample=("draw", "chain")).values))
    print(av.hdi(res.posterior.gamma, 0.954).gamma.values)

    print("\nNrest")
    print(10**np.mean(res.posterior.log_Nrest.stack(sample=("draw", "chain")).values))
    print(10**av.hdi(res.posterior.log_Nrest, 0.954).log_Nrest.values)
elif "relaxed" not in model:
    print("gamma")
    print(np.mean(res.posterior.gamma_mu_meta.stack(sample=("draw", "chain")).values))
    print(av.hdi(res.posterior.gamma_mu_meta, 0.954).gamma_mu_meta.values)

    print("\nlog_Nrest")
    print(10**np.mean(res.posterior.log_Nrest_mu_meta.stack(sample=("draw", "chain")).values))
    print(10**av.hdi(res.posterior.log_Nrest_mu_meta, 0.954).log_Nrest_mu_meta.values)

### Plots

In [None]:
%matplotlib widget
av.plot_trace(res)

In [None]:
%matplotlib widget

fig, ax = plt.subplots(1, 1)
for i in gamma:
    av.plot_kde(i, ax=ax)
# ax.set_xlim(1.377,1.545)
# ax.set_ylim(-.5,35)
ax.set_xlabel(r"$\gamma$")

In [None]:
fig, ax = plt.subplots(1, 1)
av.plot_kde(gamma_mu_meta, ax=ax)
# ax.set_xlim(1.377,1.545)
# ax.set_ylim(0,35)
ax.set_xlabel(r"$\mu_{\gamma}$")

In [None]:
np.argwhere(data["grb_id"] == 4)

In [None]:
%matplotlib widget

fig, ax = plt.subplots(1, 1)
av.plot_kde(log_epeak[22:31], log_energy_flux[22:31], ax=ax)
ax.plot([0, 2.4], 52 - (1.099 + 2 * np.log10(dl[23])) + 1.5 * (np.log10(1 + z[23]) + np.array([0, 2.4]) - 2), "b--")
ax.set_xlim(1,2.5)
ax.set_ylim(-7,-4.2)
# ax.set_xlabel(r"$\log E_\mathrm{peak}$")
# ax.set_ylabel(r"log F")

### PPC

In [None]:
in_1_sigma = []
in_2_sigma = []
for i,j,k in zip(data["observed_counts"][np.where(data["grb_id"] == grb_1 + 1)[0][interval_1],detector], ppc_1s_1.T, ppc_2s_1.T):
    in_1_sigma.append(i >= j[0] and i <= j[1])
    in_2_sigma.append(i >= k[0] and i <= k[1])
print(sum(in_1_sigma) / len(in_1_sigma))
print(sum(in_2_sigma) / len(in_2_sigma))

for i,j,k in zip(data["observed_counts"][np.where(data["grb_id"] == grb_2 + 1)[0][interval_2],detector], ppc_1s_2.T, ppc_2s_2.T):
    in_1_sigma.append(i >= j[0] and i <= j[1])
    in_2_sigma.append(i >= k[0] and i <= k[1])
print(sum(in_1_sigma) / len(in_1_sigma))
print(sum(in_2_sigma) / len(in_2_sigma))

### Inference Stats

In [None]:
av.rhat(res).max()

In [None]:
res.sample_stats.diverging.sum()

### Correlation

In [None]:
amean, (a1, a2) = np.mean(log_Nrest_mu_meta), av.hdi(log_Nrest_mu_meta, hdi_prob=0.954)
print(amean, a1-amean, a2-amean)
bmean, (b1, b2) = np.mean(gamma_mu_meta), av.hdi(gamma_mu_meta, hdi_prob=0.954)
print(bmean, b1-bmean, b2-bmean)

### Reverse Correlation

In [None]:
amean, (a1, a2) = np.mean(log_Nrest_mu_meta/gamma_mu_meta - 2), av.hdi(log_Nrest_mu_meta/gamma_mu_meta - 2, hdi_prob=0.685)
print(amean, a1-amean, a2-amean)
bmean, (b1, b2) = np.mean(1/gamma_mu_meta), av.hdi(1/gamma_mu_meta, hdi_prob=0.685)
print(bmean, b1-bmean, b2-bmean)