In [None]:
project_path = "X:/Asmara/"
base_atom = "O"
energy_dispersion = 0.00457  # eV/subpixel


In [None]:
%pylab agg
%matplotlib nbagg

import h5py
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import correlate
from scipy.signal import find_peaks
# from scipy.optimize import curve_fit


In [None]:
xas_dict = {}


def load_xas(scan_number, path=project_path.replace("\\", "/"), base=base_atom):

    global xas_dict
    if scan_number not in xas_dict:
        xas = {}
        data = np.loadtxt(f"{path}/XAS/{base}_{scan_number:04d}.xas", comments="#")
        xas["EN"] = data[:, 0]
        xas["TEY"] = data[:, 1]
        xas["TFY"] = data[:, 2]
        xas["RMU"] = data[:, 3]
        xas_dict[scan_number] = xas
    else:
        xas = xas_dict[scan_number]

    return xas["EN"], xas["TEY"], xas["TFY"], xas["RMU"]


In [None]:
rixs_dict = {}


def load_h5(file_name):

    f = h5py.File(file_name, "r")
    ccd = np.array(f["entry"]["analysis"]["spectrum"][()])
    f.close()

    return ccd


def load_ccds(scan_number, path=project_path.replace("\\", "/"), base=base_atom):

    global rixs_dict
    if scan_number not in rixs_dict:
        rixs = {}
        for i in range(1, 4):
            rixs[i] = load_h5(f"{path}/RIXS/{base}_{scan_number:04d}_d{i}.h5")
        rixs_dict[scan_number] = rixs
    else:
        rixs = rixs_dict[scan_number]

    return rixs[1], rixs[2], rixs[3]


In [None]:
def x_corr(refData, uncorrData):
    corr = correlate(refData, uncorrData)
    lag = np.argmax(corr)
    corrData = np.roll(uncorrData, lag)

    return corrData


def elastic_shift(
    pixelData, zeropixel=None, height=10, width=3, enDisp=energy_dispersion
):
    xdataPixel = np.arange(len(pixelData))

    if zeropixel == None:
        #     try to find peak
        peaks, _ = find_peaks(pixelData, height=height, width=width)

        #     try to find peak with right edge
        # peaks = xdataPixel[pixelData>height]

    else:
        peaks = [zeropixel]

    # chop data
    xdataPixel = xdataPixel[(peaks[-1] - 2000) : (peaks[-1] + 200)]
    energyData = pixelData[(peaks[-1] - 2000) : (peaks[-1] + 200)]

    xDataEnergy = (xdataPixel - peaks[-1]) * enDisp

    return xDataEnergy, energyData

In [None]:
def load_rix(scan_number):

    ccd1, ccd2, ccd3 = load_ccds(scan_number)
    ccd1 = x_corr(ccd2, ccd1)
    ccd3 = x_corr(ccd2, ccd3)

    #    load raw data
    xdata = np.arange(len(ccd2))
    ydata = ccd1 + ccd2 + ccd3

    #     shift automaticlly
    xdata, ydata = elastic_shift(ydata)

    return xdata, ydata


def load_rixs(scans):

    if type(scans) is int:
        xdata,ydata = load_rix(scans)
        # normalize data
        # ydata = ydata/3
    else:
        for i, scan_number in enumerate(scans):
            if i == 0:
                xdata, onedata = load_rix(scan_number)
                refdata = onedata
                ydata = onedata
            else:
                _, onedata = load_rix(scan_number)
                onedata = x_corr(refdata, onedata)
                ydata = ydata + onedata

        # normalize data
        # ydata = ydata/len(scans)/3
    
    # shift elastic peaks
    # xdata = elastic_shift(ydata)

    return xdata, ydata

In [None]:
def plot_map(run_list,Y = None):
    run_num = len(run_list)
    data = np.zeros((run_num,2200))
    for i,runs in enumerate(run_list):
        if type(runs) is list:
            X,d = load_rixs(runs)
        else:
            X,d = load_rixs([runs])
        data[i,:] = d
    if Y is None:
        Y = np.arange(run_num+1)
    else:
        pass
    fig,ax = plt.subplots()
    im = ax.pcolorfast(X,Y,data)
    fig.colorbar(im)
    ax.set_xlabel('Energy Loss ( eV )')
    
    return fig,ax,im
        

## Example to plot XAS data

In [None]:
figure()
EN, TEY, TFY, RMU = load_xas(2)
plot(EN, TEY)


## Example to plot one RIXS acquirement

In [None]:
figure()
d1, d2, d3 = load_ccds(28)
plot(d1)
plot(d2)
plot(d3)


## Example to combine and plot repeating RIXS data

In [None]:
figure()
X, Y = load_rixs([18, 19, 20, 21])
plot(X, Y)


## Example to plot sequence in colormap

In [None]:
run_list = range(16,26+1)
fig,ax,im = plot_map(run_list,Y=(526-0.5,536+0.5))
im.set_clim(0,10)
ax.set_ylabel('Photon Energy ( eV )')