In [1]:
import numpy as np
np.random.seed(42)
import jax
import jax.numpy as np

from pyscf import gto, scf, mcscf
import pyqmc.api as pyq

from qmc.pyscftools import orbital_evaluator_from_pyscf
from qmc.orbitals import *
from qmc.determinants import *

#Setting
jax.config.update("jax_enable_x64",True)

In [2]:
from pyscf import gto, scf, mcscf

# 물 분자 정의
mol = gto.Mole()
mol.atom = '''
O 0.000000 0.000000 0.117790
H 0.000000 0.755453 -0.471161
H 0.000000 -0.755453 -0.471161
'''
mol.basis = 'sto-3g'
mol.build()

# HF 계산 먼저 수행
mf = scf.RHF(mol)
mf.kernel()

nconfig = 10
configs = pyq.initial_guess(mol, nconfig)
# # CASSCF 계산
# # 6개 궤도함수에 6개 전자를 넣고 계산
# mc = mcscf.CASSCF(mf, ncas=6, nelecas=6)
# mc.kernel()
coords = configs.configs

converged SCF energy = -74.963146775618


In [4]:
max_orb, det_coeff, det_map, mo_coeff, occup_hash, _nelec = orbital_evaluator_from_pyscf(mol,mf)
nelec = np.sum(mol.nelec)

In [5]:
nelec = np.sum(mol.nelec)
gtoval = "GTOval_sph"
atomic_orbitals = aos(mol,gtoval, coords)
aovals = atomic_orbitals.reshape(-1, nconfig, nelec, atomic_orbitals.shape[-1])
print(atomic_orbitals.shape)
print(aovals.shape)

(1, 100, 7)
(1, 10, 10, 7)


In [6]:
initial_wf_value = compute_wf_value(coords,
                                    atomic_orbitals,
                                    mo_coeff,
                                    det_coeff,
                                    _nelec,
                                    occup_hash,
                                    det_map)


updets, up_inverse = recompute(configs=coords,
                               atomic_orbitals =atomic_orbitals,
                               mo_coeff = mo_coeff,
                               _nelec = _nelec,
                               occup_hash = occup_hash,
                               s = 0)

updets = updets[:, :, det_map[0]]

dndets, down_inverse = recompute(configs=coords,
                               atomic_orbitals =atomic_orbitals,
                               mo_coeff = mo_coeff,
                               _nelec = _nelec,
                               occup_hash = occup_hash,
                               s = 1)

dndets = dndets[:, :, det_map[1]]
inverse = tuple([up_inverse, down_inverse])
print(initial_wf_value)

(Array([-1.,  1.,  1., -1.,  1.,  1., -1.,  1.,  1., -1.], dtype=float64), Array([-24.91851571, -25.2200766 , -22.61672913, -24.5454139 ,
       -25.30044541, -33.15232444, -32.59060734, -28.77909548,
       -22.64761219, -20.84011955], dtype=float64))


In [7]:
e = 0
epos = coords[:,e, :]

g, _, _  = gradient_value(mol, 
                          e, 
                          epos,
                          inverse, 
                          updets,
                          dndets,
                          mo_coeff, 
                          det_coeff,
                          det_map,
                          _nelec,
                          occup_hash)
print(g)

[[  1.63960001   1.20706873   0.20594085   0.79131941  -1.59822779
   29.80298152 -12.3821985    1.14878567  -0.52862865  -2.48844549]
 [ -2.16440261  -0.448461    -0.03814711  -0.89822094   0.34266312
   26.60420471  -4.68935388  -0.18134237  -0.60558799  -0.98823798]
 [ -0.04978068  -0.09247176  -1.74237462   0.69911575  -0.29240211
  -33.98807898   4.53514892  -1.19578169  -1.53739607   1.95881056]]


In [8]:
gradient_laplacian(mol = mol,
                   e = e,
                   epos = epos,
                   inverse = inverse,
                   updets = updets,
                   dndets = dndets,
                   mo_coeff = mo_coeff,
                   det_coeff = det_coeff,
                   det_map = det_map,
                   _nelec = _nelec,
                   occup_hash = occup_hash)

(Array([[  1.63960001,   1.20706873,   0.20594085,   0.79131941,
          -1.59822779,  29.80298152, -12.3821985 ,   1.14878567,
          -0.52862865,  -2.48844549],
        [ -2.16440261,  -0.448461  ,  -0.03814711,  -0.89822094,
           0.34266312,  26.60420471,  -4.68935388,  -0.18134237,
          -0.60558799,  -0.98823798],
        [ -0.04978068,  -0.09247176,  -1.74237462,   0.69911575,
          -0.29240211, -33.98807898,   4.53514892,  -1.19578169,
          -1.53739607,   1.95881056]], dtype=float64),
 Array([ 1.08138154e-01,  2.13391145e-01,  9.24971541e-01,  5.81858224e-01,
        -3.12574521e+00, -4.12983512e+02, -4.53819341e+01,  1.04376415e+00,
        -1.62369023e+00,  3.12933062e+00], dtype=float64))