In [10]:
import os
import numpy as np
import jax
import jax.numpy as jnp
from jax import random
from typing import Tuple, Optional
from pyscf import gto, scf, mcscf
import pyqmc.api as pyq
from pyqmc.energy import *

from qmc.pyscftools import orbital_evaluator_from_pyscf
from qmc.setting import initialize_calculation, determine_complex_settings
from qmc.mc import limdrift
from qmc.orbitals import *
from qmc.determinants import *
from qmc.extract import *

np.random.seed(42)
jax.config.update("jax_enable_x64",True)

In [11]:
mol = gto.Mole()
mol.atom = '''
H  0.000000  0.000000  0.000000  # First hydrogen atom at origin
H  0.740000  0.000000  0.000000  # Second hydrogen atom, typical bond length ~0.74 Å
'''
mol.basis = 'sto-3g'
# mol.ecp = "ccECP"
mol.build()

nconfig, seed = 5, 42

coords, max_orb, det_coeff, det_map, mo_coeff, occup_hash, _nelec, nelec = \
    initialize_calculation(mol, nconfig, seed)
    
iscomplex, mo_dtype, get_phase = \
    determine_complex_settings(mo_coeff, det_coeff)

coords = jnp.array(coords)
atom_coords = jnp.array(mol.atom_coords())
atom_charges = jnp.array(mol.atom_charges())

np.random.seed(seed)

config = pyq.initial_guess(mol, nconfig)

converged SCF energy = -1.11675930739643


In [14]:
import pyscf
mc = pyscf.mcscf.CASSCF(mf, 2, 2)
mc.kernel()

CASSCF energy = -1.13728383448850
CASCI E = -1.13728383448850  E(CI) = -1.85238817356958  S^2 = 0.0000000


(np.float64(-1.1372838344885026),
 np.float64(-1.8523881735695837),
 FCIvector([[ 9.93646755e-01,  1.72113803e-16],
            [-5.43349285e-17, -1.12543887e-01]]),
 array([[ 0.54884228,  1.21245192],
        [ 0.54884228, -1.21245192]]),
 array([-0.57258233,  0.66546196]))

In [22]:
from pyqmc.wftools import generate_slater
wf1, to_opt1 = generate_slater(mol, mf)
wf1.recompute(configs)
wf1._dets

[array([[[ 1.        ],
         [ 1.        ],
         [ 1.        ],
         [ 1.        ],
         [ 1.        ]],
 
        [[-2.36860036],
         [-2.34523235],
         [-1.36665484],
         [-1.90908932],
         [-1.06569021]]]),
 array([[[ 1.        ],
         [ 1.        ],
         [ 1.        ],
         [ 1.        ],
         [ 1.        ]],
 
        [[-3.17244924],
         [-2.01238898],
         [-2.02178677],
         [-3.60859197],
         [-2.3990064 ]]])]

In [23]:
to_opt1

{'det_coeff': array([False])}

In [7]:
from pyqmc.api import Slater
import pyqmc.api as pyq

np.random.seed(42)
configs = pyq.initial_guess(mol, nconfig)
coords = configs.configs
coords = jnp.array(configs.configs)

mf = scf.RHF(mol)
mf.kernel()

aovals, dets, inverse = recompute(mol, coords, mo_coeff, _nelec, occup_hash)
wf, to_opt = pyq.generate_wf(mol, mf)
grad =pyq.gradient_generator(mol, wf, to_opt)
wf.recompute(configs)
import time
start = time.time()f
for i in range(500):
  aovals, dets, inverse = recompute(mol, coords, mo_coeff, _nelec, occup_hash)

end = time.time()
print(end -start)

start = time.time()
for i in range(500):
  wf.recompute(configs)

end = time.time()
print(end -start)


converged SCF energy = -1.11675930739643
0.08291411399841309
0.19236111640930176


In [9]:
wf._dets

AttributeError: 'MultiplyWF' object has no attribute '_dets'

In [11]:
wf.recompute(config)

(array([1., 1., 1., 1., 1.]),
 array([-5.48846046, -4.27934161, -3.25168698, -5.44174978, -3.30041133]))

In [12]:
print(ei_energy(mol, configs))
print(ee_energy(configs))
print(ii_energy(mol))
print(kinetic(configs, wf)[0])
total = ei_energy(mol, configs) + ee_energy(configs) + ii_energy(mol) + kinetic(configs, wf)[0]
print(total)

[-1.90453407 -2.4670126  -3.41080024 -2.10215388 -3.85402988]
[0.71597433 1.04066599 0.84595687 0.3591553  0.55762922]
0.7151043390810812
[-0.15626125  0.12732074  0.44068302 -0.14745074  0.66278783]
[-0.62971665 -0.58392153 -1.40905601 -1.17534499 -1.9185085 ]


In [71]:
# ei_mean = np.mean(ei_energy(mol, configs), axis = 0)
# ee_mean = np.mean(ee_energy(configs), axis = 0)
# ii_mean = np.mean(ii_energy(mol), axis= 0)
# ke_mean = np.mean(kinetic(configs,wf)[0], axis= 0)

In [9]:
print(jax_ei_energy(coords, atom_charges, atom_coords))
print(jax_ee_energy(coords))
print(jax_ii_energy(mol))
print(jax_kinetic_energy(coords, mol, dets, inverse, mo_coeff, det_coeff, det_map, _nelec, occup_hash)[0])

total = jax_ei_energy(coords, atom_charges, atom_coords) + jax_ee_energy(coords) + jax_ii_energy(mol) + jax_kinetic_energy(coords, mol, dets, inverse, mo_coeff, det_coeff, det_map, _nelec, occup_hash)[0]

print(total)

[-1.90453407 -2.4670126  -3.41080024 -2.10215388 -3.85402988]
[0.71597433 1.04066599 0.84595687 0.3591553  0.55762922]
0.7151043390810812
[-0.09396923  0.21639611  0.53020582 -0.08242941  0.6450487 ]
[-0.56742463 -0.49484616 -1.31953321 -1.11032365 -1.93624763]


In [73]:
grad(config, wf)

{'ke': array([-0.15626125,  0.12732074,  0.44068302, -0.14745074,  0.66278783]),
 'ee': array([0.71597433, 1.04066599, 0.84595687, 0.3591553 , 0.55762922]),
 'ei': array([-1.90453407, -2.4670126 , -3.41080024, -2.10215388, -3.85402988]),
 'ecp': array([0., 0., 0., 0., 0.]),
 'grad2': array([3.1355144 , 2.86026606, 2.25569794, 2.78992666, 1.69145976]),
 'total': array([-0.62971665, -0.58392153, -1.40905601, -1.17534499, -1.9185085 ]),
 'dpH': array([[-0.51262811+0.j, -0.41507876+0.j, -0.29545804+0.j,
         -0.17682499+0.j, -0.09536131+0.j, -0.04601202+0.j,
         -0.02189997+0.j, -0.00986495+0.j, -0.3304717 +0.j,
         -0.24590547+0.j, -0.11480599+0.j, -0.07212592+0.j,
         -0.02712589+0.j, -0.01602704+0.j, -0.00567161+0.j,
         -0.00330289+0.j,  0.        +0.j, -0.48976118+0.j,
          0.        +0.j,  0.        +0.j, -0.26070892+0.j,
          0.        +0.j,  0.        +0.j, -0.07861113+0.j,
          0.        +0.j],
        [-0.37515328+0.j, -0.40110065+0.j, -0.15

In [11]:
import pyqmc.mc

df_vmc, coords = pyqmc.mc.vmc(
    wf,
    configs,
    accumulators={"pgrad": grad},
)

In [26]:
df_vmc["pgradtotal"]

array([-1.11293998, -1.02229444, -1.03308748, -1.09977122, -1.27617765,
       -1.15337538, -1.17873983, -1.19174489, -1.02438249, -1.03846931])

In [27]:
df_vmc

{'pgradke': array([1.00914793, 1.18944444, 1.56201179, 1.5362026 , 1.0849196 ,
        0.61071501, 0.89528079, 1.15546622, 1.09190092, 1.05833295]),
 'pgradee': array([0.62042943, 0.87031513, 0.74346155, 0.69357557, 0.57780602,
        0.68911782, 0.61119074, 0.64290674, 0.65118102, 0.65430755]),
 'pgradei': array([-3.45762167, -3.79715835, -4.05366516, -4.04465373, -3.6540076 ,
        -3.16831255, -3.4003157 , -3.70522219, -3.48256876, -3.46621416]),
 'pgradecp': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'pgradgrad2': array([2.76291613, 2.72266334, 2.80891008, 2.69333234, 2.72107494,
        2.69002718, 2.73268775, 2.6970969 , 2.83886068, 2.80573356]),
 'pgradtotal': array([-1.11293998, -1.02229444, -1.03308748, -1.09977122, -1.27617765,
        -1.15337538, -1.17873983, -1.19174489, -1.02438249, -1.03846931]),
 'pgraddpH': array([[-0.90588663+0.j, -0.84892156+0.j, -0.597677  +0.j,
         -0.50483676+0.j, -0.27283644+0.j, -0.1981548 +0.j,
         -0.09140485+0.j, -0.053960

In [22]:
accumulators = {"pgrad" : grad}
for k, accumulator in accumulators.items():
    data =accumulator.avg(configs, wf)
    
data

{'ke': np.float64(0.5652793020204011),
 'ee': np.float64(0.5740119558498835),
 'ei': np.float64(-3.0198530590447406),
 'ecp': np.float64(0.0),
 'grad2': np.float64(2.773261665597687),
 'total': np.float64(-1.1654574620933746),
 'dpH': array([-0.87987491+0.j, -0.74025184+0.j, -0.4875116 +0.j, -0.37599864+0.j,
        -0.16883894+0.j, -0.12131195+0.j, -0.04129962+0.j, -0.0283496 +0.j,
        -0.94239756+0.j, -0.82914628+0.j, -0.62882745+0.j, -0.51376614+0.j,
        -0.27443396+0.j, -0.24584701+0.j, -0.07666104+0.j, -0.08244438+0.j,
         0.        +0.j, -0.67502344+0.j,  0.        +0.j,  0.        +0.j,
        -0.31708948+0.j,  0.        +0.j,  0.        +0.j, -0.10072625+0.j,
         0.        +0.j]),
 'dppsi': array([0.75912032+0.j, 0.61904342+0.j, 0.42332927+0.j, 0.31063856+0.j,
        0.14694685+0.j, 0.09894346+0.j, 0.03594173+0.j, 0.0230013 +0.j,
        0.79441994+0.j, 0.68174063+0.j, 0.5223659 +0.j, 0.41515284+0.j,
        0.22581728+0.j, 0.20058628+0.j, 0.06283257+0.j, 0.

In [15]:
df_vmc

{'pgradke': array([1.00914793, 1.18944444, 1.56201179, 1.5362026 , 1.0849196 ,
        0.61071501, 0.89528079, 1.15546622, 1.09190092, 1.05833295]),
 'pgradee': array([0.62042943, 0.87031513, 0.74346155, 0.69357557, 0.57780602,
        0.68911782, 0.61119074, 0.64290674, 0.65118102, 0.65430755]),
 'pgradei': array([-3.45762167, -3.79715835, -4.05366516, -4.04465373, -3.6540076 ,
        -3.16831255, -3.4003157 , -3.70522219, -3.48256876, -3.46621416]),
 'pgradecp': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'pgradgrad2': array([2.76291613, 2.72266334, 2.80891008, 2.69333234, 2.72107494,
        2.69002718, 2.73268775, 2.6970969 , 2.83886068, 2.80573356]),
 'pgradtotal': array([-1.11293998, -1.02229444, -1.03308748, -1.09977122, -1.27617765,
        -1.15337538, -1.17873983, -1.19174489, -1.02438249, -1.03846931]),
 'pgraddpH': array([[-0.90588663+0.j, -0.84892156+0.j, -0.597677  +0.j,
         -0.50483676+0.j, -0.27283644+0.j, -0.1981548 +0.j,
         -0.09140485+0.j, -0.053960

In [13]:
data = {}
for k in grad.keys():
    data[k] = np.mean(df_vmc["pgrad" + k], axis=0)
    
data

{'grad2': np.float64(2.7473302897156318),
 'ke': np.float64(1.119342224117029),
 'ei': np.float64(-3.6229739863211115),
 'dppsi': array([0.77093512+0.j, 0.76644661+0.j, 0.47743554+0.j, 0.46660295+0.j,
        0.2056848 +0.j, 0.19324485+0.j, 0.06472787+0.j, 0.05859185+0.j,
        0.77356268+0.j, 0.78468615+0.j, 0.47911447+0.j, 0.50309596+0.j,
        0.20356007+0.j, 0.22769842+0.j, 0.06242031+0.j, 0.07435683+0.j,
        0.        +0.j, 0.63928431+0.j, 0.        +0.j, 0.        +0.j,
        0.32610357+0.j, 0.        +0.j, 0.        +0.j, 0.11410196+0.j,
        0.        +0.j]),
 'ecp': np.float64(0.0),
 'ee': np.float64(0.6754291564706612),
 'dpH': array([-0.85780874+0.j, -0.85200955+0.j, -0.52917973+0.j, -0.51880343+0.j,
        -0.22147884+0.j, -0.2095444 +0.j, -0.06746906+0.j, -0.05982225+0.j,
        -0.85168274+0.j, -0.86376399+0.j, -0.51468098+0.j, -0.54001186+0.j,
        -0.20243372+0.j, -0.22598322+0.j, -0.05246995+0.j, -0.06463499+0.j,
         0.        +0.j, -0.66900071+0