## Import Libraries

In [1]:
import os
import numpy as np
from collections import Counter
from itertools import product
import random

from pyscf import gto, scf, lo

import netket as nk
import netket.experimental as nkx

import flax.linen as nn
import jax.numpy as jnp
import jax

## Define Molecular

In [2]:
mol = gto.M(atom='H 0 0 0; H 0 0 0.75; H 0 0 1.5; H 0 0 2.25', basis='sto-3g')   # H4 chain for example case
mol.build()
mf = scf.RHF(mol).run()
n_orb = mf.mo_coeff.shape[1]

print(mf.mo_coeff)   # Molecular Orbital Coefficients
'''
Represents the transformation matrix from atomic orbitals to molecular orbitals
'''
print("number of orbitals:", n_orb)

converged SCF energy = -2.10329082299987
[[ 0.23646159  0.54410393  0.90037498 -0.87827586]
 [ 0.40172325  0.39852889 -0.63532245  1.63712543]
 [ 0.40172325 -0.39852889 -0.63532245 -1.63712543]
 [ 0.23646159 -0.54410393  0.90037498  0.87827586]]
number of orbitals: 4


## FCI Energy

In [3]:
from pyscf import fci

cisolver = fci.FCI(mol, mf.mo_coeff)
fci_energy = cisolver.kernel()[0]

print("FCI Energy:", fci_energy)

FCI Energy: -2.14511064718622


## Partition Molecular

In [4]:
n_active = 2
n_bath = n_orb - n_active
mo_coeff = lo.Boys(mol).kernel(mf.mo_coeff)   # Localized molecular orbitals
mo_active = mo_coeff[:, :n_active]    # How molecular orbitals are linearly combined from atomic orbitals
mo_bath = mo_coeff[:, n_active:]

print("active space:", mo_active)
print("bath space:", mo_bath)

active space: [[ 0.17290591  0.46998491]
 [ 0.44341077  0.53187291]
 [ 0.44877447 -0.53147553]
 [ 0.17013577 -0.4702411 ]]
bath space: [[ 0.91489617 -0.91988559]
 [-0.60462494  1.59967119]
 [-0.60536158 -1.59802807]
 [ 0.91504347  0.92012463]]


In [5]:
# Hamiltonian in atomic orbital space
h_ao = mf.get_hcore()
# Project to active MO space → get active space Hamiltonian
h_active = mo_active.T @ h_ao @ mo_active

## Define the second-quantized Hamiltonian acting on the active space

In [6]:
H = nkx.operator.from_pyscf_molecule(mol, mo_coeff=mo_active)  # The second-quantized Hamiltonian acting on the active space
print(H)

FermionOperator2nd(hilbert=SpinOrbitalFermions(n_orbitals=2, s=1/2, n_fermions=4, n_fermions_per_spin=(2, 2)), n_operators=27, dtype=float64)


## Define Hilbert Space

In [7]:
hi_active = H.hilbert

## Conditional Variational Wave Function Model $\alpha(\sigma \mid \eta)$

In [8]:
# class ConditionalMLP(nn.Module):
#     '''
#     Return: log psi(sigma | eta)
#     '''
#     n_active: int
#     n_bath: int

#     @nn.compact
#     def __call__(self, sigma_eta):
#         x = sigma_eta
#         x = nn.Dense(64)(x)
#         x = nn.tanh(x)
#         x = nn.Dense(64)(x)
#         x = nn.tanh(x)
#         x = nn.Dense(1)(x)
#         return jnp.squeeze(x, axis=-1)   # output: log amplitude

In [9]:
class ConditionalMLP(nn.Module):
    '''
    Return: log ψ(σ | η)
    '''
    n_active: int
    n_bath: int
    hidden_dim: int = 64  # tunable

    @nn.compact
    def __call__(self, sigma_eta):
        x = sigma_eta.astype(jnp.float32)  
        x = nn.Dense(self.hidden_dim)(x)
        x = nn.relu(x)                     
        x = nn.Dense(self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)

        # log-amplitude
        return jnp.squeeze(x, axis=-1)

## Bath Configuration $\beta(\eta)$

In [10]:
def sample_bath_distribution(n_bath, n_samples):
    configs = list(product([0, 1], repeat=n_bath))
    sampled = random.choices(configs, k=n_samples)
    counter = Counter(sampled)
    result = [(np.array(eta), count / n_samples) for eta, count in counter.items()]  # (eta, p_eta)
    return result

eta_distribution = sample_bath_distribution(n_bath, 500)

## Use Variatioanl Monte Carlo (VMC) to approximate ground state energy

In [11]:
# class ConditionedState(nk.vqs.MCState):
#     def __init__(self, sampler, model, eta, n_samples=1024, debug=False):
#         self.eta = eta
#         self.debug = debug
#         super().__init__(sampler, model, n_samples=n_samples)   # sigma.shape = (1024, n_active)

#     def log_prob(self, sigma):
#         '''
#         Sigma shape: (n_chains, samples_per_chain, n_active)
#         '''
#         # sigma_eta = jnp.concatenate([sigma, jnp.tile(self.eta, (sigma.shape[0], 1))], axis=1)

#         eta_broadcasted = jnp.broadcast_to(self.eta, sigma.shape[:-1] + self.eta.shape)
#         sigma_eta = jnp.concatenate([sigma, eta_broadcasted], axis=-1)

#         # Debug: 印出 sigma 和 sigma_eta
#         if self.debug:
#             print("\nSampled sigma (first 1):")
#             print(sigma[:1])
#             print("Corresponding sigma_eta (first 1):")
#             print(sigma_eta[:1])

#         # return self.model(sigma_eta)
#         return self.model.apply(self.params, sigma_eta)

In [12]:
class ConditionedState(nk.vqs.MCState):
    def __init__(self, sampler, model, params, eta_array, n_samples=1024, debug=False):
        self.debug = debug

        def apply_fun(params, sigma, **kwargs):
            eta_broadcasted = jnp.broadcast_to(eta_array, sigma.shape[:-1] + eta_array.shape)
            sigma_eta = jnp.concatenate([sigma, eta_broadcasted], axis=-1)

            if debug:
                print("sigma shape:", sigma.shape)
                print("sigma_eta shape:", sigma_eta.shape)

            return model.apply(params, sigma_eta, **kwargs)


        super().__init__(
            sampler=sampler,
            apply_fun=apply_fun,
            variables=params,
            n_samples=n_samples,
        )

In [13]:
E_total = 0.0
E_eta_list = [] # record each eta expectation

for eta_array, p_eta in eta_distribution:
    sampler = nk.sampler.MetropolisLocal(hi_active)

    n_active = hi_active.size  
    n_bath = eta_array.shape[0]

    # print(n_active, n_bath)
    
    flax_model = ConditionalMLP(n_active=n_active, n_bath=n_bath)

    # Mpdel Initialization
    dummy_sigma = jnp.zeros((1, n_active))
    dummy_eta = jnp.reshape(eta_array, (1, n_bath)) 
    dummy_input = jnp.concatenate([dummy_sigma, dummy_eta], axis=-1)
    params = flax_model.init(jax.random.PRNGKey(0), dummy_input)

    vstate = ConditionedState(sampler, flax_model, params, eta_array)

    # VMC optimization
    opt = nk.optimizer.Sgd(learning_rate=0.01)
    sr = nk.optimizer.SR(diag_shift=0.1)
    driver = nk.driver.VMC(H, optimizer=opt, variational_state=vstate, preconditioner=sr)
    driver.run(n_iter=500, show_progress=False)  

    # Estimate E_η
    energy_stats = vstate.expect(H)
    E_eta = energy_stats.mean.item()  
    weighted_energy = p_eta * E_eta

    print(f"E_eta: {E_eta:.6f}, p_eta: {p_eta:.4f}, weighted: {weighted_energy:.6f}")
    E_eta_list.append((eta_array, p_eta, E_eta))

    E_total += weighted_energy

E_eta: -2.060501, p_eta: 0.2320, weighted: -0.478036
E_eta: -2.062798, p_eta: 0.2260, weighted: -0.466192
E_eta: -2.063180, p_eta: 0.2940, weighted: -0.606575
E_eta: -2.063231, p_eta: 0.2480, weighted: -0.511681


In [14]:
print(E_eta_list)
print(E_total)

[(array([1, 1]), 0.232, -2.0605005766899587), (array([0, 0]), 0.226, -2.0627977021436754), (array([0, 1]), 0.294, -2.0631804075442433), (array([1, 0]), 0.248, -2.063231399685427)]
-2.0624848414165347
