In [1]:
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
import equinox as eqx
import optax
import numpy as np
from pyscf import dft, scf, gto, cc
from pyscfad import dft as dft_ad
from pyscfad import gto as gto_ad
from pyscfad import scf as scf_ad
from functools import partial
import pylibxc
import pyscfad.dft as dftad
from jax import custom_jvp
jax.config.update("jax_enable_x64", True) #Enables 64 bit precision
import pyscf as PSCF
import pyscfad as PSCFAD

from xcquinox import net
from xcquinox.loss import compute_loss_mae
from xcquinox.train import Pretrainer, Optimizer
from xcquinox.utils import gen_grid_s, PBE_Fx, PBE_Fc, calculate_stats, lda_x, pw92c_unpolarized
from xcquinox.pyscf import eval_xc_gga_j2



In [2]:
PSCF.__version__, PSCFAD.__version__

('2.11.0', '0.1.11')

In [3]:
#sigma-based networks
spbe_fx = net.GGA_FxNet_sigma(depth=3, nodes=32, seed=92017, lower_rho_cutoff = 0)
spbe_fc = net.GGA_FcNet_sigma(depth=3, nodes=32, seed=92017, lower_rho_cutoff = 0)


In [4]:
#training and validation values
inds, vals, tflats, vflats = gen_grid_s(npts = 1e5)
train_inds, val_inds = inds
rv, grv, sv = vals
trf, tgrf, tsf = tflats
vrf, vgrf, vsf = vflats
#training and validation values for SIGMA
sinds, svals, stflats, svflats = gen_grid_s(npts = 1e5, sigma=True)
strain_inds, sval_inds = sinds
srv, sgrv, ssv = svals
strf, stgrf, stsf = stflats
svrf, svgrf, svsf = svflats

shapes- r/gr/s: (315,)/(315,)/(315,)
shapes- r/gr/s: (315,)/(315,)/(315,)


In [5]:
ref_fx = PBE_Fx(trf, tgrf)
ref_fc = PBE_Fc(trf, tgrf)

sinputs = jnp.stack([trf, tsf], axis=1)
siginputs = jnp.stack([strf, stgrf], axis=1)
sval = jnp.stack([vrf, vsf], axis=1)
print(sval.shape, sinputs.shape)

(3844, 2) (64009, 2)


In [6]:
#create optimizers

PT_INIT_LR = 1e-3
PT_END_LR = 1e-5
PTSTEPS = 1000
PTDECAYBEGIN = int(PTSTEPS/2)
scheduler = optax.linear_schedule(
    init_value = PT_INIT_LR,
    transition_steps = PTSTEPS-PTDECAYBEGIN,
    transition_begin = PTDECAYBEGIN,
    end_value = PT_END_LR,
)

soptimizer_fx = optax.adam(learning_rate=scheduler)
soptimizer_fc = optax.adam(learning_rate=scheduler)

spt_pbe_fx = Pretrainer(model = spbe_fx, optim = soptimizer_fx, inputs = siginputs, ref = ref_fx, loss = compute_loss_mae, steps = PTSTEPS)
spt_pbe_fc = Pretrainer(model = spbe_fc, optim = soptimizer_fc, inputs = siginputs, ref = ref_fc, loss = compute_loss_mae, steps = PTSTEPS)

Here, we only do the `snmx` and `snmc` pre-training from the mentioned notebook.

In [7]:
snmx, snlx = spt_pbe_fx()
snmc, snlc = spt_pbe_fc()

Epoch 0: Loss = 0.15158353991257742
Epoch 100: Loss = 0.0005889393505375802
Epoch 200: Loss = 0.000221089486441477
Epoch 300: Loss = 0.00017561560276741707
Epoch 400: Loss = 0.00018537229588970274
Epoch 500: Loss = 0.000303390725213164
Epoch 600: Loss = 0.00013553651027617365
Epoch 700: Loss = 0.00010916811469073238
Epoch 800: Loss = 9.313785935697063e-05
Epoch 900: Loss = 7.733615435366374e-05
Epoch 0: Loss = 0.365125275729497
Epoch 100: Loss = 0.03012428407370393
Epoch 200: Loss = 0.01796879988541145
Epoch 300: Loss = 0.01349432050911389
Epoch 400: Loss = 0.008112128623746344
Epoch 500: Loss = 0.005430790435757247
Epoch 600: Loss = 0.00409220780098196
Epoch 700: Loss = 0.0032610135092225696
Epoch 800: Loss = 0.0027869962791801993
Epoch 900: Loss = 0.0025254708466412295


In [8]:
from ase.io import read
trj = read('/home/awills/Documents/Research/xcquinox/scripts/script_data/training_subsets/06wf/subat_ref.traj', ':')
for idx, at in enumerate(trj):
    print(idx, at, at.symbols, at.info)
    if str(at.symbols) in ['Cl2', 'HF2', 'FH']:
        print(at.get_chemical_symbols(), at.get_positions())

0 Atoms(symbols='N2', pbc=False, calculator=SinglePointCalculator(...)) N2 {'N2': True, 'name': 'Dinitrogen', 'n_rad': 6, 'n_ang': 10, 'pol': False, 'target_energy': -0.36405261150717777, 'energy': -109.41916258265782, 'atomization': -0.36405261150717777, 'atomization_ev': -9.906376144900815, 'atomization_H': -0.36405261150717777, 'calc_energy': -109.41916258265782, 'e_calc': -109.41916258265782}
1 Atoms(symbols='LiF', pbc=False, calculator=SinglePointCalculator(...)) LiF {'LiF': True, 'name': 'Lithium fluoride', 'n_rad': 3, 'n_ang': 15, 'pol': False, 'target_energy': -0.22243004393513086, 'energy': -107.3059420504875, 'atomization': -0.22243004393513086, 'atomization_ev': -6.052629788935811, 'atomization_H': -0.22243004393513086, 'calc_energy': -107.3059420504875, 'e_calc': -107.3059420504875}
2 Atoms(symbols='ClH', pbc=False, calculator=SinglePointCalculator(...)) ClH {'HCl': True, 'name': 'Hydrogen chloride', 'n_rad': 3, 'n_ang': 10, 'target_energy': -0.17123597861668197, 'energy': 

In [9]:
# Define molecules
mol_params = {
    'H': {'atoms': ['H'], 'coords': [[0, 0, 0]], 'spin': 1, 'charge': 0},
    'O': {'atoms': ['O'], 'coords': [[0, 0, 0]], 'spin': 2, 'charge': 0},
    'H2O': {'atoms': ['O', 'H', 'H'], 
            'coords': [[0, 0, 0], [0, -0.757, 0.587], [0, 0.757, 0.587]], 
            'spin': 0, 'charge': 0},
}

# Reference energies (Hartree)
# H: -0.5 Ha (exact)
# O: -75.0673 Ha (high-level reference)
# H2O atomization: -0.371 Ha (974.94 kJ/mol)
kjMol_to_H = 2625.5
refs = {
    'H_TE': -0.5,
    'O_TE': -75.0673,
    'H2O_AE': -974.94 / kjMol_to_H  # Atomization energy
}

# Build molecules
mol_dct = {}
for sys, params in mol_params.items():
    atstr = '\n'.join([f"{at} {' '.join(map(str, params['coords'][i]))}" 
                       for i, at in enumerate(params['atoms'])])
    mol = gto_ad.Mole(atom=atstr, charge=params['charge'], spin=params['spin'])
    mol.basis = 'def2-svp'
    mol.build()
    mol.max_memory = 32000
    mol_dct[sys] = mol
    print(f"Built {sys}: {mol.nelectron} electrons, spin={params['spin']}")

mols = [mol_dct['H'], mol_dct['O'], mol_dct['H2O']]

Built H: 1 electrons, spin=1
Built O: 8 electrons, spin=2
Built H2O: 10 electrons, spin=0


In [10]:
class RXCModel(eqx.Module):
    xnet: eqx.Module
    cnet: eqx.Module

    xnet: eqx.Module
    cnet: eqx.Module

    def __init__(self, xnet, cnet):
        self.xnet = xnet
        self.cnet = cnet
        
    def __call__(self, inputs):
        """Compute epsilon = rho * e_xc.
        
        For unpolarized: inputs = [rho, sigma]
        For polarized: inputs = [rho_a, rho_b, sigma_aa, sigma_ab, sigma_bb]
        """
        # Detect input format by length
        inputs = jnp.atleast_1d(inputs)
        
        if inputs.shape[-1] == 2:
            # Unpolarized case
            rho = inputs[0]
            sigma = inputs[1]
        elif inputs.shape[-1] == 5:
            # Polarized case - combine densities
            rho_a = inputs[0]
            rho_b = inputs[1]
            sigma_aa = inputs[2]
            sigma_ab = inputs[3]
            sigma_bb = inputs[4]
            
            rho = rho_a + rho_b
            # Total sigma for simplified treatment
            sigma = sigma_aa + 2*sigma_ab + sigma_bb
        else:
            raise ValueError(f"Unexpected input shape: {inputs.shape}")
        
        rho = jnp.maximum(rho, 1e-18)
        
        # Network inputs
        net_inputs = [rho, sigma]
        
        # Get enhancement factors
        Fx = self.xnet(net_inputs)
        Fc = self.cnet(net_inputs)
        
        # Compute energy densities
        ex_lda = lda_x(rho)
        ec_pw92 = pw92c_unpolarized(rho)
        
        # Total epsilon = rho * e_xc
        epsilon = rho * (ex_lda * Fx + ec_pw92 * Fc)
        
        # Return scalar
        return jnp.squeeze(epsilon)


In [11]:
pt_model = RXCModel(xnet = snmx, cnet = snmc)

Below is the polarized-capable PySCF driver.

In [12]:
def eval_xc_gga_j2(xc_code, rho, spin=0, relativity=0, deriv=1, omega=None, verbose=None,
                   xcmodel=None):
    # Detect if spin-polarized by checking if rho is a tuple/list
    try:
        # Try unpolarized first
        rho0, dx, dy, dz = rho[:4]
        sigma = jnp.array(dx**2 + dy**2 + dz**2)
        rho0 = jnp.array(rho0)
        is_polarized = False
    except (ValueError, TypeError):
        # Spin-polarized: rho = [rho_a, rho_b]
        rho_a, rho_b = rho
        rho0a, dxa, dya, dza = rho_a[:4]
        rho0b, dxb, dyb, dzb = rho_b[:4]
        
        rho0 = rho0a + rho0b
        sigma_aa = dxa**2 + dya**2 + dza**2
        sigma_ab = dxa*dxb + dya*dyb + dza*dzb  
        sigma_bb = dxb**2 + dyb**2 + dzb**2
        is_polarized = True
    
    if not is_polarized:
        # ============ UNPOLARIZED CASE ============
        rhosig = jnp.stack([rho0, sigma], axis=1)
        
        # Compute exc = epsilon / rho
        epsilon = jax.vmap(xcmodel)(rhosig)
        exc = epsilon / (rho0 + 1e-18)
        
        # First derivatives: d(epsilon)/d(rho), d(epsilon)/d(sigma)
        vrho_f = eqx.filter_grad(xcmodel)
        v1 = jnp.array(jax.vmap(vrho_f)(rhosig))
        vrho = v1[:, 0]
        vsigma = v1[:, 1]
        vxc = (vrho, vsigma, None, None)
        
        # Second derivatives (Hessian)
        v2_f = jax.hessian(xcmodel)
        v2 = jnp.array(jax.vmap(v2_f)(rhosig))
        
        v2rho2 = v2[:, 0, 0]
        v2rhosigma = v2[:, 0, 1]
        v2sigma2 = v2[:, 1, 1]
        
        fxc = (v2rho2, v2rhosigma, v2sigma2, 
               None, None, None, None, None, None, None)
        kxc = None
        
    else:
        # ============ POLARIZED CASE ============
        rhosig = jnp.stack([rho0a, rho0b, sigma_aa, sigma_ab, sigma_bb], axis=1)
        
        # Compute exc
        epsilon = jax.vmap(xcmodel)(rhosig)
        exc = epsilon / (rho0 + 1e-18)
        
        # First derivatives
        vrho_f = eqx.filter_grad(xcmodel)
        v1 = jnp.array(jax.vmap(vrho_f)(rhosig))
        
        # vrho = [vrho_a, vrho_b]
        vrho = jnp.stack([v1[:, 0], v1[:, 1]], axis=1)
        # vsigma = [vsigma_aa, vsigma_ab, vsigma_bb]
        vsigma = jnp.stack([v1[:, 2], v1[:, 3], v1[:, 4]], axis=1)
        vxc = (vrho, vsigma, None, None)
        
        # Second derivatives
        v2_f = jax.hessian(xcmodel)
        v2 = jnp.array(jax.vmap(v2_f)(rhosig))
        
        # v2rho2 = [aa, ab, bb]
        v2rho2 = jnp.stack([v2[:, 0, 0], v2[:, 0, 1], v2[:, 1, 1]], axis=1)
        
        # v2rhosigma = [a-aa, a-ab, a-bb, b-aa, b-ab, b-bb]
        v2rhosigma = jnp.stack([
            v2[:, 0, 2], v2[:, 0, 3], v2[:, 0, 4],
            v2[:, 1, 2], v2[:, 1, 3], v2[:, 1, 4]
        ], axis=1)
        
        # v2sigma2 = [aa-aa, aa-ab, aa-bb, ab-ab, ab-bb, bb-bb]
        v2sigma2 = jnp.stack([
            v2[:, 2, 2], v2[:, 2, 3], v2[:, 2, 4],
            v2[:, 3, 3], v2[:, 3, 4], v2[:, 4, 4]
        ], axis=1)
        
        fxc = (v2rho2, v2rhosigma, v2sigma2,
               None, None, None, None, None, None, None)
        kxc = None
    
    return exc, vxc, fxc, kxc

In [13]:
@eqx.filter_value_and_grad
def opt_loss1(model, mols, refs):
    #assumes mols are [H, O, H2O]
    #assumes refs are [H_TE, O_TE, H2O_AE]
    preds = []
    
    for idx, mol in enumerate(mols):
        # Select RKS or UKS based on spin
        if mol.spin:
            mf = dft_ad.UKS(mol)
        else:
            mf = dft_ad.RKS(mol)
        
        # Configure calculation
        mf.grids.level = 1  # Low grid for speed during training
        
        # Set custom XC functional
        custom_eval_xc = partial(eval_xc_gga_j2, xcmodel=model)
        mf.define_xc_(custom_eval_xc, 'GGA')
        
        # For single atoms, limit SCF cycles to avoid convergence issues
        if len(mol.atom) == 1:
            mf.max_cycle = 50
        
        # Run SCF
        pred = mf.kernel()
        preds.append(pred)
        
        jax.debug.print("Mol {idx}: E = {pred}", idx=idx, pred=pred)
    
    E_H, E_O, E_H2O = preds
    
    # Compute atomization energy: AE = E(H2O) - 2*E(H) - E(O)
    # Note: atomization is negative (energy released)
    pred_AE = E_H2O - 2*E_H - E_O
    ref_AE = refs['H2O_AE']
    
    # Loss components:
    # 1. Atomization energy error (most important - weighted heavily)
    AE_loss = jnp.abs(pred_AE - ref_AE)
    
    # 2. Total energy errors (less important for chemistry)
    H_loss = jnp.abs(E_H - refs['H_TE'])
    O_loss = jnp.abs(E_O - refs['O_TE'])
    
    # Combine with atomization energy dominating
    total_loss = 50.0 * AE_loss + H_loss + O_loss
    
    jax.debug.print("AE: pred={pred}, ref={ref}, error={err}", 
                   pred=pred_AE, ref=ref_AE, err=AE_loss)
    jax.debug.print("Total loss: {loss}", loss=total_loss)
    
    return total_loss

OPT_INIT_LR = 5e-3
OPT_END_LR = 1e-5
OPTSTEPS = 100
OPTDECAYBEGIN = 50
scheduler = optax.linear_schedule(
    init_value = OPT_INIT_LR,
    transition_steps = OPTSTEPS-OPTDECAYBEGIN,
    transition_begin = OPTDECAYBEGIN,
    end_value = OPT_END_LR,
)
opt_opt1 = optax.adam(learning_rate=scheduler)
#gradient clipping
# opt_opt1 = optax.chain(
#    optax.clip_by_global_norm(1.0),
#    optax.adam(learning_rate=scheduler)
# )

optnet1_o = Optimizer(model=pt_model, optim=opt_opt1, mols = mols, refs = refs, loss=opt_loss1, print_every=1, steps=OPTSTEPS)

In [None]:
optnet1, optnet1l = optnet1_o()



converged SCF energy = -0.45669628074606  <S^2> = 0.75  2S+1 = 2
Mol 0: E = -0.4566962807460596
SCF not converged.
SCF energy = -74.785701213049 after 50 cycles  <S^2> = 2  2S+1 = 3
Mol 1: E = -74.78570121304898
converged SCF energy = -76.2246981400229
Mol 2: E = -76.22469814002287
AE: pred=-0.5256043654817688, ref=-0.37133498381260716, error=0.15426938166916165
Total loss: 8.038371589663047
Epoch 0: Loss = 8.038371589663047
converged SCF energy = -0.447728281967961  <S^2> = 0.75  2S+1 = 2
Mol 0: E = -0.4477282819679607
SCF not converged.
SCF energy = -74.5983802020161 after 50 cycles  <S^2> = 2  2S+1 = 3
Mol 1: E = -74.59838020201606
converged SCF energy = -76.0305044606451
Mol 2: E = -76.03050446064509
AE: pred=-0.5366676946931079, ref=-0.37133498381260716, error=0.16533271088050078
Total loss: 8.787827060041021
Epoch 1: Loss = 8.787827060041021
converged SCF energy = -0.442196224191252  <S^2> = 0.75  2S+1 = 2
Mol 0: E = -0.4421962241912517
SCF not converged.
SCF energy = -74.4944948