In [47]:
import os
import numpy as np
import jax
import jax.numpy as jnp
from jax import random
from typing import Tuple, Optional
from pyscf import gto, scf, mcscf
import pyqmc.api as pyq
from pyqmc.energy import *

from qmc.pyscftools import orbital_evaluator_from_pyscf
from qmc.setting import initialize_calculation, determine_complex_settings
from qmc.mc import limdrift
from qmc.orbitals import *
from qmc.determinants import *
from qmc.extract import *

np.random.seed(42)
jax.config.update("jax_enable_x64",True)

In [48]:
mol = gto.Mole()
mol.atom = '''
H  0.000000  0.000000  0.000000  # First hydrogen atom at origin
H  0.740000  0.000000  0.000000  # Second hydrogen atom, typical bond length ~0.74 Å
'''
mol.basis = 'sto-3g'
# mol.ecp = "ccECP"
mol.build()

<pyscf.gto.mole.Mole at 0x337282c20>

In [49]:
nconfig, seed = 10, 42

coords, max_orb, det_coeff, det_map, mo_coeff, occup_hash, _nelec, nelec = \
    initialize_calculation(mol, nconfig, seed)
    
iscomplex, mo_dtype, get_phase = \
    determine_complex_settings(mo_coeff, det_coeff)

coords = jnp.array(coords)
atom_coords = jnp.array(mol.atom_coords())
atom_charges = jnp.array(mol.atom_charges())

np.random.seed(seed)

config = pyq.initial_guess(mol, nconfig)

converged SCF energy = -1.11675930739643


In [5]:
from pyqmc.api import Slater
import pyqmc.api as pyq

np.random.seed(42)
configs = pyq.initial_guess(mol, nconfig)
coords = configs.configs
coords = jnp.array(configs.configs)

mf = scf.RHF(mol)
mf.kernel()
wf = Slater(mol, mf)

aovals, dets, inverse = recompute(mol, coords, mo_coeff, _nelec, occup_hash)
wf = Slater(mol, mf)
wf.recompute(configs)
wf._dets
import time
start = time.time()
for i in range(500):
  aovals, dets, inverse = recompute(mol, coords, mo_coeff, _nelec, occup_hash)

end = time.time()
print(end -start)

start = time.time()
for i in range(500):
  wf.recompute(configs)

end = time.time()
print(end -start)

coords.shape

converged SCF energy = -1.11675930739643
0.07281088829040527
0.03743696212768555


(10, 2, 3)

In [6]:
print(ei_energy(mol, configs))
print(ee_energy(configs))
print(ii_energy(mol))
print(kinetic(configs, wf)[0])

[-2.19675946 -5.48163356 -2.77732104 -3.22420964 -2.80514586 -1.8281549
 -2.83082577 -1.5447141  -2.79575494 -2.69499821]
[0.32648665 0.91723771 0.49263281 0.46253137 0.30685109 0.2483952
 0.73159971 0.23143991 0.39261861 0.63749858]
0.7151043390810812
[-0.12422153  3.12965725  0.08306495  0.51012208  0.43618889 -0.18659063
  0.28733993 -0.61984565  0.35376302  0.42808914]


In [8]:
print(jax_ei_energy(coords, atom_charges, atom_coords))
print(jax_ee_energy(coords))
print(jax_ii_energy(mol))
print(jax_kinetic_energy(coords, mol, dets, inverse, mo_coeff, det_coeff, det_map, _nelec, occup_hash)[0])

[-2.19675946 -5.48163356 -2.77732104 -3.22420964 -2.80514586 -1.8281549
 -2.83082577 -1.5447141  -2.79575494 -2.69499821]
[0.32648665 0.91723771 0.49263281 0.46253137 0.30685109 0.2483952
 0.73159971 0.23143991 0.39261861 0.63749858]
0.7151043390810812
[-0.12422153  3.12965725  0.08306495  0.51012208  0.43618889 -0.18659063
  0.28733993 -0.61984565  0.35376302  0.42808914]


In [9]:
from pyqmc.accumulators import EnergyAccumulator
def invert_list_of_dicts(A, asarray=True):
    """
    if we have a list [ {'A':1,'B':2}, {'A':3, 'B':5}], invert the structure to
    {'A':[1,3], 'B':[2,5]}.
    If not all keys are present in all lists, error.
    """
    if asarray:
        return {k: np.asarray([a[k] for a in A]) for k in A[0].keys()}
    else:
        return {k: [a[k] for a in A] for k in A[0].keys()}

ac = EnergyAccumulator(mol)
ac(configs, wf)

{'ke': array([-0.12422153,  3.12965725,  0.08306495,  0.51012208,  0.43618889,
        -0.18659063,  0.28733993, -0.61984565,  0.35376302,  0.42808914]),
 'ee': array([0.32648665, 0.91723771, 0.49263281, 0.46253137, 0.30685109,
        0.2483952 , 0.73159971, 0.23143991, 0.39261861, 0.63749858]),
 'ei': array([-2.19675946, -5.48163356, -2.77732104, -3.22420964, -2.80514586,
        -1.8281549 , -2.83082577, -1.5447141 , -2.79575494, -2.69499821]),
 'ecp': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'grad2': array([2.78820109, 1.97067214, 2.81761838, 2.27388503, 2.81503776,
        3.16131159, 2.96063378, 2.90494438, 2.35971796, 2.68159172]),
 'total': array([-1.27939   , -0.71963426, -1.48651893, -1.53645185, -1.34700155,
        -1.05124598, -1.09678179, -1.2180155 , -1.33426896, -0.91430615])}

In [10]:
ac(configs, wf)["ee"].shape

(10,)

In [38]:
import time

tstep = 0.5
nconf, nelec, _ = coords.shape

start = time.time()

aovals, dets, inverse = recompute(mol, coords, mo_coeff, _nelec, occup_hash)

equilibration_step = 100

np.random.seed(seed)

energies = []

for i in range(equilibration_step):
    acc = 0
        
    for e in range(nelec):
        
        g, _, _  = gradient_value(mol, e, coords[:, e, :], dets, inverse, mo_coeff, \
                                  det_coeff, det_map, _nelec, occup_hash)
        grad = limdrift(jnp.real(g.T))
        
        
        gauss = np.random.normal(scale=np.sqrt(tstep), size=(nconf, 3))
        gauss = jnp.array(gauss)
        newcoorde = coords[:, e, :] + gauss + grad * tstep
        
        # pbc -> make_irreducible -> Not yet
        g, new_val, saved = gradient_value(mol, e, newcoorde, dets, inverse, mo_coeff, \
                                           det_coeff, det_map, _nelec, occup_hash)
        
        new_grad = limdrift(jnp.real(g.T))
        
        forward = jnp.sum(gauss**2, axis = 1)
        backward = jnp.sum((gauss + tstep * (grad + new_grad))**2, axis = 1)
        t_prob = jnp.exp(1 / (2 * tstep) * (forward - backward))

        ratio = jnp.abs(new_val) ** 2 * t_prob
        accept = ratio > np.random.rand(nconf)
        coords = coords.at[accept, e, :].set(newcoorde[accept, :])
        aovals, dets, inverse = sherman_morrison(e, newcoorde, coords, accept, aovals, saved, get_phase, dets, inverse, mo_coeff, occup_hash, _nelec)
        
        acc += jnp.mean(accept) / nelec
        
    ee = jax_ee_energy(coords)
    ei = jax_ei_energy(coords, atom_charges, atom_coords)
    ii = jax_ii_energy(mol)
    ke = jax_kinetic_energy(coords, mol, dets, inverse, mo_coeff, det_coeff, det_map, _nelec, occup_hash)[0]
    
    energies.append({'ee': ee,
                    'ei': ei,
                    'ii': ii,
                    'ke': ke, 
                    'total': ee + ei + ii + ke})    


end = time.time()
print(end - start)

jax_energies = np.array([e['total'] for e in energies])
mean_energy = np.mean(jax_energies)
std_energy = np.std(jax_energies)
print(f"Mean total energy: {mean_energy:.6f} ± {std_energy:.6f}")

0.8883788585662842
Mean total energy: -1.105117 ± 0.727576


In [39]:
import pyscf
from pyscf import gto, scf, mcscf
from pyqmc.api import Slater
import pyqmc.api as pyq
import numpy as np
from pyqmc.api import vmc
from pyqmc.energy import kinetic

def np_limdrift(g, cutoff=1):
    """
    Limit a vector to have a maximum magnitude of cutoff while maintaining direction

    :parameter g: a [nconf,ndim] vector
    :parameter cutoff: the maximum magnitude
    :returns: The vector with the cutoff applied.
    """
    tot = np.linalg.norm(g, axis=1)
    mask = tot > cutoff
    g[mask, :] = cutoff * g[mask, :] / tot[mask, np.newaxis]
    return g


np.random.seed(42)

np_energies = []
equilibration_step = 100
tstep = 0.5

for _ in range(equilibration_step):
    acc2 = 0.0
    for e in range(nelec):
        # Propose move
        g, _, _ = wf.gradient_value(e, configs.electron(e))
        grad = np_limdrift(np.real(g.T))
        gauss = np.random.normal(scale=np.sqrt(tstep), size=(nconf, 3))
        newcoorde = configs.configs[:, e, :] + gauss + grad * tstep
        newcoorde = configs.make_irreducible(e, newcoorde)

        # Compute reverse move
        g, new_val, saved = wf.gradient_value(e, newcoorde)
        new_grad = np_limdrift(np.real(g.T))
        forward = np.sum(gauss**2, axis=1)
        backward = np.sum((gauss + tstep * (grad + new_grad)) ** 2, axis=1)

        # Acceptance
        t_prob = np.exp(1 / (2 * tstep) * (forward - backward))
        ratio = np.abs(new_val) ** 2 * t_prob
        accept = ratio > np.random.rand(nconf)
        # Update wave function
        configs.move(e, newcoorde, accept)
        wf.updateinternals(e, newcoorde, configs, mask=accept, saved_values=saved)
        acc2 += np.mean(accept) / nelec
        
    ee = ee_energy(configs)
    ei = ei_energy(mol, configs)
    ii = ii_energy(mol)
    ke = kinetic(configs, wf)[0]
    
    np_energies.append({'ee': ee,
                        'ei': ei,
                        'ii': ii,
                        'ke': ke, 
                        'total': ee + ei + ii + ke,
                        'accept_ratio': acc})  
    
print("jaejun_qmc result is" ,acc)
print("pyqmc result is", acc2)

jaejun_qmc result is 0.75
pyqmc result is 0.75


In [40]:
jax_energies = np.array([e['total'] for e in energies])
mean_energy = np.mean(jax_energies)
std_energy = np.std(jax_energies)
print(f"Mean total energy: {mean_energy:.6f} ± {std_energy:.6f}")

Mean total energy: -1.105117 ± 0.727576


In [41]:
np_energies = np.array([e['total'] for e in np_energies])
mean_energy = np.mean(np_energies)
std_energy = np.std(np_energies)
print(f"Mean total energy: {mean_energy:.6f} ± {std_energy:.6f}")

Mean total energy: -1.105117 ± 0.727576


In [9]:
from qmc.mc import jax_energy_mc_simulation
key = jax.random.PRNGKey(seed=20)

coords, energy, accuracy = jax_energy_mc_simulation(coords, mol, mo_coeff, det_coeff, det_map, _nelec, occup_hash,
                        get_phase, key, equilibration_step=500, tstep=0.5)



In [11]:
energy

Array(-1.11202313, dtype=float64)

In [13]:
from qmc.mc import run_mc_simulation
import jax.random as jrand
start = time.time()
key = jrand.PRNGKey(42)
key, subkey = jrand.split(key)

coords, acc = run_mc_simulation(coords, mol, mo_coeff, det_coeff, det_map, 
                                _nelec, occup_hash, get_phase, equilibration_step=500,
                                 tstep=0.5, seed=42)

end = time.time()
print(end-start)

3.0138938426971436


In [12]:
energies

[{'ee': Array([0.40117162, 0.96334681, 0.58657009, 0.65952545, 0.38509744,
         0.23014942, 0.95381741, 0.2375545 , 0.6026147 , 0.53849642],      dtype=float64),
  'ei': Array([ -2.65168875, -10.16193632,  -3.4550606 ,  -3.43210967,
          -2.28967237,  -1.69476882,  -2.37695265,  -1.65848615,
          -3.56217018,  -2.43442724], dtype=float64),
  'ii': Array(0.71510434, dtype=float64),
  'ke': Array([ 0.45682253,  6.99896531,  0.97034146,  0.53329311,  0.00915376,
         -0.41908748,  0.0166355 , -0.42783117,  1.43125485,  0.12000223],      dtype=float64),
  'total': Array([-1.07859025, -1.48451985, -1.1830447 , -1.52418677, -1.18031683,
         -1.16860255, -0.69139541, -1.13365848, -0.8131963 , -1.06082426],      dtype=float64),
  'accept_ratio': Array(0.5, dtype=float32)},
 {'ee': Array([0.44262115, 0.93718596, 0.65829169, 0.78646749, 0.38360768,
         0.34391544, 0.4419953 , 0.25444936, 0.70369312, 0.52978027],      dtype=float64),
  'ei': Array([-2.96513267, -6.5526

In [25]:
wf.value()

(array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),
 array([-4.93192245, -2.48304699, -4.19878575, -3.5358314 , -3.91724478,
        -5.61033333, -4.06560732, -6.90089779, -4.01644801, -4.02440578]))

In [21]:
print(wf._dets[0][1])
print(wf._dets[1][1])

[[-2.11944941]
 [-1.47702455]
 [-1.6230271 ]
 [-1.46267089]
 [-1.87042323]
 [-2.86978243]
 [-1.65028171]
 [-4.08550942]
 [-2.39432369]
 [-2.04516223]]
[[-2.81247304]
 [-1.00602244]
 [-2.57575865]
 [-2.07316051]
 [-2.04682155]
 [-2.74055089]
 [-2.41532561]
 [-2.81538837]
 [-1.62212432]
 [-1.97924355]]


In [28]:
wf._dets[0].shape[1:]

(10, 1)

In [51]:
coords.shape

(10, 2, 3)