## Usage
This notebook takes phase data and a mask, aligns the data, saves the results in `outpath`. 

### Parameters
`inpath` : str    
    The full path to the file containing data.

`maskpath` : str    
    The full path to the file containing the mask.

`outpath` : str    
    The full path to the output folder. 


**The parameters should be provided by explicitly modifying the top cell content or using tools such as [papermill](https://papermill.readthedocs.io/en/latest/index.html). If the notebook is run as is, please define the parameters accordingly.**

### Dependencies
- numpy
- ptypy
- scipy
- skimage

In [None]:
import os
import h5py
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

from scipy import ndimage
from skimage import registration

In [None]:
if maskpath:
    try:
        with h5py.File(maskpath, "r") as f:
            background_mask = f["entry/data/background_mask"][:]
            sample_mask = f["entry/data/sample_mask"][:]
    except FileNotFoundError:
        print(f"Mask file {maskpath} is not found.")

In [None]:
with h5py.File(inpath, "r") as f:
    phase = f["entry/data/phase"][:]
    if not maskpath:
        background_mask = f["entry/data/background_mask"][:]
        sample_mask = f["entry/data/sample_mask"][:]


## Vertical alignment using phase correlation

In [None]:
shifts = np.zeros((len(phase),2))
for i in range(len(phase)):
    shift, _, _= registration.phase_cross_correlation(phase[i], phase[i-1], reference_mask=sample_mask[i], moving_mask=sample_mask[i-1])
    shifts[i] = shift

In [None]:
shifts[:,0] = (shifts[:,0] - shifts[:,0].min()).astype(int)

In [None]:
plt.figure()
plt.plot(shifts[:,0], label="vertical shift")
plt.plot(shifts[:,1], label="horizontal shift")
plt.legend()
plt.show()

In [None]:
# Align vertically
phase_vertically_aligned = np.empty(np.shape(phase), dtype=np.float32)
sample_mask_shifted = np.empty(np.shape(sample_mask), dtype=np.bool)

for i in range(len(phase)):
    phase_vertically_aligned[i] = ndimage.shift(phase[i], (-shifts[i,0],0), mode="constant", cval=0.0)   
    sample_mask_shifted[i] = ndimage.shift(sample_mask[i], (-shifts[i,0],0), mode="constant", cval=0.0)   

In [None]:
plt.figure()
plt.imshow(phase_vertically_aligned[:,400].T, aspect="auto")
plt.colorbar()

# Horizontal alignment using centre of mass


In [None]:
# Set a maximum and minimum boundary on which to focus alignment
sample_mask_median = np.median(sample_mask_shifted, axis=0)
indices = np.nonzero(sample_mask_median)
# This could be set differently
ymin, ymax = indices[1].min(), indices[1].max()

In [None]:
for i in range(len(phase)):
    d = phase_vertically_aligned[i,ymin:ymax]
    m = sample_mask_shifted[i,ymin:ymax]
    l,_ = ndimage.label(m)
    _, x = ndimage.center_of_mass(d, labels=l, index=1)
    shifts[i,1] = x

In [None]:
shifts[:,1] = (shifts[:,1] - shifts[:,1].min()).astype(int)

## Apply shifts to phase data and mask

In [None]:

phase_aligned = np.empty(np.shape(phase), dtype=np.float32)
sample_mask_aligned = np.empty(np.shape(sample_mask), dtype=np.bool)

for i in range(len(phase)):
    phase_aligned[i] = ndimage.shift(phase[i], (-shifts[i,0],-shifts[i,1]), mode="constant", cval=0.0)   
    sample_mask_aligned[i] = ndimage.shift(sample_mask[i], (-shifts[i,0],-shifts[i,1]), mode="constant", cval=0.0) 

## Copy the original file and save to output filepath

In [None]:
# If an output file isn't specified, save the output file to the same directory as the input data  
outpath = Path(inpath).parent if not outpath else Path(outpath)

In [None]:
# Follow the naming convention of the input file
outfile_stem = Path(inpath).stem
outfile = outpath.joinpath(f"{outfile_stem}_aligned.nxs")

In [None]:
os.system(f"cp {inpath} {outfile}")
with h5py.File(outfile, "r+") as f:
    del f["entry/data/phase"]
    f[f"entry/data/aligned"] = phase_aligned    
    f[f"entry/data/mask"] = sample_mask_aligned

## Figures

In [None]:
start = 0
mid = np.rint(phase_aligned.shape[0]/2).astype(int)
end = phase_aligned.shape[0]-1

fig, (ax1, ax2, ax3) = plt.subplots(1,3, sharey=True)
ax1.set_title("First slice")
ax1.imshow(phase_aligned[start])
ax2.set_title("Mid slice")
ax2.imshow(phase_aligned[mid])
ax3.set_title("End slice")
ax3.imshow(phase_aligned[end])

plt.suptitle('Aligned data', fontsize=12)
plt.subplots_adjust(top=1.4)
plt.show()