# Imports

In [None]:
%load_ext lab_black

import h5py
import os

from dataclasses import dataclass
from tqdm.auto import tqdm
from scipy.signal import savgol_filter
from scipy.interpolate import interp2d
from functools import lru_cache
import lmfit as lm

from typing import Dict, List, Optional, Tuple
import numpy as np
import plotly.graph_objects as go
import plotly.colors as pc
import matplotlib.pyplot as plt


import sys

sys.path.append(r"C:\Users\atully\Code\GitHub\ARPES Code\arpes-code-python")
from arpes_functions import (
    fitting_functions,
    analysis_functions,
    plotting_functions,
    HDF5_loader,
    misc_functions,
    filter_functions,
    tr_functions,
    loading_functions,
    cnn,
)

colors = pc.qualitative.D3
angstrom = "\u212B"
theta = "\u03B8"
Theta = "\u0398"

# Load Data

In [None]:
# Load FFT, k-corrected Dataset

ddir = r"E:\atully\arpes_data\2023_February\6eV\TR"
files = ["TR3_avg_g_kw_filteredFFT_0.00int.h5"]  # 2.15 eV center energy; -1 to 2 ps

ARPES_DATA: Dict[str, tr_functions.ArpesData] = {}
ARPES_ATTRS: Dict[str, tr_functions.ArpesAttrs] = {}
for file in tqdm(files):
    data, theta, phi_or_time, energy = loading_functions.load_hdf5(ddir, file)
    ARPES_DATA[file] = tr_functions.ArpesData(
        data=data, theta=theta, phi_or_time=phi_or_time, energy=energy
    )
    ARPES_ATTRS[file] = tr_functions.load_attrs_hdf5(ddir, file)

ad3_fft = ARPES_DATA[files[0]]

In [None]:
# Load FFT, k-corrected Dataset

ddir = r"E:\atully\arpes_data\2023_February\6eV\TR"
files = [
    "TR4_avg_g_kw_filteredFFT_0.00int.h5"
]  # 2.6 eV center energy; -1 to 1 ps, same number of steps as first 2 ps of TR3

ARPES_DATA: Dict[str, tr_functions.ArpesData] = {}
ARPES_ATTRS: Dict[str, tr_functions.ArpesAttrs] = {}
for file in tqdm(files):
    data, theta, phi_or_time, energy = loading_functions.load_hdf5(ddir, file)
    ARPES_DATA[file] = tr_functions.ArpesData(
        data=data, theta=theta, phi_or_time=phi_or_time, energy=energy
    )
    ARPES_ATTRS[file] = tr_functions.load_attrs_hdf5(ddir, file)

ad4_fft = ARPES_DATA[files[0]]

# General Parameters

In [None]:
yaxis_title = "E - E<sub>HOMO</sub> [eV]"
xaxis_title = f"k<sub>x</sub> [{angstrom}<sup>-1</sup>]"

In [None]:
## Zero Delay ##

time_zero = 37.958  # from BiSe

## HOMO is at 2.05 eV below EF, based on fits from this data averaged with fits from tr-ARPES results ##

EF_400 = 1.91  # in kinetic energy, slit 400
EF_700 = 1.94  # in kinetic energy, slit 700

homo = -2.05

homo_400 = homo + EF_400
homo_700 = homo + EF_700

In [None]:
## Set up general parameters ##

integration = 0.5

# slice_center = -1  # -1.25 to -0.75 ps
# slice_center = -0.75  # -1 to -0.5 ps
# slice_center = -0.5  # -0.75 to -0.25 ps
# slice_center = -0.25  # -0.5 to 0 ps
# slice_center = 0  # -0.25 to 0.25 ps
# slice_center = 0.25  # 0 to 0.5 ps
# slice_center = 0.5  # 0.25 to 0.75 ps
# slice_center = 0.75  # 0.5 to 1 ps
# slice_center = 1  # 0.75 to 1.25 ps
# slice_center = 1.25  # 1 to 1.5 ps
# slice_center = 1.5  # 1.25 to 1.75 ps
# slice_center = 1.75  # 1.5 to 2 ps
slice_center = 2  # 1.75 to 2.25 ps


# ## This integrates from zero delay to 1 ps
# slice_center = 0.5
# integration = 1

# ## This integrates from -0.5 to 0.5 ps
# slice_center = 0
# integration = 1


slice_val = time_zero + tr_functions.ps_to_mm(slice_center, time_zero)
int_range = tr_functions.ps_to_mm(integration)  # TODO: make this able to be a tuple...


## Slicing in time to look for angular dispersion
slice_dim = "z"

# xlim = (-12, 12)  # theta
xlim = (-0.15, 0.15)  # k-corrected
ylim = None
x_bin = 2
y_bin = 2

In [None]:
all_vals = []
for ad in [ad3_fft, ad4_fft]:
    # for ad in [ad3_fft]:
    all_vals.append(
        tr_functions.slice_datacube(
            ad_dataclass=ad,
            slice_dim=slice_dim,
            slice_val=slice_val,
            int_range=int_range,
            xlim=xlim,
            ylim=(
                ad.energy[57],
                ad.energy[1007],
            ),  # get rid of zero padding on datasets
            x_bin=x_bin,
            y_bin=y_bin,
            norm_data=False,
            plot_data=False,
        )
    )
x3_fft, y3_fft, d3_fft = all_vals[0]
x4_fft, y4_fft, d4_fft = all_vals[1]

In [None]:
## Adjust energy axis to be relative to HOMO ##

homo_zero = False
homo_zero = True

if homo_zero:
    y4_fft = (
        y4_fft - homo_400
    )  # homo is negative, so energy scale will increase, because we're referencing a negative number rather than zero
    y3_fft = y3_fft - homo_700

In [None]:
## Optional Limit Dataset ##

xlim = xlim
# ylim = (2.05, np.max(y3_fft))  # theta
ylim = (2.05, 2.55)  # k-corrected, TR3

## TR3 ##
x3fft, y3fft, d3fft = analysis_functions.limit_dataset(
    x3_fft, y3_fft, d3_fft, xlim=xlim, ylim=ylim
)

d3fft = analysis_functions.norm_data(d3fft)

## TR4 ##
ylim = (2.45, np.max(y4_fft))  # k-corrected, TR4

x4fft, y4fft, d4fft = analysis_functions.limit_dataset(
    x4_fft, y4_fft, d4_fft, xlim=xlim, ylim=ylim
)

d4fft = analysis_functions.norm_data(d4fft)

## Enhance Contrast on CT2
# d4fft[np.where(d4fft > 0.3)] = 0.3  # limit dataset to cmax=0.5 for merging purposes
# d4fft = analysis_functions.norm_data(d4fft)

In [None]:
## Plot data ##

fig = tr_functions.thesis_fig(
    title=f"CT<sub>2</sub> Angular Dispersion",
    xaxis_title=xaxis_title,
    yaxis_title=yaxis_title,
    equiv_axes=False,
    height=600,
    width=600,
)

## TR 4
fig.add_trace(
    go.Heatmap(
        x=x4_fft,
        y=y4_fft,
        z=d4fft,
        coloraxis="coloraxis",
    )
)

# fig.update_coloraxes(cmin=0, cmax=0.5)

fig.show()


fig = tr_functions.thesis_fig(
    title=f"CT<sub>1</sub> Angular Dispersion",
    xaxis_title=xaxis_title,
    yaxis_title=yaxis_title,
    equiv_axes=False,
    height=600,
    width=600,
)


## TR 3
fig.add_trace(go.Heatmap(x=x3fft, y=y3fft, z=d3fft, coloraxis="coloraxis"))

# fig.update_coloraxes(cmin=0, cmax=0.9)  # when dataset is limited

fig.show()

In [None]:
d3fft.shape, d4fft.shape

In [None]:
## Linearly interpolate x11 d4fft to match resolution of TR3 and TR4 d3fft ##

x, y, d = x4fft, y4fft, d4fft

new_d = tr_functions.interpolate_dataset(x, y, d, xref=x3fft)

# fig = tr_functions.default_fig()
# fig.add_trace(go.Heatmap(x=x3fft, y=y, z=new_d))
# fig.show()

print(new_d.shape)

In [None]:
## Stitch Data ##

## TR4 & TR3
x_s1, y_s1, data_s1 = tr_functions.stitch_2_datasets(
    d3fft, x3fft, y3fft, new_d, x3fft, y4fft, stitch_dim="y"
)

In [None]:
fig = tr_functions.thesis_fig(
    title=f"Angular Dispersion at {slice_center} ps",
    xaxis_title=xaxis_title,
    yaxis_title=yaxis_title,
    equiv_axes=False,
    height=600,
    width=600,
)

## TR 4
fig.add_trace(
    go.Heatmap(
        x=x_s1,
        y=y_s1,
        z=analysis_functions.norm_data(data_s1),
        coloraxis="coloraxis",
    )
)

fig.update_coloraxes(cmin=0, cmax=None)

fig.show()

In [None]:
## Get and Plot 1D Data ##

fig = tr_functions.thesis_fig(
    title=f"EDC at {slice_center} ps",
    yaxis_title=yaxis_title,
    xaxis_title="Intensity [arb. u]",
    equiv_axes=False,
    gridlines=False,
    height=600,
    width=300,
    dtick_y=0.2,
    dtick_x=0.2,
)

y_1d, col = tr_functions.get_1d_x_slice(
    x=x_s1, y=y_s1, data=analysis_functions.norm_data(data_s1), ylims=None, x_range=None
)

# Plot Data
fig.add_trace(go.Scatter(x=col, y=y_1d, name="data", line=dict(color=colors[0])))

# Denoised Data

In [None]:
x_dn, y_dn, d_dn = x_s1, y_s1, analysis_functions.norm_data(data_s1)

# Save to .itx (igor) file
cnn.save_to_igor_itx(
    "test.itx", [x_dn], [d_dn], ["trarpes"], [y_dn]
)  # take every other y value to make dataset smaller
cnn.fix_itx_format("test.itx")  # fix itx formatting for denoising website

In [None]:
fp_dn = r"C:\Users\atully\OneDrive\Physics.UBC\PhD\exciton movie\CNN"
fn_dn = f"{slice_center}ps_movie_dn.itx"
# fn_dn = f"{slice_center}ps_contrast_movie_dn.itx"

title = f"Angular Dispersion {slice_center} ps"
yaxis_title = "E - E<sub>HOMO</sub> [eV]"
xaxis_title = xaxis_title

x, y, data_dn = loading_functions.load_denoised_data(fp_dn, fn_dn)

In [None]:
fig = tr_functions.thesis_fig(
    title=title,
    xaxis_title=xaxis_title,
    yaxis_title=yaxis_title,
    equiv_axes=False,
    height=600,
    width=600,
)

fig.add_trace(
    go.Heatmap(x=x, y=y, z=analysis_functions.norm_data(data_dn), coloraxis="coloraxis")
)


# fig.update_coloraxes(colorscale="Plasma", reversescale=False, cmin=0, cmax=1)

fig.show()