In [1]:
from ase import Atoms
from ase.io import read, write
import xcquinox as xce
import torch, jax, optax
import numpy as np
import equinox as eqx
import jax.numpy as jnp
import pyscfad as psa
import os, sys, pickle
from pyscf import dft, scf, gto, df
from pyscf.pbc import scf as scfp
from pyscf.pbc import gto as gtop
from pyscf.pbc import dft as dftp
# from pyscfad.pbc import scf as scfp
# from pyscfad.pbc import gto as gtop
# from pyscfad.pbc import dft as dftp

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"
import pandas as pd
from mp_api.client import MPRester
from mldftdat.density import get_exchange_descriptors2
from mldftdat.analyzers import RKSAnalyzer
from emmet.core.summary import HasProps
from ase.build import bulk
from pymatgen.io.ase import AseAtomsAdaptor

# os.environ['JAX_LOG_COMPILES'] = '1'
# jax.config.update('jax_log_compiles', True)

CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [2]:
from xcquinox.net import eX, eC

In [3]:
refdp = '/home/awills/Documents/Research/datasets/borlido2019/smalls_data'
trjs = sorted([i for i in os.listdir(refdp) if '_' in i and 'traj' in i and 'smalls' not in i], key = lambda x: int(x.split('_')[0]))
mols = []
mfs = []
rdms = []
init_dms = []
rmoes = []
rmooccs = []
rmocoeffs = []
refgaps = []
reftotes = []
gws = []
vns = []
gcs = []
aos = []
ts = []
proportions = []
hcs = []
eris = []
ss = []
ogds = []
GRID_SUBSAMPLE_SIZE=100
for idx, i in enumerate(trjs[:2]):
    atoms = read(os.path.join(refdp, i))
    fbase = i.split('.')[0]
    rdm = np.load(os.path.join(refdp, fbase+'.dm.npy'))
    if rdm.shape[0] < 19:
        rdms.append(rdm)
        ogds.append(rdm.shape)
        refgaps.append(atoms.info['reference'])
        rmoes.append(np.load(os.path.join(refdp, fbase+'.mo_energy.npy')))
        rmooccs.append(np.load(os.path.join(refdp, fbase+'.mo_occ.npy')))
        rmocoeffs.append(np.load(os.path.join(refdp, fbase+'.mo_coeff.npy')))
        print(atoms, atoms.info, atoms.calc.results)
        atoms.info['e_calc'] = atoms.calc.results['energy']
        reftotes.append(atoms.info['e_calc'])
        pos = atoms.positions
        spec = atoms.get_chemical_symbols()
        mol_input = [[s,p] for s,p in zip(spec,pos)]
        cell = np.array(atoms.cell)
        mol = gtop.Cell(a=cell, rcut=0.1, atom=mol_input, basis='lanl2dz', charge=0, pseudo='gth-pbe', verbose=9, spin=0)
        mol.exp_to_discard = 0.1
        mol.build()
        mols.append(mol)
        mf = dftp.RKS(mol)
        mf.grids.level = 1
        mf.grids.build()
        gw = mf.grids.weights[::GRID_SUBSAMPLE_SIZE]
        gc = mf.grids.coords[::GRID_SUBSAMPLE_SIZE]
        proportion = len(gw)/len(mf.grids.weights)
        proportions.append(proportion)
        mf.grids.weights = gw
        mf.grids.coords = gc
        gws.append(gw)
        gcs.append(gc)
        # aos.append(mf._numint.eval_ao(mol, mf.grids.coords, deriv=2))
        init_dms.append(mf.get_init_guess())
        aos.append(jnp.array(mf._numint.eval_ao(mol, gc, deriv=2)))
        hcs.append(jnp.array(mf.get_hcore()))
        ts.append(jnp.array(mol.intor('int1e_kin')))
        vns.append(jnp.array(mf.energy_nuc()))
        eris.append(jnp.array(mol.intor('int2e')))
        # print('eri shape', eri.shape)
        # print('ao_eval shape', ao_eval.shape)
        ss.append(jnp.linalg.inv(jnp.linalg.cholesky(mol.intor('int1e_ovlp'))))
        mfs.append(mf)

Atoms(symbols='Ar', pbc=True, cell=[[0.0, 2.820386, 2.820386], [2.820386, 0.0, 2.820386], [2.820386, 2.820386, 0.0]], calculator=SinglePointCalculator(...)) {'borlido_idx': 35, 'mp_id': 'mp-23155', 'reference': 14.15} {'energy': -19.14554484483416}
arg.atm = [[ 8 20  1 23  0  0]]
arg.bas = [[ 0  0  2  1  0 24 26  0]
 [ 0  0  1  1  0 28 29  0]
 [ 0  1  2  1  0 30 32  0]
 [ 0  1  1  1  0 34 35  0]]
arg.env = [ 0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  2.613       0.5736     -2.65356551  2.11503642  0.2014      0.75955569
  7.86        0.7387     -2.13150432  2.02105235  0.2081      0.41003836]
ecpbas  = []
System: uname_result(system='Linux', node='raider', release='6.5.0-27-generic', version='#28~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Mar 15 10:51:06 UTC 2', machine='x86_6

In [4]:
rdms = xce.utils.pad_array_list(rdms)
init_dms = xce.utils.pad_array_list(init_dms)
rmoes = xce.utils.pad_array_list(rmoes)
rmooccs = xce.utils.pad_array_list(rmooccs)
rmocoeffs = xce.utils.pad_array_list(rmocoeffs)
gws = xce.utils.pad_array_list(gws)
gcs = xce.utils.pad_array_list(gcs)
aos = xce.utils.pad_array_list(aos)
hcs = xce.utils.pad_array_list(hcs)
ts = xce.utils.pad_array_list(ts)
# vns = xce.utils.pad_array_list(vns)
eris = xce.utils.pad_array_list(eris)
ss = xce.utils.pad_array_list(ss)
for idx, mf in enumerate(mfs):
    mf.grids.weights = gws[idx]
    mf.grids.coords = gcs[idx]

[8, 8] (8, 8)
[8, 8] (8, 8)
[8] (8,)
[8] (8,)
[8, 8] (8, 8)
[8574] (8574,)
[8574, 3] (8574, 3)
[10, 8574, 8] (10, 8574, 8)
[8, 8] (8, 8)
[8, 8] (8, 8)
[8, 8, 8, 8] (8, 8, 8, 8)
[8, 8] (8, 8)


In [5]:
class E_Rho_Gap_loss(eqx.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, model, mol, mf, gc, gw, inp_dm, mo_occ, ogd, refE, refDM, refGap):

        #get static matrices from mol and mf
        ao_eval = jnp.array(mf._numint.eval_ao(mol, gc, deriv=2))
        t = jnp.array(mol.intor('int1e_kin'))
        vn = jnp.array(mol.intor('int1e_nuc'))
        hc = jnp.array(mf.get_hcore())
        eri = jnp.array(mol.intor('int2e'))
        s = jnp.linalg.inv(jnp.linalg.cholesky(mol.intor('int1e_ovlp')))
        homo_i = jnp.max(jnp.nonzero(mo_occ, size=inp_dm.shape[0])[0])


        #vxc function for gradient
        vgf = lambda x: model(x, ao_eval, mf.grids.weights)
        print('Generating dmp, moep, mocoep')
        dmp, moep, mocoep = xce.utils.get_dm_moe(inp_dm, eri, vgf, mo_occ, hc, s, ogd)

        veff = xce.utils.get_veff()(dmp, eri)
        print('Getting exc, vxc from dmp')
        exc, vxc = eqx.filter_value_and_grad(model)(dmp, ao_eval, gw)
        e_pred = xce.utils.energy_tot()(dmp, hc, veff+vxc+vn)
        # rho_pred = xce.utils.get_rho()(dmp, ao_eval)
        # rho_ref = xce.utils.get_rho()(refDM, ao_eval)
        print('Getting rho_pred and rho_ref')
        rho_pred = jnp.einsum('xij,yik,...jk->xy...i', ao_eval, ao_eval, dmp)[0,0]+1e-10
        rho_ref = jnp.einsum('xij,yik,...jk->xy...i', ao_eval, ao_eval, refDM)[0,0]+1e-10
        print('Getting gap_pred')
        gap_pred = moep[homo_i+1]-moep[homo_i]

        eL = (e_pred-refE)**2

        rho_L = jnp.sum(gw*(rho_pred-rho_ref)**2)/mol.nelectron**2

        gap_L = (gap_pred-refGap)**2

        return jnp.sqrt(eL+rho_L+gap_L)[0]


class E_DM_Gap_loss(eqx.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, model, mol, mf, gc, gw, inp_dm, mo_occ, ogd, refE, refDM, refGap):

        #get static matrices from mol and mf
        ao_eval = jnp.array(jax.lax.stop_gradient(mf._numint.eval_ao(mol, gc, deriv=2)))
        t = jnp.array(mol.intor('int1e_kin'))
        vn = jnp.array(mol.intor('int1e_nuc'))
        hc = jnp.array(mf.get_hcore())
        eri = jnp.array(mol.intor('int2e'))
        # print('eri shape', eri.shape)
        # print('ao_eval shape', ao_eval.shape)
        s = jnp.linalg.inv(jnp.linalg.cholesky(mol.intor('int1e_ovlp')))
        homo_i = jnp.max(jnp.nonzero(mo_occ, size=inp_dm.shape[0])[0])

        ao_eval = xce.utils.pad_array(ao_eval, ao_eval, shape=(ao_eval.shape[0], gw.shape[0], inp_dm.shape[0]))
        s = xce.utils.pad_array(s, inp_dm)
        t = xce.utils.pad_array(t, inp_dm)
        hc = xce.utils.pad_array(hc, inp_dm)
        vn = xce.utils.pad_array(vn, inp_dm)
        eri = xce.utils.pad_array(eri, eri, shape=(inp_dm.shape[0], inp_dm.shape[0], inp_dm.shape[0], inp_dm.shape[0]))

        #vxc function for gradient
        vgf = lambda x: model(x, ao_eval, gw)
        print('Generating dmp, moep, mocoep')
        print(f'inp_dm, eri, mo_occ, hc, s shapes: {inp_dm.shape}, {eri.shape}, {mo_occ.shape}, {hc.shape}, {s.shape}')
        dmp, moep, mocoep = xce.utils.get_dm_moe(inp_dm, eri, vgf, mo_occ, hc, s, ogd)
        print(f'dmp, moep, mocoep shapes: {dmp.shape}, {moep.shape}, {mocoep.shape}')
        
        dmp = xce.utils.pad_array(dmp, inp_dm)
        moep = xce.utils.pad_array(moep, moep,  shape=(dmp.shape[0],))
        # print(f'Getting veff from dmp, eri (shapes = {dmp.shape}, {eri.shape})')
        # veff = xce.utils.get_veff()(dmp, eri)
        # print('Getting exc, vxc from dmp')
        # exc, vxc = eqx.filter_value_and_grad(model)(dmp, ao_eval, gw)
        # e_pred = xce.utils.energy_tot()(dmp, hc, veff+vxc+vn)
        # print('e_pred, refE: ', e_pred, refE)
        # rho_pred = xce.utils.get_rho()(dmp, ao_eval)
        # rho_ref = xce.utils.get_rho()(refDM, ao_eval)
        print('Getting gap_pred')
        gap_pred = moep[homo_i+1]-moep[homo_i]
        print('gap_pred, refGap: ', gap_pred, refGap)

        # eL = (e_pred-refE)**2
        eL = 0

        dm_L = jnp.sum( (dmp-refDM)**2)

        gap_L = (gap_pred-refGap)**2

        # return jnp.sqrt(eL+dm_L+gap_L)[0]
        return jnp.sqrt(dm_L+gap_L)

class DM_Gap_loss(eqx.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, model, ao, t, hc, eri, s, mf, gc, gw, inp_dm, mo_occ, ogd, refE, refDM, refGap):

        homo_i = jnp.max(jnp.nonzero(mo_occ, size=inp_dm.shape[0])[0])
        #vxc function for gradient
        vgf = lambda x: model(x, ao, gw)
        print('Generating dmp, moep, mocoep')
        print(f'inp_dm, eri, mo_occ, hc, s shapes: {inp_dm.shape}, {eri.shape}, {mo_occ.shape}, {hc.shape}, {s.shape}')
        dmp, moep, mocoep = xce.utils.get_dm_moe(inp_dm, eri, vgf, mo_occ, hc, s, ogd)
        print(f'dmp, moep, mocoep shapes: {dmp.shape}, {moep.shape}, {mocoep.shape}')
        
        dmp = xce.utils.pad_array(dmp, inp_dm)
        moep = xce.utils.pad_array(moep, moep,  shape=(dmp.shape[0],))
        print('Getting gap_pred')
        gap_pred = moep[homo_i+1]-moep[homo_i]
        print('gap_pred, refGap: ', gap_pred, refGap)

        dm_L = jnp.sum( (dmp-refDM)**2)

        gap_L = (gap_pred-refGap)**2

        return jnp.sqrt(dm_L+gap_L)

class DM_Gap_Loop_loss(eqx.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, model, aos, ts, hcs, eris, ss, mfs, gcs, gws, inp_dms, mo_occs, mo_coeffs, mo_energies, ogds, refEs, refDMs, refGaps):
        total_loss = 0
        for idx in range(len(aos)):
            mo_occ = mo_occs[idx]
            mo_coeff = mo_coeffs[idx]
            mo_energy = mo_energies[idx]
            inp_dm = inp_dms[idx]
            ao = aos[idx]
            gw = gws[idx]
            gc = gcs[idx]
            eri = eris[idx]
            hc = hcs[idx]
            s = ss[idx]
            ogd = ogds[idx]
            refGap = refGaps[idx]
            refDM = refDMs[idx]

            #for non-local
            mf = mfs[idx]
            #just say it's converged
            mf.converged = True
            #assign reference calculation values to the mean field object
            mf.mo_occ = mo_occ
            mf.mo_coeff = mo_coeff
            mf.mo_energy = mo_energies
            #set flags for use in the RKSAnalyzer
            mf.idm = True
            mf.iao = True
            mf.dm = inp_dm
            mf.ao = ao
            
            homo_i = jnp.max(jnp.nonzero(mo_occ, size=inp_dm.shape[0])[0])
            #vxc function for gradient
            vgf = lambda x: model(x, ao, gw, mf=mf)
            print('Generating dmp, moep, mocoep')
            print(f'inp_dm, eri, mo_occ, hc, s shapes: {inp_dm.shape}, {eri.shape}, {mo_occ.shape}, {hc.shape}, {s.shape}')
            dmp, moep, mocoep = xce.utils.get_dm_moe(inp_dm, eri, vgf, mo_occ, hc, s, ogd)
            print(f'dmp, moep, mocoep shapes: {dmp.shape}, {moep.shape}, {mocoep.shape}')
            
            dmp = xce.utils.pad_array(dmp, inp_dm)
            moep = xce.utils.pad_array(moep, moep,  shape=(dmp.shape[0],))
            print('Getting gap_pred')
            gap_pred = moep[homo_i+1]-moep[homo_i]
            print('gap_pred, refGap: ', gap_pred, refGap)
    
            dm_L = jnp.sum( (dmp-refDM)**2)
    
            gap_L = (gap_pred-refGap)**2
    
            total_loss += jnp.sqrt(dm_L+gap_L)
        return total_loss

class Gap_Loop_loss(eqx.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, model, aos, ts, hcs, eris, ss, mfs, gcs, gws, inp_dms, mo_occs, mo_coeffs, mo_energies, ogds, refEs, refDMs, refGaps):
        total_loss = 0
        for idx in range(len(aos)):
            mo_occ = mo_occs[idx]
            mo_coeff = mo_coeffs[idx]
            mo_energy = mo_energies[idx]
            inp_dm = inp_dms[idx]
            ao = aos[idx]
            gw = gws[idx]
            gc = gcs[idx]
            eri = eris[idx]
            hc = hcs[idx]
            s = ss[idx]
            ogd = ogds[idx]
            refGap = refGaps[idx]
            refDM = refDMs[idx]

            #for non-local
            mf = mfs[idx]
            #just say it's converged
            mf.converged = True
            #assign reference calculation values to the mean field object
            mf.mo_occ = mo_occ
            mf.mo_coeff = mo_coeff
            mf.mo_energy = mo_energies
            #set flags for use in the RKSAnalyzer
            mf.idm = True
            mf.iao = True
            mf.dm = inp_dm
            mf.ao = ao
            
            homo_i = jnp.max(jnp.nonzero(mo_occ, size=inp_dm.shape[0])[0])
            #vxc function for gradient
            vgf = lambda x: model(x, ao, gw, mf=mf)
            # print('Generating dmp, moep, mocoep')
            # print(f'inp_dm, eri, mo_occ, hc, s shapes: {inp_dm.shape}, {eri.shape}, {mo_occ.shape}, {hc.shape}, {s.shape}')
            dmp, moep, mocoep = xce.utils.get_dm_moe(inp_dm, eri, vgf, mo_occ, hc, s, ogd)
            print('moep, ', moep)
            # print(f'dmp, moep, mocoep shapes: {dmp.shape}, {moep.shape}, {mocoep.shape}')
            
            # moep = xce.utils.pad_array(moep, moep,  shape=(dmp.shape[0],))
            print('Getting gap_pred')
            gap_pred = moep[homo_i+1]-moep[homo_i]
            print('gap_pred, refGap: ', gap_pred, refGap)
    
    
            gap_L = ((gap_pred-refGap)**2)/0.01
            if idx == 0:
                total_loss = jnp.sqrt(gap_L)
            else:
                total_loss += jnp.sqrt(gap_L)
        return total_loss

class Gap_loss(eqx.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, model, aos, ts, hcs, eris, ss, mfs, gcs, gws, inp_dms, mo_occs, mo_coeffs, mo_energies, ogds, refEs, refDMs, refGaps):
        mo_occ = mo_occs
        mo_coeff = mo_coeffs
        mo_energy = mo_energies
        inp_dm = inp_dms
        ao = aos
        gw = gws
        gc = gcs
        eri = eris
        hc = hcs
        s = ss
        ogd = ogds
        refGap = refGaps
        refDM = refDMs

        #for non-local
        mf = mfs
        #just say it's converged
        mf.converged = True
        #assign reference calculation values to the mean field object
        mf.mo_occ = mo_occ
        mf.mo_coeff = mo_coeff
        mf.mo_energy = mo_energies
        #set flags for use in the RKSAnalyzer
        mf.idm = True
        mf.iao = True
        mf.dm = inp_dm
        mf.ao = ao
        
        homo_i = jnp.max(jnp.nonzero(mo_occ, size=inp_dm.shape[0])[0])
        #vxc function for gradient
        vgf = lambda x: model(x, ao, gw, mf=mf)
        # print('Generating dmp, moep, mocoep')
        # print(f'inp_dm, eri, mo_occ, hc, s shapes: {inp_dm.shape}, {eri.shape}, {mo_occ.shape}, {hc.shape}, {s.shape}')
        dmp, moep, mocoep = xce.utils.get_dm_moe(inp_dm, eri, vgf, mo_occ, hc, s, ogd)
        print('moep, ', moep)
        # print(f'dmp, moep, mocoep shapes: {dmp.shape}, {moep.shape}, {mocoep.shape}')
        
        # moep = xce.utils.pad_array(moep, moep,  shape=(dmp.shape[0],))
        print('Getting gap_pred')
        gap_pred = moep[homo_i+1]-moep[homo_i]
        print('gap_pred, refGap: ', gap_pred, refGap)


        gap_L = ((gap_pred-refGap)**2)
        return gap_L

class Band_gap_janak_loss(eqx.Module):
    def __init__(self):
        """
        Initializer for the loss module, which attempts to find loss bang gaps w.r.t. reference

        .. todo: Make more robust for non-local descriptors
        """
        super().__init__()

    def __call__(self, model, ao_eval, gw, dm, eri, mo_occ, hc, s, ogd, refgap, mf=None, alpha0=0.7):
        """
        Forward pass for loss object

        NOTE: This differs from HoLu loss in that it selects the deepest minimum w.r.t. the LUMO (Fermi energy)

        :param model: The model that will be used in generating the molecular orbital energies ('band' energies)
        :type model: xcquinox.xc.eXC
        :param ao_eval: The atomic orbitals evaluated on the grid for the given molecule
        :type ao_eval: jax.Array
        :param gw: The grid weights associated to the current molecule's grids
        :type gw: jax.Array
        :param dm: Input reference density matrix for use during the one-shot forward pass to generate the new DM
        :type dm: jax.Array
        :param eri: Electron repulsion integrals associated with this molecule
        :type eri: jax.Array
        :param mo_occ: The molecule's molecular orbital occupation numbers
        :type mo_occ: jax.Array
        :param hc: The molecule's core Hamiltonian
        :type hc: jax.Array
        :param s: The molecule's overlap matrix
        :type s: jax.Array
        :param ogd: The original dimensions of this molecule's density matrix, used if padded to constrict the eigendecomposition to a relevant shape
        :type ogd: jax.Array
        :param refgap: The reference gap to optimzie against
        :type refgap: jax.Array
        :param mf: A pyscf(ad) converged calculation kernel if self.level > 3, used for building the CIDER nonlocal descriptors, defaults to None
        :type mf: pyscfad.dft.RKS kernel
        :param alpha0: The mixing parameter for the one-shot density matrix generation, defaults to 0.7
        :type alpha0: float, optional
        :return: Root-squared error between predicted gap (minimum of molecular energies) and the reference
        :rtype: jax.Array
        """
        def janak_theorem_deriv(model, ao_eval, gw, dm, eri, hc, s, ogd, mf=mf, alpha0=0.7):
            def ret_func(mo_occ):
                vgf = lambda x: model(x, ao_eval, gw, mf=mf)
                dmp, moep, mocp = xce.utils.get_dm_moe(dm, eri, vgf, mo_occ, hc, s, ogd, alpha0=alpha0)
                return model(dmp, ao_eval, gw)
        
            return ret_func

        janak_f = janak_theorem_deriv(model, ao_eval, gw, dm, eri, hc, s, ogd, alpha0)
        homo_i = jnp.max(jnp.nonzero(mo_occ, size=dm.shape[0])[0])

        e, derivs = eqx.filter_value_and_grad(janak_f)(mo_occ)
        print('derivs, ', derivs)
        pred_diff = derivs[homo_i+1] - derivs[homo_i]
        
        loss = jnp.sqrt( (pred_diff - refgap)**2)
        # print(loss)
        return loss

class E_tot_loss(eqx.Module):
    def __init__(self):
        super().__init__()


    def __call__(self, model, ao, gw, dm, eri, mo_occ, mo_coeffs, moes, hc, s, ogd, mf, vn, proportion, refEn):
        #just say it's converged
        mf.converged = True
        #assign reference calculation values to the mean field object
        mf.mo_occ = mo_occ
        mf.mo_coeff = mo_coeffs
        mf.mo_energy = moes
        #set flags for use in the RKSAnalyzer
        mf.idm = True
        mf.iao = True
        mf.dm = dm
        mf.ao = ao
        vgf = lambda x: model(x, ao, gw, mf=mf)
        # print('Generating dmp, moep, mocoep')
        # print(f'inp_dm, eri, mo_occ, hc, s shapes: {inp_dm.shape}, {eri.shape}, {mo_occ.shape}, {hc.shape}, {s.shape}')
        dmp, moep, mocoep = xce.utils.get_dm_moe(dm, eri, vgf, mo_occ, hc, s, ogd)

        veff = xce.utils.get_veff()(dmp, eri)
        etot = xce.utils.energy_tot()(dmp, hc, veff)[0]

        exc = model(dmp, ao, gw, mf=mf)/proportion
        print('etot+exc+enuc', etot, exc, vn, etot+exc+vn)
        return ( (etot+exc+vn) - refEn )**2


In [6]:
# gpus = jax.devices(backend='gpu')
cpus = jax.devices(backend='cpu')

In [14]:
#update docs, only input =2 ??? for MGGA? holdover from sebastian for some reason
xnet = eX(n_input = 2, n_hidden=16, depth=3, use = [1, 2], ueg_limit=True, lob=1.174, seed=17209)
# I guess use default LOB
cnet = eC(n_input = 4, n_hidden=16, depth=3, use = [2, 3], ueg_limit=True, seed=170920)
blankxc = xce.xc.eXC(grid_models = [xnet, cnet], level=3)
p = '/home/awills/Documents/Research/xcquinox/models/pretrained/scan'
xc = eqx.tree_deserialise_leaves(os.path.join(p, 'xc.eqx'), blankxc)
# xc = blankxc
nlxnet = eX(n_input = 15, depth=8, n_hidden=32, use = [], ueg_limit=True, lob=1.174)
nlcnet = eC(n_input = 13, depth=8, n_hidden=32, use = [], ueg_limit=True)

nlxc = xce.xc.eXC(grid_models = [nlxnet, nlcnet], level=4)

TreePathError: Error at leaf with path (GetAttrKey(name='nlstart_i'),)

In [8]:
mf.energy_nuc()

mesh for ewald [11 11 11]
Ewald components = 0.001985449436657, -16.1635769066223, 2.39777600876165


-13.763815448424035

In [13]:
NSTEPS = 100
# lr_scheduler = optax.cosine_onecycle_schedule(transition_steps = NSTEPS, peak_value=5e-1)
# lr_scheduler = optax.exponential_decay(transition_steps = NSTEPS, init_value=1e-1, decay_rate=0.1, transition_begin=NSTEPS//10)
# optimizer = optax.optimistic_gradient_descent(learning_rate=10.0)
# optimizer = optax.noisy_sgd(learning_rate=10.0, seed=92017)
optimizer = optax.lamb(learning_rate=10.0)
# optimize = optax.sgd(learning_rate = 15.0, momentum=3.3, nesterov=True)
xct = xce.train.xcTrainer(model=nlxc, optim=optimizer, steps=NSTEPS,
                          # loss = Gap_Loop_loss(),
                          # loss = Gap_loss(),
                          # loss = Band_gap_janak_loss(),
                          loss = E_tot_loss(),
                          do_jit=False)
with jax.default_device(cpus[0]):
    # newm = xct(len(mols), xct.model, mols, mfs, gcs, gws, init_dms, rmooccs, ogds, reftotes, rdms, refgaps)
    # newm = xct(1, xct.model, [aos], [ts], [hcs], [eris], [ss], [mfs], [gcs], [gws], [init_dms], [rmooccs], [rmocoeffs], [rmoes], [ogds], [reftotes], [rdms], [refgaps])
    # newm = xct(1, xct.model, aos, ts, hcs, eris, ss, mfs, gcs, gws, init_dms, rmooccs, rmocoeffs, rmoes, ogds, reftotes, rdms, refgaps)
    # newm = xct(1, xct.model, aos, gws, init_dms, eris, rmooccs, hcs, ss, ogds, refgaps, mfs)
    newm = xct(1, xct.model, aos, gws, init_dms, eris, rmooccs, rmocoeffs, rmoes, hcs, ss, ogds, mfs, vns, proportions, reftotes)
    

Epoch 0
Epoch 0 :: Batch 0/1
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc 10.232670550062918 -2.212946734085206 -13.763815448424035 -5.744091632446322
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc Traced<Concret



step=0, epoch_train_loss=179.59894820382027
Epoch 1
Epoch 1 :: Batch 0/1
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc 10.230809441522071 -2.770710693745047 -13.763815448424035 -6.30371670064701
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho sha



Epoch 2
Epoch 2 :: Batch 0/1
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc 10.230809441522071 -2.770710693745047 -13.763815448424035 -6.30371670064701
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc Traced<Concrete



Epoch 3
Epoch 3 :: Batch 0/1
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc 10.230809441522071 -2.770710693745047 -13.763815448424035 -6.30371670064701
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc Traced<Concrete



Epoch 4
Epoch 4 :: Batch 0/1
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc 10.230809441522071 -2.770710693745047 -13.763815448424035 -6.30371670064701
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc Traced<Concrete



Epoch 5
Epoch 5 :: Batch 0/1
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc 10.230809441522071 -2.770710693745047 -13.763815448424035 -6.30371670064701
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc Traced<Concrete



Epoch 6
Epoch 6 :: Batch 0/1
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc 10.230809441522071 -2.770710693745047 -13.763815448424035 -6.30371670064701
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc Traced<Concrete



Epoch 7
Epoch 7 :: Batch 0/1
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc 10.230809441522071 -2.770710693745047 -13.763815448424035 -6.30371670064701
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc Traced<Concrete



Epoch 8
Epoch 8 :: Batch 0/1
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc 10.230809441522071 -2.770710693745047 -13.763815448424035 -6.30371670064701
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc Traced<Concrete



Epoch 9
Epoch 9 :: Batch 0/1
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc 10.230809441522071 -2.770710693745047 -13.763815448424035 -6.30371670064701
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc Traced<Concrete



Epoch 10
Epoch 10 :: Batch 0/1
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc 10.230809441522071 -2.770710693745047 -13.763815448424035 -6.30371670064701
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc Traced<Concre



Epoch 11
Epoch 11 :: Batch 0/1
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc 10.230809441522071 -2.770710693745047 -13.763815448424035 -6.30371670064701
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc Traced<Concre



Epoch 12
Epoch 12 :: Batch 0/1
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc 10.230809441522071 -2.770710693745047 -13.763815448424035 -6.30371670064701
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc Traced<Concre



Epoch 13
Epoch 13 :: Batch 0/1
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc nan nan -13.763815448424035 nan
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)
etot+exc+enuc Traced<ConcreteArray(nan, dtype=float64)>with<JVPTrace(le



Epoch 14
Epoch 14 :: Batch 0/1
spin_scaling = True; input descr to exc shape: (8574, 15)
eX.__call__, rho shape: (8574, 15)
spin_scaling = False; input descr to exc shape: (8574, 13)
eC.__call__, rho shape: (8574, 13)


Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7e6a9d0ca6b0>>
Traceback (most recent call last):
  File "/home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 

KeyboardInterrupt



In [None]:
borlido = pd.read_excel('/home/awills/Documents/Research/datasets/borlido2019/gg_direct_gaps.xlsx')

In [None]:
with open('/home/awills/Documents/Research/datasets/borlido2019/ggdirect.pkl', 'rb') as f:
    bstruc_list = pickle.load(f)

In [None]:
smalls = [i for i in bstruc_list if len(i[-1].structure.as_dict()['sites']) <= 4]

In [None]:
smalls[0]

In [None]:
smallids = [i[1] for i in smalls]

In [None]:
smalldf = borlido[borlido['MP_ID'].isin(smallids)]

In [None]:
traj = []
for small in smalls:
    datidx, mpid, bstruc = small
    at = AseAtomsAdaptor.get_atoms(bstruc.structure)
    at.info['borlido_idx'] = datidx
    at.info['mp_id'] = mpid
    at.info['reference'] = smalldf[smalldf['MP_ID']==mpid]['Experimental'].values[0]
    traj.append(at)

In [None]:
traj[2].info

In [None]:
tt = traj[2]

In [None]:
from pyscf.pbc.tools import pyscf_ase

In [None]:
pyscf_ase.ase_atoms_to_pyscf(tt)

In [None]:
write('/home/awills/Documents/Research/datasets/borlido2019/smalls.traj', traj)

In [None]:
test = smalls[0][-1]

In [None]:
gtop

In [None]:
test.as_dict()['lattice_rec']

In [None]:
at  = AseAtomsAdaptor.get_atoms(test.structure)

In [None]:
mol.pbc

In [None]:
gtop.Cell(a=np.array(at.cell), charge=0).spin

In [None]:
from pyscf.pbc import cc

In [None]:
cc.

In [None]:
image = np.array([
[[1, 2, 3], [4, 5, 6], [-1, 8, 9]],
[[10, 11, 12], [13, 14, 15], [16, 17, 18]],
[[19, 20, 21], [22, 23, 24], [25, 26, 27]]
],
dtype=np.float32)
mask = np.array(
[[1, 4, 7],
[10, 4, 16],
[19, 22, 255]],
dtype=np.uint8)

In [None]:
with MPRester(api_key = '') as mpr:
    docs = mpr.materials.summary.search(
        band_gap=(0, 5), is_stable=True, fields=["material_id"], is_gap_direct = True, num_elements=(0, 4), num_sites=(0, 4),
        has_props=[HasProps.bandstructure]
    )
    gamma_directs = [mpr.get_bandstructure_by_material_id(str(i.material_id)) for i in docs]

In [None]:
t = gamma_directs[10]

In [None]:
t.get_band_gap()

In [None]:
gamgam = [i for i in gamma_directs if i.get_band_gap()['transition'] == '\\Gamma-\\Gamma']

In [None]:
len(borlido_gg_direct[0][-1].structure.as_dict()['sites'])

In [None]:
smalls = [(i[1], i[-1]) for i in borlido_gg_direct if len(i[-1].structure.as_dict()['sites']) < 4]

In [None]:
[i[1].structure.as_dict()['sites'][0]['species'] for i in smalls]

In [None]:
mpr = MPRester(api_key = '')

In [None]:
# mp = mpr.get_bandstructure_by_material_id('mp-1266')
mp = smalls[0]
mps = mp.structure.as_dict()
mp.get_band_gap()

In [None]:
rets = mps
at_coor_xyz = [ (i['species'][0]['element'], i['xyz']) for i in rets['sites']]
at_coor_abc = [ (i['species'][0]['element'], [rets['lattice']['a']*j for j in i['abc']]) for i in rets['sites']]
lat = np.array(rets['lattice']['matrix'])
if np.linalg.det(lat) < 1:
    print('left handed array, switching')
    lat = -np.array([lat[-1], lat[1], lat[0]])

In [None]:
np.linalg.det(lat)

In [None]:
traj

In [None]:
at_coor_abc

In [None]:
cell = gtop.Cell(verbose=9)
cell.atom = at_coor_abc
cell.a = lat
cell.basis = 'def2-svp'
cell.pseudo = 'gth-pbe'
cell.exp_to_discard = 0.1
cell.build()
# kpts = cell.make_kpts([2,2,2])
mf = dftp.RKS(cell, xc='pbe')
# mf = scfp.RHF(cell)
e = mf.kernel()

In [None]:
mf.mo_occ, mf.mo_energy

In [None]:
mf.mo_energy - mf.mo_energy[mf.mo_occ == 0][0], mf.mo_energy - mf.mo_energy[mf.mol.nelectron//2-1], mf.mo_energy, mf.mo_occ

In [None]:
a1 = bulk('Si', a=3.867114, b=3.867114, c=3.867114, alpha=60)

In [None]:
a1.cell

In [None]:
gtop.Cell?

In [None]:
gtop

In [None]:
mfs = []
mols = []
energies = []
dms = []
ao_evals = []
gws = []
eris = []
mo_occs = []
hcs = []
vs = []
ts = []
ss = []
hologaps = []
ogds = []

cell = gtop.Cell()
a = 5.43
cell.atom = [['Si', [0,0,0]],
              ['Si', [a/4,a/4,a/4]]]
cell.a = jnp.asarray([[0, a/2, a/2],
                     [a/2, 0, a/2],
                     [a/2, a/2, 0]])
cell.basis = 'gth-szv'
cell.pseudo = 'gth-pade'
cell.exp_to_discard = 0.1
cell.build()
kpts = cell.make_kpts([2,2,2])
mf = dftp.RKS(cell, xc='pbe0')
mf2 = dftp.KRKS(cell, xc='pbe0', kpts=kpts)
e = mf.kernel()
# e2 = mf2.kernel()

In [None]:
mf

In [None]:
mfs = []
mols = []
energies = []
dms = []
ao_evals = []
gws = []
eris = []
mo_occs = []
hcs = []
vs = []
ts = []
ss = []
hologaps = []
ogds = []

mfs.append(mf)
dm = mf.make_rdm1()
dmj = jnp.array(dm)
dmj.flags = dm.flags
ao_eval = jnp.array(mf._numint.eval_ao(mf.mol, mf.grids.coords, deriv=2))
energies.append(jnp.array(mf.get_veff().exc))
dms.append(dmj)
ogds.append(dm.shape)
ao_evals.append(jnp.array(ao_eval))
gws.append(jnp.array(mf.grids.weights))
ts.append(jnp.array(mf.mol.intor('int1e_kin')))
vs.append(jnp.array(mf.mol.intor('int1e_nuc')))
mo_occs.append(jnp.array(mf.mo_occ))
hcs.append(jnp.array(mf.get_hcore()))
eris.append(jnp.array(mf.mol.intor('int2e')))
ss.append(jnp.linalg.inv(jnp.linalg.cholesky(mf.mol.intor('int1e_ovlp'))))
hologaps.append(jnp.array(mf.mo_energy[mf.mo_occ == 0][0] - mf.mo_energy[mf.mo_occ > 1][-1]))

In [None]:
class Band_gap_1shot_loss(eqx.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, model, ao_eval, gw, dm, eri, mo_occ, hc, s, ogd, refgap, mf, alpha0=0.7):
        vgf = lambda x: model(x, ao_eval, gw, mf)
        dmp, moep, mocp = xce.utils.get_dm_moe(dm, eri, vgf, mo_occ, hc, s, ogd, alpha0)
        
        efermi = moep[mf.mol.nelectron//2-1]
        moep -= efermi
        # print(moep)
        moep_gap = jnp.min(moep)
        # print(moep_gap)
        loss = jnp.sqrt( (moep_gap - refgap)**2)
        # print(loss)
        return jnp.sqrt( (moep_gap - refgap)**2)


In [None]:
xce.net.eX?

In [None]:
#update docs, only input =2 ??? for MGGA? holdover from sebastian for some reason
xnet = xce.net.eX(n_input = 2, n_hidden=32, depth=4, use = [1, 2], ueg_limit=True, lob=1.174)
# I guess use default LOB
cnet = xce.net.eC(n_input = 4, n_hidden=32, depth=4, use = [2, 3], ueg_limit=True)
blankxc = xce.xc.eXC(grid_models = [xnet, cnet], level=3)
p = '/home/awills/Documents/Research/xcquinox/models/pretrained/scan'
# xc = eqx.tree_deserialise_leaves(os.path.join(p, 'xc.eqx'), blankxc)
xc = blankxc
nlxnet = xce.net.eX(n_input = 15, use = [], ueg_limit=True, lob=1.174)
nlcnet = xce.net.eC(n_input = 13, use = [], ueg_limit=True)

nlxc = xce.xc.eXC(grid_models = [nlxnet, nlcnet], level=4)

In [None]:

xc(dms[0], ao_evals[0], gws[0])

In [None]:
class Band_gap_janak_loss(eqx.Module):
    def __init__(self):
        """
        Initializer for the loss module, which attempts to find loss bang gaps w.r.t. reference

        .. todo: Make more robust for non-local descriptors
        """
        super().__init__()

    def __call__(self, model, ao_eval, gw, dm, eri, mo_occ, hc, s, ogd, refgap, mf, alpha0=0.7):
        """
        Forward pass for loss object

        NOTE: This differs from HoLu loss in that it selects the deepest minimum w.r.t. the LUMO (Fermi energy)

        :param model: The model that will be used in generating the molecular orbital energies ('band' energies)
        :type model: xcquinox.xc.eXC
        :param ao_eval: The atomic orbitals evaluated on the grid for the given molecule
        :type ao_eval: jax.Array
        :param gw: The grid weights associated to the current molecule's grids
        :type gw: jax.Array
        :param dm: Input reference density matrix for use during the one-shot forward pass to generate the new DM
        :type dm: jax.Array
        :param eri: Electron repulsion integrals associated with this molecule
        :type eri: jax.Array
        :param mo_occ: The molecule's molecular orbital occupation numbers
        :type mo_occ: jax.Array
        :param hc: The molecule's core Hamiltonian
        :type hc: jax.Array
        :param s: The molecule's overlap matrix
        :type s: jax.Array
        :param ogd: The original dimensions of this molecule's density matrix, used if padded to constrict the eigendecomposition to a relevant shape
        :type ogd: jax.Array
        :param refgap: The reference gap to optimzie against
        :type refgap: jax.Array
        :param mf: A pyscf(ad) converged calculation kernel if self.level > 3, used for building the CIDER nonlocal descriptors, defaults to None
        :type mf: pyscfad.dft.RKS kernel
        :param alpha0: The mixing parameter for the one-shot density matrix generation, defaults to 0.7
        :type alpha0: float, optional
        :return: Root-squared error between predicted gap (minimum of molecular energies) and the reference
        :rtype: jax.Array
        """
        def janak_theorem_deriv(model, ao_eval, gw, dm, eri, hc, s, ogd, alpha0=0.7):
            def ret_func(mo_occ):
                vgf = lambda x: model(x, ao_eval, gw)
                dmp, moep, mocp = xce.utils.get_dm_moe(dm, eri, vgf, mo_occ, hc, s, ogd, alpha0=alpha0)
                return model(dmp, ao_eval, gw)
        
            return ret_func

        janak_f = janak_theorem_deriv(model, ao_eval, gw, dm, eri, hc, s, ogd, alpha0)
        homo_i = jnp.max(jnp.nonzero(mo_occ, size=dm.shape[0])[0])

        e, derivs = eqx.filter_value_and_grad(janak_f)(mo_occ)

        pred_diff = derivs[homo_i+1] - derivs[homo_i]
        
        loss = jnp.sqrt( (pred_diff - refgap)**2)
        # print(loss)
        return loss


In [None]:
def janak_theorem_deriv(model, ao_eval, gw, dm, eri, moocc, hc, s, ogd, alpha0=0.7):
    def ret_func(mo_occ):
        vgf = lambda x: model(x, ao_eval, gw)
        dmp, moep, mocp = xce.utils.get_dm_moe(dm, eri, vgf, mo_occ, hc, s, ogd, alpha0=alpha0)
        return model(dmp, ao_eval, gw)

    return ret_func
    
    
    

In [None]:
checkd = janak_theorem_deriv(xc, ao_evals[0], gws[0], dms[0], eris[0], mo_occs[0], hcs[0], ss[0], ogds[0], alpha0=0.7)
eqx.filter_value_and_grad(checkd)(mo_occs[0])

In [None]:
xct = xce.train.xcTrainer(model=xc, optim=optax.adamw(1e-2), steps=100, loss = Band_gap_janak_loss(), do_jit=True)
newm = xct(1, xct.model, ao_evals, gws, dms, eris, mo_occs, hcs, ss, ogds, [1.17], mfs)

In [None]:
e1 = nlxc(dms[0], ao_evals[0], gws[0], mfs[0])
e2 = newm(dms[0], ao_evals[0], gws[0], mfs[0])

In [None]:
e1, e2

In [None]:
vgf1 = lambda x: xc(x, ao_evals[0], gws[0], mfs[0])
vgf2 = lambda x: newm(x, ao_evals[0], gws[0], mfs[0])
dm1, moe1, moc1 = xce.utils.get_dm_moe(dms[0], eris[0], vgf1, mo_occs[0], hcs[0], ss[0], ogds[0])
dm2, moe2, moc2 = xce.utils.get_dm_moe(dms[0], eris[0], vgf2, mo_occs[0], hcs[0], ss[0], ogds[0])

In [None]:
print(moe1 - moe1[mf.mol.nelectron//2-1])
print(moe2 - moe2[mf.mol.nelectron//2-1])

In [None]:
mf.mo_energy

In [None]:
# vbmax = -99
for en in b1[0]:
    vb_k = en[cell.nelectron//2-1]
    print('This vb_k', vb_k)
    if vb_k > vbmax:
        vbmax = vb_k
e_kn = [en - vbmax for en in b1[0]]

In [None]:
e_kn

In [None]:
mf2 = scfp.RHF(cell)
e2 = mf2.kernel()

In [None]:
cell.nelectron//2-1

In [None]:
t1 = mf2.mo_energy 
t2 = mf2.mo_energy - mf2.mo_energy[cell.nelectron//2-1]

In [None]:
t2[jnp.where(abs(t2[jnp.where( (t2 < 0) )[0]]) > 1e-4)[0]]

In [None]:
dm2 = mf2.make_rdm1()

In [None]:
dmk = mf.make_rdm1()

In [None]:
mpr = MPRester(api_key = '')
mpid = 'mp-149'
# ret = mpr.get_bandstructure_by_material_id(mpid)
# rets = ret.structure.as_dict()
ret = mpr.get_structure_by_material_id(mpid, conventional_unit_cell=False)
rets = ret.as_dict()
at_coor_xyz = [ (i['species'][0]['element'], [-j for j in i['xyz']]) for i in rets['sites']]
at_coor_abc = [ (i['species'][0]['element'], [rets['lattice']['a']*j for j in i['abc']]) for i in rets['sites']]
cella = -np.asarray(rets['lattice']['matrix'])

In [None]:
at_coor_xyz, cella

In [None]:
cell = gtop.Cell()
cell.atom = at_coor_xyz
cell.a = cella.T
cell.basis = 'gth-szv'
cell.pseudo = 'gth-pade'
cell.exp_to_discard = 0.1
cell.build()
kpts = cell.make_kpts([2,2,2])
mf = scfp.KRHF(cell, kpts=kpts)
e = mf.kernel()

In [None]:
cella