In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import os

os.environ["JAX_PLATFORM_NAME"] = "cpu"
from jax import jit, numpy as jnp, lax
from jax.tree_util import register_pytree_node_class
import matplotlib.pyplot as plt
from functools import partial
from typing import Optional, Sequence

print = partial(print, flush=True)
from dataclasses import dataclass

np.set_printoptions(suppress=True, precision=5)

In [None]:
def make_phonon_basis(n_sites, max_phonons):
    basis = []
    coefficients = [1 for _ in range(n_sites)]
    for i in range(max_phonons + 1):
        basis += frobenius(n_sites, coefficients, i)
    return basis


def frobenius(n, coefficients, target):
    """
    Enumerates solutions of the Frobenius equation with n integers.

    Args:
    - n: int, number of integers in the solution
    - coefficients: list of ints, coefficients for the Frobenius equation
    - target: int, target value for the Frobenius equation

    Returns:
    - list of tuples, each tuple represents a solution of the Frobenius equation
    """
    dp = [0] + [-1] * target  # Initialize the dynamic programming array

    for i in range(1, target + 1):
        for j in range(n):
            if coefficients[j] <= i and dp[i - coefficients[j]] != -1:
                dp[i] = j
                break

    if dp[target] == -1:
        return []  # No solution exists

    solutions = []
    current_solution = [0] * n

    def get_solution(i, remaining):
        if i == -1:
            if remaining == 0:
                solutions.append(jnp.array(current_solution))
            return

        for j in range(remaining // coefficients[i], -1, -1):
            current_solution[i] = j
            get_solution(i - 1, remaining - j * coefficients[i])

    get_solution(n - 1, target)
    return solutions


@dataclass
@register_pytree_node_class
class one_dimensional_chain:
    n_sites: int
    shape: tuple = (1,)

    def __post_init__(self):
        self.shape = (self.n_sites,)
        self.sites = tuple([(i,) for i in range(self.n_sites)])

    def get_site_num(self, pos):
        return pos[0]

    def make_polaron_basis(self, max_n_phonons):
        phonon_basis = make_phonon_basis(self.n_sites, max_n_phonons)
        polaron_basis = tuple(
            [(i,), phonon_state]
            for i in range(self.n_sites)
            for phonon_state in phonon_basis
        )
        return polaron_basis

    def make_polaron_basis_n(self, n_bands, max_n_phonons):
        phonon_basis = make_phonon_basis(self.n_sites, max_n_phonons)
        assert self.sites is not None
        electronic_basis = tuple(
            [(n, site) for n in range(n_bands) for site in self.sites]
        )
        polaron_basis = tuple(
            [site, phonon_state]
            for site in electronic_basis
            for phonon_state in phonon_basis
        )
        return polaron_basis

    def __hash__(self):
        return hash(tuple(self.__dict__.values()))

    def tree_flatten(self):
        return (), tuple(self.__dict__.values())

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data)


@dataclass
@register_pytree_node_class
class three_dimensional_grid:
    l_x: int
    l_y: int
    l_z: int
    shape: Optional[tuple] = None
    shell_distances: Optional[Sequence] = None
    sites: Optional[Sequence] = None
    bonds: Optional[Sequence] = None
    n_sites: Optional[int] = None
    coord_num: int = 6

    def __post_init__(self):
        self.shape = (self.l_z, self.l_y, self.l_x)
        self.n_sites = self.l_x * self.l_y * self.l_z
        distances = []
        for x in range(self.l_x // 2 + 1):
            for y in range(self.l_y // 2 + 1):
                for z in range(self.l_z // 2 + 1):
                    dist = x**2 + y**2 + z**2
                    distances.append(dist)
        distances = [*set(distances)]
        distances.sort()
        self.shell_distances = tuple(distances)
        self.sites = tuple(
            [
                (
                    i // (self.l_x * self.l_y),
                    (i % (self.l_x * self.l_y)) // self.l_x,
                    (i % (self.l_x * self.l_y)) % self.l_x,
                )
                for i in range(self.l_x * self.l_y * self.l_z)
            ]
        )

    def make_polaron_basis(self, max_n_phonons):
        phonon_basis = make_phonon_basis(self.l_x * self.l_y * self.l_z, max_n_phonons)
        assert self.sites is not None
        polaron_basis = tuple(
            [site, phonon_state.reshape((self.l_x, self.l_y, self.l_z))]
            for site in self.sites
            for phonon_state in phonon_basis
        )
        return polaron_basis

    def make_polaron_basis_n(self, n_bands, max_n_phonons):
        phonon_basis = make_phonon_basis(self.l_x * self.l_y * self.l_z, max_n_phonons)
        assert self.sites is not None
        electronic_basis = tuple(
            [(n, site) for n in range(n_bands) for site in self.sites]
        )
        polaron_basis = tuple(
            [site, phonon_state.reshape((self.l_x, self.l_y, self.l_z))]
            for site in electronic_basis
            for phonon_state in phonon_basis
        )
        return polaron_basis

    def get_site_num(self, pos):
        return pos[2] + self.l_x * pos[1] + (self.l_x * self.l_y) * pos[0]

    def get_distance(self, pos_1, pos_2):
        dist_z = jnp.min(
            jnp.array(
                [jnp.abs(pos_1[0] - pos_2[0]), self.l_z - jnp.abs(pos_1[0] - pos_2[0])]
            )
        )
        dist_y = jnp.min(
            jnp.array(
                [jnp.abs(pos_1[1] - pos_2[1]), self.l_y - jnp.abs(pos_1[1] - pos_2[1])]
            )
        )
        dist_x = jnp.min(
            jnp.array(
                [jnp.abs(pos_1[2] - pos_2[2]), self.l_x - jnp.abs(pos_1[2] - pos_2[2])]
            )
        )
        dist = dist_x**2 + dist_y**2 + dist_z**2
        shell_number = jnp.searchsorted(jnp.array(self.shell_distances), dist)
        return shell_number

    # ignoring side length 1 and 2 special cases
    def get_nearest_neighbors(self, pos):
        right = (pos[0], (pos[1] + 1) % self.l_y, pos[2])
        down = ((pos[0] + 1) % self.l_z, pos[1], pos[2])
        left = (pos[0], (pos[1] - 1) % self.l_y, pos[2])
        up = ((pos[0] - 1) % self.l_z, pos[1], pos[2])
        front = (pos[0], pos[1], (pos[2] + 1) % self.l_x)
        back = (pos[0], pos[1], (pos[2] - 1) % self.l_x)
        neighbors = [right, down, left, up, front, back]
        return jnp.array(neighbors)

    def __hash__(self):
        return hash(tuple(self.__dict__.values()))

    def tree_flatten(self):
        return (), tuple(self.__dict__.values())

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data)

In [None]:
# assumes x and y have equal momentum
@partial(jit, static_argnums=(5,))
def ham_element_kq(x, y, e_k, omega_q, g_kq, lattice):
    x_elec = x[0]
    x_phonon = x[1:]
    y_elec = y[0]
    y_phonon = y[1:]
    lattice_shape = jnp.array(lattice.shape)
    diff = x_phonon - y_phonon
    diff_sites = jnp.where(x_phonon == y_phonon, 0, 1)
    diff_count = jnp.sum(diff_sites)

    # diagonal term
    diag = (
        (jnp.sum(x_elec - y_elec) == 0)
        * (diff_count == 0)
        * (jnp.sum(omega_q * x_phonon) + e_k[x_elec])
    )
    # off-diagonal term
    elec_k_change = (x_elec - y_elec) % lattice_shape
    k_i = lattice.get_site_num([y_elec])
    q_i = lattice.get_site_num(elec_k_change)
    # phonon addition
    q_c = (y_elec - x_elec) % lattice_shape
    off_diag_1 = (
        (diff_count == 1)
        * (jnp.sum(diff) == 1)
        * g_kq[k_i, q_i]
        * jnp.sqrt(y_phonon[q_c] + 1)
    )
    # phonon removal
    q_d = (x_elec - y_elec) % lattice_shape
    off_diag_2 = (
        (diff_count == 1)
        * (jnp.sum(diff) == -1)
        * g_kq[k_i, q_i]
        * jnp.sqrt(y_phonon[q_d])
    )
    return diag + off_diag_1 + off_diag_2

In [None]:
n_sites = 6
omega_0 = 1.0
g = 1.0
t = 1.0
max_n_phonons = 4
lattice = one_dimensional_chain(n_sites)

# k space
omega_q = np.array([omega_0 for _ in range(n_sites)])
e_k = t * np.array([-2.0 * np.cos(2.0 * np.pi * k / n_sites) for k in range(n_sites)])
g_kq = np.array([[-g / n_sites**0.5 for _ in range(n_sites)] for _ in range(n_sites)])

basis = lattice.make_polaron_basis(max_n_phonons=max_n_phonons)
# separate basis based on k
bases_k = [[] for _ in range(n_sites)]
states_k = []
energies_k = []
basis_dict = {}
counters = [0 for _ in range(n_sites)]
for b in basis:
    phonon_k = np.zeros(len(lattice.shape), dtype=int)
    for site in lattice.sites:
        phonon_k += b[1][site] * np.array(site)
    total_k = (np.array(b[0]) + phonon_k) % np.array(lattice.shape)
    flat_basis_state = np.concatenate([np.array(b[0]), b[1].flatten()])
    bases_k[np.sum(total_k)].append(flat_basis_state)
    basis_dict[(b[0], tuple(np.array(b[1])))] = [
        np.sum(total_k),
        counters[np.sum(total_k)],
    ]
    counters[np.sum(total_k)] += 1

ham_element = partial(
    ham_element_kq, e_k=e_k, omega_q=omega_q, g_kq=g_kq, lattice=lattice
)
for k_i in range(n_sites):
    basis_k = bases_k[k_i]
    basis_k = jnp.array(basis_k)

    def scan_inner(carry, x):
        def inner(carry, y):
            return carry, ham_element(x, y)

        _, result = lax.scan(inner, None, basis_k)
        return carry, result

    _, ham_mat = lax.scan(scan_inner, None, basis_k)
    ham_mat = ham_mat.reshape(ham_mat.shape[0], ham_mat.shape[1])
    print(f"Built hamiltonian for k = {k_i} with shape {ham_mat.shape}")
    energies, states = jnp.linalg.eigh(ham_mat)
    print(energies[:5], "\n")
    states_k.append(states)
    energies_k.append(energies)

In [None]:
# zero T spectral function
eta = 0.05
omega = np.linspace(-4, 6, 1000)
specs = []
for k in range(n_sites):
    states = states_k[k]
    energies = energies_k[k]
    # find the state with zero phonons
    k_state_index = basis_dict[((k,), tuple(np.zeros(lattice.shape, dtype=int)))][1]
    abs_states_sq = np.abs(states[k_state_index, :]) ** 2
    energy_diff = omega[:, np.newaxis] - energies[np.newaxis, :]
    spec = np.sum(abs_states_sq * eta / (energy_diff**2 + eta**2) / np.pi, axis=1)
    specs.append(spec)

specs = np.array(specs)

shift = 0.8
for k in range(specs.shape[0]):
    kp = k
    if kp > specs.shape[0] // 2:
        kp = (1 - (k == 0)) * (specs.shape[0] - k)
    plt.fill_between(omega, specs[k] + shift * k, shift * k, color=f"C{kp}", alpha=0.2)
    plt.plot(omega, specs[k] + shift * k, color=f"C{kp}", linestyle="-", alpha=1.0)
plt.xlim(-4, 6)
plt.xlabel(r"$\omega$")

In [None]:
eta = 0.05
omega = np.linspace(-4, 6, 1000)
specs_t = []

for beta in [2.0]:
    z = 0.0
    specs = np.array([0.0 * omega for _ in range(n_sites)])
    for b in basis:
        energy_phonon = np.sum(omega_q * b[1])
        kpp, n = basis_dict[(b[0], tuple(np.array(b[1])))]
        if kpp > n_sites // 2:
            b_flipped = [
                ((1 - (b[0][0] == 0)) * (n_sites - b[0][0]),),
                np.concatenate([b[1][:1], np.flip(b[1][1:])]),
            ]
            kpp, n = basis_dict[(b_flipped[0], tuple(np.array(b_flipped[1])))]
        k = b[0][0]
        energies_kpp = energies_k[kpp]
        states_kpp = states_k[kpp]
        weights = np.abs(states_kpp[n, :]) ** 2
        reshaped_omega = omega.reshape(-1, 1)
        reshaped_energies = energies_kpp.reshape(1, -1)
        denominator = (reshaped_omega - reshaped_energies + energy_phonon) ** 2 + eta**2
        lorentzian = eta / (denominator * np.pi)
        specs[k] += np.sum(weights * lorentzian, axis=1) * np.exp(-beta * energy_phonon)
        z += np.exp(-beta * energy_phonon) / n_sites
    specs /= z
    specs_t.append(specs)

In [None]:
specs = specs_t[0]
shift = 0.8
for k in range(specs.shape[0]):
    kp = k
    if kp > specs.shape[0] // 2:
        kp = (1 - (k == 0)) * (specs.shape[0] - k)
    plt.fill_between(omega, specs[k] + shift * k, shift * k, color=f"C{kp}", alpha=0.2)
    plt.plot(omega, specs[k] + shift * k, color=f"C{kp}", linestyle="-", alpha=1.0)
plt.xlim(-4, 6)
plt.xlabel(r"$\omega$")

In [None]:
@partial(jit, static_argnums=(5,))
def ham_element_kq_multiband(x, y, e_n_k, omega_q, g_mn_kq, lattice):
    lattice_ndim = len(lattice.shape)

    x_elec_n = x[0]  # band index
    x_elec = x[1 : 1 + lattice_ndim]  # electron momentum
    x_phonon = x[1 + lattice_ndim :]  # phonon occupation

    y_elec_n = y[0]  # band index
    y_elec = y[1 : 1 + lattice_ndim]  # electron momentum
    y_phonon = y[1 + lattice_ndim :]  # phonon occupation

    lattice_shape = jnp.array(lattice.shape)
    diff = x_phonon - y_phonon
    diff_sites = jnp.where(x_phonon == y_phonon, 0, 1)
    diff_count = jnp.sum(diff_sites)

    # diagonal term
    diag = (
        (x_elec_n == y_elec_n)
        * (jnp.sum(x_elec - y_elec) == 0)
        * (diff_count == 0)
        * (jnp.sum(omega_q * x_phonon) + e_n_k[x_elec_n, lattice.get_site_num(x_elec)])
    )

    # off diagonal terms
    elec_k_change = (x_elec - y_elec) % lattice_shape
    k_i = lattice.get_site_num(y_elec)
    q_i = lattice.get_site_num(elec_k_change)

    q_c = (y_elec - x_elec) % lattice_shape
    q_c_i = lattice.get_site_num(q_c)
    creation_term = (
        (diff_count == 1)
        * (jnp.sum(diff) == 1)
        * (diff[q_c_i] == 1)
        * g_mn_kq[x_elec_n, y_elec_n, k_i, q_i]
        * jnp.sqrt(y_phonon[q_c_i] + 1)
    )

    q_d = (x_elec - y_elec) % lattice_shape
    q_d_i = lattice.get_site_num(q_d)
    annihilation_term = (
        (diff_count == 1)
        * (jnp.sum(diff) == -1)
        * (diff[q_d_i] == -1)
        * g_mn_kq[x_elec_n, y_elec_n, k_i, q_i]
        * jnp.sqrt(y_phonon[q_d_i])
    )

    return diag + creation_term + annihilation_term


@partial(jit, static_argnums=1)
def calc_momentum(phonon_occ, lattice):
    def scanned_fun(carry, x):
        carry += phonon_occ[tuple(x)] * x
        return carry, x

    phonon_k = jnp.zeros(len(lattice.shape), dtype=int)
    phonon_k, _ = lax.scan(scanned_fun, phonon_k, jnp.array(lattice.sites))
    return phonon_k

In [None]:
n_sites = 6
n_bands = 2
omega_0 = 1.0
g = 1.0
t = 1.0
max_n_phonons = 4
lattice = one_dimensional_chain(n_sites)

basis = lattice.make_polaron_basis_n(n_bands=n_bands, max_n_phonons=max_n_phonons)
print(f"built basis, length: {len(basis)}")

bases_k = [[] for _ in range(lattice.n_sites)]
basis_dict = {}
counters = [0 for _ in range(lattice.n_sites)]
k_no_phonon_ind = None

for b in basis:
    phonon_k = calc_momentum(b[1], lattice)
    band_idx = b[0][0]
    elec_k = np.array(b[0][1])
    total_k = (elec_k + phonon_k) % np.array(lattice.shape)
    flat_basis_state = np.concatenate([[band_idx], elec_k, b[1].flatten()])
    k_group = np.sum(total_k)
    bases_k[k_group].append(flat_basis_state)
    basis_dict[(b[0], tuple(np.array(b[1])))] = [k_group, counters[k_group]]
    counters[k_group] += 1

omega_q = jnp.array([omega_0 for _ in range(n_sites)])
e_n_k = jnp.array(
    [
        [t * (-2.0 * np.cos(2.0 * np.pi * k / n_sites)) for k in range(n_sites)],
        [t * (-1.0 * np.cos(2.0 * np.pi * k / n_sites)) for k in range(n_sites)],
    ]
)
g_mn_kq = jnp.array(
    [
        [
            [[-g / n_sites**0.5 for _ in range(n_sites)] for _ in range(n_sites)]
            for m in range(n_bands)
        ]
        for n in range(n_bands)
    ]
)
# g_mn_kq = jnp.array(
#     [
#         [
#             [
#                 [-g / n_sites**0.5 if m + n == 0 else 0.0 for _ in range(n_sites)]
#                 for _ in range(n_sites)
#             ]
#             for m in range(n_bands)
#         ]
#         for n in range(n_bands)
#     ]
# )
# ham_element_kq_multiband(bases_k[0][0], bases_k[0][1], e_n_k, omega_q, g_mn_kq, lattice)

ham_element = partial(
    ham_element_kq_multiband,
    e_n_k=e_n_k,
    omega_q=omega_q,
    g_mn_kq=g_mn_kq,
    lattice=lattice,
)

states_k = []
energies_k = []

for k_i in range(lattice.n_sites):
    if len(bases_k[k_i]) == 0:
        states_k.append(None)
        energies_k.append(None)
        continue

    basis_k = jnp.array(bases_k[k_i])

    def scan_inner(carry, x):
        def inner(carry, y):
            return carry, ham_element(x, y)

        _, result = lax.scan(inner, None, basis_k)
        return carry, result

    _, ham_mat = lax.scan(scan_inner, None, basis_k)
    ham_mat = ham_mat.reshape(ham_mat.shape[0], ham_mat.shape[1])

    print(f"Built Hamiltonian for k = {k_i} with shape {ham_mat.shape}")

    assert np.allclose(ham_mat, ham_mat.T.conj()), "Hamiltonian is not Hermitian"
    # ham_mat = (ham_mat + jnp.conj(ham_mat.T)) / 2.0

    energies, states = jnp.linalg.eigh(ham_mat)
    print(energies[:5], "\n")

    states_k.append(states)
    energies_k.append(energies)

In [None]:
# zero T spectral function
eta = 0.05
omega = np.linspace(-4, 6, 1000)
specs = []
for k in range(n_sites):
    states = states_k[k]
    energies = energies_k[k]
    # find the state with zero phonons
    k_state_index = basis_dict[((0, (k,)), tuple(np.zeros(lattice.shape, dtype=int)))][
        1
    ]
    abs_states_sq = np.abs(states[k_state_index, :]) ** 2
    energy_diff = omega[:, np.newaxis] - energies[np.newaxis, :]
    spec = np.sum(abs_states_sq * eta / (energy_diff**2 + eta**2) / np.pi, axis=1)
    specs.append(spec)

specs = np.array(specs)

shift = 0.8
for k in range(specs.shape[0]):
    kp = k
    if kp > specs.shape[0] // 2:
        kp = (1 - (k == 0)) * (specs.shape[0] - k)
    plt.fill_between(omega, specs[k] + shift * k, shift * k, color=f"C{kp}", alpha=0.2)
    plt.plot(omega, specs[k] + shift * k, color=f"C{kp}", linestyle="-", alpha=1.0)
plt.xlim(-4, 6)
plt.xlabel(r"$\omega$")