## Imports and Functions

In [None]:
%load_ext lab_black

import h5py
import os

from dataclasses import dataclass
from tqdm.auto import tqdm
from scipy.signal import savgol_filter
from scipy.interpolate import interp2d
from functools import lru_cache
import lmfit as lm

from typing import Dict, List, Optional, Tuple
import numpy as np
import plotly.graph_objects as go
import plotly.colors as pc
import matplotlib.pyplot as plt


import sys

sys.path.append(r"C:\Users\atully\Code\GitHub\ARPES Code\arpes-code-python")
from arpes_functions import (
    fitting_functions,
    analysis_functions,
    plotting_functions,
    HDF5_loader,
    misc_functions,
    filter_functions,
    tr_functions,
    loading_functions,
    cnn,
    polygons,
)

colors = pc.qualitative.D3
colors = pc.qualitative.Plotly
angstrom = "\u212B"
theta = "\u03B8"
Theta = "\u0398"

# XUV Diffraction Data

In [None]:
ddir = r"E:\atully\k-corrected data\Apr_2021\XUV_FS_gamma0"
files = ["XUV_FS_gamma0_gkw11_filteredFFT_0.00int.h5"]
# data, theta, phi, energy = HDF5_loader.load_hdf5(fp, fn)  # load data from hdf5

In [None]:
ARPES_DATA: Dict[str, tr_functions.ArpesData] = {}
ARPES_ATTRS: Dict[str, tr_functions.ArpesAttrs] = {}
for file in tqdm(files):
    data, kx, ky, energy = loading_functions.load_hdf5(ddir, file)
    ARPES_DATA[file] = tr_functions.ArpesData(
        data=data, theta=kx, phi_or_time=ky, energy=energy
    )
    ARPES_ATTRS[file] = tr_functions.load_attrs_hdf5(ddir, file)

ad = ARPES_DATA[files[0]]

In [None]:
EF = 18.3
ad.energy = ad.energy - EF

In [None]:
title = f"E - E<sub>F</sub> = "
xaxis_title = f"k<sub>x</sub> ({angstrom}<sup>-1</sup>)"
yaxis_title = f"k<sub>y</sub> ({angstrom}<sup>-1</sup>)"

In [None]:
slice_dim = "y"
# slice_val = 0
int_range = 0.2
xlim = (-0.73, 0.52)
ylim = (-1.4, 0.1)
# ylim = None
x_bin = 1
y_bin = 1


# Load Data

for i, v in enumerate([0, -0.2, -0.4]):  # CT2
    # for i, v in enumerate([2.4, 2.3, 2.2, 2.1, 2.0]):  # CT1
    x, y, d = tr_functions.slice_datacube(
        ad_dataclass=ad,
        slice_dim=slice_dim,
        slice_val=v,
        int_range=int_range,
        xlim=xlim,
        ylim=ylim,
        x_bin=x_bin,
        y_bin=y_bin,
        norm_data=False,
        plot_data=False,
    )

    ## Plot Data Plotly
    fig = tr_functions.thesis_fig(
        title=f"{title}{v}",
        xaxis_title=xaxis_title,
        yaxis_title=yaxis_title,
        equiv_axes=True,
        height=500,
        width=500,
        dtick_y=0.4,
    )

    fig.add_trace(go.Heatmap(x=x, y=y, z=d, coloraxis="coloraxis"))

    #     hexagon = polygons.gen_polygon(6, 0.42, rotation=30, translation=(0, 0))
    #     fig = polygons.plot_polygon(
    #         hexagon, color="yellow", fig=fig, show=False, dash=True, dash_width=3
    #     )

    #     hexagon = polygons.gen_polygon(6, 0.42, rotation=30, translation=(0, 2 * (-0.36)))
    #     fig = polygons.plot_polygon(
    #         hexagon, color="yellow", fig=fig, show=False, dash=True, dash_width=3
    #     )

    # fig.update_yaxes(scaleanchor="x", scaleratio=1)

    if xlim is not None:
        fig.update_xaxes(range=[xlim[0], xlim[1]], constrain="domain")

    # fig.update_coloraxes(cmin=0, cmax=1)
    fig.update_coloraxes(colorscale="Blues", reversescale=False)
    fig.show()

In [None]:
v = -0.2

slice_dim = "y"
int_range = 0.2
xlim = (-0.73, 0.52)
ylim = (-1.4, 0.1)
x_bin = 1
y_bin = 1

x, y, d = tr_functions.slice_datacube(
    ad_dataclass=ad,
    slice_dim=slice_dim,
    slice_val=v,
    int_range=int_range,
    xlim=xlim,
    ylim=ylim,
    x_bin=x_bin,
    y_bin=y_bin,
    norm_data=False,
    plot_data=False,
)

In [None]:
gm = -0.36
window = 0.06

## 1st BZ
offset = 0
ylim1 = (offset * 2 - window, offset * 2 + window)
ylim2 = (gm * 0.5 - window, gm * 0.5 + window)  # down

## 2nd BZ
offset = gm
ylim1 = (offset * 2 - window, offset * 2 + window)
ylim2 = (offset * 1.45 - window, offset * 1.45 + window)

## Noise sample
ylim3 = (-1.3 - window, -1.3 + window)

ylims = []
ylims = [ylim1, ylim2, ylim3]

In [None]:
## Heatmap
fig = tr_functions.thesis_fig(
    title=f"{title}{v}",
    xaxis_title=xaxis_title,
    yaxis_title=yaxis_title,
    equiv_axes=True,
    height=500,
    width=500,
    dtick_y=0.4,
)

fig.add_trace(go.Heatmap(x=x, y=y, z=d, coloraxis="coloraxis"))
fig.add_hline(ylim1[0], line_width=1, line_dash="dot", line_color=colors[0])
fig.add_hline(ylim1[1], line_width=1, line_dash="dot", line_color=colors[0])
fig.add_hline(ylim2[0], line_width=1, line_dash="dot", line_color=colors[1])
fig.add_hline(ylim2[1], line_width=1, line_dash="dot", line_color=colors[1])
fig.add_hline(ylim3[0], line_width=1, line_dash="dot", line_color=colors[2])
fig.add_hline(ylim3[1], line_width=1, line_dash="dot", line_color=colors[2])

fig.add_vline(0.3, line_width=1, line_dash="dot", line_color="grey")
fig.add_vline(0.5, line_width=1, line_dash="dot", line_color="grey")


hexagon = polygons.gen_polygon(6, 0.42, rotation=30, translation=(0, 2 * (-0.36)))
fig = polygons.plot_polygon(
    hexagon, color=colors[0], fig=fig, show=False, dash=True, dash_width=3
)
hexagon = polygons.gen_polygon(6, 0.42, rotation=0, translation=(0, 2 * (-0.36)))
fig = polygons.plot_polygon(
    hexagon, color=colors[1], fig=fig, show=False, dash=True, dash_width=3
)

if xlim is not None:
    fig.update_xaxes(range=[xlim[0], xlim[1]], constrain="domain")

# fig.update_coloraxes(cmin=0, cmax=1)
fig.update_coloraxes(colorscale="Blues", reversescale=False)

fig.update_xaxes(range=[0.25, 0.52], constrain="domain")
fig.update_yaxes(range=[-0.8, -0.4], constrain="domain")

fig.show()

In [None]:
## MDCs
fig = tr_functions.thesis_fig(
    title=f"MDCs",
    xaxis_title=xaxis_title,
    yaxis_title=f"Intensity (arb. u.)",
    equiv_axes=False,
    # height=500,
    # width=500,
    # dtick_y=0.4,
    gridlines=False,
)

cs = [colors[0], colors[1], colors[2]]
names = ["dominant", "rot 30", "noise"]
for ylim, c, name in zip(ylims, cs, names):
    x_1d, row = tr_functions.get_1d_y_slice(
        x=x,
        y=y,
        data=d,
        xlims=(0.3, 0.5),
        # xlims=None,
        y_range=ylim,
    )

    fig.add_trace(go.Scatter(x=x_1d, y=row, line=dict(color=c), name=name))
    # fig.add_vline(0.42, line_width=1, line_dash="dot", line_color="green")
    # fig.add_vline(0.3637, line_width=1, line_dash="dot", line_color="hotpink")


fig.show()

In [None]:
noise_x_range = [0.47, 0.5]
fraction = 0.05


xs = []
datas = []
for ylim, c, name in zip(ylims, cs, names):
    x_1d, row = tr_functions.get_1d_y_slice(
        x=x,
        y=y,
        data=d,
        xlims=(0.3, 0.5),
        y_range=ylim,
    )
    xs.append(x_1d)
    datas.append(row)

In [None]:
fig = go.Figure()
for x_, data, name in zip(xs, datas, names):
    fig.add_trace(go.Scatter(x=x_, y=data, name=name))
fig.update_layout(template="plotly_white", width=600, height=400)
for v in noise_x_range:
    fig.add_vline(v, line=dict(color="black", dash="dot"))
fig.show()

In [None]:
fig = go.Figure()
for x_, data, name in zip(xs, datas, names):
    data -= np.nanmean(
        data[np.logical_and(x_ > noise_x_range[0], x_ < noise_x_range[1])]
    )
    fig.add_trace(go.Scatter(x=x_, y=data, name=name))
fig.update_layout(template="plotly_white", width=600, height=400)
fig.show()

In [None]:
triangle_coords = np.array([(0.325, 0), (0.377, 3.35), (0.45, 0)])
fig.add_trace(
    go.Scatter(
        x=triangle_coords[:, 0],
        y=triangle_coords[:, 1],
        mode="lines",
        line=dict(color="black", dash="dash"),
        name="feature fit",
    )
)
fig.show()

In [None]:
new_coords = np.copy(triangle_coords)
new_coords[:, 1] = new_coords[:, 1] * fraction
fig.add_trace(
    go.Scatter(
        x=new_coords[:, 0],
        y=new_coords[:, 1],
        mode="lines",
        line=dict(color="black", dash="dash"),
        name=f"{fraction*100}%",
    )
)
fig.show()

In [None]:
line = lm.models.LinearModel()

line_fits = []
for coord_pair in [new_coords[:2], new_coords[1:]]:
    line_fit = line.fit(coord_pair[:, 1], x=coord_pair[:, 0])
    line_fits.append(line_fit)

new_data = np.copy(datas[1])
new_x = xs[1]
# Go through all values, and figure out what we need to add
for i, x_ in enumerate(new_x):
    # If before the dashed triangle
    if x_ < triangle_coords[0, 0]:
        continue
    # If before peak coord
    if x_ < triangle_coords[1, 0]:
        new_data[i] += line_fits[0].eval(x=x_)
        continue
    # If after peak but before end of triangle
    elif x_ < triangle_coords[2, 0]:
        new_data[i] += line_fits[1].eval(x=x_)
        continue
    # If after triangle
    else:
        break

In [None]:
fig.add_trace(
    go.Scatter(
        x=new_x,
        y=new_data,
        mode="lines",
        line_color="hotpink",
        name=f"rot 30 + {fraction*100}%",
    )
)
fig.show()