In [None]:
%reload_ext autoreload
%autoreload 2

import sqlite3
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.axes import Axes
from matplotlib.colors import Normalize

from GKTH.constants import kB
from GKTH.Global_Parameter import GlobalParams
from GKTH.Green_Function import GKTH_Greens
from GKTH.H_eig import GKTH_find_spectrum
from GKTH.k_space_flipping import GKTH_flipflip
from GKTH.Layer import Layer
from GKTH.self_consistency_delta import GKTH_self_consistency_1S_residual
from plotting import (
    FIGURE_SIZE,
    get_critical_current_data,
    get_critical_current_subplot,
    plot_critical_current,
    plot_for_lambda,
    plot_for_lambda_h_list,
    plot_for_lambda_zeros,
)
from script_single_layer import (
    PRESENTATION_MEDIA_DIR,
    get_contour,
    get_delta_vs_h,
    get_residuals
)

DATA_DIR = Path("data")


# Band energy spectrums


In [None]:
h = 0.01

# Default parameters and single default layer
p = GlobalParams()
layers = [Layer(_lambda=0)]

max_val_kx = np.max(np.abs(p.k1))
kx = np.linspace(-max_val_kx, max_val_kx, p.nkpoints * 2)
ky = np.linspace(-max_val_kx, max_val_kx, p.nkpoints * 2)
kx, ky = np.meshgrid(kx, ky)

p.h = h

# energy is 3d array, nkpoints x nkpoints x (4*nlayers)
# each k-point has its only eigenvalues list
energy = GKTH_find_spectrum(p, layers)

bands = []
for i in range(4 * len(layers)):
    energy_band = energy[:, :, i]
    new_band = GKTH_flipflip(energy_band)
    bands.append(new_band)

fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")

for i, band in enumerate(bands):
    # Skip bands with NaNs or infinities
    if np.isnan(band).any() or np.isinf(band).any():
        print(f"Skipping band {i} due to NaN or infinity values.")
        continue

    surf = ax.plot_surface(kx, ky, band, facecolor=f"C{i%10}", alpha=0.5)

ax.view_init(azim=-120, elev=20)  # adjust the view angle

ax.set_xlabel("$k_x$")
ax.set_ylabel("$k_y$")
ax.set_zlabel("Energy (eV)")
ax.set_title(f"Energy Spectrum\n$h={p.h}$ eV")

fig.savefig(
    PRESENTATION_MEDIA_DIR / f"energy_spectrum_h_{p.h}.svg",
    transparent=True,
    bbox_inches="tight",
    pad_inches=None,
)

In [None]:
%matplotlib inline
fig, axes = plt.subplots(2, 2, figsize=(10, 10))

for i, band in enumerate(bands):
    ax: Axes = axes[i // 2, i % 2]
    ax.imshow(band, cmap="viridis")
    ax.set_title(f"Band {i}")
    ax.set_xlabel("$k_x$")
    ax.set_ylabel("$k_y$")

fig.tight_layout()

fig.savefig(PRESENTATION_MEDIA_DIR / "bands.svg", transparent=True, bbox_inches="tight", pad_inches=None)

# Residual-Gap Plot


## Check root finding


In [None]:
plot_for_lambda_h_list(_lambda=0.15, h_list=[0, 0.009, 0.012])
plt.show()

In [None]:
%matplotlib inline
lambda_list = [0.0, 0.02, 0.04, 0.06, 0.08, 0.1, 0.15, 0.2]
for _lambda in lambda_list:
    fig = plot_for_lambda(_lambda)
    fig.savefig(
        PRESENTATION_MEDIA_DIR / f"residuals_delta_lambda_{_lambda}.svg",
        transparent=True,
        bbox_inches="tight",
        pad_inches=None,
    )
    plt.show()

In [None]:
def set_axes(ax: Axes):
    ax.set_xlim(0.0, 2.0)
    ax.set_ylim(-0.12, 0.04)


_lambda = 0.1
fig = plot_for_lambda(_lambda)
set_axes(fig.gca())
fig.savefig(
    PRESENTATION_MEDIA_DIR / f"residuals_delta_lambda_{_lambda}.svg",
    transparent=True,
    bbox_inches="tight",
    pad_inches=None,
)
plt.show()

fig = plot_for_lambda_zeros(_lambda)
set_axes(fig.gca())
fig.savefig(
    PRESENTATION_MEDIA_DIR / f"residuals_delta_lambda_{_lambda}_zeros.svg",
    transparent=True,
    bbox_inches="tight",
    pad_inches=None,
)
plt.show()

## Gap vs h


In [None]:
lambda_list = [0.1]

plt.figure(figsize=FIGURE_SIZE)

for _lambda in lambda_list:
    h, delta = get_delta_vs_h(_lambda)
    plt.plot(
        h * 1e3,
        delta * 1e3,
        linestyle="--",
        marker="o",
        label=rf"$\lambda = {_lambda}$",
    )

plt.legend()
plt.xlabel("h (meV)")
plt.ylabel(r"$\Delta_s$ (meV)")
plt.savefig(
    PRESENTATION_MEDIA_DIR / "delta_vs_h_lambda_0.1.svg",
    transparent=True,
    bbox_inches="tight",
    pad_inches=None,
)
plt.ylim(0, None)
plt.show()

In [None]:
plt.figure(figsize=FIGURE_SIZE)

lambda_list = [0.05, 0.1, 0.15, 0.2]
for _lambda in lambda_list:
    h, delta = get_delta_vs_h(_lambda)
    plt.plot(
        h * 1e3,
        delta * 1e3,
        linestyle="--",
        marker="o",
        label=rf"$\lambda = {_lambda}$",
    )

plt.legend()

# Log scale for x and y axes
plt.xscale("log")
plt.yscale("log")

plt.xlabel("h (meV)")
plt.ylabel(r"$\Delta_s$ (meV)")

plt.savefig(
    PRESENTATION_MEDIA_DIR / "delta_vs_h.svg",
    transparent=True,
    bbox_inches="tight",
    pad_inches=None,
)
plt.show()

# Stability


In [None]:
lambda_list = [0.1, 0.15, 0.2]
h_end_list = [1e-3, 2e-2, 5e-2]
max_Delta_list = [2e-3, 2e-2, 50e-3]
N = 41

for _lambda, h_end, max_Delta in zip(lambda_list, h_end_list, max_Delta_list):
    Delta_lin = np.round(np.linspace(0.00, max_Delta, N), 9)
    h_lin = np.round(np.linspace(0.00, h_end, N), 9)

    Delta_mesh, h_mesh = np.meshgrid(Delta_lin, h_lin)

    for Delta, h in zip(Delta_mesh.flatten(), h_mesh.flatten()):
        # Connect to database
        conn = sqlite3.connect(DATA_DIR / "residuals.db")
        c = conn.cursor()

        # Create table if it doesn't exist
        conn.execute(
            "CREATE TABLE IF NOT EXISTS residuals (lambda REAL, Delta REAL, h REAL, residual REAL)"
        )

        # If data is in the database, skip it
        c.execute(
            f"SELECT * FROM residuals WHERE lambda={_lambda} AND Delta={Delta} AND h={h}"
        )
        query = c.fetchone()
        if query is not None:
            residual = query[3]
            print(f"Delta: {Delta}, h: {h}, residual: {residual}")
            continue

        # Get result
        p = GlobalParams(h=h)
        layers = [Layer(_lambda=_lambda)]

        residual = GKTH_self_consistency_1S_residual(
            Delta_0_fit=Delta, p=p, layers=layers, layers_to_check=[0]
        )

        # Insert into database
        c.execute(f"INSERT INTO residuals VALUES ({_lambda}, {Delta}, {h}, {residual})")
        conn.commit()
        conn.close()

        print(f"Delta: {Delta}, h: {h}, residual: {residual}")

In [None]:
lambda_list = [0.1, 0.15, 0.2]
h_end_list = [1e-3, 2e-2, 5e-2]
max_Delta_list = [2e-3, 2e-2, 50e-3]

for _lambda, h_end, max_Delta in zip(lambda_list, h_end_list, max_Delta_list):
    Delta_lin = np.round(np.linspace(0.00, max_Delta, N), 9)
    h_lin = np.round(np.linspace(0.00, h_end, N), 9)
    Delta_mesh, h_mesh = np.meshgrid(Delta_lin, h_lin)

    # Query residuals from database
    Delta_list = Delta_mesh.flatten()
    h_list = h_mesh.flatten()
    residual_list = get_residuals(_lambda, Delta_list, h_list)

    residual_mesh = np.array(residual_list).reshape(Delta_mesh.shape)
    residual_mesh_mev = residual_mesh * 1e3

    bound = np.abs((residual_mesh_mev).max())
    normalize = Normalize(vmin=-bound, vmax=bound)

    plt.figure(figsize=(FIGURE_SIZE[0], FIGURE_SIZE[1] * 3 / 4))
    plt.xlabel(r"$\Delta_0$ (meV)")
    plt.ylabel("h (meV)")
    plt.title(rf"Residuals for $\lambda$={_lambda}")

    Delta_mesh_mev = Delta_mesh * 1e3
    h_mesh_mev = h_mesh * 1e3
    plt.contourf(
        Delta_mesh_mev,
        h_mesh_mev,
        residual_mesh_mev,
        cmap="bwr",
        norm=normalize,
        levels=100,
    )
    plt.colorbar(label="Residual (meV)")

    plt.savefig(
        PRESENTATION_MEDIA_DIR / f"residuals_contourf_{_lambda}.svg",
        transparent=True,
        bbox_inches="tight",
        pad_inches=None,
    )

    zeros_x, zeros_y = get_contour(Delta_mesh_mev, h_mesh_mev, residual_mesh_mev, 0.0)

    # Remove ill-defined points at x = 0
    zeros_y = zeros_y[zeros_x != 0]
    zeros_x = zeros_x[zeros_x != 0]

    # Sort by x coord
    sort_idxs = np.argsort(zeros_x)
    zeros_x = zeros_x[sort_idxs]
    zeros_y = zeros_y[sort_idxs]

    zeros_grad = np.gradient(zeros_y, zeros_x)
    (neg_grad_idxs,) = np.where(zeros_grad <= 0)

    # Midpoint idx is the first index such that more than 90% of subsequent gradients are negative
    midpoint = neg_grad_idxs[
        np.argmax(np.diff(neg_grad_idxs) > 0.9 * len(neg_grad_idxs))
    ]
    plt.plot(zeros_x[midpoint:], zeros_y[midpoint:], color="k", label="Stable")
    plt.plot(
        zeros_x[: midpoint + 1],
        zeros_y[: midpoint + 1],
        color="k",
        linestyle="--",
        label="Unstable",
    )
    plt.legend()

    plt.savefig(
        PRESENTATION_MEDIA_DIR / f"residuals_contourf_{_lambda}_with_stability.svg",
        transparent=True,
        bbox_inches="tight",
        pad_inches=None,
    )

    plt.show()

# Bilayer


In [None]:
color_dict = {"s": "C0", "d": "C1"}
marker_kwargs_dict = {0: {}, 1: {"facecolor": "none"}}
marker_size = 10


def plot_column(axes: np.ndarray[Axes], database: str, col: int):
    conn = sqlite3.connect(DATA_DIR / "ss_bilayer" / f"{database}.db")
    query = "SELECT temperature, tunneling, Ds_0, Ds_1 FROM ss_bilayer"
    df = pd.read_sql_query(query, conn)
    conn.close()

    for i, tunneling in enumerate(df["tunneling"].unique()):
        ax: Axes = axes[i, col]
        subset = df[df["tunneling"] == tunneling]
        for i, wave in enumerate(database[:2]):
            ax.scatter(
                subset["temperature"] / kB,
                subset[f"Ds_{i}"] * 1e3,
                color=color_dict[wave],
                **marker_kwargs_dict[i],
                s=marker_size,
            )

        ax.set_xlim(0, 12)
        ax.set_ylim(0, 3)

In [None]:
fig, axes = plt.subplots(
    7,
    4,
    figsize=(FIGURE_SIZE[0], FIGURE_SIZE[1] * 7 / 4),
    sharex="all",
    sharey="all",
    gridspec_kw={"wspace": 0, "hspace": 0},
)
plot_column(axes, "ss_bilayer", 0)
plot_column(axes, "dd_bilayer", 1)
plot_column(axes, "sd_bilayer", 2)
plot_column(axes, "ds_bilayer", 3)


fig.text(0.04, 0.5, r"$\Delta_S$ (meV)", va="center", rotation="vertical", fontsize=12)
fig.text(0.5, 0.04, "Temperature (K)", ha="center", fontsize=12)
plt.tight_layout(rect=[0.06, 0.06, 1, 1])

t_values = [0.0, 0.25, 0.5, 1.0, 2.0, 5.0, 10.0]

for row, t_val in enumerate(t_values):
    ax: Axes = axes[row, -1]  # rightmost subplot in the row
    ax.text(
        1.05,
        0.5,
        str(t_val),
        transform=ax.transAxes,
        rotation=0,
        va="center",
        ha="left",
    )

fig.text(1.03, 1.03, r"$t$ (meV)", va="top", ha="right", fontsize=12)


# Add legends by plotting scatter points
legend_handles = []
for i, wave in enumerate(["s", "d"]):
    for j in range(2):
        handle = ax.scatter(
            [],
            [],
            color=color_dict[wave],
            **marker_kwargs_dict[j],
            s=marker_size,
            label=f"{wave}-wave, {'high' if j == 0 else 'low'} $T_c$",
        )
        legend_handles.append(handle)

fig.legend(
    handles=legend_handles,
    loc="upper center",
    ncol=4,
    fontsize=11,
    bbox_to_anchor=(0.5, 1.04),
    handletextpad=0.2,  # Reduce spacing between legend markers and text
    columnspacing=0.8,  # Reduce spacing between legend columns
)

plt.savefig(
    PRESENTATION_MEDIA_DIR / "ds_vs_temperature.svg",
    transparent=True,
    bbox_inches="tight",
    pad_inches=None,
)

plt.show()

# Matsubara Frequencies


In [None]:
%matplotlib inline
matsubara_freqs, ksums, F_kresolved_final = GKTH_Greens(p, layers, verbose = True)

In [None]:
plt.pcolormesh(np.real(F_kresolved_final[0, ...]))
plt.colorbar()
plt.title("Real part of F_kresolved_final")
plt.show()

plt.scatter(matsubara_freqs, np.abs(ksums))
plt.xlim(0, 5200)
plt.title("Matsubara Frequencies vs |ksums|")
plt.xlabel("Matsubara Frequencies")
plt.ylabel("|ksums|")
plt.show()

# Junction


In [None]:
# plot_critical_current("S1_N_S1", 0.5e-3)

fig = plot_critical_current("S1_N_S2", 0.5e-3)
fig.savefig(
    PRESENTATION_MEDIA_DIR / "critical_current.svg",
    transparent=True,
    bbox_inches="tight",
    pad_inches=None,
)

# plot_critical_current("S2_N_S2", 0.5e-3)

In [None]:
fig = plt.figure(figsize=FIGURE_SIZE)
ax = fig.add_subplot(111)

for t in np.array([0.25, 0.5, 1]) * 1e-3:
    df = get_critical_current_data("S1_N_S2", t)
    ax.plot(df["temperature"], df["jc"], label=f"$t = {t * 1e3:.2f}$ meV")

ax.legend()
ax.set_xlim(0, 12)
ax.set_xlabel("Temperature (K)")
ax.set_ylabel(r"$j_c$ $(M A\ m^{-2})$")

# Dump


In [None]:
# fig = plt.figure()
# ts_list = []
# for i in range(10):
#     ts = 0.9 + i/50
#     p.ts = np.zeros(100) + ts
#     delta_list = []
#     tNN_list = []
#     for j in range(10):
#         print("current iteration:", i, j)
#         ts_list.append(ts)
#         tNN = -1 - j/20
#         Nb.tNN = tNN
#         Nb.tNNN = Nb.tNN * 0.1
#         tNN_list.append(tNN)
#         delta,_, _ = GKTH_self_consistency_1S(p,layers)
#         delta_list.append(delta)
#     plt.plot(tNN_list, delta_list, label = f"{ts:.2f}")

# plt.xlabel("tNN (eV)")
# plt.ylabel("delta_best_fit (eV)")
# plt.legend()