In [None]:
from pathlib import Path
from loguru import logger

import cv2
import seaborn as sns
from IPython.display import clear_output
import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt

from topostats.io import LoadScans
from topostats.filters import Filters
from topostats.plottingfuncs import Colormap

colormap = Colormap()
CMAP = colormap.get_cmap()
VMIN = None
VMAX = None

DATA_DIR = Path("/Users/sylvi/topo_data/perovskite/textured_silicon/")
files = list(DATA_DIR.glob("*"))
logger.info(f"Found {len(files)} files in {DATA_DIR}")
for file in files:
    logger.info(file)

loadscans = LoadScans(files, channel="Height Sensor")
loadscans.get_data()
image_dictionaries = loadscans.img_dict
logger.info(f"Loaded {len(image_dictionaries)} images")

In [None]:
flattened_images = {}
for filename, topostats_object in image_dictionaries.items():
    filename = Path(filename).name
    logger.info(f"Flattening {filename}")
    filters = Filters(
        image=topostats_object["image_original"],
        filename=filename,
        pixel_to_nm_scaling=topostats_object["pixel_to_nm_scaling"],
        row_alignment_quantile=None,
        threshold_method=None,
        gaussian_size=0.01,
        remove_scars={
            "run": False,
        },
    )
    filters.filter_image()
    flattened_images[filename] = {
        "original_image": topostats_object["image_original"],
        "flattened_image": filters.images["gaussian_filtered"],
        "pixel_to_nm_scaling": topostats_object["pixel_to_nm_scaling"],
    }
clear_output()
logger.info(f"Flattened {len(flattened_images)} images")

In [None]:
def plot_images(images: list, masks: list, px_to_nms: list, filenames: list, width=3, VMIN=VMIN, VMAX=VMAX, cmap=CMAP):
    num_images = len(images)
    num_rows = num_images // width + 1
    num_images_in_batch = 2
    fig, axes = plt.subplots(num_rows, width * num_images_in_batch, figsize=(width * 8, num_rows * 4))
    for i, (image, mask, p_to_nm, filename) in enumerate(zip(images, masks, px_to_nms, filenames)):
        # Plot image
        im_ax = axes[i // width, i % width * num_images_in_batch]
        im_ax.imshow(image, cmap=CMAP, vmin=VMIN, vmax=VMAX)
        im_ax.set_title(f"{p_to_nm:.3f} p/nm\n{filename}")
        im_ax.axis("off")
        # Plot mask
        mask_ax = axes[i // width, i % width * num_images_in_batch + 1]
        mask_ax.imshow(mask.astype(int))
        mask_ax.axis("off")

    fig.tight_layout()
    plt.show()

In [None]:
logger.info(f"{flattened_images.keys()}")
plot_images(
    images=[flattened_images[filename]["flattened_image"] for filename in flattened_images.keys()],
    masks=[np.zeros_like(flattened_images[filename]["flattened_image"]) for filename in flattened_images.keys()],
    px_to_nms=[flattened_images[filename]["pixel_to_nm_scaling"] for filename in flattened_images.keys()],
    filenames=[filename for filename in flattened_images.keys()],
)

In [None]:
# perform a frequency split
image: dict = flattened_images[list(flattened_images.keys())[0]]


def remove_low_frequencies(
    image: npt.NDArray[np.float32],
    pixel_to_nm_scaling: float,
    mountain_spacing_nm: float,
    debug: bool = False,
) -> npt.NDArray[np.float32]:

    # Set the zero point to zero
    image -= np.min(image)

    image_size_nm = image.shape[0] * pixel_to_nm_scaling
    mountain_frequency_nm = 1 / mountain_spacing_nm

    if debug:
        # Plot image
        plt.imshow(image, cmap=CMAP, vmin=VMIN, vmax=VMAX)
        plt.title("Original image")
        plt.show()
        # plot kde of image heights
        sns.kdeplot(image.flatten())
        plt.show()

    # Extend image for periodic boundary conditions
    extended_image = cv2.copyMakeBorder(
        image,
        image.shape[0] // 2,
        image.shape[0] // 2,
        image.shape[1] // 2,
        image.shape[1] // 2,
        cv2.BORDER_REFLECT,
    )

    # Remove frequency below 400nm from the image using opencv
    dft = cv2.dft(np.float32(extended_image), flags=cv2.DFT_COMPLEX_OUTPUT)
    # This produces two images, one for the real part and one for the imaginary part
    dft_shift = np.fft.fftshift(dft)  # shift the zero frequency component to the center
    # Create a magnitude spectrum, showing the frequency content of the image with the zero
    # frequency component in the center
    magnitude_spectrum = 20 * np.log(cv2.magnitude(dft_shift[:, :, 0], dft_shift[:, :, 1]))
    if debug:
        plt.imshow(magnitude_spectrum, cmap="viridis")
        plt.show()

    # Create a mask to remove the low frequency content
    rows, cols = extended_image.shape
    crow, ccol = rows // 2, cols // 2
    mask = np.ones((rows, cols, 2), np.uint8)
    # create the radius in pixels from the wavelength in nm
    radius = mountain_spacing_nm / pixel_to_nm_scaling
    logger.info(f"Radius: {radius}")

    mask[int(crow - radius) : int(crow + radius), int(ccol - radius) : int(ccol + radius)] = 0
    # invert the mask
    mask = 1 - mask

    # Make the centre of the mask 1
    # middle_zone_size = 1
    # mask[
    #     int(crow - middle_zone_size) : int(crow + middle_zone_size),
    #     int(ccol - middle_zone_size) : int(ccol + middle_zone_size),
    # ] = 1

    # apply the mask
    fshift = dft_shift * mask
    # Show the mask
    magnitude_spectrum = 20 * np.log(cv2.magnitude(fshift[:, :, 0], fshift[:, :, 1]))
    if debug:
        plt.imshow(magnitude_spectrum, cmap="viridis")
        plt.show()

    # inverse the shift
    f_ishift = np.fft.ifftshift(fshift)
    # inverse the dft
    img_back = cv2.idft(f_ishift)
    img_back = cv2.magnitude(img_back[:, :, 0], img_back[:, :, 1])

    # Normalize the image
    img_back /= rows * cols

    # Crop the image back to the original size
    img_back = img_back[
        image.shape[0] // 2 : -image.shape[0] // 2,
        image.shape[1] // 2 : -image.shape[1] // 2,
    ]

    if debug:
        plt.imshow(img_back, cmap="viridis")
        plt.show()

    logger.info(
        f"Image stats:\nmin: {np.min(img_back)}\n max: {np.max(img_back)}\n mean: {np.mean(img_back)}\n median: {np.median(img_back)}"
    )
    if debug:
        # plot kde of image heights
        sns.kdeplot(img_back.flatten())
        plt.show()

    # put a cap on the 90th and 10th percentile
    # img_back[img_back > np.percentile(img_back, 90)] = np.percentile(img_back, 90)
    # img_back[img_back < np.percentile(img_back, 10)] = np.percentile(img_back, 10)
    if debug:
        plt.imshow(img_back, cmap=CMAP, vmin=VMIN, vmax=VMAX)
        plt.show()

    image_high_freq = image - img_back
    if debug:
        plt.imshow(image_high_freq, cmap=CMAP, vmin=VMIN, vmax=VMAX)
        plt.show()

    return image_high_freq