In [None]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
#import os; os.chdir(os.path.dirname(os.getcwd()))
from utils.model_loader import get_model_fits
import numpy as np
import pandas as pd
import re
from sklearn.metrics import mean_squared_error
import seaborn as sns
import matplotlib.pyplot as plt


In [None]:
data_dir = f"datasets/friedman"
results_dir_relu = "results/regression/single_layer/relu/friedman/convergence"
results_dir_tanh = "results/regression/single_layer/tanh/friedman/convergence"
model_names_relu = ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T"]
model_names_tanh = ["Gaussian tanh", "Regularized Horseshoe tanh", "Dirichlet Horseshoe tanh", "Dirichlet Student T tanh"]

relu_fits = {}
tanh_fits = {}

files = sorted(f for f in os.listdir(data_dir) if f.endswith(".npz"))
for fname in files:
    base_config_name = fname.replace(".npz", "")  # e.g., "GAM_N100_p8_sigma1.00_seed1"
    full_config_path = f"{base_config_name}"  # â†’ "type_1/GAM_N100_p8_sigma1.00_seed1"
    relu_fit = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_relu,
        models=model_names_relu,
        include_prior=False,
    )
    
    tanh_fit = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_tanh,
        models=model_names_tanh,
        include_prior=False,
    )
    

    relu_fits[base_config_name] = relu_fit  # use clean key
    tanh_fits[base_config_name] = tanh_fit  # use clean key
    


In [18]:
import pandas as pd
import numpy as np
import arviz as az

seeds = [1, 2, 11]

def get_N_sigma(seed):
    if seed == 1:
        N=100
    elif seed == 2:
        N=200
    else:
        N=500
    sigma = 1
    return N, sigma

def get_all_convergence_diagnostics(all_fits):
    diagnostics = []
    rhats = {}
    for config_name, model_fits in all_fits.items():
        for model_name, fit in model_fits.items():
            try:
                idata = az.from_cmdstanpy(fit['posterior'])
                y_pred = fit['posterior'].stan_variable('output_test')
                
                path = f'datasets/friedman/{config_name}.npz'
                try:
                    data = np.load(path)
                    y_test = data["y_test"]
                except FileNotFoundError:
                    print(f"[SKIP] File not found: {path}")
                    continue
                
                divergent = idata.sample_stats["diverging"].values  # shape: (n_chains, n_draws)
                divergent_flat = divergent.flatten()  # shape: (8000,)
                divergences = np.sum(divergent_flat)
                y_pred_no_div = y_pred[~divergent_flat]
                S = y_pred.shape[0]
                rmses = np.zeros(S)
                rmses_no_div = np.zeros(S - np.sum(divergent_flat))
                
                for i in range(S):
                   rmses[i] = np.sqrt(np.mean((y_pred[i].squeeze() - y_test.squeeze()) ** 2))
                
                for i in range(S - np.sum(divergent_flat)):
                    rmses_no_div[i] = np.sqrt(np.mean((y_pred_no_div[i].squeeze() - y_test.squeeze()) ** 2))

                summary = az.summary(idata, var_names=["output"], round_to=3)
                
                rhat = summary["r_hat"]
                
                rhats[model_name] = rhat

                ess_bulk = summary["ess_bulk"]
                ess_tail = summary["ess_tail"]
                
                try:
                    seed = int(config_name.split("_seed")[-1])
                    N, sigma = get_N_sigma(seed)
                except:
                    N, sigma = np.nan, np.nan

                diagnostics.append({
                    "model": model_name,
                    "max_rhat": rhat.max(),
                    "median_rhat": rhat.median(),
                    "prop_divergent": divergences/S,
                    "rmse": np.mean(rmses, axis=0),
                    "rmse_no_div": np.mean(rmses_no_div, axis=0),
                    "median_ess_tail": ess_tail.median()/S,
                    "N": N,
                    "sigma": sigma,
                })

            except Exception as e:
                diagnostics.append({
                    #"config": config_name,
                    "model": model_name,
                    "max_rhat": np.nan,
                    "median_rhat": np.nan,
                    "p95_rhat": np.nan,
                    "min_ess_bulk": np.nan,
                    "min_ess_tail": np.nan,
                    #"n_divergent": np.nan,
                    "error": str(e)
                })

    return pd.DataFrame(diagnostics), rhats


In [21]:
relu_diagostic, rhats_relu = get_all_convergence_diagnostics(relu_fits)
tanh_diagostic, rhats_tanh = get_all_convergence_diagnostics(tanh_fits)

In [5]:
relu_grouped = relu_diagostic.assign(row_index=lambda df: df.index) \
    .sort_values(["model", "row_index"]) \
    .drop(columns="row_index") \
    .reset_index(drop=True)

tanh_grouped = tanh_diagostic.assign(row_index=lambda df: df.index) \
    .sort_values(["model", "row_index"]) \
    .drop(columns="row_index") \
    .reset_index(drop=True)


In [None]:
print(relu_grouped.to_latex(index=False))
print(tanh_grouped.to_latex(index=False))


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# First, remove "tanh" from model names in both DataFrames
relu_diagostic = relu_diagostic.copy()
relu_diagostic["model"] = relu_diagostic["model"].str.replace(" tanh", "", regex=False)

tanh_diagostic = tanh_diagostic.copy()
tanh_diagostic["model"] = tanh_diagostic["model"].str.replace(" tanh", "", regex=False)

# Create shared-axes figure
fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharex=True, sharey=True)

# Plot ReLU
sns.scatterplot(
    data=relu_diagostic,
    x="max_rhat", y="rmse",
    hue="model",
    style="sigma",
    size="N", sizes=(100, 300),
    ax=axes[0],
    legend=False
)
axes[0].set_title("ReLU")
axes[0].set_xlabel(r"Max $\hat{R}$")
axes[0].set_ylabel("RMSE")
axes[0].grid(True)

# Plot tanh and keep legend
plot_obj = sns.scatterplot(
    data=tanh_diagostic,
    x="max_rhat", y="rmse",
    hue="model",
    style="sigma",
    size="N", sizes=(100, 300),
    ax=axes[1]
)
axes[1].set_title(r"$\tanh$")
axes[1].set_xlabel(r"Max $\hat{R}$")
axes[1].set_ylabel("")
axes[1].grid(True)

# Extract and filter legend
handles, labels = axes[1].get_legend_handles_labels()
axes[1].legend_.remove()

exclude_labels = {"model", "N", "sigma"}
filtered = [(h, l) for h, l in zip(handles, labels) if l not in exclude_labels]
if filtered:
    filtered_handles, filtered_labels = zip(*filtered)
    filtered_handles = filtered_handles[:-1]
    filtered_labels = filtered_labels[:-1]
    fig.legend(
        filtered_handles,
        filtered_labels,
        title="",
        loc="upper right",
        bbox_to_anchor=(0.9, 0.8),
        ncol=2,
    )

# Layout adjustments
fig.tight_layout()
plt.subplots_adjust(top=0.85)
plt.show()


## CONVERGENCE vs ERROR

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import arviz as az

fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharex=True, sharey=True)

# --- First panel: N=100 ---
dataset_path = "datasets/friedman/Friedman_N100_p10_sigma1.00_seed1.npz"
data = np.load(dataset_path)
y_test = data["y_test"]
models_dict = relu_fits['Friedman_N100_p10_sigma1.00_seed1']

for model_name, model_dict in models_dict.items():
    model = model_dict['posterior']
    idata = az.from_cmdstanpy(model)
    rhat = az.summary(idata, var_names=["output_test"], round_to=3)["r_hat"].values
    y_pred = np.mean(model.stan_variable("output_test"), axis=0).squeeze(-1)
    rmse = np.sqrt((y_test - y_pred) ** 2)
    axes[0].scatter(rhat, rmse, label=model_name, alpha=0.7)

axes[0].set_title("N = 100")
axes[0].set_xlabel(r"$\hat{R}$")
axes[0].set_ylabel("RMSE per test point")
axes[0].grid(True)

# --- Second panel: N=200 ---
dataset_path = "datasets/friedman/Friedman_N200_p10_sigma1.00_seed2.npz"
data = np.load(dataset_path)
y_test = data["y_test"]
models_dict = relu_fits['Friedman_N200_p10_sigma1.00_seed2']

for model_name, model_dict in models_dict.items():
    model = model_dict['posterior']
    idata = az.from_cmdstanpy(model)
    rhat = az.summary(idata, var_names=["output_test"], round_to=3)["r_hat"].values
    y_pred = np.mean(model.stan_variable("output_test"), axis=0).squeeze(-1)
    rmse = np.sqrt((y_test - y_pred) ** 2)
    axes[1].scatter(rhat, rmse, label=model_name, alpha=0.7)

axes[1].set_title("N = 200")
axes[1].set_xlabel(r"$\hat{R}$")
axes[1].grid(True)

# Shared legend
handles, labels = axes[1].get_legend_handles_labels()
fig.legend(handles, labels, loc="upper center", ncol=len(labels))

plt.tight_layout()
plt.subplots_adjust(top=0.88)
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import arviz as az

fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharex=True, sharey=True)

# --- First panel: N=100 ---
dataset_path = "datasets/friedman/Friedman_N100_p10_sigma1.00_seed1.npz"
data = np.load(dataset_path)
y_test = data["y_test"]
models_dict = tanh_fits['Friedman_N100_p10_sigma1.00_seed1']

for model_name, model_dict in models_dict.items():
    model = model_dict['posterior']
    idata = az.from_cmdstanpy(model)
    rhat = az.summary(idata, var_names=["output_test"], round_to=3)["r_hat"].values
    y_pred = np.mean(model.stan_variable("output_test"), axis=0).squeeze(-1)
    rmse = np.sqrt((y_test - y_pred) ** 2)
    axes[0].scatter(rhat, rmse, label=model_name, alpha=0.7)

axes[0].set_title("N = 100")
axes[0].set_xlabel(r"$\hat{R}$")
axes[0].set_ylabel("RMSE per test point")
axes[0].grid(True)

# --- Second panel: N=200 ---
dataset_path = "datasets/friedman/Friedman_N200_p10_sigma1.00_seed2.npz"
data = np.load(dataset_path)
y_test = data["y_test"]
models_dict = tanh_fits['Friedman_N200_p10_sigma1.00_seed2']

for model_name, model_dict in models_dict.items():
    model = model_dict['posterior']
    idata = az.from_cmdstanpy(model)
    rhat = az.summary(idata, var_names=["output_test"], round_to=3)["r_hat"].values
    y_pred = np.mean(model.stan_variable("output_test"), axis=0).squeeze(-1)
    rmse = np.sqrt((y_test - y_pred) ** 2)
    axes[1].scatter(rhat, rmse, label=model_name, alpha=0.7)

axes[1].set_title("N = 200")
axes[1].set_xlabel(r"$\hat{R}$")
axes[1].grid(True)

# Shared legend
handles, labels = axes[1].get_legend_handles_labels()
fig.legend(handles, labels, loc="upper center", ncol=len(labels))

plt.tight_layout()
plt.subplots_adjust(top=0.88)
plt.show()


## TRACEPLOTS

In [None]:
import arviz as az
gauss_fit = tanh_fits['Friedman_N100_p10_sigma1.00_seed1']['Gaussian tanh']['posterior']
idata = az.from_cmdstanpy(gauss_fit)
divergent = idata.sample_stats["diverging"].values  # shape (n_chains, n_draws)
print("Divergent gaussian transitions:", np.sum(divergent))
print(divergent.shape)
# Plot trace for output[0,0] through output[4,0]

indices = [1, 3, 4, 12]
az.plot_trace(
    idata,
    var_names=["output"],
    coords={"output_dim_0": indices, "output_dim_1": [0]}
)


In [None]:
rhs_fit = tanh_fits['Friedman_N100_p10_sigma1.00_seed1']['Regularized Horseshoe tanh']['posterior']
idata = az.from_cmdstanpy(rhs_fit)
divergent = idata.sample_stats["diverging"].values  # shape (n_chains, n_draws)
print("Divergent RHS transitions:", np.sum(divergent))
print(divergent.shape)
indices = [1, 3, 4, 12]
az.plot_trace(
    idata,
    var_names=["output"],
    coords={"output_dim_0": indices, "output_dim_1": [0]}
)

In [None]:
dhs_fit = tanh_fits['Friedman_N100_p10_sigma1.00_seed1']['Dirichlet Horseshoe tanh']['posterior']
idata = az.from_cmdstanpy(dhs_fit)
divergent = idata.sample_stats["diverging"].values  # shape (n_chains, n_draws)
print("Divergent DHS transitions:", np.sum(divergent))
print(divergent.shape)
indices = [1, 3, 4, 12]
az.plot_trace(
    idata,
    var_names=["output"],
    coords={"output_dim_0": indices, "output_dim_1": [0]}
)

In [None]:
dhs_fit = tanh_fits['Friedman_N100_p10_sigma1.00_seed1']['Dirichlet Student T tanh']['posterior']
idata = az.from_cmdstanpy(dhs_fit)
divergent = idata.sample_stats["diverging"].values  # shape (n_chains, n_draws)
print("Divergent DHS transitions:", np.sum(divergent))
print(divergent.shape)
indices = [1, 3, 4, 12]
az.plot_trace(
    idata,
    var_names=["output"],
    coords={"output_dim_0": indices, "output_dim_1": [0]}
)

## RHAT PLOTS

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def _plot_rhat_dict(axs, rhats_dict, row_title=None, bins=60, log_y=True):
    """Plot one histogram per model from a {model -> Series} dict on a row of axes."""
    for ax, (model, s) in zip(axs, rhats_dict.items()):
        x = np.asarray(s.values, dtype=float)
        x = x[np.isfinite(x)]
        # Clip extreme tails so the bulk is visible (optional)
        xmax = np.quantile(x, 0.999) if x.size else 1.1
        ax.hist(x, bins=bins, range=(1.0, xmax if xmax > 1.0 else 1.1))
        if log_y: ax.set_yscale("log")
        ax.set_xlabel(r"$\hat{R}$")
        ax.set_title(model)
        ax.grid(True, which="both", axis="y", alpha=0.3)
    axs[0].set_ylabel("count")
    if row_title:
        axs[0].annotate(row_title, xy=(0, 1.02), xycoords="axes fraction",
                        fontsize=12, fontweight="bold")

def plot_output_rhats_two_rows(rhats_relu, rhats_tanh, bins=60, log_y=True, figsize_per_plot=(3.6, 2.8)):
    """Make a 2-row panel: top=ReLU models, bottom=tanh models."""
    ncols = max(len(rhats_relu), len(rhats_tanh))
    fig, axes = plt.subplots(2, ncols, figsize=(figsize_per_plot[0]*ncols, figsize_per_plot[1]*2),
                             sharey='row', sharex='row')
    # Ensure axes are iterable even if ncols==1
    axes_relu = axes[0] if ncols > 1 else [axes[0]]
    axes_tanh = axes[1] if ncols > 1 else [axes[1]]

    # Fill missing axes if dicts have fewer models than ncols
    def pad_axes(ax_list, needed):
        if len(ax_list) < needed:
            ax_list += [ax_list[-1].figure.add_subplot(ax_list[-1].get_subplotspec())]*(needed-len(ax_list))
        return ax_list

    # Plot rows
    _plot_rhat_dict(axes_relu[:len(rhats_relu)], rhats_relu, row_title="ReLU", bins=bins, log_y=log_y)
    _plot_rhat_dict(axes_tanh[:len(rhats_tanh)], rhats_tanh, row_title="tanh", bins=bins, log_y=log_y)

    # Hide any unused axes
    for row_axes, rhats in zip([axes_relu, axes_tanh], [rhats_relu, rhats_tanh]):
        for ax in row_axes[len(rhats):]:
            ax.axis("off")

    fig.tight_layout()
    return fig

# --- Use it ---
plot_output_rhats_two_rows(rhats_relu, rhats_tanh, bins=60, log_y=True)
