In [7]:
from ase import Atoms
from ase.calculators.singlepoint import SinglePointCalculator
from ase.units import Hartree
from ase.io import read, write
from time import time

#pyscfad imports
from pyscfad import dft as dfta
from pyscfad import gto as gtoa 
from pyscfad import scf as scfa
from pyscfad import cc as cca
from pyscfad.scf import hf as hfa
from pyscfad import pbc as pbca
from pyscfad.pbc import scf as scfpa
from pyscfad.pbc import gto as gtopa
from pyscfad.pbc import dft as dftpa
import equinox as eqx
import jax.numpy as jnp
import xcquinox as xce
import os, optax, jax
#spins for single atoms, since pyscf doesn't guess this correctly.
spins_dict = {
    'Al': 1,
    'B' : 1,
    'Li': 1,
    'Na': 1,
    'Si': 2 ,
    'Be':0,
    'C': 2,
    'Cl': 1,
    'F': 1,
    'H': 1,
    'N': 3,
    'O': 2,
    'P': 3,
    'S': 2,
    'Ar':0, #noble
    'Br':1, #one unpaired electron
    'Ne':0, #noble
    'Sb':3, #same column as N/P
    'Bi':3, #same column as N/P/Sb
    'Te':2, #same column as O/S
    'I':1 #one unpaired electron
}
def get_spin(at):
    #if single atom and spin is not specified in at.info dictionary, use spins_dict
    print('======================')
    print("GET SPIN: Atoms Info")
    print(at)
    print(at.info)
    print('======================')
    if ( (len(at.positions) == 1) and not ('spin' in at.info) ):
        print("Single atom and no spin specified in at.info")
        spin = spins_dict[str(at.symbols)]
    else:
        print("Not a single atom, or spin in at.info")
        if type(at.info.get('spin', None)) == type(0):
            #integer specified in at.info['spin'], so use it
            print('Spin specified in atom info.')
            spin = at.info['spin']
        elif 'radical' in at.info.get('name', ''):
            print('Radical specified in atom.info["name"], assuming spin 1.')
            spin = 1
        elif at.info.get('openshell', None):
            print("Openshell specified in atom info, attempting spin 2.")
            spin = 2
        else:
            print("No specifications in atom info to help, assuming no spin.")
            spin = 0
    return spin

In [40]:
atoms = Atoms('P', [[0, 0, 0]])
pos = atoms.positions
spec = atoms.get_chemical_symbols()


In [48]:
xc = xce.xc.get_xcfunc('GGA', 
                       '/home/awills/Documents/Research/xcquinox/scripts/script_data/ctests/pt/pbe2/xc_3_16_c0_gga',
                      )

XNET spin scaling: True
CNET spin scaling: False
Deserializing XC Functional over created object


In [49]:
basis = '6-311++G(3df,2pd)'
sping = get_spin(atoms)
molgen = False
scount = 0
initspin = sping
mol_input = [[s,p] for s,p in zip(spec,pos)]
while not molgen:
    try:
        mol = gtoa.Mole(atom=mol_input, basis=basis, spin=sping-scount, charge=0)
        mol.build()
        molgen=True
    except RuntimeError:
        #spin disparity somehow, try with one less until 0
        if initspin > 0:
            print("RuntimeError. Trying with reduced spin.")
            scount += 1
        elif initspin == 0:
            print("RuntimeError. Trying with increased spin.")
            scount -= 1
        if sping-scount < 0:
            raise ValueError
print('S: ', mol.spin)
print(f'generated pyscfad mol: {type(mol), mol}')


GET SPIN: Atoms Info
Atoms(symbols='P', pbc=False)
{}
Single atom and no spin specified in at.info
S:  3
generated pyscfad mol: (<class 'pyscfad.gto.mole.Mole'>, <pyscfad.gto.mole.Mole object at 0x7f05d6365300>)


In [50]:
mf = dfta.RKS(mol)
method = dfta.RKS
init_dm = mf.get_init_guess()

In [51]:
mol = mf.mol

In [52]:
mol.atom

[['P', array([0., 0., 0.])]]

In [53]:
print('Running short calculation to get ingredients for potential non-local network run...')
result = Atoms(atoms)
ATOMGRID = 3
mf0 = method(mol)
mf0.max_cycle = -1
mf0.conv_tol = 1e-5
mf0.kernel()
print('Starting kernel calculation complete.')
# evxc = xce.pyscf.generate_network_eval_xc(mf0, init_dm, kwargs['custom_xc_net'])
evxc = xce.pyscf.generate_network_eval_xc(mf, init_dm, xc)
mf.grids.level = ATOMGRID if ATOMGRID else 3
mf.max_cycle = 50
mf.max_memory = 64000
print("Running calculation")
mf.define_xc_(evxc, 'MGGA')
mf.mo_coeff = mf0.mo_coeff
try:
    mf.kernel()
    if mf.e_tot >= 0 and kwargs.get('above0mo', False):
        print('non-negative total energy; trying to rewrite with mo_energy of homo')
        print(f'mf.e_tot = {mf.e_tot}')
        print(f'mf.mo_occ = {mf.mo_occ}\nmf.mo_energy={mf.mo_energy}')
        homo_i = jnp.max(jnp.nonzero(mf.mo_occ, size=init_dm.shape[0])[0])
        homo_e = mf.mo_energy[homo_i]
        print(f'homo_e = {homo_e}')
        mf.e_tot = homo_e
    elif mf.e_tot >= 0:
        print(f'NON-NEGATIVE ENERGY DETECTED.\n{str(atoms.symbols), mol, mf}\nENERGY={mf.e_tot}')
        raise
    result.calc = SinglePointCalculator(result)
    result.calc.results = {'energy' : mf.e_tot}
except Exception as e:
    print(e)
    print('Kernel calculation failed, perhaps hydrogen is acting up or there is another issue')
    print('Trying with UHF')
    vgf = lambda x: xc(x, mf._numint.eval_ao(mol, mf.grids.coords, deriv=2), mf.grids.weights, mf=mf, coor=mf.grids.coords)
    mf2 = scfa.UHF(mol)
    mf2.max_cycle = 50
    mf2.max_memory = 64000
    print('Setting network and network_eval1')
    mf2.network = xc
    mf2.network_eval = vgf
    print('Running UHF calculation')
    mf2.kernel()
    result.calc = SinglePointCalculator(result)
    result.calc.results = {'energy' : mf2.e_tot}


Running short calculation to get ingredients for potential non-local network run...

WARN: Invalid number of electrons 15 for RHF method.

LDA ni.block_loop; input ao.shape=(18806, 47), weight.shape=(18806,), coords.shape=(18806, 3)
Exception raised: NumInt.eval_xc() got multiple values for argument 'spin'
Traceback (most recent call last):
  File "/home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/pyscfad/dft/numint.py", line 130, in nr_rks
    exc, vxc = ni.eval_xc(xc_code, rho, ao, weight, coords, spin=0,
TypeError: NumInt.eval_xc() got multiple values for argument 'spin'

Falling back to regular form
Starting kernel calculation complete.
Running calculation

WARN: Invalid number of electrons 15 for RHF method.





MGGA ni.block_loop; input ao.shape=(10, 18806, 47), weight.shape=(18806,), coords.shape=(18806, 3)
v_and_g_inp.shape=(18806, 11)
eX.__call__, rho shape: (2, 18806, 2)
eX.__call__, rho nans: 0
eC.__call__, rho shape: (18806, 3)
eC.__call__, rho nans: 0
Exc_exc and vs returned: Exc = -23.090804853630072, exc.shape=(18806, 1), vs.shape=(18806, 11)
eval_xc Exc = -23.090804853630072
MGGA ni.block_loop; input ao.shape=(10, 18806, 47), weight.shape=(18806,), coords.shape=(18806, 3)
v_and_g_inp.shape=(18806, 11)
eX.__call__, rho shape: (2, 18806, 2)
eX.__call__, rho nans: 0
eC.__call__, rho shape: (18806, 3)
eC.__call__, rho nans: 0
Exc_exc and vs returned: Exc = -20.300628691743057, exc.shape=(18806, 1), vs.shape=(18806, 11)
eval_xc Exc = -20.300628691743057
eX.__call__, rho shape: (2, 18806, 2)
eX.__call__, rho nans: 0
eC.__call__, rho shape: (18806, 3)
eC.__call__, rho nans: 0
MGGA ni.block_loop; input ao.shape=(10, 18806, 47), weight.shape=(18806,), coords.shape=(18806, 3)
v_and_g_inp.shap

In [54]:
mf.e_tot

Array(-299.54077136, dtype=float64)