This notebook is adapted from https://colab.research.google.com/drive/1yIlPo5CAjYrqWHeFEZrMlzWNCoNJ6_YP#scrollTo=eQwLElKmaowu

In [19]:
import jax
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_debug_nans", True)
import jax.numpy as jnp
from jax.example_libraries.stax import Dense, Relu, serial
import pandas as pd
import optax

<!-- # The Goal -->

In this notebook, we employ `RealNVP` to simplify the Potential Energy Surface (PES) of $\text{H}_2 \text{O}$ molecule.

First, we express the PES in terms of bond lengths $r_1, r_2$ of two $\text{O-H}$ bonds and bond angle $\theta$ between two $\text{O-H}$ bonds.
(Actually, we made a transformation, using the Morse variable and the cosine.)
$$E = V(\boldsymbol{x}),\quad \boldsymbol{x} = (e^{-r_1/\rho}, e^{-r_2/\rho}, \cos\theta),\ \rho = 1\text{\AA}$$

`RealNVP` allows us to make a invertible transformation, parameterized by $\text{params}$, of the input vector $\boldsymbol{x}$,
$$\boldsymbol{\tau} = f(\boldsymbol{x}|\text{params}).$$
And we assume that the PES is of a simple form in terms of $\boldsymbol{\tau}$.
In this notebook, we take a harmonic potential well, parameterized by $u_0$,
$$E = U(\boldsymbol{\tau}|c) = ||\boldsymbol{\tau}||^2 + u_0.$$

Ideally, we would have
$$E = V(\boldsymbol{x}) = U(f(\boldsymbol{x}|\text{params})|u_0).$$
So we can set our loss function to be the error
$$\mathcal{L}(\text{params}, u_0) = \frac{1}{N} \sum_{\{\boldsymbol{x}_n, E_n\} \in \mathcal{D}} [E_n - U(f(\boldsymbol{x}_n|\text{params})|u_0)]^2.$$

It is clear that the stable configuration is reached at $\boldsymbol{\tau}=\boldsymbol{0}$, and since `RealNVP` is invertible, we know that
$$\boldsymbol{x}^* = f^{-1}(\boldsymbol{0}|\text{params})$$
is the stable configuration.

In [20]:
def layer(transform):

    def init_fun(rng, input_dim):
        cutoff = input_dim // 2
        perm = jnp.arange(input_dim)[::-1]
        params, trans_fun = transform(rng, cutoff, 2 * (input_dim - cutoff))

        def direct_fun(params, inputs):
            lower, upper = inputs[:, :cutoff], inputs[:, cutoff:]

            # log_weight, bias = .split(2, axis=1)
            log_weight, bias = jnp.array_split(trans_fun(params, lower), 2, axis=1)
            upper = upper * jnp.exp(log_weight) + bias

            outputs = jnp.concatenate([lower, upper], axis=1)
            log_det_jacobian = log_weight.sum(-1)
            return outputs[:,perm], log_det_jacobian

        def inverse_fun(params, inputs):
            inputs = inputs[:, perm]
            lower, upper = inputs[:, :cutoff], inputs[:, cutoff:]

            log_weight, bias = jnp.array_split(trans_fun(params, lower), 2, axis=1)
            upper = (upper - bias) * jnp.exp(-log_weight)

            outputs = jnp.concatenate([lower, upper], axis=1)
            log_det_jacobian = log_weight.sum(-1)
            return outputs, log_det_jacobian

        return params, direct_fun, inverse_fun

    return init_fun

In [21]:
def RealNVP(transform, n: int):

    def init_fun(rng, input_dim):

        all_params, direct_funs, inverse_funs = [], [], []
        for _ in range(n):
            rng, layer_rng = jax.random.split(rng)
            init_fun = layer(transform)
            param, direct_fun, inverse_fun = init_fun(layer_rng, input_dim)

            all_params.append(param)
            direct_funs.append(direct_fun)
            inverse_funs.append(inverse_fun)

        def feed_forward(params, apply_funs, inputs):
            log_det_jacobians = jnp.zeros(inputs.shape[:1])
            for apply_fun, param in zip(apply_funs, params):
                inputs, log_det_jacobian = apply_fun(param, inputs)
                log_det_jacobians += log_det_jacobian
            return inputs, log_det_jacobians

        def direct_fun(params, inputs):
            return feed_forward(params, direct_funs, inputs)

        def inverse_fun(params, inputs):
            return feed_forward(reversed(params), reversed(inverse_funs), inputs)

        return all_params, direct_fun, inverse_fun

    return init_fun

In [22]:
def harmonic_potential(tau, u0):
    return jnp.linalg.norm(tau) + u0

In [23]:
def make_error_loss(flow_forward, data_file):
    data = pd.read_csv(data_file, sep='\s+')
    inputs = jnp.array([[jnp.exp(-r1), jnp.exp(-r2), jnp.cos(theta*jnp.pi/180)] for r1, r2, theta in zip(data["r1"], data["r2"], data["theta"])])
    energy = jnp.array(data["energy"]) + 76
    batch_decoupled_energy = jax.vmap(harmonic_potential, (0, None), 0)

    def loss(params, u0):
        outputs, _ = flow_forward(params, inputs)
        decoupled_energy = batch_decoupled_energy(outputs, u0)
        return jnp.mean( (decoupled_energy - energy) ** 2 )
    
    return loss

The data are generated from _ab initio_ calculations.

See [J. Chem. Phys. 106, 4618–4639 (1997)](https://doi.org/10.1063/1.473987) for details.

The dataset used here is adapted from supplementary data of this article.

In [24]:
batchsize = 8192
n = 1
dim = 3
nlayers = 3
rng = jax.random.PRNGKey(42)

def transform(rng, cutoff: int, other: int):
            net_init, net_apply = serial(Dense(16), Relu, Dense(16), Relu, Dense(other))
            in_shape = (-1, cutoff)
            out_shape, net_params = net_init(rng, in_shape)
            return net_params, net_apply

flow_init = RealNVP(transform, nlayers)

init_rng, rng = jax.random.split(rng)
params, flow_forward, flow_inverse = flow_init(init_rng, 3)

loss = make_error_loss(flow_forward, "/Users/longli/pycode/ml4p/projects/h2opes/h2opes.txt")
value_and_grad = jax.value_and_grad(loss, argnums=(0, 1), has_aux=False)

params_optimizer = optax.adam(0.01)
params_opt_state = params_optimizer.init(params)

u0 = 0.0
u0_optimizer = optax.adam(0.01)
u0_opt_state = u0_optimizer.init(u0)

In [25]:
@jax.jit
def step(params, u0, params_opt_state, u0_opt_state):
    value, grad = value_and_grad(params, u0)
    params_grad, u0_grad = grad
    # u0_value, u0_grad = u0_value_and_grad(params, u0, z)
    params_updates, params_opt_state = params_optimizer.update(params_grad, params_opt_state)
    u0_updates, u0_opt_state = u0_optimizer.update(u0_grad, u0_opt_state)
    params = optax.apply_updates(params, params_updates)
    u0 = optax.apply_updates(u0, u0_updates)
    return value, params, u0, params_opt_state, u0_opt_state

In [26]:
loss_history = []
for i in range(5000):
    value, params, u0, params_opt_state, u0_opt_state = step(params, u0, params_opt_state, u0_opt_state)
    loss_history.append(value)
    print(i, value)
print(u0)
output = flow_inverse(params, jnp.array([[0, 0, 0]]))

0 0.9792184582951899
1 0.6712674039886329
2 0.489768217802322
3 0.3858051276461103
4 0.3217176718679287
5 0.2760783916340585
6 0.23825851967365702
7 0.20261518560781083
8 0.16839980275386734
9 0.1366079707058247
10 0.10762406785456805
11 0.08305285643940592
12 0.06426006410616797
13 0.049531483170920286
14 0.03823303142281727
15 0.029556565401276762
16 0.02363103824975295
17 0.019579256273682476
18 0.016351021504496355
19 0.013306583163794771
20 0.010446712699577273
21 0.008105997209184328
22 0.006555177171362793
23 0.0058033264411172245
24 0.00559805016401512
25 0.0056158691293088986
26 0.005632021371865649
27 0.005579409711953421
28 0.005515770596110218
29 0.0055123041636549654
30 0.005598362150792229
31 0.005756941703790708
32 0.005906866029139674
33 0.006006422545356963
34 0.006026149960444063
35 0.005997093132351486
36 0.0059261715415817266
37 0.005840655312853092
38 0.005752737737226428
39 0.0056985635092077
40 0.005675327632176208
41 0.005653337452529845
42 0.005657345217259723


In [34]:
print(-jnp.log(output[0][0][0]), -jnp.log(output[0][0][1]), jnp.acos(output[0][0][2]) * 180 / jnp.pi)

0.9501984663086563 1.024407421889236 105.57954257857459


The output is $r_1, r_2, \theta = 0.95\text{\AA}, 1.02\text{\AA}, 105.6^{\circ}$.

The actual configuration is $r_1, r_2, \theta = 0.95\text{\AA}, 0.95\text{\AA}, 104.5^{\circ}$.