In [1]:
from ase import Atoms
from ase.io import read
import xcquinox as xce
import torch, jax, optax
import numpy as np
import equinox as eqx
import jax.numpy as jnp
import pyscfad as psa
import os, sys
from pyscfad import dft, scf, gto, df
from pyscfad.pbc import scf as scfp
from pyscfad.pbc import gto as gtop
from pyscfad.pbc import dft as dftp
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Utility function requiring torch to load old models, but we won't require torch as a prerequisite. Eventually want to have a self-contained folder of translated models.

Torch structure for loading old models below.

In [2]:
#relevant functions from dpyscfl to see if it can be self-contained here in the notebook
#xcdiff has this named XC_L, not X_L. keep for consistency's sake

class LOB(torch.nn.Module):

    def __init__(self, limit=1.804):
        """ Utility function to squash output to [-1, limit-1] inteval.
            Can be used to enforce non-negativity and Lieb-Oxford bound.
        """
        super().__init__()
        self.sig = torch.nn.Sigmoid()
        self.limit = limit

    def forward(self, x):
        return self.limit*self.sig(x-np.log(self.limit-1))-1


class X_L(torch.nn.Module):
    def __init__(self, n_input, n_hidden=16, use=[], device='cpu', ueg_limit=False, lob=1.804, one_e=False):
        """Local exchange model based on MLP
        Receives density descriptors in this order : [rho, s, alpha, nl],
        input may be truncated depending on level of approximation

        Args:
            n_input (int): Input dimensions (LDA: 1, GGA: 2, meta-GGA: 3, ...)
            n_hidden (int, optional): Number of hidden nodes (three hidden layers used by default). Defaults to 16.
            use (list of ints, optional): Only these indices are used as input to the model (can be used to omit density as input to enforce uniform density scaling). These indices are also used to enforce UEG where the assumed order is [s, alpha, ...].. Defaults to [].
            device (str, optional): {'cpu','cuda'}. Defaults to 'cpu'.
            ueg_limit (bool, optional): Enforce uniform homoegeneous electron gas limit. Defaults to False.
            lob (float, optional): Enforce this value as local Lieb-Oxford bound (don't enforce if set to 0). Defaults to 1.804.
            one_e (bool, optional): _description_. Defaults to False.
        """
        super().__init__()
        self.ueg_limit = ueg_limit
        self.spin_scaling = True
        self.lob = lob

        if not use:
            self.use = torch.Tensor(np.arange(n_input)).long().to(device)
        else:
            self.use = torch.Tensor(use).long().to(device)
        #xcdiff includes double flag on net
        self.net =  torch.nn.Sequential(
                torch.nn.Linear(n_input, n_hidden),
                torch.nn.GELU(),
                torch.nn.Linear(n_hidden, n_hidden),
                torch.nn.GELU(),
                torch.nn.Linear(n_hidden, n_hidden),
                torch.nn.GELU(),
                torch.nn.Linear(n_hidden, 1),
            ).double().to(device)

        #to device not declared in xcdiff
        self.tanh = torch.nn.Tanh().to(device)
        self.lobf = LOB(lob).to(device)
        #below declared in xcdiff
        self.sig = torch.nn.Sigmoid()
        self.shift = 1/(1+np.exp(-1e-3))

    def forward(self, rho, **kwargs):
        """Forward pass

        Args:
            rho (_type_): _description_

        Returns:
            _type_: _description_
        """
        # print(rho.size, rho.shape, rho.dtype)
        # print('x call -- rho shape', rho.shape)
        # print('x call -- rho[...,self.use] shape', rho[...,self.use].shape)
        squeezed = self.net(rho[...,self.use]).squeeze()
        # print('x call -- squeezed shape', squeezed.shape)
        # print('x call -- squeezed', squeezed)

        if self.ueg_limit:
            ueg_lim = rho[...,self.use[0]]
            if len(self.use) > 1:
                ueg_lim_a = torch.pow(self.tanh(rho[...,self.use[1]]),2)
            else:
                ueg_lim_a = 0
            #below comparison not in xcdiff
            if len(self.use) > 2:
                ueg_lim_nl = torch.sum(rho[...,self.use[2:]],dim=-1)
            else:
                ueg_lim_nl = 0
        else:
            ueg_lim = 1
            ueg_lim_a = 0
            ueg_lim_nl = 0

        if self.lob:
            result = self.lobf(squeezed*(ueg_lim + ueg_lim_a + ueg_lim_nl))
        else:
            result = squeezed*(ueg_lim + ueg_lim_a + ueg_lim_nl)

        return result

class C_L(torch.nn.Module):
    def __init__(self, n_input=2,n_hidden=16, device='cpu', ueg_limit=False, lob=2.0, use = []):
        """Local correlation model based on MLP
        Receives density descriptors in this order : [rho, spinscale, s, alpha, nl]
        input may be truncated depending on level of approximation

        Args:
            n_input (int, optional): Input dimensions (LDA: 2, GGA: 3 , meta-GGA: 4). Defaults to 2.
            n_hidden (int, optional): Number of hidden nodes (three hidden layers used by default). Defaults to 16.
            device (str, optional): {'cpu','cuda'}. Defaults to 'cpu'.
            ueg_limit (bool, optional): Enforce uniform homoegeneous electron gas limit. Defaults to False.
            lob (float, optional): Technically Lieb-Oxford bound but used here to enforce non-negativity. Should be kept at 2.0 in most instances. Defaults to 2.0.
            use (list of ints, optional): Indices for [s, alpha] (in that order) in input, to determine UEG limit. Defaults to [].
        """
        super().__init__()
        self.spin_scaling = False
        self.lob = False
        self.ueg_limit = ueg_limit
        self.n_input=n_input

        if not use:
            self.use = torch.Tensor(np.arange(n_input)).long().to(device)
        else:
            self.use = torch.Tensor(use).long().to(device)
        self.net = torch.nn.Sequential(
                torch.nn.Linear(n_input, n_hidden),
                torch.nn.GELU(),
                torch.nn.Linear(n_hidden, n_hidden),
                torch.nn.GELU(),
                torch.nn.Linear(n_hidden, n_hidden),
                torch.nn.GELU(),
                torch.nn.Linear(n_hidden, 1),
                torch.nn.Softplus()
            ).double().to(device)
        self.sig = torch.nn.Sigmoid()

        self.tanh = torch.nn.Tanh()
        #self.lob section allows for different values here, default=2. xcdiff doesn't have this,
        #assumes 2 always
        self.lob = lob
        if self.lob:
            self.lobf = LOB(self.lob)
        else:
            self.lob =  1000.0
            self.lobf = LOB(self.lob)


    def forward(self, rho, **kwargs):
        """Forward pass in network

        Args:
            rho (torch.Tensor): density

        Returns:
            _type_: _description_
        """
        inp = rho
        # print(rho.size, rho.shape, rho.dtype)
        # print('c call -- rho shape', rho.shape)
        # print('c call, rho[...,self.use] shape', rho.shape)
        # print('c call, rho[...,self.use]', rho)        
        squeezed = -self.net(inp).squeeze()
        # print('c call -- squeezed shape', squeezed.shape)
        # print('c call -- squeezed', squeezed)
        
        if self.ueg_limit:
            #below not form used in xcdiff
#            ueg_lim = rho[...,self.use[0]]
            #below form used in xcdiff,
            ueg_lim = self.tanh(rho[...,self.use[0]])
            if len(self.use) > 1:
                ueg_lim_a = torch.pow(self.tanh(rho[...,self.use[1]]),2)
            else:
                ueg_lim_a = 0
            #xcdiff does not include this next comparison
            if len(self.use) > 2:
                ueg_lim_nl = torch.sum(self.tanh(rho[...,self.use[2:]])**2,dim=-1)
            else:
                ueg_lim_nl = 0

            ueg_factor = ueg_lim + ueg_lim_a + ueg_lim_nl
        else:
            ueg_factor = 1
        #xcdiff below returns the negative of the negative inputs
        #lob is sigmoid, so odd function, negatives cancel, so not needed
        if self.lob:
            return self.lobf(squeezed*ueg_factor)
        else:
            return squeezed*ueg_factor
class LDA_X(torch.nn.Module):
    def __init__(self):
        """ UEG exchange"""
        super().__init__()

    def forward(self, rho, **kwargs):
        return -3/4*(3/np.pi)**(1/3)*rho**(1/3)
params_a_pp     = [1,  1,  1]
params_a_alpha1 = [0.21370,  0.20548,  0.11125]
params_a_a      = [0.031091, 0.015545, 0.016887]
params_a_beta1  = [7.5957, 14.1189, 10.357]
params_a_beta2  = [3.5876, 6.1977, 3.6231]
params_a_beta3  = [1.6382, 3.3662,  0.88026]
params_a_beta4  = [0.49294, 0.62517, 0.49671]
params_a_fz20   = 1.709921
       
class PW_C(torch.nn.Module):
    def __init__(self):
        """ UEG correlation, Perdew & Wang"""
        super().__init__()
    def forward(self, rs, zeta):
        def g_aux(k, rs):
            return params_a_beta1[k]*torch.sqrt(rs) + params_a_beta2[k]*rs\
          + params_a_beta3[k]*rs**1.5 + params_a_beta4[k]*rs**(params_a_pp[k] + 1)

        def g(k, rs):
            return -2*params_a_a[k]*(1 + params_a_alpha1[k]*rs)\
          * torch.log(1 +  1/(2*params_a_a[k]*g_aux(k, rs)))

        def f_zeta(zeta):
            return ((1+zeta)**(4/3) + (1-zeta)**(4/3) - 2)/(2**(4/3)-2)

        def f_pw(rs, zeta):
            return g(0, rs) + zeta**4*f_zeta(zeta)*(g(1, rs) - g(0, rs) + g(2, rs)/params_a_fz20)\
          - f_zeta(zeta)*g(2, rs)/params_a_fz20

        return f_pw(rs, zeta)

class XC(torch.nn.Module):

    def __init__(self, grid_models=None, heg_mult=True, pw_mult=True,
                    level = 1, exx_a=None, epsilon=1e-8):
        """Defines the XC functional on a grid

        Args:
            grid_models (list, optional): list of X_L (local exchange) or C_L (local correlation). Defines the xc-models/enhancement factors. Defaults to None.
            heg_mult (bool, optional): Use homoegeneous electron gas exchange (multiplicative if grid_models is not empty). Defaults to True.
            pw_mult (bool, optional): Use homoegeneous electron gas correlation (Perdew & Wang). Defaults to True.
            level (int, optional): Controls the number of density "descriptors" generated. 1: LDA, 2: GGA, 3:meta-GGA, 4: meta-GGA + electrostatic (nonlocal). Defaults to 1.
            exx_a (_type_, optional): Exact exchange mixing parameter. Defaults to None.
            epsilon (float, optional): Offset to avoid div/0 in calculations. Defaults to 1e-8.
        """

        super().__init__()
        self.heg_mult = heg_mult
        self.pw_mult = pw_mult
        self.grid_coords = None
        self.training = True
        self.level = level
        self.epsilon = epsilon
        if level > 3:
            print('WARNING: Non-local models highly experimental and likely will not work ')
        self.loge = 1e-5
        self.s_gam = 1

        if heg_mult:
            self.heg_model = LDA_X()
        if pw_mult:
            self.pw_model = PW_C()
        self.grid_models = list(grid_models)
        if self.grid_models:
            self.grid_models = torch.nn.ModuleList(self.grid_models)
        self.model_mult = [1 for m in self.grid_models]

        if exx_a is not None:
            self.exx_a = torch.nn.Parameter(torch.Tensor([exx_a]))
            self.exx_a.requires_grad = True
        else:
            self.exx_a = 0

    def evaluate(self):
        """Switches self.training flag to False
        """
        self.training=False
    def train(self):
        """Switches self.training flag to True
        """
        self.training=True

    def add_model_mult(self, model_mult):
        """_summary_

        .. todo:: 
            Unclear what the purpose of this is

        Args:
            model_mult (_type_): _description_
        """
        del(self.model_mult)
        self.register_buffer('model_mult',torch.Tensor(model_mult))

    def add_exx_a(self, exx_a):
        """Adds exact-exchange mixing parameter after initialization

        Args:
            exx_a (float): Exchange mixing parameter
        """
        self.exx_a = torch.nn.Parameter(torch.Tensor([exx_a]))
        self.exx_a.requires_grad = True

    # Density (rho)
    def l_1(self, rho):
        """Level 1 Descriptor -- Creates dimensionless quantity from rho.
        Eq. 3 in `base paper <https://link.aps.org/doi/10.1103/PhysRevB.104.L161109>`_

        .. math:: x_0 = \\rho^{1/3}

        Args:
            rho (torch.Tensor): density

        Returns:
            torch.Tensor: dimensionless density
        """
        return rho**(1/3)

    # Reduced density gradient s
    def l_2(self, rho, gamma):
        """Level 2 Descriptor -- Reduced gradient density
        Eq. 5 in `base paper <https://link.aps.org/doi/10.1103/PhysRevB.104.L161109>`_

        .. math:: x_2=s=\\frac{1}{2(3\\pi^2)^{1/3}} \\frac{|\\nabla \\rho|}{\\rho^{4/3}}

        Args:
            rho (torch.Tensor): density
            gamma (torch.Tensor): squared density gradient

        Returns:
            torch.Tensor: reduced density gradient s
        """
        return torch.sqrt(gamma)/(2*(3*np.pi**2)**(1/3)*rho**(4/3)+self.epsilon)

    # Reduced kinetic energy density alpha
    def l_3(self, rho, gamma, tau):
        """Level 3 Descriptor -- Reduced kinetic energy density
        Eq. 6 in `base paper <https://link.aps.org/doi/10.1103/PhysRevB.104.L161109>`_

        .. math:: x_3 = \\alpha = \\frac{\\tau-\\tau^W}{\\tau^{unif}},

        where

        .. math:: \\tau^W = \\frac{|\\nabla \\rho|^2}{8\\rho}, \\tau^{unif} = \\frac{3}{10} (3\\pi^2)^{2/3}\\rho^{5/3}.

        Args:
            rho (torch.Tensor): density
            gamma (torch.Tensor): squared density gradient
            tau (torch.Tensor): kinetic energy density

        Returns:
            torch.Tensor: reduced kinetic energy density
        """
        uniform_factor = (3/10)*(3*np.pi**2)**(2/3)
        tw = gamma/(8*(rho+self.epsilon))
        #commented is dpyscflite version, uncommented is xcdiff version
        #return torch.nn.functional.relu((tau - tw)/(uniform_factor*rho**(5/3)+tw*1e-3 + 1e-12))
        return (tau - gamma/(8*(rho+self.epsilon)))/(uniform_factor*rho**(5/3)+self.epsilon)

    # Unit-less electrostatic potential
    def l_4(self, rho, nl):
        """Level 4 Descriptor -- Unitless electrostatic potential

        .. todo:: Figure out what exactly this part is

        Args:
            rho (torch.Tensor): density
            nl (torch.Tensor): some non-local descriptor

        Returns:
            torch.nn.functional.relu: _description_
        """
        u = nl[:,:1]/((rho.unsqueeze(-1)**(1/3))*self.nl_ueg[:,:1] + self.epsilon)
        wu = nl[:,1:]/((rho.unsqueeze(-1))*self.nl_ueg[:,1:] + self.epsilon)
        return torch.nn.functional.relu(torch.cat([u,wu],dim=-1))

    def get_descriptors(self, rho0_a, rho0_b, gamma_a, gamma_b, gamma_ab,nl_a,nl_b, tau_a, tau_b, spin_scaling = False):
        """Creates 'ML-compatible' descriptors from the electron density and its gradients, a & b correspond to spin channels

        Args:
            rho0_a (torch.Tensor): :math:`\\rho` in spin-channel a
            rho0_b (torch.Tensor): :math:`\\rho` in spin-channel b
            gamma_a (torch.Tensor): :math:`|\\nabla \\rho|^2` in spin-channel a 
            gamma_b (torch.Tensor): :math:`|\\nabla \\rho|^2` in spin-channel b
            gamma_ab (torch.Tensor): _description_
            nl_a (torch.Tensor): _description_
            nl_b (torch.Tensor): _description_
            tau_a (torch.Tensor): KE density in spin-channel a
            tau_b (torch.Tensor): KE density in spin-channel b
            spin_scaling (bool, optional): Flag for spin-scaling. Defaults to False.

        Returns:
            _type_: _description_
        """

        if not spin_scaling:
            #If no spin-scaling, calculate polarization and use for X1
            zeta = (rho0_a - rho0_b)/(rho0_a + rho0_b + self.epsilon)
            spinscale = 0.5*((1+zeta)**(4/3) + (1-zeta)**(4/3)) # zeta

        if self.level > 0:  #  LDA
            if spin_scaling:
                descr1 = torch.log(self.l_1(2*rho0_a) + self.loge)
                descr2 = torch.log(self.l_1(2*rho0_b) + self.loge)
            else:
                descr1 = torch.log(self.l_1(rho0_a + rho0_b) + self.loge)# rho
                descr2 = torch.log(spinscale) # zeta
            descr = torch.cat([descr1.unsqueeze(-1), descr2.unsqueeze(-1)],dim=-1)
        if self.level > 1: # GGA
            if spin_scaling:
                descr3a = self.l_2(2*rho0_a, 4*gamma_a) # s
                descr3b = self.l_2(2*rho0_b, 4*gamma_b) # s
                descr3 = torch.cat([descr3a.unsqueeze(-1), descr3b.unsqueeze(-1)],dim=-1)
                descr3 = (1-torch.exp(-descr3**2/self.s_gam))*torch.log(descr3 + 1)
            else:
                descr3 = self.l_2(rho0_a + rho0_b, gamma_a + gamma_b + 2*gamma_ab) # s
                #line below in xcdiff, not dpyscfl
                descr3 = descr3/((1+zeta)**(2/3) + (1-zeta)**2/3)
                descr3 = descr3.unsqueeze(-1)
                descr3 = (1-torch.exp(-descr3**2/self.s_gam))*torch.log(descr3 + 1)
            descr = torch.cat([descr, descr3],dim=-1)
        if self.level > 2: # meta-GGA
            if spin_scaling:
                descr4a = self.l_3(2*rho0_a, 4*gamma_a, 2*tau_a)
                descr4b = self.l_3(2*rho0_b, 4*gamma_b, 2*tau_b)
                descr4 = torch.cat([descr4a.unsqueeze(-1), descr4b.unsqueeze(-1)],dim=-1)
                #below in xcdiff, not dpyscfl
                descr4 = descr4**3/(descr4**2+self.epsilon)
            else:
                descr4 = self.l_3(rho0_a + rho0_b, gamma_a + gamma_b + 2*gamma_ab, tau_a + tau_b)
                #next 2 in xcdiff, not dpyscfl
                descr4 = 2*descr4/((1+zeta)**(5/3) + (1-zeta)**(5/3))
                descr4 = descr4**3/(descr4**2+self.epsilon)

                descr4 = descr4.unsqueeze(-1)
            descr4 = torch.log((descr4 + 1)/2)
            descr = torch.cat([descr, descr4],dim=-1)
        if self.level > 3: # meta-GGA + V_estat
            if spin_scaling:
                descr5a = self.l_4(2*rho0_a, 2*nl_a)
                descr5b = self.l_4(2*rho0_b, 2*nl_b)
                descr5 = torch.log(torch.stack([descr5a, descr5b],dim=-1) + self.loge)
                descr5 = descr5.view(descr5.size()[0],-1)
            else:
                descr5= torch.log(self.l_4(rho0_a + rho0_b, nl_a + nl_b) + self.loge)

            descr = torch.cat([descr, descr5],dim=-1)
        if spin_scaling:
            print('spin_scaling')
            print('descr size -- ', descr.size())
            descr = descr.view(descr.size()[0],-1,2).permute(2,0,1)
            print('reshaped descr size --', descr.size())
        return descr


    def forward(self, dm):
        """_summary_

        Args:
            dm (torch.Tensor): density matrix

        Returns:
            _type_: _description_
        """
        Exc = 0
        if self.grid_models or self.heg_mult:
            if self.ao_eval.dim()==2:
                ao_eval = self.ao_eval.unsqueeze(0)
            else:
                ao_eval = self.ao_eval

            # Create density (and gradients) from atomic orbitals evaluated on grid
            # and density matrix
            # rho[ijsp]: del_i phi del_j phi dm (s: spin, p: grid point index)
            #print("FORWARD PASS IN XC. AO_EVAL SHAPE, DM SHAPE: ", ao_eval.shape, dm.shape)
            rho = contract('xij,yik,...jk->xy...i', ao_eval, ao_eval, dm)+1e-10
            rho0 = rho[0,0]
            drho = rho[0,1:4] + rho[1:4,0]
            tau = 0.5*(rho[1,1] + rho[2,2] + rho[3,3])

            # Non-local electrostatic potential
            if self.level > 3:
                non_loc = contract('mnQ, QP, Pki, ...mn-> ...ki', self.df_3c, self.df_2c_inv, self.vh_on_grid, dm)
            else:
                non_loc = torch.zeros_like(tau).unsqueeze(-1)

            if dm.dim() == 3: # If unrestricted (open-shell) calculation

                # Density
                rho0_a = rho0[0]
                rho0_b = rho0[1]

                # Contracted density gradient
                gamma_a, gamma_b = contract('ij,ij->j',drho[:,0],drho[:,0]), contract('ij,ij->j',drho[:,1],drho[:,1])
                gamma_ab = contract('ij,ij->j',drho[:,0],drho[:,1])

                # Kinetic energy density
                tau_a, tau_b = tau

                # E.-static
                non_loc_a, non_loc_b = non_loc
            else:
                rho0_a = rho0_b = rho0*0.5
                gamma_a=gamma_b=gamma_ab= contract('ij,ij->j',drho[:],drho[:])*0.25
                tau_a = tau_b = tau*0.5
                non_loc_a=non_loc_b = non_loc*0.5

            # xc-energy per unit particle
            exc = self.eval_grid_models(torch.cat([rho0_a.unsqueeze(-1),
                                                    rho0_b.unsqueeze(-1),
                                                    gamma_a.unsqueeze(-1),
                                                    gamma_ab.unsqueeze(-1),
                                                    gamma_b.unsqueeze(-1),
                                                    torch.zeros_like(rho0_a).unsqueeze(-1), #Dummy for laplacian
                                                    torch.zeros_like(rho0_a).unsqueeze(-1), #Dummy for laplacian
                                                    tau_a.unsqueeze(-1),
                                                    tau_b.unsqueeze(-1),
                                                    non_loc_a,
                                                    non_loc_b],dim=-1))
            print('xc call, exc.shape', exc.shape)
            #inplace modification throws MulBackwards0 error sometimes?
            Exc += torch.sum(((rho0_a + rho0_b)*exc.clone()[:,0])*self.grid_weights)
            #Exc = torch.sum(((rho0_a + rho0_b)*exc[:,0])*self.grid_weights)
            # try:
            #     Exc = torch.sum(((rho0_a + rho0_b)*exc[:,0])*self.grid_weights)
            # except:
            #     e = sys.exc_info()[0]
            #     Exc = torch.sum(((rho0_a + rho0_b)*exc[:,0])*self.grid_weights)
            #     print("Error detected")
            #     print(e)                

        #Below in xcdiff, not in dpyscfl
        #However, keep commented out -- self.nxc_models not implemented
        #if self.nxc_models:
        #    for nxc_model in self.nxc_models:
        #        Exc += nxc_model(dm, self.ml_ovlp)

        # print('XC.FORWARD: Exc = ', Exc)
        
        return Exc

    def eval_grid_models(self, rho, debug=False):
        """Evaluates all models stored in self.grid_models along with HEG exchange and correlation


        Args:
            rho ([list of torch.Tensors]): List with [rho0_a,rho0_b,gamma_a,gamma_ab,gamma_b, dummy for laplacian, dummy for laplacian, tau_a, tau_b, non_loc_a, non_loc_b]

        Returns:
            _type_: _description_
        """
        Exc = 0
        rho0_a = rho[:, 0]
        rho0_b = rho[:, 1]
        gamma_a = rho[:, 2]
        gamma_ab = rho[:, 3]
        gamma_b = rho[:, 4]
        tau_a = rho[:, 7]
        tau_b = rho[:, 8]
        nl = rho[:,9:]
        nl_size = nl.size()[-1]//2
        nl_a = nl[:,:nl_size]
        nl_b = nl[:,nl_size:]

        C_F= 3/10*(3*np.pi**2)**(2/3)
        #in xcdiff, self.meta_local would change below assignments
        #not used here
        rho0_a_ueg = rho0_a
        rho0_b_ueg = rho0_b

        zeta = (rho0_a_ueg - rho0_b_ueg)/(rho0_a_ueg + rho0_b_ueg + 1e-8)
        rs = (4*np.pi/3*(rho0_a_ueg+rho0_b_ueg + 1e-8))**(-1/3)
        rs_a = (4*np.pi/3*(rho0_a_ueg + 1e-8))**(-1/3)
        rs_b = (4*np.pi/3*(rho0_b_ueg + 1e-8))**(-1/3)


        exc_a = torch.zeros_like(rho0_a)
        exc_b = torch.zeros_like(rho0_a)
        exc_ab = torch.zeros_like(rho0_a)

        if debug:
            print('eval_grid_models nan summary:')
            print('zeta, rs, rs_a, rs_b, exc_a, exc_b, exc_ab')
            print('{}, {}, {}, {}, {}, {}, {}'.format(
                torch.isnan(zeta).any().sum(),
                torch.isnan(rs).any().sum(),
                torch.isnan(rs_a).any().sum(),
                torch.isnan(rs_b).any().sum(),
                torch.isnan(exc_a).any().sum(),
                torch.isnan(exc_b).any().sum(),
                torch.isnan(exc_ab).any().sum(),                
            ))

        descr_method = self.get_descriptors


        descr_dict = {}
        rho_tot = rho0_a + rho0_b
        if self.grid_models:

            for grid_model in self.grid_models:
                if not grid_model.spin_scaling:
                    if not 'c' in descr_dict:
                        descr_dict['c'] = descr_method(rho0_a, rho0_b, gamma_a, gamma_b,
                                                                         gamma_ab, nl_a, nl_b, tau_a, tau_b, spin_scaling = False)
                        descr_dict['c'] = descr_method(rho0_a, rho0_b, gamma_a, gamma_b,
                                                                         gamma_ab, nl_a, nl_b, tau_a, tau_b, spin_scaling = False)
                    descr = descr_dict['c']
                    #print("DESCR: ", descr)
                    #print("DESCR MAX:", torch.max(descr))
                    #print("DESCR MIN: ", torch.min(descr))
                    #print("GRID MODEL: ", grid_model)
                    for name, param in grid_model.named_parameters():
                        if torch.isnan(param).any():
                            print("NANS IN NETWORK WEIGHT -- {}".format(name))
                            raise ValueError("NaNs in Network Weights.")

                    #Evaluate network with descriptors on grid
                    #in xcdiff, edge_index is passed here, not in dpyscfl
                    exc = grid_model(descr,
                                      grid_coords = self.grid_coords)
                    #print("EXC GRID_MODEL C: ", exc)

                    #Included from xcdiff, 2dim exc -> spin polarized
                    if exc.dim() == 2: #If using spin decomposition
                        pw_alpha = self.pw_model(rs_a, torch.ones_like(rs_a))
                        pw_beta = self.pw_model(rs_b, torch.ones_like(rs_b))
                        pw = self.pw_model(rs, zeta)
                        ec_alpha = (1 + exc[:,0])*pw_alpha*rho0_a/(rho_tot+1e-8)
                        ec_beta =  (1 + exc[:,1])*pw_beta*rho0_b/(rho_tot+1e-8)
                        ec_mixed = (1 + exc[:,2])*(pw*rho_tot - pw_alpha*rho0_a - pw_beta*rho0_b)/(rho_tot+1e-8)
                        exc_ab = ec_alpha + ec_beta + ec_mixed
                    else:
                        if self.pw_mult:
                            exc_ab += (1 + exc)*self.pw_model(rs, zeta)
                        else:
                            exc_ab += exc
#                    if self.pw_mult:
#                        exc_ab += (1 + exc)*self.pw_model(rs, zeta)
#                    else:
#                        exc_ab += exc
                else:
                    if not 'x' in descr_dict:
                        descr_dict['x'] = descr_method(rho0_a, rho0_b, gamma_a, gamma_b,
                                                                         gamma_ab, nl_a, nl_b, tau_a, tau_b, spin_scaling = True)
                    descr = descr_dict['x']

                    #in xcdiff, edge_index is passed here, not in dpyscfl
                    exc = grid_model(descr,
                                  grid_coords = self.grid_coords)

                    #print("EXC GRID_MODEL X: ", exc)

                    if self.heg_mult:
                        exc_a += (1 + exc[0])*self.heg_model(2*rho0_a_ueg)*(1-self.exx_a)
                    else:
                        exc_a += exc[0]*(1-self.exx_a)

                    if torch.all(rho0_b == torch.zeros_like(rho0_b)): #Otherwise produces NaN's
                        exc_b += exc[0]*0
                    else:
                        if self.heg_mult:
                            exc_b += (1 + exc[1])*self.heg_model(2*rho0_b_ueg)*(1-self.exx_a)
                        else:
                            exc_b += exc[1]*(1-self.exx_a)

        else:
            if self.heg_mult:
                exc_a = self.heg_model(2*rho0_a_ueg)
                exc_b = self.heg_model(2*rho0_b_ueg)
            if self.pw_mult:
                exc_ab = self.pw_model(rs, zeta)


        # exc = rho0_a_ueg/rho_tot*exc_a + rho0_b_ueg/rho_tot*exc_b + exc_ab
        exc = exc_a * (rho0_a_ueg/ (rho_tot + self.epsilon)) + exc_b*(rho0_b_ueg / (rho_tot + self.epsilon)) + exc_ab
        if debug:
            print('eval_grid_models nan summary:')
            print('zeta, rs, rs_a, rs_b, exc_a, exc_b, exc_ab')
            print('{}, {}, {}, {}, {}, {}, {}'.format(
                torch.isnan(zeta).any().sum(),
                torch.isnan(rs).any().sum(),
                torch.isnan(rs_a).any().sum(),
                torch.isnan(rs_b).any().sum(),
                torch.isnan(exc_a).any().sum(),
                torch.isnan(exc_b).any().sum(),
                torch.isnan(exc_ab).any().sum(),                
            ))

        return exc.unsqueeze(-1)
class make_rdm1(torch.nn.Module):

    def __init__(self):
        """ Generate one-particle reduced density matrix"""
        super().__init__()

    def forward(self, mo_coeff, mo_occ):
        """Forward pass calculating one-particle reduced density matrix.

        Args:
            mo_coeff (torch.Tensor/np.array(?)): Molecular orbital coefficients
            mo_occ (torch.Tensor/np.array(?)): Molecular orbital occupation numbers

        Returns:
            torch.Tensor/np.array(?): The RDM1
        """
        if mo_coeff.ndim == 3:
            mocc_a = mo_coeff[0, :, mo_occ[0]>0]
            mocc_b = mo_coeff[1, :, mo_occ[1]>0]
            if torch.sum(mo_occ[1]) > 0:
                return torch.stack([contract('ij,jk->ik', mocc_a*mo_occ[0,mo_occ[0]>0], mocc_a.T),
                                    contract('ij,jk->ik', mocc_b*mo_occ[1,mo_occ[1]>0], mocc_b.T)],dim=0)
            else:
                return torch.stack([contract('ij,jk->ik', mocc_a*mo_occ[0,mo_occ[0]>0], mocc_a.T),
                                    torch.zeros_like(mo_coeff)[0]],dim=0)
        else:
            mocc = mo_coeff[:, mo_occ>0]
            return contract('ij,jk->ik', mocc*mo_occ[mo_occ>0], mocc.T)

class get_rho(torch.nn.Module):
    def __init__(self):
        super().__init__()


    def forward(self, dm, results):
        ao_eval = results['ao_eval'][0]
        print("AO_EVAL, DM SHAPES: {}. {}.".format(ao_eval.shape, dm.shape))
        if dm.ndim == 2:
            print("2D DM.")
            print("RESULTS N_ELEC: ", results['n_elec'])
            rho = contract('ij,ik,jk->i',
                               ao_eval, ao_eval, dm)
        else:
            print("NON-2D DM")
            rho = contract('ij,ik,xjk->xi',
                               ao_eval, ao_eval, dm)
        return rho

class energy_tot(torch.nn.Module):

    def __init__(self):
        """
        Total energy (electron-electron + electron-ion; ion-ion not included)
        """
        super().__init__()

    def forward(self, dm, hcore, veff):
        """Tensor contraction to find total electron energy (e-e + e-ion)

        Args:
            dm (torch.Tensor): Density matrix
            hcore (torch.Tensor): Core Hamiltonian
            veff (torch.Tensor): Effective Potential

        Returns:
            torch.Tensor: The electronic energy
        """
        return torch.sum((contract('...ij,ij', dm, hcore) + .5*contract('...ij,...ij', dm, veff))).unsqueeze(0)

class get_veff(torch.nn.Module):
    def __init__(self, exx=False, model=None, req_grad=False):
        """Builds the one-electron effective potential (not including local xc-potential)

        Args:
            exx (bool, optional): Exact exchange flag. Defaults to False.
            model (xc-model): Only used for exact exchange mixing parameter. Defaults to None.
            df (bool, optional): Use density fitting flag. Defaults to False.
        """
        super().__init__()
        self.exx = exx
        self.model = model
        self.req_grad = req_grad
        
    def forward(self, dm, eri):
        """Forward pass if no density fitting

        Args:
            dm (torch.Tensor): Density matrix
            eri (torch.Tensor(?)): Electron repulsion integral tensor

        Returns:
            torch.Tensor: The "effective" potential
        """
        J = contract('...ij,ijkl->...kl',dm, eri)
        if self.exx:
            K = self.model.exx_a * contract('...ij,ikjl->...kl',dm, eri)
        else:
            K =  torch.zeros_like(J)

        if J.ndim == 3:
            return J[0] + J[1] - K
        else:
            return J-0.5*K
    def forward2(self, dm, eri):
        ''' reimplementation of hf.dot_eri_dm '''
        nao = dm.shape[-1]
        if eri.nelement() == nao**4:
            vj = contract('...ij,ijkl->...kl',dm, eri)
            if self.exx:
                vk = self.model.exx_a * contract('...ij,ikjl->...kl',dm, eri)
            else:
                vk =  torch.zeros_like(vj)
    
        else:
            # raise ValueError('eri elements != nao**4')
            vj, vk = scf._vhf.incore(eri.detach().numpy(), dm.detach().numpy(), 0, with_j = True, with_k = self.exx)

        if not self.exx:
            vk = np.zeros_like(vj)
        if vj.ndim == 3:
            veff =  vj[0] + vj[1] - vk
        else:
            veff =  vj-0.5*vk

        return torch.tensor(veff, requires_grad=self.req_grad)
        # if vj.ndim == 3:
        #     return vj[0] + vj[1] - vk
        # else:
        #     return vj - 0.5*vk    
        
        

def get_veff_np(dm, eri):
        """Forward pass if no density fitting

        Args:
            dm (torch.Tensor): Density matrix
            eri (torch.Tensor(?)): Electron repulsion integral tensor

        Returns:
            torch.Tensor: The "effective" potential
        """
        J = contract('...ij,ijkl->...kl',dm, eri)
        K =  torch.zeros_like(J)
        if J.ndim == 3:
            return J[0] + J[1] - K
        else:
            return J-0.5*K
def energy_tot_np(dm, hcore, veff):
        """Tensor contraction to find total electron energy (e-e + e-ion)

        Args:
            dm (torch.Tensor): Density matrix
            hcore (torch.Tensor): Core Hamiltonian
            veff (torch.Tensor): Effective Potential

        Returns:
            torch.Tensor: The electronic energy
        """
        return torch.sum((contract('...ij,ij', dm, hcore) + .5*contract('...ij,...ij', dm, veff))).unsqueeze(0)
def make_rdm1_np(mo_coeff, mo_occ):
        """Forward pass calculating one-particle reduced density matrix.

        Args:
            mo_coeff (torch.Tensor/np.array(?)): Molecular orbital coefficients
            mo_occ (torch.Tensor/np.array(?)): Molecular orbital occupation numbers

        Returns:
            torch.Tensor/np.array(?): The RDM1
        """
        if mo_coeff.ndim == 3:
            mocc_a = mo_coeff[0, :, mo_occ[0]>0]
            mocc_b = mo_coeff[1, :, mo_occ[1]>0]
            if torch.sum(mo_occ[1]) > 0:
                return torch.stack([contract('ij,jk->ik', mocc_a*mo_occ[0,mo_occ[0]>0], mocc_a.T),
                                    contract('ij,jk->ik', mocc_b*mo_occ[1,mo_occ[1]>0], mocc_b.T)],dim=0)
            else:
                return torch.stack([contract('ij,jk->ik', mocc_a*mo_occ[0,mo_occ[0]>0], mocc_a.T),
                                    torch.zeros_like(mo_coeff)[0]],dim=0)
        else:
            mocc = mo_coeff[:, mo_occ>0]
            return contract('ij,jk->ik', mocc*mo_occ[mo_occ>0], mocc.T)



def get_fock(hc, veff):
    """Get the Fock matrix

    Args:
        hc (torch.Tensor): Core Hamiltonian
        veff (torch.Tensor): Effective Potential

    Returns:
        torch.Tensor: hc+veff
    """
    return hc + veff
def get_hcore(v, t):
    """ "Core" Hamiltionian, includes ion-electron and kinetic contributions

    .. math:: H_{core} = T + V_{nuc-elec}

    Args:
        v (torch.Tensor, np.array): Electron-ion interaction energy
        t (torch.Tensor, np.array): Kinetic energy

    Returns:
        torch.Tensor: v + t
    """
    return v + t


class eig(torch.nn.Module):

    def __init__(self):
        """Solves generalized eigenvalue problem using Cholesky decomposition
        """
        super().__init__()

    def forward(self, h, s_chol):
        """Solver for generalized eigenvalue problem

        .. todo:: torch.symeig is deprecated for torch.linalg.eigh, replace

        Args:
            h (torch.Tensor): Hamiltionian
            s_chol (torch.Tensor): (Inverse) Cholesky decomp. of overlap matrix S
                                    s_chol = np.linalg.inv(np.linalg.cholesky(S))

        Returns:
            (torch.Tensor, torch.Tensor): Eigenvalues (MO energies), eigenvectors (MO coeffs)
        """
        #e, c = torch.symeig(contract('ij,...jk,kl->...il',s_chol, h, s_chol.T), eigenvectors=True,upper=False)
        upper=False
        UPLO = "U" if upper else "L"
        e, c = torch.linalg.eigh(contract('ij,...jk,kl->...il',s_chol, h, s_chol.T), UPLO=UPLO)
        c = contract('ij,...jk ->...ik',s_chol.T, c.clone())
        return e, c
torch._C._debug_only_display_vmap_fallback_warnings(True)
class SCF(torch.nn.Module):

    def __init__(self, alpha=0.8, nsteps=10, xc=None, device='cpu', exx=False):
        """This class implements the self-consistent field (SCF) equations

        Args:
            alpha (float, optional): Linear mixing parameter. Defaults to 0.8.
            nsteps (int, optional): Number of scf steps. Defaults to 10.
            xc (dpyscfl.net.XC, optional): Class containing the exchange-correlation models. Defaults to None.
            device (str, optional): {'cpu','cuda'}, which device to use. Defaults to 'cpu'.
            exx (bool, optional): Use exact exchange flag. Defaults to False.
        """
        super().__init__()
        self.nsteps = nsteps
        self.alpha = alpha
        self.get_veff = get_veff(exx, xc, req_grad=REQ_GRAD).to(device) # Include Fock (exact) exchange?

        self.eig = eig().to(device)
        self.energy_tot = energy_tot().to(device)
        self.make_rdm1 = make_rdm1().to(device)
        self.xc = xc
        #ncore parameter used in xcdiff, not here

    def forward(self, dm, matrices, sc=True, **kwargs):
        """Forward pass SCF cycle

        Args:
            dm (torch.Tensor): Initial density matrix
            matrices (dict of torch.Tensors): Contains all other matrices that are considered fixed during SCF calculations (e-integrals etc.)
            sc (bool, optional): If True does self-consistent calculations, else single-pass. Defaults to True.

        Returns:
            dict of torch.Tensors: results: E, dm, and mo_energies
        """
        dm = dm[0]

        # Required matrices
        # ===================
        # v: Electron-ion pot.
        # t: Kinetic
        # mo_occ: MO occupations
        # e_nuc: Ion-Ion energy contribution
        # s: overlap matrix
        # s_chol: inverse Cholesky decomposition of overlap matrix
        v, t, mo_occ, e_nuc, s, s_chol = [matrices[key][0] for key in \
                                             ['v','t','mo_occ',
                                             'e_nuc','s','s_chol']]
        hc = get_hcore(v,t)

        # Optional matrices
        # ====================

        # Electron repulsion integrals
        eri = matrices.get('eri',[None])[0]

        grid_weights = matrices.get('grid_weights',[None])[0]
        grid_coords = matrices.get('grid_coords',[None])[0]
        #edge index called for here in xcdiff, not here

        # Atomic orbitals evaluated on grid
        ao_eval = matrices.get('ao_eval',[None])[0]

        # Used to restore correct potential after symmetrization:
        L = matrices.get('L', [torch.eye(dm.size()[-1])])[0]
        scaling = matrices.get('scaling',[torch.ones([dm.size()[-1]]*2)])[0]

        # Density fitting integrals
        df_2c_inv = matrices.get('df_2c_inv',[None])[0]
        df_3c = matrices.get('df_3c',[None])[0]

        # Electrostatic potential on grid
        vh_on_grid = matrices.get('vh_on_grid',[None])[0]

        dm_old = dm

        E = []
        deltadm = []
        nsteps = self.nsteps

        # if not self.xc.training:
        #     #if not training, backpropagation doesn't happen so don't need derivatives beyond
        #     #calculation at a given step
        #     create_graph = False
        # else:
        #     create_graph = True
        vvv = kwargs.get('verbose', False)
        if vvv:
            print('SCF Loop Beginning: {} Steps'.format(nsteps))

        # SCF iteration loop
        for step in range(nsteps):
            #some diis happens here in xcdiff, not implemented here
            if vvv:
                print('Step {}'.format(step))
            alpha = (self.alpha)**(step)+0.3
            beta = (1-alpha)
            dm = alpha * dm + beta * dm_old

            dm_old = dm
            if vvv:
                print("Density Matrix stats: ")
                print("Mean: ", torch.mean(dm))
                print("Min/Max: ", torch.min(dm), torch.max(dm))
                print("Select Indices: dm.flatten()[[0, 5, 10, 100]]", dm.flatten()[[0,5,10,100]])

            if df_3c is not None:
                veff = self.get_veff.forward_df(dm, df_3c, df_2c_inv, eri)
            elif kwargs.get('erisym_veff', False):
                veff = self.get_veff.forward2(dm, eri)
            else:
                veff = self.get_veff(dm, eri)

            if kwargs.get('debug', False):
                print('STEP-{}/VEFF: '.format(step), veff)
            
            if self.xc: #If using xc-functional (not Hartree-Fock)
                self.xc.ao_eval = ao_eval
                self.xc.grid_weights = grid_weights
                self.xc.grid_coords = grid_coords
                #edge index, ml_ovlp called for here in xcdiff
                if vh_on_grid is not None:
                    self.xc.vh_on_grid = vh_on_grid
                    self.xc.df_2c_inv = df_2c_inv
                    self.xc.df_3c = df_3c

                if torch.sum(mo_occ) == 1:   # Otherwise H produces NaNs
                    dm[1] = dm.clone()[0]*1e-12
                    dm_old[1] = dm.clone()[0]*1e-12

                exc = self.xc(dm)

                if kwargs.get('debug', False):
                    print('STEP-{}/exc: '.format(step), exc)

                
                # vxc = torch.autograd.functional.jacobian(self.xc, dm, create_graph=True)
                vxc = torch.autograd.functional.jacobian(self.xc, dm, create_graph=False,
                                                         vectorize=False)
                vxc1 = torch.autograd.grad(exc, dm)[0]
                print('vxc/vxc1 shapes,', vxc.shape, vxc1.shape)
                if kwargs.get('debug', False):
                    msize = vxc.element_size() * vxc.nelement()
                    msize1 = vxc1.element_size() * vxc1.nelement()
                    print('vxc: SHAPE = {}. SIZE = {} KB / {} MB / {} GB'.format(k, vxc.shape, msize/(1000), msize/(1000**2), msize/(1000**3)))
                    print('vxc1: SHAPE = {}. SIZE = {} KB / {} MB / {} GB'.format(k, vxc1.shape, msize1/(1000), msize1/(1000**2), msize1/(1000**3)))
                    print('|vxc - vxc1|.max(): ', abs(vxc-vxc1).max())
                # Restore correct symmetry for vxc
                if vxc.dim() > 2:
                    vxc = contract('ij,xjk,kl->xil',L,vxc.clone(),L.T)
                    vxc = torch.where(scaling.unsqueeze(0) > 0 , vxc.clone(), scaling.unsqueeze(0))
                else:
                    vxc = torch.mm(L,torch.mm(vxc.clone(),L.T))
                    vxc = torch.where(scaling > 0 , vxc.clone(), scaling)

                if torch.sum(mo_occ) == 1:   # Otherwise H produces NaNs
                    vxc[1] = torch.zeros_like(vxc.clone()[1])

                veff += vxc

                if kwargs.get('debug', False):
                    print('STEP-{}/VEFF+VXC: '.format(step), veff)


                #Add random noise to potential to avoid degeneracies in EVs
                if self.xc.training:#: and sc:
                    if step == 0:
                        print("Noise generation to avoid potential degeneracies")
                    noise = torch.abs(torch.randn(vxc.size(),device=vxc.device)*1e-4)
                    noise = noise + torch.transpose(noise,-1,-2)
                    veff = veff.clone() + noise
                if kwargs.get('debug', False):
                    print('STEP-{}/VEFF+VXC+NOISE: '.format(step), veff)

            else:
                exc=0
                vxc=torch.zeros_like(veff)
            f = get_fock(hc, veff)
            if kwargs.get('debug', False):
                print('STEP-{}/FOCK: '.format(step), f)

            mo_e, mo_coeff = self.eig(f, s_chol)
            dm = self.make_rdm1(mo_coeff, mo_occ)

            # e_tot = self.energy_tot(dm_old, hc, veff-vxc)+ e_nuc + exc
            e_tot = self.energy_tot(dm, hc, veff-vxc)+ e_nuc + exc
            E.append(e_tot)
            if vvv:
                print("{} Energy: {}".format(step, e_tot))
                print("History: {}".format(E))
            if not sc:
                break

        #in xcdiff, things happen here with mo_occ[:self.ncore], e_ip etc. not implemented here
        
        results = {'E': torch.cat(E), 'dm':dm, 'mo_energy':mo_e}

        return results

def get_optimizer(model, path='', hybrid=None, lr=1e-3, l2=1e-6):
    if hybrid:
            optimizer = torch.optim.Adam(list(model.parameters()) + [model.xc.exx_a],
                                    lr=lr, weight_decay=l2)
    else:
        optimizer = torch.optim.Adam(model.parameters(),
                                    lr=lr, weight_decay=l2)

    MIN_RATE = 1e-7
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',
                                                            verbose=True, patience=int(10/PRINT_EVERY),
                                                            factor=0.1, min_lr=MIN_RATE)

    if path:
        optimizer.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
    return optimizer, scheduler

In [3]:
def get_torch_xc(xctype, pretrain_loc='', hyb_par=0, path='', DEVICE='cpu', ueg_limit=True, meta_x=None, freec=False,
            inserts = 0, nhidden = 16):
    """_summary_

    Args:
        xctype (_type_): _description_
        pretrain_loc (_type_): _description_
        hyb_par (int, optional): _description_. Defaults to 0.
        path (str, optional): _description_. Defaults to ''.
        DEVICE (str, optional): _description_. Defaults to 'cpu'.
        ueg_limit (bool, optional): _description_. Defaults to True.
        meta_x (_type_, optional): _description_. Defaults to None.
        freec (bool, optional): _description_. Defaults to False.
    """
    print('FREEC', freec)
    if xctype == 'GGA':
        lob = 1.804 if ueg_limit else 0
        x = X_L(device=DEVICE,n_input=1, n_hidden=nhidden, use=[1], lob=lob, ueg_limit=ueg_limit) # PBE_X
        c = C_L(device=DEVICE,n_input=3, n_hidden=nhidden, use=[2], ueg_limit=ueg_limit and not freec)
        xc_level = 2
    elif xctype == 'MGGA':
        lob = 1.174 if ueg_limit else 0
        x = X_L(device=DEVICE,n_input=2, n_hidden=nhidden, use=[1,2], lob=1.174, ueg_limit=ueg_limit) # PBE_X
        c = C_L(device=DEVICE,n_input=4, n_hidden=nhidden, use=[2,3], ueg_limit=ueg_limit and not freec)
        xc_level = 3
    if pretrain_loc:
        print("Loading pre-trained models from " + pretrain_loc)
        x.load_state_dict(torch.load(pretrain_loc + '/x'))
        c.load_state_dict(torch.load(pretrain_loc + '/c'))
    EXX = bool(hyb_par)
    EXX_A = hyb_par if hyb_par else None

    xc = XC(grid_models=[x, c], heg_mult=True, level=xc_level)
    if path:
        try:
            xcp = torch.load(path, map_location=torch.device('cpu')).xc
            xc.load_state_dict(xcp.state_dict())
        except AttributeError:
            # AttributeError: 'RecursiveScriptModule' object has no attribute 'copy'
            #occurs when loading finished xc from xcdiff
            xcp = torch.jit.load(path)
            xc.load_state_dict(xcp.state_dict())

    return xc
def get_torch_weights_and_biases(torch_net):
    weights = []
    biases = []
    for nidx, net in enumerate(torch_net):
        try:
            w = jnp.array(net.weight.data)
            b = jnp.array(net.bias.data)
            weights.append(w)
            biases.append(b)
        except:
            print('This torch layer is not a Linear model.')
            continue
    return (weights, biases)
    
#per https://docs.kidger.site/equinox/tricks/
def trunc_init(weight: jax.Array, key: jax.random.PRNGKey) -> jax.Array:
    out, in_ = weight.shape
    stddev = math.sqrt(1 / in_)
    return stddev * jax.random.truncated_normal(key, shape=(out, in_), lower=-2, upper=2)

def init_linear_weight(model, seed, new_weights, new_bias):
    jax.random.PRNGKey(seed)
    is_linear = lambda x: isinstance(x, eqx.nn.Linear)
    get_weights = lambda m: [x.weight
                           for x in jax.tree_util.tree_leaves(m, is_leaf=is_linear)
                           if is_linear(x)]
    get_bias = lambda m: [x.bias
                           for x in jax.tree_util.tree_leaves(m, is_leaf=is_linear)
                           if is_linear(x)]

    weights = get_weights(model)
    bias = get_bias(model)
    new_model = eqx.tree_at(get_weights, model, new_weights)
    new_model = eqx.tree_at(get_bias, new_model, new_bias)
    return new_model

In [4]:
def jax_exc_func(model, ao_eval, gw):
    def ret_func(inp):
        return model(inp, ao_eval, gw)
    return ret_func

def jax_loss_func(loss_func, model, en, ao, gw, eri, mooc, hc, s):
    def ret_func(dm):
        return loss_func(model, dm, en, ao, gw, eri, mooc, hc, s)
    return ret_func

# @eqx.filter_jit
def jax_dm(dm, eri, vxc_grad_func, mo_occ, hc, s, ogd, alpha0=0.7):
    L = jnp.eye(dm.shape[-1])
    scaling = jnp.ones([dm.shape[-1]]*2)
    dm_old = dm
    def true_func(vxc):
        vxc.at[1].set(jnp.zeros_like(vxc[1]))
        return vxc
    def false_func(vxc):
        return vxc
    alpha = jnp.power(alpha0, 0)+0.3
    beta = (1-alpha)
    dm = alpha * dm + beta * dm_old
    dm_old = dm
    veff = xce.utils.get_veff()(dm, eri)
    vxc = jax.grad(vxc_grad_func)(dm)
    if vxc.ndim > 2:
        vxc = jnp.einsum('ij,xjk,kl->xil',L,vxc,L.T)
        vxc = jnp.where(jnp.expand_dims(scaling, 0) > 0 , vxc, jnp.expand_dims(scaling,0))
    else:
        vxc = jnp.matmul(L,jnp.matmul(vxc ,L.T))
        vxc = jnp.where(scaling > 0 , vxc, scaling)
    
    jax.lax.cond(jnp.sum(mo_occ) == 1, true_func, false_func, vxc)
    
    veff += vxc
    f = xce.utils.get_fock()(hc, veff)
    mo_e, mo_c = xce.utils.eig()(f+1e-6*jax.random.uniform(key=jax.random.PRNGKey(92017), shape=f.shape), s, ogd)
    dm = xce.utils.make_rdm1()(mo_c, mo_occ)
    return dm, mo_e, mo_c
    
# @eqx.filter_grad
def e_loss(model, inp_dm, ref_en, ao_eval, grid_weights, *args):
    print(f"e_loss; input stats. inp_dm.shape = {inp_dm.shape}, ref_en = {ref_en}, ao_eval.shape = {ao_eval.shape}, grid_weights.shape = {grid_weights.shape}")
    e_pred = model(inp_dm, ao_eval, grid_weights)
    eL = jnp.sqrt( np.mean((e_pred-ref_en)**2))
    # print('energy loss', eL)
    return eL

class E_loss(eqx.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, model, inp_dm, ref_en, ao_eval, grid_weights):

        e_pred = model(inp_dm, ao_eval, grid_weights)
        eL = jnp.sqrt( jnp.mean((e_pred-ref_en)**2))
        return eL

def holo_loss(model, inp_dm, ref_en, ao_eval, grid_weights, vxc_grad_func, mo_occ, hc, s, eri, ogd, alpha0):
    dm, mo_e, mo_c = jax_dm(inp_dm, eri, vxc_grad_func, mo_occ, hc, s, ogd, alpha0)
    homo_i = jnp.max(jnp.nonzero(mo_occ, size=dm.shape[0])[0])
    homo_e = mo_e[homo_i]
    lumo_e = mo_e[homo_i+1]
    pred_holo = lumo_e - homo_e
    print('pred_holo', pred_holo)
    return jnp.sqrt( np.mean ((pred_holo - ref_en)**2))

def loop_e_loss(model, inp_dms, ref_ens, ao_evals, grid_weights):
    e_preds = []
    for idx in range(len(ref_ens)):
        ep = model(inp_dms[idx], ao_evals[idx], grid_weights[idx])
        e_preds.append(ep)
    e_preds = jnp.array(e_preds)
    e_refs = jnp.array(ref_ens)
    eL = jnp.sqrt( jnp.mean( (e_refs-e_preds)**2))
    return eL
# @eqx.filter_grad

def dm_loss(model, inp_dm, ref_en, ao_eval, gw, eri, mo_occ, hc, s, ogd, *args):
    dmp, moe, moc = jax_dm(inp_dm, eri, jax_exc_func(model, ao_eval, gw), mo_occ, hc, s, ogd)
    dmL = jnp.sqrt(jnp.sum( (dmp - inp_dm)**2))
    return dmL


def loop_dm_loss(model, inp_dms, eris, mo_occs, hcs, ss, ao_evals, gws):
    dmL = 0
    for idx, dm in enumerate(inp_dms):
        dmp = jax_dm(inp_dms[idx], eris[idx], jax_exc_func(model, ao_evals[idx], gws[idx]), mo_occs[idx], hcs[idx], ss[idx])
        dmL += jnp.mean((dmp - inp_dms[idx])**2)
    dmL = jnp.sqrt(dmL)
    return dmL
    
# @eqx.filter_value_and_grad
def total_loss(model, inp_dms, ref_ens, ref_holos, ao_evals, grid_weights, eris, mo_occs, hcs, ss, ogd):
    # eL = e_loss(model, inp_dms, ref_ens, ao_evals, grid_weights, ogd)
    # dmL = dm_loss(model, inp_dms, ref_ens, ao_eval, grid_weights, eris, mo_occs, hcs, ss, ogd)
    vxcgf = jax_exc_func(model, ao_eval, grid_weights)
    holoL = holo_loss(model, inp_dms, ref_holos, ao_evals, grid_weights, vxcgf, mo_occs, hcs, ss, eris, ogd, alpha0=0.7)
    # return jnp.sqrt( eL**2 + holoL**2)
    return jnp.sqrt( holoL**2 )

def total_loop_loss(model, inp_dms, ref_ens, ao_evals, grid_weights, eris, mo_occs, hcs, ss):
    eL = loop_e_loss(model, inp_dms, ref_ens, ao_evals, grid_weights)
    dmL = loop_dm_loss(model, inp_dms, eris, mo_occs, hcs, ss, ao_evals, grid_weights)
    return jnp.sqrt(eL**2 + dmL**2)


In [8]:
#update docs, only input =2 ??? for MGGA? holdover from sebastian for some reason
xnet = xce.net.eX(n_input = 2, use = [1, 2], ueg_limit=True, lob=1.174)
# I guess use default LOB
cnet = xce.net.eC(n_input = 4, use = [2, 3], ueg_limit=True)
blankxc = xce.xc.eXC(grid_models = [xnet, cnet], level=3)

In [6]:
ptscan = get_torch_xc(xctype='MGGA', pretrain_loc='/home/awills/Documents/Research/dpyscfl/models/pretrained/scan',
                nhidden=16)
tgms = ptscan.grid_models
t_x_w, t_x_b = get_torch_weights_and_biases(tgms[0].net)
t_c_w, t_c_b = get_torch_weights_and_biases(tgms[1].net)

xnet = init_linear_weight(xnet, seed=92017, new_weights = t_x_w, new_bias = t_x_b)
cnet = init_linear_weight(cnet, seed=92017, new_weights = t_c_w, new_bias = t_c_b)
gms = [xnet, cnet]
xc = xce.xc.eXC(grid_models = gms, level=3)

FREEC False
Loading pre-trained models from /home/awills/Documents/Research/dpyscfl/models/pretrained/scan
This torch layer is not a Linear model.
This torch layer is not a Linear model.
This torch layer is not a Linear model.
This torch layer is not a Linear model.
This torch layer is not a Linear model.
This torch layer is not a Linear model.
This torch layer is not a Linear model.


In [18]:
xcd = get_torch_xc(xctype='MGGA', path='/home/awills/Documents/Research/torch_dpy/models/xcdiff/MODEL_MGGA/xc')

FREEC False




In [14]:
xc.grid_models[0].net.layers[0].weight - blankxc.grid_models[0].net.layers[0].weight

Array([[-2.15969814e+00, -3.33072779e-01],
       [-1.33774353e+00, -6.14078521e-01],
       [ 5.30909233e-01, -4.35234039e+00],
       [-2.82881466e-01, -1.18875395e+00],
       [-1.20849832e+01, -1.49310464e+00],
       [-4.79616419e+01, -3.27104576e-01],
       [-7.55614816e+00,  1.06021431e-01],
       [-8.53568427e-01, -3.04049449e-01],
       [-1.81156744e+00, -6.52161338e-01],
       [-4.16262999e+00, -2.73711612e+00],
       [ 3.95412026e-01, -7.49266771e-02],
       [-4.85004866e-01, -2.93522760e-01],
       [-4.21466678e+00, -9.35071154e-03],
       [-1.09003024e+00,  6.17014809e-02],
       [ 5.53233943e-01,  2.69094329e-01],
       [-3.36713487e+00, -1.55620977e-02]], dtype=float64)

In [7]:
p = '/home/awills/Documents/Research/xcquinox/models/pretrained/scan'
eqx.tree_serialise_leaves(os.path.join(p, 'xc.eqx'), xc)

In [15]:
loadxc = eqx.tree_deserialise_leaves(os.path.join(p, 'xc.eqx'), blankxc)

Test molecule with pyscfad

In [57]:
trainms = read('/home/awills/Documents/Research2/torch_dpy/subset09_nf/subat_ref_corrected.traj', ':')
energies = []
dms = []
ao_evals = []
gws = []
eris = []
mo_occs = []
hcs = []
vs = []
ts = []
ss = []
hologaps = []
ogds = []
for idx, at in enumerate(trainms[1:2]):
    name, mol = xce.utils.ase_atoms_to_mol(at, basis='def2tzvpd')
    mol.build()
    mf = dft.RKS(mol, xc='SCAN')
    e_tot = mf.kernel()
    dm = mf.make_rdm1()
    ao_eval = jnp.array(mf._numint.eval_ao(mol, mf.grids.coords, deriv=2))
    energies.append(mf.get_veff().exc)
    dms.append(dm)
    ogds.append(dm.shape)
    ao_evals.append(ao_eval)
    gws.append(mf.grids.weights)
    ts.append(mol.intor('int1e_kin'))
    vs.append(mol.intor('int1e_nuc'))
    mo_occs.append(mf.mo_occ)
    hcs.append(mf.get_hcore())
    eris.append(mol.intor('int2e'))
    ss.append(jnp.linalg.inv(jnp.linalg.cholesky(mol.intor('int1e_ovlp'))))
    hologaps.append(mf.mo_energy[mf.mo_occ == 0][0] - mf.mo_energy[mf.mo_occ > 1][-1])



In [8]:
xc(dms[0], ao_evals[0], gws[0])

Array(-0.69754768, dtype=float64)

In [172]:
class xcTrainer(eqx.Module):
    model: eqx.Module
    optim: optax.GradientTransformation
    loss: eqx.Module
    steps: int
    print_every: int
    clear_every: int
    memory_profile: bool
    verbose: bool
    do_jit: bool
    opt_state: tuple
    
    def __init__(self, model, optim, loss, steps=50, print_every=1, clear_every=1, memory_profile=False, verbose=False, do_jit=True):
        super().__init__()
        self.model = model
        self.optim = optim
        self.loss = loss
        self.steps = steps
        self.print_every = print_every
        self.clear_every = clear_every
        self.memory_profile = memory_profile
        self.verbose = verbose
        self.do_jit = do_jit
        self.opt_state = self.optim.init(eqx.filter(self.model, eqx.is_array)) 
    
    # def __post_init__(self, attr, value):
    #     object.__setattr__(self, attr, value)
    
    def clear_caches(self):
        for module_name, module in sys.modules.items():
            if module_name.startswith("jax"):
                if module_name not in ["jax.interpreters.partial_eval"]:
                    for obj_name in dir(module):
                        obj = getattr(module, obj_name)
                        if hasattr(obj, "cache_clear"):
                            try:
                                obj.cache_clear()
                            except:
                                pass
        gc.collect()


    def vprint(self, output):
        if self.verbose:
            print(output)

    def make_step(self, model, opt_state, *args):
        self.vprint('loss_value, grads')
        loss_value, grads = eqx.filter_value_and_grad(self.loss)(model, *args)
        self.vprint('updates, opt_state')
        updates, opt_state = self.optim.update(grads, opt_state, model)
        self.vprint('model update')
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_value


    def __call__(self, epoch_batch_len, model, *loss_input_lists):
        
        for step in range(self.steps):
            print('Epoch {}'.format(step))
            epoch_loss = 0
            if step == 0 and self.do_jit:
                fmake_step = eqx.filter_jit(self.make_step)
            elif (step % self.clear_every) == 0 and (step > 0) and do_jit:
                fmake_step = eqx.filter_jit(self.make_step)
            else:
                fmake_step = self.make_step
            if step == 0:
                inp_model = self.model
                inp_opt_state = self.opt_state
            for idx in range(epoch_batch_len):  
                print('Epoch {} :: Batch {}/{}'.format(step, idx, epoch_batch_len))

                #loops over every iterable in loss_input_lists, selecting one batch's input data
                #assumes separate lists, each having inputs for multiple cases in the training set
                loss_inputs = [inp[idx] for inp in loss_input_lists]
                
                this_loss = self.loss(inp_model, *loss_inputs).item()                
                inp_model, inp_opt_state, train_loss = fmake_step(inp_model, inp_opt_state, *loss_inputs) 

                if self.memory_profile:
                    e_pred.block_until_ready()
                    jax.profiler.save_device_memory_profile(f"memory{step}_{idx}.prof")
    
                print('Batch Loss = {}'.format(this_loss))
                epoch_loss += this_loss
                if (step % self.clear_every) and (step > 0) == 0:
                    fmake_step._clear_cache()
                    equinox.clear_caches()
                    jax.clear_backends()
                    jax.clear_caches()
                    self.clear_caches()
                    self.loss.clear_cache()
                    xla._xla_callable.cache_clear()
    
            if (step % self.print_every) == 0 or (step == self.steps - 1):
                print(
                    f"{step=}, epoch_train_loss={epoch_loss}"
                )
            if (step % self.clear_every) and (step > 0):
                fmake_step._clear_cache()
                equinox.clear_caches()
                jax.clear_backends()
                jax.clear_caches()
                self.clear_caches()
                self.loss.clear_cache()
                xla._xla_callable.cache_clear()

        return model

In [173]:
type(optax.adamw(1e-4).init(eqx.filter(xc, eqx.is_array)))

tuple

In [174]:
trainer = xcTrainer(model=xc, optim=optax.adamw(1e-4), loss = E_loss())

In [175]:
trainer(1, trainer.model, dms, energies, ao_evals, gws)

Epoch 0
Epoch 0 :: Batch 0/1
Batch Loss = 0.0003424572856989272
step=0, epoch_train_loss=0.0003424572856989272
Epoch 1
Epoch 1 :: Batch 0/1
Batch Loss = 0.0029619014513926345
step=1, epoch_train_loss=0.0029619014513926345
Epoch 2
Epoch 2 :: Batch 0/1
Batch Loss = 0.002782009913946837
step=2, epoch_train_loss=0.002782009913946837
Epoch 3
Epoch 3 :: Batch 0/1
Batch Loss = 0.0014471899868890858
step=3, epoch_train_loss=0.0014471899868890858
Epoch 4
Epoch 4 :: Batch 0/1
Batch Loss = 0.0004575789342489145
step=4, epoch_train_loss=0.0004575789342489145
Epoch 5
Epoch 5 :: Batch 0/1
Batch Loss = 0.0010923289809205983
step=5, epoch_train_loss=0.0010923289809205983
Epoch 6
Epoch 6 :: Batch 0/1
Batch Loss = 0.000889729042844678
step=6, epoch_train_loss=0.000889729042844678
Epoch 7
Epoch 7 :: Batch 0/1
Batch Loss = 9.479842606374689e-05
step=7, epoch_train_loss=9.479842606374689e-05
Epoch 8
Epoch 8 :: Batch 0/1
Batch Loss = 0.0011404489887336666
step=8, epoch_train_loss=0.0011404489887336666
Epoch

eXC(
  grid_models=[
    eX(
      n_input=2,
      n_hidden=16,
      ueg_limit=True,
      spin_scaling=True,
      lob=1.174,
      use=[1, 2],
      net=MLP(
        layers=(
          Linear(
            weight=f64[16,2],
            bias=f64[16],
            in_features=2,
            out_features=16,
            use_bias=True
          ),
          Linear(
            weight=f64[16,16],
            bias=f64[16],
            in_features=16,
            out_features=16,
            use_bias=True
          ),
          Linear(
            weight=f64[16,16],
            bias=f64[16],
            in_features=16,
            out_features=16,
            use_bias=True
          ),
          Linear(
            weight=f64[1,16],
            bias=f64[1],
            in_features=16,
            out_features=1,
            use_bias=True
          )
        ),
        activation=<function gelu>,
        final_activation=<function <lambda>>,
        use_bias=True,
        use_final_bias=True

In [88]:
def func2(*args):
    print([i for i in args])
def test_func(*args, blen=1):
    for idx in range(blen):
        #loops over every iterable in args, selecting the idx for it
        subls = [i[idx] for i in args]
        func2(*subls)

In [91]:
test_func(['a','b','c'], ['d','e','f'], ['g','h','i'], blen=3)

['a', 'd', 'g']
['b', 'e', 'h']
['c', 'f', 'i']


In [74]:
def clear_caches():
    for module_name, module in sys.modules.items():
        if module_name.startswith("jax"):
            if module_name not in ["jax.interpreters.partial_eval"]:
                for obj_name in dir(module):
                    obj = getattr(module, obj_name)
                    if hasattr(obj, "cache_clear"):
                        try:
                            obj.cache_clear()
                        except:
                            pass
    gc.collect()
# chosen_loss = loop_e_loss
# chosen_loss = total_loop_loss
# chosen_loss = e_loss
# chosen_loss = dm_loss
chosen_loss = total_loss
# @eqx.filter_jit
do_jit = True

def train(model: eqx.Module,
    optim: optax.GradientTransformation,
    steps: int,
    print_every: int,
    clear_every: int,
    memory_profile: bool):
    # Just like earlier: It only makes sense to train the arrays in our model,
    # so filter out everything else.
    opt_state = optim.init(eqx.filter(model, eqx.is_array)) 
    # Always wrap everything -- computing gradients, running the optimiser, updating
    # the model -- into a single JIT region. This ensures things run as fast as
    # possible.


    def make_step(model, opt_state, inp_dm, ref_en, holos, ao_eval, grid_weights, eris, mo_occs, hcs, ss, ogd):
        print('loss_value, grads')
        # loss_value, grads = eqx.filter_value_and_grad(chosen_loss)(model, inp_dm, ref_en, ao_eval, grid_weights, eris, mo_occs, hcs, ss, ogd)
        loss_value, grads = eqx.filter_value_and_grad(chosen_loss)(model, inp_dm, ref_en, holos, ao_eval, grid_weights, eris, mo_occs, hcs, ss, ogd)
        print('updates, opt_state')
        updates, opt_state = optim.update(grads, opt_state, model)
        print('model update')
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_value
    
    for step in range(steps):
        print('epoch {}'.format(step))
        epoch_loss = 0
        if step == 0 and do_jit:
            fmake_step = eqx.filter_jit(make_step)
        elif (step % clear_every) == 0 and (step > 0) and do_jit:
            fmake_step = eqx.filter_jit(make_step)
        else:
            fmake_step = make_step
        for idx in range(len(energies)):  
            idx = len(energies)-idx-1
            print('e {} mol {}/{}'.format(step, idx, len(energies)))
            en = energies[idx]
            dm = dms[idx]
            ao = ao_evals[idx]
            ogd = ogds[idx]
            print(ao.shape)
            gw = gws[idx]
            eri = eris[idx]
            mooc = mo_occs[idx]
            hc = hcs[idx]
            s = ss[idx]
            holo = hologaps[idx]
            e_pred = model(dm, ao, gw)
            dmp, mo_e, mo_c = jax_dm(dm, eri, jax_exc_func(model, ao, gw), mooc, hc, s, ogd)
            holo_pred = mo_e[mooc == 0][0] - mo_e[mooc > 1][-1]
            print('e_pred - e_ref = {}'.format(e_pred-en))
            print('dm_pred - dm sum = {}'.format((dmp-dm).sum()))
            print('holo_pred - ref_holo = {}'.format(holo_pred-holo))
            model, opt_state, train_loss = fmake_step(model, opt_state, dm, en, holo, ao, gw, eri, mooc, hc, s, ogd) 
            mol_loss = chosen_loss(model, dm, en, holo, ao, gw, eri, mooc, hc, s, ogd).item()
            e_pred.block_until_ready()
            if memory_profile:
                jax.profiler.save_device_memory_profile(f"memory{step}_{idx}.prof")

            print('mol loss = {}'.format(mol_loss))
            epoch_loss += mol_loss
            if (step % clear_every) and (step > 0) == 0:
                jax_dm._clear_cache()
                fmake_step._clear_cache()
                equinox.clear_caches()
                jax.clear_backends()
                jax.clear_caches()
                clear_caches()
                chosen_loss.clear_cache()
                xla._xla_callable.cache_clear()

        if (step % print_every) == 0 or (step == steps - 1):
            # current_loss = chosen_loss(model, dms, energies, ao_evals, gws, eris, mo_occs, hcs, ss).item()
            # current_loss = chosen_loss(model, dm, en, ao, gw, eri, mooc, hc, s).item()
            print(
                f"{step=}, epoch_train_loss={epoch_loss}"
                # f"{step=}, train_loss={current_loss}"
            )
        if (step % clear_every) and (step > 0) == 0:
            fmake_step._clear_cache()
            jax_dm._clear_cache()
            equinox.clear_caches()
            jax.clear_backends()
            jax.clear_caches()
            clear_caches()
            chosen_loss.clear_cache()
            xla._xla_callable.cache_clear()

    return model

In [75]:
m = train(xc, optax.adamw(1e-4), steps=250, print_every=1, clear_every=1, memory_profile=False)

epoch 0
e 0 mol 0/1
(10, 25728, 74)
[74] (74,)
(74, 74) (74, 74)
Spin unpolarized make_rdm1()
e_pred - e_ref = -0.0003424572856989272
dm_pred - dm sum = -0.09088607376748996
holo_pred - ref_holo = -0.0001811382552374674
loss_value, grads
[74] (74,)
(74, 74) (74, 74)
Spin unpolarized make_rdm1()
pred_holo Traced<ShapedArray(float64[])>with<JVPTrace(level=3/0)> with
  primal = Traced<ShapedArray(float64[])>with<DynamicJaxprTrace(level=1/0)>
  tangent = Traced<ShapedArray(float64[])>with<JaxprTrace(level=2/0)> with
    pval = (ShapedArray(float64[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7faa7c131560>, in_tracers=(Traced<ShapedArray(float64[]):JaxprTrace(level=2/0)>, Traced<ShapedArray(float64[]):JaxprTrace(level=2/0)>), out_tracer_refs=[<weakref at 0x7faa7c1195d0; to 'JaxprTracer' at 0x7faa7c119580>], out_avals=[ShapedArray(float64[])], primitive=pjit, params={'jaxpr': { lambda ; a:f64[] b:f64[]. let c:f64[] = sub a b in (c,) }, 'in_shardings': (UnspecifiedValue, U


KeyboardInterrupt



In [58]:
mocc = mo_occs[0]
moe = mf.mo_energy
print(mocc, moe)

[2. 2. 2. 2. 2. 2. 2. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0.] [-14.26603092 -14.26463934  -1.09180816  -0.52450794  -0.4397367
  -0.4397367   -0.39399782  -0.06022901  -0.06022901   0.08870731
   0.14128726   0.27233178   0.28895819   0.28898985   0.30133212
   0.30133212   0.35393616   0.35393616   0.39494556   0.4825326
   0.48267082   0.48952567   0.48952567   0.51522962   0.53951392
   0.72012031   0.7208078    0.80602798   0.80602798   1.08150586
   1.10496897   1.10504532   1.25532858   1.25532858   1.48594543
   1.4861282    1.54890009   1.72721476   1.72721476   1.9802559
   2.00327286   2.00327286   2.19298531   2.50156108   2.57656842
   2.57656842   2.72569718   3.01045198   3.5792937    3.57929581
   3.70628862   3.70628862   3.78633564   3.78633564   4.20792428
   4.20792428   4.23703183   4.60767191   4.607816

In [68]:
homo_i = jnp.max(jnp.nonzero(mocc, size=ogds[0][0])[0])

In [69]:
moe[homo_i]

Array(-0.39399782, dtype=float64)

In [None]:
ao_eval.shape

Create silicon cell

In [None]:
cell = gtop.Cell()
a = 5.43
cell.atom = [['Si', [0,0,0]],
              ['Si', [a/4,a/4,a/4]]]
cell.a = jnp.asarray([[0, a/2, a/2],
                     [a/2, 0, a/2],
                     [a/2, a/2, 0]])
cell.basis = 'gth-szv'
cell.pseudo = 'gth-pade'
cell.exp_to_discard = 0.1
cell.build(trace_lattice_vectors=True)
kpts = cell.make_kpts([2,2,2])
mf = scfp.KRHF(cell, kpts=kpts)
e = mf.kernel()