# load packages

In [None]:
import os
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

import cv2

from aespm import *
from utils import *
from tools import *

In [None]:
# Import packages for VAE
import atomai as aoi
import cv2
import torch
import torch.nn as nn
tt = torch.tensor

import gpax
import jax.numpy as jnp
gpax.utils.enable_x64()
import pickle

# Make the connection and pre-define functions

In [None]:
host = ''
username = ''
password = ''

folder = r"C:\Users\Asylum User\Documents\Asylum Research Data\240410"

exp = Experiment(folder=folder, connection=[host, username, password])

## Custom functions for the workflow

In [None]:
## Commonly used custom functions

def load_ibw(self, folder="C:\\Users\\Asylum User\\Documents\\AEtesting\\data_exchange"):
    '''
    Read the latest ibw file saved in a given folder.
    '''
    fname = get_files(path=folder, client=self.client)[0]
    return ibw_read(fname, copy=False, connection=self.connection)

exp.add_func(load_ibw)

def convert_coord(self, data, coord):
    '''
    Convert the coordinate from pixel to distance.
    Apply rotation if needed.
    '''
    x, y = coord
    
    scan_angle = data.header['ScanAngle']
    
    img = data[0] # This is the height channel
    
    # Convert angle to radians
    theta_rad = np.radians(-scan_angle)
    
    # Create 2D rotation matrix
    rot_matrix = np.array([[np.cos(theta_rad), -np.sin(theta_rad)],
                           [np.sin(theta_rad), np.cos(theta_rad)]])
    
    # Apply the rotation matrix to the coordinates
    center = (np.array(np.shape(img))-1) // 2
    x_rot, y_rot = np.zeros_like(x), np.zeros_like(y)
    for i in range(len(x)):
        x_rot[i], y_rot[i] = np.dot(rot_matrix, (np.array([x[i], y[i]])-center)) + center
    
    # Convert the pixels to the distance
    xpixels, ypixels = data.header['PointsLines'],data.header['ScanPoints']
    xsize, ysize = data.header['FastScanSize'],data.header['SlowScanSize']

    xfactor = xsize / xpixels
    yfactor = ysize / ypixels

    positions = np.zeros([len(x), 2])

    for i in range(len(x)):
        positions[i] = np.array([x_rot[i] * xfactor, y_rot[i] * yfactor])

    # Sort the positions according to x first and y second
    pos_sorted = sorted(positions, key=lambda x: (x[1], x[0]))
    
    p = {
        'ScanAngle': scan_angle,
        'xpixels': xpixels,
        'ypixels': ypixels,
        'xsize': xsize,
        'ysize': ysize,
        'xfactor': xfactor,
        'yfactor': yfactor,
    }
    
    for key in p:
        self.update_param(key=key, value=p[key])
    
    return pos_sorted

exp.add_func(convert_coord)

# Function to move the probe with the given displacement 
def move_tip(self, distance, v0=None, s=None):
    
    # Enable the stage move --> 5 sec, 8 seconds for safety
    move_tip(r=distance, v0=v0, s=s, connection=self.connection)

exp.add_func(move_tip)

# Function to move the probe to the location r and start force distance measurement.
def measure_fd(self, fname, r, v0=None):
    '''
    Move the probe to the location r and start force distance measurement.
    '''
    action_list = [
        ['ChangeName', fname, None], # Change file names
        ['ClearForce', None, None], # Clear any existing force points
        ['GoThere', None, 1], # Move to the center of the image
        ['move_tip', [r], None, v0=v0, s=self.param['sensitivity']], # Move the tip to location r
        ['SingleForce', None, None], # Start a DART spec
        ['check_files', None, 1], # Check file numbers in the data save folder
    ]
    
    self.execute_sequence(action_list)
    
exp.add_func(measure_fd)
    
# Function to check the file number in a given folder
def check_files(self):
    return check_file_number(path=self.folder, connection=self.connection)
exp.add_func(check_files)

## Custom functions for VAE

In [None]:
# Help function to crop images into patches
from typing import Tuple, Optional, Dict, Union, List

def get_imgstack(imgdata: np.ndarray,
                 coord: np.ndarray,
                 r: int) -> Tuple[np.ndarray]:
    """
    Extracts subimages centered at specified coordinates
    for a single image
    Args:
        imgdata (3D numpy array):
            Prediction of a neural network with dimensions
            :math:`height \\times width \\times n channels`
        coord (N x 2 numpy array):
            (x, y) coordinates
        r (int):
            Window size
    Returns:
        2-element tuple containing
        - Stack of subimages
        - (x, y) coordinates of their centers
    """
    img_cr_all = []
    com = []
    for c in coord:
        cx = int(np.around(c[0]))
        cy = int(np.around(c[1]))
        if r % 2 != 0:
            img_cr = np.copy(
                imgdata[cx-r//2:cx+r//2+1,
                        cy-r//2:cy+r//2+1])
        else:
            img_cr = np.copy(
                imgdata[cx-r//2:cx+r//2,
                        cy-r//2:cy+r//2])
        if img_cr.shape[0:2] == (int(r), int(r)) and not np.isnan(img_cr).any():
            img_cr_all.append(img_cr[None, ...])
            com.append(c[None, ...])
    if len(img_cr_all) == 0:
        return None, None
    img_cr_all = np.concatenate(img_cr_all, axis=0)
    com = np.concatenate(com, axis=0)
    return img_cr_all, com


def extract_subimages(imgdata: np.ndarray,
                      coordinates: Union[Dict[int, np.ndarray], np.ndarray],
                      window_size: int, coord_class: int = 0) -> Tuple[np.ndarray]:

    if isinstance(coordinates, np.ndarray):
        coordinates = np.concatenate((
            coordinates, np.zeros((coordinates.shape[0], 1))), axis=-1)
        coordinates = {0: coordinates}
    if np.ndim(imgdata) == 2:
        imgdata = imgdata[None, ..., None]
    subimages_all, com_all, frames_all = [], [], []
    for i, (img, coord) in enumerate(
            zip(imgdata, coordinates.values())):
        coord_i = coord[np.where(coord[:, 2] == coord_class)][:, :2]
        stack_i, com_i = get_imgstack(img, coord_i, window_size)
        if stack_i is None:
            continue
        subimages_all.append(stack_i)
        com_all.append(com_i)
        frames_all.append(np.ones(len(com_i), int) * i)
    if len(subimages_all) > 0:
        subimages_all = np.concatenate(subimages_all, axis=0)
        com_all = np.concatenate(com_all, axis=0)
        frames_all = np.concatenate(frames_all, axis=0)

    return subimages_all, com_all, frames_all

# Acquire an image

In [None]:
exp.execute('DownScan')

w = exp.load_ibw(folder=exp.folder)

plt.imshow(w.data[0], origin='lower')

# Extract features with VAE

In [None]:
# Normalize the image
img = w.data[0]

s1, s2 = np.shape(img)

image = (img - img.min()) / img.ptp()

In [None]:
# We need to crop the whole image into small patched before running VAE

# skip is the distance between center of the neighboring patches
skip = 1
coordinates = aoi.utils.get_coord_grid(image[::skip,::skip], step = 1, return_dict=False)
# coordinates

In [None]:
import random
import matplotlib.patches as p

# image patch size
window_size = 10
patches, coords, _ = extract_subimages(image[::skip, ::skip], coordinates, window_size)
patches = patches.squeeze()
np.shape(patches), np.shape(coords)

# Let's visualize the image patches after cropping
img_index = [random.randint(0, len(patches)) for _ in range(16)]

fig,ax=plt.subplots()
ax.imshow(img)

for i in range(16):
    y,x = coords[img_index[i]]
    rec = p.Rectangle((x,y), 10, 10, linewidth=1, edgecolor='r', facecolor='none')
    ax.add_patch(rec)

fig,ax=plt.subplots(4,4,figsize=[12,12])
for i in range(16):
    ax[i//4,i%4].imshow(patches[img_index[i]])

In [None]:
# Running the model with atomai rvae
input_dim = (window_size,window_size)
vae = aoi.models.VAE(input_dim, latent_dim=2,
                    numlayers_encoder=3, numhidden_encoder=1024,
                    numlayers_decoder=3, numhidden_decoder=1024,
                    skip=True)

vae.fit(patches, training_cycles=100, batch_size=75, loss="ce",
         filename='vae')

In [None]:
# Latent representation
vae.manifold2d(d = 10, origin='lower')

In [None]:
m, s = vae.encode(patches)

In [None]:
# Decode a specific position in the latent space

z = [-0.6,-0.4]
thres = 0.25
index = np.where(np.sqrt((m[:,0]-z[0])**2 + (m[:,1]-z[1])**2) < thres)

plt.figure(figsize=[4,4])
plt.imshow(vae.decode(np.array(z)).squeeze(), origin='lower')
fix,ax=plt.subplots(3,4,figsize=[12, 10])
for i in range(12):
    ax[i//4, i%4].imshow(patches[index[0][i]], origin='lower')
    # print(coordinates[index[i]])


In [None]:
coord = np.array(coords[index])
y, x = coord[:,0], coord[:,1]
# len(coord)
# coord
print(len(coord))
plt.imshow(img, origin='lower')
plt.plot(x, y, 'ro')

# Convert the coordinate unit from pixel to distance

In [None]:
pos_sorted = exp.convert_coord(data=w, coord=[x, y])

# Take force spectrum on the edge points

In [None]:
# Move the probe to the center of the image, and record the X Y sensor readings there
action_list = [
    ['ClearForce', None, None], # Clear any existing force points
    ['GoThere', None, 1], # Move to the center of the image
]

exp.execute_sequence(action_list)

v0 = read_spm(key=['PIDSLoop.0.Setpoint', 'PIDSLoop.1.Setpoint'], connection=exp.connection)

pos0 = np.array([exp.param['xsize']/2, exp.param['ysize']/2])

for i in range(len(pos_sorted)):
    # Get the next point to measure
    pos_next = pos_sorted[i]
    # Get the distance between next point and current point
    distance_to_move = pos_next - pos0
    exp.measure_fd(fname='FD_loc_{03d}_'.format(i), r=distance_to_move, v0=v0)
    