# Setup and imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# notebooks/dibs_experiment.ipynb
import torch
import numpy as np
import logging
import math
import sys
import os
import mlflow

# Visualization
import matplotlib.pyplot as plt
import networkx as nx

# Add project root to the Python path
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from data.graph_data import generate_synthetic_data
from models.dibs import grad_log_joint, log_joint, hard_gmat_from_z, bernoulli_soft_gmat, update_dibs_hparams
from models.utils import acyclic_constr

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger()

# CONFIGURATION

In [3]:
class Config:
    seed = 42
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    mlflow_experiment_name = "DiBS Simple Experiment"

    # --- Data Generation ---
    # 'simple_chain' or 'synthetic'
    data_source = 'simple_chain'
    
    # Parameters for 'simple_chain'
    num_samples = 50
    obs_noise_std = 0.1

    # Parameters for 'synthetic'
    d_nodes = 4
    graph_type = 'scale-free'
    graph_params = {'p_edge': 0.70, 'm_edges':3}
    synthetic_obs_noise_std = 0.1

    # Particle and Model parameters
    k_latent = 3
    alpha_val = 0.05
    beta_val = 1.0
    tau_val = 1.0
    theta_prior_sigma_val = 1.
    n_grad_mc_samples = 64
    n_nongrad_mc_samples = 64

    # Training parameters
    lr = 0.005
    num_iterations = 1000
    debug_print_iter = 100

cfg = Config()

# Set random seed for reproducibility
torch.manual_seed(cfg.seed)
np.random.seed(cfg.seed)

log.info(f"Running on device: {cfg.device}")

2025-06-23 14:04:26,723 - INFO - Running on device: cpu


# Syntetic data generation

In [4]:
# notebooks/dibs_experiment.ipynb

# ---- [Cell 3: Data Generation] ----
# Generate data based on the selected `data_source` from the configuration.

def generate_ground_truth_data_x1_x2_x3(num_samples, obs_noise_std, seed=None):
    """Generates data for the ground truth causal chain X1 -> X2 -> X3."""
    if seed is not None:
        torch.manual_seed(seed)
    D_nodes = 3
    G_true = torch.zeros(D_nodes, D_nodes, dtype=torch.float32)
    G_true[0, 1] = 1.0
    G_true[1, 2] = 1.0
    Theta_true = torch.zeros(D_nodes, D_nodes, dtype=torch.float32)
    Theta_true[0, 1] = 2.0
    Theta_true[1, 2] = -1.5
    X_data = torch.zeros(num_samples, D_nodes)
    X_data[:, 0] = torch.randn(num_samples) * obs_noise_std  # FIXME Should have variance obs_noise_std 
    noise_x2 = torch.randn(num_samples) * obs_noise_std
    X_data[:, 1] = Theta_true[0, 1] * X_data[:, 0] + noise_x2
    noise_x3 = torch.randn(num_samples) * obs_noise_std
    X_data[:, 2] = Theta_true[1, 2] * X_data[:, 1] + noise_x3
    return X_data, G_true, Theta_true

if cfg.data_source == 'simple_chain':
    log.info("Using 'simple_chain' data source.")
    data_x, graph_adj, graph_weights = generate_ground_truth_data_x1_x2_x3(
        num_samples=cfg.num_samples,
        obs_noise_std=cfg.obs_noise_std,
        seed=cfg.seed
    )
    # Update d_nodes based on the simple chain's size
    cfg.d_nodes = 3
    
elif cfg.data_source == 'synthetic':
    log.info("Using 'synthetic' data source.")
    graph_adj, graph_weights, data_x = generate_synthetic_data(
        n_samples=cfg.num_samples,
        n_nodes=cfg.d_nodes,
        graph_type=cfg.graph_type,
        graph_params=cfg.graph_params,
        noise_std=cfg.obs_noise_std
    )
else:
    raise ValueError(f"Unknown data_source: {cfg.data_source}")

data = {'x': data_x.to(cfg.device)}
log.info(f"Data generated with {cfg.d_nodes} nodes.")

2025-06-23 14:04:26,760 - INFO - Using 'simple_chain' data source.
2025-06-23 14:04:26,763 - INFO - Data generated with 3 nodes.


# MLflow tracking

In [5]:
# End any existing active run before starting a new one
if mlflow.active_run():
    mlflow.end_run()

mlflow.set_experiment(cfg.mlflow_experiment_name)
mlflow.start_run()

# Log all hyperparameters from the Config class
for param, value in vars(cfg).items():
    if not param.startswith('__') and not callable(value):
        mlflow.log_param(param, value)

log.info(f"Started MLflow run for experiment: '{cfg.mlflow_experiment_name}'")


2025-06-23 14:04:26,983 - INFO - Started MLflow run for experiment: 'DiBS Simple Experiment'


# Model initialization

In [6]:
# notebooks/dibs_experiment.ipynb

# ---- [Cell 4: Model and Particle Initialization] ----
# We initialize the learnable parameters (particles) z and theta.

def init_particle(d: int, k: int, device: str) -> dict:
    return {
        'z': torch.randn(d, k, 2, device=device),
        'theta': torch.randn(d, d, device=device)
    }

particle = init_particle(cfg.d_nodes, cfg.k_latent, cfg.device)

# Hparams dictionary, as used by the model functions
sigma_z = (1.0 / math.sqrt(cfg.k_latent))
hparams = {
    "alpha": cfg.alpha_val,
    "beta": cfg.beta_val,
    "alpha_base":cfg.alpha_val,
    "beta_base": cfg.beta_val,
    "tau": cfg.tau_val,
    "sigma_z": sigma_z,
    "sigma_obs_noise": cfg.synthetic_obs_noise_std,
    "theta_prior_sigma": cfg.theta_prior_sigma_val,
    "n_grad_mc_samples": cfg.n_grad_mc_samples,
    "n_nongrad_mc_samples": cfg.n_nongrad_mc_samples,
    "d": cfg.d_nodes,
    "debug_print_iter": cfg.debug_print_iter
}

# Training loop
Basic training loop using PyTorch optimizers with proper gradient hooking.

In [14]:
def basic_training_loop():

    def init_particle(d: int, k: int, device: str) -> dict:
        return {
            'z': torch.randn(d, k, 2, device=device),
            'theta': torch.randn(d, d, device=device)
        }

    particle = init_particle(cfg.d_nodes, cfg.k_latent, cfg.device)

    # Hparams dictionary, as used by the model functions
    sigma_z = (1.0 / math.sqrt(cfg.k_latent))
    hparams = {
    "alpha": cfg.alpha_val,
    "beta": cfg.beta_val,
    "alpha_base":cfg.alpha_val,
    "beta_base": cfg.beta_val,
    "tau": cfg.tau_val,
    "sigma_z": sigma_z,
    "sigma_obs_noise": cfg.synthetic_obs_noise_std,
    "theta_prior_sigma": cfg.theta_prior_sigma_val,
    "n_grad_mc_samples": cfg.n_grad_mc_samples,
    "n_nongrad_mc_samples": cfg.n_nongrad_mc_samples,
    "d": cfg.d_nodes,
    "debug_print_iter": cfg.debug_print_iter
}

    # Initialize PyTorch optimizers for z and theta parameters
    optimizer_z = torch.optim.RMSprop([particle['z']], lr=cfg.lr) # CHANGED Added requires_grad to tensors here
    optimizer_theta = torch.optim.RMSprop([particle['theta']], lr=cfg.lr)

    # Training loop using PyTorch optimizers with gradient hooking
    for t in range(1, cfg.num_iterations + 1):
        hparams = update_dibs_hparams(hparams, t)
        
        # Clear gradients using optimizers
        optimizer_z.zero_grad()
        optimizer_theta.zero_grad()
        
        # Set requires_grad for the particles
        particle['z'].requires_grad_(True)
        particle['theta'].requires_grad_(True)

        # Get gradients of the log-joint
        params_for_grad = {"z": particle['z'], "theta": particle['theta'], "t": torch.tensor(float(t))}
        grads = grad_log_joint(params_for_grad, data, hparams)

        # Hook gradients into PyTorch's gradient system
        # This is crucial: assign the computed gradients to .grad attributes
        #particle['z'].grad = grads['z']
        #particle['theta'].grad = grads['theta']
        
        # use gradient ascent 
        particle['z'].grad = -grads['z']
        particle['theta'].grad = -grads['theta']

        all_params = [particle['z'], particle['theta']]
        #torch.nn.utils.clip_grad_norm_(all_params, max_norm=10.0) # max_norm is a hyperparameter to tune


        # Use PyTorch optimizers to update parameters
        optimizer_z.step()
        optimizer_theta.step()

        # Logging
        if t % cfg.debug_print_iter == 0 or t == cfg.num_iterations:
            with torch.no_grad():
                # Calculate required values for logging
                lj_val = log_joint(params_for_grad, data, hparams).item()
                z_norm = torch.linalg.norm(particle['z']).item()
                theta_norm = torch.linalg.norm(particle['theta']).item()
                grad_z_norm = torch.linalg.norm(grads['z']).item()
                grad_theta_norm = torch.linalg.norm(grads['theta']).item()
                
                # Main log entry
                log.info(f"Iter {t}: Z_norm={z_norm:.4f}, Theta_norm={theta_norm:.4f}, "
                         f"log_joint={lj_val:.4f}, grad_Z_norm={grad_z_norm:.4e}, "
                         f"grad_Theta_norm={grad_theta_norm:.4e}")

                # Gradient sample
                log.info(f"      grad_Theta (sample from iter {t}):")
                # Ensure the logged tensor fits the console nicely
                grad_sample_str = str(grads['theta'][:3, :3]).replace("\n", "\n         ")
                log.info(f" {grad_sample_str}")
                
                # Annealed parameters
                log.info(f"      Annealed: alpha={hparams['alpha']:.3f}, "
                         f"beta={hparams['beta']:.3f}, tau={hparams['tau']:.3f}")
                         
                # Edge probabilities
                log.info(f"      Current Edge Probs (from Z, alpha={hparams['alpha']:.3f}):")
                edge_probs = bernoulli_soft_gmat(particle['z'], hparams)
                #Ensure the logged tensor fits the console nicely
                edge_probs_str = str(edge_probs).replace("\n", "\n         ")
                log.info(f" {edge_probs_str}")

                # Log to MLflow
                mlflow.log_metric("log_joint", lj_val, step=t)
    return particle,hparams


particle, hparams_final = basic_training_loop()


tensor(0.5260, grad_fn=<MeanBackward1>)
tensor(0.4583, grad_fn=<MeanBackward1>)
tensor(0.4948, grad_fn=<MeanBackward1>)
tensor(0.4983, grad_fn=<MeanBackward1>)
tensor(0.5625, grad_fn=<MeanBackward1>)
tensor(0.5712, grad_fn=<MeanBackward1>)
tensor(0.5885, grad_fn=<MeanBackward1>)
tensor(0.6580, grad_fn=<MeanBackward1>)
tensor(0.5486, grad_fn=<MeanBackward1>)
tensor(0.4931, grad_fn=<MeanBackward1>)
tensor(0.5122, grad_fn=<MeanBackward1>)
tensor(0.5781, grad_fn=<MeanBackward1>)
tensor(0.4236, grad_fn=<MeanBackward1>)
tensor(0.6215, grad_fn=<MeanBackward1>)
tensor(0.5677, grad_fn=<MeanBackward1>)
tensor(0.5243, grad_fn=<MeanBackward1>)
tensor(0.5608, grad_fn=<MeanBackward1>)
tensor(0.5191, grad_fn=<MeanBackward1>)
tensor(0.4722, grad_fn=<MeanBackward1>)
tensor(0.4705, grad_fn=<MeanBackward1>)
tensor(0.3941, grad_fn=<MeanBackward1>)
tensor(0.6545, grad_fn=<MeanBackward1>)
tensor(0.5122, grad_fn=<MeanBackward1>)
tensor(0.4792, grad_fn=<MeanBackward1>)
tensor(0.7569, grad_fn=<MeanBackward1>)


2025-06-23 14:26:48,889 - INFO - Iter 100: Z_norm=2.4661, Theta_norm=3.7427, log_joint=-167.4289, grad_Z_norm=3.5090e+00, grad_Theta_norm=6.1905e+00
2025-06-23 14:26:48,890 - INFO -       grad_Theta (sample from iter 100):
2025-06-23 14:26:48,892 - INFO -  tensor([[ 0.0000e+00,  6.0288e+00, -8.2731e-05],
                 [ 8.9127e-03,  0.0000e+00,  0.0000e+00],
                 [-4.6622e-05, -1.4055e+00,  0.0000e+00]], requires_grad=True)
2025-06-23 14:26:48,893 - INFO -       Annealed: alpha=0.488, beta=190.325, tau=1.000
2025-06-23 14:26:48,894 - INFO -       Current Edge Probs (from Z, alpha=0.488):
2025-06-23 14:26:48,895 - INFO -  tensor([[0.0000, 0.4841, 0.6119],
                 [0.5331, 0.0000, 0.2235],
                 [0.5110, 0.5045, 0.0000]])


tensor(0.5017)
tensor(0.5330, grad_fn=<MeanBackward1>)
tensor(0.5313, grad_fn=<MeanBackward1>)
tensor(0.5434, grad_fn=<MeanBackward1>)
tensor(0.5885, grad_fn=<MeanBackward1>)
tensor(0.6406, grad_fn=<MeanBackward1>)
tensor(0.6441, grad_fn=<MeanBackward1>)
tensor(0.4410, grad_fn=<MeanBackward1>)
tensor(0.4358, grad_fn=<MeanBackward1>)
tensor(0.4549, grad_fn=<MeanBackward1>)
tensor(0.5191, grad_fn=<MeanBackward1>)
tensor(0.4722, grad_fn=<MeanBackward1>)
tensor(0.5660, grad_fn=<MeanBackward1>)
tensor(0.5156, grad_fn=<MeanBackward1>)
tensor(0.5000, grad_fn=<MeanBackward1>)
tensor(0.5642, grad_fn=<MeanBackward1>)
tensor(0.4948, grad_fn=<MeanBackward1>)
tensor(0.4375, grad_fn=<MeanBackward1>)
tensor(0.4149, grad_fn=<MeanBackward1>)
tensor(0.5208, grad_fn=<MeanBackward1>)
tensor(0.4861, grad_fn=<MeanBackward1>)
tensor(0.5139, grad_fn=<MeanBackward1>)
tensor(0.4792, grad_fn=<MeanBackward1>)
tensor(0.5000, grad_fn=<MeanBackward1>)
tensor(0.4392, grad_fn=<MeanBackward1>)
tensor(0.4601, grad_fn=<M

2025-06-23 14:27:00,253 - INFO - Iter 200: Z_norm=2.7043, Theta_norm=3.7395, log_joint=-124.3563, grad_Z_norm=3.8265e+00, grad_Theta_norm=5.4608e+00
2025-06-23 14:27:00,254 - INFO -       grad_Theta (sample from iter 200):
2025-06-23 14:27:00,256 - INFO -  tensor([[ 0.0000e+00,  2.8589e+00, -1.4544e-05],
                 [ 1.2834e-05,  0.0000e+00,  0.0000e+00],
                 [-1.7558e-06,  4.6526e+00,  0.0000e+00]], requires_grad=True)
2025-06-23 14:27:00,256 - INFO -       Annealed: alpha=0.952, beta=362.538, tau=1.000
2025-06-23 14:27:00,257 - INFO -       Current Edge Probs (from Z, alpha=0.952):
2025-06-23 14:27:00,258 - INFO -  tensor([[0.0000, 0.4755, 0.8494],
                 [0.4637, 0.0000, 0.0523],
                 [0.4970, 0.5048, 0.0000]])


tensor(0.4601)
tensor(0.4826, grad_fn=<MeanBackward1>)
tensor(0.6042, grad_fn=<MeanBackward1>)
tensor(0.5382, grad_fn=<MeanBackward1>)
tensor(0.5295, grad_fn=<MeanBackward1>)
tensor(0.3924, grad_fn=<MeanBackward1>)
tensor(0.4844, grad_fn=<MeanBackward1>)
tensor(0.4861, grad_fn=<MeanBackward1>)
tensor(0.6181, grad_fn=<MeanBackward1>)
tensor(0.4497, grad_fn=<MeanBackward1>)
tensor(0.6128, grad_fn=<MeanBackward1>)
tensor(0.5226, grad_fn=<MeanBackward1>)
tensor(0.5000, grad_fn=<MeanBackward1>)
tensor(0.3941, grad_fn=<MeanBackward1>)
tensor(0.4514, grad_fn=<MeanBackward1>)
tensor(0.4826, grad_fn=<MeanBackward1>)
tensor(0.4340, grad_fn=<MeanBackward1>)
tensor(0.4635, grad_fn=<MeanBackward1>)
tensor(0.5208, grad_fn=<MeanBackward1>)
tensor(0.5382, grad_fn=<MeanBackward1>)
tensor(0.4462, grad_fn=<MeanBackward1>)
tensor(0.5382, grad_fn=<MeanBackward1>)
tensor(0.4878, grad_fn=<MeanBackward1>)
tensor(0.4097, grad_fn=<MeanBackward1>)
tensor(0.3993, grad_fn=<MeanBackward1>)
tensor(0.4410, grad_fn=<M

2025-06-23 14:27:11,280 - INFO - Iter 300: Z_norm=2.6345, Theta_norm=3.7548, log_joint=-192.0087, grad_Z_norm=1.7454e+00, grad_Theta_norm=1.5416e+00
2025-06-23 14:27:11,282 - INFO -       grad_Theta (sample from iter 300):
2025-06-23 14:27:11,283 - INFO -  tensor([[ 0.0000e+00,  1.1254e+00, -1.4544e-05],
                 [ 4.4132e-06,  0.0000e+00,  0.0000e+00],
                 [-7.5690e-07,  1.0536e+00,  0.0000e+00]], requires_grad=True)
2025-06-23 14:27:11,283 - INFO -       Annealed: alpha=1.393, beta=518.364, tau=1.000
2025-06-23 14:27:11,284 - INFO -       Current Edge Probs (from Z, alpha=1.393):
2025-06-23 14:27:11,285 - INFO -  tensor([[0.0000, 0.6819, 0.9339],
                 [0.5305, 0.0000, 0.0250],
                 [0.4933, 0.5581, 0.0000]])


tensor(0.4948)
tensor(0.5729, grad_fn=<MeanBackward1>)
tensor(0.5712, grad_fn=<MeanBackward1>)
tensor(0.6858, grad_fn=<MeanBackward1>)
tensor(0.6372, grad_fn=<MeanBackward1>)
tensor(0.4167, grad_fn=<MeanBackward1>)
tensor(0.5677, grad_fn=<MeanBackward1>)
tensor(0.6042, grad_fn=<MeanBackward1>)
tensor(0.6736, grad_fn=<MeanBackward1>)
tensor(0.6302, grad_fn=<MeanBackward1>)
tensor(0.5156, grad_fn=<MeanBackward1>)
tensor(0.6875, grad_fn=<MeanBackward1>)
tensor(0.5764, grad_fn=<MeanBackward1>)
tensor(0.5521, grad_fn=<MeanBackward1>)
tensor(0.5712, grad_fn=<MeanBackward1>)
tensor(0.6111, grad_fn=<MeanBackward1>)
tensor(0.6076, grad_fn=<MeanBackward1>)
tensor(0.5382, grad_fn=<MeanBackward1>)
tensor(0.6094, grad_fn=<MeanBackward1>)
tensor(0.5434, grad_fn=<MeanBackward1>)
tensor(0.5035, grad_fn=<MeanBackward1>)
tensor(0.5903, grad_fn=<MeanBackward1>)
tensor(0.5955, grad_fn=<MeanBackward1>)
tensor(0.5382, grad_fn=<MeanBackward1>)
tensor(0.6302, grad_fn=<MeanBackward1>)
tensor(0.4861, grad_fn=<M

2025-06-23 14:27:22,386 - INFO - Iter 400: Z_norm=2.4669, Theta_norm=3.7646, log_joint=-331.9135, grad_Z_norm=5.3650e+00, grad_Theta_norm=1.5851e-01
2025-06-23 14:27:22,387 - INFO -       grad_Theta (sample from iter 400):
2025-06-23 14:27:22,389 - INFO -  tensor([[ 0.0000e+00,  1.1720e-01, -2.3842e-06],
                 [ 4.7528e-06,  0.0000e+00,  0.0000e+00],
                 [-5.2687e-07,  1.0673e-01,  0.0000e+00]], requires_grad=True)
2025-06-23 14:27:22,390 - INFO -       Annealed: alpha=1.813, beta=659.360, tau=1.000
2025-06-23 14:27:22,391 - INFO -       Current Edge Probs (from Z, alpha=1.813):
2025-06-23 14:27:22,393 - INFO -  tensor([[0.0000, 0.8119, 0.9493],
                 [0.6028, 0.0000, 0.0274],
                 [0.4527, 0.6900, 0.0000]])


tensor(0.6215)
tensor(0.6024, grad_fn=<MeanBackward1>)
tensor(0.6128, grad_fn=<MeanBackward1>)
tensor(0.5868, grad_fn=<MeanBackward1>)
tensor(0.6181, grad_fn=<MeanBackward1>)
tensor(0.5816, grad_fn=<MeanBackward1>)
tensor(0.7083, grad_fn=<MeanBackward1>)
tensor(0.5764, grad_fn=<MeanBackward1>)
tensor(0.6302, grad_fn=<MeanBackward1>)
tensor(0.6441, grad_fn=<MeanBackward1>)
tensor(0.6597, grad_fn=<MeanBackward1>)
tensor(0.5469, grad_fn=<MeanBackward1>)
tensor(0.6285, grad_fn=<MeanBackward1>)
tensor(0.5208, grad_fn=<MeanBackward1>)
tensor(0.5486, grad_fn=<MeanBackward1>)
tensor(0.7778, grad_fn=<MeanBackward1>)
tensor(0.6354, grad_fn=<MeanBackward1>)
tensor(0.7049, grad_fn=<MeanBackward1>)
tensor(0.5972, grad_fn=<MeanBackward1>)
tensor(0.5990, grad_fn=<MeanBackward1>)
tensor(0.6215, grad_fn=<MeanBackward1>)
tensor(0.6788, grad_fn=<MeanBackward1>)


KeyboardInterrupt: 

In [8]:
print("Final soft matrix (2 decimals):")
soft_gmat = bernoulli_soft_gmat(particle['z'], hparams_final)
soft_gmat_rounded = torch.round(soft_gmat * 100) / 100
print(soft_gmat_rounded)

print(f"\nHard graph from Z:")
hard_graph = hard_gmat_from_z(particle['z'], hparams_final['alpha'])
print(hard_graph)

print(f"\nLearned Theta:")
print(particle['theta'])

print(f"\nLearned Theta * Learned Graph:")
weighted_learned = hard_graph * particle['theta']
print(weighted_learned)

print(f"\nTrue graph (graph_adj):")
print(graph_adj)

print(f"\nTrue weights (graph_weights):")
print(graph_weights)

Final soft matrix (2 decimals):
tensor([[0.0000, 1.0000, 0.0000],
        [0.0000, 0.0000, 0.9200],
        [0.0100, 0.0000, 0.0000]])

Hard graph from Z:
tensor([[0., 1., 0.],
        [0., 0., 1.],
        [0., 0., 0.]])

Learned Theta:
tensor([[-1.4723,  1.9488,  0.5655],
        [-0.9521, -0.2732, -1.4886],
        [-0.2534, -0.2550, -0.4564]], requires_grad=True)

Learned Theta * Learned Graph:
tensor([[-0.0000,  1.9488,  0.0000],
        [-0.0000, -0.0000, -1.4886],
        [-0.0000, -0.0000, -0.0000]], grad_fn=<MulBackward0>)

True graph (graph_adj):
tensor([[0., 1., 0.],
        [0., 0., 1.],
        [0., 0., 0.]])

True weights (graph_weights):
tensor([[ 0.0000,  2.0000,  0.0000],
        [ 0.0000,  0.0000, -1.5000],
        [ 0.0000,  0.0000,  0.0000]])


# Results

In [9]:

print("=========== Final Results ===========")
# print("Final probs:")
print("Edge probabilities:")
edge_probs = bernoulli_soft_gmat(particle['z'], hparams)
print(edge_probs)


# End the MLflow run
mlflow.end_run()
log.info("MLflow run finished and artifacts logged.")

2025-06-23 14:06:10,054 - INFO - MLflow run finished and artifacts logged.


Edge probabilities:
tensor([[0.0000, 0.5199, 0.4791],
        [0.4691, 0.0000, 0.5078],
        [0.4866, 0.4709, 0.0000]])


# Enhanced traning loop

In [10]:
def enchaced_training_loop():
    # ========================================================================
    # ENHANCED TRAINING LOOP WITH COMPREHENSIVE MONITORING
    # ========================================================================
    # This version extends the basic training loop with:
    # 1. Edge probability tracking during training
    # 2. Parameter norm monitoring (Z and Theta norms)
    # 3. Gradient norm computation and optional clipping
    # 4. Structural Hamming Distance (SHD) computation
    # 5. Enhanced MLflow logging with multiple metrics
    # 6. Numerical stability checks (NaN/Inf detection)
    # 7. Weighted theta matrix visualization (G * Theta)
    # 8. PyTorch optimizers with proper gradient hooking

    log.info("\n" + "="*80)
    log.info("STARTING ENHANCED TRAINING LOOP WITH DETAILED MONITORING")
    log.info("="*80)

    # --- Enhanced Training Configuration ---
    # Reset particles for a clean experiment
    particle_enhanced = init_particle(cfg.d_nodes, cfg.k_latent, cfg.device)

    # Enhanced training hyperparameters
    num_iterations_enhanced = cfg.num_iterations  # Match the final iteration from your output
    lr_z_enhanced = cfg.lr          # Separate learning rate for Z
    lr_theta_enhanced = cfg.lr      # Separate learning rate for Theta  
    max_grad_norm = 10000       # Gradient clipping threshold
    logging_interval = 50          # Log every N iterations (more frequent for better monitoring)

    # Initialize PyTorch optimizers for enhanced training
    optimizer_z_enhanced = torch.optim.Adam([particle_enhanced['z']], lr=lr_z_enhanced)
    optimizer_theta_enhanced = torch.optim.Adam([particle_enhanced['theta']], lr=lr_theta_enhanced)

    # Log initial configuration
    log.info(f"Enhanced Training Configuration:")
    log.info(f"  Iterations: {num_iterations_enhanced}")
    log.info(f"  Learning rates - Z: {lr_z_enhanced}, Theta: {lr_theta_enhanced}")
    log.info(f"  Gradient clipping threshold: {max_grad_norm}")
    log.info(f"  Logging interval: {logging_interval}")
    log.info(f"  Initial Z norm: {particle_enhanced['z'].norm().item():.4f}")
    log.info(f"  Initial Theta norm: {particle_enhanced['theta'].norm().item():.4f}")

    # --- Main Enhanced Training Loop ---
    for t in range(1, num_iterations_enhanced + 1):
        hparams = update_dibs_hparams(hparams, t)
        
        # ------------------------------------------------
        # STEP 1: Parameter Setup and Gradient Clearing
        # ------------------------------------------------
        # Clear gradients using optimizers
        optimizer_z_enhanced.zero_grad()
        optimizer_theta_enhanced.zero_grad()
        
        # Enable gradient computation for both Z and Theta parameters
        particle_enhanced['z'].requires_grad_(True)
        particle_enhanced['theta'].requires_grad_(True)
        
        # ------------------------------------------------
        # STEP 2: Forward Pass - Compute Log-Joint and Gradients
        # ------------------------------------------------
        # Prepare parameters dictionary for gradient computation
        params_for_grad = {
            "z": particle_enhanced['z'], 
            "theta": particle_enhanced['theta'], 
            "t": torch.tensor(float(t))  # Time step for annealing
        }
        
        try:
            # Compute log-joint probability (objective function)
            lj_val = log_joint(params_for_grad, data, hparams).item()
            
            # Compute gradients of log-joint w.r.t. Z and Theta
            grads = grad_log_joint(params_for_grad, data, hparams)
            grad_z = grads['z']
            grad_theta = grads['theta']
            
            # ------------------------------------------------
            # STEP 3: Gradient Analysis and Clipping
            # ------------------------------------------------
            # Compute gradient norms BEFORE clipping for monitoring
            grad_z_norm_original = grad_z.norm().item()
            grad_theta_norm_original = grad_theta.norm().item()
            
            # Apply gradient clipping if gradients exceed threshold
            grad_z_clipped = False
            grad_theta_clipped = False
            
            if grad_z_norm_original > max_grad_norm:
                grad_z = grad_z * (max_grad_norm / grad_z_norm_original)
                grad_z_clipped = True
                
            if grad_theta_norm_original > max_grad_norm:
                grad_theta = grad_theta * (max_grad_norm / grad_theta_norm_original)
                grad_theta_clipped = True
                
            # Compute final gradient norms AFTER clipping
            grad_z_norm_final = grad_z.norm().item()
            grad_theta_norm_final = grad_theta.norm().item()
                
        except Exception as e:
            log.error(f"Error in forward pass at iteration {t}: {e}")
            break
        
        # ------------------------------------------------
        # STEP 4: Parameter Update using PyTorch Optimizers
        # ------------------------------------------------
        # Hook gradients into PyTorch's gradient system
        # This is crucial: assign the computed gradients to .grad attributes
        particle_enhanced['z'].grad = grad_z
        particle_enhanced['theta'].grad = grad_theta
        
        # Use PyTorch optimizers to update parameters
        optimizer_z_enhanced.step()
        optimizer_theta_enhanced.step()
        
        # ------------------------------------------------
        # STEP 5: Comprehensive Logging and Monitoring
        # ------------------------------------------------
        if t % logging_interval == 0 or t == 1 or t == num_iterations_enhanced:
            
            # Compute current parameter norms
            z_norm = particle_enhanced['z'].norm().item()
            theta_norm = particle_enhanced['theta'].norm().item()
            
            # ------------------------------------------------
            # STEP 5a: Edge Probability Analysis
            # ------------------------------------------------
            with torch.no_grad():
                # Soft edge probabilities (continuous values)
                edge_probs = bernoulli_soft_gmat(particle_enhanced['z'], hparams).detach().cpu()
                
                # Hard graph (binary adjacency matrix after thresholding)
                hard_graph = hard_gmat_from_z(particle_enhanced['z'], hparams['alpha']).detach().cpu()
                
                # Weighted theta matrix (element-wise multiplication of G and Theta)
                theta_cpu = particle_enhanced['theta'].detach().cpu()
                weighted_theta = hard_graph * theta_cpu
            
            # ------------------------------------------------
            # STEP 5b: Ground Truth Comparison
            # ------------------------------------------------
            # Compute Structural Hamming Distance (SHD) if ground truth is available
            shd = float('nan')
            if 'graph_adj' in locals():
                # SHD = number of edge differences between learned and true graph
                shd = torch.sum(torch.abs(hard_graph.int() - graph_adj.int())).item()
            
            # ------------------------------------------------
            # STEP 5c: Concise Console Logging (Updated Format)
            # ------------------------------------------------
            # Implement beta annealing (beta increases with iteration for better convergence)
            current_beta = cfg.beta_val + t * 0.001  # Annealing: beta increases over time
            hparams['beta'] = current_beta  # Update beta in hparams
            
            # Concise logging format matching the target output
            log.info(f"Iter {t}: Z_norm={z_norm:.4f}, Theta_norm={theta_norm:.4f}, log_joint={lj_val:.4f}, grad_Z_norm={grad_z_norm_original:.4e}, grad_Theta_norm={grad_theta_norm_original:.4e}")
            
            # Show a sample of the current grad_Theta matrix
            log.info(f"    grad_Theta (sample from iter {t}):")
            log.info(f"{grad_theta.detach().cpu()}")
            
            # Show annealed hyperparameters
            current_alpha = hparams.get('alpha', cfg.alpha_val)
            current_tau = hparams.get('tau', cfg.tau_val)
            log.info(f"    Annealed: alpha={current_alpha:.3f}, beta={current_beta:.3f}, tau={current_tau:.3f}")
            
            # Show current edge probabilities
            log.info(f"    Current Edge Probs (from Z, alpha={current_alpha:.3f}):")
            log.info(f"{edge_probs}")
            
            # Edge probability statistics for MLflow logging
            max_edge_prob = edge_probs.max().item()
            mean_edge_prob = edge_probs.mean().item()
            num_edges_hard = hard_graph.sum().item()
            
            # Ground truth comparison
            if not math.isnan(shd):
                pass  # Skip detailed SHD logging during training for cleaner output
            
            # ------------------------------------------------
            # STEP 5e: Enhanced MLflow Logging
            # ------------------------------------------------
            # Log all metrics to MLflow for experiment tracking
            mlflow.log_metric("enhanced_log_joint", lj_val, step=t)
            mlflow.log_metric("z_norm", z_norm, step=t)
            mlflow.log_metric("theta_norm", theta_norm, step=t)
            mlflow.log_metric("grad_z_norm_original", grad_z_norm_original, step=t)
            mlflow.log_metric("grad_theta_norm_original", grad_theta_norm_original, step=t)
            mlflow.log_metric("grad_z_norm_final", grad_z_norm_final, step=t)
            mlflow.log_metric("grad_theta_norm_final", grad_theta_norm_final, step=t)
            mlflow.log_metric("max_edge_prob", max_edge_prob, step=t)
            mlflow.log_metric("mean_edge_prob", mean_edge_prob, step=t)
            mlflow.log_metric("num_hard_edges", num_edges_hard, step=t)
            
            if not math.isnan(shd):
                mlflow.log_metric("structural_hamming_distance", shd, step=t)
            
            # Log boolean indicators as metrics
            mlflow.log_metric("z_gradient_clipped", float(grad_z_clipped), step=t)
            mlflow.log_metric("theta_gradient_clipped", float(grad_theta_clipped), step=t)
            
            # ------------------------------------------------
            # STEP 5f: Numerical Stability Checks
            # ------------------------------------------------
            # Check for NaN or Inf values that could break training
            z_has_nan = torch.isnan(particle_enhanced['z']).any()
            z_has_inf = torch.isinf(particle_enhanced['z']).any()
            theta_has_nan = torch.isnan(particle_enhanced['theta']).any()
            theta_has_inf = torch.isinf(particle_enhanced['theta']).any()
            
            if z_has_nan or theta_has_nan:
                log.error("!!! NaN DETECTED IN PARAMETERS - STOPPING TRAINING !!!")
                log.error(f"Z has NaN: {z_has_nan}, Theta has NaN: {theta_has_nan}")
                break
                
            if z_has_inf or theta_has_inf:
                log.error("!!! INFINITY DETECTED IN PARAMETERS - STOPPING TRAINING !!!")
                log.error(f"Z has Inf: {z_has_inf}, Theta has Inf: {theta_has_inf}")
                break

    # ========================================================================
    # ENHANCED TRAINING COMPLETION AND FINAL ANALYSIS
    # ========================================================================

    log.info("\n" + "="*80)
    log.info("ENHANCED TRAINING LOOP COMPLETED")
    log.info("="*80)

    # Compute final graph structures
    with torch.no_grad():
        final_edge_probs_enhanced = bernoulli_soft_gmat(particle_enhanced['z'], hparams).detach().cpu()
        final_hard_graph_enhanced = hard_gmat_from_z(particle_enhanced['z'], hparams['alpha']).detach().cpu()
        final_theta_enhanced = particle_enhanced['theta'].detach().cpu()
        final_weighted_theta_enhanced = final_hard_graph_enhanced * final_theta_enhanced

    # Final comparison with ground truth - matching the target output format
    log.info("\n        --- Comparison with Ground Truth ---")
    log.info(f"Final G_learned_hard:")
    log.info(f"{final_hard_graph_enhanced.int()}")
    log.info("**************************************************")

    # Create final weighted matrix using learned hard graph and learned theta
    log.info(f"Final G_learned_hard * Theta_learned:")
    log.info(f"{final_weighted_theta_enhanced}")

    # Additional analysis - show ground truth comparison if available
    if 'graph_adj' in locals():
        final_shd = torch.sum(torch.abs(final_hard_graph_enhanced.int() - graph_adj.int())).item()
        log.info(f"\nFinal Structural Hamming Distance: {final_shd}")
        
        log.info(f"\nDetailed Comparison:")
        log.info(f"Ground Truth Graph:")
        log.info(f"{graph_adj.int()}")
        log.info(f"Learned Graph:")
        log.info(f"{final_hard_graph_enhanced.int()}")
        
        if 'graph_weights' in locals():
            log.info(f"Ground Truth Theta:")
            log.info(f"{graph_weights}")
            log.info(f"Learned Theta:")
            log.info(f"{final_theta_enhanced}")
            
            # Show weighted matrices comparison
            ground_truth_weighted = graph_adj * graph_weights
            log.info(f"Ground Truth G * Theta:")
            log.info(f"{ground_truth_weighted}")

    log.info("\nEnhanced training analysis complete!")

# Training Methods Comparison
Compare results between the basic training loop and the enhanced training loop.

# Results

In [11]:
# ========================================================================
# FINAL RESULTS: Matrix Outputs (Enhanced Training)
# ========================================================================
# Display the final learned matrices from the enhanced training loop

log.info("\n" + "="*80)
log.info("FINAL RESULTS FROM ENHANCED TRAINING")
log.info("="*80)

# Use results from enhanced training loop
final_graph_enhanced = hard_gmat_from_z(particle_enhanced['z'], hparams['alpha']).detach().cpu()
edge_probs_enhanced = bernoulli_soft_gmat(particle_enhanced['z'], hparams).detach().cpu()
theta_enhanced = particle_enhanced['theta'].detach().cpu()
weighted_theta_enhanced = final_graph_enhanced * theta_enhanced

log.info(f"\n1. LEARNED EDGE PROBABILITIES (Soft Graph):")
log.info(f"{edge_probs_enhanced}")

log.info(f"\n2. LEARNED HARD GRAPH (Binary Adjacency Matrix):")
log.info(f"{final_graph_enhanced.int()}")

log.info(f"\n3. LEARNED THETA MATRIX (Edge Weights):")
log.info(f"{theta_enhanced}")

log.info(f"\n4. FINAL WEIGHTED GRAPH (G ⊙ Theta):")
log.info(f"{weighted_theta_enhanced}")

# Ground truth comparison
if 'graph_adj' in locals() and 'graph_weights' in locals():
    log.info(f"\n" + "-"*60)
    log.info("GROUND TRUTH COMPARISON")
    log.info("-"*60)
    
    log.info(f"\nGround Truth Graph:")
    log.info(f"{graph_adj.int()}")
    
    log.info(f"\nGround Truth Weights:")
    log.info(f"{graph_weights}")
    
    ground_truth_weighted = graph_adj * graph_weights
    log.info(f"\nGround Truth Weighted (G_true ⊙ Theta_true):")
    log.info(f"{ground_truth_weighted}")
    
    # Compute final metrics
    final_shd = torch.sum(torch.abs(final_graph_enhanced.int() - graph_adj.int())).item()
    log.info(f"\nStructural Hamming Distance (SHD): {final_shd}")
    
    # Check if structure is correctly recovered
    structure_match = torch.equal(final_graph_enhanced.int(), graph_adj.int())
    log.info(f"Perfect Structure Recovery: {structure_match}")

# MLflow logging for final results
mlflow.log_metric("final_z_norm", particle_enhanced['z'].norm().item())
mlflow.log_metric("final_theta_norm", particle_enhanced['theta'].norm().item())
mlflow.log_metric("final_max_edge_prob", edge_probs_enhanced.max().item())
mlflow.log_metric("final_mean_edge_prob", edge_probs_enhanced.mean().item())
mlflow.log_metric("final_num_edges", final_graph_enhanced.sum().item())

if 'graph_adj' in locals():
    mlflow.log_metric("final_shd", final_shd)
    mlflow.log_metric("perfect_structure_recovery", float(structure_match))

# End the MLflow run
mlflow.end_run()
log.info("\nFinal results logging complete and MLflow run ended!")

2025-06-23 14:06:10,164 - INFO - 
2025-06-23 14:06:10,166 - INFO - FINAL RESULTS FROM ENHANCED TRAINING


NameError: name 'particle_enhanced' is not defined

# LOGSUMEXP ISSUE

In [None]:
# Test function for manual_stable_gradient2 with Gaussian log densities
import torch
import torch.nn.functional as F




def logsumexp_v1(log_tensor: torch.Tensor) -> torch.Tensor:


    M = log_tensor.shape[0]
    logM = torch.log(torch.tensor(M, dtype=log_tensor.dtype, device=log_tensor.device))

    
    log_sum_exp = torch.logsumexp(log_tensor, dim=0)

    total = log_sum_exp - logM
    return total # torch.exp(total)

def manual_stable_gradient(log_p_tensor: torch.Tensor, grad_p_tensor: torch.Tensor) -> torch.Tensor:
# uses the logsumexp_v1 function to compute the stable gradient

    print(f'log density values and shape: {log_p_tensor}, {log_p_tensor.shape}')
    log_density_lse = torch.exp(torch.logsumexp(log_p_tensor, dim=0) - log_p_tensor.shape[0])  # logsumexp_v1(log_p_tensor)
    # logsumexp_v1(log_p_tensor)
    print(f'log density lse value: {log_density_lse}, shape: {log_density_lse.shape}')

    print('-' * 50)
    print(f'grad density values and shape: {grad_p_tensor}, {grad_p_tensor.shape}')
    grad_lse = logsumexp_v1(grad_p_tensor)
    print(f'grad density lse value: {grad_lse}, shape: {grad_lse.shape}')

    return torch.exp(logsumexp_v1(grad_p_tensor) - logsumexp_v1(log_p_tensor)) # grad_lse / log_density_lse[:, None]





def test_softmax_vs_manual_logsumexp():
    """
    Compare softmax weights with manual logsumexp computation to verify equivalence
    """
    print("="*60)
    print("COMPARING SOFTMAX VS MANUAL LOGSUMEXP")
    print("="*60)
    
    # Set random seed for reproducibility
    torch.manual_seed(42)
    
    # Parameters
    n_samples = 5
    theta_shape = (3, 3)
    
    # Generate log densities from normal distribution
    log_densities = torch.randn(n_samples) * 2.0 - 10.0  # Mean around -10, std=2
    print(f"Log densities: {log_densities}")
    
    # Generate some gradients
    grad_samples = torch.randn(n_samples, *theta_shape)
    print(f"Gradient samples shape: {grad_samples.shape}")
    
    print(f"\n" + "-"*50)
    print("METHOD 1: USING PYTORCH SOFTMAX")
    print("-"*50)
    
    # Method 1: Using PyTorch's built-in softmax
    weights_softmax = F.softmax(log_densities, dim=0)
    weighted_grad_softmax = torch.sum(weights_softmax.view(-1, 1, 1) * grad_samples, dim=0)
    
    print(f"Softmax weights: {weights_softmax}")
    print(f"Sum of softmax weights: {weights_softmax.sum()}")
    print(f"Softmax weighted gradient:\n{weighted_grad_softmax}")
    
    print(f"\n" + "-"*50)

    print("METHOD 2: USING LOGSUMEXP_V1 FUNCTION")
    print("-"*50)
    
    # Method 3: Using your logsumexp_v1 function
    log_sum_exp_v1 = logsumexp_v1(log_densities)
    print(f"logsumexp_v1 result: {log_sum_exp_v1}")
    
    # Convert to weights
    weights_v1 = torch.exp(log_densities - log_sum_exp_v1)
    print(f"Weights from logsumexp_v1: {weights_v1}")
    print(f"Sum of logsumexp_v1 weights: {weights_v1.sum()}")
    
    # Weighted gradient using logsumexp_v1
    weighted_grad_v1 = torch.sum(weights_v1.view(-1, 1, 1) * grad_samples, dim=0)
    print(f"logsumexp_v1 weighted gradient:\n{weighted_grad_v1}")
    
    print(f"\n" + "-"*50)
    print("COMPARISON RESULTS")
    print("-"*50)
    
    # Compare weights
    weights_diff_softmax_manual = torch.abs(weights_softmax - weights_manual).max()
    weights_diff_softmax_v1 = torch.abs(weights_softmax - weights_v1).max()
    
    print(f"Max difference between softmax and manual weights: {weights_diff_softmax_manual:.10f}")
    print(f"Max difference between softmax and logsumexp_v1 weights: {weights_diff_softmax_v1:.10f}")
    
    # Compare gradients
    grad_diff_softmax_manual = torch.abs(weighted_grad_softmax - weighted_grad_manual).max()
    grad_diff_softmax_v1 = torch.abs(weighted_grad_softmax - weighted_grad_v1).max()
    
    print(f"Max difference between softmax and manual gradients: {grad_diff_softmax_manual:.10f}")
    print(f"Max difference between softmax and logsumexp_v1 gradients: {grad_diff_softmax_v1:.10f}")
    
    # Check if they're essentially equal
    tolerance = 1e-6
    softmax_manual_equal = weights_diff_softmax_manual < tolerance
    softmax_v1_equal = weights_diff_softmax_v1 < tolerance
    
    print(f"\nAre softmax and manual weights equal (tol={tolerance})? {softmax_manual_equal}")
    print(f"Are softmax and logsumexp_v1 weights equal (tol={tolerance})? {softmax_v1_equal}")

# Run the comparison test
#test_softmax_vs_manual_logsumexp()

In [None]:
import torch
import torch.nn.functional as F

# ---------- numerically stable, vector-valued estimator ----------
def weighted_grad(log_p: torch.Tensor,
                  grad_p: torch.Tensor) -> torch.Tensor:
    """
    Return  ∑ softmax(log_p)_m * grad_p[m]
    Shapes
        log_p  : (M,)
        grad_p : (M, …)
    """
    w = torch.softmax(log_p, dim=0)           # (M,)
    while w.dim() < grad_p.dim():             # make w broadcastable
        w = w.unsqueeze(-1)
    return (w * grad_p).sum(dim=0)
# ----------------------------------------------------------------


# ------------------ quick comparison against softmax ------------
torch.manual_seed(42)
M, shape = 5, (3, 3)

log_p   = torch.randn(M) * 2.0 - 10.0               # log densities
grad_p  = torch.randn(M, *shape)                    # per-sample grads

w_soft  = F.softmax(log_p, dim=0)                   # reference weights
grad_sm = (w_soft.unsqueeze(-1).unsqueeze(-1) * grad_p).sum(0)

grad_wg = weighted_grad(log_p, grad_p)

print("log densities:\n", log_p, "\n")
print("softmax weights  :", w_soft)
print("weighted_grad wts:", torch.softmax(log_p, 0))        # same
print("‖grad diff‖_∞     :", (grad_sm - grad_wg).abs().max())
