In [1]:
import torch
from tqdm.notebook import tqdm

from dem.energies.gmm_pseudoenergy import GMMPseudoEnergy

In [2]:
import hydra
from omegaconf import DictConfig
from hydra.core.global_hydra import GlobalHydra

# Only initialize if not already initialized
if not GlobalHydra().is_initialized():
    # Initialize hydra with the same config path as train.py
    hydra.initialize(config_path="../../configs", version_base="1.3")
    # Load the experiment config for GMM with pseudo-energy
    cfg = hydra.compose(config_name="train", overrides=["experiment=gmm_idem_pseudo"])

# Instantiate the energy function using hydra, similar to train.py
energy_function = hydra.utils.instantiate(cfg.energy)

energy_function.gmm_potential.gmm.to(energy_function.device)

In [3]:
x_base = torch.tensor([0.1, 0.1], device=energy_function.gmm_potential.gmm.device)
x_batch_base = torch.tensor([[0.0, 0.0], [1.0, 1.0]], device=energy_function.gmm_potential.gmm.device)


### No derivatives in pseudo-potential

In [4]:
potential_fn = lambda x: energy_function.gmm_potential(x)

_x = x_base.clone()
print("potential:", potential_fn(_x))

potential: tensor(-14.8496, device='cuda:0')


In [5]:
# vmapped grad (as used in estimating the score DEM/dem/models/components/score_estimator.py)
grad_fxn = torch.func.grad(potential_fn, argnums=0)
vmapped_fxn = torch.vmap(grad_fxn, in_dims=(0), randomness="different")

_x_batch = x_batch_base.clone()
print("grad:", vmapped_fxn(_x_batch))

grad: tensor([[-186.8610,  122.7425],
        [-440.4272, -426.3380]], device='cuda:0')


### Forces (grad) in pseudo-potential

In [6]:
def potential_with_grad_fn(x): 
    energy = energy_function.gmm_potential(x)
    forces = -torch.func.grad(energy_function.gmm_potential)(x)
    force_magnitude = torch.norm(forces)
    return energy + force_magnitude

_x = x_base.clone()
print("potential with grad:", potential_with_grad_fn(_x))

potential with grad: tensor(144.6745, device='cuda:0')


In [7]:
# vmapped grad (as used in estimating the score DEM/dem/models/components/score_estimator.py)
grad_fxn_grad = torch.func.grad(potential_with_grad_fn, argnums=0)
vmapped_fxn_grad = torch.vmap(grad_fxn_grad, in_dims=(0), randomness="different")

_x_batch = x_batch_base.clone()
print("grad:", vmapped_fxn_grad(_x_batch))

grad: tensor([[1018.4152, -677.0424],
        [ 601.0931,  581.8644]], device='cuda:0')


### Forces (grad) and Hessian in pseudo-potential

In [8]:
def potential_with_grad_and_hessian_fn(x):
    energy = energy_function.gmm_potential(x)
    grad = torch.func.grad(energy_function.gmm_potential)(x)
    hessian = torch.func.hessian(energy_function.gmm_potential)(x)
    return energy + torch.norm(grad) + torch.norm(hessian)

_x = x_base.clone()
print("potential with grad and hessian:", potential_with_grad_and_hessian_fn(_x))

potential with grad and hessian: tensor(2194.6680, device='cuda:0')


In [9]:
# vmapped grad (as used in estimating the score DEM/dem/models/components/score_estimator.py)
grad_fxn_hessian = torch.func.grad(potential_with_grad_and_hessian_fn, argnums=0)
vmapped_fxn_hessian = torch.vmap(grad_fxn_hessian, in_dims=(0), randomness="different")

_x_batch = x_batch_base.clone()
print("grad:", vmapped_fxn_hessian(_x_batch))

grad: tensor([[-1894.3495, -2516.1553],
        [  601.0931,   581.8644]], device='cuda:0')


### Forces (grad) and smallest Hessian eigenvalue in pseudo-potential

In [10]:
def potential_with_grad_and_hessian_ev_fn(x):
    energy = energy_function.gmm_potential(x)
    grad = torch.func.grad(energy_function.gmm_potential)(x)
    hessian = torch.func.hessian(energy_function.gmm_potential)(x)
    
    if hessian.shape[0] > 2:
        # Get smallest 2 eigenvalues using LOBPCG
        k = 2  # Number of eigenvalues to compute
        init_X = torch.randn(hessian.shape[0], k, device=hessian.device)  # Initial guess
        eigenvalues, _ = torch.lobpcg(hessian, k=k, largest=False, X=init_X)
        smallest_eigenvalues = eigenvalues[:k]  # Get k smallest eigenvalues
    
    else:
        # Get eigenvalues using torch.linalg.eigvals since Hessian is small
        eigenvalues = torch.linalg.eigvals(hessian)
        # Sort eigenvalues in ascending order
        eigenvalues = torch.sort(eigenvalues.real)[0]  # Take real part and sort
        smallest_eigenvalues = eigenvalues[:2]  # Get 2 smallest eigenvalues
    
    # Bias toward index-1 saddle points:
    # - First eigenvalue should be negative (minimize positive values)
    # - Second eigenvalue should be positive (minimize negative values) 
    # Using softplus which is differentiable everywhere but still creates one-sided penalties
    ev1_bias = torch.nn.functional.softplus(smallest_eigenvalues[0])  # Penalize if first eigenvalue > 0
    ev2_bias = torch.nn.functional.softplus(-smallest_eigenvalues[1])  # Penalize if second eigenvalue < 0
    saddle_bias = ev1_bias + ev2_bias
    
    return energy + torch.norm(grad) + saddle_bias


_x = x_base.clone()
print("potential with grad and hessian eigenvalue:", potential_with_grad_and_hessian_ev_fn(_x))

potential with grad and hessian eigenvalue: tensor(1594.2388, device='cuda:0')


In [11]:
# vmapped grad (as used in estimating the score DEM/dem/models/components/score_estimator.py)
grad_fxn_hessian_ev = torch.func.grad(potential_with_grad_and_hessian_fn, argnums=0)
vmapped_fxn_hessian_ev = torch.vmap(grad_fxn_hessian_ev, in_dims=(0), randomness="different")

_x_batch = x_batch_base.clone()
print("grad:", vmapped_fxn_hessian_ev(_x_batch))

grad: tensor([[-1894.3495, -2516.1553],
        [  601.0931,   581.8644]], device='cuda:0')


### Forces (grad) and smallest Hessian eigenvalues (using Lanczos) in pseudo-potential

In [12]:
"""
Lanczos loop includes a branch that depends on a tensor value (the check if b < 1e-8). When you use vmap, any data‐dependent control flow (i.e. branching on tensor values) is not supported. One common fix is to remove early termination based on the norm bb and instead always run a fixed number of iterations. 
"""

def hessian_vector_product(grad, params, vec):
    """
    Computes Hessian-vector product using autograd.
    grad: gradient of the loss with respect to params.
    params: list of parameters (or flattened tensor) on which the Hessian is defined.
    vec: a flattened vector (same size as the concatenated params).
    """
    # Compute the dot product between grad and vec
    grad_dot = torch.dot(torch.cat([g.view(-1) for g in grad]), vec)
    # Compute second derivative (Hessian-vector product)
    hvp = torch.autograd.grad(grad_dot, params, retain_graph=True, create_graph=True)
    return torch.cat([h.contiguous().view(-1) for h in hvp])

def lanczos(Hv, v0, m, reg=1e-6):
    """
    Lanczos algorithm to approximate a symmetric matrix's eigenvalues.
    
    Args:
        Hv: a function that takes a vector and returns the Hessian-vector product.
        v0: initial vector (1D tensor)
        m: number of Lanczos iterations (fixed to ensure compatibility with vmap)
        reg: regularization parameter for Hessian-vector product. Add small regularization to diagonal of Hessian (default=1e-6)
        
    Returns:
        T: the tridiagonal matrix (m_actual x m_actual)
        Q: orthonormal Lanczos basis (n x m_actual)
    """
    Q = []
    alpha = []
    beta = []

    # Normalize initial vector
    q = v0 / v0.norm()
    Q.append(q)

    for j in range(m):
        # Compute w = H*q_j - beta_{j-1} * q_{j-1} (skip the term for j == 0)
        w = Hv(Q[j]) + reg * Q[j]  # Add regularization
        if j > 0:
            w = w - beta[j-1] * Q[j-1]
        
        # Compute alpha_j = q_j^T * w
        a = torch.dot(Q[j], w)
        alpha.append(a)
        
        # Orthogonalize w against q_j
        w = w - a * Q[j]
        
        # Full reorthogonalization against all previous vectors
        for k in range(j+1):
            w = w - torch.dot(Q[k], w) * Q[k]
            
        # Compute beta_j = norm(w)
        b = w.norm()
        beta.append(b)
        
        if j == m - 1:
            break
            
        # Safe division with larger epsilon
        q_next = w / (b + 1e-6)
        Q.append(q_next)
        
    m_actual = len(alpha)
    # Build the tridiagonal matrix T from alpha and beta
    T = torch.diag(torch.stack(alpha))
    for i in range(m_actual - 1):
        T[i, i+1] = beta[i+1]
        T[i+1, i] = beta[i+1]
    # Stack the basis vectors as columns in a matrix
    Q_mat = torch.stack(Q, dim=1)
    return T, Q_mat

def compute_hessian_eigs(loss, params, lanczos_steps=100, reg=1e-6):  # Increased steps
    # Compute first derivatives with create_graph for higher-order derivatives.
    grad = torch.autograd.grad(loss, params, create_graph=True)
    grad_flat = torch.cat([g.contiguous().view(-1) for g in grad])
    
    # Define Hessian-vector product function using the current grads and params.
    def Hv(v):
        return hessian_vector_product(grad, params, v)
    
    # Choose a random initial vector matching grad_flat's shape.
    v0 = torch.randn_like(grad_flat)
    
    # Run the fixed-iteration Lanczos algorithm.
    T, Q_mat = lanczos(Hv, v0, m=lanczos_steps, reg=reg)
    
    # Compute eigenvalues of the tridiagonal matrix T.
    eigvals, _ = torch.linalg.eigh(T)
    
    # Return the smallest two eigenvalues.
    smallest_two = eigvals[:2]
    return smallest_two

# Example potential function using functorch for Hessian-vector products:
def potential_with_grad_and_hessian_lanczos_fn(x, lanczos_steps=100, reg=1e-6):
    energy = energy_function.gmm_potential(x)
    grad = torch.func.grad(energy_function.gmm_potential)(x)
    
    # Define loss function (here, simply the energy).
    def loss_fn(x):
        return energy_function.gmm_potential(x)
    
    # Get Hessian-vector product using functorch transforms.
    grad_fn = torch.func.grad(loss_fn)
    def hvp(v):
        # Ensure v has the same shape as x.
        v = v.reshape(x.shape)
        return torch.func.jvp(grad_fn, (x,), (v,))[1]
    
    # Run Lanczos on the HVP function with increased iterations
    v0 = torch.randn_like(x)
    T, Q_mat = lanczos(hvp, v0, m=lanczos_steps)  # Increased from 40
    
    # Compute eigenvalues of the tridiagonal matrix.
    smallest_eigenvalues, _ = torch.linalg.eigh(T)
    smallest_eigenvalues = smallest_eigenvalues[:2]
    
    # Bias toward index-1 saddle points:
    # - Penalize if the first eigenvalue is positive.
    # - Penalize if the second eigenvalue is negative.
    ev1_bias = torch.nn.functional.softplus(smallest_eigenvalues[0])
    ev2_bias = torch.nn.functional.softplus(-smallest_eigenvalues[1])
    saddle_bias = ev1_bias + ev2_bias
    
    return energy + torch.norm(grad) + saddle_bias


In [13]:
# Test the function (ensure that energy_function and x_base, x_batch_base are defined in your context)
_x = x_base.clone()
print("potential with grad and hessian (Lanczos):", potential_with_grad_and_hessian_lanczos_fn(_x))

# vmapped grad:
grad_fxn_hessian_lanczos = torch.func.grad(potential_with_grad_and_hessian_lanczos_fn, argnums=0)
vmapped_fxn_hessian_lanczos = torch.vmap(grad_fxn_hessian_lanczos, in_dims=0, randomness="different")

"""
The first row corresponds to x=[0,0] and the second to x=[1,1] in the batch. 
The second row is stable because at [1,1] the Hessian eigenvalues are well-conditioned, making the Lanczos approximation consistent. 
At [0,0], the system is near a critical point where eigenvalues are more sensitive to the random initialization.
"""
_x_batch = x_batch_base.clone()
print("grad:\n", vmapped_fxn_hessian_lanczos(_x_batch))

# compare to potential_with_grad_and_hessian_ev_fn
_x_batch = x_batch_base.clone()
print("comparison grad:\n", vmapped_fxn_hessian_ev(_x_batch))

potential with grad and hessian (Lanczos): tensor(1592.6531, device='cuda:0')
grad:
 tensor([[-2240.0139, -2734.5093],
        [  601.0931,   581.8644]], device='cuda:0')
comparison grad:
 tensor([[-1894.3495, -2516.1553],
        [  601.0931,   581.8644]], device='cuda:0')


In [14]:
# Search for minimum number of Lanczos steps needed for convergence

# Compare to reference gradient
comparison = vmapped_fxn_hessian_ev(x_batch_base.clone())

def get_grad_diff(steps, reg):
    # Create function with specific number of steps
    def potential_fn(x):
        return potential_with_grad_and_hessian_lanczos_fn(x, lanczos_steps=steps, reg=reg)
    
    # Get vmapped gradient
    grad_fn = torch.func.grad(potential_fn, argnums=0)
    vmapped_fn = torch.vmap(grad_fn, in_dims=0, randomness="different")
    
    # Compute gradient
    grad = vmapped_fn(x_batch_base.clone())
    
    # Return max absolute difference
    return torch.max(torch.abs(grad - comparison)).item()

def do_step_search(reg=1e-6):
    # Try increasing numbers of steps
    step_sizes = [5, 10, 20, 50, 70, 100, 150]
    diffs = []

    print("Lanczos steps  | Max gradient difference")
    print("-" * 40)

    # Find minimum steps needed for reasonable convergence
    threshold = 1.0  # Maximum acceptable difference
    min_steps = None
    for steps in tqdm(step_sizes):
        diff = get_grad_diff(steps, reg)
        diffs.append(diff)
        if diff < threshold:
            min_steps = steps
            break
        tqdm.write(f"steps={steps:8d} | diff={diff:.1f}")

    if min_steps is not None:
        print(f"Minimum steps needed for diff < {threshold}: {min_steps}")
    else:
        print(f"No convergence achieved with tested step sizes (diff < {threshold})\n")

do_step_search(reg=1e-6)
do_step_search(reg=0.)
do_step_search(reg=1e-12)

Lanczos steps  | Max gradient difference
----------------------------------------


  0%|          | 0/7 [00:00<?, ?it/s]

steps=       5 | diff=492.5
steps=      10 | diff=631.5


steps=      20 | diff=159.4
steps=      50 | diff=59.6
steps=      70 | diff=557.5
steps=     100 | diff=329.0
steps=     150 | diff=393.4
No convergence achieved with tested step sizes (diff < 1.0)

Lanczos steps  | Max gradient difference
----------------------------------------


  0%|          | 0/7 [00:00<?, ?it/s]

steps=       5 | diff=719.9
steps=      10 | diff=1218.7
steps=      20 | diff=441.7
steps=      50 | diff=1015.8
steps=      70 | diff=1165.4
steps=     100 | diff=685.5
steps=     150 | diff=482.5
No convergence achieved with tested step sizes (diff < 1.0)

Lanczos steps  | Max gradient difference
----------------------------------------


  0%|          | 0/7 [00:00<?, ?it/s]

steps=       5 | diff=721.8
steps=      10 | diff=550.0
steps=      20 | diff=14.7
steps=      50 | diff=722.1
steps=      70 | diff=465.3
steps=     100 | diff=1116.2
steps=     150 | diff=63.6
No convergence achieved with tested step sizes (diff < 1.0)

