### Detailed Explanation of Metropolis-Hastings Algorithm

The Metropolis-Hastings algorithm is a versatile Markov Chain Monte Carlo (MCMC) method used across various fields for sampling from complex probability distributions. This section details the traditional use of the algorithm and its adaptation in quantum Monte Carlo (QMC) simulations.

#### General Metropolis-Hastings Algorithm

1. **Initialization:**
   - Start with an initial point \( x_0 \) in the state space of the probability distribution you wish to explore.

2. **Proposal Distribution:**
   - Select a proposal distribution \( q(x'|x) \) which suggests a new state \( x' \) given the current state \( x \). This is typically a symmetric distribution like Gaussian centered around \( x \).

3. **Generation of Candidate State:**
   - Generate a candidate state \( x' \) using the proposal distribution \( q(x'|x) \).

4. **Calculation of Acceptance Probability:**
   - Compute the acceptance probability \( A \) using the formula:
   $$ A(x, x') = \min\left(1, \frac{p(x') \cdot q(x|x')}{p(x) \cdot q(x'|x)}\right) $$
   where \( p(x) \) is the density of the target distribution at \( x \) and \( q(x'|x) \) is the probability of proposing \( x' \) from \( x \).
   $$ q(x'|x) = \frac{1}{\sqrt{2\pi\sigma^2}} e^{-\frac{(x'-x)^2}{2\sigma^2}} $$

5. **Acceptance Decision:**
   - Accept the candidate state \( x' \) with probability \( A(x, x') \). If \( x' \) is accepted, move to \( x' \); otherwise, remain at \( x \).

6. **Iteration:**
   - Repeat the process for a large number of iterations to ensure adequate exploration of the target distribution.

#### Quantum Monte Carlo Adaptation of Metropolis-Hastings

1. **Quantum Context:**
   - In QMC, particularly in the Variational Monte Carlo (VMC) and Diffusion Monte Carlo (DMC) techniques, the algorithm is used to sample electron configurations according to the probability distribution \( |\Psi|^2 \), where \( \Psi \) is the quantum wave function.

2. **Quantum Force:**
   - Introduce the quantum force \( F(R) \), computed as:
   $$ F(R) = \nabla \ln |\Psi(R)| $$
   This force guides the sampling towards regions of higher probability density, effectively enhancing the efficiency of the exploration.

3. **Modified Movement Proposal:**
   - The proposal for a new position \( R' \) incorporates both a random component and the quantum force:
   $$ R' = R + \eta + \tau F(R) $$
   Here, \( \eta \) typically follows a normal distribution, and \( \tau \) is a timestep parameter.

4. **Green's Function:**
   - Calculate the transition probability or Green's function, which now considers the non-symmetry introduced by the quantum force:
   $$ G(R \rightarrow R') = \exp\left(-\frac{(R' - R - \tau F(R))^2}{2\tau}\right) $$

5. **New Acceptance Ratio:**
   - The acceptance ratio \( A \) is adjusted to include the effects of the quantum dynamics:
   $$ A = \frac{|\Psi(R')|^2}{|\Psi(R)|^2} \times \frac{G(R' \rightarrow R)}{G(R \rightarrow R')} $$

6. **Expectation Value Calculation:**
   - Compute the expectation values for observables using the weighted average over the sampled states:
   $$ \langle O \rangle = \frac{\Sigma_i O(R_i) |\Psi(R_i)|^2}{\Sigma_i |\Psi(R_i)|^2} $$

This detailed walkthrough illustrates how the Metropolis-Hastings algorithm is adapted from its general form to specifically address the unique challenges and requirements of quantum mechanical systems in QMC methods.



In [1]:
import numpy as np
import jax
import jax.numpy as jnp
import functools
from functools import partial
from pyscf import gto, scf
from dataclasses import dataclass
from typing import Tuple, List
from qmc.orbitals import aos, mos, orbital_from_pyscf
from qmc.mc import np_initial_guess

mol = gto.Mole()
mol.atom = '''
H 0.0 0.0 0.0
H 0.0 0.0 0.74
'''

mol.basis = 'sto-3g'

mf = scf.RHF(mol)
hf_energy = mf.kernel()
print(f"Hartree-Fock energy is {hf_energy:.6f} Hartree" )

converged SCF energy = -1.11675930739643
Hartree-Fock energy is -1.116759 Hartree


Initialize <pyscf.gto.mole.Mole object at 0x11a1a3a90> in <pyscf.scf.hf.RHF object at 0x168209460>


In [4]:
def mol_eval_gto(mol, evalstr: str, walker: np.ndarray) -> np.ndarray:
    """
    Calculate the value of molecule orbitals
    
    Parameters:
    mol (object) : The pyscf molecular objects
    evalstr (str) : The evaluation string for the molecular orbitals
    primcoords (np.ndarray): The coordinates of random walker for each electron
    
    Return
    np.ndarray --> to jnx.numpy.ndarray since mol cannot be utilized in jax environement
    """
    
    aos = mol.eval_gto(evalstr, walker)
    
    if "deriv2" in evalstr:
        aos[4] += aos[7] + aos[9]
        aos = aos[:5]
    
    return aos

def aos(mol, eval_str, configs, mask = None) -> jnp.ndarray:
    '''
    Evaluate atomic orbitals at given configurations.
    
    1) Parameters:
    configs -> mycoords (np.ndarray) : Configuration object containing electron positions
    mask (np.ndarray) : Optional mask for selecting specific configurations
    eval_str : Type of pyscf evaluation string
    φₖ(r) = Nᵢ × Rₙₗ(r) × Yₗₘ(θ,φ)
    
    2) eval_str
    (1) 'GTOval' : Evaluate the value of GTOs
    (2) 'GTOval_ip' : Evaluate the value of GTOs and their first derivatives
        ∂φ/∂x, ∂φ/∂y, ∂φ/∂z 반환
    (3) 'GTOval_ip_ip' : Evaluate the value of GTOs and their first and second derivatives
        ∂²φ/∂x², ∂²φ/∂y², ∂²φ/∂z², ∂²φ/∂x∂y, ∂²φ/∂x∂z, ∂²φ/∂y∂z 반환

    Returns:
    np.ndarray : Atomic orbitals
    [1, nconf, nao] or [1, 3, nconf, nao] (gradients)
    nao: Number of atomic orbitals
    nmo : Nuber of molecular orbitals
    '''
    mycoords = configs if mask is None else configs[mask] # [nconf, nelec, 3]
    mycoords = mycoords.reshape((-1, mycoords.shape[-1])) #[nconf*nelec, 3]
    eval_gto = functools.partial(mol_eval_gto, mol)
    aos = jnp.asarray(eval_gto(eval_str, mycoords))[jnp.newaxis] # [1, nconf, nelec]
    
    if len(aos.shape) == 4:  # derivatives included
        return aos.reshape((1, aos.shape[1], *mycoords.shape[:-1], aos.shape[-1]))

    return aos.reshape((1, *mycoords.shape[:-1], aos.shape[-1])) #[1, nconf*nelec, nelec]

@jax.jit
def mos(ao: jnp.ndarray, parameters) -> jnp.ndarray:
    '''
    Convert atomic orbitals to molecular orbitals for given spin.
    
    Parameters:
    ao : Atomic orbital values [1, nconf, nelec, nao]
    
    spin : Spin index (0 for alpha, 1 for beta)
    
    Φᵢ(r) = Σₖ cᵢₖ φₖ(r)
    
    Returns:
    Molecular orbital values
    '''
    return jnp.dot(ao[0], parameters)


# Initial configuration
mol = gto.Mole()
mol.atom = '''
O 0.000000 0.000000 0.117790
H 0.000000 0.755453 -0.471161
H 0.000000 -0.755453 -0.471161
'''
mol.basis = 'sto-3g'
mol.build()

# Hartree Fock Approximation
mf = scf.RHF(mol)
mf.kernel()
# Initial configuration
nconfig = 100
nelec = np.sum(mol.nelec)
config = np_initial_guess(mol, nconfig)

# 3. AO and MO casdsdalculation
gtoval = "GTOval_sph"
atomic_orbital = aos(mol, gtoval, config)
max_orb, mo_coeff, det_coeff, occup, det_map, _nelec = orbital_from_pyscf(mol, mf)
print("max orbital is", max_orb)
print("alpah mo_coeff shape is ", mo_coeff[0].shape)
print("det_coeff is ", det_coeff.shape)
print("det_map is ", det_map)
print("_nelec is ", _nelec)

# AO reshape
aovals = atomic_orbital.reshape(-1, nconfig, nelec, atomic_orbital.shape[-1])

# MO 계산
mo_coeff = jnp.array(mo_coeff)
mo_vals = mos(aovals, mo_coeff)

# 결과 확인
print("AO shape:", atomic_orbital.shape)
print("Reshaped AO shape:", aovals.shape)
print("MO shape:", mo_vals.shape)
print("mo_coeff shape:" ,mo_coeff.shape)

converged SCF energy = -74.963146775618
max orbital is [5 5]
alpah mo_coeff shape is  (7, 5)
det_coeff is  (1,)
det_map is  [[0]
 [0]]
_nelec is  (5, 5)
AO shape: (1, 1000, 7)
Reshaped AO shape: (1, 100, 10, 7)
MO shape: (100, 10, 2, 5)
mo_coeff shape: (2, 7, 5)


In [24]:
max_orb, mo_coeff, det_coeff, occup, det_map, _nelec
print(max_orb.type)

AttributeError: 'jaxlib.xla_extension.ArrayImpl' object has no attribute 'type'

In [40]:
occup

[[Array([0, 1, 2, 3, 4], dtype=int32)], [Array([0, 1, 2, 3, 4], dtype=int32)]]

In [28]:
type(max_orb)

jaxlib.xla_extension.ArrayImpl

In [29]:

def jax_organize_determinant_data(determinant_list, weight_threshold=0):
    '''
    Jax version of organize determinant data function.
    Note: This function is just for experimental version
    create package tool at pyqmc
    
       Input: determinant_list in format 
   [(weight1, ([up_orbs1], [down_orbs1])), 
    (weight2, ([up_orbs2], [down_orbs2])), ...]
   Example: [(0.9, ([0,1], [0,1])), (0.1, ([0,2], [0,1]))]
   
   Returns:
   1. detwt (determinant weights): Coefficient/weight for each determinant
      - Ex: array([0.9, 0.1])  
      - Meaning: First det has weight 0.9, second has 0.1
   
   2. occup (orbital occupations): Unique orbital occupation patterns for each spin
      - Format: [up_patterns, down_patterns]
      - Ex: [[array([0,1]), array([0,2])],  # Up spin occupation patterns
            [array([0,1])]]                 # Down spin occupation patterns
      - Meaning: Stores only unique occupation patterns without duplicates
   
   3. map_dets (pattern mapping): Maps each determinant to its occupation patterns
      - Format: array([[up_indices], [down_indices]])
      - Ex: array([[0,1],   # First det uses up[0], second uses up[1]
                   [0,0]])  # Both dets use down[0]
      - Meaning: Shows which patterns from occup are used by each weight in detwt
    '''
    
    # Initialize empty containers
    weights = []
    patterns = [[], []]
    mapping = [[], []]
    
    filtered_dets = [det for det in determinant_list 
                     if jnp.abs(det[0]) > weight_threshold]
    
    for det in filtered_dets:
        weights.append(det[0])
        spin_occupations = det[1]
        
        for spin in [0, 1]:
            curr_occupation = tuple(spin_occupations[spin])
            
            if curr_occupation not in patterns[spin]:
                mapping[spin].append(len(patterns[spin]))
                patterns[spin].append(curr_occupation)
                
            else:
                pattern_idx = patterns[spin].index(curr_occupation)
                mapping[spin].append(pattern_idx)
                
    determinant_weights = jnp.array(weights)
    pattern_mapping = jnp.array(mapping)
    
    orbital_patterns = [
        [jnp.array(list(pattern)) for pattern in spin_patterns]
        for spin_patterns in patterns
    ]
    
    return determinant_weights, orbital_patterns, pattern_mapping

list

In [409]:
@partial(jax.jit, static_argnums=(2,))
def compute_determinants(mo_values, det_occup, s):
    mo_vals = jnp.swapaxes(mo_values[:, :, det_occup[s]], 1, 2)
    compute_det = jnp.asarray(jnp.linalg.slogdet(mo_vals))
    inverse = jax.vmap(jnp.linalg.inv)(mo_vals)
    return compute_det, inverse

def recompute(configs, atomic_orbital, det_occup, mo_coeff, _nelec, s):
    nconf, nelec_tot, ndim = configs.shape
    aovals = atomic_orbital.reshape(-1, nconf, nelec_tot, atomic_orbital.shape[-1])
    if s == 0:
        ao_slice = aovals[:, :, :_nelec[0], :]
        param = mo_coeff[0]
    else:
        ao_slice = aovals[:, :, _nelec[0]:_nelec[0]+_nelec[1], :]
        param = mo_coeff[1]
    mo = mos(ao_slice, param)
    return compute_determinants(mo, occup_hash, s)

## 수정
def convert_to_hashable(occup):
    up_orbs = tuple(occup[0])
    dn_orbs = tuple(occup[1])
    
    return (up_orbs, dn_orbs)

occup_hash = convert_to_hashable(occup)

updets, up_inverse = recompute(
    configs=config,
    atomic_orbital=atomic_orbital,
    det_occup=occup_hash,
    mo_coeff=mo_coeff,
    _nelec=mol.nelec,
    s=0  # up spin
)
updets = updets[:, :, det_map[0]]

dndets, down_inverse = recompute(
    configs=config,
    atomic_orbital=atomic_orbital,
    det_occup=occup_hash,
    mo_coeff=mo_coeff,
    _nelec=mol.nelec,
    s=1  # down spin
)

dndets = dndets[:, :, det_map[1]]

determinant = tuple([updets, dndets])
inverse = tuple([up_inverse, down_inverse])

@jax.jit
def compute_wf_value(updets, dndets, def_coeff=None):
    
    """
    Args:
        updets: (signup, logvals_up) up spin determinant
        dndets: (signdown, logvals_down) down spin determinant
        def_coeffs: Coefficient of determinant
        
        Ψ = D↑ × D↓  
        logΨ = log D↑ + log D↓
        logΨ_reg = log D↑ - log D↑max + log D↓ - log D↓max
        Ψ_reg => check the zero value, which might be diverged at log scale
        => That's why we use jnp.where to check the jnp.inf
        
        wf_logval = logΨ_reg + log D↑max + log D↓max = log D↑ + log D↓
        
        return log D↑ + log D↓
    """
    
    if def_coeff is None:
        det_coeff = jnp.array([1.0])
        
        
    upref = jnp.amax(updets[1]).real
    dnref = jnp.amax(dndets[1]).real
    
    phases = updets[0] * dndets[0]
    logvals = updets[1] - upref + dndets[1] - dnref
    
    wf_val = jnp.einsum("d,id->i", def_coeff, phases * jnp.exp(logvals))
    wf_sign = jnp.where(wf_val == 0, 0.0, wf_val / jnp.abs(wf_val))
    wf_logval = jnp.where(wf_val == 0, -jnp.inf,
                            jnp.log(jnp.abs(wf_val)) + upref + dnref)
    
    return wf_sign, wf_logval

In [410]:
"""
Wave function gradient calculation
∇Ψ/Ψ = ∇ln|Ψ| = Σᵢ (∇Dᵢ/Dᵢ)
∇D/D = Σₖ (∇Φₖ/Φₖ)
"""

determinant = tuple([updets, dndets])
inverse = tuple([up_inverse, down_inverse])

def gradient_value(mol, e, epos, inverse, mo_coeff, det_occup, _nelec):
    """ 
    Compute the gradient value of the wave function
    
    Parameters
    ------------
    mol : pyscf.gto.Mole.object
    e : int
    epos : config
    wf : wavefunction
    """
    
    # determine the spin (if e is larger than up electorn, then the spin value is 1)
    s = int(e >= _nelec[0])
    
    # (φ ∂φ/∂x, ∂φ/∂y, ∂φ/∂z) -> (1, 4, config, number of coefficients)
    aograd = aos(mol, "GTOval_sph_deriv1", epos)

    # ∂Φᵢ/∂r = Σₖ cᵢₖ ∂φₖ/∂r -> (4, config, number_of_electrons)
    mograd = mos(aograd, mo_coeff[s])
    
    # (4, config, 1, number_of_electrons)
    mograd_vals = mograd[:, :, det_occup[s]]

    ratio = _testrow_deriv(e, mograd_vals, inverse, s, det_map, det_coeff, updets, dndets)
    
    derivatives = ratio[1:] / ratio[0]
    derivatives = derivatives.at[~jnp.isfinite(derivatives)].set(0.0)
    
    values = ratio[0]
    values = values.at[~jnp.isfinite(values)].set(1.0)
     
    return  derivatives, values, (aograd[:, 0], mograd[0])
    
    
def _testrow_deriv(e, vec, inverse, s, det_map, det_coeff, updets, dndets):
    
    #∇D/D = Σₖ (∇Φₖ/Φₖ)
    
    ratios = jnp.einsum("ei...dj, idj... ->ei...d",
                        vec,
                        inverse[s][..., e - s*_nelec[0]])

    upref = jnp.amax(updets[1]).real
    dnref = jnp.amax(dndets[1]).real
    
    det_array = (updets[0, :, det_map[0]] * 
                 dndets[0, :, det_map[1]] *
                 jnp.exp(
                     updets[1][:, det_map[0]] +
                     dndets[1][:, det_map[1]] -
                     upref - dnref
                     )
    )

    numer = jnp.einsum("ei...d,d,di->ei...",
                        ratios[..., det_map[s]],
                        det_coeff,               
                        det_array                 
    )
    
    denom = jnp.einsum("d,di->i...",
                      det_coeff,
                      det_array
    )

    if len(numer.shape) == 3:
        denom = denom[jnp.newaxis, :, jnp.newaxis]

    return numer / denom
    
s = 0
e = 1
epos = config[:, e]

g, _, _  = gradient_value(mol, e, epos, inverse, mo_coeff, occup_hash, _nelec)
print(g.shape)

(3, 10)


In [411]:
ao = aos(mol, "GTOval_sph_deriv2", epos)
mo = mos(ao, mo_coeff[0])
mo_vals = mo[:, :, occup[0]]
ratio = _testrow_deriv(e, mo_vals, inverse, s, det_map, det_coeff, updets, dndets)
ratio[0:1].shape


def _testrow_deriv(e, 
                   vec, 
                   inverse, 
                   s,
                   updets,
                   dndets, 
                   det_coeff,
                   det_map,
                   _nelec
                   ):
    

(1, 10)

In [419]:
def gradient_laplacian(mol, e, epos, inverse, mo_coeff, det_occup, _nelec, det_map, det_coeff, updets, dndets):
    
    s = int(e >= _nelec[0])
    
    ao = aos(mol, "GTOval_sph_deriv2", epos)
    mo = mos(ao, mo_coeff[s])
    mo_vals = mo[..., det_occup[s]]

    ratio = _testrow_deriv(e, mo_vals, inverse, s, det_map, det_coeff, updets, dndets)
    ratio = ratio/ratio[:1]
    return ratio[1:-1], ratio[-1]
    
def kinetic_energy(configs, mol, mo_coeff, det_occup, _nelec, det_map, det_coeff, updets, dndets, inverse):
    """
    운동에너지 계산
    
    Parameters
    ----------
    configs : jnp.ndarray
        전자 configurations
    기타 parameters는 gradient_laplacian과 동일
    
    Returns
    -------
    Tuple[jnp.ndarray, jnp.ndarray]
        (kinetic_energy, gradient_squared)
    """
    nconf = configs.shape[0]
    ke = jnp.zeros(nconf)
    grad2 = jnp.zeros(nconf)
    for e in range(configs.shape[1]):
        grad, lap = gradient_laplacian(
            mol, e, configs[:, e, :], inverse,
            mo_coeff, det_occup, _nelec,
            det_map, det_coeff, updets, dndets
        )
        
        # -1/2 ∇²Ψ/Ψ
        ke += -0.5 * jnp.real(lap)
        
        # gradient culmulation
        grad2 += jnp.sum(jnp.abs(grad)**2, axis=0)
    
    return ke, grad2

ke, grad2 = kinetic_energy(
    config, mol, mo_coeff, 
    occup, mol.nelec, det_map, 
    det_coeff, updets, dndets, inverse
)

print(np.mean(ke))

-17.66446


In [413]:
def ee_energy(configs):
    """
    Calculate the electron-electron Coulomb interaction energy.
    Uses the Coulomb formula: E = k * q₁ * q₂ / r, where r is the distance between electrons.
    
    Parameters
    ----------
    configs : jnp.ndarray
        Electron coordinates (nconf, nelec, 3)
        
    Returns
    -------
    jnp.ndarray
        Electron-electron energy (nconf,)
    """
    nconf, nelec, _ = configs.shape
    
    if nelec == 1:  # No interaction if there is only one electron
        return jnp.zeros(nconf)
    
    # Calculate distances between all pairs of electrons
    r_ee = configs[:, :, None, :] - configs[:, None, :, :]  # (nconf, nelec, nelec, 3)
    r_ee_dist = jnp.sqrt(jnp.sum(r_ee**2, axis=-1))  # (nconf, nelec, nelec)
    
    # Remove self-interaction (diagonal elements)
    mask = ~jnp.eye(nelec, dtype=bool)
    ee = jnp.where(mask, 1.0/r_ee_dist, 0.0)  # Coulomb interaction energy: 1/r
    
    # Correct for double counting by multiplying by 1/2
    return jnp.sum(ee, axis=(1,2)) * 0.5  # Sum over all pairs, correct for double counting

def ei_energy(mol, configs):
    """
    Calculate the electron-ion Coulomb interaction energy.
    Uses the Coulomb formula: E = k * q₁ * q₂ / r, where q₁ and q₂ are charges and r is the distance.
    
    Parameters
    ----------
    mol : pyscf.gto.Mole
        Molecule object
    configs : jnp.ndarray
        Electron coordinates (nconf, nelec, 3)
        
    Returns
    -------
    jnp.ndarray
        Electron-ion energy (nconf,)
    """
    ei = jnp.zeros(configs.shape[0])
    atom_coords = jnp.array(mol.atom_coords())
    atom_charges = jnp.array(mol.atom_charges())
    
    for coord, charge in zip(atom_coords, atom_charges):
        # Calculate distances between each electron and ion
        r_ei = configs - coord[None, None, :]  # (nconf, nelec, 3)
        r_ei_dist = jnp.sqrt(jnp.sum(r_ei**2, axis=-1))  # (nconf, nelec)
        
        # Accumulate energy (-Ze²/r): Negative because electrons are negatively charged
        ei -= charge * jnp.sum(1.0/r_ei_dist, axis=1)
    
    return ei

def dist_matrix(configs):
    """
    Calculate all pairwise distance vectors within a set of positions
    
    Parameters
    ----------
    configs : jnp.ndarray 
        Configuration array of shape (nconf, n, 3)
    
    Returns
    -------
    vs : jnp.ndarray
        Distance vectors of shape (nconf, n(n-1)/2, 3)
    ij : list
        List of index pairs [(i,j)] corresponding to distances
    """
    nconf, n = configs.shape[:2]
    npairs = int(n * (n - 1) / 2)
    
    if npairs == 0:
        return jnp.zeros((nconf, 0, 3)), []
    
    # Calculate distances and keep track of indices
    vs = []
    ij = []
    for i in range(n):
        # Calculate distance vectors from atom i to all atoms j > i
        dist_vectors = configs[:, i + 1:, :] - configs[:, i:i+1, :]
        vs.append(dist_vectors)
        ij.extend([(i, j) for j in range(i + 1, n)])
    
    vs = jnp.concatenate(vs, axis=1)
    return vs, ij

def ii_energy(mol):
    """
    Calculate ion-ion Coulomb interaction energy
    
    Parameters
    ----------
    mol : pyscf.gto.Mole
        Molecule object containing atomic positions and charges
    
    Returns
    -------
    float
        Ion-ion interaction energy
    """
    # Convert atomic coordinates and charges to JAX arrays
    coords = jnp.array(mol.atom_coords())[jnp.newaxis, :, :]
    charges = jnp.array(mol.atom_charges())
    
    # Get distance vectors and corresponding index pairs
    rij, ij = dist_matrix(coords)
    
    if len(ij) == 0:  # Single atom case
        return jnp.array(0.0)
    
    # Calculate magnitudes of distance vectors
    rij = jnp.linalg.norm(rij, axis=2)[0, :]
    
    # Sum up Coulomb interactions: Σ(ZᵢZⱼ/rᵢⱼ)
    energy = sum(charges[i] * charges[j] / r for (i, j), r in zip(ij, rij))
    
    return energy


def compute_potential_energy(mol, configs):
    """
    Total potential energy calculation
    
    Parameters
    ----------
    mol : pyscf.gto.Mole
    configs : jnp.ndarray
        Electron configuraton (nconf, nelec, 3)
        
    Returns
    -------
    dict
        total potential energy
    """
    # 각 성분 계산
    ee = ee_energy(configs)
    ei = ei_energy(mol, configs)
    ii = ii_energy(mol)
    
    # 결과 dictionary 생성
    potential_components = {
        'ee': ee,           # E-E
        'ei': ei,           # E-Ion
        'ii': jnp.full_like(ee, ii),  # Ion-Ion (상수)
        'total': ee + ei + ii  # Total potential
    }
    
    return potential_components

pot_energy = compute_potential_energy(mol, config)
print("\nPotential Energy Components:")
for key, val in pot_energy.items():
    if key != 'total':
        print(f"{key.upper()} Energy mean: {jnp.mean(val):.6f}")
print(f"Total Potential Energy mean: {jnp.mean(pot_energy['total']):.6f}")


Potential Energy Components:
EE Energy mean: 24.794577
EI Energy mean: -70.562401
II Energy mean: 9.189194
Total Potential Energy mean: -36.578632


In [414]:
def compute_total_energy(mol, configs, mo_coeff, det_occup, _nelec, det_map, det_coeff, updets, dndets, inverse):
    # 1. Kinetic energy
    ke, grad2 = kinetic_energy(
        configs, mol, mo_coeff, det_occup, _nelec,
        det_map, det_coeff, updets, dndets, inverse
    )
    
    # 2. Potential energy
    pot = compute_potential_energy(mol, configs)
    
    energy_components = {
        'ke': np.array(ke),
        'ee': np.array(pot['ee']),
        'ei': np.array(pot['ei']),
        'ii': np.array(pot['ii']),
        'grad2': np.array(grad2),
        'total': np.array(ke + pot['total'])
    }
    
    return energy_components

In [415]:
def limdrift(grad: jnp.ndarray, cutoff: float = 1.0) -> jnp.ndarray:
    """
    Limit gradient for large drift values to avoid instabilities in the simulation
    
    Parameters
    ----------
    grad : jnp.ndarray
        Gradient to limit, shape (nconf, 3)
    cutoff : float
        Maximum allowed magnitude for the drift velocity
        
    Returns
    -------
    jnp.ndarray
        Limited gradient with same shape as input
    """
    grad_squared = jnp.sum(grad**2, axis=-1, keepdims=True)
    mask = grad_squared > cutoff**2
    grad = jnp.where(mask, grad * cutoff / jnp.sqrt(grad_squared), grad)
    return grad


grad, values, saved = gradient_value(
            mol, e, config[:, e], inverse,
            mo_coeff, occup_hash, mol.nelec
        )
        
grad = jnp.real(grad.T)
        
    # Limit drift
grad = limdrift(grad)
print(grad.shape)

(10, 3)


In [416]:
# VMC parameters
nblocks = 10
nsteps = 1000
nsteps_per_block = nsteps // nblocks
tstep = 0.01
nconf = config.shape[0]
# Storage for block results
block_energies = []
block_acceptance = []

# Initial computation 
occup_hash = convert_to_hashable(occup)

updets, up_inverse = recompute(
    configs=config,
    atomic_orbital=atomic_orbital,
    det_occup=occup_hash,
    mo_coeff=mo_coeff,
    _nelec=mol.nelec,
    s=0  # up spin
)
updets = updets[:, :, det_map[0]]

dndets, down_inverse = recompute(
    configs=config,
    atomic_orbital=atomic_orbital,
    det_occup=occup_hash,
    mo_coeff=mo_coeff,
    _nelec=mol.nelec,
    s=1  # down spin
)

dndets = dndets[:, :, det_map[1]]

determinant = tuple([updets, dndets])
inverse = tuple([up_inverse, down_inverse])

# Run VMC with blocks
for block in range(nblocks):
    print(f"\nStarting block {block+1}/{nblocks}")
    
    block_avg = {
        'ke': 0.0, 'ee': 0.0, 'ei': 0.0,
        'ii': 0.0, 'total': 0.0, 'acceptance': 0.0
    }
    
    for step in range(nsteps_per_block):
        acc = 0.0
        
        for e in range(nelec):
            # Current gradient
            grad, values, saved = gradient_value(
                mol, e, config[:, e], inverse,
                mo_coeff, occup_hash, mol.nelec
            )
            grad = np.real(grad.T)
            
            # Apply drift limiting
            grad = limdrift(grad)
            
            # Propose move
            gauss = np.random.normal(scale=np.sqrt(tstep), size=(nconf, 3))
            new_pos = config[:, e] + gauss + grad * tstep
            
            # New gradient and wave function
            new_grad, new_val, saved_new = gradient_value(
                mol, e, new_pos, inverse,
                mo_coeff, occup_hash, mol.nelec
            )
            new_grad = np.real(new_grad.T)
            new_grad = limdrift(new_grad)
            
            # Metropolis acceptance
            forward = np.sum(gauss**2, axis=1)
            backward = np.sum((gauss + tstep * (grad + new_grad)) ** 2, axis=1)
            t_prob = np.exp(1/(2*tstep) * (forward - backward))
            ratio = np.abs(new_val)**2 * t_prob
            
            # Accept/reject step
            accept = ratio > np.random.random(ratio.shape)
            
            # Update positions
            config[accept, e, :] = new_pos[accept, :]
            
            # Determinant updates
            s = int(e >= mol.nelec[0])
            if accept.any():
                if s == 0:
                    updets, up_inverse = recompute(config, atomic_orbital, occup_hash, mo_coeff, mol.nelec, s=0)
                    updets = updets[:, :, det_map[0]]
                    inverse = (up_inverse, inverse[1])
                else:
                    dndets, down_inverse = recompute(config, atomic_orbital, occup_hash, mo_coeff, mol.nelec, s=1)
                    dndets = dndets[:, :, det_map[1]]
                    inverse = (inverse[0], down_inverse)
            
            acc += np.mean(accept) / nelec
        
        # Energy calculation
        energies = compute_total_energy(
            mol, config, mo_coeff, occup_hash, mol.nelec,
            det_map, det_coeff, updets, dndets, inverse
        )
        
        # Update block averages
        for key in block_avg.keys():
            if key != 'acceptance':
                block_avg[key] += np.mean(energies[key]) / nsteps_per_block
        block_avg['acceptance'] += acc / nsteps_per_block
        
        if step % 100 == 0:
            print(f"Block {block+1}, Step {step}: E = {np.mean(energies['total']):.6f}, Acc = {acc:.3f}")
    
    # Store block results
    block_energies.append(block_avg['total'])
    block_acceptance.append(block_avg['acceptance'])
    
    # Print block summary
    print(f"\nBlock {block+1} summary:")
    print(f"Energy components (atomic units):")
    for key, val in block_avg.items():
        if key != 'acceptance':
            print(f"{key.upper()}: {val:.6f}")
    print(f"Acceptance ratio: {block_avg['acceptance']:.3f}")

# Final statistical analysis
block_energies = np.array(block_energies)
mean_energy = np.mean(block_energies)
error_energy = np.std(block_energies) / np.sqrt(nblocks)
mean_acceptance = np.mean(block_acceptance)

print(f"\nFinal Results after {nblocks} blocks:")
print(f"Energy: {mean_energy:.6f} ± {error_energy:.6f}")
print(f"Mean acceptance ratio: {mean_acceptance:.3f}")


Starting block 1/10
Block 1, Step 0: E = -63.612312, Acc = 0.940

Block 1 summary:
Energy components (atomic units):
KE: -28.793247
EE: 26.210768
EI: -87.350685
II: 9.189191
TOTAL: -80.743988
Acceptance ratio: 0.825

Starting block 2/10
Block 2, Step 0: E = -78.830635, Acc = 0.820

Block 2 summary:
Energy components (atomic units):
KE: -31.997562
EE: 28.379446
EI: -97.486740
II: 9.189191
TOTAL: -91.915634
Acceptance ratio: 0.794

Starting block 3/10
Block 3, Step 0: E = -93.444908, Acc = 0.770

Block 3 summary:
Energy components (atomic units):
KE: -30.409754
EE: 27.441685
EI: -91.743225
II: 9.189191
TOTAL: -85.522087
Acceptance ratio: 0.775

Starting block 4/10
Block 4, Step 0: E = -86.285789, Acc = 0.730

Block 4 summary:
Energy components (atomic units):
KE: -39.472057
EE: 27.390001
EI: -89.726784
II: 9.189191
TOTAL: -92.619629
Acceptance ratio: 0.758

Starting block 5/10
Block 5, Step 0: E = -91.031929, Acc = 0.720

Block 5 summary:
Energy components (atomic units):
KE: -32.365292

In [417]:
print(f"\nFinal Results after {nblocks} blocks:")
print(f"Energy: {mean_energy:.6f} ± {error_energy:.6f}")
print(f"Mean acceptance ratio: {mean_acceptance:.3f}")


Final Results after 10 blocks:
Energy: -82.629105 ± 2.005051
Mean acceptance ratio: 0.754


In [406]:
# VMC parameters
nblocks = 10
nsteps = 1000
nsteps_per_block = nsteps // nblocks
tstep = 0.01
nconf = config.shape[0]
# Storage for block results
block_energies = []
block_acceptance = []

# Initial computation 
occup_hash = convert_to_hashable(occup)

updets, up_inverse = recompute(
    configs=config,
    atomic_orbital=atomic_orbital,
    det_occup=occup_hash,
    mo_coeff=mo_coeff,
    _nelec=mol.nelec,
    s=0  # up spin
)
updets = updets[:, :, det_map[0]]

dndets, down_inverse = recompute(
    configs=config,
    atomic_orbital=atomic_orbital,
    det_occup=occup_hash,
    mo_coeff=mo_coeff,
    _nelec=mol.nelec,
    s=1  # down spin
)

dndets = dndets[:, :, det_map[1]]

determinant = tuple([updets, dndets])
inverse = tuple([up_inverse, down_inverse])

# Run VMC with blocks
for block in range(nblocks):
    print(f"\nStarting block {block+1}/{nblocks}")
    
    block_avg = {
        'ke': 0.0, 'ee': 0.0, 'ei': 0.0,
        'ii': 0.0, 'total': 0.0, 'acceptance': 0.0
    }
    
    for step in range(nsteps_per_block):
        acc = 0.0
        
        for e in range(nelec):
            # Current gradient
            grad, values, saved = gradient_value(
                mol, e, config[:, e], inverse,
                mo_coeff, occup_hash, mol.nelec
            )
            grad = np.real(grad.T)
            
            # Apply drift limiting
            grad = limdrift(grad)
            
            # Propose move
            gauss = np.random.normal(scale=np.sqrt(tstep), size=(nconf, 3))
            new_pos = config[:, e] + gauss + grad * tstep
            
            # New gradient and wave function
            new_grad, new_val, saved_new = gradient_value(
                mol, e, new_pos, inverse,
                mo_coeff, occup_hash, mol.nelec
            )
            new_grad = np.real(new_grad.T)
            new_grad = limdrift(new_grad)
            
            # Metropolis acceptance
            forward = np.sum(gauss**2, axis=1)
            backward = np.sum((gauss + tstep * (grad + new_grad)) ** 2, axis=1)
            t_prob = np.exp(1/(2*tstep) * (forward - backward))
            ratio = np.abs(new_val)**2 * t_prob
            
            # Accept/reject step
            accept = ratio > np.random.random(ratio.shape)
            
            # Update positions
            config[accept, e, :] = new_pos[accept, :]
            
            # Determinant updates
            s = int(e >= mol.nelec[0])
            if accept.any():
                if s == 0:
                    updets, up_inverse = recompute(config, atomic_orbital, occup_hash, mo_coeff, mol.nelec, s=0)
                    updets = updets[:, :, det_map[0]]
                    inverse = (up_inverse, inverse[1])
                else:
                    dndets, down_inverse = recompute(config, atomic_orbital, occup_hash, mo_coeff, mol.nelec, s=1)
                    dndets = dndets[:, :, det_map[1]]
                    inverse = (inverse[0], down_inverse)
            
            acc += np.mean(accept) / nelec
        
        # Energy calculation
        energies = compute_total_energy(
            mol, config, mo_coeff, occup_hash, mol.nelec,
            det_map, det_coeff, updets, dndets, inverse
        )
        
        # Update block averages
        for key in block_avg.keys():
            if key != 'acceptance':
                block_avg[key] += np.mean(energies[key]) / nsteps_per_block
        block_avg['acceptance'] += acc / nsteps_per_block
        
        if step % 100 == 0:
            print(f"Block {block+1}, Step {step}: E = {np.mean(energies['total']):.6f}, Acc = {acc:.3f}")
    
    # Store block results
    block_energies.append(block_avg['total'])
    block_acceptance.append(block_avg['acceptance'])
    
    # Print block summary
    print(f"\nBlock {block+1} summary:")
    print(f"Energy components (atomic units):")
    for key, val in block_avg.items():
        if key != 'acceptance':
            print(f"{key.upper()}: {val:.6f}")
    print(f"Acceptance ratio: {block_avg['acceptance']:.3f}")

# Final statistical analysis
block_energies = np.array(block_energies)
mean_energy = np.mean(block_energies)
error_energy = np.std(block_energies) / np.sqrt(nblocks)
mean_acceptance = np.mean(block_acceptance)



Starting block 1/10
Block 1, Step 0: E = -1.131193, Acc = 0.950

Block 1 summary:
Energy components (atomic units):
KE: 0.754891
EE: 0.608928
EI: -3.216751
II: 0.715105
TOTAL: -1.137828
Acceptance ratio: 0.803

Starting block 2/10
Block 2, Step 0: E = -1.040040, Acc = 0.650

Block 2 summary:
Energy components (atomic units):
KE: 0.451954
EE: 0.521773
EI: -2.851694
II: 0.715105
TOTAL: -1.162863
Acceptance ratio: 0.713

Starting block 3/10
Block 3, Step 0: E = -1.053515, Acc = 0.500

Block 3 summary:
Energy components (atomic units):
KE: 0.203047
EE: 0.494712
EI: -2.440256
II: 0.715105
TOTAL: -1.027393
Acceptance ratio: 0.547

Starting block 4/10
Block 4, Step 0: E = -0.961665, Acc = 0.650

Block 4 summary:
Energy components (atomic units):
KE: 0.241882
EE: 0.478157
EI: -2.482720
II: 0.715105
TOTAL: -1.047576
Acceptance ratio: 0.564

Starting block 5/10
Block 5, Step 0: E = -0.893844, Acc = 0.550

Block 5 summary:
Energy components (atomic units):
KE: 0.426873
EE: 0.561529
EI: -2.721657

In [407]:
print(f"\nFinal Results after {nblocks} blocks:")
print(f"Energy: {mean_energy:.6f} ± {error_energy:.6f}")
print(f"Mean acceptance ratio: {mean_acceptance:.3f}")


Final Results after 10 blocks:
Energy: -1.057754 ± 0.019050
Mean acceptance ratio: 0.537
