# 1D Hubbard model exact diagonalization

Implements matrix vector product function for the Hubbard hamiltonian on a 1D chain. Does not make use of translational symmetry.


In [None]:
import os
import sys

module_path = os.path.abspath(os.path.join(os.getcwd() + "/.."))
if module_path not in sys.path:
    sys.path.append(module_path)
from nn_eph import lattices

import jax
from jax import jit, vmap, numpy as jnp
from itertools import product, combinations

from dataclasses import dataclass
from typing import Any, Tuple
from functools import partial

print = partial(print, flush=True)
import itertools
import numpy as np

In [None]:
def make_basis_spin(n_sites, n_elec):
    # generate permutations using lexicographic order
    basis = []
    elec = np.zeros(n_sites)
    for i in range(n_elec):
        elec[-i - 1] = 1
    basis.append(elec.copy())
    # find next permutation
    while True:
        k = -1
        for i in range(n_sites - 1):
            if elec[i] < elec[i + 1]:
                k = i
        if k == -1:
            break
        l = k
        for i in range(k + 1, n_sites):
            if elec[k] < elec[i]:
                l = i
        elec[k], elec[l] = elec[l], elec[k]
        elec[k + 1 :] = np.flip(elec[k + 1 :])
        basis.append(elec.copy())
    return np.array(basis, dtype=int)


def make_basis(n_sites, n_elec: tuple):
    basis_up = make_basis_spin(n_sites, n_elec[0])
    if n_elec[0] == n_elec[1]:
        basis_down = basis_up
    else:
        basis_down = make_basis_spin(n_sites, n_elec[1])
    basis = itertools.product(basis_up, basis_down)
    return np.array(list(basis))


@jit
def encode_basis_vector(vector):
    """Encode a binary occupation vector as an integer."""
    return int("".join(map(str, vector)), 2)


def make_basis_with_lookup(n_sites, n_elec):
    basis_up = make_basis_spin(n_sites, n_elec[0])
    if n_elec[0] == n_elec[1]:
        basis_down = basis_up
    else:
        basis_down = make_basis_spin(n_sites, n_elec[1])

    # Generate the full basis
    basis = list(itertools.product(basis_up, basis_down))

    # Encode each basis vector as a hashable tuple
    encoded_basis = [
        (encode_basis_vector(up), encode_basis_vector(down)) for up, down in basis
    ]

    # Create a dictionary for fast lookup
    lookup = {vec: idx for idx, vec in enumerate(encoded_basis)}

    return np.array(basis), lookup


def lookup_vector(basis_lookup, up_vector, down_vector):
    """Lookup the index of a basis vector in the precomputed dictionary."""
    encoded_vec = (encode_basis_vector(up_vector), encode_basis_vector(down_vector))
    return basis_lookup.get(encoded_vec, -1)  # Return -1 if not found

In [None]:
@dataclass
class hubbard:
    """Hubbard model Hamiltonian with (anti)periodic boundary conditions.

    Currently only works for 1D because parity is not implemented for a general lattice.

    Attributes
    ----------
    u : float
        On-site Coulomb repulsion
    n_orbs : int
        Number of orbitals
    n_elec : Sequence
        Number of electrons
    antiperiodic: bool
        Antiperiodic boundary conditions
    """

    u: float
    n_orbs: int
    n_elec: Tuple[int, int]
    antiperiodic: bool = False

    def encode_config(self, vector):
        """Encode a binary occupation vector as an integer."""
        return jnp.dot(vector, 2 ** jnp.arange(vector.size)[::-1])

    def make_basis_spin(self, n_elec_sp):
        """
        Generate all possible spin configurations with exactly n_elec electrons
        distributed among n_sites in JAX-friendly format.
        """
        n_sites = self.n_orbs
        # Generate all unique combinations of `n_elec` occupied sites out of `n_sites`
        indices = list(combinations(range(n_sites), n_elec_sp))

        # Create basis vectors where each combination sets the selected sites to 1
        def create_vector(idx):
            vec = jnp.zeros(n_sites, dtype=int)
            return vec.at[jnp.array(idx)].set(1)  # Convert idx to JAX array

        # Generate basis set as a JAX array
        basis = jnp.array([create_vector(idx) for idx in indices])
        return basis

    def make_basis(self):
        """Generate the full basis set with JAX-compatible structures, also calculate the lookup table"""
        n_sites = self.n_orbs
        n_elec = self.n_elec
        basis_up = make_basis_spin(n_sites, n_elec[0])
        if n_elec[0] == n_elec[1]:
            basis_down = basis_up
        else:
            basis_down = make_basis_spin(n_sites, n_elec[1])

        # Generate the full basis
        basis = jnp.array(list(product(basis_up, basis_down)))
        return basis

    def make_basis_with_lookup(self):
        basis = self.make_basis()

        # Encode each basis vector
        encoded_basis = jnp.array(
            [(self.encode_config(up), self.encode_config(down)) for up, down in basis]
        )

        # Flatten the encoded basis as unique keys
        max_key = jnp.max(encoded_basis[:, 0] + 1)

        keys = encoded_basis[:, 0] + encoded_basis[:, 1] * max_key
        values = jnp.arange(len(keys))

        # Sort keys and corresponding values
        sorted_indices = jnp.argsort(keys)
        sorted_keys = keys[sorted_indices]
        sorted_values = values[sorted_indices]

        return basis, (sorted_keys, sorted_values, max_key)

    @partial(jit, static_argnums=(0,))
    def lookup_config(self, sorted_map, config):
        """
        Perform a lookup using the sorted map with binary search.

        Parameters:
        - sorted_map: A tuple (sorted_keys, sorted_values)
        - config: Up and down configuration, shape (2, n_sites)

        Returns:
        - Index of the vector if found, otherwise -1.
        """
        up_config, down_config = config
        sorted_keys, sorted_values, max_key = sorted_map
        # Encode the query vector
        encoded_query = (
            self.encode_config(up_config) + self.encode_config(down_config) * max_key
        )

        # Use binary search to find the index of the query vector
        idx = jnp.searchsorted(sorted_keys, encoded_query)

        # Check if the key at the index matches the query
        match = jnp.isclose(
            sorted_keys[idx], encoded_query
        )  # Avoiding floating-point precision issues

        # if any element in config is different from 0 or 1, return -1
        invalid_config = jnp.any(jnp.logical_and(config != 0, config != 1))

        # Return the corresponding index or -1 if no match
        return (
            jnp.where(match, sorted_values[idx], -1) * (1 - invalid_config)
            - 1 * invalid_config
        )

    @partial(jit, static_argnums=(0, 2))
    def generate_excitations(self, configuration: jax.Array, lattice: Any) -> Tuple:
        """
        Generate all possible excitations for a given configuration.

        Parameters
        ----------
        configuration : jax.Array
            Spin up and down occupation numbers: (2, *lattice.shape)
        lattice : Any
            Lattice object

        Returns
        -------
        excitation : Tuple
            coefficients and configurations
        """

        elec_idx_up = jnp.nonzero(configuration[0].reshape(-1), size=self.n_elec[0])[0]
        elec_pos_up = jnp.array(lattice.sites)[elec_idx_up]
        elec_idx_dn = jnp.nonzero(configuration[1].reshape(-1), size=self.n_elec[1])[0]
        elec_pos_dn = jnp.array(lattice.sites)[elec_idx_dn]

        # diagonal
        diagonal = self.u * jnp.sum(configuration[0] * configuration[1])

        # edge hopping parities
        up_parity = (-1) ** (self.n_elec[0] - 1)
        dn_parity = (-1) ** (self.n_elec[1] - 1)
        parity = jnp.array([up_parity, dn_parity])

        # electron hops
        # scan over neighbors
        def hop(spin, elec_pos, neighbor, neighbor_edge_bond):
            new_configuration = configuration.at[(spin, *neighbor)].add(1)
            new_configuration = new_configuration.at[(spin, *elec_pos)].add(-1)
            coeff = (
                -1.0
                * (1 - 2 * neighbor_edge_bond * self.antiperiodic)
                * (1 + neighbor_edge_bond * (parity[spin] - 1))
            )
            return coeff, new_configuration

        # scan over electrons
        def outer_mapped_fun(spin, elec_pos):
            neighbors = lattice.get_nearest_neighbors(elec_pos)
            neighbor_edge_bond = lattice.get_nearest_neighbors_edge_bond(elec_pos)
            coeffs, new_configurations = vmap(hop, in_axes=(None, None, 0, 0))(
                spin, elec_pos, neighbors, neighbor_edge_bond
            )
            return coeffs, new_configurations

        coeffs_up, new_configurations_up = vmap(outer_mapped_fun, in_axes=(None, 0))(
            0, elec_pos_up
        )
        coeffs_dn, new_configurations_dn = vmap(outer_mapped_fun, in_axes=(None, 0))(
            1, elec_pos_dn
        )
        coeffs = jnp.concatenate([coeffs_up.reshape(-1), coeffs_dn.reshape(-1)])
        new_configurations = jnp.concatenate(
            [
                new_configurations_up.reshape(-1, 2, self.n_orbs),
                new_configurations_dn.reshape(-1, 2, self.n_orbs),
            ]
        )
        # add diagonal contibution at the end
        coeffs = jnp.concatenate([coeffs, jnp.array([diagonal])])
        new_configurations = jnp.concatenate(
            [new_configurations, jnp.array([configuration])]
        )
        return coeffs, new_configurations

    @partial(jit, static_argnums=(0,))
    def update_vec(self, sorted_map, coeffs, new_configurations, vec, coeff_0):
        """Look up indices for new configs and update the vector."""
        indices = vmap(self.lookup_config, in_axes=(None, 0))(
            sorted_map, new_configurations
        )
        # get rid of negative indices
        # this avoids jax non concrete size issue, relies on jax not updating out of bounds things
        indices = jnp.where(indices >= 0, indices, vec.size)
        vec = vec.at[indices].add(coeffs * coeff_0)
        return vec

    @partial(jit, static_argnums=(0, 2))
    def ham_vec_prod(self, vec, lattice, basis, sorted_map):
        """Apply the Hamiltonian to a state vector."""

        # loop over all configurations
        # defining as scan instead of vmap because of potential memory issues
        # carry: [h_v, i]
        def scanned_fun(carry, config):
            h_v, i = carry
            coeffs, new_configs = self.generate_excitations(config, lattice)
            return [
                self.update_vec(sorted_map, coeffs, new_configs, h_v, vec[i]),
                i + 1,
            ], 0

        carry, _ = jax.lax.scan(scanned_fun, [0.0 * vec, 0], basis)
        h_v, _ = carry
        return h_v

    @partial(jit, static_argnums=(0, 3))
    def ham_element(self, x, y, lattice):
        """Calculate the matrix element of the Hamiltonian betweem configs x and y, only 1D."""
        n_sites = lattice.n_sites
        n_elec = jnp.array((jnp.sum(x[0]), jnp.sum(x[1])))
        diff = jnp.array((jnp.bitwise_xor(x[0], y[0]), jnp.bitwise_xor(x[1], y[1])))
        diff_count = jnp.array((jnp.sum(diff[0]), jnp.sum(diff[1])))
        on_site = (
            (jnp.sum(diff_count) == 0) * self.u * jnp.bitwise_and(x[0], x[1]).sum()
        )
        diff_pos = jnp.nonzero(diff, size=2)
        site_1 = diff_pos[1][0]
        site_2 = diff_pos[1][1]
        edge_bond_q = (site_1 == 0) * (site_2 == n_sites - 1)
        hopping = (
            (jnp.sum(diff_count) == 2)
            * (
                (site_2 - site_1 == 1) * -1.0
                + (site_2 - site_1 == n_sites - 1)
                * (-1.0) ** (1 + jnp.sum((n_elec - 1) * diff_count / 2))
            )
            * (1 - 2 * edge_bond_q * self.antiperiodic)
        )
        return on_site + hopping

    def make_ham_mat(self, lattice):
        print("Building basis")
        basis = self.make_basis()
        print(f"Built basis, length: {len(basis)}")
        ham_element = partial(self.ham_element, lattice=lattice)
        ham_mat = vmap(vmap(ham_element, (None, 0)), (0, None))(basis, basis)
        return ham_mat

    def __hash__(self):
        return hash(tuple(self.__dict__.values()))

In [None]:
import scipy.sparse.linalg as spla


# wrapper for direct methods
class jax_linear_operator(spla.LinearOperator):
    def __init__(self, ham_vec_prod, vec_size, dtype=np.float64):
        self.ham_vec_prod = ham_vec_prod
        self.vec_size = vec_size
        super().__init__(dtype=dtype, shape=(vec_size, vec_size))

    def _matvec(self, x):
        # Convert numpy array to JAX array
        v_jax = jnp.array(x)
        # Perform matrix-vector product
        result = self.ham_vec_prod(v_jax)
        # Convert back to numpy
        return np.array(result)

In [None]:
n_sites = 6
lattice = lattices.one_dimensional_chain(n_sites)
n_elec = (n_sites // 2, n_sites // 2)
u = 4.0
ham = hubbard(u, n_sites, n_elec, antiperiodic=(n_sites % 4 == 0))
basis, sorted_map = ham.make_basis_with_lookup()
ham_mat = ham.make_ham_mat(lattice)
print("Built Hamiltonian matrix", flush=True)
ene, states = jnp.linalg.eigh(ham_mat)
print(f"energies: {ene[:6]}")

In [None]:
n_sites = 6
lattice = lattices.one_dimensional_chain(n_sites)
n_elec = (n_sites // 2, n_sites // 2)
u = 4.0
ham = hubbard(u, n_sites, n_elec, antiperiodic=(n_sites % 4 == 0))
basis, sorted_map = ham.make_basis_with_lookup()
print(f"Number of basis states: {len(basis)}")
vec_size = len(basis)
ham_vec_prod = partial(
    ham.ham_vec_prod, lattice=lattice, basis=basis, sorted_map=sorted_map
)
ham_op = jax_linear_operator(ham_vec_prod, vec_size)
eigenvalues, eigenvectors = spla.eigsh(ham_op, k=6, which="SA", tol=1.0e-4)
eigenvalues