## Imports

In [None]:
%load_ext lab_black

import h5py
import os
import numpy as np
from typing import Dict, List, Optional, Tuple

from dataclasses import dataclass
from tqdm.auto import tqdm
from scipy.signal import savgol_filter
from scipy.interpolate import interp2d
from functools import lru_cache


import plotly.graph_objects as go
import plotly.colors as pc
import matplotlib.pyplot as plt


import sys

sys.path.append(r"C:\Users\atully\Code\GitHub\ARPES Code\arpes-code-python")
from arpes_functions import (
    fitting_functions,
    analysis_functions,
    plotting_functions,
    HDF5_loader,
    misc_functions,
    filter_functions,
    tr_functions,
    loading_functions,
    kw_data_loader,
    cnn,
    polygons,
)

colors = pc.qualitative.D3
angstrom = "\u212B"
Theta = "\u0398"
phi = "\u03C6"

In [None]:
def average_timescans(files, ddir, new_filename):
    datas = []
    for i in range(0, len(files)):
        ad = ARPES_DATA[files[i]]
        datas.append(ad.data)
    data_avg = np.mean(datas, axis=0)
    print(data_avg.shape)

    new_data = data_avg

    new_fn = os.path.join(ddir, new_filename)

    with h5py.File(
        new_fn, "w"
    ) as f:  # Note: 'w' creates a new empty file (or overwrites), use 'r+' to modify an existing file
        f["data"] = new_data.T
        axes_names = [
            "angles",
            "energies",
        ]  # Change these to match your axes labels
        axes = [ad.theta, ad.energy]
        for axis, name in zip(axes, axes_names):
            f[name] = np.atleast_2d(axis).T
        entry_group = f.require_group("entry1")
        entry_group["ScanValues"] = np.atleast_2d(ad.phi_or_time).T
    return new_fn


def sum_timescans(files, ddir, new_filename):
    datas = []
    for i in range(0, len(files)):
        ad = ARPES_DATA[files[i]]
        datas.append(ad.data)
    data_avg = np.sum(datas, axis=0)
    print(data_avg.shape)

    new_data = data_avg

    new_fn = os.path.join(ddir, new_filename)

    with h5py.File(
        new_fn, "w"
    ) as f:  # Note: 'w' creates a new empty file (or overwrites), use 'r+' to modify an existing file
        f["data"] = new_data.T
        axes_names = [
            "angles",
            "energies",
        ]  # Change these to match your axes labels
        axes = [ad.theta, ad.energy]
        for axis, name in zip(axes, axes_names):
            f[name] = np.atleast_2d(axis).T
        entry_group = f.require_group("entry1")
        entry_group["ScanValues"] = np.atleast_2d(ad.phi_or_time).T
    return new_fn

# Convert Data

In [None]:
# ## Convert K corrected .ibw data to .h5 ##

# ddir = r"E:\atully\arpes_data\2023_June\C60\ARPES\FS\k_corrected"

# # STEP 1 ##
# # Convert ibw to hdf5
# fn = "FS1_avg_gkw.ibw"
# HDF5_loader.ibw_to_hdf5(ddir, fn, export=True)

# # Check conversion worked
# data, kx, ky, energy = HDF5_loader.load_hdf5(
#     ddir, "FS1_avg_gkw.h5"
# )  # load data from hdf5
# data.shape, kx.shape, ky.shape, energy.shape

# Stitch Full CT1 

In [None]:
ddir = r"E:\atully\arpes_data\2023_June\C60\ARPES\FS\k_corrected"

In [None]:
## HOMO is at 2.05 eV below EF, based on fits from this data averaged with fits from tr-ARPES results ##

EF_400 = 1.91  # in kinetic energy, slit 400

homo = -2.05

homo_400 = homo + EF_400

## Left Side

In [None]:
## CT1 -- largest area matching stats

# files = ["FS0_avg_gkw.h5"]  # bottom
# files = ["FS3_avg_gkw.h5"]  # top
# files = ["FS2_avg_gkw.h5"]  # full range

files = []
files = ["FS0_avg_gkw.h5", "FS3_avg_gkw.h5", "FS2_avg_gkw.h5"]

# This works, but makes dataclass with theta and phi_or_time instead of kx and ky
ARPES_DATA: Dict[str, tr_functions.ArpesData] = {}
ARPES_ATTRS: Dict[str, tr_functions.ArpesAttrs] = {}
for file in tqdm(files):
    data, kx, ky, energy = loading_functions.load_hdf5(ddir, file)
    ARPES_DATA[file] = tr_functions.ArpesData(
        data=data, theta=kx, phi_or_time=ky, energy=energy
    )
    ARPES_ATTRS[file] = tr_functions.load_attrs_hdf5(ddir, file)

ad_bottom_left = ARPES_DATA[files[0]]
ad_top_left = ARPES_DATA[files[1]]
ad_full_left = ARPES_DATA[files[2]]

In [None]:
## Adjust energy axis to be relative to HOMO ##
homo_zero = False
homo_zero = True

if homo_zero:
    ad_bottom_left.energy = ad_bottom_left.energy - homo_400
    ad_top_left.energy = ad_top_left.energy - homo_400
    ad_full_left.energy = ad_full_left.energy - homo_400

In [None]:
slice_val = 2.1
slice_val = 2.15
slice_val = 2.2
# slice_val = 2.3
# slice_val = 2.4
# slice_val = 2.5
# slice_val = 2.6

slice_dim = "y"
int_range = 0.1
# int_range = 0.1

xlim = (-0.13, 0.31)
# xlim = None
ylim = (-0.4, 0.14)
# ylim = None

x_bin = 2
y_bin = 2

In [None]:
title = f"E - E<sub>HOMO</sub> = {slice_val} eV"
yaxis_title = f"k<sub>y</sub> [{angstrom}<sup>-1</sup>]"
xaxis_title = f"k<sub>x</sub> [{angstrom}<sup>-1</sup>]"

In [None]:
## Get Slices ##

x_bottom_left, y_bottom_left, d_bottom_left = tr_functions.slice_datacube(
    ad_dataclass=ad_bottom_left,
    slice_dim=slice_dim,
    slice_val=slice_val,
    int_range=int_range,
    xlim=xlim,
    ylim=(-0.4, 0.094),
    x_bin=x_bin,
    y_bin=y_bin,
    norm_data=True,
    plot_data=False,
)

x_top_left, y_top_left, d_top_left = tr_functions.slice_datacube(
    ad_dataclass=ad_top_left,
    slice_dim=slice_dim,
    slice_val=slice_val,
    int_range=int_range,
    xlim=xlim,
    ylim=(0.090, 0.14),
    x_bin=x_bin,
    y_bin=y_bin,
    norm_data=True,
    plot_data=False,
)

x_full_left, y_full_left, d_full_left = tr_functions.slice_datacube(
    ad_dataclass=ad_full_left,
    slice_dim=slice_dim,
    slice_val=slice_val,
    int_range=int_range,
    xlim=xlim,
    ylim=ylim,
    x_bin=x_bin,
    y_bin=y_bin,
    norm_data=True,
    plot_data=False,
)

In [None]:
# ## Plot Data ##
# x_plot, y_plot, d_plot = x_full_left, y_full_left, d_full_left
# x_plot, y_plot, d_plot = x_bottom_left, y_bottom_left, d_bottom_left

# fig = tr_functions.thesis_fig(
#     title=f"CT<sub>1</sub> (E - E<sub>HOMO</sub> = {slice_val})",
#     xaxis_title=xaxis_title,
#     yaxis_title=yaxis_title,
#     equiv_axes=True,
#     height=500,
#     width=500,
# )

# fig.add_trace(
#     go.Heatmap(
#         x=x_plot,
#         y=y_plot,
#         z=d_plot,
#         coloraxis="coloraxis",
#     )
# )

# hexagon = polygons.gen_polygon(6, 0.42, rotation=30)
# fig = polygons.plot_polygon(
#     hexagon, color="yellow", fig=fig, show=False, dash=True, dash_width=3
# )

# fig.update_coloraxes(cmin=0, cmax=1)

# fig.update_xaxes(range=[np.min(x_plot), np.max(x_plot)], constrain="domain")
# # fig.update_yaxes(range=[np.min(y_plot), np.max(y_plot)], scaleanchor="x", scaleratio=1)
# fig.update_yaxes(range=[np.min(y_plot), np.max(y_plot)], constrain="domain")

# fig.show()

In [None]:
## Stitch left side: bottom (FS0), top (FS3) ##

x1, y1, dataslice1 = x_bottom_left, y_bottom_left, d_bottom_left
x2, y2, dataslice2 = (
    x_top_left,
    y_top_left,
    d_top_left / 2,
    # d_top_left,
)  # normalize background to each other

xs, ys, ds = tr_functions.stitch_and_avg(
    x1,
    y1,
    dataslice1,
    x2,
    y2,
    dataslice2,
    no_avg=False,
)

In [None]:
## Fix 4 rows of missing data ##
x_fix, y_fix, d_fix = xs, ys, ds

# Bad traces (rows): 922 through 925
# Set data rows 922 - 925 = 0.5 * (921 + 926)

row_below = d_fix[921]
row_above = d_fix[926]
d_fix[922] = 0.5 * (row_below + row_above)
d_fix[923] = 0.5 * (row_below + row_above)
d_fix[924] = 0.5 * (row_below + row_above)
d_fix[925] = 0.5 * (row_below + row_above)
ds = d_fix

# # def fix_row(data, row):
# #     row_below = data[row - 1]
# #     row_above = data[row + 1]
# #     data[row] = 0.5 * (row_below + row_above)
# # fix_row(d_homo, 564)

In [None]:
## Average with full left side: (FS2) ##

x1, y1, dataslice1 = xs, ys, ds
x2, y2, dataslice2 = x_full_left, y_full_left, d_full_left

xs_2, ys_2, ds_2 = tr_functions.stitch_and_avg(
    x1,
    y1,
    dataslice1,
    x2,
    y2,
    dataslice2,
    no_avg=False,
)

In [None]:
# ## Fix 8 cols of missing data ##
# x_fix, y_fix, d_fix = xs_2, ys_2, ds_2

# # Bad traces (cols): 792 through 814
# # Set data cols 792 - 814 = 0.5 * (791 + 815)

# col_left = d_fix[:, 791]
# col_right = d_fix[:, 815]
# d_fix[:, 792] = 0.5 * (col_left + col_right)
# d_fix[:, 793] = 0.5 * (col_left + col_right)
# d_fix[:, 288] = 0.5 * (col_left + col_right)
# d_fix[:, 289] = 0.5 * (col_left + col_right)
# d_fix[:, 290] = 0.5 * (col_left + col_right)
# d_fix[:, 291] = 0.5 * (col_left + col_right)
# d_fix[:, 292] = 0.5 * (col_left + col_right)
# d_fix[:, 293] = 0.5 * (col_left + col_right)
# ds_2 = d_fix

In [None]:
# x_left, y_left, d_left = xs, ys, ds
x_left, y_left, d_left = xs_2, ys_2, ds_2

In [None]:
# ## Plot Data ##

# # x_plot, y_plot, d_plot = x_bottom_left, y_bottom_left, d_bottom_left
# # x_plot, y_plot, d_plot = x_top_left, y_top_left, d_top_left
# # x_plot, y_plot, d_plot = x_full_left, y_full_left, d_full_left

# x_plot, y_plot, d_plot = x_left, y_left, d_left

# fig = tr_functions.thesis_fig(
#     title=title,
#     xaxis_title=xaxis_title,
#     yaxis_title=yaxis_title,
#     equiv_axes=False,
#     height=500,
#     width=500,
# )

# fig.add_trace(
#     go.Heatmap(
#         x=x_plot,
#         y=y_plot,
#         z=analysis_functions.norm_data(d_plot),
#         coloraxis="coloraxis",
#     )
# )

# fig.update_coloraxes(cmin=0, cmax=1)

# fig.show()

In [None]:
# ## Rotate Data ##

# x, y, z = x_left, y_left, d_left
# coords = tr_functions.x_y_to_coords(x, y)

# rotated_coords = tr_functions.rotate_2d_array(coords, 120, (0, 0))
# rotated_coords_2 = tr_functions.rotate_2d_array(coords, 240, (0, 0))

# nx, ny, nd = tr_functions.interpolate(rotated_coords, z)
# nx_2, ny_2, nd_2 = tr_functions.interpolate(rotated_coords_2, z)

In [None]:
# ## Plot raw data and rotations on same figure ##

# fig = tr_functions.thesis_fig(
#     title=title,
#     xaxis_title=xaxis_title,
#     yaxis_title=yaxis_title,
#     equiv_axes=True,
#     gridlines=False,
#     height=600,
#     width=600,
# )

# fig.add_trace(
#     go.Heatmap(
#         x=nx_2,
#         y=ny_2,
#         z=nd_2,
#         coloraxis="coloraxis",
#         # opacity=0.85,
#     )
# )

# fig.add_trace(
#     go.Heatmap(
#         x=nx,
#         y=ny,
#         z=nd,
#         coloraxis="coloraxis",
#         # opacity=0.85,
#     )
# )

# fig.add_trace(
#     go.Heatmap(
#         x=x,
#         y=y,
#         z=z,
#         coloraxis="coloraxis",
#         # opacity=0.85,
#     )
# )

# hexagon = polygons.gen_polygon(6, 0.42, rotation=30)
# fig = polygons.plot_polygon(
#     hexagon, color="yellow", fig=fig, show=False, dash=True, dash_width=3
# )

# fig.update_coloraxes(
#     colorscale="ice",
#     reversescale=True,
#     showscale=True,
#     cmin=0,
#     cmax=None,
# )
# fig.update_yaxes(scaleanchor="x", scaleratio=1)
# fig.show()

## Right Side

In [None]:
# CT1 -- largest area matching stats

# files = ["FS1_avg_gkw.h5"]  # bottom
# files = ["FS4_avg_gkw.h5"]  # top
# files = ["FS567_avg_gkw.h5"]  # full range
files = []
files = ["FS1_avg_gkw.h5", "FS4_avg_gkw.h5", "FS567_avg_gkw.h5"]


# This works, but makes dataclass with theta and phi_or_time instead of kx and ky
ARPES_DATA: Dict[str, tr_functions.ArpesData] = {}
ARPES_ATTRS: Dict[str, tr_functions.ArpesAttrs] = {}
for file in tqdm(files):
    data, kx, ky, energy = loading_functions.load_hdf5(ddir, file)
    ARPES_DATA[file] = tr_functions.ArpesData(
        data=data, theta=kx, phi_or_time=ky, energy=energy
    )
    ARPES_ATTRS[file] = tr_functions.load_attrs_hdf5(ddir, file)

ad_bottom_right = ARPES_DATA[files[0]]
ad_top_right = ARPES_DATA[files[1]]
ad_full_right = ARPES_DATA[files[2]]

In [None]:
## Adjust energy axis to be relative to HOMO ##
homo_zero = False
homo_zero = True

if homo_zero:
    ad_bottom_right.energy = ad_bottom_right.energy - homo_400
    ad_top_right.energy = ad_top_right.energy - homo_400
    ad_full_right.energy = ad_full_right.energy - homo_400

In [None]:
slice_dim = "y"
# slice_val = 0
# int_range = 0.02
xlim = (0.051, 0.468)
# xlim = None
ylim = (-0.4, 0.14)
x_bin = 2
y_bin = 2

In [None]:
## Get Slices ##

x_bottom_right, y_bottom_right, d_bottom_right = tr_functions.slice_datacube(
    ad_dataclass=ad_bottom_right,
    slice_dim=slice_dim,
    slice_val=slice_val,
    int_range=int_range,
    xlim=xlim,
    ylim=(-0.4, 0.1),
    x_bin=x_bin,
    y_bin=y_bin,
    norm_data=True,
    plot_data=False,
)

x_top_right, y_top_right, d_top_right = tr_functions.slice_datacube(
    ad_dataclass=ad_top_right,
    slice_dim=slice_dim,
    slice_val=slice_val,
    int_range=int_range,
    xlim=xlim,
    ylim=ylim,
    x_bin=x_bin,
    y_bin=y_bin,
    norm_data=True,
    plot_data=False,
)

x_full_right, y_full_right, d_full_right = tr_functions.slice_datacube(
    ad_dataclass=ad_full_right,
    slice_dim=slice_dim,
    slice_val=slice_val,
    int_range=int_range,
    xlim=xlim,
    ylim=ylim,
    x_bin=x_bin,
    y_bin=y_bin,
    norm_data=True,
    plot_data=False,
)

In [None]:
## Stitch right side: bottom (FS1) & top (FS4) ##

x1, y1, dataslice1 = x_bottom_right, y_bottom_right, d_bottom_right
x2, y2, dataslice2 = x_top_right, y_top_right, d_top_right

xs, ys, ds = tr_functions.stitch_and_avg(
    x1,
    y1,
    dataslice1,
    x2,
    y2,
    dataslice2,
    no_avg=True,
)

In [None]:
# ## Fix 4 rows of missing data ##
# x_fix, y_fix, d_fix = xs, ys, ds

# # Bad traces (rows): 922 through 925
# # Set data rows 922 - 925 = 0.5 * (921 + 926)

# row_below = d_fix[921]
# row_above = d_fix[926]
# d_fix[922] = 0.5 * (row_below + row_above)
# d_fix[923] = 0.5 * (row_below + row_above)
# d_fix[924] = 0.5 * (row_below + row_above)
# d_fix[925] = 0.5 * (row_below + row_above)

# ds = d_fix

In [None]:
## Average with full right side: (FS5) ##

x1, y1, dataslice1 = xs, ys, ds
x2, y2, dataslice2 = x_full_right, y_full_right, d_full_right

xs_2, ys_2, ds_2 = tr_functions.stitch_and_avg(
    x1,
    y1,
    dataslice1,
    x2,
    y2,
    dataslice2,
    no_avg=False,
)

In [None]:
# x_right, y_right, d_right = xs, ys, ds
x_right, y_right, d_right = xs_2, ys_2, ds_2

In [None]:
# ## Plot Data ##
# # x_plot, y_plot, d_plot = x_bottom_right, y_bottom_right, d_bottom_right
# x_plot, y_plot, d_plot = x_full_right, y_full_right, d_full_right

# x_plot, y_plot, d_plot = x_right, y_right, d_right

# fig = tr_functions.thesis_fig(
#     title=title,
#     xaxis_title=xaxis_title,
#     yaxis_title=yaxis_title,
#     equiv_axes=False,
#     height=500,
#     width=500,
# )

# fig.add_trace(
#     go.Heatmap(
#         x=x_plot,
#         y=y_plot,
#         z=analysis_functions.norm_data(d_plot),
#         coloraxis="coloraxis",
#     )
# )

# fig.update_coloraxes(cmin=0, cmax=1)

# fig.show()

In [None]:
## Stitch left and right side ##

# 2.6 eV --> 0.7
# 2.5 eV --> 0.8
# 2.4 eV --> 1.2 ?
# 2.3 eV --> 1.0, 0.85
# 2.2 eV --> 1.1
# 2.1 eV --> 1.3

x1, y1, dataslice1 = x_left, y_left, analysis_functions.norm_data(d_left)
x2, y2, dataslice2 = x_right, y_right, analysis_functions.norm_data(d_right) * 0.9

xs_3, ys_3, ds_3 = tr_functions.stitch_and_avg(
    x1,
    y1,
    dataslice1,
    x2,
    y2,
    dataslice2,
    no_avg=True,
)

In [None]:
## Plot Data ##

x_plot, y_plot, d_plot = x_right, y_right, d_right
# x_plot, y_plot, d_plot = analysis_functions.limit_dataset(
#     xs_3, ys_3, ds_3, xlim=None, ylim=(-0.395, 0.134)
# )

fig = tr_functions.thesis_fig(
    title=title,
    xaxis_title=xaxis_title,
    yaxis_title=yaxis_title,
    equiv_axes=True,
    height=500,
    width=500,
)

fig.add_trace(
    go.Heatmap(
        x=x_plot,
        y=y_plot,
        z=d_plot,
        coloraxis="coloraxis",
    )
)

hexagon = polygons.gen_polygon(6, 0.42, rotation=30)
fig = polygons.plot_polygon(
    hexagon, color="yellow", fig=fig, show=False, dash=True, dash_width=3
)

fig.update_coloraxes(cmin=0.15, cmax=0.7)
# fig.update_coloraxes(colorscale="Blues", reversescale=False)

fig.update_xaxes(range=[np.min(x_plot), np.max(x_plot)], constrain="domain")
# fig.update_yaxes(range=[np.min(y_plot), np.max(y_plot)], scaleanchor="x", scaleratio=1)
fig.update_yaxes(range=[np.min(y_plot), np.max(y_plot)], constrain="domain")

fig.show()

## Reconstruct Full BZ

In [None]:
x, y, z = analysis_functions.limit_dataset(
    xs_3, ys_3, ds_3, xlim=None, ylim=(np.min(ys_3), 0.13)
)

In [None]:
## Rotate Data ##

# x, y, z = x, y - 0.09, z

coords = tr_functions.x_y_to_coords(x, y)

rotated_coords = tr_functions.rotate_2d_array(coords, 120, (0, 0))
rotated_coords_2 = tr_functions.rotate_2d_array(coords, 240, (0, 0))

nx, ny, nd = tr_functions.interpolate(rotated_coords, z)
nx_2, ny_2, nd_2 = tr_functions.interpolate(rotated_coords_2, z)

In [None]:
## Plot raw data and rotations on same figure ##

fig = tr_functions.thesis_fig(
    title=title,
    xaxis_title=xaxis_title,
    yaxis_title=yaxis_title, 
    equiv_axes=True,
    gridlines=False,
    height=600,
    width=600,
)

fig.add_trace(
    go.Heatmap(
        x=nx_2,
        y=ny_2,
        z=nd_2,
        coloraxis="coloraxis",
        # opacity=0.85,
    )
)

fig.add_trace(
    go.Heatmap(
        x=nx,
        y=ny,
        z=nd,
        coloraxis="coloraxis",
        # opacity=0.85,
    )
)

fig.add_trace(
    go.Heatmap(
        x=x,
        y=y,
        z=z,
        coloraxis="coloraxis",
        # opacity=0.85,
    )
)

hexagon = polygons.gen_polygon(6, 0.42, rotation=30)
fig = polygons.plot_polygon(
    hexagon, color="yellow", fig=fig, show=False, dash=True, dash_width=3
)

fig.update_coloraxes(
    colorscale="ice",
    reversescale=True,
    showscale=True,
    cmin=0,
    cmax=None,
)
fig.update_yaxes(scaleanchor="x", scaleratio=1)
fig.show()