In [None]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
#import os; os.chdir(os.path.dirname(os.getcwd()))
from model_loader import get_model_fits
import numpy as np
import pandas as pd
import re
from sklearn.metrics import mean_squared_error
import seaborn as sns
import matplotlib.pyplot as plt


In [None]:
data_config = 1
data_dir = f"datasets/type_{data_config}"
results_dir_relu = "results_relu/slow"
#results_dir_tanh = "results_tanh/slow"
model_names_relu = ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T"]
#model_names_tanh = ["Gaussian tanh", "Regularized Horseshoe tanh", "Dirichlet Horseshoe tanh", "Dirichlet Student T tanh"]

relu_fits = {}
#tanh_fits = {}

files = sorted(f for f in os.listdir(data_dir) if f.endswith(".npz"))
for fname in files:
    base_config_name = fname.replace(".npz", "")  # e.g., "GAM_N100_p8_sigma1.00_seed1"
    full_config_path = f"type_{data_config}/{base_config_name}"  # â†’ "type_1/GAM_N100_p8_sigma1.00_seed1"
    relu_fit = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_relu,
        models=model_names_relu,
        include_prior=False,
    )
    
    # tanh_fit = get_model_fits(
    #     config=full_config_path,
    #     results_dir=results_dir_tanh,
    #     models=model_names_tanh,
    #     include_prior=False,
    # )

    relu_fits[base_config_name] = relu_fit  # use clean key
    #tanh_fits[base_config_name] = tanh_fit  # use clean key
    


In [3]:
def get_N_sigma(seed):
    N = 100 if seed in [1, 2, 3, 4] else 200
    sigma = 1.0 if seed in [1, 2, 7, 8] else 3.0
    if seed == 19:
        N, sigma = 1000, 1.0
    return N, sigma

def compute_rmse_results(seeds, models, all_fits, get_N_sigma):
    results = []
    posterior_means = []

    for seed in seeds:
        N, sigma = get_N_sigma(seed)
        dataset_key = f'GAM_N{N}_p8_sigma{sigma:.2f}_seed{seed}'
        path = f"datasets/type_{data_config}/{dataset_key}.npz"

        try:
            data = np.load(path)
            y_test = data["y_test"]
        except FileNotFoundError:
            print(f"[SKIP] File not found: {path}")
            continue

        for model in models:
            try:
                fit = all_fits[dataset_key][model]['posterior']
                output_test = fit.stan_variable("output_test")  # (S, N/5, O)
            except KeyError:
                print(f"[SKIP] Model or posterior not found: {dataset_key} -> {model}")
                continue

            S = output_test.shape[0]
            rmses = np.zeros(S)

            for i in range(S):
                y_hat = output_test[i]
                rmses[i] = np.sqrt(np.mean((y_hat.squeeze() - y_test.squeeze()) ** 2))
                
            posterior_mean = np.mean(output_test, axis=0)
            posterior_mean_rmse = np.sqrt(np.mean((posterior_mean.squeeze() - y_test.squeeze()) ** 2))

            posterior_means.append({
                'seed': seed,
                'N': N,
                'sigma': sigma,
                'model': model,
                'posterior_mean_rmse': posterior_mean_rmse
            })

            for i in range(S):
                results.append({
                    'seed': seed,
                    'N': N,
                    'sigma': sigma,
                    'model': model,
                    'rmse': rmses[i]
                })

    df_rmse = pd.DataFrame(results)
    df_posterior_rmse = pd.DataFrame(posterior_means)

    return df_rmse, df_posterior_rmse


In [4]:
seeds = [1, 2, 3, 4, 7, 8, 9, 10]#, 19]

df_rmse_relu, df_posterior_rmse_relu = compute_rmse_results(
    seeds, model_names_relu, relu_fits, get_N_sigma
)

# df_rmse_tanh, df_posterior_rmse_tanh = compute_rmse_results(
#     seeds, model_names_tanh, tanh_fits, get_N_sigma
# )

In [None]:
df_gauss = df_rmse_relu[df_rmse_relu['model'] == 'Gaussian']
rmse_gauss = df_gauss[df_gauss['seed'] == 1]['rmse'].mean()

df_rhs = df_rmse_relu[df_rmse_relu['model'] == 'Regularized Horseshoe']
rmse_rhs = df_rhs[df_rhs['seed'] == 1]['rmse'].mean()

df_dhs = df_rmse_relu[df_rmse_relu['model'] == 'Dirichlet Horseshoe']
rmse_dhs = df_dhs[df_dhs['seed'] == 1]['rmse'].mean()

df_dst = df_rmse_relu[df_rmse_relu['model'] == 'Dirichlet Student T']
rmse_dst = df_dst[df_dst['seed'] == 1]['rmse'].mean()


print(f"RMSE Gaussian: {rmse_gauss:.3f}")
print(f"RMSE Regularized Horseshoe: {rmse_rhs:.3f}")

print(f"RMSE Dirichlet Horseshoe: {rmse_dhs:.3f}")
print(f"RMSE Dirichlet Student T: {rmse_dst:.3f}")




In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from matplotlib.ticker import FixedLocator

# Combine and tag activation
df_relu = df_rmse_relu.copy()
df_relu["activation"] = "ReLU"

df_tanh = df_rmse_tanh.copy()
df_tanh["activation"] = "tanh"

df_all = pd.concat([df_relu, df_tanh])
df_all["model"] = df_all["model"].str.replace(" tanh", "", regex=False)

# Order of models and activations
model_order = ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T"]
activation_order = ["ReLU", "tanh"]

fig, axes = plt.subplots(2, 2, figsize=(18, 7), sharey=False)

for i, N_val in enumerate([100, 200]):
    for j, sigma_val in enumerate([1.0, 3.0]):
        ax = axes[j, i]
        df_plot = df_all[(df_all["N"] == N_val) & (df_all["sigma"] == sigma_val)].copy()

        # Use model as x, activation as hue
        sns.boxplot(
            data=df_plot,
            x="model",
            y="rmse",
            hue="activation",
            order=model_order,
            hue_order=activation_order,
            ax=ax
        )

        ax.set_title(f"N = {N_val}, Sigma = {sigma_val}")
        ax.set_xlabel("")
        ax.set_ylabel("RMSE")
        if sigma_val == 1.0:
            ax.set_ylim(0, 8)
        else:
            ax.set_ylim(0, 12)
        ax.grid(True)

        # Only show legend on top left plot
        #if i != 0 or j != 0:
        ax.get_legend().remove()

# Add shared legend at top center
handles, labels = axes[0, 0].get_legend_handles_labels()
fig.legend(handles, labels, title="Activation", loc="upper center", ncol=2)

plt.tight_layout(rect=[0, 0, 1, 0.93])
plt.show()


## Convergence

In [7]:
import pandas as pd
import numpy as np
import arviz as az

def get_N_sigma(seed):
    N = 100 if seed in [1, 2, 3, 4] else 200
    sigma = 1.0 if seed in [1, 2, 7, 8] else 3.0
    if seed == 19:
        N, sigma = 1000, 1.0
    return N, sigma

def get_all_convergence_diagnostics(all_fits):
    diagnostics = []

    for config_name, model_fits in all_fits.items():
        for model_name, fit in model_fits.items():
            try:
                idata = az.from_cmdstanpy(fit['posterior'])
                y_pred = fit['posterior'].stan_variable('output_test')
                
                path = f'datasets/type_{data_config}/{config_name}.npz'
                try:
                    data = np.load(path)
                    y_test = data["y_test"]
                except FileNotFoundError:
                    print(f"[SKIP] File not found: {path}")
                    continue
                
                divergent = idata.sample_stats["diverging"].values  # shape: (n_chains, n_draws)
                divergent_flat = divergent.flatten()  # shape: (8000,)
                divergences = np.sum(divergent_flat)
                y_pred_no_div = y_pred[~divergent_flat]
                S = y_pred.shape[0]
                rmses = np.zeros(S)
                rmses_no_div = np.zeros(S - np.sum(divergent_flat))
                
                for i in range(S):
                   rmses[i] = np.sqrt(np.mean((y_pred[i].squeeze() - y_test.squeeze()) ** 2))
                
                for i in range(S - np.sum(divergent_flat)):
                    rmses_no_div[i] = np.sqrt(np.mean((y_pred_no_div[i].squeeze() - y_test.squeeze()) ** 2))

                summary = az.summary(idata, var_names=["output"], round_to=3)
                
                rhat = summary["r_hat"]

                ess_bulk = summary["ess_bulk"]
                ess_tail = summary["ess_tail"]
                
                try:
                    seed = int(config_name.split("_seed")[-1])
                    N, sigma = get_N_sigma(seed)
                except:
                    N, sigma = np.nan, np.nan

                diagnostics.append({
                    #"config": config_name,
                    "model": model_name,
                    "max_rhat": rhat.max(),
                    "median_rhat": rhat.median(),
                    "prop_divergent": divergences/S,
                    #"p95_rhat": rhat.quantile(0.95),
                    "rmse": np.mean(rmses, axis=0),
                    "rmse_no_div": np.mean(rmses_no_div, axis=0),
                    #"min_ess_bulk": ess_bulk.min(),
                    #"median_ess_bulk": ess_bulk.median()/S,
                    #"p05_ess_bulk": ess_bulk.quantile(0.05),
                    #"min_ess_tail": ess_tail.min(),
                    "median_ess_tail": ess_tail.median()/S,
                    "N": N,
                    "sigma": sigma,
                    #"p05_ess_tail": ess_tail.quantile(0.05),
                })

            except Exception as e:
                diagnostics.append({
                    #"config": config_name,
                    "model": model_name,
                    "max_rhat": np.nan,
                    "median_rhat": np.nan,
                    "p95_rhat": np.nan,
                    "min_ess_bulk": np.nan,
                    "min_ess_tail": np.nan,
                    #"n_divergent": np.nan,
                    "error": str(e)
                })

    return pd.DataFrame(diagnostics)


In [8]:
relu_diagostic = get_all_convergence_diagnostics(relu_fits)
#tanh_diagostic = get_all_convergence_diagnostics(tanh_fits)

relu_grouped = relu_diagostic.assign(row_index=lambda df: df.index) \
    .sort_values(["model", "row_index"]) \
    .drop(columns="row_index") \
    .reset_index(drop=True)

# tanh_grouped = tanh_diagostic.assign(row_index=lambda df: df.index) \
#     .sort_values(["model", "row_index"]) \
#     .drop(columns="row_index") \
#     .reset_index(drop=True)


In [None]:
relu_grouped

In [None]:
latex_relu = relu_grouped.to_latex(index=False, float_format="%.3f")
#latex_tanh = tanh_grouped.to_latex(index=False, float_format="%.3f")
print(latex_relu)
#print(latex_tanh)