# Imports

In [None]:
%load_ext lab_black

import h5py
import os

from dataclasses import dataclass
from tqdm.auto import tqdm
from scipy.signal import savgol_filter
from functools import lru_cache

from typing import Dict, List, Optional, Tuple
import numpy as np
import plotly.graph_objects as go
import matplotlib.pyplot as plt


import sys

sys.path.append(r"C:\Users\atully\Code\GitHub\ARPES Code\arpes-code-python")
from arpes_functions import (
    fitting_functions,
    analysis_functions,
    plotting_functions,
    HDF5_loader,
    misc_functions,
    filter_functions,
    tr_functions,
    loading_functions,
    cnn,
)

# Load Data

In [None]:
# TR1
ddir = r"E:\atully\arpes_data\2023_February\6eV\TR"
files = []
files = ["TR_001_1.h5"]

ARPES_DATA: Dict[str, tr_functions.ArpesData] = {}
ARPES_ATTRS: Dict[str, tr_functions.ArpesAttrs] = {}
for file in tqdm(files):
    data, theta, phi_or_time, energy = loading_functions.load_hdf5(ddir, file)
    ARPES_DATA[file] = tr_functions.ArpesData(
        data=data, theta=theta, phi_or_time=phi_or_time, energy=energy
    )
    ARPES_ATTRS[file] = tr_functions.load_attrs_hdf5(ddir, file)

In [None]:
ad = ARPES_DATA[files[0]]
for k in ["energy", "theta", "phi_or_time"]:
    print(f"{k}.shape = {getattr(ad, k).shape}")
print(f"Data.shape = {ad.data.shape}")

In [None]:
print(f"Delay range (mm): {np.min(ad.phi_or_time), np.max(ad.phi_or_time)}")
print(
    f"Energy range (eV): {np.round(np.min(ad.energy), 2), np.round(np.max(ad.energy), 2)}"
)
print(f"Theta range: {np.round(np.min(ad.theta), 1), np.round(np.max(ad.theta), 1)}")

# Analysis

In [None]:
## Zero Delay ##

time_zero = 37.958  # from BiSe

## HOMO is at 2.05 eV below EF, based on fits from this data averaged with fits from tr-ARPES results ##

EF_400 = 1.91  # in kinetic energy, slit 400
EF_700 = 1.94  # in kinetic energy, slit 700

homo = -2.05

homo_400 = homo + EF_400
homo_700 = homo + EF_700

In [None]:
## Integrate over desired angular range ##

slice_dim = "x"
slice_val = 0
int_range = 20  # if this value is more that the integration range, my get_2D_slice function will just integrate over the max range.

xlim = None
ylim = None
x_bin = 1
y_bin = 1

In [None]:
yaxis_title = "E - E<sub>HOMO</sub> [eV]"
xaxis_title = "Delay [ps]"

## TR3 --> -1 to 2 ps; Ec = 2.15 eV

In [None]:
x_2d, y_2d, d_2d = tr_functions.slice_datacube(
    ad_dataclass=ad,
    slice_dim=slice_dim,
    slice_val=slice_val,
    int_range=int_range,
    xlim=xlim,
    ylim=(
        ad.energy[57],
        ad.energy[1007],
    ),  # get rid of zero padding on datasets
    x_bin=x_bin,
    y_bin=y_bin,
    norm_data=True,
    plot_data=False,
)

In [None]:
## Adjust energy axis to be relative to HOMO ##

homo_zero = False
homo_zero = True

if homo_zero:
    # y4 = (
    #     y4 - homo_400
    # )  # homo is negative, so energy scale will increase, because we're referencing a negative number rather than zero
    # y3 = y3 - homo_700
    # y11 = y11 - homo_400
    y_2d = y_2d - homo_700

In [None]:
# Convert mm to ps
# time_zero = 37.95
x_2d = tr_functions.mm_to_ps(x_2d, time_zero)

In [None]:
## Plot Data ##
fig = tr_functions.thesis_fig(
    title="TR1",
    xaxis_title=xaxis_title,
    yaxis_title=yaxis_title,
    equiv_axes=False,
    height=600,
    width=800,
)

fig.add_trace(go.Heatmap(x=x_2d, y=y_2d, z=d_2d, coloraxis="coloraxis"))

# fig.update_coloraxes(colorscale="greys", showscale=False)

fig.show(renderer="svg")

In [None]:
# Plot Data
fig, ax = plotting_functions.plot_2D_mpl(
    x=x_2d,
    y=y_2d,
    data=d_2d,
    xlabel="delay",
    ylabel="energy",
    title=f"TR1",
    # cmap="gray",
    cmap="Blues",
    vmin=0,
    vmax=0.05,
)

In [None]:
## Difference Map  TR11##
title = "Difference Map of TR1"
x, y, d = x_2d, y_2d, d_2d

# d_diff = d_2d - d_2d[:, 2][:, None]
d_diff_11 = d - np.mean(d[:, 0:4], axis=1)[:, None]

# Plot Data
fig = tr_functions.thesis_fig(
    title=f"{title}",
    xaxis_title=xaxis_title,
    yaxis_title=yaxis_title,
    equiv_axes=False,
    gridlines=False,
    height=400,
)

x_plot = x
y_plot = y

fig.add_trace(go.Heatmap(x=x_plot, y=y_plot, z=d_diff_11, coloraxis="coloraxis"))
# for h in [1.63, 1.8, 1.98]:
#     fig.add_hline(y=h, line=dict(color="black", width=1, dash="dash"))
fig.update_coloraxes(colorscale="RdBu", cmid=0, cmin=-0.1, cmax=0.1, showscale=True)
# fig.update_coloraxes(colorscale="balance", reversescale=True, cmid=0, showscale=True)