In [1]:
import jax.numpy as jnp
import jax.random as random
import numpyro
import numpyro.distributions as dist
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy.stats import gaussian_kde
from sklearn.mixture import GaussianMixture
from sklearn.cluster import DBSCAN

numpyro.set_host_device_count(10)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Synthetic Data Generation
def generate_synthetic_data(n_samples=500):
    rng_key = random.PRNGKey(42)
    x1 = random.uniform(rng_key, (n_samples,), minval=-1, maxval=1)
    x2 = random.uniform(rng_key, (n_samples,), minval=-1, maxval=1)
    y = jnp.sin(2 * jnp.pi * x1) + x2 ** 2 + 0.1 * random.normal(rng_key, (n_samples,))
    return {"x1": x1, "x2": x2, "y": y}

In [3]:
from namgcv.basemodels.bnam import BayesianNAM

# Train BNAM
def train_bnam(data, config):
    model = BayesianNAM(
        cat_feature_info={},  
        num_feature_info={
            "x1": {"input_dim": 1, "output_dim": 2}, 
            "x2": {"input_dim": 1, "output_dim": 2}
        },
        config=config,
    )
    model.train_model(
        num_features={"x1": data["x1"], "x2": data["x2"]},
        cat_features={},
        target=data["y"]
    )
    return model

In [4]:
from itertools import combinations


# Posterior Analysis
def analyze_posterior(model):
    posterior_samples = model._get_posterior_param_samples()
    
    # Extract joint samples
    w1_samples = posterior_samples["weights"]["x1_num_subnetwork_w0"]
    w2_samples = posterior_samples["weights"]["x2_num_subnetwork_w0"]
    sigma_samples = posterior_samples["scale"]["final_params"][:, 1]

    # Pairwise plots
    pairs = [("w1", w1_samples), ("w2", w2_samples), ("sigma", sigma_samples)]
    for (name1, samples1), (name2, samples2) in combinations(pairs, 2):
        plt.figure(figsize=(8, 6))
        sns.kdeplot(x=samples1, y=samples2, fill=True, cmap="Blues", thresh=0.05)
        plt.xlabel(name1)
        plt.ylabel(name2)
        plt.title(f"Joint Posterior: {name1} vs {name2}")
        plt.show()

        # Gaussian Mixture Model (GMM)
        X = np.vstack([samples1, samples2]).T
        gmm = GaussianMixture(n_components=2).fit(X)
        labels = gmm.predict(X)
        print(f"GMM found {len(set(labels))} components for {name1} vs {name2}")

        # DBSCAN Clustering
        dbscan = DBSCAN(eps=0.1, min_samples=10).fit(X)
        clusters = len(set(dbscan.labels_))
        print(f"DBSCAN found {clusters} clusters for {name1} vs {name2}")

In [None]:
from namgcv.configs.bayesian_nam_config import DefaultBayesianNAMConfig

# Run Ablation Study
data = generate_synthetic_data()
config = DefaultBayesianNAMConfig()
model = train_bnam(data, config)

INFO:2025-02-10 11:53:34,678:jax._src.xla_bridge:927: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2025-02-10 11:53:34,678 INFO:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:2025-02-10 11:53:34,684:jax._src.xla_bridge:927: Unable to initialize backend 'tpu': UNIMPLEMENTED: LoadPjrtPlugin is not implemented on windows yet.
2025-02-10 11:53:34,684 INFO:Unable to initialize backend 'tpu': UNIMPLEMENTED: LoadPjrtPlugin is not implemented on windows yet.
2025-02-10 11:53:35,131 INFO:Bayesian NN successfully initialized.
2025-02-10 11:53:35,133 INFO:Bayesian NN successfully initialized.
2025-02-10 11:53:35,134 INFO:Bayesian NN successfully initialized.
2025-02-10 11:53:35,135 INFO:
+---------------------------------------+
| Bayesian NAM successfully initialized.|
+---------------------------------------+

2025-02-10 11:53:35,135 INFO:Numerical feature network: x1
Netwo

In [None]:
analyze_posterior(model)