## Example: Joint inference of $p(G, \Theta | D)$ for Gaussian Bayes nets

Setup for Google Colab. Selecting the **GPU** runtime available in Google colab will make inference significantly faster.



In [None]:
%pip install --quiet dibs-lib 

DiBS translates the task of inferring the posterior over Bayesian networks into an inference problem over the continuous latent variable $Z$. This is achieved by modeling the directed acyclic graph $G$ of the Bayesian network using the generative model $p(G | Z)$. The prior $p(Z)$ enforces the acyclicity of $G$.
Ultimately, this allows us to infer $p(G, \Theta | D)$ (and $p(G | D)$) using off-the-shelf inference methods such as Stein Variational gradient descent (SVGD) (Liu and Wang, 2016).

In [None]:
import jax
import jax.random as random
key = random.PRNGKey(123)
print(f"JAX backend: {jax.default_backend()}")

### Generate synthetic ground truth Bayesian network and BN model for inference

`data` contains information about and observations sampled from a synthetic, ground truth causal model with `n_vars` variables. By default, the conditional distributions are linear Gaussian. The random graph model is set by `graph_prior_str`, where `er` denotes Erdos-Renyi and `sf` scale-free graphs. 

`graph_model` defines prior $p(G)$ and `likelihood_model` defines likelihood $p(x, \Theta| G ) = p(\Theta| G )p(x | G, \Theta )$ of the BN model for which DiBS will infer the posterior.

**For posterior inference of nonlinear Gaussian networks parameterized by fully-connected neural networks, use the function `make_nonlinear_gaussian_model`.** 


In [None]:
from dibs.target import make_linear_gaussian_model, make_nonlinear_gaussian_model
from dibs.utils import visualize_ground_truth

key, subk = random.split(key)
data, graph_model, likelihood_model = make_linear_gaussian_model(key=subk, n_vars=20, graph_prior_str="sf")
# data, graph_model, likelihood_model = make_nonlinear_gaussian_model(key=subk, n_vars=20, graph_prior_str="sf")

visualize_ground_truth(data.g)

### DiBS with SVGD

Infer $p(G, \Theta | D)$ under the prior and conditional distributions defined by the model.
The below visualization shows the *matrix of edge probabilities* $G_\alpha(Z^{(k)})$ implied by each transported latent particle (i.e., sample) $Z^{(k)}$ during the iterations of SVGD with DiBS. Refer to the paper for further details.

To explicitly perform posterior inference of $p(G | D)$ using a closed-form marginal likelihood $p(D | G)$, use the separate, analogous class `MarginalDiBS` as demonstrated in the example notebook `dibs_marginal.ipynb`



In [None]:
from dibs.inference import JointDiBS

dibs = JointDiBS(x=data.x, interv_mask=None, graph_model=graph_model, likelihood_model=likelihood_model)
key, subk = random.split(key)
gs, thetas = dibs.sample(key=subk, n_particles=20, steps=2000, callback_every=100, callback=dibs.visualize_callback())

### Evaluate on held-out data

Form the empirical (i.e., weighted by counts) and mixture distributions (i.e., weighted by unnormalized posterior probabilities, denoted DiBS+).

In [None]:
dibs_empirical = dibs.get_empirical(gs, thetas)
dibs_mixture = dibs.get_mixture(gs, thetas)

Compute some evaluation metrics.

In [None]:
from dibs.metrics import expected_shd, threshold_metrics, neg_ave_log_likelihood

for descr, dist in [('DiBS ', dibs_empirical), ('DiBS+', dibs_mixture)]:
    
    eshd = expected_shd(dist=dist, g=data.g)        
    auroc = threshold_metrics(dist=dist, g=data.g)['roc_auc']
    negll = neg_ave_log_likelihood(dist=dist, eltwise_log_likelihood=dibs.eltwise_log_likelihood_observ, x=data.x_ho)
    
    print(f'{descr} |  E-SHD: {eshd:4.1f}    AUROC: {auroc:5.2f}    neg. LL {negll:5.2f}')
    

## Deep Ensemble Comparison

In [None]:
from dibs.target import make_nonlinear_gaussian_model
from dibs.utils import visualize_ground_truth

key = random.PRNGKey(0)
key, subk = random.split(key)
data, graph_model, likelihood_model = make_nonlinear_gaussian_model(key=subk, n_vars=20, graph_prior_str="sf")
visualize_ground_truth(data.g)

In [None]:
dibs = JointDiBS(x=data.x, interv_mask=None, graph_model=graph_model, likelihood_model=likelihood_model)
key, subk = random.split(key)
gs_20, thetas_20 = dibs.sample(key=subk, n_particles=20, steps=2000, callback_every=100, callback=dibs.visualize_callback())

In [None]:
dibs_empirical_20 = dibs.get_empirical(gs_20, thetas_20)
dibs_mixture_20 = dibs.get_mixture(gs_20, thetas_20)

In [None]:
single_empiricals = []
single_mixtures = []
key_de = key
for i in range(20):
    key_de, subk = random.split(key_de)
    dibs_single = JointDiBS(x=data.x, interv_mask=None, graph_model=graph_model, likelihood_model=likelihood_model)
    gs, thetas = dibs_single.sample(key=subk, n_particles=1, steps=2000)  # No visualization callback for speed
    empirical = dibs_single.get_empirical(gs, thetas)
    mixture = dibs_single.get_mixture(gs, thetas)
    single_empiricals.append(empirical)
    single_mixtures.append(mixture)

# For aggregate deep ensemble: uniform average over the individual distributions
# (custom implementation; assumes distributions have a .prob method or similar for averaging)
def average_dist(dists, g_true):
    def avg_metric(metric_fn):
        return sum(metric_fn(dist=dist, g=g_true) for dist in dists) / len(dists)
    return avg_metric

deep_empirical_avg = average_dist(single_empiricals, data.g)
deep_mixture_avg = average_dist(single_mixtures, data.g)

In [None]:
from dibs.metrics import expected_shd, threshold_metrics, neg_ave_log_likelihood

# Helper to compute metrics for a dist (or averaged dist)
def compute_metrics(dist, name):
    if callable(dist):  # For averaged dists
        eshd = dist(expected_shd)
        auroc = dist(lambda **kw: threshold_metrics(**kw)['roc_auc'])
        negll = dist(lambda **kw: neg_ave_log_likelihood(eltwise_log_likelihood=dibs.eltwise_log_likelihood_observ, x=data.x_ho, **kw))
    else:  # For standard dists
        eshd = expected_shd(dist=dist, g=data.g)
        auroc = threshold_metrics(dist=dist, g=data.g)['roc_auc']
        negll = neg_ave_log_likelihood(dist=dist, eltwise_log_likelihood=dibs.eltwise_log_likelihood_observ, x=data.x_ho)
    print(f'{name} |  E-SHD: {eshd:4.1f}    AUROC: {auroc:5.2f}    neg. LL {negll:5.2f}')

compute_metrics(dibs_empirical_20, 'SVGD empirical (20 particles)')
compute_metrics(dibs_mixture_20, 'SVGD mixture (20 particles)')
compute_metrics(deep_empirical_avg, 'Deep ensemble empirical avg (20 x 1-particle)')
compute_metrics(deep_mixture_avg, 'Deep ensemble mixture avg (20 x 1-particle)')

In [None]:
svgd_per_metrics = []
for i in range(20):
    particle_dist = dibs.get_empirical(gs_20[i:i+1], thetas_20[i:i+1])  # Empirical for each particle
    eshd = expected_shd(dist=particle_dist, g=data.g)
    auroc = threshold_metrics(dist=particle_dist, g=data.g)['roc_auc']
    negll = neg_ave_log_likelihood(dist=particle_dist, eltwise_log_likelihood=dibs.eltwise_log_likelihood_observ, x=data.x_ho)
    svgd_per_metrics.append({'eshd': float(eshd), 'auroc': float(auroc), 'negll': float(negll)})

svgd_per_metrics  # Output the list for inspection

In [None]:
deep_empirical_per_metrics = []
deep_mixture_per_metrics = []
for i in range(20):
    emp_dist = single_empiricals[i]
    mix_dist = single_mixtures[i]
    
    eshd_emp = expected_shd(dist=emp_dist, g=data.g)
    auroc_emp = threshold_metrics(dist=emp_dist, g=data.g)['roc_auc']
    negll_emp = neg_ave_log_likelihood(dist=emp_dist, eltwise_log_likelihood=dibs.eltwise_log_likelihood_observ, x=data.x_ho)
    deep_empirical_per_metrics.append({'eshd': float(eshd_emp), 'auroc': float(auroc_emp), 'negll': float(negll_emp)})
    
    eshd_mix = expected_shd(dist=mix_dist, g=data.g)
    auroc_mix = threshold_metrics(dist=mix_dist, g=data.g)['roc_auc']
    negll_mix = neg_ave_log_likelihood(dist=mix_dist, eltwise_log_likelihood=dibs.eltwise_log_likelihood_observ, x=data.x_ho)
    deep_mixture_per_metrics.append({'eshd': float(eshd_mix), 'auroc': float(auroc_mix), 'negll': float(negll_mix)})

deep_empirical_per_metrics  # Output for inspection
deep_mixture_per_metrics  # Output for inspection

In [None]:
main_key = random.PRNGKey(42)
print(f"JAX backend: {jax.default_backend()}")

# Generate ground truth nonlinear Gaussian model
print("Generating ground truth nonlinear Gaussian model...")
key, subk = random.split(main_key)
data, graph_model, likelihood_model = make_nonlinear_gaussian_model(
    key=subk, 
    n_vars=20, 
    graph_prior_str="sf"
)

print(f"Ground truth graph has {np.sum(data.g)} edges")
print("Visualizing ground truth...")
try:
    visualize_ground_truth(data.g)
except:
    print("Visualization skipped (may not work in all environments)")

# Experiment parameters
n_ensemble_runs = 20
n_particles_svgd = 20
n_steps = 2000
callback_every = 500

print(f"\n" + "="*60)
print("EXPERIMENT SETUP")
print(f"  Deep Ensemble: {n_ensemble_runs} runs × 1 particle each")
print(f"  SVGD: 1 run × {n_particles_svgd} particles")
print(f"  Training steps: {n_steps}")
print(f"  Variables: {data.x.shape[1]}")
print(f"  Training samples: {data.x.shape[0]}")
print(f"  Test samples: {data.x_ho.shape[0]}")
print("="*60)

# Storage for results
ensemble_results = []
ensemble_metrics = {
    'eshd_empirical': [],
    'auroc_empirical': [],
    'negll_empirical': [],
    'eshd_mixture': [],
    'auroc_mixture': [],
    'negll_mixture': [],
    'training_time': []
}

print("\n" + "="*60)
print("DEEP ENSEMBLE APPROACH (20 runs × 1 particle)")
print("="*60)

# Deep Ensemble: 20 runs with 1 particle each
for run_idx in range(n_ensemble_runs):
    print(f"\nRun {run_idx + 1}/{n_ensemble_runs}")
    
    # Use different seed for each run
    key, subk = random.split(key)
    
    start_time = time.time()
    
    # Create DiBS instance
    dibs = JointDiBS(
        x=data.x, 
        interv_mask=None, 
        graph_model=graph_model, 
        likelihood_model=likelihood_model
    )
    
    # Sample with 1 particle
    gs, thetas = dibs.sample(
        key=subk, 
        n_particles=1, 
        steps=n_steps, 
        callback_every=callback_every
    )
    
    training_time = time.time() - start_time
    
    # Get distributions
    dibs_empirical = dibs.get_empirical(gs, thetas)
    dibs_mixture = dibs.get_mixture(gs, thetas)
    
    # Compute metrics
    # Empirical
    eshd_emp = expected_shd(dist=dibs_empirical, g=data.g)
    auroc_emp = threshold_metrics(dist=dibs_empirical, g=data.g)['roc_auc']
    negll_emp = neg_ave_log_likelihood(
        dist=dibs_empirical, 
        eltwise_log_likelihood=dibs.eltwise_log_likelihood_observ, 
        x=data.x_ho
    )
    
    # Mixture
    eshd_mix = expected_shd(dist=dibs_mixture, g=data.g)
    auroc_mix = threshold_metrics(dist=dibs_mixture, g=data.g)['roc_auc']
    negll_mix = neg_ave_log_likelihood(
        dist=dibs_mixture, 
        eltwise_log_likelihood=dibs.eltwise_log_likelihood_observ, 
        x=data.x_ho
    )
    
    # Store results
    run_result = {
        'run_idx': run_idx,
        'eshd_empirical': eshd_emp,
        'auroc_empirical': auroc_emp,
        'negll_empirical': negll_emp,
        'eshd_mixture': eshd_mix,
        'auroc_mixture': auroc_mix,
        'negll_mixture': negll_mix,
        'training_time': training_time
    }
    
    ensemble_results.append(run_result)
    
    # Also store in lists for easy aggregation
    ensemble_metrics['eshd_empirical'].append(eshd_emp)
    ensemble_metrics['auroc_empirical'].append(auroc_emp)
    ensemble_metrics['negll_empirical'].append(negll_emp)
    ensemble_metrics['eshd_mixture'].append(eshd_mix)
    ensemble_metrics['auroc_mixture'].append(auroc_mix)
    ensemble_metrics['negll_mixture'].append(negll_mix)
    ensemble_metrics['training_time'].append(training_time)
    
    print(f"  Empirical - E-SHD: {eshd_emp:5.2f}, AUROC: {auroc_emp:5.3f}, NegLL: {negll_emp:6.2f}")
    print(f"  Mixture   - E-SHD: {eshd_mix:5.2f}, AUROC: {auroc_mix:5.3f}, NegLL: {negll_mix:6.2f}")
    print(f"  Time: {training_time:.1f}s")

print("\n" + "="*60)
print("SVGD APPROACH (1 run × 20 particles)")
print("="*60)

# SVGD: 1 run with 20 particles
key, subk = random.split(key)

start_time = time.time()

dibs_svgd = JointDiBS(
    x=data.x, 
    interv_mask=None, 
    graph_model=graph_model, 
    likelihood_model=likelihood_model
)

gs_svgd, thetas_svgd = dibs_svgd.sample(
    key=subk, 
    n_particles=n_particles_svgd, 
    steps=n_steps, 
    callback_every=callback_every
)

svgd_training_time = time.time() - start_time

# Get distributions
svgd_empirical = dibs_svgd.get_empirical(gs_svgd, thetas_svgd)
svgd_mixture = dibs_svgd.get_mixture(gs_svgd, thetas_svgd)

# Compute metrics
# Empirical
svgd_eshd_emp = expected_shd(dist=svgd_empirical, g=data.g)
svgd_auroc_emp = threshold_metrics(dist=svgd_empirical, g=data.g)['roc_auc']
svgd_negll_emp = neg_ave_log_likelihood(
    dist=svgd_empirical, 
    eltwise_log_likelihood=dibs_svgd.eltwise_log_likelihood_observ, 
    x=data.x_ho
)

# Mixture
svgd_eshd_mix = expected_shd(dist=svgd_mixture, g=data.g)
svgd_auroc_mix = threshold_metrics(dist=svgd_mixture, g=data.g)['roc_auc']
svgd_negll_mix = neg_ave_log_likelihood(
    dist=svgd_mixture, 
    eltwise_log_likelihood=dibs_svgd.eltwise_log_likelihood_observ, 
    x=data.x_ho
)

print(f"SVGD Results:")
print(f"  Empirical - E-SHD: {svgd_eshd_emp:5.2f}, AUROC: {svgd_auroc_emp:5.3f}, NegLL: {svgd_negll_emp:6.2f}")
print(f"  Mixture   - E-SHD: {svgd_eshd_mix:5.2f}, AUROC: {svgd_auroc_mix:5.3f}, NegLL: {svgd_negll_mix:6.2f}")
print(f"  Time: {svgd_training_time:.1f}s")

print("\n" + "="*60)
print("RESULTS SUMMARY")
print("="*60)

# Compute statistics for ensemble
def compute_stats(values):
    return {
        'mean': np.mean(values),
        'std': np.std(values),
        'min': np.min(values),
        'max': np.max(values),
        'median': np.median(values)
    }

print("\nDEEP ENSEMBLE STATISTICS (20 runs × 1 particle):")
print("-" * 50)

for metric_name in ['eshd_empirical', 'auroc_empirical', 'negll_empirical', 
                   'eshd_mixture', 'auroc_mixture', 'negll_mixture']:
    stats = compute_stats(ensemble_metrics[metric_name])
    print(f"{metric_name:15s}: {stats['mean']:6.2f} ± {stats['std']:5.2f} "
          f"[{stats['min']:5.2f}, {stats['max']:5.2f}] (median: {stats['median']:5.2f})")

training_stats = compute_stats(ensemble_metrics['training_time'])
print(f"{'training_time':15s}: {training_stats['mean']:6.1f} ± {training_stats['std']:5.1f}s "
      f"(total: {sum(ensemble_metrics['training_time']):.1f}s)")

print(f"\nSVGD RESULTS (1 run × 20 particles):")
print("-" * 50)
print(f"{'eshd_empirical':15s}: {svgd_eshd_emp:6.2f}")
print(f"{'auroc_empirical':15s}: {svgd_auroc_emp:6.2f}")
print(f"{'negll_empirical':15s}: {svgd_negll_emp:6.2f}")
print(f"{'eshd_mixture':15s}: {svgd_eshd_mix:6.2f}")
print(f"{'auroc_mixture':15s}: {svgd_auroc_mix:6.2f}")
print(f"{'negll_mixture':15s}: {svgd_negll_mix:6.2f}")
print(f"{'training_time':15s}: {svgd_training_time:6.1f}s")

print("\n" + "="*60)
print("COMPARISON ANALYSIS")
print("="*60)

# Compare approaches
print("\nEMPIRICAL DISTRIBUTION COMPARISON:")
print("-" * 40)
ensemble_mean_eshd_emp = np.mean(ensemble_metrics['eshd_empirical'])
ensemble_mean_auroc_emp = np.mean(ensemble_metrics['auroc_empirical'])
ensemble_mean_negll_emp = np.mean(ensemble_metrics['negll_empirical'])

print(f"Expected SHD:")
print(f"  Deep Ensemble: {ensemble_mean_eshd_emp:5.2f} (± {np.std(ensemble_metrics['eshd_empirical']):.2f})")
print(f"  SVGD:          {svgd_eshd_emp:5.2f}")
print(f"  Difference:    {ensemble_mean_eshd_emp - svgd_eshd_emp:+5.2f} (negative is better for ensemble)")

print(f"\nAUROC:")
print(f"  Deep Ensemble: {ensemble_mean_auroc_emp:5.3f} (± {np.std(ensemble_metrics['auroc_empirical']):.3f})")
print(f"  SVGD:          {svgd_auroc_emp:5.3f}")
print(f"  Difference:    {ensemble_mean_auroc_emp - svgd_auroc_emp:+5.3f} (positive is better for ensemble)")

print(f"\nNegative Log-Likelihood:")
print(f"  Deep Ensemble: {ensemble_mean_negll_emp:6.2f} (± {np.std(ensemble_metrics['negll_empirical']):.2f})")
print(f"  SVGD:          {svgd_negll_emp:6.2f}")
print(f"  Difference:    {ensemble_mean_negll_emp - svgd_negll_emp:+6.2f} (negative is better for ensemble)")

print("\nMIXTURE DISTRIBUTION COMPARISON:")
print("-" * 40)
ensemble_mean_eshd_mix = np.mean(ensemble_metrics['eshd_mixture'])
ensemble_mean_auroc_mix = np.mean(ensemble_metrics['auroc_mixture'])
ensemble_mean_negll_mix = np.mean(ensemble_metrics['negll_mixture'])

print(f"Expected SHD:")
print(f"  Deep Ensemble: {ensemble_mean_eshd_mix:5.2f} (± {np.std(ensemble_metrics['eshd_mixture']):.2f})")
print(f"  SVGD:          {svgd_eshd_mix:5.2f}")
print(f"  Difference:    {ensemble_mean_eshd_mix - svgd_eshd_mix:+5.2f} (negative is better for ensemble)")

print(f"\nAUROC:")
print(f"  Deep Ensemble: {ensemble_mean_auroc_mix:5.3f} (± {np.std(ensemble_metrics['auroc_mixture']):.3f})")
print(f"  SVGD:          {svgd_auroc_mix:5.3f}")
print(f"  Difference:    {ensemble_mean_auroc_mix - svgd_auroc_mix:+5.3f} (positive is better for ensemble)")

print(f"\nNegative Log-Likelihood:")
print(f"  Deep Ensemble: {ensemble_mean_negll_mix:6.2f} (± {np.std(ensemble_metrics['negll_mixture']):.2f})")
print(f"  SVGD:          {svgd_negll_mix:6.2f}")
print(f"  Difference:    {ensemble_mean_negll_mix - svgd_negll_mix:+6.2f} (negative is better for ensemble)")

# Final summary
print("\n" + "="*60)
print("FINAL SUMMARY")
print("="*60)

better_empirical = []
better_mixture = []

if ensemble_mean_eshd_emp < svgd_eshd_emp:
    better_empirical.append("E-SHD")
if ensemble_mean_auroc_emp > svgd_auroc_emp:
    better_empirical.append("AUROC")
if ensemble_mean_negll_emp < svgd_negll_emp:
    better_empirical.append("NegLL")

if ensemble_mean_eshd_mix < svgd_eshd_mix:
    better_mixture.append("E-SHD")
if ensemble_mean_auroc_mix > svgd_auroc_mix:
    better_mixture.append("AUROC")
if ensemble_mean_negll_mix < svgd_negll_mix:
    better_mixture.append("NegLL")

print(f"Deep Ensemble outperforms SVGD on:")
print(f"  Empirical distribution: {better_empirical if better_empirical else 'None'}")
print(f"  Mixture distribution:   {better_mixture if better_mixture else 'None'}")

total_ensemble_time = sum(ensemble_metrics['training_time'])
print(f"\nComputational efficiency:")
print(f"  Deep Ensemble total time: {total_ensemble_time:.1f}s")
print(f"  SVGD time:               {svgd_training_time:.1f}s")
print(f"  Time ratio (Ensemble/SVGD): {total_ensemble_time/svgd_training_time:.1f}x")

print(f"\nThis comparison shows the trade-offs between:")
print(f"  - Deep Ensemble: Independent optimization, potential for diverse solutions")
print(f"  - SVGD: Particle interaction, computational efficiency, Bayesian approach")

# Save results for further analysis
results_dict = {
    'ensemble_results': ensemble_results,
    'ensemble_metrics': ensemble_metrics,
    'svgd_results': {
        'eshd_empirical': svgd_eshd_emp,
        'auroc_empirical': svgd_auroc_emp,
        'negll_empirical': svgd_negll_emp,
        'eshd_mixture': svgd_eshd_mix,
        'auroc_mixture': svgd_auroc_mix,
        'negll_mixture': svgd_negll_mix,
        'training_time': svgd_training_time
    },
    'ground_truth_edges': np.sum(data.g),
    'experiment_params': {
        'n_ensemble_runs': n_ensemble_runs,
        'n_particles_svgd': n_particles_svgd,
        'n_steps': n_steps,
        'n_vars': data.x.shape[1],
        'n_train_samples': data.x.shape[0],
        'n_test_samples': data.x_ho.shape[0]
    }
}

print(f"\nResults stored in 'results_dict' variable for further analysis.")
print(f"Individual ensemble runs available in 'ensemble_results' list.")
print(f"Ensemble aggregated metrics available in 'ensemble_metrics' dict.")

print("\n" + "="*60)
print("EXPERIMENT COMPLETED")
print("="*60) 