In [None]:
import re
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
from scipy.ndimage import shift as ndimage_shift
from scipy.signal import correlate2d
from skimage.registration import phase_cross_correlation

from topostats.classes import TopoStats
from topostats.io import LoadScans
from topostats.plottingfuncs import Colormap

colourmap = Colormap()
cmap = colourmap.get_cmap()

VMIN = -3
VMAX = 4

In [None]:
dir_base = Path("/Users/sylvi/topo_data/kinga-before-after-proteins")
assert dir_base.exists()
dir_data = dir_base / "data" / "before_and_afters"
assert dir_data.exists()
dir_processed_data = dir_base / "output/data/before_and_afters/processed"
assert dir_processed_data.exists()
dir_processed_data_topostats = dir_processed_data / "topostats"
assert dir_processed_data_topostats.exists()
dir_results = dir_base / "results"
dir_results.mkdir(exist_ok=True)

raw_files = list(dir_data.glob("*.spm"))
print(f"found {len(raw_files)} raw files")

In [None]:
def align_images_phase_correlation(image_1: npt.NDArray, image_2: npt.NDArray, upscale_factor: int = 1) -> npt.NDArray:
    """Align two images using phase correlation."""

    # apply a window to reduce edge effects
    height_1, width_1 = image_1.shape
    height_2, width_2 = image_2.shape
    assert height_1 == height_2 and width_1 == width_2, "images must be the same size"
    # create a 2d hanning window that smoothly goes to zero at the edges, remember that the outer function turns the
    # vectors into a matrix with every combination of the two vectors.
    window = np.outer(np.hanning(height_1), np.hanning(width_1)).astype(np.float64)
    # ensure that the images fade to zero at the edges to allow the FFT to work properly without introducing artefacts
    # at the edges.
    reference_image = image_1 * window
    moving_image = image_2 * window

    # upsample_factor is used to increase the resolution of
    # the shift, eg by 10 times, so the shift resolution is 1/10th of a pixel.
    # shift is the number of pixels to shift the image by
    # error is the RMS error between the aligned images
    # diff_phase is the global phase difference between the
    # images (should be close to 0 if images non-negative)
    shift, error, diff_phase = phase_cross_correlation(
        reference_image=reference_image, moving_image=moving_image, upsample_factor=upscale_factor
    )

    # order=1: use linaer interpolation as opposed to polynomial.
    # mode="constant": fill the edges with a constant value, 0
    aligned_image = ndimage_shift(image_2, shift=shift, order=1, mode="constant", cval=0)
    return aligned_image

In [None]:
# get the first pair
pair_index = 3
filename_pre_treatment = dir_processed_data_topostats / f"PAIR{pair_index}_JustDNA.topostats"
filename_post_treatment = dir_processed_data_topostats / f"PAIR{pair_index}_PARP.topostats"
assert filename_pre_treatment.exists()
assert filename_post_treatment.exists()

files = [filename_pre_treatment, filename_post_treatment]
loadscans = LoadScans(files, config={"loading": {"channel": "Height Sensor"}})
loadscans.get_data()
img_dict = loadscans.img_dict
file_pre_treatment: TopoStats = img_dict[filename_pre_treatment.name]
file_post_treatment: TopoStats = img_dict[filename_post_treatment.name]

pre_treatment_image = file_pre_treatment.image
post_treatment_image = file_post_treatment.image

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(pre_treatment_image, cmap=cmap, vmin=VMIN, vmax=VMAX)
ax[0].set_title("pre-treatment")
ax[1].imshow(post_treatment_image, cmap=cmap, vmin=VMIN, vmax=VMAX)
ax[1].set_title("post-treatment")
plt.suptitle(f"PAIR{pair_index}")
plt.show()

# subtract the pre-treatment from post-treatment
difference_image = post_treatment_image - pre_treatment_image
plt.imshow(difference_image, cmap="gray", vmin=VMIN, vmax=VMAX)
plt.title(f"pair {pair_index} difference (post - pre)")
plt.colorbar(label="height difference (nm)")
plt.show()

# align the images
aligned_post_treatment_image = align_images_phase_correlation(
    pre_treatment_image, post_treatment_image, upscale_factor=10
)
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(pre_treatment_image, cmap=cmap, vmin=VMIN, vmax=VMAX)
ax[0].set_title("pre-treatment")
ax[1].imshow(aligned_post_treatment_image, cmap=cmap, vmin=VMIN, vmax=VMAX)
ax[1].set_title("aligned post-treatment")
plt.suptitle(f"PAIR{pair_index}")
plt.show()

# subtract the pre-treatment from aligned post-treatment
aligned_difference_image = aligned_post_treatment_image - pre_treatment_image
plt.imshow(aligned_difference_image, cmap="gray", vmin=VMIN, vmax=VMAX)
plt.title(f"pair {pair_index} aligned difference (post - pre)")
plt.colorbar(label="height difference (nm)")
plt.show()