# Model

In [2]:
# models/dibs.py
import torch
from torch.distributions import Normal
import numpy as np
import logging
from typing import Dict, Any, Tuple


log = logging.getLogger(__name__)


def acyclic_constr(g: torch.Tensor, d: int) -> torch.Tensor:
    """H(G) from NOTEARS (Zheng et al.) with a series fallback for large *d*."""
    alpha = 1.0 / d
    eye = torch.eye(d, device=g.device, dtype=g.dtype)
    m = eye + alpha * g

    if d <= 10:
        return torch.trace(torch.linalg.matrix_power(m, d)) - d

    try:
        eigvals = torch.linalg.eigvals(m)
        return torch.sum(torch.real(eigvals ** d)) - d
    except RuntimeError:
        trace, p = torch.tensor(0.0, device=g.device, dtype=g.dtype), g.clone()
        for k in range(1, min(d + 1, 20)):
            trace += (alpha ** k) * torch.trace(p) / k
            if k < 19:
                p = p @ g
        return trace


def log_gaussian_likelihood(x: torch.Tensor, pred_mean: torch.Tensor, sigma: float = 0.1) -> torch.Tensor:
    sigma_tensor = torch.tensor(sigma, dtype=pred_mean.dtype, device=pred_mean.device)
    
    residuals = x - pred_mean
    #old incorrect log_prob = -0.5 * (np.log(2 * np.pi) -  (1/2)* torch.log(sigma_tensor**2) -  0.5*(residuals / sigma_tensor) ** 2) old
    log_prob = -0.5 * (torch.log(2 * torch.pi * sigma_tensor**2)) - 0.5 * ((residuals / sigma_tensor)**2)
    #normal_dist = Normal(loc=pred_mean, scale=sigma_tensor)
    #log_prob = normal_dist.log_prob(x)

    return torch.sum(log_prob)

def scores(z: torch.Tensor, alpha: float) -> torch.Tensor:
    u, v = z[..., 0], z[..., 1]
    raw_scores = alpha * torch.einsum('...ik,...jk->...ij', u, v)
    *batch_dims, d, _ = z.shape[:-1]
    diag_mask = 1.0 - torch.eye(d, device=z.device, dtype=z.dtype)
    if batch_dims:
        diag_mask = diag_mask.expand(*batch_dims, d, d)
    return raw_scores * diag_mask

def bernoulli_soft_gmat(z: torch.Tensor, hparams: Dict[str, Any]) -> torch.Tensor:
    probs = torch.sigmoid(scores(z, hparams["alpha"]))
    d = probs.shape[-1]
    diag_mask = 1.0 - torch.eye(d, device=probs.device, dtype=probs.dtype)
    if probs.ndim == 3:
        diag_mask = diag_mask.expand(probs.shape[0], d, d)
    return probs * diag_mask

def gumbel_soft_gmat(z: torch.Tensor,
                     hparams: Dict[str, Any]) -> torch.Tensor:
    """
    Soft Gumbel–Softmax adjacency  (Eq. B.6)

        g_ij  = σ_τ( L_ij + α⟨u_i , v_j⟩ )

    where  L_ij ~ Logistic(0,1)  and  τ = hparams['tau']. appendix b2
    """
    raw = scores(z, hparams["alpha"])

    # Logistic(0,1) noise   L = log U - log(1-U)
    u = torch.rand_like(raw)
    L = torch.log(u) - torch.log1p(-u)

    logits = (raw + L) / hparams["tau"]
    g_soft = torch.sigmoid(logits)

    d = g_soft.size(-1)
    mask = 1.0 - torch.eye(d, device=z.device, dtype=z.dtype)
    return g_soft * mask

def log_full_likelihood(data: Dict[str, Any], soft_gmat: torch.Tensor, theta: torch.Tensor, hparams: Dict[str, Any]) -> torch.Tensor:
    ## TODO: Expert belief: update this to use interventions, change the full likelihood 
    # and also add log bernoulli likelihood calculatior
    x_data = data['x']
    effective_W = theta * soft_gmat
    pred_mean = torch.matmul(x_data, effective_W)
    sigma_obs = hparams.get('sigma_obs_noise', 0.1)
    return log_gaussian_likelihood(x_data, pred_mean, sigma=sigma_obs)

def log_theta_prior(theta_effective: torch.Tensor, sigma: float) -> torch.Tensor:
    return log_gaussian_likelihood(theta_effective, torch.zeros_like(theta_effective), sigma=sigma)

def gumbel_acyclic_constr_mc(z: torch.Tensor, d: int, hparams: Dict[str, Any]) -> torch.Tensor:
    h_samples = []
    for _ in range(hparams['n_nongrad_mc_samples']):
        # FOR NOW, JUST GIVE THE SOFT MATRIX, AND BY ANNEALING IT TO HARD MATRIX
        g_soft = gumbel_soft_gmat(z, hparams)
        h_samples.append(acyclic_constr(g_soft, d))
        
        # should gumbel soft gmat to hard gmat be done with >0.5 or with a sigmoid?  
        #print(f'g_soft shape: {g_soft.shape}, values: \n {g_soft}')
        #if hparams['current_iteration'] % 1 == 0:
        #    print(f'g_soft shape: {g_soft.shape}, values: \n {g_soft}')
        #g_hard = torch.bernoulli(g_soft)
        #if hparams['current_iteration'] % 1 == 0:
        #    print(f'g_hard shape: {g_hard.shape}, values: \n {g_hard}')
        #print(f'g_hard shape: {g_hard.shape}, values: \n {g_hard}')
        #h_samples.append(acyclic_constr(g_hard, d))
        #g_hard = (g_soft > 0.5).float()
        #how about this  mentioned in dibs       g_ST   = g_hard + (g_soft - g_soft.detach())   # straight-through
        
        #TODO fix above
        # for now use g_soft
        
    h_samples = torch.stack(h_samples)


    return torch.mean(h_samples, dim=0)

def grad_z_log_joint_gumbel(z: torch.Tensor, theta: torch.Tensor, data: Dict[str, Any], hparams: Dict[str, Any]) -> torch.Tensor:
    d = z.shape[0]
    theta_const = theta
    
    #z.requires_grad_(True)
    # --- Part 1: Prior Gradient ---
    # MC estimate of gradient of acyclicity constraint using Gumbel soft graphs

    h_mean = gumbel_acyclic_constr_mc(z, d, hparams)
    grad_h_mc = torch.autograd.grad(h_mean, z)[0]
    grad_log_z_prior_total = -hparams['beta'] * grad_h_mc - (z / hparams['sigma_z']**2)

    # --- Part 2: Likelihood Gradient ---
    
    # 1. We need to collect the log-probability AND the gradient for each sample.
    log_density_samples = []
    grad_samples = []

    for _ in range( hparams['n_grad_mc_samples']):
        # 2. Generate a single soft graph sample.
        g_soft = gumbel_soft_gmat(z, hparams)

        # 3. Calculate the log-joint for this single sample.
        log_density_one_sample = log_full_likelihood(data, g_soft, theta_const, hparams) + \
                                 log_theta_prior(theta_const * g_soft, hparams.get('theta_prior_sigma', 1.0))

        # 4. Calculate the gradient for this single sample.
        # We must use retain_graph=True because we are doing a backward pass
        # inside a loop, and PyTorch would otherwise free the graph memory.
        grad, = torch.autograd.grad(log_density_one_sample, z, retain_graph=True)
        
        log_density_samples.append(log_density_one_sample)
        grad_samples.append(grad)

    # 5. After the loop, we can safely detach z_ from any further graph history.

    # 6. Compute the final likelihood gradient using the stable weighted average.
    # This correctly computes E[p*∇log(p)] / E[p]
    log_p = torch.stack(log_density_samples)
    grad_p = torch.stack(grad_samples)
    grad_lik = weighted_grad(log_p, grad_p)


    #if z.grad is not None:
    #    z.grad.zero_()
    #z.requires_grad_(False)
    
    # Final combined gradient
    


    total = grad_log_z_prior_total + grad_lik
    # 3) Combine
    # ------------------------------------------------
    return total.detach()


## SCORE BASED ESTIMATOR FOR GRADIENT Z 


# ------------------------------------------------------------
#  Score-function estimator for ∇_Z log p(Z,Θ | D)
#  (Section B.2 of the paper, b = 0)
# ------------------------------------------------------------
def analytic_score_g_given_z(z, g, hparams):
    # 1. logits and probabilities
    probs = bernoulli_soft_gmat(z, hparams)
    diff   = g - probs                 # (g_ij − σ(s_ij))
    u, v   = z[..., 0], z[..., 1]      # (d,k)

    # 2. gradients wrt u and v
    grad_u = hparams['alpha'] * torch.einsum('ij,jk->ik', diff, v)   # (d,k)
    grad_v = hparams['alpha'] * torch.einsum('ij,ik->jk', diff, u)   # (d,k)

    return torch.stack([grad_u, grad_v], dim=-1)          # (d,k,2)


def grad_z_log_joint_score(z: torch.Tensor,
                           theta: torch.Tensor,
                           data: Dict[str, Any],
                           hparams: Dict[str, Any]) -> torch.Tensor:
    """
    ∇_Z log p(Z,Θ | D)  using the score-function (REINFORCE) estimator.

    This replaces the Gumbel-soft estimator.
    """
    sigma_z2 = hparams['sigma_z'] ** 2
    beta     = hparams['beta']

    M = hparams['n_grad_mc_samples'] # M = 50
    d = z.shape[0]                   # d = 4
    theta_const = theta


    # 1. sample hard graphs 
    with torch.no_grad():
        g_hard_samples = [torch.bernoulli(bernoulli_soft_gmat(z, hparams)) for _ in range(M)]

    ll = []
    scores = []
    for g in g_hard_samples:
        log_lik = log_full_likelihood(data, g, theta_const, hparams)
        theta_eff = theta_const * g
        log_theta_prior_val = log_theta_prior(theta_eff, hparams.get('theta_prior_sigma', 1.0))
        
        # log likelihood 
        ll.append(log_lik + log_theta_prior_val)
        
        # score 
        scores.append(analytic_score_g_given_z(z, g, hparams))
    
    log_p = torch.stack(ll)
    grad_p = torch.stack(scores)

    log_p_max = log_p.max()
    log_p_shifted = log_p - log_p_max
    #print(f'log_p_shifted shape: {log_p_shifted.shape}, values: \n {log_p_shifted}')
    unnormalized_w = torch.exp(log_p_shifted/10)
    #print(f'unnormalized_w shape: {unnormalized_w.shape}, values: \n {unnormalized_w}')
    w = unnormalized_w / unnormalized_w.sum()
                     # (M,1,1,...)
    #print(f'w shape: {w.shape}, values: \n {w}')

    while w.dim() < grad_p.dim():
        w = w.unsqueeze(-1)                    
    
    ## compute the weighted avg 
    grad_lik = (w * grad_p).sum(dim=0)




    # ---- Z-prior: Gaussian + acyclicity penalty --------------
    # gumbel is possible cuz no differentiable function in expectation 
    z_ = z.detach().clone().requires_grad_(True)
    h_mean = gumbel_acyclic_constr_mc(z_, d, hparams)       # differentiable w.r.t z_
    grad_h, = torch.autograd.grad(h_mean, z_, retain_graph=False)
    grad_prior = -beta * grad_h - z_ / sigma_z2

    return (grad_lik + grad_prior).detach()



def weighted_grad(log_p: torch.Tensor,
                  grad_p: torch.Tensor) -> torch.Tensor:
    """
    Return   Σ softmax(log_p)_m * grad_p[m]
    Shapes
        log_p   : (M,)
        grad_p  : (M, …)   (any extra dims)
    """
    # 1. numerically stable soft-max weights
    #print(f'log_p shape: {log_p.shape}, values:\n {log_p}')
    #print(f'grad_p shape: {grad_p.shape}, values: \n{grad_p}')
    log_p_shifted = log_p - log_p.max()          # (M,)
    #print(f'log_p_shifted shape: {log_p_shifted.shape}, values: \n {log_p_shifted}')
    w = torch.exp(log_p_shifted)
    #print(f'w shape: {w.shape}, values:\n {w}')
    w = w / w.sum()
    #print(f'w after normalization shape: {w.shape}, values:\n {w}')

    # 2. broadcast weights onto grad tensor
    while w.dim() < grad_p.dim():
        w = w.unsqueeze(-1)                      # (M,1,1,...)

    return (w * grad_p).sum(dim=0)               # same shape as grad slice



def grad_theta_log_joint(z: torch.Tensor, theta: torch.Tensor, data: Dict[str, Any], hparams: Dict[str, Any]) -> torch.Tensor:
    #theta.requires_grad_(True)
    n_samples = hparams.get('n_grad_mc_samples', 1)
    theta_ = theta.clone().detach().requires_grad_(True)
    log_density_samples = []
    grad_samples = []
    for _ in range(n_samples):
        g_soft = bernoulli_soft_gmat(z, hparams)
        #print(f"g_soft values: {g_soft}")
        g_hard = torch.bernoulli(g_soft)
        #print(f"g_hard values: {g_hard}")

        # tryign with gumbel to be consistent with grad z and gumbel mc acylci impelmentation
        #g_soft = gumbel_soft_gmat(z, hparams)




        log_lik_val = log_full_likelihood(data, g_hard, theta_, hparams)
        theta_eff = theta_ * g_hard
        log_theta_prior_val = log_theta_prior(theta_eff, hparams.get('theta_prior_sigma', 1.0))
        #ll_grad, = torch.autograd.grad(log_lik_val, theta_, retain_graph=True)
        #log_theta_prior_grad, = torch.autograd.grad(log_theta_prior_val, theta_ , retain_graph=True)
        #print(f"ll_grad shape: {ll_grad.shape}, values: {ll_grad}")
        #print(f"log_theta_prior_grad shape: {log_theta_prior_grad.shape}, values: {log_theta_prior_grad}")

        current_log_density = log_lik_val + log_theta_prior_val
        current_grad ,= torch.autograd.grad(current_log_density, theta_)
        log_density_samples.append(current_log_density) 

        grad_samples.append(current_grad)
    #print(f" END OF Grad_theta mc_samples, iter number: {hparams.get('current_iteration',1)}  \n")

    log_p_tensor = torch.stack(log_density_samples)
    grad_p_tensor = torch.stack(grad_samples)


    # Cleanup
    #if theta.grad is not None:
    #    theta.grad.zero_()
    #theta.requires_grad_(False)

    grad =weighted_grad(log_p_tensor, grad_p_tensor)
    #grad = stable_gradient_estimator(log_p_tensor, grad_p_tensor)
    #print(f"Grad_theta shape: {grad.shape}, values: \n {grad}")

    return  grad.detach()


def grad_log_joint(params: Dict[str, torch.Tensor], data: Dict[str, Any], hparams: Dict[str, Any]) -> Dict[str, torch.Tensor]:
    grad_z = grad_z_log_joint_gumbel(params["z"], params["theta"].detach(), data, hparams)
    #grad_z = grad_z_log_joint_score(params["z"], params["theta"].detach(), data, hparams)
    grad_theta = grad_theta_log_joint(params["z"].detach(), params["theta"], data, hparams)
    
    return {"z": grad_z, "theta": grad_theta}

def log_joint(params: Dict[str, torch.Tensor], data: Dict[str, Any], hparams: Dict[str, Any]) -> torch.Tensor:
    hparams_updated = update_dibs_hparams(hparams, params["t"].item())
    z, theta = params['z'], params['theta']
    d = z.shape[0]

    g_soft = bernoulli_soft_gmat(z, hparams_updated)
    log_lik = log_full_likelihood(data, g_soft, theta, hparams_updated)

    log_prior_z_gaussian = torch.sum(Normal(0.0, hparams_updated['sigma_z']).log_prob(z))
    expected_h_val = gumbel_acyclic_constr_mc(z, d, hparams_updated)
    log_prior_z_acyclic = -hparams_updated['beta'] * expected_h_val
    log_prior_z = log_prior_z_gaussian + log_prior_z_acyclic
    
    theta_eff = theta * g_soft
    log_prior_theta = log_theta_prior(theta_eff, hparams_updated.get('theta_prior_sigma', 1.0))

    if (hparams_updated['current_iteration'] > 850 and hparams_updated['current_iteration'] < 1200):
        with torch.no_grad():
            log_terms = {
                "log_lik":      log_lik.item(),
                "z_prior_gauss":log_prior_z_gaussian.item(),
                "z_prior_acyc": log_prior_z_acyclic.item(),   # usually ≤ 0
                "theta_prior":  log_prior_theta.item(),
                "log_joint": log_lik + log_prior_theta + log_prior_z + log_prior_z_acyclic,
                "penalty": -hparams_updated['beta'] * expected_h_val.item()
            }
        print(f"[dbg] {log_terms}")

    
    return log_lik + log_prior_z + log_prior_theta

def update_dibs_hparams(hparams: Dict[str, Any], t_step: float) -> Dict[str, Any]:

    hparams['beta'] = hparams['beta_base'] * t_step # linear 

    hparams['alpha'] = hparams['alpha_base'] * t_step  # linear slope 0.2





    hparams['current_iteration'] = t_step # Store current iteration
    return hparams


def hard_gmat_from_z(z: torch.Tensor, alpha: float = 1.0) -> torch.Tensor:
    s = scores(z, alpha)
    return (s > 0).float()


# Single tests

In [3]:
# ... existing code ...

def test_acyclic_constr():
    """
    Test the acyclic_constr function with various scenarios to debug potential issues
    """
    print("=" * 60)
    print("Testing acyclic_constr function")
    print("=" * 60)
    
    # Test case 1: Small acyclic graph (d <= 10)
    print("\n1. Testing small acyclic graph (d=3):")
    d = 3
    g_acyclic = torch.tensor([
        [0.0, 0.5, 0.3],
        [0.0, 0.0, 0.7], 
        [0.0, 0.0, 0.0]
    ], dtype=torch.float32)
    
    try:
        h_acyclic = acyclic_constr(g_acyclic, d)
        print(f"   Acyclic graph constraint: {h_acyclic.item():.6f}")
        print(f"   Expected: close to 0 (should be <= 0 for acyclic)")
    except Exception as e:
        print(f"   ERROR: {e}")
    
    # Test case 2: Small cyclic graph
    print("\n2. Testing small cyclic graph (d=3):")
    g_cyclic = torch.tensor([
        [0.0, 0.5, 0.0],
        [0.0, 0.0, 0.7], 
        [0.3, 0.0, 0.0]
    ], dtype=torch.float32)
    
    try:
        h_cyclic = acyclic_constr(g_cyclic, d)
        print(f"   Cyclic graph constraint: {h_cyclic.item():.6f}")
        print(f"   Expected: > 0 (penalizes cycles)")
    except Exception as e:
        print(f"   ERROR: {e}")
    
    # Test case 3: Identity matrix (no edges)
    print("\n3. Testing identity/no edges (d=4):")
    d = 4
    g_identity = torch.eye(d, dtype=torch.float32)
    
    try:
        h_identity = acyclic_constr(g_identity, d)
        print(f"   Identity matrix constraint: {h_identity.item():.6f}")
        print(f"   Expected: close to 0")
    except Exception as e:
        print(f"   ERROR: {e}")
    
    # Test case 4: Large graph (d > 10, triggers eigenvalue computation)
    print("\n4. Testing large graph (d=12, eigenvalue path):")
    d = 12
    torch.manual_seed(42)  # For reproducibility
    g_large = torch.randn(d, d) * 0.1
    g_large = torch.triu(g_large, diagonal=1)  # Upper triangular (acyclic)
    
    try:
        h_large = acyclic_constr(g_large, d)
        print(f"   Large acyclic graph constraint: {h_large.item():.6f}")
        print(f"   Expected: close to 0")
    except Exception as e:
        print(f"   ERROR: {e}")
        
    # Test case 5: Large graph that might cause eigenvalue issues
    print("\n5. Testing large graph with potential eigenvalue issues (d=15):")
    d = 15
    g_problematic = torch.ones(d, d) * 0.5
    g_problematic.fill_diagonal_(0)  # No self-loops
    
    try:
        h_problematic = acyclic_constr(g_problematic, d)
        print(f"   Problematic graph constraint: {h_problematic.item():.6f}")
        print(f"   Expected: large positive value (many cycles)")
    except Exception as e:
        print(f"   ERROR (expected - should fallback to series): {e}")
    
    # Test case 6: Check gradients work
    print("\n6. Testing gradient computation:")
    d = 4
    g_test = torch.randn(d, d, requires_grad=True) * 0.1
    g_test.data.fill_diagonal_(0)
    
    try:
        h_test = acyclic_constr(g_test, d)
        h_test.backward()
        print(f"   Constraint value: {h_test.item():.6f}")
        print(f"   Gradient computed successfully: {g_test.grad is not None}")
        print(f"   Gradient norm: {g_test.grad.norm().item():.6f}")
    except Exception as e:
        print(f"   ERROR in gradient computation: {e}")
    
    # Test case 7: Edge cases
    print("\n7. Testing edge cases:")
    
    # Very small values
    d = 3
    g_small = torch.ones(d, d) * 1e-10
    g_small.fill_diagonal_(0)
    
    try:
        h_small = acyclic_constr(g_small, d)
        print(f"   Very small values constraint: {h_small.item():.10f}")
    except Exception as e:
        print(f"   ERROR with small values: {e}")
    
    # Very large values (might cause overflow)
    g_large_vals = torch.ones(d, d) * 10.0
    g_large_vals.fill_diagonal_(0)
    
    try:
        h_large_vals = acyclic_constr(g_large_vals, d)
        print(f"   Large values constraint: {h_large_vals.item():.6f}")
    except Exception as e:
        print(f"   ERROR with large values: {e}")
    
    print("\n" + "=" * 60)
    print("acyclic_constr testing complete")
    print("=" * 60)

# Run the test
test_acyclic_constr()

Testing acyclic_constr function

1. Testing small acyclic graph (d=3):
   Acyclic graph constraint: 0.000000
   Expected: close to 0 (should be <= 0 for acyclic)

2. Testing small cyclic graph (d=3):
   Cyclic graph constraint: 0.011667
   Expected: > 0 (penalizes cycles)

3. Testing identity/no edges (d=4):
   Identity matrix constraint: 5.765625
   Expected: close to 0

4. Testing large graph (d=12, eigenvalue path):
   Large acyclic graph constraint: 0.000000
   Expected: close to 0

5. Testing large graph with potential eigenvalue issues (d=15):
   Problematic graph constraint: 306.006714
   Expected: large positive value (many cycles)

6. Testing gradient computation:
   Constraint value: -0.042117
   Gradient computed successfully: False
   ERROR in gradient computation: 'NoneType' object has no attribute 'norm'

7. Testing edge cases:
   Very small values constraint: 0.0000000000
   Large values constraint: 422.222290

acyclic_constr testing complete


  print(f"   Gradient computed successfully: {g_test.grad is not None}")
  print(f"   Gradient norm: {g_test.grad.norm().item():.6f}")


In [4]:
# Test case 6: Check gradients work
print("\n6. Testing gradient computation:")
d = 4

# Method 1: Create tensor properly to maintain leaf status
g_test = torch.randn(d, d) * 0.1
# Zero out diagonal elements without modifying .data
mask = ~torch.eye(d, dtype=torch.bool)
g_test = g_test * mask.float()
g_test.requires_grad_(True)

try:
    h_test = acyclic_constr(g_test, d)
    h_test.backward()
    print(f"   Constraint value: {h_test.item():.6f}")
    print(f"   Gradient computed successfully: {g_test.grad is not None}")
    if g_test.grad is not None:
        print(f"   Gradient norm: {g_test.grad.norm().item():.6f}")
    else:
        print("   Gradient is None - tensor may not be leaf")
except Exception as e:
    print(f"   ERROR in gradient computation: {e}")

# Method 2: Alternative approach using retain_grad()
print("\n6b. Testing gradient computation (alternative method):")
g_test2 = torch.randn(d, d, requires_grad=True) * 0.1
g_test2.data.fill_diagonal_(0)
g_test2.retain_grad()  # This ensures gradients are kept even for non-leaf tensors

try:
    h_test2 = acyclic_constr(g_test2, d)
    h_test2.backward()
    print(f"   Constraint value: {h_test2.item():.6f}")
    print(f"   Gradient computed successfully: {g_test2.grad is not None}")
    if g_test2.grad is not None:
        print(f"   Gradient norm: {g_test2.grad.norm().item():.6f}")
    else:
        print("   Gradient is None")
except Exception as e:
    print(f"   ERROR in gradient computation: {e}")

# Method 3: Test with a simple differentiable operation
print("\n6c. Testing gradient computation (simple test):")
g_test3 = torch.randn(d, d, requires_grad=True) * 0.1
# Create off-diagonal matrix using multiplication
off_diag_mask = 1.0 - torch.eye(d)
g_test3_masked = g_test3 * off_diag_mask

try:
    h_test3 = acyclic_constr(g_test3_masked, d)
    grad_g3 = torch.autograd.grad(h_test3, g_test3, retain_graph=False)[0]
    print(f"   Constraint value: {h_test3.item():.6f}")
    print(f"   Gradient computed using autograd.grad: {grad_g3 is not None}")
    if grad_g3 is not None:
        print(f"   Gradient norm: {grad_g3.norm().item():.6f}")
        print(f"   Gradient shape: {grad_g3.shape}")
except Exception as e:
    print(f"   ERROR in gradient computation: {e}")


6. Testing gradient computation:
   Constraint value: -0.000009
   Gradient computed successfully: True
   Gradient norm: 2.013367

6b. Testing gradient computation (alternative method):
   Constraint value: 0.016230
   Gradient computed successfully: True
   Gradient norm: 2.014625

6c. Testing gradient computation (simple test):
   Constraint value: 0.020845
   Gradient computed using autograd.grad: True
   Gradient norm: 0.276399
   Gradient shape: torch.Size([4, 4])


# Scores z to prob and Bernouilli soft_gmat


In [5]:
def test_scores_and_bernoulli_soft_gmat():
    """
    Test the scores and bernoulli_soft_gmat functions based on the mathematical formulation
    from section 4.2 of the paper.
    
    Mathematical background:
    - Z = [U, V] where U, V ∈ R^(k×d)
    - scores should compute α * u_i^T v_j for all i,j
    - bernoulli_soft_gmat should compute σ_α(u_i^T v_j) = 1/(1 + exp(-α * u_i^T v_j))
    - Diagonal elements should be 0 (no self-loops)
    """
    print("=" * 70)
    print("Testing scores and bernoulli_soft_gmat functions")
    print("=" * 70)
    
    # Test case 1: Simple 2D case with known values
    print("\n1. Testing simple 2D case with known values:")
    d, k = 2, 3
    alpha = 1.0
    
    # Create simple Z = [U, V] with known values
    U = torch.tensor([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=torch.float32)  # (k, d)
    V = torch.tensor([[0.5, 1.0], [1.0, 0.5], [0.0, 0.0]], dtype=torch.float32)  # (k, d)
    Z = torch.stack([U.T, V.T], dim=-1)  # (d, k, 2)
    
    print(f"   Z shape: {Z.shape}")
    print(f"   Z:\n{Z}")
    
    # Test scores function
    try:
        scores_result = scores(Z, alpha)
        print(f"   Scores shape: {scores_result.shape}")
        print(f"   Scores:\n{scores_result}")
        
        # Manual computation for verification
        u1, v1 = Z[0, :, 0], Z[0, :, 1]  # u1, v1 for node 0
        u2, v2 = Z[1, :, 0], Z[1, :, 1]  # u2, v2 for node 1
        
        manual_score_01 = alpha * torch.dot(u1, v2)
        manual_score_10 = alpha * torch.dot(u2, v1)
        
        print(f"   Manual score (0->1): {manual_score_01.item():.6f}")
        print(f"   Computed score (0->1): {scores_result[0, 1].item():.6f}")
        print(f"   Manual score (1->0): {manual_score_10.item():.6f}")
        print(f"   Computed score (1->0): {scores_result[1, 0].item():.6f}")
        
        # Check diagonal is zero
        print(f"   Diagonal elements (should be 0): {torch.diag(scores_result)}")
        
    except Exception as e:
        print(f"   ERROR in scores: {e}")
    
    # Test bernoulli_soft_gmat function
    print("\n2. Testing bernoulli_soft_gmat:")
    hparams = {"alpha": alpha}
    
    try:
        probs = bernoulli_soft_gmat(Z, hparams)
        print(f"   Probabilities shape: {probs.shape}")
        print(f"   Probabilities:\n{probs}")
        
        # Manual computation for verification
        manual_prob_01 = torch.sigmoid(manual_score_01)
        manual_prob_10 = torch.sigmoid(manual_score_10)
        
        print(f"   Manual prob (0->1): {manual_prob_01.item():.6f}")
        print(f"   Computed prob (0->1): {probs[0, 1].item():.6f}")
        print(f"   Manual prob (1->0): {manual_prob_10.item():.6f}")
        print(f"   Computed prob (1->0): {probs[1, 0].item():.6f}")
        
        # Check diagonal is zero
        print(f"   Diagonal elements (should be 0): {torch.diag(probs)}")
        
        # Check probabilities are in [0, 1]
        print(f"   All probs in [0,1]: {torch.all((probs >= 0) & (probs <= 1))}")
        
    except Exception as e:
        print(f"   ERROR in bernoulli_soft_gmat: {e}")
    
    # Test case 3: Test with different alpha values
    print("\n3. Testing with different alpha values:")
    alphas = [0.1, 1.0, 5.0, 10.0]
    
    for alpha_test in alphas:
        hparams_test = {"alpha": alpha_test}
        try:
            scores_test = scores(Z, alpha_test)
            probs_test = bernoulli_soft_gmat(Z, hparams_test)
            
            print(f"   Alpha = {alpha_test}:")
            print(f"     Max score: {scores_test.max().item():.6f}")
            print(f"     Min score: {scores_test.min().item():.6f}")
            print(f"     Max prob: {probs_test.max().item():.6f}")
            print(f"     Min prob: {probs_test.min().item():.6f}")
            
        except Exception as e:
            print(f"   ERROR with alpha {alpha_test}: {e}")
    
    # Test case 4: Test gradient flow
    print("\n4. Testing gradient flow:")
    d, k = 3, 4
    Z_grad = torch.randn(d, k, 2, requires_grad=True) * 0.5
    hparams_grad = {"alpha": 2.0}
    
    try:
        scores_grad = scores(Z_grad, hparams_grad["alpha"])
        probs_grad = bernoulli_soft_gmat(Z_grad, hparams_grad)
        
        # Compute some loss and backpropagate
        loss = torch.sum(probs_grad ** 2)
        loss.backward()
        
        print(f"   Z gradient shape: {Z_grad.grad.shape}")
        print(f"   Z gradient norm: {Z_grad.grad.norm().item():.6f}")
        print(f"   Loss value: {loss.item():.6f}")
        
    except Exception as e:
        print(f"   ERROR in gradient flow: {e}")
    
    # Test case 5: Test consistency between scores and probabilities
    print("\n5. Testing consistency between scores and probabilities:")
    d, k = 4, 3
    Z_test = torch.randn(d, k, 2) * 0.3
    alpha_test = 1.5
    hparams_test = {"alpha": alpha_test}
    
    try:
        scores_manual = scores(Z_test, alpha_test)
        probs_from_scores = torch.sigmoid(scores_manual)
        
        # Zero out diagonal
        diag_mask = 1.0 - torch.eye(d)
        probs_from_scores = probs_from_scores * diag_mask
        
        probs_direct = bernoulli_soft_gmat(Z_test, hparams_test)
        
        # Check if they match
        max_diff = torch.max(torch.abs(probs_from_scores - probs_direct))
        print(f"   Max difference between manual and direct computation: {max_diff.item():.10f}")
        print(f"   Are they approximately equal: {torch.allclose(probs_from_scores, probs_direct, atol=1e-6)}")
        
    except Exception as e:
        print(f"   ERROR in consistency test: {e}")
    
    # Test case 6: Test edge cases
    print("\n6. Testing edge cases:")
    
    # Zero embeddings
    Z_zero = torch.zeros(3, 2, 2)
    hparams_zero = {"alpha": 1.0}
    
    try:
        scores_zero = scores(Z_zero, 1.0)
        probs_zero = bernoulli_soft_gmat(Z_zero, hparams_zero)
        
        print(f"   Zero embeddings - scores: {scores_zero.unique()}")
        print(f"   Zero embeddings - probs: {probs_zero.unique()}")
        print(f"   Zero embeddings - all probs should be 0.5: {torch.allclose(probs_zero, torch.ones_like(probs_zero) * 0.5)}")
        
    except Exception as e:
        print(f"   ERROR with zero embeddings: {e}")
    
    # Very large embeddings (test numerical stability)
    Z_large = torch.ones(2, 2, 2) * 10.0
    hparams_large = {"alpha": 1.0}
    
    try:
        scores_large = scores(Z_large, 1.0)
        probs_large = bernoulli_soft_gmat(Z_large, hparams_large)
        
        print(f"   Large embeddings - max score: {scores_large.max().item():.6f}")
        print(f"   Large embeddings - max prob: {probs_large.max().item():.6f}")
        print(f"   Large embeddings - are probs valid: {torch.all((probs_large >= 0) & (probs_large <= 1))}")
        
    except Exception as e:
        print(f"   ERROR with large embeddings: {e}")
    
    # Test case 7: Test batched operation
    print("\n7. Testing batched operation:")
    batch_size = 2
    d, k = 3, 2
    Z_batch = torch.randn(batch_size, d, k, 2) * 0.5
    hparams_batch = {"alpha": 1.0}
    
    try:
        scores_batch = scores(Z_batch, 1.0)
        # Note: bernoulli_soft_gmat might not support batching directly
        print(f"   Batch scores shape: {scores_batch.shape}")
        print(f"   Expected shape: ({batch_size}, {d}, {d})")
        print(f"   Shapes match: {scores_batch.shape == (batch_size, d, d)}")
        
    except Exception as e:
        print(f"   ERROR in batch operation: {e}")
    
    print("\n" + "=" * 70)
    print("scores and bernoulli_soft_gmat testing complete")
    print("=" * 70)

# Run the test
test_scores_and_bernoulli_soft_gmat()

Testing scores and bernoulli_soft_gmat functions

1. Testing simple 2D case with known values:
   Z shape: torch.Size([2, 3, 2])
   Z:
tensor([[[1.0000, 0.5000],
         [0.0000, 1.0000],
         [1.0000, 0.0000]],

        [[0.0000, 1.0000],
         [1.0000, 0.5000],
         [1.0000, 0.0000]]])
   Scores shape: torch.Size([2, 2])
   Scores:
tensor([[0., 1.],
        [1., 0.]])
   Manual score (0->1): 1.000000
   Computed score (0->1): 1.000000
   Manual score (1->0): 1.000000
   Computed score (1->0): 1.000000
   Diagonal elements (should be 0): tensor([0., 0.])

2. Testing bernoulli_soft_gmat:
   Probabilities shape: torch.Size([2, 2])
   Probabilities:
tensor([[0.0000, 0.7311],
        [0.7311, 0.0000]])
   Manual prob (0->1): 0.731059
   Computed prob (0->1): 0.731059
   Manual prob (1->0): 0.731059
   Computed prob (1->0): 0.731059
   Diagonal elements (should be 0): tensor([0., 0.])
   All probs in [0,1]: True

3. Testing with different alpha values:
   Alpha = 0.1:
     Max 

  print(f"   Z gradient shape: {Z_grad.grad.shape}")


In [7]:
# Standalone test cell for debug_notebook.ipynb
#
# To use this, you would typically have the functions available 
# in the same notebook or imported from your 'models/dibs.py' script.

import torch
import numpy as np
import unittest
import logging

# --- Setup basic logger ---
# This is to prevent errors if log.warning is called in acyclic_constr
logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)


# --- Functions to be tested ---
# Pasted here for stand-alone execution.
# In a real scenario, you would import these from your scripts.

def acyclic_constr(g: torch.Tensor, d: int) -> torch.Tensor:
    """H(G) from NOTEARS (Zheng et al.) with a series fallback for large *d*."""
    # Ensure g is a floating point tensor for matrix operations
    g = g.float()
    alpha = 1.0 / d
    eye = torch.eye(d, device=g.device, dtype=g.dtype)
    m = eye + alpha * g

    # Using matrix_power for d <= 10 as it's generally stable for smaller matrices
    if d <= 10:
        return torch.trace(torch.linalg.matrix_power(m, d)) - d

    # For larger d, eigenvalues are more efficient but can be numerically unstable
    try:
        # Eigenvalue decomposition is faster for large d
        eigvals = torch.linalg.eigvals(m)
        # The constraint is based on the sum of the d-th power of eigenvalues
        return torch.sum(torch.real(eigvals ** d)) - d
    except torch.linalg.LinAlgError:
        # Fallback to series expansion if eigenvalue computation fails
        # This is a less precise but more stable approximation
        log.warning(f"Eigenvalue computation failed for d={d}. Falling back to series expansion.")
        trace = torch.tensor(0.0, device=g.device, dtype=g.dtype)
        p = eye.clone() # Start with identity matrix for power calculation
        for k in range(1, min(d + 1, 20)): # Limit to 20 terms for practical purposes
            p = p @ m
            trace += torch.trace(p) / k
        return trace

def scores(z: torch.Tensor, alpha: float) -> torch.Tensor:
    """Calculates the raw edge scores from latent embeddings."""
    # z has shape [d, k, 2]
    # u and v have shape [d, k]
    u, v = z[..., 0], z[..., 1]
    
    # einsum performs batch matrix multiplication of u and v.T
    # 'ik,jk->ij' means: sum over k for each i and j
    raw_scores = alpha * torch.einsum('ik,jk->ij', u, v)
    
    # Ensure no self-loops by masking the diagonal
    d = z.shape[0]
    diag_mask = 1.0 - torch.eye(d, device=z.device, dtype=z.dtype)
    
    return raw_scores * diag_mask

def bernoulli_soft_gmat(z: torch.Tensor, hparams: dict) -> torch.Tensor:
    """Generates a soft adjacency matrix using a Bernoulli parameterization."""
    # Get probabilities by applying a sigmoid to the raw scores
    probs = torch.sigmoid(scores(z, hparams["alpha"]))
    
    # The scores function already handles the diagonal masking, but as a safeguard:
    d = probs.shape[-1]
    diag_mask = 1.0 - torch.eye(d, device=probs.device, dtype=probs.dtype)
    
    return probs * diag_mask


# --- Test Cases ---

class TestAcyclicConstraint(unittest.TestCase):

    def test_strictly_acyclic_graph(self):
        """Tests a graph with no cycles (a Directed Acyclic Graph)."""
        g_acyclic = torch.tensor([[0., 1., 1.], [0., 0., 1.], [0., 0., 0.]])
        d = g_acyclic.shape[0]
        h_val = acyclic_constr(g_acyclic, d)
        print(f"Acyclic graph H(G): {h_val.item():.6f}")
        self.assertAlmostEqual(h_val.item(), 0.0, places=5, msg="Acyclic graph should have H(G) = 0")

    def test_self_loop_cycle(self):
        """Tests a graph with a self-loop (the simplest cycle)."""
        g_cyclic = torch.tensor([[1., 1., 0.], [0., 0., 1.], [0., 0., 0.]])
        d = g_cyclic.shape[0]
        h_val = acyclic_constr(g_cyclic, d)
        print(f"Graph with self-loop H(G): {h_val.item():.6f}")
        self.assertTrue(h_val.item() > 1e-4, msg="Cyclic graph should have H(G) > 0")

    def test_two_node_cycle(self):
        """Tests a graph with a 2-cycle (A -> B, B -> A)."""
        g_cyclic = torch.tensor([[0., 1., 0.], [1., 0., 0.], [0., 1., 0.]])
        d = g_cyclic.shape[0]
        h_val = acyclic_constr(g_cyclic, d)
        print(f"Graph with 2-cycle H(G): {h_val.item():.6f}")
        self.assertTrue(h_val.item() > 1e-4, msg="Cyclic graph should have H(G) > 0")
        
    def test_large_graph_eigenvalue_path(self):
        """Tests the eigenvalue code path with a larger (d=12) acyclic graph."""
        d = 12
        g_large_acyclic = torch.triu(torch.ones(d, d), diagonal=1)
        h_val = acyclic_constr(g_large_acyclic, d)
        print(f"Large acyclic graph (d=12) H(G): {h_val.item():.6f}")
        self.assertAlmostEqual(h_val.item(), 0.0, places=4, msg="Large acyclic graph should have H(G) near 0")


class TestGraphGeneration(unittest.TestCase):

    def setUp(self):
        """Set up common variables for the tests."""
        self.d = 3  # Number of nodes
        self.k = 2  # Latent dimension
        # Latent variable Z = [U, V]
        self.z = torch.arange(self.d * self.k * 2, dtype=torch.float32).view(self.d, self.k, 2)
        # z will be:
        # [[[ 0,  1], [ 2,  3]],
        #  [[ 4,  5], [ 6,  7]],
        #  [[ 8,  9], [10, 11]]]
        self.alpha = 0.5
        self.hparams = {"alpha": self.alpha}

    def test_scores_calculation(self):
        """Tests the bilinear score calculation G_ij = alpha * u_i^T v_j."""
        u = self.z[..., 0] # [[[0, 2], [4, 6], [8, 10]]]
        v = self.z[..., 1] # [[[1, 3], [5, 7], [9, 11]]]
        
        # Manually calculate expected scores
        expected_scores = self.alpha * torch.matmul(u, v.T)
        # Set diagonal to zero
        expected_scores.fill_diagonal_(0)
        
        # Get scores from function
        s = scores(self.z, self.alpha)
        print(f"\nCalculated scores:\n{s}")
        print(f"Expected scores:\n{expected_scores}")
        
        self.assertTrue(torch.allclose(s, expected_scores), "Scores do not match expected values.")
        # Check that diagonal is exactly zero
        self.assertTrue(torch.all(torch.diag(s) == 0), "Diagonal of scores matrix should be zero.")

    def test_bernoulli_soft_gmat(self):
        """Tests the sigmoid transformation of scores to get probabilities."""
        # Calculate scores first
        s = scores(self.z, self.alpha)
        
        # Manually calculate expected probabilities
        expected_probs = torch.sigmoid(s)
        
        # Get probabilities from function
        g_soft = bernoulli_soft_gmat(self.z, self.hparams)
        print(f"\nCalculated soft G-matrix:\n{g_soft}")
        print(f"Expected soft G-matrix:\n{expected_probs}")
        expected_probs.fill_diagonal_(0)
        self.assertTrue(torch.allclose(g_soft, expected_probs), "Soft G-matrix probabilities do not match expected values.")
        # Check that diagonal is exactly zero
        self.assertTrue(torch.all(torch.diag(g_soft) == 0), "Diagonal of soft G-matrix should be zero.")


# --- Running the tests ---
# This allows running the tests directly from the cell.
suite = unittest.TestSuite()
print("--- Running tests for acyclic_constr ---")
suite.addTest(unittest.makeSuite(TestAcyclicConstraint))
print("\n--- Running tests for Graph Generation ---")
suite.addTest(unittest.makeSuite(TestGraphGeneration))

runner = unittest.TextTestRunner()
runner.run(suite)


.....

.
----------------------------------------------------------------------
Ran 6 tests in 0.007s

OK


--- Running tests for acyclic_constr ---

--- Running tests for Graph Generation ---
Large acyclic graph (d=12) H(G): 0.000000
Graph with self-loop H(G): 1.370371
Acyclic graph H(G): 0.000000
Graph with 2-cycle H(G): 0.666667

Calculated soft G-matrix:
tensor([[0.0000, 0.9991, 1.0000],
        [1.0000, 0.0000, 1.0000],
        [1.0000, 1.0000, 0.0000]])
Expected soft G-matrix:
tensor([[0.5000, 0.9991, 1.0000],
        [1.0000, 0.5000, 1.0000],
        [1.0000, 1.0000, 0.5000]])

Calculated scores:
tensor([[ 0.,  7., 11.],
        [11.,  0., 51.],
        [19., 55.,  0.]])
Expected scores:
tensor([[ 0.,  7., 11.],
        [11.,  0., 51.],
        [19., 55.,  0.]])


<unittest.runner.TextTestResult run=6 errors=0 failures=0>

In [19]:
def test_simple_linear_model_log_joint():
    """
    Test log joint computation with a simple linear model:
    - Linear model: y = X @ theta + noise
    - Gaussian likelihood: p(y | X, theta) = N(X @ theta, sigma^2)
    - Gaussian prior: p(theta) = N(0, sigma_prior^2)
    - Use gradient ascent to maximize log p(theta | X, y)
    """
    print("=" * 70)
    print("Testing Simple Linear Model Log Joint")
    print("=" * 70)
    
    # Set random seed for reproducibility
    torch.manual_seed(42)
    
    # Generate synthetic linear data
    n_samples, n_features = 100, 3
    true_theta = torch.tensor([2.0, -1.5, 0.8])
    X = torch.randn(n_samples, n_features)
    noise = torch.randn(n_samples) * 0.1
    y = X @ true_theta + noise
    
    print(f"Data: {n_samples} samples, {n_features} features")
    print(f"True theta: {true_theta}")
    
    # Model parameters
    sigma_likelihood = torch.tensor(0.1)  # Known noise level
    sigma_prior = torch.tensor(2.0)       # Prior variance for theta
    
    def log_gaussian_likelihood_simple(y, X, theta, sigma):
        """Simple Gaussian likelihood for linear model"""
        pred = X @ theta
        residuals = y - pred
        log_prob = -0.5 * torch.log(2 * torch.pi * sigma**2) - 0.5 * (residuals**2 / sigma**2)
        return torch.sum(log_prob) / n_samples
    
    def log_gaussian_prior_simple(theta, sigma):
        """Simple Gaussian prior for theta"""
        log_prob = -0.5 * torch.log(2 * torch.pi * sigma**2) - 0.5 * (theta**2 / sigma**2)
        return torch.sum(log_prob) / n_samples
    
    def log_joint_simple(theta, X, y, sigma_lik, sigma_prior):
        """Log joint probability"""
        log_lik = log_gaussian_likelihood_simple(y, X, theta, sigma_lik)
        log_prior = log_gaussian_prior_simple(theta, sigma_prior)
        return log_lik + log_prior
    
    # Test 1: Evaluate log joint at true parameters
    print("\n1. Testing log joint evaluation:")
    try:
        log_joint_true = log_joint_simple(true_theta, X, y, sigma_likelihood, sigma_prior)
        print(f"   Log joint at true theta: {log_joint_true.item():.6f}")
        
        # Test individual components
        log_lik_true = log_gaussian_likelihood_simple(y, X, true_theta, sigma_likelihood)
        log_prior_true = log_gaussian_prior_simple(true_theta, sigma_prior)
        print(f"   Log likelihood: {log_lik_true.item():.6f}")
        print(f"   Log prior: {log_prior_true.item():.6f}")
        print(f"   Sum: {(log_lik_true + log_prior_true).item():.6f}")
        
    except Exception as e:
        print(f"   ERROR in log joint evaluation: {e}")
    
    # Test 2: Gradient computation
    print("\n2. Testing gradient computation:")
    try:
        theta_test = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
        log_joint_val = log_joint_simple(theta_test, X, y, sigma_likelihood, sigma_prior)
        
        # Compute gradient using autograd
        grad = torch.autograd.grad(log_joint_val, theta_test)[0]
        print(f"   Gradient at zero: {grad}")
        print(f"   Gradient norm: {grad.norm().item():.6f}")
        
        # The gradient should point towards the true parameters
        print(f"   Gradient direction vs true theta direction:")
        print(f"   Normalized gradient: {grad / grad.norm()}")
        print(f"   Normalized true theta: {true_theta / true_theta.norm()}")
        
    except Exception as e:
        print(f"   ERROR in gradient computation: {e}")
    
    # Test 3: Gradient ascent optimization
    print("\n3. Testing gradient ascent optimization:")
    try:
        # Initialize parameters
        theta_opt = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
        learning_rate = 1e-5
        n_iterations = 500
        
        log_joints = []
        thetas = []
        
        for i in range(n_iterations):
            # Zero gradients
            if theta_opt.grad is not None:
                theta_opt.grad.zero_()
            
            # Forward pass
            log_joint_val = log_joint_simple(theta_opt, X, y, sigma_likelihood, sigma_prior)
            
            # Backward pass
            log_joint_val.backward()
            
            # Store values
            log_joints.append(log_joint_val.item())
            thetas.append(theta_opt.detach().clone())
            
            # Gradient ascent step (maximize log joint)
            with torch.no_grad():
                theta_opt += learning_rate * theta_opt.grad
            
            if i % 1 == 0:
                print(f"   Iteration {i}: log_joint = {log_joint_val.item():.6f}, theta = {theta_opt.detach()}")
                print(f"   Gradient: {theta_opt.grad}")

        print(f"\n   Final theta: {theta_opt.detach()}")
        print(f"   True theta:  {true_theta}")
        print(f"   Error: {(theta_opt.detach() - true_theta).norm().item():.6f}")
        
        # Check if optimization improved
        print(f"   Initial log joint: {log_joints[0]:.6f}")
        print(f"   Final log joint: {log_joints[-1]:.6f}")
        print(f"   Improvement: {log_joints[-1] - log_joints[0]:.6f}")
        
    except Exception as e:
        print(f"   ERROR in gradient ascent: {e}")
    
    # Test 4: Compare with analytical solution
    print("\n4. Comparing with analytical solution:")
    try:
        # For linear regression with Gaussian prior, the MAP estimate is:
        # theta_MAP = (X^T X + (sigma_lik^2 / sigma_prior^2) * I)^{-1} X^T y
        lambda_reg = (sigma_likelihood / sigma_prior) ** 2
        XtX = X.T @ X
        Xty = X.T @ y
        
        theta_analytical = torch.linalg.solve(XtX + lambda_reg * torch.eye(n_features), Xty)
        
        print(f"   Analytical MAP: {theta_analytical}")
        print(f"   Optimized theta: {theta_opt.detach()}")
        print(f"   Difference: {(theta_analytical - theta_opt.detach()).norm().item():.6f}")
        
        # Evaluate log joint at analytical solution
        log_joint_analytical = log_joint_simple(theta_analytical, X, y, sigma_likelihood, sigma_prior)
        print(f"   Log joint (analytical): {log_joint_analytical.item():.6f}")
        print(f"   Log joint (optimized): {log_joints[-1]:.6f}")
        
    except Exception as e:
        print(f"   ERROR in analytical comparison: {e}")
    
    # Test 5: Test gradient with different starting points
    print("\n5. Testing robustness with different starting points:")
    starting_points = [
        torch.tensor([1.0, 1.0, 1.0]),
        torch.tensor([-2.0, 0.5, -1.0]),
        torch.tensor([10.0, -5.0, 2.0])
    ]
    
    for i, start_point in enumerate(starting_points):
        try:
            theta_test = start_point.clone().detach().requires_grad_(True)
            log_joint_val = log_joint_simple(theta_test, X, y, sigma_likelihood, sigma_prior)
            grad = torch.autograd.grad(log_joint_val, theta_test)[0]
            
            print(f"   Start {i+1}: theta={start_point}, log_joint={log_joint_val.item():.6f}, grad_norm={grad.norm().item():.6f}")
            
        except Exception as e:
            print(f"   ERROR with starting point {i+1}: {e}")
    
    print("\n" + "=" * 70)
    print("Simple Linear Model Log Joint Testing Complete")
    print("=" * 70)

# Run the test
test_simple_linear_model_log_joint()

Testing Simple Linear Model Log Joint
Data: 100 samples, 3 features
True theta: tensor([ 2.0000, -1.5000,  0.8000])

1. Testing log joint evaluation:
   Log joint at true theta: 82.612999
   Log likelihood: 88.310509
   Log prior: -5.697507
   Sum: 82.612999

2. Testing gradient computation:
   Gradient at zero: tensor([ 23538.4922, -11948.1221,   6987.8818])
   Gradient norm: 27306.568359
   Gradient direction vs true theta direction:
   Normalized gradient: tensor([ 0.8620, -0.4376,  0.2559])
   Normalized true theta: tensor([ 0.7619, -0.5715,  0.3048])

3. Testing gradient ascent optimization:
   Iteration 0: log_joint = -35074.003906, theta = tensor([ 0.2354, -0.1195,  0.0699])
   Gradient: tensor([ 23538.4922, -11948.1221,   6987.8818])
   Iteration 1: log_joint = -28022.892578, theta = tensor([ 0.4431, -0.2297,  0.1329])
   Gradient: tensor([ 20766.8730, -11024.6943,   6300.6553])
   Iteration 2: log_joint = -22418.244141, theta = tensor([ 0.6263, -0.3314,  0.1898])
   Gradient: 