In [1]:
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
import equinox as eqx
import optax
import numpy as np
import pyscf
from pyscf import dft, scf, gto
import pylibxc
import pyscfad.dft as dftad
from jax import custom_jvp

print(pyscf.__version__)

2.3.0




# Building a Custom XC Functional

## An Interface to PySCF

There is a brief description of how to customize an XC functional in PySCF [here](https://pyscf.org/user/dft.html). The github repo has further examples [here](https://github.com/pyscf/pyscf/blob/master/examples/dft/24-custom_xc_functional.py) and [here](https://github.com/pyscf/pyscf/blob/master/examples/dft/24-define_xc_functional.py). The prior focuses on custom combinations of existing functionals, while the latter focuses on truly custom functionals.

This notebook will aim to walk through generating "custom functionals" for i) LDA and ii) PBE exchange energies that match PySCF predictions, to hopefully make it more clear how one structures the custom function.

## Libxc Nomenclature

It is *very critical* to get the correct derivatives, as libxc's nomenclature is a bit confusing.

For a given density, the components of `rho` are $$\rho[0] = \rho_\uparrow,\ \ \rho[1] = \rho_\downarrow.$$

The total energy is given by $$E = \int \epsilon(\mathbf{r})d^3r = \int \rho(\mathbf{r})e(\mathbf{r})d^3 r,$$

where $\epsilon$ is the *energy density per unit volume* and $e$ is the *energy density per unit particle*. 

**All of Libxc's expected derivative inputs are with respect to $\epsilon$, which will be important to consider!**

## LDA

First, we define the molecule we wish to calculate and do a baseline LDA-exchange calculation.

In [3]:
mol = gto.M(
    atom = '''
    O  0.   0.       0.
    H  0.   -0.757   0.587
    H  0.   0.757    0.587 ''',
    basis = 'ccpvdz')
mol.build()
mf = dft.RKS(mol)
#pure pyscf calculation for reference to check that our custom function is correct
mf.xc = 'lda_x,'
mf.kernel()
print(mf.e_tot)

converged SCF energy = -75.1897796609274
-75.18977966092743


We now define the "custom" LDA exchange energy function.

In [21]:
def custom_lda_x(rho):
    return -3/4*(3/np.pi)**(1/3)*np.sign(rho) * (np.abs(rho)) ** (1 / 3)

**CRITICALLY,** this is $e$ from above, **NOT** $\epsilon$, so the functional derivatives will be of $\rho$*`lda_x(rho)`!

Now, we can generate Pylibxc inputs to see that we will be generating the same data as Pylibxc expects.

In [4]:
#generate functional
func_lda_x = pylibxc.LibXCFunctional("lda_x", "unpolarized")
#grid data
ao = dft.numint.eval_ao(mol, mf.grids.coords, deriv=0)
dm = mf.make_rdm1()
rho = dft.numint.eval_rho(mol, ao, dm, xctype='LDA')
plxc_lda_x = func_lda_x.compute({'rho':rho})
print(plxc_lda_x)

{'zk': array([[-1.53163542e-04, -8.31846414e-05, -1.51526401e-04, ...,
        -4.92478424e+00, -4.92478424e+00, -4.92478424e+00]]), 'vrho': array([[-2.04218057e-04, -1.10912855e-04, -2.02035201e-04, ...,
        -6.56637899e+00, -6.56637899e+00, -6.56637899e+00]])}


In [5]:
exc_x = custom_lda_x(rho)
vxc_x = (4/3)*exc_x

In [6]:
#THE 'zk' KEY ENTRY IS FOR e, NOT epsilon -- 'vrho' key is then (d/drho)(rho*e)
np.mean(abs(exc_x-plxc_lda_x['zk'])), np.mean(abs(vxc_x-plxc_lda_x['vrho']))

(1.4089226552908693e-16, 2.1550787671829716e-16)

We must now define the custom `eval_xc` function we will use to overwrite PySCF's.

In [7]:
def eval_xc_lda(xc_code, rho, spin=0, relativity=0, deriv=1, omega=None, verbose=None):
    #we only expect there to be a rho0 array, but I unpack it as (rho, deriv) here to be in line with the
    #pyscf example -- the size of the 'rho' array depends on the xc type (LDA, GGA, etc.)
    #so since LDA calculation, check for size first.
    if len(rho.shape) > 1:
        rho0, dx, dy, dz = rho[:4]
    else:
        rho0 = rho

    #calculate the "custom" energy with rho -- THIS IS e
    exc = custom_lda_x(rho0)

    #we don't care about derivatives yet
    # but we can calculate vrho analytically -- vxc = (vrho, vgamma, vlapl, vtau)
    vrho = (4/3)*exc
    vxc = (vrho, None, None, None)
    #or higher order terms, but we can calculate v2rho2 analytically
    v2rho2 = (4/9)*exc/(rho+1e-10)
    v2rhosigma = None
    v2sigma2 = None
    v2lapl2 = None
    vtau2 = None
    v2rholapl = None
    v2rhotau = None
    v2lapltau = None
    v2sigmalapl = None
    v2sigmatau = None
    # 2nd order functional derivative
    fxc = (v2rho2, v2rhosigma, v2sigma2, v2lapl2, vtau2, v2rholapl, v2rhotau, v2lapltau, v2sigmalapl, v2sigmatau)
    #3rd order
    kxc = None
    
    return exc, vxc, fxc, kxc


Now we want to overwrite the standard driver function.

In [8]:
mf.define_xc_??

[0;31mSignature:[0m [0mmf[0m[0;34m.[0m[0mdefine_xc_[0m[0;34m([0m[0mdescription[0m[0;34m,[0m [0mxctype[0m[0;34m=[0m[0;34m'LDA'[0m[0;34m,[0m [0mhyb[0m[0;34m=[0m[0;36m0[0m[0;34m,[0m [0mrsh[0m[0;34m=[0m[0;34m([0m[0;36m0[0m[0;34m,[0m [0;36m0[0m[0;34m,[0m [0;36m0[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mSource:[0m   
[0;32mdef[0m [0mdefine_xc_[0m[0;34m([0m[0mks[0m[0;34m,[0m [0mdescription[0m[0;34m,[0m [0mxctype[0m[0;34m=[0m[0;34m'LDA'[0m[0;34m,[0m [0mhyb[0m[0;34m=[0m[0;36m0[0m[0;34m,[0m [0mrsh[0m[0;34m=[0m[0;34m([0m[0;36m0[0m[0;34m,[0m[0;36m0[0m[0;34m,[0m[0;36m0[0m[0;34m)[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0mlibxc[0m [0;34m=[0m [0mks[0m[0;34m.[0m[0m_numint[0m[0;34m.[0m[0mlibxc[0m[0;34m[0m
[0;34m[0m    [0mks[0m[0;34m.[0m[0m_numint[0m [0;34m=[0m [0mlibxc[0m[0;34m.[0m[0mdefine_xc_[0m[0;34m([0m

In [9]:
#overwrite the kernel's driver
mfc = dft.RKS(mol)
mfc = mfc.define_xc_(eval_xc_lda, 'LDA')
mfc.kernel()

converged SCF energy = -75.1897796609274


-75.18977966092737

In [10]:
print(abs(mf.e_tot - mfc.e_tot))

5.684341886080802e-14


So we've successfully re-created the LDA exchange energy via a custom functional. Now let's do the same, but using jax and it's autodifferentiation capabilities.

In [11]:
def custom_x_j(rho):
    #this is e
    return -3/4*(3/np.pi)**(1/3)*jnp.sign(rho) * (jnp.abs(rho)) ** (1 / 3)
def custom_x_rho_j(rho):
    #this is epsilon
    return (rho)*(-3/4*(3/np.pi)**(1/3)*jnp.sign(rho) * (jnp.abs(rho)) ** (1 / 3))

def eval_xc_lda_j(xc_code, rho, spin=0, relativity=0, deriv=1, omega=None, verbose=None):
    #we only expect there to be a rho0 array, but I unpack it as (rho, deriv) here to be in line with the
    #pyscf example -- the size of the 'rho' array depends on the xc type (LDA, GGA, etc.)
    #so since LDA calculation, check for size first.
    if len(rho.shape) > 1:
        rho0, dx, dy, dz = rho[:4]
    else:
        rho0 = rho

    #calculate the "custom" energy with rho -- THIS IS e
    #cast back to np.array since that's what pyscf works with
    exc = np.array(jax.vmap(custom_x_j)(rho0))
    
    #we don't care about derivatives yet
    # but we can calculate vrho automatically -- vxc = (vrho, vgamma, vlapl, vtau)
    vrho_f = eqx.filter_grad(custom_x_rho_j)
    vrho = np.array(jax.vmap(vrho_f)(rho0))
    vxc = (vrho, None, None, None)
    #or higher order terms, but we can calculate v2rho2 automatically
    v2rho2_f = eqx.filter_hessian(custom_x_rho_j)
    v2rho2 = np.array(jax.vmap(v2rho2_f)(rho))
    v2rhosigma = None
    v2sigma2 = None
    v2lapl2 = None
    vtau2 = None
    v2rholapl = None
    v2rhotau = None
    v2lapltau = None
    v2sigmalapl = None
    v2sigmatau = None
    # 2nd order functional derivative
    fxc = (v2rho2, v2rhosigma, v2sigma2, v2lapl2, vtau2, v2rholapl, v2rhotau, v2lapltau, v2sigmalapl, v2sigmatau)
    #3rd order
    kxc = None
    
    return exc, vxc, fxc, kxc


In [14]:
#overwrite the kernel's driver
mfcj = dft.RKS(mol)
mfcj = mfcj.define_xc_(eval_xc_lda_j, 'LDA')
mfcj.kernel()

converged SCF energy = -75.1897796609274


-75.18977966092739

In [17]:
abs(mf.e_tot - mfcj.e_tot), abs(mfc.e_tot - mfcj.e_tot)

(4.263256414560601e-14, 7.105427357601002e-14)

So we've re-created the LDA exchange functional via i) a custom, analytic functional which we can manually take derivatives of, and ii) a custom, auto-differentiable functional where we only specify $e$ and $\epsilon$ and rely on jax to differentiate for us.

## GGA

Now we will try to implement a "custom" functional to reproduce PBE's energies.

To start, we note that the [PBE XC energy](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.77.3865) is broken into two parts:

$$E_X^\mathrm{PBE} = \int d^3r [\rho\cdot e_X^\mathrm{HEG}]\cdot F_X(s),$$ where $$s = \frac{|\nabla\rho|}{2k_F\rho}$$ and $$F_X(s) = 1+\kappa - \frac{\kappa}{1+\mu s^2/\kappa},$$ for $\kappa=0.804$ and $\mu \simeq 0.21951$.

The correlation is given by $$E_C^\mathrm{PBE} = \int d^3r [\rho\cdot e_C^\mathrm{HEG}(r_s,\zeta)]\cdot\bigg[1+\frac{H(r_s,\zeta,t)}{e_C^\mathrm{HEG}(r_s,\zeta)}\bigg]$$ for $$r_s = (\frac{3}{4\pi\rho})^{1/3},\ \zeta = \frac{\rho_\uparrow - \rho_\downarrow}{\rho},\ t = \frac{|\nabla\rho|}{2\phi(\zeta)k_s\rho},$$ and $$\phi(\zeta) = \frac{1}{2}\cdot[(1+\zeta)^{2/3} + (1-\zeta)^{2/3}],\ k_s = \sqrt{ \frac{4k_F}{\pi a_0}}.$$

Here, $$H(r_s,\zeta,t) = \frac{e^2}{a_0}\cdot \gamma \cdot \phi^3 \cdot \ln\bigg[1_+\frac{\beta}{\gamma}t^2\bigg( \frac{1-At^2}{1+At^2+A^2t^4} \bigg) \bigg],$$ with $$A = \frac{\beta}{\gamma} \cdot \bigg[ \exp\bigg(\frac{-e_C^\mathrm{HEG}a_0}{\gamma\phi^3 e^2}\bigg)-1 \bigg]^{-1}.$$ Here, $\beta\simeq 0.066725$ and $\gamma\simeq 0.031091$.

As a starting point, since it's easier I will do the PBE exchange energy. As before, we start with a baseline calculation for a reference energy.

In [96]:
mol = gto.M(
    atom = '''
    O  0.   0.       0.
    H  0.   -0.757   0.587
    H  0.   0.757    0.587 ''',
    basis = 'ccpvdz')
mol.build()
mfp = dft.RKS(mol)
#pure pyscf calculation for reference to check that our custom function is correct
mfp.xc = 'gga_x_pbe,'
mfp.kernel()
print(mfp.e_tot)

converged SCF energy = -76.0026653855811
-76.00266538558111


Since we are now in GGA territory, libxc will be expecting derivatives w.r.t. what they call $\sigma$ and what PySCF calls $\gamma$: $$\sigma[0] = \nabla\rho_\uparrow\cdot \nabla\rho_\uparrow, \sigma[1] = \nabla\rho_\uparrow\cdot \nabla\rho_\downarrow, \sigma[2] = \nabla\rho_\downarrow\cdot \nabla\rho_\downarrow,$$ thus for a spin-unpolarized calculation, $\sigma = |\nabla\rho|^2$ -- this is important to keep in mind when doing functional derivatives.

When coding up these functions, we want JAX to be able to give us the correct derivatives -- therefore, we should code using the inputs libxc will expect us to take derivatives of and translate to the appropriate values within the function (i.e., code up $F_X^\mathrm{PBE}$ to expect the inputs of $(\rho, \sigma)$ and translate $\sigma$ to $s$ internally.

In [198]:
def custom_pbe_Fx(rho, sigma):
    #Equation 14 from PBE paper -- DOI: 10.1103/PhysRevLett.77.3865
    #THIS FLOOR SETTING MAKES VSIGMA ERROR MUCH HIGHER
    # rho = jnp.maximum(1e-12, rho) #Prevents division by 0
    k_F = (3 * jnp.pi**2 * rho)**(1/3)
    s = jnp.sqrt(sigma) / (2 * k_F * rho)
    kappa, mu = 0.804, 0.21951

    Fx = 1 + kappa - kappa / (1 + mu * s**2 / kappa) #exchange enhancement factor

    return Fx

def custom_pbe_e(rho, sigma):
    Fx = custom_pbe_Fx(rho, sigma)

    exc = custom_x_j(rho)*Fx

    return exc

def custom_pbe_epsilon(rho, sigma):

    return rho*custom_pbe_e(rho, sigma)

Per the [equinox documentation](https://docs.kidger.site/equinox/api/transformations/#automatic-differentiation), if we want derivatives w.r.t. both of the inputs (rho and sigma), we need to wrap these functions in one that unpacks a tuple passing them both.

In [199]:
def derivable_custom_pbe_e(rhosigma):
    rho, sigma = rhosigma
    return custom_pbe_e(rho, sigma)

def derivable_custom_pbe_epsilon(rhosigma):
    rho, sigma = rhosigma
    return custom_pbe_epsilon(rho, sigma)

With these defined, let us make sure we are getting the values we expect in comparison to libxc.

In [200]:
#generate functional
func_gga_x = pylibxc.LibXCFunctional("gga_x_pbe", "unpolarized")
#grid data
ao = dft.numint.eval_ao(mol, mf.grids.coords, deriv=1)
dm = mf.make_rdm1()
rho = dft.numint.eval_rho(mol, ao, dm, xctype='GGA')
rho0, dx, dy, dz = rho
sigma = dx**2+dy**2+dz**2
plxc_gga_x = func_gga_x.compute({'rho':rho0, 'sigma':sigma})
print(plxc_gga_x)

{'zk': array([[-2.76306989e-04, -1.50065087e-04, -2.73353587e-04, ...,
        -4.92497656e+00, -4.92497656e+00, -4.92497656e+00]]), 'vrho': array([[-3.68409208e-04, -2.00086767e-04, -3.64471341e-04, ...,
        -6.56612259e+00, -6.56612259e+00, -6.56612259e+00]]), 'vsigma': array([[-2.61510629e-01, -2.17081991e-01, -2.61465869e-01, ...,
        -2.14186003e-06, -2.14186003e-06, -2.14186003e-06]])}


In [213]:
test_vrho_f = eqx.filter_grad(derivable_custom_pbe_epsilon)
vrhosig = jax.vmap(test_vrho_f)( (rho0, sigma) )
vrho, vsig = vrhosig
#print out error stats
titlestr = 'Error statistics -- Libxc PBE e/vrho/vsigma vs. JAX auto-derived'
# titlestr += '\nUSING MINIMUM RHO 1e-12'
titlestr += '\nNO MINIMUM RHO'
print(titlestr)
print(len(titlestr)*'-')

print(len(titlestr)//2*'-'+'e'+len(titlestr)//2*'-')
print('Maximum Absolute Error: ',np.max(abs(plxc_gga_x['zk'] - derivable_custom_pbe_e( (rho0, sigma)))))
print('Mean Absolute Error: ',np.mean(abs(plxc_gga_x['zk'] - derivable_custom_pbe_e( (rho0, sigma)))))

print(len(titlestr)//2*'-'+'vrho'+len(titlestr)//2*'-')
print('Maximum Absolute Error: ', np.max(abs(plxc_gga_x['vrho'] - vrho )))
print('Mean Absolute Error: ', np.mean(abs(plxc_gga_x['vrho'] - vrho )))

print(len(titlestr)//2*'-'+'vsigma'+len(titlestr)//2*'-')
print('Maximum Absolute Error: ', np.max(abs(plxc_gga_x['vsigma'] - vsig  )))
print('Mean Absolute Error: ', np.mean(abs(plxc_gga_x['vsigma'] - vsig  )))

Error statistics -- Libxc PBE e/vrho/vsigma vs. JAX auto-derived
NO MINIMUM RHO
-------------------------------------------------------------------------------
---------------------------------------e---------------------------------------
Maximum Absolute Error:  5.366654586058672e-06
Mean Absolute Error:  5.862072795727476e-07
---------------------------------------vrho---------------------------------------
Maximum Absolute Error:  4.21578888687435e-06
Mean Absolute Error:  4.4862803770546e-07
---------------------------------------vsigma---------------------------------------
Maximum Absolute Error:  0.00010551126643498776
Mean Absolute Error:  8.737072305806912e-06


So it appears that our functions give us good $e$ and $v_\rho$ values, but $v_\sigma$ seems to have the most error. Let's examine the $\rho$ and $\sigma$ values where these errors are large to see any trends.

In [214]:
vsigerr = abs(plxc_gga_x['vsigma'] - vsig)
large_error_inds = np.where(vsigerr > 0.0001)[1]
print(large_error_inds)

[12902]


In [215]:
rhoind = rho0[large_error_inds]
sigind = sigma[large_error_inds]
print(rhoind)
print(sigind)

[9.37650318e-11]
[3.7172212e-20]


**CONCLUSION**: Prior to some edits, the maximum absolute error for $v_\sigma$ was $\sim 24$, which was huge. Commenting out the line in `custom_pbe_Fx` that set the minimum `rho` value to be `1e-12` brought this maximum error down to 1e-4.

Now we build the custom functional driver.

In [210]:
def eval_xc_gga_j(xc_code, rho, spin=0, relativity=0, deriv=1, omega=None, verbose=None):
    #we only expect there to be a rho0 array, but I unpack it as (rho, deriv) here to be in line with the
    #pyscf example -- the size of the 'rho' array depends on the xc type (LDA, GGA, etc.)
    #so since LDA calculation, check for size first.
    rho0, dx, dy, dz = rho[:4]
    rho0 = jnp.array(rho0)
    sigma = jnp.array(dx**2+dy**2+dz**2)

    rhosig = (rho0, sigma)
    #calculate the "custom" energy with rho -- THIS IS e
    #cast back to np.array since that's what pyscf works with
    #pass as tuple -- (rho, sigma)
    exc = np.array(jax.vmap(derivable_custom_pbe_e)( rhosig ) )
    
    #first order derivatives w.r.t. rho and sigma
    vrho_f = eqx.filter_grad(derivable_custom_pbe_epsilon)
    vrhosigma = np.array(jax.vmap(vrho_f)( rhosig ))
    # print('vrhosigma shape:', vrhosigma.shape)
    vxc = (vrhosigma[0], vrhosigma[1], None, None)

    # v2_f = eqx.filter_hessian(derivable_custom_pbe_epsilon)
    v2_f = jax.hessian(derivable_custom_pbe_epsilon)
    # v2_f = jax.hessian(custom_pbe_epsilon, argnums=[0, 1])
    v2 = np.array(jax.vmap(v2_f)( rhosig ))
    # print('v2 shape', v2.shape)
    v2rho2 = v2[0][0]
    v2rhosigma = v2[0][1]
    v2sigma2 = v2[1][1]
    v2lapl2 = None
    vtau2 = None
    v2rholapl = None
    v2rhotau = None
    v2lapltau = None
    v2sigmalapl = None
    v2sigmatau = None
    # 2nd order functional derivative
    fxc = (v2rho2, v2rhosigma, v2sigma2, v2lapl2, vtau2, v2rholapl, v2rhotau, v2lapltau, v2sigmalapl, v2sigmatau)
    #3rd order
    kxc = None
    
    return exc, vxc, fxc, kxc


In [211]:
#overwrite the kernel's driver
mfcpj = dft.RKS(mol)
mfcpj = mfcpj.define_xc_(eval_xc_gga_j, 'GGA')
mfcpj.kernel()

converged SCF energy = -76.0026504505729


-76.00265045057293

In [212]:
print(abs(mfp.e_tot - mfcpj.e_tot))

1.4935008181282683e-05


So we've achieved PBE-predictions to an accuracy of $10^{-5}\ \mathrm{H} \sim 10^{-4}\ \mathrm{eV} \sim 10^{-3}\ \mathrm{kcal/mol} \sim 10^{-2}\ \mathrm{kJ/mol}.$