In [None]:
%load_ext lab_black

import h5py
import os

from dataclasses import dataclass
from tqdm.auto import tqdm
from scipy.signal import savgol_filter
from scipy.interpolate import interp2d
from functools import lru_cache

from typing import Dict, List, Optional, Tuple
import numpy as np
import plotly.graph_objects as go
import matplotlib.pyplot as plt


import sys

sys.path.append(r"C:\Users\atully\Code\GitHub\ARPES Code\arpes-code-python")
from arpes_functions import (
    fitting_functions,
    analysis_functions,
    plotting_functions,
    HDF5_loader,
    misc_functions,
    filter_functions,
    tr_functions,
    loading_functions,
    cnn,
)

angstrom = "\u212B"

# Au Data

In [None]:
fp = r"E:\atully\arpes_data\2022_October\k_corrected"
fn = r"FS1_avg_gkw_filteredFFT_0.00int.h5"
data, theta, phi, energy = HDF5_loader.load_hdf5(fp, fn)  # load data from hdf5

In [None]:
# set slice parameters
slice_dim = "y"
int_range = 0.1
slice_val = 16.8

# plot slice
fig, ax = plotting_functions.plot_3D_mpl(
    x=theta,
    y=energy,
    z=phi,
    data=data,
    slice_dim=slice_dim,
    slice_val=slice_val,
    int_range=int_range,
    title=f"FS ({slice_val} eV)",
    cmap="Blues",
)

# set aspect ratio
ratio = 1.0
x_left, x_right = ax.get_xlim()
y_low, y_high = ax.get_ylim()
ax.set_aspect(abs((x_right - x_left) / (y_low - y_high)) * ratio)

In [None]:
fn = r"FS4_avg4_gkw_filteredFFT_0.00int.h5"
data, theta, phi, energy = HDF5_loader.load_hdf5(fp, fn)  # load data from hdf5

In [None]:
# set slice parameters
slice_dim = "y"
int_range = 0.1
slice_val = 16.8

# plot slice
fig, ax = plotting_functions.plot_3D_mpl(
    x=theta,
    y=energy,
    z=phi,
    data=data,
    slice_dim=slice_dim,
    slice_val=slice_val,
    int_range=int_range,
    title=f"FS ({slice_val} eV)",
    cmap="Blues",
)

# set aspect ratio
ratio = 1.0
x_left, x_right = ax.get_xlim()
y_low, y_high = ax.get_ylim()
ax.set_aspect(abs((x_right - x_left) / (y_low - y_high)) * ratio)

## Combine Datasets
### Procedure:
1. Take two dataslices at relevant energy slice, normalized intensity.
2. Make new theta axis that runs from absolute min to absolute max of two datasets.
3. Find overlap region of dataset along theta axis.
4. Create two weighting arrays going from 1 to 0 and 0 to 1, with a length that spans the overlap.
5. Weight (multiply) each dataset's overlap region by their relevant "fading" (weighting) array.
6. Average the two datasets together over the overlapping region.

In [None]:
slice_val = 16.8
int_range = 0.1

In [None]:
fp = r"E:\atully\arpes_data\2022_October\k_corrected"
fn = r"FS1_avg_gkw_filteredFFT_0.00int.h5"
data1, theta1, phi1, energy1 = HDF5_loader.load_hdf5(fp, fn)  # load data from hdf5

# Get slice
xaxis1, yaxis1, dataslice1 = analysis_functions.get_2Dslice(
    x=theta1,
    y=energy1,
    z=phi1,
    data=data1,
    slice_dim="y",
    slice_val=slice_val,
    int_range=int_range,
)
dataslice1 = analysis_functions.norm_data(dataslice1)

In [None]:
fn = r"FS4_avg4_gkw_filteredFFT_0.00int.h5"
data4, theta4, phi4, energy4 = HDF5_loader.load_hdf5(fp, fn)  # load data from hdf5

# Get slice
xaxis4, yaxis4, dataslice4 = analysis_functions.get_2Dslice(
    x=theta4,
    y=energy4,
    z=phi4,
    data=data4,
    slice_dim="y",
    slice_val=slice_val,
    int_range=int_range,
)
dataslice4 = analysis_functions.norm_data(dataslice4)

In [None]:
# 2

from scipy.interpolate import interp1d

new_theta = np.linspace(
    min(min(theta1), min(theta4)), max(max(theta1), max(theta4)), 2000
)

new_datas = []
for theta, data in zip([theta1, theta4], [dataslice1, dataslice4]):
    nd = []
    for row in data:
        interper = interp1d(theta, row, fill_value=np.nan, bounds_error=False)
        nd.append(interper(new_theta))
    new_datas.append(np.array(nd))

In [None]:
fig = go.Figure(
    #     data=go.Heatmap(x=new_theta, y=yaxis1, z=new_datas[0])
    data=go.Heatmap(x=new_theta, y=yaxis1, z=new_datas[1])
    #     data=go.Heatmap(x=[1, 2], y=[3, 4], z=[[1,2], [3,4]])
)
fig.update_layout(width=800, height=600)
fig.show(renderer="svg")
# print(new_theta.shape, yaxis1.shape, new_datas[1].shape)

In [None]:
# 3

left = np.min(theta1[dataslice1[100] > 0.01])
right = np.max(theta4[dataslice4[100] > 0.01])

overlap_indices = (
    np.where(new_theta > left)[0][0],
    np.where(new_theta < right)[0][-1],
)  # indices over which the datasets will overlap
# overlap_indices = np.min(new_theta[new_theta > left]), np.max(new_theta[new_theta < right])
print(overlap_indices)

In [None]:
# 4

w1 = np.linspace(0, 1, overlap_indices[1] - overlap_indices[0])
w4 = np.flip(w1)

In [None]:
# 5

overlap1 = w1 * new_datas[0][:, overlap_indices[0] : overlap_indices[1]]
overlap4 = w4 * new_datas[1][:, overlap_indices[0] : overlap_indices[1]]

overlap = overlap1 + overlap4

In [None]:
new_data = np.concatenate(
    (
        new_datas[1][:, : overlap_indices[0]],
        overlap,
        new_datas[0][:, overlap_indices[1] :],
    ),
    axis=-1,
)
new_data.shape

In [None]:
fig, ax = plotting_functions.plot_2D_mpl(
    x=new_theta,
    y=yaxis1,
    data=new_data,
    xlabel="kx",
    ylabel="E_k",
    title=f"",
    cmap="Blues",
)

# plt.savefig(r'C:\Users\atully\OneDrive\Physics.UBC\PhD\Dissertation\Data\Au\Au111_FS.png')

In [None]:
## limit to quadrant ##

fig, ax = plotting_functions.plot_2D_mpl(
    x=new_theta,
    y=yaxis1,
    data=new_data,
    xlabel="kx",
    ylabel="E_k",
    title=f"",
    cmap="Blues",
)

ax.set_ylim(-1.4, 0.0)
plt.yticks(np.arange(-1.0, 0.1, 0.5), fontsize=14)
ax.set_xlim(-1.25, 0.0)
plt.xticks(np.arange(-1.0, 0.1, 0.5), fontsize=14)

ratio = 1.0
x_left, x_right = ax.get_xlim()
y_low, y_high = ax.get_ylim()
ax.set_aspect(abs((x_right - x_left) / (y_low - y_high)) * ratio)

# plt.savefig(r'C:\Users\atully\OneDrive\Physics.UBC\PhD\Dissertation\Data\Au\Au111_FS_quadrant.png')

In [None]:
## Create Plotly Images for Symmeterizing ##

## Largest raw data

fig = tr_functions.thesis_fig(
    title=f"Au(111) FS",
    xaxis_title=f"k<sub>x</sub> [{angstrom}<sup>-1</sup>]",
    yaxis_title=f"k<sub>y</sub> [{angstrom}<sup>-1</sup>]",
    equiv_axes=True,
    gridlines=False,
)

fig.add_trace(
    go.Heatmap(
        x=new_theta,
        y=yaxis1,
        z=analysis_functions.norm_data(new_data),
        coloraxis="coloraxis",
    )
)


fig.update_coloraxes(colorscale="Blues", showscale=True)
fig.update_layout(
    title=dict(
        text=f"Au(111) FS", x=0.5, xanchor="center", yanchor="top", font_size=22
    ),
)

fig.update_coloraxes(colorscale="Blues", reversescale=False)

# fig.update_yaxes(range=[-1.4, 0.263])
# fig.update_yaxes(scaleanchor="x", scaleratio=1)
# fig.update_xaxes(range=[-1.25, 0.68])


fig.update_layout(width=600, height=600)

fig.show()

In [None]:
## For Symmetrizing: Flip Dataset ##

fig = tr_functions.default_fig()

fig.add_trace(go.Heatmap(x=new_theta, y=-1 * yaxis1, z=new_data, coloraxis="coloraxis"))

fig.update_coloraxes(colorscale="Blues", showscale=True)
fig.update_layout(
    title=dict(
        text=f"Au(111) FS", x=0.5, xanchor="center", yanchor="top", font_size=22
    ),
)

fig.update_xaxes(
    title_text="$k_x \; [A^{-1}]$", title_font=dict(size=20), range=[-1.25, 0.68]
)
fig.update_yaxes(
    title_text="$k_y \; [A^{-1}]$", title_font=dict(size=20), range=[-0.263, 1.4]
)

fig.update_layout(
    width=700, height=600, margin=dict(l=100)
)  # margin=dict(b=0, t=30, l=20, r=0)
fig.show(renderer="svg")

In [None]:
## Limit Dataset ##

x, y, z = analysis_functions.limit_dataset(
    x=new_theta, y=yaxis1, data=new_data, xlim=(-1.25, 0.68), ylim=(-1.4, 0.263)
)

x, y, z = x, y, analysis_functions.norm_data(z)

In [None]:
## Combine Raw and Flipped Data ##

fig = tr_functions.thesis_fig(
    title=f"Au(111) FS",
    xaxis_title=f"k<sub>x</sub> [{angstrom}<sup>-1</sup>]",
    yaxis_title=f"k<sub>y</sub> [{angstrom}<sup>-1</sup>]",
    equiv_axes=True,
    gridlines=False,
)

fig.add_trace(
    go.Heatmap(
        x=x,
        y=-1 * y,
        z=z,
        coloraxis="coloraxis",
    )
)

fig.add_trace(
    go.Heatmap(
        x=x,
        y=y,
        z=z,
        coloraxis="coloraxis",
    )
)


fig.update_coloraxes(colorscale="Blues", showscale=True)
fig.update_layout(
    title=dict(
        text=f"Au(111) FS", x=0.5, xanchor="center", yanchor="top", font_size=22
    ),
    # xaxis_title="$k_x \; [A^{-1}]$",  # or \quad for larger space
    # yaxis_title="$k_y \; [A^{-1}]$",
)

fig.update_coloraxes(colorscale="Blues", reversescale=False)

fig.update_yaxes(range=[-1.3, 1.3])
fig.update_yaxes(scaleanchor="x", scaleratio=1)
fig.update_xaxes(range=[-1.25, 1.25])

fig.update_layout(width=600, height=600)

# fig.update_layout(
#     width=700, height=600, margin=dict(l=100)
# )  # margin=dict(b=0, t=30, l=20, r=0)
fig.show()

In [None]:
## Rotate Dataset ##

coords = tr_functions.x_y_to_coords(x, y)

rotated_coords = tr_functions.rotate_2d_array(coords, 120, (0, 0))

nx, ny, nd = tr_functions.interpolate(rotated_coords, z)

In [None]:
# nd[np.isnan(nd)] = 0.02

In [None]:
## Plot All Datasets ##

lim = 1.23

fig = tr_functions.thesis_fig(
    title=f"Au(111) FS",
    xaxis_title=f"k<sub>x</sub> [{angstrom}<sup>-1</sup>]",
    yaxis_title=f"k<sub>y</sub> [{angstrom}<sup>-1</sup>]",
    equiv_axes=False,
    gridlines=False,
)

fig.add_trace(
    go.Heatmap(
        x=[-1 * lim, lim, -1 * lim, lim],
        y=[-1 * lim, -1 * lim, lim, lim],
        z=[0.02] * 4,
        coloraxis="coloraxis",
    )
)

fig.add_trace(
    go.Heatmap(
        x=nx,
        y=ny,
        z=nd,
        coloraxis="coloraxis",
    )
)


fig.add_trace(
    go.Heatmap(
        x=x,
        y=-1 * y,
        z=z,
        coloraxis="coloraxis",
    )
)

fig.add_trace(
    go.Heatmap(
        x=x,
        y=y,
        z=z,
        coloraxis="coloraxis",
    )
)


fig.update_coloraxes(colorscale="Blues", showscale=True, cmin=0, cmax=1)
fig.update_layout(
    title=dict(
        text=f"Au(111) FS", x=0.5, xanchor="center", yanchor="top", font_size=22
    ),
    # xaxis_title="$k_x \; [A^{-1}]$",  # or \quad for larger space
    # yaxis_title="$k_y \; [A^{-1}]$",
)

fig.update_coloraxes(colorscale="Blues", reversescale=False)

fig.update_yaxes(range=[-1 * lim, lim])
fig.update_yaxes(scaleanchor="x", scaleratio=1)
fig.update_xaxes(range=[-1 * lim, lim])

fig.update_layout(width=600, height=600, margin=dict(l=100))

# fig.update_layout(
#     width=700, height=600, margin=dict(l=100)
# )  # margin=dict(b=0, t=30, l=20, r=0)
fig.show()