# Variational Monte Carlo (VMC) for Ground State Energy

This notebook demonstrates how to use the custom VMC library to find the ground state energy of a quantum system using:
- Custom Hamiltonian builder from `lib/vmc/hamiltonian`
- Custom RBM architectures from `lib/vmc/rbm`
- Custom samplers from `lib/vmc/samplers`

We'll use the **Transverse Field Ising Model** as our example:
$$H = -J \sum_{i} Z_i Z_{i+1} + h \sum_i X_i$$

This is a classic model in quantum statistical mechanics and showcases VMC capabilities.

## 1. Setup and Imports

First, let's import all necessary libraries and configure JAX.

## 2. Define the Physical System

We'll study the 1D Transverse Field Ising Model (TFIM) with:
- System size: L = 8 spins
- Periodic boundary conditions
- J = 1.0 (ferromagnetic coupling)
- h = -0.5 (transverse field strength)

The Hamiltonian is:
$$H = -\sum_{i=0}^{L-1} Z_i Z_{(i+1) \mod L} - 0.5 \sum_{i=0}^{L-1} X_i$$

In [None]:
# System parameters
L = 8  # Number of spins
J = -1.0  # Ising coupling (negative for ferromagnetic)
h = -0.5  # Transverse field strength

print(f"System: 1D Transverse Field Ising Model")
print(f"  Size: L = {L}")
print(f"  Ising coupling: J = {J}")
print(f"  Transverse field: h = {h}")
print(f"  Boundary conditions: Periodic")

In [None]:
import sys
import os

# Add lib directory to path
sys.path.insert(0, os.path.join(os.getcwd(), 'lib'))

# JAX configuration
import jax
jax.config.update("jax_enable_x64", True)

import jax.random as random
import jax.numpy as jnp
import flax.linen as nn

# Standard libraries
import numpy as np
import matplotlib.pyplot as plt

# jVMC library
import jVMC

# Custom VMC library - Hamiltonian module
from vmc.hamiltonian import (
    build_hamiltonian_from_pauli_strings,
    print_hamiltonian_terms
)

# Custom VMC library - RBM module
from vmc.rbm import create_sparse_rbm

# Custom VMC library - Connectivity patterns
from vmc.connectivity import (
    fully_connected_pattern,
    local_connectivity_pattern,
    nearest_neighbor_pattern,
    stripe_pattern,
    random_sparse_pattern,
    ring_pattern,
    checkerboard_pattern,
    hierarchical_pattern,
    visualize_connectivity,
    print_connectivity_stats,
    compare_patterns
)

# Custom VMC library - Samplers and update proposers
from vmc.samplers import (
    custom_single_flip,
    custom_k_spin_flip,
    custom_domain_flip,
    custom_neighbor_swap,
    custom_cluster_flip,
    adaptive_proposer,
    create_sampler_with_proposer,
    test_proposer_efficiency
)

print("✓ All imports successful!")
print(f"JAX backend: {jax.lib.xla_bridge.get_backend().platform}")
print("\n📦 Available modules:")
print("  • vmc.hamiltonian - Build custom Hamiltonians from Pauli strings")
print("  • vmc.rbm - Create sparse RBM architectures")
print("  • vmc.connectivity - 8 connectivity patterns + visualization tools")
print("  • vmc.samplers - 6 update proposers + sampler utilities")

## 3. Build the Hamiltonian

Using our custom Hamiltonian builder, we'll construct the TFIM.

In [None]:
# Build Hamiltonian terms
terms = []

# Add ZZ interactions (nearest neighbor with periodic BC)
for i in range(L):
    j = (i + 1) % L  # Periodic boundary conditions
    terms.append((J, "ZZ", (i, j)))

# Add transverse field (X terms)
for i in range(L):
    terms.append((h, "X", (i,)))

# Print the Hamiltonian
print("Hamiltonian Terms:")
print("="*60)
print_hamiltonian_terms(terms)
print("="*60)

# Build the jVMC Hamiltonian
hamiltonian = build_hamiltonian_from_pauli_strings(terms)
print("\n✓ Hamiltonian successfully constructed!")

## 4. Create the Neural Quantum State (RBM)

We'll use a custom sparse RBM with local connectivity pattern where each visible unit connects to nearby hidden units.

In [None]:
# RBM parameters
num_hidden = 8  # Number of hidden units

# Create connectivity pattern
# Option 1: Fully connected (default)
connections = fully_connected_pattern(L, num_hidden)

# Option 2: Local connectivity (uncomment to use)
# connections = local_connectivity_pattern(L, num_hidden, locality=2)

print(f"RBM Configuration:")
print(f"  Visible units: {L}")
print(f"  Hidden units: {num_hidden}")
print(f"  Number of connections: {len(connections)}")
print(f"  Sparsity: {len(connections) / (L * num_hidden):.2%}")

# Create the sparse RBM network
net = create_sparse_rbm(
    num_visible=L,
    num_hidden=num_hidden,
    connections=connections,
    bias=False
)

# Wrap in NQS (Neural Quantum State)
psi = jVMC.vqs.NQS(net, seed=1234)

print("\n✓ RBM neural quantum state created!")

## 5. Set Up the Monte Carlo Sampler

We'll use a Metropolis-Hastings sampler with single spin flips.

In [None]:
# Sampler parameters
num_chains = 50          # Number of parallel Markov chains
num_samples = 5000       # Total samples per VMC step
thermalization = 25      # Thermalization sweeps
sweep_steps = L          # Steps per sweep

# Create sampler using custom proposer
sampler = create_sampler_with_proposer(
    psi=psi,
    system_size=L,
    proposer_type='single_flip',  # Standard single spin flip
    seed=4321,
    num_chains=num_chains,
    sweep_steps=sweep_steps,
    num_samples=num_samples,
    thermalization_sweeps=thermalization
)

print(f"Monte Carlo Sampler Configuration:")
print(f"  Number of chains: {num_chains}")
print(f"  Samples per step: {num_samples}")
print(f"  Thermalization: {thermalization} sweeps")
print(f"  Sweep steps: {sweep_steps}")
print("\n✓ Sampler configured!")

## 6. Run Variational Monte Carlo Optimization

We'll use TDVP (Time-Dependent Variational Principle) with Euler time stepping to find the ground state.

In [None]:
# VMC optimization parameters
n_steps = 200            # Number of optimization steps
time_step = 1e-2         # Time step for Euler integrator

# Set up TDVP equation
tdvpEquation = jVMC.util.tdvp.TDVP(
    sampler,
    rhsPrefactor=1.0,
    snrTol=1e-2,         # Signal-to-noise ratio tolerance
    diagonalShift=10,    # Regularization
    makeReal='real'      # Force real energies
)

# Set up Euler time stepper
stepper = jVMC.util.stepper.Euler(timeStep=time_step)

print(f"VMC Optimization Configuration:")
print(f"  Optimization steps: {n_steps}")
print(f"  Time step: {time_step}")
print(f"  Method: TDVP with Euler stepper")
print("\n" + "="*60)
print("Starting VMC optimization...")
print("="*60)

In [None]:
# Storage for results
results = []

# Main VMC loop
for step in range(n_steps):
    # Perform one optimization step
    dp, _ = stepper.step(
        0, 
        tdvpEquation, 
        psi.get_parameters(), 
        hamiltonian=hamiltonian, 
        psi=psi, 
        numSamples=None
    )
    
    # Update parameters
    psi.set_parameters(dp)
    
    # Extract energy and variance
    energy_mean = jax.numpy.real(tdvpEquation.ElocMean0) / L  # Energy per site
    energy_var = tdvpEquation.ElocVar0 / L  # Variance per site
    
    # Store results
    results.append([step, energy_mean, energy_var])
    
    # Print progress every 20 steps
    if step % 20 == 0 or step == n_steps - 1:
        print(f"Step {step:4d}: E/L = {energy_mean:+.6f}, Var/L = {energy_var:.6e}")

results = np.array(results)
print("="*60)
print("✓ VMC optimization completed!")

## 7. Analyze Results

Let's plot the energy convergence and compare with theoretical expectations.

In [None]:
# Calculate exact ground state energy (for comparison)
# For TFIM, we can calculate this using exact diagonalization for small systems
def exact_tfim_energy(L, J, h):
    """
    Approximate exact ground state energy per site for TFIM.
    For h/J = 0.5, the system is in the ordered phase.
    """
    # Simple analytical approximation (not exact for finite size)
    # At h=0: E0/L = J (all spins aligned)
    # At large h: E0/L ≈ h (all spins polarized in x)
    # Rough interpolation for this parameter regime
    g = abs(h / J)
    if g < 1:  # Ordered phase
        return J * (1 - 0.5 * g**2)
    else:  # Paramagnetic phase
        return h * (1 - 0.5 / g**2)

exact_energy = exact_tfim_energy(L, J, h)
final_energy = results[-1, 1]

print(f"\nResults Summary:")
print(f"  Final VMC energy/L: {final_energy:.6f}")
print(f"  Approx exact energy/L: {exact_energy:.6f}")
print(f"  Difference: {abs(final_energy - exact_energy):.6f}")
print(f"  Final variance/L: {results[-1, 2]:.6e}")

In [None]:
# Create visualization
fig, ax = plt.subplots(2, 1, figsize=(10, 8), sharex=True)

# Plot energy convergence
ax[0].plot(results[:, 0], results[:, 1], 'b-', linewidth=2, label='VMC Energy')
ax[0].axhline(y=exact_energy, color='r', linestyle='--', linewidth=1.5, 
              label='Approx. Exact')
ax[0].set_ylabel('Energy per site $E/L$', fontsize=12)
ax[0].legend(fontsize=10)
ax[0].grid(True, alpha=0.3)
ax[0].set_title(f'VMC Ground State Search: TFIM (L={L}, h={h})', fontsize=14)

# Plot variance (log scale)
ax[1].semilogy(results[:, 0], results[:, 2], 'g-', linewidth=2)
ax[1].set_xlabel('Optimization Step', fontsize=12)
ax[1].set_ylabel('Energy Variance per site $\mathrm{Var}(E)/L$', fontsize=12)
ax[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('vmc_ground_state_convergence.pdf', dpi=300, bbox_inches='tight')
plt.show()

print("\n✓ Plot saved as 'vmc_ground_state_convergence.pdf'")

## 8. Test the Converged State

Let's verify the converged wavefunction by sampling and computing observables.

In [None]:
# Generate samples from the converged state
sample_configs, _, _ = sampler.sample()

print(f"Generated samples: {sample_configs.shape}")
print(f"Sample configurations (first 5):")
print(sample_configs[:5])

# Calculate magnetization observables
mag_x = jnp.mean(sample_configs, axis=1)  # Average spin in Z basis (configs are in Z basis)
avg_mag = jnp.mean(jnp.abs(mag_x))

print(f"\n<|M_z|>/L: {avg_mag:.4f}")
print("(Should be close to 1 in ordered phase, 0 in paramagnetic phase)")

# Compute local energy statistics
Eloc = hamiltonian.get_O_loc(sample_configs, psi)
E_mean = jnp.mean(jnp.real(Eloc)) / L
E_std = jnp.std(jnp.real(Eloc)) / L

print(f"\nLocal Energy Statistics:")
print(f"  Mean E/L: {E_mean:.6f}")
print(f"  Std  E/L: {E_std:.6f}")
print(f"  Relative error: {E_std / abs(E_mean) * 100:.2f}%")

## Summary

In this notebook, we successfully:

1. ✅ Built a custom Hamiltonian using the `vmc.hamiltonian` module
2. ✅ Created a sparse RBM neural quantum state using `vmc.rbm`
3. ✅ Set up a Monte Carlo sampler using `vmc.samplers`
4. ✅ Ran VMC optimization using TDVP to find the ground state
5. ✅ Analyzed convergence and computed observables

### Key Features Demonstrated

- **Modular design**: Each component (Hamiltonian, RBM, sampler) is independent
- **Flexibility**: Easy to change system size, connectivity patterns, or update proposers
- **Integration**: Seamlessly works with jVMC's optimization machinery
- **Reproducibility**: All random seeds are controlled

### Next Steps

Try modifying:
- System size `L`
- Transverse field strength `h` (try crossing the phase transition at h/J ≈ 1)
- RBM connectivity pattern (local vs. fully connected)
- Sampler update strategy (k-spin flip, cluster updates)
- Number of hidden units

## 1. Setup and Imports

First, let's import all necessary libraries and configure JAX.