In [1]:
import numpy as np
import numpy.linalg as nla
import numpy.random as random
import scipy.linalg as la
import scipy.sparse as sp
from matplotlib import pyplot as plt
from itertools import product
import time
from tqdm import tqdm
from IPython.display import clear_output
from numpy import fft

In [None]:
# -----------------------------# ------------------------------
# Majorana Wannier function constructor

def majorana_wannier_function(mu, R, N, nshell=None):
    # k-grids in radians per lattice spacing (FFT ordering)
    k = 2 * np.pi * np.fft.fftfreq(N, d=1.0)

    dy = -2 * np.sin(k)
    dz = -(2 * np.cos(k) + mu)
    dmag = np.sqrt(dy**2 + dz**2)
    dmag = np.where(dmag == 0, 1e-15, dmag)

    pauli_x = np.array([[0, 1], [1, 0]], dtype=complex)
    pauli_y = np.array([[0, -1j], [1j, 0]], dtype=complex)
    pauli_z = np.array([[1, 0], [0, -1]], dtype=complex)
    Id = np.eye(2, dtype=complex)

    dk = (dy[:, None, None] * pauli_y + dz[:, None, None] * pauli_z) / dmag[:, None, None]
    Pminus = 0.5 * (Id - dk)

    tau = (1 / np.sqrt(2)) * np.array([1, 1], dtype=complex)

    phase_shift = np.exp(-1j * k * R)  # shift χ(r) → χ(r - R)

    alpha_k = Pminus @ tau
    alpha_k /= np.linalg.norm(alpha_k, axis=1, keepdims=True)

    chi_R = np.empty((N, 2), dtype=complex)
    chi_R[:, 0] = np.fft.ifft(phase_shift * alpha_k[:, 0])
    chi_R[:, 1] = np.fft.ifft(phase_shift * alpha_k[:, 1])

    M = pauli_x + pauli_z

    Pi_R = np.empty((N, 2), dtype=complex)
    Pi_R = np.real(np.einsum("ij,jk->ik", chi_R, M))

    if nshell is not None:
        mask = (np.abs((R-np.arange(N)) % N).astype(int) <= int(nshell))
        Pi_R = Pi_R[mask, :]

    Pi_R[:, 0] /= np.linalg.norm(Pi_R[:, 0])
    Pi_R[:, 1] /= np.linalg.norm(Pi_R[:, 1])

    

    return Pi_R


In [None]:
import numpy as np
import matplotlib.pyplot as plt

def plot_majorana_pi_components(Pi_R):
    """
    Plot real/imag parts of the two Π components vs lattice site.
    The lattice index is fft-shifted so the chosen center sits at 0.
    """
    Pi_R = np.asarray(Pi_R)
    if Pi_R.ndim != 2 or Pi_R.shape[1] != 2:
        raise ValueError("Pi_R must have shape (N, 2).")

    N = Pi_R.shape[0]
    Pi_shift = np.fft.fftshift(Pi_R, axes=0)
    sites = np.arange(N) - N // 2

    components = [(r"$\Pi_1(r-R)$", Pi_shift[:, 0]), (r"$\Pi_2(r-R)$", Pi_shift[:, 1])]

    fig, axes = plt.subplots(1, 2, figsize=(11, 4), sharex=True)
    for ax, (label, data) in zip(np.atleast_1d(axes), components):
        ax.plot(sites, data.real, '-o', lw=1.2, color="#3366cc", label="Re")
        ax.axhline(0.0, color="0.4", lw=0.6)
        ax.set_title(label)
        ax.set_ylabel("Amplitude")
        ax.grid(alpha=0.25)

    axes[0].set_xlabel("Lattice site")
    axes[1].set_xlabel("Lattice site")
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc="upper center", ncol=2, frameon=False)
    fig.tight_layout(rect=(0, 0, 1, 0.92))
    plt.show()


In [None]:
mu = 4
R = 2
N = 32

Pi_R = majorana_wannier_function(mu, R, N, nshell=1)
plot_majorana_pi_components(Pi_R)


In [None]:
print(Pi_R[:,0] @ Pi_R[:,0])  
print(Pi_R[:,1] @ Pi_R[:,1])

In [None]:
mu = -1
R1 = 4
R2 = 4
N = 128
Pi_1 = majorana_wannier_function(mu, R1, N)
Pi_2 = majorana_wannier_function(mu, R2, N)
print(Pi_1[:,0] @ Pi_2[:,0])  
print(Pi_1[:,1] @ Pi_2[:,1])