In [None]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
#import os; os.chdir(os.path.dirname(os.getcwd()))
from model_loader import get_model_fits
import numpy as np
import pandas as pd
import re
from sklearn.metrics import mean_squared_error
import seaborn as sns
import matplotlib.pyplot as plt


## PERFORMANCE

In [None]:
data_config = 1
data_dir = f"datasets/interactions"
results_dir_relu = "results_relu_interaction"
results_dir_tanh = "results_tanh_interaction"
model_names_relu = ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T"]
model_names_tanh = ["Regularized Horseshoe tanh", "Dirichlet Horseshoe tanh", "Dirichlet Student T tanh"]

relu_fits = {}
tanh_fits = {}

files = sorted(f for f in os.listdir(data_dir) if f.endswith(".npz"))
for fname in files:
    base_config_name = fname.replace(".npz", "")  # e.g., "GAM_N100_p8_sigma1.00_seed1"
    full_config_path = f"{base_config_name}"  # â†’ "type_1/GAM_N100_p8_sigma1.00_seed1"
    relu_fit = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_relu,
        models=model_names_relu,
        include_prior=False,
    )
    
    tanh_fit = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_tanh,
        models=model_names_tanh,
        include_prior=False,
    )

    relu_fits[base_config_name] = relu_fit  # use clean key
    tanh_fits[base_config_name] = tanh_fit  # use clean key
    


In [3]:
def get_N_sigma(seed):
    N = 100 if seed in [1, 2, 3, 4] else 200
    sigma = 1.0 if seed in [1, 2, 7, 8] else 3.0
    if seed == 19:
        N, sigma = 1000, 1.0
    return N, sigma

def compute_rmse_results(seeds, models, all_fits, get_N_sigma):
    results = []
    posterior_means = []

    for seed in seeds:
        N, sigma = get_N_sigma(seed)
        dataset_key = f'INT_N{N}_p8_sigma{sigma:.2f}_seed{seed}'
        path = f"datasets/interactions/{dataset_key}.npz"

        try:
            data = np.load(path)
            y_test = data["y_test"]
        except FileNotFoundError:
            print(f"[SKIP] File not found: {path}")
            continue

        for model in models:
            try:
                fit = all_fits[dataset_key][model]['posterior']
                output_test = fit.stan_variable("output_test")  # (S, N/5, O)
            except KeyError:
                print(f"[SKIP] Model or posterior not found: {dataset_key} -> {model}")
                continue

            S = output_test.shape[0]
            rmses = np.zeros(S)

            for i in range(S):
                y_hat = output_test[i]
                rmses[i] = np.sqrt(np.mean((y_hat.squeeze() - y_test.squeeze()) ** 2))
                
            posterior_mean = np.mean(output_test, axis=0)
            posterior_mean_rmse = np.sqrt(np.mean((posterior_mean.squeeze() - y_test.squeeze()) ** 2))

            posterior_means.append({
                'seed': seed,
                'N': N,
                'sigma': sigma,
                'model': model,
                'posterior_mean_rmse': posterior_mean_rmse
            })

            for i in range(S):
                results.append({
                    'seed': seed,
                    'N': N,
                    'sigma': sigma,
                    'model': model,
                    'rmse': rmses[i]
                })

    df_rmse = pd.DataFrame(results)
    df_posterior_rmse = pd.DataFrame(posterior_means)

    return df_rmse, df_posterior_rmse


In [4]:
seeds = [1, 2, 3, 4, 7, 8, 9, 10]#, 19]


df_rmse_relu, df_posterior_rmse_relu = compute_rmse_results(
    seeds, model_names_relu, relu_fits, get_N_sigma
)


df_rmse_tanh, df_posterior_rmse_tanh = compute_rmse_results(
    seeds, model_names_tanh, tanh_fits, get_N_sigma
)

In [5]:
#df_rmse_tanh[df_rmse_tanh['model'] == 'Dirichlet Horseshoe tanh']

In [6]:
#df_rmse_tanh[df_rmse_tanh['model'] == 'Regularized Horseshoe tanh']

In [None]:
df_gauss = df_rmse_relu[df_rmse_relu['model'] == 'Gaussian']
rmse_gauss = df_gauss[df_gauss['seed'] == 10]['rmse'].mean()

df_rhs = df_rmse_relu[df_rmse_relu['model'] == 'Regularized Horseshoe']
rmse_rhs = df_rhs[df_rhs['seed'] == 10]['rmse'].mean()

df_dhs = df_rmse_relu[df_rmse_relu['model'] == 'Dirichlet Horseshoe']
rmse_dhs = df_dhs[df_dhs['seed'] == 10]['rmse'].mean()

df_dst = df_rmse_relu[df_rmse_relu['model'] == 'Dirichlet Student T']
rmse_dst = df_dst[df_dst['seed'] == 10]['rmse'].mean()


print(f"RMSE Gaussian: {rmse_gauss:.3f}")
print(f"RMSE Regularized Horseshoe: {rmse_rhs:.3f}")
print(f"RMSE Dirichlet Horseshoe: {rmse_dhs:.3f}")
print(f"RMSE Dirichlet Student T: {rmse_dst:.3f}")




In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from matplotlib.ticker import FixedLocator

# Combine and tag activation
df_relu = df_rmse_relu.copy()
df_relu["activation"] = "ReLU"

df_tanh = df_rmse_tanh.copy()
df_tanh["activation"] = "tanh"

df_all = pd.concat([df_relu, df_tanh])
df_all["model"] = df_all["model"].str.replace(" tanh", "", regex=False)

# Order of models and activations
model_order = ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T"]
activation_order = ["ReLU", "tanh"]

fig, axes = plt.subplots(2, 2, figsize=(18, 7), sharey=False)

for i, N_val in enumerate([100, 200]):
    for j, sigma_val in enumerate([1.0, 3.0]):
        ax = axes[j, i]
        df_plot = df_all[(df_all["N"] == N_val) & (df_all["sigma"] == sigma_val)].copy()

        # Use model as x, activation as hue
        sns.boxplot(
            data=df_plot,
            x="model",
            y="rmse",
            hue="activation",
            order=model_order,
            hue_order=activation_order,
            ax=ax
        )

        ax.set_title(f"N = {N_val}, Sigma = {sigma_val}")
        ax.set_xlabel("")
        ax.set_ylabel("RMSE")
        if sigma_val == 1.0:
            ax.set_ylim(0, 8)
        else:
            ax.set_ylim(0, 12)
        ax.grid(True)

        # Only show legend on top left plot
        if i != 0 or j != 0:
            ax.get_legend().remove()

# Add shared legend at top center
handles, labels = axes[0, 0].get_legend_handles_labels()
fig.legend(handles, labels, title="Activation", loc="upper center", ncol=2)

plt.tight_layout(rect=[0, 0, 1, 0.93])
plt.show()


## CONVERGENCE

In [10]:
import pandas as pd
import numpy as np
import arviz as az

def get_N_sigma(seed):
    N = 100 if seed in [1, 2, 3, 4] else 200
    sigma = 1.0 if seed in [1, 2, 7, 8] else 3.0
    if seed == 19:
        N, sigma = 1000, 1.0
    return N, sigma

def get_all_convergence_diagnostics(all_fits):
    diagnostics = []

    for config_name, model_fits in all_fits.items():
        for model_name, fit in model_fits.items():
            try:
                idata = az.from_cmdstanpy(fit['posterior'])
                y_pred = fit['posterior'].stan_variable('output_test')
                
                path = f'datasets/interactions/{config_name}.npz'
                try:
                    data = np.load(path)
                    y_test = data["y_test"]
                except FileNotFoundError:
                    print(f"[SKIP] File not found: {path}")
                    continue
                
                S = y_pred.shape[0]
                rmses = np.zeros(S)
                for i in range(S):
                   rmses[i] = np.sqrt(np.mean((y_pred.squeeze() - y_test.squeeze()) ** 2))

                summary = az.summary(idata, var_names=["output"], round_to=3)
                
                rhat = summary["r_hat"]

                ess_bulk = summary["ess_bulk"]
                ess_tail = summary["ess_tail"]
                
                try:
                    seed = int(config_name.split("_seed")[-1])
                    N, sigma = get_N_sigma(seed)
                except:
                    N, sigma = np.nan, np.nan

                diagnostics.append({
                    #"config": config_name,
                    "model": model_name,
                    "max_rhat": rhat.max(),
                    "median_rhat": rhat.median(),
                    #"p95_rhat": rhat.quantile(0.95),
                    "rmse": np.mean(rmses, axis=0),
                    #"min_ess_bulk": ess_bulk.min(),
                    "median_ess_bulk": ess_bulk.median(),
                    #"p05_ess_bulk": ess_bulk.quantile(0.05),
                    #"min_ess_tail": ess_tail.min(),
                    "median_ess_tail": ess_tail.median(),
                    "N": N,
                    "sigma": sigma,
                    #"p05_ess_tail": ess_tail.quantile(0.05),
                    #"n_divergent": divergences
                })

            except Exception as e:
                diagnostics.append({
                    #"config": config_name,
                    "model": model_name,
                    "max_rhat": np.nan,
                    "median_rhat": np.nan,
                    "p95_rhat": np.nan,
                    "min_ess_bulk": np.nan,
                    "min_ess_tail": np.nan,
                    #"n_divergent": np.nan,
                    "error": str(e)
                })

    return pd.DataFrame(diagnostics)


In [11]:
relu_diagostic = get_all_convergence_diagnostics(relu_fits)
tanh_diagostic = get_all_convergence_diagnostics(tanh_fits)

In [12]:
relu_grouped = relu_diagostic.assign(row_index=lambda df: df.index) \
    .sort_values(["model", "row_index"]) \
    .drop(columns="row_index") \
    .reset_index(drop=True)

tanh_grouped = tanh_diagostic.assign(row_index=lambda df: df.index) \
    .sort_values(["model", "row_index"]) \
    .drop(columns="row_index") \
    .reset_index(drop=True)


In [118]:
latex_relu = relu_grouped.to_latex(index=False, float_format="%.3f")
latex_tanh = tanh_grouped.to_latex(index=False, float_format="%.3f")
#print(latex_relu)
#print(latex_tanh)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# First, remove "tanh" from model names in both DataFrames
relu_diagostic = relu_diagostic.copy()
relu_diagostic["model"] = relu_diagostic["model"].str.replace(" tanh", "", regex=False)

tanh_diagostic = tanh_diagostic.copy()
tanh_diagostic["model"] = tanh_diagostic["model"].str.replace(" tanh", "", regex=False)

# Create shared-axes figure
fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharex=True, sharey=True)

# Plot ReLU
sns.scatterplot(
    data=relu_diagostic,
    x="max_rhat", y="rmse",
    hue="model",
    style="sigma",
    size="N", sizes=(100, 300),
    ax=axes[0],
    legend=False
)
axes[0].set_title("ReLU")
axes[0].set_xlabel(r"Max $\hat{R}$")
axes[0].set_ylabel("RMSE")
axes[0].grid(True)

# Plot tanh and keep legend
plot_obj = sns.scatterplot(
    data=tanh_diagostic,
    x="max_rhat", y="rmse",
    hue="model",
    style="sigma",
    size="N", sizes=(100, 300),
    ax=axes[1]
)
axes[1].set_title(r"$\tanh$")
axes[1].set_xlabel(r"Max $\hat{R}$")
axes[1].set_ylabel("")
axes[1].grid(True)

# Extract and filter legend
handles, labels = axes[1].get_legend_handles_labels()
axes[1].legend_.remove()

exclude_labels = {"model", "N", "sigma"}
filtered = [(h, l) for h, l in zip(handles, labels) if l not in exclude_labels]
# Filter out "N =" entries
#filtered = [(h, l) for h, l in zip(handles, labels) if not l.startswith("N =")]
if filtered:
    filtered_handles, filtered_labels = zip(*filtered)
    #print(filtered_handles)
    print(filtered_labels)
    # Use fig.legend to add shared legend
    fig.legend(
        filtered_handles,
        filtered_labels,
        title="",
        loc="upper center",
        bbox_to_anchor=(0.5, 1),
        ncol=4,
    )

# Layout adjustments
fig.tight_layout()
plt.subplots_adjust(top=0.85)
plt.show()


## SPARSITY

In [16]:
def forward_pass_relu(X, W1, b1, W2, b2):
    """
    Forward pass for a single layer BNN.
    """
    pre_act_1 = X @ W1 + b1.reshape(1, -1)
    #pre_hidden += b1.reshape(1, -1)
    post_act_1 = np.maximum(0, pre_act_1)
    ouput = post_act_1 @ W2 + b2.reshape(1, -1)
    return ouput

def forward_pass_tanh(X, W1, b1, W2, b2):
    """
    Forward pass for a single layer BNN.
    """
    pre_act_1 = X @ W1 + b1.reshape(1, -1)
    #pre_hidden += b1.reshape(1, -1)
    post_act_1 = np.tanh(pre_act_1)
    ouput = post_act_1 @ W2 + b2.reshape(1, -1)
    return ouput

def compute_rmse_results(seeds, models, all_fits, get_N_sigma, forward_pass,
                         sparsity=0.0, prune_fn=None):
    results = []
    posterior_means = []

    for seed in seeds:
        N, sigma = get_N_sigma(seed)
        dataset_key = f'INT_N{N}_p8_sigma{sigma:.2f}_seed{seed}'
        path = f"datasets/interactions/{dataset_key}.npz"

        try:
            data = np.load(path)
            X_test, y_test = data["X_test"], data["y_test"]
        except FileNotFoundError:
            print(f"[SKIP] File not found: {path}")
            continue

        for model in models:
            try:
                fit = all_fits[dataset_key][model]['posterior']
                W1_samples = fit.stan_variable("W_1")           # (S, P, H)
                W2_samples = fit.stan_variable("W_L")           # (S, H, O)
                b1_samples = fit.stan_variable("hidden_bias")   # (S, O, H)
                b2_samples = fit.stan_variable("output_bias")   # (S, O)
            except KeyError:
                print(f"[SKIP] Model or posterior not found: {dataset_key} -> {model}")
                continue

            S = W1_samples.shape[0]
            rmses = np.zeros(S)
            #print(y_test.shape)
            y_hats = np.zeros((S, y_test.shape[0]))

            for i in range(S):
                W1 = W1_samples[i]
                W2 = W2_samples[i]

                # Apply pruning mask if requested
                if prune_fn is not None and sparsity > 0.0:
                    masks = prune_fn([W1, W2], sparsity)
                    W1 = W1 * masks[0]
                    W2 = W2 * masks[1]

                y_hat = forward_pass(X_test, W1, b1_samples[i][0], W2, b2_samples[i])
                y_hats[i] = y_hat.squeeze()  # Store the prediction for each sample
                rmses[i] = np.sqrt(np.mean((y_hat.squeeze() - y_test)**2))
                
            posterior_mean = np.mean(y_hats, axis=0)
            posterior_mean_rmse = np.sqrt(np.mean((posterior_mean - y_test.squeeze())**2))

            posterior_means.append({
                'seed': seed,
                'N': N,
                'sigma': sigma,
                'model': model,
                'sparsity': sparsity,
                'posterior_mean_rmse': posterior_mean_rmse
            })

            for i in range(S):
                results.append({
                    'seed': seed,
                    'N': N,
                    'sigma': sigma,
                    'model': model,
                    'sparsity': sparsity,
                    'rmse': rmses[i]
                })

    df_rmse = pd.DataFrame(results)
    df_posterior_rmse = pd.DataFrame(posterior_means)

    return df_rmse, df_posterior_rmse

sparsity_levels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]

seeds = [1, 2, 3, 4, 7, 8, 9, 10]

def get_N_sigma(seed):
    N = 100 if seed in [1, 2, 3, 4] else 200
    sigma = 1.0 if seed in [1, 2, 7, 8] else 3.0
    return N, sigma


def global_prune_weights(weight_matrices, sparsity_level):
    """
    Prune globally across multiple weight matrices.
    
    Args:
        weight_matrices: List of numpy arrays (weight matrices).
        sparsity_level: Float in [0, 1], fraction of weights to prune.

    Returns:
        List of binary masks with same shapes as weight_matrices.
    """
    # Flatten all weights and concatenate
    flat_weights = np.concatenate([w.flatten() for w in weight_matrices])
    abs_weights = np.abs(flat_weights)
    
    # Determine number of weights to prune
    total_weights = abs_weights.size
    num_to_prune = int(np.floor(sparsity_level * total_weights))

    # Get indices of smallest weights to prune
    prune_indices = np.argpartition(abs_weights, num_to_prune)[:num_to_prune]
    
    # Create global mask
    global_mask_flat = np.ones(total_weights, dtype=bool)
    global_mask_flat[prune_indices] = False

    # Split the global mask back into original shapes
    masks = []
    idx = 0
    for w in weight_matrices:
        size = w.size
        mask = global_mask_flat[idx:idx + size].reshape(w.shape)
        masks.append(mask.astype(float))
        idx += size

    return masks

def local_prune_weights(weights, sparsity_level, index_to_prune=0):
    """
    Apply pruning to only one weight matrix in a list, specified by index.

    Parameters:
    - weights: list of np.ndarray (e.g., [W1, W2])
    - sparsity_level: fraction of weights to prune (0.0 to 1.0)
    - index_to_prune: which weight matrix to prune in the list

    Returns:
    - list of masks (one for each weight matrix)
    """
    masks = [np.ones_like(W) for W in weights]

    W = weights[index_to_prune]
    flat = np.abs(W.flatten())
    num_to_prune = int(np.floor(sparsity_level * flat.size))

    if num_to_prune > 0:
        idx = np.argpartition(flat, num_to_prune)[:num_to_prune]
        mask_flat = np.ones_like(flat, dtype=bool)
        mask_flat[idx] = False
        masks[index_to_prune] = mask_flat.reshape(W.shape).astype(float)

    return masks




In [17]:
df_rmse_relu, df_posterior_rmse_relu = {}, {}
df_rmse_tanh, df_posterior_rmse_tanh = {}, {}

for sparsity in sparsity_levels:
    df_rmse_relu[sparsity], df_posterior_rmse_relu[sparsity] = compute_rmse_results(
        seeds, model_names_relu, relu_fits, get_N_sigma, forward_pass_relu,
        sparsity=sparsity, prune_fn=local_prune_weights
    )
    
    df_rmse_tanh[sparsity], df_posterior_rmse_tanh[sparsity] = compute_rmse_results(
        seeds, model_names_tanh, tanh_fits, get_N_sigma, forward_pass_tanh,
        sparsity=sparsity, prune_fn=local_prune_weights
    )

In [18]:
import pandas as pd

df_rmse_full_relu = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_relu.items()],
    ignore_index=True
)

df_rmse_full_tanh = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_tanh.items()],
    ignore_index=True
)


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

fig, axes = plt.subplots(2, 2, figsize=(16, 10), sharex=True, sharey=True)

activation_data = [("ReLU", df_rmse_full_relu), ("tanh", df_rmse_full_tanh)]

for row_idx, (name, df) in enumerate(activation_data):
    for col_idx, N_val in enumerate([100, 200]):
        sigma = 1.0 if col_idx == 0 else 3.0
        ax = axes[col_idx, row_idx]

        sns.lineplot(
            data=df[df['sigma'] == sigma],
            x='sparsity', y='rmse',
            hue='model', style='N', marker='o', errorbar=None, ax=ax
        )

        ax.set_title(f"{name}, $\sigma = {sigma}$")
        if col_idx == 0 or col_idx == 1:
            ax.set_ylabel("RMSE")
        if row_idx == 0 or row_idx == 1:
            ax.set_xlabel("Sparsity Level")
        else:
            ax.set_ylabel("")
        ax.grid(True)

        # Remove individual legends
        if ax.get_legend() is not None:
            ax.get_legend().remove()

exclude_labels = {"model", "N"}
filtered = [(h, l) for h, l in zip(handles, labels) if l not in exclude_labels]
# Shared legend at bottom center
if filtered:
    filtered_handles, filtered_labels = zip(*filtered)
    # Use fig.legend to add shared legend
    fig.legend(
        filtered_handles,
        filtered_labels,
        title="",
        loc="upper center",
        bbox_to_anchor=(0.5, 1.05),
        ncol=3,
    )
plt.tight_layout(rect=[0, 0.05, 1, 1])  # Space at bottom for shared legend
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from collections import OrderedDict

# Clean model names
df_rmse_full_relu = df_rmse_full_relu.copy()
df_rmse_full_relu["model"] = df_rmse_full_relu["model"].str.replace(" tanh", "", regex=False)

df_rmse_full_tanh = df_rmse_full_tanh.copy()
df_rmse_full_tanh["model"] = df_rmse_full_tanh["model"].str.replace(" tanh", "", regex=False)

# Define consistent color palette
custom_palette = {
    "Gaussian": "C0",  # blue
    "Regularized Horseshoe": "C1",  # orange
    "Dirichlet Horseshoe": "C2",  # green
    "Dirichlet Student T": "C3",  # red
}

# Set up plot
fig, axes = plt.subplots(2, 2, figsize=(16, 10), sharex=True, sharey=True)

activation_data = [("ReLU", df_rmse_full_relu), ("tanh", df_rmse_full_tanh)]
all_handles_labels = []

# Plot
for row_idx, (name, df) in enumerate(activation_data):
    for col_idx, N_val in enumerate([100, 200]):
        sigma = 1.0 if col_idx == 0 else 3.0
        ax = axes[col_idx, row_idx]

        plot = sns.lineplot(
            data=df[df['sigma'] == sigma],
            x='sparsity', y='rmse',
            hue='model', style='N', marker='o', errorbar=None, ax=ax,
            palette=custom_palette
        )

        # Capture legend handles before removing
        handles, labels = ax.get_legend_handles_labels()
        all_handles_labels.extend(zip(handles, labels))
        ax.get_legend().remove()

        ax.set_title(f"{name}, $\sigma = {sigma}$")
        if col_idx == 0 or col_idx == 1:
            ax.set_ylabel("RMSE")
        if row_idx == 0 or row_idx == 1:
            ax.set_xlabel("Sparsity Level")
        ax.grid(True)

# Filter and sort legend
legend_dict = OrderedDict()
for handle, label in all_handles_labels:
    if label not in {"model", "N"} and label not in legend_dict:
        legend_dict[label] = handle

desired_order = ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T"]
filtered = [(legend_dict[label], label) for label in desired_order if label in legend_dict]

# Shared legend
if filtered:
    filtered_handles, filtered_labels = zip(*filtered)
    fig.legend(
        filtered_handles,
        filtered_labels,
        title="",
        loc="upper center",
        bbox_to_anchor=(0.5, 1.05),
        ncol=2,
    )

plt.tight_layout(rect=[0, 0.05, 1, 1])
plt.show()


## MULTIMODALITY

In [None]:
relu_fits['INT_N100_p8_sigma1.00_seed1']['Gaussian']['posterior'].stan_variable('output_test').squeeze().shape

In [None]:
import numpy as np
from sklearn.manifold import TSNE
from sklearn.cluster import DBSCAN
from sklearn.neighbors import KernelDensity
import matplotlib.pyplot as plt
import seaborn as sns

# === REQUIRED: Extract your posterior function samples as (num_samples, num_points) ===
# Example (YOU must provide this):
# f_matrix = relu_fits['INT_N100_p8_sigma1.00_seed1']['Gaussian']['posterior'].stan_variable('f')
# Make sure f_matrix is a NumPy array of shape (4000, N)

f_matrix = relu_fits['INT_N100_p8_sigma1.00_seed1']['Gaussian']['posterior'].stan_variable('output_test').squeeze()  # <--- REPLACE THIS WITH YOUR ACTUAL DATA

# Step 1: t-SNE projection to 2D
tsne = TSNE(n_components=2, perplexity=5, random_state=42)
f_2d = tsne.fit_transform(f_matrix)

# Step 2: DBSCAN clustering to estimate modes
db = DBSCAN(eps=2.5, min_samples=10).fit(f_2d)
labels = db.labels_
n_modes = len(set(labels)) - (1 if -1 in labels else 0)

# Step 3: Plot t-SNE with clusters
plt.figure(figsize=(8, 6))
palette = sns.color_palette("hsv", len(set(labels)))
for label in set(labels):
    idx = labels == label
    plt.scatter(f_2d[idx, 0], f_2d[idx, 1], s=10,
                label=f'Mode {label}' if label != -1 else 'Noise', alpha=0.7)
plt.title(f't-SNE of Posterior Samples (Estimated Modes: {n_modes})')
plt.xlabel('Dim 1')
plt.ylabel('Dim 2')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# # Step 4: Density estimation over t-SNE space
# kde = KernelDensity(bandwidth=0.5).fit(f_2d)
# log_density = kde.score_samples(f_2d)

# # Step 5: Plot KDE density landscape
# plt.figure(figsize=(8, 6))
# plt.scatter(f_2d[:, 0], f_2d[:, 1], c=log_density, cmap='viridis', s=10)
# plt.title("Posterior Sample Density (KDE on t-SNE Projection)")
# plt.colorbar(label="Log Density")
# plt.grid(True)
# plt.tight_layout()
# plt.show()


In [None]:
f_matrix = relu_fits['INT_N100_p8_sigma1.00_seed1']['Dirichlet Horseshoe']['posterior'].stan_variable('output_test').squeeze()  # <--- REPLACE THIS WITH YOUR ACTUAL DATA

# Step 1: t-SNE projection to 2D
tsne = TSNE(n_components=2, perplexity=5, random_state=42)
f_2d = tsne.fit_transform(f_matrix)

# Step 2: DBSCAN clustering to estimate modes
db = DBSCAN(eps=2.5, min_samples=10).fit(f_2d)
labels = db.labels_
n_modes = len(set(labels)) - (1 if -1 in labels else 0)

# Step 3: Plot t-SNE with clusters
plt.figure(figsize=(8, 6))
palette = sns.color_palette("hsv", len(set(labels)))
for label in set(labels):
    idx = labels == label
    plt.scatter(f_2d[idx, 0], f_2d[idx, 1], s=10,
                label=f'Mode {label}' if label != -1 else 'Noise', alpha=0.7)
plt.title(f't-SNE of Posterior Samples (Estimated Modes: {n_modes})')
plt.xlabel('Dim 1')
plt.ylabel('Dim 2')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()