In [None]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
#import os; os.chdir(os.path.dirname(os.getcwd()))
from model_loader import get_model_fits
import numpy as np
import pandas as pd
import re
from sklearn.metrics import mean_squared_error
import seaborn as sns
import matplotlib.pyplot as plt


In [None]:
data_config = 1
data_dir = f"datasets/type_{data_config}"
results_dir_relu = "results_relu"
results_dir_large = "results_relu_large"
model_names_relu = ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T"]
model_names_large = ["Dirichlet Horseshoe"]

relu_fits = {}
large_fits = {}

files = sorted(f for f in os.listdir(data_dir) if f.endswith("GAM_N100_p8_sigma1.00_seed1.npz"))
for fname in files:
    base_config_name = fname.replace(".npz", "")  # e.g., "GAM_N100_p8_sigma1.00_seed1"
    full_config_path_relu = f"type_{data_config}/{base_config_name}"  # â†’ "type_1/GAM_N100_p8_sigma1.00_seed1"
    relu_fit = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_relu,
        models=model_names_relu,
        include_prior=False,
    )
    
    full_config_path_large = f"{base_config_name}"
    large_fit = get_model_fits(
        config=full_config_path_large,
        results_dir=results_dir_large,
        models=model_names_large,
        include_prior=False,
    )
    
    relu_fits[base_config_name] = relu_fit  # use clean key
    large_fits[base_config_name] = large_fit  # use clean key
    


In [None]:
# Create a dictionary to store the largest values of weights for each model
largest_weights = {}

for model_name in model_names_relu:
    W1_samples = relu_fits['GAM_N100_p8_sigma1.00_seed1'][model_name]['posterior'].stan_variable("W_1")
    largest_weights[model_name] = [np.max(np.abs(W1_samples[i, :, :].flatten())) for i in range(W1_samples.shape[0])]
    
W1_large = large_fits['GAM_N100_p8_sigma1.00_seed1']['Dirichlet Horseshoe']['posterior'].stan_variable("W_1")
largest_weights['Dirichlet Horseshoe large'] = [np.max(np.abs(W1_large[i, :, :].flatten())) for i in range(W1_large.shape[0])]

# Visualize the weight sizes in a histogram
plt.figure(figsize=(10, 6))
for model_name, weights in largest_weights.items():
    sns.histplot(weights, kde=True, label=model_name, bins=30)

plt.xlabel("Largest Weight Value")
plt.ylabel("Frequency")
plt.title("Histogram of Largest Weight Values Across Models")
plt.legend()
plt.show()

In [None]:
# Create a dictionary to store the largest values of weights for each model
largest_weights = {}

for model_name in model_names_relu:
    W1_samples = relu_fits['GAM_N100_p8_sigma1.00_seed1'][model_name]['posterior'].stan_variable("W_L")
    largest_weights[model_name] = [np.max(np.abs(W1_samples[i, :, :].flatten())) for i in range(W1_samples.shape[0])]

# Visualize the weight sizes in a histogram
plt.figure(figsize=(10, 6))
for model_name, weights in largest_weights.items():
    sns.histplot(weights, kde=True, label=model_name, bins=30)

plt.xlabel("Largest Weight Value")
plt.ylabel("Frequency")
plt.title("Histogram of Largest Weight Values Across Models")
plt.legend()
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
import os

def forward_pass_relu(X, W1, b1, W2, b2):
    pre_act_1 = X @ W1 + b1.reshape(1, -1)
    post_act_1 = np.maximum(0, pre_act_1)
    output = post_act_1 @ W2 + b2.reshape(1, -1)
    return output

# Load dataset
dataset_key = f'GAM_N100_p8_sigma{1:.2f}_seed{1}'
x_star = np.load(f"datasets/type_{1}/{dataset_key}.npz")["X_test"][2].reshape(1, -1)

# Set up priors to compare
priors = ['Gaussian', 'Regularized Horseshoe', 'Dirichlet Horseshoe', 'Dirichlet Horseshoe large', 'Dirichlet Student T']
results = {}

for prior in priors:
    if prior == 'Dirichlet Horseshoe large':
        posterior = large_fits[dataset_key]['Dirichlet Horseshoe']['posterior']
        W1 = posterior.stan_variable("W_1")
        b1 = posterior.stan_variable("hidden_bias")
        W2 = posterior.stan_variable("W_L")
        b2 = posterior.stan_variable("output_bias")
    
    else:
        posterior = relu_fits[dataset_key][prior]['posterior']
        W1 = posterior.stan_variable("W_1")
        b1 = posterior.stan_variable("hidden_bias")
        W2 = posterior.stan_variable("W_L")
        b2 = posterior.stan_variable("output_bias")

    S = W1.shape[0]
    preds = np.zeros(S)
    
    for s in range(S):
        f_xs = forward_pass_relu(x_star, W1[s], b1[s], W2[s], b2[s])
        preds[s] = f_xs.item()

    mu_hat, sigma_hat = np.mean(preds), np.std(preds)
    kurt = stats.kurtosis(preds, fisher=False)
    y_sorted = np.sort(preds)
    k = int(S * 0.1)
    tail = y_sorted[-k:]
    hill_est = (1 / (k - 1)) * np.sum(np.log(tail / tail.min()))

    results[prior] = {
        'preds': preds,
        'mu': mu_hat,
        'sigma': sigma_hat,
        'kurtosis': kurt,
        'hill': hill_est
    }

plt.figure(figsize=(8,5))
for prior in priors:
    preds = results[prior]['preds']
    plt.hist(preds, bins=50, density=True, histtype='step', label=prior)
plt.title("Predictive Posterior Distributions at x*")
plt.xlabel('f(x*)')
plt.ylabel('Density')
plt.legend()
plt.tight_layout()
plt.show()



In [None]:
# ---- QQ-plot Grid across Inputs ----
from matplotlib import gridspec
dataset_key = f'GAM_N100_p8_sigma{1:.2f}_seed{1}'
X_test = np.load(f"datasets/type_{1}/{dataset_key}.npz")["X_test"]

test_indices = np.linspace(0, len(X_test)-1, 16, dtype=int)  # choose 9 spread-out test points

for prior in priors:

    if prior == 'Dirichlet Horseshoe large':
        posterior = large_fits[dataset_key]['Dirichlet Horseshoe']['posterior']
        W1 = posterior.stan_variable("W_1")
        b1 = posterior.stan_variable("hidden_bias")
        W2 = posterior.stan_variable("W_L")
        b2 = posterior.stan_variable("output_bias")
        S = W1.shape[0]
    else:
        posterior = relu_fits[dataset_key][prior]['posterior']
        W1 = posterior.stan_variable("W_1")
        b1 = posterior.stan_variable("hidden_bias")
        W2 = posterior.stan_variable("W_L")
        b2 = posterior.stan_variable("output_bias")
        S = W1.shape[0]

    fig = plt.figure(figsize=(10, 10))
    fig.suptitle(f'{prior}: QQ-Plots Across 9 Test Inputs', fontsize=16)
    gs = gridspec.GridSpec(4, 4)

    for i, idx in enumerate(test_indices):
        x_star = X_test[idx].reshape(1, -1)
        preds = np.zeros(S)
        for s in range(S):
            f_xs = forward_pass_relu(x_star, W1[s], b1[s], W2[s], b2[s])
            preds[s] = f_xs.item()

        ax = fig.add_subplot(gs[i])
        stats.probplot(preds, dist="norm", plot=ax)
        ax.set_title(f'x[{idx}]', fontsize=10)

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()
