In [2]:
# import pyscfad
# from pyscfad import gto,dft,scf
import matplotlib.pyplot as plt
import equinox as eqx
import pyscf
from pyscf import gto,dft,scf



In [3]:
import numpy as np
import jax.numpy as jnp
import scipy
from ase import Atoms
from ase.io import read
import xcquinox as xce
from functools import partial
from ase.units import Bohr
import os, optax, jax


In [38]:
def get_mol(atoms, basis='6-311++G**'):
    pos = atoms.positions
    spec = atoms.get_chemical_symbols()
    mol_input = [[s, p] for s, p in zip(spec, pos)]
    try:
        mol = gto.Mole(atom=mol_input, basis=atoms.info.get('basis',basis),spin=atoms.info.get('spin',0))
    except Exception:
        mol = gto.Mole(atom=mol_input, basis=atoms.info.get('basis','STO-3G'),spin=atoms.info.get('spin',0))
    return mol 

def get_rhos(rho, spin):
    rho0 = rho[0,0]
    drho = rho[0,1:4] + rho[1:4,0]
    tau = 0.5*(rho[1,1] + rho[2,2] + rho[3,3])

    if spin != 0:
        rho0_a = rho0[0]
        rho0_b = rho0[1]
        gamma_a, gamma_b = jnp.einsum('ij,ij->j',drho[:,0],drho[:,0]), jnp.einsum('ij,ij->j',drho[:,1],drho[:,1])              
        gamma_ab = jnp.einsum('ij,ij->j',drho[:,0],drho[:,1])
        tau_a, tau_b = tau
    else:
        rho0_a = rho0_b = rho0*0.5
        gamma_a=gamma_b=gamma_ab= jnp.einsum('ij,ij->j',drho[:],drho[:])*0.25
        tau_a = tau_b = tau*0.5
    return rho0_a, rho0_b, gamma_a, gamma_b, gamma_ab, tau_a, tau_b
    
def get_data_synth(xcmodel, xc_func, n=100):
    def get_rho(s, a):
        c0 = 2*(3*np.pi**2)**(1/3)
        c1 = 3/10*(3*np.pi**2)**(2/3)
        gamma = c0*s
        tau = c1*a+c0**2*s**2/8
        rho = np.zeros([len(a),6])
        rho[:, 1] = gamma
        rho[:,-1] = tau
        rho[:, 0] = 1
        return rho
    
    s_grid = jnp.concatenate([[0],jnp.exp(jnp.linspace(-10,4,n))])
    rho = []
    for s in s_grid:
        if 'MGGA' in xc_func:
            a_grid = jnp.concatenate([jnp.exp(jnp.linspace(jnp.log((s/100)+1e-8),8,n))])
        else:
            a_grid = jnp.array([0])
        rho.append(get_rho(s, a_grid))
        
    rho = jnp.concatenate(rho)
    
    fxc =  dft.numint.libxc.eval_xc(xc_func,rho.T, spin=0)[0]/dft.numint.libxc.eval_xc('LDA_X',rho.T, spin=0)[0] -1
 
    rho = jnp.asarray(rho)
    
    tdrho = xcmodel.get_descriptors(rho[:,0]/2,rho[:,0]/2,(rho[:,1]/2)**2,(rho[:,1]/2)**2,(rho[:,1]/2)**2,rho[:,5]/2,rho[:,5]/2, spin_scaling=True, mf=mf, dm=dm)
    


    tFxc = jnp.array(fxc)
    return tdrho[0], tFxc

def get_data(mol, xcmodel, xc_func, localnet=None):
    print('mol: ', mol.atom)
    try:
        mf = scf.UKS(mol)
    except:
        mf = dft.RKS(mol)
    mf.xc = 'PBE'
    mf.grids.level = 1
    mf.kernel()
    ao = mf._numint.eval_ao(mol, mf.grids.coords, deriv=2)
    dm = mf.make_rdm1()
    if len(dm.shape) == 2:
        #artificially spin-polarize
        dm = np.array([0.5*dm, 0.5*dm])
    print('New DM shape: {}'.format(dm.shape))
    print('ao.shape', ao.shape)

    if localnet.spin_scaling:
        print('spin scaling, indicates exchange network')
        rho_alpha = mf._numint.eval_rho(mol, ao, dm[0], xctype='metaGGA',hermi=True)
        rho_beta = mf._numint.eval_rho(mol, ao, dm[1], xctype='metaGGA',hermi=True)
        fxc_a =  mf._numint.eval_xc(xc_func,(rho_alpha,rho_alpha*0), spin=1)[0]/mf._numint.eval_xc('LDA_X',(rho_alpha,rho_alpha*0), spin=1)[0] -1
        fxc_b =  mf._numint.eval_xc(xc_func,(rho_beta*0,rho_beta), spin=1)[0]/mf._numint.eval_xc('LDA_X',(rho_beta*0,rho_beta), spin=1)[0] -1
        print('fxc with xc_func = {} = {}'.format(fxc_a, xc_func))
        print(f'rho_a.shape={rho_alpha.shape}, rho_b.shape={rho_beta.shape}')
        print(f'fxc_a.shape={fxc_a.shape}, fxc_b.shape={fxc_b.shape}')

        if mol.spin != 0 and sum(mol.nelec)>1:
            print('mol.spin != 0 and sum(mol.nelec) > 1')
            rho = jnp.concatenate([rho_alpha, rho_beta], axis=-1)
            fxc = jnp.concatenate([fxc_a, fxc_b])
            print(f'rho.shape={rho.shape}, fxc.shape={fxc.shape}')
        else:
            print('NOT (mol.spin != 0 and sum(mol.nelec) > 1)')
            rho = rho_alpha
            fxc = fxc_a
            print(f'rho.shape={rho.shape}, fxc.shape={fxc.shape}')
    else:    
        print('no spin scaling, indicates correlation network')
        rho_alpha = mf._numint.eval_rho(mol, ao, dm[0], xctype='metaGGA',hermi=True)
        rho_beta = mf._numint.eval_rho(mol, ao, dm[1], xctype='metaGGA',hermi=True)
        exc = mf._numint.eval_xc(xc_func,(rho_alpha,rho_beta), spin=1)[0]
        print('exc with xc_func = {} = {}'.format(exc, xc_func))
        fxc = exc/mf._numint.eval_xc('LDA_C_PW',(rho_alpha, rho_beta), spin=1)[0] -1
        rho = jnp.stack([rho_alpha,rho_beta], axis=-1)
    
    dm = jnp.array(mf.make_rdm1())
    print('get_data, dm shape = {}'.format(dm.shape))
    ao_eval = jnp.array(mf._numint.eval_ao(mol, mf.grids.coords, deriv=1))
    print(f'ao_eval.shape={ao_eval.shape}')
    rho = jnp.einsum('xij,yik,...jk->xy...i', ao_eval, ao_eval, dm)        
    rho0 = rho[0,0]
    drho = rho[0,1:4] + rho[1:4,0]
    tau = 0.5*(rho[1,1] + rho[2,2] + rho[3,3])

    print('rho shape', rho.shape)
    if dm.ndim == 3:
        rho_filt = (jnp.sum(rho0,axis=0) > 1e-6)
    else:
        rho_filt = (rho0 > 1e-6)
    print('rho_filt shape:', rho_filt.shape)

    
    mf.converged=True
    tdrho = xcmodel.get_descriptors(*get_rhos(rho, spin=1), spin_scaling=localnet.spin_scaling, mf=mf, dm=dm)
    print(f'tdrho.shape={tdrho.shape}')
    if localnet.spin_scaling:
        if mol.spin != 0 and sum(mol.nelec) > 1:
            print('mol.spin != 0 and sum(mol.nelec) > 1')
            #tdrho not returned in a spin-polarized form regardless,
            #but the enhancement factors sampled as polarized, so double
            tdrho = jnp.concatenate([tdrho,tdrho], axis=1)
            rho_filt2 = rho_filt.copy()
            rho_filt = jnp.concatenate([rho_filt]*2)
            print(f'tdrho.shape={tdrho.shape}')
            print(f'rho_filt.shape={rho_filt.shape}')
        elif sum(mol.nelec) == 1:
            pass
    try:
        tdrho = tdrho[rho_filt]
        tFxc = jnp.array(fxc)[rho_filt]
    except:
        tdrho = tdrho[:, rho_filt, :]
        tFxc = jnp.array(fxc)[rho_filt]
    return tdrho, tFxc

def get_data_exc(mol, xcmodel, xc_func, localnet=None, xorc=None):
    print('mol: ', mol.atom)
    try:
        mf = scf.UKS(mol)
    except:
        mf = dft.RKS(mol)
    mf.xc = 'PBE'
    mf.grids.level = 1
    mf.kernel()
    ao = mf._numint.eval_ao(mol, mf.grids.coords, deriv=2)
    dm = mf.make_rdm1()
    if len(dm.shape) == 2:
        #artificially spin-polarize
        dm = np.array([0.5*dm, 0.5*dm])
    print('New DM shape: {}'.format(dm.shape))
    print('ao.shape', ao.shape)

    #depending on the x or c type, choose the generation of the exchange or correlation density
    if xorc == 'x':
        print('Exchange contribution only')
        xc_func = xc_func+','
        if xc_func.lower() == 'pbe0,':
            print('PBE0 detected. changing xc_func to be combination of HF and PBE')
            xc_func = '0.25*HF + 0.75*PBE,'
        print(xc_func)
    elif xorc == 'c':
        print('Correlation contribution only')
        xc_func = ','+xc_func
        if xc_func.lower() == ',pbe0':
            print('PBE0 detected. Changing correlation to be just PBE')
            xc_func = ',pbe'
        print(xc_func)
    if localnet.spin_scaling:
        print('spin scaling')
        rho_alpha = mf._numint.eval_rho(mol, ao, dm[0], xctype='metaGGA',hermi=True)
        rho_beta = mf._numint.eval_rho(mol, ao, dm[1], xctype='metaGGA',hermi=True)
        fxc_a =  mf._numint.eval_xc(xc_func,(rho_alpha,rho_alpha*0), spin=1)[0]
        fxc_b =  mf._numint.eval_xc(xc_func,(rho_beta*0,rho_beta), spin=1)[0]
        print('fxc with xc_func = {} = {}'.format(fxc_a, xc_func))
        print(f'rho_a.shape={rho_alpha.shape}, rho_b.shape={rho_beta.shape}')
        print(f'fxc_a.shape={fxc_a.shape}, fxc_b.shape={fxc_b.shape}')

        if mol.spin != 0 and sum(mol.nelec)>1:
            print('mol.spin != 0 and sum(mol.nelec) > 1')
            rho = jnp.concatenate([rho_alpha, rho_beta], axis=-1)
            fxc = jnp.concatenate([fxc_a, fxc_b])
            print(f'rho.shape={rho.shape}, fxc.shape={fxc.shape}')
        else:
            print('NOT (mol.spin != 0 and sum(mol.nelec) > 1)')
            rho = rho_alpha
            fxc = fxc_a
            print(f'rho.shape={rho.shape}, fxc.shape={fxc.shape}')
    else:    
        print('no spin scaling')
        rho_alpha = mf._numint.eval_rho(mol, ao, dm[0], xctype='metaGGA',hermi=True)
        rho_beta = mf._numint.eval_rho(mol, ao, dm[1], xctype='metaGGA',hermi=True)
        exc = mf._numint.eval_xc(xc_func,(rho_alpha,rho_beta), spin=1)[0]
        print('exc with xc_func = {} = {}'.format(exc, xc_func))
        fxc = exc
        rho = jnp.stack([rho_alpha,rho_beta], axis=-1)
    
    dm = jnp.array(mf.make_rdm1())
    print('get_data, dm shape = {}'.format(dm.shape))
    ao_eval = jnp.array(mf._numint.eval_ao(mol, mf.grids.coords, deriv=1))
    print(f'ao_eval.shape={ao_eval.shape}')
    rho = jnp.einsum('xij,yik,...jk->xy...i', ao_eval, ao_eval, dm)        
    rho0 = rho[0,0]

    print('rho shape', rho.shape)
    if dm.ndim == 3:
        rho_filt = (jnp.sum(rho0,axis=0) > 1e-6)
    else:
        rho_filt = (rho0 > 1e-6)
    print('rho_filt shape:', rho_filt.shape)

    
    mf.converged=True
    tdrho = xcmodel.get_descriptors(*get_rhos(rho, spin=1), spin_scaling=localnet.spin_scaling, mf=mf, dm=dm)
    print(f'get descriptors tdrho.shape={tdrho.shape}')
    if localnet.spin_scaling:
        if mol.spin != 0 and sum(mol.nelec) > 1:
            print('mol.spin != 0 and sum(mol.nelec) > 1')
            #tdrho not returned in a spin-polarized form regardless,
            #but the enhancement factors sampled as polarized, so double
            if len(tdrho.shape) == 3:
                print('concatenating spin channels along axis=0')
                tdrho = jnp.concatenate([tdrho[0],tdrho[1]], axis=0)
            else:
                print('concatenating along axis=0')
                tdrho = jnp.concatenate([tdrho, tdrho], axis=0)
            rho_filt = jnp.concatenate([rho_filt]*2)
            print(f'tdrho.shape={tdrho.shape}')
            print(f'rho_filt.shape={rho_filt.shape}')
        else:
            #spin == 0 or hydrogen
            tdrho = tdrho[0]
            
    try:
        tdrho = tdrho[rho_filt]
        tFxc = jnp.array(fxc)[rho_filt]
    except:
        tdrho = tdrho[:, rho_filt, :]
        tFxc = jnp.array(fxc)[rho_filt]
    return tdrho, tFxc


level_dict = {'GGA':2, 'MGGA':3, 'NONLOCAL':4}

x_lob_level_dict = {'GGA': 1.804, 'MGGA': 1.174, 'NONLOCAL': 1.174}

class PT_E_Loss(eqx.Module):

    def __call__(self, model, inp, ref):

        pred = jax.vmap(model.net)(inp)[:, 0]

        err = pred-ref

        return jnp.mean(jnp.square(err))


In [39]:
PRETRAIN_LEVEL = 'MGGA'

TRAIN_NET = 'x'
SS = True

REFERENCE_XC = 'PBE0'

N_HIDDEN = 16
DEPTH = 3
if PRETRAIN_LEVEL == 'GGA':
    localx = xce.net.eX(n_input=1, n_hidden=N_HIDDEN, use=[1], depth=DEPTH, spin_scaling=SS, lob=1.804)
    localc = xce.net.eC(n_input=3, n_hidden=N_HIDDEN, use=[2], depth=DEPTH, spin_scaling=SS, ueg_limit=True)
elif PRETRAIN_LEVEL == 'MGGA':
    localx = xce.net.eX(n_input=2, n_hidden=N_HIDDEN, use=[1, 2], depth=DEPTH, ueg_limit=True, spin_scaling=SS, lob=1.174)
    localc = xce.net.eC(n_input=4, n_hidden=N_HIDDEN, depth=DEPTH, use=[2,3], spin_scaling=SS, ueg_limit=True)
elif PRETRAIN_LEVEL == 'NONLOCAL':
    localx = xce.net.eX(n_input=18, n_hidden=N_HIDDEN, depth=DEPTH, ueg_limit=True, spin_scaling=SS, lob=1.174)
    #n_input = 4 from base, 12 from NL
    localc = xce.net.eC(n_input=16, n_hidden=N_HIDDEN, depth=DEPTH, spin_scaling=SS, ueg_limit=True)

if TRAIN_NET == 'x':
    thislocal = localx
else:
    thislocal = localc
ueg = xce.xc.LDA_X()
xc = xce.xc.eXC(grid_models=[thislocal], heg_mult=True, level= {'GGA':2, 'MGGA':3, 'NONLOCAL':4}[PRETRAIN_LEVEL])

In [40]:
spins = {
    'Al': 1,
    'B' : 1,
    'Li': 1,
    'Na': 1,
    'Si': 2 ,
    'Be':0,
    'C': 2,
    'Cl': 1,
    'F': 1,
    'H': 1,
    'N': 3,
    'O': 2,
    'P': 3,
    'S': 2
}

selection = [2, 113, 25, 18, 11, 17, 114, 121, 101, 0, 20, 26, 29, 67, 28, 110, 125, 10, 115, 89, 105, 50]
try:
    atoms = [read('/home/awills/Documents/Research/ogdpyscf/dpyscf/data/haunschild_g2/g2_97.traj',':')[s] for s in selection]
except:
    atoms = [read('/home/awills/Documents/Research/ogdpyscf/data/haunschild_g2/g2_97.traj',':')[s] for s in selection]
ksr_atoms = atoms
if PRETRAIN_LEVEL=='MGGA':
    ksr_atoms = ksr_atoms[2:]
ksr_atoms = [Atoms('P',info={'spin':3}), Atoms('N', info={'spin':3}), Atoms('H', info={'spin':1}),Atoms('Li', info={'spin':1}), Atoms('O',info={'spin':2}),Atoms('Cl',info={'spin':1}),Atoms('Al',info={'spin':1}), Atoms('S',info={'spin':2})] + ksr_atoms
# ksr_atoms = [Atoms('H',info={'spin':1})]
mols = [get_mol(atoms) for atoms in ksr_atoms]
mols = [i for i in mols if len(i.atom) < 8]
for idx, i in enumerate(mols):
    print(i, ksr_atoms[idx].get_chemical_formula(), i.atom, len(i.atom))
mols = mols[:1]

<pyscf.gto.mole.Mole object at 0x7faf7d50f940> P [['P', array([0., 0., 0.])]] 1
<pyscf.gto.mole.Mole object at 0x7faf7d8af1f0> N [['N', array([0., 0., 0.])]] 1
<pyscf.gto.mole.Mole object at 0x7faf7d8ad0c0> H [['H', array([0., 0., 0.])]] 1
<pyscf.gto.mole.Mole object at 0x7faf7d8ae860> Li [['Li', array([0., 0., 0.])]] 1
<pyscf.gto.mole.Mole object at 0x7faf7d8af190> O [['O', array([0., 0., 0.])]] 1
<pyscf.gto.mole.Mole object at 0x7faf7d8afc10> Cl [['Cl', array([0., 0., 0.])]] 1
<pyscf.gto.mole.Mole object at 0x7faf7d8ac430> Al [['Al', array([0., 0., 0.])]] 1
<pyscf.gto.mole.Mole object at 0x7faf7d8af310> S [['S', array([0., 0., 0.])]] 1
<pyscf.gto.mole.Mole object at 0x7faf7d8ad8a0> FLi [['Li', array([ 0.      ,  0.      , -1.172697])], ['F', array([0.      , 0.      , 0.390899])]] 2
<pyscf.gto.mole.Mole object at 0x7faf7d8ac8e0> CHN [['C', array([ 0.      ,  0.      , -0.499686])], ['N', array([0.      , 0.      , 0.652056])], ['H', array([ 0.        ,  0.        , -1.56627401])]] 3


In [41]:
# data = [get_data(mol, xc_func=ref, full=i<14) for i,mol in enumerate(mols)]
ref = 'PBE0'
datax = [get_data_exc(mol, xcmodel=xc, xc_func=ref, localnet=localx, xorc='x') for i,mol in enumerate(mols)]
datac = [get_data_exc(mol, xcmodel=xc, xc_func=ref, localnet=localc, xorc='c') for i,mol in enumerate(mols)]
# 

mol:  [['P', array([0., 0., 0.])]]


Initialize <pyscf.gto.mole.Mole object at 0x7faf7d50f940> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -341.104145992717  <S^2> = 3.7502984  2S+1 = 4.0001492
New DM shape: (2, 30, 30)
ao.shape (10, 6320, 30)
Exchange contribution only
PBE0 detected. changing xc_func to be combination of HF and PBE
0.25*HF + 0.75*PBE,
spin scaling
fxc with xc_func = [-1.02859844e-02 -2.86073945e-03 -4.09993757e-03 ... -7.11627837e+00
 -7.11627837e+00 -7.11627837e+00] = 0.25*HF + 0.75*PBE,
rho_a.shape=(6, 6320), rho_b.shape=(6, 6320)
fxc_a.shape=(6320,), fxc_b.shape=(6320,)
mol.spin != 0 and sum(mol.nelec) > 1
rho.shape=(6, 12640), fxc.shape=(12640,)
get_data, dm shape = (2, 30, 30)
ao_eval.shape=(4, 6320, 30)
rho shape (4, 4, 2, 6320)
rho_filt shape: (6320,)
get descriptors tdrho.shape=(2, 6320, 3)
mol.spin != 0 and sum(mol.nelec) > 1
concatenating spin channels along axis=0
tdrho.shape=(12640, 3)
rho_filt.shape=(12640,)
mol:  [['P', array([0., 0., 0.])]]
converged SCF energy = -341.104145992718  <S^2> = 3.7502984  2S+1 = 4.0001492
New DM shape: (2, 30, 30)
ao.shape (10, 6320, 30)


In [25]:
[datax[i][1] - datac[i][1] for i in range(len(datax))]

[Array([1.66533454e-15, 1.41553436e-15, 1.81105131e-15, ...,
        6.21724894e-15, 6.21724894e-15, 6.21724894e-15], dtype=float64)]

In [23]:
datax[1]

IndexError: list index out of range

In [8]:
# data = [get_data_synth(ref, 100)]
tdrho = jnp.concatenate([d[0] for d in data], axis=1)
tFxc = jnp.concatenate([d[1] for d in data])

In [11]:
nan_filt.shape

(2, 258825)

In [13]:
nan_filt = ~jnp.any((tdrho != tdrho),axis=-1)

tFxc = tFxc[nan_filt[0, :]]
tdrho = tdrho[nan_filt]

In [14]:
tdrho[::].shape, tdrho[::], tFxc.shape, tFxc

((517650, 3),
 Array([[-3.32842976,  2.23725351,  1.44783222],
        [-3.81180913,  2.88156858,  1.84297971],
        [-2.97912881,  1.80658656,  1.05260213],
        ...,
        [-4.21483038,  3.0933172 ,  0.43034293],
        [-4.25518687,  3.15235971,  0.11670809],
        [-4.31527705,  3.09062045,  1.31631199]], dtype=float64),
 (258825,),
 Array([0.32315335, 0.3454216 , 0.27872126, ..., 0.3483771 , 0.3489568 ,
        0.34849783], dtype=float64))

In [15]:
cpus = jax.devices(backend='cpu')

In [None]:
first_pred = jax.vmap(thislocal.net)(tdrho[...,thislocal.use])

In [None]:
PRINT_EVERY=200
scheduler = optax.exponential_decay(init_value = 1e-2, transition_begin=50, transition_steps=500, decay_rate=0.9)
# optimizer = optax.adam(learning_rate = 1e-2)
optimizer = optax.adam(learning_rate = scheduler)

class PT_E_Loss():

    def __call__(self, model, inp, ref):

        pred = jax.vmap(model.net)(inp)[:, 0]

        err = pred-ref

        return jnp.mean(jnp.square(err))

trainer = xce.train.xcTrainer(model=thislocal, optim=optimizer, steps=500, loss = PT_E_Loss(), do_jit=True)
if TRAIN_NET == 'x':
    inp = [tdrho[..., trainer.model.use]]
else:
    inp = [tdrho]
with jax.default_device(cpus[0]):
    newm = trainer(1, trainer.model, inp, [tFxc])
        

# for epoch in range(100000):
#     total_loss = 0
#     results = thislocal(tdrho[::])
#     loss = eloss(results, tFxc[::])
#     total_loss += loss.item()
#     loss.backward()

#     optimizer.step()
#     optimizer.zero_grad()
#     if epoch%PRINT_EVERY==0:
#         print('total loss {:.12f}'.format(total_loss))
        


In [None]:
best_pred = jax.vmap(newm.net)(tdrho[...,newm.use])[:, 0]

In [None]:
f = plt.figure()
ax = f.add_subplot(111)
sel = 1
xs = np.arange(0, len(best_pred[::sel]))
ax.grid()
ax.scatter(xs, abs(best_pred[::sel]-tFxc[::sel]))
# ax.scatter(best_pred[::sel], tFxc[::sel])
rmse = np.sqrt( np.mean( (best_pred[::sel] - tFxc[::sel])**2))
print(rmse)

# ax.set_xlim(0, 1)
# ax.set_ylim(0, 10)
ax.set_yscale('log')
ax.set_ylabel('|Pred. - Ref.|')
ax.set_xlabel('Gridpoint Index')

In [None]:
f = plt.figure()
ax = f.add_subplot(111)
sel = 1
xs = np.arange(0, len(first_pred[::sel]))
ax.grid()
ax.scatter(xs, abs(first_pred.T[0][::sel]-tFxc[::sel]))
rmse = np.sqrt( np.mean( (first_pred.T[0][::sel] - tFxc[::sel])**2))
print(rmse)
# ax.set_xlim(0, 1)
# ax.set_ylim(-1, 1000)
ax.set_yscale('log')
ax.set_ylabel('|Pred. - Ref.|')
ax.set_xlabel('Gridpoint Index')


Below for nonlocal, not specifically shaped MGGA

In [None]:
first_pred = jax.vmap(thislocal.net)(tdrho)

In [None]:
first_pred.shape

In [None]:
plt.scatter(first_pred[::100], tFxc[::100])

In [None]:
f = plt.figure()
ax = f.add_subplot(111)

ax.scatter(first_pred[::100], tFxc[::100])
ax.set_xlim(0, 10)

In [None]:
PRINT_EVERY=200
scheduler = optax.exponential_decay(init_value = 1e-2, transition_begin=50, transition_steps=500, decay_rate=0.9)
# optimizer = optax.adam(learning_rate = 1e-2)
optimizer = optax.adam(learning_rate = scheduler)

class PT_E_Loss():

    def __call__(self, model, inp, ref):

        pred = jax.vmap(model.net)(inp)[:, 0]

        err = pred-ref

        return jnp.mean(jnp.square(err))

trainer = xce.train.xcTrainer(model=thislocal, optim=optimizer, steps=500, loss = PT_E_Loss(), do_jit=True)
if TRAIN_NET == 'x':
    inp = [tdrho[:, trainer.model.use]]
else:
    inp = [tdrho]
with jax.default_device(cpus[0]):
    newm = trainer(1, trainer.model, inp, [tFxc])
        

# for epoch in range(100000):
#     total_loss = 0
#     results = thislocal(tdrho[::])
#     loss = eloss(results, tFxc[::])
#     total_loss += loss.item()
#     loss.backward()

#     optimizer.step()
#     optimizer.zero_grad()
#     if epoch%PRINT_EVERY==0:
#         print('total loss {:.12f}'.format(total_loss))
        


In [None]:
new_pred = jax.vmap(newm.net)(tdrho).T[0]

In [None]:
chkpts = sorted([i for i in os.listdir() if 'xc.eqx' in i], key=lambda x: int(x.split('.')[-1]))

In [None]:
chkpts[-1]

In [None]:
bestnet = eqx.tree_deserialise_leaves(chkpts[-2], newm)

In [None]:
best_pred = jax.vmap(bestnet.net)(tdrho)[:, 0]

In [None]:
f = plt.figure()
ax = f.add_subplot(111)
sel = 1
xs = np.arange(0, len(best_pred[::sel]))
ax.grid()
ax.scatter(xs, abs(best_pred[::sel]-tFxc[::sel]))
# ax.scatter(best_pred[::sel], tFxc[::sel])
rmse = np.sqrt( np.mean( (best_pred[::sel] - tFxc[::sel])**2))
print(rmse)

# ax.set_xlim(0, 1)
# ax.set_ylim(0, 10)
ax.set_yscale('log')
ax.set_ylabel('|Pred. - Ref.|')
ax.set_xlabel('Gridpoint Index')


In [None]:
f = plt.figure()
ax = f.add_subplot(111)
sel = 1
xs = np.arange(0, len(new_pred[::sel]))

ax.scatter(xs, new_pred[::sel]-tFxc[::sel])
rmse = np.sqrt( np.mean( (new_pred[::sel] - tFxc[::sel])**2))
print(rmse)

# ax.set_xlim(0, 1)
ax.set_ylim(-1, 1)

In [None]:
f = plt.figure()
ax = f.add_subplot(111)
sel = 1
xs = np.arange(0, len(first_pred[::sel]))
ax.grid()
ax.scatter(xs, abs(first_pred.T[0][::sel]-tFxc[::sel]))
rmse = np.sqrt( np.mean( (first_pred.T[0][::sel] - tFxc[::sel])**2))
print(rmse)
# ax.set_xlim(0, 1)
# ax.set_ylim(-1, 1000)
ax.set_yscale('log')
ax.set_ylabel('|Pred. - Ref.|')
ax.set_xlabel('Gridpoint Index')


In [None]:
first_pred.shape

In [None]:
xs.shape, (new_pred[::sel]-tFxc[::sel]).shape

In [None]:

new_pred.shape, tFxc.shape