In [None]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
#import os; os.chdir(os.path.dirname(os.getcwd()))
from utils.model_loader import get_model_fits
import numpy as np
import pandas as pd
import re
from sklearn.metrics import mean_squared_error
import seaborn as sns
import matplotlib.pyplot as plt


In [None]:
data_dir = f"datasets/friedman_correlated"
#results_dir_relu = "results/regression/single_layer/relu/friedman"
results_dir_tanh = "results/regression/single_layer/tanh/friedman_correlated"

#model_names_relu = ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T"]
model_names_tanh = ["Gaussian tanh", "Regularized Horseshoe tanh", "Dirichlet Horseshoe tanh"]#, "Dirichlet Student T tanh"]


#relu_fits = {}
tanh_fits = {}

files = sorted(f for f in os.listdir(data_dir) if f.endswith(".npz"))
for fname in files:
    base_config_name = fname.replace(".npz", "")  # e.g., "GAM_N100_p8_sigma1.00_seed1"
    full_config_path = f"{base_config_name}"  # → "type_1/GAM_N100_p8_sigma1.00_seed1"
    # relu_fit = get_model_fits(
    #     config=full_config_path,
    #     results_dir=results_dir_relu,
    #     models=model_names_relu,
    #     include_prior=False,
    # )
    
    tanh_fit = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_tanh,
        models=model_names_tanh,
        include_prior=False,
    )
    

    #relu_fits[base_config_name] = relu_fit  # use clean key
    tanh_fits[base_config_name] = tanh_fit  # use clean key
    


In [None]:
from sklearn.metrics import mean_squared_error
from properscoring import crps_ensemble
import numpy as np
import pandas as pd

path = f"datasets/friedman_correlated/Friedman_N200_p10_sigma1.0_seed6.npz"
data = np.load(path)
X_test, y_test = data["X_test"], data["y_test"]
rows = []
for model_name, model_entry in tanh_fits['Friedman_N200_p10_sigma1.0_seed6'].items():
    post = model_entry["posterior"]

    # (S, n_test)
    y_samps = post.stan_variable("output_test").squeeze(-1)

    # Optional: limit to first S draws if desired
    # S = min(4000, y_samps.shape[0])
    # y_samps = y_samps[:S]

    # Posterior-mean predictions and RMSE
    y_mean = y_samps.mean(axis=0)                                   # (n_test,)
    rmse_post_mean = float(np.sqrt(mean_squared_error(y_test, y_mean)))

    # Per-draw RMSEs and their mean
    per_draw_rmse = np.sqrt(((y_samps - y_test[None, :])**2).mean(axis=1))  # (S,)
    rmse_draw_mean = float(per_draw_rmse.mean())

    # CRPS across the ensemble (expects shape (n_test, S))
    crps = float(np.mean(crps_ensemble(y_test, y_samps.T)))

    rows.append({
        "Model": model_name,
        "RMSE_posterior_mean": rmse_post_mean,
        "RMSE_mean_over_draws": rmse_draw_mean,
        "CRPS": crps,
        "n_draws": y_samps.shape[0]
    })

results_df = pd.DataFrame(rows).sort_values("RMSE_posterior_mean")
print(results_df)


In [4]:
def compute_sparse_rmse_results(seeds, models, all_fits, get_N_sigma, forward_pass,
                         sparsity=0.0, prune_fn=None):
    results = []
    posterior_means = []

    for seed in seeds:
        N, sigma = get_N_sigma(seed)
        dataset_key = f'Friedman_N{N}_p10_sigma{sigma:.1f}_seed{seed}'
        path = f"datasets/friedman_correlated/{dataset_key}.npz"
        try:
            data = np.load(path)
            X_test, y_test = data["X_test"], data["y_test"]
        except FileNotFoundError:
            print(f"[SKIP] File not found: {path}")
            continue

        for model in models:
            try:
                fit = all_fits[dataset_key][model]['posterior']
                W1_samples = fit.stan_variable("W_1")           # (S, P, H)
                W2_samples = fit.stan_variable("W_L")           # (S, H, O)
                b1_samples = fit.stan_variable("hidden_bias")   # (S, O, H)
                b2_samples = fit.stan_variable("output_bias")   # (S, O)
            except KeyError:
                print(f"[SKIP] Model or posterior not found: {dataset_key} -> {model}")
                continue

            S = W1_samples.shape[0]
            rmses = np.zeros(S)
            #print(y_test.shape)
            y_hats = np.zeros((S, y_test.shape[0]))

            for i in range(S):
                W1 = W1_samples[i]
                W2 = W2_samples[i]

                # Apply pruning mask if requested
                if prune_fn is not None and sparsity > 0.0:
                    masks = prune_fn([W1, W2], sparsity)
                    W1 = W1 * masks[0]
                    #W2 = W2 * masks[1]

                y_hat = forward_pass(X_test, W1, b1_samples[i][0], W2, b2_samples[i])
                y_hats[i] = y_hat.squeeze()  # Store the prediction for each sample
                rmses[i] = np.sqrt(np.mean((y_hat.squeeze() - y_test)**2))
                
            posterior_mean = np.mean(y_hats, axis=0)
            posterior_mean_rmse = np.sqrt(np.mean((posterior_mean - y_test.squeeze())**2))

            posterior_means.append({
                'seed': seed,
                'N': N,
                'sigma': sigma,
                'model': model,
                'sparsity': sparsity,
                'posterior_mean_rmse': posterior_mean_rmse
            })

            for i in range(S):
                results.append({
                    'seed': seed,
                    'N': N,
                    'sigma': sigma,
                    'model': model,
                    'sparsity': sparsity,
                    'rmse': rmses[i]
                })

    df_rmse = pd.DataFrame(results)
    df_posterior_rmse = pd.DataFrame(posterior_means)

    return df_rmse, df_posterior_rmse


from utils.sparsity import forward_pass_relu, forward_pass_tanh, local_prune_weights

sparsity_levels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]

seeds = [1, 6]#, 11]

def get_N_sigma(seed):
    if seed == 1:
        N=100
    elif seed == 6:
        N=200
    else:
        N=500
    sigma=1.00
    return N, sigma

In [5]:
#df_rmse_relu, df_posterior_rmse_relu = {}, {}
df_rmse_tanh, df_posterior_rmse_tanh = {}, {}

for sparsity in sparsity_levels:
    # df_rmse_relu[sparsity], df_posterior_rmse_relu[sparsity] = compute_sparse_rmse_results(
    #     seeds, model_names_relu, relu_fits, get_N_sigma, forward_pass_relu,
    #     sparsity=sparsity, prune_fn=local_prune_weights
    # )
    
    df_rmse_tanh[sparsity], df_posterior_rmse_tanh[sparsity] = compute_sparse_rmse_results(
        seeds, model_names_tanh, tanh_fits, get_N_sigma, forward_pass_tanh,
        sparsity=sparsity, prune_fn=local_prune_weights
    )

In [18]:
import pandas as pd

# df_rmse_full_relu = pd.concat(
#     [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_relu.items()],
#     ignore_index=True
# )

df_rmse_full_tanh = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_tanh.items()],
    ignore_index=True
)


In [6]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def plot_sparsity_rmse(df_dict, metric="rmse", show_boxplot=False, title_prefix='RMSE vs sparsity', sample_points=2000):
    """
    df_dict: dict mapping sparsity -> dataframe containing columns ['model', 'rmse']
    show_boxplot: if True, include boxplot+points figure (figure #2)
    title_prefix: title text for figure 1
    sample_points: number of stripplot points to sample (to avoid huge overplotting)
    """
    
    # Stack dataframes
    frames = []
    for sp, df in df_dict.items():
        df = df.copy()
        df['sparsity'] = float(sp)
        frames.append(df)

    df_all = pd.concat(frames, axis=0, ignore_index=True)

    # Clean
    df_all['sparsity'] = df_all['sparsity'].astype(float)
    df_all['rmse'] = pd.to_numeric(df_all[metric], errors='coerce')
    df_all = df_all.dropna(subset=['rmse', 'sparsity', 'model'])

    # Sort models consistently
    models_order = sorted(df_all['model'].unique())
    df_all['model'] = pd.Categorical(df_all['model'], categories=models_order, ordered=True)

    # Summary statistics
    summary = (
        df_all.groupby(['model', 'sparsity'])
        .agg(n=('rmse', 'size'),
             mean_rmse=('rmse', 'mean'),
             std_rmse=('rmse', 'std'))
        .reset_index()
    )
    summary['sem'] = summary['std_rmse'] / np.sqrt(summary['n'])
    summary['ci95'] = 1.96 * summary['sem']
    summary['ymin'] = summary['mean_rmse'] - summary['ci95']
    summary['ymax'] = summary['mean_rmse'] + summary['ci95']

    # Plot styling
    sns.set_context('talk')
    sns.set_style('whitegrid')

    # ---- Figure 1: mean + CI ----
    plt.figure(figsize=(10, 6))
    sns.lineplot(
        data=summary.sort_values(['model', 'sparsity']),
        x='sparsity', y='mean_rmse', hue='model',
        linewidth=2.5, marker='o', markersize=7
    )
    # error bars
    for _, row in summary.iterrows():
        plt.plot([row['sparsity'], row['sparsity']], [row['ymin'], row['ymax']],
                 color=sns.color_palette()[models_order.index(row['model'])], lw=2)

    plt.title(f'{title_prefix} (mean ± 95% CI)')
    plt.xlabel('Sparsity')
    plt.ylabel('RMSE')
    plt.legend(title='Model', loc='best')
    plt.tight_layout()
    plt.show()

    # ---- Optionally: Figure 2 ----
    if show_boxplot:
        plt.figure(figsize=(12, 6))
        sns.boxplot(
            data=df_all,
            x='sparsity', y='rmse', hue='model',
            showfliers=False,
            linewidth=1.2
        )

        # Only sample points to avoid heavy overplot
        df_sample = df_all.sample(min(len(df_all), sample_points), random_state=42)
        sns.stripplot(
            data=df_sample,
            x='sparsity', y='rmse', hue='model',
            dodge=True, size=2, alpha=0.25, palette='dark'
        )

        # Remove duplicate legend caused by stripplot
        handles, labels = plt.gca().get_legend_handles_labels()
        plt.legend(handles[:len(models_order)], labels[:len(models_order)], title='Model', loc='best')

        plt.title('RMSE distribution per sparsity and model')
        plt.xlabel('Sparsity')
        plt.ylabel('RMSE')
        plt.tight_layout()
        plt.show()

    return df_all, summary


In [None]:
df_all, summary = plot_sparsity_rmse(df_rmse_tanh, show_boxplot=False)
df_all, summary = plot_sparsity_rmse(df_posterior_rmse_tanh, metric="posterior_mean_rmse", show_boxplot=False)


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from collections import OrderedDict

abbr = {
    "Gaussian": "Gaussian",
    "Regularized Horseshoe": "RHS",
    "Dirichlet Horseshoe": "DHS",
    "Dirichlet Student T": "DS-T"
}

# Clean model names
df_rmse_full_relu = df_rmse_full_relu.copy()
df_rmse_full_relu["model"] = df_rmse_full_relu["model"].str.replace(" tanh", "", regex=False)

df_rmse_full_tanh = df_rmse_full_tanh.copy()
df_rmse_full_tanh["model"] = df_rmse_full_tanh["model"].str.replace(" tanh", "", regex=False)

# Define consistent color palette
custom_palette = {
    "Gaussian": "C0",
    "Regularized Horseshoe": "C1",
    "Dirichlet Horseshoe": "C2",
    "Dirichlet Student T": "C3",
}

# Set up plot
fig, axes = plt.subplots(1, 2, figsize=(9, 7), sharex=True, sharey=True)

activation_data = [("ReLU", df_rmse_full_relu), ("tanh", df_rmse_full_tanh)]
all_handles_labels = []

# Plot
for idx, (name, df) in enumerate(activation_data):
    ax = axes[idx]
    
    sns.lineplot(
        data=df,
        x='sparsity', y='rmse',
        hue='model', style='N', marker='o', errorbar=None, ax=ax,
        palette=custom_palette
    )

    handles, labels = ax.get_legend_handles_labels()
    all_handles_labels.extend(zip(handles, labels))
    ax.get_legend().remove()

    ax.set_title(f"{name}")
    ax.set_ylabel("RMSE")
    ax.set_xlabel("Sparsity Level")
    #ax.set_ylim((0.15, 3))
    ax.grid(True)
handles, labels = axes[0].get_legend_handles_labels()

# Filter only the N-related entries (exclude model names)
n_handles = []
n_labels = []
for h, l in zip(handles, labels):
    if l.startswith("N=") or l.isdigit() or l.strip().lower().startswith("n"):  # adjust if your N labels differ
        n_handles.append(h)
        n_labels.append(l)

# Add legend for N (line styles) to the LEFT subplot
if n_handles:
    axes[0].legend(
        n_handles,
        n_labels,
        title=None,
        loc="upper left",
        bbox_to_anchor=(0.02, 0.98),
        frameon=True,
        fontsize='medium'
    )

# Clean legend
legend_dict = OrderedDict()
for handle, label in all_handles_labels:
    if label not in {"model", "N"} and label not in legend_dict:
        legend_dict[label] = handle

desired_order = ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T"]
filtered = [(legend_dict[label], abbr[label]) for label in desired_order if label in legend_dict]

if filtered:
    filtered_handles, filtered_labels = zip(*filtered)
    fig.legend(
        filtered_handles,
        filtered_labels,
        title=None,
        loc="upper left",
        bbox_to_anchor=(0.75, 0.9),
        ncol=1,
        fontsize='large'
    )

plt.tight_layout(rect=[0, 0.05, 1, 1])
plt.show()


In [7]:
import pandas as pd

df_rmse_full_relu = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_posterior_rmse_relu.items()],
    ignore_index=True
)

df_rmse_full_tanh = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_posterior_rmse_tanh.items()],
    ignore_index=True
)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from collections import OrderedDict

abbr = {
    "Gaussian": "Gaussian",
    "Regularized Horseshoe": "RHS",
    "Dirichlet Horseshoe": "DHS",
    "Dirichlet Student T": "DS-T"
}

# Clean model names
df_rmse_full_relu = df_rmse_full_relu.copy()
df_rmse_full_relu["model"] = df_rmse_full_relu["model"].str.replace(" tanh", "", regex=False)

df_rmse_full_tanh = df_rmse_full_tanh.copy()
df_rmse_full_tanh["model"] = df_rmse_full_tanh["model"].str.replace(" tanh", "", regex=False)

# Define consistent color palette
custom_palette = {
    "Gaussian": "C0",
    "Regularized Horseshoe": "C1",
    "Dirichlet Horseshoe": "C2",
    "Dirichlet Student T": "C3",
}

# Set up plot
fig, axes = plt.subplots(1, 2, figsize=(9, 7), sharex=True, sharey=True)

activation_data = [("ReLU", df_rmse_full_relu), ("tanh", df_rmse_full_tanh)]
all_handles_labels = []

# Plot
for idx, (name, df) in enumerate(activation_data):
    ax = axes[idx]
    
    sns.lineplot(
        data=df,
        x='sparsity', y='posterior_mean_rmse',
        hue='model', style='N', marker='o', errorbar=None, ax=ax,
        palette=custom_palette
    )

    handles, labels = ax.get_legend_handles_labels()
    all_handles_labels.extend(zip(handles, labels))
    ax.get_legend().remove()

    ax.set_title(f"{name}")
    ax.set_ylabel("RMSE")
    ax.set_xlabel("Sparsity Level")
    #ax.set_ylim((0.15, 3))
    ax.grid(True)
handles, labels = axes[0].get_legend_handles_labels()

# Filter only the N-related entries (exclude model names)
n_handles = []
n_labels = []
for h, l in zip(handles, labels):
    if l.startswith("N=") or l.isdigit() or l.strip().lower().startswith("n"):  # adjust if your N labels differ
        n_handles.append(h)
        n_labels.append(l)

# Add legend for N (line styles) to the LEFT subplot
if n_handles:
    axes[0].legend(
        n_handles,
        n_labels,
        title=None,
        loc="upper left",
        bbox_to_anchor=(0.02, 0.98),
        frameon=True,
        fontsize='medium'
    )

# Clean legend
legend_dict = OrderedDict()
for handle, label in all_handles_labels:
    if label not in {"model", "N"} and label not in legend_dict:
        legend_dict[label] = handle

desired_order = ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T"]
filtered = [(legend_dict[label], abbr[label]) for label in desired_order if label in legend_dict]

if filtered:
    filtered_handles, filtered_labels = zip(*filtered)
    fig.legend(
        filtered_handles,
        filtered_labels,
        title=None,
        loc="upper left",
        bbox_to_anchor=(0.75, 0.9),
        ncol=1,
        fontsize='large'
    )

plt.tight_layout(rect=[0, 0.05, 1, 1])
plt.show()


## VISUALIZE

In [10]:
from utils.visualize_networks import compute_activation_frequency, extract_all_pruned_means, plot_all_networks_subplots_activations
path = "datasets/friedman/many/Friedman_N100_p10_sigma1.00_seed6.npz"
data = np.load(path)
x_train = data["X_train"]

In [None]:
node_activation_colors = {
    model_name: compute_activation_frequency(fit, x_train)
    for model_name, fit in relu_fits['Friedman_N100_p10_sigma1.00_seed6'].items()
}

# Flatten and find the global maximum
all_freqs = np.concatenate(list(node_activation_colors.values()))
global_max = all_freqs.max()
print(global_max)

In [13]:
P = 10
H = 16
L = 1
out_nodes = 1
layer_sizes = [P] + [H]*L + [out_nodes]

layer_structure = {
    'input_to_hidden': {'name': 'W_1', 'shape': (P, H)},
    'hidden_to_output': {'name': 'W_L', 'shape': (H, out_nodes)}
}

sparsity_level = 0.5

In [None]:
pruned_model_means = extract_all_pruned_means(relu_fits['Friedman_N100_p10_sigma1.00_seed6'], layer_structure, sparsity_level)

p1, widths_1 = plot_all_networks_subplots_activations(pruned_model_means, layer_sizes, node_activation_colors, activation_color_max=global_max, signed_colors=False)


Move Networks.ipynb into this file to show the networks

## Node pruning

In [None]:
from utils.sparsity import compute_sparse_rmse_results, prune_nodes_by_output_weights

df_rmse_node_relu, df_posterior_rmse_node_relu = {}, {}
df_rmse_node_tanh, df_posterior_rmse_node_tanh = {}, {}

def nodes_to_sparsity(nodes_to_prune_list, total_nodes):
    """
    Convert a list of node counts to prune into sparsity levels.

    Args:
        nodes_to_prune_list: list of integers (number of nodes to prune).
        total_nodes: total number of nodes in the layer.

    Returns:
        List of sparsity levels between 0.0 and 1.0.
    """
    sparsity_levels = [round(n_prune / total_nodes, 4) for n_prune in nodes_to_prune_list]
    return sparsity_levels

# Suppose you have 16 nodes in the hidden layer
total_nodes = 16
nodes_to_prune = [0, 1, 2, 4, 6, 8, 10, 12, 14]

node_sparsity = nodes_to_sparsity(nodes_to_prune, total_nodes)
print(node_sparsity)  
# Output: [0.0, 0.0625, 0.125, 0.25, 0.5, 0.75, 0.875]

for sparsity in node_sparsity:
    df_rmse_node_relu[sparsity], df_posterior_rmse_node_relu[sparsity] = compute_sparse_rmse_results(
    seeds, model_names_relu, relu_fits, get_N_sigma, forward_pass_relu,
    sparsity=sparsity, prune_fn=prune_nodes_by_output_weights
)
    
    df_rmse_node_tanh[sparsity], df_posterior_rmse_node_tanh[sparsity] = compute_sparse_rmse_results(
    seeds, model_names_tanh, tanh_fits, get_N_sigma, forward_pass_tanh,
    sparsity=sparsity, prune_fn=prune_nodes_by_output_weights
)

In [16]:
import pandas as pd

df_rmse_full_node_relu = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_node_relu.items()],
    ignore_index=True
)

df_rmse_full_node_tanh = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_node_tanh.items()],
    ignore_index=True
)


In [17]:
total_nodes = 16  # adjust this if needed

df_rmse_full_node_relu['nodes_pruned'] = (df_rmse_full_node_relu['sparsity'] * total_nodes).astype(int)
df_rmse_full_node_tanh['nodes_pruned'] = (df_rmse_full_node_tanh['sparsity'] * total_nodes).astype(int)


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from collections import OrderedDict

# Clean model names
df_rmse_full_node_relu = df_rmse_full_node_relu.copy()
df_rmse_full_node_relu["model"] = df_rmse_full_node_relu["model"].str.replace(" tanh", "", regex=False)

df_rmse_full_node_tanh = df_rmse_full_node_tanh.copy()
df_rmse_full_node_tanh["model"] = df_rmse_full_node_tanh["model"].str.replace(" tanh", "", regex=False)

# Define consistent color palette
custom_palette = {
    "Gaussian": "C0",
    "Regularized Horseshoe": "C1",
    "Dirichlet Horseshoe": "C2",
    "Dirichlet Student T": "C3",
}

# Set up plot
fig, axes = plt.subplots(1, 2, figsize=(16, 5), sharex=True, sharey=True)

activation_data = [("ReLU", df_rmse_full_node_relu), ("tanh", df_rmse_full_node_tanh)]
all_handles_labels = []

# Plot
for idx, (name, df) in enumerate(activation_data):
    ax = axes[idx]
    
    sns.lineplot(
        data=df,
        x='nodes_pruned', y='rmse',
        hue='model', style='N', marker='o', errorbar=None, ax=ax,
        palette=custom_palette
    )

    handles, labels = ax.get_legend_handles_labels()
    all_handles_labels.extend(zip(handles, labels))
    ax.get_legend().remove()

    ax.set_title(f"{name} activation")
    ax.set_ylabel("RMSE")
    ax.set_xlabel("Nodes pruned")
    ax.grid(True)

# Clean legend
legend_dict = OrderedDict()
for handle, label in all_handles_labels:
    if label not in {"model", "N"} and label not in legend_dict:
        legend_dict[label] = handle

desired_order = ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T"]
filtered = [(legend_dict[label], label) for label in desired_order if label in legend_dict]

if filtered:
    filtered_handles, filtered_labels = zip(*filtered)
    fig.legend(
        filtered_handles,
        filtered_labels,
        title="Model",
        loc="upper center",
        bbox_to_anchor=(0.5, 1.05),
        ncol=2,
    )

plt.tight_layout(rect=[0, 0.05, 1, 1])
plt.show()

## Visualize

In [20]:
from utils.visualize_networks import compute_activation_frequency, extract_all_pruned_means, plot_all_networks_subplots_activations
path = "datasets/friedman/many/Friedman_N100_p10_sigma1.00_seed6.npz"
data = np.load(path)
x_train = data["X_train"]

In [None]:
node_activation_colors = {
    model_name: compute_activation_frequency(fit, x_train)
    for model_name, fit in relu_fits['Friedman_N100_p10_sigma1.00_seed6'].items()
}

# Flatten and find the global maximum
all_freqs = np.concatenate(list(node_activation_colors.values()))
global_max = all_freqs.max()
print(global_max)

## NODE PRUNE VISUALIZE

In [None]:
from utils.visualize_networks import extract_all_pruned_node_means, plot_all_networks_subplots_activations
num_nodes_to_prune = 14  # for example
pruned_model_means_nodes = extract_all_pruned_node_means(relu_fits['Friedman_N100_p10_sigma1.00_seed6'], layer_structure, num_nodes_to_prune)

p_nodes, widths_nodes = plot_all_networks_subplots_activations(
    pruned_model_means_nodes, layer_sizes, node_activation_colors,
    activation_color_max=global_max, signed_colors=False
)
