# Imports

In [None]:
%load_ext lab_black

import h5py
import os

from dataclasses import dataclass
from tqdm.auto import tqdm
from scipy.signal import savgol_filter
from scipy.interpolate import interp2d
from functools import lru_cache

from typing import Dict, List, Optional, Tuple
import numpy as np
import plotly.graph_objects as go
import matplotlib.pyplot as plt


import sys

sys.path.append(r"C:\Users\atully\Code\GitHub\ARPES Code\arpes-code-python")
from arpes_functions import (
    fitting_functions,
    analysis_functions,
    plotting_functions,
    HDF5_loader,
    misc_functions,
    filter_functions,
    tr_functions,
    loading_functions,
    cnn,
)

# Load Data

In [None]:
ddir = r"E:\atully\arpes_data\2023_February\6eV\TR"
files = ["TR11_Ali_avg.h5"]  # 1.7 eV center energy; -1 to 100 ps

ARPES_DATA: Dict[str, tr_functions.ArpesData] = {}
ARPES_ATTRS: Dict[str, tr_functions.ArpesAttrs] = {}
for file in tqdm(files):
    data, theta, phi_or_time, energy = loading_functions.load_hdf5(ddir, file)
    ARPES_DATA[file] = tr_functions.ArpesData(
        data=data, theta=theta, phi_or_time=phi_or_time, energy=energy
    )
    ARPES_ATTRS[file] = tr_functions.load_attrs_hdf5(ddir, file)

ad_11 = ARPES_DATA[files[0]]

In [None]:
ddir = r"E:\atully\arpes_data\2023_February\6eV\TR"
files = ["TR3_Ali_avg.h5"]  # 2.15 eV center energy; -1 to 2 ps

ARPES_DATA: Dict[str, tr_functions.ArpesData] = {}
ARPES_ATTRS: Dict[str, tr_functions.ArpesAttrs] = {}
for file in tqdm(files):
    data, theta, phi_or_time, energy = loading_functions.load_hdf5(ddir, file)
    ARPES_DATA[file] = tr_functions.ArpesData(
        data=data, theta=theta, phi_or_time=phi_or_time, energy=energy
    )
    ARPES_ATTRS[file] = tr_functions.load_attrs_hdf5(ddir, file)

ad_3 = ARPES_DATA[files[0]]

In [None]:
ddir = r"E:\atully\arpes_data\2023_February\6eV\TR"
files = [
    "TR4_Ali_avg.h5"
]  # 2.6 eV center energy; -1 to 1 ps, same number of steps as first 2 ps of TR3

ARPES_DATA: Dict[str, tr_functions.ArpesData] = {}
ARPES_ATTRS: Dict[str, tr_functions.ArpesAttrs] = {}
for file in tqdm(files):
    data, theta, phi_or_time, energy = loading_functions.load_hdf5(ddir, file)
    ARPES_DATA[file] = tr_functions.ArpesData(
        data=data, theta=theta, phi_or_time=phi_or_time, energy=energy
    )
    ARPES_ATTRS[file] = tr_functions.load_attrs_hdf5(ddir, file)

ad_4 = ARPES_DATA[files[0]]

In [None]:
ad = ad_4

for k in ["energy", "theta", "phi_or_time"]:
    print(f"{k}.shape = {getattr(ad, k).shape}")
print(f"Data.shape = {ad.data.shape}")

In [None]:
print(f"Delay range (mm): {np.min(ad.phi_or_time), np.max(ad.phi_or_time)}")
print(
    f"Energy range (eV): {np.round(np.min(ad.energy), 2), np.round(np.max(ad.energy), 2)}"
)
print(f"Theta range: {np.round(np.min(ad.theta), 1), np.round(np.max(ad.theta), 1)}")

# Analysis

## Plan
ad_11 --> 1.7 eV center energy; -1 to 100 ps; 80 steps in delay (mm), but -1 to 1 ps in 21 steps (not great time resolution...)

ad_3 --> 2.15 eV center energy; -1 to 2 ps; 62 steps in delay (mm)

ad_4 --> 2.6 eV center energy; -1 to 1 ps; 42 steps in delay (mm)

1. cut down to same timescale (-1 to 1 ps --> 37.81 to 38.11 mm)
2. ensure appropriate x axis

In [None]:
## Set up general parameters ##

# toggle_time = "picoseconds"
toggle_time = "mm"
time_zero = 37.96

slice_dim = "x"
slice_val = 0
int_range = 50  # integrate over all angles; if this value is more that the integration range, my get_2D_slice function will just integrate over the max range.

xlim = None
ylim = None
x_bin = 1
y_bin = 1

In [None]:
all_vals = []
for ad in [ad_11, ad_3, ad_4]:
    all_vals.append(
        tr_functions.slice_datacube(
            ad_dataclass=ad,
            slice_dim=slice_dim,
            slice_val=slice_val,
            int_range=int_range,
            xlim=xlim,
            ylim=(
                ad.energy[57],
                ad.energy[1007],
            ),  # get rid of zero padding on datasets
            x_bin=x_bin,
            y_bin=y_bin,
            norm_data=False,
            plot_data=False,
        )
    )
x_11, y_11, d_11 = all_vals[0]
x_3, y_3, d_3 = all_vals[1]
x_4, y_4, d_4 = all_vals[2]

In [None]:
## Plot Data: Plotly ##

fig = tr_functions.default_fig()
fig.add_trace(go.Heatmap(x=x_4, y=y_4, z=d_4, coloraxis="coloraxis"))
fig.add_trace(go.Heatmap(x=x_3, y=y_3, z=d_3, coloraxis="coloraxis"))
fig.add_trace(go.Heatmap(x=x_11, y=y_11, z=d_11, coloraxis="coloraxis"))
fig.update_coloraxes(colorscale="plasma", showscale=True)  # cmin=0, cmax=1.8

fig.update_layout(xaxis_range=(37.81, 38.11))

# fig.write_image(r'C:\Users\atully\OneDrive\Physics.UBC\TR-ARPES\Data\TR3&TR4&T11_plasma_plotly.png')

In [None]:
## Plot Data: MPL ##

fig, ax = plt.subplots(1)

ax.pcolormesh(x_11, y_11, d_11, shading="auto", cmap="plasma", vmin=0, vmax=2.5)
ax.pcolormesh(x_3, y_3, d_3, shading="auto", cmap="plasma", vmin=0, vmax=0.2)
ax.pcolormesh(x_4, y_4, d_4, shading="auto", cmap="plasma", vmin=0, vmax=0.1)

ax.set_xlim(xmin=37.81, xmax=38.11)

# plt.save_fig(r'C:\Users\atully\OneDrive\Physics.UBC\TR-ARPES\Data\TR3&TR4&T11_plasma_mpl.png')

In [None]:
## Normalize Plots relative to backgrounds

In [None]:
# TR4
xlim = (37.81, 37.84)
ylim = (2.3, 2.9)

tr4_bg = tr_functions.get_avg_background(x_4, y_4, d_4, xlim, ylim)
# tr4_bg = get_avg_background(x_4, y_4, d4_norm, xlim, ylim)  # check

# TR3
xlim = (37.81, 37.84)
ylim = (2.15, 2.48)

tr3_bg = tr_functions.get_avg_background(x_3, y_3, d_3, xlim, ylim)

# TR11
xlim = (37.81, 37.85)
ylim = (1.93, 2.04)

tr11_bg = tr_functions.get_avg_background(x_11, y_11, d_11, xlim, ylim)
# tr11_bg = get_avg_background(x_11, y_11, d11_norm, xlim, ylim)  # check

tr4_bg, tr3_bg, tr11_bg

In [None]:
norm_tr4_to_tr3 = tr3_bg / tr4_bg
norm_tr11_to_tr3 = tr3_bg / tr11_bg

d4_norm = d_4 * norm_tr4_to_tr3
d11_norm = d_11 * norm_tr11_to_tr3

In [None]:
# Plot Data
fig = tr_functions.default_fig()
fig.add_trace(go.Heatmap(x=x_4, y=y_4, z=d4_norm, coloraxis="coloraxis"))
fig.add_trace(go.Heatmap(x=x_3, y=y_3, z=d_3, coloraxis="coloraxis"))
fig.add_trace(go.Heatmap(x=x_11, y=y_11, z=d11_norm, coloraxis="coloraxis"))
fig.update_coloraxes(colorscale="plasma", showscale=True, cmin=0, cmax=None)
fig.update_layout(xaxis_range=(37.81, 38.11))

In [None]:
# # figure out ylims to eliminate d=0 padding --> bottom
# np.where(np.isclose(y_4, 2.2528)), np.where(np.isclose(y_3, 1.8028)), np.where(
#     np.isclose(y_11, 1.35281)
# )

In [None]:
# # figure out ylims to eliminate d=0 padding --> top_
# np.where(np.isclose(y_4, 2.9479)), np.where(np.isclose(y_3, 2.4979)), np.where(
#     np.isclose(y_11, 2.0479)
# ),

In [None]:
# give equivalent x axes

xlim = (37.81, 38.11)

x4, y4, d4 = analysis_functions.limit_dataset(x_4, y_4, d4_norm, xlim=xlim, ylim=None)
x3, y3, d3 = analysis_functions.limit_dataset(x_3, y_3, d_3, xlim=xlim, ylim=None)
x11, y11, d11 = analysis_functions.limit_dataset(
    x_11, y_11, d11_norm, xlim=xlim, ylim=None
)

In [None]:
# Plot TR4
fig = tr_functions.default_fig()
fig.add_trace(go.Heatmap(x=x4, y=y4, z=d4, coloraxis="coloraxis"))
fig.update_coloraxes(colorscale="greys", showscale=True)  # greys

# # Plot Data
# fig, ax = plotting_functions.plot_2D_mpl(
#     x=x,
#     y=y,
#     data=d,
#     xlabel="delay",
#     ylabel="energy",
#     title=f"TR4",
#     # cmap="gray",
#     cmap="Blues",
#     vmin=0,
#     vmax=1,
# )

In [None]:
# Plot TR3
fig = tr_functions.default_fig()
fig.add_trace(go.Heatmap(x=x3, y=y3, z=d3, coloraxis="coloraxis"))
fig.update_coloraxes(colorscale="plasma", showscale=False, cmin=0, cmax=0.4)

In [None]:
# Plot TR11
fig = tr_functions.default_fig()
fig.add_trace(go.Heatmap(x=x11, y=y11, z=d11, coloraxis="coloraxis"))
fig.update_coloraxes(colorscale="plasma", showscale=True)  # can set cmin & cmax here
fig.show()

print(d11.shape)

In [None]:
## Linearly interpolate x11 to match resolution of TR3 and TR4 ##

x, y, d = x11, y11, d11

new_d = tr_functions.interpolate_dataset(x, y, d, xref=x3)

fig = tr_functions.default_fig()
fig.add_trace(go.Heatmap(x=x3, y=y, z=new_d))
fig.show()

print(new_d.shape)

In [None]:
### If I didn't linearly interporlate, I would need to bin the data of TR3 and TR4 to match the time resolution of TR11 ###

In [None]:
## Bin TR3 and TR4 ##

# # TR4
# x_bin = 2
# y_bin = 1

# d4_bin = misc_functions.bin_data(data=d4, bin_x=x_bin, bin_y=y_bin)
# x4_bin = misc_functions.bin_data(data=x4, bin_x=x_bin)
# y4_bin = misc_functions.bin_data(data=y4, bin_x=y_bin)

# # TR3
# x_bin = 2
# y_bin = 1

# d3_bin = misc_functions.bin_data(data=d3, bin_x=x_bin, bin_y=y_bin)
# x3_bin = misc_functions.bin_data(data=x3, bin_x=x_bin)
# y3_bin = misc_functions.bin_data(data=y3, bin_x=y_bin)

# d4_bin.shape, d3_bin.shape, d11.shape

In [None]:
## Plot binned data ##

# x, y, d = x4_bin, y4_bin, d4_bin
# # x, y, d = x3_bin, y3_bin, d3_bin
# # x, y, d = x11, y11, d11

# # Plot Data
# fig = tr_functions.default_fig()
# fig.add_trace(
#     go.Heatmap(x=x, y=y, z=d, coloraxis="coloraxis")  # can set zmin & zmax here
# )
# fig.update_coloraxes(colorscale="greys", showscale=False)

In [None]:
# # Plot Stitched Data (Binned)

# x, y, data = stitch_2_datasets(d4_bin, x4_bin, y4_bin, d3_bin, x3_bin, y3_bin)

# fig = tr_functions.default_fig()
# fig.add_trace(
#     go.Heatmap(x=x, y=y, z=data, coloraxis="coloraxis")
#     # np.log(data)
# )
# fig.update_coloraxes(
#     colorscale="plasma", showscale=False, cmin=0, cmax=0.3
# )  # cmin=0, cmax=0.3

# # fig, ax = plt.subplots(1)
# # ax.pcolormesh(x, y, data, shading="auto", cmap="plasma", vmin=0, vmax=0.3)

In [None]:
## Plot Stitched Datasets (linearly interpolated): Step 1 ##

x_s1, y_s1, data_s1 = tr_functions.stitch_2_datasets(d4, x4, y4, d3, x3, y3)
print(data.shape)

time_zero = 37.95
x_s1 = ((x_s1 - time_zero) * 1e-3 * 2) / (3e8)

# Plot Data
fig = tr_functions.default_fig()
fig.add_trace(
    go.Heatmap(x=x_s1, y=y_s1, z=data_s1, coloraxis="coloraxis")
    # np.log(data)
)
fig.update_coloraxes(colorscale="plasma", showscale=True, cmin=0, cmax=0.8)

fig.update_layout(
    title=f"TR3 & TR4: Backgrounds normalized and stitched",
    xaxis_title="delay",
    yaxis_title="energy (eV)",
)

# fig.update_layout(
#     width=600,
#     height=600,
#     autosize=False,
# )

# fig.write_image(r'C:\Users\atully\OneDrive\Physics.UBC\TR-ARPES\Data\TR3&TR4_plasma.png')

In [None]:
## Plot Stitched Datasets (linearly interpolated): Step 2 ##

x_s2, y_s2, data_s2 = tr_functions.stitch_2_datasets(
    new_d, x3, y11, data_s1, x_s1, y_s1
)

# Plot Data
fig = tr_functions.default_fig()
fig.add_trace(
    go.Heatmap(x=x_s2, y=y_s2, z=data_s2, coloraxis="coloraxis")
    # np.log(data)
)
fig.update_coloraxes(colorscale="plasma", showscale=True, cmin=0, cmax=1.8)

# Fit for EF

Fermi-Dirac function: $f(E) = \frac{1}{2}[1 - tanh(\frac{1}{2}\beta(E-\mu))]$ 

Note it is symmetric around $E=\mu$

$\beta = \frac{1}{k_BT}$

In [None]:
k_B = 8.617333e-5  # eV/K

In [None]:
k_B * 10.6

In [None]:
xlim = None
ylim = (1.85, 2.5)

y_1d, col = tr_functions.get_1d_x_slice(
    x=x_2d,
    y=y_2d,
    data=d,
    ylims=ylim,
    x_range=xlim,
)

# Plot Data
fig = tr_functions.default_fig()
fig.add_trace(go.Scatter(x=y_1d, y=col, name="data"))
fig.update_layout(
    title=f"Time integration limits: {xlim}",
    xaxis_title="Energy",
    yaxis_title="Intensity (arb. u)",
)

In [None]:
import lmfit as lm

# gauss1 = lm.models.GaussianModel(prefix="A_")
# gauss2 = lm.models.GaussianModel(prefix="B_")

gauss1 = fitting_functions.make_gaussian(num="A", amplitude=1, center=1.65, sigma=0.5)
gauss2 = fitting_functions.make_gaussian(num="B", amplitude=1, center=2.05, sigma=0.5)


# ADD LINEAR OFFSET
def fermi(x, center, theta, amp):
    """CHECK THIS"""
    arg = (x - center) / (2 * theta)  # x=E, center=mu, theta = k_B * T
    return -amp / 2 * np.tanh(arg)


fermi_model = lm.models.Model(fermi)

full_model = gauss1 + gauss2 + fermi_model
# full_model = fermi_model + gauss2
params = full_model.make_params()

# params["A_center"].value = 1.6
# params["B_center"].value = 2.0

params["center"].value = 1.8
# params["theta"].value = 0.1
params["theta"].value = k_B * (10.6)
params["amp"].value = 1


fit = full_model.fit(col, x=y_1d, params=params)
fit.plot()

In [None]:
fit