In [None]:
inpath = ""
inpath_shifts = ""
outpath_nexus = None
normalise = True
cropping = "auto"
cropping_margin = 20

## Usage 

This notebook aligns and normalises a stacked XRF tomography data and saves the alignment projections (`outpath_nexus`).

### Parameters

`inpath` : str
The relative path to a stacked XRF tomography nexus file containing projection data.

`inpath_shifts` : str
The relative path to a .txt with the shifts from the initial alignment.

`outpath_nexus` : str
The relative path to the output nexus containing the aligned and normalised projection data.

`normalise` : bool
Whether to normalise the projections based on the area under the vertical mass profiles, default is True.


### Dependencies

- Numpy
- SciPy
- h5py
- matplotlib
- imageio


In [None]:
import h5py
import imageio
import pathlib
import numpy as np
import scipy.ndimage as ndi
import matplotlib.pyplot as plt

## Create folder for assets

In [None]:
assets_folder = pathlib.Path("./_assets")
assets_folder.mkdir(parents=True, exist_ok=True)

## Check parameters

In [None]:
assert inpath is not None, "Need to provide Nexus NXtomo file."
assert inpath.endswith(".nxs"), f"The provided input file {inpath} needs to end with .nxs"
inpath_nexus = pathlib.Path(inpath)

assert inpath_shifts is not None, "Need to provide TXT file with shifts."
assert inpath_shifts.endswith(".txt"), f"The provided shifts file needs to end with .txt"

if outpath_nexus is None:
    outpath_nexus = inpath_nexus.stem + "_aligned.nxs"

autocrop = False
cropval = 0
if cropping == "auto":
    autocrop = True
elif isinstance(cropping, int):
    cropcal = cropping

## Loading data from Nexus NXtomo file

In [None]:
with h5py.File(inpath, "r") as f:
    assert "entry" in f and \
           "definition" in f["entry"] and \
           f["entry/definition"][()] == b"NXtomo", \
           f"{inpath} is not a Nexus file of type NXtomo -> use nxstacker to generate stacked XRF projections."
    tomo = np.nan_to_num(f["entry/data/data"][:])

## Loading shifts from TXT file

In [None]:
shifts = np.loadtxt(inpath_shifts)

## Apply shifts to projections

In [None]:
aligned = np.nan_to_num(np.array([ndi.shift(tomo[i], -shifts[i], order=1) for i in range(len(tomo))]))

In [None]:
maxv = np.max(aligned)
minv = np.min(aligned)
output_file = assets_folder / f"{inpath_nexus.stem}_align_stack.gif"
image = []
for i in range(len(aligned)):
    data = aligned[i] - minv
    data = data/(maxv-minv)
    data = 255 * data # Now scale by 255
    img = data.astype(np.uint8)
    image.append(img)
imageio.mimsave(output_file, image)

In [None]:
from IPython.display import Image
print("File",output_file)
display(Image(output_file,width=512))

## Normalise projections

In [None]:
# User horizontal mass profile to determine where the sample starts and ends 
hmp_mean = aligned.sum(axis=1).mean(axis=0)
hmp_mean_laplace = np.abs(np.gradient(np.gradient(hmp_mean)))
hmp_mean_laplace /= hmp_mean_laplace.max()
hmp_mean_laplace_high = np.where(hmp_mean_laplace>0.25)[0]
if len(hmp_mean_laplace_high) > 2:
    hmin, hmax = hmp_mean_laplace_high[0], hmp_mean_laplace_high[-1]
    if hmin > cropping_margin:
        hmin -= cropping_margin
    if hmax < len(hmp_mean) - cropping_margin:
        hmax += cropping_margin
else:
    hmin, hmax = 0, len(hmp_mean)

In [None]:
if autocrop:
    cmin, cmax = hmin, hmax
else:
    cmin, cmax = cropval, -1-cropval

In [None]:
plt.figure(figsize=(10,5))
plt.plot(hmp_mean/hmp_mean.max())
plt.plot(hmp_mean_laplace/hmp_mean_laplace.max())
plt.axvline(hmin, color="k", ls=":")
plt.axvline(hmax, color="k", ls=":")
plt.axvline(cmin, color="g", ls=":")
plt.axvline(cmax, color="g", ls=":")
plt.show()

In [None]:
vmp = aligned[:,:,cmin:cmax].sum(axis=2)
vmp_total = vmp.sum(axis=1)
vmp_scale = vmp_total.reshape(vmp.shape[0],1,1) / vmp_total.mean()

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(10,5))
axes[0].plot(vmp.T / vmp.max())
axes[1].imshow(vmp.T / vmp.max())
plt.show()

In [None]:
if normalise:
    aligned /= vmp_scale
    vmp_after = aligned[:,:,cmin:cmax].sum(axis=2)
    fig, axes = plt.subplots(ncols=2, figsize=(10,5))
    axes[0].plot(vmp_after.T / vmp_after.max())
    axes[1].imshow(vmp_after.T /  vmp_after.max())
    plt.show()

## Centre of Mass Alignment

In [None]:
def plot_sinogram(projections, y_slice=None):
    if y_slice == None:
        y_slice = int(projections.shape[1]/2)
    plt.figure(figsize=[15,5])
    plt.imshow(projections[:,y_slice,:].T, aspect='auto', cmap='gray')
    plt.show()

In [None]:
plot_sinogram(aligned[:,:,cmin:cmax])

In [None]:
dx = np.array([ndi.center_of_mass(aligned[i,:,cmin:cmax])[1] for i in range(len(aligned))])
dx -= (cmax-cmin)/2

In [None]:
plt.figure(figsize=(10,5))
plt.plot(dx)
plt.show()

In [None]:
aligned_refine = np.copy(aligned)
for i in range(len(aligned)):
    aligned_refine[i] = ndi.shift(aligned[i], (0,-dx[i]))

In [None]:
plot_sinogram(aligned_refine[:,:,cmin:cmax])

## Saving the aligned stack into Nexus NXtomo file

In [None]:
with h5py.File(inpath, "r") as fin:
    with h5py.File(outpath_nexus, "w") as fout:
        fin.copy(fin["entry"], fout, "entry")
        del fout["entry/data/data"]
        fout["entry/data/data"] = aligned_refine[:,:,cmin:cmax].astype(np.uint16)