In [1]:
# import pyscfad
# from pyscfad import gto,dft,scf
import pyscf
from pyscf import gto,dft,scf



In [2]:
import numpy as np
import jax.numpy as jnp
import scipy
from ase import Atoms
from ase.io import read
import xcquinox as xce
from functools import partial
from ase.units import Bohr
import os, optax, jax


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [3]:
def get_mol(atoms, basis='6-311++G**'):
    pos = atoms.positions
    spec = atoms.get_chemical_symbols()
    mol_input = [[s, p] for s, p in zip(spec, pos)]
    try:
        mol = gto.Mole(atom=mol_input, basis=atoms.info.get('basis',basis),spin=atoms.info.get('spin',0))
    except Exception:
        mol = gto.Mole(atom=mol_input, basis=atoms.info.get('basis','STO-3G'),spin=atoms.info.get('spin',0))
    return mol 

def get_rhos(rho, spin):
    rho0 = rho[0,0]
    drho = rho[0,1:4] + rho[1:4,0]
    tau = 0.5*(rho[1,1] + rho[2,2] + rho[3,3])

    if spin != 0:
        rho0_a = rho0[0]
        rho0_b = rho0[1]
        gamma_a, gamma_b = jnp.einsum('ij,ij->j',drho[:,0],drho[:,0]), jnp.einsum('ij,ij->j',drho[:,1],drho[:,1])              
        gamma_ab = jnp.einsum('ij,ij->j',drho[:,0],drho[:,1])
        tau_a, tau_b = tau
    else:
        rho0_a = rho0_b = rho0*0.5
        gamma_a=gamma_b=gamma_ab= jnp.einsum('ij,ij->j',drho[:],drho[:])*0.25
        tau_a = tau_b = tau*0.5
    return rho0_a, rho0_b, gamma_a, gamma_b, gamma_ab, tau_a, tau_b
    
def get_data_synth(xc_func, n=100):
    def get_rho(s, a):
        c0 = 2*(3*np.pi**2)**(1/3)
        c1 = 3/10*(3*np.pi**2)**(2/3)
        gamma = c0*s
        tau = c1*a+c0**2*s**2/8
        rho = np.zeros([len(a),6])
        rho[:, 1] = gamma
        rho[:,-1] = tau
        rho[:, 0] = 1
        return rho
    
    s_grid = jnp.concatenate([[0],jnp.exp(jnp.linspace(-10,4,n))])
    rho = []
    for s in s_grid:
        if 'MGGA' in xc_func:
            a_grid = jnp.concatenate([jnp.exp(jnp.linspace(jnp.log((s/100)+1e-8),8,n))])
        else:
            a_grid = jnp.array([0])
        rho.append(get_rho(s, a_grid))
        
    rho = jnp.concatenate(rho)
    
    fxc =  dft.numint.libxc.eval_xc(xc_func,rho.T, spin=0)[0]/dft.numint.libxc.eval_xc('LDA_X',rho.T, spin=0)[0] -1
 
    rho = jnp.asarray(rho)
    
    tdrho = xc.get_descriptors(rho[:,0]/2,rho[:,0]/2,(rho[:,1]/2)**2,(rho[:,1]/2)**2,(rho[:,1]/2)**2,rho[:,5]/2,rho[:,5]/2, spin_scaling=True)
    


    tFxc = torch.from_numpy(fxc)
    return tdrho[0], tFxc

def get_data(mol, xc_func ,full=False, enhance_spin=False, localnet=None):
    print('mol: ', mol.atom)
    try:
        mf = scf.UKS(mol)
    except:
        mf = dft.RKS(mol)
    mf.xc = 'PBE'
    mf.grids.level = 1
    mf.kernel()
    if not full:
        mf.grids.coords = coords
        mf.grids.weights = weights
    if localnet.spin_scaling:
        print('spin scaling, indicates exchange network')
        rho_alpha = mf._numint.eval_rho(mol, mf._numint.eval_ao(mol, mf.grids.coords, deriv=2) , mf.make_rdm1()[0], xctype='metaGGA',hermi=True)
        rho_beta = mf._numint.eval_rho(mol, mf._numint.eval_ao(mol, mf.grids.coords, deriv=2) , mf.make_rdm1()[1], xctype='metaGGA',hermi=True)
        fxc_a =  mf._numint.eval_xc(xc_func,(rho_alpha,rho_alpha*0), spin=1)[0]/mf._numint.eval_xc('LDA_X',(rho_alpha,rho_alpha*0), spin=1)[0] -1
        fxc_b =  mf._numint.eval_xc(xc_func,(rho_beta*0,rho_beta), spin=1)[0]/mf._numint.eval_xc('LDA_X',(rho_beta*0,rho_beta), spin=1)[0] -1
        print('fxc with xc_func = {} = {}'.format(fxc_a, xc_func))

        if mol.spin != 0 and sum(mol.nelec)>1:
            rho = jnp.concatenate([rho_alpha, rho_beta])
            fxc = jnp.concatenate([fxc_a, fxc_b])
        else:
            rho = rho_alpha
            fxc = fxc_a
    else:    
        print('no spin scaling, indicates correlation network')
        rho_alpha = mf._numint.eval_rho(mol, mf._numint.eval_ao(mol, mf.grids.coords, deriv=2) , mf.make_rdm1()[0], xctype='metaGGA',hermi=True)
        rho_beta = mf._numint.eval_rho(mol, mf._numint.eval_ao(mol, mf.grids.coords, deriv=2) , mf.make_rdm1()[1], xctype='metaGGA',hermi=True)
        exc = mf._numint.eval_xc(xc_func,(rho_alpha,rho_beta), spin=1)[0]
        print('exc with xc_func = {} = {}'.format(exc, xc_func))
        fxc = exc/mf._numint.eval_xc('LDA_C_PW',(rho_alpha, rho_beta), spin=1)[0] -1
#         fxc = exc
        rho = jnp.stack([rho_alpha,rho_beta], axis=-1)
    
    dm = jnp.array(mf.make_rdm1())
    ao_eval = jnp.array(mf._numint.eval_ao(mol, mf.grids.coords, deriv=1))
    rho = jnp.einsum('xij,yik,...jk->xy...i', ao_eval, ao_eval, dm)
    
    if dm.ndim == 3:
        rho_filt = (jnp.sum(rho[0,0],axis=0) > 1e-6)
    else:
        rho_filt = (rho[0,0] > 1e-6)
    tdrho = xc.get_descriptors(*get_rhos(rho, spin=1), spin_scaling=localnet.spin_scaling)
    
#     tdrho = torch.from_numpy(tdrho.detach().numpy().round(8))
        
    if localnet.spin_scaling:
        if mol.spin != 0 and sum(mol.nelec) > 1:
            tdrho = jnp.concatenate([tdrho[0],tdrho[1]])
            rho_filt = jnp.concatenate([rho_filt]*2)
            
        else:
            tdrho = tdrho[0]
    tdrho = tdrho[rho_filt]

    tFxc = jnp.array(fxc)[rho_filt]
#     tFxc = torch.from_snumpy(fxc)
    return tdrho, tFxc

In [4]:
PRETRAIN_LEVEL = 'MGGA'

TRAIN_NET = 'c'

REFERENCE_XC = 'PBE0'

N_HIDDEN = 16
DEPTH = 3
if PRETRAIN_LEVEL == 'GGA':
    if TRAIN_NET == 'x':
        localx = xce.net.eX(n_input=1, n_hidden=N_HIDDEN, use=[1], depth=DEPTH, lob=1.804)
    elif TRAIN_NET == 'c':
        localc = xce.net.eC(n_input=3, n_hidden=N_HIDDEN, use=[2], depth=DEPTH, ueg_limit=True)
elif PRETRAIN_LEVEL == 'MGGA':
    if TRAIN_NET == 'x':
        localx = xce.net.eX(n_input=2, n_hidden=N_HIDDEN, use=[1, 2], depth=DEPTH, ueg_limit=True, lob=1.174)
    elif TRAIN_NET == 'c':
        localc = xce.net.eC(n_input=4, n_hidden=N_HIDDEN, depth=DEPTH, use=[2,3], ueg_limit=True)

if TRAIN_NET == 'x':
    thislocal = localx
else:
    thislocal = localc
ueg = xce.xc.LDA_X()
xc = xce.xc.eXC(grid_models=[thislocal], heg_mult=True, level= {'GGA':2, 'MGGA':3, 'NONLOCAL':4}[PRETRAIN_LEVEL])

In [5]:
spins = {
    'Al': 1,
    'B' : 1,
    'Li': 1,
    'Na': 1,
    'Si': 2 ,
    'Be':0,
    'C': 2,
    'Cl': 1,
    'F': 1,
    'H': 1,
    'N': 3,
    'O': 2,
    'P': 3,
    'S': 2
}

selection = [2, 113, 25, 18, 11, 17, 114, 121, 101, 0, 20, 26, 29, 67, 28, 110, 125, 10, 115, 89, 105, 50]
atoms = [read('/home/awills/Documents/Research/ogdpyscf/dpyscf/data/haunschild_g2/g2_97.traj',':')[s] for s in selection]
ksr_atoms = atoms
if PRETRAIN_LEVEL=='MGGA':
    ksr_atoms = ksr_atoms[2:]
ksr_atoms = [Atoms('P',info={'spin':3}), Atoms('N', info={'spin':3}), Atoms('H', info={'spin':1}),Atoms('Li', info={'spin':1}), Atoms('O',info={'spin':2}),Atoms('Cl',info={'spin':1}),Atoms('Al',info={'spin':1}), Atoms('S',info={'spin':2})] + ksr_atoms
# ksr_atoms = [Atoms('H',info={'spin':1})]
mols = [get_mol(atoms) for atoms in ksr_atoms]
mols = [i for i in mols if len(i.atom) < 8]
for i in mols:
    print(i, i.atom, len(i.atom))

<pyscf.gto.mole.Mole object at 0x7451a04314b0> [['P', array([0., 0., 0.])]] 1
<pyscf.gto.mole.Mole object at 0x7451a0432b60> [['N', array([0., 0., 0.])]] 1
<pyscf.gto.mole.Mole object at 0x7451a04324d0> [['H', array([0., 0., 0.])]] 1
<pyscf.gto.mole.Mole object at 0x7451a04322f0> [['Li', array([0., 0., 0.])]] 1
<pyscf.gto.mole.Mole object at 0x7451a0432830> [['O', array([0., 0., 0.])]] 1
<pyscf.gto.mole.Mole object at 0x7451a0433430> [['Cl', array([0., 0., 0.])]] 1
<pyscf.gto.mole.Mole object at 0x7451a0431e40> [['Al', array([0., 0., 0.])]] 1
<pyscf.gto.mole.Mole object at 0x7451a0433580> [['S', array([0., 0., 0.])]] 1
<pyscf.gto.mole.Mole object at 0x7451a0433490> [['Li', array([ 0.      ,  0.      , -1.172697])], ['F', array([0.      , 0.      , 0.390899])]] 2
<pyscf.gto.mole.Mole object at 0x7451a0432cb0> [['C', array([ 0.      ,  0.      , -0.499686])], ['N', array([0.      , 0.      , 0.652056])], ['H', array([ 0.        ,  0.        , -1.56627401])]] 3
<pyscf.gto.mole.Mole object

In [6]:
thislocal.spin_scaling

False

In [7]:
# data = [get_data(mol, xc_func=ref, full=i<14) for i,mol in enumerate(mols)]
ref = 'PBE0'
data = [get_data(mol, xc_func=ref,full=True, localnet=thislocal) for i,mol in enumerate(mols)]
# 
# data = [get_data_synth(ref, 100)]
tdrho = jnp.concatenate([d[0] for d in data])
tFxc = jnp.concatenate([d[1] for d in data])

mol:  [['P', array([0., 0., 0.])]]


Initialize <pyscf.gto.mole.Mole object at 0x7451a04314b0> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -341.104145992717  <S^2> = 3.7502984  2S+1 = 4.0001492
no spin scaling, indicates correlation network
exc with xc_func = [-1.02664237e-02 -2.75034982e-03 -4.08130021e-03 ... -7.25655317e+00
 -7.25655317e+00 -7.25655317e+00] = PBE0
mol:  [['N', array([0., 0., 0.])]]


Initialize <pyscf.gto.mole.Mole object at 0x7451a0432b60> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -54.5289742046674  <S^2> = 3.7524945  2S+1 = 4.0012471
no spin scaling, indicates correlation network
exc with xc_func = [-5.96866425e-03 -2.51798665e-03 -5.83130694e-04 ... -3.33703592e+00
 -3.33703592e+00 -3.33703592e+00] = PBE0
mol:  [['H', array([0., 0., 0.])]]


Initialize <pyscf.gto.mole.Mole object at 0x7451a04324d0> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -0.49981298400854  <S^2> = 0.75  2S+1 = 2
no spin scaling, indicates correlation network
exc with xc_func = [-5.68454275e-03 -2.30635936e-03 -4.18219349e-04 ... -4.97345779e-01
 -4.97345779e-01 -4.97345779e-01] = PBE0
mol:  [['Li', array([0., 0., 0.])]]


Initialize <pyscf.gto.mole.Mole object at 0x7451a04322f0> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -7.46006188627842  <S^2> = 0.75000049  2S+1 = 2.0000005
no spin scaling, indicates correlation network
exc with xc_func = [-0.01101625 -0.00387793 -0.00387793 ... -0.02009147 -1.39639206
 -1.39639206] = PBE0
mol:  [['O', array([0., 0., 0.])]]


Initialize <pyscf.gto.mole.Mole object at 0x7451a0432830> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -75.0033795084925  <S^2> = 2.0027447  2S+1 = 3.0018292
no spin scaling, indicates correlation network
exc with xc_func = [-3.68102742e-03 -1.13732790e-03 -1.57783875e-04 ... -3.81896192e+00
 -3.81896192e+00 -3.81896192e+00] = PBE0
mol:  [['Cl', array([0., 0., 0.])]]


Initialize <pyscf.gto.mole.Mole object at 0x7451a0433430> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -459.95757712076  <S^2> = 0.7516194  2S+1 = 2.0016187
no spin scaling, indicates correlation network
exc with xc_func = [-2.11982273e-03 -2.48633744e-03 -1.44197115e-03 ... -8.24204296e+00
 -8.24204296e+00 -8.24204296e+00] = PBE0
mol:  [['Al', array([0., 0., 0.])]]


Initialize <pyscf.gto.mole.Mole object at 0x7451a0431e40> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -242.226561015217  <S^2> = 0.75226414  2S+1 = 2.0022629
no spin scaling, indicates correlation network
exc with xc_func = [-0.02501281 -0.01759485 -0.0104445  ... -0.00091049 -0.00407204
 -0.00067045] = PBE0
mol:  [['S', array([0., 0., 0.])]]


Initialize <pyscf.gto.mole.Mole object at 0x7451a0433580> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -397.938786804559  <S^2> = 2.0022329  2S+1 = 3.0014882
no spin scaling, indicates correlation network
exc with xc_func = [-6.01728008e-03 -2.66238833e-03 -2.59666135e-03 ... -7.75011279e+00
 -7.75011279e+00 -7.75011279e+00] = PBE0
mol:  [['Li', array([ 0.      ,  0.      , -1.172697])], ['F', array([0.      , 0.      , 0.390899])]]


Initialize <pyscf.gto.mole.Mole object at 0x7451a0433490> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -107.339357395734  <S^2> = 5.3290705e-15  2S+1 = 1
no spin scaling, indicates correlation network
exc with xc_func = [-0.00214665 -0.00173898 -0.00179938 ... -0.00203991 -0.61562248
 -0.61562248] = PBE0
mol:  [['C', array([ 0.      ,  0.      , -0.499686])], ['N', array([0.      , 0.      , 0.652056])], ['H', array([ 0.        ,  0.        , -1.56627401])]]


Initialize <pyscf.gto.mole.Mole object at 0x7451a0432cb0> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -93.337792446513  <S^2> = 4.0072479e-10  2S+1 = 1
no spin scaling, indicates correlation network
exc with xc_func = [-1.38296163e-03 -2.66832598e-04 -3.06425201e-04 ... -5.03343606e-01
 -5.03343606e-01 -5.03343606e-01] = PBE0
mol:  [['C', array([0., 0., 0.])], ['O', array([0.      , 0.      , 1.162879])], ['O', array([ 0.      ,  0.      , -1.162879])]]


Initialize <pyscf.gto.mole.Mole object at 0x7451a0430100> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -188.456965322844  <S^2> = 7.1054274e-15  2S+1 = 1
no spin scaling, indicates correlation network
exc with xc_func = [-1.29084098e-03 -1.15920157e-03 -3.32405317e-03 ... -6.61945569e-05
 -2.84591976e+00 -2.84591976e+00] = PBE0
mol:  [['Cl', array([0.      , 0.      , 1.008241])], ['Cl', array([ 0.      ,  0.      , -1.008241])]]


Initialize <pyscf.gto.mole.Mole object at 0x7451a04339d0> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -920.00560888896  <S^2> = 4.938272e-13  2S+1 = 1
no spin scaling, indicates correlation network
exc with xc_func = [-0.00376872 -0.00297928 -0.00401242 ... -0.00297928 -0.3467079
 -0.3467079 ] = PBE0
mol:  [['F', array([0.      , 0.      , 0.693963])], ['F', array([ 0.      ,  0.      , -0.693963])]]


Initialize <pyscf.gto.mole.Mole object at 0x7451a0433820> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -199.394370591172  <S^2> = 1.1901591e-13  2S+1 = 1
no spin scaling, indicates correlation network
exc with xc_func = [-1.28783793e-03 -2.82669550e-04 -4.34324741e-03 ... -4.25493449e-01
 -4.25493449e-01 -4.25493449e-01] = PBE0
mol:  [['O', array([0.      , 0.      , 0.603195])], ['O', array([ 0.      ,  0.      , -0.603195])]]


Initialize <pyscf.gto.mole.Mole object at 0x7451a0432d70> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -150.21489454177  <S^2> = 1.0018599  2S+1 = 2.2377309
no spin scaling, indicates correlation network
exc with xc_func = [-2.88673908e-03 -9.06450598e-04 -1.28008902e-04 ... -5.16165985e-01
 -5.16165985e-01 -5.16165985e-01] = PBE0
mol:  [['C', array([0.      , 0.      , 0.599454])], ['C', array([ 0.      ,  0.      , -0.599454])], ['H', array([ 0.        ,  0.        , -1.66162301])], ['H', array([0.        , 0.        , 1.66162301])]]


Initialize <pyscf.gto.mole.Mole object at 0x7451a0432710> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -77.2435048346375  <S^2> = 1.5099033e-14  2S+1 = 1
no spin scaling, indicates correlation network
exc with xc_func = [-0.00060514 -0.00060514 -0.0032515  ... -0.00031967 -0.00031865
 -0.00031865] = PBE0
mol:  [['O', array([0.      , 0.      , 0.484676])], ['C', array([ 0.      ,  0.      , -0.646235])]]


Initialize <pyscf.gto.mole.Mole object at 0x7451a0431b40> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -113.221335689652  <S^2> = 6.6346928e-13  2S+1 = 1
no spin scaling, indicates correlation network
exc with xc_func = [-0.00219858 -0.00302532 -0.00639712 ... -0.56161386 -0.56161386
 -0.56161386] = PBE0
mol:  [['Cl', array([0.      , 0.      , 0.071315])], ['H', array([ 0.      ,  0.      , -1.212358])]]


Initialize <pyscf.gto.mole.Mole object at 0x7451a0431de0> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -460.624592374078  <S^2> = 6.5725203e-14  2S+1 = 1
no spin scaling, indicates correlation network
exc with xc_func = [-1.16123674e-03 -2.16564032e-04 -2.58966276e-04 ... -2.41956464e+00
 -2.41956464e+00 -2.41956464e+00] = PBE0
mol:  [['Li', array([0.      , 0.      , 0.403632])], ['H', array([ 0.      ,  0.      , -1.210897])]]


Initialize <pyscf.gto.mole.Mole object at 0x7451a0433880> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -8.04458854018922  <S^2> = 7.9269924e-14  2S+1 = 1
no spin scaling, indicates correlation network
exc with xc_func = [-0.00280244 -0.00292882 -0.00292882 ... -0.42329395 -0.42329395
 -0.42329395] = PBE0
mol:  [['Na', array([0.        , 0.        , 1.50747901])], ['Na', array([ 0.        ,  0.        , -1.50747901])]]


Initialize <pyscf.gto.mole.Mole object at 0x7451a04317b0> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -324.340512506578  <S^2> = 1.5857538e-11  2S+1 = 1
no spin scaling, indicates correlation network
exc with xc_func = [-0.00664309 -0.00664309 -0.00843956 ... -0.00664309 -0.00664309
 -0.00843956] = PBE0
mol:  [['Al', array([0., 0., 0.])], ['Cl', array([0.        , 2.08019101, 0.        ])], ['Cl', array([ 1.80149801, -1.040095  ,  0.        ])], ['Cl', array([-1.80149801, -1.040095  ,  0.        ])]]


Initialize <pyscf.gto.mole.Mole object at 0x7451a0433fd0> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -1622.57507845814  <S^2> = 8.2422957e-13  2S+1 = 1
no spin scaling, indicates correlation network
exc with xc_func = [-9.24417223e-04 -1.96285983e-03 -4.54661166e-03 ... -6.26905723e+00
 -6.26905723e+00 -6.26905723e+00] = PBE0
mol:  [['P', array([0.      , 0.      , 0.128906])], ['H', array([ 0.      ,  1.19333 , -0.644531])], ['H', array([ 1.033455, -0.596665, -0.644531])], ['H', array([-1.033455, -0.596665, -0.644531])]]


Initialize <pyscf.gto.mole.Mole object at 0x7451a0433c70> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -342.979728469575  <S^2> = 2.5393021e-11  2S+1 = 1
no spin scaling, indicates correlation network
exc with xc_func = [-3.12430446e-03 -1.72805187e-03 -3.47626992e-04 ... -1.52422293e+00
 -1.52422293e+00 -1.52422293e+00] = PBE0
mol:  [['Si', array([0.      , 0.      , 1.135214])], ['Si', array([ 0.      ,  0.      , -1.135214])]]


Initialize <pyscf.gto.mole.Mole object at 0x7451a0432590> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -578.56533666059  <S^2> = 1.0034708  2S+1 = 2.2391702
no spin scaling, indicates correlation network
exc with xc_func = [-0.00093916 -0.00130488 -0.00130487 ... -0.30880861 -0.30880861
 -0.30880861] = PBE0
mol:  [['C', array([0., 0., 0.])], ['H', array([0.630382, 0.630382, 0.630382])], ['H', array([-0.630382, -0.630382,  0.630382])], ['H', array([ 0.630382, -0.630382, -0.630382])], ['H', array([-0.630382,  0.630382, -0.630382])]]


Initialize <pyscf.gto.mole.Mole object at 0x7451a0432a10> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -40.4598214864075  <S^2> = 3.170797e-13  2S+1 = 1
no spin scaling, indicates correlation network
exc with xc_func = [-0.00295239 -0.00121221 -0.00020696 ... -0.00027031 -0.00027031
 -0.00027031] = PBE0
mol:  [['C', array([0.      , 0.      , 0.179918])], ['H', array([ 0.      ,  0.855475, -0.539754])], ['H', array([ 0.      , -0.855475, -0.539754])]]


Initialize <pyscf.gto.mole.Mole object at 0x7451a04304f0> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -39.0756147483505  <S^2> = 6.2030381e-12  2S+1 = 1
no spin scaling, indicates correlation network
exc with xc_func = [-2.63869543e-03 -4.00286742e-04 -1.08589328e-03 ... -9.56249175e-01
 -9.56249175e-01 -9.56249175e-01] = PBE0
mol:  [['Si', array([0., 0., 0.])], ['H', array([0.855876, 0.855876, 0.855876])], ['H', array([-0.855876, -0.855876,  0.855876])], ['H', array([-0.855876,  0.855876, -0.855876])], ['H', array([ 0.855876, -0.855876, -0.855876])]]


Initialize <pyscf.gto.mole.Mole object at 0x7451a04313c0> in UKS object of <class 'pyscf.dft.uks.UKS'>


converged SCF energy = -291.719272437819  <S^2> = 1.3146817e-11  2S+1 = 1
no spin scaling, indicates correlation network
exc with xc_func = [-0.00288016 -0.00118216 -0.00020993 ... -0.0002218  -0.0002218
 -0.0002218 ] = PBE0


In [8]:
nan_filt = ~jnp.any((tdrho != tdrho),axis=-1)

tFxc = tFxc[nan_filt]
tdrho = tdrho[nan_filt,:]

In [9]:
tdrho[::].shape, tdrho[::]

((224145, 4),
 Array([[-3.55646847e+00,  2.20899060e-01,  2.04142939e+00,
          1.43859114e+00],
        [-4.04188739e+00,  2.26935426e-01,  2.65887948e+00,
          1.38496111e+00],
        [-3.20301698e+00,  2.08517375e-01,  1.63659032e+00,
          1.14131231e+00],
        ...,
        [-4.21483060e+00,  9.81437154e-14,  2.82063903e+00,
          5.40655126e-01],
        [-4.25518732e+00,  4.00346423e-13,  2.87882770e+00,
          2.76902720e-01],
        [-4.31527649e+00,  6.20392626e-13,  2.81798141e+00,
          1.35860915e+00]], dtype=float64))

In [14]:
cpus = jax.devices(backend='cpu')

In [16]:
PRINT_EVERY=200
scheduler = optax.exponential_decay(init_value = 1e-2, transition_begin=50, transition_steps=1000, decay_rate=0.9)
# optimizer = optax.adam(learning_rate = 1e-2)
optimizer = optax.adam(learning_rate = scheduler)

class PT_E_Loss():

    def __call__(self, model, inp, ref):

        pred = jax.vmap(model.net)(inp)[:, 0]

        err = pred-ref

        return jnp.mean(jnp.square(err))

trainer = xce.train.xcTrainer(model=thislocal, optim=optimizer, steps=100000, loss = PT_E_Loss(), do_jit=True)
if TRAIN_NET == 'x':
    inp = [tdrho[:, trainer.model.use]]
else:
    inp = [tdrho]
with jax.default_device(cpus[0]):
    newm = trainer(1, trainer.model, inp, [tFxc])
        

# for epoch in range(100000):
#     total_loss = 0
#     results = thislocal(tdrho[::])
#     loss = eloss(results, tFxc[::])
#     total_loss += loss.item()
#     loss.backward()

#     optimizer.step()
#     optimizer.zero_grad()
#     if epoch%PRINT_EVERY==0:
#         print('total loss {:.12f}'.format(total_loss))
        


Epoch 0
Epoch 0 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 142.01953322295498
0, epoch_train_loss=142.01953322295498
Epoch 1
Epoch 1 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 141.52423514100153
1, epoch_train_loss=141.52423514100153
Epoch 2
Epoch 2 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 140.93775898579273
2, epoch_train_loss=140.93775898579273
Epoch 3
Epoch 3 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 140.21737178188835
3, epoch_train_loss=140.21737178188835
Epoch 4
Epoch 4 :: Batch 0/1
Batch Loss = 139.3136881036083
4, epoch_train_loss=139.3136881036083
Epoch 5
Epoch 5 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 138.16118932122865
5, epoch_train_loss=138.16118932122865
Epoch 6
Epoch 6 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 136.6826233323594
6, epoch_train_loss=136.6826233323594
Epoch 7
Epoch 7 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 134.78788687460062
7, epoch_train_loss=134.78788687460062
Epoch 8
Epoch 8 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 132.37911096797714
8, epoch_train_loss=132.37911096797714
Epoch 9
Epoch 9 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 129.354262105626
9, epoch_train_loss=129.354262105626
Epoch 10
Epoch 10 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 125.61655599566436
10, epoch_train_loss=125.61655599566436
Epoch 11
Epoch 11 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 121.08010315284187
11, epoch_train_loss=121.08010315284187
Epoch 12
Epoch 12 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 115.6694970421447
12, epoch_train_loss=115.6694970421447
Epoch 13
Epoch 13 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 109.32586802565598
13, epoch_train_loss=109.32586802565598
Epoch 14
Epoch 14 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 102.0103293245059
14, epoch_train_loss=102.0103293245059
Epoch 15
Epoch 15 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 93.70808314475755
15, epoch_train_loss=93.70808314475755
Epoch 16
Epoch 16 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 84.44260081666202
16, epoch_train_loss=84.44260081666202
Epoch 17
Epoch 17 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 74.29156860914269
17, epoch_train_loss=74.29156860914269
Epoch 18
Epoch 18 :: Batch 0/1
Batch Loss = 63.38767684473789
18, epoch_train_loss=63.38767684473789
Epoch 19
Epoch 19 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 51.972113491236826
19, epoch_train_loss=51.972113491236826
Epoch 20
Epoch 20 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 40.46697595175366
20, epoch_train_loss=40.46697595175366
Epoch 21
Epoch 21 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 29.50917794143992
21, epoch_train_loss=29.50917794143992
Epoch 22
Epoch 22 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 19.95310610270469
22, epoch_train_loss=19.95310610270469
Epoch 23
Epoch 23 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 12.794355444919017
23, epoch_train_loss=12.794355444919017
Epoch 24
Epoch 24 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 9.070226942027253
24, epoch_train_loss=9.070226942027253
Epoch 25
Epoch 25 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 9.593325266358981
25, epoch_train_loss=9.593325266358981
Epoch 26
Epoch 26 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 13.936150284997348
26, epoch_train_loss=13.936150284997348
Epoch 27
Epoch 27 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 19.405037816890903
27, epoch_train_loss=19.405037816890903
Epoch 28
Epoch 28 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 22.574512404975096
28, epoch_train_loss=22.574512404975096
Epoch 29
Epoch 29 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 21.964669938538268
29, epoch_train_loss=21.964669938538268
Epoch 30
Epoch 30 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 18.450993116757395
30, epoch_train_loss=18.450993116757395
Epoch 31
Epoch 31 :: Batch 0/1
Batch Loss = 14.015423323540983
31, epoch_train_loss=14.015423323540983
Epoch 32
Epoch 32 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 10.510461477466128
32, epoch_train_loss=10.510461477466128
Epoch 33
Epoch 33 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 8.952402583293074
33, epoch_train_loss=8.952402583293074
Epoch 34
Epoch 34 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 9.169911007891397
34, epoch_train_loss=9.169911007891397
Epoch 35
Epoch 35 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 10.227060895302841
35, epoch_train_loss=10.227060895302841
Epoch 36
Epoch 36 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 11.189091922177269
36, epoch_train_loss=11.189091922177269
Epoch 37
Epoch 37 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 11.51738354812142
37, epoch_train_loss=11.51738354812142
Epoch 38
Epoch 38 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 11.11031736998523
38, epoch_train_loss=11.11031736998523
Epoch 39
Epoch 39 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 10.142906633018372
39, epoch_train_loss=10.142906633018372
Epoch 40
Epoch 40 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 8.909145430876883
40, epoch_train_loss=8.909145430876883
Epoch 41
Epoch 41 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 7.721528636463038
41, epoch_train_loss=7.721528636463038
Epoch 42
Epoch 42 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 6.843084919161483
42, epoch_train_loss=6.843084919161483
Epoch 43
Epoch 43 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 6.424971977348707
43, epoch_train_loss=6.424971977348707
Epoch 44
Epoch 44 :: Batch 0/1
Batch Loss = 6.458571877450885
44, epoch_train_loss=6.458571877450885
Epoch 45
Epoch 45 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 6.7772549284648385
45, epoch_train_loss=6.7772549284648385
Epoch 46
Epoch 46 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 7.130108135352308
46, epoch_train_loss=7.130108135352308
Epoch 47
Epoch 47 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 7.298281879134176
47, epoch_train_loss=7.298281879134176
Epoch 48
Epoch 48 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 7.186078002505767
48, epoch_train_loss=7.186078002505767
Epoch 49
Epoch 49 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 6.839310197806178
49, epoch_train_loss=6.839310197806178
Epoch 50
Epoch 50 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 6.395365082935501
50, epoch_train_loss=6.395365082935501
Epoch 51
Epoch 51 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 6.004991330386296
51, epoch_train_loss=6.004991330386296
Epoch 52
Epoch 52 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 5.7681307739371865
52, epoch_train_loss=5.7681307739371865
Epoch 53
Epoch 53 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 5.706312765895923
53, epoch_train_loss=5.706312765895923
Epoch 54
Epoch 54 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 5.772217148433129
54, epoch_train_loss=5.772217148433129
Epoch 55
Epoch 55 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 5.882596899645913
55, epoch_train_loss=5.882596899645913
Epoch 56
Epoch 56 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 5.9563907013926825
56, epoch_train_loss=5.9563907013926825
Epoch 57
Epoch 57 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 5.942775279914039
57, epoch_train_loss=5.942775279914039
Epoch 58
Epoch 58 :: Batch 0/1
Batch Loss = 5.832419379409219
58, epoch_train_loss=5.832419379409219
Epoch 59
Epoch 59 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 5.652861462986839
59, epoch_train_loss=5.652861462986839
Epoch 60
Epoch 60 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 5.452866895034855
60, epoch_train_loss=5.452866895034855
Epoch 61
Epoch 61 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 5.282123214181726
61, epoch_train_loss=5.282123214181726
Epoch 62
Epoch 62 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 5.172922207867301
62, epoch_train_loss=5.172922207867301
Epoch 63
Epoch 63 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 5.129790832571869
63, epoch_train_loss=5.129790832571869
Epoch 64
Epoch 64 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 5.130713271230361
64, epoch_train_loss=5.130713271230361
Epoch 65
Epoch 65 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 5.139242380715564
65, epoch_train_loss=5.139242380715564
Epoch 66
Epoch 66 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 5.1218181833102765
66, epoch_train_loss=5.1218181833102765
Epoch 67
Epoch 67 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 5.061978395004324
67, epoch_train_loss=5.061978395004324
Epoch 68
Epoch 68 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 4.965015634266676
68, epoch_train_loss=4.965015634266676
Epoch 69
Epoch 69 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 4.852051773251655
69, epoch_train_loss=4.852051773251655
Epoch 70
Epoch 70 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 4.747946046123675
70, epoch_train_loss=4.747946046123675
Epoch 71
Epoch 71 :: Batch 0/1
Batch Loss = 4.669724393885189
71, epoch_train_loss=4.669724393885189
Epoch 72
Epoch 72 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 4.620655297989225
72, epoch_train_loss=4.620655297989225
Epoch 73
Epoch 73 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 4.591419115039264
73, epoch_train_loss=4.591419115039264
Epoch 74
Epoch 74 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 4.566384867938147
74, epoch_train_loss=4.566384867938147
Epoch 75
Epoch 75 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 4.5313136719703335
75, epoch_train_loss=4.5313136719703335
Epoch 76
Epoch 76 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 4.479052150569882
76, epoch_train_loss=4.479052150569882
Epoch 77
Epoch 77 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 4.411295666917683
77, epoch_train_loss=4.411295666917683
Epoch 78
Epoch 78 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 4.336386907127048
78, epoch_train_loss=4.336386907127048
Epoch 79
Epoch 79 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 4.2646470377744095
79, epoch_train_loss=4.2646470377744095
Epoch 80
Epoch 80 :: Batch 0/1


CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Batch Loss = 4.203499225050835
80, epoch_train_loss=4.203499225050835
Epoch 81
Epoch 81 :: Batch 0/1


Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7452a39d66b0>>
Traceback (most recent call last):
  File "/home/awills/anaconda3/envs/pyscfad/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 

KeyboardInterrupt



In [None]:
jax.vmap(thislocal.net)(tdrho[:, [1,2]])[:, 0].shape

In [None]:
jax.vmap(trainer.model.net)(tdrho[:, [1,2]])