In [64]:
from pyscf import gto,dft,scf
import pickle
import numpy as np
import jax.numpy as jnp
from ase import Atoms
from ase.io import read
import xcquinox as xce
import equinox as eqx
import os, optax, jax, argparse
import faulthandler
import pandas as pd
faulthandler.enable()
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"

In [59]:
def get_mol(atoms, basis='6-311++G**'):
    pos = atoms.positions
    spec = atoms.get_chemical_symbols()
    mol_input = [[s, p] for s, p in zip(spec, pos)]
    try:
        mol = gto.Mole(atom=mol_input, basis=atoms.info.get('basis',basis),spin=atoms.info.get('spin',0))
    except Exception:
        mol = gto.Mole(atom=mol_input, basis=atoms.info.get('basis','STO-3G'),spin=atoms.info.get('spin',0))
    return mol 

def get_rhos(rho, spin):
    rho0 = rho[0,0]
    drho = rho[0,1:4] + rho[1:4,0]
    tau = 0.5*(rho[1,1] + rho[2,2] + rho[3,3])

    if spin != 0:
        rho0_a = rho0[0]
        rho0_b = rho0[1]
        gamma_a, gamma_b = jnp.einsum('ij,ij->j',drho[:,0],drho[:,0]), jnp.einsum('ij,ij->j',drho[:,1],drho[:,1])              
        gamma_ab = jnp.einsum('ij,ij->j',drho[:,0],drho[:,1])
        tau_a, tau_b = tau
    else:
        rho0_a = rho0_b = rho0*0.5
        gamma_a=gamma_b=gamma_ab= jnp.einsum('ij,ij->j',drho[:],drho[:])*0.25
        tau_a = tau_b = tau*0.5
    return rho0_a, rho0_b, gamma_a, gamma_b, gamma_ab, tau_a, tau_b
    
def get_data_synth(xcmodel, xc_func, n=100):
    def get_rho(s, a):
        c0 = 2*(3*np.pi**2)**(1/3)
        c1 = 3/10*(3*np.pi**2)**(2/3)pic
        gamma = c0*s
        tau = c1*a+c0**2*s**2/8
        rho = np.zeros([len(a),6])
        rho[:, 1] = gamma
        rho[:,-1] = tau
        rho[:, 0] = 1
        return rho
    
    s_grid = jnp.concatenate([[0],jnp.exp(jnp.linspace(-10,4,n))])
    rho = []
    for s in s_grid:
        if 'MGGA' in xc_func:
            a_grid = jnp.concatenate([jnp.exp(jnp.linspace(jnp.log((s/100)+1e-8),8,n))])
        else:
            a_grid = jnp.array([0])
        rho.append(get_rho(s, a_grid))
        
    rho = jnp.concatenate(rho)
    
    fxc =  dft.numint.libxc.eval_xc(xc_func,rho.T, spin=0)[0]/dft.numint.libxc.eval_xc('LDA_X',rho.T, spin=0)[0] -1
 
    rho = jnp.asarray(rho)
    
    tdrho = xcmodel.get_descriptors(rho[:,0]/2,rho[:,0]/2,(rho[:,1]/2)**2,(rho[:,1]/2)**2,(rho[:,1]/2)**2,rho[:,5]/2,rho[:,5]/2, spin_scaling=True, mf=mf, dm=dm)
    


    tFxc = jnp.array(fxc)
    return tdrho[0], tFxc

def get_data(mol, xcmodel, xc_func, localnet=None):
    print('mol: ', mol.atom)
    try:
        mf = scf.UKS(mol)
    except:
        mf = dft.RKS(mol)
    mf.xc = 'PBE'
    mf.grids.level = 1
    try:
        mf.kernel()
    except Exception as e:
        print('ERROR IN CALCULATION')
        print(e)
        print('SKIPPING MOLECULE: ', mol)
        return 
    ao = mf._numint.eval_ao(mol, mf.grids.coords, deriv=2)
    dm = mf.make_rdm1()
    if len(dm.shape) == 2:
        #artificially spin-polarize
        dm = np.array([0.5*dm, 0.5*dm])
    print('New DM shape: {}'.format(dm.shape))
    print('ao.shape', ao.shape)

    if localnet.spin_scaling:
        print('spin scaling, indicates exchange network')
        rho_alpha = mf._numint.eval_rho(mol, ao, dm[0], xctype='metaGGA',hermi=True)
        rho_beta = mf._numint.eval_rho(mol, ao, dm[1], xctype='metaGGA',hermi=True)
        fxc_a =  mf._numint.eval_xc(xc_func,(rho_alpha,rho_alpha*0), spin=1)[0]/mf._numint.eval_xc('LDA_X',(rho_alpha,rho_alpha*0), spin=1)[0] -1
        fxc_b =  mf._numint.eval_xc(xc_func,(rho_beta*0,rho_beta), spin=1)[0]/mf._numint.eval_xc('LDA_X',(rho_beta*0,rho_beta), spin=1)[0] -1
        print('fxc with xc_func = {} = {}'.format(fxc_a, xc_func))
        print(f'rho_a.shape={rho_alpha.shape}, rho_b.shape={rho_beta.shape}')
        print(f'fxc_a.shape={fxc_a.shape}, fxc_b.shape={fxc_b.shape}')

        if mol.spin != 0 and sum(mol.nelec)>1:
            print('mol.spin != 0 and sum(mol.nelec) > 1')
            rho = jnp.concatenate([rho_alpha, rho_beta], axis=-1)
            fxc = jnp.concatenate([fxc_a, fxc_b])
            print(f'rho.shape={rho.shape}, fxc.shape={fxc.shape}')
        else:
            print('NOT (mol.spin != 0 and sum(mol.nelec) > 1)')
            rho = rho_alpha
            fxc = fxc_a
            print(f'rho.shape={rho.shape}, fxc.shape={fxc.shape}')
    else:    
        print('no spin scaling, indicates correlation network')
        rho_alpha = mf._numint.eval_rho(mol, ao, dm[0], xctype='metaGGA',hermi=True)
        rho_beta = mf._numint.eval_rho(mol, ao, dm[1], xctype='metaGGA',hermi=True)
        exc = mf._numint.eval_xc(xc_func,(rho_alpha,rho_beta), spin=1)[0]
        print('exc with xc_func = {} = {}'.format(exc, xc_func))
        fxc = exc/mf._numint.eval_xc('LDA_C_PW',(rho_alpha, rho_beta), spin=1)[0] -1
        rho = jnp.stack([rho_alpha,rho_beta], axis=-1)
    
    dm = jnp.array(mf.make_rdm1())
    print('get_data, dm shape = {}'.format(dm.shape))
    ao_eval = jnp.array(mf._numint.eval_ao(mol, mf.grids.coords, deriv=1))
    print(f'ao_eval.shape={ao_eval.shape}')
    rho = jnp.einsum('xij,yik,...jk->xy...i', ao_eval, ao_eval, dm)        
    rho0 = rho[0,0]

    print('rho shape', rho.shape)
    if dm.ndim == 3:
        rho_filt = (jnp.sum(rho0,axis=0) > 1e-6)
    else:
        rho_filt = (rho0 > 1e-6)
    print('rho_filt shape:', rho_filt.shape)

    
    mf.converged=True
    tdrho = xcmodel.get_descriptors(*get_rhos(rho, spin=1), spin_scaling=localnet.spin_scaling, mf=mf, dm=dm)
    print(f'get descriptors tdrho.shape={tdrho.shape}')
    if localnet.spin_scaling:
        if mol.spin != 0 and sum(mol.nelec) > 1:
            print('mol.spin != 0 and sum(mol.nelec) > 1')
            #tdrho not returned in a spin-polarized form regardless,
            #but the enhancement factors sampled as polarized, so double
            if len(tdrho.shape) == 3:
                print('concatenating spin channels along axis=0')
                tdrho = jnp.concatenate([tdrho[0],tdrho[1]], axis=0)
            else:
                print('concatenating along axis=0')
                tdrho = jnp.concatenate([tdrho, tdrho], axis=0)
            rho_filt = jnp.concatenate([rho_filt]*2)
            print(f'tdrho.shape={tdrho.shape}')
            print(f'rho_filt.shape={rho_filt.shape}')
        else:
            #spin == 0 or hydrogen
            tdrho = tdrho[0]
            
    try:
        tdrho = tdrho[rho_filt]
        tFxc = jnp.array(fxc)[rho_filt]
    except:
        tdrho = tdrho[:, rho_filt, :]
        tFxc = jnp.array(fxc)[rho_filt]
    return tdrho, tFxc
level_dict = {'GGA':2, 'MGGA':3, 'NONLOCAL':4}

x_lob_level_dict = {'GGA': 1.804, 'MGGA': 1.174, 'NONLOCAL': 1.174}

class PT_E_Loss(eqx.Module):

    def __call__(self, model, inp, ref):
        if model.spin_scaling and len(inp.shape) == 3:
            #spin scaling shape = (2, N, len(self.use))
            pred = jax.vmap(jax.vmap(model.net), in_axes=1)(inp)[:, 0]
        else:
            pred = jax.vmap(model.net)(inp)[:, 0]

        err = pred-ref

        return jnp.mean(jnp.square(err))

def get_model_info(xcdir, model_dir, tlogf = 'ptlog.dat'):
    refxc = xcdir.split('/')[-1]
    nd_split = model_dir.split('_')
    
    def_mgga_x_use = [1, 2]
    def_mgga_c_use = []
    def_mgga_x_inp = 2
    def_mgga_c_inp = 4
    def_nl_x_use = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
    def_nl_x_inp = len(def_nl_x_use)
    def_nl_c_use = []
    def_nl_c_inp = 16
    
    use2_nl_x_use = [1, 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
    use2_nl_c_use = [0, 1, 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
    use2_nl_x_inp = len(use2_nl_x_use)
    use2_nl_c_inp = len(use2_nl_c_use)
    
    if len(nd_split) == 4:
        #xorc_depth_nodes_level    
        xorc, depth, nodes, level = nd_split
        lr2 = ''
        use = ''
    elif len(nd_split) == 5:
        #xorc_depth_nodes_level_lr2, just denotes a different learning rate schedule used    
        xorc, depth, nodes, level, lr2 = nd_split
        use = ''
    elif len(nd_split) == 6:
        #xorc_depth_nodes_level_lr2_use2, just denotes a different learning rate schedule used    
        xorc, depth, nodes, level, lr2, use = nd_split

    if xorc == 'x':
        if level == 'mgga':
            rinp = def_mgga_x_inp
            ruse = def_mgga_x_use
        elif level == 'nl':
            if not use:
                rinp = def_nl_x_inp
                ruse = def_nl_x_use
            else:
                rinp = use2_nl_x_inp
                ruse = use2_nl_x_use
    elif xorc == 'c':
        if level == 'mgga':
            rinp = def_mgga_c_inp
            ruse = def_mgga_c_use
        elif level == 'nl':
            if not use:
                rinp = def_nl_c_inp
                ruse = def_nl_c_use
            else:
                rinp = use2_nl_c_inp
                ruse = use2_nl_c_use

    xcs = sorted([i for i in os.listdir(os.path.join(xcdir,model_dir)) if 'xc.eqx' in i],
             key = lambda x: int(x.split('.')[-1]))
    if not xcs:
        print('No networks in directory')
        return
    if tlogf:
        try:
            loss = pd.read_csv(os.path.join(xcdir, model_dir, 'ptlog.dat'), delimiter='\t')
            epoch_min = loss[loss['Loss'] == loss['Loss'].min()]['#Epoch'].values[0]
            xcf = [i for i in xcs if int(i.split('.')[-1]) == epoch_min][0]
        except:
            if not xcs:
                print('No networks in directory')
                xcf = ''
            else:
                xcf = xcs[-1]
    else:
        selind = -1
        xcf = xcs[selind]
        
    if level == 'nl':
        rlevel = 'nonlocal'.upper()
    else:
        rlevel = level.upper()
    return (refxc, xorc, int(depth), int(nodes), ruse, int(rinp), rlevel.upper(), xcf)

def gen_network_model(xorc, depth, nodes, ninp, use, level='MGGA', ptpath = None, genverbose=False):
    level_dict = {'GGA':2, 'MGGA':3, 'NONLOCAL':4}
    x_lob_level_dict = {'GGA': 1.804, 'MGGA': 1.174, 'NONLOCAL': 1.174}
    if xorc == 'x':
        net = xce.net.eX(n_input = ninp,
                         n_hidden = nodes,
                         depth = depth,
                         use = use,
                         ueg_limit = True,
                         lob=x_lob_level_dict[level],
                         spin_scaling=True)
    elif xorc == 'c':
        net = xce.net.eC(n_input = ninp,
                         n_hidden = nodes,
                         depth = depth,
                         use = use,
                         ueg_limit = True,
                         spin_scaling=True)

    if ptpath:
        net = eqx.tree_deserialise_leaves(ptpath, net)

    xc = xce.xc.eXC(grid_models=[net], heg_mult=True, level=level_dict[level], verbose=genverbose)
    return xc, net

In [63]:
spins = {
    'Al': 1,
    'B' : 1,
    'Li': 1,
    'Na': 1,
    'Si': 2 ,
    'Be':0,
    'C': 2,
    'Cl': 1,
    'F': 1,
    'H': 1,
    'N': 3,
    'O': 2,
    'P': 3,
    'S': 2
}
g297_path = '../scripts/script_data/haunschild_g2/g2_97.traj'
g297 = read(g297_path, ':')
ng297 = len(g297)
inds = np.arange(0, ng297)
pt_selection = [2, 113, 25, 18, 11, 17, 114, 121, 101, 0, 20, 26, 29, 67, 28, 110, 125, 10, 115, 89, 105, 50]

np.random.seed(seed=92017)

SIZE = 30
SIZE = 0

if SIZE:
    val_selection = [i for i in np.random.choice(inds, size=SIZE) if i not in pt_selection]
else:
    val_selection = [i for i in inds if i not in pt_selection]
pt_atoms = [g297[s] for s in pt_selection]
val_atoms = [g297[s] for s in val_selection]

# for idx, at in enumerate(val_atoms):
#     print(idx, at, at.info)
    
mols = [get_mol(atoms) for atoms in val_atoms]
# mols = [i for i in mols if len(i.atom) < 8]
for idx, i in enumerate(mols):
    print(idx, i.atom, len(i.atom))


0 [['C', array([ 0.      ,  0.      , -1.120678])], ['S', array([0.      , 0.      , 0.420254])]] 2
1 [['F', array([0.      , 0.      , 0.091946])], ['H', array([ 0.      ,  0.      , -0.827512])]] 2
2 [['P', array([0.      , 0.      , 0.947658])], ['P', array([ 0.      ,  0.      , -0.947658])]] 2
3 [['S', array([0.      , 0.      , 0.956078])], ['S', array([ 0.      ,  0.      , -0.956078])]] 2
4 [['S', array([0.      , 0.      , 0.079416])], ['H', array([ 0.      ,  0.      , -1.270651])]] 2
5 [['O', array([ 0.      ,  0.      , -0.997879])], ['S', array([0.     , 0.     , 0.49894])]] 2
6 [['C', array([ 0.013445, -0.731846,  0.      ])], ['C', array([0.013445, 0.477327, 0.      ])], ['H', array([-0.161343  ,  1.52711101,  0.        ])]] 3
7 [['C', array([0., 0., 0.])], ['F', array([0.762931, 0.762931, 0.762931])], ['F', array([-0.762931, -0.762931,  0.762931])], ['F', array([-0.762931,  0.762931, -0.762931])], ['F', array([ 0.762931, -0.762931, -0.762931])]] 5
8 [['S', array([0.    

In [None]:
pbe0d = '/home/awills/Documents/Research/xcquinox_pt/pbe0'
scand = '/home/awills/Documents/Research/xcquinox_pt/scan'

pbe0nets = sorted([i for i in os.listdir(pbe0d) if '_' in i and os.path.isdir(os.path.join(pbe0d, i))])
scannets = sorted([i for i in os.listdir(scand) if '_' in i and os.path.isdir(os.path.join(scand, i))])

val_dct = {'PBE0': {},
           'SCAN': {}}

refxcps = {'PBE0':pbe0d,
           'SCAN':scand}

val_dct['PBE0'] = {'x': {'mgga': {k : [] for k in pbe0nets if 'mgga' in k and 'x' in k},
                         'nl': {k : [] for k in pbe0nets if 'nl' in k and 'x' in k},
                        },
                   'c': {'mgga': {k : [] for k in pbe0nets if 'mgga' in k and 'c' in k},
                         'nl': {k : [] for k in pbe0nets if 'nl' in k and 'c' in k}
                        },
                  }
val_dct['SCAN'] = {'x': {'mgga': {k : [] for k in scannets if 'mgga' in k and 'x' in k},
                         'nl': {k : [] for k in scannets if 'nl' in k and 'x' in k},
                        },
                   'c': {'mgga': {k : [] for k in scannets if 'mgga' in k and 'c' in k},
                         'nl': {k : [] for k in scannets if 'nl' in k and 'c' in k}
                        },
                  }

In [None]:
calc_dct = {}

for krefxc in val_dct.keys():
    krefdct = val_dct[krefxc]
    calc_dct[krefxc] = {}
    for kxorc in krefdct.keys():
        krefxcdct = krefdct[kxorc]
        calc_dct[krefxc][kxorc] = {}
        for klevel in krefxcdct.keys():
            krxcldct = krefxcdct[klevel]
            #DO GET_DATA GENERATION HERE
            data = []
            calcs = []
            rejects = []
            for idx, mol in enumerate(mols):
                try:
                    data.append(get_data(mol, xcmodel=xc, xc_func=ref, localnet=net))
                    calcs.append(mol)
                except:
                    rejects.append(mol)
                    continue
            calc_dct[krefxc][kxorc][klevel] = {'calcs': calcs,
                                              'rejects': rejects,
                                              'calc_losses': {}}
            tdrhos = [i[0] for i in data]
            tfxcs = [i[1] for i in data]
            for knet in krxcldct.keys():
                print(krefxc, kxorc, klevel, knet)
                tup = get_model_info(refxcps[krefxc], knet)
                try:
                    refxc, xorc, depth, nodes, ruse, rinp, level, xcf = tup
                    print(tup)
                except:
                    print('no networks found')
                    continue
                xc, net = gen_network_model(xorc, depth, nodes, rinp, ruse, level, ptpath = os.path.join(refxcps[krefxc], knet, xcf))
                losses = []
                for idx, dat in enumerate(tdrhos):
                    this_tFxc = tfxcs[idx]
                    this_tdrho = dat
                    if ruse:
                        if net.spin_scaling:
                            if len(tdrho.shape) == 3:
                                inp = this_tdrho[:, :, ruse]
                            else:
                                inp = this_tdrho[:, ruse]
                        else:
                            inp = this_tdrho[:, ruse]
                    else:
                        inp = this_tdrho
                    # print(f'inp[0].shape = {inp[0].shape}')
                    loss = PT_E_Loss()(net, inp, this_tFxc)
                    losses.append(loss)
                val_dct[krefxc][kxorc][klevel][knet] = losses
                calc_dct[krefxc][kxorc][klevel]['calc_losses'][knet] = losses
                with open(os.path.join(refxcps[krefxc], knet, 'valdct.pkl'), 'wb') as f:
                    pickle.dump(val_dct, f)
                with open(os.path.join(refxcps[krefxc], knet, 'calcdct.pkl'), 'wb') as f:
                    pickle.dump(calc_dct, f)

In [9]:

# ref = refxc
# data = []
# rejects = []
# for idx, mol in enumerate(mols):
#     try:
#         data.append(get_data(mol, xcmodel=xc, xc_func=ref, localnet=net))
#     except:
#         rejects.append(mol)
#         continue

# tdrhos = [i[0] for i in data]
# tfxcs = [i[1] for i in data]

# if net.spin_scaling:
#     print(f'localnet.spin_scaling: concatenating the data')
#     fdshape = data[0][0].shape
#     print(f'first data shape = {fdshape}')
#     if len(fdshape) == 3:
#         tdrho = jnp.concatenate([d[0] for d in data], axis=1)
#     else:
#         tdrho = jnp.concatenate([d[0] for d in data], axis=0)
#     print(f'concatenated: tdrho.shape={tdrho.shape}')
# else:
#     tdrho = jnp.concatenate([d[0] for d in data])

# tFxc = jnp.concatenate([d[1] for d in data])
# print(f'PRE NAN FILT: tFxc.shape={tFxc.shape}, tdrho.shape={tdrho.shape}')

# nan_filt_rho = ~jnp.any((tdrho != tdrho), axis=-1)
# nan_filt_fxc = ~jnp.isnan(tFxc)
# print(f'nan_filt_rho.shape={nan_filt_rho.shape}')
# print(f'nan_filt_fxc.shape={nan_filt_fxc.shape}')
# tFxc = tFxc[nan_filt_fxc]
# tdrho = tdrho[nan_filt_rho, :]

# print(f'tFxc.shape={tFxc.shape}, tdrho.shape={tdrho.shape}')
# cpus = jax.devices(backend='cpu')


<pyscf.gto.mole.Mole object at 0x72e8b0ba9390> [['P', array([0.      , 0.      , 0.947658])], ['P', array([ 0.      ,  0.      , -0.947658])]] 2
<pyscf.gto.mole.Mole object at 0x72e8b0ba96f0> [['O', array([ 0.      ,  0.      , -0.997879])], ['S', array([0.     , 0.     , 0.49894])]] 2
<pyscf.gto.mole.Mole object at 0x72e8b0ba9660> [['C', array([0.      , 0.      , 0.661747])], ['C', array([ 0.      ,  0.      , -0.661747])], ['F', array([0.        , 1.098469  , 1.38337501])], ['F', array([ 0.        , -1.098469  ,  1.38337501])], ['F', array([ 0.        ,  1.098469  , -1.38337501])], ['F', array([ 0.        , -1.098469  , -1.38337501])]] 6
<pyscf.gto.mole.Mole object at 0x72e8b0bab070> [['C', array([0., 0., 0.])], ['F', array([0.762931, 0.762931, 0.762931])], ['F', array([-0.762931, -0.762931,  0.762931])], ['F', array([-0.762931,  0.762931, -0.762931])], ['F', array([ 0.762931, -0.762931, -0.762931])]] 5
<pyscf.gto.mole.Mole object at 0x72e8b0ba8f70> [['C', array([0., 0., 0.])], ['C'

Initialize <pyscf.gto.mole.Mole object at 0x72e8b0ba9390> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -682.392471889732  <S^2> = 9.467982e-13  2S+1 = 1
New DM shape: (2, 60, 60)
ao.shape (10, 12528, 60)
no spin scaling, indicates correlation network
exc with xc_func = [-0.00097767 -0.00119365 -0.00119365 ... -0.00159786 -0.36854854
 -0.36854854] = pbe0
get_data, dm shape = (2, 60, 60)
ao_eval.shape=(4, 12528, 60)
rho shape (4, 4, 2, 12528)
rho_filt shape: (12528,)
get descriptors tdrho.shape=(12528, 16)
mol:  [['O', array([ 0.      ,  0.      , -0.997879])], ['S', array([0.     , 0.     , 0.49894])]]


Initialize <pyscf.gto.mole.Mole object at 0x72e8b0ba96f0> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -473.138531430299  <S^2> = 1.0027149  2S+1 = 2.2384949
New DM shape: (2, 52, 52)
ao.shape (10, 11288, 52)
no spin scaling, indicates correlation network
exc with xc_func = [-0.00524762 -0.00228259 -0.00067822 ... -0.43018482 -0.43018482
 -0.43018482] = pbe0
get_data, dm shape = (2, 52, 52)
ao_eval.shape=(4, 11288, 52)
rho shape (4, 4, 2, 11288)
rho_filt shape: (11288,)
get descriptors tdrho.shape=(11288, 16)
mol:  [['C', array([0.      , 0.      , 0.661747])], ['C', array([ 0.      ,  0.      , -0.661747])], ['F', array([0.        , 1.098469  , 1.38337501])], ['F', array([ 0.        , -1.098469  ,  1.38337501])], ['F', array([ 0.        ,  1.098469  , -1.38337501])], ['F', array([ 0.        , -1.098469  , -1.38337501])]]


Initialize <pyscf.gto.mole.Mole object at 0x72e8b0ba9660> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -475.2123619051  <S^2> = 7.4251716e-13  2S+1 = 1
New DM shape: (2, 132, 132)
ao.shape (10, 31440, 132)
no spin scaling, indicates correlation network
exc with xc_func = [-0.00056419 -0.00068809 -0.00182286 ... -0.00031377 -0.0007201
 -0.00282036] = pbe0
get_data, dm shape = (2, 132, 132)
ao_eval.shape=(4, 31440, 132)
rho shape (4, 4, 2, 31440)
rho_filt shape: (31440,)
get descriptors tdrho.shape=(31440, 16)
mol:  [['C', array([0., 0., 0.])], ['F', array([0.762931, 0.762931, 0.762931])], ['F', array([-0.762931, -0.762931,  0.762931])], ['F', array([-0.762931,  0.762931, -0.762931])], ['F', array([ 0.762931, -0.762931, -0.762931])]]


Initialize <pyscf.gto.mole.Mole object at 0x72e8b0bab070> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -437.217292208887  <S^2> = 4.1247006e-12  2S+1 = 1
New DM shape: (2, 110, 110)
ao.shape (10, 24672, 110)
no spin scaling, indicates correlation network
exc with xc_func = [-1.81782888e-03 -2.19689199e-04 -2.19692653e-04 ... -3.61892080e-04
 -2.84265937e+00 -2.84265937e+00] = pbe0
get_data, dm shape = (2, 110, 110)
ao_eval.shape=(4, 24672, 110)
rho shape (4, 4, 2, 24672)
rho_filt shape: (24672,)
get descriptors tdrho.shape=(24672, 16)
mol:  [['C', array([0., 0., 0.])], ['C', array([0.      , 0.      , 1.301399])], ['C', array([ 0.      ,  0.      , -1.301399])], ['H', array([0.        , 0.925815  , 1.86943001])], ['H', array([ 0.        , -0.925815  ,  1.86943001])], ['H', array([ 0.925815  ,  0.        , -1.86943001])], ['H', array([-0.925815  ,  0.        , -1.86943001])]]


Initialize <pyscf.gto.mole.Mole object at 0x72e8b0ba8f70> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -116.527040748631  <S^2> = 1.5667467e-12  2S+1 = 1
New DM shape: (2, 94, 94)
ao.shape (10, 25296, 94)
no spin scaling, indicates correlation network
exc with xc_func = [-0.00086626 -0.00116412 -0.00254959 ... -0.00254959 -0.00036546
 -0.00035847] = pbe0
get_data, dm shape = (2, 94, 94)
ao_eval.shape=(4, 25296, 94)
rho shape (4, 4, 2, 25296)
rho_filt shape: (25296,)
get descriptors tdrho.shape=(25296, 16)
mol:  [['O', array([-1.029332, -0.445176,  0.      ])], ['C', array([0.     , 0.41975, 0.     ])], ['O', array([1.159211, 0.11656 , 0.      ])], ['H', array([-0.646185, -1.337139,  0.      ])], ['H', array([-0.39285   ,  1.44756601,  0.        ])]]


Initialize <pyscf.gto.mole.Mole object at 0x72e8b0ba8fd0> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -189.625165692314  <S^2> = 9.890222e-11  2S+1 = 1
New DM shape: (2, 80, 80)
ao.shape (10, 19184, 80)
no spin scaling, indicates correlation network
exc with xc_func = [-8.12418463e-04 -1.70844637e-04 -1.89755388e-04 ... -4.09849631e-01
 -4.09849631e-01 -4.09849631e-01] = pbe0
get_data, dm shape = (2, 80, 80)
ao_eval.shape=(4, 19184, 80)
rho shape (4, 4, 2, 19184)
rho_filt shape: (19184,)
get descriptors tdrho.shape=(19184, 16)
mol:  [['C', array([0.      , 0.763812, 0.      ])], ['C', array([1.293467, 1.044723, 0.      ])], ['Cl', array([-0.626081, -0.865349,  0.      ])], ['H', array([-0.789043  ,  1.50465501,  0.        ])], ['H', array([2.05549701, 0.275045  , 0.        ])], ['H', array([1.61612701, 2.08003201, 0.        ])]]


Initialize <pyscf.gto.mole.Mole object at 0x72e77bd263e0> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -537.925558087699  <S^2> = 4.1104897e-12  2S+1 = 1
New DM shape: (2, 95, 95)
ao.shape (10, 22496, 95)
no spin scaling, indicates correlation network
exc with xc_func = [-0.00397457 -0.00300487 -0.00168806 ... -0.35548505 -0.35548505
 -0.35548505] = pbe0
get_data, dm shape = (2, 95, 95)
ao_eval.shape=(4, 22496, 95)
rho shape (4, 4, 2, 22496)
rho_filt shape: (22496,)
get descriptors tdrho.shape=(22496, 16)
mol:  [['C', array([0.      , 0.      , 0.661747])], ['C', array([ 0.      ,  0.      , -0.661747])], ['F', array([0.        , 1.098469  , 1.38337501])], ['F', array([ 0.        , -1.098469  ,  1.38337501])], ['F', array([ 0.        ,  1.098469  , -1.38337501])], ['F', array([ 0.        , -1.098469  , -1.38337501])]]


Initialize <pyscf.gto.mole.Mole object at 0x72e77bd24220> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -475.212361905099  <S^2> = 7.3185902e-13  2S+1 = 1
New DM shape: (2, 132, 132)
ao.shape (10, 31440, 132)
no spin scaling, indicates correlation network
exc with xc_func = [-0.00056419 -0.00068809 -0.00182286 ... -0.00031377 -0.0007201
 -0.00282036] = pbe0
get_data, dm shape = (2, 132, 132)
ao_eval.shape=(4, 31440, 132)
rho shape (4, 4, 2, 31440)
rho_filt shape: (31440,)
get descriptors tdrho.shape=(31440, 16)
mol:  [['C', array([-1.319773,  0.323264,  0.      ])], ['O', array([0.      , 0.886226, 0.      ])], ['H', array([-1.98878301,  1.182938  ,  0.        ])], ['H', array([-1.47994701, -0.292376  ,  0.889614  ])], ['H', array([-1.47994701, -0.292376  , -0.889614  ])], ['N', array([ 1.04264, -0.03814,  0.     ])], ['O', array([ 0.696105, -1.170075,  0.      ])]]


Initialize <pyscf.gto.mole.Mole object at 0x72e77bd27f10> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -244.8313779461  <S^2> = 3.7877257e-11  2S+1 = 1
New DM shape: (2, 109, 109)
ao.shape (10, 25784, 109)
no spin scaling, indicates correlation network
exc with xc_func = [-0.00041181 -0.00104945 -0.00092364 ... -0.29424018 -0.29424018
 -0.29424018] = pbe0
get_data, dm shape = (2, 109, 109)
ao_eval.shape=(4, 25784, 109)
rho shape (4, 4, 2, 25784)
rho_filt shape: (25784,)
get descriptors tdrho.shape=(25784, 16)
mol:  [['O', array([ 0.      ,  0.      , -0.997879])], ['S', array([0.     , 0.     , 0.49894])]]


Initialize <pyscf.gto.mole.Mole object at 0x72e77bd24940> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -473.138531432376  <S^2> = 1.002715  2S+1 = 2.238495
New DM shape: (2, 52, 52)
ao.shape (10, 11288, 52)
no spin scaling, indicates correlation network
exc with xc_func = [-0.00524612 -0.00228201 -0.00067808 ... -0.43018495 -0.43018495
 -0.43018495] = pbe0
get_data, dm shape = (2, 52, 52)
ao_eval.shape=(4, 11288, 52)
rho shape (4, 4, 2, 11288)
rho_filt shape: (11288,)
get descriptors tdrho.shape=(11288, 16)
mol:  [['C', array([0.      , 0.187439, 0.      ])], ['O', array([0.172598  , 1.35977101, 0.        ])], ['F', array([-1.249462, -0.324212,  0.      ])], ['C', array([ 1.016674, -0.915681,  0.      ])], ['H', array([ 2.02085701, -0.494499  ,  0.        ])], ['H', array([ 0.871739  , -1.54815501,  0.880587  ])], ['H', array([ 0.871739  , -1.54815501, -0.880587  ])]]


Initialize <pyscf.gto.mole.Mole object at 0x72e77bd24f10> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -252.905508994251  <S^2> = 2.1127988e-11  2S+1 = 1
New DM shape: (2, 109, 109)
ao.shape (10, 26008, 109)
no spin scaling, indicates correlation network
exc with xc_func = [-9.37122250e-05 -2.68820005e-04 -1.65125590e-04 ... -9.14995513e-01
 -9.14995513e-01 -9.14995513e-01] = pbe0
get_data, dm shape = (2, 109, 109)
ao_eval.shape=(4, 26008, 109)
rho shape (4, 4, 2, 26008)
rho_filt shape: (26008,)
get descriptors tdrho.shape=(26008, 16)
mol:  [['C', array([0.      , 0.      , 0.769005])], ['Cl', array([ 0.        ,  1.48857001, -0.216701  ])], ['Cl', array([ 0.        , -1.48857001, -0.216701  ])], ['H', array([-0.900326  ,  0.        ,  1.37690301])], ['H', array([0.900326  , 0.        , 1.37690301])]]


Initialize <pyscf.gto.mole.Mole object at 0x72e77bd266b0> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -959.307683207229  <S^2> = 2.0712321e-12  2S+1 = 1
New DM shape: (2, 96, 96)
ao.shape (10, 21496, 96)
no spin scaling, indicates correlation network
exc with xc_func = [-0.00179105 -0.00059081 -0.00165719 ... -0.30593888 -0.30593888
 -0.30593888] = pbe0
get_data, dm shape = (2, 96, 96)
ao_eval.shape=(4, 21496, 96)
rho shape (4, 4, 2, 21496)
rho_filt shape: (21496,)
get descriptors tdrho.shape=(21496, 16)
mol:  [['B', array([0., 0., 0.])], ['Cl', array([0.        , 1.75062101, 0.        ])], ['Cl', array([ 1.51608201, -0.87531   ,  0.        ])], ['Cl', array([-1.51608201, -0.87531   ,  0.        ])]]


Initialize <pyscf.gto.mole.Mole object at 0x72e77bd255a0> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -1405.01141832645  <S^2> = 1.6349588e-11  2S+1 = 1
New DM shape: (2, 112, 112)
ao.shape (10, 23336, 112)
no spin scaling, indicates correlation network
exc with xc_func = [-1.51591693e-03 -1.21750195e-03 -1.56720553e-03 ... -2.35960702e+00
 -2.35960702e+00 -2.35960702e+00] = pbe0
get_data, dm shape = (2, 112, 112)
ao_eval.shape=(4, 23336, 112)
rho shape (4, 4, 2, 23336)
rho_filt shape: (23336,)
get descriptors tdrho.shape=(23336, 16)
mol:  [['B', array([0., 0., 0.])], ['F', array([0.      , 1.308815, 0.      ])], ['F', array([ 1.133467, -0.654408,  0.      ])], ['F', array([-1.133467, -0.654408,  0.      ])]]


Initialize <pyscf.gto.mole.Mole object at 0x72e77bd25f60> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -324.336275765114  <S^2> = 1.961098e-12  2S+1 = 1
New DM shape: (2, 88, 88)
ao.shape (10, 20024, 88)
no spin scaling, indicates correlation network
exc with xc_func = [-1.69699472e-03 -3.70814878e-04 -4.61823803e-04 ... -2.35913448e+00
 -2.35913448e+00 -2.35913448e+00] = pbe0
get_data, dm shape = (2, 88, 88)
ao_eval.shape=(4, 20024, 88)
rho shape (4, 4, 2, 20024)
rho_filt shape: (20024,)
get descriptors tdrho.shape=(20024, 16)
mol:  [['C', array([0., 0., 0.])], ['Cl', array([1.030124, 1.030124, 1.030124])], ['Cl', array([-1.030124, -1.030124,  1.030124])], ['Cl', array([-1.030124,  1.030124, -1.030124])], ['Cl', array([ 1.030124, -1.030124, -1.030124])]]


Initialize <pyscf.gto.mole.Mole object at 0x72e77bd26aa0> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -1878.14149018999  <S^2> = 3.7161385e-12  2S+1 = 1
New DM shape: (2, 142, 142)
ao.shape (10, 29224, 142)
no spin scaling, indicates correlation network
exc with xc_func = [-1.72014924e-03 -2.59090456e-04 -3.91440771e-04 ... -5.73756456e-05
 -2.84470685e+00 -2.84470685e+00] = pbe0
get_data, dm shape = (2, 142, 142)
ao_eval.shape=(4, 29224, 142)
rho shape (4, 4, 2, 29224)
rho_filt shape: (29224,)
get descriptors tdrho.shape=(29224, 16)
mol:  [['N', array([0.      , 0.      , 0.322871])], ['O', array([ 0.      ,  1.101927, -0.141256])], ['O', array([ 0.      , -1.101927, -0.141256])]]


Initialize <pyscf.gto.mole.Mole object at 0x72e77bd25930> in UKS object of <class 'pyscf.dft.uks.UKS'>
Initialize <pyscf.gto.mole.Mole object at 0x72e77bd25930> in RKS object of <class 'pyscf.dft.rks.RKS'>
Initialize <pyscf.gto.mole.Mole object at 0x72e77be051e0> in UKS object of <class 'pyscf.dft.uks.UKS'>


mol:  [['O', array([0.      , 0.      , 1.315714])], ['C', array([0.      , 0.      , 0.140943])], ['F', array([ 0.      ,  1.060556, -0.631743])], ['F', array([ 0.      , -1.060556, -0.631743])]]
converged SCF energy = -312.817990698189  <S^2> = 2.0133228e-11  2S+1 = 1
New DM shape: (2, 88, 88)
ao.shape (10, 19824, 88)
no spin scaling, indicates correlation network
exc with xc_func = [-0.00152257 -0.0024238  -0.00287804 ... -0.0001242  -0.00164712
 -0.00020123] = pbe0
get_data, dm shape = (2, 88, 88)
ao_eval.shape=(4, 19824, 88)
rho shape (4, 4, 2, 19824)
rho_filt shape: (19824,)
get descriptors tdrho.shape=(19824, 16)
mol:  [['C', array([0.     , 0.     , 0.49611])], ['F', array([ 0.      ,  1.099786, -0.28882 ])], ['F', array([ 0.      , -1.099786, -0.28882 ])], ['H', array([-0.909099,  0.      ,  1.111055])], ['H', array([0.909099, 0.      , 1.111055])]]


Initialize <pyscf.gto.mole.Mole object at 0x72e77be07520> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -238.822480433493  <S^2> = 3.7836401e-13  2S+1 = 1
New DM shape: (2, 80, 80)
ao.shape (10, 19336, 80)
no spin scaling, indicates correlation network
exc with xc_func = [-1.35849699e-03 -2.97947818e-04 -3.69098277e-05 ... -5.19554999e-04
 -4.56091422e-04 -3.73000182e-01] = pbe0
get_data, dm shape = (2, 80, 80)
ao_eval.shape=(4, 19336, 80)
rho shape (4, 4, 2, 19336)
rho_filt shape: (19336,)
get descriptors tdrho.shape=(19336, 16)
mol:  [['C', array([0.      , 0.      , 0.860477])], ['C', array([ 0.      ,  0.645288, -0.50076 ])], ['C', array([ 0.      , -0.645288, -0.50076 ])], ['H', array([-0.911889  ,  0.        ,  1.46224101])], ['H', array([0.911889  , 0.        , 1.46224101])], ['H', array([ 0.        ,  1.57654301, -1.039113  ])], ['H', array([ 0.        , -1.57654301, -1.039113  ])]]


Initialize <pyscf.gto.mole.Mole object at 0x72e77be06800> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -116.491407945108  <S^2> = 1.5649704e-12  2S+1 = 1
New DM shape: (2, 94, 94)
ao.shape (10, 23320, 94)
no spin scaling, indicates correlation network
exc with xc_func = [-1.78158560e-03 -3.78102382e-04 -3.72304377e-04 ... -3.91531882e-01
 -3.91531882e-01 -3.91531882e-01] = pbe0
get_data, dm shape = (2, 94, 94)
ao_eval.shape=(4, 23320, 94)
rho shape (4, 4, 2, 23320)
rho_filt shape: (23320,)
get descriptors tdrho.shape=(23320, 16)
PRE NAN FILT: tFxc.shape=(356713,), tdrho.shape=(356713, 16)
nan_filt_rho.shape=(356713,)
nan_filt_fxc.shape=(356713,)
tFxc.shape=(356713,), tdrho.shape=(356713, 16)


In [33]:
# losses = []
# for idx, dat in enumerate(tdrhos):
#     this_tFxc = tfxcs[idx]
#     this_tdrho = dat
#     if ruse:
#         if net.spin_scaling:
#             if len(tdrho.shape) == 3:
#                 inp = this_tdrho[:, :, ruse]
#             else:
#                 inp = this_tdrho[:, ruse]
#         else:
#             inp = this_tdrho[:, ruse]
#     else:
#         inp = this_tdrho
#     # print(f'inp[0].shape = {inp[0].shape}')
#     loss = PT_E_Loss()(net, inp, this_tFxc)
#     losses.append(loss)

In [34]:
# losses

[Array(0.30406356, dtype=float64),
 Array(0.16487314, dtype=float64),
 Array(0.24927765, dtype=float64),
 Array(0.26863706, dtype=float64),
 Array(0.13990741, dtype=float64),
 Array(0.14148157, dtype=float64),
 Array(0.17739353, dtype=float64),
 Array(0.24927765, dtype=float64),
 Array(0.14128046, dtype=float64),
 Array(0.16448226, dtype=float64),
 Array(0.17070659, dtype=float64),
 Array(0.21838447, dtype=float64),
 Array(0.29479163, dtype=float64),
 Array(0.27463935, dtype=float64),
 Array(0.29111053, dtype=float64),
 Array(0.22563895, dtype=float64),
 Array(0.20694636, dtype=float64),
 Array(0.1305856, dtype=float64)]

In [14]:
# loss

Array(0.36802255, dtype=float64)

In [None]:
# if ruse:
#     if net.spin_scaling:
#         if len(tdrho.shape) == 3:
#             inp = [tdrho[:, :, ruse]]
#         else:
#             inp = [tdrho[:, ruse]]
#     else:
#         inp = [tdrho[:, ruse]]
# else:
#     inp = [tdrho]
# print(f'inp[0].shape = {inp[0].shape}')
# loss = PT_E_Loss()(net, inp[0], tFxc)