In [None]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
#import os; os.chdir(os.path.dirname(os.getcwd()))
from utils.model_loader import get_model_fits
import numpy as np
import pandas as pd
import re
from sklearn.metrics import mean_squared_error
import seaborn as sns
import matplotlib.pyplot as plt


In [None]:
data_dir = f"datasets/friedman"
#results_dir_relu = "results/regression/single_layer/relu/friedman/convergence"
results_dir_tanh = "results/regression/single_layer/tanh/friedman/convergence"
#model_names_relu = ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T"]
model_names_tanh = ["Gaussian tanh", "Regularized Horseshoe tanh", "Dirichlet Horseshoe tanh nodewise"]#, "Dirichlet Student T tanh"]

#relu_fits = {}
tanh_fits = {}

files = sorted(f for f in os.listdir(data_dir) if f.endswith(".npz"))
for fname in files:
    base_config_name = fname.replace(".npz", "")  # e.g., "GAM_N100_p8_sigma1.00_seed1"
    full_config_path = f"{base_config_name}"  # â†’ "type_1/GAM_N100_p8_sigma1.00_seed1"
    # relu_fit = get_model_fits(
    #     config=full_config_path,
    #     results_dir=results_dir_relu,
    #     models=model_names_relu,
    #     include_prior=False,
    # )
    
    tanh_fit = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_tanh,
        models=model_names_tanh,
        include_prior=False,
    )
    

    #relu_fits[base_config_name] = relu_fit  # use clean key
    tanh_fits[base_config_name] = tanh_fit  # use clean key
    


In [3]:
import pandas as pd
import numpy as np
import arviz as az

seeds = [1, 2, 11]

def get_N_sigma(seed):
    if seed == 1:
        N=100
    elif seed == 2:
        N=200
    else:
        N=500
    sigma = 1
    return N, sigma

def get_all_convergence_diagnostics(all_fits):
    diagnostics = []
    rhats_global = {}
    for config_name, model_fits in all_fits.items():
        rhats = {}
        for model_name, fit in model_fits.items():
            try:
                idata = az.from_cmdstanpy(fit['posterior'])
                y_pred = fit['posterior'].stan_variable('output_test')
                
                path = f'datasets/friedman/{config_name}.npz'
                try:
                    data = np.load(path)
                    y_test = data["y_test"]
                except FileNotFoundError:
                    print(f"[SKIP] File not found: {path}")
                    continue
                
                divergent = idata.sample_stats["diverging"].values  # shape: (n_chains, n_draws)
                divergent_flat = divergent.flatten()  # shape: (8000,)
                divergences = np.sum(divergent_flat)
                y_pred_no_div = y_pred[~divergent_flat]
                S = y_pred.shape[0]
                rmses = np.zeros(S)
                rmses_no_div = np.zeros(S - np.sum(divergent_flat))
                
                for i in range(S):
                   rmses[i] = np.sqrt(np.mean((y_pred[i].squeeze() - y_test.squeeze()) ** 2))
                
                for i in range(S - np.sum(divergent_flat)):
                    rmses_no_div[i] = np.sqrt(np.mean((y_pred_no_div[i].squeeze() - y_test.squeeze()) ** 2))

                summary = az.summary(idata, var_names=["output"], round_to=3)
                
                rhat = summary["r_hat"]
                
                rhats[model_name] = rhat

                ess_bulk = summary["ess_bulk"]
                ess_tail = summary["ess_tail"]
                ess = summary["ess_tail"]
                
                try:
                    seed = int(config_name.split("_seed")[-1])
                    N, sigma = get_N_sigma(seed)
                except:
                    N, sigma = np.nan, np.nan

                diagnostics.append({
                    "model": model_name,
                    "max_rhat": rhat.max(),
                    "median_rhat": rhat.median(),
                    "prop_divergent": divergences/S,
                    #"rmse": np.mean(rmses, axis=0),
                    #"rmse_no_div": np.mean(rmses_no_div, axis=0),
                    "median_ess_tail": ess_tail.median()/S,
                    "median_ess_bulk": ess_bulk.median()/S,
                    "N": N,
                    #"sigma": sigma,
                })

            except Exception as e:
                diagnostics.append({
                    #"config": config_name,
                    "model": model_name,
                    "max_rhat": np.nan,
                    "median_rhat": np.nan,
                    "p95_rhat": np.nan,
                    "min_ess_bulk": np.nan,
                    "min_ess_tail": np.nan,
                    #"n_divergent": np.nan,
                    "error": str(e)
                })
        rhats_global[config_name] = rhats

    return pd.DataFrame(diagnostics), rhats_global


In [4]:
#relu_diagostic, rhats_relu = get_all_convergence_diagnostics(relu_fits)
tanh_diagostic, rhats_tanh = get_all_convergence_diagnostics(tanh_fits)

In [5]:
# relu_grouped = relu_diagostic.assign(row_index=lambda df: df.index) \
#     .sort_values(["model", "row_index"]) \
#     .drop(columns="row_index") \
#     .reset_index(drop=True)
tanh_grouped = tanh_diagostic.assign(row_index=lambda df: df.index) \
    .sort_values(["model", "row_index"]) \
    .drop(columns="row_index") \
    .reset_index(drop=True)


In [None]:
#print(relu_grouped.to_latex(index=False))
print((tanh_grouped.round(3)).to_latex(index=False))

## CONVERGENCE vs ERROR

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import arviz as az

fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharex=True, sharey=True)

# --- First panel: N=100 ---
dataset_path = "datasets/friedman/Friedman_N100_p10_sigma1.00_seed1.npz"
data = np.load(dataset_path)
y_test = data["y_test"]
models_dict = tanh_fits['Friedman_N100_p10_sigma1.00_seed1']

for model_name, model_dict in models_dict.items():
    model = model_dict['posterior']
    idata = az.from_cmdstanpy(model)
    rhat = az.summary(idata, var_names=["output_test"], round_to=3)["r_hat"].values
    y_pred = np.mean(model.stan_variable("output_test"), axis=0).squeeze(-1)
    rmse = np.sqrt((y_test - y_pred) ** 2)
    axes[0].scatter(rhat, rmse, label=model_name, alpha=0.7)

axes[0].set_title("N = 100")
axes[0].set_xlabel(r"$\hat{R}$")
axes[0].set_ylabel("RMSE per test point")
axes[0].grid(True)

# --- Second panel: N=200 ---
dataset_path = "datasets/friedman/Friedman_N200_p10_sigma1.00_seed2.npz"
data = np.load(dataset_path)
y_test = data["y_test"]
models_dict = tanh_fits['Friedman_N200_p10_sigma1.00_seed2']

for model_name, model_dict in models_dict.items():
    model = model_dict['posterior']
    idata = az.from_cmdstanpy(model)
    rhat = az.summary(idata, var_names=["output_test"], round_to=3)["r_hat"].values
    y_pred = np.mean(model.stan_variable("output_test"), axis=0).squeeze(-1)
    rmse = np.sqrt((y_test - y_pred) ** 2)
    axes[1].scatter(rhat, rmse, label=model_name, alpha=0.7)

axes[1].set_title("N = 200")
axes[1].set_xlabel(r"$\hat{R}$")
axes[1].grid(True)

# Shared legend
handles, labels = axes[1].get_legend_handles_labels()
fig.legend(handles, labels, loc="upper center", ncol=len(labels))

plt.tight_layout()
plt.subplots_adjust(top=0.88)
plt.show()


## TRACEPLOTS

In [None]:
import arviz as az
gauss_fit = tanh_fits['Friedman_N100_p10_sigma1.00_seed1']['Gaussian tanh']['posterior']
idata = az.from_cmdstanpy(gauss_fit)
divergent = idata.sample_stats["diverging"].values  # shape (n_chains, n_draws)
print("Divergent gaussian transitions:", np.sum(divergent))
print(divergent.shape)
# Plot trace for output[0,0] through output[4,0]

indices = [1, 3, 4, 12]
az.plot_trace(
    idata,
    var_names=["output"],
    coords={"output_dim_0": indices, "output_dim_1": [0]}
)


In [None]:
rhs_fit = tanh_fits['Friedman_N100_p10_sigma1.00_seed1']['Regularized Horseshoe tanh']['posterior']
idata = az.from_cmdstanpy(rhs_fit)
divergent = idata.sample_stats["diverging"].values  # shape (n_chains, n_draws)
print("Divergent RHS transitions:", np.sum(divergent))
print(divergent.shape)
indices = [1, 3, 4, 12]
az.plot_trace(
    idata,
    var_names=["output"],
    coords={"output_dim_0": indices, "output_dim_1": [0]}
)

In [None]:
dhs_fit = tanh_fits['Friedman_N100_p10_sigma1.00_seed1']['Dirichlet Horseshoe tanh nodewise']['posterior']
idata = az.from_cmdstanpy(dhs_fit)
divergent = idata.sample_stats["diverging"].values  # shape (n_chains, n_draws)
print("Divergent DHS transitions:", np.sum(divergent))
print(divergent.shape)
indices = [1, 3, 4, 12]
az.plot_trace(
    idata,
    var_names=["output"],
    coords={"output_dim_0": indices, "output_dim_1": [0]}
)

In [None]:
dhs_fit = tanh_fits['Friedman_N100_p10_sigma1.00_seed1']['Dirichlet Student T tanh nodewise']['posterior']
idata = az.from_cmdstanpy(dhs_fit)
divergent = idata.sample_stats["diverging"].values  # shape (n_chains, n_draws)
print("Divergent DHS transitions:", np.sum(divergent))
print(divergent.shape)
indices = [1, 3, 4, 12]
az.plot_trace(
    idata,
    var_names=["output"],
    coords={"output_dim_0": indices, "output_dim_1": [0]}
)

## RHAT PLOTS

In [None]:
import numpy as np
import matplotlib.pyplot as plt

abbr = {
    "Gaussian tanh": "Gauss",
    "Regularized Horseshoe tanh": "RHS",
    "Dirichlet Horseshoe tanh nodewise": "DHS",
    "Dirichlet Student T tanh": "DST",
}

def plot_output_rhats_tanh_overlay(rhats_tanh, bins=60, log_y=True):
    plt.figure(figsize=(6, 4))

    # Choose a common x-range so models are comparable
    all_x = np.concatenate([np.asarray(s.values, float) for s in rhats_tanh.values()])
    all_x = all_x[np.isfinite(all_x)]
    xmax = np.quantile(all_x, 0.999) if all_x.size else 1.1
    xrng = (1.0, xmax if xmax > 1.0 else 1.1)

    for model, s in rhats_tanh.items():
        x = np.asarray(s.values, dtype=float)
        x = x[np.isfinite(x)]
        plt.hist(
            x,
            bins=bins,
            range=xrng,
            histtype="step",   # outlines instead of filled bars
            linewidth=2,
            label=abbr.get(model, model),
        )

    if log_y:
        plt.yscale("log")

    plt.title(r"$\tanh$")
    plt.xlabel(r"$\hat{R}$", fontsize=15)
    plt.ylabel("Frequency", fontsize=15)
    plt.grid(True, which="both", axis="y", alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.show()


plot_output_rhats_tanh_overlay(rhats_tanh, bins=60, log_y=True)

In [7]:
import numpy as np
import matplotlib.pyplot as plt

abbr = {
    "Gaussian tanh": "Gauss",
    "Regularized Horseshoe tanh": "RHS",
    "Dirichlet Horseshoe tanh nodewise": "DHS",
    "Dirichlet Student T tanh": "DST",
}

def plot_output_rhats_tanh(rhats_tanh, bins=20, log_y=True, figsize_per_plot=(3.6, 2.8)):
    """Make a single-row panel for tanh models."""
    ncols = len(rhats_tanh)
    fig, axes = plt.subplots(1, ncols, figsize=(figsize_per_plot[0]*ncols, figsize_per_plot[1]),
                             sharey=True)

    if ncols == 1:
        axes = [axes]

    for ax, (model, s) in zip(axes, rhats_tanh.items()):
        x = np.asarray(s.values, dtype=float)
        x = x[np.isfinite(x)]
        xmax = np.quantile(x, 0.999) if x.size else 1.1

        ax.hist(x, bins=bins, range=(1.0, xmax if xmax > 1.0 else 1.1))
        if log_y:
            ax.set_yscale("log")
        ax.set_xlabel(r"$\hat{R}$", fontsize=15)
        ax.set_title(abbr.get(model, model), fontsize=15)
        ax.tick_params(axis='both', labelsize=15)
        ax.grid(True, which="both", axis="y", alpha=0.3)

    axes[0].set_ylabel("Frequency", fontsize=15)
    fig.tight_layout()
    return fig


In [None]:
p = plot_output_rhats_tanh(rhats_tanh['Friedman_N100_p10_sigma1.00_seed1'], bins=60, log_y=True)

In [None]:
p = plot_output_rhats_tanh(rhats_tanh['Friedman_N200_p10_sigma1.00_seed2'], bins=60, log_y=True)

In [None]:
p = plot_output_rhats_tanh(rhats_tanh['Friedman_N500_p10_sigma1.00_seed11'], bins=60, log_y=True)

In [19]:
import numpy as np
import matplotlib.pyplot as plt

abbr = {
    "Gaussian tanh": "Gauss",
    "Regularized Horseshoe tanh": "RHS",
    "Dirichlet Horseshoe tanh nodewise": "DHS",
    "Dirichlet Student T tanh": "DST",
}

def plot_output_rhats_tanh_overlay_by_N(
    rhats_by_N,          # dict: { "N=100": rhats_dict, "N=200": rhats_dict, ... }
    bins=60,
    log_y=True,
    figsize_per_plot=(3.6, 2.8),
    clip_q=0.999,
):
    """
    One row with one panel per model. In each panel, overlay histograms for each N (different colors).
    rhats_by_N[label] should be a dict {model_name -> Series_of_rhat}.
    """
    # Determine model order from the first entry
    first_key = next(iter(rhats_by_N))
    models = list(rhats_by_N[first_key].keys())
    ncols = len(models)

    fig, axes = plt.subplots(
        1, ncols,
        figsize=(figsize_per_plot[0] * ncols, figsize_per_plot[1]),
        sharey=True
    )
    if ncols == 1:
        axes = [axes]

    # Use a common x-range per model across N so overlays are comparable
    for ax, model in zip(axes, models):
        # collect all samples across N for that model to set a shared range
        all_x = []
        for _, rhats_dict in rhats_by_N.items():
            s = rhats_dict[model]
            x = np.asarray(s.values, dtype=float)
            x = x[np.isfinite(x)]
            if x.size:
                all_x.append(x)
        all_x = np.concatenate(all_x) if all_x else np.array([1.0])

        xmax = np.quantile(all_x, clip_q) if all_x.size else 1.1
        xrng = (1.0, xmax if xmax > 1.0 else 1.1)

        # overlay one histogram per N
        for label, rhats_dict in rhats_by_N.items():
            s = rhats_dict[model]
            x = np.asarray(s.values, dtype=float)
            x = x[np.isfinite(x)]
            ax.hist(
                x,
                bins=bins,
                range=xrng,
                histtype="step",   # outline so overlays remain readable
                linewidth=3,
                label=label
            )

        if log_y:
            ax.set_yscale("log")

        ax.set_xlabel(r"$\hat{R}$", fontsize=15)
        ax.set_title(abbr.get(model, model), fontsize=15)
        ax.tick_params(axis="both", labelsize=12)
        ax.grid(True, which="both", axis="y", alpha=0.3)

    axes[0].set_ylabel("Frequency", fontsize=15)

    # one shared legend for the whole figure
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc="upper right", frameon=True)

    fig.tight_layout()
    return fig


In [None]:
rhats_by_N = {
    "N=100": rhats_tanh["Friedman_N100_p10_sigma1.00_seed1"],
    "N=200": rhats_tanh["Friedman_N200_p10_sigma1.00_seed2"],
    "N=500": rhats_tanh["Friedman_N500_p10_sigma1.00_seed3"],
}

fig = plot_output_rhats_tanh_overlay_by_N(rhats_by_N, bins=15, log_y=True)
plt.show()
