In [None]:
import matplotlib.pyplot as plt
import numpy as np
import nanonispy as nap
import tkinter as tk
from tkinter import filedialog
from PIL import Image, ImageTk
from PIL.PngImagePlugin import PngInfo
import scipy.linalg
from mpl_toolkits.mplot3d import Axes3D
import os
import sys
from IPython.display import display, clear_output
import time

def read_metadata(sxm_file):
    try:
        data = nap.read.Scan(sxm_file)
        print("SXM image file (" + sxm_file + ") loaded successfully.")

        # Extract the metadata
        scan_time = data.header["rec_time"]
        scan_date = data.header["rec_date"]
        scan_dir = data.header["scan_dir"]
        scan_bias = float(data.header["bias"])
        scan_fb = float(data.header["z-controller>setpoint"]) * 1E12        # feedback current in pA
        scan_width = float(data.header["scan_range"][0]) * 1E9              # Scan width in nm
        #scan_temp = float(data.header["temperature 1>temperature 1 (k)"])  
        scan_fb_control = data.header["z-controller>controller status"]     # feedback ON or OFF
        scan_pixels = int(data.header["scan_pixels"][0])                    # number of pixels in the x-direction of the scan
        pixel_size = scan_width / scan_pixels                               # nm/pixel
        scan_temp = 5
        
        if scan_fb_control == "ON":
            scan_info_1 = "STM topographic scan (V_bias = " + str(scan_bias) + " V, I_fb = " + str(round(scan_fb, 3)) + " pA, width = " + str(round(scan_width, 3)) + " nm)"
        else:
            scan_info_1 = "Constant height scan (V_bias = " + str(scan_bias) + " V, width = " + str(round(scan_width, 3)) + " nm)"

        scan_info_2 = "Scan recorded on " + scan_date + " at " + scan_time + " at a temperature of " + str(round(scan_temp, 3)) + " K"
    
    except Exception as e:
        print(f"Error loading sxm file: {e}")
        exit(1)
    
    return [scan_time, scan_date, scan_dir, scan_bias, scan_fb, scan_width, scan_temp, scan_fb_control, scan_pixels, pixel_size, scan_info_1, scan_info_2]

def crop_nan(matrix): # Crops a NumPy matrix to remove rows and columns containing NaN values.

    nan_mask = np.isnan(matrix)
    nan_rows = np.any(nan_mask, axis = 1)

    if np.all(nan_rows):
        return np.array([])
    
    first_row = np.argmin(nan_rows) if np.any(~nan_rows) else 0
    last_row = len(nan_rows) - np.argmin(nan_rows[::-1]) if np.any(~nan_rows) else len(nan_rows)

    return matrix[first_row:last_row]

def cropshiftflip(matrix, scan_direction):
    # Crop the matrix to remove NaN values
    cropped_matrix = crop_nan(matrix)

    min_value = np.min(cropped_matrix)

    # Shift the matrix to center it
    shifted_matrix = cropped_matrix - min_value

    # Flip the matrix upside down
    if scan_direction == "up":
        flipped_matrix = np.flipud(shifted_matrix)
    else:
        flipped_matrix = shifted_matrix

    return flipped_matrix

def least_squares(X, Y, scan_data, order):
    XX = X.flatten()
    YY = Y.flatten()
    ZZ = scan_data.flatten()
    total_points = len(XX)

    if order == 1:
        # best-fit linear plane
        
        A = np.c_[XX, YY, np.ones(total_points)]
        
        C,_,_,_ = scipy.linalg.lstsq(A, ZZ)    # coefficients
        
        # evaluate it on grid
        Z = C[0] * X + C[1] * Y + C[2]
        
    elif order == 2:
        # best-fit quadratic curve
        A = np.c_[np.ones(np.shape(scan_data)[0]), scan_data[:, :2], np.prod(scan_data[:, :2], axis = 1), scan_data[:, :2]**2]
        C,_,_,_ = scipy.linalg.lstsq(A, scan_data[:, 2])
        
        # evaluate it on a grid
        Z = np.dot(np.c_[np.ones(np.shape(XX)), XX, YY, XX * YY, XX ** 2, YY ** 2], C).reshape(np.shape(X))
    
    return Z

def matrix_normalize(matrix):
    # Find the data range
    matrix_min, matrix_max = np.min(matrix), np.max(matrix)
    data_range = matrix_max - matrix_min

    matrix_norm = matrix - matrix_min
    matrix_norm = matrix_norm / data_range
    
    return matrix_norm, data_range

def matrix_save_and_show(matrix, output_filename):
    plt.imshow(matrix * 255, cmap = "gray")
    plt.imsave(output_filename, matrix * 255, cmap = "gray")
    print("Image saved as " + output_filename)
    plt.axis("off")
    plt.show()

def add_metadata_to_png(filename, metadata):
    # Create a new PNG image
    img = Image.open(filename)
    png_info = PngInfo()

    # Add metadata to the PNG image
    for key, value in metadata.items():
        png_info.add_text(key, value)

    # Save the image with metadata
    img.save(filename, pnginfo = png_info)

def process_and_save(data_matrix, scan_width, subtraction, output_filename):

    if subtraction == "none":
        matrix_norm, data_range = matrix_normalize(data_matrix)
        matrix_save_and_show(matrix_norm, output_filename)
        print("This is the raw Z image")

    elif subtraction == "plane":
        # Obtain grids of the x and y coordinates corresponding to the sxm data, then use them to perform a plane subtraction           
        scan_height = scan_width * (np.shape(data_matrix)[1] / np.shape(data_matrix)[0])  # nm
        
        # Use indexing='ij' so that X_grid and Y_grid match the shape of data_matrix.
        X_grid, Y_grid = np.meshgrid(
            np.linspace(-scan_width / 2, scan_width / 2, data_matrix.shape[0]),
            np.linspace(-scan_height / 2, scan_height / 2, data_matrix.shape[1]),
            indexing = "ij"
        )
        
        Z_fit = least_squares(X_grid, Y_grid, data_matrix, 1)
        data_subtr = data_matrix - Z_fit
        # Save the plane-subtracted Z data as a PNG image
    
        matrix_norm, data_range = matrix_normalize(data_subtr)
        matrix_save_and_show(matrix_norm, output_filename)
        print("This is the Z image subjected to a global plane subtraction")

    range_info = "The full range of this channel is " + str(round(data_range, 3)) + " nm"
    return range_info

In [22]:
# Preferences
planesubtraction = "plane"
targetchannel = "Z"
scandirection = "forward"

In [None]:
# Select a file
root = tk.Tk()
root.withdraw()  # Hide the main window
root.attributes("-topmost", True)  # Keep the file dialog on top

file_path = filedialog.askopenfilename(
    title = "Select an .sxm file",
    filetypes = [("SXM files and spectra", "*.sxm *.dat")]
)

directory, filename = os.path.split(file_path)
extension = os.path.splitext(file_path)[1]
save_directory = directory + "/Processed/"
os.makedirs(save_directory, exist_ok = True)

# The selected file is a spectroscopy file
if extension == ".dat":
    try:
        data = nap.read.Spec(file_path)
        print("Spectroscopy file loaded successfully.")
        
        V_t = data.signals["Bias (V)"]
        I_t = data.signals["Current (A)"] * 1e12
        
        
        IV = np.array([V_t, I_t])
        
        spectroscopy_filename = save_directory + filename.replace(".dat", "_IV.svg")

        fig, ax = plt.subplots()
        ax.plot(V_t, I_t)
        ax.set(xlabel = "bias (V)", ylabel = "current (pA)", title = "I(V)")
        plt.savefig(spectroscopy_filename)
        print("I(V) spectrum saved as " + spectroscopy_filename)
        plt.show()
                
    except Exception as e:
        print(f"Error loading spectroscopy file: {e}")
        exit(1)

# The selected filed is an sxm image file
elif extension == ".sxm":
    scan_time, scan_date, scan_dir, scan_bias, scan_fb, scan_width, scan_temp, scan_fb_control, scan_pixels, pixel_size, scan_info_1, scan_info_2 = read_metadata(file_path)
    print(scan_info_1)
    print(scan_info_2) 
    
    try:
        data = nap.read.Scan(file_path)        
        channels = data.signals.keys()
    
        print("The following scan channels were found:\n", list(channels))

        if targetchannel == "Z" and any("Z" for channel in channels):
            print("Z found")
            
            # Extract the Z channel data, crop off rows that have not been scanned, and shift to remove the offset
            Z = cropshiftflip(data.signals["Z"][scandirection] * 1E9, scan_dir)
            output_filename = save_directory + filename.replace(".sxm", "_Z.png")

            range_info = process_and_save(Z, scan_width, planesubtraction, output_filename)
            add_metadata_to_png(output_filename, {"1": scan_info_1, "2": scan_info_2, "3": range_info})
            print(range_info)

                

    except Exception as e:
        print(f"Error loading sxm file: {e}")
        exit(1)