# Holography Tools prototype

Demonstrates the prototype holo tools GUI panels

This currently relies on the branch https://github.com/sk1p/LiberTEM-holo/tree/core-reconstr-changes, uncomment to install:

In [None]:
# !pip -q install -q git+https://github.com/sk1p/LiberTEM-holo.git@core-reconstr-changes

In [None]:
import numpy as np
import panel as pn
import libertem.api as lt

In [None]:
from libertem_ui.ui_context import UIContext
from libertem_ui.live_plot import ApertureFigure

## Data

Build a fake holo stack from some `skimage` datasets, with drift and noise and phase offset

In [None]:
from libertem_holo.base.generate import hologram_frame
from skimage.data import cat, gravel, astronaut
from skimage.transform import resize, rotate

In [None]:
shape = (512, 512)
angle = 30.  # np.random.uniform(0, 360)
sampling = 5.5  # np.random.uniform(4, 5)
frames = []
drift = np.asarray((-3, 5))
border = 48
border_s = np.s_[border:-border, border:-border]
gen_phase_offsets = np.random.uniform(-1, 1, size=(4,))
gen_phase_offsets[0] = 0.
for c in range(4):
    _drift = np.round(drift * np.asarray((c, np.sqrt(c)))).astype(int)
    amp = resize(cat()[..., 0] / 255, shape)
    phase = (resize(gravel()[:128, :128] / 255, amp.shape, order=2) - 0.5) * 3 * np.pi
    phase += gen_phase_offsets[c]
    amp = np.roll(amp, _drift, axis=(0, 1))
    phase = np.roll(phase, _drift, axis=(0, 1))
    amp = amp[border_s]
    phase = phase[border_s]
    frames.append(hologram_frame(amp, phase, visibility=0.2, sampling=sampling, f_angle=angle, gaussian_noise=0.5, poisson_noise=0.03))
frames = np.asarray(frames).astype(np.float32)

Display the generated wave amplitude + phase:

In [None]:
left = ApertureFigure.new(amp, title="amp")
right = ApertureFigure.new(phase, title="phase")
pn.Row(left.layout, right.layout)

## Aperture Builder

This panel is used to define an aperture to extract the sideband and reconstruct the wavefront. It allows viewing the stack images and their FFTs, interactively placing the aperture and viewing the reconstruction for the displayed image and parameters. 

In [None]:
from libertem_ui.applications.holography.panels import ApertureBuilder, ApertureConfig

The initial state of the parameters can be extracted and loaded into the window via a `NamedTuple`:

In [None]:
state = ApertureConfig(sb_pos=(142, 170), radius=44.1, window_size=(94, 94), ap_but_order=7)

Passing the `state` is optional, if not provided then a default state is used.

Workflow / Interactions:

The left column is used to define the sideband aperture and view it atop each image in the stack. The right column shows the reconstruction steps (aperture, crop from FFT and reconstructed amplitude / phase).

General:
- The `σ` button on each plot toolbar applies intensity scaling based on the image standard deviation. This improves image contrast. Sometimes image load with an invalid colormap and clicking this icon corrects the issue.
- The 'sliders' icon opens a plot toolbox to manually adjust the colormap / scale.
- The general tools on the plot toolbar allow zooming / panning / reset / etc.
- Other tools are sometimes available here, for drawing / editing / selecting annotations and data. Click these to activate them.

*Left column:*
- `Image` / `FFT` selector: display the current frame as image or FFT
- `Display channel` slider: change the displayed image
- `FFT` mode selected: Use the `point select tool` to drag the sideband (circle) annotation
- `Aperture radius` and `Window size` sliders: When in FFT display, adjust to change the aperture size parameters
- `Estimate SB` / `Lower / Upper` etc: Automatically estimate the sideband position and update the displays

*Right column:*
- `Aperture` / `Crop` / `Reconstruction` selector: display the relevant view using the current parameters
- `Amplitude` / `Phase` selector: when displaying the reconstruction show either the amplitude or phase image
- `Unwrap` button: when displaying the phase image, swap this for the unwrapped phase image (temporary)
- `Aperture type` and `Order` selector: choose the aperture function and associated parameters

In [None]:
window = ApertureBuilder.using(frames, state=state)
window.layout()

Extract the current state of the parameters from the window:

In [None]:
window.state

## Reconstruct

The window state (`ApertureConfig`) can then be passed to LiberTEM functions to reconstruct the whole stack, optionally with a reference stack.

*TODO*

In [None]:
state = window.state

aperture = window._data.get_aperture(
    state.ap_type,
    state.radius,
    state.window_size,
    order=state.ap_but_order,
)
fft_crops = np.stack(
    [window._data.get_crop(idx, *state.sb_pos, aperture) for idx in range(window._data.stack_len)],
    axis=0,
)
reconstruction = np.fft.ifft2(fft_crops) * np.prod(window._data.sig_shape)
print(reconstruction.shape, reconstruction.dtype)

## Align and average stack

Once reconstructed, the `StackAligner` panel is used to align the stack before averaging the wavefront images. The panel provides a automatic alignment on the whole image or on subregions, as well as manual alignment via mouse interaction (clicking the arrows, dragging the moving image or dragging the scatter plot markers). Images can also be marked as "to-skip" if their reconstruction is incorrect.

In [None]:
from libertem_ui.applications.holography.panels import StackAlignWindow, AlignState

Again we can pre-load shift values into the panel:

In [None]:
state = AlignState(
    static_idx=0,
    shifts_y=np.asarray([0.  , 0.65, 1.15, 2.15]),
    shifts_x=np.asarray([ 0.  , -1.1 , -1.7 , -1.95]),
    skip_frame=np.asarray([False, False, False, False])
)

Workflow / Interactions:

The interface is an image viewer with static image + moving/active image on the left, and a scatter plot of image shifts on the right. By adjusting the transparency on the image viewer it is possible to evaluate the correctness of the shift values.

*Left column:*
- `Prev` / `Next` / `Align Image` slider: choose the image to be 'active' in the alignment, relative to the fixed reference image in the stack
- `Clear ROI` button: if any ROI annotations are on the image, delete these
- `Lasso` tool: Select this then click, drag and release to shift the currently active image
- `Static alpha`, `Align alpha`, `Toggle alpha`, `Equal alpha`: Control the transparency of the static reference and the currently selected image
- `Cycle alpha`: automatically cycle the transparency of the two images when checked

*Right column:*
- `Skip image` checkbox: Mark the current image as "skipped", i.e. ignore in downstream processing (no impact on this GUI)
- Arrow keys and `0.1 / 1 / 10` selector: shift the currently selected image by this number of pixels
- Double click on a scatter point: select the image nearest the mouse (blue when selected, red if unselected, grey if skipped, black of static / reference)
- With 'points select' tool active in the plot toolbar: drag a scatter plot point to move the anchor of the active image
- `Auto-align` buttons: automatically align either all or the currently active image using the current parameters.
- `Reset` buttons: reset to 0 the shift of either all or the currently active image
- `Align option` selector:
  - `Whole image`: align using the whole image
  - `Subregion`: align using a single rectangular subregion of the image. Set this subregion by using the 'Rectangle draw' tool on the left figure. Long press to begin, long press to end drawing a rectangle. Only the first region is used (`Clear ROI` to restart).
  - `Arb. mask`: align using masked cross-correlation on an arbitrary mask. Combine rectangular regions as before with polygon regions drawn with the Polgon draw / edit tools on the left figure.
- `Relative ROI` checkbox: when computing the masks / regions to align on take the current shift values into account.

In [None]:
align_window = StackAlignWindow.using(
    np.abs(reconstruction), state=state, static_idx=0
)
align_window.layout()

The state of the alignment panel can be extracted using the `.state` attribute, the return of which can be passed to LiberTEM to perform the full alignment and averaging.

In [None]:
align_window.state

To correctly average the wavefronts we must match the phase of the sideband carrier frequency across the stack. This can be done in the FFT domain by phase shifting the `[0,0]` pixel of each crop to match that of the first frame. *TODO check this with someone who knows holography!*

Calculate the phase offsets, these should approximately match the phase offsets we applied to each frame during data generation:

In [None]:
phase_offset_carrier = np.angle(fft_crops[:, 0, 0]) - np.angle(fft_crops[0, 0, 0])
print(f"Computed phase difference: {phase_offset_carrier}")
print(f"Actual phase difference: {gen_phase_offsets}")

We can now apply this phase correction and the image shifts in the Fourier domain before inverting again. Because the image is rolled in a periodic way by `ndimage.fourier_shift` the edges of the shifted stack are invalid - this can be solved by simple cropping, or by masking out invalid pixels and computing the mean from only the valid layers (more complex).

In [None]:
# TODO This needs to be validated for all cases !!!
yx_shifts = np.stack((
    align_window.state.shifts_y,
    align_window.state.shifts_x,
), axis=1)

def ceil_away(val):
    return (np.sign(val) + np.trunc(val)).astype(int)

miny, minx = ceil_away(yx_shifts.min(axis=0))
maxy, maxx = ceil_away(yx_shifts.max(axis=0))
print(miny, maxy, minx, maxx)
_, h, w = reconstruction.shape
valid_crop = np.s_[
    :,
    max(0, maxy):h + min(miny, 0),
    max(0, maxx):w + min(minx, 0),
]

print(valid_crop)

Now that we have the valid crop, we can align and average. (This should be done inside an LT function):

In [None]:
# TODO image skipping

from scipy import ndimage
from skimage.restoration import unwrap_phase

recon_shifted = np.zeros_like(reconstruction)
for idx, (fft_crop, phase_corr) in enumerate(zip(fft_crops.copy(), phase_offset_carrier)):
    fft_crop[0, 0] *= np.exp(-1j * phase_corr)
    shift = align_window.state.shifts_y[idx], align_window.state.shifts_x[idx]
    shifted = ndimage.fourier_shift(fft_crop, shift=shift)
    recon_shifted[idx, ...] = np.fft.ifft2(shifted) * np.prod(shape)

avg_recon = np.mean(
    recon_shifted[valid_crop],
    axis=0,
)

Display the results for clarity:

In [None]:
mean_phase_image_uc = np.angle(
    np.mean(
        reconstruction,
        axis=0
    )
)
mean_phase_image = np.angle(avg_recon)
pn.Row(
    ApertureFigure.new(
        unwrap_phase(mean_phase_image_uc),
        title="uncorrected + unwrapped",
    ).layout,
    ApertureFigure.new(
        unwrap_phase(mean_phase_image),
        title="corrected + unwrapped",
    ).layout,
)

## Phase unwrap tool

Once we have averaged the wavefronts we can unwrap the phase information to get hopefully continuous phase profiles. Normally phase unwrapping is a simple function but this panel compares two algorithms and the ability to offset the raw phase image before unwrapping:

In [None]:
from libertem_ui.applications.holography.panels import PhaseUnwrapWindow

Workflow / Interactions:

*Left column:*
- `Lasso` tool: Select this then click, drag and release to draw an arbitrary region where the phase unwrapping should begin (only with "quality-guided" unwrapping

*Right column:*
- `Run` button: Run the currently selected unwrapping algorithm with the current parameters / ROI
- `Reset` button: Clear the current unwrapped image, offset and ROI
- `Method` select: Choose the unwrapping algorithm used
  - `Reliability-guided` uses the `scikit-image` function `unwrap_phase`
  - `Quality-guided` uses an implementation of flood filling following a path of minimum local variance in the phase image
- `Offset phase` slider: roll the phase image around the period, this is used as starting point when unwrapping

phase_window = PhaseUnwrapWindow.using(
    image=np.angle(avg_recon),
)
phase_window.layout()

The "quality guided" algorithm implementation does allow the starting region to be set using the lasso selection tool. It is unclear at this time if this feature is beneficial!

This panel is a work in progress until a more interactive unwrapping approach is developed.

## Point alignment tool

The points alignment tool lets the user define two corresponding pointsets on a pair of images for computing a transform matrix (or other image registration approach). At this time the entry is purely manual, but could be semi-automated by seeding the pointsets with SIFT or similar (with appropriate thresholds).

In [None]:
from libertem_ui.applications.holography.panels import PointsAlignWindow

Workflow / Interactions:

*Left column:*
- The figure contains the static reference image and overlaid this is the transformed moving image (initially transparent)
- `Overlay alpha` slider / 'Toggle alpha` button: adjust the transparency of the overlay image
- With the `point tool` selected: double click on the figure to create points in the image, click and drag these points to position them. Each point created will be mirrored in the right column, and vice-versa.
- `Transformation type` select: choose the transformation matrix to calculate from the two corresponding pointsets
- `Run` button: Compute the selected transform from the pointsets and update the overlay image
- `Clear points` button: Clear the current pointsets

*Right column:*
- With the `point tool` selected: double click on the figure to create points in the image, click and drag these points to position them. Each point created will be mirrored in the left column, and vice-versa. Position the points on corresponding patches of the image.
- Transformation matrix display: show the current transformation matrix, or error messages, if any

In [None]:
points_window = PointsAlignWindow.using(
    static=gravel()[256:256+128, 256:256+128],
    moving=rotate(gravel(), 20)[256:256+128, 256:256+128],
)
points_window.layout()

The pointsets can be extracted using the `.state` attribute. The last-computed transformation matrix is also available, though this does not necessarily match the current 'live' pointset.

In [None]:
points_window.state