## Imports and Setup

In [None]:
import os
import sys

import arviz as av
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from threeML import update_logging_level
update_logging_level("FATAL")

parent_dir = os.path.abspath("..")
if parent_dir not in sys.path:
    sys.path.append(parent_dir)
from zusammen import DataSet

sns.set_theme(context="paper")

mpl.use("pgf")
pgf_with_latex = {
    "text.usetex": True,
    "font.family": "serif",
    "axes.labelsize": 10,
    "font.size": 10,
    "legend.fontsize": 8,
    "xtick.labelsize": 8,
    "ytick.labelsize": 8,
    # "pgf.texsystem": "lualatex",
    "pgf.preamble": "\n".join([
        r"\usepackage{siunitx}",
        r"\DeclareSIUnit{\erg}{erg}"
        ])
    }
mpl.rcParams.update(pgf_with_latex)

width = 455

In [None]:
simulation_folder = "simulation/"
survey_name = "corr_cpl_survey"

inference_folder = "inference/"
data_name = "data"

model_name = "cpl_simple_chunked_gc_relaxed"
inference_name = model_name

In [None]:
ds = DataSet.from_hdf5_file(inference_folder + data_name + ".h5")
data = ds.to_stan_dict()
res = av.from_netcdf(inference_folder + inference_name + ".nc")

#### Set size

In [None]:
def pt_to_inch(pt):
    return pt / 72.27

def set_size(width_pt, fraction=1, subplots=(1, 1), height=None):
    """Set figure dimensions to sit nicely in our document.

    Parameters
    ----------
    width_pt: float
            Document width in points
    fraction: float, optional
            Fraction of the width which you wish the figure to occupy
    subplots: array-like, optional
            The number of rows and columns of subplots.
    Returns
    -------
    fig_dim: tuple
            Dimensions of figure in inches
    """
    # Width of figure (in pts)
    fig_width_pt = width_pt * fraction

    # Golden ratio to set aesthetic figure height
    golden_ratio = (5**.5 - 1) / 2

    # Figure height in inches
    if not height:
        fig_height_in = pt_to_inch(fig_width_pt) * golden_ratio * (subplots[0] / subplots[1])
    else:
        fig_height_in = pt_to_inch(height)

    return (pt_to_inch(fig_width_pt), fig_height_in)

## Band function

In [None]:
def band(E, Ec, piv, K, alpha, beta):
    ret = np.zeros(E.shape)
    for i, e in enumerate(E):
        if e < (alpha - beta) * Ec:
            ret[i] = K * (e / piv)**alpha * np.exp(-e / Ec)
        else:
            ret[i] = K * ((alpha - beta) * Ec / piv)**(alpha - beta) * np.exp(beta - alpha) * (e / piv)**beta
    return ret

In [None]:
K = 10
alpha = -1
beta = -2.2
Ec = 1
piv = .1
x1 = 1e-2
x2 = 1e2

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=set_size(width, height=270, subplots=(2,1)), height_ratios=(3,2))

E = np.logspace(np.log(x1), np.log(x2), 100)
ax1.plot(E, K * (E / piv)**alpha * np.exp(-E / Ec), "b--")
ax1.plot(E, band(E, Ec, piv, K, alpha, beta), "b")

ax1.loglog()
ax1.set_xlim(x1, x2)
ax1.set_ylim(1.5e-6, 1e3)

ax1.text(.98,.88, 'a)', horizontalalignment='right', transform=ax1.transAxes, bbox={"boxstyle": "square", "facecolor": ax1.get_facecolor(), "edgecolor": "gray"})
ax1.set_ylabel(r"$B$ [$\si{\per\second\per\centi\meter\squared\per\mega\electronvolt}$]")


ax2.plot(E, K * 1.602e-6 * E**2 * (E / piv)**alpha * np.exp(-E / Ec), "b--")
ax2.plot(E, 1.602e-6 * E**2 * band(E, Ec, piv, K, alpha, beta), "b")

ax2.loglog()
ax2.set_ylim(8e-9,2e-6)
ax2.text(.98,.82, 'b)', horizontalalignment='right', transform=ax2.transAxes, bbox={"boxstyle": "square", "facecolor": ax2.get_facecolor(), "edgecolor": "gray"})
ax2.set_xlabel(r"Photon Energy $E_\gamma$ $[\si{\mega\electronvolt}]$")
ax2.set_ylabel(r"$E_\gamma^2 B$ [$\si{\erg\per\second\per\centi\meter\squared}$]")


fig.align_ylabels([ax1,ax2])
fig.tight_layout()
# fig.savefig("band.pgf")

## PPC

In [None]:
from posterior_predictive_check import PPC

p = PPC(data, res)

chain = 0
draws = 100
grb = 0
interval = 1
detector = 1

In [None]:
cenergies, ppc_counts = p.ppc(chain, draws, grb, interval, detector)
ppc_mu, ppc_1s, ppc_2s = PPC.ppc_summary(ppc_counts)

In [None]:
%matplotlib widget
fig, ax = plt.subplots(1, 1, figsize=set_size(width, height=270))

plt.stairs(data["observed_counts"][interval,detector], cenergies, color="#dd8452")
plt.stairs(ppc_1s[0], cenergies, baseline=ppc_1s[1], fill=True, color="#55a868", alpha=.5)
plt.stairs(ppc_2s[0], cenergies, baseline=ppc_2s[1], fill=True, color="#4c72b0", alpha=.2)

# plt.ylim(.8,1e2)
plt.xlabel("Energy")
plt.ylabel("Count Rate")
plt.loglog()