In [1]:
from ase import Atoms
from ase.io import read
import xcquinox as xce
import torch, jax, optax
import numpy as np
import equinox as eqx
import jax.numpy as jnp
import pyscfad as psa
import os, sys
from pyscf import dft, scf, gto, df
from pyscf.pbc import scf as scfp
from pyscf.pbc import gto as gtop
from pyscf.pbc import dft as dftp
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'

from mp_api.client import MPRester
from mldftdat.density import get_exchange_descriptors2
from mldftdat.analyzers import RKSAnalyzer

from ase.build import bulk

CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [2]:
mpr = MPRester(api_key = '')
si = mpr.get_bandstructure_by_material_id('mp-149')
c = mpr.get_bandstructure_by_material_id('mp-66')

Retrieving ElectronicStructureDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving ElectronicStructureDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]

In [3]:
si.get_band_gap()

{'direct': False,
 'energy': 0.6105,
 'transition': '\\Gamma-(0.413,0.000,0.413)'}

In [4]:
sisd = si.structure.as_dict()
si_lat = sisd['lattice']['matrix']
si_lat = -np.array([si_lat[-1], si_lat[1], si_lat[0]])
si_lat

array([[-0.      ,  2.734463,  2.734463],
       [ 2.734463, -0.      ,  2.734463],
       [ 2.734463,  2.734463, -0.      ]])

In [5]:
mp = mpr.get_bandstructure_by_material_id('mp-984')
mps = mp.structure.as_dict()

Retrieving ElectronicStructureDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]

In [6]:
mps

{'@module': 'pymatgen.core.structure',
 '@class': 'Structure',
 'charge': 0,
 'lattice': {'matrix': [[-1.256785, -2.176028, 0.0],
   [-1.256785, 2.176028, 0.0],
   [0.0, 0.0, -6.804991]],
  'pbc': (True, True, True),
  'a': 2.512888058193003,
  'b': 2.512888058193003,
  'c': 6.804991,
  'alpha': 90.0,
  'beta': 90.0,
  'gamma': 119.98204498301234,
  'volume': 37.2205699268395},
 'properties': {},
 'sites': [{'species': [{'element': 'B', 'occu': 1}],
   'abc': [0.666495, 0.333505, 0.75],
   'xyz': [-1.256785, -0.72459556372, -5.10374325],
   'properties': {},
   'label': 'B'},
  {'species': [{'element': 'B', 'occu': 1}],
   'abc': [0.333505, 0.666495, 0.25],
   'xyz': [-1.256785, 0.72459556372, -1.70124775],
   'properties': {},
   'label': 'B'},
  {'species': [{'element': 'N', 'occu': 1}],
   'abc': [0.666502, 0.333498, 0.25],
   'xyz': [-1.256785, -0.724626028112, -1.70124775],
   'properties': {},
   'label': 'N'},
  {'species': [{'element': 'N', 'occu': 1}],
   'abc': [0.333498, 0.6

In [7]:
rets = mps
at_coor_xyz = [ (i['species'][0]['element'], i['xyz']) for i in rets['sites']]
at_coor_abc = [ (i['species'][0]['element'], [rets['lattice']['a']*j for j in i['abc']]) for i in rets['sites']]
lat = np.array(rets['lattice']['matrix'])
if np.linalg.det(lat) < 1:
    print('left handed array, switching')
    lat = -np.array([lat[-1], lat[1], lat[0]])

In [8]:
np.linalg.det(lat)

37.2205699268395

In [9]:
lat

array([[-1.256785, -2.176028,  0.      ],
       [-1.256785,  2.176028,  0.      ],
       [ 0.      ,  0.      , -6.804991]])

In [10]:
at_coor_abc

[('B', [1.6748273263453455, 0.8380607318476575, 1.8846660436447524]),
 ('B', [0.8380607318476575, 1.6748273263453455, 0.6282220145482508]),
 ('N', [1.674844916561753, 0.8380431416312502, 0.6282220145482508]),
 ('N', [0.8380431416312502, 1.674844916561753, 1.8846660436447524])]

In [11]:
mp.get_band_gap()

{'direct': False,
 'energy': 4.273999999999999,
 'transition': '(0.304,0.304,0.000)-M'}

In [12]:
# cell = gtop.Cell()
# cell.atom = at_coor_abc
# cell.a = lat
# cell.basis = 'gth-szv'
# cell.pseudo = 'gth-pade'
# cell.exp_to_discard = 0.1
# cell.build()
# kpts = cell.make_kpts([2,2,2])
# mf = scfp.RHF(cell)
# e = mf.kernel()

In [13]:
# mf.mo_energy - mf.mo_energy[mf.mo_occ == 0][0], mf.mo_energy - mf.mo_energy[mf.mol.nelectron//2-1], mf.mo_energy, mf.mo_occ

In [14]:
cisd = c.structure.as_dict()

In [15]:
sisd, cisd

({'@module': 'pymatgen.core.structure',
  '@class': 'Structure',
  'charge': 0,
  'lattice': {'matrix': [[-2.734463, -2.734463, 0.0],
    [-2.734463, 0.0, -2.734463],
    [0.0, -2.734463, -2.734463]],
   'pbc': (True, True, True),
   'a': 3.8671146604074202,
   'b': 3.8671146604074202,
   'c': 3.8671146604074202,
   'alpha': 59.99999999999999,
   'beta': 59.99999999999999,
   'gamma': 59.99999999999999,
   'volume': 40.89273419687557},
  'properties': {},
  'sites': [{'species': [{'element': 'Si', 'occu': 1}],
    'abc': [0.75, 0.75, 0.75],
    'xyz': [-4.1016945, -4.1016945, -4.1016945],
    'properties': {},
    'label': 'Si'},
   {'species': [{'element': 'Si', 'occu': 1}],
    'abc': [0.0, 0.0, 0.0],
    'xyz': [0.0, 0.0, 0.0],
    'properties': {},
    'label': 'Si'}]},
 {'@module': 'pymatgen.core.structure',
  '@class': 'Structure',
  'charge': 0,
  'lattice': {'matrix': [[-1.786855, -1.786855, 0.0],
    [-1.786855, 0.0, -1.786855],
    [0.0, -1.786855, -1.786855]],
   'pbc': (Tru

In [16]:
a1 = bulk('Si', a=3.867114, b=3.867114, c=3.867114, alpha=60)

In [17]:
a1.cell

Cell([[0.0, 1.933557, 1.933557], [1.933557, 0.0, 1.933557], [1.933557, 1.933557, 0.0]])

In [18]:
mfs = []
mols = []
energies = []
dms = []
ao_evals = []
gws = []
eris = []
mo_occs = []
hcs = []
vs = []
ts = []
ss = []
hologaps = []
ogds = []

cell = gtop.Cell()
a = 5.43
cell.atom = [['Si', [0,0,0]],
              ['Si', [a/4,a/4,a/4]]]
cell.a = jnp.asarray([[0, a/2, a/2],
                     [a/2, 0, a/2],
                     [a/2, a/2, 0]])
cell.basis = 'gth-szv'
cell.pseudo = 'gth-pade'
cell.exp_to_discard = 0.1
cell.build()
kpts = cell.make_kpts([2,2,2])
mf = dftp.RKS(cell, xc='pbe0')
mf2 = dftp.KRKS(cell, xc='pbe0', kpts=kpts)
e = mf.kernel()
# e2 = mf2.kernel()

<class 'pyscf.pbc.dft.rks.RKS'> does not have attributes  nlc nlcgrids


converged SCF energy = -7.21836790602506


In [19]:
mf

RKS-KohnShamDFT object of <class 'pyscf.pbc.dft.rks.RKS'>

In [20]:
mfs = []
mols = []
energies = []
dms = []
ao_evals = []
gws = []
eris = []
mo_occs = []
hcs = []
vs = []
ts = []
ss = []
hologaps = []
ogds = []

mfs.append(mf)
dm = mf.make_rdm1()
dmj = jnp.array(dm)
dmj.flags = dm.flags
ao_eval = jnp.array(mf._numint.eval_ao(mf.mol, mf.grids.coords, deriv=2))
energies.append(jnp.array(mf.get_veff().exc))
dms.append(dmj)
ogds.append(dm.shape)
ao_evals.append(jnp.array(ao_eval))
gws.append(jnp.array(mf.grids.weights))
ts.append(jnp.array(mf.mol.intor('int1e_kin')))
vs.append(jnp.array(mf.mol.intor('int1e_nuc')))
mo_occs.append(jnp.array(mf.mo_occ))
hcs.append(jnp.array(mf.get_hcore()))
eris.append(jnp.array(mf.mol.intor('int2e')))
ss.append(jnp.linalg.inv(jnp.linalg.cholesky(mf.mol.intor('int1e_ovlp'))))
hologaps.append(jnp.array(mf.mo_energy[mf.mo_occ == 0][0] - mf.mo_energy[mf.mo_occ > 1][-1]))

In [21]:
class Band_gap_1shot_loss(eqx.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, model, ao_eval, gw, dm, eri, mo_occ, hc, s, ogd, refgap, mf, alpha0=0.7):
        vgf = lambda x: model(x, ao_eval, gw, mf)
        dmp, moep, mocp = xce.utils.get_dm_moe(dm, eri, vgf, mo_occ, hc, s, ogd, alpha0)
        
        efermi = moep[mf.mol.nelectron//2-1]
        moep -= efermi
        # print(moep)
        moep_gap = jnp.min(moep)
        # print(moep_gap)
        loss = jnp.sqrt( (moep_gap - refgap)**2)
        # print(loss)
        return jnp.sqrt( (moep_gap - refgap)**2)


In [43]:
xce.net.eX?

[0;31mInit signature:[0m
[0mxce[0m[0;34m.[0m[0mnet[0m[0;34m.[0m[0meX[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mn_input[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mn_hidden[0m[0;34m=[0m[0;36m16[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdepth[0m[0;34m=[0m[0;36m3[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0muse[0m[0;34m=[0m[0;34m[[0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mueg_limit[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mlob[0m[0;34m=[0m[0;36m1.804[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mseed[0m[0;34m=[0m[0;36m92017[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m      eX(n_input, n_hidden=16, depth=3, use=[], ueg_limit=False, lob=1.804, seed=92017)
[0;31mInit docstring:[0m
__init__ Local exchange model based on MLP.

Receives density descriptors in this order : [rho, s, alpha, nl], where the input may be truncated depending on XC-level of a

In [44]:
#update docs, only input =2 ??? for MGGA? holdover from sebastian for some reason
xnet = xce.net.eX(n_input = 2, n_hidden=32, depth=4, use = [1, 2], ueg_limit=True, lob=1.174)
# I guess use default LOB
cnet = xce.net.eC(n_input = 4, n_hidden=32, depth=4, use = [2, 3], ueg_limit=True)
blankxc = xce.xc.eXC(grid_models = [xnet, cnet], level=3)
p = '/home/awills/Documents/Research/xcquinox/models/pretrained/scan'
# xc = eqx.tree_deserialise_leaves(os.path.join(p, 'xc.eqx'), blankxc)
xc = blankxc
nlxnet = xce.net.eX(n_input = 15, use = [], ueg_limit=True, lob=1.174)
nlcnet = xce.net.eC(n_input = 13, use = [], ueg_limit=True)

nlxc = xce.xc.eXC(grid_models = [nlxnet, nlcnet], level=4)



In [45]:

xc(dms[0], ao_evals[0], gws[0])

spin_scaling = True; input descr to exc shape: (2, 50653, 3)
eX.__call__, rho shape: (2, 50653, 3)
spin_scaling = False; input descr to exc shape: (50653, 4)
eC.__call__, rho shape: (50653, 4)


Array(-2.51022075, dtype=float64)

In [24]:
class Band_gap_janak_loss(eqx.Module):
    def __init__(self):
        """
        Initializer for the loss module, which attempts to find loss bang gaps w.r.t. reference

        .. todo: Make more robust for non-local descriptors
        """
        super().__init__()

    def __call__(self, model, ao_eval, gw, dm, eri, mo_occ, hc, s, ogd, refgap, mf, alpha0=0.7):
        """
        Forward pass for loss object

        NOTE: This differs from HoLu loss in that it selects the deepest minimum w.r.t. the LUMO (Fermi energy)

        :param model: The model that will be used in generating the molecular orbital energies ('band' energies)
        :type model: xcquinox.xc.eXC
        :param ao_eval: The atomic orbitals evaluated on the grid for the given molecule
        :type ao_eval: jax.Array
        :param gw: The grid weights associated to the current molecule's grids
        :type gw: jax.Array
        :param dm: Input reference density matrix for use during the one-shot forward pass to generate the new DM
        :type dm: jax.Array
        :param eri: Electron repulsion integrals associated with this molecule
        :type eri: jax.Array
        :param mo_occ: The molecule's molecular orbital occupation numbers
        :type mo_occ: jax.Array
        :param hc: The molecule's core Hamiltonian
        :type hc: jax.Array
        :param s: The molecule's overlap matrix
        :type s: jax.Array
        :param ogd: The original dimensions of this molecule's density matrix, used if padded to constrict the eigendecomposition to a relevant shape
        :type ogd: jax.Array
        :param refgap: The reference gap to optimzie against
        :type refgap: jax.Array
        :param mf: A pyscf(ad) converged calculation kernel if self.level > 3, used for building the CIDER nonlocal descriptors, defaults to None
        :type mf: pyscfad.dft.RKS kernel
        :param alpha0: The mixing parameter for the one-shot density matrix generation, defaults to 0.7
        :type alpha0: float, optional
        :return: Root-squared error between predicted gap (minimum of molecular energies) and the reference
        :rtype: jax.Array
        """
        def janak_theorem_deriv(model, ao_eval, gw, dm, eri, hc, s, ogd, alpha0=0.7):
            def ret_func(mo_occ):
                vgf = lambda x: model(x, ao_eval, gw)
                dmp, moep, mocp = xce.utils.get_dm_moe(dm, eri, vgf, mo_occ, hc, s, ogd, alpha0=alpha0)
                return model(dmp, ao_eval, gw)
        
            return ret_func

        janak_f = janak_theorem_deriv(model, ao_eval, gw, dm, eri, hc, s, ogd, alpha0)
        homo_i = jnp.max(jnp.nonzero(mo_occ, size=dm.shape[0])[0])

        e, derivs = eqx.filter_value_and_grad(janak_f)(mo_occ)

        pred_diff = derivs[homo_i+1] - derivs[homo_i]
        
        loss = jnp.sqrt( (pred_diff - refgap)**2)
        # print(loss)
        return loss


In [25]:
def janak_theorem_deriv(model, ao_eval, gw, dm, eri, moocc, hc, s, ogd, alpha0=0.7):
    def ret_func(mo_occ):
        vgf = lambda x: model(x, ao_eval, gw)
        dmp, moep, mocp = xce.utils.get_dm_moe(dm, eri, vgf, mo_occ, hc, s, ogd, alpha0=alpha0)
        return model(dmp, ao_eval, gw)

    return ret_func
    
    
    

In [26]:
checkd = janak_theorem_deriv(xc, ao_evals[0], gws[0], dms[0], eris[0], mo_occs[0], hcs[0], ss[0], ogds[0], alpha0=0.7)
eqx.filter_value_and_grad(checkd)(mo_occs[0])

spin_scaling = True; input descr to exc shape: (2, 50653, 3)
eX.__call__, rho shape: (2, 50653, 3)
spin_scaling = False; input descr to exc shape: (50653, 4)
eC.__call__, rho shape: (50653, 4)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (2, 50653, 3)
eX.__call__, rho shape: (2, 50653, 3)
spin_scaling = False; input descr to exc shape: (50653, 4)
eC.__call__, rho shape: (50653, 4)


(Array(-3.11735723, dtype=float64),
 Array([-0.76418695, -0.23435005, -0.51950138, -0.51950137,  0.        ,
         0.        ,  0.        ,  0.        ], dtype=float64))

In [48]:
xct = xce.train.xcTrainer(model=xc, optim=optax.adamw(1e-2), steps=100, loss = Band_gap_janak_loss(), do_jit=True)
newm = xct(1, xct.model, ao_evals, gws, dms, eris, mo_occs, hcs, ss, ogds, [1.17], mfs)

Epoch 0
Epoch 0 :: Batch 0/1
spin_scaling = True; input descr to exc shape: (2, 50653, 3)
eX.__call__, rho shape: (2, 50653, 3)
spin_scaling = False; input descr to exc shape: (50653, 4)
eC.__call__, rho shape: (50653, 4)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (2, 50653, 3)
eX.__call__, rho shape: (2, 50653, 3)
spin_scaling = False; input descr to exc shape: (50653, 4)
eC.__call__, rho shape: (50653, 4)



KeyboardInterrupt



In [39]:
e1 = nlxc(dms[0], ao_evals[0], gws[0], mfs[0])
e2 = newm(dms[0], ao_evals[0], gws[0], mfs[0])



spin_scaling = True; input descr to exc shape: (50653, 15)
eX.__call__, rho shape: (50653, 15)
spin_scaling = False; input descr to exc shape: (50653, 13)
eC.__call__, rho shape: (50653, 13)
spin_scaling = True; input descr to exc shape: (2, 50653, 3)
eX.__call__, rho shape: (2, 50653, 3)
spin_scaling = False; input descr to exc shape: (50653, 4)
eC.__call__, rho shape: (50653, 4)


In [40]:
e1, e2

(Array(-2.57218399, dtype=float64), Array(-2.90765126, dtype=float64))

In [41]:
vgf1 = lambda x: xc(x, ao_evals[0], gws[0], mfs[0])
vgf2 = lambda x: newm(x, ao_evals[0], gws[0], mfs[0])
dm1, moe1, moc1 = xce.utils.get_dm_moe(dms[0], eris[0], vgf1, mo_occs[0], hcs[0], ss[0], ogds[0])
dm2, moe2, moc2 = xce.utils.get_dm_moe(dms[0], eris[0], vgf2, mo_occs[0], hcs[0], ss[0], ogds[0])

spin_scaling = True; input descr to exc shape: (2, 50653, 3)
eX.__call__, rho shape: (2, 50653, 3)
spin_scaling = False; input descr to exc shape: (50653, 4)
eC.__call__, rho shape: (50653, 4)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()
spin_scaling = True; input descr to exc shape: (2, 50653, 3)
eX.__call__, rho shape: (2, 50653, 3)
spin_scaling = False; input descr to exc shape: (50653, 4)
eC.__call__, rho shape: (50653, 4)
[8] (8,)
(8, 8) (8, 8)
Spin unpolarized make_rdm1()


In [42]:
print(moe1 - moe1[mf.mol.nelectron//2-1])
print(moe2 - moe2[mf.mol.nelectron//2-1])

[-6.57726192e-01 -7.60975504e-02 -3.70073737e-07  0.00000000e+00
  1.02667394e-01  1.02667614e-01  3.18068487e-01  5.38626343e-01]
[-6.86420513e-01 -2.79444128e-02 -3.70097756e-07  0.00000000e+00
  1.49307013e-01  1.49307234e-01  3.48533950e-01  5.46699981e-01]


In [32]:
mf.mo_energy

array([-0.24878937,  0.22904513,  0.22904513,  0.22904513,  0.43234039,
        0.43234039,  0.43234039,  0.45784889])

In [33]:
# vbmax = -99
for en in b1[0]:
    vb_k = en[cell.nelectron//2-1]
    print('This vb_k', vb_k)
    if vb_k > vbmax:
        vbmax = vb_k
e_kn = [en - vbmax for en in b1[0]]

NameError: name 'b1' is not defined

In [None]:
e_kn

In [None]:
mf2 = scfp.RHF(cell)
e2 = mf2.kernel()

In [None]:
cell.nelectron//2-1

In [None]:
t1 = mf2.mo_energy 
t2 = mf2.mo_energy - mf2.mo_energy[cell.nelectron//2-1]

In [None]:
t2[jnp.where(abs(t2[jnp.where( (t2 < 0) )[0]]) > 1e-4)[0]]

In [None]:
dm2 = mf2.make_rdm1()

In [None]:
dmk = mf.make_rdm1()

In [None]:
mpr = MPRester(api_key = '')
mpid = 'mp-149'
# ret = mpr.get_bandstructure_by_material_id(mpid)
# rets = ret.structure.as_dict()
ret = mpr.get_structure_by_material_id(mpid, conventional_unit_cell=False)
rets = ret.as_dict()
at_coor_xyz = [ (i['species'][0]['element'], [-j for j in i['xyz']]) for i in rets['sites']]
at_coor_abc = [ (i['species'][0]['element'], [rets['lattice']['a']*j for j in i['abc']]) for i in rets['sites']]
cella = -np.asarray(rets['lattice']['matrix'])

In [None]:
at_coor_xyz, cella

In [None]:
cell = gtop.Cell()
cell.atom = at_coor_xyz
cell.a = cella.T
cell.basis = 'gth-szv'
cell.pseudo = 'gth-pade'
cell.exp_to_discard = 0.1
cell.build()
kpts = cell.make_kpts([2,2,2])
mf = scfp.KRHF(cell, kpts=kpts)
e = mf.kernel()

In [None]:
cella