In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Union
from scipy.optimize import curve_fit

import perovskite_flattening as perov_flatten


from topostats.filters import Filters
from topostats.io import LoadScans
from topostats.plottingfuncs import Colormap

colourmap = Colormap()
CMAP = colourmap.get_cmap()
CMAP = plt.get_cmap("viridis")

In [None]:
DATA_DIR = Path("/Users/sylvi/topo_data/textured-silicon/Peak Force QNM/")
OUTPUT_DIR = Path("/Users/sylvi/topo_data/textured-silicon/flattening_output/")
OUTPUT_DIR.mkdir(exist_ok=True)
directory_contents = DATA_DIR.iterdir()
for item in directory_contents:
    print(item)
FILE_EXT = ".spm"
CHANNEL = "Height"

In [None]:
def plot(image: np.ndarray):
    plt.imshow(image, cmap=CMAP)
    plt.show()

In [None]:
def load_data(file: Path, channel=CHANNEL):
    loadscans = LoadScans([file], channel=channel)
    loadscans.get_data()
    image = loadscans.image
    pixel_to_nm_scaling = loadscans.pixel_to_nm_scaling
    return image, pixel_to_nm_scaling


def plot_gallery(
    images: list[np.ndarray], filenames: list[str], row_length: int = 4, title: str = "", save_figure: bool = False
):
    print(f"plotting {len(images)} images")
    fig, ax = plt.subplots(int(np.ceil(len(images) / row_length)), row_length, figsize=(15, 15))
    for image_index, image in enumerate(images):
        image_index_y = int(np.floor(image_index / row_length))
        image_index_x = image_index % row_length
        print(f"image index: {image_index}")
        print(f"image index x: {image_index_x}")
        print(f"image index y: {image_index_y}")
        ax[image_index_y, image_index_x].imshow(image, cmap=CMAP)
        ax[image_index_y, image_index_x].set_title(filenames[image_index], fontsize=10)

    plt.suptitle(title)
    fig.tight_layout()
    plt.show()
    if save_figure:
        if title != "":
            fig.savefig(OUTPUT_DIR / title)
        else:
            print("FIGURE NOT SAVED, TITLE MISSING")

In [None]:
def remove_nonlinear_polynomial(image: np.ndarray, mask: Union[np.ndarray, None] = None) -> np.ndarray:
    # Script has a lot of locals but I feel this is necessary for readability?
    # pylint: disable=too-many-locals
    """Fit and remove a "saddle" shaped nonlinear polynomial trend of the form a + b * x * y - c * x - d * y
    from the supplied image. AFM images sometimes contain a "saddle" shape trend to their background,
    and so to remove them we fit a nonlinear polynomial of x and y and then subtract the fit from the image.
    If these trends are not removed, then the image will not flatten properly and will leave opposite diagonal
    corners raised or lowered.

    Parameters
    ----------
    image: np.ndarray
        2D numpy heightmap array of floats with a polynomial trend to remmove.
    mask: np.ndarray
        2D numpy boolean array used to mask out any points in the image that are deemed not to be part of the
        heightmap's background data. This argument is optional.

    Returns
    -------
    np.ndarray
        Copy of the supplied image with the polynomial trend subtracted.
    """

    # Define the polynomial function to fit to the image
    def model_func(x, y, a, b, c, d):
        return a + b * x * y - c * x - d * y

    image = image.copy()
    if mask is not None:
        read_matrix = np.ma.masked_array(image, mask=mask, fill_value=np.nan).filled()
    else:
        read_matrix = image

    # Construct a meshgrid of x and y points for fitting to the z heights
    xdata, ydata = np.meshgrid(np.arange(read_matrix.shape[1]), np.arange(read_matrix.shape[0]))
    zdata = read_matrix

    # Only use data that is not nan. Nans may be in the image from the
    # masked array. Curve fitting cannot handle nans.
    nan_mask = ~np.isnan(zdata)
    xdata_nans_removed = xdata[nan_mask]
    ydata_nans_removed = ydata[nan_mask]
    zdata_nans_removed = zdata[nan_mask]

    # Convert the z data to a 1D array
    zdata = zdata.ravel()
    zdata_nans_removed = zdata_nans_removed.ravel()

    # Stack the x, y meshgrid data after converting them to 1D
    xy_data_stacked = np.vstack((xdata_nans_removed.ravel(), ydata_nans_removed.ravel()))

    # Fit the model to the data
    # Note: pylint is flagging the tuple unpacking regarding an internal line of scipy.optimize._minpack_py : 910.
    # This isn't actually an issue though as the extended tuple output is only provided if the 'full_output' flag is
    # provided as a kwarg in curve_fit.
    popt, _pcov = curve_fit(  # pylint: disable=unbalanced-tuple-unpacking
        lambda x, a, b, c, d: model_func(x[0], x[1], a, b, c, d), xy_data_stacked, zdata_nans_removed
    )

    # Unpack the optimised parameters
    a, b, c, d = popt
    print(f"Nonlinear polynomial removal optimal params: const: {a} xy: {b} x: {c} y: {d}")

    # Use the optimised parameters to contstruct a prediction of the underlying surface
    z_pred = model_func(xdata, ydata, a, b, c, d)
    # Subtract the fitted nonlinear polynomial from the image
    image -= z_pred

    return image

In [None]:
files = DATA_DIR.glob(f"*{FILE_EXT}")
save_images = False
plot_individual_images = False
images = []
filenames = []
for file in files:
    filename = str(file.stem).replace('.', '_')
    print(file)
    image, pixel_to_nm_scaling = load_data(file)
    if plot_individual_images:
        plot(image)

    # Flattening
    # image = perov_flatten.plane_tilt_removal(image)
    image = 
    
    images.append(image)
    filenames.append(filename)
    if save_images:
        plt.imsave(OUTPUT_DIR / f"{filename}.png", image, cmap=CMAP)

plot_gallery(images, filenames=filenames, title="polynomial removal images viridis", save_figure=True)