## Imports and Setup

In [None]:
import os
import sys

import arviz as av
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as stats
import seaborn as sns

parent_dir = os.path.abspath("..")
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

sns.set_theme()

mpl.use("pgf")
pgf_with_latex = {
    "text.usetex": True,
    "font.family": "serif",
    "axes.labelsize": 10,
    "font.size": 10,
    "legend.fontsize": 8,
    "xtick.labelsize": 8,
    "ytick.labelsize": 8,
    # "pgf.texsystem": "lualatex",
    "pgf.preamble": "\n".join([
        r"\usepackage{siunitx}",
        r"\DeclareSIUnit{\erg}{erg}"
        ])
    }
mpl.rcParams.update(pgf_with_latex)

width = 455

#### Set size

In [None]:
def pt_to_inch(pt):
    return pt / 72.27

def set_size(width_pt, fraction=1, subplots=(1, 1), height=None):
    """Set figure dimensions to sit nicely in our document.

    Parameters
    ----------
    width_pt: float
            Document width in points
    fraction: float, optional
            Fraction of the width which you wish the figure to occupy
    subplots: array-like, optional
            The number of rows and columns of subplots.
    Returns
    -------
    fig_dim: tuple
            Dimensions of figure in inches
    """
    # Width of figure (in pts)
    fig_width_pt = width_pt * fraction

    # Golden ratio to set aesthetic figure height
    golden_ratio = (5**.5 - 1) / 2

    # Figure height in inches
    if not height:
        fig_height_in = pt_to_inch(fig_width_pt) * golden_ratio * (subplots[0] / subplots[1])
    else:
        fig_height_in = pt_to_inch(height)

    return (pt_to_inch(fig_width_pt), fig_height_in)

## Band function

In [None]:
def band(E, Ec, piv, K, alpha, beta):
    ret = np.zeros(E.shape)
    for i, e in enumerate(E):
        if e < (alpha - beta) * Ec:
            ret[i] = K * (e / piv)**alpha * np.exp(-e / Ec)
        else:
            ret[i] = K * ((alpha - beta) * Ec / piv)**(alpha - beta) * np.exp(beta - alpha) * (e / piv)**beta
    return ret

In [None]:
K = 10
alpha = -1
beta = -2.2
Ec = 1
piv = .1
x1 = 1e-2
x2 = 1e2

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=set_size(width, height=280, subplots=(2,1)), height_ratios=(3,2))

E = np.logspace(np.log(x1), np.log(x2), 100)
ax1.plot(E, K * (E / piv)**alpha * np.exp(-E / Ec), "b--")
ax1.plot(E, band(E, Ec, piv, K, alpha, beta), "b")

ax1.loglog()
ax1.set_xlim(x1, x2)
ax1.set_ylim(1.5e-6, 1e3)

ax1.text(.98,.88, 'a)', horizontalalignment='right', transform=ax1.transAxes, bbox={"boxstyle": "square", "facecolor": ax1.get_facecolor(), "edgecolor": "gray"})
ax1.set_ylabel(r"$B$ [$\si{\per\second\per\centi\meter\squared\per\mega\electronvolt}$]")


ax2.plot(E, K * 1.602e-6 * E**2 * (E / piv)**alpha * np.exp(-E / Ec), "b--")
ax2.plot(E, 1.602e-6 * E**2 * band(E, Ec, piv, K, alpha, beta), "b")

ax2.loglog()
ax2.set_ylim(8e-9,2e-6)
ax2.text(.98,.82, 'b)', horizontalalignment='right', transform=ax2.transAxes, bbox={"boxstyle": "square", "facecolor": ax2.get_facecolor(), "edgecolor": "gray"})
ax2.set_xlabel(r"Photon Energy $E_\gamma$ $[\si{\mega\electronvolt}]$")
ax2.set_ylabel(r"$E_\gamma^2 B$ [$\si{\erg\per\second\per\centi\meter\squared}$]")


fig.align_ylabels([ax1,ax2])
fig.tight_layout()
fig.savefig("band.pgf")

## Arviz

In [None]:
res = av.from_netcdf("inference_data/testing_gc_2.nc")

In [None]:
N_intervals = res.posterior.alpha.shape[2]
N_grbs = res.posterior.gamma.shape[2]
length = res.posterior.gamma.shape[0] * res.posterior.gamma.shape[1]

alpha = np.zeros((N_intervals, length))
log_ec = np.zeros((N_intervals, length))
K_prime = np.zeros((N_intervals, length))
K= np.zeros((N_intervals, length))
log_energy_flux = np.zeros((N_intervals, length))
log_epeak = np.zeros((N_intervals, length))
gamma = np.zeros((N_grbs, length))
log_Nrest = np.zeros((N_grbs, length))
div = np.zeros((N_intervals, length))
samples = np.zeros((N_intervals, 3, length))
dl = []

for id in range(N_intervals):
    alpha[id] = res.posterior.alpha.stack(sample=("chain", "draw")).values[id]

    log_ec[id] = res.posterior.log_ec.stack(sample=("chain", "draw")).values[id]
    K[id] = res.posterior.K.stack(sample=("chain", "draw")).values[id]

    log_epeak[id] = res.posterior.log_epeak.stack(sample=("chain", "draw")).values[id]
    log_energy_flux[id] = res.posterior.log_energy_flux.stack(sample=("chain", "draw")).values[id]

    div[id] = res.sample_stats.diverging.stack(sample=("chain", "draw")).values

    samples[id] = np.vstack((K[id], alpha[id], 10.**log_ec[id]))

    # dl.append(ds.get_data_list_of_interval(id))

for id in range(N_grbs):
    gamma[id] = res.posterior.gamma.stack(sample=("chain", "draw")).values[id]
    log_Nrest[id] = res.posterior.log_Nrest.stack(sample=("chain", "draw")).values[id]

gamma_mu_meta = res.posterior.gamma_mu_meta.stack(sample=("chain", "draw")).values
log_Nrest_mu_meta = res.posterior.log_Nrest_mu_meta.stack(sample=("chain", "draw")).values

In [None]:
fig, ax = plt.subplots(1, 1, figsize=set_size(width))
av.plot_kde(gamma_mu_meta, ax=ax)
ax.set_xlim(1.377,1.545)
ax.set_ylim(-.5,35)
ax.set_xlabel(r"$\mu_{\gamma}$")

In [None]:
fig, ax = plt.subplots(1, 1, figsize=set_size(width))
av.plot_kde(gamma_mu_meta, ax=ax)
ax.set_xlim(1.377,1.545)
ax.set_ylim(0,35)
ax.set_xlabel(r"$\mu_{\log N_\mathrm{rest}}$")