In [None]:
%reload_ext autoreload
%autoreload 2
import sqlite3
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.colors import Normalize

from Global_Parameter import GlobalParams
from Green_Function import GKTH_Greens
from H_eig import GKTH_find_spectrum
from k_space_flipping import GKTH_flipflip
from Layer import Layer
from main import (
    FIGURE_SIZE,
    PRESENTATION_MEDIA_DIR,
    drop_lambda,
    get_contour,
    get_delta_vs_h,
    get_residuals,
    plot_for_lambda,
    plot_for_lambda_h_list,
    plot_for_lambda_zeros,
    run_for_lambda,
)
from self_consistency_delta import GKTH_self_consistency_1S_residual

DATA_DIR = Path("data")

# Band energy spectrums


In [None]:
h = 0.01

# Default parameters and single default layer
p = GlobalParams()
layers = [Layer(_lambda=0)]

max_val_kx = np.max(np.abs(p.k1))
kx = np.linspace(-max_val_kx, max_val_kx, p.nkpoints * 2)
ky = np.linspace(-max_val_kx, max_val_kx, p.nkpoints * 2)
kx, ky = np.meshgrid(kx, ky)

p.h = h

# energy is 3d array, nkpoints x nkpoints x (4*nlayers)
# each k-point has its only eigenvalues list
energy = GKTH_find_spectrum(p, layers)

bands = []
for i in range(4 * len(layers)):
    energy_band = energy[:, :, i]
    new_band = GKTH_flipflip(energy_band)
    bands.append(new_band)

fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")

for i, band in enumerate(bands):
    # Skip bands with NaNs or infinities
    if np.isnan(band).any() or np.isinf(band).any():
        print(f"Skipping band {i} due to NaN or infinity values.")
        continue

    surf = ax.plot_surface(kx, ky, band, facecolor=f"C{i%10}", alpha=0.5)

ax.view_init(azim=-120, elev=20)  # adjust the view angle

ax.set_xlabel("$k_x$")
ax.set_ylabel("$k_y$")
ax.set_zlabel("Energy (eV)")
ax.set_title(f"Energy Spectrum\n$h={p.h}$ eV")

fig.savefig(
    PRESENTATION_MEDIA_DIR / f"energy_spectrum_h_{p.h}.svg",
    transparent=True,
    bbox_inches="tight",
    pad_inches=None,
)

In [None]:
%matplotlib inline
fig, axes = plt.subplots(2, 2, figsize=(10, 10))

for i, band in enumerate(bands):
    ax: Axes = axes[i // 2, i % 2]
    ax.imshow(band, cmap="viridis")
    ax.set_title(f"Band {i}")
    ax.set_xlabel("$k_x$")
    ax.set_ylabel("$k_y$")

fig.tight_layout()

fig.savefig(PRESENTATION_MEDIA_DIR / "bands.svg", transparent=True, bbox_inches="tight", pad_inches=None)

# Residual-Gap Plot


## Check root finding


In [None]:
# lambda_list = np.round(np.linspace(0.0, 0.2, 11), 9)
lambda_list = [0.1, 0.15, 0.2]
h_end_list = [1e-3, 2e-2, 5e-2]
max_Delta_list = [2e-3, 2e-2, 50e-3]
for _lambda, h_end, max_Delta in zip(lambda_list, h_end_list, max_Delta_list):
    run_for_lambda(_lambda, h_end=h_end, delta_end=max_Delta)

In [None]:
plot_for_lambda_h_list(_lambda=0.15, h_list=[0, 0.009, 0.012])
plt.show()

In [None]:
%matplotlib inline
lambda_list = [0.0, 0.02, 0.04, 0.05, 0.06, 0.08, 0.1, 0.15, 0.2]
for _lambda in lambda_list:
    fig = plot_for_lambda(_lambda)
    fig.savefig(
        PRESENTATION_MEDIA_DIR / f"residuals_delta_lambda_{_lambda}.svg",
        transparent=True,
        bbox_inches="tight",
        pad_inches=None,
    )
    plt.show()

In [None]:
def set_axes(ax: Axes):
    ax.set_xlim(0.0, 2.0)
    ax.set_ylim(-0.12, 0.04)


_lambda = 0.1
fig = plot_for_lambda(_lambda)
set_axes(fig.gca())
fig.savefig(
    PRESENTATION_MEDIA_DIR / f"residuals_delta_lambda_{_lambda}.svg",
    transparent=True,
    bbox_inches="tight",
    pad_inches=None,
)
plt.show()

fig = plot_for_lambda_zeros(_lambda)
set_axes(fig.gca())
fig.savefig(
    PRESENTATION_MEDIA_DIR / f"residuals_delta_lambda_{_lambda}_zeros.svg",
    transparent=True,
    bbox_inches="tight",
    pad_inches=None,
)
plt.show()

## Gap vs h


In [None]:
lambda_list = [0.1]

plt.figure(figsize=FIGURE_SIZE)

for _lambda in lambda_list:
    h, delta = get_delta_vs_h(_lambda)
    plt.plot(
        h * 1e3,
        delta * 1e3,
        linestyle="--",
        marker="o",
        label=rf"$\lambda = {_lambda}$",
    )

plt.legend()
plt.xlabel("h (meV)")
plt.ylabel(r"$\Delta_s$ (meV)")
plt.savefig(
    PRESENTATION_MEDIA_DIR / "delta_vs_h_lambda_0.1.svg",
    transparent=True,
    bbox_inches="tight",
    pad_inches=None,
)
plt.show()

In [None]:
plt.figure(figsize=FIGURE_SIZE)

lambda_list = [0.05, 0.1, 0.15, 0.2]
for _lambda in lambda_list:
    h, delta = get_delta_vs_h(_lambda)
    plt.plot(
        h * 1e3,
        delta * 1e3,
        linestyle="--",
        marker="o",
        label=rf"$\lambda = {_lambda}$",
    )

plt.legend()

# Log scale for x and y axes
plt.xscale("log")
plt.yscale("log")

plt.xlabel("h (meV)")
plt.ylabel(r"$\Delta_s$ (meV)")

plt.savefig(
    PRESENTATION_MEDIA_DIR / "delta_vs_h.svg",
    transparent=True,
    bbox_inches="tight",
    pad_inches=None,
)
plt.show()

# Stability


In [None]:
lambda_list = [0.1, 0.15, 0.2]
h_end_list = [1e-3, 2e-2, 5e-2]
max_Delta_list = [2e-3, 2e-2, 50e-3]
N = 41

for _lambda, h_end, max_Delta in zip(lambda_list, h_end_list, max_Delta_list):
    Delta_lin = np.round(np.linspace(0.00, max_Delta, N), 9)
    h_lin = np.round(np.linspace(0.00, h_end, N), 9)

    Delta_mesh, h_mesh = np.meshgrid(Delta_lin, h_lin)

    for Delta, h in zip(Delta_mesh.flatten(), h_mesh.flatten()):
        # Connect to database
        conn = sqlite3.connect(DATA_DIR / "residuals.db")
        c = conn.cursor()

        # Create table if it doesn't exist
        conn.execute(
            "CREATE TABLE IF NOT EXISTS residuals (lambda REAL, Delta REAL, h REAL, residual REAL)"
        )

        # If data is in the database, skip it
        c.execute(
            f"SELECT * FROM residuals WHERE lambda={_lambda} AND Delta={Delta} AND h={h}"
        )
        query = c.fetchone()
        if query is not None:
            residual = query[3]
            print(f"Delta: {Delta}, h: {h}, residual: {residual}")
            continue

        # Get result
        p = GlobalParams(h=h)
        layers = [Layer(_lambda=_lambda)]

        residual = GKTH_self_consistency_1S_residual(
            Delta_0_fit=Delta, p=p, layers=layers, layers_to_check=[0]
        )

        # Insert into database
        c.execute(f"INSERT INTO residuals VALUES ({_lambda}, {Delta}, {h}, {residual})")
        conn.commit()
        conn.close()

        print(f"Delta: {Delta}, h: {h}, residual: {residual}")

In [None]:
lambda_list = [0.1, 0.15, 0.2]
h_end_list = [1e-3, 2e-2, 5e-2]
max_Delta_list = [2e-3, 2e-2, 50e-3]

for _lambda, h_end, max_Delta in zip(lambda_list, h_end_list, max_Delta_list):
    Delta_lin = np.round(np.linspace(0.00, max_Delta, N), 9)
    h_lin = np.round(np.linspace(0.00, h_end, N), 9)
    Delta_mesh, h_mesh = np.meshgrid(Delta_lin, h_lin)

    # Query residuals from database
    Delta_list = Delta_mesh.flatten()
    h_list = h_mesh.flatten()
    residual_list = get_residuals(_lambda, Delta_list, h_list)

    residual_mesh = np.array(residual_list).reshape(Delta_mesh.shape)
    residual_mesh_mev = residual_mesh * 1e3

    bound = np.abs((residual_mesh_mev).max())
    normalize = Normalize(vmin=-bound, vmax=bound)

    plt.figure(figsize=FIGURE_SIZE)
    plt.xlabel(r"$\Delta_0$ (meV)")
    plt.ylabel("h (meV)")
    plt.title(rf"Residuals for $\lambda$={_lambda}")

    Delta_mesh_mev = Delta_mesh * 1e3
    h_mesh_mev = h_mesh * 1e3
    plt.contourf(
        Delta_mesh_mev,
        h_mesh_mev,
        residual_mesh_mev,
        cmap="bwr",
        norm=normalize,
        levels=100,
    )
    plt.colorbar(label="Residual (meV)")

    plt.savefig(
        PRESENTATION_MEDIA_DIR / f"residuals_contourf_{_lambda}.svg",
        transparent=True,
        bbox_inches="tight",
        pad_inches=None,
    )

    zeros_x, zeros_y = get_contour(Delta_mesh_mev, h_mesh_mev, residual_mesh_mev, 0.0)

    # Remove ill-defined points at x = 0
    zeros_y = zeros_y[zeros_x != 0]
    zeros_x = zeros_x[zeros_x != 0]

    # Sort by x coord
    sort_idxs = np.argsort(zeros_x)
    zeros_x = zeros_x[sort_idxs]
    zeros_y = zeros_y[sort_idxs]

    zeros_grad = np.gradient(zeros_y, zeros_x)
    (neg_grad_idxs,) = np.where(zeros_grad <= 0)

    # Midpoint idx is the first index such that more than 90% of subsequent gradients are negative
    midpoint = neg_grad_idxs[
        np.argmax(np.diff(neg_grad_idxs) > 0.9 * len(neg_grad_idxs))
    ]
    plt.plot(zeros_x[midpoint:], zeros_y[midpoint:], color="k", label="Stable")
    plt.plot(
        zeros_x[: midpoint + 1],
        zeros_y[: midpoint + 1],
        color="k",
        linestyle="--",
        label="Unstable",
    )
    plt.legend()

    plt.savefig(
        PRESENTATION_MEDIA_DIR / f"residuals_contourf_{_lambda}_with_stability.svg",
        transparent=True,
        bbox_inches="tight",
        pad_inches=None,
    )

    plt.show()

# Matsubara Frequencies


In [None]:
%matplotlib inline
matsubara_freqs, ksums, F_kresolved_final = GKTH_Greens(p, layers, verbose = True)

In [None]:
plt.pcolormesh(np.real(F_kresolved_final[0, ...]))
plt.colorbar()
plt.title("Real part of F_kresolved_final")
plt.show()

plt.scatter(matsubara_freqs, np.abs(ksums))
plt.xlim(0, 5200)
plt.title("Matsubara Frequencies vs |ksums|")
plt.xlabel("Matsubara Frequencies")
plt.ylabel("|ksums|")
plt.show()

# Misc


In [None]:
# drop_lambda(0.15)
# _lambda = 0.1
# conn = sqlite3.connect(DATA_DIR / "residuals.db")
# c = conn.cursor()
# c.execute("DELETE FROM residuals WHERE lambda = ?", (_lambda,))
# conn.commit()
# conn.close()

# Dump


In [None]:
# fig = plt.figure()
# ts_list = []
# for i in range(10):
#     ts = 0.9 + i/50
#     p.ts = np.zeros(100) + ts
#     delta_list = []
#     tNN_list = []
#     for j in range(10):
#         print("current iteration:", i, j)
#         ts_list.append(ts)
#         tNN = -1 - j/20
#         Nb.tNN = tNN
#         Nb.tNNN = Nb.tNN * 0.1
#         tNN_list.append(tNN)
#         delta,_, _ = GKTH_self_consistency_1S(p,layers)
#         delta_list.append(delta)
#     plt.plot(tNN_list, delta_list, label = f"{ts:.2f}")

# plt.xlabel("tNN (eV)")
# plt.ylabel("delta_best_fit (eV)")
# plt.legend()