In [None]:
import os
import sys

import arviz as av
import numpy as np

from cosmogrb.universe.survey import Survey
from threeML import update_logging_level
update_logging_level("FATAL")

parent_dir = os.path.abspath("..")
if parent_dir not in sys.path:
    sys.path.append(parent_dir)
from zusammen import AnalysisBuilder, DataSet
from zusammen.stan_models.stan_model import get_model

In [None]:
simulation_folder = "simulation/"
survey_name = "survey"

inference_folder = "inference/"
data_name = "data_sig_5"

model_name = "cpl_simple_chunked_gc_global"
inference_name = "global_sig_5_1000"

Load survey

In [None]:
survey = Survey.from_file(simulation_folder + survey_name + ".h5")
ab = AnalysisBuilder(survey, use_bb=True, intervals_min=5, sig_min=2, all_above_limit=False, save_directory=inference_folder)

ab.write_yaml(inference_folder + data_name + ".yml")
ds = DataSet.from_yaml(inference_folder + data_name + ".yml")
ds.to_hdf5_file(inference_folder + data_name + ".h5")

In [None]:
ds = DataSet.from_hdf5_file(inference_folder + data_name + ".h5")
data = ds.to_stan_dict()

Stan inference

In [None]:
m = get_model(model_name)

In [None]:
m.clean_model()

In [None]:
m.build_model(opt_exp=True)

In [None]:
n_threads = 2
n_chains = 2
n_warmup = 2000
n_sampling = 1000

fit = m.model.sample(
    data=data,
    chains=n_chains,
    parallel_chains=n_chains,
    threads_per_chain=n_threads,
    inits= {
        'alpha': -1 * np.ones(data['N_intervals']),
        'log_ec': 2 * np.ones(data['N_intervals']),

        'gamma_sig_meta': 1,
        'log_Nrest_sig_meta': 1,
        'gamma_mu_meta': 1.5,
        'log_Nrest_mu_meta': 52,
        'gamma': 1.5,# * np.ones(data['N_grbs']),
        'log_Nrest': 52,# * np.ones(data['N_grbs']),
    },  # type: ignore
    iter_warmup=n_warmup,
    iter_sampling=n_sampling,
    # max_treedepth=12,
    # adapt_delta=0.99,
    # step_size=0.1,
    show_progress=True,
    refresh=1
)

In [None]:
fit.diagnose()

Save to NetCDF

In [None]:
res = av.from_cmdstanpy(fit)
res.to_netcdf(inference_folder + inference_name + ".nc")