In [None]:
# default settings
inpath = "" # type: str
outpath = "" # type: str
edge_element = "" # type: str
window_width = 40 # type: int
sztol = 0.9  # type: float
normalised = True  # type: bool
method = "mutual"  # type: str
ref_index = None  # type: int
max_fractional_shift = 0.2  # type: float
auto_processing = True # type: bool

# Stacking and alignment for XANES
This notebook attempts to stack a sequence of datasets acquired at different energy for a particular line group and perform alignment based on a line group. Some functionalities like normalisation, cropping and removing inconsistent size of scanning axis are incoporated. The algorithm for alignment is by mutual information or phase correlation. The aligned stack is saved as an hdf5 file which can be recognised by Mantis for further analysis.

## Parameters
    
**inpath** : string  
    The last raw data file of the XANES stack, it must contain the `"/entry/previous_scan_files/paths"` entry.   
    For example, `"/dls/i14/data/2023/cm33895-3/scan/i14-234367.nxs"`. 

**outpath** : string  
    The output file of the XANES stack, which is compatible to Mantis.       
    For example, `"/dls/i14/data/2023/cm33895-3/processed/mantis_xanes_234367.hdf5"`.   
    
**edge_element** : string  
    The line group of interested. This can be the same as _element_to_align_ or different.  
    For example, `"Ca-Ka"` or `"Fe-Ka"`.   
    
**window_width** : int  
    The channel width for the XRF window.  
    For example, `40`.  
    
**sztol** : float   
    The percentage tolerance of the size of the scan for declaring a scan to be removed or cropped.    
    The number must be between 0 and 1 and sensible value should be above 0.9.    
    The reference is always set by the scan _first_.    
    This applies to both x and y axis.   
       
       
   For example, if the size of x axis in _first_ is 71 and _sztol_ is 0.9:   
   dataset with x-axis size >63 (71\*0.9 rounded down) will be cropped;   
   dataset with x-axis size <=63 will be removed from the stack.
    
    <-0___________remove___________63-><-_____crop_____71->
         |                                 |
         |                                 |  
         |     all the dataset will be cropped to the minimum x-axis size of the whole sequence
         |     e.g. if one dataset has an x-axis size 67, all dataset is cropped to 67 in the x axis
         |          it is not removed as it is higher than the tolerance
         |
         |
         |
    all the dataset will be removed from the stack
    e.g. if a dataset has an x-axis size 21, it is removed
         it is not cropped as it is lower than the tolerance
       
**normalised** : boolean    
    Whether to normalise the stack by I0.
    `True` or `False`.
    
**method** : string   
    The method to use for alignment. `mutual` (mutual information) or `fourier` (Fourier correlation).
    
**ref_index** : None or integer   
    The index of the reference image for alignment.  
    If `None`, the reference image is the one with maximum total intensity. This works most of the time.
    
**max_fractional_shift** : float   
    <span style="color:blue">This is only relevant when _method_ is "mutual".</span>   
    This number is used to define the maximum shift that could be applied to an axis, as a fraction of the size of the axis.  
    This should be between 0 and 1. Setting it to a higher number results in longer duration of alignment.   
    A sensible value is between `0.2` and `0.4`.

**auto_processing** : bool    
    the flag to state this is an auto- or post-processing.    

         
## Legacy workflow
1. There is no need to have the variable _element_to_load_ as this is redundant; it is determined from _element_to_align_ and _edge_element_.
2. There is no need to specify (although you can do this through _to_remove_) the problematic scan number that is to be removed from the stack. It is done via the tolerance _sz_tol_, and the cropping and removal are all done via this. 
3. The algorithm for alignment is changed from Fourier correlation (in-built in HyperSpy) to mutual information (written by Paul Q). This is believed to be more robust.

## Working environment
This notebook is developed under the `python/epsic3.10` environment (available after `module load python/epsic3.10`). It works both locally and in cluster.

### Package import

In [None]:
import ast
import time
from pathlib import Path

import h5py
import numpy as np
import matplotlib.pyplot as plt
import hyperspy.api as hs

from i14_utility.xanes.mutual import estimate_shift2D_mutual
from i14_utility.xanes.window_xrf import (channel_start_end, read_raw_data, window_mca, check_inconsistent_xy, 
crop_stack_map, print_file_summary)
from i14_utility.xanes.io import save_mantis, save_stack_as_gif

In [None]:
from importlib.metadata import version
print(f"HyperSpy version: {version('hyperspy')}")

### Override the outpath
Save the file with `.hdf5` extension and inside `.../processed/xanes`

In [None]:
op = Path(outpath)
processed_xanes = op.parent / "xanes"
processed_xanes.mkdir(parents=True, exist_ok=True)

processed_xanes_nxs = processed_xanes / op.name

mantis_outpath = str(processed_xanes_nxs.with_suffix(".hdf5"))
gif_outpath = str(processed_xanes_nxs.with_suffix(".gif"))
png_outpath = str(processed_xanes_nxs.with_suffix(".png")) 

print(f"The MANTIS file will be saved as {mantis_outpath}")
print(f"The GIF file will be saved as {gif_outpath}")
print(f"The PNG file will be saved as {png_outpath}")

### Read the list of scans and tracking line from the last file

In [None]:
if auto_processing:
    try:
        with h5py.File(inpath, "r") as f:
            file_list = f["/entry/previous_scan_files/paths"][()]
            element_to_align = f["/entry/line"][()].decode()
    except:
        raise
    else:
        file_list = [f.decode() for f in file_list] + [inpath]
        print(f"Number of scan files in the list: {len(file_list)}")
else:
    exclude = ast.literal_eval(exclude)
    file_list = [f"{visit_dir}/scan/i14-{x}.nxs" for x in range(first, last + 1) if x not in exclude]
        
if not element_to_align or element_to_align == "None":
    element_to_align = edge_element
    
print(f"Number of scan files in the list: {len(file_list)}")
print(f'Line group used for tracking: {element_to_align}')
print(f'Line group to be aligned: {edge_element}')

### Window both element_to_align and edge_element

In [None]:
start = time.perf_counter()

stack_element_align = []
stack_edge_element = []
energy_data = []
SampleX = []
SampleY = []
scan_shapes = []
I0_total = []

start_align, end_align = channel_start_end(element_to_align, width=window_width)
start_edge, end_edge = channel_start_end(edge_element, width=window_width)

for raw_data in file_list:
    # windowing happens here for each file
    data = read_raw_data(raw_data)
    
    # window it
    windowed_align = window_mca(data["mca"], start_align, end_align, data["scan_shape"], data["scan_model"])
    if edge_element != element_to_align:
        windowed_edge = window_mca(data["mca"], start_edge, end_edge, data["scan_shape"], data["scan_model"])
    else:
        windowed_edge = windowed_align
    
    # sum the I0
    I0_t = np.squeeze(data["I0_1"] + data["I0_2"] + data["I0_3"] + data["I0_4"])
    
    # record everything
    stack_element_align.append(windowed_align)
    stack_edge_element.append(windowed_edge)
    
    energy_data.append(data["energy"])
    SampleX.append(data["x"])
    SampleY.append(data["y"])
    scan_shapes.append(data["scan_shape"])
    I0_total.append(I0_t)
    
print(f"Time reading raw data: {(time.perf_counter() - start)/60:.2f} min")

### Check inconsistency of element maps

In [None]:
valid_scan, dim_x, dim_y = check_inconsistent_xy(SampleX, SampleY, scan_shapes, sztol=sztol)

### A summary of scan files

In [None]:
first, last = print_file_summary(file_list, valid_scan)

In [None]:
# take the first scan shape in the file list
print(f"Original elemental map shape in (y, x): {scan_shapes[0]}")
print(f"Elemental map shape of the stack in (y, x): ({dim_y}, {dim_x})")
print(f"Number of elemental maps in the stack: {np.count_nonzero(valid_scan)}")

### Construct the stack as HyperSpy signal
This is for easier inspection and interface with alignment function. 

In [None]:
x_axis = np.asarray(SampleX)[valid_scan][0]
y_axis = np.asarray(SampleY)[valid_scan][0]
energy_data = np.asarray(energy_data)[valid_scan]

xax_dict = {"offset": x_axis[0], "scale": x_axis[1]-x_axis[0], "size": x_axis.size, "units": "mm"}
yax_dict = {"offset": y_axis[0], "scale": y_axis[1]-y_axis[0], "size": y_axis.size, "units": "mm"}

ndata = np.asarray(file_list)[valid_scan].size

In [None]:
tmp = crop_stack_map(stack_element_align, valid_scan, dim_x, dim_y)
s_element_align = hs.signals.Signal2D(tmp, axes=[{"size": tmp.shape[0]}, xax_dict, yax_dict])

tmp = crop_stack_map(stack_edge_element, valid_scan, dim_x, dim_y)
s_edge_element = hs.signals.Signal2D(tmp, axes=[{"size": tmp.shape[0]}, xax_dict, yax_dict])

tmp = crop_stack_map(I0_total, valid_scan, dim_x, dim_y)
s_i0_total = hs.signals.Signal2D(tmp, axes=[{"size": tmp.shape[0]}, xax_dict, yax_dict])

### Plot the element_to_align

In [None]:
if ref_index is None:
    ref_index = np.argmax(s_element_align.sum(axis=(1,2))).data[0]

if not 0 <= ref_index <= ndata-1:
    raise ValueError(f"The reference index {ref_index} for alignment is out of bound, should be between 0 and {ndata-1}")

s_element_align.axes_manager[0].index = ref_index

<div class="alert alert-block alert-info">
The image shown here is the one used as the reference for alignment.
</div>

In [None]:
s_element_align.plot()

### Align element_to_align

In [None]:
if (method := method.lower()) == "mutual":
    # mutual information
    print("Using mutual information")
    shifts, max_vals = estimate_shift2D_mutual(s_element_align,
                                               reference="current",
                                               bin_rule="sturges",
                                               max_fractional_shift=max_fractional_shift,
                                               brute=False
                                              )
    s_element_align.align2D(shifts=shifts)
elif method == "fourier":
    # fourier correlation
    print("Using Fourier correlation")
    shifts = s_element_align.align2D(reference="current",
                                     normalize_corr=True,
                                    )

else:
    raise ValueError(f'Method "{method}" not recognised, either "mutual" or "fourier".')

### Plot the shifts for element_to_align

In [None]:
fig, ax = plt.subplots(1, 2)
ax[0].plot(shifts[:, 0])
ax[0].set_title("Row shifts")
ax[1].plot(shifts[:, 1])
ax[1].set_title("Column shifts")

print(f"Sum of all row shifts: {np.sum(shifts[:, 0])}")
print(f"Sum of all column shifts: {np.sum(shifts[:, 1])}")

### Align I0 sum

In [None]:
if normalised:
    s_i0_total.align2D(shifts=shifts)

### Align the edge_element

In [None]:
s_edge_element.align2D(shifts=shifts)

In [None]:
s_edge_element.plot()

### Plot the I0 sum

In [None]:
if normalised:
    s_i0_total.plot()

### Normalise the stack by I0 sum

In [None]:
if normalised:
    i0_max = s_i0_total.data.max()
    stack_norm_reduced = s_edge_element / (s_i0_total / i0_max)
else:
    stack_norm_reduced = s_edge_element

### Plot energy of the stack

In [None]:
fig, ax = plt.subplots()
ax.plot(energy_data, "C0.-")
ax.set_xlabel("index")
ax.set_ylabel("energy (keV)")
ax.set_title("Energy of the stack")

### Save the file

In [None]:
new_white_spectrum = np.ones_like(energy_data)
new_energy = np.asarray(energy_data) * 1000
comment = ""

In [None]:
save_mantis(mantis_outpath, stack_norm_reduced.data.T,
            ax_energy=new_energy,
            ax_white=new_white_spectrum,
            comment=comment,
            shifts=shifts,
           )
print(f"File written: {mantis_outpath}")

### Save edge_element image stack as a gif

In [None]:
_ = save_stack_as_gif(stack_norm_reduced.data, gif_outpath)

print(f"Saving edge_element image stack gif to {gif_outpath}")

### Save edge_element integral intensity plot

In [None]:
edge_integral = stack_norm_reduced.data.sum((-2,-1))

In [None]:
print(f"Saving edge_element integral intensity to {png_outpath}")

fig, ax = plt.subplots()
ax.plot(energy_data, edge_integral, "r-")
ax.set_title(edge_element)
ax.set_xlabel("energy (keV)")
fig.tight_layout()
plt.savefig(png_outpath)