In [None]:
import os
import sys

import arviz as av
import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np

from astromodels import Cutoff_powerlaw, Model, PointSource
from threeML import (
    DataList,
    JointLikelihood,
    update_logging_level
)
update_logging_level("FATAL")

parent_dir = os.path.abspath("..")
if parent_dir not in sys.path:
    sys.path.append(parent_dir)
from zusammen import DataSet
from zusammen.spectral_plot import display_posterior_model_counts

example_dir = os.path.abspath("../examples")
if example_dir not in sys.path:
    sys.path.append(example_dir)
from cpl_prime import Cutoff_powerlaw_prime

### Load data

In [None]:
inference_folder = "inference/"
data_name = "data_2_sig_5"

inference_name = "relaxed_2_sig_5_1000"

In [None]:
ds = DataSet.from_hdf5_file(inference_folder + data_name + ".h5")
data = ds.to_stan_dict()
res = av.from_netcdf(inference_folder + inference_name + ".nc")

### Quick arviz stats

In [None]:
res.sample_stats.tree_depth.max()

In [None]:
%matplotlib widget
av.plot_trace(res)

In [None]:
div = res.sample_stats.diverging.stack(sample=("chain", "draw")).values
div.sum()

In [None]:
av.rhat(res)

### Properly load the data

In [None]:
N_intervals = data["N_intervals"]
N_grbs = data["N_grbs"]
length = res.posterior.gamma.shape[0] * res.posterior.gamma.shape[1]

alpha = np.zeros((N_intervals, length))
log_ec = np.zeros((N_intervals, length))
K_prime = np.zeros((N_intervals, length))
K= np.zeros((N_intervals, length))
log_energy_flux = np.zeros((N_intervals, length))
log_epeak = np.zeros((N_intervals, length))
if inference_name.startswith("global"):
    gamma = np.zeros((length))
    log_Nrest = np.zeros((length))
else:
    gamma = np.zeros((N_grbs, length))
    log_Nrest = np.zeros((N_grbs, length))
div = np.zeros((N_intervals, length))
samples = np.zeros((N_intervals, 3, length))
dl = []

for id in range(N_intervals):
    alpha[id] = res.posterior.alpha.stack(sample=("chain", "draw")).values[id]

    log_ec[id] = res.posterior.log_ec.stack(sample=("chain", "draw")).values[id]
    K_prime[id] = res.posterior.K.stack(sample=("chain", "draw")).values[id]
    K[id] = (10**log_ec[id])**(-alpha[id])

    log_epeak[id] = res.posterior.log_epeak.stack(sample=("chain", "draw")).values[id]
    log_energy_flux[id] = res.posterior.log_energy_flux.stack(sample=("chain", "draw")).values[id]

    div[id] = res.sample_stats.diverging.stack(sample=("chain", "draw")).values

    samples[id] = np.vstack((K_prime[id], alpha[id], 10.**log_ec[id]))

if inference_name.startswith("global"):
    gamma = res.posterior.gamma.stack(sample=("chain", "draw")).values
    log_Nrest = res.posterior.log_Nrest.stack(sample=("chain", "draw")).values
else:
    for id in range(N_grbs):
        gamma[id] = res.posterior.gamma.stack(sample=("chain", "draw")).values[id]
        log_Nrest[id] = res.posterior.log_Nrest.stack(sample=("chain", "draw")).values[id]

In [None]:
if inference_name.startswith("global"):
    print(gamma.mean())
    print(av.hdi(gamma, 0.683))
    print(av.hdi(gamma, 0.954))
else:
    print(np.mean(gamma,1))
    print(av.hdi(gamma.T, 0.683))
    print(av.hdi(gamma.T, 0.954))

Show posterior model counts

In [None]:
%matplotlib widget

bc = Cutoff_powerlaw_prime()

bc.index.bounds = (None, None)
bc.K.bounds = (None, None)
bc.xc.bounds = (None, None)

model = Model(PointSource("ps",0,0, spectral_shape=bc))

# for id in range(2):#range(data["N_intervals"]):
id = 1
display_posterior_model_counts(
    ds.get_data_list_of_interval(id)[1], model, samples[id].T[::20], min_rate=1e-99
)

3ML fit

In [None]:
F_3ml, epeak_3ml = np.zeros(data["N_intervals"]), np.zeros(data["N_intervals"])

cpl = Cutoff_powerlaw(piv=100,K=1e-1,xc=200)
model = Model(PointSource("ps",0,0, spectral_shape=cpl))

for i in range(data["N_intervals"]):
    dl = ds.get_data_list_of_interval(i)
    ba = JointLikelihood(model,DataList(*dl))
    ba.fit()
    ec_3ml = ba.results.get_data_frame()["value"]["ps.spectrum.main.Cutoff_powerlaw.xc"]
    alpha_3ml = ba.results.get_data_frame()["value"]["ps.spectrum.main.Cutoff_powerlaw.index"]
    epeak_3ml[i] = (2 + alpha_3ml) * ec_3ml
    F_3ml[i] = ba.results.get_flux(10*u.keV, 10e4*u.keV)["flux"][0].value

Show GC

In [None]:
%matplotlib widget
def gc_log(log_epeak, log_Nrest, gamma, z, dl):
    return log_Nrest - (1.099 + 2 * np.log10(dl)) + gamma * (np.log10(1 + z) + log_epeak - 2)

plt.scatter(np.mean(log_epeak, 1), np.mean(log_energy_flux,1))
log_epeak_sort = np.linspace(0.5,3)
z = [data["z"][0]] + [j for i,j in zip(data["z"], data["z"][1:]) if i != j]
d_l = [data["dl"][0]] + [j for i,j in zip(data["dl"], data["dl"][1:]) if i != j]
for i in range(data["N_grbs"]):
    plt.plot(log_epeak_sort, gc_log(log_epeak_sort, 52, 1.5, z[i], d_l[i]))