# Self-Consistent Field Calculations with Neural Functionals

In this tutorial, we will cover how to create to perform self consistent field (SCF)calculations with `NeuralFunctionals`. Many of the implementations of SCF methods in Grad DFT are fully differentiable so will come in useful when we wish to accurately train functionals. 

Others are not fully differentiable but are still useful when one need to converge a SCF loop in Grad DFT when other methods may fail. These methods bypass the SCF loop by directly minimizing the energy with respect to the Kohn-Sham orbital coefficients.

To begin, we will run most of the code cells from the previous tutorial `~/examples/basic_notebooks/example_neural_functional_04.ipynb` such that we have a basic `NeuralFunctional` instance and a dummy "truth energy" from an LDA calculation in PySCF.


In [1]:
from pyscf import gto, dft
import grad_dft as gd
 
# Define the geometry of the molecule
mol = gto.M(atom=[["H", (0, 0, 0)], ["H", (0, 0, 1)]], basis="def2-tzvp", charge=0, spin=0)
mf = dft.UKS(mol)
ground_truth_energy = mf.kernel()

# Then we can use the following function to generate the molecule object
HH_molecule = gd.molecule_from_pyscf(mf)

import jax.numpy as jnp

def coefficient_inputs(molecule: gd.Molecule, *_, **__):
    rho = molecule.density()
    kinetic = molecule.kinetic_density()
    return jnp.concatenate((rho, kinetic), axis = 1)

def energy_densities(molecule: gd.Molecule, clip_cte: float = 1e-30, *_, **__):
    r"""Auxiliary function to generate the features of LSDA."""
    # Molecule can compute the density matrix.
    rho = jnp.clip(molecule.density(), a_min=clip_cte)
    # Now we can implement the LDA energy density equation in the paper.
    lda_e = -3/2 * (3/(4*jnp.pi)) ** (1/3) * (rho**(4/3)).sum(axis = 1, keepdims = True)
    # For simplicity we do not include the exchange polarization correction
    # check function exchange_polarization_correction in functional.py
    # The output of features must be an Array of dimension n_grid x n_features.
    return lda_e

from flax import linen as nn
from jax.nn import sigmoid

out_features = 1
def coefficients(instance, rhoinputs):
    r"""
    Instance is an instance of the class Functional or NeuralFunctional.
    rhoinputs is the input to the neural network, in the form of an array.
    localfeatures represents the potentials e_\theta(r).

    The output of this function is the energy density of the system.
    """

    x = nn.Dense(features=out_features)(rhoinputs)
    x = nn.LayerNorm()(x)
    return sigmoid(x)

nf = gd.NeuralFunctional(coefficients, energy_densities, coefficient_inputs)

from jax.random import PRNGKey

key = PRNGKey(42)
cinputs = coefficient_inputs(HH_molecule)
params = nf.init(key, cinputs)



Instructions for updating:
non-resource variables are not supported in the long term
converged SCF energy = -1.11599939445016  <S^2> = 4.4408921e-16  2S+1 = 1


Recall that the total energy can be calculated in a non-self consistent way like:

In [2]:
E_non_scf = nf.energy(params, HH_molecule)
print("Neural functional non-SCF total energy with random parameters is", E_non_scf)

Neural functional non-SCF total energy with random parameters is -0.7769992


## Linear mixing

The most simple (but robust) way of performing a SCF calculation in DFT is [linear mixing of the density](http://www.numis.northwestern.edu/Presentations/DFT_Mixing_For_Dummies.pdf). This is implemented in `make_simple_scf_loop` in a robust (but non-JIT-compilable) format and in `diff_simple_scf_loop` in a JIT-compilable format. 

Let's make both.

In [3]:
linear_mix = gd.simple_scf_loop(nf, mixing_factor=0.3, cycles=25)
linear_mix_jit = gd.diff_simple_scf_loop(nf, mixing_factor=0.3, cycles=25)

and compute the SCF total energy for both

In [4]:
mol_linear_mix = linear_mix(params, HH_molecule)
mol_linear_mix_jit = linear_mix_jit(params, HH_molecule)

print("Linear mixing (non-JIT) total energy is", mol_linear_mix.energy)
print("Linear mixing (JIT) total energy is", mol_linear_mix_jit.energy)

  nelecs = jnp.array([self.mo_occ[i].sum() for i in range(2)], dtype=jnp.int64)
  return asarray(x, dtype=self.dtype)


Linear mixing (non-JIT) total energy is -0.7908802
Linear mixing (JIT) total energy is -0.7908811


Try benchmarking the speed of both yourself. The Jitted version should be faster.

## Direct Inversion of the Iterative Subspace

[Direct Inversion of the Iterative Subspace](https://en.wikipedia.org/wiki/DIIS) (DIIS) is a more complex method used in many code to quickly converge the SCF. Like linear mixing, we have a non-JIT and JIT version implemented. 

The default functions are created like:

In [5]:
diis = gd.scf_loop(nf, cycles=5)
diis_jit = gd.diff_scf_loop(nf, cycles=5)

and are evaluated in the same way as before

In [6]:
mol_diis = diis(params, HH_molecule)
mol_diis_jit = diis_jit(params, HH_molecule)

print("DIIS (non-JIT) total energy is", mol_diis.energy)
print("DIIS (JIT) total energy is", mol_diis_jit.energy)

SCF not converged.
SCF energy = -0.918728046617933


  nelecs = jnp.array([self.mo_occ[i].sum() for i in range(2)], dtype=jnp.int64)
  nelecs = jnp.array([self.mo_occ[i].sum() for i in range(2)], dtype=jnp.int64)
  nelecs = jnp.array([self.mo_occ[i].sum() for i in range(2)], dtype=jnp.int64)
  nelecs = jnp.array([self.mo_occ[i].sum() for i in range(2)], dtype=jnp.int64)


DIIS (non-JIT) total energy is -0.79088175
DIIS (JIT) total energy is -0.7908813


I will remind you now that all of these SCF iterators so far, DIIS or linear mixing, are fully differentiable which means we have access to the gradients of any SCF predicted property (like the total energy, density etc.) with respect to the parameters of the neural functional. This means that neural functionals can be trained self consistently. We will encounter this in the `intermediate_notebooks`.

## Optimizing the Kohn-Sham orbitals

We have a further method for total energy calculation in Grad DFT. This involes a direct minimization of the total energy with respect to the Kohn-Sham orbital coefficients. This process is not presently differentiable but can come in useful in cases where the total energy is not converging in the SCF loops above. 

To use this method, we require an optimizer from `optax`

In [7]:
from optax import adam

learning_rate = 1e-5
momentum = 0.9
tx = adam(learning_rate=learning_rate, b1=momentum)

We can now make callable non-jittable and jittable versions of the orbital optimizer:

In [8]:
orb_opt = gd.mol_orb_optimizer(nf, tx, cycles=20)
orb_opt_jit = gd.jitted_mol_orb_optimizer(nf, tx, cycles=20)

and calculate the total energy

In [9]:
mol_orb_opt = orb_opt(params, HH_molecule)
mol_orb_opt_jit = orb_opt_jit(params, HH_molecule)

print("Orbital optimizer (non-JIT) total energy is", mol_orb_opt.energy)
print("Orbital optimizer (JIT) total energy is", mol_orb_opt_jit.energy)

Orbital optimizer (non-JIT) total energy is -0.707909
Orbital optimizer (JIT) total energy is -0.70790976
