In [1]:
import pyscfad
from pyscfad import gto,dft,scf
import matplotlib.pyplot as plt
import equinox as eqx
import pyscf
# from pyscf import gto,dft,scf
from pyscfad import dft as dfta
from pyscfad import gto as gtoa 
from pyscfad import scf as scfa
from pyscfad import cc as cca
from pyscfad.scf import hf as hfa
from pyscfad import pbc as pbca
from pyscfad.pbc import scf as scfpa
from pyscfad.pbc import gto as gtopa
from pyscfad.pbc import dft as dftpa

import numpy as np
import jax.numpy as jnp
import scipy
from ase import Atoms
from ase.calculators.singlepoint import SinglePointCalculator
from ase.units import Hartree
from ase.io import read, write
import xcquinox as xce
from functools import partial
from ase.units import Bohr
import os, optax, jax
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"
spins_dict = {
    'Al': 1,
    'B' : 1,
    'Li': 1,
    'Na': 1,
    'Si': 2 ,
    'Be':0,
    'C': 2,
    'Cl': 1,
    'F': 1,
    'H': 1,
    'N': 3,
    'O': 2,
    'P': 3,
    'S': 2,
    'Ar':0, #noble
    'Br':1, #one unpaired electron
    'Ne':0, #noble
    'Sb':3, #same column as N/P
    'Bi':3, #same column as N/P/Sb
    'Te':2, #same column as O/S
    'I':1 #one unpaired electron
}
def get_spin(at):
    #if single atom and spin is not specified in at.info dictionary, use spins_dict
    print('======================')
    print("GET SPIN: Atoms Info")
    print(at)
    print(at.info)
    print('======================')
    if ( (len(at.positions) == 1) and not ('spin' in at.info) ):
        print("Single atom and no spin specified in at.info")
        spin = spins_dict[str(at.symbols)]
    else:
        print("Not a single atom, or spin in at.info")
        if type(at.info.get('spin', None)) == type(0):
            #integer specified in at.info['spin'], so use it
            print('Spin specified in atom info.')
            spin = at.info['spin']
        elif 'radical' in at.info.get('name', ''):
            print('Radical specified in atom.info["name"], assuming spin 1.')
            spin = 1
        elif at.info.get('openshell', None):
            print("Openshell specified in atom info, attempting spin 2.")
            spin = 2
        else:
            print("No specifications in atom info to help, assuming no spin.")
            spin = 0
    return spin



In [2]:
#develop function for reading in networks with assume directory structure
p = '/home/awills/Documents/Research/xcquinox_pt/pbe0/c_3_16_mgga'
ninput = 4
use = []

def loadnet_from_strucdir(path, ninput, use=[]):
    sp = path.split('/')
    if '.eqx' in sp[-1]:
        f = sp[-1]
        sdir = sp[-2]
        fullpath = True
    else:
        sdir = sp[-1]
        f = sorted([i for i in os.listdir(path) if '.eqx' in i], key = lambda x: int(x.split('.')[-1]))[-1]
        fullpath = False
    
    loadnet = path if fullpath else os.path.join(path, f)
    levels = {'gga': 2, 'mgga': 3, 'nl': 4}
    net_type, ndepth, nhidden, level = sdir.split('_')
    if level == 'gga':
        if net_type == 'x':
            use = use if use else [1]
            thisnet = xce.net.eX(n_input=ninput, n_hidden=int(nhidden), use=use, depth=int(ndepth), lob=1.804)
        elif net_type == 'c':
            use = use if use else [2]
            thisnet = xce.net.eC(n_input=ninput, n_hidden=int(nhidden), use=use, depth=int(ndepth), ueg_limit=True)
    elif level == 'mgga':
        if net_type == 'x':
            use = use if use else [1, 2]
            thisnet = xce.net.eX(n_input=ninput, n_hidden=int(nhidden), use=use, depth=int(ndepth), ueg_limit=True, lob=1.174)
        elif net_type == 'c':
            use = use if use else []
            thisnet = xce.net.eC(n_input=ninput, n_hidden=int(nhidden), use=use, depth=int(ndepth), ueg_limit=True)
    elif level == 'nl':
        if net_type == 'x':
            use = use if use else []
            thisnet = xce.net.eX(n_input=ninput, n_hidden=int(nhidden), use=use, depth=int(ndepth), ueg_limit=True, lob=1.174)
        elif net_type == 'c':
            use = use if use else []
            thisnet = xce.net.eC(n_input=ninput, n_hidden=int(nhidden), use=use, depth=int(ndepth), ueg_limit=True)
    
    thisnet = eqx.tree_deserialise_leaves(loadnet, thisnet)
    return thisnet, levels[level]

net, netlevel = loadnet_from_strucdir(p, 4)

In [3]:
LEVL = 'NL'
# LEVL = 'MGGA'
if LEVL == 'NL':
    px = '/home/awills/Documents/Research/xcquinox_pt/pbe0/x_3_16_nl'
    pc = '/home/awills/Documents/Research/xcquinox_pt/pbe0/c_3_16_nl'
    thisx, xl = loadnet_from_strucdir(px, ninput=14, use=[1,2,3,4,5,6,7,8,9,10,11,12,13,14])
    thisc, cl = loadnet_from_strucdir(pc, ninput=16)
elif LEVL == 'MGGA':
    px = '/home/awills/Documents/Research/xcquinox_pt/pbe0/x_3_16_mgga'
    pc = '/home/awills/Documents/Research/xcquinox_pt/pbe0/c_3_16_mgga'
    thisx, xl = loadnet_from_strucdir(px, ninput=2, use=[1,2])
    thisc, cl = loadnet_from_strucdir(pc, ninput=4)


In [4]:
xc = xce.xc.eXC(grid_models=[thisx, thisc], heg_mult=True, level= xl,
               verbose=True)



In [5]:
# PRETRAIN_LEVEL = 'MGGA'

# REFERENCE_XC = 'PBE0'

# N_HIDDEN = 16
# DEPTH = 3
# if PRETRAIN_LEVEL == 'GGA':
#     localx = xce.net.eX(n_input=1, n_hidden=N_HIDDEN, use=[1], depth=DEPTH, lob=1.804)
#     localc = xce.net.eC(n_input=3, n_hidden=N_HIDDEN, use=[2], depth=DEPTH, ueg_limit=True)
# elif PRETRAIN_LEVEL == 'MGGA':
#     localx = xce.net.eX(n_input=2, n_hidden=N_HIDDEN, use=[1, 2], depth=DEPTH, ueg_limit=True, lob=1.174)
#     localc = xce.net.eC(n_input=4, n_hidden=N_HIDDEN, depth=DEPTH, use=[], ueg_limit=True)
# elif PRETRAIN_LEVEL == 'NONLOCAL':
#     localx = xce.net.eX(n_input=18, n_hidden=N_HIDDEN, depth=DEPTH, ueg_limit=True, lob=1.174)
#     localc = xce.net.eC(n_input=16, n_hidden=N_HIDDEN, depth=DEPTH, ueg_limit=True)

# ptmgxp = '/home/awills/Documents/Research/xcquinox_pt/pbe0/x_3_16_mgga'
# ptmgcp = '/home/awills/Documents/Research/xcquinox_pt/pbe0/c_3_16_mgga'

# if PRETRAIN_LEVEL == 'MGGA':
#     try:
#         xcs = sorted([i for i in os.listdir(ptmgxp) if 'xc.eqx' in i], key=lambda x: int(x.split('.')[-1]))[-1]
#         localx = eqx.tree_deserialise_leaves(os.path.join(ptmgxp, xcs), localx)
#     except Exception as e:
#         print(e)
#         print('couldnt read in pt network to overwrite exchange')
#     try:
#         xcs = sorted([i for i in os.listdir(ptmgcp) if 'xc.eqx' in i], key=lambda x: int(x.split('.')[-1]))[-1]
#         localc = eqx.tree_deserialise_leaves(os.path.join(ptmgcp, xcs), localc)
#     except Exception as e:
#         print(e)
#         print('couldnt read in pt network to overwrite correlation')


# xc = xce.xc.eXC(grid_models=[localx, localc], heg_mult=True, level= {'GGA':2, 'MGGA':3, 'NONLOCAL':4}[PRETRAIN_LEVEL],
#                verbose=True)



In [3]:
try:
    trainms = read('/home/awills/Documents/Research2/torch_dpy/subset09_nf/subat_ref_corrected.traj', ':')
except:
    trainms = read('/home/awills/Documents/Research/torch_dpy/subset09_nf/subat_ref_corrected.traj', ':')
mfs = []
mols = []
energies = []
dms = []
ao_evals = []
gws = []
eris = []
mo_occs = []
hcs = []
vs = []
ts = []
ss = []
hologaps = []
ogds = []
for idx, at in enumerate(trainms[6:7]):
    name, mol = xce.utils.ase_atoms_to_mol(at, basis='def2tzvpd')
    mol.verbose=9
    mol.build()
    mols.append(mol)
    mf = dft.RKS(mol, xc='SCAN')
    # mf = scf.UHF(mol)
    mf.grids.level = 1
    e_tot = mf.kernel()
    mf.conv_tol = 1e-6
    mfs.append(mf)
    dm = mf.make_rdm1()
    ao_eval = jnp.array(mf._numint.eval_ao(mol, mf.grids.coords, deriv=2))
    energies.append(mf.get_veff().exc)
    dms.append(dm)
    ogds.append(dm.shape)
    ao_evals.append(ao_eval)
    gws.append(mf.grids.weights)
    ts.append(mol.intor('int1e_kin'))
    vs.append(mol.intor('int1e_nuc'))
    mo_occs.append(mf.mo_occ)
    hcs.append(mf.get_hcore())
    eris.append(mol.intor('int2e'))
    ss.append(jnp.linalg.inv(jnp.linalg.cholesky(mol.intor('int1e_ovlp'))))
    # hologaps.append(mf.mo_energy[mf.mo_occ == 0][0] - mf.mo_energy[mf.mo_occ > 1][-1])

System: uname_result(system='Linux', node='aegis', release='5.15.0-113-generic', version='#123~20.04.1-Ubuntu SMP Wed Jun 12 17:33:13 UTC 2024', machine='x86_64')  Threads 20
Python 3.10.14 (main, Mar 21 2024, 16:24:04) [GCC 11.2.0]
numpy 1.26.4  scipy 1.11.4
Date: Thu Jul 18 14:25:03 2024
PySCF version 2.3.0
PySCF path  /home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/pyscf

[CONFIG] ARGPARSE = False
[CONFIG] DEBUG = False
[CONFIG] MAX_MEMORY = 4000
[CONFIG] TMPDIR = .
[CONFIG] UNIT = angstrom
[CONFIG] VERBOSE = 3
[CONFIG] conf_file = /home/awills/.pyscf_conf.py
[CONFIG] pyscf_numpy_backend = jax
[CONFIG] pyscf_scipy_backend = jax
[CONFIG] pyscf_scipy_linalg_backend = pyscfad
[CONFIG] pyscfad = True
[CONFIG] pyscfad_ccsd_implicit_diff = True
[CONFIG] pyscfad_scf_implicit_diff = True
[INPUT] verbose = 9
[INPUT] max_memory = 4000 
[INPUT] num. atoms = 2
[INPUT] num. electrons = 18
[INPUT] charge = 0
[INPUT] spin (= nelec alpha-beta = 2S) = 0
[INPUT] symmetry False subgro



Padding 0 grids
tot grids = 10600
Drop grids 460
    CPU time for setting up grids      2.68 sec, wall time      0.42 sec
MGGA ni.block_loop; input ao.shape=(10, 10140, 80), weight.shape=(10140,), coords.shape=(10140, 3)
Exception raised: NumInt.eval_xc() got multiple values for argument 'spin'
Traceback (most recent call last):
  File "/home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/pyscfad/dft/numint.py", line 241, in nr_rks
    exc, vxc = ni.eval_xc(xc_code, rho, ao, weight, coords, spin=0,
TypeError: NumInt.eval_xc() got multiple values for argument 'spin'

Falling back to regular form
nelec by numeric integration = 17.999761140470635
    CPU time for vxc      1.86 sec, wall time      0.47 sec
E1 = -339.6620071501802  Ecoul = 129.76614252592103  Exc = -20.645414851105684
init E= -199.658253950575
    CPU time for initialize scf      8.17 sec, wall time      1.27 sec
  HOMO = -0.391516470436875  LUMO = -0.19766535263376
  mo_energy =
[-24.57084021 -24.57081374  -1.3



MGGA ni.block_loop; input ao.shape=(10, 10140, 80), weight.shape=(10140,), coords.shape=(10140, 3)
Exception raised: NumInt.eval_xc() got multiple values for argument 'spin'
Traceback (most recent call last):
  File "/home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/pyscfad/dft/numint.py", line 241, in nr_rks
    exc, vxc = ni.eval_xc(xc_code, rho, ao, weight, coords, spin=0,
TypeError: NumInt.eval_xc() got multiple values for argument 'spin'

Falling back to regular form
nelec by numeric integration = 17.999981686981254
    CPU time for vxc      1.39 sec, wall time      0.12 sec
E1 = -339.5159612062326  Ecoul = 129.78774216390315  Exc = -20.696626263605793
cycle= 2 E= -199.541819781145  delta_E= -0.00431  |g|= 0.148  |ddm|= 0.149
    CPU time for cycle = 2      2.34 sec, wall time      0.22 sec
diis-norm(errvec)=0.341841
diis-c [-0.01169233  0.34598618  0.65401382]
  HOMO = -0.369391643895105  LUMO = -0.191908608430028
  mo_energy =
[-24.55462186 -24.55458958  -1.322050

In [23]:
at = Atoms('P', [[0,0,0]])

In [28]:
name, mol = xce.utils.ase_atoms_to_mol(at, basis='def2tzvpd')
mol.verbose=9
mol.build()
mols.append(mol)
mf = dft.RKS(mol, xc='PBE')
# mf = scf.UHF(mol)
mf.grids.level = 1
e_tot = mf.kernel()
mf.conv_tol = 1e-6
mfs.append(mf)
dm = mf.make_rdm1()
ao_eval = jnp.array(mf._numint.eval_ao(mol, mf.grids.coords, deriv=2))
energies.append(mf.get_veff().exc)
dms.append(dm)
ogds.append(dm.shape)
ao_evals.append(ao_eval)
gws.append(mf.grids.weights)
ts.append(mol.intor('int1e_kin'))
vs.append(mol.intor('int1e_nuc'))
mo_occs.append(mf.mo_occ)
hcs.append(mf.get_hcore())
eris.append(mol.intor('int2e'))
ss.append(jnp.linalg.inv(jnp.linalg.cholesky(mol.intor('int1e_ovlp'))))
# hologaps.append(mf.mo_energy[mf.mo_occ == 0][0] - mf.mo_energy[mf.mo_occ > 1][-1])

System: uname_result(system='Linux', node='aegis', release='5.15.0-113-generic', version='#123~20.04.1-Ubuntu SMP Wed Jun 12 17:33:13 UTC 2024', machine='x86_64')  Threads 20
Python 3.10.14 (main, Mar 21 2024, 16:24:04) [GCC 11.2.0]
numpy 1.26.4  scipy 1.11.4
Date: Thu Jul 18 14:34:39 2024
PySCF version 2.3.0
PySCF path  /home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/pyscf

[CONFIG] ARGPARSE = False
[CONFIG] DEBUG = False
[CONFIG] MAX_MEMORY = 4000
[CONFIG] TMPDIR = .
[CONFIG] UNIT = angstrom
[CONFIG] VERBOSE = 3
[CONFIG] conf_file = /home/awills/.pyscf_conf.py
[CONFIG] pyscf_numpy_backend = jax
[CONFIG] pyscf_scipy_backend = jax
[CONFIG] pyscf_scipy_linalg_backend = pyscfad
[CONFIG] pyscfad = True
[CONFIG] pyscfad_ccsd_implicit_diff = True
[CONFIG] pyscfad_scf_implicit_diff = True
[INPUT] verbose = 9
[INPUT] max_memory = 4000 
[INPUT] num. atoms = 1
[INPUT] num. electrons = 15
[INPUT] charge = 0
[INPUT] spin (= nelec alpha-beta = 2S) = 1
[INPUT] symmetry False subgro



  mo_energy =
[-7.62613406e+01 -6.30715840e+00 -4.51860853e+00 -4.51860853e+00
 -4.51860853e+00 -4.96851379e-01 -1.91347166e-01 -1.91347166e-01
 -1.91347166e-01  2.33360687e-02  1.01268069e-01  1.01268069e-01
  1.01268069e-01  1.01268069e-01  1.01268069e-01  2.16196942e-01
  2.16196942e-01  2.16196942e-01  3.75998130e-01  3.75998130e-01
  3.75998130e-01  3.75998130e-01  3.75998130e-01  4.94765028e-01
  1.26942131e+00  1.26942131e+00  1.26942131e+00  1.26942131e+00
  1.26942131e+00  1.26942131e+00  1.26942131e+00  1.28149071e+00
  1.28149071e+00  1.28149071e+00  1.28149071e+00  1.28149071e+00
  1.40738761e+00  1.40738761e+00  1.40738761e+00  7.04774086e+00
  7.04774086e+00  7.04774086e+00  1.29206178e+01]
GGA ni.block_loop; input ao.shape=(4, 6146, 43), weight.shape=(6146,), coords.shape=(6146, 3)
Exception raised: NumInt.eval_xc() got multiple values for argument 'spin'
Traceback (most recent call last):
  File "/home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/pyscfad/d



E1 = -466.04809952546  Ecoul = 148.12639546302273  Exc = -22.725350361263846
cycle= 3 E= -340.647054423701  delta_E= -0.00801  |g|= 0.162  |ddm|= 1.68
    CPU time for cycle = 3      1.45 sec, wall time      0.10 sec
diis-norm(errvec)=0.358731
diis-c [-0.02985659  0.3702075   0.29218866  0.33760384]

WARN: HOMO -0.527310919292145 == LUMO -0.526675808070449

  mo_energy =
[-7.67927042e+01 -6.77521767e+00 -4.99322917e+00 -4.99246573e+00
 -4.99088290e+00 -8.58272323e-01 -5.27310919e-01 -5.26675808e-01
 -5.25432465e-01 -1.77925870e-01 -1.43069857e-01 -1.42343837e-01
 -1.42193907e-01 -1.38337768e-01 -1.37221378e-01 -3.11913967e-02
 -3.06838182e-02 -3.01636439e-02  1.02797895e-01  1.03306973e-01
  1.03425104e-01  1.07593479e-01  1.08692231e-01  2.25028727e-01
  9.28574736e-01  9.32056785e-01  9.33359242e-01  9.33944644e-01
  9.33984969e-01  9.34405470e-01  9.34697788e-01  9.36100790e-01
  9.37107716e-01  9.38045983e-01  9.38090821e-01  9.38340768e-01
  1.06306786e+00  1.06364086e+00  1.06486

In [32]:
mf.get_veff().exc

GGA ni.block_loop; input ao.shape=(4, 6146, 43), weight.shape=(6146,), coords.shape=(6146, 3)
Exception raised: NumInt.eval_xc() got multiple values for argument 'spin'
Traceback (most recent call last):
  File "/home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/pyscfad/dft/numint.py", line 163, in nr_rks
    exc, vxc = ni.eval_xc(xc_code, rho, ao, weight, coords, spin=0,
TypeError: NumInt.eval_xc() got multiple values for argument 'spin'

Falling back to regular form
nelec by numeric integration = 13.99999987214007
    CPU time for vxc      0.99 sec, wall time      0.07 sec


Array(-22.69562117, dtype=float64)

In [31]:
mf._numint.eval_xc('PBE')

TypeError: NumInt.eval_xc() missing 1 required positional argument: 'rho'

In [30]:
mf.energy_tot(), mf.energy_elec(), mf.energy_nuc()

GGA ni.block_loop; input ao.shape=(4, 6146, 43), weight.shape=(6146,), coords.shape=(6146, 3)
Exception raised: NumInt.eval_xc() got multiple values for argument 'spin'
Traceback (most recent call last):
  File "/home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/pyscfad/dft/numint.py", line 163, in nr_rks
    exc, vxc = ni.eval_xc(xc_code, rho, ao, weight, coords, spin=0,
TypeError: NumInt.eval_xc() got multiple values for argument 'spin'

Falling back to regular form
nelec by numeric integration = 13.99999987214007
    CPU time for vxc      1.21 sec, wall time      0.08 sec
E1 = -465.75623783621006  Ecoul = 147.8013049616361  Exc = -22.6956211743425
GGA ni.block_loop; input ao.shape=(4, 6146, 43), weight.shape=(6146,), coords.shape=(6146, 3)
Exception raised: NumInt.eval_xc() got multiple values for argument 'spin'
Traceback (most recent call last):
  File "/home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/pyscfad/dft/numint.py", line 163, in nr_rks
    ex

(Array(-340.65055405, dtype=float64),
 (Array(-340.65055405, dtype=float64), Array(125.10568379, dtype=float64)),
 0.0)

In [7]:
mf.define_xc_?

[0;31mSignature:[0m [0mmf[0m[0;34m.[0m[0mdefine_xc_[0m[0;34m([0m[0mdescription[0m[0;34m,[0m [0mxctype[0m[0;34m=[0m[0;34m'LDA'[0m[0;34m,[0m [0mhyb[0m[0;34m=[0m[0;36m0[0m[0;34m,[0m [0mrsh[0m[0;34m=[0m[0;34m([0m[0;36m0[0m[0;34m,[0m [0;36m0[0m[0;34m,[0m [0;36m0[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mFile:[0m      ~/anaconda3/envs/pyscfad/lib/python3.10/site-packages/pyscf/dft/rks.py
[0;31mType:[0m      method

In [8]:
thisx.spin_scaling

False

In [9]:
mf.mo_energy

Array([-24.54351844, -24.54348579,  -1.31780039,  -1.07049372,
        -0.57987558,  -0.50051198,  -0.50051198,  -0.36482604,
        -0.36482604,  -0.18662279,   0.08411717,   0.09380579,
         0.11310517,   0.11310517,   0.16355414,   0.16820664,
         0.16820664,   0.32756182,   0.43567852,   0.43575939,
         0.56914708,   0.57289653,   0.58949981,   0.58949981,
         0.62541633,   0.62541633,   0.64091874,   0.64101837,
         0.6731772 ,   0.6731772 ,   0.84139823,   0.99700564,
         1.05762003,   1.05762003,   1.11653403,   1.50989831,
         1.88078377,   1.88078377,   2.11780274,   2.11816264,
         2.13207825,   2.283855  ,   2.28426641,   2.55257369,
         2.55257369,   3.01741996,   3.16179668,   3.16179668,
         3.24140839,   3.51073557,   3.51073557,   3.77408443,
         4.05091309,   4.93682678,   6.48795933,   6.48795933,
         6.6683857 ,   6.73563041,   6.73638843,   6.91056913,
         6.91056913,   6.94384751,   6.94384751,   7.07

In [10]:
xc(dms[0], ao_evals[0], gws[0], mf=mf, coor=mf.grids.coords)

custom gw and coor present in eval_grid_models; shapes: gw=(10140,), coor=(10140, 3)
eval_grid_models initial nan summary:
zeta, rs, rs_a, rs_b, exc_a, exc_b, exc_ab
0, 0, 0, 0, 0, 0, 0
l_1, descr shape: (10140,)
NaNs in descr from self.l_1 = 0
self.level > 0; descr1 Nans = 0
self.level > 0; descr2 Nans = 0
get_descriptors -> self.level > 0
descr1.shape=(10140,), descr2.shape=(10140,), descr.shape=(10140, 2)
l_2, descr shape: (10140,)
NaNs in descr from self.l_2 = 0
self.level > 1; descr3 Nans = 0
get_descriptors -> self.level > 1
descr3.shape=(10140, 1), descr.shape=(10140, 3)
l_3, descr shape: (10140,)
NaNs in descr from self.l_3 = 0
self.level > 2; pre-log descr4 Nans = 0
descr4.min/max: 0.0024562108654236467, 105.93784961211075
self.level > 2; descr4 Nans = 0
get_descriptors -> self.level > 2
descr4.shape=(10140, 1), descr.shape=(10140, 4)
Constructing non-local CIDER descriptor generator
Sending mf=RKS-KohnShamDFT object of <class 'pyscfad.dft.rks.RKS'> to RKSAnalyzer
mf.e_tot=-19



nl_4, descr5 shape: (12, 10140)
NaNs in descr from self.l_1 = 0
get_descriptors -> self.level > 3 -> returned descr5 shape=(10140, 12)
get_descriptors not_spin_scaling -> self.level > 3 -> descr5.shape=(10140, 12)
get_descriptors, not spin_scaling -> descr.shape=(10140, 16)
l_1, descr shape: (10140,)
NaNs in descr from self.l_1 = 0
l_1, descr shape: (10140,)
NaNs in descr from self.l_1 = 0
self.level > 0; descr1 Nans = 0
self.level > 0; descr2 Nans = 0
get_descriptors -> self.level > 0
descr1.shape=(10140,), descr2.shape=(10140,), descr.shape=(10140, 2)
l_2, descr shape: (10140,)
NaNs in descr from self.l_2 = 0
l_2, descr shape: (10140,)
NaNs in descr from self.l_2 = 0
self.level > 1; descr3a Nans = 0
self.level > 1; descr3b Nans = 0
get_descriptors -> self.level > 1 and spin_scaling
descr3a.shape=(10140,), descr3b.shape=(10140,)
self.level > 1; descr3 Nans = 0
get_descriptors -> self.level > 1
descr3.shape=(10140, 2), descr.shape=(10140, 4)
l_3, descr shape: (10140,)
NaNs in descr fro

Array(-1.39127476, dtype=float64)

In [11]:
def generate_network_eval_xc(mf, dm, network):
    '''
    Generates a function to overwrite eval_xc with on the mf object, for use in training with pyscfad's SCF cycle

    :param mf: Pyscfad calculation kernel object
    :type mf: Pyscfad calculation kernel object
    :param dm: Initial density matrix to use in the cycle
    :type dm: jax.Array
    :param network: The network to use in evaluating the SCF cycle
    :type network: xcquinox.xc.eXC
    :return: A function `eval_xc` that uses an xcquinox network as the pyscfad kernel calculation driver.
    :rtype: function

    The returned function:

    eval_xc(xc_code, rho, ao, gw, coords, spin=0, relativity=0, deriv=1, omega=None, verbose=None)
    The function to use as driver for a pyscf(ad) calculation, using an xcquinox network.

    This overwrites mf.eval_xc with a custom function, evaluating:

    Exc_exc, vs = jax.value_and_grad(EXC_exc_vs, has_aux=True)(jnp.concatenate([jnp.expand_dims(rho0_a,-1),
                                            jnp.expand_dims(rho0_b,-1),
                                            jnp.expand_dims(gamma_a,-1),
                                            jnp.expand_dims(gamma_ab,-1),
                                            jnp.expand_dims(gamma_b,-1),
                                            jnp.expand_dims(jnp.zeros_like(rho0_a),-1), #Dummy for laplacian
                                            jnp.expand_dims(jnp.zeros_like(rho0_a),-1), #Dummy for laplacian
                                            jnp.expand_dims(tau_a,-1),
                                            jnp.expand_dims(tau_b,-1),
                                            jnp.expand_dims(non_loc_a,-1),
                                            jnp.expand_dims(non_loc_b,-1)],axis=-1))


        :param xc_code: The XC functional code string in libxc format, but it is ignored as the network is the calculation driver
        :type xc_code: str
        :param rho: The [..., *, N] arrays (... for spin polarized), N is the number of grid points.
                    rho (*,N) ordered as (rho, grad_x, grad_y, grad_z, laplacian, tau)
                    rho (2,*,N) is [(rho_up, grad_x_up, grad_y_up, grad_z_up, laplacian_up, tau_up),
                                    (rho_down, grad_x_down, grad_y_down, grad_z_down, laplacian_down, tau_down)]
                    PySCFAD doesn't do spin-polarized grid calculations yet, so this will be unpolarized.
        :type rho: jax.Array
        :param ao: The atomic orbitals on the grid to use in the network calculation. Explcitly specified as the block loops break down the grid if memory is too low
        :type ao: jax.Array
        :param ao: The grid weights to use in the network calculation. Explcitly specified as the block loops break down the grid if memory is too low
        :type ao: jax.Array
        :param ao: The grid coordinates to use in the network calculation. Explcitly specified as the block loops break down the grid if memory is too low
        :type ao: jax.Array
        :param spin: The spin of the calculation, integer valued, polarized if non-zero, defaults to zero
        :type spin: int
        :param relativity: Integer, unused right now, defaults to zero
        :type relativity: int
        :param deriv: Unused here, defaults to 1
        :type deriv: int
        :param omega: Hybrid mixing term, unused here, defaults to None
        :type omega: float
        :param verbose: Unused here, defaults to None
        :type verbose: int
        :return: ex, vxc, fxc, kxc
                 where: ex -> exc, XC energy density on the grid
                        vxc -> (vrho, vsigma, vlapl, vtau), gradients of the exc w.r.t. the quantities given.
                        Only vrho and vtau are used, vsigma=vlapl=fxc=kxc=None.
                        vrho = vs[:, 0]+vs[:, 1]
                        vtau = vs[:, 7]+vs[:, 8]
        
        :rtype: tuple
    '''
    def eval_xc(xc_code, rho, ao, gw, coords, spin=0, relativity=0, deriv=1, omega=None, verbose=None):
        '''
        The function to use as driver for a pyscf(ad) calculation, using an xcquinox network.

        This overwrites mf.eval_xc with a custom function, evaluating:

        Exc_exc, vs = jax.value_and_grad(EXC_exc_vs, has_aux=True)(jnp.concatenate([jnp.expand_dims(rho0_a,-1),
                                                jnp.expand_dims(rho0_b,-1),
                                                jnp.expand_dims(gamma_a,-1),
                                                jnp.expand_dims(gamma_ab,-1),
                                                jnp.expand_dims(gamma_b,-1),
                                                jnp.expand_dims(jnp.zeros_like(rho0_a),-1), #Dummy for laplacian
                                                jnp.expand_dims(jnp.zeros_like(rho0_a),-1), #Dummy for laplacian
                                                jnp.expand_dims(tau_a,-1),
                                                jnp.expand_dims(tau_b,-1),
                                                jnp.expand_dims(non_loc_a,-1),
                                                jnp.expand_dims(non_loc_b,-1)],axis=-1))

        :param xc_code: The XC functional code string in libxc format, but it is ignored as the network is the calculation driver
        :type xc_code: str
        :param rho: The [..., *, N] arrays (... for spin polarized), N is the number of grid points.
                    rho (*,N) ordered as (rho, grad_x, grad_y, grad_z, laplacian, tau)
                    rho (2,*,N) is [(rho_up, grad_x_up, grad_y_up, grad_z_up, laplacian_up, tau_up),
                                    (rho_down, grad_x_down, grad_y_down, grad_z_down, laplacian_down, tau_down)]
                    PySCFAD doesn't do spin-polarized grid calculations yet, so this will be unpolarized.
        :type rho: jax.Array
        :param ao: The atomic orbitals on the grid to use in the network calculation. Explcitly specified as the block loops break down the grid if memory is too low
        :type ao: jax.Array
        :param ao: The grid weights to use in the network calculation. Explcitly specified as the block loops break down the grid if memory is too low
        :type ao: jax.Array
        :param ao: The grid coordinates to use in the network calculation. Explcitly specified as the block loops break down the grid if memory is too low
        :type ao: jax.Array
        :param spin: The spin of the calculation, integer valued, polarized if non-zero, defaults to zero
        :type spin: int
        :param relativity: Integer, unused right now, defaults to zero
        :type relativity: int
        :param deriv: Unused here, defaults to 1
        :type deriv: int
        :param omega: Hybrid mixing term, unused here, defaults to None
        :type omega: float
        :param verbose: Unused here, defaults to None
        :type verbose: int
        :return: ex, vxc, fxc, kxc
                 where: ex -> exc, XC energy density on the grid
                        vxc -> (vrho, vsigma, vlapl, vtau), gradients of the exc w.r.t. the quantities given.
                        Only vrho and vtau are used, vsigma=vlapl=fxc=kxc=None.
                        vrho = vs[:, 0]+vs[:, 1]
                        vtau = vs[:, 7]+vs[:, 8]
        
        :rtype: tuple
        '''
        vgf = lambda x: network(x, ao, gw, mf=mf, coor=coords)
        mf.converged = True
        mf.network = network
        mf.network_eval = vgf

        # print('custom eval_xc; input rho shape: ', rho.shape)
        if len(rho.shape) == 2:
            #not spin-polarized
            rho0 = rho[0] #density
            drho = rho[1:4] #grad_x, grad_y, grad_z
            #laplacian next
            # tau = 0.5*(rho[1] + rho[2] + rho[3])
            tau = rho[-1] # tau
            
            non_loc = jnp.zeros_like(tau)
            #decompose into spin channels
            rho0_a = rho0_b = rho0*0.5
            gamma_a=gamma_b=gamma_ab= jnp.einsum('ij,ij->j',drho[:],drho[:])*0.25
            tau_a = tau_b = tau*0.5
            non_loc_a=non_loc_b = non_loc*0.5
            if network.verbose:
                print(f'decomposed shapes:\nrho0={rho0.shape}\ndrho={drho.shape}\ntau={tau.shape}\nnon_loc={non_loc.shape}')
                print(f'decomposed shapes:\ngamma_a={gamma_a.shape}\ngamma_b={gamma_b.shape}\ngamma_ab={gamma_ab.shape}')
        else:
            #spin-polarized density
            rho0_a = rho[0, 0]
            rho0_b = rho[1, 0]

            drho_a = rho[0, 1:4]
            drho_b = rho[1, 1:4]
            # jnp.einsumed density gradient
            gamma_a, gamma_b = jnp.einsum('ij,ij->j',drho_a,drho_a), jnp.einsum('ij,ij->j',drho_b,drho_b)
            gamma_ab = jnp.einsum('ij,ij->j',drho_a,drho_b)
            # Kinetic energy density
            tau_a = rho[0, -1]
            tau_b = rho[1, -1]

            non_loc_a, non_loc_b = jnp.zeros_like(tau_a), jnp.zeros_like(tau_b)
            if network.verbose:
                print(f'decomposed shapes:\nrho0(a,b)={rho0_a.shape},{rho0_b.shape}\ndrho(a,b)={drho_a.shape},{drho_b.shape}\ntau(a,b)={tau_a.shape},{tau_b.shape}\nnon_loc(a,b)={non_loc_a.shape},{non_loc_b.shape}')
                print(f'decomposed shapes:\ngamma_a={gamma_a.shape}\ngamma_b={gamma_b.shape}\ngamma_ab={gamma_ab.shape}')


        # xc-energy per unit particle
        # print(f'EVALUATING GRID MODELS; OPTIONAL PARAMETERS:')
        # try:
        #     print(f'gw.shape={gw.shape}, coor.shape={coor.shape}')
        # except:
        #     print('no externally supplied gw or coor')
        # print('eval_xc eval_grid_models call')
        
        def EXC_exc_vs(x):
            exc = network.eval_grid_models(x, mf=mf, dm=dm, ao=ao, gw=gw, coor=coords)
            Exc = jnp.sum(((rho0_a + rho0_b)*exc[:,0])*gw)
            return Exc, exc
        if network.verbose:
            print(f'eval_xc -> Exc_exc and potentials on grid via autodiff')
        v_and_g_inp = jnp.concatenate([jnp.expand_dims(rho0_a,-1),
                                                jnp.expand_dims(rho0_b,-1),
                                                jnp.expand_dims(gamma_a,-1),
                                                jnp.expand_dims(gamma_ab,-1),
                                                jnp.expand_dims(gamma_b,-1),
                                                jnp.expand_dims(jnp.zeros_like(rho0_a),-1), #Dummy for laplacian
                                                jnp.expand_dims(jnp.zeros_like(rho0_a),-1), #Dummy for laplacian
                                                jnp.expand_dims(tau_a,-1),
                                                jnp.expand_dims(tau_b,-1),
                                                jnp.expand_dims(non_loc_a,-1),
                                                jnp.expand_dims(non_loc_b,-1)],axis=-1)
        print(f'v_and_g_inp.shape={v_and_g_inp.shape}')
        Exc_exc, vs = jax.value_and_grad(EXC_exc_vs, has_aux=True)(v_and_g_inp)
        print(f'Exc_exc and vs returned: Exc = {Exc_exc[0]}, exc.shape={Exc_exc[1].shape}, vs.shape={vs.shape}')
        Exc, exc = Exc_exc
        print(f'eval_xc Exc = {Exc}')
        if jnp.sum(jnp.isnan(exc[:, 0])):
            print('NaNs detected in exc. Number of NaNs: {}'.format(jnp.sum(jnp.isnan(exc[:, 0]))))
            raise
        else:
            exc = exc[:, 0]
            
        # print('ao shape: ', ao.shape)
        # print('exc from network evaluation on grid models shape: ', exc.shape)
        # print('vs from network evaluation on grid models shape: ', vs.shape)
        # print('Exc from network evaluation on grid models shape: ', Exc)

        #vrho; d Exc/d rho, separate spin channels
        vrho = vs[:, 0]+vs[:, 1]
        #vtau; d Exc/d tau, separate spin channels
        vtau = vs[:, 7]+vs[:, 8]
        
        vgamma = jnp.zeros_like(vrho)
        
        vlapl = None
        
        fxc = None #second order functional derivative
        kxc = None #third order functional derivative
        if network.verbose:
            print(f'shapes: vrho={vrho.shape}, vgamma={vgamma.shape}')
        return exc, (vrho, vgamma, vlapl, vtau), fxc, kxc
    return eval_xc

In [12]:
mft0 = dft.RKS(mol, xc='scan')
mft0.conv_tol = 1e-5
mft0.max_cycle = -1
mft0.kernel()

Set gradient conv threshold to 0.00316228
Initial guess from minao.
cond(S) = 5161.200779305
atom F rad-grids = 75, ang-grids = [ 50  50  50  50  50  50  50  50  50  50  50  50  50  50  50  50  50  50
  50  50  50  86  86  86  86  86  86  86  86  86 266 266 266 266 266 266
 302 302 302 302 302 302 302 302 302 302 302 302 302 302 302 302 302 302
 266 266 266 266 266 266 266 266 266 266 266 266 266 266 266 266 266 266
 266 266 266]
tot_boxes 891, boxes in each direction [7 7 9]
Padding 4 grids
tot grids = 28888
Drop grids 1518
    CPU time for setting up grids      2.12 sec, wall time      0.28 sec
MGGA ni.block_loop; input ao.shape=(10, 27370, 80), weight.shape=(27370,), coords.shape=(27370, 3)
Exception raised: NumInt.eval_xc() got multiple values for argument 'spin'
Falling back to regular form
nelec by numeric integration = 17.999777216793575
    CPU time for vxc      2.70 sec, wall time      0.52 sec
E1 = -339.6620071501802  Ecoul = 129.76614252592103  Exc = -20.65265878410222
init 

Array(-199.66549788, dtype=float64)

In [13]:
mft0.e_tot, mft0.mo_coeff, mft0.converged

(Array(-199.66549788, dtype=float64),
 array([[-3.01034362e-01,  3.01011481e-01, -7.33342717e-02, ...,
          1.99674824e-02,  9.83642390e-01,  9.88791034e-01],
        [-4.58450356e-01,  4.58632757e-01, -1.78440415e-01, ...,
          5.91387195e-02, -1.19329670e+00, -1.22461939e+00],
        [-2.46292374e-02,  2.29195466e-02,  2.56646222e-01, ...,
         -6.20853866e-01,  1.02679074e+00,  1.26755795e+00],
        ...,
        [ 2.29711832e-19,  2.13453100e-19, -5.98677519e-18, ...,
          3.33945578e-18, -1.27674769e-18,  2.08954857e-18],
        [ 3.61178927e-20,  3.12847826e-20, -1.65105268e-18, ...,
         -3.86958594e-18, -7.39755221e-19,  9.81024106e-19],
        [ 3.67528181e-20,  2.86591142e-20, -1.33083723e-18, ...,
          2.39262787e-17,  9.95705097e-19, -4.32206734e-19]]),
 False)

In [14]:
mft = dft.RKS(mol, xc='scan')
mft.conv_tol = 1e-5
evxc = generate_network_eval_xc(mf=mft0, dm=dms[0], network=xc)
mft.define_xc_(evxc, 'MGGA')
# vgf = lambda x: xc(x, ao_evals[0], gws[0], mf=mf, coor=mf.grids.coords)
# mft.network = xc
# mft.network_eval = vgf
mft.kernel()

Set gradient conv threshold to 0.00316228
Initial guess from minao.
cond(S) = 5161.200779305
atom F rad-grids = 75, ang-grids = [ 50  50  50  50  50  50  50  50  50  50  50  50  50  50  50  50  50  50
  50  50  50  86  86  86  86  86  86  86  86  86 266 266 266 266 266 266
 302 302 302 302 302 302 302 302 302 302 302 302 302 302 302 302 302 302
 266 266 266 266 266 266 266 266 266 266 266 266 266 266 266 266 266 266
 266 266 266]
tot_boxes 891, boxes in each direction [7 7 9]
Padding 4 grids
tot grids = 28888
Drop grids 1518




    CPU time for setting up grids      2.40 sec, wall time      0.20 sec
MGGA ni.block_loop; input ao.shape=(10, 27370, 80), weight.shape=(27370,), coords.shape=(27370, 3)
decomposed shapes:
rho0=(27370,)
drho=(3, 27370)
tau=(27370,)
non_loc=(27370,)
decomposed shapes:
gamma_a=(27370,)
gamma_b=(27370,)
gamma_ab=(27370,)
eval_xc -> Exc_exc and potentials on grid via autodiff
v_and_g_inp.shape=(27370, 11)
custom gw and coor present in eval_grid_models; shapes: gw=(27370,), coor=(27370, 3)
eval_grid_models initial nan summary:
zeta, rs, rs_a, rs_b, exc_a, exc_b, exc_ab
0, 0, 0, 0, 0, 0, 0
l_1, descr shape: (27370,)
NaNs in descr from self.l_1 = 0
self.level > 0; descr1 Nans = 0
self.level > 0; descr2 Nans = 0
get_descriptors -> self.level > 0
descr1.shape=(27370,), descr2.shape=(27370,), descr.shape=(27370, 2)
l_2, descr shape: (27370,)
NaNs in descr from self.l_2 = 0
self.level > 1; descr3 Nans = 0
get_descriptors -> self.level > 1
descr3.shape=(27370, 1), descr.shape=(27370, 3)
l_3, des



nl_4, descr5 shape: (12, 27370)
NaNs in descr from self.l_1 = 0
get_descriptors -> self.level > 3 -> returned descr5 shape=(27370, 12)
get_descriptors not_spin_scaling -> self.level > 3 -> descr5.shape=(27370, 12)
get_descriptors, not spin_scaling -> descr.shape=(27370, 16)
l_1, descr shape: (27370,)
NaNs in descr from self.l_1 = 0
l_1, descr shape: (27370,)
NaNs in descr from self.l_1 = 0
self.level > 0; descr1 Nans = 0
self.level > 0; descr2 Nans = 0
get_descriptors -> self.level > 0
descr1.shape=(27370,), descr2.shape=(27370,), descr.shape=(27370, 2)
l_2, descr shape: (27370,)
NaNs in descr from self.l_2 = 0
l_2, descr shape: (27370,)
NaNs in descr from self.l_2 = 0
self.level > 1; descr3a Nans = 0
self.level > 1; descr3b Nans = 0
get_descriptors -> self.level > 1 and spin_scaling
descr3a.shape=(27370,), descr3b.shape=(27370,)
self.level > 1; descr3 Nans = 0
get_descriptors -> self.level > 1
descr3.shape=(27370, 2), descr.shape=(27370, 4)
l_3, descr shape: (27370,)
NaNs in descr fro



nl_4, descr5 shape: (12, 27370)
NaNs in descr from self.l_1 = 0
get_descriptors -> self.level > 3 -> returned descr5 shape=(27370, 12)
spin_scaling and self.level > 3, descr5.shape=(27370, 12)
decomposing descriptors into spin channels, half each
new descr5.shape=(2, 27370, 12)
get_descriptors -> self.level > 3, descr5.shape=(2, 27370, 12)
spin_scaling, get_descriptors -> reshaping -> descr.shape=(2, 27370, 15)
NaNs in descr_dict[False] = 0
NaNs in descr_dict[True] = 0
Grid models present; looping over separate networks to construct exc
eX.__call__, rho shape: (27370, 16)
exc.shape, descr with spin_scaling=False -> (27370,)
NaNs in exc from gm_eval_func, spin_scaling=False -> = 0
eval_grid_models gm_eval_func [0] nan summary:
exc_a, exc_b, exc_ab
0, 0, 0
eC.__call__, rho shape: (27370, 16)
exc.shape, descr with spin_scaling=False -> (27370,)
NaNs in exc from gm_eval_func, spin_scaling=False -> = 0
eval_grid_models gm_eval_func [1] nan summary:
exc_a, exc_b, exc_ab
0, 0, 0
eval_grid_mod



nl_4, descr5 shape: (12, 27370)
NaNs in descr from self.l_1 = 0
get_descriptors -> self.level > 3 -> returned descr5 shape=(27370, 12)
get_descriptors not_spin_scaling -> self.level > 3 -> descr5.shape=(27370, 12)
get_descriptors, not spin_scaling -> descr.shape=(27370, 16)
l_1, descr shape: (27370,)
NaNs in descr from self.l_1 = 0
l_1, descr shape: (27370,)
NaNs in descr from self.l_1 = 0
self.level > 0; descr1 Nans = 0
self.level > 0; descr2 Nans = 0
get_descriptors -> self.level > 0
descr1.shape=(27370,), descr2.shape=(27370,), descr.shape=(27370, 2)
l_2, descr shape: (27370,)
NaNs in descr from self.l_2 = 0
l_2, descr shape: (27370,)
NaNs in descr from self.l_2 = 0
self.level > 1; descr3a Nans = 0
self.level > 1; descr3b Nans = 0
get_descriptors -> self.level > 1 and spin_scaling
descr3a.shape=(27370,), descr3b.shape=(27370,)
self.level > 1; descr3 Nans = 0
get_descriptors -> self.level > 1
descr3.shape=(27370, 2), descr.shape=(27370, 4)
l_3, descr shape: (27370,)
NaNs in descr fro



MGGA ni.block_loop; input ao.shape=(10, 27370, 80), weight.shape=(27370,), coords.shape=(27370, 3)
decomposed shapes:
rho0=(27370,)
drho=(3, 27370)
tau=(27370,)
non_loc=(27370,)
decomposed shapes:
gamma_a=(27370,)
gamma_b=(27370,)
gamma_ab=(27370,)
eval_xc -> Exc_exc and potentials on grid via autodiff
v_and_g_inp.shape=(27370, 11)
custom gw and coor present in eval_grid_models; shapes: gw=(27370,), coor=(27370, 3)
eval_grid_models initial nan summary:
zeta, rs, rs_a, rs_b, exc_a, exc_b, exc_ab
0, 0, 0, 0, 0, 0, 0
l_1, descr shape: (27370,)
NaNs in descr from self.l_1 = 0
self.level > 0; descr1 Nans = 0
self.level > 0; descr2 Nans = 0
get_descriptors -> self.level > 0
descr1.shape=(27370,), descr2.shape=(27370,), descr.shape=(27370, 2)
l_2, descr shape: (27370,)
NaNs in descr from self.l_2 = 0
self.level > 1; descr3 Nans = 0
get_descriptors -> self.level > 1
descr3.shape=(27370, 1), descr.shape=(27370, 3)
l_3, descr shape: (27370,)
NaNs in descr from self.l_3 = 0
self.level > 2; pre-lo

Array(-180.90369937, dtype=float64)

In [15]:
mft.e_tot

Array(-180.90369937, dtype=float64)

In [16]:
L = jnp.eye(dm.shape[-1])
scaling = jnp.ones(dm.shape[-1]*2)

In [17]:
mft.make_rdm1()

Array([[ 3.37903679e-01,  5.78295517e-01,  2.92187768e-04, ...,
        -1.67119131e-17,  2.33769370e-11,  5.36457920e-20],
       [ 5.78295517e-01,  1.00616626e+00, -5.59027357e-02, ...,
        -4.09067152e-17,  5.95167218e-11,  1.43747902e-19],
       [ 2.92187768e-04, -5.59027357e-02,  1.93373109e-01, ...,
         5.02722752e-17, -5.38358955e-11, -1.96854320e-19],
       ...,
       [-1.67119131e-17, -4.09067152e-17,  5.02722752e-17, ...,
         7.92784994e-05, -1.40191478e-18, -7.71184401e-09],
       [ 2.33769370e-11,  5.95167218e-11, -5.38358955e-11, ...,
        -1.40191478e-18,  2.62298532e-18,  2.60833699e-22],
       [ 5.36457920e-20,  1.43747902e-19, -1.96854320e-19, ...,
        -7.71184401e-09,  2.60833699e-22,  1.06099057e-12]],      dtype=float64)

In [18]:
dm.ndim

2

In [19]:
mft.scf_summary

{'e1': Array(-328.91219064, dtype=float64),
 'coul': Array(118.41701374, dtype=float64),
 'exc': Array(-1.291548, dtype=float64),
 'nuc': Array(30.88302552, dtype=float64)}

In [20]:
mft2 = scf.UHF(mol)
mft2.kernel()

Set gradient conv threshold to 3.16228e-05
cond(S) = 5161.200779305




E1 = -339.9987227720976  Ecoul = 110.41980412185697
init E= -198.69589312545
    CPU time for initialize scf      4.19 sec, wall time      0.40 sec
  alpha nocc = 9  HOMO = -0.498746994186616  LUMO = -0.191321862432996
  beta  nocc = 9  HOMO = -0.499904457890085  LUMO = -0.194315603453644
  alpha mo_energy =
[-26.4045579  -26.4037571   -1.88684238  -1.45655895  -0.81972716
  -0.7291923   -0.7291923   -0.49874699  -0.49874699  -0.19132186
   0.1151307    0.15379408   0.15680022   0.15680022   0.19876028
   0.19876028   0.2171668    0.39596016   0.54218201   0.54218201
   0.6553805    0.68289492   0.70150073   0.70150073   0.72971975
   0.72971975   0.73009382   0.73009382   0.84142105   0.84142105
   1.00246296   1.14341945   1.19331654   1.19331654   1.29132396
   1.68623583   2.18817925   2.18817925   2.39627741   2.39627741
   2.47620792   2.56774034   2.56774034   2.87334465   2.87334465
   3.31695278   3.59913212   3.59913212   3.70231898   3.95389954
   3.95389954   4.15124117   4

Array(-198.76770027, dtype=float64)

In [21]:
raise

RuntimeError: No active exception to reraise

In [None]:
mft.get_veff??

In [None]:
from pyscf import dft as dft_pyscf

In [None]:
mf3 = dft_pyscf.UKS(mol, xc='SCAN')
dminp = mf3.get_init_guess()
evxc2 = generate_network_eval_xc(mf=mf3, dm=dminp, network=xc)
mf3.grids.level =3
mf3.define_xc_(evxc, 'MGGA')
mf3.kernel()

In [None]:
vgf = lambda x: xc(x, ao_evals[0], gws[0], mf=mf, coor=mf.grids.coords)
mft.network = xc
mft.network_eval1 = vgf


In [None]:
class E_PySCFAD_loss(eqx.Module):
    def __init__(self):
        '''
        The standard energy loss module, RMSE loss of predicted vs. reference energies.
        '''
        super().__init__()

    def __call__(self, model, mf, inp_dm, ao, gw, ref_en):
        '''
        Computes the energy loss for a given model and associated input density matrix, atomic orbitals on the grid, and grid weights

        Loss is the RMSE energy, so predicted energy can potentially be a jax.Array of SCF guesses.

        :param model: The XC object whose forward pass predicts the XC energy based on the inputs here.
        :type model: xcquinox.xc.eXC
        :param inp_dm: The density matrix to pass into the network for density creation on the grid.
        :type inp_dm: jax.Array
        :param ref_en: The reference energy to take the loss with respect to.
        :type ref_en: jax.Array
        :param ao_eval: Atomic orbitals evaluated on the grid
        :type ao_eval: jax.Array
        :param grid_weights: pyscfad's grid weights for the reference calculation
        :type grid_weights: jax.Array
        :return: The RMSE error.
        :rtype: jax.Array
        '''
        print('generating eval_xc function to overwrite')
        # vgf = lambda x: model(x, ao, gw, mf=mf)
        # mf.network = model
        # mf.network_eval = vgf
        mf.max_memory=16000
        # evxc = generate_network_eval_xc(mf=mf, dm=inp_dm, ao=ao, gw=gw, network=model)
        evxc = generate_network_eval_xc(mf=mf, dm=inp_dm, network=model)
        mf.define_xc_(evxc, xctype='MGGA')
        print('predicting energy...')
        e_pred = mf.kernel()
        print('energy predicted')
        eL = jnp.sqrt( jnp.mean((e_pred-ref_en)**2))
        return eL


In [None]:
cpus = jax.devices(backend='cpu')

In [None]:
scheduler = optax.exponential_decay(init_value = 1e-2, transition_begin=50, transition_steps=500, decay_rate=0.9)
optimizer = optax.adam(learning_rate = 1e-2)
# optimizer = optax.adam(learning_rate = scheduler)

trainer = xce.train.xcTrainer(model=xc, optim=optimizer, steps=500, loss = E_PySCFAD_loss(), do_jit=False, logfile='log')
# with jax.default_device(cpus[0]):
#     newm = trainer(1, trainer.model, mfs, dms, ao_evals, gws, [-109.52596483])
newm = trainer(1, trainer.model, mfs, dms, ao_evals, gws, [-109.52596483])


Modifications in mldftdat, pyscfad, pyscf, xcquinox. 

xcquinox modifications took place in package repo, so no need to find differences

mldftdat and pyscfad changes occured in package directory, need to find differences for mldftdat to put into xcquinox-cider and a patch for pyscfad

In dft.rks; np.isnan check at line 57, import jax

In dft.numint, anywhere the loop generates subset ao/grids to loop over eval_xc with, tagged with #XCQUINOX MODIFICATION

In [None]:
jax.clear_backends()

In [None]:
jax.clear_caches()

In [None]:
eqx.clear_caches()

In [None]:
pyscfad.jax.clear_backends()

In [None]:
pyscfad.jax.clear_caches()

In [None]:
pyscfad.jax.c

# TO DEBUG THE INTERFACE AND GET CORRECT ENERGY CALCULATIONS

In [27]:
ranxc = xce.xc.get_xcfunc("GGA",
                          '/home/awills/Documents/Research/xcquinox/scripts/script_data/ctests/ran/xc_3_16_c0_gga',
                         )
pbexc = xce.xc.get_xcfunc("GGA",
                          '/home/awills/Documents/Research/xcquinox/scripts/script_data/ctests/pt/pbe/xc_3_16_c0_gga',
                         )
pbe2xc = xce.xc.get_xcfunc("GGA",
                          '/home/awills/Documents/Research/xcquinox/scripts/script_data/ctests/pt/pbe2/xc_3_16_c0_gga',
                         )


XNET spin scaling: True
CNET spin scaling: False
Deserializing XC Functional over created object
XNET spin scaling: True
CNET spin scaling: False
Deserializing XC Functional over created object
XNET spin scaling: True
CNET spin scaling: False
Deserializing XC Functional over created object


In [28]:
ranxc.grid_models[0].net.layers[0].bias

Array([-0.6721236 ,  0.48915458, -0.64653359,  0.75173173,  0.95819129,
        0.65768199,  0.22365171,  0.64063636, -0.69597844,  0.4891379 ,
       -0.76016807,  0.43088474,  0.84921562, -0.27464495, -0.75629146,
        0.94038789], dtype=float64)

In [29]:
#verify the weights are different after pre-training, as they should be
for i in range(2):
    for ilayer in range(len(ranxc.grid_models[i].net.layers)):
        ranw = ranxc.grid_models[i].net.layers[ilayer].weight
        ranb = ranxc.grid_models[i].net.layers[ilayer].bias
        pbew = pbexc.grid_models[i].net.layers[ilayer].weight
        pbeb = pbexc.grid_models[i].net.layers[ilayer].bias
        w_diff = ranw-pbew
        b_diff = ranb-pbeb
        print(w_diff.sum(), b_diff.sum())

7.76188234938763 3.0043397714992004
-2.46148213539284 5.135813680891678
1.810438811826577 0.8738129831622448
0.26918024619148895 -0.0005994378219750518
18.99120203392856 -1.97182129735826
22.862121459084474 1.8951738813824273
44.87940493475185 6.854707053829723
1.3044266291878372 -0.38016087490199657


In [30]:
#verify the weights are different after pre-training, as they should be
for i in range(2):
    for ilayer in range(len(ranxc.grid_models[i].net.layers)):
        ranw = ranxc.grid_models[i].net.layers[ilayer].weight
        ranb = ranxc.grid_models[i].net.layers[ilayer].bias
        pbew = pbe2xc.grid_models[i].net.layers[ilayer].weight
        pbeb = pbe2xc.grid_models[i].net.layers[ilayer].bias
        w_diff = ranw-pbew
        b_diff = ranb-pbeb
        print(w_diff.sum(), b_diff.sum())

3.2378630093471106 2.727299372969207
13.840822231286513 3.78122848956414
-0.03434163288361536 0.9722549419416174
0.5662886490583133 -0.0789865127355194
4.140205053600012 4.015335521224063
13.375862615363364 2.969000927522298
-11.544697851192598 2.5725671903038654
-2.5590382342663123 0.35710970861775976


In [31]:
atoms = Atoms('P', [[0, 0, 0]])
pos = atoms.positions
spec = atoms.get_chemical_symbols()
basis = '6-311++G(3df,2pd)'
sping = get_spin(atoms)
molgen = False
scount = 0
initspin = sping
mol_input = [[s,p] for s,p in zip(spec,pos)]
while not molgen:
    try:
        mol = gtoa.Mole(atom=mol_input, basis=basis, spin=sping-scount, charge=0)
        mol.build()
        molgen=True
    except RuntimeError:
        #spin disparity somehow, try with one less until 0
        if initspin > 0:
            print("RuntimeError. Trying with reduced spin.")
            scount += 1
        elif initspin == 0:
            print("RuntimeError. Trying with increased spin.")
            scount -= 1
        if sping-scount < 0:
            raise ValueError
print('S: ', mol.spin)
print(f'generated pyscfad mol: {type(mol), mol}')
mf = dfta.RKS(mol)
method = dfta.RKS
init_dm = mf.get_init_guess()

GET SPIN: Atoms Info
Atoms(symbols='P', pbc=False)
{}
Single atom and no spin specified in at.info
S:  3
generated pyscfad mol: (<class 'pyscfad.gto.mole.Mole'>, <pyscfad.gto.mole.Mole object at 0x7f5a9b470fa0>)


Random network, then PBE network (evidently the same, which is an error), then PBE2 network

In [32]:
print('Running short calculation to get ingredients for potential non-local network run...')
mf = dfta.RKS(mol)
result = Atoms(atoms)
ATOMGRID = 3
mf0 = method(mol)
mf0.max_cycle = -1
mf0.conv_tol = 1e-5
mf0.kernel()
print('Starting kernel calculation complete.')
# evxc = xce.pyscf.generate_network_eval_xc(mf0, init_dm, kwargs['custom_xc_net'])
evxc = xce.pyscf.generate_network_eval_xc(mf, init_dm, ranxc)
mf.grids.level = ATOMGRID if ATOMGRID else 3
mf.max_cycle = -1
mf.max_memory = 64000
print("Running calculation")
mf.define_xc_(evxc, 'MGGA')
mf.mo_coeff = mf0.mo_coeff
try:
    mf.kernel()
    if mf.e_tot >= 0 and kwargs.get('above0mo', False):
        print('non-negative total energy; trying to rewrite with mo_energy of homo')
        print(f'mf.e_tot = {mf.e_tot}')
        print(f'mf.mo_occ = {mf.mo_occ}\nmf.mo_energy={mf.mo_energy}')
        homo_i = jnp.max(jnp.nonzero(mf.mo_occ, size=init_dm.shape[0])[0])
        homo_e = mf.mo_energy[homo_i]
        print(f'homo_e = {homo_e}')
        mf.e_tot = homo_e
    elif mf.e_tot >= 0:
        print(f'NON-NEGATIVE ENERGY DETECTED.\n{str(atoms.symbols), mol, mf}\nENERGY={mf.e_tot}')
        raise
    result.calc = SinglePointCalculator(result)
    result.calc.results = {'energy' : mf.e_tot}
except Exception as e:
    print(e)
    print('Kernel calculation failed, perhaps hydrogen is acting up or there is another issue')
    print('Trying with UHF')
    vgf = lambda x: ranxc(x, mf._numint.eval_ao(mol, mf.grids.coords, deriv=2), mf.grids.weights, mf=mf, coor=mf.grids.coords)
    mf2 = scfa.UHF(mol)
    mf2.max_cycle = 50
    mf2.max_memory = 64000
    print('Setting network and network_eval1')
    mf2.network = ranxc
    mf2.network_eval = vgf
    print('Running UHF calculation')
    mf2.kernel()
    result.calc = SinglePointCalculator(result)
    result.calc.results = {'energy' : mf2.e_tot}


Running short calculation to get ingredients for potential non-local network run...

WARN: Invalid number of electrons 15 for RHF method.

LDA ni.block_loop; input ao.shape=(224, 47), weight.shape=(224,), coords.shape=(224, 3)
Exception raised: NumInt.eval_xc() got an unexpected keyword argument 'ao'
Traceback (most recent call last):
  File "/home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/pyscfad/dft/numint.py", line 130, in nr_rks
    exc, vxc = ni.eval_xc(xc_code=xc_code,
TypeError: NumInt.eval_xc() got an unexpected keyword argument 'ao'

Falling back to regular form
LDA ni.block_loop; input ao.shape=(224, 47), weight.shape=(224,), coords.shape=(224, 3)
Exception raised: NumInt.eval_xc() got an unexpected keyword argument 'ao'
Traceback (most recent call last):
  File "/home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/pyscfad/dft/numint.py", line 130, in nr_rks
    exc, vxc = ni.eval_xc(xc_code=xc_code,
TypeError: NumInt.eval_xc() got an unexpected k

In [33]:
mf.get_veff().exc, result.calc.results

MGGA ni.block_loop; input ao.shape=(10, 18806, 47), weight.shape=(18806,), coords.shape=(18806, 3)
v_and_g_inp.shape=(18806, 11)
eX.__call__, rho shape: (2, 18806, 2)
eX.__call__, rho nans: 0
eC.__call__, rho shape: (18806, 3)
eC.__call__, rho nans: 0
Exc_exc and vs returned: Exc = -19.456102170415644, exc.shape=(18806, 1), vs.shape=(18806, 11)
eval_xc Exc = -19.456102170415644


(Array(-19.45610217, dtype=float64),
 {'energy': Array(-340.11157985, dtype=float64)})

In [34]:
print('Running short calculation to get ingredients for potential non-local network run...')
mf = dfta.RKS(mol)
result = Atoms(atoms)
ATOMGRID = 3
mf0 = method(mol)
mf0.max_cycle = -1
mf0.conv_tol = 1e-5
mf0.kernel()
print('Starting kernel calculation complete.')
# evxc = xce.pyscf.generate_network_eval_xc(mf0, init_dm, kwargs['custom_xc_net'])
evxc = xce.pyscf.generate_network_eval_xc(mf, init_dm, pbexc)
mf.grids.level = ATOMGRID if ATOMGRID else 3
mf.max_cycle = -1
mf.max_memory = 64000
print("Running calculation")
mf.define_xc_(evxc, 'MGGA')
mf.mo_coeff = mf0.mo_coeff
try:
    mf.kernel()
    if mf.e_tot >= 0 and kwargs.get('above0mo', False):
        print('non-negative total energy; trying to rewrite with mo_energy of homo')
        print(f'mf.e_tot = {mf.e_tot}')
        print(f'mf.mo_occ = {mf.mo_occ}\nmf.mo_energy={mf.mo_energy}')
        homo_i = jnp.max(jnp.nonzero(mf.mo_occ, size=init_dm.shape[0])[0])
        homo_e = mf.mo_energy[homo_i]
        print(f'homo_e = {homo_e}')
        mf.e_tot = homo_e
    elif mf.e_tot >= 0:
        print(f'NON-NEGATIVE ENERGY DETECTED.\n{str(atoms.symbols), mol, mf}\nENERGY={mf.e_tot}')
        raise
    result.calc = SinglePointCalculator(result)
    result.calc.results = {'energy' : mf.e_tot}
except Exception as e:
    print(e)
    print('Kernel calculation failed, perhaps hydrogen is acting up or there is another issue')
    print('Trying with UHF')
    vgf = lambda x: pbexc(x, mf._numint.eval_ao(mol, mf.grids.coords, deriv=2), mf.grids.weights, mf=mf, coor=mf.grids.coords)
    mf2 = scfa.UHF(mol)
    mf2.max_cycle = 50
    mf2.max_memory = 64000
    print('Setting network and network_eval1')
    mf2.network = pbexc
    mf2.network_eval = vgf
    print('Running UHF calculation')
    mf2.kernel()
    result.calc = SinglePointCalculator(result)
    result.calc.results = {'energy' : mf2.e_tot}


Running short calculation to get ingredients for potential non-local network run...

WARN: Invalid number of electrons 15 for RHF method.

LDA ni.block_loop; input ao.shape=(224, 47), weight.shape=(224,), coords.shape=(224, 3)
Exception raised: NumInt.eval_xc() got an unexpected keyword argument 'ao'
Traceback (most recent call last):
  File "/home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/pyscfad/dft/numint.py", line 130, in nr_rks
    exc, vxc = ni.eval_xc(xc_code=xc_code,
TypeError: NumInt.eval_xc() got an unexpected keyword argument 'ao'

Falling back to regular form
LDA ni.block_loop; input ao.shape=(224, 47), weight.shape=(224,), coords.shape=(224, 3)
Exception raised: NumInt.eval_xc() got an unexpected keyword argument 'ao'
Traceback (most recent call last):
  File "/home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/pyscfad/dft/numint.py", line 130, in nr_rks
    exc, vxc = ni.eval_xc(xc_code=xc_code,
TypeError: NumInt.eval_xc() got an unexpected k

In [35]:
mf.get_veff().exc, result.calc.results

MGGA ni.block_loop; input ao.shape=(10, 18806, 47), weight.shape=(18806,), coords.shape=(18806, 3)
v_and_g_inp.shape=(18806, 11)
eX.__call__, rho shape: (2, 18806, 2)
eX.__call__, rho nans: 0
eC.__call__, rho shape: (18806, 3)
eC.__call__, rho nans: 0
Exc_exc and vs returned: Exc = -20.32673950035, exc.shape=(18806, 1), vs.shape=(18806, 11)
eval_xc Exc = -20.32673950035


(Array(-20.3267395, dtype=float64),
 {'energy': Array(-341.01018754, dtype=float64)})

In [36]:
print('Running short calculation to get ingredients for potential non-local network run...')
mf = dfta.RKS(mol)
result = Atoms(atoms)
ATOMGRID = 3
mf0 = method(mol)
mf0.max_cycle = -1
mf0.conv_tol = 1e-5
mf0.kernel()
print('Starting kernel calculation complete.')
# evxc = xce.pyscf.generate_network_eval_xc(mf0, init_dm, kwargs['custom_xc_net'])
evxc = xce.pyscf.generate_network_eval_xc(mf, init_dm, pbe2xc)
mf.grids.level = ATOMGRID if ATOMGRID else 3
mf.max_cycle = -1
mf.max_memory = 64000
print("Running calculation")
mf.define_xc_(evxc, 'MGGA')
mf.mo_coeff = mf0.mo_coeff
try:
    mf.kernel()
    if mf.e_tot >= 0 and kwargs.get('above0mo', False):
        print('non-negative total energy; trying to rewrite with mo_energy of homo')
        print(f'mf.e_tot = {mf.e_tot}')
        print(f'mf.mo_occ = {mf.mo_occ}\nmf.mo_energy={mf.mo_energy}')
        homo_i = jnp.max(jnp.nonzero(mf.mo_occ, size=init_dm.shape[0])[0])
        homo_e = mf.mo_energy[homo_i]
        print(f'homo_e = {homo_e}')
        mf.e_tot = homo_e
    elif mf.e_tot >= 0:
        print(f'NON-NEGATIVE ENERGY DETECTED.\n{str(atoms.symbols), mol, mf}\nENERGY={mf.e_tot}')
        raise
    result.calc = SinglePointCalculator(result)
    result.calc.results = {'energy' : mf.e_tot}
except Exception as e:
    print(e)
    print('Kernel calculation failed, perhaps hydrogen is acting up or there is another issue')
    print('Trying with UHF')
    vgf = lambda x: pbe2xc(x, mf._numint.eval_ao(mol, mf.grids.coords, deriv=2), mf.grids.weights, mf=mf, coor=mf.grids.coords)
    mf2 = scfa.UHF(mol)
    mf2.max_cycle = 50
    mf2.max_memory = 64000
    print('Setting network and network_eval1')
    mf2.network = pbe2xc
    mf2.network_eval = vgf
    print('Running UHF calculation')
    mf2.kernel()
    result.calc = SinglePointCalculator(result)
    result.calc.results = {'energy' : mf2.e_tot}


Running short calculation to get ingredients for potential non-local network run...

WARN: Invalid number of electrons 15 for RHF method.

LDA ni.block_loop; input ao.shape=(224, 47), weight.shape=(224,), coords.shape=(224, 3)
Exception raised: NumInt.eval_xc() got an unexpected keyword argument 'ao'
Traceback (most recent call last):
  File "/home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/pyscfad/dft/numint.py", line 130, in nr_rks
    exc, vxc = ni.eval_xc(xc_code=xc_code,
TypeError: NumInt.eval_xc() got an unexpected keyword argument 'ao'

Falling back to regular form
LDA ni.block_loop; input ao.shape=(224, 47), weight.shape=(224,), coords.shape=(224, 3)
Exception raised: NumInt.eval_xc() got an unexpected keyword argument 'ao'
Traceback (most recent call last):
  File "/home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/pyscfad/dft/numint.py", line 130, in nr_rks
    exc, vxc = ni.eval_xc(xc_code=xc_code,
TypeError: NumInt.eval_xc() got an unexpected k

In [37]:
mf.get_veff().exc, result.calc.results

MGGA ni.block_loop; input ao.shape=(10, 18806, 47), weight.shape=(18806,), coords.shape=(18806, 3)
v_and_g_inp.shape=(18806, 11)
eX.__call__, rho shape: (2, 18806, 2)
eX.__call__, rho nans: 0
eC.__call__, rho shape: (18806, 3)
eC.__call__, rho nans: 0
Exc_exc and vs returned: Exc = -20.300628691743057, exc.shape=(18806, 1), vs.shape=(18806, 11)
eval_xc Exc = -20.300628691743057


(Array(-20.30062869, dtype=float64),
 {'energy': Array(-341.0190469, dtype=float64)})

In [38]:
print('Running short calculation to get ingredients for potential non-local network run...')
mf = dfta.RKS(mol)
result = Atoms(atoms)
ATOMGRID = 3
mf0 = method(mol)
mf0.max_cycle = -1
mf0.conv_tol = 1e-5
mf0.kernel()
print('Starting kernel calculation complete.')
mf.grids.level = ATOMGRID if ATOMGRID else 3
mf.max_cycle = 50
mf.max_memory = 64000
print("Running calculation")
mf.xc = 'PBE'
try:
    mf.kernel()
    if mf.e_tot >= 0 and kwargs.get('above0mo', False):
        print('non-negative total energy; trying to rewrite with mo_energy of homo')
        print(f'mf.e_tot = {mf.e_tot}')
        print(f'mf.mo_occ = {mf.mo_occ}\nmf.mo_energy={mf.mo_energy}')
        homo_i = jnp.max(jnp.nonzero(mf.mo_occ, size=init_dm.shape[0])[0])
        homo_e = mf.mo_energy[homo_i]
        print(f'homo_e = {homo_e}')
        mf.e_tot = homo_e
    elif mf.e_tot >= 0:
        print(f'NON-NEGATIVE ENERGY DETECTED.\n{str(atoms.symbols), mol, mf}\nENERGY={mf.e_tot}')
        raise
    result.calc = SinglePointCalculator(result)
    result.calc.results = {'energy' : mf.e_tot}
except Exception as e:
    print(e)
    print('Kernel calculation failed, perhaps hydrogen is acting up or there is another issue')
    print('Trying with UHF')
    vgf = lambda x: pbe2xc(x, mf._numint.eval_ao(mol, mf.grids.coords, deriv=2), mf.grids.weights, mf=mf, coor=mf.grids.coords)
    mf2 = scfa.UHF(mol)
    mf2.max_cycle = 50
    mf2.max_memory = 64000
    print('Setting network and network_eval1')
    mf2.network = pbe2xc
    mf2.network_eval = vgf
    print('Running UHF calculation')
    mf2.kernel()
    result.calc = SinglePointCalculator(result)
    result.calc.results = {'energy' : mf2.e_tot}


Running short calculation to get ingredients for potential non-local network run...

WARN: Invalid number of electrons 15 for RHF method.

LDA ni.block_loop; input ao.shape=(224, 47), weight.shape=(224,), coords.shape=(224, 3)
Exception raised: NumInt.eval_xc() got an unexpected keyword argument 'ao'
Traceback (most recent call last):
  File "/home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/pyscfad/dft/numint.py", line 130, in nr_rks
    exc, vxc = ni.eval_xc(xc_code=xc_code,
TypeError: NumInt.eval_xc() got an unexpected keyword argument 'ao'

Falling back to regular form
LDA ni.block_loop; input ao.shape=(224, 47), weight.shape=(224,), coords.shape=(224, 3)
Exception raised: NumInt.eval_xc() got an unexpected keyword argument 'ao'
Traceback (most recent call last):
  File "/home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/pyscfad/dft/numint.py", line 130, in nr_rks
    exc, vxc = ni.eval_xc(xc_code=xc_code,
TypeError: NumInt.eval_xc() got an unexpected k



GGA ni.block_loop; input ao.shape=(4, 18806, 47), weight.shape=(18806,), coords.shape=(18806, 3)
Exception raised: NumInt.eval_xc() got multiple values for argument 'spin'
Traceback (most recent call last):
  File "/home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/pyscfad/dft/numint.py", line 169, in nr_rks
    exc, vxc = ni.eval_xc(xc_code=xc_code,
TypeError: NumInt.eval_xc() got multiple values for argument 'spin'

Falling back to regular form
GGA ni.block_loop; input ao.shape=(4, 18806, 47), weight.shape=(18806,), coords.shape=(18806, 3)
Exception raised: NumInt.eval_xc() got multiple values for argument 'spin'
Traceback (most recent call last):
  File "/home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/pyscfad/dft/numint.py", line 169, in nr_rks
    exc, vxc = ni.eval_xc(xc_code=xc_code,
TypeError: NumInt.eval_xc() got multiple values for argument 'spin'

Falling back to regular form




GGA ni.block_loop; input ao.shape=(4, 18806, 47), weight.shape=(18806,), coords.shape=(18806, 3)
Exception raised: NumInt.eval_xc() got multiple values for argument 'spin'
Traceback (most recent call last):
  File "/home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/pyscfad/dft/numint.py", line 169, in nr_rks
    exc, vxc = ni.eval_xc(xc_code=xc_code,
TypeError: NumInt.eval_xc() got multiple values for argument 'spin'

Falling back to regular form
GGA ni.block_loop; input ao.shape=(4, 18806, 47), weight.shape=(18806,), coords.shape=(18806, 3)
Exception raised: NumInt.eval_xc() got multiple values for argument 'spin'
Traceback (most recent call last):
  File "/home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/pyscfad/dft/numint.py", line 169, in nr_rks
    exc, vxc = ni.eval_xc(xc_code=xc_code,
TypeError: NumInt.eval_xc() got multiple values for argument 'spin'

Falling back to regular form
GGA ni.block_loop; input ao.shape=(4, 18806, 47), weight.shape=(18806,

In [16]:
mf.get_veff().exc, result.calc.results

GGA ni.block_loop; input ao.shape=(4, 18806, 47), weight.shape=(18806,), coords.shape=(18806, 3)
Exception raised: NumInt.eval_xc() got multiple values for argument 'spin'
Traceback (most recent call last):
  File "/home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/pyscfad/dft/numint.py", line 169, in nr_rks
    exc, vxc = ni.eval_xc(xc_code=xc_code,
TypeError: NumInt.eval_xc() got multiple values for argument 'spin'

Falling back to regular form


(Array(-22.69791352, dtype=float64),
 {'energy': Array(-340.65225412, dtype=float64)})

In [39]:
def generate_network_eval_xc2(mf, dm, network):
    '''
    Generates a function to overwrite eval_xc with on the mf object, for use in training with pyscfad's SCF cycle

    :param mf: Pyscfad calculation kernel object
    :type mf: Pyscfad calculation kernel object
    :param dm: Initial density matrix to use in the cycle
    :type dm: jax.Array
    :param network: The network to use in evaluating the SCF cycle
    :type network: xcquinox.xc.eXC
    :return: A function `eval_xc` that uses an xcquinox network as the pyscfad kernel calculation driver.
    :rtype: function

    The returned function:

    eval_xc(xc_code, rho, ao, gw, coords, spin=0, relativity=0, deriv=1, omega=None, verbose=None)
    The function to use as driver for a pyscf(ad) calculation, using an xcquinox network.

    This overwrites mf.eval_xc with a custom function, evaluating:

    Exc_exc, vs = jax.value_and_grad(EXC_exc_vs, has_aux=True)(jnp.concatenate([jnp.expand_dims(rho0_a,-1),
                                            jnp.expand_dims(rho0_b,-1),
                                            jnp.expand_dims(gamma_a,-1),
                                            jnp.expand_dims(gamma_ab,-1),
                                            jnp.expand_dims(gamma_b,-1),
                                            jnp.expand_dims(jnp.zeros_like(rho0_a),-1), #Dummy for laplacian
                                            jnp.expand_dims(jnp.zeros_like(rho0_a),-1), #Dummy for laplacian
                                            jnp.expand_dims(tau_a,-1),
                                            jnp.expand_dims(tau_b,-1),
                                            jnp.expand_dims(non_loc_a,-1),
                                            jnp.expand_dims(non_loc_b,-1)],axis=-1))


        :param xc_code: The XC functional code string in libxc format, but it is ignored as the network is the calculation driver
        :type xc_code: str
        :param rho: The [..., *, N] arrays (... for spin polarized), N is the number of grid points.
                    rho (*,N) ordered as (rho, grad_x, grad_y, grad_z, laplacian, tau)
                    rho (2,*,N) is [(rho_up, grad_x_up, grad_y_up, grad_z_up, laplacian_up, tau_up),
                                    (rho_down, grad_x_down, grad_y_down, grad_z_down, laplacian_down, tau_down)]
                    PySCFAD doesn't do spin-polarized grid calculations yet, so this will be unpolarized.
        :type rho: jax.Array
        :param ao: The atomic orbitals on the grid to use in the network calculation. Explcitly specified as the block loops break down the grid if memory is too low
        :type ao: jax.Array
        :param ao: The grid weights to use in the network calculation. Explcitly specified as the block loops break down the grid if memory is too low
        :type ao: jax.Array
        :param ao: The grid coordinates to use in the network calculation. Explcitly specified as the block loops break down the grid if memory is too low
        :type ao: jax.Array
        :param spin: The spin of the calculation, integer valued, polarized if non-zero, defaults to zero
        :type spin: int
        :param relativity: Integer, unused right now, defaults to zero
        :type relativity: int
        :param deriv: Unused here, defaults to 1
        :type deriv: int
        :param omega: Hybrid mixing term, unused here, defaults to None
        :type omega: float
        :param verbose: Unused here, defaults to None
        :type verbose: int
        :return: ex, vxc, fxc, kxc
                 where: ex -> exc, XC energy density on the grid
                        vxc -> (vrho, vsigma, vlapl, vtau), gradients of the exc w.r.t. the quantities given.
                        Only vrho and vtau are used, vsigma=vlapl=fxc=kxc=None.
                        vrho = vs[:, 0]+vs[:, 1]
                        vtau = vs[:, 7]+vs[:, 8]
        
        :rtype: tuple
    '''
    def eval_xc(xc_code, rho, ao, gw, coords, spin=0, relativity=0, deriv=1, omega=None, verbose=None):
        '''
        The function to use as driver for a pyscf(ad) calculation, using an xcquinox network.

        This overwrites mf.eval_xc with a custom function, evaluating:

        Exc_exc, vs = jax.value_and_grad(EXC_exc_vs, has_aux=True)(jnp.concatenate([jnp.expand_dims(rho0_a,-1),
                                                jnp.expand_dims(rho0_b,-1),
                                                jnp.expand_dims(gamma_a,-1),
                                                jnp.expand_dims(gamma_ab,-1),
                                                jnp.expand_dims(gamma_b,-1),
                                                jnp.expand_dims(jnp.zeros_like(rho0_a),-1), #Dummy for laplacian
                                                jnp.expand_dims(jnp.zeros_like(rho0_a),-1), #Dummy for laplacian
                                                jnp.expand_dims(tau_a,-1),
                                                jnp.expand_dims(tau_b,-1),
                                                jnp.expand_dims(non_loc_a,-1),
                                                jnp.expand_dims(non_loc_b,-1)],axis=-1))

        :param xc_code: The XC functional code string in libxc format, but it is ignored as the network is the calculation driver
        :type xc_code: str
        :param rho: The [..., *, N] arrays (... for spin polarized), N is the number of grid points.
                    rho (*,N) ordered as (rho, grad_x, grad_y, grad_z, laplacian, tau)
                    rho (2,*,N) is [(rho_up, grad_x_up, grad_y_up, grad_z_up, laplacian_up, tau_up),
                                    (rho_down, grad_x_down, grad_y_down, grad_z_down, laplacian_down, tau_down)]
                    PySCFAD doesn't do spin-polarized grid calculations yet, so this will be unpolarized.
        :type rho: jax.Array
        :param ao: The atomic orbitals on the grid to use in the network calculation. Explcitly specified as the block loops break down the grid if memory is too low
        :type ao: jax.Array
        :param ao: The grid weights to use in the network calculation. Explcitly specified as the block loops break down the grid if memory is too low
        :type ao: jax.Array
        :param ao: The grid coordinates to use in the network calculation. Explcitly specified as the block loops break down the grid if memory is too low
        :type ao: jax.Array
        :param spin: The spin of the calculation, integer valued, polarized if non-zero, defaults to zero
        :type spin: int
        :param relativity: Integer, unused right now, defaults to zero
        :type relativity: int
        :param deriv: Unused here, defaults to 1
        :type deriv: int
        :param omega: Hybrid mixing term, unused here, defaults to None
        :type omega: float
        :param verbose: Unused here, defaults to None
        :type verbose: int
        :return: ex, vxc, fxc, kxc
                 where: ex -> exc, XC energy density on the grid
                        vxc -> (vrho, vsigma, vlapl, vtau), gradients of the exc w.r.t. the quantities given.
                        Only vrho and vtau are used, vsigma=vlapl=fxc=kxc=None.
                        vrho = vs[:, 0]+vs[:, 1]
                        vtau = vs[:, 7]+vs[:, 8]
        
        :rtype: tuple
        '''
        # print('custom eval_xc; input rho shape: ', rho.shape)
        if len(rho.shape) == 2:
            #not spin-polarized
            rho0 = rho[0] #density
            drho = rho[1:4] #grad_x, grad_y, grad_z
            #laplacian next
            # tau = 0.5*(rho[1] + rho[2] + rho[3])
            tau = rho[-1] # tau
            
            non_loc = jnp.zeros_like(tau)
            #decompose into spin channels
            rho0_a = rho0_b = rho0*0.5
            gamma_a=gamma_b=gamma_ab= jnp.einsum('ij,ij->j',drho[:],drho[:])*0.25
            tau_a = tau_b = tau*0.5
            non_loc_a=non_loc_b = non_loc*0.5
            if network.verbose:
                print(f'decomposed shapes:\nrho0={rho0.shape}\ndrho={drho.shape}\ntau={tau.shape}\nnon_loc={non_loc.shape}')
                print(f'decomposed shapes:\ngamma_a={gamma_a.shape}\ngamma_b={gamma_b.shape}\ngamma_ab={gamma_ab.shape}')
        else:
            #spin-polarized density
            rho0_a = rho[0, 0]
            rho0_b = rho[1, 0]

            drho_a = rho[0, 1:4]
            drho_b = rho[1, 1:4]
            # jnp.einsumed density gradient
            gamma_a, gamma_b = jnp.einsum('ij,ij->j',drho_a,drho_a), jnp.einsum('ij,ij->j',drho_b,drho_b)
            gamma_ab = jnp.einsum('ij,ij->j',drho_a,drho_b)
            # Kinetic energy density
            tau_a = rho[0, -1]
            tau_b = rho[1, -1]

            non_loc_a, non_loc_b = jnp.zeros_like(tau_a), jnp.zeros_like(tau_b)
            if network.verbose:
                print(f'decomposed shapes:\nrho0(a,b)={rho0_a.shape},{rho0_b.shape}\ndrho(a,b)={drho_a.shape},{drho_b.shape}\ntau(a,b)={tau_a.shape},{tau_b.shape}\nnon_loc(a,b)={non_loc_a.shape},{non_loc_b.shape}')
                print(f'decomposed shapes:\ngamma_a={gamma_a.shape}\ngamma_b={gamma_b.shape}\ngamma_ab={gamma_ab.shape}')


        # xc-energy per unit particle
        # print(f'EVALUATING GRID MODELS; OPTIONAL PARAMETERS:')
        # try:
        #     print(f'gw.shape={gw.shape}, coor.shape={coor.shape}')
        # except:
        #     print('no externally supplied gw or coor')
        # print('eval_xc eval_grid_models call')
        
        def EXC_exc_vs(x):
            exc = network.eval_grid_models(x, mf=mf, dm=dm, ao=ao, gw=gw, coor=coords)
            Exc = jnp.sum(((rho0_a + rho0_b)*exc[:,0])*gw)
            return Exc, exc
        if network.verbose:
            print(f'eval_xc -> Exc_exc and potentials on grid via autodiff')
        v_and_g_inp = jnp.concatenate([jnp.expand_dims(rho0_a,-1),
                                                jnp.expand_dims(rho0_b,-1),
                                                jnp.expand_dims(gamma_a,-1),
                                                jnp.expand_dims(gamma_ab,-1),
                                                jnp.expand_dims(gamma_b,-1),
                                                jnp.expand_dims(jnp.zeros_like(rho0_a),-1), #Dummy for laplacian
                                                jnp.expand_dims(jnp.zeros_like(rho0_a),-1), #Dummy for laplacian
                                                jnp.expand_dims(tau_a,-1),
                                                jnp.expand_dims(tau_b,-1),
                                                jnp.expand_dims(non_loc_a,-1),
                                                jnp.expand_dims(non_loc_b,-1)],axis=-1)
        print(f'v_and_g_inp.shape={v_and_g_inp.shape}')
        Exc_exc, vs = jax.value_and_grad(EXC_exc_vs, has_aux=True)(v_and_g_inp)
        print(f'Exc_exc and vs returned: Exc = {Exc_exc[0]}, exc.shape={Exc_exc[1].shape}, vs.shape={vs.shape}')
        Exc, exc = Exc_exc
        print(f'eval_xc Exc = {Exc}')
        if jnp.sum(jnp.isnan(exc[:, 0])):
            print('NaNs detected in exc. Number of NaNs: {}'.format(jnp.sum(jnp.isnan(exc[:, 0]))))
            raise
        else:
            exc = exc[:, 0]
            
        # print('ao shape: ', ao.shape)
        # print('exc from network evaluation on grid models shape: ', exc.shape)
        # print('vs from network evaluation on grid models shape: ', vs.shape)
        # print('Exc from network evaluation on grid models shape: ', Exc)

        vgf = lambda x: network(x, ao, gw, mf=mf, coor=coords)
        mf.converged = True
        mf.network = network
        mf.network_eval = vgf

        #vrho; d Exc/d rho, separate spin channels
        vrho = 1.0*(vs[:, 0]+vs[:, 1])
        #vtau; d Exc/d tau, separate spin channels
        vtau = 1.0*(vs[:, 7]+vs[:, 8])

        #vgamma; d Exc/d gamma
        vgamma = 1.0*(vs[:, 2] + vs[:, 3] + vs[:, 4])
        
        vlapl = None
        
        fxc = None #second order functional derivative
        kxc = None #third order functional derivative
        if network.verbose:
            print(f'shapes: vrho={vrho.shape}, vgamma={vgamma.shape}')
        return exc, (vrho, vgamma, vlapl, vtau), fxc, kxc
    return eval_xc

In [40]:
print('Running short calculation to get ingredients for potential non-local network run...')
mf = dfta.RKS(mol)
result = Atoms(atoms)
ATOMGRID = 3
mf0 = method(mol)
mf0.max_cycle = -1
mf0.conv_tol = 1e-5
mf0.kernel()
print('Starting kernel calculation complete.')
# evxc = xce.pyscf.generate_network_eval_xc(mf0, init_dm, kwargs['custom_xc_net'])
evxc = generate_network_eval_xc2(mf, init_dm, pbexc)
mf.grids.level = ATOMGRID if ATOMGRID else 3
mf.max_cycle = -1
mf.max_memory = 64000
print("Running calculation")
mf.define_xc_(evxc, 'MGGA')
mf.mo_coeff = mf0.mo_coeff
try:
    mf.kernel()
    if mf.e_tot >= 0 and kwargs.get('above0mo', False):
        print('non-negative total energy; trying to rewrite with mo_energy of homo')
        print(f'mf.e_tot = {mf.e_tot}')
        print(f'mf.mo_occ = {mf.mo_occ}\nmf.mo_energy={mf.mo_energy}')
        homo_i = jnp.max(jnp.nonzero(mf.mo_occ, size=init_dm.shape[0])[0])
        homo_e = mf.mo_energy[homo_i]
        print(f'homo_e = {homo_e}')
        mf.e_tot = homo_e
    elif mf.e_tot >= 0:
        print(f'NON-NEGATIVE ENERGY DETECTED.\n{str(atoms.symbols), mol, mf}\nENERGY={mf.e_tot}')
        raise
    result.calc = SinglePointCalculator(result)
    result.calc.results = {'energy' : mf.e_tot}
except Exception as e:
    print(e)
    print('Kernel calculation failed, perhaps hydrogen is acting up or there is another issue')
    print('Trying with UHF')
    vgf = lambda x: pbexc(x, mf._numint.eval_ao(mol, mf.grids.coords, deriv=2), mf.grids.weights, mf=mf, coor=mf.grids.coords)
    mf2 = scfa.UHF(mol)
    mf2.max_cycle = 50
    mf2.max_memory = 64000
    print('Setting network and network_eval1')
    mf2.network = pbexc
    mf2.network_eval = vgf
    print('Running UHF calculation')
    mf2.kernel()
    result.calc = SinglePointCalculator(result)
    result.calc.results = {'energy' : mf2.e_tot}


Running short calculation to get ingredients for potential non-local network run...

WARN: Invalid number of electrons 15 for RHF method.

LDA ni.block_loop; input ao.shape=(224, 47), weight.shape=(224,), coords.shape=(224, 3)
Exception raised: NumInt.eval_xc() got an unexpected keyword argument 'ao'
Traceback (most recent call last):
  File "/home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/pyscfad/dft/numint.py", line 130, in nr_rks
    exc, vxc = ni.eval_xc(xc_code=xc_code,
TypeError: NumInt.eval_xc() got an unexpected keyword argument 'ao'

Falling back to regular form
LDA ni.block_loop; input ao.shape=(224, 47), weight.shape=(224,), coords.shape=(224, 3)
Exception raised: NumInt.eval_xc() got an unexpected keyword argument 'ao'
Traceback (most recent call last):
  File "/home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/pyscfad/dft/numint.py", line 130, in nr_rks
    exc, vxc = ni.eval_xc(xc_code=xc_code,
TypeError: NumInt.eval_xc() got an unexpected k

In [41]:
mf.get_veff().exc, result.calc.results, mf.e_tot

MGGA ni.block_loop; input ao.shape=(10, 18806, 47), weight.shape=(18806,), coords.shape=(18806, 3)
v_and_g_inp.shape=(18806, 11)
eX.__call__, rho shape: (2, 18806, 2)
eX.__call__, rho nans: 0
eC.__call__, rho shape: (18806, 3)
eC.__call__, rho nans: 0
Exc_exc and vs returned: Exc = -20.33136164852754, exc.shape=(18806, 1), vs.shape=(18806, 11)
eval_xc Exc = -20.33136164852754


(Array(-20.33136165, dtype=float64),
 {'energy': Array(-341.01018754, dtype=float64)},
 Array(-341.01018754, dtype=float64))