In [None]:
! pip install -U "jax[cuda12]"
# ! pip install -U jax
! pip install -U  diffrax flax distrax optax

Collecting jax[cuda12]
  Downloading jax-0.5.0-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.5.0,>=0.5.0 (from jax[cuda12])
  Downloading jaxlib-0.5.0-cp311-cp311-manylinux2014_x86_64.whl.metadata (978 bytes)
Collecting jax-cuda12-plugin<=0.5.0,>=0.5.0 (from jax-cuda12-plugin[with_cuda]<=0.5.0,>=0.5.0; extra == "cuda12"->jax[cuda12])
  Downloading jax_cuda12_plugin-0.5.0-cp311-cp311-manylinux2014_x86_64.whl.metadata (1.2 kB)
Collecting jax-cuda12-pjrt==0.5.0 (from jax-cuda12-plugin<=0.5.0,>=0.5.0->jax-cuda12-plugin[with_cuda]<=0.5.0,>=0.5.0; extra == "cuda12"->jax[cuda12])
  Downloading jax_cuda12_pjrt-0.5.0-py3-none-manylinux2014_x86_64.whl.metadata (348 bytes)
Collecting nvidia-cuda-nvcc-cu12>=12.6.85 (from jax-cuda12-plugin[with_cuda]<=0.5.0,>=0.5.0; extra == "cuda12"->jax[cuda12])
  Downloading nvidia_cuda_nvcc_cu12-12.8.61-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.7 kB)
Downloading jaxlib-0.5.0-cp311-cp311-manylinux2014_x86_64.whl (102.0 MB)
[2K

In [None]:
import distrax
import jax
import jax.numpy as jnp
import jax.random as jrnd
import optax
from flax import nnx

import diffrax
from diffrax import ODETerm, Euler, Dopri5, AbstractSolver, diffeqsolve
# from tensorflow_probability.substrates import jax as tfp

import matplotlib
import matplotlib.pyplot as plt
from typing import Any, Union,Callable
import chex
from optax import ema
from jax._src import prng

from distrax import MultivariateNormalDiag

from jax import lax, vmap
jax.config.update("jax_enable_x64", True)

In [None]:
# to accumulate the value fo all functionals
@chex.dataclass
class F_values:
    energy: chex.ArrayDevice
    kin: chex.ArrayDevice
    vnuc: chex.ArrayDevice
    hart: chex.ArrayDevice

The following text is used to define the vector field needed in Normalizig flows.

$$
\partial_t \begin{bmatrix}
\mathbf{z}(t) \\
\log \rho_\phi(\mathbf{z}(t))
\end{bmatrix} =\begin{bmatrix}
g_\phi(\mathbf{z}(t),t) \\
-\nabla_{\mathbf{x}} \cdot g_\phi(\mathbf{z}(t),t)
\end{bmatrix},
$$
where $g_\phi(\mathbf{z}(t),t)$ is the NN that parametrizes the vector field, and the second term allows us to compute the change of volumne in Normalizing Flows.

In [None]:
class Flow(nnx.Module):
  def __init__(self, din: int, dim: int, rngs: nnx.Rngs):
    self.din, self.dim = din, dim
    self.linear_in = nnx.Linear(din + 1, dim, rngs=rngs)
    self.blocks = [
      nnx.Linear(dim, dim, rngs=rngs)
      for _ in range(3)
    ]
    self.linear_out = nnx.Linear(dim, din, rngs=rngs)

  def __call__(self, x, t):
    x = jnp.concatenate([x, t], axis=-1)
    x = self.linear_in(x)
    x = jnp.tanh(x)
    for block in self.blocks:
      x = block(x)
      x = jnp.tanh(x)
    x = self.linear_out(x)

    return x

class CNF(nnx.Module):
  def __init__(self, din: int, dim: int, rngs: nnx.Rngs):
    self.din, self.dim = din, dim
    self.flow = Flow(din, dim, rngs)

  def __call__(self, states, t):

    x, log_px = states[:-1], states[-1:]
    dz, f_vjp = jax.vjp(self.flow, x,t)
    x_ones = jnp.ones((self.din))
    (dtrJ,_) = f_vjp(x_ones)
    dtrJ = jnp.sum(dtrJ)

    return jnp.concatenate([dz, -dtrJ[None]], axis=-1)

data_dim: int = 1
model_dim: int = 264
rngs = nnx.Rngs(0)
flow = Flow(data_dim, model_dim, rngs)
flow_model = CNF(data_dim, model_dim, rngs)

@nnx.vmap(in_axes=(None, 0, 0), out_axes=0)
def forward(model, x, t):
  return model(x,t)

This function generates random samples from the prior distribution which are used to compute the expectation value of the functionals.

In [None]:
def batch_generator(key: prng.PRNGKeyArray, batch_size: int, prior_dist: Callable):
    """
    Generator that yields batches of samples from the prior distribution.

    Parameters
    ----------
    key : prng.PRNGKeyArray
        Key to generate random numbers.
    batch_size : int
        Size of the batch.
    prior_dist : Callable
        Prior distribution.

    """
    while True:
        _, key = jrnd.split(key)
        samples = prior_dist.sample(seed=key, sample_shape=batch_size)
        logp_samples = prior_dist.log_prob(samples)
        samples0 = lax.concatenate(
            (samples, logp_samples[:,None]), 1)

        _, key = jrnd.split(key)
        samples = prior_dist.sample(seed=key, sample_shape=batch_size)
        logp_samples = prior_dist.log_prob(samples)
        samples1 = lax.concatenate(
            (samples, logp_samples[:,None]), 1)

        yield lax.concatenate((samples0, samples1), 0)

key = jrnd.PRNGKey(0)
_,key = jrnd.split(key)

# information about LiH
Ne = 2 # Number of valence electrons
Z_alpha = 3 # Atomic number of Li
Z_beta = 1 # Atomic number of H
R =10. # Interatomic distance

energies_ema = ema(decay=0.99)
energies_state = energies_ema.init(
    F_values(energy=jnp.array(0.), kin=jnp.array(0.), vnuc=jnp.array(0.),  hart = jnp.array(0.)))

base_dist = distrax.MultivariateNormalDiag(jnp.array([0.]), jnp.array([1.]))

These functions are the functionals from our paper, Eqs. 11 to 21 in the Suplemental Information.
[Paper link](https://arxiv.org/pdf/2404.08764)

In [None]:
def thomas_fermi_1D(den: Any, Ne: int, c: float=(jnp.pi*jnp.pi)/24) -> jax.Array:
    r"""
    Thomas-Fermi kinetic functional in 1D.
    See original paper eq. 18 in https://pubs.aip.org/aip/jcp/article/139/22/224104/193579/Orbital-free-bond-breaking-via-machine-learning

    T_{\text{TF}}[\rhom] = \frac{\pi^2}{24} \int \left(\rhom(x) \right)^{3} \mathrm{d}x \\
    T_{\text{TF}}[\rhom] = \frac{\pi^2}{24} \Ne^3 \EX_{\rhozero} \left[ (\rhophi(x))^{2}

    Parameters
    ----------
    den : Array
        Density.
    score : Array
        Gradient of the log-likelihood function.
    Ne : int
        Number of electrons.
    c : float, optional
        Multiplication constant, by default (jnp.pi*jnp.pi)/24

    Returns
    -------
    jax.Array
        Thomas-Fermi kinetic energy.
    """

    den_sqr = den*den
    return c*(Ne**3)*den_sqr

def soft_coulomb(x:Any,xp:Any,Ne: int) -> jax.Array:
    r"""
    Soft-Coulomb potential.

    See eq 6 in https://pubs.aip.org/aip/jcp/article/139/22/224104/193579/Orbital-free-bond-breaking-via-machine-learning

    Parameters
    ----------
    x : Any
        A point where the potential is evaluated.
    xp : Any
        A point where the charge density is zero.
    Ne : int
        Number of electrons.

    Returns
    -------
    jax.Array
        Soft version of the Coulomb potential.
    """
    v_coul = 1/(jnp.sqrt( 1 + (x-xp)*(x-xp)))
    return v_coul*Ne**2

def attraction(x:Any, R:float, Z_alpha:int, Z_beta:int, Ne: int) -> jax.Array:
    """
    Attraction between two nuclei.

    See eq 7 in https://pubs.aip.org/aip/jcp/article/139/22/224104/193579/Orbital-free-bond-breaking-via-machine-learning

    Parameters
    ----------
    x : Any
        A point where the potential is evaluated.
    R : float
        Distance between the two nuclei.
    Z_alpha : int
        Atomic number of the first nucleus.
    Z_beta : int
        Atomic number of the second nucleus.
    Ne : int
        Number of electrons.

    Returns
    -------
    jax.Array
        Attraction to the nuclei of charges Z_alpha and Z_beta.
    """
    v_x = - Z_alpha/(jnp.sqrt(1 + (x + R/2)**2))  - Z_beta/(jnp.sqrt(1 + (x - R/2)**2))
    return v_x*Ne

In continuous normalizing flows, we can move from the base distribution ($p_0(z)$) to the target ($p_x(x)$). \\
For this we need to run the joint ODE in forward or reverse order.  

In [None]:
def rev_ode(flow_model, z_and_logpz):
  t0 = 0.
  t1 = 1.
  dt0 = t1 - t0
  vector_field = lambda t, x, args: forward(flow_model, x, t*jnp.ones((x.shape[0],1)))
  term = ODETerm(vector_field)
  sol = diffeqsolve(term, diffrax.Tsit5(), t1, t0, -dt0, z_and_logpz, stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-6), saveat=diffrax.SaveAt(ts=jnp.array([1., 0.])))
  x_and_logpx = sol.ys[-1,:,:]
  x = x_and_logpx[:,:-1]
  log_px = x_and_logpx[:,-1:]
  return x,log_px

def fwd_ode(flow_model, x_and_logpx):
  t0 = 0.
  t1 = 1.
  dt0 = t1 - t0
  flow_model.eval()
  # vector_field = lambda t, x, args: flow_model(x, jnp.full(x.shape[0], t))
  vector_field = lambda t, x, args: forward(flow_model, x, t*jnp.ones((x.shape[0],1)))
  term = ODETerm(vector_field)
  sol = diffeqsolve(term, diffrax.Tsit5(), t0, t1, dt0, x_and_logpx,
                    stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-6),
                    saveat=diffrax.SaveAt(ts=jnp.array([0., 1.])))
  z_and_log_jac = sol.ys[-1,:,:]
  z = z_and_log_jac[:,:-1]
  log_jac = z_and_log_jac[:,-1:]
  return z, log_jac

In [None]:
# compute the density using numerical integration.
def rho_rev(model, x):
  zt = jnp.concatenate([x,jnp.zeros((x.shape[0],1))], axis=-1)
  z0, logp_jac = rev_ode(model, zt)
  logp_x = prior_dist.log_prob(z0)[:, None] - logp_jac
  return jnp.exp(logp_x)

def integral(model, x_and_logpx):
  x = x_and_logpx[:,:-1]
  p_x = rho_rev(model, x_and_logpx)
  return jnp.trapezoid(p_x.flatten(),dx = x[1]-x[0]),p_x

In [None]:
# energy optimization
def grad_loss(model, z_and_logpz):
  x, log_px = fwd_ode(model, z_and_logpz)

  den_all, x_all = jnp.exp(log_px), x
  den, denp = den_all[:-1], den_all[-1:]
  x, xp = x_all[:-1], x_all[-1:]

  # evaluate all the functionals locally F[x_i, \rho(x_i)]
  e_t = thomas_fermi_1D(den, Ne)
  e_h = soft_coulomb(x, xp, Ne)
  e_nuc_v = attraction(x, R, Z_alpha, Z_beta,Ne)
  e = e_t + e_nuc_v + e_h

  energy = jnp.mean(e)

  f_values = F_values(energy=energy,
                            kin=jnp.mean(e_t),
                            vnuc=jnp.mean(e_nuc_v),
                            hart=jnp.mean(e_h),

                            )

  return energy, f_values

@nnx.jit
def train_step(flow_model: Flow, optimizer: nnx.Optimizer, x_and_logpx):
  loss, grads = nnx.value_and_grad(grad_loss, has_aux=True)(flow_model, x_and_logpx)
  optimizer.update(grads)
  return loss, optimizer

lr = optax.schedules.exponential_decay(2e-3, transition_steps = 1, decay_rate = 0.95)
tx = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adamw(lr,weight_decay=1E-5)
        )
optimizer = nnx.Optimizer(flow_model, tx)

prior_dist = MultivariateNormalDiag(jnp.zeros(1), 1.*jnp.ones(1))

gen_batches = batch_generator(key, model_dim, prior_dist)

for itr in range(3):
    _,key = jrnd.split(key)
    batch = next(gen_batches) # generate a random sample form p_z

    loss_value, optimizer = train_step(flow_model, optimizer, batch) # compute the energy
    loss_epoch, losses = loss_value
    energies_i_ema, energies_state = energies_ema.update(
            losses, energies_state)
    ei_ema = energies_i_ema.energy

    r_ema = {'epoch': itr,
                 'E': energies_i_ema.energy,
                 'T': energies_i_ema.kin, 'V': energies_i_ema.vnuc, 'H': energies_i_ema.hart

                 }
    print( r_ema)


{'epoch': 0, 'E': Array(1.57444175, dtype=float64), 'T': Array(0.23048938, dtype=float64), 'V': Array(-1.6199868, dtype=float64), 'H': Array(2.96393917, dtype=float64)}
{'epoch': 1, 'E': Array(-0.13409486, dtype=float64), 'T': Array(0.18302722, dtype=float64), 'V': Array(-2.65111808, dtype=float64), 'H': Array(2.333996, dtype=float64)}
{'epoch': 2, 'E': Array(-0.52899626, dtype=float64), 'T': Array(0.19273305, dtype=float64), 'V': Array(-2.85198934, dtype=float64), 'H': Array(2.13026003, dtype=float64)}


In [None]:
x_grid = jnp.linspace(-12,12,1000)[:,None]
norm_val, rho_pred = integral(flow_model,x_grid)
print(norm_val)

TypeError: mul got incompatible shapes for broadcasting: (0,), (999,).

In [None]:
plt.plot(x_grid,rho_pred.flatten(),label='Normalizing Flow')
p0 = prior_dist.prob(x_grid)
plt.plot(x_grid,p0,label='Base distribution')
plt.legend()