In [1]:
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
import equinox as eqx
import optax
import numpy as np
import pyscf
from pyscf import dft, scf, gto
import pylibxc
import pyscfad.dft as dftad
from jax import custom_jvp

print(pyscf.__version__)

2.3.0




# Building a Custom XC Functional

## An Interface to PySCF

There is a brief description of how to customize an XC functional in PySCF [here](https://pyscf.org/user/dft.html). The github repo has further examples [here](https://github.com/pyscf/pyscf/blob/master/examples/dft/24-custom_xc_functional.py) and [here](https://github.com/pyscf/pyscf/blob/master/examples/dft/24-define_xc_functional.py). The prior focuses on custom combinations of existing functionals, while the latter focuses on truly custom functionals.

This notebook will aim to walk through generating "custom functionals" for i) LDA and ii) PBE exchange energies that match PySCF predictions, to hopefully make it more clear how one structures the custom function.

## Libxc Nomenclature

It is *very critical* to get the correct derivatives, as libxc's nomenclature is a bit confusing.

For a given density, the components of `rho` are $$\rho[0] = \rho_\uparrow,\ \ \rho[1] = \rho_\downarrow.$$

The total energy is given by $$E = \int \epsilon(\mathbf{r})d^3r = \int \rho(\mathbf{r})e(\mathbf{r})d^3 r,$$

where $\epsilon$ is the *energy density per unit volume* and $e$ is the *energy density per unit particle*. 

**All of Libxc's expected derivative inputs are with respect to $\epsilon$, which will be important to consider!**

## LDA

We first define the "custom" LDA exchange energy function.

In [2]:
def custom_x(rho):
    return -3/4*(3/np.pi)**(1/3)*np.sign(rho) * (np.abs(rho)) ** (1 / 3)

**CRITICALLY,** this is $e$ from above, **NOT** $\epsilon$, so the functional derivatives will be of $\rho$*`lda_x(rho)`!

Now we define the molecule we wish to calculate and do a baseline LDA-exchange calculation.

In [3]:
mol = gto.M(
    atom = '''
    O  0.   0.       0.
    H  0.   -0.757   0.587
    H  0.   0.757    0.587 ''',
    basis = 'ccpvdz')
mol.build()
mf = dft.RKS(mol)
#pure pyscf calculation for reference to check that our custom function is correct
mf.xc = 'lda_x,'
mf.kernel()
print(mf.e_tot)

converged SCF energy = -75.1897796609274
-75.18977966092743


Now, we can generate Pylibxc inputs to see that we will be generating the same data as Pylibxc expects.

In [4]:
#generate functional
func_lda_x = pylibxc.LibXCFunctional("lda_x", "unpolarized")
#grid data
ao = dft.numint.eval_ao(mol, mf.grids.coords, deriv=0)
dm = mf.make_rdm1()
rho = dft.numint.eval_rho(mol, ao, dm, xctype='LDA')
plxc_lda_x = func_lda_x.compute({'rho':rho})
print(plxc_lda_x)

{'zk': array([[-1.53163542e-04, -8.31846414e-05, -1.51526401e-04, ...,
        -4.92478424e+00, -4.92478424e+00, -4.92478424e+00]]), 'vrho': array([[-2.04218057e-04, -1.10912855e-04, -2.02035201e-04, ...,
        -6.56637899e+00, -6.56637899e+00, -6.56637899e+00]])}


In [5]:
exc_x = custom_x(rho)
vxc_x = (4/3)*exc_x

In [6]:
#THE 'zk' KEY ENTRY IS FOR e, NOT epsilon -- 'vrho' key is then (d/drho)(rho*e)
np.mean(abs(exc_x-plxc_lda_x['zk'])), np.mean(abs(vxc_x-plxc_lda_x['vrho']))

(1.4089226552908693e-16, 2.1550787671829716e-16)

We must now define the custom `eval_xc` function we will use to overwrite PySCF's.

In [7]:
def eval_xc_lda(xc_code, rho, spin=0, relativity=0, deriv=1, omega=None, verbose=None):
    #we only expect there to be a rho0 array, but I unpack it as (rho, deriv) here to be in line with the
    #pyscf example -- the size of the 'rho' array depends on the xc type (LDA, GGA, etc.)
    #so since LDA calculation, check for size first.
    if len(rho.shape) > 1:
        rho0, dx, dy, dz = rho[:4]
    else:
        rho0 = rho

    #calculate the "custom" energy with rho -- THIS IS e
    exc = custom_x(rho0)

    #we don't care about derivatives yet
    # but we can calculate vrho analytically -- vxc = (vrho, vgamma, vlapl, vtau)
    vrho = (4/3)*exc
    vxc = (vrho, None, None, None)
    #or higher order terms, but we can calculate v2rho2 analytically
    v2rho2 = (4/9)*exc/(rho+1e-10)
    v2rhosigma = None
    v2sigma2 = None
    v2lapl2 = None
    vtau2 = None
    v2rholapl = None
    v2rhotau = None
    v2lapltau = None
    v2sigmalapl = None
    v2sigmatau = None
    # 2nd order functional derivative
    fxc = (v2rho2, v2rhosigma, v2sigma2, v2lapl2, vtau2, v2rholapl, v2rhotau, v2lapltau, v2sigmalapl, v2sigmatau)
    #3rd order
    kxc = None
    
    return exc, vxc, fxc, kxc


Now we want to overwrite the standard driver function.

In [8]:
mf.define_xc_??

[0;31mSignature:[0m [0mmf[0m[0;34m.[0m[0mdefine_xc_[0m[0;34m([0m[0mdescription[0m[0;34m,[0m [0mxctype[0m[0;34m=[0m[0;34m'LDA'[0m[0;34m,[0m [0mhyb[0m[0;34m=[0m[0;36m0[0m[0;34m,[0m [0mrsh[0m[0;34m=[0m[0;34m([0m[0;36m0[0m[0;34m,[0m [0;36m0[0m[0;34m,[0m [0;36m0[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mSource:[0m   
[0;32mdef[0m [0mdefine_xc_[0m[0;34m([0m[0mks[0m[0;34m,[0m [0mdescription[0m[0;34m,[0m [0mxctype[0m[0;34m=[0m[0;34m'LDA'[0m[0;34m,[0m [0mhyb[0m[0;34m=[0m[0;36m0[0m[0;34m,[0m [0mrsh[0m[0;34m=[0m[0;34m([0m[0;36m0[0m[0;34m,[0m[0;36m0[0m[0;34m,[0m[0;36m0[0m[0;34m)[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0mlibxc[0m [0;34m=[0m [0mks[0m[0;34m.[0m[0m_numint[0m[0;34m.[0m[0mlibxc[0m[0;34m[0m
[0;34m[0m    [0mks[0m[0;34m.[0m[0m_numint[0m [0;34m=[0m [0mlibxc[0m[0;34m.[0m[0mdefine_xc_[0m[0;34m([0m

In [9]:
#overwrite the kernel's driver
mfc = dft.RKS(mol)
mfc = mfc.define_xc_(eval_xc_lda, 'LDA')
mfc.kernel()

converged SCF energy = -75.1897796609274


-75.18977966092737

In [10]:
print(abs(mf.e_tot - mfc.e_tot))

5.684341886080802e-14


So we've successfully re-created the LDA exchange energy via a custom functional. Now let's do the same, but using jax and it's autodifferentiation capabilities.

In [11]:
def custom_x_j(rho):
    #this is e
    return -3/4*(3/np.pi)**(1/3)*jnp.sign(rho) * (jnp.abs(rho)) ** (1 / 3)
def custom_x_rho_j(rho):
    #this is epsilon
    return (rho)*(-3/4*(3/np.pi)**(1/3)*jnp.sign(rho) * (jnp.abs(rho)) ** (1 / 3))

def eval_xc_lda_j(xc_code, rho, spin=0, relativity=0, deriv=1, omega=None, verbose=None):
    #we only expect there to be a rho0 array, but I unpack it as (rho, deriv) here to be in line with the
    #pyscf example -- the size of the 'rho' array depends on the xc type (LDA, GGA, etc.)
    #so since LDA calculation, check for size first.
    if len(rho.shape) > 1:
        rho0, dx, dy, dz = rho[:4]
    else:
        rho0 = rho

    #calculate the "custom" energy with rho -- THIS IS e
    #cast back to np.array since that's what pyscf works with
    exc = np.array(jax.vmap(custom_x_j)(rho0))

    #we don't care about derivatives yet
    # but we can calculate vrho automatically -- vxc = (vrho, vgamma, vlapl, vtau)
    vrho_f = eqx.filter_grad(custom_x_rho_j)
    vrho = np.array(jax.vmap(vrho_f)(rho0))
    vxc = (vrho, None, None, None)
    #or higher order terms, but we can calculate v2rho2 analytically
    v2rho2_f = eqx.filter_hessian(custom_x_rho_j)
    v2rho2 = np.array(jax.vmap(v2rho2_f)(rho))
    v2rhosigma = None
    v2sigma2 = None
    v2lapl2 = None
    vtau2 = None
    v2rholapl = None
    v2rhotau = None
    v2lapltau = None
    v2sigmalapl = None
    v2sigmatau = None
    # 2nd order functional derivative
    fxc = (v2rho2, v2rhosigma, v2sigma2, v2lapl2, vtau2, v2rholapl, v2rhotau, v2lapltau, v2sigmalapl, v2sigmatau)
    #3rd order
    kxc = None
    
    return exc, vxc, fxc, kxc


In [14]:
#overwrite the kernel's driver
mfcj = dft.RKS(mol)
mfcj = mfcj.define_xc_(eval_xc_lda_j, 'LDA')
mfcj.kernel()

converged SCF energy = -75.1897796609274


-75.18977966092739

In [17]:
abs(mf.e_tot - mfcj.e_tot), abs(mfc.e_tot - mfcj.e_tot)

(4.263256414560601e-14, 7.105427357601002e-14)

So we've re-created the LDA exchange functional via i) a custom, analytic functional which we can manually take derivatives of, and ii) a custom, auto-differentiable functional where we only specify $e$ and $\epsilon$ and rely on jax to differentiate for us.

## GGA

[TO DO]