# Basic Functionals in Grad DFT

In this basic tutorial we will to introduce the concept of a functional.

First, we we will prepare a `Molecule` instance, like we did in the previous tutorial:

In [1]:
from pyscf import gto, dft
import grad_dft as gd

# Define the geometry of the molecule and mean-field object
mol = gto.M(atom=[["H", (0, 0, 0)]], basis="def2-tzvp", charge=0, spin=1)
mf = dft.UKS(mol)
mf.kernel()
# Then we can use the following function to generate the molecule object
HF_molecule = gd.molecule_from_pyscf(mf)



Instructions for updating:
non-resource variables are not supported in the long term
converged SCF energy = -0.478343887897971  <S^2> = 0.75  2S+1 = 2


## Building an LDA exchange functional from scratch

To create any functional in Grad DFT, neural or otherwise, we need to define at least the following methods:

1. A features function, which takes a molecule and returns an array of features such as the density, its spatial derivatives and/or the kinetic density. We will just use the density.

In [2]:
import jax.numpy as jnp

def coefficient_inputs(molecule: gd.Molecule, *_, **__):
    rho = molecule.density()
    return rho

2. A function which takes a `Molecule` instance and returns energy densities $e[\rho](\mathbf{r})$. We will use just the LDA exchange energy density.

In [3]:
def energy_densities(molecule: gd.Molecule, clip_cte: float = 1e-30):
    r"""Auxiliary function to generate the features of LSDA."""
    # Molecule can compute the density matrix.
    rho = molecule.density()
    # To avoid numerical issues in JAX we limit too small numbers.
    rho = jnp.clip(rho, a_min=clip_cte)
    # Now we can implement the LDA energy density equation in the paper.
    lda_e = -3 / 2 * (3 / (4 * jnp.pi)) ** (1 / 3) * (rho ** (4 / 3)).sum(axis=1, keepdims=True)
    # For simplicity we do not include the exchange polarization correction
    # check function exchange_polarization_correction in functional.py
    # The output of features must be an Array of dimension n_grid x n_features.
    return lda_e

3. A coefficient function $\mathbf{c}_{\boldsymbol{\theta}}[\rho](\mathbf{r})$. In the neural case, this is where a neural network enters the scene (see the next tutorial). For a simple functional with fixed parameters like the LDA, the coefficient function is a constant 

In [4]:
def coefficients(instance, rho):
    return jnp.array([[1.]])

with the above ingredients, we can now successfully build

$$

E_{xc}[\rho] = \int \mathbf{c}_{\boldsymbol{\theta}}[\rho](\mathbf{r}) \cdot \mathbf{e}[\rho](\mathbf{r})d\mathbf{r}

$$

which in our simple exchange-only LDA case is

$$

E_{x, LDA}[\rho] = \int e_{x, LDA}[\rho](\mathbf{r}) d\mathbf{r}

$$

In [5]:
LSDA = gd.Functional(coefficients, energy_densities)

## Simple computations with functionals

We can compute the predicted energy using the following code:

In [6]:
from flax.core import freeze

params = freeze({'params': {}}) # Since the functional is not neural, we pass frozen dict for the parameters
compute_energy = gd.energy_predictor(LSDA)
predicted_energy_0, fock = compute_energy(params=params, molecule=HF_molecule)

You may notice we also predicted a Fock matrix. More on that in the next section!

We may use `molecule_predictor` to compute the energy of any other molecule too.

Another was of doing the same thing is first computing the features and then the energy.

In [7]:
predicted_energy_1 = LSDA.energy(params, HF_molecule)

Under the hood, what is really happening to compute the energy is the following:

First we compute the densities

In [8]:
densities = LSDA.compute_densities(molecule=HF_molecule)

Then we compute the coefficient inputs

In [9]:
cinputs = LSDA.compute_coefficient_inputs(molecule=HF_molecule)

Next, we compute the exchange-correlation energy

In [10]:
predicted_energy_2 = LSDA.xc_energy(params, HF_molecule.grid, cinputs, densities)

and, finally, add the non-exchange-correlation energy component

In [11]:
predicted_energy_2 += HF_molecule.nonXC()

We can check that all methods return the same energy

In [12]:
print("Predicted energies", predicted_energy_0, predicted_energy_1, predicted_energy_2)

Predicted energies -0.45662177 -0.45662177 -0.45662177


## Computing the Fock matrix using automatic differentiation

How did we compute the Fock matrix above? We used the Jax `value_and_grad` function.

Let us start defining a function that computes the energy from some one particle reduced density matrix `rdm1`:

In [13]:
def compute_energy_and_fock(rdm1, molecule):
    molecule = molecule.replace(rdm1=rdm1)
    return LSDA.energy(params, molecule)

Now comes the magic of Jax. We can compute the energy and the gradient of the energy
using `jax.grad` (or alternatively `value_and_grad`), indicating the argument we are takiong the derivatives with respect to

In [14]:
from jax import grad

new_fock = grad(compute_energy_and_fock, argnums=0)(HF_molecule.rdm1, HF_molecule)

Next, we need to add the corrections to compute the full fock matrix

In [15]:
new_fock = 1 / 2 * (new_fock + new_fock.transpose(0, 2, 1))

In [16]:
print("Is the newly computed fock matrix correct?:", jnp.isclose(fock, new_fock).all())

Is the newly computed fock matrix correct?: True
