In [1]:
import jax.numpy as jnp
import jax
import equinox as eqx
import optax
import numpy as np
from pyscf import dft, scf, gto, cc
from pyscfad import dft as dft_ad
from pyscfad import gto as gto_ad
from pyscfad import scf as scf_ad
from functools import partial
import pylibxc
import pyscfad.dft as dftad
from jax import custom_jvp
jax.config.update("jax_enable_x64", True) #Enables 64 bit precision
from ase.io import read
import pyscfad, pyscf
from xcquinox import net, xc
from xcquinox.loss import compute_loss_mae
from xcquinox.train import Pretrainer, Optimizer
from xcquinox.utils import gen_grid_s, PBE_Fx, PBE_Fc, calculate_stats, lda_x, pw92c_unpolarized
from xcquinox.pyscf import eval_xc_gga_j2

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [2]:
pyscfad.__version__, pyscf.__version__

('0.1.9', '2.9.0')

In [3]:
DIET_TRAJECTORY_PATH = "/home/awills/Documents/Research/xcquinox/scripts/script_data/dietgmtkn55-50/diet50.traj"

traj = read(DIET_TRAJECTORY_PATH, ':')
print("Trajectory loaded in.")
print("Printing information...")    
for idx, at in enumerate(traj):
    print(20*'=')
    print(idx, at)
    print(at.info)
    print(at.get_chemical_symbols())
    print(at.positions)

Trajectory loaded in.
Printing information...
0 Atoms(symbols='FNH3', pbc=False)
{'spin': 0, 'subset': 'AHB21', 'subsetind': 0, 'species': 1, 'count': 1, 'charge': -1, 'refweight': 2.53, 'refen': -17.79, 'energy': -17.79}
['F', 'N', 'H', 'H', 'H']
[[-0.20512  1.19195 -0.69629]
 [-1.07147 -0.00913  1.50279]
 [-0.61757 -0.91698  1.47415]
 [-0.6009   0.48688  2.25356]
 [-0.76307  0.48318  0.60817]]
1 Atoms(symbols='F', pbc=False)
{'spin': 0, 'subset': 'AHB21', 'subsetind': 0, 'species': '1A', 'count': -1, 'charge': -1, 'refweight': 2.53, 'refen': -17.79, 'energy': -17.79}
['F']
[[-0.20512  1.19195 -0.69629]]
2 Atoms(symbols='NH3', pbc=False)
{'spin': 0, 'subset': 'AHB21', 'subsetind': 0, 'species': '1B', 'count': -1, 'charge': 0, 'refweight': 2.53, 'refen': -17.79, 'energy': -17.79}
['N', 'H', 'H', 'H']
[[-1.07147 -0.00913  1.50279]
 [-0.61757 -0.91698  1.47415]
 [-0.6009   0.48688  2.25356]
 [-0.76307  0.48318  0.60817]]
3 Atoms(symbols='Li2', pbc=False)
{'spin': 0, 'subset': 'ALK8', 'su

In [120]:
def local_eval_xc_gga_j2(xc_code, rho, spin=0, relativity=0, deriv=1, omega=None, verbose=None,
                   xcmodel=None):
    # we only expect there to be a rho0 array, but I unpack it as (rho, deriv) here to be in line with the
    # pyscf example -- the size of the 'rho' array depends on the xc type (LDA, GGA, etc.)
    # so since LDA calculation, check for size first.
    try:
        rhoshape = len(rho.shape)
        pol = 3
    except:
        rhoshape = len(rho)
        pol = 2
    #if len of shape == 3, spin polarized so compress to unpolarized for calculation
    if rhoshape != pol:    
        #SPIN-UNPOLARIZED, ALL ARRAYS PASSED AS IS TO LIBXC
        try:
            # print("unpacking rho[:4] into rho0, dx, dy, dz")
            rho0, dx, dy, dz = rho[:4]
            sigma = jnp.array(dx**2+dy**2+dz**2)
        except:
            print("Unpacking failed...")
            rho0, drho = rho[:4]
            sigma = jnp.array(drho**2)
        rho0 = jnp.array(rho0)
        rhosig = jnp.stack([rho0, sigma], axis=1)
        # print('rho/sig/rhosig shapes: ', rho0.shape, sigma.shape, rhosig.shape)
        # calculate the "custom" energy with rho -- THIS IS e
        # cast back to np.array since that's what pyscf works with
        # pass as tuple -- (rho, sigma)
        exc = jax.vmap(xcmodel)(rhosig)
        exc = jnp.array(exc)/rho0
        vrho_f = eqx.filter_grad(xcmodel)
        vrhosigma = jnp.array(jax.vmap(vrho_f)(rhosig))
        #vxc = vrho and vsigma, unpolarized, followed by nothing higher order in GGA
        vxc = (vrhosigma[:, 0], vrhosigma[:, 1], None, None)
    
        v2_f = jax.hessian(xcmodel)
        v2 = jnp.array(jax.vmap(v2_f)(rhosig))
        # print('v2 shape', v2.shape)
        v2rho2 = v2[:, 0, 0]
        v2rhosigma = v2[:, 0, 1]
        v2sigma2 = v2[:, 1, 1]
        v2lapl2 = None
        vtau2 = None
        v2rholapl = None
        v2rhotau = None
        v2lapltau = None
        v2sigmalapl = None
        v2sigmatau = None
        # 2nd order functional derivative
        fxc = (v2rho2, v2rhosigma, v2sigma2, v2lapl2, vtau2, v2rholapl, v2rhotau, v2lapltau, v2sigmalapl, v2sigmatau)
        # 3rd order
        kxc = None

    else:
        #SPIN POLARIZED; RESULT ARRAYS MUST BE RETURNED SPIN POLARIZED
        #THIS IS HACKY -- THE NETWORK IS NOT ARCHITECTED TO ACCEPT ALL THE POLARIZED PARAMETERS, SO THE GRADIENTS ARE JUST DUPLICATED IN THE RETURN;
        #GENERATE A FUNCTION THAT COMBINES THEN CALLS
        def make_epsilon_function(model):
            #importantly, do not place the vmap here
            def get_epsilon(arr):
                rhou, rhod, sigma1, sigma2, sigma3 = arr
                rho0 = jnp.array(rhou+rhod)
                #sum the sigma contributions
                sumsigma = sigma1+sigma2+sigma3

                rhosig = jnp.stack([rho0, sumsigma])
                # calculate the "custom" energy with rho -- THIS IS e
                # cast back to np.array since that's what pyscf works with
                # pass as tuple -- (rho, sigma)
                exc = model(rhosig)
                return exc
            return get_epsilon
                
        # model_epsilon = partial(get_epsilon, model=xcmodel)
        model_epsilon = make_epsilon_function(model=xcmodel)
        rho_u, rho_d = rho
        # print('rho_u, rho_d shapes:', rho_u.shape, rho_d.shape)
        rho0u, dxu, dyu, dzu = rho_u[:4]
        rho0d, dxd, dyd, dzd = rho_d[:4]
        #up-up
        dxu2 = dxu*dxu
        dyu2 = dyu*dyu
        dzu2 = dzu*dzu
        #up-down
        dxud = dxu*dxd
        dyud = dyu*dyd
        dzud = dzu*dzd
        #down-down
        dxd2 = dxd*dxd
        dyd2 = dyd*dyd
        dzd2 = dzd*dzd
        sigma1 = dxu2+dyu2+dzu2
        sigma2 = dxud+dyud+dzud
        sigma3 = dxd2+dyd2+dzd2

        rho0 = jnp.array(rho0u+rho0d)
        # print('rho0 shape', rho0.shape)
        # print('sigma1/2/3 shapes', sigma1.shape, sigma2.shape, sigma3.shape)
        sumsigma = sigma1+sigma2+sigma3
        # print('sumsigma shape', sumsigma.shape)
        #sum the sigma contributions
        rhosig = jnp.stack([rho0, sigma1+sigma2+sigma3], axis=1)
        # calculate the "custom" energy with rho -- THIS IS e
        # cast back to np.array since that's what pyscf works with
        # pass as tuple -- (rho, sigma)
        #epsilon here
        input_arr = jnp.stack([rho0u, rho0d, sigma1, sigma2, sigma3], axis=1)
        exc = jax.vmap(model_epsilon)(input_arr)
        # print('epsilon shape', exc.shape)
        #e here
        exc = jnp.array(exc)/rho0
        # exc = exc[jnp.newaxis, :]
        # print('exc shape', exc.shape)
        v1_f = jax.grad(model_epsilon)
        v1 = jax.vmap(v1_f)(input_arr)
        #vrho = vrho_up, vrho_down
        vrho = jnp.vstack((v1[:, 0], v1[:, 1]))
        #vsigma = vsigma1, vsigma2, vsigma3
        vsigma = jnp.vstack((v1[:, 2], v1[:, 3], v1[:, 4]))
        vxc = (vrho, vsigma)
        # print('vrho shape', vrho.shape)
        # print('vsigma shape', vsigma.shape)
        v2_f = jax.hessian(model_epsilon)
        v2 = jax.vmap(v2_f)(input_arr)
        # print('v2 shape', v2.shape)
        #v2rho2 = (v2rhou2, v2rhoud, v2rhod2)
        v2rho2 = jnp.vstack((v2[:, 0, 0], v2[:, 0, 1], v2[:, 1, 1]))
        #v2rhosigma is six-part = (u,1),(u,2),(u,3),(d,1),(d,2),(d,3)
        v2rhosigma = jnp.vstack((v2[:, 0, 2], v2[:, 0, 3], v2[:, 0, 4], v2[:, 1, 2], v2[:, 1, 3], v2[:, 1, 4]))
        #v2sigma2 is also six-part
        v2sigma2 = jnp.vstack((v2[:, 2, 2], v2[:, 2, 3], v2[:, 2, 4], v2[:, 3, 3], v2[:, 3, 4], v2[:, 4, 4]))
        # print('v2rho2 shape', v2rho2.shape)
        # print('v2rhosigma shape', v2rhosigma.shape)
        # print('v2sigma2 shape', v2sigma2.shape)
        v2lapl2 = None
        vtau2 = None
        v2rholapl = None
        v2rhotau = None
        v2lapltau = None
        v2sigmalapl = None
        v2sigmatau = None
        # 2nd order functional derivative
        fxc = (v2rho2, v2rhosigma, v2sigma2, v2lapl2, vtau2, v2rholapl, v2rhotau, v2lapltau, v2sigmalapl, v2sigmatau)
        # 3rd order
        kxc = None
        TRANSPOSE = True
        if TRANSPOSE:
            vxc = [i.T for i in vxc]
            fxc = [i.T for i in fxc if type(i) == type(jnp.array([1]))]


    return exc, vxc, fxc, kxc

In [None]:
#generate random-weight networks
rw_fx = net.GGA_FxNet_sigma(depth=3, nodes=16, seed=92017)
rw_fc = net.GGA_FcNet_sigma(depth=3, nodes=16, seed=92017)
rw_xc = xc.RXCModel_GGA(xnet = rw_fx, cnet = rw_fc)
OVERWRITE_EVAL_XC = partial(local_eval_xc_gga_j2, xcmodel=rw_xc)
GRID_LEVEL = 1
MAX_SCF_STEPS = 25
results = {idx: 0 for idx in range(len(traj))}
for idx, sys in enumerate(traj):
    atstr = ''
    for aidx, sysat in enumerate(sys.get_chemical_symbols()):
        atstr += f"{sysat} {sys.positions[aidx][0]} {sys.positions[aidx][1]} {sys.positions[aidx][2]}\n"
    mol = gto_ad.Mole(atom=atstr, charge=sys.info.get('charge', 0), spin=sys.info.get('spin', 0))
    mol.build()
    #I set this to suppose 32GB of RAM. 
    #If the local memory usage reaches this max_memory value, the SCF cycles are broken down into sub-loops over small sections of the grid that take *forever* to get through
    mol.max_memory = 32000
    print("Beginning calculation...")
    print(f"{idx} -- {sys.symbols}/{sys.get_chemical_formula()}")
    # if sys.get_chemical_formula() != 'H':
    if sys.info.get('spin', 0) == 0:
        print("SPIN 0 -> RKS")
        mf = dft_ad.RKS(mol)
        mf.grids.level = GRID_LEVEL
        mf.max_cycle = MAX_SCF_STEPS
        mf.define_xc_(OVERWRITE_EVAL_XC, 'GGA')
        mf.kernel()
    else:
        print("NONZERO SPIN -> UKS")
        mf = dft_ad.UKS(mol)
        mf.grids.level = GRID_LEVEL
        mf.max_cycle = MAX_SCF_STEPS
        mf.define_xc_(OVERWRITE_EVAL_XC, 'GGA')
        mf.kernel()
    results[idx] = (mf.e_tot, mf.converged)
    print(f"Results: CONVERGED = {mf.converged}, ENERGY = {mf.e_tot}")

Beginning calculation...
0 -- FNH3/H3FN
SPIN 0 -> RKS




converged SCF energy = -153.016352700298
Results: CONVERGED = True, ENERGY = -153.01635270029846
Beginning calculation...
1 -- F/F
SPIN 0 -> RKS




converged SCF energy = -97.4288601354338
Results: CONVERGED = True, ENERGY = -97.42886013543381
Beginning calculation...
2 -- NH3/H3N
SPIN 0 -> RKS




converged SCF energy = -55.4136748617568
Results: CONVERGED = True, ENERGY = -55.413674861756846
Beginning calculation...
3 -- Li2/Li2
SPIN 0 -> RKS




converged SCF energy = -14.5544510346765
Results: CONVERGED = True, ENERGY = -14.554451034676548
Beginning calculation...
4 -- CLiH3Li2/CH3Li3
SPIN 0 -> RKS




converged SCF energy = -60.9997484399142
Results: CONVERGED = True, ENERGY = -60.999748439914214
Beginning calculation...
5 -- CLiH3/CH3Li
SPIN 0 -> RKS




converged SCF energy = -46.3793332033898
Results: CONVERGED = True, ENERGY = -46.37933320338983
Beginning calculation...
6 -- F/F
NONZERO SPIN -> UKS




converged SCF energy = -97.6902080510096  <S^2> = 0.75  2S+1 = 2
Results: CONVERGED = True, ENERGY = -97.69020805100962
Beginning calculation...
7 -- K/K
NONZERO SPIN -> UKS




converged SCF energy = -592.445062345202  <S^2> = 0.75000045  2S+1 = 2.0000004
Results: CONVERGED = True, ENERGY = -592.4450623452022
Beginning calculation...
8 -- KF/FK
SPIN 0 -> RKS




converged SCF energy = -690.41200520415
Results: CONVERGED = True, ENERGY = -690.41200520415
Beginning calculation...
9 -- NHCHC2H2CO3NCH4C2OH4/C7H12N2O4
SPIN 0 -> RKS


