# This notebook contains all the functions required to pre-process AFM microscopy images

In [None]:
"""
Author: Ruben Millan-Solsona
Date of Creation: August 2024

Description:
This module contains functions for flattening images, including methods to reduce
noise in images, particularly for atomic force microscopy (AFM) data. It offers plane
and polynomial surface fitting, optimal plane subtraction, and various filtering techniques
to improve image quality.

plane fit is good for sample with a simple tilt or uniform slope
poly fit is good for the background has curvature or more complex distortions
In our case, we have a simple tilt, so we will use the plane fit

Dependencies:
- os
- numpy
- cv2
- matplotlib.pyplot
- typing
- scipy.optimize
- scipy.ndimage
- scipy.stats
- skimage.metrics
- skvideo.measure
- AFMclasses (contains clImage, ChannelType, ExtentionType)
- managefiles (custom module for file management)

"""

import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
from typing import List, Tuple
from scipy.optimize import curve_fit, differential_evolution
from scipy.ndimage import label, find_objects
from scipy.stats import pearsonr
from skimage.metrics import peak_signal_noise_ratio, structural_similarity


def resize_image(image: np.ndarray, size: tuple) -> np.ndarray:
    """
    Resizes the image to the specified size using OpenCV.
    """
    return cv2.resize(image, size, interpolation=cv2.INTER_AREA)

def plane(coords, a, b, c):
    """
    Defines a plane function: ax + by + c.
    """
    x, y = coords
    return a * x + b * y + c

def poly(coords, a, b, c, d, e, f):
    """
    Defines a polynomial surface function: ax^2 + by^2 + cxy + dx + ey + f.
    """
    x, y = coords
    return a * x**2 + b * y**2 + c * x * y + d * x + e * y + f

def FitPlane(img, mask=None):
    """
    Fits a plane surface to the image data, optionally applying a mask to focus on certain regions.
    Returns the fitted surface and the Pearson correlation.
    """
    x = np.arange(img.shape[1])
    y = np.arange(img.shape[0])
    x, y = np.meshgrid(x, y)

    if mask is None:
        x_flat, y_flat, z_flat = x.flatten(), y.flatten(), img.flatten()
    else:
        x_flat, y_flat, z_flat = x[mask == 1].flatten(), y[mask == 1].flatten(), img[mask == 1].flatten()

    p0 = np.zeros(3)
    params, _ = curve_fit(plane, (x_flat, y_flat), z_flat, p0)
    plane_fitted = plane((x, y), *params).reshape(img.shape)
    correlation, _ = pearsonr(z_flat, plane((x_flat, y_flat), *params))
    return plane_fitted, correlation

def FitPoly(img, mask=None):
    """
    Fits a polynomial surface to the image data, optionally applying a mask to focus on certain regions.
    Returns the fitted surface and the Pearson correlation.
    """
    x = np.arange(img.shape[1])
    y = np.arange(img.shape[0])
    x, y = np.meshgrid(x, y)

    if mask is None:
        x_flat, y_flat, z_flat = x.flatten(), y.flatten(), img.flatten()
    else:
        x_flat, y_flat, z_flat = x[mask == 1].flatten(), y[mask == 1].flatten(), img[mask == 1].flatten()

    p0 = np.zeros(6)
    params, _ = curve_fit(poly, (x_flat, y_flat), z_flat, p0)
    poly_fitted = poly((x, y), *params).reshape(img.shape)
    correlation, _ = pearsonr(z_flat, poly((x_flat, y_flat), *params))

    return poly_fitted, correlation

def SubtractGlobalPoly(img, show=False):
    """
    Subtracts a polynomial surface from the image and optionally displays the result.
    """
    poly_fitted, correlation = FitPoly(img)
    img_flattened = img - poly_fitted

    if show:
        fig, ax = plt.subplots(1, 2, figsize=(12, 6))
        ax[0].imshow(img, cmap='gray')
        ax[0].set_title('Original Image')
        ax[1].imshow(img_flattened, cmap='gray')
        ax[1].set_title('Flattened Image')
        plt.show()

    return img_flattened, poly_fitted

def SubtractGlobalPlane(img, show=False):
    """
    Subtracts a fitted plane from the image and optionally displays the result.
    """
    plane_fitted, correlation = FitPlane(img)
    img_flattened = img - plane_fitted

    if show:
        fig, ax = plt.subplots(1, 2, figsize=(12, 6))
        ax[0].imshow(img, cmap='gray')
        ax[0].set_title('Original Image')
        ax[1].imshow(img_flattened, cmap='gray')
        ax[1].set_title('Flattened Image')
        plt.show()

    return img_flattened, plane_fitted

def FitOffsetToFlattingImageByDiffAndMask(img, mask=None):
    """
    Calculates the offset between adjacent lines in an image using a mask to ignore certain pixels.
    """
    offset_img = img.copy()
    for i in range(1, offset_img.shape[0]):
        line_below, current_line = offset_img[i - 1, :], offset_img[i, :]
        if mask is not None:
            line_below_masked, current_line_masked = line_below[mask[i, :] == 1], current_line[mask[i, :] == 1]
            offset = np.median(current_line_masked - line_below_masked) if len(line_below_masked) > 0 else 0
        else:
            offset = np.median(current_line - line_below)
        offset_img[i, :] -= offset

    if mask is not None:
        offset_img[mask == 0] = 0

    return offset_img




In [None]:
"""
Author: Huanhuan Zhao
Date of Creation: November 2024

Description:
This module contains functions for loading the data, preprocessing the image, and obtaining the differential images along x-axis.

"""


import gwyfile
import cv2

def load_data(img_path):
    '''
    load gwy file, topography data
    '''
    obj = gwyfile.load(img_path)
    channels = gwyfile.util.get_datafields(obj)
    data = channels['Forward - Topography'].data

    return data


def load_amplitude(file):
    obj = gwyfile.load(file)
    channels = gwyfile.util.get_datafields(obj)
    # amplitude = channels['Forward - Analyzer 1 Amplitude'].data
    amplitude = channels['Forward - Z-Controller In'].data

    return amplitude

def flatten_img(data):
    ''' flatten the image'''
    offset_img = FitOffsetToFlattingImageByDiffAndMask(data, mask=None)
    plane_fitted, correlation = FitPlane(offset_img)
    img_flattened = offset_img - plane_fitted

    return img_flattened

def normalize_img(img_flattened):
    '''scale, mormalize the image and convert it to uint8 format, resize image to improve resolution,
    output images list for next step processing. Higher resolution helps with feature detection'''

    img_flattened = img_flattened * 1e9  # scale the image since the image values are small
    img_flattened_8bit =   cv2.normalize(img_flattened, None, 0, 255, cv2.NORM_MINMAX,dtype=cv2.CV_8U)

    scale_factor = 1024/img_flattened_8bit.shape[1]
    new_width = int(img_flattened_8bit.shape[1] * scale_factor)
    new_height = int(img_flattened_8bit.shape[0] * scale_factor)


    # Resize the image using INTER_CUBIC for better quality
    image_high_res = cv2.resize(img_flattened_8bit, (new_width, new_height), interpolation=cv2.INTER_CUBIC)

    return image_high_res

def diff(file):
    ''' obtaining the differenciation of the topology image along x axis'''
    obj = gwyfile.load(file)
    channels = gwyfile.util.get_datafields(obj)
    topology = channels['Forward - Topography'].data
    grad_x = cv2.Sobel(topology, cv2.CV_64F, 1, 0, ksize=1)  #
    grad_x_norm = normalize_img(grad_x)
    return grad_x_norm


In [None]:
"""
Generating flattened topographical images

"""

path = '/home/huanhuan/Desktop/stitch_papers/26th_ruben/overlap10'
files= os.listdir(path)
images = [i for i in files if i.endswith('.gwy')]
images_path = [os.path.join(path, i) for i in images]
images = []
for i in images_path:
    filename = i.split('.') [0].split('/')[-1]
    data = load_data(i)
    img_flattened = flatten_img(data)
    image_high_res = normalize_img(img_flattened)
    images.append(image_high_res)

    cv2.imwrite(f'{path}/topography/{filename}.png', image_high_res)

    fig, ax = plt.subplots(figsize=(3.8, 3.8))

    im = ax.imshow(data, cmap="gray")

    plt.imshow(image_high_res)

    plt.close()

In [None]:
"""
Generating amplitude images

"""

path = '/home/huanhuan/Desktop/stitch_papers/26th_ruben/overlap10'
files= os.listdir(path)
images = [i for i in files if i.endswith('.gwy')]
images_path = [os.path.join(path, i) for i in images]
images = []
for i in images_path:
    filename = i.split('.') [0].split('/')[-1]
    amplitude = load_amplitude(i)
    amplitude_norm = normalize_img(amplitude)
    images.append(amplitude_norm)

    cv2.imwrite(f'{path}/amplitude/{filename}.png', amplitude_norm)

    fig, ax = plt.subplots(figsize=(3.8, 3.8))

    im = ax.imshow(amplitude_norm, cmap="gray")

    plt.imshow(amplitude_norm)

    plt.close()

In [None]:
"""
Generating differential images

"""

path = '/home/huanhuan/Desktop/stitch_papers/figures for paper/6th_7*7'
files= os.listdir(path)
images = [i for i in files if i.endswith('.gwy')]
images_path = [os.path.join(path, i) for i in images]
images = []
for i in images_path:
    filename = i.split('.') [0].split('/')[-1]
    grad_x_norm = diff(i)

    cv2.imwrite(f'{path}/grad_x/{filename}.png', grad_x_norm)

    fig, ax = plt.subplots(figsize=(3.8, 3.8))

    im = ax.imshow(grad_x_norm, cmap="gray")

    plt.imshow(grad_x_norm)

    plt.close()