In [None]:
"""Reproduce Fig6 of [Phys. Status Solidi B 258, 2000081 (2021)]."""


In [None]:
import numpy as np
from scipy.constants import pi
import scipy.linalg as la
import matplotlib
from matplotlib.gridspec import GridSpec
import matplotlib.pyplot as plt

np.set_printoptions(threshold=np.inf)
np.set_printoptions(linewidth=np.inf)

# Set the figure font
matplotlib.rcParams["font.family"] = "STIXGeneral"
matplotlib.rcParams["font.serif"] = "STIXGeneral"
matplotlib.rcParams["mathtext.fontset"] = "stix"

In [None]:
sigma_0 = np.array([[1, 0], [0, 1]])
sigma_x = np.array([[0, 1], [1, 0]])
sigma_y = np.array([[0, -1j], [1j, 0]])
sigma_z = np.array([[1, 0], [0, -1]])
sigma_plus = 1 / 2 * (sigma_z + sigma_0)
sigma_minus = 1 / 2 * (sigma_z - sigma_0)

s_0 = np.array([[1, 0], [0, 1]])
s_x = np.array([[0, 1], [1, 0]])
s_y = np.array([[0, -1j], [1j, 0]])
s_z = np.array([[1, 0], [0, -1]])

In [None]:
def H(kx, ky, **kwargs):
    """Define the kp Hamiltonian of graphene."""

    # This kp Hamiltonian takes Dirac point as the original point, namely, K = [0, 0].
    # Thus there is no need to subtract K point from kpoint.

    vF = kwargs["vF"]
    Delta = kwargs["Delta"]
    lambda_A = kwargs["lambda_A"]
    lambda_B = kwargs["lambda_B"]
    lambda_R = kwargs["lambda_R"]
    lambda_A_PIA = kwargs["lambda_A_PIA"]
    lambda_B_PIA = kwargs["lambda_B_PIA"]
    E_D = kwargs["E_D"]
    tau = kwargs["tau"]
    nbnd = kwargs["nbnd"]

    # The paper says: The model Hamiltonian
    # given in the basis |\Psi_A, ↑>, |\Psi_A, ↓>, |\Psi_B, ↑>, |\Psi_B, ↓>,
    # Thus, pseudospin operator (\sigma_\alpha) locates at the former
    # while spin operator (s_\alpha) locates at the latter in H_R
    H_0 = kwargs["hbar"] * vF * np.kron((tau * kx * sigma_x - ky * sigma_y), s_0)
    H_Delta = Delta * np.kron(sigma_z, s_0)
    H_I = tau * np.kron((lambda_A * sigma_plus + lambda_B * sigma_minus), s_z)
    H_R = -lambda_R * (tau * np.kron(sigma_x, s_y) + np.kron(sigma_y, s_x))
    H_PIA = kwargs["a"] * np.kron(
        (lambda_A_PIA * sigma_plus - lambda_B_PIA * sigma_minus), (kx * s_y - ky * s_x)
    )
    E_D = E_D * np.eye(nbnd)

    H = H_0 + H_Delta + H_I + H_R + H_PIA + E_D

    eigenvalues, eigenvectors = la.eigh(H)

    # Arrange the eigenvalues from the smallest to the largest by argsort()[::-1],
    # while keep the correct corresponding ordering of eigenvectors
    idx = eigenvalues.argsort()[::-1]
    eigenvalues = eigenvalues[idx]
    eigenvectors = eigenvectors[:, idx]

    return eigenvalues, eigenvectors

In [None]:
def generate_points(**kwargs):
    """
    Generates a list of 2D points along a specified direction (M-K-G).

    Returns:
        list: A list of tuples representing the 2D points along the path.
    """
    xmin = kwargs["xmin"]  # x coordinate of the starting point of the path.
    xmax = kwargs["xmax"]  # x coordinate of the ending point of the path.
    num_points = kwargs["num_points"]  # number of points to generate.
    angle = kwargs["angle"]  # angle of graphene being rotated (twisted) in degree unit.
    # For the Sb2Te3/Gr, graphene is rotated by 60 deg for compensation
    # Calculate the starting and ending points
    angle = np.deg2rad(angle)
    ymin = np.tan(angle) * xmin
    ymax = np.tan(angle) * xmax
    start = (xmin, ymin)
    end = (xmax, ymax)
    # Calculate the total distance to cover between start and end points
    distance = np.linalg.norm(np.array(end) - np.array(start))
    # Determine spacing between points
    spacing = distance / (num_points - 1)
    # Generate points along the path direction
    # MKG is one the same line if you consider supercell of BZ
    direction = np.array([np.cos(angle), np.sin(angle)])
    points = [tuple(start + i * spacing * direction) for i in range(num_points)]

    return points

In [None]:
def calc_bands(**kwargs):
    """Calculate eigenvalues of given kpoints.

    Returns:
        kpath (array)
        energies (array): Eigenvalues of all kpoints and bands.
    """
    points = generate_points(**kwargs)
    energies = np.zeros((kwargs["num_points"], kwargs["nbnd"]))
    kpath = np.zeros(kwargs["num_points"])
    for i in range(kwargs["num_points"]):
        kx = points[i][0]
        ky = points[i][1]
        # Generate kpath which is symmetric to [0, 0] point
        kpath[i] = np.sqrt(kx**2 + ky**2) * -1 if kx < 0 else np.sqrt(kx**2 + ky**2)
        # Calculate eigenvalues
        e, k = H(kx, ky, **kwargs)
        energies[i, :] = e

    return kpath, energies

In [None]:
def calc_spins(operator, **kwargs):
    """Calculate spin expectations of given kpoints.

    Parameters:
        operator (array): Pauli matrix.
    Returns:
        kpath (array)
        spins (array): Eigenvalues of all kpoints and bands.
    """
    points = generate_points(**kwargs)
    spins = np.zeros((kwargs["num_points"], kwargs["nbnd"]))
    kpath = np.zeros(kwargs["num_points"])
    for i in range(kwargs["num_points"]):
        kx = points[i][0]
        ky = points[i][1]
        # Generate kpath which is symmetric to [0, 0] point
        kpath[i] = np.sqrt(kx**2 + ky**2) * -1 if kx < 0 else np.sqrt(kx**2 + ky**2)

        # Generate spin array with shape of (num_points, nbnd)
        e, k = H(kx, ky, **kwargs)  # Get eigenvals and eigenfuncs
        spin_k = np.zeros(kwargs["nbnd"])
        for j in range(kwargs["nbnd"]):
            spin = (
                np.transpose(np.conjugate(k[:, j].reshape(kwargs["nbnd"], 1)))
                @ np.kron(np.eye(2), operator)
                @ k[:, j].reshape(kwargs["nbnd"], 1)
            )  # Pauli mat is 2x2, operating on the vector
            spin_k[j] = spin.item().real
        spins[i, :] = spin_k * 0.5  # Define max as 0.5, thus the spin unit is hbar/2

    return kpath, spins

In [None]:
def plot_bands(**kwargs):
    """Plot eigenvalues."""

    kpath, energies = calc_bands(**kwargs)

    for j in range(kwargs["nbnd"]):
        plt.plot(kpath, energies[:, j])

    plt.tick_params(axis="x", which="both", top=False, direction="in")
    plt.tick_params(axis="y", which="both", top=False, direction="in")
    plt.hlines(0, plt.xlim()[0], plt.xlim()[1], lw=0.5, colors="gray")
    plot_ticks(**kwargs)
    plt.ylabel(r"$\mathrm{E}-\mathrm{E}_\mathrm{F}$ (meV)", fontsize=11)
    plt.ylim(-12, 8)

In [None]:
def plot_spinz_bands(**kwargs):
    """Plot bands projected by spinz."""

    kpath, energies = calc_bands(**kwargs)
    kpath, spins = calc_spins(operator=s_z, **kwargs)

    for i in range(kwargs["nbnd"]):
        plt.scatter(kpath, energies[:, i], c=spins[:, i], s=1, cmap="coolwarm")
        plt.clim(-0.5, 0.5)  # max spin is 0.5 acc to left figures in Fig6

    plt.tick_params(axis="x", which="both", top=False, direction="in")
    plt.tick_params(axis="y", which="both", top=False, direction="in")
    plt.hlines(0, plt.xlim()[0], plt.xlim()[1], lw=0.5, colors="gray")
    plt.ylabel(r"$\mathrm{E}-\mathrm{E}_\mathrm{F}$ [meV]")
    plt.ylim(-12, 8)
    plot_ticks(**kwargs)

In [None]:
def plot_spins(operator, ib, color, **kwargs):
    """Plot spin expectation values."""

    kpath, spins = calc_spins(operator, **kwargs)
    plt.plot(kpath, spins[:, ib], c=color)
    plt.ylim(-0.5, 0.5)

In [None]:
def plot_split(**kwargs):
    """Plot energy splitting of conduction / valence bands."""

    kpath, energies = calc_bands(**kwargs)
    label = "$\Delta$E$_\mathrm{CB}$"
    plt.plot(kpath, energies[:, 0] - energies[:, 1], c="r", label=label)
    label = "$\Delta$E$_\mathrm{VB}$"
    plt.plot(kpath, energies[:, 2] - energies[:, 3], c="b", label=label)
    plt.legend(fontsize=9)
    plt.ylim(0.3, 0.6)
    plt.ylabel("splitting [meV]")
    plot_ticks(**kwargs)

In [None]:
def plot_ticks(**kwargs):
    """Format figure."""

    ax = plt.gca()
    ax.set_xlim(kwargs["xmin"], kwargs["xmax"])
    space = " " * 20  # tune the space in xlabel
    bottomlabel = (
        "M" + space + r"$\leftarrow$" + "K" + r"$\rightarrow$" + space + r"$\Gamma$"
    )
    ax.set_xlabel(bottomlabel)
    ax.set_xticks([])
    ax.tick_params(axis="both", which="both", top=False, bottom=False, direction="in")

    ax2 = plt.gca().twiny()
    ax2.set_xlim(kwargs["xmin"], kwargs["xmax"])
    ax2.set_xlabel(r"k [nm$^{-1}$]")
    ax2.tick_params(axis="both", which="both", direction="in")

    plt.tight_layout()

In [None]:
def plot_figure(**kwargs):
    """Plot figure as Fig6."""

    fig = plt.figure(figsize=(6, 6), dpi=144)
    # Control the scale of rows and columns in figure
    a = 1
    b = 1
    c = 0.5
    gs = GridSpec(
        7,
        3,
        width_ratios=(a, c, a),
        height_ratios=(a, b, a, b, a, b, a),
        hspace=0,
        wspace=0,
        figure=fig,
    )

    ylabels = [r"CB$_2$", r"CB$_1$", r"VB$_2$", r"VB$_1$"]
    for i in range(kwargs["nbnd"]):
        ax = plt.subplot(gs[i * 2, 0])
        plot_spins(operator=s_x, ib=i, color="k", **kwargs)
        plot_spins(operator=s_y, ib=i, color="r", **kwargs)
        plot_spins(operator=s_z, ib=i, color="b", **kwargs)
        ax.set_ylabel(ylabels[i])
        plot_ticks(**kwargs)

    plt.subplot(gs[0:5, 2])
    cax = plot_spinz_bands(**kwargs)

    ax = plt.subplot(gs[0:5, 1])
    ax.axis("off")
    cbar = plt.colorbar(
        cax,
        orientation="vertical",
        fraction=0.1,
        aspect=10,
        shrink=0.2,
        anchor=(-2, 0.5),
    )
    cbar.ax.tick_params(direction="in")

    plt.subplot(gs[6:, 2])
    plot_split(**kwargs)

In [None]:
# Define physical parameters.
# angle=60 because MKG line is along this direction BZ.
# The length unit is nm. Remember to change constant values if you change the unit.
kwargs = dict(
    a=2.486e-1,  # lattice constant in nm
    vF=8.119e5 * 1e9,  # velocity in nm/s
    hbar=6.582119569e-16 * 1e3,  # hbar in meV⋅s
    Delta=0.2e-3,  # in meV
    lambda_R=-0.221,  # in meV
    lambda_A=0.147,  # in meV
    lambda_B=-0.139,  # in meV
    lambda_A_PIA=2.623,  # in meV
    lambda_B_PIA=1.177,  # in meV
    E_D=-2,  # in meV
    tau=1,
    angle=60,  # in degree
    xmin=-2e-2,  # in nm-1
    xmax=2e-2,  # in nm-1
    num_points=1001,  # dimensionless
    nbnd=4,  # dimensionless
)

In [None]:
def main():
    plot_figure(**kwargs)
    plt.show()

In [None]:
if __name__ == "__main__":
    main()