In [None]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
#import os; os.chdir(os.path.dirname(os.getcwd()))
from utils.model_loader import get_model_fits
import numpy as np
import pandas as pd
import re
from sklearn.metrics import mean_squared_error
import seaborn as sns
import matplotlib.pyplot as plt


In [None]:
data_dir = f"datasets/friedman"
results_dir_relu = "results/regression/single_layer/relu/friedman"
results_dir_tanh = "results/regression/single_layer/tanh/friedman"
#model_names_relu = ["Dirichlet Horseshoe"]
model_names_relu = ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T"]
model_names_tanh = ["Gaussian tanh", "Regularized Horseshoe tanh", "Dirichlet Horseshoe tanh", "Dirichlet Student T tanh"]


relu_fits = {}
tanh_fits = {}

files = sorted(f for f in os.listdir(data_dir) if f.endswith(".npz"))
for fname in files:
    base_config_name = fname.replace(".npz", "")  # e.g., "GAM_N100_p8_sigma1.00_seed1"
    full_config_path = f"{base_config_name}"  # â†’ "type_1/GAM_N100_p8_sigma1.00_seed1"
    relu_fit = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_relu,
        models=model_names_relu,
        include_prior=False,
    )
    
    tanh_fit = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_tanh,
        models=model_names_tanh,
        include_prior=False,
    )
    

    relu_fits[base_config_name] = relu_fit  # use clean key
    tanh_fits[base_config_name] = tanh_fit  # use clean key
    

In [3]:
from utils.sparsity import compute_sparse_rmse_results, forward_pass_relu, forward_pass_tanh, local_prune_weights

sparsity_levels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]

seeds = [1, 2]

def get_N_sigma(seed):
    N = 100 if seed == 1 else 200
    sigma = 1
    return N, sigma

In [4]:
df_rmse_relu, df_posterior_rmse_relu = {}, {}
df_rmse_tanh, df_posterior_rmse_tanh = {}, {}

for sparsity in sparsity_levels:
    df_rmse_relu[sparsity], df_posterior_rmse_relu[sparsity] = compute_sparse_rmse_results(
        seeds, model_names_relu, relu_fits, get_N_sigma, forward_pass_relu,
        sparsity=sparsity, prune_fn=local_prune_weights
    )
    
    df_rmse_tanh[sparsity], df_posterior_rmse_tanh[sparsity] = compute_sparse_rmse_results(
        seeds, model_names_tanh, tanh_fits, get_N_sigma, forward_pass_tanh,
        sparsity=sparsity, prune_fn=local_prune_weights
    )

In [5]:
import pandas as pd

df_rmse_full_relu = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_relu.items()],
    ignore_index=True
)

df_rmse_full_tanh = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_tanh.items()],
    ignore_index=True
)


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from collections import OrderedDict

# Clean model names
df_rmse_full_relu = df_rmse_full_relu.copy()
df_rmse_full_relu["model"] = df_rmse_full_relu["model"].str.replace(" tanh", "", regex=False)

df_rmse_full_tanh = df_rmse_full_tanh.copy()
df_rmse_full_tanh["model"] = df_rmse_full_tanh["model"].str.replace(" tanh", "", regex=False)

# Define consistent color palette
custom_palette = {
    "Gaussian": "C0",
    "Regularized Horseshoe": "C1",
    "Dirichlet Horseshoe": "C2",
    "Dirichlet Student T": "C3",
}

# Set up plot
fig, axes = plt.subplots(1, 2, figsize=(16, 5), sharex=True, sharey=True)

activation_data = [("ReLU", df_rmse_full_relu), ("tanh", df_rmse_full_tanh)]
all_handles_labels = []

# Plot
for idx, (name, df) in enumerate(activation_data):
    ax = axes[idx]
    
    sns.lineplot(
        data=df,
        x='sparsity', y='rmse',
        hue='model', style='N', marker='o', errorbar=None, ax=ax,
        palette=custom_palette
    )

    handles, labels = ax.get_legend_handles_labels()
    all_handles_labels.extend(zip(handles, labels))
    ax.get_legend().remove()

    ax.set_title(f"{name} activation")
    ax.set_ylabel("RMSE")
    ax.set_xlabel("Sparsity Level")
    ax.grid(True)

# Clean legend
legend_dict = OrderedDict()
for handle, label in all_handles_labels:
    if label not in {"model", "N"} and label not in legend_dict:
        legend_dict[label] = handle

desired_order = ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T"]
filtered = [(legend_dict[label], label) for label in desired_order if label in legend_dict]

if filtered:
    filtered_handles, filtered_labels = zip(*filtered)
    fig.legend(
        filtered_handles,
        filtered_labels,
        title="Model",
        loc="upper center",
        bbox_to_anchor=(0.5, 1.05),
        ncol=2,
    )

plt.tight_layout(rect=[0, 0.05, 1, 1])
plt.show()


## VISUALIZE

In [7]:
from utils.visualize_networks import compute_activation_frequency, extract_all_pruned_means, plot_all_networks_subplots_activations
path = "datasets/friedman/Friedman_N100_p10_sigma1.00_seed1.npz"
data = np.load(path)
x_train = data["X_train"]

In [None]:
node_activation_colors = {
    model_name: compute_activation_frequency(fit, x_train)
    for model_name, fit in relu_fits['Friedman_N100_p10_sigma1.00_seed1'].items()
}

# Flatten and find the global maximum
all_freqs = np.concatenate(list(node_activation_colors.values()))
global_max = all_freqs.max()
print(global_max)

In [9]:
P = 10
H = 16
L = 1
out_nodes = 1
layer_sizes = [P] + [H]*L + [out_nodes]

layer_structure = {
    'input_to_hidden': {'name': 'W_1', 'shape': (P, H)},
    'hidden_to_output': {'name': 'W_L', 'shape': (H, out_nodes)}
}

sparsity_level = 0.5

In [None]:
pruned_model_means = extract_all_pruned_means(relu_fits['Friedman_N100_p10_sigma1.00_seed1'], layer_structure, sparsity_level)

p1, widths_1 = plot_all_networks_subplots_activations(pruned_model_means, layer_sizes, node_activation_colors, activation_color_max=global_max, signed_colors=False)


Move Networks.ipynb into this file to show the networks

## Node pruning

In [None]:
from utils.sparsity import compute_sparse_rmse_results, prune_nodes_by_output_weights

df_rmse_node_relu, df_posterior_rmse_node_relu = {}, {}
df_rmse_node_tanh, df_posterior_rmse_node_tanh = {}, {}

def nodes_to_sparsity(nodes_to_prune_list, total_nodes):
    """
    Convert a list of node counts to prune into sparsity levels.

    Args:
        nodes_to_prune_list: list of integers (number of nodes to prune).
        total_nodes: total number of nodes in the layer.

    Returns:
        List of sparsity levels between 0.0 and 1.0.
    """
    sparsity_levels = [round(n_prune / total_nodes, 4) for n_prune in nodes_to_prune_list]
    return sparsity_levels

# Suppose you have 16 nodes in the hidden layer
total_nodes = 16
nodes_to_prune = [0, 1, 2, 4, 6, 8, 10, 12, 14]

node_sparsity = nodes_to_sparsity(nodes_to_prune, total_nodes)
print(node_sparsity)  
# Output: [0.0, 0.0625, 0.125, 0.25, 0.5, 0.75, 0.875]

for sparsity in node_sparsity:
    df_rmse_node_relu[sparsity], df_posterior_rmse_node_relu[sparsity] = compute_sparse_rmse_results(
    seeds, model_names_relu, relu_fits, get_N_sigma, forward_pass_relu,
    sparsity=sparsity, prune_fn=prune_nodes_by_output_weights
)
    
    df_rmse_node_tanh[sparsity], df_posterior_rmse_node_tanh[sparsity] = compute_sparse_rmse_results(
    seeds, model_names_tanh, tanh_fits, get_N_sigma, forward_pass_tanh,
    sparsity=sparsity, prune_fn=prune_nodes_by_output_weights
)

In [12]:
import pandas as pd

df_rmse_full_node_relu = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_node_relu.items()],
    ignore_index=True
)

df_rmse_full_node_tanh = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_node_tanh.items()],
    ignore_index=True
)


In [13]:
total_nodes = 16  # adjust this if needed

df_rmse_full_node_relu['nodes_pruned'] = (df_rmse_full_node_relu['sparsity'] * total_nodes).astype(int)
df_rmse_full_node_tanh['nodes_pruned'] = (df_rmse_full_node_tanh['sparsity'] * total_nodes).astype(int)


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from collections import OrderedDict

# Clean model names
df_rmse_full_node_relu = df_rmse_full_node_relu.copy()
df_rmse_full_node_relu["model"] = df_rmse_full_node_relu["model"].str.replace(" tanh", "", regex=False)

df_rmse_full_node_tanh = df_rmse_full_node_tanh.copy()
df_rmse_full_node_tanh["model"] = df_rmse_full_node_tanh["model"].str.replace(" tanh", "", regex=False)

# Define consistent color palette
custom_palette = {
    "Gaussian": "C0",
    "Regularized Horseshoe": "C1",
    "Dirichlet Horseshoe": "C2",
    "Dirichlet Student T": "C3",
}

# Set up plot
fig, axes = plt.subplots(1, 2, figsize=(16, 5), sharex=True, sharey=True)

activation_data = [("ReLU", df_rmse_full_node_relu), ("tanh", df_rmse_full_node_tanh)]
all_handles_labels = []

# Plot
for idx, (name, df) in enumerate(activation_data):
    ax = axes[idx]
    
    sns.lineplot(
        data=df,
        x='nodes_pruned', y='rmse',
        hue='model', style='N', marker='o', errorbar=None, ax=ax,
        palette=custom_palette
    )

    handles, labels = ax.get_legend_handles_labels()
    all_handles_labels.extend(zip(handles, labels))
    ax.get_legend().remove()

    ax.set_title(f"{name} activation")
    ax.set_ylabel("RMSE")
    ax.set_xlabel("Nodes pruned")
    ax.grid(True)

# Clean legend
legend_dict = OrderedDict()
for handle, label in all_handles_labels:
    if label not in {"model", "N"} and label not in legend_dict:
        legend_dict[label] = handle

desired_order = ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T"]
filtered = [(legend_dict[label], label) for label in desired_order if label in legend_dict]

if filtered:
    filtered_handles, filtered_labels = zip(*filtered)
    fig.legend(
        filtered_handles,
        filtered_labels,
        title="Model",
        loc="upper center",
        bbox_to_anchor=(0.5, 1.05),
        ncol=2,
    )

plt.tight_layout(rect=[0, 0.05, 1, 1])
plt.show()

## Visualize

In [15]:
from utils.visualize_networks import compute_activation_frequency, extract_all_pruned_means, plot_all_networks_subplots_activations
path = "datasets/friedman/Friedman_N100_p10_sigma1.00_seed1.npz"
data = np.load(path)
x_train = data["X_train"]

In [None]:
node_activation_colors = {
    model_name: compute_activation_frequency(fit, x_train)
    for model_name, fit in relu_fits['Friedman_N100_p10_sigma1.00_seed1'].items()
}

# Flatten and find the global maximum
all_freqs = np.concatenate(list(node_activation_colors.values()))
global_max = all_freqs.max()
print(global_max)

## NODE PRUNE VISUALIZE

In [None]:
from utils.visualize_networks import extract_all_pruned_node_means, plot_all_networks_subplots_activations
num_nodes_to_prune = 14  # for example
pruned_model_means_nodes = extract_all_pruned_node_means(relu_fits['Friedman_N100_p10_sigma1.00_seed1'], layer_structure, num_nodes_to_prune)

p_nodes, widths_nodes = plot_all_networks_subplots_activations(
    pruned_model_means_nodes, layer_sizes, node_activation_colors,
    activation_color_max=global_max, signed_colors=False
)
