In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import numpy as np
import scipy as sp
import os
from scipy.linalg import fractional_matrix_power

os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ["JAX_ENABLE_X64"] = "True"
from jax import vmap, jit, numpy as jnp, random, lax, jvp, scipy as jsp
import matplotlib.pyplot as plt
import pickle
from functools import reduce

import matplotlib.animation as animation
from IPython.display import HTML

module_path = os.path.abspath(os.path.join("/Users/ankitmahajan/softwares/ad_afqmc"))
if module_path not in sys.path:
    sys.path.append(module_path)

from ad_afqmc import driver, pyscf_interface, mpi_jax, linalg_utils, lattices

from pyscf import fci, gto, scf, mp, ao2mo

import itertools
from functools import partial

np.set_printoptions(precision=5, suppress=True)

In [None]:
@jit
def calc_s2(sd_up, sd_dn):
    green_up = (
        sd_up.dot(jnp.linalg.inv(sd_up.T.conj().dot(sd_up))).dot(sd_up.T.conj())
    ).T
    green_dn = (
        sd_dn.dot(jnp.linalg.inv(sd_dn.T.conj().dot(sd_dn))).dot(sd_dn.T.conj())
    ).T
    spsm = sd_up.shape[1] - jnp.trace(green_up.dot(green_dn))
    sz = (sd_up.shape[1] - sd_dn.shape[1]) / 2
    return spsm + sz * (sz - 1)


@jit
def calc_green(sd):
    return (sd.dot(jnp.linalg.inv(sd.T.conj().dot(sd))).dot(sd.T.conj())).T


@jit
def calc_norm(sd):
    return jnp.linalg.det(sd.T.conj().dot(sd))

In [None]:
@jit
def normalize(walker):
    walker = vmap(lambda x: x / jnp.linalg.norm(x), in_axes=1, out_axes=1)(walker)
    return walker

In [None]:
@partial(jit, static_argnums=(3,))
def ham_element_hubbard(x, y, u, lattice):
    n_sites = lattice.n_sites
    n_elec = jnp.array((jnp.sum(x[0]), jnp.sum(x[1])))
    diff = jnp.array((jnp.bitwise_xor(x[0], y[0]), jnp.bitwise_xor(x[1], y[1])))
    diff_count = jnp.array((jnp.sum(diff[0]), jnp.sum(diff[1])))
    on_site = (jnp.sum(diff_count) == 0) * u * jnp.bitwise_and(x[0], x[1]).sum()
    diff_pos = jnp.nonzero(diff, size=2)
    site_1 = diff_pos[1][0]
    site_2 = diff_pos[1][1]
    hopping = (jnp.sum(diff_count) == 2) * (
        (site_2 - site_1 == 1) * -1.0
        + (site_2 - site_1 == n_sites - 1)
        * (-1.0) ** (1 + jnp.sum((n_elec - 1) * diff_count / 2))
    )
    return on_site + hopping


def make_basis_spin(n_sites, n_elec):
    # generate permutations using lexicographic order
    basis = []
    elec = np.zeros(n_sites)
    for i in range(n_elec):
        elec[-i - 1] = 1
    basis.append(elec.copy())
    # find next permutation
    while True:
        k = -1
        for i in range(n_sites - 1):
            if elec[i] < elec[i + 1]:
                k = i
        if k == -1:
            break
        l = k
        for i in range(k + 1, n_sites):
            if elec[k] < elec[i]:
                l = i
        elec[k], elec[l] = elec[l], elec[k]
        elec[k + 1 :] = np.flip(elec[k + 1 :])
        basis.append(elec.copy())
    return np.array(basis, dtype=int)


def make_basis(n_sites, n_elec: tuple):
    basis_up = make_basis_spin(n_sites, n_elec[0])
    if n_elec[0] == n_elec[1]:
        basis_down = basis_up
    else:
        basis_down = make_basis_spin(n_sites, n_elec[1])
    basis = itertools.product(basis_up, basis_down)
    return np.array(list(basis))

In [None]:
n_sites = 4
n_elec = (2, 2)
u = 4.0
lattice = lattices.one_dimensional_chain(n_sites)
ci_basis = jnp.array(make_basis(n_sites, n_elec), dtype=jnp.int32)
print(f"built basis, length: {len(ci_basis)}")
ham_element = partial(ham_element_hubbard, u=u, lattice=lattice)
ham_mat = vmap(vmap(ham_element, (None, 0)), (0, None))(ci_basis, ci_basis)
print("built hamiltonian")
ene_mat, vec = jnp.linalg.eigh(ham_mat)
print(ene_mat[:10])

In [None]:
n_sites = 6
n_elec = (3, 3)
u = 4.0
integrals = {}
integrals["h0"] = 0.0
unit_vec = np.zeros(n_sites)
unit_vec[1] = -1.0
integrals["h1"] = sp.linalg.toeplitz(unit_vec)
# pbc
integrals["h1"][0, -1] = -1.0
integrals["h1"][-1, 0] = -1.0
# 2 x 4 grid
# integrals["h1"][0, 3] = -1.0
# integrals["h1"][3, 0] = -1.0
# integrals["h1"][4, 7] = -1.0
# integrals["h1"][7, 4] = -1.0
# integrals["h1"][1, 6] = -1.0
# integrals["h1"][6, 1] = -1.0
# integrals["h1"][2, 5] = -1.0
# integrals["h1"][5, 2] = -1.0
# 2 x 3 grid
integrals["h1"][1, 4] = -1.0
integrals["h1"][4, 1] = -1.0
# integrals["h1"][0, 2] = -1.0
# integrals["h1"][2, 0] = -1.0
# integrals["h1"][3, 5] = -1.0
# integrals["h1"][5, 3] = -1.0
h2 = np.zeros((n_sites, n_sites, n_sites, n_sites))
for i in range(n_sites):
    h2[i, i, i, i] = u
integrals["h2"] = ao2mo.restore(8, h2, n_sites)

# dummy molecule
mol = gto.Mole()
mol.nelectron = sum(n_elec)
mol.incore_anyway = True
mol.build()

# rhf
mf = scf.RHF(mol)
mf.get_hcore = lambda *args: integrals["h1"]
mf.get_ovlp = lambda *args: np.eye(n_sites)
mf._eri = ao2mo.restore(8, integrals["h2"], n_sites)
mf.kernel()

# fci from integrals
ci = fci.FCI(mol)
e, ci_coeffs = ci.kernel(
    h1e=integrals["h1"], eri=integrals["h2"], norb=n_sites, nelec=n_elec
)
print(f"fci energy: {e}")

# uhf
umf = scf.UHF(mol)
umf.get_hcore = lambda *args: integrals["h1"]
umf.get_ovlp = lambda *args: np.eye(n_sites)
umf._eri = ao2mo.restore(8, integrals["h2"], n_sites)
dm_init = 0.0 * umf.init_guess_by_1e()
for i in range(n_sites // 2):
    dm_init[0, 2 * i, 2 * i] = 1.0
    dm_init[1, 2 * i + 1, 2 * i + 1] = 1.0
dm_init += 0.1 * np.random.randn(*dm_init.shape)
umf.kernel(dm_init)

# ad afqmc
pyscf_interface.prep_afqmc(umf, mo_coeff=np.eye(n_sites), integrals=integrals)
options = {
    "dt": 0.005,
    "n_eql": 5,
    "n_ene_blocks": 1,
    "n_sr_blocks": 5,
    "n_blocks": 100,
    "n_walkers": 50,
    "seed": 98,
    "walker_type": "uhf",
    "save_walkers": True,
}

ham_data, ham, prop, trial, wave_data, observable, options = mpi_jax._prep_afqmc(
    options
)
e_afqmc, err_afqmc = driver.afqmc(
    ham_data, ham, prop, trial, wave_data, observable, options
)

In [None]:
# 2d square
# 1 based indexing for some reason
def findSiteAtRowNCol(row, col, size):
    if row % 2 == 1:
        return (row - 1) * size + col
    else:
        return row * size - (col - 1)


def findRowNColAtSite(site, size):
    row = (site - 1) // size + 1
    if row % 2 == 1:
        col = (site - 1) % size + 1
    else:
        col = size - (site - 1) % size
    return [row, col]


def findNeighbors(site, size):
    neighbors = []
    [row, col] = findRowNColAtSite(site, size)
    # up
    if row == 1:  # top edge
        neighbors.append(findSiteAtRowNCol(size, col, size))
    else:
        neighbors.append(findSiteAtRowNCol(row - 1, col, size))
    # left
    if col == 1:  # left edge
        neighbors.append(findSiteAtRowNCol(row, size, size))
    else:
        neighbors.append(findSiteAtRowNCol(row, col - 1, size))
    # down
    if row == size:  # bottom edge
        neighbors.append(findSiteAtRowNCol(1, col, size))
    else:
        neighbors.append(findSiteAtRowNCol(row + 1, col, size))
    # right
    if col == size:  # right edge
        neighbors.append(findSiteAtRowNCol(row, 1, size))
    else:
        neighbors.append(findSiteAtRowNCol(row, col + 1, size))
    return neighbors


size = 4
n_sites = size**2
u = 4.0
n_elec = (8, 8)

integrals = {}
integrals["h0"] = 0.0

h1 = np.zeros((n_sites, n_sites))
for i in range(n_sites):
    currentSite = i + 1
    neighbors = findNeighbors(currentSite, size)
    for neighbor in neighbors:
        h1[currentSite - 1, neighbor - 1] = -1.0
integrals["h1"] = h1

h2 = np.zeros((n_sites, n_sites, n_sites, n_sites))
for i in range(n_sites):
    h2[i, i, i, i] = u
integrals["h2"] = ao2mo.restore(8, h2, n_sites)

# make dummy molecule
mol = gto.Mole()
mol.nelectron = sum(n_elec)
mol.incore_anyway = True
mol.build()

mf = scf.RHF(mol)
mf.get_hcore = lambda *args: integrals["h1"]
mf.get_ovlp = lambda *args: np.eye(n_sites)
mf._eri = ao2mo.restore(8, integrals["h2"], n_sites)
mf.kernel()

umf = scf.UHF(mol)
umf.get_hcore = lambda *args: integrals["h1"]
umf.get_ovlp = lambda *args: np.eye(n_sites)
umf._eri = ao2mo.restore(8, integrals["h2"], n_sites)
dm_init = 0.0 * umf.init_guess_by_1e()
for i in range(n_sites // 2):
    dm_init[0, 2 * i, 2 * i] = 1.0
    dm_init[1, 2 * i + 1, 2 * i + 1] = 1.0
dm_init += 0.1 * np.random.randn(*dm_init.shape)
umf.kernel(dm_init)

# ad afqmc
pyscf_interface.prep_afqmc(umf, mo_coeff=np.eye(n_sites), integrals=integrals)
options = {
    "dt": 0.005,
    "n_eql": 5,
    "n_ene_blocks": 1,
    "n_sr_blocks": 5,
    "n_blocks": 100,
    "n_walkers": 50,
    "seed": 98,
    "walker_type": "uhf",
    "save_walkers": True,
}

ham_data, ham, prop, trial, wave_data, observable, options = mpi_jax._prep_afqmc(
    options
)
e_afqmc, err_afqmc = driver.afqmc(
    ham_data, ham, prop, trial, wave_data, observable, options
)

In [None]:
# triangular lattice


def find_nearest_neighbors(q, r, width, height):
    neighbors = [
        ((q + 1) % height, r % width),
        ((q - 1) % height, r % width),
        (q % height, (r + 1) % width),
        (q % height, (r - 1) % width),
        ((q + 1) % height, (r - (-1) ** q) % width),
        ((q - 1) % height, (r - (-1) ** q) % width),
    ]
    return neighbors


def create_adjacency_matrix(width, height):
    size = width * height
    h = np.zeros((size, size), dtype=int)

    for r in range(width):
        for q in range(height):
            i = q * width + r
            neighbors = find_nearest_neighbors(q, r, width, height)
            for nq, nr in neighbors:
                if 0 <= nq < height and 0 <= nr < width:  # Check bounds
                    j = nq * width + nr
                    h[i, j] = 1
                    h[j, i] = 1
    return h


width = 4
height = 3
h = create_adjacency_matrix(width, height)
print(h)

In [None]:
# n_sites = 5
# n_elec = (3, 2)
# u = 4.0
# integrals = {}
# integrals["h0"] = 0.0
# unit_vec = np.zeros(n_sites)
# unit_vec[1] = -1.0
# integrals["h1"] = sp.linalg.toeplitz(unit_vec)
# # pbc
# integrals["h1"][0, -1] = -1.0
# integrals["h1"][-1, 0] = -1.0
height = 4
width = 4
n_sites = height * width
u = 4.0
n_elec = (n_sites // 2, n_sites // 2)

integrals = {}
integrals["h0"] = 0.0

h1 = -1.0 * create_adjacency_matrix(width, height)
integrals["h1"] = h1

h2 = np.zeros((n_sites, n_sites, n_sites, n_sites))
for i in range(n_sites):
    h2[i, i, i, i] = u
integrals["h2"] = ao2mo.restore(8, h2, n_sites)

# dummy molecule
mol = gto.Mole()
mol.nelectron = sum(n_elec)
mol.incore_anyway = True
mol.spin = abs(n_elec[0] - n_elec[1])
mol.build()

# rhf
mf = scf.RHF(mol)
mf.get_hcore = lambda *args: integrals["h1"]
mf.get_ovlp = lambda *args: np.eye(n_sites)
mf._eri = ao2mo.restore(8, integrals["h2"], n_sites)
mf.kernel()

# # fci from integrals
# ci = fci.FCI(mol)
# e, ci_coeffs = ci.kernel(
#     h1e=integrals["h1"], eri=integrals["h2"], norb=n_sites, nelec=n_elec
# )
# print(f"fci energy: {e}")

# uhf
umf = scf.UHF(mol)
umf.get_hcore = lambda *args: integrals["h1"]
umf.get_ovlp = lambda *args: np.eye(n_sites)
umf._eri = ao2mo.restore(8, integrals["h2"], n_sites)
dm_init = 0.0 * umf.init_guess_by_1e()
for i in range(n_sites // 2):
    dm_init[0, 2 * i, 2 * i] = 1.0
    dm_init[1, 2 * i + 1, 2 * i + 1] = 1.0
dm_init += 0.1 * np.random.randn(*dm_init.shape)
umf.kernel(dm_init)
mo1 = umf.stability(external=True)[0]
umf = umf.newton().run(mo1, umf.mo_occ)
mo1 = umf.stability(external=True)[0]
umf = umf.newton().run(mo1, umf.mo_occ)
mo1 = umf.stability(external=True)[0]
umf = umf.newton().run(mo1, umf.mo_occ)

# ghf
gmf = scf.GHF(mol)
gmf.get_hcore = lambda *args: sp.linalg.block_diag(integrals["h1"], integrals["h1"])
gmf.get_ovlp = lambda *args: np.eye(2 * n_sites)
gmf._eri = ao2mo.restore(8, integrals["h2"], n_sites)
dm_init = sp.linalg.block_diag(dm_init[0], dm_init[1])
dm_init += 1.0 * np.random.randn(*dm_init.shape)
gmf.kernel(dm_init)
mo1 = gmf.stability(external=True)
gmf = gmf.newton().run(mo1, gmf.mo_occ)
mo1 = gmf.stability(external=True)
gmf = gmf.newton().run(mo1, gmf.mo_occ)
mo1 = gmf.stability(external=True)
gmf = gmf.newton().run(mo1, gmf.mo_occ)

# ad afqmc
pyscf_interface.prep_afqmc(umf, mo_coeff=np.eye(n_sites), integrals=integrals)
options = {
    "dt": 0.005,
    "n_eql": 5,
    "n_ene_blocks": 1,
    "n_sr_blocks": 5,
    "n_blocks": 100,
    "n_walkers": 50,
    "seed": 98,
    "walker_type": "uhf",
    "save_walkers": True,
}

ham_data, ham, prop, trial, wave_data, observable, options = mpi_jax._prep_afqmc(
    options
)

from ad_afqmc import wavefunctions, hamiltonian

trial = wavefunctions.ghf(n_sites, n_elec)
wave_data = gmf.mo_coeff[:, : n_elec[0] + n_elec[1]]
ham = hamiltonian.hamiltonian_ghf(n_sites, n_elec, ham_data["chol"].shape[0])
ham_data = ham.rot_ham(ham_data, wave_data)

e_afqmc, err_afqmc = driver.afqmc(
    ham_data, ham, prop, trial, wave_data, observable, options
)