In [12]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
#import os; os.chdir(os.path.dirname(os.getcwd()))
from model_loader import get_model_fits
import numpy as np
import pandas as pd
import re
from sklearn.metrics import mean_squared_error
import seaborn as sns
import matplotlib.pyplot as plt
#os.chdir('..')

In [None]:
data_config = 1
data_dir = f"datasets/type_{data_config}"
results_dir = "results_relu_exhaustive/slow"
model_names = ["Dirichlet Horseshoe"]

all_fits = {}

files = sorted(f for f in os.listdir(data_dir) if f.endswith(".npz"))
for fname in files:
    base_config_name = fname.replace(".npz", "")  # e.g., "GAM_N100_p8_sigma1.00_seed1"
    full_config_path = f"type_{data_config}/{base_config_name}"  # → "config_1/GAM_N100_p8_sigma1.00_seed1"
    print(full_config_path)
    print(base_config_name)
    fits = get_model_fits(
        config=full_config_path,
        results_dir=results_dir,
        models=model_names,
        include_prior=False,
    )

    all_fits[base_config_name] = fits  # use clean key



In [6]:
def forward_pass(X, W1, b1, W2, b2):
    """
    Forward pass for a single layer BNN.
    """
    pre_act_1 = X @ W1 + b1.reshape(1, -1)
    #pre_hidden += b1.reshape(1, -1)
    post_act_1 = np.maximum(0, pre_act_1)
    ouput = post_act_1 @ W2 + b2.reshape(1, -1)
    return ouput

def compute_rmse_results(seeds, models, all_fits, get_N_sigma, forward_pass,
                         sparsity=0.0, prune_fn=None):
    results = []
    posterior_means = []

    for seed in seeds:
        N, sigma = get_N_sigma(seed)
        dataset_key = f'GAM_N{N}_p8_sigma{sigma:.2f}_seed{seed}'
        path = f"datasets/type_{data_config}/{dataset_key}.npz"

        try:
            data = np.load(path)
            X_test, y_test = data["X_test"], data["y_test"]
        except FileNotFoundError:
            print(f"[SKIP] File not found: {path}")
            continue

        for model in models:
            try:
                fit = all_fits[dataset_key][model]['posterior']
                W1_samples = fit.stan_variable("W_1")           # (S, P, H)
                W2_samples = fit.stan_variable("W_L")           # (S, H, O)
                b1_samples = fit.stan_variable("hidden_bias")   # (S, O, H)
                b2_samples = fit.stan_variable("output_bias")   # (S, O)
            except KeyError:
                print(f"[SKIP] Model or posterior not found: {dataset_key} -> {model}")
                continue

            S = W1_samples.shape[0]
            rmses = np.zeros(S)
            #print(y_test.shape)
            y_hats = np.zeros((S, y_test.shape[0]))

            for i in range(S):
                W1 = W1_samples[i]
                W2 = W2_samples[i]

                # Apply pruning mask if requested
                if prune_fn is not None and sparsity > 0.0:
                    masks = prune_fn([W1, W2], sparsity)
                    W1 = W1 * masks[0]
                    W2 = W2 * masks[1]

                y_hat = forward_pass(X_test, W1, b1_samples[i][0], W2, b2_samples[i])
                y_hats[i] = y_hat.squeeze()  # Store the prediction for each sample
                rmses[i] = np.sqrt(np.mean((y_hat.squeeze() - y_test)**2))
                
            posterior_mean = np.mean(y_hats, axis=0)
            posterior_mean_rmse = np.sqrt(np.mean((posterior_mean - y_test.squeeze())**2))

            posterior_means.append({
                'seed': seed,
                'N': N,
                'sigma': sigma,
                'model': model,
                'sparsity': sparsity,
                'posterior_mean_rmse': posterior_mean_rmse
            })

            for i in range(S):
                results.append({
                    'seed': seed,
                    'N': N,
                    'sigma': sigma,
                    'model': model,
                    'sparsity': sparsity,
                    'rmse': rmses[i]
                })

    df_rmse = pd.DataFrame(results)
    df_posterior_rmse = pd.DataFrame(posterior_means)

    return df_rmse, df_posterior_rmse


In [7]:
sparsity_levels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]

seeds = [1, 2, 3, 4, 7, 8, 9, 10]

def get_N_sigma(seed):
    N = 100 if seed in [1, 2, 3, 4] else 200
    sigma = 1.0 if seed in [1, 2, 7, 8] else 3.0
    return N, sigma


def global_prune_weights(weight_matrices, sparsity_level):
    """
    Prune globally across multiple weight matrices.
    
    Args:
        weight_matrices: List of numpy arrays (weight matrices).
        sparsity_level: Float in [0, 1], fraction of weights to prune.

    Returns:
        List of binary masks with same shapes as weight_matrices.
    """
    # Flatten all weights and concatenate
    flat_weights = np.concatenate([w.flatten() for w in weight_matrices])
    abs_weights = np.abs(flat_weights)
    
    # Determine number of weights to prune
    total_weights = abs_weights.size
    num_to_prune = int(np.floor(sparsity_level * total_weights))

    # Get indices of smallest weights to prune
    prune_indices = np.argpartition(abs_weights, num_to_prune)[:num_to_prune]
    
    # Create global mask
    global_mask_flat = np.ones(total_weights, dtype=bool)
    global_mask_flat[prune_indices] = False

    # Split the global mask back into original shapes
    masks = []
    idx = 0
    for w in weight_matrices:
        size = w.size
        mask = global_mask_flat[idx:idx + size].reshape(w.shape)
        masks.append(mask.astype(float))
        idx += size

    return masks

def local_prune_weights(weights, sparsity_level, index_to_prune=0):
    """
    Apply pruning to only one weight matrix in a list, specified by index.

    Parameters:
    - weights: list of np.ndarray (e.g., [W1, W2])
    - sparsity_level: fraction of weights to prune (0.0 to 1.0)
    - index_to_prune: which weight matrix to prune in the list

    Returns:
    - list of masks (one for each weight matrix)
    """
    masks = [np.ones_like(W) for W in weights]

    W = weights[index_to_prune]
    flat = np.abs(W.flatten())
    num_to_prune = int(np.floor(sparsity_level * flat.size))

    if num_to_prune > 0:
        idx = np.argpartition(flat, num_to_prune)[:num_to_prune]
        mask_flat = np.ones_like(flat, dtype=bool)
        mask_flat[idx] = False
        masks[index_to_prune] = mask_flat.reshape(W.shape).astype(float)

    return masks



df_rmse_global, df_posterior_rmse_global = {}, {}
df_rmse_local, df_posterior_rmse_local = {}, {}

for sparsity in sparsity_levels:
    df_rmse_global[sparsity], df_posterior_rmse_global[sparsity] = compute_rmse_results(
        seeds, model_names, all_fits, get_N_sigma, forward_pass,
        sparsity=sparsity, prune_fn=global_prune_weights
    )
    
    df_rmse_local[sparsity], df_posterior_rmse_local[sparsity] = compute_rmse_results(
        seeds, model_names, all_fits, get_N_sigma, forward_pass,
        sparsity=sparsity, prune_fn=local_prune_weights
    )


In [8]:
import pandas as pd

df_rmse_full_local = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_local.items()],
    ignore_index=True
)


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

fig, axes = plt.subplots(1, 2, figsize=(16, 6), sharey=True)

for i, N_val in enumerate([100, 200]):
    ax = axes[i]
    if i == 0:
        sigma = 1.0
    else:
        sigma = 3.0
        ax.set_ylabel("RMSE")
    sns.lineplot(
        data=df_rmse_full_local[df_rmse_full_local['sigma']==sigma],
        x='sparsity', y='rmse',
        hue='model', style='N', marker='o', errorbar=None, ax=ax
    )

    ax.set_title(f"sigma = {sigma}")
    ax.set_xlabel("Sparsity Level")
    ax.set_ylabel("RMSE")
    ax.grid(True)

axes[1].legend(title="Sigma")
plt.tight_layout()
#plt.savefig(f"figures/GAM_{data_config}/prune_local_sigma.png", dpi=300, bbox_inches='tight')
plt.show()

In [10]:
import pandas as pd

df_rmse_full_global = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_global.items()],
    ignore_index=True
)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

fig, axes = plt.subplots(1, 2, figsize=(16, 6), sharey=True)

for i, N_val in enumerate([100, 200]):
    ax = axes[i]
    if i == 0:
        sigma = 1.0
    else:
        sigma = 3.0
        ax.set_ylabel("RMSE")
    sns.lineplot(
        data=df_rmse_full_global[df_rmse_full_global['sigma']==sigma],
        x='sparsity', y='rmse',
        hue='model', style='N', marker='o', errorbar=None, ax=ax
    )

    ax.set_title(f"sigma = {sigma}")
    ax.set_xlabel("Sparsity Level")
    ax.set_ylabel("RMSE")
    ax.grid(True)

axes[1].legend(title="Sigma")
plt.tight_layout()
#plt.savefig(f"figures/GAM_{data_config}/prune_global_sigma.png", dpi=300, bbox_inches='tight')
plt.show()

## Analyze

In [95]:
fit_gauss = all_fits['GAM_N100_p8_sigma1.00_seed1']['Gaussian']['posterior']
fit_rhs = all_fits['GAM_N100_p8_sigma1.00_seed1']['Regularized Horseshoe']['posterior']
fit_dhs = all_fits['GAM_N100_p8_sigma1.00_seed1']['Dirichlet Horseshoe']['posterior']

W1_samples_gauss = fit_gauss.stan_variable("W_1")           # (S, P, H)
W1_samples_rhs = fit_rhs.stan_variable("W_1")           # (S, P, H)
W1_samples_dhs = fit_dhs.stan_variable("W_1")           # (S, P, H)
W2_samples_gauss = fit_gauss.stan_variable("W_L")           # (S, H, O)
W2_samples_rhs = fit_rhs.stan_variable("W_L")           # (S, H, O)
W2_samples_dhs = fit_dhs.stan_variable("W_L")           # (S, H, O)
b1_samples_gauss = fit_gauss.stan_variable("hidden_bias")   # (S, O, H)
b1_samples_rhs = fit_rhs.stan_variable("hidden_bias")   # (S, O, H)
b1_samples_dhs = fit_dhs.stan_variable("hidden_bias")   # (S, O, H)
b2_samples_gauss = fit_gauss.stan_variable("output_bias")   # (S, O)
b2_samples_rhs = fit_rhs.stan_variable("output_bias")   # (S, O)
b2_samples_dhs = fit_dhs.stan_variable("output_bias")   # (S, O)

dataset_key = f'GAM_N{100}_p8_sigma{1:.2f}_seed{1}'
path = f"datasets/type_{data_config}/{dataset_key}.npz"

data = np.load(path)
X_test, y_test = data["X_test"], data["y_test"]

In [None]:
mean_contribs_all_gauss = np.einsum('nd,sdh->sdh', X_test, W1_samples_gauss) / len(X_test)
mean_contribs_all_rhs = np.einsum('nd,sdh->sdh', X_test, W1_samples_rhs) / len(X_test)
mean_contribs_all_dhs = np.einsum('nd,sdh->sdh', X_test, W1_samples_dhs) / len(X_test)

H = W1_samples_gauss.shape[2]  # number of hidden units

mean_contribs_gauss = mean_contribs_all_gauss.mean(axis=0)  # average over posterior samples
mean_contribs_rhs = mean_contribs_all_rhs.mean(axis=0)  # average over posterior samples
mean_contribs_dhs = mean_contribs_all_dhs.mean(axis=0)  # average over posterior samples

fig, ax = plt.subplots(1, 3, figsize=(18, 5))
for i, (mean_contribs, model_name) in enumerate(zip(
    [mean_contribs_gauss, mean_contribs_rhs, mean_contribs_dhs],
    ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe"]
)):
    sns.heatmap(mean_contribs, annot=False, fmt=".2f", cmap="coolwarm",
                xticklabels=[f"H{i}" for i in range(H)],
                yticklabels=[f"X{i}" for i in range(X_test.shape[1])],
                ax=ax[i], vmin=-0.5, vmax=0.5,)
    ax[i].set_title(f"Average Input-to-Hidden Weights ({model_name})")
    ax[i].set_xlabel("Hidden Unit")
    ax[i].set_ylabel("Input Feature")
plt.tight_layout()
plt.show()



In [None]:

mean_contribs_W2_gauss = W2_samples_gauss.reshape(W2_samples_gauss.shape[0], -1).mean(axis=0)
mean_contribs_W2_rhs = W2_samples_rhs.reshape(W2_samples_rhs.shape[0], -1).mean(axis=0)
mean_contribs_W2_dhs = W2_samples_dhs.reshape(W2_samples_dhs.shape[0], -1).mean(axis=0)


std_contribs_W2_gauss = W2_samples_gauss.reshape(W2_samples_gauss.shape[0], -1).std(axis=0)
std_contribs_W2_rhs = W2_samples_rhs.reshape(W2_samples_rhs.shape[0], -1).std(axis=0)
std_contribs_W2_dhs = W2_samples_dhs.reshape(W2_samples_dhs.shape[0], -1).std(axis=0)

H = mean_contribs_W2_gauss.shape[0]
x = np.arange(H)

fig, ax = plt.subplots(1, 3, figsize=(18, 4), sharey=True)
for i, (mean, std, title) in enumerate(zip(
    [mean_contribs_W2_gauss, mean_contribs_W2_rhs, mean_contribs_W2_dhs],
    [std_contribs_W2_gauss, std_contribs_W2_rhs, std_contribs_W2_dhs],
    ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe"]
)):
    ax[i].bar(x, mean, yerr=std, capsize=3)
    #ax[i].bar(x, mean, capsize=3)
    ax[i].set_title(f"Hidden-to-Output Weights ({title})")
    ax[i].set_xlabel("Hidden Unit")
    ax[i].set_xticks(x)
    ax[i].set_xticklabels([f"H{i}" for i in x])
ax[0].set_ylabel("Weight Strength")
plt.tight_layout()
plt.show()


In [None]:
# Config
eps = 1e-4
n_obs = X_test.shape[0]
n_features = X_test.shape[1]
n_samples = W1_samples.shape[0]

# Allocate sensitivity array: shape (n_samples, n_obs, n_features)
sensitivities_all = np.zeros((n_samples, n_obs, n_features))

# Loop over posterior samples and test points
for s in range(n_samples):
    W1 = W1_samples[s]
    W2 = W2_samples[s]
    b1 = b1_samples[s]
    b2 = b2_samples[s]

    for n in range(n_obs):
        baseline = X_test[n:n+1]  # shape (1, D)
        y_base = forward_pass(baseline, W1, b1, W2, b2).item()

        for j in range(n_features):
            x_eps = baseline.copy()
            x_eps[0, j] += eps
            y_eps = forward_pass(x_eps, W1, b1, W2, b2).item()
            sensitivities_all[s, n, j] = (y_eps - y_base) / eps

# Now you can compute:
# (1) Sensitivities per observation (averaged over posterior)
sens_per_obs = sensitivities_all.mean(axis=0)  # shape (n_obs, n_features)

# (2) Global feature importance (average over both posterior + inputs)
global_mean_sens = sens_per_obs.mean(axis=0)
global_std_sens = sensitivities_all.std(axis=(0, 1))

# Plot global sensitivity summary
plt.figure(figsize=(8, 4))
plt.bar(range(n_features), global_mean_sens, yerr=global_std_sens, capsize=4)
plt.xlabel("Input Feature")
plt.ylabel("Sensitivity")
plt.title("Global Mean Sensitivity to Input Features (± SD)")
plt.tight_layout()
plt.show()


In [None]:
# Config
eps = 1e-4
n_features = X_test.shape[1]
n_samples = W1_samples.shape[0]
baseline = X_test[10:11]  # shape (1, D)

# Allocate sensitivity array
sensitivities_all = np.zeros((n_samples, n_features))

# Loop over posterior samples
for s in range(n_samples):
    W1 = W1_samples[s]
    W2 = W2_samples[s]
    b1 = b1_samples[s]
    b2 = b2_samples[s]

    y_base = forward_pass(baseline, W1, b1, W2, b2).item()

    for j in range(n_features):
        x_eps = baseline.copy()
        x_eps[0, j] += eps
        y_eps = forward_pass(x_eps, W1, b1, W2, b2).item()
        sensitivities_all[s, j] = (y_eps - y_base) / eps

# Average + std across posterior
mean_sens = sensitivities_all.mean(axis=0)
std_sens = sensitivities_all.std(axis=0)

# Plot
plt.figure(figsize=(8, 4))
plt.bar(range(n_features), mean_sens, yerr=std_sens, capsize=4)
plt.xlabel("Input Feature")
plt.ylabel("Sensitivity")
plt.title("Mean Output Sensitivity to Input Features (± SD)")
plt.tight_layout()
plt.show()


In [None]:
# Forward pass intermediates
pre_acts = X_test @ W1_samples + b1_samples.reshape(1, -1)  # shape (N_test, H)
post_acts = np.maximum(0, pre_acts)                         # shape (N_test, H)

dead_mask = (post_acts == 0)
dead_neurons = np.all(dead_mask, axis=0)
num_dead = dead_neurons.sum()

print(f"Number of dead neurons: {num_dead} / {post_acts.shape[1]}")
print("Mean pre-activations per neuron:", pre_acts.mean(axis=0))

bias_values = b1_samples.flatten()  # shape (H,)
num_units = pre_acts.shape[1]
y_positions = np.arange(1, num_units + 1)  # boxplot y-ticks start at 1

plt.boxplot(pre_acts, vert=False)
#plt.scatter(bias_values, y_positions, color='orange', marker='x', label='Bias value')
plt.title("Distribution of Pre-Activations per Hidden Unit")
plt.xlabel("Pre-Activation Value")
plt.legend()
plt.tight_layout()
plt.show()



In [None]:
activation_rate = np.mean(post_acts > 0, axis=0)
plt.scatter(bias_values, activation_rate)
plt.xlabel("Bias")
plt.ylabel("Activation Frequency")
plt.title("Hidden Bias vs Activation Rate")

# Add hidden unit index labels
for i, (b, a) in enumerate(zip(bias_values, activation_rate)):
    plt.text(b, a, str(i), fontsize=8, ha='right', va='bottom')

plt.grid(True)
plt.tight_layout()
plt.show()



## VISUALIZE

The below only visualizes weights. I think the biases encodes much information, so I try to improve the interpretations by exploring the biases above

In [5]:
from network import extract_posterior_means

P = 8
H = 16
L = 1
out_nodes = 1
layer_sizes = [P] + [H]*L + [out_nodes]

layer_structure = {
    'input_to_hidden': {'name': 'W_1', 'shape': (P, H)},
    'hidden_to_output': {'name': 'W_L', 'shape': (H, out_nodes)}
}

def extract_all_posterior_means_test(fits, layer_structure):
    """
    Extract posterior mean weights from all models in a dictionary of fits.

    Parameters:
        fits (dict): Dictionary of fits structured as
                    {model_name: {"posterior": CmdStanMCMC}}.
        layer_structure (dict): Layer structure as required by extract_posterior_means().

    Returns:
        dict: Dictionary of posterior means for each model,
            structured as {model_name: {param_name: mean_weights}}.
    """

    model_means = {}
    for name, fit_dict in fits.items():
        model_means[name] = extract_posterior_means(fit_dict['posterior'], layer_structure)
    return model_means

dict = extract_all_posterior_means_test(all_fits['GAM_N100_p8_sigma1.00_seed2'], layer_structure)

In [None]:
import networkx as nx

def plot_all_networks_subplots(model_dicts, layer_sizes, max_width=5.0, ncols=3, figsize_per_plot=(5, 4), signed_colors=False):
    """
    Plot multiple neural networks as subplots, with edge thickness representing weight magnitude.

    Parameters:
        model_dicts (dict): Dictionary mapping model names to weight dicts.
                            Each weight dict must include:
                            - 'data_to_hidden': (input_dim, hidden_dim)
                            - 'hidden_to_output': (hidden_dim, output_dim)
                            - optionally 'hidden_to_hidden': list of (hidden_dim, hidden_dim) matrices
        layer_sizes (list[int]): List of node counts for each layer (e.g. [5, 9, 1]).
        max_width (float): Maximum line width for strongest edge. Default is 5.0.
        ncols (int): Number of subplot columns. Default is 3.
        figsize_per_plot (tuple): Base figure size per subplot (width, height).
        signed_colors (bool): If True, positive weights are red and negative weights are blue.

    Notes:
        - Automatically handles any number of models and fills unused subplot slots.
        - Weight matrices are assumed to follow (input_dim, output_dim) format.
    """

    n_models = len(model_dicts)
    nrows = int(np.ceil(n_models / ncols))
    figsize = (figsize_per_plot[0] * ncols, figsize_per_plot[1] * nrows)

    fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
    axes = axes.flatten()

    for ax in axes[n_models:]:
        ax.axis('off')

    for idx, (title, weight_dict) in enumerate(model_dicts.items()):
        G = nx.DiGraph()
        pos = {}
        node_ids_per_layer = []

        # Add nodes
        for layer_idx, size in enumerate(layer_sizes):
            nodes = []
            y_coords = np.linspace(size - 1, 0, size) - (size - 1) / 2
            for i in range(size):
                nid = f"L{layer_idx}_{i}"
                G.add_node(nid)
                pos[nid] = (layer_idx, y_coords[i])
                nodes.append(nid)
            node_ids_per_layer.append(nodes)

        edge_colors = []
        edge_widths = []

        # Function to add edges from W
        def add_edges(W, in_nodes, out_nodes):
            for j, out_node in enumerate(out_nodes):
                for i, in_node in enumerate(in_nodes):
                    w = W[i, j]
                    # if i == 7: # and j == 0:
                    #     print(f"Edge from {in_node} to {out_node} is {w}")
                    #     print(f"Edge width to append is {abs(w)}")
                    G.add_edge(in_node, out_node, weight=abs(w))
                    edge_colors.append('red' if w >= 0 else 'blue')
                    edge_widths.append(abs(w))

        # Input-to-hidden
        add_edges(weight_dict['W_1'], node_ids_per_layer[0], node_ids_per_layer[1])

        # Hidden-to-hidden
        if 'W_internal' in weight_dict:
            for l in range(len(weight_dict['W_internal'])):
                add_edges(weight_dict['W_internal'][l], node_ids_per_layer[l+1], node_ids_per_layer[l+2])

        # Hidden-to-output
        add_edges(weight_dict['W_L'], node_ids_per_layer[-2], node_ids_per_layer[-1])

        # Normalize widths
        #max_w = max(edge_widths) if edge_widths else 1.0
        #edge_widths = [max_width * (w / max_w) for w in edge_widths]
        labels = {nid: nid for nid in G.nodes}
        nx.draw_networkx_labels(G, pos, labels=labels, ax=axes[idx], font_size=8)

        edge_widths = [G[u][v]['weight'] for u, v in G.edges()]

        nx.draw(G, pos, ax=axes[idx], node_color='lightgray',
                edge_color=edge_colors if signed_colors else 'red',
                width=edge_widths, with_labels=False,
                node_size=400, arrows=False)


        axes[idx].set_title(title, fontsize=10)
        axes[idx].axis('off')
        
    plt.tight_layout()
    #plt.show()
    #plt.close()
    return fig, edge_widths

p1, widths_1 = plot_all_networks_subplots(dict, layer_sizes, signed_colors=True)


In [None]:
all_fits['GAM_N100_p8_sigma1.00_seed2']['Regularized Horseshoe']['posterior'].stan_variable("W_1")[0, :, :]