In [None]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
#import os; os.chdir(os.path.dirname(os.getcwd()))
from model_loader import get_model_fits
import numpy as np
import pandas as pd
import re
from sklearn.metrics import mean_squared_error
import seaborn as sns
import matplotlib.pyplot as plt


In [None]:
data_config = 1
data_dir = f"datasets/type_{data_config}"
results_dir = "results_relu_exhaustive"
model_names = ["Dirichlet Horseshoe"]

dhs_08_fits = {}

files = sorted(f for f in os.listdir(data_dir) if f.endswith(".npz"))
for fname in files:
    base_config_name = fname.replace(".npz", "")  # e.g., "GAM_N100_p8_sigma1.00_seed1"
    full_config_path = f"type_{data_config}/{base_config_name}"  # → "type_1/GAM_N100_p8_sigma1.00_seed1"
    fits = get_model_fits(
        config=full_config_path,
        results_dir=results_dir,
        models=model_names,
        include_prior=False,
    )

    dhs_08_fits[base_config_name] = fits  # use clean key



In [None]:
data_config = 1
data_dir = f"datasets/type_{data_config}"
results_dir = "results_relu_exhaustive/slow"
model_names = ["Dirichlet Horseshoe"]

dhs_095_fits = {}

files = sorted(f for f in os.listdir(data_dir) if f.endswith(".npz"))
for fname in files:
    base_config_name = fname.replace(".npz", "")  # e.g., "GAM_N100_p8_sigma1.00_seed1"
    full_config_path = f"type_{data_config}/{base_config_name}"  # → "type_1/GAM_N100_p8_sigma1.00_seed1"
    fits = get_model_fits(
        config=full_config_path,
        results_dir=results_dir,
        models=model_names,
        include_prior=False,
    )

    dhs_095_fits[base_config_name] = fits  # use clean key



In [4]:
import pandas as pd
import numpy as np
import arviz as az

def get_N_sigma(seed):
    N = 100 if seed in [1, 2, 3, 4] else 200
    sigma = 1.0 if seed in [1, 2, 7, 8] else 3.0
    if seed == 19:
        N, sigma = 1000, 1.0
    return N, sigma

def get_all_convergence_diagnostics(all_fits):
    diagnostics = []

    for config_name, model_fits in all_fits.items():
        for model_name, fit in model_fits.items():
            try:
                # ArviZ inference data
                idata = az.from_cmdstanpy(fit['posterior'])
                y_pred = fit['posterior'].stan_variable('output_test')
                
                path = f'datasets/type_{data_config}/{config_name}.npz'
                try:
                    data = np.load(path)
                    y_test = data["y_test"]
                except FileNotFoundError:
                    print(f"[SKIP] File not found: {path}")
                    continue
                
                S = y_pred.shape[0]
                rmses = np.zeros(S)
                for i in range(S):
                   rmses[i] = np.sqrt(np.mean((y_pred.squeeze() - y_test.squeeze()) ** 2))

                # Summary statistics
                summary = az.summary(idata, var_names=["output"], round_to=3)
                
                # R-hat and ESS per parameter
                rhat = summary["r_hat"]
                
                #rhat_max = summary["r_hat"].max()
                ess_bulk = summary["ess_bulk"]
                ess_tail = summary["ess_tail"]
                
                # Parse seed from config_name
                try:
                    seed = int(config_name.split("_seed")[-1])
                    N, sigma = get_N_sigma(seed)
                except:
                    N, sigma = np.nan, np.nan

                # diagnostics.append({
                #     "model": model_name,
                #     "seed": seed,
                #     "N": N,
                #     "sigma": sigma,
                #     "max_rhat": rhat.max(),
                #     "median_rhat": rhat.median(),
                #     "p95_rhat": rhat.quantile(0.95),
                #     "rmse": np.mean(rmses)
                # })

                diagnostics.append({
                    #"config": config_name,
                    "model": model_name,
                    "max_rhat": rhat.max(),
                    "median_rhat": rhat.median(),
                    "p95_rhat": rhat.quantile(0.95),
                    "rmse": np.mean(rmses, axis=0),
                    #"min_ess_bulk": ess_bulk.min(),
                    "median_ess_bulk": ess_bulk.median(),
                    #"p05_ess_bulk": ess_bulk.quantile(0.05),
                    #"min_ess_tail": ess_tail.min(),
                    "median_ess_tail": ess_tail.median(),
                    #"p05_ess_tail": ess_tail.quantile(0.05),
                    #"n_divergent": divergences
                    "N": N,
                    "sigma": sigma,
                })

            except Exception as e:
                diagnostics.append({
                    #"config": config_name,
                    "model": model_name,
                    "max_rhat": np.nan,
                    "median_rhat": np.nan,
                    "p95_rhat": np.nan,
                    "min_ess_bulk": np.nan,
                    "min_ess_tail": np.nan,
                    #"n_divergent": np.nan,
                    "error": str(e)
                })

    return pd.DataFrame(diagnostics)


In [5]:
dhs_08_diagostic = get_all_convergence_diagnostics(dhs_08_fits)
dhs_095_diagostic = get_all_convergence_diagnostics(dhs_095_fits)

In [None]:
dhs_08_diagostic

In [None]:
dhs_095_diagostic

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Create a shared-axes figure with two subplots side by side
fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharex=True, sharey=True)

# First subplot: δ = 0.8
sns.scatterplot(
    data=dhs_08_diagostic,
    x="max_rhat", y="rmse",
    hue="sigma",
    #style="model",
    size="N", sizes=(100, 300),
    ax=axes[0]
)
axes[0].set_title(r"$\delta = 0.8$")
axes[0].set_xlabel(r"Max $\hat{R}$")
axes[0].set_ylabel("RMSE")

# Second subplot: δ = 0.95
sns.scatterplot(
    data=dhs_095_diagostic,
    x="max_rhat", y="rmse",
    hue="sigma",
    #style="model",
    size="N", sizes=(100, 300),
    ax=axes[1]
)
axes[1].set_title(r"$\delta = 0.95$")
axes[1].set_xlabel(r"Max $\hat{R}$")
axes[1].set_ylabel("")  # shared y-axis, so don't repeat

# Clean layout
fig.tight_layout()

plt.show()


In [None]:
import numpy as np
import arviz as az
dhs_08_fit = dhs_08_fits['GAM_N100_p8_sigma1.00_seed1']['Dirichlet Horseshoe']['posterior']
dhs_095_fit = dhs_095_fits['GAM_N100_p8_sigma1.00_seed1']['Dirichlet Horseshoe']['posterior']
# Suppose this is your array of shape (8000, 80, 1)
output_samples_08 = dhs_08_fit.stan_variable("output")  
output_samples_08 = output_samples_08.squeeze(-1)  # Now shape (8000, 80)

output_samples_095 = dhs_08_fit.stan_variable("output")  
output_samples_095 = output_samples_08.squeeze(-1)  # Now shape (8000, 80)

# Create a dict: One key per observation
output_dict_08 = {
    f"output[{i}]": output_samples_08[:, i]
    for i in range(output_samples_08.shape[1])
}

# Create a dict: One key per observation
output_dict_095 = {
    f"output[{i}]": output_samples_095[:, i]
    for i in range(output_samples_095.shape[1])
}

# Convert to InferenceData object
idata_output_08 = az.from_dict(posterior=output_dict_08)
az.plot_trace(idata_output_08, var_names=["output[0]"])

# Convert to InferenceData object
idata_output_095 = az.from_dict(posterior=output_dict_095)
az.plot_trace(idata_output_095, var_names=["output[0]"])
