In [1]:
import os
import numpy as np
np.random.seed(42)

import jax
import jax.numpy as np

from pyscf import gto, scf, mcscf
import pyqmc.api as pyq

from qmc.pyscftools import orbital_evaluator_from_pyscf
from qmc.orbitals import *
from qmc.determinants import *


jax.config.update("jax_enable_x64",True)

def initialize_calculation(mol: gto.Mole, nconfig: int):
    """
    Initial Calculation
    Args:
        mol: pyscf object
        nconfig: configuration
        
    Returns:
        coords, max_orb, det_coeff, det_map, mo_coeff, occup_hash, nelec
    """
    mf = scf.RHF(mol)
    mf.kernel()
    
    configs = pyq.initial_guess(mol, nconfig)
    coords = configs.configs
    
    max_orb, det_coeff, det_map, mo_coeff, occup_hash, _nelec = \
        orbital_evaluator_from_pyscf(mol, mf)
    
    nelec = np.sum(mol.nelec)
    return coords, max_orb, det_coeff, det_map, mo_coeff, occup_hash, nelec

def determine_complex_settings(mo_coeff: jnp.ndarray, 
                             det_coeff: jnp.ndarray):
    """
    
    Args:
        mo_coeff: MO's coefficient
        det_coeff: determinant's coefficent
        
    Returns:
        iscomplex, mo_dtype, get_phase
    """
    ao_dtype = float
    
    iscomplex = check_parameters_complex(mo_coeff)
    mo_dtype = complex if iscomplex else float
    
    
    iscomplex = mo_dtype == complex or check_parameters_complex(det_coeff, mo_coeff)
    get_phase = get_complex_phase if iscomplex else jnp.sign
    
    return iscomplex, mo_dtype, get_phase


In [6]:
coords, max_orb, det_coeff, det_map, mo_coeff, occup_hash, _nelec = initialize_calculation(mol,10)

converged SCF energy = -74.963146775618


In [10]:
_nelec

(5, 5)

In [2]:
from pyscf import gto, scf, mcscf

# 물 분자 정의
mol = gto.Mole()
mol.atom = '''
O 0.000000 0.000000 0.117790
H 0.000000 0.755453 -0.471161
H 0.000000 -0.755453 -0.471161
'''
mol.basis = 'sto-3g'
mol.build()

# HF 계산 먼저 수행
mf = scf.RHF(mol)
mf.kernel()

nconfig = 10
configs = pyq.initial_guess(mol, nconfig)
# # CASSCF 계산
# # 6개 궤도함수에 6개 전자를 넣고 계산
# mc = mcscf.CASSCF(mf, ncas=6, nelecas=6)
# mc.kernel()
coords = configs.configs

converged SCF energy = -74.963146775618


In [4]:
# Parameter setting
max_orb, det_coeff, det_map, mo_coeff, occup_hash, _nelec = orbital_evaluator_from_pyscf(mol,mf)
nelec = np.sum(mol.nelec)
ao_dtype = float
iscomplex = check_parameters_complex(mo_coeff)
mo_dtype = complex if iscomplex else ao_dtype

iscomplex = mo_dtype == complex or check_parameters_complex(det_coeff, mo_coeff)

get_phase = get_complex_phase if iscomplex else jnp.sign

In [9]:
coords

array([[[-1.47852199e+00, -7.19844208e-01, -2.38047931e-01],
        [ 1.05712223e+00,  3.43618290e-01, -1.54044932e+00],
        [ 3.24083969e-01, -3.85082280e-01, -4.54331160e-01],
        [ 6.11676289e-01,  1.03099952e+00,  1.15387096e+00],
        [-8.39217523e-01, -3.09212376e-01,  5.53854272e-01],
        [ 9.75545127e-01, -4.79174238e-01,  3.69318635e-02],
        [-1.10633497e+00, -1.19620662e+00,  1.03511666e+00],
        [ 1.35624003e+00, -7.20101216e-02,  1.22612374e+00],
        [ 3.61636025e-01, -6.45119755e-01,  5.83986446e-01],
        [ 1.53803657e+00, -1.46342531e+00,  6.74278405e-01]],

       [[-2.61974510e+00,  8.21902504e-01,  3.09637908e-01],
        [-2.99007350e-01,  9.17607765e-02, -1.76497807e+00],
        [-2.19671888e-01,  3.57112572e-01,  1.70048488e+00],
        [-5.18270218e-01, -8.08493603e-01, -2.79166203e-01],
        [ 9.15402118e-01, -1.09884816e+00, -1.42012545e+00],
        [ 5.13267433e-01,  9.70775493e-02,  1.19123583e+00],
        [-7.02053094e-

In [5]:
nelec = np.sum(mol.nelec)
gtoval = "GTOval_sph"

e = 0
epos = coords[:,e, :]

atomic_orbitals = aos(mol,gtoval, coords)
aovals = atomic_orbitals.reshape(-1, nconfig, nelec, atomic_orbitals.shape[-1])
mo = mos(aovals, mo_coeff[0])
mo_vals = mo[:, occup_hash[0]]
print(atomic_orbitals.shape)
print(aovals.shape)
print(mo.shape)
print(mo_vals.shape)

# inverse[0].shape

(1, 100, 7)
(1, 10, 10, 7)
(10, 10, 5)
(10, 1, 5, 5)


In [6]:
mask = None
ao = aos(mol, gtoval, epos)
print(ao.shape)
mask = np.ones(epos.shape[0], dtype=bool)
print(mask.shape)
#aovals[:, mask, e, :] = ao
aovals = aovals.at[:, mask, e, :].set(ao)

mo = mos(ao, mo_coeff[0])
print(mo.shape)
print(aovals.shape)

mo_vals = mo[:, occup_hash[0]]
print(mo_vals.shape)



(1, 10, 7)
(10,)
(10, 5)
(1, 10, 10, 7)
(10, 1, 5)


In [7]:
dets, inverse, aovals = recompute(configs=coords,
                                  atomic_orbitals =atomic_orbitals,
                                  mo_coeff = mo_coeff,
                                  _nelec = _nelec,
                                  occup_hash = occup_hash)

wf_value = compute_wf_value(configs, 
                            dets,
                            det_coeff,
                            det_map)

wf_value

(Array([-1.,  1.,  1., -1.,  1.,  1., -1.,  1.,  1., -1.], dtype=float64),
 Array([-24.91851571, -25.2200766 , -22.61672913, -24.5454139 ,
        -25.30044541, -33.15232444, -32.59060734, -28.77909548,
        -22.64761219, -20.84011955], dtype=float64))

In [8]:

e = 0
epos = coords[:,e, :]

g, val, saved  = gradient_value(mol, 
                                e, 
                                epos,
                                dets, 
                                inverse,
                                mo_coeff,
                                det_coeff,
                                det_map,
                                _nelec,
                                occup_hash)
g

Array([[  1.63960001,   1.20706873,   0.20594085,   0.79131941,
         -1.59822779,  29.80298152, -12.3821985 ,   1.14878567,
         -0.52862865,  -2.48844549],
       [ -2.16440261,  -0.448461  ,  -0.03814711,  -0.89822094,
          0.34266312,  26.60420471,  -4.68935388,  -0.18134237,
         -0.60558799,  -0.98823798],
       [ -0.04978068,  -0.09247176,  -1.74237462,   0.69911575,
         -0.29240211, -33.98807898,   4.53514892,  -1.19578169,
         -1.53739607,   1.95881056]], dtype=float64)

In [15]:
gradient_laplacian(mol,
                   e,
                   epos,
                   dets,
                   inverse,
                   mo_coeff,
                   det_coeff,
                   det_map,
                   _nelec,
                   occup_hash)

(Array([[  1.63960001,   1.20706873,   0.20594085,   0.79131941,
          -1.59822779,  29.80298152, -12.3821985 ,   1.14878567,
          -0.52862865,  -2.48844549],
        [ -2.16440261,  -0.448461  ,  -0.03814711,  -0.89822094,
           0.34266312,  26.60420471,  -4.68935388,  -0.18134237,
          -0.60558799,  -0.98823798],
        [ -0.04978068,  -0.09247176,  -1.74237462,   0.69911575,
          -0.29240211, -33.98807898,   4.53514892,  -1.19578169,
          -1.53739607,   1.95881056]], dtype=float64),
 Array([ 1.08138154e-01,  2.13391145e-01,  9.24971541e-01,  5.81858224e-01,
        -3.12574521e+00, -4.12983512e+02, -4.53819341e+01,  1.04376415e+00,
        -1.62369023e+00,  3.12933062e+00], dtype=float64))