In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import numpy as np
import os

from jax import jit, numpy as jnp
import pickle

module_path = os.path.abspath(os.path.join(os.getcwd() + "/.."))
if module_path not in sys.path:
    sys.path.append(module_path)

from ad_afqmc import driver, pyscf_interface, mpi_jax, wavefunctions

from pyscf import fci, gto, scf

from typing import Tuple
from dataclasses import dataclass

np.set_printoptions(precision=5, suppress=True)



# Hostname: MacBook-Air-548
# System Type: Darwin
# Machine Type: x86_64
# Processor: i386


### Minimal RHF implementation

Only the overlap needs to be defined (and `__hash__` for technical reasons, but a simple definitions as follows will do).


In [3]:
@dataclass
class rhf(wavefunctions.wave_function_auto):
    norb: int
    nelec: Tuple[int, int]

    def _calc_overlap(
        self, walker_up: jnp.ndarray, walker_dn: jnp.ndarray, wave_data: dict
    ) -> jnp.ndarray:
        return jnp.linalg.det(wave_data["mo_coeff"].T @ walker_up) * jnp.linalg.det(
            wave_data["mo_coeff"].T @ walker_dn
        )

    def __hash__(self):
        return hash(tuple(self.__dict__.values()))

In [None]:
r = 1.012
theta = 106.67 * np.pi / 180.0
rz = r * np.sqrt(np.cos(theta / 2) ** 2 - np.sin(theta / 2) ** 2 / 3)
dc = 2 * r * np.sin(theta / 2) / np.sqrt(3)
atomstring = f"""
                 N 0. 0. 0.
                 H 0. {dc} {rz}
                 H {r * np.sin(theta/2)} {-dc/2} {rz}
                 H {-r * np.sin(theta/2)} {-dc/2} {rz}
              """
mol = gto.M(atom=atomstring, basis="sto-6g", verbose=3, symmetry=0)
mf = scf.RHF(mol)
mf.kernel()

# fci
cisolver = fci.FCI(mol, mf.mo_coeff)
efci, ci = cisolver.kernel()
print(f"FCI energy: {efci}")

trial = rhf(mol.nao, mol.nelec)
wave_data = {
    "mo_coeff": jnp.eye(mol.nao)[:, : mol.nelec[0]]
}  # afqmc is performed in the rhf basis

# write trial to disk
with open("trial.pkl", "wb") as f:
    pickle.dump([trial, wave_data], f)

pyscf_interface.prep_afqmc(mf)
options = {
    "dt": 0.005,
    "n_eql": 3,
    "n_ene_blocks": 1,
    "n_sr_blocks": 5,
    "n_blocks": 100,
    "n_walkers": 50,
    "seed": 98,
    "walker_type": "rhf",
}
e_afqmc, err_afqmc = driver.afqmc(*(mpi_jax._prep_afqmc(options)))

A simple way to speed up the above calculation is to define a function for the restricted walker i.e. spin up and spin down dets are the same.


In [4]:
@dataclass
class rhf_faster(rhf):
    def _calc_overlap_restricted(self, walker, wave_data: dict) -> jnp.ndarray:
        return jnp.linalg.det(wave_data["mo_coeff"].T @ walker) ** 2

    def __hash__(self) -> int:
        return hash(tuple(self.__dict__.values()))

In [None]:
trial = rhf_faster(mol.nao, mol.nelec)
wave_data = {
    "mo_coeff": jnp.eye(mol.nao)[:, : mol.nelec[0]]
}  # afqmc is performed in the rhf basis

# write trial to disk
with open("trial.pkl", "wb") as f:
    pickle.dump([trial, wave_data], f)

pyscf_interface.prep_afqmc(mf)
options = {
    "dt": 0.005,
    "n_eql": 3,
    "n_ene_blocks": 1,
    "n_sr_blocks": 5,
    "n_blocks": 100,
    "n_walkers": 50,
    "seed": 98,
    "walker_type": "rhf",
}
e_afqmc, err_afqmc = driver.afqmc(*(mpi_jax._prep_afqmc(options)))