## Imports

In [None]:
%load_ext lab_black

import h5py
import os
import numpy as np
from typing import Dict, List, Optional, Tuple

from dataclasses import dataclass
from tqdm.auto import tqdm
from scipy.signal import savgol_filter
from scipy.interpolate import interp2d
from functools import lru_cache


import plotly.graph_objects as go
import plotly.colors as pc
import matplotlib.pyplot as plt


import sys

sys.path.append(r"C:\Users\atully\Code\GitHub\ARPES Code\arpes-code-python")
from arpes_functions import (
    fitting_functions,
    analysis_functions,
    plotting_functions,
    HDF5_loader,
    misc_functions,
    filter_functions,
    tr_functions,
    loading_functions,
    kw_data_loader,
    cnn,
    polygons,
)

colors = pc.qualitative.D3
angstrom = "\u212B"

# Load Data

In [None]:
# # Convert K corrected dat to h5##

# ddir = r"E:\atully\arpes_data\2023_February\6eV\FS"

# # STEP 1 ##
# # Convert ibw to hdf5
# fn = "FS2_avg_gkw.ibw"
# HDF5_loader.ibw_to_hdf5(ddir, fn, export=True)

# # Check conversion worked
# data, kx, ky, energy = HDF5_loader.load_hdf5(
#     ddir, "FS2_avg_gkw.h5"
# )  # load data from hdf5
# data.shape, kx.shape, ky.shape, energy.shape

In [None]:
## Load averaged, K corrected data ##

ddir = r"E:\atully\arpes_data\2023_February\6eV\FS"
files = ["FS2_avg_gkw.h5"]

# This works, but makes dataclass with theta and phi_or_time instead of kx and ky
ARPES_DATA: Dict[str, tr_functions.ArpesData] = {}
ARPES_ATTRS: Dict[str, tr_functions.ArpesAttrs] = {}
for file in tqdm(files):
    data, kx, ky, energy = loading_functions.load_hdf5(ddir, file)
    ARPES_DATA[file] = tr_functions.ArpesData(
        data=data, theta=kx, phi_or_time=ky, energy=energy
    )
    ARPES_ATTRS[file] = tr_functions.load_attrs_hdf5(ddir, file)

In [None]:
ad = ARPES_DATA[files[0]]
for k in ["energy", "theta", "phi_or_time"]:
    print(f"{k}.shape = {getattr(ad, k).shape}")
print(f"Data.shape = {ad.data.shape}")

# Analysis -- stitching and normalizing datasets

In [None]:
## HOMO is at 2.05 eV below EF, based on fits from this data averaged with fits from tr-ARPES results ##

EF_400 = 1.91  # in kinetic energy, slit 400

homo = -2.05

homo_400 = homo + EF_400

ad.energy = ad.energy - homo_400

In [None]:
## Initial params ##
slice_dim = "y"
# slice_val = np.round(2.65 - homo_400, 3)
# slice_val = np.round(2.55 - homo_400, 3)
# slice_val = np.round(2.45 - homo_400, 3)
# slice_val = 2.5
# slice_val = 2.75
slice_val = 2.96

# slice_val =

int_range = 0.05

In [None]:
title = f"CT<sub>2</sub> (E - E<sub>HOMO</sub> = {slice_val})"
yaxis_title = f"k<sub>y</sub> [{angstrom}<sup>-1</sup>]"
xaxis_title = f"k<sub>x</sub> [{angstrom}<sup>-1</sup>]"

In [None]:
## Slice Data ##

xlim = (-0.13, 0.47)
ylim = (-0.57, 0.18)
# xlim = None
# ylim = None
x_bin = 2
y_bin = 2

## Get data
x_2d, y_2d, d_2d = tr_functions.slice_datacube(
    ad,
    slice_dim,
    slice_val,
    int_range,
    xlim,
    ylim,
    x_bin,
    y_bin,
    norm_data=True,
    plot_data=False,
)

## Plot data
fig = tr_functions.thesis_fig(
    title=f"Excited State: {slice_val} eV",
    xaxis_title="$k_x \; [A^{-1}]$",
    yaxis_title="$k_y \; [A^{-1}]$",
)

fig.add_trace(
    go.Heatmap(
        x=x_2d, y=y_2d, z=analysis_functions.norm_data(d_2d), coloraxis="coloraxis"
    )
)

hexagon = polygons.gen_polygon(6, 0.42, rotation=30)
fig = polygons.plot_polygon(hexagon, color="green", fig=fig, show=False)

fig.update_coloraxes(
    colorscale="ice",
    reversescale=True,
    showscale=True,
    cmin=0,
    cmax=None,
)

fig.update_xaxes(range=[xlim[0], xlim[1]], constrain="domain")
fig.update_yaxes(scaleanchor="x", scaleratio=1)

fig.update_layout(width=600, height=600)
fig.show(renderer="svg")

In [None]:
## Rotate Data ##

In [None]:
# x, y, z = x_s2, y_s2, analysis_functions.norm_data(data_s2)
x, y, z = x_2d, y_2d, analysis_functions.norm_data(d_2d)

coords = tr_functions.x_y_to_coords(x, y)

In [None]:
rotated_coords = tr_functions.rotate_2d_array(coords, 120, (0, 0))
rotated_coords_2 = tr_functions.rotate_2d_array(coords, 240, (0, 0))

In [None]:
nx, ny, nd = tr_functions.interpolate(rotated_coords, z)
nx_2, ny_2, nd_2 = tr_functions.interpolate(rotated_coords_2, z)

In [None]:
## Plot raw data and rotations on same figure ##

fig = tr_functions.thesis_fig(
    title=title,
    xaxis_title=xaxis_title,
    yaxis_title=yaxis_title,
    equiv_axes=True,
    gridlines=False,
    height=600,
    width=600,
)

fig.add_trace(
    go.Heatmap(
        x=nx_2,
        y=ny_2,
        z=nd_2,
        coloraxis="coloraxis",
        # opacity=0.85,
    )
)

fig.add_trace(
    go.Heatmap(
        x=nx,
        y=ny,
        z=nd,
        coloraxis="coloraxis",
        # opacity=0.85,
    )
)

fig.add_trace(
    # go.Heatmap(
    #     x=x_s2,
    #     y=y_s2,
    #     z=analysis_functions.norm_data(data_s2),
    #     coloraxis="coloraxis",
    #     opacity=0.85,
    # )
    go.Heatmap(
        x=x_2d,
        y=y_2d,
        z=analysis_functions.norm_data(d_2d),
        coloraxis="coloraxis",
        # opacity=0.85,
    )
)

hexagon = polygons.gen_polygon(6, 0.42, rotation=30)
fig = polygons.plot_polygon(
    hexagon, color="yellow", fig=fig, show=False, dash=True, dash_width=3
)

fig.update_coloraxes(
    colorscale="ice",
    reversescale=True,
    showscale=True,
    cmin=0,
    cmax=None,
)
fig.update_yaxes(scaleanchor="x", scaleratio=1)

fig.update_layout(width=600, height=600)
fig.show()

# fig.write_image(
#     # r"C:\Users\atully\OneDrive\Physics.UBC\TR-ARPES\Data\FS_2.65eV_full.png"
#     # r"C:\Users\atully\OneDrive\Physics.UBC\TR-ARPES\Data\FS_2.55eV_full.png"
#     # r"C:\Users\atully\OneDrive\Physics.UBC\TR-ARPES\Data\FS_2.45eV_full.png"
# )

# fig.write_image(
#     # r"C:\Users\atully\OneDrive\Physics.UBC\TR-ARPES\Data\FS_2.65eV_full_opacity0.85.png"
#     # r"C:\Users\atully\OneDrive\Physics.UBC\TR-ARPES\Data\FS_2.55eV_full_opacity0.85.png"
#     # r"C:\Users\atully\OneDrive\Physics.UBC\TR-ARPES\Data\FS_2.45eV_full_opacity0.85.png"
# )

In [None]:
## Average these datasets all together (requires interpolation of data)  ##

In [None]:
# def stitch_and_avg(x1, y1, data1, x2, y2, data2):
#     # Create new axes, 1000 x 1000 is the desired final resolution
#     new_x = np.linspace(min(min(x1), min(x2)), max(max(x1), max(x2)), 1000)
#     new_y = np.linspace(min(min(y1), min(y2)), max(max(y1), max(y2)), 1000)

#     # Generate new grid for data
#     new_datas = []

#     # Interpolate datasets onto new meshgrid (rqeuires defining interper function)
#     for x, y, data in zip([x1, x2], [y1, y2], [data1, data2]):
#         interper = RegularGridInterpolator(
#             (y, x), data, fill_value=np.nan, bounds_error=False
#         )
#         xx, yy = np.meshgrid(new_x, new_y, indexing="ij")

#         new_datas.append(interper((yy, xx)).T)

#     # Average dataslices together where they overlap (otherwise keep the original data)
#     new_data = np.nanmean(new_datas, axis=0)

#     return new_x, new_y, new_data

In [None]:
## Average original dataset with 1st rotated dataset ##

# x1, y1, dataslice1 = x_s2, y_s2, analysis_functions.norm_data(data_s2)
x1, y1, dataslice1 = x_2d, y_2d, analysis_functions.norm_data(d_2d)
x2, y2, dataslice2 = nx, ny, nd

new_x, new_y, new_data = tr_functions.stitch_and_avg(
    x1, y1, dataslice1, x2, y2, dataslice2
)

# fig = tr_functions.thesis_fig()

# fig.add_trace(
#     go.Heatmap(
#         x=new_x,
#         y=new_y,
#         z=analysis_functions.norm_data(new_data),
#         coloraxis="coloraxis",
#     )
# )

# fig.show(rendere="svg")

In [None]:
## Average new dataset with 2nd rotated dataset ##

x1, y1, dataslice1 = new_x, new_y, new_data
x2, y2, dataslice2 = nx_2, ny_2, nd_2

new_x, new_y, new_data = tr_functions.stitch_and_avg(
    x1, y1, dataslice1, x2, y2, dataslice2
)

fig = tr_functions.thesis_fig(
    title=f"Excited State: {slice_val} eV",
    xaxis_title="$k_x \; [A^{-1}]$",
    yaxis_title="$k_y \; [A^{-1}]$",
)

fig.add_trace(
    go.Heatmap(
        x=new_x,
        y=new_y,
        z=analysis_functions.norm_data(new_data),
        coloraxis="coloraxis",
    )
)

hexagon = polygons.gen_polygon(6, 0.42, rotation=30)
fig = polygons.plot_polygon(
    hexagon, color="firebrick", fig=fig, show=False, dash=True, dash_width=3
)

fig.show()

# fig.write_image(
#     # r"C:\Users\atully\OneDrive\Physics.UBC\TR-ARPES\Data\FS_2.65eV_full_averaged.png"
#     # r"C:\Users\atully\OneDrive\Physics.UBC\TR-ARPES\Data\FS_2.55eV_full_averaged.png"
#     # r"C:\Users\atully\OneDrive\Physics.UBC\TR-ARPES\Data\FS_2.45eV_full_averaged.png"
# )