In [1]:
import torch
from dem.energies.gmm_pseudoenergy import GMMPseudoEnergy

In [2]:
import hydra
from omegaconf import DictConfig
from hydra.core.global_hydra import GlobalHydra

# Only initialize if not already initialized
if not GlobalHydra().is_initialized():
    # Initialize hydra with the same config path as train.py
    hydra.initialize(config_path="../../configs", version_base="1.3")
    # Load the experiment config for GMM with pseudo-energy
    cfg = hydra.compose(config_name="train", overrides=["experiment=gmm_idem_pseudo"])

# Instantiate the energy function using hydra, similar to train.py
energy_function = hydra.utils.instantiate(cfg.energy)

energy_function.gmm_potential.gmm.to(energy_function.device)

In [3]:
x_base = torch.tensor([0.0, 0.0], device=energy_function.gmm_potential.gmm.device)
x_batch_base = torch.tensor([[0.0, 0.0], [1.0, 1.0]], device=energy_function.gmm_potential.gmm.device)


### No derivatives in pseudo-potential

In [4]:
potential_fn = lambda x: energy_function.gmm_potential(x)

_x = x_base.clone()
print("potential:", potential_fn(_x))

potential: tensor(-23.3163, device='cuda:0')


In [5]:
# vmapped grad (as used in estimating the score DEM/dem/models/components/score_estimator.py)
grad_fxn = torch.func.grad(potential_fn, argnums=0)
vmapped_fxn = torch.vmap(grad_fxn, in_dims=(0), randomness="different")

_x_batch = x_batch_base.clone()
print("grad:", vmapped_fxn(_x_batch))

grad: tensor([[-186.8610,  122.7425],
        [-440.4272, -426.3380]], device='cuda:0')


### Forces (grad) in pseudo-potential

In [6]:
def potential_with_grad_fn(x): 
    energy = energy_function.gmm_potential(x)
    forces = -torch.func.grad(energy_function.gmm_potential)(x)
    force_magnitude = torch.norm(forces)
    return energy + force_magnitude

_x = x_base.clone()
print("potential with grad:", potential_with_grad_fn(_x))

potential with grad: tensor(200.2518, device='cuda:0')


In [7]:
# vmapped grad (as used in estimating the score DEM/dem/models/components/score_estimator.py)
grad_fxn_grad = torch.func.grad(potential_with_grad_fn, argnums=0)
vmapped_fxn_grad = torch.vmap(grad_fxn_grad, in_dims=(0), randomness="different")

_x_batch = x_batch_base.clone()
print("grad:", vmapped_fxn_grad(_x_batch))

grad: tensor([[1018.4152, -677.0424],
        [ 601.0931,  581.8644]], device='cuda:0')


### Forces (grad) and Hessian in pseudo-potential

In [8]:
def potential_with_grad_and_hessian_fn(x):
    energy = energy_function.gmm_potential(x)
    grad = torch.func.grad(energy_function.gmm_potential)(x)
    hessian = torch.func.hessian(energy_function.gmm_potential)(x)
    return energy + torch.norm(grad) + torch.norm(hessian)

_x = x_base.clone()
print("potential with grad and hessian:", potential_with_grad_and_hessian_fn(_x))

potential with grad and hessian: tensor(2237.5999, device='cuda:0')


In [9]:
# vmapped grad (as used in estimating the score DEM/dem/models/components/score_estimator.py)
grad_fxn_hessian = torch.func.grad(potential_with_grad_and_hessian_fn, argnums=0)
vmapped_fxn_hessian = torch.vmap(grad_fxn_hessian, in_dims=(0), randomness="different")

_x_batch = x_batch_base.clone()
print("grad:", vmapped_fxn_hessian(_x_batch))

grad: tensor([[-1894.3495, -2516.1553],
        [  601.0931,   581.8644]], device='cuda:0')
