In [None]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
#import os; os.chdir(os.path.dirname(os.getcwd()))
from model_loader import get_model_fits
import numpy as np
import pandas as pd
import re
from sklearn.metrics import mean_squared_error
import seaborn as sns
import matplotlib.pyplot as plt
#os.chdir('..')

In [None]:
data_config = 1
data_dir = f"datasets/interactions"
results_dir = "results_relu"
model_names = ["Regularized Horseshoe", "Dirichlet Horseshoe", "Gaussian", "Dirichlet Student T"]

all_fits = {}

files = sorted(f for f in os.listdir(data_dir) if f.endswith(".npz"))
for fname in files:
    base_config_name = fname.replace(".npz", "")  # e.g., "GAM_N100_p8_sigma1.00_seed1"
    full_config_path = f"interactions/{base_config_name}"  # â†’ "config_1/GAM_N100_p8_sigma1.00_seed1"
    print(full_config_path)
    print(base_config_name)
    fits = get_model_fits(
        config=full_config_path,
        results_dir=results_dir,
        models=model_names,
        include_prior=False,
    )

    all_fits[base_config_name] = fits  # use clean key



In [7]:
def get_N_sigma(seed):
    N = 100 if seed in [1, 2, 3, 4] else 200
    sigma = 1.0 if seed in [1, 2, 7, 8] else 3.0
    if seed == 19:
        N, sigma = 1000, 1.0
    return N, sigma

def compute_rmse_results(seeds, models, all_fits, get_N_sigma):
    results = []
    posterior_means = []

    for seed in seeds:
        N, sigma = get_N_sigma(seed)
        dataset_key = f'INT_N{N}_p8_sigma{sigma:.2f}_seed{seed}'
        path = f"datasets/interactions/{dataset_key}.npz"

        try:
            data = np.load(path)
            y_test = data["y_test"]
        except FileNotFoundError:
            print(f"[SKIP] File not found: {path}")
            continue

        for model in models:
            try:
                fit = all_fits[dataset_key][model]['posterior']
                output_test = fit.stan_variable("output_test")  # (S, N/5, O)
            except KeyError:
                print(f"[SKIP] Model or posterior not found: {dataset_key} -> {model}")
                continue

            S = output_test.shape[0]
            rmses = np.zeros(S)

            for i in range(S):
                y_hat = output_test[i]
                rmses[i] = np.sqrt(np.mean((y_hat.squeeze() - y_test.squeeze()) ** 2))
                
            posterior_mean = np.mean(output_test, axis=0)
            posterior_mean_rmse = np.sqrt(np.mean((posterior_mean.squeeze() - y_test.squeeze()) ** 2))

            posterior_means.append({
                'seed': seed,
                'N': N,
                'sigma': sigma,
                'model': model,
                'posterior_mean_rmse': posterior_mean_rmse
            })

            for i in range(S):
                results.append({
                    'seed': seed,
                    'N': N,
                    'sigma': sigma,
                    'model': model,
                    'rmse': rmses[i]
                })

    df_rmse = pd.DataFrame(results)
    df_posterior_rmse = pd.DataFrame(posterior_means)

    return df_rmse, df_posterior_rmse


In [8]:
seeds = [1, 2, 3, 4, 7, 8, 9, 10]#, 19]

df_rmse, df_posterior_rmse = compute_rmse_results(
    seeds, model_names, all_fits, get_N_sigma
)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

fig, axes = plt.subplots(1, 2, figsize=(16, 6), sharey=True)

for i, N_val in enumerate([100, 200]):
    ax = axes[i]
    sns.boxplot(
        data=df_rmse[df_rmse['N'] == N_val],
        x="model", y="rmse", hue="sigma", ax=ax
    )
    sns.scatterplot(
        data=df_posterior_rmse[df_posterior_rmse['N'] == N_val],
        x="model", y="posterior_mean_rmse", hue="sigma",
        marker="X", s=100, ax=ax, zorder=10, legend=False
    )
    ax.set_title(f"RMSE Distribution (N = {N_val})")
    ax.set_xlabel("Model")
    #ax.set_ylim(0, 5)
    ax.set_ylabel("RMSE")
    ax.grid(True)

axes[1].legend(title="Sigma")
plt.tight_layout()
#plt.savefig(f"figures/GAM_{data_config}/rmse_models.png", dpi=300, bbox_inches='tight')
plt.show()

In [12]:
import arviz as az

def get_N_sigma(seed):
    N = 100 if seed in [1, 2, 3, 4] else 200
    sigma = 1.0 if seed in [1, 2, 7, 8] else 3.0
    if seed == 19:
        N, sigma = 1000, 1.0
    return N, sigma

def get_all_convergence_diagnostics(all_fits):
    diagnostics = []

    for config_name, model_fits in all_fits.items():
        for model_name, fit in model_fits.items():
            try:
                idata = az.from_cmdstanpy(fit['posterior'])
                y_pred = fit['posterior'].stan_variable('output_test')
                
                path = f'datasets/interactions/{config_name}.npz'
                try:
                    data = np.load(path)
                    y_test = data["y_test"]
                except FileNotFoundError:
                    print(f"[SKIP] File not found: {path}")
                    continue
                
                S = y_pred.shape[0]
                rmses = np.zeros(S)
                for i in range(S):
                   rmses[i] = np.sqrt(np.mean((y_pred.squeeze() - y_test.squeeze()) ** 2))

                summary = az.summary(idata, var_names=["output"], round_to=3)
                
                rhat = summary["r_hat"]

                ess_bulk = summary["ess_bulk"]
                ess_tail = summary["ess_tail"]
                
                try:
                    seed = int(config_name.split("_seed")[-1])
                    N, sigma = get_N_sigma(seed)
                except:
                    N, sigma = np.nan, np.nan

                diagnostics.append({
                    #"config": config_name,
                    "model": model_name,
                    "max_rhat": rhat.max(),
                    "median_rhat": rhat.median(),
                    #"p95_rhat": rhat.quantile(0.95),
                    "rmse": np.mean(rmses, axis=0),
                    #"min_ess_bulk": ess_bulk.min(),
                    "median_ess_bulk": ess_bulk.median(),
                    #"p05_ess_bulk": ess_bulk.quantile(0.05),
                    #"min_ess_tail": ess_tail.min(),
                    "median_ess_tail": ess_tail.median(),
                    "N": N,
                    "sigma": sigma,
                    #"p05_ess_tail": ess_tail.quantile(0.05),
                    #"n_divergent": divergences
                })

            except Exception as e:
                diagnostics.append({
                    #"config": config_name,
                    "model": model_name,
                    "max_rhat": np.nan,
                    "median_rhat": np.nan,
                    "p95_rhat": np.nan,
                    "min_ess_bulk": np.nan,
                    "min_ess_tail": np.nan,
                    #"n_divergent": np.nan,
                    "error": str(e)
                })

    return pd.DataFrame(diagnostics)


In [13]:
df_diagostic = get_all_convergence_diagnostics(all_fits)

In [None]:
df_diagostic

In [None]:
# Plot ReLU
plt.figure(figsize=(10, 6))
sns.scatterplot(
    data=df_diagostic,
    x="max_rhat", y="rmse",
    hue="model",
    style="sigma",
    size="N", sizes=(100, 300),
    legend=True
)
plt.title("ReLU")
plt.xlabel(r"Max $\hat{R}$")
plt.ylabel("RMSE")
plt.grid(True)
plt.tight_layout()
plt.show()


In [6]:
from properscoring import crps_ensemble
import pandas as pd
import numpy as np

def compute_crps_results(seeds, models, all_fits, get_N_sigma):
    results = []
    posterior_means = []

    for seed in seeds:
        N, sigma = get_N_sigma(seed)
        dataset_key = f'GAM_N{N}_p8_sigma{sigma:.2f}_seed{seed}'
        path = f"datasets/type_{data_config}/{dataset_key}.npz"

        try:
            data = np.load(path)
            y_test = data["y_test"].squeeze()  # (N_test,)
        except FileNotFoundError:
            print(f"[SKIP] File not found: {path}")
            continue

        for model in models:
            try:
                fit = all_fits[dataset_key][model]['posterior']
                output_test = fit.stan_variable("output_test").squeeze()  # shape: (S, N_test)
            except KeyError:
                print(f"[SKIP] Model or posterior not found: {dataset_key} -> {model}")
                continue

            # Compute CRPS for each test point using all samples
            crps_vals = crps_ensemble(y_test, output_test.T)  # shape: (N_test,)
            mean_crps = np.mean(crps_vals)

            # Store average CRPS
            posterior_means.append({
                'seed': seed,
                'N': N,
                'sigma': sigma,
                'model': model,
                'mean_crps': mean_crps
            })

            # Store CRPS per observation (optional granularity)
            for i in range(len(y_test)):
                results.append({
                    'seed': seed,
                    'N': N,
                    'sigma': sigma,
                    'model': model,
                    'obs': i,
                    'crps': crps_vals[i]
                })

    df_crps = pd.DataFrame(results)
    df_crps_summary = pd.DataFrame(posterior_means)

    return df_crps, df_crps_summary


In [7]:
seeds = [1, 2, 3, 4, 7, 8, 9, 10]#, 19]

df_crps, df_crps_summary = compute_crps_results(
    seeds, model_names, all_fits, get_N_sigma
)

In [None]:
df_crps

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

fig, axes = plt.subplots(1, 2, figsize=(16, 6), sharey=True)

for i, N_val in enumerate([100, 200]):
    ax = axes[i]
    sns.boxplot(
        data=df_crps[df_crps['N'] == N_val],
        x="model", y="crps", hue="sigma", ax=ax
    )
    ax.set_title(f"CRPS Distribution (N = {N_val})")
    ax.set_xlabel("Model")
    #ax.set_ylim(0, 5)
    ax.set_ylabel("CRPS")
    ax.grid(True)

axes[1].legend(title="Sigma")
plt.tight_layout()
#plt.savefig(f"figures/GAM_{data_config}/rmse_models.png", dpi=300, bbox_inches='tight')
plt.show()