## Imports and Setup

In [None]:
import os
import sys

import arviz as av
from cosmogrb.universe.survey import Survey
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from plots import *
from threeML import update_logging_level

update_logging_level("FATAL")

parent_dir = os.path.abspath("..")
if parent_dir not in sys.path:
    sys.path.append(parent_dir)
from zusammen import DataSet

sns.set_theme(context="paper")

mpl.use("pgf")
pgf_with_latex = {
    "text.usetex": True,
    "font.family": "serif",
    "axes.labelsize": 10,
    "font.size": 10,
    "legend.fontsize": 8,
    "xtick.labelsize": 8,
    "ytick.labelsize": 8,
    "pgf.texsystem": "lualatex",
    "pgf.rcfonts": "False",
    "pgf.preamble": "\n".join([r"\usepackage{siunitx}", r"\DeclareSIUnit{\erg}{erg}"]),
}
mpl.rcParams.update(pgf_with_latex)

width = 455. / 72.27
height = width / 1.61803398875

In [None]:
# data_folder = "simulation/"
# data_name = "data_2_sig_5"
# survey_name = "survey_2"

# inference_folder = "inference/"
# inference_name = "simulated_2_sig_5_1000"
# inference_name = "simulated_int_sig_5_1000"
# inference_name = "simulated_relaxed_2_sig_5_1000"
# inference_name = "simulated_global_2_sig_5_1000"


data_folder = "real_data/"
data_name = "data"

inference_folder = "inference/"
inference_name = "real_sig_5_1000"
# inference_name = "real_int_sig_5_1000"
# inference_name = "real_relaxed_3_sig_5_1000"
# inference_name = "real_global_sig_5_1000"

In [None]:
plots_folder = "/Users/chrobin/LRZ Sync+Share/Uni/Bachelorarbeit/Thesis/figures/"
tables_folder = "/Users/chrobin/LRZ Sync+Share/Uni/Bachelorarbeit/Thesis/tables/"
if "real" in data_folder:
    model = "real"
else:
    model = "simulated"
    survey = Survey.from_file(data_folder + survey_name + ".h5")
if "relaxed" in inference_name:
    model += "_relaxed"
elif "global" in inference_name:
    model += "_global"
elif "int" in inference_name:
    model += "_int"

ds = DataSet.from_hdf5_file(data_folder + data_name + ".h5")
data = ds.to_stan_dict()
res = av.from_netcdf(inference_folder + inference_name + ".nc")

assert data["N_intervals"] == res.posterior.alpha.shape[2]
model

## Basics

### Band function

In [None]:
%matplotlib widget
fig = plot_band(width)
# fig.savefig(plots_folder + "band.pdf")

### Light Curve and Spectrum

In [None]:
# %matplotlib widget
fig = plot_light_curve_basics(model=model, data_folder=data_folder, grb_name="GRB160509374", detector="n1", width=width, height=height, siunitx=True)
fig.savefig(plots_folder + "lightcurve_ex.pdf")

In [None]:
fig = plot_spectrum_basics(model=model, data_folder=data_folder, grb_name = "GRB160509374", detector="n1", width=width, height=height, siunitx=True)
fig.savefig(plots_folder + "spectrum_ex.pdf")

## Methods

### Weak Light Curve

In [None]:
# %matplotlib widget
fig = plot_weak_light_curve(model=model, data_folder=data_folder, survey_name=survey_name, width=width, siunitx=True)
fig.savefig(plots_folder + "lc_insignificant.pdf")

## Results

### Light Curves

In [None]:
# %matplotlib widget

kwargs = {}
if "simulated" in model:
    kwargs["survey_name"] = survey_name
    kwargs["data_name"] = data_name
    kwargs["grb_names"] = ("SynthGRB_5","SynthGRB_10")
    kwargs["det_names"] = ("n1","n4")
else:
    kwargs["grb_names"] = ("GRB160509374","GRB120119170")
    kwargs["det_names"] = ("n1","n9")
fig = plot_light_curve(model, data_folder, width, siunitx=True, **kwargs)
fig.savefig(plots_folder + "lc_" + ["real" if "real" in model else "simulated"][0] + ".pdf")

### PPC

In [None]:
fig = plot_ppc(model, data_folder, ds, data, res, width, siunitx=True)
plt.savefig(plots_folder + "ppc_" + model + ".pdf")

### Corner Plot

In [None]:
# %matplotlib widget
fig = plot_corner(model, res, width)
fig.savefig(plots_folder + "corner_" + model + ".pdf")

### Violin Plot

In [None]:
# %matplotlib widget
if "simulated" in model:
    kwargs = {"survey_name": survey_name}
else:
    kwargs = {}
fig_gamma, fig_Nrest = plot_violin(model, data_folder, ds, res, width, height, **kwargs)
fig_gamma.savefig(plots_folder + "violin_gamma_" + model + ".pdf")
fig_Nrest.savefig(plots_folder + "violin_log_Nrest_" + model + ".pdf")

### Trace Plot

In [None]:
%matplotlib widget
fig = plot_trace(model, res, width, divergences="bottom")
# fig.savefig(plots_folder + "trace_" + model + ".pdf")

### GC

In [None]:
# %matplotlib widget
fig = plot_gc(model, data, res, width, height)
fig.savefig(plots_folder + "gc_" + model + ".pdf")

### Correlation of Hyperparameters

In [None]:
# %matplotlib widget
fig = plot_gc_kde(model, res, width)
fig.savefig(plots_folder + "meta_" + model + ".pdf")

### Tables

#### GRBs

In [None]:
if "simulated" in model:
    grbs = {}
    for i in range(len(survey)):
        if f"SynthGRB_{i}" not in list(ds._grbs.keys()):
            continue
        grbs[f"GRB {i}"] = []
        grbs[f"GRB {i}"].append(survey['SynthGRB_' + str(i)].grb.z)
        for j in ["alpha", "ep_start"]:
            grbs[f"GRB {i}"].append(survey['SynthGRB_' + str(i)].grb._source_params[j])
        grbs[f"GRB {i}"].append(1.00e52)
        grbs[f"GRB {i}"].append(1.50)
    grb_df = pd.DataFrame.from_dict(grbs, orient="index", columns=["z", "alpha", "ep", "Nrest", "gamma"]).round(2)
    grb_df["Nrest"] = grb_df["Nrest"].map(lambda x: '%.0e' % x)
    grb_df.index.name = "grb"
    grb_df.to_csv(tables_folder + "simulated.csv")

#### Inference Results

In [None]:
grbs = []
if "simulated" in model:
    for i in range(len(survey)):
        if f"SynthGRB_{i}" not in list(ds._grbs.keys()):
            continue
        grbs.append(f"GRB {i}")
else:
    with open(data_folder + "grb_names.yml") as f:
        real_names = yaml.load(f, Loader=yaml.SafeLoader)
    real_names = {k: v[3:] for k,v in real_names.items()}
    grbs = [real_names[i] for i in ds._grbs.keys()]

log_Nrest_mean = 10**np.mean(res.posterior.log_Nrest.stack(sample=("chain", "draw")).values, axis=1)
log_Nrest_hdi = 10**av.hdi(res.posterior.log_Nrest, hdi_prob=0.954).log_Nrest.values
log_Nrest = np.zeros((data["N_grbs"], 3))
for i, (mean, (hdi_lo, hdi_hi)) in enumerate(zip(log_Nrest_mean, log_Nrest_hdi)):
    log_Nrest[i][0] = mean
    log_Nrest[i][1] = hdi_lo
    log_Nrest[i][2] = hdi_hi

gamma_mean = np.mean(res.posterior.gamma.stack(sample=("chain", "draw")).values, axis=1)
gamma_hdi = av.hdi(res.posterior.gamma, hdi_prob=0.954).gamma.values
gamma = np.zeros((data["N_grbs"], 3))
for i, (mean, (hdi_lo, hdi_hi)) in enumerate(zip(gamma_mean, gamma_hdi)):
    gamma[i][0] = mean
    gamma[i][1] = hdi_lo
    gamma[i][2] = hdi_hi

if "int" in model:
    int_mean = np.mean(res.posterior.int_scatter.stack(sample=("chain", "draw")).values, axis=1)
    int_hdi = av.hdi(res.posterior.int_scatter, hdi_prob=0.954).int_scatter.values
    int = np.zeros((data["N_grbs"], 3))
    for i, (mean, (hdi_lo, hdi_hi)) in enumerate(zip(int_mean, int_hdi)):
        int[i][0] = mean
        int[i][1] = hdi_lo
        int[i][2] = hdi_hi
    grb_df = pd.DataFrame(columns=["logNrest","logNrest-","logNrest+","gamma","gamma-","gamma+","int","int-","int+"])
    for i, (n, g, s) in enumerate(zip(log_Nrest, gamma, int)):
        grb_df.loc[grbs[i]] = [*n, *g, *s]
else:
    grb_df = pd.DataFrame(columns=["logNrest","logNrest-","logNrest+","gamma","gamma-","gamma+",])
    for i, (n, g) in enumerate(zip(log_Nrest, gamma)):
        grb_df.loc[grbs[i]] = [*n, *g]

if "int" in model:
    grb_df["logNrest"] = grb_df["logNrest"].map(lambda x: x / 10**50)
    grb_df["logNrest-"] = grb_df["logNrest-"].map(lambda x: x / 10**50)
    grb_df["logNrest+"] = grb_df["logNrest+"].map(lambda x: x / 10**50)
    grb_df = grb_df.round(2)
else:
    for col in grb_df.columns:
        if "Nrest" not in col:
            grb_df[col] = grb_df[col].round(2)
    grb_df["logNrest"] = grb_df["logNrest"].map(lambda x: '%.2e' % x)
    grb_df["logNrest-"] = grb_df["logNrest-"].map(lambda x: '%.2e' % x)
    grb_df["logNrest+"] = grb_df["logNrest+"].map(lambda x: '%.2e' % x)
if "real" in model:
    for k, v in ds._grbs.items():
        grb_df.loc[real_names[k], "z"] = v.z
grb_df.index.name = "grb"
grb_df.to_csv(tables_folder + "res_" + model + ".csv")