In [1]:
# import pyscfad
# from pyscfad import gto,dft,scf
import pyscf
from pyscf import gto,dft,scf



In [2]:
import numpy as np
import jax.numpy as jnp
import scipy
from ase import Atoms
from ase.io import read
import xcquinox as xce
from functools import partial
from ase.units import Bohr
import os, optax, jax


2024-04-16 13:39:35.230790: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:280] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


In [3]:
def get_mol(atoms, basis='6-311++G**'):
    pos = atoms.positions
    spec = atoms.get_chemical_symbols()
    mol_input = [[s, p] for s, p in zip(spec, pos)]
    try:
        mol = gto.Mole(atom=mol_input, basis=atoms.info.get('basis',basis),spin=atoms.info.get('spin',0))
    except Exception:
        mol = gto.Mole(atom=mol_input, basis=atoms.info.get('basis','STO-3G'),spin=atoms.info.get('spin',0))
    return mol 

def get_rhos(rho, spin):
    rho0 = rho[0,0]
    drho = rho[0,1:4] + rho[1:4,0]
    tau = 0.5*(rho[1,1] + rho[2,2] + rho[3,3])

    if spin != 0:
        rho0_a = rho0[0]
        rho0_b = rho0[1]
        gamma_a, gamma_b = jnp.einsum('ij,ij->j',drho[:,0],drho[:,0]), jnp.einsum('ij,ij->j',drho[:,1],drho[:,1])              
        gamma_ab = jnp.einsum('ij,ij->j',drho[:,0],drho[:,1])
        tau_a, tau_b = tau
    else:
        rho0_a = rho0_b = rho0*0.5
        gamma_a=gamma_b=gamma_ab= jnp.einsum('ij,ij->j',drho[:],drho[:])*0.25
        tau_a = tau_b = tau*0.5
    return rho0_a, rho0_b, gamma_a, gamma_b, gamma_ab, tau_a, tau_b
    
def get_data_synth(xc_func, n=100):
    def get_rho(s, a):
        c0 = 2*(3*np.pi**2)**(1/3)
        c1 = 3/10*(3*np.pi**2)**(2/3)
        gamma = c0*s
        tau = c1*a+c0**2*s**2/8
        rho = np.zeros([len(a),6])
        rho[:, 1] = gamma
        rho[:,-1] = tau
        rho[:, 0] = 1
        return rho
    
    s_grid = jnp.concatenate([[0],jnp.exp(jnp.linspace(-10,4,n))])
    rho = []
    for s in s_grid:
        if 'MGGA' in xc_func:
            a_grid = jnp.concatenate([jnp.exp(jnp.linspace(jnp.log((s/100)+1e-8),8,n))])
        else:
            a_grid = jnp.array([0])
        rho.append(get_rho(s, a_grid))
        
    rho = jnp.concatenate(rho)
    
    fxc =  dft.numint.libxc.eval_xc(xc_func,rho.T, spin=0)[0]/dft.numint.libxc.eval_xc('LDA_X',rho.T, spin=0)[0] -1
 
    rho = jnp.asarray(rho)
    
    tdrho = xc.get_descriptors(rho[:,0]/2,rho[:,0]/2,(rho[:,1]/2)**2,(rho[:,1]/2)**2,(rho[:,1]/2)**2,rho[:,5]/2,rho[:,5]/2, spin_scaling=True)
    


    tFxc = torch.from_numpy(fxc)
    return tdrho[0], tFxc

def get_data(mol, xc_func ,full=False, enhance_spin=False, localnet=None):
    print('mol: ', mol.atom)
    mf = dft.RKS(mol)
    print(mf)   
    mf.xc = 'PBE'
    mf.grids.level = 1
    mf.kernel()
    ao = mf._numint.eval_ao(mol, mf.grids.coords, deriv=2)
    dm = mf.make_rdm1()
    if len(dm.shape) == 2:
        dm = np.array([0.5*dm, 0.5*dm])
    print('ao.shape', ao.shape)
    if not full:
        mf.grids.coords = coords
        mf.grids.weights = weights
    if localnet.spin_scaling:
        print('spin scaling, indicates exchange network')
        rho_alpha = mf._numint.eval_rho(mol, ao, dm[0], xctype='metaGGA',hermi=True)
        rho_beta = mf._numint.eval_rho(mol, ao, dm[1], xctype='metaGGA',hermi=True)
        fxc_a =  mf._numint.eval_xc(xc_func,(rho_alpha,rho_alpha*0), spin=1)[0]/mf._numint.eval_xc('LDA_X',(rho_alpha,rho_alpha*0), spin=1)[0] -1
        fxc_b =  mf._numint.eval_xc(xc_func,(rho_beta*0,rho_beta), spin=1)[0]/mf._numint.eval_xc('LDA_X',(rho_beta*0,rho_beta), spin=1)[0] -1
        print('fxc with xc_func = {} = {}'.format(fxc_a, xc_func))

        if mol.spin != 0 and sum(mol.nelec)>1:
            rho = jnp.concatenate([rho_alpha, rho_beta])
            fxc = jnp.concatenate([fxc_a, fxc_b])
        else:
            rho = rho_alpha
            fxc = fxc_a
    else:    
        print('no spin scaling, indicates correlation network')
        rho_alpha = mf._numint.eval_rho(mol, ao, dm[0], xctype='metaGGA',hermi=True)
        rho_beta = mf._numint.eval_rho(mol, ao, dm[1], xctype='metaGGA',hermi=True)
        exc = mf._numint.eval_xc(xc_func,(rho_alpha,rho_beta), spin=1)[0]
        print('exc with xc_func = {} = {}'.format(exc, xc_func))
        fxc = exc/mf._numint.eval_xc('LDA_C_PW',(rho_alpha, rho_beta), spin=1)[0] -1
        rho = jnp.stack([rho_alpha,rho_beta], axis=-1)
    
    dm = jnp.array(mf.make_rdm1())
    print('get_data, dm shape = {}'.format(dm.shape))
    ao_eval = jnp.array(mf._numint.eval_ao(mol, mf.grids.coords, deriv=1))
    rho = jnp.einsum('xij,yik,...jk->xy...i', ao_eval, ao_eval, dm)
    print('rho shape', rho.shape)
    if dm.ndim == 3:
        rho_filt = (jnp.sum(rho[0,0],axis=0) > 1e-6)
    else:
        rho_filt = (rho[0,0] > 1e-6)
    print('rho_filt shape:', rho_filt.shape)

    
    mf.converged=True
    tdrho = xc.get_descriptors(*get_rhos(rho, spin=mol.spin), spin_scaling=localnet.spin_scaling, mf=mf, dm=dm)
    print('tdrho shape from get_descriptors', tdrho.shape)
        
    if localnet.spin_scaling:
        if mol.spin != 0 and sum(mol.nelec) > 1:
            #tdrho not returned in a spin-polarized form regardless,
            #but the enhancement factors sampled as polarized, so double
            tdrho = jnp.concatenate([tdrho,tdrho])
            rho_filt = jnp.concatenate([rho_filt]*2)
        elif sum(mol.nelec) == 1:
            pass
    
    tdrho = tdrho[rho_filt]
    # tdrho = tdrho
    print('tdrho[rho_filt], rho_filt shapes:', tdrho.shape, rho_filt.shape)
    print('fxc shape', fxc.shape)
    tFxc = jnp.array(fxc)[rho_filt]
    # tFxc = jnp.array(fxc)
    print('returned tdrho, tFxc shape: ', tdrho.shape, tFxc.shape)
    return tdrho, tFxc

In [4]:
PRETRAIN_LEVEL = 'NONLOCAL'

TRAIN_NET = 'x'

REFERENCE_XC = 'PBE0'

N_HIDDEN = 16
DEPTH = 3
if PRETRAIN_LEVEL == 'GGA':
    if TRAIN_NET == 'x':
        localx = xce.net.eX(n_input=1, n_hidden=N_HIDDEN, use=[1], depth=DEPTH, lob=1.804)
    elif TRAIN_NET == 'c':
        localc = xce.net.eC(n_input=3, n_hidden=N_HIDDEN, use=[2], depth=DEPTH, ueg_limit=True)
elif PRETRAIN_LEVEL == 'MGGA':
    if TRAIN_NET == 'x':
        localx = xce.net.eX(n_input=2, n_hidden=N_HIDDEN, use=[1, 2], depth=DEPTH, ueg_limit=True, lob=1.174)
    elif TRAIN_NET == 'c':
        localc = xce.net.eC(n_input=4, n_hidden=N_HIDDEN, depth=DEPTH, use=[2,3], ueg_limit=True)
elif PRETRAIN_LEVEL == 'NONLOCAL':
    if TRAIN_NET == 'x':
        localx = xce.net.eX(n_input=15, n_hidden=N_HIDDEN, depth=DEPTH, ueg_limit=True, lob=1.174)
    elif TRAIN_NET == 'c':
        #n_input = 4 from base, 12 from NL
        localc = xce.net.eC(n_input=16, n_hidden=N_HIDDEN, depth=DEPTH, ueg_limit=True)

if TRAIN_NET == 'x':
    thislocal = localx
else:
    thislocal = localc
ueg = xce.xc.LDA_X()
xc = xce.xc.eXC(grid_models=[thislocal], heg_mult=True, level= {'GGA':2, 'MGGA':3, 'NONLOCAL':4}[PRETRAIN_LEVEL])



In [5]:
spins = {
    'Al': 1,
    'B' : 1,
    'Li': 1,
    'Na': 1,
    'Si': 2 ,
    'Be':0,
    'C': 2,
    'Cl': 1,
    'F': 1,
    'H': 1,
    'N': 3,
    'O': 2,
    'P': 3,
    'S': 2
}

selection = [2, 113, 25, 18, 11, 17, 114, 121, 101, 0, 20, 26, 29, 67, 28, 110, 125, 10, 115, 89, 105, 50]
try:
    atoms = [read('/home/awills/Documents/Research/ogdpyscf/dpyscf/data/haunschild_g2/g2_97.traj',':')[s] for s in selection]
except:
    atoms = [read('/home/awills/Documents/Research/ogdpyscf/data/haunschild_g2/g2_97.traj',':')[s] for s in selection]
ksr_atoms = atoms
if PRETRAIN_LEVEL=='MGGA':
    ksr_atoms = ksr_atoms[2:]
ksr_atoms = [Atoms('P',info={'spin':3}), Atoms('N', info={'spin':3}), Atoms('H', info={'spin':1}),Atoms('Li', info={'spin':1}), Atoms('O',info={'spin':2}),Atoms('Cl',info={'spin':1}),Atoms('Al',info={'spin':1}), Atoms('S',info={'spin':2})] + ksr_atoms
# ksr_atoms = [Atoms('H',info={'spin':1})]
mols = [get_mol(atoms) for atoms in ksr_atoms]
mols = [i for i in mols if len(i.atom) < 8]
for i in mols:
    print(i, i.atom, len(i.atom))

<pyscf.gto.mole.Mole object at 0x7fed050732b0> [['P', array([0., 0., 0.])]] 1
<pyscf.gto.mole.Mole object at 0x7fed05073790> [['N', array([0., 0., 0.])]] 1
<pyscf.gto.mole.Mole object at 0x7fed050718a0> [['H', array([0., 0., 0.])]] 1
<pyscf.gto.mole.Mole object at 0x7fed05070280> [['Li', array([0., 0., 0.])]] 1
<pyscf.gto.mole.Mole object at 0x7fed05071b40> [['O', array([0., 0., 0.])]] 1
<pyscf.gto.mole.Mole object at 0x7fed05073430> [['Cl', array([0., 0., 0.])]] 1
<pyscf.gto.mole.Mole object at 0x7fed050715a0> [['Al', array([0., 0., 0.])]] 1
<pyscf.gto.mole.Mole object at 0x7fed05073100> [['S', array([0., 0., 0.])]] 1
<pyscf.gto.mole.Mole object at 0x7fed050726e0> [['H', array([0.      , 0.      , 0.371395])], ['H', array([ 0.      ,  0.      , -0.371395])]] 2
<pyscf.gto.mole.Mole object at 0x7fed05073640> [['N', array([0.      , 0.      , 0.549396])], ['N', array([ 0.      ,  0.      , -0.549396])]] 2
<pyscf.gto.mole.Mole object at 0x7fed050738b0> [['Li', array([ 0.      ,  0.      ,

In [6]:
# data = [get_data(mol, xc_func=ref, full=i<14) for i,mol in enumerate(mols)]
ref = 'PBE0'
data = [get_data(mol, xc_func=ref,full=True, localnet=thislocal) for i,mol in enumerate(mols)]
# 
# data = [get_data_synth(ref, 100)]
tdrho = jnp.concatenate([d[0] for d in data])
tFxc = jnp.concatenate([d[1] for d in data])

mol:  [['P', array([0., 0., 0.])]]
ROKS object of <class 'pyscf.dft.roks.ROKS'>


Initialize <pyscf.gto.mole.Mole object at 0x7fed050732b0> in ROKS object of <class 'pyscf.dft.roks.ROKS'>


converged SCF energy = -341.103975305829
ao.shape (10, 6320, 30)
spin scaling, indicates exchange network
fxc with xc_func = [ 0.35227731  0.35296584  0.33914136 ... -0.24267484 -0.24267484
 -0.24267484] = PBE0
get_data, dm shape = (2, 30, 30)
rho shape (4, 4, 2, 6320)
rho_filt shape: (6320,)
Shape mo_coeff present, will pad vele_mat
Initial vele_mat shape:  (30, 30, 6320)
shape_mo_coeff shape:  (30, 30)
Will try to pad to:  [(0, 0), (0, 0), (0, 0)]
get_vele_mat, returned shape: (30, 30, 6320)
get_vele_mat, points shape: (6320, 3)
Evaluating MGGA Data; shapes: 
grid.coords = (6320, 3)
generated ao_data = (20, 6320, 30)
rdm1 input shape: (2, 30, 30)
ao_data shape will be padded to match rdm1
paddeds ao_data = (20, 6320, 30)
nl_4, descr5 shape: (2, 12, 6320)
descr shape:  (6320, 6)
descr5 shape:  (6320, 12)
tdrho shape from get_descriptors (6320, 18)
tdrho[rho_filt], rho_filt shapes: (10940, 18) (12640,)
fxc shape (12640,)
returned tdrho, tFxc shape:  (10940, 18) (10940,)
mol:  [['N', ar

Initialize <pyscf.gto.mole.Mole object at 0x7fed05073790> in ROKS object of <class 'pyscf.dft.roks.ROKS'>


converged SCF energy = -54.5278183883327
ao.shape (10, 5184, 22)
spin scaling, indicates exchange network
fxc with xc_func = [ 0.35208183  0.35288518  0.35299592 ... -0.2364243  -0.2364243
 -0.2364243 ] = PBE0
get_data, dm shape = (2, 22, 22)
rho shape (4, 4, 2, 5184)
rho_filt shape: (5184,)


  fxc_b =  mf._numint.eval_xc(xc_func,(rho_beta*0,rho_beta), spin=1)[0]/mf._numint.eval_xc('LDA_X',(rho_beta*0,rho_beta), spin=1)[0] -1


Shape mo_coeff present, will pad vele_mat
Initial vele_mat shape:  (22, 22, 5184)
shape_mo_coeff shape:  (22, 22)
Will try to pad to:  [(0, 0), (0, 0), (0, 0)]
get_vele_mat, returned shape: (22, 22, 5184)
get_vele_mat, points shape: (5184, 3)
Evaluating MGGA Data; shapes: 
grid.coords = (5184, 3)
generated ao_data = (20, 5184, 22)
rdm1 input shape: (2, 22, 22)
ao_data shape will be padded to match rdm1
paddeds ao_data = (20, 5184, 22)
nl_4, descr5 shape: (2, 12, 5184)
descr shape:  (5184, 6)
descr5 shape:  (5184, 12)
tdrho shape from get_descriptors (5184, 18)
tdrho[rho_filt], rho_filt shapes: (8668, 18) (10368,)
fxc shape (10368,)
returned tdrho, tFxc shape:  (8668, 18) (8668,)
mol:  [['H', array([0., 0., 0.])]]
ROKS object of <class 'pyscf.dft.roks.ROKS'>


Initialize <pyscf.gto.mole.Mole object at 0x7fed050718a0> in ROKS object of <class 'pyscf.dft.roks.ROKS'>


converged SCF energy = -0.49981298402386
ao.shape (10, 2488, 7)
spin scaling, indicates exchange network
fxc with xc_func = [ 0.35187063  0.35279876  0.35299579 ... -0.19767117 -0.19767117
 -0.19767117] = PBE0
get_data, dm shape = (2, 7, 7)
rho shape (4, 4, 2, 2488)
rho_filt shape: (2488,)
Shape mo_coeff present, will pad vele_mat
Initial vele_mat shape:  (7, 7, 2488)
shape_mo_coeff shape:  (7, 7)
Will try to pad to:  [(0, 0), (0, 0), (0, 0)]
get_vele_mat, returned shape: (7, 7, 2488)
get_vele_mat, points shape: (2488, 3)
Evaluating MGGA Data; shapes: 
grid.coords = (2488, 3)
generated ao_data = (20, 2488, 7)
rdm1 input shape: (2, 7, 7)
ao_data shape will be padded to match rdm1
paddeds ao_data = (20, 2488, 7)
nl_4, descr5 shape: (2, 12, 2488)
descr shape:  (2488, 6)
descr5 shape:  (2488, 12)
tdrho shape from get_descriptors (2488, 18)
tdrho[rho_filt], rho_filt shapes: (2144, 18) (2488,)
fxc shape (2488,)
returned tdrho, tFxc shape:  (2144, 18) (2144,)
mol:  [['Li', array([0., 0., 0.])

Initialize <pyscf.gto.mole.Mole object at 0x7fed05070280> in ROKS object of <class 'pyscf.dft.roks.ROKS'>


ROKS object of <class 'pyscf.dft.roks.ROKS'>
converged SCF energy = -7.46006139980431
ao.shape (10, 4640, 22)
spin scaling, indicates exchange network
fxc with xc_func = [0.34328499 0.3516925  0.3516925  ... 0.3516925  0.2854847  0.32157416] = PBE0
get_data, dm shape = (2, 22, 22)
rho shape (4, 4, 2, 4640)
rho_filt shape: (4640,)
Shape mo_coeff present, will pad vele_mat
Initial vele_mat shape:  (22, 22, 4640)
shape_mo_coeff shape:  (22, 22)
Will try to pad to:  [(0, 0), (0, 0), (0, 0)]
get_vele_mat, returned shape: (22, 22, 4640)
get_vele_mat, points shape: (4640, 3)
Evaluating MGGA Data; shapes: 
grid.coords = (4640, 3)
generated ao_data = (20, 4640, 22)
rdm1 input shape: (2, 22, 22)
ao_data shape will be padded to match rdm1
paddeds ao_data = (20, 4640, 22)
nl_4, descr5 shape: (2, 12, 4640)
descr shape:  (4640, 6)
descr5 shape:  (4640, 12)
tdrho shape from get_descriptors (4640, 18)
tdrho[rho_filt], rho_filt shapes: (8600, 18) (9280,)
fxc shape (9280,)
returned tdrho, tFxc shape:  (

Initialize <pyscf.gto.mole.Mole object at 0x7fed05071b40> in ROKS object of <class 'pyscf.dft.roks.ROKS'>


ROKS object of <class 'pyscf.dft.roks.ROKS'>



KeyboardInterrupt



In [None]:
nan_filt = ~jnp.any((tdrho != tdrho),axis=-1)

tFxc = tFxc[nan_filt]
tdrho = tdrho[nan_filt,:]

In [None]:
tdrho[::].shape, tdrho[::], tFxc.shape, tFxc

In [None]:
np.sum([(0,0),(0,0),(0,0)])

In [None]:
cpus = jax.devices(backend='cpu')

In [None]:
PRINT_EVERY=200
scheduler = optax.exponential_decay(init_value = 1e-2, transition_begin=50, transition_steps=1000, decay_rate=0.9)
# optimizer = optax.adam(learning_rate = 1e-2)
optimizer = optax.adam(learning_rate = scheduler)

class PT_E_Loss():

    def __call__(self, model, inp, ref):

        pred = jax.vmap(model.net)(inp)[:, 0]

        err = pred-ref

        return jnp.mean(jnp.square(err))

trainer = xce.train.xcTrainer(model=thislocal, optim=optimizer, steps=100000, loss = PT_E_Loss(), do_jit=True)
if TRAIN_NET == 'x':
    inp = [tdrho[:, trainer.model.use]]
else:
    inp = [tdrho]
with jax.default_device(cpus[0]):
    newm = trainer(1, trainer.model, inp, [tFxc])
        

# for epoch in range(100000):
#     total_loss = 0
#     results = thislocal(tdrho[::])
#     loss = eloss(results, tFxc[::])
#     total_loss += loss.item()
#     loss.backward()

#     optimizer.step()
#     optimizer.zero_grad()
#     if epoch%PRINT_EVERY==0:
#         print('total loss {:.12f}'.format(total_loss))
        


In [None]:
jax.vmap(thislocal.net)(tdrho[:, [1,2]])[:, 0].shape

In [None]:
jax.vmap(trainer.model.net)(tdrho[:, [1,2]])