In [None]:
# outpath will be used to get outpath_nexus, outpath_mantis and outpath_complete for saving data, and outpath itself won't be used

# default settings
inpath = "" # type: str
outpath = "" # type: str
edge_element = "" # type: str
sztol = 0.9 # type: float
normalised = True  # type: bool
completion_rank = 6 # type: int
tol_residual = 1e-4 # type: float
num_short_iteration = 75 # type: int
num_final_iteration = 2000 # type: int
auto_processing = True # type: bool

## Usage 

This notebook takes the last scan file of a sparse XANES scan (`inpath`), defines the 2D full grid, inserts the data in the correct rows, stack the images, and completes the missing data by using looped alternating steepest descent (ASD). There are 3 output files:
- the incomplete NeXuS file
- the incomplete MANTIS file
- the complete MATNIS file

### Parameters
`inpath` : str  
the full path of the last scan file of a sparse XANES scan. e.g. "/dls/i14/data/2024/cm37259-1/scan/i14-280251.nxs"  
`outpath` : str
the full path of the output file. e.g. "/dls/i14/data/2024/cm37259-1/processed/i14-280251_xanes_sparse_stack_autoprocessing0.nxs". `outpath` will be overridden if it is triggered by auto-processing, as defined by the flag `auto_processing` (default to `True`); for post-processing, `outpath` will not be modified. `outpath` will then be used to make the file paths for 3 files: incomplete NeXuS file, incomplete MANTIS file and complete MANTIS file.  
`edge_element` : str  
the transition for the XRF window. e.g. "Zr-Ka"   
`sztol` : float   
The percentage tolerance of the size of the scan for declaring a scan to be removed or cropped.    
The number must be between 0 and 1 and sensible value should be above 0.9.    
The reference is always set by the scan first.    This applies to both x and y axis.      
For example, if the size of x axis in first is 71 and sztol is 0.9:   
dataset with x-axis size >63 (71\*0.9 rounded down) will be cropped;   
dataset with x-axis size <=63 will be removed from the stack.
```
<-0___________remove___________63-><-_____crop_____71->
     |                                 |
     |                                 |  
     |     all the dataset will be cropped to the minimum x-axis size of the whole sequence
     |     e.g. if one dataset has an x-axis size 67, all dataset is cropped to 67 in the x axis
     |          it is not removed as it is higher than the tolerance
     |
     |
     |
all the dataset will be removed from the stack
e.g. if a dataset has an x-axis size 21, it is removed
     it is not cropped as it is lower than the tolerance
```
`completion_rank` : int   
the maximum rank that the sparse XANES stack will be decomposed to for matrix completion. It should be between 3 to 12.   
`tol_residual` : float   
the iteration will stop once the residual is below this value.    
`num_short_iteration` : int    
the number of iteration for all but the final rank decomposition.   
`num_final_iteration` : int   
the number of iteration for the final rank decomposition.   
`auto_processing` : bool    
the flag to state this is an auto- or post-processing.    

**The parameters should be provided by explicitly modifying the top cell content or using tools such as [papermill](https://papermill.readthedocs.io/en/latest/index.html). If the notebook is run as is, please define the parameters accordingly.**

### Dependencies
Majority of the work is carried out by [i14-utility-xanes](https://gitlab.diamond.ac.uk/i14/i14_utility).   
- numpy
- matplotlib
- h5py
- hyperspy
- i14-utility (https://gitlab.diamond.ac.uk/i14/i14_utility)

In [None]:
import time
from pathlib import Path
import re

import numpy as np
import matplotlib.pyplot as plt
import h5py
import hyperspy.api as hs

from i14_utility.xanes.window_xrf import (channel_start_end, read_raw_data, window_mca, check_inconsistent_axis, 
print_file_summary, full_x_axis, full_y_axis, sparse_y_indices, sparse_row_map, 
sparse_window_stack)
from i14_utility.xanes.completion import LoopedASD, imagesc
from i14_utility.xanes.io import save_mantis

In [None]:
from importlib.metadata import version
print(f"HyperSpy version: {version('hyperspy')}")

### Get scan file list

In [None]:
previous_file_dataset_path = "/entry/previous_scan_files/paths"

if auto_processing:
    try:
        with h5py.File(inpath, "r") as f:
            file_list = f[previous_file_dataset_path][()]
    except FileNotFoundError as err:
        msg = f"The file {inpath} cannot be found, perhaps the year/visit is wrong?"
        raise FileNotFoundError(msg) from err
    except KeyError as err:
        msg = (f"The dataset path {previous_file_dataset_path} seems not present, " 
            f"please check if {inpath} is the last scan of a sparse XANES experiment.")
        raise KeyError(msg) from err
    else:
        # add the current one
        file_list = [f.decode() for f in file_list] + [inpath]

print(f"Number of scan files in the list: {len(file_list)}")

In [None]:
print(f"Line group to be aligned: {edge_element}")

### Window here 
Instead of reading processed file to avoid endless trouble about data location/dataset path

In [None]:
start = time.perf_counter()

windowed = []
energy_data =[]
SampleX = []
SampleY = []
scan_shapes = []
I0_total = []

lg_start, lg_end = channel_start_end(edge_element)

for raw_data in file_list:
    # windowing happens here for each file
    data = read_raw_data(raw_data)
    
    # window it
    w_mca = window_mca(data["mca"], lg_start, lg_end, data["scan_shape"], data["scan_model"])
  
    # sum the I0
    I0_t = np.squeeze(data["I0_1"] + data["I0_2"] + data["I0_3"] + data["I0_4"])
    
    # record everything
    windowed.append(w_mca)
    
    energy_data.append(data["energy"])
    SampleX.append(data["x"])
    SampleY.append(data["y"])
    scan_shapes.append(data["scan_shape"])
    I0_total.append(I0_t)
    
print(f"Time reading raw data: {(time.perf_counter() - start)/60:.2f} min")

###  Check inconsistency of element maps

In [None]:
valid_scan_x, dim_x = check_inconsistent_axis(SampleX, scan_shapes, sztol=sztol)
valid_scan_y, dim_y = check_inconsistent_axis(SampleY, scan_shapes, sztol=sztol)
valid_scan = valid_scan_x & valid_scan_y
dim_ = min(dim_x, dim_y)

### A summary of scan files

In [None]:
first, last = print_file_summary(file_list, valid_scan)

In [None]:
# take the first scan shape in the file list
print(f"Original elemental map shape: {scan_shapes[0]}")
print(f"Elemental map shape of the stack: ({dim_},)")
print(f"Number of elemental maps in the stack: {np.count_nonzero(valid_scan)}")

### Override paths of output files 

In [None]:
output_folder = Path(outpath).parent / "xanes_sparse"
output_folder.mkdir(parents=True, exist_ok=True)

print(f"output_folder is {output_folder}")

In [None]:
first_num = re.search(r"^.*i14-(\d+)\.nxs$", first).group(1)
last_num = re.search(r"^.*i14-(\d+)\.nxs$", last).group(1)

outpath_nexus = str(output_folder / f"i14_{first_num}_{last_num}_stack.nxs")
outpath_mantis = str(output_folder / f"i14_{first_num}_{last_num}_mantis.hdf5")
outpath_complete = str(output_folder / f"i14_{first_num}_{last_num}_completion.hdf5")

In [None]:
print(f"The NeXus file: {outpath_nexus}")
print(f"The MANTIS (incomplete) file: {outpath_mantis}")
print(f"The MANTIS (complete) file: {outpath_complete}")

### Determine overall scan size

In [None]:
xall = full_x_axis(SampleX)
yall, y_coords = full_y_axis(SampleY, return_coords=True)

In [None]:
y_size, x_size = yall.size, xall.size

### Determine scan row and put into the scan stack
The left figure should look quite random, the right figure should only show one value if each sparse scan contains equal number of rows.

In [None]:
scanned_rows = sparse_y_indices(yall, y_coords)
_ = sparse_row_map(scanned_rows, y_size)

In [None]:
stack = sparse_window_stack(windowed, scanned_rows, y_size, x_size, I0_total=I0_total, normalised=normalised)

### Convert to a HyperSpy signal and save it as NeXus file

In [None]:
sig = hs.signals.Signal2D(stack, signal_axes=(0,1))

In [None]:
sig.axes_manager[1].offset = np.min(xall)
sig.axes_manager[1].scale = abs(np.diff(xall).min())
sig.axes_manager[1].name = "X"
sig.axes_manager[1].units = "mm"

sig.axes_manager[2].offset = np.min(yall)
sig.axes_manager[2].scale = abs(np.diff(yall).min())
sig.axes_manager[2].name = "Y"
sig.axes_manager[2].units = "mm"

sig.metadata.set_item("Acquisition_instrument.XRF.beam_energy", np.asarray(energy_data))

In [None]:
sig.plot()

In [None]:
print(f"Saving NeXus file at {outpath_nexus}")
sig.save(outpath_nexus)

### Save as MANTIS Exchange format

In [None]:
new_white_spectrum = np.ones_like(energy_data)
new_energy = np.asarray(energy_data) * 1000
comment = ""

In [None]:
print(f"Saving MANTIS file (incomplete) at {outpath_mantis}")

save_mantis(outpath_mantis, stack.T,
            ax_energy=new_energy,
            ax_white=new_white_spectrum,
            comment=comment)

### Matrix completion
Using looped alternating steepest descent. (origianl codes by Oliver Townsend/Paul Quinn)

In [None]:
loop_asd = LoopedASD(stack.T, 
                     rank_max=completion_rank, 
                     tol=tol_residual, 
                     niter_short=num_short_iteration, 
                     niter_final=num_final_iteration, 
                     verbose=True)

In [None]:
ts = time.perf_counter()

loop_asd.start_looping()

te = time.perf_counter()

In [None]:
print(f"Completion run time: {te-ts} s")
print(f"True undersample ratio: {loop_asd.undersampling_ratio}")
print(f"Completion Residual: {loop_asd.residuals[-1]}")

In [None]:
imagesc(loop_asd.flatten, im_title='Sparse Data Color Map')
imagesc(loop_asd.low_rank_matrix, im_title='Completed Data Color Map')

In [None]:
fig, ax = plt.subplots()
_ = ax.plot(np.log10(loop_asd.residuals))
_ = ax.set_xlabel("Number of iterations")
_ = ax.set_ylabel(r"$\mathrm{log}_{10}$ R")
_ = ax.set_title("Log residual")

In [None]:
data_complete = loop_asd.stack_complete

In [None]:
print(f"Saving MANTIS file (complete) at {outpath_complete}")

save_mantis(outpath_complete, data_complete,
            ax_energy=new_energy,
            ax_white=new_white_spectrum,
            comment=comment)