In [1]:
import os
import numpy as np
import jax
import jax.numpy as jnp
from jax import random
from typing import Tuple, Optional

from pyscf import gto, scf, mcscf
import pyqmc.api as pyq

from qmc.pyscftools import orbital_evaluator_from_pyscf
from qmc.setting import initialize_calculation, determine_complex_settings
from qmc.mc import limdrift
from qmc.orbitals import *
from qmc.determinants import *
from qmc.extract import *

np.random.seed(42)
jax.config.update("jax_enable_x64",True)

In [2]:
mol = gto.Mole()
mol.atom = '''
O 0.000000 0.000000 0.117790
H 0.000000 0.755453 -0.471161
H 0.000000 -0.755453 -0.471161
'''
mol.basis = 'sto-3g'
mol.build()

<pyscf.gto.mole.Mole at 0x1055ce8f0>

In [4]:
nconfig, seed = 10, 42
coords, max_orb, det_coeff, det_map, mo_coeff, occup_hash, _nelec, nelec = \
    initialize_calculation(mol, nconfig, seed)
iscomplex, mo_dtype, get_phase = \
    determine_complex_settings(mo_coeff, det_coeff)

converged SCF energy = -74.963146775618


In [6]:
np.random.seed(seed)
nsteps = 1000
tstep = 0.5
nconf, nelec, _ = coords.shape

dets, inverse, aovals = recompute(mol, coords, mo_coeff, _nelec, occup_hash)

equilibration_step = 800

electron_positions = []
electron_movements = []
save_times = []

for i in range(equilibration_step):
    acc = 0
    
    if i % 50 == 0:
        save_times.append(i)
        electron_positions.append(coords[0, :, :])
        step_movements = []
        
    for e in range(nelec):
        g, _, _  = gradient_value(mol, e, coords[:, e, :], dets, inverse, mo_coeff, \
                                  det_coeff, det_map, _nelec, occup_hash)
        grad = limdrift(jnp.real(g.T))
        
        if i % 100 == 0:
            step_movements.append(grad[0].copy())
        
        gauss = np.random.normal(scale=np.sqrt(tstep), size=(nconf, 3))
        gauss = jnp.array(gauss)
        newcoorde = coords[:, e, :] + gauss + grad * tstep
        
        # pbc -> make_irreducible -> Not yet
        g, new_val, saved = gradient_value(mol, e, newcoorde, dets, inverse, mo_coeff, \
                                           det_coeff, det_map, _nelec, occup_hash)
        
        new_grad = limdrift(jnp.real(g.T))
        
        forward = jnp.sum(gauss**2, axis = 1)
        backward = jnp.sum((gauss + tstep * (grad + new_grad))**2, axis = 1)
        t_prob = jnp.exp(1 / (2 * tstep) * (forward - backward))

        ratio = jnp.abs(new_val) ** 2 * t_prob
        accept = ratio > np.random.rand(nconf)
        coords[accept, e, :] = newcoorde[accept, :]
        aovals, dets, inverse = sherman_morrison(e, newcoorde, coords, mask = accept, gtoval = "GTOval_sph", aovals = aovals, saved_value= saved, get_phase = get_phase, dets = dets, inverse = inverse, mo_coeff = mo_coeff, occup_hash = occup_hash, _nelec = _nelec)
        
        acc += jnp.mean(accept) / nelec
    
    if i % 100 == 0:
        electron_movements.append(np.array(step_movements))


IndexError: index 8 is out of bounds for axis 0 with size 8

In [6]:
acc

Array(0.68, dtype=float32)

In [14]:
dets, inverse, aovals = recompute(mol, coords, mo_coeff, _nelec, occup_hash)
dets[0]

Array([[[ -1.        ],
        [ -1.        ],
        [ -1.        ],
        [ -1.        ],
        [  1.        ],
        [  1.        ],
        [ -1.        ],
        [  1.        ],
        [  1.        ],
        [  1.        ]],

       [[-10.30106326],
        [-13.01126718],
        [ -9.55197587],
        [-13.30125016],
        [-11.87155652],
        [-11.21617926],
        [-15.79358409],
        [-15.33547236],
        [-12.42074359],
        [-10.03953025]]], dtype=float64)