In [1]:
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
import equinox as eqx
import optax
import numpy as np
from pyscf import dft, scf, gto, cc
from pyscfad import dft as dft_ad
from pyscfad import gto as gto_ad
from pyscfad import scf as scf_ad
from functools import partial
import pylibxc
import pyscfad.dft as dftad
from jax import custom_jvp
jax.config.update("jax_enable_x64", True) #Enables 64 bit precision
import pyscf as PSCF
import pyscfad as PSCFAD

from xcquinox import net
from xcquinox.loss import compute_loss_mae
from xcquinox.train import Pretrainer, Optimizer
from xcquinox.utils import gen_grid_s, PBE_Fx, PBE_Fc, calculate_stats, lda_x, pw92c_unpolarized
from xcquinox.pyscf import eval_xc_gga_j2



In [2]:
PSCF.__version__, PSCFAD.__version__

('2.11.0', '0.1.11')

The following cells are copied verbatim from `pbe_notebook_v5`, without the extra plotting.

In [3]:
#sigma-based networks
spbe_fx = net.GGA_FxNet_sigma(depth=3, nodes=16, seed=92017, lower_rho_cutoff = 0)
spbe_fc = net.GGA_FcNet_sigma(depth=3, nodes=16, seed=92017, lower_rho_cutoff = 0)

spbe_fx_lg = net.GGA_FxNet_sigma(depth=5, nodes=32, seed=92017, lower_rho_cutoff = 0)
spbe_fc_lg = net.GGA_FcNet_sigma(depth=5, nodes=32, seed=92017, lower_rho_cutoff = 0)

In [4]:
#training and validation values
inds, vals, tflats, vflats = gen_grid_s(npts = 1e5)
train_inds, val_inds = inds
rv, grv, sv = vals
trf, tgrf, tsf = tflats
vrf, vgrf, vsf = vflats
#training and validation values for SIGMA
sinds, svals, stflats, svflats = gen_grid_s(npts = 1e5, sigma=True)
strain_inds, sval_inds = sinds
srv, sgrv, ssv = svals
strf, stgrf, stsf = stflats
svrf, svgrf, svsf = svflats

shapes- r/gr/s: (315,)/(315,)/(315,)
shapes- r/gr/s: (315,)/(315,)/(315,)


In [5]:
ref_fx = PBE_Fx(trf, tgrf)
ref_fc = PBE_Fc(trf, tgrf)

sinputs = jnp.stack([trf, tsf], axis=1)
siginputs = jnp.stack([strf, stgrf], axis=1)
sval = jnp.stack([vrf, vsf], axis=1)
print(sval.shape, sinputs.shape)

(3844, 2) (64009, 2)


In [6]:
#create optimizers

PT_INIT_LR = 1e-3
PT_END_LR = 1e-5
PTSTEPS = 500
PTDECAYBEGIN = int(PTSTEPS/2)
scheduler = optax.linear_schedule(
    init_value = PT_INIT_LR,
    transition_steps = PTSTEPS-PTDECAYBEGIN,
    transition_begin = PTDECAYBEGIN,
    end_value = PT_END_LR,
)

soptimizer_fx = optax.adam(learning_rate=scheduler)
soptimizer_fc = optax.adam(learning_rate=scheduler)

spt_pbe_fx = Pretrainer(model = spbe_fx, optim = soptimizer_fx, inputs = siginputs, ref = ref_fx, loss = compute_loss_mae, steps = PTSTEPS)
spt_pbe_fc = Pretrainer(model = spbe_fc, optim = soptimizer_fc, inputs = siginputs, ref = ref_fc, loss = compute_loss_mae, steps = PTSTEPS)
spt_pbe_fx_lg = Pretrainer(model = spbe_fx_lg, optim = soptimizer_fx, inputs = siginputs, ref = ref_fx, loss = compute_loss_mae, steps = PTSTEPS)
spt_pbe_fc_lg = Pretrainer(model = spbe_fc_lg, optim = soptimizer_fc, inputs = siginputs, ref = ref_fc, loss = compute_loss_mae, steps = PTSTEPS)

Here, we only do the `snmx` and `snmc` pre-training from the mentioned notebook.

In [7]:
snmx, snlx = spt_pbe_fx()
snmc, snlc = spt_pbe_fc()

Epoch 0: Loss = 0.15148956514419687
Epoch 100: Loss = 0.0004852712515070963
Epoch 200: Loss = 0.00027408608613717583
Epoch 300: Loss = 0.00025285528076710327
Epoch 400: Loss = 0.00024138166315321546
Epoch 0: Loss = 0.3699024597167876
Epoch 100: Loss = 0.04856361873077554
Epoch 200: Loss = 0.021877972910576547
Epoch 300: Loss = 0.01485900824250185
Epoch 400: Loss = 0.013560472898018764


In [8]:
from ase.io import read
trj = read('/home/awills/Documents/Research/xcquinox/scripts/script_data/training_subsets/06wf/subat_ref.traj', ':')
for idx, at in enumerate(trj):
    print(idx, at, at.symbols, at.info)
    if str(at.symbols) in ['Cl2', 'HF2', 'FH']:
        print(at.get_chemical_symbols(), at.get_positions())

0 Atoms(symbols='N2', pbc=False, calculator=SinglePointCalculator(...)) N2 {'N2': True, 'name': 'Dinitrogen', 'n_rad': 6, 'n_ang': 10, 'pol': False, 'target_energy': -0.36405261150717777, 'energy': -109.41916258265782, 'atomization': -0.36405261150717777, 'atomization_ev': -9.906376144900815, 'atomization_H': -0.36405261150717777, 'calc_energy': -109.41916258265782, 'e_calc': -109.41916258265782}
1 Atoms(symbols='LiF', pbc=False, calculator=SinglePointCalculator(...)) LiF {'LiF': True, 'name': 'Lithium fluoride', 'n_rad': 3, 'n_ang': 15, 'pol': False, 'target_energy': -0.22243004393513086, 'energy': -107.3059420504875, 'atomization': -0.22243004393513086, 'atomization_ev': -6.052629788935811, 'atomization_H': -0.22243004393513086, 'calc_energy': -107.3059420504875, 'e_calc': -107.3059420504875}
2 Atoms(symbols='ClH', pbc=False, calculator=SinglePointCalculator(...)) ClH {'HCl': True, 'name': 'Hydrogen chloride', 'n_rad': 3, 'n_ang': 10, 'target_energy': -0.17123597861668197, 'energy': 

In [9]:
#reference energies:
#[H_TE, O_TE, H2O_AE] <- step 1
kjMol_to_H = 2625.5
refs1 = [-0.5, -75.0673, -974.94/kjMol_to_H]
#[Cl_TE, Cl2_AE, HF+F -> H+F2 barrier]
refs2 = [-460.148, -0.09454200500963746, 0.16920860069537955]
refs3 = refs1 + refs2

mol_params = {'H': {'atoms':['H'], 'coords': [[0,0,0]], 'spin':1, 'charge':0},
              'O': {'atoms':['O'], 'coords':[[0,0,0]], 'spin':2, 'charge':0},
              'H2O': {'atoms':['O','H','H'], 'coords':[[0,0,0],[0,-0.757,0.587],[0,0.757,0.587]], 'spin':0,'charge':0},
              'Cl':{'atoms':['Cl'], 'coords':[[0,0,0]], 'spin':1, 'charge':0},
              'Cl2':{'atoms':['Cl','Cl'], 'coords': [ [0,0,1.008241], [0,0,-1.008241] ], 'spin':0, 'charge':0},
              'HF2':{'atoms':['H', 'F', 'F'], 'coords':[ [0,0,-2.2312757],[0,0,-0.61621628],[0,0,0.8641358] ], 'spin':1, 'charge':0},
              'FH':{'atoms':['F', 'H'], 'coords':[ [0,0,0.09153813], [0,0,-0.82384424] ], 'spin':0, 'charge':0},
              'F':{'atoms':['F'], 'coords':[ [0,0,0] ], 'spin':1, 'charge':0}
             }
mol_dct = {k:0 for k in mol_params.keys()}
for sys in mol_dct.keys():
    atstr = ''
    print(sys)
    for idx, at in enumerate(mol_params[sys]['atoms']):
        atstr += '{} {} {} {}\n'.format(at, *mol_params[sys]['coords'][idx])
    print(atstr)
    mol = gto_ad.Mole(atom = atstr, charge = mol_params[sys]['charge'], spin = mol_params[sys]['spin'])
    mol.basis = 'dzvp'
    mol.build()
    mol.max_memory = 32000

    mol_dct[sys] = mol

mols1 = [mol_dct['H'], mol_dct['O'], mol_dct['H2O']]
mols2 = [mol_dct['Cl'], mol_dct['Cl2'], mol_dct['HF2'], mol_dct['FH'], mol_dct['F']]
mols3 = mols1+mols2

H
H 0 0 0

O
O 0 0 0

H2O
O 0 0 0
H 0 -0.757 0.587
H 0 0.757 0.587

Cl
Cl 0 0 0

Cl2
Cl 0 0 1.008241
Cl 0 0 -1.008241

HF2
H 0 0 -2.2312757
F 0 0 -0.61621628
F 0 0 0.8641358

FH
F 0 0 0.09153813
H 0 0 -0.82384424

F
F 0 0 0



In [10]:
class RXCModel(eqx.Module):
    xnet: eqx.Module
    cnet: eqx.Module

    def __init__(self, xnet, cnet):
        self.xnet = xnet
        self.cnet = cnet
        
    def __call__(self, inputs):
        #this generate epsilon, not exc -- divide end result by rho when needed
        try:
            #if this runs, it is unpolarized
            rho = inputs[0]
            sigma = inputs[1]
            ninputs = [rho, sigma]
        except:
            #if this runs, it is polarized
            rhoa, rhob, sigma0, sigma1, sigma2 = inputs
            rho = rhoa+rhob
            sigma = (sigma0+sigma1+sigma2)/3 #average sigma
            ninputs = [rho, sigma]
        # print('rho/sigma shape:', rho.shape, sigma.shape)
        # print('RXCModel call - inputs {}'.format(inputs))
        # return rho*(lda_x(rho)*jax.vmap(self.xnet)(inputs[..., jnp.newaxis]) + pw92c_unpolarized(rho)*jax.vmap(self.cnet)(inputs[..., jnp.newaxis])).flatten()[0]
        ldaxrho = lda_x(rho)
        pw92c = pw92c_unpolarized(rho)
        # print('ldax/pw92c shape:', ldaxrho.shape, pw92c.shape)
        # print('inputs[...,jnp.newaxis] shape:', inputs[...,jnp.newaxis].shape)
        # print('inputs.shape :', inputs.shape)
        # xvmap = jax.vmap(self.xnet)(inputs.T)
        # cvmap = jax.vmap(self.cnet)(inputs.T)
        xvmap = self.xnet(ninputs)
        cvmap = self.cnet(ninputs)
        # print('xvmap/cvmap shape:', xvmap.shape, cvmap.shape)
        retarr = rho*(ldaxrho*xvmap + pw92c*cvmap).flatten()[0]
        # print('retarr shape:', retarr.shape)
        return retarr
        # return rho*(lda_x(rho)*self.xnet(inputs) + pw92c_unpolarized(rho)*self.cnet(inputs)).flatten()[0]

In [11]:
pt_model = RXCModel(xnet = snmx, cnet = snmc)

Below is the polarized-capable PySCF driver.

In [12]:
def eval_xc_gga_j2(xc_code, rho, spin=0, relativity=0, deriv=1, omega=None, verbose=None,
                   xcmodel=None):
    # we only expect there to be a rho0 array, but I unpack it as (rho, deriv) here to be in line with the
    # pyscf example -- the size of the 'rho' array depends on the xc type (LDA, GGA, etc.)
    # so since LDA calculation, check for size first.
    # print('EVAL_XC SPIN:', spin)
    # try:
    #     print('rho shape', rho.shape)
    # except:
    #     print('rho len', len(rho))
    #     print('rho_a shape', rho[0].shape)
    try:
        #if this works, we are not spin-polarized -- GGA rho is passed as [rho0, dx, dy, dz]
        rho0, dx, dy, dz = rho
        sigma = jnp.array(dx**2+dy**2+dz**2)
        rho0 = jnp.array(rho0)
        # print('rho0, dx, dy, dz unpacked')
        # print('rho0/sigma shape:', rho0.shape, sigma.shape)
        # print('DEBUG eval_xc_gga_j: rho0/sigma shapes: ', rho0.shape, sigma.shape)
        # rhosig = (rho0, sigma)
        rhosig = jnp.stack([rho0, sigma], axis=1)
        # print('rhosig shape', rhosig.shape)
        # calculate the "custom" energy with rho -- THIS IS e
        # cast back to np.array since that's what pyscf works with
        # pass as tuple -- (rho, sigma)
        exc = jax.vmap(xcmodel)(rhosig)
        exc = jnp.array(exc)/rho0
        # print('exc shape = {}'.format(exc.shape))
        # first order derivatives w.r.t. rho and sigma
        vrho_f = eqx.filter_grad(xcmodel)
        vrhosigma = jnp.array(jax.vmap(vrho_f)(rhosig))
        # print('vrhosigma shape:', vrhosigma.shape)
        vxc = (vrhosigma[:, 0], vrhosigma[:, 1], None, None)

        #2nd order derivatives
        v2_f = jax.hessian(xcmodel)
        # v2_f = jax.hessian(custom_pbe_epsilon, argnums=[0, 1])
        v2 = jnp.array(jax.vmap(v2_f)(rhosig))
        # print('v2 shape', v2.shape)
        v2rho2 = v2[:, 0, 0]
        v2rhosigma = v2[:, 0, 1]
        v2sigma2 = v2[:, 1, 1]
        v2lapl2 = None
        vtau2 = None
        v2rholapl = None
        v2rhotau = None
        v2lapltau = None
        v2sigmalapl = None
        v2sigmatau = None
        # 2nd order functional derivative
        fxc = (v2rho2, v2rhosigma, v2sigma2, v2lapl2, vtau2, v2rholapl, v2rhotau, v2lapltau, v2sigmalapl, v2sigmatau)
        # 3rd order
        kxc = None

    except:
        #if this runs, we are spin-polarized -- GGA rho is passed as [GGA rho_a, GGA rho_b]
        rho_a, rho_b = rho
        rho0a, dxa, dya, dza = rho_a
        rho0b, dxb, dyb, dzb = rho_b

        rho0 = rho0a+rho0b
        

        #per libxc manual, sigma contracts as follows
        sigma0 = dxa**2 + dya**2 + dza**2
        sigma1 = dxa*dxb + dya*dyb + dza*dzb
        sigma2 = dxb**2 + dyb**2 + dzb**2
        
        # print('rho0, drho unpacked')
        # print('rho0/sigma0 shape:', rho0.shape, sigma0.shape)
        #at this point, we should send the rho_a, rho_b, sigma0, sigma1, sigma2 to the network
        #we must send it this way, rather than condensing beforehand, because we need the derivatives
        #w.r.t. the separate sigma derivatives/spin channels to send back to pyscf as it expects

        #we can condense these to feed into the already existing structures as above
        rhosig = jnp.stack([rho0a, rho0b, sigma0, sigma1, sigma2], axis=1)
        # print('rhosig shape', rhosig.shape)
        # calculate the "custom" energy with rho -- THIS IS e
        # cast back to np.array since that's what pyscf works with
        # pass as tuple -- (rho, sigma)
        exc = jax.vmap(xcmodel)(rhosig)
        exc = jnp.array(exc)/rho0
        # print('exc shape = {}'.format(exc.shape))
        
        # first order derivatives w.r.t. rho and sigma
        vrho_f = eqx.filter_grad(xcmodel)
        vrhosigma = jnp.array(jax.vmap(vrho_f)(rhosig))
        print('vrhosigma shape:', vrhosigma.shape)
        #for spin-polarized, vrho = [vrho_a, vrho_b]
        #and vsigma = [vsigma0, vsigma1, vsigma2]
        vrhoa = vrhosigma[:, 0]
        vrhob = vrhosigma[:, 1]
        vsigma0 = vrhosigma[:, 2]
        vsigma1 = vrhosigma[:, 3]
        vsigma2 = vrhosigma[:, 4]
        vxc = (jnp.stack([vrhoa, vrhob], axis=1), jnp.stack([vsigma0, vsigma1, vsigma2], axis=1),
               None, None)

        #2nd order derivatives
        v2_f = jax.hessian(xcmodel)
        # v2_f = jax.hessian(custom_pbe_epsilon, argnums=[0, 1])
        v2 = jnp.array(jax.vmap(v2_f)(rhosig))
        print('hessian shape:', v2.shape)
        #for spin-polarized, v2rho2 = [v2rho2_aa, v2rho2_ab, v2rho2_bb]
        #v2rhosigma = [v2rho(a)sigma0, v2rho(a)sigma1, v2rho(a)sigma2, v2rho(b)sigma0, v2rho(b)sigma1, v2rho(b)sigma2]
        #v2sigma2 = [v2sigma2_00, v2sigma2_01, v2sigma2_02, v2sigma2_11, v2sigma2_12, v2sigma2_22]
        #again, these will be provided assuming the call to xcmodel is equipped to take the full list of inputs
        # print('v2 shape', v2.shape)
        #v2rho2
        v2rho2aa = v2[:, 0, 0]
        v2rho2ab = v2[:, 0, 1]
        v2rho2bb = v2[:, 1, 1]
        #v2rhosigma
        v2rhoasigma0 = v2[:, 0, 2]
        v2rhoasigma1 = v2[:, 0, 3]
        v2rhoasigma2 = v2[:, 0, 4]
        v2rhobsigma0 = v2[:, 1, 2]
        v2rhobsigma1 = v2[:, 1, 3]
        v2rhobsigma2 = v2[:, 1, 4]
        #v2sigma2
        v2sigma200 = v2[:, 2, 2]
        v2sigma201 = v2[:, 2, 3]
        v2sigma202 = v2[:, 2, 4]
        v2sigma211 = v2[:, 3, 3]
        v2sigma212 = v2[:, 3, 4]
        v2sigma222 = v2[:, 4, 4]
        #not used
        v2lapl2 = None
        vtau2 = None
        v2rholapl = None
        v2rhotau = None
        v2lapltau = None
        v2sigmalapl = None
        v2sigmatau = None
        # 2nd order functional derivative
        fxc = (jnp.stack([v2rho2aa, v2rho2ab, v2rho2bb], axis=1),
               jnp.stack([v2rhoasigma0, v2rhoasigma1, v2rhoasigma2, v2rhobsigma0, v2rhobsigma1, v2rhobsigma2], axis=1),
               jnp.stack([v2sigma200, v2sigma201, v2sigma202, v2sigma211, v2sigma212, v2sigma222], axis=1),
               v2lapl2, vtau2, v2rholapl, v2rhotau, v2lapltau, v2sigmalapl, v2sigmatau)
        # 3rd order
        kxc = None


    return exc, vxc, fxc, kxc

In [13]:
refs1

[-0.5, -75.0673, -0.37133498381260716]

In [14]:
@eqx.filter_value_and_grad
def opt_loss1(model, mols, refs):
    #assumes mols are [H, O, H2O]
    #assumes refs are [H_TE, O_TE, H2O_AE]
    total_loss = 0
    preds = []
    for idx, mol in enumerate(mols):
        print(10*'=')
        if mol.spin:
            mf = dft_ad.UKS(mol)
        else:
            mf = dft_ad.RKS(mol)
        if idx == 0:
            mf.max_cycle = -1
        custom_eval_xc = partial(eval_xc_gga_j2, xcmodel=model)
        mf.grids.level = 1
        mf.define_xc_(custom_eval_xc, 'GGA')
        mf.kernel()
        pred = mf.e_tot
        preds.append(pred)
        this_ae = jnp.abs(refs[idx]-pred)
        print(f'{idx} -- {mol.atom}')
        if idx != 2:
            pref = refs[idx]
            total_loss += this_ae
        elif idx == 2:
            #create a "total energy" for water to aim for, based on the atomization energy we want
            refwat_e = 2*refs[0]+refs[1]+refs[2]
            pref = refwat_e
            this_ae = jnp.abs(refwat_e - pred)
            total_loss += this_ae
        jax.debug.print(f'REF={pref}, PRED={pred}, ABS. ERR.={this_ae})')
        jax.debug.print("PRED:{parameter}", parameter=pred)
        jax.debug.print("ABS. ERR.:{parameter}", parameter=this_ae)
        print(10*'=')
    print(preds)
    water_ae = preds[-1] - (2*preds[0] + preds[1])
    water_ae_loss = jnp.abs(-water_ae - refs[2])
    print('Energy calculations complete.')
    jax.debug.print(f'WATER AE REF={refs[2]}, WATER AE PRED={water_ae}, WATER ABS. ERR.={water_ae_loss})')
    jax.debug.print("PRED:{parameter}", parameter=water_ae)
    jax.debug.print("ABS. ERR.:{parameter}", parameter=water_ae_loss)
    total_loss += 100*water_ae_loss
    return total_loss[..., jnp.newaxis][0]

OPT_INIT_LR = 5e-3
OPT_END_LR = 1e-5
OPTSTEPS = 100
OPTDECAYBEGIN = 50
scheduler = optax.linear_schedule(
    init_value = OPT_INIT_LR,
    transition_steps = OPTSTEPS-OPTDECAYBEGIN,
    transition_begin = OPTDECAYBEGIN,
    end_value = OPT_END_LR,
)
opt_opt1 = optax.adam(learning_rate=scheduler)

optnet1_o = Optimizer(model=pt_model, optim=opt_opt1, mols = mols1, refs = refs1, loss=opt_loss1, print_every=1, steps=OPTSTEPS)

In [None]:
optnet1, optnet1l = optnet1_o()

vrhosigma shape: (2472, 5)
hessian shape: (2472, 5, 5)
SCF not converged.
SCF energy = -0.294018097862274 after -1 cycles  <S^2> = 0.75  2S+1 = 2
0 -- H 0 0 0

REF=-0.5, PRED=LinearizeTracer<float64[]>, ABS. ERR.=LinearizeTracer<float64[]>)
PRED:-0.29401809786227395
ABS. ERR.:0.20598190213772605
vrhosigma shape: (5184, 5)
hessian shape: (5184, 5, 5)




vrhosigma shape: (5184, 5)
hessian shape: (5184, 5, 5)
vrhosigma shape: (5184, 5)
hessian shape: (5184, 5, 5)
vrhosigma shape: (5184, 5)
hessian shape: (5184, 5, 5)
vrhosigma shape: (5184, 5)
hessian shape: (5184, 5, 5)
vrhosigma shape: (5184, 5)
hessian shape: (5184, 5, 5)
vrhosigma shape: (5184, 5)
hessian shape: (5184, 5, 5)
vrhosigma shape: (5184, 5)
hessian shape: (5184, 5, 5)
vrhosigma shape: (5184, 5)
hessian shape: (5184, 5, 5)
vrhosigma shape: (5184, 5)
hessian shape: (5184, 5, 5)
vrhosigma shape: (5184, 5)
hessian shape: (5184, 5, 5)
vrhosigma shape: (5184, 5)
hessian shape: (5184, 5, 5)
vrhosigma shape: (5184, 5)
hessian shape: (5184, 5, 5)
vrhosigma shape: (5184, 5)
hessian shape: (5184, 5, 5)
vrhosigma shape: (5184, 5)
hessian shape: (5184, 5, 5)
vrhosigma shape: (5184, 5)
hessian shape: (5184, 5, 5)
vrhosigma shape: (5184, 5)
hessian shape: (5184, 5, 5)
vrhosigma shape: (5184, 5)
hessian shape: (5184, 5, 5)
vrhosigma shape: (5184, 5)
hessian shape: (5184, 5, 5)
vrhosigma 