In [1]:
import pyscfad
from pyscfad import gto,dft,scf
import matplotlib.pyplot as plt
import equinox as eqx
import pyscf
# from pyscf import gto,dft,scf
import numpy as np
import jax.numpy as jnp
import scipy
from ase import Atoms
from ase.io import read
import xcquinox as xce
from functools import partial
from ase.units import Bohr
import os, optax, jax


2024-04-18 13:17:15.103905: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:280] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


In [2]:
PRETRAIN_LEVEL = 'MGGA'

TRAIN_NET = 'x'

REFERENCE_XC = 'PBE0'

N_HIDDEN = 16
DEPTH = 3
if PRETRAIN_LEVEL == 'GGA':
    localx = xce.net.eX(n_input=1, n_hidden=N_HIDDEN, use=[1], depth=DEPTH, lob=1.804)
    localc = xce.net.eC(n_input=3, n_hidden=N_HIDDEN, use=[2], depth=DEPTH, ueg_limit=True)
elif PRETRAIN_LEVEL == 'MGGA':
    localx = xce.net.eX(n_input=2, n_hidden=N_HIDDEN, use=[1, 2], depth=DEPTH, ueg_limit=True, lob=1.174)
    localc = xce.net.eC(n_input=4, n_hidden=N_HIDDEN, depth=DEPTH, use=[2,3], ueg_limit=True)
elif PRETRAIN_LEVEL == 'NONLOCAL':
    localx = xce.net.eX(n_input=18, n_hidden=N_HIDDEN, depth=DEPTH, ueg_limit=True, lob=1.174)
    localc = xce.net.eC(n_input=16, n_hidden=N_HIDDEN, depth=DEPTH, ueg_limit=True)

xc = xce.xc.eXC(grid_models=[localx, localc], heg_mult=True, level= {'GGA':2, 'MGGA':3, 'NONLOCAL':4}[PRETRAIN_LEVEL])

In [3]:
trainms = read('/home/awills/Documents/Research2/torch_dpy/subset09_nf/subat_ref_corrected.traj', ':')
mfs = []
mols = []
energies = []
dms = []
ao_evals = []
gws = []
eris = []
mo_occs = []
hcs = []
vs = []
ts = []
ss = []
hologaps = []
ogds = []
for idx, at in enumerate(trainms[1:2]):
    name, mol = xce.utils.ase_atoms_to_mol(at, basis='def2tzvpd')
    mol.build()
    mols.append(mol)
    mf = dft.RKS(mol, xc='SCAN')
    e_tot = mf.kernel()
    mfs.append(mf)
    dm = mf.make_rdm1()
    ao_eval = jnp.array(mf._numint.eval_ao(mol, mf.grids.coords, deriv=2))
    energies.append(mf.get_veff().exc)
    dms.append(dm)
    ogds.append(dm.shape)
    ao_evals.append(ao_eval)
    gws.append(mf.grids.weights)
    ts.append(mol.intor('int1e_kin'))
    vs.append(mol.intor('int1e_nuc'))
    mo_occs.append(mf.mo_occ)
    hcs.append(mf.get_hcore())
    eris.append(mol.intor('int2e'))
    ss.append(jnp.linalg.inv(jnp.linalg.cholesky(mol.intor('int1e_ovlp'))))
    hologaps.append(mf.mo_energy[mf.mo_occ == 0][0] - mf.mo_energy[mf.mo_occ > 1][-1])



In [4]:
e_tot

Array(-109.52596483, dtype=float64)

In [15]:
def generate_network_eval_xc(mf, dm, ao, gw, network):
    def eval_xc(xc_code, rho, spin=0, relativity=0, deriv=1, omega=None, verbose=None):
        print('custom eval_xc; input rho shape: ', rho.shape)
        if len(rho.shape) == 2:
            rho0 = rho[0]
            drho = rho[1:4]
            tau = 0.5*(rho[1] + rho[2] + rho[3])
            
            non_loc = jnp.zeros_like(tau)
        # print(f'decomposed shapes:\nrho0={rho0.shape}\ndrho={drho.shape}\ntau={tau.shape}\nnon_loc={non_loc.shape}')
        if dm.ndim == 3: # If unrestricted (open-shell) calculation

            # Density
            rho0_a = rho0[0]
            rho0_b = rho0[1]

            # jnp.einsumed density gradient
            gamma_a, gamma_b = jnp.einsum('ij,ij->j',drho[:,0],drho[:,0]), jnp.einsum('ij,ij->j',drho[:,1],drho[:,1])
            gamma_ab = jnp.einsum('ij,ij->j',drho[:,0],drho[:,1])

            # Kinetic energy density
            tau_a, tau_b = tau

            # E.-static
            non_loc_a, non_loc_b = non_loc
        else:
            rho0_a = rho0_b = rho0*0.5
            gamma_a=gamma_b=gamma_ab= jnp.einsum('ij,ij->j',drho[:],drho[:])*0.25
            tau_a = tau_b = tau*0.5
            non_loc_a=non_loc_b = non_loc*0.5

        # xc-energy per unit particle
        exc = network.eval_grid_models(jnp.concatenate([jnp.expand_dims(rho0_a,-1),
                                                jnp.expand_dims(rho0_b,-1),
                                                jnp.expand_dims(gamma_a,-1),
                                                jnp.expand_dims(gamma_ab,-1),
                                                jnp.expand_dims(gamma_b,-1),
                                                jnp.expand_dims(jnp.zeros_like(rho0_a),-1), #Dummy for laplacian
                                                jnp.expand_dims(jnp.zeros_like(rho0_a),-1), #Dummy for laplacian
                                                jnp.expand_dims(tau_a,-1),
                                                jnp.expand_dims(tau_b,-1),
                                                jnp.expand_dims(non_loc_a,-1),
                                                jnp.expand_dims(non_loc_b,-1)],axis=-1),
                                   mf = mf, dm = dm)
        exc = exc[:, 0]
        # print('exc from network evaluation on grid models shape: ', exc.shape)
        vgf = lambda x: network(x, ao, gw, mf=mf)
        mf.network = xc
        mf.network_eval = vgf
        vxc = jax.grad(vgf)(dm)
        # vrho = vxc
        vrho = exc*rho0
        vgamma = jnp.zeros_like(vrho)
        vlapl = None
        vtau = jnp.nan_to_num(exc*tau)
        fxc = None #second order functional derivative
        kxc = None #third order functional derivative
        # print(f'shapes: vxc={vxc.shape}, vrho={vrho.shape}, vgamma={vgamma.shape}')
        return exc, (vrho, vgamma, vlapl, vtau), fxc, kxc
    return eval_xc

In [16]:
mf.mol.spin

0

In [17]:
evxc = generate_network_eval_xc(mf=mf, dm=dm, ao=ao_eval, gw=mf.grids.weights, network=xc)

In [18]:

mf.define_xc_(evxc, xctype='MGGA')

RKS-KohnShamDFT object of <class 'pyscfad.dft.rks.RKS'>

In [19]:
mf.kernel()

<class 'pyscfad.dft.rks.RKS'> does not have attributes  network network_eval


custom eval_xc; input rho shape:  (6, 25728)
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)
exc is nan, trying alternate calculation with network eval
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)
custom eval_xc; input rho shape:  (6, 25728)
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)
exc is nan, trying alternate calculation with network eval
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)




custom eval_xc; input rho shape:  (6, 25728)
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)
exc is nan, trying alternate calculation with network eval
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)




custom eval_xc; input rho shape:  (6, 25728)
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)
exc is nan, trying alternate calculation with network eval
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)




custom eval_xc; input rho shape:  (6, 25728)
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)
exc is nan, trying alternate calculation with network eval
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)




custom eval_xc; input rho shape:  (6, 25728)
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)
exc is nan, trying alternate calculation with network eval
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)




custom eval_xc; input rho shape:  (6, 25728)
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)
exc is nan, trying alternate calculation with network eval
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)




custom eval_xc; input rho shape:  (6, 25728)
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)
exc is nan, trying alternate calculation with network eval
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)




custom eval_xc; input rho shape:  (6, 25728)
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)
exc is nan, trying alternate calculation with network eval
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)
custom eval_xc; input rho shape:  (6, 25728)
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)
exc is nan, trying alternate calculation with network eval
eX.__call__, rho shape: (2, 25728, 3)
eC.__call__, rho shape: (25728, 4)




Array(-108.63733875, dtype=float64)