In [None]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
#import os; os.chdir(os.path.dirname(os.getcwd()))
from model_loader import get_model_fits
import numpy as np
import pandas as pd
import re
from sklearn.metrics import mean_squared_error
import seaborn as sns
import matplotlib.pyplot as plt


In [None]:
data_config = 1
data_dir = f"datasets/type_{data_config}"
#results_dir_relu_extremely_slow = "results_relu_exhaustive/extremely_slow"
#results_dir_relu_slow = "results_relu_exhaustive/slow"
results_dir_relu = "results_relu_exhaustive"

model_names_relu = ["Dirichlet Horseshoe", "Regularized Horseshoe", "Gaussian"] #["Regularized Horseshoe", "Dirichlet Horseshoe", "Gaussian", "Dirichlet Student T"]

relu_fits = {}
relu_slow_fits = {}
relu_extremely_slow_fits = {}

for model in model_names_relu:
    base_config_name = "GAM_N100_p8_sigma3.00_seed4" #fname.replace(".npz", "")  # e.g., "GAM_N100_p8_sigma1.00_seed1"
    full_config_path = f"type_{data_config}/{base_config_name}"  # → "type_1/GAM_N100_p8_sigma1.00_seed1"
    
    fit = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_relu,
        models=[model],
        include_prior=False,
    )

    # slow_fit = get_model_fits(
    #     config=full_config_path,
    #     results_dir=results_dir_relu_slow,
    #     models=[model],
    #     include_prior=False,
    # )    
    
    # extremely_slow_fit = get_model_fits(
    #     config=full_config_path,
    #     results_dir=results_dir_relu_extremely_slow,
    #     models=[model],
    #     include_prior=False,
    # )
    
    relu_fits[model] = fit
    #relu_slow_fits[model] = slow_fit
    #relu_extremely_slow_fits[model] = extremely_slow_fit




In [None]:
import cmdstanpy 
import arviz as az

gauss_fit = relu_fits['Gaussian']['Gaussian']['posterior']
idata = az.from_cmdstanpy(gauss_fit)
divergent = idata.sample_stats["diverging"].values  # shape (n_chains, n_draws)
print("Divergent gaussian transitions:", np.sum(divergent))
print(divergent.shape)

rhs_fit = relu_fits['Regularized Horseshoe']['Regularized Horseshoe']['posterior']
idata = az.from_cmdstanpy(rhs_fit)
divergent = idata.sample_stats["diverging"].values  # shape (n_chains, n_draws)
print("Divergent RHS transitions:", np.sum(divergent))
print(divergent.shape)

dhs_fit = relu_fits['Dirichlet Horseshoe']['Dirichlet Horseshoe']['posterior']
idata = az.from_cmdstanpy(dhs_fit)
divergent = idata.sample_stats["diverging"].values  # shape (n_chains, n_draws)
print("Divergent DHS transitions:", np.sum(divergent))
print(divergent.shape)



In [None]:
import numpy as np
import arviz as az
import matplotlib.pyplot as plt


# Step 1: Load your fit and diagnostics
idata = az.from_cmdstanpy(dhs_fit)
divergent = idata.sample_stats["diverging"].values.flatten()  # shape: (n_draws,)

# Step 2: Extract output of interest — shape (n_draws, N_test)
output_test = dhs_fit.stan_variable("output_test")  # shape: (samples, N_test)

# Choose a subset of outputs to visualize
selected_outputs = {
    f"Ey[{i}]": output_test[:, i]
    for i in range(min(4, output_test.shape[1]))  # Limit to 4 for clarity
}

# Step 3: Convert to InferenceData
idata_output = az.from_dict(
    posterior=selected_outputs,
    sample_stats={"diverging": divergent.astype(bool)}
)

# Step 4: Plot pairwise comparisons for the output
az.plot_pair(
    idata_output,
    var_names=list(selected_outputs.keys()),
    kind='scatter',
    divergences=True,
    marginal_kwargs={'fill_last': True}
)
plt.suptitle("Divergences in output_test space", fontsize=14)
plt.show()


In [4]:
import pandas as pd
import numpy as np
import arviz as az

def get_N_sigma(seed):
    N = 100 if seed in [1, 2, 3, 4] else 200
    sigma = 1.0 if seed in [1, 2, 7, 8] else 3.0
    if seed == 19:
        N, sigma = 1000, 1.0
    return N, sigma

def get_all_convergence_diagnostics(all_fits):
    diagnostics = []


    for model_name, fit in all_fits.items():
        try:
            idata = az.from_cmdstanpy(fit[model_name]['posterior'])
            y_pred = fit[model_name]['posterior'].stan_variable('output_test')
            
            path = f'datasets/type_{data_config}/GAM_N100_p8_sigma3.00_seed4.npz'
            try:
                data = np.load(path)
                y_test = data["y_test"]
            except FileNotFoundError:
                print(f"[SKIP] File not found: {path}")
                continue
            
            #idata = az.from_cmdstanpy(fit)
            divergent = idata.sample_stats["diverging"].values  # shape: (n_chains, n_draws)
            divergent_flat = divergent.flatten()  # shape: (8000,)
            y_pred_no_div = y_pred[~divergent_flat]
            
            S = y_pred.shape[0]
            rmses = np.zeros(S)
            rmses_no_div = np.zeros(S - np.sum(divergent_flat))
            
            for i in range(S):
                rmses[i] = np.sqrt(np.mean((y_pred[i].squeeze() - y_test.squeeze()) ** 2))
            
            for i in range(S - np.sum(divergent_flat)):
                rmses_no_div[i] = np.sqrt(np.mean((y_pred_no_div[i].squeeze() - y_test.squeeze()) ** 2))

            summary = az.summary(idata, var_names=["output"], round_to=3)
            
            rhat = summary["r_hat"]

            ess_bulk = summary["ess_bulk"]
            ess_tail = summary["ess_tail"]
            

            N, sigma = 100, 1

            diagnostics.append({
                #"config": config_name,
                "model": model_name,
                "max_rhat": rhat.max(),
                "median_rhat": rhat.median(),
                #"p95_rhat": rhat.quantile(0.95),
                "rmse": np.mean(rmses, axis=0),
                "rmse_no_div": np.mean(rmses_no_div, axis=0),
                #"min_ess_bulk": ess_bulk.min(),
                "median_ess_bulk": ess_bulk.median(),
                #"p05_ess_bulk": ess_bulk.quantile(0.05),
                #"min_ess_tail": ess_tail.min(),
                "median_ess_tail": ess_tail.median(),
                "N": N,
                "sigma": sigma,
                #"p05_ess_tail": ess_tail.quantile(0.05),
                #"n_divergent": divergences
            })

        except Exception as e:
            diagnostics.append({
                #"config": config_name,
                "model": model_name,
                "max_rhat": np.nan,
                "median_rhat": np.nan,
                #"p95_rhat": np.nan,
                #"min_ess_bulk": np.nan,
                #"min_ess_tail": np.nan,
                #"n_divergent": np.nan,
                "error": str(e)
            })

    return pd.DataFrame(diagnostics)


In [5]:
relu_diagostic = get_all_convergence_diagnostics(relu_fits)

In [None]:
relu_diagostic