# Building Neural Functionals

In this tutorial, we will cover how to create a parameterized neural functional. This will be a modified version of the workflow from `~/examples/basic_notebooks/example_lda_functional_02.ipynb` where the coefficient function $\mathbf{c}_{\boldsymbol{\theta}}[\rho](\mathbf{r})$ is no longer a constant but takes the form of a basic neural network.

Like before, we begin by performing a PySCF calculation, this time looking at $\text{H}_2$.


In [1]:
from pyscf import gto, dft
import grad_dft as gd

# Define the geometry of the molecule
mol = gto.M(atom=[["H", (0, 0, 0)], ["H", (0, 0, 1)]], basis="def2-tzvp", charge=0, spin=0)
mf = dft.UKS(mol)
ground_truth_energy = mf.kernel()

# Then we can use the following function to generate the molecule object
HH_molecule = gd.molecule_from_pyscf(mf)



Instructions for updating:
non-resource variables are not supported in the long term
converged SCF energy = -1.11599939445016  <S^2> = 4.4408921e-16  2S+1 = 1


Now we create an instance of `NeuralFunctional`. You will recognize the ingredients from when we built a basic `Functional` in `~/examples/basic_notebooks/example_lda_functional_02.ipynb`.

In [2]:
import jax.numpy as jnp

def coefficient_inputs(molecule: gd.Molecule, *_, **__):
    rho = molecule.density()
    kinetic = molecule.kinetic_density()
    return jnp.concatenate((rho, kinetic), axis = 1)

def energy_densities(molecule: gd.Molecule, clip_cte: float = 1e-30, *_, **__):
    r"""Auxiliary function to generate the features of LSDA."""
    # Molecule can compute the density matrix.
    rho = jnp.clip(molecule.density(), a_min=clip_cte)
    # Now we can implement the LDA energy density equation in the paper.
    lda_e = -3/2 * (3/(4*jnp.pi)) ** (1/3) * (rho**(4/3)).sum(axis = 1, keepdims = True)
    # For simplicity we do not include the exchange polarization correction
    # check function exchange_polarization_correction in functional.py
    # The output of features must be an Array of dimension n_grid x n_features.
    return lda_e

This time, however, the coefficient function is a simple neural network created with `flax` and `jax.nn`

In [3]:
from flax import linen as nn
from jax.nn import sigmoid

out_features = 1
def coefficients(instance, rhoinputs):
    r"""
    Instance is an instance of the class Functional or NeuralFunctional.
    rhoinputs is the input to the neural network, in the form of an array.
    localfeatures represents the potentials e_\theta(r).

    The output of this function is the energy density of the system.
    """

    x = nn.Dense(features=out_features)(rhoinputs)
    x = nn.LayerNorm()(x)
    return sigmoid(x)

we can now create a `NeuralFunctional` instance

In [4]:
nf = gd.NeuralFunctional(coefficients, energy_densities, coefficient_inputs)

Should I now initialize `nf` with some parameters

In [5]:
from jax.random import PRNGKey

key = PRNGKey(42)
cinputs = coefficient_inputs(HH_molecule)
params = nf.init(key, cinputs)

I can calculate the total energy given these functional parameters in the same way as we did for a regular `Functional` instance

In [6]:
E = nf.energy(params, HH_molecule)
print("Neural functional energy with random parameters is", E)

Neural functional energy with random parameters is -0.7769992


More complicated `NeuralFunctional`s will be defined in the `intermediate_notebooks` and `advanced_scripts`.

## Basic Neural Functional Training

The most basic way to train a `NeuralFunctional` is to fit the total energy from a high accuracy calculation like full CI, CCSD, CISD etc. We will proceed using the LDA total energy we already calculated using PySCF as a "dummy" high accuracy calculation.

We first create an optimizer for the training

In [7]:
from optax import adam

learning_rate = 0.01
momentum = 0.9
tx = adam(learning_rate=learning_rate, b1=momentum)
opt_state = tx.init(params)

and create the most basic of predictors: the non-SCF predictor. With this predictor, we assume that the charge density $\rho(\mathbf{r})$ passed the the Grad DFT `Molecule` object was a good approximation of the ground state density of the neural functional for all neural network parameters $\boldsymbol{\theta}$.

In [8]:
predictor = gd.non_scf_predictor(nf)

We now iterate for a number of epochs, printing out the currently predicted total energy from the `NeuralFunctional` instance and finally saving the trained model with a checkpoint.

In [9]:
from tqdm import tqdm
from optax import apply_updates

n_epochs = 20
for iteration in tqdm(range(n_epochs), desc="Training epoch"):
    (cost_value, predicted_energy), grads = gd.simple_energy_loss(
        params, predictor, HH_molecule, ground_truth_energy
    )
    print("Iteration", iteration, "Predicted energy:", predicted_energy, "Cost value:", cost_value)
    updates, opt_state = tx.update(grads, opt_state, params)
    params = apply_updates(params, updates)

nf.save_checkpoints(params, tx, step=n_epochs)

Training epoch:   5%|▌         | 1/20 [00:01<00:27,  1.44s/it]

Iteration 0 Predicted energy: -0.7769992 Cost value: 0.114921115


Training epoch:  15%|█▌        | 3/20 [00:01<00:07,  2.32it/s]

Iteration 1 Predicted energy: -0.7782445 Cost value: 0.114078335
Iteration 2 Predicted energy: -0.7794926 Cost value: 0.11323678
Iteration 3 Predicted energy: -0.7807367 Cost value: 0.112401046


Training epoch:  35%|███▌      | 7/20 [00:01<00:02,  5.61it/s]

Iteration 4 Predicted energy: -0.7819853 Cost value: 0.11156539
Iteration 5 Predicted energy: -0.78322804 Cost value: 0.11073674
Iteration 6 Predicted energy: -0.78447235 Cost value: 0.109910145


Training epoch:  45%|████▌     | 9/20 [00:02<00:01,  7.05it/s]

Iteration 7 Predicted energy: -0.7857164 Cost value: 0.10908681
Iteration 8 Predicted energy: -0.7869568 Cost value: 0.108269
Iteration 9 Predicted energy: -0.7881994 Cost value: 0.10745279


Training epoch:  65%|██████▌   | 13/20 [00:02<00:00,  9.56it/s]

Iteration 10 Predicted energy: -0.7894391 Cost value: 0.106641605
Iteration 11 Predicted energy: -0.7906773 Cost value: 0.105834424
Iteration 12 Predicted energy: -0.7919148 Cost value: 0.105030775


Training epoch:  75%|███████▌  | 15/20 [00:02<00:00, 10.57it/s]

Iteration 13 Predicted energy: -0.7931471 Cost value: 0.10423358
Iteration 14 Predicted energy: -0.794382 Cost value: 0.10343773
Iteration 15 Predicted energy: -0.7956114 Cost value: 0.102648444


Training epoch:  95%|█████████▌| 19/20 [00:02<00:00, 11.70it/s]

Iteration 16 Predicted energy: -0.796839 Cost value: 0.101863325
Iteration 17 Predicted energy: -0.7980643 Cost value: 0.1010827
Iteration 18 Predicted energy: -0.79928565 Cost value: 0.10030756


Training epoch: 100%|██████████| 20/20 [00:02<00:00,  6.85it/s]

Iteration 19 Predicted energy: -0.80050814 Cost value: 0.0995347



