In [None]:
import argparse
import pyscf
import pyscfad
from xcquinox.pyscf import eval_xc_gga_grho
from xcquinox.utils import gen_grid_s, PBE_Fx, PBE_Fc, calculate_stats, lda_x, pw92c_unpolarized
from xcquinox.train import Pretrainer, Optimizer
from xcquinox.loss import compute_loss_mae
from xcquinox import net, xc
from ase.io import read
import jax.numpy as jnp
import jax
import equinox as eqx
import optax
import numpy as np
from pyscf import dft, scf, gto, cc
from pyscfad import dft as dft_ad
from pyscfad import gto as gto_ad
from pyscfad import scf as scf_ad
from functools import partial
import pylibxc
import pyscfad.dft as dftad
from jax import custom_jvp
jax.config.update("jax_enable_x64", True)  # Enables 64 bit precision


In [None]:
def eval_xc_gga_grho(xc_code, rho, spin=0, relativity=0, deriv=1,
                        omega=None, verbose=None,
                        xcmodel=None):
    """ With networks that expect rho and grad rho - the derivatives here must still be wrt sigma"""
    rho0, dx, dy, dz = rho[:4]
    sigma = (dx ** 2 + dy ** 2 + dz ** 2)
    if xcmodel is None:
        raise ValueError("xcmodel must be provided")
    def xcmodel_sigma(rho0, sigma):
        grad_rho = jnp.sqrt(sigma)
        if rho0.ndim == 0:
            rhogradrho = jnp.array([rho0, grad_rho])
        else:
            rhogradrho = jnp.stack([rho0, grad_rho], axis=-1)
        return xcmodel(rhogradrho)


    # Calculate the "custom" energy with rho -- THIS IS e
    exc = jax.vmap(xcmodel_sigma)(rho0, sigma)
    exc = jnp.array(exc) / rho0

    vrho, vsigma = jax.vmap(
        jax.grad(xcmodel_sigma, argnums=(0, 1)))(rho0, sigma)

    vxc = (vrho, vsigma, None, None)
    fxc = None
    kxc = None

    return exc, vxc, fxc, kxc
def eval_xc_gga_j2(xc_code, rho, spin=0, relativity=0, deriv=1, omega=None, verbose=None,
                   xcmodel=None):
    # we only expect there to be a rho0 array, but I unpack it as (rho, deriv) here to be in line with the
    # pyscf example -- the size of the 'rho' array depends on the xc type (LDA, GGA, etc.)
    # so since LDA calculation, check for size first.
    try:
        rho0, dx, dy, dz = rho[:4]
        sigma = jnp.array(dx**2+dy**2+dz**2)
    except:
        rho0, drho = rho[:4]
        sigma = jnp.array(drho**2)
    rho0 = jnp.array(rho0)
    # sigma = jnp.array(dx**2+dy**2+dz**2)
    # print('DEBUG eval_xc_gga_j: rho0/sigma shapes: ', rho0.shape, sigma.shape)
    # rhosig = (rho0, sigma)
    rhosig = jnp.stack([rho0, sigma], axis=1)
    print(rhosig.shape)
    # calculate the "custom" energy with rho -- THIS IS e
    # cast back to np.array since that's what pyscf works with
    # pass as tuple -- (rho, sigma)
    exc = jax.vmap(xcmodel)(rhosig)
    exc = jnp.array(exc)/rho0
    # exc = jnp.array(jax.vmap(xcmodel)( rhosig ) )/rho0
    # print('exc shape = {}'.format(exc.shape))
    # first order derivatives w.r.t. rho and sigma
    vrho_f = eqx.filter_grad(xcmodel)
    vrhosigma = jnp.array(jax.vmap(vrho_f)(rhosig))
    # print('vrhosigma shape:', vrhosigma.shape)
    vxc = (vrhosigma[:, 0], vrhosigma[:, 1], None, None)

    # v2_f = eqx.filter_hessian(derivable_custom_pbe_epsilon)
    v2_f = jax.hessian(xcmodel)
    # v2_f = jax.hessian(custom_pbe_epsilon, argnums=[0, 1])
    v2 = jnp.array(jax.vmap(v2_f)(rhosig))
    print('v2 shape', v2.shape)
    v2rho2 = v2[:, 0, 0]
    v2rhosigma = v2[:, 0, 1]
    v2sigma2 = v2[:, 1, 1]
    v2lapl2 = None
    vtau2 = None
    v2rholapl = None
    v2rhotau = None
    v2lapltau = None
    v2sigmalapl = None
    v2sigmatau = None
    # 2nd order functional derivative
    fxc = (v2rho2, v2rhosigma, v2sigma2, v2lapl2, vtau2, v2rholapl, v2rhotau, v2lapltau, v2sigmalapl, v2sigmatau)
    # 3rd order
    kxc = None

In [None]:
load_xnet_path = 'GGA_FxNet_G_d3_n16_s42_v_10000' 
load_cnet_path = 'GGA_FcNet_G_d3_n16_s42_v_10000'
diet_traj_path = '../../script_data/dietgmtkn55-50/diet50.traj'
outfile = './test_v.txt'
calc_maxmem = 32000

In [None]:


xnet = net.load_xcquinox_model(load_xnet_path)
cnet = net.load_xcquinox_model(load_cnet_path)
xc = xc.RXCModel_GGA(xnet=xnet, cnet=cnet)


OVERWRITE_EVAL_XC = partial(eval_xc_gga_grho, xcmodel=xc)
GRID_LEVEL = 1  # Default in Alec's scripts 
MAX_SCF_STEPS = 25  # Default in Alec's scripts

traj = read(diet_traj_path, ':')
print("Trajectory loaded in.")
print("Printing information...")
for idx, at in enumerate(traj):
    print(20*'=')
    print(idx, at)
    print(at.info)
    print(at.get_chemical_symbols())
    print(at.positions)

results = {idx: 0 for idx in range(len(traj))}
with open(outfile, 'w') as f:
    f.write("atidx\tatsymbols\tatformula\tsubset\tsubsetind\tspecies\tcount\trefweight\trefen\tcalcen\tcalcconv\n")
for idx, sys in enumerate(traj):
    atstr = ''
    for aidx, sysat in enumerate(sys.get_chemical_symbols()):
        atstr += f"{sysat} {sys.positions[aidx][0]} {sys.positions[aidx][1]} {sys.positions[aidx][2]}\n"
    mol = gto_ad.Mole(atom=atstr, charge=sys.info.get('charge', 0), spin=sys.info.get('spin', 0))
    mol.build()
    # If the local memory usage reaches this max_memory value, the SCF cycles are broken down into sub-loops over small sections of the grid that take *forever* to get through
    mol.max_memory = calc_maxmem
    print("Beginning calculation...")
    print(f"{idx} -- {sys.symbols}/{sys.get_chemical_formula()}")
    # if sys.get_chemical_formula() != 'H':
    if sys.info.get('spin', 0) == 0:
        print("SPIN 0 -> RKS")
        mf = dft_ad.RKS(mol)
        mf.grids.level = GRID_LEVEL
        mf.max_cycle = MAX_SCF_STEPS
        mf.define_xc_(OVERWRITE_EVAL_XC, 'GGA')
        mf.kernel()
    else:
        print("NONZERO SPIN -> UKS")
        mf = dft_ad.UKS(mol)
        mf.grids.level = GRID_LEVEL
        mf.max_cycle = MAX_SCF_STEPS
        mf.define_xc_(OVERWRITE_EVAL_XC, 'GGA')
        mf.kernel()
    results[idx] = (mf.e_tot, mf.converged)
    print(f"Results: CONVERGED = {mf.converged}, ENERGY = {mf.e_tot}")
    with open(outfile, 'a') as f:
        f.write(f"{idx}\t{sys.symbols}\t{sys.get_chemical_formula()}\t\
{sys.info['subset']}\t{sys.info['subsetind']}\t{sys.info['species']}\t{sys.info['count']}\t\
{sys.info['refweight']}\t{sys.info['refen']}\t{mf.e_tot}\t{mf.converged}\n")