Skip to content

Commit

Permalink
Minor tweaks in the implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
PabloAMC committed Sep 13, 2023
1 parent 400a37a commit fd1f319
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 29 deletions.
8 changes: 5 additions & 3 deletions grad_dft/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from typeguard import typechecked
from grad_dft.molecule import abs_clip

from grad_dft.utils import Scalar, Array, PyTree, DType, default_dtype
from grad_dft.utils import DType, default_dtype
from grad_dft.molecule import Grid, Molecule

import sys
Expand Down Expand Up @@ -179,7 +179,7 @@ def compute_densities(self, molecule: Molecule, *args, **kwargs):

elif self.nograd_densities:
densities = stop_gradient(self.nograd_densities(molecule, *args, **kwargs))
densities = abs_clip(densities, 1e-20)
densities = abs_clip(densities, 1e-20) #todo: investigate if we can lower this
return densities

def compute_coefficient_inputs(self, molecule: Molecule, *args, **kwargs):
Expand Down Expand Up @@ -309,6 +309,7 @@ def _integrate(
Scalar
"""

#todo: study if we can lower this clipping constants
return jnp.einsum("r,r->", abs_clip(gridweights, 1e-20), abs_clip(energy_density, 1e-20), precision=precision)


Expand Down Expand Up @@ -684,7 +685,8 @@ def dm21_hfgrads_densities(
)
return vxc_hf.sum(axis=0) # Sum over omega


@jaxtyped
@typechecked
def dm21_hfgrads_cinputs(
functional: nn.Module,
params: PyTree,
Expand Down
25 changes: 18 additions & 7 deletions grad_dft/popular_functionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,17 @@ def lyp_c_e(rho: Float[Array, "grid spin"], grad_rho: Float[Array, "grid spin 3"
Returns
-------
Float[Array, "grid"]
Notes:
------
Libxc implementation:
https://github.com/ElectronicStructureLibrary/libxc/blob/master/maple/gga_exc/gga_c_lyp.mpl
Important: This implementation uses the original LYP functional definition
in C. Lee, W. Yang, and R. G. Parr., Phys. Rev. B 37, 785 (1988) (doi: 10.1103/PhysRevB.37.785)
instead of the one in libxc: B. Miehlich, A. Savin, H. Stoll, and H. Preuss., Chem. Phys. Lett. 157, 200 (1989) (doi: 10.1016/0009-2614(89)87234-3)
This sometimes gives rise to <1 kcal/mol differences in spin-polarized systems.
"""

a = 0.04918
Expand Down Expand Up @@ -242,19 +253,19 @@ def lyp_c_e(rho: Float[Array, "grid spin"], grad_rho: Float[Array, "grid spin 3"
rho_grad2rho = (rho * grad2rho).sum(axis=1)
# assert not jnp.isnan(rho_grad2rho).any() and not jnp.isinf(rho_grad2rho).any()

exp_factor = jnp.where(rho.sum(axis=1) > 0, jnp.exp(-c * rho.sum(axis=1) ** (-1 / 3)), 0)
# assert not jnp.isnan(exp_factor).any() and not jnp.isinf(exp_factor).any()

rhom1_3 = (rho.sum(axis=1)) ** (-1 / 3)
rho8_3 = (rho ** (8 / 3.0)).sum(axis=1)
rho8_3 = (rho ** (8 / 3)).sum(axis=1)
rhom5_3 = (rho.sum(axis=1)) ** (-5 / 3)

par = 2 ** (2 / 3) * CF * (rho8_3) - rhos_ts + rho_t / 9 + rho_grad2rho / 18
exp_factor = jnp.where(rho.sum(axis=1) > 0, jnp.exp(-c * rhom1_3), 0)
# assert not jnp.isnan(exp_factor).any() and not jnp.isinf(exp_factor).any()

parenthesis = 2 ** (2 / 3) * CF * (rho8_3) - rhos_ts + rho_t / 9 + rho_grad2rho / 18

sum_ = jnp.where(rho.sum(axis=1) > clip_cte, 2 * b * rhom5_3 * par * exp_factor, 0.0)
braket_m_rho = jnp.where(rho.sum(axis=1) > clip_cte, 2 * b * rhom5_3 * parenthesis * exp_factor, 0.0)

return -a * jnp.where(
rho.sum(axis=1) > clip_cte, gamma / (1 + d * rhom1_3) * (rho.sum(axis=1) + sum_), 0.0
rho.sum(axis=1) > clip_cte, gamma / (1 + d * rhom1_3) * (rho.sum(axis=1) + braket_m_rho), 0.0
)

def lsda_density(molecule: Molecule, clip_cte: float = 1e-27, *_, **__) -> Float[Array, "grid densities"]:
Expand Down
29 changes: 29 additions & 0 deletions tests/integration/test_classical_functionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
from flax.core import freeze
from jax import numpy as jnp
import pytest
from grad_dft.functional import DM21 # A class, needs to be instanciated!
from grad_dft.popular_functionals import B3LYP, B88, LSDA, LYP, VWN, PW92

from grad_dft.interface.pyscf import molecule_from_pyscf
from grad_dft.external import NeuralNumInt
from grad_dft.external import Functional

# This file aims to test, given some electronic density, whether our
# implementation of classical functionals closely matches libxc (pyscf default).
Expand Down Expand Up @@ -102,6 +105,9 @@ def test_vwn(mol):
assert jnp.allclose(vwndiff, 0, atol=1)

##### LYP ####
# This test differs slightly due to the use of the original LYP functional definition
# in C. Lee, W. Yang, and R. G. Parr., Phys. Rev. B 37, 785 (1988) (doi: 10.1103/PhysRevB.37.785)
# instead of the one in libxc: B. Miehlich, A. Savin, H. Stoll, and H. Preuss., Chem. Phys. Lett. 157, 200 (1989) (doi: 10.1016/0009-2614(89)87234-3)
@pytest.mark.parametrize("mol", mols)
def test_lyp(mol):
mf = dft.UKS(mol)
Expand All @@ -121,6 +127,9 @@ def test_lyp(mol):
#### B3LYP ####
# This test will only pass if you set B3LYP_WITH_VWN5 = True in pyscf_conf.py.
# See pyscf_conf.py in .github/workflows
# This test differs slightly due to the use of the original LYP functional definition
# in C. Lee, W. Yang, and R. G. Parr., Phys. Rev. B 37, 785 (1988) (doi: 10.1103/PhysRevB.37.785)
# instead of the one in libxc: B. Miehlich, A. Savin, H. Stoll, and H. Preuss., Chem. Phys. Lett. 157, 200 (1989) (doi: 10.1016/0009-2614(89)87234-3)
@pytest.mark.parametrize("mol", mols)
def test_b3lyp(mol):
mf = dft.UKS(mol)
Expand Down Expand Up @@ -154,3 +163,23 @@ def test_pw92(mol):

assert not jnp.isnan(fock).any()
assert jnp.allclose(pw92diff, 0, atol=1e-3)


#### DM21 ####
@pytest.mark.parametrize("mol", mols)
def test_dm21(mol):
mf = dft.UKS(mol)
mf._numint = NeuralNumInt(Functional.DM21)
ground_truth_energy = mf.kernel()

functional = DM21() # Note that DM21 is a class, that needs to be instantiated.
params = functional.generate_DM21_weights()

molecule = molecule_from_pyscf(mf, omegas=[0.0, 0.4])
predict_molecule = molecule_predictor(functional)
predicted_e, fock = predict_molecule(params, molecule)

dm21diff = (ground_truth_energy - predicted_e) * Hartree2kcalmol

assert not jnp.isnan(fock).any()
assert jnp.allclose(dm21diff, 0, atol=1)
17 changes: 6 additions & 11 deletions tests/integration/test_predict_B3LYP.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,26 +56,21 @@ def test_predict(mol_and_name: tuple[gto.Mole, str]) -> None:
mol_and_name (tuple[gto.Mole, str]): PySCF molecule object and the name of the molecule.
"""
mol, name = mol_and_name
if name == "water":
mf = dft.RKS(mol)
elif name == "Li":
mf = dft.UKS(mol)
mf = dft.UKS(mol)
mf.max_cycle = 0

energy = mf.kernel()

molecule = molecule_from_pyscf(mf, energy=energy, omegas=[0.0], scf_iteration=0)

iterator = make_scf_loop(FUNCTIONAL, verbose=2, max_cycles=25)
iterator = make_scf_loop(FUNCTIONAL, verbose=2, max_cycles=10)
e_XND = iterator(PARAMS, molecule)

if name == "water":
if mol.spin == 0:
mf = dft.RKS(mol)
elif name == "Li":
else:
mf = dft.UKS(mol)
mf.xc = "B3LYP"
mf.max_cycle = 25
mf.max_cycle = 10
e_DM = mf.kernel()
kcalmoldiff = (e_XND - e_DM) * Hartree2kcalmol
assert np.allclose(kcalmoldiff, 0, atol=1)

assert np.allclose(kcalmoldiff, 0, atol=1)
18 changes: 10 additions & 8 deletions tests/integration/test_predict_DM21.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
# This only works on startup!
from jax.config import config

from grad_dft.train import molecule_predictor

config.update("jax_enable_x64", True)

dirpath = os.path.dirname(os.path.dirname(__file__))
Expand All @@ -45,7 +47,7 @@
from openfermion import geometry_from_pubchem

from pyscf import gto, dft, cc, scf
import numpy as np
import jax.numpy as jnp
from grad_dft.functional import DM21
from grad_dft.utils.types import Hartree2kcalmol

Expand Down Expand Up @@ -87,11 +89,11 @@ def test_predict(mol):
e_DM = mf.kernel()

kcalmoldiff = (e_XND - e_DM) * Hartree2kcalmol
assert np.allclose(kcalmoldiff, 0, atol=1)
assert jnp.allclose(kcalmoldiff, 0, atol=1)


##################
test_predict(mol)
# test_predict(mol)


###################### Open shell ############################
Expand Down Expand Up @@ -122,16 +124,16 @@ def test_predict(mol):
# iterator = make_orbital_optimizer(functional, tx, omegas = [0., 0.4], verbose = 2, functional_type = 'DM21')
# e_XND_DF4T = iterator(params, molecule)

iterator = make_scf_loop(functional, verbose=2, max_cycles=10)
iterator = make_scf_loop(functional, verbose=2, max_cycles=1)
e_XND = iterator(params, molecule)

mf = dft.UKS(mol)
mf._numint = NeuralNumInt(Functional.DM21)
mf.max_cycle = 10
mf.max_cycle = 1
e_DM = mf.kernel()

kcalmoldiff = (e_XND - e_DM) * Hartree2kcalmol
assert np.allclose(kcalmoldiff, 0, atol=1)
assert jnp.allclose(kcalmoldiff, 0, atol=1)


##################
Expand Down Expand Up @@ -167,7 +169,7 @@ def test_rks():
e_XND = iterator(params, molecule)

kcalmoldiff = (e_XND - e_DM) * Hartree2kcalmol
assert np.allclose(kcalmoldiff, 0, atol=1)
assert jnp.allclose(kcalmoldiff, 0, atol=1)


def test_uks():
Expand Down Expand Up @@ -197,7 +199,7 @@ def test_uks():
e_XND = iterator(params, molecule)

kcalmoldiff = (e_XND - e_DM) * Hartree2kcalmol
assert np.allclose(kcalmoldiff, 0, atol=1)
assert jnp.allclose(kcalmoldiff, 0, atol=1)


##################
Expand Down

0 comments on commit fd1f319

Please sign in to comment.